Spaces:
Running on Zero
Running on Zero
feat(workflow): set_input + validate over node graph
Browse files- tests/test_workflow.py +32 -0
- workflow.py +37 -0
tests/test_workflow.py
CHANGED
|
@@ -22,3 +22,35 @@ def test_load_template_returns_independent_copy():
|
|
| 22 |
a["nodes"].append({"id": -999})
|
| 23 |
b = workflow.load_template("t2v")
|
| 24 |
assert {-999} & {n.get("id") for n in b["nodes"]} == set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
a["nodes"].append({"id": -999})
|
| 23 |
b = workflow.load_template("t2v")
|
| 24 |
assert {-999} & {n.get("id") for n in b["nodes"]} == set()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_set_input_patches_widgets_values_in_place():
|
| 28 |
+
wf = workflow.load_template("t2v")
|
| 29 |
+
target_node = next(n for n in wf["nodes"] if n["type"] == "CLIPTextEncode")
|
| 30 |
+
workflow.set_input(wf, target_node["id"], 0, "new prompt text")
|
| 31 |
+
refetched = next(n for n in wf["nodes"] if n["id"] == target_node["id"])
|
| 32 |
+
assert refetched["widgets_values"][0] == "new prompt text"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_set_input_raises_for_unknown_node():
|
| 36 |
+
wf = workflow.load_template("t2v")
|
| 37 |
+
with pytest.raises(KeyError, match="node id"):
|
| 38 |
+
workflow.set_input(wf, 999_999_999, 0, "x")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def test_validate_accepts_canonical_template():
|
| 42 |
+
wf = workflow.load_template("t2v")
|
| 43 |
+
workflow.validate(wf) # must not raise
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_validate_rejects_workflow_with_no_nodes():
|
| 47 |
+
wf = {"nodes": [], "links": []}
|
| 48 |
+
with pytest.raises(ValueError, match="no nodes"):
|
| 49 |
+
workflow.validate(wf)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def test_validate_rejects_orphan_link():
|
| 53 |
+
wf = workflow.load_template("t2v")
|
| 54 |
+
wf["links"].append([99999, 1, 0, 999_999_999, 0, "INT"]) # destination doesn't exist
|
| 55 |
+
with pytest.raises(ValueError, match="orphan link"):
|
| 56 |
+
workflow.validate(wf)
|
workflow.py
CHANGED
|
@@ -17,3 +17,40 @@ def load_template(mode: str) -> dict[str, Any]:
|
|
| 17 |
raise ValueError(f"unknown mode {mode!r}; expected one of {VALID_MODES}")
|
| 18 |
path = WORKFLOWS_DIR / f"{mode}.json"
|
| 19 |
return copy.deepcopy(json.loads(path.read_text()))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
raise ValueError(f"unknown mode {mode!r}; expected one of {VALID_MODES}")
|
| 18 |
path = WORKFLOWS_DIR / f"{mode}.json"
|
| 19 |
return copy.deepcopy(json.loads(path.read_text()))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def set_input(workflow: dict[str, Any], node_id: int, widget_index: int, value: Any) -> None:
|
| 23 |
+
"""Patch a node's widgets_values in place.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
workflow: A workflow dict (must have a "nodes" list).
|
| 27 |
+
node_id: The id of the node to patch.
|
| 28 |
+
widget_index: Position within the node's widgets_values list.
|
| 29 |
+
value: New value.
|
| 30 |
+
|
| 31 |
+
Raises:
|
| 32 |
+
KeyError: If no node with the given id exists.
|
| 33 |
+
"""
|
| 34 |
+
for node in workflow["nodes"]:
|
| 35 |
+
if node.get("id") == node_id:
|
| 36 |
+
widgets = node.setdefault("widgets_values", [])
|
| 37 |
+
while len(widgets) <= widget_index:
|
| 38 |
+
widgets.append(None)
|
| 39 |
+
widgets[widget_index] = value
|
| 40 |
+
return
|
| 41 |
+
raise KeyError(f"node id {node_id} not found in workflow")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def validate(workflow: dict[str, Any]) -> None:
|
| 45 |
+
"""Static schema validation. Raises ValueError on the first problem found."""
|
| 46 |
+
nodes = workflow.get("nodes")
|
| 47 |
+
if not isinstance(nodes, list) or len(nodes) == 0:
|
| 48 |
+
raise ValueError("workflow has no nodes")
|
| 49 |
+
|
| 50 |
+
node_ids = {n.get("id") for n in nodes if "id" in n}
|
| 51 |
+
for link in workflow.get("links", []):
|
| 52 |
+
if not isinstance(link, list) or len(link) < 6:
|
| 53 |
+
raise ValueError(f"malformed link {link}")
|
| 54 |
+
_, src, _, dst, _, _ = link
|
| 55 |
+
if src not in node_ids or dst not in node_ids:
|
| 56 |
+
raise ValueError(f"orphan link {link}")
|