pytorch深度神经网络入门准备自己的图片数据

 更新时间:2022年06月29日 17:31:24   作者:denny402  
这篇文章主要为大家介绍了pytorch深度神经网络入门准备自己的图片数据示例过程,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

正文

图片数据一般有两种情况:

1、所有图片放在一个文件夹内,另外有一个txt文件显示标签。

2、不同类别的图片放在不同的文件夹内,文件夹就是图片的类别。

针对这两种不同的情况,数据集的准备也不相同,第一种情况可以自定义一个Dataset,第二种情况直接调用torchvision.datasets.ImageFolder来处理。下面分别进行说明:

一、所有图片放在一个文件夹内

这里以mnist数据集的10000个test为例, 我先把test集的10000个图片保存出来,并生着对应的txt标签文件。

先在当前目录创建一个空文件夹mnist_test, 用于保存10000张图片,接着运行代码:

import torch
import torchvision
import matplotlib.pyplot as plt
from skimage import io
mnist_test= torchvision.datasets.MNIST(
    './mnist', train=False, download=True
)
print('test set:', len(mnist_test))
f=open('mnist_test.txt','w')
for i,(img,label) in enumerate(mnist_test):
    img_path="./mnist_test/"+str(i)+".jpg"
    io.imsave(img_path,img)
    f.write(img_path+' '+str(label)+'\n')
f.close()

经过上面的操作,10000张图片就保存在mnist_test文件夹里了,并在当前目录下生成了一个mnist_test.txt的文件,大致如下:

前期工作就装备好了,接着就进入正题了:

from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from PIL import Image
def default_loader(path):
    return Image.open(path).convert('RGB')
class MyDataset(Dataset):
    def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
        fh = open(txt, 'r')
        imgs = []
        for line in fh:
            line = line.strip('\n')
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0],int(words[1])))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = self.loader(fn)
        if self.transform is not None:
            img = self.transform(img)
        return img,label
    def __len__(self):
        return len(self.imgs)
train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor())
data_loader = DataLoader(train_data, batch_size=100,shuffle=True)
print(len(data_loader))
def show_batch(imgs):
    grid = utils.make_grid(imgs)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))
    plt.title('Batch from dataloader')
for i, (batch_x, batch_y) in enumerate(data_loader):
    if(i<4):
        print(i, batch_x.size(),batch_y.size())
        show_batch(batch_x)
        plt.axis('off')
        plt.show()

自定义了一个MyDataset, 继承自torch.utils.data.Dataset。然后利用torch.utils.data.DataLoader将整个数据集分成多个批次。

二、不同类别的图片放在不同的文件夹内

同样先准备数据,这里以flowers数据集为例

提取 链接: https://pan.baidu.com/s/1dcAsOOZpUfWNYR77JGXPHA?pwd=mwg6 

花总共有五类,分别放在5个文件夹下。大致如下图:

我的路径是d:/flowers/.

数据准备好了,就开始准备Dataset吧,这里直接调用torchvision里面的ImageFolder

import torch
import torchvision
from torchvision import transforms, utils
import matplotlib.pyplot as plt
img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower',
                                            transform=transforms.Compose([
                                                transforms.Scale(256),
                                                transforms.CenterCrop(224),
                                                transforms.ToTensor()])
                                            )
print(len(img_data))
data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True)
print(len(data_loader))
def show_batch(imgs):
    grid = utils.make_grid(imgs,nrow=5)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))
    plt.title('Batch from dataloader')
for i, (batch_x, batch_y) in enumerate(data_loader):
    if(i<4):
        print(i, batch_x.size(), batch_y.size())
        show_batch(batch_x)
        plt.axis('off')
        plt.show()

以上就是pytorch深度神经网络入门准备自己的图片数据的详细内容,更多关于pytorch图片数据准备的资料请关注脚本之家其它相关文章!

相关文章

  • Python使用VIF实现检测多重共线性

    Python使用VIF实现检测多重共线性

    多重共线性是指多元回归模型中有两个或两个以上的自变量,它们之间具有高度的相关性,本文主要介绍了如何使用VIF实现检测多重共线性,需要的可以参考下
    2023-12-12
  • Python 的第三方调试库 ​​​pysnooper​​ 使用示例

    Python 的第三方调试库 ​​​pysnooper​​ 使用示例

    这篇文章主要介绍了Python 的第三方调试库 ​​​pysnooper​​ 使用示例的相关资料,需要的朋友可以参考下
    2023-02-02
  • python 读取视频,处理后,实时计算帧数fps的方法

    python 读取视频,处理后,实时计算帧数fps的方法

    今天小编就为大家分享一篇python 读取视频,处理后,实时计算帧数fps的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-07-07
  • Python批量获取并保存手机号归属地和运营商的示例

    Python批量获取并保存手机号归属地和运营商的示例

    这篇文章主要介绍了Python批量获取并保存手机号的归属地和运营商的示例,帮助大家更好的利用python处理数据,感兴趣的朋友可以了解下
    2020-10-10
  • Python使用遗传算法解决最大流问题

    Python使用遗传算法解决最大流问题

    这篇文章主要为大家详细介绍了Python使用遗传算法解决最大流问题,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-01-01
  • python爬虫请求头设置代码

    python爬虫请求头设置代码

    在本篇文章里小编给大家整理的是一篇关于python爬虫请求头如何设置内容,需要的朋友们可以学习下。
    2020-07-07
  • Python中作用域的深入讲解

    Python中作用域的深入讲解

    这篇文章主要给大家介绍了关于Python中作用域的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2018-12-12
  • 简单了解为什么python函数后有多个括号

    简单了解为什么python函数后有多个括号

    这篇文章主要介绍了简单了解为什么python函数后有多个括号,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-12-12
  • 基于python cut和qcut的用法及区别详解

    基于python cut和qcut的用法及区别详解

    今天小编就为大家分享一篇基于python cut和qcut的用法及区别详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-11-11
  • python标准库学习之sys模块详解

    python标准库学习之sys模块详解

    sys模块是最常用的和python解释器交互的模块,sys模块可供访问由解释器(interpreter)使用或维护的变量和与解释器进行交互的函数,下面这篇文章主要给大家介绍了关于python标准库学习之sys模块的相关资料,需要的朋友可以参考下
    2022-08-08

最新评论