pytorch中构建模型的3种方法详解

 更新时间:2023年09月22日 16:00:20   作者:hxh207  
这篇文章主要介绍了pytorch中构建模型的3种方法,分别是使用继承nn.Module基类构建自定义模型,使用nn.Sequential按层顺序构建模型或者,继承nn.Module基类构建模型并辅助应用模型容器进行封装(nn.Sequential,nn.ModuleList,nn.ModuleDict),需要的朋友可以参考下

可以使用以下3种方式构建模型:

1,继承nn.Module基类构建自定义模型。

2,使用nn.Sequential按层顺序构建模型。

3,继承nn.Module基类构建模型并辅助应用模型容器进行封装(nn.Sequential,nn.ModuleList,nn.ModuleDict)。

其中 第1种方式最为常见,第2种方式最简单,第3种方式最为灵活也较为复杂。推荐使用第1种方式构建模型。

一、继承nn.Module基类构建自定义模型

以下是继承nn.Module基类构建自定义模型的一个范例。模型中的用到的层一般在__init__函数中定义,然后在forward方法中定义模型的正向传播逻辑

from torch import nn 
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3)
        self.pool1 = nn.MaxPool2d(kernel_size = 2,stride = 2)
        self.conv2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5)
        self.pool2 = nn.MaxPool2d(kernel_size = 2,stride = 2)
        self.dropout = nn.Dropout2d(p = 0.1)
        self.adaptive_pool = nn.AdaptiveMaxPool2d((1,1))
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(64,32)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(32,1)
    def forward(self,x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.dropout(x)
        x = self.adaptive_pool(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.relu(x)
        y = self.linear2(x)
        return y
net = Net()
print(net)

image.png

from torchkeras import summary 
summary(net,input_shape= (3,32,32));

nn.Conv1d:普通一维卷积,常用于文本。参数个数 = 输入通道数×卷积核尺寸(如3)×卷积核个数 + 卷积核尺寸(如3)=卷积核尺寸(如3乘3)x输出通道数+输出通道数(偏置数量)

nn.Conv2d:普通二维卷积,常用于图像。参数个数 = 输入通道数×卷积核尺寸(如3乘3)×卷积核个数 + 卷积核尺寸(如3乘3)。=卷积核尺寸(如3乘3)x输入通道数x输出通道数+输出通道数(偏置数量)) 通过调整dilation参数大于1,可以变成空洞卷积,增加感受野。 通过调整groups参数不为1,可以变成分组卷积。分组卷积中每个卷积核仅对其对应的一个分组进行操作。 当groups参数数量等于输入通道数时,相当于tensorflow中的二维深度卷积层tf.keras.layers.DepthwiseConv2D。 利用分组卷积和1乘1卷积的组合操作,可以构造相当于Keras中的二维深度可分离卷积层tf.keras.layers.SeparableConv2D。

image.png

二、使用nn.Sequential按层顺序构建模型

使用nn.Sequential按层顺序构建模型无需定义forward方法。仅仅适合于简单的模型。以下是使用nn.Sequential搭建模型的一些等价方法。

利用add_module方法

net = nn.Sequential()
net.add_module("conv1",nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3))
net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("conv2",nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5))
net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("dropout",nn.Dropout2d(p = 0.1))
net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
net.add_module("flatten",nn.Flatten())
net.add_module("linear1",nn.Linear(64,32))
net.add_module("relu",nn.ReLU())
net.add_module("linear2",nn.Linear(32,1))
print(net)

image.png

利用变长参数

这种方式构建时不能给每个层指定名称。

net = nn.Sequential(
    nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3),
    nn.MaxPool2d(kernel_size = 2,stride = 2),
    nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
    nn.MaxPool2d(kernel_size = 2,stride = 2),
    nn.Dropout2d(p = 0.1),
    nn.AdaptiveMaxPool2d((1,1)),
    nn.Flatten(),
    nn.Linear(64,32),
    nn.ReLU(),
    nn.Linear(32,1)
)
print(net)

image.png

利用OrderedDict

键值对形式:键为层的名字,值为层的定义

from collections import OrderedDict
net = nn.Sequential(OrderedDict(
          [("conv1",nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3)),
            ("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2)),
            ("conv2",nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5)),
            ("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2)),
            ("dropout",nn.Dropout2d(p = 0.1)),
            ("adaptive_pool",nn.AdaptiveMaxPool2d((1,1))),
            ("flatten",nn.Flatten()),
            ("linear1",nn.Linear(64,32)),
            ("relu",nn.ReLU()),
            ("linear2",nn.Linear(32,1))
          ])
        )
print(net)

image.png

三、继承nn.Module基类构建模型并辅助应用模型容器进行封装

当模型的结构比较复杂时,我们可以应用模型容器(nn.Sequential,nn.ModuleList,nn.ModuleDict)对模型的部分结构进行封装。

这样做会让模型整体更加有层次感,有时候也能减少代码量。(复杂模型的时候比较常用)注意,在下面的范例中我们每次仅仅使用一种模型容器,但实际上这些模型容器的使用是非常灵活的,可以在一个模型中任意组合任意嵌套使用。

相当于结合以上两种方式。

nn.Sequential作为模型容器

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Dropout2d(p = 0.1),
            nn.AdaptiveMaxPool2d((1,1))
        )
        self.dense = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64,32),
            nn.ReLU(),
            nn.Linear(32,1)
        )
    def forward(self,x):
        x = self.conv(x)
        y = self.dense(x)
        return y 
net = Net()
print(net)

image.png

nn.ModuleList作为模型容器

注意下面中的ModuleList不能用Python中的列表代替。(即不用省略)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layers = nn.ModuleList([
            nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Dropout2d(p = 0.1),
            nn.AdaptiveMaxPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(64,32),
            nn.ReLU(),
            nn.Linear(32,1)]
        )
    def forward(self,x):
        for layer in self.layers:
            x = layer(x)
        return x
net = Net()
print(net)

image.png

nn.ModuleDict作为模型容器

注意下面中的ModuleDict不能用Python中的字典代替。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layers_dict = nn.ModuleDict({"conv1":nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3),
               "pool": nn.MaxPool2d(kernel_size = 2,stride = 2),
               "conv2":nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
               "dropout": nn.Dropout2d(p = 0.1),
               "adaptive":nn.AdaptiveMaxPool2d((1,1)),
               "flatten": nn.Flatten(),
               "linear1": nn.Linear(64,32),
               "relu":nn.ReLU(),
               "linear2": nn.Linear(32,1)
              })
    def forward(self,x):
        layers = ["conv1","pool","conv2","pool","dropout","adaptive",
                  "flatten","linear1","relu","linear2","sigmoid"]
        for layer in layers:
            x = self.layers_dict[layer](x) # 只找有的 sigmoid是没有的
        return x
net = Net()
print(net)

image.png

参考:https://github.com/lyhue1991/eat_pytorch_in_20_days

到此这篇关于pytorch中构建模型的3种方法的文章就介绍到这了,更多相关pytorch构建模型内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

您可能感兴趣的文章:

相关文章

  • python 数字类型和字符串类型的相互转换实例

    python 数字类型和字符串类型的相互转换实例

    今天小编就为大家分享一篇python 数字类型和字符串类型的相互转换实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-07-07
  • Python使用flask框架操作sqlite3的两种方式

    Python使用flask框架操作sqlite3的两种方式

    这篇文章主要介绍了Python使用flask框架操作sqlite3的两种方式,结合实例形式分析了Python基于flask框架操作sqlite3数据库的两种常用操作技巧,需要的朋友可以参考下
    2018-01-01
  • python协程之yield和yield from实例详解

    python协程之yield和yield from实例详解

    Python在并发处理上不仅提供了多进程和多线程的处理,还包括了协程,下面这篇文章主要给大家介绍了关于python协程之yield和yield from的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-12-12
  • 解决python2中unicode()函数在python3中报错的问题

    解决python2中unicode()函数在python3中报错的问题

    这篇文章主要介绍了在python2中unicode()函数在python3中报错的解决方案,希望给大家做个参考,下次出现这个问题的时候,也知道如何应对
    2021-05-05
  • python实现Dijkstra算法的最短路径问题

    python实现Dijkstra算法的最短路径问题

    这篇文章主要介绍了python实现Dijkstra算法的最短路径问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-06-06
  • 浅谈Python type的使用

    浅谈Python type的使用

    今天小编就为大家分享一篇浅谈Python type的使用,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-11-11
  • 利用PyQt中的QThread类实现多线程

    利用PyQt中的QThread类实现多线程

    本文主要给大家分享的是python实现多线程及线程间通信的简单方法,非常的实用,有需要的小伙伴可以参考下
    2020-02-02
  • pandas 转换成行列表进行读取与Nan处理的方法

    pandas 转换成行列表进行读取与Nan处理的方法

    今天小编就为大家分享一篇pandas 转换成行列表进行读取与Nan处理的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-10-10
  • PyTorch中关于tensor.repeat()的使用

    PyTorch中关于tensor.repeat()的使用

    这篇文章主要介绍了PyTorch中关于tensor.repeat()的使用,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-11-11
  • 如何使用python-dotenv解决代码与敏感信息的分离

    如何使用python-dotenv解决代码与敏感信息的分离

    我们开发的每个系统都离不开配置信息,这些信息都非常敏感,一旦泄露出去后果非常严重,被泄露的原因一般是程序员将配置信息和代码混在一起导致的,这篇文章主要给大家介绍了关于如何使用python-dotenv解决代码与敏感信息的分离,需要的朋友可以参考下
    2022-03-03

最新评论