解决torch.to(device)是否赋值的坑

 更新时间:2024年06月27日 14:45:39   作者:不会卷积  
这篇文章主要介绍了解决torch.to(device)是否赋值的坑,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

torch.to(device)是否赋值的坑

在我们用GPU跑程序时,需要在程序中把变量和模型放到GPU里面。

有一些坑需要注意,本文用RNN模型实例

首先,定义device

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

对于变量,需要进行赋值操作才能真正转到GPU上:

all_input_batch=all_input_batch.to(device)

对于模型,不需要进行赋值:

 model = TextRNN()
 model.to(device)

对模型进行to(device),还有一种方法,就是在定义模型的时候全部对模型网络参数to(device),这样就可以不需要model.to(device)这句话。

class TextRNN(nn.Module):

    def __init__(self):
        super(TextRNN, self).__init__()
        #self.cnt = 0
        self.C = nn.Embedding(n_class, embedding_dim=emb_size,device=device)
        self.rnn = nn.RNN(input_size=emb_size, hidden_size=n_hidden,device=device)
        self.W = nn.Linear(n_hidden, n_class, bias=False,device=device)
        self.b = nn.Parameter(torch.ones([n_class])).to(device)


    def forward(self, X):
        X = self.C(X)
        #print(X.is_cuda)
        X = X.transpose(0, 1) # X : [n_step, batch_size, embeding size]
        outputs, hidden = self.rnn(X)
        # outputs : [n_step, batch_size, num_directions(=1) * n_hidden]
        # hidden : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
        outputs = outputs[-1] # [batch_size, num_directions(=1) * n_hidden]
        model = self.W(outputs) + self.b # model : [batch_size, n_class]
        return model

pytorch中model=model.to(device)用法

这代表将模型加载到指定设备上。

其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")则代表的使用GPU。

当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中。

将由GPU保存的模型加载到CPU上

torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

将由GPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由CPU保存的模型加载到GPU上

确保对输入的tensors调用input = input.to(device)方法。

map_location是将模型加载到GPU上,model.to(torch.device('cuda'))是将模型参数加载为CUDA的tensor。

最后保证使用.to(torch.device('cuda'))方法将需要使用的参数放入CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python3批量移动指定文件到指定文件夹方法示例

    Python3批量移动指定文件到指定文件夹方法示例

    这篇文章主要给大家介绍了关于Python3批量移动指定文件到指定文件夹的相关资料,文中通过示例代码介绍的非常详细,对大家学习或者使用Python3具有一定的参考学习价值,需要的朋友们下面来一起学习学习吧
    2019-09-09
  • Python写的一个简单DNS服务器实例

    Python写的一个简单DNS服务器实例

    这篇文章主要介绍了Python写的一个简单DNS服务器实例,需要的朋友可以参考下
    2014-06-06
  • pytorch图片分割原理分析

    pytorch图片分割原理分析

    自Transformer模型被应用于计算机视觉领域后,图像分割技术得到了进一步的发展,但图像分割操作复杂,特别是对张量的处理,涉及多种变换方法,其中,view/reshape用于改变数据形状,而permute/transpose用于改变数据的维度顺序
    2024-10-10
  • 基于python win32setpixel api 实现计算机图形学相关操作(推荐)

    基于python win32setpixel api 实现计算机图形学相关操作(推荐)

    这篇文章主要介绍了基于python win32setpixel api 实现计算机图形学相关操作,这次的主要分为2个主要模块,一个是实现画线,画圆的算法,还有填充的算法,以及裁剪的算法,需要的朋友可以参考下
    2021-12-12
  • Python实现图片和视频的相互转换

    Python实现图片和视频的相互转换

    有时候我们需要把很多的图片合成视频,或者说自己写一个脚本去加快或者放慢视频;也有时候需要把视频裁剪成图片,进行后续操作。这篇文章就将为大家介绍如何通过Python实现图片和视频的相互转换,需要的可以参考一下
    2021-12-12
  • Python利用神经网络解决非线性回归问题实例详解

    Python利用神经网络解决非线性回归问题实例详解

    这篇文章主要介绍了Python利用神经网络解决非线性回归问题,结合实例形式详细分析了Python使用神经网络解决非线性回归问题的相关原理与实现技巧,需要的朋友可以参考下
    2019-07-07
  • Python调用ollama本地大模型进行批量识别PDF

    Python调用ollama本地大模型进行批量识别PDF

    现在市场上有很多PDF文件的识别,然而随着AI的兴起,本地大模型的部署,这些成为一种很方便的方法,本文我们就来看看Python如何调用ollama本地大模型进行PDF相关操作吧
    2025-03-03
  • Notepad 轻量级文本编辑器的安装及基本使用

    Notepad 轻量级文本编辑器的安装及基本使用

    notepad–是一个国产跨平台、轻量级的文本编辑器,是替换notepad++的一种选择,notepad特点支持Window/Mac/Linux操作系统平台,支持其他notepad竞品的常用功能,这篇文章给大家介绍Notepad 轻量级文本编辑器的安装及基本使用,感兴趣的朋友一起看看吧
    2024-01-01
  • tensorflow入门之训练简单的神经网络方法

    tensorflow入门之训练简单的神经网络方法

    本篇文章主要介绍了tensorflow入门之训练简单的神经网络方法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-02-02
  • python 实现屏幕录制示例

    python 实现屏幕录制示例

    今天小编就为大家分享一篇python 实现屏幕录制示例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12

最新评论