tf.where(
    condition,
    x=None,
    y=None,
    name=None
)

根据条件返回元素(x或y)。 如果x和y都为空,那么这个操作返回条件的真元素的坐标。坐标在二维张量中返回,其中第一个维度(行)表示真实元素的数量,第二个维度(列)表示真实元素的坐标。记住,输出张量的形状可以根据输入中有多少个真值而变化。索引按行主顺序输出。如果两者都是非零,则x和y必须具有相同的形状。如果x和y是标量,条件张量必须是标量。如果x和y是更高秩的向量,那么条件必须是大小与x的第一个维度匹配的向量,或者必须具有与x相同的形状。条件张量充当一个掩码,它根据每个元素的值选择输出中对应的元素/行是来自x(如果为真)还是来自y(如果为假)。如果条件是一个向量,x和y是高秩矩阵,那么它选择从x和y复制哪一行(外维),如果条件与x和y形状相同,那么它选择从x和y复制哪一个元素。

参数:

  • condition:  bool类型的张量
  • x:  一个张量,它的形状可能和条件相同。如果条件为秩1,x的秩可能更高,但是它的第一个维度必须与条件的大小匹配
  • y:  与x形状和类型相同的张量
  • name:  操作的名称(可选)

返回值:

  • 一个与x, y相同类型和形状的张量,如果它们是非零的话。一个带形状(num_true, dim_size(condition))的张量。

可能引发的异常:

  • ValueError: When exactly one of x or y is non-None.

官方文档中只有tf.where(input, name=None)一种用法,在实际应用中发现了另外一种使用方法tf.where(input, a,b),其中a,b均为尺寸一致的tensor,作用是将a中对应input中true的位置的元素值不变,其余元素进行替换,替换成b中对应位置的元素值,下面使用代码来说明:

import tensorflow as tf
import numpy as np
sess=tf.Session()

a=np.array([[1,0,0],[0,1,1]])
a1=np.array([[3,2,3],[4,5,6]])

print(sess.run(tf.equal(a,1)))

Output:
----------------------
[[ True False False]
 [False  True  True]]
----------------------


                        

import tensorflow as tf
import numpy as np
sess=tf.Session()

a=np.array([[1,0,0],[0,1,1]])
a1=np.array([[3,2,3],[4,5,6]])

print(sess.run(tf.where(tf.equal(a,1),a1,1-a1)))

Output:
-------------
[[ 3 -1 -2]
 [-3  5  6]]
-------------

 

import tensorflow as tf
import numpy as np
sess=tf.Session()

a=np.array([[1,0,0],[0,1,1]])
a1=np.array([[3,2,3],[4,5,6]])

print(sess.run(tf.where(tf.equal(a,0),a1,1-a1)))


Output:
-------------
[[-2  2  3]
 [ 4 -4 -5]]
-------------

 

 

对比两行代码的不同可以发现该函数的作用。

  • 不同之处为tf.equal(a,0)和tf.equal(a,1)
  • tf.equal()返回tensor中满足条件的位置