前言
项目地址:https://github.com/zzubqh/CenterNetDect-pytorch.git
之前写过一篇CenterNet源码结构解析,能看到官方的源码结构有点复杂,虽然已经剖析过了但是真到实际应用中查找细节的时候还是有点绕,虽然文章中也列出了另外一个简单实现,但是个人感觉结构依然不是太清晰,所以基于这个简单实现的源码进行了重构,主要修改的地方:
- 将数据增强部分封装成了一个类,基于imgaug包实现;
- 去掉了基于coco数据集的依赖,改成解析自己的数据集,数据集的格式变得很简单,格式如下:
“文件路径 x1,y1,x2,y2,class_id x1,y1,x2,y2,class_id"
F:\code\Data\DectDataset\case_006/coronal/245.jpg 762,112,850,195,3 756,292,843,366,3
F:\code\Data\DectDataset\case_024/sagittal/69.jpg 736,139,817,212,3
F:\code\Data\DectDataset\case_024/sagittal/67.jpg 737,136,822,213,3
F:\code\Data\DectDataset\case_012/sagittal/179.jpg 834,95,912,161,3
- 重写训练代码,改成了一个类,不再将训练函数和验证函数放在一起而是封装到了一个类中
- 重写了验证函数,使得验证后给出的结果更加直观明了
- 去掉了pose预测的代码,只专注于目标检测
- 修正了原项目中的一些错误,比如在hourglass.py源文件中有句代码“if self.training or ind == self.nstack - 1:”,但是整个类是没有self.training这个属性的,还有在get_hourglass中,源作者没有将类的个数即num_classess这个变量传入,所以整个工程默认80个类,即便在train.py中给了num_classess的值在构造hourglass实例的时候依然是80个类
- 基本上除了保留了核心的hourglass.py网络结构以外全部进行了重写,但依然很感谢https://github.com/zzzxxxttt/pytorch_simple_CenterNet_45作者的辛苦,表示敬意!
训练自己的数据
网络架构配置在config.py的cfg.arch =‘large_hourglass’ 中,名称只支持large_hourglass和small_hourglass,在large_hourglass模式下,显存至少8G以上。
数据准备
使用你自己熟悉的打标工具进行标注,如果使用的是labelme 4.5.6版本标注的,可以使用以下代码转成上面的数据格式
import os
import json
def create_image_dataset():
"""
输出label.txt格式: image_path x1,y1,x2,y2,class_id1 x1,y1,x2,y2,class_id2
"""
annotation_file = r'bone_annotation.md' # 输出的label文件
input_root_dir = r'F:\code\Data\DectDataset'
jpeg_dir = [os.path.join(input_root_dir, 'img_dir']
with open(annotation_file, 'w', encoding='utf-8') as wf:
for child_dir in jpeg_dir:
json_files = glob.glob(child_dir + '/*.json')
print('load {0} json files success!'.format(child_dir))
json_values = [json_pares(json_file) for json_file in json_files]
labeled_img = {item['img_name']: {'bbox': item['bbox'], 'label': item['label']} for item in json_values}
img_files = glob.glob(child_dir + '/*.jpg')
for img_path in tqdm.tqdm(img_files):
img_name = os.path.basename(img_path)
bbox = []
label = []
key_name = img_name
if key_name in labeled_img.keys():
bbox = labeled_img[key_name]['bbox']
label = labeled_img[key_name]['label']
line_str = img_path
for box_index, rect in enumerate(bbox):
item_str = ' {0},{1},{2},{3},{4}'.format(rect[0, 0], rect[0, 1], rect[1, 0], rect[1, 1], label[box_index])
line_str += item_str
wf.write(line_str + '\r')
def json_pares(json_file):
value_data = dict()
with open(json_file, 'r', encoding='utf-8') as rf:
json_data = json.load(rf)
value_data.setdefault('img_name', json_data['imagePath'].replace('png', 'jpg'))
value_data.setdefault('bbox', [])
value_data.setdefault('label', [])
for shape in json_data['shapes']:
value_data['label'].append(shape['label'].lower())
p1 = shape['points'][0]
p2 = shape['points'][1]
bbox = np.array([[p1[0], p1[1]], [p2[0], p2[1]]], dtype=np.int)
value_data['bbox'].append(bbox)
return value_data
注意:由于labelme中的label是文字描述的,而代码中需要使用label对应的序号,这个地方需要注意一下,可以在代码中直接改掉,也可以转换好后通过“查找-替换”来做,请自行搞定。
代码修改
- 在dataset.py文件中,找到self.max_objs,这里是一张图片中最多出现几个object,比如你的数据集中一张图片中最多要检测100个目标,则这里改成100,根据实际情况修改;将class_names = []修改成你的类别的名称
- 在config.py文件中,cfg.net_input_size是网络的输入图片尺寸,按[w, h]格式输入,我的是长方形的图片所以我填的是[512, 256];还有一个是cfg.num_classes修改成你数据集的类别总数,后面加1是因为背景类,所以num_class = 你的类别总数 + 1
- dataset.py文件中,有个get_image_id的函数,这里的作用是为了后面的验证时使用的id,只要保证不同的图片不一样的值就行了,我只是给出了一个例子,也可以直接使用UUID来替代
训练
根据实际情况修改好上面的内容后,开始训练
python train.py
训练过程:
测试
修改detect.py文件,分别将weights_file和img_path改成你模型的保存路径和测试图片的路径即可。
运行:
python detect.py
会自动将检测的结果显示出来