pytorch cnn 识别手写的字实现自建图片数据

 更新时间:2018年05月20日 17:03:26   作者:瓦力冫  
这篇文章主要介绍了pytorch cnn 识别手写的字实现自建图片数据,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧

本文主要介绍了pytorch cnn 识别手写的字实现自建图片数据,分享给大家,具体如下:

# library
# standard library
import os 
# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# torch.manual_seed(1)  # reproducible 
# Hyper Parameters
EPOCH = 1        # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001       # learning rate 
 
root = "./mnist/raw/"
 
def default_loader(path):
  # return Image.open(path).convert('RGB')
  return Image.open(path)
 
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0], int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader
    fh.close()
  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    img = Image.fromarray(np.array(img), mode='L')
    if self.transform is not None:
      img = self.transform(img)
    return img,label
  def __len__(self):
    return len(self.imgs)
 
train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE)
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(     # input shape (1, 28, 28)
      nn.Conv2d(
        in_channels=1,       # input height
        out_channels=16,      # n_filters
        kernel_size=5,       # filter size
        stride=1,          # filter movement/step
        padding=2,         # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
      ),               # output shape (16, 28, 28)
      nn.ReLU(),           # activation
      nn.MaxPool2d(kernel_size=2),  # choose max value in 2x2 area, output shape (16, 14, 14)
    )
    self.conv2 = nn.Sequential(     # input shape (16, 14, 14)
      nn.Conv2d(16, 32, 5, 1, 2),   # output shape (32, 14, 14)
      nn.ReLU(),           # activation
      nn.MaxPool2d(2),        # output shape (32, 7, 7)
    )
    self.out = nn.Linear(32 * 7 * 7, 10)  # fully connected layer, output 10 classes
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1)      # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
    output = self.out(x)
    return output, x  # return x for visualization 
cnn = CNN()
print(cnn) # net architecture
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()            # the target label is not one-hotted 
 
# training and testing
for epoch in range(EPOCH):
  for step, (x, y) in enumerate(train_loader):  # gives batch data, normalize x when iterate train_loader
    b_x = Variable(x)  # batch x
    b_y = Variable(y)  # batch y
 
    output = cnn(b_x)[0]        # cnn output
    loss = loss_func(output, b_y)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step()        # apply gradients
 
    if step % 50 == 0:
      cnn.eval()
      eval_loss = 0.
      eval_acc = 0.
      for i, (tx, ty) in enumerate(test_loader):
        t_x = Variable(tx)
        t_y = Variable(ty)
        output = cnn(t_x)[0]
        loss = loss_func(output, t_y)
        eval_loss += loss.data[0]
        pred = torch.max(output, 1)[1]
        num_correct = (pred == t_y).sum()
        eval_acc += float(num_correct.data[0])
      acc_rate = eval_acc / float(len(test_data))
      print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))

图片和label 见上一篇文章《pytorch 把MNIST数据集转换成图片和txt

结果如下:

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • 利于python脚本编写可视化nmap和masscan的方法

    利于python脚本编写可视化nmap和masscan的方法

    这篇文章主要介绍了利于python脚本编写可视化nmap和masscan的方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-12-12
  • matlab xlabel位置的设置方式

    matlab xlabel位置的设置方式

    这篇文章主要介绍了matlab xlabel位置的设置方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2021-05-05
  • 十一个案例带你吃透Python函数参数

    十一个案例带你吃透Python函数参数

    这篇文章主要通过十一个案例带大家一起了解一下Python中的函数参数,文中的示例代码讲解详细,对我们学习Python有一定帮助,需要的可以参考一下
    2022-08-08
  • TensorFlow人工智能学习Keras高层接口应用示例

    TensorFlow人工智能学习Keras高层接口应用示例

    这篇文章主要为大家介绍了TensorFlow人工智能学习中Keras高层接口的应用示例,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步
    2021-11-11
  • 关于django 数据库迁移(migrate)应该知道的一些事

    关于django 数据库迁移(migrate)应该知道的一些事

    今天小编就为大家分享一篇关于django 数据库迁移(migrate)应该知道的一些事,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-05-05
  • Python基于unittest实现测试用例执行

    Python基于unittest实现测试用例执行

    这篇文章主要介绍了Python基于unittest实现测试用例执行,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-11-11
  • Python实现蚁群算法

    Python实现蚁群算法

    本文主要介绍了Python实现蚁群算法,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-03-03
  • Python使用Slider组件实现调整曲线参数功能示例

    Python使用Slider组件实现调整曲线参数功能示例

    这篇文章主要介绍了Python使用Slider组件实现调整曲线参数功能,结合实例形式分析了Python使用matplotlib与Slider组件进行图形绘制相关操作技巧,需要的朋友可以参考下
    2019-09-09
  • django中账号密码验证登陆功能的实现方法

    django中账号密码验证登陆功能的实现方法

    这篇文章主要介绍了django中账号密码验证登陆功能的实现方法,本文图文并茂给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-07-07
  • python实现得到当前登录用户信息的方法

    python实现得到当前登录用户信息的方法

    这篇文章主要介绍了python实现得到当前登录用户信息的方法,结合实例形式分析了Python在Linux平台以及Windows平台使用相关模块获取用户信息的相关操作技巧,需要的朋友可以参考下
    2019-06-06

最新评论