Tensorflow输出

tensorflow由于其基于静态图的模式,导致写代码的时候很难调试,除了用官方的调试工具外,最直接的方法就是把中间结果输出出来查看,然而,直接用print函数只能输出tensor变量的形状(如("Placeholder:0", shape=(128, 346), dtype=float32),原因是:A Tensor object is a symbolic handle to the result of an operation, but does not actually hold the values of the operation’s output.),而不是数值,而我们通常想要输出tensor的具体数值。

print tensor属性信息

如shape,直接使用(必须使用int()转换,否则float(num_features)报错:TypeError: float() argument must be a string or a number, not 'Dimension')

num_features = int(batch_samples.shape[1])

或者int(batch_samples.get_shape()[1])

或者传入sess后使用tf.shape(batch_samples).eval(session=sess)

输出的shape为?即None怎么办?

输出的shape为(batch_size,?)不能直接输入到dense layer中,否则报错:flatten dense ValueError: The last dimension of the inputs to `Dense` should be defined. Found `None`.

        如果是keras,main_input = Input(shape=(max_seq_len,), dtype='float64', name='main_input')的维度为shape=(?, max_seq_len)。可以加上main_input = Input(shape=(max_seq_len,), batch_shape=(batch_size,) + tuple((maxlen,)), dtype='float64', name='main_input')这样输出就是shape=(batch_size, max_seq_len)。通过下面的debug找到的:进入Input函数的代码文件input_layer.py:

if shape is not None and not batch_shape:
    batch_shape = (None,) + tuple(shape)
    # batch_shape = (128,) + tuple(shape) # batch_size = 128

        如果是tensorflow dataset.padded_batch(params.get('batch_size', 20), shapes, defaults, drop_remainder=True)可以加上drop_remainder=True,即the last batch should be dropped in the case it has fewer than batch_size elements。这样所有的batch大小才都是batch_size,输出的dataset的shape为shapes: (((20, max_seq_len), (20,)), (20, ?)),而不是shapes: (((?, max_seq_len), (?,)), (?, ?))。通过进入函数padded_batch文件内部多层后,dataset_ops.py中这句发现的:

tensor_shape.vector(
    tensor_util.constant_value(self._batch_size) if smart_cond.
    smart_constant_value(self._drop_remainder) else None)


print tensor数值

1 Session.run()Tensor.eval()

所以评估Tensor对象的实际值的最简单的方法是将其传递给Session.run()方法,或者在有默认会话时调用Tensor.eval()(即在with tf.Session():块中)。一般来说,不能在会话中运行某些代码时打印张量的值。

使用Session.run()方法时,placeholder需要feed才能有值,所以需要在函数中返回这个placeholder进行feed,通过run it inside a session输出结果。

但是如果想要输出函数中的中间值而该值又未传回主函数呢?如estimator这种端到端的高级封装,这种情况下无法在函数中开启一个新的Session,但是可以用tf.Print建立op来实现。

示例

import tensorflow as tf
x = tf.constant([[[1, 2, 3],
                   [4, 5, 6]],
                  [[7, 8, 9],
                   [10, 11, 12]]])y = tf.constant([[[11, 12, 13],
                   [14, 15, 16]],
                  [[17, 18, 19],
                   [110, 111, 112]]])print(x.shape)  # (2, 2, 3)
 with tf.Session().as_default() as sess:
    print(x.eval())   
 print(y.eval(session=sess))
 print(sess.run(y))

2 tf.Print

计算一个变量的同时,指定打印一组变量

tf.Print(
    input_,
    data,
    message=None,
    first_n=None,
    summarize=None,
    name=None
)

input_:通过这个操作的张量。 (流入的数据流)
data:计算 op 时要打印的张量列表。(用[ ]引起来的一串需要打印的东西,用逗号隔开)
message:一个字符串,错误消息的前缀。
first_n:只记录 first_n 次数。负数日志,这是默认的。
summarize:只打印每个张量的固定数目的条目(忽略维度信息)。如果没有,则每个输入张量最多打印3个元素。summarize=-1打印所有数值。

name:操作的名称(可选)

如:

pred_id = tf.Print(pred_id, [pred_id], summarize=20, message='*pred_id*')
        x=tf.Print(x,[x.shape,'test', x],message='Debug message:',summarize=100)

示例1:

import tensorflow as tf
score = tf.constant([[0.0], [0.2], [0.3],
                          [0.2], [1.0], [4.0]])score = tf.Print(score, ['score: ', score], summarize=10)
 score_sig = tf.sigmoid(score)
 score_sig = tf.Print(score_sig, ['score_sig: ', score_sig], summarize=10)
 score_scale = score_sig * 2.0 - 1.0
 score_scale = tf.Print(score_scale, ['score_scale: ', score_scale], summarize=10)with tf.Session() as sess:
   sess.run(score_scale)[score: ][[0][0.2][0.3][0.2][1][4]]
 [score_sig: ][[0.5][0.549833953][0.574442506][0.549833953][0.731058598][0.982013762]]
 [score_scale: ][[0][0.0996679068][0.148885012][0.0996679068][0.462117195][0.964027524]]

示例2:

def test():
    a=tf.constant(0)
    a_print = tf.Print(a,['a_value: ',a])
    for i in range(10):  
        a=a_print+1
    return a
    
if __name__=='__main__':
    with tf.Session() as sess:
        sess.run(test())

会输出几次值呢?这其实并不是看下文中a_print被使用了几次,而是看数据流要从该节点上流经几次,可以理解为a_print这个op被“定义”了几次。

但是如果for循环中改成a_print=a_print+1,就不会输出任何东西,因为a_print这个op没有和别的变量发生关系,它没有被别的变量使用,在图里为孤立的一个节点,没有数据流过,就不会被执行。

像计算loss的logits这种方法可以每次输出,但是像计算acc指标的pred_id可能只会在每次batch执行完执行一次以及在load checkpoint时执行多次(这个不知道为什么,可能是分析不行)。

[tensorflow在函数中用tf.Print输出中间值的方法]

示例3:

        输出estimator这种端到端的高级封装中后面没有用到的tensor。如想输出其中一个 tensor变量,但它又不是返回的结果,没有数据流过,所以即使tf.Print()也不会打印出来。

解决:在返回时可以返回dict的tensor,加一个额外的other_print_tensor就可以,也会有数据流过。如predict时的:

other_print_tensor = tf.Print(other_print_tensor, [other_print_tensor])
predictions = {
    ...,
    'other_print_tensor': other_print_tensor
}
return tf.estimator.EstimatorSpec(mode, predictions=predictions)

但是这里dict返回的tensor的batch size必须是一样大小,否则会出错:ValueError: Batch length of predictions should be same. pred_strings1 has different batch length than others.而在estimator中dict其它的shape又是(None, None)的没法获取。

reverse_vocab_tags = tf.contrib.lookup.index_to_string_table_from_file(params['tags'])
other_print_tensor = tf.Print(reverse_vocab_tags.export()[0], [reverse_vocab_tags.export()[0]])

所以像这种index_to_string_table_from_file最好在外面另起一个代码用sess.run()测试下。

3 使用summary方法,在tensorboard中观测

[How to print the value of a Tensor object in TensorFlow?][如何在TensorFlow中打印Tensor对象的值?]

[Tensorflow print in function]

[Tensorflow之调试(Debug)及打印变量]*

4 logging方法(不行)

tf.logging.set_verbosity(tf.logging.INFO)
...
tf.logging.info("Created vocabulary with %d words" % len(vocab))

5 第三方debug/可视化工具[https://github.com/ericjang/tdb]

解决tensorflow打印tensor有省略号的问题

import numpy as np #借助numpy模块的set_printoptions()函数,将打印上限设置为无限即可
np.set_printoptions(threshold=np.inf)

但是这个只对run出来的打印有效,对tf.Print无效。

[解决tensorflow打印tensor有省略号的问题]

如何查看TensorFlow中Tensor, Variable, Constant的值?

模型训练完成后,如何获取模型的参数?

通过tf.trainable_variables()得到训练参数

[TensorFlow中遇到的问题及解决方法]

tensorflow创建变量以及根据名称查找变量

[tensorflow创建变量以及根据名称查找变量]

tensorflow去除warning信息输出

import warnings

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    warnings.warn("deprecated", DeprecationWarning)
    import tensorflow as tf

from tensorflow.python.util import deprecation

deprecation._PRINT_DEPRECATION_WARNINGS = False

[warnings — Warning control]