element-wise multiplication with broadcasting in keras custom layer

后端 未结 2 845
迷失自我
迷失自我 2020-12-19 07:57

I am creating a custom layer with weights that need to be multiplied by element-wise before activation. I can get it to work when the output and input is the same shape. The

相关标签:
2条回答
  • 2020-12-19 08:17

    Before multiplying, you need to repeat the elements to increase the shape. You can use K.repeat_elements for that. (import keras.backend as K)

    class MyLayer(Layer):
    
        #there are some difficulties for different types of shapes   
        #let's use a 'repeat_count' instead, increasing only one dimension
        def __init__(self, repeat_count,**kwargs):
            self.repeat_count = repeat_count
            super(MyLayer, self).__init__(**kwargs)
    
        def build(self, input_shape):
    
            #first, let's get the output_shape
            output_shape = self.compute_output_shape(input_shape)
            weight_shape = (1,) + output_shape[1:] #replace the batch size by 1
    
    
            self.kernel = self.add_weight(name='kernel',
                                          shape=weight_shape,
                                          initializer='ones',
                                          trainable=True)
    
    
            super(MyLayer, self).build(input_shape)  # Be sure to call this somewhere!
    
        #here, we need to repeat the elements before multiplying
        def call(self, x):
    
            if self.repeat_count > 1:
    
                 #we add the extra dimension:
                 x = K.expand_dims(x, axis=1)
    
                 #we replicate the elements
                 x = K.repeat_elements(x, rep=self.repeat_count, axis=1)
    
    
            #multiply
            return x * self.kernel
    
    
        #make sure we comput the ouptut shape according to what we did in "call"
        def compute_output_shape(self, input_shape):
    
            if self.repeat_count > 1:
                return (input_shape[0],self.repeat_count) + input_shape[1:]
            else:
                return input_shape
    
    0 讨论(0)
  • 2020-12-19 08:36

    here is another solution that is based on the answer by Daniel Möller, but uses tf.multiply like the original code.

    class MyLayer(Layer):
    
        def __init__(self, output_dim, **kwargs):
            self.output_dim = output_dim
    
            super(MyLayer, self).__init__(**kwargs)
    
        def build(self, input_shape):
            # Create a trainable weight variable for this layer.
            output_shape = self.compute_output_shape(input_shape)
            self.kernel = self.add_weight(name='kernel',
                                          shape=(1,) + output_shape[1:],
                                          initializer='ones',
                                          trainable=True)
    
    
            super(MyLayer, self).build(input_shape)  # Be sure to call this somewhere!
    
        def call(self, x):
    
            return K.tf.multiply(x, self.kernel)
    
        def compute_output_shape(self, input_shape):
            return (input_shape[0],self.output_dim)+input_shape[1:]
    
    0 讨论(0)
提交回复
热议问题