far.in.net


~Transformers: Perceptrons in disguise

Attention is a key element of the transformer architecture. It can be understood from several related perspectives: parallel soft search, reading from / writing to residual streams, loosely coupled QK/OV circuits, and message passing.

In this technical note, I explore another perspective, quite different from those standard ones: Attention as a context-dependent single-hidden-layer feed-forward neural network, as in, a multi-layer perceptron for which the weights depend on the context.

Acknowledgement: This perspective is not original to me. I exapted it from a paper I reviewed. After writing this note, I had the delightful experience of discovering that Schmidhuber had already published a similar idea.

§The attention mechanism

Attention blocks are the defining component of a transformer sequence model. The first attention block takes the position-augmented embeddings of the prompt’s tokens as inputs, moves information between them in a specific way, and produces a sequence of residual vectors to be added back to the respective residual streams. Subsequent attention blocks repeat the same kind of operation with the current state of each residual stream as input. Eventually, after several attention blocks interleaved with MLP blocks that operate on each residual stream in isolation, the transformer produces its predictions.

For the purposes of this note, let’s focus on the part of an attention block that produces a residual vector for the single stream at some position TNT \in \mathbb{N}. The input to this computation is the current state of this stream and all previous streams (a similar computation is used at each position, discussing a single step will dramatically simplify our notation and remove the need to discuss masking).

Formally, let DND \in \mathbb{N} be an embedding dimension and HNH \in \mathbb{N} a hidden/head dimension (usually, a divisor of DD). Let WQ,WK,WV,WORH×DW_Q, W_K, W_V, W_O \in \mathbb{R}^{H \times D} be query, key, value, and output matrices parametrising this attention block. The attention mechanism for producing the residual vector at position TT is then a function fWQ,WK,WV,WO:(RD)TRDf_{W_Q,W_K,W_V,W_O} : (\mathbb{R}^D)^T \to \mathbb{R}^D where, for x1,,xTRDx_1, \ldots, x_T \in \mathbb{R}^D, fWQ,WK,WV,WO(x1,,xT)=i=1TWOWVxiexp(1HxiWKWQxT)j=1Texp(1HxjWKWQxT). f_{W_Q,W_K,W_V,W_O}(x_1, \ldots, x_T) = \sum_{i=1}^{T} W_O^\top W_V x_i \frac {\exp\left( \dfrac1{\sqrt{H}} x_i^\top W_K^\top W_Q x_T \right)} {\displaystyle \sum_{j=1}^{T} \exp\left( \dfrac1{\sqrt{H}} x_j^\top W_K^\top W_Q x_T \right)}.

§The standard interpretation of attention

The usual way of interpreting the above attention expression would be something like the following:

In my previous note, I discuss this perspective in more detail, along with a few related interpretations. In the next sections, I’ll discuss a different interpretation of the attention expression that I don’t often see discussed.

§Attention as a contextual multi-layer perceptron

A multi-layer perceptron is the simplest deep neural network. A typical formulation of a multi-input, multi-output is to let DND \in \mathbb{N} be a number of input/output neurons and let TNT \in \mathbb{N} be a number of hidden neurons. Then, let a1,,aT,b1,,bTRDa_1, \ldots, a_T, b_1, \ldots, b_T \in \mathbb{R}^D be sets of incoming weight vectors (bib_i) and outgoing weight vectors (aia_i), and let σ:RR\sigma : \mathbb{R}\to \mathbb{R} be some activation function. Then, a (biasless, single-hidden-layer) multi-layer perceptron is a function ga1,,aT,b1,,bT,σ:RDRDg_{a_1,\ldots,a_T,b_1,\ldots,b_T,\sigma} : \mathbb{R}^D \to \mathbb{R}^D where, for xRDx \in \mathbb{R}^D, ga1,...,aT,b1,...,bT,σ(x)=i=1Taiσ(bix). g_{a_1,...,a_T,b_1,...,b_T,\sigma}(x) = \sum_{i=1}^{T} a_i \sigma( b_i^\top x) .

The point I want to make is that we can realise the attention expression in this format as follows. Assume we are given an input embedding sequence x1,,xTRDx_1, \ldots, x_T \in \mathbb{R}^D and an attention block with parameters WQ,WK,WV,WORH×DW_Q, W_K, W_V, W_O \in \mathbb{R}^{H \times D}. Then:

After making this identification, we have fWQ,WK,WV,WO(x1,,xT)=ga1,...,aT,b1,...,bT,σ(xT). f_{W_Q,W_K,W_V,W_O}(x_1, \ldots, x_T) = g_{a_1,...,a_T,b_1,...,b_T,\sigma}(x_T) .

Note that this identity requires that we use gg as defined in terms of its input, xTx_T. The same MLP won’t coincide with attention for other inputs, because xTx_T was used in determining the normalisation as well as the weights for the final neuron (aTa_T and bTb_T). We could partially recover from this issue by considering a global softmax activation on the hidden layer rather than an element-wise exponential activation with hard-coded normalisation, but the dependency of the weights of unit TT would remain.

Anyway, at least when the inputs line up correctly, this identity gives us our new interpretation of the attention mechanism:

That is what I mean by “attention as a context-dependent single-hidden-layer feed-forward neural network, as in, a multi-layer perceptron for which the weights depend on the context.”

§Implications for understanding transformers

Generally, the more perspectives you have on an object, the better able you will be to understand it. This is because different perspectives will be useful for different questions you might want to ask. There are a lot of unanswered questions about how large-scale transformers work. My hope in sharing this perspective, is that it might show the way to answers we have previously overlooked.

For example, much of mechanistic interpretability attempts to understand the computation of transformer blocks in terms of attention patterns and other observables that are downstream of the search/communication perspectives. The contextual multi-layer perceptron perspective suggests a new view of attention as a recipe for constructing MLPs from contexts. Can we study the query–key and value–output transforms in terms of the incoming/outgoing weight vectors they induce for different prompts? Can we generally study attention blocks as contextually-determined functions of the current token’s embedding?

As another example, in generative language modelling, the outputs of the attention computation loop back to become inputs into the next generation, adding progressively more hidden units to an ever-widening MLP (up to normalisation). The outputs of the transformer determine its inputs. In the case of chain of thought prompting, the transformer’s task is literally to write the inputs that determine the weights of its own contextual MLPs, programming the neurons it uses to determine its own final answer. If we train the transformer with reinforcement learning, we are effectively teaching the transformer to do this in a way that corresponds to getting high-scoring final answers. It’s unclear to me how to think of this from the search/communication perspectives at all.

I’m yet to try applying this new perspective to a serious problem involving understanding transformers, but now that I have it in my toolbox, I’m sure an opportunity to use it will come up sooner or later.