Flatten batch in tensorflow

前端 未结 3 2142
小鲜肉
小鲜肉 2020-12-05 04:45

I have an input to tensorflow of shape [None, 9, 2] (where the None is batch).

To perform further actions (e.g. matmul) on it I need to tra

相关标签:
3条回答
  • 2020-12-05 05:08

    You can use dynamic reshaping to get value of batch dimension through tf.batch during runtime, calculate the whole set of new dimensions into tf.reshape. Here's an example of reshaping flat list into square matrix without knowing list length.

    tf.reset_default_graph()
    sess = tf.InteractiveSession("")
    a = tf.placeholder(dtype=tf.int32)
    # get [9]
    ashape = tf.shape(a)
    # slice the list from 0th to 1st position
    ashape0 = tf.slice(ashape, [0], [1])
    # reshape list to scalar, ie from [9] to 9
    ashape0_flat = tf.reshape(ashape0, ())
    # tf.sqrt doesn't support int, so cast to float
    ashape0_flat_float = tf.to_float(ashape0_flat)
    newshape0 = tf.sqrt(ashape0_flat_float)
    # convert [3, 3] Python list into [3, 3] Tensor
    newshape = tf.pack([newshape0, newshape0])
    # tf.reshape doesn't accept float, so convert back to int
    newshape_int = tf.to_int32(newshape)
    a_reshaped = tf.reshape(a, newshape_int)
    sess.run(a_reshaped, feed_dict={a: np.ones((9))})
    

    You should see

    array([[1, 1, 1],
           [1, 1, 1],
           [1, 1, 1]], dtype=int32)
    
    0 讨论(0)
  • 2020-12-05 05:18

    You can do it easily with tf.reshape() without knowing the batch size.

    x = tf.placeholder(tf.float32, shape=[None, 9,2])
    shape = x.get_shape().as_list()        # a list: [None, 9, 2]
    dim = numpy.prod(shape[1:])            # dim = prod(9,2) = 18
    x2 = tf.reshape(x, [-1, dim])           # -1 means "all"
    

    The -1 in the last line means the whole column no matter what the batchsize is in the runtime. You can see it in tf.reshape().


    Update: shape = [None, 3, None]

    Thanks @kbrose. For the cases where more than 1 dimension are undefined, we can use tf.shape() with tf.reduce_prod() alternatively.

    x = tf.placeholder(tf.float32, shape=[None, 3, None])
    dim = tf.reduce_prod(tf.shape(x)[1:])
    x2 = tf.reshape(x, [-1, dim])
    

    tf.shape() returns a shape Tensor which can be evaluated in runtime. The difference between tf.get_shape() and tf.shape() can be seen in the doc.

    I also tried tf.contrib.layers.flatten() in another . It is simplest for the first case, but it can't handle the second.

    0 讨论(0)
  • 2020-12-05 05:22
    flat_inputs = tf.layers.flatten(inputs)
    
    0 讨论(0)
提交回复
热议问题