techfreakworm commited on
Commit
81f5ac0
·
unverified ·
1 Parent(s): d65b6b1

feat(workflow): set_input + validate over node graph

Browse files
Files changed (2) hide show
  1. tests/test_workflow.py +32 -0
  2. workflow.py +37 -0
tests/test_workflow.py CHANGED
@@ -22,3 +22,35 @@ def test_load_template_returns_independent_copy():
22
  a["nodes"].append({"id": -999})
23
  b = workflow.load_template("t2v")
24
  assert {-999} & {n.get("id") for n in b["nodes"]} == set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  a["nodes"].append({"id": -999})
23
  b = workflow.load_template("t2v")
24
  assert {-999} & {n.get("id") for n in b["nodes"]} == set()
25
+
26
+
27
+ def test_set_input_patches_widgets_values_in_place():
28
+ wf = workflow.load_template("t2v")
29
+ target_node = next(n for n in wf["nodes"] if n["type"] == "CLIPTextEncode")
30
+ workflow.set_input(wf, target_node["id"], 0, "new prompt text")
31
+ refetched = next(n for n in wf["nodes"] if n["id"] == target_node["id"])
32
+ assert refetched["widgets_values"][0] == "new prompt text"
33
+
34
+
35
+ def test_set_input_raises_for_unknown_node():
36
+ wf = workflow.load_template("t2v")
37
+ with pytest.raises(KeyError, match="node id"):
38
+ workflow.set_input(wf, 999_999_999, 0, "x")
39
+
40
+
41
+ def test_validate_accepts_canonical_template():
42
+ wf = workflow.load_template("t2v")
43
+ workflow.validate(wf) # must not raise
44
+
45
+
46
+ def test_validate_rejects_workflow_with_no_nodes():
47
+ wf = {"nodes": [], "links": []}
48
+ with pytest.raises(ValueError, match="no nodes"):
49
+ workflow.validate(wf)
50
+
51
+
52
+ def test_validate_rejects_orphan_link():
53
+ wf = workflow.load_template("t2v")
54
+ wf["links"].append([99999, 1, 0, 999_999_999, 0, "INT"]) # destination doesn't exist
55
+ with pytest.raises(ValueError, match="orphan link"):
56
+ workflow.validate(wf)
workflow.py CHANGED
@@ -17,3 +17,40 @@ def load_template(mode: str) -> dict[str, Any]:
17
  raise ValueError(f"unknown mode {mode!r}; expected one of {VALID_MODES}")
18
  path = WORKFLOWS_DIR / f"{mode}.json"
19
  return copy.deepcopy(json.loads(path.read_text()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  raise ValueError(f"unknown mode {mode!r}; expected one of {VALID_MODES}")
18
  path = WORKFLOWS_DIR / f"{mode}.json"
19
  return copy.deepcopy(json.loads(path.read_text()))
20
+
21
+
22
+ def set_input(workflow: dict[str, Any], node_id: int, widget_index: int, value: Any) -> None:
23
+ """Patch a node's widgets_values in place.
24
+
25
+ Args:
26
+ workflow: A workflow dict (must have a "nodes" list).
27
+ node_id: The id of the node to patch.
28
+ widget_index: Position within the node's widgets_values list.
29
+ value: New value.
30
+
31
+ Raises:
32
+ KeyError: If no node with the given id exists.
33
+ """
34
+ for node in workflow["nodes"]:
35
+ if node.get("id") == node_id:
36
+ widgets = node.setdefault("widgets_values", [])
37
+ while len(widgets) <= widget_index:
38
+ widgets.append(None)
39
+ widgets[widget_index] = value
40
+ return
41
+ raise KeyError(f"node id {node_id} not found in workflow")
42
+
43
+
44
+ def validate(workflow: dict[str, Any]) -> None:
45
+ """Static schema validation. Raises ValueError on the first problem found."""
46
+ nodes = workflow.get("nodes")
47
+ if not isinstance(nodes, list) or len(nodes) == 0:
48
+ raise ValueError("workflow has no nodes")
49
+
50
+ node_ids = {n.get("id") for n in nodes if "id" in n}
51
+ for link in workflow.get("links", []):
52
+ if not isinstance(link, list) or len(link) < 6:
53
+ raise ValueError(f"malformed link {link}")
54
+ _, src, _, dst, _, _ = link
55
+ if src not in node_ids or dst not in node_ids:
56
+ raise ValueError(f"orphan link {link}")