Parallelizing Linear Transformers with the Delta Rule over Sequence Length (arxiv link).

Motivation

  • Linear attention is pure addition.
  • Forgetting factor cannot fully forget previous context.
  • Model should learn to remove less important memory.

The Delta Rule

This can be regarded as optimizing an online regression loss using a single step of SGD:

An alternative interpretation for DeltaNet is from the perspective of key-value retrieval:

  1. Retrieves the old value using the current key .
  2. Obtains a new value by interpolating between the old value and the current new value , which replaces in the memory.

DeltaNet Transformer

---
config:
  look: handDrawn
---
flowchart BT
in(("in"))
qk_proj("Linear")
qk_conv("Conv<sup>*</sup>")
qk_silu("SiLU")
qk_l2norm("L2 Norm")
v_proj("Linear")
v_conv("Conv<sup>*</sup>")
v_silu("SiLU")
beta_proj("Linear")
beta_sigmoid("Sigmoid")
delta_rule("Delta Rule")
rmsnorm("RMSNorm")
o_proj("Linear")
out(("out"))


in --> qk_proj --> qk_conv --> qk_silu --> qk_l2norm -- "q, k" --> delta_rule
in --> v_proj --> v_conv --> v_silu -- "v" --> delta_rule
in --> beta_proj --> beta_sigmoid -- "beta" --> delta_rule
delta_rule --> rmsnorm --> o_proj
o_proj --> out
  • The convolution layers are depthwise-separable convolution that generalizes the shift SSM.
  • Interleave DeltaNet layers and SWA layers.
  • Two layers with global attention: 2nd and the -th layer.