PyTorch之关于hook机制

 更新时间:2023年08月02日 15:35:56   作者:harry_tea  
这篇文章主要介绍了PyTorch之关于hook机制的理解,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

PyTorch: hook机制

在训练神经网络的时候我们有时需要输出网络中间层,一般来说我们有两种处理方法:

一种是在model的forward中保存中间层的变量,然后再return的时候将其和结果一起返回;

另一种是使用pytorch自带的register_forward_hook,即hook机制

register_forward_hook

register_forward_hook(hook)
  • 返回module中的一个前向的hook,这个hook每次在执行forward的时候都会被调用
  • hook: hook(module, input, output)

可能不是很好理解,我们直接用一个例子来说明,如下所示,首先我们将hook包装在类SaveValues中,我们现在想要获取模型Net中的l1的输入和输出,因此将model.l1存入到类中:value = SaveValues(model.l1),在类中定义一个hook_fn_act函数,此函数的作用是随着我们的register_forward_hook函数获取Net的某一层的名字,输入以及输出,在这里对应的就是model.l1, 他的输入和输出,最终我们将他获取的网络层的名字、输入以及输出保存到类SaveValues中方便我们输出

注意:hook_fn_act函数必须有三个参数,分别对应module,input以及output

import torch
import torch.nn as nn
class SaveValues():
    def __init__(self, layer):
        self.model  = None
        self.input  = None
        self.output = None
        self.grad_input  = None
        self.grad_output = None
        self.forward_hook  = layer.register_forward_hook(self.hook_fn_act)
        self.backward_hook = layer.register_full_backward_hook(self.hook_fn_grad)
    def hook_fn_act(self, module, input, output):
        self.model  = module
        self.input  = input[0]
        self.output = output
    def hook_fn_grad(self, module, grad_input, grad_output):
        self.grad_input  = grad_input[0]
        self.grad_output = grad_output[0]
    def remove(self):
        self.forward_hook.remove()
        self.backward_hook.remove()
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.l1 = nn.Linear(2, 5)
        self.l2 = nn.Linear(5, 10)
    def forward(self, x):
        x = self.l1(x)
        x = self.l2(x)
        return x
l1loss = nn.L1Loss()
model  = Net()
value  = SaveValues(model.l2)
gt = torch.ones((10,), dtype=torch.float32, requires_grad=False)
x  = torch.ones((2,), dtype=torch.float32, requires_grad=False)
y = model(x)
loss  = l1loss(y, gt)
loss.backward()
x += 1.2
value.remove()

运行上述程序,当我们运行到y = model(x)这一行时,我们看一下value中的值(图左),当我们运行完y = model(x)时,我们看一下value中的值(图右),这是因为在执行net中的forward函数时,我们的hook机制会从中提取出网络的输入和输出,不执行forward就不会提取

注意:

当我们不想在提取网络中间层时,我们调用value.remove()即可,即删除了网络中的hook。

但是在训练网络时我们可能需要输出每个epoch的中间层信息,那么在for循环中就不需要删除hook啦

register_full_backward_hook

好像这个反向hook很少用到?

register_forward_hook(hook)
  • 返回module中的一个反向的hook,这个hook每次在执行forward的时候都会被调用
  • hook: hook(module, grad_input, grad_output)

继续上述的代码,这次我们运行到loss.backward()之前与之后查看value中存储的grad的变化,如下所示,可以发现在没有反向传播之前grad为None,当我们执行反向传播之后grad就有值了

注意:

这里将layer换成了l2,因为第一层l1经过backward之后依然是左图不变,可能是第一层没有梯度?

value  = SaveValues(model.l2)  # modify here: model.l1--->model.l2

remove

关于remove其实如果显存足够可以不用remove,虽然每个epoch的时候hook的值都会变化,但是只占用一个hook的内存,除非开销很大可以考虑remove

visual

当我们的SaveValues类提取出特征图之后,就可以对value.output进行可视化啦

当然如果有需要也可以用input、output或者grad进行相应的操作

使用Pytorch的hook机制提取特征时踩的一个坑

因为项目需求,需要用DenseNet模型提取图片特征,在使用Pytorch的hook机制提取特征,调试的时候发现提取出来的特征数值上全部大于等于0。

很明显提取出来的特征是经过ReLU的。现在来看一下笔者是怎么定义hook的:

fmap_block = []
# 注册hook
def forward_hook(module, input, output):
    fmap_block.append(output)
get_feature_model = densenet121(num_classes=2, pretrained=False)
model_dict = torch.load(model_weight_path)
get_feature_model = nn.DataParallel(get_feature_model.cuda())
get_feature_model.module.features.register_forward_hook(forward_hook)

模型定义的时候因项目需求,笔者并没有使用预训练模型。而是自己训练了一个DenseNet121模型,并且使用了DataParallel进行包装。这里有两点需要注意:

1.大部分的官方模型都会分成两部分,分别是特征层和分类层。

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1)
        out = self.classifier(out)
        return out

这是DenseNet模型前向传播代码,很明显就是笔者上诉说的那样。所以在使用Pytorch的hook进行提取特征的时候可以很方便的定义成这个样子:

DenseNet类实例.features.register_forward_hook(forward_hook)

2.眼尖的读者可以发现笔者的代码里并不是这样定义的,多了一个.module(这也算是一个小小的坑)。这是因为笔者使用了DataParallel进行包装模型,使之可以使用多GPU训练,下面来看一下DataParallel的源码:

    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(DataParallel, self).__init__()
        if not torch.cuda.is_available():
            self.module = module
            self.device_ids = []

可以看到初始化DataParallel类的时候,将model作为一个参数传给了module,所以得多加一个.module才能定位到我们需要的feature。

看到这里,估计很多人已经发现问题在哪里了,没错,问题出现了前向传播部分,更准确的来说是relu函数。

out = F.relu(features, inplace=True)

inplace表示原地修改张量,所以经过relu层时提前放在列表中的特征张量就会被修改。两种解决方法:

将inplace置为False,这样就不会原地修改张量了。修改hook函数

def forward_hook(module, input, output):
    fmap_block.append(output.detach().cpu())

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python中zipfile压缩文件模块的基本使用教程

    Python中zipfile压缩文件模块的基本使用教程

    这篇文章主要给大家介绍了关于Python中zipfile压缩文件模块的基本使用教程,文中通过示例代码介绍的非常详细,对大家学习或者使用Python具有一定的参考学习价值,需要的朋友们下面来一起学习学习吧
    2020-06-06
  • Python全栈之学习MySQL(1)

    Python全栈之学习MySQL(1)

    这篇文章主要为大家介绍了Python全栈之MySQL,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助
    2022-01-01
  • python如何实现不可变字典inmutabledict

    python如何实现不可变字典inmutabledict

    这篇文章主要介绍了python如何实现不可变字典inmutabledict,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-01-01
  • Python自动化办公之图片转PDF的实现

    Python自动化办公之图片转PDF的实现

    实现图片转换成PDF文档的操作方法有很多,综合对比以后感觉fpdf这个模块用起来比较方便而且代码量相当少。所以本文将利用Python语言实现图片转PDF,感兴趣的可以了解一下
    2022-04-04
  • python 3.6 tkinter+urllib+json实现火车车次信息查询功能

    python 3.6 tkinter+urllib+json实现火车车次信息查询功能

    这篇文章主要介绍了python 3.6 tkinter+urllib+json 火车车次信息查询功能,本文以查询火车车次至南京的信息为例,需要的朋友可以参考下
    2017-12-12
  • Python安装Matplotlib包完整步骤记录

    Python安装Matplotlib包完整步骤记录

    这篇文章主要给大家介绍了关于Python安装Matplotlib包的相关资料,Matplotlib是一个Python 2D绘图库,它以多种硬拷贝格式和跨平台的交互式环境生成出版物质量的图形,需要的朋友可以参考下
    2023-12-12
  • Python标准库之Sys模块使用详解

    Python标准库之Sys模块使用详解

    这篇文章主要介绍了Python标准库之Sys模块使用详解,本文讲解了使用sys模块获得脚本的参数、处理模块、使用sys模块操作模块搜索路径、使用sys模块查找内建模块、使用sys模块查找已导入的模块等使用案例,需要的朋友可以参考下
    2015-05-05
  • 基于Pytorch实现逻辑回归

    基于Pytorch实现逻辑回归

    这篇文章主要为大家详细介绍了基于Pytorch实现逻辑回归,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-07-07
  • python 如何通过KNN来填充缺失值

    python 如何通过KNN来填充缺失值

    这篇文章主要介绍了python 通过KNN来填充缺失值的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2021-05-05
  • Python 测试框架unittest和pytest的优劣

    Python 测试框架unittest和pytest的优劣

    这篇文章主要介绍了Python 测试框架unittest和pytest的优劣,帮助大家更好的进行python程序的测试,感兴趣的朋友可以了解下
    2020-09-09

最新评论