返回一个one-ho张量。

tf.one_hot(
    indices,
    depth,
    on_value=None,
    off_value=None,
    axis=None,
    dtype=None,
    name=None
)

索引中由索引表示的位置取值on_value,而所有其他位置取值off_value。on_value和off_value必须具有匹配的数据类型。如果还提供了dtype,则它们必须与dtype指定的数据类型相同。如果没有提供on_value,它将默认为值1,类型为dtype。如果没有提供off_value,它将默认值为0,类型为dtype。如果输入索引的秩为N,那么输出的秩为N+1。新轴是在维度轴上创建的(缺省值:新轴附加在末尾)。如果索引是标量,则输出形状将是长度深度向量。如果索引是长度特征向量,则输出形状为:

  features x depth if axis == -1
  depth x features if axis == 0

如果索引是一个形状为[batch, features]的矩阵(batch),则输出形状为:

  batch x features x depth if axis == -1
  batch x depth x features if axis == 1
  depth x batch x features if axis == 0

如果没有提供dtype,它将尝试假设数据类型为on_value或off_value(如果传入了一个或两个值)。如果不提供on_value、off_value或dtype, dtype将默认为tf.float32。

注意:如果需要非数值数据类型输出(tf.string, tf.bool等),,on_value和off_value都必须提供给one_hot。

参数:

  • indices:指标的张量。
  • depth:定义one hot维的深度。
  • on_value:一个标量,定义了当索引[j] = i时要填充输出的值。
  • off_value:一个标量,定义当索引[j] != i.(默认值:0)时要填充输出的值。
  • axis:要填充的轴(默认值:-1,一个新的最内层轴)。
  • dtype:输出张量的数据类型。
  • name:操作的名称(可选)。

返回值:

  • output: 一个独热向量。

可能产生的异常:

  • TypeError: If dtype of either on_value or off_value don't match dtype
  • TypeError: If dtype of on_value and off_value don't match one another

例:

import numpy as np
import tensorflow as tf
 
#第一个案例:最简单的普通形式
indices=[0,1,2]   #rank=1
depth=3
a=tf.one_hot(indices,depth)   #rank=2,输出为[3,3]
 
#设置了on_value和off_value的大小,注意观察结果,
# 你会发现indices中数字指示的地方用on_value代替,
# 其他地方用的是off_value
indices=[0,2,-1,1]  #rank=1
depth=3
b=tf.one_hot(indices,depth,on_value=5.0,off_value=1.0,axis=-1)
#rank=2,输出的shape为[4,3]
 
#第三个案例是一个矩阵,rank=2,那么输出的rank=3,depth=3,axis=-1,
#输出的shape为[2,2,3]
indices=[[0,2],[-1,1]]
depth=3
c=tf.one_hot(indices,depth,on_value=1,axis=-1)
 
with tf.Session() as sess:
    print(sess.run(a))
    print("...............")
    print(sess.run(b))
    print("---------------")
    print(sess.run(c))

Output:
-----------------------------
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
...............
[[5. 1. 1.]
 [1. 1. 5.]
 [1. 1. 1.]
 [1. 5. 1.]]
---------------
[[[1 0 0]
  [0 0 1]]

 [[0 0 0]
  [0 1 0]]]
-----------------------------