Multi-Query Attention
Share one set of keys and values across attention heads, and the decoder stops waiting on memory.
A Transformer generates one token at a time, and at every step it has to reload the keys and values of every token it has ever seen. Shazeer (2019) drops that reload by a factor of h, the number of heads, with a change small enough to fit on one line of code.
Explaining the paperFast Transformer Decoding: One Write-Head is All You NeedTraining is parallel and arithmetic-bound; decoding is sequential and memory-bound. So shrink the thing that gets reloaded.
A six-page paper from Noam Shazeer at Google, written in late 2019, sits underneath nearly every modern large language model. It named the cost a Transformer pays once you ship it: at decode time, the attention layer is not waiting on its multiplications, it is waiting on its memory traffic. Read the keys and values back, score, write the new key and value, repeat. The arithmetic per step is small, and the bytes moved per step are what determines how long each step takes.
One edit does it. The keys and values in standard multi-head attention live in tensors of shape (and ), one set per head. Multi-query attention drops the head dimension on keys and values: one shared , one shared , all query heads point at them. The cache shrinks by a factor of , the bytes-per-step shrink by the same factor, and decoding gets much faster. The cost is about 0.2 BLEU on WMT14 English-German, and the paper shows that price is far cheaper than every other way of shrinking the cache.
The argument runs through why decoding is memory-bound while training is not, what exactly sits in the KV cache, what one line of einsum changes, and what the resulting arithmetic intensity buys at deploy time. The math is short; the consequences run forward through grouped-query attention, FlashAttention, and PagedAttention.
Decoding waits on memory
One asymmetry between training and decoding sets up everything that follows. During training, a Transformer sees the whole target sequence at once. All queries fire in parallel against the same and on the same forward pass, and a backward pass on each batch reuses the tensors many times. Per token, the layer does a lot of arithmetic for relatively few bytes moved through memory. That is what a GPU is built for: big matrix multiplies that pull each weight off the wire once and use it many times.
Generation does not look like that: to produce the next token, the model runs a single forward pass with exactly one new query. That query attends to every previous position, so the keys and values of every prior token have to be available. They were already computed at earlier steps, so the model stores them in a KV cache and reads the cache rather than redoing the work. The arithmetic per step is small (one new query against stored entries), but every step has to pull the entire cache back through memory. On modern accelerators that traffic costs more than the arithmetic does, by orders of magnitude.
Section 2.3.1 of the paper makes the asymmetry exact. Under the simplifying assumptions , , and , batched training does arithmetic and bytes of memory access, for a memory-to-arithmetic ratio of . Section 2.4.1 redoes the count for incremental decoding and gets a different ratio:
The two terms come from different parts of the layer. The piece is the cache itself: and are size (because ), so the bytes moved across steps grow as , divided by the per-step arithmetic gives . The piece is the projection weights , which are independent of batch but get loaded once per step regardless. Big batches dilute the second; nothing dilutes the first.
What ratio is too high? A modern GPU is compute-bound when arithmetic intensity (the inverse of (1)) sits above roughly 100 to 150 FLOPs per byte of memory accessed, the rough compute-to-bandwidth ratio of every recent device. The paper is more conservative: it points out the gap can be two orders of magnitude. When the ratio is too high, the multipliers stall because the bytes have not arrived yet. Slide the cache length below and watch the two bars move in lockstep, with the gap to the compute roof never closing:
Two pieces of context before going further. First, the gap is a property of decoding, not of any specific GPU: the compute-to-bandwidth ratio above 100 has held across roughly a decade of accelerator releases. Second, the "1 FLOP per byte" baseline is for plain multi-head attention; the same counting under multi-query gives 32 FLOPs per byte, and the gap to the roof closes by exactly the same factor.
The KV cache, and why it grows by one row per token
Attention starts from one query vector and a set of key-value pairs assembled into matrices and . The output is a weighted average of the values, with weights from the softmax of dot products between the query and the keys:
That is everything attention does. The dot product is large when the query and a key point the same way, the softmax turns the scores into weights that sum to one, and the result is the matching value. (The original Transformer scales the logits by to keep them well-behaved in high dimension. Shazeer drops the scale because it can be folded into the projection that produces or .)
At decode time the queries arrive one at a time. The keys and values of every prior token do not change once they have been computed, so the layer stores them in a cache and reuses them. The query is fresh each step, used once, then thrown away; the keys and values are computed once and kept forever, and that asymmetry is what makes a cache useful at all. The cache grows by one row per generated token:
Count the per-step bytes once. At position , the model has cached keys and values, each a vector of width (or ), for every one of attention heads, across layers. That is
For a model with , , and a 128-token target (the paper's WMT setup), the cache at full length is about 1.6 million numbers per sequence. That is small. Scale to today's deployed models, and and a 32,000-token context, and the same arithmetic gives:
around 16 GB at 16-bit precision. At that size the cache stops being an implementation detail and starts dominating what a serving system has to plan around.
Inside one decode step, all 8.4 billion numbers (or 1.6 million, in the paper's setup) get streamed past the multipliers exactly once, and then they have to be streamed again on the next token, and the next. The arithmetic on each token is small (one query against every cached entry); the bytes are large; the ratio is what (1) measures.
Multi-head: every head, its own keys and values
Standard multi-head attention runs attention layers in parallel and gives each one its own projection from the input. A single input vector of width is multiplied by learned matrices to make queries, then by more to make keys, then by more to make values. Every head sees the same and produces its own version of all three. The outputs of the layers run through one more learned projection and sum.
# multi-head, batched (paper, section 2.3)
Q = einsum("bnd,hdk->bhnk", X, P_q) # h independent query heads
K = einsum("bmd,hdk->bhmk", M, P_k) # h independent key heads
V = einsum("bmd,hdv->bhmv", M, P_v) # h independent value heads
logits = einsum("bhnk,bhmk->bhnm", Q, K)
weights = softmax(logits + mask)
O = einsum("bhnm,bhmv->bhnv", weights, V)
Y = einsum("bhnv,hdv->bnd", O, P_o)The tensor shapes carry the cost. P_k and P_v are tensors of shape and : the head dimension is in there explicitly, which is why comes out shaped and lives in the cache that way. Eight heads, eight different key projections, eight cached keys per token.
The eight heads carry meaningful diversity. One commonly attends to the previous token, another to a content word in the prompt, another to a broad neighborhood. The standard intuition for why this is the right shape: a single attention distribution has to point somewhere, and a multi-head layer lets the model point eight places at once. (How clean those roles actually are in a trained model is more debatable than the textbook telling, but the diversity-of-attention argument is real, and it is the one the design rests on.)
Multi-query: one shared key and value, all queries
The paper's observation is that only the query side of attention has to differ across heads. The key and value side supplies a single shared set of vectors that every head dots against, and a single set can be dotted against from many directions. So drop the dimension from the key and value projections:
# multi-query, batched (paper, section 3)
Q = einsum("bnd,hdk->bhnk", X, P_q) # h query heads, same as before
K = einsum("bmd,dk->bmk", M, P_k) # the h is gone: ONE shared key
V = einsum("bmd,dv->bmv", M, P_v) # ONE shared value
logits = einsum("bhnk,bmk->bhnm", Q, K)
weights = softmax(logits + mask)
O = einsum("bhnm,bmv->bhnv", weights, V)
Y = einsum("bhnv,hdv->bnd", O, P_o)The diff is one letter per line. P_k is now shape , not . The cached drops the head dimension to , and so does . There are still queries, but they all point at one shared set of keys and one shared set of values. The output projection stays per-head, since each head's output is its own vector even when it shares the K and V it dots against. Dropping from and shrinks every tensor downstream of them by the same factor.
Toggle the three settings below and watch the wiring collapse. The eight query heads stay; the eight KV heads of multi-head collapse to one for multi-query, and to a small number for grouped-query, the middle ground from Ainslie et al. (2023) that Mistral picks up four years later. The cache bar tracks the number of KV heads exactly:
GQA-8 · Mistral: 4 query heads share each KV head
The cost is real and the paper does not hide it. Removing per-head keys and values is a constraint on what the model can do. Multiple heads were the explicit answer to "a single attention distribution can only point one place"; with shared keys and values, the eight heads can still point eight places, but every place has to be encoded in the same shared K and V. The claim is that the query side already carries enough of the diversity for the eight heads to remain useful, and Section 4.2 confirms it on WMT and the Billion-Word LM benchmark.
The arithmetic intensity gets cut by
Redo Section 2.4.1's count with the new shapes. The cache is now and of size rather than , a factor of smaller. The arithmetic per step is unchanged at , because the query heads still do all the multiplies they did before; the savings live entirely in the bytes moved. Across steps the memory total becomes , and dividing by the arithmetic gives the new ratio:
Compare to (1). The dominant term has split off a factor of and become , which is the headline result. The and terms are unchanged, so the precise statement is that the dominant term drops by , not that the ratio as a whole does. The Provenance panel below expands on the one place the paper's framing is easy to overread.
Concretely, take the paper's WMT setup, , , . The cache term works out as:
Roughly an cut to the term, which is . The other two terms contribute on the order of and (at batch ) about . The dominant cost was the cache, and it shrank by . Push to today's shapes (, ) and the factor is 32, and the gap between "memory bound" and "compute bound" closes by the same factor.
Two pieces of the bound do not get touched. The term is the input and output vectors , which are loaded each step regardless of how the cache is structured. The term is the four projection matrices , which get pulled from memory once per step and do not depend on the batch. Both are constants that batched decoding already handles by adding more sequences to a step, the same way bigger batches dilute weight loads in training. MQA leaves them where they were and goes straight after the term that actually scales with sequence length.
WMT and LM: an order of magnitude on decode, 0.2 BLEU on quality
Section 4 runs the comparison on two benchmarks: WMT 2014 English to German translation, and the Billion-Word Language Modeling Benchmark. The translation baseline is a 6-layer encoder-decoder with , , and , 211 million parameters. The multi-query model has the same total parameter count (Shazeer widens to 5440 to make up for the parameters MQA removes), and all the attention layers in the model, encoder self-attention, decoder self-attention, and encoder-decoder attention, are switched over.
On the WMT dev set the multi-query model is at against the baseline's , and at BLEU 26.5 against 26.7, a 0.2 point drop. On the WMT14 test set, multi-query lands at BLEU 27.5 with greedy decoding and 28.5 with beam-4 search, against the baseline's 27.7 and 28.4. With beam search multi-query is actually a hair ahead, although the gap is well within evaluation noise and Shazeer reads it as a tie. On Billion-Word LM the multi-query model gets perplexity 30.2 against the baseline 29.9. So multi-query is slightly worse on both metrics, by a fraction of a BLEU point on the first and 0.3 nats on the second.
The speed difference is what the paper buys for that 0.2 BLEU. Table 2 gives amortized cost per output token on one TPUv2 with 8 cores. Greedy decoding takes 46 microseconds per token on the multi-head decoder. The multi-query decoder takes 3.8 microseconds, a speedup. Beam-4 search drops from 203 microseconds to 32, a speedup (smaller because beam search is closer to compute-bound, since each step does beam-width times as much arithmetic). The encoder, which trains in parallel and is not memory-bound, barely changes: 1.7 to 1.5 microseconds. Training is a wash: 13.2 to 13.0 microseconds per token.
The paper also runs the obvious cheap alternative as a sanity check, four versions of the multi-head model with fewer or smaller heads ( with rescaled so the per-head attention dimension still holds), then widens to keep parameter count matched. Those alternatives shrink the cache too, but they pay a lot more for it. The cheapest of them () lands at BLEU 26.2, 0.3 points worse than multi-query. The two extremes ( and ) are at 25.8, 0.7 points worse. Plotted as speed against quality, multi-query is the only point that gets both at once:
multi-query · h=8 queries, 1 shared KV
On Billion-Word LM the pattern is the same: multi-query at perplexity 30.2, the cheaper alternatives at 30.9 through 31.2. The shrink-the-heads alternatives are not bad, but they sit on the wrong side of the trade, paying more in quality for less in speed. Multi-query is the one variant that recovers almost all of the decode cost while costing almost no quality.
One last detail the paper lists in a footnote, because it matters when you reproduce the numbers: the measured decode time was on padded fixed-shape tensors, the kind of implementation TPUs run well, so every step takes the same time regardless of how many cached tokens it actually attends to. A smarter implementation that grew the cache incrementally would be faster early in the sequence; the 46-microsecond and 3.8-microsecond figures are the steady-state numbers at full length.
Where MQA sits in the KV-cache lineage
The 0.2-BLEU drop is the part of the paper that later work walked back. By 2023 it was clear that multi-query attention left a little quality on the table for a few large models, and a softer version of the same idea showed up: Ainslie et al., GQA. Group the query heads into groups; let each group share one K/V; multi-head is ; multi-query is . Practical models settle in the middle. Llama 2 70B uses for ; Mistral 7B uses for . The cache shrinks by 4x or 8x rather than 32x, and most of the quality loss disappears. GQA is the practical setting, but multi-query is its limit: knowing what is at the extreme is what makes the middle ground well-defined.
MQA is also the first paper to call out the broader cost: attention's expensive component on a deployed model is not the matrix multiplies, it is the bytes being moved. A small library of papers that followed picked up other parts of the same problem and sit alongside this one. FlashAttention (Dao et al., 2022) keeps the attention math bit-exact but tiles it so the score matrix never materializes in slow memory, cutting the bandwidth cost of one attention call by an order of magnitude. PagedAttention (Kwon et al., 2023) attacks how the cache is physically laid out across many concurrent users in a serving system, where naive contiguous allocations fragment DRAM and waste 60 to 80 percent of it; paging it like virtual memory recovers the rest. All three sit on different axes of the same cost and stack cleanly.
Read them together and each one attacks a different piece of the same decode cost. MQA shrinks the cache the model keeps. GQA gives back some of the quality MQA traded away by keeping a few KV heads. FlashAttention shrinks the work each attention call has to do over that cache. PagedAttention shrinks the wasted space when many sequences share a GPU. Every model deployed at scale today, every chat assistant and every code model, is running some combination of those four, and multi-query attention is the one that ran first and made the rest possible.
Questions you might still have
How can dropping every head except one cost only 0.2 BLEU?
The heads still differ on the query side, the side that varies across heads. Eight different queries can still be evaluated against one shared set of keys and values, and pull out eight different weighted averages. The only restriction is that all eight heads have to use the same K and V to do it.
Why does training time barely change while decoding gets 12x faster?
Training already runs n queries against the same K and V in parallel, so K and V are loaded once per sequence and the per-token cost is amortized. Decoding fires one query at a time and must reload the cache every step, which is exactly what the n/d term measures.
Is multi-query attention the same as grouped-query attention?
No. Multi-query is the extreme: a single shared K, V pair. Grouped-query is the middle ground (a few groups, each with its own K, V), introduced four years later by Ainslie et al. and adopted by Llama 2 and Mistral. The siblings sit on a line: MHA -> GQA -> MQA, cache shrinking by 1, by groups, by h.
How does MQA fit alongside FlashAttention and PagedAttention?
They attack three different costs in the same attention layer. MQA shrinks the cache itself by sharing K, V across heads. FlashAttention shrinks the cost of one attention call, keeping the math exact but tiling so the N x N scores never materialize in slow memory (/flashattention-fast-and-memory-efficient-exact/). PagedAttention shrinks the wasted space in how the cache is laid out in DRAM across many concurrent users (/paged-attention/). All three can stack.
Footnotes & further reading
- The paper: Noam Shazeer, Fast Transformer Decoding: One Write-Head is All You Need (Google, 2019). Six pages, including the two tables of WMT results.
- The Transformer that MQA edits: Vaswani et al., Attention Is All You Need (2017), which Section 2.2 of Shazeer's paper reviews in einsum form.
- The middle ground that came later: Ainslie et al., GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints (2023). Adopted by Llama 2 (Touvron et al.) and Mistral.
- The kernel-side complement, tiling so the score matrix stays in SRAM: Dao et al., FlashAttention (2022). Explainer at /flashattention-fast-and-memory-efficient-exact/.
- The serving-side complement, paging the cache for many concurrent users: Kwon et al., Efficient Memory Management for LLMs with PagedAttention (2023). Sibling explainer at /paged-attention/.
- The local-attention work MQA is described as orthogonal to (it shrinks cache per token; locality shrinks the number of tokens attended to): Liu, Saleh, Pot, Goodrich, Sepassi, Kaiser, Shazeer, Generating Wikipedia by Summarizing Long Sequences (ICLR 2018); Povey et al., A time-restricted self-attention layer for ASR (ICASSP 2018); Zhang et al., Accelerating Neural Transformer via Average Attention Network (2018).
- The compute-versus-bandwidth language MQA is using without naming: Williams, Waterman, Patterson, Roofline: An Insightful Visual Performance Model for Multicore Architectures (Comms. of the ACM, 2009), the canonical statement of arithmetic intensity as the deciding ratio.
How could this explainer be improved? Found an error, or something unclear? I read every message.