dagloop5 commited on
Commit
73f01da
·
verified ·
1 Parent(s): 81a7aff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -89
app.py CHANGED
@@ -349,40 +349,21 @@ pipeline = LTX23DistilledA2VPipeline(
349
  )
350
  # ----------------------------------------------------------------
351
 
352
- LORA_STATE_DICT_CACHE: dict[str, StateDict] = {}
353
-
354
  def _load_lora_state_dict(path: str) -> StateDict:
355
- if path in LORA_STATE_DICT_CACHE:
356
- return LORA_STATE_DICT_CACHE[path]
357
-
358
  with safe_open(path, framework="pt", device="cpu") as f:
359
- tensors = {k: f.get_tensor(k).contiguous() for k in f.keys()}
 
 
 
 
 
 
360
 
361
  size = sum(t.numel() * t.element_size() for t in tensors.values())
362
  dtypes = {t.dtype for t in tensors.values()}
363
- sd = StateDict(sd=tensors, device=torch.device("cpu"), size=size, dtype=dtypes)
364
- LORA_STATE_DICT_CACHE[path] = sd
365
- return sd
366
-
367
- def _make_lora_key(pose_strength: float, general_strength: float, motion_strength: float, dreamlay_strength: float, mself_strength: float, dramatic_strength: float, fluid_strength: float, liquid_strength: float, demopose_strength: float, voice_strength: float, realism_strength: float, transition_strength: float, physics_strength: float, reasoning_strength: float, twostep_strength: float) -> tuple[str, str]:
368
- rp = round(float(pose_strength), 2)
369
- rg = round(float(general_strength), 2)
370
- rm = round(float(motion_strength), 2)
371
- rd = round(float(dreamlay_strength), 2)
372
- rs = round(float(mself_strength), 2)
373
- rr = round(float(dramatic_strength), 2)
374
- rf = round(float(fluid_strength), 2)
375
- rl = round(float(liquid_strength), 2)
376
- ro = round(float(demopose_strength), 2)
377
- rv = round(float(voice_strength), 2)
378
- re = round(float(realism_strength), 2)
379
- rt = round(float(transition_strength), 2)
380
- ry = round(float(physics_strength), 2)
381
- ri = round(float(reasoning_strength), 2)
382
- rw = round(float(twostep_strength), 2)
383
- key_str = f"{pose_lora_path}:{rp}|{general_lora_path}:{rg}|{motion_lora_path}:{rm}|{dreamlay_lora_path}:{rd}|{mself_lora_path}:{rs}|{dramatic_lora_path}:{rr}|{fluid_lora_path}:{rf}|{liquid_lora_path}:{rl}|{demopose_lora_path}:{ro}|{voice_lora_path}:{rv}|{realism_lora_path}:{re}|{transition_lora_path}:{rt}|{physics_lora_path}:{ry}|{reasoning_lora_path}:{ri}|{twostep_lora_path}:{rw}"
384
- key = hashlib.sha256(key_str.encode("utf-8")).hexdigest()
385
- return key, key_str
386
 
387
  def _collect_lora_specs(
388
  pose_strength: float,
@@ -401,8 +382,8 @@ def _collect_lora_specs(
401
  reasoning_strength: float,
402
  twostep_strength: float,
403
  ):
404
- # Keep all 14 adapters in the active list; zero strength means no effect.
405
- return [
406
  (pose_lora_path, round(float(pose_strength), 2)),
407
  (general_lora_path, round(float(general_strength), 2)),
408
  (motion_lora_path, round(float(motion_strength), 2)),
@@ -419,6 +400,8 @@ def _collect_lora_specs(
419
  (reasoning_lora_path, round(float(reasoning_strength), 2)),
420
  (twostep_lora_path, round(float(twostep_strength), 2)),
421
  ]
 
 
422
 
423
 
424
  def apply_current_loras_to_transformer(
@@ -438,74 +421,60 @@ def apply_current_loras_to_transformer(
438
  reasoning_strength: float,
439
  twostep_strength: float,
440
  ):
441
- global ACTIVE_LORA_KEY
442
 
443
- key, _ = _make_lora_key(
 
444
  pose_strength, general_strength, motion_strength, dreamlay_strength,
445
  mself_strength, dramatic_strength, fluid_strength, liquid_strength,
446
  demopose_strength, voice_strength, realism_strength, transition_strength,
447
- physics_strength, reasoning_strength, twostep_strength
448
  )
449
 
450
- if key == ACTIVE_LORA_KEY:
451
- return "LoRAs already active."
 
452
 
453
- if key in LORA_STATE_CACHE:
454
- fused_state = LORA_STATE_CACHE[key]
455
- else:
456
- loras = [
457
- LoraStateDictWithStrength(
458
- state_dict=_load_lora_state_dict(path),
459
- strength=strength,
460
- )
461
- for path, strength in _collect_lora_specs(
462
- pose_strength,
463
- general_strength,
464
- motion_strength,
465
- dreamlay_strength,
466
- mself_strength,
467
- dramatic_strength,
468
- fluid_strength,
469
- liquid_strength,
470
- demopose_strength,
471
- voice_strength,
472
- realism_strength,
473
- transition_strength,
474
- physics_strength,
475
- reasoning_strength,
476
- twostep_strength,
477
- )
478
- if strength != 0.0
479
- ]
480
-
481
- base_model_sd = _StateDictModel(
482
- {k: v.clone() for k, v in BASE_TRANSFORMER_STATE.items()}
483
- )
484
-
485
- fused_model_sd = apply_loras(
486
- base_model_sd,
487
- loras,
488
- dtype=pipeline.model_ledger.dtype,
489
- )
490
-
491
- fused_state = (
492
- fused_model_sd.sd
493
- if hasattr(fused_model_sd, "sd")
494
- else fused_model_sd
495
  )
496
-
497
- fused_state = fused_model_sd.sd if hasattr(fused_model_sd, "sd") else fused_model_sd
498
- LORA_STATE_CACHE[key] = fused_state
499
 
 
 
 
 
 
 
 
 
 
 
 
500
  with torch.no_grad():
501
- missing, unexpected = _transformer.load_state_dict(fused_state, strict=False)
 
 
 
 
502
  if missing or unexpected:
503
- print(
504
- f"[LoRA] state_dict mismatch: missing={len(missing)}, unexpected={len(unexpected)}"
505
- )
506
 
507
- ACTIVE_LORA_KEY = key
508
- return f"Applied LoRAs: {key[:12]}"
509
 
510
  # ---- REPLACE PRELOAD BLOCK START ----
511
  # Preload all models for ZeroGPU tensor packing.
@@ -533,8 +502,6 @@ BASE_TRANSFORMER_STATE = {
533
  class _StateDictModel:
534
  def __init__(self, sd: dict[str, torch.Tensor]):
535
  self.sd = sd
536
- ACTIVE_LORA_KEY: str | None = None
537
- LORA_STATE_CACHE: dict[str, dict[str, torch.Tensor]] = {}
538
  _video_encoder = _orig_video_encoder_factory()
539
  _video_decoder = _orig_video_decoder_factory()
540
  _audio_encoder = _orig_audio_encoder_factory()
 
349
  )
350
  # ----------------------------------------------------------------
351
 
 
 
352
  def _load_lora_state_dict(path: str) -> StateDict:
353
+ # Note: Per-request LoRA loading (no caching).
354
+ # If performance becomes an issue, add caching back with correct StateDict handling.
 
355
  with safe_open(path, framework="pt", device="cpu") as f:
356
+ tensors = {}
357
+ for key in f.keys():
358
+ # Apply ComfyUI→base-model key renaming so LoRA weights match transformer keys
359
+ renamed_key = LTXV_LORA_COMFY_RENAMING_MAP.apply_to_key(key)
360
+ if renamed_key is None:
361
+ renamed_key = key # Keep original if no renaming match
362
+ tensors[renamed_key] = f.get_tensor(key).contiguous()
363
 
364
  size = sum(t.numel() * t.element_size() for t in tensors.values())
365
  dtypes = {t.dtype for t in tensors.values()}
366
+ return StateDict(sd=tensors, device=torch.device("cpu"), size=size, dtype=dtypes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
368
  def _collect_lora_specs(
369
  pose_strength: float,
 
382
  reasoning_strength: float,
383
  twostep_strength: float,
384
  ):
385
+ """Collect (path, strength) pairs for all LoRAs with non-zero strength."""
386
+ specs = [
387
  (pose_lora_path, round(float(pose_strength), 2)),
388
  (general_lora_path, round(float(general_strength), 2)),
389
  (motion_lora_path, round(float(motion_strength), 2)),
 
400
  (reasoning_lora_path, round(float(reasoning_strength), 2)),
401
  (twostep_lora_path, round(float(twostep_strength), 2)),
402
  ]
403
+ # Filter out zero-strength LoRAs
404
+ return [(path, strength) for path, strength in specs if strength != 0.0]
405
 
406
 
407
  def apply_current_loras_to_transformer(
 
421
  reasoning_strength: float,
422
  twostep_strength: float,
423
  ):
424
+ global _transformer
425
 
426
+ # Collect non-zero strength LoRAs
427
+ lora_specs = _collect_lora_specs(
428
  pose_strength, general_strength, motion_strength, dreamlay_strength,
429
  mself_strength, dramatic_strength, fluid_strength, liquid_strength,
430
  demopose_strength, voice_strength, realism_strength, transition_strength,
431
+ physics_strength, reasoning_strength, twostep_strength,
432
  )
433
 
434
+ # No LoRAs to apply — skip
435
+ if not lora_specs:
436
+ return "No LoRAs (all zero strength)."
437
 
438
+ # Build base model StateDict (proper type for apply_loras)
439
+ base_model_sd = StateDict(
440
+ sd={k: v.clone() for k, v in BASE_TRANSFORMER_STATE.items()},
441
+ device=torch.device("cpu"),
442
+ size=sum(v.numel() * v.element_size() for v in BASE_TRANSFORMER_STATE.values()),
443
+ dtype={v.dtype for v in BASE_TRANSFORMER_STATE.values()},
444
+ )
445
+
446
+ # Build LoraStateDictWithStrength objects
447
+ loras = [
448
+ LoraStateDictWithStrength(
449
+ state_dict=_load_lora_state_dict(path),
450
+ strength=strength,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
  )
452
+ for path, strength in lora_specs
453
+ ]
 
454
 
455
+ # Fuse LoRAs into base model
456
+ fused_model_sd = apply_loras(
457
+ base_model_sd,
458
+ loras,
459
+ dtype=pipeline.model_ledger.dtype,
460
+ )
461
+
462
+ # Extract plain dict from StateDict for load_state_dict
463
+ fused_state = fused_model_sd.sd
464
+
465
+ # Load fused state dict into transformer
466
  with torch.no_grad():
467
+ fused_state_cuda = {
468
+ k: (v.to(_transformer.device) if v.device == torch.device("cpu") else v)
469
+ for k, v in fused_state.items()
470
+ }
471
+ missing, unexpected = _transformer.load_state_dict(fused_state_cuda, strict=False)
472
  if missing or unexpected:
473
+ print(f"[LoRA] state_dict load: missing={len(missing)}, unexpected={len(unexpected)}")
474
+ if missing:
475
+ print(f" Missing keys (first 5): {missing[:5]}")
476
 
477
+ return f"Applied {len(lora_specs)} LoRA(s)."
 
478
 
479
  # ---- REPLACE PRELOAD BLOCK START ----
480
  # Preload all models for ZeroGPU tensor packing.
 
502
  class _StateDictModel:
503
  def __init__(self, sd: dict[str, torch.Tensor]):
504
  self.sd = sd
 
 
505
  _video_encoder = _orig_video_encoder_factory()
506
  _video_decoder = _orig_video_decoder_factory()
507
  _audio_encoder = _orig_audio_encoder_factory()