zae_engine.examples.ddpm_mnist의 소스 코드

import math
import os.path
import shutil
from typing import Dict, Any, Union, Type, Sequence, Callable, Tuple, List

import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

from zae_engine.data import CollateBase
from zae_engine.models import AutoEncoder
from zae_engine.schedulers import CosineAnnealingScheduler, WarmUpScheduler, SchedulerChain
from zae_engine.nn_night.blocks import UNetBlock
from zae_engine.trainer import Trainer


[문서] class CustomMNISTDataset(Dataset): """ Custom Dataset class for MNIST data. This class wraps the torchvision MNIST dataset and returns a dictionary containing the image. Parameters ---------- root : str Root directory of dataset where MNIST exists or will be saved. train : bool, optional If True, creates dataset from training set, otherwise from test set. transform : callable, optional A function/transform that takes in an image and returns a transformed version. download : bool, optional If True, downloads the dataset from the internet and puts it in root directory. """ def __init__(self, root, train=True, transform=None, download=False): self.mnist = datasets.MNIST(root=root, train=train, transform=transform, download=download) def __len__(self): return len(self.mnist) def __getitem__(self, idx): image, label = self.mnist[idx] return {"pixel_values": image}
[문서] class TimestepEmbedding(nn.Module): """ Sinusoidal embedding for timesteps. This module generates sinusoidal embeddings for each timestep, similar to positional encodings used in Transformers. Parameters ---------- embed_dim : int Dimension of the embedding vector. """ def __init__(self, embed_dim): super(TimestepEmbedding, self).__init__() self.embed_dim = embed_dim self.linear = nn.Linear(embed_dim, embed_dim)
[문서] def forward(self, t: Tensor) -> Tensor: """ Forward pass for timestep embedding. Parameters ---------- t : Tensor Timestep tensor of shape (batch_size,) or higher. Returns ------- Tensor Embedded timestep tensor of shape (batch_size, embed_dim). """ if t.dim() > 1: t = t.view(-1) # Ensure t is (batch_size,) print(f"TimestepEmbedding forward: Reshaped t shape: {t.shape}") half_dim = self.embed_dim // 2 emb_scale = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=t.device) * -emb_scale) emb = t[:, None].float() * emb[None, :] # (batch_size, half_dim) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) # (batch_size, embed_dim) emb = self.linear(emb) # (batch_size, embed_dim) return emb
[문서] class NoiseScheduler: """ Scheduler for managing the noise levels in DDPM. This class defines the noise schedule used in the diffusion process, supporting both linear and cosine schedules. Parameters ---------- timesteps : int, optional Total number of diffusion steps. Default is 1000. schedule : str, optional Type of noise schedule ('linear', 'cosine'). Default is 'linear'. beta_start : float, optional Starting value of beta. Default is 1e-4. beta_end : float, optional Ending value of beta. Default is 0.02. Attributes ---------- beta : torch.Tensor Noise levels for each timestep. alpha : torch.Tensor 1 - beta for each timestep. alpha_bar : torch.Tensor Cumulative product of alpha up to each timestep. sqrt_alpha_bar : torch.Tensor Square root of alpha_bar. sqrt_one_minus_alpha_bar : torch.Tensor Square root of (1 - alpha_bar). posterior_variance : torch.Tensor Variance used in the posterior distribution. timesteps : int Total number of diffusion steps. """ def __init__(self, timesteps=1000, schedule="linear", beta_start=1e-4, beta_end=0.02): self.timesteps = timesteps if schedule == "linear": self.beta = self.linear_beta_schedule(timesteps, beta_start, beta_end) elif schedule == "cosine": self.beta = self.cosine_beta_schedule(timesteps) else: raise NotImplementedError(f"Schedule '{schedule}' is not implemented.") self.alpha = 1.0 - self.beta self.alpha_bar = torch.cumprod(self.alpha, dim=0) # Precompute terms for efficiency self.sqrt_alpha_bar = torch.sqrt(self.alpha_bar) self.sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bar) self.posterior_variance = self.beta[1:] * (1 - self.alpha_bar[:-1]) / (1 - self.alpha_bar[1:])
[문서] def linear_beta_schedule(self, timesteps: int, beta_start: float, beta_end: float) -> torch.Tensor: """ Linear schedule for beta. Parameters ---------- timesteps : int Total number of diffusion steps. beta_start : float Starting value of beta. beta_end : float Ending value of beta. Returns ------- torch.Tensor Beta schedule. """ return torch.linspace(beta_start, beta_end, timesteps)
[문서] def cosine_beta_schedule(self, timesteps: int, s: float = 0.008) -> torch.Tensor: """ Cosine schedule for beta as proposed in https://arxiv.org/abs/2102.09672. Parameters ---------- timesteps : int Total number of diffusion steps. s : float, optional Small offset to prevent beta from being exactly 0. Default is 0.008. Returns ------- torch.Tensor Beta schedule. """ steps = timesteps + 1 x = torch.linspace(0, timesteps, steps) alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi / 2) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] alphas_cumprod = alphas_cumprod[:timesteps] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) betas = torch.clamp(betas, min=1e-4, max=0.999) return betas
[문서] def get_sigma(self, t: torch.Tensor, ddim: bool = False) -> torch.Tensor: """ Get the sigma value for a given timestep. Parameters ---------- t : torch.Tensor Timestep tensor. ddim : bool, optional Whether to use DDIM sampling. If True, uses sigma=0 for deterministic sampling. Returns ------- torch.Tensor Sigma value for the given timestep. """ if ddim: return torch.zeros_like(self.beta[t]) # Deterministic sampling for DDIM else: return torch.sqrt(self.posterior_variance[t])
[문서] class ForwardDiffusion: """ Class for performing the forward diffusion process by adding noise to the input data. This class adds noise to the input data at a randomly sampled timestep, producing the noised data `x_t`, along with the original data `x0`, the timestep `t`, and the noise added. Parameters ---------- noise_scheduler : NoiseScheduler Instance of NoiseScheduler managing the noise levels. """ def __init__(self, noise_scheduler: NoiseScheduler): self.noise_scheduler = noise_scheduler def __call__(self, batch: Dict[str, Any]) -> Dict[str, Any]: """ Applies forward diffusion by adding noise to the input data at a random timestep. Parameters ---------- batch : Dict[str, Any] The input batch containing data under 'pixel_values'. Returns ------- Dict[str, Any] The batch with added noise, original data, and timestep information. Contains: - 'x_t': Noised data. - 'x0': Original data. - 't': Timestep. - 'noise': Added noise. """ key = "pixel_values" if key not in batch: raise KeyError(f"Batch must contain '{key}' key.") origin = batch[key] # origin has shape [channel, height, width] # Sample random timestep t = torch.randint(0, self.noise_scheduler.timesteps, (1,)).long() noise = torch.randn_like(origin) # Calculate x_t sqrt_alpha_bar_t = self.noise_scheduler.sqrt_alpha_bar[t].view(1, 1, 1) sqrt_one_minus_alpha_bar_t = self.noise_scheduler.sqrt_one_minus_alpha_bar[t].view(1, 1, 1) x_t = sqrt_alpha_bar_t * origin + sqrt_one_minus_alpha_bar_t * noise batch["x_t"] = x_t batch["x0"] = origin batch["t"] = t batch["noise"] = noise return batch
[문서] class DDPM(AutoEncoder): """ Denoising Diffusion Probabilistic Model (DDPM) implemented as an AutoEncoder. This model integrates timestep embeddings into the bottleneck of the AutoEncoder architecture, allowing the model to condition on the diffusion timestep during training. Parameters ---------- block : Type[Union[UNetBlock, nn.Module]] The block type to use in the AutoEncoder (e.g., UNetBlock). ch_in : int Number of input channels. ch_out : int Number of output channels. width : int Base width of the network. layers : Sequence[int] Number of layers in each block. groups : int, optional Number of groups for group normalization, by default 1. dilation : int, optional Dilation rate for convolutions, by default 1. norm_layer : Callable[..., nn.Module], optional Normalization layer to use, by default nn.BatchNorm2d. skip_connect : bool, optional Whether to use skip connections, by default False. timestep_embed_dim : int, optional Dimension of the timestep embedding, by default 256. Attributes ---------- timestep_embedding : TimestepEmbedding Module for generating timestep embeddings. t_embed_proj : nn.Linear Linear layer to project timestep embeddings to match the bottleneck dimensions. """ def __init__( self, block: Type[Union[UNetBlock, nn.Module]], ch_in: int, ch_out: int, width: int, layers: Sequence[int], groups: int = 1, dilation: int = 1, norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d, skip_connect: bool = False, timestep_embed_dim: int = 256, # Dimension of the timestep embedding ): super(DDPM, self).__init__(block, ch_in, ch_out, width, layers, groups, dilation, norm_layer, skip_connect) # Timestep embedding module self.timestep_embedding = TimestepEmbedding(timestep_embed_dim) # Additional layer to project timestep embeddings to match bottleneck dimensions self.t_embed_proj = nn.Linear(timestep_embed_dim, width * 16) print(f"DDPM initialized with timestep_embed_dim={timestep_embed_dim}")
[문서] def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: """ Forward pass for DDPM. Parameters ---------- x : torch.Tensor Noised image (x_t) of shape (batch_size, channels, height, width). t : torch.Tensor Timestep tensor of shape (batch_size,). Returns ------- torch.Tensor Reconstructed image tensor of shape (batch_size, channels, height, width). """ # Ensure t is (batch_size,) if t.dim() > 1: t = t.view(-1) self.feature_vectors = [] # Forward through the encoder and collect feature vectors via hooks _ = self.encoder(x) if not self.feature_vectors: raise ValueError("No feature vectors collected from encoder.") feat = self.bottleneck(self.feature_vectors.pop()) # Timestep embedding t_emb = self.timestep_embedding(t) # Shape: (batch_size, embed_dim) t_emb = self.t_embed_proj(t_emb) # Shape: (batch_size, width * 16) t_emb = t_emb[:, :, None, None] # Shape: (batch_size, width * 16, 1, 1) feat = feat + t_emb # Broadcasting addition # Decoder with skip connections if enabled for up_pool, dec in zip(self.up_pools, self.decoder): feat = up_pool(feat) if self.skip_connect and len(self.feature_vectors) > 0: feat = torch.cat((feat, self.feature_vectors.pop()), dim=1) feat = dec(feat) output = self.sig(self.fc(feat)) return output
[문서] class DDPMTrainer(Trainer): """ Trainer class specialized for DDPM training and sampling. Inherits from the abstract Trainer class and implements the train_step and test_step methods. Additionally, it includes methods for generating and visualizing samples using the trained model. """
[문서] def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Perform a training step for DDPM. Parameters ---------- batch : Dict[str, torch.Tensor] A batch of data containing 'x_t', 'x0', 't', 'noise'. Returns ------- Dict[str, torch.Tensor] A dictionary containing the loss and the model's output. """ x_t = batch["x_t"] t = batch["t"] noise = batch.get("noise", None) output = self.model(x_t, t) # Predicted noise loss = nn.MSELoss()(output, noise) if noise is not None else None return {"loss": loss, "output": output}
[문서] def test_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Perform a testing step for DDPM. Parameters ---------- batch : Dict[str, torch.Tensor] A batch of data containing 'x_t', 'x0', 't', 'noise'. Returns ------- Dict[str, torch.Tensor] A dictionary containing the loss and the model's output. """ return self.train_step(batch=batch)
[문서] def noise_scheduling(self, noise_scheduler: NoiseScheduler) -> None: """ Update the noise scheduler used by the trainer. Parameters ---------- noise_scheduler : NoiseScheduler New noise scheduler to be used. """ self.noise_scheduler = noise_scheduler
[문서] def generate( self, n_samples: int, channels: int, height: int, width: int, intermediate: int = 0, ddim: bool = False ) -> Tuple[Tensor, List[Any]]: """ Generate new samples using the trained diffusion model. Parameters ---------- n_samples : int Number of samples to generate. channels : int Number of channels in the generated images. height : int Height of the generated images. width : int Width of the generated images. intermediate : int, optional Number of intermediate samples to save during generation, by default 0. ddim : bool, optional Whether to use DDIM sampling. If True, uses DDIM; otherwise, uses DDPM. Returns ------- Tuple[Tensor, List[Any]] Generated samples tensor of shape (n_samples, channels, height, width). List of intermediate samples if specified. """ timesteps = self.noise_scheduler.timesteps alpha = self.noise_scheduler.alpha alpha_bar = self.noise_scheduler.alpha_bar posterior_variance = self.noise_scheduler.posterior_variance # Initialize with standard normal noise x = torch.randn(n_samples, channels, height, width) self.toggle("test") save_step = timesteps // intermediate if intermediate > 0 else 0 save_x = [] for t_step in reversed(range(timesteps)): t_tensor = torch.full((n_samples,), t_step, dtype=torch.long) batch = {"x_t": x, "t": t_tensor} self.run_batch(batch) predict = self.log_test["output"][0] if ddim: # Deterministic update for DDIM if t_step > 0: sqrt_alpha = torch.sqrt(alpha[t_step]) sqrt_alpha_prev = torch.sqrt(alpha[t_step - 1]) x_prev = (x - self.noise_scheduler.sqrt_one_minus_alpha_bar[t_step] * predict) / sqrt_alpha x_prev = ( sqrt_alpha_prev * x_prev + self.noise_scheduler.sqrt_one_minus_alpha_bar[t_step - 1] * predict ) else: # Final timestep (t=0) handling x_prev = (x - self.noise_scheduler.sqrt_one_minus_alpha_bar[t_step] * predict) / torch.sqrt( alpha[t_step] ) x = x_prev else: # DDPM sampling t_noise = self.noise_scheduler.beta[t_step] if t_step > 0: sqrt_recip_alpha = 1 / torch.sqrt(alpha[t_step]) mu_theta = sqrt_recip_alpha * ( x - (t_noise / self.noise_scheduler.sqrt_one_minus_alpha_bar[t_step]) * predict ) # Sample from the posterior noise = torch.randn_like(x) var = torch.sqrt(posterior_variance[t_step - 1]).view(-1, 1, 1, 1) x = mu_theta + var * noise else: x = (x - (t_noise / self.noise_scheduler.sqrt_one_minus_alpha_bar[t_step]) * predict) / torch.sqrt( alpha[t_step] ) if intermediate > 0 and t_step % save_step == save_step - 1: save_x.append(x.clone()) return x, save_x
[문서] def visualize_samples( self, final_samples: torch.Tensor, intermediate_images: List[Any] = None, train_losses: List[float] = None, valid_losses: List[float] = None, lr_history: List[float] = None, ) -> None: """ Visualize generated samples and training progress. Parameters ---------- final_samples : torch.Tensor Final generated samples. Shape: (n_samples, channels, height, width). intermediate_images : List[Any], optional Intermediate images for selected samples. Shape: (num_selected, channels, height, width), by default None. train_losses : List[float], optional Training loss history, by default None. valid_losses : List[float], optional Validation loss history, by default None. lr_history : List[float], optional Learning rate history, by default None. """ fig = plt.figure(figsize=(24, 18)) gs = fig.add_gridspec(3, 4, hspace=0.3, wspace=0.2) # Upper Left (0:2, 0:2) - 2x2 grid region: 16 generated images with 4x4 grids ax1 = fig.add_subplot(gs[0:2, 0:2]) grid_img = make_grid(final_samples, nrow=4, padding=2, normalize=True) grid_img_np = grid_img.permute(1, 2, 0).cpu().numpy() ax1.imshow(grid_img_np) ax1.set_title("Generated Images (4x4 Grid)") ax1.axis("off") # Upper Right (0:2, 2:4) - 2x2 grid region: visualization of selected images using intermediate output if intermediate_images is not None and len(intermediate_images) > 0: # merge intermediate outputs to single grid intermediate_grids = [] for img in intermediate_images: steps = torch.stack(img, dim=0) # (steps, C, H, W) steps_grid = make_grid(steps, nrow=1, padding=2, normalize=True) intermediate_grids.append(steps_grid) # merge multiple grids to single image intermediate_grids = torch.stack(intermediate_grids, dim=0) # (num_samples, C, H, W) intermediate_grids = make_grid(intermediate_grids, nrow=len(intermediate_grids), padding=1, normalize=True) intermediate_grids_np = intermediate_grids.permute(1, 2, 0).cpu().numpy() ax2 = fig.add_subplot(gs[0:2, 2:4]) ax2.imshow(intermediate_grids_np) ax2.set_title("Intermediate Steps of Selected Images") ax2.axis("off") # Lower (2,0:4) : Line chart of train, valid loss & learning rate history ax3 = fig.add_subplot(gs[2, :]) if train_losses is not None: ax3.plot(train_losses, label="Train Loss", color="blue") if valid_losses is not None: ax3.plot(valid_losses, label="Valid Loss", color="orange") if lr_history is not None: ax4 = ax3.twinx() # 두 번째 y축 생성 ax4.plot(lr_history, label="Learning Rate", color="green") ax4.set_ylabel("Learning Rate", color="green") ax4.tick_params(axis="y", labelcolor="green") ax4.legend(loc="upper right") ax3.set_title("Training and Validation Loss with Learning Rate") ax3.set_xlabel("Epoch") ax3.set_ylabel("Loss") ax3.legend(loc="upper left") ax3.grid(True) plt.tight_layout() plt.show()
if __name__ == "__main__": # config DDIM = False device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") batch_size = 32 epoch = 100 learning_rate = 5e-3 target_width = target_height = 64 data_path = "./mnist_example" # NoiseScheduler & ForwardDiffusion noise_scheduler = NoiseScheduler() forward_diffusion = ForwardDiffusion(noise_scheduler=noise_scheduler) print("ForwardDiffusion initialized.") transform = transforms.Compose( [ transforms.Resize((target_height, target_width)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), ] ) print("Transform defined.") dataset = CustomMNISTDataset(root=data_path, train=True, transform=transform, download=True) print("CustomMNISTDataset initialized.") collator = CollateBase(x_key=["pixel_values"], y_key=[], aux_key=[]) collator.add_fn(name="forward_diffusion", fn=forward_diffusion) train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collator.wrap()) print("DataLoader initialized.") model = DDPM( block=UNetBlock, ch_in=1, ch_out=1, width=8, layers=[1, 1, 1, 1], skip_connect=True, timestep_embed_dim=256 ) print("Model defined.") # Optimizer & Scheduler optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) warm_up_steps = int(0.1 * len(train_loader) * epoch) scheduler = SchedulerChain( WarmUpScheduler(optimizer=optimizer, total_iters=warm_up_steps), CosineAnnealingScheduler(optimizer=optimizer, total_iters=len(train_loader) * epoch - warm_up_steps), ) print("Optimizer and scheduler initialized.") # Trainer trainer = DDPMTrainer( model=model, device=device, mode="train", optimizer=optimizer, scheduler=scheduler, log_bar=True, scheduler_step_on_batch=True, gradient_clip=0.0, ) trainer.noise_scheduling(noise_scheduler) print("Trainer initialized.") # 학습 수행 print("Starting training...") trainer.run(n_epoch=epoch, loader=train_loader, valid_loader=None) trainer.save_model(os.path.join("../../ddpm_model.pth")) print("Training completed.") train_loss = trainer.log_train.get("loss", None) valid_loss = trainer.log_test.get("loss", None) # 샘플 생성 및 시각화 print("Generating samples...") trainer.toggle("test") generated = trainer.generate( n_samples=16, channels=1, height=target_height, width=target_width, intermediate=4, ddim=DDIM, ) generated_samples, generated_intermediate_samples = generated # handling intermediate outputs generated_intermediate_samples = torch.stack([inter[:4] for inter in generated_intermediate_samples]) trainer.visualize_samples( final_samples=generated_samples, intermediate_images=generated_intermediate_samples.permute(1, 0, 2, 3, 4), train_losses=trainer.get_loss_history('train'), valid_losses=None, ) print("Sample generation and visualization completed.") shutil.rmtree(data_path)