tf.cond()的用法

 

由于tensorflow使用的是graph的计算概念,在没有涉及控制数据流向的时候编程和普通编程语言的编程差别不大,但是涉及到控制数据流向的操作时,就要特别小心,不然很容易出错。这也是TensorFlow比较反直觉的地方。

在TensorFlow中,tf.cond()类似于c语言中的if...else...,用来控制数据流向,但是仅仅类似而已,其中差别还是挺大的。关于tf.cond()函数的具体操作,我参考了tf的说明文档。

format:tf.cond(pred, fn1, fn2, name=None)

Return :either fn1() or fn2() based on the boolean predicate `pred`.(注意这里,也就是说'fnq'和‘fn2’是两个函数)

arguments:`fn1` and `fn2` both return lists of output tensors. `fn1` and `fn2` must have the same non-zero number and type of outputs('fnq'和‘fn2’返回的是非零的且类型相同的输出)

官方例子:

  1. ​z = tf.multiply(a, b)​
  2. ​result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))​

上面例子执行这样的操作,如果x<y则result这个操作是tf.add(x,z),反之则是tf.square(y)。这一点上,确实很像逻辑控制中的if...else...,但是官方说明里也提到

Since z is needed for at least one  branch of the cond,branch of the cond, the tf.mul operation is always executed, unconditionally.

因为z在cond函数中的至少一个分支被用到,所以

 

z = tf.multiply(a, b)

总是被无条件执行,这个确实很反直觉,跟我想象中的不太一样,按一般的逻辑不应该是不用到就不执行么?,然后查阅官方文档,我感受到了来之官方文档深深的鄙视0.0

 Although this behavior is consistent with the dataflow model of TensorFlow,it has occasionally surprised some users who expected a lazier semantics.

翻译过来应该是:尽管这样的操作与TensorFlow的数据流模型一致,但是偶尔还是会令那些期望慵懒语法的用户吃惊。(应该是这么翻译的吧,淦,我就那个懒人0.0)

好吧,我就大概记录一下我自己的理解(如果错了,欢迎拍砖)。因为TensorFlow是基于图的计算,数据以流的形式存在,所以只要构建好了图,有数据源,那么应该都会 数据流过,所以在执行tf.cond之前,两个数据流一个是tf.add()中的x,z,一个是tf.square(y)中的y,而tf.cond()就决定了是数据流x,z从tf.add()流过,还是数据流y从tf.square()流过。这里这个tf.cond也就像个控制水流的阀门,水流管道x,z,y在这个阀门交汇,而tf.cond决定了谁将流向后面的管道,但是不管哪一个水流流向下一个管道,在阀门作用之前,水流应该都是要到达阀门的。(啰啰嗦嗦了一大堆,还是不太理解)

栗子:

import tensorflow as tf

a=tf.constant(2)

b=tf.constant(3)

x=tf.constant(4)

y=tf.constant(5)

z = tf.multiply(a, b)

result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))

with tf.Session() as session:

print(result.eval())