pytorch实现建立自己的数据集(以mnist为例)

 更新时间:2020年01月18日 16:06:16   作者:sjtu_leexx  
今天小编就为大家分享一篇pytorch实现建立自己的数据集(以mnist为例),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

本文将原始的numpy array数据在pytorch下封装为Dataset类的数据集,为后续深度网络训练提供数据。

加载并保存图像信息

首先导入需要的库,定义各种路径。

import os
import matplotlib
from keras.datasets import mnist
import numpy as np
from torch.utils.data.dataset import Dataset
from PIL import Image
import scipy.misc

root_path = 'E:/coding_ex/pytorch/Alexnet/data/'
base_path = 'baseset/'
training_path = 'trainingset/'
test_path = 'testset/'

这里将数据集分为三类,baseset为所有数据(trainingset+testset),trainingset是训练集,testset是测试集。直接通过keras.dataset加载mnist数据集,不能自动下载的话可以手动下载.npz并保存至相应目录下。

def LoadData(root_path, base_path, training_path, test_path):
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
  x_baseset = np.concatenate((x_train, x_test))
  y_baseset = np.concatenate((y_train, y_test))
  train_num = len(x_train)
  test_num = len(x_test)
  
  #baseset
  file_img = open((os.path.join(root_path, base_path)+'baseset_img.txt'),'w')
  file_label = open((os.path.join(root_path, base_path)+'baseset_label.txt'),'w')
  for i in range(train_num + test_num):
    file_img.write(root_path + base_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_baseset[i])+'\n') #label
#    scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i])
    matplotlib.image.imsave(root_path + base_path + 'img/'+str(i) + '.png', x_baseset[i])
  file_img.close()
  file_label.close()
  
  #trainingset
  file_img = open((os.path.join(root_path, training_path)+'trainingset_img.txt'),'w')
  file_label = open((os.path.join(root_path, training_path)+'trainingset_label.txt'),'w')
  for i in range(train_num):
    file_img.write(root_path + training_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_train[i])+'\n') #label
#    scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i])
    matplotlib.image.imsave(root_path + training_path + 'img/'+str(i) + '.png', x_train[i])
  file_img.close()
  file_label.close()
  
  #testset
  file_img = open((os.path.join(root_path, test_path)+'testset_img.txt'),'w')
  file_label = open((os.path.join(root_path, test_path)+'testset_label.txt'),'w')
  for i in range(test_num):
    file_img.write(root_path + test_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_test[i])+'\n') #label
#    scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i])
    matplotlib.image.imsave(root_path + test_path + 'img/'+str(i) + '.png', x_test[i])
  file_img.close()
  file_label.close()

使用这段代码时,需要建立相应的文件夹及.txt文件,./data文件夹结构如下:

/img文件夹

由于mnist数据集其实是灰度图,这里用matplotlib保存的图像是伪彩色图像。

如果用scipy.misc.imsave的话保存的则是灰度图像。

xxx_img.txt文件

xxx_img.txt文件中存放的是每张图像的名字

xxx_label.txt文件

xxx_label.txt文件中存放的是类别标记

这里记得保存的时候一行为一个图像信息,便于后续读取。

定义自己的Dataset类

pytorch训练数据时需要数据集为Dataset类,便于迭代等等,这里将加载保存之后的数据封装成Dataset类,继承该类需要写初始化方法(__init__),获取指定下标数据的方法__getitem__),获取数据个数的方法(__len__)。这里尤其需要注意的是要把label转为LongTensor类型的。

class DataProcessingMnist(Dataset):
  def __init__(self, root_path, imgfile_path, labelfile_path, imgdata_path, transform = None):
    self.root_path = root_path
    self.transform = transform
    self.imagedata_path = imgdata_path
    img_file = open((root_path + imgfile_path),'r')
    self.image_name = [x.strip() for x in img_file]
    img_file.close()
    label_file = open((root_path + labelfile_path), 'r')
    label = [int(x.strip()) for x in label_file]
    label_file.close()
    self.label = torch.LongTensor(label)#这句很重要,一定要把label转为LongTensor类型的
    
  def __getitem__(self, idx):
    image = Image.open(str(self.image_name[idx]))
    image = image.convert('RGB')
    if self.transform is not None:
      image = self.transform(image)
    label = self.label[idx]
    return image, label
  def __len__(self):
    return len(self.image_name)

定义完自己的类之后可以测试一下。

  LoadData(root_path, base_path, training_path, test_path)
  training_imgfile = training_path + 'trainingset_img.txt'
  training_labelfile = training_path + 'trainingset_label.txt'
  training_imgdata = training_path + 'img/'
  #实例化一个类
  dataset = DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata)

得到图像名称

name = dataset.image_name

这里我们可以单独输出某一个名称看一下是否有换行符

print(name[0])
>>>'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png'

如果定义类的时候self.image_name = [x.strip() for x in img_file]这句没有strip掉,则输出的值将为'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png\n'

获取固定下标的图像

im, label = dataset.__getitem__(0)

得到结果

以上这篇pytorch实现建立自己的数据集(以mnist为例)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 详解requirements.txt的生成和安装

    详解requirements.txt的生成和安装

    本文主要介绍了详解requirements.txt的生成和安装,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-03-03
  • python优化测试稳定性的失败重试工具pytest-rerunfailures详解

    python优化测试稳定性的失败重试工具pytest-rerunfailures详解

    笔者在执行自动化测试用例时,会发现有时候用例失败并非代码问题,而是由于服务正在发版,导致请求失败,从而降低了自动化用例的稳定性,那该如何增加失败重试机制呢?带着问题我们一起探索
    2023-10-10
  • 一文教会你用Python绘制动态可视化图表

    一文教会你用Python绘制动态可视化图表

    数据可视化是数据科学中关键的一步,下面这篇文章主要给大家介绍了关于如何利用Python绘制动态可视化图表的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-05-05
  • 解决python flask中config配置管理的问题

    解决python flask中config配置管理的问题

    今天小编就为大家分享一篇解决python flask中config配置管理的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07
  • python爬虫实战之最简单的网页爬虫教程

    python爬虫实战之最简单的网页爬虫教程

    在我们日常上网浏览网页的时候,经常会看到一些好看的图片,我们就希望把这些图片保存下载,或者用户用来做桌面壁纸,或者用来做设计的素材。下面这篇文章就来给大家介绍了关于利用python实现最简单的网页爬虫的相关资料,需要的朋友可以参考借鉴,下面来一起看看吧。
    2017-08-08
  • Python如何实现感知器的逻辑电路

    Python如何实现感知器的逻辑电路

    这篇文章主要介绍了Python如何实现感知器的逻辑电路,帮助大家更好的理解和使用python,感兴趣的朋友可以了解下
    2020-12-12
  • Python判断以什么结尾以什么开头的实例

    Python判断以什么结尾以什么开头的实例

    今天小编就为大家分享一篇Python判断以什么结尾以什么开头的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-10-10
  • Python利用pip安装tar.gz格式的离线资源包

    Python利用pip安装tar.gz格式的离线资源包

    这篇文章主要给大家介绍了关于Python利用pip安装tar.gz格式的离线资源包的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-09-09
  • python rsa 加密解密

    python rsa 加密解密

    本篇文章主要介绍了python rsa加密解密 (编解码,base64编解码)的相关知识。具有很好的参考价值,下面跟着小编一起来看下吧
    2017-03-03
  • Django ORM数据库操作处理全面指南

    Django ORM数据库操作处理全面指南

    本文深度探讨Django ORM的概念、基础使用、进阶操作以及详细解析在实际使用中如何处理数据库操作,同时,我们还讨论了模型深入理解,如何进行CRUD操作,并且深化理解到数据库迁移等高级主题
    2023-09-09

最新评论