人工智能学习pyTorch自建数据集及可视化结果实现过程

 更新时间:2021年11月11日 14:16:57   作者:Swayzzu  
这篇文章主要为大家介绍了人工智能学习pyTorch自建数据集及可视化结果的实现过程,有需要的朋友可以借鉴参考下,希望能够有所帮助

一、自定义数据集

现有数据如下:

5个文件夹,每个文件夹是神奇宝贝的一种。

每个图片形状、大小、格式不一。

我们训练CNN的时候需要的是tensor类型的数据,因此需要将所有的图片进行下列转换:

1.对文件夹编号,进行映射,比如妙蛙种子文件夹编号0,皮卡丘编号1等。

2.对文件夹中所有图片,进行编号的对应,这个就是标签。并保存为一个csv文件。

3.图片信息获取:分为train,val,test

4.处理图片,使其成为torch可以处理的类型

1.文件夹映射

前半部分为文件夹的映射。我们希望传入数据的时候直接传入文件夹的名字,而文件夹所在的路径就是py文件所在的路径,因此这样可以直接读取。对于路径的操作使用os.path.join进行。

2.图片对应标签

输入的filename,就是我们将图片和标签信息存储的文件。

使用glob.glob方法,可以轻松调取路径下的所有指定类型的文件。

将名字和标签对应好后,通过csv.writer,可以将信息以csv格式写入新文件。

以上是保存的部分,在这个函数中,我们还要重新读取一下这个文件,因为要在这个类中获得最终的图片,以及标签,并且返回。

3.训练及测试数据分割

这里是第一步的图片的后半部分,导入了图片之后,对其进行分割,这里是按照训练、交叉验证、测试,分别是0.6,0.2,0.2进行分割的。

分割完毕后的self.images, self.labels,就可以拿来进行tensor相关的处理了。

4.数据处理

上面几步是准备工作,接下来定义的__getitem__是为了能够使train_loader = DataLoader()这一语句实现。在这里面直接将数据进行我们希望进行的转换。比如大小、旋转、裁剪等。

最后返回处理好的图片,以及tensor化的标签。

另外,还需要定义一个__len__,使得我们可以获得数据集长度。

二、ResNet处理

我们要用ResNet对图片进行处理,因此其中的参数需要进行一定的修改。

主要的修改部分是ResNet18之中的resblock模块。因为我们希望输入的是3通道,224*224的图片,因此在这里对通道,步长进行一定的修改,并进行测试,成功之后便可以进行训练了。

三、训练及可视化

1.数据集导入

同时把GPU设备相关代码准备好,并且由于需要可视化,因此先实例化visdom,并且在终端上输入python -m visdom.server,打开visdom监视终端。

2.测试函数

先把模式改为eval(),接下来就是通过model,去训练测试集,得到标签,并统计正确率。

3.训练过程及可视化

和之前的一样,还是先实例化一个优化器,选择损失函数模式,实例化ResNet18,然后进行训练。

在这里由于要展示,因此先对损失值,交叉验证分数分别设置一个初始的线,通过append的方法,画出我们的损失曲线,以及交叉验证分数曲线。

通过torch.save方法存储我们的最优解。

最后通过把存储好的最优解调用起来,使用测试集,来测试最终的效果。

最终获得的交叉验证准确率89%,测试集准确率88%,损失值及交叉验证结果的图像如下:

以上就是人工智能学习pyTorch自建数据集及可视化结果实现过程的详细内容,更多关于pyTorch自建数据集及可视化结果实现的资料请关注脚本之家其它相关文章!

相关文章

  • python实现2048小游戏

    python实现2048小游戏

    本文给大家分享的是个人修改自某网友的Python实现2048小游戏的代码,推荐给大家,有需要的小伙伴可以参考下。
    2015-03-03
  • Python列表推导式、字典推导式与集合推导式用法实例分析

    Python列表推导式、字典推导式与集合推导式用法实例分析

    这篇文章主要介绍了Python列表推导式、字典推导式与集合推导式用法,结合实例形式分析了Python三种推导式的概念、使用方法及相关注意事项,需要的朋友可以参考下
    2018-02-02
  • Python获取svn版本信息

    Python获取svn版本信息

    本文主要介绍了Python获取svn版本信息,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-07-07
  • pandas学习之txt与sql文件的基本操作指南

    pandas学习之txt与sql文件的基本操作指南

    Pandas是Python的第三方库,提供高性能易用的数据类型和分析工具,下面这篇文章主要给大家介绍了关于pandas学习之txt与sql文件的基本操作指南,需要的朋友可以参考下
    2021-08-08
  • Python中魔术方法的定义及一些常用方法

    Python中魔术方法的定义及一些常用方法

    所有以双下划线__包起来的方法,统称为Magic Method(魔术方法),它是一种的特殊方法,这篇文章主要给大家介绍了关于Python中魔术方法的定义及一些常用方法,需要的朋友可以参考下
    2024-02-02
  • Python实现带下标索引的遍历操作示例

    Python实现带下标索引的遍历操作示例

    这篇文章主要介绍了Python实现带下标索引的遍历操作,结合具体实例形式分析了2种带索引的遍历操作实现方法及相关操作注意事项,需要的朋友可以参考下
    2019-05-05
  • 浅谈python中拼接路径os.path.join斜杠的问题

    浅谈python中拼接路径os.path.join斜杠的问题

    今天小编就为大家分享一篇浅谈python中拼接路径os.path.join斜杠的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-10-10
  • python pickle 和 shelve模块的用法

    python pickle 和 shelve模块的用法

    pickle和shelve模块都可以把python对象存储到文件中,下面来看看它们的用法吧
    2013-09-09
  • python爬虫---requests库的用法详解

    python爬虫---requests库的用法详解

    requests是python实现的简单易用的HTTP库,使用起来比urllib简洁很多,这里就为大家分享一下
    2020-09-09
  • Python数据拟合与广义线性回归算法学习

    Python数据拟合与广义线性回归算法学习

    这篇文章主要为大家详细介绍了Python数据拟合与广义线性回归算法,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2017-12-12

最新评论