Pytorch实现的手写数字mnist识别功能完整示例

 更新时间:2019年12月13日 10:47:02   作者:nudt_qxx  
这篇文章主要介绍了Pytorch实现的手写数字mnist识别功能,结合完整实例形式分析了Pytorch模块手写字识别具体步骤与相关实现技巧,需要的朋友可以参考下

本文实例讲述了Pytorch实现的手写数字mnist识别功能。分享给大家供大家参考,具体如下:

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义网络结构
class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(   #input_size=(1*28*28)
      nn.Conv2d(1, 6, 5, 1, 2), #padding=2保证输入输出尺寸相同
      nn.ReLU(),   #input_size=(6*28*28)
      nn.MaxPool2d(kernel_size=2, stride=2),#output_size=(6*14*14)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(6, 16, 5),
      nn.ReLU(),   #input_size=(16*10*10)
      nn.MaxPool2d(2, 2) #output_size=(16*5*5)
    )
    self.fc1 = nn.Sequential(
      nn.Linear(16 * 5 * 5, 120),
      nn.ReLU()
    )
    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.ReLU()
    )
    self.fc3 = nn.Linear(84, 10)
  # 定义前向传播过程,输入为x
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x
#使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
parser = argparse.ArgumentParser()
parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') #模型保存路径
parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)") #模型加载路径
opt = parser.parse_args()
# 超参数设置
EPOCH = 8  #遍历数据集次数
BATCH_SIZE = 64   #批处理尺寸(batch_size)
LR = 0.001    #学习率
# 定义数据预处理方式
transform = transforms.ToTensor()
# 定义训练数据集
trainset = tv.datasets.MNIST(
  root='./data/',
  train=True,
  download=True,
  transform=transform)
# 定义训练批处理数据
trainloader = torch.utils.data.DataLoader(
  trainset,
  batch_size=BATCH_SIZE,
  shuffle=True,
  )
# 定义测试数据集
testset = tv.datasets.MNIST(
  root='./data/',
  train=False,
  download=True,
  transform=transform)
# 定义测试批处理数据
testloader = torch.utils.data.DataLoader(
  testset,
  batch_size=BATCH_SIZE,
  shuffle=False,
  )
# 定义损失函数loss function 和优化方式(采用SGD)
net = LeNet().to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,通常用于多分类问题上
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
# 训练
if __name__ == "__main__":
  for epoch in range(EPOCH):
    sum_loss = 0.0
    # 数据读取
    for i, data in enumerate(trainloader):
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)
      # 梯度清零
      optimizer.zero_grad()
      # forward + backward
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      # 每训练100个batch打印一次平均loss
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d, %d] loss: %.03f'
           % (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0
    # 每跑完一次epoch测试一下准确率
    with torch.no_grad():
      correct = 0
      total = 0
      for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        # 取得分最高的那个类
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
      print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct / total)))
  #torch.save(net.state_dict(), '%s/net_%03d.pth' % (opt.outf, epoch + 1))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程

希望本文所述对大家Python程序设计有所帮助。

相关文章

  • Python中字符串格式化的方法小结

    Python中字符串格式化的方法小结

    在Python中,格式化字符串输出是一项非常常见的任务,Python提供了多种方式来实现字符串格式化,每种方式都有其独特的优势和用法,下面我们就来学习一下这些方法的具体操作吧
    2023-11-11
  • Python实现仓库管理系统

    Python实现仓库管理系统

    这篇文章主要为大家详细介绍了Python实现仓库管理系统,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-05-05
  • 利用Python将list列表写入文件并读取的方法汇总

    利用Python将list列表写入文件并读取的方法汇总

    因为实验需要,实现了一下写入txt文件,下面这篇文章主要给大家介绍了关于如何利用Python将list列表写入文件并读取的几种方法,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-03-03
  • 基于Python实现二维图像双线性插值

    基于Python实现二维图像双线性插值

    双线性插值,又称为双线性内插。在数学上,双线性插值是有两个变量的插值函数的线性插值扩展,其核心思想是在两个方向分别进行一次线性插值。本文将用Python实现二维图像双线性插值,感兴趣的可以了解下
    2022-06-06
  • Python学习之os包使用教程详解

    Python学习之os包使用教程详解

    本文将详细介绍python的内置包——OS 包。OS 包拥有着普遍的操作系统功能,拥有着各种各样的函数来操作系统的驱动功能。快来跟随小编一起学习一下OS包的使用方法吧
    2022-03-03
  • Python基于pygame实现单机版五子棋对战

    Python基于pygame实现单机版五子棋对战

    这篇文章主要为大家详细介绍了Python基于pygame实现单机版五子棋对战,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-12-12
  • PO模式在selenium自动化测试框架的优势

    PO模式在selenium自动化测试框架的优势

    大家都知道po模式可以提高代码的可读性和减少了代码的重复,但是相对的缺点还有,今天通过本文一起学习下PO模式在selenium自动化测试框架的优势,需要的朋友可以参考下
    2022-03-03
  • Python如何用filter函数筛选数据

    Python如何用filter函数筛选数据

    这篇文章主要介绍了Python如何用filter函数筛选数据,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-03-03
  • python 中 os.walk() 函数详解

    python 中 os.walk() 函数详解

    os.walk()是一种遍历目录数的函数,它以一种深度优先的策略(depth-first)访问指定的目录。这篇文章主要介绍了python 中 os.walk() 函数,需要的朋友可以参考下
    2021-11-11
  • python目标检测数据增强的代码参数解读及应用

    python目标检测数据增强的代码参数解读及应用

    这篇文章主要为大家介绍了python目标检测数据增强的代码参数解读及应用,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-05-05

最新评论