pytorch 损失函数详解及自定义方法

损失函数是机器学习与深度学习解决问题中非常重要的一部分,可以说,损失函数给出了问题的定义,也就是需要优化的目标:怎么样可以认为这个模型是否够好、怎样可以认为当前训练是否有效等。

pytorch框架上手十分方便,也为我们定义了很多常用的损失函数。当然,面对特殊的应用场景或实际问题,往往也需要自行定义损失函数。

本文首先介绍如何自定义损失函数,再选择一些常用或经典的损失函数进行介绍。

1·Pytorch自定义损失函数

首先pytorch为我们提供的损失函数也位于torch.nn下。其实注意观察我们会发现torch.nn中的如各种layer和loss,在torch.nn.functional中也能找到,二者的区别在pytorch官网的discuss中能找到一些讨论,链接如下:
https://discuss.pytorch.org/t/whats-the-difference-between-torch-nn-functional-and-torch-nn/681

个人观点是torch.nn实现了torch.nn.functional中功能的封装,简化了使用接口,并且保存了训练参数、状态信息等。这一点我目前还没有仔细研究,以后可能会更详细的讨论,也欢迎大家发表看法。

PyTorch中的损失函数继承自其基类 _Loss(或_Loss的派生类 _WeightedLoss)例如 MSE Loss继承自_Loss:

@weak_module
class MSELoss(_Loss):

而CrossEntropyLoss继承自_WeightedLoss

@weak_module
class CrossEntropyLoss(_WeightedLoss):

查看_Loss和_WeightedLoss的源码如下,也使派生自abstract的Module类:

class _Loss(Module):
    def __init__(self, size_average=None, reduce=None, reduction='mean'):
        super(_Loss, self).__init__()
        if size_average is not None or reduce is not None:
            self.reduction = _Reduction.legacy_get_string(size_average, reduce)
        else:
            self.reduction = reduction


class _WeightedLoss(_Loss):
    def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
        super(_WeightedLoss, self).__init__(size_average, reduce, reduction)
        self.register_buffer('weight', weight)

在损失函数的定义中,首先有一大段关于损失函数的介绍,包括其数学定义、使用场景、参数含义、使用样例等。接着是重写的__init__()和forward()方法,跟我们自定义网络等是一样的。
除去说明部分,MSE loss的定义是这样的:

__constants__ = ['reduction']

    def __init__(self, size_average=None, reduce=None, reduction='mean'):
        super(MSELoss, self).__init__(size_average, reduce, reduction)

    @weak_script_method
    def forward(self, input, target):
        return F.mse_loss(input, target, reduction=self.reduction)

可见其forward方法确实调用了torch.nn.functional中的方法(这里的F)

这里也不需要自己实现backward()方法,根据pytorch自动求导机制,使用Tensor的各种算数操作中有一个Tensor的require_grad=True,这个操作就会具有自动求导的功能。关于自动求导我之后会再详细讨论。

通过上述的观察我们发现,只要仿照pytorch本身的损失函数定义方法,继承自_Loss或_WeightedLoss,但显然按照python习惯,命名是头部下划线的是不希望被外部访问的,所以我们只要继承自抽象类Module就好了,同时重写__init__()和forward()方法即可。例如:

class MyMSELoss(torch.nn.Module):
    def __init__(self):
      super(MyMSELoss, self).__init__()
        
    def forward(self, output, label):
      return torch.mean(torch.pow(output - label),2)

一般来说,loss的返回值要是一个标量。

上述的方法可以认为是利用pytorch本身对Tensor操作的支持实现了一些原本没有提供的损失函数。当然我们也可以通过numpy或scipy等完全自定义损失的计算。这种方法需要我们自定义损失函数的类继承自torch.autograd.Function并且实现forward()和backward()函数。

2·部分损失函数介绍

损失函数的简单介绍与使用其实在官网说明的比较细致了如MSE Loss :

@weak_module
class MSELoss(_Loss):
    r"""Creates a criterion that measures the mean squared error (squared L2 norm) between
    each element in the input :math:`x` and target :math:`y`.

    The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:

    .. math::
        \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
        l_n = \left( x_n - y_n \right)^2,

    where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
    (default ``'mean'``), then:

    .. math::
        \ell(x, y) =
        \begin{cases}
            \operatorname{mean}(L), &  \text{if reduction} = \text{'mean';}\\
            \operatorname{sum}(L),  &  \text{if reduction} = \text{'sum'.}
        \end{cases}

    :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
    of :math:`n` elements each.

    The sum operation still operates over all the elements, and divides by :math:`n`.

    The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``.

    Args:
        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
            the losses are averaged over each loss element in the batch. Note that for
            some losses, there are multiple elements per sample. If the field :attr:`size_average`
            is set to ``False``, the losses are instead summed for each minibatch. Ignored
            when reduce is ``False``. Default: ``True``
        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
            losses are averaged or summed over observations for each minibatch depending
            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
            batch element instead and ignores :attr:`size_average`. Default: ``True``
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``

    Shape:
        - Input: :math:`(N, *)` where :math:`*` means, any number of additional
          dimensions
        - Target: :math:`(N, *)`, same shape as the input

    Examples::

        >>> loss = nn.MSELoss()
        >>> input = torch.randn(3, 5, requires_grad=True)
        >>> target = torch.randn(3, 5)
        >>> output = loss(input, target)
        >>> output.backward()
    """
    __constants__ = ['reduction']

    def __init__(self, size_average=None, reduce=None, reduction='mean'):
        super(MSELoss, self).__init__(size_average, reduce, reduction)

    @weak_script_method
    def forward(self, input, target):
        return F.mse_loss(input, target, reduction=self.reduction)

不仅介绍了损失函数的意义、数学定义:
pytorch内置函数 pytorch函数说明_pytorch内置函数
pytorch内置函数 pytorch函数说明_自定义_02
也给出了参数的定义、使用样例等等,非常详细。

pytorch的损失函数都具有 size_average, reduce, reduction三个参数,reduce=False时,只计算Loss,其返回值也是一个张量,包含batch内每个数据得到的loss;reduce=True时,若size_average=True对batch内计算的每个数据对应的loss求均值,size_average=False时求和。按照官方的说法,size_average, reduce两个参数已被弃用(看上边的描述就知道用起来很繁琐),所以按照默认值设为None即可。

现在只使用第三个参数reduction即可,这个参数输入的是字符串,其取值有 'sum', 'mean', 'None'三种,分别对应了对batch内loss求和、求均值和直接返回三种。

这里简要介绍一下常用的几种损失函数:
BCELoss:二元交叉熵
pytorch内置函数 pytorch函数说明_自定义_03
二元交叉熵一般用于二分类任务

NLLLoss:
pytorch内置函数 pytorch函数说明_自定义_04
使用时输入是input和target,input包含每一类的对数概率的Tensor,target则是数据所属的类别索引,适用于多分类任务。

CrossEntropyLoss:交叉熵
pytorch内置函数 pytorch函数说明_pytorch内置函数_05
交叉熵损失一般用于多类别的分类任务。不过这里的交叉熵跟我们平常理解的不太一样, 其实可以看做是log-softmax和NLLLoss的结合。

MSELoss:前文中有提到,可以认为是输入与标签L2范数的平方

L1Loss:
pytorch内置函数 pytorch函数说明_ide_06