Image Diffusion by Pytorch

Xiwei Pan / 2025-01-27


Preface #

In my previous blogs , we have discussed the theoretical derivation of diffusion models, so I’d like to take this opportunity to walk through the code implementation of Denoising Diffusion Probabilistic Models (DDPM) and explore the logic behind building deep learning networks in PyTorch.

The code I use for investigating diffusion models was originally released by dome272 , so the associated code and data resources are easily accessible through the provided GitHub page. Here, I hope to take on two roles: first, by studying the specific diffusion model code, I aim to deepen my understanding of Python programming and the key PyTorch libraries; second, I want to integrate the conclusions related to DDPM network construction, which I have explored in my previous blogs, into a cohesive framework, allowing me and others to move more clearly into the practical implementation detail rather than getting lost in lengthy derivations.

Forward Diffusing #

According to the second restriction for an MHVAE to qualify as a VDM, the encoder structure must be pre-defined as a a Gaussian distribution centered around the output of the previous timestep, with the mean being $\pmb{\mu}_t(\pmb{x}_t)=\sqrt{\alpha_t}\pmb{x}_{t-1}$ and the variance being $\sum_t(\pmb{x}_t)=(1-\alpha_t)\pmb{\mathrm{I}}$. Mathematically, this is given by $$q(\pmb{x}_t|\pmb{x}_{t-1})=\mathcal{N}(\pmb{x}_t; \sqrt{\alpha_t}\pmb{x}_{t-1},(1-\alpha_t)\pmb{\mathrm{I}}),\tag{1} \label{eq1}$$ through reparameterization trick , samples at timestep $t$ ($\pmb{x}_t\sim q(\pmb{x}_t|\pmb{x}_{t-1})$) can relate to a standard Gaussian $\pmb{\epsilon}$, that is $$\pmb{x}_t=\sqrt{\alpha_t}\pmb{x}_{t-1}+\sqrt{1-\alpha_t}\pmb{\epsilon},\quad\pmb{\epsilon}\sim\mathcal{N}(\pmb{\epsilon};\pmb{0},\pmb{\mathrm{I}}). \tag{2} \label{eq2}$$

Note that the notation $\beta_t=1-\alpha_t$, commonly referred to as the “diffusion coefficient”, is also widely used in both papers and coding. In this groundbreaking paper of DDPM , the value of $\beta_t$ is set to range from $\beta_1=1e-4$ to $\beta_T=0.02$, and the time step $T=1000$.

The diffused image $q(\pmb{x}_t|\pmb{x}_0)$ can now be recursively derived using the reparameterization trick: $$\pmb{x}_t=\sqrt{\bar{\alpha}_t}\pmb{x}_0+\sqrt{1-\bar{\alpha}_t}\pmb{\epsilon}\sim\mathcal{N}\left(\pmb{x}_t;\sqrt{\bar{\alpha}_t}\pmb{x}_0,(1-\bar{\alpha}_t)\pmb{\mathrm{I}}\right), \tag{3} \label{eq3}$$ here $\bar{\alpha}_t=\prod_{i=1}^t\alpha_i$. $\alpha_t$ decreases with the increase of $\beta_t$, so as $t\to T$ (typically a large iteration number), $\bar{\alpha}_T\to 0$, which causes the resulting image $\pmb{x}_T\to \pmb{\epsilon}$. The image diffusion process is illustrated in Fig. 1, and the corresponding Python code is provided below.

Figure 1: Progressively diffuse an image using random noise.

Figure 1: Progressively diffuse an image using random noise.

  import torch
  from torchvision.utils import make_grid, save_image
  from torchvision import transforms
  from PIL import Image

  beta = torch.linspace(0.02, 0.1, 1000).double()
  alpha = 1. - beta
  alpha_bar = torch.cumprod(alpha, dim=0)
  sqrt_alpha_bar = torch.sqrt(alpha_bar)
  sqrt_one_minus_alpha_bar = torch.sqrt(1 - alpha_bar)

  img = Image.open('car.jpg')  # read image
  trans = transforms.Compose([
      transforms.Resize(224),  # resize the input image to have a height and width of 224 pixels
      transforms.ToTensor()  # convert to tensor form
  ])
  x_0 = trans(img)
  img_list = [x_0]
  noise = torch.randn_like(x_0)  # standard Gaussian distribution
  for i in range(15):
      x_t = sqrt_alpha_bar[i] * x_0 + sqrt_one_minus_alpha_bar[i] * noise
      img_list.append(x_t)  # add items to the end of an image list

  all_img = torch.stack(img_list, dim=0)  # combine a list of tensors into a single tensor along a new dimension (dim=0 means the batch dimension)
  all_img = make_grid(all_img)  # default padding and nrow: 2*8
  save_image(all_img, 'car_noise.jpg')

Reverse Denoising #

The key to denoising is the computation of the posterior probability $q(\pmb{x}_{t-1}|\pmb{x}_t,\pmb{x}_0)$ given the original image $\pmb{x}_0$. With the Bayes rule, we have \begin{align} q(\pmb{x}_{t-1}|\pmb{x}_t,\pmb{x}_0)&=\frac{q(\pmb{x}_t|\pmb{x}_{t-1},\pmb{x}_0)q(\pmb{x}_{t-1}|\pmb{x}_0)}{q(\pmb{x}_t|\pmb{x}_0)}\\ &=\frac{\mathcal{N}\left(\pmb{x}_t;\sqrt{\alpha_t}\pmb{x}_{t-1},(1-\alpha_t)\pmb{\mathrm{I}}\right)\mathcal{N}\left(\pmb{x}_{t-1};\sqrt{\bar{\alpha}_{t-1}}\pmb{x}_0,(1-\bar{\alpha}_{t-1})\pmb{\mathrm{I}}\right)}{\mathcal{N}\left(\pmb{x}_t;\sqrt{\bar{\alpha}_t}\pmb{x}_0,(1-\bar{\alpha}_t)\pmb{\mathrm{I}}\right)}\\ &\propto\mathcal{N}\left(\pmb{x}_{t-1};\underbrace{\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})\pmb{x}_t+\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)\pmb{x}_0}{1-\bar{\alpha}_t}}_{\mu_q(\pmb{x}_t,\pmb{x}_0)},\underbrace{\frac{(1-\alpha_t)(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}}_{\sigma_q^2(t)}\pmb{\mathrm{I}}\right). \tag{4} \label{eq4} \end{align}

Substituting Equation $\eqref{eq3}$ into the mean term in Equation $\eqref{eq4}$ and eliminating $\pmb{x}_0$ give $$\mu_q(\pmb{x}_t,\pmb{x}_0)=\frac{1}{\sqrt{\alpha}_t}\left(\pmb{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\pmb{\epsilon}\right), \tag{5} \label{eq5}$$ here the noise $\pmb{\epsilon}$ in the reverse process remains unknown, which calls for neural networks to serve as solutions.

Training Process #

In this blog , we have derived that optimizing a VDM is equivalent to minimizing the difference between the means of two distributions. So according to Equation $\eqref{eq5}$, this further boils down to learning a neural network to predict the source noise that results $\pmb{x}_t$ from $\pmb{x}_0$. Here, the U-Net ($\pmb{\epsilon}_\theta(\pmb{x}_t,t)$) are adopted to predict the unknown noise $\pmb{\epsilon}$.

The specific pseudocode of the training procedure is illustrated in Fig. 2.

Figure 2: Complete training procedure of the DDPM (Ho et al., 2020).

Figure 2: Complete training procedure of the DDPM (Ho et al., 2020).

Sampling Process #

From Equation $\eqref{eq4}$, we find that $\pmb{x}_{t-1}\sim\mathcal{N}\left(\pmb{x}_{t-1};\mu_q(\pmb{x}_t,\pmb{x}_0),\sigma_q^2(t)\,\pmb{\mathrm{I}}\right)$. Again, using reparameterization trick, we have $$\pmb{x}_{t-1}=\frac{1}{\sqrt{\alpha}_t}\left(\pmb{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\pmb{\epsilon}_\theta(\pmb{x}_t,t)\right)+\sigma_q(t)\pmb{\epsilon}, \tag{6} \label{eq6}$$ where $\pmb{\epsilon}_\theta(\pmb{x}_t,t)$ is parameterized by a neural network (e.g., a U-Net) that is optimized to match the noise term $\pmb{\epsilon}$ in Equation $\eqref{eq5}$ in the reverse sampling stage.

Fig. 3 shows the specific pseudocode of the sampling procedure of the DDPM.

Figure 3: Complete sampling procedure of the DDPM (Ho et al., 2020).

Figure 3: Complete sampling procedure of the DDPM (Ho et al., 2020).

Specific Code Implementation #

The complete Pytorch code for the image diffusion model described above is shown below, divided into three main modules. The ddpm.py file mainly includes the construction of the diffusion model’s network architecture, training, and sampling; the modules.py file primarily focuses on the architecture of U-Net; and the utils.py file contains functions related to plotting and file saving. Detailed explanations of the specific code and related functions are provided as comments in the code for reference by interested readers.

ddpm.py Module:

  import os
  import torch
  import torch.nn as nn
  from matplotlib import pyplot as plt
  from tqdm import tqdm
  from torch import optim
  from utils import *
  from modules import UNet
  import logging
  from torch.utils.tensorboard import SummaryWriter

  logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")

  class Diffusion:
      def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, device="cuda"):
          self.noise_steps = noise_steps
          self.beta_start = beta_start
          self.beta_end = beta_end
          self.img_size = img_size
          self.device = device

          self.beta = self.prepare_noise_schedule().to(device)  # .to -- dtype cast or device conversion
          self.alpha = 1. - self.beta
          self.alpha_bar = torch.cumprod(self.alpha, dim=0)  # cumulative product

      def prepare_noise_schedule(self):
          return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

      def noise_images(self, x, t):  # diffuse the original image x
          sqrt_alpha_bar = torch.sqrt(self.alpha_bar[t])[:, None, None, None]  # adding dimension 3 times
          sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bar[t])[:, None, None, None]
          Ɛ = torch.randn_like(x)  # standard Gaussian distribution, noise ground truth
          return sqrt_alpha_bar * x + sqrt_one_minus_alpha_bar * Ɛ, Ɛ

      def sample_timesteps(self, n):
          return torch.randint(low=1, high=self.noise_steps, size=(n,))  # generate a tensor filled with random integers

      def sample(self, model, n):  # model evaluation/prediction
          logging.info(f"Sampling {n} new images....")
          model.eval()  # setting 'Dropout' and 'Batch Normalization' to evaluation mode
          with torch.no_grad():  # run forward process without recording gradient information in model evaluation
              x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)  # random number in standard Gaussian distribution
              for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                  t = (torch.ones(n) * i).long().to(self.device)
                  predicted_noise = model(x, t)  # predicted by UNet
                  alpha = self.alpha[t][:, None, None, None]
                  alpha_bar = self.alpha_bar[t][:, None, None, None]
                  beta = self.beta[t][:, None, None, None]
                  if i > 1:
                      noise = torch.randn_like(x)
                  else:
                      noise = torch.zeros_like(x)
                  # reparameterization trick
                  x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_bar))) * predicted_noise) + torch.sqrt(beta) * noise
          model.train()  # setting 'Dropout' and 'Batch Normalization' to training mode
          x = (x.clamp(-1, 1) + 1) / 2  # clamp the value to [0,1]
          x = (x * 255).type(torch.uint8)  # bring the predicted value to valid pixel range [0,255]
          return x

  def train(args):
      setup_logging(args.run_name)
      device = args.device
      dataloader = get_data(args)
      model = UNet().to(device)
      optimizer = optim.AdamW(model.parameters(), lr=args.lr)
      mse = nn.MSELoss()
      diffusion = Diffusion(img_size=args.image_size, device=device)
      logger = SummaryWriter(os.path.join("runs", args.run_name))
      l = len(dataloader)  # len(dataloader)=total number of samples/batch size, return the number of batches in the dataset

      for epoch in range(args.epochs):  # the total number of times the model will go through the entire dataset
          logging.info(f"Starting epoch {epoch}:")
          pbar = tqdm(dataloader)  # "tqdm" is a Python library used to display a progress bar for loops, showing how many batches have been processed
          # This iterates over the batches of data provided by the "dataloader". Each batch consists of images and their corresponding labels, "i" is the index of the current batch
          for i, (images, _) in enumerate(pbar):
              images = images.to(device)  # moves the image tensor to the appropriate device
              t = diffusion.sample_timesteps(images.shape[0]).to(device)  # generates a sequence of timesteps for the current batch of images based on the batch size (images.shape[0] is the number of images in the batch)
              x_t, noise = diffusion.noise_images(images, t)  # adds noise to the images according to the diffusion process
              predicted_noise = model(x_t, t)
              loss = mse(noise, predicted_noise)  # training to better predict the added noise

              optimizer.zero_grad()  # clears the gradients of all optimized tensors from the previous step before the new backward pass
              loss.backward()  # performs backpropagation to compute the gradients of the loss with respect to the model's parameters
              optimizer.step()  # updates the model's parameters using the computed gradients and the optimizer's update rule (e.g., Adam, SGD)

              pbar.set_postfix(MSE=loss.item())  # updates the progress bar with the current MSE loss value. "loss.item()" gets the scalar value of the loss (instead of the full tensor)
              logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)

          sampled_images = diffusion.sample(model, n=images.shape[0])
          save_images(sampled_images, os.path.join("results", args.run_name, f"{epoch}.jpg"))
          torch.save(model.state_dict(), os.path.join("models", args.run_name, f"ckpt.pt"))

  # =============================================================================
  # Adopted Model Parameters
  # =============================================================================
  def launch():
      import argparse
      parser = argparse.ArgumentParser()
      args = parser.parse_args()
      args.run_name = "DDPM_Uncondtional"
      args.epochs = 500
      args.batch_size = 12
      args.image_size = 64
      args.dataset_path = r"E:\datasets"
      args.device = "cuda"
      args.lr = 3e-4
      train(args)

  if __name__ == '__main__':
      launch()
      # device = "cuda"
      # model = UNet().to(device)
      # ckpt = torch.load("./working/orig/ckpt.pt")
      # model.load_state_dict(ckpt)
      # diffusion = Diffusion(img_size=64, device=device)
      # x = diffusion.sample(model, 8)
      # print(x.shape)
      # plt.figure(figsize=(32, 32))
      # plt.imshow(torch.cat([
      #     torch.cat([i for i in x.cpu()], dim=-1),
      # ], dim=-2).permute(1, 2, 0).cpu())
      # plt.show()

modules.py Module:

  import torch
  import torch.nn as nn
  import torch.nn.functional as F

  # =============================================================================
  # Exponential Moving Average Module
  # commonly used for maintaining a moving average of model parameters, which
  # helps in stabilizing training and improving generalization during model optimization
  # =============================================================================
  class EMA:
      def __init__(self, beta):
          super().__init__()
          self.beta = beta
          self.step = 0

      def update_model_average(self, ma_model, current_model):
          for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
              old_weight, up_weight = ma_params.data, current_params.data
              ma_params.data = self.update_average(old_weight, up_weight)

      def update_average(self, old, new):
          if old is None:
              return new
          return old * self.beta + (1 - self.beta) * new

      def step_ema(self, ema_model, model, step_start_ema=2000):
          if self.step < step_start_ema:
              self.reset_parameters(ema_model, model)
              self.step += 1
              return
          self.update_model_average(ema_model, model)
          self.step += 1

      def reset_parameters(self, ema_model, model):
          ema_model.load_state_dict(model.state_dict())

  # =============================================================================
  # Self Attention Module
  # allow the model to focus on different spatial regions of the input image,
  # enabling it to capture long-range dependencies between pixels
  # =============================================================================
  class SelfAttention(nn.Module):
      def __init__(self, channels, size):
          super(SelfAttention, self).__init__()
          self.channels = channels
          self.size = size
          self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)  # applies multi-head self-attention to the input; 4: the number of attention heads
          self.ln = nn.LayerNorm([channels])  # normalizes the input tensor across the channels dimension
          # Feedforward Network: a small MLP applied after attention to enhance the feature representation
          self.ff_self = nn.Sequential(
              nn.LayerNorm([channels]),
              nn.Linear(channels, channels),
              nn.GELU(),
              nn.Linear(channels, channels),
          )

      # The following transformations reshape the input tensor "x" to a sequence of patches (like how transformers process image data)
      def forward(self, x):
          # flattens the spatial dimensions (height and width) into one dimension (self.size * self.size). Then swaps the second and third dimensions, making it compatible for processing with the Multihead Attention layer
          x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
          x_ln = self.ln(x)
          attention_value, _ = self.mha(x_ln, x_ln, x_ln)  # _ is used to discard the second value (the attention weights and scores)
          attention_value = attention_value + x
          attention_value = self.ff_self(attention_value) + attention_value
          return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)  # restoring the original shape

  # =============================================================================
  # Double Convolution Module
  # =============================================================================
  class DoubleConv(nn.Module):
      def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
          super().__init__()  # calls the constructor of nn.Module to initialize the parent class
          self.residual = residual
          if not mid_channels:  # both convolutional layers will have the same number of channels
              mid_channels = out_channels
          self.double_conv = nn.Sequential(
              nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),  # Conv layer
              nn.GroupNorm(1, mid_channels),  # applies Group Normalization over a mini-batch of inputs
              nn.GELU(),  # Gaussian Error Linear Unit (GELU) activation function
              nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),  # Another Conv layer
              nn.GroupNorm(1, out_channels),
          )  # defines a stack of layers

      def forward(self, x):  # defines the actual computation that happens when input x is passed through the module
          if self.residual:  # Residual/Skip Connection, which helps with training deep networks by allowing gradients to flow more easily through the network during backpropagation
              return F.gelu(x + self.double_conv(x))
          else:
              return self.double_conv(x)

  # =============================================================================
  # Downsampling with Maxpool then Double Convolution
  # =============================================================================
  class Down(nn.Module):  # inherits from nn.Module
      def __init__(self, in_channels, out_channels, emb_dim=256):
          super().__init__()
          self.maxpool_conv = nn.Sequential(
              nn.MaxPool2d(2),  # 2D max-pooling layer
              DoubleConv(in_channels, in_channels, residual=True),  # the number of feature channels is unchanged, and this layer include a skip connection process
              DoubleConv(in_channels, out_channels),  # this part changes the feature dimensionality, downsampling and expanding the representation
          )
          # nn.Sequential() --- A container that applies modules in a sequential order
          # All modules inside this Sequential block will be executed in the given order when you call forward() on this maxpool_conv layer

          #  Embedding: mapping high-dimensional discrete data to low-dimensional continuous vector
          self.emb_layer = nn.Sequential(
              nn.SiLU(),  # Sigmoid-weighted Linear Unit (SiLU) activation function
              nn.Linear(
                  emb_dim,
                  out_channels
              ),  # a fully connected (linear) layer that maps the "emb_dim"-dimensional input "t" to an output with "out_channels" dimensions
          )

      def forward(self, x, t):
          x = self.maxpool_conv(x)  # max-pooling and convolution, downsampling the input image and extracting features
          # the "None" indexing creates two new dimensions to match the spatial dimensions of the input "x", i.e., (batch_size, out_channels, 1, 1)
          # repeats the "emb" tensor along the spatial dimensions to match the height and width of the input "x", resulting in (batch_size, out_channels, H, W)
          emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
          # skip connection allows the model to incorporate both the feature map and the embedding, which helps the network propagate information across layers without losing critical details
          return x + emb

  # =============================================================================
  # Upsampling and Feature Fusion
  # =============================================================================
  class Up(nn.Module):
      def __init__(self, in_channels, out_channels, emb_dim=256):
          super().__init__()

          # A bilinear upsampling operation that doubles the spatial resolution of the input tensor. It works by interpolating pixel values using a bilinear method
          # align_corners=True: This ensures that the corner pixels of the input and output are aligned during the upsampling process
          self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
          self.conv = nn.Sequential(
              DoubleConv(in_channels, in_channels, residual=True),  # applies convolutions with skip connections to preserve information across layers (number of skip channels are "in_channels")
              # in_channels // 2: reduce the number of channels after concatenation "[skip_x, x]", making the network more computationally efficient and helping it focus on the most relevant features
              DoubleConv(in_channels, out_channels, in_channels // 2),
          )

          self.emb_layer = nn.Sequential(
              nn.SiLU(),
              nn.Linear(
                  emb_dim,
                  out_channels
              ),
          )

      def forward(self, x, skip_x, t):
          x = self.up(x)  # the input is upsampled using bilinear interpolation, and the spatial resolution is increased
          x = torch.cat([skip_x, x], dim=1)  # the skip connection "skip_x" is concatenated with the upsampled "x" along the channel dimension (dim=1)
          x = self.conv(x)
          emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
          return x + emb

  # =============================================================================
  # Specific UNet Architecture
  # =============================================================================
  class UNet(nn.Module):
      def __init__(self, c_in=3, c_out=3, time_dim=256, device="cuda"):
          super().__init__()  # call the "init" of the parent class
          self.device = device
          self.time_dim = time_dim
          self.inc = DoubleConv(c_in, 64)  # two consecutive convolution layers, increasing input channels from c_in=3 to 64
          self.down1 = Down(64, 128)  # a downsampling block that reduces spatial dimensions while increasing the depth (channels)
          self.sa1 = SelfAttention(128, 32)  # applies a self-attention mechanism to capture long-range dependencies in the feature map
          self.down2 = Down(128, 256)
          self.sa2 = SelfAttention(256, 16)
          self.down3 = Down(256, 256)
          self.sa3 = SelfAttention(256, 8)

          self.bot1 = DoubleConv(256, 512)  # additional convolutional layers in the middle of the network
          self.bot2 = DoubleConv(512, 512)  # it serves as the bottleneck of the network where the feature map is at its smallest spatial resolution (highest depth)
          self.bot3 = DoubleConv(512, 256)

          # an upsampling block that increases the spatial dimensions while decreasing the depth
          self.up1 = Up(512, 128)  # concatenates the upsampled output with the corresponding feature map from the encoder (skip connection from down3 + output of bot3 = 512 channels)
          self.sa4 = SelfAttention(128, 16)
          self.up2 = Up(256, 64)  # input channel 256 = skip connection from down2 (input) 128 + ouput of up1 128
          self.sa5 = SelfAttention(64, 32)
          self.up3 = Up(128, 64)
          self.sa6 = SelfAttention(64, 64)
          self.outc = nn.Conv2d(64, c_out, kernel_size=1)  # a final convolution layer that reduces the feature map back to the desired output channels (an RGB image)

      # it generates positional encodings for time steps in the model and injects information about the position or timing of each element in a sequence, so the model can capture the order of elements
      def pos_encoding(self, t, channels):
          # torch.arrange(0,channels,2): creates a tensor of values from 0 to channels with a step size of 2 (the positional encoding is split into two parts: one for the sine function and one for the cosine function)
          # ... / channels: normalizes the values to the range [0, 1]
          # 10000 ** (normalized values): applies an exponential scaling factor, effectively spreading out the frequencies (the factor 10000 ensures that the encoding spans a wide range of frequencies)
          # 1.0 / ...: takes the inverse of the frequency values, we get lower values corresponding to a longer wavelength. This allows the sine and cosine functions to cover different spatial scales in the encoding
          inv_freq = 1.0 / (
              10000
              ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
          )
          # t.repeat(1, channels // 2): repeats the time tensor t across the second dimension to match the size of the frequency tensor "inv_freq" (e.g., (batch_size, 1), channels=256 --> (batch_size, 128))
          # * inv_freq: each time step t is scaled by different frequency values, which ensures the sine wave has different periods depending on the position in the sequence
          # torch.sin(...): sine waves with different frequencies are used to encode the position in the sequence. The resulting tensor "pos_enc_a" will have the same shape as "t.repeat(1, channels // 2)"
          # torch.cat(...): concatenates "pos_enc_a" and "pos_enc_b" along the last dimension (the feature/channel dimension), resulting a tensor of shape (batch_size, channels)
          # Why Sine and Cosine? They generate periodic patterns with different wavelengths allowing the model to distinguish between different positions or time steps
          #                      By using both, you ensure that the positional encoding captures both local (how close are two time steps?) and global (where is this time step in the entire sequence?) patterns
          pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
          pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
          pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
          return pos_enc

      def forward(self, x, t):
          t = t.unsqueeze(-1).type(torch.float)  # t.unsqueeze(-1): adds a new dimension at the last position of the tensor t, ensuring that t becomes a 2D tensor (batch_size, 1) for later dimension match
          t = self.pos_encoding(t, self.time_dim)

          x1 = self.inc(x)
          x2 = self.down1(x1, t)
          x2 = self.sa1(x2)
          x3 = self.down2(x2, t)
          x3 = self.sa2(x3)
          x4 = self.down3(x3, t)
          x4 = self.sa3(x4)

          x4 = self.bot1(x4)
          x4 = self.bot2(x4)
          x4 = self.bot3(x4)

          x = self.up1(x4, x3, t)
          x = self.sa4(x)
          x = self.up2(x, x2, t)
          x = self.sa5(x)
          x = self.up3(x, x1, t)
          x = self.sa6(x)
          output = self.outc(x)
          return output

  if __name__ == '__main__':
      # net = UNet(device="cpu")
      net = UNet_conditional(num_classes=10, device="cpu")
      print(sum([p.numel() for p in net.parameters()]))
      x = torch.randn(3, 3, 64, 64)
      t = x.new_tensor([500] * x.shape[0]).long()
      y = x.new_tensor([1] * x.shape[0]).long()
      print(net(x, t, y).shape)

utils.py Module:

  import os
  import torch
  import torchvision
  from PIL import Image
  from matplotlib import pyplot as plt
  from torch.utils.data import DataLoader

  # PyTorch image tensors are typically in the format [batch_size, channels, height, width]
  # For visualization with matplotlib, the format needs to be [height, width, channels] (i.e., HWC format), which is the format expected by "imshow"
  # permute(1, 2, 0) swaps the dimensions of the tensor, moving the channels dimension from the first axis to the last
  def plot_images(images):
      plt.figure(figsize=(32, 32))
      plt.imshow(torch.cat([
          torch.cat([i for i in images.cpu()], dim=-1),
      ], dim=-2).permute(1, 2, 0).cpu())
      plt.show()

  # "images" is a tensor, and "make_grid" arranges the images into a grid, by default horizontally and vertically
  # the output, 'grid', will have the shape [channels, height, total_width], where total_width depends on how many images are in the grid
  # .numpy(): Converts the PyTorch tensor into a NumPy array, as the "PIL.Image.fromarray()" function requires a NumPy array to create an image
  def save_images(images, path, **kwargs):
      grid = torchvision.utils.make_grid(images, **kwargs)
      ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
      im = Image.fromarray(ndarr)
      im.save(path)

  # designed to load and preprocess an image dataset using PyTorch's "torchvision" library
  # torchvision.transforms.Compose([...]): a function that chains together several image transformation operations. Each operation is applied sequentially to each image in the dataset
  # .RandomResizedCrop(...): applies a random crop to the resized image and resizes it to args.image_size
  # scale=(0.8, 1.0) means the random crop will take a random area of the image such that the area of the crop is between 80% to 100% of the original image area
  # Random-crop operation helps with data augmentation, making the model more robust to slight variations in the image content and positioning
  # .ToTensor(): converts the image to a PyTorch tensor [channels, height, width], which is required for neural network training
  # .Normalize(): normalizes the image tensor by subtracting the mean (0.5, 0.5, 0.5) and dividing by the standard deviation (0.5, 0.5, 0.5), ensuring that the pixel values are in the range [-1, 1]
  # torchvision.datasets.ImageFolder(): a standard dataset class for loading images organized in directories, where each directory represents a different class
  # "transform=transforms" applies the previously defined transformations to each image as it is loaded
  # "DataLoader" is a PyTorch class that provides an efficient way to load and batch the dataset. "shuffle=True": the dataset is shuffled before each epoch, preventing the model from learning any unintended patterns in the order of the data
  # "return dataloader", which can be used to iterate through the dataset in batches during training or evaluation
  def get_data(args):
      transforms = torchvision.transforms.Compose([
          torchvision.transforms.Resize(80),  # args.image_size(64) + 1/4 *args.image_size
          torchvision.transforms.RandomResizedCrop(args.image_size, scale=(0.8, 1.0)),
          torchvision.transforms.ToTensor(),
          torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
      ])
      dataset = torchvision.datasets.ImageFolder(args.dataset_path, transform=transforms)
      dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
      return dataloader

  # designed to create the necessary directories for saving model files and results during a machine learning run
  # os.makedirs(): a function used to create directories. It will create all the intermediate directories in the specified path if they do not already exist
  # "models": the path to the directory where model files (such as checkpoints, weights, etc.) will be stored
  # "results": the directory where evaluation results, logs, or any other output from the model might be saved
  # "exist_ok=True": ensures that if the directory already exists, no error will be raised
  # os.path.join("models", run_name): combines the "models" directory path with the run_name to create a path like "models/run_name", where run_name will be the unique identifier for this particular run
  # os.path.join("results", run_name): creates a subdirectory inside "results" with the name run_name (specific to the current run)
  def setup_logging(run_name):
      os.makedirs("models", exist_ok=True)
      os.makedirs("results", exist_ok=True)
      os.makedirs(os.path.join("models", run_name), exist_ok=True)
      os.makedirs(os.path.join("results", run_name), exist_ok=True)

Example: Generated Images #

For the example section, we selected two datasets: one containing natural landscape images and the other containing flower images . Fig. 4 shows the results of the diffusion model trained to generate the corresponding landscape and flower images.

Figure 4: Images generated by the trained diffusion model. The left panel corresponds to the natural landscape dataset, while the right one corresponds to the flower dataset.

Figure 4: Images generated by the trained diffusion model. The left panel corresponds to the natural landscape dataset, while the right one corresponds to the flower dataset.

#diffusion models #machine learning #Python

Last modified on 2025-03-17