pytorch打印网络结构的实例

 更新时间:2019年08月19日 14:51:34   作者:每天都要深度学习  
今天小编就为大家分享一篇pytorch打印网络结构的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

最简单的方法当然可以直接print(net),但是这样网络比较复杂的时候效果不太好,看着比较乱;以前使用caffe的时候有一个网站可以在线生成网络框图,tensorflow可以用tensor board,keras中可以用model.summary()、或者plot_model()。pytorch没有这样的API,但是可以用代码来完成。

(1)安装环境:graphviz

conda install -n pytorch python-graphviz

或:

sudo apt-get install graphviz

或者从官网下载,按此教程。

(2)生成网络结构的代码:

def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot

(3)打印网络结构:

import torch 
from torch.autograd import Variable 
import torch.nn as nn 
from graphviz import Digraph
 
class CNN(nn.module):
  def __init__(self):
   ******
   def forward(self,x):
   ******
   return out
 
*****************************
def make_dot(): #复制上面的代码
*****************************
 
if __name__ == '__main__': 
  net = CNN() 
  x = Variable(torch.randn(1, 1, 1024,1024)) 
  y = net(x) 
  g = make_dot(y) 
  g.view() 
 
  params = list(net.parameters()) 
  k = 0 
  for i in params: 
    l = 1 
    print("该层的结构:" + str(list(i.size()))) 
    for j in i.size(): 
      l *= j 
    print("该层参数和:" + str(l)) 
    k = k + l 
  print("总参数数量和:" + str(k))

(4)结果展示(例如这是一个resnet block类型的网络):

以上这篇pytorch打印网络结构的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python numpy下几种fft函数的使用方式

    Python numpy下几种fft函数的使用方式

    numpy中有一个fft的库,scipy中也有一个fftpack的库,各自都有fft函数,两者的用法基本是一致的,下面这篇文章主要给大家介绍了关于Python numpy下几种fft函数的使用方式,需要的朋友可以参考下
    2022-08-08
  • 利用信号如何监控Django模型对象字段值的变化详解

    利用信号如何监控Django模型对象字段值的变化详解

    这篇文章主要给大家介绍了关于利用信号如何监控Django模型对象字段值变化的相关资料,文中通过示例代码介绍的非常详细,需要的朋友可以参考借鉴,下面随着小编来一起学习学习吧。
    2017-11-11
  • python 执行终端/控制台命令的例子

    python 执行终端/控制台命令的例子

    今天小编就为大家分享一篇python 执行终端/控制台命令的例子,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07
  • 使用OpenCV实现逐帧获取视频图片

    使用OpenCV实现逐帧获取视频图片

    这篇文章主要为大家详细介绍了如何使用OpenCV实现逐帧获取视频中的图片用来标注,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下
    2024-03-03
  • Python 机器学习之线性回归详解分析

    Python 机器学习之线性回归详解分析

    回归是监督学习的一个重要问题,回归用于预测输入变量和输出变量之间的关系,特别是当输入变量的值发生变化时,输出变量的值也随之发生变化。回归模型正是表示从输入变量到输出变量之间映射的函数
    2021-11-11
  • python代码中怎么换行

    python代码中怎么换行

    这篇文章主要介绍了python代码中怎么换行的相关知识点以及方法,需要的朋友们可以学习下。
    2020-06-06
  • python安装第三方库如xlrd的方法

    python安装第三方库如xlrd的方法

    这篇文章主要介绍了python安装第三方库如xlrd的方法,本文通过两种方法给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-10-10
  • Python的re模块正则表达式操作

    Python的re模块正则表达式操作

    这篇文章主要介绍了Python的re模块正则表达式操作 的相关资料,需要的朋友可以参考下
    2016-05-05
  • 不可错过的十本Python好书

    不可错过的十本Python好书

    不可错过的十本Python好书,分别适合入门、进阶到精深三个不同阶段的人来阅读,感兴趣的小伙伴们可以参考一下
    2017-07-07
  • 删除python pandas.DataFrame 的多重index实例

    删除python pandas.DataFrame 的多重index实例

    今天小编就为大家分享一篇删除python pandas.DataFrame 的多重index实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06

最新评论