交叉熵损失函数是用来度量两个概率分布间的差异性,有关交叉熵损失函数的原理在这篇博客中讲解得很好。而本文主要对以下几种tensorflow中常用的交叉熵损失函数进行比较和总结:

  • tf.losses.sigmoid_cross_entropy
  • tf.nn.sigmoid_cross_entropy_with_logits
  • tf.losses.softmax_cross_entropy
  • tf.nn.softmax_cross_entropy_with_logits_v2
  • tf.losses.sparse_softmax_cross_entropy

1. tf.losses.sigmoid_cross_entropy

import tensorflow as tf

batch_size = 4
num_classes = 2
'''
tf.losses.sigmoid_cross_entropy适用于二分类问题,是对logits先进行sigmoid再求交叉熵
args:
    logits:不经过sigmoid处理的神经网络输出,是分类器对每个类别打的分数,shape:[batch_size, num_classes]
    labels:真实标签值,shape:[batch_size, num_classes]
'''
logits = tf.constant([[9., 2.],
                      [1, 7.],
                      [5., 4.],
                      [2., 8.]])
labels = tf.constant([0, 1, 0, 1])
one_hot_labels = tf.one_hot(labels, depth=num_classes, dtype=tf.int32)
loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=one_hot_labels, logits=logits)

sess = tf.InteractiveSession()
print(loss.eval())

打印的结果为:loss:1.1991692

2. tf.nn.sigmoid_cross_entropy_with_logits

import tensorflow as tf

batch_size = 4
num_classes = 2
'''
tf.nn.sigmoid_cross_entropy_with_logits与tf.losses.sigmoid_cross_entropy的功能类似,主要差别如下:
    1. 前者要求logits和labels不仅要有相同的shape,还要有相同的type
    2. 前者的输出为一个list,后者的输出为一个具体的数值
args:
    logits:不经过sigmoid处理的神经网络输出,是分类器对每个类别打的分数,shape:[batch_size, num_classes]
    labels:真实标签值,shape:[batch_size,num_classes]
'''
logits = tf.constant([[9., 2.],
                      [1., 7.],
                      [5., 4.],
                      [2., 8.]])
# logits = tf.cast(logits, tf.float32)
labels = tf.constant([0, 1, 0, 1], dtype=tf.int32)
labels = tf.one_hot(labels, depth=num_classes, dtype=tf.float32)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits))

sess = tf.InteractiveSession()
print(loss.eval())

打印的结果为:loss:1.1991692,与方法1的结果相同

3.tf.losses.softmax_cross_entropy

import tensorflow as tf

batch_size = 4
num_classes = 3
'''
tf.losses.softmax_cross_entropy适用于多分类问题,是对logits先进行softmax再求交叉熵
args:
    logits:不经过softmax处理的神经网络输出,是分类器对每个类别打的分数,shape:[batch_size, num_classes]
    labels:真实标签值,shape:[batch_size, num_classes]
'''
logits = tf.constant([[9., 2., 4.],
                      [1., 7., 3.],
                      [5., 4., 8.],
                      [2., 8., 9.]])
labels = tf.constant([1, 0, 2, 1])
one_hot_labels = tf.one_hot(labels, depth=num_classes, dtype=tf.int32)
loss = tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels, logits=logits)

sess = tf.InteractiveSession()
print(loss.eval())

打印结果为:loss:3.6020036

4.tf.nn.softmax_cross_entropy_with_logits_v2

import tensorflow as tf

batch_size = 4
num_classes = 3
'''
tf.nn.softmax_cross_entropy_with_logits_v2的功能与tf.losses.softmax_cross_entropy类似,两者的差别如下:
    1. 前者的输出为一个list,后者的输出为一个具体的值
args:
    logits:不经过softmax处理的神经网络输出,是分类器对每个类别打的分数,shape:[batch_size, num_classes]
    labels:真实标签值,shape:[batch_size, num_classes]
'''
logits = tf.constant([[9., 2., 4.],
                      [1., 7., 3.],
                      [5., 4., 8.],
                      [2., 8., 9.]])
labels = tf.constant([1, 0, 2, 1])
one_hot_labels = tf.one_hot(labels, depth=num_classes, dtype=tf.int32)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=one_hot_labels, logits=logits))

sess = tf.InteractiveSession()
print(loss.eval())

打印的结果为:loss:3.6020036,与方法3的结果相同

5.tf.losses.sparse_softmax_cross_entropy

import tensorflow as tf

batch_size = 4
num_classes = 3
'''
tf.losses.sparse_softmax_cross_entropy的功能与tf.losses.softmax_cross_entropy类似,两者的差别如下:
    1. 前者的参数labels不需要ont_hot编码
args:
    logits:不经过softmax处理的神经网络输出,是分类器对每个类别打的分数,shape:[batch_size, num_classes]
    labels:真实标签值,shape:[batch_size,]
'''
logits = tf.constant([[9., 2., 4.],
                      [1., 7., 3.],
                      [5., 4., 8.],
                      [2., 8., 9.]])
labels = tf.constant([1, 0, 2, 1])
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

sess = tf.InteractiveSession()
print(loss.eval())

打印的结果为:loss:3.6020036,与方法3和方法4的结果相同

总结:

  1. sigmoid交叉熵损失函数适用于二分类问题,softmax交叉熵损失函数适用于多分类问题。
  2. tf.nn模块中的损失函数输出为一个list,可用tf.reduce_mean()函数求均值作为最终损失;tf.losses模块中的损失函数输出为一个值。
  3. sparse_softmax_cross_entropy与softmax_cross_entropy的区别是前者的labels参数不需要one_hot编码,可以减少一些内存占用。