分享Pytorch获取中间层输出的3种方法

 更新时间:2022年03月09日 17:24:20   作者:我是蛋蛋  
这篇文章主要给大家分享了Pytorch获取中间层输出的3种方法,文章内容介绍详细,需要的小伙伴可以参考一下,希望对你的学习或工作有所帮助

【1】方法一:获取nn.Sequential的中间层输出

import torch
import torch.nn as nn
model = nn.Sequential(
            nn.Conv2d(3, 9, 1, 1, 0, bias=False),
            nn.BatchNorm2d(9),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1)),
        )

# 假如想要获得ReLu的输出
x = torch.rand([2, 3, 224, 224])
for i in range(len(model)):
    x = model[i](x)
    if i == 2:
        ReLu_out = x
print('ReLu_out.shape:\n\t',ReLu_out.shape)
print('x.shape:\n\t',x.shape)

结果:

ReLu_out.shape:
  torch.Size([2, 9, 224, 224])
x.shape:
  torch.Size([2, 9, 1, 1])

【2】方法二:IntermediateLayerGetter

from collections import OrderedDict
 
import torch
from torch import nn
 
 
class IntermediateLayerGetter(nn.ModuleDict):
    """
    Module wrapper that returns intermediate layers from a model
    It has a strong assumption that the modules have been registered
    into the model in the same order as they are used.
    This means that one should **not** reuse the same nn.Module
    twice in the forward if you want this to work.
    Additionally, it is only able to query submodules that are directly
    assigned to the model. So if `model` is passed, `model.feature1` can
    be returned, but not `model.feature1.layer2`.
    Arguments:
        model (nn.Module): model on which we will extract the features
        return_layers (Dict[name, new_name]): a dict containing the names
            of the modules for which the activations will be returned as
            the key of the dict, and the value of the dict is the name
            of the returned activation (which the user can specify).
    """
    
    def __init__(self, model, return_layers):
        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
            raise ValueError("return_layers are not present in model")
 
        orig_return_layers = return_layers
        return_layers = {k: v for k, v in return_layers.items()}
        layers = OrderedDict()
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layers:
                del return_layers[name]
            if not return_layers:
                break
 
        super(IntermediateLayerGetter, self).__init__(layers)
        self.return_layers = orig_return_layers
 
    def forward(self, x):
        out = OrderedDict()
        for name, module in self.named_children():
            x = module(x)
            if name in self.return_layers:
                out_name = self.return_layers[name]
                out[out_name] = x
        return out
# example
m = torchvision.models.resnet18(pretrained=True)
# extract layer1 and layer3, giving as names `feat1` and feat2`
new_m = torchvision.models._utils.IntermediateLayerGetter(m,{'layer1': 'feat1', 'layer3': 'feat2'})
out = new_m(torch.rand(1, 3, 224, 224))
print([(k, v.shape) for k, v in out.items()])
# [('feat1', torch.Size([1, 64, 56, 56])), ('feat2', torch.Size([1, 256, 14, 14]))]

作用:

在定义它的时候注明作用的模型(如下例中的m)和要返回的layer(如下例中的layer1,layer3),得到new_m。

使用时喂输入变量,返回的就是对应的layer

举例:

m = torchvision.models.resnet18(pretrained=True)
 # extract layer1 and layer3, giving as names `feat1` and feat2`
new_m = torchvision.models._utils.IntermediateLayerGetter(m,{'layer1': 'feat1', 'layer3': 'feat2'})
out = new_m(torch.rand(1, 3, 224, 224))
print([(k, v.shape) for k, v in out.items()])

输出结果:

[('feat1', torch.Size([1, 64, 56, 56])), ('feat2', torch.Size([1, 256, 14, 14]))]

【3】方法三:钩子

class TestForHook(nn.Module):
    def __init__(self):
        super().__init__()

        self.linear_1 = nn.Linear(in_features=2, out_features=2)
        self.linear_2 = nn.Linear(in_features=2, out_features=1)
        self.relu = nn.ReLU()
        self.relu6 = nn.ReLU6()
        self.initialize()

    def forward(self, x):
        linear_1 = self.linear_1(x)
        linear_2 = self.linear_2(linear_1)
        relu = self.relu(linear_2)
        relu_6 = self.relu6(relu)
        layers_in = (x, linear_1, linear_2)
        layers_out = (linear_1, linear_2, relu)
        return relu_6, layers_in, layers_out

features_in_hook = []
features_out_hook = []

def hook(module, fea_in, fea_out):
    features_in_hook.append(fea_in)
    features_out_hook.append(fea_out)
    return None

net = TestForHook()

第一种写法,按照类型勾,但如果有重复类型的layer比较复杂

net_chilren = net.children()
for child in net_chilren:
    if not isinstance(child, nn.ReLU6):
        child.register_forward_hook(hook=hook)

推荐下面我改的这种写法,因为我自己的网络中,在Sequential中有很多层,
这种方式可以直接先print(net)一下,找出自己所需要那个layer的名称,按名称勾出来

layer_name = 'relu_6'
for (name, module) in net.named_modules():
    if name == layer_name:
        module.register_forward_hook(hook=hook)

print(features_in_hook)  # 勾的是指定层的输入
print(features_out_hook)  # 勾的是指定层的输出

到此这篇关于分享Pytorch获取中间层输出的3种方法的文章就介绍到这了,更多相关Pytorch获取中间层输出方法内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • 用Python爬取LOL所有的英雄信息以及英雄皮肤的示例代码

    用Python爬取LOL所有的英雄信息以及英雄皮肤的示例代码

    这篇文章主要介绍了用Python爬取LOL所有的英雄信息以及英雄皮肤的示例代码,主要分为两部分,获取网页上数据和图片保存到本地等,感兴趣的可以了解一下
    2020-07-07
  • 没编程基础可以学python吗

    没编程基础可以学python吗

    在本篇文章里小编给大家整理的是关于没编程基础可以学python吗的相关知识点,需要的朋友们可以学习下。
    2020-06-06
  • Django中Middleware中的函数详解

    Django中Middleware中的函数详解

    这篇文章主要介绍了Django中Middleware中的函数详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-07-07
  • python 布尔操作实现代码

    python 布尔操作实现代码

    python布尔操作也是我们经常写代码需要用到的,首先我们需要明白在python里面,哪些被解释器当做真,哪些当做假
    2013-03-03
  • python爬取w3shcool的JQuery课程并且保存到本地

    python爬取w3shcool的JQuery课程并且保存到本地

    本文主要介绍python爬取w3shcool的JQuery的课程并且保存到本地的方法解析。具有很好的参考价值。下面跟着小编一起来看下吧
    2017-04-04
  • 基于python的MD5脚本开发思路

    基于python的MD5脚本开发思路

    这篇文章主要介绍了基于python的MD5脚本,通过 string模块自动生成字典,使用permutations()函数,对字典进行全排列,本文通过实例代码给大家介绍的非常详细,需要的朋友可以参考下
    2022-03-03
  • 关于Python数据处理中的None、NULL和NaN的理解与应用

    关于Python数据处理中的None、NULL和NaN的理解与应用

    这篇文章主要介绍了关于Python数据处理中的None、NULL和NaN的理解与应用,None表示空值,一个特殊Python对象,None的类型是NoneType,需要的朋友可以参考下
    2023-08-08
  • Python3.10和Python3.9版本之间的差异介绍

    Python3.10和Python3.9版本之间的差异介绍

    大家好,本篇文章主要讲的是Python3.10和Python3.9版本之间的差异介绍,感兴趣的同学赶快来看一看吧,对你有帮助的话记得收藏一下哦
    2021-12-12
  • Python Flask搭建yolov3目标检测系统详解流程

    Python Flask搭建yolov3目标检测系统详解流程

    YOLOv3没有太多的创新,主要是借鉴一些好的方案融合到YOLO里面。不过效果还是不错的,在保持速度优势的前提下,提升了预测精度,尤其是加强了对小物体的识别能力
    2021-11-11
  • Python3实现打格点算法的GPU加速实例详解

    Python3实现打格点算法的GPU加速实例详解

    这篇文章主要给大家介绍了关于Python3实现打格点算法的GPU加速的相关资料,文中通过示例代码介绍的非常详细,对大家学习或者使用python具有一定的参考学习价值,需要的朋友可以参考下
    2021-09-09

最新评论