Update app.py
Browse files
app.py
CHANGED
|
@@ -47,23 +47,7 @@ from safetensors.torch import load_file
|
|
| 47 |
from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
|
| 48 |
from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
|
| 49 |
|
| 50 |
-
|
| 51 |
-
from ltx_core.loader.fuse_loras import apply_loras
|
| 52 |
-
except ImportError:
|
| 53 |
-
from ltx_core.loader.fuse_loras import fuse_lora_weights
|
| 54 |
-
|
| 55 |
-
def apply_loras(model_sd, loras, dtype=None):
|
| 56 |
-
# fuse_lora_weights is the lower-level helper the repo uses internally;
|
| 57 |
-
# this wrapper turns its output into a regular state_dict.
|
| 58 |
-
return {
|
| 59 |
-
k: v
|
| 60 |
-
for k, v in fuse_lora_weights(
|
| 61 |
-
model_sd,
|
| 62 |
-
loras,
|
| 63 |
-
dtype=dtype,
|
| 64 |
-
preserve_input_device=False,
|
| 65 |
-
)
|
| 66 |
-
}
|
| 67 |
|
| 68 |
from ltx_core.components.diffusion_steps import EulerDiffusionStep
|
| 69 |
from ltx_core.components.noisers import GaussianNoiser
|
|
@@ -459,11 +443,17 @@ def apply_current_loras_to_transformer(
|
|
| 459 |
)
|
| 460 |
]
|
| 461 |
|
| 462 |
-
|
| 463 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
loras,
|
| 465 |
dtype=pipeline.model_ledger.dtype,
|
| 466 |
)
|
|
|
|
|
|
|
| 467 |
LORA_STATE_CACHE[key] = fused_state
|
| 468 |
|
| 469 |
with torch.no_grad():
|
|
@@ -499,7 +489,9 @@ BASE_TRANSFORMER_STATE = {
|
|
| 499 |
k: v.detach().cpu().contiguous()
|
| 500 |
for k, v in _transformer.state_dict().items()
|
| 501 |
}
|
| 502 |
-
|
|
|
|
|
|
|
| 503 |
ACTIVE_LORA_KEY: str | None = None
|
| 504 |
LORA_STATE_CACHE: dict[str, dict[str, torch.Tensor]] = {}
|
| 505 |
_video_encoder = _orig_video_encoder_factory()
|
|
|
|
| 47 |
from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
|
| 48 |
from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
|
| 49 |
|
| 50 |
+
from ltx_core.loader.fuse_loras import apply_loras
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
from ltx_core.components.diffusion_steps import EulerDiffusionStep
|
| 53 |
from ltx_core.components.noisers import GaussianNoiser
|
|
|
|
| 443 |
)
|
| 444 |
]
|
| 445 |
|
| 446 |
+
base_model_sd = _StateDictModel(
|
| 447 |
+
{k: v.clone() for k, v in BASE_TRANSFORMER_STATE.items()}
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
fused_model_sd = apply_loras(
|
| 451 |
+
base_model_sd,
|
| 452 |
loras,
|
| 453 |
dtype=pipeline.model_ledger.dtype,
|
| 454 |
)
|
| 455 |
+
|
| 456 |
+
fused_state = fused_model_sd.sd if hasattr(fused_model_sd, "sd") else fused_model_sd
|
| 457 |
LORA_STATE_CACHE[key] = fused_state
|
| 458 |
|
| 459 |
with torch.no_grad():
|
|
|
|
| 489 |
k: v.detach().cpu().contiguous()
|
| 490 |
for k, v in _transformer.state_dict().items()
|
| 491 |
}
|
| 492 |
+
class _StateDictModel:
|
| 493 |
+
def __init__(self, sd: dict[str, torch.Tensor]):
|
| 494 |
+
self.sd = sd
|
| 495 |
ACTIVE_LORA_KEY: str | None = None
|
| 496 |
LORA_STATE_CACHE: dict[str, dict[str, torch.Tensor]] = {}
|
| 497 |
_video_encoder = _orig_video_encoder_factory()
|