补充:TFRECORD文件学习


import tensorflow as tf import os import random import sys


#生成的tfrecord文件数量 _NUM_BLOCK = 2 #源图片位置 DATASET_DIR = "./Images/SourceImgs/" #目标文件位置 GOALSET_DIR = "./Images/TFImgs/"  GOALTFSET_DIR = "./Images/TFSourceImgs/"  QUEUE_DIR = "./Images/QueueImgs/"


一:生成TFRECORD文件

(一)获取图片信息


#获取图片的信息 def get_file_info(dataSet_dir = DATASET_DIR):     files_name = []     for filename in os.listdir(dataSet_dir):         files_name.append(os.path.join(dataSet_dir,filename))              return files_name


(二)写入TFRECORD文件


with tf.Session() as sess:     files_list = get_file_info()     num_per_block = len(files_list)//_NUM_BLOCK     for _id in range(_NUM_BLOCK):         tfr_name = "image_%d.tfrecord"%(_id+1)         tfr_dir = os.path.join(GOALSET_DIR,tfr_name)         with tf.python_io.TFRecordWriter(tfr_dir) as writer:             start_idx = _id*num_per_block             end_idx = min((_id+1)*num_per_block,len(files_list))                          for i in range(start_idx,end_idx):                 try:                     sys.stdout.write("\r>>Converting images %d/%d to block %d"%(i+1,len(files_list),_id+1))                     sys.stdout.flush()                     #读取图片信息                     image_data = tf.gfile.FastGFile(files_list[i],'rb').read()                     #获取标签                     label = files_list[i].split("/")[-1].split(".")[0]                                          example = _format_record(image_data,label)                                          writer.write(example.SerializePartialToString())                 except IOError as e:                     print("Could not read:",files_list[i])                     print("Error",e)                     print("Skip it\n")     sys.stdout.write("\n")     sys.stdout.flush()


Tensorflow踩坑系列---TFRECORD文件读写_获取图片

二:直接读取TFRECORD文件

(一)解析文件


def _parse_record(example_proto):     features = {         'label':tf.FixedLenFeature((),tf.string),         'data':tf.FixedLenFeature((),tf.string)     }     parsed_features = tf.parse_single_example(example_proto,features=features)     return parsed_features


(二)读取所有文件


with tf.Session() as sess:     tf_files = []     for fn in os.listdir(GOALSET_DIR):         tf_files.append(os.path.join(GOALSET_DIR,fn))              dataSet = tf.data.TFRecordDataset(tf_files) #读取TF文件---可以选择一次性读取所有的tfrecord文件     dataSet = dataSet.map(_parse_record) #解析数据          iterator = dataSet.make_one_shot_iterator()          sess.run(tf.local_variables_initializer())     while True:         try:             Singledata = sess.run(iterator.get_next())             label = Singledata['label'].decode()             image_data = Singledata['data']             tf.gfile.GFile(os.path.join(GOALTFSET_DIR,"%s.jpg"%label),"wb").write(image_data)         except BaseException as e:             print("Read finish!!!")             break


Tensorflow踩坑系列---TFRECORD文件读写_sed_02

三:使用文件队列读取多个tfrecord文件



tf_files = [] for fn in os.listdir(GOALSET_DIR):     tf_files.append(os.path.join(GOALSET_DIR,fn))  #string_input_producer产生文件名队列 filename_queue = tf.train.string_input_producer(tf_files,shuffle=True,num_epochs=3) #获取了多个tfrecord文件  #reader从文件名队列中读取数据 reader = tf.TFRecordReader() key,value = reader.read(filename_queue) #返回文件名和文件内容 features = tf.parse_single_example(value,features={         'label':tf.FixedLenFeature((),tf.string),         'data':tf.FixedLenFeature((),tf.string)     }) img_data = features['data'] label = features['label']  image_batch,label_batch = tf.train.shuffle_batch([img_data,label],batch_size=8,num_threads=2,allow_smaller_final_batch=True,                                                 capacity=500,min_after_dequeue=100)  with tf.Session() as sess:     sess.run(tf.global_variables_initializer()) #初始化上面的全局变量     sess.run(tf.local_variables_initializer()) #初始化上面的局部变量          coord = tf.train.Coordinator()     #启动start_queue_runners之后,才会开始填充队列     threads = tf.train.start_queue_runners(sess=sess,coord=coord)     j = 1     try:         while not coord.should_stop():             images_data,labels_data = sess.run([image_batch,label_batch])             for i in range(len(images_data)):                 with open(QUEUE_DIR+"%s-%d.jpg"%(labels_data[i].decode(),j),"wb") as f:                     f.write(images_data[i])                 j+=1     except BaseException as e:             print("read all files")     finally:         coord.request_stop() #将读取文件的线程关闭     coord.join(threads) #线程回收,将读取文件的子线程加入主线程


Tensorflow踩坑系列---TFRECORD文件读写_获取图片_03