I\'m trying to use gradient_override_map
with Tensorflow 2.0. There is an example in the documentation, which I will use as the example here as well.
In
In addition to mrry's answer, there are two points I would like to add:
@tf.custom_gradient
def custom_square(x):
def grad(dy):
return tf.constant(0.0)
return tf.square(x), grad
with tf.GradientTape() as tape:
x = tf.Variable(5.0)
s_2 = custom_square(x)
print(tape.gradient(s_2,x).numpy())
custom grad
with the previous gradBe careful, gradient calculation is a chained computation, we should multiply our custom grad by dy
(the previously computed gradient).
Without doing this, our customized function will be broken in a chain calculation. This is an example:
@tf.custom_gradient
def custom_square(x):
def grad(dy):
return tf.constant(4.0)
return tf.square(x), grad
with tf.GradientTape(persistent=True) as tape:
x = tf.Variable(5.0)
s_2 = custom_square(x)
s_4 = custom_square(s_2)
print("Grad from s_4 to x: ",tape.gradient(s_4,x).numpy())
print("Grad from s_4 to s_2: ",tape.gradient(s_4,s_2).numpy())
print("Grad from s_2 to x: ",tape.gradient(s_2,x).numpy())
The result:
Grad from s_4 to x: 4.0
Grad from s_4 to s_2: 4.0
Grad from s_2 to x: 4.0
Grad from s_4
to x
should be 16 (accumulated grad from s_4
to s_2
and grad frm s_2
to x
).
but the result was 4. That mean it didn't accumulate gradient from previous step.
Multiply the custom grad with dy
will solve the problem:
@tf.custom_gradient
def custom_square(x):
def grad(dy):
return tf.constant(4.0)*dy
return tf.square(x), grad
with tf.GradientTape(persistent=True) as tape:
x = tf.Variable(5.0)
s_2 = custom_square(x)
s_4 = custom_square(s_2)
print("Grad from s_4 to x: ",tape.gradient(s_4,x).numpy())
print("Grad from s_4 to s_2: ",tape.gradient(s_4,s_2).numpy())
print("Grad from s_2 to x: ",tape.gradient(s_2,x).numpy())
Here is the result:
Grad from s_4 to x: 16.0
Grad from s_4 to s_2: 4.0
Grad from s_2 to x: 4.0
You can try the implementation through Colab here: https://colab.research.google.com/drive/1gbLopOLJiyznDA-Cr473bZEeWkWh_KGG?usp=sharing
There is no built-in mechanism in TensorFlow 2.0 to override all gradients for a built-in operator within a scope. However, if you are able to modify the call-site for each call to the built-in operator, you can use the tf.custom_gradient
decorator as follows:
@tf.custom_gradient
def custom_square(x):
def grad(dy):
return tf.constant(0.0)
return tf.square(x), grad
with tf.Graph().as_default() as g:
x = tf.Variable(5.0)
with tf.GradientTape() as tape:
s_2 = custom_square(x)
with tf.compat.v1.Session() as sess:
sess.run(tf.compat.v1.global_variables_initializer())
print(sess.run(tape.gradient(s_2, x)))