> The bitter lesson [1] is going to eventually come for all of these. Eventually we'll figure out how to machine-learn the heuristic rather than hard code it. Recurrent neural networks (RNNs) do this implicitly, but we don't yet know how to effectively train RNNs on ultra-deep sequences.
Linear RNNs and RWKV are examples of RNNs on deep sequences:
HiPPO was brilliant - instead of working with the raw sequence, you work with its weighted laplace transform, and instead of actually computing the laplace transform you find the rule to update it when new data is added. Furthermore, we can 'band limit' the Laplace transform (similar to PCA) to keep only the 'most important' information while still preserving most of the information in the sequence - this is a common and quite effective compression technique.
Any 'fast' transformer is going to be working with some kind of sampling or aggregation or compression of the long sequence. Sampling is ultimately going to be too noisy, and standard aggregations are going to be too coarse. So the thing to bet on is better compression techniques, which is what the S4/RWKV group are ultimately working on.
The sequence of model activations is being compressed. s4 treats each activation channel as an independent sequence, and applies a learned version of the Laplace transform, and drops less-significant components.
This is similar to basic compression you get with PCA or Fourier transforms. These transforms re fully invertible, until you drop the less significant components. Dropping less-significant components lets you reconstruct some degraded version of the input, and the transform makes it easy to pick the right components to drop.
I think the jury is still out if these will actually scale to ultra-long language understanding sequences. KWKV, for example, is still trained like GPT, but is architected so it can be run as an RNN during inference time. This is awesome, but it is unclear if the training regime will limit the effective use of long-ranging recurrent context.
Training as GPT vs RNN will give you numerically identical results with RWKV, it's just two ways of computing the same thing. It's trained in GPT-mode because it's cheaper to train that way -- you can parallelize over the sequence length. In practice it isn't going to be any different than training with back-propagation through time for the same sequence length.
The current versions of RWKV slowly go insane when exposed to sequences that are too long, because the state slowly diverges over time as you increase past the context length of the training session. They are experimenting with ways to avoid this though: https://github.com/Blealtan/RWKV-LM-LoRA/tree/dev-infctx
Linear RNNs and RWKV are examples of RNNs on deep sequences:
https://arxiv.org/abs/2303.06349
https://arxiv.org/abs/2305.13048