PyTorch核心方法之state_dict()、parameters()参数打印与应用案例

 更新时间:2025年12月13日 14:56:20   作者:木棉知行者  
PyTorch是一个流行的开源深度学习框架,提供了灵活且高效的方式来训练和部署神经网络,这篇文章主要介绍了PyTorch核心方法之state_dict()、parameters()参数打印与应用案例的相关资料,需要的朋友可以参考下

前言

本文以 LeNet-5 模型为案例,介绍了 PyTorch 中打印模型参数的相关方法。首先展示了 LeNet-5 模型的结构定义及打印结果;随后详细说明了三种获取模型参数的方式:

  • state_dict()方法返回有序字典形式的可学习参数,包含参数名称和对应张量;
  • parameters()方法返回生成器,仅包含各层参数信息;
  • named_parameters()方法返回生成器,包含模型名称和对应参数信息;
    最后提供了利用named_parameters()进行模型结构冻结的示例,可打印确认冻结的网络名称。

模型案例

本文以LeNet-5为基础模型,快速验证模型参数打印过程。

import os 
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import torch 
import torch.nn.functional as F 
import torch.nn as nn

class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120) # 这里论文上写的是conv,官方教程用了线性层
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

net = LeNet5()
print(net)

模型结构打印如下。

A. state_dict()方法验证

在 PyTorch 中,state_dict() 是核心方法之一,用于以有序字典(OrderedDict)的形式返回模型 / 优化器等实例的可学习参数(或状态),是模型保存、加载、迁移学习的基础。

state_dict() 本质是一个 Python 字典(PyTorch 中为 OrderedDict),键为参数 / 状态的名称(字符串),值为对应的张量(torch.Tensor)。

print(type(net.state_dict()))   # <class 'collections.OrderedDict'>
## 遍历打印
for model_key in net.state_dict():      # 【字典格式】的遍历,获取的是模型的名称
    print(f"{model_key}: {net.state_dict()[model_key].size()}")

对于Lenet-5模型进行打印,可以看到state_dict()的类型为 <class 'collections.OrderedDict'>,各层名称及参数尺寸如下图所示。

B. parameters()

parameters()方法也可以获取到模型的参数。可以看出,parameters()获取到的是一个生成器,其中仅包含各层参数的信息。

params = net.parameters()   
print(type(params))   # <class 'generator'>  生成器  

for param in params:    
    print(param.size())   # 只包含参数信息:具体的参数尺寸

对Lenet-5进行模型参数打印。

如果也需要模型名称信息,可以使用named_parameters()方法。该方法获取的也是一个生成器,其中返回的是一个元组,包括模型名称和对应的参数。

named_params = net.named_parameters()   
print(type(named_params))   # <class 'generator'>  也是一个生成器

for name, param in named_params:
    print(f"{name}: {param.size()}")   # 同时获取网络名称和网络参数

对Lenet-5进行模型名称及参数尺寸信息打印:

C. 模型结构冻结示例

该方法可以在对模型结构冻结时使用,如下述示例对模型结构m的参数进行冻结,同时打印确认冻结包含哪些网络结构。

# 示例
for name, param in m.named_parameters():
	param.requires_grad = False
	print(f"Freezing layer {name}")

总结 

到此这篇关于PyTorch核心方法之state_dict()、parameters()参数打印与应用案例的文章就介绍到这了,更多相关PyTorch state_dict()、parameters()参数打印内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • 详解Python爬虫爬取博客园问题列表所有的问题

    详解Python爬虫爬取博客园问题列表所有的问题

    这篇文章主要介绍了详解Python爬虫爬取博客园问题列表所有的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-01-01
  • Python代码实现删除一个list里面重复元素的方法

    Python代码实现删除一个list里面重复元素的方法

    今天小编就为大家分享一篇关于Python代码实现删除一个list里面重复元素的方法,小编觉得内容挺不错的,现在分享给大家,具有很好的参考价值,需要的朋友一起跟随小编来看看吧
    2019-04-04
  • Python Numpy之linspace用法说明

    Python Numpy之linspace用法说明

    这篇文章主要介绍了Python Numpy之linspace用法说明,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-04-04
  • python操作excel之openpyxl模块读写xlsx格式使用方法详解

    python操作excel之openpyxl模块读写xlsx格式使用方法详解

    这篇文章主要介绍了python操作excel之openpyxl模块读写xlsx格式使用方法详解,需要的朋友可以参考下
    2022-12-12
  • python读取nc数据并绘图的方法实例

    python读取nc数据并绘图的方法实例

    最近项目中需要处理和分析NC数据,所以下面这篇文章主要给大家介绍了关于python读取nc数据并绘图的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-05-05
  • 解决pycharm无法删除invalid interpreter(无效解析器)的问题

    解决pycharm无法删除invalid interpreter(无效解析器)的问题

    这篇文章主要介绍了pycharm无法删除invalid interpreter(无效解析器)的问题,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2023-07-07
  • python 动态规划问题解析(背包问题和最长公共子串)

    python 动态规划问题解析(背包问题和最长公共子串)

    这篇文章主要介绍了python 动态规划(背包问题和最长公共子串),在动态规划中,你要将某个指标最大化。在这个例子中,你要找出两个单词的最长公共子串。fish和fosh都包含的最长子串是什么呢,感兴趣的朋友跟随小编一起看看吧
    2022-05-05
  • Python按照24个实用大方向精选的上千种工具库汇总整理

    Python按照24个实用大方向精选的上千种工具库汇总整理

    本文整理了Python生态中近千个库,涵盖数据处理、图像处理、网络开发、Web框架、人工智能、科学计算、GUI工具、测试框架、环境管理等多个领域,列举了如difflib、requests、Django、TensorFlow等代表性工具,展示Python在各场景下的强大功能与灵活性
    2025-08-08
  • Linux中Python 环境软件包安装步骤

    Linux中Python 环境软件包安装步骤

    本文给大家分享的是在Linux系统中Python环境的安装步骤,以及常用的软件的安装升级,非常的实用,有需要的小伙伴可以参考下
    2016-03-03
  • Python OpenCV图像处理之图像滤波特效详解

    Python OpenCV图像处理之图像滤波特效详解

    图像滤波按图像域可分为两种类型:邻域滤波和频域滤波。按图像频率滤除效果主要分为两种类型:低通滤波和高通滤波。本文将通过案例为大家详细介绍一下OpenCV中的图像滤波特效,需要的可以参考一下
    2022-02-02

最新评论