基于Python实现的简单数字识别程序

 更新时间:2025年12月25日 08:54:51   作者:ufdf  
文章介绍了如何使用全连接神经网络(MLP)进行MNIST数字识别,包括代码模型定义、训练和测试的步骤,并解释了模型权重保存文件的内容,需要的朋友可以参考下

这里我们使用全连接神经网络(MLP) 实现的 MNIST 数字识别代码,结构更简单,仅包含几个线性层和激活函数。

简易代码

模型定义代码,model.py

import torch.nn as nn

# 定义一个简单的 CNN 模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = self.flatten(x)  # [B, 1, 28, 28] -> [B, 784]
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)  # 输出层不加激活(CrossEntropyLoss 内部含 softmax)
        return x

然后训练代码,train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from model import SimpleModel  # 👈 从 model.py 导入

# 配置
batch_size = 64
learning_rate = 0.001
num_epochs = 10
model_save_path = 'mnist_mlp.pth'

# 数据
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 模型、损失、优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练
print(f"Training on {device}...")
model.train()
for epoch in range(num_epochs):
    total_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}')

# 保存
torch.save(model.state_dict(), model_save_path)
print(f"✅ Model saved to {model_save_path}")

训练

在训练之前我们需要安装下python依赖

pip install torch torchvision

然后我们就可以开始训练模型啦!执行命令python ./train.py,你会看到类似输出

Training on cpu...
Epoch [1/10], Loss: 0.3501
Epoch [2/10], Loss: 0.1702
Epoch [3/10], Loss: 0.1335
Epoch [4/10], Loss: 0.1141
Epoch [5/10], Loss: 0.1027
Epoch [6/10], Loss: 0.0915
Epoch [7/10], Loss: 0.0884
Epoch [8/10], Loss: 0.0801
Epoch [9/10], Loss: 0.0769
Epoch [10/10], Loss: 0.0715
✅ Model saved to mnist_mlp.pth

目录下会生成一个mnist_mlp.pthmnist_mlp.pth 是一个 PyTorch 模型权重保存文件,本质上是一个 序列化后的字典(state_dict),存储了神经网络中所有可学习参数(如权重和偏置)的数值。

测试模型

现在我们拿我们的模型去试试我们的数字图片了~
predict.py

# predict.py
import torch
import torchvision.transforms as transforms
from PIL import Image
from model import SimpleModel
import argparse
import os

def predict_image(image_path, model_path='mnist_mlp.pth', device='cpu'):
    # 1. 加载模型
    model = SimpleModel()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()  # 推理模式

    # 2. 图像预处理(必须和训练时一致!)
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),  # 转灰度
        transforms.Resize((28, 28)),                   # 调整为 28x28
        transforms.ToTensor(),                         # 转为 Tensor [0,1]
        transforms.Normalize((0.1307,), (0.3081,))    # 用 MNIST 的均值/标准差
    ])

    # 3. 加载并预处理图像
    image = Image.open(image_path).convert('L')  # 强制灰度(兼容 RGB 输入)
    input_tensor = transform(image)              # shape: [1, 28, 28]
    input_batch = input_tensor.unsqueeze(0)      # 增加 batch 维度 → [1, 1, 28, 28]

    # 4. 推理
    with torch.no_grad():
        output = model(input_batch)
        probabilities = torch.softmax(output, dim=1)
        predicted_class = torch.argmax(probabilities, dim=1).item()
        confidence = probabilities[0][predicted_class].item()

    return predicted_class, confidence

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Predict digit in an image using trained MLP')
    parser.add_argument('image_path', type=str, help='Path to the input image (e.g., digit.png)')
    args = parser.parse_args()

    if not os.path.exists(args.image_path):
        print(f"❌ Error: Image file '{args.image_path}' not found!")
        exit(1)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    digit, conf = predict_image(args.image_path, device=device)

    print(f"✅ Predicted digit: {digit}")
    print(f"📊 Confidence: {conf:.4f} ({conf*100:.2f}%)")

我们可以python .\predict.py .\data\digit.png来看看预测的结果如何。

到此这篇关于基于Python实现的简单数字识别程序的文章就介绍到这了,更多相关Python数字识别程序内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python 如何快速复制序列

    python 如何快速复制序列

    这篇文章主要介绍了python 如何快速复制序列,帮助大家更好的理解和学习python,感兴趣的朋友可以了解下
    2020-09-09
  • python包装和授权学习教程

    python包装和授权学习教程

    包装是指对一个已经存在的对象进行系定义加工,实现授权是包装的一个特性,下面这篇文章主要给大家介绍了关于python包装和授权的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2023-06-06
  • tensorflow之获取tensor的shape作为max_pool的ksize实例

    tensorflow之获取tensor的shape作为max_pool的ksize实例

    今天小编就为大家分享一篇tensorflow之获取tensor的shape作为max_pool的ksize实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-01-01
  • Python实现根据Excel生成Model和数据导入脚本

    Python实现根据Excel生成Model和数据导入脚本

    最近遇到一个需求,有几十个Excel,每个的字段都不一样,然后都差不多是第一行是表头,后面几千上万的数据,需要把这些Excel中的数据全都加入某个已经上线的Django项目。所以我造了个自动生成 Model和导入脚本的轮子,希望对大家有所帮助
    2022-11-11
  • 使用Python和OpenCV实现实时文档扫描与矫正系统

    使用Python和OpenCV实现实时文档扫描与矫正系统

    在日常工作和学习中,我们经常需要将纸质文档数字化,手动拍摄文档照片常常会出现角度倾斜、透 视变形等问题,影响后续使用,本文将介绍如何使用Python和OpenCV构建一个实时文档扫描与矫正系统,能够通过摄像头自动检测文档边缘并进行透 视变换矫正,需要的朋友可以参考下
    2025-05-05
  • 解决phantomjs截图失败,phantom.exit位置的问题

    解决phantomjs截图失败,phantom.exit位置的问题

    今天小编就为大家分享一篇解决phantomjs截图失败,phantom.exit位置的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-05-05
  • 自定义Django_rest_framework_jwt登陆错误返回的解决

    自定义Django_rest_framework_jwt登陆错误返回的解决

    这篇文章主要介绍了自定义Django_rest_framework_jwt登陆错误返回的解决,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-10-10
  • pandas dataframe rolling移动计算方式

    pandas dataframe rolling移动计算方式

    在Pandas中,rolling()方法用于执行移动窗口计算,常用于时间序列数据分析,例如,计算某商品的7天或1个月销售总量,可以通过rolling()轻松实现,该方法的关键参数包括window(窗口大小),min_periods(最小计算周期)
    2024-09-09
  • Python三维绘图之Matplotlib库的使用方法

    Python三维绘图之Matplotlib库的使用方法

    这篇文章主要给大家介绍了关于Python三维绘图之Matplotlib库的使用方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-09-09
  • 使用Python实现一键往Word文档的表格中填写数据

    使用Python实现一键往Word文档的表格中填写数据

    在工作中,我们经常遇到将Excel表中的部分信息填写到Word文档的对应表格中,以生成报告,方便打印,所以本文小编就给大家介绍了如何使用Python实现一键往Word文档的表格中填写数据,文中有详细的代码示例供大家参考,需要的朋友可以参考下
    2023-12-12

最新评论