pytorch 利用lstm做mnist手写数字识别分类的实例

 更新时间:2020年01月10日 10:43:23   作者:xckkcxxck  
今天小编就为大家分享一篇pytorch 利用lstm做mnist手写数字识别分类的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定义数据
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定义模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果
     
   def forward(self, x):
     #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28)
     x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28)
     out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分类结果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定义训练过程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)    

以上这篇pytorch 利用lstm做mnist手写数字识别分类的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 对numpy中轴与维度的理解

    对numpy中轴与维度的理解

    下面小编就为大家分享一篇对numpy中轴与维度的理解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • Python实现图像和办公文档处理的方法和技巧

    Python实现图像和办公文档处理的方法和技巧

    本文介绍了Python实现图像和办公文档处理的方法和技巧,包括使用Pillow库处理图像、使用OpenCV库进行图像识别和处理、使用PyPDF2库处理PDF文档、使用docx和xlwt库处理Word和Excel文档等,帮助读者更好地掌握Python在图像和办公文档处理方面的应用
    2023-05-05
  • 使用Python在Excel工作表中创建图表的实现步骤

    使用Python在Excel工作表中创建图表的实现步骤

    在现代企业中,数据驱动的决策变得越来越重要,Excel作为企业中最常用的数据分析工具,其强大的表格和图表功能在日常工作中不可或缺,然而,当面对成百上千条数据或需要生成定期报告时,手动制作图表不仅耗时,还容易出错,所以本文介绍了如何实现Excel图表自动化生成
    2025-12-12
  • 使用Python 统计文件夹内所有pdf页数的小工具

    使用Python 统计文件夹内所有pdf页数的小工具

    这篇文章主要介绍了Python 统计文件夹内所有pdf页数的小工具,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-03-03
  • pycharm使用anaconda全过程

    pycharm使用anaconda全过程

    这篇文章主要介绍了pycharm使用anaconda全过程,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-02-02
  • python3爬取各类天气信息

    python3爬取各类天气信息

    这篇文章主要为大家详细介绍了python3爬取各类天气信息,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-02-02
  • 使用Python的Twisted框架编写非阻塞程序的代码示例

    使用Python的Twisted框架编写非阻塞程序的代码示例

    Twisted是基于异步模式的开发框架,因而利用Twisted进行非阻塞编程自然也是必会的用法,下面我们就来一起看一下使用Python的Twisted框架编写非阻塞程序的代码示例:
    2016-05-05
  • TensorFlow2.1.0最新版本安装详细教程

    TensorFlow2.1.0最新版本安装详细教程

    TensorFlow是一款优秀的深度学习框架,支持多种常见的操作系统,对大家的学习或工作具有一定的参考借鉴价值,这篇文章主要介绍了TensorFlow2.1.0最新版本安装详细教程,需要的朋友可以参考下
    2020-04-04
  • Python 实现自动化Excel报表的步骤

    Python 实现自动化Excel报表的步骤

    这篇文章主要介绍了Python 实现自动化Excel报表的步骤,帮助大家更好的理解和学习使用python,感兴趣的朋友可以了解下
    2021-04-04
  • torch.optim优化算法理解之optim.Adam()解读

    torch.optim优化算法理解之optim.Adam()解读

    这篇文章主要介绍了torch.optim优化算法理解之optim.Adam()解读,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-11-11

最新评论