Pytorch写数字识别LeNet模型

 更新时间:2022年01月26日 17:54:21   作者:Jokic_Rn   
这篇文章主要介绍了Pytorch写数字识别LeNet模型,LeNet-5是一个较简单的卷积神经网络,  LeNet-5 这个网络虽然很小,但是它包含了深度学习的基本模块:卷积层,池化层,全连接层。是其他深度学习模型的基础, 这里我们对LeNet-5进行深入分析,需要的朋友可以参考下

LeNet网络

LeNet网络过卷积层时候保持分辨率不变,过池化层时候分辨率变小。实现如下

from PIL import Image
import cv2
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import numpy as np
import tqdm as tqdm

class LeNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.sequential = nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
                                        nn.AvgPool2d(kernel_size=2,stride=2),
                                        nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
                                        nn.AvgPool2d(kernel_size=2,stride=2),
                                        nn.Flatten(),
                                        nn.Linear(16*25,120),nn.Sigmoid(),
                                        nn.Linear(120,84),nn.Sigmoid(),
                                        nn.Linear(84,10))
        
    
    def forward(self,x):
        return self.sequential(x)

class MLP(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.sequential = nn.Sequential(nn.Flatten(),
                          nn.Linear(28*28,120),nn.Sigmoid(),
                          nn.Linear(120,84),nn.Sigmoid(),
                          nn.Linear(84,10))
        
    
    def forward(self,x):
        return self.sequential(x)

epochs = 15
batch = 32
lr=0.9
loss = nn.CrossEntropyLoss()
model = LeNet()
optimizer = torch.optim.SGD(model.parameters(),lr)
device = torch.device('cuda')
root = r"./"
trans_compose  = transforms.Compose([transforms.ToTensor(),
                    ])
train_data = torchvision.datasets.MNIST(root,train=True,transform=trans_compose,download=True)
test_data = torchvision.datasets.MNIST(root,train=False,transform=trans_compose,download=True)
train_loader = DataLoader(train_data,batch_size=batch,shuffle=True)
test_loader = DataLoader(test_data,batch_size=batch,shuffle=False)

model.to(device)
loss.to(device)
# model.apply(init_weights)
for epoch in range(epochs):
  train_loss = 0
  test_loss = 0
  correct_train = 0
  correct_test = 0
  for index,(x,y) in enumerate(train_loader):
    x = x.to(device)
    y = y.to(device)
    predict = model(x)
    L = loss(predict,y)
    optimizer.zero_grad()
    L.backward()
    optimizer.step()
    train_loss = train_loss + L
    correct_train += (predict.argmax(dim=1)==y).sum()
  acc_train = correct_train/(batch*len(train_loader))
  with torch.no_grad():
    for index,(x,y) in enumerate(test_loader):
      [x,y] = [x.to(device),y.to(device)]
      predict = model(x)
      L1 = loss(predict,y)
      test_loss = test_loss + L1
      correct_test += (predict.argmax(dim=1)==y).sum()
    acc_test = correct_test/(batch*len(test_loader))
  print(f'epoch:{epoch},train_loss:{train_loss/batch},test_loss:{test_loss/batch},acc_train:{acc_train},acc_test:{acc_test}')

训练结果

epoch:12,train_loss:2.235553741455078,test_loss:0.3947642743587494,acc_train:0.9879833459854126,acc_test:0.9851238131523132
epoch:13,train_loss:2.028963804244995,test_loss:0.3220392167568207,acc_train:0.9891499876976013,acc_test:0.9875199794769287
epoch:14,train_loss:1.8020273447036743,test_loss:0.34837451577186584,acc_train:0.9901833534240723,acc_test:0.98702073097229

泛化能力测试

找了一张图片,将其分割成只含一个数字的图片进行测试

images_np = cv2.imread("/content/R-C.png",cv2.IMREAD_GRAYSCALE)
h,w = images_np.shape
images_np = np.array(255*torch.ones(h,w))-images_np#图片反色
images = Image.fromarray(images_np)
plt.figure(1)
plt.imshow(images)
test_images = []
for i in range(10):
  for j in range(16):
    test_images.append(images_np[h//10*i:h//10+h//10*i,w//16*j:w//16*j+w//16])
sample = test_images[77]
sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device)
sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28))
predict = model(sample_tensor)
output = predict.argmax()
print(output)
plt.figure(2)
plt.imshow(np.array(sample_tensor.squeeze().to('cpu')))

此时预测结果为4,预测正确。从这段代码中可以看到有一个反色的步骤,若不反色,结果会受到影响,如下图所示,预测为0,错误。
模型用于输入的图片是单通道的黑白图片,这里由于可视化出现了黄色,但实际上是黑白色,反色操作说明了数据的预处理十分的重要,很多数据如果是不清理过是无法直接用于推理的。

将所有用来泛化性测试的图片进行准确率测试:

correct = 0
i = 0
cnt = 1
for sample in test_images:
  sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device)
  sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28))
  predict = model(sample_tensor)
  output = predict.argmax()
  if(output==i):
    correct+=1
  if(cnt%16==0):
    i+=1
  cnt+=1
acc_g = correct/len(test_images)
print(f'acc_g:{acc_g}')

如果不反色,acc_g=0.15

acc_g:0.50625

到此这篇关于Pytorch写数字识别LeNet模型的文章就介绍到这了,更多相关Pytorch写数字识别LeNet模型内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python实现将一个大文件按段落分隔为多个小文件的简单操作方法

    Python实现将一个大文件按段落分隔为多个小文件的简单操作方法

    这篇文章主要介绍了Python实现将一个大文件按段落分隔为多个小文件的简单操作方法,涉及Python针对文件的读取、遍历、转换、写入等相关操作技巧,需要的朋友可以参考下
    2017-04-04
  • Ubuntu18.04安装 PyCharm并使用 Anaconda 管理的Python环境

    Ubuntu18.04安装 PyCharm并使用 Anaconda 管理的Python环境

    这篇文章主要介绍了Ubuntu18.04安装 PyCharm并使用 Anaconda 管理的Python环境的教程,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-04-04
  • Keras神经网络efficientnet模型搭建yolov3目标检测平台

    Keras神经网络efficientnet模型搭建yolov3目标检测平台

    这篇文章主要为大家介绍了Keras利用efficientnet系列模型搭建yolov3目标检测平台的过程详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-05-05
  • python中验证码连通域分割的方法详解

    python中验证码连通域分割的方法详解

    这篇文章主要给大家介绍了关于python中验证码连通域分割的相关资料,文中通过示例代码介绍的非常详细,对大家学习或者使用python具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2018-06-06
  • 你真的了解Python的random模块吗?

    你真的了解Python的random模块吗?

    这篇文章主要介绍了Python的random模块的相关内容,具有一定借鉴价值,需要的朋友可以参考下。
    2017-12-12
  • python字典的值可以修改吗

    python字典的值可以修改吗

    在本篇文章里小编给大家分享的是一篇关于python字典的值修改的方法步骤,需要的朋友们可以学习下。
    2020-06-06
  • Pandas删除数据的几种情况(小结)

    Pandas删除数据的几种情况(小结)

    这篇文章主要介绍了Pandas删除数据的几种情况(小结),详细的介绍了4种方式,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-06-06
  • Python实现12306火车票抢票系统

    Python实现12306火车票抢票系统

    这篇文章主要介绍了Python实现12306火车票抢票系统,本文通过实例代码给大家介绍的非常详细,具有一定的参考借鉴价值 ,需要的朋友可以参考下
    2019-07-07
  • 关于Qt6中QtMultimedia多媒体模块的重大改变分析

    关于Qt6中QtMultimedia多媒体模块的重大改变分析

    如果您一直在 Qt 5 中使用 Qt Multimedia,则需要对您的实现进行更改。这篇博文将尝试引导您完成最大的变化,同时查看 API 和内部结构
    2021-09-09
  • opencv实践项目之图像拼接详细步骤

    opencv实践项目之图像拼接详细步骤

    OpenCV的应用领域非常广泛,包括图像拼接、图像降噪、产品质检、人机交互、人脸识别、动作识别、动作跟踪、无人驾驶等,下面这篇文章主要给大家介绍了关于opencv实践项目之图像拼接的相关资料,需要的朋友可以参考下
    2023-05-05

最新评论