I know this is a subject with a lot of questions but I couldn\'t find any solution to my problem.
I am training a LSTM network on variable-length inputs using a mask
The Lambda
layer, by default, does not propagate masks. In other words, the mask tensor computed by the Masking
layer is thrown away by the Lambda
layer, and thus the Masking
layer has no effect on the output loss.
If you want the compute_mask
method of a Lambda
layer to propagate previous mask, you have to provide the mask
argument when the layer is created. As can be seen from the source code of Lambda
layer,
def __init__(self, function, output_shape=None,
mask=None, arguments=None, **kwargs):
# ...
if mask is not None:
self.supports_masking = True
self.mask = mask
# ...
def compute_mask(self, inputs, mask=None):
if callable(self.mask):
return self.mask(inputs, mask)
return self.mask
Because the default value of mask
is None
, compute_mask
returns None
and the loss is not masked at all.
To fix the problem, since your Lambda
layer itself does not introduce any additional masking, the compute_mask
method should just return the mask from the previous layer (with appropriate slicing to match the output shape of the layer).
masking_func = lambda inputs, previous_mask: previous_mask[:, N:]
model = Sequential()
model.add(Masking(mask_value=0., input_shape=(timesteps, features)))
model.add(LSTM(128, return_sequences=True))
model.add(LSTM(64, return_sequences=True))
model.add(LSTM(1, return_sequences=True))
model.add(Lambda(lambda x: x[:, N:, :], mask=masking_func))
Now you should be able to see the correct loss value.
>> model.evaluate(x_test, y_test, verbose=0)
0.2660679519176483
>> out = model.predict(x_test)
>> print('wo mask', mean_absolute_error(y_test.ravel(), out.ravel()))
wo mask 0.26519736809498456
>> print('w mask', mean_absolute_error(y_test[~(x_test[:,N:] == 0).all(axis=2)].ravel(), out[~(x_test[:,N:] == 0).all(axis=2)].ravel()))
w mask 0.2660679670482195
Using NaN value for padding does not work because masking is done by multiplying the loss tensor with a binary mask (0 * nan
is still nan
, so the mean value would be nan
).