【YOLOv4-pytorch】训练自己的数据集实践记录及问题总结
使用pytorch-yolov4训练自己的目标检测数据集
代码:https://github.com/Tianxiaomo/pytorch-YOLOv4
预训练模型:
yolov4.pth(链接:https://pan.baidu.com/s/17GivIeUbItyfwVdooUVsDw?pwd=p8ub 提取码:p8ub)
yolov4.conv.137.pth(链接:https://pan.baidu.com/s/1n50M_v_gd2KT4ROx7cySWg?pwd=zdzv 提取码:zdzv)
环境配置和模型测试就不写了,和之前的大同小异,参考附录1的博主写的很详细。
目录
- 【YOLOv4-pytorch】训练自己的数据集实践记录及问题总结
- 一. 数据集准备
- 1. 生成trainId.txt,valId.txt
- 2. 基于train/val.txt生成yolov4要求的train.txt,val.txt
- 二. 训练过程
- 1. 修改cfg.py文件
- 2. 训练
- 三. 训练中遇到的问题&解决方案
- 1. cv2.error: OpenCV(4.5.5) /io/opencv/modules/imgproc/src/color.cpp:182: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'
- 2. OpenCV can't augment image: 608 x 608
- 3. RuntimeError: shape '[4, 3, 20, 76, 76]' is invalid for input of size 5891520
- 4. cv2.error: Caught error in DataLoader worker process 7
- 5. 196 x2num.py[line:14] WARNING: NaN or Inf found in input tensor.
- 6. cv2.error: OpenCV(3.4.4) /io/opencv/modules/imgproc/src/color.cpp:181: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'
一. 数据集准备
Yolov4要求的train.txt,val.txt内容为
影像名,bbox左上角坐标(x1,y1),右下角坐标(x2,y2),类别id
image_path1 x1,y1,x2,y2,id x1,y1,x2,y2,id x1,y1,x2,y2,id ...
image_path2 x1,y1,x2,y2,id x1,y1,x2,y2,id x1,y1,x2,y2,id ...
...
- image_path : Image Name
- x1,y1 : Coordinates of the upper left corner
- x2,y2 : Coordinates of the lower right corner
- id : Object Class
我的数据集是DOTA,不过之前用别的模型训练时将DOTA标签数据转化成了xml,所以这次直接将xml转换为yolov4要求的txt格式。
1. 生成trainId.txt,valId.txt
文本中按行存放影像名。如果你的数据是voc格式,那原来的ImageSets/Main/下的txt文件即为所求,将它改名即可。
import os
txt_path="/home/DOTA/ImageSets/trainId.txt"
file_path = "/home/DOTA/train2017/"
path_list = os.listdir(file_path) #遍历整个文件夹下的文件name并返回一个列表
path_list.sort() #--------------
path_name = []
for i in path_list:
path_name.append(i.split(".")[0]) #若带有后缀名,利用循环遍历path_list列表,split去掉后缀名
#path_name.append(i)
for file_name in path_name:
# "a"表示以不覆盖的形式写入到文件中,当前文件夹如果没有"save.txt"会自动创建
with open(txt_path, "a") as file:
file.write(file_name + "\n")
#print(file_name)
file.close()
2. 基于train/val.txt生成yolov4要求的train.txt,val.txt
使用附录1博主给出的代码略加修改,生成的文件位于/ImageSets/下
import xml.etree.ElementTree as ET
import os
from os import getcwd
sets = [('2017','train'), ('2017','val')]
classes = ['plane', 'baseball-diamond', 'bridge', 'ground-track-field',
'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
'basketball-court', 'storage-tank', 'soccer-ball-field',
'roundabout', 'harbor', 'swimming-pool', 'helicopter']
def convert_annotation(year, image_id, list_file):
in_file=open('/home/DOTA/Annotations/%s.xml'%(image_id),'r',encoding='utf-8')
tree = ET.parse(in_file)
root = tree.getroot()
for obj in root.iter('object'):
difficult = 0
if obj.find('difficult') != None:
difficult = obj.find('difficult').text
cls = obj.find('name').text
if cls not in classes or int(difficult) == 1:
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (int(xmlbox.find('xmin').text), int(xmlbox.find('ymin').text), int(xmlbox.find('xmax').text),
int(xmlbox.find('ymax').text))
list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
wd = "/home/DOTA" #getcwd()
for year, image_set in sets:
image_ids = open('/home/DOTA/ImageSets/%sId.txt' % (image_set)).read().strip().split()
list_file = open('/home/DOTA/ImageSets/%s.txt' % (image_set), 'w')
for image_id in image_ids:
list_file.write('%s/%s2017/%s.png' % (wd, image_set,image_id))
convert_annotation(year, image_id, list_file)
list_file.write('\n')
list_file.close()
二. 训练过程
1. 修改cfg.py文件
必须修改的:类别数,label路径
Cfg.classes = 15
Cfg.train_label = “/home/DOTA/ImageSets/train.txt”
Cfg.val_label = “/home/n/DOTA/ImageSets/val.txt”
按照自己需求设置的:batch一般设置的8的倍数,在内存不溢出的情况下设置为32,64,96等值,batch设置愈大,对内存要求越高;subdivisions一般设置为batch的1/4或1/8,设置越大,对内存要求越小;接下来两项其实不一定要修改,max_batches一般设置为类别数*2000,steps范围为80%-90%*max_batches;EPOCHS按照需要设置,其他点动量学习率之类保持默认或者按照需要设置。(更清楚的说明参考附录4)
Cfg.batch = 64
Cfg.subdivisions = 8
Cfg.max_batches = 30000
Cfg.steps = [24000, 27000]
Cfg.TRAIN_EPOCHS = 500
这里有一个Cfg.use_darknet_cfg = True,后续可能报错如问题3,此时将其改为 False.
2. 训练
python train.py -l 0.0001 -g 0 -pretrained ./weights/yolov4.conv.137.pth -classes 15 -dir /home/DOTA/train2017 -train_label_path /home/DOTA/ImageSets/train.txt
#l learning rate
#g gpu id
#pretrained Pre-trained backbone network, converted from yolov4.conv.137 of darknet given by AlexeyAB
#classes NO. of classes
#dir Training image dir
三. 训练中遇到的问题&解决方案
1. cv2.error: OpenCV(4.5.5) /io/opencv/modules/imgproc/src/color.cpp:182: error: (-215:Assertion failed) !_src.empty() in function ‘cvtColor’
图片没有成功打开,总结一下,原因类型总共有以下几种:
(1)90%是因为图片路径错误,检查图片路径,将路径中的“\”改为“/”(前者具有转义功能);相对路径改为绝对路径;路径不要包含中文名;检查后缀名是否完整正确。
(2)检查标签文件train.txt以及val.txt转换是否有误,检查图片文件夹中图片名称数量和标签文件是否对应,图片文件夹是否存在其它多余文件。
后问题(6)也是和这个报错差不多,不过那是另一个原因了。
2. OpenCV can’t augment image: 608 x 608
opencv版本太高,降低版本
pip install opencv_python==3.4.4.19
3. RuntimeError: shape ‘[4, 3, 20, 76, 76]’ is invalid for input of size 5891520
定位到语句是这句:
output = output.view(batchsize, self.n_anchors, n_ch, fsize, fsize)
有一个说法是
在model.py文件中有一个bug,Yolov4类中,将Classes定义为80,将它改为自己的类别数,不过试了一下没有用。。。
解决方法:将cfg.py中的Cfg.use_darknet_cfg = True
改为Cfg.use_darknet_cfg = False
4. cv2.error: Caught error in DataLoader worker process 7
这个问题是训练开始了一会儿时弹出来的,
方法,将train.py中的train_loader的num_workers值改为0
5. 196 x2num.py[line:14] WARNING: NaN or Inf found in input tensor.
可能是初始学习率过大导致梯度消失或者梯度爆炸,将初始学习率降低。
6. cv2.error: OpenCV(3.4.4) /io/opencv/modules/imgproc/src/color.cpp:181: error: (-215:Assertion failed) !_src.empty() in function ‘cvtColor’
很奇怪,在训练开始一段时间后总会突然弹出这个错误,而我检查了图片文件夹、train.txt文件内容、类别,卡了一整天也不知道哪儿错了,最后发现txt文件应该是image_path,x1,y2,x2,y2的格式,我使用的博主xml转txt的代码(参考3)中这儿误写成了x1,x2,y1,y2,把这个错误改正后,再训练试试。(这儿确实有错误但是并不是这个导致的)
…过去了三天了…
我!终于!!找到了原因!!!
花了两天检查是不是我的标签文件中哪里有错误加上一天消极怠工,都想要放弃这个版本的yolov4了,今天用了最简单的方法找到了原因
报错定位到代码上是这句 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB),于是试着在执行代码前把图片名打印出来看看是哪张图片出了错误,这也正是我之前一直很疑惑的事情,每次是第一个Epoch读取图片一段时间报这个错误,训练终止,而每次终止时读取图片的数目并不一样,这说明并不是我以为的那样按照train.txt文件中的影像目录来读取(我排好了序),事实证明读取图像确实是打乱顺序的。终于!找到!原因了!,几次训练终止打印出的图片名都有一个特点–没有目标框(除了进行数据增强时裁剪等操作会造成这个错误,我的DOTA数据集本身也含有十来张负样本影像),这说明这个报错错误并不是没有成功读进来图片,而是图片标签读取失败,把这几张图删除,终于成功训练了(热泪盈眶)。
ps:开始训练一段时间时,我也弹出了You could also create your own ‘get_image_id’ function的关于get_image_id的问题,看了一下,就是需要每张图片返回一个int型的id值,这个是要求根据自己的图片名称修改dataset.py中的get_image_id函数的。比如,DOTA数据集命名为PXXXX.png的格式,所以使用如下代码,这样比如P0001.png就获取到1作为id值。
parts = filename.split('/')
id=int(parts[-1][1:-4]) #修改这句
return id
可以跑了
训练中…