I have an input to tensorflow of shape [None, 9, 2] (where the None is batch).
To perform further actions (e.g. matmul) on it I need to tra
You can use dynamic reshaping to get value of batch dimension through tf.batch during runtime, calculate the whole set of new dimensions into tf.reshape. Here's an example of reshaping flat list into square matrix without knowing list length.
tf.reset_default_graph()
sess = tf.InteractiveSession("")
a = tf.placeholder(dtype=tf.int32)
# get [9]
ashape = tf.shape(a)
# slice the list from 0th to 1st position
ashape0 = tf.slice(ashape, [0], [1])
# reshape list to scalar, ie from [9] to 9
ashape0_flat = tf.reshape(ashape0, ())
# tf.sqrt doesn't support int, so cast to float
ashape0_flat_float = tf.to_float(ashape0_flat)
newshape0 = tf.sqrt(ashape0_flat_float)
# convert [3, 3] Python list into [3, 3] Tensor
newshape = tf.pack([newshape0, newshape0])
# tf.reshape doesn't accept float, so convert back to int
newshape_int = tf.to_int32(newshape)
a_reshaped = tf.reshape(a, newshape_int)
sess.run(a_reshaped, feed_dict={a: np.ones((9))})
You should see
array([[1, 1, 1],
[1, 1, 1],
[1, 1, 1]], dtype=int32)
You can do it easily with tf.reshape() without knowing the batch size.
x = tf.placeholder(tf.float32, shape=[None, 9,2])
shape = x.get_shape().as_list() # a list: [None, 9, 2]
dim = numpy.prod(shape[1:]) # dim = prod(9,2) = 18
x2 = tf.reshape(x, [-1, dim]) # -1 means "all"
The -1 in the last line means the whole column no matter what the batchsize is in the runtime. You can see it in tf.reshape().
Thanks @kbrose. For the cases where more than 1 dimension are undefined, we can use tf.shape() with tf.reduce_prod() alternatively.
x = tf.placeholder(tf.float32, shape=[None, 3, None])
dim = tf.reduce_prod(tf.shape(x)[1:])
x2 = tf.reshape(x, [-1, dim])
tf.shape() returns a shape Tensor which can be evaluated in runtime. The difference between tf.get_shape() and tf.shape() can be seen in the doc.
I also tried tf.contrib.layers.flatten() in another . It is simplest for the first case, but it can't handle the second.
flat_inputs = tf.layers.flatten(inputs)