これは、論文「ノイズ除去拡散確率モデル」のPyTorch実装/チュートリアルです。
簡単に言うと、データから画像を取得し、段階的にノイズを追加します。次に、モデルをトレーニングして各ステップでそのノイズを予測し、そのモデルを使用して画像を生成します。
以下の定義と派生は、これがどのように機能するかを示しています。詳しくは論文を参照してください。
フォワード処理では、タイムステップのノイズがデータに追加されます。
差異スケジュールはどこですか。
以下を使用すると、任意のタイムステップでサンプリングできます。
どこと
逆の処理では、4 つのタイムステップからノイズを除去します。
私たちがトレーニングするパラメータです。
(ジェンソンの不等式から)ELBOを負の対数確率で最適化します。
損失は次のように書き直すことができます。
私たちは一定なので一定です。
次の条件で条件付けされたフォワードプロセスの事後処理は
用紙セットは定数またはに設定されています。
次に、
与えられたノイズに対して
これにより、
ノイズを予測するためのモデルによる再パラメータ化
与えられたものを予測する学習済み関数はどこか
これにより、
つまり、ノイズを予測するためのトレーニングを行っています。
これにより、ウェイトを廃棄するタイミングと廃棄するタイミングを最小限に抑えることができます。重みを捨てると、(ノイズレベルが高い)高いほうに与えられる重みが増え、サンプルの品質が向上します
。このファイルには、トレーニング中に画像を生成するために使用する損失計算と基本的なサンプリング方法が実装されています。
これがコードを提供してトレーニングするUNetモデルです。このファイルでは、トレーニング済みのモデルからサンプルと補間を生成できます
。162from typing import Tuple, Optional
163
164import torch
165import torch.nn.functional as F
166import torch.utils.data
167from torch import nn
168
169from labml_nn.diffusion.ddpm.utils import gather
172class DenoiseDiffusion:
eps_model
モデルですn_steps
は device
定数を配置するデバイスです177 def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
183 super().__init__()
184 self.eps_model = eps_model
直線的に増加する差異スケジュールの作成
187 self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
190 self.alpha = 1. - self.beta
192 self.alpha_bar = torch.cumprod(self.alpha, dim=0)
194 self.n_steps = n_steps
196 self.sigma2 = self.beta
198 def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
210 var = 1 - gather(self.alpha_bar, t)
212 return mean, var
214 def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
224 if eps is None:
225 eps = torch.randn_like(x0)
取得
228 mean, var = self.q_xt_x0(x0, t)
からのサンプル
230 return mean + (var ** 0.5) * eps
232 def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
246 eps_theta = self.eps_model(xt, t)
250 alpha = gather(self.alpha, t)
252 eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
255 mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
257 var = gather(self.sigma2, t)
260 eps = torch.randn(xt.shape, device=xt.device)
[サンプル]
262 return mean + (var ** .5) * eps
264 def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):
バッチサイズを取得
273 batch_size = x0.shape[0]
バッチ内の各サンプルをランダムに取得
275 t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
278 if noise is None:
279 noise = torch.randn_like(x0)
のサンプル
282 xt = self.q_sample(x0, t, eps=noise)
取得
284 eps_theta = self.eps_model(xt, t)
MSE ロス
287 return F.mse_loss(noise, eps_theta)