问题
I'm running this Tensorflow NMT tutorial: https://github.com/tensorflow/addons/blob/master/docs/tutorials/networks_seq2seq_nmt.ipynb
When I try to save the decoder:
decoder.save('decoder')
, I get:
AssertionError: Tried to export a function which references untracked object Tensor("LuongAttention/memory_layer/Tensordot:0", shape=(1024, 23, 256), dtype=float32).TensorFlow objects (e.g. tf.Variable) captured by functions must be tracked by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly.
I also tried registering the LuongAttention object like:
attention_mechanism = tfa.seq2seq.LuongAttention(units=units, memory=None, memory_sequence_length=BATCH_SIZE*[max_length_input])
custom_objects = {"LuongAttention": attention_mechanism}
with tf.keras.utils.custom_object_scope(custom_objects):
decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE, attention_mechanism)```
回答1:
In order to save/load a model with custom-defined layers, or a subclassed model, you should overwrite the get_config and optionally from_config methods. Additionally, you should use register the custom object so that Keras is aware of it.
See here: https://www.tensorflow.org/guide/keras/save_and_serialize
来源:https://stackoverflow.com/questions/64847875/tf-keras-model-save-assertionerror-tried-to-export-a-function-which-reference