Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
The selective SSM, rewritten as a matrix you can multiply on a GPU.
Mamba walked the sequence one token at a time and kept a small state. Useful, but the work was scalar, and modern GPUs are built for big matrix multiplications. Mamba-2 restricts one piece of the math, and the recurrence collapses into a structured matrix you can chunk and multiply, two to eight times faster.
Explaining the paperTransformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space DualityA 2.7B model that, on the same data as Pythia-2.8B, beats Pythia-6.9B on common downstream tasks. The architecture is a selective state space model, the same family as Mamba, but the inner layer is computed by a chunked matrix multiplication. Same big-O cost as Mamba's scan, much higher arithmetic intensity.
Mamba-2, from Tri Dao and Albert Gu, is in one sense a follow-up to Mamba: same selective state space layer at the core, same fixed-size running state, same linear-time scan over the sequence. In another sense it is a much larger claim. Mamba-1 made the argument that a carefully gated recurrence could compete with attention on language. Mamba-2 makes the argument that the recurrence and attention are the same object, viewed through different decompositions of one structured matrix. That matrix is what the paper calls structured state space duality (SSD), and the algorithm that comes out of it looks like attention from the outside (big matrix multiplications, tensor cores busy) and like a recurrence on the inside (linear cost in sequence length, constant state).
The mechanism is one restriction and one decomposition. The restriction tightens the per-step dynamics of the SSM from a diagonal matrix down to a scalar times the identity, so that across state slots there is one shared decay per token rather than independent ones. That restriction is what factors the sequence-transformation matrix into a 1-semiseparable mask (one cumulative-product scalar per pair of positions) multiplied entrywise by an attention-shaped Gram matrix: every coordinate of is now (decay) (Gram entry). The decomposition then chunks into a grid of small blocks: diagonals are computed by a small attention-shaped matmul, off-diagonals are absorbed by carrying a state across chunk boundaries. The work runs in linear time in the sequence length, with the inner operation being a dense matrix multiplication GPUs are built for.
A few pieces carry the argument: an SSM written as a triangular matrix, the semiseparable structure of that matrix, the duality between this matrix and a kind of masked attention, and the chunked algorithm that exploits both views. Each piece is short on its own, and they only do their job once you see them chain together.
Why Mamba was slow on a GPU
Start with the problem Mamba-2 is solving. Mamba introduced a selective state space model: a small running state updated at every token by
with a diagonal matrix per token and linear projections of the input. The state has size (typically 16); the head dimension is one channel, so the whole layer runs independent SSMs in parallel.
The cost is not the FLOP count. A scan over tokens does work per channel, the same asymptotic as everyone else. The cost is what those FLOPs look like to a GPU: each step is a small element-wise multiply-add followed by a reduction, nothing wider than a vector. Modern accelerators do dense matrix multiplications on tensor cores at hundreds of teraFLOPs per second, while scalar code on the same chip runs at a small fraction of that. Mamba's selective scan needs a careful CUDA kernel just to keep the GPU vaguely busy, and even with that kernel it leaves headroom on the table.
FlashAttention had already shown that for attention, the work-per-second on a GPU matters more than the FLOP count, and that phrasing the work as matrix multiplications and loading data into SRAM in tiles is what unlocks tensor cores. Mamba-2 asks the same question of selective SSMs: can the recurrence be rewritten so the inner work is a matrix multiplication, without giving up linear-in- cost? It can, and seeing why takes one move: write the SSM as a matrix.
A recurrence is a triangular matrix
Unroll equation (1) from . The state at step is a sum of every past input pushed through the chain of matrices since:
Multiplying by collapses the output to a single sum:
That is a matrix multiplication , where is lower triangular (the SSM is causal: only sees for ) and each entry is the product of a contraction (), a chain of step matrices (), and an expansion (). Nothing has been added or lost. The recurrence and the matrix multiplication compute the same numbers; they are two algorithms for the same map.
The matrix is not a generic array. It has very specific structure: the entry at with is a product whose "middle" depends on and whose "endpoints" depend on and . That structure has a name.
Off-diagonal blocks are low rank
Pick any block of that lies strictly below the diagonal. Pick rows and columns, with , so the block sits entirely below the diagonal (no diagonal entries inside). Every entry of that block is
for the corresponding indices. The center factor is a single matrix that does not depend on which row or column you are in: it is fixed by the block. The left factor varies with the row, the right factor varies with the column, and every entry of the block is therefore (left vector) (fixed middle) (right vector). That is a rank- factorization: the block factors through an -dimensional middle.
A matrix whose every off-diagonal block has rank at most is called N-semiseparable. The structured representation that writes each entry as is called the sequentially semiseparable (SSS) form. Two non-trivial facts about this class. First, every -semiseparable matrix can be written in SSS form, so the SSM family and the semiseparable family are the same matrix family. Second, an -semiseparable matrix of size can be stored in parameters and multiplied by a vector in work, even when the underlying blocks are dense (Pernet, Signargout & Villard 2023, restated as Proposition 3.6 in the paper).
The first fact, when you take it seriously, is the centre of the paper: any algorithm for multiplying a semiseparable matrix is an algorithm for evaluating an SSM, and conversely. Years of work on rank-structured matrix algebra apply directly to SSMs, and the SSM literature can borrow whichever decomposition is most convenient.
Mamba-2: one scalar per step
Mamba uses a diagonal : each of the state slots has its own decay rate per token. That is flexible (each slot can remember on its own time scale), but with independent decays per token there is no single scalar to pull out as a 1-semiseparable mask. To get one, we want one decay per token, shared across the slots. So Mamba-2 makes a restriction: , a scalar times the identity. Per-step dynamics now boils down to a single number .
What does the restriction cost? In principle, model capacity per step. In practice, on the kinds of sequence the paper measures, the cost vanishes: language modeling perplexity matches Mamba and Transformer++ at every size from 125M to 1.3B (Figure 9). On the MQAR synthetic the picture is even stranger: Mamba-2 with is noticeably better than Mamba-1 with the same , despite being strictly less expressive per step. The paper does not pin down the cause and is honest about it.
What does the restriction buy? With a scalar, the cumulative product is also a scalar, call it . The matrix
is itself a structured matrix: lower-triangular, with each entry a chain product. This is a 1-semiseparablematrix, the simplest member of the family we just defined. And the SSM matrix factors through it. Writing ,
and therefore , where is the elementwise (Hadamard) product. We have walked the SSM into the form of masked attention without softmax, with the causal mask replaced by a data-dependent decay mask.
The same matrix as masked attention
Equation (4) is the duality. The left side is the SSM, written as a linear recurrence on a state of size ; you compute it by walking the sequence and updating and , work linear in . The right side is masked attention with the softmax removed and the causal mask replaced by; you compute it by forming the Gram matrix , multiplying by elementwise, and contracting with , work quadratic in .
Same . Same numbers out. Two algorithms.
The duality is more than an analogy. The paper proves, in the other direction, that any masked-attention variant with an efficient autoregressive recurrence must use a semiseparable mask (Theorem 5.2). So the matrix class is the right place to look for fast sequence layers: structured masked attention with semiseparable is exactly the family that admits a linear-time recurrence, and that family includes both selective SSMs and linear attention.
What we have so far is two views and a cost: linear-in- but scalar, or quadratic-in- but matmul. To win we need an algorithm that is both. That is what the SSD chunked decomposition is for.
The SSD algorithm: chunk the matrix
Partition the matrix into a grid of chunks of size . There are now diagonal chunks and a triangular array of off-diagonal chunks below the main diagonal. The diagonal chunks are full-rank, but they are small ( in practice). The off-diagonal chunks inherit the structure of : they are off-diagonal submatrices of an-semiseparable matrix, so each one has rank at most .
Now compute by parts. The diagonal chunks are blocks where you actually do want the quadratic form, because is small and the form is a dense matmul. The off-diagonal chunks you do not store at all, you just remember the chunk-boundary states and let a much shorter recurrence carry them across the diagonal. Concretely:
- Diagonal blocks. Inside each chunk, compute the small attention on that chunk's own slice of . All chunks are independent and run as a single batched matmul.
- Right factors (input → state). For each chunk, compute the final state at the chunk boundaryassuming the initial state was zero. One matmul per chunk, shape .
- Center factors (state → state). The chunk-boundary states form a short sequence of length . Run a 1-semiseparable scan on that sequence to turn each chunk-state into its true initial state. This is a scan, but on a sequence times shorter than the original.
- Left factors (state → output). Multiply each chunk's true initial state by its values to get the cross-chunk contribution to outputs. Another batched matmul per chunk.
- Add. The output of a chunk is the intra-chunk piece plus the cross-chunk piece.
Count the cost on the structure above. With all of the matmul terms become , so the total FLOP count is , the same as the linear recurrent form. The inner-most operation is now a dense matrix multiplication, which is what GPU tensor cores are designed for; the chunk-state scan is a length- reduction with negligible cost. There is no blow-up because no block for is ever materialized.
The full algorithm fits in a page of PyTorch. The paper's Listing 1, slightly annotated:
# the full SSD layer, PyTorch. block_len is the chunk length Q.
# X: (batch, T, n_heads, d_head) the inputs (analog of V)
# A: (batch, T, n_heads) the SCALAR per-step decay log a_t
# B: (batch, T, n_heads, d_state) the input projection (analog of K)
# C: (batch, T, n_heads, d_state) the output projection (analog of Q)
# 1. cut everything into chunks of length Q along the time axis.
X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=Q) for x in (X, A, B, C)]
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1) # log a_{t:t0}
# 2. INTRA-CHUNK (the diagonal blocks): small attention, in parallel.
L = torch.exp(segsum(A)) # the 1-SS mask, exp of segment sum
Y_diag = einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
# 3. each chunk's exit state, supposing its initial state was 0:
decay_states = torch.exp(A_cumsum[..., -1:] - A_cumsum)
states = einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
# 4. INTER-CHUNK recurrence on the chunk-states (1-SS scan, length T/Q):
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[..., -1], (1, 0))))
new_states = einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
# 5. CARRY each chunk's true initial state to its outputs:
state_decay_out = torch.exp(A_cumsum)
Y_off = einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out)
# 6. fuse: intra-chunk + cross-chunk.
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
return Y, final_statesegsum is a segment-cumulative-sum that turns the log-decay sequence into the 1-SS mask:L = exp(segsum(A)) assembles the lower-triangular matrix of cumulative products . Everything else is one of two operations: a chunk-shaped Einstein summation (the matmul lines) or a length- recurrence on chunk-states (the decay_chunk line). Even in vanilla PyTorch without custom CUDA, this version is competitive with Mamba's heavily-tuned selective scan, because the work the GPU is being asked to do is precisely the kind it does best.
The Mamba-2 block
The block itself is one small structural change from Mamba-1 and one normalization addition. In Mamba-1 the projections that produce are functions of the SSM input , so they sit serially inside the block: is projected from the residual stream, then a second projection computes from (Figure 6 in the paper). In Mamba-2 the layer is viewed as a map , so all four are projected in parallel at the start of the block, the way are in attention. That removes the inner sync point, which lets tensor parallelism shard the layer cleanly across GPUs the way Megatron shards attention and MLP layers.
A GroupNorm sits just before the final output projection, after the multiplicative gate. The paper finds this normalization fixes instabilities at larger scale, and notes the same pattern in TransNormerLLM and RetNet (NormFormer-style late normalization). The head pattern Mamba-2 uses is MVA / MIS: and are shared across the heads of , the same head structure as multi-value attention.
The other minor change is a larger head dimension: Mamba-1 uses (one channel per SSM), Mamba-2 uses or , the same head dimension a Transformer would use. That change is what makes the inner matmul instead of .
Bigger state, sharper recall
One thing the SSD algorithm enables, and the change of head dimension cheaply allows, is a much larger state size . Mamba uses by default; Mamba-2 ships with , and the paper shows the algorithm continues to be efficient up to or higher. That matters most for tasks that ask the model to memorize many key/value associations in its context, the place fixed-state recurrences traditionally lose to attention.
The multi-query associative recall (MQAR) synthetic is the cleanest test. The model is shown a sequence of key-value pairs and is later prompted with some of the keys; it must produce the matching values. Attention does this well at any sequence length because it caches every token; a small-state SSM has to compress everything into slots, so at short sequence length or low it fails outright. Watch what happens to Mamba-2 as you increase :
Two things to read off the figure. First, Mamba-2 at outperforms Mamba at : the architecture change wins even at matched state size, which the paper flags as surprising and does not fully explain. Second, the trend up to is monotonic; SSD's ability to keep the inner work in matmul form is what lets a model that size remain fast enough to train.
Speed and language modeling
The headline efficiency claim is on an A100 80GB: a single SSD layer is two to eight times faster than Mamba's fused selective scan, depending on state size, and faster than FlashAttention-2 from sequence length two thousand onward. At sixteen thousand tokens, SSD is about six times faster than FlashAttention-2. Both comparisons turn on the same mechanism: SSD does the same number of FLOPs as Mamba, but those FLOPs are in big matmuls instead of small scalar steps; SSD does fewer FLOPs than FlashAttention-2 (linear vs quadratic in ), and runs them at the same matmul utilization. The crossover with FlashAttention-2 sits around two thousand tokens.
On standard language modeling the picture is competitive rather than dominant. Mamba-2 matches Mamba and the modern Transformer recipe (RoPE + SwiGLU + RMSNorm, no biases) on Chinchilla-style scaling sweeps from 125M to 1.3B parameters. At 2.7B trained to 300B tokens on the Pile, Mamba-2 lands a Pile perplexity of 6.09 against Mamba's 6.22 and Pythia-2.8B's 6.73, with consistent gains on common downstream tasks (LAMBADA, HellaSwag, PIQA, Arc-E, Arc-C, WinoGrande, OpenbookQA). And the 2.7B Mamba-2 outperforms Pythia-6.9B trained on the same data at most of those tasks, despite less than half the parameter count.
The hybrid results are the most interesting forward-looking piece. With ~10% of the layers in a 350M / 48-layer model switched from SSD to softmax attention, Pile perplexity drops from 8.60 (pure Mamba-2) to 8.26 (six attention layers), below the 8.68 of pure Transformer++. A 2.7B Mamba-2-Attention hybrid with six attention layers (out of 64) hits 5.95 Pile perplexity, below both pure Mamba-2 (6.09) and Transformer++ (6.13). So the architecture that emerges is mostly recurrence with a few attention layers stacked in; the recurrence does the bulk-mixing work cheaply, and a few attention layers handle the lookups the recurrence cannot.
Where Mamba-2 gives ground: the scalar-A restriction is real, and any task that genuinely needed per-slot decay rates will see a difference (none of the language benchmarks did, but the existence proof is not the same as a no-cost result). Tensor parallelism, sequence parallelism, and variable-length packing all need new conventions for SSMs that the paper has to introduce from scratch, and the ecosystem is still much thinner than Transformers'. Pure associative recall at very long sequences still favors attention even at , which is why the hybrids exist.
An SSM is a triangular matrix whose off-diagonal blocks have rank at most , the state size. Restrict to a scalar and that matrix becomes the same object as masked attention without softmax, with a 1-semiseparable mask in place of the causal mask. Chunk that matrix into a grid: diagonals are small attention, off-diagonals collapse to a short scan on chunk-states. The work is linear in , the inner operation is a dense matmul, and the language-modeling cost of the scalar restriction is, on the kinds of sequence anyone runs, zero.
Questions you might still have
Is Mamba-2 the same as Mamba with a faster scan?
No. It is a different layer. Mamba uses a diagonal A_t, so each of the N state slots has its own decay; that flexibility is what stopped you from collapsing the recurrence into a single matrix multiply. Mamba-2 restricts A_t to a scalar times identity, so the per-step decay is one number. With that restriction the recurrence becomes a structured matrix you can chunk and multiply with tensor cores, which is what the SSD algorithm does. Same family, smaller model class, faster compute.
Does the scalar-A restriction cost accuracy?
On language modeling, no measurable cost: Mamba-2 matches Mamba and Transformer++ on the Chinchilla-style scaling sweep, and at 2.7B trained to 300B tokens it slightly outperforms Mamba and Pythia at twice the size. On MQAR (the multi-query associative recall synthetic) Mamba-2 actually beats Mamba-1 at matched state size, which is surprising and the paper does not fully explain.
If SSD is "attention without softmax", does it inherit attention's long-context behavior?
Only partly. The state is still finite (N²) and the mask L is data-dependent decay, not an unbounded lookup. So associative recall still hurts at very long sequence lengths unless N grows. In practice this line of work has converged on hybrids: stack mostly SSD layers and sprinkle a few full-attention layers in. The paper finds ~10% attention layers is optimal at the 350M scale.
Why does the algorithm care so much about matrix multiplications, not FLOPs?
Modern GPUs do dense matmuls on tensor cores at hundreds of teraFLOP/s; the same number of scalar FLOPs in a sequential scan runs at a small fraction of that, because tensor cores sit idle. Mamba-1's selective scan is linear-in-T but scalar; SSD is linear-in-T and matmul-shaped. Same big-O, much higher utilization. The Mamba-2 paper is largely an argument that this is the lever that matters.
Footnotes & further reading
- The paper: Dao, Gu, Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality (Princeton & CMU, ICML 2024). Code.
- The predecessor: Gu, Dao, Mamba: Linear-Time Sequence Modeling with Selective State Spaces. The Mamba explainer covers the selective scan and the input-dependent .
- Linear attention and the masked-attention recurrence trick: Katharopoulos et al., Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. The Mamba-2 title is a deliberate echo.
- Semiseparable matrices: Pernet, Signargout, Villard, Time and space efficient generators for quasiseparable matrices (the parameter result), and the textbook survey Vandebril, Van Barel, Mastronardi, Matrix Computations and Semiseparable Matrices.
- The IO-aware mindset SSD inherits: Dao, FlashAttention-2, and the original FlashAttention explainer.
- The two NormFormer-style adjacent papers Mamba-2 cites for late normalization: Qin et al., TransNormerLLM and Sun et al., RetNet.
- The MQAR task: Arora et al., Zoology: Measuring and Improving Recall in Efficient Language Models.
- A more discursive walk through SSD from the authors: State Space Duality (Mamba-2) on the Goomba Lab blog.
How could this explainer be improved? Found an error, or something unclear? I read every message.