Let’s talk about a topic that is - thankfully - becoming more and more relevant in AI: alternative architectures to Transformers.
What might look like me geeking out on non-mainstream stuff is really not. Transformers suffer from a few but important limitations, different groups of researchers and developers are hard at work to mitigate or remove them and they came up with very innovative ideas.
We will talk about the main issues with Transformers, how RNNs and SSMs are being proposed as valid alternatives and where this might lead us in the near future. Skip to the Conclusions section if your attention span is less than that of a chihuahua.
What’s the Issue with Transformers?
The same architectural choice that makes Transformers so powerful, is also the source of most of its bottlenecks, I’m talking about Self-Attention.
The Compute Cost of Attention
Self-Attention is the ability of the model to attend to every token in a sentence to understand what it refers to.
Attention helps Transformers focus on parts of the inputs that are most important before making a prediction, such as providing an answer to a question. Take the sentence “She spoke to the person on the boat”. Attention helps a transformer understand whether "on the boat" describes "the person" (as being on the boat), or if it means "she" spoke while they both were on the boat.
Attention is also a Transformer’s Achille’s heel. A vanilla implementation is in fact cost-quadratic: the compute cost required is proportional to the square of the size of the input data: O(n2).
As an example, GPT-3 initially had a context window of 4k tokens, GPT-4 has 128k tokens. So GPT-4 has a 32x larger context window but the compute cost is 1024x, and this hurts Sam a big deal. So he made GPT-4 more stupid (I guess?).
Context is the data you provide to an LLM, the amount of tokens you can pass to the model in a single go is the context window. Longer context windows mean more data the LLM can use to reason on, but they also mean higher compute cost.
It’s clear now what is the most important issue to address if we want to scale LLMs.
Problem #1: Attention is expensive and it accounts for the majority of compute costs at inference time (and it’s as expensive at training time too).
And to add more: Attention is calculated over a fixed context window, so if we need to pass to the model more information - say because we want to analyze a large document - we just can’t, we are stuck with the context window we have.
The Memory Cost of Attention
As mentioned above, Attention… attends to every token in the input, meaning that it’s necessary to store somewhere (in memory, and VRAM is very expensive) all intermediate values representing which token attends to which other, for the entire input length. Guess how this matrix grows? Quadratically as well.
Problem #2: The memory requirements to compute Self-Attention are also quadratic.
The Time Cost of Attention
At this point I think it’s clear where this is going. If compute and memory both go up quadratically, the time to compute a prediction also grows, so for longer input sequences, Transformers will take longer to produce a token than they would for shorter inputs.
Problem #3: Longer inputs require longer processing time to output a prediction
The Data Cost of Attention
We can safely say the jury is still out on this one, and I’ll try to distill this concern as much as possible: we don’t know exactly how efficient transformers are at learning, or in other words, how many tokens per parameter are required to fully train an LLM of a certain size.
I’d like to provide an overview of the complexity of this estimation, because whatever number will come out at the end of this journey will exactly define who can be successful at training an LLM and who cannot.
In 2020, OpenAI published a paper stating that to properly train an LLM one needs 1.7 tokens per parameter (let’s create a new unit and call it TPP).
In 2022 a seminal paper came up with the Chinchilla Scaling Law stating that you needed 20 TPP.
Earlier this month (April 2024) another group tried to reproduce Chinchilla’s results, they found an issue (confirmed by the original authors) in the method used and the corrected estimation was adjusted to 25 TPP.
At this point you might extrapolate and say: alright let’s pencil down a range between 20-25 TPP and call it a day, right? If only life was that simple! You are missing out on 2 fundamental factors: data quality and Meta.
In this paper researchers tried to figure out how LLMs capabilities increase as their size (number of parameters) grows. They have found that if good data is mixed with crappy data, all other things being equal, the model’s capabilities decrease 20x! To add even more complexity: even the rarity of knowledge has an important impact on the model’s capabilities.
Then comes Meta that decided to throw everything we learned from Chinchilla and trained LLama 3 with 1875 TPP!! In the previous update we discussed LLama 3 extensively and we found out that, on par with size, it eats Mistral for dinner and Mistral was, until few days ago, the best LLM you could get your hands on.
🤔 What’s going on here? Mistral (small) and Mixtral 8x7B were already trained way above the Chinchilla threshold, respectively at 429 TPP and 172 TPP and both of them happened to outperform, by a large margin, all other LLMs at the same size.
For the longest time I had the feeling (very unscientific I know) that LLMs were undertrained, I built this impression not by training a lot of LLMs - unfortunately Jensen keeps replying to my letters by enclosing in the envelope a tortured chip, may a hint to stop asking him GPUs? - but by studying them quite a lot. And what was just an intuition is now showing up as a real possibility.
LLama3 set a new performance standard so what’s next? Where is the inflection point where we get diminishing returns? We don’t know, but one point is becoming increasing more clear.
Problem #4: Transformers appears to be more data hungry than we suspected
The Problem with Complex Problems
Unexpectedly I know, complex problems are hard! Go figure 🤷♂️.
A class of decision problems called NC1-complete defines problems that are inherently sequential. The counter-part are problems in class TC0 that are instead fully parallelizable.
Two simple examples:
NC1: code evaluation, checking the parity of a string (counting the number of 1s)
TC0: sorting numbers, multiplying two numbers
LLMs can express problems in TC0 but they cannot find exact solutions to NC1 problems (hard-state tracking problems). You can do a quick check yourself using parity and a long enough string, but you should read this paper for a more rigorous proof of the previous claim.
Alright, we have a class of problems - useful in real-life - that Transformers cannot solve. This is a more problematic issue compared to the ones we have analyzed before, because it means that no amount of optimization will make these problems tractable under a regular Transformer architecture.
Problem #5: Transformers do no learn hard-state tracking
That was a lot. Go drink some kombucha, make yourself an avocado 🥑 toast, put on a good vinyl, change into your best dungarees and then come back for the second part.
RNNs and SSMs to the Rescue? Maybe.
For clarity: there is a lot of… Attention on the issues with Self-Attention. There have been many developments, Flash Attention, Ring Attention, Flash Decoding, Grouped Query Attention, Infini-Attention and more that aimed at mitigating the issues of quadratic scaling with minimal loss of performance, but while mitigated, the main problem is still there.
Another approach is to use a different architecture altogether, and that’s where RNNs (Recurrent Neural Networks) and SSMs (Space-State Models) enter the picture.
The concept of Attention was not invented with Transformers, in fact it was an innovation proposed in 2014 for Neural Machine Translation to reduce overfitting while training an RNN and give the model the ability to soft-search across its context. RNNs also have the tendency to focus on the last part of the input, and gradually “forget” the beginning. Attention helped mitigate this issue. To be completely fair, even if not called “Attention”, similar mechanisms where already used in LSTMs and other architectures. It’s a long story, let’s leave it at this for now.
RNNs, as powerful as they are, suffer from their own issues too - but don’t we all 🙃? Yes, RNNs can solve problems in NC1 and they can - in theory - have a context window of arbitrary length. This means we can pass to the model a couple pages or the entire dictionary from Aardvark to Zyzzyva - sound of me frantically shuffling the dictionary to finally figure out what on Earth is a Zyzzyva - without any restriction. But, RNNs are hard to train and the issue with longer contexts is not completely solved.
RWKV
The friendliest acronym in ML stands for Receptance Weighted Key Value, and although the authors say that you should pronounce it as RWaKuV, just don’t, a single wrong attempt will most likely summon Satan, or one of his acolytes, straight in your living room.
I was making jokes about Yann naming things, but at least he still uses vowels! Ok enough naming-things-shaming for today.
Anyways, imagine now a Transformer with Linear Attention and you got yourself - sort of - RWKV. This architecture tries to address the most important limitations of RNNs: they cannot be easily parallelized and scaled.
One of the most important benefits of this architecture is that, during inference, the memory and computational costs remain constant. The other is that context can be nearly infinite. Ok, this is not really true in practice, there is no free lunch after all, but context can be arbitrarily long.
So you might wonder, problem solved? We all switch to RWKV tomorrow?
If only life was that simple!
Yours truly, few pages ago
Hold your pandas 🐼 cowperson, not so fast! While it is true that memory and compute remain constant with size, the statement about context length must be taken with caution.
First of all, Transformers can learn more complex relationships precisely because the Self-Attention mechanism can attend to any token, while Linear attention can only attend to certain tokens, so that additional compute burden comes with some (expensive) advantages.
I can’t recall the paper titles but it seems (if you trust my memory) that Transformers (learned) attention matrix is low-rank, or in simple words: it can be made much simpler. Which I think is one of the assumption around which the authors of RWKV worked on.
Secondly, RWKV remains sensitive to the prompt, compared to Transformers, because past context information goes through a fixed-size vector (aka: it gets compressed).
This is not the end of the world as we know it, but we must be aware of the limitations of each architecture. The RWKV foundation has trained EagleX and published benchmarks. The model does better than LLama2 7B on some tasks, worse on others and it’s generally a little below Mistral. Though, mind you, this is nonetheless an amazing result. A new model that performs near the SOTA for a fraction of the cost is no small feat - and you can directly tune RWKV’s state to control its behavior, which is insanely cool.
You should keep an eye on RWKV and how it evolves, the community behind it is rich and they work hard to improve the model and release updates quite frequently.
Mamba 🐍
The other major contender is Mamba🐍, which is a State-Space Model. SSMs come from a different direction than you’d expect, Control Theory specifically, where they’re used to model dynamic systems.
Mamba is a model without attention, the math is a little complex but the key part is something called Discretization.
Discretization allows us to pass from a Continuous representation to a Recurrent and Convolutional representation. You might be like this right now: 🤔.
What this means is that you can use any of these representations for each given task, or to make it simpler: you can choose a Convolutional view for training (as it can be parallelized), a Recursive view for inference (which is time-constant, like in RWKV) and a Continuous view when you need to work with continuous data. Talk about freedom!
As you may have understood by now, Mamba is not an RNN (with the exception of the recursive representation that can be modelled as such) or a traditional Neural Network, in fact it’s something very distinct.
Benchmarks have shown remarkable performance in the sense that Mamba matches and sometimes surpasses SOTA models on the same tasks, which is a big win. This means we already have 2 viable alternatives to Transformers - equally or maybe a little more capable - whose compute cost scale linearly!
In Conclusion
You either made it to this point (congrats Axel! 💪) or you skipped here after reading 3 paragraphs (hi Terry 👋), in any case, what do we make of all of this?
Side note: I couldn’t really address all possible pros and cons. Transformers for instance are better than all other architectures at copying their inputs back to the output, RWKV can selectively forget part of the input etc… It was just too much to fit into a single post and I had to leave a lot of concepts out.
Transformers are powerful and all large companies are working to make them more efficient in an effort to drive down the cost of the AI services they provide. But as customer demand for more powerful LLMs grows, at some point the dinosaur 🦖 in the room will need to be addressed.
So far none of the large companies have been - at least publicly - investing significant resources to train large versions of either RWKV or Mamba, but the models we have so far are working well and the cost reduction at scale is indeed significant.
There is no winner takes all type of situation here, my view on this point is that we will move in a multiple-models scenario fronted by a Router Model and possibly a policy layer to optimize parameters like cost, quality and performance. Offering multiple models has no significant cost impact vs hosting a single one (at least at scale). I see no harm in directing queries to Mamba for specific data type, using RWKV for RAG and a regular Transformer - or even a series of Transformers either via MoE or simply by using multiple models - to answer the remaining queries.
Cost is a concern and the input query is very long? Route to RWKV. You need to answer a question about a legal dense document? Route to a GPT. You need to analyze sequential data? Route to Mamba. You need to quickly support a customer’s request? Send it to LLama3 8B. You need to solve a complex-state problem? Send to RWKV etc. You get the gist.
It’s exciting to have the privilege of choice and the ability to optimize across different dimensions to support different business cases - and user cases as well - as needed.
So far, I can’t say that any of these architectures will replace Transformers, but that’s not the point. We must always use the best tool for the job, sometimes that tool will be a Transformer, others it won’t, before we didn’t have a viable alternative, now we do.
See you next time!
PS
The Illusion of State in State-Space Models?
I started reading this paper with half a heartbreak 💔 while I was writing this update, and it seemed fair to add this section since we talked about Mamba🐍 and the paper appeared to be a rebuke of SSMs.
The paper states that SSMs, like Transformers, are unable to express computation other than TC0 class, unlike what I wrote above. So the state is, in fact, an illusion. Bummer.
But, in a plot twist, the same researchers have found a way, with a minimal modification, to ensure SSMs can learn hard-state tracking in the same way as RNNs.
Within the space of a single paper, they gutted the poor Mamba🐍 and then lazarused it back to life with new powers. Pheww 😰.
First time I started sweating while reading a research paper!
Yes, SSM state appears to be an illusion but a minimal modification makes it real. I couldn’t have hoped for a better ending!