Test all funtions with jax.test_util.check_grads
to ensure differentiability #438
Open
Description
Is Your Feature Request Related to a Problem? Please Describe
As we move towards porting stuff to the Array API to make things GPU enabled, we ultimately one them to be differentiable too. Rather than coming up with a solution that works for GPU, which then needs a massive rewrite for differentiability. It would be good to ensure functions are on route to differentiability at the time.
Describe the Solution You'd Like
No response
Describe Alternatives You've Considered
No response
Additional Context
No response