pytorch 可视化feature map的示例代码

 更新时间:2019年08月20日 10:10:32   作者:牛丸4  
今天小编就为大家分享一篇pytorch 可视化feature map的示例代码,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

之前做的一些项目中涉及到feature map 可视化的问题,一个层中feature map的数量往往就是当前层out_channels的值,我们可以通过以下代码可视化自己网络中某层的feature map,个人感觉可视化feature map对调参还是很有用的。

不多说了,直接看代码:

import torch
from torch.autograd import Variable
import torch.nn as nn
import pickle

from sys import path
path.append('/residual model path')
import residual_model
from residual_model import Residual_Model

model = Residual_Model()
model.load_state_dict(torch.load('./model.pkl'))



class myNet(nn.Module):
  def __init__(self,pretrained_model,layers):
    super(myNet,self).__init__()
    self.net1 = nn.Sequential(*list(pretrained_model.children())[:layers[0]])
    self.net2 = nn.Sequential(*list(pretrained_model.children())[:layers[1]])
    self.net3 = nn.Sequential(*list(pretrained_model.children())[:layers[2]])

  def forward(self,x):
    out1 = self.net1(x)
    out2 = self.net(out1)
    out3 = self.net(out2)
    return out1,out2,out3

def get_features(pretrained_model, x, layers = [3, 4, 9]): ## get_features 其实很简单
'''
1.首先import model 
2.将weights load 进model
3.熟悉model的每一层的位置,提前知道要输出feature map的网络层是处于网络的那一层
4.直接将test_x输入网络,*list(model.chidren())是用来提取网络的每一层的结构的。net1 = nn.Sequential(*list(pretrained_model.children())[:layers[0]]) ,就是第三层前的所有层。

'''
  net1 = nn.Sequential(*list(pretrained_model.children())[:layers[0]]) 
#  print net1 
  out1 = net1(x) 

  net2 = nn.Sequential(*list(pretrained_model.children())[layers[0]:layers[1]]) 
#  print net2 
  out2 = net2(out1) 

  #net3 = nn.Sequential(*list(pretrained_model.children())[layers[1]:layers[2]]) 
  #out3 = net3(out2) 

  return out1, out2
with open('test.pickle','rb') as f:
  data = pickle.load(f)
x = data['test_mains'][0]
x = Variable(torch.from_numpy(x)).view(1,1,128,1) ## test_x必须为Varibable
#x = Variable(torch.randn(1,1,128,1))
if torch.cuda.is_available():
  x = x.cuda() # 如果模型的训练是用cuda加速的话,输入的变量也必须是cuda加速的,两个必须是对应的,网络的参数weight都是用cuda加速的,不然会报错
  model = model.cuda()
output1,output2 = get_features(model,x)## model是训练好的model,前面已经import 进来了Residual model
print('output1.shape:',output1.shape)
print('output2.shape:',output2.shape)
#print('output3.shape:',output3.shape)
output_1 = torch.squeeze(output2,dim = 0)
output_1_arr = output_1.data.cpu().numpy() # 得到的cuda加速的输出不能直接转变成numpy格式的,当时根据报错的信息首先将变量转换为cpu的,然后转换为numpy的格式
output_1_arr = output_1_arr.reshape([output_1_arr.shape[0],output_1_arr.shape[1]])

以上这篇pytorch 可视化feature map的示例代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 使用python和pygame制作挡板弹球游戏

    使用python和pygame制作挡板弹球游戏

    这篇文章主要介绍了使用python和pygame制作挡板弹球游戏,本文通过实例代码给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-12-12
  • Keras框架中的epoch、bacth、batch size、iteration使用介绍

    Keras框架中的epoch、bacth、batch size、iteration使用介绍

    这篇文章主要介绍了Keras框架中的epoch、bacth、batch size、iteration使用介绍,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-06-06
  • Python Django 后台管理之后台模型属性详解

    Python Django 后台管理之后台模型属性详解

    这篇文章主要介绍了Python Django 后台管理之后台模型属性,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-04-04
  • python学习--使用QQ邮箱发送邮件代码实例

    python学习--使用QQ邮箱发送邮件代码实例

    这篇文章主要介绍了python使用QQ邮箱发送邮件,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-04-04
  • python程序 线程队列queue使用方法解析

    python程序 线程队列queue使用方法解析

    这篇文章主要介绍了python程序 线程队列queue使用方法解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-09-09
  • Python中表达式x += y和x = x+y 的区别详解

    Python中表达式x += y和x = x+y 的区别详解

    这篇文章主要跟大家介绍了关于Python中x += y和x = x+y 的区别的相关资料,文中通过示例代码介绍的非常详细,对大家具有一定的参考学习价值,需要的朋友们下面来一起看看吧。
    2017-06-06
  • python网络应用开发知识点浅析

    python网络应用开发知识点浅析

    在本篇内容中小编给学习python的朋友们整理了关于网络应用开发的相关知识点以及实例内容,需要的朋友们参考下。
    2019-05-05
  • django限制匿名用户访问及重定向的方法实例

    django限制匿名用户访问及重定向的方法实例

    这篇文章主要给大家介绍了关于django限制匿名用户访问及重定向的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧。
    2018-02-02
  • 卷积神经网络如何实现提取特征

    卷积神经网络如何实现提取特征

    这篇文章主要介绍了卷积神经网络如何实现提取特征问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-04-04
  • Python利用matplotlib.pyplot绘图时如何设置坐标轴刻度

    Python利用matplotlib.pyplot绘图时如何设置坐标轴刻度

    Matplotlib是Python提供的一个二维绘图库,所有类型的平面图,包括直方图、散点图、折线图、点图、热图以及其他各种类型,都能由Python制作出来。本文主要介绍了关于Python利用matplotlib.pyplot绘图时如何设置坐标轴刻度的相关资料,需要的朋友可以参考下。
    2018-04-04

最新评论