far.in.net


Hi, JAX!

A short introduction to JAX for deep learning researchers.

The pilot of this course will run from June to August, 2024. To express interest, either email me or join the metauni Discord server (invite link, see #hijax channel) where the classes will probably take place.

Format

Weekly:

Source code (including solutions) for the weekly demonstrations will be available on GitHub. Participants are invited to share their own solutions and projects through maintaining a public fork of the course repository.

Prerequisites

Programming:

Theory:

Syllabus

The syllabus is still under construction. Tentative topics as follows.

Overview:

  1. Hi, JAX! Intro to JAX, course overview, immutability and jax.numpy, randomness with jax.random. Demonstration: Elementary cellular automaton. Challenge: Conway’s game of life.

Deep learning in JAX:

  1. Hi, automatic differentiation! Call/init function API for models, jax.grad transformation. Demonstration: Classical perceptron trained with classical SGD, vanilla JAX. Challenge: Multi-layer perceptron, vanilla JAX.

  2. Hi, flax and optax! Pytrees, flax modules, optax optimisers, train state. Demonstration: Train an MLP on MNIST with minibatch SGD. Challenge: Implement a drop-in replacement for optax SGD. Harder challenge: Implement a drop-in replacement for a stateful optimiser (Adam).

  3. Hi, automatic vectorisation! Vectorisation with jax.vmap. Demonstration: Train a CNN on MNIST with minibatch SGD. Challenge: Implement your own convolution and pooling modules.

  4. Hi, just-in-time compilation! Compilation with jax.jit, tracing vs. execution, side-effects. Demonstration: Train an accelerated CNN on ImageNet. Challenge: Implement and train a ResNet on CIFAR.

  5. Hi again, just-in-time compilation! Compile errors due to non-static shapes, static arguments and recompilation. Demonstration: Train a byte-transformer on the Sherlock Holmes canon. Challenge: Add dropout modules to the transformer.

  6. Hi, loop acceleration! Looping computations with jax.lax.scan. Demonstration: Accelerate a whole training loop. Challenge: Vectorise a hyperparameter sweep.

  7. How are you going? Chance to revise some previous topics. Demonstration: Accelerated and vectorised LLC estimation. Challenge: Replicate some findings from [1] and [2].

Deep reinforcement learning in JAX:

  1. Hi, branching computation! Stateful environment API, conditional computation with jax.lax.cond and jax.lax.select. Demonstration: Simple gridworld maze environment. Challenge: Determine solvability by implementing an accelerated DFS/BFS. Alternative challenge: Tabular Q-learning.

  2. Hi, deep reinforcement learning! DQN algorithm, revision of previous topics. Demonstration: Accelerated DQN. Challenge: Enhancements such as double DQN.

  3. Hi again, deep reinforcement learning! PPO algorithm, reverse scan, revision of previous topics. Demonstration: Accelerated DQN. Challenge: Solve larger mazes.

More topics may be included from the following list.