文章目录

  • 1. FocalLoss的应用场景
  • 2. 二分类场景下FocalLoss原理解释
  • 2.1 FocalLoss如何调节正负样本权重
  • 2.2 FocalLoss如何调节难易样本权重
  • 2.3 整合上述过程,完成FocalLoss
  • 2.4 Pytorch 实现FocalLoss
  • 3. 多分类场景下的FocalLoss
  • 3.1 FocalLoss调节多分类的类别权重
  • 3.2 FocalLoss调节多分类难易样本权重
  • 3.3 整合上述过程,完成多分类的FocalLoss
  • 3.4 Pytorch 实现多分类FocalLoss

1. FocalLoss的应用场景

学一个东西,首先要知道这个东西是干嘛用的。

FocalLoss主要有两个作用,这也决定了它的应用场景:

  1. FocalLoss可以调节正负样本的loss权重。这意味着,当正负样本数量及其不平衡时,可以考虑使用FocalLoss。
  2. FocalLoss可以调节难易样本的loss权重。这意味着,当训练样本的难易程度不平衡时,可以考虑使用FocalLoss

这也是“Focal Loss”的名字的含义,把目光聚焦(Focal)在那些“少的,难的”样本上。

虽然大部分博客讨论FocalLoss都是在目标检测场景下,但其实FocalLoss其他场景下都可以用。

举个NLP的应用场景:

  1. 当我们在情感分类(好评/差评)时,若99%都是好评,只有1%是差评,就可以考虑使用FocalLoss通过loss来调节数据不平衡问题。
  2. 情感分类问题有些样本很难,例如:“我家狗吃了你的菜连夜给我做了四菜一汤”。而有些样本很简单,例如“差评,太难吃了”。这种场景下,FocalLoss可以帮助调节难易样本的loss权重,从而更好的学习到难样本的特征。

2. 二分类场景下FocalLoss原理解释

本节会分别讨论FocalLoss是如何实现其两个功能的,然后再进行整合。

2.1 FocalLoss如何调节正负样本权重

二分类问题我们通常使用交叉熵计算Loss,损失函数如下:

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_权重

其中CE是CrossEntropy的缩写,

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_深度学习_02

是预测结果,例如0.8。

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_权重_03

假设我们99%的样本都是负样本,那么最终计算出的loss负样本占比极大。要进行调节,很简单,只需要乘个权重就行了。比如:

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_深度学习_04

我们让正样本和负样本的loss给个9:1的权重就行了。将其0.9写成变量

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_机器学习_05

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_多分类_06

其中,

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_深度学习_07

2.2 FocalLoss如何调节难易样本权重

当我们在训练二分类问题时,经过sigmoid后最终的输出是0到1的概率,表示为正样本的概率是多少。

那假设标签为1的样本:

  • 若预测为为0.95,意味该样本是一个比较简单的样本。
  • 若预测值为0.65,意味着该样本稍微有点难
  • 若预测值为0.28,意味着该样本非常难。

负样本同理。即 预测值距离真值越远,则样本越难

难样本想要多学习,那就给它的loss分个较大的权重,简单样本易学习,那就给个较小的权重。那我们可以直接用它的难易程度给它分权重嘛,例如:

假设标签为1的样本:

  • 若预测为为0.95,意味该样本是一个比较简单的样本。权重为 (1-0.95) = 0.05
  • 若预测值为0.65,意味着该样本稍微有点难。权重为 (1-0.65) = 0.35
  • 若预测值为0.28,意味着该样本非常难。权重为 (1-0.28) = 0.72

按照这个思路,我们就可以得到如下损失函数:

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_权重_08

这样你可能还不过瘾,你想让简单样本权重更低,难样本权重更高,那么也很简单,只需要加个平方就行了,这样小的会更小,大的会更大。这样我们会得到如下公式:

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_分类_09

但你可能会觉得平方太小或太大,那么我们把平方写成超参数

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_分类_10

,此时公式就变成了如下:

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_机器学习_11

这样我们就完成了难易样本权重的调节。最后再总结一下参数

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_分类_10


  • FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_深度学习_13

  • FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_多分类_14

  • FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_多分类_14

  • 通常取
  • FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_深度学习_16

2.3 整合上述过程,完成FocalLoss

整合过程很简单,把

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_机器学习_05


FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_分类_10

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_机器学习_19

这样写稍显难看,所以我们定义两个新的变量

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_机器学习_20


FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_机器学习_21

, 其中:

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_深度学习_22

那么FocalLoss就可以写成如下的最终公式:

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_深度学习_23

这就是FocalLoss的公式。

2.4 Pytorch 实现FocalLoss

import torch
from torch import nn


class BinaryFocalLoss(nn.Module):
"""
参考 https://github.com/lonePatient/TorchBlocks
"""

def __init__(self, gamma=2.0, alpha=0.25, epsilnotallow=1.e-9):
super(BinaryFocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
self.epsilon = epsilon

def forward(self, input, target):
"""
Args:
input: model's output, shape of [batch_size, num_cls]
target: ground truth labels, shape of [batch_size]
Returns:
shape of [batch_size]
"""
multi_hot_key = target
logits = input
# 如果模型没有做sigmoid的话,这里需要加上
# logits = torch.sigmoid(logits)
zero_hot_key = 1 - multi_hot_key
loss = -self.alpha * multi_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()
loss += -(1 - self.alpha) * zero_hot_key * torch.pow(logits, self.gamma) * (1 - logits + self.epsilon).log()
return loss.mean()


if __name__ == '__main__':
m = nn.Sigmoid()
loss = BinaryFocalLoss()
input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)
output = loss(m(input), target)
print("loss:", output)
output.backward()

3. 多分类场景下的FocalLoss

有了前面二分类的基础,多分类就影刃而解了。

3.1 FocalLoss调节多分类的类别权重

假设我们有个三分类的场景,y=(1, 2, 3),他们的样本数量分别是100个,2000个和10000个。

那么此时我们的

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_机器学习_20

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_深度学习_25

在多分类场景下,我们的

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_机器学习_05

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_深度学习_27

其中

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_分类_28

表示有

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_分类_28

在大部分博客甚至开源项目上,在多分类问题上

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_机器学习_30

3.2 FocalLoss调节多分类难易样本权重

同样,假设我们有个三分类的场景,y=(1, 2, 3),对于某个样本的预测结果如下:

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_机器学习_31

  • 若标签为1,那么构造难易程度的调制因子时只需要考虑
  • FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_深度学习_32

  • 若标签为2,同理,调制因子只需要考虑
  • FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_机器学习_33

  • ,同时说明这个样本很难。
  • 若标签为3,同理。

为了达到上述目的,我们可以使用one-hot向量来把不关心的非标签值给抹去,即:

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_多分类_34

此时,我们把其与调制因子结合,为

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_多分类_35

,为:

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_深度学习_36

这里,我们将 one-hot 向量用

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_分类_37

表示。这里的

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_机器学习_38

3.3 整合上述过程,完成多分类的FocalLoss

综上所述,在多分类场景下,FocalLoss的公式变成了如下:

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_分类_39

这里,

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_机器学习_20


FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_机器学习_21

的含义与二分类不同,

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_机器学习_20

为一个列表,里面是每个类别的权重。而

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_机器学习_21

是输出的概率分布,

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_分类_37

举个实际的例子来看一下该公式:

假设我们有个三分类的场景,y=(1, 2, 3),其中

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_深度学习_45


FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_机器学习_46

。对于样本

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_机器学习_47

,输出的概率分布为

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_分类_48

,,则FocalLoss为:

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_分类_49

从上面例子可以看出,因为one-hot的存在,真正对loss起作用的其实只有样本所在的那一行。

因此,我们可以将FocalLoss公式改进为如下:

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_分类_50

其中

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_分类_51

为当前样本的类别,

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_分类_52

表示类别c对应的权重,

FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现_机器学习_53

3.4 Pytorch 实现多分类FocalLoss

class FocalLoss(nn.Module):
"""
参考 https://github.com/lonePatient/TorchBlocks
"""

def __init__(self, gamma=2.0, alpha=1, epsilnotallow=1.e-9, device=None):
super(FocalLoss, self).__init__()
self.gamma = gamma
if isinstance(alpha, list):
self.alpha = torch.Tensor(alpha, device=device)
else:
self.alpha = alpha
self.epsilon = epsilon

def forward(self, input, target):
"""
Args:
input: model's output, shape of [batch_size, num_cls]
target: ground truth labels, shape of [batch_size]
Returns:
shape of [batch_size]
"""
num_labels = input.size(-1)
idx = target.view(-1, 1).long()
one_hot_key = torch.zeros(idx.size(0), num_labels, dtype=torch.float32, device=idx.device)
one_hot_key = one_hot_key.scatter_(1, idx, 1)
logits = torch.softmax(input, dim=-1)
loss = -self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()
loss = loss.sum(1)
return loss.mean()


if __name__ == '__main__':
loss = FocalLoss(alpha=[0.1, 0.2, 0.3, 0.15, 0.25])
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
print(output)
output.backward()