对于交叉熵损失函数的来由有很多资料可以参考,这里就不再赘述。本文主要尝试对交叉熵损失函数的内部运算做深度解析。

1. 函数介绍

  Pytorch官网中对交叉熵损失函数的介绍如下:

CLASS torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=- 100,reduce=None, reduction=‘mean’, label_smoothing=0.0)

cross_entropy pytorch 输入 pytorch crossentropyloss weight_pytorch类别的分类问题。参数weight给定时,其为分配给每一个类别的权重的一维张量(Tensor)。当数据集分布不均衡时,这是很有用的。
  函数输入(input)应包含每一个类别的原始、非标准化分数。对于未批量化的输入,输入必须是大小为cross_entropy pytorch 输入 pytorch crossentropyloss weight_损失函数_02的张量,cross_entropy pytorch 输入 pytorch crossentropyloss weight_损失函数_03cross_entropy pytorch 输入 pytorch crossentropyloss weight_损失函数_04,在K维情况下,cross_entropy pytorch 输入 pytorch crossentropyloss weight_深度学习_05
  函数目标值(target)有两种情况,本文只介绍其中较为有效的一种情况,即target为类索引
   本文以下内容均为target为类索引的情况。

cross_entropy pytorch 输入 pytorch crossentropyloss weight_损失函数_06之间的类索引,cross_entropy pytorch 输入 pytorch crossentropyloss weight_pytorch为类别数。参数reduction设为'none'时,交叉熵损失可描述如下:
cross_entropy pytorch 输入 pytorch crossentropyloss weight_官网_08

  其中,cross_entropy pytorch 输入 pytorch crossentropyloss weight_深度学习_09是输入,cross_entropy pytorch 输入 pytorch crossentropyloss weight_官网_10是目标值,cross_entropy pytorch 输入 pytorch crossentropyloss weight_深度学习_11是weight,cross_entropy pytorch 输入 pytorch crossentropyloss weight_pytorch是类别数,cross_entropy pytorch 输入 pytorch crossentropyloss weight_深度学习_13为batch size。在reduction不为'none'时(默认为'mean'),有:
cross_entropy pytorch 输入 pytorch crossentropyloss weight_深度学习_14

 需要指出的是,在这种情况下的交叉熵损失等价于LogSoftmaxNLLLoss的组合。1

  因此,我们可以从LogSoftmaxNLLLoss来深度解析交叉熵损失函数的内部运算。

2. LogSoftmax函数

  LogSoftmax()函数2公式如下:
cross_entropy pytorch 输入 pytorch crossentropyloss weight_损失函数_15
  即,先对输入值进行Softmax归一化处理,然后对归一化值取对数。这部分对应公式(1)中的cross_entropy pytorch 输入 pytorch crossentropyloss weight_python_16

  代码示例如下:

>>> import torch.nn as nn
>>> SM = nn.Softmax(dim=1) #Softmax函数
>>> x = torch.tensor([[1.0,3.0,4.0],[7.0,3.0,8.0],[9.0,7.0,5.0]])
>>> x
tensor([[1., 3., 4.],
        [7., 3., 8.],
        [9., 7., 5.]])
 
>>> output_SM = SM(x) #第一步,对x进行Softmax归一化处理
>>> output_SM
#每一行元素相加之和等于1
tensor([[0.0351, 0.2595, 0.7054],
        [0.2676, 0.0049, 0.7275],
        [0.8668, 0.1173, 0.0159]]) 
>>> out_L_SM = torch.log(output_SM) #第二步,对输出取log
>>> out_L_SM
tensor([[-3.3490, -1.3490, -0.3490],
        [-1.3182, -5.3182, -0.3182],
        [-0.1429, -2.1429, -4.1429]])
        
#直接使用LogSoftmax函数,一步到位
>>> L_SM = nn.LogSoftmax(dim=1)
>>> out_L_SM_ = L_SM(x)
>>> out_L_SM_
tensor([[-3.3490, -1.3490, -0.3490],
        [-1.3182, -5.3182, -0.3182],
        [-0.1429, -2.1429, -4.1429]])


3. NLLLoss函数

  Pytorch中的NLLLoss函数3“名不副实”,虽然名为负对数似然函数,但其内部并没有进行对数计算,而只是对输入值求平均后取负(函数参数reduction为默认值'mean',参数weight为默认值'none'时)。

  官网介绍如下:

CLASS torch.nn.NLLLoss(weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction=‘mean’)

  参数reduction值为'none'时:
cross_entropy pytorch 输入 pytorch crossentropyloss weight_官网_17
  其中,cross_entropy pytorch 输入 pytorch crossentropyloss weight_深度学习_09为输入,cross_entropy pytorch 输入 pytorch crossentropyloss weight_官网_10为目标值,cross_entropy pytorch 输入 pytorch crossentropyloss weight_深度学习_11为weight,cross_entropy pytorch 输入 pytorch crossentropyloss weight_深度学习_13为batch size。
  参数reduction值不为'none'时(默认为'mean'),有:
cross_entropy pytorch 输入 pytorch crossentropyloss weight_官网_22
  可以看出,当reduction'mean'时,即是对cross_entropy pytorch 输入 pytorch crossentropyloss weight_损失函数_23求加权平均值。weight参数默认为1,因此默认情况下,即是对cross_entropy pytorch 输入 pytorch crossentropyloss weight_损失函数_23求平均值。又cross_entropy pytorch 输入 pytorch crossentropyloss weight_官网_25,所以weight为默认值1时,cross_entropy pytorch 输入 pytorch crossentropyloss weight_深度学习_26。故此时,即是cross_entropy pytorch 输入 pytorch crossentropyloss weight_官网_27求平均后取负。 这部分对于公式(2)中的cross_entropy pytorch 输入 pytorch crossentropyloss weight_python_28

  实例代码验证如下:

>>> import torch
>>> NLLLoss = torch.nn.NLLLoss() #Pytorch负对数似然损失函数
>>> input = torch.randn(3,3)
>>>input
tensor([[1.4550, 2.3858, 1.1724],
        [0.4952, 1.5870, 0.9594],
        [1.4170, 0.4525, 0.2519]])
        
>>>target = torch.tensor([1,0,2]) #类索引目标值
>>> loss = NLLLoss(input, target)
>>> loss
tensor(-1.0443)

cross_entropy pytorch 输入 pytorch crossentropyloss weight_深度学习_29
  显然,平均取负结果和NLLLoss运算结果相同。

注:笔者窃以为,公式(5)中上式可写为 cross_entropy pytorch 输入 pytorch crossentropyloss weight_pytorch_30,如此则更容易理解。公式(2)同理。

4. 小结

  本文通过将CrossEntropyLoss拆解为LogSoftmaxNLLLoss两步,对交叉熵损失内部计算做了深度的解析,以更清晰地理解交叉熵损失函数。需要指出的是,本文所介绍的内容,只是对于CrossEntropyLoss的target为类索引的情况,CrossEntropyLoss的target还可以是每个类别的概率(Probabilities for each class),这种情况有所不同。


  学习总结,以作分享,如有不妥,敬请指出。


Reference



  1. https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss ↩︎
  2. https://pytorch.org/docs/stable/generated/torch.nn.LogSoftmax.html?highlight=logsoftmax#torch.nn.LogSoftmax ↩︎
  3. https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html?highlight=nllloss#torch.nn.NLLLoss ↩︎