Class-Conditioned UNet for Fashion MNIST

A diffusion model trained from scratch to generate class-conditioned images from the Fashion MNIST dataset. Built on top of diffusers' UNet2DModel, it accepts a class label (0โ€“9) at inference time and generates the corresponding clothing category.


Model Architecture

The model wraps UNet2DModel with a learned class embedding that is concatenated to the noisy input as extra channels โ€” a simple but effective conditioning strategy that avoids cross-attention overhead.

Property Value
Base architecture UNet2DModel
Input resolution 28 ร— 28
Image channels 1 (grayscale)
Conditioning Class embedding (10 classes, embedding size 4)
Effective input channels 5 (1 image + 4 class embedding)
Noise scheduler DDPMScheduler (squaredcos_cap_v2)

Class Labels

Label Category
0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot

How to Use

1. Define ClassConditionedUnet

This wrapper must be defined before loading the model, as the pretrained weights cover only the inner UNet2DModel.

import torch
import torch.nn as nn
from diffusers import UNet2DModel

class ClassConditionedUnet(nn.Module):
    def __init__(self, num_classes=10, class_emb_size=4):
        super().__init__()
        self.class_emb = nn.Embedding(num_classes, class_emb_size)
        self.model = UNet2DModel(
            sample_size=28,
            in_channels=1 + class_emb_size,  # image + class embedding
            out_channels=1,
            layers_per_block=2,
            block_out_channels=(64, 128, 256),
            down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
            up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
        )

    def forward(self, x, t, class_labels):
        bs, ch, w, h = x.shape
        # Embed class labels and broadcast to spatial dimensions
        class_cond = self.class_emb(class_labels)                              # (bs, emb_size)
        class_cond = class_cond.view(bs, -1, 1, 1).expand(bs, -1, w, h)       # (bs, emb_size, w, h)
        net_input = torch.cat((x, class_cond), dim=1)                          # (bs, 1+emb_size, w, h)
        return self.model(net_input, t).sample

2. Load the Model

device = "cuda" if torch.cuda.is_available() else "cpu"

net = ClassConditionedUnet(num_classes=10, class_emb_size=4)
net.model = UNet2DModel.from_pretrained(
    "andreagemelli/UNet2DModel-fashion_mnist",
    use_safetensors=True,
)
net = net.to(device)
net.eval()

3. Generate Images

The snippet below generates 8 samples per class (80 images total) and displays them in a grid.

import torchvision
from diffusers import DDPMScheduler
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

# 8 samples per class, all 10 classes โ†’ 80 images
n_per_class = 8
x = torch.randn(n_per_class * 10, 1, 28, 28).to(device)
y = torch.tensor([[i] * n_per_class for i in range(10)]).flatten().to(device)

# Scheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2")
noise_scheduler.set_timesteps(1000)

# Reverse diffusion loop
for t in tqdm(noise_scheduler.timesteps, desc="Sampling"):
    with torch.no_grad():
        residual = net(x, t, y)
    x = noise_scheduler.step(residual, t, x).prev_sample

# Visualize
grid = torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=n_per_class, normalize=True)
fig, ax = plt.subplots(figsize=(12, 6))
ax.imshow(grid.permute(1, 2, 0), cmap="Greys")
ax.set_title("Generated Fashion MNIST Images (rows = classes 0โ€“9)")
ax.axis("off")
plt.tight_layout()
plt.show()

Training Details

Hyperparameter Value
Dataset Fashion MNIST
Epochs 5
Batch size 128
Learning rate 2e-4
Optimizer Adam
Loss function MSE (noise prediction)
Noise scheduler DDPMScheduler (squaredcos_cap_v2)
Timesteps 1000

Intended Use & Limitations

This model is intended for educational purposes โ€” specifically to demonstrate class-conditioned image generation using the diffusers library.

Limitations:

  • Output quality reflects the simplicity of Fashion MNIST (28ร—28 grayscale); this is not a high-fidelity generative model.
  • Trained for only 5 epochs โ€” longer training would improve sample quality and class coherence.
  • The class embedding approach is intentionally minimal. More sophisticated conditioning (e.g. cross-attention or classifier-free guidance) would yield sharper class separation.

Credits

Training inspired by the Hugging Face Diffusion Models Course.

Downloads last month
16
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Dataset used to train andreagemelli/UNet2DModel-fashion_mnist