File size: 4,105 Bytes
3c96643 b3a6f56 3c96643 b3a6f56 3c96643 b3a6f56 3c96643 b3a6f56 3c96643 4e76fed 3c96643 4e76fed 3c96643 b3a6f56 3c96643 4e76fed | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | #!/usr/bin/env python3
from __future__ import annotations
from contextlib import nullcontext
from pathlib import Path
from typing import Iterable
import torch
from diffusers import DDPMScheduler
from torchvision.utils import make_grid, save_image
from transformers import AutoModel
MODEL_ID = "SupraLabs/SupraMNST-IMG-200k"
OUTPUT_IMAGE = "./digit_samples.png"
USE_MULTIPLE_DIGITS = False
DIGIT = 7
DIGITS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
IMAGES_PER_DIGIT = 4
NUM_INFERENCE_STEPS = 200
IMAGE_SIZE = 32
SEED = 42
USE_AMP = torch.cuda.is_available()
ALLOW_TF32 = False
def _selected_digits() -> list[int]:
if USE_MULTIPLE_DIGITS:
if not DIGITS:
raise ValueError("`DIGITS` must not be empty when `USE_MULTIPLE_DIGITS=True`")
return [int(d) for d in DIGITS]
return [int(DIGIT)]
def _load_model(model_id: str, device: torch.device):
model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
model.to(device)
model.eval()
return model
def _load_scheduler(model_id: str) -> DDPMScheduler:
return DDPMScheduler.from_pretrained(model_id, trust_remote_code=True)
def _to_display_range(x: torch.Tensor) -> torch.Tensor:
return ((x.clamp(-1.0, 1.0) + 1.0) / 2.0).cpu()
@torch.inference_mode()
def generate_grid(
model,
scheduler: DDPMScheduler,
device: torch.device,
digits: Iterable[int],
images_per_digit: int,
num_inference_steps: int,
image_size: int,
) -> torch.Tensor:
scheduler.set_timesteps(num_inference_steps, device=device)
rows: list[torch.Tensor] = []
for digit in digits:
autocast_ctx = (
torch.autocast(device_type="cuda", dtype=torch.float16)
if USE_AMP and device.type == "cuda"
else nullcontext()
)
latents = torch.randn(
images_per_digit,
1,
image_size,
image_size,
device=device,
)
class_labels = torch.full(
(images_per_digit,),
int(digit),
device=device,
dtype=torch.long,
)
with autocast_ctx:
for t in scheduler.timesteps:
t_batch = torch.full(
(images_per_digit,),
int(t),
device=device,
dtype=torch.long,
)
output = model(
noisy_images=latents,
timesteps=t_batch,
class_labels=class_labels,
)
noise_pred = output.sample if hasattr(output, "sample") else output[0]
latents = scheduler.step(noise_pred, t, latents).prev_sample
rows.append(_to_display_range(latents))
all_images = torch.cat(rows, dim=0)
nrow = images_per_digit
grid = make_grid(all_images, nrow=nrow)
return grid
def main() -> None:
if ALLOW_TF32 and torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
digits = _selected_digits()
print(f"[info] model_id = {MODEL_ID}")
print(f"[info] device = {device}")
print(f"[info] digits = {digits}")
print(f"[info] images_per_digit = {IMAGES_PER_DIGIT}")
print(f"[info] num_steps = {NUM_INFERENCE_STEPS}")
print(f"[info] output_image = {OUTPUT_IMAGE}")
model = _load_model(MODEL_ID, device)
scheduler = _load_scheduler(MODEL_ID)
grid = generate_grid(
model=model,
scheduler=scheduler,
device=device,
digits=digits,
images_per_digit=IMAGES_PER_DIGIT,
num_inference_steps=NUM_INFERENCE_STEPS,
image_size=IMAGE_SIZE,
)
out_path = Path(OUTPUT_IMAGE)
out_path.parent.mkdir(parents=True, exist_ok=True)
save_image(grid, out_path)
print(f"[done] saved to {out_path.resolve()}")
if __name__ == "__main__":
main() |