import tensorflow as tf
sess = tf.Session()

input = tf.ones([2,3,3])*2
mask = tf.diag(tf.ones([3]))

print(sess.run(mask))
print(sess.run(input * mask))

print结果:
[[1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]]

[[[2. 0. 0.]
[0. 2. 0.]
[0. 0. 2.]]
[[2. 0. 0.]
[0. 2. 0.]
[0. 0. 2.]]]