Hi, JAX!
A short introduction to JAX for deep learning researchers.
This course runs from July 11, 2024 to September 12, 2024.
Anyone can join the course. Simply join the metauni Discord server (invite link) and introduce
yourself in the #hijax
channel. (Note: unlike other metauni
events, this course will not use Roblox.)
Format
Weekly workshops: A 60 minute live code-along workshop on implementing a short ML-related project that requires understanding a new JAX concept. See the syllabus for details.
- Time: Thursdays 2pm AEST (= Wednesdays 9pm PDT, Thursdays 6am CEST).
- Place: metauni Discord server (invite link), General voice channel.
The code will be available on GitHub to help you follow each week’s workshop. Participants are invited to share their own solutions by maintaining a public fork of the course repository.
Optional homework: Participants are invited to complete additional projects outside of the workshops, including:
- Weekly challenges: suggested self-directed tasks that reinforce the concepts from the workshop. See the syllabus for details.
- Course bounties: Four additional bounty projects invite participants to learn new JAX concepts. See the bounty board for details.
These are suggestions only, and participants are encouraged to pursue their own alternative project ideas and share their work.
Community: Participants are welcome to use the
#hijax
channel in the metauni Discord server to discuss
course topics, workshops, and challenges. This channel will also be used
for scheduling updates.
Prerequisites
Programming:
- Prior experience programming in Python.
- Prior experience programming in NumPy.
- NumPy basics tutorial should be sufficient.
- Helpful/optional: Prior experience with einops.
- Einops basics tutorial more than sufficient.
- Helpful/optional: Prior experience programming in Rust (immutability).
- Helpful/optional: Prior experience programming in Haskell (side-effectless state management).
Theory:
- Basic vector calculus, optimisation, and deep learning architectures
(stochastic gradient descent; basic architectures such as MLP, CNN,
Transformer).
- 3Blue1Brown Season 3 (skip chapter 4) should be sufficient.
- Helpful/optional: Reinforcement learning (policy gradients).
You will need a Python environment (but not necessarily a GPU) if you want to code along during the workshops.
Syllabus
The course will cover the following nine topics. Workshops run on Thursdays at 2pm AEST on the listed dates.
Hi, JAX! How’s life? (July 11) Intro to JAX, course overview, immutability and
jax.numpy
, randomness withjax.random
. Demonstration: Elementary cellular automaton. Challenge: Conway’s game of life.Hi, automatic differentiation! (July 18) Call/init function API for models,
jax.grad
transformation. Demonstration: Classical perceptron trained with classical SGD, vanilla JAX. Challenge: Multi-layer perceptron, vanilla JAX.Hi, pytrees! (July 25) Pytrees,
jax.tree.map
,equinox
modules. Demonstration: Train an MLP on MNIST with minibatch SGD. Challenge: Manually register the MLP modules as pytrees (obviating theequinox
dependency).Hi, deep learning ecosystem! Hi, automatic vectorisation! (August 1) Modules with
equinox.nn
, vectorisation withjax.vmap
, stateful optimisation withoptax
. Demonstration: Train a CNN on MNIST with minibatch SGD and Adam. Challenge: Implement a drop-in replacement foroptax.adam
.Hi, just-in-time compilation! (August 8) Compilation with
jax.jit
, tracing vs. execution, side-effects. Demonstration: JIT dojo (part 1), and train an accelerated CNN on MNIST. Challenge: Implement and train more historic network architectures.Hi again, just-in-time compilation! (August 15) Compile errors due to non-static shapes, static arguments and recompilation. Demonstration: Train a byte-transformer on the Sherlock Holmes canon. Challenge: Add dropout modules to the transformer.
Hi, loop acceleration! (August 22) Looping computations with
jax.lax.scan
. Demonstration: Accelerate a whole training loop. Challenge: Vectorise a hyperparameter sweep and replicate some error rate from Yann LeCun’s table.
Note: No workshop on August 29.
Hi, branching computation! (September 5) Stateful environment API, conditional computation with
jax.lax.cond
andjax.lax.select
. Demonstration: Simple gridworld maze environment. Challenge: Determine solvability by implementing an accelerated DFS/BFS. Alternative challenge: Tabular Q-learning.Hi, deep reinforcement learning! (September 12) PPO algorithm, reverse
scan
, revision of previous topics. Demonstration: Accelerated PPO, solve small mazes. Challenge: Solve larger mazes.
Bounty board
The first person to complete each of the following tasks will receive a prize of one hexadecimal Australian dollar (2.56 AUD) and name recognition on this webpage.
Hi, profiling tools! Learn how to use JAX profiling tools and find and remove a 2x performance bottleneck in the
hijax
GitHub repository.Hi, automatic parallelisation! Find and demonstrate a hardware/hyperparameter configuration where using
jax.pmap
yields a speedup of at least 10x.Hi, developmental interpretability! Accelerate and vectorise LLC estimation, and replicate your choice of figure from [1] or [2] (authors from these works are ineligible).
Bounty 3 claimed by Rohan Hitchcock on 2024.08.07. [Rohan’s fork]
Hi, mechanistic interpretability! Take an existing image model we have trained and then switch to optimising over the space of inputs (rather than parameters) to produce a feature visualisation along the lines of [3].
Hi again, mechanistic interpretability! Take an existing image model we have trained and then train a sparse autoencoder (SAE) on it to produce a feature visualisation along the lines of [4].
Other JAX resources
If you want to learn more about JAX, here are some good resources to know about.
The official JAX documentation offers beginner and advanced tutorials, advice on frequent issues, a detailed API reference, and more.
The University of Amsterdam’s Deep Learning Course notebooks has tutorials covering basic JAX and various deep learning topics implemented in JAX.
Awesome JAX is a GitHub repository maintaining a list of JAX libraries, learning resources, papers, blog posts, and more.