Parallelizing Linear Transformers with the Delta Rule over Sequence Length (arxiv link).
Motivation
- Linear attention is pure addition.
- Forgetting factor cannot fully forget previous context.
- Model should learn to remove less important memory.
The Delta Rule
This can be regarded as optimizing an online regression loss using a single step of SGD:
An alternative interpretation for DeltaNet is from the perspective of key-value retrieval:
- Retrieves the old value using the current key .
- Obtains a new value by interpolating between the old value and the current new value , which replaces in the memory.
DeltaNet Transformer
--- config: look: handDrawn --- flowchart BT in(("in")) qk_proj("Linear") qk_conv("Conv<sup>*</sup>") qk_silu("SiLU") qk_l2norm("L2 Norm") v_proj("Linear") v_conv("Conv<sup>*</sup>") v_silu("SiLU") beta_proj("Linear") beta_sigmoid("Sigmoid") delta_rule("Delta Rule") rmsnorm("RMSNorm") o_proj("Linear") out(("out")) in --> qk_proj --> qk_conv --> qk_silu --> qk_l2norm -- "q, k" --> delta_rule in --> v_proj --> v_conv --> v_silu -- "v" --> delta_rule in --> beta_proj --> beta_sigmoid -- "beta" --> delta_rule delta_rule --> rmsnorm --> o_proj o_proj --> out
- The convolution layers are depthwise-separable convolution that generalizes the shift SSM.
- Interleave DeltaNet layers and SWA layers.
- Two layers with global attention: 2nd and the -th layer.