PyTorch 训练过程中,冻结某些层的方法

前言

    深度学习训练自己的数据集时,由于计算资源,训练耗时等原因,通常我们站在巨人的肩膀上(迁移学习方式),来进行自己的数据训练。
 	在迁移学习的过程中,通常有两种方式:
 		1.直接使用预训练模型,重头到尾进行参数更新;	
 		2.首先,冻结预训练模型某些层的参数,进行初次训练,保存最好的模型结果best,然后,利用加载best与unfreeze某些层的参数,
 		再次训练更新权重参数,最后,保存结果最好的模型。
    具体选择哪种迁移学习方式,根据自己的数据集与训练初始化权重的数据集相似程度而定。以下我们主要讲解冻结某些网络层的三种方法。

冻结方法

1. require_grad=False

def freeze(model):
    for param in model.parameters():
        param.requires_grad = False

a. 冻结某一层参数

import torch
import torch.nn as nn
#lenet
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        freeze(self.conv2) # 冻结self.conv2层
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        # print(self)
        self.fc3 = nn.Linear(84, 10) 
        

b. 如果是冻结某一层上面所有层(传self)

import torch
import torch.nn as nn
#lenet
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        freeze(self) # 冻结con1 / conv2层
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10) 

2. torch.no_grad()

import torch
import torch.nn as nn
import torch.nn.functional as F 

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.maxPool1 = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.maxpool2 = nn.MaxPool2d(2,2)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        with torch.no_grad():
            x = self.conv1(x)
            x = self.maxPool1(F.reLU(x))
            x = self.maxPool1(F.reLU(self.conv1(x)))

        x = x.view(-1, 16 * 5 * 5)  
        x = F.relu(self.fc1(x))  
        x = F.relu(self.fc2(x))  
        x = self.fc3(x)
        
        return x

3. detach

	import torch
	
 	x = torch.randn(3,4, requires_grad=True)
    print(x.requires_grad)
    y = x.mul(x)
    print(y.requires_grad)
    y1 = y.detach()
    print(y1.requires_grad)

输出结果:
True
True
False

总结

  1. requires_grad:在最开始创建Tensor时候可以设置的属性,用于表明是否追踪当前Tensor的计算操作。后面也可以通过
    requires_grad_()方法设置该参数,但是只有叶子节点才可以设置该参数。
  2. detach()方法:则是用于将某一个Tensor从计算图中分离出来。返回的是一个内存共享的Tensor。
  3. torch.no_grad():对所有包裹的计算操作进行分离。但是torch.no_grad()将会使用更少的内存,因为从包裹的开始,就表明不需要计算梯度了,因此就不需要保存中间结果

参考文献:

  1. https://blog.csdn.net/weixin_42855362/article/details/127284573
  2. https://blog.csdn.net/qq_44785998/article/details/126103165