Hi, JAX!
A short introduction to JAX for deep learning researchers.
The pilot of this course will run from June to August, 2024. To
express interest, either email me or join the metauni Discord server (invite link, see
#hijax
channel) where the classes will probably take
place.
Format
Weekly:
- 60–90 minute live code-along demonstration (via Discord): implementing a short ML-related project requiring understanding some new JAX concept.
- Optional homework challenges: suggestions for self-directed tasks that reinforce the concepts from the week’s demonstration. Participants are encouraged to pursue their own interests and share their work.
- Q&A (via Discord): opportunity to discuss course topics, demonstrations, and challenges.
Source code (including solutions) for the weekly demonstrations will be available on GitHub. Participants are invited to share their own solutions and projects through maintaining a public fork of the course repository.
Prerequisites
Programming:
- Prior experience programming in Python and NumPy.
- Helpful/optional: Prior experience with einops (basics).
- Helpful/optional: Prior experience programming in Rust (immutability).
- Helpful/optional: Prior experience programming in Haskell (side-effectless state management).
Theory:
- Basic linear algebra (matrices).
- Basic vector calculus (gradient vectors).
- Basic optimisation (stochastic gradient descent).
- Basic deep learning (MLPs, CNNs, Transformers).
- Helpful/optional: Singular learning theory (LLC estimation [1, 2]).
- Helpful/optional: Reinforcement learning (Q-learning and policy gradients).
Syllabus
The syllabus is still under construction. Tentative topics as follows.
Overview:
- Hi, JAX! Intro to JAX, course overview,
immutability and
jax.numpy
, randomness withjax.random
. Demonstration: Elementary cellular automaton. Challenge: Conway’s game of life.
Deep learning in JAX:
Hi, automatic differentiation! Call/init function API for models,
jax.grad
transformation. Demonstration: Classical perceptron trained with classical SGD, vanilla JAX. Challenge: Multi-layer perceptron, vanilla JAX.Hi, flax and optax! Pytrees,
flax
modules,optax
optimisers, train state. Demonstration: Train an MLP on MNIST with minibatch SGD. Challenge: Implement a drop-in replacement for optax SGD. Harder challenge: Implement a drop-in replacement for a stateful optimiser (Adam).Hi, automatic vectorisation! Vectorisation with
jax.vmap
. Demonstration: Train a CNN on MNIST with minibatch SGD. Challenge: Implement your own convolution and pooling modules.Hi, just-in-time compilation! Compilation with
jax.jit
, tracing vs. execution, side-effects. Demonstration: Train an accelerated CNN on ImageNet. Challenge: Implement and train a ResNet on CIFAR.Hi again, just-in-time compilation! Compile errors due to non-static shapes, static arguments and recompilation. Demonstration: Train a byte-transformer on the Sherlock Holmes canon. Challenge: Add dropout modules to the transformer.
Hi, loop acceleration! Looping computations with
jax.lax.scan
. Demonstration: Accelerate a whole training loop. Challenge: Vectorise a hyperparameter sweep.How are you going? Chance to revise some previous topics. Demonstration: Accelerated and vectorised LLC estimation. Challenge: Replicate some findings from [1] and [2].
Deep reinforcement learning in JAX:
Hi, branching computation! Stateful environment API, conditional computation with
jax.lax.cond
andjax.lax.select
. Demonstration: Simple gridworld maze environment. Challenge: Determine solvability by implementing an accelerated DFS/BFS. Alternative challenge: Tabular Q-learning.Hi, deep reinforcement learning! DQN algorithm, revision of previous topics. Demonstration: Accelerated DQN. Challenge: Enhancements such as double DQN.
Hi again, deep reinforcement learning! PPO algorithm, reverse
scan
, revision of previous topics. Demonstration: Accelerated DQN. Challenge: Solve larger mazes.
More topics may be included from the following list.
- Hi, einops!
- Hi, debugging tools!
- Hi, profiling tools!
- Hi, automatic parallelisation!
- Hi, mechanistic interpretability!