🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题

🍊专栏推荐:深度学习网络原理与实战

🍊近期目标:写好专栏的每一篇文章

🍊支持小苏:点赞👍🏼、收藏⭐、留言📩

HRNet原理详解篇

写在前面

Hello,大家好,我是小苏👦🏽👦🏽👦🏽

今天我打算来给大家介绍一个新的专题——姿态估计。先让我来搜搜姿态估计,看看百度出来的结果,如下:

关键点检测——HRNet原理详解篇_数据

看到这些图,我觉得大家应该还是蛮熟悉的。这些都在对人体关键点进行检测,在计算机视觉领域,关键点检测是一个非常常见的任务,那么他和姿态估计有什么差异呢?我谈谈我的理解,它们之间确实是存在一定的区别,我感觉用“使用关键点检测技术来实现姿态估计”这句话来表示它们的关系是比较贴切的。也就是说,关键点检测是一项技术,而姿态估计是一种应用。

那么今天所讲的HRNet其实就是一个实现关键点检测任务的网络,作者是我们的中国人——王井东老师。话不多说,让我们进入到本节的HRNet网络原理的讲解中。🚀🚀🚀

大家阅读此篇博客前强烈建议先了解一下COCO数据集关键点检测标注文件,我已经写了相关博客,点击☞☞☞了解详情。

姿态估计概述

在具体介绍HRNet的网络结构之前,我想先给大家介绍一下姿态估计概述,包括常见方法、数据集和评价指标以及应用场景,为此,我绘制了一个思维导图供大家查看,如下:

关键点检测——HRNet原理详解篇_2d_02

【注:上图的一些细节可能看不清楚,需要的可以私信我,发Xmind源文件】

那么我们再来看看本文介绍的HRNet属于上述思维导图的哪种方法,其属于–2D姿态估计–>单人检测–>基于热力图–🌱🌱🌱

整体框架

我们先来看看实现的效果,如下图所示:

关键点检测——HRNet原理详解篇_HRNet_03

当我们将一张图片输入HRNet网络后,会得到一个输出的特征图,然后对输出的特征图做一些后处理,就可以得到在原图上关键点的坐标。


看了上面的图,我想你大概知道HRNet实现了一个什么样的功能了,下面我们将来详细分析一下HRNet的网络结果:

图片来自B站霹雳吧啦wz

【这个图显示的不是很清楚,大家点击这个链接下载查看:)】

我们可以大体来看一下这个结果,其实看上去并不是很复杂,主要还是将不同尺寸的特征进行融合,这里就不带大家分析为什么这么设计了。【哈哈哈哈因为我也不知道咋分析🍀🍀🍀大家感兴趣的可以看看关于HRNet对王井东老师的采访,看看当时他的灵感是怎么来的,可以点击☞☞☞前往观看。】


我觉得大家要搞清楚网络结构到底是如何实现的,自己动手调试是必要的。所以我也不会再对这个网络结构做太细致的介绍,只说一些需要注意的点。

  • 上图中粉红色的Conv是指一个CBA结构,即卷积、BN和激活函数,橙色的Conv2d表示卷积
  • Layer1就是resnet中的layer1,我们可以看看相关代码,如下:
self.layer1 = nn.Sequential(
            Bottleneck(64, 64, downsample=downsample),
            Bottleneck(256, 64),
            Bottleneck(256, 64),
            Bottleneck(256, 64)
        )

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
                               bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion,
                                  momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out
  • 剩下的就是transitionStage的结构,这里的代码我觉得写的很巧妙,我就不细说了,大家自己动手去调试调试叭,是很容易的。

可能大家想吐槽,这小节感觉什么也没说,只是放了一个网络结构图,一点解析都没有。确实是这样哈,这是因为我觉得这部分难度不大,大家完全可以自己看明白,更重要的是大家应该更关注王井东老师涉及这个网络的构思,想想他当时是怎么想出这个网络的,关于这点,可以去看前文给出的对王井东老师采访的视频。

原理详解

在上一小节,为大家介绍了HRNet的网络结构。在这一小节中,我想和大家唠唠这个网络的过程及原理。首先对于COCO数据集中的一张尺寸为H×W的3通道图片,我们会对齐进行一系列数据增强手段,如仿射变换、随机水平翻转等等,经过数据增强后,我们会将原来H×W×3的图像resize到256×196×3的大小,之后这个256×196×3的图像就作为网络的输入。

这里的数据增强是理解HRNet的重难点,因为在数据增强的过程中会涉及关键点位置的变换。HRNet中做了HalfBodyAffineTransformRandomHorizontalFlip等数据增强手段,关于这些我们将在HRNet源码实战篇详细为大家介绍。

除了数据增强外,由于HRNet是基于热力图的关键点检测方法,所以我们需要将关键点映射成热力图,那么其是怎么将关键点映射成热力图的呢,这里我们来结合代码来详细看看这一步骤:

首先,先来看看其__init__函数:

def __init__(self,
             heatmap_hw: Tuple[int, int] = (256 // 4, 192 // 4),
             gaussian_sigma: int = 2,
             keypoints_weights=None):
    self.heatmap_hw = heatmap_hw
    self.sigma = gaussian_sigma
    self.kernel_radius = self.sigma * 3
    self.use_kps_weights = False if keypoints_weights is None else True
    self.kps_weights = keypoints_weights

    # generate gaussian kernel(not normalized)
    kernel_size = 2 * self.kernel_radius + 1
    kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
    x_center = y_center = kernel_size // 2
    for x in range(kernel_size):
        for y in range(kernel_size):
            kernel[y, x] = np.exp(-((x - x_center) ** 2 + (y - y_center) ** 2) / (2 * self.sigma ** 2))
            # print(kernel)

            self.kernel = kernel

这段主要定义了存储热力图的宽度和高度、高斯标准差和关键点权重等信息,然后生成了一个大小为13*13的高斯核kernel(中间的值大,往四周扩散值越来越小),如下图所示:

关键点检测——HRNet原理详解篇_HRNet_04

接着我们来看__call__函数:

def __call__(self, image, target):
    kps = target["keypoints"]
    num_kps = kps.shape[0]
    kps_weights = np.ones((num_kps,), dtype=np.float32)
    if "visible" in target:
        visible = target["visible"]
        kps_weights = visible

        heatmap = np.zeros((num_kps, self.heatmap_hw[0], self.heatmap_hw[1]), dtype=np.float32)
        heatmap_kps = (kps / 4 + 0.5).astype(np.int)  # round
        for kp_id in range(num_kps):
            v = kps_weights[kp_id]
            if v < 0.5:
                # 如果该点的可见度很低,则直接忽略
                continue

                x, y = heatmap_kps[kp_id]
                ul = [x - self.kernel_radius, y - self.kernel_radius]  # up-left x,y
                br = [x + self.kernel_radius, y + self.kernel_radius]  # bottom-right x,y
                # 如果以xy为中心kernel_radius为半径的辐射范围内与heatmap没交集,则忽略该点(该规则并不严格)
                if ul[0] > self.heatmap_hw[1] - 1 or \
                ul[1] > self.heatmap_hw[0] - 1 or \
                br[0] < 0 or \
                br[1] < 0:
                    # If not, just return the image as is
                    kps_weights[kp_id] = 0
                    continue

                    # Usable gaussian range
                    # 计算高斯核有效区域(高斯核坐标系)
                    g_x = (max(0, -ul[0]), min(br[0], self.heatmap_hw[1] - 1) - ul[0])
                    g_y = (max(0, -ul[1]), min(br[1], self.heatmap_hw[0] - 1) - ul[1])
                    # image range
                    # 计算heatmap中的有效区域(heatmap坐标系)
                    img_x = (max(0, ul[0]), min(br[0], self.heatmap_hw[1] - 1))
                    img_y = (max(0, ul[1]), min(br[1], self.heatmap_hw[0] - 1))

                    if kps_weights[kp_id] > 0.5:
                        # 将高斯核有效区域复制到heatmap对应区域
                        heatmap[kp_id][img_y[0]:img_y[1] + 1, img_x[0]:img_x[1] + 1] = \
                        self.kernel[g_y[0]:g_y[1] + 1, g_x[0]:g_x[1] + 1]

                        if self.use_kps_weights:
                            kps_weights = np.multiply(kps_weights, self.kps_weights)

                            plot_heatmap(image, heatmap, kps, kps_weights)

                            target["heatmap"] = torch.as_tensor(heatmap, dtype=torch.float32)
                            target["kps_weights"] = torch.as_tensor(kps_weights, dtype=torch.float32)

                            return image, target

我给大家解释一下可能难理解的地方:

heatmap_kps = (kps / 4 + 0.5).astype(np.int)

这句是将关键点的坐标映射到热力图上,因为最终的热力图相较于原图像下采样了4倍,所以要除以4,这里加上0.5是起到一个四舍五入的作用,因为后面要将坐标转为int格式。

ul = [x - self.kernel_radius, y - self.kernel_radius]  # up-left x,y
br = [x + self.kernel_radius, y + self.kernel_radius]  # bottom-right x,y

这两句是找到某个关键点对应热力图的左上角(ul)和右下角(br)的坐标,kernel_radius是高斯核的半径,如下图所示,hw坐标系表示热力图坐标,中间的⚪表示关键点在热力图上的坐标,坐标为(x,y):

关键点检测——HRNet原理详解篇_关键点检测_05

# 如果以xy为中心kernel_radius为半径的辐射范围内与heatmap没交集,则忽略该点(该规则并不严格)
if ul[0] > self.heatmap_hw[1] - 1 or \
        ul[1] > self.heatmap_hw[0] - 1 or \
        br[0] < 0 or \
        br[1] < 0:
    # If not, just return the image as is
    kps_weights[kp_id] = 0
    continue

这句是看看以xy为中心kernel_radius为半径的辐射范围内(就是上图中的正方形区域内)与heatmap(就是上图的hw坐标系,当然其h=64,w=48,并不是无线延长的坐标系)有没有交集,若无交集,则将kps_weights[kp_id]置为0。

# Usable gaussian range
# 计算高斯核有效区域(高斯核坐标系)
g_x = (max(0, -ul[0]), min(br[0], self.heatmap_hw[1] - 1) - ul[0])
g_x = (max(0, -ul[1]), min(br[1], self.heatmap_hw[0] - 1) - ul[1])
# image range
# 计算heatmap中的有效区域(heatmap坐标系)
img_x = (max(0, ul[0]), min(br[0], self.heatmap_hw[1] - 1))
img_y = (max(0, ul[1]), min(br[1], self.heatmap_hw[0] - 1))

这几句分别计算高斯核有效区域和heatmap中的有效区域,为下一步将将高斯核有效区域复制到heatmap对应区域做准备:

if kps_weights[kp_id] > 0.5:
    # 将高斯核有效区域复制到heatmap对应区域
    heatmap[kp_id][img_y[0]:img_y[1] + 1, img_x[0]:img_x[1] + 1] = \
        self.kernel[g_y[0]:g_y[1] + 1, g_x[0]:g_x[1] + 1]

这几句到底实现了什么呢,其实就是把高斯核kernel复制到热力图中,至于复制到什么位置,复制多少,就看g_x、g_x、img_x和img_y了。我调试帮助大家理解一下,比如现在g_x=(0,12)、g_y=(0,12)、img_x=(25,37)和img_y=(12,24)。

关键点检测——HRNet原理详解篇_ide_06

g_x[0]:g_x[1]+1=0:12+1、g_y[0]:g_y[1]+1=0:12+1表示复制kernel的x方向(0,12+1)范围内的值和y方向(0,12+1)范围内,你看kernel的shape你会发现,其大小为13*13,那么这个(0,12+1)就是复制整个kernel数组**(这里刚好是整个数组,你调试的话会有不同的结果)**:

关键点检测——HRNet原理详解篇_关键点检测_07

那么把这个数组复制到哪里呢,其实就是热力图的对应区域,这是就用到了img_x=(25,37)和img_y=(12,24),将其复制到热力图w方向(25,37+1)和h方向(12,24+1)的位置,如下图所示:

关键点检测——HRNet原理详解篇_数据_08


这里展示一下图片和产生热力图的结果,如下图所示:【注:由于不是同一次调试的结果,所以这里的图像和之前的有所差异】

关键点检测——HRNet原理详解篇_数据_09


最后我还想说一个小点,就是kps_weights这个值,表示的是关键点的权重,如果没有指定这个参数,那么其就默认是关键点的可见性,如果指定了这个参数,其会让原来的可见性乘这个指定的参数,在HRNet中,这个kps_weights默认如下:

关键点检测——HRNet原理详解篇_HRNet_10


热力图构建完成后,我们一切的准备工作就做完了,接下来就会将这个256×196×3的图像送入HRNet中,其会得到一个大小为64×48×17的特征图。我们可以看到输出特征图的宽和高相较于输入下采样了4倍,然后这个17表示有17个关键点的特征图,每个特征图的尺寸都为64×48大小的。【注:COCO数据集中标注了17个人体关键点位置,不清楚的可以看看我这篇对COCO数据集关键点检测的分析。】

其实我们对这个64×48×17大小的特征图进行一些后处理操作,就可以得到17个关键点的坐标信息,具体怎么做的,我们来结合代码为大家介绍一下。首先我们要将得到的特征图变成坐标,实现方法如下:

def get_max_preds(batch_heatmaps):
    """
    get predictions from score maps
    heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
    """
    assert isinstance(batch_heatmaps, torch.Tensor), 'batch_heatmaps should be torch.Tensor'
    assert len(batch_heatmaps.shape) == 4, 'batch_images should be 4-ndim'

    batch_size, num_joints, h, w = batch_heatmaps.shape
    heatmaps_reshaped = batch_heatmaps.reshape(batch_size, num_joints, -1)
    maxvals, idx = torch.max(heatmaps_reshaped, dim=2)

    maxvals = maxvals.unsqueeze(dim=-1)
    idx = idx.float()

    preds = torch.zeros((batch_size, num_joints, 2)).to(batch_heatmaps)     

    preds[:, :, 0] = idx % w  # column 对应最大值的x坐标
    preds[:, :, 1] = torch.floor(idx / w)  # row 对应最大值的y坐标

    pred_mask = torch.gt(maxvals, 0.0).repeat(1, 1, 2).float().to(batch_heatmaps.device)

    preds *= pred_mask
    return preds, maxvals

这段代码实现了什么呢,我来解释一下,首先会将刚刚(1,17,64,48)的特征图resize到(1,17,3072),即将高度和宽度合并成一维,这个维度表示有17个一维向量(17个表示17个关键点),每个一维向量有3072个值,我们计算出每个一维向量即3072个值中的最大值和最大值对应的索引,然后通过最大值索引来计算关键点的坐标,为了方便大家理解,作图如下:

关键点检测——HRNet原理详解篇_关键点检测_11

最后还需要将设置一个模板,过滤掉maxvals小于0的坐标,如下:

pred_mask = torch.gt(maxvals, 0.0).repeat(1, 1, 2).float().to(batch_heatmaps.device)
preds *= pred_mask

这个maxvals其实就是一个置信度分数,这步操作完后,我们就有了关键点在特征图上的坐标和置信度分数了,接下来其实就只要将这个坐标映射到原图上就可以了,如下:

for i in range(coords.shape[0]):
    preds[i] = affine_points(preds[i], trans[i])
def affine_points(pt, t):
    ones = np.ones((pt.shape[0], 1), dtype=float)
    pt = np.concatenate([pt, ones], axis=1).T
    new_pt = np.dot(t, pt)
    return new_pt.T

这里是通过仿射变换的逆变换将关键点从特征图映射回原图上的,因为我们在图像预处理过程中使用了仿射变换。但是代码中还对刚刚得到的坐标做了后处理,如下:

# post-processing
if post_processing:
    for n in range(coords.shape[0]):
        for p in range(coords.shape[1]):
            hm = batch_heatmaps[n][p]
            px = int(math.floor(coords[n][p][0] + 0.5))
            py = int(math.floor(coords[n][p][1] + 0.5))
            if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
                diff = torch.tensor(
                    [
                        hm[py][px + 1] - hm[py][px - 1],
                        hm[py + 1][px] - hm[py - 1][px]
                    ]
                ).to(batch_heatmaps.device)
                coords[n][p] += torch.sign(diff) * .25

preds = coords.clone().cpu().numpy()

这段代码主要是想得到更加精确的坐标,画图帮大家理解:

关键点检测——HRNet原理详解篇_HRNet_12

这样有了关键点的坐标,就可以将其映射到原图上,下图展示了映射一个关键点nose的结果,其它的关键点原理相同:

关键点检测——HRNet原理详解篇_2d_13

损失计算

HRNet中的损失计算非常简单,使用的是MSE均方误差,关键代码如下:

self.criterion = torch.nn.MSELoss(reduction='none')
loss = self.criterion(logits, heatmaps).mean(dim=[2, 3])

这里criterion有两个传入的值,一个是logits,一个是heatmaps,这两个都是什么呢,我来解释一下。logits很好理解,其就是网络的输出结果,是一个64×48×17大小的特征图;那么heatmaps是什么呢,我们知道,损失计算肯定是要用到预测值和真实值,logits是网络输出,logits是预测值,那么heatmaps就应该是真实值。但是heatmaps到底是什么呢?我们关键点检测的真实值不是关键的检测的坐标吗【坐标的话应该维度应该是2×17】,怎么会是一个heatmaps?【两个进行MSE损失计算,heatmaps维度应该为64×48×17,和logits一致】

不知道大家能否想到,其实啊,这就是基于热力图(heatmaps)进行关键点检测的关键,如果标签是单纯的坐标,那么其实就是基于回归的方式实现关键点检测。又说回来,HRNet怎么将关键点坐标转换成热力图的呢?其实就是在对图像进行数据增强时使用了transforms.KeypointToHeatMap,关于此方法在HRNet源码详解篇有详细介绍,大家一定要去看,对你理解HRNet有很大帮助。

评价指标

在关键点检测的任务中,我们一般使用OKS来衡量预测keypoints和真实keypoints的相似程度,它取值在0~1之间,越大表示越相似,其表达式如下:

关键点检测——HRNet原理详解篇_HRNet_14

看到这个公式你懵了,我也懵了。🥀🥀🥀对相关变量做一定的解释:

  • 关键点检测——HRNet原理详解篇_ide_15表示第关键点检测——HRNet原理详解篇_ide_15个关键点
  • 关键点检测——HRNet原理详解篇_关键点检测_17表示第关键点检测——HRNet原理详解篇_ide_15个预测关键点和真实关键点的欧式距离
  • s表示groundtruth中所占面积的平方根,是可以直接获取的,那么关键点检测——HRNet原理详解篇_ide_19即表示面积
  • 关键点检测——HRNet原理详解篇_HRNet_20表示第关键点检测——HRNet原理详解篇_ide_15个骨骼点的归一化因子,是个常数
  • 关键点检测——HRNet原理详解篇_数据_22表示第i个关键点的可见性
  • 关键点检测——HRNet原理详解篇_2d_23表示当x为True时,值为1,当x为False时,值为0。关键点检测——HRNet原理详解篇_2d_24表示当关键点在图像中标注了(v=1或v=2),则为1,没有标注(v=0)则为0

结合COCO中相关的代码来解释一下,主要看看这个公式是不是和代码一致,代码如下:

def computeOks(self, imgId, catId):
    p = self.params
    # dimention here should be Nxm
    gts = self._gts[imgId, catId]
    dts = self._dts[imgId, catId]
    inds = np.argsort([-d['score'] for d in dts], kind='mergesort')
    dts = [dts[i] for i in inds]
    if len(dts) > p.maxDets[-1]:
        dts = dts[0:p.maxDets[-1]]
    # if len(gts) == 0 and len(dts) == 0:
    if len(gts) == 0 or len(dts) == 0:
        return []
    ious = np.zeros((len(dts), len(gts)))
    sigmas = p.kpt_oks_sigmas
    vars = (sigmas * 2)**2
    k = len(sigmas)
    # compute oks between each detection and ground truth object
    for j, gt in enumerate(gts):
        # create bounds for ignore regions(double the gt bbox)
        g = np.array(gt['keypoints'])
        xg = g[0::3]; yg = g[1::3]; vg = g[2::3]
        k1 = np.count_nonzero(vg > 0)
        bb = gt['bbox']
        x0 = bb[0] - bb[2]; x1 = bb[0] + bb[2] * 2
        y0 = bb[1] - bb[3]; y1 = bb[1] + bb[3] * 2
        for i, dt in enumerate(dts):
            d = np.array(dt['keypoints'])
            xd = d[0::3]; yd = d[1::3]
            if k1>0:
                # measure the per-keypoint distance if keypoints visible
                dx = xd - xg
                dy = yd - yg
            else:
                # measure minimum distance to keypoints in (x0,y0) & (x1,y1)
                z = np.zeros((k))
                dx = np.max((z, x0-xd),axis=0)+np.max((z, xd-x1),axis=0)
                dy = np.max((z, y0-yd),axis=0)+np.max((z, yd-y1),axis=0)
            e = (dx**2 + dy**2) / vars / (gt['area']+np.spacing(1)) / 2
            if k1 > 0:
                e=e[vg > 0]
            ious[i, j] = np.sum(np.exp(-e)) / e.shape[0]
    return ious

我们主要来看最后几行:

e = (dx**2 + dy**2) / vars / (gt['area']+np.spacing(1)) / 2
if k1 > 0:
    e=e[vg > 0]
ious[i, j] = np.sum(np.exp(-e)) / e.shape[0]

先来看这句:e = (dx**2 + dy**2) / vars / (gt['area']+np.spacing(1)) / 2,它就对应公式的关键点检测——HRNet原理详解篇_关键点检测_25,其中dx**2 + dy**2表示预测关键点和真实关键点的欧式距离的平方,即关键点检测——HRNet原理详解篇_2d_26。vars为表示的是关键点检测——HRNet原理详解篇_关键点检测_27,是个常数, (gt['area']+np.spacing(1))表示关键点检测——HRNet原理详解篇_关键点检测_28,加上np.spacing(1)是防止分母为0。

综上,关键点检测——HRNet原理详解篇_HRNet_29,再来看if k1 > 0: e=e[vg > 0]表示如果存在可见关键点,就从之前计算的 e 中筛选出可见的关键点对应的值。**【注意一下代码中的K1和公式中关键点检测——HRNet原理详解篇_数据_30表示的不是一个,代码中K1表示关键点的个数,即关键点检测——HRNet原理详解篇_2d_31】**这步对应公式关键点检测——HRNet原理详解篇_数据_32

最后再经过ious[i, j] = np.sum(np.exp(-e)) / e.shape[0],这里的e.shape[0]其实就是K1,表示可见关键点个数,那么经过这步之后,ious[i, j]的值就表示OKS,即关键点检测——HRNet原理详解篇_2d_33

小结

这节就为大家介绍到这里啦,我觉得看到这里大家都是懵懵的,没关系,因为HRNet我是准备分三小结来为大家介绍,所以这节内容写的较为简略。在下一节,我将花一万字给大家好好解析HRNet的源码,大家看完所有的内容,再回来消化消化这部分,说不定有意想不到的收获喔。🍊🍊🍊

拜拜啦~~~我们下期见。🥗🥗🥗

参考链接

HRNet论文🍁🍁🍁

HRNet网络简介🍁🍁🍁