How to register a custom gradient for a operation composed of tf operations

后端 未结 3 1523
既然无缘
既然无缘 2020-12-09 12:00

More specifically I have a simple fprop that is a composition of tf operations. I want to override the tensorflow gradient computation with my own gradient method using Regi

3条回答
  •  时光说笑
    2020-12-09 12:21

    If you want to use tf.RegisterGradient() for this purpose, I'm not sure if it is a proper solution. Because in the official documents https://www.tensorflow.org/api_docs/python/tf/RegisterGradient , it says:

    This decorator is only used when defining a new op type.

    which means you need to define a new op written in C++ or wrapped in py_func. I'm not totally sure if it can apply on the group of "tf op" you said.


    However, You can also refer to the "trick" methods mentioned in this thread:

    How Can I Define Only the Gradient for a Tensorflow Subgraph?

    where you could combine tf.stop_gradient() and tfgradient_override_map() together to re-define the gradients for groups of operations

提交回复
热议问题