Multi-head, MQA, GQA — projection subspaces
§13.1 covered a single attention head. Real transformers run many heads in parallel. The motivation is simple: a single attention head produces ONE softmax weighting per query position — one “what should I pay attention to?” answer. But a sentence has many simultaneous structures (syntactic dependencies, anaphora, semantic associations, positional patterns). Letting the model run h independent attention computations in parallel — each in a lower-dimensional subspace of the embedding — gives it room to learn different relations in different heads. This section covers (1) the math of multi-head attention and why head_dim = d_model / h, (2) the KV cache that dominates inference memory, and (3) the MQA → GQA evolution that every modern open model adopted to shrink that cache.
Multi-head attention
Split the model dimension d_model into h heads, each operating in dimension d_k = d_model / h. Each head has its own W_Q, W_K, W_V projections; runs its own scaled-dot-product attention; produces its own output. The h outputs are concatenated and projected once more by W_O.
Multi-head attention is the original Vaswani 2017 design and the default for nearly a decade.
Three observations about this construction:
- Parameter count is unchanged vs single-head with same d_model. A single head of width d_model would have W_Q, W_K, W_V each of size d_model × d_model — total 3 · d_model². Multi-head with h heads, each of width d_k = d_model / h, has H projections each of size d_model × d_k — total H · 3 · d_model · d_k = 3 · d_model². Same.
- Compute is also unchanged. Each head does O(N² · d_k) work; h heads in parallel is O(H · N² · d_k) = O(N² · d_model). Same as a single head of width d_model.
- So what does multi-head buy? The structure constraint. Different heads see different lower-dimensional subspaces; the softmax in each head produces a distinct weighting; the outputs cover different patterns. Voita 2019, Clark 2019 visualisations show heads specialising in distinct linguistic relations: one head tracks coreference, another tracks adjacent-token bigrams, another tracks subject–verb agreement, etc. A single fat softmax can’t represent multiple overlapping weightings simultaneously.
The viz: same input sentence, two attention heads with different learned projections. Watch how the same query position attends to DIFFERENT keys in each head — that’s the structural win.
Parameter and compute counts are unchanged. What changes is the structure of the softmax constraint:
A single softmax-over-keys produces ONE weighting per query position — a single convex combination of value vectors. If the model needs to simultaneously attend to (a) the syntactic parent, (b) the most recent verb, (c) the semantic referent of a pronoun — a single softmax cannot encode all three. It must compromise to a single weighting.
Multi-head splits this into h INDEPENDENT softmaxes, each in its own subspace with its own learned Q/K/V projections. Different heads can attend to different things in parallel. Head 1 can softmax-weight the syntactic parent; head 2 the recent verb; head 3 the pronoun’s referent. The concatenated output mixes all three signals.
So multi-head is best understood as a structural constraint on the softmax, not as added capacity. The model gets h “votes” instead of one. Empirically: removing heads after training shows large variation in head importance — some heads matter a lot, some are prunable (Michel 2019), confirming heads specialise.
The trade: smaller d_k means each head’s dot products are over a smaller feature space, so each head is less expressive individually. But the gain of h independent weightings outweighs this. Optimal h has been found empirically to be around 16-32 for typical d_model ∈ [2K, 8K].
The KV cache: where inference memory goes
So far the math treats Q, K, V as if they’re all computed fresh on every forward pass. During training this is true. During autoregressive inference, it’s wasteful — and the waste is dominant.
The fix: cache the K and V vectors as they’re computed, and only compute fresh K, V for the newest token.
KV cache stores, for each attention layer and head, the K and V vectors for every previous position. New token in → compute fresh Q_new, K_new, V_new for just that token → append K_new, V_new to the cache → attention is Q_new against all cached K (size N), output is weighted sum of all cached V.
KV-cache size at inference:
86 GB per request is bigger than the model weights themselves (140 GB → ~70 GB in fp16). For batched serving at long contexts, KV-cache memory dominates everything. This is why every modern inference stack obsesses about KV-cache size.
Estimate: KV cache bytes = 2 (K + V) · L · H · N · d_k · 2 (fp16) = 4 · L · H · N · d_k.
= 4 · 80 · 64 · 32768 · 128 = 86 GB per request.
Why dominant: at N = 32K, the cache is bigger than the model weights themselves (70B params · 2 bytes/param = 140 GB total, but you can shard the model across GPUs and reuse the same weights for every request — fixed cost per machine). The KV cache, by contrast, is PER-REQUEST: each concurrent user needs their own cache. Serving 16 users in parallel at 32K context: 16 · 86 GB = 1.4 TB of KV-cache memory. That’s the inference deployment cost.
Scaling: cache_bytes = 4 · L · H · N · d_k.
- Linear in N (context length) — doubling the context doubles the cache.
- Linear in H (heads) — this is what MQA/GQA exploits: reduce H_kv to 1 or g, cache shrinks by H/H_kv.
- Linear in L (layers) — bigger models hurt more.
- Linear in d_k — but d_k is small (~128) and rarely changed.
Optimisations target each axis: GQA/MQA shrink H_kv; quantisation shrinks bytes per element (KV in int8 → 2× smaller); paged attention (vLLM) packs multiple requests’ caches with no padding waste; sliding-window attention drops old cache entries; sparse attention skips most cache reads.
MQA: one KV head, h query heads
Shazeer 2019 “Fast Transformer Decoding” proposed the simplest cache reduction: multi-query attention. Keep h query heads but use only ONE K head and ONE V head, shared across all queries.
The math change is tiny — in the attention formula, every head’s query dots against the SAME K matrix; every head’s softmax weights the SAME V matrix. So all h heads’ outputs share the K/V information but differ in their query projections.
The trade-off. MQA gives up some expressivity. Different heads can no longer have different K projections — they can only differ in HOW they query a shared K. Empirically, MQA shows modest quality degradation vs MHA (around 1% on standard benchmarks), but it’s largely recoverable by training with MQA from the start or doing a brief MQA fine-tune of a pretrained MHA model.
GQA: groups of KV heads
Ainslie 2023 “GQA: Generalized Multi-Query” generalises: grouped-query attention uses g KV heads where 1 ≤ g ≤ h. Each KV head is shared by h/g query heads.
GQA hits the sweet spot. It’s now the default for every major open model: Llama 2 70B (g=8), Llama 3 (g=8), Mistral 7B (g=8 with sliding window), Qwen 2/3 (g=8). The cache savings of 8× vs MHA plus the quality recovery vs MQA make it dominate the design space.
GQA is also the reason “decoder-only vs encoder-decoder” architecture comparisons need to be careful about KV-cache cost. An encoder-decoder model recomputes K, V for the encoder once and caches it across all decoder steps — the per-step cost is dominated by the decoder’s KV cache, which is much smaller (decoder generates token by token). A decoder-only model at the same total parameter count has a single cache that grows with the full context. GQA partly closes this gap.
Cache size scales linearly in g (the number of distinct KV heads). At g = 64 (MHA), 32K context = 86 GB cache per request. Each halving of g halves the cache.
g = 64 (MHA): 86 GB / request — too big for batched serving.
g = 16: 21.5 GB — still significant; 4× MQA cache.
g = 8: 10.75 GB — 8× smaller than MHA, manageable for batch sizes 8-16 on a 96GB GPU.
g = 4: 5.4 GB — minor additional saving over g = 8.
g = 1 (MQA): 1.35 GB — biggest savings but quality drops noticeably.
Why g = 8 is the sweet spot:
- Cache savings have diminishing returns past 8×. Going from MHA → g=8 saves 75 GB. Going from g=8 → g=1 saves only another 9 GB. The first 8× swing matters more than the next 8×.
- Quality recovery is g/H dependent in a non-linear way. Below g ≈ 8, head specialisation breaks down: queries that should learn distinct K patterns share a KV head, forcing them to compromise. Above g ≈ 8, there’s headroom for h/g ≈ 8 queries to share one K because their queries naturally cluster (Ainslie’s analysis: heads tend to form ~8 functional clusters in many models).
- Hardware alignment. g = 8 means h/g = 8 query heads per KV group — a clean SIMD/warp-friendly number for kernel implementations.
The decision rule: pick g as the LARGEST value where the cache fits comfortably in your serving memory budget at the target batch size. For 80GB H100 serving 32K context Llama-class models: g = 8 with batch ~8 fits. If your context grows to 128K, drop to g = 4 or move to MQA — quality loss is now worth it.
This is the genre of decision modern inference engineers make weekly. The math is mechanical; the budget constraint is what drives the answer.
Next: §13.3 — FlashAttention. Now that we have the cost model (N² for the attention matrix, KV cache for inference), the §12.3 online-softmax recurrence becomes the key to making this entire computation streaming over blocks instead of materialising the N × N matrix.