How to return intermideate gradients (for non-leaf nodes) in pytorch?

送分小仙女□ 提交于 2021-02-10 14:25:44

问题


My question is concerning the syntax of pytorch register_hook.

x = torch.tensor([1.], requires_grad=True)
y = x**2
z = 2*y

x.register_hook(print)
y.register_hook(print)

z.backward()

outputs:

tensor([2.])
tensor([4.])

this snippet simply prints the gradient of z w.r.t x and y, respectively.

Now my (most likely trivial) question is how to return the intermediate gradients (rather than only printing)?

UPDATE:

It appears that calling retain_grad() solves the issue for leaf nodes. ex. y.retain_grad().

However, retain_grad does not seem to solve it for non-leaf nodes. Any suggestions?


回答1:


I think you can use those hooks to store the gradients in a global variable:

grads = []
x = torch.tensor([1.], requires_grad=True)
y = x**2 + 1
z = 2*y

x.register_hook(lambda d:grads.append(d))
y.register_hook(lambda d:grads.append(d))

z.backward()

But you most likely also need to remember the corresponding tensor these gradients were computed for. In that case, we slightly extend above using a dict instead of list:

grads = {}
x = torch.tensor([1.,2.], requires_grad=True)
y = x**2 + 1
z = 2*y

def store(grad,parent):
    print(grad,parent)
    grads[parent] = grad.clone()

x.register_hook(lambda grad:store(grad,x))
y.register_hook(lambda grad:store(grad,y))

z.sum().backward()

Now you can, for example, access tensor y's grad simply using grads[y]



来源:https://stackoverflow.com/questions/55305262/how-to-return-intermideate-gradients-for-non-leaf-nodes-in-pytorch

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!