目标检测中一些对数据集的处理脚本

解释说明: 该部分脚本是作者在做目标检测算法时所用到的一些数据集的处理脚本。脚本主要包含以下内容

  1. 将xml数据转换为coco数据
  2. 数据集的划分,分别有传统的划分方式和按照类别进行划分的方式
  3. 获取数据集中的包含的类别和各类别包含的个数
  4. 将含有不同标注的xml文件合并
  5. 删除xml文件中的指定类别
  6. 删除xml文件中空类别的xml文件
  7. 获取包含指定类别的xml文件
  8. 删除xml文件中图像边缘的类别.
  9. 删除xml文件中超出图像的框
  10. 找出xml文件中图像size记录为0的文件
  11. 为xml文件添加某些属性
  12. 查找那些无法用CV2读取的图片
  13. 查找那些无法转换为矩阵的图片
  14. 将标注后的数据中指定类别的图像裁剪出来
  15. 数据的扩增处理

1 将VOC格式的数据集转换为COCO数据集格式

1.1 直接将所有xml文件转换

脚本描述: 对已经换分好验证集合测试集的VOC数据集,将数据集中的XML文件写入到json文件中。

#!/usr/bin/python
# -*- coding:utf-8 -*-
# @Author: hj
# @Time: 2018-01-29 2020-10-9
# @Description: 

import os, sys, json

from xml.etree.ElementTree import ElementTree, Element

data_root = r'/home/blue/hanjian/datasets/2021_bisai_data/fushusheshi/'   # VOC数据集的根目录

TXT_PATH = data_root + 'ImageSets/Main/train.txt'       # train.txt和val.txt的文件位置
XML_PATH = data_root + 'Annotations'                    # xml标签的目录
JSON_PATH = data_root + 'train.json'                    # json文件的存储目录

json_obj = {}
images = []
annotations = []
categories = []
categories_list = []
image_id = 0
annotation_id = 0
catID = 0
classes = ("fnsssh","ghptwbq","gjptwbq","ghpsh","gjpsh","sbpsh")    # 要转换的类别信息

def read_xml(in_path):
    tree = ElementTree()
    tree.parse(in_path)
    return tree

def if_match(node, kv_map):
    for key in kv_map:
        if node.get(key) != kv_map.get(key):
            return False
    return True
    
def get_node_by_keyvalue(nodelist, kv_map):
    result_nodes = []
    for node in nodelist:
        if if_match(node, kv_map):
            result_nodes.append(node)
    return result_nodes

def find_nodes(tree, path):
    return tree.findall(path)

print("-----------------Start------------------")

xml_names = []
sum = 0

f = open(TXT_PATH)
lines = f.readlines()
for line in lines:
    line = line.strip("\r\n") + ".xml"
    #print line
    xml_names.append(line)
    sum = sum + 1
#print "xml",sum
f.close()

for xml in xml_names:
    flag = False
    tree = read_xml(XML_PATH + "/" + xml)
    object_nodes = get_node_by_keyvalue(find_nodes(tree, "object"), {})
    if len(object_nodes) == 0:
        image = {}
        file_name = os.path.splitext(xml)[0]
        image["file_name"] = file_name + ".jpg"
        width_nodes = get_node_by_keyvalue(find_nodes(tree, "size/width"), {})
        image["width"] = int(width_nodes[0].text)
        height_nodes = get_node_by_keyvalue(find_nodes(tree, "size/height"), {})
        image["height"] = int(height_nodes[0].text)
        image["id"] = image_id
        print(xml, "no object")
    else:
        image = {}
        file_name = os.path.splitext(xml)[0]
        image["file_name"] = file_name + ".jpg"
        width_nodes = get_node_by_keyvalue(find_nodes(tree, "size/width"), {})
        image["width"] = int(width_nodes[0].text)
        height_nodes = get_node_by_keyvalue(find_nodes(tree, "size/height"), {})
        image["height"] = int(height_nodes[0].text)
        image["id"] = image_id 


        name_nodes = get_node_by_keyvalue(find_nodes(tree, "object/name"), {})
        xmin_nodes = get_node_by_keyvalue(find_nodes(tree, "object/bndbox/xmin"), {})
        ymin_nodes = get_node_by_keyvalue(find_nodes(tree, "object/bndbox/ymin"), {})
        xmax_nodes = get_node_by_keyvalue(find_nodes(tree, "object/bndbox/xmax"), {})
        ymax_nodes = get_node_by_keyvalue(find_nodes(tree, "object/bndbox/ymax"), {})
       # print ymax_nodes
        for index, node in enumerate(object_nodes):
            annotation = {}
            segmentation = []
            bbox = []
            seg_coordinate = [] 
            seg_coordinate.append(int(xmin_nodes[index].text))
            seg_coordinate.append(int(ymin_nodes[index].text))
            seg_coordinate.append(int(xmin_nodes[index].text))
            seg_coordinate.append(int(ymax_nodes[index].text))
            seg_coordinate.append(int(xmax_nodes[index].text))
            seg_coordinate.append(int(ymax_nodes[index].text))
            seg_coordinate.append(int(xmax_nodes[index].text))
            seg_coordinate.append(int(ymin_nodes[index].text))
            segmentation.append(seg_coordinate)
            width = int(xmax_nodes[index].text) - int(xmin_nodes[index].text)
            height = int(ymax_nodes[index].text) - int(ymin_nodes[index].text)
            area = width * height
            bbox.append(int(xmin_nodes[index].text))
            bbox.append(int(ymin_nodes[index].text))
            bbox.append(width)
            bbox.append(height)

            annotation["segmentation"] = segmentation
            annotation["area"] = area
            annotation["iscrowd"] = 0
            annotation["image_id"] = image_id
            annotation["bbox"] = bbox
            if name_nodes[index].text not in classes:
                continue
            else:
                annotation["category_id"] = name_nodes[index].text
            annotation["id"] = annotation_id
            annotation_id += 1
            annotation["ignore"] = 0
            annotations.append(annotation)
            flag = True

    if flag:
        images.append(image)
        image_id += 1
        print("processing " + xml)

cat_iter = 0
for idx, i  in enumerate(classes):
    categorie = {}
    categorie["supercategory"] = None
    categorie["id"] = i
    categorie["name"] = i
    categories.append(categorie)


json_obj["images"] = images
json_obj["type"] = "instances"
json_obj["annotations"] = annotations
json_obj["categories"] = categories

f = open(JSON_PATH, "w")
#json.dump(json_obj, f)
json_str = json.dumps(json_obj)
f.write(json_str)

print("------------------End-------------------")

1.2 通过txt文件转换

解释说明: 读取txt文件中的xml文件名,然后去读取xml文件

import os, sys, json

from xml.etree.ElementTree import ElementTree, Element


TXT_PATH = 'D:/Dataset/VOCdevkit/VOC2012/ImageSets/Main/val.txt'
XML_PATH = 'D:/Dataset/VOCdevkit/VOC2012/Annotations'
JSON_PATH = 'D:/Dataset/val.json'



json_obj = {}
images = []
annotations = []
categories = []
categories_list = []
image_id = 0
annotation_id = 0
catID = 0

classes = ('aeroplane', 'aeroplane', 'bird', 'boat','bottle','bottle','bus','car','cat','chair','cow','diningtable','dog',
           'horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor')



def read_xml(in_path):
    tree = ElementTree()
    tree.parse(in_path)
    return tree


def if_match(node, kv_map):
    for key in kv_map:
        if node.get(key) != kv_map.get(key):
            return False
    return True


def get_node_by_keyvalue(nodelist, kv_map):
    result_nodes = []
    for node in nodelist:
        if if_match(node, kv_map):
            result_nodes.append(node)
    return result_nodes


def find_nodes(tree, path):
    return tree.findall(path)


print("-----------------Start------------------")

xml_names = []
sum = 0

f = open(TXT_PATH)
lines = f.readlines()
for line in lines:
    line = line.strip("\r\n") + ".xml"
    #print line
    xml_names.append(line)
    sum = sum + 1
#print "xml",sum
f.close()



for xml in xml_names:
    flag = False
    tree = read_xml(XML_PATH + "/" + xml)
    object_nodes = get_node_by_keyvalue(find_nodes(tree, "object"), {})
    if len(object_nodes) == 0:
        image = {}
        file_name = os.path.splitext(xml)[0]
        image["file_name"] = file_name + ".jpg"
        width_nodes = get_node_by_keyvalue(find_nodes(tree, "size/width"), {})
        image["width"] = int(width_nodes[0].text)
        height_nodes = get_node_by_keyvalue(find_nodes(tree, "size/height"), {})
        image["height"] = int(height_nodes[0].text)
        image["id"] = image_id
        print(xml, "no object")
    else:
        image = {}
        file_name = os.path.splitext(xml)[0]
        image["file_name"] = file_name + ".jpg"
        width_nodes = get_node_by_keyvalue(find_nodes(tree, "size/width"), {})
        image["width"] = int(width_nodes[0].text)
        height_nodes = get_node_by_keyvalue(find_nodes(tree, "size/height"), {})
        image["height"] = int(height_nodes[0].text)
        image["id"] = image_id 


        name_nodes = get_node_by_keyvalue(find_nodes(tree, "object/name"), {})
        xmin_nodes = get_node_by_keyvalue(find_nodes(tree, "object/bndbox/xmin"), {})
        ymin_nodes = get_node_by_keyvalue(find_nodes(tree, "object/bndbox/ymin"), {})
        xmax_nodes = get_node_by_keyvalue(find_nodes(tree, "object/bndbox/xmax"), {})
        ymax_nodes = get_node_by_keyvalue(find_nodes(tree, "object/bndbox/ymax"), {})
       # print ymax_nodes
        for index, node in enumerate(object_nodes):
            annotation = {}
            segmentation = []
            bbox = []
            seg_coordinate = [] 
            seg_coordinate.append(int(xmin_nodes[index].text))
            seg_coordinate.append(int(ymin_nodes[index].text))
            seg_coordinate.append(int(xmin_nodes[index].text))
            seg_coordinate.append(int(ymax_nodes[index].text))
            seg_coordinate.append(int(xmax_nodes[index].text))
            seg_coordinate.append(int(ymax_nodes[index].text))
            seg_coordinate.append(int(xmax_nodes[index].text))
            seg_coordinate.append(int(ymin_nodes[index].text))
            segmentation.append(seg_coordinate)
            width = int(xmax_nodes[index].text) - int(xmin_nodes[index].text)
            height = int(ymax_nodes[index].text) - int(ymin_nodes[index].text)
            area = width * height
            bbox.append(int(xmin_nodes[index].text))
            bbox.append(int(ymin_nodes[index].text))
            bbox.append(width)
            bbox.append(height)

            annotation["segmentation"] = segmentation
            annotation["area"] = area
            annotation["iscrowd"] = 0
            annotation["image_id"] = image_id
            annotation["bbox"] = bbox
            if name_nodes[index].text not in classes:
                continue
            else:
                annotation["category_id"] = name_nodes[index].text
            annotation["id"] = annotation_id
            annotation_id += 1
            annotation["ignore"] = 0
            annotations.append(annotation)
            flag = True

    if flag:
        images.append(image)
        image_id += 1
        print("processing " + xml)

cat_iter = 0
for idx, i  in enumerate(classes):
    categorie = {}
    categorie["supercategory"] = None
    categorie["id"] = i
    categorie["name"] = i
    categories.append(categorie)


json_obj["images"] = images
json_obj["type"] = "instances"
json_obj["annotations"] = annotations
json_obj["categories"] = categories

f = open(JSON_PATH, "w")
#json.dump(json_obj, f)
json_str = json.dumps(json_obj)
f.write(json_str)

print("------------------End-------------------")

2 VOC数据集的训练集和验证集的划分

2.1 传统的划分方式

脚本说明: 按照传统的划分方式,数据集的训练集和验证集的数据是按照固定比例划分的,但是此种划分存在的局限性是:训练集和验证集中具体的类别不是按照固定比例的。

#!/usr/bin/python
# -*- coding:utf-8 -*-
# @author: hj
# @description: 数据预处理:根据xml划分trainval、train、val、test.txt

import os
import random
'''
		假设总数据为:100
		 trainval_percent = 0.7   
		 train_percent = 0.8 
		则训练集  train = 100 * 0.7 * 0.8 = 56
		验证集   val = 100 * 0.7 * (1-0.8) = 14 
		测试集   test = 100 *(1 - 0.7) = 30
	
	注意: 在实际目标检测中(工程应用中)一般是不需要测试集的,只有在比赛环境中才可能用到测试集。
		可根据自己的实际情况选取是否要测试集。
'''
trainval_percent = 0.7   # 训练集+验证集的比例为0.7   测试集的比例为0.3
train_percent = 0.8      # 训练集和验证集中训练集所占比例

xmlfilepath = '/home/djm/djm_work_dir/mmdetection/data/datasets/anyi/xml/'    # xml文件路径
txtsavepath = '/home/djm/djm_work_dir/mmdetection/data/datasets/anyi/annotations/' # txt文件保存路径

# 历遍"Annotations"文件夹然后返回列表
total_xml = os.listdir(xmlfilepath)
total_xml.sort()
#total_xml.sort(key = lambda x: int(x[:-4]))

# 获取列表的总数
num = len(total_xml)
numlist = range(num)


tv = int(num * trainval_percent)
tr = int(tv * train_percent)

trainval = random.sample(numlist, tv)
trainval.sort()
test = list(set(numlist).difference(set(trainval)))
test.sort()

train = random.sample(trainval, tr)
train.sort()
val = list(set(trainval).difference(set(train)))
val.sort()

ftrainval = open(os.path.join(txtsavepath, 'trainval.txt'), 'w')
ftest = open(os.path.join(txtsavepath, 'test.txt'), 'w')
ftrain = open(os.path.join(txtsavepath, 'train.txt'), 'w')
fval = open(os.path.join(txtsavepath, 'val.txt'), 'w')

for i in numlist:
    # 使用切片方法获取文件名(去掉后缀".xml")
    name = total_xml[i][:-4] + '\n'
    if i in trainval:
        ftrainval.write(name)
        if i in train:
            ftrain.write(name)
        else:
            fval.write(name)
    else:
        ftest.write(name)

ftrainval.close()
ftrain.close()
fval.close()
ftest.close()

2.1 按照类别数对数据集进行划分

脚本说明: 由于传统的数据划分存在局限性,如:当A类有100个,B类有100个,此时按照传统的划分方式,可能将A类划分30个而B类只划分10个,存在较大的随机性。

本次脚本的划分方式是按照数据集中的类别比例进行划分,如我们划分可能是A类20个和B类20个。

#!/usr/bin/python
# -*- coding:utf-8 -*-
# @author: hj



#数据集划分
import os
import random
import xml.etree.ElementTree as ET

root_dir=r''  # xml文件路径

classname = ["bl_jyz_wh","bl_jyz_zb","c_jyz_wh","czjyz","fhjyz","jyhyw"]     # 类别

## 0.7train 0.1val 0.2test

total_xml = os.listdir(root_dir)
need_xml = []

# 获取到每个类别包含的xml文件名字
for obj_name in classname:
    name_num = []
    for obj_xml in total_xml:
        tree = ET.parse(root_dir + obj_xml)
        root = tree.getroot()
        for root_name in root.findall("object"):
            root_name = root_name.find("name")
            if root_name.text == obj_name:
                if obj_xml not in name_num:
                    name_num.append(obj_xml)

    # 从每个类别包含的xml文件列表中随机抽取20%的xml文件

    classname_xml = []
    num = len(name_num)
    list = range(num)
    tr = int(num * 0.2)
    val = random.sample(list, tr)
    print(len(val))
    for i in val:
        classname_xml.append(name_num[i])
    # 得到每个类别的xml文件后进行去重处理

    for xml in classname_xml:
        if xml not in need_xml:
            need_xml.append(xml)

val_need_name = []
for xml in need_xml:
    val_need_name.append(xml.split(".xml")[0])
train_need_name = []
for xml in total_xml:
    if xml.split(".xml")[0] not in val_need_name:
            train_need_name.append(xml.split(".xml")[0])

print(len(val_need_name))
print(len(train_need_name))

root_dirs = r"/"   #voc源路径

ftrain = open(root_dirs+'train.txt', 'w')
fval = open(root_dirs+'val.txt', 'w')

for i in val_need_name:
    fval.write(i+'\n')
for j in train_need_name:
    ftrain.write(j+'\n')
ftrain.close()
fval.close()
print("数据划分完毕==========")

3. 对标准的VOC数据集的一些操作

3.1 获取数据集中的包含的类别和各类别包含的个数

解释说明: 通过VOC数据集中的所有XML文件获取到数据集中包含哪些类别,并输出各个类别的数目。

import os
from tqdm import tqdm
import xml.dom.minidom

def ReadXml(FilePath):
    if os.path.exists(FilePath) is False:
        return None
    dom = xml.dom.minidom.parse(FilePath)
    root_ = dom.documentElement
    object_ = root_.getElementsByTagName('object')
    info = []
    for object_1 in object_:
        name = object_1.getElementsByTagName("name")[0].firstChild.data
        bndbox = object_1.getElementsByTagName("bndbox")[0]
        xmin = int(bndbox.getElementsByTagName("xmin")[0].firstChild.data)
        ymin = int(bndbox.getElementsByTagName("ymin")[0].firstChild.data)
        xmax = int(bndbox.getElementsByTagName("xmax")[0].firstChild.data)
        ymax = int(bndbox.getElementsByTagName("ymax")[0].firstChild.data)
        info.append([xmin, ymin, xmax, ymax, name])
    return info

def CountLabelKind(Path):
    LabelDict = {}
    print("Star to count label kinds....")
    for root, dirs, files in os.walk(Path):
        for file in tqdm(files):
            if file[-1] == 'l':
                Infos = ReadXml(root + "\\" + file)
                for Info in Infos:
                    if Info[-1] not in LabelDict.keys():
                        LabelDict[Info[-1]] = 1
                    else:
                        LabelDict[Info[-1]] += 1

    return dict(sorted(LabelDict.items(), key=lambda x: x[0]))

if __name__ == '__main__':

    SrcDir = r"" # 数据集中xml文件的路径

    LabelDict = CountLabelKind(SrcDir)
    KeyDict = sorted(LabelDict)
    print("%d kind labels and %d labels in total:" % (len(KeyDict), sum(LabelDict.values())))
    print(KeyDict)
    print("Label Name and it's number:")
    for key in KeyDict:
        print("%s\t: %d" % (key, LabelDict[key]))

3.2 通过TXT文件获取到val、train数据中各个类别的数目

解释说明:数据集划分完毕后,想要获取到train和val中各个类别的数目。

import xml.etree.ElementTree as ET
import os
import numpy as np
from tqdm import tqdm
# import matplotlib.pyplot as plt

# xml文件路径
xmlpath =r'' # xml文件路径
classes={}
xmlnames = []
with open("【txt文件路径】", 'r', encoding='utf-8') as f:
    filename = f.readlines()
    print(filename)

xmlnames = [c.strip("\n") + '.xml' for c in filename]

print(xmlnames)

for xmlname in tqdm(xmlnames):#遍历所有的文件
    if xmlname.endswith('.xml'):#如果是xml
      # print('当前xml文件名:',xmlname)
      tree = ET.parse(os.path.join(xmlpath,xmlname))#解析xml
      objs=tree.findall('object')
      for obj in objs:
        cls=obj.find('name').text
        if cls not in classes.keys():
           classes[cls]=1
        else:
            classes[cls]+=1

    print('当前统计 ',xmlname,'{:d}/{:d}'.format(xmlnames.index(xmlname)+1,len(xmlnames)))


index = []
values = []
for cls1 in classes.keys():
    print(cls1,':',classes[cls1])
    index.append(cls1)
    values.append(classes[cls1])

4. 对XML文件的一些操作

4.1 将含有不同标注的xml文件合并

解释说明: 对同一张图片生成了两个xml文件(两个xml文件是同名的),两个xml文件中标注的类别信息不一样。需要将两个xml文件合并为一个xml文件。

# -*- coding:utf-8 -*-
import os, shutil
import xml.etree.ElementTree as ET
import time
from multiprocessing import Pool
#from tqdm import tqdm

"""
@功能:合并同名称的xml节点,并将两份数据集整合成一份。
@作者:HJ

"""

"""
思路: 1。首先按照xml文件名,再找对应的图片名,拷贝到相应目录;
      2。如果出现说名称重复的,则将对应的两个xml节点合并成一个xml发到目录里面;
      3。源目录剩下重复的xml跟图片文件。     

"""
# 打开文件
rootPath = "D:\基地装置和通道和杆塔\ceshi"

selectPicPathw = r"D:\基地装置和通道和杆塔\jpg/"
selectXmlPathw = r"D:\基地装置和通道和杆塔\xml/"


def gci(filepath):
    # 遍历filepath下所有文件,包括子目录
    files = os.listdir(filepath)
    for fi in files:
        fi_d = os.path.join(filepath, fi)

        # isdir和isfile参数必须跟绝对路径
        if os.path.isdir(fi_d):
            gci(fi_d)
            if 'xml' in fi_d:
                xmlPath = fi_d + '/'
                jpgpath = xmlPath.split('xml')[0] + 'jpg/'
                print('jpgpath', jpgpath)
                print('xmlPath', xmlPath)
                merge_2xml(jpgpath, xmlPath, selectPicPathw, selectXmlPathw)


def merge_2xml(ImgPath, AnnoPath, selectPicPathw, selectXmlPathw):
    XmlList = os.listdir(AnnoPath)

    for xml in XmlList:  # type: object

        a, b = os.path.splitext(xml)

        imgfile = ImgPath + a + '.jpg'
        xmlfile = AnnoPath + a + '.xml'

        # print("imgfile: ",imgfile)
        # print("xmlfile: ",xmlfile)

        try:
            shutil.move(xmlfile, selectXmlPathw)
            shutil.move(imgfile, selectPicPathw)
        except Exception as r:
            print('error:', r)
            # xml文件已存在,将对应的两个xml节点合并成一个xml发到目录里面;
            xmlfile_exists = selectXmlPathw + a + '.xml'

            tree_exists = ET.parse(xmlfile_exists)
            tree_new = ET.parse(xmlfile)
            root_exists = tree_exists.getroot()
            root_new = tree_new.getroot()

            for it in root_new.iter('object'):
                root_exists.append(it)
            os.remove(xmlfile_exists)

            tree_exists.write(xmlfile_exists)

if __name__ == '__main__':
    gci(rootPath)

4.2 删除xml文件中的指定类别

解释说明: xml文件中包含了不需要训练的类别,可通过此脚本删除。

import os
import xml.etree.ElementTree as ET
import tqdm

def del_delete_eq_1(xml_path):
    # 从xml文件中读取,使用getroot()获取根节点,得到的是一个Element对象
    tree = ET.parse(xml_path)
    root = tree.getroot()

    for object in root.findall('object'):
        deleted = str(object.find('name').text)

        if (deleted in ["需要删除的类别"]):
            root.remove(object)

    tree.write(xml_path)


def main():
    root_dir = r"" #xml文件路径
    xml_path_list = [os.path.join(root_dir, x) for x in os.listdir(root_dir)]

    # 使用tqdm显示进程
    for xml in tqdm.tqdm(xml_path_list):
        del_delete_eq_1(xml)


if __name__ == '__main__':
    main()

4.3 删除xml文件中空类别的xml文件

解释说明:数据清洗完毕后,xml文件中可能存在较多空标签的xml文件,可通过此脚本删除

#  批量移除空的xml标注和xml对应的图片
import xml.etree.cElementTree as ET
import os

path_root = r''   # xml文件路径
path_jpg = r''     # jpg图片路径

xml_list = os.listdir(path_root)

count = 0
for axml in xml_list:
    path_xml = os.path.join(path_root, axml)
    file_name = os.path.splitext(axml)[0]
    tree = ET.parse(path_xml)
    root = tree.getroot()


    if len(root.findall('object')) == 0:
        count = count + 1
        os.remove(path_xml)
        os.remove(path_jpg + "/" +  file_name +".jpg")   # 注意图片后缀   自行更改
        print(count)

4.4 获取包含指定类别的xml文件

解释说明:将xml文件中包含指定类别的xml文件提取出来。

import cv2
import os,shutil
import xml.etree.ElementTree as ET
import pdb
import time
from multiprocessing import Pool


AnnoPath =r"" # xml文件地址
selectXmlPathw =r"" # 需要将xml文件提取到指定位置的地址


ClassList2=["jccj","jcps","jclzym"]  # 需要提取的类别

PicCount = 0

def ProcessPic(imgPath):
    global PicCount
    print("count:",PicCount)
    PicCount += PicCount
    a, b = os.path.splitext(imgPath)
    # print("a:",a,"b:",b)
    # ImgPath = os.path.join(ImgPath)
    tree = ET.parse(AnnoPath + a + '.xml')
    # tree = ET.parse(xml_path)
    root = tree.getroot()
    # object_name = root.iter('name')
    # object_name = tree.find('object')

    xmlfile = AnnoPath + a + '.xml'
    # print("imgfile",imgfile)
    # print("xmlfile", xmlfile)
    for sub4 in root.findall("object"):
        subsub = sub4.find('name')
        #if subsub.text == 'ladder-bak':
        if subsub.text in ClassList2:
             shutil.move(xmlfile, selectXmlPathw)
             break

if __name__ == '__main__':

    ImgPathList = os.listdir(AnnoPath)
    start = time.time()
    pool = Pool()
    pool.map(ProcessPic, ImgPathList)
    pool.close()
    pool.join()
    end = time.time()
    print(end - start)

4.5 删除xml文件中图像边缘的类别

解释说明: 删除xml文件中图像边缘的xml类别(作者在做数据扩增的时候发现可能存在部分xml文件做扩增后有些标注会叠加在图像的边缘)。

import os
import xml.etree.ElementTree as ET
import tqdm

def del_delete_eq_1(xml_path):
    # 从xml文件中读取,使用getroot()获取根节点,得到的是一个Element对象
    tree = ET.parse(xml_path)
    root = tree.getroot()
    for object in root.findall('object'):
        xmin = int(object.find("bndbox").find("xmin").text)
        ymin = int(object.find("bndbox").find("ymin").text)
        xmax = int(object.find("bndbox").find("xmax").text)
        ymax = int(object.find("bndbox").find("ymax").text)
        if abs(xmax - xmin) <= 2 or abs(ymin - ymax) <= 2:
            root.remove(object)
    tree.write(xml_path)

def main():
    root_dir = r""  # xml文件路径
    xml_path_list = [os.path.join(root_dir, x) for x in os.listdir(root_dir)]

    # 使用tqdm显示进程
    for xml in tqdm.tqdm(xml_path_list):
        del_delete_eq_1(xml)
if __name__ == '__main__':
    main()

4.6 删除xml文件中超出图像的框

解释说明:数据扩增后可能有部分类别的框会超出增强后图像的边缘。

import os
import xml.etree.ElementTree as ET
import tqdm
def del_delete_eq_1(xml_path):
    # 从xml文件中读取,使用getroot()获取根节点,得到的是一个Element对象
    tree = ET.parse(xml_path)
    root = tree.getroot()
    width = int(root.find("size").find("width").text)
    height = int(root.find("size").find("height").text)
    for object in root.findall('object'):
        xmin = int(object.find("bndbox").find("xmin").text)
        ymin = int(object.find("bndbox").find("ymin").text)
        xmax = int(object.find("bndbox").find("xmax").text)
        ymax = int(object.find("bndbox").find("ymax").text)
        if xmin < 0 or ymin < 0 or xmax > width or ymax > height:
            root.remove(object)
            print(xml_path)
    tree.write(xml_path)

def main():
    root_dir = r""  # xml文件地址
    xml_path_list = [os.path.join(root_dir, x) for x in os.listdir(root_dir)]

    # 使用tqdm显示进程
    for xml in xml_path_list:
        del_delete_eq_1(xml)
if __name__ == '__main__':
    main()

4.7 找出xml文件中图像size记录为0的文件

解释说明:数据中可能会存在xml文件记录图像的size为0的数据。

import os
import xml.etree.ElementTree as ET
import tqdm


def chacuo(xml_path):
    # 从xml文件中读取,使用getroot()获取根节点,得到的是一个Element对象
    tree = ET.parse(xml_path)
    root = tree.getroot()

    wh = root.find("size")
    width = int(wh.find("width").text)
    height = int(wh.find("height").text)

    if width == 0 or height == 0:
        print(xml_path)
        # os.remove(xml_path)
    # tree.write(xml_path)


def main():
    root_dir = r"D:\基地装置和通道和杆塔\xml/"
    xml_path_list = [os.path.join(root_dir, x) for x in os.listdir(root_dir)]
    # 使用tqdm显示进程
    for xml in tqdm.tqdm(xml_path_list):
        chacuo(xml)
if __name__ == '__main__':
    main()

4.8 为xml文件添加某些属性

解释说明:不同的标注软件会产生不同的xml文件,但是有些标注软件在写入xml文件时会忽略某些属性。作者在做实验时发现,不同的算法源码对XML文件的解析不一样,有些源码会找一些不必要的属性,当找不到这些属性时会报错。如:旷视开源的yolox源码在训练时会检测xml文件的difficult,pose等属性。

import os
import xml.etree.ElementTree as ET
import tqdm


def insert(xml_path):
    # 从xml文件中读取,使用getroot()获取根节点,得到的是一个Element对象
    tree = ET.parse(xml_path)
    root = tree.getroot()

    for object in root.findall('object'):
        try:
            object.find("difficult").text   #属性
        except: 
            x = ET.Element("difficult")      # 属性
            x.text = "0"                     #属性值
            object.append(x)
            tree.write(xml_path)
            print(xml_path)

def main():
    root_dir = r"" # xml文件路径
    xml_path_list = [os.path.join(root_dir, x) for x in os.listdir(root_dir)]
    # 使用tqdm显示进程
    for xml in xml_path_list:
        insert(xml)
if __name__ == '__main__':
    main()

5. 对JPG文件的一些操作

5.1 查找那些无法用CV2读取的图片

import cv2
import os
import tqdm
path_img = r"E:\wurenji\JPEGImages/"
path_jpg = os.listdir(path_img)

for path in tqdm.tqdm(path_jpg):
    img = cv2.imread(path_img + path, cv2.IMREAD_COLOR)
    try:
        assert img is not None
    except:
        print(path)

5.2 查找那些无法转换为矩阵的图片

import os
import shutil

#查找那些无法读取的jpg文件

path_img = r"E:\wurenji\fushusheshi\JPEGImages/"

import numpy as np
from PIL import Image
path_jpg = os.listdir(path_img)
count = 0

for absolute_path in path_jpg:
    try:
        img = Image.open(path_img + absolute_path)
    except:
        count = count + 1
        print("corrupt img",absolute_path)

    try:
        img = np.asarray(path_img + absolute_path)
    except:
        print('corrupt img', absolute_path)
print(count)

5.3 将标注后的数据中指定类别的图像裁剪出来

解释说明:当数据较为庞大的时候,我们不可能一条条的去查询标注人员是否标注错误,此时可以通过此脚本,将标注后的XML文件中在图像中进行裁剪,找出指定类别的所有裁剪图片。这样会非常的方便

import os
import cv2
from xml.etree.ElementTree import ElementTree, Element


jpg_dir = r'E:\lsqlm\zzy\jpg'   # JPG文件地址
xml_dir = r'E:\lsqlm\zzy\xml'   # xml文件地址
jpc_crop_dir = r'E:\lsqlm\zzy\bcqlm' # 裁剪后图片保存地址
cat_id = ['bcqlm']              # 需要裁剪的类别
et = ElementTree()
xml_files = os.listdir(xml_dir)
resize_scale = 224               

for xml_file in xml_files:
    print(xml_file)
    jpg_file = xml_file.replace('.xml', '.jpg')
    jpg_path = os.path.join(jpg_dir, jpg_file)
    xml_path = os.path.join(xml_dir, xml_file)
    # if not os.path.exists(jpg_path):
    #     continue

    img = cv2.imread(jpg_path)
    tree = et.parse(xml_path)
    object_nodes = tree.findall('object')
    object_num = 0
    if len(object_nodes) > 0:
        for object_node in object_nodes:
            if object_node.find('name').text not in cat_id:
                continue
            else:
                object_num += 1
            xmin = int(object_node.find('bndbox/xmin').text)
            ymin = int(object_node.find('bndbox/ymin').text)
            xmax = int(object_node.find('bndbox/xmax').text)
            ymax = int(object_node.find('bndbox/ymax').text)
            object_region = img[ymin:ymax, xmin:xmax]
            w, h = xmax - xmin, ymax - ymin
            m = max(w, h)
            ratio = resize_scale / m
            new_w, new_h = int(ratio * w), int(ratio * h)
            assert new_w > 0 and new_h > 0, 'new_w/new_h is zero!!'
            resize_object_region = cv2.resize(object_region,(new_w,new_h))
            jpg_crop_name = jpg_file.replace('.jpg', '_'+str(object_num)+'.jpg')
            cv2.imwrite(os.path.join(jpc_crop_dir, jpg_crop_name), resize_object_region)

6. 数据的扩增处理

解释说明: 当数据集中的部分类别较少时,我们可以对指定类别进行扩增处理,但是需要注意的是扩增后的数据是不能够用来验证的,因为数据不管怎么扩增它和原始数据的相似度还是非常大的,当放入验证集时验证结果就不客观了。

import os
import shutil
import numpy as np
import imgaug as ia
import xml.etree.ElementTree as ET

from PIL import Image
from imgaug import augmenters as iaa

ia.seed(1)


def read_xml_annotation(root, image_id):
    in_file = open(os.path.join(root, image_id),encoding="utf-8")
    tree = ET.parse(in_file)
    root = tree.getroot()
    bndboxlist = []

    for object in root.findall('object'):  # 找到root节点下的所有country节点
        bndbox = object.find('bndbox')  # 子节点下节点rank的值

        xmin = int(bndbox.find('xmin').text)
        xmax = int(bndbox.find('xmax').text)
        ymin = int(bndbox.find('ymin').text)
        ymax = int(bndbox.find('ymax').text)

        bndboxlist.append([xmin, ymin, xmax, ymax])

    return bndboxlist


def change_xml_annotation(root, image_id, new_target):
    new_xmin = new_target[0]
    new_ymin = new_target[1]
    new_xmax = new_target[2]
    new_ymax = new_target[3]

    in_file = open(os.path.join(root, str(image_id) + '.xml'))  # 这里root分别由两个意思
    tree = ET.parse(in_file)
    xmlroot = tree.getroot()
    object = xmlroot.find('object')
    bndbox = object.find('bndbox')
    xmin = bndbox.find('xmin')
    xmin.text = str(new_xmin)
    ymin = bndbox.find('ymin')
    ymin.text = str(new_ymin)
    xmax = bndbox.find('xmax')
    xmax.text = str(new_xmax)
    ymax = bndbox.find('ymax')
    ymax.text = str(new_ymax)
    tree.write(os.path.join(root, str("%06d" % str(id) + '.xml')))


def change_xml_list_annotation(root, image_id, new_target, saveroot, _id):
    in_file = open(os.path.join(root, str(image_id) + '.xml'),encoding="utf-8")  # 这里root分别由两个意思
    tree = ET.parse(in_file)
    elem = tree.find('filename')
    elem.text = _id + '.jpg'
    xmlroot = tree.getroot()
    index = 0

    for object in xmlroot.findall('object'):  # 找到root节点下的所有country节点
        bndbox = object.find('bndbox')  # 子节点下节点rank的值

        new_xmin = new_target[index][0]
        new_ymin = new_target[index][1]
        new_xmax = new_target[index][2]
        new_ymax = new_target[index][3]

        xmin = bndbox.find('xmin')
        xmin.text = str(new_xmin)
        ymin = bndbox.find('ymin')
        ymin.text = str(new_ymin)
        xmax = bndbox.find('xmax')
        xmax.text = str(new_xmax)
        ymax = bndbox.find('ymax')
        ymax.text = str(new_ymax)

        index = index + 1

    tree.write(os.path.join(saveroot, _id + '.xml'))


def mkdir(path):
    # 去除首位空格
    path = path.strip()
    # 去除尾部 \ 符号
    path = path.rstrip("\\")
    # 判断路径是否存在
    isExists = os.path.exists(path)
    # 判断结果
    if not isExists:
        # 如果不存在则创建目录
        os.makedirs(path)
        print(path + ' 创建成功')
        return True
    else:
        # 如果目录存在则不创建,并提示目录已存在
        print(path + ' 目录已存在')
        return False


if __name__ == "__main__":
    IMG_DIR = r""  # 需要增强的影像文件夹路径
    XML_DIR = r""  # 需要增强的XML文件夹路径

    AUG_IMG_DIR = r""  # 存储增强后的影像文件夹路径
    AUG_XML_DIR = r""  # 存储增强后的XML文件夹路径

    try:
        shutil.rmtree(AUG_IMG_DIR)
    except FileNotFoundError as e:
        a = 1
    mkdir(AUG_IMG_DIR)

    try:
        shutil.rmtree(AUG_XML_DIR)
    except FileNotFoundError as e:
        a = 1
    mkdir(AUG_XML_DIR)

    AUGLOOP = 3  # 每张影像增强的数量

    boxes_img_aug_list = []
    new_bndbox = []
    new_bndbox_list = []

    # 影像增强
    seq = iaa.Sequential([
        #对80%的图像进行左右翻转
        iaa.Fliplr(0.8),
        # 使用下面的一个或者两个方法进行增强图像
        iaa.SomeOf((0,2),
                       [
                           # 随机裁剪
                           iaa.Crop(percent=(0,0.2)),# 裁剪幅度为0~0.2

                           iaa.Affine(  # 部分图像做仿射变换
                               scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},  # 图像缩放为80%到120%
                               cval=(0),  # 全白全黑填充
                               mode=ia.ALL  # 定义填充图像外区域的方法
                           ),

                            # 模糊
                           iaa.OneOf([
                               iaa.GaussianBlur((0, 3.0)),
                               iaa.AverageBlur(k=(2, 7)),  # 核大小2~7之间,k=((5, 7), (1, 3))时,核高度5~7,宽度1~3
                               iaa.MedianBlur(k=(3, 11)),
                           ]),
                           # 对比度变为原来的一半或者1.5
                           iaa.ContrastNormalization((0.9,1.2),per_channel=0.5),

                       ],
                        random_order=True  #随机选择
                   )
    ])

    for root, sub_folders, files in os.walk(XML_DIR):

        for name in files:

            bndbox = read_xml_annotation(XML_DIR, name)

            for epoch in range(AUGLOOP):
                seq_det = seq.to_deterministic()  # 保持坐标和图像同步改变,而不是随机
                # 读取图片
                img = Image.open(os.path.join(IMG_DIR, name[:-4] + '.jpg'))
                # sp = img.size
                img = np.asarray(img)
                # bndbox 坐标增强
                for i in range(len(bndbox)):
                    bbs = ia.BoundingBoxesOnImage([
                        ia.BoundingBox(x1=bndbox[i][0], y1=bndbox[i][1], x2=bndbox[i][2], y2=bndbox[i][3]),
                    ], shape=img.shape)

                    bbs_aug = seq_det.augment_bounding_boxes([bbs])[0]
                    boxes_img_aug_list.append(bbs_aug)

                    # new_bndbox_list:[[x1,y1,x2,y2],...[],[]]
                    n_x1 = int(max(1, min(img.shape[1], bbs_aug.bounding_boxes[0].x1)))
                    n_y1 = int(max(1, min(img.shape[0], bbs_aug.bounding_boxes[0].y1)))
                    n_x2 = int(max(1, min(img.shape[1], bbs_aug.bounding_boxes[0].x2)))
                    n_y2 = int(max(1, min(img.shape[0], bbs_aug.bounding_boxes[0].y2)))
                    if n_x1 == 1 and n_x1 == n_x2:
                        n_x2 += 1
                    if n_y1 == 1 and n_y2 == n_y1:
                        n_y2 += 1
                    if n_x1 >= n_x2 or n_y1 >= n_y2:
                        print('error', name)
                    new_bndbox_list.append([n_x1, n_y1, n_x2, n_y2])
                # 存储变化后的图片
                image_aug = seq_det.augment_images([img])[0]
                path = os.path.join(AUG_IMG_DIR, name[:-4] + '_' + str(epoch) + '.jpg')
                image_auged = bbs.draw_on_image(image_aug, thickness=0)
                try:
                    Image.fromarray(image_auged).save(path)
                except:
                    print(path)

                # 存储变化后的XML
                change_xml_list_annotation(XML_DIR, name[:-4], new_bndbox_list, AUG_XML_DIR,
                                           name[:-4] + '_' + str(epoch))
                print(name[:-4] + '_' + str(epoch) + '.jpg')
                new_bndbox_list = []