问题
I was implementing Sharpness Aware Minimization (SAM) using Tensorflow. The algorithm is simplified as follows
- Compute gradient using current weight W
 - Compute ε according to the equation in the paper
 - Compute gradient using the weights W + ε
 - Update model using gradient from step 3
 
I have implement step 1 and 2 already, but having trouble implementing step 3 according to the code below
def train_step(self, data, rho=0.05, p=2, q=2):
    if (1 / p) + (1 / q) != 1:
        raise tf.python.framework.errors_impl.InvalidArgumentError('p, q must be specified so that 1/p + 1/q = 1')
    x, y = data
        
    # compute first backprop
    with tf.GradientTape() as tape:
        y_pred = self(x, training=True)
        loss = self.compiled_loss(y, y_pred)
    trainable_vars = self.trainable_variables
    gradients = tape.gradient(loss, trainable_vars)
        
    # compute neighborhoods (epsilon_hat) from first backprop
    trainable_w_plus_epsilon_hat = [
        w + (rho * tf.sign(loss) * (tf.pow(tf.abs(g), q-1) / tf.math.pow(tf.norm(g, ord=q), q / p)))
        for w, g in zip(trainable_vars, gradients)
    ]
        
    ### HOW TO SET TRAINABLE WEIGHTS TO `w_plus_epsilon_hat`?
    #
    # TODO:
    #     1. compute gradient using trainable weights from `trainable_w_plus_epsilon_hat`
    #     2. update `trainable_vars` using gradient from step 1
    #
    #########################################################
    self.compiled_metrics.update_state(y, y_pred)
    return {m.name: m.result() for m in self.metrics}
Is there anyway to compute gradient using trainable weights from trainable_w_plus_epsilon_hat?
来源:https://stackoverflow.com/questions/65381773/computing-gradient-of-the-model-with-modified-weights