计算图是tensorflow中最基本的一个概念,Tensorflow中的所有计算都会被转化为计算图上的节点。
1.计算图的概念
Tensorflow的名字中已经说明它最重要的两个概念--Tensor和Flow。Tensor就是张量,在Tensorflow中,张量可以被简单的理解为多维数组。Flow翻译成中文是“流”,它直观的表达了张量之间通过计算相互转化。
Tensorflow是一个通过计算图的形式来表述计算的编程系统。Tensorflow中的每一个计算都是计算图上的一个节点,而节点之间的边描述了计算之间的依赖关系。
如果说Tensorflow的第一个词Tensor表明了它的数据结构,那么Flow则体现了它的计算模型。
2.计算图的使用
Tensorflow程序一般可以分为两个阶段。在第一个阶段需要定义计算图中所有的计算,第二个阶段为执行计算。
以下代码给出了计算定义阶段的样例。
import tensorflow as tf
a = tf.constant([1.0, 2.0] , name ='a')
b = tf.constant([2.0, 3.0] , name ='b')
result = a + b
在这个过程中,Tensorflow会自动将定义的计算转化为计算图上的节点。在Tensorflow程序中,系统会自动维护一个默认的计算图,通过tf.get_default_graph函数可以获取当前默认的计算图。以下代码示意了如何获取默认计算图以及如何查看一个运算所属的计算图。
# 通过a.graph可以查看张量所属的计算图。
# 因为没有特意指定,所以这个计算图应该等于当前默认的计算图。
# 所以下面这个操作输出值为True。
print(a.graph is tf.get_default_graph())
除了使用默认的计算图,Tensorflow支持通过tf.Graph函数来生成新的计算图。不同计算图上的张量和运算都不会共享。以下代码示意了如何在不同计算图上定义和使用变量。
import tensorflow as tf
g1 = tf.Graph()
with g1.as_default():
# 在计算图 g1 中定义变量“v”,并设置初始值为0
v = tf.get_variable(
"v", initializer = tf.zeros_initializer()(shape=[1]))
g2 = tf.Graph()
with g2.as_default():
# 在计算图 g2 中定义变量“v”,并设置初始值为1
v = tf.get_variable(
"v", initializer = tf.ones_initializer()(shape=[1]))
g2 = tf.Graph()
with g2.as_default():
# 在计算图 g2 中定义变量“v”,并设置初始值为1
v = tf.get_variable(
"v", initializer = tf.ones_initializer()(shape=[1]))
# 在计算图 g1 中读取变量“v”的取值
with tf.Session(graph=g1) as sess:
tf.global_variables_initializer().run()
with tf.variable_scope("" , reuse=True):
# 在计算图 g1 中, 变量“v”的取值应该为0,所以下面这行会输出[0.]
print(sess.run(tf.get_variable("v")))
# 在计算图 g2 中读取变量“v”的取值
with tf.Session(graph=g2) as sess:
tf.global_variables_initializer().run()
with tf.variable_scope("" , reuse=True):
# 在计算图 g2 中, 变量“v”的取值应该为1,所以下面这行会输出[1.]
print(sess.run(tf.get_variable("v")))
上面的代码产生了两个计算图,每个计算图中定义了一个名字为“v”的变量。在计算图g1中,将v初始化为0;在计算图g2中,将v初始化为1。可以看到当运行不同计算图时,变量v的值也是不一样的。
Tensorflow中的计算图不仅仅可以用来隔离张量和计算,它还提供了管理张量和计算的机制。计算图可以通过tf.Graph.device函数来指定运行计算的设备。这为Tensorflow使用GPU提供了机制。下面的程序可以将加法计算跑在GPU上。
g = tf.Graph()
# 指定计算运行的设备
with g.device('/gpu:0'):
result = a + b
在一个计算图中,可以通过集合(collection)来管理不同类别的资源。比如通过tf.add_to_collection函数可以将资源加入一个或多个集合中,然后通过tf.get_collection获取一个集合里面的所有资源。这里的资源可以是张量、变量或者运行Tensorflow程序需要的队列资源,等等。