Test Flash Attention Implementation

This is the code to test and measure performance of our flash attention implementation

7import torch
8import triton
9
10from labml import logger, monit
11from labml_nn.transformers.flash import attention
12
13HI_PRES_TORCH = torch.float32

Calculate absolute and relative error for reporting

16@torch.no_grad()
17def _calc_abs_rel_error(a: torch.Tensor, b: torch.Tensor, atol=1e-2):
21    d = (a - b).abs()
22    max_abs = d.max()
23    d = (d - atol).clamp(min=0)
24    d = d / b.abs()
25    max_rel = d.max()
26
27    return max_abs.cpu().item(), max_rel.cpu().item()

Compare our implementation with naive PyTorch attention

30def test_fwd_bwd(batch_size, n_heads, k_heads, q_seq_len, kv_seq_len, d_head, causal, dtype, device):
35    with monit.section(f'Init {q_seq_len} {kv_seq_len} {d_head}'):
36        torch.manual_seed(20)
37        q = (torch.empty((batch_size, n_heads, q_seq_len, d_head),
38                         dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
39        k = (torch.empty((batch_size, k_heads, kv_seq_len, d_head),
40                         dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
41        v = (torch.empty((batch_size, k_heads, kv_seq_len, d_head),
42                         dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
43        sm_scale = d_head ** -0.5
44        d_out = torch.randn_like(q)

reference implementation

46        mask = torch.tril(torch.ones((q_seq_len, kv_seq_len), device=device, dtype=torch.bool))
47        torch.cuda.synchronize()
48
49    with monit.section('Pytorch'):
50        p = torch.matmul(q.view(batch_size, k_heads, -1, q_seq_len, d_head),
51                         k.transpose(2, 3)[:, :, None, :, :]) * sm_scale
52        if causal:
53            p[:, :, :, ~mask] = float("-inf")
54        p = torch.softmax(p.to(HI_PRES_TORCH), dim=-1).to(dtype)
55        ref_out = torch.matmul(p, v[:, :, None, :, :])
56        ref_out = ref_out.view(q.shape)
57        ref_out.backward(d_out)
58        ref_dv, v.grad = v.grad.clone(), None
59        ref_dk, k.grad = k.grad.clone(), None
60        ref_dq, q.grad = q.grad.clone(), None
61        torch.cuda.synchronize()
62
63    with monit.section('Triton'):
64        assert q.dtype == dtype
65        tri_out = attention(q, k, v, causal, sm_scale).to(dtype)
66        monit.progress(0.5)
67
68        tri_out.backward(d_out)
69        monit.progress(0.9)
70        tri_dv, v.grad = v.grad.clone(), None  # type: ignore
71        tri_dk, k.grad = k.grad.clone(), None  # type: ignore
72        tri_dq, q.grad = q.grad.clone(), None  # type: ignore
73        torch.cuda.synchronize()
74
75    with monit.section('Test') as s:

compare

77        passed = True
78        if not torch.allclose(tri_out, ref_out, atol=1e-2, rtol=0.):
79            abs_err, rel_err = _calc_abs_rel_error(ref_out, tri_out)
80            logger.log(('[FAILED]', logger.Text.danger), f' Out mismatch {abs_err} {rel_err}')
81            passed = False
82        rtol = 1e-1
83        if not torch.allclose(tri_dq, ref_dq, atol=1e-2, rtol=rtol):
84            abs_err, rel_err = _calc_abs_rel_error(ref_dq, tri_dq)
85            logger.log(('[FAILED]', logger.Text.danger), f' dQ mismatch {abs_err} {rel_err}')
86            passed = False
87        if not torch.allclose(tri_dv, ref_dv, atol=1e-2, rtol=rtol):
88            abs_err, rel_err = _calc_abs_rel_error(ref_dv, tri_dv)
89            logger.log(('[FAILED]', logger.Text.danger), f' dV mismatch {abs_err} {rel_err}')
90            passed = False
91        if not torch.allclose(tri_dk, ref_dk, atol=1e-2, rtol=rtol):
92            abs_err, rel_err = _calc_abs_rel_error(ref_dk, tri_dk)
93            logger.log(('[FAILED]', logger.Text.danger), f' dK mismatch {abs_err} {rel_err}')
94            passed = False
95
96        if passed:
97            logger.log('[PASSED]', logger.Text.success)
98            s.success = True
99        else:
100            s.success = False
101        torch.cuda.synchronize()

Get a partial function to test performance of our implementation

104def _perf_triton_fn(*, device, dtype, batch_size, k_heads, n_groups, seq_len, d_head, causal):
108    q = torch.randn((batch_size, k_heads * n_groups, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
109    k = torch.randn((batch_size, k_heads, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
110    v = torch.randn((batch_size, k_heads, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
111    sm_scale = d_head ** -0.5
112    return lambda: attention(q, k, v, causal, sm_scale)

Get a partial function to test performance of original flash implementation

115def _perf_flash(*, batch_size, k_heads, n_groups, seq_len, d_head, causal, device, dtype):
119    q = torch.randn((batch_size, seq_len, k_heads * n_groups, d_head), dtype=dtype, device=device, requires_grad=True)
120    k = torch.randn((batch_size, seq_len, k_heads, d_head), dtype=dtype, device=device, requires_grad=True)
121    v = torch.randn((batch_size, seq_len, k_heads, d_head), dtype=dtype, device=device, requires_grad=True)
122    from flash_attn import flash_attn_func
123    return lambda: flash_attn_func(q, k, v, causal=causal)

Measure the speed

126def measure_performance(name, fn, *, batch_size, k_heads, n_groups, seq_len, d_head, causal, is_bwd: bool):
130    if is_bwd:
131        o = fn()
132        do = torch.randn_like(o)
133        fn = lambda: o.backward(do, retain_graph=True)
134    ms = triton.testing.do_bench(fn)
135
136    flops_per_matmul = 2.0 * batch_size * k_heads * n_groups * seq_len * seq_len * d_head
137    total_flops = 2 * flops_per_matmul
138    if causal:
139        total_flops *= 0.5
140    if is_bwd:
141        total_flops *= 2.5  # 2.0(bwd) + 0.5(recompute)
142
143    tf_ps = total_flops * 1e-12 / (ms * 1e-3)
144    logger.log((f'{name}', logger.Text.key), ': ', f'{ms :,.1f}ms', ' ', f'{tf_ps :,.2f}TFps')
147def main():
148    device = torch.device('cuda:0')
149    torch.cuda.set_device(device)
150
151    dtype = torch.float16

only works on post-Ampere GPUs right now

154    test_fwd_bwd(1, 4, 1, 2048, 2048, 128, True, dtype=dtype, device=device)
155    test_fwd_bwd(16, 32, 8, 2001, 4001, 128, False, dtype=dtype, device=device)
156    test_fwd_bwd(4, 32, 8, 2048, 1024, 128, False, dtype=dtype, device=device)
157    test_fwd_bwd(4, 32, 8, 2001, 4001, 128, True, dtype=dtype, device=device)
158
159    _conf = {
160        'batch_size': 16,
161        'k_heads': 8,
162        'n_groups': 4,
163        'seq_len': 2048,
164        'd_head': 128,
165    }
166
167    for _causal in [False, True]:
168        for is_bwd in [False, True]:
169            logger.log(f'{"Causal" if _causal else "Non-causal"} {" Backward" if is_bwd else ""}', logger.Text.title)
170            measure_performance(f'flash', _perf_flash(causal=_causal, device=device, dtype=dtype, **_conf),
171                                is_bwd=is_bwd,
172                                causal=_causal, **_conf)
173            measure_performance(f'triton', _perf_triton_fn(causal=_causal, device=device, dtype=dtype, **_conf),
174                                is_bwd=is_bwd,
175                                causal=_causal, **_conf)
176
177
178if __name__ == "__main__":
179    main()