Spaces:
Running on Zero
Running on Zero
File size: 4,832 Bytes
08c5e28 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | from dataclasses import dataclass, replace
from typing import NamedTuple, Protocol
import torch
@dataclass(frozen=True, slots=True)
class ContentReplacement:
"""
Represents a content replacement operation.
Used to replace a specific content with a replacement in a state dict key.
"""
content: str
replacement: str
@dataclass(frozen=True, slots=True)
class ContentMatching:
"""
Represents a content matching operation.
Used to match a specific prefix and suffix in a state dict key.
"""
prefix: str = ""
suffix: str = ""
class KeyValueOperationResult(NamedTuple):
"""
Represents the result of a key-value operation.
Contains the new key and value after the operation has been applied.
"""
new_key: str
new_value: torch.Tensor
class KeyValueOperation(Protocol):
"""
Protocol for key-value operations.
Used to apply operations to a specific key and value in a state dict.
"""
def __call__(self, tensor_key: str, tensor_value: torch.Tensor) -> list[KeyValueOperationResult]: ...
@dataclass(frozen=True, slots=True)
class SDKeyValueOperation:
"""
Represents a key-value operation.
Used to apply operations to a specific key and value in a state dict.
"""
key_matcher: ContentMatching
kv_operation: KeyValueOperation
@dataclass(frozen=True, slots=True)
class SDOps:
"""Immutable class representing state dict key operations."""
name: str
mapping: tuple[
ContentReplacement | ContentMatching | SDKeyValueOperation, ...
] = () # Immutable tuple of (key, value) pairs
allowed_keys: frozenset[str] | None = None
def with_replacement(self, content: str, replacement: str) -> "SDOps":
"""Create a new SDOps instance with the specified replacement added to the mapping."""
new_mapping = (*self.mapping, ContentReplacement(content, replacement))
return replace(self, mapping=new_mapping)
def with_matching(self, prefix: str = "", suffix: str = "") -> "SDOps":
"""Create a new SDOps instance with the specified prefix and suffix matching added to the mapping."""
new_mapping = (*self.mapping, ContentMatching(prefix, suffix))
return replace(self, mapping=new_mapping)
def with_additional_allowed_keys(self, keys: frozenset[str]) -> "SDOps":
"""Create a new SDOps instance that only passes keys present in *keys* (post-replacement).
If allowed_keys already exists, the sets are merged via union.
"""
merged = frozenset(keys) | self.allowed_keys if self.allowed_keys is not None else frozenset(keys)
return replace(self, allowed_keys=merged)
def with_kv_operation(
self,
operation: KeyValueOperation,
key_prefix: str = "",
key_suffix: str = "",
) -> "SDOps":
"""Create a new SDOps instance with the specified value operation added to the mapping."""
key_matcher = ContentMatching(key_prefix, key_suffix)
sd_kv_operation = SDKeyValueOperation(key_matcher, operation)
new_mapping = (*self.mapping, sd_kv_operation)
return replace(self, mapping=new_mapping)
def apply_to_key(self, key: str) -> str | None:
"""Apply the mapping to the given name."""
matchers = [content for content in self.mapping if isinstance(content, ContentMatching)]
valid = any(key.startswith(f.prefix) and key.endswith(f.suffix) for f in matchers)
if not valid:
return None
for replacement in self.mapping:
if not isinstance(replacement, ContentReplacement):
continue
if replacement.content in key:
key = key.replace(replacement.content, replacement.replacement)
if self.allowed_keys is not None and key not in self.allowed_keys:
return None
return key
def apply_to_key_value(self, key: str, value: torch.Tensor) -> list[KeyValueOperationResult]:
"""Apply the value operation to the given name and associated value."""
for operation in self.mapping:
if not isinstance(operation, SDKeyValueOperation):
continue
if key.startswith(operation.key_matcher.prefix) and key.endswith(operation.key_matcher.suffix):
return operation.kv_operation(key, value)
return [KeyValueOperationResult(key, value)]
# Predefined SDOps instances
LTXV_LORA_COMFY_RENAMING_MAP = (
SDOps("LTXV_LORA_COMFY_PREFIX_MAP").with_matching().with_replacement("diffusion_model.", "")
)
LTXV_LORA_COMFY_TARGET_MAP = (
SDOps("LTXV_LORA_COMFY_TARGET_MAP")
.with_matching()
.with_replacement("diffusion_model.", "")
.with_replacement(".lora_A.weight", ".weight")
.with_replacement(".lora_B.weight", ".weight")
)
|