Understanding cdist() function

生来就可爱ヽ(ⅴ<●) 提交于 2021-01-28 09:40:56

问题


What does this new_cdist() function actually do? More specifically:

  1. Why is there a sqrt() operation when the AdderNet paper does not use it in its backward propagation equation?
  2. How is needs_input_grad[] used?
def new_cdist(p, eta):
    class cdist(torch.autograd.Function):
        @staticmethod
        def forward(ctx, W, X):
            ctx.save_for_backward(W, X)
            out = -torch.cdist(W, X, p)
            return out

        @staticmethod
        def backward(ctx, grad_output):
            W, X = ctx.saved_tensors
            grad_W = grad_X = None
            if ctx.needs_input_grad[0]:
                _temp1 = torch.unsqueeze(X, 2).expand(X.shape[0], X.shape[1], W.shape[0]).permute(1, 0, 2)
                _temp2 = torch.unsqueeze(W.transpose(0, 1), 1)
                _temp = torch.cdist(_temp1, _temp2, p).squeeze().transpose(0, 1)
                grad_W = torch.matmul(grad_output, _temp)
                # print('before norm: ', torch.norm(grad_W))
                grad_W = eta * np.sqrt(grad_W.numel()) / torch.norm(grad_W) * grad_W
                print('after norm: ', torch.norm(grad_W))
            if ctx.needs_input_grad[1]:
                _temp1 = torch.unsqueeze(W, 2).expand(W.shape[0], W.shape[1], X.shape[0]).permute(1, 0, 2)
                _temp2 = torch.unsqueeze(X.transpose(0, 1), 1)
                _temp = torch.cdist(_temp1, _temp2, p).squeeze().transpose(0, 1)
                _temp = torch.nn.functional.hardtanh(_temp, min_val=-1., max_val=1.)
                grad_X = torch.matmul(grad_output.transpose(0, 1), _temp)
            return grad_W, grad_X
    return cdist().apply

I mean that it seems to be related to a new type of back-propagation equation and adaptive learning rate.


回答1:


Actually, the AdderNet paper does use the sqrt. It is in the adaptive learning rate computation (Algorithm 1, line 6). More specifically, you can see that Eq. 12:

is what is written in this line:

# alpha_l = eta * np.sqrt(grad_W.numel()) / torch.norm(grad_W)
grad_W = eta * np.sqrt(grad_W.numel()) / torch.norm(grad_W) * grad_W

and the sqrt() comes from Eq. 13:

where k denotes the number of elements in F_l to average the l2-norm, and η is a hyper-parameter to control the learning rate of adder filters.


About your second question: needs_input_grad is just a variable to check if the inputs really require gradients. [0] in this case would refer to W, and [1] to X. You can read more about it here.



来源:https://stackoverflow.com/questions/61154470/understanding-cdist-function

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!