SynLayers commited on
Commit
130c3c9
·
verified ·
1 Parent(s): 43bd07d

Upload infer/common_infer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. infer/common_infer.py +13 -3
infer/common_infer.py CHANGED
@@ -29,6 +29,14 @@ def resolve_local_config_path(path: str | None) -> str | None:
29
  return os.path.join(PROJECT_ROOT, path)
30
 
31
 
 
 
 
 
 
 
 
 
32
  def scale_box_xyxy(box, source_size: int, target_size: int) -> tuple:
33
  """Scale an xyxy box from source_size to target_size."""
34
  scale = target_size / source_size
@@ -92,7 +100,9 @@ def initialize_pipeline(config):
92
  single_layer_decoder=config.get("single_layer_decoder", None),
93
  )
94
  transp_vae = CustomVAE(vae_args)
95
- transp_vae_weights = torch.load(transp_vae_path, map_location=torch.device("cuda"))
 
 
96
  missing_keys, unexpected_keys = transp_vae.load_state_dict(
97
  transp_vae_weights["model"], strict=False
98
  )
@@ -141,7 +151,7 @@ def initialize_pipeline(config):
141
  layer_pe_path = os.path.join(layer_ckpt, "layer_pe.pth") if layer_ckpt else ""
142
  if os.path.exists(layer_pe_path):
143
  print(f"[INFO] Loading layer_pe from {layer_pe_path}...", flush=True)
144
- layer_pe = torch.load(layer_pe_path)
145
  transformer.load_state_dict(layer_pe, strict=False)
146
 
147
  print("[INFO] Loading MultiLayer-Adapter...", flush=True)
@@ -174,4 +184,4 @@ def initialize_pipeline(config):
174
  print(f"[INFO] Loading trained LoRA from {lora_ckpt}...", flush=True)
175
  pipeline.load_lora_weights(lora_ckpt, adapter_name="layer")
176
 
177
- return pipeline, transp_vae
 
29
  return os.path.join(PROJECT_ROOT, path)
30
 
31
 
32
+ def load_trusted_checkpoint(path: str, map_location=None):
33
+ """Load SynLayers-owned checkpoints saved before PyTorch changed weights_only defaults."""
34
+ try:
35
+ return torch.load(path, map_location=map_location, weights_only=False)
36
+ except TypeError:
37
+ return torch.load(path, map_location=map_location)
38
+
39
+
40
  def scale_box_xyxy(box, source_size: int, target_size: int) -> tuple:
41
  """Scale an xyxy box from source_size to target_size."""
42
  scale = target_size / source_size
 
100
  single_layer_decoder=config.get("single_layer_decoder", None),
101
  )
102
  transp_vae = CustomVAE(vae_args)
103
+ transp_vae_weights = load_trusted_checkpoint(
104
+ transp_vae_path, map_location=torch.device("cuda")
105
+ )
106
  missing_keys, unexpected_keys = transp_vae.load_state_dict(
107
  transp_vae_weights["model"], strict=False
108
  )
 
151
  layer_pe_path = os.path.join(layer_ckpt, "layer_pe.pth") if layer_ckpt else ""
152
  if os.path.exists(layer_pe_path):
153
  print(f"[INFO] Loading layer_pe from {layer_pe_path}...", flush=True)
154
+ layer_pe = load_trusted_checkpoint(layer_pe_path, map_location=torch.device("cuda"))
155
  transformer.load_state_dict(layer_pe, strict=False)
156
 
157
  print("[INFO] Loading MultiLayer-Adapter...", flush=True)
 
184
  print(f"[INFO] Loading trained LoRA from {lora_ckpt}...", flush=True)
185
  pipeline.load_lora_weights(lora_ckpt, adapter_name="layer")
186
 
187
+ return pipeline, transp_vae