In TensorFlow, the tf.while_loop is fully differentiable.
tf.while_loop
In JAX, according to the documentation of jax.lax.while_loop:
while_l