pytorch制作自己的LMDB数据操作示例

 更新时间:2019年12月18日 12:04:58   作者:团长sama  
这篇文章主要介绍了pytorch制作自己的LMDB数据操作,结合实例形式分析了pytorch使用lmdb的相关操作技巧与使用注意事项,需要的朋友可以参考下

本文实例讲述了pytorch制作自己的LMDB数据操作。分享给大家供大家参考,具体如下:

前言

记录下pytorch里如何使用lmdb的code,自用

制作部分的Code

code就是ASTER里数据制作部分的代码改了点,aster_train.txt里面就算图片的完整路径每行一个,图片同目录下有同名的txt,里面记着jpg的标签

import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
from tqdm import tqdm
import six
from PIL import Image
import scipy.io as sio
from tqdm import tqdm
import re
def checkImageIsValid(imageBin):
 if imageBin is None:
  return False
 imageBuf = np.fromstring(imageBin, dtype=np.uint8)
 img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
 imgH, imgW = img.shape[0], img.shape[1]
 if imgH * imgW == 0:
  return False
 return True
def writeCache(env, cache):
 with env.begin(write=True) as txn:
  for k, v in cache.items():
   txn.put(k.encode(), v)
def _is_difficult(word):
 assert isinstance(word, str)
 return not re.match('^[\w]+$', word)
def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
 """
 Create LMDB dataset for CRNN training.
 ARGS:
   outputPath  : LMDB output path
   imagePathList : list of image path
   labelList   : list of corresponding groundtruth texts
   lexiconList  : (optional) list of lexicon lists
   checkValid  : if true, check the validity of every image
 """
 assert(len(imagePathList) == len(labelList))
 nSamples = len(imagePathList)
 env = lmdb.open(outputPath, map_size=1099511627776)#最大空间1048576GB
 cache = {}
 cnt = 1
 for i in range(nSamples):
  imagePath = imagePathList[i]
  label = labelList[i]
  if len(label) == 0:
   continue
  if not os.path.exists(imagePath):
   print('%s does not exist' % imagePath)
   continue
  with open(imagePath, 'rb') as f:
   imageBin = f.read()
  if checkValid:
   if not checkImageIsValid(imageBin):
    print('%s is not a valid image' % imagePath)
    continue
  #数据库中都是二进制数据
  imageKey = 'image-%09d' % cnt#9位数不足填零
  labelKey = 'label-%09d' % cnt
  cache[imageKey] = imageBin
  cache[labelKey] = label.encode()
  if lexiconList:
   lexiconKey = 'lexicon-%09d' % cnt
   cache[lexiconKey] = ' '.join(lexiconList[i])
  if cnt % 1000 == 0:
   writeCache(env, cache)
   cache = {}
   print('Written %d / %d' % (cnt, nSamples))
  cnt += 1
 nSamples = cnt-1
 cache['num-samples'] = str(nSamples).encode()
 writeCache(env, cache)
 print('Created dataset with %d samples' % nSamples)
def get_sample_list(txt_path:str):
  with open(txt_path,'r') as fr:
    jpg_list=[x.strip() for x in fr.readlines() if os.path.exists(x.replace('.jpg','.txt').strip())]
  txt_content_list=[]
  for jpg in jpg_list:
    label_path=jpg.replace('.jpg','.txt')
    with open(label_path,'r') as fr:
      try:
        str_tmp=fr.readline()
      except UnicodeDecodeError as e:
        print(label_path)
        raise(e)
      txt_content_list.append(str_tmp.strip())
  return jpg_list,txt_content_list
if __name__ == "__main__":
 txt_path='/home/gpu-server/disk/disk1/NumberData/8NumberSample/aster_train.txt'
 lmdb_output_path = '/home/gpu-server/project/aster/dataset/train'
 imagePathList,labelList=get_sample_list(txt_path)
 createDataset(lmdb_output_path, imagePathList, labelList)

读取部分

这里用的pytorch的dataloader,简单记录一下,人比较懒,代码就直接抄过来,不整理拆分了,重点看__getitem__

from __future__ import absolute_import
# import sys
# sys.path.append('./')
import os
# import moxing as mox
import pickle
from tqdm import tqdm
from PIL import Image, ImageFile
import numpy as np
import random
import cv2
import lmdb
import sys
import six
import torch
from torch.utils import data
from torch.utils.data import sampler
from torchvision import transforms
from lib.utils.labelmaps import get_vocabulary, labels2strs
from lib.utils import to_numpy
ImageFile.LOAD_TRUNCATED_IMAGES = True
from config import get_args
global_args = get_args(sys.argv[1:])
if global_args.run_on_remote:
 import moxing as mox
 #moxing是一个分布式的框架 跳过
class LmdbDataset(data.Dataset):
 def __init__(self, root, voc_type, max_len, num_samples, transform=None):
  super(LmdbDataset, self).__init__()
  if global_args.run_on_remote:
   dataset_name = os.path.basename(root)
   data_cache_url = "/cache/%s" % dataset_name
   if not os.path.exists(data_cache_url):
    os.makedirs(data_cache_url)
   if mox.file.exists(root):
    mox.file.copy_parallel(root, data_cache_url)
   else:
    raise ValueError("%s not exists!" % root)
   self.env = lmdb.open(data_cache_url, max_readers=32, readonly=True)
  else:
   self.env = lmdb.open(root, max_readers=32, readonly=True)
  assert self.env is not None, "cannot create lmdb from %s" % root
  self.txn = self.env.begin()
  self.voc_type = voc_type
  self.transform = transform
  self.max_len = max_len
  self.nSamples = int(self.txn.get(b"num-samples"))
  self.nSamples = min(self.nSamples, num_samples)
  assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS','DIGITS']
  self.EOS = 'EOS'
  self.PADDING = 'PADDING'
  self.UNKNOWN = 'UNKNOWN'
  self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN)
  self.char2id = dict(zip(self.voc, range(len(self.voc))))
  self.id2char = dict(zip(range(len(self.voc)), self.voc))
  self.rec_num_classes = len(self.voc)
  self.lowercase = (voc_type == 'LOWERCASE')
 def __len__(self):
  return self.nSamples
 def __getitem__(self, index):
  assert index <= len(self), 'index range error'
  index += 1
  img_key = b'image-%09d' % index
  imgbuf = self.txn.get(img_key)
  #由于Image.open需要一个类文件对象 所以这里需要把二进制转为一个类文件对象
  buf = six.BytesIO()
  buf.write(imgbuf)
  buf.seek(0)
  try:
   img = Image.open(buf).convert('RGB')
   # img = Image.open(buf).convert('L')
   # img = img.convert('RGB')
  except IOError:
   print('Corrupted image for %d' % index)
   return self[index + 1]
  # reconition labels
  label_key = b'label-%09d' % index
  word = self.txn.get(label_key).decode()
  if self.lowercase:
   word = word.lower()
  ## fill with the padding token
  label = np.full((self.max_len,), self.char2id[self.PADDING], dtype=np.int)
  label_list = []
  for char in word:
   if char in self.char2id:
    label_list.append(self.char2id[char])
   else:
    ## add the unknown token
    print('{0} is out of vocabulary.'.format(char))
    label_list.append(self.char2id[self.UNKNOWN])
  ## add a stop token
  label_list = label_list + [self.char2id[self.EOS]]
  assert len(label_list) <= self.max_len
  label[:len(label_list)] = np.array(label_list)
  if len(label) <= 0:
   return self[index + 1]
  # label length
  label_len = len(label_list)
  if self.transform is not None:
   img = self.transform(img)
  return img, label, label_len

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程

希望本文所述对大家Python程序设计有所帮助。

相关文章

  • Python 随机生成测试数据的模块:faker基本使用方法详解

    Python 随机生成测试数据的模块:faker基本使用方法详解

    这篇文章主要介绍了Python 随机生成测试数据的模块:faker基本使用方法,结合实例形式详细分析了Python 随机生成测试数据的模块faker基本功能、原理、使用方法及操作注意事项,需要的朋友可以参考下
    2020-04-04
  • Python如何基于smtplib发不同格式的邮件

    Python如何基于smtplib发不同格式的邮件

    这篇文章主要介绍了Python如何基于smtplib发不同格式的邮件,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-12-12
  • python3.7 的新特性详解

    python3.7 的新特性详解

    这篇文章主要介绍了python3.7 的新特性详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-07-07
  • Python Tkinter Gui运行不卡顿(解决多线程解决界面卡死问题)

    Python Tkinter Gui运行不卡顿(解决多线程解决界面卡死问题)

    最近写的Python代码不知为何,总是执行到一半卡住不动,所以下面这篇文章主要给大家介绍了关于Python Tkinter Gui运行不卡顿,解决多线程解决界面卡死问题的相关资料,需要的朋友可以参考下
    2023-02-02
  • python2与python3爬虫中get与post对比解析

    python2与python3爬虫中get与post对比解析

    这篇文章主要介绍了python2与python3爬虫中get与post对比解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-09-09
  • 基于Python打造账号共享浏览器功能

    基于Python打造账号共享浏览器功能

    这篇文章主要介绍了基于Python打造账号共享浏览器功能,本文图文并茂给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-05-05
  • Python定时任务随机时间执行的实现方法

    Python定时任务随机时间执行的实现方法

    这篇文章主要介绍了Python定时任务随机时间执行的实现方法,文中给大家提到了python定时执行任务的三种方式 ,本文通过实例代码给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-08-08
  • Python的Random库的使用方法详解

    Python的Random库的使用方法详解

    这篇文章主要介绍了Python的Random库的使用方法详解,random库是使用随机数的Python标准库,python中用于生成伪随机数的函数库是random,需要的朋友可以参考下
    2023-07-07
  • python实现梯度下降算法

    python实现梯度下降算法

    这篇文章主要为大家详细介绍了python实现梯度下降算法,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-08-08
  • Pyside6 安装和简单界面开发过程详细介绍

    Pyside6 安装和简单界面开发过程详细介绍

    PySide是跨平台应用程序框架Qt的Python绑定,Qt是跨平台C++图形可视化界面应用开发框架,自推出以来深受业界盛赞,Pyside6是利用Python语言进行开发的GUI,所以在使用Pyside6前要先安装Python环境,本文给大家介绍Pyside6 安装和简单界面开发过程,一起看看吧
    2023-10-10

最新评论