pytorch 利用lstm做mnist手写数字识别分类的实例

 更新时间:2020年01月10日 10:43:23   作者:xckkcxxck  
今天小编就为大家分享一篇pytorch 利用lstm做mnist手写数字识别分类的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定义数据
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定义模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果
     
   def forward(self, x):
     #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28)
     x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28)
     out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分类结果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定义训练过程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)    

以上这篇pytorch 利用lstm做mnist手写数字识别分类的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python+Pygame实现怀旧游戏飞机大战

    Python+Pygame实现怀旧游戏飞机大战

    第一次见到飞机大战是在小学五年级下半学期的时候,这个游戏中可以说包含了几乎所有我目前可接触到的pygame知识。本文就来利用Pygame实现飞机大战游戏,需要的可以参考一下
    2022-11-11
  • 详解pytorch 0.4.0迁移指南

    详解pytorch 0.4.0迁移指南

    这篇文章主要介绍了详解pytorch 0.4.0迁移指南,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2019-06-06
  • 用 Python 制作地球仪的方法

    用 Python 制作地球仪的方法

    这篇文章主要介绍了如何用 Python 制作地球仪,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考即将价值,需要的朋友可以参考下
    2020-04-04
  • pandas.concat实现DataFrame竖着拼接、横着拼接方式

    pandas.concat实现DataFrame竖着拼接、横着拼接方式

    这篇文章主要介绍了pandas.concat实现DataFrame竖着拼接、横着拼接方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-10-10
  • Python中的choice()方法使用详解

    Python中的choice()方法使用详解

    这篇文章主要介绍了Python中的choice()方法使用详解,是Python入门中的基础知识,需要的朋友可以参考下
    2015-05-05
  • 如何在Python项目中引入日志

    如何在Python项目中引入日志

    在开发一些大型项目的时候,都会使用日志来记录项目运行时产生的信息,以备出错时定位分析和从日志信息中提取数据统计分析等。在 Python 中使用 logging 内置模块即可对项目进行日志的配置。
    2021-05-05
  • python编程开发之textwrap文本样式处理技巧

    python编程开发之textwrap文本样式处理技巧

    这篇文章主要介绍了python编程开发之textwrap文本样式处理技巧,实例分析了Python中textwrap的常用方法与处理文本样式的相关使用技巧,需要的朋友可以参考下
    2015-11-11
  • 初步解析Python下的多进程编程

    初步解析Python下的多进程编程

    这篇文章主要介绍了初步解析Python下的多进程编程,使用多进程编程一直是Python编程当中的重点和难点,需要的朋友可以参考下
    2015-04-04
  • Python随机生成手机号、数字的方法详解

    Python随机生成手机号、数字的方法详解

    这篇文章主要介绍了Python随机生成手机号、数字的方法,结合完整实例形式分析了Python编程生成随机手机号与数字的实现方法及相关函数用法,需要的朋友可以参考下
    2017-07-07
  • 解决Jupyter notebook更换主题工具栏被隐藏及添加目录生成插件问题

    解决Jupyter notebook更换主题工具栏被隐藏及添加目录生成插件问题

    这篇文章主要介绍了解决Jupyter notebook更换主题工具栏被隐藏及添加目录生成插件问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-04-04

最新评论