Summary
- Flash Linear Attention (FLA): hardware/IO-aware linear attention kernel.
- Chunk-wise block-parallel form interpolates between the parallel and recurrent form.
- Data-dependent gating mechanism for linear attention that still allows hardware-efficient chunkwise form.
Motivation
- Current implementations of linear attention lack I/O-awareness and are thus slower than highly optimized implementations of softmax attention.
- Forget gates are shown crucial in RNNs.
- RetNet: non-data-dependent decay factor.
- Mamba cannot parallelize training and is not hardware-efficient.
- Mamba2 uses a more restricted gating mechanism: is a scalar.
Methodology
Chunk-Wise Block-Parallel
- Softmax attention can be parallelized during training at the cost of quadratic complexity.
GLA
Recurrent Form
A general data-dependent forgetting gate can be formulated as
- A naive mapping would require matrix of size , which would be parameter-inefficient.
- Outer-product-based low-rank parameterization: .
- Mamba2 uses , where is a data-dependent scalar.
GLA adopts a middle ground:
flowchart LR subgraph Recurrent q_proj("Q") k_proj("K") v_proj("V") alpha_down[/"`$$W_{\alpha^-}$$`"\] alpha_up[\"`$$W_{\alpha^+}$$`"/] sigmoid(("`$$\sigma$$`")) matmul(("`$$\times$$`")) add(("`$$+$$`")) matmul2(("`$$\times$$`")) diag("`$$\text{Diag}(\cdot)$$`") outer_product(("`$$\times$$`")) v_proj --> outer_product k_proj --> outer_product outer_product --> add alpha_down --> alpha_up alpha_up --> sigmoid --"`$$\alpha_t$$`"--> diag --"`$$G_t$$`"--> matmul --> add add --> matmul2 q_proj --> matmul2 end state(("`$$S_{t - 1}$$`")) input(("`$$x_t$$`")) output(("`$$o_t$$`")) state_next(("`$$S_t$$`")) input --> q_proj input --> k_proj input --> v_proj input --> alpha_down state --> matmul matmul2 --> output add --> state_next