How to use gradient_override_map in Tensorflow 2.0?

前端 未结 2 2071
既然无缘
既然无缘 2021-01-01 02:11

I\'m trying to use gradient_override_map with Tensorflow 2.0. There is an example in the documentation, which I will use as the example here as well.

In

2条回答
  •  耶瑟儿~
    2021-01-01 02:35

    There is no built-in mechanism in TensorFlow 2.0 to override all gradients for a built-in operator within a scope. However, if you are able to modify the call-site for each call to the built-in operator, you can use the tf.custom_gradient decorator as follows:

    @tf.custom_gradient
    def custom_square(x):
      def grad(dy):
        return tf.constant(0.0)
      return tf.square(x), grad
    
    with tf.Graph().as_default() as g:
      x = tf.Variable(5.0)
      with tf.GradientTape() as tape:
        s_2 = custom_square(x)
    
      with tf.compat.v1.Session() as sess:
        sess.run(tf.compat.v1.global_variables_initializer())            
        print(sess.run(tape.gradient(s_2, x)))
    

提交回复
热议问题