Pytorch实现简单自定义网络层的方法

 更新时间:2022年05月20日 10:07:15   作者:ting_qifengl  
这篇文章主要给大家介绍了关于Pytorch实现简单自定义网络层的相关资料,文中通过实例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下

前言

Pytorch、Tensoflow等许多深度学习框架集成了大量常见的网络层,为我们搭建神经网络提供了诸多便利。但在实际工作中,因为项目要求、研究需要或者发论文需要等等,大家一般都会需要自己发明一个现在在深度学习框架中还不存在的层。 在这些情况下,就必须构建自定义层。

博主在学习了沐神的动手学深度学习这本书之后,学到了许多东西。这里记录一下书中基于Pytorch实现简单自定义网络层的方法,仅供参考。

一、不带参数的层

首先,我们构造一个没有任何参数的自定义层,要构建它,只需继承基础层类并实现前向传播功能。

import torch
import torch.nn.functional as F
from torch import nn
class CenteredLayer(nn.Module):
    def __init__(self):
        super().__init__()
 
    def forward(self, X):
        return X - X.mean()

输入一些数据,验证一下网络是否能正常工作:

layer = CenteredLayer()
print(layer(torch.FloatTensor([1, 2, 3, 4, 5])))

输出结果如下:

tensor([-2., -1.,  0.,  1.,  2.])

运行正常,表明网络没有问题。

现在将我们自建的网络层作为组件合并到更复杂的模型中,并输入数据进行验证:

net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
Y = net(torch.rand(4, 8))
print(Y.mean())  # 因为模型参数较多,输出也较多,所以这里输出Y的均值,验证模型可运行即可

结果如下:

tensor(-5.5879e-09, grad_fn=<MeanBackward0>)

二、带参数的层

这里使用内置函数来创建参数,这些函数可以提供一些基本的管理功能,使用更加方便。

这里实现了一个简单的自定义的全连接层,大家可根据需要自行修改即可。

class MyLinear(nn.Module):
    def __init__(self, in_units, units):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_units, units))
        self.bias = nn.Parameter(torch.randn(units,))
    def forward(self, X):
        linear = torch.matmul(X, self.weight.data) + self.bias.data
        return F.relu(linear)

接下来实例化类并访问其模型参数:

linear = MyLinear(5, 3)
print(linear.weight)

结果如下:

Parameter containing:
tensor([[-0.3708,  1.2196,  1.3658],
        [ 0.4914, -0.2487, -0.9602],
        [ 1.8458,  0.3016, -0.3956],
        [ 0.0616, -0.3942,  1.6172],
        [ 0.7839,  0.6693, -0.8890]], requires_grad=True)

而后输入一些数据,查看模型输出结果:

print(linear(torch.rand(2, 5)))
# 结果如下
tensor([[1.2394, 0.0000, 0.0000],
        [1.3514, 0.0968, 0.6667]])

我们还可以使用自定义层构建模型,使用方法与使用内置的全连接层相同。

net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))
print(net(torch.rand(2, 64)))
# 结果如下
tensor([[4.1416],
        [0.2567]])

三、总结

我们可以通过基本层类设计自定义层。这允许我们定义灵活的新层,其行为与深度学习框架中的任何现有层不同。

在自定义层定义完成后,我们就可以在任意环境和网络架构中调用该自定义层。

层可以有局部参数,这些参数可以通过内置函数创建。

四、参考

《动手学深度学习》 — 动手学深度学习 2.0.0-beta0 documentation

https://zh-v2.d2l.ai/

附:pytorch获取网络的层数和每层的名字

#创建自己的网络
import models
model = models.__dict__["resnet50"](pretrained=True)

for index ,(name, param) in enumerate(model.named_parameters()):
    print( str(index) + " " +name)

结果如下:

0 conv1.weight
1 bn1.weight
2 bn1.bias
3 layer1.0.conv1.weight
4 layer1.0.bn1.weight
5 layer1.0.bn1.bias
6 layer1.0.conv2.weight
7 layer1.0.bn2.weight
8 layer1.0.bn2.bias
9 layer1.0.conv3.weight

到此这篇关于Pytorch实现简单自定义网络层的文章就介绍到这了,更多相关Pytorch自定义网络层内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python 3行代码提取音乐高潮部分

    Python 3行代码提取音乐高潮部分

    这篇文章主要介绍了利用Python代码提取音乐高潮部分,文章围绕Python代码的相关详情展开提取音乐的内容,需要的小伙伴可以参考一下
    2022-01-01
  • python实现的DES加密算法和3DES加密算法实例

    python实现的DES加密算法和3DES加密算法实例

    这篇文章主要介绍了python实现的DES加密算法和3DES加密算法,以实例形式较为详细的分析了DES加密算法和3DES加密算法的原理与实现技巧,需要的朋友可以参考下
    2015-06-06
  • python requests 库请求带有文件参数的接口实例

    python requests 库请求带有文件参数的接口实例

    今天小编就为大家分享一篇python requests 库请求带有文件参数的接口实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-01-01
  • Python将Excel转换为多种图片格式的方法(PNG, JPG, BMP, SVG)

    Python将Excel转换为多种图片格式的方法(PNG, JPG, BMP, SVG)

    有时,你可能希望以图片形式分享Excel数据,以防止他人对数据进行修改或编辑,将Excel转换为图片可以将数据锁定为静态图片,确保数据的完整性和准确性,这篇文章将探讨如何使用Python实现将Excel工作表转换为多种图片格式,如PNG,JPG,BMP和SVG,需要的朋友可以参考下
    2025-03-03
  • Python使用Dijkstra算法实现求解图中最短路径距离问题详解

    Python使用Dijkstra算法实现求解图中最短路径距离问题详解

    这篇文章主要介绍了Python使用Dijkstra算法实现求解图中最短路径距离问题,简单描述了Dijkstra算法的原理并结合具体实例形式分析了Python使用Dijkstra算法实现求解图中最短路径距离的相关步骤与操作技巧,需要的朋友可以参考下
    2018-05-05
  • Python使用三种方法实现PCA算法

    Python使用三种方法实现PCA算法

    这篇文章主要介绍了Python使用三种方法实现PCA算法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-12-12
  • Python中的变量赋值

    Python中的变量赋值

    这篇文章主要介绍了Python中的变量赋值,Python中的变量在使用中很流畅,可以不关注类型,任意赋值,对于开发来说效率得到了提升,但不了解其中的机理,往往也会犯一些小错,让开发进行的不那么流畅,本文就从语言设计和底层原理的角度,带大家理解Python中的变量。
    2021-10-10
  • python实现解数独程序代码

    python实现解数独程序代码

    最近在带孩子学习数独,职业使然,就上网搜了下相关程序的解法,这里分享给大家,希望对大家学习python有所帮助
    2017-04-04
  • Python Flask 上传文件测试示例

    Python Flask 上传文件测试示例

    这篇文章主要为大家介绍了Python Flask 上传文件测试的方法示例,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-07-07
  • Python实现树莓派WiFi断线自动重连的实例代码

    Python实现树莓派WiFi断线自动重连的实例代码

    实现 WiFi 断线自动重连,原理是用 Python 监测网络是否断线,如果断线则重启网络服务。接下来给大家分享实现代码,需要的朋友参考下
    2017-03-03

最新评论