Attention
The math behind attention kernels. From dot products to softmax to FlashAttention—understand what every operation actually does.
- Implement standard scaled dot-product attention
- Explain why naive attention is O(n²) in memory
- Describe how FlashAttention achieves O(n) memory
- Implement online softmax for numerical stability
- Apply tiling strategies to attention computation
This chapter uses linear algebra and exponentials. Linear Algebra | Exponentials
The Dot Product: Measuring Relevance
Attention computes how relevant each cached token is to your query. The dot product is the measuring stick—it tells you how much two vectors "agree."
High dot product = vectors point in similar directions = high
relevance.
Zero dot product = vectors are perpendicular = no relationship.
In attention, you compute Q*K for every cached token. With a 4096-token context, that's 4096 dot products just to process one query. This is why attention is the bottleneck.
Softmax: From Scores to Probabilities
Raw dot products can be any value—positive, negative, huge, tiny. Softmax converts them to a probability distribution: all positive, sums to 1.
The exponential amplifies differences. A score of 10 vs 5 becomes e10/e5 ~ 150x more weight, not 2x. This makes attention "sharp"—it focuses on the most relevant tokens.
The problem: exp(100) = 2.7 x 1043. That overflows FP16 (max ~ 65504). Your kernel crashes. This is why we need the numerical stability trick.
Online Softmax: Streaming Without Overflow
Two problems with naive softmax:
1. Overflow: Large values explode exp(). Solution: Subtract max first.
2. Memory: You need to see ALL values to compute max. But in attention, you're processing in blocks to stay in fast SRAM. Solution: Online algorithm that updates incrementally.
Watch how the algorithm maintains running statistics as new values stream in. This is how FlashAttention processes attention in blocks.
Click "Add Random Block" to start the simulation.
# The online softmax update rule
def update(m_old, l_old, new_block):
m_block = new_block.max()
m_new = max(m_old, m_block)
# Rescale old accumulator to new max
l_new = l_old * exp(m_old - m_new)
# Add new block contribution
l_new += sum(exp(new_block - m_new))
return m_new, l_new
Floating Point: What FP8 and NVFP4 Actually Are
Your KV cache is quantized. Understanding the bit layout tells you what precision you're trading for memory bandwidth.
Sign Exponent Mantissa
Why NVFP4 works: Neural network values cluster tightly. A per-block scaling factor (stored in FP8) shifts the representable range to where your values actually are. You get 4x memory reduction vs FP16 with ~1% accuracy loss.
Putting It Together: Full Attention
Now you have all the pieces. Here's the complete attention equation and what each part does:
Click a step to see what happens at each stage of attention.
What your kernel must do efficiently:
1. Load Q (single vector for decode)
2. Stream through KV cache in blocks (fits in SRAM)
3. Compute dot products, track online softmax statistics
4. Accumulate weighted V vectors
5. Output final attention result
The bottleneck is memory bandwidth—loading all those K and V vectors from HBM. FP8/NVFP4 quantization halves or quarters that traffic.
Hands-On Labs
Citations & Further Reading
Video Resources
Outstanding visual explanation of attention, QKV, and the transformer architecture.
Watch on YouTubeBuild a transformer from scratch with detailed attention implementation.
Watch on YouTubeFoundational Papers
-
Attention Is All You Need - Vaswani et al., NeurIPS 2017
arXiv:1706.03762 -
FlashAttention - Dao et al., NeurIPS 2022
arXiv:2205.14135 -
FlashAttention-2 - Dao, 2023
arXiv:2307.08691