Update app.py
Browse files
app.py
CHANGED
|
@@ -423,6 +423,9 @@ def apply_current_loras_to_transformer(
|
|
| 423 |
):
|
| 424 |
global _transformer
|
| 425 |
|
|
|
|
|
|
|
|
|
|
| 426 |
# Collect non-zero strength LoRAs
|
| 427 |
lora_specs = _collect_lora_specs(
|
| 428 |
pose_strength, general_strength, motion_strength, dreamlay_strength,
|
|
@@ -465,7 +468,7 @@ def apply_current_loras_to_transformer(
|
|
| 465 |
# Load fused state dict into transformer
|
| 466 |
with torch.no_grad():
|
| 467 |
fused_state_cuda = {
|
| 468 |
-
k: (v.to(
|
| 469 |
for k, v in fused_state.items()
|
| 470 |
}
|
| 471 |
missing, unexpected = _transformer.load_state_dict(fused_state_cuda, strict=False)
|
|
|
|
| 423 |
):
|
| 424 |
global _transformer
|
| 425 |
|
| 426 |
+
# Get device from the pipeline (fallback: first parameter of transformer)
|
| 427 |
+
device = getattr(pipeline, 'device', None) or next(_transformer.parameters()).device
|
| 428 |
+
|
| 429 |
# Collect non-zero strength LoRAs
|
| 430 |
lora_specs = _collect_lora_specs(
|
| 431 |
pose_strength, general_strength, motion_strength, dreamlay_strength,
|
|
|
|
| 468 |
# Load fused state dict into transformer
|
| 469 |
with torch.no_grad():
|
| 470 |
fused_state_cuda = {
|
| 471 |
+
k: (v.to(device) if v.device == torch.device("cpu") else v)
|
| 472 |
for k, v in fused_state.items()
|
| 473 |
}
|
| 474 |
missing, unexpected = _transformer.load_state_dict(fused_state_cuda, strict=False)
|