~Hi, JAX!
An introduction to JAX for deep learning researchers
This page is for the online edition of Hi, JAX! See also: 2024 edition.
This is the syllabus for a course I plan to record and release in December 2025. The course will be released freely for anyone on YouTube.
§Prerequisites
Programming:
- Prior experience programming in Python (e.g., collections, dataclasses, functools, type annotations).
- Prior experience programming in NumPy.
- NumPy basics tutorial should be sufficient.
- Prior experience with einops.
- Einops basics tutorial more than sufficient.
Machine learning:
- Basic vector calculus and optimisation (stochastic gradient descent).
- Basic deep learning architectures (MLP, CNN, Transformer).
- 3Blue1Brown Season 3 should be sufficient (skip chapter 4).
Helpful/optional (we will cover what is required):
- Basic reinforcement learning (Markov decision process formalism, policy gradients).
- Prior experience programming in, e.g., Rust (the concept of immutability).
- Prior experience programming in, e.g., Haskell (the concepts of pure functions, mapping, folding).
You will need a Python environment (but not necessarily a GPU) if you want to code along during the workshops.
§Syllabus
The course covers the following topics. Each lecture includes a demonstration which I walk through line by line, along with a challenge where you take what you have learned and implement something yourself.
Overture:
Hi, JAX! How’s life? Quick overview of many JAX topics.
Demonstration: Port elementary cellular automaton simulator from NumPy to JAX.
Challenge: Accelerated Conway’s game of life simulator.
Act I: Basics.
Hi, automatic differentiation! Call/init function API for model parameters,
jax.gradtransformation.Demonstration: Train a teacher–student linear regression model with full-batch gradient descent.
Challenge: Train the teacher as well as the student.
Hi, procedural random number generation! Immutable PRNG state management,
jax.random.split.Demonstration: Implement and train a classical perceptron with classical SGD.
Challenge: Implement and train a multi-layer perceptron.
Hi, PyTrees! PyTrees,
jax.tree.map.Demonstration: Implement and train an MLP on MNIST with minibatch SGD.
Challenge: Add additional layers to the MLP.
Hi, automatic vectorisation! Vectorisation with
jax.vmap.Demonstration: Implement and train a CNN on MNIST with minibatch SGD.
Challenge: Train an ensemble of CNNs.
Hi, stateful optimisation! Managing state during a training loop.
Demonstration: Implement Adam in vanilla JAX.
Challenge: Implement Adam with weight decay.
Act II: Acceleration.
Hi, just-in-time compilation! Compilation with
jax.jit, tracing versus execution, side-effects, debugging tools.Demonstration: JIT dojo, accelerate CNN training on MNIST.
Challenge: Implement and train a residual network.
Hi, loop acceleration! Looping computations with
jax.lax.scan.Demonstration: Accelerate a whole forward pass and a whole training loop.
Challenge: Vectorise a hyperparameter sweep and replicate some error rate from Yann LeCun’s table.
Hi, static arguments! Compile errors due to non-static shapes, flagging static arguments, compilation cache.
Demonstration: Implement and train a byte-transformer on the Sherlock Holmes canon.
Challenge: Add dropout modules to the transformer.
Hi, branching computation! Stateful environment API, conditional computation with
jax.lax.condandjax.lax.select.Demonstration: Implement a simple grid-world maze environment.
Challenge: Determine solvability by implementing accelerated depth-first search or breadth-first search.
Hi, algorithms! Performance considerations for branching computation and parallelism.
Demonstration: Comparative implementation of Kruskall’s minimum spanning tree algorithm with different union–find data structures.
Challenge: Implement and accelerate the Floyd–Warshall algorithm to compute all-pairs shortest paths.
Act III: Deep learning ecosystem.
Hi, optimisation! Stateful optimisation with
optax.Demonstration: Replicate our Adam optimiser using optax.
Challenge: Replicate Adam with weight decay using optax.
Hi, data loading! Whirlwind tour of data loading libraries usable with JAX.
Demonstration: Replicate our MNIST data loading in various libraries.
Challenge: Find a new data set, choose a data loading method, and implement it.
Hi, modules! Tour of deep neural networks with Patrick Kidger’s equinox and Google DeepMind’s Flax NNX (plus legacy libraries Google’s Flax Linen DeepMind’s Haiku).
Demonstration: Comparative replication of our vanilla-JAX CNN with
equinox,flax.linen,dm-haiku, andflax.nnx, contrasting the various libraries’ approaches to modularity, initialisation, and state management.Challenge: Choose one of
equinox,flax.linen,dm-haiku, orflax.nnxand then replicate our vanilla-JAX transformer.
Hi, checkpointing! Checkpoint trained models with
orbax.checkpoint.Demonstration: Save a trained model and then reload it for evaluation.
Challenge: Implement checkpointing while training, recover from a crashed training run including preserving random state (re-training with the same seed without crashing should lead to an identical model).
Finale:
Hi, deep reinforcement learning! Revision of previous fundamental topics, reverse
scan.Demonstration: Accelerated PPO with GAE, train a policy to solve a small maze.
Challenge: Solve larger mazes.
§Projects
The best way to consolidate your understanding is to implement your own non-trivial JAX project. As a suggestion, you could choose one of the following research replication projects on the topics of mechanistic interpretability and the science of deep learning.
Hi, mechanistic interpretability! Take some image model we have trained and then switch to optimising over the space of input images (rather than parameters) to produce a feature visualisation for some neuron along the lines of [1].
Hi, computational mechanics! Implement a simple HMM generative process and train a transformer on this data, then probe for belief state geometry in the residual stream, replicating figures 4BCD and 6A from [2].
Hi, sparse autoencoders! Take an existing image model we have trained and then train a sparse autoencoder (SAE) on it to produce a feature visualisation along the lines of [3].
Hi, singular learning theory! Implement local learning coefficient estimation and then replicate (a rescaled version of) figure 3 from [4]
Hi, science of deep learning! Implement a synthetic in-context linear regression data generator and in-context dMMSE and ridge regression algorithmic baselines, and then train a transformer to replicate a (low-resolution or small-architecture version of) figure 2 from [5].
Hi, goal misgeneralisation! Implement the “cheese in the corner” environment including the distribution shift, and the domain randomisation RL algorithm, and train a policy to replicate the black line from figures 4(left) and H.1 (bottom left) from [6].
For those with a need for more speed and scale, you may be interested in this advanced self-study project instead:
Hi, profiling tools! Learn how to use JAX profiling tools, then find and remove a 2x memory or runtime bottleneck in the
hijaxGitHub repository.Hi, distributed training! Learn how to use distributed data loading and automatic parallelism, then find a hardware and hyperparameter configuration for some learning task where parallelism leads to an order of magnitude speed-up.
Whatever project you select, please let me know if you complete it! And if you like, I would be delighted list your name and a link to your completed project here:
2024.07.27 Billy Snikkers completed a custom project (training a neural network to implement the update rule from Conway’s game of life) [repo].
2024.08.07 Rohan Hitchcock completed project 4 (Hi, singular learning theory!) [repo].
§Other JAX resources
If you want to learn more about JAX, here are some good resources to know about.
The official JAX documentation offers beginner and advanced tutorials, advice on frequent issues, a detailed API reference, and more.
The University of Amsterdam’s Deep Learning Course notebooks has tutorials covering basic JAX and various deep learning topics implemented in JAX.
An ACM SIGPLAN tutorial on JAX from Matthew Johnson (JAX developer).
A PyCon tutorial on JAX from Simon Pressler.
For more resources, see Awesome JAX, a GitHub repository maintaining a list of JAX libraries, learning resources, papers, blog posts, and more.