I\'m trying to use gradient_override_map
with Tensorflow 2.0. There is an example in the documentation, which I will use as the example here as well.
In
There is no built-in mechanism in TensorFlow 2.0 to override all gradients for a built-in operator within a scope. However, if you are able to modify the call-site for each call to the built-in operator, you can use the tf.custom_gradient
decorator as follows:
@tf.custom_gradient
def custom_square(x):
def grad(dy):
return tf.constant(0.0)
return tf.square(x), grad
with tf.Graph().as_default() as g:
x = tf.Variable(5.0)
with tf.GradientTape() as tape:
s_2 = custom_square(x)
with tf.compat.v1.Session() as sess:
sess.run(tf.compat.v1.global_variables_initializer())
print(sess.run(tape.gradient(s_2, x)))