有了jpg读取的经验和tfrecord写入的经验之后,开始尝试把jpg图像写入到tfrecord,另外还想尝试从tfrecord文件读出jpg图像
上示例把jpg的图片的二进制以及长和宽的信息保存进tfrecord
decode_jpeg_data = tf.placeholder(dtype=tf.string) decode_jpeg = tf.image.decode_jpeg(decode_jpeg_data, channels=3) tfrecords_filename = './tfrecords' writer = tf.python_io.TFRecordWriter(tfrecords_filename) # 创建.tfrecord文件,准备写入 image_data = tf.gfile.FastGFile("C:/Users/shenwei/Desktop/timg.jpg", 'rb').read() print(type(image_data)) with tf.Session() as sess: image = sess.run(decode_jpeg,feed_dict={decode_jpeg_data: image_data}) print(image.shape[0]) print(image.shape[1]) example = tf.train.Example(features=tf.train.Features( feature={ 'encoded': tf.train.Feature(bytes_list = tf.train.BytesList(value=[image_data])) , 'height': tf.train.Feature(int64_list = tf.train.Int64List(value = [image.shape[0]])), 'width': tf.train.Feature(int64_list = tf.train.Int64List(value = [image.shape[1]])), })) writer.write(example.SerializeToString()) writer.close()
********************上一段完整的保存以及读取的示例********************
slim = tf.contrib.slim # 创建TFrecord文件 def create_record_file(): train_filename = "train.tfrecords" if os.path.exists(train_filename): os.remove(train_filename) # 创建.tfrecord文件,准备写入 writer = tf.python_io.TFRecordWriter('./'+train_filename) with tf.Session() as sess: for i in range(10): img_raw = tf.gfile.FastGFile("C:/Users/shenwei/Desktop/test/"+str(i)+".jpg", 'rb').read() decode_data = tf.image.decode_jpeg(img_raw) image_shape= decode_data.eval().shape example = tf.train.Example(features=tf.train.Features( feature={ 'image/encoded':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), 'image/format':tf.train.Feature(bytes_list = tf.train.BytesList(value=[b'jpg'])), 'image/width':tf.train.Feature(int64_list = tf.train.Int64List(value=[image_shape[1]])), 'image/height':tf.train.Feature(int64_list = tf.train.Int64List(value=[image_shape[0]])), 'image/label':tf.train.Feature(int64_list = tf.train.Int64List(value=[i])), })) writer.write(example.SerializeToString()) # 序列化保存 writer.close() print ("保存tfrecord文件成功。") # 使用Slim的方法从TFrecord文件中读取 def read_record_file(): tfrecords_filename = "train.tfrecords" # 将tf.train.Example反序列化成存储之前的格式。由tf完成 keys_to_features = { 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'), 'image/width': tf.FixedLenFeature((), tf.int64, default_value=0), 'image/height': tf.FixedLenFeature((), tf.int64, default_value=0), 'image/label': tf.FixedLenFeature((), tf.int64, default_value=0), } # 将反序列化的数据组装成更高级的格式。由slim完成 items_to_handlers = { 'image': slim.tfexample_decoder.Image(image_key='image/encoded', format_key='image/format', channels=3), 'label': slim.tfexample_decoder.Tensor('image/label'), 'height': slim.tfexample_decoder.Tensor('image/height'), 'width': slim.tfexample_decoder.Tensor('image/width') } # 定义解码器,进行解码 decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) # 定义dataset,该对象定义了数据集的文件位置,解码方式等元信息 dataset = slim.dataset.Dataset( data_sources=tfrecords_filename, reader=tf.TFRecordReader, decoder=decoder, num_samples=10, # 训练数据的总数 items_to_descriptions=None, num_classes=10, ) #使用provider对象根据dataset信息读取数据 provider = slim.dataset_data_provider.DatasetDataProvider( dataset, num_readers=1, common_queue_capacity=20, common_queue_min=1) # 获取数据 [image, label,height,width] = provider.get(['image', 'label','height','width']) with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) coord=tf.train.Coordinator() threads= tf.train.start_queue_runners(coord=coord) for i in range(10): img,l,h,w= sess.run([image,label,height,width]) #img = tf.reshape(img, [h,w,3]) print (img.shape) #img=Image.fromarray(img.eval(), 'RGB') # 这里将narray转为Image类,Image转narray:a=np.array(img) #img.save('./'+str(l)+'.jpg') # 保存图片 coord.request_stop() coord.join(threads) if __name__ == '__main__': #create_record_file() read_record_file()