ATTENTION, FULLY ASSEMBLED
Section 13.1
01

Attention as soft dictionary lookup

By the end of Ch.11 we had embeddings — every token is a vector. By the end of Ch.12 we had softmax — turn any real vector into a probability distribution, do it in a numerically stable way, do it in blocks. Now we have everything we need for attention. Attention answers a simple question: for each position in a sequence, what information from the OTHER positions should I pull in? The mechanism is a soft, differentiable dictionary lookup. Every position emits three vectors — a query (what am I looking for?), a key (what do I offer to be found by?), and a value (what’s my actual content?). Queries get matched against keys via dot products; the matches softmax into weights; the weights pick out a soft combination of values. That’s it. The math is straightforward; the implementation is straightforward; the memory cost (an N × N matrix per attention head per layer) is what makes long-context attention hard and is what FlashAttention solves.

The three roles

The Vaswani 2017 “Attention is all you need” formulation is famously compact:

Attention(Q, K, V) = softmax( Q · Kᵀ / √d_k ) · V Q ∈ ℝ^{N × d_k} one query vector per position K ∈ ℝ^{N × d_k} one key vector per position V ∈ ℝ^{N × d_v} one value vector per position d_k query/key dimension (usually = d_model / h_heads) d_v value dimension (usually = d_k) N sequence length

Where do Q, K, V come from? Each is a linear projection of the input embeddings:

Given input X ∈ ℝ^{N × d_model}, the three projections are: Q = X · W_Q W_Q ∈ ℝ^{d_model × d_k} K = X · W_K W_K ∈ ℝ^{d_model × d_k} V = X · W_V W_V ∈ ℝ^{d_model × d_v} W_Q, W_K, W_V are LEARNED — they are what attention learns.

This is the only place in the architecture where the model gets to learn what it should look for, what it should offer, and what it should hand over. Everything else (the softmax, the matrix multiplications, the residual connection) is fixed.

Query (Q) — the request. “I’m position 7; here’s what I’m looking for.”

Key (K) — the advertised handle. “I’m position 3; here’s the signature you’d recognise me by.”

Value (V) — the actual payload. “I’m position 3; if you decide to attend to me, here’s what you’ll get.”

The lookup itself is just dot products. Position i’s query is dotted with EVERY position’s key. The score for position i attending to position j is Q_i · K_jᵀ. With N positions, this produces an N × N matrix of scores — the attention score matrix.

attend FROM position:
position j
token
Q · K_jᵀ
÷√d_k
softmax weight
0
the
1.78
0.63
0.210
1
cat
0.15
0.05
0.118
2
sat
-1.34
-0.47
0.070
3
on
-1.09
-0.39
0.076
4
the
0.03
0.01
0.113
5
mat
0.97
0.34
0.158
6
and
0.31
0.11
0.125
7
purred
0.39
0.14
0.129
output for position 1 ('cat') = Σ weight_j · V_j
-0.35-0.390.19-0.010.030.17-0.230.22
Toggle the 1/√d_k scaling and watch the softmax weights. With raw dot products (d_k = 8, so std ≈ √8 ≈ 2.83), weights sharpen aggressively onto one or two positions — the soft lookup becomes a near-hard argmax, which is what kills training at initialisation. The scaling restores a learnable, diffuse weighting.
A query at one position dots against every key to produce N raw scores. The 1/√d_k scaling keeps softmax in a learnable regime regardless of head dimension. The output is the softmax-weighted sum of value vectors.

The viz: change the query position, watch the per-key scores and the resulting softmax weights. The output is a convex combination of value vectors, weighted by how well each key matches the query.

— think, then check —

Query (Q) — what this position is looking for. Q = X · W_Q.

Key (K) — what this position advertises (so others can find it). K = X · W_K.

Value (V) — what this position will hand over if attended to. V = X · W_V.

Why three? The role a position plays as a retriever is fundamentally different from its role as a retrievable target, which is fundamentally different from its role as a content source. If you tied them, you’d lose expressivity:

  • Q = K would force “what you’re looking for” to match “how you advertise yourself” — a token searching for verbs would also advertise itself as a verb, which is wrong; a noun looking for its verb is not itself a verb.
  • K = V would force “how you advertise” = “what content you hand over” — but a position might be useful as a syntactic anchor (its key is its part-of-speech signal) while its actual value is its semantic content.
  • Empirically, decoupled Q, K, V is required for attention to learn anything beyond trivial copying patterns.

The three projections W_Q, W_K, W_V are how the model learns what role each dimension of the embedding plays in retrieval. Untying them is what gives attention its expressive power.

Why the 1/√d_k scaling?

The softmax in the attention formula is applied to scaled scores: Q · Kᵀ / √d_k, not raw Q · Kᵀ. The scaling factor matters. Why?

Recall from Ch.6 §3 that the dot product of two random unit vectors in d dimensions has variance that grows with d. Specifically, if q, k ∼ 𝒩(0, I_d) are independent standard Gaussians, then:

Var(q · k) = Var( Σ_l q_l · k_l ) = Σ_l Var(q_l · k_l) (independence) = Σ_l E[q_l² · k_l²] − E[q_l · k_l]² = Σ_l (1 · 1) − 0 (zero-mean factors) = d_k ⇒ std(q · k) = √d_k

So raw dot products have standard deviation that grows like √d_k. For d_k = 64 (a typical head dimension), the typical dot product magnitude is around ±8. For d_k = 128, around ±11.

Why this breaks the softmax. Softmax sharpens hard when its inputs span a wide range. If raw scores have standard deviation 8, then with random initialisation, a few positions will have scores like +24 and a few will have −24 — and e^24 / (e^24 + 60·e^0) ≈ 1, so the softmax saturates to a near-one-hot distribution. The gradient through softmax at saturation is tiny (recall §12.2 — the softmax+CE gradient is p − y, and when p saturates, gradients vanish). Training stalls at initialisation.

The fix: divide by √d_k so that the scaled scores have unit standard deviation regardless of head size:

Q_i · K_jᵀ / √d_k has variance approximately 1 (at initialisation) The softmax now sees inputs in a "reasonable" range — neither saturated nor too flat to learn from.

The 1/√d_k scaling is one of those small implementation details that doesn’t show up in the high-level “attention is all you need” pitch but is necessary for training stability.

— think, then check —

The derivation: For independent standard-Gaussian q, k ∈ ℝ^d:

Var(q · k) = Σ_l Var(q_l · k_l) = Σ_l E[q_l² k_l²] = Σ_l 1 = d.

So std(q · k) = √d. For d_k = 128, raw dot products have std ≈ 11.

What breaks without scaling: The softmax over a row of N scores with std 11 will sharpen hard at initialisation. A score 22 above the mean (2 standard deviations away) gets weight e^22 ≈ 3.6 billion times higher than the mean score. The softmax saturates to a near-one-hot distribution, focusing all attention on the highest-score position.

Why this kills training: The softmax+CE gradient is p − y (§12.2). When p saturates to one-hot, p − y is near-zero everywhere except at the wrong-prediction position, where it’s near-one. The gradient through softmax becomes a sparse spike — and downstream (through QKᵀ), gradients to most positions are near-zero. The network can’t learn what to attend to because attending is too sharp from the start.

The fix: Divide by √d_k so the scaled scores have variance 1 regardless of head size. Softmax now operates in a learnable regime — neither saturated nor flat. The same softmax can learn to sharpen progressively over training.

If you increased d_k from 64 to 256 without changing scaling, attention training would degrade — softmax would saturate harder. The √d_k scaling is what makes attention head dimension a free hyperparameter.

The full attention computation, sized

Putting it all together. Suppose N = 8192 (a moderate context), d_model = 4096, d_k = d_v = 128 per head. The computation per attention head:

1. Projection: Q, K, V = X · W_Q, X · W_K, X · W_V cost: 3 · N · d_model · d_k ≈ 3 · 8192 · 4096 · 128 ≈ 13 GFLOPS memory for Q, K, V: 3 · N · d_k · 2 bytes ≈ 6 MB 2. Scores: S = Q · Kᵀ / √d_k shape N × N cost: N² · d_k ≈ 8192² · 128 ≈ 8.6 GFLOPS memory for S: N² · 2 bytes ≈ 128 MB 3. Softmax: A = softmax_row(S) shape N × N (same as S) cost: N² (a few flops per element) ≈ 0.3 GFLOPS memory for A: N² · 2 bytes ≈ 128 MB 4. Apply: O = A · V shape N × d_v cost: N² · d_v ≈ 8.6 GFLOPS memory for O: N · d_v · 2 bytes ≈ 2 MB TOTAL per head: ≈ 30 GFLOPS, ≈ 264 MB intermediate memory. For a 32-head, 32-layer transformer at N = 8192: compute: 30 GFLOPS × 32 heads × 32 layers ≈ 30 TFLOPS per forward pass intermediate memory: 264 MB × 32 heads × 32 layers ≈ 270 GB if naïvely stored.

The 270 GB number is what makes naïve long-context attention impossible. No GPU has that much memory; HBM transfers between RAM and SRAM dominate runtime even when it fits. The intermediate S and A matrices (each N²) are the bottleneck — at N = 64K, just S and A are 16 GB per head per layer in float16.

This is the problem FlashAttention solves. The key observation from §12.3: the softmax can be computed by streaming over blocks of K and V, maintaining running state (m, ℓ). If we ALSO maintain a running output O, we never have to materialise the full N × N score matrix S or the full N × N weight matrix A — only the running per-query-row state.

— think, then check —

Dominant cost: the N × N attention score matrix S and the N × N attention weight matrix A — one of each per head per layer.

Each is N² = 67M float16 entries = 128 MB. Multiplied by 32 heads × 32 layers = 1024 attention computations per forward pass, the intermediate memory alone is ~128 MB × 1024 = ~131 GB just for S + A across the whole network (per forward pass at N = 8192). And this scales as N² — at N = 64K, it’s a thousand times bigger.

Why it’s specifically this matrix: The compute (Q · Kᵀ and A · V matmuls) is dominated by the N² · d_k flops, but compute is “cheap” in the sense that GPUs are fast at matmul. The problem is the N × N matrix itself: it has to be read from HBM (the slow GPU memory), held in SRAM for the softmax, written back, then read again for the V multiplication. The HBM ↔ SRAM bandwidth is the actual wall-clock bottleneck.

FlashAttention’s fix: never materialise S or A in HBM. Tile K and V into blocks small enough to fit in SRAM; compute partial Q · K_block scores in SRAM; softmax them streaming with running (m, ℓ); accumulate the partial A_block · V_block into a running output O. The N × N matrix only ever exists in tiny per-block fragments that stay in SRAM. Memory cost drops from O(N²) to O(N). HBM accesses drop by a factor of (SRAM size / N) — concretely 5–10× speedup at long contexts.

We’ve now framed both the upside and the cost of attention. The upside — every position can pull information from every other position via learned content addressing — is the reason transformers replaced RNNs and CNNs as the default sequence model. The cost — O(N²) memory for the intermediate score matrix — is the constraint that shapes every long-context optimisation: FlashAttention (tiling), grouped-query attention (KV-cache reduction), sparse attention (only attend to a subset), linear attention (a different formulation entirely that trades O(N²) for O(N) at the cost of expressivity). The next two sections cover multi-head/GQA, then FlashAttention proper.

Next: §13.2 — Multi-head attention, MQA, GQA, and why head_dim = d_model / h.