深度学习论文: KeepAugment: A Simple Information-Preserving Data Augmentation Approach及其PyTorch实现
KeepAugment: A Simple Information-Preserving Data Augmentation Approach
PDF: ​​​https://arxiv.org/pdf/2011.11778.pdf​​​ PyTorch代码: ​​https://github.com/shanglianlm0525/CvPytorch​​ PyTorch代码: ​​https://github.com/shanglianlm0525/PyTorch-Networks​

1 概述

KeepAugment 提出了一种简单但高效的方法,称为"保持增强",以提高增强图像的保真度。KeepAugment 首先使用显著图来检测原始图像上的重要区域,然后在增广过程中保留这些信息区域。这种信息保存策略使我们能够生成更多保真的训练样本。

2 KeepAugment

KeepAugment主要可以分为(1)计算saliency map(2)保留重要的矩形区域这两个步骤

深度学习论文: KeepAugment: A Simple Information-Preserving Data Augmentation Approach及其PyTorch实现_pytorch

2-1 Saliency map

KeepAugment 通过 vanilla gradient 方法获取saliency map,具体来说,给定图像x及其对应标签logit value 深度学习论文: KeepAugment: A Simple Information-Preserving Data Augmentation Approach及其PyTorch实现_深度学习_02,KeepAugment 将深度学习论文: KeepAugment: A Simple Information-Preserving Data Augmentation Approach及其PyTorch实现_sed_03设为vanilla gradients的绝对值深度学习论文: KeepAugment: A Simple Information-Preserving Data Augmentation Approach及其PyTorch实现_pytorch_04。对于RBG图像,采用通道最大值,以获取每个像素(i,j)的单个显着性值。重要性得分定义的公式如下所示:

深度学习论文: KeepAugment: A Simple Information-Preserving Data Augmentation Approach及其PyTorch实现_sed_05

2-2 Selective-Cut

KeepAugment对于区域级的数据增强方法(如:cutout), 通过确保被切割的区域不会具有较大的重要性得分来控制数据增强的保真度。即随机采样要切割的区域S,直到其重要性得分I(S,x,y)小于给定的阈值τ。

深度学习论文: KeepAugment: A Simple Information-Preserving Data Augmentation Approach及其PyTorch实现_pytorch_06

2-3 Selective-Paste

图像级变换共同修改了整个图像,所以我们通过粘贴具有较高重要性的随机区域来确保变换的保真度。即对满足阈值τ的I(S, x, y)>τ的区域S进行均匀采样,然后将原始图像x的regionS粘贴到x’。

深度学习论文: KeepAugment: A Simple Information-Preserving Data Augmentation Approach及其PyTorch实现_sed_07

3 Efficient Implementation of KeepAugment

KeepAugment要求在每个训练步骤中通过反向传播来计算 saliency map。直接计算的话会导致计算成本增加两倍。因此作者提出两种不同的近似策略。

深度学习论文: KeepAugment: A Simple Information-Preserving Data Augmentation Approach及其PyTorch实现_深度学习_08


Low resolution based approximation: 对于给定的图像x,首先生成一个低分辨率副本,然后计算其saliency map;将低分辨率saliency map映射到其相应的原始分辨率。

**Early head based approximation:**在网络前面层添加loss,然后用此loss来生成saliency map。