Harley-ml commited on
Commit
4e76fed
·
verified ·
1 Parent(s): 67d5b01

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +3 -7
inference.py CHANGED
@@ -23,10 +23,6 @@ USE_AMP = torch.cuda.is_available()
23
  ALLOW_TF32 = False
24
 
25
 
26
- # -----------------------------------------------------------------------------
27
- # Helpers
28
- # -----------------------------------------------------------------------------
29
-
30
  def _selected_digits() -> list[int]:
31
  if USE_MULTIPLE_DIGITS:
32
  if not DIGITS:
@@ -43,7 +39,7 @@ def _load_model(model_id: str, device: torch.device):
43
 
44
 
45
  def _load_scheduler(model_id: str) -> DDPMScheduler:
46
- return DDPMScheduler.from_pretrained(model_id)
47
 
48
 
49
  def _to_display_range(x: torch.Tensor) -> torch.Tensor:
@@ -60,7 +56,6 @@ def generate_grid(
60
  num_inference_steps: int,
61
  image_size: int,
62
  ) -> torch.Tensor:
63
- # Generate a single grid image containing all requested digits.
64
  scheduler.set_timesteps(num_inference_steps, device=device)
65
 
66
  rows: list[torch.Tensor] = []
@@ -108,6 +103,7 @@ def generate_grid(
108
  grid = make_grid(all_images, nrow=nrow)
109
  return grid
110
 
 
111
  def main() -> None:
112
  if ALLOW_TF32 and torch.cuda.is_available():
113
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -146,4 +142,4 @@ def main() -> None:
146
 
147
 
148
  if __name__ == "__main__":
149
- main()
 
23
  ALLOW_TF32 = False
24
 
25
 
 
 
 
 
26
  def _selected_digits() -> list[int]:
27
  if USE_MULTIPLE_DIGITS:
28
  if not DIGITS:
 
39
 
40
 
41
  def _load_scheduler(model_id: str) -> DDPMScheduler:
42
+ return DDPMScheduler.from_pretrained(model_id, trust_remote_code=True)
43
 
44
 
45
  def _to_display_range(x: torch.Tensor) -> torch.Tensor:
 
56
  num_inference_steps: int,
57
  image_size: int,
58
  ) -> torch.Tensor:
 
59
  scheduler.set_timesteps(num_inference_steps, device=device)
60
 
61
  rows: list[torch.Tensor] = []
 
103
  grid = make_grid(all_images, nrow=nrow)
104
  return grid
105
 
106
+
107
  def main() -> None:
108
  if ALLOW_TF32 and torch.cuda.is_available():
109
  torch.backends.cuda.matmul.allow_tf32 = True
 
142
 
143
 
144
  if __name__ == "__main__":
145
+ main()