29def test():
30 device_info = DeviceInfo(use_cuda=True, cuda_device=0)
31 print(device_info)
32 inp = torch.randn((64, 1, 28, 28), device=device_info.device)
33 target = torch.ones(64, dtype=torch.long, device=device_info.device)
34 loss_func = nn.CrossEntropyLoss()
35 model = Model().to(device_info.device)
36 my_adam = MyAdam(model.parameters())
37 torch_adam = TorchAdam(model.parameters())
38 loss = loss_func(model(inp), target)
39 loss.backward()
40 with monit.section('MyAdam warmup'):
41 for i in range(100):
42 my_adam.step()
43 with monit.section('MyAdam'):
44 for i in range(1000):
45 my_adam.step()
46 with monit.section('TorchAdam warmup'):
47 for i in range(100):
48 torch_adam.step()
49 with monit.section('TorchAdam'):
50 for i in range(1000):
51 torch_adam.step()
52
53
54if __name__ == '__main__':
55 test()