far.in.net


Attention: All you need to know1

Saturday, March 29th, 2025

In this technical note, I explore attention, a key element of the transformer architecture driving recent progress in deep learning applications. I review four standard perspectives on how to understand the attention mechanism. The four perspectives are as follows.

  1. Attention as a soft, parallel search for relevant information from a sequence of vectors. This is the perspective adopted during the original development of attention for machine translation (Bahdanau et al., 2015; Vaswani et al., 2017).

  2. Attention as moving information between “residual streams.” This perspective was developed by Elhage et al. (2021) in “A mathematical framework for transformer circuits,” (henceforth the “mathematical framework”), and forms the foundation for much work on mechanistic interpretability of language models.

  3. Attention as two coupled (but individually simple) low-rank transforms. This perspective is another contribution of the “mathematical framework,” and is mainly of mathematical interest.

  4. Attention as a message exchange between nodes in a network. This perspective comes from Andrej Karpathy, who introduces it in a lecture as part of his perspective on transformers as differentiable computers (Karpathy, 2023).

These perspectives are quite closely related. Of course, they all describe the same mechanism. But, there are subtle (and not-so-subtle) differences in emphasis and in the meanings they attach to different variables in the attention mechanism. My hope is, by laying out all these perspectives, I’ll be in a position to use the most appropriate perspective in any given situation, and I’ll achieve a somehow deeper understanding of attention in the end.

Acknowledgements: Thanks to Edmund Lau for discussion and references.

Contents

This note has the following sections.

Pandoc counts 6214 words plus 189 mathematics environments.

Background: Sequences and transformers

First, a brief introduction to sequence modelling and the transformer architecture.

Sequence modelling is the task of learning to predict the continuations of sequences of vectors. There are a couple of variants of the ‘type signature’ of this problem. For this note, I want to think about a sequence model as a function with the following signature:

A transformer is a deep neural network architecture designed for sequence modelling. The architecture itself consists in the specific chain of linear and nonlinear transformations used to map an input sequence to an output sequence. The prototypical transformer architecture involves the following components.

  1. An embedding block, mapping the sequence of input vectors u1,,uTRdu_1, \ldots, u_T \in \mathbb{R}^d into an embedded sequence in some internal representation space, x1,,xTRex_1, \ldots, x_T \in \mathbb{R}^e.

  2. A number of blocks, each of which computes an additive update to each internal representation vector in the sequence. There are two kinds of blocks.

    1. Attention blocks operate on pairs of internal representations.

    2. Multi-layer perceptron (MLP) blocks operate on each internal representation independently.

    Usually, we alternate attention and MLP blocks. The term transformer block refers to a pair (attention block, MLP block), and the number of “layers” in a transformer is considered to be the number of transformer blocks.

  3. Finally, an unembedding block, mapping from the final sequence of updated internal representation vectors x1,,xTRex_1', \ldots, x_T' \in \mathbb{R}^e to the overall transformer’s output sequence v1,,vTRdv_1, \ldots, v_T \in \mathbb{R}^{d'}.

This note is mostly about item 2.1—the attention blocks. Therefore, the main thing to understand is how attention fits into the overall architecture. In particular, attention is the step in which information is transmitted between different parts of the sequence. In fact, it’s the only such step, as all other steps operate on each sequence element independently. Indeed, this is the essential idea of the transformer architecture. The contribution of “Attention is all you need” (Vaswani et al., 2017; the paper that invented the transformer architecture) was to show that attention could play this role, where previous architectures used recurrent or convolutional networks to transfer information between sequence positions.

The above is a very high-level description of the transformer architecture. There are lots of details to add, including positional encodings, multi-head attention (running multiple attention mechanisms in parallel in one attention block), layer normalisation, and dropout. For this note, we don’t need to worry about most of these details.

For a more detailed introduction to the rest of the transformer architecture, I recommend 3Blue1Brown’s LLM playlist or Karpathy’s lecture. To really understand the computations involved, there’s nothing better than following Karpathy’s tutorial. Alternatively, for those who prefer equations over code, spend some time with the relevant sections of the “mathematical framework”, or with “Formal algorithms for transformers” from Phuong and Hutter (2022).

Preliminaries: Defining attention

So, then, what is this attention mechanism? Consider a fixed context length TNT \in \mathbb{N}. The input to the mechanism is a sequence of embedded vectors x1,,xTRex_1, \ldots, x_T \in \mathbb{R}^e, and the output is another such sequence, y1,,yTRey_1, \ldots, y_T \in \mathbb{R}^e, where eNe \in \mathbb{N} is the embedding dimension (these outputs are the ‘additive updates’ to the input vectors I mentioned earlier). We can stack the input and output vectors together as the columns2 of two matrices X,YRe×TX, Y \in \mathbb{R}^{e \times T}, that is, X=[  x1x2xT  ]andY=[  y1y2yT  ]. X = \begin{bmatrix} | & | & & | \\~~ x_1 & x_2 & \cdots & x_T ~~ \\ | & | & & | \end{bmatrix} \qquad \text{and} \qquad Y = \begin{bmatrix} | & | & & | \\~~ y_1 & y_2 & \cdots & y_T ~~ \\ | & | & & | \end{bmatrix}.

The attention mechanism is parameterised by four weight matrices:

Here, h,hNh, h' \in \mathbb{N} are hidden dimensions of the query/key and value/output pairs, respectively. Now, let’s now write down the matrix equation for “scaled dot-product attention,” which is the form used in “Attention is all you need” (modulo some transpositions): Y=WOWVXσ ⁣(XWKWQX). Y = W_O^\top W_V X \sigma\!\left( X^\top W_K^\top W_Q X \right). Here, σ:RT×TRT×T\sigma : \mathbb{R}^{T \times T} \to \mathbb{R}^{T \times T} is a scaled and masked column-wise softmax function, such that, for ART×TA \in \mathbb{R}^{T \times T}, σ(A)=A~\sigma(A) = \tilde A, a column stochastic matrix with entries A~j,i=Mj,iexp ⁣(βAj,i)k=1TMk,iexp ⁣(βAk,i) \tilde A_ {j,i} = \frac { M_{j,i} \exp\!\left(\beta \cdot A_{j,i}\right) } {\sum_{k=1}^{T} M_{k,i} \exp\!\left(\beta \cdot A_{k,i}\right) } where β\beta is an inverse temperature (usually taken to be 1/h1/\sqrt h) that helps with optimisation, and M{0,1}T×TM \in \{0,1\}^{T \times T} is a mask matrix (for autoregressive language modelling, we would use an upper triangular matrix of ones, which would make A~\tilde A also upper triangular).

There’s a lot going on in this definition. In order to understand it, we’ll step through the attention equation. We’ll go through it once in detail under the ‘soft search’ perspective. Then, we’ll revisit several aspects from each of the other perspectives in turn.

The original perspective on attention is that it’s a ‘soft search’ over a set of input vectors. Imagine you’re at a particular point in the input sequence, i{1,,T}i \in \{1, \ldots, T\}. You think there might be some relevant information represented with at least one of the preceding input positions j{1,,i}j \in \{1, \ldots, i\}, but you’re not sure where. Attention helps you look up that information.

Recall the definition of attention: Y=WOWVXσ ⁣(XWKWQX). Y = W_O^\top W_V X \sigma\!\left( X^\top W_K^\top W_Q X \right). Let’s step through the expression on the right hand side. We’ll start with the inside of the σ\sigma function.

  1. First, for each vector in the input sequence xix_i, we formulate a query (column) vector using the query transform, qi=WQxiRhq_i = W_Q x_i \in \mathbb{R}^h, representing what kind of information we’re searching for at this part of the sequence. These query vectors form the rows of the query matrix Q=WQXRh×TQ = W_Q X \in \mathbb{R}^{h \times T}.

  2. Then, for each vector in the input sequence xjx_j, we formulate a key vector using the key transform, kj=WKxjRhk_j = W_K x_j \in \mathbb{R}^h, representing what kind of information is available at this part of the sequence. These key vectors form the rows of the key matrix K=WKXRh×TK = W_K X \in \mathbb{R}^{h \times T}.

  3. To determine the relevance score of key kjk_j to query qiq_i, we take the dot product of the key and the query, aj,i=kjqiRa_{j,i} = k_j^\top q_i \in \mathbb{R}. The relevance score aj,ia_{j,i} is exactly the entry at row jj, column ii in the score matrix A=KQ=XWKWQXRT×TA = K^\top Q = X^\top W_K^\top W_Q X \in \mathbb{R}^{T \times T}.

We’ve completed the first major step of the calculation—the inside of the σ\sigma function—and found that it represents scoring search results using the query and key transforms. The next step is to perform the softmax calculation itself. This time, rather than stepping through the calculation, let’s step through three versions of the calculation of progressively increasing complexity.

  1. For a given ii, define aiRTa_i \in \mathbb{R}^T to be the iith column of the score matrix, ai=[a1,iaT,i]. a_i = \begin{bmatrix} a_{1,i} \\\vdots \\a_{T,i} \end{bmatrix}. This column represents how relevant each key is to the query from position ii. What we want to do next is somehow ‘rank’ these matches. The best result for a given query would be the row with the highest positive score. We could represent this result as a one-hot column vector.

    However, we want the computation to be differentiable, and we might actually want to consider multiple results, so we’ll need something ‘softer’ than max…

  2. Accordingly, let’s use softmax to normalise the column of scores into a probability distribution representing an scaled soft result ranking, a~i=exp(βai)exp(βai)1=1j=1Texp(βaj,i)[exp(βa1,i)exp(βaT,i)]ΔT1RT \tilde a_i = \frac{ \exp(\beta a_i) }{ \|\exp(\beta a_i)\|_ 1 } = \frac1{\sum_{j=1}^T \exp(\beta a_{j,i})} \begin{bmatrix} \exp(\beta a_{1,i}) \\\vdots \\\exp(\beta a_{T,i}) \end{bmatrix} \in \Delta^{T-1} \subset \mathbb{R}^{T} where βR+\beta \in \mathbb{R}^{+} is an inverse temperature (included to control how soft the softmax is) and ΔT1\Delta^{T-1} is the probability simplex in TT dimensions (indicating that s~i\tilde s_i encodes a probability distribution).

    Actually, we might want to restrict which results can be considered by eqch query. For example, in autoregressive language modelling, we only want to consider matches from earlier sequence positions. Therefore, we need to include masking…

  3. We get the final version by forcing the probability to zero for responses we want to exclude. We can achieve this by including a mask matrix M{0,1}T×TM \in \{0,1\}^{T \times T} such that Mj,iM_{j,i} contains a 11 if position jj is a valid result for the query at position ii. Let mim_i be the iith column of MM. Thus we can define us the masked and scaled soft result ranking: a~i=miexp(βai)miexp(βai)1=1j=1iMj,iexp(βaj,i)[M1,iexp(βa1,i)MT,iexp(βaT,i)]ΔT1RT \tilde a_i = \frac{ m_i \odot \exp(\beta a_i) }{ \|m_i \odot \exp(\beta a_i)\|_ 1 } = \frac1{\sum_{j=1}^i M_{j,i} \exp(\beta a_{j,i})} \begin{bmatrix} M_{1,i} \exp(\beta a_{1,i}) \\\vdots \\M_{T,i} \exp(\beta a_{T,i}) \end{bmatrix} \in \Delta^{T-1} \subset \mathbb{R}^{T} where \odot is an elementwise vector product.

    These masked and scaled soft result rankings form the columns of the attention matrix, a column stochastic matrix: A~=σ(A)=σ ⁣(XWKWQX)RT×T. \tilde A = \sigma(A) = \sigma\!\left(X^\top W_K^\top W_Q X \right) \in \mathbb{R}^{T \times T}.

We’re nearly there! The attention matrix represents the (soft) ranking of the search results. All we need to do now is retrieve the results themselves (using the value transform) and interpret them (using the output transform).

  1. First, for each vector in the input sequence xtx_t, we formulate a value vector3 using the value transform, vt=WVxtRhv_t = W_V x_t \in \mathbb{R}^{h'}, representing the actual information content available at this part of the sequence. These value vectors form the rows of the value matrix V=WVXRh×TV = W_V X \in \mathbb{R}^{h' \times T}.

  2. The next step is to take the product VA~V \tilde A. Since A~\tilde A is a column stochastic matrix, the columns of this product matrix are mixtures of the columns of VV according to those probability vectors. In detail, consider query ii. Recall that we have a probability distribution a~iRT\tilde a_i \in \mathbb{R}^T. We use this probability distribution to mix the value vectors v1,,vTv_1, \ldots, v_T, producing a single mixture result v~i=Va~i=t=1T(a~i)tvtRh. \tilde v_i = V \tilde a_i = \sum_{t=1}^T (\tilde a_i)_ t v_ t \in \mathbb{R}^{h'}. The mixture result v~i\tilde v_i represents the combined information returned from the search at position ii. These mixture results form the columns of the mixed value matrix V~=VA~=WVXσ ⁣(XWKWQX)Rh×T. \tilde V = V \tilde A = W_V X \sigma\!\left(X^\top W_K W_Q X\right) \in \mathbb{R}^{h' \times T} .

  3. Finally, for each mixed result v~i\tilde v_i, we formulate an output vector using the output transform, yi=WOv~iRey_i = W_O^\top \tilde v_i \in \mathbb{R}^e, representing the interpretation of the search result in the original space. This matches the definition of attention since these output vectors form the columns of the output matrix Y=WOV~=WOWVXσ ⁣(XWKWQX)Re×T, Y = W_O^\top \tilde V = W_O^\top W_V X \sigma\!\left(X^\top W_K W_Q X\right) \in \mathbb{R}^{e \times T}, as required.

As I said, a lot going on! But with one perspective under our belt, the others will be much easier to follow.

Remarks on perspective 1

Before moving on to the next perspective, I wanted to make some brief remarks on what this perspective offers us. The idea of attention was originally motivated in the context of machine translation. As I understand, the history looks roughly like this:

Since then, individual elements of the transformer architecture, including ‘all-pairs’ attention, have been refined, but the decoder from the original paper has essentially the form of the architecture used in modern large language models.

Perspective 2: Attention as information routing

This brings us to the early days of Anthropic, and the launch of the transformer circuits thread with the article “A mathematical framework for transformer circuits” (Elhage et al., 2021). The interpretability team at Anthropic adopted the goal of understanding the internal computations going on inside large-scale transformers. Their first move was to get their hands dirty with the architecture’s basic equations, including attention, and see what they could make of them. Among other insights (see the full article), they came to the following observation:

The fundamental action of attention heads is moving information. They read information from the residual stream of one token, and write it to the residual stream of another token.

To understand this perspective, let’s first break down what they are referring to as the “residual stream of a token.”

  1. Note that the attention/MLP blocks of the transformer operate on sequences of vectors. After the embedding block, these vectors are the initial embedded input vectors. Each successive block modifies these internal representations in some way, producing an altered sequence of vectors, which then becomes the input for the next block. The output of the final attention/MLP block becomes the input of the unembedding block, producing the transformer’s overall output vector sequence.

  2. Each block preserves the size of the sequence of vectors. We can therefore pick a particular sequence position, and follow the chain of vectors at that position through each block, accumulating the various modifications. This gives us a different sequence of vectors—one with length B+1B+1 where BB is the number of attention/MLP blocks in the transformer. This sequence of vectors is called a residual stream, and there is one for each position in the input/output sequences (from 11 to TT).

  3. Why is it called a residual stream? Recall that the attention (and MLP) blocks compute additive updates to their input. The name for this kind of block is a residual block, because it means that the role of each attention/MLP block is to learn to output the differences (or “residuals”), yiy_i, between their inputs xix_i and the block’s outputs xi+yix_i + y_i. The neural connection by which each xix_i bypasses the attention/MLP mechanisms and still gets included in the output is referred to as a “skip connection” or a “residual connection”, and these connections line up directly with the residual streams.

So, that explains the first part of their perspective. How about the part about attention “reading from” and “writing to” the residual stream? This suggests that we should be able to view the contents of the residual stream at a particular point in the forward pass as a store of information that is dynamically changing with each block, and in particular, attention plays the role of reading from and/or writing to these information stores. To see this, let’s revisit the final stage of the first perspective on attention, involving the role of the value and output transforms.

  1. Consider the product of the value transform WVRh×eW_V \in \mathbb{R}^{h' \times e} and the input matrix XRe×TX \in \mathbb{R}^{e \times T} in the attention mechanism. Previously, we thought of the value vectors vt=WVxtRhv_t = W_V x_t \in \mathbb{R}^{h'} as representing information retrieved in response to a search query. The second perspective subtly recasts the value transform as reading information stored somewhere in the ttth residual stream.

  2. The next step in the calculation is to right-multiply by the column stochastic attention matrix, A~RT×T\tilde A \in \mathbb{R}^{T \times T}, forming the mixed value matrix V~=VA~Rh×T\tilde V = V \tilde A \in \mathbb{R}^{h' \times T}. Recall that the columns of this matrix are of the form v~i=Va~i=t=1T(a~i)tvtRh, \tilde v_i = V \tilde a_i = \sum_{t=1}^T (\tilde a_i)_ t v_t \in \mathbb{R}^{h'}, where aiRTa_i \in \mathbb{R}^T is the iith column of the attention matrix, a probability vector. In the second perspective, these vectors represent which information read from each residual stream is to be routed to a particular residual stream.

    Compare to the search perspective, where the emphasis was mostly on the detailed structure of the attention matrix as in terms of soft ranked matches between queries and keys—the second perspective emphasises the form of the attention matrix less than its function, which is to direct the information that was previous read from each residual stream on its way to its destination residual stream.

  3. The last piece of the picture is for the information the information routed to residual stream ii to be written out to that stream. Thus, we come to the output transform, WORh×eW_O \in \mathbb{R}^{h' \times e}. The product Y=WOV~Y = W_O^\top \tilde V transforms each column of the mixed value matrix v~iRh\tilde v_i \in \mathbb{R}^{h'} into an output yi=WOv~iRey_i = W_O^\top \tilde v_i \in \mathbb{R}^e. Since we’re using residual connections, these outputs are subsequently added to the previous state of that residual stream—we can interpret this as writing to that residual stream.

There you have it: attention as routing information between residual streams, governed by the attention matrix (for deciding which streams to connect) and the value and output transforms (for deciding which information to read and write, respectively).

Remarks on perspective 2

Before we move on to discuss the next perspective, I wanted to make a few remarks about the impact the above perspective has had on attempts to understand transformer cognition.

First, the idea of the attention matrix as directing the communication between residual streams has played a fundamental role in work on mechanistic interpretability. Most mechanistic interpretability studies of transformers I have seen involve visualising attention patterns, which is also well-motivated from this perspective.

However, the emphasis on the attention pattern is not uniquely motivated by this perspective. Even the original attention paper visualised cross-attention patterns to confirm that the learned alignments between source sentence tokens and target sentence tokens matched their intuitive relationships. The emphasis on the attention pattern appears to arise naturally from each of the standard perspectives on attention.

Something that does appear to be unique to this framework is the central emphasis on the residual stream itself as a key part of the broader transformer architecture. This is a literal change in perspective on the architecture—compare the illustrations of transformers in the “mathematical framework” to the standard illustrations such as in “Attention is all you need.” The former illustrations place the residual stream at the center—a direct line from embedding to unembedding—with attention and MLP blocks to the side, as if they are optional detours. In contrast, the standard illustrations of the architecture emphasise the sequence of transformer blocks and leave the residual stream implicit.

The idea to emphasise the residual stream appears to have had a deep impact. Many mechanistic interpretability studies have looked at probing the contents of the residual stream to see what information can be reconstructed from it at different points in the forward pass. This method for studying the contents of transformer cognition is particular well-motivated from the above perspective.4

Perspective 3: Attention as two coupled transforms

The “mathematical framework” briefly discusses one further perspective on the attention mechanism, namely the perspective that it can be factorised into two simple computations that interact through the softmax function. Recall, one final time, the equation for attention, Y=WOWVXσ ⁣(XWKWQX). Y = W_O^\top W_V X \sigma\!\left( X^\top W_K^\top W_Q X \right). Elhage et al. (2021) observed that the query/key transforms only interact through their product, and likewise for the output/value transforms. Moreover, the role of each of these products has a particularly neat interpretation.

  1. WKWQRe×eW_K^\top W_Q \in \mathbb{R}^{e \times e} is a square matrix (with rank at most hh, due to the shapes of WQW_Q and WKW_K). Call this product matrix the attention transform, WA=WKWQW_A = W_K^\top W_Q. The attention transform determines a quadratic relationship that gives the entries of the score matrix: A=XWAX=[x1WAx1x1WAx2x1WAxTx2WAx1x2WAx2x2WAxTxTWAx1xTWAx2xTWAxT]RT×T. A = X^\top W_A X = \begin{bmatrix} x_1^\top W_A x_1 & x_1^\top W_A x_2 & \cdots & x_1^\top W_A x_T \\ x_2^\top W_A x_1 & x_2^\top W_A x_2 & \cdots & x_2^\top W_A x_T \\ \vdots & \vdots & \ddots & \vdots \\ x_T^\top W_A x_1 & x_T^\top W_A x_2 & \cdots & x_T^\top W_A x_T \end{bmatrix} \in \mathbb{R}^{T \times T}. This is a particularly simple way of looking at the score matrix.

  2. WOWVRe×eW_O^\top W_V \in \mathbb{R}^{e \times e} is a square matrix (with rank at most hh'). Call this product the information transform, WI=WOWVW_I = W_O^\top W_V. The information transform is a single linear transformation that takes the contents of a residual stream and determines how it should be turned into an ‘update’ for other residual streams, ui=WIxiReu_i = W_I x_i \in \mathbb{R}^e. We can aggregate these updates as the rows of an update matrix U=WIXRe×TU = W_I X \in \mathbb{R}^{e \times T}.

  3. Attention is essentially the interaction of these two very simple (linear or quadratic) computations through the scaled and masked softmax: First, σ\sigma turns the score matrix AA into a column-probability matrix A~=σ(A)\tilde A = \sigma(A). Then, this column-probability matrix multiplies the update matrix UU to determine which mixtures of updates are going to be performed by the attention block.

Remarks on perspective 3

What does this perspective tell us? One idea you might have, upon noticing that your architecture is really about two ‘product’ transforms rather than four ‘factor’ transforms, is to ask, why learn the factors in the first place, rather than learning the product matrices directly? Concretely, why don’t people parameterise attention by WAW_A and WIW_I directly instead of WQW_Q, WKW_K, WVW_V, and WOW_O? I can think of a couple of reasons that might explain why the factor parameterisation persists:

So, what is the use of this reparameterisation? Even if we don’t choose to parameterise our transformers this way for reasons of their learning behaviour or efficiency, it’s still a valid mathematical perspective on the computations that we are effectively doing.

This offers us an abstraction—instead of focusing on the distinction between queries and keys (central to the search perspective), or the distinction between reading and writing (central to the residual stream perspective), and abstractly views attention in terms of these two mathematically simple transformations of the input and their interaction through the nonlinear softmax function.

This perspective can be helpful when mathematically analysing attention, because it suggests that rather than looking individually at the weights of the factor transforms, we could be looking at the product transforms instead, and thinking of the mechanism in terms of the effect of the attention transform and the effect of the information transform.

Perspective 4: Attention as a message exchange

This brings us to the fourth and final perspective: attention as a message exchange protocol between communicating computational units. I attribute this perspective to Andrej Karpathy, who introduces it in a Stanford lecture on the transformer architecture and its history (Karpathy, 2023).

This perspective is quite closely related to the search perspective and the routing perspective, but with an emphasis more on the sequence positions and their connectivity properties. In fact, it begins with an explicit network representation of those sequence positions and their connectivity properties.

  1. The network has one node for each sequence position. Let’s denote these nodes by their indices 1,,TN1, \ldots, T \in \mathbb{N}. The important thing about these nodes is that they have state that evolves throughout the forward pass. The state is the value of the (modified) internal representation. That is, the state of the nodes is essentially the residual stream from perspective 2.

  2. The network has edges between nodes given by the mask matrix. Recall from perspective 1 that the mask matrix M{0,1}T×TM \in \{0,1\}^{T \times T} is such that Mj,i=1M_{j,i}=1 indicates that position jj is a valid result for the query from position ii. Let’s use this as the condition for creating an edge from node jj to node ii. In other words, MM is the adjacency matrix of the graph.

Given such a graph, the attention mechanism is a protocol for the nodes to exchange messages along the edges of the network. The effect of the protocol is to map the initial state of each node x1,,xTRex_1, \ldots, x_T \in \mathbb{R}^e to a corresponding state update y1,,yTRey_1, \ldots, y_T \in \mathbb{R}^e (that is, the new state of node ii would then be xi+yix_i+y_i). The protocol for passing messages and computing these updates proceeds as follows.

  1. Each node acts as a receiver. Node ii uses the query transform WQRh×eW_Q \in \mathbb{R}^{h \times e} and its current state xiRex_i \in \mathbb{R}^e to compute a receiving channel qi=WQxiRhq_i = W_Q x_i \in \mathbb{R}^h. As before, we can group the receiving channels into a receiving channel matrix Q=WQXRh×TQ = W_Q X \in \mathbb{R}^{h \times T}.

  2. Each node also acts as a sender. Node jj uses the key transform WKRh×eW_K \in \mathbb{R}^{h \times e} and its current state xiRex_i \in \mathbb{R}^e to compute a sending channel ki=WKxiRhk_i = W_K x_i \in \mathbb{R}^h. As before, we can group the sending channels into a sending channel matrix K=WKXRh×TK = W_K X \in \mathbb{R}^{h \times T}.

  3. You’re getting the picture—for each edge (j,i)(j, i), we compute a channel overlap score aj,i=kjqiRa_{j,i} = k_j ^\top q_i \in \mathbb{R}. The channel overlap scores are the entries of a channel overlap matrix A=KQRT×TA = K^\top Q \in \mathbb{R}^{T \times T}. (As written, this matrix also has channel overlap scores for pairs of nodes that don’t have edges between them—the masking will be applied in the next step).

  4. Now we use the scaled and masked softmax function to transform the channel overlap scores into an attention matrix. The math is the same as before, of course, but the interpretation can be given a little more succinctly here in terms of the graph. The iith column of the attention matrix is a probability vector encoding the softmax distribution of the scaled channel overlap scores along all incoming edges, a~i=1j=1iMj,iexp(βaj,i)[M1,iexp(βa1,i)MT,iexp(βaT,i)]ΔT1RT. \tilde a_i = \frac1{\sum_{j=1}^i M_{j,i} \exp(\beta a_{j,i})} \begin{bmatrix} M_{1,i} \exp(\beta a_{1,i}) \\\vdots \\M_{T,i} \exp(\beta a_{T,i}) \end{bmatrix} \in \Delta^{T-1} \subset \mathbb{R}^{T}. I’ll call this distribution a~i\tilde a_i the listening distribution for node ii, because it will be used to mix together the incoming messages. The attention matrix is the column-stochastic matrix with a~i\tilde a_i as its columns, A~=[a1  aT]RT×T\tilde A = [ a_1 ~ \cdots ~ a_T ] \in \mathbb{R}^{T \times T}.

  5. In its capacity as a sender, each node jj then has to decide what message it wants to send on its sending channel. We use the value transform WVRh×eW_V \in \mathbb{R}^{h' \times e} and the current state xjRex_j \in \mathbb{R}^e to compute an outgoing message vj=WVxjRhv_j = W_V x_j \in \mathbb{R}^{h'}. As before, we can group the outgoing messages into an outgoing message matrix V=WVXRh×TV = W_V X \in \mathbb{R}^{h' \times T}.

  6. In its capacity as a receiver, each node receives messages from its neighbours along its receiving channel. The messages come mixed together using the listening distribution, so that the aggregate message received by node ii, called the mixed incoming message, is v~i=Vai=T=1T(a~i)tvtRh.\tilde v_i = V a_i = \sum_{T=1}^T (\tilde a_i)_ t v_t \in \mathbb{R}^{h'}.

  7. Communication having concluded, the final step of the protocol is for each node to transform its received message into a state update. We use the output transform WORh×eW_O \in \mathbb{R}^{h' \times e} to transform the mixed incoming message into a state update, yi=WOv~iRey_i = W_O^\top \tilde v_i \in \mathbb{R}^e.

That concludes the protocol, and with it, the description of the fourth and final perspective on the attention computation. Before wrapping up, let me make a few remarks about what this perspective offers us.

Remarks on perspective 4

For me, studying, this message exchange perspective made a few subtle points about attention more salient than they were before.5

As a consequence, we can think of an attention mechanism as having a somewhat limited bandwidth for mutual information exchange. In practice, this limitation is ameliorated by having multiple attention heads running in parallel, which allows different heads to specialise in specific types of information exchange.

In a similar vein, Karpathy emphasises that this message exchange perspective on attention provides a neat interpretation of the mechanism’s role within the broader transformer. He casts the body of the forward pass as an alternation between:

  1. communication phases, wherein (multi-headed) attention blocks exchange information along the edges of the network; and

  2. computation phases, wherein MLP blocks do complex nonlinear operations on the information stored at each node.

Karpathy extends this perspective into a view in which a transformer is an efficient, general-purpose, differentiable computational processing unit.

While these properties are individually satisfied for various other systems, transformers combine the best of all of them. And the key to all of these properties is the attention mechanism, which plays the role of allowing the computation to process sequences.

Conclusion

As I said, these perspectives are quite closely related to each other. If there are important differences, they are very subtle. I could attempt to draw the following distinctions between perspectives 1, 2, and 4 (though, to be honest, this breakdown still seems a little forced):

  1. The soft search perspective describes attention from the perspective of a given sequence position looking out over the rest of the sequences. The ‘computational unit’ at that position is ‘seeking’ information that may be present elsewhere in the sequence. Attention describes the process by which that information is identified and retrieved.

  2. The information routing perspective describes attention from a perspective ‘outside’ of the individual sequence positions. The attention mechanism itself plays an active role coordinating the movement of information between residual streams based on their current contents.

  3. The message exchange perspective is somewhere between these perspectives. It returns to the perspective of individual units driving the interaction, but places greater emphasis on the units actively sending the messages (not just those receiving messages), achieving a similar symmetry between senders and receivers as between readers and writers in the information routing perspective.

Of course, it’s well understood in computer science that reading/writing from storage or sending/receiving messages are equivalent pursuits in some sense. In this case we see a form of information retrieval (soft search) that is also equivalent to information storage and information communication. That these three perspectives are related is not a unique property of attention.

The exception to this similarity is the third perspective. That perspective seems genuinely different from the other three. To elide the distinction between queries and keys and focus on the product as a (low-rank) quadratic form, and likewise for the information transform, is an example of an abstraction. This perspective seems to offer something valuable, at least if we just want to understand the input/output mapping of a given attention mechanism. On the other hand, if we want to understand the statistical properties or training dynamics, the fact that we actually use a factored parameterisation rather than the product parameterisation becomes important, and this abstraction becomes inappropriate.

Nevertheless, I am satisfied to have teased apart these four perspectives. I noticed a few points about the architecture I hadn’t noticed before. I don’t know if I have a deeper understanding about attention, but I do feel prepared to adopt the appropriate perspective in a given situation.

I am only left to wonder whether there are other fruitful perspectives on attention that are not as prominent. Dear reader, do you know of any? Please share them. For my part, I have been toying with an apparently quite different perspective of attention as a contextual multi-layer perceptron, that is, an MLP with weights determined by the state of the input sequence. This perspective is implicit in an earlier note I wrote about a paper on universal approximation in transformers. It doesn’t seem like a perfect fit, unless you have a very specific attention mask structure. I need to think about this a bit more. If I figure it out, I’ll make a follow-up post.

Bibliography

Academic papers on the attention mechanism:

Technical introductions to the transformer architecture, including attention. Listed in roughly increasing order of time investment required.

More references I haven’t read myself:


  1. No guarantees.↩︎

  2. Throughout, I use column vectors and generally try to keep track of what the columns of matrices represent. A linear transform F:RaRbF : \mathbb{R}^a \to \mathbb{R}^b is represented as a matrix FRb×aF \in \mathbb{R}^{b \times a}, applied by pre-multiplying against a column vector, Fx=F(x)F \cdot x = F(x), or against a matrix of column vectors F[x1  xT]=[F(x1)  F(xT)]F \cdot [x_1 ~ \cdots ~ x_T] = [F(x_1) ~ \cdots ~ F(x_T)]. The most noticeable consequences are for the attention matrix. First, we compute a product of keys and queries, rather than queries and keys, and second, after the softmax, the columns are probability distributions (rather than the rows).↩︎

  3. To extend the search analogy, I might have called vtv_t a ‘result vector’ or something.↩︎

  4. People have been probing neural network activations since well before transformers. However, probes have a special relevance to transformers due to the use of residual connections and the idea that attention and MLP blocks operate on the contents of the stream through simple addition operations. I am not so familiar with the history of interpretability before transformers—it’s possible that this perspective emerged earlier, in the context of residual networks, where (at least parts of) architectures have similar residual streams. However, to my knowledge, at least in the context of transformers, the residual stream perspective is original to the “mathematical framework.”↩︎

  5. These points are at least implicit in the other perspectives. Each potential ‘result’ position selects a single key in response to all queries under the soft search perspective, and the same information is read from each residual stream in the information routing perspective, but I didn’t notice these implications until they became obvious when writing out the above protocol. The mathematical perspective might have made these things more explicit—I guess these are properties of the quadratic and linear structure of the separate attention and information pathways—however, I’m not a ‘native speaker’ in that part of mathematics and I missed the implication initially. Of course, the broader point is that it doesn’t matter that I missed these implications initially, that’s what the fourth perspective is for!↩︎