~Transformers: Perceptrons in disguise
Attention is a key element of the transformer architecture. It can be understood from several related perspectives: parallel soft search, reading from / writing to residual streams, loosely coupled QK/OV circuits, and message passing.
In this technical note, I explore another perspective, quite different from those standard ones: Attention as a context-dependent single-hidden-layer feed-forward neural network, as in, a multi-layer perceptron for which the weights depend on the context.
Acknowledgement: This perspective is not original to me. I exapted it from a paper I reviewed. After writing this note, I had the delightful experience of discovering that Schmidhuber had already published a similar idea.
§The attention mechanism
Attention blocks are the defining component of a transformer sequence model. The first attention block takes the position-augmented embeddings of the prompt’s tokens as inputs, moves information between them in a specific way, and produces a sequence of residual vectors to be added back to the respective residual streams. Subsequent attention blocks repeat the same kind of operation with the current state of each residual stream as input. Eventually, after several attention blocks interleaved with MLP blocks that operate on each residual stream in isolation, the transformer produces its predictions.
For the purposes of this note, let’s focus on the part of an attention block that produces a residual vector for the single stream at some position . The input to this computation is the current state of this stream and all previous streams (a similar computation is used at each position, discussing a single step will dramatically simplify our notation and remove the need to discuss masking).
Formally, let be an embedding dimension and a hidden/head dimension (usually, a divisor of ). Let be query, key, value, and output matrices parametrising this attention block. The attention mechanism for producing the residual vector at position is then a function where, for ,
§The standard interpretation of attention
The usual way of interpreting the above attention expression would be something like the following:
Interpret as a vector query and each as vector search keys.
Their inner products are to be used as the weights for a softmax distribution with inverse temperature , The distribution forms the ‘attention pattern’ for the token at position .
Interpret each as a vector ‘value,’ representing what meaningful information is to be conveyed to the current residual stream.
Produce a single integrated value vector , mixing together the individual value vectors according to the softmax attention pattern.
Produce the final residual vector , using the output matrix to project the result from the head/hidden space back into the embedding space.
In my previous note, I discuss this perspective in more detail, along with a few related interpretations. In the next sections, I’ll discuss a different interpretation of the attention expression that I don’t often see discussed.
§Attention as a contextual multi-layer perceptron
A multi-layer perceptron is the simplest deep neural network. A typical formulation of a multi-input, multi-output is to let be a number of input/output neurons and let be a number of hidden neurons. Then, let be sets of incoming weight vectors () and outgoing weight vectors (), and let be some activation function. Then, a (biasless, single-hidden-layer) multi-layer perceptron is a function where, for ,
The point I want to make is that we can realise the attention expression in this format as follows. Assume we are given an input embedding sequence and an attention block with parameters . Then:
For the incoming weights, set . Thus, .
For the outgoing weights, set .
For the activation function, use such that where .
After making this identification, we have
Note that this identity requires that we use as defined in terms of its input, . The same MLP won’t coincide with attention for other inputs, because was used in determining the normalisation as well as the weights for the final neuron ( and ). We could partially recover from this issue by considering a global softmax activation on the hidden layer rather than an element-wise exponential activation with hard-coded normalisation, but the dependency of the weights of unit would remain.
Anyway, at least when the inputs line up correctly, this identity gives us our new interpretation of the attention mechanism:
The prompt, including the embedding at position , determine the weights and activation function of a -hidden-neuron MLP:
The query–key product constructs an incoming weight vector from each hidden neuron from the embedding at the corresponding position.
The output–value transform constructs the corresponding outgoing weight vector.
The activation function is an exponential that is normalised so that the activations across all neurons sums to one when the input happens to be .
The residual for the token at position is given by applying this MLP to the embedding at that token position, .
That is what I mean by “attention as a context-dependent single-hidden-layer feed-forward neural network, as in, a multi-layer perceptron for which the weights depend on the context.”
§Implications for understanding transformers
Generally, the more perspectives you have on an object, the better able you will be to understand it. This is because different perspectives will be useful for different questions you might want to ask. There are a lot of unanswered questions about how large-scale transformers work. My hope in sharing this perspective, is that it might show the way to answers we have previously overlooked.
For example, much of mechanistic interpretability attempts to understand the computation of transformer blocks in terms of attention patterns and other observables that are downstream of the search/communication perspectives. The contextual multi-layer perceptron perspective suggests a new view of attention as a recipe for constructing MLPs from contexts. Can we study the query–key and value–output transforms in terms of the incoming/outgoing weight vectors they induce for different prompts? Can we generally study attention blocks as contextually-determined functions of the current token’s embedding?
As another example, in generative language modelling, the outputs of the attention computation loop back to become inputs into the next generation, adding progressively more hidden units to an ever-widening MLP (up to normalisation). The outputs of the transformer determine its inputs. In the case of chain of thought prompting, the transformer’s task is literally to write the inputs that determine the weights of its own contextual MLPs, programming the neurons it uses to determine its own final answer. If we train the transformer with reinforcement learning, we are effectively teaching the transformer to do this in a way that corresponds to getting high-scoring final answers. It’s unclear to me how to think of this from the search/communication perspectives at all.
I’m yet to try applying this new perspective to a serious problem involving understanding transformers, but now that I have it in my toolbox, I’m sure an opportunity to use it will come up sooner or later.