Class-Conditioned UNet for Fashion MNIST
A diffusion model trained from scratch to generate class-conditioned images from the Fashion MNIST dataset. Built on top of diffusers' UNet2DModel, it accepts a class label (0โ9) at inference time and generates the corresponding clothing category.
Model Architecture
The model wraps UNet2DModel with a learned class embedding that is concatenated to the noisy input as extra channels โ a simple but effective conditioning strategy that avoids cross-attention overhead.
| Property | Value |
|---|---|
| Base architecture | UNet2DModel |
| Input resolution | 28 ร 28 |
| Image channels | 1 (grayscale) |
| Conditioning | Class embedding (10 classes, embedding size 4) |
| Effective input channels | 5 (1 image + 4 class embedding) |
| Noise scheduler | DDPMScheduler (squaredcos_cap_v2) |
Class Labels
| Label | Category |
|---|---|
| 0 | T-shirt/top |
| 1 | Trouser |
| 2 | Pullover |
| 3 | Dress |
| 4 | Coat |
| 5 | Sandal |
| 6 | Shirt |
| 7 | Sneaker |
| 8 | Bag |
| 9 | Ankle boot |
How to Use
1. Define ClassConditionedUnet
This wrapper must be defined before loading the model, as the pretrained weights cover only the inner UNet2DModel.
import torch
import torch.nn as nn
from diffusers import UNet2DModel
class ClassConditionedUnet(nn.Module):
def __init__(self, num_classes=10, class_emb_size=4):
super().__init__()
self.class_emb = nn.Embedding(num_classes, class_emb_size)
self.model = UNet2DModel(
sample_size=28,
in_channels=1 + class_emb_size, # image + class embedding
out_channels=1,
layers_per_block=2,
block_out_channels=(64, 128, 256),
down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
)
def forward(self, x, t, class_labels):
bs, ch, w, h = x.shape
# Embed class labels and broadcast to spatial dimensions
class_cond = self.class_emb(class_labels) # (bs, emb_size)
class_cond = class_cond.view(bs, -1, 1, 1).expand(bs, -1, w, h) # (bs, emb_size, w, h)
net_input = torch.cat((x, class_cond), dim=1) # (bs, 1+emb_size, w, h)
return self.model(net_input, t).sample
2. Load the Model
device = "cuda" if torch.cuda.is_available() else "cpu"
net = ClassConditionedUnet(num_classes=10, class_emb_size=4)
net.model = UNet2DModel.from_pretrained(
"andreagemelli/UNet2DModel-fashion_mnist",
use_safetensors=True,
)
net = net.to(device)
net.eval()
3. Generate Images
The snippet below generates 8 samples per class (80 images total) and displays them in a grid.
import torchvision
from diffusers import DDPMScheduler
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
# 8 samples per class, all 10 classes โ 80 images
n_per_class = 8
x = torch.randn(n_per_class * 10, 1, 28, 28).to(device)
y = torch.tensor([[i] * n_per_class for i in range(10)]).flatten().to(device)
# Scheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2")
noise_scheduler.set_timesteps(1000)
# Reverse diffusion loop
for t in tqdm(noise_scheduler.timesteps, desc="Sampling"):
with torch.no_grad():
residual = net(x, t, y)
x = noise_scheduler.step(residual, t, x).prev_sample
# Visualize
grid = torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=n_per_class, normalize=True)
fig, ax = plt.subplots(figsize=(12, 6))
ax.imshow(grid.permute(1, 2, 0), cmap="Greys")
ax.set_title("Generated Fashion MNIST Images (rows = classes 0โ9)")
ax.axis("off")
plt.tight_layout()
plt.show()
Training Details
| Hyperparameter | Value |
|---|---|
| Dataset | Fashion MNIST |
| Epochs | 5 |
| Batch size | 128 |
| Learning rate | 2e-4 |
| Optimizer | Adam |
| Loss function | MSE (noise prediction) |
| Noise scheduler | DDPMScheduler (squaredcos_cap_v2) |
| Timesteps | 1000 |
Intended Use & Limitations
This model is intended for educational purposes โ specifically to demonstrate class-conditioned image generation using the diffusers library.
Limitations:
- Output quality reflects the simplicity of Fashion MNIST (28ร28 grayscale); this is not a high-fidelity generative model.
- Trained for only 5 epochs โ longer training would improve sample quality and class coherence.
- The class embedding approach is intentionally minimal. More sophisticated conditioning (e.g. cross-attention or classifier-free guidance) would yield sharper class separation.
Credits
Training inspired by the Hugging Face Diffusion Models Course.
- Downloads last month
- 16