dagloop5 commited on
Commit
fcd3e09
·
verified ·
1 Parent(s): d3e2661

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -7
app.py CHANGED
@@ -44,8 +44,13 @@ import gradio as gr
44
  import numpy as np
45
  from huggingface_hub import hf_hub_download, snapshot_download
46
  from safetensors.torch import load_file
47
- from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
 
 
 
 
48
  from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
 
49
 
50
  from ltx_core.loader.fuse_loras import apply_loras
51
 
@@ -344,6 +349,21 @@ pipeline = LTX23DistilledA2VPipeline(
344
  )
345
  # ----------------------------------------------------------------
346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  def _make_lora_key(pose_strength: float, general_strength: float, motion_strength: float, dreamlay_strength: float, mself_strength: float, dramatic_strength: float, fluid_strength: float, liquid_strength: float, demopose_strength: float, voice_strength: float, realism_strength: float, transition_strength: float, physics_strength: float, reasoning_strength: float, twostep_strength: float) -> tuple[str, str]:
348
  rp = round(float(pose_strength), 2)
349
  rg = round(float(general_strength), 2)
@@ -434,15 +454,30 @@ def apply_current_loras_to_transformer(
434
  fused_state = LORA_STATE_CACHE[key]
435
  else:
436
  loras = [
437
- LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
 
 
 
438
  for path, strength in _collect_lora_specs(
439
- pose_strength, general_strength, motion_strength, dreamlay_strength,
440
- mself_strength, dramatic_strength, fluid_strength, liquid_strength,
441
- demopose_strength, voice_strength, realism_strength, transition_strength,
442
- physics_strength, reasoning_strength, twostep_strength,
 
 
 
 
 
 
 
 
 
 
 
443
  )
 
444
  ]
445
-
446
  base_model_sd = _StateDictModel(
447
  {k: v.clone() for k, v in BASE_TRANSFORMER_STATE.items()}
448
  )
@@ -453,6 +488,12 @@ def apply_current_loras_to_transformer(
453
  dtype=pipeline.model_ledger.dtype,
454
  )
455
 
 
 
 
 
 
 
456
  fused_state = fused_model_sd.sd if hasattr(fused_model_sd, "sd") else fused_model_sd
457
  LORA_STATE_CACHE[key] = fused_state
458
 
 
44
  import numpy as np
45
  from huggingface_hub import hf_hub_download, snapshot_download
46
  from safetensors.torch import load_file
47
+ from ltx_core.loader.primitives import (
48
+ StateDict,
49
+ LoraPathStrengthAndSDOps,
50
+ LoraStateDictWithStrength,
51
+ )
52
  from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
53
+ from safetensors import safe_open
54
 
55
  from ltx_core.loader.fuse_loras import apply_loras
56
 
 
349
  )
350
  # ----------------------------------------------------------------
351
 
352
+ LORA_STATE_DICT_CACHE: dict[str, StateDict] = {}
353
+
354
+ def _load_lora_state_dict(path: str) -> StateDict:
355
+ if path in LORA_STATE_DICT_CACHE:
356
+ return LORA_STATE_DICT_CACHE[path]
357
+
358
+ with safe_open(path, framework="pt", device="cpu") as f:
359
+ tensors = {k: f.get_tensor(k).contiguous() for k in f.keys()}
360
+
361
+ size = sum(t.numel() * t.element_size() for t in tensors.values())
362
+ dtypes = {t.dtype for t in tensors.values()}
363
+ sd = StateDict(sd=tensors, device=torch.device("cpu"), size=size, dtype=dtypes)
364
+ LORA_STATE_DICT_CACHE[path] = sd
365
+ return sd
366
+
367
  def _make_lora_key(pose_strength: float, general_strength: float, motion_strength: float, dreamlay_strength: float, mself_strength: float, dramatic_strength: float, fluid_strength: float, liquid_strength: float, demopose_strength: float, voice_strength: float, realism_strength: float, transition_strength: float, physics_strength: float, reasoning_strength: float, twostep_strength: float) -> tuple[str, str]:
368
  rp = round(float(pose_strength), 2)
369
  rg = round(float(general_strength), 2)
 
454
  fused_state = LORA_STATE_CACHE[key]
455
  else:
456
  loras = [
457
+ LoraStateDictWithStrength(
458
+ state_dict=_load_lora_state_dict(path),
459
+ strength=strength,
460
+ )
461
  for path, strength in _collect_lora_specs(
462
+ pose_strength,
463
+ general_strength,
464
+ motion_strength,
465
+ dreamlay_strength,
466
+ mself_strength,
467
+ dramatic_strength,
468
+ fluid_strength,
469
+ liquid_strength,
470
+ demopose_strength,
471
+ voice_strength,
472
+ realism_strength,
473
+ transition_strength,
474
+ physics_strength,
475
+ reasoning_strength,
476
+ twostep_strength,
477
  )
478
+ if strength != 0.0
479
  ]
480
+
481
  base_model_sd = _StateDictModel(
482
  {k: v.clone() for k, v in BASE_TRANSFORMER_STATE.items()}
483
  )
 
488
  dtype=pipeline.model_ledger.dtype,
489
  )
490
 
491
+ fused_state = (
492
+ fused_model_sd.sd
493
+ if hasattr(fused_model_sd, "sd")
494
+ else fused_model_sd
495
+ )
496
+
497
  fused_state = fused_model_sd.sd if hasattr(fused_model_sd, "sd") else fused_model_sd
498
  LORA_STATE_CACHE[key] = fused_state
499