Harley-ml commited on
Commit
837c003
·
verified ·
1 Parent(s): 32272db

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +7 -37
inference.py CHANGED
@@ -1,15 +1,4 @@
1
  #!/usr/bin/env python3
2
- """Simple Hugging Face inference script for the digit diffusion model.
3
-
4
- No command-line arguments. Edit the values in the CONFIG section below.
5
-
6
- What it does:
7
- - loads the model from the Hugging Face Hub (or a local HF cache/path)
8
- - loads the DDPM scheduler from the same repo
9
- - generates one or more images for one digit or several digits
10
- - saves everything as a single PNG grid
11
- """
12
-
13
  from __future__ import annotations
14
 
15
  from contextlib import nullcontext
@@ -22,49 +11,31 @@ from torchvision.utils import make_grid, save_image
22
  from transformers import AutoModel
23
 
24
 
25
- # -----------------------------------------------------------------------------
26
- # CONFIG — edit these values only
27
- # -----------------------------------------------------------------------------
28
-
29
- MODEL_ID = "your-hf-username/your-digit-diffusion-repo"
30
- OUTPUT_IMAGE = "./digit_samples.png"
31
 
32
- # Choose either a single digit or multiple digits.
 
33
  USE_MULTIPLE_DIGITS = False
34
- DIGIT = 7
35
  DIGITS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
36
-
37
- # How many images to generate for each selected digit.
38
  IMAGES_PER_DIGIT = 4
39
-
40
- # Number of denoising steps.
41
  NUM_INFERENCE_STEPS = 1000
42
-
43
- # Output image size should match training.
44
  IMAGE_SIZE = 32
45
-
46
- # Reproducibility.
47
  SEED = 42
48
-
49
- # Optional performance knobs.
50
  USE_AMP = torch.cuda.is_available()
51
- ALLOW_TF32 = True
52
 
53
-
54
- # -----------------------------------------------------------------------------
55
  # Helpers
56
- # -----------------------------------------------------------------------------
57
 
58
  def _selected_digits() -> list[int]:
59
  if USE_MULTIPLE_DIGITS:
60
  if not DIGITS:
61
- raise ValueError("DIGITS must not be empty when USE_MULTIPLE_DIGITS=True")
62
  return [int(d) for d in DIGITS]
63
  return [int(DIGIT)]
64
 
65
 
66
  def _load_model(model_id: str, device: torch.device):
67
- """Load the custom HF model without defining any local model classes."""
68
  model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
69
  model.to(device)
70
  model.eval()
@@ -76,7 +47,6 @@ def _load_scheduler(model_id: str) -> DDPMScheduler:
76
 
77
 
78
  def _to_display_range(x: torch.Tensor) -> torch.Tensor:
79
- """Map tensors from [-1, 1] to [0, 1]."""
80
  return ((x.clamp(-1.0, 1.0) + 1.0) / 2.0).cpu()
81
 
82
 
@@ -177,7 +147,7 @@ def main() -> None:
177
  out_path = Path(OUTPUT_IMAGE)
178
  out_path.parent.mkdir(parents=True, exist_ok=True)
179
  save_image(grid, out_path)
180
- print(f"[done] saved -> {out_path.resolve()}")
181
 
182
 
183
  if __name__ == "__main__":
 
1
  #!/usr/bin/env python3
 
 
 
 
 
 
 
 
 
 
 
2
  from __future__ import annotations
3
 
4
  from contextlib import nullcontext
 
11
  from transformers import AutoModel
12
 
13
 
14
+ # config
 
 
 
 
 
15
 
16
+ MODEL_ID = "Harley-ml/MNIST-IMG-390k"
17
+ OUTPUT_IMAGE = "./digits.png"
18
  USE_MULTIPLE_DIGITS = False
19
+ DIGIT = 1
20
  DIGITS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
 
 
21
  IMAGES_PER_DIGIT = 4
 
 
22
  NUM_INFERENCE_STEPS = 1000
 
 
23
  IMAGE_SIZE = 32
 
 
24
  SEED = 42
 
 
25
  USE_AMP = torch.cuda.is_available()
26
+ ALLOW_TF32 = False
27
 
 
 
28
  # Helpers
 
29
 
30
  def _selected_digits() -> list[int]:
31
  if USE_MULTIPLE_DIGITS:
32
  if not DIGITS:
33
+ raise ValueError("`DIGITS` must not be empty when `USE_MULTIPLE_DIGITS=True`")
34
  return [int(d) for d in DIGITS]
35
  return [int(DIGIT)]
36
 
37
 
38
  def _load_model(model_id: str, device: torch.device):
 
39
  model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
40
  model.to(device)
41
  model.eval()
 
47
 
48
 
49
  def _to_display_range(x: torch.Tensor) -> torch.Tensor:
 
50
  return ((x.clamp(-1.0, 1.0) + 1.0) / 2.0).cpu()
51
 
52
 
 
147
  out_path = Path(OUTPUT_IMAGE)
148
  out_path.parent.mkdir(parents=True, exist_ok=True)
149
  save_image(grid, out_path)
150
+ print(f"[done] saved to {out_path.resolve()}")
151
 
152
 
153
  if __name__ == "__main__":