以为自己会用的是pytorch架构,没有想到还是得先学会tensorflow啊!!!
本来打算先读FastFCN的源码的,但是因为刚出来可以学习的笔记太少,所以我还是选择了deeplabv3+源码学习!
大佬的github:https://codeload.github.com/rishizek/tensorflow-deeplab-v3-plus/zip/master
因为VOC2012的Augmented data一直下载不了,我选择用camvid进行训练,emmm,直接开始吧。(注意camvid的label是通过labelme进行图像标注,然后把得到的json转换为图片,最后把RGB图片再转换成单通道的灰度图片,数据集中的mask已经是单通道的灰度图片,故可以直接转为tfrecord数据格式)
1.TFRecord数据文件
Tensorflow拥有直接的数据输入的格式,所以在进行训练的第一步当然是如何把图片数据image和标签数据label转换成为TFrecored数据。下面介绍转换过程中常用的一些函数:
Img_raw=tf.gfile.FastGFile(dir,'rb').read() #dir具体到每张图片的地址,得到的图片类型是bytes,就不用转换类型了
# 变形记开始啦
Writer=tf.python_io.TFRecordWriter(output_dir) #output_dir是转换后文件地址,文件后缀可以直接是.tfrecord
# 将得到的图片数据转换成example protocol buffer
Example=tf.train.Example(features=tf.train.Features(feature={
'img_raw':_bytes_feature(Img_raw)
}))
Writer.write(Example.SerializeToString()) #将信息写入这个数据结构
Writer.close()
以下是deeplabv3+中的数据转换,create_pascal_tf_record.py文件解读:
def create_tf_record(output_filename,image_dir,label_dir,examples):
"""Creates a TFRecord file from examples.
Args:
output_filename: Path to where output file is saved.
image_dir: Directory where image files are stored.
label_dir: Directory where label files are stored.
examples: Examples to parse and save to tf record.
"""
# 创建一个类writer
writer = tf.python_io.TFRecordWriter(output_filename)
for idx, example in enumerate(examples):
if idx % 500 == 0:
tf.logging.info('On image %d of %d', idx, len(examples))
# 得到图片image和label的具体地址
image_path = os.path.join(image_dir, example + '.png')
label_path = os.path.join(label_dir, example + '.png')
if not os.path.exists(image_path):
tf.logging.warning('Could not find %s, ignoring example.', image_path)
continue
elif not os.path.exists(label_path):
tf.logging.warning('Could not find %s, ignoring example.', label_path)
continue
# 将两个地址都送入dict_to_tf_example中,得到example对象
try:
tf_example = dict_to_tf_example(image_path, label_path)
writer.write(tf_example.SerializeToString()) # 将对象写入文件地址中
except ValueError:
tf.logging.warning('Invalid example: %s, ignoring.', example)
writer.close()
# 将image和label都转换为TFrecord文件
def dict_to_tf_example(image_path, label_path):
#以rb读二进制的方式打开图片所在地址
with tf.gfile.GFile(image_path, 'rb') as fid:
encoded_jpg = fid.read() # 读取图片内容
encoded_jpg_io = io.BytesIO(encoded_jpg) # 创建一个临时二进制文件
image = PIL.Image.open(encoded_jpg_io) # 读取这个二进制文件,但是返回的不是numpy数据,但可以通过numpy.array进行数据转换
if image.format != 'png':
raise ValueError('Image format not PNG')
# 同上,对label数据进行转换。
with tf.gfile.GFile(label_path, 'rb') as fid:
encoded_label = fid.read()
encoded_label_io = io.BytesIO(encoded_label)
label = PIL.Image.open(encoded_label_io)
if label.format != 'PNG':
raise ValueError('Label format not PNG')
if image.size != label.size:
raise ValueError('The size of image does not match with that of label.')
# 获得图片的宽高(480,360)
width, height = image.size
# 创建Example对象,每一个example都有下面这些feature
# bytes_feature和int64_feature是TF数据的两种类型,即字符串型和int64
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature('png'.encode('utf8')),
'label/encoded': dataset_util.bytes_feature(encoded_label),
'label/format': dataset_util.bytes_feature('png'.encode('utf8')),
}))
return example
2.train.py文件解读
def main(unused_argv):
# Using the Winograd non-fused algorithms provides a small performance boost.
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
if FLAGS.clean_model_dir:
shutil.rmtree(FLAGS.model_dir, ignore_errors=True)
#1.创建RunConfig来更改checkpoint的时间
run_config=tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9)
#2.实例化Estimator,model_fn得到的是resnet-101的网络架构,model_dir是预训练好的模型参数,params参数会自动传送给deeplabv3_plus_model_fn
model=tf.estimator.Estimator(model_fn=deeplab_model.deeplabv3_plus_model_fn, model_dir=FLAGS.model_dir,
config=run_config, params={
'output_stride':FLAGS.output_stride,
'batch_size':FLAGS.batch_size,
'base_architecture':FLAGS.base_architecture,
'pre_trained_model':FLAGS.pre_trained_model,
'batch_norm_decay':_BATCH_NORM_DECAY,
'num_classes':_NUM_CLASSES,
'tensorboard_images_max_outputs':FLAGS.tensorboard_images_max_outputs,
'weight_decay':FLAGS.weight_decay,
'learning_rate_policy':FLAGS.learning_rate_policy,
'num_train':_NUM_IMAGES['train'],
'initial_learning_rate':FLAGS.initial_learning_rate,
'max_iter':FLAGS.max_iter,
'end_learning_rate':FLAGS.end_learning_rate,
'power':_POWER,
'momentum':_MOMENTUM,
'freeze_batch_norm':FLAGS.freeze_batch_norm,
'initial_global_step':FLAGS.initial_global_step
})
# 会被打印出来的内容包括学习率,交叉熵,像素点的准确度和mIOU
for_inrange(FLAGS.train_epochs//FLAGS.epochs_per_eval):
tensors_to_log={
'learning_rate':'learning_rate',
'cross_entropy':'cross_entropy',
'train_px_accuracy':'train_px_accuracy',
'train_mean_iou':'train_mean_iou',
}
# 设置每迭代10次就会打印一次
logging_hook=tf.train.LoggingTensorHook(
tensors=tensors_to_log,every_n_iter=10)
train_hooks=[logging_hook]
eval_hooks=None
# 调用TF的调试器
if FLAGS.debug:
debug_hook=tf_debug.LocalCLIDebugHook()
train_hooks.append(debug_hook)
eval_hooks=[debug_hook]
# 3.开始训练模型,函数input_fn作为数据输入的来源
tf.logging.info("Starttraining.")
model.train(input_fn=lambda:input_fn(True,FLAGS.data_dir,FLAGS.batch_size,FLAGS.epochs_per_eval),hooks=train_hooks,)
# 4.开始进行模型评估,函数input_fn作为数据输入来源
tf.logging.info("Startevaluation.")
eval_results=model.evaluate(input_fn=lambda:input_fn(False,FLAGS.data_dir,1),hooks=eval_hooks,)
print(eval_results)
3.input_fn()数据输入函数:
def input_fn(is_training,data_dir,batch_size,num_epochs=1):
# 1.切片处理,输入数据,在第一个维度内进行切片!!!!
dataset=tf.data.Dataset.from_tensor_slices(get_filenames(is_training,data_dir)) # 得到的是TFrecord的地址
# print(dataset.output_shapes)
#print(dataset.types) # 可以查看dataset里面的数据类型
dataset=dataset.flat_map(tf.data.TFRecordDataset) # 解析tfrecord文件的每一条记录,序列化后#tf.train.Example,解析函数是parse_record中的parse_single_example()
# 2.shuffle()操作:随机化处理
ifis_training:
dataset=dataset.shuffle(buffer_size=_NUM_IMAGES['train']) # 训练时,将输入数据的顺序进行打乱
dataset=dataset.map(parse_record) # 指定parse_record方法对数据进行改变,返回的是image和label
dataset=dataset.map(lambdaimage,label:preprocess_image(image,label,is_training))
dataset=dataset.prefetch(batch_size)
# repeat操作
dataset=dataset.repeat(num_epochs) # 指定重复的次数
dataset=dataset.batch(batch_size)
# 3.创建迭代器
iterator=dataset.make_one_shot_iterator()
images,labels=iterator.get_next()
Return images,labels
4.parse_record()函数
下面介绍得到TFrecord之后对该文件的解析,即dataset.map(parserecord)中的parse_record()函数
def parse_record(raw_record):
keys_to_features={
'image/height':
tf.FixedLenFeature((),tf.int64),
'image/width':
tf.FixedLenFeature((),tf.int64),
'image/encoded':
tf.FixedLenFeature((),tf.string,default_value=''),
'image/format':
tf.FixedLenFeature((),tf.string,default_value='png'),
'label/encoded':
tf.FixedLenFeature((),tf.string,default_value=''),
'label/format':
tf.FixedLenFeature((),tf.string,default_value='png'),
}
# 解析每一条记录
parsed=tf.parse_single_example(raw_record,keys_to_features) # 解析record的每条记录
#height=tf.cast(parsed['image/height'],tf.int32) # 进行类型转换!!!
#width=tf.cast(parsed['image/width'],tf.int32)
image=tf.image.decode_image(tf.reshape(parsed['image/encoded'],shape=[]) ,_DEPTH)
image=tf.to_float(tf.image.convert_image_dtype(image,dtype=tf.uint8))
image.set_shape([None,None,3])
label=tf.image.decode_image(
tf.reshape(parsed['label/encoded'],shape=[]),1)
label=tf.to_int32(tf.image.convert_image_dtype(label,dtype=tf.uint8))
label.set_shape([None,None,1])
return image,label
在input_fn函数中,解析之后得到的文件还有进过再处理,即preprocess_image()函数:
def preprocess_image(image,label,is_training):
"""Preprocessa single image of layout [height,width,depth]."""
ifis_training:
#Randomly scale thei mage and label.
image,label=preprocessing.random_rescale_image_and_label(image,label,_MIN_SCALE,_MAX_SCALE)
#Randomly crop or pad a [_HEIGHT,_WIDTH] section of the image and label.
image,label=preprocessing.random_crop_or_pad_image_and_label(
image,label,_HEIGHT,_WIDTH,_IGNORE_LABEL)
#Randomly flip the image and label horizontally.
image,label=preprocessing.random_flip_left_right_image_and_label(
image,label)
image.set_shape([_HEIGHT,_WIDTH,3])
label.set_shape([_HEIGHT,_WIDTH,1])
image=preprocessing.mean_image_subtraction(image)
return image,label
经过该函数可以知道,解析后的图片经过再次的处理之后图片的shape=[height,width,3],注意label.shape=[height,width,1]其中的1是它的通道数,因为label是单通道的灰度图,所以为1。
至此,数据的输入就讲完了,deeplabv3+的知识就讲完了,涉及到的argparse和tf.eatimator在onenote中进行过笔记处理,闲时再整理吧。