Pytorch阅读文档中的flatten函数

 更新时间:2023年11月08日 09:51:23   作者:GhostintheCode  
PyTorch提供了一个非常方便的函数flatten()来完成这个任务,本文将介绍Pytorch阅读文档中的flatten函数,并提供一些示例代码,感兴趣的朋友一起看看吧

Pytorch阅读文档中的flatten函数

pytorch中flatten函数

torch.flatten()

#展平一个连续范围的维度,输出类型为Tensor
torch.flatten(input, start_dim=0, end_dim=-1) → Tensor
# Parameters:input (Tensor) – 输入为Tensor
#start_dim (int) – 展平的开始维度
#end_dim (int) – 展平的最后维度
#example
#一个3x2x2的三维张量
>>> t = torch.tensor([[[1, 2],
                       [3, 4]],
                      [[5, 6],
                       [7, 8]],
                  [[9, 10],
                       [11, 12]]])
#当开始维度为0,最后维度为-1,展开为一维
>>> torch.flatten(t)
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
#当开始维度为0,最后维度为-1,展开为3x4,也就是说第一维度不变,后面的压缩
>>> torch.flatten(t, start_dim=1)
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])
>>> torch.flatten(t, start_dim=1).size()
torch.Size([3, 4])
#下面的和上面进行对比应该就能看出是,当锁定最后的维度的时候
#前面的就会合并
>>> torch.flatten(t, start_dim=0, end_dim=1)
tensor([[ 1,  2],
        [ 3,  4],
        [ 5,  6],
        [ 7,  8],
        [ 9, 10],
        [11, 12]])
>>> torch.flatten(t, start_dim=0, end_dim=1).size()
torch.Size([6, 2])

torch.nn.Flatten()

Class torch.nn.Flatten(start_dim=1, end_dim=-1)
#Flattens a contiguous range of dims into a tensor. 
#For use with Sequential. :
#param start_dim: first dim to flatten (default = 1). 
#param end_dim: last dim to flatten (default = -1).
#能力有限,个人认为是用于卷积中的
#Shape:
#Input: (N, *dims)(N,∗dims)
#Output: (N, \prod *dims)(N,∏∗dims) (for the default case).
#官方example
>>> m = nn.Sequential(
>>>     nn.Conv2d(1, 32, 5, 1, 1),
>>>     nn.Flatten()
>>> )
#源代码为 TORCH.NN.MODULES.FLATTEN
from .module import Module
[docs]class Flatten(Module):
    r"""
    Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
    Args:
        start_dim: first dim to flatten (default = 1).
        end_dim: last dim to flatten (default = -1).
    Shape:
        - Input: :math:`(N, *dims)`
        - Output: :math:`(N, \prod *dims)` (for the default case).
    Examples::
        >>> m = nn.Sequential(
        >>>     nn.Conv2d(1, 32, 5, 1, 1),
        >>>     nn.Flatten()
        >>> )
    """
    __constants__ = ['start_dim', 'end_dim']
    def __init__(self, start_dim=1, end_dim=-1):
        super(Flatten, self).__init__()
        self.start_dim = start_dim
        self.end_dim = end_dim
    def forward(self, input):
        return input.flatten(self.start_dim, self.end_dim)

torch.Tensor.flatten()

和torch.flatten()一样

PyTorch中Flatten(start_dim=1, end_dim=-1)是什么意思

`Flatten(start_dim=1, end_dim=-1)` 是PyTorch中的一个函数,用于将输入张量进行扁平化操作。它可以将多维的张量转换为一维张量,保持数据的顺序不变。

参数:
- `start_dim`(可选):指定开始扁平化的维度。默认值为 1,表示从第二个维度开始扁平化。注意,维度索引是从 0 开始的。
- `end_dim`(可选):指定结束扁平化的维度。默认值为 -1,表示扁平化到最后一个维度。

返回值:
- 返回一个新的张量,是输入张量扁平化后的结果。

下面是一个示例,说明如何使用 `Flatten()` 函数:

import torch
input = torch.tensor([[1, 2, 3],
                      [4, 5, 6]])
output = torch.flatten(input, start_dim=0, end_dim=1)
print(output)
tensor([1, 2, 3, 4, 5, 6])

在上面的示例中,输入张量 `input` 是一个 2D 张量,形状为 (2, 3)。使用 `torch.flatten()` 函数对 `input` 进行扁平化操作,将其转换为一维张量。由于没有指定 `start_dim` 和 `end_dim`,默认从第二个维度(即行维度)开始扁平化,并扁平化到最后一个维度(即列维度)。最终的输出张量 `output` 是一个一维张量,包含了原始张量中的所有元素,按照原始张量的顺序排列。

请注意,`Flatten()` 函数返回的是一个新的张量,原始张量保持不变。

到此这篇关于Pytorch阅读文档中的flatten函数的文章就介绍到这了,更多相关Pytorch flatten函数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • YOLOv5以txt或json格式输出预测结果的方法详解

    YOLOv5以txt或json格式输出预测结果的方法详解

    这篇文章主要给大家介绍了关于YOLOv5以txt或json格式输出预测结果的相关资料,文中通过实例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2023-03-03
  • 实例讲解Python中global语句下全局变量的值的修改

    实例讲解Python中global语句下全局变量的值的修改

    global是Python中的一个关键字用来,声明一个局部变量为全局变量,这里我们来以实例讲解Python中global语句下全局变量的值的修改,需要的朋友可以参考下.
    2016-06-06
  • Python中的copy()函数详解(list,array)

    Python中的copy()函数详解(list,array)

    这篇文章主要介绍了Python中的copy()函数详解(list,array),具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-09-09
  • 详解python中@的用法

    详解python中@的用法

    这篇文章主要介绍了python中@的用法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-03-03
  • Python通过TensorFLow进行线性模型训练原理与实现方法详解

    Python通过TensorFLow进行线性模型训练原理与实现方法详解

    这篇文章主要介绍了Python通过TensorFLow进行线性模型训练原理与实现方法,结合实例形式详细分析了Python通过TensorFLow进行线性模型训练相关概念、算法设计与训练操作技巧,需要的朋友可以参考下
    2020-01-01
  • python中的异步爬虫详解

    python中的异步爬虫详解

    这篇文章主要介绍了python中的异步爬虫详解,所谓的异步异步 IO,就是发起一个 IO 阻塞的操作,但是不用等到它结束,可以在它执行 IO 的过程中继续做别的事情,当 IO 执行完毕之后会收到它的通知,需要的朋友可以参考下
    2023-08-08
  • python实现数据清洗(缺失值与异常值处理)

    python实现数据清洗(缺失值与异常值处理)

    今天小编就为大家分享一篇python实现数据清洗(缺失值与异常值处理),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12
  • PyCharm 2020.2.2 x64 下载并安装的详细教程

    PyCharm 2020.2.2 x64 下载并安装的详细教程

    这篇文章主要介绍了PyCharm 2020.2.2 x64 下载并安装的详细教程,本文通过图文并茂的形式给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-10-10
  • python如何将多个PDF进行合并

    python如何将多个PDF进行合并

    这篇文章主要为大家详细介绍了python如何将多个PDF进行合并,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-08-08
  • Python ORM框架SQLAlchemy学习笔记之映射类使用实例和Session会话介绍

    Python ORM框架SQLAlchemy学习笔记之映射类使用实例和Session会话介绍

    这篇文章主要介绍了Python ORM框架SQLAlchemy学习笔记之映射类使用实例和Session会话介绍,需要的朋友可以参考下
    2014-06-06

最新评论