tf.GradientTape详解

ぃ、小莉子 提交于 2020-01-31 20:39:20

参考文献:https://blog.csdn.net/guanxs/article/details/102471843

        在TensorFlow 1.x静态图时代,我们知道每个静态图都有两部分,一部分是前向图,另一部分是反向图。反向图就是用来计算梯度的,用在整个训练过程中。而TensorFlow 2.0默认是eager模式,每行代码顺序执行,没有了构建图的过程(也取消了control_dependency的用法)。但也不能每行都计算一下梯度吧?计算量太大,也没必要。因此,需要一个上下文管理器(context manager)来连接需要计算梯度的函数和变量,方便求解同时也提升效率。
       举个例子:计算y=x^2在x = 3时的导数:

x = tf.constant(3.0)
with tf.GradientTape() as g:
  g.watch(x)
  y = x * x
dy_dx = g.gradient(y, x) # y’ = 2*x = 2*3 = 6

例子中的watch函数把需要计算梯度的变量x加进来了。GradientTape默认只监控由tf.Variable创建的traiable=True属性(默认)的变量。上面例子中的x是constant,因此计算梯度需要增加g.watch(x)函数。当然,也可以设置不自动监控可训练变量,完全由自己指定,设置watch_accessed_variables=False就行了(一般用不到)。

GradientTape也可以嵌套多层用来计算高阶导数,例如:

x = tf.constant(3.0)
with tf.GradientTape() as g:
  g.watch(x)
  with tf.GradientTape() as gg:
    gg.watch(x)
    y = x * x
  dy_dx = gg.gradient(y, x)      # y’ = 2*x = 2*3 =6
d2y_dx2 = g.gradient(dy_dx, x)  # y’’ = 2

另外,默认情况下GradientTape的资源在调用gradient函数后就被释放,再次调用就无法计算了。所以如果需要多次计算梯度,需要开启persistent=True属性,例如:

x = tf.constant(3.0)
with tf.GradientTape(persistent=True) as g:
  g.watch(x)
  y = x * x
  z = y * y
dz_dx = g.gradient(z, x)  # z = y^2 = x^4, z’ = 4*x^3 = 4*3^3
dy_dx = g.gradient(y, x)  # y’ = 2*x = 2*3 = 6
del g  # 删除这个上下文tape

最后,一般在网络中使用时,不需要显式调用watch函数,使用默认设置,GradientTape会监控可训练变量,例如:

with tf.GradientTape() as tape:
    predictions = model(images)
    loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)

这样即可计算出所有可训练变量的梯度,然后进行下一步的更新。对于TensorFlow 2.0,推荐大家使用这种方式计算梯度,并且可以在eager模式下查看具体的梯度值。

根据上面的例子说一下tf.GradientTape这个类的常见的属性和函数,更多的可以去官方文档来看。

__init__(persistent=False,watch_accessed_variables=True)
作用:创建一个新的GradientTape
参数:

persistent: 布尔值,用来指定新创建的gradient tape是否是可持续性的。默认是False,意味着只能够调用一次gradient()函数。
watch_accessed_variables: 布尔值,表明这个gradien tap是不是会自动追踪任何能被训练(trainable)的变量。默认是True。要是为False的话,意味着你需要手动去指定你想追踪的那些变量。
比如在上面的例子里面,新创建的gradient tape设定persistent为True,便可以在这个上面反复调用gradient()函数。

watch(tensor)
作用:确保某个tensor被tape追踪

参数:

tensor: 一个Tensor或者一个Tensor列表
gradient(target,sources,output_gradients=None,unconnected_gradients=tf.UnconnectedGradients.NONE)
作用:根据tape上面的上下文来计算某个或者某些tensor的梯度
参数:

target: 被微分的Tensor或者Tensor列表,你可以理解为经过某个函数之后的值
sources: Tensors 或者Variables列表(当然可以只有一个值). 你可以理解为函数的某个变量
output_gradients: a list of gradients, one for each element of target. Defaults to None.
unconnected_gradients: a value which can either hold ‘none’ or ‘zero’ and alters the value which will be returned if the target and sources are unconnected. The possible values and effects are detailed in ‘UnconnectedGradients’ and it defaults to ‘none’.
返回:
一个列表表示各个变量的梯度值,和source中的变量列表一一对应,表明这个变量的梯度。

 

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!