1、概述
在前一篇文章中详细讨论了迭代器与数据集的相关内容。由于数据集与迭代器是链接原始数据与程序连接的渠道,所以本文主要讨论如何从原始数据中构建数据集,主要涉及以下场景:
- 内存
- TFRecord data
- 文本文件
- csv文件
2、从内存中读取数据
如果所有的数据都以numpy数据组的形式预先保存到了内存当中,那么我们使用Dataset.from_tensor_slices()方法可以非常方便的将一个这样的数组转化为tensorflow的张量对象。下面以手写数字的数据集为例来说明这个场景的应用。
1)首先像我们在使用卷积神经网络处理手写数字一样,我们先来下载数据集,参考代码属下:
# mnist数据集
import tensorflow as tf
mnist = tf.contrib.learn.datasets.load_dataset("mnist")
2)由于只是讨论数据的处理,这里我们使用数据量比较小的测试数据集
import numpy as np
eval_data = mnist.test.images
print(eval_data.shape)
eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)
print(eval_labels.shape)
输出结果如下:
3)现在这个数据就是一个numpy array 的数据集,下面将其转化为一个Dataset,同时构造一个one-shot的迭代器来读取里边的数据
dataset = tf.data.Dataset.from_tensor_slices((eval_data, eval_labels))
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(2):
value = sess.run(next_element)
print(value)
就目前tensorflow的机制来讲,所有保存在内存中的数据都需要转化为numpy数组才能够被更好的处理,如果是通过pandas读取的数据需要先转化为nd array才能够输入给tensorflow。
3、读取TFRecord 格式的数据
由于数据量非常大并不适用于直接读入到内存当中,只能从磁盘文件进行数据读入,tf.data支持各种格式的文件读入。TFRecord文件格式是一种简单的面向记录的二进制格式,许多TensorFlow应用程序使用它来训练数据。
1)关于TFRecord
tfrecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList、tf.train.Int64List、tf.train.FloatList写入tf.train.Feature,如下所示:
#feature一般是多维数组,要先转为list
tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
#tostring函数后feature的形状信息会丢失,把shape也写入
tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.shape)))
tf.train.Feature(float_list=tf.train.FloatList(value=[label]))
我们可以使用TFRecordWriter类来实现TFRecord格式的生成,下面是生成的详细代码:
import numpy as np
eval_data = mnist.test.images
eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)
writer = tf.python_io.TFRecordWriter('tfdata/img.tfrecords')
for i in range(len(eval_labels)):
row_data = eval_data[i]
labels = eval_labels[i]
example = tf.train.Example(
features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[labels])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[row_data.tostring()]))
}))
# 将信息写入指定路径
writer.write(example.SerializeToString())
下面可以使用dataset来读取数据,请参照如下代码:
def parse_data(data):
feats = tf.parse_single_example(data, features={'img_raw':tf.FixedLenFeature([], tf.string),'label':tf.FixedLenFeature([],tf.int64)})
image = tf.decode_raw(feats['img_raw'], tf.float32)
label = feats['label']
return image, label
dataset = tf.data.TFRecordDataset('tfdata/img.tfrecords')
dataset = dataset.map(parse_data)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
value = sess.run(next_element)
print(value)
4、读取文本文件
很多数据集都是作为一个或多个文本文件分布的。tf.data.TextLineDataset 提供了一种从一个或多个文本文件中提取行的简单方法。给定一个或多个文件名,TextLineDataset 会为这些文件的每行生成一个字符串值元素。像 TFRecordDataset 一样,TextLineDataset 将 filenames 视为 tf.Tensor,因此您可以通过传递 tf.placeholder(tf.string) 来进行参数化,请参照如下代码:
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filenames)