far.in.net


~Hi, JAX!

An introduction to vanilla JAX for deep learning

This page is for the online edition of Hi, JAX! See also: 2024 edition.

This is the syllabus for a course I am currently developing and will be releasing shortly. The course will be released freely for anyone on my YouTube channel. The code is in development on GitHub.

Development progress:

Estimated release date: January 2026

§Why learn JAX?

I’m a deep learning scientist. I try to understand the natural principles underpinning the structure and function of learned neural networks. In the course of my research, I have to train, test, and dissect many neural networks, large and small. Neural networks are essentially just arrays of floating point parameters, so the best way to interact with them is using an array programming library. That is where JAX comes in.

JAX is a Python library for writing and transforming array programs. The JAX API for writing array programs is similar to that of NumPy, but the JAX API for transforming array programs is utterly unique, effortlessly powerful, and delightfully elegant.

Since switching from using PyTorch to using JAX for my research, I have personally noticed the following benefits.

  1. JAX is fast. The main selling point of JAX is its integration with the XLA array program compiler. XLA gives a speed boost for even small-scale JAX programs, which run multiple times faster than uncompiled array programs, even on the same processors.

  2. JAX is portable. JAX isn’t just useful on GPUs. It’s equally happy to optimise code I want to run on my M2 Air. When I do need to reach for more compute, I don’t have to cough up for a GPU node, I can take advantage of Google’s TPU Research Cloud programme and run my code on a free TPU cluster.

  3. JAX is elegant. Deep learning researchers describe their insights with the powerful abstractions forged over hundreds of years by mathematicians. The JAX library’s design celebrates mathematical abstractions, rather than over-fitting to some subset of methods that happen to be in vogue.

  4. JAX is good for your code. It can be challenging to comply with the various restrictions JAX enforces (immutability, pure functions, static shapes). But even if they weren’t justified by the powerful optimisations they enable, these restrictions are all good ways to make code more explicit, readable, and, ultimately, less wrong.

JAX isn’t perfect. But it’s far and away the closest to my conception of the ideal array programming library. Plus, it’s under active development by a team that has shown they can get the big picture so right, so the future looks bright.

All else equal, I’d choose JAX for my projects, every time. Of course, all else is not always equal. Here are some reasons people might not want to switch to JAX.

  1. JAX is less popular. It hasn’t been around as long as PyTorch, which has a larger and more well-resourced ecosystem. If you switch to JAX, you might have to write more code yourself when the paper you want to replicate or the architecture you want to try provides a PyTorch codebase. Your colleagues might refuse to port their existing code to work with JAX.

    But you are not alone. While still small relative to the PyTorch ecosystem, the JAX ecosystem is decent in absolute terms today. There are many awesome JAX tools out there and I see new projects posted every week. Maybe I will see your contribution one day?

  2. JAX is hard to learn. Using JAX requires thinking about array programs in new ways. You will encounter challenges you have no idea how to solve, and this will make you feel like you’re a beginner programmer again. This is difficult, and there aren’t many learning resources out there that convey the “JAX mindset.”

    That changes, now. I went through this journey myself. It took me months to crack the “JAX mindset.” But I did eventually crack it, and now I have made this course—including programming demonstrations drawn from real challenges I came up against when learning JAX.

I can feel which way the wind is blowing. Can you?

§Prerequisites

Programming:

Machine learning:

Helpful/optional (we will cover what is required):

You will need a Python environment (but not necessarily a GPU) if you want to code along during the workshops.

§Syllabus

The course covers the following topics. Each lecture includes a demonstration which I walk through line by line, along with a challenge where you take what you have learned and implement something yourself.

§§Overture

In which we first meet JAX and get a taste of how it differs from NumPy.

  1. Hi, JAX! How’s life? Quick overview of JAX features, jax.numpy library.

    • Demonstration: Port elementary cellular automaton simulator from NumPy to JAX.

    • Challenge: Accelerated Conway’s game of life simulator.

§§Act I: Basics

In which we learn the elementary components of JAX programs while implementing and training increasingly complex neural networks.

  1. Hi, automatic differentiation! Functional model API, jax.grad transformation.

    • Demonstration: Train a teacher–student linear regression model with full-batch gradient descent.

    • Challenge: Train the teacher as well as the student.

  2. Hi, procedural random number generation! Immutable PRNG state management, jax.random library.

    • Demonstration: Implement and train a classical perceptron with classical SGD.

    • Challenge: Implement and train a multi-layer perceptron.

  3. Hi, PyTrees! PyTrees, jax.tree.map.

    • Demonstration: Implement and train an MLP with minibatch SGD.

    • Challenge: Add additional layers to the MLP.

  4. Hi, automatic vectorisation! Vectorisation with jax.vmap.

    • Demonstration: Implement and train a CNN on MNIST with minibatch SGD.

    • Challenge: Train an ensemble of CNNs.

  5. Hi, stateful optimisation! Managing state during a training loop.

    • Demonstration: Implement Adam in vanilla JAX.

    • Challenge: Implement Adam with weight decay.

§§Act II: Acceleration

In which we explore various aspects of just-in-time compilation and the kinds of tricks we need to use to prepare our computational graphs for the XLA compiler.

  1. Hi, just-in-time compilation! Compilation with jax.jit, tracing versus execution, side-effects, debugging tools.

    • Demonstration: JIT dojo, accelerate CNN training on MNIST.

    • Challenge: Implement and train a residual network.

  2. Hi, loop acceleration! Looping computations with jax.lax.scan.

    • Demonstration: Accelerate a whole forward pass and a whole training loop.

    • Challenge: Vectorise a hyperparameter sweep and tune the learning rate for one of our previous models.

  3. Hi, static arguments! Compile errors due to non-static shapes, flagging static arguments.

    • Demonstration: Implement and train a byte-transformer on the Sherlock Holmes canon.

    • Challenge: Add dropout modules to the transformer.

  4. Hi, branching computation! Stateful environment API, conditional computation with jax.lax.select, jax.numpy.where, and expression-level branching.

    • Demonstration: Implement a simple grid-world maze environment.

    • Alternative challenge: Add a locked door and a key to the grid-world environment.

  5. Hi, algorithms! Performance considerations for branching computation and parallelism, jax.lax.while.

    • Demonstration: Comparative implementation of Kruskal’s minimum spanning tree algorithm with different union–find data structures.

    • Challenge: Determine solvability of a maze by implementing and accelerating depth-first search. Solve the maze by implementing and accelerating breadth-first search. For a harder challenge, implement and accelerate the Floyd–Warshall algorithm to compute all-pairs shortest paths.

§§Finale

In which we bring together everything we have learned to accelerate an end-to-end deep reinforcement learning environment simulation and training loop, one of the most effective uses of JAX for deep learning.

  1. Hi, deep reinforcement learning! Revision of previous fundamental topics, reverse scan.

    • Demonstration: Accelerated PPO with GAE, train a policy to solve a small maze.

    • Challenge: Solve larger mazes.

§Projects

The best way to consolidate your understanding is to implement your own non-trivial JAX project. As a suggestion, you could choose one of the following research replication projects on the topics of mechanistic interpretability and the science of deep learning.

  1. Hi, mechanistic interpretability! Take some image model we have trained and then switch to optimising over the space of input images (rather than parameters) to produce a feature visualisation for some neuron along the lines of [1].

  2. Hi, computational mechanics! Implement a simple HMM generative process and train a transformer on this data, then probe for belief state geometry in the residual stream, replicating figures 4BCD and 6A from [2].

  3. Hi, sparse autoencoders! Take an existing image model we have trained and then train a sparse autoencoder (SAE) on it to produce a feature visualisation along the lines of [3].

  4. Hi, singular learning theory! Implement local learning coefficient estimation and then replicate (a rescaled version of) figure 3 from [4]

  5. Hi, science of deep learning! Implement a synthetic in-context linear regression data generator and in-context dMMSE and ridge regression algorithmic baselines, and then train a transformer to replicate a (low-resolution or small-architecture version of) figure 2 from [5].

  6. Hi, goal misgeneralisation! Implement the “cheese in the corner” environment including the distribution shift, and the domain randomisation RL algorithm, and train a policy to replicate the black line from figures 4(left) and H.1 (bottom left) from [6].

For those with a need for more speed and scale, you may be interested in this advanced self-study project instead:

  1. Hi, profiling tools! Learn how to use JAX profiling tools, then find and remove a 2x memory or runtime bottleneck in the hijax GitHub repository.

  2. Hi, distributed training! Learn how to use distributed data loading and automatic parallelism, then find a hardware and hyperparameter configuration for some learning task where parallelism leads to an order of magnitude speed-up.

Whatever project you select, please let me know if you complete it! And if you like, I would be delighted list your name and a link to your completed project here:

§Beyond the fundamentals

Mastering the fundamentals of JAX is valuable because vanilla JAX is a powerful tool for accelerated Python programming.

That’s great, but there’s no reason you have to stick to vanilla JAX. There is a growing ecosystem of libraries that builds on top of JAX. Here is an incomplete and opinionated overview of some of the libraries I have come across (based on my own experience by late 2025, mostly deep learning related).

Neural network modules: You would apparently be spoiled for choice. Each of the following options has its own way of specifying a set of neural network parameters with some variant of an init/forward API. I have personally found myself coming back to the vanilla style developed in this course for my own (small-scale) projects.

Optimisation: We built our own Adam implementation in lecture 05 for pedagogical reasons, but in practice I would definitely use optax:

Checkpointing: For when you want to save models or training state so that you can pick back up where you left off or evaluate your models layer. As far as I am aware this space is monopolised by a library from Google:

Data loading: A bit of a blind spot for the JAX ecosystem. I normally work with synthetic data or RL so I can get away without it, but I have heard of these tools:

Here are two related tutorials (1) (2).

Environments: TODO: There are many cool environment libraries for RL in JAX. I haven’t used these myself, instead building my own suite of procedurally generated grid-worlds for studying goal misgeneralisation. However some libraries I have noted include:

Whatever well-known pre-JAX RL environment you are looking for, these days it seems pretty likely something you are looking for, these days it’s quite possible there is a JAX version or there will be one soon?

Beyond deep learning: Awesome JAX is a GitHub repository maintaining a list of JAX libraries. They also have learning resources, papers, blog posts, and more.

Learning more about JAX: If you want to learn more about JAX, here are some good resources to know about.

Bye!