変圧器用ユーティリティ

10import torch

後続マスクにより、将来の(後続の)タイムステップからデータを隠すことができます

13def subsequent_mask(seq_len):
17    mask = torch.tril(torch.ones(seq_len, seq_len)).to(torch.bool).unsqueeze(-1)
18    return mask
21def _subsequent_mask():
22    from labml.logger import inspect
23    inspect(subsequent_mask(10)[:, :, 0])
24
25
26if __name__ == '__main__':
27    _subsequent_mask()