Hi, JAX!

A short introduction to JAX for deep learning researchers.

This course runs from July 11, 2024 to September 12, 2024.

Anyone can join the course. Simply join the metauni Discord server (invite link) and introduce yourself in the #hijax channel. (Note: unlike other metauni events, this course will not use Roblox.)


Weekly workshops: A 60 minute live code-along workshop on implementing a short ML-related project that requires understanding a new JAX concept. See the syllabus for details.

The code will be available on GitHub to help you follow each week’s workshop. Participants are invited to share their own solutions by maintaining a public fork of the course repository.

Optional homework: Participants are invited to complete additional projects outside of the workshops, including:

These are suggestions only, and participants are encouraged to pursue their own alternative project ideas and share their work.

Community: Participants are welcome to use the #hijax channel in the metauni Discord server to discuss course topics, workshops, and challenges. This channel will also be used for scheduling updates.




You will need a Python environment (but not necessarily a GPU) if you want to code along during the workshops.


The course will cover the following nine topics. Workshops run on Thursdays at 2pm AEST on the listed dates.

  1. Hi, JAX! How’s life? (July 11) Intro to JAX, course overview, immutability and jax.numpy, randomness with jax.random. Demonstration: Elementary cellular automaton. Challenge: Conway’s game of life.

  2. Hi, automatic differentiation! (July 18) Call/init function API for models, jax.grad transformation. Demonstration: Classical perceptron trained with classical SGD, vanilla JAX. Challenge: Multi-layer perceptron, vanilla JAX.

  3. Hi, flax and optax! (July 25) Pytrees, flax modules, optax optimisers, train state. Demonstration: Train an MLP on MNIST with minibatch SGD. Challenge: Implement a drop-in replacement for optax SGD. Harder challenge: Implement a drop-in replacement for a stateful optimiser (Adam).

  4. Hi, automatic vectorisation! (August 1) Vectorisation with jax.vmap. Demonstration: Train a CNN on MNIST with minibatch SGD. Challenge: Implement your own convolution and pooling modules.

  5. Hi, just-in-time compilation! (August 8) Compilation with jax.jit, tracing vs. execution, side-effects. Demonstration: Train an accelerated CNN on ImageNet. Challenge: Implement and train a ResNet on CIFAR.

  6. Hi again, just-in-time compilation! (August 15) Compile errors due to non-static shapes, static arguments and recompilation. Demonstration: Train a byte-transformer on the Sherlock Holmes canon. Challenge: Add dropout modules to the transformer.

  7. Hi, loop acceleration! (August 22) Looping computations with jax.lax.scan. Demonstration: Accelerate a whole training loop. Challenge: Vectorise a hyperparameter sweep.

Note: No workshop on August 29.

  1. Hi, branching computation! (September 5) Stateful environment API, conditional computation with jax.lax.cond and jax.lax.select. Demonstration: Simple gridworld maze environment. Challenge: Determine solvability by implementing an accelerated DFS/BFS. Alternative challenge: Tabular Q-learning.

  2. Hi, deep reinforcement learning! (September 12) PPO algorithm, reverse scan, revision of previous topics. Demonstration: Accelerated PPO, solve small mazes. Challenge: Solve larger mazes.

Bounty board

The first person to complete each of the following tasks will receive a prize of one hexadecimal Australian dollar (2.56 AUD) and name recognition on this webpage.

  1. Hi, profiling tools! Learn how to use JAX profiling tools and find and remove a 2x performance bottleneck in the hijax GitHub repository.

  2. Hi, automatic parallelisation! Find and demonstrate a hardware/hyperparameter configuration where using jax.pmap yields a speedup of at least 10x.

  3. Hi, developmental interpretability! Accelerate and vectorise LLC estimation, and replicate your choice of figure from [1] or [2] (authors from these works are ineligible).

  4. Hi, mechanistic interpretability! IDK, make an SAE or something. Suggestions welcome.

Other JAX resources

If you want to learn more about JAX, here are some good resources to know about.

  1. The official JAX documentation offers beginner and advanced tutorials, advice on frequent issues, a detailed API reference, and more.

  2. The University of Amsterdam’s Deep Learning Course notebooks has tutorials covering basic JAX and various deep learning topics implemented in JAX.

  3. Awesome JAX is a GitHub repository maintaining a list of JAX libraries, learning resources, papers, blog posts, and more.