SOFTMAX & THE EXPONENTIAL FAMILY
Section 12.3
03

Online softmax — the FlashAttention key

The whole conversation that started this book was about FlashAttention’s tiling — about being able to compute attention without materialising the full N × N attention matrix. The Ch.2 §2 three-axis story explained why matrix multiplication tiles freely: independent loops over M, N, K, with K being the only “non-trivial” reduction axis. But attention has a softmax in the middle of two matmuls — and softmax is, naïvely, a global operation: it needs the max across all input positions, then a sum across all input positions, before any output can be computed. It looks impossible to tile. The breakthrough — and the one that powers every long-context LLM — is that softmax actually CAN be computed by streaming over blocks of the input, maintaining two running scalars (the max and the sum-of-exps) per row. The identity that makes this work is what this section derives and runs in code. Once you have online softmax, FlashAttention falls out almost mechanically — which is what Ch.13 does.

The reduction problem

Recall §12.1’s stable softmax:

m = max_j z_j ← first pass: find max ℓ = Σ_j exp(z_j − m) ← second pass: sum the (shifted) exponentials out_i = exp(z_i − m) / ℓ ← third pass: compute probabilities

Three passes over the input. For N = 64K logits, that’s 192K memory accesses just for the softmax, plus the actual cost of computing exp.

For tiled / block-streaming computation, this is a problem: the second pass needs m from the first pass, and the third pass needs from the second. There’s no obvious way to process the input block-by-block — each block needs the global m first, which requires seeing every block.

This is essentially the same problem Ch.2 §2 named: softmax is a non-associative reduction. The max operation is associative (you can combine partial maxes), but the sum-of-exps with max-subtraction is not — because the maxes change as new blocks arrive, and the shift inside each exp depends on the current running max. Naïve tiling breaks the math.

The online identity

Milakov & Gimelshein 2018 (“Online normalizer calculation for softmax,” arXiv:1805.02867) proved that the softmax CAN be computed by streaming over the input, with a clever update rule that handles the changing max.

Let’s see why. Suppose we’ve processed blocks 1, 2, …, t and have:

m_t = max over blocks 1..t of z_j (running max) ℓ_t = Σ over blocks 1..t of exp(z_j − m_t) (running sum-of-exps, normalised by current max)

Now block t+1 arrives. Compute its local max and local sum-of-exps:

m_{t+1}^{loc} = max over block t+1 of z_j ℓ_{t+1}^{loc} = Σ over block t+1 of exp(z_j − m_{t+1}^{loc})

The new running max is the max of the old running max and the block’s max:

m_new = max(m_t, m_{t+1}^{loc})

Now we need to rescale BOTH the old running sum (its exponents were shifted by the old max m_t) and the block’s sum (its exponents were shifted by the block’s max mₗₒc) to use the new shift m_new:

ℓ_new = ℓ_t · exp(m_t − m_new) + ℓ_{t+1}^{loc} · exp(m_{t+1}^{loc} − m_new) The trick: exp(z_j − m_t) · exp(m_t − m_new) = exp(z_j − m_new). (old normalisation) (rescale) (correctly normalised by new max) Same for the block: exp(z_j − m_{t+1}^{loc}) · exp(m_{t+1}^{loc} − m_new) = exp(z_j − m_new).

Both terms are now correctly normalised by m_new, and we can sum them. The running state (m_t, ℓ_t) is updated to (m_new, ℓ_new) and we move on to the next block.

Final probabilities are computed at the end with one more pass: p_i = \exp(z_i − m_final) / ℓ_final.

Online softmax is the entire mechanism. Two scalars per row, one rescaling step per block, no full-input access required.

step 0 / 4
input z (logits):
0 -0.541 1.032 3.953 4.504 -5.675 2.246 -2.437 3.408 1.309 0.0210 0.4411 -3.5012 3.6113 1.4114 0.4415 2.94
final running state vs naïve full-batch:
m (online)
4.500
m (naïve)
4.500
ℓ (online)
2.80141
ℓ (naïve)
2.80141
Each block produces a local max and a local sum-of-exps. The running state (m, ℓ) is updated via the two-term rescale formula. Final (m, ℓ) matches the naïve full-batch computation exactly, regardless of block size. That identity is the structural fact FlashAttention (Ch.13) uses to compute attention in O(N) memory instead of O(N²).
Online softmax steps through the input N positions in blocks of B. Each step updates the running max (m) and running sum-of-exps (ℓ). After processing every block, the running state contains exactly what naïve full-batch softmax would have produced.

Slide the block size B. Watch the running state update at each step. After all blocks have been processed, the running (m, ℓ) matches the naïve full-batch values exactly — that’s the green-text comparison at the bottom.

— think, then check —

(1) Compute the block’s local max m_loc and local sum-of-exps ℓ_loc = Σ exp(z − m_loc).

(2) New running max: m_new = max(m, m_loc).

(3) Rescale and combine: ℓ_new = ℓ · exp(m − m_new) + ℓ_loc · exp(m_loc − m_new).

(4) Update: m ← m_new, ℓ ← ℓ_new.

The exp(m − m_new) factor ‘re-shifts’ the old running sum from its previous max basis to the new max basis. Same for the block. After the rescale, both terms are normalised by m_new and can be added. Final probabilities: p_i = exp(z_i − m_final) / ℓ_final.

Two scalars of state per row. Each block’s contribution is a constant-time update. Block size doesn’t affect correctness — only memory and parallelism. The smaller the block, the more granular the streaming.

Now make it run

The kernel processes the same 1024-element input with seven different block sizes and compares to naïve full-batch softmax:

online_softmax.c — softmax_online C · online vs naive comparison
/* Naïve numerically-stable softmax — full batch. */
static void softmax_naive(const float* z, float* out, int N) {
    float m = -INFINITY;
    for (int i = 0; i < N; i++) if (z[i] > m) m = z[i];
    float sum = 0;
    for (int i = 0; i < N; i++) {
        out[i] = expf(z[i] - m);
        sum += out[i];
    }
    for (int i = 0; i < N; i++) out[i] /= sum;
}

/* Online softmax — process blocks of size B. */
static void softmax_online(const float* z, float* out, int N, int B) {
    float m = -INFINITY;  /* running max */
    float ell = 0.0f;     /* running sum of exp(z_i − m) */

    /* Two passes: pass 1 builds the running (m, ℓ); pass 2 produces the probs. */
    for (int start = 0; start < N; start += B) {
        int end = start + B; if (end > N) end = N;
        float m_block = -INFINITY;
        for (int i = start; i < end; i++) if (z[i] > m_block) m_block = z[i];
        float ell_block = 0;
        for (int i = start; i < end; i++) ell_block += expf(z[i] - m_block);

The output:

Online (streaming) softmax — block size sweep, N = 1024
block size   max |online−naive|    total prob
1            0.00e+00              1.000001
2            7.15e-07              1.000000
8            4.77e-07              1.000000
32           5.36e-07              1.000000
128          4.77e-07              1.000000
512          3.58e-07              1.000000
1024         0.00e+00              1.000001

At every block size — from 1 (per-element streaming) to 1024 (whole-batch) — the online output agrees with the naïve full-batch output to float roundoff (10⁻⁷ max). The total probability sums to 1 to 6 decimals. Block size doesn’t change correctness, only memory and parallelism.

Why this unlocks FlashAttention

Attention is conceptually QKᵀ then softmax then V matmul. The N×N attention matrix is the bottleneck — for N = 64K, that matrix alone is 16 GB in float32, 4 GB in bfloat16. Materialising it doesn’t fit in GPU SRAM (the fast on-chip memory).

FlashAttention’s trick is to process the keys and values in blocks. For each query position:

  1. Iterate over key/value blocks one at a time.
  2. Compute the partial attention scores for the current key block: s = q · K_block^T.
  3. Update the running softmax statistics (m, ℓ) using the online identity.
  4. Update the running output O by rescaling and adding (softmax_block · V_block).

The key insight: the same online-softmax recurrence applies, plus a parallel rescaling of the running output. Each block contributes to m, , and O; the per-block work fits entirely in SRAM; the full N×N attention matrix never materialises.

Memory cost drops from O(N²) to O(N). Compute cost stays the same (still 2 N² d flops total), but you no longer pay the bandwidth cost of writing and re-reading the N×N matrix to HBM.

We’ll derive the full FlashAttention algorithm in Ch.13 — but the heart of it is now sitting on this page. Online softmax + a running output buffer = FlashAttention.

— think, then check —

What online softmax avoids materialising: the N × N attention score matrix s = QK^T, and the corresponding N × N softmax weight matrix. With online softmax, FlashAttention processes K and V in blocks and maintains per-query-row running state (m, ℓ, O). Each block’s contribution to the output is computed and combined into the running state, then discarded. The full N × N matrix never exists in memory.

Why this is different from vanilla matmul tiling: Ch.2 §2 said matmul tiles freely along M and N (independent loops) and is awkward along K because K is the contraction axis. K-tiling produces partial sums that get added.

For attention’s softmax, the K axis (the sequence length you’re attending OVER) is similarly awkward — but the reduction is the softmax-normalised sum, which is non-associative without correction. You can’t just sum partial contributions; the softmax denominator depends on the global max, which changes as more blocks arrive.

The online softmax recurrence is the correction: it carries the running (m, ℓ) and renormalises BOTH the old partial output AND the new block’s contribution to the same max basis before combining. After this correction, the non-associative reduction becomes associative-with-state — each block is combined via the same rescaling formula, in any order, and the result is bit-equal to processing all of K at once.

So FlashAttention isn’t ‘just tiling’ — it’s tiling-with-state. The state is (m, ℓ, O) per query row, and the rescaling rule for the state on each new block is what makes the math correct. Vanilla matmul tiles with NO state per partial output; FlashAttention tiles with 2+d_v scalars of state per query row. That extra state is what the algorithm pays to make softmax tileable.

This is the formal answer to the original conversation’s question: the softmax reduction was the obstacle to attention tiling; the online-softmax recurrence is the unlock. Once you have it, attention reduces to a stream of small block-matmuls plus running-state updates — and that’s the FlashAttention kernel.

— think, then check —

Online softmax combines the FIRST two passes (find max and compute exponentials) into one streaming pass that maintains running (m, ℓ). The THIRD pass — divide by the final ℓ — is unavoidable for the output probabilities themselves (you need to know the global ℓ before normalising each output).

But: in FlashAttention, the third pass isn’t needed because attention’s output isn’t the probabilities themselves — it’s the weighted sum of V vectors. You can carry a running O = (softmax-weighted V) along with (m, ℓ), and rescale O on each block update. After the last block, O is divided by the final ℓ once — but O is a d-dimensional vector, not an N-dimensional probability distribution. So the per-row cost is O(d) at the end, not O(N).

This is why FlashAttention’s memory access cost is O(N) instead of O(N²): no N-element probability vector is ever materialised; only an O(d) running output per query row.

Hardware impact: GPU SRAM is fast but small (~50 KB on consumer GPUs, larger on H100). The N-dim softmax probability buffer for N = 64K is 256 KB in float32 — doesn’t fit in SRAM. The d-dim FlashAttention running output for d = 128 is 512 bytes — fits easily, runs at SRAM speed. The pass-count reduction translates directly into ‘whole computation stays in SRAM’ which is what makes FlashAttention 3–5× faster wall-clock than vanilla attention at long sequences.

END OF CH.12 — Softmax & the exponential family.
§1 (softmax + max-subtraction stability) · §2 (cross-entropy = KL divergence + entropy; label smoothing, perplexity) · §3 (online softmax — the FlashAttention key).

We now have all the pieces for attention. The math: Ch.4 §3 (chain rule), Ch.2 §2 (matmul tiling on three axes), §12.3 (online softmax). The systems: Ch.2 §4 (microkernel pattern), Ch.3 §3 (int8 SIMD), Ch.20 will revisit HBM/SRAM. Ch.13 puts them all together into the attention layer and FlashAttention algorithm — closing the loop on the original conversation that started this book.

Coming next: Ch.13 — Attention, fully assembled. The conceptual capstone.