Checkpointing keras model: TypeError: can't pickle _thread.lock objects

后端 未结 3 464
时光说笑
时光说笑 2020-12-01 11:37

It seems like the error has occurred in the past in different contexts here, but I\'m not dumping the model directly -- I\'m using the ModelCheckpoint callback. Any idea wha

3条回答
  •  -上瘾入骨i
    2020-12-01 11:44

    To clarify: this is not a problem of Keras being unable to pickle a Tensor (other scenarios possible, see below) in a Lambda layer, but rather that the arguments of the python's function (here: a lambda function) are attempted to be serialized independently from the function (here: outside of the context of the lambda function itself). This works for 'static' arguments, but fails otherwise. In order to circumvent it, one should wrap the non-static function arguments in another function.

    Here are a couple of workarounds:


    1. Use static variables, such as python/numpy-variables (just a mentioned above):
    low = np.random.rand(30, 3)
    high = 1 + np.random.rand(30, 3)
    
    x = Input(shape=(30,3))
    clipped_out_position = Lambda(lambda x: tf.clip_by_value(x, low, high))(x)
    

    1. Use functools.partial to wrap your lambda-function:
    import functools
    
    clip_by_value = functools.partial(
       tf.clip_by_value,
       clip_value_min=low,
       clip_value_max=high)
    
    x = Input(shape=(30,3))
    clipped_out_position = Lambda(lambda x: clip_by_value(x))(x)
    

    1. Use a closure to wrap your lambda-function:
    low = tf.constant(np.random.rand(30, 3).astype('float32'))
    high = tf.constant(1 + np.random.rand(30, 3).astype('float32'))
    
    def clip_by_value(t):
        return tf.clip_by_value(t, low, high)
    
    x = Input(shape=(30,3))
    clipped_out_position = Lambda(lambda x: clip_by_value(x))(x)
    

    Notice: although that you can sometimes drop the creation of explicit lambda-function and have this cleaner code snippet instead:

    clipped_out_position = Lambda(clip_by_value)(x)
    

    the absence of an extra wrapping layer of a lambda function (that is lambda t: clip_by_value(t)) might still lead to the same problem when doing 'deep-copy' of the function arguments, and should be avoided.


    1. Finally, you can wrap your model logic into a separate Keras layer, which in this particular case may look a bit over-engineered:
    x = Input(shape=(30,3))
    low = Lambda(lambda t: tf.constant(np.random.rand(30, 3).astype('float32')))(x)
    high = Lambda(lambda t: tf.constant(1 + np.random.rand(30, 3).astype('float32')))(x)
    clipped_out_position = Lambda(lambda x: tf.clip_by_value(*x))((x, low, high))
    

    Notice: the tf.clip_by_value(*x) in the last Lambda layer is just an unpacked argument tuple, which can also be written in a more verbose form as tf.clip_by_value(x[0], x[1], x[2]) instead.


    (below) As a side note: such a scenario, where your lambda-function is trying to capture (a part of) a class instance will also break the serialization (due to a late binding):

    class MyModel:
        def __init__(self):
            self.low = np.random.rand(30, 3)
            self.high = 1 + np.random.rand(30, 3)
    
        def run(self):
            x = Input(shape=(30,3))
            clipped_out_position = Lambda(lambda x: tf.clip_by_value(x, self.low, self.high))(x)
            model = Model(inputs=x, outputs=[clipped_out_position])
            optimizer = Adam(lr=.1)
            model.compile(optimizer=optimizer, loss="mean_squared_error")
            checkpoint = ModelCheckpoint("debug.hdf", monitor="val_loss", verbose=1, save_best_only=True, mode="min")
            training_callbacks = [checkpoint]
            model.fit(np.random.rand(100, 30, 3), 
                     [np.random.rand(100, 30, 3)], callbacks=training_callbacks, epochs=50, batch_size=10, validation_split=0.33)
    
    MyModel().run()
    

    Which can be solved by assuring an early binding by this default argument trick:

            (...)
            clipped_out_position = Lambda(lambda x, l=self.low, h=self.high: tf.clip_by_value(x, l, h))(x)
            (...)
    

提交回复
热议问题