~Switching to JAX
JAX is a Python library for writing and transforming array programs, such as arise in the course of deep learning research. The JAX API for writing array programs is similar to that of NumPy, but the JAX API for transforming array programs is utterly unique, effortlessly powerful, and delightfully elegant.
A few years ago, I switched from PyTorch to JAX for my deep learning research. I have since fallen in love with JAX, and I would never go back. Here are some reasons why.
JAX is fast. The main selling point of JAX is its integration with the XLA array program compiler. XLA minimises unnecessary memory allocations and reshapes computations to fit nicely on whatever processors are available. This can lead to multiple-times speed boosts over uncompiled programs, everywhere from large-scale training runs on a cluster to small-scale pilot/debugging runs on a laptop CPU.
These days, I’m told PyTorch supports compilation and should be able to match JAX for speed, at least for standard deep learning workflows. However, the feature is in fundamental tension with the rest of PyTorch’s eager execution design. By contrast, JAX is built from the ground up around the concept of compilation.
JAX is portable. JAX isn’t just useful on GPUs. It’s equally happy to optimise the same codebase to run on my M2 Air. When I do need to reach for more compute, I can take advantage of Google’s TPU Research Cloud programme and train on a free TPU cluster.
PyTorch is actually pretty portable too, at least across CPU, GPU, and Apple hardware. However, TPUs in particular are a pain point. The PyTorch/XLA backend exists, I’ve tried it, but I would not use the word “portable” to describe the experience. It was a nightmare to set up the environment and modify the training code to work with PyTorch/XLA. I hope the situation gets better in the future, but in the mean time, JAX was made to work seamlessly with TPUs.
JAX is elegant. The JAX library design celebrates the powerful abstractions forged over hundreds of years by mathematicians (tensors, functions, differentiation), alongside the most fundamental computational operations (batching, iteration, parallelisation, compilation). The API makes these elements of array programming effortless individually and in composition. As a result, natural ideas are natural to implement in JAX.
PyTorch itself defeated TensorFlow by removing boilerplate and friction in implementing natural ideas. But PyTorch’s design is overfit for the current deep learning paradigm (neural network, forward pass, backward pass, train, evaluate). JAX’s flexible and composable function transformations bring new possibilities within reach. In the course of research—where the goal is to push the boundaries of the possible—the best tool is the one that most closely mirrors the mathematical language in which we have and share our insights.
JAX is good for your code. JAX will only transform functions that conform with certain restrictions: JAX arrays are immutable, functions can’t normally have side-effects, types (including array shapes) must be statically determinable. These restrictions nudge your code to align with best-practice functional programming principles, making interdependencies more explicit, thereby making code more straight-forward and readable and less likely to have errors.
In PyTorch, you forget to track gradients (or stop tracking them), your code runs, and you have no idea what is wrong. Determinism requires seeding several different procedural random number generators, and you had better not change how many the order in which you draw from them anywhere in your codebase. These problems are difficult to run into in JAX due to the way the autograd and PRNG APIs are designed. JAX has its foot-guns, but at least they don’t have silencers on (at least, not by default).
This isn’t to say that I think JAX is perfect, or the right choice for everyone. Here are some barriers you might face if you want to make the switch, and my thoughts on each.
JAX has limitations. For example, it inherits Python’s lack of support for type checking array programs. Also, it’s slightly awkward to have to twist your for loops and if statements into unnatural shapes so that JAX can accelerate them (though I think the problem has been overstated). I have a small wishlist of other minor and miscellaneous API design improvements.
But it’s on the right track. JAX is far and away the closest thing that exists to my conception of the ideal array programming library, and it’s under active development by a team that has shown they can get the big picture so right. The future looks bright.
JAX is less popular. Since JAX hasn’t been around as long as PyTorch, the latter has a larger and more well-resourced ecosystem. If you switch to JAX, you might have to write some code yourself when the paper you want to replicate or the architecture you want to try provides a PyTorch codebase. Your colleagues might refuse to port their existing code to work with JAX.
But you are not alone. While still small relative to the PyTorch ecosystem, the JAX ecosystem is decent in absolute terms today. There is an active community that has built many awesome JAX tools. I see new projects posted every week.
JAX is hard to learn. Using JAX requires thinking about array programs in new ways. You will encounter challenges you have no idea how to solve, and this will make you feel like you’re a beginner programmer again. I hit this wall myself. It’s the same kind of wall I saw hundreds of students hit back when I taught Haskell to Python programmers. Crossing this kind of wall is difficult and it takes time.
But you can do it. My students always cracked the “Haskell mindset” with enough persistence. I myself eventually cracked the “JAX mindset.” I am confident that you can too. You already learned to program once. In fact, I bet you found it thrilling—do you remember what it was like every time you mastered a new pattern of solving coding problems?
I learned JAX by reading the (refreshingly excellent) documentation, and working through the many challenges that arose in the course of building a deep reinforcement learning research codebase. Since then, I was inspired to turn my experience into a free online course, where I show people how to work through the same kinds of challenges. I hope it might help you cross the wall, if you’re willing to try.
If you think about neural networks in terms of equations; if you are not afraid of learning how to solve problems in new ways; if you don’t want to let your tools hold you back from pushing back the unknown as fast and as far as possible; consider switching to JAX.
—MFR