Understanding Adaptive Sparse Flash Attention
— #transformers#gpu-optimizations#low-level
This blog post explores Adaptive Sparse Flash Attention (AdaSplash), an optimization technique that combines the efficiency of Flash Attention with sparse attention patterns. We'll cover:
- Vanilla Attention: Understanding the baseline transformer attention mechanism and its computational challenges
- Flash Attention: How block-wise computation and careful memory management improves efficiency
- α-entmax: A differentiable sparse alternative to softmax that learns to focus on relevant tokens
- AdaSplash: The novel combination of Flash Attention's memory optimizations with α-entmax's sparsity
- Implementation: Practical code examples and performance analysis
Vanilla Attention
Vanilla attention is the foundational mechanism in transformer architectures, computing relevance scores between all pairs of tokens in a sequence. The standard formula for attention is:
However, vanilla attention has significant computational and memory challenges:
Quadratic Memory: The attention matrix has size , meaning memory requirements grow quadratically with sequence length. For a sequence of length 1024, we need to store over 1 million attention scores.
Quadratic Compute: Computing all pairwise interactions between tokens requires operations. This becomes prohibitively expensive for long sequences.
Memory Access Pattern: The algorithm requires multiple passes over the large attention matrix:
- First to compute
- Then to apply softmax normalization
- Finally to multiply with
For example, with a sequence length of 1024 and hidden dimension of 64:
- computation: ~67 million FLOPs
- Memory required for attention matrix: 4MB (assuming float32)
- Total memory bandwidth used: >12MB due to multiple passes
These limitations make vanilla attention impractical for longer sequences, motivating the development of more efficient variants like Flash Attention and Adaptive Splash Attention.
Flash Attention (Dao et al)
Flash Attention splits the and matrices into blocks that fit in GPU SRAM (a small but fast memory cache, typically 64-128KB per SM). By carefully managing data movement between SRAM and slower DRAM memory, it significantly reduces memory bandwidth and improves performance.
The key insight is breaking down attention computation into SRAM-sized blocks:
For a block of queries and keys , we compute:
The algorithm:
- Loads Q and K blocks into SRAM
- Computes local attention scores
- Updates softmax statistics
- Multiplies with V block
- Accumulates results
It maintains running statistics for stable softmax:
Key benefits:
- O(1) memory complexity vs O(n²)
- ~10x less memory bandwidth
- Better cache usage
- No full attention matrix storage
While Flash Attention requires slightly more compute operations, the dramatic reduction in memory access makes it significantly faster than vanilla attention, especially for long sequences.
Splash Attention
In traditional English, we notice that not every word will provide valuable information or "attention" to every other word.
For example in the sentence "The cat quickly jumped over the brown fence and landed gracefully on the other side", words like "the" and "and" contribute relatively little attention to understanding the core meaning compared to content words like "cat", "jumped", "fence", and "gracefully".
Therefore, we can introduce some sparsity via having a "sliding window" of some sorts. In this case, we introduce a mask that reduces compute from to where every query attends to keys
This is where sparsity arises from. In the dense case, we have to multiply every to every in the sequence. Sparse attention masks evertyhing that isn't within a region.
- :
- :
- :
- :
- :
Adaptive Splash Attention (Goncalves et al.)
Now we know an example of splash attention and flash attention, we can dive into the literature of Adaptive-Splash Attention. Something to keep in mind is that "sparse" has many definitions in attention optimizations.
- Region-based sparsity: Only compute attention within predefined blocks or regions of the sequence
- Threshold-based pruning: Zero out attention weights that fall below a certain threshold, removing weak connections
- Learnable sparsity: Use trainable parameters to adaptively determine which attention connections to keep or drop during training
So currently, the original softmax is dense, meaning that it puts non-zero probability on all tokens.
An alternative to this is the α-entmax transformation, which can learn to put exactly zero probability on some tokens, creating sparse attention patterns. The α parameter controls how sparse the output distribution becomes - as α increases, more tokens get zero probability.
Softmax Alternative
Let's break down the α-entmax formula:
Where:
- is the input score vector (logits)
- is a threshold/normalizing constant
- is the ReLU function that zeros out negative values
- is the sparsity parameter
This formula is quite elegant in how it achieves sparsity:
The term shifts and scales the input scores
The ReLU function zeros out any values below the threshold , creating sparsity
The exponent rescales the remaining non-zero values to form a valid probability distribution
The key insight is that unlike softmax which always gives non-zero probabilities, α-entmax can output exact zeros when input scores fall below the learned threshold . The parameter controls how aggressive this thresholding is:
- As : Approaches softmax (dense)
- : Moderate sparsity
- : High sparsity (sparsemax)
- : Very sparse attention
The threshold parameter plays a crucial role in α-entmax by determining which attention scores get zeroed out. Specifically:
is computed to ensure the output probabilities sum to 1:
Values of where get mapped to exactly 0
This thresholding behavior is what enables α-entmax to learn sparse attention patterns adaptively during training, focusing only on the most relevant tokens.
Halley Bisection Theorem
To be able to find this perfect , we could run a traditional bisection (binary search) algorithm. However, we can use Halley's method of convergence that uses first and second derivatives to offer cubic convergence.
Let's assume that we are trying to find the root of this function:
The bisection algorithm updates the search interval based on the function value:
- If : Set interval to
- Otherwise: Set interval to
- After each iteration, we update to be the midpoint:
However, with Halley's method, we can use the first derivatives as follows:
and we get Halley's root finding as:
Additionally, to ensure convergence and we are always tau's estimation, we have a fail-safe mechanism that uses bisection when Halley's method prdouces an update that moves the solution out of the bisection bounds.
Forward Pass Implementation
For the forward pass, I implemented block-wise computation to keep memory usage low. Let me walk you through how I did it:
They split up the computation into blocks so they don't need to store the whole attention matrix at once. The algorithm will split into blocks and into blocks.
Here is just a slice of the score matrix for each block. This is the function as a sum over all the blocks.
They will recompute these and matrices during backpropogation, similar to gradient checkpointing. This also sees an increase im space constraint while a decrease in memory.
We will conduct sparse masking at the block level based on any individual score
Backward Pass
The backward pass leverages sparsity in the α-entmax Jacobian for memory and compute efficiency. Let's break down how this works:
First, we need to differentiate through:
where is the attention score vector and is the attention weight vector. We need for backpropagation.
The Jacobian of α-entmax (Peters et al., 2019) is:
where . This Jacobian is naturally sparse since many (and thus ), zeroing out rows/columns.
For efficient block-wise computation, we define:
- : matrix with
- : block of for query block and key block
The gradient with respect to scores is then:
where:
- for each row
The first term handles element-wise gradient scaling based on the sparsity pattern, while provides the necessary correction to ensure projection orthogonal to .
This formulation is efficient because it:
- Avoids materializing the full Jacobian by only computing non-zero elements
- Works block-wise to align with GPU memory layout
- Enables fast, memory-efficient backpropagation through α-entmax attention
The rest of the matix updates for are trivial chain-rule implementations following the derivative w.r.t .
Code Implementation
You can find my full code implementation in CUDA here.
You can find the author's full code implementation in Triton here.