far.in.net


~Switching to JAX

JAX is a Python library for writing and transforming array programs, such as arise in the course of deep learning research. The JAX API for writing array programs is similar to that of NumPy, but the JAX API for transforming array programs is utterly unique, effortlessly powerful, and delightfully elegant.

A few years ago, I switched from PyTorch to JAX for my deep learning research. I have since fallen in love with JAX, and I would never go back. Here are some reasons why.

This isn’t to say that I think JAX is perfect, or the right choice for everyone. Here are some barriers you might face if you want to make the switch, and my thoughts on each.

If you think about neural networks in terms of equations; if you are not afraid of learning how to solve problems in new ways; if you don’t want to let your tools hold you back from pushing back the unknown as fast and as far as possible; consider switching to JAX.

—MFR