How to compute gradient of output wrt input in Tensorflow 2.0

◇◆丶佛笑我妖孽 提交于 2020-06-25 12:16:07

问题


I have a trained Tensorflow 2.0 model (from tf.keras.Sequential()) that takes an input layer with 26 columns (X) and produces an output layer with 1 column (Y).

In TF 1.x I was able to calculate the gradient of the output with respect to the input with the following:

model = load_model('mymodel.h5')
sess = K.get_session()
grad_func = tf.gradients(model.output, model.input)
gradients = sess.run(grad_func, feed_dict={model.input: X})[0]

In TF2 when I try to run tf.gradients(), I get the error:

RuntimeError: tf.gradients is not supported when eager execution is enabled. Use tf.GradientTape instead.

In the question In TensorFlow 2.0 with eager-execution, how to compute the gradients of a network output wrt a specific layer?, we see an answer on how to calculate gradients with respect to intermediate layers, but I don't see how to apply this to gradients with respect to the inputs. On the Tensorflow help for tf.GradientTape, there are examples with calculating gradients for simple functions, but not neural networks.

How can tf.GradientTape be used to calculate the gradient of the output with respect to the input?


回答1:


This should work in TF2:

inp = tf.Variable(np.random.normal(size=(25, 120)), dtype=tf.float32)

with tf.GradientTape() as tape:
    preds = model(inp)

grads = tape.gradient(preds, inp)

Basically you do it the same way as TF1, but using GradientTape.




回答2:


I hope this is what you're looking for. This will give the gradients of the output w.r.t. the inputs.

# Whatever the input you like goes in as the initial_value
x = tf.Variable(np.random.normal(size=(25, 120)), dtype=tf.float32)
y_true = np.random.choice([0,1], size=(25,10))

print(model.output)
print(model.predict(x))
with tf.GradientTape() as tape:
  pred = model.predict(x)

grads = tape.gradients(pred, x)


来源:https://stackoverflow.com/questions/59145221/how-to-compute-gradient-of-output-wrt-input-in-tensorflow-2-0

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!