#!/usr/bin/env python3 from __future__ import annotations from contextlib import nullcontext from pathlib import Path from typing import Iterable import torch from diffusers import DDPMScheduler from torchvision.utils import make_grid, save_image from transformers import AutoModel # config MODEL_ID = "Harley-ml/MNIST-IMG-390k" OUTPUT_IMAGE = "./digits.png" USE_MULTIPLE_DIGITS = False DIGIT = 1 DIGITS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] IMAGES_PER_DIGIT = 4 NUM_INFERENCE_STEPS = 1000 IMAGE_SIZE = 32 SEED = 42 USE_AMP = torch.cuda.is_available() ALLOW_TF32 = False # Helpers def _selected_digits() -> list[int]: if USE_MULTIPLE_DIGITS: if not DIGITS: raise ValueError("`DIGITS` must not be empty when `USE_MULTIPLE_DIGITS=True`") return [int(d) for d in DIGITS] return [int(DIGIT)] def _load_model(model_id: str, device: torch.device): model = AutoModel.from_pretrained(model_id, trust_remote_code=True) model.to(device) model.eval() return model def _load_scheduler(model_id: str) -> DDPMScheduler: return DDPMScheduler.from_pretrained(model_id) def _to_display_range(x: torch.Tensor) -> torch.Tensor: return ((x.clamp(-1.0, 1.0) + 1.0) / 2.0).cpu() @torch.inference_mode() def generate_grid( model, scheduler: DDPMScheduler, device: torch.device, digits: Iterable[int], images_per_digit: int, num_inference_steps: int, image_size: int, ) -> torch.Tensor: # Generate a single grid image containing all requested digits. scheduler.set_timesteps(num_inference_steps, device=device) rows: list[torch.Tensor] = [] for digit in digits: autocast_ctx = ( torch.autocast(device_type="cuda", dtype=torch.float16) if USE_AMP and device.type == "cuda" else nullcontext() ) latents = torch.randn( images_per_digit, 1, image_size, image_size, device=device, ) class_labels = torch.full( (images_per_digit,), int(digit), device=device, dtype=torch.long, ) with autocast_ctx: for t in scheduler.timesteps: t_batch = torch.full( (images_per_digit,), int(t), device=device, dtype=torch.long, ) output = model( noisy_images=latents, timesteps=t_batch, class_labels=class_labels, ) noise_pred = output.sample if hasattr(output, "sample") else output[0] latents = scheduler.step(noise_pred, t, latents).prev_sample rows.append(_to_display_range(latents)) all_images = torch.cat(rows, dim=0) nrow = images_per_digit grid = make_grid(all_images, nrow=nrow) return grid # ----------------------------------------------------------------------------- # Main # ----------------------------------------------------------------------------- def main() -> None: if ALLOW_TF32 and torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.manual_seed(SEED) if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") digits = _selected_digits() print(f"[info] model_id = {MODEL_ID}") print(f"[info] device = {device}") print(f"[info] digits = {digits}") print(f"[info] images_per_digit = {IMAGES_PER_DIGIT}") print(f"[info] num_steps = {NUM_INFERENCE_STEPS}") print(f"[info] output_image = {OUTPUT_IMAGE}") model = _load_model(MODEL_ID, device) scheduler = _load_scheduler(MODEL_ID) grid = generate_grid( model=model, scheduler=scheduler, device=device, digits=digits, images_per_digit=IMAGES_PER_DIGIT, num_inference_steps=NUM_INFERENCE_STEPS, image_size=IMAGE_SIZE, ) out_path = Path(OUTPUT_IMAGE) out_path.parent.mkdir(parents=True, exist_ok=True) save_image(grid, out_path) print(f"[done] saved to {out_path.resolve()}") if __name__ == "__main__": main()