VerifiedarXiv:2205.1413522 min
Architecture · Systems

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Attention was never short on compute. It was waiting on memory.

The slow part of attention is not the matrix multiplies. It is shuttling a big intermediate matrix to and from slow GPU memory. FlashAttention computes exactly the same attention, in one fused kernel that keeps the work on-chip, and the kernel runs several times faster while using far less memory.

Explaining the paperFlashAttention: Fast and Memory-Efficient Exact Attention with IO-AwarenessDao, Fu, Ermon, Rudra, Ré · Stanford · NeurIPS 2022 · arXiv:2205.14135

What if the slow part of attention was never the math?

Everyone learns the same headline about attention: its time and memory grow with the square of the sequence length, N2N^2. The usual reading is that the arithmetic is the problem, all those multiplications in the N×NN \times N score matrix. FlashAttention makes a sharper claim. On a modern GPU the arithmetic was almost never the bottleneck. The attention kernel spends most of its wall-clock time moving that N×NN \times N matrix between two kinds of memory, and the multiplications wait on the moving. Fix the data movement and you get a kernel that computes the identical answer, several times faster, on a fraction of the memory.

The fix rests on two old ideas put together carefully. The first is tiling: cut the inputs into blocks small enough to fit in the GPU's fast on-chip memory, and compute attention one block at a time so the big matrix is never assembled in slow memory at all. The second is recomputation: instead of saving the big matrix for the backward pass, throw it away and rebuild it from scratch when the gradients need it. Both look like they should cost something. Packed into a single hand-written GPU kernel, they cost almost nothing and win a lot, and the paper proves the result is close to the best any exact method could do.

To see why, we build up a short tower of facts: how GPU memory is actually laid out, why the standard attention kernel wastes it, the one piece of math (a streaming softmax) that makes tiling possible, the tiled algorithm itself, how the backward pass dodges the same trap, and finally the cost accounting that turns "more arithmetic" into "less time." Stacked up, they explain the whole paper, and a good deal of why every Transformer you use today runs on some descendant of this kernel.

Attention was waiting on memory

Start with the operation itself. Given a sequence of NN tokens, attention forms three matrices, each of shape N×dN \times d where dd is the head dimension (small, usually 64 or 128): the queries QQ, the keys KK, and the values VV. The computation is three steps.

S=QKRN×N,P=softmax(S)RN×N,O=PVRN×dS = QK^\top \in \mathbb{R}^{N\times N}, \qquad P = \text{softmax}(S) \in \mathbb{R}^{N\times N}, \qquad O = PV \in \mathbb{R}^{N\times d}
(1)

Every query scores every key (that is SS, the N×NN \times N matrix), a row-wise softmax turns each row of scores into weights that sum to one (that is PP), and each output is the weighted average of the values (that is OO). The squareness is real: SS and PP each have N2N^2 entries, and for a long sequence that is enormous. With N=1024N = 1024 and d=64d = 64, as in GPT-2, the score matrix holds over a million numbers, sixteen times the size of any one of the N×dN \times d input matrices.

This is the part that gets misread. The expensive thing about that N×NN \times N matrix on a GPU is not computing it. It is storing it and reading it back. To see that, you have to know one fact about how the hardware is built, and it is a fact most people coding in PyTorch never have to think about.

A GPU has two relevant kinds of memory. There is a large pool, called HBM (high-bandwidth memory), where your tensors live: tens of gigabytes, fast by everyday standards. And there is a tiny pool of on-chip memory, called SRAM, sitting right next to the arithmetic units: a few tens of megabytes total, but more than ten times faster to read and write. Every operation works the same way. Load inputs from HBM into SRAM, compute, write the result back out to HBM. The catch is that over the last decade arithmetic has gotten much faster while memory bandwidth has crept up only slowly, so for many operations the chip now finishes the math and sits idle, waiting on the next load from HBM. Those operations are called memory-bound: too little arithmetic per byte moved, so the compute units idle waiting on memory. A row-wise softmax over a giant matrix is the textbook case, roughly one operation per number it touches, while a matrix multiply reuses every number it loads across a whole row of outputs, which is why attention's elementwise middle is the bottleneck and its two matmuls are not.

So the right way to think about an attention kernel is not "how many multiplications." It is "how many bytes cross the slow link to HBM." That single change of accounting is the whole paper. The authors call it being IO-aware: counting reads and writes between levels of memory, the way databases and numerical libraries have done for decades, instead of counting only floating-point operations.

The memory you forget you pay for

Picture a cook at a station. The countertop is small but everything on it is within arm's reach. The walk-in fridge is enormous but every trip to it costs you a walk. A fast kitchen minimizes trips to the fridge, not fridge size. SRAM is the countertop; HBM is the fridge.

The concrete numbers, for the NVIDIA A100 the paper benchmarks on, make the gap vivid. HBM is 40 GB with a bandwidth around 1.5 TB/s. The on-chip SRAM is 192 KB per streaming multiprocessor (the GPU's basic compute unit, an SM), and there are 108 of them, so about 20 MB total, but at an estimated 19 TB/s. So the fast memory is roughly two thousand times smaller than the slow memory and more than ten times faster. Below: bar length is bandwidth, and the capacity grows as you go down. The fastest memory is the smallest one, which is exactly the tension FlashAttention has to manage.

Figure 1 · the memory hierarchy
On-chip SRAM is tiny (~20 MB) but ~19 TB/s; HBM is large (40 GB) but ~1.5 TB/s; CPU DRAM is huge and far slower. Bandwidth bars shorten as capacity grows. The whole game is keeping the working set on the small fast tier so the slow link is touched as little as possible.

Two definitions worth nailing down, because they get tangled. The SRAM here is the per-core scratchpad (the combined L1 cache and shared memory on each SM), not the A100's 40 MB L2 cache, which is larger but shared and slower. And the way to tell memory-bound from compute-bound is arithmetic intensity, arithmetic operations done per byte of memory traffic: a big matrix multiply with a fat inner dimension has high intensity and is compute-bound, while a softmax, an elementwise dropout, a mask sit at low intensity, memory-bound, which is most of what attention does apart from its two matrix multiplies.

The standard trick for memory-bound work is kernel fusion: if several operations touch the same data, load it once, do all of them on-chip, and write once, instead of making a separate round trip per operation. Compilers fuse simple elementwise chains automatically. The trouble is that during training the intermediate results usually have to be written out to HBM anyway, to be kept for the backward pass, which quietly undoes most of the benefit. Attention has a second, deeper obstacle on top of that: the softmax couples a whole row of scores, so a stock compiler cannot fuse the score, softmax, and value steps tile by tile at all. Getting around both is what FlashAttention does.

Where the time actually goes

Now watch the standard attention kernel spend its time. It runs the three steps of (1) as three separate passes, and the big N×NN \times N matrix crosses the slow link on every one.

# standard attention: three kernels, three trips through slow HBM
S = (Q @ K.T) / sqrt(d)   # write the N×N score matrix to HBM
P = softmax(S)            # read N×N back, softmax, write N×N to HBM
O = P @ V                 # read N×N + V, write the N×d output O

Read it as memory traffic, not arithmetic. The first line computes the scores and writes all N2N^2 of them to HBM. The second reads those N2N^2 numbers back, applies the softmax, and writes N2N^2 numbers out again. The third reads them a third time to multiply by VV. The matrix that is sixteen times bigger than a single input gets written twice and read three times, and the masking and dropout that a real Transformer applies to it add still more passes. The arithmetic in between is cheap. The kernel is mostly waiting on those round trips.

FlashAttention's goal is to never write that matrix to HBM at all. Toggle the figure between the two. In the standard kernel the N×NN \times N matrix lives in HBM and makes several round trips up to the chip and back. In FlashAttention a single fused kernel builds the matrix one tile at a time inside SRAM and never spills it; only the small inputs and the output OO ever cross the link.

Figure 2 · IO round trips
Standard attention writes the N×N matrix to HBM and reads it back across separate kernels, filling the HBM-traffic meter to 40.3 GB. FlashAttention fuses everything into one kernel: the N×N matrix is tiled in SRAM and never written to HBM, so traffic drops to 4.4 GB. Same math, far fewer bytes moved.

Stated as goals, FlashAttention has to do two things that look impossible at first. It has to compute the softmax without ever having the whole row of scores in memory at once, and it has to make the backward pass work without the saved N×NN \times N matrix. The first takes a piece of math, the second a piece of accounting.

The obstacle: softmax wants the whole row

Tiling means processing the scores in blocks. But the softmax is what makes that hard, because a softmax couples a whole row. To turn a row of scores into weights you divide each exponential by the sum of all of them, and to do it without overflowing you first subtract the row's maximum. Both the max and the sum need the entire row of scores. The numerically safe softmax of a vector xx is

m(x)=maxixi,f(x)=[ex1m(x), , exBm(x)],(x)=if(x)i,softmax(x)=f(x)(x)m(x) = \max_i x_i, \qquad f(x) = \big[\,e^{x_1 - m(x)},\ \dots,\ e^{x_B - m(x)}\,\big], \qquad \ell(x) = \textstyle\sum_i f(x)_i, \qquad \text{softmax}(x) = \frac{f(x)}{\ell(x)}
(2)

If you only have a block of the row, you cannot finish, because a bigger value might be waiting in a block you have not seen, and that would change the maximum and rescale everything. The fragility comes from the exponentials being scale-sensitive: every weight formed so far was computed as exime^{x_i - m} relative to the running max, so a new, larger max means every earlier exponential was normalized against the wrong reference and has to be rescaled by emoldmnewe^{\,m_{\text{old}} - m_{\text{new}}}. So it looks like you are stuck holding the whole row, which is the thing you were trying to avoid.

The way out is a streaming trick that predates this paper by four years (Milakov and Gimelshein, 2018, building on the same idea Welford used for running variances). You do not need the whole row. You need to carry two running numbers per row as the blocks arrive: the running maximum mm and the running normalizer \ell. When a new block shows up with its own local max and local sum, you merge them, and if the new block's max is bigger, you correct the old running sum by an exponential factor before adding the new one.

mnew=max(m,m~),new=emmnew  +  em~mnew~m^{\text{new}} = \max(m,\, \tilde m), \qquad \ell^{\text{new}} = e^{\,m - m^{\text{new}}}\,\ell \;+\; e^{\,\tilde m - m^{\text{new}}}\,\tilde\ell
(3)

That factor emmnewe^{m - m^{\text{new}}} is the whole mechanism. It is one when the max did not change, and shrinks the old contribution when a larger value arrives and resets the reference point. It works because the algebra closes exactly: exmemmnew=exmnewe^{x - m}\,e^{\,m - m^{\text{new}}} = e^{x - m^{\text{new}}}, the factor undoes the old reference and applies the new one, so after every block, not just the last, the running pair (m,)(m, \ell) is exactly what a full pass over the row so far would have produced. The same correction is applied to the running output, so the partial weighted average stays consistent as the reference shifts. It is like grading a stack of exams on a curve as they trickle in: keep the running top score and the running tally, and when a higher exam appears it moves the curve, so you rescale the ones you already counted. Nothing is approximated. After the last block the running (m,)(m, \ell) equal the true row max and the true row sum exactly, and so does the output.

Step through it. Sixteen scores arrive in four blocks. The figure carries only mm and \ell, rescales the accumulator when a block brings a larger max (watch the amber note), and tracks the running output. By the last block the streamed output lands exactly on the answer you would get from the full softmax. This is what makes attention an exact computation you can do block by block.

Figure 3 · online softmax
start
Scores arrive in four blocks. Carry just the running max m and normalizer ; when a block raises the max, rescale the accumulated stats by exp(m_old − m_new), then fold the block in. The running output converges to, and exactly equals, the full-softmax answer (the amber target). Press Play, or scrub to study a rescale.

This is also the difference between this work and the closest prior art. Rabe and Staats (2021) used the same streaming softmax to show attention needs only linear extra memory. But they were minimizing the peak memory, not the memory traffic, so their version ran at about the same speed as standard attention. FlashAttention takes the same math and asks the harder question: how do you order the reads and writes so the chip is never waiting?

Tiling: attention one block at a time

With the streaming softmax in hand, the algorithm falls out. Cut QQ into row-blocks and K,VK, V into column-blocks, sized so a block fits in SRAM. The paper sets the block widths from the SRAM size MM:

Bc=M4d,Br=min ⁣(M4d, d)B_c = \left\lceil \frac{M}{4d} \right\rceil, \qquad B_r = \min\!\left(\left\lceil \frac{M}{4d} \right\rceil,\ d\right)
(4)

The 4d4d is there because several blocks (Qi,Kj,VjQ_i, K_j, V_j and the score tile) have to share the on-chip budget at once. Then you loop. The outer loop walks the column-blocks of K,VK, V; the inner loop walks the row-blocks of QQ. For each tile you load the pieces into SRAM, compute the small Br×BcB_r \times B_c score tile, do the local softmax, and fold it into a running output and running statistics for those query rows. The big matrix is built a tile at a time and never assembled in HBM.

# FlashAttention forward: one fused kernel, N×N never leaves SRAM
init O[:]=0, l[:]=0, m[:]=-inf       # per-query running stats (in HBM)
for j in KV_blocks:                  # OUTER: load one K_j, V_j to SRAM
    for i in Q_blocks:               # INNER: stream Q_i, O_i, l_i, m_i
        S_ij  = (Q_i @ K_j.T) / sqrt(d)    # B_r×B_c tile, on chip
        m_ij  = rowmax(S_ij)
        P_ij  = exp(S_ij - m_ij)           # local softmax, max-shifted
        l_ij  = rowsum(P_ij)
        m_new = max(m_i, m_ij)             # merge the running max
        l_new = exp(m_i - m_new)*l_i + exp(m_ij - m_new)*l_ij
        O_i   = (l_i*exp(m_i-m_new)*O_i
                 + exp(m_ij-m_new)*(P_ij @ V_j)) / l_new   # rescale + add
        l_i, m_i = l_new, m_new            # write O_i, l_i, m_i to HBM
return O                             # also save l, m for the backward pass

The figure animates the sweep. One tile is live in SRAM at any moment (bright). The rest of the N×NN \times N grid is ghosted, because it is never materialized. The output column on the right fills in as more K,VK, V columns are folded into it. That accumulating column is the running output OiO_i from the streaming softmax, one block of query rows at a time.

Figure 4 · tiling and the two loops
The N×N score grid, tiled. FlashAttention v1 loops outer over K, V column blocks and inner over Q row blocks: one tile is live in SRAM (bright), the rest of the grid is ghosted and never written to HBM, and the output column on the right accumulates as columns are folded in.

Why that particular block size, and not bigger? The block size is a real knob, and it pulls in two directions. A wider block does more work per load and makes fewer passes over the sequence, so the total number of HBM round-trips falls as the block grows. But a block's working set, the Qi,Kj,VjQ_i, K_j, V_j tiles plus the Br×BcB_r \times B_c score tile, all have to sit in SRAM at once, and that footprint grows with the block. SRAM is tiny, so the footprint hits the ceiling. The largest block whose working set still fits is the best you can do, which is exactly what Bc=M/4dB_c = \lceil M / 4d \rceil picks: take the SRAM budget MM, divide by the four tiles' share 4d4d, and that is your block width.

Figure 5 · choosing the block size
200
The block width B_c is a trade-off. The round-trip count falls like 1/B as the block grows; the on-chip working set rises like 4·B·d and must stay under the SRAM ceiling (~100 KB, the per-core scratchpad, not L2). The largest block that still fits, B = M/4d, is the sweet spot. The slider changes only the constant in front of the cost: traffic stays quadratic in N, which is held fixed. Relationship is the paper's (Eq 4); the numbers are illustrative for a typical SRAM.

The exact update inside the loop is the one line that carries the streaming softmax into the output. The running output for a block of rows is not a sum of exponentials but a sum of value vectors weighted by them, so when the max shifts, every contribution already folded in is rescaled by the same factor before the new tile is added:

Oidiag(inew)1(diag(i)emiminewOi  +  em~ijminewP~ijVj)O_i \leftarrow \mathrm{diag}(\ell_i^{\text{new}})^{-1}\Big(\mathrm{diag}(\ell_i)\,e^{\,m_i - m_i^{\text{new}}}\,O_i \;+\; e^{\,\tilde m_{ij} - m_i^{\text{new}}}\,\tilde P_{ij} V_j\Big)
(5)

Because tiling lets everything happen on-chip, the whole attention, the two matrix multiplies, the softmax, and in a real Transformer the masking and dropout too, fuses into a single CUDA kernel. Load the blocks once, compute, write the output once. That is the kernel fusion that ordinary frameworks could not pull off, because they would have spilled the intermediate matrix to HBM in between. One detail worth flagging for anyone who looks at the modern code: this is the version-one loop order, outer over K,VK, V. Because QQ is on the inner loop, each output block and its running statistics are read and rewritten to HBM once per outer pass. FlashAttention-2, the 2023 follow-up, swapped the loops to put QQ on the outside, which removes most of that extra traffic, parallelizes better across the sequence, and roughly doubled the speed again. The idea is the same; the diagram is mirrored.

The backward pass: recompute, don't remember

Training needs gradients, and the gradient of attention normally wants the matrices SS and PP that the forward pass produced. Saving them is the O(N2)O(N^2) memory cost we just worked so hard to avoid. So FlashAttention does not save them. It saves only the output OO and the two tiny per-row statistics (m,)(m, \ell), plus the random-number seed used for dropout, and then recomputes the score tiles from Q,K,VQ, K, V in SRAM during the backward pass.

This sounds like the classic technique called gradient checkpointing, where you discard activations on the forward pass and recompute them on the backward pass to save memory. But classic checkpointing is a trade: you pay more compute to use less memory, and it runs slower. FlashAttention's recomputation is not a trade. It runs faster, because the backward pass is memory-bound too. Recomputing the score tile on-chip is cheap arithmetic; the thing it avoids is reading the giant matrix back from HBM, which was the slow part. Doing more math to move fewer bytes is a win whenever you are memory-bound, as you are here.

The backward pass also reuses a small algebraic shortcut. The softmax gradient has a term that, written naively, sums over a whole row of length NN. Because of the softmax Jacobian, that row reduction Pi:dPi:P_{i:}^\top dP_{i:} equals the much cheaper dot product doioido_i^\top o_i of two vectors of length dd, which fits on-chip:

Di=Pi:dPi:=doioi,dSij=Pij(dPijDi)D_i = P_{i:}^\top \, dP_{i:} = do_i^\top o_i, \qquad dS_{ij} = P_{ij}\,(dP_{ij} - D_i)
(6)

That dd-length reduction instead of an NN-length one is the small thing that keeps the backward pass entirely in fast memory, tile by tile, the same way the forward pass does.

More FLOPs, less time

FlashAttention does more arithmetic than standard attention, because it recomputes the score matrix in the backward pass, and yet it runs several times faster. The paper measures it directly on GPT-2 medium: 75.2 GFLOP against the standard 66.6, about 13% more compute, but 4.4 GB of HBM traffic against 40.3, about 9× fewer bytes, and so 7.3 ms against 41.7, about 5.7× faster overall.

Figure 6 · more FLOPs, less time
GPT-2 medium, forward + backward, one A100. FlashAttention does ~13% more arithmetic than standard attention (recomputation), moves ~9× fewer bytes to HBM, and finishes ~5.7× sooner. When a kernel is memory-bound, the byte count sets the clock, not the FLOP count.

That is the byte accounting doing its work: trade a little extra arithmetic for a lot less memory traffic and you come out ahead. The same logic explains why the attention kernel alone speeds up by as much as 7.6× while the end-to-end training speedups are more modest, attention is only one part of the model.

You can also count the bytes asymptotically. Standard attention reads and writes HBM a number of times proportional to Nd+N2Nd + N^2, dominated by the N2N^2 matrix. FlashAttention reduces this to a count that looks larger but, for real head dimensions, is far smaller:

Θ ⁣(N2d2M) HBM accesses,versusΘ ⁣(Nd+N2) for standard attention\Theta\!\left(\frac{N^2 d^2}{M}\right) \ \text{HBM accesses}, \qquad \text{versus} \qquad \Theta\!\left(N d + N^2\right) \ \text{for standard attention}
(7)

For real values, dd is 64 to 128 and MM is around 100 KB: square the head dimension and d2d^2 comes out in the low thousands to low tens of thousands, small against MM, so d2/Md^2 / M sits well below one and FlashAttention moves many times fewer bytes (up to 9× in their measurements). That fraction, applied to every byte of score traffic, is the entire win, and it is a large constant factor, not a change in the growth rate. The N2N^2 is still there: FlashAttention does not make attention linear, its memory traffic and its wall-clock time stay quadratic in the sequence length, and the only thing that becomes linear is the extra memory the kernel needs, O(N)O(N) instead of O(N2)O(N^2). Linear memory, quadratic time: that distinction is the single most misquoted fact about this paper.

The authors go one step further and prove the constant cannot be beaten in general. No exact attention algorithm can do asymptotically fewer HBM accesses than Θ(N2d2/M)\Theta(N^2 d^2 / M) across all SRAM sizes. For exact attention on a single GPU, the cook is already making essentially as few trips to the walk-in fridge as the problem allows.

Exact, not approximate

The word exact is what separates this paper from a decade of attention research that came before it. Faced with the N2N^2 cost, most prior work changed the math: sparse attention only looks at some pairs, low-rank methods like Linformer compress the keys, kernel methods like Performer replace the softmax with something cheaper. These are approximate attention. They cut the operation count, but they change the answer, and so they trade away some model quality.

The frustrating thing, and the observation that motivates the whole paper, is that many of those approximate methods do not actually run faster in wall-clock time. They reduce FLOPs, but they ignore memory access, and so the chip is still waiting on HBM. Reducing a count that was never the bottleneck does not help. This is why so many clever sub-quadratic attention variants never got adopted.

FlashAttention does not touch the math. It computes softmax(QK)V\text{softmax}(QK^\top)V, the same function, to the same precision. In the strict sense, "exact" means mathematically identical, which is why a model trained with FlashAttention reaches the same perplexity as the same model trained with standard attention, with the same training curves. In finite-precision floating point the output can differ in the last few bits, because the streaming softmax accumulates the sum in a different order and recomputation reorders it again; that is rounding, not approximation. It does not cost quality, and if anything it is more numerically stable, since the giant matrix never has to survive a round trip through memory. Approximate attention reads a summary of the document; FlashAttention reads every word, only faster.

What linear memory buys

The first win is straightforward speed, with no change to the model. On BERT-large, FlashAttention beats the MLPerf 1.1 training-speed record by 15% (17.4 minutes against 20.0). On GPT-2 it trains up to 3× faster than the standard HuggingFace implementation and about 1.7× faster than Megatron-LM, reaching the identical perplexity. On the Long-Range Arena benchmark it is 2.4× faster than a standard Transformer at matching accuracy. These are drop-in wins: same weights, same results, less wall-clock.

The second win comes from the memory side rather than the speed side. Because the kernel's memory footprint grows linearly with sequence length instead of quadratically, you can fit far longer sequences on the same GPU. Standard attention's footprint grows like N2N^2 and runs out of memory on a 40 GB A100 before reaching 64K tokens; FlashAttention's grows like NN and keeps going. The paper measures up to 20× less memory than exact attention baselines, and 2× less than even Linformer, an approximate method. Drag the slider and watch the gap open up.

Figure 7 · memory and context length
16K
Standard attention stores the N×N matrix, so its memory grows like N² and it runs out on a 40GB A100 before 64K. FlashAttention never stores it, so memory grows like N and reaches 64K. Linear memory is what unlocks long context: 16K for Path-X, 64K for Path-256.

Longer context is not just convenient; it changes what the model can learn. Training GPT-2 with 4× the context length, still faster than Megatron at the shorter length, lowers perplexity by 0.7. On two long-document classification tasks, stretching the input to longer sequences (8K to 16K) lifts accuracy by 6.4 points on average (4.3 on clinical notes at 16K, 8.5 on legal cases at 8K). The 0.7 is not exact FlashAttention beating exact standard attention on the same model, on the same model they are identical; the improvement comes entirely from the longer context the memory savings unlock.

The sharpest demonstration is a benchmark no Transformer had ever passed. Path-X asks a model to decide whether two dots in a 128×128 image are joined by a path, fed one pixel at a time, which means a sequence of 16K tokens. Before this, every Transformer either ran out of memory or scored at chance. FlashAttention is the first to do better than chance, at 61.4% accuracy, purely by being able to hold the longer sequence. With the block-sparse extension, the model reaches 64K tokens and clears Path-256 at 63.1%.

That block-sparse extension is the one place FlashAttention itself becomes approximate, and it is worth keeping straight. By giving the kernel a block mask and simply skipping the zero blocks, it computes only the tiles that matter, with IO cost Θ(Nd+N2d2M1s)\Theta(Nd + N^2 d^2 M^{-1} s) where ss is the fraction of nonzero blocks. That makes it 2 to 4× faster than dense FlashAttention and faster than any approximate method the authors tested, while staying within reach of full accuracy. So "FlashAttention" really names two things: an exact algorithm, which is the heart of the paper, and an optional approximate kernel built on the same machinery.

The limitations are honest and worth stating. Every variant needs a new CUDA kernel, hand-written in a much lower-level language than PyTorch, and not always portable across GPU generations. The authors point to the need for a compiler that takes a high-level attention description and emits the IO-aware kernel, the way Halide did for image processing. And the IO analysis is for a single GPU; spreading attention across many GPUs adds a whole new layer of data-movement accounting.

Step back and the argument is four facts long. On a GPU, attention is bottlenecked by memory traffic, not arithmetic. A streaming softmax lets you compute attention exactly, one block at a time, without ever assembling the big matrix. Tiling and recomputation keep every block on-chip, so the big matrix never touches slow memory in either the forward or the backward pass. And once you count bytes instead of operations, doing a little more arithmetic to move far fewer bytes is the obvious move, provably close to optimal. The squared cost in the sequence length is still there. It just stopped being the thing you wait on.

Provenance Verified against primary literature
Online softmax (2018)Milakov & Gimelshein: running max + normalizer with rescaling. Exact, and four years older than this paper.
Self-attention memory (2021)Rabe & Staats: attention without quadratic extra memory, via the same streaming softmax.
Aggarwal & Vitter (1988)The external-memory (IO-complexity) model the proofs are stated in.
Gradient checkpointing (2016)Chen et al.: recompute activations to save memory, normally at a speed cost.
A100 specs (NVIDIA)192KB on-chip SRAM per SM × 108 SMs; HBM at 1.5–2.0 TB/s. Verified against the Ampere whitepaper.
correctionFlashAttention is often described as making attention linear. It does not. HBM accesses and wall-clock time stay quadratic in N (Θ(N²d²/M), a large constant-factor win); only the extra memory is linear, O(N). We also teach the v1 loop order (outer loop over K, V); FlashAttention-2 later swapped the loops.

Questions you might still have

?

Does FlashAttention make attention linear-time?
No. Time and HBM traffic both stay quadratic in the sequence length N. The win is a large constant factor (roughly d²/M), not a change in the asymptotics. Only the extra memory it needs is linear in N. The paper proves the count is Θ(N²d²/M), which is still N² up to that factor.

?

If it does more arithmetic, why is it faster?
Because attention on a GPU is bottlenecked by memory access, not arithmetic. Runtime tracks the bytes moved to and from HBM, and FlashAttention moves about 9× fewer of them. The extra FLOPs from recomputation run on-chip for nearly free, so it still finishes about 5.7× sooner.

?

Is the output identical to standard attention?
Mathematically yes: it computes softmax(QKᵀ)V, the same function, which is what "exact" means. In floating point the last few bits can differ, because the sum is accumulated in a different order. It is not a quality-trading approximation, and it reports the same perplexity as the baselines.

?

Then where does the 0.7 better perplexity come from?
Not from the same model. On the same model the perplexity is identical. The memory savings let you train a model with 4× the context length at the same cost, and the longer context is what lowers perplexity. The 6.4-point lift on long documents is the same story.

?

Why does it need a hand-written CUDA kernel?
The whole method is fusing scores, softmax, masking, dropout, and the value product into one kernel that explicitly stages blocks through SRAM. PyTorch and TensorFlow run those as separate ops that each spill to HBM, so the fusion has to be written by hand. That is also the main limitation: a new kernel per variant and per GPU.

Footnotes & further reading

  1. The paper: Dao, Fu, Ermon, Rudra, Ré, FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (Stanford, NeurIPS 2022). Code.
  2. The streaming softmax: Milakov & Gimelshein, Online normalizer calculation for softmax (NVIDIA, 2018). It is the same flavor of trick as Welford's running-variance update: carry a running summary, correct it as each new chunk arrives.
  3. The closest prior art on memory-efficient attention: Rabe & Staats, Self-attention Does Not Need O(n²) Memory (2021), which reduces peak memory but not memory traffic.
  4. The IO-complexity model: Aggarwal & Vitter, The input/output complexity of sorting and related problems (CACM, 1988). And gradient checkpointing: Chen et al., Training Deep Nets with Sublinear Memory Cost (2016).
  5. The follow-up that swapped the loops for another ~2×: Tri Dao, FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (2023). The A100 hardware figures are from NVIDIA's Ampere architecture whitepaper.