tensorflow实战之数据加载进阶

老子叫甜甜 提交于 2020-02-06 00:55:29

我们知道Dataset对象能够方便的进行数据加载,而Dataset数据集一般使用流程如下:

(1)创建Dataset对象;

(2)对Dataset对象进行变换操作;

(3)创建Dataset迭代器;

(4)在会话Session中取数据。

对于(1)(2)我们已经有详细的了解,在实际使用过程中如何创建迭代器?如何进行多epoch的训练?当遍历完数据集时如何继续输入数据进行训练?另外,对于不同的数据类型如何加载,如Imagenet数据、yolo等目标检测数据,本篇将继续深入的介绍。


目录

一、Dataset对象迭代器的创建方法

总结

二、数据加载方法汇总

1、ImageNet数据

2、VOC2014数据

3、文本数据

一、Dataset对象迭代器的创建方法

  • make_one_shot_iterator

该迭代器只会将对应的数据遍历一次,不会多次遍历。因此,该迭代器内部自动实现了迭代器的初始化,而其他类型的迭代器需要额外的初始化。这种迭代器适用于之遍历一次整个数据集的情况,多次遍历需要其他类型的迭代器。

import tensorflow as tf
import numpy as np

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
iterator = dataset.make_one_shot_iterator()			#从到到尾读一次
one_element = iterator.get_next()					#从iterator里取出一个元素
with tf.Session() as sess:	# 建立会话(session)
    while True:
        try:
            print(sess.run(one_element))
        except tf.errors.OutOfRangeError:
            break

运行结果:

注意,在Session中通常需要try和except tf.errors.OutOfRangeError来配合使用,因为一般当 Dataset 中的数据被读取完毕的时候,程序会抛出异常,获取这个异常就可以从容结束本次数据的迭代。

  • make_initializable_iterator

可初始化的迭代器,它能够重复的遍历同一个数据集,之所以能够重复进行遍历,主要是当遍历一遍后通过初始化迭代器再重新进行遍历,所以这里会用到try-catch来设置迭代器的重新初始化。一般在实际运用中经常用到。

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0]))
iterator = dataset.make_initializable_iterator()			
one_element = iterator.get_next()					#从iterator里取出一个元素
init=iterator.initializer
with tf.Session() as sess:	# 建立会话(session)
    for epoch in range(2):
        print("================")
        sess.run(init)
        while True:
            try:
                print(sess.run(one_element))
            except tf.errors.OutOfRangeError:
                break

运行结果:

跟单次 Iterator 的代码只有 2 处不同:

1、创建的方式不同,dataset.make_initializable_iterator();

2、每次重新初始化的时候,都要调用sess.run(iterator.initializer)。

  • tf.data.Iterator.from_structure

更加灵活的迭代器,其支持多个数据集读取,和上面一个迭代器类似,也支持迭代器的初始化。

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0]))
iterator = tf.data.Iterator.from_structure(dataset.output_types,dataset.output_shapes)
one_element = iterator.get_next()					
init=iterator.make_initializer(dataset)
with tf.Session() as sess:	# 建立会话(session)
    for epoch in range(2):
        print("================")
        sess.run(init)
        while True:
            try:
                print(sess.run(one_element))
            except tf.errors.OutOfRangeError:
                break

运行结果:

总结

1、 单次 Iterator ,它最简单,但无法重用,无法处理数据集参数化的要求。
2、 可以初始化的 Iterator ,它可以满足 Dataset 重复加载数据,满足了参数化要求。
3、可重新初始化的 Iterator,它可以对接不同的 Dataset,也就是可以从不同的 Dataset 中读取数据。

目前来讲,需要掌握最后两种,这两种用的比较多。同时这节也说明了如何重复使用数据,如何通过捕获异常来进行数据的重复训练,这些都是要注意的点。

参考链接:

TensorFlow』数据读取类_data.Dataset

【Tensorflow】Dataset 中的 Iterator

二、数据加载方法汇总

1、ImageNet数据

2、VOC2014数据

3、文本数据

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!