对Pytorch中nn.ModuleList 和 nn.Sequential详解

 更新时间:2019年08月18日 08:58:05   作者:ustc_lijia  
今天小编就为大家分享一篇对Pytorch中nn.ModuleList 和 nn.Sequential详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

简而言之就是,nn.Sequential类似于Keras中的贯序模型,它是Module的子类,在构建数个网络层之后会自动调用forward()方法,从而有网络模型生成。而nn.ModuleList仅仅类似于pytho中的list类型,只是将一系列层装入列表,并没有实现forward()方法,因此也不会有网络模型产生的副作用。

需要注意的是,nn.ModuleList接受的必须是subModule类型,例如:

nn.ModuleList(
      [nn.ModuleList([Conv(inp_dim + j * increase, oup_dim, 1, relu=False, bn=False) for j in range(5)]) for i in
       range(nstack)])

其中,二次嵌套的list内部也必须额外使用一个nn.ModuleList修饰实例化,否则会无法识别类型而报错!

摘录自

nn.ModuleList is just like a Python list. It was designed to store any desired number of nn.Module's. It may be useful, for instance, if you want to design a neural network whose number of layers is passed as input:

class LinearNet(nn.Module):
 def __init__(self, input_size, num_layers, layers_size, output_size):
   super(LinearNet, self).__init__()
 
   self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
   self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])
   self.linears.append(nn.Linear(layers_size, output_size)

nn.Sequential allows you to build a neural net by specifying sequentially the building blocks (nn.Module's) of that net. Here's an example:

class Flatten(nn.Module):
 def forward(self, x):
  N, C, H, W = x.size() # read in N, C, H, W
  return x.view(N, -1)
 
simple_cnn = nn.Sequential(
      nn.Conv2d(3, 32, kernel_size=7, stride=2),
      nn.ReLU(inplace=True),
      Flatten(), 
      nn.Linear(5408, 10),
     )

In nn.Sequential, the nn.Module's stored inside are connected in a cascaded way. For instance, in the example that I gave, I define a neural network that receives as input an image with 3 channels and outputs 10 neurons. That network is composed by the following blocks, in the following order: Conv2D -> ReLU -> Linear layer. Moreover, an object of type nn.Sequential has a forward() method, so if I have an input image x I can directly call y = simple_cnn(x) to obtain the scores for x. When you define an nn.Sequential you must be careful to make sure that the output size of a block matches the input size of the following block. Basically, it behaves just like a nn.Module

On the other hand, nn.ModuleList does not have a forward() method, because it does not define any neural network, that is, there is no connection between each of the nn.Module's that it stores. You may use it to store nn.Module's, just like you use Python lists to store other types of objects (integers, strings, etc). The advantage of using nn.ModuleList's instead of using conventional Python lists to store nn.Module's is that Pytorch is “aware” of the existence of the nn.Module's inside an nn.ModuleList, which is not the case for Python lists. If you want to understand exactly what I mean, just try to redefine my class LinearNet using a Python list instead of a nn.ModuleList and train it. When defining the optimizer() for that net, you'll get an error saying that your model has no parameters, because PyTorch does not see the parameters of the layers stored in a Python list. If you use a nn.ModuleList instead, you'll get no error.

以上这篇对Pytorch中nn.ModuleList 和 nn.Sequential详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python自定义模块的创建与使用

    Python自定义模块的创建与使用

    这篇文章主要给大家介绍了关于Python自定义模块创建与使用的相关资料,文中还给大家分享了python打包用户自定义模块的方法,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-05-05
  • pytorch 常用线性函数详解

    pytorch 常用线性函数详解

    今天小编就为大家分享一篇pytorch 常用线性函数详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-01-01
  • 解决jupyter notebook 出现In[*]的问题

    解决jupyter notebook 出现In[*]的问题

    这篇文章主要介绍了解决jupyter notebook 出现In[*]的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-04-04
  • python处理文本文件并生成指定格式的文件

    python处理文本文件并生成指定格式的文件

    本节主要介绍了python如何处理文本文件并生成指定格式的文件,需要的朋友可以参考下
    2014-07-07
  • 用Python的绘图库(matplotlib)绘制小波能量谱

    用Python的绘图库(matplotlib)绘制小波能量谱

    这篇文章主要介绍了用Python的绘图库(matplotlib)绘制小波能量谱,代码简单详细,思路清晰,需要的朋友可以参考下
    2021-04-04
  • python3如何获取子线程中函数返回值

    python3如何获取子线程中函数返回值

    这篇文章主要介绍了python3如何获取子线程中函数返回值问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-11-11
  • pytorch加载自己的数据集源码分享

    pytorch加载自己的数据集源码分享

    这篇文章主要介绍了pytorch加载自己的数据集源码分享,标准的数据集流程梳理分为数据准备以及加载数据库–>数据加载器的调用或者设计–>批量调用进行训练或者其他作用,需要的朋友可以参考下
    2022-08-08
  • Python3使用 GitLab API 进行批量合并分支

    Python3使用 GitLab API 进行批量合并分支

    这篇文章主要介绍了Python3使用 GitLab API 进行批量合并分支的思路详解,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-10-10
  • Python pandas用法最全整理

    Python pandas用法最全整理

    在本篇文章里小编给大家分享的是关于Python pandas用法以及相关实例代码,需要的朋友们可以学习下。
    2019-08-08
  • Python 制作查询商品历史价格的小工具

    Python 制作查询商品历史价格的小工具

    这篇文章主要介绍了Python 如何制作查询商品历史价格的小工具,帮助大家更好的理解和学习python,感兴趣的朋友可以了解下
    2020-10-10

最新评论