Pytorch most efficient Jacobian/Hessian calculation

天涯浪子 提交于 2019-12-23 19:43:21

问题


I am looking for the most efficient way to get the Jacobian of a function through Pytorch and have so far come up with the following solutions:

def func(X):
    return torch.stack((
                     X.pow(2).sum(1),
                     X.pow(3).sum(1),
                     X.pow(4).sum(1)
                      ),1)  

X = Variable(torch.ones(1,int(1e5))*2.00094, requires_grad=True).cuda()

                                # Solution 1:
t = time()
Y = func(X)
J = torch.zeros(3, int(1e5))
for i in range(3):
    J[i] = grad(Y[0][i], X, create_graph=True, retain_graph=True, allow_unused=True)[0]
print(time()-t)
Output: 0.002 s

                                # Solution 2:
def Jacobian(f,X):
    X_batch = Variable(X.repeat(3,1), requires_grad=True)
    f(X_batch).backward(torch.eye(3).cuda(),  retain_graph=True)
    return X_batch.grad

t = time()
J2 = Jacobian(func,X)
print(time()-t)
Output: 0.001 s

Since there seem to be not a big difference between using a loop in the first solution than the second one, I wanted to ask if there might still be be a faster way to calculate a Jacobian in pytorch.

My other question is then also about what might be the most efficient way to calculate the Hessian.

Finally, does anyone know if something like this can be done easier or more efficient in TensorFlow?


回答1:


I had a similar problem which I solved by defining the Jacobian manually (calculating the derivatives by hand). For my problem this was feasible, but I can imagine that is not always the case. The computation time speeds up some factors on my machine (cpu), compared to the second solution.

# Solution 2
def Jacobian(f,X):
    X_batch = Variable(X.repeat(3,1), requires_grad=True)
    f(X_batch).backward(torch.eye(3).cuda(),  retain_graph=True)
    return X_batch.grad

%timeit Jacobian(func,X)
11.7 ms ± 130 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# Solution 3
def J_func(X):
    return torch.stack(( 
                 2*X,
                 3*X.pow(2),
                 4*X.pow(3)
                  ),1)

%timeit J_func(X)
539 µs ± 24.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


来源:https://stackoverflow.com/questions/56480578/pytorch-most-efficient-jacobian-hessian-calculation

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!