PyTorch中nn.Module示例详解

 更新时间:2025年07月24日 09:48:54   作者:捂一捂啊啊  
本文详解PyTorch中nn.Module,涵盖参数管理、训练模式、设备迁移、保存加载等核心功能,并对比nn.Sequential的差异,强调其灵活性与适用场景,建议根据网络复杂度选择使用,感兴趣的朋友一起看看吧

直接print(dir(nn.Module)),得到如下内容:

一、模型结构与参数

  1. parameters()

    • 用途:返回模块的所有可训练参数(如权重、偏置)。
    • 示例
      for param in model.parameters():
          print(param.shape)
      
  2. named_parameters()

    • 用途:返回带名称的参数迭代器,便于调试和访问特定参数。
    • 示例
      for name, param in model.named_parameters():
          if 'weight' in name:
              print(name, param.shape)
      
  3. children()

    • 用途:返回直接子模块的迭代器。
    • 示例
      for child in model.children():
          print(type(child))
      
  4. modules()

    • 用途:递归返回所有子模块(包括自身)。
    • 示例
      for module in model.modules():
          if isinstance(module, nn.Conv2d):
              print(module.kernel_size)
      

二、模型状态与模式

  1. train()eval()

    • 用途:切换训练/推理模式(影响Dropout、BatchNorm等层)。
    • 示例
      model.train()  # 训练模式
      model.eval()   # 推理模式
      
  2. training

    • 用途:布尔属性,指示当前模式(True 为训练,False 为推理)。
    • 示例
      print(model.training)  # 输出:True/False
      

三、模型保存与加载

  1. state_dict()

    • 用途:返回包含模型所有参数的字典(OrderedDict)。
    • 示例
      torch.save(model.state_dict(), 'model.pth')
      
  2. load_state_dict()

    • 用途:从字典加载模型参数。
    • 示例
      model.load_state_dict(torch.load('model.pth'))
      

四、设备与数据类型

  1. to()

    • 用途:将模型移动到指定设备(如GPU)或转换数据类型。
    • 示例
      model.to('cuda')          # 移动到GPU
      model.to(torch.float16)   # 转换为半精度
      
  2. cpu()cuda()

    • 用途:快捷方法,分别将模型移动到CPU或GPU。
    • 示例
      model.cuda()  # 等价于 model.to('cuda')
      

五、前向传播与计算

  1. forward()

    • 用途:定义模型的前向传播逻辑(需在自定义模块中重写)。
    • 示例
      class MyModel(nn.Module):
          def forward(self, x):
              return self.layer(x)
      
  2. __call__()

    • 用途:调用模型实例时触发(内部调用 forward(),支持钩子函数)。
    • 示例
      output = model(x)  # 等价于 output = model.forward(x)
      

六、参数初始化与优化

  1. zero_grad()

    • 用途:清空所有参数的梯度(通常在每个训练步骤前调用)。
    • 示例
      optimizer.zero_grad()  # 等价于 model.zero_grad()
      
  2. requires_grad_()

    • 用途:设置参数是否需要梯度(用于冻结部分模型)。
    • 示例
      for param in model.parameters():
          param.requires_grad = False  # 冻结所有参数
      

七、调试与信息

  1. extra_repr()

    • 用途:自定义模块打印信息(需在子类中重写)。
    • 示例
      class MyModel(nn.Module):
          def extra_repr(self):
              return f"hidden_size={self.hidden_size}"
      
  2. dump_patches()

    • 用途:打印模型的补丁信息(用于调试版本差异)。

八、其他实用方法

  1. apply()

    • 用途:递归应用函数到所有子模块(如初始化权重)。
    • 示例
      def init_weights(m):
          if isinstance(m, nn.Conv2d):
              nn.init.kaiming_normal_(m.weight)
      model.apply(init_weights)
      
  2. register_forward_hook()

    • 用途:注册前向传播钩子(用于捕获中间输出,调试或特征提取)。

总结

日常使用中,最频繁的方法包括:

  • 模型构建parameters(), children(), modules()
  • 训练与推理train(), eval(), zero_grad(), forward()
  • 保存与加载state_dict(), load_state_dict()
  • 设备管理to(), cuda(), cpu()

其他方法根据具体需求选择使用,例如钩子函数用于高级调试,apply() 用于统一初始化。

与nn.Sequential对比:

1. 继承关系与基础属性

  • nn.Module

    • 是所有神经网络模块的基类,提供最基础的功能(如参数管理、钩子机制)。
    • 包含核心属性:_parameters, _modules, _buffers 等。
  • nn.Sequential

    • nn.Module 的子类,继承了所有基础功能。
    • 额外添加了与顺序执行相关的属性(如 __getitem__append)。

2. 核心差异对比

功能类别nn.Modulenn.Sequential
模块构建需要手动实现 forward 方法自动按顺序执行子模块,无需定义 forward
子模块访问通过属性名(如 self.conv1通过索引或命名(如 model[0]
动态修改需手动管理子模块支持 appendextendinsert 等操作
适用场景复杂网络结构(如ResNet、U-Net)简单顺序结构(如LeNet卷积部分)

3. 具体方法对比

3.1 公共方法(两者都有)
# 模型参数与结构
['parameters', 'named_parameters', 'children', 'modules', 'named_children', 'named_modules']
# 模型状态
['train', 'eval', 'training', 'zero_grad', 'requires_grad_']
# 设备与数据类型
['to', 'cpu', 'cuda', 'float', 'double', 'half', 'bfloat16']
# 保存与加载
['state_dict', 'load_state_dict']
# 钩子机制
['register_forward_hook', 'register_backward_hook']
3.2nn.Sequential特有的方法
# 列表操作(动态修改模块顺序)
['__getitem__', '__setitem__', '__delitem__', '__len__', 'append', 'extend', 'insert', 'pop']
# 索引相关
['_get_item_by_idx']
3.3nn.Module特有的方法
# 自定义实现
['forward', 'extra_repr']
# 高级管理
['add_module', 'register_module', 'register_parameter', 'register_buffer']

4. 示例对比

4.1 创建模型
# nn.Module(需自定义 forward)
class CustomModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 64, 3)
        self.relu = nn.ReLU()
    def forward(self, x):
        return self.relu(self.conv(x))
# nn.Sequential(自动按顺序执行)
seq_model = nn.Sequential(
    nn.Conv2d(3, 64, 3),
    nn.ReLU()
)
4.2 访问子模块
# nn.Module
custom_model.conv  # 通过属性名访问
# nn.Sequential
seq_model[0]       # 通过索引访问
seq_model.append(nn.MaxPool2d(2))  # 动态添加模块

5. 总结

特性nn.Modulenn.Sequential
灵活性高(自定义任意逻辑)低(仅支持顺序执行)
代码复杂度较高(需手动实现 forward低(自动处理前向传播)
动态修改不支持直接操作(需手动管理)支持 appendinsert 等操作
适用场景复杂网络、分支结构、自定义操作简单堆叠模块(如CNN的卷积部分)

建议:

  • 对于简单的顺序网络,优先使用 nn.Sequential 以减少代码量。
  • 对于包含复杂逻辑(如残差连接、多输入输出)的网络,使用 nn.Module 自定义实现。

到此这篇关于PyTorch中nn.Module详解的文章就介绍到这了,更多相关PyTorch nn.Module内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • django将网络中的图片,保存成model中的ImageField的实例

    django将网络中的图片,保存成model中的ImageField的实例

    今天小编就为大家分享一篇django将网络中的图片,保存成model中的ImageField的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08
  • python实现通过flask和前端进行数据收发

    python实现通过flask和前端进行数据收发

    今天小编就为大家分享一篇python实现通过flask和前端进行数据收发,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08
  • Python游戏开发之精灵和精灵组

    Python游戏开发之精灵和精灵组

    python作为当前非常受欢迎的编程语言,很大一部分原因是拥有丰富的库,这篇文章主要给大家介绍了关于Python游戏开发之精灵和精灵组的相关资料,文中通过图文介绍的非常详细,需要的朋友可以参考下
    2023-05-05
  • 使用python telnetlib批量备份交换机配置的方法

    使用python telnetlib批量备份交换机配置的方法

    今天小编就为大家分享一篇使用python telnetlib批量备份交换机配置的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07
  • DjangoUeditor图片不显示img的src没有域名问题

    DjangoUeditor图片不显示img的src没有域名问题

    在使用DjangoUeditor过程中,可能遇到图片上传后不显示问题,解决办法是修改源码view.py,加入代码使得保存的图片URL带有协议和域名,具体做法是在保存图片代码中添加request.scheme获取协议,request.META['HTTP_HOST']获取域名
    2024-09-09
  • python字典key不能是可以是啥类型

    python字典key不能是可以是啥类型

    在本篇文章里小编给大家整理了关于python字典key不能是可以是啥类型的相关知识点,需要的朋友们可以参考下。
    2020-08-08
  • python 中 __init__的意义以及作用

    python 中 __init__的意义以及作用

    python中的__init__是一个私有函数(方法),访问私有函数中的变量在python中用self,在PHP中用$this,这篇文章主要介绍了python 中 __init__的意义以及作用,需要的朋友可以参考下
    2023-02-02
  • python 装饰器带参数和不带参数步骤详解

    python 装饰器带参数和不带参数步骤详解

    装饰器是Python语言中一种特殊的语法,用于在不修改原函数代码的情况下,为函数添加额外的功能或修改函数的行为,这篇文章主要介绍了python装饰器带参数和不带参数的相关知识,需要的朋友可以参考下
    2024-05-05
  • python 读取摄像头数据并保存的实例

    python 读取摄像头数据并保存的实例

    今天小编就为大家分享一篇python 读取摄像头数据并保存的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-08-08
  • 使用Python3 poplib模块删除服务器多天前的邮件实现代码

    使用Python3 poplib模块删除服务器多天前的邮件实现代码

    这篇文章主要介绍了使用Python3 poplib模块删除多天前的邮件的实现代码,代码简单易懂,非常不错,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-04-04

最新评论