Keras - passing different parameter for different data point onto Lambda Layer

只愿长相守 提交于 2019-12-24 09:06:58

问题


I am working on a CNN model in Keras/TF background. At the end of final convolutional layer, I need to pool the output maps from the filters. Instead of using GlobalAveragePooling or any other sort of pooling, I had to pool according to time frames which exist along the width of the output map.

So if a sample output from one filter is let's say n x m, n being time frames and m outputs along the features. Here I just need to pool output from frames n1 to n2 where n1 and n2 <= n. So my output slice is (n2-n1)*m, on which I will apply pooling. I came across Lambda Layer of keras to do this. But I am stuck at a point where n1 and n2 will be different for each points. So my question is how can pass a custom argument for each data point onto a Lambda Layer? or am I approaching this in a wrong way?

A sample snippet:

# for slicing a tensor
def time_based_slicing(x, crop_at):
    dim = x.get_shape()
    len_ = crop_at[1] - crop_at[0]
    return tf.slice(x, [0, crop_at[0], 0, 0], [1, len_, dim[2], dim[3]])

# for output shape
def return_out_shape(input_shape):
    return tuple([input_shape[0], None, input_shape[2], input_shape[3]])

# lambda layer addition
model.add(Lambda(time_based_slicing, output_shape=return_out_shape, arguments={'crop_at': (2, 5)}))

The above argument crop_at needs to be custom for each data point when fitting in a loop. Any pointers/clues to this will be helpful.


回答1:


Given that you know the indices of the time frames that belong to each datapoint from before, you can store them in a text file and pass them as an additional Input to your model:

slice_input = Input((2,))

And use those in your time_based_slicing function.




回答2:


Switch from Sequential API - it starts to fall apart when you need to use multiple inputs: use Functional API https://keras.io/models/model/

Assuming that your lambda functions are correct:

def time_based_slicing(inputs_list):
    x, crop_at = inputs_list
    ... (will probably need to do some work to subset crop_at since it will be a tensor now instead of constants

inp = Input(your_shape)
inp_additional = Inp((2,)
x=YOUR_CNN_LOGIC(inp)
out = Lambda(time_based_slicing)([x,inp_additional])


来源:https://stackoverflow.com/questions/51557383/keras-passing-different-parameter-for-different-data-point-onto-lambda-layer

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!