ノイズ除去拡散確率モデル (DDPM)

Open In Colab

これは、論文「ノイズ除去拡散確率モデル」のPyTorch実装/チュートリアルです

簡単に言うと、データから画像を取得し、段階的にノイズを追加します。次に、モデルをトレーニングして各ステップでそのノイズを予測し、そのモデルを使用して画像を生成します。

以下の定義と派生は、これがどのように機能するかを示しています。詳しくは論文を参照してください

転送プロセス

フォワード処理ではタイムステップのノイズがデータに追加されます。

差異スケジュールはどこですか。

以下を使用すると、任意のタイムステップでサンプリングできます。

どこと

リバースプロセス

逆の処理では、4 つのタイムステップからノイズを除去します。

私たちがトレーニングするパラメータです。

損失

(ジェンソンの不等式から)ELBOを負の対数確率で最適化します。

損失は次のように書き直すことができます。

私たちは一定なので一定です。

コンピューティング

次の条件で条件付けされたフォワードプロセスの事後処理は

用紙セットは定数またはに設定されています。

次に、

与えられたノイズに対して

これにより、

ノイズを予測するためのモデルによる再パラメータ化

与えられたものを予測する学習済み関数はどこか

これにより、

つまり、ノイズを予測するためのトレーニングを行っています。

簡易損失

これにより、ウェイトを廃棄するタイミングと廃棄するタイミングを最小限に抑えることができます。重みを捨てると、(ノイズレベルが高い)高いほうに与えられる重みが増え、サンプルの品質が向上します

このファイルには、トレーニング中に画像を生成するために使用する損失計算と基本的なサンプリング方法が実装されています。

これがコードを提供してトレーニングするUNetモデルですこのファイルでは、トレーニング済みのモデルからサンプルと補間を生成できます

162from typing import Tuple, Optional
163
164import torch
165import torch.nn.functional as F
166import torch.utils.data
167from torch import nn
168
169from labml_nn.diffusion.ddpm.utils import gather

ノイズ除去

172class DenoiseDiffusion:
  • eps_model モデルです
  • n_steps
  • device 定数を配置するデバイスです
177    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
183        super().__init__()
184        self.eps_model = eps_model

直線的に増加する差異スケジュールの作成

187        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)

190        self.alpha = 1. - self.beta

192        self.alpha_bar = torch.cumprod(self.alpha, dim=0)

194        self.n_steps = n_steps

196        self.sigma2 = self.beta

ディストリビューションを取得

198    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

収集して計算する

208        mean = gather(self.alpha_bar, t) ** 0.5 * x0

210        var = 1 - gather(self.alpha_bar, t)

212        return mean, var

からのサンプル

214    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):

224        if eps is None:
225            eps = torch.randn_like(x0)

取得

228        mean, var = self.q_xt_x0(x0, t)

からのサンプル

230        return mean + (var ** 0.5) * eps

からのサンプル

232    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):

246        eps_theta = self.eps_model(xt, t)

集まる

248        alpha_bar = gather(self.alpha_bar, t)

250        alpha = gather(self.alpha, t)

252        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5

255        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)

257        var = gather(self.sigma2, t)

260        eps = torch.randn(xt.shape, device=xt.device)

[サンプル]

262        return mean + (var ** .5) * eps

簡易損失

264    def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):

バッチサイズを取得

273        batch_size = x0.shape[0]

バッチ内の各サンプルをランダムに取得

275        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)

278        if noise is None:
279            noise = torch.randn_like(x0)

のサンプル

282        xt = self.q_sample(x0, t, eps=noise)

取得

284        eps_theta = self.eps_model(xt, t)

MSE ロス

287        return F.mse_loss(noise, eps_theta)