如何从csv文件构建Tensorflow的数据集

 更新时间:2020年09月21日 11:13:15   作者:Sight Tech.  
这篇文章主要介绍了如何从csv文件构建Tensorflow的数据集,帮助大家更好的理解和使用Tensorflow,感兴趣的朋友可以了解下

从csv文件构建Tensorflow的数据集

当我们有一系列CSV文件,如何构建Tensorflow的数据集呢?

基本步骤

  1. 获得一组CSV文件的路径
  2. 将这组文件名,转成文件名对应的dataset => file_dataset
  3. 根据file_dataset中的每个文件名,读取文件内容 生成一个内容的dataset => content_dataset
  4. 这样的多个content_dataset, 拼接起来,形成一整个dataset
  5. 因为读出来的每条记录都是string类型, 所以还需要对每条记录做decode

存在一个这样的变量train_filenames

pprint.pprint(train_filenames)
#	['generate_csv\\train_00.csv',
#	 'generate_csv\\train_01.csv',
#	 'generate_csv\\train_02.csv',
#	 'generate_csv\\train_03.csv',
#	 'generate_csv\\train_04.csv',
#	 'generate_csv\\train_05.csv',
#	 'generate_csv\\train_06.csv',
#	 'generate_csv\\train_07.csv',
#	 'generate_csv\\train_08.csv',
#	 'generate_csv\\train_09.csv',
#	 'generate_csv\\train_10.csv',
#	 'generate_csv\\train_11.csv',
#	 'generate_csv\\train_12.csv',
#	 'generate_csv\\train_13.csv',
#	 'generate_csv\\train_14.csv',
#	 'generate_csv\\train_15.csv',
#	 'generate_csv\\train_16.csv',
#	 'generate_csv\\train_17.csv',
#	 'generate_csv\\train_18.csv',
#	 'generate_csv\\train_19.csv']

接着,我们用提前定义好的API构建文件名数据集file_dataset

filename_dataset = tf.data.Dataset.list_files(train_filenames)
for filename in filename_dataset:
  print(filename)
#tf.Tensor(b'generate_csv\\train_09.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_19.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_03.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_01.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_14.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_17.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_15.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_06.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_05.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_07.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_11.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_02.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_12.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_13.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_10.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_16.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_18.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_00.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_04.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_08.csv', shape=(), dtype=string)

第三步, 根据每个文件名,去读取文件里面的内容

dataset = filename_dataset.interleave(
  lambda filename: tf.data.TextLineDataset(filename).skip(1),
  cycle_length=5
)

for line in dataset.take(3):
  print(line)

#tf.Tensor(b'0.46908349737250216,1.8718193706428006,0.13936365871212536,-0.011055733363841472,-0.6349261778219746,-0.036732316700563934,1.0259470089944995,-1.319095600336748,2.171', shape=(), dtype=string)
#tf.Tensor(b'-1.102093775650278,1.313248890578542,-0.7212003024178728,-0.14707856286537277,0.34720121604358517,0.0965085401826684,-0.74698820254838,0.6810563907247876,1.428', shape=(), dtype=string)
#tf.Tensor(b'-0.8901003715328659,0.9142699762469286,-0.1851678950250224,-0.12947457252940406,0.5958187430364827,-0.021255215877779534,0.7914317693724252,-0.45618713536506217,0.75', shape=(), dtype=string)

interleave的作用可以类比map, 对每个元素应用操作,然后还能把结果合起来。
因此,有了interleave, 我们就把第三四步,一起完成了
之所以skip(1),是因为这个csv第一行是header.
cycle_length是并行化构建数据集的线程数

好,第五步,解析每条记录

def parse_csv_line(line, n_fields=9):
  defaults = [tf.constant(np.nan)] * n_fields
  parsed_fields = tf.io.decode_csv(line, record_defaults=defaults)
  x = tf.stack(parsed_fields[:-1])
  y = tf.stack(parsed_fields[-1:])
  return x, y

parse_csv_line('1.2286258796252256,-1.0806245954111382,0.4444161407754224,-0.0352172575329119,0.9740347681426992,-0.003516079473801425,-0.8126524696425611,0.865609068204283,2.803', 9)

#(<tf.Tensor: shape=(8,), dtype=float32, numpy= array([ 1.2286259 , -1.0806246 , 0.44441614, -0.03521726, 0.9740348 ,-0.00351608, -0.81265247, 0.86560905], dtype=float32)>,<tf.Tensor: shape=(1,), dtype=float32, numpy=array([2.803], dtype=float32)>)

最后,将每条记录都应用这个方法,就完成了构建。

dataset = dataset.map(parse_csv_line)

完整代码

def csv_2_dataset(filenames, n_readers_thread = 5, batch_size = 32, n_parse_thread = 5, shuffle_buffer_size = 10000):
  
  dataset = tf.data.Dataset.list_files(filenames)
  dataset = dataset.repeat()
  dataset = dataset.interleave(
    lambda filename: tf.data.TextLineDataset(filename).skip(1),
    cycle_length=n_readers_thread
  )
  dataset.shuffle(shuffle_buffer_size)
  dataset = dataset.map(parse_csv_line, num_parallel_calls = n_parse_thread)
  dataset = dataset.batch(batch_size)
  return dataset

如何使用

train_dataset = csv_2_dataset(train_filenames, batch_size=32)
valid_dataset = csv_2_dataset(valid_filenames, batch_size=32)

model = ...

model.fit(train_set, validation_data=valid_set, 
          steps_per_epoch = 11610 // 32,
          validation_steps = 3870 // 32,
          epochs=100, callbacks=callbacks)

这里的11610 和 3870是什么?

这是train_dataset 和 valid_dataset中数据的数量,需要在训练中手动指定每个batch中参与训练的数据的多少。

model.evaluate(test_set, steps=5160//32)

同理,测试的时候,使用这样的数据集,也需要手动指定。
5160是测试数据集的总量。

以上就是如何从csv文件构建Tensorflow的数据集的详细内容,更多关于csv文件构建Tensorflow的数据集的资料请关注脚本之家其它相关文章!

相关文章

  • 常用的Python代码调试工具总结

    常用的Python代码调试工具总结

    今天给大家带来的是关于Python的相关知识,文章围绕着Python代码调试工具展开,文中有非常详细的介绍及代码示例,需要的朋友可以参考下
    2021-06-06
  • 基于PyTorch实现一个简单的CNN图像分类器

    基于PyTorch实现一个简单的CNN图像分类器

    本文记录了一个简单的基于pytorch的图像多分类器模型构造过程,参考自Pytorch官方文档、磐创团队的《PyTorch官方教程中文版》以及余霆嵩的《PyTorch 模型训练实用教程》。从加载数据集开始,包括了模型设计、训练、测试等过程。
    2021-05-05
  • python boto和boto3操作bucket的示例

    python boto和boto3操作bucket的示例

    这篇文章主要介绍了python boto和boto3操作bucket的示例,帮助大家更好的理解和使用python,感兴趣的朋友可以了解下
    2020-10-10
  • python系统指定文件的查找只输出目录下所有文件及文件夹

    python系统指定文件的查找只输出目录下所有文件及文件夹

    这篇文章主要介绍了python系统指定文件的查找只输出目录下所有文件及文件夹,本文给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-01-01
  • python中asyncio异步编程学习

    python中asyncio异步编程学习

    这篇文章主要介绍了python中asyncio异步编程学习,内部就是基于协程实现的异步编程,如果想研究异步编程的同学,要仔细看哦
    2021-04-04
  • 使用Python解析JSON数据的基本方法

    使用Python解析JSON数据的基本方法

    这篇文章主要介绍了使用Python解析JSON数据的基本方法,是Python入门学习中的基础知识,需要的朋友可以参考下
    2015-10-10
  • Springboo如何t动态修改配置文件属性

    Springboo如何t动态修改配置文件属性

    这篇文章主要介绍了Springboo如何t动态修改配置文件属性问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-09-09
  • 关于python3安装pip及requests库的导入问题

    关于python3安装pip及requests库的导入问题

    小编最近快毕业了,闲着无事学习下python的内容在学习到requsets库的导入问题时遇到一些问题,通过查找相关资料问题顺利解决,今天小编把问题解决思路及注意事项分享给大家供大家参考学习
    2021-05-05
  • 用Python实现插值算法

    用Python实现插值算法

    大家好,本篇文章主要讲的是用Python实现插值算法,感兴趣的同学赶快来看一看吧,对你有帮助的话记得收藏一下
    2022-02-02
  • python实现文件分片上传的接口自动化

    python实现文件分片上传的接口自动化

    这篇文章主要为大家详细介绍了python实现文件分片上传的接口自动化,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2020-11-11

最新评论