I am confused about how to compute higher-order multivariate derivatives in jax.
For example, how do you compute d^2f / dx dy for
def f(x, y): ret