Pytorch中的torch.nn.Linear()方法用法解读

 更新时间:2024年02月26日 10:09:14   作者:拥抱晨曦之温暖  
这篇文章主要介绍了Pytorch中的torch.nn.Linear()方法用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

Pytorch torch.nn.Linear()方法

torch.nn.Linear()作为深度学习中最简单的线性变换方法,其主要作用是对输入数据应用线性转换

看一下官方的解释及介绍

class Linear(Module):
    r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
    This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to ``False``, the layer will not learn an additive bias.
            Default: ``True``
    Shape:
        - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
          additional dimensions and :math:`H_{in} = \text{in\_features}`
        - Output: :math:`(N, *, H_{out})` where all but the last dimension
          are the same shape as the input and :math:`H_{out} = \text{out\_features}`.
    Attributes:
        weight: the learnable weights of the module of shape
            :math:`(\text{out\_features}, \text{in\_features})`. The values are
            initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
            :math:`k = \frac{1}{\text{in\_features}}`
        bias:   the learnable bias of the module of shape :math:`(\text{out\_features})`.
                If :attr:`bias` is ``True``, the values are initialized from
                :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                :math:`k = \frac{1}{\text{in\_features}}`
    Examples::
        >>> m = nn.Linear(20, 30)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([128, 30])
    """
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: Tensor
 
    def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
 
    def reset_parameters(self) -> None:
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)
 
    def forward(self, input: Tensor) -> Tensor:
        return F.linear(input, self.weight, self.bias)
 
    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )
 
 
# This class exists solely for Transformer; it has an annotation stating
# that bias is never None, which appeases TorchScript

这里我们主要看__init__()方法,很容易知道,当我们使用这个方法时一般需要传入2~3个参数,分别是in_features: int, out_features: int, bias: bool = True,第三个参数是说是否加偏置(bias),简单来讲,这个函数其实就是一个'一次函数':y = xA^T + b,(T表示张量A的转置),首先super(Linear, self).__init__()就是老生常谈的方法,之后初始化in_features和out_features,接下来就是比较重要的weight的设置,我们可以很清晰的看到weight的shape是(out_features,in_features)的,而我们在做xA^T时,并不是x和A^T相乘的,而是x和A.weight^T相乘的,这里需要大大留意,也就是说先对A做转置得到A.weight,然后在丢入y = xA^T + b中,得出结果。

接下来奉上一个小例子来实践一下:

import torch
 
# 随机初始化一个shape为(128,20)的Tensor
x = torch.randn(128,20)
# 构造线性变换函数y = xA^T + b,且参数(20,30)指的是A的shape,则A.weight的shape就是(30,20)了
y= torch.nn.Linear(20,30)
output = y(x)
# 按照以上逻辑使用torch中的简单乘法函数进行检验,结果很显然与上述符合
# 下面的y.weight可以理解为一个shape为(30,20)的一个可学习的矩阵,.t()表示转置
# y.bias若为TRUE,则bias是一个Tensor,且其shape为out_features,在该程序中应为30
# 更加细致的表达一下y = (128 * 20) * (30 * 20)^T + (if bias (1,30) ,else: 0)
ans = torch.mm(x,y.weight.t())+y.bias
print('ans.shape:\n',ans.shape)
print(torch.equal(ans,output))

对torch.nn.Linear的理解

torch.nn.Linear是pytorch的线性变换层

定义如下:

Linear(in_features: int, out_features: int, bias: bool = True, device: Any | None = None, dtype: Any | None = None)

全连接层 Fully Connect 一般就就用这个函数来实现。

因此在潜意识里,变换的输入张量的 shape 为 (batchsize, in_features),输出张量的 shape 为 (batchsize, out_features)。

当然这是常用的方式,但是 Linear 的输入张量的维度其实并不需要必须为上述的二维,多维也是完全可以的,Linear 仅是对输入的最后一维做线性变换,不影响其他维。

可以看下官网的解释

Linear — PyTorch 1.11.0 documentation

一个例子

如下:

import torch
input = torch.randn(30, 20, 10)  # [30, 20, 10]
linear = torch.nn.Linear(10, 15)  # (*, 10) --> (*, 15)
output = linear(input)
print(output.size()) # 输出 [30, 20, 15]

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 基于python 字符编码的理解

    基于python 字符编码的理解

    下面小编就为大家带来一篇基于python 字符编码的理解。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-09-09
  • 十个Python经典小游戏的代码合集

    十个Python经典小游戏的代码合集

    这篇文章主要为大家分享十个Python经典的小游戏代码,非常适合Python初学者练手。文中的示例代码讲解详细,感兴趣的小伙伴可以尝试一下
    2022-05-05
  • Python虚拟机字节码教程之装饰器实现详解

    Python虚拟机字节码教程之装饰器实现详解

    在本篇文章当中主要给大家介绍在 cpython 当中一些比较常见的字节码,从根本上理解 python 程序的执行。在本文当中主要介绍一些 python 基本操作的字节码,并且将从字节码的角度分析函数装饰器的原理
    2023-04-04
  • Python+Pygame实现经典魂斗罗游戏

    Python+Pygame实现经典魂斗罗游戏

    《魂斗罗》(Contra)是由Konami于1987年推出的一系列卷轴射击类单机游戏。本文将利用Python中的Pygame库实现这一经典游戏,感兴趣的可以了解一下
    2022-05-05
  • python 日期操作类代码

    python 日期操作类代码

    这篇文章主要介绍了python 日期操作类代码,里面涉及了python日期操作的一些基础知识,需要的朋友可以参考下
    2018-05-05
  • 在Docker上开始部署Python应用的教程

    在Docker上开始部署Python应用的教程

    这篇文章主要介绍了在Docker上开始部署Python应用的教程,Docker是时下最火爆的虚拟机,正在被各大云主机服务商所采用,需要的朋友可以参考下
    2015-04-04
  • 利用Python的tkinter模块实现界面化的批量修改文件名

    利用Python的tkinter模块实现界面化的批量修改文件名

    这篇文章主要介绍了利用Python的tkinter模块实现界面化的批量修改文件名,用Python编写过批量修改文件名的脚本程序,代码很简单,运行也比较快,详细内容需要的小伙伴可以参考一下下面文章内容
    2022-08-08
  • Python详细对比讲解break和continue区别

    Python详细对比讲解break和continue区别

    这篇文章主要介绍了python循环控制语句 break 与 continue,break就像是终止按键,不管执行到哪一步,只要遇到break,不管什么后续步骤,直接跳出当前循环
    2022-06-06
  • Python unittest单元测试框架总结

    Python unittest单元测试框架总结

    这篇文章主要介绍了Python unittest单元测试框架总结,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-09-09
  • Python中tkinter+MySQL实现增删改查

    Python中tkinter+MySQL实现增删改查

    这篇文章主要介绍了Python中tkinter+MySQL实现增删改查,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-04-04

最新评论