~Hi, JAX!
An introduction to vanilla JAX for deep learning
This page is for the online edition of Hi, JAX! See also: 2024 edition.
This is the syllabus for a course I am currently developing and will be releasing shortly. The course will be released freely for anyone on my YouTube channel. The code is in development on GitHub.
Development progress:
- Demonstrations written: 12 / 12.
- Lectures recorded: 11 / 12.
- Lectures edited: 10 / 12.
Estimated release date: January 2026
§Why learn JAX?
I’m a deep learning scientist. I try to understand the natural principles underpinning the structure and function of learned neural networks. In the course of my research, I have to train, test, and dissect many neural networks, large and small. Neural networks are essentially just arrays of floating point parameters, so the best way to interact with them is using an array programming library. That is where JAX comes in.
JAX is a Python library for writing and transforming array programs. The JAX API for writing array programs is similar to that of NumPy, but the JAX API for transforming array programs is utterly unique, effortlessly powerful, and delightfully elegant.
Since switching from using PyTorch to using JAX for my research, I have personally noticed the following benefits.
JAX is fast. The main selling point of JAX is its integration with the XLA array program compiler. XLA gives a speed boost for even small-scale JAX programs, which run multiple times faster than uncompiled array programs, even on the same processors.
JAX is portable. JAX isn’t just useful on GPUs. It’s equally happy to optimise code I want to run on my M2 Air. When I do need to reach for more compute, I don’t have to cough up for a GPU node, I can take advantage of Google’s TPU Research Cloud programme and run my code on a free TPU cluster.
JAX is elegant. Deep learning researchers describe their insights with the powerful abstractions forged over hundreds of years by mathematicians. The JAX library’s design celebrates mathematical abstractions, rather than over-fitting to some subset of methods that happen to be in vogue.
JAX is good for your code. It can be challenging to comply with the various restrictions JAX enforces (immutability, pure functions, static shapes). But even if they weren’t justified by the powerful optimisations they enable, these restrictions are all good ways to make code more explicit, readable, and, ultimately, less wrong.
JAX isn’t perfect. But it’s far and away the closest to my conception of the ideal array programming library. Plus, it’s under active development by a team that has shown they can get the big picture so right, so the future looks bright.
All else equal, I’d choose JAX for my projects, every time. Of course, all else is not always equal. Here are some reasons people might not want to switch to JAX.
JAX is less popular. It hasn’t been around as long as PyTorch, which has a larger and more well-resourced ecosystem. If you switch to JAX, you might have to write more code yourself when the paper you want to replicate or the architecture you want to try provides a PyTorch codebase. Your colleagues might refuse to port their existing code to work with JAX.
But you are not alone. While still small relative to the PyTorch ecosystem, the JAX ecosystem is decent in absolute terms today. There are many awesome JAX tools out there and I see new projects posted every week. Maybe I will see your contribution one day?
JAX is hard to learn. Using JAX requires thinking about array programs in new ways. You will encounter challenges you have no idea how to solve, and this will make you feel like you’re a beginner programmer again. This is difficult, and there aren’t many learning resources out there that convey the “JAX mindset.”
That changes, now. I went through this journey myself. It took me months to crack the “JAX mindset.” But I did eventually crack it, and now I have made this course—including programming demonstrations drawn from real challenges I came up against when learning JAX.
I can feel which way the wind is blowing. Can you?
§Prerequisites
Programming:
- Prior experience programming in Python (e.g., dataclasses, functools, type annotations).
- Prior experience programming in NumPy.
- NumPy basics tutorial may be sufficient.
- Prior experience with einops.
- Einops basics tutorial recommended (more than sufficient).
Machine learning:
- Basic vector calculus and optimisation (stochastic gradient descent).
- Basic deep learning architectures (MLP, CNN, Transformer).
- 3Blue1Brown Season 3 should be sufficient (skip chapter 4).
Helpful/optional (we will cover what is required):
- Basic reinforcement learning (Markov decision process formalism, policy gradients).
- Prior experience programming in, e.g., Rust (the concept of immutability).
- Prior experience programming in, e.g., Haskell (the concepts of pure functions, mapping, folding).
You will need a Python environment (but not necessarily a GPU) if you want to code along during the workshops.
§Syllabus
The course covers the following topics. Each lecture includes a demonstration which I walk through line by line, along with a challenge where you take what you have learned and implement something yourself.
§§Overture
In which we first meet JAX and get a taste of how it differs from NumPy.
Hi, JAX! How’s life? Quick overview of JAX features,
jax.numpylibrary.Demonstration: Port elementary cellular automaton simulator from NumPy to JAX.
Challenge: Accelerated Conway’s game of life simulator.
§§Act I: Basics
In which we learn the elementary components of JAX programs while implementing and training increasingly complex neural networks.
Hi, automatic differentiation! Functional model API,
jax.gradtransformation.Demonstration: Train a teacher–student linear regression model with full-batch gradient descent.
Challenge: Train the teacher as well as the student.
Hi, procedural random number generation! Immutable PRNG state management,
jax.randomlibrary.Demonstration: Implement and train a classical perceptron with classical SGD.
Challenge: Implement and train a multi-layer perceptron.
Hi, PyTrees! PyTrees,
jax.tree.map.Demonstration: Implement and train an MLP with minibatch SGD.
Challenge: Add additional layers to the MLP.
Hi, automatic vectorisation! Vectorisation with
jax.vmap.Demonstration: Implement and train a CNN on MNIST with minibatch SGD.
Challenge: Train an ensemble of CNNs.
Hi, stateful optimisation! Managing state during a training loop.
Demonstration: Implement Adam in vanilla JAX.
Challenge: Implement Adam with weight decay.
§§Act II: Acceleration
In which we explore various aspects of just-in-time compilation and the kinds of tricks we need to use to prepare our computational graphs for the XLA compiler.
Hi, just-in-time compilation! Compilation with
jax.jit, tracing versus execution, side-effects, debugging tools.Demonstration: JIT dojo, accelerate CNN training on MNIST.
Challenge: Implement and train a residual network.
Hi, loop acceleration! Looping computations with
jax.lax.scan.Demonstration: Accelerate a whole forward pass and a whole training loop.
Challenge: Vectorise a hyperparameter sweep and tune the learning rate for one of our previous models.
Hi, static arguments! Compile errors due to non-static shapes, flagging static arguments.
Demonstration: Implement and train a byte-transformer on the Sherlock Holmes canon.
Challenge: Add dropout modules to the transformer.
Hi, branching computation! Stateful environment API, conditional computation with
jax.lax.select,jax.numpy.where, and expression-level branching.Demonstration: Implement a simple grid-world maze environment.
Alternative challenge: Add a locked door and a key to the grid-world environment.
Hi, algorithms! Performance considerations for branching computation and parallelism,
jax.lax.while.Demonstration: Comparative implementation of Kruskal’s minimum spanning tree algorithm with different union–find data structures.
Challenge: Determine solvability of a maze by implementing and accelerating depth-first search. Solve the maze by implementing and accelerating breadth-first search. For a harder challenge, implement and accelerate the Floyd–Warshall algorithm to compute all-pairs shortest paths.
§§Finale
In which we bring together everything we have learned to accelerate an end-to-end deep reinforcement learning environment simulation and training loop, one of the most effective uses of JAX for deep learning.
Hi, deep reinforcement learning! Revision of previous fundamental topics, reverse
scan.Demonstration: Accelerated PPO with GAE, train a policy to solve a small maze.
Challenge: Solve larger mazes.
§Projects
The best way to consolidate your understanding is to implement your own non-trivial JAX project. As a suggestion, you could choose one of the following research replication projects on the topics of mechanistic interpretability and the science of deep learning.
Hi, mechanistic interpretability! Take some image model we have trained and then switch to optimising over the space of input images (rather than parameters) to produce a feature visualisation for some neuron along the lines of [1].
Hi, computational mechanics! Implement a simple HMM generative process and train a transformer on this data, then probe for belief state geometry in the residual stream, replicating figures 4BCD and 6A from [2].
Hi, sparse autoencoders! Take an existing image model we have trained and then train a sparse autoencoder (SAE) on it to produce a feature visualisation along the lines of [3].
Hi, singular learning theory! Implement local learning coefficient estimation and then replicate (a rescaled version of) figure 3 from [4]
Hi, science of deep learning! Implement a synthetic in-context linear regression data generator and in-context dMMSE and ridge regression algorithmic baselines, and then train a transformer to replicate a (low-resolution or small-architecture version of) figure 2 from [5].
Hi, goal misgeneralisation! Implement the “cheese in the corner” environment including the distribution shift, and the domain randomisation RL algorithm, and train a policy to replicate the black line from figures 4(left) and H.1 (bottom left) from [6].
For those with a need for more speed and scale, you may be interested in this advanced self-study project instead:
Hi, profiling tools! Learn how to use JAX profiling tools, then find and remove a 2x memory or runtime bottleneck in the
hijaxGitHub repository.Hi, distributed training! Learn how to use distributed data loading and automatic parallelism, then find a hardware and hyperparameter configuration for some learning task where parallelism leads to an order of magnitude speed-up.
Whatever project you select, please let me know if you complete it! And if you like, I would be delighted list your name and a link to your completed project here:
2024.07.27 Billy Snikkers completed a custom project (training a neural network to implement the update rule from Conway’s game of life) [repo].
2024.08.07 Rohan Hitchcock completed project 4 (Hi, singular learning theory!) [repo].
§Beyond the fundamentals
Mastering the fundamentals of JAX is valuable because vanilla JAX is a powerful tool for accelerated Python programming.
That’s great, but there’s no reason you have to stick to vanilla JAX. There is a growing ecosystem of libraries that builds on top of JAX. Here is an incomplete and opinionated overview of some of the libraries I have come across (based on my own experience by late 2025, mostly deep learning related).
Neural network modules: You would apparently be spoiled for choice. Each of the following options has its own way of specifying a set of neural network parameters with some variant of an init/forward API. I have personally found myself coming back to the vanilla style developed in this course for my own (small-scale) projects.
Patrick Kidger’s equinox hews closest to the JAX philosophy and would be my recommendation if people don’t want to roll their own modules. (equinox is more than just a neural network module library, and I am happy to recommend it generally.)
Google DeepMind’s Flax NNX I have not had the opportunity to try, but from what I have seen it looks like it betrays the JAX philosophy in some important ways and I am not excited to try it.
Flax NNX is a new project born in the GDM merger, deprecating two now legacy libraries that you might still see around: Flax Linen (from Google proper) and Haiku (from DeepMind). I have tried Flax Linen and became frustrated by it. I considered but didn’t get into Haiku.
Optimisation: We built our own Adam implementation
in lecture 05 for pedagogical reasons, but in practice I would
definitely use optax:
- Optax is a library for processing and applying gradients to models, and it has your back for optimisers, regularisers, learning rate schedules, and more.
Checkpointing: For when you want to save models or training state so that you can pick back up where you left off or evaluate your models layer. As far as I am aware this space is monopolised by a library from Google:
- orbax.checkpoint, Whenever I have touched orbax I have became extremely frustrated about inconsistent documentation and about how many imports and objects are required to do very simple things (I’m sure it works great at Google scale but seems overengineered for my use-cases).
Data loading: A bit of a blind spot for the JAX ecosystem. I normally work with synthetic data or RL so I can get away without it, but I have heard of these tools:
Google’s
grain, maybe useful for larger scale systems, seems like it would make simple things painful?It’s possible to use dataloaders from other frameworks, including TensorFlow or PyTorch, and convert the data to JAX at load time.
HuggingFace data-loading has been recommended.
Here are two related tutorials (1) (2).
Environments: TODO: There are many cool environment libraries for RL in JAX. I haven’t used these myself, instead building my own suite of procedurally generated grid-worlds for studying goal misgeneralisation. However some libraries I have noted include:
- gymnax: Something like gymnasium.
- Jumanji: Combinatorial games and other puzzles.
- Octax: Arcade game emulator.
- Craftax: 2d Minecraft-like + roguelike.
- Pgx: for board games.
- Plus a few more in my link queue!
Whatever well-known pre-JAX RL environment you are looking for, these days it seems pretty likely something you are looking for, these days it’s quite possible there is a JAX version or there will be one soon?
Beyond deep learning: Awesome JAX is a GitHub repository maintaining a list of JAX libraries. They also have learning resources, papers, blog posts, and more.
Learning more about JAX: If you want to learn more about JAX, here are some good resources to know about.
Your first port of call should always be the official JAX documentation, which offers beginner and advanced tutorials, advice on frequent issues, a detailed API reference, and more.
For another hands-on introduction to JAX, see the University of Amsterdam’s Deep Learning Course notebooks has tutorials covering basic JAX and various deep learning topics implemented in JAX.
Two short informative tutorials I enjoyed were (1) An ACM SIGPLAN tutorial from Matthew Johnson (JAX developer) and a PyCon tutorial from Simon Pressler.
Bye!