• TensorFlow Datasets 数据集载入

    TensorFlow Datasets 数据集载入

    TensorFlow Datasets 是一个开箱即用的数据集集合,包含数十种常用的机器学习数据集。通过简单的几行代码即可将数据以 tf.data.Datasets 的格式载入。关于 tf.data.Datasets 的使用可参考 tf.data。

    该工具是一个独立的Python包,可以通过:

    1. pip install tensorflow-datasets

    安装。

    在使用时,首先使用import导入该包

    1. import tensorflow as tf
    2. import tensorflow_datasets as tfds

    然后,最基础的用法是使用 tfds.load 方法,载入所需的数据集,如:

    1. dataset = tfds.load("mnist", split=tfds.Split.TRAIN)
    2. dataset = tfds.load("cats_vs_dogs", split=tfds.Split.TRAIN, as_supervised=True)
    3. dataset = tfds.load("tf_flowers", split=tfds.Split.TRAIN, as_supervised=True)

    该方法返回一个 tf.data.Datasets 对象。部分重要的参数如下:

    • as_supervised :若为True,则根据数据集的特性返回为 (input, label) 格式,否则返回所有特征的字典。

    • split:指定返回数据集的特定部分,若无则返回整个数据集。一般有 tfds.Split.TRAIN (训练集)和 tfds.Split.TEST (测试集)选项。

    当前支持的数据集可在 官方文档 或使用 tfds.list_builders() 查看。

    当得到了 tf.data.Datasets 类型的数据集后,我们即可使用 tf.data 对数据集进行各种预处理以及读取数据。例如:

    1. # 使用 TessorFlow Datasets 载入“tf_flowers”数据集
    2. dataset = tfds.load("tf_flowers", split=tfds.Split.TRAIN, as_supervised=True)
    3. # 对 dataset 进行大小调整、打散和分批次操作
    4. dataset = dataset.map(lambda img, label: (tf.image.resize(img, [224, 224]) / 255.0, label)) \
    5. .shuffle(1024) \
    6. .batch(32)
    7. # 迭代数据
    8. for images, labels in dataset:
    9. # 对images和labels进行操作

    详细操作说明可见 本文档的 tf.data 一节 。

    提示

    在使用 TensorFlow Datasets 时,可能需要设置代理。较为简易的方式是设置 TFDS_HTTPS_PROXY 环境变量,即

    1. export TFDS_HTTPS_PROXY=http://代理服务器IP:端口