Skip to Content
All posts

Understanding Adaptive Sparse Flash Attention

 — #transformers#gpu-optimizations#low-level

This blog post explores Adaptive Sparse Flash Attention (AdaSplash), an optimization technique that combines the efficiency of Flash Attention with sparse attention patterns. We'll cover:

  1. Vanilla Attention: Understanding the baseline transformer attention mechanism and its computational challenges
  2. Flash Attention: How block-wise computation and careful memory management improves efficiency
  3. α-entmax: A differentiable sparse alternative to softmax that learns to focus on relevant tokens
  4. AdaSplash: The novel combination of Flash Attention's memory optimizations with α-entmax's sparsity
  5. Implementation: Practical code examples and performance analysis

Vanilla Attention

Vanilla attention is the foundational mechanism in transformer architectures, computing relevance scores between all pairs of tokens in a sequence. The standard formula for attention is:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

However, vanilla attention has significant computational and memory challenges:

  1. Quadratic Memory: The attention matrix QKTQK^T has size (n×n)(n \times n), meaning memory requirements grow quadratically with sequence length. For a sequence of length 1024, we need to store over 1 million attention scores.

  2. Quadratic Compute: Computing all pairwise interactions between tokens requires O(n2)O(n^2) operations. This becomes prohibitively expensive for long sequences.

  3. Memory Access Pattern: The algorithm requires multiple passes over the large attention matrix:

    • First to compute QKTQK^T
    • Then to apply softmax normalization
    • Finally to multiply with VV

For example, with a sequence length of 1024 and hidden dimension of 64:

  • QKTQK^T computation: ~67 million FLOPs
  • Memory required for attention matrix: 4MB (assuming float32)
  • Total memory bandwidth used: >12MB due to multiple passes

These limitations make vanilla attention impractical for longer sequences, motivating the development of more efficient variants like Flash Attention and Adaptive Splash Attention.

Flash Attention (Dao et al)

Flash Attention splits the QQ and KK matrices into blocks that fit in GPU SRAM (a small but fast memory cache, typically 64-128KB per SM). By carefully managing data movement between SRAM and slower DRAM memory, it significantly reduces memory bandwidth and improves performance.

The key insight is breaking down attention computation into SRAM-sized blocks:

For a block of queries QiQ_i and keys KjK_j, we compute:

Sij=QiKjT=[qi1kj1qi1kj2qi2kj1qi2kj2]S_{ij} = Q_i K_j^T = \begin{bmatrix} q_{i1} \cdot k_{j1} & q_{i1} \cdot k_{j2} & \cdots \\ q_{i2} \cdot k_{j1} & q_{i2} \cdot k_{j2} & \cdots \\ \vdots & \vdots & \ddots \end{bmatrix}

The algorithm:

  1. Loads Q and K blocks into SRAM
  2. Computes local attention scores
  3. Updates softmax statistics
  4. Multiplies with V block
  5. Accumulates results

It maintains running statistics for stable softmax:

mi=maxjBSijli=jBeSijmim_i = \max_{j \leq B} S_{ij} \quad \quad l_i = \sum_{j \leq B} e^{S_{ij} - m_i}

Key benefits:

  • O(1) memory complexity vs O(n²)
  • ~10x less memory bandwidth
  • Better cache usage
  • No full attention matrix storage

While Flash Attention requires slightly more compute operations, the dramatic reduction in memory access makes it significantly faster than vanilla attention, especially for long sequences.

Splash Attention

In traditional English, we notice that not every word will provide valuable information or "attention" to every other word.

For example in the sentence "The cat quickly jumped over the brown fence and landed gracefully on the other side", words like "the" and "and" contribute relatively little attention to understanding the core meaning compared to content words like "cat", "jumped", "fence", and "gracefully".

Therefore, we can introduce some sparsity via having a "sliding window" of some sorts. In this case, we introduce a mask that reduces compute from O(N2D)O(N^2 * D) to O(NSD)O(NSD) where every query attends to S<<NS<<N keys

This is where sparsity arises from. In the dense case, we have to multiply every QiQ_i to every KjK_j in the sequence. Sparse attention masks evertyhing that isn't within a region.

  • Q0Q_0: K0,K1{K_0, K_1}
  • Q1Q_1: K0,K1,K2{K_0, K_1, K_2}
  • Q2Q_2: K1,K2,K3{K_1, K_2, K_3}
  • Q3Q_3: K2,K3,K4{K_2, K_3, K_4}
  • Q4Q_4: K3,K4,K5{K_3, K_4, K_5}

Adaptive Splash Attention (Goncalves et al.)

Now we know an example of splash attention and flash attention, we can dive into the literature of Adaptive-Splash Attention. Something to keep in mind is that "sparse" has many definitions in attention optimizations.

  • Region-based sparsity: Only compute attention within predefined blocks or regions of the sequence
  • Threshold-based pruning: Zero out attention weights that fall below a certain threshold, removing weak connections
  • Learnable sparsity: Use trainable parameters to adaptively determine which attention connections to keep or drop during training

So currently, the original softmax is dense, meaning that it puts non-zero probability on all tokens.

An alternative to this is the α-entmax transformation, which can learn to put exactly zero probability on some tokens, creating sparse attention patterns. The α parameter controls how sparse the output distribution becomes - as α increases, more tokens get zero probability.

Softmax Alternative

Let's break down the α-entmax formula:

α-entmax(s)=[(α1)sτ]+1α1\text{α-entmax}(s) = [(α-1)s - τ]_+^{\frac{1}{α-1}}

Where:

  • ss is the input score vector (logits)
  • ττ is a threshold/normalizing constant
  • []+[·]_+ is the ReLU function that zeros out negative values
  • α>1α > 1 is the sparsity parameter

This formula is quite elegant in how it achieves sparsity:

  1. The term (α1)sτ(α-1)s - τ shifts and scales the input scores

  2. The ReLU function []+[·]_+ zeros out any values below the threshold τ/(α1)τ/(α-1), creating sparsity

  3. The exponent 1α1\frac{1}{α-1} rescales the remaining non-zero values to form a valid probability distribution

The key insight is that unlike softmax which always gives non-zero probabilities, α-entmax can output exact zeros when input scores fall below the learned threshold ττ. The αα parameter controls how aggressive this thresholding is:

  • As α1α \to 1: Approaches softmax (dense)
  • α=1.5α = 1.5: Moderate sparsity
  • α=2.0α = 2.0: High sparsity (sparsemax)
  • α>2α > 2: Very sparse attention

The threshold parameter ττ plays a crucial role in α-entmax by determining which attention scores get zeroed out. Specifically:

  1. ττ is computed to ensure the output probabilities sum to 1: i[(α1)siτ]+1α1=1\sum_i [(α-1)s_i - τ]_+^{\frac{1}{α-1}} = 1

  2. Values of sis_i where (α1)si<τ(α-1)s_i < τ get mapped to exactly 0

This thresholding behavior is what enables α-entmax to learn sparse attention patterns adaptively during training, focusing only on the most relevant tokens.

Halley Bisection Theorem

To be able to find this perfect ττ, we could run a traditional bisection (binary search) algorithm. However, we can use Halley's method of convergence that uses first and second derivatives to offer cubic convergence.

Let's assume that we are trying to find the root of this function:

f(τ)=i[(α1)siτ]+1α11f(\tau) = \sum_i [(α-1)s_i - \tau]_+^{\frac{1}{α-1}} - 1

The bisection algorithm updates the search interval based on the function value:

  • If f(τ)<0f(τ) < 0: Set interval to (τlo,τ)(τ_{lo}, τ)
  • Otherwise: Set interval to (τ,τhi)(τ, τ_{hi})
  • After each iteration, we update ττ to be the midpoint: τ=τlo+τhi2τ = \frac{τ_{lo} + τ_{hi}}{2}

However, with Halley's method, we can use the first derivatives as follows:

f(τ)=i1α1[(α1)siτ]+1α11f'(\tau) = -\sum_i \frac{1}{α-1} [(α-1)s_i - \tau]_+^{\frac{1}{α-1}-1}
f(τ)=i2α(α1)2[(α1)siτ]+1α12f''(\tau) = \sum_i \frac{2-α}{(α-1)^2} [(α-1)s_i - \tau]_+^{\frac{1}{α-1}-2}

and we get Halley's root finding as:

τn+1=τn2f(τn)f(τn)2(f(τn))2f(τn)f(τn)\tau_{n+1} = \tau_n - \frac{2f(\tau_n)f'(\tau_n)}{2(f'(\tau_n))^2 - f(\tau_n)f''(\tau_n)}

Additionally, to ensure convergence and we are always tau's estimation, we have a fail-safe mechanism that uses bisection when Halley's method prdouces an update that moves the solution out of the bisection bounds.

Forward Pass Implementation

For the forward pass, I implemented block-wise computation to keep memory usage low. Let me walk you through how I did it:

  1. They split up the computation into blocks so they don't need to store the whole attention matrix at once. The algorithm will split QQ into TrT_r blocks and KK into TcT_c blocks.

    f(τ)=j=1Tcf(τ;Si(j))f(\tau) = \sum_{j=1}^{T_c} f(\tau; S_i^{(j)})

    Here Si(j)S_i^{(j)} is just a slice of the score matrix for each block. This is the function as a sum over all the blocks.

  2. They will recompute these SS and PP matrices during backpropogation, similar to gradient checkpointing. This also sees an increase im space constraint while a decrease in memory.

  3. We will conduct sparse masking at the block level based on any individual score Sij>TiS_{ij}>T_i

Mij={1if iI(i),jJ(j) such that Sij>τi0otherwiseM_{ij} = \begin{cases} 1 & \text{if } \exists i' \in \mathcal{I}(i), j' \in \mathcal{J}(j) \text{ such that } S_{i'j'} > \tau_{i'} \\ 0 & \text{otherwise} \end{cases}

Backward Pass

The backward pass leverages sparsity in the α-entmax Jacobian for memory and compute efficiency. Let's break down how this works:

First, we need to differentiate through:

p=α-entmax(s)\mathbf{p} = \text{α-entmax}(\mathbf{s})

where s\mathbf{s} is the attention score vector and p\mathbf{p} is the attention weight vector. We need ps\frac{\partial \mathbf{p}}{\partial \mathbf{s}} for backpropagation.

The Jacobian of α-entmax (Peters et al., 2019) is:

α-entmax(s)s=Diag(u)uuu1\frac{\partial \text{α-entmax}(\mathbf{s})}{\partial \mathbf{s}} = \text{Diag}(\mathbf{u}) - \frac{\mathbf{u} \mathbf{u}^\top}{\|\mathbf{u}\|_1}

where uj=pj2αu_j = p_j^{2 - \alpha}. This Jacobian is naturally sparse since many pj=0p_j = 0 (and thus uj=0u_j = 0), zeroing out rows/columns.

For efficient block-wise computation, we define:

  • URn×nU \in \mathbb{R}^{n \times n}: matrix with Ulk=Plk2αU_{lk} = P_{lk}^{2 - \alpha}
  • Ui(j)RBr×BcU_i^{(j)} \in \mathbb{R}^{B_r \times B_c}: block of UU for query block ii and key block jj

The gradient with respect to scores is then:

dSi(j)=Ui(j)dPi(j)Diag(δi)Ui(j)dS_i^{(j)} = U_i^{(j)} \odot dP_i^{(j)} - \text{Diag}(\delta_i) U_i^{(j)}

where:

  • dPi(j)=dOiVjdP_i^{(j)} = dO_i V_j^\top
  • δl=kUlkdPlkkUlk\delta_l = \frac{\sum_k U_{lk} \cdot dP_{lk}}{\sum_k U_{lk}} for each row ll

The first term Ui(j)dPi(j)U_i^{(j)} \odot dP_i^{(j)} handles element-wise gradient scaling based on the sparsity pattern, while Diag(δi)Ui(j)\text{Diag}(\delta_i) U_i^{(j)} provides the necessary correction to ensure projection orthogonal to uu.

This formulation is efficient because it:

  • Avoids materializing the full Jacobian by only computing non-zero elements
  • Works block-wise to align with GPU memory layout
  • Enables fast, memory-efficient backpropagation through α-entmax attention

The rest of the matix updates for Q,K,VQ,K,V are trivial chain-rule implementations following the derivative w.r.t ss.

Code Implementation

You can find my full code implementation in CUDA here.

You can find the author's full code implementation in Triton here.