1.in_top_k(predictions, targets, k, name=None)
Args:
predictions: 一种tf.float的张量。一个batch_size的x类张量。预测值,one-hot编码,size为[batch_size,label类别数]
如在cifar10的分类上为[128,10]
targets: 一个张量。必须是下列类型之一:int32, int64。size只有一维,也就意味着不能是one-hot编码的。理由举例就知道了
k:每个样本的预测结果的前k个最大的数里面是否包含targets预测中的标签,一般都是取1,即取预测最大概率的索引与标签对比。
name : 操作的名称(可选)。
举例:假设预测值logits为【10,5】的张量,5表示预测为5个类别,labels就为【10】
import tensorflow as tf
logits = tf.Variable(tf.random_normal([10,5],mean=0.0,stddev=1.0,dtype=tf.float32))
labels = tf.constant([0,2,0,1,0,0,4,0,3,0])
top_1_op = tf.nn.in_top_k(logits,labels,1)
top_2_op = tf.nn.in_top_k(logits,labels,2)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(logits.eval())
print(labels.eval())
print(top_1_op.eval())
print(top_2_op.eval())
结果:
解读第一个top_1_op.eval()值False的来源;
首先看第一行,前1个最大的值的索引为1,而labels第一个值为0,不想等,所以为False.以此类推.......
解读第一个top_2_op.eval()值False的来源;
首先看第一行,前2个最大的值的索引分别为1和0,而labels第一个值为0,有一个与labels相等,所以为True.以此类推.......
从这个过程我们就可以知道,labels如果也是一个one-hot编码的话,即使找到logits前一个最大值的索引,你要同labels(假设为【0,0,1,0,0】)去比较值相等,显然是不可能的,因为labels本身就不是一个值,而是一个列表,你怎么将一个数和一个列表比较相不相等呢?所以,用这种方法labels是不能够用one-hot编码的。
举个错误的例子,将这里的labels改为one-hot编码。看看报错怎么样。
import tensorflow as tf
logits = tf.Variable(tf.random_normal([10,5],mean=0.0,stddev=1.0,dtype=tf.float32))
labels = tf.constant([0,2,0,1,0,0,4,0,3,0])
n_classes = 5
labels = tf.one_hot(labels, depth=n_classes)
print(labels)
top_1_op = tf.nn.in_top_k(logits,labels,1)
top_2_op = tf.nn.in_top_k(logits,labels,2)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(logits.eval())
print(labels.eval())
print(top_1_op.eval())
print(top_2_op.eval())
Tensor("one_hot:0", shape=(10, 5), dtype=float32)
TypeError: Value passed to parameter 'targets' has DataType float32 not in list of allowed values: int32, int64