What does the parameter retain_graph mean in the Variable's backward() method?

后端 未结 2 2008
隐瞒了意图╮
隐瞒了意图╮ 2020-12-12 16:13

I\'m going through the neural transfer pytorch tutorial and am confused about the use of retain_variable(deprecated, now referred to as retain_graph

2条回答
  •  盖世英雄少女心
    2020-12-12 16:53

    @cleros is pretty on the point about the use of retain_graph=True. In essence, it will retain any necessary information to calculate a certain variable, so that we can do backward pass on it.

    An illustrative example

    Suppose that we have a computation graph shown above. The variable d and e is the output, and a is the input. For example,

    import torch
    from torch.autograd import Variable
    a = Variable(torch.rand(1, 4), requires_grad=True)
    b = a**2
    c = b*2
    d = c.mean()
    e = c.sum()
    

    when we do d.backward(), that is fine. After this computation, the part of graph that calculate d will be freed by default to save memory. So if we do e.backward(), the error message will pop up. In order to do e.backward(), we have to set the parameter retain_graph to True in d.backward(), i.e.,

    d.backward(retain_graph=True)
    

    As long as you use retain_graph=True in your backward method, you can do backward any time you want:

    d.backward(retain_graph=True) # fine
    e.backward(retain_graph=True) # fine
    d.backward() # also fine
    e.backward() # error will occur!
    

    More useful discussion can be found here.

    A real use case

    Right now, a real use case is multi-task learning where you have multiple loss which maybe be at different layers. Suppose that you have 2 losses: loss1 and loss2 and they reside in different layers. In order to backprop the gradient of loss1 and loss2 w.r.t to the learnable weight of your network independently. You have to use retain_graph=True in backward() method in the first back-propagated loss.

    # suppose you first back-propagate loss1, then loss2 (you can also do the reverse)
    loss1.backward(retain_graph=True)
    loss2.backward() # now the graph is freed, and next process of batch gradient descent is ready
    optimizer.step() # update the network parameters
    

提交回复
热议问题