pytorch固定BN层参数的操作

 更新时间:2021年05月27日 08:58:15   作者:grllery  
这篇文章主要介绍了pytorch固定BN层参数的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

背景:

基于PyTorch的模型,想固定主分支参数,只训练子分支,结果发现在不同epoch相同的测试数据经过主分支输出的结果不同。

原因:

未固定主分支BN层中的running_mean和running_var。

解决方法:

将需要固定的BN层状态设置为eval。

问题示例:

环境:torch:1.7.0

# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.bn1 = nn.BatchNorm2d(6)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.bn2 = nn.BatchNorm2d(16)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 5)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

def print_parameter_grad_info(net):
    print('-------parameters requires grad info--------')
    for name, p in net.named_parameters():
        print(f'{name}:\t{p.requires_grad}')

def print_net_state_dict(net):
    for key, v in net.state_dict().items():
        print(f'{key}')

if __name__ == "__main__":
    net = Net()

    print_parameter_grad_info(net)
    net.requires_grad_(False)
    print_parameter_grad_info(net)

    torch.random.manual_seed(5)
    test_data = torch.rand(1, 1, 32, 32)
    train_data = torch.rand(5, 1, 32, 32)

    # print(test_data)
    # print(train_data[0, ...])
    for epoch in range(2):
        # training phase, 假设每个epoch只迭代一次
        net.train()
        pre = net(train_data)
        # 计算损失和参数更新等
        # ....

        # test phase
        net.eval()
        x = net(test_data)
        print(f'epoch:{epoch}', x)

运行结果:

-------parameters requires grad info--------
conv1.weight: True
conv1.bias: True
bn1.weight: True
bn1.bias: True
conv2.weight: True
conv2.bias: True
bn2.weight: True
bn2.bias: True
fc1.weight: True
fc1.bias: True
fc2.weight: True
fc2.bias: True
fc3.weight: True
fc3.bias: True
-------parameters requires grad info--------
conv1.weight: False
conv1.bias: False
bn1.weight: False
bn1.bias: False
conv2.weight: False
conv2.bias: False
bn2.weight: False
bn2.bias: False
fc1.weight: False
fc1.bias: False
fc2.weight: False
fc2.bias: False
fc3.weight: False
fc3.bias: False
epoch:0 tensor([[-0.0755, 0.1138, 0.0966, 0.0564, -0.0224]])
epoch:1 tensor([[-0.0763, 0.1113, 0.0970, 0.0574, -0.0235]])

可以看到:

net.requires_grad_(False)已经将网络中的各参数设置成了不需要梯度更新的状态,但是同样的测试数据test_data在不同epoch中前向之后出现了不同的结果。

调用print_net_state_dict可以看到BN层中的参数running_mean和running_var并没在可优化参数net.parameters中

bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked

但在training pahse的前向过程中,这两个参数被更新了。导致整个网络在freeze的情况下,同样的测试数据出现了不同的结果

Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a defaultmomentumof 0.1. source

因此在training phase时对BN层显式设置eval状态:

if __name__ == "__main__":
    net = Net()
    net.requires_grad_(False)

    torch.random.manual_seed(5)
    test_data = torch.rand(1, 1, 32, 32)
    train_data = torch.rand(5, 1, 32, 32)

    # print(test_data)
    # print(train_data[0, ...])
    for epoch in range(2):
        # training phase, 假设每个epoch只迭代一次
        net.train()
        net.bn1.eval()
        net.bn2.eval()
        pre = net(train_data)
        # 计算损失和参数更新等
        # ....

        # test phase
        net.eval()
        x = net(test_data)
        print(f'epoch:{epoch}', x)

可以看到结果正常了:

epoch:0 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])
epoch:1 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])

补充:pytorch---之BN层参数详解及应用(1,2,3)(1,2)?

BN层参数详解(1,2)

一般来说pytorch中的模型都是继承nn.Module类的,都有一个属性trainning指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层(对于BN层测试的均值和方差是通过统计训练的时候所有的batch的均值和方差的平均值)或者Dropout层(对于Dropout层在测试的时候所有神经元都是激活的)。通常用model.train()指定当前模型model为训练状态,model.eval()指定当前模型为测试状态。

同时,BN的API中有几个参数需要比较关心的,一个是affine指定是否需要仿射,还有个是track_running_stats指定是否跟踪当前batch的统计特性。容易出现问题也正好是这三个参数:trainning,affine,track_running_stats。

其中的affine指定是否需要仿射,也就是是否需要上面算式的第四个,如果affine=False则γ=1,β=0 \gamma=1,\beta=0γ=1,β=0,并且不能学习被更新。一般都会设置成affine=True。(这里是一个可学习参数)

trainning和track_running_stats,track_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性(意思就是说新的batch依赖于之前的batch的均值和方差这里使用momentum参数,参考了指数移动平均的算法EMA)。相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了。当在推理阶段的时候,如果track_running_stats=False,此时如果batch_size比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。

应用技巧:(1,2)

通常pytorch都会用到optimizer.zero_grad() 来清空以前的batch所累加的梯度,因为pytorch中Variable计算的梯度会进行累计,所以每一个batch都要重新清空一次梯度,原始的做法是下面这样的:

问题:参数non_blocking,以及pytorch的整体框架??

代码(1)

for index,data,target in enumerate(dataloader):
    data = data.cuda(non_blocking=True)
    target = torch.from_numpy(np.array(target)).float().cuda(non_blocking = Trye)
    output = model(data)
    loss = criterion(output,target)
    
    #清空梯度
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

而这里为了模仿minibacth,我们每次batch不清0,累积到一定次数再清0,再更新权重:

for index, data, target in enumerate(dataloader):
    #如果不是Tensor,一般要用到torch.from_numpy()
    data = data.cuda(non_blocking = True)
    target = torch.from_numpy(np.array(target)).float().cuda(non_blocking = True)
    output = model(data)
    loss = criterion(data, target)
    loss.backward()
    if index%accumulation == 0:
        #用累积的梯度更新权重
        optimizer.step()
        #清空梯度
        optimizer.zero_grad()

虽然这里的梯度是相当于原来的accumulation倍,但是实际在前向传播的过程中,对于BN几乎没有影响,因为前向的BN还是只是一个batch的均值和方差,这个时候可以用pytorch中BN的momentum参数,默认是0.1,BN参数如下,就是指数移动平均

x_new_running = (1 - momentum) * x_running + momentum * x_new_observed. momentum

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • pytorch中的广播语义

    pytorch中的广播语义

    这篇文章主要介绍了pytorch中的广播语义,pytorch的广播语义即broadcasting semantics,和numpy的很像,下面文章介绍更多相关内容的介绍,需要的小伙伴可以参考一下
    2022-03-03
  • Python读取mat文件,并转为csv文件的实例

    Python读取mat文件,并转为csv文件的实例

    今天小编就为大家分享一篇Python读取mat文件,并转为csv文件的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-07-07
  • pip命令无法使用的解决方法

    pip命令无法使用的解决方法

    今天小编就为大家分享一篇pip命令无法使用的解决方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • python 安装移动复制第三方库操作

    python 安装移动复制第三方库操作

    这篇文章主要介绍了python 安装移动复制第三方库操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-07-07
  • python 获取毫秒数,计算调用时长的方法

    python 获取毫秒数,计算调用时长的方法

    今天小编就为大家分享一篇python 获取毫秒数,计算调用时长的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-02-02
  • 关于Keras Dense层整理

    关于Keras Dense层整理

    这篇文章主要介绍了关于Keras Dense层整理,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • Python常见内置高效率函数用法示例

    Python常见内置高效率函数用法示例

    这篇文章主要介绍了Python常见内置高效率函数用法,结合实例形式分析了Python中filter()、map()、reduce()、lambda匿名函数等功能与简单使用技巧,需要的朋友可以参考下
    2018-07-07
  • python3读取csv和xlsx文件的实例

    python3读取csv和xlsx文件的实例

    今天小编就为大家分享一篇python3读取csv和xlsx文件的实例,具有很好的参考价值,希望对的大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • 详解pandas赋值失败问题解决

    详解pandas赋值失败问题解决

    这篇文章主要介绍了详解pandas赋值失败问题解决,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-11-11
  • Python使用 OpenCV 进行图像投影变换

    Python使用 OpenCV 进行图像投影变换

    这篇文章主要介绍了Python使用 OpenCV 进行图像投影变换,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的小伙伴可以参考一下
    2022-08-08

最新评论