tf.shape() get wrong shape in tensorflow

后端 未结 6 569
南笙
南笙 2020-12-04 10:21

I define a tensor like this:

x = tf.get_variable(\"x\", [100])

But when I try to print shape of tensor :

print( tf.shape(x) )

6条回答
  •  死守一世寂寞
    2020-12-04 11:06

    Just a quick example, to make things clear:

    a = tf.Variable(tf.zeros(shape=(2, 3, 4)))
    print('-'*60)
    print("v1", tf.shape(a))
    print('-'*60)
    print("v2", a.get_shape())
    print('-'*60)
    with tf.Session() as sess:
        print("v3", sess.run(tf.shape(a)))
    print('-'*60)
    print("v4",a.shape)
    

    Output will be:

    ------------------------------------------------------------
    v1 Tensor("Shape:0", shape=(3,), dtype=int32)
    ------------------------------------------------------------
    v2 (2, 3, 4)
    ------------------------------------------------------------
    v3 [2 3 4]
    ------------------------------------------------------------
    v4 (2, 3, 4)
    

    Also this should be helpful: How to understand static shape and dynamic shape in TensorFlow?

提交回复
热议问题