pytorch之torch.flatten()和torch.nn.Flatten()的用法

 更新时间:2025年04月10日 14:42:10   作者:三つ叶  
这篇文章主要介绍了pytorch之torch.flatten()和torch.nn.Flatten()的用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

torch.flatten()和torch.nn.Flatten()的用法

flatten()函数的作用是将tensor铺平成一维

torch.flatten(input, start_dim=0, end_dim=- 1) → Tensor
  • input (Tensor) – the input tensor.
  • start_dim (int) – the first dim to flatten
  • end_dim (int) – the last dim to flatten

start_dim和end_dim构成了整个你要选择铺平的维度范围

下面举例说明

x = torch.tensor([[1,2], [3,4], [5,6]])
x = x.flatten(0)
x
------------------------
tensor([1, 2, 3, 4, 5, 6])

对于图片数据,我们往往期望进入fc层的维度为(channels, N)这样

x = torch.tensor([[[1,2],[3,4]], [[5,6],[7,8]]])
x = x.flatten(1)
x
-------------------------
tensor([[1, 2],
        [3, 4],
        [5, 6]])

注:

torch.nn.Flatten(start_dim=1, end_dim=- 1)

start_dim 默认为 1

所以在构建网络时,下面两种是等价的

class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        # The arguments for commonly used modules:
        # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0)
        # torch.nn.MaxPool2d(kernel_size, stride=None, padding=0)

        # input image size: [3, 128, 128]
        self.cnn_layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=4, padding=0),
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(256 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 11)
        )

    def forward(self, x):
        # input (x): [batch_size, 3, 128, 128]
        # output: [batch_size, 11]

        # Extract features by convolutional layers.
        x = self.cnn_layers(x)

        # The extracted feature map must be flatten before going to fully-connected layers.
        x = x.flatten(1)

        # The features are transformed by fully-connected layers to obtain the final logits.
        x = self.fc_layers(x)
        return x
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=4, padding=0),

            nn.Flatten(),

            nn.Linear(256 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 11)
        )

    def forward(self, x):
       
        x = self.layers(x)

        return x

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python中__new__函数的具体使用

    python中__new__函数的具体使用

    new是object基类提供的内置静态方法,本文主要介绍了python中__new__函数的具体使用,具有一定的参考价值,感兴趣的可以了解一下
    2025-09-09
  • Python pandas 计算每行的增长率与累计增长率

    Python pandas 计算每行的增长率与累计增长率

    这篇文章主要介绍了Python pandas 计算每行的增长率与累计增长率,文章举例详细说明。需要的小伙伴可以参考一下
    2022-03-03
  • 使用Python将PDF转成Excel的代码实现

    使用Python将PDF转成Excel的代码实现

    在日常工作中,您是否曾被困扰于从复杂的PDF文档中手动提取数据,特别是表格数据,然后逐一录入到Excel,这项任务不仅耗时耗力,还极易引入人为错误,严重影响工作效率,本文将深入探讨如何利用Spire.PDF for Python这一高效库,轻松实现PDF转Excel的需求
    2025-10-10
  • tensorflow saver 保存和恢复指定 tensor的实例讲解

    tensorflow saver 保存和恢复指定 tensor的实例讲解

    今天小编就为大家分享一篇tensorflow saver 保存和恢复指定 tensor的实例讲解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-07-07
  • python编写猜数字小游戏

    python编写猜数字小游戏

    这篇文章主要为大家详细介绍了python编写猜数字小游戏,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-10-10
  • Python实现OFD文件转PDF

    Python实现OFD文件转PDF

    OFD 文件是由中国国家标准化管理委员会制定的国家标准,是一种开放式文档格式,具有高度可扩展性和可编辑性,本文主要介绍了如何利用Python实现OFD文件转PDF,需要的可以参考下
    2024-10-10
  • tensorflow实现对图片的读取的示例代码

    tensorflow实现对图片的读取的示例代码

    本篇文章主要介绍了tensorflow实现对图片的读取的示例代码,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-02-02
  • 在Python编程过程中用单元测试法调试代码的介绍

    在Python编程过程中用单元测试法调试代码的介绍

    这篇文章主要介绍了在Python编程过程中用单元测试法调试代码的介绍,包括使用断言等,有助于debug时的效率提升,需要的朋友可以参考下
    2015-04-04
  • Python2随机数列生成器简单实例

    Python2随机数列生成器简单实例

    这篇文章主要介绍了Python2随机数列生成器,结合简单实例形式分析了Python基于random模块操作随机数的相关实现技巧,需要的朋友可以参考下
    2017-09-09
  • Python网络爬虫技术高阶用法

    Python网络爬虫技术高阶用法

    网络爬虫成为了自动化数据抓取的核心工具,Python 拥有强大的第三方库支持,在网络爬虫领域的应用尤为广泛,本文将深入探讨 Python 网络爬虫的高阶用法,包括处理反爬虫机制、动态网页抓取、分布式爬虫以及并发和异步爬虫等技术,帮助读者掌握高级Python爬虫技术
    2024-12-12

最新评论