问题
I am trying to merge two trained neural networks. I have two trained Keras model files A and B.
Model A is for image super-resolution and model B is for image colorization.
I am trying to merge two trained networks so that I can inference SR+colorization faster. (I am not willing to use a single network to accomplish both SR and colorization tasks. I need to use two different networks for SR and colorization tasks.)
Any tips on how to merge two Keras neural networks?
回答1:
As long a the shape of the output of the network A is compatible with the shape of the input of the model B, it is possible.
As a tf.keras.models.Model
inherits from tf.keras.layers.Layer
, you can use a Model
as you would use a Layer
when creating your keras model.
A simple example :
Lets first create 2 simple networks, A and B, with the constraints that the input of B has the same shape as the output of A.
import tensorflow as tf
A = tf.keras.models.Sequential(
[
tf.keras.Input((10,)),
tf.keras.layers.Dense(5, activation="tanh")
],
name="A"
)
B = tf.keras.models.Sequential(
[
tf.keras.Input((5,)),
tf.keras.layers.Dense(10, activation="tanh")
],
name="B"
)
Then we can merge those two models as one, in that case using the functional API (this is completely possible using the Sequential API as well):
merged_input = tf.keras.Input((10,))
x = A(merged_input)
merged_output = B(x)
merged_model = tf.keras.Model(inputs=merged_input, outputs=merged_output, name="merged_AB")
resulting in the following network:
>>> merged_model.summary()
Model: "merged_AB"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_3 (InputLayer) [(None, 10)] 0
_________________________________________________________________
A (Sequential) (None, 5) 55
_________________________________________________________________
B (Sequential) (None, 10) 60
=================================================================
Total params: 115
Trainable params: 115
Non-trainable params: 0
_________________________________________________________________
来源:https://stackoverflow.com/questions/65556636/merging-two-trained-networks-for-inferencing-sequentially