tensorflow入门:TFRecordDataset变长数据的batch读取详解

 更新时间:2020年01月20日 10:43:29   作者:yeqiustu  
今天小编就为大家分享一篇tensorflow入门:TFRecordDataset变长数据的batch读取详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行;但如果每条数据的长度不一样(常见于语音、视频、NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法:

1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord;这个方法的问题在于:若是有大量数据的长度都远远小于最大长度,则会造成存储空间的大量浪费。

2.使用dataset中的padded_batch方法来进行,参数padded_shapes #指明每条记录中各成员要pad成的形状,成员若是scalar,则用[],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([], [mx_length], [d1,...,dn]);该方法的函数说明如下:

padded_batch(
 batch_size,
 padded_shapes,
 padding_values=None #默认使用各类型数据的默认值,一般使用时可忽略该项
)

使用mnist数据来举例说明,首先在把mnist写入tfrecord之前,把mnist数据进行更改,以使得每个mnist图像的大小不等,如下:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
 
 
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
 
 
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 print(feats[0].dtype, feats[0].shape, ndatas)
 assert len(labels[0]) > 1
 for inx in range(ndatas):
 ed = random.randint(0,3) #随机丢掉几个数据点,以使长度不等
 exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

用dataset加载批量数据,在解析数据时用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函数使用,如下:

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature读入的是一个sparse_tensor,用该函数进行转换
 label = tf.reshape(feats['label'],[2,5]) #把label变成[2,5],以说明array数据如何padding
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 
padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一个scalar,不输入数字
# training dataset
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)

以上这篇tensorflow入门:TFRecordDataset变长数据的batch读取详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 一文解决Python切换版本问题

    一文解决Python切换版本问题

    由于mac默认都会安装python2.x,这给我们python开发造成不便,我们经常要用到python3.x的版本来进行测试、开发,所以本文主要介绍了Python切换版本问题,感兴趣的可以了解一下
    2021-07-07
  • Python+tkinter模拟“记住我”自动登录实例代码

    Python+tkinter模拟“记住我”自动登录实例代码

    这篇文章主要介绍了Python+tkinter模拟“记住我”自动登录实例代码,具有一定借鉴价值,需要的朋友可以参考下
    2018-01-01
  • Python实现FTP文件定时自动下载的步骤

    Python实现FTP文件定时自动下载的步骤

    这篇文章主要介绍了Python实现FTP文件定时自动下载的示例,帮助大家更好的理解和使用python,感兴趣的朋友可以了解下
    2020-12-12
  • Python中注释(多行注释和单行注释)的用法实例

    Python中注释(多行注释和单行注释)的用法实例

    这篇文章主要给大家介绍了关于Python中注释(多行注释和单行注释)用法的相关资料,文中通过示例代码介绍的非常详细,对大家学习或者使用Python具有一定的参考学习价值,需要的朋友们下面来一起学习学习吧
    2019-08-08
  • Python实战使用Selenium爬取网页数据

    Python实战使用Selenium爬取网页数据

    这篇文章主要为大家介绍了Python实战使用Selenium爬取网页数据示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步早日升职加薪
    2023-05-05
  • python模块smtplib实现纯文本邮件发送功能

    python模块smtplib实现纯文本邮件发送功能

    这篇文章主要为大家详细介绍了python模块smtplib实现纯文本邮件发送功能,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-05-05
  • python 类相关概念理解

    python 类相关概念理解

    这篇文章主要介绍了简单了解python类概念,具有一定借鉴价值,需要的朋友可以参考下,希望能够给你带来帮助
    2021-09-09
  • 使用Python Turtle库带你玩转创意绘图(画个心,写个花)

    使用Python Turtle库带你玩转创意绘图(画个心,写个花)

    Python的turtle库提供了一种有趣且易于上手的编程绘图方式,适合初学者学习,通过本文的介绍,你将了解到如何进行画布设置、画笔属性的调整、画笔的移动与控制,文中通过代码介绍的非常详细,需要的朋友可以参考下
    2024-11-11
  • Flask模板渲染与Get和Post请求详细介绍

    Flask模板渲染与Get和Post请求详细介绍

    这篇文章主要介绍了Flask模板渲染与Get和Post请求,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2022-09-09
  • Python中异常处理的5个最佳实践分享

    Python中异常处理的5个最佳实践分享

    异常处理是编写健壮可靠的 Python 代码的一个基本方面,这篇文章为大家整理了Python中异常处理的5个最佳实践,文中的示例代码讲解详细,希望对大家有所帮助
    2024-01-01

最新评论