Quantization-aware training — STE, BitNet, QLoRA
§24.1 and §24.2 covered post-training quantization — train in fp16, then quantize for deployment. That gets you to ~4 bits with modest perplexity drop. Past 4 bits, PTQ degrades sharply and you need a different approach: quantization-aware training (QAT) puts the quantizer inside the training loop. The forward pass uses quantized weights; gradients flow through as if the quantizer didn’t exist (the straight-through estimator). The network learns weights that are quantization-friendly. This unlocks 2-bit and even 1-bit weights with much smaller quality loss. The same trick — STE through a non-differentiable operation — also powers BitNet (1.58-bit models trained from scratch) and QLoRA (the technique that lets you fine-tune Llama 2 70B on a 24GB consumer GPU). This section closes the chapter by walking the math, the kernel, and the production techniques.
The straight-through estimator (STE)
The quantization operation Q(W) = scale · round(W / scale) has zero derivative almost everywhere (it’s a step function) and undefined derivative at the step boundaries. Plain backprop through it would produce no gradient. The straight-through estimator sidesteps this by pretending in the backward pass that Q is the identity:
The STE is “wrong” — Q’s true derivative isn’t 1. But it’s wrong in a way that works empirically: the gradient still points in roughly the right direction (small perturbations of W produce small perturbations of Q(W) on average, even though pointwise it jumps), and the network learns weights that, when quantized, give good outputs.
The fake-quantize-in-forward + STE-in-backward pattern is the entire idea behind QAT. The kernel below trains a 2-layer MLP two ways and measures the deployment cost:
/* Symmetric int4 fake-quantize (per-block of QK) */
static void fake_quantize_int4(const float* W, float* W_fq, int n) {
for (int b = 0; b < n; b += QK) {
int end = b + QK; if (end > n) end = n;
float amax = 0;
for (int i = b; i < end; i++) {
float a = fabsf(W[i]); if (a > amax) amax = a;
}
float scale = amax / 7.0f;
if (scale == 0) scale = 1.0f;
for (int i = b; i < end; i++) {
int q = (int)roundf(W[i] / scale);
if (q > 7) q = 7;
if (q < -8) q = -8;
W_fq[i] = q * scale;
}
}
}
/* Generate a synthetic regression dataset: y = sin(sum(x)) + noise. */Output:
Mode 1: train in fp32 (no QAT)
fp32 ep 0 loss = 0.79713
fp32 ep 100 loss = 0.37293
fp32 ep 199 loss = 0.34471
Mode 2: train with QAT (forward pass uses Q(W) via STE)
qat ep 0 loss = 0.80807
qat ep 100 loss = 0.37556
qat ep 199 loss = 0.34672
--- Final eval ---
Model fp32 deploy int4 deploy gap
trained fp32 0.34449 0.35018 +1.65%
trained QAT (STE) 0.34799 0.34424 -1.08%
Read carefully: the fp32-trained model has loss 0.344 in fp32 deployment, 0.350 in int4 — a 1.65% degradation. The QAT-trained model has slightly worse fp32 loss (0.348) but essentially zero degradation when deployed in int4 (0.344). On the deployment metric (int4), QAT wins. It traded a tiny bit of fp32 quality for a model whose weights are pre-adapted to be quantized.
The mathematical problem:
The quantizer Q(W) = scale · round(W / scale) is a step function. Its derivative is 0 almost everywhere and undefined at the step transitions. A pure chain-rule backprop through it would produce zero gradient and the network couldn’t learn.
What STE does:
Pretend Q’s derivative is 1 in the backward pass. The forward uses Q(W); the backward computes gradients as if Q were just the identity.
This is provably “wrong” — Q is not the identity. But for the purpose of training, it gives a workable signal.
Why it works:
(1) On AVERAGE, Q approximates the identity. For small enough perturbations, E[Q(W + δ)] ≈ Q(W) + δ — the expectation of a small perturbation to W after quantization is close to the same small perturbation. So while the pointwise derivative is wrong, the expected derivative behaves like 1.
(2) The loss landscape isn’t sharp. Quantization noise produces local fluctuations in the loss, but the BROAD landscape (averaged over many quantization-boundary crossings) is smooth. The STE gradient points in the right direction in this smoothed sense.
(3) The fp32 “shadow” weights matter, not the gradient at the quantization boundary itself. The optimiser updates the fp32 weights W; the quantized weights Q(W) are derived. As W moves smoothly through space, Q(W) jumps occasionally — but the fp32 W’s trajectory is what learns, and STE gives a reasonable signal for that trajectory.
(4) Empirically. Bengio 2013, Hinton 2012 lectures, and a decade of follow-up work showed STE is the workhorse for training low-bit networks. The “wrong” gradient is good enough.
The deeper insight: STE is a particular instance of a broader pattern in deep learning: when a forward operation is non-differentiable, replace its backward pass with the identity (or a smooth surrogate). The same trick is used in Gumbel-Softmax (for discrete sampling), in REINFORCE (for sampling actions), in argmax-attention (for hard attention). All “wrong” gradients that work.
Learned Step Size Quantization (LSQ)
A refinement: instead of using a fixed scale per block, make the scale a learnable parameter. The optimizer trains both the weights and the per-block scale. Esser 2019 “Learned Step Size Quantization”:
LSQ lets the model discover the optimal scale per block during training, instead of using the absmax heuristic. Typical gains: 0.2-0.5 perplexity over absmax-based QAT at the same bpw. Standard in modern QAT pipelines for sub-4-bit quantization.
BitNet — 1-bit (and 1.58-bit) from scratch
Wang 2023 “BitNet” took QAT to its extreme: 1-bit weights, trained from scratch.
A 1-bit weight matrix has no per-element scale and only one sign bit per weight. The matmul Y = X · W becomes a sequence of additions and subtractions — no multiplications. On custom hardware, this is dramatically cheaper than even int8 matmul.
Ma 2024 “BitNet b1.58” refined this to three states 1 — 1.58 bits per weight (log₂ 3 ≈ 1.585). The third “0” state lets the network gate connections, which empirically recovers most of the quality gap to fp16. BitNet b1.58 reportedly matches fp16 Llama at sizes ≥ 3B params on standard benchmarks, with 8× memory reduction and similar compute reduction on hardware that exploits the ternary structure.
BitNet b1.58 is the strongest evidence that quantization-aware training can fundamentally change the cost / quality frontier — not just compress an already-trained model, but produce a fundamentally cheaper model that’s as good.
The catch: BitNet has to be trained from scratch. You can’t take a Llama 3 70B fp16 checkpoint and convert it to 1.58-bit without losing massive quality — the fp16 weights aren’t in a 1.58-bit-friendly configuration. Training from scratch costs the same as training fp16 (because the activations are still fp16; only weights are ternary). So BitNet is a deployment win, not a training win.
The crucial structural difference:
fp16 training settles into a weight distribution with continuous-valued weights spanning a wide range of small magnitudes — a near-continuous Gaussian shape with std ~0.02. This distribution has nothing in common with 1. Rounding it to ternary is catastrophic.
1.58-bit training-from-scratch with STE settles into a fundamentally different weight distribution. Throughout training, the model’s effective forward pass uses ternary weights — so the network LEARNS to encode information in ternary form. The continuous “shadow” weights that the optimizer updates are constantly being rounded; they evolve toward configurations that ROUND to good ternary patterns.
What changes in the loss landscape:
fp16 training: gradient descent in a continuous space, exploring a Gaussian-shaped weight distribution.
1.58-bit QAT: gradient descent in the same continuous space BUT the loss is computed using ternary weights. The optimizer learns to find ternary-aligned local minima — points where small perturbations of the shadow weights don’t change the ternary result, but the ternary configuration is locally optimal.
Why this works at all:
The Lottery Ticket Hypothesis (Frankle 2018) and follow-up work showed that sparse / ternary subnetworks WITHIN a dense network can match the dense network’s performance. BitNet finds these structures directly by constraining the search space to ternary throughout training, instead of trying to find them after the fact via pruning + quantization of an already-trained dense model.
Why PTQ fails:
fp16 weights and ternary weights live in different “neighborhoods” of weight space. There’s no smooth path from a typical fp16 weight value (e.g., 0.0184) to its rounded ternary equivalent (0) — the rounding throws away the value. PTQ assumes the two regimes are close; they’re not, for 1.58-bit. QAT navigates the network to a ternary-friendly region of weight space directly.
The training cost: BitNet has to be trained fully from scratch — you cannot “convert” Llama 3 to BitNet b1.58 without retraining. So the win is at deployment (8× memory, much cheaper compute) but the upfront training investment is the same as fp16.
QLoRA — the production trick
The most impactful application of QAT-adjacent techniques is QLoRA — Dettmers 2023 — the technique that put fine-tuning of 70B models within reach of solo researchers.
The data type used for the base is NF4 (NormalFloat 4-bit), which uses 16 quantization levels positioned at the quantiles of a normal distribution. Since LLM weights ARE approximately normally distributed, NF4 represents them more accurately than evenly-spaced int4. The choice is matched to the empirical weight distribution.
Why this works as well as fine-tuning the full fp16 model:
- The LoRA adapters B · A have full fp16 precision, so they can represent any necessary correction to W_base.
- The frozen 4-bit base is “good enough” — its quantization error is in directions the LoRA adapter can correct.
- The base + adapter combination at inference is effectively a higher-precision matrix than 4-bit alone.
The memory math is brutal in QLoRA’s favor. A 70B model:
QLoRA is the workhorse of single-GPU LLM customization. Every “fine-tune your own Llama” tutorial on Hugging Face goes through QLoRA.
Tensors and dtypes at fine-tuning time:
For each linear layer L in the base model:
- W_base — 4-bit NF4, frozen. Stored on GPU. Dequantized to fp16 on-the-fly inside the layer’s forward pass.
- A — fp16, trainable. Shape (r, d_out). Small (r typically 16-64).
- B — fp16, trainable. Shape (d_in, r). Initialized to zero so the LoRA adapter starts as a no-op.
- m, v (Adam state) — fp16 or fp32, for A and B only.
- Gradient buffers ∂L/∂A, ∂L/∂B — fp16. For A and B only.
Forward pass for layer L:
W_eff = dequantize_nf4(W_base) + B @ A (W_eff in fp16, held briefly)
Y = X @ W_eff (matmul in fp16)
The dequantized W_base is materialized briefly, used for the matmul, then discarded. Activation X is checkpointed for backward.
Backward pass for layer L:
Incoming: ∂L/∂Y, X.
∂L/∂W_eff = X^T @ ∂L/∂Y (the gradient w.r.t. the effective weight — straightforward matmul backward)
∂L/∂A = B^T @ ∂L/∂W_eff (chain rule through W_eff = … + B @ A)
∂L/∂B = ∂L/∂W_eff @ A^T
∂L/∂W_base = ∂L/∂W_eff ← computed BUT discarded (W_base is frozen)
∂L/∂X = ∂L/∂Y @ W_eff^T (passed back to previous layer)
What’s critical:
- The 4-bit W_base is NEVER updated. We only read it during forward (and to compute ∂L/∂X during backward), but we don’t track gradients for it.
- The dequantization happens on-the-fly during forward — no fp16 copy of W_base is stored long-term. This is the memory win.
- The LoRA adapters A, B are tiny (rank r ≈ 16) but full-precision (fp16). They learn to compensate for both the base model’s deficiencies on the new task AND the quantization noise in W_base.
- The optimizer only updates A and B. Adam state, gradients, and parameter copies are needed only for these — typically <1% of the parameter count.
Why this matches full fine-tuning quality:
The LoRA adapters can express the same range of weight modifications as full fine-tuning at the granularity that matters for the fine-tuning task (low-rank updates capture most task-specific changes). The 4-bit base introduces ~0.04 perplexity drop vs fp16 base, which is recovered by the LoRA adapters. Net: QLoRA fine-tunes match full fp16 fine-tunes within ~0.1 perplexity at 1% the memory cost.
QLoRA is the most important deployment-side training technique of the post-Chinchilla era. Every Llama derivative on Hugging Face that says “fine-tuned” was probably QLoRA-fine-tuned.
The picture across the chapter
Three layers of quantization, three different scopes:
Where each lives in production:
- PTQ (q4_K_M, AWQ, GPTQ): every llama.cpp deployment, every Hugging Face quantized model. Default for inference.
- QAT: research and bespoke deployments, increasingly mainstream for 2-3 bit production targets. Apple’s foundation models reportedly use it.
- BitNet b1.58: still mostly research as of 2026, but Microsoft’s 100B-class BitNet 4.5 model is in active development and may ship.
- QLoRA: every fine-tune-your-own tutorial; every research project that customises a 70B model; the entire “fine-tune Llama” ecosystem on Hugging Face.
END OF CH.24 — Quantization.
§1 (PTQ basics, blockwise vs per-tensor, outliers, LLM.int8 → GPTQ → AWQ) ·
§2 (the GGML family: q4_0 to q6_K with exact byte layouts, q4_K_M model naming, IQ-quants + imatrix) ·
§3 (QAT + STE, BitNet b1.58, QLoRA — the workhorses of low-bit training and single-GPU fine-tuning).
Coming next: back to the normal chapter order. Ch.16 — Pretraining, Chinchilla scaling laws, the token budget.