问题
What would be the most straightforward way to implement a hypernetwork in Keras? That is, where one leg of the network creates the weights for another? In particular, I would like to do template matching where I feed the template in to a CNN leg that generates a convolutional kernel for a leg that operates on the main image. The part I'm unsure of is where I have a CNN layer that is fed weights externally, yet the gradients still flow through properly for training.
回答1:
The weights leg:
For the weights leg, just create a regiular network as you would with Keras.
Be sure that its output(s) have shape like (spatial_kernel_size1, spatial_kernel_size2, input_channels, output_channels)
Usint the functional API you can create a few weights, for instance:
inputs = Input((imgSize1, imgSize2, imgChannels))
w1 = Conv2D(desired_channels, ....)(inputs)
w2 = Conv2D(desired_channels2, ....)(inputs or w1)
....
You should apply some kind of pooling here, since your outputs will have a huge size and you probably want filters with small sizes such as 3, 5, etc.
w1 = GlobalAveragePooling2D()(w1) #maybe GlobalMaxPooling2D
w2 = GlobalAveragePooling2D()(w2)
If you're using fixed image sizes, you could also use other kinds of pooling or flatten and dense, etc.
Make sure you reshape the weights for the correct shape.
w1 = Reshape((size1,size2,input_channels, output_channels))(w1)
w2 = Reshape((sizeA, sizeB, input_channels2, output_channels2))(w2)
....
The choice of the number of channels is up to you to optimize
The convolutional leg:
Now, this leg will only use "non trainable" convolutions, they can be found directly in the backend and be used in Lambda
layers:
out1 = Lambda(lambda x: K.conv2d(x[0], x[1]))([inputs,w1])
out2 = Lambda(lambda x: K.conv2d(x[0], x[1]))([out1,w2])
Now, how you're going to interleave the layers, how many weights, etc., is also something you should optimize for yourself.
Create the model:
model = Model(inputs, out2)
Interleaving
You may take an output from this leg as input for the weight generator leg too:
w3 = Conv2D(filters, ...)(out2)
w3 = GlobalAveragePooling2D()(w3)
w3 = Reshape((sizeI, sizeII, inputC, outputC))(w3)
out3 = Lambda(lambda x: K.conv2d(x[0], x[1]))([out2,w3])
来源:https://stackoverflow.com/questions/56812831/keras-hypernetwork-implementation