pytorch查看模型weight与grad方式

 更新时间:2020年06月24日 08:56:26   作者:YongjieShi  
这篇文章主要介绍了pytorch查看模型weight与grad方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

在用pdb debug的时候,有时候需要看一下特定layer的权重以及相应的梯度信息,如何查看呢?

1. 首先把你的模型打印出来,像这样

2. 然后观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了,在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad就可以查看梯度信息

补充知识:查看Pytorch网络的各层输出(feature map)、权重(weight)、偏置(bias)

BatchNorm2d参数量

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# 卷积层中卷积核的数量C 
num_features – C from an expected input of size (N, C, H, W)
>>> import torch
>>> m = torch.nn.BatchNorm2d(100)
>>> m.weight.shape
torch.Size([100])
>>> m.numel()
AttributeError: 'BatchNorm2d' object has no attribute 'numel'
>>> m.weight.numel()
100
>>> m.parameters().numel()
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
AttributeError: 'generator' object has no attribute 'numel'
>>> [p.numel() for p in m.parameters()]
[100, 100]

linear层

>>> import torch
>>> m1 = torch.nn.Linear(100,10)
# 参数数量= (输入神经元+1)*输出神经元
>>> m1.weight.shape
torch.Size([10, 100])
>>> m1.bias.shape
torch.Size([10])
>>> m1.bias.numel()
10
>>> m1.weight.numel()
1000
>>> m11 = list(m1.parameters())
>>> m11[0].shape
# weight
torch.Size([10, 100])
>>> m11[1].shape
# bias
torch.Size([10])

weight and bias

# Method 1 查看Parameters的方式多样化,直接访问即可
model = alexnet(pretrained=True).to(device)
conv1_weight = model.features[0].weight# Method 2 
# 这种方式还适合你想自己参考一个预训练模型写一个网络,各层的参数不变,但网络结构上表述有所不同
# 这样你就可以把param迭代出来,赋给你的网络对应层,避免直接load不能匹配的问题!
for layer,param in model.state_dict().items(): # param is weight or bias(Tensor) 
 print layer,param

feature map

由于pytorch是动态网络,不存储计算数据,查看各层输出的特征图并不是很方便!分下面两种情况讨论:

1、你想查看的层是独立的,那么你在forward时用变量接收并返回即可!!

class Net(nn.Module):
  def __init__(self):
    self.conv1 = nn.Conv2d(1, 1, 3)
    self.conv2 = nn.Conv2d(1, 1, 3)
    self.conv3 = nn.Conv2d(1, 1, 3)  def forward(self, x):
    out1 = F.relu(self.conv1(x))
    out2 = F.relu(self.conv2(out1))
    out3 = F.relu(self.conv3(out2))
    return out1, out2, out3

2、你的想看的层在nn.Sequential()顺序容器中,这个麻烦些,主要有以下几种思路:

# Method 1 巧用nn.Module.children()
# 在模型实例化之后,利用nn.Module.children()删除你查看的那层的后面层
import torch
import torch.nn as nn
from torchvision import modelsmodel = models.alexnet(pretrained=True)# remove last fully-connected layer
new_classifier = nn.Sequential(*list(model.classifier.children())[:-1])
model.classifier = new_classifier
# Third convolutional layer
new_features = nn.Sequential(*list(model.features.children())[:5])
model.features = new_features
# Method 2 巧用hook,推荐使用这种方式,不用改变原有模型
# torch.nn.Module.register_forward_hook(hook)
# hook(module, input, output) -> Nonemodel = models.alexnet(pretrained=True)
# 定义
def hook (module,input,output):
  print output.size()
# 注册
handle = model.features[0].register_forward_hook(hook)
# 删除句柄
handle.remove()# torch.nn.Module.register_backward_hook(hook)
# hook(module, grad_input, grad_output) -> Tensor or None
model = alexnet(pretrained=True).to(device)
outputs = []
def hook (module,input,output):
  outputs.append(output)
  print len(outputs)handle = model.features[0].register_backward_hook(hook)

注:还可以通过定义一个提取特征的类,甚至是重构成各层独立相同模型将问题转化成第一种

计算模型参数数量

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

以上这篇pytorch查看模型weight与grad方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Pygame游戏开发之太空射击实战入门篇

    Pygame游戏开发之太空射击实战入门篇

    相信大多数8090后都玩过太空射击游戏,在过去游戏不多的年代太空射击自然属于经典好玩的一款了,今天我们来自己动手实现它,在编写学习中回顾过往展望未来,下面开始入门篇
    2022-08-08
  • pytorch 搭建神经网路的实现

    pytorch 搭建神经网路的实现

    这篇文章主要介绍了pytorch 搭建神经网路,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-08-08
  • python中Pycharm 输出中文或打印中文乱码现象的解决办法

    python中Pycharm 输出中文或打印中文乱码现象的解决办法

    本篇文章主要介绍了python中Pycharm 输出中文或打印中文乱码现象的解决办法 ,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2017-06-06
  • 利用PyQt5模拟实现网页鼠标移动特效

    利用PyQt5模拟实现网页鼠标移动特效

    不知道大家有没有发现,博客园有些博客左侧会有鼠标移动特效。通过移动鼠标,会形成类似蜘蛛网的特效,本文将用PyQt5实现这一特效,需要的可以参考一下
    2022-03-03
  • Python免费试用最新Openai API的步骤

    Python免费试用最新Openai API的步骤

    本文主要介绍了Python免费试用最新Openai API,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-03-03
  • python3 将阶乘改成函数形式进行调用的操作

    python3 将阶乘改成函数形式进行调用的操作

    这篇文章主要介绍了python3 将阶乘改成函数形式进行调用的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-03-03
  • python抓取百度首页的方法

    python抓取百度首页的方法

    这篇文章主要介绍了python抓取百度首页的方法,涉及Python使用urllib模块实现页面抓取的相关技巧,需要的朋友可以参考下
    2015-05-05
  • python语音识别指南终极版(有这一篇足矣)

    python语音识别指南终极版(有这一篇足矣)

    这篇文章主要介绍了python语音识别指南终极版的相关资料,包括语音识别的工作原理及使用代码,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-09-09
  • pythotn条件分支与循环详解(2)

    pythotn条件分支与循环详解(2)

    这篇文章主要介绍了Python条件分支和循环用法,结合实例形式较为详细的分析了Python逻辑运算操作符,条件分支语句,循环语句等功能与基本用法,需要的朋友可以参考下
    2021-08-08
  • Python实现求取表格文件某个区域内单元格的最大值

    Python实现求取表格文件某个区域内单元格的最大值

    这篇文章主要介绍基于Python语言,基于Excel表格文件内某一列的数据,计算这一列数据在每一个指定数量的行的范围内(例如每一个4行的范围内)的区间最大值的方法,需要的朋友可以参考下
    2023-08-08

最新评论