VerifiedarXiv:1911.0215018 min
Architecture · Efficiency

Multi-Query Attention

Share one set of keys and values across attention heads, and the decoder stops waiting on memory.

A Transformer generates one token at a time, and at every step it has to reload the keys and values of every token it has ever seen. Shazeer (2019) drops that reload by a factor of h, the number of heads, with a change small enough to fit on one line of code.

Explaining the paperFast Transformer Decoding: One Write-Head is All You NeedNoam Shazeer · Google · 2019 · arXiv:1911.02150

Training is parallel and arithmetic-bound; decoding is sequential and memory-bound. So shrink the thing that gets reloaded.

A six-page paper from Noam Shazeer at Google, written in late 2019, sits underneath nearly every modern large language model. It named the cost a Transformer pays once you ship it: at decode time, the attention layer is not waiting on its multiplications, it is waiting on its memory traffic. Read the keys and values back, score, write the new key and value, repeat. The arithmetic per step is small, and the bytes moved per step are what determines how long each step takes.

One edit does it. The keys and values in standard multi-head attention live in tensors of shape h×dkh \times d_k (and h×dvh \times d_v), one set per head. Multi-query attention drops the head dimension on keys and values: one shared KK, one shared VV, all hh query heads point at them. The cache shrinks by a factor of hh, the bytes-per-step shrink by the same factor, and decoding gets much faster. The cost is about 0.2 BLEU on WMT14 English-German, and the paper shows that price is far cheaper than every other way of shrinking the cache.

The argument runs through why decoding is memory-bound while training is not, what exactly sits in the KV cache, what one line of einsum changes, and what the resulting arithmetic intensity buys at deploy time. The math is short; the consequences run forward through grouped-query attention, FlashAttention, and PagedAttention.

Decoding waits on memory

One asymmetry between training and decoding sets up everything that follows. During training, a Transformer sees the whole target sequence at once. All nn queries fire in parallel against the same KK and VV on the same forward pass, and a backward pass on each batch reuses the tensors many times. Per token, the layer does a lot of arithmetic for relatively few bytes moved through memory. That is what a GPU is built for: big matrix multiplies that pull each weight off the wire once and use it many times.

Generation does not look like that: to produce the next token, the model runs a single forward pass with exactly one new query. That query attends to every previous position, so the keys and values of every prior token have to be available. They were already computed at earlier steps, so the model stores them in a KV cache and reads the cache rather than redoing the work. The arithmetic per step is small (one new query against nn stored entries), but every step has to pull the entire cache back through memory. On modern accelerators that traffic costs more than the arithmetic does, by orders of magnitude.

Section 2.3.1 of the paper makes the asymmetry exact. Under the simplifying assumptions m=nm = n, k=v=d/hk = v = d/h, and ndn \le d, batched training does Θ(bnd2)\Theta(b n d^2) arithmetic and O(bnd+bhn2+d2)O(b n d + b h n^2 + d^2) bytes of memory access, for a memory-to-arithmetic ratio of O(1/k+1/(bn))O(1/k + 1/(bn)). Section 2.4.1 redoes the count for incremental decoding and gets a different ratio:

bytes movedarithmetic opsdecode, per step  =  Θ ⁣(nd  +  1b)\underbrace{\frac{\text{bytes moved}}{\text{arithmetic ops}}}_{\text{decode, per step}} \;=\; \Theta\!\left(\frac{n}{d} \;+\; \frac{1}{b}\right)
(1)

The two terms come from different parts of the layer. The n/dn/d piece is the cache itself: KK and VV are size bhnk=bn2b h n k = b n^2 (because hk=dh k = d), so the bytes moved across nn steps grow as bn2db n^2 d, divided by the per-step arithmetic bnd2b n d^2 gives n/dn/d. The 1/b1/b piece is the projection weights Pq,Pk,Pv,PoP_q, P_k, P_v, P_o, which are independent of batch but get loaded once per step regardless. Big batches dilute the second; nothing dilutes the first.

What ratio is too high? A modern GPU is compute-bound when arithmetic intensity (the inverse of (1)) sits above roughly 100 to 150 FLOPs per byte of memory accessed, the rough compute-to-bandwidth ratio of every recent device. The paper is more conservative: it points out the gap can be two orders of magnitude. When the ratio is too high, the multipliers stall because the bytes have not arrived yet. Slide the cache length below and watch the two bars move in lockstep, with the gap to the compute roof never closing:

Figure 1 · why decoding is memory-bound
4,096 tok
32 KV heads
One decode step, in the configuration of a 32-layer model with 32 query heads and head width 128. Attention arithmetic and bytes moved through memory both grow linearly with the cache, so the arithmetic intensity is pinned. With multi-head attention it sits at exactly 1 FLOP per byte. A GPU needs around 150 FLOPs per byte to be limited by arithmetic, a typical compute-to-bandwidth ratio rather than a specific card's spec, so the step is memory-bound at every cache size. Toggle to GQA-8 and the bytes drop 4x; toggle MQA and the bytes drop 32x. The arithmetic does not move; only the memory does, and shrinking the memory is what the paper does.

Two pieces of context before going further. First, the gap is a property of decoding, not of any specific GPU: the compute-to-bandwidth ratio above 100 has held across roughly a decade of accelerator releases. Second, the "1 FLOP per byte" baseline is for plain multi-head attention; the same counting under multi-query gives 32 FLOPs per byte, and the gap to the roof closes by exactly the same factor.

The KV cache, and why it grows by one row per token

Attention starts from one query vector qq and a set of mm key-value pairs assembled into matrices KK and VV. The output is a weighted average of the values, with weights from the softmax of dot products between the query and the keys:

Attention(q,K,V)  =  softmax ⁣(qK)V\text{Attention}(q, K, V) \;=\; \operatorname{softmax}\!\left(q K^{\top}\right) V
(2)

That is everything attention does. The dot product qKq K^{\top} is large when the query and a key point the same way, the softmax turns the scores into weights that sum to one, and the result is the matching value. (The original Transformer scales the logits by 1/dk1/\sqrt{d_k} to keep them well-behaved in high dimension. Shazeer drops the scale because it can be folded into the projection that produces qq or KK.)

At decode time the queries arrive one at a time. The keys and values of every prior token do not change once they have been computed, so the layer stores them in a cache and reuses them. The query is fresh each step, used once, then thrown away; the keys and values are computed once and kept forever, and that asymmetry is what makes a cache useful at all. The cache grows by one row per generated token:

Figure 2 · the KV cache grows by one row per token
17 tok
Each generated token stores its key and value so they are never recomputed. The current token attends to all of them. So the cache, and the bytes read per step, climb with the sequence length. That climb is the n/dn/d term in (1). Multi-query attention will scale this whole picture down by a factor of hh.

Count the per-step bytes once. At position ii, the model has cached ii keys and values, each a vector of width dkd_k (or dvd_v), for every one of hh attention heads, across LL layers. That is

cache at step i  =  2K,V×h×dk×L×i    (numbers, per batch element)\text{cache at step } i \;=\; \underbrace{2}_{K,\,V} \times h \times d_k \times L \times i \;\;\text{(numbers, per batch element)}
(3)

For a model with h=8h = 8, dk=128d_k = 128, L=6L = 6 and a 128-token target (the paper's WMT setup), the cache at full length is about 1.6 million numbers per sequence. That is small. Scale to today's deployed models, h=32h = 32 and L=32L = 32 and a 32,000-token context, and the same arithmetic gives:

2×32×128×32×32,000    8.4×109  numbers per sequence2 \,\times\, 32 \,\times\, 128 \,\times\, 32 \,\times\, 32{,}000 \;\approx\; 8.4 \times 10^9 \;\text{numbers per sequence}

around 16 GB at 16-bit precision. At that size the cache stops being an implementation detail and starts dominating what a serving system has to plan around.

Inside one decode step, all 8.4 billion numbers (or 1.6 million, in the paper's setup) get streamed past the multipliers exactly once, and then they have to be streamed again on the next token, and the next. The arithmetic on each token is small (one query against every cached entry); the bytes are large; the ratio is what (1) measures.

Multi-head: every head, its own keys and values

Standard multi-head attention runs hh attention layers in parallel and gives each one its own projection from the input. A single input vector xx of width dd is multiplied by hh learned matrices to make hh queries, then by hh more to make hh keys, then by hh more to make hh values. Every head sees the same xx and produces its own version of all three. The outputs of the hh layers run through one more learned projection and sum.

# multi-head, batched (paper, section 2.3)
Q = einsum("bnd,hdk->bhnk", X, P_q)     # h independent query heads
K = einsum("bmd,hdk->bhmk", M, P_k)     # h independent key heads
V = einsum("bmd,hdv->bhmv", M, P_v)     # h independent value heads
logits  = einsum("bhnk,bhmk->bhnm", Q, K)
weights = softmax(logits + mask)
O = einsum("bhnm,bhmv->bhnv", weights, V)
Y = einsum("bhnv,hdv->bnd",  O, P_o)

The tensor shapes carry the cost. P_k and P_v are tensors of shape (h,d,dk)(h, d, d_k) and (h,d,dv)(h, d, d_v): the head dimension is in there explicitly, which is why KK comes out shaped (b,h,m,dk)(b, h, m, d_k) and lives in the cache that way. Eight heads, eight different key projections, eight cached keys per token.

The eight heads carry meaningful diversity. One commonly attends to the previous token, another to a content word in the prompt, another to a broad neighborhood. The standard intuition for why this is the right shape: a single attention distribution has to point somewhere, and a multi-head layer lets the model point eight places at once. (How clean those roles actually are in a trained model is more debatable than the textbook telling, but the diversity-of-attention argument is real, and it is the one the design rests on.)

Multi-query: one shared key and value, all hh queries

The paper's observation is that only the query side of attention has to differ across heads. The key and value side supplies a single shared set of vectors that every head dots against, and a single set can be dotted against from many directions. So drop the hh dimension from the key and value projections:

# multi-query, batched (paper, section 3)
Q = einsum("bnd,hdk->bhnk", X, P_q)     # h query heads, same as before
K = einsum("bmd,dk->bmk",   M, P_k)     # the h is gone: ONE shared key
V = einsum("bmd,dv->bmv",   M, P_v)     # ONE shared value
logits  = einsum("bhnk,bmk->bhnm", Q, K)
weights = softmax(logits + mask)
O = einsum("bhnm,bmv->bhnv", weights, V)
Y = einsum("bhnv,hdv->bnd",  O, P_o)

The diff is one letter per line. P_k is now shape (d,dk)(d, d_k), not (h,d,dk)(h, d, d_k). The cached KK drops the head dimension to (b,m,dk)(b, m, d_k), and so does VV. There are still hh queries, but they all point at one shared set of keys and one shared set of values. The output projection PoP_o stays per-head, since each head's output is its own vector even when it shares the K and V it dots against. Dropping hh from PkP_k and PvP_v shrinks every tensor downstream of them by the same factor.

Toggle the three settings below and watch the wiring collapse. The eight query heads stay; the eight KV heads of multi-head collapse to one for multi-query, and to a small number for grouped-query, the middle ground from Ainslie et al. (2023) that Mistral picks up four years later. The cache bar tracks the number of KV heads exactly:

Figure 3 · multi-head, grouped-query, multi-query
8 KV

GQA-8 · Mistral: 4 query heads share each KV head

32 query heads wired down to a varying number of KV heads. Multi-head keeps 32 (one per query). Multi-query collapses them all to 1, the paper's proposal. Grouped-query keeps a few (Mistral picks 8), the compromise that came later. The cache scales with the number of KV heads, so 1 head is a 32x smaller cache than 32, and 8 is a 4x smaller cache.

The cost is real and the paper does not hide it. Removing per-head keys and values is a constraint on what the model can do. Multiple heads were the explicit answer to "a single attention distribution can only point one place"; with shared keys and values, the eight heads can still point eight places, but every place has to be encoded in the same shared K and V. The claim is that the query side already carries enough of the diversity for the eight heads to remain useful, and Section 4.2 confirms it on WMT and the Billion-Word LM benchmark.

The arithmetic intensity gets cut by hh

Redo Section 2.4.1's count with the new shapes. The cache is now KK and VV of size bmdkb m d_k rather than bhmdkb h m d_k, a factor of hh smaller. The arithmetic per step is unchanged at Θ(bnd2)\Theta(b n d^2), because the hh query heads still do all the multiplies they did before; the savings live entirely in the bytes moved. Across nn steps the memory total becomes Θ(bnd+bn2k+nd2)\Theta(b n d + b n^2 k + n d^2), and dividing by the arithmetic gives the new ratio:

bytes movedarithmetic ops  =  Θ ⁣(1d  +  ndh  +  1b)\frac{\text{bytes moved}}{\text{arithmetic ops}} \;=\; \Theta\!\left(\frac{1}{d} \;+\; \frac{n}{d \, h} \;+\; \frac{1}{b}\right)
(4)

Compare to (1). The dominant n/dn/d term has split off a factor of hh and become n/(dh)n/(d h), which is the headline result. The 1/d1/d and 1/b1/b terms are unchanged, so the precise statement is that the dominant term drops by hh, not that the ratio as a whole does. The Provenance panel below expands on the one place the paper's framing is easy to overread.

Concretely, take the paper's WMT setup, h=8h = 8, d=1024d = 1024, n=128n = 128. The cache term works out as:

nd  =  1281024  =  0.125  MQA  ndh  =  12810248    0.016\frac{n}{d} \;=\; \frac{128}{1024} \;=\; 0.125 \quad \xrightarrow{\;\text{MQA}\;} \quad \frac{n}{d\,h} \;=\; \frac{128}{1024 \cdot 8} \;\approx\; 0.016

Roughly an 8×8\times cut to the term, which is hh. The other two terms contribute on the order of 0.0010.001 and (at batch b=1024b = 1024) about 0.0010.001. The dominant cost was the cache, and it shrank by hh. Push to today's shapes (h=32h = 32, n=4096n = 4096) and the factor is 32, and the gap between "memory bound" and "compute bound" closes by the same factor.

Two pieces of the bound do not get touched. The 1/d1/d term is the input and output vectors x,q,o,yx, q, o, y, which are loaded each step regardless of how the cache is structured. The 1/b1/b term is the four projection matrices Pq,Pk,Pv,PoP_q, P_k, P_v, P_o, which get pulled from memory once per step and do not depend on the batch. Both are constants that batched decoding already handles by adding more sequences to a step, the same way bigger batches dilute weight loads in training. MQA leaves them where they were and goes straight after the term that actually scales with sequence length.

WMT and LM: an order of magnitude on decode, 0.2 BLEU on quality

Section 4 runs the comparison on two benchmarks: WMT 2014 English to German translation, and the Billion-Word Language Modeling Benchmark. The translation baseline is a 6-layer encoder-decoder with dmodel=1024d_{\text{model}} = 1024, dff=4096d_{ff} = 4096, h=8h = 8 and dk=dv=128d_k = d_v = 128, 211 million parameters. The multi-query model has the same total parameter count (Shazeer widens dffd_{ff} to 5440 to make up for the parameters MQA removes), and all the attention layers in the model, encoder self-attention, decoder self-attention, and encoder-decoder attention, are switched over.

On the WMT dev set the multi-query model is at ln(PPL)=1.439\ln(\text{PPL}) = 1.439 against the baseline's 1.4241.424, and at BLEU 26.5 against 26.7, a 0.2 point drop. On the WMT14 test set, multi-query lands at BLEU 27.5 with greedy decoding and 28.5 with beam-4 search, against the baseline's 27.7 and 28.4. With beam search multi-query is actually a hair ahead, although the gap is well within evaluation noise and Shazeer reads it as a tie. On Billion-Word LM the multi-query model gets perplexity 30.2 against the baseline 29.9. So multi-query is slightly worse on both metrics, by a fraction of a BLEU point on the first and 0.3 nats on the second.

The speed difference is what the paper buys for that 0.2 BLEU. Table 2 gives amortized cost per output token on one TPUv2 with 8 cores. Greedy decoding takes 46 microseconds per token on the multi-head decoder. The multi-query decoder takes 3.8 microseconds, a 12×12\times speedup. Beam-4 search drops from 203 microseconds to 32, a 6×6\times speedup (smaller because beam search is closer to compute-bound, since each step does beam-width times as much arithmetic). The encoder, which trains in parallel and is not memory-bound, barely changes: 1.7 to 1.5 microseconds. Training is a wash: 13.2 to 13.0 microseconds per token.

The paper also runs the obvious cheap alternative as a sanity check, four versions of the multi-head model with fewer or smaller heads (h=1,2,4,8h = 1, 2, 4, 8 with dk=dvd_k = d_v rescaled so the per-head attention dimension hdk=dh d_k = d still holds), then widens dffd_{ff} to keep parameter count matched. Those alternatives shrink the cache too, but they pay a lot more for it. The cheapest of them (h=2,dk=64h = 2, d_k = 64) lands at BLEU 26.2, 0.3 points worse than multi-query. The two extremes (h=1,dk=128h = 1, d_k = 128 and h=8,dk=16h = 8, d_k = 16) are at 25.8, 0.7 points worse. Plotted as speed against quality, multi-query is the only point that gets both at once:

Figure 4 · the speed-quality trade, on WMT14 en-de
MQA

multi-query · h=8 queries, 1 shared KV

Decoder microseconds per output token (log scale) against dev BLEU. Multi-query sits in the bottom-right corner: a 12x faster decoder for 0.2 BLEU. Multi-head is the baseline. The four shrink-the-heads variants sit at the same decoder cost as multi-head (they do not shrink the cache reload nearly as much) and well below multi-query on BLEU. Toggle to ln(PPL) and the ranking holds. No other way of trimming the cache reaches the same pareto front.

On Billion-Word LM the pattern is the same: multi-query at perplexity 30.2, the cheaper alternatives at 30.9 through 31.2. The shrink-the-heads alternatives are not bad, but they sit on the wrong side of the trade, paying more in quality for less in speed. Multi-query is the one variant that recovers almost all of the decode cost while costing almost no quality.

One last detail the paper lists in a footnote, because it matters when you reproduce the numbers: the measured decode time was on padded fixed-shape tensors, the kind of implementation TPUs run well, so every step takes the same time regardless of how many cached tokens it actually attends to. A smarter implementation that grew the cache incrementally would be faster early in the sequence; the 46-microsecond and 3.8-microsecond figures are the steady-state numbers at full length.

Where MQA sits in the KV-cache lineage

The 0.2-BLEU drop is the part of the paper that later work walked back. By 2023 it was clear that multi-query attention left a little quality on the table for a few large models, and a softer version of the same idea showed up: Ainslie et al., GQA. Group the hh query heads into GG groups; let each group share one K/V; multi-head is G=hG = h; multi-query is G=1G = 1. Practical models settle in the middle. Llama 2 70B uses G=8G = 8 for h=64h = 64; Mistral 7B uses G=8G = 8 for h=32h = 32. The cache shrinks by 4x or 8x rather than 32x, and most of the quality loss disappears. GQA is the practical setting, but multi-query is its limit: knowing what is at the extreme is what makes the middle ground well-defined.

MQA is also the first paper to call out the broader cost: attention's expensive component on a deployed model is not the matrix multiplies, it is the bytes being moved. A small library of papers that followed picked up other parts of the same problem and sit alongside this one. FlashAttention (Dao et al., 2022) keeps the attention math bit-exact but tiles it so the N×NN \times N score matrix never materializes in slow memory, cutting the bandwidth cost of one attention call by an order of magnitude. PagedAttention (Kwon et al., 2023) attacks how the cache is physically laid out across many concurrent users in a serving system, where naive contiguous allocations fragment DRAM and waste 60 to 80 percent of it; paging it like virtual memory recovers the rest. All three sit on different axes of the same cost and stack cleanly.

Read them together and each one attacks a different piece of the same decode cost. MQA shrinks the cache the model keeps. GQA gives back some of the quality MQA traded away by keeping a few KV heads. FlashAttention shrinks the work each attention call has to do over that cache. PagedAttention shrinks the wasted space when many sequences share a GPU. Every model deployed at scale today, every chat assistant and every code model, is running some combination of those four, and multi-query attention is the one that ran first and made the rest possible.

Provenance Verified against primary literature
Vaswani et al. (2017)The multi-head attention layer that MQA modifies. Section 2 of Shazeer is the verbatim review.
Williams, Waterman, Patterson (2009)The roofline model. Arithmetic intensity (FLOPs per byte) determines whether a kernel is bound by compute or memory.
Liu, Saleh, et al. (2018), Povey et al. (2018)Local and time-restricted attention, the prior approaches to shrinking the K, V tensors that MQA is presented as orthogonal to.
GQA (Ainslie et al., 2023)The middle ground between MHA and MQA: a few shared KV heads instead of one. Sibling explainer at /mistral/ uses it.
PagedAttention (Kwon et al., 2023)The KV cache becomes the serving bottleneck; the sibling explainer at /paged-attention/ tackles its fragmentation. Same cache, a different cost.
correctionSection 2.4.1 puts the multi-head incremental memory-to-FLOPs ratio at O(n/d + 1/b). Section 3.1 gives MQA O(1/d + n/(d·h) + 1/b). The factor that drops by h is the n/d term, not the entire ratio: the 1/b and 1/d pieces are unchanged. So the precise statement is that the dominant term shrinks by h, not that the cost shrinks by h.

Questions you might still have

?

How can dropping every head except one cost only 0.2 BLEU?
The heads still differ on the query side, the side that varies across heads. Eight different queries can still be evaluated against one shared set of keys and values, and pull out eight different weighted averages. The only restriction is that all eight heads have to use the same K and V to do it.

?

Why does training time barely change while decoding gets 12x faster?
Training already runs n queries against the same K and V in parallel, so K and V are loaded once per sequence and the per-token cost is amortized. Decoding fires one query at a time and must reload the cache every step, which is exactly what the n/d term measures.

?

Is multi-query attention the same as grouped-query attention?
No. Multi-query is the extreme: a single shared K, V pair. Grouped-query is the middle ground (a few groups, each with its own K, V), introduced four years later by Ainslie et al. and adopted by Llama 2 and Mistral. The siblings sit on a line: MHA -> GQA -> MQA, cache shrinking by 1, by groups, by h.

?

How does MQA fit alongside FlashAttention and PagedAttention?
They attack three different costs in the same attention layer. MQA shrinks the cache itself by sharing K, V across heads. FlashAttention shrinks the cost of one attention call, keeping the math exact but tiling so the N x N scores never materialize in slow memory (/flashattention-fast-and-memory-efficient-exact/). PagedAttention shrinks the wasted space in how the cache is laid out in DRAM across many concurrent users (/paged-attention/). All three can stack.

Footnotes & further reading

  1. The paper: Noam Shazeer, Fast Transformer Decoding: One Write-Head is All You Need (Google, 2019). Six pages, including the two tables of WMT results.
  2. The Transformer that MQA edits: Vaswani et al., Attention Is All You Need (2017), which Section 2.2 of Shazeer's paper reviews in einsum form.
  3. The middle ground that came later: Ainslie et al., GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints (2023). Adopted by Llama 2 (Touvron et al.) and Mistral.
  4. The kernel-side complement, tiling so the score matrix stays in SRAM: Dao et al., FlashAttention (2022). Explainer at /flashattention-fast-and-memory-efficient-exact/.
  5. The serving-side complement, paging the cache for many concurrent users: Kwon et al., Efficient Memory Management for LLMs with PagedAttention (2023). Sibling explainer at /paged-attention/.
  6. The local-attention work MQA is described as orthogonal to (it shrinks cache per token; locality shrinks the number of tokens attended to): Liu, Saleh, Pot, Goodrich, Sepassi, Kaiser, Shazeer, Generating Wikipedia by Summarizing Long Sequences (ICLR 2018); Povey et al., A time-restricted self-attention layer for ASR (ICASSP 2018); Zhang et al., Accelerating Neural Transformer via Average Attention Network (2018).
  7. The compute-versus-bandwidth language MQA is using without naming: Williams, Waterman, Patterson, Roofline: An Insightful Visual Performance Model for Multicore Architectures (Comms. of the ACM, 2009), the canonical statement of arithmetic intensity as the deciding ratio.