ATTENTION, FULLY ASSEMBLED
Section 13.2
02

Multi-head, MQA, GQA — projection subspaces

§13.1 covered a single attention head. Real transformers run many heads in parallel. The motivation is simple: a single attention head produces ONE softmax weighting per query position — one “what should I pay attention to?” answer. But a sentence has many simultaneous structures (syntactic dependencies, anaphora, semantic associations, positional patterns). Letting the model run h independent attention computations in parallel — each in a lower-dimensional subspace of the embedding — gives it room to learn different relations in different heads. This section covers (1) the math of multi-head attention and why head_dim = d_model / h, (2) the KV cache that dominates inference memory, and (3) the MQA → GQA evolution that every modern open model adopted to shrink that cache.

Multi-head attention

Split the model dimension d_model into h heads, each operating in dimension d_k = d_model / h. Each head has its own W_Q, W_K, W_V projections; runs its own scaled-dot-product attention; produces its own output. The h outputs are concatenated and projected once more by W_O.

For h = 1..H heads, each with its own learned projections: Q_h = X · W_Q^h W_Q^h ∈ ℝ^{d_model × d_k} d_k = d_model / H K_h = X · W_K^h W_K^h ∈ ℝ^{d_model × d_k} V_h = X · W_V^h W_V^h ∈ ℝ^{d_model × d_v} d_v = d_model / H head_h = softmax( Q_h · K_hᵀ / √d_k ) · V_h ∈ ℝ^{N × d_v} MultiHead(X) = Concat(head_1, ..., head_H) · W_O W_O ∈ ℝ^{d_model × d_model}

Multi-head attention is the original Vaswani 2017 design and the default for nearly a decade.

Three observations about this construction:

  1. Parameter count is unchanged vs single-head with same d_model. A single head of width d_model would have W_Q, W_K, W_V each of size d_model × d_model — total 3 · d_model². Multi-head with h heads, each of width d_k = d_model / h, has H projections each of size d_model × d_k — total H · 3 · d_model · d_k = 3 · d_model². Same.
  2. Compute is also unchanged. Each head does O(N² · d_k) work; h heads in parallel is O(H · N² · d_k) = O(N² · d_model). Same as a single head of width d_model.
  3. So what does multi-head buy? The structure constraint. Different heads see different lower-dimensional subspaces; the softmax in each head produces a distinct weighting; the outputs cover different patterns. Voita 2019, Clark 2019 visualisations show heads specialising in distinct linguistic relations: one head tracks coreference, another tracks adjacent-token bigrams, another tracks subject–verb agreement, etc. A single fat softmax can’t represent multiple overlapping weightings simultaneously.
attend FROM:
mode:
KV cache ratio: 4× / 4× MHA = 100%
head 0
the
0.154
cat
0.119
sat
0.133
on
0.112
the
0.335
mat
0.147
head 1
the
0.170
cat
0.131
sat
0.223
on
0.120
the
0.105
mat
0.250
head 2
the
0.219
cat
0.141
sat
0.235
on
0.122
the
0.163
mat
0.121
head 3
the
0.197
cat
0.108
sat
0.294
on
0.126
the
0.162
mat
0.114
Each head has its own Q projection. K and V are either per-head (MHA), shared in groups (GQA), or fully shared (MQA). Watch how the softmax weights diverge across heads in MHA — different heads attend to different tokens for the same query. In MQA, all four heads share the same K, V; the weights cluster more tightly because the only freedom left per head is the query projection.
Four attention heads with independent Q projections. MHA/GQA/MQA control how K and V are shared. KV cache cost is proportional to the number of distinct KV heads — h for MHA, g for GQA, 1 for MQA.

The viz: same input sentence, two attention heads with different learned projections. Watch how the same query position attends to DIFFERENT keys in each head — that’s the structural win.

— think, then check —

Parameter and compute counts are unchanged. What changes is the structure of the softmax constraint:

A single softmax-over-keys produces ONE weighting per query position — a single convex combination of value vectors. If the model needs to simultaneously attend to (a) the syntactic parent, (b) the most recent verb, (c) the semantic referent of a pronoun — a single softmax cannot encode all three. It must compromise to a single weighting.

Multi-head splits this into h INDEPENDENT softmaxes, each in its own subspace with its own learned Q/K/V projections. Different heads can attend to different things in parallel. Head 1 can softmax-weight the syntactic parent; head 2 the recent verb; head 3 the pronoun’s referent. The concatenated output mixes all three signals.

So multi-head is best understood as a structural constraint on the softmax, not as added capacity. The model gets h “votes” instead of one. Empirically: removing heads after training shows large variation in head importance — some heads matter a lot, some are prunable (Michel 2019), confirming heads specialise.

The trade: smaller d_k means each head’s dot products are over a smaller feature space, so each head is less expressive individually. But the gain of h independent weightings outweighs this. Optimal h has been found empirically to be around 16-32 for typical d_model ∈ [2K, 8K].

The KV cache: where inference memory goes

So far the math treats Q, K, V as if they’re all computed fresh on every forward pass. During training this is true. During autoregressive inference, it’s wasteful — and the waste is dominant.

The fix: cache the K and V vectors as they’re computed, and only compute fresh K, V for the newest token.

KV cache stores, for each attention layer and head, the K and V vectors for every previous position. New token in → compute fresh Q_new, K_new, V_new for just that token → append K_new, V_new to the cache → attention is Q_new against all cached K (size N), output is weighted sum of all cached V.

KV-cache size at inference:

Per layer, per head: K-cache: N · d_k · sizeof(dtype) V-cache: N · d_v · sizeof(dtype) Total cache for an L-layer, H-head model at sequence length N (fp16 = 2 bytes): cache_bytes = 2 · L · H · N · d_k · 2 bytes = 4 · L · H · N · d_k bytes Example: Llama 2 70B (L = 80, H = 64, d_k = 128, N = 4096 default) cache = 4 · 80 · 64 · 4096 · 128 = 10.7 GB (per request) Example: same model at N = 32768 (extended context): cache = 4 · 80 · 64 · 32768 · 128 = 86 GB (per request)

86 GB per request is bigger than the model weights themselves (140 GB → ~70 GB in fp16). For batched serving at long contexts, KV-cache memory dominates everything. This is why every modern inference stack obsesses about KV-cache size.

— think, then check —

Estimate: KV cache bytes = 2 (K + V) · L · H · N · d_k · 2 (fp16) = 4 · L · H · N · d_k.

= 4 · 80 · 64 · 32768 · 128 = 86 GB per request.

Why dominant: at N = 32K, the cache is bigger than the model weights themselves (70B params · 2 bytes/param = 140 GB total, but you can shard the model across GPUs and reuse the same weights for every request — fixed cost per machine). The KV cache, by contrast, is PER-REQUEST: each concurrent user needs their own cache. Serving 16 users in parallel at 32K context: 16 · 86 GB = 1.4 TB of KV-cache memory. That’s the inference deployment cost.

Scaling: cache_bytes = 4 · L · H · N · d_k.

  • Linear in N (context length) — doubling the context doubles the cache.
  • Linear in H (heads) — this is what MQA/GQA exploits: reduce H_kv to 1 or g, cache shrinks by H/H_kv.
  • Linear in L (layers) — bigger models hurt more.
  • Linear in d_k — but d_k is small (~128) and rarely changed.

Optimisations target each axis: GQA/MQA shrink H_kv; quantisation shrinks bytes per element (KV in int8 → 2× smaller); paged attention (vLLM) packs multiple requests’ caches with no padding waste; sliding-window attention drops old cache entries; sparse attention skips most cache reads.

MQA: one KV head, h query heads

Shazeer 2019 “Fast Transformer Decoding” proposed the simplest cache reduction: multi-query attention. Keep h query heads but use only ONE K head and ONE V head, shared across all queries.

MHA: MQA: Q has h heads Q has h heads (same as MHA) K has h heads K has 1 head, broadcast to all queries V has h heads V has 1 head, broadcast to all queries KV cache: factor of h smaller. For Llama 2 70B (h = 64) at N = 32K, MHA cache 86 GB → MQA cache 1.35 GB.

The math change is tiny — in the attention formula, every head’s query dots against the SAME K matrix; every head’s softmax weights the SAME V matrix. So all h heads’ outputs share the K/V information but differ in their query projections.

The trade-off. MQA gives up some expressivity. Different heads can no longer have different K projections — they can only differ in HOW they query a shared K. Empirically, MQA shows modest quality degradation vs MHA (around 1% on standard benchmarks), but it’s largely recoverable by training with MQA from the start or doing a brief MQA fine-tune of a pretrained MHA model.

GQA: groups of KV heads

Ainslie 2023 “GQA: Generalized Multi-Query” generalises: grouped-query attention uses g KV heads where 1 ≤ g ≤ h. Each KV head is shared by h/g query heads.

General formula: g KV heads, h query heads, h/g queries per KV group MHA: g = h (each query head has its own KV) MQA: g = 1 (all query heads share one KV) GQA: 1 < g < h (g groups, each shared by h/g queries) Llama 2 70B uses g = 8, h = 64 (8 queries per KV group) Cache reduction: factor of h/g = 8× smaller than MHA. Quality: within ~0.1% of MHA, within ~0.5% above MQA.

GQA hits the sweet spot. It’s now the default for every major open model: Llama 2 70B (g=8), Llama 3 (g=8), Mistral 7B (g=8 with sliding window), Qwen 2/3 (g=8). The cache savings of 8× vs MHA plus the quality recovery vs MQA make it dominate the design space.

GQA is also the reason “decoder-only vs encoder-decoder” architecture comparisons need to be careful about KV-cache cost. An encoder-decoder model recomputes K, V for the encoder once and caches it across all decoder steps — the per-step cost is dominated by the decoder’s KV cache, which is much smaller (decoder generates token by token). A decoder-only model at the same total parameter count has a single cache that grows with the full context. GQA partly closes this gap.

— think, then check —

Cache size scales linearly in g (the number of distinct KV heads). At g = 64 (MHA), 32K context = 86 GB cache per request. Each halving of g halves the cache.

g = 64 (MHA): 86 GB / request — too big for batched serving.

g = 16: 21.5 GB — still significant; 4× MQA cache.

g = 8: 10.75 GB — 8× smaller than MHA, manageable for batch sizes 8-16 on a 96GB GPU.

g = 4: 5.4 GB — minor additional saving over g = 8.

g = 1 (MQA): 1.35 GB — biggest savings but quality drops noticeably.

Why g = 8 is the sweet spot:

  1. Cache savings have diminishing returns past 8×. Going from MHA → g=8 saves 75 GB. Going from g=8 → g=1 saves only another 9 GB. The first 8× swing matters more than the next 8×.
  2. Quality recovery is g/H dependent in a non-linear way. Below g ≈ 8, head specialisation breaks down: queries that should learn distinct K patterns share a KV head, forcing them to compromise. Above g ≈ 8, there’s headroom for h/g ≈ 8 queries to share one K because their queries naturally cluster (Ainslie’s analysis: heads tend to form ~8 functional clusters in many models).
  3. Hardware alignment. g = 8 means h/g = 8 query heads per KV group — a clean SIMD/warp-friendly number for kernel implementations.

The decision rule: pick g as the LARGEST value where the cache fits comfortably in your serving memory budget at the target batch size. For 80GB H100 serving 32K context Llama-class models: g = 8 with batch ~8 fits. If your context grows to 128K, drop to g = 4 or move to MQA — quality loss is now worth it.

This is the genre of decision modern inference engineers make weekly. The math is mechanical; the budget constraint is what drives the answer.

Next: §13.3 — FlashAttention. Now that we have the cost model (N² for the attention matrix, KV cache for inference), the §12.3 online-softmax recurrence becomes the key to making this entire computation streaming over blocks instead of materialising the N × N matrix.