Keras lstm with masking layer for variable-length inputs

我们两清 提交于 2019-12-02 19:40:25

The Lambda layer, by default, does not propagate masks. In other words, the mask tensor computed by the Masking layer is thrown away by the Lambda layer, and thus the Masking layer has no effect on the output loss.

If you want the compute_mask method of a Lambda layer to propagate previous mask, you have to provide the mask argument when the layer is created. As can be seen from the source code of Lambda layer,

def __init__(self, function, output_shape=None,
             mask=None, arguments=None, **kwargs):
    # ...
    if mask is not None:
        self.supports_masking = True
    self.mask = mask

# ...

def compute_mask(self, inputs, mask=None):
    if callable(self.mask):
        return self.mask(inputs, mask)
    return self.mask

Because the default value of mask is None, compute_mask returns None and the loss is not masked at all.

To fix the problem, since your Lambda layer itself does not introduce any additional masking, the compute_mask method should just return the mask from the previous layer (with appropriate slicing to match the output shape of the layer).

masking_func = lambda inputs, previous_mask: previous_mask[:, N:]
model = Sequential()
model.add(Masking(mask_value=0., input_shape=(timesteps, features)))
model.add(LSTM(128, return_sequences=True))
model.add(LSTM(64, return_sequences=True))
model.add(LSTM(1, return_sequences=True))
model.add(Lambda(lambda x: x[:, N:, :], mask=masking_func))

Now you should be able to see the correct loss value.

>> model.evaluate(x_test, y_test, verbose=0)
0.2660679519176483
>> out = model.predict(x_test)
>> print('wo mask', mean_absolute_error(y_test.ravel(), out.ravel()))
wo mask 0.26519736809498456
>> print('w mask', mean_absolute_error(y_test[~(x_test[:,N:] == 0).all(axis=2)].ravel(), out[~(x_test[:,N:] == 0).all(axis=2)].ravel()))
w mask 0.2660679670482195

Using NaN value for padding does not work because masking is done by multiplying the loss tensor with a binary mask (0 * nan is still nan, so the mean value would be nan).

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!