PyTorch如何利用parameters()获取模型参数

 更新时间:2023年09月12日 08:46:55   作者:玉笛仙踪  
这篇文章主要介绍了PyTorch如何利用parameters()获取模型参数问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

利用parameters()获取模型参数

在PyTorch中,可以使用parameters函数来获取模型中的所有可学习参数。

以下是一个示例:

import torch.nn as nn
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x
model = MyModel()
params = list(model.parameters())

在这个示例中,我们首先定义了一个包含两个线性层的神经网络,然后通过list(model.parameters())获取了模型中的所有可学习参数。

这些参数存储在一个Python列表中,可以用于进行优化器的初始化和模型的保存和加载。

PyTorch中模型的parameters()方法

首先先定义一个模型:

import torch as t
import torch.nn as nn
class A(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 2, 3)
        self.conv2 = nn.Conv2d(2, 2, 3)
        self.conv3 = nn.Conv2d(2, 2, 3)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x

然后打印出该模型的参数:

pythona = A()
print(a.parameters()) #<generator object Module.parameters at 0x7f7b740d2360>

以上代码说明parameters()会返回一个生成器(迭代器)

然后将其迭代打印出来:

print(list(a.parameters())):#将迭代器转换成列表
Parameter containing:
tensor([[[[-0.0299,  0.0891,  0.0303],
          [ 0.0869, -0.0230, -0.1760],
          [ 0.1408,  0.0348,  0.1795]],
         [[ 0.2001,  0.0023, -0.1775],
          [ 0.0947, -0.0231, -0.1756],
          [ 0.1201, -0.0997, -0.0303]]],
        [[[-0.0425,  0.0748, -0.1754],
          [-0.1191, -0.1203, -0.1219],
          [-0.0794,  0.0895, -0.1719]],
         [[ 0.1968, -0.0463,  0.0550],
          [-0.0386,  0.1594,  0.1282],
          [-0.0009,  0.2167, -0.1783]]]], requires_grad=True)
Parameter containing:
tensor([ 0.0147, -0.0406], requires_grad=True)
Parameter containing:
tensor([[[[-0.0578, -0.1114, -0.1194],
          [-0.1469, -0.1175, -0.1616],
          [-0.2289, -0.0975, -0.1700]],
         [[-0.0894,  0.0074,  0.1222],
          [-0.0176, -0.0509,  0.1622],
          [-0.0405, -0.1349,  0.1782]]],
        [[[-0.0739,  0.2167,  0.1864],
          [ 0.0956, -0.1761,  0.0464],
          [ 0.0062, -0.0685,  0.0748]],
         [[ 0.1085,  0.1481,  0.1334],
          [ 0.2236, -0.0706, -0.0224],
          [ 0.0079, -0.1835, -0.0407]]]], requires_grad=True)
Parameter containing:
tensor([-8.0720e-05,  1.6026e-01], requires_grad=True)
Parameter containing:
tensor([[[[-0.0702,  0.1846,  0.0419],
          [-0.1891, -0.0893, -0.0024],
          [-0.0349, -0.0213,  0.0936]],
         [[-0.1062,  0.1242,  0.0391],
          [-0.1924,  0.0535, -0.1480],
          [ 0.0400, -0.0487, -0.2317]]],
        [[[ 0.1202,  0.0961,  0.2336],
          [ 0.2225, -0.2294, -0.2283],
          [-0.0963, -0.0311, -0.2354]],
         [[ 0.0676, -0.0439, -0.0962],
          [-0.2316, -0.0639, -0.0671],
          [ 0.1737, -0.1169, -0.1751]]]], requires_grad=True)
Parameter containing:
tensor([-0.1939, -0.0959], requires_grad=True)

从以上结果可以看出列表中有6个元素,由于nn.Conv2d()的参数包括self.weight和self.bias两部分,所以每个2D卷积层包括两部分的参数.注意self.bias是加在每个通道上的,所以self.bias的长度与output_channl相同

心得:

parameters()会返回一个生成器(迭代器),生成器每次生成的是Tensor类型的数据.

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python机器学习库之Scikit-learn基本用法详解

    Python机器学习库之Scikit-learn基本用法详解

    Scikit-learn 是 Python 中最著名的机器学习库之一,它提供了大量实用的机器学习算法以及相关的工具,可以方便我们进行数据挖掘和数据分析,在这篇文章中,我们将介绍 Scikit-learn 的基本使用,包括如何导入数据、预处理数据、选择和训练模型,以及评估模型的性能
    2023-07-07
  • Python Matplotlib绘制箱线图boxplot()函数详解

    Python Matplotlib绘制箱线图boxplot()函数详解

    箱线图一般用来展现数据的分布(如上下四分位值、中位数等),同时也可以用箱线图来反映数据的异常情况,下面这篇文章主要给大家介绍了关于Python Matplotlib绘制箱线图boxplot()函数的相关资料,需要的朋友可以参考下
    2022-07-07
  • Python实现向好友发送微信消息优化篇

    Python实现向好友发送微信消息优化篇

    利用python可以实现微信消息发送功能,怎么实现呢?你肯定会想着很复杂,但是python的好处就是很多人已经把接口打包做好了,只需要调用即可,今天通过本文给大家分享使用 Python 实现微信消息发送的思路代码,一起看看吧
    2022-06-06
  • Python集成开发环境pycharm配置git的实现步骤

    Python集成开发环境pycharm配置git的实现步骤

    本文主要介绍了Python集成开发环境pycharm配置git的实现步骤,文中通过图文的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2024-05-05
  • python中超简单的字符分割算法记录(车牌识别、仪表识别等)

    python中超简单的字符分割算法记录(车牌识别、仪表识别等)

    这篇文章主要给大家介绍了关于python中超简单的字符分割算法记录,如车牌识别、仪表识别等,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2021-09-09
  • selenium环境搭建及基本元素定位方式详解

    selenium环境搭建及基本元素定位方式详解

    selenium最初是一个自动化测试工具,而爬虫中使用它主要是为了解决requests无法执行javaScript代码的问题,这篇文章主要介绍了selenium环境搭建及基本元素定位方式,需要的朋友可以参考下
    2023-04-04
  • Anaconda修改默认虚拟环境安装位置的方案分享

    Anaconda修改默认虚拟环境安装位置的方案分享

    新安装Anaconda后,在创建环境时环境自动安装在C盘,但是C盘空间有限,下面这篇文章主要给大家介绍了关于Anaconda修改默认虚拟环境安装位置的相关资料,需要的朋友可以参考下
    2023-01-01
  • Python全栈之面向对象基础

    Python全栈之面向对象基础

    这篇文章主要为大家介绍了Python面向对象基础,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助
    2021-11-11
  • Python通过pyperclip库操作剪贴板

    Python通过pyperclip库操作剪贴板

    pyperclip是一个python库用于操作剪贴板,可以非常方便地将文本复制到剪贴板或从剪贴板获取文本,下面就跟随小编一起了解一下pyperclip的具体使用吧
    2024-11-11
  • Python简单生成8位随机密码的方法

    Python简单生成8位随机密码的方法

    这篇文章主要介绍了Python简单生成8位随机密码的方法,结合实例形式分析了2种简单生成随机密码的方法,非常简单实用,需要的朋友可以参考下
    2017-05-05

最新评论