DAB-DETR是吸收了Deformable-DETR，Conditional-DETR，Anchor-DETR等基础上完善而来的。其主要贡献为将query初始化为x,y,w,h思维坐标形式。

DAB-DETR主要是对Decoder模型进行改进。博主也主要是对Decoder模块的模型进行解析。

## 位置编码的温度值调整

``````class PositionEmbeddingSineHW(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperatureH = temperatureH
self.temperatureW = temperatureW
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
# import ipdb; ipdb.set_trace()
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_tx
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats)
pos_y = y_embed[:, :, :, None] / dim_ty
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
# import ipdb; ipdb.set_trace()
return pos``````

## Transformer整体架构

src：由backbone提取的特征信息，shape初始为 torch.Size([2, 256,19,24]) 后变为torch.Size([456, 2, 256])
mask：对图像进行补全掩码信息，shape初始为 torch.Size([2, 19, 24]) 后展平为 torch.Size([2, 456])

refpoint_embed：参考点坐标编码，即object_query，torch.Size([300, 4])。在Decoder模块使用，其是在DAB-DETR模块定义初始化的：self.refpoint_embed = nn.Embedding(num_queries, query_dim)，初始为torch.Size([300,4])，后经过refpoint_embed = refpoint_embed.unsqueeze(1).repeat(1, bs, 1)变为torch.Size([300, 4])。

pos_embed：位置编码信息，shape初始为 torch.Size([2, 256,19,24]) 后变为torch.Size([456, 2, 256])

``````# flatten NxCxHxW to HWxNxC
bs, c, h, w = src.shape  #初始为2，256，19，24
src = src.flatten(2).permute(2, 0, 1)#拉平：
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
refpoint_embed = refpoint_embed.unsqueeze(1).repeat(1, bs, 1)

``memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)``

``````num_queries = refpoint_embed.shape[0]
if self.num_patterns == 0:
tgt = torch.zeros(num_queries, bs, self.d_model, device=refpoint_embed.device)
else:
tgt = self.patterns.weight[:, None, None, :].repeat(1, self.num_queries, bs, 1).flatten(0, 1) # n_q*n_pat, bs, d_model
refpoint_embed = refpoint_embed.repeat(self.num_patterns, 1, 1) # n_q*n_pat, bs, d_model``````

``````hs, references = self.decoder(tgt, memory, memory_key_padding_mask=mask,
pos=pos_embed, refpoints_unsigmoid=refpoint_embed)
return hs, references``````

## Encoder模块构建

DAB-DETR的Encoder模块与DETR并没有太大差别。

### EncoderLayer

`src_key_padding_mask`:将图片补全shape为【2，456】
`src`：通过ResNet提取到的特征，由二维转为一维，shape为 torch.Size([456, 2, 256])
`pos`：位置编码信息，原本为两种，分别为sincos位置编码与可学习的位置编码，此外，DAB-DETR还提出一种可以跳转宽高的位置编码方式。shape为 torch.Size([456, 2, 256])
`src2` 通过self-attention获得，shape为 torch.Size([456, 2, 256])，随后经过dropout层，norm层。最终的输出结果为src：torch.Size([456, 2, 256])，将该结果送入Decoder。

``````q = k = self.with_pos_embed(src, pos)
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src``````

``````def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos``````

## Encoder模块

Encoder即有6个EncoderLayer构成。

``````class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None, d_model=256):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.query_scale = MLP(d_model, d_model, d_model, 2)
self.norm = norm
def forward(self, src,
pos: Optional[Tensor] = None):
output = src
for layer_id, layer in enumerate(self.layers):
# rescale the content and pos sim
pos_scales = self.query_scale(output)
if self.norm is not None:
output = self.norm(output)
return output``````

## Decoder模块代码实现

``output = tgt``

`reference_points`归一化，shape仍为torch.Size([300, 2, 4])

``reference_points = refpoints_unsigmoid.sigmoid()``

``````obj_center = reference_points[..., :self.query_dim]  #torch.Size([300, 2, 4])
query_sine_embed = gen_sineembed_for_position(obj_center) #torch.Size([300,2,512])
query_pos = self.ref_point_head(query_sine_embed) #torch.Size([300, 2, 256])``````

`gen_sineembed_for_position`方法如下：

``````def gen_sineembed_for_position(pos_tensor):
# n_query, bs, _ = pos_tensor.size()
# sineembed_tensor = torch.zeros(n_query, bs, 256)
scale = 2 * math.pi
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
dim_t = 10000 ** (2 * (dim_t // 2) / 128)
x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale
pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
if pos_tensor.size(-1) == 2:
pos = torch.cat((pos_y, pos_x), dim=2)
elif pos_tensor.size(-1) == 4:
w_embed = pos_tensor[:, :, 2] * scale
pos_w = w_embed[:, :, None] / dim_t
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
h_embed = pos_tensor[:, :, 3] * scale
pos_h = h_embed[:, :, None] / dim_t
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
else:
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
return pos``````

`ref_anchor_head`是一个MLP，`self.ref_anchor_head = MLP(d_model, d_model, 2, 2)`输入维度为256，中间层宽度为256，输出维度为2，隐藏层数为2。
refHW_cond为torch.Size([300, 2, 2])
query_sine_embed 初始为torch.Size([300, 2, 512]),经过下面`query_sine_embed = query_sine_embed[...,:self.d_model] * pos_transformation`后变为torch.Size([300, 2, 256])，该句代码意思为取前256维

``````if self.query_scale_type != 'fix_elewise':#执行
if layer_id == 0:#第一层时执行
pos_transformation = 1
else:
pos_transformation = self.query_scale(output) #query_scale为MLP
else:
pos_transformation = self.query_scale.weight[layer_id]
#取出  query_sine_embed的前256维，即x,y与pos_transformation相乘
query_sine_embed = query_sine_embed[...,:self.d_model] * pos_transformation
if self.modulate_hw_attn:
refHW_cond = self.ref_anchor_head(output).sigmoid() #将其送入MLP后进行归一化 torch.Size([300, 2, 2])
query_sine_embed[..., self.d_model // 2:] *= (refHW_cond[..., 0] / obj_center[..., 2]).unsqueeze(-1)
query_sine_embed[..., :self.d_model // 2] *= (refHW_cond[..., 1] / obj_center[..., 3]).unsqueeze(-1)``````

``````output = layer(output, memory, tgt_mask=tgt_mask,
pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed,
is_first=(layer_id == 0))``````

## 第一层DecoderLayer模块

### Self_Attention

`tgt`即上一层DecoderLayer的输出结果，此时全为0，shape为 torch.Size([300, 2, 256])

``````if not self.rm_self_attn_decoder:
# Apply projections here
# shape: num_queries x batch_size x 256
q_content = self.sa_qcontent_proj(tgt)      # target is the input of the first decoder layer. zero by default.
q_pos = self.sa_qpos_proj(query_pos)
k_content = self.sa_kcontent_proj(tgt)
k_pos = self.sa_kpos_proj(query_pos)
v = self.sa_v_proj(tgt)
num_queries, bs, n_model = q_content.shape
hw, _, _ = k_content.shape
q = q_content + q_pos
k = k_content + k_pos
#tgt2为Attention计算结果，torch.Size([300, 2, 256])
# ========== End of Self-Attention =============
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)``````

### Cross_Attention

``````q_content = self.ca_qcontent_proj(tgt)#torch.Size([300, 2, 256])
k_content = self.ca_kcontent_proj(memory)#torch.Size([456, 2, 256])
v = self.ca_v_proj(memory)#torch.Size([456, 2, 256])

k_pos = self.ca_kpos_proj(pos)#对K进行位置编码，pos来自于Encoder。torch.Size([456, 2, 256])``````

``````if is_first or self.keep_query_pos:#self.keep_query_pos默认为False
q_pos = self.ca_qpos_proj(query_pos)# query_pos:torch.Size([300, 2, 256])
q = q_content + q_pos
k = k_content + k_pos
else:
q = q_content
k = k_content``````

``````q = q.view(num_queries, bs, self.nhead, n_model//self.nhead)# q分头：torch.Size([300, 2, 8, 32])
query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)#query_sine_embed即
q = torch.cat([q, query_sine_embed], dim=3).view(num_queries, bs, n_model * 2)
#q经过拼接变为torch.Size([300, 2, 512])
k = torch.cat([k, k_pos], dim=3).view(hw, bs, n_model * 2)#torch.Size([456, 2, 512])``````

``tgt2 = self.cross_attn(query=q, key=k, value=v, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0]``

``````return multi_head_attention_forward(
self.in_proj_weight, self.in_proj_bias,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,

## 锚点更新策略

``````if self.bbox_embed is not None:
if self.bbox_embed_diff_each_layer:#是否共享参数：false
tmp = self.bbox_embed[layer_id](output)
else:
tmp = self.bbox_embed(output)#经过MLP获得output偏移量x,y,w,h torch.Size([300, 2, 4])
# import ipdb; ipdb.set_trace()
tmp[..., :self.query_dim] += inverse_sigmoid(reference_points)
new_reference_points = tmp[..., :self.query_dim].sigmoid()
if layer_id != self.num_layers - 1:
ref_points.append(new_reference_points)
reference_points = new_reference_points.detach()
if self.return_intermediate:
intermediate.append(self.norm(output))``````

## 第二层DecoderLayer模块

``````obj_center = reference_points[..., :self.query_dim]
query_sine_embed = gen_sineembed_for_position(obj_center)

``self.query_scale = MLP(d_model, d_model, d_model, 2)``

## Decoder模块

``````if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)``````

``````if self.return_intermediate:
if self.bbox_embed is not None:
return [
torch.stack(intermediate).transpose(1, 2),
torch.stack(ref_points).transpose(1, 2),
]
else:
return [
torch.stack(intermediate).transpose(1, 2),
reference_points.unsqueeze(0).transpose(1, 2)
]``````

Transformer的decoder模块最终返回结果：

``````hs, references = self.decoder(tgt, memory, memory_key_padding_mask=mask,
pos=pos_embed, refpoints_unsigmoid=refpoint_embed)``````

## DAB-DETR整体模块

``````if not self.bbox_embed_diff_each_layer:#是否权值共享
reference_before_sigmoid = inverse_sigmoid(reference)#反归一化
tmp = self.bbox_embed(hs)#torch.Size([6, 2, 300, 4])
tmp[..., :self.query_dim] += reference_before_sigmoid
outputs_coord = tmp.sigmoid()``````

outputs_coord值即预测框的xywh

pred_logits 为类别预测（这里是91类）torch.Size([2, 300, 91])

pred_boxes为box框预测 torch.Size([2, 300, 4])

aux_outputs为前5层DecoderLayer的结果。为list，有5个值。