Pytorch 冻结参数更新的三种方法

在Pytorch中,默认情况下,所有设置requires_grad=True的张量都能跟踪它们的梯度计算历史并支持梯度计算。然而,在某些情况下,我们不需要这样做,例如,当我们已经训练了模型,只想将其应用于一些输入数据时,即我们只想通过网络进行前向计算。此时,这个禁用梯度跟踪操作就显得很重要,即冻结某个变量或模块的参数更新。下面介绍三种方式实现这个操作。
第一种方法:requires_grad_(False)冻结

import torch
import torch.nn as nn
class my_model(nn.Module):
    def __init__(self):
        super(my_model, self).__init__()
        self.l1 = nn.Linear(3,3).requires_grad_(False)
        self.l2 = nn.Linear(3,3)
    def forward(self, x):
        out = self.l1(x) +self.l2(x) 
        return out
model = my_model()
y=torch.rand(6,3)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
for i in range(2):
    data = torch.randn(6,3)
    out = model(x)
    loss=nn.functional.mse_loss(y,out)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(list(model.named_parameters()))
    print(loss)

通过上述结果,我们会发现l1层的参数,完成没有更新,从而达到冻结的作用。可以通过list(model.named_parameters())查看参数更新情况
其次,我们查看一下,被冻结的参数是否进入学习。

model.state_dict()

我们发现这个被冻结的参数在学习层中(model.state_dict()只保存参与学习的参数),因此这是最纯粹的冻结。
第二种方法:detach()分离操作

import torch
import torch.nn as nn
class my_model(nn.Module):
    def __init__(self):
        super(my_model, self).__init__()
        self.l1 = nn.Linear(3,3)
        self.l2 = nn.Linear(3,3)
    def forward(self, x):
        out = self.l1(x).detach()
        out = out+self.l2(x) 
        return out
model = my_model()
y=torch.rand(6,3)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
for i in range(2):
    data = torch.randn(6,3)
    out = model(x)
    loss=nn.functional.mse_loss(y,out)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(list(model.named_parameters()))
    print(loss)
model.state_dict()

我们发现这个被冻结的参数同样在学习层中,因此detach()与requires_grad()的冻结方法是一样的。
但是从形式上看,第一种冻结方法要更直观易懂
同样可以达到相应的功能,但是detach()只能放于前向中,因为nn.Module没有detach属性,只有张量才这个属性。
第三种方法:with torch.no_grad()环境

import torch
import torch.nn as nn
class my_model(nn.Module):
    def __init__(self):
        super(my_model, self).__init__()
        with torch.no_grad():  
            self.l1 = nn.Linear(3,3)
        self.l2 = nn.Linear(3,3)
    def forward(self, x):
        out = self.l1(x)+self.l2(x) 
        return out
model = my_model()
y=torch.rand(6,3)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
for i in range(2):
    data = torch.randn(6,3)
    out = model(x)
    loss=nn.functional.mse_loss(y,out)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(list(model.named_parameters()))
    print(f'Loss is : {loss}')
model.state_dict()

特别注意:运行这个程序,我们会发现,l1的梯度仍然在更新,那是因为with torch.no_grad()这个环境对工厂函数没有作用(参见官网说明:https://pytorch.org/docs/stable/generated/torch.no_grad.html#no-grad)。像nn.Linear,nn.Parameters这样的函数均为工厂函数,如果像detach一样,放在forward呢?则就能冻结。

import torch
import torch.nn as nn
class my_model(nn.Module):
    def __init__(self):
        super(my_model, self).__init__()
        self.l1 = nn.Linear(3,3)
        self.l2 = nn.Linear(3,3)
    def forward(self, x):
        with torch.no_grad():  
            out = self.l1(x)
        out = out+self.l2(x) 
        return out
model = my_model()
y=torch.rand(6,3)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
for i in range(2):
    data = torch.randn(6,3)
    out = model(x)
    loss=nn.functional.mse_loss(y,out)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(list(model.named_parameters()))
    print(f'Loss is : {loss}')
model.state_dict()

总结:上面介绍了三种冻结参数的方法,如果正确使用,三种方法的结果是一样。但是,如果是冻结模块参数,建议使用requires_grad_(False)冻结,这种方法十分直观易懂。如果是冻结某个变量的参数更新,建议使用detach()和with torch.no_grad()。当然,这三种方法均可以放在训练过程,冻结某个部分或变量的参数更新。