Python神经网络5之数据读取2
- 数据读取
- TFRecords
- TFRecords文件
- 案例:CIFAR10数据存入TFRecords文件
- 读取TFRecords文件API
- 案例:读取CIFAR的TFRecords文件
数据读取
TFRecords
TFRecords文件
TFRecords其实是一种二进制文件,虽然不如其他格式好理解,但是能够更好的利用内容,更方便复制和移动,并且不需要单独的标签文件
使用步骤:
1.获取数据
2.将数据填入到Example协议内存块(protocol buffer)
3.将协议内存块序列化为字符串,并且通过tf.python_io.TFRecordWriter写入到TFRecords文件
- 文件格式 *.tfrecords
Example结构解析:
- tf.train.Example协议内存块(protocol buffer)(协议内存块包含了字段Features)
Features包含了一个Features字段
Feature中包含要写入的数据,并指明数据类型。
这是一个样本的结构,批数据需要循环存入这样的结构 - tf.train.Example(features=None)
写入tfrecords文件
features:tf.train.Features类型的特征实例
return:example格式协议块 - tf.train.Features(feature=None)
构建每个样本的信息键值对
feature:字典数据,key为要保存的名字
value为tf.train.Feature实例
return:Features类型 - tf.train.Feature(options)
- options:例如
- bytes_list=tf.train.BytesList(value=[Bytes])
- int64_list=tf.train.Int64List(value=[Value])
- 支持存入的类型如下
- tf.train.Int64List(value=[Value])
- tf.train.BytesList(value=[Bytes])
- tf.train.FloatList(value=[value])
example = tf.train.Example(features=tf.train.Features(feature={
“image”:tf.train.Feature(bytes_list=tf.train. BytesList(value=[image])),
“label”:tf.train.Feature(int64_list=tf.train. Int64List(value=[label]))
}))
将example序列化:example.SerializeToString()
案例:CIFAR10数据存入TFRecords文件
- 构造存储实例,tf.python_io.TFRecordWriter(path)
- 写入tfrecords文件
- path:TFRecords文件的路径
- return:写文件
- method方法
- write(record):向文件中写入一个example
- close():关闭文件写入器
- 循环将数据填入到Example协议内存块(protocol buffer)
def write_to_tfrecords(self, image_batch, label_batch):
"""
将样本特征值和目标值一起写入tfrecords文件
:param image:
:param label:
:return:
"""
with tf.python_io.TFRecordWriter("cifar10.tfrecords") as writer:
# 循环构造example对象,并序列化写入文件
for i in range(100):
image = image_batch[i].tostring()
label = label_batch[i][0]
# print("tfrecords_image:\n",image)
# print("tfrecords_label:\n",label)
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
}))
# example.SerializeToString()
# 将序列化后的example写入文件
writer.write(example.SerializeToString())
return None;
生成cifar10.tfrecords文件:
读取TFRecords文件API
读取这种文件整个过程与其他文件一样,只不过需要有个解析Example的步骤,从TFRecords文件中读取数据,可以使用tf.TFRecordReader的tf.parse_single_example解析器,这个操作可以将Example协议内存块(protocol buffer)的解析为张量
- tf.parse_single_example(serialized,features=None,name=None)
解析一个单一的Example原型
serialized:标量字符串Tensor,一个序列化的Example
features:dict字典数据,键为读取的名字,值为FixedLenFeature
return:一个键值对组成的字典,键为读取的名字 - tf.FixedLenFeature(shape,dtype)
shape:输入数据的形状,一般不指定,为空列表
dtype:输入数据类型,与存储进文件的类型要一致
类型只能是float32,int64,string
案例:读取CIFAR的TFRecords文件
- 构造文件名队列
- 读取和解码
读取
解析Example
解码 - 构造批处理队列
def read_tfrecords(self):
"""
读取TFRecords文件
:return:
"""
# 1.构造文件名队列
file_queue = tf.train.string_input_producer(["cifar10.tfrecords"])
# 2.读取与解码
reader = tf.TFRecordReader()
key, value = reader.read(file_queue)
# 解析example
feature = tf.parse_single_example(value, features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64)
})
image = feature["image"]
label = feature["label"]
print("read_tf_image:\n", image)
print("read_tf_label:\n", label)
# 解码
image_decoded = tf.decode_raw(image, tf.uint8)
print("image_decoded:\n", image_decoded)
# 图像形状调整
image_reshaped = tf.reshape(image_decoded, [self.height, self.width, self.channels])
print("image_reshaped:\n", image_reshaped)
# 3.构造批处理队列
image_batch, label_batch = tf.train.batch([image_reshaped, label], batch_size=100, num_threads=2, capacity=100)
print("image_batch:\n", image_batch)
print("label_batch:\n", label_batch)
# 开启会话
with tf.Session() as sess:
# 开启线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
#image_value, label_value, image_decoded_value = sess.run([image, label, image_decoded])
image_value,label_value=sess.run([image_batch,label_batch])
print("image_value:\n", image_value)
print("label_value:\n", label_value)
# 回收资源
coord.request_stop()
coord.join(threads)
return None;
全部代码:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import os
class Cifar(object):
def __init__(self):
# 初始化操作
self.height = 32
self.width = 32
self.channels = 3
# 设置图像字节数
self.image = self.height * self.width * self.channels
self.label = 1
self.sample = self.image + self.label
def read_binary(self):
"""
读取二进制文件
:param file_list:
:return:
"""
# 1.构造文件名队列
filename_list = os.listdir("./cifar-10-batches-bin")
print("file_name:\n", filename_list)
# 构造文件名路径列表
file_list = [os.path.join("./cifar-10-batches-bin/", file) for file in filename_list if file[-3:] == "bin"]
print("file_list:\n", file_list)
file_queue = tf.train.string_input_producer(file_list)
# 2.读取与解码
reader = tf.FixedLengthRecordReader(self.sample)
# key文件名 value一个样本
key, value = reader.read(file_queue)
print("key:\n", key)
print("value:\n", value)
# 解码阶段
image_decoded = tf.decode_raw(value, tf.uint8)
print("image_decoded:\n", image_decoded)
# 将目标值和特征值切片切开
label = tf.slice(image_decoded, [0], [self.label])
image = tf.slice(image_decoded, [self.label], [self.image])
print("label:\n", label)
print("image:\n", image)
# 调整图片形状
image_reshaped = tf.reshape(image, shape=[self.channels, self.height, self.width])
print("image_reshaped:\n", image_reshaped)
# 转置,将图片的顺序转为height,width,channels
image_transposed = tf.transpose(image_reshaped, [1, 2, 0])
print("image_transposed:\n", image_transposed)
# 3.批处理
label_batch, image_batch = tf.train.batch([label, image_transposed], batch_size=100, num_threads=2, capacity=100)
print("label_batch:\n", label_batch)
print("image_batch:\n", image_batch)
# 开启会话
with tf.Session() as sess:
# 开启线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
label_value, image_value = sess.run([label_batch, image_batch])
print("label_value:\n", label_value)
print("image_value:\n", image_value)
# 回收线程
coord.request_stop()
coord.join(threads)
return label_value, image_value
def write_to_tfrecords(self, image_batch, label_batch):
"""
将样本特征值和目标值一起写入tfrecords文件
:param image:
:param label:
:return:
"""
with tf.python_io.TFRecordWriter("cifar10.tfrecords") as writer:
# 循环构造example对象,并序列化写入文件
for i in range(100):
image = image_batch[i].tostring()
label = label_batch[i][0]
# print("tfrecords_image:\n",image)
# print("tfrecords_label:\n",label)
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
}))
# example.SerializeToString()
# 将序列化后的example写入文件
writer.write(example.SerializeToString())
return None;
def read_tfrecords(self):
"""
读取TFRecords文件
:return:
"""
# 1.构造文件名队列
file_queue = tf.train.string_input_producer(["cifar10.tfrecords"])
# 2.读取与解码
reader = tf.TFRecordReader()
key, value = reader.read(file_queue)
# 解析example
feature = tf.parse_single_example(value, features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64)
})
image = feature["image"]
label = feature["label"]
print("read_tf_image:\n", image)
print("read_tf_label:\n", label)
# 解码
image_decoded = tf.decode_raw(image, tf.uint8)
print("image_decoded:\n", image_decoded)
# 图像形状调整
image_reshaped = tf.reshape(image_decoded, [self.height, self.width, self.channels])
print("image_reshaped:\n", image_reshaped)
# 3.构造批处理队列
image_batch, label_batch = tf.train.batch([image_reshaped, label], batch_size=100, num_threads=2, capacity=100)
print("image_batch:\n", image_batch)
print("label_batch:\n", label_batch)
# 开启会话
with tf.Session() as sess:
# 开启线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
#image_value, label_value, image_decoded_value = sess.run([image, label, image_decoded])
image_value,label_value=sess.run([image_batch,label_batch])
print("image_value:\n", image_value)
print("label_value:\n", label_value)
# 回收资源
coord.request_stop()
coord.join(threads)
return None;
if __name__ == "__main__":
# 实例化Cifar
cifar = Cifar()
# label_value,image_value=cifar.read_binary()
# cifar.write_to_tfrecords(image_value,label_value)
cifar.read_tfrecords()