Flash Attention

Flash attention speeds up transformer attention mechanism by reducing the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM.

It's introduced in paper FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness and further optimized in paper FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. Official CUDA implementation can be found at Dao-AILab/flash-attention.

Our implementation is based on the Triton's example implementation.

Note: You can click on the mathematical symbols or identifiers to highlight them.

You can run test.py to see correctness and measure performance of this implementation.

Forward pass

Here's the attention forward pass. The formulas represent a single attention head. is query vector (row vector) at position and and are the key and value row vectors at position . is the output vector at position .

is the attention score matrix before softmax, is the softmax denominator, and is the attention matrix after softmax.

Flash Attention Optimization

You can compute , instead of doing the full softmax, by computing the sum of exponents and the unnormalized output while iterating over keys:

Finally you can compute,

To make it numerically stable flash attention subtracts the current max of before exponentiating.

So it maintains the following while iterating over keys:

  • , the max
  • , the sum of exponents , and
  • , the unnormalized output

For each block of keys it updates them:

Then finally,

This reduces the memory usage since we don't have to compute full matrix or matrix. It also speeds up since we don't have to load these large matrices. Instead it only loads blocks of and as it iterates over them.

Backward pass

Here's the standard backward pass. is the gradient vector on the output

where is when and otherwise.

Flash attention paper introduces to simplify computation.

Then,

Flash attention saves from the forward pass since it doesn't take much memory. So during the backward pass it doesn't have to keep computing or .

It first computes . Then it iterates over the queries and compute (accumulate) and . Finally it iterates over the keys and compute (accumulate) .

In both forward and backward pass we calculate logarithms and exponentials of instead of for performance.

148from typing import Any, Tuple
149
150import torch
151import triton
152import triton.language as tl
153
154HI_PRES_TL: tl.constexpr = tl.float32
155HI_PRES_TORCH: torch.dtype = torch.float32
158class AttentionFunc(torch.autograd.Function):

Forward pass

Group query attention forward pass. Returns the output in shape [batch_size, n_heads, q_seq_len, d_head] .

  • ctx is the context for torch gradient descent
  • q has shape [batch_size, n_heads, q_seq_len, d_head]
  • q has shape [batch_size, n_heads, q_seq_len, d_head]
  • k has shape [batch_size, k_heads, kv_seq_len, d_head]
  • v has shape [batch_size, k_heads, kv_seq_len, d_head]
  • causal whether to apply causal attention mask
  • sm_scale softmax scale factor
159    @staticmethod
160    def forward(ctx: Any,
161                q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
162                causal: bool, sm_scale: float) -> torch.Tensor:
176        batch_size, n_heads, q_seq_len, d_head = q.shape
177        _, k_heads, kv_seq_len, _ = k.shape
178        assert n_heads % k_heads == 0
179        n_groups = n_heads // k_heads

Shape constraints

182        assert d_head == k.shape[-1] == v.shape[-1]
183        assert d_head in {16, 32, 64, 128, 256}

Change the tensors combining the heads with the batch dimension

186        q = q.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
187        k = k.view(batch_size * k_heads, kv_seq_len, d_head)
188        v = v.view(batch_size * k_heads, kv_seq_len, d_head)

Make sure the tensors are contiguous and the strides are same

191        assert q.is_contiguous()
192        assert k.is_contiguous()
193        assert v.is_contiguous()
194        assert k.stride() == v.stride()

Tensor for the output

197        o = torch.empty_like(q)

Tensor for log of sum of exponentials

199        lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)

The forward computation will be parallelized along the batch dimension and the queries in blocks of size BLOCK_Q

202        grid = lambda meta: (triton.cdiv(q_seq_len, meta["BLOCK_Q"]), batch_size * k_heads * n_groups, 1)
203        _attn_fwd[grid](
204            q, k, v, sm_scale * 1.4426950408889634, lse, o,
205            n_groups=n_groups,
206            q_seq_len=q_seq_len,
207            kv_seq_len=kv_seq_len,
208            d_head=d_head,
209            is_causal=causal,
210        )

Save the reshaped inputs and outputs for the backward pass

213        ctx.save_for_backward(q, k, v, o, lse)
214        ctx.sm_scale = sm_scale
215        ctx.n_groups = n_groups
216        ctx.causal = causal

Return the output in shape [batch_size, n_heads, q_seq_len, d_head]

219        return o.view(batch_size, n_heads, q_seq_len, d_head)

Backward pass

The backward pass computes the gradients of the input tensors.

  • ctx is the context for torch gradient descent
  • do is the gradient tensor of the attention output with shape [batch_size, n_heads, q_seq_len, d_head]
221    @staticmethod
222    def backward(ctx: Any, do: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:

Get saved tensors and attributes

233        n_groups = ctx.n_groups
234        sm_scale = ctx.sm_scale
235        causal = ctx.causal
236        q, k, v, o, lse = ctx.saved_tensors

Get shapes

239        batch_size, n_heads, q_seq_len, d_head = do.shape
240        _, kv_seq_len, _ = k.shape
241        k_heads = n_heads // n_groups

Combine the heads with the batch dimension of the output gradients tensor

244        do = do.view(batch_size * k_heads, n_groups, q_seq_len, d_head)

Make sure it's contiguous and the strides are the same

247        assert do.is_contiguous()
248        assert k.stride() == v.stride()
249        assert q.stride() == o.stride() == do.stride()

Create tensors for input gradients

252        dq = torch.empty_like(q)
253        dk = torch.empty_like(k)
254        dv = torch.empty_like(v)

Precompute

257        k_scaled = k * (sm_scale * 1.4426950408889634)

259        pdp = torch.empty_like(lse)

We use fixed BLOCK_Q for backward pass on

Compute

This is parallelized along the batch and query in blocks of size BLOCK_Q

265        BLOCK_Q = 16
266        pre_grid = (triton.cdiv(q_seq_len, BLOCK_Q), batch_size * k_heads)
267        _attn_bwd_d[pre_grid](
268            o, do,
269            pdp,
270            BLOCK_Q=16,
271            d_head=d_head,
272            q_seq_len=q_seq_len,
273            n_groups=n_groups,
274            num_stages=1,
275        )

Compute and

This is parallelized along the batch and keys in blocks of size BLOCK_K

280        grid = lambda meta: (triton.cdiv(kv_seq_len, meta['BLOCK_K']), batch_size * k_heads)
281        _attn_bwd_dkdv[grid](
282            q, k_scaled, v, sm_scale, do, dk, dv,
283            lse, pdp,
284            q_seq_len, kv_seq_len, n_groups, d_head,
285            is_causal=causal,
286
287        )

Compute

This is parallelized along the batch and queries in blocks of size BLOCK_Q

292        grid = lambda meta: (triton.cdiv(q_seq_len, meta['BLOCK_Q']), batch_size * k_heads * n_groups)
293        _attn_bwd_dq[grid](
294            q, k_scaled, v, do,
295            dq,
296            lse, pdp,
297            q_seq_len, kv_seq_len, n_groups, d_head,
298            is_causal=causal,
299        )

Split the combined batch and heads

302        dq = dq.view(batch_size, n_heads, q_seq_len, d_head)
303        dk = dk.view(batch_size, k_heads, kv_seq_len, d_head)
304        dv = dv.view(batch_size, k_heads, kv_seq_len, d_head)

307        return dq, dk, dv, None, None
308
309
310attention = AttentionFunc.apply

Configs for auto-tuning

313def _get_autotune_configs(inner_loop: str) -> list:
318    configs = []

Possible options for BLOCK_Q

321    for bq in [64, 128, 256]:

Possible options for BLOCK_K

323        for bk in [64, 128, 256]:

If the inner loop is along keys the BLOCK_Q must be a multiple of BLOCK_K for causal masking

325            if inner_loop == 'key' and bq % bk != 0:
326                continue

Similarly when the inner loop is along queries

328            if inner_loop == 'query' and bk % bq != 0:
329                continue

Number of stages and warps

332            for s in [2, 3, 4]:
333                for w in [4, 8]:
334                    if bq * bk < 128 * 128 and w == 8:
335                        continue
336
337                    configs.append(triton.Config({'BLOCK_Q': bq, 'BLOCK_K': bk}, num_stages=s, num_warps=w))

Use return configs to autotune. Trying all combinations is slow for testing.

340    return configs[:1]

Triton kernel for Flash attention forward pass

  • t_q queries
  • t_k keys
  • t_v values
  • sm_scale_log2e softmax scale multiplied by
  • t_lse (out)
  • t_o output
  • n_groups number of groups in GQA
  • q_seq_len query sequence length
  • kv_seq_len key/value sequence length
  • d_head number of dimensions in a head
  • BLOCK_Q block size for query sequence length
  • BLOCK_K block size for key sequence length
  • is_causal whether causal attention

Strides z , h , m and d denote the stride of the corresponding dimensions (batch_size , n_heads , q_seq_len , d_head ) in the query. Stride n denote the stride on kv_seq_len of key.

343@triton.autotune(_get_autotune_configs(inner_loop='key'),
344                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
345@triton.jit
346def _attn_fwd(t_q, t_k, t_v, sm_scale_log2e, t_lse, t_o,
347              n_groups: tl.constexpr,
348              q_seq_len: tl.constexpr,
349              kv_seq_len: tl.constexpr,
350              d_head: tl.constexpr,
351              is_causal: tl.constexpr,
352              BLOCK_Q: tl.constexpr,
353              BLOCK_K: tl.constexpr,
354              ):

We are computing the attention for for i ... `i + BLOCK_Q' in batch/head combination .

378    i = tl.program_id(0)
379    z = tl.program_id(1) // n_groups
380    g = tl.program_id(1) % n_groups

Create block pointers

383    p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
384                            (q_seq_len, d_head),
385                            (d_head, 1),
386                            (i * BLOCK_Q, 0),
387                            (BLOCK_Q, d_head),
388                            (1, 0))
389    p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
390                            (kv_seq_len, d_head),
391                            (d_head, 1),
392                            (0, 0),
393                            (BLOCK_K, d_head),
394                            (1, 0))
395    p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
396                             (d_head, kv_seq_len),
397                             (1, d_head),
398                             (0, 0),
399                             (d_head, BLOCK_K),
400                             (0, 1))
401    p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
402                            (q_seq_len, d_head),
403                            (d_head, 1),
404                            (i * BLOCK_Q, 0),
405                            (BLOCK_Q, d_head),
406                            (1, 0))
407    p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
408                              (q_seq_len,),
409                              (1,),
410                              (i * BLOCK_Q,),
411                              (BLOCK_Q,),
412                              (0,))

Initialize offsets

415    offs_i = i * BLOCK_Q + tl.arange(0, BLOCK_Q)
416    offs_j = tl.arange(0, BLOCK_K)

Mask for for the last block

419    i_mask = offs_i < q_seq_len

Initialize and . is initialized to and to . So in the first update, the effect of initial is .

b_m will be storing

425    b_m = tl.where(i_mask, -float("inf"), 0.0)
426    b_l = tl.where(i_mask, 1.0, 0.0)

429    b_o = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)

Load outside the loop since it will be reused through out the loop over .

432    b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
433
434    if is_causal:

Inner loop upto the diagonal block

436        b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q,
437                                        p_kT, p_v,
438                                        sm_scale_log2e,
439                                        BLOCK_Q, d_head, BLOCK_K,
440                                        offs_i, offs_j,
441                                        j=tl.full([], 0, tl.int32),  # type: ignore
442                                        steps=(i * BLOCK_Q) // BLOCK_K,
443                                        MASK=False,
444                                        q_seq_len=q_seq_len,
445                                        kv_seq_len=kv_seq_len
446                                        )

Diagonal block with masking within it

448        b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
449                                        sm_scale_log2e,
450                                        BLOCK_Q, d_head, BLOCK_K,
451                                        offs_i, offs_j,
452                                        j=i * BLOCK_Q,
453                                        steps=BLOCK_Q // BLOCK_K,
454                                        MASK=True,
455                                        q_seq_len=q_seq_len,
456                                        kv_seq_len=kv_seq_len
457                                        )
458    else:

Iterate through all

460        b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
461                                        sm_scale_log2e,
462                                        BLOCK_Q, d_head, BLOCK_K,
463                                        offs_i, offs_j,
464                                        j=tl.full([], 0, tl.int32),  # type: ignore
465                                        steps=tl.cdiv(kv_seq_len, BLOCK_K),
466                                        MASK=False,
467                                        q_seq_len=q_seq_len,
468                                        kv_seq_len=kv_seq_len
469                                        )

Store LSE

472    tl.store(p_lse, b_m + tl.math.log2(b_l), boundary_check=(0,))

Store

474    tl.store(p_o, (b_o / b_l[:, None]).to(t_o.type.element_ty), boundary_check=(0,))

Inner loop to calculate

This iterates through keys and values starting from j for steps number of steps. In each step it processes BLOCK_K entries of keys/values.

477@triton.jit
478def _attn_fwd_inner(b_o, b_l, b_m, b_q,
479                    p_kT, p_v,
480                    sm_scale_log2e,
481                    BLOCK_Q: tl.constexpr,
482                    d_head: tl.constexpr,
483                    BLOCK_K: tl.constexpr,
484                    offs_i, offs_j,
485                    j,
486                    steps,
487                    MASK: tl.constexpr,
488                    q_seq_len: tl.constexpr,
489                    kv_seq_len: tl.constexpr
490                    ):
497    tl.static_assert(BLOCK_Q % BLOCK_K == 0)

Move and pointers

500    p_kT = tl.advance(p_kT, (0, j))
501    p_v = tl.advance(p_v, (j, 0))

Iterate over , and update and

504    for _ in range(steps):

Load

506        b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")

Compute

508        b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
509        b_s = b_s * sm_scale_log2e

Apply causal mask

512        if MASK:
513            causal_mask = offs_i[:, None] >= (j + offs_j[None, :])
514            b_s = tl.where(causal_mask, b_s, -float("inf"))

Mask out if the block is beyond the end of

517        j_mask = (j + offs_j) < kv_seq_len
518        b_s = tl.where(j_mask[None, :], b_s, -float("inf"))

521        b_m_new = tl.maximum(b_m, tl.max(b_s, -1))

527        b_p = tl.math.exp2(b_s - b_m_new[:, None])

530        b_l_new = tl.sum(b_p, -1)

532        b_m_m_new = tl.math.exp2(b_m - b_m_new)

534        b_l = b_l * b_m_m_new + b_l_new

537        b_o = b_o * b_m_m_new[:, None]
538        b_p = b_p.to(b_q.dtype)  # TODO
539        b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
540        b_o += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)

543        b_m = b_m_new

Move pointers

546        j += BLOCK_K
547        p_v = tl.advance(p_v, (BLOCK_K, 0))
548        p_kT = tl.advance(p_kT, (0, BLOCK_K))
549
550    tl.static_assert(b_o.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
551
552    return b_o, b_l, b_m

Triton kernel to compute

555@triton.jit
556def _attn_bwd_d(t_o, t_do,
557                t_pdp,
558                BLOCK_Q: tl.constexpr, d_head: tl.constexpr,
559                q_seq_len: tl.constexpr,
560                n_groups: tl.constexpr,
561                ):
565    i = tl.program_id(0) * BLOCK_Q
566    z = tl.program_id(1)

Create block pointers

569    p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head,
570                            (n_groups, q_seq_len, d_head),
571                            (q_seq_len * d_head, d_head, 1),
572                            (0, i, 0),
573                            (n_groups, BLOCK_Q, d_head),
574                            (2, 1, 0))
575    p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
576                             (n_groups, q_seq_len, d_head),
577                             (q_seq_len * d_head, d_head, 1),
578                             (0, i, 0),
579                             (n_groups, BLOCK_Q, d_head),
580                             (2, 1, 0))
581    p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
582                              (n_groups, q_seq_len),
583                              (q_seq_len, 1),
584                              (0, i),
585                              (n_groups, BLOCK_Q),
586                              (1, 0))

Load

589    o = tl.load(p_o, boundary_check=(1,), padding_option="zero")

Load

591    do = tl.load(p_do, boundary_check=(1,), padding_option="zero").to(HI_PRES_TL)

Calculate

593    d = tl.sum(o * do, axis=-1)

Save

595    tl.store(p_pdp, d, boundary_check=(1,))

Triton kernel to compute and

598@triton.autotune(_get_autotune_configs(inner_loop='query'),
599                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
600@triton.jit
601def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
602                   t_do,
603                   t_dk, t_dv,
604                   t_lse, t_pdp,
605                   q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
606                   n_groups: tl.constexpr, d_head: tl.constexpr,
607                   is_causal: tl.constexpr,
608                   BLOCK_Q: tl.constexpr,
609                   BLOCK_K: tl.constexpr,
610                   ):

Compute and for j ... j + BLOCK_K by iterating over

616    j = tl.program_id(0) * BLOCK_K
617    z = tl.program_id(1)

Create block pointers

620    p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
621                            (kv_seq_len, d_head),
622                            (d_head, 1),
623                            (j, 0),
624                            (BLOCK_K, d_head),
625                            (1, 0))
626    p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
627                            (kv_seq_len, d_head),
628                            (d_head, 1),
629                            (j, 0),
630                            (BLOCK_K, d_head),
631                            (1, 0))
632    p_dk = tl.make_block_ptr(t_dk + z * kv_seq_len * d_head,
633                             (kv_seq_len, d_head),
634                             (d_head, 1),
635                             (j, 0),
636                             (BLOCK_K, d_head),
637                             (1, 0))
638    p_dv = tl.make_block_ptr(t_dv + z * kv_seq_len * d_head,
639                             (kv_seq_len, d_head),
640                             (d_head, 1),
641                             (j, 0),
642                             (BLOCK_K, d_head),
643                             (1, 0))

Initialize and

646    b_dk = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
647    b_dv = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)

Load and outside the loop.

650    b_k = tl.load(p_k, boundary_check=(0,), padding_option="zero")
651    b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")

Iterate through queries in GQA

654    for g in range(n_groups):

Create block pointers

656        p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
657                                 (d_head, q_seq_len),
658                                 (1, d_head),
659                                 (0, 0),
660                                 (d_head, BLOCK_Q),
661                                 (0, 1))
662
663        p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
664                                 (q_seq_len, d_head),
665                                 (d_head, 1),
666                                 (0, 0),
667                                 (BLOCK_Q, d_head),
668                                 (1, 0))
669        p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
670                                  (q_seq_len,),
671                                  (1,),
672                                  (0,),
673                                  (BLOCK_Q,),
674                                  (0,))
675        p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
676                                  (q_seq_len,),
677                                  (1,),
678                                  (0,),
679                                  (BLOCK_Q,),
680                                  (0,))
681
682        if is_causal:

Inner loop at the diagonal block

684            b_dk, b_dv = _attn_bwd_dkdv_inner(
685                b_dk, b_dv,
686                p_qT, b_k, b_v, p_do,
687                p_lse, p_pdp,
688                BLOCK_Q, BLOCK_K,
689                d_head,
690                j=j, i=j,
691                steps=BLOCK_K // BLOCK_Q,
692                MASK=True,
693                q_seq_len=q_seq_len,
694                kv_seq_len=kv_seq_len,
695            )

Inner loop on queries after the diagonal

698            b_dk, b_dv = _attn_bwd_dkdv_inner(
699                b_dk, b_dv,
700                p_qT, b_k, b_v, p_do,
701                p_lse, p_pdp,
702                BLOCK_Q, BLOCK_K,
703                d_head,
704                j=j, i=j + BLOCK_K,
705                steps=tl.cdiv((q_seq_len - (j + BLOCK_K)), BLOCK_Q),
706                MASK=False,
707                q_seq_len=q_seq_len,
708                kv_seq_len=kv_seq_len
709            )
710        else:

Iterate through all queries

712            b_dk, b_dv = _attn_bwd_dkdv_inner(
713                b_dk, b_dv,
714                p_qT, b_k, b_v, p_do,
715                p_lse, p_pdp,
716                BLOCK_Q, BLOCK_K,
717                d_head,
718                j=j, i=tl.full([], 0, tl.int32),
719                steps=tl.cdiv(q_seq_len, BLOCK_Q),
720                MASK=False,
721                q_seq_len=q_seq_len,
722                kv_seq_len=kv_seq_len
723            )

Save

726    tl.store(p_dv, b_dv.to(t_dv.type.element_ty), boundary_check=(0,))

b_dk had

729    b_dk *= sm_scale

Save

732    tl.store(p_dk, b_dk.to(t_dk.type.element_ty), boundary_check=(0,))

Inner loop to calculate ,

735@triton.jit
736def _attn_bwd_dkdv_inner(b_dk, b_dv,
737                         p_qT, b_k, b_v, p_do,
738                         p_lse, p_pdp,
739                         BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
740                         d_head: tl.constexpr,
741                         j, i, steps,
742                         MASK: tl.constexpr,
743                         q_seq_len: tl.constexpr,
744                         kv_seq_len: tl.constexpr):

To apply the mask

750    tl.static_assert(BLOCK_K % BLOCK_Q == 0)

Offsets and mask

753    offs_i = i + tl.arange(0, BLOCK_Q)
754    offs_j = j + tl.arange(0, BLOCK_K)

Move the pointers

757    p_qT = tl.advance(p_qT, (0, i))
758    p_do = tl.advance(p_do, (i, 0))
759    p_lse = tl.advance(p_lse, (i,))
760    p_pdp = tl.advance(p_pdp, (i,))

Iterate over

763    for _ in range(steps):

Load

765        b_qT = tl.load(p_qT, boundary_check=(1,), padding_option="zero")

768        b_l = tl.load(p_lse, boundary_check=(0,), padding_option="zero")

771        b_sT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)

780        b_pT = tl.math.exp2(b_sT - b_l[None, :])

Autoregressive masking

783        if MASK:
784            mask = (offs_i[None, :] >= offs_j[:, None])
785            b_pT = tl.where(mask, b_pT, 0.0)

Mask out if the block is beyond the end of

Note: No need to mask out based on because the effects on positions outside boundary will not get stored in or Masking by may also not be necessary size the tensors have 0 on loading

792        i_mask = offs_i < q_seq_len
793        b_pT = tl.where(i_mask[None, :], b_pT, 0.0)

796        b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
797        b_dv += tl.dot(b_pT.to(b_do.dtype), b_do, out_dtype=HI_PRES_TL)

800        b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")

802        b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)

804        b_dsT = b_pT * (b_dpT - b_pdp[None, :])

806        b_dk += tl.dot(b_dsT.to(b_qT.dtype), tl.trans(b_qT), out_dtype=HI_PRES_TL)

Increment pointers.

809        offs_i += BLOCK_Q
810        p_lse = tl.advance(p_lse, (BLOCK_Q,))
811        p_pdp = tl.advance(p_pdp, (BLOCK_Q,))
812        p_qT = tl.advance(p_qT, (0, BLOCK_Q))
813        p_do = tl.advance(p_do, (BLOCK_Q, 0))

Return accumulated and

816    return b_dk, b_dv

Triton kernel to compute

819@triton.autotune(_get_autotune_configs(inner_loop='key'),
820                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
821@triton.jit
822def _attn_bwd_dq(t_q, t_k, t_v, t_do,
823                 t_dq,
824                 t_lse, t_pdp,
825                 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
826                 n_groups: tl.constexpr, d_head: tl.constexpr,
827                 is_causal: tl.constexpr,
828                 BLOCK_Q: tl.constexpr,
829                 BLOCK_K: tl.constexpr,
830                 ):
835    i = tl.program_id(0) * BLOCK_Q
836    z = tl.program_id(1) // n_groups
837    g = tl.program_id(1) % n_groups  # TODO

Create block pointers

840    p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
841                            (q_seq_len, d_head),
842                            (d_head, 1),
843                            (i, 0),
844                            (BLOCK_Q, d_head),
845                            (1, 0))
846    p_dq = tl.make_block_ptr(t_dq + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
847                             (q_seq_len, d_head),
848                             (d_head, 1),
849                             (i, 0),
850                             (BLOCK_Q, d_head),
851                             (1, 0))
852    p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
853                             (q_seq_len, d_head),
854                             (d_head, 1),
855                             (i, 0),
856                             (BLOCK_Q, d_head),
857                             (1, 0))
858    p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
859                             (d_head, kv_seq_len),
860                             (1, d_head),
861                             (0, 0),
862                             (d_head, BLOCK_K),
863                             (0, 1))
864    p_vT = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
865                             (d_head, kv_seq_len),
866                             (1, d_head),
867                             (0, 0),
868                             (d_head, BLOCK_K),
869                             (0, 1))
870    p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
871                              (q_seq_len,),
872                              (1,),
873                              (i,),
874                              (BLOCK_Q,),
875                              (0,))
876    p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
877                              (q_seq_len,),
878                              (1,),
879                              (i,),
880                              (BLOCK_Q,),
881                              (0,))

Load , , , and outside the loop

884    b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
885    b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
886    b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
887    b_lse = tl.load(p_lse, boundary_check=(0,), padding_option="zero")

Initialize

890    b_dq = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)

894    if is_causal:

Compute for masked (diagonal) blocks.

896        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
897                                  b_do, b_lse, b_pdp,
898                                  BLOCK_Q, BLOCK_K,
899                                  i=i, j=i,
900                                  steps=BLOCK_Q // BLOCK_K,
901                                  MASK=True,
902                                  q_seq_len=q_seq_len,
903                                  kv_seq_len=kv_seq_len
904                                  )

Compute for other blocks

907        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
908                                  b_do, b_lse, b_pdp,
909                                  BLOCK_Q, BLOCK_K,
910                                  i=i, j=tl.full([], 0, tl.int32),  # type: ignore
911                                  steps=i // BLOCK_K,
912                                  MASK=False,
913                                  q_seq_len=q_seq_len,
914                                  kv_seq_len=kv_seq_len
915                                  )
916    else:

Iterate through all

918        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
919                                  b_do, b_lse, b_pdp,
920                                  BLOCK_Q, BLOCK_K,
921                                  i=i, j=tl.full([], 0, tl.int32),  # type: ignore
922                                  steps=tl.cdiv(kv_seq_len, BLOCK_K),
923                                  MASK=False,
924                                  q_seq_len=q_seq_len,
925                                  kv_seq_len=kv_seq_len
926                                  )

b_dq stores so multiply by to get

929    b_dq *= 0.6931471824645996

Save

932    tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))

Inner loop to calculate

935@triton.jit
936def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
937                       b_do, b_lse, b_pdp,
938                       BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
939                       i, j, steps,
940                       MASK: tl.constexpr,
941                       q_seq_len: tl.constexpr,
942                       kv_seq_len: tl.constexpr):

Offsets

948    offs_i = i + tl.arange(0, BLOCK_Q)
949    offs_j = j + tl.arange(0, BLOCK_K)

Move the pointers

952    p_kT = tl.advance(p_kT, (0, j))
953    p_vT = tl.advance(p_vT, (0, j))
954
955    tl.static_assert(BLOCK_Q % BLOCK_K == 0, 'BLOCK_Q must be divisible by BLOCK_K')

Iterate over

958    for _ in range(steps):

Load

960        b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")

Load

962        b_vT = tl.load(p_vT, boundary_check=(1,), padding_option="zero")

965        b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)

974        b_p = tl.math.exp2(b_s - b_lse[:, None])

Autoregressive masking

977        if MASK:
978            causal_mask = (offs_i[:, None] >= offs_j[None, :])
979            b_p = tl.where(causal_mask, b_p, 0.0)

Mask out if the block is beyond the end of

982        j_mask = offs_j < kv_seq_len
983        b_p = tl.where(j_mask[None, :], b_p, 0.0)

988        b_dp = tl.dot(b_do, b_vT, out_dtype=HI_PRES_TL).to(HI_PRES_TL)

990        b_ds = b_p * (b_dp - b_pdp[:, None])

992        b_dq += tl.dot(b_ds.to(b_kT.dtype), tl.trans(b_kT), out_dtype=HI_PRES_TL)

Increment pointers.

995        offs_j += BLOCK_K
996        p_kT = tl.advance(p_kT, (0, BLOCK_K))
997        p_vT = tl.advance(p_vT, (0, BLOCK_K))

Return accumulated

1000    return b_dq