解决pytorch model代码内tensor device不一致的问题

 更新时间:2023年07月04日 16:34:38   作者:_LvP  
这篇文章主要介绍了pytorch model代码内tensor device不一致的问题,本文给大家分享完美解决方案,对pytorch tensor device不一致问题解决方案感兴趣的朋友跟随小编一起看看吧

pytorch model代码内tensor device不一致的问题

在编写一段处理两个tensor的代码如下,需要在forward函数内编写函数创建一个新的tensor进行索引的掩码计算

# todo(liang)空间交换
def compute_sim_and_swap(t1, t2, threshold=0.7):
     n, c, h, w = t1.shape
     sim = torch.nn.functional.cosine_similarity(t1, t2, dim=1) # n, h, w
     sim = sim.unsqueeze(0) # c, n, h, w
     expand_tensor = sim.clone()
     # 使用拼接构建相同的维度
     for _ in range(c-1): # c, n, h, w
         sim = torch.cat([sim, expand_tensor], dim=0)
     sim = sim.permute(1, 0, 2, 3) # n, c, h, w
     # 创建逻辑掩码,小于 threshold 的将掩码变为 True 用于交换
     mask = sim < threshold
     indices = torch.rand(mask.shape) < 0.5
     t1[mask&indices], t2[mask&indices] = t2[mask&indices], t1[mask&indices]
     return t1, t2

这段代码报了这个错误

File "xxx/network.py", line 347, in compute_sim_and_swap
t1[mask&indices], t2[mask&indices] = t2[mask&indices], t1[mask&indices]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

统一下进行掩码计算的张量的设备即可

device = mask.Device
indices = indices.to(device)

PyTorch 多GPU使用torch.nn.DataParallel训练参数不一致问题

在多GPU训练时,遇到了下述的错误:

1. Expected tensor for argument  1 'input' to have the same device as tensor for argument  2 'weight'; but device 0 does not equal 1 
2. RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

造成这个错误的可能性有挺多,总起来是模型、输入、模型内参数不在一个GPU上。本人是在调试RandLA-Net pytorch源码,希望使用双GPU训练,经过尝试解决这个问题,此处做一个记录,希望给后来人一个提醒。经过调试,发现报错的地方主要是在数据拼接的时候,即一个数据在GPU0上,一个数据在GPU1上,这就会出现错误,相关代码如下:

return torch.cat((
            self.mlp(concat),
            features.expand(B, -1, N, K)
        ), dim=-3)

上述代码中,必须保证self.mlp(concat)与features.expand(B, -1, N, K)在同一个GPU中。在多GPU运算时,features(此时是输入变量)有可能放在任何一个GPU中,因此此处在拼接前,获取一下features的GPU,然后将concat放入相应的GPU中,再进行数据拼接就可以了,代码如下:

device = features.device
concat = concat.to(device)
return torch.cat((
            self.mlp(concat),
            features.expand(B, -1, N, K)
        ), dim=-3)

该源码中默认状态下device是一个固定的值,在多GPU训练状态下就会报错,代码中还有几处数据融合,大家可以依据上述思路做修改。此外该源码中由于把device的值写死了,训练好的模型也必须在相应的GPU中做推理,如在cuda0中训练的模型如果在cuda1中推理就会报错,各位可以依据此思路对源码做相应的修改。如果修改有困难,可以私信我,我可以把相关修改后的源码分享。

到此这篇关于pytorch model代码内tensor device不一致的问题的文章就介绍到这了,更多相关pytorch tensor device不一致内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • 将Python脚本打包成MACOSAPP程序过程

    将Python脚本打包成MACOSAPP程序过程

    我们编写python程序时,有时候需要想将python脚本转成可执行的程序或者app,可以直接通过双击执行即可,像Windows上可以将其通过工具转换成exe程序,那么在MACOS下我们可以将其打包成MACOS APP程序
    2021-09-09
  • django如何实现用户的注册、登录、注销功能

    django如何实现用户的注册、登录、注销功能

    本文详细介绍了创建Django项目的步骤,包括配置数据库、编写用户模型、创建迁移文件、编写表单校验、编写前端页面、编写视图类、编写路由、使用Django自带的管理后台以及具体的文件结构,通过这些步骤,可以实现一个基本的Django项目
    2025-01-01
  • python画微信表情符的实例代码

    python画微信表情符的实例代码

    这篇文章主要介绍了python画微信表情的实例代码,代码简单易懂,非常不错,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-10-10
  • python实现不同数据库间数据同步功能

    python实现不同数据库间数据同步功能

    这篇文章主要介绍了python实现不同数据库间数据同步功能,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-02-02
  • Python web实战教程之Django文件上传和处理详解

    Python web实战教程之Django文件上传和处理详解

    Django和Flask都是Python的Web框架,用于开发Web应用程序,这篇文章主要给大家介绍了关于Python web实战教程之Django文件上传和处理的相关资料,文中通过代码介绍的非常详细,需要的朋友可以参考下
    2023-12-12
  • python3.6.4安装opencv3.4.2的实现

    python3.6.4安装opencv3.4.2的实现

    这篇文章主要介绍了python3.6.4安装opencv3.4.2的实现方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-10-10
  • 详解Python编程中基本的数学计算使用

    详解Python编程中基本的数学计算使用

    这篇文章主要介绍了Python编程中基本的数学计算使用,其中重点讲了除法运算及相关division模块的使用,需要的朋友可以参考下
    2016-02-02
  • python pyheatmap包绘制热力图

    python pyheatmap包绘制热力图

    这篇文章主要为大家详细介绍了python pyheatmap包绘制热力图,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-11-11
  • 谈谈如何手动释放Python的内存

    谈谈如何手动释放Python的内存

    Python不会自动清理这些内存,这篇文章主要介绍了谈谈如何手动释放Python的内存,具有一定的参考价值,感兴趣的小伙伴们可以参考一下。
    2016-12-12
  • python绘制多个子图的实例

    python绘制多个子图的实例

    今天小编就为大家分享一篇python绘制多个子图的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07

最新评论