I have an input to tensorflow of shape [None, 9, 2]
(where the None
is batch).
To perform further actions (e.g. matmul) on it I need to tra
You can use dynamic reshaping to get value of batch dimension through tf.batch
during runtime, calculate the whole set of new dimensions into tf.reshape
. Here's an example of reshaping flat list into square matrix without knowing list length.
tf.reset_default_graph()
sess = tf.InteractiveSession("")
a = tf.placeholder(dtype=tf.int32)
# get [9]
ashape = tf.shape(a)
# slice the list from 0th to 1st position
ashape0 = tf.slice(ashape, [0], [1])
# reshape list to scalar, ie from [9] to 9
ashape0_flat = tf.reshape(ashape0, ())
# tf.sqrt doesn't support int, so cast to float
ashape0_flat_float = tf.to_float(ashape0_flat)
newshape0 = tf.sqrt(ashape0_flat_float)
# convert [3, 3] Python list into [3, 3] Tensor
newshape = tf.pack([newshape0, newshape0])
# tf.reshape doesn't accept float, so convert back to int
newshape_int = tf.to_int32(newshape)
a_reshaped = tf.reshape(a, newshape_int)
sess.run(a_reshaped, feed_dict={a: np.ones((9))})
You should see
array([[1, 1, 1],
[1, 1, 1],
[1, 1, 1]], dtype=int32)
You can do it easily with tf.reshape() without knowing the batch size.
x = tf.placeholder(tf.float32, shape=[None, 9,2])
shape = x.get_shape().as_list() # a list: [None, 9, 2]
dim = numpy.prod(shape[1:]) # dim = prod(9,2) = 18
x2 = tf.reshape(x, [-1, dim]) # -1 means "all"
The -1
in the last line means the whole column no matter what the batchsize is in the runtime. You can see it in tf.reshape().
Thanks @kbrose. For the cases where more than 1 dimension are undefined, we can use tf.shape() with tf.reduce_prod() alternatively.
x = tf.placeholder(tf.float32, shape=[None, 3, None])
dim = tf.reduce_prod(tf.shape(x)[1:])
x2 = tf.reshape(x, [-1, dim])
tf.shape() returns a shape Tensor which can be evaluated in runtime. The difference between tf.get_shape() and tf.shape() can be seen in the doc.
I also tried tf.contrib.layers.flatten() in another . It is simplest for the first case, but it can't handle the second.
flat_inputs = tf.layers.flatten(inputs)