Pretraining — next-token prediction at trillion-token scale
From a math standpoint, pretraining is the simplest part of an LLM. The architecture is fixed (Ch.15: pre-norm decoder-only, RMSNorm, RoPE, SwiGLU); the objective is fixed (next-token cross-entropy from Ch.12 §2); the optimizer is fixed (AdamW with linear warmup + cosine decay, Ch.8). What changes is the scale: a 70B Llama-class run is on the order of 10²⁵ floating-point operations, more compute than the entire history of computer graphics. This section walks the actual training loop with real-shape numbers, explains the mixed-precision trick that makes bf16 training stable, and counts how many bits move between GPUs per step (the answer surprises people). Most of the engineering complexity is invisible — it lives in the data pipeline (§16.2) and the scaling-law-driven budget (§16.3).
The training loop
Real pretraining loops are remarkably uniform across labs. Llama 3, Mistral, Qwen — all run the same basic structure:
Two details that look small but dominate the engineering:
Mixed precision training means the forward pass uses bf16 (or fp16), but the optimizer maintains a fp32 “master copy” of every parameter, and the gradient accumulation buffer is fp32. The bf16 forward is twice as fast and uses half the memory; the fp32 master prevents the slow accumulation drift that a pure bf16 optimizer would suffer over 10²⁵ ops.
Gradient accumulation is how you train with a 4M-token effective batch when only 32K tokens fit in any single GPU. Each micro-batch’s forward + backward computes a partial gradient; you don’t step the optimizer until you’ve accumulated 128 micro-batches’ worth (or whatever the ratio is). Activation memory is bounded by the micro-batch size; the effective batch size is whatever you want.
Once per micro-batch (inner loop):
- Sample text chunk; tokenize; move to GPU.
- Forward pass in bf16; compute cross-entropy.
- Backward pass; accumulate gradients into fp32 buffer.
Once per gradient step (outer loop, after N micro-batches):
- all_reduce(gradients) — sync across all data-parallel ranks (this is the bandwidth-dominant step).
- grad_clip + optimizer.step() — apply AdamW update to fp32 master copy.
- scheduler.step() — advance learning rate (cosine decay).
- checkpoint (every few hundred steps) — save state for resumability.
The split matters because the per-micro-batch operations are local (no cross-GPU communication), while the per-step operations require all-reduce (slow, communication-bound). Maximising work per step amortises the communication cost.
The numerical recipe
The “bf16 everywhere” version saves memory but training diverges past ~50K steps. The fp32 master copy + fp32 grad accumulation + fp32 Adam state is what gives mixed precision its stability. Adam’s gradient statistics (especially the v term, which involves squared gradients) need fp32 range to avoid underflow over millions of steps.
What actually happens at one step
Pretending you’re a 70B Llama in the middle of training, one step takes about 6 seconds of wall-clock on a 16K-H100 cluster. In those 6 seconds:
The all-reduce is interesting. Modern training uses ZeRO / FSDP to shard parameters, gradients, and optimizer state across the cluster — each GPU holds only 1/N of the model, parameters are gathered just-in-time per layer during forward, and the same shard structure absorbs the optimizer step. This is what lets a 70B model train on a 16K-GPU cluster without requiring 1 PB of total memory.
What goes wrong without fp32 master:
bf16 has ~7 bits of mantissa precision — about 3 decimal digits. The smallest representable update relative to a value of 1.0 is about 2⁻⁷ ≈ 0.008.
AdamW’s update step is on the order of lr × grad / √v ≈ 1e-4 × O(1) = 1e-4 per step. In bf16, this update is BELOW the smallest representable change to a weight of magnitude ≥ 0.013 — the weight rounds back to its original value, and the gradient is silently lost.
Over millions of training steps, this manifests as “training plateaus” — the model stops improving because most weight updates are being rounded to zero. The model converges to a worse optimum than fp32 would have reached.
The fp32 master copy fix:
Keep a fp32 copy of every parameter. Each step: bf16 forward + backward produces bf16 gradients; cast gradients to fp32; accumulate in fp32 gradient buffer; AdamW update applied to fp32 master; cast master to bf16 for the next forward.
Memory cost: 2× weights (one bf16, one fp32) — ~6 bytes per param instead of 2. Acceptable.
Benefit: AdamW updates are computed in fp32, not lost to bf16 rounding. The fp32 master accumulates micro-updates correctly over training; the bf16 copy is just a “current best snapshot” used for the next forward.
Why inference doesn’t have this problem:
Inference does ONE forward pass per token. There’s no optimizer; no millions of cumulative weight updates to add up. bf16 is fine for the forward pass because the per-layer error compounds over only 32-100 layers, not millions of optimization steps. Inference engines routinely use pure bf16 or fp16 with no fp32 master because the failure mode that motivates the master simply doesn’t apply.
This is a recurring theme in deep learning: the optimization process is what’s numerically delicate, not the model evaluation. Inference can be much lower precision than training without quality loss.
Setup:
Cluster: 16384 H100s. Each H100 holds a shard of the model + can do forward + backward on a batch of ~32K tokens at a time.
Three batching dimensions:
- Sequence length per micro-batch: 8K tokens (context length).
- Micro-batches per GPU: 4 sequences = 32K tokens per micro-batch per GPU (~the activation memory limit).
- Gradient accumulation: 1 (none — micro-batch IS the per-GPU contribution). But if you wanted larger effective batch, accumulate.
- Data-parallel ranks: 16384 / (TP × PP) ≈ 500 if tensor and pipeline parallel are 8 × 4. So 500 ranks contribute gradients each step.
Effective batch = 32K tokens/rank × 500 ranks = 16M tokens per step.
The cost-throughput trade-off:
- Smaller batch: faster wall-clock per step; less compute wasted on stale gradients; but gradient noise dominates training signal at very small batches. Empirically, batches under ~512K tokens hurt training quality on LLMs.
- Larger batch: better gradient statistics; lower variance; smoother loss curves; but past ~16M tokens, returns diminish — you’re adding compute without proportional learning improvement.
- Communication scaling: all-reduce time grows with DP rank count. At very large DP, communication can dominate compute and decrease utilisation.
16M tokens has been the empirical sweet spot for ~70B-class models since GPT-4 onwards. Smaller (1-4M) for smaller models; larger (32M+) for 1T+ MoE models.
The deeper point: “batch size” in LLM training is a hardware constraint dressed as a hyperparameter. You pick the largest batch that fits in compute + memory + communication budget; gradient accumulation is the bridge between per-GPU memory and effective-batch needs. Llama 3 70B’s 16M-token batch wasn’t optimal because of math — it was optimal because it’s what 16K H100s comfortably do per gradient step.
Next: §16.2 — The data pipeline. How Common Crawl → C4 → RefinedWeb → FineWeb evolved, and why “what tokens you train on” is now the dominant lever in LLM quality.