I have an input to tensorflow of shape [None, 9, 2]
(where the None
is batch).
To perform further actions (e.g. matmul) on it I need to tra
You can do it easily with tf.reshape() without knowing the batch size.
x = tf.placeholder(tf.float32, shape=[None, 9,2])
shape = x.get_shape().as_list() # a list: [None, 9, 2]
dim = numpy.prod(shape[1:]) # dim = prod(9,2) = 18
x2 = tf.reshape(x, [-1, dim]) # -1 means "all"
The -1
in the last line means the whole column no matter what the batchsize is in the runtime. You can see it in tf.reshape().
Thanks @kbrose. For the cases where more than 1 dimension are undefined, we can use tf.shape() with tf.reduce_prod() alternatively.
x = tf.placeholder(tf.float32, shape=[None, 3, None])
dim = tf.reduce_prod(tf.shape(x)[1:])
x2 = tf.reshape(x, [-1, dim])
tf.shape() returns a shape Tensor which can be evaluated in runtime. The difference between tf.get_shape() and tf.shape() can be seen in the doc.
I also tried tf.contrib.layers.flatten() in another . It is simplest for the first case, but it can't handle the second.