利用TensorFlow Object Detection API的预训练模型训练自己的数据
文章目录
- 利用TensorFlow Object Detection API的预训练模型训练自己的数据
- 1.前言介绍
- 2.前期准备
- 2.1环境搭建
- 2.2数据准备
- 2.3模型准备
- 3.训练过程
- 3.1修改配置文件(config文件)
- 3.2开始训练
- 3.3保存模型
- 3.4Tensorboard实时查看训练效果
- 4.测试结果
1.前言介绍
- pb文件为训练好的模型,可以直接拿来使用
- ckpt文件就是预训练模型,用来训练自己的数据
2.前期准备
- 准备一个保存收集图片的文件夹,包含Image和Annotations,分别用来保存图片和标注后的xml文件
- 另外准备一个文件夹放训练有关的数据,里面包含三个下属文件data,export,model,分别用来存放训练可用的数据,生成的最终模型,训练产生的文件
2.1环境搭建
- 配置Tensorflow环境,Windows或Ubuntu都可
2.2数据准备
- 根据自己训练需要收集所需要的图片
- 将所收集的图片进行排序后进行筛选然后再排序
如果是处理自己采集的数据集,一定要先排序再筛选!!否则可能会遗漏掉一些本该筛选的图片在标注时增加自己的工作量
我用的方法是按顺序对所有文件进行重命名
import os
i = 1
for filename in os.listdir('D:/DataCollection/hand_data/Image/test/'):
newname = str(i) + '.jpg'
print(newname)
os.rename('D:/DataCollection/hand_data/Image/test/'+filename, 'D:/DataCollection/hand_data/Image/test/'+newname)
i += 1
- 对排序后的图片进行标注
标注图片用的软件是labelImg,可以选择标注的图片位置(Image),以及生成的xml文件保存的位置即Annotations文件夹
W是标注 D是下一张 A是上一张 空格保存
- 格式转换
这里的生成的csv以及tfrecord文件都放在data文件夹下
图片需转换成tensorflow可以识别的格式
- 先由xml转为csv
"""
将文件夹内所有XML文件的信息记录到CSV文件中
"""
import os
import glob
import pandas as pd
import xml.etree.ElementTree as ET
os.chdir('E:/tensorflow/hand_data_new/hand_data/test') #xml文件保存路径 使用时需改为自己的路径
path = 'E:/tensorflow/hand_data_new/hand_data/test'
def xml_to_csv(path):
xml_list = []
for xml_file in glob.glob(path + '/*.xml'):
tree = ET.parse(xml_file)
root = tree.getroot()
print('test')
for member in root.findall('object'):
value = (root.find('filename').text,
int(root.find('size')[0].text),
int(root.find('size')[1].text),
member[0].text,
int(member[4][0].text),
int(member[4][1].text),
int(member[4][2].text),
int(member[4][3].text)
)
xml_list.append(value)
column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
xml_df = pd.DataFrame(xml_list, columns=column_name)
return xml_df
def main():
image_path = path
xml_df = xml_to_csv(image_path)
xml_df.to_csv('E:/tensorflow/hand_set/data/eval.csv', index=None) #得到的csv文件保存路径
print('Successfully converted xml to csv.')
main()
- 然后将csv文件转为tfrecord
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
import os
import io
import pandas as pd
import tensorflow as tf
from PIL import Image
from object_detection.utils import dataset_util
from collections import namedtuple, OrderedDict
flags = tf.app.flags
flags.DEFINE_string('csv_input', 'E:/tensorflow/hand_set/data/eval.csv', 'Path to the CSV input')#csv文件
flags.DEFINE_string('output_path', 'E:/tensorflow/hand_set/data/eval.record', 'Path to output TFRecord')#TFRecord文件
flags.DEFINE_string('image_dir', 'E:/tensorflow/hand_data_new/hand_data/Image/TEST', 'Path to images')#对应的图片位置
FLAGS = flags.FLAGS
# TO-DO replace this with label map
#从1开始根据自己训练的类别数和标签来写
def class_text_to_int(row_label):
if row_label == 'DOWN':
return 1
elif row_label == 'FIVE':
return 2
else:
None
def split(df, group):
data = namedtuple('data', ['filename', 'object'])
gb = df.groupby(group)
return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
def create_tf_example(group, path):
with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = Image.open(encoded_jpg_io)
width, height = image.size
filename = group.filename.encode('utf8')
image_format = b'jpg'
xmins = []
xmaxs = []
ymins = []
ymaxs = []
classes_text = []
classes = []
for index, row in group.object.iterrows():
xmins.append(row['xmin'] / width)
xmaxs.append(row['xmax'] / width)
ymins.append(row['ymin'] / height)
ymaxs.append(row['ymax'] / height)
classes_text.append(row['class'].encode('utf8'))
classes.append(class_text_to_int(row['class']))
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(filename),
'image/source_id': dataset_util.bytes_feature(filename),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature(image_format),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
}))
return tf_example
def main(_):
writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
path = os.path.join(FLAGS.image_dir)
examples = pd.read_csv(FLAGS.csv_input)
grouped = split(examples, 'filename')
for group in grouped:
tf_example = create_tf_example(group, path)
writer.write(tf_example.SerializeToString())
writer.close()
output_path = os.path.join(os.getcwd(), FLAGS.output_path)
print('Successfully created the TFRecords: {}'.format(output_path))
if __name__ == '__main__':
tf.app.run()
- 训练数据准备完以后还需要准备一个pbtxt文件
例如hand.pbtxt,放在data文件夹里
内容如下,根据自己的类别数而定
item {
id: 1
name: 'DOWN'
}
item{
id: 2
name: 'FIVE'
}
2.3模型准备
下载Tensorflow模型
下载protoc
下载预训练模型
- 利用protoc编译TensorFlow Object Detection API,转换为py文件
- 建立一个专门的文件夹来保存预训练模型,记住下载路径,之后会用到里面的ckpt文件
- 在下载的Tensorflow模型的文件下找到models\research\object_detection\samples\configs,在里面找到自己所用的预训练模型对应的config文件,拷贝一份放在最初建立的model文件夹下
3.训练过程
3.1修改配置文件(config文件)
以下仅仅是列出最主要的修改,其他有关训练配置可根据实际情况再做调整
- 改成自己训练的类别数量
num_classes: 4
- 根据自己的机器性能适当修改也可以不改
batch_size: 24
- 改成所用的预训练模型路径
fine_tune_checkpoint: "E:/tensorflow/pretrained_models/ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/model.ckpt"
- 训练所需的tfrecord文件路径,测试的则改为测试集的路径
tf_record_input_reader {
input_path: "E:/tensorflow/hand_set/data/train.record"
}
- 标签映射文件,即pbtxt文件位置,训练与测试共用一个
label_map_path: "E:/tensorflow/hand_set/data/object_detection.pbtxt"
3.2开始训练
执行语句
python E:/tensorflow/models/research/object_detection/legacy/train.py --train_dir=E:/tensorflow/hand_set/model/model5 --pipeline_config_path=E:/tensorflow/hand_set/model/ssd_mobilenet_v2_quantized_300x300_coco.config --logtostderr
- train.py在下载的Tensorflow模型文件夹下
- train_dir是训练时的数据保存位置,放在最初建立的model文件夹下,因为我训练了多个模型,因此我在model文件夹下建立了多个子文件夹命名为model1等等,例如此例子我将该模型保存在model/model5中
- pipeline是config文件的位置,我放在model文件下
3.3保存模型
最初建立的三个文件夹data是用来存放数据集的,而model是训练时的数据,主要包括各个检查点对应的能够生成模型的ckpt文件,以及训练过程中的信息,而export就是保存我们导出的模型
执行语句
python E:/tensorflow/models/research/object_detection/export_inference_graph.py --input_type image_tensor --pipeline_config_path E:/tensorflow/hand_set/model/ssd_mobilenet_v2_coco.config --trained_checkpoint_prefix E:/tensorflow/hand_set/model/model4/model.ckpt-5997 --output_directory E:/tensorflow/hand_set/export/model4/
- export_inference_graph.py在下载的Tensorflow模型文件夹下
- pipeline_config_path位置同上
- trained_checkpoint_prefix选择效果最好的检查点来生成模型,一般选择最新的
- output_directory模型保存路径
- 最后生成的pb文件就是我们可以用的模型
3.4Tensorboard实时查看训练效果
win+r,输入cmd,执行语句,输入刚刚训练保存模型的绝对路径,记得输入绝对路径不容易出错
tensorboard --logdir=E:\tensorflow\hand_set\model\model5
然后在浏览器里输入https://localhost:6006 就可以查看训练效果了
4.测试结果
- 在Tensorflow模型文件夹下tensorflow\models\research\object_detection 找到object_detection_tutorial.ipynb文件,将代码复制出来
- 将模型修改为我们自己训练的模型地址,即pb文件的地址
# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_FROZEN_GRAPH =
'E:/tensorflow/hand_set/export/model4/frozen_inference_graph.pb'
- pbtxt文件地址也改为我们自己的文件地址
PATH_TO_LABELS = os.path.join('E:/tensorflow/hand_set/data', 'object_detection.pbtxt')
- 设置测试图片路径
PATH_TO_TEST_IMAGES_DIR = 'test_images'
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]
- 也可以改为摄像头实时测试
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
while True:
ret, image = capture.read()
if ret is True:
image_np_expanded = np.expand_dims(image, axis=0)
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
scores = detection_graph.get_tensor_by_name('detection_scores:0')
classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
(boxes,scores,classes,num_detections)=sess.run([boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
vis_util.visualize_boxes_and_labels_on_image_array(
image,
np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
category_index,
min_score_thresh=0.6, #置信度
use_normalized_coordinates=True,
line_thickness=4
)
c = cv.waitKey(5)
if c == 27: # ESC
break
cv.imshow("Demo", image)
else:
break
cv.waitKey(0)
cv.destoryAllWindows()
- 运行代码