姓名:Jyx
描述:人工智能学习笔记
tensorflow object detection api
tensorflow object detection api 并不在主项目内,而是在单独开发的一个项目内,其路径为https://github.com/tensorflow/models/tree/master/research/object_detection
tensorflow object detection 基本流程
- 数据准备。将数据打包成object detection支持的格式,一般为tfrecord
- 参数配置。定义pipeline config文件
- 模型训练。调用模型训练
- 可视化
从上面的过程可以看出:tensorflow api封装的极为完善。基本只要准备好数据,在pipeline里面定义好模型参数就可以使用。
###1. 数据准备
object detection 框架接收一个tfrecord格式的文件作为输入,而一般的数据集都是图片和标记单独文件给出的,所以有个格式转换的过程。
tfrecord是一种二进制文件,由数据对象序列化而成。数据对象的格式需要为tf.train.Example, 其成员为tf.train.Feature. 所以数据准备就是:
- 准备数据字典,并转换为Feature存储,字典成员可以参考https://github.com/tensorflow/models/blob/master/research/object_detection/core/standard_fields.py中的定义
- 通过tf.train.Example转换为Example格式
- Example序列化并通过tf.python_io.TFRecordWriter写入文件
代码简单实现如下
import tensorflow as tf
import PIL.Image as im
import numpy as np
import sys,os,io
import xmltodict
import matplotlib.pyplot as plt
import random
%matplotlib inline
dpath = '../../data'
image_dir = os.path.join(dpath, 'images')
annotation_dir = os.path.join(dpath, 'annotations', 'xmls')
tfrecord_dir = 'data'
train_file_name = os.path.join(tfrecord_dir, 'train.record')
annotation_names = os.listdir(annotation_dir)
#在例子中为了方便只是随机选取了10个文件作为input
annotation_names = random.choices(annotation_names, k = 10)
image_names = [ os.path.splitext(annotation_name)[0] + '.jpg' for annotation_name in annotation_names ]
image_paths = [ os.path.join(image_dir, image_name) for image_name in image_names ]
annotation_paths = [ os.path.join(annotation_dir, annotation_name) for annotation_name in annotation_names ]
labels_map = {
'cat': 1,
'dog': 2
}
def build_record(image_path, annotation_path):
with tf.gfile.GFile(image_path, 'rb') as fimage:
image_data = fimage.read()
with im.open(io.BytesIO(image_data)) as fim:
width, height = fim.size
with open(annotation_path, 'rb') as fxml:
xml_data = xmltodict.parse(fxml.read())
bndbox = xml_data['annotation']['object']['bndbox']
#准备feature字典
feature_dict = {
'image/encoded': tf.train.Feature(bytes_list = tf.train.BytesList(value=[image_data])),
'image/format': tf.train.Feature(bytes_list = tf.train.BytesList(value = [ b'jpeg'])),
'image/object/bbox/xmin': tf.train.Feature(float_list = tf.train.FloatList(value = [ float(bndbox['xmin'])/width ])),
'image/object/bbox/ymin': tf.train.Feature(float_list = tf.train.FloatList(value = [ float(bndbox['ymin'])/height ])),
'image/object/bbox/xmax': tf.train.Feature(float_list = tf.train.FloatList(value = [ float(bndbox['xmax'])/width ])),
'image/object/bbox/ymax': tf.train.Feature(float_list = tf.train.FloatList(value = [ float(bndbox['ymax'])/height ])),
}
#打包成features
features = tf.train.Features( feature = feature_dict)
#转换成example格式
example = tf.train.Example(features = features)
return example
with tf.python_io.TFRecordWriter(train_file_name) as writer:
for image_path, annotation_path in zip(image_paths, annotation_paths):
example = build_record(image_path, annotation_path)
#写入tfrecord
writer.write(example.SerializeToString())
上面的代码演示了如何创建一个简单的tfrecord文件。在实际的应用场景中可以利用框架中预先实现的一些转换函数。这些函数定义在https://github.com/tensorflow/models/tree/master/research/object_detection/dataset_tools下,其模块名中包含该代码所对应的数据集。根据实际需要可能需要适当修改
调用过程大致如下
python object_detection/dataset_tools/create_pet_tf_record.py \
--label_map_path=object_detection/data/pet_label_map.pbtxt \
--data_dir=data_dir \
--output_dir=output_dir
###2. 参数配置
object detection 框架中模型并不需要用户创建,当然也支持用户创建,但这些不在本文范围。
config文件主要由5部分组成
model {
(... Add model config here...)
}
train_config : {
(... Add train_config here...)
}
train_input_reader: {
(... Add train_input configuration here...)
}
eval_config: {
}
eval_input_reader: {
(... Add eval_input configuration here...)
}
models:
models中的配置项视所选择的模型而定,可选择的模型在model文件夹下https://github.com/tensorflow/models/tree/master/research/object_detection/models。
train_config:
- 模型训练参数初始化
- 输入预处理
- sgd优化参数
train_input_reader:
主要是两个参数
输入文件 input_path: “/usr/home/username/data/train.record”
标签映射文件 label_map_path: “/usr/home/username/data/label_map.pbtxt”
eval_config:
测试参数配置
eval_input_reader:
主要是两个参数
输入文件 input_path: “/usr/home/username/data/train.record”
标签映射文件 label_map_path: “/usr/home/username/data/label_map.pbtxt”
###3. 模型训练:
主要是调用model_main.py 并配置一些输入参数
PIPELINE_CONFIG_PATH=ssd_mobilenet_v1_pets.config
MODEL_DIR=./model
NUM_TRAIN_STEPS=20003
NUM_EVAL_STEPS=2000
python -m object_detection.model_main \
--pipeline_config_path=${PIPELINE_CONFIG_PATH} \
--model_dir=${MODEL_DIR} \
--num_train_steps=${NUM_TRAIN_STEPS} \
--num_eval_steps=${NUM_EVAL_STEPS} \
--alsologtostderr
###4. 可视化
可以利用https://github.com/tensorflow/models/blob/master/research/object_detection/utils/visualization_utils.py中的visualize_boxes_and_labels_on_image_array函数。
image = Image.open(image_path)
# the array based representation of the image will be used later in order to prepare the
# result image with boxes and labels on it.
image_np = load_image_into_numpy_array(image)
# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(image_np, axis=0)
# Actual detection.
output_dict = run_inference_for_single_image(image_np, detection_graph)
# Visualization of the results of a detection.
vis_util.visualize_boxes_and_labels_on_image_array(
image_np,
output_dict['detection_boxes'],
output_dict['detection_classes'],
output_dict['detection_scores'],
category_index,
instance_masks=output_dict.get('detection_masks'),
use_normalized_coordinates=True,
line_thickness=8)
plt.imsave(os.path.join('./data/output', os.path.basename(image_path)), image_np)
plt.figure(figsize=IMAGE_SIZE)
plt.imshow(image_np)