机器学习中的分类问题常用到交叉熵作为损失函数,那么Pytorch中如何使用交叉熵损失函数呢?这就涉及到torch.nn中的三个类:nn.LogSoftmax、nn.NLLLoss、nn.CrossEntropyLoss,今天剖析一下这几个类,便于以后查找。

一、nn.LogSoftmax

softmax常用在网络的输出层上,以得到每个类别的概率,顾名思义,nn.LogSoftmax就是对softmax的结果取了一个log。

打印 pytorch学习率 pytorch打印loss_机器学习

来源:https://pytorch.org/docs/stable/generated/torch.nn.LogSoftmax.html#torch.nn.LogSoftmax

注意,使用这个类时最好要指定dim,即沿着tensor的哪一个维度做softmax,如果不指定,也能做,那么沿着哪一维做呢?通过层层查看源码,我们发现:

打印 pytorch学习率 pytorch打印loss_机器学习_02

来源:https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#log_softmax

打印 pytorch学习率 pytorch打印loss_html_03

来源:https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#log_softmax

如果不指定dim,torch会调用到_get_softmax_dim函数,该函数会根据输入tensor的维度总数指定一个,0、1、3维tensor,沿着第0维做;其他的,沿着第1维做。同时,该函数给我们了警告,告诉我们应该人为指定dim.

二、nn.NLLLoss

这个loss的全称叫负对数似然loss(negative log likelihood loss),里面的操作我认为其实就是取了个负。。如果不考虑前面可选的权重的话:

打印 pytorch学习率 pytorch打印loss_pytorch_04

来源:https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html#torch.nn.NLLLoss

因为它要求输入就已经是每个类的对数值了。值得注意的是,target并不是one-hot向量,而是范围在[0, C-1]之间的类别索引。这一点和后面要说的CrossEntropyLoss是一样的。

三、nn.CrossEntropyLoss

nn.CrossEntropyLoss可以看作是nn.LogSoftmax和nn.NLLLoss的结合:

打印 pytorch学习率 pytorch打印loss_打印 pytorch学习率_05

来源:https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=crossentropy#torch.nn.CrossEntropyLoss

即对输入数据先做log_softmax,再过NLLLoss。

注意体会红框内的计算过程,可以理解为什么它要求target不是one-hot向量,而是类别索引的标号:target中的类别序号yn实际上给出了计算softmax时的使用的输入向量x的索引,如此一来,

打印 pytorch学习率 pytorch打印loss_打印 pytorch学习率_06

计算的恰好是第yn类对应的似然值。也就是说x_n,yn代表的是某batch中第n个样本的yn个值。

这个计算形式跟我们熟悉的交叉熵的形式:

loss = -(

打印 pytorch学习率 pytorch打印loss_html_07

ylogy'+(1-y)log(1-y') )

是等效的。

还有一点要注意的是,如果reduction不是'none'的话,默认会取平均,这个平均不仅是对batch取平均,如果有其它维度的话,对其它维度也是平均的