Spaces:
Running on Zero
Running on Zero
update common_infer.py
Browse files- infer/common_infer.py +12 -2
infer/common_infer.py
CHANGED
|
@@ -29,6 +29,14 @@ def resolve_local_config_path(path: str | None) -> str | None:
|
|
| 29 |
return os.path.join(PROJECT_ROOT, path)
|
| 30 |
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
def scale_box_xyxy(box, source_size: int, target_size: int) -> tuple:
|
| 33 |
"""Scale an xyxy box from source_size to target_size."""
|
| 34 |
scale = target_size / source_size
|
|
@@ -92,7 +100,9 @@ def initialize_pipeline(config):
|
|
| 92 |
single_layer_decoder=config.get("single_layer_decoder", None),
|
| 93 |
)
|
| 94 |
transp_vae = CustomVAE(vae_args)
|
| 95 |
-
transp_vae_weights =
|
|
|
|
|
|
|
| 96 |
missing_keys, unexpected_keys = transp_vae.load_state_dict(
|
| 97 |
transp_vae_weights["model"], strict=False
|
| 98 |
)
|
|
@@ -141,7 +151,7 @@ def initialize_pipeline(config):
|
|
| 141 |
layer_pe_path = os.path.join(layer_ckpt, "layer_pe.pth") if layer_ckpt else ""
|
| 142 |
if os.path.exists(layer_pe_path):
|
| 143 |
print(f"[INFO] Loading layer_pe from {layer_pe_path}...", flush=True)
|
| 144 |
-
layer_pe = torch.
|
| 145 |
transformer.load_state_dict(layer_pe, strict=False)
|
| 146 |
|
| 147 |
print("[INFO] Loading MultiLayer-Adapter...", flush=True)
|
|
|
|
| 29 |
return os.path.join(PROJECT_ROOT, path)
|
| 30 |
|
| 31 |
|
| 32 |
+
def load_trusted_checkpoint(path: str, map_location=None):
|
| 33 |
+
"""Load SynLayers-owned checkpoints saved before PyTorch changed weights_only defaults."""
|
| 34 |
+
try:
|
| 35 |
+
return torch.load(path, map_location=map_location, weights_only=False)
|
| 36 |
+
except TypeError:
|
| 37 |
+
return torch.load(path, map_location=map_location)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
def scale_box_xyxy(box, source_size: int, target_size: int) -> tuple:
|
| 41 |
"""Scale an xyxy box from source_size to target_size."""
|
| 42 |
scale = target_size / source_size
|
|
|
|
| 100 |
single_layer_decoder=config.get("single_layer_decoder", None),
|
| 101 |
)
|
| 102 |
transp_vae = CustomVAE(vae_args)
|
| 103 |
+
transp_vae_weights = load_trusted_checkpoint(
|
| 104 |
+
transp_vae_path, map_location=torch.device("cuda")
|
| 105 |
+
)
|
| 106 |
missing_keys, unexpected_keys = transp_vae.load_state_dict(
|
| 107 |
transp_vae_weights["model"], strict=False
|
| 108 |
)
|
|
|
|
| 151 |
layer_pe_path = os.path.join(layer_ckpt, "layer_pe.pth") if layer_ckpt else ""
|
| 152 |
if os.path.exists(layer_pe_path):
|
| 153 |
print(f"[INFO] Loading layer_pe from {layer_pe_path}...", flush=True)
|
| 154 |
+
layer_pe = load_trusted_checkpoint(layer_pe_path, map_location=torch.device("cuda"))
|
| 155 |
transformer.load_state_dict(layer_pe, strict=False)
|
| 156 |
|
| 157 |
print("[INFO] Loading MultiLayer-Adapter...", flush=True)
|