Pytorch, what are the gradient arguments

后端 未结 4 689
北荒
北荒 2020-11-30 16:44

I am reading through the documentation of PyTorch and found an example where they write

gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
y.backward(gradien         


        
4条回答
  •  Happy的楠姐
    2020-11-30 17:35

    Here, the output of forward(), i.e. y is a a 3-vector.

    The three values are the gradients at the output of the network. They are usually set to 1.0 if y is the final output, but can have other values as well, especially if y is part of a bigger network.

    For eg. if x is the input, y = [y1, y2, y3] is an intermediate output which is used to compute the final output z,

    Then,

    dz/dx = dz/dy1 * dy1/dx + dz/dy2 * dy2/dx + dz/dy3 * dy3/dx
    

    So here, the three values to backward are

    [dz/dy1, dz/dy2, dz/dy3]
    

    and then backward() computes dz/dx

提交回复
热议问题