下面注释insightface中 大规模人脸分类层实现细节
完整分类代码放到页面末尾
大规模人脸分类的基本思想是,分类层可线性拆分,可将分类层均匀拆分到N块GPU上。分类类别数量增大,仅需增加GPU的数量即可,单卡存放部分的分类层参数量可不变。
(1)下面代码就是实现分类层如何拆分。如果分类类别数量不能被gpu数量整除,那么多出的部分,被rank靠前的gpu卡,每个多分到1个类别。 self.num_local就是当前gpu负责分类类别数,self.class_start就是当前GPU负责的分类权重在全局分类权重的起始位置
self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size)
self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size)
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
(2) 将每个gpu上的特征feature都聚合在一起构成全体特征total_feature, 然后和当前GPU上的局部分类权重相乘,得到局部的logit值。每个gpu上都存有全局特征total_feature
total_features = torch.zeros(
size=[self.batch_size * self.world_size, self.embedding_size], device=self.device)
dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data)
total_features.requires_grad = True
(3)同样地,将每个gpu上的标签都聚合在一起构成全体标签total_label。但是由于每个gpu只能存储部分分类权重,因此挑选出全体特征所对应标签预测,可以在当前gpu进行的样本。index_positive就是指那些样本的标签,落在当前GPU上的样本索引,剩余的则在当前GPU计算中忽略,此处用-1来标识。
total_label = torch.zeros(size=[self.batch_size * self.world_size], device=self.device,dtype=torch.long)
dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label)
self.sample(total_label)
@torch.no_grad()
def sample(self, total_label):
"""
Sample all positive class centers in each rank, and random select neg class centers to filling a fixed
`num_sample`.
total_label: tensor
Label after all gather, which cross all GPUs.
"""
index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local)
total_label[~index_positive] = -1
total_label[index_positive] -= self.class_start
(4)下面的logits值就是全体样本特征在当前GPU上的局部logits值。allreduce 操作获取全局logit的最大值,这样是之后为了计算exp保证数据不会溢出。logits_exp就是类别概率值了。\(p_k = \frac{e^{x_k-x_{m}}}{\sum{e^{x_j-x_m}}}\),先求exp,然后allreduce求 logits exp和,然后得到了每个样本的类别概率值。
(5)index = torch.where(total_label != -1)[0]
是获取全体样本类别标签落在当前GPU上的样本索引值。针对这些样本获取这些样本的标签的one-hot表示。
(6)为了方便描述,在说明loss计算过程部分,用prob代替grad。设所有卡上的图片数量是N,当前gpu 上存储分类类别权重[d, M],那么prob的维度就是[N, M]. 我们仅考虑样本标签落在该GPU上,即计算时仅考虑index处取值。取index样本的真实标签位置处的概率值,赋值给loss的index。采用allreduce操作,对loss求和。loss是一个数组,求和后仍是数组。比如采用2块GPU计算,N=N1+N2. gpu0上计算的loss只有一半进行了赋值,在gpu1上计算的loss也只有另一半进行了赋值。采用all_reduce+sum操作,其实也是进行真实标签对应的prob进行了汇总,最后才是取负对数求均值。
(7)\(p_k = \frac{e^{x_k-x_{m}}}{\sum{e^{x_j-x_m}}}\)这个表达是样本属于第k类的概率值。设样本x的真实标签是c,那么由该样本产生的损失值和导数分别为:\(L=-log(p_c)=-(e^{x_c-x_{m}})+log(\sum{e^{x_j-x_m}})\)
\(\frac{\partial L}{\partial x_k}= \begin {cases} p_k, k\not=c \\ -1 + p_k, k= c\end {cases}\)
因此对输入x的梯度值就是\(p_k\),当然真实标签处的梯度,还要加上-1.由于上述仅仅是对单个样本求梯度,实际是对一次前向传播的所有样本求梯度,由于loss是取均值,当然梯度也要取均值,这里就是要除以所有卡上样本数量了。
# calculate loss
loss = torch.zeros(prob.size()[0], 1, device=grad.device)
loss[index] = prob[index].gather(1, total_label[index, None])
dist.all_reduce(loss, dist.ReduceOp.SUM)
loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)
total_label, norm_weight = self.prepare(label, optimizer)
total_features = torch.zeros(
size=[self.batch_size * self.world_size, self.embedding_size], device=self.device)
dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data)
total_features.requires_grad = True
logits = self.forward(total_features, norm_weight)
logits = self.margin_softmax(logits, total_label)
with torch.no_grad():
max_fc = torch.max(logits, dim=1, keepdim=True)[0]
dist.all_reduce(max_fc, dist.ReduceOp.MAX)
# calculate exp(logits) and all-reduce
logits_exp = torch.exp(logits - max_fc)
logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)
# calculate prob
logits_exp.div_(logits_sum_exp)
# get one-hot
grad = logits_exp
index = torch.where(total_label != -1)[0]
one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device)
one_hot.scatter_(1, total_label[index, None], 1)
# calculate loss
loss = torch.zeros(grad.size()[0], 1, device=grad.device)
loss[index] = grad[index].gather(1, total_label[index, None])
dist.all_reduce(loss, dist.ReduceOp.SUM)
loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)
# calculate grad
grad[index] -= one_hot
grad.div_(self.batch_size * self.world_size)
logits.backward(grad)
全部代码:
import logging
import os
import torch
import torch.distributed as dist
from torch.nn import Module
from torch.nn.functional import normalize, linear
from torch.nn.parameter import Parameter
class PartialFC(Module):
"""
Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint,
Partial FC: Training 10 Million Identities on a Single Machine
See the original paper:
https://arxiv.org/abs/2010.05222
"""
@torch.no_grad()
def __init__(self, rank, local_rank, world_size, batch_size, resume,
margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"):
"""
rank: int
Unique process(GPU) ID from 0 to world_size - 1.
local_rank: int
Unique process(GPU) ID within the server from 0 to 7.
world_size: int
Number of GPU.
batch_size: int
Batch size on current rank(GPU).
resume: bool
Select whether to restore the weight of softmax.
margin_softmax: callable
A function of margin softmax, eg: cosface, arcface.
num_classes: int
The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size,
required.
sample_rate: float
The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling
can greatly speed up training, and reduce a lot of GPU memory, default is 1.0.
embedding_size: int
The feature dimension, default is 512.
prefix: str
Path for save checkpoint, default is './'.
"""
super(PartialFC, self).__init__()
#
self.num_classes: int = num_classes
self.rank: int = rank
self.local_rank: int = local_rank
self.device: torch.device = torch.device("cuda:{}".format(self.local_rank))
self.world_size: int = world_size
self.batch_size: int = batch_size
self.margin_softmax: callable = margin_softmax
self.sample_rate: float = sample_rate
self.embedding_size: int = embedding_size
self.prefix: str = prefix
self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size)
self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size)
self.num_sample: int = int(self.sample_rate * self.num_local)
self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank))
self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank))
if resume:
try:
self.weight: torch.Tensor = torch.load(self.weight_name)
self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name)
if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local:
raise IndexError
logging.info("softmax weight resume successfully!")
logging.info("softmax weight mom resume successfully!")
except (FileNotFoundError, KeyError, IndexError):
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
logging.info("softmax weight init!")
logging.info("softmax weight mom init!")
else:
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
logging.info("softmax weight init successfully!")
logging.info("softmax weight mom init successfully!")
self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank)
self.index = None
if int(self.sample_rate) == 1:
self.update = lambda: 0
self.sub_weight = Parameter(self.weight)
self.sub_weight_mom = self.weight_mom
else:
self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank))
def save_params(self):
""" Save softmax weight for each rank on prefix
"""
torch.save(self.weight.data, self.weight_name)
torch.save(self.weight_mom, self.weight_mom_name)
@torch.no_grad()
def sample(self, total_label):
"""
Sample all positive class centers in each rank, and random select neg class centers to filling a fixed
`num_sample`.
total_label: tensor
Label after all gather, which cross all GPUs.
"""
index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local)
total_label[~index_positive] = -1
total_label[index_positive] -= self.class_start
if int(self.sample_rate) != 1:
positive = torch.unique(total_label[index_positive], sorted=True)
if self.num_sample - positive.size(0) >= 0:
perm = torch.rand(size=[self.num_local], device=self.device)
perm[positive] = 2.0
index = torch.topk(perm, k=self.num_sample)[1]
index = index.sort()[0]
else:
index = positive
self.index = index
total_label[index_positive] = torch.searchsorted(index, total_label[index_positive])
self.sub_weight = Parameter(self.weight[index])
self.sub_weight_mom = self.weight_mom[index]
def forward(self, total_features, norm_weight):
""" Partial fc forward, `logits = X * sample(W)`
"""
torch.cuda.current_stream().wait_stream(self.stream)
logits = linear(total_features, norm_weight)
return logits
@torch.no_grad()
def update(self):
""" Set updated weight and weight_mom to memory bank.
"""
self.weight_mom[self.index] = self.sub_weight_mom
self.weight[self.index] = self.sub_weight
def prepare(self, label, optimizer):
"""
get sampled class centers for cal softmax.
label: tensor
Label tensor on each rank.
optimizer: opt
Optimizer for partial fc, which need to get weight mom.
"""
with torch.cuda.stream(self.stream):
total_label = torch.zeros(
size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long)
dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label)
self.sample(total_label)
optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None)
optimizer.param_groups[-1]['params'][0] = self.sub_weight
optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom
norm_weight = normalize(self.sub_weight)
return total_label, norm_weight
def forward_backward(self, label, features, optimizer):
"""
Partial fc forward and backward with model parallel
label: tensor
Label tensor on each rank(GPU)
features: tensor
Features tensor on each rank(GPU)
optimizer: optimizer
Optimizer for partial fc
Returns:
--------
x_grad: tensor
The gradient of features.
loss_v: tensor
Loss value for cross entropy.
"""
total_label, norm_weight = self.prepare(label, optimizer)
total_features = torch.zeros(
size=[self.batch_size * self.world_size, self.embedding_size], device=self.device)
dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data)
total_features.requires_grad = True
logits = self.forward(total_features, norm_weight)
logits = self.margin_softmax(logits, total_label)
with torch.no_grad():
max_fc = torch.max(logits, dim=1, keepdim=True)[0]
dist.all_reduce(max_fc, dist.ReduceOp.MAX)
# calculate exp(logits) and all-reduce
logits_exp = torch.exp(logits - max_fc)
logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)
# calculate prob
logits_exp.div_(logits_sum_exp)
# get one-hot
grad = logits_exp
index = torch.where(total_label != -1)[0]
one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device)
one_hot.scatter_(1, total_label[index, None], 1)
# calculate loss
loss = torch.zeros(grad.size()[0], 1, device=grad.device)
loss[index] = grad[index].gather(1, total_label[index, None])
dist.all_reduce(loss, dist.ReduceOp.SUM)
loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)
# calculate grad
grad[index] -= one_hot
grad.div_(self.batch_size * self.world_size)
logits.backward(grad)
if total_features.grad is not None:
total_features.grad.detach_()
x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True)
# feature gradient all-reduce
dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0)))
x_grad = x_grad * self.world_size
# backward backbone
return x_grad, loss_v