I would like to have a 2d convolution with a filter which depends on the sample in the mini-batch in tensorflow. Any ideas how one could do that, especially if the number of
You could use tf.map_fn
as follows:
inp = tf.placeholder(tf.float32, [None, h, w, c_in])
def single_conv(tupl):
x, kernel = tupl
return tf.nn.conv2d(x, kernel, strides=(1, 1, 1, 1), padding='VALID')
# Assume kernels shape is [tf.shape(inp)[0], fh, fw, c_in, c_out]
batch_wise_conv = tf.squeeze(tf.map_fn(
single_conv, (tf.expand_dims(inp, 1), kernels), dtype=tf.float32),
axis=1
)
It is important to specify dtype
for map_fn
. Basically, this solution defines batch_dim_size
2D convolution operations.