MNIST-IMG-390k / inference.py
Harley-ml's picture
Update inference.py
b9aba18 verified
#!/usr/bin/env python3
from __future__ import annotations
from contextlib import nullcontext
from pathlib import Path
from typing import Iterable
import torch
from diffusers import DDPMScheduler
from torchvision.utils import make_grid, save_image
from transformers import AutoModel
# config
MODEL_ID = "Harley-ml/MNIST-IMG-390k"
OUTPUT_IMAGE = "./digits.png"
USE_MULTIPLE_DIGITS = False
DIGIT = 1
DIGITS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
IMAGES_PER_DIGIT = 4
NUM_INFERENCE_STEPS = 1000
IMAGE_SIZE = 32
SEED = 42
USE_AMP = torch.cuda.is_available()
ALLOW_TF32 = False
# Helpers
def _selected_digits() -> list[int]:
if USE_MULTIPLE_DIGITS:
if not DIGITS:
raise ValueError("`DIGITS` must not be empty when `USE_MULTIPLE_DIGITS=True`")
return [int(d) for d in DIGITS]
return [int(DIGIT)]
def _load_model(model_id: str, device: torch.device):
model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
model.to(device)
model.eval()
return model
def _load_scheduler(model_id: str) -> DDPMScheduler:
return DDPMScheduler.from_pretrained(model_id)
def _to_display_range(x: torch.Tensor) -> torch.Tensor:
return ((x.clamp(-1.0, 1.0) + 1.0) / 2.0).cpu()
@torch.inference_mode()
def generate_grid(
model,
scheduler: DDPMScheduler,
device: torch.device,
digits: Iterable[int],
images_per_digit: int,
num_inference_steps: int,
image_size: int,
) -> torch.Tensor:
# Generate a single grid image containing all requested digits.
scheduler.set_timesteps(num_inference_steps, device=device)
rows: list[torch.Tensor] = []
for digit in digits:
autocast_ctx = (
torch.autocast(device_type="cuda", dtype=torch.float16)
if USE_AMP and device.type == "cuda"
else nullcontext()
)
latents = torch.randn(
images_per_digit,
1,
image_size,
image_size,
device=device,
)
class_labels = torch.full(
(images_per_digit,),
int(digit),
device=device,
dtype=torch.long,
)
with autocast_ctx:
for t in scheduler.timesteps:
t_batch = torch.full(
(images_per_digit,),
int(t),
device=device,
dtype=torch.long,
)
output = model(
noisy_images=latents,
timesteps=t_batch,
class_labels=class_labels,
)
noise_pred = output.sample if hasattr(output, "sample") else output[0]
latents = scheduler.step(noise_pred, t, latents).prev_sample
rows.append(_to_display_range(latents))
all_images = torch.cat(rows, dim=0)
nrow = images_per_digit
grid = make_grid(all_images, nrow=nrow)
return grid
# -----------------------------------------------------------------------------
# Main
# -----------------------------------------------------------------------------
def main() -> None:
if ALLOW_TF32 and torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
digits = _selected_digits()
print(f"[info] model_id = {MODEL_ID}")
print(f"[info] device = {device}")
print(f"[info] digits = {digits}")
print(f"[info] images_per_digit = {IMAGES_PER_DIGIT}")
print(f"[info] num_steps = {NUM_INFERENCE_STEPS}")
print(f"[info] output_image = {OUTPUT_IMAGE}")
model = _load_model(MODEL_ID, device)
scheduler = _load_scheduler(MODEL_ID)
grid = generate_grid(
model=model,
scheduler=scheduler,
device=device,
digits=digits,
images_per_digit=IMAGES_PER_DIGIT,
num_inference_steps=NUM_INFERENCE_STEPS,
image_size=IMAGE_SIZE,
)
out_path = Path(OUTPUT_IMAGE)
out_path.parent.mkdir(parents=True, exist_ok=True)
save_image(grid, out_path)
print(f"[done] saved to {out_path.resolve()}")
if __name__ == "__main__":
main()