| |
| from __future__ import annotations |
|
|
| from contextlib import nullcontext |
| from pathlib import Path |
| from typing import Iterable |
|
|
| import torch |
| from diffusers import DDPMScheduler |
| from torchvision.utils import make_grid, save_image |
| from transformers import AutoModel |
|
|
| MODEL_ID = "SupraLabs/SupraMNST-IMG-200k" |
| OUTPUT_IMAGE = "./digit_samples.png" |
| USE_MULTIPLE_DIGITS = False |
| DIGIT = 7 |
| DIGITS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] |
| IMAGES_PER_DIGIT = 4 |
| NUM_INFERENCE_STEPS = 200 |
| IMAGE_SIZE = 32 |
| SEED = 42 |
| USE_AMP = torch.cuda.is_available() |
| ALLOW_TF32 = False |
|
|
|
|
| def _selected_digits() -> list[int]: |
| if USE_MULTIPLE_DIGITS: |
| if not DIGITS: |
| raise ValueError("`DIGITS` must not be empty when `USE_MULTIPLE_DIGITS=True`") |
| return [int(d) for d in DIGITS] |
| return [int(DIGIT)] |
|
|
|
|
| def _load_model(model_id: str, device: torch.device): |
| model = AutoModel.from_pretrained(model_id, trust_remote_code=True) |
| model.to(device) |
| model.eval() |
| return model |
|
|
|
|
| def _load_scheduler(model_id: str) -> DDPMScheduler: |
| return DDPMScheduler.from_pretrained(model_id, trust_remote_code=True) |
|
|
|
|
| def _to_display_range(x: torch.Tensor) -> torch.Tensor: |
| return ((x.clamp(-1.0, 1.0) + 1.0) / 2.0).cpu() |
|
|
|
|
| @torch.inference_mode() |
| def generate_grid( |
| model, |
| scheduler: DDPMScheduler, |
| device: torch.device, |
| digits: Iterable[int], |
| images_per_digit: int, |
| num_inference_steps: int, |
| image_size: int, |
| ) -> torch.Tensor: |
| scheduler.set_timesteps(num_inference_steps, device=device) |
|
|
| rows: list[torch.Tensor] = [] |
| for digit in digits: |
| autocast_ctx = ( |
| torch.autocast(device_type="cuda", dtype=torch.float16) |
| if USE_AMP and device.type == "cuda" |
| else nullcontext() |
| ) |
| latents = torch.randn( |
| images_per_digit, |
| 1, |
| image_size, |
| image_size, |
| device=device, |
| ) |
| class_labels = torch.full( |
| (images_per_digit,), |
| int(digit), |
| device=device, |
| dtype=torch.long, |
| ) |
|
|
| with autocast_ctx: |
| for t in scheduler.timesteps: |
| t_batch = torch.full( |
| (images_per_digit,), |
| int(t), |
| device=device, |
| dtype=torch.long, |
| ) |
|
|
| output = model( |
| noisy_images=latents, |
| timesteps=t_batch, |
| class_labels=class_labels, |
| ) |
| noise_pred = output.sample if hasattr(output, "sample") else output[0] |
| latents = scheduler.step(noise_pred, t, latents).prev_sample |
|
|
| rows.append(_to_display_range(latents)) |
|
|
| all_images = torch.cat(rows, dim=0) |
| nrow = images_per_digit |
| grid = make_grid(all_images, nrow=nrow) |
| return grid |
|
|
|
|
| def main() -> None: |
| if ALLOW_TF32 and torch.cuda.is_available(): |
| torch.backends.cuda.matmul.allow_tf32 = True |
|
|
| torch.manual_seed(SEED) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(SEED) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| digits = _selected_digits() |
|
|
| print(f"[info] model_id = {MODEL_ID}") |
| print(f"[info] device = {device}") |
| print(f"[info] digits = {digits}") |
| print(f"[info] images_per_digit = {IMAGES_PER_DIGIT}") |
| print(f"[info] num_steps = {NUM_INFERENCE_STEPS}") |
| print(f"[info] output_image = {OUTPUT_IMAGE}") |
|
|
| model = _load_model(MODEL_ID, device) |
| scheduler = _load_scheduler(MODEL_ID) |
|
|
| grid = generate_grid( |
| model=model, |
| scheduler=scheduler, |
| device=device, |
| digits=digits, |
| images_per_digit=IMAGES_PER_DIGIT, |
| num_inference_steps=NUM_INFERENCE_STEPS, |
| image_size=IMAGE_SIZE, |
| ) |
|
|
| out_path = Path(OUTPUT_IMAGE) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| save_image(grid, out_path) |
| print(f"[done] saved to {out_path.resolve()}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |