Load pre-trained keras model for continued training on google cloud

社会主义新天地 提交于 2020-02-25 04:43:09

问题


I am trying to load a pre-trained Keras model, for continued training on google cloud. It works locally, by simply loading the discriminator and generator with

 model = load_model('myPretrainedModel.h5')

But obviously this doesn't work on google cloud, I have tried using the same method I use to read the training data from my google storage bucket, with:

fil = "gs://mygcbucket/myPretrainedModel.h5"    
f = BytesIO(file_io.read_file_to_string(fil, binary_mode=True))
return np.load(f)

However this doesn't seem to work for loading a model, I get the following error running the job.

ValueError: Cannot load file containing pickled data when allow_pickle=False

adding allow_pickle=True, throws another error:

OSError: Failed to interpret file <_io.BytesIO object at 0x7fdf2bb42620> as a pickle

I then tried something I found as someone suggested for a similar issue, as I understand it temporarily resaving the model locally (in relation to where the job is running) from the bucket and then loading it, with:

fil = "gs://mygcbucket/myPretrainedModel.h5"  
model_file = file_io.FileIO(fil, mode='rb')
file_stream = file_io.FileIO(model_file, mode='r')
temp_model_location = './temp_model.h5'
temp_model_file = open(temp_model_location, 'wb')
temp_model_file.write(file_stream.read())
temp_model_file.close()
file_stream.close()
model = load_model(temp_model_location)
return model

However, this throw the following error:

TypeError: Expected binary or unicode string, got tensorflow.python.lib.io.file_io.FileIO object

I must admit I am not really sure what I need to do to actually load a pre-trained keras model from my storage bucket, and the use if in my training job at google cloud. Any help is deeply appreciated.


回答1:


I would suggest to use AI Platform Notebooks to do so. Download the trained model using this method. Check the Python code under the Code samples tab. Once you have your model in the VM where the Notebook is running you can load it as you were doing locally. Here you have an example where tf.keras.models.load_model is used.



来源:https://stackoverflow.com/questions/59927900/load-pre-trained-keras-model-for-continued-training-on-google-cloud

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!