PyTorch 分布式训练的实现

 更新时间:2025年05月15日 10:27:48   作者:handsomeboysk  
本文主要介绍了PyTorch 分布式训练的实现,包括数据并行、模型并行、混合并行和流水线并行等模式,感兴趣的可以了解一下

在深度学习模型变得日益庞大之后,单个 GPU 的显存已经无法满足高效训练的需求。此时,“分布式训练(Distributed Training)”技术应运而生,成为加速训练的重要手段。

本文将围绕以下三行典型的 PyTorch 分布式训练代码进行详细解析,并扩展介绍分布式训练的核心概念和实践方法:

local_rank = int(os.getenv('LOCAL_RANK', -1))  # https://pytorch.org/docs/stable/elastic/run.html
global_rank = int(os.getenv('RANK', -1))
world_size = int(os.getenv('WORLD_SIZE', 1))

一、什么是分布式训练?

分布式训练是指将模型训练过程划分到多个计算设备(通常是多个 GPU,甚至是多台机器)上进行协同处理,目标是加速训练速度扩展模型容量

分布式训练可以分为以下几种模式:

  • 数据并行(Data Parallelism):每个 GPU 处理不同的数据子集,同步梯度。
  • 模型并行(Model Parallelism):将模型拆成多个部分,分别部署到不同的 GPU。
  • 混合并行(Hybrid Parallelism):结合模型并行和数据并行。
  • 流水线并行(Pipeline Parallelism):按层切分模型,不同 GPU 处理不同阶段。

二、理解分布式训练的核心概念

1. World Size(全局进程数)

world_size = int(os.getenv('WORLD_SIZE', 1))
  • 含义:分布式训练中,所有参与训练的进程总数。通常等于 GPU 总数。
  • 作用:用于初始化进程组(torch.distributed.init_process_group()),让每个进程知道集群的规模。

比如你有两台机器,每台 4 块 GPU,那么 world_size = 8。

2. Rank(全局进程编号)

global_rank = int(os.getenv('RANK', -1))
  • 含义:标识每个训练进程在所有进程中的唯一编号(从 0 开始)。
  • 作用:常用于标记主节点(rank == 0),控制日志输出、模型保存等。

例如:

  • rank=0:负责打印日志、保存模型
  • rank=1,2,…:只做训练

3. Local Rank(本地进程编号)

local_rank = int(os.getenv('LOCAL_RANK', -1))
  • 含义:当前训练进程在本地机器上的 GPU 编号。一般与 CUDA_VISIBLE_DEVICES 配合使用。

  • 作用:用于指定该进程应该使用哪块 GPU,如:

    torch.cuda.set_device(local_rank)
    

三、环境变量的设置方式

这些环境变量通常由 分布式启动器 设置。例如使用 torchrun

torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
    --master_addr=192.168.1.1 --master_port=12345 train.py

torchrun 会自动为每个进程设置:

  • LOCAL_RANK
  • RANK
  • WORLD_SIZE

也可以手动导出:

export WORLD_SIZE=8
export RANK=3
export LOCAL_RANK=3

四、分布式训练初始化流程(PyTorch 示例)

在 PyTorch 中,典型的初始化流程如下:

import os
import torch
import torch.distributed as dist

def setup_distributed():
    local_rank = int(os.getenv('LOCAL_RANK', -1))
    global_rank = int(os.getenv('RANK', -1))
    world_size = int(os.getenv('WORLD_SIZE', 1))

    torch.cuda.set_device(local_rank)

    dist.init_process_group(
        backend='nccl',  # GPU 用 nccl,CPU 用 gloo
        init_method='env://',
        world_size=world_size,
        rank=global_rank
    )
  • init_method='env://':表示从环境变量中读取初始化信息。
  • nccl 是 NVIDIA 的高性能通信库,支持 GPU 间高速通信。

五、分布式训练的代码结构

使用 PyTorch 实现分布式训练的基本框架:

def train():
    setup_distributed()

    model = MyModel().cuda()
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])

    dataset = MyDataset()
    sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=64)

    for epoch in range(epochs):
        sampler.set_epoch(epoch)
        for batch in dataloader:
            # 正常训练流程

六、Elastic Training(弹性训练)

值得注意的是,示例代码中注释中提到的链接:https://pytorch.org/docs/stable/elastic/run.html

这是指 PyTorch 的 弹性分布式训练(Elastic Training),支持在训练过程中动态增加或移除节点,具备高容错性。

  • 工具:torch.distributed.elastic
  • 启动命令:torchrun --standalone --nnodes=1 --nproc_per_node=4 train.py

该特性对于大规模、长时间训练任务非常重要。

七、总结

变量名含义来源典型用途
WORLD_SIZE全局进程数量torchrun 设置初始化进程组,全局通信
RANK当前进程的全局编号torchrun 设置控制主节点行为
LOCAL_RANK当前进程在本地的 GPU 编号torchrun 设置显卡绑定:torch.cuda.set_device

这三行代码虽然简单,却是 PyTorch 分布式训练的入口。理解它们,就理解了 PyTorch 在分布式场景下的通信机制和训练框架。

如果你想要进一步深入了解 PyTorch 分布式训练,推荐官方文档:

到此这篇关于PyTorch 分布式训练的实现的文章就介绍到这了,更多相关PyTorch 分布式训练内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家! 

相关文章

  • Python中使用HTMLParser解析html实例

    Python中使用HTMLParser解析html实例

    这篇文章主要介绍了Python中使用HTMLParser解析html实例,本文直接给出使用示例,并总结出HTMLParser含有的方法分为两类,一类是需要显式调用的,而另一类不需显示调用,需要的朋友可以参考下
    2015-02-02
  • Python使用LSTM实现销售额预测详解

    Python使用LSTM实现销售额预测详解

    大家经常会遇到一些需要预测的场景,比如预测品牌销售额,预测产品销量。本文给大家分享一波使用 LSTM 进行端到端时间序列预测的完整代码和详细解释,需要的可以参考一下
    2022-07-07
  • Python Django查询集的延迟加载特性详解

    Python Django查询集的延迟加载特性详解

    在 Django 的开发过程中,查询集(QuerySet)是我们与数据库进行交互的重要工具,本文将深入探讨 Django 查询集的延迟加载特性,帮助新手理解其工作原理及优缺点,提供一些实用的代码示例来展示延迟加载如何在实际项目中使用,需要的朋友可以参考下
    2024-10-10
  • Python设计模式之代理模式实例详解

    Python设计模式之代理模式实例详解

    这篇文章主要介绍了Python设计模式之代理模式,结合实例形式较为详细的分析了代理模式的概念、原理及Python定义、使用代理模式相关操作技巧,需要的朋友可以参考下
    2019-01-01
  • Python中self关键字的用法解析

    Python中self关键字的用法解析

    在Python中,self是一个经常出现的关键字,特别是在类定义中的方法,这篇文章主要和大家self的作用和用法,希望可以帮助大家更好地理解为什么需要它以及如何正确使用它
    2023-11-11
  • python迭代器的使用方法实例

    python迭代器的使用方法实例

    这篇文章主要介绍了python迭代器的使用方法,代码很简单,大家可以参考使用
    2013-11-11
  • Python编程判断一个正整数是否为素数的方法

    Python编程判断一个正整数是否为素数的方法

    这篇文章主要介绍了Python编程判断一个正整数是否为素数的方法,涉及Python数学运算相关操作技巧,需要的朋友可以参考下
    2017-04-04
  • python实现单链表的方法示例

    python实现单链表的方法示例

    这篇文章主要给大家介绍了关于python实现单链表的相关资料,文中通过示例代码介绍的非常详细,对大家学习或者使用python具有一定的参考学习价值,需要的朋友们下面来一起学习学习吧
    2019-09-09
  • 对于Python异常处理慎用“except:pass”建议

    对于Python异常处理慎用“except:pass”建议

    这篇文章主要介绍了对于Python异常处理方法的建议,摘选自StackOverflow上的热门问题的回答,阐述了except:pass的使用时需要注意的地方,需要的朋友可以参考下
    2015-04-04
  • python matplotlib如何给图中的点加标签

    python matplotlib如何给图中的点加标签

    这篇文章主要介绍了python matplotlib给图中的点加标签,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-11-11

最新评论