DDPM 实验的实用程序函数

10import torch.utils.data

收集要素地图形状的常量并将其整形为要素地图形状

13def gather(consts: torch.Tensor, t: torch.Tensor):
15    c = consts.gather(-1, t)
16    return c.reshape(-1, 1, 1, 1)