Harley-ml commited on
Commit
32272db
·
verified ·
1 Parent(s): f886944

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +184 -0
inference.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Simple Hugging Face inference script for the digit diffusion model.
3
+
4
+ No command-line arguments. Edit the values in the CONFIG section below.
5
+
6
+ What it does:
7
+ - loads the model from the Hugging Face Hub (or a local HF cache/path)
8
+ - loads the DDPM scheduler from the same repo
9
+ - generates one or more images for one digit or several digits
10
+ - saves everything as a single PNG grid
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from contextlib import nullcontext
16
+ from pathlib import Path
17
+ from typing import Iterable
18
+
19
+ import torch
20
+ from diffusers import DDPMScheduler
21
+ from torchvision.utils import make_grid, save_image
22
+ from transformers import AutoModel
23
+
24
+
25
+ # -----------------------------------------------------------------------------
26
+ # CONFIG — edit these values only
27
+ # -----------------------------------------------------------------------------
28
+
29
+ MODEL_ID = "your-hf-username/your-digit-diffusion-repo"
30
+ OUTPUT_IMAGE = "./digit_samples.png"
31
+
32
+ # Choose either a single digit or multiple digits.
33
+ USE_MULTIPLE_DIGITS = False
34
+ DIGIT = 7
35
+ DIGITS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
36
+
37
+ # How many images to generate for each selected digit.
38
+ IMAGES_PER_DIGIT = 4
39
+
40
+ # Number of denoising steps.
41
+ NUM_INFERENCE_STEPS = 1000
42
+
43
+ # Output image size should match training.
44
+ IMAGE_SIZE = 32
45
+
46
+ # Reproducibility.
47
+ SEED = 42
48
+
49
+ # Optional performance knobs.
50
+ USE_AMP = torch.cuda.is_available()
51
+ ALLOW_TF32 = True
52
+
53
+
54
+ # -----------------------------------------------------------------------------
55
+ # Helpers
56
+ # -----------------------------------------------------------------------------
57
+
58
+ def _selected_digits() -> list[int]:
59
+ if USE_MULTIPLE_DIGITS:
60
+ if not DIGITS:
61
+ raise ValueError("DIGITS must not be empty when USE_MULTIPLE_DIGITS=True")
62
+ return [int(d) for d in DIGITS]
63
+ return [int(DIGIT)]
64
+
65
+
66
+ def _load_model(model_id: str, device: torch.device):
67
+ """Load the custom HF model without defining any local model classes."""
68
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
69
+ model.to(device)
70
+ model.eval()
71
+ return model
72
+
73
+
74
+ def _load_scheduler(model_id: str) -> DDPMScheduler:
75
+ return DDPMScheduler.from_pretrained(model_id)
76
+
77
+
78
+ def _to_display_range(x: torch.Tensor) -> torch.Tensor:
79
+ """Map tensors from [-1, 1] to [0, 1]."""
80
+ return ((x.clamp(-1.0, 1.0) + 1.0) / 2.0).cpu()
81
+
82
+
83
+ @torch.inference_mode()
84
+ def generate_grid(
85
+ model,
86
+ scheduler: DDPMScheduler,
87
+ device: torch.device,
88
+ digits: Iterable[int],
89
+ images_per_digit: int,
90
+ num_inference_steps: int,
91
+ image_size: int,
92
+ ) -> torch.Tensor:
93
+ """Generate a single grid image containing all requested digits."""
94
+ scheduler.set_timesteps(num_inference_steps, device=device)
95
+
96
+ rows: list[torch.Tensor] = []
97
+ for digit in digits:
98
+ autocast_ctx = (
99
+ torch.autocast(device_type="cuda", dtype=torch.float16)
100
+ if USE_AMP and device.type == "cuda"
101
+ else nullcontext()
102
+ )
103
+ latents = torch.randn(
104
+ images_per_digit,
105
+ 1,
106
+ image_size,
107
+ image_size,
108
+ device=device,
109
+ )
110
+ class_labels = torch.full(
111
+ (images_per_digit,),
112
+ int(digit),
113
+ device=device,
114
+ dtype=torch.long,
115
+ )
116
+
117
+ with autocast_ctx:
118
+ for t in scheduler.timesteps:
119
+ t_batch = torch.full(
120
+ (images_per_digit,),
121
+ int(t),
122
+ device=device,
123
+ dtype=torch.long,
124
+ )
125
+
126
+ output = model(
127
+ noisy_images=latents,
128
+ timesteps=t_batch,
129
+ class_labels=class_labels,
130
+ )
131
+ noise_pred = output.sample if hasattr(output, "sample") else output[0]
132
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
133
+
134
+ rows.append(_to_display_range(latents))
135
+
136
+ all_images = torch.cat(rows, dim=0)
137
+ nrow = images_per_digit
138
+ grid = make_grid(all_images, nrow=nrow)
139
+ return grid
140
+
141
+
142
+ # -----------------------------------------------------------------------------
143
+ # Main
144
+ # -----------------------------------------------------------------------------
145
+
146
+ def main() -> None:
147
+ if ALLOW_TF32 and torch.cuda.is_available():
148
+ torch.backends.cuda.matmul.allow_tf32 = True
149
+
150
+ torch.manual_seed(SEED)
151
+ if torch.cuda.is_available():
152
+ torch.cuda.manual_seed_all(SEED)
153
+
154
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
155
+ digits = _selected_digits()
156
+
157
+ print(f"[info] model_id = {MODEL_ID}")
158
+ print(f"[info] device = {device}")
159
+ print(f"[info] digits = {digits}")
160
+ print(f"[info] images_per_digit = {IMAGES_PER_DIGIT}")
161
+ print(f"[info] num_steps = {NUM_INFERENCE_STEPS}")
162
+ print(f"[info] output_image = {OUTPUT_IMAGE}")
163
+
164
+ model = _load_model(MODEL_ID, device)
165
+ scheduler = _load_scheduler(MODEL_ID)
166
+
167
+ grid = generate_grid(
168
+ model=model,
169
+ scheduler=scheduler,
170
+ device=device,
171
+ digits=digits,
172
+ images_per_digit=IMAGES_PER_DIGIT,
173
+ num_inference_steps=NUM_INFERENCE_STEPS,
174
+ image_size=IMAGE_SIZE,
175
+ )
176
+
177
+ out_path = Path(OUTPUT_IMAGE)
178
+ out_path.parent.mkdir(parents=True, exist_ok=True)
179
+ save_image(grid, out_path)
180
+ print(f"[done] saved -> {out_path.resolve()}")
181
+
182
+
183
+ if __name__ == "__main__":
184
+ main()