12import argparse
13
14import torch
15from torch import nn
16
17from labml_nn.neox.evaluation import run_eval_harness
18from labml_nn.neox.model import LayerGenerator
21def main():
参数解析器
23 parser = argparse.ArgumentParser()
24
25 parser.add_argument("--flash", action='store_true', help="whether to use Flash Attention")
26
27 opt = parser.parse_args()
设备
30 device = torch.device('cuda:0')
加载图层
32 layers = list(LayerGenerator(is_clone_layers=True,
33 filter_layers=None,
34 dtype=torch.float16,
35 device=device,
36 is_flash_attention=opt.flash,
37 ).load())
创建nn.Sequential
模型
40 model = nn.Sequential(*layers)
47if __name__ == '__main__':
48 main()