techfreakworm commited on
Commit
55a3bb4
·
unverified ·
1 Parent(s): c80a8b9

feat(models): walk_workflow_for_models scans loader nodes

Browse files
Files changed (2) hide show
  1. models.py +50 -0
  2. tests/test_models.py +14 -0
models.py CHANGED
@@ -94,3 +94,53 @@ MODEL_REGISTRY: dict[str, ModelEntry] = {
94
  )
95
  },
96
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  )
95
  },
96
  }
97
+
98
+
99
+ LOADER_NODE_TYPES: tuple[str, ...] = (
100
+ "CheckpointLoaderSimple",
101
+ "UNETLoader",
102
+ "UnetLoaderGGUF",
103
+ "VAELoader",
104
+ "VAELoaderKJ",
105
+ "LoraLoader",
106
+ "Power Lora Loader (rgthree)",
107
+ "LTXVGemmaCLIPModelLoader",
108
+ "LatentUpscaleModelLoader",
109
+ "DualCLIPLoader",
110
+ )
111
+
112
+
113
+ def walk_workflow_for_models(workflow: dict) -> set[str]:
114
+ """Return the set of model filenames referenced by loader nodes in the workflow.
115
+
116
+ Pulls filenames from nodes whose `type` matches a known loader. Filenames are
117
+ typically in `widgets_values[0]` (CheckpointLoaderSimple) or in nested rows
118
+ (Power Lora Loader). Falls back to scanning all string-valued widget entries
119
+ for `*.safetensors` / `*.gguf`.
120
+ """
121
+ needed: set[str] = set()
122
+ for node in workflow.get("nodes", []):
123
+ if node.get("type") not in LOADER_NODE_TYPES:
124
+ continue
125
+ widgets = node.get("widgets_values") or []
126
+ for value in _flatten_widget_values(widgets):
127
+ if isinstance(value, str) and (
128
+ value.endswith(".safetensors") or value.endswith(".gguf")
129
+ or value == "tokenizer.model" or value.endswith(".json")
130
+ ):
131
+ needed.add(value)
132
+ return needed
133
+
134
+
135
+ def _flatten_widget_values(values):
136
+ """Walk nested list/dict widget structures, yielding leaf values."""
137
+ if isinstance(values, dict):
138
+ yield from _flatten_widget_values(list(values.values()))
139
+ return
140
+ for v in values:
141
+ if isinstance(v, (list, tuple)):
142
+ yield from _flatten_widget_values(v)
143
+ elif isinstance(v, dict):
144
+ yield from _flatten_widget_values(list(v.values()))
145
+ else:
146
+ yield v
tests/test_models.py CHANGED
@@ -1,5 +1,6 @@
1
  """Unit tests for models.py — MODEL_REGISTRY and ensure_models_for_mode."""
2
  import models
 
3
 
4
 
5
  def test_model_registry_resolves_known_files():
@@ -12,3 +13,16 @@ def test_model_registry_includes_gemma_shards():
12
  key = f"model-{i:05d}-of-00005.safetensors"
13
  assert key in models.MODEL_REGISTRY
14
  assert "gemma-3-12b-it" in models.MODEL_REGISTRY[key].repo_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Unit tests for models.py — MODEL_REGISTRY and ensure_models_for_mode."""
2
  import models
3
+ import workflow
4
 
5
 
6
  def test_model_registry_resolves_known_files():
 
13
  key = f"model-{i:05d}-of-00005.safetensors"
14
  assert key in models.MODEL_REGISTRY
15
  assert "gemma-3-12b-it" in models.MODEL_REGISTRY[key].repo_id
16
+
17
+
18
+ def test_walk_workflow_for_models_finds_t2v_loaders():
19
+ wf = workflow.load_template("t2v")
20
+ needed = models.walk_workflow_for_models(wf)
21
+ # T2V needs at minimum a transformer (distilled, dev fp8, or GGUF Q4) and a gemma encoder
22
+ assert any(
23
+ name.endswith(".gguf")
24
+ or "distilled.safetensors" in name
25
+ or "transformer_only" in name
26
+ for name in needed
27
+ )
28
+ assert any("gemma" in name.lower() for name in needed)