techfreakworm commited on
Commit
caca25f
·
unverified ·
1 Parent(s): 5926879

feat(tools): extract six mode templates from master workflow JSON

Browse files
tests/test_extract_modes.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the workflow-mode extractor."""
2
+ import json
3
+ import subprocess
4
+ import sys
5
+
6
+ from tests.conftest import REPO_ROOT
7
+
8
+
9
+ def test_extract_creates_six_mode_files(master_workflow, tmp_path):
10
+ """extract_modes.py emits six valid mode-specific JSON templates."""
11
+ out_dir = tmp_path / "workflows"
12
+ master_path = tmp_path / "master.json"
13
+ master_path.write_text(json.dumps(master_workflow))
14
+
15
+ result = subprocess.run(
16
+ [
17
+ sys.executable,
18
+ str(REPO_ROOT / "tools" / "extract_modes.py"),
19
+ "--master",
20
+ str(master_path),
21
+ "--out",
22
+ str(out_dir),
23
+ ],
24
+ check=False,
25
+ capture_output=True,
26
+ text=True,
27
+ )
28
+
29
+ assert result.returncode == 0, result.stderr
30
+ expected = {"t2v.json", "a2v.json", "i2v.json", "lipsync.json", "keyframe.json", "style.json"}
31
+ actual = {p.name for p in out_dir.iterdir()}
32
+ assert actual == expected
33
+
34
+ # Each file must be valid JSON with at least one node.
35
+ for path in out_dir.iterdir():
36
+ wf = json.loads(path.read_text())
37
+ assert "nodes" in wf
38
+ assert len(wf["nodes"]) > 0
tools/__init__.py ADDED
File without changes
tools/extract_modes.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Extract six mode-specific workflow templates from the master LTX 2.3 All-In-One workflow.
2
+
3
+ Each ComfyUI group whose title starts with a number (e.g. "01 Text to Video") becomes
4
+ a mode template containing only that group's nodes plus shared scaffolding (Models,
5
+ Lora, Setting, Prompt, Load Audio/Image/Video, Output groups).
6
+
7
+ Group title -> output filename mapping:
8
+ 01 -> t2v.json
9
+ 02 -> a2v.json
10
+ 03 -> i2v.json
11
+ 04 -> lipsync.json
12
+ 05 -> keyframe.json
13
+ 06 -> style.json
14
+ """
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import json
19
+ import pathlib
20
+ import sys
21
+ from collections.abc import Iterable
22
+
23
+ GROUP_TO_FILENAME: dict[str, str] = {
24
+ "01": "t2v.json",
25
+ "02": "a2v.json",
26
+ "03": "i2v.json",
27
+ "04": "lipsync.json",
28
+ "05": "keyframe.json",
29
+ "06": "style.json",
30
+ }
31
+
32
+ SHARED_GROUP_PREFIXES: tuple[str, ...] = (
33
+ "Models",
34
+ "Lora",
35
+ "Setting",
36
+ "Prompt",
37
+ "Load Audio",
38
+ "Load Image",
39
+ "Load Images",
40
+ "Load Video",
41
+ "Output",
42
+ )
43
+
44
+
45
+ def _node_in_group(node: dict, group: dict) -> bool:
46
+ """Test whether a node's position lies inside a group's bounding box."""
47
+ if "pos" not in node or "bounding" not in group:
48
+ return False
49
+ nx, ny = node["pos"][0], node["pos"][1]
50
+ gx, gy, gw, gh = group["bounding"]
51
+ return (gx <= nx <= gx + gw) and (gy <= ny <= gy + gh)
52
+
53
+
54
+ def _select_groups(master: dict, mode_prefix: str) -> list[dict]:
55
+ """Pick the mode group plus all shared groups."""
56
+ selected: list[dict] = []
57
+ for g in master.get("groups", []):
58
+ title = (g.get("title") or "").strip()
59
+ if title.startswith(mode_prefix + " "):
60
+ selected.append(g)
61
+ elif any(title.startswith(p) for p in SHARED_GROUP_PREFIXES):
62
+ selected.append(g)
63
+ return selected
64
+
65
+
66
+ def _collect_nodes(master: dict, groups: Iterable[dict]) -> list[dict]:
67
+ """Return all nodes lying inside any of the given groups."""
68
+ groups_list = list(groups)
69
+ keep: list[dict] = []
70
+ for node in master.get("nodes", []):
71
+ if any(_node_in_group(node, g) for g in groups_list):
72
+ keep.append(node)
73
+ return keep
74
+
75
+
76
+ def _collect_links(master: dict, kept_node_ids: set[int]) -> list[list]:
77
+ """Keep only links where both endpoints are in the surviving node set."""
78
+ return [
79
+ link
80
+ for link in master.get("links", [])
81
+ # ComfyUI link tuple format: [link_id, src_node_id, src_out, dst_node_id, dst_in, type]
82
+ if link[1] in kept_node_ids and link[3] in kept_node_ids
83
+ ]
84
+
85
+
86
+ def extract_mode(master: dict, mode_prefix: str) -> dict:
87
+ """Build a focused workflow JSON for the given mode group prefix."""
88
+ groups = _select_groups(master, mode_prefix)
89
+ nodes = _collect_nodes(master, groups)
90
+ kept_ids = {n["id"] for n in nodes}
91
+ links = _collect_links(master, kept_ids)
92
+
93
+ return {
94
+ "id": f"ltx23-aio-{mode_prefix}",
95
+ "revision": 0,
96
+ "last_node_id": max(kept_ids, default=0),
97
+ "last_link_id": max((l[0] for l in links), default=0),
98
+ "nodes": nodes,
99
+ "links": links,
100
+ "groups": groups,
101
+ "definitions": master.get("definitions", {}),
102
+ "config": master.get("config", {}),
103
+ "extra": master.get("extra", {}),
104
+ "version": master.get("version", 0.4),
105
+ }
106
+
107
+
108
+ def main(argv: list[str] | None = None) -> int:
109
+ parser = argparse.ArgumentParser(description=__doc__)
110
+ parser.add_argument("--master", type=pathlib.Path, required=True)
111
+ parser.add_argument("--out", type=pathlib.Path, required=True)
112
+ args = parser.parse_args(argv)
113
+
114
+ master = json.loads(args.master.read_text())
115
+ args.out.mkdir(parents=True, exist_ok=True)
116
+
117
+ for prefix, filename in GROUP_TO_FILENAME.items():
118
+ wf = extract_mode(master, prefix)
119
+ out_path = args.out / filename
120
+ out_path.write_text(json.dumps(wf, indent=2))
121
+ print(f" -> wrote {out_path} ({len(wf['nodes'])} nodes, {len(wf['links'])} links)")
122
+
123
+ return 0
124
+
125
+
126
+ if __name__ == "__main__":
127
+ sys.exit(main())