tf.shape() get wrong shape in tensorflow

后端 未结 6 564
南笙
南笙 2020-12-04 10:21

I define a tensor like this:

x = tf.get_variable(\"x\", [100])

But when I try to print shape of tensor :

print( tf.shape(x) )

6条回答
  •  粉色の甜心
    2020-12-04 10:52

    Similar question is nicely explained in TF FAQ:

    In TensorFlow, a tensor has both a static (inferred) shape and a dynamic (true) shape. The static shape can be read using the tf.Tensor.get_shape method: this shape is inferred from the operations that were used to create the tensor, and may be partially complete. If the static shape is not fully defined, the dynamic shape of a Tensor t can be determined by evaluating tf.shape(t).

    So tf.shape() returns you a tensor, will always have a size of shape=(N,), and can be calculated in a session:

    a = tf.Variable(tf.zeros(shape=(2, 3, 4)))
    with tf.Session() as sess:
        print sess.run(tf.shape(a))
    

    On the other hand you can extract the static shape by using x.get_shape().as_list() and this can be calculated anywhere.

提交回复
热议问题