Pytorch 使用CNN图像分类的实现

 更新时间:2020年06月16日 09:56:21   作者:NULL  
这篇文章主要介绍了Pytorch 使用CNN图像分类的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧

需求

在4*4的图片中,比较外围黑色像素点和内圈黑色像素点个数的大小将图片分类

如上图图片外围黑色像素点5个大于内圈黑色像素点1个分为0类反之1类

想法

  • 通过numpy、PIL构造4*4的图像数据集
  • 构造自己的数据集类
  • 读取数据集对数据集选取减少偏斜
  • cnn设计因为特征少,直接1*1卷积层
  • 或者在4*4外围添加padding成6*6,设计2*2的卷积核得出3*3再接上全连接层

代码

import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from PIL import Image

构造数据集

import csv
import collections
import os
import shutil

def buildDataset(root,dataType,dataSize):
  """构造数据集
  构造的图片存到root/{dataType}Data
  图片地址和标签的csv文件存到 root/{dataType}DataInfo.csv
  Args:
    root:str
      项目目录
    dataType:str
      'train'或者‘test'
    dataNum:int
      数据大小
  Returns:
  """
  dataInfo = []
  dataPath = f'{root}/{dataType}Data'
  if not os.path.exists(dataPath):
    os.makedirs(dataPath)
  else:
    shutil.rmtree(dataPath)
    os.mkdir(dataPath)
    
  for i in range(dataSize):
    # 创建0,1 数组
    imageArray=np.random.randint(0,2,(4,4))
    # 计算0,1数量得到标签
    allBlackNum = collections.Counter(imageArray.flatten())[0]
    innerBlackNum = collections.Counter(imageArray[1:3,1:3].flatten())[0]
    label = 0 if (allBlackNum-innerBlackNum)>innerBlackNum else 1
    # 将图片保存
    path = f'{dataPath}/{i}.jpg'
    dataInfo.append([path,label])
    im = Image.fromarray(np.uint8(imageArray*255))
    im = im.convert('1') 
    im.save(path)
  # 将图片地址和标签存入csv文件
  filePath = f'{root}/{dataType}DataInfo.csv'
  with open(filePath, 'w') as f:
    writer = csv.writer(f)
    writer.writerows(dataInfo)
root=r'/Users/null/Documents/PythonProject/Classifier'

构造训练数据集

buildDataset(root,'train',20000)

构造测试数据集

buildDataset(root,'test',10000)

读取数据集

class MyDataset(torch.utils.data.Dataset):

  def __init__(self, root, datacsv, transform=None):
    super(MyDataset, self).__init__()
    with open(f'{root}/{datacsv}', 'r') as f:
      imgs = []
      # 读取csv信息到imgs列表
      for path,label in map(lambda line:line.rstrip().split(','),f):
        imgs.append((path, int(label)))
    self.imgs = imgs
    self.transform = transform if transform is not None else lambda x:x
    
  def __getitem__(self, index):
    path, label = self.imgs[index]
    img = self.transform(Image.open(path).convert('1'))
    return img, label

  def __len__(self):
    return len(self.imgs)
trainData=MyDataset(root = root,datacsv='trainDataInfo.csv', transform=transforms.ToTensor())
testData=MyDataset(root = root,datacsv='testDataInfo.csv', transform=transforms.ToTensor())

处理数据集使得数据集不偏斜

import itertools

def chooseData(dataset,scale):
  # 将类别为1的排序到前面
  dataset.imgs.sort(key=lambda x:x[1],reverse=True)
  # 获取类别1的数目 ,取scale倍的数组,得数据不那么偏斜
  trueNum =collections.Counter(itertools.chain.from_iterable(dataset.imgs))[1]
  end = min(trueNum*scale,len(dataset))
  dataset.imgs=dataset.imgs[:end]
scale = 4
chooseData(trainData,scale)
chooseData(testData,scale)
len(trainData),len(testData)
(2250, 1122)
import torch.utils.data as Data

# 超参数
batchSize = 50
lr = 0.1
numEpochs = 20

trainIter = Data.DataLoader(dataset=trainData, batch_size=batchSize, shuffle=True)
testIter = Data.DataLoader(dataset=testData, batch_size=batchSize)

定义模型

from torch import nn
from torch.autograd import Variable
from torch.nn import Module,Linear,Sequential,Conv2d,ReLU,ConstantPad2d
import torch.nn.functional as F
class Net(Module):  
  def __init__(self):
    super(Net, self).__init__()

    self.cnnLayers = Sequential(
      # padding添加1层常数1,设定卷积核为2*2
      ConstantPad2d(1, 1),
      Conv2d(1, 1, kernel_size=2, stride=2,bias=True)
    )
    self.linearLayers = Sequential(
      Linear(9, 2)
    )

  def forward(self, x):
    x = self.cnnLayers(x)
    x = x.view(x.shape[0], -1)
    x = self.linearLayers(x)
    return x
class Net2(Module):  
  def __init__(self):
    super(Net2, self).__init__()

    self.cnnLayers = Sequential(
      Conv2d(1, 1, kernel_size=1, stride=1,bias=True)
    )
    self.linearLayers = Sequential(
      ReLU(),
      Linear(16, 2)
    )

  def forward(self, x):
    x = self.cnnLayers(x)
    x = x.view(x.shape[0], -1)
    x = self.linearLayers(x)
    return x

定义损失函数

# 交叉熵损失函数
loss = nn.CrossEntropyLoss()
loss2 = nn.CrossEntropyLoss()

定义优化算法

net = Net()
optimizer = torch.optim.SGD(net.parameters(),lr = lr)
net2 = Net2()
optimizer2 = torch.optim.SGD(net2.parameters(),lr = lr)

训练模型

# 计算准确率
def evaluateAccuracy(dataIter, net):
  accSum, n = 0.0, 0
  with torch.no_grad():
    for X, y in dataIter:
      accSum += (net(X).argmax(dim=1) == y).float().sum().item()
      n += y.shape[0]
  return accSum / n
def train(net, trainIter, testIter, loss, numEpochs, batchSize,
       optimizer):
  for epoch in range(numEpochs):
    trainLossSum, trainAccSum, n = 0.0, 0.0, 0
    for X,y in trainIter:
      yHat = net(X)
      l = loss(yHat,y).sum()
      optimizer.zero_grad()
      l.backward()
      optimizer.step()
      # 计算训练准确度和loss
      trainLossSum += l.item()
      trainAccSum += (yHat.argmax(dim=1) == y).sum().item()
      n += y.shape[0]
    # 评估测试准确度
    testAcc = evaluateAccuracy(testIter, net)
    print('epoch {:d}, loss {:.4f}, train acc {:.3f}, test acc {:.3f}'.format(epoch + 1, trainLossSum / n, trainAccSum / n, testAcc))  

Net模型训练

train(net, trainIter, testIter, loss, numEpochs, batchSize,optimizer)
epoch 1, loss 0.0128, train acc 0.667, test acc 0.667
epoch 2, loss 0.0118, train acc 0.683, test acc 0.760
epoch 3, loss 0.0104, train acc 0.742, test acc 0.807
epoch 4, loss 0.0093, train acc 0.769, test acc 0.772
epoch 5, loss 0.0085, train acc 0.797, test acc 0.745
epoch 6, loss 0.0084, train acc 0.798, test acc 0.807
epoch 7, loss 0.0082, train acc 0.804, test acc 0.816
epoch 8, loss 0.0078, train acc 0.816, test acc 0.812
epoch 9, loss 0.0077, train acc 0.818, test acc 0.817
epoch 10, loss 0.0074, train acc 0.824, test acc 0.826
epoch 11, loss 0.0072, train acc 0.836, test acc 0.819
epoch 12, loss 0.0075, train acc 0.823, test acc 0.829
epoch 13, loss 0.0071, train acc 0.839, test acc 0.797
epoch 14, loss 0.0067, train acc 0.849, test acc 0.824
epoch 15, loss 0.0069, train acc 0.848, test acc 0.843
epoch 16, loss 0.0064, train acc 0.864, test acc 0.851
epoch 17, loss 0.0062, train acc 0.867, test acc 0.780
epoch 18, loss 0.0060, train acc 0.871, test acc 0.864
epoch 19, loss 0.0057, train acc 0.881, test acc 0.890
epoch 20, loss 0.0055, train acc 0.885, test acc 0.897

Net2模型训练

# batchSize = 50 
# lr = 0.1
# numEpochs = 15 下得出的结果
train(net2, trainIter, testIter, loss2, numEpochs, batchSize,optimizer2)

epoch 1, loss 0.0119, train acc 0.638, test acc 0.676
epoch 2, loss 0.0079, train acc 0.823, test acc 0.986
epoch 3, loss 0.0046, train acc 0.987, test acc 0.977
epoch 4, loss 0.0030, train acc 0.983, test acc 0.973
epoch 5, loss 0.0023, train acc 0.981, test acc 0.976
epoch 6, loss 0.0019, train acc 0.980, test acc 0.988
epoch 7, loss 0.0016, train acc 0.984, test acc 0.984
epoch 8, loss 0.0014, train acc 0.985, test acc 0.986
epoch 9, loss 0.0013, train acc 0.987, test acc 0.992
epoch 10, loss 0.0011, train acc 0.989, test acc 0.993
epoch 11, loss 0.0010, train acc 0.989, test acc 0.996
epoch 12, loss 0.0010, train acc 0.992, test acc 0.994
epoch 13, loss 0.0009, train acc 0.993, test acc 0.994
epoch 14, loss 0.0008, train acc 0.995, test acc 0.996
epoch 15, loss 0.0008, train acc 0.994, test acc 0.998

测试

test = torch.Tensor([[[[0,0,0,0],[0,1,1,0],[0,1,1,0],[0,0,0,0]]],
         [[[1,1,1,1],[1,0,0,1],[1,0,0,1],[1,1,1,1]]],
         [[[0,1,0,1],[1,0,0,1],[1,0,0,1],[0,0,0,1]]],
         [[[0,1,1,1],[1,0,0,1],[1,0,0,1],[0,0,0,1]]],
         [[[0,0,1,1],[1,0,0,1],[1,0,0,1],[1,0,1,0]]],
         [[[0,0,1,0],[0,1,0,1],[0,0,1,1],[1,0,1,0]]],
         [[[1,1,1,0],[1,0,0,1],[1,0,1,1],[1,0,1,1]]]
         ])

target=torch.Tensor([0,1,0,1,1,0,1])
test
tensor([[[[0., 0., 0., 0.],
     [0., 1., 1., 0.],
     [0., 1., 1., 0.],
     [0., 0., 0., 0.]]],

​

    [[[1., 1., 1., 1.],
     [1., 0., 0., 1.],
     [1., 0., 0., 1.],
     [1., 1., 1., 1.]]],

​

    [[[0., 1., 0., 1.],
     [1., 0., 0., 1.],
     [1., 0., 0., 1.],
     [0., 0., 0., 1.]]],

​

    [[[0., 1., 1., 1.],
     [1., 0., 0., 1.],
     [1., 0., 0., 1.],
     [0., 0., 0., 1.]]],

​

    [[[0., 0., 1., 1.],
     [1., 0., 0., 1.],
     [1., 0., 0., 1.],
     [1., 0., 1., 0.]]],

​

    [[[0., 0., 1., 0.],
     [0., 1., 0., 1.],
     [0., 0., 1., 1.],
     [1., 0., 1., 0.]]],

​

    [[[1., 1., 1., 0.],
     [1., 0., 0., 1.],
     [1., 0., 1., 1.],
     [1., 0., 1., 1.]]]])



with torch.no_grad():
  output = net(test)
  output2 = net2(test)
predictions =output.argmax(dim=1)
predictions2 =output2.argmax(dim=1)
# 比较结果
print(f'Net测试结果{predictions.eq(target)}')
print(f'Net2测试结果{predictions2.eq(target)}')
Net测试结果tensor([ True, True, False, True, True, True, True])
Net2测试结果tensor([False, True, False, True, True, False, True])

到此这篇关于Pytorch 使用CNN图像分类的实现的文章就介绍到这了,更多相关Pytorch CNN图像分类内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python三种遍历文件目录的方法实例代码

    Python三种遍历文件目录的方法实例代码

    这篇文章主要介绍了Python三种遍历文件目录的方法实例代码,具有一定借鉴价值,需要的朋友可以参考下
    2018-01-01
  • 详解Python3.8+PyQt5+pyqt5-tools+Pycharm配置详细教程

    详解Python3.8+PyQt5+pyqt5-tools+Pycharm配置详细教程

    这篇文章主要介绍了Python3.8+PyQt5+pyqt5-tools+Pycharm配置详细教程,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧
    2020-11-11
  • django连接oracle时setting 配置方法

    django连接oracle时setting 配置方法

    今天小编就为大家分享一篇django连接oracle时setting 配置方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08
  • python项目--使用Tkinter的日历GUI应用程序

    python项目--使用Tkinter的日历GUI应用程序

    在 Python 中,我们可以使用 Tkinter 制作 GUI。如果你非常有想象力和创造力,你可以用 Tkinter 做出很多有趣的东西,希望本篇文章能够帮到你
    2021-08-08
  • python用Configobj模块读取配置文件

    python用Configobj模块读取配置文件

    这篇文章主要介绍了python用Configobj模块读取配置文件,帮助大家更好的利用python处理文件,感兴趣的朋友可以了解下
    2020-09-09
  • 深入了解Python在HDA中的应用

    深入了解Python在HDA中的应用

    这篇文章主要介绍了深入了解Python在HDA中的应用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-09-09
  • 详解在SpringBoot如何优雅的使用多线程

    详解在SpringBoot如何优雅的使用多线程

    这篇文章主要带大家快速了解一下@Async注解的用法,包括异步方法无返回值、有返回值,最后总结了@Async注解失效的几个坑,感兴趣的小伙伴可以了解一下
    2023-02-02
  • 详解python os.walk()方法的使用

    详解python os.walk()方法的使用

    今天给大家带来的是关于Python的相关知识,文章围绕python os.walk()方法的使用展开,文中有非常详细的介绍及代码示例,需要的朋友可以参考下
    2021-06-06
  • Python3 列表list合并的4种方法

    Python3 列表list合并的4种方法

    这篇文章主要介绍了Python3 列表list合并的4种方法,需要的朋友可以参考下
    2021-04-04
  • Pycharm 使用 Pipenv 新建的虚拟环境(图文详解)

    Pycharm 使用 Pipenv 新建的虚拟环境(图文详解)

    pipenv 是 Pipfile 主要倡导者、requests 作者 Kenneth Reitz 写的一个命令行工具,主要包含了Pipfile、pip、click、requests和virtualenv。这篇文章主要介绍了Pycharm 使用 Pipenv 新建的虚拟环境的问题,需要的朋友可以参考下
    2020-04-04

最新评论