Update app.py
Browse files
app.py
CHANGED
|
@@ -44,8 +44,13 @@ import gradio as gr
|
|
| 44 |
import numpy as np
|
| 45 |
from huggingface_hub import hf_hub_download, snapshot_download
|
| 46 |
from safetensors.torch import load_file
|
| 47 |
-
from ltx_core.loader.primitives import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
|
|
|
|
| 49 |
|
| 50 |
from ltx_core.loader.fuse_loras import apply_loras
|
| 51 |
|
|
@@ -344,6 +349,21 @@ pipeline = LTX23DistilledA2VPipeline(
|
|
| 344 |
)
|
| 345 |
# ----------------------------------------------------------------
|
| 346 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
def _make_lora_key(pose_strength: float, general_strength: float, motion_strength: float, dreamlay_strength: float, mself_strength: float, dramatic_strength: float, fluid_strength: float, liquid_strength: float, demopose_strength: float, voice_strength: float, realism_strength: float, transition_strength: float, physics_strength: float, reasoning_strength: float, twostep_strength: float) -> tuple[str, str]:
|
| 348 |
rp = round(float(pose_strength), 2)
|
| 349 |
rg = round(float(general_strength), 2)
|
|
@@ -434,15 +454,30 @@ def apply_current_loras_to_transformer(
|
|
| 434 |
fused_state = LORA_STATE_CACHE[key]
|
| 435 |
else:
|
| 436 |
loras = [
|
| 437 |
-
|
|
|
|
|
|
|
|
|
|
| 438 |
for path, strength in _collect_lora_specs(
|
| 439 |
-
pose_strength,
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
)
|
|
|
|
| 444 |
]
|
| 445 |
-
|
| 446 |
base_model_sd = _StateDictModel(
|
| 447 |
{k: v.clone() for k, v in BASE_TRANSFORMER_STATE.items()}
|
| 448 |
)
|
|
@@ -453,6 +488,12 @@ def apply_current_loras_to_transformer(
|
|
| 453 |
dtype=pipeline.model_ledger.dtype,
|
| 454 |
)
|
| 455 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
fused_state = fused_model_sd.sd if hasattr(fused_model_sd, "sd") else fused_model_sd
|
| 457 |
LORA_STATE_CACHE[key] = fused_state
|
| 458 |
|
|
|
|
| 44 |
import numpy as np
|
| 45 |
from huggingface_hub import hf_hub_download, snapshot_download
|
| 46 |
from safetensors.torch import load_file
|
| 47 |
+
from ltx_core.loader.primitives import (
|
| 48 |
+
StateDict,
|
| 49 |
+
LoraPathStrengthAndSDOps,
|
| 50 |
+
LoraStateDictWithStrength,
|
| 51 |
+
)
|
| 52 |
from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
|
| 53 |
+
from safetensors import safe_open
|
| 54 |
|
| 55 |
from ltx_core.loader.fuse_loras import apply_loras
|
| 56 |
|
|
|
|
| 349 |
)
|
| 350 |
# ----------------------------------------------------------------
|
| 351 |
|
| 352 |
+
LORA_STATE_DICT_CACHE: dict[str, StateDict] = {}
|
| 353 |
+
|
| 354 |
+
def _load_lora_state_dict(path: str) -> StateDict:
|
| 355 |
+
if path in LORA_STATE_DICT_CACHE:
|
| 356 |
+
return LORA_STATE_DICT_CACHE[path]
|
| 357 |
+
|
| 358 |
+
with safe_open(path, framework="pt", device="cpu") as f:
|
| 359 |
+
tensors = {k: f.get_tensor(k).contiguous() for k in f.keys()}
|
| 360 |
+
|
| 361 |
+
size = sum(t.numel() * t.element_size() for t in tensors.values())
|
| 362 |
+
dtypes = {t.dtype for t in tensors.values()}
|
| 363 |
+
sd = StateDict(sd=tensors, device=torch.device("cpu"), size=size, dtype=dtypes)
|
| 364 |
+
LORA_STATE_DICT_CACHE[path] = sd
|
| 365 |
+
return sd
|
| 366 |
+
|
| 367 |
def _make_lora_key(pose_strength: float, general_strength: float, motion_strength: float, dreamlay_strength: float, mself_strength: float, dramatic_strength: float, fluid_strength: float, liquid_strength: float, demopose_strength: float, voice_strength: float, realism_strength: float, transition_strength: float, physics_strength: float, reasoning_strength: float, twostep_strength: float) -> tuple[str, str]:
|
| 368 |
rp = round(float(pose_strength), 2)
|
| 369 |
rg = round(float(general_strength), 2)
|
|
|
|
| 454 |
fused_state = LORA_STATE_CACHE[key]
|
| 455 |
else:
|
| 456 |
loras = [
|
| 457 |
+
LoraStateDictWithStrength(
|
| 458 |
+
state_dict=_load_lora_state_dict(path),
|
| 459 |
+
strength=strength,
|
| 460 |
+
)
|
| 461 |
for path, strength in _collect_lora_specs(
|
| 462 |
+
pose_strength,
|
| 463 |
+
general_strength,
|
| 464 |
+
motion_strength,
|
| 465 |
+
dreamlay_strength,
|
| 466 |
+
mself_strength,
|
| 467 |
+
dramatic_strength,
|
| 468 |
+
fluid_strength,
|
| 469 |
+
liquid_strength,
|
| 470 |
+
demopose_strength,
|
| 471 |
+
voice_strength,
|
| 472 |
+
realism_strength,
|
| 473 |
+
transition_strength,
|
| 474 |
+
physics_strength,
|
| 475 |
+
reasoning_strength,
|
| 476 |
+
twostep_strength,
|
| 477 |
)
|
| 478 |
+
if strength != 0.0
|
| 479 |
]
|
| 480 |
+
|
| 481 |
base_model_sd = _StateDictModel(
|
| 482 |
{k: v.clone() for k, v in BASE_TRANSFORMER_STATE.items()}
|
| 483 |
)
|
|
|
|
| 488 |
dtype=pipeline.model_ledger.dtype,
|
| 489 |
)
|
| 490 |
|
| 491 |
+
fused_state = (
|
| 492 |
+
fused_model_sd.sd
|
| 493 |
+
if hasattr(fused_model_sd, "sd")
|
| 494 |
+
else fused_model_sd
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
fused_state = fused_model_sd.sd if hasattr(fused_model_sd, "sd") else fused_model_sd
|
| 498 |
LORA_STATE_CACHE[key] = fused_state
|
| 499 |
|