techfreakworm commited on
Commit
0256245
Β·
unverified Β·
1 Parent(s): 0314079

feat(modes): T2V + I2V parameterize_fn with stage maps

Browse files
Files changed (2) hide show
  1. modes.py +105 -0
  2. tests/test_modes.py +32 -0
modes.py CHANGED
@@ -41,3 +41,108 @@ class Mode:
41
 
42
  # Filled in by tasks 11–12.
43
  MODE_REGISTRY: dict[str, Mode] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # Filled in by tasks 11–12.
43
  MODE_REGISTRY: dict[str, Mode] = {}
44
+
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # Node-id constants β€” captured from workflows/{t2v,i2v}.json on 2026-04-30.
48
+ #
49
+ # The master workflow uses rgthree's GetNode/SetNode for indirection. SetNodes
50
+ # named "pos"/"neg" expose the *outputs* of CLIPTextEncode, not the prompt
51
+ # strings. So the canonical place to set the prompt text is the CLIPTextEncode
52
+ # node itself.
53
+ #
54
+ # Width/Height/FPS are INTConstant nodes whose values feed downstream Set_*
55
+ # variables. Clip length comes from a mxSlider (in seconds, then multiplied by
56
+ # FPS via a MathExpression to compute frames). No SetNode for "noise"/seed
57
+ # survived the extraction, so seed is intentionally NOT patched here β€” the
58
+ # template's hard-coded value is used until we wire RandomNoise injection in
59
+ # Task 12+.
60
+ #
61
+ # LoRA rows live inside a single Power Lora Loader (rgthree) node whose
62
+ # widgets_values is a list of dicts. Patching a specific row requires knowing
63
+ # the index, and the canonical mapping (camera_lora value -> row index) belongs
64
+ # in models.py once camera-LoRA selection lands. Deferred for now.
65
+ # ---------------------------------------------------------------------------
66
+
67
+ T2V_NODE_PROMPT = 5536 # CLIPTextEncode positive β€” wv[0] = prompt
68
+ T2V_NODE_NEG_PROMPT = 5537 # CLIPTextEncode negative β€” wv[0] = negative prompt
69
+ T2V_NODE_WIDTH = 5383 # INTConstant "Width" β€” wv[0]
70
+ T2V_NODE_HEIGHT = 5382 # INTConstant "Height" β€” wv[0]
71
+ T2V_NODE_FPS = 5445 # INTConstant "FPS" β€” wv[0]
72
+ T2V_NODE_CLIP_LENGTH = 196 # mxSlider "Clip Length ( in seconds )" β€” wv[0]
73
+
74
+ I2V_NODE_PROMPT = 5536
75
+ I2V_NODE_NEG_PROMPT = 5537
76
+ I2V_NODE_WIDTH = 5383
77
+ I2V_NODE_HEIGHT = 5382
78
+ I2V_NODE_FPS = 5445
79
+ I2V_NODE_CLIP_LENGTH = 196
80
+ I2V_NODE_IMAGE = 149 # LoadImage "Load Image1" β€” wv[0] = filename
81
+
82
+
83
+ def _frames_to_seconds(frames: int, fps: int) -> int:
84
+ """Convert (frames, fps) to integer seconds for the mxSlider clip-length widget.
85
+
86
+ The downstream MathExpression is `a*b+1` (a=seconds, b=fps -> total frames),
87
+ so for a target frame count F at fps R we need seconds = ceil((F - 1) / R).
88
+ Round up so the slider is never short of the requested frames.
89
+ """
90
+ if fps <= 0:
91
+ return 1
92
+ return max(1, -(-(frames - 1) // fps))
93
+
94
+
95
+ def _t2v_parameterize(inp: dict[str, Any]) -> list[Patch]:
96
+ return [
97
+ (T2V_NODE_PROMPT, 0, inp["prompt"]),
98
+ (T2V_NODE_NEG_PROMPT, 0, inp.get("negative_prompt", "")),
99
+ (T2V_NODE_WIDTH, 0, int(inp["width"])),
100
+ (T2V_NODE_HEIGHT, 0, int(inp["height"])),
101
+ (T2V_NODE_FPS, 0, int(inp["fps"])),
102
+ (T2V_NODE_CLIP_LENGTH, 0, _frames_to_seconds(int(inp["frames"]), int(inp["fps"]))),
103
+ ]
104
+
105
+
106
+ def _i2v_parameterize(inp: dict[str, Any]) -> list[Patch]:
107
+ return [
108
+ (I2V_NODE_PROMPT, 0, inp["prompt"]),
109
+ (I2V_NODE_NEG_PROMPT, 0, inp.get("negative_prompt", "")),
110
+ (I2V_NODE_IMAGE, 0, inp["image"]),
111
+ (I2V_NODE_WIDTH, 0, int(inp["width"])),
112
+ (I2V_NODE_HEIGHT, 0, int(inp["height"])),
113
+ (I2V_NODE_FPS, 0, int(inp["fps"])),
114
+ (I2V_NODE_CLIP_LENGTH, 0, _frames_to_seconds(int(inp["frames"]), int(inp["fps"]))),
115
+ ]
116
+
117
+
118
+ _T2V_STAGES = [
119
+ Stage("Encode prompt", 5),
120
+ Stage("Diffusion (Stage 1)", 60),
121
+ Stage("Spatial upscale", 7),
122
+ Stage("Diffusion (Stage 2)", 18),
123
+ Stage("Decode video", 10),
124
+ ]
125
+
126
+ _I2V_STAGES = [
127
+ Stage("Encode prompt", 5),
128
+ Stage("Encode image", 3),
129
+ Stage("Diffusion (Stage 1)", 55),
130
+ Stage("Spatial upscale", 7),
131
+ Stage("Diffusion (Stage 2)", 20),
132
+ Stage("Decode video", 10),
133
+ ]
134
+
135
+ MODE_REGISTRY["t2v"] = Mode(
136
+ name="t2v",
137
+ label="Text β†’ Video",
138
+ icon="πŸ“",
139
+ parameterize_fn=_t2v_parameterize,
140
+ stage_map=_T2V_STAGES,
141
+ )
142
+ MODE_REGISTRY["i2v"] = Mode(
143
+ name="i2v",
144
+ label="Image β†’ Video",
145
+ icon="πŸ–Ό",
146
+ parameterize_fn=_i2v_parameterize,
147
+ stage_map=_I2V_STAGES,
148
+ )
tests/test_modes.py CHANGED
@@ -2,6 +2,7 @@
2
  import pytest
3
 
4
  import modes
 
5
 
6
 
7
  def test_mode_dataclass_has_expected_fields():
@@ -14,3 +15,34 @@ def test_mode_dataclass_has_expected_fields():
14
  def test_mode_registry_is_a_dict():
15
  """MODE_REGISTRY exists and is a dict (entries added in Tasks 11–12)."""
16
  assert isinstance(modes.MODE_REGISTRY, dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import pytest
3
 
4
  import modes
5
+ import workflow
6
 
7
 
8
  def test_mode_dataclass_has_expected_fields():
 
15
  def test_mode_registry_is_a_dict():
16
  """MODE_REGISTRY exists and is a dict (entries added in Tasks 11–12)."""
17
  assert isinstance(modes.MODE_REGISTRY, dict)
18
+
19
+
20
+ def test_t2v_parameterize_produces_valid_patches(canonical_inputs):
21
+ inputs = canonical_inputs["t2v"]
22
+ mode = modes.MODE_REGISTRY["t2v"]
23
+ patches = mode.parameterize_fn(inputs)
24
+
25
+ # All patches must be (node_id: int, widget_index: int, value: Any)
26
+ for node_id, widget_index, value in patches:
27
+ assert isinstance(node_id, int)
28
+ assert isinstance(widget_index, int)
29
+
30
+ # Apply patches to a real template; result must validate.
31
+ wf = workflow.load_template("t2v")
32
+ for patch in patches:
33
+ workflow.set_input(wf, *patch)
34
+ workflow.validate(wf)
35
+
36
+
37
+ def test_i2v_parameterize_uses_image_path(canonical_inputs):
38
+ inputs = canonical_inputs["i2v"]
39
+ mode = modes.MODE_REGISTRY["i2v"]
40
+ patches = mode.parameterize_fn(inputs)
41
+ values = [p[2] for p in patches]
42
+ assert inputs["image"] in values
43
+
44
+
45
+ def test_t2v_and_i2v_in_registry():
46
+ """T2V and I2V exist in MODE_REGISTRY (full completeness in Task 12)."""
47
+ assert "t2v" in modes.MODE_REGISTRY
48
+ assert "i2v" in modes.MODE_REGISTRY