MATRICES AS TRANSFORMATIONS
Section 2.2
02

Matmul as composition; the three axes

§1 made one promise: matrices are functions. Now collect on it. If B and A are two linear functions, applying first B then A gives a new linear function — the composition A ∘ B. That composed function is also a matrix (every linear function is). The matrix that represents it is exactly A · B. Matrix multiplication isn’t a strange notation invented by 19th-century algebraists; it’s the bookkeeping that makes composition compute correctly. Once you see that, three otherwise-arbitrary rules fall into place at once, and a fourth — the rule that drives every fast attention kernel since 2022 — becomes visible.

Composition, then notation

Two linear functions in sequence:

B : ℝᵏ → ℝⁿ // takes a k-vector, produces an n-vector A : ℝⁿ → ℝᵐ // takes an n-vector, produces an m-vector A ∘ B : ℝᵏ → ℝᵐ // input via B's input, output via A's output

Apply it to an input x ∈ ℝᵏ in two steps: (A ∘ B)(x) = A(Bx). The result lives in ℝᵐ. The intermediate Bx lives in ℝⁿ — the space where B lands and A picks up. That intermediate dimension n is the agreement between the two functions; it’s the only thing they have to share for the composition to make sense.

Now apply the §1 picture: the composed function A ∘ B is linear, so it’s a matrix. Its j-th column is wherever (A ∘ B) sends eⱼ, the j-th basis vector of ℝᵏ. Compute that:

(A ∘ B)(eⱼ) = A(B eⱼ) = A · (j-th column of B)

So column j of the composed matrix is what A does to column j of B. Stack those over all j and you have the full composed matrix:

┌ ┐ A·B = │ A·col₁(B) A·col₂(B) ⋯ A·colₖ(B) │ ← k columns, each in ℝᵐ └ ┘

That’s matrix multiplication. The “rows-times-columns” mechanics you memorized is what you get when you write out the A·colⱼ(B) step using A’s row picture — every entry of every column is one dot product. But the operation is composition; the rows-and-columns formula is the implementation.

The four rules of matmul, derived in one paragraph.
(1) Dimension match. A·B only makes sense when A’s columns and B’s rows agree — that’s the n they share as functions.
(2) Shape of result. m × n times n × k gives m × k — the outer dims, because the composed function is ℝᵏ → ℝᵐ.
(3) Non-commutativity. A·B ≠ B·A because function composition isn’t commutative — “rotate then scale” is a different function from “scale then rotate.”
(4) Associativity. (A·B)·C = A·(B·C) because ((f∘g)∘h)(x) = (f∘g)(h(x)) = f(g(h(x))) = (f∘(g∘h))(x). Composition is associative by definition.

— think, then check —

The rightmost: B is applied first, then A. The expression literally reads “A applied to (B applied to x).” This is the opposite of how you read a math expression left-to-right, and it’s the opposite of a Unix pipeline (‘x | B | A’ would mean the same thing in pipe notation). The direction surprises people; once you fix the picture as “rightmost is innermost,” everything from gradient backprop ordering to attention’s QKᵀV ordering makes sense.

The three axes

Matrix multiplication has — count them — three independent loop dimensions. The textbook writes them in the formula:

C[i][j] = Σₚ A[i][p] · B[p][j] for i = 0..m-1, j = 0..n-1, p = 0..k-1

Three indices, three nested loops, three independent axes. Worth naming explicitly:

The viz makes this visible. Pick an axis, see which tiles of each input the corresponding output tile depends on. Hover an output cell to see its full K-axis contraction.

tile along
Hover an output cell to see the K-axis contraction. Click an axis to see how that axis tiles.
A 4×6
-2
0
2
-1
1
-2
1
-2
0
2
-1
1
-1
1
-2
0
2
-1
2
-1
1
-2
0
2
×
B 6×5
-2
1
-1
2
0
0
-2
1
-1
2
2
0
-2
1
-1
-1
2
0
-2
1
1
-1
2
0
-2
-2
1
-1
2
0
=
C = AB 4×5
14
-7
2
-4
-5
-7
11
-6
2
0
2
-6
11
-7
0
-4
2
-7
14
-5
c[i][j] = Σₚ A[i][p] · B[p][j] ← that Σ is the K-axis
Three matrices, three loop dimensions. M = output rows, N = output columns, K = contraction (the dimension being summed over). Each axis tiles independently — the structural fact that makes FlashAttention possible.

The cell coloring on the inputs follows where they’re read from when tiling along that axis. Notice: an M-axis tile of C needs the same rows of A but all of B. An N-axis tile of C needs the same columns of B but all of A. A K-axis tile needs partial slices of both A and B — and the outputs aren’t done yet (a K-tile produces a partial sum that the other K-tiles add to).

Why the three axes matter

A loop you can tile is a loop you can reorder, parallelize, distribute, or fuse. The structural payoff of seeing three axes is that you can reason about each one independently — and they have different properties:

Now make it run — naïve gemm

The textbook formula transliterated, three loops, no cleverness. Here it is — the slowest matmul kernel you will ever write, included so the three axes are unmistakable in source:

gemm_naive (ijk) C, three nested loops · the formula
#include <string.h>

void gemm_naive(const float* A, const float* B, float* C,
                int m, int n, int k) {
    memset(C, 0, (size_t)m * n * sizeof(float));
    for (int i = 0; i < m; i++) {              /* M axis — output rows    */
        for (int j = 0; j < n; j++) {          /* N axis — output cols    */
            float s = 0.0f;
            for (int p = 0; p < k; p++) {      /* K axis — contraction    */
                s += A[(size_t)i * k + p] * B[(size_t)p * n + j];
            }
            C[(size_t)i * n + j] = s;

The loops are labeled with the axis each one traverses. Move them around — ikj, kij, jik — six permutations exist, all produce identical results because matmul’s definition doesn’t depend on iteration order. Different orders have wildly different cache behavior (the topic of §4), but the math doesn’t care:

gemm_naive_kij same matmul, K loop hoisted outermost
 * topic of §4). Included here so the test harness can prove that all six
 * loop orderings produce identical results: matmul doesn't care about
 * loop order, only about the three axes. */
void gemm_naive_kij(const float* A, const float* B, float* C,
                    int m, int n, int k) {
    memset(C, 0, (size_t)m * n * sizeof(float));
    for (int p = 0; p < k; p++) {              /* K outermost             */
        for (int i = 0; i < m; i++) {          /* M middle                */
            float a_ip = A[(size_t)i * k + p];
            for (int j = 0; j < n; j++) {      /* N innermost             */
                C[(size_t)i * n + j] += a_ip * B[(size_t)p * n + j];

The test in code/ch02/test_gemm.c proves three things at once — and each is one of the “four rules” above made executable:

AB = [7 2; 3 1]
BA = [1 2; 3 7]   <- different, as expected
ijk and kij orderings produce identical 6x7 output
(PQ)R == P(QR)  -- associativity holds (4x2 output)

Non-commutativity (rule 3), loop-order invariance (a corollary of axes being independent of iteration), and associativity (rule 4) all hold for any matrices you’d care to multiply. The first two come from the function-composition view; the third is a property of composition itself.

— think, then check —

K is the odd one out. M and N appear in the output (they index C’s rows and columns); K does not. K is the dimension being summed over. A tile along K produces a partial sum — not a finished value — and the partial sums from all K-tiles must be combined before the output is correct.

Mechanically: M-tiles and N-tiles partition C into disjoint pieces; K-tiles partition computation of every C entry. That’s why parallel matmul typically splits over M and N (no coordination) and resorts to either an additive reduction (sum-K) or carried statistics (online-softmax style) when it splits over K.

The forward connection — why this section matters for Ch.13

Attention computes S = QKᵀ, then a softmax along each row of S, then O = softmax(S) · V. Both S = QKᵀ and O = softmax(S) · V are matrix multiplications — and tiling them is exactly the three-axis problem above. M and N tiles are free; K (the head-dimension d) is the awkward one.

For S = QKᵀ, you can tile freely along all three axes because there’s no reduction issue downstream of the multiply yet. For the second multiply, the K-axis of O = softmax(S) · V is the sequence-length N — and that’s the axis FlashAttention streams, carrying the softmax running max/sum across tiles. The strange-looking “online softmax” you may have heard of is just K-axis tiling under a non-associative reduction. Now you have the substrate to understand it.

— think, then check —

Let R be a 90° rotation and S be a shear along the x-axis: R = [[0,-1],[1,0]] sends e₁ to e₂; S = [[1,1],[0,1]] sends e₂ to e₁ + e₂.

Compute the composed functions on e₁:

(R·S)(e₁) = R(S(e₁)) = R(e₁) = e₂ — the shear leaves e₁ alone, then the rotation lifts it up.

(S·R)(e₁) = S(R(e₁)) = S(e₂) = e₁ + e₂ — the rotation lifts e₁ to e₂, then the shear pushes it sideways.

Two different output vectors, same input. The functions are different functions, and their matrices differ accordingly. Geometrically: rotation rotates wherever its input currently is; shear shears wherever its input currently is. Order matters because each one acts on what it sees, not what was originally there. This is the same reason “put on socks, then shoes” doesn’t compose to the same outfit as “put on shoes, then socks.”

END OF CH.2 §2 — Matmul as composition; the three axes.
Built: MatmulThreeAxes viz (hover an output cell to see its K-axis contraction; click an axis to see how it tiles), gemm_naive.c with two loop orderings, a test that proves non-commutativity, loop-order invariance, and associativity. Three recall items.
Coming next: §2.3 — Orthogonal and rotation matrices.