解决pytorch model代码内tensor device不一致的问题

 更新时间:2023年07月04日 16:34:38   作者:_LvP  
这篇文章主要介绍了pytorch model代码内tensor device不一致的问题,本文给大家分享完美解决方案,对pytorch tensor device不一致问题解决方案感兴趣的朋友跟随小编一起看看吧

pytorch model代码内tensor device不一致的问题

在编写一段处理两个tensor的代码如下,需要在forward函数内编写函数创建一个新的tensor进行索引的掩码计算

# todo(liang)空间交换
def compute_sim_and_swap(t1, t2, threshold=0.7):
     n, c, h, w = t1.shape
     sim = torch.nn.functional.cosine_similarity(t1, t2, dim=1) # n, h, w
     sim = sim.unsqueeze(0) # c, n, h, w
     expand_tensor = sim.clone()
     # 使用拼接构建相同的维度
     for _ in range(c-1): # c, n, h, w
         sim = torch.cat([sim, expand_tensor], dim=0)
     sim = sim.permute(1, 0, 2, 3) # n, c, h, w
     # 创建逻辑掩码,小于 threshold 的将掩码变为 True 用于交换
     mask = sim < threshold
     indices = torch.rand(mask.shape) < 0.5
     t1[mask&indices], t2[mask&indices] = t2[mask&indices], t1[mask&indices]
     return t1, t2

这段代码报了这个错误

File "xxx/network.py", line 347, in compute_sim_and_swap
t1[mask&indices], t2[mask&indices] = t2[mask&indices], t1[mask&indices]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

统一下进行掩码计算的张量的设备即可

device = mask.Device
indices = indices.to(device)

PyTorch 多GPU使用torch.nn.DataParallel训练参数不一致问题

在多GPU训练时,遇到了下述的错误:

1. Expected tensor for argument  1 'input' to have the same device as tensor for argument  2 'weight'; but device 0 does not equal 1 
2. RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

造成这个错误的可能性有挺多,总起来是模型、输入、模型内参数不在一个GPU上。本人是在调试RandLA-Net pytorch源码,希望使用双GPU训练,经过尝试解决这个问题,此处做一个记录,希望给后来人一个提醒。经过调试,发现报错的地方主要是在数据拼接的时候,即一个数据在GPU0上,一个数据在GPU1上,这就会出现错误,相关代码如下:

return torch.cat((
            self.mlp(concat),
            features.expand(B, -1, N, K)
        ), dim=-3)

上述代码中,必须保证self.mlp(concat)与features.expand(B, -1, N, K)在同一个GPU中。在多GPU运算时,features(此时是输入变量)有可能放在任何一个GPU中,因此此处在拼接前,获取一下features的GPU,然后将concat放入相应的GPU中,再进行数据拼接就可以了,代码如下:

device = features.device
concat = concat.to(device)
return torch.cat((
            self.mlp(concat),
            features.expand(B, -1, N, K)
        ), dim=-3)

该源码中默认状态下device是一个固定的值,在多GPU训练状态下就会报错,代码中还有几处数据融合,大家可以依据上述思路做修改。此外该源码中由于把device的值写死了,训练好的模型也必须在相应的GPU中做推理,如在cuda0中训练的模型如果在cuda1中推理就会报错,各位可以依据此思路对源码做相应的修改。如果修改有困难,可以私信我,我可以把相关修改后的源码分享。

到此这篇关于pytorch model代码内tensor device不一致的问题的文章就介绍到这了,更多相关pytorch tensor device不一致内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python实现与redis交互操作详解

    python实现与redis交互操作详解

    这篇文章主要介绍了python实现与redis交互操作,结合实例形式分析了Python Redis模块的安装、导入、连接与简单操作相关实现技巧,需要的朋友可以参考下
    2020-04-04
  • 提高python代码可读性利器pycodestyle使用详解

    提高python代码可读性利器pycodestyle使用详解

    鉴于 Python 在数据科学中的流行,我将深入研究 pycodestyle 的使用方法,以提高 Python 代码的质量和可读性。如果你想提升代码质量,欢迎收藏学习,有所收获,点赞支持
    2021-11-11
  • python字符串对其居中显示的方法

    python字符串对其居中显示的方法

    这篇文章主要介绍了python字符串对其居中显示的方法,涉及Python打印输出显示的相关技巧,具有一定参考借鉴价值,需要的朋友可以参考下
    2015-07-07
  • 两个很实用的Python装饰器详解

    两个很实用的Python装饰器详解

    这篇文章主要为大家介绍了Python的装饰器,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助,希望能够给你带来帮助
    2021-11-11
  • Python pandas找出、删除重复的数据实例

    Python pandas找出、删除重复的数据实例

    在面试中很可能遇到给定一个含有重复元素的列表,删除其中重复的元素,下面这篇文章主要给大家介绍了关于Python pandas找出、删除重复数据的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-07-07
  • python 对dataframe下面的值进行大规模赋值方法

    python 对dataframe下面的值进行大规模赋值方法

    今天小编就为大家分享一篇python 对dataframe下面的值进行大规模赋值方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • Python操作多维数组输出和矩阵运算示例

    Python操作多维数组输出和矩阵运算示例

    这篇文章主要介绍了Python操作多维数组输出和矩阵运算,结合实例形式分析了Python多维数组的生成、打印输出及矩阵运算相关操作技巧,需要的朋友可以参考下
    2019-11-11
  • python2和python3的输入和输出区别介绍

    python2和python3的输入和输出区别介绍

    这篇文章主要介绍了python2和python3的输入和输出区别介绍,本文给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2018-11-11
  • Python通过DOM和SAX方式解析XML的应用实例分享

    Python通过DOM和SAX方式解析XML的应用实例分享

    这篇文章主要介绍了Python通过DOM和SAX方式解析XML的应用实例分享,针对这两种解析方式Python都有相关的模块可供使用,需要的朋友可以参考下
    2015-11-11
  • Python增强下git那长长的指令详解

    Python增强下git那长长的指令详解

    这篇文章主要介绍了Python增强下git那长长的指令 ,在开发中用到的代码目录结构,本文也给大家详细讲解,需要的朋友可以参考下
    2021-09-09

最新评论