Keras lstm with masking layer for variable-length inputs

后端 未结 1 1863
忘掉有多难
忘掉有多难 2020-12-07 18:01

I know this is a subject with a lot of questions but I couldn\'t find any solution to my problem.

I am training a LSTM network on variable-length inputs using a mask

相关标签:
1条回答
  • 2020-12-07 18:50

    The Lambda layer, by default, does not propagate masks. In other words, the mask tensor computed by the Masking layer is thrown away by the Lambda layer, and thus the Masking layer has no effect on the output loss.

    If you want the compute_mask method of a Lambda layer to propagate previous mask, you have to provide the mask argument when the layer is created. As can be seen from the source code of Lambda layer,

    def __init__(self, function, output_shape=None,
                 mask=None, arguments=None, **kwargs):
        # ...
        if mask is not None:
            self.supports_masking = True
        self.mask = mask
    
    # ...
    
    def compute_mask(self, inputs, mask=None):
        if callable(self.mask):
            return self.mask(inputs, mask)
        return self.mask
    

    Because the default value of mask is None, compute_mask returns None and the loss is not masked at all.

    To fix the problem, since your Lambda layer itself does not introduce any additional masking, the compute_mask method should just return the mask from the previous layer (with appropriate slicing to match the output shape of the layer).

    masking_func = lambda inputs, previous_mask: previous_mask[:, N:]
    model = Sequential()
    model.add(Masking(mask_value=0., input_shape=(timesteps, features)))
    model.add(LSTM(128, return_sequences=True))
    model.add(LSTM(64, return_sequences=True))
    model.add(LSTM(1, return_sequences=True))
    model.add(Lambda(lambda x: x[:, N:, :], mask=masking_func))
    

    Now you should be able to see the correct loss value.

    >> model.evaluate(x_test, y_test, verbose=0)
    0.2660679519176483
    >> out = model.predict(x_test)
    >> print('wo mask', mean_absolute_error(y_test.ravel(), out.ravel()))
    wo mask 0.26519736809498456
    >> print('w mask', mean_absolute_error(y_test[~(x_test[:,N:] == 0).all(axis=2)].ravel(), out[~(x_test[:,N:] == 0).all(axis=2)].ravel()))
    w mask 0.2660679670482195
    

    Using NaN value for padding does not work because masking is done by multiplying the loss tensor with a binary mask (0 * nan is still nan, so the mean value would be nan).

    0 讨论(0)
提交回复
热议问题