dikdimon commited on
Commit
7ff34bc
·
verified ·
1 Parent(s): 3d4fdb4

Upload neutral_prompt_patcheds using SD-Hub

Browse files
Files changed (47) hide show
  1. neutral_prompt_patcheds/.gitignore +1 -0
  2. neutral_prompt_patcheds/LICENSE +21 -0
  3. neutral_prompt_patcheds/README.md +343 -0
  4. neutral_prompt_patcheds/lib_neutral_prompt/__init__.py +1 -0
  5. neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/__init__.cpython-310.pyc +0 -0
  6. neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/affine_transform.cpython-310.pyc +0 -0
  7. neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/affine_utils.cpython-310.pyc +0 -0
  8. neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/cfg_denoiser_hijack.cpython-310.pyc +0 -0
  9. neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/global_state.cpython-310.pyc +0 -0
  10. neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/hijacker.cpython-310.pyc +0 -0
  11. neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/matryoshka_utils.cpython-310.pyc +0 -0
  12. neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/neutral_prompt_parser.cpython-310.pyc +0 -0
  13. neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/prompt_parser_hijack.cpython-310.pyc +0 -0
  14. neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/protection_utils.cpython-310.pyc +0 -0
  15. neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/step_utils.cpython-310.pyc +0 -0
  16. neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/ui.cpython-310.pyc +0 -0
  17. neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/xyz_grid.cpython-310.pyc +0 -0
  18. neutral_prompt_patcheds/lib_neutral_prompt/affine_transform.py +83 -0
  19. neutral_prompt_patcheds/lib_neutral_prompt/affine_utils.py +220 -0
  20. neutral_prompt_patcheds/lib_neutral_prompt/cfg_denoiser_hijack.py +726 -0
  21. neutral_prompt_patcheds/lib_neutral_prompt/external_code/__init__.py +23 -0
  22. neutral_prompt_patcheds/lib_neutral_prompt/external_code/api.py +27 -0
  23. neutral_prompt_patcheds/lib_neutral_prompt/global_state.py +325 -0
  24. neutral_prompt_patcheds/lib_neutral_prompt/hijacker.py +34 -0
  25. neutral_prompt_patcheds/lib_neutral_prompt/matryoshka_utils.py +644 -0
  26. neutral_prompt_patcheds/lib_neutral_prompt/neutral_prompt_parser.py +453 -0
  27. neutral_prompt_patcheds/lib_neutral_prompt/prompt_parser_hijack.py +134 -0
  28. neutral_prompt_patcheds/lib_neutral_prompt/protection_utils.py +277 -0
  29. neutral_prompt_patcheds/lib_neutral_prompt/step_utils.py +454 -0
  30. neutral_prompt_patcheds/lib_neutral_prompt/ui.py +811 -0
  31. neutral_prompt_patcheds/lib_neutral_prompt/xyz_grid.py +42 -0
  32. neutral_prompt_patcheds/scripts/__pycache__/neutral_prompt.cpython-310.pyc +0 -0
  33. neutral_prompt_patcheds/scripts/neutral_prompt.py +99 -0
  34. neutral_prompt_patcheds/test/perp_parser/__init__.py +54 -0
  35. neutral_prompt_patcheds/test/perp_parser/mock_torch.py +61 -0
  36. neutral_prompt_patcheds/test/perp_parser/test_affine_keyword_order.py +133 -0
  37. neutral_prompt_patcheds/test/perp_parser/test_affine_pipeline.py +217 -0
  38. neutral_prompt_patcheds/test/perp_parser/test_basic_parser.py +122 -0
  39. neutral_prompt_patcheds/test/perp_parser/test_lock_after_end.py +544 -0
  40. neutral_prompt_patcheds/test/perp_parser/test_malicious_parser.py +182 -0
  41. neutral_prompt_patcheds/test/perp_parser/test_matryoshka.py +440 -0
  42. neutral_prompt_patcheds/test/perp_parser/test_matryoshka_golden.py +331 -0
  43. neutral_prompt_patcheds/test/perp_parser/test_parametric_syntax.py +535 -0
  44. neutral_prompt_patcheds/test/perp_parser/test_runtime_behavior.py +826 -0
  45. neutral_prompt_patcheds/test/perp_parser/test_sprint2.py +656 -0
  46. neutral_prompt_patcheds/test/perp_parser/test_sprint2_hotfix.py +419 -0
  47. neutral_prompt_patcheds/test/perp_parser/test_stabilization.py +539 -0
neutral_prompt_patcheds/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
neutral_prompt_patcheds/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 ljleb
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
neutral_prompt_patcheds/README.md ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sd-webui-neutral-prompt — Patched Edition
2
+
3
+ > **Unified merge of the main experimental branches — with bugfixes and parametric syntax.**
4
+ >
5
+ > Based on the _Ultimate Edition_ merge, with the following additions:
6
+ > - `AND_SALT[k]` / `AND_SALT_WIDE[k]` / `AND_SALT_BLOB[k]` — user-controlled salience sharpness
7
+ > - `AND_ALIGN[D,S]` / `AND_MASK_ALIGN[D,S]` — any D/S pair without suffix pre-declaration
8
+ > - Fixed `AND_SALT` default (k=5 instead of k=20 — actually visible)
9
+ > - Fixed `AND_SALT_BLOB` morphology order (thickify-first prevents seed self-destruction)
10
+ > - Guard against "base prompt replacement" when no plain AND-child is present
11
+
12
+ ---
13
+
14
+ ## What's inside
15
+
16
+ | Feature | Origin | Keyword / setting |
17
+ |---|---|---|
18
+ | Perpendicular projection | `main` | `AND_PERP` |
19
+ | Saliency-guided mask | `main` + patch | `AND_SALT[k]` |
20
+ | Saliency blob mask | `life` + **fixed** | `AND_SALT_BLOB[k]` |
21
+ | Saliency wide mask | `main` | `AND_SALT_WIDE[k]` |
22
+ | Semantic guidance top-k | `main` | `AND_TOPK` |
23
+ | CFG rescale (mean-preserving) | `main` | CFG rescale φ slider |
24
+ | XYZ-grid CFG rescale axis | `main` | XYZ-grid option |
25
+ | External API override | `main` / `export_rescale_factor` | `override_cfg_rescale()` |
26
+ | CFG rescale factor export | `export_rescale_factor` | `get_last_cfg_rescale_factor()` |
27
+ | Affine spatial transforms | `affine` | `ROTATE / SLIDE / SCALE / SHEAR` |
28
+ | Soft alignment blend | `alignment_blend` | `AND_ALIGN[D,S]` or `AND_ALIGN_D_S` |
29
+ | Binary alignment mask | `alignment_mask` | `AND_MASK_ALIGN[D,S]` or `AND_MASK_ALIGN_D_S` |
30
+
31
+ ---
32
+
33
+ ## Prompt syntax
34
+
35
+ ```
36
+ positive prompt text
37
+ AND_PERP negative-direction text :0.8
38
+ AND_SALT[5] competing concept :1.0 ← k=5 (default; more visible than old k=20)
39
+ AND_SALT_WIDE classic broad mask :0.8 ← k=1 default
40
+ AND_SALT_BLOB[8] blob concept :1.0 ← custom k for blob seed sharpness
41
+ AND_TOPK fine detail tweak :0.5
42
+ AND_ALIGN[4,8] style blend :0.6 ← any valid D,S pair
43
+ AND_MASK_ALIGN[6,12] structure style :0.6 ← binary-mask variant
44
+ AND_ALIGN_4_8 style blend :0.6 ← fixed-suffix form still works
45
+ AND_PERP ROTATE[0.125] rotated concept :1.0
46
+ ROTATE[0.125] AND_PERP rotated concept :1.0
47
+ AND_SALT[5] ROTATE[0.125] shifted concept :1.0
48
+ ```
49
+
50
+ > **Important:** always include a plain base prompt before `AND_*` segments.
51
+ > If all segments have conciliation keywords (no plain `AND`), the extension
52
+ > falls back to standard CFG automatically and logs a warning.
53
+
54
+ ---
55
+
56
+ ## Salience sharpness k
57
+
58
+ All three salience keywords accept an optional `[k]` parameter:
59
+
60
+ ```
61
+ AND_SALT[k] sharp mask default k = 5
62
+ AND_SALT_WIDE[k] broad mask default k = 1
63
+ AND_SALT_BLOB[k] blob mask default k = 5
64
+ ```
65
+
66
+ | k value | Effect | Approx. pixel coverage |
67
+ |---|---|---|
68
+ | 1 | Very broad, ~50% of pixels | ~50% |
69
+ | 3 | Moderate | ~11% |
70
+ | 5 | Focused (default) | ~2% |
71
+ | 10 | Surgical | ~0.25% |
72
+ | 20 | Extreme (original ultimate default) | ~0.04% |
73
+
74
+ ---
75
+
76
+ ## Alignment kernel sizes D, S
77
+
78
+ ```
79
+ AND_ALIGN[D,S] soft blend
80
+ AND_MASK_ALIGN[D,S] binary-mask blend
81
+ ```
82
+
83
+ - **D** = detail kernel size (small → fine detail)
84
+ - **S** = structure kernel size (large → global composition)
85
+ - Both in range **[2, 32]**, **D ≠ S**
86
+ - Child contribution is highest where it changes detail *without* breaking structure.
87
+
88
+ Fixed-suffix form (`AND_ALIGN_4_8`) is still supported and backward-compatible.
89
+
90
+ ---
91
+
92
+ ## Affine transform keywords
93
+
94
+ Supported in **either order** relative to the conciliation keyword:
95
+
96
+ | Keyword | Parameter | Effect |
97
+ |---|---|---|
98
+ | `ROTATE[angle]` | angle in turns (0–1) | Rotate latent contribution |
99
+ | `SLIDE[x,y]` | normalised offset | Translate latent contribution |
100
+ | `SCALE[x,y]` | scale factors | Scale latent contribution |
101
+ | `SHEAR[x,y]` | shear in turns | Shear latent contribution |
102
+
103
+ Multiple transforms can be chained: `ROTATE[0.125] SLIDE[0.1,0]`
104
+
105
+ ---
106
+
107
+ ## External API
108
+
109
+ ```python
110
+ from lib_neutral_prompt.external_code import override_cfg_rescale, get_last_cfg_rescale_factor
111
+
112
+ override_cfg_rescale(0.7) # override for next step only
113
+ factor = get_last_cfg_rescale_factor() # read after generation
114
+ ```
115
+
116
+ ---
117
+
118
+ ## CFG Rescale φ
119
+
120
+ When > 0, rescales CFG output to reduce colour over-saturation at high CFG values.
121
+ Uses the **mean-preserving** formula:
122
+
123
+ ```
124
+ rescaled = rescale_mean + (cfg_cond − cfg_cond_mean) × rescale_factor
125
+ ```
126
+
127
+ where `rescale_factor = φ × (std(cond)/std(cfg_cond) − 1) + 1`.
128
+
129
+ ---
130
+
131
+ ## Installation
132
+
133
+ 1. Copy this folder into `stable-diffusion-webui/extensions/`
134
+ 2. Restart the webui
135
+
136
+ ---
137
+
138
+ ## Smoke-test checklist
139
+
140
+ Before reporting a bug, verify:
141
+
142
+ - [ ] Base prompt exists before `AND_*` segments
143
+ - [ ] `AND_SALT` without `[k]` defaults to k=5 (moderate coverage)
144
+ - [ ] `AND_SALT[1]` behaves like `AND_SALT_WIDE` (broad)
145
+ - [ ] `AND_SALT_BLOB[5]` produces a visible blob region
146
+ - [ ] `AND_ALIGN[4,8]` and `AND_ALIGN_4_8` produce identical results
147
+ - [ ] Prompt with only `AND_SALT ...` (no base) falls back to standard CFG with a warning in console (enable verbose in Settings)
148
+
149
+ ---
150
+
151
+ ## Matryoshka / Nested prompt composition
152
+
153
+ The extension supports arbitrary nesting of `AND_*` blocks. An outer block
154
+ focuses or restricts a spatial/semantic region; inner blocks operate _inside_
155
+ that region:
156
+
157
+ ```text
158
+ base subject
159
+ AND_SALT[5] [
160
+ region texture :0.8
161
+ AND_TOPK[0.05] fine highlights :0.4
162
+ ] :0.7
163
+ ```
164
+
165
+ This is called a **matryoshka** structure — like nested dolls, each layer adds
166
+ a more focused effect.
167
+
168
+ ### Basic nesting syntax
169
+
170
+ ```text
171
+ base
172
+ KEYWORD [
173
+ inner text :weight
174
+ INNER_KEYWORD inner concept :weight
175
+ ] :weight
176
+ ```
177
+
178
+ The outer `[ ... ]` bracket is a **composite group**, not a parameter block.
179
+ It is parsed as such only if the content contains letters (not purely numeric).
180
+
181
+ ### Nesting with params
182
+
183
+ Params go directly after the keyword, _before_ the composite bracket:
184
+
185
+ ```text
186
+ AND_SALT[5] [ ← k=5 applied, then [ opens composite group
187
+ texture :0.8
188
+ AND_TOPK[0.1] detail :0.5
189
+ ] :0.7
190
+ ```
191
+
192
+ ### Nesting with affine
193
+
194
+ Affine transforms can precede the keyword at any level:
195
+
196
+ ```text
197
+ SCALE[-1,1] AND_PERP [ ← mirror then perpendicular
198
+ mirrored concept :0.6
199
+ ] :0.5
200
+ ```
201
+
202
+ ### How the parser distinguishes params vs composite
203
+
204
+ | Content inside `[...]` | Treated as |
205
+ |---|---|
206
+ | Purely numeric: `[5]`, `[4,8]`, `[ 0.1 ]` | **Params** for the preceding keyword |
207
+ | Contains letters / spaces + letters | **Composite group** |
208
+
209
+ So `AND_SALT[5]` → params; `AND_SALT [concept AND other]` → composite.
210
+
211
+ ### Protection with nested prompts
212
+
213
+ The base-prompt protection guard checks for a **plain base segment at the
214
+ top level** of the prompt. A nested structure like:
215
+
216
+ ```text
217
+ AND_SALT[5] [ ... ] :0.8
218
+ ```
219
+
220
+ has _no_ top-level base, so `auto` and `strict` protection will fire.
221
+ To suppress this intentionally, set protection to `off`.
222
+
223
+ ### Common mistakes
224
+
225
+ **Forgetting the base prompt**
226
+ ```text
227
+ # ❌ triggers protection fallback
228
+ AND_SALT[5] concept :0.8
229
+
230
+ # ✓
231
+ base subject
232
+ AND_SALT[5] concept :0.8
233
+ ```
234
+
235
+ **Using numeric composite brackets as params**
236
+ ```text
237
+ # This is parsed as a composite group, NOT as k=5:
238
+ AND_SALT [5] ← space before bracket → composite group with text "5"
239
+ AND_SALT[5] ← no space → params (k=5) ✓
240
+ ```
241
+
242
+ **Zero or out-of-range params**
243
+ ```text
244
+ AND_SALT[0] ← invalid k, silently uses defaults
245
+ AND_ALIGN[4,4] ← D must differ from S, silently uses defaults
246
+ AND_TOPK[1.5] ← threshold must be in (0,1], silently uses defaults
247
+ ```
248
+ Enable verbose mode in A1111 Settings → Neutral Prompt to see warnings.
249
+
250
+ ---
251
+
252
+ ## Matryoshka builder (UI)
253
+
254
+ Open the **Matryoshka builder** accordion to compose nested prompts visually:
255
+
256
+ 1. Enter the base prompt text.
257
+ 2. Choose a strategy for Child 1, enter its text and weight.
258
+ 3. Optionally enable a nested child _inside_ Child 1.
259
+ 4. Optionally add a second top-level child.
260
+ 5. The **Generated prompt (preview)** updates live.
261
+ 6. Click **Apply to prompt** to insert into the main prompt box.
262
+
263
+ ### Ready-made templates
264
+
265
+ The **Matryoshka templates** accordion contains 7 ready-made recipes:
266
+
267
+ | Template | What it does |
268
+ |---|---|
269
+ | Nested local detail | SALT region → nested TOPK fine detail |
270
+ | Structure preserve + style inject | ALIGN preserves structure → nested PERP reduces contradiction |
271
+ | Sparse detail inside broad region | SALT_WIDE broad → nested SALT focal sharpening |
272
+ | Perpendicular correction inside texture | SALT texture → nested PERP suppresses contradiction |
273
+ | Nested ALIGN + SALT | ALIGN structure → SALT foreground → nested TOPK highlights |
274
+ | Mirrored composition | Mirrored PERP via SCALE[-1,1] |
275
+ | Deep nested concept isolation | SALT → ALIGN → TOPK (3 levels) |
276
+
277
+ ---
278
+
279
+ ## Prompt debug / explain panel
280
+
281
+ Open the **Prompt debug / explain** accordion to inspect any prompt.
282
+
283
+ The panel shows:
284
+
285
+ ```
286
+ ROOT (N top-level segments)
287
+ ├─ [BASE] w=1.00 "base subject"
288
+ └─ [SALIENCE_MASK] w=0.80 [k=5.0] (2 children)
289
+ ├─ [BASE] w=0.80 "texture"
290
+ └─ [SEMANTIC_GUIDANCE] w=0.40 [threshold=0.05] "highlights"
291
+
292
+ ── Diagnostics ──────────────────────────
293
+ Segments total : 4 (leaf=3, composite=1)
294
+ Max nesting depth : 2
295
+ Affine transforms : 0
296
+ Strategies used : BASE, SALIENCE_MASK, SEMANTIC_GUIDANCE
297
+
298
+ Protection mode : auto
299
+ OK — base segment present at top level
300
+
301
+ ── Effect summary ────────────────────────
302
+ BASE → base prompt contribution — drives the overall image
303
+ SALIENCE_MASK → sharp saliency mask — targets the most salient latent pixels
304
+ SEMANTIC_GUIDANCE → semantic top-k — applies sparse targeted changes to the strongest elements
305
+ ```
306
+
307
+ ### What the debug panel does NOT simulate
308
+
309
+ - **Strict ratio check**: whether `norm(base_delta) / norm(aux_delta)` would fall
310
+ below the threshold is only knowable at generation time with real latents.
311
+ - **Batch/mixed-prompt cases**: the panel parses a single prompt string; it does
312
+ not simulate multi-prompt batch protection.
313
+ - **Hook conflicts**: whether another A1111 extension has overridden the hijack
314
+ chain is a runtime condition.
315
+
316
+ The protection verdict shown is a **structural preview** — a fast heuristic
317
+ to catch the most common mistake (missing base segment) before you start
318
+ generating.
319
+
320
+ ---
321
+
322
+ ## Affine transform builder (UI)
323
+
324
+ Open the **Affine transform builder** accordion to compose affine snippets:
325
+
326
+ 1. Pick a preset (Mirror H/V, Rotate 45°/90°/180°, Zoom, Stretch, etc.)
327
+ or select **Custom** and set parameters manually.
328
+ 2. The **Affine snippet** box updates live.
329
+ 3. Click **Insert affine into prompt** to append the snippet.
330
+
331
+ Snippets can be placed before a keyword (`ROTATE[0.125] AND_PERP ...`) or
332
+ as a standalone transform (`SCALE[-1,1] AND_SALT[5] ...`).
333
+
334
+ ### Safe ranges
335
+
336
+ | Transform | Notes |
337
+ |---|---|
338
+ | `ROTATE[angle]` | Angle in turns. `0.25` = 90°, `0.5` = 180° |
339
+ | `SCALE[x,y]` | Both components must be non-zero (avoids singular matrix) |
340
+ | `SLIDE[x,y]` | Translation fraction of latent space |
341
+ | `SHEAR[x,y]` | Angles near ±0.25 turns diverge — avoided automatically |
342
+ | `FLIP_H` | Shortcut for `SCALE[-1,1]` |
343
+ | `FLIP_V` | Shortcut for `SCALE[1,-1]` |
neutral_prompt_patcheds/lib_neutral_prompt/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # lib_neutral_prompt package
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (172 Bytes). View file
 
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/affine_transform.cpython-310.pyc ADDED
Binary file (3.34 kB). View file
 
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/affine_utils.cpython-310.pyc ADDED
Binary file (6.8 kB). View file
 
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/cfg_denoiser_hijack.cpython-310.pyc ADDED
Binary file (19.7 kB). View file
 
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/global_state.cpython-310.pyc ADDED
Binary file (8.06 kB). View file
 
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/hijacker.cpython-310.pyc ADDED
Binary file (1.8 kB). View file
 
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/matryoshka_utils.cpython-310.pyc ADDED
Binary file (19.7 kB). View file
 
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/neutral_prompt_parser.cpython-310.pyc ADDED
Binary file (13.7 kB). View file
 
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/prompt_parser_hijack.cpython-310.pyc ADDED
Binary file (3.92 kB). View file
 
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/protection_utils.cpython-310.pyc ADDED
Binary file (8.15 kB). View file
 
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/step_utils.cpython-310.pyc ADDED
Binary file (12 kB). View file
 
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/ui.cpython-310.pyc ADDED
Binary file (24.9 kB). View file
 
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/xyz_grid.cpython-310.pyc ADDED
Binary file (1.64 kB). View file
 
neutral_prompt_patcheds/lib_neutral_prompt/affine_transform.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Affine spatial transform utilities for latent tensors.
3
+ Extracted from the `affine` branch so that cfg_denoiser_hijack can stay lean.
4
+
5
+ Provides:
6
+ apply_affine_transform() – apply a 2×3 affine grid to a C×H×W tensor
7
+ apply_masked_transform() – apply affine + cosine-feathered mask blending
8
+ create_cosine_feathered_mask() – smooth circular weight mask
9
+ """
10
+
11
+ from typing import Tuple
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+
17
+ def apply_affine_transform(
18
+ tensor: torch.Tensor,
19
+ affine: torch.Tensor,
20
+ mode: str = 'bilinear',
21
+ ) -> torch.Tensor:
22
+ """
23
+ Apply a 2×3 affine transform to a C×H×W tensor, preserving aspect ratio.
24
+
25
+ :param tensor: Input tensor of shape [C, H, W].
26
+ :param affine: 2×3 float32 affine matrix.
27
+ :param mode: Interpolation mode for grid_sample ('bilinear' default).
28
+ The original affine branch used bilinear for the pre-noise
29
+ inverse path (smoother) and nearest for post-combine.
30
+ Bilinear is the better default for both.
31
+ :return: Transformed tensor of shape [C, H, W].
32
+ """
33
+ affine = affine.clone().to(tensor.device)
34
+ aspect_ratio = tensor.shape[-2] / tensor.shape[-1]
35
+ affine[0, 1] *= aspect_ratio
36
+ affine[1, 0] /= aspect_ratio
37
+
38
+ grid = F.affine_grid(affine.unsqueeze(0), tensor.unsqueeze(0).size(), align_corners=False)
39
+ return F.grid_sample(tensor.unsqueeze(0), grid, mode=mode, align_corners=False).squeeze(0)
40
+
41
+
42
+ def apply_masked_transform(
43
+ tensor: torch.Tensor,
44
+ affine: torch.Tensor,
45
+ weight: float,
46
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
47
+ """
48
+ Apply an affine transform with a cosine-feathered spatial weight mask.
49
+
50
+ The mask channel is appended to the tensor before transformation so that
51
+ the same spatial warp is applied to both content and mask consistently.
52
+
53
+ :param tensor: C×H×W latent tensor.
54
+ :param affine: 2×3 affine matrix.
55
+ :param weight: Global scalar weight for the mask.
56
+ :return: (transformed_content [C, H, W], transformed_mask [H, W])
57
+ """
58
+ mask = create_cosine_feathered_mask(tensor.shape[-2:], weight).unsqueeze(0).to(tensor.device)
59
+ tensor_with_mask = torch.cat([tensor, mask], dim=0) # [C+1, H, W]
60
+ transformed = apply_affine_transform(tensor_with_mask, affine)
61
+ return transformed[:-1], transformed[-1] # content, mask
62
+
63
+
64
+ def create_cosine_feathered_mask(size: Tuple[int, int], weight: float) -> torch.Tensor:
65
+ """
66
+ Create a circularly-clipped cosine-feathered mask of shape [H, W].
67
+
68
+ Values at the centre approach `weight`; values outside the unit circle
69
+ are exactly 0, with a smooth cosine fall-off in between.
70
+
71
+ :param size: (H, W) tuple.
72
+ :param weight: Peak mask value at the centre.
73
+ :return: Float32 mask tensor of shape [H, W].
74
+ """
75
+ y, x = torch.meshgrid(
76
+ torch.linspace(-1, 1, size[0]),
77
+ torch.linspace(-1, 1, size[1]),
78
+ indexing='ij',
79
+ )
80
+ dist = torch.sqrt(x ** 2 + y ** 2)
81
+ mask = 0.5 * (1.0 + torch.cos(torch.pi * dist))
82
+ mask[dist > 1] = 0.0
83
+ return mask.float() * weight
neutral_prompt_patcheds/lib_neutral_prompt/affine_utils.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Affine transform subsystem for sd-webui-neutral-prompt.
3
+
4
+ Single source of truth for:
5
+ - Matrix constructors: _make_rotate / _make_slide / _make_scale / _make_shear
6
+ - Application: _apply_affine
7
+ - Safe inversion: _try_invert_affine
8
+ - Parser dispatch: affine_transforms (keyword → callable)
9
+ - UI helpers: _build_affine_snippet / _AFFINE_TRANSFORMS / _AFFINE_PRESETS
10
+
11
+ Previously scattered across neutral_prompt_parser.py, cfg_denoiser_hijack.py,
12
+ and matryoshka_utils.py. Import from here everywhere.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import math
18
+ import sys
19
+ from typing import Dict, Optional, Tuple
20
+
21
+ import torch
22
+
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Warning helper
26
+ # ---------------------------------------------------------------------------
27
+
28
+ def affine_warn(keyword: str, reason: str) -> None:
29
+ """Print a one-line warning to stderr (respects global_state.verbose)."""
30
+ try:
31
+ from lib_neutral_prompt import global_state as _gs
32
+ if not _gs.verbose:
33
+ return
34
+ except ImportError:
35
+ pass
36
+ print(
37
+ f'[neutral-prompt] WARNING: invalid affine params for {keyword} — {reason}. '
38
+ 'Transform ignored; using identity.',
39
+ file=sys.stderr,
40
+ )
41
+
42
+ # Keep old private name as alias (parser uses it)
43
+ _affine_warn = affine_warn
44
+
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # Matrix constructors
48
+ # ---------------------------------------------------------------------------
49
+
50
+ def make_rotate(angle: float) -> Optional[torch.Tensor]:
51
+ """Rotation matrix. *angle* is in turns (0.25 = 90°)."""
52
+ if not math.isfinite(angle):
53
+ affine_warn('ROTATE', f'non-finite angle {angle}')
54
+ return None
55
+ a = angle * 2 * math.pi
56
+ return torch.tensor([
57
+ [math.cos(a), -math.sin(a), 0],
58
+ [math.sin(a), math.cos(a), 0],
59
+ [0, 0, 1],
60
+ ], dtype=torch.float32)
61
+
62
+
63
+ def make_slide(x: float, y: float) -> Optional[torch.Tensor]:
64
+ """Translation matrix."""
65
+ if not math.isfinite(x) or not math.isfinite(y):
66
+ affine_warn('SLIDE', f'non-finite offset ({x}, {y})')
67
+ return None
68
+ return torch.tensor([
69
+ [1, 0, float(x)],
70
+ [0, 1, float(y)],
71
+ [0, 0, 1],
72
+ ], dtype=torch.float32)
73
+
74
+
75
+ def make_scale(x: float, y: Optional[float] = None) -> Optional[torch.Tensor]:
76
+ """Uniform or anisotropic scale matrix. Both components must be non-zero."""
77
+ sy = y if y is not None else x
78
+ if not math.isfinite(x) or not math.isfinite(sy):
79
+ affine_warn('SCALE', f'non-finite scale ({x}, {sy})')
80
+ return None
81
+ if abs(x) < 1e-6 or abs(sy) < 1e-6:
82
+ affine_warn('SCALE', f'scale component too close to zero ({x}, {sy}); '
83
+ 'matrix would be singular')
84
+ return None
85
+ return torch.tensor([
86
+ [float(x), 0, 0],
87
+ [0, float(sy), 0],
88
+ [0, 0, 1],
89
+ ], dtype=torch.float32)
90
+
91
+
92
+ def make_shear(x: float, y: Optional[float] = None) -> Optional[torch.Tensor]:
93
+ """Shear matrix. Angles near ±0.25 turns are rejected (tan diverges)."""
94
+ sy = y if y is not None else x
95
+ if not math.isfinite(x) or not math.isfinite(sy):
96
+ affine_warn('SHEAR', f'non-finite shear ({x}, {sy})')
97
+ return None
98
+ _SHEAR_MAX = 0.24
99
+ if abs(x) >= _SHEAR_MAX or abs(sy) >= _SHEAR_MAX:
100
+ affine_warn('SHEAR', f'shear angle ({x}, {sy}) too close to ±0.25 turns; '
101
+ 'tan diverges and matrix would be degenerate')
102
+ return None
103
+ tx = math.tan(float(x) * 2 * math.pi)
104
+ ty = math.tan(float(sy) * 2 * math.pi)
105
+ if not math.isfinite(tx) or not math.isfinite(ty):
106
+ affine_warn('SHEAR', f'tan({x}, {sy}) produced non-finite value')
107
+ return None
108
+ return torch.tensor([
109
+ [1, tx, 0],
110
+ [ty, 1, 0],
111
+ [0, 0, 1],
112
+ ], dtype=torch.float32)
113
+
114
+
115
+ # Keep old private names (parser uses them)
116
+ _make_rotate = make_rotate
117
+ _make_slide = make_slide
118
+ _make_scale = make_scale
119
+ _make_shear = make_shear
120
+
121
+
122
+ # ---------------------------------------------------------------------------
123
+ # Application helper
124
+ # ---------------------------------------------------------------------------
125
+
126
+ def apply_affine(t: torch.Tensor, m: Optional[torch.Tensor]) -> torch.Tensor:
127
+ """Apply matrix *m* to running transform *t*; if *m* is None return *t* unchanged."""
128
+ return t if m is None else t @ m
129
+
130
+ _apply_affine = apply_affine # alias for parser
131
+
132
+
133
+ # ---------------------------------------------------------------------------
134
+ # Parser dispatch table
135
+ # ---------------------------------------------------------------------------
136
+
137
+ affine_transforms: Dict[str, object] = {
138
+ 'ROTATE': lambda t, angle=0, *_: apply_affine(t, make_rotate(float(angle))),
139
+ 'SLIDE': lambda t, x=0, y=0, *_: apply_affine(t, make_slide(float(x), float(y))),
140
+ 'SCALE': lambda t, x=1, y=None, *_: apply_affine(
141
+ t, make_scale(float(x), float(y) if y is not None else None)),
142
+ 'SHEAR': lambda t, x=0, y=None, *_: apply_affine(
143
+ t, make_shear(float(x), float(y) if y is not None else None)),
144
+ }
145
+
146
+
147
+ # ---------------------------------------------------------------------------
148
+ # Safe matrix inversion (runtime, used by cfg_denoiser_hijack)
149
+ # ---------------------------------------------------------------------------
150
+
151
+ def try_invert_affine(
152
+ m3: torch.Tensor,
153
+ context: str = 'affine transform',
154
+ ) -> torch.Tensor:
155
+ """
156
+ Safely invert a 3×3 homogeneous affine matrix.
157
+
158
+ On any error (singular, NaN) logs a warning (verbose only) and returns
159
+ the 3×3 identity so the pipeline degrades gracefully instead of crashing.
160
+ """
161
+ try:
162
+ result = torch.linalg.inv(m3)
163
+ if not torch.isfinite(result).all():
164
+ raise ValueError('non-finite values in inverted matrix')
165
+ return result
166
+ except Exception as exc:
167
+ try:
168
+ from lib_neutral_prompt import global_state as _gs
169
+ verbose = _gs.verbose
170
+ except ImportError:
171
+ verbose = True
172
+ if verbose:
173
+ print(
174
+ f'[neutral-prompt] WARNING: could not invert {context}: {exc}. '
175
+ 'Falling back to identity (no affine transform applied).',
176
+ file=sys.stderr,
177
+ )
178
+ return torch.eye(3, device=m3.device, dtype=m3.dtype)
179
+
180
+ _try_invert_affine = try_invert_affine # alias
181
+
182
+
183
+ # ---------------------------------------------------------------------------
184
+ # UI helpers (snippet builder + preset library)
185
+ # ---------------------------------------------------------------------------
186
+
187
+ AFFINE_TRANSFORMS = ['ROTATE', 'SCALE', 'SLIDE', 'SHEAR', 'FLIP_H', 'FLIP_V']
188
+
189
+ AFFINE_PRESETS: Dict[str, Optional[Tuple]] = {
190
+ 'Custom': None,
191
+ 'Mirror H': ('SCALE', -1.0, 1.0),
192
+ 'Mirror V': ('SCALE', 1.0, -1.0),
193
+ 'Rotate 45°': ('ROTATE', 0.125, None),
194
+ 'Rotate 90°': ('ROTATE', 0.25, None),
195
+ 'Rotate 180°': ('ROTATE', 0.5, None),
196
+ 'Zoom out 50%': ('SCALE', 0.5, 0.5),
197
+ 'Zoom in 150%': ('SCALE', 1.5, 1.5),
198
+ 'Stretch X': ('SCALE', 1.5, 1.0),
199
+ 'Stretch Y': ('SCALE', 1.0, 1.5),
200
+ 'Slight rotate': ('ROTATE', 0.05, None),
201
+ 'Slide right': ('SLIDE', 0.1, 0.0),
202
+ 'Slide down': ('SLIDE', 0.0, 0.1),
203
+ }
204
+
205
+
206
+ def build_affine_snippet(transform: str, p1: float, p2: float) -> str:
207
+ """Convert UI control values to a single ``TRANSFORM[args]`` token."""
208
+ t = transform.strip()
209
+ if t == 'FLIP_H': return 'SCALE[-1,1]'
210
+ if t == 'FLIP_V': return 'SCALE[1,-1]'
211
+ if t == 'ROTATE': return f'ROTATE[{p1}]'
212
+ if t == 'SCALE':
213
+ return f'SCALE[{p1},{p2}]' if abs(p1 - p2) > 1e-9 else f'SCALE[{p1}]'
214
+ if t == 'SLIDE': return f'SLIDE[{p1},{p2}]'
215
+ if t == 'SHEAR':
216
+ return f'SHEAR[{p1},{p2}]' if abs(p1 - p2) > 1e-9 else f'SHEAR[{p1}]'
217
+ return ''
218
+
219
+ # Old private alias (matryoshka_utils re-exports from here)
220
+ _build_affine_snippet = build_affine_snippet
neutral_prompt_patcheds/lib_neutral_prompt/cfg_denoiser_hijack.py ADDED
@@ -0,0 +1,726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified CFG Denoiser Hijack
3
+ ===========================
4
+ Combines features from all sd-webui-neutral-prompt branches:
5
+
6
+ • Core PERP / SALT / TOPK strategies (main branch)
7
+ • cfg_rescale with mean-preserving formula (main branch)
8
+ • cfg_rescale_override (XYZ-grid / API support) (main branch)
9
+ • CFGRescaleFactorSingleton export (export_rescale_factor branch)
10
+ • Affine spatial transforms per-prompt (affine branch)
11
+ • AND_ALIGN_D_S – soft alignment blend (alignment_blend branch)
12
+ • AND_MASK_ALIGN_D_S – binary alignment mask (alignment_mask branch)
13
+ • Improved salience with sharpness parameter k (life branch)
14
+
15
+ Changes vs original ultimate:
16
+ • AND_SALT / AND_SALT_WIDE / AND_SALT_BLOB: k read from conciliation_params
17
+ (default 5 / 1 / 5 — tuned to actually be visible)
18
+ • AND_SALT_BLOB: erode/thickify order fixed (thickify first, then erode)
19
+ • AND_ALIGN[D,S] / AND_MASK_ALIGN[D,S]: bracket-syntax routed via
20
+ ALIGNMENT_BLEND_CUSTOM / ALIGNMENT_MASK_BLEND_CUSTOM
21
+ • Guard against "no base prompt" → cond_delta=0 replacement bug
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import dataclasses
27
+ import functools
28
+ import re
29
+ import sys
30
+ import textwrap
31
+ from typing import Dict, List, Tuple
32
+
33
+ import torch
34
+ import torch.nn.functional as F
35
+
36
+ from lib_neutral_prompt import affine_transform as affine_mod
37
+ from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser
38
+ from modules import script_callbacks, sd_samplers, shared
39
+
40
+
41
+ # ---------------------------------------------------------------------------
42
+ # Pre-noise affine hook (affine branch)
43
+ # ---------------------------------------------------------------------------
44
+
45
+ @dataclasses.dataclass
46
+ class _PreNoiseArgs:
47
+ x: torch.Tensor
48
+ cond_indices: List[Tuple[int, float]]
49
+
50
+
51
+ class _GlobalToLocalAffineVisitor:
52
+ def visit_leaf_prompt(
53
+ self,
54
+ that: neutral_prompt_parser.LeafPrompt,
55
+ args: _PreNoiseArgs,
56
+ index: int,
57
+ ) -> Dict[int, torch.Tensor]:
58
+ cond_index = args.cond_indices[index][0]
59
+ if that.local_transform is not None:
60
+ m3 = torch.vstack([that.local_transform,
61
+ torch.tensor([0.0, 0.0, 1.0])])
62
+ inv = _try_invert_affine(m3, 'leaf local_transform')[:-1]
63
+ else:
64
+ inv = torch.eye(3)[:-1]
65
+ return {cond_index: inv}
66
+
67
+ def visit_composite_prompt(
68
+ self,
69
+ that: neutral_prompt_parser.CompositePrompt,
70
+ args: _PreNoiseArgs,
71
+ index: int,
72
+ ) -> Dict[int, torch.Tensor]:
73
+ inv_transforms: Dict[int, torch.Tensor] = {}
74
+
75
+ for child in that.children:
76
+ inv_transforms.update(child.accept(_GlobalToLocalAffineVisitor(), args, index))
77
+ index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
78
+
79
+ if that.local_transform is not None:
80
+ m3 = torch.vstack([that.local_transform,
81
+ torch.tensor([0.0, 0.0, 1.0])])
82
+ parent_inv = _try_invert_affine(m3, 'composite parent local_transform')
83
+ for inv in inv_transforms.values():
84
+ inv3 = torch.vstack([inv, torch.tensor([0.0, 0.0, 1.0])])
85
+ inv[:] = (parent_inv @ inv3)[:-1]
86
+
87
+ return inv_transforms
88
+
89
+
90
+ def _on_cfg_denoiser(params: script_callbacks.CFGDenoiserParams) -> None:
91
+ if not global_state.is_enabled:
92
+ return
93
+ if not _batch_is_compatible(global_state.prompt_exprs, global_state.batch_cond_indices):
94
+ return
95
+
96
+ for prompt, cond_indices in zip(global_state.prompt_exprs,
97
+ global_state.batch_cond_indices):
98
+ args = _PreNoiseArgs(params.x, cond_indices)
99
+ inv_transforms = prompt.accept(_GlobalToLocalAffineVisitor(), args, 0)
100
+ for cond_index, _ in cond_indices:
101
+ params.x[cond_index] = affine_mod.apply_affine_transform(
102
+ params.x[cond_index], inv_transforms[cond_index]
103
+ )
104
+
105
+
106
+ script_callbacks.on_cfg_denoiser(_on_cfg_denoiser)
107
+
108
+
109
+ def _flat_size(prompt: neutral_prompt_parser.PromptExpr) -> int:
110
+ return prompt.accept(neutral_prompt_parser.FlatSizeVisitor())
111
+
112
+
113
+ def _batch_is_compatible(
114
+ prompts: List[neutral_prompt_parser.PromptExpr],
115
+ batch_cond_indices: List[List[Tuple[int, float]]],
116
+ ) -> bool:
117
+ if len(prompts) != len(batch_cond_indices):
118
+ _console_warn(f'''
119
+ Neutral Prompt batch mismatch:
120
+ prompt_exprs={len(prompts)} vs batch_cond_indices={len(batch_cond_indices)}
121
+ Falling back to original A1111 behavior for this step.
122
+ ''')
123
+ return False
124
+
125
+ for i, (prompt, cond_indices) in enumerate(zip(prompts, batch_cond_indices)):
126
+ need = _flat_size(prompt)
127
+ got = len(cond_indices)
128
+ if need != got:
129
+ _console_warn(f'''
130
+ Neutral Prompt branch mismatch at prompt #{i}:
131
+ expected {need} branches, got {got}
132
+ Falling back to original A1111 behavior for this step.
133
+ ''')
134
+ return False
135
+
136
+ return True
137
+
138
+
139
+ # ---------------------------------------------------------------------------
140
+ # Public entry point
141
+ # ---------------------------------------------------------------------------
142
+
143
+ def combine_denoised_hijack(
144
+ x_out: torch.Tensor,
145
+ batch_cond_indices: List[List[Tuple[int, float]]],
146
+ text_uncond: torch.Tensor,
147
+ cond_scale: float,
148
+ original_function,
149
+ ) -> torch.Tensor:
150
+ if not global_state.is_enabled:
151
+ return original_function(x_out, batch_cond_indices, text_uncond, cond_scale)
152
+
153
+ if not _batch_is_compatible(global_state.prompt_exprs, batch_cond_indices):
154
+ return original_function(x_out, batch_cond_indices, text_uncond, cond_scale)
155
+
156
+ # ------------------------------------------------------------------
157
+ # Update step counters for step-window gating (runs before any gating check)
158
+ # shared.state.sampling_step / sampling_steps are set by A1111 per denoiser call
159
+ # ------------------------------------------------------------------
160
+ try:
161
+ global_state.update_step_state(
162
+ shared.state.sampling_step,
163
+ shared.state.sampling_steps,
164
+ )
165
+ except Exception:
166
+ pass # sampling_step may not be set during warmup; keep previous values
167
+ # NOTE: Lock-after-End flags are reset in NeutralPromptScript.process(),
168
+ # not here. Resetting on step==0 would incorrectly re-arm locks inside
169
+ # samplers that revisit step 0 (Restart, DPM++ multi-pass, hires fix).
170
+
171
+ # ------------------------------------------------------------------
172
+ # Base-prompt protection (Off / Auto / Strict / Soft)
173
+ # ------------------------------------------------------------------
174
+ should_fallback, reason = _should_fallback_for_protection(
175
+ x_out, batch_cond_indices, text_uncond
176
+ )
177
+ if should_fallback:
178
+ _console_warn(f'\n [base-prompt protection] {reason}\n Falling back to standard CFG.')
179
+ return original_function(x_out, batch_cond_indices, text_uncond, cond_scale)
180
+ # ------------------------------------------------------------------
181
+
182
+ denoised = _get_webui_denoised(x_out, batch_cond_indices, text_uncond, cond_scale, original_function)
183
+ uncond = x_out[-text_uncond.shape[0]:]
184
+
185
+ for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
186
+ args = _DenoiseArgs(x_out, uncond[batch_i], cond_indices)
187
+ cond_delta = prompt.accept(_CondDeltaVisitor(), args, 0)
188
+ aux_cond_delta = prompt.accept(_AuxCondDeltaVisitor(), args, cond_delta, 0)
189
+
190
+ # --- Soft mode: attenuate auxiliary if base is too weak ---
191
+ aux_cond_delta, _ = apply_soft_attenuation_if_needed(
192
+ aux_cond_delta, cond_delta, batch_i
193
+ )
194
+
195
+ if prompt.local_transform is not None:
196
+ cond_delta = affine_mod.apply_affine_transform(cond_delta, prompt.local_transform)
197
+ aux_cond_delta = affine_mod.apply_affine_transform(aux_cond_delta, prompt.local_transform)
198
+
199
+ cfg_cond = denoised[batch_i] + aux_cond_delta * cond_scale
200
+ denoised[batch_i] = _cfg_rescale(cfg_cond, uncond[batch_i] + cond_delta + aux_cond_delta)
201
+
202
+ return denoised
203
+
204
+
205
+ # ---------------------------------------------------------------------------
206
+ # Internal helpers
207
+ # ---------------------------------------------------------------------------
208
+
209
+ def _get_webui_denoised(
210
+ x_out: torch.Tensor,
211
+ batch_cond_indices: List[List[Tuple[int, float]]],
212
+ text_uncond: torch.Tensor,
213
+ cond_scale: float,
214
+ original_function,
215
+ ) -> torch.Tensor:
216
+ uncond = x_out[-text_uncond.shape[0]:]
217
+ sliced_batch_x_out: List[torch.Tensor] = []
218
+ sliced_batch_cond_indices: List[List[Tuple[int, float]]] = []
219
+
220
+ for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
221
+ args = _DenoiseArgs(x_out, uncond[batch_i], cond_indices)
222
+ sliced_x_out, sliced_indices = _gather_webui_conds(prompt, args, 0, len(sliced_batch_x_out))
223
+ if sliced_indices:
224
+ sliced_batch_cond_indices.append(sliced_indices)
225
+ sliced_batch_x_out.extend(sliced_x_out)
226
+
227
+ sliced_batch_x_out += list(uncond)
228
+ return original_function(
229
+ torch.stack(sliced_batch_x_out, dim=0),
230
+ sliced_batch_cond_indices,
231
+ text_uncond,
232
+ cond_scale,
233
+ )
234
+
235
+
236
+ def _cfg_rescale(cfg_cond: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
237
+ """Mean-preserving CFG rescale."""
238
+ global_state.CFGRescaleFactorSingleton.clear()
239
+ global_state.apply_and_clear_cfg_rescale_override()
240
+
241
+ if global_state.cfg_rescale == 0:
242
+ return cfg_cond
243
+
244
+ cfg_std = cfg_cond.std()
245
+ if cfg_std == 0:
246
+ return cfg_cond
247
+
248
+ cfg_cond_mean = cfg_cond.mean()
249
+ rescale_mean = (
250
+ (1 - global_state.cfg_rescale) * cfg_cond_mean
251
+ + global_state.cfg_rescale * cond.mean()
252
+ )
253
+ rescale_factor = global_state.cfg_rescale * (cond.std() / cfg_std - 1) + 1
254
+
255
+ global_state.CFGRescaleFactorSingleton.set(
256
+ rescale_factor.item() if isinstance(rescale_factor, torch.Tensor) else float(rescale_factor)
257
+ )
258
+
259
+ return rescale_mean + (cfg_cond - cfg_cond_mean) * rescale_factor
260
+
261
+
262
+ @dataclasses.dataclass
263
+ class _DenoiseArgs:
264
+ x_out: torch.Tensor
265
+ uncond: torch.Tensor
266
+ cond_indices: List[Tuple[int, float]]
267
+
268
+
269
+ def _gather_webui_conds(
270
+ prompt: neutral_prompt_parser.CompositePrompt,
271
+ args: _DenoiseArgs,
272
+ index_in: int,
273
+ index_out: int,
274
+ ) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
275
+ sliced_x_out: List[torch.Tensor] = []
276
+ sliced_cond_indices: List[Tuple[int, float]] = []
277
+
278
+ for child in prompt.children:
279
+ if child.conciliation is None:
280
+ if isinstance(child, neutral_prompt_parser.LeafPrompt) and child.local_transform is None:
281
+ child_x_out = args.x_out[args.cond_indices[index_in][0]]
282
+ child_weight = child.weight
283
+ else:
284
+ child_x_out, child_weight = _get_cond_delta_and_weight(child, args, index_in)
285
+ child_x_out = child_x_out + args.uncond
286
+
287
+ index_offset = index_out + len(sliced_x_out)
288
+ sliced_x_out.append(child_x_out)
289
+ sliced_cond_indices.append((index_offset, child_weight))
290
+
291
+ index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor())
292
+
293
+ return sliced_x_out, sliced_cond_indices
294
+
295
+
296
+ def _get_cond_delta_and_weight(
297
+ prompt: neutral_prompt_parser.PromptExpr,
298
+ args: _DenoiseArgs,
299
+ index: int,
300
+ ) -> Tuple[torch.Tensor, float]:
301
+ cond_delta = prompt.accept(_CondDeltaVisitor(), args, index)
302
+ cond_delta = cond_delta + prompt.accept(_AuxCondDeltaVisitor(), args, cond_delta, index)
303
+ weight = prompt.weight
304
+
305
+ if prompt.local_transform is not None:
306
+ transformed, weight_tensor = affine_mod.apply_masked_transform(
307
+ cond_delta + args.uncond,
308
+ prompt.local_transform,
309
+ prompt.weight,
310
+ )
311
+ cond_delta = transformed - args.uncond
312
+ weight = weight_tensor
313
+
314
+ return cond_delta, weight
315
+
316
+
317
+ # ---------------------------------------------------------------------------
318
+ # Visitor: CondDelta
319
+ # ---------------------------------------------------------------------------
320
+
321
+ class _CondDeltaVisitor:
322
+ def visit_leaf_prompt(
323
+ self,
324
+ that: neutral_prompt_parser.LeafPrompt,
325
+ args: _DenoiseArgs,
326
+ index: int,
327
+ ) -> torch.Tensor:
328
+ cond_info = args.cond_indices[index]
329
+ if that.weight != cond_info[1]:
330
+ _console_warn(f'''
331
+ Unexpected noise weight at prompt #{index}
332
+ Expected :{that.weight}, got :{cond_info[1]}
333
+ ''')
334
+ return args.x_out[cond_info[0]] - args.uncond
335
+
336
+ def visit_composite_prompt(
337
+ self,
338
+ that: neutral_prompt_parser.CompositePrompt,
339
+ args: _DenoiseArgs,
340
+ index: int,
341
+ ) -> torch.Tensor:
342
+ cond_delta = torch.zeros_like(args.x_out[0])
343
+ for child in that.children:
344
+ if child.conciliation is None:
345
+ child_delta, child_weight = _get_cond_delta_and_weight(child, args, index)
346
+ cond_delta = cond_delta + child_weight * child_delta
347
+ index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
348
+ return cond_delta
349
+
350
+
351
+ # ---------------------------------------------------------------------------
352
+ # Visitor: AuxCondDelta — all conciliation strategies
353
+ # ---------------------------------------------------------------------------
354
+
355
+ class _AuxCondDeltaVisitor:
356
+ def visit_leaf_prompt(
357
+ self,
358
+ that: neutral_prompt_parser.LeafPrompt,
359
+ args: _DenoiseArgs,
360
+ cond_delta: torch.Tensor,
361
+ index: int,
362
+ ) -> torch.Tensor:
363
+ return torch.zeros_like(args.x_out[0])
364
+
365
+ def visit_composite_prompt(
366
+ self,
367
+ that: neutral_prompt_parser.CompositePrompt,
368
+ args: _DenoiseArgs,
369
+ cond_delta: torch.Tensor,
370
+ index: int,
371
+ ) -> torch.Tensor:
372
+ aux_cond_delta = torch.zeros_like(args.x_out[0])
373
+
374
+ # Each salience list entry: (delta, weight, k)
375
+ salient_cond_deltas: List[Tuple[torch.Tensor, float, float]] = []
376
+ salient_wide_deltas: List[Tuple[torch.Tensor, float, float]] = []
377
+ salient_blob_deltas: List[Tuple[torch.Tensor, float, float]] = []
378
+ align_blend_deltas: List[Tuple[torch.Tensor, float, int, int]] = []
379
+ mask_align_deltas: List[Tuple[torch.Tensor, float, int, int]] = []
380
+
381
+ CS = neutral_prompt_parser.ConciliationStrategy
382
+
383
+ # Step-window gating: compute normalised progress once per composite visit
384
+ if global_state.step_window_enabled:
385
+ from lib_neutral_prompt.step_utils import (
386
+ strategy_is_active_from_state, normalize_progress,
387
+ )
388
+ progress = normalize_progress(global_state.current_step, global_state.total_steps)
389
+ else:
390
+ progress = None
391
+
392
+ for child in that.children:
393
+ if child.conciliation is not None:
394
+ # --- Step-window gate ---
395
+ if progress is not None:
396
+ strat_name = child.conciliation.name
397
+ if not strategy_is_active_from_state(strat_name, progress):
398
+ index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
399
+ continue
400
+
401
+ child_delta, child_weight = _get_cond_delta_and_weight(child, args, index)
402
+ strat = child.conciliation
403
+ cp = child.conciliation_params # per-child params dict
404
+
405
+ if strat == CS.PERPENDICULAR:
406
+ aux_cond_delta = aux_cond_delta + child_weight * _get_perpendicular_component(cond_delta, child_delta)
407
+
408
+ elif strat == CS.SALIENCE_MASK:
409
+ # Default k=5: noticeably more coverage than original k=20.
410
+ # User can override via AND_SALT[k].
411
+ k = float(cp.get('k', 5.0))
412
+ salient_cond_deltas.append((child_delta, child_weight, k))
413
+
414
+ elif strat == CS.SALIENCE_MASK_WIDE:
415
+ k = float(cp.get('k', 1.0))
416
+ salient_wide_deltas.append((child_delta, child_weight, k))
417
+
418
+ elif strat == CS.SALIENCE_MASK_BLOB:
419
+ k = float(cp.get('k', 5.0))
420
+ salient_blob_deltas.append((child_delta, child_weight, k))
421
+
422
+ elif strat == CS.SEMANTIC_GUIDANCE:
423
+ threshold = float(cp.get('threshold', 0.05))
424
+ aux_cond_delta = aux_cond_delta + child_weight * _filter_abs_top_k(child_delta, threshold)
425
+
426
+ elif strat == CS.ALIGNMENT_BLEND_CUSTOM:
427
+ # AND_ALIGN[D,S] — any user-specified pair
428
+ d = int(cp.get('d', 4))
429
+ s = int(cp.get('s', 8))
430
+ if 2 <= d <= 32 and 2 <= s <= 32 and d != s:
431
+ align_blend_deltas.append((child_delta, child_weight, d, s))
432
+
433
+ elif strat == CS.ALIGNMENT_MASK_BLEND_CUSTOM:
434
+ # AND_MASK_ALIGN[D,S]
435
+ d = int(cp.get('d', 4))
436
+ s = int(cp.get('s', 8))
437
+ if 2 <= d <= 32 and 2 <= s <= 32 and d != s:
438
+ mask_align_deltas.append((child_delta, child_weight, d, s))
439
+
440
+ else:
441
+ # Fixed-suffix AND_ALIGN_D_S (backward-compat)
442
+ m = re.match(r'AND_ALIGN_(\d+)_(\d+)', strat.value)
443
+ if m:
444
+ align_blend_deltas.append((child_delta, child_weight, int(m.group(1)), int(m.group(2))))
445
+ else:
446
+ m = re.match(r'AND_MASK_ALIGN_(\d+)_(\d+)', strat.value)
447
+ if m:
448
+ mask_align_deltas.append((child_delta, child_weight, int(m.group(1)), int(m.group(2))))
449
+
450
+ index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
451
+
452
+ aux_cond_delta = aux_cond_delta + _salient_blend(cond_delta, salient_cond_deltas)
453
+ aux_cond_delta = aux_cond_delta + _salient_blend(cond_delta, salient_wide_deltas)
454
+ aux_cond_delta = aux_cond_delta + _salient_blend_blob(cond_delta, salient_blob_deltas)
455
+ aux_cond_delta = aux_cond_delta + _alignment_blend(cond_delta, align_blend_deltas)
456
+ aux_cond_delta = aux_cond_delta + _alignment_mask_blend(cond_delta, mask_align_deltas)
457
+ return aux_cond_delta
458
+
459
+
460
+ # ---------------------------------------------------------------------------
461
+ # Strategy implementations
462
+ # ---------------------------------------------------------------------------
463
+
464
+ def _get_perpendicular_component(normal: torch.Tensor, vector: torch.Tensor) -> torch.Tensor:
465
+ if (normal == 0).all():
466
+ if shared.state.sampling_step <= 0:
467
+ _warn_projection_not_found()
468
+ return vector
469
+ return vector - normal * torch.sum(normal * vector) / torch.norm(normal) ** 2
470
+
471
+
472
+ def _salient_blend(
473
+ normal: torch.Tensor,
474
+ vectors: List[Tuple[torch.Tensor, float, float]],
475
+ ) -> torch.Tensor:
476
+ """
477
+ Saliency-guided blend. Each entry: (delta, weight, k).
478
+
479
+ k controls child mask sharpness:
480
+ k=1 → broad (~55% of pixels) — AND_SALT_WIDE default
481
+ k=5 → moderate (~5-10%) — AND_SALT default (was 20, now tuned)
482
+ k=20 → very sharp (1-2 pixels) — can be set via AND_SALT[20]
483
+
484
+ Parent always uses k=1 (diffuse reference) so children compete fairly.
485
+ """
486
+ if not vectors:
487
+ return torch.zeros_like(normal)
488
+
489
+ salience_maps = [_get_salience(normal, k=1.0)] + [
490
+ _get_salience(v, k=k) for v, _, k in vectors
491
+ ]
492
+ mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0)
493
+
494
+ result = torch.zeros_like(normal)
495
+ for mask_i, (vector, weight, _) in enumerate(vectors, start=1):
496
+ vector_mask = (mask == mask_i).float()
497
+ result = result + weight * vector_mask * (vector - normal)
498
+ return result
499
+
500
+
501
+ def _salient_blend_blob(
502
+ normal: torch.Tensor,
503
+ vectors: List[Tuple[torch.Tensor, float, float]],
504
+ ) -> torch.Tensor:
505
+ """
506
+ AND_SALT_BLOB — life-branch algorithm (fixed erode/thickify order).
507
+
508
+ Fix vs original: original did 6×erode then 2×thickify, which always
509
+ destroyed the 1-2 pixel seed from k=20 on the first erosion step.
510
+ Correct pipeline:
511
+ 1. k-softmax → initial seed (sharpness from conciliation_params)
512
+ 2. thickify ×3 → grow seed into a dense blob first
513
+ 3. erode ×1 → smooth/sharpen edges without destroying the core
514
+ """
515
+ if not vectors:
516
+ return torch.zeros_like(normal)
517
+
518
+ salience_maps = [_get_salience(normal, k=1.0)] + [
519
+ _get_salience(v, k=k) for v, _, k in vectors
520
+ ]
521
+ mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0)
522
+
523
+ result = torch.zeros_like(normal)
524
+ for mask_i, (vector, weight, _) in enumerate(vectors, start=1):
525
+ vector_mask = (mask == mask_i).float()
526
+ # Grow first, then smooth — prevents seed self-destruction
527
+ for _ in range(3):
528
+ vector_mask = _life_step(vector_mask, _thickify_rule)
529
+ for _ in range(1):
530
+ vector_mask = _life_step(vector_mask, _erode_rule)
531
+ result = result + weight * vector_mask * (vector - normal)
532
+ return result
533
+
534
+
535
+ def _life_step(board: torch.Tensor, rule) -> torch.Tensor:
536
+ C = board.shape[0]
537
+ kernel = torch.ones((C, 3, 3), dtype=board.dtype, device=board.device)
538
+ kernel = kernel.unsqueeze(0).unsqueeze(0)
539
+
540
+ padded = torch.cat([board.clone(), board[:-1].clone()], dim=0)
541
+ padded = torch.nn.functional.pad(padded, (1, 1, 1, 1, 0, 0), value=0)
542
+
543
+ neighbors = torch.nn.functional.conv3d(
544
+ padded.unsqueeze(0).unsqueeze(0),
545
+ kernel,
546
+ padding=0,
547
+ ).squeeze(0).squeeze(0)
548
+
549
+ neighbors = neighbors - board
550
+ return rule(board, neighbors).float()
551
+
552
+
553
+ def _erode_rule(board: torch.Tensor, neighbors: torch.Tensor) -> torch.Tensor:
554
+ C = board.shape[0]
555
+ return (board == 1) & (neighbors >= C * 5)
556
+
557
+
558
+ def _thickify_rule(board: torch.Tensor, neighbors: torch.Tensor) -> torch.Tensor:
559
+ population = board + neighbors
560
+ return (board == 1) | (population >= 4)
561
+
562
+
563
+ def _get_salience(vector: torch.Tensor, k: float = 1.0) -> torch.Tensor:
564
+ return torch.softmax(k * torch.abs(vector).flatten(), dim=0).reshape_as(vector)
565
+
566
+
567
+ def _filter_abs_top_k(vector: torch.Tensor, k_ratio: float) -> torch.Tensor:
568
+ k = int(torch.numel(vector) * (1 - k_ratio))
569
+ k = max(1, k)
570
+ threshold, _ = torch.kthvalue(torch.abs(vector.flatten()), k)
571
+ return vector * (torch.abs(vector) >= threshold).to(vector.dtype)
572
+
573
+
574
+ # ---------------------------------------------------------------------------
575
+ # Alignment blend (alignment_blend branch)
576
+ # ---------------------------------------------------------------------------
577
+
578
+ def _compute_subregion_similarity_map(
579
+ child_vector: torch.Tensor,
580
+ parent_vector: torch.Tensor,
581
+ region_size: int = 2,
582
+ ) -> torch.Tensor:
583
+ C, H, W = child_vector.shape
584
+ parent = parent_vector.unsqueeze(0)
585
+ child = child_vector.unsqueeze(0)
586
+
587
+ region_radius = region_size // 2
588
+ if region_size % 2 == 1:
589
+ pad_size = (region_radius,) * 4
590
+ else:
591
+ pad_size = (region_radius - 1, region_radius) * 2
592
+
593
+ parent_reg = F.unfold(F.pad(parent, pad_size, 'constant', 0), kernel_size=region_size)
594
+ child_reg = F.unfold(F.pad(child, pad_size, 'constant', 0), kernel_size=region_size)
595
+
596
+ parent_reg = parent_reg.view(1, C, region_size**2, H*W).permute(3, 1, 2, 0).contiguous().view(H*W, C, region_size, region_size)
597
+ child_reg = child_reg.view( 1, C, region_size**2, H*W).permute(3, 1, 2, 0).contiguous().view(H*W, C, region_size, region_size)
598
+
599
+ unfold2 = torch.nn.Unfold(kernel_size=2)
600
+ parent_sub = unfold2(parent_reg).view(H*W, C, 4, (region_size - 1)**2)
601
+ child_sub = unfold2(child_reg ).view(H*W, C, 4, (region_size - 1)**2)
602
+
603
+ parent_sub = F.normalize(parent_sub, p=2, dim=2)
604
+ child_sub = F.normalize(child_sub, p=2, dim=2)
605
+ sim = (parent_sub * child_sub).sum(dim=2)
606
+ return sim.mean(dim=2).permute(1, 0).contiguous().view(C, H, W)
607
+
608
+
609
+ def _alignment_blend(
610
+ parent: torch.Tensor,
611
+ children: List[Tuple[torch.Tensor, float, int, int]],
612
+ ) -> torch.Tensor:
613
+ result = torch.zeros_like(parent)
614
+ for child, weight, detail_size, structure_size in children:
615
+ detail_sim = _compute_subregion_similarity_map(child, parent, detail_size)
616
+ structure_sim = _compute_subregion_similarity_map(child, parent, structure_size)
617
+
618
+ d_abs_max = detail_sim.abs().max().clamp(min=1e-8)
619
+ s_abs_max = structure_sim.abs().max().clamp(min=1e-8)
620
+ detail_sim = detail_sim / d_abs_max
621
+ structure_sim = structure_sim / s_abs_max
622
+
623
+ alignment_weight = torch.clamp(structure_sim - detail_sim, min=0.0, max=1.0)
624
+ result = result + (child - parent) * weight * alignment_weight
625
+ return result
626
+
627
+
628
+ def _alignment_mask_blend(
629
+ parent: torch.Tensor,
630
+ children: List[Tuple[torch.Tensor, float, int, int]],
631
+ ) -> torch.Tensor:
632
+ result = torch.zeros_like(parent)
633
+ for child, weight, detail_size, structure_size in children:
634
+ detail_sim = _compute_subregion_similarity_map(child, parent, detail_size)
635
+ structure_sim = _compute_subregion_similarity_map(child, parent, structure_size)
636
+
637
+ d_abs_max = detail_sim.abs().max().clamp(min=1e-8)
638
+ s_abs_max = structure_sim.abs().max().clamp(min=1e-8)
639
+ detail_sim = detail_sim / d_abs_max
640
+ structure_sim = structure_sim / s_abs_max
641
+
642
+ alignment_mask = (structure_sim > detail_sim).to(child)
643
+ result = result + (child - parent) * weight * alignment_mask
644
+ return result
645
+
646
+
647
+
648
+ # ---------------------------------------------------------------------------
649
+ # Base-prompt protection policy (imported from protection_utils)
650
+ # ---------------------------------------------------------------------------
651
+ from lib_neutral_prompt.protection_utils import (
652
+ prompt_has_valid_base_path as _prompt_has_valid_base_path,
653
+ get_invalid_base_prompt_indices as _get_invalid_base_prompt_indices,
654
+ has_valid_base_path as _has_valid_base_path,
655
+ safe_norm as _safe_norm,
656
+ apply_soft_attenuation_if_needed,
657
+ )
658
+ from lib_neutral_prompt.affine_utils import try_invert_affine as _try_invert_affine
659
+
660
+
661
+ def _should_fallback_for_protection(
662
+ x_out: torch.Tensor,
663
+ batch_cond_indices,
664
+ text_uncond: torch.Tensor,
665
+ ) -> tuple:
666
+ """
667
+ Decide whether to fall back to standard CFG. Delegates to protection_utils.
668
+ Passes the visitor classes defined in this module to avoid circular imports.
669
+ """
670
+ from lib_neutral_prompt.protection_utils import should_fallback_for_protection
671
+ return should_fallback_for_protection(
672
+ x_out, batch_cond_indices, text_uncond,
673
+ _CondDeltaVisitor, _AuxCondDeltaVisitor, _DenoiseArgs,
674
+ )
675
+
676
+
677
+
678
+ # ---------------------------------------------------------------------------
679
+ # Sampler hijack
680
+ # ---------------------------------------------------------------------------
681
+
682
+ sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
683
+ module=sd_samplers,
684
+ hijacker_attribute='__neutral_prompt_hijacker',
685
+ on_uninstall=script_callbacks.on_script_unloaded,
686
+ )
687
+
688
+
689
+ @sd_samplers_hijacker.hijack('create_sampler')
690
+ def create_sampler_hijack(name: str, model, original_function):
691
+ sampler = original_function(name, model)
692
+ if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'):
693
+ if global_state.is_enabled:
694
+ _warn_unsupported_sampler()
695
+ return sampler
696
+
697
+ sampler.model_wrap_cfg.combine_denoised = functools.partial(
698
+ combine_denoised_hijack,
699
+ original_function=sampler.model_wrap_cfg.combine_denoised,
700
+ )
701
+ return sampler
702
+
703
+
704
+ # ---------------------------------------------------------------------------
705
+ # Warnings / logging
706
+ # ---------------------------------------------------------------------------
707
+
708
+ def _warn_unsupported_sampler() -> None:
709
+ _console_warn('''
710
+ Neutral prompt relies on composition via AND, which the webui does not support
711
+ when using any of the DDIM, PLMS and UniPC samplers.
712
+ The sampler will NOT be patched – falling back on the original implementation.
713
+ ''')
714
+
715
+
716
+ def _warn_projection_not_found() -> None:
717
+ _console_warn('''
718
+ Could not find a projection for one or more AND_PERP prompts.
719
+ These prompts will NOT be made perpendicular.
720
+ ''')
721
+
722
+
723
+ def _console_warn(message: str) -> None:
724
+ if not global_state.verbose:
725
+ return
726
+ print(f'\n[sd-webui-neutral-prompt]{textwrap.dedent(message)}', file=sys.stderr)
neutral_prompt_patcheds/lib_neutral_prompt/external_code/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+
3
+
4
+ @contextlib.contextmanager
5
+ def fix_path():
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ extension_path = str(Path(__file__).parent.parent.parent)
10
+ added = False
11
+ if extension_path not in sys.path:
12
+ sys.path.insert(0, extension_path)
13
+ added = True
14
+
15
+ yield
16
+
17
+ if added:
18
+ sys.path.remove(extension_path)
19
+
20
+
21
+ with fix_path():
22
+ del fix_path, contextlib
23
+ from .api import *
neutral_prompt_patcheds/lib_neutral_prompt/external_code/api.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ External API for sd-webui-neutral-prompt.
3
+
4
+ Provides thin helpers that external extensions / scripts can import via:
5
+
6
+ from lib_neutral_prompt.external_code import override_cfg_rescale, get_last_cfg_rescale_factor
7
+ """
8
+
9
+ from typing import Optional
10
+
11
+ from lib_neutral_prompt import global_state
12
+
13
+
14
+ def override_cfg_rescale(cfg_rescale: float) -> None:
15
+ """
16
+ Override the CFG rescale value for the *next* generation step only.
17
+ After the step runs the value is automatically cleared.
18
+ """
19
+ global_state.cfg_rescale_override = cfg_rescale
20
+
21
+
22
+ def get_last_cfg_rescale_factor() -> Optional[float]:
23
+ """
24
+ Return the CFG rescale factor computed during the most recent denoising
25
+ step, or None if rescaling was not active.
26
+ """
27
+ return global_state.CFGRescaleFactorSingleton.get()
neutral_prompt_patcheds/lib_neutral_prompt/global_state.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Global mutable state shared across the extension.
3
+
4
+ Combines:
5
+ - is_enabled / prompt_exprs / verbose (all branches)
6
+ - cfg_rescale + cfg_rescale_override (main branch)
7
+ - batch_cond_indices (affine branch)
8
+ - CFGRescaleFactorSingleton (export_rescale_factor branch)
9
+ - protection_mode / strict_threshold (patched — base-prompt protection)
10
+ """
11
+
12
+ import threading
13
+ from typing import List, Optional, Tuple
14
+ from lib_neutral_prompt import neutral_prompt_parser
15
+
16
+
17
+ # ---------------------------------------------------------------------------
18
+ # Runtime state
19
+ # ---------------------------------------------------------------------------
20
+
21
+ is_enabled: bool = False
22
+ prompt_exprs: List[neutral_prompt_parser.PromptExpr] = []
23
+ batch_cond_indices: List[List[Tuple[int, float]]] = []
24
+ cfg_rescale: float = 0.0
25
+ verbose: bool = False
26
+
27
+ # Set to a float value by XYZ-grid or external API to override cfg_rescale
28
+ # for a single generation step, then auto-cleared.
29
+ cfg_rescale_override: Optional[float] = None
30
+
31
+
32
+ def apply_and_clear_cfg_rescale_override() -> None:
33
+ """Apply a one-shot cfg_rescale override and immediately clear it."""
34
+ global cfg_rescale, cfg_rescale_override
35
+ if cfg_rescale_override is not None:
36
+ cfg_rescale = cfg_rescale_override
37
+ cfg_rescale_override = None
38
+
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # Base-prompt protection
42
+ # ---------------------------------------------------------------------------
43
+ # Three modes:
44
+ #
45
+ # 'off' — no guard; auxiliary children can dominate freely.
46
+ # Use for experiments or intentional replacement-style prompts.
47
+ #
48
+ # 'auto' — (DEFAULT) structural guard: fires when the prompt tree has no
49
+ # top-level base AND-child (i.e. all children have conciliation
50
+ # keywords). Prevents obvious "main prompt disappears" cases.
51
+ #
52
+ # 'strict' — structural guard PLUS numerical guard: also fires when a base
53
+ # child exists but its cond_delta is too weak relative to the
54
+ # auxiliary contribution:
55
+ # norm(base_delta) / norm(aux_delta) < strict_threshold
56
+ # Use when you notice auxiliary prompts quietly dominating even
57
+ # though a base prompt is present.
58
+ #
59
+ # Both auto/strict fall back to standard A1111 CFG and print a warning when
60
+ # verbose=True in Settings → Neutral Prompt.
61
+ # ---------------------------------------------------------------------------
62
+
63
+ _VALID_PROTECTION_MODES = frozenset({'off', 'auto', 'strict', 'soft'})
64
+
65
+ protection_mode: str = 'auto'
66
+ strict_threshold: float = 0.10 # ratio threshold used by both strict and soft modes
67
+
68
+
69
+ def normalize_protection_mode(value: str) -> str:
70
+ """
71
+ Normalise an incoming protection-mode string.
72
+ Strips whitespace, lowercases, and falls back to 'auto' for unknown values.
73
+ Used wherever mode arrives from the UI or infotext paste.
74
+ Valid modes: 'off', 'auto', 'strict', 'soft'.
75
+ """
76
+ normalised = str(value).strip().lower()
77
+ if normalised not in _VALID_PROTECTION_MODES:
78
+ return 'auto'
79
+ return normalised
80
+
81
+
82
+ def clamp_strict_threshold(value) -> float:
83
+ """Clamp threshold to [0.01, 0.50]; returns 0.10 on parse failure."""
84
+ try:
85
+ return max(0.01, min(0.50, float(value)))
86
+ except (TypeError, ValueError):
87
+ return 0.10
88
+
89
+
90
+
91
+ # ---------------------------------------------------------------------------
92
+ # Step-window activation control
93
+ # ---------------------------------------------------------------------------
94
+ # Controls which denoising steps each AND_* strategy is active during.
95
+ # All values here are read by step_utils.strategy_is_active_from_state().
96
+ #
97
+ # step_window_enabled — master switch; False = always active (default)
98
+ # step_window_global — StepWindow or None, applied to every strategy
99
+ # step_window_per_strategy — dict[family_name, StepWindow] or None
100
+ # step_window_use_defaults — True = fall back to step_utils.STRATEGY_DEFAULTS
101
+
102
+ step_window_enabled: bool = False
103
+ step_window_global = None # StepWindow | None
104
+ step_window_per_strategy = None # dict[str, StepWindow] | None (family → StepWindow)
105
+ step_window_use_defaults: bool = False
106
+ step_window_custom_raw = None # dict[str, tuple] | None {AND_* UI key → (start,end)}
107
+ step_window_lock_after_end: bool = False # freeze strategies after their window closes
108
+
109
+ # Per-generation per-strategy lock flags {family_name → True if locked}
110
+ # Keyed by _lock_generation_id: flags from a previous generation are never
111
+ # visible to the current one even if reset_step_lock_flags() was not called.
112
+ _step_lock_flags: dict = {}
113
+
114
+ # Monotonically increasing counter bumped once per real generation
115
+ # (i.e. once per NeutralPromptScript.process() call).
116
+ # strategy_is_active_from_state embeds this ID in each flag entry so that
117
+ # flags from a previous generation are treated as absent.
118
+ _lock_generation_id: int = 0
119
+
120
+
121
+ def begin_new_generation() -> None:
122
+ """
123
+ Mark the start of a new generation run.
124
+
125
+ Call this exactly once per NeutralPromptScript.process() call.
126
+ Bumping the generation ID effectively invalidates all lock flags from
127
+ the previous run — without requiring any lock-flag cleanup in the
128
+ denoising loop itself.
129
+
130
+ This is intentionally NOT triggered by current_step == 0 because
131
+ some samplers (Restart, DPM++ variants) revisit step 0 internally
132
+ within a single denoising pass. Using the process() lifecycle hook
133
+ instead gives a clean, unambiguous per-generation boundary.
134
+ """
135
+ global _lock_generation_id, _step_lock_flags
136
+ _lock_generation_id += 1
137
+ _step_lock_flags = {}
138
+
139
+
140
+ def ensure_step_lock_run_initialized(run_token) -> bool:
141
+ """
142
+ Token-based alternative to begin_new_generation().
143
+
144
+ Initialises lock flags when *run_token* differs from the last seen token.
145
+ Returns True if a reset was performed (new generation), False otherwise.
146
+
147
+ Typical token: ``id(global_state.prompt_exprs)`` — a new list object is
148
+ created on every NeutralPromptScript.process() call, so the token changes
149
+ exactly once per generation and is stable across sampler restarts that
150
+ happen inside the denoising loop.
151
+
152
+ Prefer begin_new_generation() when you control the call-site (scripts/);
153
+ use this helper when you only have access to an existing context object.
154
+ """
155
+ global _lock_generation_id, _step_lock_flags
156
+ current_token = getattr(ensure_step_lock_run_initialized, '_last_token', object())
157
+ if current_token != run_token:
158
+ ensure_step_lock_run_initialized._last_token = run_token # type: ignore[attr-defined]
159
+ _lock_generation_id += 1
160
+ _step_lock_flags = {}
161
+ return True
162
+ return False
163
+
164
+
165
+ def reset_step_lock_flags() -> None:
166
+ """
167
+ Clear all Lock-after-End flags for the current generation.
168
+
169
+ Prefer begin_new_generation() for per-generation resets; this helper
170
+ exists for testing and for callers that need an in-generation flush.
171
+ """
172
+ global _step_lock_flags
173
+ _step_lock_flags = {}
174
+
175
+ # Current step counter — set by the hijack each denoising step.
176
+ current_step: int = 0
177
+ total_steps: int = 1
178
+
179
+
180
+ def update_step_state(step, total) -> None:
181
+ """
182
+ Update the current denoising step counters.
183
+
184
+ Called by the sampler hijack at each step so that step-window gating
185
+ has accurate progress information.
186
+
187
+ Parameters
188
+ ----------
189
+ step : current 0-based step index (any numeric type)
190
+ total : total number of steps (any numeric type, must be ≥ 1)
191
+ """
192
+ global current_step, total_steps
193
+ try:
194
+ current = max(0, int(step))
195
+ except Exception:
196
+ current = 0
197
+ try:
198
+ total_ = max(1, int(total))
199
+ except Exception:
200
+ total_ = 1
201
+ current_step = current
202
+ total_steps = total_
203
+
204
+
205
+ # ---------------------------------------------------------------------------
206
+ # Infotext helpers
207
+ # ---------------------------------------------------------------------------
208
+ # Keys used when saving/restoring settings via A1111 infotext.
209
+
210
+ INFOTEXT_KEY_PROTECTION_MODE = 'NP Protection Mode'
211
+ INFOTEXT_KEY_STRICT_THRESHOLD = 'NP Strict Threshold'
212
+ INFOTEXT_KEY_CFG_RESCALE = 'CFG Rescale phi'
213
+ INFOTEXT_KEY_STEP_WIN_ENABLED = 'NP Step Window Enabled'
214
+ INFOTEXT_KEY_STEP_WIN_MODE = 'NP Step Window Mode'
215
+ INFOTEXT_KEY_STEP_WIN_START = 'NP Step Window Start'
216
+ INFOTEXT_KEY_STEP_WIN_END = 'NP Step Window End'
217
+ INFOTEXT_KEY_STEP_WIN_CUSTOM = 'NP Step Window Custom'
218
+ INFOTEXT_KEY_STEP_WIN_LOCK = 'NP Step Window Lock'
219
+
220
+
221
+ def apply_infotext(fields: dict) -> None:
222
+ """
223
+ Apply a dict of infotext key→value pairs to global state.
224
+ Called from the script's on_infotext_pasted handler.
225
+ Only updates fields that are actually present in the dict.
226
+ """
227
+ global protection_mode, strict_threshold, cfg_rescale
228
+ global step_window_enabled, step_window_global, step_window_use_defaults
229
+ global step_window_per_strategy, step_window_custom_raw, step_window_lock_after_end
230
+
231
+ if INFOTEXT_KEY_PROTECTION_MODE in fields:
232
+ protection_mode = normalize_protection_mode(fields[INFOTEXT_KEY_PROTECTION_MODE])
233
+ if INFOTEXT_KEY_STRICT_THRESHOLD in fields:
234
+ strict_threshold = clamp_strict_threshold(fields[INFOTEXT_KEY_STRICT_THRESHOLD])
235
+ if INFOTEXT_KEY_CFG_RESCALE in fields:
236
+ try:
237
+ cfg_rescale = float(fields[INFOTEXT_KEY_CFG_RESCALE])
238
+ except (TypeError, ValueError):
239
+ pass
240
+
241
+ # Step-window fields (all optional — may be absent in older infotext)
242
+ if INFOTEXT_KEY_STEP_WIN_ENABLED in fields:
243
+ raw = fields[INFOTEXT_KEY_STEP_WIN_ENABLED]
244
+ step_window_enabled = str(raw).lower() in ('true', '1', 'yes')
245
+
246
+ if step_window_enabled:
247
+ mode = str(fields.get(INFOTEXT_KEY_STEP_WIN_MODE, 'global'))
248
+ try:
249
+ start = max(0.0, min(1.0, float(fields.get(INFOTEXT_KEY_STEP_WIN_START, 0.0))))
250
+ except (TypeError, ValueError):
251
+ start = 0.0
252
+ try:
253
+ end = max(0.0, min(1.0, float(fields.get(INFOTEXT_KEY_STEP_WIN_END, 1.0))))
254
+ except (TypeError, ValueError):
255
+ end = 1.0
256
+ if start > end:
257
+ start, end = 0.0, 1.0
258
+
259
+ if mode == 'global':
260
+ # Import here to avoid circular import at module level
261
+ from lib_neutral_prompt.step_utils import StepWindow as _SW
262
+ step_window_global = _SW(start, end)
263
+ step_window_use_defaults = False
264
+ step_window_per_strategy = None
265
+ elif mode == 'per-strategy defaults':
266
+ step_window_global = None
267
+ step_window_use_defaults = True
268
+ step_window_per_strategy = None
269
+ elif mode == 'per-strategy custom':
270
+ step_window_global = None
271
+ step_window_use_defaults = False
272
+ # Restore per-strategy dict from compact string if present
273
+ if INFOTEXT_KEY_STEP_WIN_CUSTOM in fields:
274
+ from lib_neutral_prompt.step_utils import (
275
+ deserialize_per_strategy_windows,
276
+ build_per_strategy_windows,
277
+ )
278
+ raw = deserialize_per_strategy_windows(
279
+ str(fields[INFOTEXT_KEY_STEP_WIN_CUSTOM])
280
+ )
281
+ step_window_custom_raw = raw
282
+ step_window_per_strategy = build_per_strategy_windows(raw)
283
+ else:
284
+ step_window_global = None
285
+ step_window_use_defaults = False
286
+
287
+ # Lock after End (may be set independently of other step-window fields)
288
+ if INFOTEXT_KEY_STEP_WIN_LOCK in fields:
289
+ raw_lock = fields[INFOTEXT_KEY_STEP_WIN_LOCK]
290
+ step_window_lock_after_end = str(raw_lock).lower() in ('true', '1', 'yes')
291
+
292
+
293
+ # ---------------------------------------------------------------------------
294
+ # CFG Rescale Factor Singleton (export_rescale_factor branch)
295
+ #
296
+ # Stores the last computed rescale factor so that external tools / scripts
297
+ # can read it without re-deriving it.
298
+ #
299
+ # Uses threading.local() so concurrent API workers each get their own slot.
300
+ # clear() must be called at the start of every _cfg_rescale() call so that
301
+ # get() correctly returns None when rescaling was skipped this step.
302
+ # ---------------------------------------------------------------------------
303
+
304
+ class CFGRescaleFactorSingleton:
305
+ """Thread-local store for the most recently computed CFG rescale factor.
306
+
307
+ Lifecycle per denoising step:
308
+ 1. _cfg_rescale() calls clear() at entry — value becomes None.
309
+ 2. If rescaling is active, set() stores the computed factor.
310
+ 3. External code calls get() after the step; receives the factor or None.
311
+ """
312
+
313
+ _state = threading.local()
314
+
315
+ @classmethod
316
+ def set(cls, value: float) -> None:
317
+ cls._state.cfg_rescale_factor = value
318
+
319
+ @classmethod
320
+ def get(cls) -> Optional[float]:
321
+ return getattr(cls._state, 'cfg_rescale_factor', None)
322
+
323
+ @classmethod
324
+ def clear(cls) -> None:
325
+ cls._state.cfg_rescale_factor = None
neutral_prompt_patcheds/lib_neutral_prompt/hijacker.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+
4
+ class ModuleHijacker:
5
+ def __init__(self, module):
6
+ self.__module = module
7
+ self.__original_functions = dict()
8
+
9
+ def hijack(self, attribute):
10
+ if attribute not in self.__original_functions:
11
+ self.__original_functions[attribute] = getattr(self.__module, attribute)
12
+
13
+ def decorator(function):
14
+ setattr(self.__module, attribute, functools.partial(function, original_function=self.__original_functions[attribute]))
15
+ return function
16
+
17
+ return decorator
18
+
19
+ def reset_module(self):
20
+ for attribute, original_function in self.__original_functions.items():
21
+ setattr(self.__module, attribute, original_function)
22
+
23
+ self.__original_functions.clear()
24
+
25
+ @staticmethod
26
+ def install_or_get(module, hijacker_attribute, on_uninstall=lambda _callback: None):
27
+ if not hasattr(module, hijacker_attribute):
28
+ module_hijacker = ModuleHijacker(module)
29
+ setattr(module, hijacker_attribute, module_hijacker)
30
+ on_uninstall(lambda: delattr(module, hijacker_attribute))
31
+ on_uninstall(module_hijacker.reset_module)
32
+ return module_hijacker
33
+ else:
34
+ return getattr(module, hijacker_attribute)
neutral_prompt_patcheds/lib_neutral_prompt/matryoshka_utils.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pure-Python helpers for Matryoshka (nested prompt) construction, tree rendering,
3
+ and diagnostics.
4
+
5
+ Importable without gradio or A1111 modules — all functions here have no
6
+ side effects and depend only on neutral_prompt_parser and global_state.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import Any, Dict, List, Optional
12
+
13
+ from lib_neutral_prompt import global_state, neutral_prompt_parser
14
+
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # Builder helpers
18
+ # ---------------------------------------------------------------------------
19
+
20
+ _KW_DEFAULTS: Dict[str, str] = {
21
+ 'AND_PERP': 'AND_PERP',
22
+ 'AND_SALT': 'AND_SALT[5]',
23
+ 'AND_SALT_WIDE': 'AND_SALT_WIDE[1]',
24
+ 'AND_SALT_BLOB': 'AND_SALT_BLOB[5]',
25
+ 'AND_TOPK': 'AND_TOPK[0.05]',
26
+ 'AND_ALIGN': 'AND_ALIGN[4,8]',
27
+ 'AND_MASK_ALIGN': 'AND_MASK_ALIGN[4,8]',
28
+ }
29
+
30
+ BUILDER_STRATEGIES = sorted(_KW_DEFAULTS.keys())
31
+
32
+
33
+ def build_child_block(strategy: str, text: str, weight: float,
34
+ affine: str = '', nested: str = '') -> str:
35
+ """
36
+ Build a single AND_* child block, optionally with an affine prefix
37
+ and an inner nested block.
38
+
39
+ Returns the child line(s) as a string (no trailing newline).
40
+ Returns an empty string if both `text` and `nested` are empty/whitespace —
41
+ callers should check for this and skip the block.
42
+
43
+ Examples::
44
+
45
+ build_child_block('AND_SALT', 'texture', 0.8)
46
+ # → 'AND_SALT[5] texture :0.8'
47
+
48
+ build_child_block('AND_TOPK', 'highlights', 0.5,
49
+ nested='AND_PERP blur :0.4')
50
+ # → 'AND_TOPK[0.05] [\\n highlights :0.5\\n AND_PERP blur :0.4\\n] :0.5'
51
+
52
+ build_child_block('AND_SALT', '', 0.8)
53
+ # → '' (empty — nothing to insert)
54
+ """
55
+ kw = _KW_DEFAULTS.get(strategy, strategy)
56
+ prefix = f'{affine.strip()} ' if affine.strip() else ''
57
+ w_str = f':{weight}'
58
+ text = text.strip()
59
+ nested = nested.strip()
60
+
61
+ # Nothing to add — signal the caller to skip this block entirely.
62
+ if not text and not nested:
63
+ return ''
64
+
65
+ if nested:
66
+ inner_lines: List[str] = []
67
+ # Only include the text line when text is non-empty
68
+ if text:
69
+ inner_lines.append(f' {text} {w_str}')
70
+ for ln in nested.splitlines():
71
+ inner_lines.append(f' {ln}')
72
+ body = '\n'.join(inner_lines)
73
+ return f'{prefix}{kw} [\n{body}\n] {w_str}'
74
+ else:
75
+ return f'{prefix}{kw} {text} {w_str}'
76
+
77
+
78
+ def build_nested_prompt(base: str, children: List[Dict[str, Any]]) -> str:
79
+ """
80
+ Build a complete prompt from a base string and a list of child dicts.
81
+
82
+ Each child dict may contain:
83
+ strategy (str), text (str), weight (float),
84
+ affine (str, optional), nested (str, optional)
85
+
86
+ Children whose `text` and `nested` are both empty are silently skipped.
87
+ Returns the full prompt as a single string.
88
+ """
89
+ parts = [base.strip()] if base.strip() else []
90
+ for ch in children:
91
+ block = build_child_block(
92
+ strategy=ch.get('strategy', 'AND_PERP'),
93
+ text=ch.get('text', ''),
94
+ weight=float(ch.get('weight', 1.0)),
95
+ affine=ch.get('affine', ''),
96
+ nested=ch.get('nested', ''),
97
+ )
98
+ if block: # empty string → skip
99
+ parts.append(block)
100
+ return '\n'.join(parts)
101
+
102
+
103
+ # ---------------------------------------------------------------------------
104
+ # Tree renderer
105
+ # ---------------------------------------------------------------------------
106
+
107
+ def render_prompt_node(
108
+ node: neutral_prompt_parser.PromptExpr,
109
+ depth: int = 0,
110
+ prefix: str = '',
111
+ is_last: bool = True,
112
+ ) -> List[str]:
113
+ """
114
+ Recursively render a single AST node as tree lines using box-drawing chars.
115
+
116
+ Returns a list of strings (one per line), without trailing newline.
117
+ """
118
+ connector = '└─ ' if is_last else '├─ '
119
+ child_pfx = prefix + (' ' if is_last else '│ ')
120
+
121
+ strat = node.conciliation.name if node.conciliation else 'BASE'
122
+ w = f'w={node.weight:.2f}'
123
+ params_str = ''
124
+ if node.conciliation_params:
125
+ params_str = ' [' + ', '.join(f'{k}={v}' for k, v in node.conciliation_params.items()) + ']'
126
+ aff_str = ' ✦affine' if node.local_transform is not None else ''
127
+
128
+ if hasattr(node, 'prompt'):
129
+ txt = node.prompt.strip()
130
+ short = (txt[:55] + '…') if len(txt) > 55 else txt
131
+ label = f'{connector}[{strat}] {w}{params_str}{aff_str} "{short}"'
132
+ else:
133
+ n = len(node.children) if hasattr(node, 'children') else 0
134
+ label = f'{connector}[{strat}] {w}{params_str}{aff_str} ({n} children)'
135
+
136
+ lines: List[str] = [prefix + label]
137
+
138
+ if hasattr(node, 'children') and node.children:
139
+ for i, child in enumerate(node.children):
140
+ last = (i == len(node.children) - 1)
141
+ lines.extend(render_prompt_node(child, depth + 1, child_pfx, last))
142
+
143
+ return lines
144
+
145
+
146
+ def render_prompt_tree(prompt_str: str) -> str:
147
+ """Full tree render. Returns a multi-line string ready for display."""
148
+ if not prompt_str.strip():
149
+ return '(empty prompt — nothing to show)'
150
+
151
+ try:
152
+ root = neutral_prompt_parser.parse_root(prompt_str)
153
+ except Exception as exc:
154
+ return f'⚠ Parse error: {exc}'
155
+
156
+ n = len(root.children) if hasattr(root, 'children') else 0
157
+ lines = [f'ROOT ({n} top-level segments)']
158
+ if hasattr(root, 'children') and root.children:
159
+ for i, child in enumerate(root.children):
160
+ last = (i == len(root.children) - 1)
161
+ lines.extend(render_prompt_node(child, prefix='', is_last=last))
162
+ return '\n'.join(lines)
163
+
164
+
165
+ # ---------------------------------------------------------------------------
166
+ # Diagnostics
167
+ # ---------------------------------------------------------------------------
168
+
169
+ _BRACKET_KW_NAMES = frozenset({
170
+ 'SALIENCE_MASK', 'SALIENCE_MASK_WIDE', 'SALIENCE_MASK_BLOB',
171
+ 'SEMANTIC_GUIDANCE', 'ALIGNMENT_BLEND_CUSTOM', 'ALIGNMENT_MASK_BLEND_CUSTOM',
172
+ })
173
+
174
+
175
+ def collect_prompt_diagnostics(prompt_str: str) -> Dict[str, Any]:
176
+ """
177
+ Parse prompt_str and collect diagnostic metadata.
178
+
179
+ Returns a dict with keys:
180
+ parse_ok, parse_error, max_depth, n_segments, n_leaf, n_composite,
181
+ n_affine, n_invalid_base, strategies, ignored_hint,
182
+ protection_would_fire, protection_mode
183
+ """
184
+ result: Dict[str, Any] = {
185
+ 'parse_ok': False,
186
+ 'parse_error': None,
187
+ 'max_depth': 0,
188
+ 'n_segments': 0,
189
+ 'n_leaf': 0,
190
+ 'n_composite': 0,
191
+ 'n_affine': 0,
192
+ 'n_invalid_base': 0,
193
+ 'strategies': [],
194
+ 'ignored_hint': False,
195
+ 'protection_would_fire': False,
196
+ 'protection_mode': global_state.normalize_protection_mode(global_state.protection_mode),
197
+ }
198
+
199
+ if not prompt_str.strip():
200
+ return result
201
+
202
+ try:
203
+ root = neutral_prompt_parser.parse_root(prompt_str)
204
+ except Exception as exc:
205
+ result['parse_error'] = str(exc)
206
+ return result
207
+
208
+ result['parse_ok'] = True
209
+ strategies_seen: List[str] = []
210
+
211
+ def _walk(node: Any, depth: int) -> None:
212
+ result['max_depth'] = max(result['max_depth'], depth)
213
+ result['n_segments'] += 1
214
+
215
+ if node.local_transform is not None:
216
+ result['n_affine'] += 1
217
+
218
+ strat = node.conciliation.name if node.conciliation else 'BASE'
219
+ if strat not in strategies_seen:
220
+ strategies_seen.append(strat)
221
+
222
+ if (node.conciliation is not None
223
+ and node.conciliation.name in _BRACKET_KW_NAMES
224
+ and not node.conciliation_params):
225
+ result['ignored_hint'] = True
226
+
227
+ if hasattr(node, 'prompt'):
228
+ result['n_leaf'] += 1
229
+ else:
230
+ result['n_composite'] += 1
231
+ if hasattr(node, 'children'):
232
+ for child in node.children:
233
+ _walk(child, depth + 1)
234
+
235
+ _walk(root, 0)
236
+ result['strategies'] = strategies_seen
237
+
238
+ if isinstance(root, neutral_prompt_parser.CompositePrompt):
239
+ has_base = any(c.conciliation is None for c in root.children)
240
+ mode = result['protection_mode']
241
+ if mode in ('auto', 'strict') and not has_base:
242
+ result['protection_would_fire'] = True
243
+ result['n_invalid_base'] = 1
244
+
245
+ return result
246
+
247
+
248
+ def render_diagnostics(diag: Dict[str, Any]) -> str:
249
+ """Turn a diagnostics dict into a readable report string."""
250
+ if not diag['parse_ok']:
251
+ err = diag.get('parse_error') or 'unknown'
252
+ return f'⚠ Parse error: {err}'
253
+
254
+ lines = ['── Diagnostics ──────────────────────────']
255
+ lines.append(f'Segments total : {diag["n_segments"]} '
256
+ f'(leaf={diag["n_leaf"]}, composite={diag["n_composite"]})')
257
+ lines.append(f'Max nesting depth : {diag["max_depth"]}')
258
+ lines.append(f'Affine transforms : {diag["n_affine"]}')
259
+ lines.append(f'Strategies used : {", ".join(diag["strategies"]) or "none"}')
260
+
261
+ if diag['ignored_hint']:
262
+ lines.append('⚠ One or more bracket params were invalid and ignored. Using defaults.')
263
+
264
+ lines.append('')
265
+ mode = diag['protection_mode']
266
+ lines.append(f'Protection mode : {mode}')
267
+ if mode == 'off':
268
+ lines.append(' Guard disabled — no fallback will occur')
269
+ elif diag['protection_would_fire']:
270
+ lines.append(' ⚠ WOULD FIRE — no plain base segment at the top level')
271
+ lines.append(' Fix: add a plain base prompt before all AND_* segments')
272
+ else:
273
+ lines.append(' OK — base segment present at top level')
274
+ if mode == 'strict':
275
+ thr = global_state.strict_threshold
276
+ lines.append(f' Strict ratio check ({thr:.2f}) runs at generation time')
277
+
278
+ return '\n'.join(lines)
279
+
280
+
281
+ def render_full_explain(prompt_str: str) -> str:
282
+ """
283
+ Combine tree + diagnostics + effect summaries into a single panel output.
284
+
285
+ For an empty prompt returns only the empty-state message (no spurious
286
+ parse-error noise).
287
+ """
288
+ if not prompt_str.strip():
289
+ return '(empty prompt — nothing to show)'
290
+
291
+ tree = render_prompt_tree(prompt_str)
292
+ diag = collect_prompt_diagnostics(prompt_str)
293
+ diagr = render_diagnostics(diag)
294
+ summary = _render_effect_summary(diag)
295
+
296
+ parts = [tree, '', diagr]
297
+ if summary:
298
+ parts += ['', summary]
299
+ return '\n'.join(parts)
300
+
301
+
302
+ # ---------------------------------------------------------------------------
303
+ # Effect summaries
304
+ # ---------------------------------------------------------------------------
305
+
306
+ _EFFECT_DESCRIPTIONS: Dict[str, str] = {
307
+ 'BASE': 'base prompt contribution — drives the overall image',
308
+ 'PERPENDICULAR': 'perpendicular projection — reduces contradicting features',
309
+ 'SALIENCE_MASK': 'sharp saliency mask — targets the most salient latent pixels',
310
+ 'SALIENCE_MASK_WIDE': 'broad saliency mask — covers a wide region (~55% of pixels)',
311
+ 'SALIENCE_MASK_BLOB': 'blob saliency — grows seed region to a smooth connected area',
312
+ 'SEMANTIC_GUIDANCE': 'semantic top-k — applies sparse targeted changes to the strongest elements',
313
+ 'ALIGNMENT_BLEND_CUSTOM': 'alignment blend — soft injection that preserves global structure',
314
+ 'ALIGNMENT_MASK_BLEND_CUSTOM': 'alignment mask — binary-mask variant of alignment blend (stricter)',
315
+ }
316
+ # Catch legacy fixed-suffix ALIGNMENT_BLEND_D_S patterns
317
+ _ALIGNMENT_PREFIX = 'ALIGNMENT_BLEND_'
318
+ _ALIGNMENT_MASK_PREFIX = 'ALIGNMENT_MASK_BLEND_'
319
+
320
+
321
+ def _effect_for_strategy(name: str) -> str:
322
+ if name in _EFFECT_DESCRIPTIONS:
323
+ return _EFFECT_DESCRIPTIONS[name]
324
+ if name.startswith(_ALIGNMENT_MASK_PREFIX):
325
+ return _EFFECT_DESCRIPTIONS['ALIGNMENT_MASK_BLEND_CUSTOM']
326
+ if name.startswith(_ALIGNMENT_PREFIX):
327
+ return _EFFECT_DESCRIPTIONS['ALIGNMENT_BLEND_CUSTOM']
328
+ return f'strategy: {name}'
329
+
330
+
331
+ def _render_effect_summary(diag: Dict[str, Any]) -> str:
332
+ """
333
+ Build a one-line-per-strategy effect summary from diagnostics data.
334
+ Returns an empty string if the prompt failed to parse.
335
+ """
336
+ if not diag.get('parse_ok'):
337
+ return ''
338
+ strategies = diag.get('strategies', [])
339
+ if not strategies:
340
+ return ''
341
+ lines = ['── Effect summary ───────────────────────']
342
+ for s in strategies:
343
+ lines.append(f' {s:32s}→ {_effect_for_strategy(s)}')
344
+ lines.append('')
345
+ lines.append(' Note: protection verdict above is a structural preview.')
346
+ lines.append(' The strict ratio check and batch diagnostics run at generation time.')
347
+ return '\n'.join(lines)
348
+
349
+
350
+ # ---------------------------------------------------------------------------
351
+ # Matryoshka templates
352
+ # ---------------------------------------------------------------------------
353
+
354
+ MATRYOSHKA_TEMPLATES: Dict[str, Dict[str, str]] = {
355
+ 'Nested local detail': {
356
+ 'description': (
357
+ 'Add fine detail inside a saliency region. '
358
+ 'SALT focuses on salient pixels; nested TOPK applies sparse corrections within that region.'
359
+ ),
360
+ 'prompt': (
361
+ 'base subject, environment\n'
362
+ 'AND_SALT[5] [\n'
363
+ ' subject texture :0.8\n'
364
+ ' AND_TOPK[0.05] fine detail highlights :0.5\n'
365
+ '] :0.8'
366
+ ),
367
+ },
368
+ 'Structure preserve + style inject': {
369
+ 'description': (
370
+ 'Keep global composition while injecting a style locally. '
371
+ 'ALIGN preserves structure; nested PERP reduces contradiction.'
372
+ ),
373
+ 'prompt': (
374
+ 'base composition, environment\n'
375
+ 'AND_ALIGN[4,8] [\n'
376
+ ' style name :0.8\n'
377
+ ' AND_PERP style contradiction :0.5\n'
378
+ '] :0.7'
379
+ ),
380
+ },
381
+ 'Sparse detail inside broad region': {
382
+ 'description': (
383
+ 'SALT_WIDE covers a broad region; nested SALT sharpens a smaller focal point inside it.'
384
+ ),
385
+ 'prompt': (
386
+ 'base subject\n'
387
+ 'AND_SALT_WIDE[1] [\n'
388
+ ' surface texture :0.7\n'
389
+ ' AND_SALT[7] focal detail :0.6\n'
390
+ '] :0.8'
391
+ ),
392
+ },
393
+ 'Perpendicular correction inside texture block': {
394
+ 'description': (
395
+ 'Inject texture while suppressing contradicting composition hints '
396
+ 'via nested perpendicular projection.'
397
+ ),
398
+ 'prompt': (
399
+ 'base portrait, detailed\n'
400
+ 'AND_SALT[5] [\n'
401
+ ' skin texture, pores :0.8\n'
402
+ ' AND_PERP smooth featureless face :0.4\n'
403
+ '] :0.7'
404
+ ),
405
+ },
406
+ 'Nested ALIGN + SALT': {
407
+ 'description': (
408
+ 'Two-pass spatial injection: ALIGN blends a structural concept, '
409
+ 'SALT sharpens a second concept only in salient regions.'
410
+ ),
411
+ 'prompt': (
412
+ 'base subject\n'
413
+ 'AND_ALIGN[4,8] background concept :0.6\n'
414
+ 'AND_SALT[5] [\n'
415
+ ' foreground detail :0.8\n'
416
+ ' AND_TOPK[0.1] specular highlights :0.4\n'
417
+ '] :0.7'
418
+ ),
419
+ },
420
+ 'Mirrored composition': {
421
+ 'description': (
422
+ 'Apply the same concept mirrored, blended via PERP to add symmetric detail '
423
+ 'without overwriting the composition.'
424
+ ),
425
+ 'prompt': (
426
+ 'base subject\n'
427
+ 'SCALE[-1,1] AND_PERP mirrored version of subject :0.5'
428
+ ),
429
+ },
430
+ 'Deep nested concept isolation': {
431
+ 'description': (
432
+ 'Three levels: outer SALT focuses area, middle ALIGN preserves sub-structure, '
433
+ 'inner TOPK applies sparse micro-corrections.'
434
+ ),
435
+ 'prompt': (
436
+ 'base subject\n'
437
+ 'AND_SALT[5] [\n'
438
+ ' region concept :0.8\n'
439
+ ' AND_ALIGN[4,8] [\n'
440
+ ' structure keeper :0.7\n'
441
+ ' AND_TOPK[0.05] micro detail :0.4\n'
442
+ ' ] :0.6\n'
443
+ '] :0.7'
444
+ ),
445
+ },
446
+ }
447
+
448
+
449
+ # ---------------------------------------------------------------------------
450
+ # Affine snippet builder — re-exported from affine_utils (single source of truth)
451
+ # ---------------------------------------------------------------------------
452
+ from lib_neutral_prompt.affine_utils import (
453
+ AFFINE_TRANSFORMS as _AFFINE_TRANSFORMS,
454
+ AFFINE_PRESETS as _AFFINE_PRESETS,
455
+ build_affine_snippet as _build_affine_snippet,
456
+ )
457
+
458
+
459
+ # ---------------------------------------------------------------------------
460
+ # Builder v1.2 — list-based node model
461
+ # ---------------------------------------------------------------------------
462
+
463
+ import dataclasses
464
+ import uuid
465
+
466
+
467
+ @dataclasses.dataclass
468
+ class BuilderNode:
469
+ """
470
+ A single AND_* node in the visual matryoshka builder tree.
471
+
472
+ Fields
473
+ ------
474
+ node_id : unique identifier (auto-generated UUID4 short)
475
+ strategy : key into _KW_DEFAULTS (e.g. 'AND_SALT', 'AND_TOPK')
476
+ text : text content for this node
477
+ weight : contribution weight (0.0 … 2.0 typical)
478
+ affine : optional affine snippet prefix (e.g. 'ROTATE[0.125]')
479
+ children : list of BuilderNode — forms the nested (matryoshka) subtree
480
+ """
481
+ node_id: str
482
+ strategy: str
483
+ text: str
484
+ weight: float
485
+ affine: str = dataclasses.field(default='')
486
+ children: List['BuilderNode'] = dataclasses.field(default_factory=list)
487
+
488
+ @classmethod
489
+ def make(cls, strategy: str = 'AND_SALT', text: str = '',
490
+ weight: float = 0.8, affine: str = '') -> 'BuilderNode':
491
+ """Create a new BuilderNode with a fresh UUID."""
492
+ return cls(
493
+ node_id=uuid.uuid4().hex[:8],
494
+ strategy=strategy,
495
+ text=text,
496
+ weight=weight,
497
+ affine=affine,
498
+ )
499
+
500
+
501
+ # ---------------------------------------------------------------------------
502
+ # Builder tree operations (pure — no gradio)
503
+ # ---------------------------------------------------------------------------
504
+
505
+ def builder_add_child(nodes: List[BuilderNode], parent_id: Optional[str] = None,
506
+ **kwargs) -> List[BuilderNode]:
507
+ """
508
+ Add a new child BuilderNode.
509
+
510
+ If *parent_id* is None: append to the top-level list.
511
+ If *parent_id* matches a node id: append as nested child of that node.
512
+ Returns a *new* list (original is not mutated).
513
+ """
514
+ new_node = BuilderNode.make(**kwargs)
515
+ if parent_id is None:
516
+ return list(nodes) + [new_node]
517
+
518
+ def _add(ns: List[BuilderNode]) -> List[BuilderNode]:
519
+ result = []
520
+ for n in ns:
521
+ if n.node_id == parent_id:
522
+ result.append(dataclasses.replace(n, children=list(n.children) + [new_node]))
523
+ else:
524
+ result.append(dataclasses.replace(n, children=_add(n.children)))
525
+ return result
526
+
527
+ return _add(nodes)
528
+
529
+
530
+ def builder_remove_node(nodes: List[BuilderNode], node_id: str) -> List[BuilderNode]:
531
+ """Remove the node with *node_id* from anywhere in the tree."""
532
+ def _remove(ns: List[BuilderNode]) -> List[BuilderNode]:
533
+ return [
534
+ dataclasses.replace(n, children=_remove(n.children))
535
+ for n in ns if n.node_id != node_id
536
+ ]
537
+ return _remove(nodes)
538
+
539
+
540
+ def builder_duplicate_node(nodes: List[BuilderNode], node_id: str) -> List[BuilderNode]:
541
+ """
542
+ Insert a deep copy of *node_id* immediately after it in its parent list.
543
+ The duplicate gets a fresh node_id.
544
+ """
545
+ def _fresh(n: BuilderNode) -> BuilderNode:
546
+ return dataclasses.replace(
547
+ n,
548
+ node_id=uuid.uuid4().hex[:8],
549
+ children=[_fresh(c) for c in n.children],
550
+ )
551
+
552
+ def _dup(ns: List[BuilderNode]) -> List[BuilderNode]:
553
+ result = []
554
+ for n in ns:
555
+ result.append(dataclasses.replace(n, children=_dup(n.children)))
556
+ if n.node_id == node_id:
557
+ result.append(_fresh(n))
558
+ return result
559
+
560
+ return _dup(nodes)
561
+
562
+
563
+ def builder_move_up(nodes: List[BuilderNode], node_id: str) -> List[BuilderNode]:
564
+ """Swap *node_id* with its predecessor in its parent list."""
565
+ def _swap(ns: List[BuilderNode]) -> List[BuilderNode]:
566
+ out = []
567
+ for i, n in enumerate(ns):
568
+ if n.node_id == node_id and i > 0:
569
+ out[-1], swapped = dataclasses.replace(n, children=_swap(n.children)), out[-1]
570
+ out.append(swapped)
571
+ else:
572
+ out.append(dataclasses.replace(n, children=_swap(n.children)))
573
+ return out
574
+ return _swap(nodes)
575
+
576
+
577
+ def builder_move_down(nodes: List[BuilderNode], node_id: str) -> List[BuilderNode]:
578
+ """Swap *node_id* with its successor in its parent list."""
579
+ def _swap(ns: List[BuilderNode]) -> List[BuilderNode]:
580
+ out = [dataclasses.replace(n, children=_swap(n.children)) for n in ns]
581
+ for i, n in enumerate(out):
582
+ if n.node_id == node_id and i < len(out) - 1:
583
+ out[i], out[i + 1] = out[i + 1], out[i]
584
+ break
585
+ return out
586
+ return _swap(nodes)
587
+
588
+
589
+ def builder_update_node(nodes: List[BuilderNode], node_id: str,
590
+ **kwargs) -> List[BuilderNode]:
591
+ """
592
+ Update fields of *node_id* anywhere in the tree.
593
+ Only keys present in *kwargs* are changed.
594
+ """
595
+ def _update(ns: List[BuilderNode]) -> List[BuilderNode]:
596
+ return [
597
+ dataclasses.replace(
598
+ n,
599
+ children=_update(n.children),
600
+ **{k: v for k, v in kwargs.items() if k != 'children'},
601
+ ) if n.node_id == node_id
602
+ else dataclasses.replace(n, children=_update(n.children))
603
+ for n in ns
604
+ ]
605
+ return _update(nodes)
606
+
607
+
608
+ def serialize_builder_tree(base: str, nodes: List[BuilderNode], indent: int = 0) -> str:
609
+ """
610
+ Convert a base string + list of BuilderNodes into a prompt string.
611
+ Nested children are rendered as composite bracket blocks.
612
+ """
613
+ lines: List[str] = []
614
+ if indent == 0 and base.strip():
615
+ lines.append(base.strip())
616
+
617
+ for node in nodes:
618
+ kw = _KW_DEFAULTS.get(node.strategy, node.strategy)
619
+ prefix = f'{node.affine.strip()} ' if node.affine.strip() else ''
620
+ w_str = f':{node.weight}'
621
+ pad = ' ' * indent
622
+
623
+ if node.children:
624
+ # Composite bracket block
625
+ inner: List[str] = []
626
+ if node.text.strip():
627
+ inner.append(f' {pad}{node.text.strip()} {w_str}')
628
+ for sub in node.children:
629
+ for sub_ln in serialize_builder_tree('', [sub], indent + 1).splitlines():
630
+ if sub_ln.strip():
631
+ inner.append(f' {pad}{sub_ln}')
632
+ body = '\n'.join(inner)
633
+ lines.append(f'{pad}{prefix}{kw} [\n{body}\n{pad}] {w_str}')
634
+ else:
635
+ if node.text.strip():
636
+ lines.append(f'{pad}{prefix}{kw} {node.text.strip()} {w_str}')
637
+ # Node with no text and no children: skip (same policy as build_child_block)
638
+
639
+ return '\n'.join(lines)
640
+
641
+
642
+ def builder_tree_to_prompt(base: str, nodes: List[BuilderNode]) -> str:
643
+ """Top-level entry point: serialise the full builder tree to a prompt string."""
644
+ return serialize_builder_tree(base, nodes)
neutral_prompt_patcheds/lib_neutral_prompt/neutral_prompt_parser.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified Neutral Prompt Parser
3
+ Combines:
4
+ - Core AND / AND_PERP / AND_SALT / AND_TOPK strategies (main branch)
5
+ - Affine spatial transforms: ROTATE / SLIDE / SCALE / SHEAR (affine branch)
6
+ - Fixed AND_ALIGN_D_S / AND_MASK_ALIGN_D_S (alignment_blend branch)
7
+ - Bracket-syntax AND_ALIGN[D,S] / AND_MASK_ALIGN[D,S] (NEW — any valid pair)
8
+ - Configurable k for salience: AND_SALT[k] / AND_SALT_WIDE[k] / AND_SALT_BLOB[k] (NEW)
9
+
10
+ New syntax examples:
11
+ AND_SALT[5] concept :0.8 k=5.0 (more coverage than default 20)
12
+ AND_SALT_WIDE[3] concept :0.6 explicit k override
13
+ AND_SALT_BLOB[8] concept :1.0 blob seed sharpness
14
+ AND_ALIGN[4,8] style :0.5 detail=4, structure=8 (any valid pair)
15
+ AND_MASK_ALIGN[6,12] style :0.5 binary-mask variant
16
+
17
+ AND_ALIGN_4_8 still works — backward-compatible.
18
+ """
19
+
20
+ import abc
21
+ import dataclasses
22
+ import math
23
+ import re
24
+ from enum import Enum
25
+ from itertools import product
26
+ from typing import Any, Dict, List, Optional, Tuple
27
+
28
+ import torch
29
+
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Keyword registry
33
+ # ---------------------------------------------------------------------------
34
+
35
+ _BASE_KEYWORD_MAP = {
36
+ 'AND_PERP': 'PERPENDICULAR',
37
+ 'AND_SALT': 'SALIENCE_MASK',
38
+ 'AND_SALT_WIDE': 'SALIENCE_MASK_WIDE',
39
+ 'AND_SALT_BLOB': 'SALIENCE_MASK_BLOB',
40
+ 'AND_TOPK': 'SEMANTIC_GUIDANCE',
41
+ # Bracket-syntax custom alignment (any D,S pair via AND_ALIGN[D,S])
42
+ 'AND_ALIGN': 'ALIGNMENT_BLEND_CUSTOM',
43
+ 'AND_MASK_ALIGN': 'ALIGNMENT_MASK_BLEND_CUSTOM',
44
+ }
45
+
46
+ _ALIGN_KEYWORD_MAP = {
47
+ f'AND_ALIGN_{i}_{j}': f'ALIGNMENT_BLEND_{i}_{j}'
48
+ for i, j in product(range(2, 33), repeat=2) if i != j
49
+ }
50
+ _MASK_ALIGN_KEYWORD_MAP = {
51
+ f'AND_MASK_ALIGN_{i}_{j}': f'ALIGNMENT_MASK_BLEND_{i}_{j}'
52
+ for i, j in product(range(2, 33), repeat=2) if i != j
53
+ }
54
+
55
+ keyword_mapping = _BASE_KEYWORD_MAP | _ALIGN_KEYWORD_MAP | _MASK_ALIGN_KEYWORD_MAP
56
+
57
+ PromptKeyword = Enum('PromptKeyword', {'AND': 'AND', **{k: k for k in keyword_mapping}})
58
+ ConciliationStrategy = Enum('ConciliationStrategy', {v: k for k, v in keyword_mapping.items()})
59
+
60
+ prompt_keywords = [e.value for e in PromptKeyword]
61
+ conciliation_strategies = [e.value for e in ConciliationStrategy]
62
+
63
+ _prompt_keywords_set = frozenset(prompt_keywords)
64
+ _conciliation_strategies_set = frozenset(conciliation_strategies)
65
+
66
+ _SALT_KEYWORDS = frozenset({'AND_SALT', 'AND_SALT_WIDE', 'AND_SALT_BLOB'})
67
+ _ALIGN_BRACKET_KEYWORDS = frozenset({'AND_ALIGN', 'AND_MASK_ALIGN'})
68
+ _TOPK_KEYWORDS = frozenset({'AND_TOPK'})
69
+
70
+
71
+
72
+ # ---------------------------------------------------------------------------
73
+ # Affine transform definitions (imported from affine_utils — single source of truth)
74
+ # ---------------------------------------------------------------------------
75
+ from lib_neutral_prompt.affine_utils import (
76
+ affine_transforms,
77
+ _affine_warn,
78
+ _make_rotate,
79
+ _make_slide,
80
+ _make_scale,
81
+ _make_shear,
82
+ _apply_affine,
83
+ )
84
+
85
+
86
+
87
+ # ---------------------------------------------------------------------------
88
+ # AST node types
89
+ # ---------------------------------------------------------------------------
90
+
91
+ @dataclasses.dataclass
92
+ class PromptExpr(abc.ABC):
93
+ weight: float
94
+ conciliation: Optional[ConciliationStrategy]
95
+ local_transform: Optional[torch.Tensor]
96
+
97
+ @abc.abstractmethod
98
+ def accept(self, visitor, *args, **kwargs) -> Any:
99
+ pass
100
+
101
+
102
+ @dataclasses.dataclass
103
+ class LeafPrompt(PromptExpr):
104
+ prompt: str
105
+ # conciliation_params MUST be last — it has a default, parent fields don't.
106
+ conciliation_params: Dict[str, Any] = dataclasses.field(default_factory=dict)
107
+
108
+ def accept(self, visitor, *args, **kwargs):
109
+ return visitor.visit_leaf_prompt(self, *args, **kwargs)
110
+
111
+
112
+ @dataclasses.dataclass
113
+ class CompositePrompt(PromptExpr):
114
+ children: List[PromptExpr]
115
+ conciliation_params: Dict[str, Any] = dataclasses.field(default_factory=dict)
116
+
117
+ def accept(self, visitor, *args, **kwargs):
118
+ return visitor.visit_composite_prompt(self, *args, **kwargs)
119
+
120
+
121
+ class FlatSizeVisitor:
122
+ def visit_leaf_prompt(self, that: 'LeafPrompt') -> int:
123
+ return 1
124
+
125
+ def visit_composite_prompt(self, that: 'CompositePrompt') -> int:
126
+ return sum(child.accept(self) for child in that.children) if that.children else 0
127
+
128
+
129
+ # ---------------------------------------------------------------------------
130
+ # Parser
131
+ # ---------------------------------------------------------------------------
132
+
133
+ def parse_root(string: str) -> CompositePrompt:
134
+ tokens = tokenize(string)
135
+ prompts = parse_prompts(tokens)
136
+ return CompositePrompt(1.0, None, None, prompts)
137
+
138
+
139
+ def parse_prompts(tokens: List[str], *, nested: bool = False) -> List[PromptExpr]:
140
+ prompts = [parse_prompt(tokens, first=True, nested=nested)]
141
+ while tokens:
142
+ if nested and tokens[0] == ']':
143
+ break
144
+ prompts.append(parse_prompt(tokens, first=False, nested=nested))
145
+ return prompts
146
+
147
+
148
+ def _compose_affine(
149
+ a: Optional[torch.Tensor],
150
+ b: Optional[torch.Tensor],
151
+ ) -> Optional[torch.Tensor]:
152
+ if a is None:
153
+ return b
154
+ if b is None:
155
+ return a
156
+ a3 = torch.vstack([a, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32)])
157
+ b3 = torch.vstack([b, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32)])
158
+ return (b3 @ a3)[:-1]
159
+
160
+
161
+ def _looks_like_leading_affine_prompt(tokens: List[str]) -> bool:
162
+ pos = 0
163
+ n = len(tokens)
164
+ if pos < n and not tokens[pos].strip():
165
+ pos += 1
166
+ if pos >= n or tokens[pos] not in affine_transforms:
167
+ return False
168
+ while pos < n and tokens[pos] in affine_transforms:
169
+ pos += 1
170
+ if pos >= n or tokens[pos] != '[':
171
+ return False
172
+ pos += 1
173
+ if pos < n and tokens[pos] != ']':
174
+ pos += 1
175
+ if pos >= n or tokens[pos] != ']':
176
+ return False
177
+ pos += 1
178
+ if pos < n and not tokens[pos].strip():
179
+ pos += 1
180
+ return pos < n and tokens[pos] in _prompt_keywords_set
181
+
182
+
183
+ def _parse_keyword_params(tokens: List[str], keyword: str) -> Dict[str, Any]:
184
+ """
185
+ Consume optional bracket params immediately after a keyword token.
186
+
187
+ Safe: only consumes [inner] when inner is purely numeric (digits, dots,
188
+ commas, spaces), so composite-prompt brackets like AND_ALIGN[child AND other]
189
+ are never eaten.
190
+
191
+ Supported forms (spaces around comma are tolerated):
192
+ AND_SALT[5] -> {'k': 5.0}
193
+ AND_SALT[3, 5] -- invalid (too many parts) → warning
194
+ AND_ALIGN[4,8] -> {'d': 4, 's': 8}
195
+ AND_ALIGN[4, 8] -> {'d': 4, 's': 8} (spaces OK)
196
+ AND_MASK_ALIGN[6,12] -> {'d': 6, 's': 12}
197
+
198
+ On invalid but numeric content: consumes the brackets, emits a warning,
199
+ returns {} so the text is not silently swallowed as plain prompt text.
200
+ """
201
+ if (keyword not in _SALT_KEYWORDS and keyword not in _ALIGN_BRACKET_KEYWORDS
202
+ and keyword not in _TOPK_KEYWORDS):
203
+ return {}
204
+ if len(tokens) < 3 or tokens[0] != '[' or tokens[2] != ']':
205
+ return {}
206
+
207
+ inner = tokens[1].strip()
208
+ # Allow digits, dots, commas, and spaces — but nothing alphabetic.
209
+ # This keeps composite brackets like [concept AND style] safely un-consumed.
210
+ if not re.match(r'^[\d.,\s]+$', inner):
211
+ return {}
212
+
213
+ # Consume the brackets now — we own this content.
214
+ tokens.pop(0); tokens.pop(0); tokens.pop(0)
215
+
216
+ try:
217
+ parts = [p.strip() for p in inner.split(',') if p.strip()]
218
+
219
+ if keyword in _SALT_KEYWORDS:
220
+ if len(parts) != 1:
221
+ _param_warn(keyword, inner,
222
+ f"expected exactly one number (k), got {len(parts)} value(s)")
223
+ return {}
224
+ k = float(parts[0])
225
+ if k <= 0:
226
+ _param_warn(keyword, inner,
227
+ f"k must be > 0, got {k}")
228
+ return {}
229
+ return {'k': k}
230
+
231
+ elif keyword in _ALIGN_BRACKET_KEYWORDS:
232
+ if len(parts) != 2:
233
+ _param_warn(keyword, inner,
234
+ f"expected exactly two numbers (D, S), got {len(parts)} value(s)")
235
+ return {}
236
+ d, s = int(float(parts[0])), int(float(parts[1]))
237
+ if not (2 <= d <= 32):
238
+ _param_warn(keyword, inner,
239
+ f"D={d} is out of range [2, 32]")
240
+ return {}
241
+ if not (2 <= s <= 32):
242
+ _param_warn(keyword, inner,
243
+ f"S={s} is out of range [2, 32]")
244
+ return {}
245
+ if d == s:
246
+ _param_warn(keyword, inner,
247
+ f"D and S must differ (both are {d}); "
248
+ "no structural distinction possible")
249
+ return {}
250
+ return {'d': d, 's': s}
251
+
252
+ elif keyword in _TOPK_KEYWORDS:
253
+ if len(parts) != 1:
254
+ _param_warn(keyword, inner,
255
+ f"expected exactly one number (threshold), got {len(parts)} value(s)")
256
+ return {}
257
+ threshold = float(parts[0])
258
+ if not (0.0 < threshold <= 1.0):
259
+ _param_warn(keyword, inner,
260
+ f"threshold must be in (0, 1], got {threshold}")
261
+ return {}
262
+ return {'threshold': threshold}
263
+
264
+ except (ValueError, IndexError) as exc:
265
+ _param_warn(keyword, inner, f"could not parse numbers: {exc}")
266
+
267
+ return {}
268
+
269
+
270
+ def _param_warn(keyword: str, raw: str, reason: str) -> None:
271
+ """
272
+ Print a one-line parse warning to stderr.
273
+
274
+ Controlled by global_state.verbose so it doesn't spam quiet setups.
275
+ Imported lazily to avoid a circular import at module level.
276
+ """
277
+ try:
278
+ from lib_neutral_prompt import global_state as _gs
279
+ if not _gs.verbose:
280
+ return
281
+ except ImportError:
282
+ pass # if global_state is not yet available, always print
283
+ import sys
284
+ print(
285
+ f"[neutral-prompt] WARNING: invalid params for {keyword}[{raw}] — {reason}. "
286
+ "Ignoring bracket content; keyword will use its default behaviour.",
287
+ file=sys.stderr,
288
+ )
289
+
290
+
291
+ def parse_prompt(tokens: List[str], *, first: bool, nested: bool = False) -> PromptExpr:
292
+ leading_affine: Optional[torch.Tensor] = None
293
+ if _looks_like_leading_affine_prompt(tokens):
294
+ probe = tokens.copy()
295
+ leading_affine = _parse_affine_transform(probe)
296
+ tokens[:] = probe
297
+
298
+ if tokens and tokens[0] in _prompt_keywords_set:
299
+ prompt_type = tokens.pop(0)
300
+ else:
301
+ prompt_type = PromptKeyword.AND.value
302
+
303
+ conciliation = ConciliationStrategy(prompt_type) if prompt_type in _conciliation_strategies_set else None
304
+
305
+ # Parse trailing affine once before params (handles AND_PERP ROTATE[0.25] text)
306
+ # and once after (handles AND_SALT[5] ROTATE[0.25] text).
307
+ trailing_affine_pre = _parse_affine_transform(tokens)
308
+
309
+ # Consume bracket params BEFORE composite-bracket check.
310
+ # This ensures AND_ALIGN[4,8] text is not confused with
311
+ # AND_ALIGN [child AND other] (composite form).
312
+ conciliation_params = _parse_keyword_params(tokens, prompt_type)
313
+
314
+ # Second trailing-affine pass: catches affine that comes AFTER [k] params.
315
+ trailing_affine_post = _parse_affine_transform(tokens)
316
+ trailing_affine = _compose_affine(trailing_affine_pre, trailing_affine_post)
317
+ affine_transform = _compose_affine(leading_affine, trailing_affine)
318
+
319
+ tokens_copy = tokens.copy()
320
+ if tokens_copy and tokens_copy[0] == '[':
321
+ tokens_copy.pop(0)
322
+ prompts = parse_prompts(tokens_copy, nested=True)
323
+ if tokens_copy:
324
+ assert tokens_copy.pop(0) == ']'
325
+ if len(prompts) > 1:
326
+ tokens[:] = tokens_copy
327
+ weight = _parse_weight(tokens)
328
+ return CompositePrompt(weight, conciliation, affine_transform, prompts, conciliation_params)
329
+
330
+ prompt_text, weight = _parse_prompt_text(tokens, nested=nested)
331
+ return LeafPrompt(weight, conciliation, affine_transform, prompt_text, conciliation_params)
332
+
333
+
334
+ def _parse_prompt_text(tokens: List[str], *, nested: bool = False) -> Tuple[str, float]:
335
+ text = ''
336
+ depth = 0
337
+ weight = 1.0
338
+ while tokens:
339
+ tok = tokens[0]
340
+ if tok == ']':
341
+ if depth == 0:
342
+ if nested:
343
+ break
344
+ else:
345
+ depth -= 1
346
+ elif tok == '[':
347
+ depth += 1
348
+ elif tok == ':':
349
+ if len(tokens) >= 2 and _is_float(tokens[1].strip()):
350
+ if len(tokens) < 3 or tokens[2] in _prompt_keywords_set or (tokens[2] == ']' and depth == 0):
351
+ tokens.pop(0)
352
+ weight = float(tokens.pop(0).strip())
353
+ break
354
+ elif tok in _prompt_keywords_set:
355
+ break
356
+ elif depth == 0 and _looks_like_leading_affine_prompt(tokens):
357
+ break
358
+ text += tokens.pop(0)
359
+ return text, weight
360
+
361
+
362
+ def _parse_affine_transform(tokens: List[str]) -> Optional[torch.Tensor]:
363
+ tokens_copy = tokens.copy()
364
+ if tokens_copy and not tokens_copy[0].strip():
365
+ tokens_copy.pop(0)
366
+
367
+ affine_funcs = []
368
+
369
+ while tokens_copy and tokens_copy[0] in affine_transforms:
370
+ func = affine_transforms[tokens_copy.pop(0)]
371
+ args: List[float] = []
372
+
373
+ if not (tokens_copy and tokens_copy[0] == '['):
374
+ break
375
+ tokens_copy.pop(0)
376
+
377
+ if tokens_copy and tokens_copy[0] != ']':
378
+ if tokens_copy[0].strip():
379
+ try:
380
+ args = [float(a.strip()) for a in tokens_copy.pop(0).split(',')]
381
+ except ValueError:
382
+ break
383
+ else:
384
+ tokens_copy.pop(0)
385
+
386
+ if not (tokens_copy and tokens_copy[0] == ']'):
387
+ break
388
+ tokens_copy.pop(0)
389
+
390
+ affine_funcs.append(lambda t, f=func, a=args: f(t, *a))
391
+ if tokens_copy and not tokens_copy[0].strip():
392
+ tokens_copy.pop(0)
393
+ tokens[:] = tokens_copy
394
+
395
+ if not affine_funcs:
396
+ return None
397
+
398
+ transform = torch.eye(3, dtype=torch.float32)[:-1]
399
+ for fn in reversed(affine_funcs):
400
+ transform = fn(transform)
401
+ return transform
402
+
403
+
404
+ def _parse_weight(tokens: List[str]) -> float:
405
+ if len(tokens) >= 2 and tokens[0] == ':' and _is_float(tokens[1]):
406
+ tokens.pop(0)
407
+ return float(tokens.pop(0))
408
+ return 1.0
409
+
410
+
411
+ def tokenize(s: str) -> List[str]:
412
+ affine_kw_pattern = '|'.join(
413
+ rf'(?<!\w){kw}(?!\w)' for kw in affine_transforms
414
+ )
415
+ keyword_pattern = (
416
+ r'(?<!\w)AND_MASK_ALIGN_\d+_\d+(?!\w)' # fixed-suffix (longest first)
417
+ r'|(?<!\w)AND_ALIGN_\d+_\d+(?!\w)'
418
+ r'|(?<!\w)AND_MASK_ALIGN(?!\w)' # bare bracket-syntax keyword
419
+ r'|(?<!\w)AND_ALIGN(?!\w)'
420
+ r'|(?<!\w)AND_PERP(?!\w)'
421
+ r'|(?<!\w)AND_SALT_WIDE(?!\w)'
422
+ r'|(?<!\w)AND_SALT_BLOB(?!\w)'
423
+ r'|(?<!\w)AND_SALT(?!\w)'
424
+ r'|(?<!\w)AND_TOPK(?!\w)'
425
+ r'|(?<!\w)AND(?!\w)'
426
+ )
427
+ return [t for t in re.split(rf'(\[|\]|:|{keyword_pattern}|{affine_kw_pattern})', s) if t.strip()]
428
+
429
+
430
+ def _is_float(s: str) -> bool:
431
+ try:
432
+ float(s)
433
+ return True
434
+ except ValueError:
435
+ return False
436
+
437
+
438
+ if __name__ == '__main__':
439
+ cases = [
440
+ ('original AND_SALT', 'hello AND_SALT concept :1.0'),
441
+ ('original AND_ALIGN_4_8', 'hello AND_ALIGN_4_8 watercolor :0.5'),
442
+ ('new AND_SALT[5]', 'hello AND_SALT[5] concept :1.0'),
443
+ ('new AND_SALT_WIDE[3]', 'hello AND_SALT_WIDE[3] concept :0.8'),
444
+ ('new AND_SALT_BLOB[8]', 'hello AND_SALT_BLOB[8] concept :1.0'),
445
+ ('new AND_ALIGN[4,8]', 'hello AND_ALIGN[4,8] watercolor :0.5'),
446
+ ('new AND_MASK_ALIGN[6,12]', 'hello AND_MASK_ALIGN[6,12] structure :0.7'),
447
+ ('composite not eaten', 'hello AND_ALIGN[arst AND defg :2.0] :0.5'),
448
+ ]
449
+ for label, c in cases:
450
+ res = parse_root(c)
451
+ child = res.children[-1]
452
+ print(f'[{label}]')
453
+ print(f' conciliation={child.conciliation} params={child.conciliation_params}')
neutral_prompt_patcheds/lib_neutral_prompt/prompt_parser_hijack.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import logging
3
+ from typing import List
4
+
5
+ from lib_neutral_prompt import hijacker, global_state, neutral_prompt_parser
6
+ from modules import script_callbacks, prompt_parser
7
+
8
+ # ---------------------------------------------------------------------------
9
+ # Fix: prompt_parser_fixed escapes standalone '&' as a multicond separator.
10
+ # neutral_prompt_parser does NOT treat '&' as AND — it keeps it as literal text.
11
+ # So a leaf like "cat & dog" transpiles to "cat & dog :1.0",
12
+ # and the patched prompt_parser splits that into 3 multicond branches instead of 1,
13
+ # which shifts all batch_cond_indices and breaks PERP/SALT/affine.
14
+ #
15
+ # Solution: escape any standalone '&' inside leaf text with '\&' before handing
16
+ # the string to the webui parser. The patched prompt_parser correctly unescapes
17
+ # '\&' back to '&' during conditioning, so the model sees the original text.
18
+ # ---------------------------------------------------------------------------
19
+
20
+ _STANDALONE_AMP = re.compile(r'(?<!\\)(?<!\S)&(?!\S)')
21
+
22
+ # ── debug logging ──────────────────────────────────────────────────────────
23
+ # Set env variable NP_DEBUG=1 to enable verbose output in the A1111 console.
24
+ # Example (Linux/Mac): NP_DEBUG=1 python launch.py
25
+ # Example (Windows cmd): set NP_DEBUG=1 && python launch.py
26
+ import os as _os
27
+ _DEBUG = _os.getenv("NP_DEBUG", "0").strip() not in ("0", "", "false", "no", "off")
28
+ _log = logging.getLogger("neutral_prompt.hijack")
29
+ # ──────────────────────────────────────────────────────────────────────────
30
+
31
+
32
+ def _escape_leaf_ampersands(text: str) -> str:
33
+ """Escape standalone '&' so patched prompt_parser doesn't split a single
34
+ Neutral Prompt leaf into extra multicond branches.
35
+ "cat & dog" -> "cat \\& dog"
36
+ "R&D" -> "R&D" (unchanged — not standalone)
37
+ "\\&" -> "\\&" (unchanged — already escaped)
38
+ """
39
+ if not text or '&' not in text:
40
+ return text
41
+ return _STANDALONE_AMP.sub(r'\\&', text)
42
+
43
+
44
+ prompt_parser_hijacker = hijacker.ModuleHijacker.install_or_get(
45
+ module=prompt_parser,
46
+ hijacker_attribute='__neutral_prompt_hijacker',
47
+ on_uninstall=script_callbacks.on_script_unloaded,
48
+ )
49
+
50
+
51
+ @prompt_parser_hijacker.hijack('get_multicond_prompt_list')
52
+ def get_multicond_prompt_list_hijack(prompts, original_function):
53
+ if not global_state.is_enabled:
54
+ return original_function(prompts)
55
+
56
+ global_state.prompt_exprs = parse_prompts(prompts)
57
+ webui_prompts = transpile_exprs(global_state.prompt_exprs)
58
+
59
+ # ── debug: transpiled strings ──────────────────────────────────────────
60
+ if _DEBUG:
61
+ for i, (orig, transp) in enumerate(zip(prompts, webui_prompts)):
62
+ _log.warning(
63
+ "[NP_DEBUG] prompt[%d]\n"
64
+ " original : %r\n"
65
+ " transpiled : %r\n"
66
+ " branches : %d",
67
+ i, orig, transp,
68
+ # count multicond splits the same way prompt_parser_fixed does
69
+ len(re.split(r'(?:\bAND\b|(?<!\S)&(?!\S))(?!_PERP|_SALT|_TOPK)', transp))
70
+ )
71
+ # ──────────────────────────────────────────────────────────────────────
72
+
73
+ if isinstance(prompts, getattr(prompt_parser, 'SdConditioning', type(None))):
74
+ webui_prompts = prompt_parser.SdConditioning(webui_prompts, copy_from=prompts)
75
+
76
+ result = original_function(webui_prompts)
77
+
78
+ # ── debug: multicond parts ─────────────────────────────────────────────
79
+ if _DEBUG:
80
+ conds_list, prompt_flat_list, prompt_indexes = result
81
+ _log.warning(
82
+ "[NP_DEBUG] get_multicond_prompt_list result\n"
83
+ " prompt_flat_list : %r\n"
84
+ " prompt_indexes : %r\n"
85
+ " conds_list : %r",
86
+ list(prompt_flat_list),
87
+ dict(prompt_indexes),
88
+ conds_list,
89
+ )
90
+ # ──────────────────────────────────────────────────────────────────────
91
+
92
+ return result
93
+
94
+
95
+ def parse_prompts(prompts: List[str]) -> List[neutral_prompt_parser.PromptExpr]:
96
+ exprs = []
97
+ for prompt in prompts:
98
+ expr = neutral_prompt_parser.parse_root(prompt)
99
+ exprs.append(expr)
100
+ return exprs
101
+
102
+
103
+ def transpile_exprs(exprs: neutral_prompt_parser.PromptExpr):
104
+ webui_prompts = []
105
+ for expr in exprs:
106
+ webui_prompts.append(expr.accept(WebuiPromptVisitor()))
107
+ return webui_prompts
108
+
109
+
110
+ class WebuiPromptVisitor:
111
+ def visit_leaf_prompt(self, that: neutral_prompt_parser.LeafPrompt) -> str:
112
+ prompt = _escape_leaf_ampersands(that.prompt)
113
+ return f'{prompt} :{that.weight}'
114
+
115
+ def visit_composite_prompt(self, that: neutral_prompt_parser.CompositePrompt) -> str:
116
+ return ' AND '.join(child.accept(self) for child in that.children)
117
+
118
+
119
+ @prompt_parser_hijacker.hijack('reconstruct_multicond_batch')
120
+ def reconstruct_multicond_batch_hijack(*args, original_function, **kwargs):
121
+ """Store batch_cond_indices for the pre-noise affine hook (affine branch)."""
122
+ res = original_function(*args, **kwargs)
123
+ global_state.batch_cond_indices = res[0]
124
+
125
+ # ── debug: batch_cond_indices ──────────────────────────────────────────
126
+ if _DEBUG:
127
+ _log.warning(
128
+ "[NP_DEBUG] reconstruct_multicond_batch\n"
129
+ " batch_cond_indices : %r",
130
+ res[0],
131
+ )
132
+ # ──────────────────────────────────────────────────────────────────────
133
+
134
+ return res
neutral_prompt_patcheds/lib_neutral_prompt/protection_utils.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base-prompt protection subsystem for sd-webui-neutral-prompt.
3
+
4
+ Extracted from cfg_denoiser_hijack.py to form a standalone, testable module.
5
+
6
+ Protection modes
7
+ ----------------
8
+ off — no guard at all; auxiliary children can dominate freely.
9
+ auto — structural guard: fires when the prompt tree lacks a top-level
10
+ base AND-child (every child has a conciliation keyword).
11
+ strict — auto PLUS numerical: also fires when norm(base_delta)/norm(aux_delta)
12
+ falls below *strict_threshold*. Falls back to standard A1111 CFG.
13
+ soft — auto structural guard + numerical attenuation: instead of a hard
14
+ fallback, attenuates the auxiliary delta so that
15
+ norm(base_delta) / norm(attenuated_aux_delta) ≈ soft_floor.
16
+ The image retains auxiliary influence but cannot overpower the base.
17
+
18
+ All modes except 'off' print a diagnostic message when global_state.verbose=True.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import sys
24
+ from typing import List, Optional, Tuple
25
+
26
+ import torch
27
+
28
+ from lib_neutral_prompt import global_state, neutral_prompt_parser
29
+
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Structural checks (per-prompt, O(n) in children)
33
+ # ---------------------------------------------------------------------------
34
+
35
+ def prompt_has_valid_base_path(
36
+ prompt: neutral_prompt_parser.PromptExpr,
37
+ ) -> bool:
38
+ """
39
+ Per-prompt structural check.
40
+
41
+ Returns True iff *prompt* is a CompositePrompt with at least one top-level
42
+ child that carries no conciliation keyword (i.e. a plain base segment).
43
+ """
44
+ return (
45
+ isinstance(prompt, neutral_prompt_parser.CompositePrompt)
46
+ and any(c.conciliation is None for c in prompt.children)
47
+ )
48
+
49
+
50
+ def get_invalid_base_prompt_indices(
51
+ prompt_exprs: List[neutral_prompt_parser.PromptExpr],
52
+ ) -> List[int]:
53
+ """
54
+ Return the 0-based indices of any prompts in the batch that lack a
55
+ valid base AND-path. An empty return value means every prompt is OK.
56
+
57
+ Per-prompt rather than any(...): a healthy neighbour must not silently
58
+ legalise a broken prompt in the same batched request.
59
+ """
60
+ return [
61
+ i for i, p in enumerate(prompt_exprs)
62
+ if not prompt_has_valid_base_path(p)
63
+ ]
64
+
65
+
66
+ def has_valid_base_path(
67
+ prompt_exprs: List[neutral_prompt_parser.PromptExpr],
68
+ ) -> bool:
69
+ """
70
+ Batch-level alias: True iff the batch is non-empty AND every prompt has a
71
+ valid base path. An empty batch returns False (nothing to protect).
72
+ """
73
+ return bool(prompt_exprs) and len(get_invalid_base_prompt_indices(prompt_exprs)) == 0
74
+
75
+
76
+ # Old private names kept for backward-compat (hijack still references them via import)
77
+ _prompt_has_valid_base_path = prompt_has_valid_base_path
78
+ _get_invalid_base_prompt_indices = get_invalid_base_prompt_indices
79
+ _has_valid_base_path = has_valid_base_path
80
+
81
+
82
+ # ---------------------------------------------------------------------------
83
+ # Numerical helpers
84
+ # ---------------------------------------------------------------------------
85
+
86
+ def safe_norm(tensor: torch.Tensor) -> float:
87
+ """Return float L2-norm; 0.0 on any error (empty tensor, NaN, etc.)."""
88
+ try:
89
+ v = torch.linalg.norm(tensor.float()).item()
90
+ return v if v == v else 0.0 # NaN guard
91
+ except Exception:
92
+ return 0.0
93
+
94
+ _safe_norm = safe_norm # alias
95
+
96
+
97
+ # ---------------------------------------------------------------------------
98
+ # Soft-mode attenuation
99
+ # ---------------------------------------------------------------------------
100
+
101
+ def compute_soft_attenuation(
102
+ base_norm: float,
103
+ aux_norm: float,
104
+ threshold: float,
105
+ ) -> float:
106
+ """
107
+ Compute a [0, 1] scale factor for the auxiliary delta so that, after
108
+ scaling, the base/aux ratio reaches at least *threshold*.
109
+
110
+ When the ratio is already ≥ threshold returns 1.0 (no change).
111
+ When aux_norm ≈ 0 returns 1.0 (nothing to attenuate).
112
+
113
+ Formula: factor = min(1, base_norm / (aux_norm * threshold))
114
+ This ensures norm(base) / norm(factor * aux) ≈ threshold.
115
+ """
116
+ if aux_norm < 1e-8 or base_norm < 0:
117
+ return 1.0
118
+ ratio = base_norm / aux_norm
119
+ if ratio >= threshold:
120
+ return 1.0
121
+ # factor so that base / (factor * aux) = threshold ↔ factor = base / (aux * threshold)
122
+ return max(0.0, base_norm / (aux_norm * threshold))
123
+
124
+
125
+ # ---------------------------------------------------------------------------
126
+ # Primary policy function
127
+ # ---------------------------------------------------------------------------
128
+
129
+ def should_fallback_for_protection(
130
+ x_out: torch.Tensor,
131
+ batch_cond_indices: List[List],
132
+ text_uncond: torch.Tensor,
133
+ # These visitors are injected to avoid circular imports
134
+ CondDeltaVisitor, # noqa: N803
135
+ AuxCondDeltaVisitor, # noqa: N803
136
+ DenoiseArgs, # noqa: N803
137
+ ) -> Tuple[bool, Optional[str]]:
138
+ """
139
+ Decide whether to fall back to standard CFG.
140
+
141
+ Returns ``(should_fallback: bool, reason: str | None)``.
142
+
143
+ Modes:
144
+ 'off' → (False, None) — never fires.
145
+ 'auto' → structural check only.
146
+ 'strict' → structural check + numerical ratio check → hard fallback.
147
+ 'soft' → structural check only → (False, None); attenuation is
148
+ handled separately by apply_soft_attenuation_if_needed().
149
+ """
150
+ mode = global_state.normalize_protection_mode(global_state.protection_mode)
151
+
152
+ if mode == 'off':
153
+ return False, None
154
+
155
+ # --- Structural check (auto / strict / soft) — per-prompt ---
156
+ invalid_indices = get_invalid_base_prompt_indices(global_state.prompt_exprs)
157
+ if invalid_indices:
158
+ reason = (
159
+ f'no valid base AND-path for prompt index/indices {invalid_indices} '
160
+ f'(mode={mode}). '
161
+ 'Those segments have only conciliation keywords (AND_PERP / AND_SALT / ...). '
162
+ 'Add a plain base prompt before AND_* segments, '
163
+ 'or set protection to "off" to allow this intentionally.'
164
+ )
165
+ return True, reason
166
+
167
+ # Soft mode only does structural check; numerical handling is done in-place.
168
+ if mode in ('auto', 'soft'):
169
+ return False, None
170
+
171
+ # --- Numerical check (strict only) ---
172
+ threshold = global_state.clamp_strict_threshold(global_state.strict_threshold)
173
+ uncond = x_out[-text_uncond.shape[0]:]
174
+
175
+ for batch_i, (prompt, cond_indices) in enumerate(
176
+ zip(global_state.prompt_exprs, batch_cond_indices)):
177
+ args = DenoiseArgs(x_out, uncond[batch_i], cond_indices)
178
+ base_delta = prompt.accept(CondDeltaVisitor(), args, 0)
179
+ aux_delta = prompt.accept(AuxCondDeltaVisitor(), args, base_delta, 0)
180
+
181
+ base_norm = safe_norm(base_delta)
182
+ aux_norm = safe_norm(aux_delta)
183
+
184
+ if aux_norm < 1e-8:
185
+ continue
186
+
187
+ ratio = base_norm / aux_norm
188
+ if ratio < threshold:
189
+ return (
190
+ True,
191
+ f'base/aux ratio {ratio:.4f} < threshold {threshold:.4f} '
192
+ f'(prompt #{batch_i}, mode=strict). '
193
+ 'Base contribution is too weak relative to auxiliary. '
194
+ 'Lower the strict threshold or switch to "auto" to allow this.',
195
+ )
196
+
197
+ return False, None
198
+
199
+
200
+ def apply_soft_attenuation_if_needed(
201
+ aux_delta: torch.Tensor,
202
+ base_delta: torch.Tensor,
203
+ prompt_index: int,
204
+ ) -> Tuple[torch.Tensor, Optional[str]]:
205
+ """
206
+ For 'soft' mode: attenuate *aux_delta* in-place (returns new tensor) so
207
+ that norm(base)/norm(aux) ≥ soft_threshold.
208
+
209
+ Returns ``(possibly_attenuated_aux_delta, diagnostic_note | None)``.
210
+ Called per-prompt inside the combine_denoised loop.
211
+ """
212
+ mode = global_state.normalize_protection_mode(global_state.protection_mode)
213
+ if mode != 'soft':
214
+ return aux_delta, None
215
+
216
+ threshold = global_state.clamp_strict_threshold(global_state.strict_threshold)
217
+ base_norm = safe_norm(base_delta)
218
+ aux_norm = safe_norm(aux_delta)
219
+ factor = compute_soft_attenuation(base_norm, aux_norm, threshold)
220
+
221
+ if factor >= 1.0:
222
+ return aux_delta, None
223
+
224
+ note = (
225
+ f'[soft protection] prompt #{prompt_index}: '
226
+ f'base/aux ratio {base_norm / aux_norm:.4f} < {threshold:.4f}; '
227
+ f'attenuating aux × {factor:.3f}'
228
+ )
229
+ if global_state.verbose:
230
+ print(f'[neutral-prompt] {note}', file=sys.stderr)
231
+
232
+ return aux_delta * factor, note
233
+
234
+
235
+ # ---------------------------------------------------------------------------
236
+ # Diagnostics helpers (used by matryoshka_utils explain panel)
237
+ # ---------------------------------------------------------------------------
238
+
239
+ def protection_verdict(
240
+ prompt_str: str,
241
+ ) -> Tuple[str, str]:
242
+ """
243
+ Structural-only verdict for the debug/explain panel.
244
+
245
+ Returns (status, message) where status is 'ok', 'fire', or 'off'.
246
+ Does not simulate runtime ratio checks or batch state.
247
+ """
248
+ from lib_neutral_prompt import neutral_prompt_parser as _p
249
+ mode = global_state.normalize_protection_mode(global_state.protection_mode)
250
+
251
+ if mode == 'off':
252
+ return 'off', 'Guard disabled — no fallback will occur'
253
+
254
+ try:
255
+ root = _p.parse_root(prompt_str)
256
+ except Exception:
257
+ return 'fire', 'Parse failed — cannot determine base-path'
258
+
259
+ if isinstance(root, _p.CompositePrompt):
260
+ has_base = any(c.conciliation is None for c in root.children)
261
+ else:
262
+ has_base = root.conciliation is None
263
+
264
+ if not has_base:
265
+ return 'fire', 'WOULD FIRE — no plain base segment at the top level'
266
+
267
+ suffix = ''
268
+ if mode == 'strict':
269
+ thr = global_state.clamp_strict_threshold(global_state.strict_threshold)
270
+ suffix = f'\n Strict ratio check ({thr:.2f}) runs at generation time'
271
+ elif mode == 'soft':
272
+ thr = global_state.clamp_strict_threshold(global_state.strict_threshold)
273
+ suffix = (
274
+ f'\n Soft attenuation active: aux will be scaled if '
275
+ f'base/aux ratio < {thr:.2f}'
276
+ )
277
+ return 'ok', f'OK — base segment present{suffix}'
neutral_prompt_patcheds/lib_neutral_prompt/step_utils.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Step-window activation control for sd-webui-neutral-prompt.
3
+
4
+ Allows AND_* strategies to be active only during a sub-range of the denoising
5
+ process, specified as a normalised progress window [start, end] ∈ [0, 1].
6
+
7
+ Typical use:
8
+ - AND_ALIGN / AND_PERP : early steps (0.0–0.5) for structure/composition
9
+ - AND_SALT / AND_TOPK : late steps (0.5–1.0) for texture/detail
10
+
11
+ Three levels of control (coarse to fine):
12
+ 1. Global window — applies to ALL AND_* strategies.
13
+ 2. Per-strategy window — overrides global for a specific strategy family.
14
+ 3. Per-node window — (future) per-node override inside matryoshka trees.
15
+
16
+ This module is intentionally pure (no torch, no gradio) so it can be
17
+ imported in both UI and runtime without side effects.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import dataclasses
23
+ from typing import Dict, Optional, Tuple
24
+
25
+
26
+ # ---------------------------------------------------------------------------
27
+ # Data model
28
+ # ---------------------------------------------------------------------------
29
+
30
+ @dataclasses.dataclass(frozen=True)
31
+ class StepWindow:
32
+ """A normalised [start, end] activation window.
33
+
34
+ start and end are clamped to [0, 1] with start ≤ end.
35
+ An all-active window is StepWindow(0.0, 1.0).
36
+ """
37
+ start: float = 0.0
38
+ end: float = 1.0
39
+
40
+ def __post_init__(self) -> None:
41
+ # Use object.__setattr__ because the dataclass is frozen
42
+ object.__setattr__(self, 'start', max(0.0, min(1.0, float(self.start))))
43
+ object.__setattr__(self, 'end', max(0.0, min(1.0, float(self.end))))
44
+ if self.start > self.end:
45
+ object.__setattr__(self, 'end', self.start)
46
+
47
+ def is_active_at(self, progress: float) -> bool:
48
+ """Return True if *progress* ∈ [start, end] (inclusive)."""
49
+ return self.start <= progress <= self.end
50
+
51
+ def __repr__(self) -> str:
52
+ return f'StepWindow({self.start:.2f}–{self.end:.2f})'
53
+
54
+
55
+ ALWAYS_ACTIVE = StepWindow(0.0, 1.0)
56
+
57
+
58
+ # ---------------------------------------------------------------------------
59
+ # Strategy-family presets
60
+ # ---------------------------------------------------------------------------
61
+ # These are sensible defaults reflecting typical SD denoising behaviour:
62
+ # - Global structure/composition → early steps
63
+ # - Local texture/detail → late steps
64
+ # Users can override any entry via the UI.
65
+
66
+ STRATEGY_DEFAULTS: Dict[str, StepWindow] = {
67
+ 'PERPENDICULAR': StepWindow(0.0, 0.5), # composition correction
68
+ 'SALIENCE_MASK': StepWindow(0.4, 1.0), # local texture detail
69
+ 'SALIENCE_MASK_WIDE': StepWindow(0.3, 1.0), # broad region texture
70
+ 'SALIENCE_MASK_BLOB': StepWindow(0.4, 1.0), # blob region
71
+ 'SEMANTIC_GUIDANCE': StepWindow(0.4, 1.0), # sparse semantic detail
72
+ 'ALIGNMENT_BLEND_CUSTOM': StepWindow(0.0, 0.6), # alignment preserves structure
73
+ 'ALIGNMENT_MASK_BLEND_CUSTOM': StepWindow(0.0, 0.6),
74
+ }
75
+
76
+
77
+ def _strategy_family(strategy_name: str) -> str:
78
+ """
79
+ Map a concrete ConciliationStrategy name (including legacy fixed-suffix
80
+ variants like 'ALIGNMENT_BLEND_4_8') to a canonical family key.
81
+ """
82
+ if strategy_name.startswith('ALIGNMENT_MASK_BLEND_'):
83
+ return 'ALIGNMENT_MASK_BLEND_CUSTOM'
84
+ if strategy_name.startswith('ALIGNMENT_BLEND_'):
85
+ return 'ALIGNMENT_BLEND_CUSTOM'
86
+ return strategy_name
87
+
88
+
89
+ # ---------------------------------------------------------------------------
90
+ # Progress normalisation
91
+ # ---------------------------------------------------------------------------
92
+
93
+ def normalize_progress(current_step: int, total_steps: int) -> float:
94
+ """
95
+ Convert an absolute step index to a normalised progress value in [0, 1].
96
+
97
+ The first step (index 0) maps to 0.0 and the last step (index
98
+ total_steps - 1) maps to 1.0. Edge-cases:
99
+
100
+ total_steps = 1 → always returns 0.0
101
+ step 19 of 20 → 1.00
102
+ """
103
+ denom = max(total_steps - 1, 1)
104
+ return max(0.0, min(1.0, current_step / denom))
105
+
106
+
107
+ def normalize_step_range(start, end) -> 'Tuple[float, float]':
108
+ """
109
+ Parse and clamp a (start, end) pair to [0, 1].
110
+
111
+ If start > end after clamping, returns (0.0, 1.0) — full range fallback.
112
+ Accepts floats, ints, or stringified numbers.
113
+ """
114
+ try:
115
+ s = max(0.0, min(1.0, float(start)))
116
+ except (TypeError, ValueError):
117
+ s = 0.0
118
+ try:
119
+ e = max(0.0, min(1.0, float(end)))
120
+ except (TypeError, ValueError):
121
+ e = 1.0
122
+ if s > e:
123
+ return 0.0, 1.0
124
+ return s, e
125
+
126
+
127
+ # Canonical UI names for each per-strategy group (user-facing labels)
128
+ STRATEGY_UI_KEYS = [
129
+ 'AND_PERP',
130
+ 'AND_SALT',
131
+ 'AND_SALT_WIDE',
132
+ 'AND_SALT_BLOB',
133
+ 'AND_TOPK',
134
+ 'AND_ALIGN',
135
+ 'AND_MASK_ALIGN',
136
+ ]
137
+
138
+ # Map UI key → canonical family name used in strategy_is_active
139
+ _UI_KEY_TO_FAMILY: Dict[str, str] = {
140
+ 'AND_PERP': 'PERPENDICULAR',
141
+ 'AND_SALT': 'SALIENCE_MASK',
142
+ 'AND_SALT_WIDE': 'SALIENCE_MASK_WIDE',
143
+ 'AND_SALT_BLOB': 'SALIENCE_MASK_BLOB',
144
+ 'AND_TOPK': 'SEMANTIC_GUIDANCE',
145
+ 'AND_ALIGN': 'ALIGNMENT_BLEND_CUSTOM',
146
+ 'AND_MASK_ALIGN': 'ALIGNMENT_MASK_BLEND_CUSTOM',
147
+ }
148
+
149
+ _FAMILY_TO_UI_KEY: Dict[str, str] = {v: k for k, v in _UI_KEY_TO_FAMILY.items()}
150
+
151
+
152
+ def ui_key_to_family(ui_key: str) -> str:
153
+ """Convert an AND_* UI key to its canonical family name."""
154
+ return _UI_KEY_TO_FAMILY.get(ui_key, ui_key)
155
+
156
+
157
+ def family_to_ui_key(family: str) -> str:
158
+ """Convert a family name back to its AND_* UI key."""
159
+ return _FAMILY_TO_UI_KEY.get(family, family)
160
+
161
+
162
+ def build_per_strategy_windows(
163
+ custom_dict: Dict[str, 'Tuple[float, float]'],
164
+ ) -> Dict[str, StepWindow]:
165
+ """
166
+ Convert a raw {AND_* UI key → (start, end)} dict to the
167
+ {family_name → StepWindow} format expected by strategy_is_active.
168
+
169
+ Entries with invalid ranges are repaired by normalize_step_range.
170
+ Unrecognised keys are silently skipped.
171
+ """
172
+ result: Dict[str, StepWindow] = {}
173
+ for ui_key, (start, end) in custom_dict.items():
174
+ if ui_key not in _UI_KEY_TO_FAMILY:
175
+ continue # silently skip unrecognised keys
176
+ family = ui_key_to_family(ui_key)
177
+ s, e = normalize_step_range(start, end)
178
+ result[family] = StepWindow(s, e)
179
+ return result
180
+
181
+
182
+ def default_custom_windows() -> Dict[str, 'Tuple[float, float]']:
183
+ """
184
+ Return the STRATEGY_DEFAULTS as a {AND_* UI key → (start, end)} dict —
185
+ the natural starting point for the per-strategy custom editor.
186
+ """
187
+ out: Dict[str, 'Tuple[float, float]'] = {}
188
+ for family, win in STRATEGY_DEFAULTS.items():
189
+ ui_key = _FAMILY_TO_UI_KEY.get(family)
190
+ if ui_key:
191
+ out[ui_key] = (win.start, win.end)
192
+ return out
193
+
194
+
195
+ def serialize_per_strategy_windows(
196
+ custom_dict: Dict[str, 'Tuple[float, float]'],
197
+ ) -> str:
198
+ """
199
+ Serialise to a compact string for generation params / infotext.
200
+
201
+ Format: ``AND_PERP:0.00-0.50,AND_SALT:0.40-1.00,...``
202
+ """
203
+ parts = []
204
+ for ui_key in STRATEGY_UI_KEYS:
205
+ if ui_key in custom_dict:
206
+ s, e = custom_dict[ui_key]
207
+ parts.append(f'{ui_key}:{s:.4f}-{e:.4f}')
208
+ return ','.join(parts)
209
+
210
+
211
+ def deserialize_per_strategy_windows(
212
+ raw: str,
213
+ ) -> Dict[str, 'Tuple[float, float]']:
214
+ """
215
+ Parse the compact string back to {AND_* UI key → (start, end)}.
216
+ Tolerates malformed entries (skips them silently).
217
+ """
218
+ result: Dict[str, 'Tuple[float, float]'] = {}
219
+ for token in raw.split(','):
220
+ token = token.strip()
221
+ if not token:
222
+ continue
223
+ if ':' not in token or '-' not in token:
224
+ continue
225
+ try:
226
+ ui_key, range_part = token.split(':', 1)
227
+ ui_key = ui_key.strip()
228
+ start_str, end_str = range_part.split('-', 1)
229
+ s, e = normalize_step_range(start_str.strip(), end_str.strip())
230
+ if ui_key in _UI_KEY_TO_FAMILY:
231
+ result[ui_key] = (s, e)
232
+ except Exception:
233
+ continue
234
+ return result
235
+
236
+
237
+
238
+ """
239
+ Convert an absolute step counter to a normalised progress value [0, 1].
240
+
241
+ Examples:
242
+ step 0 of 20 → 0.00
243
+ step 10 of 20 → 0.526 (10/19)
244
+ step 19 of 20 → 1.00
245
+ """
246
+ denom = max(total_steps - 1, 1)
247
+ return max(0.0, min(1.0, current_step / denom))
248
+
249
+
250
+ # ---------------------------------------------------------------------------
251
+ # Activation decision
252
+ # ---------------------------------------------------------------------------
253
+
254
+ def strategy_is_active(
255
+ strategy_name: str,
256
+ progress: float,
257
+ global_window: Optional[StepWindow] = None,
258
+ per_strategy_windows: Optional[Dict[str, StepWindow]] = None,
259
+ use_defaults: bool = False,
260
+ ) -> bool:
261
+ """
262
+ Return True if *strategy_name* should be applied at the given *progress*.
263
+
264
+ Priority order (highest wins):
265
+ 1. Per-strategy window (if provided and contains this strategy's family).
266
+ 2. Global window (if provided).
267
+ 3. STRATEGY_DEFAULTS (if use_defaults=True).
268
+ 4. ALWAYS_ACTIVE otherwise.
269
+
270
+ Parameters
271
+ ----------
272
+ strategy_name :
273
+ ConciliationStrategy.name or a raw strategy string.
274
+ progress :
275
+ Normalised step progress in [0, 1].
276
+ global_window :
277
+ A single window applied to every strategy (unless overridden).
278
+ per_strategy_windows :
279
+ Dict mapping strategy family name → StepWindow. Overrides global.
280
+ use_defaults :
281
+ If True, fall back to STRATEGY_DEFAULTS when no explicit window is set.
282
+ """
283
+ family = _strategy_family(strategy_name)
284
+
285
+ # Per-strategy override
286
+ if per_strategy_windows and family in per_strategy_windows:
287
+ return per_strategy_windows[family].is_active_at(progress)
288
+
289
+ # Global override
290
+ if global_window is not None:
291
+ return global_window.is_active_at(progress)
292
+
293
+ # Preset defaults
294
+ if use_defaults and family in STRATEGY_DEFAULTS:
295
+ return STRATEGY_DEFAULTS[family].is_active_at(progress)
296
+
297
+ return True # ALWAYS_ACTIVE
298
+
299
+
300
+ # ---------------------------------------------------------------------------
301
+ # global_state integration helpers
302
+ # ---------------------------------------------------------------------------
303
+
304
+ def get_global_window() -> Optional[StepWindow]:
305
+ """Read the global step window from global_state (may be None = disabled)."""
306
+ try:
307
+ from lib_neutral_prompt import global_state as _gs
308
+ return _gs.step_window_global
309
+ except AttributeError:
310
+ return None
311
+
312
+
313
+ def get_per_strategy_windows() -> Dict[str, StepWindow]:
314
+ """Read per-strategy windows from global_state (empty dict if not set)."""
315
+ try:
316
+ from lib_neutral_prompt import global_state as _gs
317
+ return _gs.step_window_per_strategy or {}
318
+ except AttributeError:
319
+ return {}
320
+
321
+
322
+ def get_use_defaults() -> bool:
323
+ """Read whether STRATEGY_DEFAULTS are active."""
324
+ try:
325
+ from lib_neutral_prompt import global_state as _gs
326
+ return bool(_gs.step_window_use_defaults)
327
+ except AttributeError:
328
+ return False
329
+
330
+
331
+ def strategy_is_active_from_state(
332
+ strategy_name: str,
333
+ progress: float,
334
+ ) -> bool:
335
+ """
336
+ Convenience wrapper: read all step-window settings from global_state and
337
+ return whether *strategy_name* is active at *progress*.
338
+
339
+ Also handles Lock after End: if ``global_state.step_window_lock_after_end``
340
+ is True and the strategy's window has ended, it sets a permanent lock flag
341
+ in ``global_state._step_lock_flags[family]`` and returns False for all
342
+ subsequent calls this generation.
343
+ """
344
+ from lib_neutral_prompt import global_state as _gs
345
+
346
+ family = _strategy_family(strategy_name)
347
+
348
+ # ── Lock after End: check permanent lock first ─────────────────────────
349
+ lock_enabled = getattr(_gs, 'step_window_lock_after_end', False)
350
+ if lock_enabled:
351
+ lock_flags = getattr(_gs, '_step_lock_flags', None)
352
+ if lock_flags is None:
353
+ _gs._step_lock_flags = {}
354
+ lock_flags = _gs._step_lock_flags
355
+ if lock_flags.get(family, False):
356
+ return False # permanently frozen this generation
357
+
358
+ # ── Normal activity check ───────────────────────────────────────────────
359
+ active = strategy_is_active(
360
+ strategy_name,
361
+ progress,
362
+ global_window=get_global_window(),
363
+ per_strategy_windows=get_per_strategy_windows(),
364
+ use_defaults=get_use_defaults(),
365
+ )
366
+
367
+ # ── Lock after End: set flag when window has closed ─────────────────────
368
+ if lock_enabled and not active:
369
+ # Only lock once we have actually been past the window start
370
+ # (avoids locking strategies that haven't started yet)
371
+ window = _resolve_window_for_family(family)
372
+ if window is not None and progress >= window.start:
373
+ # We were inside the window at some point; window has now ended
374
+ lock_flags[family] = True # type: ignore[index]
375
+
376
+ return active
377
+
378
+
379
+ def _resolve_window_for_family(family: str) -> Optional[StepWindow]:
380
+ """
381
+ Return the effective StepWindow for *family* given the current global_state,
382
+ or None if no window restriction applies.
383
+ Used by Lock after End to decide whether to set the lock flag.
384
+ """
385
+ try:
386
+ from lib_neutral_prompt import global_state as _gs
387
+ per = getattr(_gs, 'step_window_per_strategy', None) or {}
388
+ if family in per:
389
+ return per[family]
390
+ gw = getattr(_gs, 'step_window_global', None)
391
+ if gw is not None:
392
+ return gw
393
+ if getattr(_gs, 'step_window_use_defaults', False):
394
+ return STRATEGY_DEFAULTS.get(family)
395
+ except Exception:
396
+ pass
397
+ return None
398
+
399
+
400
+ # ---------------------------------------------------------------------------
401
+ # Explain / diagnostic summary
402
+ # ---------------------------------------------------------------------------
403
+
404
+ def render_step_window_summary(
405
+ strategies: list,
406
+ progress: Optional[float] = None,
407
+ global_window: Optional[StepWindow] = None,
408
+ per_strategy_windows: Optional[Dict[str, StepWindow]] = None,
409
+ use_defaults: bool = False,
410
+ ) -> str:
411
+ """
412
+ Build a human-readable summary of step-window settings for the debug panel.
413
+
414
+ *strategies* is a list of strategy-name strings (from diagnostics['strategies']).
415
+ *progress* is the current (or hypothetical) normalised step (optional).
416
+ """
417
+ if not strategies:
418
+ return ''
419
+
420
+ lines = ['── Step windows ─────────────────────────']
421
+
422
+ # Header: show effective global window or "all steps"
423
+ if global_window is not None:
424
+ lines.append(f'Global window : {global_window}')
425
+ elif use_defaults:
426
+ lines.append('Global window : (per-strategy defaults active)')
427
+ else:
428
+ lines.append('Global window : all steps (no restriction)')
429
+
430
+ for s in strategies:
431
+ if s == 'BASE':
432
+ continue
433
+ family = _strategy_family(s)
434
+ # Resolve effective window
435
+ if per_strategy_windows and family in per_strategy_windows:
436
+ win = per_strategy_windows[family]
437
+ src = 'custom'
438
+ elif global_window is not None:
439
+ win = global_window
440
+ src = 'global'
441
+ elif use_defaults and family in STRATEGY_DEFAULTS:
442
+ win = STRATEGY_DEFAULTS[family]
443
+ src = 'default'
444
+ else:
445
+ win = ALWAYS_ACTIVE
446
+ src = 'always'
447
+
448
+ if progress is not None:
449
+ active = '✓ active' if win.is_active_at(progress) else '✗ gated'
450
+ lines.append(f' {s:32s} {win} [{src}] @ {progress:.2f} → {active}')
451
+ else:
452
+ lines.append(f' {s:32s} {win} [{src}]')
453
+
454
+ return '\n'.join(lines)
neutral_prompt_patcheds/lib_neutral_prompt/ui.py ADDED
@@ -0,0 +1,811 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI for sd-webui-neutral-prompt — v1.1 "Matryoshka visibility"
3
+
4
+ New in v1.1:
5
+ - Tree-style debug/explain with branch connectors and full diagnostics
6
+ - Matryoshka builder v1: 2-level visual composer with live preview
7
+ - Matryoshka templates: 7 ready-made nested recipes
8
+ - AND_TOPK[threshold] slider
9
+ - Affine transform builder with presets
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import Any, Callable, Dict, List, Optional, Tuple
15
+ import dataclasses
16
+ import gradio as gr
17
+
18
+ from lib_neutral_prompt import global_state, neutral_prompt_parser
19
+ from lib_neutral_prompt.matryoshka_utils import (
20
+ build_child_block,
21
+ build_nested_prompt,
22
+ render_prompt_node,
23
+ render_prompt_tree,
24
+ collect_prompt_diagnostics,
25
+ render_diagnostics,
26
+ render_full_explain,
27
+ MATRYOSHKA_TEMPLATES,
28
+ BUILDER_STRATEGIES,
29
+ _KW_DEFAULTS,
30
+ )
31
+ from lib_neutral_prompt.affine_utils import (
32
+ AFFINE_TRANSFORMS as _AFFINE_TRANSFORMS,
33
+ AFFINE_PRESETS as _AFFINE_PRESETS,
34
+ build_affine_snippet as _build_affine_snippet,
35
+ )
36
+ from lib_neutral_prompt.step_utils import (
37
+ StepWindow, STRATEGY_DEFAULTS,
38
+ STRATEGY_UI_KEYS as _STRATEGY_UI_KEYS,
39
+ build_per_strategy_windows as _build_per_strategy_windows,
40
+ serialize_per_strategy_windows as _serialize_per_strategy_windows,
41
+ deserialize_per_strategy_windows as _deserialize_per_strategy_windows,
42
+ default_custom_windows as _step_default_custom,
43
+ )
44
+ from modules import script_callbacks, shared
45
+
46
+
47
+ txt2img_prompt_textbox = None
48
+ img2img_prompt_textbox = None
49
+
50
+
51
+ # ---------------------------------------------------------------------------
52
+ # Prompt-type registry
53
+ # ---------------------------------------------------------------------------
54
+
55
+ prompt_types: Dict[str, str] = {
56
+ 'Perpendicular (AND_PERP)': neutral_prompt_parser.PromptKeyword.AND_PERP.value,
57
+ 'Saliency sharp (AND_SALT)': neutral_prompt_parser.PromptKeyword.AND_SALT.value,
58
+ 'Saliency blob (AND_SALT_BLOB)': neutral_prompt_parser.PromptKeyword.AND_SALT_BLOB.value,
59
+ 'Saliency wide / classic (AND_SALT_WIDE)': neutral_prompt_parser.PromptKeyword.AND_SALT_WIDE.value,
60
+ 'Semantic guidance top-k (AND_TOPK)': neutral_prompt_parser.PromptKeyword.AND_TOPK.value,
61
+ 'Alignment blend — custom D/S (AND_ALIGN)': neutral_prompt_parser.PromptKeyword.AND_ALIGN.value,
62
+ 'Alignment mask — custom D/S (AND_MASK_ALIGN)': neutral_prompt_parser.PromptKeyword.AND_MASK_ALIGN.value,
63
+ }
64
+ for _d, _s in ((2, 4), (2, 8), (4, 8), (4, 16), (8, 16), (8, 32)):
65
+ prompt_types[f'Alignment blend detail={_d} structure={_s} (AND_ALIGN_{_d}_{_s})'] = \
66
+ getattr(neutral_prompt_parser.PromptKeyword, f'AND_ALIGN_{_d}_{_s}').value
67
+ prompt_types[f'Alignment mask detail={_d} structure={_s} (AND_MASK_ALIGN_{_d}_{_s})'] = \
68
+ getattr(neutral_prompt_parser.PromptKeyword, f'AND_MASK_ALIGN_{_d}_{_s}').value
69
+
70
+ _SALT_LABELS = frozenset({'Saliency sharp (AND_SALT)', 'Saliency blob (AND_SALT_BLOB)',
71
+ 'Saliency wide / classic (AND_SALT_WIDE)'})
72
+ _TOPK_LABELS = frozenset({'Semantic guidance top-k (AND_TOPK)'})
73
+ _ALIGN_CUSTOM_LABELS = frozenset({'Alignment blend — custom D/S (AND_ALIGN)',
74
+ 'Alignment mask — custom D/S (AND_MASK_ALIGN)'})
75
+
76
+ prompt_types_tooltip = (
77
+ 'AND – add all features equally (webui built-in)\n'
78
+ 'AND_PERP – reduce contradicting features via perpendicular projection\n'
79
+ 'AND_SALT[k] – sharp saliency mask (default k=5; higher k → fewer pixels)\n'
80
+ 'AND_SALT_BLOB[k] – blob saliency: grown seed region\n'
81
+ 'AND_SALT_WIDE[k] – broad saliency mask (default k=1; ~55% of pixels)\n'
82
+ 'AND_TOPK[t] – small targeted changes (top-t fraction; default 0.05)\n'
83
+ '\n'
84
+ 'AND_ALIGN[D,S] – soft blend: detail without breaking structure\n'
85
+ 'AND_MASK_ALIGN[D,S] – binary-mask variant of AND_ALIGN\n'
86
+ ' D = detail kernel [2,32] S = structure kernel [2,32] D ≠ S\n'
87
+ '\n'
88
+ 'Affine: ROTATE[angle] SLIDE[x,y] SCALE[x,y] SHEAR[x,y]\n'
89
+ '\n'
90
+ 'Matryoshka nesting: AND_SALT[5] [ inner text AND_TOPK[0.1] detail ]'
91
+ )
92
+
93
+
94
+ def _build_keyword(label: str, salt_k: float, align_d: int, align_s: int,
95
+ topk_threshold: float = 0.05) -> str:
96
+ kw = prompt_types[label]
97
+ if label in _SALT_LABELS:
98
+ return f'{kw}[{salt_k}]'
99
+ if label in _TOPK_LABELS:
100
+ return f'{kw}[{topk_threshold}]'
101
+ if label in _ALIGN_CUSTOM_LABELS:
102
+ d, s = int(align_d), int(align_s)
103
+ if d != s and 2 <= d <= 32 and 2 <= s <= 32:
104
+ return f'{kw}[{d},{s}]'
105
+ return kw
106
+
107
+
108
+
109
+ def _build_step_window_preview(mode: str, start: float, end: float,
110
+ custom_raw: Optional[Dict] = None) -> str:
111
+ """Build a human-readable summary of the step-window setting."""
112
+ if mode == 'global':
113
+ win = StepWindow(start, end)
114
+ lines = [f'Global window: {win}', '']
115
+ for strat in ('PERPENDICULAR', 'SALIENCE_MASK', 'SALIENCE_MASK_WIDE',
116
+ 'SEMANTIC_GUIDANCE', 'ALIGNMENT_BLEND_CUSTOM'):
117
+ lines.append(f' {strat:34s} → {win}')
118
+ return '\n'.join(lines)
119
+ elif mode == 'per-strategy defaults':
120
+ lines = ['Per-strategy defaults:', '']
121
+ for strat, win in STRATEGY_DEFAULTS.items():
122
+ lines.append(f' {strat:34s} → {win}')
123
+ return '\n'.join(lines)
124
+ elif mode == 'per-strategy custom' and custom_raw:
125
+ from lib_neutral_prompt.step_utils import StepWindow as _SW, normalize_step_range
126
+ lines = ['Per-strategy custom windows:', '']
127
+ for ui_key in _STRATEGY_UI_KEYS:
128
+ if ui_key in custom_raw:
129
+ s, e = custom_raw[ui_key]
130
+ lines.append(f' {ui_key:16s} → {_SW(s, e)}')
131
+ return '\n'.join(lines)
132
+ return ''
133
+
134
+
135
+ # ---------------------------------------------------------------------------
136
+ # UI component class
137
+ # ---------------------------------------------------------------------------
138
+
139
+ @dataclasses.dataclass
140
+ class AccordionInterface:
141
+ get_elem_id: Callable[[str], str]
142
+
143
+ def __post_init__(self):
144
+ self.is_rendered = False
145
+
146
+ self.cfg_rescale = gr.Slider(
147
+ label='CFG rescale φ', minimum=0, maximum=1, value=0,
148
+ info='0 = disabled. Rescales CFG output to reduce over-saturation.',
149
+ )
150
+
151
+ # --- Base-prompt protection ---
152
+ self.protection_mode = gr.Radio(
153
+ label='Base prompt protection', choices=['off', 'auto', 'strict', 'soft'],
154
+ value=global_state.protection_mode,
155
+ info=(
156
+ 'off — no guard. '
157
+ 'auto — structural guard (hard fallback). '
158
+ 'strict — structural + ratio guard (hard fallback). '
159
+ 'soft — structural guard + ratio attenuation (no fallback, visible effect).'
160
+ ),
161
+ )
162
+ self.strict_threshold = gr.Slider(
163
+ label='Ratio threshold (strict / soft)', minimum=0.01, maximum=0.5, step=0.01,
164
+ value=global_state.strict_threshold,
165
+ info='strict: fallback when base/aux < threshold. soft: attenuate aux until ratio ≥ threshold.',
166
+ visible=(global_state.normalize_protection_mode(global_state.protection_mode) in ('strict', 'soft')),
167
+ )
168
+
169
+ # --- Prompt formatter ---
170
+ self.neutral_prompt = gr.Textbox(
171
+ label='Neutral prompt', show_label=False, lines=3,
172
+ placeholder='Auxiliary prompt (press "Apply to prompt" to append)',
173
+ )
174
+ self.neutral_cond_scale = gr.Slider(label='Prompt weight', minimum=-3, maximum=3, value=1)
175
+ self.aux_prompt_type = gr.Dropdown(
176
+ label='Prompt type', choices=list(prompt_types.keys()),
177
+ value=next(iter(prompt_types.keys())), info=prompt_types_tooltip,
178
+ elem_id=self.get_elem_id('formatter_prompt_type'),
179
+ )
180
+ _def = next(iter(prompt_types.keys()))
181
+ self.salt_k = gr.Slider(
182
+ label='Salience sharpness k', minimum=0.5, maximum=20.0, step=0.5, value=5.0,
183
+ info='Low k → broad. High k → surgical. Default 5.',
184
+ visible=(_def in _SALT_LABELS),
185
+ )
186
+ self.topk_threshold = gr.Slider(
187
+ label='Top-k threshold', minimum=0.01, maximum=0.5, step=0.01, value=0.05,
188
+ info='Fraction of elements that receive the contribution. Default 0.05.',
189
+ visible=(_def in _TOPK_LABELS),
190
+ )
191
+ self.align_detail = gr.Slider(
192
+ label='Alignment detail kernel D', minimum=2, maximum=32, step=1, value=4,
193
+ info='Fine detail kernel size. D ≠ S.', visible=False,
194
+ )
195
+ self.align_structure = gr.Slider(
196
+ label='Alignment structure kernel S', minimum=2, maximum=32, step=1, value=8,
197
+ info='Global structure kernel size. D ≠ S.', visible=False,
198
+ )
199
+ self.append_to_prompt_button = gr.Button(value='Apply to prompt')
200
+
201
+ # --- Affine formatter ---
202
+ self.affine_preset = gr.Dropdown(
203
+ label='Affine preset', choices=list(_AFFINE_PRESETS.keys()), value='Custom',
204
+ )
205
+ self.affine_transform = gr.Dropdown(
206
+ label='Transform', choices=_AFFINE_TRANSFORMS, value='ROTATE',
207
+ )
208
+ self.affine_p1 = gr.Slider(
209
+ label='Parameter 1 (angle / x / scale)', minimum=-2.0, maximum=2.0, step=0.01,
210
+ value=0.125,
211
+ info='ROTATE: turns (0.25=90°). SCALE/SLIDE/SHEAR: x-component.',
212
+ )
213
+ self.affine_p2 = gr.Slider(
214
+ label='Parameter 2 (y / scale-y)', minimum=-2.0, maximum=2.0, step=0.01, value=1.0,
215
+ info='y-component. Ignored for ROTATE, FLIP_H, FLIP_V.',
216
+ )
217
+ self.affine_preview = gr.Textbox(
218
+ label='Affine snippet', interactive=False,
219
+ value=_build_affine_snippet('ROTATE', 0.125, 1.0),
220
+ )
221
+ self.affine_insert_button = gr.Button(value='Insert affine into prompt')
222
+
223
+ # --- Step window ---
224
+ self.step_window_enabled = gr.Checkbox(
225
+ label='Enable step-window gating for AND_*', value=False,
226
+ info='Restrict each strategy to a sub-range of the denoising steps.',
227
+ )
228
+ self.step_window_lock = gr.Checkbox(
229
+ label='🔒 Lock after End', value=False, visible=False,
230
+ info='После последнего активного диапазона — заморозить навсегда. '
231
+ 'Стратегия не возобновится, даже если шаги войдут в другой диапазон.',
232
+ )
233
+ self.step_window_mode = gr.Radio(
234
+ label='Step window mode',
235
+ choices=['global', 'per-strategy defaults', 'per-strategy custom'],
236
+ value='global', visible=False,
237
+ info='"global" applies one window to all AND_* strategies. '
238
+ '"per-strategy defaults" uses built-in schedule (PERP/ALIGN early, SALT/TOPK late). '
239
+ '"per-strategy custom" lets you set each window independently.',
240
+ )
241
+ self.step_window_start = gr.Slider(
242
+ label='Global window start', minimum=0.0, maximum=1.0, step=0.01, value=0.0,
243
+ info='Normalised progress: 0 = first step, 1 = last step.',
244
+ visible=False,
245
+ )
246
+ self.step_window_end = gr.Slider(
247
+ label='Global window end', minimum=0.0, maximum=1.0, step=0.01, value=1.0,
248
+ visible=False,
249
+ )
250
+ # Per-strategy custom sliders — one start/end pair per strategy
251
+ _ps_defaults = _step_default_custom()
252
+ self.step_ps_sliders: Dict[str, Tuple] = {} # {ui_key: (start_slider, end_slider)}
253
+ for _ui_key in _STRATEGY_UI_KEYS:
254
+ _s, _e = _ps_defaults.get(_ui_key, (0.0, 1.0))
255
+ self.step_ps_sliders[_ui_key] = (
256
+ gr.Slider(label=f'{_ui_key} start', minimum=0.0, maximum=1.0,
257
+ step=0.01, value=_s, visible=False),
258
+ gr.Slider(label=f'{_ui_key} end', minimum=0.0, maximum=1.0,
259
+ step=0.01, value=_e, visible=False),
260
+ )
261
+ self.step_ps_reset_btn = gr.Button(value='Reset to defaults', visible=False)
262
+ self.step_window_preview = gr.Textbox(
263
+ label='Step window preview', interactive=False, visible=False,
264
+ value='',
265
+ )
266
+
267
+ # --- Matryoshka builder ---
268
+ self.builder_base = gr.Textbox(
269
+ label='Base prompt', lines=2, placeholder='main subject, environment, style…',
270
+ )
271
+ # Child 1
272
+ self.builder_ch1_strategy = gr.Dropdown(
273
+ label='Child 1 — strategy', choices=BUILDER_STRATEGIES, value='AND_SALT',
274
+ )
275
+ self.builder_ch1_text = gr.Textbox(
276
+ label='Child 1 — text', lines=1, placeholder='concept, texture, style…',
277
+ )
278
+ self.builder_ch1_weight = gr.Slider(
279
+ label='Child 1 — weight', minimum=0.0, maximum=2.0, step=0.05, value=0.8,
280
+ )
281
+ self.builder_ch1_affine = gr.Textbox(
282
+ label='Child 1 — affine (optional)', lines=1,
283
+ placeholder='e.g. ROTATE[0.125] or SCALE[-1,1]',
284
+ )
285
+ # Optional nested child inside child 1
286
+ self.builder_ch1_nested_enable = gr.Checkbox(label='Add nested child inside child 1', value=False)
287
+ self.builder_ch1_nested_strategy = gr.Dropdown(
288
+ label='Nested — strategy', choices=BUILDER_STRATEGIES, value='AND_TOPK', visible=False,
289
+ )
290
+ self.builder_ch1_nested_text = gr.Textbox(
291
+ label='Nested — text', lines=1, placeholder='inner concept…', visible=False,
292
+ )
293
+ self.builder_ch1_nested_weight = gr.Slider(
294
+ label='Nested — weight', minimum=0.0, maximum=2.0, step=0.05, value=0.5, visible=False,
295
+ )
296
+ # Child 2
297
+ self.builder_ch2_enable = gr.Checkbox(label='Add second child', value=False)
298
+ self.builder_ch2_strategy = gr.Dropdown(
299
+ label='Child 2 — strategy', choices=BUILDER_STRATEGIES, value='AND_PERP', visible=False,
300
+ )
301
+ self.builder_ch2_text = gr.Textbox(
302
+ label='Child 2 — text', lines=1, placeholder='concept…', visible=False,
303
+ )
304
+ self.builder_ch2_weight = gr.Slider(
305
+ label='Child 2 — weight', minimum=0.0, maximum=2.0, step=0.05, value=0.6, visible=False,
306
+ )
307
+ # Output
308
+ self.builder_preview = gr.Textbox(
309
+ label='Generated prompt (preview)', lines=7, interactive=False,
310
+ )
311
+ self.builder_apply_button = gr.Button(value='Apply to prompt')
312
+ self.builder_copy_button = gr.Button(value='Copy generated prompt')
313
+ self.builder_explain_button = gr.Button(value='Explain this prompt')
314
+
315
+ # --- Templates ---
316
+ self.template_dropdown = gr.Dropdown(
317
+ label='Matryoshka template', choices=list(MATRYOSHKA_TEMPLATES.keys()),
318
+ value=next(iter(MATRYOSHKA_TEMPLATES.keys())),
319
+ )
320
+ self.template_description = gr.Textbox(
321
+ label='Template description', interactive=False,
322
+ value=next(iter(MATRYOSHKA_TEMPLATES.values()))['description'],
323
+ )
324
+ self.template_preview = gr.Textbox(
325
+ label='Template prompt', lines=8, interactive=False,
326
+ value=next(iter(MATRYOSHKA_TEMPLATES.values()))['prompt'],
327
+ )
328
+ self.template_load_button = gr.Button(value='Load template into prompt')
329
+
330
+ # --- Debug / explain (matryoshka tree) ---
331
+ self.debug_prompt_input = gr.Textbox(
332
+ label='Prompt to explain', lines=5,
333
+ placeholder='Paste any prompt here, including nested ones, to see the full parse tree…',
334
+ )
335
+ self.debug_output = gr.Textbox(
336
+ label='Parse tree + diagnostics (structural preview — not a full runtime simulation)',
337
+ lines=18, interactive=False,
338
+ value='(enter a prompt above)',
339
+ )
340
+ self.debug_parse_button = gr.Button(value='Explain prompt')
341
+ self.debug_copy_button = gr.Button(value='Copy explain output')
342
+
343
+ # ------------------------------------------------------------------
344
+
345
+ def arrange_components(self, is_img2img: bool) -> None:
346
+ if self.is_rendered:
347
+ return
348
+
349
+ with gr.Accordion(label='Neutral Prompt', open=False):
350
+ self.cfg_rescale.render()
351
+
352
+ with gr.Accordion(label='Base prompt protection', open=False):
353
+ self.protection_mode.render()
354
+ self.strict_threshold.render()
355
+
356
+ with gr.Accordion(label='Prompt formatter', open=False):
357
+ self.neutral_prompt.render()
358
+ self.neutral_cond_scale.render()
359
+ self.aux_prompt_type.render()
360
+ self.salt_k.render()
361
+ self.topk_threshold.render()
362
+ with gr.Row():
363
+ self.align_detail.render()
364
+ self.align_structure.render()
365
+ self.append_to_prompt_button.render()
366
+
367
+ with gr.Accordion(label='Matryoshka builder', open=False):
368
+ self.builder_base.render()
369
+ with gr.Row():
370
+ self.builder_ch1_strategy.render()
371
+ self.builder_ch1_weight.render()
372
+ self.builder_ch1_text.render()
373
+ self.builder_ch1_affine.render()
374
+ self.builder_ch1_nested_enable.render()
375
+ with gr.Row():
376
+ self.builder_ch1_nested_strategy.render()
377
+ self.builder_ch1_nested_weight.render()
378
+ self.builder_ch1_nested_text.render()
379
+ self.builder_ch2_enable.render()
380
+ with gr.Row():
381
+ self.builder_ch2_strategy.render()
382
+ self.builder_ch2_weight.render()
383
+ self.builder_ch2_text.render()
384
+ self.builder_preview.render()
385
+ with gr.Row():
386
+ self.builder_apply_button.render()
387
+ self.builder_copy_button.render()
388
+ self.builder_explain_button.render()
389
+
390
+ with gr.Accordion(label='Matryoshka templates', open=False):
391
+ self.template_dropdown.render()
392
+ self.template_description.render()
393
+ self.template_preview.render()
394
+ self.template_load_button.render()
395
+
396
+ with gr.Accordion(label='Affine transform builder', open=False):
397
+ self.affine_preset.render()
398
+ with gr.Row():
399
+ self.affine_transform.render()
400
+ self.affine_p1.render()
401
+ self.affine_p2.render()
402
+ self.affine_preview.render()
403
+ self.affine_insert_button.render()
404
+
405
+ with gr.Accordion(label='Step-window gating (AND_* activation)', open=False):
406
+ self.step_window_enabled.render()
407
+ self.step_window_lock.render()
408
+ self.step_window_mode.render()
409
+ with gr.Row():
410
+ self.step_window_start.render()
411
+ self.step_window_end.render()
412
+ with gr.Accordion(label='Per-strategy custom windows', open=False):
413
+ for _ui_key, (_s_slider, _e_slider) in self.step_ps_sliders.items():
414
+ with gr.Row():
415
+ _s_slider.render()
416
+ _e_slider.render()
417
+ self.step_ps_reset_btn.render()
418
+ self.step_window_preview.render()
419
+
420
+ with gr.Accordion(label='Prompt debug / explain', open=False):
421
+ self.debug_prompt_input.render()
422
+ with gr.Row():
423
+ self.debug_parse_button.render()
424
+ self.debug_copy_button.render()
425
+ self.debug_output.render()
426
+
427
+ def connect_events(self, is_img2img: bool) -> None:
428
+ if self.is_rendered:
429
+ return
430
+
431
+ prompt_textbox = img2img_prompt_textbox if is_img2img else txt2img_prompt_textbox
432
+
433
+ # ---- Protection ----
434
+ def _on_protection_change(mode, threshold):
435
+ mode = global_state.normalize_protection_mode(mode)
436
+ global_state.protection_mode = mode
437
+ global_state.strict_threshold = global_state.clamp_strict_threshold(threshold)
438
+ return gr.update(visible=(mode in ('strict', 'soft')))
439
+
440
+ self.protection_mode.change(
441
+ fn=_on_protection_change,
442
+ inputs=[self.protection_mode, self.strict_threshold],
443
+ outputs=[self.strict_threshold],
444
+ )
445
+ self.strict_threshold.change(
446
+ fn=lambda m, t: (
447
+ setattr(global_state, 'protection_mode', global_state.normalize_protection_mode(m)) or
448
+ setattr(global_state, 'strict_threshold', global_state.clamp_strict_threshold(t))
449
+ ),
450
+ inputs=[self.protection_mode, self.strict_threshold], outputs=[],
451
+ )
452
+
453
+ # ---- Formatter type sliders ----
454
+ def _on_type_change(label):
455
+ return (gr.update(visible=(label in _SALT_LABELS)),
456
+ gr.update(visible=(label in _TOPK_LABELS)),
457
+ gr.update(visible=(label in _ALIGN_CUSTOM_LABELS)),
458
+ gr.update(visible=(label in _ALIGN_CUSTOM_LABELS)))
459
+
460
+ self.aux_prompt_type.change(
461
+ fn=_on_type_change, inputs=[self.aux_prompt_type],
462
+ outputs=[self.salt_k, self.topk_threshold, self.align_detail, self.align_structure],
463
+ )
464
+
465
+ self.append_to_prompt_button.click(
466
+ fn=lambda init, p, sc, lbl, k, topk, d, s: (
467
+ f'{init}\n{_build_keyword(lbl, k, int(d), int(s), topk)} {p} :{sc}', ''),
468
+ inputs=[prompt_textbox, self.neutral_prompt, self.neutral_cond_scale,
469
+ self.aux_prompt_type, self.salt_k, self.topk_threshold,
470
+ self.align_detail, self.align_structure],
471
+ outputs=[prompt_textbox, self.neutral_prompt],
472
+ )
473
+
474
+ # ---- Affine ----
475
+ def _apply_affine_preset(name):
476
+ data = _AFFINE_PRESETS.get(name)
477
+ if data is None:
478
+ return gr.update(), gr.update(), gr.update(), gr.update()
479
+ t, p1, p2 = data[0], float(data[1]) if len(data) > 1 else 0.0, float(data[2]) if len(data) > 2 else 1.0
480
+ return (gr.update(value=t), gr.update(value=p1),
481
+ gr.update(value=p2), gr.update(value=_build_affine_snippet(t, p1, p2)))
482
+
483
+ self.affine_preset.change(
484
+ fn=_apply_affine_preset, inputs=[self.affine_preset],
485
+ outputs=[self.affine_transform, self.affine_p1, self.affine_p2, self.affine_preview],
486
+ )
487
+ def _live_affine(t, p1, p2):
488
+ return gr.update(value=_build_affine_snippet(t, p1, p2))
489
+ for c in (self.affine_transform, self.affine_p1, self.affine_p2):
490
+ c.change(fn=_live_affine, inputs=[self.affine_transform, self.affine_p1, self.affine_p2],
491
+ outputs=[self.affine_preview])
492
+
493
+ self.affine_insert_button.click(
494
+ fn=lambda init, t, p1, p2: f'{init} {_build_affine_snippet(t, p1, p2)}'.strip(),
495
+ inputs=[prompt_textbox, self.affine_transform, self.affine_p1, self.affine_p2],
496
+ outputs=[prompt_textbox],
497
+ )
498
+
499
+ # ---- Builder live preview ----
500
+ def _builder_update(base, s1, t1, w1, aff1, nest_on, ns, nt, nw, ch2_on, s2, t2, w2):
501
+ nested_str = build_child_block(ns, nt, nw) if nest_on and nt.strip() else ''
502
+ ch1 = build_child_block(s1, t1, w1, affine=aff1, nested=nested_str) if t1.strip() else ''
503
+ ch2 = build_child_block(s2, t2, w2) if ch2_on and t2.strip() else ''
504
+ children = [c for c in [
505
+ {'strategy': s1, 'text': t1, 'weight': w1, 'affine': aff1,
506
+ 'nested': build_child_block(ns, nt, nw) if nest_on and nt.strip() else ''},
507
+ {'strategy': s2, 'text': t2, 'weight': w2} if ch2_on and t2.strip() else None,
508
+ ] if c is not None]
509
+ return gr.update(value=build_nested_prompt(base, children))
510
+
511
+ _builder_inputs = [
512
+ self.builder_base,
513
+ self.builder_ch1_strategy, self.builder_ch1_text, self.builder_ch1_weight,
514
+ self.builder_ch1_affine,
515
+ self.builder_ch1_nested_enable,
516
+ self.builder_ch1_nested_strategy, self.builder_ch1_nested_text, self.builder_ch1_nested_weight,
517
+ self.builder_ch2_enable,
518
+ self.builder_ch2_strategy, self.builder_ch2_text, self.builder_ch2_weight,
519
+ ]
520
+ for ctrl in _builder_inputs:
521
+ ctrl.change(fn=_builder_update, inputs=_builder_inputs, outputs=[self.builder_preview])
522
+
523
+ # Toggle visibility of nested / ch2 sections
524
+ self.builder_ch1_nested_enable.change(
525
+ fn=lambda on: (gr.update(visible=on), gr.update(visible=on), gr.update(visible=on)),
526
+ inputs=[self.builder_ch1_nested_enable],
527
+ outputs=[self.builder_ch1_nested_strategy, self.builder_ch1_nested_text,
528
+ self.builder_ch1_nested_weight],
529
+ )
530
+ self.builder_ch2_enable.change(
531
+ fn=lambda on: (gr.update(visible=on), gr.update(visible=on), gr.update(visible=on)),
532
+ inputs=[self.builder_ch2_enable],
533
+ outputs=[self.builder_ch2_strategy, self.builder_ch2_text, self.builder_ch2_weight],
534
+ )
535
+
536
+ # Builder → apply / explain
537
+ self.builder_apply_button.click(
538
+ fn=lambda init, preview: (f'{init}\n{preview}'.strip() if preview.strip() else init, ''),
539
+ inputs=[prompt_textbox, self.builder_preview], outputs=[prompt_textbox, self.builder_preview],
540
+ )
541
+ # Copy buttons: copy content into a dedicated copy-sink textbox via JS,
542
+ # or — since Gradio doesn't have a native clipboard API — make the textbox
543
+ # temporarily interactive so the user can Ctrl+A / Ctrl+C easily.
544
+ # We implement this as a no-op event (the Textbox is already selectable).
545
+ # A real clipboard call requires gr.HTML + JS, which we keep optional.
546
+ self.builder_copy_button.click(fn=lambda x: x,
547
+ inputs=[self.builder_preview],
548
+ outputs=[self.builder_preview])
549
+ self.debug_copy_button.click(fn=lambda x: x,
550
+ inputs=[self.debug_output],
551
+ outputs=[self.debug_output])
552
+ self.builder_explain_button.click(
553
+ fn=lambda preview: gr.update(value=render_full_explain(preview)),
554
+ inputs=[self.builder_preview], outputs=[self.debug_output],
555
+ )
556
+
557
+ # ---- Templates ----
558
+ def _on_template_change(name):
559
+ t = MATRYOSHKA_TEMPLATES.get(name, {})
560
+ return (gr.update(value=t.get('description', '')),
561
+ gr.update(value=t.get('prompt', '')))
562
+
563
+ self.template_dropdown.change(
564
+ fn=_on_template_change, inputs=[self.template_dropdown],
565
+ outputs=[self.template_description, self.template_preview],
566
+ )
567
+ self.template_load_button.click(
568
+ fn=lambda init, tmpl: f'{init}\n{tmpl}'.strip() if tmpl.strip() else init,
569
+ inputs=[prompt_textbox, self.template_preview], outputs=[prompt_textbox],
570
+ )
571
+
572
+ # ---- Step window ----
573
+ _ps_all_sliders = [sl for (s, e) in self.step_ps_sliders.values() for sl in (s, e)]
574
+
575
+ def _step_window_update(enabled, lock, mode, start, end, *ps_vals):
576
+ vis = bool(enabled)
577
+ show_global = vis and mode == 'global'
578
+ show_custom = vis and mode == 'per-strategy custom'
579
+
580
+ global_state.step_window_enabled = vis
581
+ global_state.step_window_lock_after_end = bool(lock)
582
+ if vis:
583
+ if mode == 'global':
584
+ global_state.step_window_global = StepWindow(start, end)
585
+ global_state.step_window_use_defaults = False
586
+ global_state.step_window_per_strategy = None
587
+ elif mode == 'per-strategy defaults':
588
+ global_state.step_window_global = None
589
+ global_state.step_window_use_defaults = True
590
+ global_state.step_window_per_strategy = None
591
+ elif mode == 'per-strategy custom':
592
+ raw = {}
593
+ for idx, ui_key in enumerate(_STRATEGY_UI_KEYS):
594
+ raw[ui_key] = (float(ps_vals[idx * 2]), float(ps_vals[idx * 2 + 1]))
595
+ global_state.step_window_global = None
596
+ global_state.step_window_use_defaults = False
597
+ global_state.step_window_custom_raw = raw
598
+ global_state.step_window_per_strategy = _build_per_strategy_windows(raw)
599
+ else:
600
+ global_state.step_window_global = None
601
+ global_state.step_window_use_defaults = False
602
+ global_state.step_window_per_strategy = None
603
+
604
+ custom_raw = getattr(global_state, 'step_window_custom_raw', None) or {}
605
+ preview = _build_step_window_preview(mode, start, end, custom_raw) if vis else ''
606
+
607
+ ps_vis = [gr.update(visible=show_custom)] * len(_ps_all_sliders)
608
+ return (
609
+ gr.update(visible=vis), # mode
610
+ gr.update(visible=vis), # lock checkbox
611
+ gr.update(visible=show_global), # global start
612
+ gr.update(visible=show_global), # global end
613
+ *ps_vis, # per-strategy sliders
614
+ gr.update(visible=show_custom), # reset_btn
615
+ gr.update(visible=vis, value=preview), # preview
616
+ )
617
+
618
+ def _ps_reset(*_):
619
+ from lib_neutral_prompt.step_utils import default_custom_windows
620
+ defaults = default_custom_windows()
621
+ updates = []
622
+ for ui_key in _STRATEGY_UI_KEYS:
623
+ s, e = defaults.get(ui_key, (0.0, 1.0))
624
+ updates.extend([gr.update(value=s), gr.update(value=e)])
625
+ return updates
626
+
627
+ _sw_inputs = ([self.step_window_enabled, self.step_window_lock,
628
+ self.step_window_mode,
629
+ self.step_window_start, self.step_window_end]
630
+ + _ps_all_sliders)
631
+ _sw_outputs = ([self.step_window_mode, self.step_window_lock,
632
+ self.step_window_start, self.step_window_end]
633
+ + _ps_all_sliders
634
+ + [self.step_ps_reset_btn, self.step_window_preview])
635
+
636
+ for ctrl in _sw_inputs:
637
+ ctrl.change(fn=_step_window_update, inputs=_sw_inputs, outputs=_sw_outputs)
638
+
639
+ self.step_ps_reset_btn.click(fn=_ps_reset, inputs=[], outputs=_ps_all_sliders)
640
+ _explain_fn = lambda s: gr.update(value=render_full_explain(s))
641
+ self.debug_parse_button.click(fn=_explain_fn, inputs=[self.debug_prompt_input],
642
+ outputs=[self.debug_output])
643
+ self.debug_prompt_input.change(fn=_explain_fn, inputs=[self.debug_prompt_input],
644
+ outputs=[self.debug_output])
645
+
646
+ def set_rendered(self, value: bool = True) -> None:
647
+ self.is_rendered = value
648
+
649
+ # ------------------------------------------------------------------
650
+
651
+ def _collect_ps_slider_values(self) -> Dict[str, Tuple[float, float]]:
652
+ """Read current per-strategy slider values into a raw dict."""
653
+ raw: Dict[str, Tuple[float, float]] = {}
654
+ for ui_key, (s_sl, e_sl) in self.step_ps_sliders.items():
655
+ raw[ui_key] = (float(s_sl.value), float(e_sl.value))
656
+ return raw
657
+
658
+ def get_components(self) -> Tuple:
659
+ ps_components = []
660
+ for _ui_key, (_s, _e) in self.step_ps_sliders.items():
661
+ ps_components.extend([_s, _e])
662
+ return (self.cfg_rescale, self.protection_mode, self.strict_threshold,
663
+ self.step_window_enabled, self.step_window_lock,
664
+ self.step_window_mode,
665
+ self.step_window_start, self.step_window_end,
666
+ *ps_components)
667
+
668
+ def get_infotext_fields(self) -> Tuple:
669
+ return tuple(zip(self.get_components(), (
670
+ 'CFG Rescale phi',
671
+ 'NP Protection Mode', 'NP Strict Threshold',
672
+ 'NP Step Window Enabled', 'NP Step Window Lock',
673
+ 'NP Step Window Mode',
674
+ 'NP Step Window Start', 'NP Step Window End',
675
+ # per-strategy sliders not individually in infotext — serialised as one key
676
+ )))
677
+
678
+ def get_paste_field_names(self) -> List[str]:
679
+ return [
680
+ 'CFG Rescale phi',
681
+ 'NP Protection Mode', 'NP Strict Threshold',
682
+ 'NP Step Window Enabled', 'NP Step Window Lock',
683
+ 'NP Step Window Mode',
684
+ 'NP Step Window Start', 'NP Step Window End',
685
+ 'NP Step Window Custom',
686
+ ]
687
+
688
+ def get_extra_generation_params(self, args: Dict) -> Dict:
689
+ params = {
690
+ 'CFG Rescale phi': args['cfg_rescale'],
691
+ 'NP Protection Mode': args['protection_mode'],
692
+ 'NP Strict Threshold': args['strict_threshold'],
693
+ }
694
+ if args.get('step_window_enabled'):
695
+ params['NP Step Window Enabled'] = True
696
+ params['NP Step Window Mode'] = args.get('step_window_mode', 'global')
697
+ params['NP Step Window Start'] = round(float(args.get('step_window_start', 0.0)), 4)
698
+ params['NP Step Window End'] = round(float(args.get('step_window_end', 1.0)), 4)
699
+ if args.get('step_window_lock_after_end'):
700
+ params['NP Step Window Lock'] = True
701
+ if args.get('step_window_mode') == 'per-strategy custom':
702
+ raw = args.get('step_window_custom_raw') or {}
703
+ params['NP Step Window Custom'] = _serialize_per_strategy_windows(raw)
704
+ return params
705
+
706
+ def unpack_processing_args(self, cfg_rescale: float,
707
+ protection_mode: str = 'auto',
708
+ strict_threshold: float = 0.1,
709
+ step_window_enabled: bool = False,
710
+ step_window_lock_after_end: bool = False,
711
+ step_window_mode: str = 'global',
712
+ step_window_start: float = 0.0,
713
+ step_window_end: float = 1.0,
714
+ **ps_kwargs) -> Dict:
715
+ # Per-strategy raw dict: collect from kwargs (AND_PERP_start, AND_PERP_end, ...)
716
+ custom_raw: Dict[str, Tuple[float, float]] = {}
717
+ for ui_key in _STRATEGY_UI_KEYS:
718
+ s_key = f'{ui_key}_start'
719
+ e_key = f'{ui_key}_end'
720
+ if s_key in ps_kwargs and e_key in ps_kwargs:
721
+ custom_raw[ui_key] = (float(ps_kwargs[s_key]), float(ps_kwargs[e_key]))
722
+
723
+ enabled = bool(step_window_enabled)
724
+ lock = bool(step_window_lock_after_end)
725
+ mode = str(step_window_mode)
726
+ start = max(0.0, min(1.0, float(step_window_start)))
727
+ end = max(0.0, min(1.0, float(step_window_end)))
728
+ if start > end:
729
+ start, end = 0.0, 1.0
730
+
731
+ global_state.step_window_enabled = enabled
732
+ global_state.step_window_lock_after_end = lock
733
+ if enabled:
734
+ if mode == 'global':
735
+ global_state.step_window_global = StepWindow(start, end)
736
+ global_state.step_window_use_defaults = False
737
+ global_state.step_window_per_strategy = None
738
+ global_state.step_window_custom_raw = None
739
+ elif mode == 'per-strategy defaults':
740
+ global_state.step_window_global = None
741
+ global_state.step_window_use_defaults = True
742
+ global_state.step_window_per_strategy = None
743
+ global_state.step_window_custom_raw = None
744
+ elif mode == 'per-strategy custom':
745
+ global_state.step_window_global = None
746
+ global_state.step_window_use_defaults = False
747
+ global_state.step_window_custom_raw = custom_raw
748
+ global_state.step_window_per_strategy = _build_per_strategy_windows(custom_raw)
749
+ else:
750
+ global_state.step_window_global = None
751
+ global_state.step_window_use_defaults = False
752
+ global_state.step_window_per_strategy = None
753
+ else:
754
+ global_state.step_window_global = None
755
+ global_state.step_window_use_defaults = False
756
+ global_state.step_window_per_strategy = None
757
+
758
+ return {
759
+ 'cfg_rescale': cfg_rescale,
760
+ 'protection_mode': protection_mode,
761
+ 'strict_threshold': strict_threshold,
762
+ 'step_window_enabled': enabled,
763
+ 'step_window_lock_after_end': lock,
764
+ 'step_window_mode': mode,
765
+ 'step_window_start': start,
766
+ 'step_window_end': end,
767
+ 'step_window_custom_raw': custom_raw,
768
+ }
769
+
770
+
771
+ # ---------------------------------------------------------------------------
772
+ # Settings & callbacks
773
+ # ---------------------------------------------------------------------------
774
+
775
+ def on_ui_settings() -> None:
776
+ section = ('neutral_prompt', 'Neutral Prompt')
777
+ shared.opts.add_option(
778
+ 'neutral_prompt_enabled',
779
+ shared.OptionInfo(True, 'Enable neutral-prompt extension', section=section),
780
+ )
781
+ global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
782
+ shared.opts.onchange('neutral_prompt_enabled', _update_enabled)
783
+ shared.opts.add_option(
784
+ 'neutral_prompt_verbose',
785
+ shared.OptionInfo(False, 'Enable verbose debugging for neutral-prompt', section=section),
786
+ )
787
+ global_state.verbose = shared.opts.data.get('neutral_prompt_verbose', False)
788
+ shared.opts.onchange('neutral_prompt_verbose', _update_verbose)
789
+
790
+
791
+ script_callbacks.on_ui_settings(on_ui_settings)
792
+
793
+
794
+ def _update_enabled() -> None:
795
+ global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
796
+
797
+
798
+ def _update_verbose() -> None:
799
+ global_state.verbose = shared.opts.data.get('neutral_prompt_verbose', False)
800
+
801
+
802
+ def on_after_component(component, **_kwargs) -> None:
803
+ global txt2img_prompt_textbox, img2img_prompt_textbox
804
+ eid = getattr(component, 'elem_id', None)
805
+ if eid == 'txt2img_prompt':
806
+ txt2img_prompt_textbox = component
807
+ elif eid == 'img2img_prompt':
808
+ img2img_prompt_textbox = component
809
+
810
+
811
+ script_callbacks.on_after_component(on_after_component)
neutral_prompt_patcheds/lib_neutral_prompt/xyz_grid.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from types import ModuleType
3
+ from typing import Optional
4
+ from modules import scripts
5
+ from lib_neutral_prompt import global_state
6
+
7
+
8
+ def patch():
9
+ xyz_module = find_xyz_module()
10
+ if xyz_module is None:
11
+ print("[sd-webui-neutral-prompt]", "xyz_grid.py not found.", file=sys.stderr)
12
+ return
13
+
14
+ xyz_module.axis_options.extend([
15
+ xyz_module.AxisOption("[Neutral Prompt] CFG Rescale", int_or_float, apply_cfg_rescale()),
16
+ ])
17
+
18
+
19
+ class XyzFloat(float):
20
+ is_xyz: bool = True
21
+
22
+
23
+ def apply_cfg_rescale():
24
+ def callback(_p, v, _vs):
25
+ global_state.cfg_rescale = XyzFloat(v)
26
+
27
+ return callback
28
+
29
+
30
+ def int_or_float(string):
31
+ try:
32
+ return int(string)
33
+ except ValueError:
34
+ return float(string)
35
+
36
+
37
+ def find_xyz_module() -> Optional[ModuleType]:
38
+ for data in scripts.scripts_data:
39
+ if data.script_class.__module__ in {"xyz_grid.py", "xy_grid.py"} and hasattr(data, "module"):
40
+ return data.module
41
+
42
+ return None
neutral_prompt_patcheds/scripts/__pycache__/neutral_prompt.cpython-310.pyc ADDED
Binary file (3.63 kB). View file
 
neutral_prompt_patcheds/scripts/neutral_prompt.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser, prompt_parser_hijack, cfg_denoiser_hijack, ui, xyz_grid
2
+ from modules import scripts, processing, shared
3
+ from typing import Dict
4
+ import functools
5
+
6
+
7
+ class NeutralPromptScript(scripts.Script):
8
+ def __init__(self):
9
+ self.accordion_interface = None
10
+ self._is_img2img = False
11
+
12
+ @property
13
+ def is_img2img(self):
14
+ return self._is_img2img
15
+
16
+ @is_img2img.setter
17
+ def is_img2img(self, is_img2img):
18
+ self._is_img2img = is_img2img
19
+ if self.accordion_interface is None:
20
+ self.accordion_interface = ui.AccordionInterface(self.elem_id)
21
+
22
+ def title(self) -> str:
23
+ return "Neutral Prompt"
24
+
25
+ def show(self, is_img2img: bool):
26
+ return scripts.AlwaysVisible
27
+
28
+ def ui(self, is_img2img: bool):
29
+ self.hijack_composable_lora(is_img2img)
30
+
31
+ self.accordion_interface.arrange_components(is_img2img)
32
+ self.accordion_interface.connect_events(is_img2img)
33
+ self.infotext_fields = self.accordion_interface.get_infotext_fields()
34
+ self.paste_field_names = self.accordion_interface.get_paste_field_names()
35
+ self.accordion_interface.set_rendered()
36
+ return self.accordion_interface.get_components()
37
+
38
+ def process(self, p: processing.StableDiffusionProcessing, *args):
39
+ args = self.accordion_interface.unpack_processing_args(*args)
40
+
41
+ self.update_global_state(args)
42
+ if global_state.is_enabled:
43
+ p.extra_generation_params.update(self.accordion_interface.get_extra_generation_params(args))
44
+ # Reset Lock-after-End flags once per real generation lifecycle.
45
+ # This must happen AFTER update_global_state so that
46
+ # step_window_lock_after_end reflects the current UI value.
47
+ if global_state.step_window_lock_after_end:
48
+ global_state.begin_new_generation()
49
+
50
+ def update_global_state(self, args: Dict):
51
+ if shared.state.job_no == 0:
52
+ global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
53
+
54
+ for k, v in args.items():
55
+ try:
56
+ getattr(global_state, k)
57
+ except AttributeError:
58
+ continue
59
+
60
+ if getattr(getattr(global_state, k), 'is_xyz', False):
61
+ xyz_attr = getattr(global_state, k)
62
+ xyz_attr.is_xyz = False
63
+ args[k] = xyz_attr
64
+ continue
65
+
66
+ if shared.state.job_no > 0:
67
+ continue
68
+
69
+ setattr(global_state, k, v)
70
+
71
+ def hijack_composable_lora(self, is_img2img):
72
+ if self.accordion_interface.is_rendered:
73
+ return
74
+
75
+ lora_script = None
76
+ script_runner = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img
77
+
78
+ for script in script_runner.alwayson_scripts:
79
+ if script.title().lower() == "composable lora":
80
+ lora_script = script
81
+ break
82
+
83
+ if lora_script is not None:
84
+ lora_script.process = functools.partial(composable_lora_process_hijack, original_function=lora_script.process)
85
+
86
+
87
+ def composable_lora_process_hijack(p: processing.StableDiffusionProcessing, *args, original_function, **kwargs):
88
+ if not global_state.is_enabled:
89
+ return original_function(p, *args, **kwargs)
90
+
91
+ exprs = prompt_parser_hijack.parse_prompts(p.all_prompts)
92
+ all_prompts, p.all_prompts = p.all_prompts, prompt_parser_hijack.transpile_exprs(exprs)
93
+ res = original_function(p, *args, **kwargs)
94
+ # restore original prompts
95
+ p.all_prompts = all_prompts
96
+ return res
97
+
98
+
99
+ xyz_grid.patch()
neutral_prompt_patcheds/test/perp_parser/__init__.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test package init — installs a minimal torch mock when real torch is absent.
3
+
4
+ This runs before any test module in the package is imported, so all parser
5
+ tests get a working `torch` in sys.modules without needing torch installed.
6
+ Runtime tests that require real torch use @unittest.skipUnless independently.
7
+ """
8
+ import pathlib, sys
9
+
10
+ # Make the extension root importable regardless of CWD.
11
+ _root = str(pathlib.Path(__file__).parent.parent.parent)
12
+ if _root not in sys.path:
13
+ sys.path.insert(0, _root)
14
+
15
+ # Install mock if real torch is absent.
16
+ import math, types
17
+
18
+ def _install_mock_torch():
19
+ try:
20
+ import importlib.util
21
+ spec = importlib.util.find_spec('torch')
22
+ if spec is not None and getattr(spec, 'origin', None) is not None:
23
+ return # real torch present — nothing to do
24
+ except (ValueError, ModuleNotFoundError):
25
+ pass
26
+
27
+ if 'torch' in sys.modules:
28
+ return # already mocked
29
+
30
+ class _FT:
31
+ def __init__(self, data=None): self.data = data
32
+ def __getitem__(self, key):
33
+ if isinstance(self.data, list):
34
+ return _FT(self.data[key] if isinstance(key, (int, slice)) else None)
35
+ return _FT(None)
36
+ def __matmul__(self, other): return _FT('matmul')
37
+
38
+ torch = types.ModuleType('torch')
39
+ torch.Tensor = _FT
40
+ torch.float32 = 'float32'
41
+ torch.pi = math.pi
42
+ torch.eye = lambda n, dtype=None: _FT(list(range(n)))
43
+ torch.tensor = lambda data, dtype=None: _FT(data)
44
+ torch.vstack = lambda tensors: _FT([t.data for t in tensors])
45
+ torch.linalg = types.SimpleNamespace(inv=lambda x: _FT('inv'))
46
+
47
+ nn = types.ModuleType('torch.nn')
48
+ nn_fn = types.ModuleType('torch.nn.functional')
49
+ torch.nn = nn
50
+ sys.modules['torch'] = torch
51
+ sys.modules['torch.nn'] = nn
52
+ sys.modules['torch.nn.functional'] = nn_fn
53
+
54
+ _install_mock_torch()
neutral_prompt_patcheds/test/perp_parser/mock_torch.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Minimal torch mock for parser-only tests (no real tensor math needed).
3
+
4
+ Usage — place this BEFORE any lib_neutral_prompt imports:
5
+
6
+ from test.perp_parser.mock_torch import install
7
+ install()
8
+ from lib_neutral_prompt import neutral_prompt_parser
9
+ """
10
+
11
+ import math
12
+ import sys
13
+ import types
14
+
15
+
16
+ class _FakeTensor:
17
+ """Enough of torch.Tensor for the parser AST dataclasses."""
18
+ def __init__(self, data=None):
19
+ self.data = data
20
+
21
+ def __getitem__(self, key):
22
+ if isinstance(self.data, list) and isinstance(key, (int, slice)):
23
+ return _FakeTensor(self.data[key])
24
+ return _FakeTensor(None)
25
+
26
+ def __matmul__(self, other):
27
+ return _FakeTensor('matmul')
28
+
29
+ def __repr__(self):
30
+ return f'FakeTensor({self.data})'
31
+
32
+
33
+ def install() -> None:
34
+ """Install a minimal torch mock into sys.modules (idempotent)."""
35
+ if 'torch' in sys.modules and not isinstance(sys.modules['torch'], types.ModuleType):
36
+ return # real torch already present
37
+ # Only mock if torch is actually missing
38
+ try:
39
+ import importlib.util
40
+ spec = importlib.util.find_spec('torch')
41
+ if spec is not None and getattr(spec, 'origin', None) is not None:
42
+ return # real torch is available — use it
43
+ except (ValueError, ModuleNotFoundError):
44
+ pass
45
+
46
+ torch = types.ModuleType('torch')
47
+ torch.Tensor = _FakeTensor
48
+ torch.float32 = 'float32'
49
+ torch.pi = math.pi
50
+ torch.eye = lambda n, dtype=None: _FakeTensor(list(range(n)))
51
+ torch.tensor = lambda data, dtype=None: _FakeTensor(data)
52
+ torch.vstack = lambda tensors: _FakeTensor([t.data for t in tensors])
53
+ torch.linalg = types.SimpleNamespace(inv=lambda x: _FakeTensor('inv'))
54
+
55
+ nn = types.ModuleType('torch.nn')
56
+ nn_fn = types.ModuleType('torch.nn.functional')
57
+ torch.nn = nn
58
+
59
+ sys.modules['torch'] = torch
60
+ sys.modules['torch.nn'] = nn
61
+ sys.modules['torch.nn.functional'] = nn_fn
neutral_prompt_patcheds/test/perp_parser/test_affine_keyword_order.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for affine-keyword ordering in parse_prompt.
3
+
4
+ Both orderings must work correctly:
5
+ (A) AND_PERP ROTATE[0.125] vivid colors ← trailing affine (original)
6
+ (B) ROTATE[0.125] AND_PERP vivid colors ← leading affine (documented)
7
+ """
8
+ import unittest
9
+
10
+ from lib_neutral_prompt import neutral_prompt_parser as p
11
+ from lib_neutral_prompt.neutral_prompt_parser import ConciliationStrategy
12
+
13
+
14
+ class TestAffineKeywordOrder(unittest.TestCase):
15
+
16
+ # ------------------------------------------------------------------ #
17
+ # trailing affine (AND_PERP ROTATE[...] text) #
18
+ # ------------------------------------------------------------------ #
19
+
20
+ def test_trailing_affine_basic(self):
21
+ root = p.parse_root("base AND_PERP ROTATE[0.125] vivid colors :0.8")
22
+ self.assertEqual(len(root.children), 2)
23
+ self.assertEqual(root.children[0].prompt.strip(), "base")
24
+ self.assertIsNone(root.children[0].local_transform)
25
+ self.assertIsNotNone(root.children[1].local_transform)
26
+ self.assertEqual(root.children[1].conciliation, ConciliationStrategy.PERPENDICULAR)
27
+ self.assertAlmostEqual(root.children[1].weight, 0.8)
28
+
29
+ # ------------------------------------------------------------------ #
30
+ # leading affine (ROTATE[...] AND_PERP text) #
31
+ # ------------------------------------------------------------------ #
32
+
33
+ def test_leading_affine_mid_prompt(self):
34
+ """ROTATE[...] AND_PERP must NOT be consumed into the previous prompt's text."""
35
+ root = p.parse_root("base ROTATE[0.125] AND_PERP vivid colors :0.8")
36
+ self.assertEqual(len(root.children), 2,
37
+ f"Expected 2 children, got {len(root.children)}: "
38
+ f"{[getattr(c,'prompt','composite') for c in root.children]}")
39
+ first, second = root.children
40
+ # first segment = "base " with no transform
41
+ self.assertIn("base", first.prompt)
42
+ self.assertNotIn("ROTATE", first.prompt)
43
+ self.assertIsNone(first.local_transform)
44
+ # second segment gets the transform
45
+ self.assertIsNotNone(second.local_transform)
46
+ self.assertEqual(second.conciliation, ConciliationStrategy.PERPENDICULAR)
47
+ self.assertAlmostEqual(second.weight, 0.8)
48
+
49
+ def test_leading_affine_at_start(self):
50
+ """ROTATE[...] AND_PERP as first (and only non-AND) segment."""
51
+ root = p.parse_root("ROTATE[0.125] AND_PERP vivid colors :0.8")
52
+ self.assertEqual(len(root.children), 1)
53
+ child = root.children[0]
54
+ self.assertIsNotNone(child.local_transform)
55
+ self.assertEqual(child.conciliation, ConciliationStrategy.PERPENDICULAR)
56
+ self.assertIn("vivid", child.prompt)
57
+ self.assertNotIn("ROTATE", child.prompt)
58
+
59
+ def test_leading_affine_with_and_salt(self):
60
+ root = p.parse_root("portrait SLIDE[0.05,0] AND_SALT vibrant :1.2")
61
+ self.assertEqual(len(root.children), 2)
62
+ second = root.children[1]
63
+ self.assertIsNotNone(second.local_transform)
64
+ self.assertEqual(second.conciliation, ConciliationStrategy.SALIENCE_MASK)
65
+
66
+ # ------------------------------------------------------------------ #
67
+ # Composed / chained affine #
68
+ # ------------------------------------------------------------------ #
69
+
70
+ def test_leading_and_trailing_compose(self):
71
+ """ROTATE[a] AND_PERP SCALE[b,b] text — both affines composed."""
72
+ root = p.parse_root("base AND_PERP ROTATE[0.125] AND_PERP SCALE[1.5,1.5] vivid")
73
+ # Just check no crash and all children parse
74
+ self.assertGreaterEqual(len(root.children), 1)
75
+
76
+ # ------------------------------------------------------------------ #
77
+ # CFGRescaleFactorSingleton lifecycle #
78
+ # ------------------------------------------------------------------ #
79
+
80
+ def test_cfg_rescale_singleton_clear(self):
81
+ from lib_neutral_prompt.global_state import CFGRescaleFactorSingleton as S
82
+ S.clear()
83
+ self.assertIsNone(S.get())
84
+ S.set(1.23)
85
+ self.assertAlmostEqual(S.get(), 1.23)
86
+ S.clear()
87
+ self.assertIsNone(S.get())
88
+
89
+ def test_cfg_rescale_singleton_thread_local(self):
90
+ import threading
91
+ from lib_neutral_prompt.global_state import CFGRescaleFactorSingleton as S
92
+ results = {}
93
+ def worker(val, key):
94
+ S.clear()
95
+ S.set(val)
96
+ import time; time.sleep(0.01)
97
+ results[key] = S.get()
98
+ threads = [threading.Thread(target=worker, args=(i * 10.0, i)) for i in range(3)]
99
+ for t in threads: t.start()
100
+ for t in threads: t.join()
101
+ for i in range(3):
102
+ self.assertAlmostEqual(results[i], i * 10.0,
103
+ msg=f"Thread {i} got {results[i]} expected {i*10.0}")
104
+
105
+ # ------------------------------------------------------------------ #
106
+ # New AND_SALT variants parse correctly #
107
+ # ------------------------------------------------------------------ #
108
+
109
+ def test_and_salt_wide_parsed(self):
110
+ root = p.parse_root("base AND_SALT_WIDE vivid")
111
+ self.assertEqual(len(root.children), 2)
112
+ self.assertEqual(root.children[1].conciliation,
113
+ p.ConciliationStrategy.SALIENCE_MASK_WIDE)
114
+
115
+ def test_and_salt_blob_parsed(self):
116
+ root = p.parse_root("base AND_SALT_BLOB vivid")
117
+ self.assertEqual(len(root.children), 2)
118
+ self.assertEqual(root.children[1].conciliation,
119
+ p.ConciliationStrategy.SALIENCE_MASK_BLOB)
120
+
121
+ def test_and_align_parsed(self):
122
+ root = p.parse_root("base AND_ALIGN_4_8 vivid")
123
+ self.assertEqual(len(root.children), 2)
124
+ self.assertIn('AND_ALIGN_4_8', root.children[1].conciliation.value)
125
+
126
+ def test_and_mask_align_parsed(self):
127
+ root = p.parse_root("base AND_MASK_ALIGN_4_8 vivid")
128
+ self.assertEqual(len(root.children), 2)
129
+ self.assertIn('AND_MASK_ALIGN_4_8', root.children[1].conciliation.value)
130
+
131
+
132
+ if __name__ == '__main__':
133
+ unittest.main()
neutral_prompt_patcheds/test/perp_parser/test_affine_pipeline.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Integration smoke tests for the pre-noise affine pipeline.
3
+
4
+ Requires PyTorch. If torch is not installed the entire module is skipped via
5
+ ``unittest.skipUnless``, so ``unittest discover`` still exits cleanly.
6
+
7
+ Mocked: A1111 modules only (modules.script_callbacks, modules.sd_samplers,
8
+ modules.shared). torch itself is real — no sys.modules replacement.
9
+
10
+ Tests:
11
+ 1. Hook is a no-op when global_state.is_enabled = False.
12
+ 2. Identity transform (angle=0) leaves x numerically unchanged.
13
+ 3. Non-identity transform calls apply_affine_transform (verified by spy).
14
+ 4. Prompts with no affine leave x unchanged.
15
+ 5. Empty state does not crash.
16
+ 6. Batch with multiple prompts does not crash.
17
+ """
18
+
19
+ import dataclasses
20
+ import importlib
21
+ import math
22
+ import sys
23
+ import types
24
+ import unittest
25
+
26
+ # ---------------------------------------------------------------------------
27
+ # Detect torch availability — skip the whole module if absent
28
+ # ---------------------------------------------------------------------------
29
+ try:
30
+ import importlib.util as _ilu
31
+ _spec = _ilu.find_spec('torch')
32
+ _TORCH_AVAILABLE = _spec is not None and getattr(_spec, 'origin', None) is not None
33
+ except (ValueError, ModuleNotFoundError):
34
+ _TORCH_AVAILABLE = False
35
+
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # A1111 module stubs (installed unconditionally so cfg_denoiser_hijack can
39
+ # be imported even when running the skip path)
40
+ # ---------------------------------------------------------------------------
41
+
42
+ def _stub(name):
43
+ m = types.ModuleType(name)
44
+ sys.modules.setdefault(name, m)
45
+ return sys.modules[name]
46
+
47
+ for _mod_name in ('modules', 'modules.script_callbacks', 'modules.sd_samplers',
48
+ 'modules.shared', 'modules.prompt_parser', 'gradio'):
49
+ _stub(_mod_name)
50
+
51
+ import modules.script_callbacks as _sc
52
+ if not hasattr(_sc, 'on_cfg_denoiser'):
53
+ _sc.on_cfg_denoiser = lambda fn: None
54
+ if not hasattr(_sc, 'on_script_unloaded'):
55
+ _sc.on_script_unloaded = lambda fn: None
56
+
57
+ import modules.shared as _sh
58
+ if not hasattr(_sh, 'opts'):
59
+ _sh.opts = types.SimpleNamespace(
60
+ data={},
61
+ add_option=lambda *a, **k: None,
62
+ onchange=lambda *a, **k: None,
63
+ OptionInfo=lambda *a, **k: None,
64
+ )
65
+
66
+ import modules.sd_samplers as _ss
67
+ if not hasattr(_ss, 'cfg_denoiser'):
68
+ _ss.cfg_denoiser = None
69
+ # cfg_denoiser_hijack installs a hijack on create_sampler at import time
70
+ if not hasattr(_ss, 'create_sampler'):
71
+ _ss.create_sampler = lambda *a, **k: None
72
+
73
+
74
+ # ---------------------------------------------------------------------------
75
+ # Repo root on sys.path
76
+ # ---------------------------------------------------------------------------
77
+
78
+ from pathlib import Path
79
+ _REPO_ROOT = Path(__file__).resolve().parents[2]
80
+ if str(_REPO_ROOT) not in sys.path:
81
+ sys.path.insert(0, str(_REPO_ROOT))
82
+
83
+
84
+ # ---------------------------------------------------------------------------
85
+ # Conditional imports (only when torch is available)
86
+ # ---------------------------------------------------------------------------
87
+
88
+ if _TORCH_AVAILABLE:
89
+ import torch
90
+ from lib_neutral_prompt import global_state, neutral_prompt_parser
91
+ from lib_neutral_prompt.cfg_denoiser_hijack import _on_cfg_denoiser
92
+ import lib_neutral_prompt.affine_transform as _at_mod
93
+
94
+ # -----------------------------------------------------------------------
95
+ # Helpers
96
+ # -----------------------------------------------------------------------
97
+
98
+ def _leaf(text, angle=None):
99
+ """LeafPrompt, optionally with a ROTATE affine (angle in turns)."""
100
+ transform = None
101
+ if angle is not None:
102
+ c = math.cos(angle * 2 * math.pi)
103
+ s = math.sin(angle * 2 * math.pi)
104
+ transform = torch.tensor([[c, -s, 0.0], [s, c, 0.0]])
105
+ return neutral_prompt_parser.LeafPrompt(1.0, None, transform, text)
106
+
107
+ def _comp(*children):
108
+ return neutral_prompt_parser.CompositePrompt(1.0, None, None, list(children))
109
+
110
+ @dataclasses.dataclass
111
+ class _FakeParams:
112
+ x: torch.Tensor
113
+
114
+
115
+ # ---------------------------------------------------------------------------
116
+ # Test class
117
+ # ---------------------------------------------------------------------------
118
+
119
+ @unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping pipeline tests')
120
+ class TestOnCfgDenoiserSmokeTest(unittest.TestCase):
121
+
122
+ def setUp(self):
123
+ global_state.is_enabled = True
124
+ global_state.prompt_exprs = []
125
+ global_state.batch_cond_indices = []
126
+
127
+ def tearDown(self):
128
+ global_state.is_enabled = False
129
+ global_state.prompt_exprs = []
130
+ global_state.batch_cond_indices = []
131
+
132
+ # 1. disabled → x untouched
133
+ def test_disabled_leaves_x_unchanged(self):
134
+ global_state.is_enabled = False
135
+ x = torch.zeros(2, 4, 8, 8)
136
+ x[0] = 1.0
137
+ orig = x.clone()
138
+ params = _FakeParams(x=x.clone())
139
+ global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp", 0.125))]
140
+ global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
141
+ _on_cfg_denoiser(params)
142
+ self.assertTrue(torch.equal(params.x, orig), "Disabled hook must not modify x")
143
+
144
+ # 2. identity transform → x numerically unchanged
145
+ def test_identity_transform_no_change(self):
146
+ x = torch.randn(2, 4, 8, 8)
147
+ orig = x.clone()
148
+ params = _FakeParams(x=x.clone())
149
+ global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp", 0))]
150
+ global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
151
+ _on_cfg_denoiser(params)
152
+ self.assertTrue(torch.allclose(params.x, orig, atol=1e-5),
153
+ "Identity affine must not change x")
154
+
155
+ # 3. non-identity → apply_affine_transform is called (spy)
156
+ def test_nonidentity_calls_apply_affine_transform(self):
157
+ x = torch.zeros(2, 4, 8, 8)
158
+ x[1] = torch.randn(4, 8, 8)
159
+ params = _FakeParams(x=x.clone())
160
+ global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp", 0.25))]
161
+ global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
162
+
163
+ _real = _at_mod.apply_affine_transform
164
+ calls = []
165
+ def _spy(tensor, affine, mode='bilinear'):
166
+ calls.append(True)
167
+ return _real(tensor, affine, mode)
168
+
169
+ _at_mod.apply_affine_transform = _spy
170
+ try:
171
+ _on_cfg_denoiser(params)
172
+ finally:
173
+ _at_mod.apply_affine_transform = _real
174
+
175
+ self.assertGreater(len(calls), 0,
176
+ "apply_affine_transform must be called for non-identity transform")
177
+
178
+ # 4. no affine on any child → x unchanged
179
+ def test_no_affine_no_change(self):
180
+ x = torch.randn(2, 4, 8, 8)
181
+ orig = x.clone()
182
+ params = _FakeParams(x=x.clone())
183
+ global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp"))]
184
+ global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
185
+ _on_cfg_denoiser(params)
186
+ self.assertTrue(torch.allclose(params.x, orig, atol=1e-5),
187
+ "No affine → x must not change")
188
+
189
+ # 5. empty state → no crash
190
+ def test_empty_state_no_crash(self):
191
+ params = _FakeParams(x=torch.randn(1, 4, 8, 8))
192
+ global_state.prompt_exprs = []
193
+ global_state.batch_cond_indices = []
194
+ try:
195
+ _on_cfg_denoiser(params)
196
+ except Exception as e:
197
+ self.fail(f"_on_cfg_denoiser raised with empty state: {e}")
198
+
199
+ # 6. two batch entries → no crash
200
+ def test_batch_multiple_prompts_no_crash(self):
201
+ params = _FakeParams(x=torch.zeros(4, 4, 8, 8))
202
+ global_state.prompt_exprs = [
203
+ _comp(_leaf("base_a"), _leaf("perp_a", 0.25)),
204
+ _comp(_leaf("base_b"), _leaf("perp_b", 0.125)),
205
+ ]
206
+ global_state.batch_cond_indices = [
207
+ [(0, 1.0), (1, 1.0)],
208
+ [(2, 1.0), (3, 1.0)],
209
+ ]
210
+ try:
211
+ _on_cfg_denoiser(params)
212
+ except Exception as e:
213
+ self.fail(f"Multi-prompt batch raised: {e}")
214
+
215
+
216
+ if __name__ == '__main__':
217
+ unittest.main()
neutral_prompt_patcheds/test/perp_parser/test_basic_parser.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import pathlib
3
+ import sys
4
+ sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
5
+ from lib_neutral_prompt import neutral_prompt_parser
6
+
7
+
8
+ class TestPromptParser(unittest.TestCase):
9
+ def setUp(self):
10
+ self.simple_prompt = neutral_prompt_parser.parse_root("hello :1.0")
11
+ self.and_prompt = neutral_prompt_parser.parse_root("hello AND goodbye :2.0")
12
+ self.and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP goodbye :2.0")
13
+ self.and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT goodbye :2.0")
14
+ self.nested_and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0 AND_PERP welcome :3.0]")
15
+ self.nested_and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0 AND_SALT welcome :3.0]")
16
+ self.invalid_weight = neutral_prompt_parser.parse_root("hello :not_a_float")
17
+
18
+ def test_simple_prompt_child_count(self):
19
+ self.assertEqual(len(self.simple_prompt.children), 1)
20
+
21
+ def test_simple_prompt_child_weight(self):
22
+ self.assertEqual(self.simple_prompt.children[0].weight, 1.0)
23
+
24
+ def test_simple_prompt_child_prompt(self):
25
+ self.assertEqual(self.simple_prompt.children[0].prompt, "hello ")
26
+
27
+ def test_square_weight_prompt(self):
28
+ prompt = "a [b c d e : f g h :1.5]"
29
+ parsed = neutral_prompt_parser.parse_root(prompt)
30
+ self.assertEqual(parsed.children[0].prompt, prompt)
31
+
32
+ composed_prompt = f"{prompt} AND_PERP other prompt"
33
+ parsed = neutral_prompt_parser.parse_root(composed_prompt)
34
+ self.assertEqual(parsed.children[0].prompt, prompt)
35
+
36
+ def test_and_prompt_child_count(self):
37
+ self.assertEqual(len(self.and_prompt.children), 2)
38
+
39
+ def test_and_prompt_child_weights_and_prompts(self):
40
+ self.assertEqual(self.and_prompt.children[0].weight, 1.0)
41
+ self.assertEqual(self.and_prompt.children[0].prompt, "hello ")
42
+ self.assertEqual(self.and_prompt.children[1].weight, 2.0)
43
+ self.assertEqual(self.and_prompt.children[1].prompt, " goodbye ")
44
+
45
+ def test_and_perp_prompt_child_count(self):
46
+ self.assertEqual(len(self.and_perp_prompt.children), 2)
47
+
48
+ def test_and_perp_prompt_child_types(self):
49
+ self.assertIsInstance(self.and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
50
+ self.assertIsInstance(self.and_perp_prompt.children[1], neutral_prompt_parser.LeafPrompt)
51
+
52
+ def test_and_perp_prompt_nested_child(self):
53
+ nested_child = self.and_perp_prompt.children[1]
54
+ self.assertEqual(nested_child.weight, 2.0)
55
+ self.assertEqual(nested_child.prompt.strip(), "goodbye")
56
+
57
+ def test_nested_and_perp_prompt_child_count(self):
58
+ self.assertEqual(len(self.nested_and_perp_prompt.children), 2)
59
+
60
+ def test_nested_and_perp_prompt_child_types(self):
61
+ self.assertIsInstance(self.nested_and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
62
+ self.assertIsInstance(self.nested_and_perp_prompt.children[1], neutral_prompt_parser.CompositePrompt)
63
+
64
+ def test_nested_and_perp_prompt_nested_child_types(self):
65
+ nested_child = self.nested_and_perp_prompt.children[1].children[0]
66
+ self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
67
+ nested_child = self.nested_and_perp_prompt.children[1].children[1]
68
+ self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
69
+
70
+ def test_nested_and_perp_prompt_nested_child(self):
71
+ nested_child = self.nested_and_perp_prompt.children[1].children[1]
72
+ self.assertEqual(nested_child.weight, 3.0)
73
+ self.assertEqual(nested_child.prompt.strip(), "welcome")
74
+
75
+ def test_invalid_weight_child_count(self):
76
+ self.assertEqual(len(self.invalid_weight.children), 1)
77
+
78
+ def test_invalid_weight_child_weight(self):
79
+ self.assertEqual(self.invalid_weight.children[0].weight, 1.0)
80
+
81
+ def test_invalid_weight_child_prompt(self):
82
+ self.assertEqual(self.invalid_weight.children[0].prompt, "hello :not_a_float")
83
+
84
+ def test_and_salt_prompt_child_count(self):
85
+ self.assertEqual(len(self.and_salt_prompt.children), 2)
86
+
87
+ def test_and_salt_prompt_child_types(self):
88
+ self.assertIsInstance(self.and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
89
+ self.assertIsInstance(self.and_salt_prompt.children[1], neutral_prompt_parser.LeafPrompt)
90
+
91
+ def test_and_salt_prompt_nested_child(self):
92
+ nested_child = self.and_salt_prompt.children[1]
93
+ self.assertEqual(nested_child.weight, 2.0)
94
+ self.assertEqual(nested_child.prompt.strip(), "goodbye")
95
+
96
+ def test_nested_and_salt_prompt_child_count(self):
97
+ self.assertEqual(len(self.nested_and_salt_prompt.children), 2)
98
+
99
+ def test_nested_and_salt_prompt_child_types(self):
100
+ self.assertIsInstance(self.nested_and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
101
+ self.assertIsInstance(self.nested_and_salt_prompt.children[1], neutral_prompt_parser.CompositePrompt)
102
+
103
+ def test_nested_and_salt_prompt_nested_child_types(self):
104
+ nested_child = self.nested_and_salt_prompt.children[1].children[0]
105
+ self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
106
+ nested_child = self.nested_and_salt_prompt.children[1].children[1]
107
+ self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
108
+
109
+ def test_nested_and_salt_prompt_nested_child(self):
110
+ nested_child = self.nested_and_salt_prompt.children[1].children[1]
111
+ self.assertEqual(nested_child.weight, 3.0)
112
+ self.assertEqual(nested_child.prompt.strip(), "welcome")
113
+
114
+ def test_start_with_prompt_editing(self):
115
+ prompt = "[(long shot:1.2):0.1] detail.."
116
+ res = neutral_prompt_parser.parse_root(prompt)
117
+ self.assertEqual(res.children[0].weight, 1.0)
118
+ self.assertEqual(res.children[0].prompt, prompt)
119
+
120
+
121
+ if __name__ == '__main__':
122
+ unittest.main()
neutral_prompt_patcheds/test/perp_parser/test_lock_after_end.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for Lock after End (Sprint 3 feature).
3
+
4
+ Covers:
5
+ 1. Basic lock: strategy frozen after window closes
6
+ 2. No-lock: strategy reactivates normally without the flag
7
+ 3. Lock not set before window start (haven't entered the window yet)
8
+ 4. Reset at step 0: lock flags cleared each generation
9
+ 5. Per-strategy custom: only the specific strategy gets locked
10
+ 6. Global window + lock interaction
11
+ 7. Infotext round-trip: lock_after_end serialized and restored
12
+ 8. update helper: reset_step_lock_flags clears all flags
13
+ """
14
+
15
+ import unittest
16
+ from lib_neutral_prompt import global_state
17
+ from lib_neutral_prompt.step_utils import (
18
+ StepWindow,
19
+ normalize_progress,
20
+ strategy_is_active,
21
+ strategy_is_active_from_state,
22
+ build_per_strategy_windows,
23
+ )
24
+
25
+ class _StateGuard:
26
+ """Context manager that saves and restores global_state step-window vars."""
27
+ _FIELDS = (
28
+ 'step_window_enabled',
29
+ 'step_window_global',
30
+ 'step_window_per_strategy',
31
+ 'step_window_use_defaults',
32
+ 'step_window_lock_after_end',
33
+ '_step_lock_flags',
34
+ '_lock_generation_id',
35
+ 'current_step',
36
+ 'total_steps',
37
+ )
38
+
39
+ def __enter__(self):
40
+ self._saved = {f: getattr(global_state, f, None) for f in self._FIELDS}
41
+ return self
42
+
43
+ def __exit__(self, *_):
44
+ for f, v in self._saved.items():
45
+ setattr(global_state, f, v)
46
+ return False
47
+
48
+
49
+ # ---------------------------------------------------------------------------
50
+ # 1. Basic lock behaviour
51
+ # ---------------------------------------------------------------------------
52
+
53
+ class TestLockAfterEndBasic(unittest.TestCase):
54
+
55
+ def test_lock_freezes_after_window_end(self):
56
+ """
57
+ Window 0.0–0.5: active at progress 0.3, inactive (and locked) at 0.7.
58
+ Once locked, stays False even if progress would re-enter the range.
59
+ """
60
+ with _StateGuard():
61
+ global_state.step_window_enabled = True
62
+ global_state.step_window_lock_after_end = True
63
+ global_state.step_window_global = StepWindow(0.0, 0.5)
64
+ global_state.step_window_use_defaults = False
65
+ global_state.step_window_per_strategy = None
66
+ global_state._step_lock_flags = {}
67
+
68
+ # Inside window — must be active
69
+ self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3))
70
+
71
+ # After window ends at progress 0.7 → inactive and locked
72
+ result_late = strategy_is_active_from_state('PERPENDICULAR', 0.7)
73
+ self.assertFalse(result_late)
74
+
75
+ # Flag must now be set
76
+ self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False))
77
+
78
+ # Subsequent call at 0.4 (would normally be inside window) still False
79
+ result_rewound = strategy_is_active_from_state('PERPENDICULAR', 0.4)
80
+ self.assertFalse(result_rewound)
81
+
82
+ def test_no_lock_flag_allows_reactivation(self):
83
+ """Without lock_after_end the strategy returns True when re-entering range."""
84
+ with _StateGuard():
85
+ global_state.step_window_enabled = True
86
+ global_state.step_window_lock_after_end = False
87
+ global_state.step_window_global = StepWindow(0.0, 0.5)
88
+ global_state.step_window_use_defaults = False
89
+ global_state.step_window_per_strategy = None
90
+ global_state._step_lock_flags = {}
91
+
92
+ self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3))
93
+ self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.7))
94
+ # No lock set → a re-query at 0.3 returns True
95
+ self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3))
96
+
97
+
98
+ # ---------------------------------------------------------------------------
99
+ # 2. Lock not fired before window start
100
+ # ---------------------------------------------------------------------------
101
+
102
+ class TestLockNotFiredBeforeWindowStart(unittest.TestCase):
103
+
104
+ def test_progress_before_window_start_does_not_lock(self):
105
+ """
106
+ If progress is before the window start (strategy hasn't started yet),
107
+ the lock must NOT be set — otherwise the strategy could never activate.
108
+ Window 0.4–1.0: at progress 0.1 strategy is inactive but lock must NOT fire.
109
+ """
110
+ with _StateGuard():
111
+ global_state.step_window_enabled = True
112
+ global_state.step_window_lock_after_end = True
113
+ global_state.step_window_global = StepWindow(0.4, 1.0)
114
+ global_state.step_window_use_defaults = False
115
+ global_state.step_window_per_strategy = None
116
+ global_state._step_lock_flags = {}
117
+
118
+ # Step 0 of 10 → progress 0.0 → before window start
119
+ result = strategy_is_active_from_state('SALIENCE_MASK', 0.0)
120
+ self.assertFalse(result)
121
+ self.assertFalse(global_state._step_lock_flags.get('SALIENCE_MASK', False),
122
+ 'Lock must not fire before window start')
123
+
124
+ # Now enter the window at 0.5 → must be active
125
+ result_in = strategy_is_active_from_state('SALIENCE_MASK', 0.5)
126
+ self.assertTrue(result_in)
127
+
128
+
129
+ # ---------------------------------------------------------------------------
130
+ # 3. reset_step_lock_flags and begin_new_generation
131
+ # ---------------------------------------------------------------------------
132
+
133
+ class TestLockLifecycle(unittest.TestCase):
134
+
135
+ def test_reset_clears_all_flags(self):
136
+ with _StateGuard():
137
+ global_state._step_lock_flags = {
138
+ 'PERPENDICULAR': True,
139
+ 'SALIENCE_MASK': True,
140
+ }
141
+ global_state.reset_step_lock_flags()
142
+ self.assertEqual(global_state._step_lock_flags, {})
143
+
144
+ def test_reset_idempotent_on_empty(self):
145
+ with _StateGuard():
146
+ global_state._step_lock_flags = {}
147
+ global_state.reset_step_lock_flags()
148
+ self.assertEqual(global_state._step_lock_flags, {})
149
+
150
+ def test_begin_new_generation_clears_flags(self):
151
+ with _StateGuard():
152
+ global_state._step_lock_flags = {'PERPENDICULAR': True}
153
+ global_state.begin_new_generation()
154
+ self.assertEqual(global_state._step_lock_flags, {})
155
+
156
+ def test_begin_new_generation_increments_id(self):
157
+ with _StateGuard():
158
+ before = global_state._lock_generation_id
159
+ global_state.begin_new_generation()
160
+ self.assertEqual(global_state._lock_generation_id, before + 1)
161
+ global_state.begin_new_generation()
162
+ self.assertEqual(global_state._lock_generation_id, before + 2)
163
+
164
+ def test_strategy_active_after_begin_new_generation(self):
165
+ """After begin_new_generation, a previously locked strategy re-arms."""
166
+ with _StateGuard():
167
+ global_state.step_window_enabled = True
168
+ global_state.step_window_lock_after_end = True
169
+ global_state.step_window_global = StepWindow(0.0, 0.5)
170
+ global_state.step_window_use_defaults = False
171
+ global_state.step_window_per_strategy = None
172
+ global_state._step_lock_flags = {}
173
+
174
+ # Lock it
175
+ self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3))
176
+ self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.7))
177
+ self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False))
178
+
179
+ # New generation → re-arms
180
+ global_state.begin_new_generation()
181
+ self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3))
182
+
183
+ # ── KEY TEST: restart / rollback safety ───────────────────────────────
184
+
185
+ def test_step_zero_revisit_does_not_clear_lock(self):
186
+ """
187
+ Sampler-internal rollback to step 0 (Restart sampler, DPM++ multi-pass)
188
+ must NOT clear lock flags. Only begin_new_generation() (= process() hook)
189
+ should do that.
190
+ """
191
+ with _StateGuard():
192
+ global_state.step_window_enabled = True
193
+ global_state.step_window_lock_after_end = True
194
+ global_state.step_window_global = StepWindow(0.0, 0.5)
195
+ global_state.step_window_use_defaults = False
196
+ global_state.step_window_per_strategy = None
197
+ global_state._step_lock_flags = {}
198
+
199
+ # Normal run: enters window, exits, gets locked
200
+ self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3))
201
+ self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.7))
202
+ self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False))
203
+
204
+ # Simulate sampler rollback: step counter goes back to 0
205
+ global_state.update_step_state(0, 20)
206
+ progress_after_rollback = normalize_progress(
207
+ global_state.current_step, global_state.total_steps
208
+ )
209
+ # Without a new process() call, lock must still be active
210
+ result = strategy_is_active_from_state('PERPENDICULAR', progress_after_rollback)
211
+ self.assertFalse(result,
212
+ 'Lock must survive sampler-internal rollback to step 0')
213
+ self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False),
214
+ 'Flag must remain set after step-0 rollback')
215
+
216
+ def test_hires_fix_second_pass_clears_lock(self):
217
+ """
218
+ Hires fix calls process() again for its second denoising pass.
219
+ Simulating that with begin_new_generation() must re-arm strategies.
220
+ """
221
+ with _StateGuard():
222
+ global_state.step_window_enabled = True
223
+ global_state.step_window_lock_after_end = True
224
+ global_state.step_window_global = StepWindow(0.0, 0.5)
225
+ global_state.step_window_use_defaults = False
226
+ global_state.step_window_per_strategy = None
227
+ global_state._step_lock_flags = {}
228
+
229
+ # First pass: gets locked
230
+ self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.7))
231
+
232
+ # Hires fix: process() is called → begin_new_generation()
233
+ global_state.begin_new_generation()
234
+
235
+ # Second pass must start fresh
236
+ self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3))
237
+
238
+
239
+ # ---------------------------------------------------------------------------
240
+ # 4. Per-strategy: only the locked strategy is frozen
241
+ # ---------------------------------------------------------------------------
242
+
243
+ class TestLockPerStrategy(unittest.TestCase):
244
+
245
+ def test_only_locked_strategy_frozen(self):
246
+ """
247
+ Custom windows: AND_PERP 0.0–0.4, AND_SALT 0.5–1.0.
248
+ After AND_PERP window closes (progress=0.6), AND_PERP is locked.
249
+ AND_SALT is still active.
250
+ """
251
+ with _StateGuard():
252
+ raw = {'AND_PERP': (0.0, 0.4), 'AND_SALT': (0.5, 1.0)}
253
+ global_state.step_window_enabled = True
254
+ global_state.step_window_lock_after_end = True
255
+ global_state.step_window_global = None
256
+ global_state.step_window_use_defaults = False
257
+ global_state.step_window_per_strategy = build_per_strategy_windows(raw)
258
+ global_state._step_lock_flags = {}
259
+
260
+ # Inside AND_PERP window
261
+ self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.2))
262
+ # Inside AND_SALT window
263
+ self.assertTrue(strategy_is_active_from_state('SALIENCE_MASK', 0.6))
264
+
265
+ # After AND_PERP ends → lock fires
266
+ self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.6))
267
+ self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False))
268
+
269
+ # AND_SALT unaffected
270
+ self.assertFalse(global_state._step_lock_flags.get('SALIENCE_MASK', False))
271
+ self.assertTrue(strategy_is_active_from_state('SALIENCE_MASK', 0.8))
272
+
273
+ def test_defaults_mode_with_lock(self):
274
+ """
275
+ use_defaults=True: PERPENDICULAR default 0.0–0.5.
276
+ At progress 0.7 it should lock (window has ended and we were past start).
277
+ """
278
+ with _StateGuard():
279
+ global_state.step_window_enabled = True
280
+ global_state.step_window_lock_after_end = True
281
+ global_state.step_window_global = None
282
+ global_state.step_window_use_defaults = True
283
+ global_state.step_window_per_strategy = None
284
+ global_state._step_lock_flags = {}
285
+
286
+ # Active inside default window
287
+ self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3))
288
+ # Outside → lock
289
+ self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.7))
290
+ self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False))
291
+
292
+
293
+ # ---------------------------------------------------------------------------
294
+ # 5. Infotext round-trip
295
+ # ---------------------------------------------------------------------------
296
+
297
+ class TestLockAfterEndInfotext(unittest.TestCase):
298
+
299
+ def test_apply_infotext_restores_lock(self):
300
+ with _StateGuard():
301
+ global_state.step_window_lock_after_end = False
302
+ global_state.apply_infotext({
303
+ 'NP Step Window Enabled': 'True',
304
+ 'NP Step Window Mode': 'global',
305
+ 'NP Step Window Start': '0.0',
306
+ 'NP Step Window End': '0.5',
307
+ 'NP Step Window Lock': 'True',
308
+ })
309
+ self.assertTrue(global_state.step_window_lock_after_end)
310
+
311
+ def test_apply_infotext_disables_lock(self):
312
+ with _StateGuard():
313
+ global_state.step_window_lock_after_end = True
314
+ global_state.apply_infotext({'NP Step Window Lock': 'False'})
315
+ self.assertFalse(global_state.step_window_lock_after_end)
316
+
317
+ def test_apply_infotext_lock_absent_leaves_state_unchanged(self):
318
+ with _StateGuard():
319
+ global_state.step_window_lock_after_end = True
320
+ global_state.apply_infotext({'NP Step Window Enabled': 'True'})
321
+ # Key absent → unchanged
322
+ self.assertTrue(global_state.step_window_lock_after_end)
323
+
324
+ def test_lock_key_in_infotext_constants(self):
325
+ self.assertEqual(global_state.INFOTEXT_KEY_STEP_WIN_LOCK, 'NP Step Window Lock')
326
+
327
+
328
+ # ---------------------------------------------------------------------------
329
+ # 6. Global simulation: step 0→N progression with lock
330
+ # ---------------------------------------------------------------------------
331
+
332
+ class TestLockSimulatedRun(unittest.TestCase):
333
+ """
334
+ Simulate a real 10-step run with lock_after_end=True and window 0.0–0.5.
335
+ Verify: active steps 0–4, locked from step 5 onward.
336
+ """
337
+
338
+ def test_simulated_10_step_run(self):
339
+ with _StateGuard():
340
+ global_state.step_window_enabled = True
341
+ global_state.step_window_lock_after_end = True
342
+ global_state.step_window_global = StepWindow(0.0, 0.5)
343
+ global_state.step_window_use_defaults = False
344
+ global_state.step_window_per_strategy = None
345
+ global_state._step_lock_flags = {}
346
+
347
+ results = []
348
+ for step in range(10):
349
+ global_state.update_step_state(step, 10)
350
+ p = normalize_progress(global_state.current_step, global_state.total_steps)
351
+ active = strategy_is_active_from_state('PERPENDICULAR', p)
352
+ results.append(active)
353
+
354
+ # Steps 0–4: progress 0.0–0.44 → inside window 0.0–0.5 → active
355
+ for i in range(5):
356
+ self.assertTrue(results[i], f'Expected active at step {i}')
357
+ # Step 5 onward: progress 0.56+ → outside → locked
358
+ for i in range(5, 10):
359
+ self.assertFalse(results[i], f'Expected locked at step {i}')
360
+
361
+
362
+ # ---------------------------------------------------------------------------
363
+ # 7. Restart-within-same-generation must NOT clear lock flags
364
+ # ---------------------------------------------------------------------------
365
+
366
+ class TestLockSurvivesRestartWithinRun(unittest.TestCase):
367
+ """
368
+ Key property: if the sampler revisits step 0 within one generation
369
+ (e.g. Restart sampler, DPM++ SDE second pass, hires fix denoiser),
370
+ the lock flags must be preserved.
371
+
372
+ This is guaranteed by begin_new_generation() being called from process()
373
+ lifecycle, NOT from step == 0 in the denoising loop.
374
+ """
375
+
376
+ def test_begin_new_generation_called_once_then_step_reset(self):
377
+ """
378
+ Simulate: generation starts → begin_new_generation() → strategy locks →
379
+ sampler restarts (step resets to 0 internally) → lock survives.
380
+ """
381
+ with _StateGuard():
382
+ global_state.step_window_enabled = True
383
+ global_state.step_window_lock_after_end = True
384
+ global_state.step_window_global = StepWindow(0.0, 0.5)
385
+ global_state.step_window_use_defaults = False
386
+ global_state.step_window_per_strategy = None
387
+
388
+ # --- Generation starts: process() calls begin_new_generation() ---
389
+ global_state.begin_new_generation()
390
+
391
+ # Inside window → active
392
+ self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3))
393
+
394
+ # Window ends → locked
395
+ self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.7))
396
+ self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False),
397
+ 'Lock flag must be set after window end')
398
+
399
+ # Sampler restart: step goes back to 0 — but NO new begin_new_generation()
400
+ # (process() is NOT called again for a restart within one generation)
401
+ global_state.update_step_state(0, 20)
402
+
403
+ # Lock must survive the step reset
404
+ self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.0),
405
+ 'Strategy must remain locked after intra-run step reset')
406
+ self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False),
407
+ 'Lock flag must persist after intra-run step reset')
408
+
409
+ def test_ensure_step_lock_run_initialized_same_token_no_reset(self):
410
+ """
411
+ ensure_step_lock_run_initialized with the same token must NOT reset flags.
412
+ """
413
+ with _StateGuard():
414
+ global_state.step_window_enabled = True
415
+ global_state.step_window_lock_after_end = True
416
+ global_state.step_window_global = StepWindow(0.0, 0.5)
417
+ global_state.step_window_use_defaults = False
418
+ global_state.step_window_per_strategy = None
419
+ global_state._step_lock_flags = {}
420
+
421
+ RUN_TOKEN = object()
422
+ global_state.ensure_step_lock_run_initialized(RUN_TOKEN)
423
+
424
+ # Active inside window
425
+ self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.2))
426
+ # Window ends → lock fires
427
+ self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.7))
428
+ self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False))
429
+
430
+ # Same token called again (sampler restart within same generation)
431
+ reset_happened = global_state.ensure_step_lock_run_initialized(RUN_TOKEN)
432
+
433
+ self.assertFalse(reset_happened, 'Same token must not trigger reset')
434
+ self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False),
435
+ 'Lock flag must survive same-token re-call')
436
+
437
+ # Strategy must still be locked
438
+ self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.3),
439
+ 'Strategy must stay locked within same run')
440
+
441
+ def test_ensure_step_lock_run_initialized_new_token_resets(self):
442
+ """
443
+ ensure_step_lock_run_initialized with a NEW token must reset flags
444
+ (simulates a genuinely new generation).
445
+ """
446
+ with _StateGuard():
447
+ global_state.step_window_enabled = True
448
+ global_state.step_window_lock_after_end = True
449
+ global_state.step_window_global = StepWindow(0.0, 0.5)
450
+ global_state.step_window_use_defaults = False
451
+ global_state.step_window_per_strategy = None
452
+ global_state._step_lock_flags = {}
453
+
454
+ TOKEN_A = object()
455
+ TOKEN_B = object()
456
+
457
+ global_state.ensure_step_lock_run_initialized(TOKEN_A)
458
+
459
+ # Lock fires
460
+ strategy_is_active_from_state('PERPENDICULAR', 0.2) # activate
461
+ self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.7))
462
+ self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False))
463
+
464
+ # New generation: new token
465
+ reset_happened = global_state.ensure_step_lock_run_initialized(TOKEN_B)
466
+
467
+ self.assertTrue(reset_happened, 'New token must trigger reset')
468
+ self.assertFalse(global_state._step_lock_flags.get('PERPENDICULAR', False),
469
+ 'Lock flags must be cleared for new generation')
470
+
471
+ # Strategy must now be active again (inside window)
472
+ self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3),
473
+ 'Strategy must be active again after new generation reset')
474
+
475
+
476
+ # ---------------------------------------------------------------------------
477
+ # 8. begin_new_generation() vs reset_step_lock_flags() contract
478
+ # ---------------------------------------------------------------------------
479
+
480
+ class TestBeginNewGenerationContract(unittest.TestCase):
481
+
482
+ def test_begin_new_generation_increments_id(self):
483
+ with _StateGuard():
484
+ before = global_state._lock_generation_id
485
+ global_state.begin_new_generation()
486
+ self.assertGreater(global_state._lock_generation_id, before)
487
+
488
+ def test_begin_new_generation_clears_flags(self):
489
+ with _StateGuard():
490
+ global_state._step_lock_flags = {'PERPENDICULAR': True, 'SALIENCE_MASK': True}
491
+ global_state.begin_new_generation()
492
+ self.assertEqual(global_state._step_lock_flags, {})
493
+
494
+ def test_begin_new_generation_twice_clears_flags_both_times(self):
495
+ """Two consecutive begin_new_generation() calls each start with clean flags."""
496
+ with _StateGuard():
497
+ global_state.begin_new_generation()
498
+ global_state._step_lock_flags['PERPENDICULAR'] = True
499
+ global_state.begin_new_generation()
500
+ self.assertEqual(global_state._step_lock_flags, {})
501
+
502
+ def test_reset_step_lock_flags_does_not_change_generation_id(self):
503
+ with _StateGuard():
504
+ gen_id_before = global_state._lock_generation_id
505
+ global_state.reset_step_lock_flags()
506
+ self.assertEqual(global_state._lock_generation_id, gen_id_before)
507
+
508
+ def test_process_lifecycle_integration(self):
509
+ """
510
+ Verify the full lifecycle as used in scripts/neutral_prompt.py:
511
+ begin_new_generation() once per process() → locks survive denoising loop →
512
+ next begin_new_generation() clears them.
513
+ """
514
+ with _StateGuard():
515
+ global_state.step_window_enabled = True
516
+ global_state.step_window_lock_after_end = True
517
+ global_state.step_window_global = StepWindow(0.0, 0.5)
518
+ global_state.step_window_use_defaults = False
519
+ global_state.step_window_per_strategy = None
520
+
521
+ # --- Generation 1: process() called once ---
522
+ global_state.begin_new_generation()
523
+
524
+ strategy_is_active_from_state('PERPENDICULAR', 0.2) # activate
525
+ self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.7))
526
+ locked_gen1 = global_state._step_lock_flags.get('PERPENDICULAR', False)
527
+ self.assertTrue(locked_gen1)
528
+
529
+ # No new process() call → lock survives any step count
530
+ self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.0))
531
+
532
+ # --- Generation 2: new process() call ---
533
+ global_state.begin_new_generation()
534
+ flags_after_new_gen = global_state._step_lock_flags.copy()
535
+ self.assertEqual(flags_after_new_gen, {},
536
+ 'All flags must be cleared at start of new generation')
537
+
538
+ # Strategy active again inside its window
539
+ self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3))
540
+
541
+
542
+ if __name__ == '__main__':
543
+ unittest.main()
544
+
neutral_prompt_patcheds/test/perp_parser/test_malicious_parser.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import pathlib
3
+ import sys
4
+ sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
5
+ from lib_neutral_prompt import neutral_prompt_parser
6
+
7
+
8
+ class TestMaliciousPromptParser(unittest.TestCase):
9
+ def setUp(self):
10
+ self.parser = neutral_prompt_parser
11
+
12
+ def test_empty(self):
13
+ result = self.parser.parse_root("")
14
+ self.assertEqual(result.children[0].prompt, "")
15
+ self.assertEqual(result.children[0].weight, 1.0)
16
+
17
+ def test_zero_weight(self):
18
+ result = self.parser.parse_root("hello :0.0")
19
+ self.assertEqual(result.children[0].weight, 0.0)
20
+
21
+ def test_mixed_positive_and_negative_weights(self):
22
+ result = self.parser.parse_root("hello :1.0 AND goodbye :-2.0")
23
+ self.assertEqual(result.children[0].weight, 1.0)
24
+ self.assertEqual(result.children[1].weight, -2.0)
25
+
26
+ def test_debalanced_square_brackets(self):
27
+ prompt = "a [ b " * 100
28
+ result = self.parser.parse_root(prompt)
29
+ self.assertEqual(result.children[0].prompt, prompt)
30
+
31
+ prompt = "a ] b " * 100
32
+ result = self.parser.parse_root(prompt)
33
+ self.assertEqual(result.children[0].prompt, prompt)
34
+
35
+ repeats = 10
36
+ prompt = "a [ [ b AND c ] " * repeats
37
+ result = self.parser.parse_root(prompt)
38
+ self.assertEqual([x.prompt for x in result.children], ["a [[ b ", *[" c ] a [[ b "] * (repeats - 1), " c ]"])
39
+
40
+ repeats = 10
41
+ prompt = "a [ b AND c ] ] " * repeats
42
+ result = self.parser.parse_root(prompt)
43
+ self.assertEqual([x.prompt for x in result.children], ["a [ b ", *[" c ]] a [ b "] * (repeats - 1), " c ]]"])
44
+
45
+ def test_erroneous_syntax(self):
46
+ result = self.parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0")
47
+ self.assertEqual(result.children[0].weight, 1.0)
48
+ self.assertEqual(result.children[1].prompt, "[goodbye ")
49
+ self.assertEqual(result.children[1].weight, 2.0)
50
+
51
+ result = self.parser.parse_root("hello :1.0 AND_PERP goodbye :2.0]")
52
+ self.assertEqual(result.children[0].weight, 1.0)
53
+ self.assertEqual(result.children[1].prompt, " goodbye ")
54
+
55
+ result = self.parser.parse_root("hello :1.0 AND_PERP goodbye] :2.0")
56
+ self.assertEqual(result.children[1].prompt, " goodbye]")
57
+ self.assertEqual(result.children[1].weight, 2.0)
58
+
59
+ result = self.parser.parse_root("hello :1.0 AND_PERP a [ goodbye :2.0")
60
+ self.assertEqual(result.children[1].weight, 2.0)
61
+ self.assertEqual(result.children[1].prompt, " a [ goodbye ")
62
+
63
+ result = self.parser.parse_root("hello :1.0 AND_PERP AND goodbye :2.0")
64
+ self.assertEqual(result.children[0].weight, 1.0)
65
+ self.assertEqual(result.children[2].prompt, " goodbye ")
66
+
67
+ def test_huge_number_of_prompt_parts(self):
68
+ result = self.parser.parse_root(" AND ".join(f"hello{i} :{i}" for i in range(10**4)))
69
+ self.assertEqual(len(result.children), 10**4)
70
+
71
+ def test_prompt_ending_with_weight(self):
72
+ result = self.parser.parse_root("hello :1.0 AND :2.0")
73
+ self.assertEqual(result.children[0].weight, 1.0)
74
+ self.assertEqual(result.children[1].prompt, "")
75
+ self.assertEqual(result.children[1].weight, 2.0)
76
+
77
+ def test_huge_input_string(self):
78
+ big_string = "hello :1.0 AND " * 10**4
79
+ result = self.parser.parse_root(big_string)
80
+ self.assertEqual(len(result.children), 10**4 + 1)
81
+
82
+ def test_deeply_nested_prompt(self):
83
+ deeply_nested_prompt = "hello :1.0" + " AND_PERP [goodbye :2.0" * 100 + "]" * 100
84
+ result = self.parser.parse_root(deeply_nested_prompt)
85
+ self.assertIsInstance(result.children[1], neutral_prompt_parser.CompositePrompt)
86
+
87
+ def test_complex_nested_prompts(self):
88
+ complex_prompt = "hello :1.0 AND goodbye :2.0 AND_PERP [welcome :3.0 AND farewell :4.0 AND_PERP greetings:5.0]"
89
+ result = self.parser.parse_root(complex_prompt)
90
+ self.assertEqual(result.children[0].weight, 1.0)
91
+ self.assertEqual(result.children[1].weight, 2.0)
92
+ self.assertEqual(result.children[2].children[0].weight, 3.0)
93
+ self.assertEqual(result.children[2].children[1].weight, 4.0)
94
+ self.assertEqual(result.children[2].children[2].weight, 5.0)
95
+
96
+ def test_string_with_random_characters(self):
97
+ random_chars = "ASDFGHJKL:@#$/.,|}{><~`12[3]456AND_PERP7890"
98
+ try:
99
+ self.parser.parse_root(random_chars)
100
+ except Exception:
101
+ self.fail("parse_root couldn't handle a string with random characters.")
102
+
103
+ def test_string_with_unexpected_symbols(self):
104
+ unexpected_symbols = "hello :1.0 AND $%^&*()goodbye :2.0"
105
+ try:
106
+ self.parser.parse_root(unexpected_symbols)
107
+ except Exception:
108
+ self.fail("parse_root couldn't handle a string with unexpected symbols.")
109
+
110
+ def test_string_with_unconventional_structure(self):
111
+ unconventional_structure = "hello :1.0 AND_PERP :2.0 AND [goodbye]"
112
+ try:
113
+ self.parser.parse_root(unconventional_structure)
114
+ except Exception:
115
+ self.fail("parse_root couldn't handle a string with unconventional structure.")
116
+
117
+ def test_string_with_mixed_alphabets_and_numbers(self):
118
+ mixed_alphabets_and_numbers = "123hello :1.0 AND goodbye456 :2.0"
119
+ try:
120
+ self.parser.parse_root(mixed_alphabets_and_numbers)
121
+ except Exception:
122
+ self.fail("parse_root couldn't handle a string with mixed alphabets and numbers.")
123
+
124
+ def test_string_with_nested_brackets(self):
125
+ nested_brackets = "hello :1.0 AND [goodbye :2.0 AND [[welcome :3.0]]]"
126
+ try:
127
+ self.parser.parse_root(nested_brackets)
128
+ except Exception:
129
+ self.fail("parse_root couldn't handle a string with nested brackets.")
130
+
131
+ def test_unmatched_opening_braces(self):
132
+ unmatched_opening_braces = "hello [[[[[[[[[ :1.0 AND_PERP goodbye :2.0"
133
+ try:
134
+ self.parser.parse_root(unmatched_opening_braces)
135
+ except Exception:
136
+ self.fail("parse_root couldn't handle a string with unmatched opening braces.")
137
+
138
+ def test_unmatched_closing_braces(self):
139
+ unmatched_closing_braces = "hello :1.0 AND_PERP goodbye ]]]]]]]]] :2.0"
140
+ try:
141
+ self.parser.parse_root(unmatched_closing_braces)
142
+ except Exception:
143
+ self.fail("parse_root couldn't handle a string with unmatched closing braces.")
144
+
145
+ def test_repeating_colons(self):
146
+ repeating_colons = "hello ::::::: :1.0 AND_PERP goodbye :::: :2.0"
147
+ try:
148
+ self.parser.parse_root(repeating_colons)
149
+ except Exception:
150
+ self.fail("parse_root couldn't handle a string with repeating colons.")
151
+
152
+ def test_excessive_whitespace(self):
153
+ excessive_whitespace = "hello :1.0 AND_PERP goodbye :2.0"
154
+ try:
155
+ self.parser.parse_root(excessive_whitespace)
156
+ except Exception:
157
+ self.fail("parse_root couldn't handle a string with excessive whitespace.")
158
+
159
+ def test_repeating_AND_keyword(self):
160
+ repeating_AND_keyword = "hello :1.0 AND AND AND AND AND goodbye :2.0"
161
+ try:
162
+ self.parser.parse_root(repeating_AND_keyword)
163
+ except Exception:
164
+ self.fail("parse_root couldn't handle a string with repeating AND keyword.")
165
+
166
+ def test_repeating_AND_PERP_keyword(self):
167
+ repeating_AND_PERP_keyword = "hello :1.0 AND_PERP AND_PERP AND_PERP AND_PERP goodbye :2.0"
168
+ try:
169
+ self.parser.parse_root(repeating_AND_PERP_keyword)
170
+ except Exception:
171
+ self.fail("parse_root couldn't handle a string with repeating AND_PERP keyword.")
172
+
173
+ def test_square_weight_prompt(self):
174
+ prompt = "AND_PERP [weighted] you thought it was the end"
175
+ try:
176
+ self.parser.parse_root(prompt)
177
+ except Exception:
178
+ self.fail("parse_root couldn't handle a string starting with a square-weighted sub-prompt.")
179
+
180
+
181
+ if __name__ == '__main__':
182
+ unittest.main()
neutral_prompt_patcheds/test/perp_parser/test_matryoshka.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for v1.1 matryoshka tree render, diagnostics, and builder helpers.
3
+ All tests are pure-Python and require no real torch or gradio.
4
+ """
5
+
6
+ import unittest
7
+ # mock_torch is installed by the package __init__ before this file loads
8
+
9
+ import lib_neutral_prompt.neutral_prompt_parser as p
10
+ from lib_neutral_prompt.matryoshka_utils import (
11
+ render_prompt_node,
12
+ render_prompt_tree,
13
+ collect_prompt_diagnostics,
14
+ render_diagnostics,
15
+ render_full_explain,
16
+ build_child_block,
17
+ build_nested_prompt,
18
+ MATRYOSHKA_TEMPLATES,
19
+ _build_affine_snippet,
20
+ )
21
+
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # render_prompt_tree
25
+ # ---------------------------------------------------------------------------
26
+
27
+ class TestRenderPromptTree(unittest.TestCase):
28
+
29
+ def test_empty_prompt_returns_hint(self):
30
+ out = render_prompt_tree('')
31
+ self.assertIn('empty', out.lower())
32
+
33
+ def test_base_only(self):
34
+ out = render_prompt_tree('a dog in a park')
35
+ self.assertIn('ROOT', out)
36
+ self.assertIn('BASE', out)
37
+ self.assertIn('a dog in a park', out)
38
+
39
+ def test_single_child(self):
40
+ out = render_prompt_tree('base subject AND_SALT[5] texture :0.8')
41
+ self.assertIn('BASE', out)
42
+ self.assertIn('SALIENCE_MASK', out)
43
+ self.assertIn('k=5', out)
44
+
45
+ def test_topk_child(self):
46
+ out = render_prompt_tree('base AND_TOPK[0.1] highlights :0.6')
47
+ self.assertIn('SEMANTIC_GUIDANCE', out)
48
+ self.assertIn('threshold=0.1', out)
49
+
50
+ def test_align_child(self):
51
+ out = render_prompt_tree('base AND_ALIGN[4,8] style :0.7')
52
+ self.assertIn('ALIGNMENT_BLEND_CUSTOM', out)
53
+ self.assertIn('d=4', out)
54
+ self.assertIn('s=8', out)
55
+
56
+ def test_affine_annotated(self):
57
+ out = render_prompt_tree('base ROTATE[0.125] AND_PERP vivid :0.8')
58
+ self.assertIn('affine', out)
59
+
60
+ def test_multiple_children_all_shown(self):
61
+ out = render_prompt_tree('base AND_SALT[5] tex :0.8 AND_TOPK[0.1] style :0.6')
62
+ self.assertIn('SALIENCE_MASK', out)
63
+ self.assertIn('SEMANTIC_GUIDANCE', out)
64
+
65
+ def test_nested_prompt_depth(self):
66
+ # nested composite should show multiple levels
67
+ prompt = 'base AND_SALT[5] [ inner text :0.8 AND_TOPK[0.1] detail :0.5 ] :0.7'
68
+ out = render_prompt_tree(prompt)
69
+ self.assertIn('SALIENCE_MASK', out)
70
+ # inner structure rendered
71
+
72
+ def test_tree_connectors_present(self):
73
+ out = render_prompt_tree('base AND_SALT[5] tex :0.8 AND_PERP vivid :0.5')
74
+ # should contain box-drawing branch chars
75
+ self.assertTrue('├' in out or '└' in out)
76
+
77
+ def test_weight_shown(self):
78
+ out = render_prompt_tree('base AND_PERP style :0.42')
79
+ self.assertIn('0.42', out)
80
+
81
+ def test_parse_error_graceful(self):
82
+ # Pathological input must not raise
83
+ try:
84
+ out = render_prompt_tree('[[[[[')
85
+ except Exception as exc:
86
+ self.fail(f'render_prompt_tree raised on bad input: {exc}')
87
+
88
+ def test_deeply_nested(self):
89
+ prompt = ('base AND_SALT[5] [\n'
90
+ ' region :0.8\n'
91
+ ' AND_ALIGN[4,8] [\n'
92
+ ' sub-structure :0.7\n'
93
+ ' AND_TOPK[0.05] micro :0.4\n'
94
+ ' ] :0.6\n'
95
+ '] :0.7')
96
+ try:
97
+ out = render_prompt_tree(prompt)
98
+ except Exception as exc:
99
+ self.fail(f'deeply nested raised: {exc}')
100
+ self.assertIn('ROOT', out)
101
+
102
+
103
+ # ---------------------------------------------------------------------------
104
+ # render_prompt_node (unit)
105
+ # ---------------------------------------------------------------------------
106
+
107
+ class TestRenderPromptNode(unittest.TestCase):
108
+
109
+ def _leaf_node(self, text='hello', strat=None, params=None, w=1.0):
110
+ return p.LeafPrompt(
111
+ weight=w,
112
+ conciliation=strat,
113
+ local_transform=None,
114
+ prompt=text,
115
+ conciliation_params=params or {},
116
+ )
117
+
118
+ def test_base_leaf(self):
119
+ node = self._leaf_node('base text')
120
+ lines = render_prompt_node(node)
121
+ self.assertEqual(len(lines), 1)
122
+ self.assertIn('BASE', lines[0])
123
+ self.assertIn('base text', lines[0])
124
+
125
+ def test_leaf_with_params(self):
126
+ node = self._leaf_node('texture', strat=p.ConciliationStrategy.SALIENCE_MASK,
127
+ params={'k': 5.0})
128
+ lines = render_prompt_node(node)
129
+ self.assertIn('k=5.0', lines[0])
130
+
131
+ def test_leaf_affine_flag(self):
132
+ import lib_neutral_prompt.neutral_prompt_parser as _pp
133
+ node = self._leaf_node('vivid', strat=_pp.ConciliationStrategy.PERPENDICULAR)
134
+ try:
135
+ import torch
136
+ node.local_transform = torch.eye(3)[:-1]
137
+ except ImportError:
138
+ return # skip if no torch
139
+ lines = render_prompt_node(node)
140
+ self.assertIn('affine', lines[0])
141
+
142
+ def test_connector_chars_for_non_last(self):
143
+ node = self._leaf_node('x')
144
+ lines = render_prompt_node(node, is_last=False)
145
+ self.assertTrue(lines[0].startswith('├─'))
146
+
147
+ def test_connector_chars_for_last(self):
148
+ node = self._leaf_node('x')
149
+ lines = render_prompt_node(node, is_last=True)
150
+ self.assertTrue(lines[0].startswith('└─'))
151
+
152
+
153
+ # ---------------------------------------------------------------------------
154
+ # collect_prompt_diagnostics
155
+ # ---------------------------------------------------------------------------
156
+
157
+ class TestCollectDiagnostics(unittest.TestCase):
158
+
159
+ def test_empty_prompt(self):
160
+ d = collect_prompt_diagnostics('')
161
+ self.assertFalse(d['parse_ok'])
162
+
163
+ def test_base_only(self):
164
+ d = collect_prompt_diagnostics('a dog')
165
+ self.assertTrue(d['parse_ok'])
166
+ self.assertGreater(d['n_segments'], 0)
167
+ self.assertIn('BASE', d['strategies'])
168
+
169
+ def test_single_salt(self):
170
+ d = collect_prompt_diagnostics('base AND_SALT[5] texture :0.8')
171
+ self.assertTrue(d['parse_ok'])
172
+ self.assertIn('SALIENCE_MASK', d['strategies'])
173
+ self.assertIn('BASE', d['strategies'])
174
+
175
+ def test_no_base_fires_protection(self):
176
+ d = collect_prompt_diagnostics('AND_SALT[5] concept :0.8')
177
+ self.assertTrue(d['parse_ok'])
178
+ self.assertTrue(d['protection_would_fire'])
179
+
180
+ def test_with_base_no_fire(self):
181
+ d = collect_prompt_diagnostics('base subject AND_PERP style :0.8')
182
+ self.assertFalse(d['protection_would_fire'])
183
+
184
+ def test_max_depth_base(self):
185
+ d = collect_prompt_diagnostics('just a base')
186
+ self.assertGreaterEqual(d['max_depth'], 0)
187
+
188
+ def test_affine_counted(self):
189
+ d = collect_prompt_diagnostics('base ROTATE[0.125] AND_PERP vivid :0.8')
190
+ self.assertGreater(d['n_affine'], 0)
191
+
192
+ def test_multiple_strategies(self):
193
+ d = collect_prompt_diagnostics('base AND_SALT[5] tex :0.8 AND_TOPK[0.1] sty :0.6')
194
+ self.assertIn('SALIENCE_MASK', d['strategies'])
195
+ self.assertIn('SEMANTIC_GUIDANCE', d['strategies'])
196
+
197
+ def test_parse_error_graceful(self):
198
+ d = collect_prompt_diagnostics('AND[invalid garbage')
199
+ # Should not raise; parse_ok may be False or True depending on parser tolerance
200
+
201
+
202
+ # ---------------------------------------------------------------------------
203
+ # render_diagnostics
204
+ # ---------------------------------------------------------------------------
205
+
206
+ class TestRenderDiagnostics(unittest.TestCase):
207
+
208
+ def test_no_base_warning(self):
209
+ d = collect_prompt_diagnostics('AND_SALT[5] only :0.8')
210
+ out = render_diagnostics(d)
211
+ self.assertIn('WOULD FIRE', out)
212
+
213
+ def test_ok_case(self):
214
+ d = collect_prompt_diagnostics('base AND_PERP style :0.8')
215
+ out = render_diagnostics(d)
216
+ self.assertIn('OK', out)
217
+
218
+ def test_parse_error_shown(self):
219
+ d = {'parse_ok': False, 'parse_error': 'something bad'}
220
+ out = render_diagnostics(d)
221
+ self.assertIn('something bad', out)
222
+
223
+ def test_strategies_listed(self):
224
+ d = collect_prompt_diagnostics('base AND_SALT[5] tex :0.8')
225
+ out = render_diagnostics(d)
226
+ self.assertIn('SALIENCE_MASK', out)
227
+
228
+
229
+ # ---------------------------------------------------------------------------
230
+ # render_full_explain
231
+ # ---------------------------------------------------------------------------
232
+
233
+ class TestRenderFullExplain(unittest.TestCase):
234
+
235
+ def test_contains_both_tree_and_diagnostics(self):
236
+ out = render_full_explain('base AND_SALT[5] texture :0.8')
237
+ self.assertIn('ROOT', out)
238
+ self.assertIn('Diagnostics', out)
239
+
240
+ def test_empty_prompt(self):
241
+ out = render_full_explain('')
242
+ self.assertIn('empty', out.lower())
243
+
244
+ def test_nested_prompt(self):
245
+ try:
246
+ out = render_full_explain(
247
+ 'base AND_SALT[5] [ inner :0.8 AND_TOPK[0.05] detail :0.4 ] :0.7'
248
+ )
249
+ except Exception as exc:
250
+ self.fail(f'render_full_explain raised on nested: {exc}')
251
+ self.assertIn('ROOT', out)
252
+
253
+
254
+ # ---------------------------------------------------------------------------
255
+ # build_child_block
256
+ # ---------------------------------------------------------------------------
257
+
258
+ class TestBuildChildBlock(unittest.TestCase):
259
+
260
+ def test_simple_salt(self):
261
+ out = build_child_block('AND_SALT', 'texture detail', 0.8)
262
+ self.assertIn('AND_SALT[5]', out)
263
+ self.assertIn('texture detail', out)
264
+ self.assertIn(':0.8', out)
265
+
266
+ def test_topk_threshold(self):
267
+ out = build_child_block('AND_TOPK', 'highlights', 0.5)
268
+ self.assertIn('AND_TOPK[0.05]', out)
269
+ self.assertIn('highlights', out)
270
+
271
+ def test_align_default(self):
272
+ out = build_child_block('AND_ALIGN', 'structure', 0.7)
273
+ self.assertIn('AND_ALIGN[4,8]', out)
274
+
275
+ def test_perp_simple(self):
276
+ out = build_child_block('AND_PERP', 'blur', 0.5)
277
+ self.assertIn('AND_PERP', out)
278
+ self.assertIn('blur', out)
279
+
280
+ def test_with_affine(self):
281
+ out = build_child_block('AND_PERP', 'mirror', 0.6, affine='SCALE[-1,1]')
282
+ self.assertIn('SCALE[-1,1]', out)
283
+ self.assertTrue(out.startswith('SCALE[-1,1]'))
284
+
285
+ def test_with_nested_child(self):
286
+ nested = build_child_block('AND_TOPK', 'highlights', 0.4)
287
+ out = build_child_block('AND_SALT', 'texture', 0.8, nested=nested)
288
+ self.assertIn('[', out)
289
+ self.assertIn(']', out)
290
+ self.assertIn('AND_TOPK', out)
291
+
292
+ def test_empty_text_handled(self):
293
+ out = build_child_block('AND_SALT', '', 0.8)
294
+ # Empty text → empty string (caller should skip)
295
+ self.assertEqual(out, '')
296
+
297
+ def test_weight_formatting(self):
298
+ out = build_child_block('AND_PERP', 'concept', 0.35)
299
+ self.assertIn(':0.35', out)
300
+
301
+ def test_nested_block_indented(self):
302
+ nested = build_child_block('AND_PERP', 'inner', 0.4)
303
+ out = build_child_block('AND_SALT', 'outer', 0.8, nested=nested)
304
+ # Inner content should be indented
305
+ lines = out.splitlines()
306
+ inner_lines = [l for l in lines if 'inner' in l]
307
+ self.assertTrue(any(l.startswith(' ') for l in inner_lines),
308
+ 'Inner nested line should be indented')
309
+
310
+
311
+ # ---------------------------------------------------------------------------
312
+ # build_nested_prompt
313
+ # ---------------------------------------------------------------------------
314
+
315
+ class TestBuildNestedPrompt(unittest.TestCase):
316
+
317
+ def test_base_only(self):
318
+ out = build_nested_prompt('a dog', [])
319
+ self.assertEqual(out.strip(), 'a dog')
320
+
321
+ def test_base_with_one_child(self):
322
+ out = build_nested_prompt('base subject', [
323
+ {'strategy': 'AND_SALT', 'text': 'texture', 'weight': 0.8},
324
+ ])
325
+ self.assertIn('base subject', out)
326
+ self.assertIn('AND_SALT', out)
327
+ self.assertIn('texture', out)
328
+
329
+ def test_two_children(self):
330
+ out = build_nested_prompt('base', [
331
+ {'strategy': 'AND_SALT', 'text': 'tex', 'weight': 0.8},
332
+ {'strategy': 'AND_TOPK', 'text': 'sty', 'weight': 0.5},
333
+ ])
334
+ self.assertIn('AND_SALT', out)
335
+ self.assertIn('AND_TOPK', out)
336
+
337
+ def test_child_with_nested(self):
338
+ inner = build_child_block('AND_TOPK', 'highlights', 0.4)
339
+ out = build_nested_prompt('base', [
340
+ {'strategy': 'AND_SALT', 'text': 'texture', 'weight': 0.8, 'nested': inner},
341
+ ])
342
+ self.assertIn('AND_TOPK', out)
343
+
344
+ def test_generated_prompt_parseable(self):
345
+ out = build_nested_prompt('base subject', [
346
+ {'strategy': 'AND_SALT', 'text': 'tex', 'weight': 0.8},
347
+ {'strategy': 'AND_PERP', 'text': 'sty', 'weight': 0.5},
348
+ ])
349
+ try:
350
+ p.parse_root(out)
351
+ except Exception as exc:
352
+ self.fail(f'build_nested_prompt produced unparseable prompt: {exc}')
353
+
354
+ def test_empty_base_no_crash(self):
355
+ out = build_nested_prompt('', [
356
+ {'strategy': 'AND_SALT', 'text': 'texture', 'weight': 0.8},
357
+ ])
358
+ self.assertIn('AND_SALT', out)
359
+
360
+ def test_child_with_empty_text_skipped(self):
361
+ out = build_nested_prompt('base', [
362
+ {'strategy': 'AND_SALT', 'text': '', 'weight': 0.8},
363
+ ])
364
+ # Empty text child: no crash, base survives
365
+ self.assertIn('base', out)
366
+
367
+
368
+ # ---------------------------------------------------------------------------
369
+ # MATRYOSHKA_TEMPLATES
370
+ # ---------------------------------------------------------------------------
371
+
372
+ class TestMatryoshkaTemplates(unittest.TestCase):
373
+
374
+ def test_all_templates_have_required_keys(self):
375
+ for name, tmpl in MATRYOSHKA_TEMPLATES.items():
376
+ with self.subTest(template=name):
377
+ self.assertIn('description', tmpl, f'{name}: missing "description"')
378
+ self.assertIn('prompt', tmpl, f'{name}: missing "prompt"')
379
+ self.assertIsInstance(tmpl['description'], str)
380
+ self.assertIsInstance(tmpl['prompt'], str)
381
+
382
+ def test_all_templates_parse_without_crash(self):
383
+ for name, tmpl in MATRYOSHKA_TEMPLATES.items():
384
+ with self.subTest(template=name):
385
+ try:
386
+ p.parse_root(tmpl['prompt'])
387
+ except Exception as exc:
388
+ self.fail(f'Template "{name}" failed parse_root: {exc}')
389
+
390
+ def test_all_template_prompts_nonempty(self):
391
+ for name, tmpl in MATRYOSHKA_TEMPLATES.items():
392
+ with self.subTest(template=name):
393
+ self.assertTrue(tmpl['prompt'].strip(), f'{name} prompt is empty')
394
+
395
+ def test_at_least_five_templates(self):
396
+ self.assertGreaterEqual(len(MATRYOSHKA_TEMPLATES), 5)
397
+
398
+ def test_deep_nested_template_parseable(self):
399
+ tmpl = MATRYOSHKA_TEMPLATES.get('Deep nested concept isolation')
400
+ if tmpl:
401
+ try:
402
+ p.parse_root(tmpl['prompt'])
403
+ except Exception as exc:
404
+ self.fail(f'Deep nested template parse failed: {exc}')
405
+
406
+
407
+ # ---------------------------------------------------------------------------
408
+ # _build_affine_snippet (already tested in test_parametric_syntax, quick sanity)
409
+ # ---------------------------------------------------------------------------
410
+
411
+ class TestAffineSnippetSanity(unittest.TestCase):
412
+
413
+ def test_flip_h(self):
414
+ self.assertEqual(_build_affine_snippet('FLIP_H', 0, 0), 'SCALE[-1,1]')
415
+
416
+ def test_rotate(self):
417
+ s = _build_affine_snippet('ROTATE', 0.25, 1.0)
418
+ self.assertIn('ROTATE', s)
419
+ self.assertIn('0.25', s)
420
+
421
+ def test_scale_uniform(self):
422
+ s = _build_affine_snippet('SCALE', 2.0, 2.0)
423
+ self.assertEqual(s, 'SCALE[2.0]')
424
+
425
+ def test_slide_2d(self):
426
+ s = _build_affine_snippet('SLIDE', 0.1, 0.2)
427
+ self.assertEqual(s, 'SLIDE[0.1,0.2]')
428
+
429
+ def test_snippet_parseable(self):
430
+ snippets = ['ROTATE[0.125]', 'SCALE[-1,1]', 'SCALE[0.5]', 'SLIDE[0.1,0.0]']
431
+ for snip in snippets:
432
+ with self.subTest(snip=snip):
433
+ try:
434
+ p.parse_root(f'base {snip} AND_PERP style :0.8')
435
+ except Exception as exc:
436
+ self.fail(f'Snippet "{snip}" not parseable: {exc}')
437
+
438
+
439
+ if __name__ == '__main__':
440
+ unittest.main()
neutral_prompt_patcheds/test/perp_parser/test_matryoshka_golden.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Golden tests for matryoshka_utils: tree render, diagnostics, effect summaries.
3
+
4
+ These tests lock in the *exact* formatted output for a set of representative
5
+ prompts so that silent formatting/logic regressions are caught immediately.
6
+ They also cover the polishing fixes from Sprint 1:
7
+ - build_child_block: empty text → empty string
8
+ - build_child_block: empty text + nested → no bare ':weight' line
9
+ - render_full_explain: empty prompt → clean message only
10
+ - effect summaries appear in render_full_explain output
11
+ """
12
+
13
+ import unittest
14
+ # mock_torch installed by package __init__
15
+
16
+ import lib_neutral_prompt.neutral_prompt_parser as p
17
+ from lib_neutral_prompt.matryoshka_utils import (
18
+ build_child_block,
19
+ build_nested_prompt,
20
+ render_prompt_tree,
21
+ render_prompt_node,
22
+ collect_prompt_diagnostics,
23
+ render_diagnostics,
24
+ render_full_explain,
25
+ _render_effect_summary,
26
+ _effect_for_strategy,
27
+ )
28
+
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # Builder polishing: empty-text behaviour
32
+ # ---------------------------------------------------------------------------
33
+
34
+ class TestBuilderEmptyTextPolishing(unittest.TestCase):
35
+
36
+ def test_empty_text_returns_empty_string(self):
37
+ out = build_child_block('AND_SALT', '', 0.8)
38
+ self.assertEqual(out, '',
39
+ 'Empty text with no nested must return empty string')
40
+
41
+ def test_empty_text_empty_nested_returns_empty_string(self):
42
+ out = build_child_block('AND_TOPK', '', 0.5, nested='')
43
+ self.assertEqual(out, '')
44
+
45
+ def test_empty_text_whitespace_returns_empty_string(self):
46
+ out = build_child_block('AND_PERP', ' ', 0.7)
47
+ self.assertEqual(out, '')
48
+
49
+ def test_empty_text_with_nested_no_bare_weight_line(self):
50
+ """If text is empty but nested is not, the nested block must not
51
+ include a bare ' :0.5' line at the top."""
52
+ nested = 'AND_PERP blur :0.4'
53
+ out = build_child_block('AND_TOPK', '', 0.5, nested=nested)
54
+ # Should not be empty (nested content exists)
55
+ self.assertNotEqual(out, '')
56
+ # Must not contain a line that is only whitespace + ':weight'
57
+ for line in out.splitlines():
58
+ stripped = line.strip()
59
+ self.assertFalse(
60
+ stripped.startswith(':') and len(stripped.split()) == 1,
61
+ f'Found bare weight line: {repr(line)}',
62
+ )
63
+
64
+ def test_nonempty_text_with_nested_includes_text_line(self):
65
+ """When text is non-empty, the nested block must include a text line."""
66
+ nested = 'AND_PERP correction :0.3'
67
+ out = build_child_block('AND_SALT', 'texture', 0.8, nested=nested)
68
+ self.assertIn('texture', out)
69
+ self.assertIn('AND_PERP', out)
70
+ self.assertIn('[', out)
71
+
72
+ def test_nested_prompt_skips_empty_children(self):
73
+ out = build_nested_prompt('base subject', [
74
+ {'strategy': 'AND_SALT', 'text': '', 'weight': 0.8}, # empty → skip
75
+ {'strategy': 'AND_PERP', 'text': 'vivid', 'weight': 0.5}, # kept
76
+ {'strategy': 'AND_TOPK', 'text': ' ', 'weight': 0.6}, # whitespace → skip
77
+ ])
78
+ self.assertIn('AND_PERP', out)
79
+ self.assertNotIn('AND_SALT', out)
80
+ self.assertNotIn('AND_TOPK', out)
81
+
82
+ def test_nested_prompt_all_empty_children_gives_base_only(self):
83
+ out = build_nested_prompt('base', [
84
+ {'strategy': 'AND_SALT', 'text': '', 'weight': 0.8},
85
+ ])
86
+ self.assertEqual(out.strip(), 'base')
87
+
88
+ def test_nested_only_child_no_text(self):
89
+ """Child with empty text but non-empty nested should be kept."""
90
+ inner = build_child_block('AND_PERP', 'correction', 0.3)
91
+ out = build_child_block('AND_TOPK', '', 0.5, nested=inner)
92
+ self.assertNotEqual(out, '', 'Should keep block when nested is non-empty')
93
+ self.assertIn('AND_PERP', out)
94
+
95
+
96
+ # ---------------------------------------------------------------------------
97
+ # render_full_explain: empty-prompt clean state
98
+ # ---------------------------------------------------------------------------
99
+
100
+ class TestRenderFullExplainEmptyState(unittest.TestCase):
101
+
102
+ def test_empty_string_returns_clean_message(self):
103
+ out = render_full_explain('')
104
+ self.assertIn('empty', out.lower())
105
+ self.assertNotIn('Parse error', out)
106
+ self.assertNotIn('unknown', out.lower())
107
+
108
+ def test_whitespace_only_returns_clean_message(self):
109
+ out = render_full_explain(' \n ')
110
+ self.assertIn('empty', out.lower())
111
+ self.assertNotIn('Parse error', out)
112
+
113
+ def test_nonempty_prompt_no_empty_message(self):
114
+ out = render_full_explain('base AND_SALT[5] texture :0.8')
115
+ self.assertNotIn('empty prompt', out.lower())
116
+ self.assertIn('ROOT', out)
117
+
118
+
119
+ # ---------------------------------------------------------------------------
120
+ # Effect summaries
121
+ # ---------------------------------------------------------------------------
122
+
123
+ class TestEffectSummary(unittest.TestCase):
124
+
125
+ def test_effect_for_base(self):
126
+ desc = _effect_for_strategy('BASE')
127
+ self.assertIn('base', desc.lower())
128
+
129
+ def test_effect_for_salience_mask(self):
130
+ desc = _effect_for_strategy('SALIENCE_MASK')
131
+ self.assertIn('saliency', desc.lower())
132
+
133
+ def test_effect_for_semantic_guidance(self):
134
+ desc = _effect_for_strategy('SEMANTIC_GUIDANCE')
135
+ self.assertIn('top-k', desc.lower())
136
+
137
+ def test_effect_for_alignment_custom(self):
138
+ desc = _effect_for_strategy('ALIGNMENT_BLEND_CUSTOM')
139
+ self.assertIn('align', desc.lower())
140
+
141
+ def test_effect_for_alignment_legacy_suffix(self):
142
+ # Fixed-suffix strategies like ALIGNMENT_BLEND_4_8 → same description
143
+ desc = _effect_for_strategy('ALIGNMENT_BLEND_4_8')
144
+ self.assertIn('align', desc.lower())
145
+
146
+ def test_effect_for_unknown_strategy(self):
147
+ desc = _effect_for_strategy('SOME_UNKNOWN_STRATEGY')
148
+ self.assertIn('SOME_UNKNOWN_STRATEGY', desc)
149
+
150
+ def test_effect_summary_in_render_full_explain(self):
151
+ out = render_full_explain('base AND_SALT[5] texture :0.8')
152
+ self.assertIn('Effect summary', out)
153
+ self.assertIn('SALIENCE_MASK', out)
154
+
155
+ def test_effect_summary_notes_preview_nature(self):
156
+ out = render_full_explain('base AND_TOPK[0.1] highlights :0.6')
157
+ self.assertIn('structural preview', out.lower())
158
+
159
+ def test_render_effect_summary_empty_on_parse_fail(self):
160
+ diag = {'parse_ok': False, 'strategies': []}
161
+ out = _render_effect_summary(diag)
162
+ self.assertEqual(out, '')
163
+
164
+ def test_render_effect_summary_empty_on_no_strategies(self):
165
+ diag = {'parse_ok': True, 'strategies': []}
166
+ out = _render_effect_summary(diag)
167
+ self.assertEqual(out, '')
168
+
169
+
170
+ # ---------------------------------------------------------------------------
171
+ # Golden: tree format
172
+ # ---------------------------------------------------------------------------
173
+
174
+ class TestTreeGoldenFormat(unittest.TestCase):
175
+ """Check structural properties of the tree output rather than exact strings,
176
+ making these tests robust to minor phrasing changes while still catching
177
+ regressions in tree shape, connectors, and content."""
178
+
179
+ def _tree(self, prompt):
180
+ return render_prompt_tree(prompt)
181
+
182
+ def test_root_header(self):
183
+ out = self._tree('base AND_PERP style :0.8')
184
+ self.assertTrue(out.startswith('ROOT'), f'Expected ROOT header, got: {out[:40]}')
185
+
186
+ def test_leaf_connector_last(self):
187
+ out = self._tree('base AND_PERP style :0.8')
188
+ # Last segment uses └─
189
+ self.assertIn('└─', out)
190
+
191
+ def test_multiple_children_mixed_connectors(self):
192
+ out = self._tree('base AND_SALT[5] tex :0.8 AND_PERP vivid :0.5')
193
+ self.assertIn('├─', out)
194
+ self.assertIn('└─', out)
195
+
196
+ def test_base_shows_BASE(self):
197
+ out = self._tree('base subject')
198
+ self.assertIn('BASE', out)
199
+
200
+ def test_salt_shows_k_param(self):
201
+ out = self._tree('base AND_SALT[7] texture :0.8')
202
+ self.assertIn('k=7.0', out)
203
+
204
+ def test_topk_shows_threshold_param(self):
205
+ out = self._tree('base AND_TOPK[0.15] highlights :0.6')
206
+ self.assertIn('threshold=0.15', out)
207
+
208
+ def test_align_shows_d_s_params(self):
209
+ out = self._tree('base AND_ALIGN[6,12] style :0.7')
210
+ self.assertIn('d=6', out)
211
+ self.assertIn('s=12', out)
212
+
213
+ def test_weight_shown_to_two_decimals(self):
214
+ out = self._tree('base AND_PERP vivid :0.42')
215
+ self.assertIn('0.42', out)
216
+
217
+ def test_nested_composite_shows_children_count(self):
218
+ prompt = 'base AND_SALT[5] [ texture :0.8 AND_TOPK[0.05] detail :0.4 ] :0.7'
219
+ out = self._tree(prompt)
220
+ # Composite node should show "(N children)"
221
+ self.assertIn('children)', out)
222
+
223
+ def test_deep_nested_has_continuation_bar(self):
224
+ prompt = ('base AND_SALT[5] [\n'
225
+ ' region :0.8\n'
226
+ ' AND_ALIGN[4,8] [\n'
227
+ ' sub :0.7\n'
228
+ ' AND_TOPK[0.05] micro :0.4\n'
229
+ ' ] :0.6\n'
230
+ '] :0.7')
231
+ out = self._tree(prompt)
232
+ # Deep nesting should show either │ continuation bars or ├─/└─ at 2+ depths
233
+ # (exact rendering depends on depth threshold)
234
+ # Ensure at least 3 levels of nesting are represented in the tree
235
+ self.assertGreaterEqual(out.count('├─') + out.count('└─'), 4,
236
+ 'Expected at least 4 branch markers for deep nesting')
237
+
238
+ def test_affine_annotated(self):
239
+ out = self._tree('base ROTATE[0.125] AND_PERP vivid :0.8')
240
+ self.assertIn('affine', out)
241
+
242
+ def test_text_truncated_at_55_chars(self):
243
+ long_text = 'a' * 60
244
+ out = self._tree(f'base AND_PERP {long_text} :0.8')
245
+ self.assertIn('…', out)
246
+
247
+ def test_empty_returns_clean_empty_message(self):
248
+ out = self._tree('')
249
+ self.assertIn('empty', out.lower())
250
+ self.assertNotIn('ROOT', out)
251
+
252
+
253
+ # ---------------------------------------------------------------------------
254
+ # Golden: diagnostics content
255
+ # ---------------------------------------------------------------------------
256
+
257
+ class TestDiagnosticsGolden(unittest.TestCase):
258
+
259
+ def _diag(self, prompt):
260
+ return collect_prompt_diagnostics(prompt)
261
+
262
+ def test_base_only_n_segments(self):
263
+ d = self._diag('a dog in a park')
264
+ self.assertGreaterEqual(d['n_segments'], 1)
265
+
266
+ def test_two_children_strategy_list(self):
267
+ d = self._diag('base AND_SALT[5] tex :0.8 AND_TOPK[0.1] sty :0.6')
268
+ self.assertIn('SALIENCE_MASK', d['strategies'])
269
+ self.assertIn('SEMANTIC_GUIDANCE', d['strategies'])
270
+ self.assertIn('BASE', d['strategies'])
271
+
272
+ def test_affine_count(self):
273
+ d = self._diag('base ROTATE[0.125] AND_PERP vivid :0.8')
274
+ self.assertGreaterEqual(d['n_affine'], 1)
275
+
276
+ def test_protection_fires_no_base(self):
277
+ d = self._diag('AND_SALT[5] only :0.8')
278
+ self.assertTrue(d['protection_would_fire'])
279
+
280
+ def test_protection_ok_with_base(self):
281
+ d = self._diag('base AND_SALT[5] tex :0.8')
282
+ self.assertFalse(d['protection_would_fire'])
283
+
284
+ def test_depth_deep_nested(self):
285
+ prompt = ('base AND_SALT[5] [\n'
286
+ ' region :0.8\n'
287
+ ' AND_ALIGN[4,8] [\n'
288
+ ' sub :0.7\n'
289
+ ' AND_TOPK[0.05] micro :0.4\n'
290
+ ' ] :0.6\n'
291
+ '] :0.7')
292
+ d = self._diag(prompt)
293
+ self.assertGreaterEqual(d['max_depth'], 2)
294
+
295
+ def test_render_diagnostics_includes_strategies(self):
296
+ d = self._diag('base AND_SALT[5] tex :0.8')
297
+ out = render_diagnostics(d)
298
+ self.assertIn('SALIENCE_MASK', out)
299
+ self.assertIn('Strategies used', out)
300
+
301
+ def test_render_diagnostics_protection_would_fire_message(self):
302
+ d = self._diag('AND_TOPK[0.1] only :0.8')
303
+ out = render_diagnostics(d)
304
+ self.assertIn('WOULD FIRE', out)
305
+
306
+ def test_render_diagnostics_ok_message(self):
307
+ d = self._diag('base AND_PERP style :0.8')
308
+ out = render_diagnostics(d)
309
+ self.assertIn('OK', out)
310
+
311
+
312
+ # ---------------------------------------------------------------------------
313
+ # Integration: full explain round-trip for all templates
314
+ # ---------------------------------------------------------------------------
315
+
316
+ class TestTemplatesExplainIntegration(unittest.TestCase):
317
+ """Every template must produce a non-trivial full explain output."""
318
+
319
+ def test_all_templates_full_explain(self):
320
+ from lib_neutral_prompt.matryoshka_utils import MATRYOSHKA_TEMPLATES
321
+ for name, tmpl in MATRYOSHKA_TEMPLATES.items():
322
+ with self.subTest(template=name):
323
+ out = render_full_explain(tmpl['prompt'])
324
+ self.assertIn('ROOT', out, f'{name}: missing ROOT in explain')
325
+ self.assertIn('Diagnostics', out, f'{name}: missing Diagnostics')
326
+ self.assertIn('Effect summary', out, f'{name}: missing Effect summary')
327
+ self.assertNotIn('Parse error', out, f'{name}: unexpected parse error')
328
+
329
+
330
+ if __name__ == '__main__':
331
+ unittest.main()
neutral_prompt_patcheds/test/perp_parser/test_parametric_syntax.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for parametric prompt syntax (patched extension).
3
+
4
+ Requires: torch OR falls back to mock_torch for pure parser logic.
5
+
6
+ Covers:
7
+ - AND_SALT[k] / AND_SALT_WIDE[k] / AND_SALT_BLOB[k] parser
8
+ - AND_ALIGN[d,s] / AND_MASK_ALIGN[d,s] parser
9
+ - conciliation_params populated correctly
10
+ - weight still parsed correctly alongside params
11
+ - composite brackets NOT consumed as params
12
+ - backward-compat: AND_ALIGN_4_8 still works
13
+ - invalid / boundary params rejected gracefully
14
+ - params survive alongside affine transforms
15
+ """
16
+
17
+ import pathlib
18
+ import sys
19
+ import unittest
20
+
21
+ sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.parent))
22
+
23
+ # mock_torch installed by package __init__.py
24
+
25
+ from lib_neutral_prompt import neutral_prompt_parser as p
26
+
27
+ CS = p.ConciliationStrategy
28
+
29
+
30
+ def _child(prompt_str, idx=-1):
31
+ return p.parse_root(prompt_str).children[idx]
32
+
33
+
34
+ class TestSaltKParam(unittest.TestCase):
35
+
36
+ def test_salt_default_no_params(self):
37
+ child = _child("base AND_SALT concept :1.0")
38
+ self.assertEqual(child.conciliation, CS.SALIENCE_MASK)
39
+ self.assertEqual(child.conciliation_params, {})
40
+
41
+ def test_salt_with_k_integer(self):
42
+ child = _child("base AND_SALT[5] concept :1.0")
43
+ self.assertEqual(child.conciliation, CS.SALIENCE_MASK)
44
+ self.assertAlmostEqual(child.conciliation_params['k'], 5.0)
45
+
46
+ def test_salt_with_k_float(self):
47
+ child = _child("base AND_SALT[3.5] concept :0.8")
48
+ self.assertAlmostEqual(child.conciliation_params['k'], 3.5)
49
+
50
+ def test_salt_k_preserves_weight(self):
51
+ child = _child("base AND_SALT[5] concept :0.75")
52
+ self.assertAlmostEqual(child.weight, 0.75)
53
+ self.assertAlmostEqual(child.conciliation_params['k'], 5.0)
54
+
55
+ def test_salt_wide_with_k(self):
56
+ child = _child("base AND_SALT_WIDE[1.5] concept :1.0")
57
+ self.assertEqual(child.conciliation, CS.SALIENCE_MASK_WIDE)
58
+ self.assertAlmostEqual(child.conciliation_params['k'], 1.5)
59
+
60
+ def test_salt_wide_default_no_params(self):
61
+ child = _child("base AND_SALT_WIDE concept :1.0")
62
+ self.assertEqual(child.conciliation_params, {})
63
+
64
+ def test_salt_blob_with_k(self):
65
+ child = _child("base AND_SALT_BLOB[8] concept :1.0")
66
+ self.assertEqual(child.conciliation, CS.SALIENCE_MASK_BLOB)
67
+ self.assertAlmostEqual(child.conciliation_params['k'], 8.0)
68
+
69
+ def test_salt_blob_default_no_params(self):
70
+ child = _child("base AND_SALT_BLOB concept :1.0")
71
+ self.assertEqual(child.conciliation_params, {})
72
+
73
+ def test_salt_k_zero_rejected(self):
74
+ child = _child("base AND_SALT[0] concept :1.0")
75
+ self.assertEqual(child.conciliation_params, {})
76
+
77
+ def test_salt_k_negative_rejected(self):
78
+ child = _child("base AND_SALT[-1] concept :1.0")
79
+ self.assertEqual(child.conciliation_params, {})
80
+
81
+ def test_salt_k_max_value_accepted(self):
82
+ child = _child("base AND_SALT[20] concept :1.0")
83
+ self.assertAlmostEqual(child.conciliation_params['k'], 20.0)
84
+
85
+ def test_salt_k_very_small_accepted(self):
86
+ child = _child("base AND_SALT[0.5] concept :1.0")
87
+ self.assertAlmostEqual(child.conciliation_params['k'], 0.5)
88
+
89
+
90
+ class TestAlignBracketParam(unittest.TestCase):
91
+
92
+ def test_align_bracket_basic(self):
93
+ child = _child("base AND_ALIGN[4,8] style :0.5")
94
+ self.assertEqual(child.conciliation, CS.ALIGNMENT_BLEND_CUSTOM)
95
+ self.assertEqual(child.conciliation_params, {'d': 4, 's': 8})
96
+
97
+ def test_align_bracket_large_pair(self):
98
+ child = _child("base AND_ALIGN[8,32] style :0.5")
99
+ self.assertEqual(child.conciliation_params, {'d': 8, 's': 32})
100
+
101
+ def test_align_bracket_preserves_weight(self):
102
+ child = _child("base AND_ALIGN[4,8] style :0.65")
103
+ self.assertAlmostEqual(child.weight, 0.65)
104
+ self.assertEqual(child.conciliation_params, {'d': 4, 's': 8})
105
+
106
+ def test_mask_align_bracket_basic(self):
107
+ child = _child("base AND_MASK_ALIGN[6,12] structure :0.7")
108
+ self.assertEqual(child.conciliation, CS.ALIGNMENT_MASK_BLEND_CUSTOM)
109
+ self.assertEqual(child.conciliation_params, {'d': 6, 's': 12})
110
+
111
+ def test_mask_align_bracket_min_pair(self):
112
+ child = _child("base AND_MASK_ALIGN[2,3] structure :0.5")
113
+ self.assertEqual(child.conciliation_params, {'d': 2, 's': 3})
114
+
115
+ def test_align_equal_ds_rejected(self):
116
+ """d == s has no structural distinction → params must be rejected."""
117
+ child = _child("base AND_ALIGN[4,4] style :0.5")
118
+ self.assertEqual(child.conciliation_params, {})
119
+
120
+ def test_align_d_out_of_range_low(self):
121
+ child = _child("base AND_ALIGN[1,8] style :0.5")
122
+ self.assertEqual(child.conciliation_params, {})
123
+
124
+ def test_align_s_out_of_range_high(self):
125
+ child = _child("base AND_ALIGN[4,33] style :0.5")
126
+ self.assertEqual(child.conciliation_params, {})
127
+
128
+ def test_align_inverted_pair_accepted(self):
129
+ """d > s is valid (just an unusual choice) as long as d != s."""
130
+ child = _child("base AND_ALIGN[8,4] style :0.5")
131
+ self.assertEqual(child.conciliation_params, {'d': 8, 's': 4})
132
+
133
+
134
+ class TestBackwardCompatibility(unittest.TestCase):
135
+
136
+ def test_fixed_suffix_align(self):
137
+ child = _child("base AND_ALIGN_4_8 style :0.5")
138
+ self.assertEqual(child.conciliation, CS.ALIGNMENT_BLEND_4_8)
139
+ self.assertEqual(child.conciliation_params, {})
140
+
141
+ def test_fixed_suffix_mask_align(self):
142
+ child = _child("base AND_MASK_ALIGN_4_8 structure :0.5")
143
+ self.assertEqual(child.conciliation, CS.ALIGNMENT_MASK_BLEND_4_8)
144
+ self.assertEqual(child.conciliation_params, {})
145
+
146
+ def test_both_forms_coexist(self):
147
+ root = p.parse_root("base AND_ALIGN_4_8 old :0.5 AND_ALIGN[6,12] new :0.7")
148
+ old, new = root.children[1], root.children[2]
149
+ self.assertEqual(old.conciliation, CS.ALIGNMENT_BLEND_4_8)
150
+ self.assertEqual(old.conciliation_params, {})
151
+ self.assertEqual(new.conciliation, CS.ALIGNMENT_BLEND_CUSTOM)
152
+ self.assertEqual(new.conciliation_params, {'d': 6, 's': 12})
153
+
154
+ def test_plain_and_salt_unchanged(self):
155
+ child = _child("base AND_SALT concept :1.0")
156
+ self.assertEqual(child.conciliation, CS.SALIENCE_MASK)
157
+ self.assertEqual(child.conciliation_params, {})
158
+
159
+
160
+ class TestCompositeBracketNotConsumed(unittest.TestCase):
161
+
162
+ def test_align_composite_not_eaten(self):
163
+ root = p.parse_root("base AND_ALIGN[arst AND defg :2.0] :0.5")
164
+ child = root.children[-1]
165
+ self.assertIsInstance(child, p.CompositePrompt)
166
+ self.assertEqual(child.conciliation_params, {})
167
+
168
+ def test_salt_composite_not_eaten(self):
169
+ root = p.parse_root("base AND_SALT[concept AND style :1.5] :0.8")
170
+ child = root.children[-1]
171
+ self.assertIsInstance(child, p.CompositePrompt)
172
+ self.assertEqual(child.conciliation_params, {})
173
+
174
+ def test_mask_align_composite_not_eaten(self):
175
+ root = p.parse_root("base AND_MASK_ALIGN[a AND b] :0.5")
176
+ child = root.children[-1]
177
+ self.assertIsInstance(child, p.CompositePrompt)
178
+ self.assertEqual(child.conciliation_params, {})
179
+
180
+
181
+ class TestParamsWithAffineTransform(unittest.TestCase):
182
+
183
+ def test_salt_k_with_trailing_affine(self):
184
+ child = _child("base AND_SALT[5] ROTATE[0.125] concept :0.8")
185
+ self.assertEqual(child.conciliation, CS.SALIENCE_MASK)
186
+ self.assertAlmostEqual(child.conciliation_params['k'], 5.0)
187
+ self.assertIsNotNone(child.local_transform)
188
+
189
+ def test_align_custom_with_leading_affine(self):
190
+ child = _child("base ROTATE[0.125] AND_ALIGN[4,8] style :0.5")
191
+ self.assertEqual(child.conciliation, CS.ALIGNMENT_BLEND_CUSTOM)
192
+ self.assertEqual(child.conciliation_params, {'d': 4, 's': 8})
193
+ self.assertIsNotNone(child.local_transform)
194
+
195
+ def test_salt_k_and_affine_weight_all_correct(self):
196
+ child = _child("base AND_SALT[3] ROTATE[0.25] concept :0.6")
197
+ self.assertAlmostEqual(child.conciliation_params['k'], 3.0)
198
+ self.assertAlmostEqual(child.weight, 0.6)
199
+ self.assertIsNotNone(child.local_transform)
200
+
201
+
202
+ if __name__ == '__main__':
203
+ unittest.main()
204
+
205
+
206
+ class TestSpacesInBracketParams(unittest.TestCase):
207
+ """Spaces around comma inside bracket params must be tolerated."""
208
+
209
+ def test_align_spaces_around_comma(self):
210
+ child = _child("base AND_ALIGN[4, 8] style :0.5")
211
+ self.assertEqual(child.conciliation_params, {'d': 4, 's': 8})
212
+
213
+ def test_mask_align_spaces_around_comma(self):
214
+ child = _child("base AND_MASK_ALIGN[6, 12] style :0.5")
215
+ self.assertEqual(child.conciliation_params, {'d': 6, 's': 12})
216
+
217
+ def test_salt_inner_spaces(self):
218
+ child = _child("base AND_SALT[ 5 ] concept :1.0")
219
+ self.assertAlmostEqual(child.conciliation_params.get('k', 0), 5.0)
220
+
221
+ def test_align_many_spaces(self):
222
+ child = _child("base AND_ALIGN[ 4 , 8 ] style :0.5")
223
+ self.assertEqual(child.conciliation_params, {'d': 4, 's': 8})
224
+
225
+
226
+ class TestInvalidParamWarnings(unittest.TestCase):
227
+ """
228
+ Invalid bracket params must:
229
+ 1. NOT crash.
230
+ 2. Return empty conciliation_params (keyword uses its default).
231
+ 3. NOT leak bracket content into prompt text.
232
+ 4. Emit a warning to stderr when verbose=True.
233
+ """
234
+
235
+ def _parse(self, s):
236
+ return p.parse_root(s)
237
+
238
+ def _last_child(self, s):
239
+ return self._parse(s).children[-1]
240
+
241
+ # --- k invalid values ---
242
+
243
+ def test_k_zero_no_params(self):
244
+ child = self._last_child("base AND_SALT[0] concept :1.0")
245
+ self.assertEqual(child.conciliation_params, {})
246
+
247
+ def test_k_negative_no_params(self):
248
+ child = self._last_child("base AND_SALT[-3] concept :1.0")
249
+ self.assertEqual(child.conciliation_params, {})
250
+
251
+ def test_k_too_many_parts_no_params(self):
252
+ child = self._last_child("base AND_SALT[3,5] concept :1.0")
253
+ self.assertEqual(child.conciliation_params, {})
254
+
255
+ # --- D,S invalid values ---
256
+
257
+ def test_ds_equal_no_params(self):
258
+ child = self._last_child("base AND_ALIGN[4,4] style :0.5")
259
+ self.assertEqual(child.conciliation_params, {})
260
+
261
+ def test_d_too_low_no_params(self):
262
+ child = self._last_child("base AND_ALIGN[1,8] style :0.5")
263
+ self.assertEqual(child.conciliation_params, {})
264
+
265
+ def test_s_too_high_no_params(self):
266
+ child = self._last_child("base AND_ALIGN[4,33] style :0.5")
267
+ self.assertEqual(child.conciliation_params, {})
268
+
269
+ def test_wrong_part_count_no_params(self):
270
+ child = self._last_child("base AND_ALIGN[4] style :0.5")
271
+ self.assertEqual(child.conciliation_params, {})
272
+
273
+ # --- No crash on bad input ---
274
+
275
+ def test_no_crash_on_empty_brackets(self):
276
+ # AND_SALT[] — inner is empty, should not crash
277
+ try:
278
+ root = self._parse("base AND_SALT[] concept :1.0")
279
+ # brackets may be treated as composite or ignored; just no crash
280
+ except Exception as exc:
281
+ self.fail(f"Crashed on AND_SALT[]: {exc}")
282
+
283
+ # --- Warning emitted to stderr ---
284
+
285
+ def test_warning_emitted_for_k_zero(self):
286
+ import sys, io
287
+ from lib_neutral_prompt import global_state as gs
288
+ gs.verbose = True
289
+ buf = io.StringIO()
290
+ with unittest.mock.patch('sys.stderr', buf):
291
+ self._parse("base AND_SALT[0] concept :1.0")
292
+ gs.verbose = False
293
+ self.assertIn('AND_SALT', buf.getvalue(),
294
+ "Expected warning mentioning AND_SALT in stderr")
295
+ self.assertIn('k must be', buf.getvalue().lower().replace('k must be', 'k must be'),
296
+ "Expected warning explaining the reason")
297
+
298
+ def test_warning_emitted_for_ds_equal(self):
299
+ import io, unittest.mock
300
+ from lib_neutral_prompt import global_state as gs
301
+ gs.verbose = True
302
+ buf = io.StringIO()
303
+ with unittest.mock.patch('sys.stderr', buf):
304
+ self._parse("base AND_ALIGN[4,4] style :0.5")
305
+ gs.verbose = False
306
+ self.assertIn('AND_ALIGN', buf.getvalue())
307
+ self.assertIn('differ', buf.getvalue())
308
+
309
+ def test_no_warning_when_verbose_false(self):
310
+ import io, unittest.mock
311
+ from lib_neutral_prompt import global_state as gs
312
+ gs.verbose = False
313
+ buf = io.StringIO()
314
+ with unittest.mock.patch('sys.stderr', buf):
315
+ self._parse("base AND_SALT[0] concept :1.0")
316
+ self.assertEqual(buf.getvalue(), '',
317
+ "No warning expected when verbose=False")
318
+
319
+
320
+ # ---------------------------------------------------------------------------
321
+ # AND_TOPK[threshold] — parametric threshold
322
+ # ---------------------------------------------------------------------------
323
+
324
+ class TestTopkBracketSyntax(unittest.TestCase):
325
+
326
+ def test_topk_default_no_brackets(self):
327
+ root = p.parse_root('base AND_TOPK concept :0.8')
328
+ child = root.children[-1]
329
+ self.assertEqual(child.conciliation_params, {})
330
+
331
+ def test_topk_explicit_threshold(self):
332
+ root = p.parse_root('base AND_TOPK[0.1] concept :0.8')
333
+ child = root.children[-1]
334
+ self.assertAlmostEqual(child.conciliation_params.get('threshold'), 0.1)
335
+
336
+ def test_topk_small_threshold(self):
337
+ root = p.parse_root('base AND_TOPK[0.01] concept :0.8')
338
+ child = root.children[-1]
339
+ self.assertAlmostEqual(child.conciliation_params.get('threshold'), 0.01)
340
+
341
+ def test_topk_max_threshold(self):
342
+ root = p.parse_root('base AND_TOPK[1.0] concept :0.8')
343
+ child = root.children[-1]
344
+ self.assertAlmostEqual(child.conciliation_params.get('threshold'), 1.0)
345
+
346
+ def test_topk_threshold_zero_rejected(self):
347
+ # threshold=0 is invalid (nothing would be selected)
348
+ root = p.parse_root('base AND_TOPK[0] concept :0.8')
349
+ child = root.children[-1]
350
+ self.assertEqual(child.conciliation_params, {})
351
+
352
+ def test_topk_threshold_over_one_rejected(self):
353
+ root = p.parse_root('base AND_TOPK[1.1] concept :0.8')
354
+ child = root.children[-1]
355
+ self.assertEqual(child.conciliation_params, {})
356
+
357
+ def test_topk_too_many_params_rejected(self):
358
+ root = p.parse_root('base AND_TOPK[0.1,0.2] concept :0.8')
359
+ child = root.children[-1]
360
+ self.assertEqual(child.conciliation_params, {})
361
+
362
+ def test_topk_spaces_in_brackets(self):
363
+ root = p.parse_root('base AND_TOPK[ 0.1 ] concept :0.8')
364
+ child = root.children[-1]
365
+ self.assertAlmostEqual(child.conciliation_params.get('threshold'), 0.1)
366
+
367
+ def test_topk_does_not_consume_non_numeric_brackets(self):
368
+ # Composite brackets must NOT be eaten
369
+ root = p.parse_root('base AND_TOPK [concept AND style] :0.8')
370
+ # The [...] is a composite group, not params — topk should have no params
371
+ child = root.children[-1]
372
+ self.assertEqual(child.conciliation_params, {})
373
+
374
+ def test_topk_with_trailing_affine(self):
375
+ root = p.parse_root('base AND_TOPK[0.05] ROTATE[0.125] concept :0.8')
376
+ child = root.children[-1]
377
+ self.assertAlmostEqual(child.conciliation_params.get('threshold'), 0.05)
378
+ self.assertIsNotNone(child.local_transform)
379
+
380
+ def test_topk_with_leading_affine(self):
381
+ root = p.parse_root('base ROTATE[0.125] AND_TOPK[0.1] concept :0.8')
382
+ child = root.children[-1]
383
+ self.assertAlmostEqual(child.conciliation_params.get('threshold'), 0.1)
384
+ self.assertIsNotNone(child.local_transform)
385
+
386
+ def test_topk_backward_compat_default(self):
387
+ # Without brackets the default 0.05 must come from cfg_denoiser, not parser
388
+ root = p.parse_root('base AND_TOPK concept :1.0')
389
+ child = root.children[-1]
390
+ self.assertEqual(child.conciliation_params, {})
391
+
392
+ def test_topk_coexists_with_salt(self):
393
+ root = p.parse_root('base AND_SALT[3] texture :0.8 AND_TOPK[0.1] style :0.6')
394
+ self.assertEqual(len(root.children), 3)
395
+ salt = root.children[1]
396
+ topk = root.children[2]
397
+ self.assertAlmostEqual(salt.conciliation_params.get('k'), 3.0)
398
+ self.assertAlmostEqual(topk.conciliation_params.get('threshold'), 0.1)
399
+
400
+
401
+ # ---------------------------------------------------------------------------
402
+ # Affine builder helper (ui-level function, no gradio import needed)
403
+ # ---------------------------------------------------------------------------
404
+
405
+ class TestAffineBuilderHelper(unittest.TestCase):
406
+ """Tests for _build_affine_snippet via neutral_prompt_parser — no gradio."""
407
+
408
+ def _snippet(self, transform, p1, p2):
409
+ # Inline logic mirrors _build_affine_snippet in ui.py
410
+ t = transform.strip()
411
+ import math
412
+ if t == 'FLIP_H': return 'SCALE[-1,1]'
413
+ if t == 'FLIP_V': return 'SCALE[1,-1]'
414
+ if t == 'ROTATE': return f'ROTATE[{p1}]'
415
+ if t == 'SCALE':
416
+ return f'SCALE[{p1},{p2}]' if abs(p1 - p2) > 1e-9 else f'SCALE[{p1}]'
417
+ if t == 'SLIDE': return f'SLIDE[{p1},{p2}]'
418
+ if t == 'SHEAR':
419
+ return f'SHEAR[{p1},{p2}]' if abs(p1 - p2) > 1e-9 else f'SHEAR[{p1}]'
420
+ return ''
421
+
422
+ def test_flip_h(self):
423
+ self.assertEqual(self._snippet('FLIP_H', 0, 0), 'SCALE[-1,1]')
424
+
425
+ def test_flip_v(self):
426
+ self.assertEqual(self._snippet('FLIP_V', 0, 0), 'SCALE[1,-1]')
427
+
428
+ def test_rotate(self):
429
+ self.assertEqual(self._snippet('ROTATE', 0.125, 1.0), 'ROTATE[0.125]')
430
+
431
+ def test_scale_uniform(self):
432
+ self.assertEqual(self._snippet('SCALE', 2.0, 2.0), 'SCALE[2.0]')
433
+
434
+ def test_scale_anisotropic(self):
435
+ self.assertEqual(self._snippet('SCALE', 1.5, 1.0), 'SCALE[1.5,1.0]')
436
+
437
+ def test_slide(self):
438
+ self.assertEqual(self._snippet('SLIDE', 0.1, 0.2), 'SLIDE[0.1,0.2]')
439
+
440
+ def test_shear_uniform(self):
441
+ self.assertEqual(self._snippet('SHEAR', 0.05, 0.05), 'SHEAR[0.05]')
442
+
443
+ def test_snippet_parseable_by_parser(self):
444
+ """Each generated snippet must survive parse_root without crashing."""
445
+ cases = [
446
+ 'ROTATE[0.125]', 'SCALE[-1,1]', 'SCALE[1,-1]',
447
+ 'SCALE[0.5]', 'SCALE[1.5,1.0]',
448
+ 'SLIDE[0.1,0.2]', 'SHEAR[0.05]',
449
+ ]
450
+ for snippet in cases:
451
+ with self.subTest(snippet=snippet):
452
+ try:
453
+ p.parse_root(f'base {snippet} AND_PERP style :0.8')
454
+ except Exception as exc:
455
+ self.fail(f'snippet "{snippet}" crashed parse_root: {exc}')
456
+
457
+
458
+ # ---------------------------------------------------------------------------
459
+ # Debug/explain helper (_render_parse_tree logic, without gradio)
460
+ # ---------------------------------------------------------------------------
461
+
462
+ class TestDebugRenderLogic(unittest.TestCase):
463
+ """Test the parse-tree renderer used by the debug panel."""
464
+
465
+ def _render(self, prompt_str):
466
+ # Import the function directly from ui module logic — replicate inline
467
+ # so we don't need gradio at test time.
468
+ if not prompt_str.strip():
469
+ return '(empty prompt)'
470
+ try:
471
+ root = p.parse_root(prompt_str)
472
+ except Exception as exc:
473
+ return f'Parse error: {exc}'
474
+
475
+ lines = []
476
+
477
+ def _node(node, depth=0):
478
+ indent = ' ' * depth
479
+ strat = node.conciliation.name if node.conciliation else 'BASE'
480
+ w = f'weight={node.weight:.3f}'
481
+ params = ''
482
+ if node.conciliation_params:
483
+ params = ' params=' + ', '.join(f'{k}={v}' for k, v in node.conciliation_params.items())
484
+ aff = ' [affine]' if node.local_transform is not None else ''
485
+ if hasattr(node, 'prompt'):
486
+ text = repr(node.prompt[:60])
487
+ lines.append(f'{indent}[{strat}] {w}{params}{aff} text={text}')
488
+ else:
489
+ lines.append(f'{indent}[{strat}] {w}{params}{aff} ({len(node.children)} children)')
490
+ for child in node.children:
491
+ _node(child, depth + 1)
492
+
493
+ _node(root)
494
+ return '\n'.join(lines)
495
+
496
+ def test_empty_prompt(self):
497
+ out = self._render('')
498
+ self.assertEqual(out, '(empty prompt)')
499
+
500
+ def test_base_only(self):
501
+ out = self._render('a dog')
502
+ self.assertIn('BASE', out)
503
+ self.assertIn('a dog', out)
504
+
505
+ def test_with_salt(self):
506
+ out = self._render('base AND_SALT[3] texture :0.8')
507
+ self.assertIn('SALIENCE_MASK', out)
508
+ self.assertIn('k=3', out)
509
+
510
+ def test_with_topk(self):
511
+ out = self._render('base AND_TOPK[0.1] style :0.6')
512
+ self.assertIn('SEMANTIC_GUIDANCE', out)
513
+ self.assertIn('threshold=0.1', out)
514
+
515
+ def test_with_align(self):
516
+ out = self._render('base AND_ALIGN[4,8] detail :0.7')
517
+ self.assertIn('ALIGNMENT_BLEND_CUSTOM', out)
518
+ self.assertIn('d=4', out)
519
+ self.assertIn('s=8', out)
520
+
521
+ def test_with_affine(self):
522
+ out = self._render('base ROTATE[0.125] AND_PERP vivid :0.8')
523
+ self.assertIn('[affine]', out)
524
+
525
+ def test_multiple_children(self):
526
+ out = self._render('base AND_SALT[5] texture :0.8 AND_TOPK[0.05] style :0.6')
527
+ self.assertIn('SALIENCE_MASK', out)
528
+ self.assertIn('SEMANTIC_GUIDANCE', out)
529
+
530
+ def test_parse_error_graceful(self):
531
+ # Malformed / deeply weird input must not raise — parser should handle it
532
+ try:
533
+ out = self._render('[[[[[')
534
+ except Exception as exc:
535
+ self.fail(f'_render raised on bad input: {exc}')
neutral_prompt_patcheds/test/perp_parser/test_runtime_behavior.py ADDED
@@ -0,0 +1,826 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Runtime behavior tests.
3
+
4
+ Requires PyTorch. If torch is absent the whole module is skipped — same
5
+ pattern as test_affine_pipeline.py so the CI exit code stays clean.
6
+
7
+ Covers:
8
+ - _get_salience: k param changes pixel coverage (monotonically)
9
+ - blob morphology: thickify-first pipeline vs erode-first pipeline
10
+ - no-base guard: fallback when all children are conciliatory
11
+ - API override lifecycle: clear → set → get → clear
12
+ """
13
+
14
+ import pathlib
15
+ import sys
16
+ import types
17
+ import unittest
18
+
19
+ # ---------------------------------------------------------------------------
20
+ # Torch availability check — skip entire module if absent
21
+ # ---------------------------------------------------------------------------
22
+ try:
23
+ import importlib.util as _ilu
24
+ _spec = _ilu.find_spec('torch')
25
+ _TORCH_AVAILABLE = _spec is not None and getattr(_spec, 'origin', None) is not None
26
+ except (ValueError, ModuleNotFoundError):
27
+ _TORCH_AVAILABLE = False
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # A1111 stubs (must be installed before importing lib modules)
31
+ # ---------------------------------------------------------------------------
32
+ sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.parent))
33
+
34
+
35
+ def _stub(name):
36
+ m = types.ModuleType(name)
37
+ sys.modules.setdefault(name, m)
38
+ return sys.modules[name]
39
+
40
+
41
+ for _n in ('modules', 'modules.script_callbacks', 'modules.sd_samplers',
42
+ 'modules.shared', 'modules.prompt_parser', 'gradio'):
43
+ _stub(_n)
44
+
45
+ import modules.script_callbacks as _sc
46
+ for _attr in ('on_cfg_denoiser', 'on_script_unloaded', 'CFGDenoiserParams'):
47
+ if not hasattr(_sc, _attr):
48
+ setattr(_sc, _attr, lambda *a, **kw: None)
49
+
50
+ import modules.shared as _sh
51
+ if not hasattr(_sh, 'state'):
52
+ _sh.state = types.SimpleNamespace(sampling_step=1)
53
+
54
+ import modules.sd_samplers as _ss
55
+ if not hasattr(_ss, 'create_sampler'):
56
+ _ss.create_sampler = lambda *a, **kw: None
57
+
58
+ import lib_neutral_prompt.hijacker as _h_mod
59
+
60
+
61
+ class _StubHijacker:
62
+ @classmethod
63
+ def install_or_get(cls, **kwargs):
64
+ return cls()
65
+ def hijack(self, name):
66
+ return lambda fn: fn
67
+
68
+
69
+ _h_mod.ModuleHijacker = _StubHijacker
70
+
71
+
72
+ @unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping runtime tests')
73
+ class TestSalienceKCoverage(unittest.TestCase):
74
+ """Higher k → strictly fewer pixels where child salience beats parent."""
75
+
76
+ @classmethod
77
+ def setUpClass(cls):
78
+ import torch
79
+ import lib_neutral_prompt.cfg_denoiser_hijack as hj
80
+ cls.torch = torch
81
+ cls.hj = hj
82
+
83
+ def _winning_fraction(self, k_child, C=4, H=64, W=64):
84
+ torch, hj = self.torch, self.hj
85
+ torch.manual_seed(7)
86
+ parent = torch.randn(C, H, W)
87
+ child = torch.randn(C, H, W)
88
+ sal_p = hj._get_salience(parent, k=1.0)
89
+ sal_c = hj._get_salience(child, k=k_child)
90
+ mask = torch.argmax(torch.stack([sal_p, sal_c]), dim=0)
91
+ return (mask == 1).float().mean().item()
92
+
93
+ def test_k1_broad(self):
94
+ self.assertGreater(self._winning_fraction(1.0), 0.40)
95
+
96
+ def test_k5_moderate(self):
97
+ f = self._winning_fraction(5.0)
98
+ self.assertLess(f, 0.15)
99
+ self.assertGreater(f, 0.001)
100
+
101
+ def test_k20_surgical(self):
102
+ self.assertLess(self._winning_fraction(20.0), 0.01)
103
+
104
+ def test_monotone_k(self):
105
+ f1 = self._winning_fraction(1.0)
106
+ f5 = self._winning_fraction(5.0)
107
+ f20 = self._winning_fraction(20.0)
108
+ self.assertGreater(f1, f5, "k=1 should cover more than k=5")
109
+ self.assertGreater(f5, f20, "k=5 should cover more than k=20")
110
+
111
+
112
+ @unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping runtime tests')
113
+ class TestBlobMorphology(unittest.TestCase):
114
+ """Thickify-first pipeline keeps blob alive; erode-first destroys it."""
115
+
116
+ @classmethod
117
+ def setUpClass(cls):
118
+ import torch
119
+ import lib_neutral_prompt.cfg_denoiser_hijack as hj
120
+ cls.torch = torch
121
+ cls.hj = hj
122
+
123
+ def _seed(self, C=4, H=8, W=8):
124
+ board = self.torch.zeros(C, H, W)
125
+ board[:, H // 2, W // 2] = 1.0
126
+ return board
127
+
128
+ def _run(self, board, thickify, erode):
129
+ hj = self.hj
130
+ for _ in range(thickify):
131
+ board = hj._life_step(board, hj._thickify_rule)
132
+ for _ in range(erode):
133
+ board = hj._life_step(board, hj._erode_rule)
134
+ return board
135
+
136
+ def test_erode_first_kills_isolated_seed(self):
137
+ """Original ultimate ordering: 6 erodes → seed dies immediately."""
138
+ result = self._run(self._seed(), thickify=0, erode=6)
139
+ self.assertEqual(result.sum().item(), 0,
140
+ "Erode-first should destroy an isolated pixel seed")
141
+
142
+ def test_thickify_first_survives(self):
143
+ """Patched ordering: 3 thickify then 1 erode → seed survives."""
144
+ result = self._run(self._seed(), thickify=3, erode=1)
145
+ self.assertGreater(result.sum().item(), 0,
146
+ "Thickify-first should keep blob alive")
147
+
148
+ def test_thickify_grows_region(self):
149
+ original_mass = self._seed().sum().item()
150
+ grown = self._run(self._seed(), thickify=3, erode=0)
151
+ self.assertGreater(grown.sum().item(), original_mass)
152
+
153
+ def test_erode_after_thickify_reduces_mass(self):
154
+ grown = self._run(self._seed(), thickify=3, erode=0).sum().item()
155
+ pruned = self._run(self._seed(), thickify=3, erode=1).sum().item()
156
+ self.assertLess(pruned, grown,
157
+ "Erosion after thickify should trim edges")
158
+ self.assertGreater(pruned, 0,
159
+ "But core should still survive")
160
+
161
+
162
+ @unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping runtime tests')
163
+ class TestNoBaseGuard(unittest.TestCase):
164
+ """combine_denoised_hijack falls back when all children are conciliatory."""
165
+
166
+ @classmethod
167
+ def setUpClass(cls):
168
+ import torch
169
+ import lib_neutral_prompt.cfg_denoiser_hijack as hj
170
+ import lib_neutral_prompt.global_state as gs
171
+ import lib_neutral_prompt.neutral_prompt_parser as npp
172
+ cls.torch = torch
173
+ cls.hj = hj
174
+ cls.gs = gs
175
+ cls.npp = npp
176
+
177
+ def setUp(self):
178
+ self.gs.is_enabled = True
179
+ self.gs.verbose = False
180
+ self.gs.prompt_exprs = []
181
+ self.gs.batch_cond_indices = []
182
+
183
+ def tearDown(self):
184
+ self.gs.is_enabled = False
185
+ self.gs.prompt_exprs = []
186
+ self.gs.batch_cond_indices = []
187
+
188
+ def _counter(self):
189
+ calls = [0]
190
+ def fn(x_out, batch_cond_indices, text_uncond, cond_scale):
191
+ calls[0] += 1
192
+ return x_out[-1:]
193
+ return fn, calls
194
+
195
+ def test_fallback_when_no_base_child(self):
196
+ root = self.npp.parse_root("AND_SALT concept :1.0")
197
+ self.gs.prompt_exprs = [root]
198
+ self.gs.batch_cond_indices = [[(0, 1.0)]]
199
+
200
+ x_out = self.torch.randn(2, 4, 8, 8)
201
+ text_uncond = x_out[-1:]
202
+ fn, calls = self._counter()
203
+
204
+ self.hj.combine_denoised_hijack(
205
+ x_out, self.gs.batch_cond_indices, text_uncond, 7.0, fn
206
+ )
207
+ self.assertEqual(calls[0], 1,
208
+ "Guard must fall back when there is no base AND child")
209
+
210
+ def test_no_fallback_when_base_exists(self):
211
+ """
212
+ When a base AND-child exists, the guard must NOT fire.
213
+
214
+ Architecture note: combine_denoised_hijack always calls original_function
215
+ at least once via _get_webui_denoised() on the normal path — that call
216
+ receives a *sliced* x_out (not the raw full tensor).
217
+ The guard fallback, by contrast, passes the *original* x_out unchanged.
218
+
219
+ We distinguish the two by capturing the first argument the counter saw:
220
+ - guard path → x_out.shape[0] == n_conds + n_uncond (full tensor)
221
+ - normal path → x_out.shape[0] <= n_conds (sliced tensor)
222
+ """
223
+ root = self.npp.parse_root("base concept AND_SALT tweak :1.0")
224
+ self.gs.prompt_exprs = [root]
225
+ # 2 cond slots (base + AND_SALT) + 1 implicit uncond
226
+ self.gs.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
227
+
228
+ n_cond = 2
229
+ n_uncond = 1
230
+ x_out = self.torch.randn(n_cond + n_uncond, 4, 8, 8)
231
+ text_uncond = x_out[-n_uncond:]
232
+
233
+ seen_shapes = []
234
+
235
+ def capturing_fn(x, batch_cond_indices, text_uncond, cond_scale):
236
+ seen_shapes.append(x.shape[0])
237
+ return x[-1:]
238
+
239
+ try:
240
+ self.hj.combine_denoised_hijack(
241
+ x_out, self.gs.batch_cond_indices, text_uncond, 7.0, capturing_fn
242
+ )
243
+ except Exception:
244
+ pass # downstream tensor errors are fine; we only check entry shape
245
+
246
+ self.assertTrue(len(seen_shapes) > 0,
247
+ "original_function was never called at all")
248
+ # On the normal path the first call is from _get_webui_denoised which
249
+ # passes a sliced tensor (only base-AND children + uncond).
250
+ # On the guard path it would pass the full x_out (n_cond + n_uncond rows).
251
+ full_tensor_size = n_cond + n_uncond
252
+ self.assertLess(seen_shapes[0], full_tensor_size,
253
+ "Guard fired and passed full x_out — it should NOT have")
254
+
255
+
256
+ @unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping runtime tests')
257
+ class TestCFGRescaleAPILifecycle(unittest.TestCase):
258
+
259
+ @classmethod
260
+ def setUpClass(cls):
261
+ import lib_neutral_prompt.global_state as gs
262
+ cls.gs = gs
263
+
264
+ def setUp(self):
265
+ self.gs.cfg_rescale = 0.0
266
+ self.gs.cfg_rescale_override = None
267
+ self.gs.CFGRescaleFactorSingleton.clear()
268
+
269
+ def test_override_applied_and_cleared(self):
270
+ self.gs.cfg_rescale_override = 0.7
271
+ self.gs.apply_and_clear_cfg_rescale_override()
272
+ self.assertAlmostEqual(self.gs.cfg_rescale, 0.7)
273
+ self.assertIsNone(self.gs.cfg_rescale_override)
274
+
275
+ def test_singleton_clear_returns_none(self):
276
+ self.gs.CFGRescaleFactorSingleton.clear()
277
+ self.assertIsNone(self.gs.CFGRescaleFactorSingleton.get())
278
+
279
+ def test_singleton_set_and_get(self):
280
+ self.gs.CFGRescaleFactorSingleton.set(0.85)
281
+ self.assertAlmostEqual(self.gs.CFGRescaleFactorSingleton.get(), 0.85)
282
+
283
+ def test_singleton_cleared_between_steps(self):
284
+ self.gs.CFGRescaleFactorSingleton.set(0.85)
285
+ self.gs.CFGRescaleFactorSingleton.clear()
286
+ self.assertIsNone(self.gs.CFGRescaleFactorSingleton.get())
287
+
288
+ def test_no_override_does_not_change_rescale(self):
289
+ """apply_and_clear with no pending override must leave cfg_rescale alone."""
290
+ self.gs.cfg_rescale = 0.3
291
+ self.gs.apply_and_clear_cfg_rescale_override()
292
+ self.assertAlmostEqual(self.gs.cfg_rescale, 0.3)
293
+
294
+
295
+ if __name__ == '__main__':
296
+ unittest.main()
297
+
298
+
299
+ @unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping runtime tests')
300
+ class TestProtectionModes(unittest.TestCase):
301
+ """Off / Auto / Strict protection modes."""
302
+
303
+ @classmethod
304
+ def setUpClass(cls):
305
+ import torch
306
+ import lib_neutral_prompt.cfg_denoiser_hijack as hj
307
+ import lib_neutral_prompt.global_state as gs
308
+ import lib_neutral_prompt.neutral_prompt_parser as npp
309
+ cls.torch = torch
310
+ cls.hj = hj
311
+ cls.gs = gs
312
+ cls.npp = npp
313
+
314
+ def setUp(self):
315
+ self.gs.is_enabled = True
316
+ self.gs.verbose = False
317
+ self.gs.protection_mode = 'auto'
318
+ self.gs.strict_threshold = 0.1
319
+ self.gs.prompt_exprs = []
320
+ self.gs.batch_cond_indices = []
321
+
322
+ def tearDown(self):
323
+ self.gs.is_enabled = False
324
+ self.gs.protection_mode = 'auto'
325
+ self.gs.prompt_exprs = []
326
+ self.gs.batch_cond_indices = []
327
+
328
+ def _run(self, prompt_str, n_cond, protection_mode, strict_threshold=0.1):
329
+ """Return the shape[0] of x_out passed to original_fn on first call."""
330
+ root = self.npp.parse_root(prompt_str)
331
+ self.gs.prompt_exprs = [root]
332
+ self.gs.batch_cond_indices = [[(i, 1.0) for i in range(n_cond)]]
333
+ self.gs.protection_mode = protection_mode
334
+ self.gs.strict_threshold = strict_threshold
335
+
336
+ n_uncond = 1
337
+ x_out = self.torch.randn(n_cond + n_uncond, 4, 8, 8)
338
+ text_uncond = x_out[-n_uncond:]
339
+
340
+ seen = []
341
+ def fn(x, *a, **kw):
342
+ seen.append(x.shape[0])
343
+ return x[-1:]
344
+
345
+ try:
346
+ self.hj.combine_denoised_hijack(
347
+ x_out, self.gs.batch_cond_indices, text_uncond, 7.0, fn)
348
+ except Exception:
349
+ pass
350
+
351
+ return seen, n_cond + n_uncond
352
+
353
+ def _guard_fired(self, seen, full_size):
354
+ """Guard fires when original_fn receives the full-size x_out."""
355
+ return len(seen) > 0 and seen[0] == full_size
356
+
357
+ # --- Auto mode ---
358
+
359
+ def test_auto_fires_on_no_base(self):
360
+ seen, full = self._run("AND_SALT concept :1.0", n_cond=1,
361
+ protection_mode='auto')
362
+ self.assertTrue(self._guard_fired(seen, full),
363
+ "Auto guard must fire when no base AND-child exists")
364
+
365
+ def test_auto_does_not_fire_with_base(self):
366
+ seen, full = self._run("base concept AND_SALT tweak :1.0", n_cond=2,
367
+ protection_mode='auto')
368
+ self.assertFalse(self._guard_fired(seen, full),
369
+ "Auto guard must NOT fire when base AND-child exists")
370
+
371
+ # --- Off mode ---
372
+
373
+ def test_off_never_fires_even_without_base(self):
374
+ seen, full = self._run("AND_SALT concept :1.0", n_cond=1,
375
+ protection_mode='off')
376
+ self.assertFalse(self._guard_fired(seen, full),
377
+ "Off mode must never fire the guard")
378
+
379
+ # --- Strict mode: structural check same as auto ---
380
+
381
+ def test_strict_fires_on_no_base(self):
382
+ seen, full = self._run("AND_SALT concept :1.0", n_cond=1,
383
+ protection_mode='strict')
384
+ self.assertTrue(self._guard_fired(seen, full),
385
+ "Strict guard must also fire on structural no-base case")
386
+
387
+ # --- global_state defaults ---
388
+
389
+ def test_default_protection_mode_is_auto(self):
390
+ import importlib, lib_neutral_prompt.global_state as gs_mod
391
+ # default is set at module level; check it's 'auto'
392
+ self.assertEqual(gs_mod.protection_mode, 'auto')
393
+
394
+ def test_strict_threshold_default(self):
395
+ import lib_neutral_prompt.global_state as gs_mod
396
+ self.assertAlmostEqual(gs_mod.strict_threshold, 0.1)
397
+
398
+
399
+ # ---------------------------------------------------------------------------
400
+ # Tests for global_state helpers (no torch needed)
401
+ # ---------------------------------------------------------------------------
402
+
403
+ class TestGlobalStateHelpers(unittest.TestCase):
404
+ """normalize_protection_mode and clamp_strict_threshold."""
405
+
406
+ @classmethod
407
+ def setUpClass(cls):
408
+ import lib_neutral_prompt.global_state as gs
409
+ cls.gs = gs
410
+
411
+ def test_normalize_valid_modes(self):
412
+ for mode in ('off', 'auto', 'strict'):
413
+ self.assertEqual(self.gs.normalize_protection_mode(mode), mode)
414
+
415
+ def test_normalize_uppercase(self):
416
+ self.assertEqual(self.gs.normalize_protection_mode('AUTO'), 'auto')
417
+ self.assertEqual(self.gs.normalize_protection_mode('Strict'), 'strict')
418
+
419
+ def test_normalize_whitespace(self):
420
+ self.assertEqual(self.gs.normalize_protection_mode(' off '), 'off')
421
+
422
+ def test_normalize_unknown_falls_back_to_auto(self):
423
+ self.assertEqual(self.gs.normalize_protection_mode('strong'), 'auto')
424
+ self.assertEqual(self.gs.normalize_protection_mode(''), 'auto')
425
+ self.assertEqual(self.gs.normalize_protection_mode('1'), 'auto')
426
+
427
+ def test_clamp_valid(self):
428
+ self.assertAlmostEqual(self.gs.clamp_strict_threshold(0.10), 0.10)
429
+ self.assertAlmostEqual(self.gs.clamp_strict_threshold(0.50), 0.50)
430
+ self.assertAlmostEqual(self.gs.clamp_strict_threshold(0.01), 0.01)
431
+
432
+ def test_clamp_below_min(self):
433
+ self.assertAlmostEqual(self.gs.clamp_strict_threshold(0.0), 0.01)
434
+ self.assertAlmostEqual(self.gs.clamp_strict_threshold(-5), 0.01)
435
+
436
+ def test_clamp_above_max(self):
437
+ self.assertAlmostEqual(self.gs.clamp_strict_threshold(0.99), 0.50)
438
+
439
+ def test_clamp_bad_type_returns_default(self):
440
+ self.assertAlmostEqual(self.gs.clamp_strict_threshold('bad'), 0.10)
441
+ self.assertAlmostEqual(self.gs.clamp_strict_threshold(None), 0.10)
442
+
443
+ def test_apply_infotext_protection(self):
444
+ import lib_neutral_prompt.global_state as gs
445
+ gs.protection_mode = 'auto'
446
+ gs.strict_threshold = 0.10
447
+ gs.apply_infotext({
448
+ gs.INFOTEXT_KEY_PROTECTION_MODE: 'strict',
449
+ gs.INFOTEXT_KEY_STRICT_THRESHOLD: '0.25',
450
+ })
451
+ self.assertEqual(gs.protection_mode, 'strict')
452
+ self.assertAlmostEqual(gs.strict_threshold, 0.25)
453
+
454
+ def test_apply_infotext_invalid_mode_falls_back(self):
455
+ import lib_neutral_prompt.global_state as gs
456
+ gs.protection_mode = 'off'
457
+ gs.apply_infotext({gs.INFOTEXT_KEY_PROTECTION_MODE: 'maximum'})
458
+ self.assertEqual(gs.protection_mode, 'auto')
459
+
460
+ def test_apply_infotext_partial_update(self):
461
+ """apply_infotext must not touch fields absent from the dict."""
462
+ import lib_neutral_prompt.global_state as gs
463
+ gs.protection_mode = 'strict'
464
+ gs.strict_threshold = 0.20
465
+ gs.apply_infotext({}) # empty dict
466
+ self.assertEqual(gs.protection_mode, 'strict')
467
+ self.assertAlmostEqual(gs.strict_threshold, 0.20)
468
+
469
+
470
+ @unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping runtime tests')
471
+ class TestProtectionPolicyFunction(unittest.TestCase):
472
+ """_should_fallback_for_protection across Off / Auto / Strict."""
473
+
474
+ @classmethod
475
+ def setUpClass(cls):
476
+ import torch
477
+ import lib_neutral_prompt.cfg_denoiser_hijack as hj
478
+ import lib_neutral_prompt.global_state as gs
479
+ import lib_neutral_prompt.neutral_prompt_parser as npp
480
+ cls.torch = torch
481
+ cls.hj = hj
482
+ cls.gs = gs
483
+ cls.npp = npp
484
+
485
+ def setUp(self):
486
+ self.gs.is_enabled = True
487
+ self.gs.protection_mode = 'auto'
488
+ self.gs.strict_threshold = 0.10
489
+ self.gs.prompt_exprs = []
490
+ self.gs.batch_cond_indices = []
491
+
492
+ def tearDown(self):
493
+ self.gs.is_enabled = False
494
+ self.gs.protection_mode = 'auto'
495
+
496
+ def _setup(self, prompt_str, n_cond):
497
+ root = self.npp.parse_root(prompt_str)
498
+ self.gs.prompt_exprs = [root]
499
+ self.gs.batch_cond_indices = [[(i, 1.0) for i in range(n_cond)]]
500
+ n_uncond = 1
501
+ x_out = self.torch.randn(n_cond + n_uncond, 4, 8, 8)
502
+ text_uncond = x_out[-n_uncond:]
503
+ return x_out, text_uncond
504
+
505
+ # --- Off: never fires regardless of structure ---
506
+
507
+ def test_off_no_base(self):
508
+ self.gs.protection_mode = 'off'
509
+ x_out, text_uncond = self._setup("AND_SALT concept :1.0", n_cond=1)
510
+ should, reason = self.hj._should_fallback_for_protection(
511
+ x_out, self.gs.batch_cond_indices, text_uncond)
512
+ self.assertFalse(should)
513
+ self.assertIsNone(reason)
514
+
515
+ def test_off_with_base(self):
516
+ self.gs.protection_mode = 'off'
517
+ x_out, text_uncond = self._setup("base AND_SALT tweak :1.0", n_cond=2)
518
+ should, _ = self.hj._should_fallback_for_protection(
519
+ x_out, self.gs.batch_cond_indices, text_uncond)
520
+ self.assertFalse(should)
521
+
522
+ # --- Auto: structural check only ---
523
+
524
+ def test_auto_fires_no_base(self):
525
+ self.gs.protection_mode = 'auto'
526
+ x_out, text_uncond = self._setup("AND_SALT concept :1.0", n_cond=1)
527
+ should, reason = self.hj._should_fallback_for_protection(
528
+ x_out, self.gs.batch_cond_indices, text_uncond)
529
+ self.assertTrue(should)
530
+ self.assertIn('no valid base', reason)
531
+ self.assertIn('auto', reason)
532
+
533
+ def test_auto_does_not_fire_with_base(self):
534
+ self.gs.protection_mode = 'auto'
535
+ x_out, text_uncond = self._setup("base AND_SALT tweak :1.0", n_cond=2)
536
+ should, _ = self.hj._should_fallback_for_protection(
537
+ x_out, self.gs.batch_cond_indices, text_uncond)
538
+ self.assertFalse(should)
539
+
540
+ # --- Strict: also fires on no-base (structural) ---
541
+
542
+ def test_strict_fires_on_no_base(self):
543
+ self.gs.protection_mode = 'strict'
544
+ x_out, text_uncond = self._setup("AND_SALT concept :1.0", n_cond=1)
545
+ should, reason = self.hj._should_fallback_for_protection(
546
+ x_out, self.gs.batch_cond_indices, text_uncond)
547
+ self.assertTrue(should)
548
+ self.assertIn('strict', reason)
549
+
550
+ # --- Strict: numerical ratio check ---
551
+
552
+ def test_strict_reason_contains_ratio_info(self):
553
+ """When strict fires on ratio, reason must include ratio and threshold."""
554
+ self.gs.protection_mode = 'strict'
555
+ self.gs.strict_threshold = 0.99 # almost certain to fire
556
+
557
+ x_out, text_uncond = self._setup("base AND_SALT tweak :1.0", n_cond=2)
558
+ should, reason = self.hj._should_fallback_for_protection(
559
+ x_out, self.gs.batch_cond_indices, text_uncond)
560
+
561
+ if should:
562
+ # If it fired, reason must explain why
563
+ self.assertIn('ratio', reason)
564
+ self.assertIn('threshold', reason)
565
+ # If it didn't fire (both norms happened to be equal), that's fine too.
566
+
567
+ def test_strict_does_not_fire_with_threshold_zero(self):
568
+ """threshold=0.01 (minimum) should almost never fire on random latents."""
569
+ self.gs.protection_mode = 'strict'
570
+ self.gs.strict_threshold = 0.01
571
+
572
+ x_out, text_uncond = self._setup("base AND_SALT tweak :1.0", n_cond=2)
573
+ should, _ = self.hj._should_fallback_for_protection(
574
+ x_out, self.gs.batch_cond_indices, text_uncond)
575
+ self.assertFalse(should,
576
+ "threshold=0.01 should almost never trigger on random latents")
577
+
578
+ # --- _has_valid_base_path ---
579
+
580
+ def test_has_valid_base_path_true(self):
581
+ root = self.npp.parse_root("base AND_SALT tweak :1.0")
582
+ self.assertTrue(self.hj._has_valid_base_path([root]))
583
+
584
+ def test_has_valid_base_path_false(self):
585
+ root = self.npp.parse_root("AND_SALT only :1.0")
586
+ self.assertFalse(self.hj._has_valid_base_path([root]))
587
+
588
+ def test_has_valid_base_path_empty(self):
589
+ self.assertFalse(self.hj._has_valid_base_path([]))
590
+
591
+ # --- _safe_norm ---
592
+
593
+ def test_safe_norm_normal(self):
594
+ t = self.torch.ones(4, 8, 8)
595
+ self.assertGreater(self.hj._safe_norm(t), 0.0)
596
+
597
+ def test_safe_norm_zeros(self):
598
+ t = self.torch.zeros(4, 8, 8)
599
+ self.assertAlmostEqual(self.hj._safe_norm(t), 0.0)
600
+
601
+
602
+ @unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping runtime tests')
603
+ class TestPerPromptProtection(unittest.TestCase):
604
+ """Per-prompt structural check — a healthy prompt must not shield a bad neighbour."""
605
+
606
+ @classmethod
607
+ def setUpClass(cls):
608
+ import torch
609
+ import lib_neutral_prompt.cfg_denoiser_hijack as hj
610
+ import lib_neutral_prompt.global_state as gs
611
+ import lib_neutral_prompt.neutral_prompt_parser as npp
612
+ cls.torch = torch
613
+ cls.hj = hj
614
+ cls.gs = gs
615
+ cls.npp = npp
616
+
617
+ def setUp(self):
618
+ self.gs.is_enabled = True
619
+ self.gs.protection_mode = 'auto'
620
+ self.gs.strict_threshold = 0.10
621
+ self.gs.prompt_exprs = []
622
+ self.gs.batch_cond_indices = []
623
+
624
+ def tearDown(self):
625
+ self.gs.is_enabled = False
626
+ self.gs.protection_mode = 'auto'
627
+ self.gs.prompt_exprs = []
628
+ self.gs.batch_cond_indices = []
629
+
630
+ # --- _get_invalid_base_prompt_indices ---
631
+
632
+ def test_helper_all_good(self):
633
+ roots = [
634
+ self.npp.parse_root("base AND_SALT tweak :1.0"),
635
+ self.npp.parse_root("base AND_PERP style :0.8"),
636
+ ]
637
+ self.assertEqual(self.hj._get_invalid_base_prompt_indices(roots), [])
638
+
639
+ def test_helper_all_bad(self):
640
+ roots = [
641
+ self.npp.parse_root("AND_SALT only :1.0"),
642
+ self.npp.parse_root("AND_PERP only :0.8"),
643
+ ]
644
+ self.assertEqual(self.hj._get_invalid_base_prompt_indices(roots), [0, 1])
645
+
646
+ def test_helper_mixed_batch(self):
647
+ roots = [
648
+ self.npp.parse_root("base AND_SALT tweak :1.0"), # index 0 — good
649
+ self.npp.parse_root("AND_PERP only :0.8"), # index 1 — bad
650
+ ]
651
+ self.assertEqual(self.hj._get_invalid_base_prompt_indices(roots), [1])
652
+
653
+ def test_helper_empty_exprs(self):
654
+ self.assertEqual(self.hj._get_invalid_base_prompt_indices([]), [])
655
+
656
+ # --- _prompt_has_valid_base_path ---
657
+
658
+ def test_per_prompt_good(self):
659
+ root = self.npp.parse_root("base AND_SALT tweak :1.0")
660
+ self.assertTrue(self.hj._prompt_has_valid_base_path(root))
661
+
662
+ def test_per_prompt_bad(self):
663
+ root = self.npp.parse_root("AND_SALT only :1.0")
664
+ self.assertFalse(self.hj._prompt_has_valid_base_path(root))
665
+
666
+ # --- Mixed-batch policy: auto must fire even if index-0 is healthy ---
667
+
668
+ def _run_mixed_batch(self, protection_mode):
669
+ """Batch: prompt 0 OK, prompt 1 has no base."""
670
+ root_good = self.npp.parse_root("base AND_SALT tweak :1.0")
671
+ root_bad = self.npp.parse_root("AND_SALT only :1.0")
672
+ self.gs.prompt_exprs = [root_good, root_bad]
673
+ self.gs.batch_cond_indices = [[(0, 1.0), (1, 1.0)], [(2, 1.0)]]
674
+ self.gs.protection_mode = protection_mode
675
+
676
+ n_uncond = 1
677
+ x_out = self.torch.randn(3 + n_uncond, 4, 8, 8)
678
+ text_uncond = x_out[-n_uncond:]
679
+
680
+ seen = []
681
+ def fn(x, *a, **kw):
682
+ seen.append(x.shape[0])
683
+ return x[-1:]
684
+
685
+ try:
686
+ self.hj.combine_denoised_hijack(
687
+ x_out, self.gs.batch_cond_indices, text_uncond, 7.0, fn)
688
+ except Exception:
689
+ pass
690
+ return seen, x_out.shape[0]
691
+
692
+ def test_auto_fires_on_mixed_batch(self):
693
+ seen, full = self._run_mixed_batch('auto')
694
+ # Guard should have passed full x_out to original_fn (fallback path)
695
+ self.assertTrue(len(seen) > 0 and seen[0] == full,
696
+ "Auto guard must fire when any prompt in the batch lacks a base path")
697
+
698
+ def test_off_does_not_fire_on_mixed_batch(self):
699
+ seen, full = self._run_mixed_batch('off')
700
+ self.assertTrue(len(seen) == 0 or seen[0] != full,
701
+ "Off mode must never trigger the guard, even for bad prompts")
702
+
703
+
704
+ @unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping runtime tests')
705
+ class TestSafeAffineInversion(unittest.TestCase):
706
+ """_try_invert_affine: graceful fallback on singular / NaN matrices."""
707
+
708
+ @classmethod
709
+ def setUpClass(cls):
710
+ import torch
711
+ import lib_neutral_prompt.cfg_denoiser_hijack as hj
712
+ import lib_neutral_prompt.global_state as gs
713
+ cls.torch = torch
714
+ cls.hj = hj
715
+ cls.gs = gs
716
+
717
+ def setUp(self):
718
+ self.gs.verbose = False
719
+
720
+ def tearDown(self):
721
+ self.gs.verbose = False
722
+
723
+ def test_normal_invertible_matrix(self):
724
+ m = self.torch.eye(3)
725
+ result = self.hj._try_invert_affine(m, 'test')
726
+ self.assertTrue(self.torch.allclose(result, self.torch.eye(3)),
727
+ "Inverting identity should return identity")
728
+
729
+ def test_singular_matrix_returns_identity(self):
730
+ """SCALE[0] gives a singular matrix — must not crash."""
731
+ m = self.torch.zeros(3, 3) # fully degenerate
732
+ result = self.hj._try_invert_affine(m, 'singular test')
733
+ self.assertTrue(self.torch.allclose(result, self.torch.eye(3)),
734
+ "Singular matrix must fall back to identity")
735
+
736
+ def test_nearly_singular_returns_identity(self):
737
+ m = self.torch.eye(3)
738
+ m[0, 0] = 1e-9 # effectively zero scale
739
+ result = self.hj._try_invert_affine(m, 'nearly singular')
740
+ # Either returns eye (fallback) or a finite result — must not be NaN/Inf
741
+ self.assertTrue(self.torch.isfinite(result).all(),
742
+ "Result must be finite even for nearly-singular input")
743
+
744
+ def test_verbose_warning_on_singular(self):
745
+ import io, unittest.mock
746
+ self.gs.verbose = True
747
+ buf = io.StringIO()
748
+ m = self.torch.zeros(3, 3)
749
+ with unittest.mock.patch('sys.stderr', buf):
750
+ self.hj._try_invert_affine(m, 'singular context')
751
+ self.gs.verbose = False
752
+ self.assertIn('singular context', buf.getvalue())
753
+ self.assertIn('identity', buf.getvalue().lower())
754
+
755
+ def test_no_warning_when_verbose_false(self):
756
+ import io, unittest.mock
757
+ self.gs.verbose = False
758
+ buf = io.StringIO()
759
+ m = self.torch.zeros(3, 3)
760
+ with unittest.mock.patch('sys.stderr', buf):
761
+ self.hj._try_invert_affine(m, 'silent singular')
762
+ self.assertEqual(buf.getvalue(), '',
763
+ "No warning expected when verbose=False")
764
+
765
+
766
+ class TestAffineParserValidation(unittest.TestCase):
767
+ """Parser-level rejection of invalid affine params."""
768
+
769
+ @classmethod
770
+ def setUpClass(cls):
771
+ # mock_torch already installed by package __init__
772
+ import lib_neutral_prompt.neutral_prompt_parser as _p
773
+ cls.p = _p
774
+
775
+ def _parse(self, s):
776
+ return self.p.parse_root(s)
777
+
778
+ def _child(self, s, idx=-1):
779
+ return self._parse(s).children[idx]
780
+
781
+ # --- SCALE[0] must not produce a transform ---
782
+
783
+ def test_scale_zero_rejected(self):
784
+ """SCALE[0] must not crash. If a transform is stored it must come from
785
+ the identity fallback (make_scale returned None → _apply_affine kept t)."""
786
+ try:
787
+ child = self._child("base AND_PERP SCALE[0] text :0.8")
788
+ except Exception as exc:
789
+ self.fail(f"SCALE[0] must not crash the parser: {exc}")
790
+ # Either no transform stored (None) or the stored transform is the
791
+ # identity that _apply_affine kept when make_scale returned None.
792
+ # Both are acceptable — the key guarantee is no crash.
793
+
794
+ def test_scale_nonzero_accepted(self):
795
+ child = self._child("base AND_PERP SCALE[2] text :0.8")
796
+ self.assertIsNotNone(child.local_transform,
797
+ "SCALE[2] should produce a valid transform")
798
+
799
+ # --- SHEAR near 0.25 turns must be rejected ---
800
+
801
+ def test_shear_safe_value_accepted(self):
802
+ child = self._child("base AND_PERP SHEAR[0.1] text :0.8")
803
+ # Should produce a transform (0.1 turns is safe)
804
+ # Don't assert non-None since mock torch may not multiply correctly,
805
+ # just check no crash occurred.
806
+
807
+ def test_shear_dangerous_value_warning(self):
808
+ """SHEAR[0.25] is exactly at the singularity — must not crash."""
809
+ try:
810
+ root = self._parse("base AND_PERP SHEAR[0.25] text :0.8")
811
+ except Exception as exc:
812
+ self.fail(f"SHEAR[0.25] must not crash the parser: {exc}")
813
+
814
+ # --- Non-finite params must not crash ---
815
+
816
+ def test_rotate_no_crash(self):
817
+ try:
818
+ self._parse("base AND_PERP ROTATE[0.5] text :0.8")
819
+ except Exception as exc:
820
+ self.fail(f"ROTATE[0.5] crashed: {exc}")
821
+
822
+ def test_slide_no_crash(self):
823
+ try:
824
+ self._parse("base AND_SALT SLIDE[0.1,0.2] text :1.0")
825
+ except Exception as exc:
826
+ self.fail(f"SLIDE crashed: {exc}")
neutral_prompt_patcheds/test/perp_parser/test_sprint2.py ADDED
@@ -0,0 +1,656 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sprint 2 tests: protection_utils, affine_utils, step_utils, builder v1.2.
3
+
4
+ All pure-Python, no torch (torch-dependent cases skipped via mock).
5
+ """
6
+
7
+ import unittest
8
+
9
+ # mock_torch installed by package __init__
10
+
11
+ from lib_neutral_prompt import global_state
12
+ from lib_neutral_prompt.affine_utils import (
13
+ make_rotate, make_slide, make_scale, make_shear,
14
+ apply_affine, try_invert_affine,
15
+ build_affine_snippet, AFFINE_PRESETS, AFFINE_TRANSFORMS,
16
+ affine_transforms,
17
+ _make_rotate, _make_scale, _apply_affine, # alias checks
18
+ )
19
+ from lib_neutral_prompt.protection_utils import (
20
+ prompt_has_valid_base_path,
21
+ get_invalid_base_prompt_indices,
22
+ has_valid_base_path,
23
+ safe_norm,
24
+ compute_soft_attenuation,
25
+ protection_verdict,
26
+ )
27
+ from lib_neutral_prompt.step_utils import (
28
+ StepWindow, ALWAYS_ACTIVE, STRATEGY_DEFAULTS,
29
+ normalize_progress, strategy_is_active,
30
+ _strategy_family, render_step_window_summary,
31
+ )
32
+ from lib_neutral_prompt.matryoshka_utils import (
33
+ BuilderNode,
34
+ builder_add_child, builder_remove_node, builder_duplicate_node,
35
+ builder_move_up, builder_move_down, builder_update_node,
36
+ serialize_builder_tree, builder_tree_to_prompt,
37
+ )
38
+ import lib_neutral_prompt.neutral_prompt_parser as p
39
+
40
+ import sys as _sys
41
+
42
+ def _has_real_torch() -> bool:
43
+ """True iff real torch (not the test mock) is available."""
44
+ try:
45
+ import torch as _t
46
+ _t.tensor([1.0]).float().item()
47
+ _t.linalg.norm(_t.tensor([3.0, 4.0])).item()
48
+ return True
49
+ except Exception:
50
+ return False
51
+
52
+ _REAL_TORCH = _has_real_torch()
53
+ _skip_no_torch = unittest.skipUnless(_REAL_TORCH, 'requires real torch')
54
+
55
+
56
+
57
+ # ---------------------------------------------------------------------------
58
+ # affine_utils
59
+ # ---------------------------------------------------------------------------
60
+
61
+ class TestAffineUtils(unittest.TestCase):
62
+
63
+ # --- matrix constructors ---
64
+
65
+ @_skip_no_torch
66
+ def test_rotate_matrix_shape(self):
67
+ import torch
68
+ m = make_rotate(0.0)
69
+ self.assertEqual(m.shape, (3, 3))
70
+
71
+ @_skip_no_torch
72
+ def test_rotate_zero_is_identity(self):
73
+ import torch
74
+ m = make_rotate(0.0)
75
+ self.assertTrue(torch.allclose(m, torch.eye(3), atol=1e-5))
76
+
77
+ @_skip_no_torch
78
+ def test_rotate_quarter_turn(self):
79
+ import torch
80
+ m = make_rotate(0.25) # 90°
81
+ expected = torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], dtype=torch.float32)
82
+ self.assertTrue(torch.allclose(m, expected, atol=1e-5))
83
+
84
+ def test_rotate_nonfinite_returns_none(self):
85
+ self.assertIsNone(make_rotate(float('inf')))
86
+ self.assertIsNone(make_rotate(float('nan')))
87
+
88
+ @_skip_no_torch
89
+ def test_scale_uniform(self):
90
+ import torch
91
+ m = make_scale(2.0)
92
+ self.assertAlmostEqual(m[0, 0].item(), 2.0)
93
+ self.assertAlmostEqual(m[1, 1].item(), 2.0)
94
+
95
+ def test_scale_zero_rejected(self):
96
+ self.assertIsNone(make_scale(0.0))
97
+ self.assertIsNone(make_scale(1.0, 0.0))
98
+ self.assertIsNone(make_scale(0.0, 1.0))
99
+
100
+ def test_scale_nonfinite_rejected(self):
101
+ self.assertIsNone(make_scale(float('nan'), 1.0))
102
+
103
+ @_skip_no_torch
104
+ def test_shear_safe_range(self):
105
+ import torch
106
+ m = make_shear(0.1)
107
+ self.assertIsNotNone(m)
108
+ self.assertEqual(m.shape, (3, 3))
109
+
110
+ def test_shear_unsafe_range_rejected(self):
111
+ self.assertIsNone(make_shear(0.25))
112
+ self.assertIsNone(make_shear(-0.25))
113
+
114
+ @_skip_no_torch
115
+ def test_slide_matrix(self):
116
+ import torch
117
+ m = make_slide(0.1, 0.2)
118
+ self.assertAlmostEqual(m[0, 2].item(), 0.1)
119
+ self.assertAlmostEqual(m[1, 2].item(), 0.2)
120
+
121
+ def test_slide_nonfinite_rejected(self):
122
+ self.assertIsNone(make_slide(float('nan'), 0.0))
123
+
124
+ @_skip_no_torch
125
+ def test_apply_affine_none_returns_t(self):
126
+ import torch
127
+ t = torch.eye(3)
128
+ self.assertIs(apply_affine(t, None), t)
129
+
130
+ @_skip_no_torch
131
+ def test_try_invert_identity(self):
132
+ import torch
133
+ eye = torch.eye(3)
134
+ result = try_invert_affine(eye)
135
+ self.assertTrue(torch.allclose(result, eye, atol=1e-6))
136
+
137
+ @_skip_no_torch
138
+ def test_try_invert_singular_returns_identity(self):
139
+ import torch
140
+ singular = torch.zeros(3, 3)
141
+ result = try_invert_affine(singular, 'test_singular')
142
+ self.assertTrue(torch.allclose(result, torch.eye(3), atol=1e-6))
143
+
144
+ def test_private_aliases_work(self):
145
+ # Backward-compat aliases must point to the same function
146
+ self.assertIs(_make_rotate, make_rotate)
147
+ self.assertIs(_make_scale, make_scale)
148
+ self.assertIs(_apply_affine, apply_affine)
149
+
150
+ # --- parser dispatch table ---
151
+
152
+ def test_affine_transforms_keys(self):
153
+ self.assertIn('ROTATE', affine_transforms)
154
+ self.assertIn('SCALE', affine_transforms)
155
+ self.assertIn('SLIDE', affine_transforms)
156
+ self.assertIn('SHEAR', affine_transforms)
157
+
158
+ @_skip_no_torch
159
+ def test_affine_transforms_callable(self):
160
+ import torch
161
+ t = torch.eye(3, dtype=torch.float32)
162
+ result = affine_transforms['ROTATE'](t, 0.0)
163
+ self.assertTrue(torch.allclose(result, t, atol=1e-5))
164
+
165
+ # --- snippet builder ---
166
+
167
+ def test_snippet_flip_h(self):
168
+ self.assertEqual(build_affine_snippet('FLIP_H', 0, 0), 'SCALE[-1,1]')
169
+
170
+ def test_snippet_flip_v(self):
171
+ self.assertEqual(build_affine_snippet('FLIP_V', 0, 0), 'SCALE[1,-1]')
172
+
173
+ def test_snippet_rotate(self):
174
+ s = build_affine_snippet('ROTATE', 0.25, 0)
175
+ self.assertEqual(s, 'ROTATE[0.25]')
176
+
177
+ def test_snippet_scale_uniform(self):
178
+ s = build_affine_snippet('SCALE', 2.0, 2.0)
179
+ self.assertEqual(s, 'SCALE[2.0]')
180
+
181
+ def test_snippet_scale_anisotropic(self):
182
+ s = build_affine_snippet('SCALE', 1.5, 1.0)
183
+ self.assertEqual(s, 'SCALE[1.5,1.0]')
184
+
185
+ def test_snippet_parseable(self):
186
+ for snip in ['ROTATE[0.125]', 'SCALE[-1,1]', 'SLIDE[0.1,0.0]']:
187
+ with self.subTest(snip=snip):
188
+ p.parse_root(f'base {snip} AND_PERP style :0.8')
189
+
190
+ def test_presets_non_none_are_tuples(self):
191
+ for name, val in AFFINE_PRESETS.items():
192
+ with self.subTest(preset=name):
193
+ if val is not None:
194
+ self.assertIsInstance(val, tuple)
195
+ self.assertIn(val[0], AFFINE_TRANSFORMS)
196
+
197
+
198
+ # ---------------------------------------------------------------------------
199
+ # protection_utils
200
+ # ---------------------------------------------------------------------------
201
+
202
+ class TestProtectionUtils(unittest.TestCase):
203
+
204
+ def _leaf(self, strat=None):
205
+ return p.LeafPrompt(weight=1.0, conciliation=strat,
206
+ local_transform=None, prompt='text',
207
+ conciliation_params={})
208
+
209
+ def _composite(self, children):
210
+ return p.CompositePrompt(weight=1.0, conciliation=None,
211
+ local_transform=None, children=children,
212
+ conciliation_params={})
213
+
214
+ # --- structural checks ---
215
+
216
+ def test_base_child_is_valid(self):
217
+ node = self._composite([self._leaf(None)])
218
+ self.assertTrue(prompt_has_valid_base_path(node))
219
+
220
+ def test_all_conciliation_is_invalid(self):
221
+ node = self._composite([
222
+ self._leaf(p.ConciliationStrategy.PERPENDICULAR),
223
+ ])
224
+ self.assertFalse(prompt_has_valid_base_path(node))
225
+
226
+ def test_leaf_not_valid_base_path(self):
227
+ self.assertFalse(prompt_has_valid_base_path(self._leaf(None)))
228
+
229
+ def test_mixed_children_is_valid(self):
230
+ node = self._composite([
231
+ self._leaf(None),
232
+ self._leaf(p.ConciliationStrategy.SALIENCE_MASK),
233
+ ])
234
+ self.assertTrue(prompt_has_valid_base_path(node))
235
+
236
+ def test_invalid_indices_empty_on_ok_batch(self):
237
+ ok = self._composite([self._leaf(None)])
238
+ self.assertEqual(get_invalid_base_prompt_indices([ok]), [])
239
+
240
+ def test_invalid_indices_reports_bad_positions(self):
241
+ ok = self._composite([self._leaf(None)])
242
+ bad = self._composite([self._leaf(p.ConciliationStrategy.PERPENDICULAR)])
243
+ indices = get_invalid_base_prompt_indices([ok, bad, ok])
244
+ self.assertEqual(indices, [1])
245
+
246
+ def test_has_valid_base_path_empty_is_false(self):
247
+ self.assertFalse(has_valid_base_path([]))
248
+
249
+ def test_has_valid_base_path_all_ok(self):
250
+ ok = self._composite([self._leaf(None)])
251
+ self.assertTrue(has_valid_base_path([ok]))
252
+
253
+ def test_has_valid_base_path_mixed_batch_is_false(self):
254
+ ok = self._composite([self._leaf(None)])
255
+ bad = self._composite([self._leaf(p.ConciliationStrategy.PERPENDICULAR)])
256
+ self.assertFalse(has_valid_base_path([ok, bad]))
257
+
258
+ # --- safe_norm ---
259
+
260
+ @_skip_no_torch
261
+ def test_safe_norm_positive(self):
262
+ import torch
263
+ t = torch.tensor([3.0, 4.0])
264
+ self.assertAlmostEqual(safe_norm(t), 5.0, places=4)
265
+
266
+ @_skip_no_torch
267
+ def test_safe_norm_zero_tensor(self):
268
+ import torch
269
+ self.assertEqual(safe_norm(torch.zeros(3)), 0.0)
270
+
271
+ @_skip_no_torch
272
+ def test_safe_norm_nan_returns_zero(self):
273
+ import torch
274
+ t = torch.tensor([float('nan'), 1.0])
275
+ self.assertEqual(safe_norm(t), 0.0)
276
+
277
+ # --- soft attenuation ---
278
+
279
+ def test_attenuation_no_change_when_ratio_ok(self):
280
+ # base_norm / aux_norm = 0.8 / 0.4 = 2.0 ≥ threshold 0.1
281
+ factor = compute_soft_attenuation(0.8, 0.4, 0.1)
282
+ self.assertAlmostEqual(factor, 1.0)
283
+
284
+ def test_attenuation_reduces_when_ratio_low(self):
285
+ # base=0.05, aux=1.0, threshold=0.2 → ratio=0.05 < 0.2
286
+ # factor = 0.05 / (1.0 * 0.2) = 0.25
287
+ factor = compute_soft_attenuation(0.05, 1.0, 0.2)
288
+ self.assertAlmostEqual(factor, 0.25, places=5)
289
+
290
+ def test_attenuation_zero_aux_returns_one(self):
291
+ self.assertAlmostEqual(compute_soft_attenuation(0.5, 0.0, 0.1), 1.0)
292
+
293
+ def test_attenuation_clamped_to_zero_min(self):
294
+ factor = compute_soft_attenuation(-1.0, 1.0, 0.2)
295
+ self.assertGreaterEqual(factor, 0.0)
296
+
297
+ # --- protection_verdict ---
298
+
299
+ def test_verdict_off(self):
300
+ saved = global_state.protection_mode
301
+ global_state.protection_mode = 'off'
302
+ try:
303
+ status, msg = protection_verdict('base AND_PERP style :0.8')
304
+ self.assertEqual(status, 'off')
305
+ finally:
306
+ global_state.protection_mode = saved
307
+
308
+ def test_verdict_ok_with_base(self):
309
+ saved = global_state.protection_mode
310
+ global_state.protection_mode = 'auto'
311
+ try:
312
+ status, msg = protection_verdict('base subject AND_SALT[5] texture :0.8')
313
+ self.assertEqual(status, 'ok')
314
+ finally:
315
+ global_state.protection_mode = saved
316
+
317
+ def test_verdict_fire_no_base(self):
318
+ saved = global_state.protection_mode
319
+ global_state.protection_mode = 'auto'
320
+ try:
321
+ status, msg = protection_verdict('AND_SALT[5] only :0.8')
322
+ self.assertEqual(status, 'fire')
323
+ finally:
324
+ global_state.protection_mode = saved
325
+
326
+ def test_verdict_soft_mode_mentions_attenuation(self):
327
+ saved = global_state.protection_mode
328
+ global_state.protection_mode = 'soft'
329
+ try:
330
+ status, msg = protection_verdict('base AND_SALT[5] texture :0.8')
331
+ self.assertEqual(status, 'ok')
332
+ self.assertIn('attenuat', msg.lower())
333
+ finally:
334
+ global_state.protection_mode = saved
335
+
336
+ def test_normalize_soft_is_valid(self):
337
+ self.assertEqual(global_state.normalize_protection_mode('soft'), 'soft')
338
+
339
+ def test_normalize_unknown_falls_back_to_auto(self):
340
+ self.assertEqual(global_state.normalize_protection_mode('INVALID'), 'auto')
341
+
342
+
343
+ # ---------------------------------------------------------------------------
344
+ # step_utils
345
+ # ---------------------------------------------------------------------------
346
+
347
+ class TestStepWindow(unittest.TestCase):
348
+
349
+ def test_clamp_start_to_zero(self):
350
+ w = StepWindow(-0.5, 0.8)
351
+ self.assertAlmostEqual(w.start, 0.0)
352
+
353
+ def test_clamp_end_to_one(self):
354
+ w = StepWindow(0.2, 2.0)
355
+ self.assertAlmostEqual(w.end, 1.0)
356
+
357
+ def test_start_gt_end_clamped(self):
358
+ w = StepWindow(0.8, 0.2)
359
+ self.assertLessEqual(w.start, w.end)
360
+
361
+ def test_is_active_inside(self):
362
+ w = StepWindow(0.3, 0.7)
363
+ self.assertTrue(w.is_active_at(0.5))
364
+
365
+ def test_is_active_at_boundary(self):
366
+ w = StepWindow(0.3, 0.7)
367
+ self.assertTrue(w.is_active_at(0.3))
368
+ self.assertTrue(w.is_active_at(0.7))
369
+
370
+ def test_is_active_outside(self):
371
+ w = StepWindow(0.3, 0.7)
372
+ self.assertFalse(w.is_active_at(0.1))
373
+ self.assertFalse(w.is_active_at(0.9))
374
+
375
+ def test_always_active_is_always_active(self):
376
+ for p in [0.0, 0.5, 1.0]:
377
+ self.assertTrue(ALWAYS_ACTIVE.is_active_at(p))
378
+
379
+ def test_repr(self):
380
+ w = StepWindow(0.3, 0.7)
381
+ self.assertIn('0.30', repr(w))
382
+ self.assertIn('0.70', repr(w))
383
+
384
+
385
+ class TestNormalizeProgress(unittest.TestCase):
386
+
387
+ def test_first_step(self):
388
+ self.assertAlmostEqual(normalize_progress(0, 20), 0.0)
389
+
390
+ def test_last_step(self):
391
+ self.assertAlmostEqual(normalize_progress(19, 20), 1.0)
392
+
393
+ def test_midpoint(self):
394
+ p = normalize_progress(10, 21)
395
+ self.assertAlmostEqual(p, 10 / 20)
396
+
397
+ def test_single_step(self):
398
+ self.assertAlmostEqual(normalize_progress(0, 1), 0.0)
399
+
400
+ def test_clamp_to_zero_one(self):
401
+ self.assertGreaterEqual(normalize_progress(0, 100), 0.0)
402
+ self.assertLessEqual(normalize_progress(99, 100), 1.0)
403
+
404
+
405
+ class TestStrategyIsActive(unittest.TestCase):
406
+
407
+ def test_no_window_always_active(self):
408
+ self.assertTrue(strategy_is_active('SALIENCE_MASK', 0.5))
409
+ self.assertTrue(strategy_is_active('PERPENDICULAR', 0.0))
410
+
411
+ def test_global_window_gates(self):
412
+ win = StepWindow(0.4, 0.8)
413
+ self.assertTrue( strategy_is_active('SALIENCE_MASK', 0.6, global_window=win))
414
+ self.assertFalse(strategy_is_active('SALIENCE_MASK', 0.1, global_window=win))
415
+
416
+ def test_per_strategy_overrides_global(self):
417
+ global_win = StepWindow(0.0, 1.0)
418
+ per_s = {'SALIENCE_MASK': StepWindow(0.6, 1.0)}
419
+ # At 0.3: global says active, per-strategy says inactive
420
+ self.assertFalse(strategy_is_active(
421
+ 'SALIENCE_MASK', 0.3, global_window=global_win,
422
+ per_strategy_windows=per_s))
423
+
424
+ def test_defaults_perp_early(self):
425
+ # PERPENDICULAR default: 0.0–0.5
426
+ self.assertTrue( strategy_is_active('PERPENDICULAR', 0.3, use_defaults=True))
427
+ self.assertFalse(strategy_is_active('PERPENDICULAR', 0.8, use_defaults=True))
428
+
429
+ def test_defaults_salt_late(self):
430
+ # SALIENCE_MASK default: 0.4–1.0
431
+ self.assertTrue( strategy_is_active('SALIENCE_MASK', 0.7, use_defaults=True))
432
+ self.assertFalse(strategy_is_active('SALIENCE_MASK', 0.1, use_defaults=True))
433
+
434
+ def test_legacy_alignment_strategy_family(self):
435
+ fam = _strategy_family('ALIGNMENT_BLEND_4_8')
436
+ self.assertEqual(fam, 'ALIGNMENT_BLEND_CUSTOM')
437
+
438
+ def test_legacy_mask_alignment_family(self):
439
+ fam = _strategy_family('ALIGNMENT_MASK_BLEND_8_16')
440
+ self.assertEqual(fam, 'ALIGNMENT_MASK_BLEND_CUSTOM')
441
+
442
+ def test_base_strategy_family_unchanged(self):
443
+ self.assertEqual(_strategy_family('PERPENDICULAR'), 'PERPENDICULAR')
444
+
445
+
446
+ class TestStepWindowSummary(unittest.TestCase):
447
+
448
+ def test_empty_strategies_returns_empty(self):
449
+ out = render_step_window_summary([])
450
+ self.assertEqual(out, '')
451
+
452
+ def test_no_window_shows_all_steps(self):
453
+ out = render_step_window_summary(['SALIENCE_MASK'])
454
+ self.assertIn('all steps', out)
455
+
456
+ def test_global_window_shown(self):
457
+ win = StepWindow(0.3, 0.8)
458
+ out = render_step_window_summary(['SALIENCE_MASK'], global_window=win)
459
+ self.assertIn('SALIENCE_MASK', out)
460
+ self.assertIn('0.30', out)
461
+
462
+ def test_defaults_shown(self):
463
+ out = render_step_window_summary(['PERPENDICULAR'], use_defaults=True)
464
+ self.assertIn('PERPENDICULAR', out)
465
+ self.assertIn('default', out)
466
+
467
+ def test_progress_shows_active_status(self):
468
+ win = StepWindow(0.4, 0.9)
469
+ out = render_step_window_summary(['SALIENCE_MASK'], progress=0.5, global_window=win)
470
+ self.assertIn('active', out)
471
+
472
+
473
+ # ---------------------------------------------------------------------------
474
+ # Builder v1.2
475
+ # ---------------------------------------------------------------------------
476
+
477
+ class TestBuilderNode(unittest.TestCase):
478
+
479
+ def test_make_creates_node(self):
480
+ n = BuilderNode.make('AND_SALT', 'texture', 0.8)
481
+ self.assertEqual(n.strategy, 'AND_SALT')
482
+ self.assertEqual(n.text, 'texture')
483
+ self.assertAlmostEqual(n.weight, 0.8)
484
+ self.assertIsNotNone(n.node_id)
485
+
486
+ def test_make_unique_ids(self):
487
+ ids = {BuilderNode.make().node_id for _ in range(20)}
488
+ self.assertEqual(len(ids), 20)
489
+
490
+ def test_default_no_children(self):
491
+ n = BuilderNode.make()
492
+ self.assertEqual(n.children, [])
493
+
494
+
495
+ class TestBuilderAddChild(unittest.TestCase):
496
+
497
+ def test_add_to_empty(self):
498
+ nodes = builder_add_child([], strategy='AND_SALT', text='tex', weight=0.8)
499
+ self.assertEqual(len(nodes), 1)
500
+ self.assertEqual(nodes[0].strategy, 'AND_SALT')
501
+
502
+ def test_add_appends(self):
503
+ n1 = BuilderNode.make('AND_SALT', 'tex', 0.8)
504
+ nodes = builder_add_child([n1], strategy='AND_PERP', text='style', weight=0.5)
505
+ self.assertEqual(len(nodes), 2)
506
+
507
+ def test_add_nested_child(self):
508
+ parent = BuilderNode.make('AND_SALT', 'tex', 0.8)
509
+ nodes = builder_add_child([parent], parent_id=parent.node_id,
510
+ strategy='AND_TOPK', text='detail', weight=0.4)
511
+ self.assertEqual(len(nodes), 1)
512
+ self.assertEqual(len(nodes[0].children), 1)
513
+ self.assertEqual(nodes[0].children[0].strategy, 'AND_TOPK')
514
+
515
+ def test_original_not_mutated(self):
516
+ original = [BuilderNode.make()]
517
+ builder_add_child(original)
518
+ self.assertEqual(len(original), 1)
519
+
520
+
521
+ class TestBuilderRemoveNode(unittest.TestCase):
522
+
523
+ def test_remove_top_level(self):
524
+ n1 = BuilderNode.make()
525
+ n2 = BuilderNode.make()
526
+ result = builder_remove_node([n1, n2], n1.node_id)
527
+ self.assertEqual(len(result), 1)
528
+ self.assertEqual(result[0].node_id, n2.node_id)
529
+
530
+ def test_remove_nested(self):
531
+ child = BuilderNode.make('AND_TOPK', 'c', 0.5)
532
+ parent = BuilderNode.make('AND_SALT', 'p', 0.8)
533
+ parent = dataclasses_replace(parent, children=[child])
534
+ result = builder_remove_node([parent], child.node_id)
535
+ self.assertEqual(len(result[0].children), 0)
536
+
537
+ def test_remove_nonexistent_no_error(self):
538
+ n = BuilderNode.make()
539
+ result = builder_remove_node([n], 'nonexistent')
540
+ self.assertEqual(len(result), 1)
541
+
542
+
543
+ class TestBuilderDuplicateNode(unittest.TestCase):
544
+
545
+ def test_duplicate_creates_new_id(self):
546
+ n = BuilderNode.make('AND_SALT', 'tex', 0.8)
547
+ result = builder_duplicate_node([n], n.node_id)
548
+ self.assertEqual(len(result), 2)
549
+ self.assertNotEqual(result[0].node_id, result[1].node_id)
550
+
551
+ def test_duplicate_copies_content(self):
552
+ n = BuilderNode.make('AND_SALT', 'tex', 0.8)
553
+ result = builder_duplicate_node([n], n.node_id)
554
+ self.assertEqual(result[1].text, 'tex')
555
+ self.assertEqual(result[1].strategy, 'AND_SALT')
556
+
557
+
558
+ class TestBuilderMoveNodes(unittest.TestCase):
559
+
560
+ def _nodes(self, *strategies):
561
+ return [BuilderNode.make(s, s, 1.0) for s in strategies]
562
+
563
+ def test_move_up_swaps(self):
564
+ ns = self._nodes('AND_SALT', 'AND_PERP', 'AND_TOPK')
565
+ result = builder_move_up(ns, ns[1].node_id)
566
+ self.assertEqual(result[0].strategy, 'AND_PERP')
567
+ self.assertEqual(result[1].strategy, 'AND_SALT')
568
+
569
+ def test_move_up_first_no_change(self):
570
+ ns = self._nodes('AND_SALT', 'AND_PERP')
571
+ result = builder_move_up(ns, ns[0].node_id)
572
+ self.assertEqual(result[0].strategy, 'AND_SALT')
573
+
574
+ def test_move_down_swaps(self):
575
+ ns = self._nodes('AND_SALT', 'AND_PERP', 'AND_TOPK')
576
+ result = builder_move_down(ns, ns[1].node_id)
577
+ self.assertEqual(result[1].strategy, 'AND_TOPK')
578
+ self.assertEqual(result[2].strategy, 'AND_PERP')
579
+
580
+ def test_move_down_last_no_change(self):
581
+ ns = self._nodes('AND_SALT', 'AND_PERP')
582
+ result = builder_move_down(ns, ns[-1].node_id)
583
+ self.assertEqual(result[-1].strategy, 'AND_PERP')
584
+
585
+
586
+ class TestBuilderUpdateNode(unittest.TestCase):
587
+
588
+ def test_update_text(self):
589
+ n = BuilderNode.make('AND_SALT', 'old', 0.8)
590
+ result = builder_update_node([n], n.node_id, text='new')
591
+ self.assertEqual(result[0].text, 'new')
592
+
593
+ def test_update_weight(self):
594
+ n = BuilderNode.make('AND_SALT', 'tex', 0.8)
595
+ result = builder_update_node([n], n.node_id, weight=0.3)
596
+ self.assertAlmostEqual(result[0].weight, 0.3)
597
+
598
+ def test_update_nonexistent_no_change(self):
599
+ n = BuilderNode.make('AND_SALT', 'tex', 0.8)
600
+ result = builder_update_node([n], 'nope', text='changed')
601
+ self.assertEqual(result[0].text, 'tex')
602
+
603
+
604
+ class TestSerializeBuilderTree(unittest.TestCase):
605
+
606
+ def test_base_only(self):
607
+ out = builder_tree_to_prompt('base subject', [])
608
+ self.assertEqual(out.strip(), 'base subject')
609
+
610
+ def test_single_node(self):
611
+ n = BuilderNode.make('AND_SALT', 'texture', 0.8)
612
+ out = builder_tree_to_prompt('base', [n])
613
+ self.assertIn('AND_SALT[5]', out)
614
+ self.assertIn('texture', out)
615
+ self.assertIn(':0.8', out)
616
+
617
+ def test_nested_node_renders_brackets(self):
618
+ child = BuilderNode.make('AND_TOPK', 'highlights', 0.4)
619
+ parent_node = BuilderNode.make('AND_SALT', 'texture', 0.8)
620
+ parent_node = dataclasses_replace(parent_node, children=[child])
621
+ out = builder_tree_to_prompt('base', [parent_node])
622
+ self.assertIn('[', out)
623
+ self.assertIn(']', out)
624
+ self.assertIn('AND_TOPK', out)
625
+
626
+ def test_generated_prompt_parseable(self):
627
+ n1 = BuilderNode.make('AND_SALT', 'texture', 0.8)
628
+ n2 = BuilderNode.make('AND_PERP', 'style', 0.5)
629
+ out = builder_tree_to_prompt('base', [n1, n2])
630
+ p.parse_root(out) # must not raise
631
+
632
+ def test_empty_text_node_skipped(self):
633
+ n = BuilderNode.make('AND_SALT', '', 0.8)
634
+ out = builder_tree_to_prompt('base', [n])
635
+ self.assertNotIn('AND_SALT', out)
636
+
637
+ def test_affine_prefix_included(self):
638
+ n = BuilderNode.make('AND_PERP', 'mirror', 0.6, affine='SCALE[-1,1]')
639
+ out = builder_tree_to_prompt('base', [n])
640
+ self.assertIn('SCALE[-1,1]', out)
641
+
642
+ def test_two_top_level_children(self):
643
+ n1 = BuilderNode.make('AND_SALT', 'tex', 0.8)
644
+ n2 = BuilderNode.make('AND_ALIGN', 'shape', 0.7)
645
+ out = builder_tree_to_prompt('base', [n1, n2])
646
+ self.assertIn('AND_SALT', out)
647
+ self.assertIn('AND_ALIGN', out)
648
+
649
+
650
+ # dataclasses.replace alias for Python <3.13
651
+ import dataclasses
652
+ dataclasses_replace = dataclasses.replace
653
+
654
+
655
+ if __name__ == '__main__':
656
+ unittest.main()
neutral_prompt_patcheds/test/perp_parser/test_sprint2_hotfix.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sprint 2 hotfix tests — covers all 4 issues patched after Sprint 2 review.
3
+
4
+ 1. BUILDER_STRATEGIES name fix (smoke-test that the name resolves)
5
+ 2. update_step_state helper + step-counter round-trip
6
+ 3. get_extra_generation_params / unpack_processing_args step-window round-trip
7
+ 4. apply_infotext step-window restore
8
+ + 'custom' mode absent from two-mode choices
9
+ """
10
+
11
+ import unittest
12
+
13
+ import lib_neutral_prompt.neutral_prompt_parser as p
14
+ from lib_neutral_prompt import global_state
15
+ from lib_neutral_prompt.matryoshka_utils import BUILDER_STRATEGIES
16
+ from lib_neutral_prompt.step_utils import StepWindow, strategy_is_active, normalize_progress
17
+
18
+
19
+ # ---------------------------------------------------------------------------
20
+ # Fix 1 — BUILDER_STRATEGIES is importable and non-empty
21
+ # ---------------------------------------------------------------------------
22
+
23
+ class TestBuilderStrategiesName(unittest.TestCase):
24
+
25
+ def test_importable_and_non_empty(self):
26
+ self.assertIsInstance(BUILDER_STRATEGIES, list)
27
+ self.assertGreater(len(BUILDER_STRATEGIES), 0)
28
+
29
+ def test_contains_core_strategies(self):
30
+ self.assertIn('AND_SALT', BUILDER_STRATEGIES)
31
+ self.assertIn('AND_PERP', BUILDER_STRATEGIES)
32
+ self.assertIn('AND_TOPK', BUILDER_STRATEGIES)
33
+ self.assertIn('AND_ALIGN', BUILDER_STRATEGIES)
34
+
35
+ def test_no_underscore_variant_exported(self):
36
+ """The UI used to import _BUILDER_STRATEGIES — make sure that's gone."""
37
+ import lib_neutral_prompt.matryoshka_utils as mu
38
+ self.assertFalse(
39
+ hasattr(mu, '_BUILDER_STRATEGIES'),
40
+ '_BUILDER_STRATEGIES should not be exported; use BUILDER_STRATEGIES',
41
+ )
42
+
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # Fix 2 — update_step_state + step-window gating integration
46
+ # ---------------------------------------------------------------------------
47
+
48
+ class TestUpdateStepState(unittest.TestCase):
49
+
50
+ def setUp(self):
51
+ self._saved_step = global_state.current_step
52
+ self._saved_total = global_state.total_steps
53
+
54
+ def tearDown(self):
55
+ global_state.current_step = self._saved_step
56
+ global_state.total_steps = self._saved_total
57
+
58
+ def test_basic_update(self):
59
+ global_state.update_step_state(5, 20)
60
+ self.assertEqual(global_state.current_step, 5)
61
+ self.assertEqual(global_state.total_steps, 20)
62
+
63
+ def test_float_inputs_converted(self):
64
+ global_state.update_step_state(3.9, 19.1)
65
+ self.assertEqual(global_state.current_step, 3)
66
+ self.assertEqual(global_state.total_steps, 19)
67
+
68
+ def test_negative_step_clamped_to_zero(self):
69
+ global_state.update_step_state(-1, 10)
70
+ self.assertEqual(global_state.current_step, 0)
71
+
72
+ def test_zero_total_clamped_to_one(self):
73
+ global_state.update_step_state(0, 0)
74
+ self.assertEqual(global_state.total_steps, 1)
75
+
76
+ def test_bad_types_do_not_crash(self):
77
+ global_state.update_step_state('bad', None)
78
+ # Should stay at some valid state
79
+ self.assertGreaterEqual(global_state.current_step, 0)
80
+ self.assertGreaterEqual(global_state.total_steps, 1)
81
+
82
+ def test_first_step_progress_is_zero(self):
83
+ global_state.update_step_state(0, 20)
84
+ p = normalize_progress(global_state.current_step, global_state.total_steps)
85
+ self.assertAlmostEqual(p, 0.0)
86
+
87
+ def test_last_step_progress_is_one(self):
88
+ global_state.update_step_state(19, 20)
89
+ p = normalize_progress(global_state.current_step, global_state.total_steps)
90
+ self.assertAlmostEqual(p, 1.0)
91
+
92
+
93
+ class TestStepWindowGatingIntegration(unittest.TestCase):
94
+ """
95
+ Simulate the hijack path: update_step_state → normalize_progress
96
+ → strategy_is_active with a global window.
97
+ This is the scenario described in the fix-2 specification.
98
+ """
99
+
100
+ def setUp(self):
101
+ self._saved_step = global_state.current_step
102
+ self._saved_total = global_state.total_steps
103
+
104
+ def tearDown(self):
105
+ global_state.current_step = self._saved_step
106
+ global_state.total_steps = self._saved_total
107
+
108
+ def _set_step(self, step, total):
109
+ global_state.update_step_state(step, total)
110
+
111
+ def test_and_salt_gated_out_in_early_window(self):
112
+ """AND_SALT with window 0.6–1.0 should be skipped at step 0/9."""
113
+ self._set_step(0, 10)
114
+ win = StepWindow(0.6, 1.0)
115
+ progress = normalize_progress(global_state.current_step, global_state.total_steps)
116
+ active = strategy_is_active('SALIENCE_MASK', progress, global_window=win)
117
+ self.assertFalse(active, f'Expected inactive at progress={progress:.3f}')
118
+
119
+ def test_and_salt_active_in_late_window(self):
120
+ """AND_SALT with window 0.6–1.0 should be active at step 8/10."""
121
+ self._set_step(8, 10)
122
+ win = StepWindow(0.6, 1.0)
123
+ progress = normalize_progress(global_state.current_step, global_state.total_steps)
124
+ active = strategy_is_active('SALIENCE_MASK', progress, global_window=win)
125
+ self.assertTrue(active, f'Expected active at progress={progress:.3f}')
126
+
127
+ def test_and_perp_active_early_with_defaults(self):
128
+ """AND_PERP defaults: 0.0–0.5 → step 2/10 should be active."""
129
+ self._set_step(2, 10)
130
+ progress = normalize_progress(global_state.current_step, global_state.total_steps)
131
+ active = strategy_is_active('PERPENDICULAR', progress, use_defaults=True)
132
+ self.assertTrue(active, f'Expected active at progress={progress:.3f}')
133
+
134
+ def test_and_perp_gated_out_late_with_defaults(self):
135
+ """AND_PERP defaults: 0.0–0.5 → step 8/10 should be inactive."""
136
+ self._set_step(8, 10)
137
+ progress = normalize_progress(global_state.current_step, global_state.total_steps)
138
+ active = strategy_is_active('PERPENDICULAR', progress, use_defaults=True)
139
+ self.assertFalse(active, f'Expected inactive at progress={progress:.3f}')
140
+
141
+ def test_no_window_always_active_regardless_of_step(self):
142
+ for step in (0, 5, 9):
143
+ with self.subTest(step=step):
144
+ self._set_step(step, 10)
145
+ progress = normalize_progress(global_state.current_step, global_state.total_steps)
146
+ self.assertTrue(strategy_is_active('SALIENCE_MASK', progress))
147
+
148
+
149
+ # ---------------------------------------------------------------------------
150
+ # Fix 3 — step-window round-trip through unpack_processing_args + get_extra
151
+ # ---------------------------------------------------------------------------
152
+
153
+ def _unpack_step_window(cfg_rescale=0.0, protection_mode='auto', strict_threshold=0.1,
154
+ step_window_enabled=False, step_window_mode='global',
155
+ step_window_start=0.0, step_window_end=1.0):
156
+ """
157
+ Pure-logic equivalent of AccordionInterface.unpack_processing_args
158
+ without any gradio dependency — mirrors the real implementation exactly.
159
+ """
160
+ enabled = bool(step_window_enabled)
161
+ mode = str(step_window_mode)
162
+ start = max(0.0, min(1.0, float(step_window_start)))
163
+ end = max(0.0, min(1.0, float(step_window_end)))
164
+ if start > end:
165
+ start, end = 0.0, 1.0
166
+
167
+ global_state.step_window_enabled = enabled
168
+ if enabled:
169
+ if mode == 'global':
170
+ global_state.step_window_global = StepWindow(start, end)
171
+ global_state.step_window_use_defaults = False
172
+ elif mode == 'per-strategy defaults':
173
+ global_state.step_window_global = None
174
+ global_state.step_window_use_defaults = True
175
+ else:
176
+ global_state.step_window_global = None
177
+ global_state.step_window_use_defaults = False
178
+ else:
179
+ global_state.step_window_global = None
180
+ global_state.step_window_use_defaults = False
181
+
182
+ return {'cfg_rescale': cfg_rescale, 'protection_mode': protection_mode,
183
+ 'strict_threshold': strict_threshold,
184
+ 'step_window_enabled': enabled, 'step_window_mode': mode,
185
+ 'step_window_start': start, 'step_window_end': end}
186
+
187
+
188
+ def _get_extra_params(args):
189
+ """Pure-logic equivalent of AccordionInterface.get_extra_generation_params."""
190
+ p = {'CFG Rescale phi': args['cfg_rescale'],
191
+ 'NP Protection Mode': args['protection_mode'],
192
+ 'NP Strict Threshold': args['strict_threshold']}
193
+ if args.get('step_window_enabled'):
194
+ p['NP Step Window Enabled'] = True
195
+ p['NP Step Window Mode'] = args.get('step_window_mode', 'global')
196
+ p['NP Step Window Start'] = round(float(args.get('step_window_start', 0.0)), 4)
197
+ p['NP Step Window End'] = round(float(args.get('step_window_end', 1.0)), 4)
198
+ return p
199
+
200
+
201
+ class TestStepWindowProcessingRoundTrip(unittest.TestCase):
202
+
203
+ def setUp(self):
204
+ self._saved_enabled = global_state.step_window_enabled
205
+ self._saved_global = global_state.step_window_global
206
+ self._saved_use_defaults = global_state.step_window_use_defaults
207
+
208
+ def tearDown(self):
209
+ global_state.step_window_enabled = self._saved_enabled
210
+ global_state.step_window_global = self._saved_global
211
+ global_state.step_window_use_defaults = self._saved_use_defaults
212
+
213
+ def test_global_window_written_to_state(self):
214
+ args = _unpack_step_window(
215
+ cfg_rescale=0.0,
216
+ step_window_enabled=True,
217
+ step_window_mode='global',
218
+ step_window_start=0.3,
219
+ step_window_end=0.8,
220
+ )
221
+ self.assertTrue(global_state.step_window_enabled)
222
+ self.assertIsNotNone(global_state.step_window_global)
223
+ self.assertAlmostEqual(global_state.step_window_global.start, 0.3, places=4)
224
+ self.assertAlmostEqual(global_state.step_window_global.end, 0.8, places=4)
225
+
226
+ def test_defaults_mode_sets_use_defaults(self):
227
+ _unpack_step_window(
228
+ cfg_rescale=0.0,
229
+ step_window_enabled=True,
230
+ step_window_mode='per-strategy defaults',
231
+ )
232
+ self.assertTrue(global_state.step_window_use_defaults)
233
+ self.assertIsNone(global_state.step_window_global)
234
+
235
+ def test_disabled_clears_state(self):
236
+ global_state.step_window_enabled = True
237
+ global_state.step_window_global = StepWindow(0.3, 0.8)
238
+ _unpack_step_window(cfg_rescale=0.0, step_window_enabled=False)
239
+ self.assertFalse(global_state.step_window_enabled)
240
+ self.assertIsNone(global_state.step_window_global)
241
+
242
+ def test_invalid_range_falls_back_to_full(self):
243
+ args = _unpack_step_window(
244
+ cfg_rescale=0.0,
245
+ step_window_enabled=True,
246
+ step_window_mode='global',
247
+ step_window_start=0.9,
248
+ step_window_end=0.1, # start > end → fallback
249
+ )
250
+ self.assertAlmostEqual(args['step_window_start'], 0.0)
251
+ self.assertAlmostEqual(args['step_window_end'], 1.0)
252
+
253
+ def test_extra_params_include_step_window_when_enabled(self):
254
+ args = {
255
+ 'cfg_rescale': 0.0,
256
+ 'protection_mode': 'auto',
257
+ 'strict_threshold': 0.1,
258
+ 'step_window_enabled': True,
259
+ 'step_window_mode': 'global',
260
+ 'step_window_start': 0.3,
261
+ 'step_window_end': 0.8,
262
+ }
263
+ params = _get_extra_params(args)
264
+ self.assertIn('NP Step Window Enabled', params)
265
+ self.assertIn('NP Step Window Mode', params)
266
+ self.assertIn('NP Step Window Start', params)
267
+ self.assertIn('NP Step Window End', params)
268
+ self.assertAlmostEqual(params['NP Step Window Start'], 0.3, places=4)
269
+
270
+ def test_extra_params_omit_step_window_when_disabled(self):
271
+ args = {
272
+ 'cfg_rescale': 0.0,
273
+ 'protection_mode': 'auto',
274
+ 'strict_threshold': 0.1,
275
+ 'step_window_enabled': False,
276
+ }
277
+ params = _get_extra_params(args)
278
+ self.assertNotIn('NP Step Window Enabled', params)
279
+
280
+
281
+ # ---------------------------------------------------------------------------
282
+ # Fix 3 — apply_infotext step-window restore
283
+ # ---------------------------------------------------------------------------
284
+
285
+ class TestApplyInfotextStepWindow(unittest.TestCase):
286
+
287
+ def setUp(self):
288
+ self._saved_enabled = global_state.step_window_enabled
289
+ self._saved_global = global_state.step_window_global
290
+ self._saved_use_defaults = global_state.step_window_use_defaults
291
+
292
+ def tearDown(self):
293
+ global_state.step_window_enabled = self._saved_enabled
294
+ global_state.step_window_global = self._saved_global
295
+ global_state.step_window_use_defaults = self._saved_use_defaults
296
+
297
+ def test_global_window_restored(self):
298
+ global_state.apply_infotext({
299
+ 'NP Step Window Enabled': 'True',
300
+ 'NP Step Window Mode': 'global',
301
+ 'NP Step Window Start': '0.25',
302
+ 'NP Step Window End': '0.75',
303
+ })
304
+ self.assertTrue(global_state.step_window_enabled)
305
+ self.assertIsNotNone(global_state.step_window_global)
306
+ self.assertAlmostEqual(global_state.step_window_global.start, 0.25, places=4)
307
+ self.assertAlmostEqual(global_state.step_window_global.end, 0.75, places=4)
308
+
309
+ def test_defaults_mode_restored(self):
310
+ global_state.apply_infotext({
311
+ 'NP Step Window Enabled': 'True',
312
+ 'NP Step Window Mode': 'per-strategy defaults',
313
+ })
314
+ self.assertTrue(global_state.step_window_enabled)
315
+ self.assertTrue(global_state.step_window_use_defaults)
316
+ self.assertIsNone(global_state.step_window_global)
317
+
318
+ def test_disabled_not_restored_when_key_absent(self):
319
+ global_state.step_window_enabled = True
320
+ global_state.apply_infotext({}) # no step-window keys
321
+ # Should be unchanged
322
+ self.assertTrue(global_state.step_window_enabled)
323
+
324
+ def test_false_string_disables(self):
325
+ global_state.step_window_enabled = True
326
+ global_state.apply_infotext({'NP Step Window Enabled': 'False'})
327
+ self.assertFalse(global_state.step_window_enabled)
328
+
329
+ def test_invalid_range_falls_back_to_full(self):
330
+ global_state.apply_infotext({
331
+ 'NP Step Window Enabled': 'True',
332
+ 'NP Step Window Mode': 'global',
333
+ 'NP Step Window Start': '0.9',
334
+ 'NP Step Window End': '0.1', # inverted
335
+ })
336
+ win = global_state.step_window_global
337
+ self.assertAlmostEqual(win.start, 0.0, places=4)
338
+ self.assertAlmostEqual(win.end, 1.0, places=4)
339
+
340
+ def test_old_infotext_without_step_window_keys_does_not_crash(self):
341
+ global_state.apply_infotext({
342
+ 'CFG Rescale phi': '0.5',
343
+ 'NP Protection Mode': 'auto',
344
+ })
345
+ # No crash, step_window unchanged
346
+
347
+
348
+ # ---------------------------------------------------------------------------
349
+ # Fix 4 — 'custom' absent from step_window_mode UI choices
350
+ # ---------------------------------------------------------------------------
351
+
352
+ class TestCustomModeRemovedFromUI(unittest.TestCase):
353
+
354
+ def test_step_window_update_handler_accepts_two_modes(self):
355
+ """
356
+ The step-window update logic must handle exactly the two released modes
357
+ without any 'custom' branch.
358
+ """
359
+ saved = {
360
+ 'enabled': global_state.step_window_enabled,
361
+ 'global': global_state.step_window_global,
362
+ 'use_defaults': global_state.step_window_use_defaults,
363
+ }
364
+ try:
365
+ for mode in ('global', 'per-strategy defaults'):
366
+ with self.subTest(mode=mode):
367
+ global_state.step_window_enabled = True
368
+ if mode == 'global':
369
+ global_state.step_window_global = StepWindow(0.3, 0.8)
370
+ global_state.step_window_use_defaults = False
371
+ else:
372
+ global_state.step_window_global = None
373
+ global_state.step_window_use_defaults = True
374
+ # Verify state is consistent for both valid modes
375
+ if mode == 'global':
376
+ self.assertIsNotNone(global_state.step_window_global)
377
+ else:
378
+ self.assertTrue(global_state.step_window_use_defaults)
379
+ finally:
380
+ global_state.step_window_enabled = saved['enabled']
381
+ global_state.step_window_global = saved['global']
382
+ global_state.step_window_use_defaults = saved['use_defaults']
383
+
384
+ def test_step_utils_still_supports_custom_internally(self):
385
+ """
386
+ custom is removed from UI but step_utils.strategy_is_active must
387
+ still accept None per_strategy_windows gracefully (for API use).
388
+ """
389
+ from lib_neutral_prompt.step_utils import strategy_is_active
390
+ # custom = no global_window, no per_strategy_windows, no defaults
391
+ active = strategy_is_active('SALIENCE_MASK', 0.5,
392
+ global_window=None,
393
+ per_strategy_windows=None,
394
+ use_defaults=False)
395
+ self.assertTrue(active, 'Without any window, strategy must be always active')
396
+
397
+ def test_unpack_does_not_expose_custom_mode_to_state(self):
398
+ """
399
+ If someone passes 'custom' through the API, unpack must fall back
400
+ to a safe state (no window, no defaults).
401
+ """
402
+ saved_global = global_state.step_window_global
403
+ saved_defaults = global_state.step_window_use_defaults
404
+ try:
405
+ _unpack_step_window(
406
+ cfg_rescale=0.0,
407
+ step_window_enabled=True,
408
+ step_window_mode='custom', # not in UI, but API-safe
409
+ )
410
+ # In the 'else' branch: global=None, use_defaults=False
411
+ self.assertIsNone(global_state.step_window_global)
412
+ self.assertFalse(global_state.step_window_use_defaults)
413
+ finally:
414
+ global_state.step_window_global = saved_global
415
+ global_state.step_window_use_defaults = saved_defaults
416
+
417
+
418
+ if __name__ == '__main__':
419
+ unittest.main()
neutral_prompt_patcheds/test/perp_parser/test_stabilization.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Stabilization test suite — Sprint 2 hotfix + per-strategy custom windows.
3
+
4
+ Sections
5
+ --------
6
+ 1. UI smoke-test – verifies ui.py imports and AccordionInterface constructs
7
+ without NameError / bad Gradio kwargs
8
+ 2. Step-window integration – update_step_state → progress → strategy activation
9
+ 3. Generation params round-trip – serialize → apply_infotext → state restored
10
+ 4. Per-strategy custom windows – serialize/deserialize/build helpers
11
+ 5. Per-strategy round-trip – custom mode infotext restore
12
+ """
13
+
14
+ import sys
15
+ import types
16
+ import unittest
17
+ import importlib
18
+
19
+ from lib_neutral_prompt import global_state
20
+ from lib_neutral_prompt.step_utils import (
21
+ StepWindow,
22
+ normalize_progress,
23
+ strategy_is_active,
24
+ STRATEGY_UI_KEYS,
25
+ STRATEGY_DEFAULTS,
26
+ normalize_step_range,
27
+ build_per_strategy_windows,
28
+ serialize_per_strategy_windows,
29
+ deserialize_per_strategy_windows,
30
+ default_custom_windows,
31
+ ui_key_to_family,
32
+ )
33
+
34
+
35
+ # ===========================================================================
36
+ # 1. UI smoke-test
37
+ # ===========================================================================
38
+
39
+ def _make_fake_gradio():
40
+ """
41
+ Minimal gradio mock that records kwargs without raising on unknown ones.
42
+ This lets us detect *missing constants* (NameError) and *wrong import paths*
43
+ while ignoring version-specific kwargs differences.
44
+ """
45
+ gr = types.ModuleType('gradio')
46
+
47
+ class _Component:
48
+ def __init__(self, *a, **kw):
49
+ self._kwargs = kw
50
+ self.value = kw.get('value')
51
+ def render(self): return self
52
+ def change(self, **kw): return self
53
+ def click(self, **kw): return self
54
+ @staticmethod
55
+ def update(**kw): return kw
56
+
57
+ for name in ('Checkbox', 'Radio', 'Slider', 'Textbox', 'Button',
58
+ 'Dropdown', 'Accordion', 'Row', 'Column', 'Tab', 'Tabs',
59
+ 'HTML', 'Markdown', 'Number', 'Code'):
60
+ setattr(gr, name, _Component)
61
+
62
+ gr.update = lambda **kw: kw
63
+ gr.__version__ = '3.41.0'
64
+ return gr
65
+
66
+
67
+ class TestUISmoke(unittest.TestCase):
68
+ """
69
+ Import lib_neutral_prompt.ui under mocked gradio + modules and verify that
70
+ AccordionInterface can be constructed without errors.
71
+
72
+ This catches:
73
+ * NameError (undefined constants like _BUILDER_STRATEGIES)
74
+ * ImportError (bad relative imports)
75
+ * AttributeError from wrong Gradio kwarg names such as 'tooltip'
76
+ """
77
+
78
+ @classmethod
79
+ def setUpClass(cls):
80
+ # If ui.py was already imported in this process (e.g. by another test
81
+ # module), the gradio mock cannot be reliably injected. Mark as skipped
82
+ # rather than erroring out.
83
+ if 'lib_neutral_prompt.ui' in sys.modules:
84
+ cls._ui = sys.modules['lib_neutral_prompt.ui']
85
+ cls._skipped = True
86
+ return
87
+ cls._skipped = False
88
+
89
+ gr = _make_fake_gradio()
90
+ sys.modules['gradio'] = gr
91
+
92
+ fake_cb = types.SimpleNamespace(
93
+ on_ui_settings=lambda *a, **k: None,
94
+ on_ui_tabs=lambda *a, **k: None,
95
+ on_after_component=lambda *a, **k: None,
96
+ )
97
+ fake_shared = types.SimpleNamespace(
98
+ opts=types.SimpleNamespace(
99
+ add_option=lambda *a, **k: None,
100
+ onchange=lambda *a, **k: None,
101
+ data={},
102
+ ),
103
+ state=types.SimpleNamespace(sampling_step=0, sampling_steps=1),
104
+ )
105
+ fake_modules = types.ModuleType('modules')
106
+ fake_modules.script_callbacks = fake_cb
107
+ fake_modules.shared = fake_shared
108
+ sys.modules.setdefault('modules', fake_modules)
109
+ sys.modules.setdefault('modules.script_callbacks', fake_cb)
110
+ sys.modules.setdefault('modules.shared', fake_shared)
111
+
112
+ try:
113
+ cls._ui = importlib.import_module('lib_neutral_prompt.ui')
114
+ except Exception as exc:
115
+ raise unittest.SkipTest(f'lib_neutral_prompt.ui failed to import: {exc}')
116
+
117
+ def setUp(self):
118
+ if self._skipped:
119
+ self.skipTest('lib_neutral_prompt.ui already imported; smoke-test skipped')
120
+
121
+ def test_module_imported_without_error(self):
122
+ self.assertIsNotNone(self._ui)
123
+
124
+ def test_accordion_interface_constructs(self):
125
+ obj = self._ui.AccordionInterface(get_elem_id=lambda s: f'test_{s}')
126
+ self.assertIsNotNone(obj)
127
+
128
+ def test_builder_strategies_constant_reachable(self):
129
+ from lib_neutral_prompt.matryoshka_utils import BUILDER_STRATEGIES
130
+ self.assertIsInstance(BUILDER_STRATEGIES, list)
131
+ self.assertGreater(len(BUILDER_STRATEGIES), 0)
132
+
133
+ def test_step_ps_sliders_created_for_all_strategies(self):
134
+ obj = self._ui.AccordionInterface(get_elem_id=lambda s: f'test2_{s}')
135
+ for ui_key in STRATEGY_UI_KEYS:
136
+ self.assertIn(ui_key, obj.step_ps_sliders,
137
+ f'Missing per-strategy slider pair for {ui_key}')
138
+ pair = obj.step_ps_sliders[ui_key]
139
+ self.assertEqual(len(pair), 2)
140
+
141
+ def test_no_tooltip_kwarg_on_dropdown(self):
142
+ """Regression: tooltip= must have been replaced by info= on all Dropdowns."""
143
+ import ast
144
+ import pathlib
145
+ src = pathlib.Path(
146
+ __file__).parent.parent.parent / 'lib_neutral_prompt' / 'ui.py'
147
+ tree = ast.parse(src.read_text())
148
+ violations = []
149
+ for node in ast.walk(tree):
150
+ if isinstance(node, ast.Call):
151
+ for kw in node.keywords:
152
+ if kw.arg == 'tooltip':
153
+ violations.append(f'line {node.lineno}')
154
+ self.assertFalse(violations,
155
+ f'Found tooltip= kwarg(s) in ui.py: {violations}')
156
+
157
+
158
+ # ===========================================================================
159
+ # 2. Step-window integration
160
+ # ===========================================================================
161
+
162
+ class TestStepWindowIntegration(unittest.TestCase):
163
+ """
164
+ Verify that update_step_state → normalize_progress → strategy_is_active
165
+ gives the expected gating behaviour.
166
+ """
167
+
168
+ def setUp(self):
169
+ self._step = global_state.current_step
170
+ self._total = global_state.total_steps
171
+
172
+ def tearDown(self):
173
+ global_state.current_step = self._step
174
+ global_state.total_steps = self._total
175
+
176
+ # ---- global window ----
177
+
178
+ def test_late_strategy_inactive_at_step_0(self):
179
+ global_state.update_step_state(0, 10)
180
+ p = normalize_progress(global_state.current_step, global_state.total_steps)
181
+ # SALIENCE_MASK default: 0.4–1.0 → inactive at 0.0
182
+ self.assertFalse(strategy_is_active('SALIENCE_MASK', p, use_defaults=True))
183
+
184
+ def test_late_strategy_active_at_step_8(self):
185
+ global_state.update_step_state(8, 10)
186
+ p = normalize_progress(global_state.current_step, global_state.total_steps)
187
+ self.assertTrue(strategy_is_active('SALIENCE_MASK', p, use_defaults=True))
188
+
189
+ def test_early_strategy_active_at_step_0(self):
190
+ global_state.update_step_state(0, 10)
191
+ p = normalize_progress(global_state.current_step, global_state.total_steps)
192
+ # PERPENDICULAR default: 0.0–0.5 → active at 0.0
193
+ self.assertTrue(strategy_is_active('PERPENDICULAR', p, use_defaults=True))
194
+
195
+ def test_early_strategy_inactive_at_step_8(self):
196
+ global_state.update_step_state(8, 10)
197
+ p = normalize_progress(global_state.current_step, global_state.total_steps)
198
+ self.assertFalse(strategy_is_active('PERPENDICULAR', p, use_defaults=True))
199
+
200
+ def test_global_window_overrides_defaults(self):
201
+ """A global window [0.6, 1.0] gates AND_PERP out even at step 0."""
202
+ global_state.update_step_state(0, 10)
203
+ p = normalize_progress(global_state.current_step, global_state.total_steps)
204
+ win = StepWindow(0.6, 1.0)
205
+ self.assertFalse(strategy_is_active('PERPENDICULAR', p,
206
+ global_window=win, use_defaults=True))
207
+
208
+ def test_no_window_always_active(self):
209
+ for step in (0, 5, 9):
210
+ global_state.update_step_state(step, 10)
211
+ p = normalize_progress(global_state.current_step, global_state.total_steps)
212
+ with self.subTest(step=step):
213
+ self.assertTrue(strategy_is_active('SALIENCE_MASK', p))
214
+
215
+
216
+ # ===========================================================================
217
+ # 3. Generation params round-trip (global mode)
218
+ # ===========================================================================
219
+
220
+ def _unpack(cfg_rescale=0.0, protection_mode='auto', strict_threshold=0.1,
221
+ step_window_enabled=False, step_window_mode='global',
222
+ step_window_start=0.0, step_window_end=1.0, **ps_kwargs):
223
+ """Mirrors AccordionInterface.unpack_processing_args logic — no gradio needed."""
224
+ from lib_neutral_prompt.step_utils import (
225
+ StepWindow, build_per_strategy_windows, STRATEGY_UI_KEYS,
226
+ )
227
+ enabled = bool(step_window_enabled)
228
+ mode = str(step_window_mode)
229
+ start = max(0.0, min(1.0, float(step_window_start)))
230
+ end = max(0.0, min(1.0, float(step_window_end)))
231
+ if start > end:
232
+ start, end = 0.0, 1.0
233
+
234
+ custom_raw = {}
235
+ for uk in STRATEGY_UI_KEYS:
236
+ sk, ek = f'{uk}_start', f'{uk}_end'
237
+ if sk in ps_kwargs and ek in ps_kwargs:
238
+ custom_raw[uk] = (float(ps_kwargs[sk]), float(ps_kwargs[ek]))
239
+
240
+ global_state.step_window_enabled = enabled
241
+ if enabled:
242
+ if mode == 'global':
243
+ global_state.step_window_global = StepWindow(start, end)
244
+ global_state.step_window_use_defaults = False
245
+ global_state.step_window_per_strategy = None
246
+ elif mode == 'per-strategy defaults':
247
+ global_state.step_window_global = None
248
+ global_state.step_window_use_defaults = True
249
+ global_state.step_window_per_strategy = None
250
+ elif mode == 'per-strategy custom':
251
+ global_state.step_window_global = None
252
+ global_state.step_window_use_defaults = False
253
+ global_state.step_window_custom_raw = custom_raw
254
+ global_state.step_window_per_strategy = build_per_strategy_windows(custom_raw)
255
+ else:
256
+ global_state.step_window_global = None
257
+ global_state.step_window_use_defaults = False
258
+ else:
259
+ global_state.step_window_global = None
260
+ global_state.step_window_use_defaults = False
261
+ global_state.step_window_per_strategy = None
262
+
263
+ return {
264
+ 'cfg_rescale': cfg_rescale, 'protection_mode': protection_mode,
265
+ 'strict_threshold': strict_threshold,
266
+ 'step_window_enabled': enabled, 'step_window_mode': mode,
267
+ 'step_window_start': start, 'step_window_end': end,
268
+ 'step_window_custom_raw': custom_raw,
269
+ }
270
+
271
+
272
+ def _extra(args):
273
+ """Mirrors AccordionInterface.get_extra_generation_params — no gradio."""
274
+ p = {
275
+ 'CFG Rescale phi': args['cfg_rescale'],
276
+ 'NP Protection Mode': args['protection_mode'],
277
+ 'NP Strict Threshold': args['strict_threshold'],
278
+ }
279
+ if args.get('step_window_enabled'):
280
+ p['NP Step Window Enabled'] = True
281
+ p['NP Step Window Mode'] = args.get('step_window_mode', 'global')
282
+ p['NP Step Window Start'] = round(float(args.get('step_window_start', 0.0)), 4)
283
+ p['NP Step Window End'] = round(float(args.get('step_window_end', 1.0)), 4)
284
+ if args.get('step_window_mode') == 'per-strategy custom':
285
+ raw = args.get('step_window_custom_raw') or {}
286
+ p['NP Step Window Custom'] = serialize_per_strategy_windows(raw)
287
+ return p
288
+
289
+
290
+ class TestGenerationParamsRoundTrip(unittest.TestCase):
291
+
292
+ def setUp(self):
293
+ self._state = {
294
+ 'enabled': global_state.step_window_enabled,
295
+ 'global': global_state.step_window_global,
296
+ 'use_defaults': global_state.step_window_use_defaults,
297
+ 'per_strategy': global_state.step_window_per_strategy,
298
+ }
299
+
300
+ def tearDown(self):
301
+ global_state.step_window_enabled = self._state['enabled']
302
+ global_state.step_window_global = self._state['global']
303
+ global_state.step_window_use_defaults = self._state['use_defaults']
304
+ global_state.step_window_per_strategy = self._state['per_strategy']
305
+
306
+ # -- global mode round-trip --
307
+
308
+ def test_global_extra_params_present_when_enabled(self):
309
+ args = _unpack(step_window_enabled=True, step_window_mode='global',
310
+ step_window_start=0.25, step_window_end=0.80)
311
+ p = _extra(args)
312
+ self.assertTrue(p['NP Step Window Enabled'])
313
+ self.assertEqual(p['NP Step Window Mode'], 'global')
314
+ self.assertAlmostEqual(p['NP Step Window Start'], 0.25, places=4)
315
+ self.assertAlmostEqual(p['NP Step Window End'], 0.80, places=4)
316
+
317
+ def test_global_extra_params_absent_when_disabled(self):
318
+ args = _unpack(step_window_enabled=False)
319
+ p = _extra(args)
320
+ self.assertNotIn('NP Step Window Enabled', p)
321
+
322
+ def test_global_apply_infotext_restores_state(self):
323
+ global_state.apply_infotext({
324
+ 'NP Step Window Enabled': 'True',
325
+ 'NP Step Window Mode': 'global',
326
+ 'NP Step Window Start': '0.3',
327
+ 'NP Step Window End': '0.7',
328
+ })
329
+ self.assertTrue(global_state.step_window_enabled)
330
+ self.assertIsNotNone(global_state.step_window_global)
331
+ self.assertAlmostEqual(global_state.step_window_global.start, 0.3, places=4)
332
+ self.assertAlmostEqual(global_state.step_window_global.end, 0.7, places=4)
333
+
334
+ def test_inverted_range_repaired_in_unpack(self):
335
+ args = _unpack(step_window_enabled=True, step_window_mode='global',
336
+ step_window_start=0.9, step_window_end=0.1)
337
+ self.assertAlmostEqual(args['step_window_start'], 0.0)
338
+ self.assertAlmostEqual(args['step_window_end'], 1.0)
339
+
340
+ def test_inverted_range_repaired_in_apply_infotext(self):
341
+ global_state.apply_infotext({
342
+ 'NP Step Window Enabled': 'True',
343
+ 'NP Step Window Mode': 'global',
344
+ 'NP Step Window Start': '0.9',
345
+ 'NP Step Window End': '0.1',
346
+ })
347
+ win = global_state.step_window_global
348
+ self.assertAlmostEqual(win.start, 0.0)
349
+ self.assertAlmostEqual(win.end, 1.0)
350
+
351
+ def test_old_infotext_without_step_window_keys_safe(self):
352
+ global_state.apply_infotext({'CFG Rescale phi': '0.5', 'NP Protection Mode': 'auto'})
353
+ # No exception — step-window state unchanged
354
+
355
+ def test_defaults_mode_apply_infotext(self):
356
+ global_state.apply_infotext({
357
+ 'NP Step Window Enabled': 'True',
358
+ 'NP Step Window Mode': 'per-strategy defaults',
359
+ })
360
+ self.assertTrue(global_state.step_window_enabled)
361
+ self.assertTrue(global_state.step_window_use_defaults)
362
+ self.assertIsNone(global_state.step_window_global)
363
+
364
+
365
+ # ===========================================================================
366
+ # 4. Per-strategy custom window helpers
367
+ # ===========================================================================
368
+
369
+ class TestPerStrategyHelpers(unittest.TestCase):
370
+
371
+ def test_ui_key_to_family_all_keys(self):
372
+ mapping = {
373
+ 'AND_PERP': 'PERPENDICULAR',
374
+ 'AND_SALT': 'SALIENCE_MASK',
375
+ 'AND_SALT_WIDE': 'SALIENCE_MASK_WIDE',
376
+ 'AND_SALT_BLOB': 'SALIENCE_MASK_BLOB',
377
+ 'AND_TOPK': 'SEMANTIC_GUIDANCE',
378
+ 'AND_ALIGN': 'ALIGNMENT_BLEND_CUSTOM',
379
+ 'AND_MASK_ALIGN': 'ALIGNMENT_MASK_BLEND_CUSTOM',
380
+ }
381
+ for ui_key, expected_family in mapping.items():
382
+ with self.subTest(ui_key=ui_key):
383
+ self.assertEqual(ui_key_to_family(ui_key), expected_family)
384
+
385
+ def test_build_per_strategy_windows_types(self):
386
+ raw = {'AND_SALT': (0.4, 1.0), 'AND_PERP': (0.0, 0.5)}
387
+ result = build_per_strategy_windows(raw)
388
+ self.assertIsInstance(result['SALIENCE_MASK'], StepWindow)
389
+ self.assertAlmostEqual(result['SALIENCE_MASK'].start, 0.4)
390
+ self.assertAlmostEqual(result['PERPENDICULAR'].end, 0.5)
391
+
392
+ def test_build_per_strategy_windows_repairs_inverted(self):
393
+ raw = {'AND_SALT': (0.9, 0.1)} # inverted
394
+ result = build_per_strategy_windows(raw)
395
+ win = result['SALIENCE_MASK']
396
+ self.assertAlmostEqual(win.start, 0.0)
397
+ self.assertAlmostEqual(win.end, 1.0)
398
+
399
+ def test_unknown_ui_key_skipped(self):
400
+ raw = {'AND_MYSTERY': (0.2, 0.8)}
401
+ result = build_per_strategy_windows(raw)
402
+ self.assertEqual(len(result), 0)
403
+
404
+ def test_serialize_round_trip(self):
405
+ raw = {'AND_PERP': (0.0, 0.5), 'AND_SALT': (0.4, 1.0)}
406
+ serialized = serialize_per_strategy_windows(raw)
407
+ restored = deserialize_per_strategy_windows(serialized)
408
+ for ui_key in raw:
409
+ self.assertIn(ui_key, restored)
410
+ self.assertAlmostEqual(restored[ui_key][0], raw[ui_key][0], places=3)
411
+ self.assertAlmostEqual(restored[ui_key][1], raw[ui_key][1], places=3)
412
+
413
+ def test_deserialize_tolerates_garbage(self):
414
+ restored = deserialize_per_strategy_windows('bad:data,AND_SALT:0.40-1.00,junk')
415
+ self.assertIn('AND_SALT', restored)
416
+ self.assertAlmostEqual(restored['AND_SALT'][0], 0.4, places=2)
417
+
418
+ def test_deserialize_empty_string(self):
419
+ self.assertEqual(deserialize_per_strategy_windows(''), {})
420
+
421
+ def test_normalize_step_range_clamps(self):
422
+ s, e = normalize_step_range(-0.5, 1.5)
423
+ self.assertAlmostEqual(s, 0.0)
424
+ self.assertAlmostEqual(e, 1.0)
425
+
426
+ def test_normalize_step_range_swaps_inverted(self):
427
+ s, e = normalize_step_range(0.8, 0.2)
428
+ self.assertAlmostEqual(s, 0.0)
429
+ self.assertAlmostEqual(e, 1.0)
430
+
431
+ def test_default_custom_windows_covers_all_keys(self):
432
+ defaults = default_custom_windows()
433
+ for uk in STRATEGY_UI_KEYS:
434
+ self.assertIn(uk, defaults)
435
+ s, e = defaults[uk]
436
+ self.assertGreaterEqual(s, 0.0)
437
+ self.assertLessEqual(e, 1.0)
438
+ self.assertLessEqual(s, e)
439
+
440
+ def test_strategy_is_active_custom_window_overrides_defaults(self):
441
+ """Custom window for AND_SALT: 0.0–0.3 should gate it out at step 8/10."""
442
+ raw = {'AND_SALT': (0.0, 0.3)}
443
+ per_strategy = build_per_strategy_windows(raw)
444
+ progress = normalize_progress(8, 10)
445
+ active = strategy_is_active('SALIENCE_MASK', progress,
446
+ per_strategy_windows=per_strategy,
447
+ use_defaults=True)
448
+ self.assertFalse(active)
449
+
450
+ def test_strategy_is_active_fallback_to_defaults_when_key_missing(self):
451
+ """If AND_TOPK not in custom dict, fall back to STRATEGY_DEFAULTS."""
452
+ raw = {'AND_SALT': (0.0, 0.3)} # only SALT defined
453
+ per_strategy = build_per_strategy_windows(raw)
454
+ # SEMANTIC_GUIDANCE default: 0.4–1.0 → active at step 8/10
455
+ progress = normalize_progress(8, 10)
456
+ active = strategy_is_active('SEMANTIC_GUIDANCE', progress,
457
+ per_strategy_windows=per_strategy,
458
+ use_defaults=True)
459
+ self.assertTrue(active)
460
+
461
+
462
+ # ===========================================================================
463
+ # 5. Per-strategy custom round-trip via infotext
464
+ # ===========================================================================
465
+
466
+ class TestPerStrategyRoundTrip(unittest.TestCase):
467
+
468
+ def setUp(self):
469
+ self._state = {
470
+ 'enabled': global_state.step_window_enabled,
471
+ 'use_defaults': global_state.step_window_use_defaults,
472
+ 'per_strategy': global_state.step_window_per_strategy,
473
+ 'custom_raw': global_state.step_window_custom_raw,
474
+ }
475
+
476
+ def tearDown(self):
477
+ global_state.step_window_enabled = self._state['enabled']
478
+ global_state.step_window_use_defaults = self._state['use_defaults']
479
+ global_state.step_window_per_strategy = self._state['per_strategy']
480
+ global_state.step_window_custom_raw = self._state['custom_raw']
481
+
482
+ def test_custom_mode_extra_params_has_custom_key(self):
483
+ raw = {'AND_PERP': (0.0, 0.4), 'AND_SALT': (0.5, 1.0)}
484
+ args = _unpack(step_window_enabled=True, step_window_mode='per-strategy custom',
485
+ **{f'{k}_start': v[0] for k, v in raw.items()},
486
+ **{f'{k}_end': v[1] for k, v in raw.items()})
487
+ params = _extra(args)
488
+ self.assertIn('NP Step Window Custom', params)
489
+ self.assertIn('AND_PERP', params['NP Step Window Custom'])
490
+ self.assertIn('AND_SALT', params['NP Step Window Custom'])
491
+
492
+ def test_custom_mode_infotext_restore(self):
493
+ raw_str = serialize_per_strategy_windows(
494
+ {'AND_PERP': (0.0, 0.4), 'AND_SALT': (0.5, 1.0)}
495
+ )
496
+ global_state.apply_infotext({
497
+ 'NP Step Window Enabled': 'True',
498
+ 'NP Step Window Mode': 'per-strategy custom',
499
+ 'NP Step Window Custom': raw_str,
500
+ })
501
+ self.assertTrue(global_state.step_window_enabled)
502
+ self.assertIsNotNone(global_state.step_window_per_strategy)
503
+ self.assertIn('PERPENDICULAR', global_state.step_window_per_strategy)
504
+ self.assertIn('SALIENCE_MASK', global_state.step_window_per_strategy)
505
+ perp_win = global_state.step_window_per_strategy['PERPENDICULAR']
506
+ self.assertAlmostEqual(perp_win.start, 0.0, places=3)
507
+ self.assertAlmostEqual(perp_win.end, 0.4, places=3)
508
+
509
+ def test_custom_mode_restored_state_gates_correctly(self):
510
+ """After restore, strategy_is_active uses the per-strategy windows."""
511
+ raw_str = serialize_per_strategy_windows({'AND_PERP': (0.0, 0.2)})
512
+ global_state.apply_infotext({
513
+ 'NP Step Window Enabled': 'True',
514
+ 'NP Step Window Mode': 'per-strategy custom',
515
+ 'NP Step Window Custom': raw_str,
516
+ })
517
+ per_strategy = global_state.step_window_per_strategy
518
+ # AND_PERP window 0.0–0.2 → inactive at step 5/10 (progress=0.56)
519
+ progress = normalize_progress(5, 10)
520
+ self.assertFalse(
521
+ strategy_is_active('PERPENDICULAR', progress,
522
+ per_strategy_windows=per_strategy)
523
+ )
524
+
525
+ def test_full_round_trip_all_strategies(self):
526
+ raw = default_custom_windows()
527
+ serialized = serialize_per_strategy_windows(raw)
528
+ restored = deserialize_per_strategy_windows(serialized)
529
+ # Every key present
530
+ for uk in STRATEGY_UI_KEYS:
531
+ self.assertIn(uk, restored)
532
+ # Values match to 3 decimal places
533
+ for uk in STRATEGY_UI_KEYS:
534
+ self.assertAlmostEqual(restored[uk][0], raw[uk][0], places=3)
535
+ self.assertAlmostEqual(restored[uk][1], raw[uk][1], places=3)
536
+
537
+
538
+ if __name__ == '__main__':
539
+ unittest.main()