I am trying to optimize the following function using jit:
@partial(jit, static_argnums=(0, 1,)) def coocurrence_helper(