QKᵀ, scaling, softmax, V; multi-head → GQA/MQA. FlashAttention as tiling + online softmax. The capstone of Part III.
Attention is a content-addressable lookup over a sequence: every position emits a query Q, a key K, and a value V; the query scores each key; softmax over scores produces a weighting; the output is the softmax-weighted sum of values. The whole layer is three matmuls + one softmax — but the softmax is over an N × N matrix, which is the central bottleneck FlashAttention will solve.
Multi-head attention runs h independent attention computations in parallel, each in a d_model/h subspace. Different heads learn different relations. For autoregressive inference, the K and V tensors of past tokens are cached — KV cache size scales as N · h · d_k per layer, dominating memory at long contexts. MQA (1 KV head) and GQA (g KV heads) shrink the cache 8–32× with minimal quality loss; every modern open model (Llama 3, Mistral, Qwen) uses GQA.
FlashAttention computes exact attention by tiling Q, K, V into blocks small enough to fit in GPU SRAM. The online-softmax recurrence (§12.3) plus a running output buffer O — rescaled the same way as ℓ — eliminates the need to materialise the N × N attention matrix. Memory drops from O(N²) to O(N). The kernel in this section verifies bit-equal output to naïve attention across all tile sizes; this is the algorithm Dao 2022 deployed and the algorithm every modern transformer inference engine uses.
← ALL CHAPTERS