pytorch配置双显卡方式,使用双显卡跑代码

 更新时间:2024年06月26日 09:12:53   作者:好好好好饭  
这篇文章主要介绍了pytorch配置双显卡方式,使用双显卡跑代码,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

项目场景

Linux系统,pytorch环境

问题描述

使用的服务器有两张显卡,感觉一张显卡跑代码比较慢,想配置两张显卡同时跑代码,只需要在你的代码中添加几行,就可以使用双显卡,亲测有效。

解决方案

提示:这里填写该问题的具体解决方案:

先看以下官方示例代码,插入添加的地方是需要我们添加在代码中的代码行

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import os 
#######添加
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' # 这里输入你的GPU_id
 
# Parameters and DataLoaders
input_size = 5
output_size = 2
 
batch_size = 30
data_size = 100
#######添加
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
# Dummy DataSet
class RandomDataset(Dataset):
 
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)
 
    def __getitem__(self, index):
        return self.data[index]
 
    def __len__(self):
        return self.len
 
rand_loader = DataLoader(dataset=RandomDataset(input_size, data_size),
                         batch_size=batch_size, shuffle=True)
 
# Simple Model
class Model(nn.Module):
    # Our model
 
    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.fc = nn.Linear(input_size, output_size)
 
    def forward(self, input):
        output = self.fc(input)
        print("\tIn Model: input size", input.size(),
              "output size", output.size())
 
        return output
################添加
# Create Model and DataParallel
model = Model(input_size, output_size)
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  model = nn.DataParallel(model)
model.to(device)
 
 
#Run the Model
for data in rand_loader:
    input = data.to(device)
    output = model(input)
    print("Outside: input size", input.size(),
          "output_size", output.size())

其中我将model = nn.DataParallel(model)修改为model = nn.DataParallel(model.cuda()),这一步直接参照网上修改的,因此这一步没有报错。

比如我自己在我代码中添加如下

from model.hash_model import DCMHT as DCMHT
import os
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import scipy.io as scio
 
 
from .base import TrainBase
from model.optimization import BertAdam
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity
from utils.calc_utils import calc_map_k_matrix as calc_map_k
from dataset.dataloader import dataloader
###############添加
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' # 这里输入你的GPU_id
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
class Trainer(TrainBase):
 
    def __init__(self,
                rank=0):
        args = get_args()
        super(Trainer, self).__init__(args, rank)
        self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
        self.run()
 
    def _init_model(self):
        self.logger.info("init model.")
        linear = False
        if self.args.hash_layer == "linear":
            linear = True
 
        self.logger.info("ViT+GPT!")
        HashModel = DCMHT
        self.model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
                            writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
####################################添加
        self.model= nn.DataParallel(self.model.cuda())
        if torch.cuda.device_count() >1:
            print("Lets use",torch.cuda.device_count(),"GPUs!")
        self.model.to(device)
 
        if self.args.pretrained != "" and os.path.exists(self.args.pretrained):
            self.logger.info("load pretrained model.")
            self.model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
        
        self.model.float()
        self.optimizer = BertAdam([
                    {'params': self.model.clip.parameters(), 'lr': self.args.clip_lr},
                    {'params': self.model.image_hash.parameters(), 'lr': self.args.lr},
                    {'params': self.model.text_hash.parameters(), 'lr': self.args.lr}
                    ], lr=self.args.lr, warmup=self.args.warmup_proportion, schedule='warmup_cosine', 
                    b1=0.9, b2=0.98, e=1e-6, t_total=len(self.train_loader) * self.args.epochs,
                    weight_decay=self.args.weight_decay, max_grad_norm=1.0)
                
        print(self.model)

添加以上代码后一般还会报如下错误

“AttributeError: ‘DataParallel’ object has no attribute ‘xxx’”

解决办法为先在dataparallel后的model调用module模块,然后再调用xxx

比如在上述我自己的代码中会报错

AttributeError: ‘DataParallel’ object has no attribute ‘clip’

解决办法:

是将model,修改为model.module.,后续报错大致相同,将你的代码中涉及到model.的地方修改为model.module.即可。

self.optimizer = BertAdam([
                    {'params': self.model.module.clip.parameters(), 'lr': self.args.clip_lr},
                    {'params': self.model.module.image_hash.parameters(), 'lr': self.args.lr},
                    {'params': self.model.module.text_hash.parameters(), 'lr': self.args.lr}
                    ], lr=self.args.lr, warmup=self.args.warmup_proportion, 

检查显卡使用情况

打开终端,在终端输入nvidia-smi命令可查看显卡使用情况

成功使用双显卡跑代码!

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 详解如何在Pandas中删除常量列

    详解如何在Pandas中删除常量列

    常数列不提供可变性,这意味着它们无助于区分不同的数据点,在许多机器学习模型中,这些列会引入冗余或不相关的数据,从而对性能产生负面影响,因此,通常必须删除常量列,所以本文我们将探索如何使用Python识别和删除Pandas DataFrame中的常量列,需要的朋友可以参考下
    2025-03-03
  • 关于Python参数解析器argparse的应用场景

    关于Python参数解析器argparse的应用场景

    这篇文章主要介绍了关于Python参数解析器argparse的应用场景,argparse 模块使编写用户友好的命令行界面变得容易,程序定义了所需的参数,而 argparse 将找出如何从 sys.argv 中解析这些参数,需要的朋友可以参考下
    2023-08-08
  • linux环境下安装python虚拟环境及注意事项

    linux环境下安装python虚拟环境及注意事项

    这篇文章主要介绍了linux环境下安装python虚拟环境,本文给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-01-01
  • Python实现图片格式转换

    Python实现图片格式转换

    经常会遇到图片格式需要转换的情况,这篇文章主要为大家详细介绍了Python实现图片格式转换,文中示例代码介绍的非常详细、实用,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-08-08
  • python flask之模板继承方式

    python flask之模板继承方式

    这篇文章主要介绍了python flask之模板继承方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-03-03
  • tensorflow 1.X迁移至tensorflow2 的代码写法

    tensorflow 1.X迁移至tensorflow2 的代码写法

    本文主要介绍了tensorflow 1.X迁移至tensorflow2 的代码写法,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-12-12
  • python寻找list中最大值、最小值并返回其所在位置的方法

    python寻找list中最大值、最小值并返回其所在位置的方法

    今天小编就为大家分享一篇python寻找list中最大值、最小值并返回其所在位置的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • 在pandas中一次性删除dataframe的多个列方法

    在pandas中一次性删除dataframe的多个列方法

    下面小编就为大家分享一篇在pandas中一次性删除dataframe的多个列方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • 关于pycharm找不到MySQLdb模块的解决方法

    关于pycharm找不到MySQLdb模块的解决方法

    MySQLdb是用于Python链接Mysql数据库的接口,它实现了Python数据库API规范V2.0,基于MySql C API上建立的,本文给大家介绍pycharm找不到MySQLdb模块解决方法,需要的朋友参考下吧
    2021-06-06
  • Python实现单词翻译功能

    Python实现单词翻译功能

    这篇文章主要为大家详细介绍了Python实现单词翻译功能,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2017-06-06

最新评论