PyTorch基于MNIST的手写数字识别

 更新时间:2026年01月19日 08:54:56   作者:子夜江寒  
本文介绍了使用PyTorch框架构建深度学习模型处理MNIST手写数字识别的完整流程,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧

1. 深度学习与PyTorch简介

深度学习作为机器学习的重要分支,已在计算机视觉、自然语言处理等领域取得了显著成果。PyTorch是由Facebook开源的深度学习框架,以其动态计算图和直观的API设计而广受欢迎。本文以经典的MNIST手写数字数据集为例,展示如何利用PyTorch框架构建并训练深度学习模型。

2. 环境配置与数据准备

2.1 环境检查

首先检查PyTorch及相关库的版本,确保环境配置正确:

import torch
import torchvision
import torchaudio
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from matplotlib import pyplot as plt

print(torch.__version__)
print(torchaudio.__version__)
print(torchvision.__version__)

2.2 数据加载与预处理

MNIST数据集包含60,000个训练样本和10,000个测试样本,每个样本为28×28像素的灰度手写数字图像。

training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

参数

  • root:数据存储路径
  • train:是否为训练集
  • download:是否自动下载
  • transform:数据预处理转换,ToTensor()将PIL图像转换为张量并归一化到[0,1]

2.3 数据可视化

我们可以查看数据集的样本分布:

print(len(training_data))

figure = plt.figure()
for i in range(9):
    img, label = training_data[i + 59000]
    figure.add_subplot(3, 3, i + 1)
    plt.title(label)
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

2.4 数据批量加载

使用DataLoader实现数据的批量加载和随机打乱:

# 增加批次大小
train_dataloader = DataLoader(training_data, batch_size=128)  # 增大batch size
test_dataloader = DataLoader(test_data, batch_size=128)

for X, y in test_dataloader:
    print(f"Shape of X[N,C,H,W]:{X.shape}")
    print(f"Shape of y:{y.shape} {y.dtype}")
    break

3. 神经网络模型设计

3.1 设备选择

根据可用硬件选择计算设备:

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

3.2 神经网络架构

设计一个包含多个全连接层的深度神经网络:

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.a = 10
        self.flatten = nn.Flatten()
        原始架构
        self.hidden1 = nn.Linear(28 * 28, 128)
        self.hidden2 = nn.Linear(128, 256)
        self.out = nn.Linear(256, 10)
        
    
    def forward(self, x):
        # 原始前向传播
        x = self.flatten(x)
        x = self.hidden1(x)
        x = torch.sigmoid(x)
        x = self.hidden2(x)
        x = torch.sigmoid(x)
        return x

3.3 模型实例化

model = NeuralNetwork().to(device)
print(model)

4. 训练与评估流程

4.1 训练函数

def train(dataloader, model, loss_fn, optimizer):
    model.train()
    batch_size_num = 1
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        pred = model.forward(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_value = loss.item()
        if batch_size_num % 100 == 0:
            print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")
        batch_size_num += 1

训练步骤

  1. model.train():设置为训练模式(启用Dropout)
  2. 前向传播计算预测值
  3. 计算损失函数值
  4. optimizer.zero_grad():清空梯度
  5. loss.backward():反向传播计算梯度
  6. optimizer.step():更新模型参数

4.2 测试函数

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model.forward(X)
            test_loss = loss_fn(pred, y)
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            a = (pred.argmax(1) == y)
            b = (pred.argmax(1) == y).type(torch.float)
    test_loss /= num_batches
    correct /= size

    print(f"Test result:\n Accuracy:{(100 * correct):.2f}%, Avg loss: {test_loss}")

测试要点

  • model.eval():设置为评估模式(禁用Dropout)
  • torch.no_grad():禁用梯度计算,节省内存
  • pred.argmax(1):获取预测类别

5. 损失函数配置

loss_fn = nn.CrossEntropyLoss()

损失函数说明

  • 使用CrossEntropyLoss,适用于多分类问题
  • 结合了LogSoftmax和NLLLoss,直接输出分类概率

6. 模型训练与评估

6.1 优化器配置

# 原始优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

6.2 单次训练与测试

train(train_dataloader, model, loss_fn, optimizer)
test(train_dataloader, model, loss_fn)

6.3 多轮训练(可选)

epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n----------------------")
    train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)

7. 提高准确率的优化方式

  1. 层数增加:从2层隐藏层增加到3层,增强模型表达能力
  2. 神经元增加:第一层从128个神经元增加到512个
  3. 激活函数:用ReLU替代sigmoid,缓解梯度消失问题
  4. 正则化:添加Dropout层(0.2丢弃率),防止过拟合
  5. 改进优化器:降低学习率
        # 改进架构
        self.hidden1 = nn.Linear(28 * 28, 512)  # 增加神经元
        self.dropout1 = nn.Dropout(0.2)  # 添加Dropout
        self.hidden2 = nn.Linear(512, 256)
        self.dropout2 = nn.Dropout(0.2)  # 添加Dropout
        self.hidden3 = nn.Linear(256, 128)  # 增加一层
        self.out = nn.Linear(128, 10)
        # 改进的前向传播
        x = self.flatten(x)
        x = self.hidden1(x)
        x = torch.relu(x)  # 使用ReLU替代sigmoid
        x = self.dropout1(x)  # 训练时随机丢弃
        x = self.hidden2(x)
        x = torch.relu(x)  # 使用ReLU替代sigmoid
        x = self.dropout2(x)  # 训练时随机丢弃
        x = self.hidden3(x)
        x = torch.relu(x)
        x = self.out(x)
# 改进优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 降低学习率

到此这篇关于PyTorch基于MNIST的手写数字识别的文章就介绍到这了,更多相关PyTorch MNIST手写数字识别内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python为图片和PDF去水印详解

    python为图片和PDF去水印详解

    大家好,本篇文章主要讲的是python为图片和PDF去水印详解,感兴趣的同学赶快来看一看吧,对你有帮助的话记得收藏一下
    2022-01-01
  • 深入Python Tkinter 模块

    深入Python Tkinter 模块

    本文介绍Python Tkinter模块,强调其轻量、跨平台、无依赖的优势,适用于快速开发无需复杂环境配置的桌面工具,通过布局、事件回调、样式美化等核心技巧,结合PyInstaller打包,实现从脚本到可交付程序的完整流程,适合制作简单实用的小工具,感兴趣的朋友跟随小编一起看看吧
    2025-08-08
  • Python报错:TypeError: ‘xxx‘ object is not subscriptable解决办法

    Python报错:TypeError: ‘xxx‘ object is not subscriptable解决

    这篇文章主要给大家介绍了关于Python报错:TypeError: ‘xxx‘ object is not subscriptable的解决办法,TypeError是Python中的一种错误,表示操作或函数应用于不合适类型的对象时发生,文中将解决办法介绍的非常详细,需要的朋友可以参考下
    2024-08-08
  • Python实现OpenCV的安装与使用示例

    Python实现OpenCV的安装与使用示例

    这篇文章主要介绍了Python实现OpenCV的安装与使用,结合实例形式分析了Python中OpenCV的安装及针对图片的相关操作技巧,需要的朋友可以参考下
    2018-03-03
  • Python cookbook(数据结构与算法)筛选及提取序列中元素的方法

    Python cookbook(数据结构与算法)筛选及提取序列中元素的方法

    这篇文章主要介绍了Python cookbook(数据结构与算法)筛选及提取序列中元素的方法,涉及Python列表推导式、生成器表达式及filter()函数相关使用技巧,需要的朋友可以参考下
    2018-03-03
  • python中literal_eval函数的使用小结

    python中literal_eval函数的使用小结

    literal_eval是Python标准库ast模块中的一个安全函数,用于将包含 Python字面量表达式的字符串安全地转换为对应的Python对象,下面就来介绍一下literal_eval函数的使用
    2025-08-08
  • Python实现对二维码数据进行压缩

    Python实现对二维码数据进行压缩

    当前二维码的应用越来越广泛,包括疫情时期的健康码也是应用二维码的典型案例。本文的目标很明确,就是使用python,实现一张二维码显示更多信息,代码简单实用,感兴趣的可以了解一下
    2023-02-02
  • Python字典和集合讲解

    Python字典和集合讲解

    这篇文章主要给大家假关节的是Python字典和集合,字典是Python内置的数据结构之一,是一个无序的序列;而集合是python语言提供的内置数据结构,没有value的字典,集合类型与其他类型最大的区别在于,它不包含重复元素。想具体了解有关python字典与集合,请看下面文章内容
    2021-10-10
  • Keras自动下载的数据集/模型存放位置介绍

    Keras自动下载的数据集/模型存放位置介绍

    这篇文章主要介绍了Keras自动下载的数据集/模型存放位置介绍,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-06-06
  • 如何用Python检查SQLite数据库中表是否存在

    如何用Python检查SQLite数据库中表是否存在

    Python查询表中数据有多种方法,具体取决于你使用的数据库类型和查询工具,这篇文章主要介绍了如何用Python检查SQLite数据库中表是否存在的相关资料,文中通过代码介绍的非常详细,需要的朋友可以参考下
    2025-11-11

最新评论