文章目录

  • 重点提示
  • 使用场景
  • 公式
  • 数学背景
  • 用法


重点提示

注意,PyTorch的Cross Entropy Loss与其它框架的不同,因为PyTorch中该损失函数其实自带了“nn.LogSoftmax”与“nn.NLLLoss”两个方法。因此,在PyTorch的Cross Entropy Loss之前请勿再使用Softmax方法!

使用场景

当现在面临多分类问题(不限于二分类问题)需要Loss函数时,Cross Entropy Loss是一个很方便的工具。

公式

infonce loss代码pytorch_Cross
其中class为样本x的类别,x[class]为样本类别对应的预测分数;对于中间的公式,核心思想就是计算样本目标类别分值的softmax后取负对数作为分类损失。

数学背景

infonce loss代码pytorch_深度学习_02
易知s其实是分类正确的概率,介于0~1之间;对照着log函数的图像可以看出,当s=1时(此时分类完全正确),取完对数为0;当s接近0时(分类完全错误),取完对数为负无穷,取复数后变为正无穷;恰好可以充当loss函数:分类正确时损失为0,分类越错误损失越大。

除此之外,可以看出log函数越接近0,梯度越大;log函数越接近1,梯度越小。因此在更新参数时,当网络分类错的很离谱(loss较大时),求导后会得到比较大的梯度,从而大幅更新网络参数;随着网络正确率的升高,梯度也会逐渐平缓,渐渐进入“微调阶段”。

infonce loss代码pytorch_python_03

用法

loss = torch.nn.CrossEntropyLoss()
output = loss(input, target)
output.backward()

其中,input为样本的预测结果矩阵,形状为(样本数量,类别数量),例如100个样本实现二分类形状就是(100,2),每一列的index分别表示对应的类别;target为标签向量,形状为(样本数量),其中为各样本对应的类别index。
假如是二分类,对应的最理想(完全正确)的预测结果应如下所示:

target:1,0,1

input:

0(类)

1(类)

0

1

1

0

0

1

target向量的长度为3,说明现在有三个样本,其中第一个、第三个样本的标签均为1,第二个样本的标签为0;而input矩阵的形状为(3,2),每行为对应样本的预测结果分值,而每列为对应类别的分值。当我们希望获得当前样本被分类为1的分值时,我们取第1列,对应的向量为(1,0,1);当我们希望获得当前样本被分类为0的分值时,我们取第0列,对应的向量为(0,1,0)

实际预测中,很少能达到这么完美的情况,加上CrossEntropyLoss一般与Softmax连用,因此input矩阵中的每个元素表示的其实是第i个样本(i行)被分类为j类(j列)的概率

我们以二分类为例,如下所示:
input:

0(类)

1(类)

0.3

0.7

0.6

0.4

0.2

0.8

该input矩阵表示的其实是第1个样本被分类到1类的概率是0.7,第2个样本被分类到1类的概率是0.4,第3个样本被分类到1类的概率是0.8。