利用TensorFlow Object Detection API的预训练模型训练自己的数据

文章目录

  • 利用TensorFlow Object Detection API的预训练模型训练自己的数据
  • 1.前言介绍
  • 2.前期准备
  • 2.1环境搭建
  • 2.2数据准备
  • 2.3模型准备
  • 3.训练过程
  • 3.1修改配置文件(config文件)
  • 3.2开始训练
  • 3.3保存模型
  • 3.4Tensorboard实时查看训练效果
  • 4.测试结果


1.前言介绍

  • pb文件为训练好的模型,可以直接拿来使用
  • ckpt文件就是预训练模型,用来训练自己的数据

2.前期准备

  • 准备一个保存收集图片的文件夹,包含Image和Annotations,分别用来保存图片和标注后的xml文件
  • 另外准备一个文件夹放训练有关的数据,里面包含三个下属文件data,export,model,分别用来存放训练可用的数据,生成的最终模型,训练产生的文件

2.1环境搭建

  • 配置Tensorflow环境,Windows或Ubuntu都可

2.2数据准备

  1. 根据自己训练需要收集所需要的图片
  2. 将所收集的图片进行排序后进行筛选然后再排序
    如果是处理自己采集的数据集,一定要先排序再筛选!!否则可能会遗漏掉一些本该筛选的图片在标注时增加自己的工作量
    我用的方法是按顺序对所有文件进行重命名
import os
i = 1
for filename in os.listdir('D:/DataCollection/hand_data/Image/test/'):
	newname = str(i) + '.jpg'
	print(newname)
	os.rename('D:/DataCollection/hand_data/Image/test/'+filename, 'D:/DataCollection/hand_data/Image/test/'+newname)
	i += 1
  1. 对排序后的图片进行标注
    标注图片用的软件是labelImg,可以选择标注的图片位置(Image),以及生成的xml文件保存的位置即Annotations文件夹

W是标注 D是下一张 A是上一张 空格保存

  1. 格式转换
    这里的生成的csv以及tfrecord文件都放在data文件夹下
    图片需转换成tensorflow可以识别的格式
  • 先由xml转为csv
"""
将文件夹内所有XML文件的信息记录到CSV文件中
"""

import os
import glob
import pandas as pd
import xml.etree.ElementTree as ET
 
os.chdir('E:/tensorflow/hand_data_new/hand_data/test')  #xml文件保存路径 使用时需改为自己的路径
path = 'E:/tensorflow/hand_data_new/hand_data/test'


def xml_to_csv(path):
    xml_list = []
    for xml_file in glob.glob(path + '/*.xml'):
        tree = ET.parse(xml_file)
        root = tree.getroot()
        print('test')
        for member in root.findall('object'):
            value = (root.find('filename').text,
                     int(root.find('size')[0].text),
                     int(root.find('size')[1].text),
                     member[0].text,
                     int(member[4][0].text),
                     int(member[4][1].text),
                     int(member[4][2].text),
                     int(member[4][3].text)
                     )
            xml_list.append(value)
    column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
    xml_df = pd.DataFrame(xml_list, columns=column_name)
    return xml_df


def main():
    image_path = path
    xml_df = xml_to_csv(image_path)
    xml_df.to_csv('E:/tensorflow/hand_set/data/eval.csv', index=None)  #得到的csv文件保存路径
    print('Successfully converted xml to csv.')

main()
  • 然后将csv文件转为tfrecord
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

import os
import io
import pandas as pd
import tensorflow as tf

from PIL import Image
from object_detection.utils import dataset_util
from collections import namedtuple, OrderedDict

flags = tf.app.flags

flags.DEFINE_string('csv_input', 'E:/tensorflow/hand_set/data/eval.csv', 'Path to the CSV input')#csv文件
flags.DEFINE_string('output_path', 'E:/tensorflow/hand_set/data/eval.record', 'Path to output TFRecord')#TFRecord文件
flags.DEFINE_string('image_dir', 'E:/tensorflow/hand_data_new/hand_data/Image/TEST', 'Path to images')#对应的图片位置

FLAGS = flags.FLAGS

# TO-DO replace this with label map
#从1开始根据自己训练的类别数和标签来写
def class_text_to_int(row_label):
    if row_label == 'DOWN':
        return 1
    elif row_label == 'FIVE':
        return 2
    else:
        None

def split(df, group):
    data = namedtuple('data', ['filename', 'object'])
    gb = df.groupby(group)
    return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]


def create_tf_example(group, path):
    with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:

        encoded_jpg = fid.read()

    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = Image.open(encoded_jpg_io)
    width, height = image.size

    filename = group.filename.encode('utf8')
    image_format = b'jpg'
    xmins = []
    xmaxs = []
    ymins = []
    ymaxs = []
    classes_text = []
    classes = []

    for index, row in group.object.iterrows():
        xmins.append(row['xmin'] / width)
        xmaxs.append(row['xmax'] / width)
        ymins.append(row['ymin'] / height)
        ymaxs.append(row['ymax'] / height)
        classes_text.append(row['class'].encode('utf8'))
        classes.append(class_text_to_int(row['class']))

    tf_example = tf.train.Example(features=tf.train.Features(feature={

        'image/height': dataset_util.int64_feature(height),

        'image/width': dataset_util.int64_feature(width),

        'image/filename': dataset_util.bytes_feature(filename),

        'image/source_id': dataset_util.bytes_feature(filename),

        'image/encoded': dataset_util.bytes_feature(encoded_jpg),

        'image/format': dataset_util.bytes_feature(image_format),

        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),

        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),

        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),

        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),

        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),

        'image/object/class/label': dataset_util.int64_list_feature(classes),

    }))

    return tf_example


def main(_):

    writer = tf.python_io.TFRecordWriter(FLAGS.output_path)

    path = os.path.join(FLAGS.image_dir)

    examples = pd.read_csv(FLAGS.csv_input)

    grouped = split(examples, 'filename')

    for group in grouped:

        tf_example = create_tf_example(group, path)

        writer.write(tf_example.SerializeToString())


    writer.close()

    output_path = os.path.join(os.getcwd(), FLAGS.output_path)

    print('Successfully created the TFRecords: {}'.format(output_path))



if __name__ == '__main__':

    tf.app.run()
  1. 训练数据准备完以后还需要准备一个pbtxt文件
    例如hand.pbtxt,放在data文件夹里
    内容如下,根据自己的类别数而定
item {
  id: 1
  name: 'DOWN'
}

item{
  id: 2
  name: 'FIVE'
}

2.3模型准备

下载Tensorflow模型

下载地址:https://github.com/tensorflow/models

下载protoc

下载地址:https://github.com/protocolbuffers/protobuf/releases

下载预训练模型

下载地址:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md

  1. 利用protoc编译TensorFlow Object Detection API,转换为py文件
  2. 建立一个专门的文件夹来保存预训练模型,记住下载路径,之后会用到里面的ckpt文件
  3. 在下载的Tensorflow模型的文件下找到models\research\object_detection\samples\configs,在里面找到自己所用的预训练模型对应的config文件,拷贝一份放在最初建立的model文件夹下

3.训练过程

3.1修改配置文件(config文件)

以下仅仅是列出最主要的修改,其他有关训练配置可根据实际情况再做调整

  • 改成自己训练的类别数量
num_classes: 4
  • 根据自己的机器性能适当修改也可以不改
batch_size: 24
  • 改成所用的预训练模型路径
fine_tune_checkpoint: "E:/tensorflow/pretrained_models/ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/model.ckpt"
  • 训练所需的tfrecord文件路径,测试的则改为测试集的路径
tf_record_input_reader {
    input_path: "E:/tensorflow/hand_set/data/train.record"
  }
  • 标签映射文件,即pbtxt文件位置,训练与测试共用一个
label_map_path: "E:/tensorflow/hand_set/data/object_detection.pbtxt"

3.2开始训练

执行语句

python E:/tensorflow/models/research/object_detection/legacy/train.py --train_dir=E:/tensorflow/hand_set/model/model5 --pipeline_config_path=E:/tensorflow/hand_set/model/ssd_mobilenet_v2_quantized_300x300_coco.config --logtostderr
  • train.py在下载的Tensorflow模型文件夹下
  • train_dir是训练时的数据保存位置,放在最初建立的model文件夹下,因为我训练了多个模型,因此我在model文件夹下建立了多个子文件夹命名为model1等等,例如此例子我将该模型保存在model/model5中
  • pipeline是config文件的位置,我放在model文件下

3.3保存模型

最初建立的三个文件夹data是用来存放数据集的,而model是训练时的数据,主要包括各个检查点对应的能够生成模型的ckpt文件,以及训练过程中的信息,而export就是保存我们导出的模型

执行语句

python E:/tensorflow/models/research/object_detection/export_inference_graph.py --input_type image_tensor --pipeline_config_path E:/tensorflow/hand_set/model/ssd_mobilenet_v2_coco.config  --trained_checkpoint_prefix E:/tensorflow/hand_set/model/model4/model.ckpt-5997  --output_directory E:/tensorflow/hand_set/export/model4/
  • export_inference_graph.py在下载的Tensorflow模型文件夹下
  • pipeline_config_path位置同上
  • trained_checkpoint_prefix选择效果最好的检查点来生成模型,一般选择最新的
  • output_directory模型保存路径
  • 最后生成的pb文件就是我们可以用的模型

3.4Tensorboard实时查看训练效果

win+r,输入cmd,执行语句,输入刚刚训练保存模型的绝对路径,记得输入绝对路径不容易出错

tensorboard --logdir=E:\tensorflow\hand_set\model\model5

然后在浏览器里输入https://localhost:6006 就可以查看训练效果了

4.测试结果

  • 在Tensorflow模型文件夹下tensorflow\models\research\object_detection 找到object_detection_tutorial.ipynb文件,将代码复制出来
  • 将模型修改为我们自己训练的模型地址,即pb文件的地址
# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_FROZEN_GRAPH =
'E:/tensorflow/hand_set/export/model4/frozen_inference_graph.pb'
  • pbtxt文件地址也改为我们自己的文件地址
PATH_TO_LABELS = os.path.join('E:/tensorflow/hand_set/data', 'object_detection.pbtxt')
  • 设置测试图片路径
PATH_TO_TEST_IMAGES_DIR = 'test_images'
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]
  • 也可以改为摄像头实时测试
with detection_graph.as_default():
    with tf.Session(graph=detection_graph) as sess:
        while  True:
            ret, image = capture.read()
            if ret is True:
                image_np_expanded = np.expand_dims(image, axis=0)
                image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
                boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
                scores = detection_graph.get_tensor_by_name('detection_scores:0')
                classes = detection_graph.get_tensor_by_name('detection_classes:0')
                num_detections = detection_graph.get_tensor_by_name('num_detections:0')

                (boxes,scores,classes,num_detections)=sess.run([boxes, scores, classes, num_detections],
                                                                feed_dict={image_tensor: image_np_expanded})
                vis_util.visualize_boxes_and_labels_on_image_array(
                    image,
                    np.squeeze(boxes),
                    np.squeeze(classes).astype(np.int32),
                    np.squeeze(scores),
                    category_index,
                    min_score_thresh=0.6, #置信度
                    use_normalized_coordinates=True,
                    line_thickness=4
                )
                c = cv.waitKey(5)
                if c == 27:  # ESC
                    break
                cv.imshow("Demo", image)
            else:
                break
        cv.waitKey(0)
        cv.destoryAllWindows()
  • 运行代码