Tensorflow: Convolutions with different filter for each sample in the mini-batch

前端 未结 4 1145
伪装坚强ぢ
伪装坚强ぢ 2021-01-05 02:31

I would like to have a 2d convolution with a filter which depends on the sample in the mini-batch in tensorflow. Any ideas how one could do that, especially if the number of

4条回答
  •  自闭症患者
    2021-01-05 02:47

    You could use tf.map_fn as follows:

    inp = tf.placeholder(tf.float32, [None, h, w, c_in]) 
    def single_conv(tupl):
        x, kernel = tupl
        return tf.nn.conv2d(x, kernel, strides=(1, 1, 1, 1), padding='VALID')
    # Assume kernels shape is [tf.shape(inp)[0], fh, fw, c_in, c_out]
    batch_wise_conv = tf.squeeze(tf.map_fn(
        single_conv, (tf.expand_dims(inp, 1), kernels), dtype=tf.float32),
        axis=1
    )
    

    It is important to specify dtype for map_fn. Basically, this solution defines batch_dim_size 2D convolution operations.

提交回复
热议问题