pytorch cnn 识别手写的字实现自建图片数据

 更新时间:2018年05月20日 17:03:26   作者:瓦力冫  
这篇文章主要介绍了pytorch cnn 识别手写的字实现自建图片数据,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧

本文主要介绍了pytorch cnn 识别手写的字实现自建图片数据,分享给大家,具体如下:

# library
# standard library
import os 
# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# torch.manual_seed(1)  # reproducible 
# Hyper Parameters
EPOCH = 1        # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001       # learning rate 
 
root = "./mnist/raw/"
 
def default_loader(path):
  # return Image.open(path).convert('RGB')
  return Image.open(path)
 
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0], int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader
    fh.close()
  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    img = Image.fromarray(np.array(img), mode='L')
    if self.transform is not None:
      img = self.transform(img)
    return img,label
  def __len__(self):
    return len(self.imgs)
 
train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE)
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(     # input shape (1, 28, 28)
      nn.Conv2d(
        in_channels=1,       # input height
        out_channels=16,      # n_filters
        kernel_size=5,       # filter size
        stride=1,          # filter movement/step
        padding=2,         # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
      ),               # output shape (16, 28, 28)
      nn.ReLU(),           # activation
      nn.MaxPool2d(kernel_size=2),  # choose max value in 2x2 area, output shape (16, 14, 14)
    )
    self.conv2 = nn.Sequential(     # input shape (16, 14, 14)
      nn.Conv2d(16, 32, 5, 1, 2),   # output shape (32, 14, 14)
      nn.ReLU(),           # activation
      nn.MaxPool2d(2),        # output shape (32, 7, 7)
    )
    self.out = nn.Linear(32 * 7 * 7, 10)  # fully connected layer, output 10 classes
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1)      # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
    output = self.out(x)
    return output, x  # return x for visualization 
cnn = CNN()
print(cnn) # net architecture
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()            # the target label is not one-hotted 
 
# training and testing
for epoch in range(EPOCH):
  for step, (x, y) in enumerate(train_loader):  # gives batch data, normalize x when iterate train_loader
    b_x = Variable(x)  # batch x
    b_y = Variable(y)  # batch y
 
    output = cnn(b_x)[0]        # cnn output
    loss = loss_func(output, b_y)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step()        # apply gradients
 
    if step % 50 == 0:
      cnn.eval()
      eval_loss = 0.
      eval_acc = 0.
      for i, (tx, ty) in enumerate(test_loader):
        t_x = Variable(tx)
        t_y = Variable(ty)
        output = cnn(t_x)[0]
        loss = loss_func(output, t_y)
        eval_loss += loss.data[0]
        pred = torch.max(output, 1)[1]
        num_correct = (pred == t_y).sum()
        eval_acc += float(num_correct.data[0])
      acc_rate = eval_acc / float(len(test_data))
      print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))

图片和label 见上一篇文章《pytorch 把MNIST数据集转换成图片和txt

结果如下:

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • Python 使用双重循环打印图形菱形操作

    Python 使用双重循环打印图形菱形操作

    这篇文章主要介绍了Python 使用双重循环打印图形菱形操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-08-08
  • python实现汽车管理系统

    python实现汽车管理系统

    这篇文章主要为大家详细介绍了python实现汽车管理系统,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-11-11
  • python创建列表并给列表赋初始值的方法

    python创建列表并给列表赋初始值的方法

    这篇文章主要介绍了python创建列表并给列表赋初始值的方法,涉及Python列表的定义与赋值技巧,需要的朋友可以参考下
    2015-07-07
  • Python OpenCV超详细讲解基本功能

    Python OpenCV超详细讲解基本功能

    OpenCV用C++语言编写,它具有C ++,Python,Java和MATLAB接口,并支持Windows,Linux,Android和Mac OS,OpenCV主要倾向于实时视觉应用,并在可用时利用MMX和SSE指令,本篇文章带你了解OpenCV的基本功能
    2022-04-04
  • Opencv Python实现两幅图像匹配

    Opencv Python实现两幅图像匹配

    这篇文章主要为大家详细介绍了Opencv Python实现两幅图像匹配,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-06-06
  • python3+PyQt5重新实现QT事件处理程序

    python3+PyQt5重新实现QT事件处理程序

    这篇文章主要为大家详细介绍了python3+PyQt5重新实现QT事件处理程序,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-04-04
  • Python调用DeepSeek API实现对本地数据库的AI管理

    Python调用DeepSeek API实现对本地数据库的AI管理

    这篇文章主要为大家详细介绍了Python如何基于DeepSeek模型实现对本地数据库的AI管理,文中的示例代码简洁易懂,有需要的小伙伴可以跟随小编一起学习一下
    2025-02-02
  • python 已知三条边求三角形的角度案例

    python 已知三条边求三角形的角度案例

    这篇文章主要介绍了python 已知三条边求三角形的角度案例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-04-04
  • Python selenium find_element()示例详解

    Python selenium find_element()示例详解

    selenium定位元素的函数/方法可以分为两类:find_element及find_elements,下面这篇文章主要给大家介绍了关于Python selenium find_element()的相关资料,需要的朋友可以参考下
    2022-07-07
  • 一文详解如何在Matplotlib中更改图例字体大小

    一文详解如何在Matplotlib中更改图例字体大小

    在我们处理数据的时候,需要对大量的数据进行绘图,就免不了要使用到Matplotlib,下面这篇文章主要给大家介绍了关于如何在Matplotlib中更改图例字体大小的相关资料,需要的朋友可以参考下
    2023-05-05

最新评论