import tensorflow as tf
image = tf.zeros([10,10,3])

print(image.shape.as_list())

print(tf.expand_dims(image, axis=0).shape.as_list())


print(tf.expand_dims(image, axis=1).shape.as_list())


cc=tf.expand_dims(image, -1)
print(cc.shape.as_list())


print(tf.squeeze(cc).shape.as_list())
[10, 10, 3]
[1, 10, 10, 3]
[10, 1, 10, 3]
[10, 10, 3, 1]
[10, 10, 3]
  • tf.squeeze ,其去除大小为1的尺寸。