pytorch利用Dataset读取数据报错问题及解决
报错点
如下:
Traceback (most recent call last):
File "read_data.py", line 100, in <module>
for i , (image,seg) in enumerate(train_loader):
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 819, in __next__
return self._process_data(data)
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 846, in _process_data
data.reraise()
File "/usr/local/lib/python3.6/dist-packages/torch/_utils.py", line 369, in reraise
raise self.exc_type(msg)
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "read_data.py", line 91, in __getitem__
])(img)
File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py", line 61, in __call__
img = t(img)
File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py", line 238, in __call__
return F.center_crop(img, self.size)
File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py", line 374, in center_crop
w, h = img.size
TypeError: 'int' object is not iterable
原来我其实没有注意Datset与PIL下面的Image的关系:
def __getitem__(self, index):
img = cv2.imread(self.image_name[index],cv2.COLOR_BGR2RGB)
#img = np.transpose(img,(2,1,0))
img = cv2.resize(img,(self.size,self.size))
seg = cv2.imread(self.image_seg[index],cv2.COLOR_BGR2RGB)
seg = cv2.resize(seg,(self.size,self.size) )
seg = convert_from_color_segmentation(seg)
#seg = torch.from_numpy(seg)
if self.transform is not None:
img = self.transform(img)
return img , seg报错中清晰提及这个问题
我突然反应过来,是自己的读取数据错误了:
应该为:
def __getitem__(self, index):
#img = cv2.imread(self.image_name[index],cv2.COLOR_BGR2RGB)
img = Image.open(self.image_name[index])
#img = np.transpose(img,(2,1,0))
#img = cv2.resize(img,(self.size,self.size))
seg = cv2.imread(self.image_seg[index],cv2.COLOR_BGR2RGB)
seg = cv2.resize(seg,(self.size,self.size) )
seg = convert_from_color_segmentation(seg)
#seg = torch.from_numpy(seg)
if self.transform is not None:
img = self.transform(img)
return img , seg测试打印数据
完美解决
transform = transforms.Compose([transforms.Resize((300,300)),transforms.RandomCrop((224,224)),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
#transform = transforms.Compose([
# transforms.CenterCrop((278,278)),transforms.Resize((224,224)),transforms.ToTensor()
# ])
train_data = GetParasetData(size=224,train=True,transform=transform)
train_loader = DataLoader(train_data,batch_size=64,shuffle=True,num_workers=2)
for i , (image,seg) in enumerate(train_loader):
print(image.shape,seg.shape)
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
相关文章
tensorflow模型的save与restore,及checkpoint中读取变量方式
这篇文章主要介绍了tensorflow模型的save与restore,及checkpoint中读取变量方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧2020-05-05
Python利用memory_profiler查看内存占用情况
memory_profiler是第三方模块,用于监视进程的内存消耗以及python程序内存消耗的逐行分析。本文将利用memory_profiler查看代码运行占用内存情况,感兴趣的可以了解一下2022-06-06
python使用requests.post方法传递form-data类型的Excel数据的示例代码
这篇文章介绍了python使用requests.post方法传递form-data类型的Excel数据的示例代码,某些post接口,需要发送multipart/form-data类型的数据,如何使用python requests来模拟这种类型的请求发送呢?补充讲解了python使用requests post请求发送form-data类型数据,一起看看吧2024-01-01
Python 使用folium绘制leaflet地图的实现方法
今天小编就为大家分享一篇Python 使用folium绘制leaflet地图的实现方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧2019-07-07


最新评论