pytorch之torch.flatten()和torch.nn.Flatten()的用法

 更新时间:2025年04月10日 14:42:10   作者:三つ叶  
这篇文章主要介绍了pytorch之torch.flatten()和torch.nn.Flatten()的用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

torch.flatten()和torch.nn.Flatten()的用法

flatten()函数的作用是将tensor铺平成一维

torch.flatten(input, start_dim=0, end_dim=- 1) → Tensor
  • input (Tensor) – the input tensor.
  • start_dim (int) – the first dim to flatten
  • end_dim (int) – the last dim to flatten

start_dim和end_dim构成了整个你要选择铺平的维度范围

下面举例说明

x = torch.tensor([[1,2], [3,4], [5,6]])
x = x.flatten(0)
x
------------------------
tensor([1, 2, 3, 4, 5, 6])

对于图片数据,我们往往期望进入fc层的维度为(channels, N)这样

x = torch.tensor([[[1,2],[3,4]], [[5,6],[7,8]]])
x = x.flatten(1)
x
-------------------------
tensor([[1, 2],
        [3, 4],
        [5, 6]])

注:

torch.nn.Flatten(start_dim=1, end_dim=- 1)

start_dim 默认为 1

所以在构建网络时,下面两种是等价的

class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        # The arguments for commonly used modules:
        # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0)
        # torch.nn.MaxPool2d(kernel_size, stride=None, padding=0)

        # input image size: [3, 128, 128]
        self.cnn_layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=4, padding=0),
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(256 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 11)
        )

    def forward(self, x):
        # input (x): [batch_size, 3, 128, 128]
        # output: [batch_size, 11]

        # Extract features by convolutional layers.
        x = self.cnn_layers(x)

        # The extracted feature map must be flatten before going to fully-connected layers.
        x = x.flatten(1)

        # The features are transformed by fully-connected layers to obtain the final logits.
        x = self.fc_layers(x)
        return x
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=4, padding=0),

            nn.Flatten(),

            nn.Linear(256 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 11)
        )

    def forward(self, x):
       
        x = self.layers(x)

        return x

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python的三种主要模块介绍

    Python的三种主要模块介绍

    这篇文章介绍了Python的三类主要模块,文中通过示例代码介绍的非常详细。对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-07-07
  • 浅析python继承与多重继承

    浅析python继承与多重继承

    在本篇文章中我们给大家分析了python继承与多重继承的相关知识点内容,有兴趣的读者们参考下。
    2018-09-09
  • Python导入Excel表格数据并以字典dict格式保存的操作方法

    Python导入Excel表格数据并以字典dict格式保存的操作方法

    本文介绍基于Python语言,将一个Excel表格文件中的数据导入到Python中,并将其通过字典格式来存储的方法,感兴趣的朋友一起看看吧
    2023-01-01
  • python网络爬虫采集联想词示例

    python网络爬虫采集联想词示例

    这篇文章主要介绍了python网络爬虫采集联想词示例,需要的朋友可以参考下
    2014-02-02
  • Django中外键使用总结

    Django中外键使用总结

    本文主要介绍了Django中外键使用总结,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2022-07-07
  • Python中字符串去空格的五种方法介绍与对比

    Python中字符串去空格的五种方法介绍与对比

    在 Python 中,去除字符串中的空格是一个常见的操作,这篇文章小编将为大家盘点一下python中常用的的去空格的方法,需要的可以参考一下
    2025-02-02
  • Python数据传输黏包问题

    Python数据传输黏包问题

    这篇文章主要介绍了Python数据传输黏包问题,黏包指数据与数据之间没有明确的分界线,导致不能正确的读取数据,更多相关内容需要的小伙伴可以参考一下
    2022-04-04
  • Keras实现将两个模型连接到一起

    Keras实现将两个模型连接到一起

    这篇文章主要介绍了Keras实现将两个模型连接到一起,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • openCV-Python笔记之解读图像的读取、显示和保存问题

    openCV-Python笔记之解读图像的读取、显示和保存问题

    这篇文章主要介绍了openCV-Python笔记之解读图像的读取、显示和保存问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-12-12
  • 详解Python数据结构与算法中的顺序表

    详解Python数据结构与算法中的顺序表

    线性表在计算机中的表示可以采用多种方法,采用不同存储方法的线性表也有着不同的名称和特点。线性表有两种基本的存储结构:顺序存储结构和链式存储结构。本文将介绍顺序存储结构的特点以及各种基本运算的实现。需要的可以参考一下
    2022-01-01

最新评论