Upload inference.py
Browse files- inference.py +184 -0
inference.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Simple Hugging Face inference script for the digit diffusion model.
|
| 3 |
+
|
| 4 |
+
No command-line arguments. Edit the values in the CONFIG section below.
|
| 5 |
+
|
| 6 |
+
What it does:
|
| 7 |
+
- loads the model from the Hugging Face Hub (or a local HF cache/path)
|
| 8 |
+
- loads the DDPM scheduler from the same repo
|
| 9 |
+
- generates one or more images for one digit or several digits
|
| 10 |
+
- saves everything as a single PNG grid
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
from contextlib import nullcontext
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Iterable
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from diffusers import DDPMScheduler
|
| 21 |
+
from torchvision.utils import make_grid, save_image
|
| 22 |
+
from transformers import AutoModel
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# -----------------------------------------------------------------------------
|
| 26 |
+
# CONFIG — edit these values only
|
| 27 |
+
# -----------------------------------------------------------------------------
|
| 28 |
+
|
| 29 |
+
MODEL_ID = "your-hf-username/your-digit-diffusion-repo"
|
| 30 |
+
OUTPUT_IMAGE = "./digit_samples.png"
|
| 31 |
+
|
| 32 |
+
# Choose either a single digit or multiple digits.
|
| 33 |
+
USE_MULTIPLE_DIGITS = False
|
| 34 |
+
DIGIT = 7
|
| 35 |
+
DIGITS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
| 36 |
+
|
| 37 |
+
# How many images to generate for each selected digit.
|
| 38 |
+
IMAGES_PER_DIGIT = 4
|
| 39 |
+
|
| 40 |
+
# Number of denoising steps.
|
| 41 |
+
NUM_INFERENCE_STEPS = 1000
|
| 42 |
+
|
| 43 |
+
# Output image size should match training.
|
| 44 |
+
IMAGE_SIZE = 32
|
| 45 |
+
|
| 46 |
+
# Reproducibility.
|
| 47 |
+
SEED = 42
|
| 48 |
+
|
| 49 |
+
# Optional performance knobs.
|
| 50 |
+
USE_AMP = torch.cuda.is_available()
|
| 51 |
+
ALLOW_TF32 = True
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# -----------------------------------------------------------------------------
|
| 55 |
+
# Helpers
|
| 56 |
+
# -----------------------------------------------------------------------------
|
| 57 |
+
|
| 58 |
+
def _selected_digits() -> list[int]:
|
| 59 |
+
if USE_MULTIPLE_DIGITS:
|
| 60 |
+
if not DIGITS:
|
| 61 |
+
raise ValueError("DIGITS must not be empty when USE_MULTIPLE_DIGITS=True")
|
| 62 |
+
return [int(d) for d in DIGITS]
|
| 63 |
+
return [int(DIGIT)]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _load_model(model_id: str, device: torch.device):
|
| 67 |
+
"""Load the custom HF model without defining any local model classes."""
|
| 68 |
+
model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
|
| 69 |
+
model.to(device)
|
| 70 |
+
model.eval()
|
| 71 |
+
return model
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _load_scheduler(model_id: str) -> DDPMScheduler:
|
| 75 |
+
return DDPMScheduler.from_pretrained(model_id)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _to_display_range(x: torch.Tensor) -> torch.Tensor:
|
| 79 |
+
"""Map tensors from [-1, 1] to [0, 1]."""
|
| 80 |
+
return ((x.clamp(-1.0, 1.0) + 1.0) / 2.0).cpu()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@torch.inference_mode()
|
| 84 |
+
def generate_grid(
|
| 85 |
+
model,
|
| 86 |
+
scheduler: DDPMScheduler,
|
| 87 |
+
device: torch.device,
|
| 88 |
+
digits: Iterable[int],
|
| 89 |
+
images_per_digit: int,
|
| 90 |
+
num_inference_steps: int,
|
| 91 |
+
image_size: int,
|
| 92 |
+
) -> torch.Tensor:
|
| 93 |
+
"""Generate a single grid image containing all requested digits."""
|
| 94 |
+
scheduler.set_timesteps(num_inference_steps, device=device)
|
| 95 |
+
|
| 96 |
+
rows: list[torch.Tensor] = []
|
| 97 |
+
for digit in digits:
|
| 98 |
+
autocast_ctx = (
|
| 99 |
+
torch.autocast(device_type="cuda", dtype=torch.float16)
|
| 100 |
+
if USE_AMP and device.type == "cuda"
|
| 101 |
+
else nullcontext()
|
| 102 |
+
)
|
| 103 |
+
latents = torch.randn(
|
| 104 |
+
images_per_digit,
|
| 105 |
+
1,
|
| 106 |
+
image_size,
|
| 107 |
+
image_size,
|
| 108 |
+
device=device,
|
| 109 |
+
)
|
| 110 |
+
class_labels = torch.full(
|
| 111 |
+
(images_per_digit,),
|
| 112 |
+
int(digit),
|
| 113 |
+
device=device,
|
| 114 |
+
dtype=torch.long,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
with autocast_ctx:
|
| 118 |
+
for t in scheduler.timesteps:
|
| 119 |
+
t_batch = torch.full(
|
| 120 |
+
(images_per_digit,),
|
| 121 |
+
int(t),
|
| 122 |
+
device=device,
|
| 123 |
+
dtype=torch.long,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
output = model(
|
| 127 |
+
noisy_images=latents,
|
| 128 |
+
timesteps=t_batch,
|
| 129 |
+
class_labels=class_labels,
|
| 130 |
+
)
|
| 131 |
+
noise_pred = output.sample if hasattr(output, "sample") else output[0]
|
| 132 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
| 133 |
+
|
| 134 |
+
rows.append(_to_display_range(latents))
|
| 135 |
+
|
| 136 |
+
all_images = torch.cat(rows, dim=0)
|
| 137 |
+
nrow = images_per_digit
|
| 138 |
+
grid = make_grid(all_images, nrow=nrow)
|
| 139 |
+
return grid
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# -----------------------------------------------------------------------------
|
| 143 |
+
# Main
|
| 144 |
+
# -----------------------------------------------------------------------------
|
| 145 |
+
|
| 146 |
+
def main() -> None:
|
| 147 |
+
if ALLOW_TF32 and torch.cuda.is_available():
|
| 148 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 149 |
+
|
| 150 |
+
torch.manual_seed(SEED)
|
| 151 |
+
if torch.cuda.is_available():
|
| 152 |
+
torch.cuda.manual_seed_all(SEED)
|
| 153 |
+
|
| 154 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 155 |
+
digits = _selected_digits()
|
| 156 |
+
|
| 157 |
+
print(f"[info] model_id = {MODEL_ID}")
|
| 158 |
+
print(f"[info] device = {device}")
|
| 159 |
+
print(f"[info] digits = {digits}")
|
| 160 |
+
print(f"[info] images_per_digit = {IMAGES_PER_DIGIT}")
|
| 161 |
+
print(f"[info] num_steps = {NUM_INFERENCE_STEPS}")
|
| 162 |
+
print(f"[info] output_image = {OUTPUT_IMAGE}")
|
| 163 |
+
|
| 164 |
+
model = _load_model(MODEL_ID, device)
|
| 165 |
+
scheduler = _load_scheduler(MODEL_ID)
|
| 166 |
+
|
| 167 |
+
grid = generate_grid(
|
| 168 |
+
model=model,
|
| 169 |
+
scheduler=scheduler,
|
| 170 |
+
device=device,
|
| 171 |
+
digits=digits,
|
| 172 |
+
images_per_digit=IMAGES_PER_DIGIT,
|
| 173 |
+
num_inference_steps=NUM_INFERENCE_STEPS,
|
| 174 |
+
image_size=IMAGE_SIZE,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
out_path = Path(OUTPUT_IMAGE)
|
| 178 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 179 |
+
save_image(grid, out_path)
|
| 180 |
+
print(f"[done] saved -> {out_path.resolve()}")
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
if __name__ == "__main__":
|
| 184 |
+
main()
|