tf.where(
condition,
x=None,
y=None,
name=None
)
根据条件返回元素(x或y)。 如果x和y都为空,那么这个操作返回条件的真元素的坐标。坐标在二维张量中返回,其中第一个维度(行)表示真实元素的数量,第二个维度(列)表示真实元素的坐标。记住,输出张量的形状可以根据输入中有多少个真值而变化。索引按行主顺序输出。如果两者都是非零,则x和y必须具有相同的形状。如果x和y是标量,条件张量必须是标量。如果x和y是更高秩的向量,那么条件必须是大小与x的第一个维度匹配的向量,或者必须具有与x相同的形状。条件张量充当一个掩码,它根据每个元素的值选择输出中对应的元素/行是来自x(如果为真)还是来自y(如果为假)。如果条件是一个向量,x和y是高秩矩阵,那么它选择从x和y复制哪一行(外维),如果条件与x和y形状相同,那么它选择从x和y复制哪一个元素。
参数:
- condition: bool类型的张量
- x: 一个张量,它的形状可能和条件相同。如果条件为秩1,x的秩可能更高,但是它的第一个维度必须与条件的大小匹配
- y: 与x形状和类型相同的张量
- name: 操作的名称(可选)
返回值:
- 一个与x, y相同类型和形状的张量,如果它们是非零的话。一个带形状(num_true, dim_size(condition))的张量。
可能引发的异常:
-
ValueError
: When exactly one ofx
ory
is non-None.
官方文档中只有tf.where(input, name=None)一种用法,在实际应用中发现了另外一种使用方法tf.where(input, a,b),其中a,b均为尺寸一致的tensor,作用是将a中对应input中true的位置的元素值不变,其余元素进行替换,替换成b中对应位置的元素值,下面使用代码来说明:
import tensorflow as tf
import numpy as np
sess=tf.Session()
a=np.array([[1,0,0],[0,1,1]])
a1=np.array([[3,2,3],[4,5,6]])
print(sess.run(tf.equal(a,1)))
Output:
----------------------
[[ True False False]
[False True True]]
----------------------
import tensorflow as tf
import numpy as np
sess=tf.Session()
a=np.array([[1,0,0],[0,1,1]])
a1=np.array([[3,2,3],[4,5,6]])
print(sess.run(tf.where(tf.equal(a,1),a1,1-a1)))
Output:
-------------
[[ 3 -1 -2]
[-3 5 6]]
-------------
import tensorflow as tf
import numpy as np
sess=tf.Session()
a=np.array([[1,0,0],[0,1,1]])
a1=np.array([[3,2,3],[4,5,6]])
print(sess.run(tf.where(tf.equal(a,0),a1,1-a1)))
Output:
-------------
[[-2 2 3]
[ 4 -4 -5]]
-------------
对比两行代码的不同可以发现该函数的作用。
- 不同之处为tf.equal(a,0)和tf.equal(a,1)
- tf.equal()返回tensor中满足条件的位置