File size: 4,019 Bytes
caca25f
 
 
 
 
 
 
 
 
 
 
 
 
 
3ea399a
caca25f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ea399a
caca25f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""Extract six mode-specific workflow templates from the master LTX 2.3 All-In-One workflow.

Each ComfyUI group whose title starts with a number (e.g. "01 Text to Video") becomes
a mode template containing only that group's nodes plus shared scaffolding (Models,
Lora, Setting, Prompt, Load Audio/Image/Video, Output groups).

Group title -> output filename mapping:
    01 -> t2v.json
    02 -> a2v.json
    03 -> i2v.json
    04 -> lipsync.json
    05 -> keyframe.json
    06 -> style.json
"""

from __future__ import annotations

import argparse
import json
import pathlib
import sys
from collections.abc import Iterable

GROUP_TO_FILENAME: dict[str, str] = {
    "01": "t2v.json",
    "02": "a2v.json",
    "03": "i2v.json",
    "04": "lipsync.json",
    "05": "keyframe.json",
    "06": "style.json",
}

SHARED_GROUP_PREFIXES: tuple[str, ...] = (
    "Models",
    "Lora",
    "Setting",
    "Prompt",
    "Load Audio",
    "Load Image",
    "Load Images",
    "Load Video",
    "Output",
)


def _node_in_group(node: dict, group: dict) -> bool:
    """Test whether a node's position lies inside a group's bounding box."""
    if "pos" not in node or "bounding" not in group:
        return False
    nx, ny = node["pos"][0], node["pos"][1]
    gx, gy, gw, gh = group["bounding"]
    return (gx <= nx <= gx + gw) and (gy <= ny <= gy + gh)


def _select_groups(master: dict, mode_prefix: str) -> list[dict]:
    """Pick the mode group plus all shared groups."""
    selected: list[dict] = []
    for g in master.get("groups", []):
        title = (g.get("title") or "").strip()
        if title.startswith(mode_prefix + " "):
            selected.append(g)
        elif any(title.startswith(p) for p in SHARED_GROUP_PREFIXES):
            selected.append(g)
    return selected


def _collect_nodes(master: dict, groups: Iterable[dict]) -> list[dict]:
    """Return all nodes lying inside any of the given groups."""
    groups_list = list(groups)
    keep: list[dict] = []
    for node in master.get("nodes", []):
        if any(_node_in_group(node, g) for g in groups_list):
            keep.append(node)
    return keep


def _collect_links(master: dict, kept_node_ids: set[int]) -> list[list]:
    """Keep only links where both endpoints are in the surviving node set."""
    return [
        link
        for link in master.get("links", [])
        # ComfyUI link tuple format: [link_id, src_node_id, src_out, dst_node_id, dst_in, type]
        if link[1] in kept_node_ids and link[3] in kept_node_ids
    ]


def extract_mode(master: dict, mode_prefix: str) -> dict:
    """Build a focused workflow JSON for the given mode group prefix."""
    groups = _select_groups(master, mode_prefix)
    nodes = _collect_nodes(master, groups)
    kept_ids = {n["id"] for n in nodes}
    links = _collect_links(master, kept_ids)

    return {
        "id": f"ltx23-aio-{mode_prefix}",
        "revision": 0,
        "last_node_id": max(kept_ids, default=0),
        "last_link_id": max((link[0] for link in links), default=0),
        "nodes": nodes,
        "links": links,
        "groups": groups,
        "definitions": master.get("definitions", {}),
        "config": master.get("config", {}),
        "extra": master.get("extra", {}),
        "version": master.get("version", 0.4),
    }


def main(argv: list[str] | None = None) -> int:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--master", type=pathlib.Path, required=True)
    parser.add_argument("--out", type=pathlib.Path, required=True)
    args = parser.parse_args(argv)

    master = json.loads(args.master.read_text())
    args.out.mkdir(parents=True, exist_ok=True)

    for prefix, filename in GROUP_TO_FILENAME.items():
        wf = extract_mode(master, prefix)
        out_path = args.out / filename
        out_path.write_text(json.dumps(wf, indent=2))
        print(f"  -> wrote {out_path} ({len(wf['nodes'])} nodes, {len(wf['links'])} links)")

    return 0


if __name__ == "__main__":
    sys.exit(main())