PyTorch 解决Dataset和Dataloader遇到的问题

 更新时间:2020年01月08日 14:14:34   作者:xgbm_k  
今天小编就为大家分享一篇PyTorch 解决Dataset和Dataloader遇到的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

今天在使用PyTorch中Dataset遇到了一个问题。先看代码

class psDataset(Dataset):
  def __init__(self, x, y, transforms = None):
    super(Dataset, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = Compose([Resize((224, 224)), ToTensor()])
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)
  
  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = self.transforms(img)    
    return img, torch.tensor([[self.y[idx]]])

结果运行时报错:RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 3 and 1 in dimension 1 at /opt/conda/conda-bld/pytorch_1522182087074/work/torch/lib/TH/generic/THTensorMath.c:2897

Google了一下发现是这样的:读入的图片有些是灰度图(1个通道),绝大多数是RGB图片(3通道),也有些是带透明度的(4通道)

。这导致在读入后最后一个维度(通道数)不一致(可能是1、3或者4)。

Dataloader在制作batch data时,tensor的shape必须一样,就报了这个错误。解决的方法是:img = img.convert(“RGB”)。完

整代码如下:

class psDataset(Dataset):
  def __init__(self, x, y, transforms = None):
    super(Dataset, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = Compose([Resize((224, 224)), ToTensor()])
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)
  
  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = img.convert("RGB")
    img = self.transforms(img)    
    return img, torch.tensor([[self.y[idx]]])

以上这篇PyTorch 解决Dataset和Dataloader遇到的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Selenium+BeautifulSoup+json获取Script标签内的json数据

    Selenium+BeautifulSoup+json获取Script标签内的json数据

    这篇文章主要介绍了Selenium+BeautifulSoup+json获取Script标签内的json数据,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-12-12
  • 手动安装python3.6的操作过程详解

    手动安装python3.6的操作过程详解

    这篇文章主要介绍了如何手动安装python3.6,本文给大家带来了安装步骤,非常不错,具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-01-01
  • Python反射机制案例超详细讲解

    Python反射机制案例超详细讲解

    反射就是通过字符串的形式,导入模块;通过字符串的形式,去模块寻找指定函数,并执行。利用字符串的形式去对象(模块)中操作(查找/获取/删除/添加)成员,一种基于字符串的事件驱动
    2022-09-09
  • Python 抓取微信公众号账号信息的方法

    Python 抓取微信公众号账号信息的方法

    搜狗微信搜索提供两种类型的关键词搜索,一种是搜索公众号文章内容,另一种是直接搜索微信公众号。这篇文章主要介绍了Python 抓取微信公众号账号信息,需要的朋友可以参考下
    2019-06-06
  • 用pyqt5 给按钮设置图标和css样式的方法

    用pyqt5 给按钮设置图标和css样式的方法

    今天小编就为大家分享一篇用pyqt5 给按钮设置图标和css样式的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-06-06
  • python利用socketserver实现并发套接字功能

    python利用socketserver实现并发套接字功能

    这篇文章主要为大家详细介绍了python利用socketserver实现并发套接字功能,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-01-01
  • Python3中的map函数调用后内存释放问题

    Python3中的map函数调用后内存释放问题

    这篇文章主要介绍了Python3中的map函数调用后内存释放问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2024-02-02
  • Python实现一个优先级队列的方法

    Python实现一个优先级队列的方法

    这篇文章主要介绍了Python实现一个优先级队列的方法,文中讲解非常细致,代码帮助大家更好的理解和学习,感兴趣的朋友可以了解下
    2020-07-07
  • Jmeter并发执行Python 脚本的完整流程

    Jmeter并发执行Python 脚本的完整流程

    这篇文章主要介绍了Jmeter并发执行 Python 脚本的问题详解,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-09-09
  • 利用python实现在微信群刷屏的方法

    利用python实现在微信群刷屏的方法

    今天小编就为大家分享一篇利用python实现在微信群刷屏的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-02-02

最新评论