Summary

  • Flash Linear Attention (FLA): hardware/IO-aware linear attention kernel.
    • Chunk-wise block-parallel form interpolates between the parallel and recurrent form.
  • Data-dependent gating mechanism for linear attention that still allows hardware-efficient chunkwise form.

Motivation

  • Current implementations of linear attention lack I/O-awareness and are thus slower than highly optimized implementations of softmax attention.
  • Forget gates are shown crucial in RNNs.
    • RetNet: non-data-dependent decay factor.
    • Mamba cannot parallelize training and is not hardware-efficient.
    • Mamba2 uses a more restricted gating mechanism: is a scalar.

Methodology

Chunk-Wise Block-Parallel

  • Softmax attention can be parallelized during training at the cost of quadratic complexity.

GLA

Recurrent Form

A general data-dependent forgetting gate can be formulated as

  • A naive mapping would require matrix of size , which would be parameter-inefficient.
  • Outer-product-based low-rank parameterization: .
  • Mamba2 uses , where is a data-dependent scalar.

GLA adopts a middle ground:

flowchart LR

subgraph Recurrent
    q_proj("Q")
    k_proj("K")
    v_proj("V")
    alpha_down[/"`$$W_{\alpha^-}$$`"\]
    alpha_up[\"`$$W_{\alpha^+}$$`"/]
    sigmoid(("`$$\sigma$$`"))
    matmul(("`$$\times$$`"))
    add(("`$$+$$`"))
    matmul2(("`$$\times$$`"))
    diag("`$$\text{Diag}(\cdot)$$`")
    
    outer_product(("`$$\times$$`"))
    v_proj --> outer_product
    k_proj --> outer_product
    outer_product --> add
    alpha_down --> alpha_up
    alpha_up --> sigmoid --"`$$\alpha_t$$`"--> diag --"`$$G_t$$`"--> matmul --> add
    add --> matmul2
    q_proj --> matmul2
end

state(("`$$S_{t - 1}$$`"))
input(("`$$x_t$$`"))
output(("`$$o_t$$`"))
state_next(("`$$S_t$$`"))

input --> q_proj
input --> k_proj
input --> v_proj 
input --> alpha_down
state --> matmul
matmul2 --> output
add --> state_next