dagloop5 commited on
Commit
999ba54
·
verified ·
1 Parent(s): 73f01da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -423,6 +423,9 @@ def apply_current_loras_to_transformer(
423
  ):
424
  global _transformer
425
 
 
 
 
426
  # Collect non-zero strength LoRAs
427
  lora_specs = _collect_lora_specs(
428
  pose_strength, general_strength, motion_strength, dreamlay_strength,
@@ -465,7 +468,7 @@ def apply_current_loras_to_transformer(
465
  # Load fused state dict into transformer
466
  with torch.no_grad():
467
  fused_state_cuda = {
468
- k: (v.to(_transformer.device) if v.device == torch.device("cpu") else v)
469
  for k, v in fused_state.items()
470
  }
471
  missing, unexpected = _transformer.load_state_dict(fused_state_cuda, strict=False)
 
423
  ):
424
  global _transformer
425
 
426
+ # Get device from the pipeline (fallback: first parameter of transformer)
427
+ device = getattr(pipeline, 'device', None) or next(_transformer.parameters()).device
428
+
429
  # Collect non-zero strength LoRAs
430
  lora_specs = _collect_lora_specs(
431
  pose_strength, general_strength, motion_strength, dreamlay_strength,
 
468
  # Load fused state dict into transformer
469
  with torch.no_grad():
470
  fused_state_cuda = {
471
+ k: (v.to(device) if v.device == torch.device("cpu") else v)
472
  for k, v in fused_state.items()
473
  }
474
  missing, unexpected = _transformer.load_state_dict(fused_state_cuda, strict=False)