基于Python实现的简单数字识别程序
这里我们使用全连接神经网络(MLP) 实现的 MNIST 数字识别代码,结构更简单,仅包含几个线性层和激活函数。
简易代码
模型定义代码,model.py
import torch.nn as nn
# 定义一个简单的 CNN 模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.2)
def forward(self, x):
x = self.flatten(x) # [B, 1, 28, 28] -> [B, 784]
x = self.relu(self.fc1(x))
x = self.dropout(x)
x = self.relu(self.fc2(x))
x = self.dropout(x)
x = self.fc3(x) # 输出层不加激活(CrossEntropyLoss 内部含 softmax)
return x然后训练代码,train.py
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from model import SimpleModel # 👈 从 model.py 导入
# 配置
batch_size = 64
learning_rate = 0.001
num_epochs = 10
model_save_path = 'mnist_mlp.pth'
# 数据
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 模型、损失、优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 训练
print(f"Training on {device}...")
model.train()
for epoch in range(num_epochs):
total_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}')
# 保存
torch.save(model.state_dict(), model_save_path)
print(f"✅ Model saved to {model_save_path}")训练
在训练之前我们需要安装下python依赖
pip install torch torchvision
然后我们就可以开始训练模型啦!执行命令python ./train.py,你会看到类似输出
Training on cpu... Epoch [1/10], Loss: 0.3501 Epoch [2/10], Loss: 0.1702 Epoch [3/10], Loss: 0.1335 Epoch [4/10], Loss: 0.1141 Epoch [5/10], Loss: 0.1027 Epoch [6/10], Loss: 0.0915 Epoch [7/10], Loss: 0.0884 Epoch [8/10], Loss: 0.0801 Epoch [9/10], Loss: 0.0769 Epoch [10/10], Loss: 0.0715 ✅ Model saved to mnist_mlp.pth
目录下会生成一个mnist_mlp.pth,mnist_mlp.pth 是一个 PyTorch 模型权重保存文件,本质上是一个 序列化后的字典(state_dict),存储了神经网络中所有可学习参数(如权重和偏置)的数值。
测试模型
现在我们拿我们的模型去试试我们的数字图片了~
predict.py
# predict.py
import torch
import torchvision.transforms as transforms
from PIL import Image
from model import SimpleModel
import argparse
import os
def predict_image(image_path, model_path='mnist_mlp.pth', device='cpu'):
# 1. 加载模型
model = SimpleModel()
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval() # 推理模式
# 2. 图像预处理(必须和训练时一致!)
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1), # 转灰度
transforms.Resize((28, 28)), # 调整为 28x28
transforms.ToTensor(), # 转为 Tensor [0,1]
transforms.Normalize((0.1307,), (0.3081,)) # 用 MNIST 的均值/标准差
])
# 3. 加载并预处理图像
image = Image.open(image_path).convert('L') # 强制灰度(兼容 RGB 输入)
input_tensor = transform(image) # shape: [1, 28, 28]
input_batch = input_tensor.unsqueeze(0) # 增加 batch 维度 → [1, 1, 28, 28]
# 4. 推理
with torch.no_grad():
output = model(input_batch)
probabilities = torch.softmax(output, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0][predicted_class].item()
return predicted_class, confidence
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict digit in an image using trained MLP')
parser.add_argument('image_path', type=str, help='Path to the input image (e.g., digit.png)')
args = parser.parse_args()
if not os.path.exists(args.image_path):
print(f"❌ Error: Image file '{args.image_path}' not found!")
exit(1)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
digit, conf = predict_image(args.image_path, device=device)
print(f"✅ Predicted digit: {digit}")
print(f"📊 Confidence: {conf:.4f} ({conf*100:.2f}%)")我们可以python .\predict.py .\data\digit.png来看看预测的结果如何。
到此这篇关于基于Python实现的简单数字识别程序的文章就介绍到这了,更多相关Python数字识别程序内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
相关文章
tensorflow之获取tensor的shape作为max_pool的ksize实例
今天小编就为大家分享一篇tensorflow之获取tensor的shape作为max_pool的ksize实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧2020-01-01
解决phantomjs截图失败,phantom.exit位置的问题
今天小编就为大家分享一篇解决phantomjs截图失败,phantom.exit位置的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧2018-05-05
自定义Django_rest_framework_jwt登陆错误返回的解决
这篇文章主要介绍了自定义Django_rest_framework_jwt登陆错误返回的解决,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧2020-10-10
pandas dataframe rolling移动计算方式
在Pandas中,rolling()方法用于执行移动窗口计算,常用于时间序列数据分析,例如,计算某商品的7天或1个月销售总量,可以通过rolling()轻松实现,该方法的关键参数包括window(窗口大小),min_periods(最小计算周期)2024-09-09


最新评论