PyTorch实现MNIST数据集手写数字识别详情

 更新时间:2022年09月06日 14:17:45   作者:长浔  
这篇文章主要介绍了PyTorch实现MNIST数据集手写数字识别详情,文章围绕主题展开详细的内容戒杀,具有一定的参考价值,需要的朋友可以参考一下

前言:

本篇文章基于卷积神经网络CNN,使用PyTorch实现MNIST数据集手写数字识别。

一、PyTorch是什么?

PyTorch 是一个 Torch7 团队开源的 Python 优先的深度学习框架,提供两个高级功能:

  • 强大的 GPU 加速 Tensor 计算(类似 numpy)
  • 构建基于 tape 的自动升级系统上的深度神经网络

你可以重用你喜欢的 python 包,如 numpy、scipy 和 Cython ,在需要时扩展 PyTorch。

二、程序示例

下面案例可供运行参考

1.引入必要库

import torchvision
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

2.下载数据集

这里设置download=True,将会自动下载数据集,并存储在./data文件夹。

train_data = torchvision.datasets.MNIST(root="./data",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.MNIST(root="./data",train=False,transform=torchvision.transforms.ToTensor(),download=True)

3.加载数据集

batch_size=32表示每一个batch中包含32张手写数字图片,shuffle=True表示打乱测试集(data和target仍一一对应)

train_loader = DataLoader(train_data,batch_size=32,shuffle=True)
test_loader = DataLoader(test_data,batch_size=32,shuffle=False)

4.搭建CNN模型并实例化

class Net(torch.nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.con1 = torch.nn.Conv2d(1,10,kernel_size=5)
        self.con2 = torch.nn.Conv2d(10,20,kernel_size=5)
        self.pooling = torch.nn.MaxPool2d(2)
        self.fc = torch.nn.Linear(320,10)
    def forward(self,x):
        batch_size = x.size(0)
        x = F.relu(self.pooling(self.con1(x)))
        x = F.relu(self.pooling(self.con2(x)))
        x = x.view(batch_size,-1)
        x = self.fc(x)
        return x
#模型实例化        
model = Net()

5.交叉熵损失函数损失函数及SGD算法优化器

lossfun = torch.nn.CrossEntropyLoss()
opt = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.5)

6.训练函数

def train(epoch):
    running_loss = 0.0
    for i,(inputs,targets) in enumerate(train_loader,0):
        # inputs,targets = inputs.to(device),targets.to(device)
        opt.zero_grad()
        outputs = model(inputs)
        loss = lossfun(outputs,targets)
        loss.backward()
        opt.step()

        running_loss += loss.item()
        if i % 300 == 299:
            print('[%d,%d] loss:%.3f' % (epoch+1,i+1,running_loss/300))
            running_loss = 0.0

7.测试函数

def test():
    total = 0
    correct = 0
    with torch.no_grad():
        for (inputs,targets) in test_loader:
            # inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _,predicted = torch.max(outputs.data,dim=1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    print(100*correct/total)

8.运行

if __name__ == '__main__':
    for epoch in range(20):
        train(epoch)
        test()

三、总结

到此这篇关于PyTorch实现MNIST数据集手写数字识别详情的文章就介绍到这了,更多相关PyTorch MNIST 内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python实现Smtplib发送带有各种附件的邮件实例

    Python实现Smtplib发送带有各种附件的邮件实例

    本篇文章主要介绍了Python实现Smtplib发送带有各种附件的邮件实例,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2017-06-06
  • Python GUI实现PDF转Word功能

    Python GUI实现PDF转Word功能

    这篇文章主要介绍了如何使用 wxPython 创建一个简单的图形用户界面(GUI)应用程序,结合 pdf2docx 库,实现将 PDF 转换为 Word 文档的功能,需要的可以参考下
    2024-12-12
  • Python numpy线性代数用法实例解析

    Python numpy线性代数用法实例解析

    这篇文章主要介绍了Python numpy线性代数用法实例解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-11-11
  • python查看模块安装位置的方法

    python查看模块安装位置的方法

    今天小编就为大家分享一篇python查看模块安装位置的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-10-10
  • Python爬取股票信息,并可视化数据的示例

    Python爬取股票信息,并可视化数据的示例

    这篇文章主要介绍了Python爬取股票信息,并可视化数据的示例,帮助大家更好的理解和使用python爬虫,感兴趣的朋友可以了解下
    2020-09-09
  • Python爬取附近餐馆信息代码示例

    Python爬取附近餐馆信息代码示例

    这篇文章主要介绍了Python爬取附近餐馆信息代码示例,具有一定借鉴价值,需要的朋友可以参考下。
    2017-12-12
  • python使用fork实现守护进程的方法

    python使用fork实现守护进程的方法

    守护进程(Daemon)也称为精灵进程是一种生存期较长的一种进程。它们独立于控制终端并且周期性的执行某种任务或等待处理某些发生的事件。他们常常在系统引导装入时启动,在系统关闭时终止。
    2017-11-11
  • python 合并多个excel中同名的sheet

    python 合并多个excel中同名的sheet

    这篇文章主要介绍了python 如何合并多个excel中同名的sheet,帮助大家更好的利用python处理excel表格,感兴趣的朋友可以了解下
    2021-01-01
  • 入门tensorflow教程之TensorBoard可视化模型训练

    入门tensorflow教程之TensorBoard可视化模型训练

    在本篇文章中,主要介绍 了TensorBoard 的基础知识,并了解如何可视化训练模型中的一些基本信息,希望对大家的TensorBoard可视化模型训练有所帮助
    2021-08-08
  • 用python爬取豆瓣前一百电影

    用python爬取豆瓣前一百电影

    大家好,本篇文章主要讲的是用python爬取豆瓣前一百电影,感兴趣的同学赶快来看一看吧,对你有帮助的话记得收藏一下
    2022-01-01

最新评论