pytorch 带batch的tensor类型图像显示操作

 更新时间:2021年05月20日 14:38:55   作者:Xavier Jiezou  
这篇文章主要介绍了pytorch 带batch的tensor类型图像显示操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

项目场景

pytorch训练时我们一般把数据集放到数据加载器里,然后分批拿出来训练。训练前我们一般还要看一下训练数据长啥样,也就是训练数据集可视化。

那么如何显示dataloader里面带batch的tensor类型的图像呢?

显示图像

绘图最常用的库就是matplotlib:

pip install matplotlib

显示图像会用到matplotlib.pyplot.imshow方法。查阅官方文档可知,该方法接收的图像的通道数要放到后面:

在这里插入图片描述

数据加载器中数据的维度是[B, C, H, W],我们每次只拿一个数据出来就是[C, H, W],而matplotlib.pyplot.imshow要求的输入维度是[H, W, C],所以我们需要交换一下数据维度,把通道数放到最后面,这里用到pytorch里面的permute方法(transpose方法也行,不过要交换两次,没这个方便,numpy中的transpose方法倒是可以一次交换完成)

用法示例如下:

>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(1, 2, 0).size()
torch.Size([3, 5, 2])

代码示例

#%% 导入模块
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#%% 下载数据集
train_file = datasets.MNIST(
    root='./dataset/',
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]),
    download=True
)
#%% 制作数据加载器
train_loader = DataLoader(
    dataset=train_file,
    batch_size=9,
    shuffle=True
)
#%% 训练数据可视化
images, labels = next(iter(train_loader))
print(images.size())  # torch.Size([9, 1, 28, 28])
plt.figure(figsize=(9, 9))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.title(labels[i].item())
    plt.imshow(images[i].permute(1, 2, 0), cmap='gray')
    plt.axis('off')
plt.show()

这里以mnist数据集为例,演示一下显示效果。我这个代码其实还有一点小问题。数据增强的时候我不是进行标准化了嘛,就是在第7行代码:Normalize((0.1307,), (0.3081,))。

所以,如果你想查看训练集的原始图像,还得反标准化。

标准化:image = (image-mean)/std

反标准化:image = image*std+mean

我拿imagenet中的一个蚂蚁和蜜蜂的子集做了一下实验,标准化前后的区别还是很明显的:

在这里插入图片描述

最终效果

在这里插入图片描述

补充:PIL,plt显示tensor类型的图像

该方法针对显示Dataloader读取的图像

PIL 与plt中对应操作不同,但原理是一样的,我试过用下方代码Image的方法在plt上show失败了,原因暂且不知。

 # 方法1:Image.show()
 # transforms.ToPILImage()中有一句
 # npimg = np.transpose(pic.numpy(), (1, 2, 0))
 # 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维
 img = transforms.ToPILImage(image[0])
 img.show()

 # 方法2:plt.imshow(ndarray)
 img = image[0] # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
 img = img.numpy() # FloatTensor转为ndarray
 img = np.transpose(img, (1,2,0)) # 把channel那一维放到最后
 # 显示图片
 plt.imshow(img)
 plt.show()
 cnt += 1

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 基于python+pandoc实现html批量转word

    基于python+pandoc实现html批量转word

    pandoc是一个强大的文档格式转换工具,支持丰富的格式转换,并尽可能的保留原来的排版,号称文档格式转换的瑞士军刀,本文将给大家介绍一下使用python搭配pandoc实现html批量转word,感兴趣的朋友可以参考阅读下
    2023-09-09
  • appium中常见的几种点击方式

    appium中常见的几种点击方式

    本文主要介绍了appium中常见的几种点击方式,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-02-02
  • Flask之flask-session的具体使用

    Flask之flask-session的具体使用

    这篇文章主要介绍了Flask之flask-session的具体使用,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-07-07
  • Python 日志记录模块的综合指南

    Python 日志记录模块的综合指南

    这篇文章主要为大家介绍了Python 日志记录模块的综合指南,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-12-12
  • Python pip指定安装源的方法详解

    Python pip指定安装源的方法详解

    pip是Python包管理工具,该工具提供了对Python包的查找、下载、安装、卸载的功能,这篇文章主要给大家介绍了关于Python pip指定安装源的相关资料,需要的朋友可以参考下
    2023-12-12
  • python利用多种方式来统计词频(单词个数)

    python利用多种方式来统计词频(单词个数)

    这篇文章主要介绍了python利用多种方式来统计词频(单词个数),小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2019-05-05
  • Python如何使用正则表达式爬取京东商品信息

    Python如何使用正则表达式爬取京东商品信息

    这篇文章主要介绍了Python如何使用正则表达式爬取京东商品信息,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-06-06
  • Python网络编程基于多线程实现多用户全双工聊天功能示例

    Python网络编程基于多线程实现多用户全双工聊天功能示例

    这篇文章主要介绍了Python网络编程基于多线程实现多用户全双工聊天功能,结合实例形式分析了Python网络编程中使用多线程进行多用户异步通信的原理与相关实现技巧,需要的朋友可以参考下
    2018-04-04
  • 详解Swift中属性的声明与作用

    详解Swift中属性的声明与作用

    Swift中的属性可以被分为存储属性和计算属性,本文将为大家详解Swift中属性的声明与作用,需要的朋友可以参考下
    2016-06-06
  • pyqt5 tablewidget 利用线程动态刷新数据的方法

    pyqt5 tablewidget 利用线程动态刷新数据的方法

    今天小编就为大家分享一篇pyqt5 tablewidget 利用线程动态刷新数据的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-06-06

最新评论