"""LoRA file validation and apply/revert context manager.""" from __future__ import annotations import json import struct from collections.abc import Iterator from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path from typing import Any ZIMAGE_LORA_PREFIXES = ("transformer.", "dit.", "model.transformer.", "diffusion_model.") class LoRAValidationError(ValueError): """Raised when a LoRA safetensors file doesn't match Z-Image's key layout.""" @dataclass(frozen=True) class LoRAInfo: path: Path rank: int target: str # which submodule it applies to ("transformer" for Z-Image) size_bytes: int def sniff(path: Path | str) -> LoRAInfo: """Read just the safetensors header to verify and infer rank + target. Doesn't load tensors. Doesn't allocate GPU memory. Cheap enough to call before @spaces.GPU fires. """ path = Path(path) raw = path.read_bytes() if len(raw) < 8: raise LoRAValidationError(f"{path.name}: file too short to be safetensors") (header_len,) = struct.unpack(" len(raw): raise LoRAValidationError(f"{path.name}: not a valid safetensors header") try: header = json.loads(raw[8 : 8 + header_len]) except json.JSONDecodeError as e: raise LoRAValidationError(f"{path.name}: safetensors header is not JSON ({e})") from e tensor_keys = [k for k in header.keys() if not k.startswith("__")] if not tensor_keys: raise LoRAValidationError(f"{path.name}: no tensors in file") bad = [k for k in tensor_keys if not k.startswith(ZIMAGE_LORA_PREFIXES)] if bad: sample = bad[0] raise LoRAValidationError( f"{path.name}: unexpected key '{sample}' — Z-Image LoRAs must target " f"{ZIMAGE_LORA_PREFIXES} (got {len(bad)}/{len(tensor_keys)} mismatched keys)" ) meta = header.get("__metadata__") or {} rank = int(meta.get("rank", 0)) if not rank: # Infer from any A/B tensor pair shape for k, v in header.items(): if "lora_A" in k or "lora_down" in k: shape = v.get("shape") or [] if shape: rank = int(min(shape)) break return LoRAInfo( path=path, rank=rank, target="transformer", size_bytes=path.stat().st_size, ) @contextmanager def applied_lora(pipe: Any, path: Path | str | None, strength: float) -> Iterator[None]: """Apply a LoRA to the pipeline's dit for the duration of the context. Reverts on exit (including exception path) so the cached GPU model is left clean. If ``path`` is ``None``, this is a no-op. Validates the LoRA file with :func:`sniff` before touching the pipeline so a bad file is rejected before any GPU work begins. """ if path is None: yield return sniff(path) # raises LoRAValidationError on bad input _apply_lora_impl(pipe, path, strength) try: yield finally: _revert_lora_impl(pipe) def _apply_lora_impl(pipe: Any, path: Path | str, strength: float) -> None: """Apply a LoRA to ``pipe.dit`` using DiffSynth's ``load_lora`` (hotload mode). ``GeneralLoRALoader.convert_state_dict`` normalises CivitAI-style ``diffusion_model.*`` keys into the bare module-path keys DiffSynth's AutoWrappedLinear modules consume, so we don't need to remap ourselves. """ pipe.load_lora(module=pipe.dit, lora_config=str(path), alpha=float(strength), verbose=0) def _revert_lora_impl(pipe: Any) -> None: """Clear the hotloaded LoRA so the cached transformer is left clean.""" pipe.clear_lora(verbose=0)