tf.keras.Model save: “AssertionError: Tried to export a function which references untracked object Tensor”

一曲冷凌霜 提交于 2021-01-29 08:12:32

问题


I'm running this Tensorflow NMT tutorial: https://github.com/tensorflow/addons/blob/master/docs/tutorials/networks_seq2seq_nmt.ipynb

When I try to save the decoder: decoder.save('decoder'), I get:

AssertionError: Tried to export a function which references untracked object Tensor("LuongAttention/memory_layer/Tensordot:0", shape=(1024, 23, 256), dtype=float32).TensorFlow objects (e.g. tf.Variable) captured by functions must be tracked by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly.

I also tried registering the LuongAttention object like:

attention_mechanism = tfa.seq2seq.LuongAttention(units=units, memory=None, memory_sequence_length=BATCH_SIZE*[max_length_input])
custom_objects = {"LuongAttention": attention_mechanism}
with tf.keras.utils.custom_object_scope(custom_objects):
  decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE, attention_mechanism)```

回答1:


In order to save/load a model with custom-defined layers, or a subclassed model, you should overwrite the get_config and optionally from_config methods. Additionally, you should use register the custom object so that Keras is aware of it.

See here: https://www.tensorflow.org/guide/keras/save_and_serialize



来源:https://stackoverflow.com/questions/64847875/tf-keras-model-save-assertionerror-tried-to-export-a-function-which-reference

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!