Spaces:
Running on Zero
Running on Zero
feat(models): walk_workflow_for_models scans loader nodes
Browse files- models.py +50 -0
- tests/test_models.py +14 -0
models.py
CHANGED
|
@@ -94,3 +94,53 @@ MODEL_REGISTRY: dict[str, ModelEntry] = {
|
|
| 94 |
)
|
| 95 |
},
|
| 96 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
)
|
| 95 |
},
|
| 96 |
}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
LOADER_NODE_TYPES: tuple[str, ...] = (
|
| 100 |
+
"CheckpointLoaderSimple",
|
| 101 |
+
"UNETLoader",
|
| 102 |
+
"UnetLoaderGGUF",
|
| 103 |
+
"VAELoader",
|
| 104 |
+
"VAELoaderKJ",
|
| 105 |
+
"LoraLoader",
|
| 106 |
+
"Power Lora Loader (rgthree)",
|
| 107 |
+
"LTXVGemmaCLIPModelLoader",
|
| 108 |
+
"LatentUpscaleModelLoader",
|
| 109 |
+
"DualCLIPLoader",
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def walk_workflow_for_models(workflow: dict) -> set[str]:
|
| 114 |
+
"""Return the set of model filenames referenced by loader nodes in the workflow.
|
| 115 |
+
|
| 116 |
+
Pulls filenames from nodes whose `type` matches a known loader. Filenames are
|
| 117 |
+
typically in `widgets_values[0]` (CheckpointLoaderSimple) or in nested rows
|
| 118 |
+
(Power Lora Loader). Falls back to scanning all string-valued widget entries
|
| 119 |
+
for `*.safetensors` / `*.gguf`.
|
| 120 |
+
"""
|
| 121 |
+
needed: set[str] = set()
|
| 122 |
+
for node in workflow.get("nodes", []):
|
| 123 |
+
if node.get("type") not in LOADER_NODE_TYPES:
|
| 124 |
+
continue
|
| 125 |
+
widgets = node.get("widgets_values") or []
|
| 126 |
+
for value in _flatten_widget_values(widgets):
|
| 127 |
+
if isinstance(value, str) and (
|
| 128 |
+
value.endswith(".safetensors") or value.endswith(".gguf")
|
| 129 |
+
or value == "tokenizer.model" or value.endswith(".json")
|
| 130 |
+
):
|
| 131 |
+
needed.add(value)
|
| 132 |
+
return needed
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _flatten_widget_values(values):
|
| 136 |
+
"""Walk nested list/dict widget structures, yielding leaf values."""
|
| 137 |
+
if isinstance(values, dict):
|
| 138 |
+
yield from _flatten_widget_values(list(values.values()))
|
| 139 |
+
return
|
| 140 |
+
for v in values:
|
| 141 |
+
if isinstance(v, (list, tuple)):
|
| 142 |
+
yield from _flatten_widget_values(v)
|
| 143 |
+
elif isinstance(v, dict):
|
| 144 |
+
yield from _flatten_widget_values(list(v.values()))
|
| 145 |
+
else:
|
| 146 |
+
yield v
|
tests/test_models.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
"""Unit tests for models.py — MODEL_REGISTRY and ensure_models_for_mode."""
|
| 2 |
import models
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
def test_model_registry_resolves_known_files():
|
|
@@ -12,3 +13,16 @@ def test_model_registry_includes_gemma_shards():
|
|
| 12 |
key = f"model-{i:05d}-of-00005.safetensors"
|
| 13 |
assert key in models.MODEL_REGISTRY
|
| 14 |
assert "gemma-3-12b-it" in models.MODEL_REGISTRY[key].repo_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""Unit tests for models.py — MODEL_REGISTRY and ensure_models_for_mode."""
|
| 2 |
import models
|
| 3 |
+
import workflow
|
| 4 |
|
| 5 |
|
| 6 |
def test_model_registry_resolves_known_files():
|
|
|
|
| 13 |
key = f"model-{i:05d}-of-00005.safetensors"
|
| 14 |
assert key in models.MODEL_REGISTRY
|
| 15 |
assert "gemma-3-12b-it" in models.MODEL_REGISTRY[key].repo_id
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test_walk_workflow_for_models_finds_t2v_loaders():
|
| 19 |
+
wf = workflow.load_template("t2v")
|
| 20 |
+
needed = models.walk_workflow_for_models(wf)
|
| 21 |
+
# T2V needs at minimum a transformer (distilled, dev fp8, or GGUF Q4) and a gemma encoder
|
| 22 |
+
assert any(
|
| 23 |
+
name.endswith(".gguf")
|
| 24 |
+
or "distilled.safetensors" in name
|
| 25 |
+
or "transformer_only" in name
|
| 26 |
+
for name in needed
|
| 27 |
+
)
|
| 28 |
+
assert any("gemma" in name.lower() for name in needed)
|