这里写自定义目录标题
- 背景
- 数据读入部分
- transform
- Dataset
- DataLoader
- 网络定义
背景
上一篇文章写了pytorch版本yolov3的源码。代码较为简单。这篇文章准备写一篇代码较为复杂的SSD实现版本。该版本的github地址为:
https://github.com/amdegroot/ssd.pytorch
在该github下的使用操作方法比较完善,就不在这里记录了。在这里只记录代码的解析。
数据读入部分
数据读入部分的代码为
dataset=VOCDetection(root=args.dataset_root,transform=SSDAugmentation(cfg['min_dim'],MEANS))
transform
其中SSDAugmentation函数为图像、标签转换函数,具体的定义为之在utils/augmentations.py
class SSDAugmentation(object):
def __init__(self, size=300, mean=(104, 117, 123)):
self.mean = mean
self.size = size
self.augment = Compose([
ConvertFromInts(), #数据类型转换
ToAbsoluteCoords(), #位置信息转换
PhotometricDistort(), #镜像翻转
Expand(self.mean), #扩展图像
RandomSampleCrop(), #随机裁剪
RandomMirror(), #随机镜像翻转
ToPercentCoords(), #位置归一化
Resize(self.size), #图像尺寸缩放
SubtractMeans(self.mean) #图像去均值
])
def __call__(self, img, boxes, labels):
return self.augment(img, boxes, labels)
其中__call__方法是python的一个方法,该方法的定义表明该类可以直接调用。
在augmentations.py中详细的定义了每一个数据转换的方法,具体的定义就不叙述了。
Dataset
类VOCDetetction继承于torch.utils.data.Dataset类别,需要定义imgs,getitem、__len__方法。
python的__getitem__方法可以让对象实现迭代功能。在这里,会返回单张图像及其标签。定义如下:
class VOCDetection(data.Dataset):
def __init__(self, root,
image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
transform=None, target_transform=VOCAnnotationTransform(),
dataset_name='VOC0712'):
self.root = root #设置数据集的根目录
self.image_set = image_sets #设置要选用的数据集
self.transform = transform #定义图像转换方法
self.target_transform = target_transform #定义标签的转换方法
self.name = dataset_name #定义数据集名称
self._annopath = osp.join('%s', 'Annotations', '%s.xml') #记录标签的位置
self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg') #记录图像的位置
self.ids = list() #记录数据集中的所有图像的名字
#读入数据集中的图像名称,可以依照该名称和_annopath、_imgpath推断出图片、描述文件存储的位置
for (year, name) in image_sets:
rootpath = osp.join(self.root, 'VOC' + year)
for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
self.ids.append((rootpath, line.strip()))
def __getitem__(self, index):
im, gt, h, w = self.pull_item(index)
return im, gt
def __len__(self):
return len(self.ids)
def pull_item(self, index):
img_id = self.ids[index] #获取index对应的img名称
target = ET.parse(self._annopath % img_id).getroot() #读取xml文件
img = cv2.imread(self._imgpath % img_id) #获取图像
height, width, channels = img.shape #获取图像的尺寸
if self.target_transform is not None:
target = self.target_transform(target, width, height) #获取target
if self.transform is not None:
target = np.array(target)
img, boxes, labels = self.transform(img, target[:, :4], target[:, 4]) #对图像、target进行转换
# to rgb
img = img[:, :, (2, 1, 0)] #opencv读入图像的顺序是BGR,该操作将图像转为RGB
# img = img.transpose(2, 0, 1)
target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
return torch.from_numpy(img).permute(2, 0, 1), target, height, width #返回image、label、宽高.这里的permute(2,0,1)是将原有的三维(28,28,3)变为(3,28,28),将通道数提前,为了统一torch的后续训练操作。
def pull_image(self, index):
img_id = self.ids[index]
return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)
def pull_anno(self, index):
img_id = self.ids[index]
anno = ET.parse(self._annopath % img_id).getroot()
gt = self.target_transform(anno, 1, 1)
return img_id[1], gt
def pull_tensor(self, index):
return torch.Tensor(self.pull_image(index)).unsqueeze_(0)
DataLoader
实际的使用过程中,使用DataLoader,批量的读入数据。不知道什么原因,在windows下执行的时候,worker只能设置为0,否则跑不起来。DataLoader实现了一个并行读入图像、标签的功能。
data_loader = data.DataLoader(dataset, args.batch_size, #数据loader
num_workers=args.num_workers,
shuffle=True, collate_fn=detection_collate,
pin_memory=True)
data_loader可以用以下两种方式调用
#方法一
for (img,target) in data_loader
print(img,shape) #height,width,channels
print(target)
#方法二:
batch_iterator=iter(data_loader)
img,target=next(batch_iterator)
data_loader获取到的数据为(batch,channels,height,width),复合图像网络层需要的数据定义。
网络定义
网络定义在ssd.py文件中。网络定义继承于torch.nn.Module。torch的优势之一是能够进行自动求导的过程。当定义了网络的正向的传播方向,会依照结构进行反向的传播过程,因此在网络的定义过程中只需要定义每个层以及对应层的forward功能。
作者将ssd网络定义分成两个部分。第一步获取各个模块的层,用list进行装填。第二步实现forward的串接
def build_ssd(phase, size=300, num_classes=21):
if phase != "test" and phase != "train":
print("ERROR: Phase: " + phase + " not recognized")
return
if size != 300:
print("ERROR: You specified size " + repr(size) + ". However, " +
"currently only SSD300 (size=300) is supported!")
return
base_, extras_, head_ = multibox(vgg(base[str(size)], 3),
add_extras(extras[str(size)], 1024),
mbox[str(size)], num_classes) #获取三部分需要的卷积层
return SSD(phase, size, base_, extras_, head_, num_classes)
multibox生成三部分网络层,分别用list装填(实际上,head_为两个module,分别用于计算类别和位置偏移量)
base_为基础网络部分可更换,extras_为基础网络后降采样的部分,head_为类别、位置偏移量计算的卷积部分。
super() 函数是用于调用父类(超类)的一个方法