可以将文章内容翻译成中文,广告屏蔽插件可能会导致该功能失效(如失效,请关闭广告屏蔽插件后再试):
问题:
Following the upgrade to Keras 2.0.9, I have been using the multi_gpu_model
utility but I can't save my models or best weights using
model.save('path')
The error I get is
TypeError: can’t pickle module objects
I suspect there is some problem gaining access to the model object. Is there a work around this issue?
回答1:
Workaround
Here's a patched version that doesn't fail while saving:
from keras.layers import Lambda, concatenate from keras import Model import tensorflow as tf def multi_gpu_model(model, gpus): if isinstance(gpus, (list, tuple)): num_gpus = len(gpus) target_gpu_ids = gpus else: num_gpus = gpus target_gpu_ids = range(num_gpus) def get_slice(data, i, parts): shape = tf.shape(data) batch_size = shape[:1] input_shape = shape[1:] step = batch_size // parts if i == num_gpus - 1: size = batch_size - step * i else: size = step size = tf.concat([size, input_shape], axis=0) stride = tf.concat([step, input_shape * 0], axis=0) start = stride * i return tf.slice(data, start, size) all_outputs = [] for i in range(len(model.outputs)): all_outputs.append([]) # Place a copy of the model on each GPU, # each getting a slice of the inputs. for i, gpu_id in enumerate(target_gpu_ids): with tf.device('/gpu:%d' % gpu_id): with tf.name_scope('replica_%d' % gpu_id): inputs = [] # Retrieve a slice of the input. for x in model.inputs: input_shape = tuple(x.get_shape().as_list())[1:] slice_i = Lambda(get_slice, output_shape=input_shape, arguments={'i': i, 'parts': num_gpus})(x) inputs.append(slice_i) # Apply model on slice # (creating a model replica on the target device). outputs = model(inputs) if not isinstance(outputs, list): outputs = [outputs] # Save the outputs for merging back together later. for o in range(len(outputs)): all_outputs[o].append(outputs[o]) # Merge outputs on CPU. with tf.device('/cpu:0'): merged = [] for name, outputs in zip(model.output_names, all_outputs): merged.append(concatenate(outputs, axis=0, name=name)) return Model(model.inputs, merged)
You can use this multi_gpu_model
function, until the bug is fixed in keras. Also, when loading the model, it's important to provide the tensorflow module object:
model = load_model('multi_gpu_model.h5', {'tf': tf})
How it works
The problem is with import tensorflow
line in the middle of multi_gpu_model
:
def multi_gpu_model(model, gpus): ... import tensorflow as tf ...
This creates a closure for the get_slice
lambda function, which includes the number of gpus (that's ok) and tensorflow module (not ok). Model save tries to serialize all layers, including the ones that call get_slice
and fails exactly because tf
is in the closure.
The solution is to move import out of multi_gpu_model
, so that tf
becomes a global object, though still needed for get_slice
to work. This fixes the problem of saving, but in loading one has to provide tf
explicitly.
回答2:
To be honest, the easiest approach to this is to actually examine the multi gpu parallel model using
parallel_model.summary()
(The parallel model is simply the model after applying the multi_gpu function). This clearly highlights the actual model (in I think the penultimate layer - I am not at my computer right now). Then you can use the name of this layer to save the model.
model = parallel_model.get_layer('sequential_1)
Often its called sequential_1 but if you are using a published architecture, it may be 'googlenet' or 'alexnet'. You will see the name of the layer from the summary.
Then its simple to just save
model.save()
Maxims approach works, but its overkill I think.
Rem: you will need to compile both the model, and the parallel model.
回答3:
It's something that need a little work around by loading the multi_gpu_model weight to the regular model weight. e.g.
#1, instantiate your base model on a cpu with tf.device("/cpu:0"): model = create_model() #2, put your model to multiple gpus, say 2 multi_model = multi_gpu_model(model, 2) #3, compile both models model.compile(loss=your_loss, optimizer=your_optimizer(lr)) multi_model.compile(loss=your_loss, optimizer=your_optimizer(lr)) #4, train the multi gpu model # multi_model.fit() or multi_model.fit_generator() #5, save weights model.set_weights(multi_model.get_weights()) model.save(filepath=filepath)
`
refrence: https://github.com/fchollet/keras/issues/8123
回答4:
@Maxim: The workaround works great in saving the model. Unfortunately, I am facing error (see snippet below) in loading the model (as specified above) Using Keras version 2.0.9
File "", line 1, in
File "/home/sci/prafulag/KeplerCluster/anaconda3/lib/python3.5/site-packages/keras/models.py", line 239, in load_model
model = model_from_config(model_config, custom_objects=custom_objects)
File "/home/sci/prafulag/KeplerCluster/anaconda3/lib/python3.5/site-packages/keras/models.py", line 313, in model_from_config return layer_module.deserialize(config, custom_objects=custom_objects)
File "/home/sci/prafulag/KeplerCluster/anaconda3/lib/python3.5/site-packages/keras/layers/init.py", line 55, in deserialize printable_module_name='layer')
File "/home/sci/prafulag/KeplerCluster/anaconda3/lib/python3.5/site-packages/keras/utils/generic_utils.py", line 139, in deserialize_keras_object list(custom_objects.items())))
File "/home/sci/prafulag/KeplerCluster/anaconda3/lib/python3.5/site-packages/keras/engine/topology.py", line 2490, in from_config process_layer(layer_data)
File "/home/sci/prafulag/KeplerCluster/anaconda3/lib/python3.5/site-packages/keras/engine/topology.py", line 2476, in process_layer custom_objects=custom_objects)
File "/home/sci/prafulag/KeplerCluster/anaconda3/lib/python3.5/site-packages/keras/layers/init.py", line 55, in deserialize printable_module_name='layer')
File "/home/sci/prafulag/KeplerCluster/anaconda3/lib/python3.5/site-packages/keras/utils/generic_utils.py", line 139, in deserialize_keras_object list(custom_objects.items())))
File "/home/sci/prafulag/KeplerCluster/anaconda3/lib/python3.5/site-packages/keras/layers/core.py", line 699, in from_config function = func_load(config['function'], globs=globs)
File "/home/sci/prafulag/KeplerCluster/anaconda3/lib/python3.5/site-packages/keras/utils/generic_utils.py", line 206, in func_load closure=closure)
TypeError: arg 5 (closure) must be None or tuple