SynLayers commited on
Commit
7e06453
·
1 Parent(s): 5962c55

update common_infer.py

Browse files
Files changed (1) hide show
  1. infer/common_infer.py +12 -2
infer/common_infer.py CHANGED
@@ -29,6 +29,14 @@ def resolve_local_config_path(path: str | None) -> str | None:
29
  return os.path.join(PROJECT_ROOT, path)
30
 
31
 
 
 
 
 
 
 
 
 
32
  def scale_box_xyxy(box, source_size: int, target_size: int) -> tuple:
33
  """Scale an xyxy box from source_size to target_size."""
34
  scale = target_size / source_size
@@ -92,7 +100,9 @@ def initialize_pipeline(config):
92
  single_layer_decoder=config.get("single_layer_decoder", None),
93
  )
94
  transp_vae = CustomVAE(vae_args)
95
- transp_vae_weights = torch.load(transp_vae_path, map_location=torch.device("cuda"))
 
 
96
  missing_keys, unexpected_keys = transp_vae.load_state_dict(
97
  transp_vae_weights["model"], strict=False
98
  )
@@ -141,7 +151,7 @@ def initialize_pipeline(config):
141
  layer_pe_path = os.path.join(layer_ckpt, "layer_pe.pth") if layer_ckpt else ""
142
  if os.path.exists(layer_pe_path):
143
  print(f"[INFO] Loading layer_pe from {layer_pe_path}...", flush=True)
144
- layer_pe = torch.load(layer_pe_path)
145
  transformer.load_state_dict(layer_pe, strict=False)
146
 
147
  print("[INFO] Loading MultiLayer-Adapter...", flush=True)
 
29
  return os.path.join(PROJECT_ROOT, path)
30
 
31
 
32
+ def load_trusted_checkpoint(path: str, map_location=None):
33
+ """Load SynLayers-owned checkpoints saved before PyTorch changed weights_only defaults."""
34
+ try:
35
+ return torch.load(path, map_location=map_location, weights_only=False)
36
+ except TypeError:
37
+ return torch.load(path, map_location=map_location)
38
+
39
+
40
  def scale_box_xyxy(box, source_size: int, target_size: int) -> tuple:
41
  """Scale an xyxy box from source_size to target_size."""
42
  scale = target_size / source_size
 
100
  single_layer_decoder=config.get("single_layer_decoder", None),
101
  )
102
  transp_vae = CustomVAE(vae_args)
103
+ transp_vae_weights = load_trusted_checkpoint(
104
+ transp_vae_path, map_location=torch.device("cuda")
105
+ )
106
  missing_keys, unexpected_keys = transp_vae.load_state_dict(
107
  transp_vae_weights["model"], strict=False
108
  )
 
151
  layer_pe_path = os.path.join(layer_ckpt, "layer_pe.pth") if layer_ckpt else ""
152
  if os.path.exists(layer_pe_path):
153
  print(f"[INFO] Loading layer_pe from {layer_pe_path}...", flush=True)
154
+ layer_pe = load_trusted_checkpoint(layer_pe_path, map_location=torch.device("cuda"))
155
  transformer.load_state_dict(layer_pe, strict=False)
156
 
157
  print("[INFO] Loading MultiLayer-Adapter...", flush=True)