Not OP. I prefer JAX for non-AI tasks in scientific computing because of the different mental model than PyTorch. In JAX, you think about functions and gradients of functions. In PyTorch you think about tensors which accumulate a gradient while being manipulated through functions. JAX just suits my way of thinking much better.
I also like that jax.jit forces you to write "functional" functions free of side effects or inplace array updates. It might feel weird at first (and not every algorithm is suited for this style) but ultimately it leads to clearer and faster code.
I am surprised that JIT in PyTorch gets so little attention. Maybe it's less impactful for PyTorch's usual usecase of large networks, as opposed to general scientific computing?
>I also like that jax.jit forces you to write "functional" functions free of side effects or inplace array updates. It might feel weird at first (and not every algorithm is suited for this style) but ultimately it leads to clearer and faster code.
It's not weird. It's actually the most natural way of doing things for me. You just write down your math equations as JAX and you're done.
> You just write down your math equations as JAX and you're done.
It's natural when your basic unit is a whole vector (tensor), manipulated by some linear algebra expression. It's less natural if your basic unit is an element of a vector.
If you're solving sudoku, for example, the obvious 'update' is in-place.
In-place updates are also often the right answer for performance reasons, such as writing the output of a .map() operation directly to the destination tensor. Jax leans heavily on compile-time optimizations to turn the mathematically-nice code into computer-nice code, so the delta between eager-Jax and compiled-Jax is much larger than the delta between eager-Pytorch and compiled-Pytorch.
I also like that jax.jit forces you to write "functional" functions free of side effects or inplace array updates. It might feel weird at first (and not every algorithm is suited for this style) but ultimately it leads to clearer and faster code.
I am surprised that JIT in PyTorch gets so little attention. Maybe it's less impactful for PyTorch's usual usecase of large networks, as opposed to general scientific computing?