GPT on a TPU VM with PyTorch/XLA
How I Learned to Stop Worrying and Love the Brrr
independent AI safety researcher
- Learn about Tensor Processing Units (TPUs), TPU Virtual Machines (TPU VMs), and Accelerated Linear Algebra (XLA), and how they can potentially accelerate your computing work-flows.
- Access and run code on a real TPU VM, including cloning a private GitHub repository, training a GPT-style transformer on the CPU using PyTorch, and then training the same transformer on the TPU using PyTorch/XLA.
- (Optional) apply to the TPU Research Cloud (TRC) for your own free 30-day (at first) TPU VM allocation on Google Cloud Platform (GCP) to accelerate your research.
- Knowledge of Python and PyTorch will help you appreciate the examples.
- A shell with the commands
- Any terminal emulator on linux or macos should be fine.
- On windows, consider using WSL, git bash, conda, or a better OS.
- Basic familiarity with unix command line.
- A GitHub account.
- (If you are not following a live workshop, but found this guide
- Someone needs to lend you access to a TPU VM to run this workshop. Otherwise, first set up your own free TPU VMs (see part 5).
- The workshop also makes use of a private repository containing some
PyTorch example code. I’ll eventually make this public. In the mean
time, you can ask me and I am happy to share it.
Edit: The repository is now public.
Overview of steps:
- Access a TPU VM via SSH.
- Share code with the VM via Git and GitHub.
- Train a transformer on the CPU with PyTorch.
- Train a transformer on the TPU with PyTorch/XLA.
- (Optional) Apply for your own TPU VMs.
- (Homework) Exercises.
Along the way, we’ll learn about TPU VMs, SSH, GPT, XLA, GCP and the TRC!
Prelude: WTF? Any more initialisms?
OK, it’s not a great sign that there are more capital letters in that title than there are lower case letters. Let me spell out the main terms used in this workshop.
TPU: A Tensor Processing Unit is a specialised processor designed by Google to execute large batches of neural network operations (e.g. float32 matrix multiplications) more efficiently than on a general-purpose CPU or a GPU specialised for rendering.
More information (worth reading afterwards):
- Wikipedia on the history of TPU development (since 2015).
- Google Cloud on how TPUs accelerate certain operations and what they are (and aren’t) most useful for.
VM: A Virtual Machine is an emulated computer system, offering the functionality of a full operating system but running on shared hardware. In our case, the main hardware of interest is a TPU. A TPU VM is a VM that has access to a TPU.
XLA: Accelerated Linear Algebra is a compiler that takes computational graphs (like those specified by a Python program) and transforms them for efficient execution on specific hardware (including CPUs, GPUs, TPUs, and other ML hardware).
Note, XLA is not specific to TPUs. However, running code on TPUs currently requires XLA, because TPUs don’t (yet?) run uncompiled computational graphs. That’s why we will be using XLA today.
- XLA is open source and more information can be found on its public GitHub repository.
- PyTorch/XLA is a library that bridges XLA and PyTorch. There is more information in the PyTorch/XLA GitHub repository or documentation.
SSH: The Secure
SHell protocol allows you to securely login to and run
commands on a private server (such as a TPU VM) over a network. To do so
you need an SSH client on your local machine. We will use the
GPT: Recall that a Generative Pre-trained Transformer is a foundation model based on a decode-only transformer architecture.
GCP: Google Cloud Platform is basically Google’s version of Microsoft Azure or Amazon AWS (don’t get me started on AWS initialisms…). This is where my TPU VMs live.
TRC: The TPU Research Cloud is a program by Google Research aimed at supporting open access machine learning research with free GCP compute resources, including TPU VMs, including my TPU VMs we’ll be using today. You can find out more and apply for a TPU allocation on the TRC website, about which, more later.
Part 1: Access a TPU VM via SSH
Goal: Obtain shell access to a live TPU VM.
Options for obtaining a TPU VM:
- You could rent a TPU VM and configure shell access.
- You could apply for a free TPU VM allocation and configure shell access.
- I could just configure shell access to one of my TPU VMs.
Options for actually obtaining shell access, once you have a TPU VM:
- You could visit Google’s data center and plug in a monitor and a keyboard.
- You could use the Google Cloud web app’s built-in shell.
- You could use a graphical SSH client.
- You could just use
ssh, a terminal SSH client.
Let’s do the latter in both cases.
Step 1.1: Generate a key pair with
Open your shell.
Run the following command:
ssh-keygen -t ed25519 -f <keypath> -C <username>
<username>is replaced with an arbitrary username, and
<keypath>is the place to store the keys:
- On macos/linux/wsl, suggest
- On windows, maybe something like
- On macos/linux/wsl, suggest
Step 1.2: Add the public key to my TPU VMs on Google Cloud Platform.
Open the file
<keypath.pub>, for example
Copy the contents and send them to me via Discord.
I will add them to my TPU VMs, please wait.
Warning: Sending me your public key is safer than sending me your private key, even if it ends up in the recording—only with the private key can you log in with this account.
However, after this we will all have root privileges on all of my TPU VMs, which means everyone can in principle see what is in everyone else’s account. This will later include a GitHub key. Hopefully this is acceptable for today. Consider revoking this key from GitHub after the workshop if you prefer.
Step 1.3: Connect to the TPU VM with
Open your shell.
Run the following command:
ssh -i <keypath> <username>@<ip-address>
<username>are the same as you used in step 1.1.2, and
<ip-address>is the address I share with you via Discord now.
Interlude: Where am I?
If this command works, you will now have an active shell connection to the TPU VM. That is, the commands you run going forward will run on the TPU VM, not your local machine. Conceptually, the shell you see now is the same as you would see if you were in Google’s data centre with a monitor and a keyboard (modulo some lag).
Now that we are here, where is here? Your TPU VM runs Ubuntu Linux. It has the following hardware:
- 96 GB disk space,
- a 96-core CPU,
- 335 GB of main memory,
- an 8-core TPU (v2 or v3), and
- 128 GB of TPU memory.
You can play with the CPU now. The rest of this workshop will work towards exercising that TPU.
Part 2: Share code with the VM via Git and GitHub
Goal: Get code from your machine to the VM.
Options for doing this:
- You could just directly send code files to the server using
rsync(for example). You can send updates every time the code changes.
- You could store the code in a private online git repository, such as via GitHub. You can make changes either locally or online and track them with git.
Let’s do the latter.
Edit: The example repository is now public, so an SSH key is not strictly required, however you might like to follow these instructions anyway because an SSH key will be necessary later if you want yo clone your own private repository.
Step 2.1: Generate a(nother) key-pair on the TPU VM, and authorise it to your GitHub account.
Make sure you are inside a shell on the TPU VM, not your local machine. Note that this means everyone is using linux for this stage, unlike last time.
Generate an SSH key on the VM:
ssh-keygen -t ed25519 -f ~/.ssh/github -C "<your-github-email>"
(where the email is obviously your actual GitHub email).
Add that SSH key to the SSH agent on the VM:
eval "$(ssh-agent -s)" ssh-add ~/.ssh/github
Add the generated SSH key to your GitHub account. This is done via the GitHub website. Full instructions here (Or watch me).
Full, more detailed instructions for steps 2 and 3 from GitHub here.
Step 2.2: Let me add you to a private repository with some pre-prepared code.
The repository is at https://github.com/matomatical/bitesizedGPT.
If I have not added you already, send me your GitHub username via Discord, I will invite you, and you can accept the invitation then you will be able to use this link.
Edit: The repository is now public.
Step 2.3: Clone the GitHub repository on the TPU VM.
On the TPU VM shell, in the root directory, run:
git clone email@example.com:matomatical/bitesizedGPT.git
Interlude: Inside the repository…
The repository is called bitesizedGPT. It’s a simplified implementation based on Andrej Karpathy’s transformer tutorial Let’s build GPT: from scratch, in code, spelled out and his public transformer implementation nanoGPT.
The main differences are:
- This transformer architecture is simpler than nanoGPT (e.g. no dropout).
- This implementation uses ASCII code-points=characters=bytes directly (instead of a small subset of unicode as in the tutorial, or BPE tokens as in nanoGPT).
- We use the full Sherlock Holmes canon as a data set (rather than complete works of Shakespeare).
- The total lines of code is just 264 (compared to ~600 for nanoGPT).
Whereas nanoGPT can fit on my laptop, bitesizedGPT can fit in my head.
I highly recommend studying Karpathy’s resources if you are interested in coding and training transformers.
For now, let me briefly tour the repository contents:
data: this folder stores some text files we will use as our training corpus.
- The main file of interest is
sherlock-ascii.txt, which is the complete Sherlock Holmes canon with non-ASCII characters replaced with ASCII equivalents. It’s modified from here.
- The main file of interest is
model.py: defines a transformer that operates on byte vectors and associated utilities for converting Python strings to/from byte vectors. The underlying decode-only transformer is the same code I am using for ICL experiments, using PyTorch.
train.py: a very simple PyTorch training loop.
Part 3: Train a transformer on the CPU with PyTorch
Goal: Train the transformer on the CPU.
The code uses PyTorch, but other frameworks are available:
- Julia (not Python)
All I know is PyTorch so far.
Step 3.1: Run the code on the CPU.
In a shell on the TPU VM, change into the bitesizedGPT directory:
(Normally, at this stage, you would need to install some non-standard dependencies, but in this case on my TPU VMs all of the dependencies are already installed. For your own TPU VMs, you will need to manage your own environment, and you might choose to use global installs for simplicity or a virtual environment for more control).
Now training the model should be as simple as running the script:
Note: We are not using the TPU yet, but most of the 96-core CPU is being used by PyTorch, and it’s running a bit faster than on my M2 macbook with MPS (Metal Performance Shader) GPU acceleration.
Interlude: From CPU to TPU
We don’t have time to wait for CPUs to train models, the bitter lesson is on our heels. What is it going to take to get this code running on the TPU instead?
Like I said at the start, training on a TPU requires a compilation step. That’s where the compiler XLA (and the PyTorch wrapper library “PyTorch/XLA”) comes in.
Conceptually, this is what needs to happen:
As the PyTorch code runs, instead of performing actual matrix operations, an abstract computational graph is constructed.
At certain ‘break points’ (whenever a concrete value is requested, or whenever a break point is requested directly), XLA compiles the computational graph and then runs it on the TPU. Part of PyTorch/XLA’s role is to bring this ‘lazy’ execution style to PyTorch models, to enable the graph to be compiled.
Compiling is actually pretty slow and destroys the efficiency gains of using the TPU to run the compiled code, but often in machine learning we are compiling the same computational graph repeatedly (e.g. a training step), so XLA caches the computational graphs it compiles so that the total computational cost is kept to a minimum.
For this reason, XLA is not going to be very effective at speeding up programs that don’t have this kind of repetition. We will see an example later.
PyTorch/XLA takes care of most of this, but unfortunately, it’s still
not as simple as writing
device='xla'. We need to import
torch_xla and make some changes to the code… Let’s do
Part 4: Train a transformer on the TPU with PyTorch/XLA
Goal: Train the transformer on the TPU.
I have two options for you, and this time you can genuinely choose either option (see steps 4.2a and 4.2b respectively):
- “Thanks I’ll do it myself”-option: Modify the code on the VM
- You can follow the changes I make or copy-paste the snippets below, live.
- You will need to use a text-based text editor, such as
vim, on the VM, which is not that straightforward the first time.
- “Here’s one I prepared earlier”-option: Simply check out the branch
tpuwith the modifications already made.
- This is a bit easier and faster
- You can still see the changes I make live, or using
Step 4.1: If still running, quit the current CPU training run.
- You can just press control C and it should shut down.
Step 4.2a: Follow these changes I make to
The code we add should only run if we are using XLA. Define a flag
XLAat the start of the
XLA = (device == 'xla')
We need to initialise a device object for the TPU. Insert the following code after the
if XLA: print("set environment variables expected by torch_xla...") import os os.environ['PJRT_DEVICE'] = "TPU" print("importing torch_xla...") import torch_xla.core.xla_model as xm print("initialising default xla device...") device = xm.xla_device()
The code does the following:
- Sets the environment variable
TPUahead of importing PyTorch/XLA. This tells PyTorch/XLA that we want to compile to the TPU, not, for example, the CPU.
- Imports the library
torch_xla, which will check the environment variable and prepare to target the TPU.
- Creates the actual
torch.deviceobject representing the lazy/caching XLA compiler targeting the TPU.
- Sets the environment variable
We need to insert break points in various places to force compilation and execution. Each break point is a line of the form:
if XLA: xm.mark_step()
I am not very experienced at this and so I just add break points wherever I think they might be be needed. The basic idea is to separate initialisation, each training step, and each evaluation step, so that that each can be compiled once and executed repeatedly.
Step 4.2b: Switch branch and look at the diff.
Switch to the
git checkout -t origin/tpu
See the changes this makes to the code by running a diff:
git diff main
Step 4.3: Run the model with
python train.py xla
Interlude: Want more speed?
The training run should take around 14 minutes for 50,000 training steps and 500 evaluation steps. That’s the fastest of our available benchmarks:
- The TPU VM CPUs estimate it would take ~21 minutes.
- My M2 macbook air GPU with MPS acceleration estimates it would take ~24 minutes (before thermal throttling).
- My M2 macbook air CPU estimates it would take ~38 minutes (before thermal throttling).
So the TPU is pretty fast… But it’s not, like, lightning fast. Why? Could it be made even faster?
Bigger models: Training a bigger model will be slower in absolute terms, but relatively speaking, TPU acceleration might give more of a boost to larger models than small ones like this, since it’s possible we are not fully utilising the TPU’s capacity (I don’t know).
To test this, try increasing the batch size and/or the embed size and rerunning the CPU/GPU/TPU comparison.
Carefully controlling XLA: It can be a bit hard to see what’s happening inside the XLA device, and small changes to code can lead accidentally to excessive recompilation (e.g. because of variation in computational graphs) or moving data between CPU and TPU repeatedly (e.g. for operations not supported on the TPU), harming performance gains.
Example: In our code, it appears that the
function is actually running slower on the TPU than the CPUs, including
after first compilation. Can anyone guess why this might be the
- PyTorch/XLA repository and documentation contain some useful advice, examples, and further links.
torch_xla.debugcontains some methods that will report some diagnostic metrics such as number of cache hits and misses.
- Google TPU Research Cloud documentation has some tutorials on more detailed profiling methods. I haven’t followed this one yet.
Frameworks: I’ve heard JAX integrates better with TPUs. I don’t know JAX, but want to learn it and try this.
- For a guide to TPU VMs that emphasises JAX, see ayaka14932’s guide.
- TPU Research Cloud documentation has many examples for TF and JAX (unlike PyTorch). There are probably also other good resources online, I haven’t searched for nor evaluated any.
Parallelisation: So far, we have only been using a quarter of one 8-core TPU because each TPU v2-8 or v3-8 has 4 ‘devices’ (chips) with two cores each, and we used one device by default. Mind blown!
An easy way to gain more effective speed is to run independent training runs on different devices (four per TPU v-8) and different VMs at the same time. Each training run takes the same time, but you can run many at once. This is beyond the scope of this tutorial but I have another guide in the works on this.
A harder way to gain more speed is to parallelise a particular training run across TPU cores and/or across TPU VMs. This is less straightforward and requires some parallel programming, but may be necessary for really large training runs. I haven’t looked into this yet.
(Note: for really large training runs, TPU VM volatility becomes an issue. Google Cloud Platform sometimes restarts VMs. You’ll want to make sure you don’t lose much progress if a TPU VM drops out unexpectedly).
Part 5: (Optional) Apply for your own TPU VMs
Optional goal: You get your own TPU allocation for 30 days.
- Pay thousands of dollars per TPU VM.
- Ingratiate yourself to Google Research and be allocated 10 TPU VMs for free.
More seriously, Google Research’s TPU Research Cloud (TRC) offers researchers reasonably generous periods of free access to the TPU VMs we have been using today, nominally for the purpose of promoting ML research. More information on this scheme is on the TRC website.
Step 5.1: Apply for the TPU allocation.
- Go to the TRC website.
- Click ‘apply now’ and fill in the form.
- Then, wait to hear back (took ~30 minutes for me).
- Follow the instructions in the reply email to proceed to initialise and configure your own TPU VMs.
Step 5.2: ???
I have a draft of a separate guide on launching and configuring a TPU VM, which would bridge the gap between step 5 and repeating steps 1–4 on your own TPU.
The draft is currently tailored for the ICL project but I’m considering working on a more general guide, and either way it’s still generally informative.
Anyway, the next steps are out of scope today (you can ask me for access to the draft when you get up to that).
Interlude: A free lunch?
Is all of this really free? So far, I haven’t paid a cent, yet… Here’s the situation, as I understand it:
You get free TPUs for 30 days, potentially longer: The TPU VMs are free for 30 days after they are granted. I have heard that the TPU Research Cloud are very generous in offering extensions on request.
For example, in my case, when my thirty started coming up recently, they sent me an email asking if I needed any help using the TPUs, and I replied asking for an extension. Within a couple of days, they granted me an extension from October 7th to December 1st.
I have heard that some people have been asked to fill out feedback surveys on TPU usage in order to get further extensions. This kind of thing would also be consistent with the conditions stated on the TRC website.
If the TRC did decide not to grant my extension, that would be the end of my TPU access—the full price for a single TPU VM is thousands of dollars per month.
You get free starting credits for 90 days: Google Cloud Platform is also a paid platform, selling various computing resources other than TPU VMs, including, for example:
- Storage buckets (for storage beyond the 96 GB on each TPU VM).
- Network traffic (e.g. for sending checkpoints to another storage platform).
New accounts (or just students? academic staff? can’t remember) get 300–400 USD of starting credits usable for 90 days. I’m at around 30 days into my trial, using only TPU VMs and network traffic, which has used around 50 USD of my starting credits, so I haven’t had to pay anything yet.
After 90 days, I anticipate if I want to keep using TPUs, I would have to pay a small–medium amount of ongoing monthly costs for network traffic, depending on usage.
As long as TRC are allowed/willing to continue to support researchers by extending their allocations, paying for the network traffic will be a small price to pay for access to such powerful compute.
You can continue to use my TPUs for a bit: I don’t mind if you keep using my TPUs for the next week or so (including for the hackathon). I don’t need them right now. If you do use them, please keep in mind these norms as a courtesy:
- Stick to your 1 TPU VM unless you coordinate with someone else and want to use theirs, or if you coordinate with me and want to use the spares (since only one process can use each TPU device at a time).
- Don’t specifically seek to spend my free credits on e.g. networking.
- Don’t abuse root access (e.g. don’t read my GitHub keys or other people’s GitHub keys). (If anyone is worried, revoke this GitHub key).
- Try not to break the system e.g. by messing with system packages or files.
- Take steps towards getting your own allocation so you are not using mine indefinitely.
I am not able to authenticate additional users to the TPUs at this stage.
Coda: Elementary, my dear…
Is the training run finished yet? What did your model have to say? Let’s go around the room and read the completions. Send them to me and I will add them to this website, so that future foundation models can learn from them.
"Elementary, my dear
"Elementary, my dear Or before? A some down, you have
quite full see an end you. But the name turnised of the mims the
shatre.then whalll shain opmatHen rhamporfurm.1 Chawnapopseatheroat&
a nemes aney
"Elementary, my dear through me is made which into
devening how about to up him vasigning pointle entression, not I was
o "Younzs "Tolelly olist lidg man russe slill nore.
"Elementary, my dearlial for a
rave of him the smakel child to no alcosing the har for his you
was a much under wcouldrsung anctucheonspppondeou fapeeen
i undoudekn'ougeon hugssnineon, clondiculsucheneon siggrimerllep as
op as th
"Elementary, my dear acheyckner, just
country selfom expectly his bon us a feweering at metrishe.
Gad he madvdue they His man
Thkkkend Ithkkkkkenkedkeathmankeazy minke chemunddeablman matk
"Elementary, my dear impluacity. But
the other with invellor occurian Godor Mostanly. He he some not an little
to Botsss plomsssssesh. a shas shas sI sI howassss enshey
e Werqff y896 "II's Riencqoassssseym shewerve shervs?" she
Bonus question: How many times does the famous phrase “Elementary, my dear Watson” occur in the entire Holmesian canon?
Part 6: Homework exercises
Congratulations on completing the workshop! But there is much work still to do. I hope that what you have learned today will directly help you accelerate your own experiments.
If you are not ready to run experiments on a TPU yet, consider the following exercises based on the bitesizedGPT repository.
Exercises 6.1: Improve model performance on this data set.
- More training: Run the training loop for longer and see if the performance continues improving. The transformer has probably seen each training byte ~60 times, could it still benefit from reading the canon again?
- Hyperparameter sweep: Experiment with different learning rates, batch sizes, and architecture parameters.
- Architectural changes: Does an attention-only version of the transformer perform as well? Does adding dropout help?
Exercises 6.2: Improve TPU performance for prompt completion.
- Moving to CPU: What’s faster?
- the current method of computing prompt completions, with the model on the TPU, or
- moving the model to the CPU, computing the prompt completion, and then moving the model back to the TPU afterwards?
- Breaking loops: Does adding additional break points within the prompt completion inner loop help?
- Flagging loops: Does PyTorch/XLA have a method for you to tell it about a loop, and it will be more smart about compiling it? I think I read about this somewhere but forgot where, go hunting.
Exercises 6.3: Explore other frameworks for TPU computation.
- JAX: Re-implement this transformer in JAX and get it running on the TPU.
In all cases, I would be very interested to see and discuss your results!