pytorch 利用lstm做mnist手写数字识别分类的实例

 更新时间:2020年01月10日 10:43:23   作者:xckkcxxck  
今天小编就为大家分享一篇pytorch 利用lstm做mnist手写数字识别分类的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定义数据
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定义模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果
     
   def forward(self, x):
     #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28)
     x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28)
     out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分类结果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定义训练过程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)    

以上这篇pytorch 利用lstm做mnist手写数字识别分类的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python GUI库图形界面开发之PyQt5开发环境配置与基础使用

    python GUI库图形界面开发之PyQt5开发环境配置与基础使用

    这篇文章主要介绍了python GUI库图形界面开发之PyQt5开发环境配置与基础使用,需要的朋友可以参考下
    2020-02-02
  • 详解Python中生成随机数据的示例详解

    详解Python中生成随机数据的示例详解

    在日常工作编程中存在着各种随机事件,同样在编程中生成随机数字的时候也是一样。每当在 Python 中生成随机数据、字符串或数字时,最好至少大致了解这些数据是如何生成的。所以本文将详细为大家讲解一下Python是如何生成随机数据,需要的可以参考一下
    2022-04-04
  • 利用python清除移动硬盘中的临时文件

    利用python清除移动硬盘中的临时文件

    本篇文章的目的是在移动硬盘插入到电脑的同时,利用Python自动化和Windows服务删除掉这些临时文件。感兴趣的朋友可以了解下
    2020-10-10
  • Python写的服务监控程序实例

    Python写的服务监控程序实例

    这篇文章主要介绍了Python写的服务监控程序实例,本文直接给出实现代码,需要的朋友可以参考下
    2015-01-01
  • Python下Fabric的简单部署方法

    Python下Fabric的简单部署方法

    这篇文章主要介绍了Python下Fabric的简单部署方法,Fabric是Python下一个流行的自动化工具,需要的朋友可以参考下
    2015-07-07
  • Python天气语音播报小助手

    Python天气语音播报小助手

    马上就要迎来国庆小长假了,激不激动,兴不兴奋!那今年国庆:天气怎么样?能不能出门逛街?能不能出去旅游?旅游出门就要挑个好的天气!下雨天哪儿哪儿都不舒服。今天小编带大家写一款Python天气语音播报小助手
    2021-09-09
  • Python super( )函数用法总结

    Python super( )函数用法总结

    今天给大家带来的知识是关于Python的相关知识,文章围绕着super( )函数展开,文中有非常详细的介绍及代码示例,需要的朋友可以参考下
    2021-06-06
  • jupyter-lab设置自启动及远程连接开发环境

    jupyter-lab设置自启动及远程连接开发环境

    本文主要介绍了jupyter-lab设置自启动及远程连接开发环境,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-02-02
  • python打印n位数“水仙花数”(实例代码)

    python打印n位数“水仙花数”(实例代码)

    这篇文章主要介绍了python打印n位数“水仙花数”,本文通过实例代码给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-12-12
  • PyTorch实现FedProx联邦学习算法

    PyTorch实现FedProx联邦学习算法

    这篇文章主要为大家介绍了PyTorch实现FedProx的联邦学习算法,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-05-05

最新评论