State Space Models and the Mamba Architecture

This article is adapted from this lecture.

The quadratic cost in input sequence length that self-attention imposes on the transformer architecture has motivated different lines of research into reducing this cost. For example, Tri Dao’s Flash Attention leverages hardware to significantly improve runtime costs, and approaches like TransformerFAM and Leave No Context Behind introduce modifications to the transformer architecture that enable it to effectively work with indefinitely long input sequences. More generally, there’s been an ongoing effort to increase the effective size of the context window in large language models (LLMs).

LLM context size by Cobus Greyling, August 2023
LLM context size by Cobus Greyling, August 2023.

Another line of research that has been at least partly motivated by this quadratic cost is research into alternative architectures for language modeling. One such architecture is state space models, which has been a ongoing line of research for several years already, and came to more mainstream prominence in late 2023 with the publication of the Mamba model. In this article, I’ll give a high-level overview of state space models including S4 (a precursor to Mamba), as well as an overview of the Mamba model. For more details, see references at the end of the article. I will assume you are familiar with the transformer architecture for sequence-to-sequence modeling.

State Space Models

Long-range dependencies are a big challenge for sequence models. As a result, datasets such as the Long Range Arena have been designed to foster research in this direction, with more efficient variants of the transformer typically fairing better than a vanilla transformer architecture. A family of models that excels at this type of tasks is state space models.

So, what is a state space model? A continuous latent space model is a model that does two things:

  1. Maps a 1-dimensional signal over time \(x(t)\) to an n-dimensional latent state \(h(t)\).
  2. Projects that latent space (plus the input signal) to a 1-dimensional output \(y(t)\).

Formally, that is:

\[\begin{align} d\mathbf{h}(t)/dt &= \mathbf{Ah}(t) + \mathbf{B}x(t) \\ y(t) &= \mathbf{Ch}(t) + \mathbf{D}x(t), \\ \end{align}\]

where \(\mathbf{A},\mathbf{B},\mathbf{C},\mathbf{D}\) are learned parameters and \(t\) is time. Normally, \(\mathbf{D} = 0\) for simplicity, and because it’s seen as a skip connection between input and output that is easy to compute. The first is a differential equation that describes how latent state \(h\) changes over time. The second equation is a signal-to-signal model between signals \(x(t)\) and \(y(t)\). Importantly, both equations describe \(\mathbf{h}\) and \(y\) as a function of the input signal and the latent state \(\mathbf{h}\). This is a Markov decision process, where at every time step \(t\), latent state \(\mathbf{h}\) ideally encodes the entire history of input signal \(x(t)\).

Now, given that the input signal in language modeling is a discrete sequence of tokens, we need to discretize this continuous model. One intuitive way to see what this discretization process looks like is by looking at the Euler method. First, the definition of a derivative: \(d\mathbf{h}(t)/dt = (\mathbf{h}(t + \Delta) - \mathbf{h}(t)) / \Delta\). Cleaning up notation a bit, where we set \(x_t = x(t), \mathbf{h}_{t} = \mathbf{h}(t)\) and \(\mathbf{h}_{t+1} = \mathbf{h}(t+\Delta)\), and applying the definition of a derivative to the first equation in our model, we have:

\[\begin{align} \mathbf{h}_{t+1} - \mathbf{h}_t &= \Delta(\mathbf{Ah}_t + \mathbf{B}x_t) \\ \mathbf{h}_{t+1} &= \Delta\mathbf{Ah}_t + \Delta\mathbf{B}x_t + \mathbf{h}_t \\ \mathbf{h}_{t+1} &= \mathbf{h}_t(\Delta\mathbf{A} + I ) + \Delta\mathbf{B}x_t \\ \mathbf{h}_{t+1} &= (I + \Delta\mathbf{A})\mathbf{h}_t + (\Delta\mathbf{B})x_t. \end{align}\]

If we now set \(\mathbf{A}^* = (I + \Delta\mathbf{A})\), \(\mathbf{B}^* = \Delta\mathbf{B}\) and \(\mathbf{D} = 0\), we get:

\[\begin{align} \mathbf{h}_{t+1} &= \mathbf{A}^*\mathbf{h}_t + \mathbf{B}^*x_t \\ y_t &= \mathbf{Ch}_t. \end{align}\]

Hopefully you recognize that as an RNN.

SSMs as RNNs.
SSMs as RNNs.

Discrete SSMs are essentially linear RNNs, i.e. RNNs with the identity function \(I\) as activation function. From this perspective, we can see the discretization steps for computing \(\mathbf{A}^*\) and \(\mathbf{B}^*\) as the first step in the compute graph.

Now, discrete SSMs don’t actually use the Euler method for discretization, but different state space models (SSMs) use different discretization approaches. For example, the S4 (more about this model soon) uses a bilinear method:

\[\begin{align} \mathbf{A}^* &= (I - \Delta/2 \mathbf{A})^{-1}(I + \Delta/2 \mathbf{A}) \\ \mathbf{B}^* &= (I - \Delta/2 \mathbf{A})^{-1}\Delta\mathbf{B}. \end{align}\]

Mamba uses a process called zero-order hold:

\[\begin{align} \mathbf{A}^* &= \text{exp}(\Delta\mathbf{A}) \\ \mathbf{B}^* &= (\Delta\mathbf{A})^{-1}(\text{exp}(\Delta\mathbf{A}) - I )\Delta\mathbf{B}. \end{align}\]

In general, discrete SSMs are thus parameterized by \(\mathbf{\theta} = [\Delta, \mathbf{A}, \mathbf{B}, \mathbf{C}]\).

“Wait, so… we are back to RNNs? And linear ones at that?” Good question! At this point, yes. But let’s not forget that RNNs fell out of favor for two specific reasons:

  1. Difficulty modeling long range dependencies.
  2. Slow to train due to sequential computation of hidden states.

If SSMs were to ever be successful, they’d have to overcome these limitations, which is precisely what drove SSM research for years. Let’s have a high-level look at how SSMs overcame these two challenges.

First, for modeling long sequences, SSMs rely on the HiPPO framework, where HiPPO stands for high-order polynomial projection operators. This is a very technical mathematical framework (described as magic by some sources) that proposes a principled way to initialize matrix \(\mathbf{A}\) so state \(\mathbf{h}_t\) is able to memorize input sequence \(\mathbf{x}\). Specifically, they propose that \(\mathbf{A}\) be initialized as a complex matrix of the following upper-triangular form:

Initialization of matrix A according to HiPPO framework (Gu et al. 2020)
Initialization of matrix A according to HiPPO framework by Gu et al., 2020.

More generally, they used principles from approximation theory and signal processing to model the notion of memory as online function approximation. As validation for their framework, they used it to derive GRUs and LMUs. This work from Gu et al. (2020) put SSMs on the map, as they reported massive improvement in sequental MNIST (model must classify handwritten digits by “reading” them sequentially, one pixel at a time), bringing the state-of-the-art results from 68% to 98%. In short, the HiPPO framework is a principled way to get SSMs to perform much better at long-range dependencies. But, as with RNNs, they still were slow to train.

To see how SSMs overcame the issue of slow sequential training that plagued RNNs, we use the fact that the model is linear and thus, the computations of \(\mathbf{h}_{t+1}\) and \(y\) can be unrolled as follows:

\[\begin{align} \mathbf{h}_0 &= \mathbf{B}^*x_0, \quad \mathbf{h}_1 = \mathbf{A^*B^*}x_0 + \mathbf{B}^*x_1, \quad \mathbf{h}_2 = \mathbf{A^{*2}B^*}x_0 + \mathbf{A^*B^*}x_1 + \mathbf{B}^*x_2,\\ y_{0} &= \mathbf{CB}^*x_0, \quad y_1 = \mathbf{CA^*B^*}x_0 + \mathbf{CB}^*x_1, \quad y_2 = \mathbf{CA^{*2}B^*}x_0 + \mathbf{CA^*B^*}x_1 + \mathbf{CB}^*x_2, \end{align}\]

and so on until \(\mathbf{h}_{t+1}\) starting from an initial state \(\mathbf{h}_{-1} = \mathbf{0}\). Thus, given input sequence \(\mathbf{x}\), we can see the computation of the output signal as \(\mathbf{y}=\mathbf{K}\odot\mathbf{x}\), where \(\odot\) is a discrete convolution with kernel \(\mathbf{K}\) defined as:

\[\begin{align} \mathbf{K} &= (\mathbf{CB^*}, \mathbf{CA^*B^*}, \mathbf{CA^{*2}B^*, \ldots, \mathbf{CA^{*|x|-1}B^*}}). \end{align}\]

In other words, we can see the entire forward pass of such a model as a single global convolution. This is convenient, because it is known that these can be computed efficiently using Fast-Fourier Tranforms (FFTs). Still, the computation of this convolution can be quite expensive, particularly because computing \(\mathbf{K}\) relies on the repeated matrix multiplication of \(\mathbf{A}^*\). And this is where the S4 model came in, as its major contribution was achieving this computation in linear time.

S4 stands for structured state space sequence model and was proposed by Gu et al. in 2022. To compute \(\mathbf{K}\) in linear time, they enforced more structure in parameter matrix \(\mathbf{A}\), which was far from trivial, as \(\mathbf{A}\) already had a specific structure dictated by the HiPPO framework (complex upper-triangular as defined above), which had to be kept for long-range performance. Further, the authors designed an S4 block in order to construct deep models. The proposed S4 block was made up of an SSM layer + dropout + a non-linearity + a linear projection. But note that, as seen so far, SSMs process a single scalar. So the S4 block contains as many instances of an SSM model as the desired hidden size, meaning it can process sequences of vectors as done by transformers, but in linear time! During training, sequences are processed via a global convolution, and at inference time, autoregressively, as nicely shown in code here. S4 was the key step that allowed SSMs to efficiently train without losing their great performance on long-range dependency tasks.

SSMs vs RNNs

Before discussing the Mamba model, which is an extension of the S4 model, it’s worth summarizing the relation between SSMs and RNNs.

  1. Both SSMs and linear RNNs can be efficiently parallelized via a global convolution.
  2. SSMs are linear RNNs buth with special requirements for how some parameters are computed (discretization).
  3. SSMs are complex valued and initialized according to the HiPPO framework.
  4. Parameters \(\mathbf{A^*}\) and \(\mathbf{B^*}\) make SSMs look like linear RNNs, but these parameters share parameters themselves, namely \(\Delta\) and \(\mathbf{A}\).

These differences between RNNs and SSMs account for the success of the latter models. This was made evident by the really nice work or Orvieto et al., who starting from a linear RNN, “ablated” their way to S4. Specifically, they introduced diagonalization, a special initialization and parameterization, as well as normaliation into RNNs to get them to reach the performance of S4.

The Mamba Model

Coming soon!

References

Contact

Questions? Comments? Drop me a line at daniel@ruffinelli.io ← Back to all posts