dagloop5 commited on
Commit
d3e2661
·
verified ·
1 Parent(s): dac6133

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -20
app.py CHANGED
@@ -47,23 +47,7 @@ from safetensors.torch import load_file
47
  from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
48
  from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
49
 
50
- try:
51
- from ltx_core.loader.fuse_loras import apply_loras
52
- except ImportError:
53
- from ltx_core.loader.fuse_loras import fuse_lora_weights
54
-
55
- def apply_loras(model_sd, loras, dtype=None):
56
- # fuse_lora_weights is the lower-level helper the repo uses internally;
57
- # this wrapper turns its output into a regular state_dict.
58
- return {
59
- k: v
60
- for k, v in fuse_lora_weights(
61
- model_sd,
62
- loras,
63
- dtype=dtype,
64
- preserve_input_device=False,
65
- )
66
- }
67
 
68
  from ltx_core.components.diffusion_steps import EulerDiffusionStep
69
  from ltx_core.components.noisers import GaussianNoiser
@@ -459,11 +443,17 @@ def apply_current_loras_to_transformer(
459
  )
460
  ]
461
 
462
- fused_state = apply_loras(
463
- BASE_TRANSFORMER_STATE,
 
 
 
 
464
  loras,
465
  dtype=pipeline.model_ledger.dtype,
466
  )
 
 
467
  LORA_STATE_CACHE[key] = fused_state
468
 
469
  with torch.no_grad():
@@ -499,7 +489,9 @@ BASE_TRANSFORMER_STATE = {
499
  k: v.detach().cpu().contiguous()
500
  for k, v in _transformer.state_dict().items()
501
  }
502
-
 
 
503
  ACTIVE_LORA_KEY: str | None = None
504
  LORA_STATE_CACHE: dict[str, dict[str, torch.Tensor]] = {}
505
  _video_encoder = _orig_video_encoder_factory()
 
47
  from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
48
  from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
49
 
50
+ from ltx_core.loader.fuse_loras import apply_loras
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  from ltx_core.components.diffusion_steps import EulerDiffusionStep
53
  from ltx_core.components.noisers import GaussianNoiser
 
443
  )
444
  ]
445
 
446
+ base_model_sd = _StateDictModel(
447
+ {k: v.clone() for k, v in BASE_TRANSFORMER_STATE.items()}
448
+ )
449
+
450
+ fused_model_sd = apply_loras(
451
+ base_model_sd,
452
  loras,
453
  dtype=pipeline.model_ledger.dtype,
454
  )
455
+
456
+ fused_state = fused_model_sd.sd if hasattr(fused_model_sd, "sd") else fused_model_sd
457
  LORA_STATE_CACHE[key] = fused_state
458
 
459
  with torch.no_grad():
 
489
  k: v.detach().cpu().contiguous()
490
  for k, v in _transformer.state_dict().items()
491
  }
492
+ class _StateDictModel:
493
+ def __init__(self, sd: dict[str, torch.Tensor]):
494
+ self.sd = sd
495
  ACTIVE_LORA_KEY: str | None = None
496
  LORA_STATE_CACHE: dict[str, dict[str, torch.Tensor]] = {}
497
  _video_encoder = _orig_video_encoder_factory()