Flash attention speeds up transformer attention mechanism by reducing the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM.
It's introduced in paper FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness and further optimized in paper FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. Official CUDA implementation can be found at Dao-AILab/flash-attention.
Our implementation is based on the Triton's example implementation.
Note: You can click on the mathematical symbols or identifiers to highlight them.
You can run test.py to see correctness and measure performance of this implementation.
Here's the attention forward pass. The formulas represent a single attention head. is query vector (row vector) at position and and are the key and value row vectors at position . is the output vector at position .
is the attention score matrix before softmax, is the softmax denominator, and is the attention matrix after softmax.
You can compute , instead of doing the full softmax, by computing the sum of exponents and the unnormalized output while iterating over keys:
Finally you can compute,
To make it numerically stable flash attention subtracts the current max of before exponentiating.
So it maintains the following while iterating over keys:
For each block of keys it updates them:
Then finally,
This reduces the memory usage since we don't have to compute full matrix or matrix. It also speeds up since we don't have to load these large matrices. Instead it only loads blocks of and as it iterates over them.
Here's the standard backward pass. is the gradient vector on the output
where is when and otherwise.
Flash attention paper introduces to simplify computation.
Then,
Flash attention saves from the forward pass since it doesn't take much memory. So during the backward pass it doesn't have to keep computing or .
It first computes . Then it iterates over the queries and compute (accumulate) and . Finally it iterates over the keys and compute (accumulate) .
In both forward and backward pass we calculate logarithms and exponentials of instead of for performance.
148from typing import Any, Tuple
149
150import torch
151import triton
152import triton.language as tl
153
154HI_PRES_TL: tl.constexpr = tl.float32
155HI_PRES_TORCH: torch.dtype = torch.float32
158class AttentionFunc(torch.autograd.Function):
Group query attention forward pass. Returns the output in shape [batch_size, n_heads, q_seq_len, d_head]
.
ctx
is the context for torch gradient descent q
has shape [batch_size, n_heads, q_seq_len, d_head]
q
has shape [batch_size, n_heads, q_seq_len, d_head]
k
has shape [batch_size, k_heads, kv_seq_len, d_head]
v
has shape [batch_size, k_heads, kv_seq_len, d_head]
causal
whether to apply causal attention mask sm_scale
softmax scale factor 159 @staticmethod
160 def forward(ctx: Any,
161 q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
162 causal: bool, sm_scale: float) -> torch.Tensor:
176 batch_size, n_heads, q_seq_len, d_head = q.shape
177 _, k_heads, kv_seq_len, _ = k.shape
178 assert n_heads % k_heads == 0
179 n_groups = n_heads // k_heads
Shape constraints
182 assert d_head == k.shape[-1] == v.shape[-1]
183 assert d_head in {16, 32, 64, 128, 256}
Change the tensors combining the heads with the batch dimension
186 q = q.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
187 k = k.view(batch_size * k_heads, kv_seq_len, d_head)
188 v = v.view(batch_size * k_heads, kv_seq_len, d_head)
Make sure the tensors are contiguous and the strides are same
191 assert q.is_contiguous()
192 assert k.is_contiguous()
193 assert v.is_contiguous()
194 assert k.stride() == v.stride()
Tensor for the output
197 o = torch.empty_like(q)
Tensor for log of sum of exponentials
199 lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
The forward computation will be parallelized along the batch dimension and the queries in blocks of size BLOCK_Q
202 grid = lambda meta: (triton.cdiv(q_seq_len, meta["BLOCK_Q"]), batch_size * k_heads * n_groups, 1)
203 _attn_fwd[grid](
204 q, k, v, sm_scale * 1.4426950408889634, lse, o,
205 n_groups=n_groups,
206 q_seq_len=q_seq_len,
207 kv_seq_len=kv_seq_len,
208 d_head=d_head,
209 is_causal=causal,
210 )
Save the reshaped inputs and outputs for the backward pass
213 ctx.save_for_backward(q, k, v, o, lse)
214 ctx.sm_scale = sm_scale
215 ctx.n_groups = n_groups
216 ctx.causal = causal
Return the output in shape [batch_size, n_heads, q_seq_len, d_head]
219 return o.view(batch_size, n_heads, q_seq_len, d_head)
The backward pass computes the gradients of the input tensors.
ctx
is the context for torch gradient descent do
is the gradient tensor of the attention output with shape [batch_size, n_heads, q_seq_len, d_head]
221 @staticmethod
222 def backward(ctx: Any, do: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
Get saved tensors and attributes
233 n_groups = ctx.n_groups
234 sm_scale = ctx.sm_scale
235 causal = ctx.causal
236 q, k, v, o, lse = ctx.saved_tensors
Get shapes
239 batch_size, n_heads, q_seq_len, d_head = do.shape
240 _, kv_seq_len, _ = k.shape
241 k_heads = n_heads // n_groups
Combine the heads with the batch dimension of the output gradients tensor
244 do = do.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
Make sure it's contiguous and the strides are the same
247 assert do.is_contiguous()
248 assert k.stride() == v.stride()
249 assert q.stride() == o.stride() == do.stride()
Create tensors for input gradients
252 dq = torch.empty_like(q)
253 dk = torch.empty_like(k)
254 dv = torch.empty_like(v)
Precompute
257 k_scaled = k * (sm_scale * 1.4426950408889634)
259 pdp = torch.empty_like(lse)
We use fixed BLOCK_Q
for backward pass on
265 BLOCK_Q = 16
266 pre_grid = (triton.cdiv(q_seq_len, BLOCK_Q), batch_size * k_heads)
267 _attn_bwd_d[pre_grid](
268 o, do,
269 pdp,
270 BLOCK_Q=16,
271 d_head=d_head,
272 q_seq_len=q_seq_len,
273 n_groups=n_groups,
274 num_stages=1,
275 )
280 grid = lambda meta: (triton.cdiv(kv_seq_len, meta['BLOCK_K']), batch_size * k_heads)
281 _attn_bwd_dkdv[grid](
282 q, k_scaled, v, sm_scale, do, dk, dv,
283 lse, pdp,
284 q_seq_len, kv_seq_len, n_groups, d_head,
285 is_causal=causal,
286
287 )
292 grid = lambda meta: (triton.cdiv(q_seq_len, meta['BLOCK_Q']), batch_size * k_heads * n_groups)
293 _attn_bwd_dq[grid](
294 q, k_scaled, v, do,
295 dq,
296 lse, pdp,
297 q_seq_len, kv_seq_len, n_groups, d_head,
298 is_causal=causal,
299 )
Split the combined batch and heads
302 dq = dq.view(batch_size, n_heads, q_seq_len, d_head)
303 dk = dk.view(batch_size, k_heads, kv_seq_len, d_head)
304 dv = dv.view(batch_size, k_heads, kv_seq_len, d_head)
307 return dq, dk, dv, None, None
308
309
310attention = AttentionFunc.apply
313def _get_autotune_configs(inner_loop: str) -> list:
318 configs = []
Possible options for BLOCK_Q
321 for bq in [64, 128, 256]:
Possible options for BLOCK_K
323 for bk in [64, 128, 256]:
If the inner loop is along keys the BLOCK_Q
must be a multiple of BLOCK_K
for causal masking
325 if inner_loop == 'key' and bq % bk != 0:
326 continue
Similarly when the inner loop is along queries
328 if inner_loop == 'query' and bk % bq != 0:
329 continue
Number of stages and warps
332 for s in [2, 3, 4]:
333 for w in [4, 8]:
334 if bq * bk < 128 * 128 and w == 8:
335 continue
336
337 configs.append(triton.Config({'BLOCK_Q': bq, 'BLOCK_K': bk}, num_stages=s, num_warps=w))
Use return configs
to autotune. Trying all combinations is slow for testing.
340 return configs[:1]
t_q
queries t_k
keys t_v
values sm_scale_log2e
softmax scale multiplied by t_lse
(out) t_o
output n_groups
number of groups in GQA q_seq_len
query sequence length kv_seq_len
key/value sequence length d_head
number of dimensions in a head BLOCK_Q
block size for query sequence length BLOCK_K
block size for key sequence length is_causal
whether causal attentionStrides z
, h
, m
and d
denote the stride of the corresponding dimensions (batch_size
, n_heads
, q_seq_len
, d_head
) in the query. Stride n
denote the stride on kv_seq_len
of key.
343@triton.autotune(_get_autotune_configs(inner_loop='key'),
344 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
345@triton.jit
346def _attn_fwd(t_q, t_k, t_v, sm_scale_log2e, t_lse, t_o,
347 n_groups: tl.constexpr,
348 q_seq_len: tl.constexpr,
349 kv_seq_len: tl.constexpr,
350 d_head: tl.constexpr,
351 is_causal: tl.constexpr,
352 BLOCK_Q: tl.constexpr,
353 BLOCK_K: tl.constexpr,
354 ):
We are computing the attention for for i
... `i + BLOCK_Q' in batch/head combination .
378 i = tl.program_id(0)
379 z = tl.program_id(1) // n_groups
380 g = tl.program_id(1) % n_groups
383 p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
384 (q_seq_len, d_head),
385 (d_head, 1),
386 (i * BLOCK_Q, 0),
387 (BLOCK_Q, d_head),
388 (1, 0))
389 p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
390 (kv_seq_len, d_head),
391 (d_head, 1),
392 (0, 0),
393 (BLOCK_K, d_head),
394 (1, 0))
395 p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
396 (d_head, kv_seq_len),
397 (1, d_head),
398 (0, 0),
399 (d_head, BLOCK_K),
400 (0, 1))
401 p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
402 (q_seq_len, d_head),
403 (d_head, 1),
404 (i * BLOCK_Q, 0),
405 (BLOCK_Q, d_head),
406 (1, 0))
407 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
408 (q_seq_len,),
409 (1,),
410 (i * BLOCK_Q,),
411 (BLOCK_Q,),
412 (0,))
Initialize offsets
415 offs_i = i * BLOCK_Q + tl.arange(0, BLOCK_Q)
416 offs_j = tl.arange(0, BLOCK_K)
Mask for for the last block
419 i_mask = offs_i < q_seq_len
Initialize and . is initialized to and to . So in the first update, the effect of initial is .
b_m
will be storing
425 b_m = tl.where(i_mask, -float("inf"), 0.0)
426 b_l = tl.where(i_mask, 1.0, 0.0)
429 b_o = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
Load outside the loop since it will be reused through out the loop over .
432 b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
433
434 if is_causal:
Inner loop upto the diagonal block
436 b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q,
437 p_kT, p_v,
438 sm_scale_log2e,
439 BLOCK_Q, d_head, BLOCK_K,
440 offs_i, offs_j,
441 j=tl.full([], 0, tl.int32), # type: ignore
442 steps=(i * BLOCK_Q) // BLOCK_K,
443 MASK=False,
444 q_seq_len=q_seq_len,
445 kv_seq_len=kv_seq_len
446 )
Diagonal block with masking within it
448 b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
449 sm_scale_log2e,
450 BLOCK_Q, d_head, BLOCK_K,
451 offs_i, offs_j,
452 j=i * BLOCK_Q,
453 steps=BLOCK_Q // BLOCK_K,
454 MASK=True,
455 q_seq_len=q_seq_len,
456 kv_seq_len=kv_seq_len
457 )
458 else:
Iterate through all
460 b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
461 sm_scale_log2e,
462 BLOCK_Q, d_head, BLOCK_K,
463 offs_i, offs_j,
464 j=tl.full([], 0, tl.int32), # type: ignore
465 steps=tl.cdiv(kv_seq_len, BLOCK_K),
466 MASK=False,
467 q_seq_len=q_seq_len,
468 kv_seq_len=kv_seq_len
469 )
Store LSE
472 tl.store(p_lse, b_m + tl.math.log2(b_l), boundary_check=(0,))
Store
474 tl.store(p_o, (b_o / b_l[:, None]).to(t_o.type.element_ty), boundary_check=(0,))
This iterates through keys and values starting from j
for steps
number of steps. In each step it processes BLOCK_K
entries of keys/values.
477@triton.jit
478def _attn_fwd_inner(b_o, b_l, b_m, b_q,
479 p_kT, p_v,
480 sm_scale_log2e,
481 BLOCK_Q: tl.constexpr,
482 d_head: tl.constexpr,
483 BLOCK_K: tl.constexpr,
484 offs_i, offs_j,
485 j,
486 steps,
487 MASK: tl.constexpr,
488 q_seq_len: tl.constexpr,
489 kv_seq_len: tl.constexpr
490 ):
497 tl.static_assert(BLOCK_Q % BLOCK_K == 0)
Move and pointers
500 p_kT = tl.advance(p_kT, (0, j))
501 p_v = tl.advance(p_v, (j, 0))
Iterate over , and update and
504 for _ in range(steps):
Load
506 b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
Compute
508 b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
509 b_s = b_s * sm_scale_log2e
Apply causal mask
512 if MASK:
513 causal_mask = offs_i[:, None] >= (j + offs_j[None, :])
514 b_s = tl.where(causal_mask, b_s, -float("inf"))
Mask out if the block is beyond the end of
517 j_mask = (j + offs_j) < kv_seq_len
518 b_s = tl.where(j_mask[None, :], b_s, -float("inf"))
521 b_m_new = tl.maximum(b_m, tl.max(b_s, -1))
527 b_p = tl.math.exp2(b_s - b_m_new[:, None])
530 b_l_new = tl.sum(b_p, -1)
532 b_m_m_new = tl.math.exp2(b_m - b_m_new)
534 b_l = b_l * b_m_m_new + b_l_new
537 b_o = b_o * b_m_m_new[:, None]
538 b_p = b_p.to(b_q.dtype) # TODO
539 b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
540 b_o += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
543 b_m = b_m_new
Move pointers
546 j += BLOCK_K
547 p_v = tl.advance(p_v, (BLOCK_K, 0))
548 p_kT = tl.advance(p_kT, (0, BLOCK_K))
549
550 tl.static_assert(b_o.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
551
552 return b_o, b_l, b_m
555@triton.jit
556def _attn_bwd_d(t_o, t_do,
557 t_pdp,
558 BLOCK_Q: tl.constexpr, d_head: tl.constexpr,
559 q_seq_len: tl.constexpr,
560 n_groups: tl.constexpr,
561 ):
565 i = tl.program_id(0) * BLOCK_Q
566 z = tl.program_id(1)
Create block pointers
569 p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head,
570 (n_groups, q_seq_len, d_head),
571 (q_seq_len * d_head, d_head, 1),
572 (0, i, 0),
573 (n_groups, BLOCK_Q, d_head),
574 (2, 1, 0))
575 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
576 (n_groups, q_seq_len, d_head),
577 (q_seq_len * d_head, d_head, 1),
578 (0, i, 0),
579 (n_groups, BLOCK_Q, d_head),
580 (2, 1, 0))
581 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
582 (n_groups, q_seq_len),
583 (q_seq_len, 1),
584 (0, i),
585 (n_groups, BLOCK_Q),
586 (1, 0))
Load
589 o = tl.load(p_o, boundary_check=(1,), padding_option="zero")
Load
591 do = tl.load(p_do, boundary_check=(1,), padding_option="zero").to(HI_PRES_TL)
Calculate
593 d = tl.sum(o * do, axis=-1)
Save
595 tl.store(p_pdp, d, boundary_check=(1,))
598@triton.autotune(_get_autotune_configs(inner_loop='query'),
599 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
600@triton.jit
601def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
602 t_do,
603 t_dk, t_dv,
604 t_lse, t_pdp,
605 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
606 n_groups: tl.constexpr, d_head: tl.constexpr,
607 is_causal: tl.constexpr,
608 BLOCK_Q: tl.constexpr,
609 BLOCK_K: tl.constexpr,
610 ):
Compute and for j
... j + BLOCK_K
by iterating over
616 j = tl.program_id(0) * BLOCK_K
617 z = tl.program_id(1)
Create block pointers
620 p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
621 (kv_seq_len, d_head),
622 (d_head, 1),
623 (j, 0),
624 (BLOCK_K, d_head),
625 (1, 0))
626 p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
627 (kv_seq_len, d_head),
628 (d_head, 1),
629 (j, 0),
630 (BLOCK_K, d_head),
631 (1, 0))
632 p_dk = tl.make_block_ptr(t_dk + z * kv_seq_len * d_head,
633 (kv_seq_len, d_head),
634 (d_head, 1),
635 (j, 0),
636 (BLOCK_K, d_head),
637 (1, 0))
638 p_dv = tl.make_block_ptr(t_dv + z * kv_seq_len * d_head,
639 (kv_seq_len, d_head),
640 (d_head, 1),
641 (j, 0),
642 (BLOCK_K, d_head),
643 (1, 0))
Initialize and
646 b_dk = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
647 b_dv = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
Load and outside the loop.
650 b_k = tl.load(p_k, boundary_check=(0,), padding_option="zero")
651 b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
Iterate through queries in GQA
654 for g in range(n_groups):
Create block pointers
656 p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
657 (d_head, q_seq_len),
658 (1, d_head),
659 (0, 0),
660 (d_head, BLOCK_Q),
661 (0, 1))
662
663 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
664 (q_seq_len, d_head),
665 (d_head, 1),
666 (0, 0),
667 (BLOCK_Q, d_head),
668 (1, 0))
669 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
670 (q_seq_len,),
671 (1,),
672 (0,),
673 (BLOCK_Q,),
674 (0,))
675 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
676 (q_seq_len,),
677 (1,),
678 (0,),
679 (BLOCK_Q,),
680 (0,))
681
682 if is_causal:
Inner loop at the diagonal block
684 b_dk, b_dv = _attn_bwd_dkdv_inner(
685 b_dk, b_dv,
686 p_qT, b_k, b_v, p_do,
687 p_lse, p_pdp,
688 BLOCK_Q, BLOCK_K,
689 d_head,
690 j=j, i=j,
691 steps=BLOCK_K // BLOCK_Q,
692 MASK=True,
693 q_seq_len=q_seq_len,
694 kv_seq_len=kv_seq_len,
695 )
Inner loop on queries after the diagonal
698 b_dk, b_dv = _attn_bwd_dkdv_inner(
699 b_dk, b_dv,
700 p_qT, b_k, b_v, p_do,
701 p_lse, p_pdp,
702 BLOCK_Q, BLOCK_K,
703 d_head,
704 j=j, i=j + BLOCK_K,
705 steps=tl.cdiv((q_seq_len - (j + BLOCK_K)), BLOCK_Q),
706 MASK=False,
707 q_seq_len=q_seq_len,
708 kv_seq_len=kv_seq_len
709 )
710 else:
Iterate through all queries
712 b_dk, b_dv = _attn_bwd_dkdv_inner(
713 b_dk, b_dv,
714 p_qT, b_k, b_v, p_do,
715 p_lse, p_pdp,
716 BLOCK_Q, BLOCK_K,
717 d_head,
718 j=j, i=tl.full([], 0, tl.int32),
719 steps=tl.cdiv(q_seq_len, BLOCK_Q),
720 MASK=False,
721 q_seq_len=q_seq_len,
722 kv_seq_len=kv_seq_len
723 )
Save
726 tl.store(p_dv, b_dv.to(t_dv.type.element_ty), boundary_check=(0,))
b_dk
had
729 b_dk *= sm_scale
Save
732 tl.store(p_dk, b_dk.to(t_dk.type.element_ty), boundary_check=(0,))
735@triton.jit
736def _attn_bwd_dkdv_inner(b_dk, b_dv,
737 p_qT, b_k, b_v, p_do,
738 p_lse, p_pdp,
739 BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
740 d_head: tl.constexpr,
741 j, i, steps,
742 MASK: tl.constexpr,
743 q_seq_len: tl.constexpr,
744 kv_seq_len: tl.constexpr):
To apply the mask
750 tl.static_assert(BLOCK_K % BLOCK_Q == 0)
Offsets and mask
753 offs_i = i + tl.arange(0, BLOCK_Q)
754 offs_j = j + tl.arange(0, BLOCK_K)
Move the pointers
757 p_qT = tl.advance(p_qT, (0, i))
758 p_do = tl.advance(p_do, (i, 0))
759 p_lse = tl.advance(p_lse, (i,))
760 p_pdp = tl.advance(p_pdp, (i,))
Iterate over
763 for _ in range(steps):
Load
765 b_qT = tl.load(p_qT, boundary_check=(1,), padding_option="zero")
768 b_l = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
771 b_sT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)
780 b_pT = tl.math.exp2(b_sT - b_l[None, :])
Autoregressive masking
783 if MASK:
784 mask = (offs_i[None, :] >= offs_j[:, None])
785 b_pT = tl.where(mask, b_pT, 0.0)
Mask out if the block is beyond the end of
Note: No need to mask out based on because the effects on positions outside boundary will not get stored in or Masking by may also not be necessary size the tensors have 0 on loading
792 i_mask = offs_i < q_seq_len
793 b_pT = tl.where(i_mask[None, :], b_pT, 0.0)
796 b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
797 b_dv += tl.dot(b_pT.to(b_do.dtype), b_do, out_dtype=HI_PRES_TL)
800 b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
802 b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)
804 b_dsT = b_pT * (b_dpT - b_pdp[None, :])
806 b_dk += tl.dot(b_dsT.to(b_qT.dtype), tl.trans(b_qT), out_dtype=HI_PRES_TL)
Increment pointers.
809 offs_i += BLOCK_Q
810 p_lse = tl.advance(p_lse, (BLOCK_Q,))
811 p_pdp = tl.advance(p_pdp, (BLOCK_Q,))
812 p_qT = tl.advance(p_qT, (0, BLOCK_Q))
813 p_do = tl.advance(p_do, (BLOCK_Q, 0))
Return accumulated and
816 return b_dk, b_dv
819@triton.autotune(_get_autotune_configs(inner_loop='key'),
820 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
821@triton.jit
822def _attn_bwd_dq(t_q, t_k, t_v, t_do,
823 t_dq,
824 t_lse, t_pdp,
825 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
826 n_groups: tl.constexpr, d_head: tl.constexpr,
827 is_causal: tl.constexpr,
828 BLOCK_Q: tl.constexpr,
829 BLOCK_K: tl.constexpr,
830 ):
835 i = tl.program_id(0) * BLOCK_Q
836 z = tl.program_id(1) // n_groups
837 g = tl.program_id(1) % n_groups # TODO
Create block pointers
840 p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
841 (q_seq_len, d_head),
842 (d_head, 1),
843 (i, 0),
844 (BLOCK_Q, d_head),
845 (1, 0))
846 p_dq = tl.make_block_ptr(t_dq + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
847 (q_seq_len, d_head),
848 (d_head, 1),
849 (i, 0),
850 (BLOCK_Q, d_head),
851 (1, 0))
852 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
853 (q_seq_len, d_head),
854 (d_head, 1),
855 (i, 0),
856 (BLOCK_Q, d_head),
857 (1, 0))
858 p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
859 (d_head, kv_seq_len),
860 (1, d_head),
861 (0, 0),
862 (d_head, BLOCK_K),
863 (0, 1))
864 p_vT = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
865 (d_head, kv_seq_len),
866 (1, d_head),
867 (0, 0),
868 (d_head, BLOCK_K),
869 (0, 1))
870 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
871 (q_seq_len,),
872 (1,),
873 (i,),
874 (BLOCK_Q,),
875 (0,))
876 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
877 (q_seq_len,),
878 (1,),
879 (i,),
880 (BLOCK_Q,),
881 (0,))
Load , , , and outside the loop
884 b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
885 b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
886 b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
887 b_lse = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
Initialize
890 b_dq = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
894 if is_causal:
Compute for masked (diagonal) blocks.
896 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
897 b_do, b_lse, b_pdp,
898 BLOCK_Q, BLOCK_K,
899 i=i, j=i,
900 steps=BLOCK_Q // BLOCK_K,
901 MASK=True,
902 q_seq_len=q_seq_len,
903 kv_seq_len=kv_seq_len
904 )
Compute for other blocks
907 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
908 b_do, b_lse, b_pdp,
909 BLOCK_Q, BLOCK_K,
910 i=i, j=tl.full([], 0, tl.int32), # type: ignore
911 steps=i // BLOCK_K,
912 MASK=False,
913 q_seq_len=q_seq_len,
914 kv_seq_len=kv_seq_len
915 )
916 else:
Iterate through all
918 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
919 b_do, b_lse, b_pdp,
920 BLOCK_Q, BLOCK_K,
921 i=i, j=tl.full([], 0, tl.int32), # type: ignore
922 steps=tl.cdiv(kv_seq_len, BLOCK_K),
923 MASK=False,
924 q_seq_len=q_seq_len,
925 kv_seq_len=kv_seq_len
926 )
b_dq
stores so multiply by to get
929 b_dq *= 0.6931471824645996
Save
932 tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))
935@triton.jit
936def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
937 b_do, b_lse, b_pdp,
938 BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
939 i, j, steps,
940 MASK: tl.constexpr,
941 q_seq_len: tl.constexpr,
942 kv_seq_len: tl.constexpr):
Offsets
948 offs_i = i + tl.arange(0, BLOCK_Q)
949 offs_j = j + tl.arange(0, BLOCK_K)
Move the pointers
952 p_kT = tl.advance(p_kT, (0, j))
953 p_vT = tl.advance(p_vT, (0, j))
954
955 tl.static_assert(BLOCK_Q % BLOCK_K == 0, 'BLOCK_Q must be divisible by BLOCK_K')
Iterate over
958 for _ in range(steps):
Load
960 b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
Load
962 b_vT = tl.load(p_vT, boundary_check=(1,), padding_option="zero")
965 b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
974 b_p = tl.math.exp2(b_s - b_lse[:, None])
Autoregressive masking
977 if MASK:
978 causal_mask = (offs_i[:, None] >= offs_j[None, :])
979 b_p = tl.where(causal_mask, b_p, 0.0)
Mask out if the block is beyond the end of
982 j_mask = offs_j < kv_seq_len
983 b_p = tl.where(j_mask[None, :], b_p, 0.0)
988 b_dp = tl.dot(b_do, b_vT, out_dtype=HI_PRES_TL).to(HI_PRES_TL)
990 b_ds = b_p * (b_dp - b_pdp[:, None])
992 b_dq += tl.dot(b_ds.to(b_kT.dtype), tl.trans(b_kT), out_dtype=HI_PRES_TL)
Increment pointers.
995 offs_j += BLOCK_K
996 p_kT = tl.advance(p_kT, (0, BLOCK_K))
997 p_vT = tl.advance(p_vT, (0, BLOCK_K))
Return accumulated
1000 return b_dq