Python神经网络5之数据读取2

  • 数据读取
  • TFRecords
  • TFRecords文件
  • 案例:CIFAR10数据存入TFRecords文件
  • 读取TFRecords文件API
  • 案例:读取CIFAR的TFRecords文件


数据读取

TFRecords

TFRecords文件

TFRecords其实是一种二进制文件,虽然不如其他格式好理解,但是能够更好的利用内容,更方便复制和移动,并且不需要单独的标签文件
使用步骤:
1.获取数据
2.将数据填入到Example协议内存块(protocol buffer)
3.将协议内存块序列化为字符串,并且通过tf.python_io.TFRecordWriter写入到TFRecords文件

  • 文件格式 *.tfrecords

Example结构解析:

使用python编写代码进行数据采集_python

  • tf.train.Example协议内存块(protocol buffer)(协议内存块包含了字段Features)
    Features包含了一个Features字段
    Feature中包含要写入的数据,并指明数据类型。
    这是一个样本的结构,批数据需要循环存入这样的结构
  • tf.train.Example(features=None)
    写入tfrecords文件
    features:tf.train.Features类型的特征实例
    return:example格式协议块
  • tf.train.Features(feature=None)
    构建每个样本的信息键值对
    feature:字典数据,key为要保存的名字
    value为tf.train.Feature实例
    return:Features类型
  • tf.train.Feature(options)
  • options:例如
  • bytes_list=tf.train.BytesList(value=[Bytes])
  • int64_list=tf.train.Int64List(value=[Value])
  • 支持存入的类型如下
  • tf.train.Int64List(value=[Value])
  • tf.train.BytesList(value=[Bytes])
  • tf.train.FloatList(value=[value])

example = tf.train.Example(features=tf.train.Features(feature={
“image”:tf.train.Feature(bytes_list=tf.train. BytesList(value=[image])),
“label”:tf.train.Feature(int64_list=tf.train. Int64List(value=[label]))
}))
将example序列化:example.SerializeToString()

案例:CIFAR10数据存入TFRecords文件

  • 构造存储实例,tf.python_io.TFRecordWriter(path)
  • 写入tfrecords文件
  • path:TFRecords文件的路径
  • return:写文件
  • method方法
  • write(record):向文件中写入一个example
  • close():关闭文件写入器
  • 循环将数据填入到Example协议内存块(protocol buffer)
def write_to_tfrecords(self, image_batch, label_batch):
        """
        将样本特征值和目标值一起写入tfrecords文件
        :param image:
        :param label:
        :return:
        """
        with tf.python_io.TFRecordWriter("cifar10.tfrecords") as writer:
            # 循环构造example对象,并序列化写入文件
            for i in range(100):
                image = image_batch[i].tostring()
                label = label_batch[i][0]
                # print("tfrecords_image:\n",image)
                # print("tfrecords_label:\n",label)
                example = tf.train.Example(features=tf.train.Features(feature={
                    "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
                }))
                # example.SerializeToString()
                # 将序列化后的example写入文件
                writer.write(example.SerializeToString())
        return None;

生成cifar10.tfrecords文件:

使用python编写代码进行数据采集_数据_02


使用python编写代码进行数据采集_使用python编写代码进行数据采集_03


使用python编写代码进行数据采集_序列化_04


使用python编写代码进行数据采集_序列化_05


使用python编写代码进行数据采集_数据_06

读取TFRecords文件API

读取这种文件整个过程与其他文件一样,只不过需要有个解析Example的步骤,从TFRecords文件中读取数据,可以使用tf.TFRecordReader的tf.parse_single_example解析器,这个操作可以将Example协议内存块(protocol buffer)的解析为张量

  • tf.parse_single_example(serialized,features=None,name=None)
    解析一个单一的Example原型
    serialized:标量字符串Tensor,一个序列化的Example
    features:dict字典数据,键为读取的名字,值为FixedLenFeature
    return:一个键值对组成的字典,键为读取的名字
  • tf.FixedLenFeature(shape,dtype)
    shape:输入数据的形状,一般不指定,为空列表
    dtype:输入数据类型,与存储进文件的类型要一致
    类型只能是float32,int64,string

案例:读取CIFAR的TFRecords文件

  1. 构造文件名队列
  2. 读取和解码
    读取
    解析Example
    解码
  3. 构造批处理队列
def read_tfrecords(self):
        """
        读取TFRecords文件
        :return:
        """
        # 1.构造文件名队列
        file_queue = tf.train.string_input_producer(["cifar10.tfrecords"])

        # 2.读取与解码
        reader = tf.TFRecordReader()
        key, value = reader.read(file_queue)

        # 解析example
        feature = tf.parse_single_example(value, features={
            "image": tf.FixedLenFeature([], tf.string),
            "label": tf.FixedLenFeature([], tf.int64)
        })
        image = feature["image"]
        label = feature["label"]
        print("read_tf_image:\n", image)
        print("read_tf_label:\n", label)
        # 解码
        image_decoded = tf.decode_raw(image, tf.uint8)
        print("image_decoded:\n", image_decoded)
        # 图像形状调整
        image_reshaped = tf.reshape(image_decoded, [self.height, self.width, self.channels])
        print("image_reshaped:\n", image_reshaped)

        # 3.构造批处理队列
        image_batch, label_batch = tf.train.batch([image_reshaped, label], batch_size=100, num_threads=2, capacity=100)
        print("image_batch:\n", image_batch)
        print("label_batch:\n", label_batch)

        # 开启会话
        with tf.Session() as sess:
            # 开启线程
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            #image_value, label_value, image_decoded_value = sess.run([image, label, image_decoded])
            image_value,label_value=sess.run([image_batch,label_batch])
            print("image_value:\n", image_value)
            print("label_value:\n", label_value)

            # 回收资源
            coord.request_stop()
            coord.join(threads)

        return None;

使用python编写代码进行数据采集_ci_07


使用python编写代码进行数据采集_序列化_08


使用python编写代码进行数据采集_序列化_09


全部代码:

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import os


class Cifar(object):

    def __init__(self):
        # 初始化操作
        self.height = 32
        self.width = 32
        self.channels = 3

        # 设置图像字节数
        self.image = self.height * self.width * self.channels
        self.label = 1
        self.sample = self.image + self.label

    def read_binary(self):
        """
        读取二进制文件
        :param file_list:
        :return:
        """
        # 1.构造文件名队列
        filename_list = os.listdir("./cifar-10-batches-bin")
        print("file_name:\n", filename_list)
        # 构造文件名路径列表
        file_list = [os.path.join("./cifar-10-batches-bin/", file) for file in filename_list if file[-3:] == "bin"]
        print("file_list:\n", file_list)
        file_queue = tf.train.string_input_producer(file_list)

        # 2.读取与解码
        reader = tf.FixedLengthRecordReader(self.sample)
        # key文件名 value一个样本
        key, value = reader.read(file_queue)
        print("key:\n", key)
        print("value:\n", value)
        # 解码阶段
        image_decoded = tf.decode_raw(value, tf.uint8)
        print("image_decoded:\n", image_decoded)

        # 将目标值和特征值切片切开
        label = tf.slice(image_decoded, [0], [self.label])
        image = tf.slice(image_decoded, [self.label], [self.image])
        print("label:\n", label)
        print("image:\n", image)

        # 调整图片形状
        image_reshaped = tf.reshape(image, shape=[self.channels, self.height, self.width])
        print("image_reshaped:\n", image_reshaped)

        # 转置,将图片的顺序转为height,width,channels
        image_transposed = tf.transpose(image_reshaped, [1, 2, 0])
        print("image_transposed:\n", image_transposed)

        # 3.批处理
        label_batch, image_batch = tf.train.batch([label, image_transposed], batch_size=100, num_threads=2, capacity=100)
        print("label_batch:\n", label_batch)
        print("image_batch:\n", image_batch)

        # 开启会话
        with tf.Session() as sess:
            # 开启线程
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            label_value, image_value = sess.run([label_batch, image_batch])

            print("label_value:\n", label_value)
            print("image_value:\n", image_value)

            # 回收线程
            coord.request_stop()
            coord.join(threads)
        return label_value, image_value

    def write_to_tfrecords(self, image_batch, label_batch):
        """
        将样本特征值和目标值一起写入tfrecords文件
        :param image:
        :param label:
        :return:
        """
        with tf.python_io.TFRecordWriter("cifar10.tfrecords") as writer:
            # 循环构造example对象,并序列化写入文件
            for i in range(100):
                image = image_batch[i].tostring()
                label = label_batch[i][0]
                # print("tfrecords_image:\n",image)
                # print("tfrecords_label:\n",label)
                example = tf.train.Example(features=tf.train.Features(feature={
                    "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
                }))
                # example.SerializeToString()
                # 将序列化后的example写入文件
                writer.write(example.SerializeToString())
        return None;

    def read_tfrecords(self):
        """
        读取TFRecords文件
        :return:
        """
        # 1.构造文件名队列
        file_queue = tf.train.string_input_producer(["cifar10.tfrecords"])

        # 2.读取与解码
        reader = tf.TFRecordReader()
        key, value = reader.read(file_queue)

        # 解析example
        feature = tf.parse_single_example(value, features={
            "image": tf.FixedLenFeature([], tf.string),
            "label": tf.FixedLenFeature([], tf.int64)
        })
        image = feature["image"]
        label = feature["label"]
        print("read_tf_image:\n", image)
        print("read_tf_label:\n", label)
        # 解码
        image_decoded = tf.decode_raw(image, tf.uint8)
        print("image_decoded:\n", image_decoded)
        # 图像形状调整
        image_reshaped = tf.reshape(image_decoded, [self.height, self.width, self.channels])
        print("image_reshaped:\n", image_reshaped)

        # 3.构造批处理队列
        image_batch, label_batch = tf.train.batch([image_reshaped, label], batch_size=100, num_threads=2, capacity=100)
        print("image_batch:\n", image_batch)
        print("label_batch:\n", label_batch)

        # 开启会话
        with tf.Session() as sess:
            # 开启线程
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            #image_value, label_value, image_decoded_value = sess.run([image, label, image_decoded])
            image_value,label_value=sess.run([image_batch,label_batch])
            print("image_value:\n", image_value)
            print("label_value:\n", label_value)

            # 回收资源
            coord.request_stop()
            coord.join(threads)

        return None;


if __name__ == "__main__":
    # 实例化Cifar
    cifar = Cifar()
    # label_value,image_value=cifar.read_binary()
    # cifar.write_to_tfrecords(image_value,label_value)
    cifar.read_tfrecords()