How to Feed Batched Sequences of Images through Tensorflow conv2d

纵然是瞬间 提交于 2020-01-14 10:43:49

问题


This seems like a trivial question, but I've been unable to find the answer.

I have batched sequences of images of shape:

[batch_size, number_of_frames, frame_height, frame_width, number_of_channels]

and I would like to pass each frame through a few convolutional and pooling layers. However, TensorFlow's conv2d layer accepts 4D inputs of shape:

[batch_size, frame_height, frame_width, number_of_channels]

My first attempt was to use tf.map_fn over axis=1, but I discovered that this function does not propagate gradients.

My second attempt was to use tf.unstack over the first dimension and then use tf.while_loop. However, my batch_size and number_of_frames are dynamically determined (i.e. both are None), and tf.unstack raises {ValueError} Cannot infer num from shape (?, ?, 30, 30, 3) if num is unspecified. I tried specifying num=tf.shape(self.observations)[1], but this raises {TypeError} Expected int for argument 'num' not <tf.Tensor 'A2C/infer/strided_slice:0' shape=() dtype=int32>.


回答1:


Since all the images (num_of_frames) are passed to the same convolutional model, you can stack both batch and frames together and do the normal convolution. Can be achieved by just using tf.resize as shown below:


# input with size [batch_size, frame_height, frame_width, number_of_channels
x = tf.placeholder(tf.float32,[None, None,32,32,3])

# reshape for the conv input
x_reshapped = tf.reshape(x,[-1, 32, 32, 3])

x_reshapped output size will be (50, 32, 32, 3)

# define your conv network
y = tf.layers.conv2d(x_reshapped,5,kernel_size=(3,3),padding='SAME')
#(50, 32, 32, 3)

#Get back the input shape
out = tf.reshape(x,[-1, tf.shape(x)[1], 32, 32, 3])

The output size would be same as the input: (10, 5, 32, 32, 3

with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())

   print(sess.run(out, {x:np.random.normal(size=(10,5,32,32,3))}).shape)
   #(10, 5, 32, 32, 3) 


来源:https://stackoverflow.com/questions/50786077/how-to-feed-batched-sequences-of-images-through-tensorflow-conv2d

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!