Keras hypernetwork implementation?

女生的网名这么多〃 提交于 2020-08-10 20:43:09

问题


What would be the most straightforward way to implement a hypernetwork in Keras? That is, where one leg of the network creates the weights for another? In particular, I would like to do template matching where I feed the template in to a CNN leg that generates a convolutional kernel for a leg that operates on the main image. The part I'm unsure of is where I have a CNN layer that is fed weights externally, yet the gradients still flow through properly for training.


回答1:


The weights leg:

For the weights leg, just create a regiular network as you would with Keras.

Be sure that its output(s) have shape like (spatial_kernel_size1, spatial_kernel_size2, input_channels, output_channels)

Usint the functional API you can create a few weights, for instance:

inputs = Input((imgSize1, imgSize2, imgChannels))

w1 = Conv2D(desired_channels, ....)(inputs)
w2 = Conv2D(desired_channels2, ....)(inputs or w1)
....

You should apply some kind of pooling here, since your outputs will have a huge size and you probably want filters with small sizes such as 3, 5, etc.

w1 = GlobalAveragePooling2D()(w1) #maybe GlobalMaxPooling2D
w2 = GlobalAveragePooling2D()(w2)

If you're using fixed image sizes, you could also use other kinds of pooling or flatten and dense, etc.

Make sure you reshape the weights for the correct shape.

w1 = Reshape((size1,size2,input_channels, output_channels))(w1)
w2 = Reshape((sizeA, sizeB, input_channels2, output_channels2))(w2)
....

The choice of the number of channels is up to you to optimize

The convolutional leg:

Now, this leg will only use "non trainable" convolutions, they can be found directly in the backend and be used in Lambda layers:

out1 = Lambda(lambda x: K.conv2d(x[0], x[1]))([inputs,w1])
out2 = Lambda(lambda x: K.conv2d(x[0], x[1]))([out1,w2])

Now, how you're going to interleave the layers, how many weights, etc., is also something you should optimize for yourself.

Create the model:

model = Model(inputs, out2)

Interleaving

You may take an output from this leg as input for the weight generator leg too:

w3 = Conv2D(filters, ...)(out2)
w3 = GlobalAveragePooling2D()(w3)
w3 = Reshape((sizeI, sizeII, inputC, outputC))(w3)
out3 = Lambda(lambda x: K.conv2d(x[0], x[1]))([out2,w3])


来源:https://stackoverflow.com/questions/56812831/keras-hypernetwork-implementation

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!