Flatten batch in tensorflow

前端 未结 3 2145
小鲜肉
小鲜肉 2020-12-05 04:45

I have an input to tensorflow of shape [None, 9, 2] (where the None is batch).

To perform further actions (e.g. matmul) on it I need to tra

3条回答
  •  轻奢々
    轻奢々 (楼主)
    2020-12-05 05:18

    You can do it easily with tf.reshape() without knowing the batch size.

    x = tf.placeholder(tf.float32, shape=[None, 9,2])
    shape = x.get_shape().as_list()        # a list: [None, 9, 2]
    dim = numpy.prod(shape[1:])            # dim = prod(9,2) = 18
    x2 = tf.reshape(x, [-1, dim])           # -1 means "all"
    

    The -1 in the last line means the whole column no matter what the batchsize is in the runtime. You can see it in tf.reshape().


    Update: shape = [None, 3, None]

    Thanks @kbrose. For the cases where more than 1 dimension are undefined, we can use tf.shape() with tf.reduce_prod() alternatively.

    x = tf.placeholder(tf.float32, shape=[None, 3, None])
    dim = tf.reduce_prod(tf.shape(x)[1:])
    x2 = tf.reshape(x, [-1, dim])
    

    tf.shape() returns a shape Tensor which can be evaluated in runtime. The difference between tf.get_shape() and tf.shape() can be seen in the doc.

    I also tried tf.contrib.layers.flatten() in another . It is simplest for the first case, but it can't handle the second.

提交回复
热议问题