GPT on a TPU VM with PyTorch/XLA

How I Learned to Stop Worrying and Love the Brrr

Interactive workshop
presented at the
Singular Learning Theory Seminar
MetaUni, October, 2023

Presented by
Matthew Farrugia-Roberts
independent AI safety researcher


Intended outcomes:


Overview of steps:

  1. Access a TPU VM via SSH.
  2. Share code with the VM via Git and GitHub.
  3. Train a transformer on the CPU with PyTorch.
  4. Train a transformer on the TPU with PyTorch/XLA.
  5. (Optional) Apply for your own TPU VMs.
  6. (Homework) Exercises.

Along the way, we’ll learn about TPU VMs, SSH, GPT, XLA, GCP and the TRC!

Prelude: WTF? Any more initialisms?

OK, it’s not a great sign that there are more capital letters in that title than there are lower case letters. Let me spell out the main terms used in this workshop.

TPU: A Tensor Processing Unit is a specialised processor designed by Google to execute large batches of neural network operations (e.g. float32 matrix multiplications) more efficiently than on a general-purpose CPU or a GPU specialised for rendering.

More information (worth reading afterwards):

VM: A Virtual Machine is an emulated computer system, offering the functionality of a full operating system but running on shared hardware. In our case, the main hardware of interest is a TPU. A TPU VM is a VM that has access to a TPU.

XLA: Accelerated Linear Algebra is a compiler that takes computational graphs (like those specified by a Python program) and transforms them for efficient execution on specific hardware (including CPUs, GPUs, TPUs, and other ML hardware).

Note, XLA is not specific to TPUs. However, running code on TPUs currently requires XLA, because TPUs don’t (yet?) run uncompiled computational graphs. That’s why we will be using XLA today.

More information:

SSH: The Secure SHell protocol allows you to securely login to and run commands on a private server (such as a TPU VM) over a network. To do so you need an SSH client on your local machine. We will use the terminal client ssh.

GPT: Recall that a Generative Pre-trained Transformer is a foundation model based on a decode-only transformer architecture.

GCP: Google Cloud Platform is basically Google’s version of Microsoft Azure or Amazon AWS (don’t get me started on AWS initialisms…). This is where my TPU VMs live.

TRC: The TPU Research Cloud is a program by Google Research aimed at supporting open access machine learning research with free GCP compute resources, including TPU VMs, including my TPU VMs we’ll be using today. You can find out more and apply for a TPU allocation on the TRC website, about which, more later.

Part 1: Access a TPU VM via SSH

Goal: Obtain shell access to a live TPU VM.

Options for obtaining a TPU VM:

Options for actually obtaining shell access, once you have a TPU VM:

Let’s do the latter in both cases.

Step 1.1: Generate a key pair with ssh-keygen.

  1. Open your shell.

  2. Run the following command:

    ssh-keygen -t ed25519 -f <keypath> -C <username>

    where <username> is replaced with an arbitrary username, and <keypath> is the place to store the keys:

    • On macos/linux/wsl, suggest ~/.ssh/matts-tpuvms
    • On windows, maybe something like c/Users/<you>/.ssh/matts-tpuvms

Step 1.2: Add the public key to my TPU VMs on Google Cloud Platform.

  1. Open the file <keypath.pub>, for example ~/.ssh/matts-tpuvms.pub.

  2. Copy the contents and send them to me via Discord.

  3. I will add them to my TPU VMs, please wait.

Warning: Sending me your public key is safer than sending me your private key, even if it ends up in the recording—only with the private key can you log in with this account.

However, after this we will all have root privileges on all of my TPU VMs, which means everyone can in principle see what is in everyone else’s account. This will later include a GitHub key. Hopefully this is acceptable for today. Consider revoking this key from GitHub after the workshop if you prefer.

Step 1.3: Connect to the TPU VM with ssh.

  1. Open your shell.

  2. Run the following command:

    ssh -i <keypath> <username>@<ip-address>

    where <keypath> and <username> are the same as you used in step 1.1.2, and <ip-address> is the address I share with you via Discord now.

Interlude: Where am I?

If this command works, you will now have an active shell connection to the TPU VM. That is, the commands you run going forward will run on the TPU VM, not your local machine. Conceptually, the shell you see now is the same as you would see if you were in Google’s data centre with a monitor and a keyboard (modulo some lag).

Now that we are here, where is here? Your TPU VM runs Ubuntu Linux. It has the following hardware:

You can play with the CPU now. The rest of this workshop will work towards exercising that TPU.

Part 2: Share code with the VM via Git and GitHub

Goal: Get code from your machine to the VM.

Options for doing this:

Let’s do the latter.

Edit: The example repository is now public, so an SSH key is not strictly required, however you might like to follow these instructions anyway because an SSH key will be necessary later if you want yo clone your own private repository.

Step 2.1: Generate a(nother) key-pair on the TPU VM, and authorise it to your GitHub account.

  1. Make sure you are inside a shell on the TPU VM, not your local machine. Note that this means everyone is using linux for this stage, unlike last time.

  2. Generate an SSH key on the VM:

    ssh-keygen -t ed25519 -f ~/.ssh/github -C "<your-github-email>"

    (where the email is obviously your actual GitHub email).

  3. Add that SSH key to the SSH agent on the VM:

    eval "$(ssh-agent -s)"
    ssh-add ~/.ssh/github
  4. Add the generated SSH key to your GitHub account. This is done via the GitHub website. Full instructions here (Or watch me).

Full, more detailed instructions for steps 2 and 3 from GitHub here.

Step 2.2: Let me add you to a private repository with some pre-prepared code.

  1. The repository is at https://github.com/matomatical/bitesizedGPT.

  2. If I have not added you already, send me your GitHub username via Discord, I will invite you, and you can accept the invitation then you will be able to use this link.

    Edit: The repository is now public.

Step 2.3: Clone the GitHub repository on the TPU VM.

  1. On the TPU VM shell, in the root directory, run:

    git clone git@github.com:matomatical/bitesizedGPT.git

Interlude: Inside the repository…

The repository is called bitesizedGPT. It’s a simplified implementation based on Andrej Karpathy’s transformer tutorial Let’s build GPT: from scratch, in code, spelled out and his public transformer implementation nanoGPT.

The main differences are:

Whereas nanoGPT can fit on my laptop, bitesizedGPT can fit in my head.

Figure adapted from nanoGPT

I highly recommend studying Karpathy’s resources if you are interested in coding and training transformers.

For now, let me briefly tour the repository contents:

Part 3: Train a transformer on the CPU with PyTorch

Goal: Train the transformer on the CPU.

The code uses PyTorch, but other frameworks are available:

All I know is PyTorch so far.

Step 3.1: Run the code on the CPU.

  1. In a shell on the TPU VM, change into the bitesizedGPT directory:

    cd ~/bitesizedGPT
  2. (Normally, at this stage, you would need to install some non-standard dependencies, but in this case on my TPU VMs all of the dependencies are already installed. For your own TPU VMs, you will need to manage your own environment, and you might choose to use global installs for simplicity or a virtual environment for more control).

  3. Now training the model should be as simple as running the script:

    python train.py

Note: We are not using the TPU yet, but most of the 96-core CPU is being used by PyTorch, and it’s running a bit faster than on my M2 macbook with MPS (Metal Performance Shader) GPU acceleration.

Interlude: From CPU to TPU

We don’t have time to wait for CPUs to train models, the bitter lesson is on our heels. What is it going to take to get this code running on the TPU instead?

Like I said at the start, training on a TPU requires a compilation step. That’s where the compiler XLA (and the PyTorch wrapper library “PyTorch/XLA”) comes in.

Conceptually, this is what needs to happen:

  1. As the PyTorch code runs, instead of performing actual matrix operations, an abstract computational graph is constructed.

  2. At certain ‘break points’ (whenever a concrete value is requested, or whenever a break point is requested directly), XLA compiles the computational graph and then runs it on the TPU. Part of PyTorch/XLA’s role is to bring this ‘lazy’ execution style to PyTorch models, to enable the graph to be compiled.

  3. Compiling is actually pretty slow and destroys the efficiency gains of using the TPU to run the compiled code, but often in machine learning we are compiling the same computational graph repeatedly (e.g. a training step), so XLA caches the computational graphs it compiles so that the total computational cost is kept to a minimum.

    For this reason, XLA is not going to be very effective at speeding up programs that don’t have this kind of repetition. We will see an example later.

PyTorch/XLA takes care of most of this, but unfortunately, it’s still not as simple as writing device='xla'. We need to import torch_xla and make some changes to the code… Let’s do it!

Part 4: Train a transformer on the TPU with PyTorch/XLA

Goal: Train the transformer on the TPU.

I have two options for you, and this time you can genuinely choose either option (see steps 4.2a and 4.2b respectively):

Step 4.1: If still running, quit the current CPU training run.

  1. You can just press control C and it should shut down.

Step 4.2a: Follow these changes I make to train.py’s train() function:

  1. The code we add should only run if we are using XLA. Define a flag XLA at the start of the train function:

    XLA = (device == 'xla')
  2. We need to initialise a device object for the TPU. Insert the following code after the XLA flag:

    if XLA:
        print("set environment variables expected by torch_xla...")
        import os
        os.environ['PJRT_DEVICE'] = "TPU"
        print("importing torch_xla...")
        import torch_xla.core.xla_model as xm
        print("initialising default xla device...")
        device = xm.xla_device()

    The code does the following:

    • Sets the environment variable PJRT_DEVICE to TPU ahead of importing PyTorch/XLA. This tells PyTorch/XLA that we want to compile to the TPU, not, for example, the CPU.
    • Imports the library torch_xla, which will check the environment variable and prepare to target the TPU.
    • Creates the actual torch.device object representing the lazy/caching XLA compiler targeting the TPU.
  3. We need to insert break points in various places to force compilation and execution. Each break point is a line of the form:

    if XLA: xm.mark_step()

    I am not very experienced at this and so I just add break points wherever I think they might be be needed. The basic idea is to separate initialisation, each training step, and each evaluation step, so that that each can be compiled once and executed repeatedly.

Step 4.2b: Switch branch and look at the diff.

  1. Switch to the tpu branch:

    git checkout -t origin/tpu
  2. See the changes this makes to the code by running a diff:

    git diff main

Step 4.3: Run the model with device="xla" now.

  1. Just run:

    python train.py xla

Interlude: Want more speed?

The training run should take around 14 minutes for 50,000 training steps and 500 evaluation steps. That’s the fastest of our available benchmarks:

So the TPU is pretty fast… But it’s not, like, lightning fast. Why? Could it be made even faster?

Bigger models: Training a bigger model will be slower in absolute terms, but relatively speaking, TPU acceleration might give more of a boost to larger models than small ones like this, since it’s possible we are not fully utilising the TPU’s capacity (I don’t know).

To test this, try increasing the batch size and/or the embed size and rerunning the CPU/GPU/TPU comparison.

Carefully controlling XLA: It can be a bit hard to see what’s happening inside the XLA device, and small changes to code can lead accidentally to excessive recompilation (e.g. because of variation in computational graphs) or moving data between CPU and TPU repeatedly (e.g. for operations not supported on the TPU), harming performance gains.

Example: In our code, it appears that the continue function is actually running slower on the TPU than the CPUs, including after first compilation. Can anyone guess why this might be the case?


Frameworks: I’ve heard JAX integrates better with TPUs. I don’t know JAX, but want to learn it and try this.

Parallelisation: So far, we have only been using a quarter of one 8-core TPU because each TPU v2-8 or v3-8 has 4 ‘devices’ (chips) with two cores each, and we used one device by default. Mind blown!

An easy way to gain more effective speed is to run independent training runs on different devices (four per TPU v-8) and different VMs at the same time. Each training run takes the same time, but you can run many at once. This is beyond the scope of this tutorial but I have another guide in the works on this.

A harder way to gain more speed is to parallelise a particular training run across TPU cores and/or across TPU VMs. This is less straightforward and requires some parallel programming, but may be necessary for really large training runs. I haven’t looked into this yet.

(Note: for really large training runs, TPU VM volatility becomes an issue. Google Cloud Platform sometimes restarts VMs. You’ll want to make sure you don’t lose much progress if a TPU VM drops out unexpectedly).

Part 5: (Optional) Apply for your own TPU VMs

Optional goal: You get your own TPU allocation for 30 days.


More seriously, Google Research’s TPU Research Cloud (TRC) offers researchers reasonably generous periods of free access to the TPU VMs we have been using today, nominally for the purpose of promoting ML research. More information on this scheme is on the TRC website.

Step 5.1: Apply for the TPU allocation.

  1. Go to the TRC website.
  2. Click ‘apply now’ and fill in the form.
  3. Then, wait to hear back (took ~30 minutes for me).
  4. Follow the instructions in the reply email to proceed to initialise and configure your own TPU VMs.

Step 5.2: ???

I have a draft of a separate guide on launching and configuring a TPU VM, which would bridge the gap between step 5 and repeating steps 1–4 on your own TPU.

The draft is currently tailored for the ICL project but I’m considering working on a more general guide, and either way it’s still generally informative.

Anyway, the next steps are out of scope today (you can ask me for access to the draft when you get up to that).

Interlude: A free lunch?

Is all of this really free? So far, I haven’t paid a cent, yet… Here’s the situation, as I understand it:

You get free TPUs for 30 days, potentially longer: The TPU VMs are free for 30 days after they are granted. I have heard that the TPU Research Cloud are very generous in offering extensions on request.

For example, in my case, when my thirty started coming up recently, they sent me an email asking if I needed any help using the TPUs, and I replied asking for an extension. Within a couple of days, they granted me an extension from October 7th to December 1st.

I have heard that some people have been asked to fill out feedback surveys on TPU usage in order to get further extensions. This kind of thing would also be consistent with the conditions stated on the TRC website.

If the TRC did decide not to grant my extension, that would be the end of my TPU access—the full price for a single TPU VM is thousands of dollars per month.

You get free starting credits for 90 days: Google Cloud Platform is also a paid platform, selling various computing resources other than TPU VMs, including, for example:

New accounts (or just students? academic staff? can’t remember) get 300–400 USD of starting credits usable for 90 days. I’m at around 30 days into my trial, using only TPU VMs and network traffic, which has used around 50 USD of my starting credits, so I haven’t had to pay anything yet.

After 90 days, I anticipate if I want to keep using TPUs, I would have to pay a small–medium amount of ongoing monthly costs for network traffic, depending on usage.

As long as TRC are allowed/willing to continue to support researchers by extending their allocations, paying for the network traffic will be a small price to pay for access to such powerful compute.

You can continue to use my TPUs for a bit: I don’t mind if you keep using my TPUs for the next week or so (including for the hackathon). I don’t need them right now. If you do use them, please keep in mind these norms as a courtesy:

  1. Stick to your 1 TPU VM unless you coordinate with someone else and want to use theirs, or if you coordinate with me and want to use the spares (since only one process can use each TPU device at a time).
  2. Don’t specifically seek to spend my free credits on e.g. networking.
  3. Don’t abuse root access (e.g. don’t read my GitHub keys or other people’s GitHub keys). (If anyone is worried, revoke this GitHub key).
  4. Try not to break the system e.g. by messing with system packages or files.
  5. Take steps towards getting your own allocation so you are not using mine indefinitely.

I am not able to authenticate additional users to the TPUs at this stage.

Coda: Elementary, my dear…

Is the training run finished yet? What did your model have to say? Let’s go around the room and read the completions. Send them to me and I will add them to this website, so that future foundation models can learn from them.

Prompt: "Elementary, my dear


"Elementary, my dear Or before? A some down, you have
     quite full see an end you. But the name turnised of the mims the
         shatre.then whalll shain opmatHen rhamporfurm.1 Chawnapopseatheroat&
a   nemes aney


                nenestallloan on
"Elementary, my dear through me is made which into
     devening how about to up him vasigning pointle entression, not I was
       Trolime him"s.

o     "Younzs  "Tolelly olist lidg man  russe slill nore.
"Elementary, my dearlial for a
     rave of him the smakel child to no alcosing the har for his you
     was a much under wcouldrsung  anctucheonspppondeou fapeeen
i     undoudekn'ougeon  hugssnineon,  clondiculsucheneon  siggrimerllep  as
        op as th
"Elementary, my dear acheyckner, just
     country selfom expectly his bon us a feweering at metrishe.
     "Then God?"

     Gad he madvdue they His man

       Thkkkend Ithkkkkkenkedkeathmankeazy minke chemunddeablman matk


"Elementary, my dear impluacity. But
     the other with invellor occurian Godor Mostanly.  He he some not an little
     to Botsss plomsssssesh.  a shas  shas  sI  sI howassss enshey
e            Werqff y896 "II's Riencqoassssseym shewerve shervs?" she

Bonus question: How many times does the famous phrase “Elementary, my dear Watson” occur in the entire Holmesian canon?

Answer (hover).

Part 6: Homework exercises

Congratulations on completing the workshop! But there is much work still to do. I hope that what you have learned today will directly help you accelerate your own experiments.

If you are not ready to run experiments on a TPU yet, consider the following exercises based on the bitesizedGPT repository.

Exercises 6.1: Improve model performance on this data set.

  1. More training: Run the training loop for longer and see if the performance continues improving. The transformer has probably seen each training byte ~60 times, could it still benefit from reading the canon again?
  2. Hyperparameter sweep: Experiment with different learning rates, batch sizes, and architecture parameters.
  3. Architectural changes: Does an attention-only version of the transformer perform as well? Does adding dropout help?

Exercises 6.2: Improve TPU performance for prompt completion.

  1. Moving to CPU: What’s faster?
    1. the current method of computing prompt completions, with the model on the TPU, or
    2. moving the model to the CPU, computing the prompt completion, and then moving the model back to the TPU afterwards?
    Hint: this shouldn’t depend on the transformer’s weights, you can test using a randomly initialised transformer.
  2. Breaking loops: Does adding additional break points within the prompt completion inner loop help?
  3. Flagging loops: Does PyTorch/XLA have a method for you to tell it about a loop, and it will be more smart about compiling it? I think I read about this somewhere but forgot where, go hunting.

Exercises 6.3: Explore other frameworks for TPU computation.

  1. JAX: Re-implement this transformer in JAX and get it running on the TPU.

In all cases, I would be very interested to see and discuss your results!