Not Op. I have production / scale experience in PyTorch and toy/hobby experience in JAX. I wish I could have time time or liberty to use JAX more. It consists of small, orthogonal set of ideas that combine like lego blocks. I can attempt to reason from first principals about performance. The documentation is super readable and strives to make you understand things.
JAX seems well engineered. One would argue so was TensorFlow. But ideas behind JAX were built outside Google (autograd) so it has struck right balance with being close to idiomatic Python / Numpy.
PyTorch is where the tailwinds are, though. It is a wildly successful project which has acquired ton of code over the years. So it is little harder to figure out how something works (say torch-compile) from first principles.
JAX seems well engineered. One would argue so was TensorFlow. But ideas behind JAX were built outside Google (autograd) so it has struck right balance with being close to idiomatic Python / Numpy.
PyTorch is where the tailwinds are, though. It is a wildly successful project which has acquired ton of code over the years. So it is little harder to figure out how something works (say torch-compile) from first principles.