Mamba: SSM, Theory, and Implementation in Keras and TensorFlow

Author:Murphy  |  View: 28053  |  Time: 2025-03-22 22:25:03
Source: AI Generate (SDXL)

Submitted on 1st December, 2023 on arXiv, the paper titled "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" proposed an interesting approach to sequence modeling. The authors – Albert Gu, Tri Dao – introduced, ‘Mamba' that utilized ‘selective' state space models (SSM) to achieve results that compete with the performance of the, now ubiquitous, Transformer model.

What's so unique about Mamba?

Transformers have seen recent popularity with the rise of Large Language Models (LLMs) like LLaMa-2, GPT-4, Claude, Gemini, etc., but it suffers from the problem of context window. The issue with transformers lies in it's core, the multi head-attention mechanism.

The main issue with multi-head attention sprouts from the fact that for input sequence length n, the time complexity and space complexity scales by O(n²). This limits the length of the context window of an LLM. Because, to increase it by 10x, we need to scale the hardware requirement (most notably GPU VRAM) by 100x.

Mamba, on the other hand, scales by O(n)!, i.e., Linearly.

Plot taken from the Mamba paper comparing FlashAttention and Mamba approach (indicated by scan(ours) in the legends)[1]

This linear scaling is what has taken wind for researchers to speculate that Mamba might be the future of sequence modeling.

The backbone of Mamba: State Space Models

The core of the Mamba model comes from the concept of State Space Models. State Space Models, like Transformers and RNN, process sequences of information, like text, audio signals, video frames, DNA sequences, etc.

State Space Models come from an idea of describing a physical system as a set of input, outputs, and variables. These variables are: A, B, C, D. The process of SSM involves calculation of an internal state vector h(t), given an input x(t). Then, we do a weighted sum of h(t) and x(t) where the weights are A, B, C, & D. In the simplest form (continuous time-invariant), the process formulation looks like:

source: wikipedia[6]

h(t) is often called the ‘hidden' or the ‘latent' state, I will be sticking to calling it the ‘hidden' state for better clarity. It is important to note that A, B, C, and D are learnt parameters in SSM.

What are the variables?

The variables, A, B, C & D, are learnt parameters, and they can be described as:

  • A: How much should the previous hidden state (h) be considered to calculate the new hidden state
  • B: How much should the input (x) be consider to calculate the new hidden state.
  • C: How much should the new hidden state be considered in calculating the output (y).
  • D: How much should the input (x) be consider in calculating the output (y).

D comes in the end of the computations and does not affect how the hidden state is calculated. Hence, it is usually considered outside of ssm, and can be thought of as a skip connection.

Going from continuous spaces to discrete spaces

The above formulation applies to a system where the input and output belong to a continuous space. But in cases, like language modeling, where the input and output belong to discrete spaces (token values in a vocabulary). Also, finding h(t) is analytically challenging. This can be achieved by performing a Zero-order hold.

In a zero-order hold, every time an input is received, the model holds its value till the next input is received. This leads to a continuous input space.

How Zero order hold works

This length of ‘hold' is determined by a new parameter called, step size ∆. It can be thought of as the resolution of the input. Ideally, ∆ should be infinitesimal.

Mathematically, Zero-order hold can be described as:

Finally, we can create a discrete SSM, as:

Since, D is used with a skip connection outside of SSM, the output can be reduced to:

Involvement of DX(t) is considered as a skip connection, hence is goes from outside of SSM

SSM and recurrence

In SSMs, the hidden state is carried over to when the next input is received. This is similar to how Recurrent Neural Networks function.

Comparison of RNN and SSM

This recurrent format of SSM can be unwrapped, just like RNNs. But unlike RNNs, which are iterative and slow, SSM can process the input sequence in parallel (just like transformers) and this makes the training processes faster.

Unrolled form of SSM

Note that ‘D' is used in a skip connection, which is outside of SSM.

The key insight in how SSM make training fast is to use the variables A, B, C in a pre-computed convolutional kernel. Maarten Grootendorst wrote a really good explanation on how this canonical ‘convolutional' kernel is constructed. But here's a simple mathematical explanation.

Consider the output y. For a sequence length of k, the output for y(k) will be represented (assuming h0 = zero):

Similarly, y3 can be represented as:

Extrapolating the pattern, yk can be represented as:

This formulation can be further reduced to:

The funny looking multiplication symbol represents a convolution operation, where the convolution kernel is K. Notice that K is not dependent on x, hence K can be pre-computed into a convolutional kernel, which makes the process faster.

Mamba and ‘Selective' SSM

As good as the computational capacity of SSM sounds, it turns out to be pretty meh in metrics like accuracy compared to Transformers.

The core issue lies with the variables, ∆, A, B, & C. Turns out that since we apply the same matrices to every input, they cannot really process the context of the sequence.

SSMs are inflexible in the way they process data[4]

So what's so special about Mamba? In mamba, we use a process called ‘selective' SSM, where the variables, ∆, B, & C, are computed based on the input.

Tags: Deeplearing Large Language Models Machine Learning Mamba TensorFlow

Comment