This is the code to test and measure performance of our flash attention implementation
7import torch
8import triton
9
10from labml import logger, monit
11from labml_nn.transformers.flash import attention
12
13HI_PRES_TORCH = torch.float32
16@torch.no_grad()
17def _calc_abs_rel_error(a: torch.Tensor, b: torch.Tensor, atol=1e-2):
21 d = (a - b).abs()
22 max_abs = d.max()
23 d = (d - atol).clamp(min=0)
24 d = d / b.abs()
25 max_rel = d.max()
26
27 return max_abs.cpu().item(), max_rel.cpu().item()
30def test_fwd_bwd(batch_size, n_heads, k_heads, q_seq_len, kv_seq_len, d_head, causal, dtype, device):
35 with monit.section(f'Init {q_seq_len} {kv_seq_len} {d_head}'):
36 torch.manual_seed(20)
37 q = (torch.empty((batch_size, n_heads, q_seq_len, d_head),
38 dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
39 k = (torch.empty((batch_size, k_heads, kv_seq_len, d_head),
40 dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
41 v = (torch.empty((batch_size, k_heads, kv_seq_len, d_head),
42 dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
43 sm_scale = d_head ** -0.5
44 d_out = torch.randn_like(q)
reference implementation
46 mask = torch.tril(torch.ones((q_seq_len, kv_seq_len), device=device, dtype=torch.bool))
47 torch.cuda.synchronize()
48
49 with monit.section('Pytorch'):
50 p = torch.matmul(q.view(batch_size, k_heads, -1, q_seq_len, d_head),
51 k.transpose(2, 3)[:, :, None, :, :]) * sm_scale
52 if causal:
53 p[:, :, :, ~mask] = float("-inf")
54 p = torch.softmax(p.to(HI_PRES_TORCH), dim=-1).to(dtype)
55 ref_out = torch.matmul(p, v[:, :, None, :, :])
56 ref_out = ref_out.view(q.shape)
57 ref_out.backward(d_out)
58 ref_dv, v.grad = v.grad.clone(), None
59 ref_dk, k.grad = k.grad.clone(), None
60 ref_dq, q.grad = q.grad.clone(), None
61 torch.cuda.synchronize()
62
63 with monit.section('Triton'):
64 assert q.dtype == dtype
65 tri_out = attention(q, k, v, causal, sm_scale).to(dtype)
66 monit.progress(0.5)
67
68 tri_out.backward(d_out)
69 monit.progress(0.9)
70 tri_dv, v.grad = v.grad.clone(), None # type: ignore
71 tri_dk, k.grad = k.grad.clone(), None # type: ignore
72 tri_dq, q.grad = q.grad.clone(), None # type: ignore
73 torch.cuda.synchronize()
74
75 with monit.section('Test') as s:
compare
77 passed = True
78 if not torch.allclose(tri_out, ref_out, atol=1e-2, rtol=0.):
79 abs_err, rel_err = _calc_abs_rel_error(ref_out, tri_out)
80 logger.log(('[FAILED]', logger.Text.danger), f' Out mismatch {abs_err} {rel_err}')
81 passed = False
82 rtol = 1e-1
83 if not torch.allclose(tri_dq, ref_dq, atol=1e-2, rtol=rtol):
84 abs_err, rel_err = _calc_abs_rel_error(ref_dq, tri_dq)
85 logger.log(('[FAILED]', logger.Text.danger), f' dQ mismatch {abs_err} {rel_err}')
86 passed = False
87 if not torch.allclose(tri_dv, ref_dv, atol=1e-2, rtol=rtol):
88 abs_err, rel_err = _calc_abs_rel_error(ref_dv, tri_dv)
89 logger.log(('[FAILED]', logger.Text.danger), f' dV mismatch {abs_err} {rel_err}')
90 passed = False
91 if not torch.allclose(tri_dk, ref_dk, atol=1e-2, rtol=rtol):
92 abs_err, rel_err = _calc_abs_rel_error(ref_dk, tri_dk)
93 logger.log(('[FAILED]', logger.Text.danger), f' dK mismatch {abs_err} {rel_err}')
94 passed = False
95
96 if passed:
97 logger.log('[PASSED]', logger.Text.success)
98 s.success = True
99 else:
100 s.success = False
101 torch.cuda.synchronize()
Get a partial function to test performance of our implementation
104def _perf_triton_fn(*, device, dtype, batch_size, k_heads, n_groups, seq_len, d_head, causal):
108 q = torch.randn((batch_size, k_heads * n_groups, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
109 k = torch.randn((batch_size, k_heads, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
110 v = torch.randn((batch_size, k_heads, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
111 sm_scale = d_head ** -0.5
112 return lambda: attention(q, k, v, causal, sm_scale)
Get a partial function to test performance of original flash implementation
115def _perf_flash(*, batch_size, k_heads, n_groups, seq_len, d_head, causal, device, dtype):
119 q = torch.randn((batch_size, seq_len, k_heads * n_groups, d_head), dtype=dtype, device=device, requires_grad=True)
120 k = torch.randn((batch_size, seq_len, k_heads, d_head), dtype=dtype, device=device, requires_grad=True)
121 v = torch.randn((batch_size, seq_len, k_heads, d_head), dtype=dtype, device=device, requires_grad=True)
122 from flash_attn import flash_attn_func
123 return lambda: flash_attn_func(q, k, v, causal=causal)
126def measure_performance(name, fn, *, batch_size, k_heads, n_groups, seq_len, d_head, causal, is_bwd: bool):
130 if is_bwd:
131 o = fn()
132 do = torch.randn_like(o)
133 fn = lambda: o.backward(do, retain_graph=True)
134 ms = triton.testing.do_bench(fn)
135
136 flops_per_matmul = 2.0 * batch_size * k_heads * n_groups * seq_len * seq_len * d_head
137 total_flops = 2 * flops_per_matmul
138 if causal:
139 total_flops *= 0.5
140 if is_bwd:
141 total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
142
143 tf_ps = total_flops * 1e-12 / (ms * 1e-3)
144 logger.log((f'{name}', logger.Text.key), ': ', f'{ms :,.1f}ms', ' ', f'{tf_ps :,.2f}TFps')
147def main():
148 device = torch.device('cuda:0')
149 torch.cuda.set_device(device)
150
151 dtype = torch.float16
only works on post-Ampere GPUs right now
154 test_fwd_bwd(1, 4, 1, 2048, 2048, 128, True, dtype=dtype, device=device)
155 test_fwd_bwd(16, 32, 8, 2001, 4001, 128, False, dtype=dtype, device=device)
156 test_fwd_bwd(4, 32, 8, 2048, 1024, 128, False, dtype=dtype, device=device)
157 test_fwd_bwd(4, 32, 8, 2001, 4001, 128, True, dtype=dtype, device=device)
158
159 _conf = {
160 'batch_size': 16,
161 'k_heads': 8,
162 'n_groups': 4,
163 'seq_len': 2048,
164 'd_head': 128,
165 }
166
167 for _causal in [False, True]:
168 for is_bwd in [False, True]:
169 logger.log(f'{"Causal" if _causal else "Non-causal"} {" Backward" if is_bwd else ""}', logger.Text.title)
170 measure_performance(f'flash', _perf_flash(causal=_causal, device=device, dtype=dtype, **_conf),
171 is_bwd=is_bwd,
172 causal=_causal, **_conf)
173 measure_performance(f'triton', _perf_triton_fn(causal=_causal, device=device, dtype=dtype, **_conf),
174 is_bwd=is_bwd,
175 causal=_causal, **_conf)
176
177
178if __name__ == "__main__":
179 main()