Chapter 7

Common Kernels

Beyond attention: the essential kernels that make up modern neural networks. LayerNorm, RMSNorm, fused activations, embeddings, and optimizers.

What You'll Learn
  1. Implement fused element-wise operations (activation + bias)
  2. Write efficient LayerNorm and RMSNorm kernels
  3. Optimize embedding lookups for arbitrary indices
  4. Understand fused optimizer patterns (Adam)
  5. Know when to write custom kernels vs using cuDNN
01 - KERNEL FUSION

Why Fusion Matters

Most neural network operations are memory-bound—limited by how fast you can move data, not compute. Fusion reduces memory traffic by combining multiple operations into one kernel.

Unfused vs Fused: Memory Traffic

Unfused: 6 Memory Ops
Load x Add bias Store tmp1
Load tmp1 GELU Store out

2 kernel launches, 6 memory operations

Fused: 2 Memory Ops
Load x Add bias + GELU Store out

1 kernel launch, 2 memory operations → 3x less memory traffic

Fused Bias + Activation in Triton

Triton: Fused bias + GELU (Triton Tutorial Style)
@triton.jit
def fused_bias_gelu(
    x_ptr, bias_ptr, out_ptr,
    N,  # number of elements
    BLOCK: tl.constexpr
):
    pid = tl.program_id(0)
    offsets = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offsets < N
    
    # Load x and bias (bias is broadcast)
    x = tl.load(x_ptr + offsets, mask=mask)
    bias = tl.load(bias_ptr + offsets % BIAS_SIZE, mask=mask)
    
    # Fused: add bias, then GELU
    x = x + bias
    # GELU approximation: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
    out = x * 0.5 * (1.0 + tl.libdevice.tanh(
        0.7978845608 * (x + 0.044715 * x * x * x)
    ))
    
    tl.store(out_ptr + offsets, out, mask=mask)
GELU Approximation

The GELU activation (Hendrycks & Gimpel, 2016) uses a tanh approximation for efficiency. PyTorch's nn.GELU(approximate='tanh') uses this form. The exact form uses erf() which is slower.

Fusing two element-wise operations primarily improves performance by:
Reducing computation
Reducing memory traffic (fewer loads/stores)
Using fewer threads

02 - NORMALIZATION

LayerNorm and RMSNorm

Layer Normalization (Ba et al., 2016) is ubiquitous in Transformers. RMSNorm (Zhang & Sennrich, 2019) is a simplified variant used in LLaMA and other modern architectures.

LayerNorm vs RMSNorm

Operation LayerNorm RMSNorm
Formula γ * (x - μ) / √(σ² + ε) + β γ * x / √(mean(x²) + ε)
Compute mean? Yes (for centering) No
Learnable bias? Yes (β) No
Passes over data 2 (mean, then variance) 1 (just RMS)
Used in BERT, GPT-2, T5 LLaMA, Mistral, Gemma

LayerNorm Kernel Anatomy

1

Load row into registers/shared memory

Each thread block handles one or more rows. Load the entire row for reduction.

2

Compute mean via parallel reduction

Sum all elements, divide by count. Use warp shuffles for efficiency.

3

Compute variance via parallel reduction

Sum (x - mean)², divide by count. Can use Welford's algorithm for single-pass.

4

Normalize and apply affine transform

Compute (x - mean) * rsqrt(var + eps) * gamma + beta. Write output.

RMSNorm: Simpler and Faster

RMSNorm in Triton
@triton.jit
def rmsnorm_kernel(
    x_ptr, weight_ptr, out_ptr,
    stride, N,
    eps: tl.constexpr,
    BLOCK: tl.constexpr
):
    row = tl.program_id(0)
    x_ptr += row * stride
    out_ptr += row * stride
    
    # Load row
    cols = tl.arange(0, BLOCK)
    mask = cols < N
    x = tl.load(x_ptr + cols, mask=mask, other=0.0).to(tl.float32)
    
    # Compute RMS (no mean subtraction!)
    x_sq = x * x
    mean_sq = tl.sum(x_sq) / N
    rms = tl.rsqrt(mean_sq + eps)
    
    # Normalize and scale
    weight = tl.load(weight_ptr + cols, mask=mask)
    out = x * rms * weight
    
    tl.store(out_ptr + cols, out, mask=mask)
Numerical Stability

Always compute reductions in FP32, even for FP16 inputs. Accumulating many small values in FP16 causes precision loss. Cast back to the output dtype only at the final store.

Why is RMSNorm faster than LayerNorm?
It uses less memory
It skips mean computation and has no bias parameter
It uses Tensor Cores
It has better cache behavior

03 - EMBEDDINGS

Embedding Lookups

Embedding tables convert token IDs to vectors. The challenge: indices are arbitrary, causing non-coalesced memory access.

The Coalescing Problem

Sequential indices (rare)

Indices: [0, 1, 2, 3, 4, 5, 6, 7]

0
1
2
3
4
5
6
7

Coalesced: 1 memory transaction

Random indices (typical)

Indices: [42, 7, 1024, 3, 999, 15, 42, 100]

42
7
1K
3
999
15
42
100

Scattered: up to 8 memory transactions

Optimization Strategies

Strategy When to Use Trade-off
Vectorized loads Embedding dim divisible by 4 Load float4 instead of float → 4x fewer transactions per row
L2 cache persistence Repeated access to same embeddings Use cudaAccessPolicyWindow to pin hot embeddings
Sorted indices Batch allows reordering Sort indices to improve locality, then unsort output
Embedding bag Sum/mean pooling over variable-length sequences Fuse gather + reduction in one kernel
Vectorized Embedding Lookup
@triton.jit
def embedding_kernel(
    indices_ptr, weight_ptr, out_ptr,
    seq_len, embed_dim,
    weight_stride,
    BLOCK_SIZE: tl.constexpr
):
    # Each program handles one token
    token_idx = tl.program_id(0)
    
    # Load the vocabulary index for this token
    vocab_idx = tl.load(indices_ptr + token_idx)
    
    # Calculate pointer to embedding row
    embed_ptr = weight_ptr + vocab_idx * weight_stride
    
    # Load embedding in chunks (vectorized)
    offsets = tl.arange(0, BLOCK_SIZE)
    mask = offsets < embed_dim
    embedding = tl.load(embed_ptr + offsets, mask=mask)
    
    # Store to output
    out_offset = token_idx * embed_dim
    tl.store(out_ptr + out_offset + offsets, embedding, mask=mask)
Why are embedding lookups memory-inefficient compared to matrix multiplies?
Random indices cause non-coalesced (scattered) memory access
Embeddings use more memory
They can't use Tensor Cores

04 - FUSED OPTIMIZERS

Fused Adam

The Adam optimizer (Kingma & Ba, 2014) updates parameters using first and second moment estimates. A naive implementation requires multiple kernel launches and memory passes.

Adam Memory Traffic

Operation Unfused (separate kernels) Fused (one kernel)
Load gradient 1 load 1 kernel
Load: g, m, v, p
Store: m, v, p
= 7 memory ops
Update m = β₁m + (1-β₁)g Load m, store m
Update v = β₂v + (1-β₂)g² Load g, load v, store v
Compute m̂, v̂ (bias correction) Load m, load v
Update p = p - lr * m̂ / (√v̂ + ε) Load p, store p
Total memory ops 11 ops (5 kernels) 7 ops (1 kernel)
Fused Adam Kernel (simplified)
@triton.jit
def fused_adam(
    param_ptr, grad_ptr, m_ptr, v_ptr,
    lr, beta1, beta2, eps, step,
    N, BLOCK: tl.constexpr
):
    pid = tl.program_id(0)
    offsets = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offsets < N
    
    # Load all tensors once
    p = tl.load(param_ptr + offsets, mask=mask)
    g = tl.load(grad_ptr + offsets, mask=mask)
    m = tl.load(m_ptr + offsets, mask=mask)
    v = tl.load(v_ptr + offsets, mask=mask)
    
    # Update moments
    m = beta1 * m + (1 - beta1) * g
    v = beta2 * v + (1 - beta2) * g * g
    
    # Bias correction
    m_hat = m / (1 - tl.math.pow(beta1, step))
    v_hat = v / (1 - tl.math.pow(beta2, step))
    
    # Update parameters
    p = p - lr * m_hat / (tl.sqrt(v_hat) + eps)
    
    # Store updated values
    tl.store(param_ptr + offsets, p, mask=mask)
    tl.store(m_ptr + offsets, m, mask=mask)
    tl.store(v_ptr + offsets, v, mask=mask)
Multi-Tensor Apply

Real-world optimizers use multi-tensor apply: process all model parameters in one kernel launch. This amortizes kernel launch overhead (which can be ~5-10μs per launch) across all parameters.

Fused Adam stores which tensors?
Just the updated parameters
Parameters and gradients
Parameters, first moment (m), and second moment (v)
Only the moments

05 - LIBRARIES VS CUSTOM

When to Write Custom Kernels

cuDNN and cuBLAS are highly optimized. Don't reinvent the wheel—but know when custom kernels win.

Use Libraries When...

Operation Library Why
Matrix multiply (GEMM) cuBLAS Tensor Core optimized, auto-tuned per GPU
Convolution cuDNN Multiple algorithms (Winograd, FFT, implicit GEMM), auto-tuned
Batch normalization cuDNN Fused forward+backward, running stats handled
Attention (standard) FlashAttention IO-aware, extensively optimized

Write Custom When...

Scenario Example Why Custom Wins
Fusion opportunities Bias + GELU + Dropout Libraries can't fuse across op boundaries
Non-standard shapes Very small matrices, odd dimensions Libraries optimized for common sizes
Custom attention patterns Sliding window, sparse patterns Standard attention doesn't support masking patterns
Research ops Novel activations, custom losses No library implementation exists
Benchmark First

Always benchmark your custom kernel against the library version. A "clever" custom kernel that's 20% slower than cuDNN is a waste of engineering time. Profile with realistic batch sizes and shapes.


REFERENCES

Citations & Further Reading

Papers

  1. Layer Normalization
    Ba, Kiros, Hinton, 2016
    arXiv:1607.06450
  2. Root Mean Square Layer Normalization
    Zhang & Sennrich, 2019
    arXiv:1910.07467
  3. GELU Activation
    Hendrycks & Gimpel, 2016
    arXiv:1606.08415
  4. Adam Optimizer
    Kingma & Ba, 2014
    arXiv:1412.6980
  5. LLaMA (uses RMSNorm)
    Touvron et al., 2023
    arXiv:2302.13971

Documentation

  1. Triton Tutorials
    Fused softmax, matrix multiply, and more
    triton-lang.org/tutorials
  2. CUDA C++ Best Practices Guide
    Memory coalescing, occupancy, optimization
    docs.nvidia.com/cuda
  3. cuDNN Documentation
    Convolution, normalization, attention APIs
    docs.nvidia.com/cudnn
  4. APEX Fused Optimizers
    Multi-tensor apply, fused Adam/LAMB
    nvidia.github.io/apex
All material licensed under CC BY-NC-SA 4.0