Chapter 10

Training Kernels

Understanding backward passes, mixed-precision training, and memory optimization techniques that make training large models possible on limited GPU memory.

What You'll Learn
  1. Implement backward passes for common operations
  2. Apply mixed-precision training with loss scaling
  3. Use activation checkpointing to trade compute for memory
  4. Implement gradient accumulation for larger effective batch sizes
  5. Debug training-specific numerical issues
01 - FORWARD VS BACKWARD

The Computational Graph

Training neural networks requires computing gradients via automatic differentiation. For every forward operation, there's a corresponding backward operation that propagates gradients.

Memory Requirements

Training consumes far more memory than inference because you must store:

Weights
Model params
Gradients
Same as weights
Optimizer States
2x weights (Adam)
Activations
Batch × Layers

For a 7B parameter model in FP32, the ZeRO paper breaks down memory consumption:

Weights
28
GB
Gradients
28
GB
Adam States
56
GB
Total (min)
112
GB
Why Activations Matter

Activations scale with batch size and sequence length. For transformers, attention activations grow as O(seq_len²) per layer unless using FlashAttention which doesn't materialize the full attention matrix.

Forward vs Backward Compute

A common rule of thumb: backward pass costs ~2x the forward pass. This comes from computing both the gradient with respect to inputs and weights. For a linear layer Y = XW:

Forward Pass
# Forward: 1 matmul
Y = X @ W  # (B, in) @ (in, out) = (B, out)
Backward Pass
# Backward: 2 matmuls
dX = dY @ W.T  # gradient w.r.t. input
dW = X.T @ dY  # gradient w.r.t. weights
A 13B parameter model trained with Adam in FP32 needs at minimum how much memory for weights + optimizer states?
52 GB (4 bytes × 13B)
104 GB (weights + gradients)
208 GB (weights + gradients + 2x optimizer states)

02 - BACKWARD KERNELS

Writing Backward Passes

Understanding backward kernels helps you write custom operations and debug gradient issues. Each operation must define how gradients flow backwards.

Linear Layer Backward

The linear layer backward pass requires saving the input tensor from the forward pass:

Triton Linear Backward
@triton.jit
def linear_backward_dx(
    dY_ptr, W_ptr, dX_ptr,
    M, N, K,  # M=batch, N=in_features, K=out_features
    stride_dy_m, stride_dy_k,
    stride_w_n, stride_w_k,
    stride_dx_m, stride_dx_n,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr
):
    """dX = dY @ W.T"""
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    
    # Accumulate dX[m, n] = sum_k(dY[m, k] * W[n, k])
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)
    
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    
    for k in range(0, K, BLOCK_K):
        dy = tl.load(dY_ptr + offs_m[:, None] * stride_dy_m + 
                     (k + offs_k[None, :]) * stride_dy_k)
        w = tl.load(W_ptr + offs_n[:, None] * stride_w_n + 
                    (k + offs_k[None, :]) * stride_w_k)
        acc += tl.dot(dy, tl.trans(w))
    
    tl.store(dX_ptr + offs_m[:, None] * stride_dx_m + 
             offs_n[None, :] * stride_dx_n, acc)

LayerNorm Backward

LayerNorm backward is more complex because normalization creates dependencies between elements. The gradients must account for the mean and variance computations:

LayerNorm Gradient Equations
# Forward: y = (x - mean) / std * gamma + beta
# Where std = sqrt(var + eps)

# Backward requires 3 passes:
# 1. Compute gradient contributions
dx_hat = dY * gamma                    # (B, D)
dvar = sum(dx_hat * (x - mean) * -0.5 * std^-3)
dmean = sum(dx_hat * -1/std) + dvar * sum(-2 * (x - mean)) / D

# 2. Compute dx
dx = dx_hat / std + dvar * 2 * (x - mean) / D + dmean / D

# 3. Compute dgamma, dbeta
dgamma = sum(dY * (x - mean) / std, dim=0)  # reduce over batch
dbeta = sum(dY, dim=0)
Numerical Stability in Backward

Division by std can cause issues when variance is very small. Always use the same epsilon in backward as forward, and consider using numerically stable formulations.

Attention Backward Memory

Standard attention backward is memory-intensive because it requires the full attention matrix. For sequence length L and batch B with H heads:

What to Store Size L=2048, B=8, H=32
Q, K, V tensors 3 × B × H × L × d ~3 GB
Attention matrix B × H × L × L ~8 GB
Softmax output B × H × L × L ~8 GB

This is why FlashAttention recomputes attention during backward instead of storing it—trading compute for memory.

Why does FlashAttention recompute the attention matrix during backward pass?
Recomputation is faster than memory access
Storing O(L²) attention matrix is prohibitively expensive for long sequences
The attention matrix changes between forward and backward

03 - MIXED PRECISION

FP16/BF16 Training with Loss Scaling

Mixed-precision training uses lower precision (FP16 or BF16) for forward/backward passes while maintaining FP32 master weights. This reduces memory and leverages Tensor Cores for speed.

The Precision Hierarchy

Format Bits Dynamic Range Use Case
FP32 32 ~1038 Master weights, optimizer states
FP16 16 ~65504 Forward/backward (requires loss scaling)
BF16 16 ~1038 Forward/backward (no loss scaling needed)
TF32 19 (internal) ~1038 Tensor Core matmuls on Ampere+

Loss Scaling

FP16 can't represent gradients smaller than ~6×10-8. Small gradients underflow to zero, causing training to stall. Loss scaling multiplies the loss by a large factor, scaling up all gradients:

PyTorch AMP with Loss Scaling
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for inputs, targets in dataloader:
    optimizer.zero_grad()
    
    # Forward in FP16
    with autocast():
        outputs = model(inputs)
        loss = criterion(outputs, targets)
    
    # Backward with scaled loss
    scaler.scale(loss).backward()  # loss × scale_factor
    
    # Unscale gradients, check for inf/nan, step
    scaler.step(optimizer)  # only steps if grads valid
    scaler.update()  # adjust scale factor

Dynamic Loss Scaling

The scale factor is adjusted dynamically:

BF16 Advantage

BF16 has the same exponent range as FP32, so gradients rarely underflow. Most modern training uses BF16 without loss scaling when hardware supports it (Ampere+ GPUs, TPUs).

What Stays in FP32

Even in mixed-precision training, some operations must remain in FP32:

Why does BF16 training usually not require loss scaling?
BF16 has more mantissa bits than FP16
BF16 has the same exponent range as FP32, so gradients don't underflow
BF16 uses stochastic rounding

04 - CHECKPOINTING

Trading Compute for Memory

Activation checkpointing (gradient checkpointing) saves memory by not storing all intermediate activations. Instead, it recomputes them during the backward pass.

How It Works

Without checkpointing, all activations are saved for backward:

No checkpoint
L1
L2
L3
L4
L5
L6
Memory
O(n) activations stored

With checkpointing every 2 layers:

Checkpointed
L1
L2
L3
L4
L5
L6
Memory
O(sqrt(n)) activations stored

Implementation

PyTorch Checkpointing
from torch.utils.checkpoint import checkpoint

class CheckpointedTransformer(nn.Module):
    def __init__(self, n_layers):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerBlock() for _ in range(n_layers)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            # checkpoint() doesn't save activations for this segment
            # Instead, recomputes them during backward
            x = checkpoint(layer, x, use_reentrant=False)
        return x

Memory vs Compute Tradeoff

Strategy Memory Compute Overhead When to Use
No checkpointing O(n) 0% Memory not a constraint
Checkpoint every k layers O(n/k) ~(k-1)/k × forward Moderate memory pressure
Checkpoint every layer O(1) per layer ~100% (2x forward) Severe memory constraints
Selective checkpointing Variable Variable Checkpoint attention, keep FFN
Selective Checkpointing

Not all operations benefit equally from checkpointing. Attention has O(L²) activations but O(L²) recompute cost. FFN has O(L×d) activations but cheap recompute. Selective checkpointing only checkpoints the most memory-intensive operations.

Checkpointing every layer reduces activation memory to O(1) per layer, but at what compute cost?
No extra compute—just memory savings
~50% more forward compute
~100% more forward compute (recompute entire forward during backward)

05 - GRADIENT ACCUMULATION

Simulating Larger Batch Sizes

Gradient accumulation lets you train with effective batch sizes larger than what fits in memory by accumulating gradients over multiple forward-backward passes before updating weights.

The Pattern

Micro-batch 1
FWD
BWD
Micro-batch 2
FWD
BWD (accumulate)
Micro-batch 3
FWD
BWD (accumulate)
Micro-batch 4
FWD
BWD (accumulate)
UPDATE

With 4 accumulation steps and micro-batch size 8, effective batch size = 32.

Implementation

Gradient Accumulation Loop
accumulation_steps = 4
optimizer.zero_grad()

for i, (inputs, targets) in enumerate(dataloader):
    # Forward + backward (gradients accumulate)
    with autocast():
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss = loss / accumulation_steps  # Scale loss
    
    scaler.scale(loss).backward()
    
    # Only update every N steps
    if (i + 1) % accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
Divide Loss by Accumulation Steps

Since gradients are summed (not averaged) across accumulation steps, you must divide the loss by the number of accumulation steps to get correct gradient magnitudes. Alternatively, ensure your loss is already a mean over the batch.

BatchNorm Interaction

BatchNorm computes statistics over micro-batches, not the effective batch. This can cause training instability with small micro-batches. Solutions:

With gradient accumulation over 8 steps and micro-batch size 16, what is the effective batch size?
16 (micro-batch only)
128 (8 × 16)
8 (accumulation steps only)

06 - TRAINING NUMERICS

Stability and Debugging

Training instability manifests as loss spikes, NaN losses, or gradients that explode/vanish. Understanding the numerical causes helps you fix them.

Gradient Clipping

Gradient clipping prevents exploding gradients by capping their magnitude:

Gradient Clipping in PyTorch
# After backward, before optimizer step
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# Or clip by value
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)

Common max_norm values: 1.0 for transformers, 0.1-0.5 for RNNs. Monitor gradient norms—if clipping activates frequently, investigate the cause.

Loss Spikes

Sudden loss increases often indicate numerical issues:

Symptom Likely Cause Fix
Spike then recovery Outlier batch, FP16 overflow Skip batch, adjust loss scale
Spike then NaN Gradient explosion Gradient clipping, lower LR
Gradual divergence LR too high, bad hyperparams LR warmup, hyperparameter search
NaN from start Weight init, input data issues Check data pipeline, init scale

Debugging NaN Gradients

NaN Detection Hook
def check_nan_hook(module, grad_input, grad_output):
    for i, grad in enumerate(grad_output):
        if grad is not None and torch.isnan(grad).any():
            print(f"NaN in {module.__class__.__name__} grad_output[{i}]")
            print(f"  grad stats: min={grad.min()}, max={grad.max()}")

# Register on all modules
for module in model.modules():
    module.register_backward_hook(check_nan_hook)
Anomaly Detection Mode

PyTorch's anomaly detection tracks which operation produced NaN: torch.autograd.set_detect_anomaly(True). Slows training significantly—use only for debugging.

Weight Initialization

Proper initialization prevents vanishing/exploding gradients at training start. Common schemes:

GPT-style models often scale residual path weights by 1/sqrt(2 * n_layers) as noted in the GPT-2 paper.

You see loss = NaN after a few hundred steps. Gradient norms were growing before the crash. Most likely fix?
Increase learning rate
Add gradient clipping
Remove loss scaling

07 - SUMMARY

Key Takeaways

Training Memory Breakdown
  • Training needs ~4x model weights for Adam (weights + gradients + 2x optimizer states)
  • Activations scale with batch size, sequence length, and model depth
  • Backward pass costs ~2x forward due to computing gradients for both inputs and weights
Memory Optimization Techniques
  • Mixed precision — FP16/BF16 halves memory for activations and gradients
  • Activation checkpointing — trade compute for memory by recomputing activations
  • Gradient accumulation — simulate larger batches without more memory
Numerical Stability
  • Use loss scaling with FP16 (not needed for BF16)
  • Gradient clipping prevents explosion—common max_norm = 1.0
  • Proper initialization scales with network depth