Persistent Variable in keras Custom Layer

前端 未结 1 1573
不思量自难忘°
不思量自难忘° 2020-12-17 05:37

I want write a custom layer, where I can keep a variable in memory between runs. For example,

class MyLayer(Layer):
def __init__(self, out_dim = 51, **kwarg         


        
相关标签:
1条回答
  • 2020-12-17 05:58

    The trick is that you have to call self.add_update(...) in your call function to register a function that will be called every time your model is evaluated (I found this by digging into the source code of the stateful rnns). If you do self.stateful = True it will call your custom update function for every training and prediction call, otherwise it will only call it during training. For example:

    import keras.backend as K
    import numpy as np
    from keras.engine.topology import Layer
    
    class CounterLayer(Layer):
      def __init__(self, stateful=False,**kwargs):
        self.stateful = stateful # True means it will increment counter on predict and train, false means it will only increment counter on train 
        super(CounterLayer, self).__init__(**kwargs)
    
    
      def build(self, input_shape):
        # Define variables in build
        self.count = K.variable(0, name="count")
        super(CounterLayer, self).build(input_shape)
    
      def call(self, x, mask=None):
        updates = []
        # The format is (variable, value setting to)
        # So this says 
        # self.pos = self.pos + 1
        updates.append((self.count, self.count+1))
    
        # You can append more updates to this list or call add_update more
        # times if you want
    
        # Add our custom update
    
        # We stick x here so it calls our update function every time our layer 
        # is given a new x
        self.add_update(updates, x)
    
        # This will be an identity layer but keras gets mad for some reason
        # if you just output x so we'll multiply it by 1 so it thinks it is a
        # "new variable"
        return self.count
      # in newer keras versions you might need to name this compute_output_shape instead
      def get_output_shape_for(self, input_shape):
        # We will just return our count as an array ([[count]])
        return (1,1)
    
      def reset_states(self):
        self.count.set_value(0)
    

    Example usage:

    from keras.layers import Input
    from keras.models import Model
    from keras.optimizers import RMSprop
    inputLayer = Input(shape=(10,))
    counter = CounterLayer() # Don't update on predict
    # counter = CounterLayer(stateful=True) # This will update each time you call predict
    counterLayer = counter(inputLayer)
    model = Model(input=inputLayer, output=counterLayer)
    optimizer = RMSprop(lr=0.001)
    model.compile(loss="mse", optimizer=optimizer)
    
    
    # See the value of our counter
    print counter.count.get_value()
    
    # This won't actually train anything but each epoch will update our counter
    
    # Note that if you say have a batch size of 5, update will be called 5 times per epoch
    model.fit(np.zeros([1, 10]), np.array([0]), batch_size=1, nb_epoch=5)
    
    # The value of our counter has now changed
    print counter.count.get_value()
    
    model.predict(np.zeros([1, 10]))
    
    # If we did stateful=False, this didn't change, otherwise it did
    print counter.count.get_value()
    
    0 讨论(0)
提交回复
热议问题