Pytorch 实现数据集自定义读取

 更新时间:2020年01月18日 17:20:27   作者:_寒潭雁影  
今天小编就为大家分享一篇Pytorch 实现数据集自定义读取,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

以读取VOC2012语义分割数据集为例,具体见代码注释:

VocDataset.py

from PIL import Image
import torch
import torch.utils.data as data
import numpy as np
import os
import torchvision
import torchvision.transforms as transforms
import time

#VOC数据集分类对应颜色标签
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
        [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
        [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
        [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
        [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
        [0, 64, 128]]

#颜色标签空间转到序号标签空间,就他妈这里浪费巨量的时间,这里还他妈的有问题
def voc_label_indices(colormap, colormap2label):
  """Assign label indices for Pascal VOC2012 Dataset."""
  idx = ((colormap[:, :, 2] * 256 + colormap[ :, :,1]) * 256+ colormap[:, :,0])
  #out = np.empty(idx.shape, dtype = np.int64) 
  out = colormap2label[idx]
  out=out.astype(np.int64)#数据类型转换
  end = time.time()
  return out

class MyDataset(data.Dataset):#创建自定义的数据读取类
  def __init__(self, root, is_train, crop_size=(320,480)):
    self.rgb_mean =(0.485, 0.456, 0.406)
    self.rgb_std = (0.229, 0.224, 0.225)
    self.root=root
    self.crop_size=crop_size
    images = []#创建空列表存文件名称
    txt_fname = '%s/ImageSets/Segmentation/%s' % (root, 'train.txt' if is_train else 'val.txt')
    with open(txt_fname, 'r') as f:
      self.images = f.read().split()
    #数据名称整理
    self.files = []
    for name in self.images:
      img_file = os.path.join(self.root, "JPEGImages/%s.jpg" % name)
      label_file = os.path.join(self.root, "SegmentationClass/%s.png" % name)
      self.files.append({
        "img": img_file,
        "label": label_file,
        "name": name
      })
    self.colormap2label = np.zeros(256**3)
    #整个循环的意思就是将颜色标签映射为单通道的数组索引
    for i, cm in enumerate(VOC_COLORMAP):
      self.colormap2label[(cm[2] * 256 + cm[1]) * 256 + cm[0]] = i
  #按照索引读取每个元素的具体内容
  def __getitem__(self, index):
    
    datafiles = self.files[index]
    name = datafiles["name"]
    image = Image.open(datafiles["img"])
    label = Image.open(datafiles["label"]).convert('RGB')#打开的是PNG格式的图片要转到rgb的格式下,不然结果会比较要命
    #以图像中心为中心截取固定大小图像,小于固定大小的图像则自动填0
    imgCenterCrop = transforms.Compose([
       transforms.CenterCrop(self.crop_size),
       transforms.ToTensor(),
       transforms.Normalize(self.rgb_mean, self.rgb_std),#图像数据正则化
     ])
    labelCenterCrop = transforms.CenterCrop(self.crop_size)
    cropImage=imgCenterCrop(image)
    croplabel=labelCenterCrop(label)
    croplabel=torch.from_numpy(np.array(croplabel)).long()#把标签数据类型转为torch
    
    #将颜色标签图转为序号标签图
    mylabel=voc_label_indices(croplabel, self.colormap2label)
    
    return cropImage,mylabel
  #返回图像数据长度
  def __len__(self):
    return len(self.files)

Train.py

import matplotlib.pyplot as plt
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np

from PIL import Image
from VocDataset import MyDataset

#VOC数据集分类对应颜色标签
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
        [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
        [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
        [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
        [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
        [0, 64, 128]]

root='../data/VOCdevkit/VOC2012'
train_data=MyDataset(root,True)
trainloader = data.DataLoader(train_data, 4)

#从数据集中拿出一个批次的数据
for i, data in enumerate(trainloader):
  getimgs, labels= data
  img = transforms.ToPILImage()(getimgs[0])

  labels = labels.numpy()#tensor转numpy
  labels=labels[0]#获得批次标签集中的一张标签图像
  labels = labels.transpose((1,0))#数组维度切换,将第1维换到第0维,第0维换到第1维

  ##将单通道索引标签图片映射回颜色标签图片
  newIm= Image.new('RGB', (480, 320))#创建一张与标签大小相同的图片,用以显示标签所对应的颜色
  for i in range(0, 480):
    for j in range(0, 320):
      sele=labels[i][j]#取得坐标点对应像素的值
      newIm.putpixel((i, j), (int(VOC_COLORMAP[sele][0]), int(VOC_COLORMAP[sele][1]), int(VOC_COLORMAP[sele][2])))

  #显示图像和标签
  plt.figure("image")
  ax1 = plt.subplot(1,2,1)
  ax2 = plt.subplot(1,2,2)
  plt.sca(ax1)
  plt.imshow(img)
  plt.sca(ax2)
  plt.imshow(newIm)
  plt.show()

以上这篇Pytorch 实现数据集自定义读取就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python深度学习albumentations数据增强库

    Python深度学习albumentations数据增强库

    下面开始albumenations的正式介绍,在这里我强烈建议英语基础还好的读者去官方网站跟着教程一步步学习,而这里的内容主要是我自己的一个总结以及方便英语能力较弱的读者学习
    2021-09-09
  • Python自动化开发学习之三级菜单制作

    Python自动化开发学习之三级菜单制作

    这篇文章主要为大家详细介绍了Python自动化开发学习之三级菜单的制作方法,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2017-07-07
  • python多继承(钻石继承)问题和解决方法简单示例

    python多继承(钻石继承)问题和解决方法简单示例

    这篇文章主要介绍了python多继承(钻石继承)问题和解决方法,结合实例形式分析了Python多继承调用父类初始化方法相关操作技巧,需要的朋友可以参考下
    2019-10-10
  • 如何在Python中将字符串转换为数组详解

    如何在Python中将字符串转换为数组详解

    最近在用Python,做一个小脚本,有个操作就是要把内容换成数组对象再进行相关操作,下面这篇文章主要给大家介绍了关于如何在Python中将字符串转换为数组的相关资料,需要的朋友可以参考下
    2022-12-12
  • python基于itchat模块实现微信防撤回

    python基于itchat模块实现微信防撤回

    这篇文章主要为大家详细介绍了python实现微信防撤回,基于itchat模块,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-04-04
  • Python异常之常见的Bug类型解决方法

    Python异常之常见的Bug类型解决方法

    这篇文章主要介绍了Python异常之常见的Bug类型解决方法,主要分享一些粗心导致和知识不熟练导致的语法错误以及被迫掉坑等内容,文章介绍非常详细需要的小伙伴可以参考一下
    2022-03-03
  • Python使用装饰器模拟用户登陆验证功能示例

    Python使用装饰器模拟用户登陆验证功能示例

    这篇文章主要介绍了Python使用装饰器模拟用户登陆验证功能,结合登录验证实例形式分析了装饰器的简单使用技巧,需要的朋友可以参考下
    2018-08-08
  • 使用python将csv数据导入mysql数据库

    使用python将csv数据导入mysql数据库

    这篇文章主要为大家详细介绍了如何使用python将csv数据导入mysql数据库,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下
    2024-05-05
  • 在Django下测试与调试REST API的方法详解

    在Django下测试与调试REST API的方法详解

    今天小编就为大家分享一篇在Django下测试与调试REST API的方法详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08
  • Python Pandas学习之Pandas数据结构详解

    Python Pandas学习之Pandas数据结构详解

    Pandas中一共有三种数据结构,分别为:Series、DataFrame和MultiIndex(老版本中叫Panel )。其中Series是一维数据结构,DataFrame是二维的表格型数据结构,MultiIndex是三维的数据结构。本文将详细为大家讲解这三个数据结构,需要的可以参考一下
    2022-02-02

最新评论