Upload neutral_prompt_patcheds using SD-Hub
Browse files- neutral_prompt_patcheds/.gitignore +1 -0
- neutral_prompt_patcheds/LICENSE +21 -0
- neutral_prompt_patcheds/README.md +343 -0
- neutral_prompt_patcheds/lib_neutral_prompt/__init__.py +1 -0
- neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/__init__.cpython-310.pyc +0 -0
- neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/affine_transform.cpython-310.pyc +0 -0
- neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/affine_utils.cpython-310.pyc +0 -0
- neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/cfg_denoiser_hijack.cpython-310.pyc +0 -0
- neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/global_state.cpython-310.pyc +0 -0
- neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/hijacker.cpython-310.pyc +0 -0
- neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/matryoshka_utils.cpython-310.pyc +0 -0
- neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/neutral_prompt_parser.cpython-310.pyc +0 -0
- neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/prompt_parser_hijack.cpython-310.pyc +0 -0
- neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/protection_utils.cpython-310.pyc +0 -0
- neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/step_utils.cpython-310.pyc +0 -0
- neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/ui.cpython-310.pyc +0 -0
- neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/xyz_grid.cpython-310.pyc +0 -0
- neutral_prompt_patcheds/lib_neutral_prompt/affine_transform.py +83 -0
- neutral_prompt_patcheds/lib_neutral_prompt/affine_utils.py +220 -0
- neutral_prompt_patcheds/lib_neutral_prompt/cfg_denoiser_hijack.py +726 -0
- neutral_prompt_patcheds/lib_neutral_prompt/external_code/__init__.py +23 -0
- neutral_prompt_patcheds/lib_neutral_prompt/external_code/api.py +27 -0
- neutral_prompt_patcheds/lib_neutral_prompt/global_state.py +325 -0
- neutral_prompt_patcheds/lib_neutral_prompt/hijacker.py +34 -0
- neutral_prompt_patcheds/lib_neutral_prompt/matryoshka_utils.py +644 -0
- neutral_prompt_patcheds/lib_neutral_prompt/neutral_prompt_parser.py +453 -0
- neutral_prompt_patcheds/lib_neutral_prompt/prompt_parser_hijack.py +134 -0
- neutral_prompt_patcheds/lib_neutral_prompt/protection_utils.py +277 -0
- neutral_prompt_patcheds/lib_neutral_prompt/step_utils.py +454 -0
- neutral_prompt_patcheds/lib_neutral_prompt/ui.py +811 -0
- neutral_prompt_patcheds/lib_neutral_prompt/xyz_grid.py +42 -0
- neutral_prompt_patcheds/scripts/__pycache__/neutral_prompt.cpython-310.pyc +0 -0
- neutral_prompt_patcheds/scripts/neutral_prompt.py +99 -0
- neutral_prompt_patcheds/test/perp_parser/__init__.py +54 -0
- neutral_prompt_patcheds/test/perp_parser/mock_torch.py +61 -0
- neutral_prompt_patcheds/test/perp_parser/test_affine_keyword_order.py +133 -0
- neutral_prompt_patcheds/test/perp_parser/test_affine_pipeline.py +217 -0
- neutral_prompt_patcheds/test/perp_parser/test_basic_parser.py +122 -0
- neutral_prompt_patcheds/test/perp_parser/test_lock_after_end.py +544 -0
- neutral_prompt_patcheds/test/perp_parser/test_malicious_parser.py +182 -0
- neutral_prompt_patcheds/test/perp_parser/test_matryoshka.py +440 -0
- neutral_prompt_patcheds/test/perp_parser/test_matryoshka_golden.py +331 -0
- neutral_prompt_patcheds/test/perp_parser/test_parametric_syntax.py +535 -0
- neutral_prompt_patcheds/test/perp_parser/test_runtime_behavior.py +826 -0
- neutral_prompt_patcheds/test/perp_parser/test_sprint2.py +656 -0
- neutral_prompt_patcheds/test/perp_parser/test_sprint2_hotfix.py +419 -0
- neutral_prompt_patcheds/test/perp_parser/test_stabilization.py +539 -0
neutral_prompt_patcheds/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
neutral_prompt_patcheds/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 ljleb
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
neutral_prompt_patcheds/README.md
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# sd-webui-neutral-prompt — Patched Edition
|
| 2 |
+
|
| 3 |
+
> **Unified merge of the main experimental branches — with bugfixes and parametric syntax.**
|
| 4 |
+
>
|
| 5 |
+
> Based on the _Ultimate Edition_ merge, with the following additions:
|
| 6 |
+
> - `AND_SALT[k]` / `AND_SALT_WIDE[k]` / `AND_SALT_BLOB[k]` — user-controlled salience sharpness
|
| 7 |
+
> - `AND_ALIGN[D,S]` / `AND_MASK_ALIGN[D,S]` — any D/S pair without suffix pre-declaration
|
| 8 |
+
> - Fixed `AND_SALT` default (k=5 instead of k=20 — actually visible)
|
| 9 |
+
> - Fixed `AND_SALT_BLOB` morphology order (thickify-first prevents seed self-destruction)
|
| 10 |
+
> - Guard against "base prompt replacement" when no plain AND-child is present
|
| 11 |
+
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
## What's inside
|
| 15 |
+
|
| 16 |
+
| Feature | Origin | Keyword / setting |
|
| 17 |
+
|---|---|---|
|
| 18 |
+
| Perpendicular projection | `main` | `AND_PERP` |
|
| 19 |
+
| Saliency-guided mask | `main` + patch | `AND_SALT[k]` |
|
| 20 |
+
| Saliency blob mask | `life` + **fixed** | `AND_SALT_BLOB[k]` |
|
| 21 |
+
| Saliency wide mask | `main` | `AND_SALT_WIDE[k]` |
|
| 22 |
+
| Semantic guidance top-k | `main` | `AND_TOPK` |
|
| 23 |
+
| CFG rescale (mean-preserving) | `main` | CFG rescale φ slider |
|
| 24 |
+
| XYZ-grid CFG rescale axis | `main` | XYZ-grid option |
|
| 25 |
+
| External API override | `main` / `export_rescale_factor` | `override_cfg_rescale()` |
|
| 26 |
+
| CFG rescale factor export | `export_rescale_factor` | `get_last_cfg_rescale_factor()` |
|
| 27 |
+
| Affine spatial transforms | `affine` | `ROTATE / SLIDE / SCALE / SHEAR` |
|
| 28 |
+
| Soft alignment blend | `alignment_blend` | `AND_ALIGN[D,S]` or `AND_ALIGN_D_S` |
|
| 29 |
+
| Binary alignment mask | `alignment_mask` | `AND_MASK_ALIGN[D,S]` or `AND_MASK_ALIGN_D_S` |
|
| 30 |
+
|
| 31 |
+
---
|
| 32 |
+
|
| 33 |
+
## Prompt syntax
|
| 34 |
+
|
| 35 |
+
```
|
| 36 |
+
positive prompt text
|
| 37 |
+
AND_PERP negative-direction text :0.8
|
| 38 |
+
AND_SALT[5] competing concept :1.0 ← k=5 (default; more visible than old k=20)
|
| 39 |
+
AND_SALT_WIDE classic broad mask :0.8 ← k=1 default
|
| 40 |
+
AND_SALT_BLOB[8] blob concept :1.0 ← custom k for blob seed sharpness
|
| 41 |
+
AND_TOPK fine detail tweak :0.5
|
| 42 |
+
AND_ALIGN[4,8] style blend :0.6 ← any valid D,S pair
|
| 43 |
+
AND_MASK_ALIGN[6,12] structure style :0.6 ← binary-mask variant
|
| 44 |
+
AND_ALIGN_4_8 style blend :0.6 ← fixed-suffix form still works
|
| 45 |
+
AND_PERP ROTATE[0.125] rotated concept :1.0
|
| 46 |
+
ROTATE[0.125] AND_PERP rotated concept :1.0
|
| 47 |
+
AND_SALT[5] ROTATE[0.125] shifted concept :1.0
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
> **Important:** always include a plain base prompt before `AND_*` segments.
|
| 51 |
+
> If all segments have conciliation keywords (no plain `AND`), the extension
|
| 52 |
+
> falls back to standard CFG automatically and logs a warning.
|
| 53 |
+
|
| 54 |
+
---
|
| 55 |
+
|
| 56 |
+
## Salience sharpness k
|
| 57 |
+
|
| 58 |
+
All three salience keywords accept an optional `[k]` parameter:
|
| 59 |
+
|
| 60 |
+
```
|
| 61 |
+
AND_SALT[k] sharp mask default k = 5
|
| 62 |
+
AND_SALT_WIDE[k] broad mask default k = 1
|
| 63 |
+
AND_SALT_BLOB[k] blob mask default k = 5
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
| k value | Effect | Approx. pixel coverage |
|
| 67 |
+
|---|---|---|
|
| 68 |
+
| 1 | Very broad, ~50% of pixels | ~50% |
|
| 69 |
+
| 3 | Moderate | ~11% |
|
| 70 |
+
| 5 | Focused (default) | ~2% |
|
| 71 |
+
| 10 | Surgical | ~0.25% |
|
| 72 |
+
| 20 | Extreme (original ultimate default) | ~0.04% |
|
| 73 |
+
|
| 74 |
+
---
|
| 75 |
+
|
| 76 |
+
## Alignment kernel sizes D, S
|
| 77 |
+
|
| 78 |
+
```
|
| 79 |
+
AND_ALIGN[D,S] soft blend
|
| 80 |
+
AND_MASK_ALIGN[D,S] binary-mask blend
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
- **D** = detail kernel size (small → fine detail)
|
| 84 |
+
- **S** = structure kernel size (large → global composition)
|
| 85 |
+
- Both in range **[2, 32]**, **D ≠ S**
|
| 86 |
+
- Child contribution is highest where it changes detail *without* breaking structure.
|
| 87 |
+
|
| 88 |
+
Fixed-suffix form (`AND_ALIGN_4_8`) is still supported and backward-compatible.
|
| 89 |
+
|
| 90 |
+
---
|
| 91 |
+
|
| 92 |
+
## Affine transform keywords
|
| 93 |
+
|
| 94 |
+
Supported in **either order** relative to the conciliation keyword:
|
| 95 |
+
|
| 96 |
+
| Keyword | Parameter | Effect |
|
| 97 |
+
|---|---|---|
|
| 98 |
+
| `ROTATE[angle]` | angle in turns (0–1) | Rotate latent contribution |
|
| 99 |
+
| `SLIDE[x,y]` | normalised offset | Translate latent contribution |
|
| 100 |
+
| `SCALE[x,y]` | scale factors | Scale latent contribution |
|
| 101 |
+
| `SHEAR[x,y]` | shear in turns | Shear latent contribution |
|
| 102 |
+
|
| 103 |
+
Multiple transforms can be chained: `ROTATE[0.125] SLIDE[0.1,0]`
|
| 104 |
+
|
| 105 |
+
---
|
| 106 |
+
|
| 107 |
+
## External API
|
| 108 |
+
|
| 109 |
+
```python
|
| 110 |
+
from lib_neutral_prompt.external_code import override_cfg_rescale, get_last_cfg_rescale_factor
|
| 111 |
+
|
| 112 |
+
override_cfg_rescale(0.7) # override for next step only
|
| 113 |
+
factor = get_last_cfg_rescale_factor() # read after generation
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
---
|
| 117 |
+
|
| 118 |
+
## CFG Rescale φ
|
| 119 |
+
|
| 120 |
+
When > 0, rescales CFG output to reduce colour over-saturation at high CFG values.
|
| 121 |
+
Uses the **mean-preserving** formula:
|
| 122 |
+
|
| 123 |
+
```
|
| 124 |
+
rescaled = rescale_mean + (cfg_cond − cfg_cond_mean) × rescale_factor
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
where `rescale_factor = φ × (std(cond)/std(cfg_cond) − 1) + 1`.
|
| 128 |
+
|
| 129 |
+
---
|
| 130 |
+
|
| 131 |
+
## Installation
|
| 132 |
+
|
| 133 |
+
1. Copy this folder into `stable-diffusion-webui/extensions/`
|
| 134 |
+
2. Restart the webui
|
| 135 |
+
|
| 136 |
+
---
|
| 137 |
+
|
| 138 |
+
## Smoke-test checklist
|
| 139 |
+
|
| 140 |
+
Before reporting a bug, verify:
|
| 141 |
+
|
| 142 |
+
- [ ] Base prompt exists before `AND_*` segments
|
| 143 |
+
- [ ] `AND_SALT` without `[k]` defaults to k=5 (moderate coverage)
|
| 144 |
+
- [ ] `AND_SALT[1]` behaves like `AND_SALT_WIDE` (broad)
|
| 145 |
+
- [ ] `AND_SALT_BLOB[5]` produces a visible blob region
|
| 146 |
+
- [ ] `AND_ALIGN[4,8]` and `AND_ALIGN_4_8` produce identical results
|
| 147 |
+
- [ ] Prompt with only `AND_SALT ...` (no base) falls back to standard CFG with a warning in console (enable verbose in Settings)
|
| 148 |
+
|
| 149 |
+
---
|
| 150 |
+
|
| 151 |
+
## Matryoshka / Nested prompt composition
|
| 152 |
+
|
| 153 |
+
The extension supports arbitrary nesting of `AND_*` blocks. An outer block
|
| 154 |
+
focuses or restricts a spatial/semantic region; inner blocks operate _inside_
|
| 155 |
+
that region:
|
| 156 |
+
|
| 157 |
+
```text
|
| 158 |
+
base subject
|
| 159 |
+
AND_SALT[5] [
|
| 160 |
+
region texture :0.8
|
| 161 |
+
AND_TOPK[0.05] fine highlights :0.4
|
| 162 |
+
] :0.7
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
This is called a **matryoshka** structure — like nested dolls, each layer adds
|
| 166 |
+
a more focused effect.
|
| 167 |
+
|
| 168 |
+
### Basic nesting syntax
|
| 169 |
+
|
| 170 |
+
```text
|
| 171 |
+
base
|
| 172 |
+
KEYWORD [
|
| 173 |
+
inner text :weight
|
| 174 |
+
INNER_KEYWORD inner concept :weight
|
| 175 |
+
] :weight
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
The outer `[ ... ]` bracket is a **composite group**, not a parameter block.
|
| 179 |
+
It is parsed as such only if the content contains letters (not purely numeric).
|
| 180 |
+
|
| 181 |
+
### Nesting with params
|
| 182 |
+
|
| 183 |
+
Params go directly after the keyword, _before_ the composite bracket:
|
| 184 |
+
|
| 185 |
+
```text
|
| 186 |
+
AND_SALT[5] [ ← k=5 applied, then [ opens composite group
|
| 187 |
+
texture :0.8
|
| 188 |
+
AND_TOPK[0.1] detail :0.5
|
| 189 |
+
] :0.7
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
### Nesting with affine
|
| 193 |
+
|
| 194 |
+
Affine transforms can precede the keyword at any level:
|
| 195 |
+
|
| 196 |
+
```text
|
| 197 |
+
SCALE[-1,1] AND_PERP [ ← mirror then perpendicular
|
| 198 |
+
mirrored concept :0.6
|
| 199 |
+
] :0.5
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
### How the parser distinguishes params vs composite
|
| 203 |
+
|
| 204 |
+
| Content inside `[...]` | Treated as |
|
| 205 |
+
|---|---|
|
| 206 |
+
| Purely numeric: `[5]`, `[4,8]`, `[ 0.1 ]` | **Params** for the preceding keyword |
|
| 207 |
+
| Contains letters / spaces + letters | **Composite group** |
|
| 208 |
+
|
| 209 |
+
So `AND_SALT[5]` → params; `AND_SALT [concept AND other]` → composite.
|
| 210 |
+
|
| 211 |
+
### Protection with nested prompts
|
| 212 |
+
|
| 213 |
+
The base-prompt protection guard checks for a **plain base segment at the
|
| 214 |
+
top level** of the prompt. A nested structure like:
|
| 215 |
+
|
| 216 |
+
```text
|
| 217 |
+
AND_SALT[5] [ ... ] :0.8
|
| 218 |
+
```
|
| 219 |
+
|
| 220 |
+
has _no_ top-level base, so `auto` and `strict` protection will fire.
|
| 221 |
+
To suppress this intentionally, set protection to `off`.
|
| 222 |
+
|
| 223 |
+
### Common mistakes
|
| 224 |
+
|
| 225 |
+
**Forgetting the base prompt**
|
| 226 |
+
```text
|
| 227 |
+
# ❌ triggers protection fallback
|
| 228 |
+
AND_SALT[5] concept :0.8
|
| 229 |
+
|
| 230 |
+
# ✓
|
| 231 |
+
base subject
|
| 232 |
+
AND_SALT[5] concept :0.8
|
| 233 |
+
```
|
| 234 |
+
|
| 235 |
+
**Using numeric composite brackets as params**
|
| 236 |
+
```text
|
| 237 |
+
# This is parsed as a composite group, NOT as k=5:
|
| 238 |
+
AND_SALT [5] ← space before bracket → composite group with text "5"
|
| 239 |
+
AND_SALT[5] ← no space → params (k=5) ✓
|
| 240 |
+
```
|
| 241 |
+
|
| 242 |
+
**Zero or out-of-range params**
|
| 243 |
+
```text
|
| 244 |
+
AND_SALT[0] ← invalid k, silently uses defaults
|
| 245 |
+
AND_ALIGN[4,4] ← D must differ from S, silently uses defaults
|
| 246 |
+
AND_TOPK[1.5] ← threshold must be in (0,1], silently uses defaults
|
| 247 |
+
```
|
| 248 |
+
Enable verbose mode in A1111 Settings → Neutral Prompt to see warnings.
|
| 249 |
+
|
| 250 |
+
---
|
| 251 |
+
|
| 252 |
+
## Matryoshka builder (UI)
|
| 253 |
+
|
| 254 |
+
Open the **Matryoshka builder** accordion to compose nested prompts visually:
|
| 255 |
+
|
| 256 |
+
1. Enter the base prompt text.
|
| 257 |
+
2. Choose a strategy for Child 1, enter its text and weight.
|
| 258 |
+
3. Optionally enable a nested child _inside_ Child 1.
|
| 259 |
+
4. Optionally add a second top-level child.
|
| 260 |
+
5. The **Generated prompt (preview)** updates live.
|
| 261 |
+
6. Click **Apply to prompt** to insert into the main prompt box.
|
| 262 |
+
|
| 263 |
+
### Ready-made templates
|
| 264 |
+
|
| 265 |
+
The **Matryoshka templates** accordion contains 7 ready-made recipes:
|
| 266 |
+
|
| 267 |
+
| Template | What it does |
|
| 268 |
+
|---|---|
|
| 269 |
+
| Nested local detail | SALT region → nested TOPK fine detail |
|
| 270 |
+
| Structure preserve + style inject | ALIGN preserves structure → nested PERP reduces contradiction |
|
| 271 |
+
| Sparse detail inside broad region | SALT_WIDE broad → nested SALT focal sharpening |
|
| 272 |
+
| Perpendicular correction inside texture | SALT texture → nested PERP suppresses contradiction |
|
| 273 |
+
| Nested ALIGN + SALT | ALIGN structure → SALT foreground → nested TOPK highlights |
|
| 274 |
+
| Mirrored composition | Mirrored PERP via SCALE[-1,1] |
|
| 275 |
+
| Deep nested concept isolation | SALT → ALIGN → TOPK (3 levels) |
|
| 276 |
+
|
| 277 |
+
---
|
| 278 |
+
|
| 279 |
+
## Prompt debug / explain panel
|
| 280 |
+
|
| 281 |
+
Open the **Prompt debug / explain** accordion to inspect any prompt.
|
| 282 |
+
|
| 283 |
+
The panel shows:
|
| 284 |
+
|
| 285 |
+
```
|
| 286 |
+
ROOT (N top-level segments)
|
| 287 |
+
├─ [BASE] w=1.00 "base subject"
|
| 288 |
+
└─ [SALIENCE_MASK] w=0.80 [k=5.0] (2 children)
|
| 289 |
+
├─ [BASE] w=0.80 "texture"
|
| 290 |
+
└─ [SEMANTIC_GUIDANCE] w=0.40 [threshold=0.05] "highlights"
|
| 291 |
+
|
| 292 |
+
── Diagnostics ──────────────────────────
|
| 293 |
+
Segments total : 4 (leaf=3, composite=1)
|
| 294 |
+
Max nesting depth : 2
|
| 295 |
+
Affine transforms : 0
|
| 296 |
+
Strategies used : BASE, SALIENCE_MASK, SEMANTIC_GUIDANCE
|
| 297 |
+
|
| 298 |
+
Protection mode : auto
|
| 299 |
+
OK — base segment present at top level
|
| 300 |
+
|
| 301 |
+
── Effect summary ────────────────────────
|
| 302 |
+
BASE → base prompt contribution — drives the overall image
|
| 303 |
+
SALIENCE_MASK → sharp saliency mask — targets the most salient latent pixels
|
| 304 |
+
SEMANTIC_GUIDANCE → semantic top-k — applies sparse targeted changes to the strongest elements
|
| 305 |
+
```
|
| 306 |
+
|
| 307 |
+
### What the debug panel does NOT simulate
|
| 308 |
+
|
| 309 |
+
- **Strict ratio check**: whether `norm(base_delta) / norm(aux_delta)` would fall
|
| 310 |
+
below the threshold is only knowable at generation time with real latents.
|
| 311 |
+
- **Batch/mixed-prompt cases**: the panel parses a single prompt string; it does
|
| 312 |
+
not simulate multi-prompt batch protection.
|
| 313 |
+
- **Hook conflicts**: whether another A1111 extension has overridden the hijack
|
| 314 |
+
chain is a runtime condition.
|
| 315 |
+
|
| 316 |
+
The protection verdict shown is a **structural preview** — a fast heuristic
|
| 317 |
+
to catch the most common mistake (missing base segment) before you start
|
| 318 |
+
generating.
|
| 319 |
+
|
| 320 |
+
---
|
| 321 |
+
|
| 322 |
+
## Affine transform builder (UI)
|
| 323 |
+
|
| 324 |
+
Open the **Affine transform builder** accordion to compose affine snippets:
|
| 325 |
+
|
| 326 |
+
1. Pick a preset (Mirror H/V, Rotate 45°/90°/180°, Zoom, Stretch, etc.)
|
| 327 |
+
or select **Custom** and set parameters manually.
|
| 328 |
+
2. The **Affine snippet** box updates live.
|
| 329 |
+
3. Click **Insert affine into prompt** to append the snippet.
|
| 330 |
+
|
| 331 |
+
Snippets can be placed before a keyword (`ROTATE[0.125] AND_PERP ...`) or
|
| 332 |
+
as a standalone transform (`SCALE[-1,1] AND_SALT[5] ...`).
|
| 333 |
+
|
| 334 |
+
### Safe ranges
|
| 335 |
+
|
| 336 |
+
| Transform | Notes |
|
| 337 |
+
|---|---|
|
| 338 |
+
| `ROTATE[angle]` | Angle in turns. `0.25` = 90°, `0.5` = 180° |
|
| 339 |
+
| `SCALE[x,y]` | Both components must be non-zero (avoids singular matrix) |
|
| 340 |
+
| `SLIDE[x,y]` | Translation fraction of latent space |
|
| 341 |
+
| `SHEAR[x,y]` | Angles near ±0.25 turns diverge — avoided automatically |
|
| 342 |
+
| `FLIP_H` | Shortcut for `SCALE[-1,1]` |
|
| 343 |
+
| `FLIP_V` | Shortcut for `SCALE[1,-1]` |
|
neutral_prompt_patcheds/lib_neutral_prompt/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# lib_neutral_prompt package
|
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (172 Bytes). View file
|
|
|
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/affine_transform.cpython-310.pyc
ADDED
|
Binary file (3.34 kB). View file
|
|
|
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/affine_utils.cpython-310.pyc
ADDED
|
Binary file (6.8 kB). View file
|
|
|
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/cfg_denoiser_hijack.cpython-310.pyc
ADDED
|
Binary file (19.7 kB). View file
|
|
|
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/global_state.cpython-310.pyc
ADDED
|
Binary file (8.06 kB). View file
|
|
|
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/hijacker.cpython-310.pyc
ADDED
|
Binary file (1.8 kB). View file
|
|
|
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/matryoshka_utils.cpython-310.pyc
ADDED
|
Binary file (19.7 kB). View file
|
|
|
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/neutral_prompt_parser.cpython-310.pyc
ADDED
|
Binary file (13.7 kB). View file
|
|
|
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/prompt_parser_hijack.cpython-310.pyc
ADDED
|
Binary file (3.92 kB). View file
|
|
|
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/protection_utils.cpython-310.pyc
ADDED
|
Binary file (8.15 kB). View file
|
|
|
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/step_utils.cpython-310.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/ui.cpython-310.pyc
ADDED
|
Binary file (24.9 kB). View file
|
|
|
neutral_prompt_patcheds/lib_neutral_prompt/__pycache__/xyz_grid.cpython-310.pyc
ADDED
|
Binary file (1.64 kB). View file
|
|
|
neutral_prompt_patcheds/lib_neutral_prompt/affine_transform.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Affine spatial transform utilities for latent tensors.
|
| 3 |
+
Extracted from the `affine` branch so that cfg_denoiser_hijack can stay lean.
|
| 4 |
+
|
| 5 |
+
Provides:
|
| 6 |
+
apply_affine_transform() – apply a 2×3 affine grid to a C×H×W tensor
|
| 7 |
+
apply_masked_transform() – apply affine + cosine-feathered mask blending
|
| 8 |
+
create_cosine_feathered_mask() – smooth circular weight mask
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from typing import Tuple
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def apply_affine_transform(
|
| 18 |
+
tensor: torch.Tensor,
|
| 19 |
+
affine: torch.Tensor,
|
| 20 |
+
mode: str = 'bilinear',
|
| 21 |
+
) -> torch.Tensor:
|
| 22 |
+
"""
|
| 23 |
+
Apply a 2×3 affine transform to a C×H×W tensor, preserving aspect ratio.
|
| 24 |
+
|
| 25 |
+
:param tensor: Input tensor of shape [C, H, W].
|
| 26 |
+
:param affine: 2×3 float32 affine matrix.
|
| 27 |
+
:param mode: Interpolation mode for grid_sample ('bilinear' default).
|
| 28 |
+
The original affine branch used bilinear for the pre-noise
|
| 29 |
+
inverse path (smoother) and nearest for post-combine.
|
| 30 |
+
Bilinear is the better default for both.
|
| 31 |
+
:return: Transformed tensor of shape [C, H, W].
|
| 32 |
+
"""
|
| 33 |
+
affine = affine.clone().to(tensor.device)
|
| 34 |
+
aspect_ratio = tensor.shape[-2] / tensor.shape[-1]
|
| 35 |
+
affine[0, 1] *= aspect_ratio
|
| 36 |
+
affine[1, 0] /= aspect_ratio
|
| 37 |
+
|
| 38 |
+
grid = F.affine_grid(affine.unsqueeze(0), tensor.unsqueeze(0).size(), align_corners=False)
|
| 39 |
+
return F.grid_sample(tensor.unsqueeze(0), grid, mode=mode, align_corners=False).squeeze(0)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def apply_masked_transform(
|
| 43 |
+
tensor: torch.Tensor,
|
| 44 |
+
affine: torch.Tensor,
|
| 45 |
+
weight: float,
|
| 46 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 47 |
+
"""
|
| 48 |
+
Apply an affine transform with a cosine-feathered spatial weight mask.
|
| 49 |
+
|
| 50 |
+
The mask channel is appended to the tensor before transformation so that
|
| 51 |
+
the same spatial warp is applied to both content and mask consistently.
|
| 52 |
+
|
| 53 |
+
:param tensor: C×H×W latent tensor.
|
| 54 |
+
:param affine: 2×3 affine matrix.
|
| 55 |
+
:param weight: Global scalar weight for the mask.
|
| 56 |
+
:return: (transformed_content [C, H, W], transformed_mask [H, W])
|
| 57 |
+
"""
|
| 58 |
+
mask = create_cosine_feathered_mask(tensor.shape[-2:], weight).unsqueeze(0).to(tensor.device)
|
| 59 |
+
tensor_with_mask = torch.cat([tensor, mask], dim=0) # [C+1, H, W]
|
| 60 |
+
transformed = apply_affine_transform(tensor_with_mask, affine)
|
| 61 |
+
return transformed[:-1], transformed[-1] # content, mask
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def create_cosine_feathered_mask(size: Tuple[int, int], weight: float) -> torch.Tensor:
|
| 65 |
+
"""
|
| 66 |
+
Create a circularly-clipped cosine-feathered mask of shape [H, W].
|
| 67 |
+
|
| 68 |
+
Values at the centre approach `weight`; values outside the unit circle
|
| 69 |
+
are exactly 0, with a smooth cosine fall-off in between.
|
| 70 |
+
|
| 71 |
+
:param size: (H, W) tuple.
|
| 72 |
+
:param weight: Peak mask value at the centre.
|
| 73 |
+
:return: Float32 mask tensor of shape [H, W].
|
| 74 |
+
"""
|
| 75 |
+
y, x = torch.meshgrid(
|
| 76 |
+
torch.linspace(-1, 1, size[0]),
|
| 77 |
+
torch.linspace(-1, 1, size[1]),
|
| 78 |
+
indexing='ij',
|
| 79 |
+
)
|
| 80 |
+
dist = torch.sqrt(x ** 2 + y ** 2)
|
| 81 |
+
mask = 0.5 * (1.0 + torch.cos(torch.pi * dist))
|
| 82 |
+
mask[dist > 1] = 0.0
|
| 83 |
+
return mask.float() * weight
|
neutral_prompt_patcheds/lib_neutral_prompt/affine_utils.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Affine transform subsystem for sd-webui-neutral-prompt.
|
| 3 |
+
|
| 4 |
+
Single source of truth for:
|
| 5 |
+
- Matrix constructors: _make_rotate / _make_slide / _make_scale / _make_shear
|
| 6 |
+
- Application: _apply_affine
|
| 7 |
+
- Safe inversion: _try_invert_affine
|
| 8 |
+
- Parser dispatch: affine_transforms (keyword → callable)
|
| 9 |
+
- UI helpers: _build_affine_snippet / _AFFINE_TRANSFORMS / _AFFINE_PRESETS
|
| 10 |
+
|
| 11 |
+
Previously scattered across neutral_prompt_parser.py, cfg_denoiser_hijack.py,
|
| 12 |
+
and matryoshka_utils.py. Import from here everywhere.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import sys
|
| 19 |
+
from typing import Dict, Optional, Tuple
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
# Warning helper
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
|
| 28 |
+
def affine_warn(keyword: str, reason: str) -> None:
|
| 29 |
+
"""Print a one-line warning to stderr (respects global_state.verbose)."""
|
| 30 |
+
try:
|
| 31 |
+
from lib_neutral_prompt import global_state as _gs
|
| 32 |
+
if not _gs.verbose:
|
| 33 |
+
return
|
| 34 |
+
except ImportError:
|
| 35 |
+
pass
|
| 36 |
+
print(
|
| 37 |
+
f'[neutral-prompt] WARNING: invalid affine params for {keyword} — {reason}. '
|
| 38 |
+
'Transform ignored; using identity.',
|
| 39 |
+
file=sys.stderr,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Keep old private name as alias (parser uses it)
|
| 43 |
+
_affine_warn = affine_warn
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
# Matrix constructors
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
|
| 50 |
+
def make_rotate(angle: float) -> Optional[torch.Tensor]:
|
| 51 |
+
"""Rotation matrix. *angle* is in turns (0.25 = 90°)."""
|
| 52 |
+
if not math.isfinite(angle):
|
| 53 |
+
affine_warn('ROTATE', f'non-finite angle {angle}')
|
| 54 |
+
return None
|
| 55 |
+
a = angle * 2 * math.pi
|
| 56 |
+
return torch.tensor([
|
| 57 |
+
[math.cos(a), -math.sin(a), 0],
|
| 58 |
+
[math.sin(a), math.cos(a), 0],
|
| 59 |
+
[0, 0, 1],
|
| 60 |
+
], dtype=torch.float32)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def make_slide(x: float, y: float) -> Optional[torch.Tensor]:
|
| 64 |
+
"""Translation matrix."""
|
| 65 |
+
if not math.isfinite(x) or not math.isfinite(y):
|
| 66 |
+
affine_warn('SLIDE', f'non-finite offset ({x}, {y})')
|
| 67 |
+
return None
|
| 68 |
+
return torch.tensor([
|
| 69 |
+
[1, 0, float(x)],
|
| 70 |
+
[0, 1, float(y)],
|
| 71 |
+
[0, 0, 1],
|
| 72 |
+
], dtype=torch.float32)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def make_scale(x: float, y: Optional[float] = None) -> Optional[torch.Tensor]:
|
| 76 |
+
"""Uniform or anisotropic scale matrix. Both components must be non-zero."""
|
| 77 |
+
sy = y if y is not None else x
|
| 78 |
+
if not math.isfinite(x) or not math.isfinite(sy):
|
| 79 |
+
affine_warn('SCALE', f'non-finite scale ({x}, {sy})')
|
| 80 |
+
return None
|
| 81 |
+
if abs(x) < 1e-6 or abs(sy) < 1e-6:
|
| 82 |
+
affine_warn('SCALE', f'scale component too close to zero ({x}, {sy}); '
|
| 83 |
+
'matrix would be singular')
|
| 84 |
+
return None
|
| 85 |
+
return torch.tensor([
|
| 86 |
+
[float(x), 0, 0],
|
| 87 |
+
[0, float(sy), 0],
|
| 88 |
+
[0, 0, 1],
|
| 89 |
+
], dtype=torch.float32)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def make_shear(x: float, y: Optional[float] = None) -> Optional[torch.Tensor]:
|
| 93 |
+
"""Shear matrix. Angles near ±0.25 turns are rejected (tan diverges)."""
|
| 94 |
+
sy = y if y is not None else x
|
| 95 |
+
if not math.isfinite(x) or not math.isfinite(sy):
|
| 96 |
+
affine_warn('SHEAR', f'non-finite shear ({x}, {sy})')
|
| 97 |
+
return None
|
| 98 |
+
_SHEAR_MAX = 0.24
|
| 99 |
+
if abs(x) >= _SHEAR_MAX or abs(sy) >= _SHEAR_MAX:
|
| 100 |
+
affine_warn('SHEAR', f'shear angle ({x}, {sy}) too close to ±0.25 turns; '
|
| 101 |
+
'tan diverges and matrix would be degenerate')
|
| 102 |
+
return None
|
| 103 |
+
tx = math.tan(float(x) * 2 * math.pi)
|
| 104 |
+
ty = math.tan(float(sy) * 2 * math.pi)
|
| 105 |
+
if not math.isfinite(tx) or not math.isfinite(ty):
|
| 106 |
+
affine_warn('SHEAR', f'tan({x}, {sy}) produced non-finite value')
|
| 107 |
+
return None
|
| 108 |
+
return torch.tensor([
|
| 109 |
+
[1, tx, 0],
|
| 110 |
+
[ty, 1, 0],
|
| 111 |
+
[0, 0, 1],
|
| 112 |
+
], dtype=torch.float32)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# Keep old private names (parser uses them)
|
| 116 |
+
_make_rotate = make_rotate
|
| 117 |
+
_make_slide = make_slide
|
| 118 |
+
_make_scale = make_scale
|
| 119 |
+
_make_shear = make_shear
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# ---------------------------------------------------------------------------
|
| 123 |
+
# Application helper
|
| 124 |
+
# ---------------------------------------------------------------------------
|
| 125 |
+
|
| 126 |
+
def apply_affine(t: torch.Tensor, m: Optional[torch.Tensor]) -> torch.Tensor:
|
| 127 |
+
"""Apply matrix *m* to running transform *t*; if *m* is None return *t* unchanged."""
|
| 128 |
+
return t if m is None else t @ m
|
| 129 |
+
|
| 130 |
+
_apply_affine = apply_affine # alias for parser
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# ---------------------------------------------------------------------------
|
| 134 |
+
# Parser dispatch table
|
| 135 |
+
# ---------------------------------------------------------------------------
|
| 136 |
+
|
| 137 |
+
affine_transforms: Dict[str, object] = {
|
| 138 |
+
'ROTATE': lambda t, angle=0, *_: apply_affine(t, make_rotate(float(angle))),
|
| 139 |
+
'SLIDE': lambda t, x=0, y=0, *_: apply_affine(t, make_slide(float(x), float(y))),
|
| 140 |
+
'SCALE': lambda t, x=1, y=None, *_: apply_affine(
|
| 141 |
+
t, make_scale(float(x), float(y) if y is not None else None)),
|
| 142 |
+
'SHEAR': lambda t, x=0, y=None, *_: apply_affine(
|
| 143 |
+
t, make_shear(float(x), float(y) if y is not None else None)),
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# ---------------------------------------------------------------------------
|
| 148 |
+
# Safe matrix inversion (runtime, used by cfg_denoiser_hijack)
|
| 149 |
+
# ---------------------------------------------------------------------------
|
| 150 |
+
|
| 151 |
+
def try_invert_affine(
|
| 152 |
+
m3: torch.Tensor,
|
| 153 |
+
context: str = 'affine transform',
|
| 154 |
+
) -> torch.Tensor:
|
| 155 |
+
"""
|
| 156 |
+
Safely invert a 3×3 homogeneous affine matrix.
|
| 157 |
+
|
| 158 |
+
On any error (singular, NaN) logs a warning (verbose only) and returns
|
| 159 |
+
the 3×3 identity so the pipeline degrades gracefully instead of crashing.
|
| 160 |
+
"""
|
| 161 |
+
try:
|
| 162 |
+
result = torch.linalg.inv(m3)
|
| 163 |
+
if not torch.isfinite(result).all():
|
| 164 |
+
raise ValueError('non-finite values in inverted matrix')
|
| 165 |
+
return result
|
| 166 |
+
except Exception as exc:
|
| 167 |
+
try:
|
| 168 |
+
from lib_neutral_prompt import global_state as _gs
|
| 169 |
+
verbose = _gs.verbose
|
| 170 |
+
except ImportError:
|
| 171 |
+
verbose = True
|
| 172 |
+
if verbose:
|
| 173 |
+
print(
|
| 174 |
+
f'[neutral-prompt] WARNING: could not invert {context}: {exc}. '
|
| 175 |
+
'Falling back to identity (no affine transform applied).',
|
| 176 |
+
file=sys.stderr,
|
| 177 |
+
)
|
| 178 |
+
return torch.eye(3, device=m3.device, dtype=m3.dtype)
|
| 179 |
+
|
| 180 |
+
_try_invert_affine = try_invert_affine # alias
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# ---------------------------------------------------------------------------
|
| 184 |
+
# UI helpers (snippet builder + preset library)
|
| 185 |
+
# ---------------------------------------------------------------------------
|
| 186 |
+
|
| 187 |
+
AFFINE_TRANSFORMS = ['ROTATE', 'SCALE', 'SLIDE', 'SHEAR', 'FLIP_H', 'FLIP_V']
|
| 188 |
+
|
| 189 |
+
AFFINE_PRESETS: Dict[str, Optional[Tuple]] = {
|
| 190 |
+
'Custom': None,
|
| 191 |
+
'Mirror H': ('SCALE', -1.0, 1.0),
|
| 192 |
+
'Mirror V': ('SCALE', 1.0, -1.0),
|
| 193 |
+
'Rotate 45°': ('ROTATE', 0.125, None),
|
| 194 |
+
'Rotate 90°': ('ROTATE', 0.25, None),
|
| 195 |
+
'Rotate 180°': ('ROTATE', 0.5, None),
|
| 196 |
+
'Zoom out 50%': ('SCALE', 0.5, 0.5),
|
| 197 |
+
'Zoom in 150%': ('SCALE', 1.5, 1.5),
|
| 198 |
+
'Stretch X': ('SCALE', 1.5, 1.0),
|
| 199 |
+
'Stretch Y': ('SCALE', 1.0, 1.5),
|
| 200 |
+
'Slight rotate': ('ROTATE', 0.05, None),
|
| 201 |
+
'Slide right': ('SLIDE', 0.1, 0.0),
|
| 202 |
+
'Slide down': ('SLIDE', 0.0, 0.1),
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def build_affine_snippet(transform: str, p1: float, p2: float) -> str:
|
| 207 |
+
"""Convert UI control values to a single ``TRANSFORM[args]`` token."""
|
| 208 |
+
t = transform.strip()
|
| 209 |
+
if t == 'FLIP_H': return 'SCALE[-1,1]'
|
| 210 |
+
if t == 'FLIP_V': return 'SCALE[1,-1]'
|
| 211 |
+
if t == 'ROTATE': return f'ROTATE[{p1}]'
|
| 212 |
+
if t == 'SCALE':
|
| 213 |
+
return f'SCALE[{p1},{p2}]' if abs(p1 - p2) > 1e-9 else f'SCALE[{p1}]'
|
| 214 |
+
if t == 'SLIDE': return f'SLIDE[{p1},{p2}]'
|
| 215 |
+
if t == 'SHEAR':
|
| 216 |
+
return f'SHEAR[{p1},{p2}]' if abs(p1 - p2) > 1e-9 else f'SHEAR[{p1}]'
|
| 217 |
+
return ''
|
| 218 |
+
|
| 219 |
+
# Old private alias (matryoshka_utils re-exports from here)
|
| 220 |
+
_build_affine_snippet = build_affine_snippet
|
neutral_prompt_patcheds/lib_neutral_prompt/cfg_denoiser_hijack.py
ADDED
|
@@ -0,0 +1,726 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified CFG Denoiser Hijack
|
| 3 |
+
===========================
|
| 4 |
+
Combines features from all sd-webui-neutral-prompt branches:
|
| 5 |
+
|
| 6 |
+
• Core PERP / SALT / TOPK strategies (main branch)
|
| 7 |
+
• cfg_rescale with mean-preserving formula (main branch)
|
| 8 |
+
• cfg_rescale_override (XYZ-grid / API support) (main branch)
|
| 9 |
+
• CFGRescaleFactorSingleton export (export_rescale_factor branch)
|
| 10 |
+
• Affine spatial transforms per-prompt (affine branch)
|
| 11 |
+
• AND_ALIGN_D_S – soft alignment blend (alignment_blend branch)
|
| 12 |
+
• AND_MASK_ALIGN_D_S – binary alignment mask (alignment_mask branch)
|
| 13 |
+
• Improved salience with sharpness parameter k (life branch)
|
| 14 |
+
|
| 15 |
+
Changes vs original ultimate:
|
| 16 |
+
• AND_SALT / AND_SALT_WIDE / AND_SALT_BLOB: k read from conciliation_params
|
| 17 |
+
(default 5 / 1 / 5 — tuned to actually be visible)
|
| 18 |
+
• AND_SALT_BLOB: erode/thickify order fixed (thickify first, then erode)
|
| 19 |
+
• AND_ALIGN[D,S] / AND_MASK_ALIGN[D,S]: bracket-syntax routed via
|
| 20 |
+
ALIGNMENT_BLEND_CUSTOM / ALIGNMENT_MASK_BLEND_CUSTOM
|
| 21 |
+
• Guard against "no base prompt" → cond_delta=0 replacement bug
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import dataclasses
|
| 27 |
+
import functools
|
| 28 |
+
import re
|
| 29 |
+
import sys
|
| 30 |
+
import textwrap
|
| 31 |
+
from typing import Dict, List, Tuple
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
import torch.nn.functional as F
|
| 35 |
+
|
| 36 |
+
from lib_neutral_prompt import affine_transform as affine_mod
|
| 37 |
+
from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser
|
| 38 |
+
from modules import script_callbacks, sd_samplers, shared
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
# Pre-noise affine hook (affine branch)
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
|
| 45 |
+
@dataclasses.dataclass
|
| 46 |
+
class _PreNoiseArgs:
|
| 47 |
+
x: torch.Tensor
|
| 48 |
+
cond_indices: List[Tuple[int, float]]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class _GlobalToLocalAffineVisitor:
|
| 52 |
+
def visit_leaf_prompt(
|
| 53 |
+
self,
|
| 54 |
+
that: neutral_prompt_parser.LeafPrompt,
|
| 55 |
+
args: _PreNoiseArgs,
|
| 56 |
+
index: int,
|
| 57 |
+
) -> Dict[int, torch.Tensor]:
|
| 58 |
+
cond_index = args.cond_indices[index][0]
|
| 59 |
+
if that.local_transform is not None:
|
| 60 |
+
m3 = torch.vstack([that.local_transform,
|
| 61 |
+
torch.tensor([0.0, 0.0, 1.0])])
|
| 62 |
+
inv = _try_invert_affine(m3, 'leaf local_transform')[:-1]
|
| 63 |
+
else:
|
| 64 |
+
inv = torch.eye(3)[:-1]
|
| 65 |
+
return {cond_index: inv}
|
| 66 |
+
|
| 67 |
+
def visit_composite_prompt(
|
| 68 |
+
self,
|
| 69 |
+
that: neutral_prompt_parser.CompositePrompt,
|
| 70 |
+
args: _PreNoiseArgs,
|
| 71 |
+
index: int,
|
| 72 |
+
) -> Dict[int, torch.Tensor]:
|
| 73 |
+
inv_transforms: Dict[int, torch.Tensor] = {}
|
| 74 |
+
|
| 75 |
+
for child in that.children:
|
| 76 |
+
inv_transforms.update(child.accept(_GlobalToLocalAffineVisitor(), args, index))
|
| 77 |
+
index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
|
| 78 |
+
|
| 79 |
+
if that.local_transform is not None:
|
| 80 |
+
m3 = torch.vstack([that.local_transform,
|
| 81 |
+
torch.tensor([0.0, 0.0, 1.0])])
|
| 82 |
+
parent_inv = _try_invert_affine(m3, 'composite parent local_transform')
|
| 83 |
+
for inv in inv_transforms.values():
|
| 84 |
+
inv3 = torch.vstack([inv, torch.tensor([0.0, 0.0, 1.0])])
|
| 85 |
+
inv[:] = (parent_inv @ inv3)[:-1]
|
| 86 |
+
|
| 87 |
+
return inv_transforms
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _on_cfg_denoiser(params: script_callbacks.CFGDenoiserParams) -> None:
|
| 91 |
+
if not global_state.is_enabled:
|
| 92 |
+
return
|
| 93 |
+
if not _batch_is_compatible(global_state.prompt_exprs, global_state.batch_cond_indices):
|
| 94 |
+
return
|
| 95 |
+
|
| 96 |
+
for prompt, cond_indices in zip(global_state.prompt_exprs,
|
| 97 |
+
global_state.batch_cond_indices):
|
| 98 |
+
args = _PreNoiseArgs(params.x, cond_indices)
|
| 99 |
+
inv_transforms = prompt.accept(_GlobalToLocalAffineVisitor(), args, 0)
|
| 100 |
+
for cond_index, _ in cond_indices:
|
| 101 |
+
params.x[cond_index] = affine_mod.apply_affine_transform(
|
| 102 |
+
params.x[cond_index], inv_transforms[cond_index]
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
script_callbacks.on_cfg_denoiser(_on_cfg_denoiser)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _flat_size(prompt: neutral_prompt_parser.PromptExpr) -> int:
|
| 110 |
+
return prompt.accept(neutral_prompt_parser.FlatSizeVisitor())
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _batch_is_compatible(
|
| 114 |
+
prompts: List[neutral_prompt_parser.PromptExpr],
|
| 115 |
+
batch_cond_indices: List[List[Tuple[int, float]]],
|
| 116 |
+
) -> bool:
|
| 117 |
+
if len(prompts) != len(batch_cond_indices):
|
| 118 |
+
_console_warn(f'''
|
| 119 |
+
Neutral Prompt batch mismatch:
|
| 120 |
+
prompt_exprs={len(prompts)} vs batch_cond_indices={len(batch_cond_indices)}
|
| 121 |
+
Falling back to original A1111 behavior for this step.
|
| 122 |
+
''')
|
| 123 |
+
return False
|
| 124 |
+
|
| 125 |
+
for i, (prompt, cond_indices) in enumerate(zip(prompts, batch_cond_indices)):
|
| 126 |
+
need = _flat_size(prompt)
|
| 127 |
+
got = len(cond_indices)
|
| 128 |
+
if need != got:
|
| 129 |
+
_console_warn(f'''
|
| 130 |
+
Neutral Prompt branch mismatch at prompt #{i}:
|
| 131 |
+
expected {need} branches, got {got}
|
| 132 |
+
Falling back to original A1111 behavior for this step.
|
| 133 |
+
''')
|
| 134 |
+
return False
|
| 135 |
+
|
| 136 |
+
return True
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ---------------------------------------------------------------------------
|
| 140 |
+
# Public entry point
|
| 141 |
+
# ---------------------------------------------------------------------------
|
| 142 |
+
|
| 143 |
+
def combine_denoised_hijack(
|
| 144 |
+
x_out: torch.Tensor,
|
| 145 |
+
batch_cond_indices: List[List[Tuple[int, float]]],
|
| 146 |
+
text_uncond: torch.Tensor,
|
| 147 |
+
cond_scale: float,
|
| 148 |
+
original_function,
|
| 149 |
+
) -> torch.Tensor:
|
| 150 |
+
if not global_state.is_enabled:
|
| 151 |
+
return original_function(x_out, batch_cond_indices, text_uncond, cond_scale)
|
| 152 |
+
|
| 153 |
+
if not _batch_is_compatible(global_state.prompt_exprs, batch_cond_indices):
|
| 154 |
+
return original_function(x_out, batch_cond_indices, text_uncond, cond_scale)
|
| 155 |
+
|
| 156 |
+
# ------------------------------------------------------------------
|
| 157 |
+
# Update step counters for step-window gating (runs before any gating check)
|
| 158 |
+
# shared.state.sampling_step / sampling_steps are set by A1111 per denoiser call
|
| 159 |
+
# ------------------------------------------------------------------
|
| 160 |
+
try:
|
| 161 |
+
global_state.update_step_state(
|
| 162 |
+
shared.state.sampling_step,
|
| 163 |
+
shared.state.sampling_steps,
|
| 164 |
+
)
|
| 165 |
+
except Exception:
|
| 166 |
+
pass # sampling_step may not be set during warmup; keep previous values
|
| 167 |
+
# NOTE: Lock-after-End flags are reset in NeutralPromptScript.process(),
|
| 168 |
+
# not here. Resetting on step==0 would incorrectly re-arm locks inside
|
| 169 |
+
# samplers that revisit step 0 (Restart, DPM++ multi-pass, hires fix).
|
| 170 |
+
|
| 171 |
+
# ------------------------------------------------------------------
|
| 172 |
+
# Base-prompt protection (Off / Auto / Strict / Soft)
|
| 173 |
+
# ------------------------------------------------------------------
|
| 174 |
+
should_fallback, reason = _should_fallback_for_protection(
|
| 175 |
+
x_out, batch_cond_indices, text_uncond
|
| 176 |
+
)
|
| 177 |
+
if should_fallback:
|
| 178 |
+
_console_warn(f'\n [base-prompt protection] {reason}\n Falling back to standard CFG.')
|
| 179 |
+
return original_function(x_out, batch_cond_indices, text_uncond, cond_scale)
|
| 180 |
+
# ------------------------------------------------------------------
|
| 181 |
+
|
| 182 |
+
denoised = _get_webui_denoised(x_out, batch_cond_indices, text_uncond, cond_scale, original_function)
|
| 183 |
+
uncond = x_out[-text_uncond.shape[0]:]
|
| 184 |
+
|
| 185 |
+
for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
|
| 186 |
+
args = _DenoiseArgs(x_out, uncond[batch_i], cond_indices)
|
| 187 |
+
cond_delta = prompt.accept(_CondDeltaVisitor(), args, 0)
|
| 188 |
+
aux_cond_delta = prompt.accept(_AuxCondDeltaVisitor(), args, cond_delta, 0)
|
| 189 |
+
|
| 190 |
+
# --- Soft mode: attenuate auxiliary if base is too weak ---
|
| 191 |
+
aux_cond_delta, _ = apply_soft_attenuation_if_needed(
|
| 192 |
+
aux_cond_delta, cond_delta, batch_i
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
if prompt.local_transform is not None:
|
| 196 |
+
cond_delta = affine_mod.apply_affine_transform(cond_delta, prompt.local_transform)
|
| 197 |
+
aux_cond_delta = affine_mod.apply_affine_transform(aux_cond_delta, prompt.local_transform)
|
| 198 |
+
|
| 199 |
+
cfg_cond = denoised[batch_i] + aux_cond_delta * cond_scale
|
| 200 |
+
denoised[batch_i] = _cfg_rescale(cfg_cond, uncond[batch_i] + cond_delta + aux_cond_delta)
|
| 201 |
+
|
| 202 |
+
return denoised
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# ---------------------------------------------------------------------------
|
| 206 |
+
# Internal helpers
|
| 207 |
+
# ---------------------------------------------------------------------------
|
| 208 |
+
|
| 209 |
+
def _get_webui_denoised(
|
| 210 |
+
x_out: torch.Tensor,
|
| 211 |
+
batch_cond_indices: List[List[Tuple[int, float]]],
|
| 212 |
+
text_uncond: torch.Tensor,
|
| 213 |
+
cond_scale: float,
|
| 214 |
+
original_function,
|
| 215 |
+
) -> torch.Tensor:
|
| 216 |
+
uncond = x_out[-text_uncond.shape[0]:]
|
| 217 |
+
sliced_batch_x_out: List[torch.Tensor] = []
|
| 218 |
+
sliced_batch_cond_indices: List[List[Tuple[int, float]]] = []
|
| 219 |
+
|
| 220 |
+
for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
|
| 221 |
+
args = _DenoiseArgs(x_out, uncond[batch_i], cond_indices)
|
| 222 |
+
sliced_x_out, sliced_indices = _gather_webui_conds(prompt, args, 0, len(sliced_batch_x_out))
|
| 223 |
+
if sliced_indices:
|
| 224 |
+
sliced_batch_cond_indices.append(sliced_indices)
|
| 225 |
+
sliced_batch_x_out.extend(sliced_x_out)
|
| 226 |
+
|
| 227 |
+
sliced_batch_x_out += list(uncond)
|
| 228 |
+
return original_function(
|
| 229 |
+
torch.stack(sliced_batch_x_out, dim=0),
|
| 230 |
+
sliced_batch_cond_indices,
|
| 231 |
+
text_uncond,
|
| 232 |
+
cond_scale,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def _cfg_rescale(cfg_cond: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
| 237 |
+
"""Mean-preserving CFG rescale."""
|
| 238 |
+
global_state.CFGRescaleFactorSingleton.clear()
|
| 239 |
+
global_state.apply_and_clear_cfg_rescale_override()
|
| 240 |
+
|
| 241 |
+
if global_state.cfg_rescale == 0:
|
| 242 |
+
return cfg_cond
|
| 243 |
+
|
| 244 |
+
cfg_std = cfg_cond.std()
|
| 245 |
+
if cfg_std == 0:
|
| 246 |
+
return cfg_cond
|
| 247 |
+
|
| 248 |
+
cfg_cond_mean = cfg_cond.mean()
|
| 249 |
+
rescale_mean = (
|
| 250 |
+
(1 - global_state.cfg_rescale) * cfg_cond_mean
|
| 251 |
+
+ global_state.cfg_rescale * cond.mean()
|
| 252 |
+
)
|
| 253 |
+
rescale_factor = global_state.cfg_rescale * (cond.std() / cfg_std - 1) + 1
|
| 254 |
+
|
| 255 |
+
global_state.CFGRescaleFactorSingleton.set(
|
| 256 |
+
rescale_factor.item() if isinstance(rescale_factor, torch.Tensor) else float(rescale_factor)
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
return rescale_mean + (cfg_cond - cfg_cond_mean) * rescale_factor
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
@dataclasses.dataclass
|
| 263 |
+
class _DenoiseArgs:
|
| 264 |
+
x_out: torch.Tensor
|
| 265 |
+
uncond: torch.Tensor
|
| 266 |
+
cond_indices: List[Tuple[int, float]]
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def _gather_webui_conds(
|
| 270 |
+
prompt: neutral_prompt_parser.CompositePrompt,
|
| 271 |
+
args: _DenoiseArgs,
|
| 272 |
+
index_in: int,
|
| 273 |
+
index_out: int,
|
| 274 |
+
) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
|
| 275 |
+
sliced_x_out: List[torch.Tensor] = []
|
| 276 |
+
sliced_cond_indices: List[Tuple[int, float]] = []
|
| 277 |
+
|
| 278 |
+
for child in prompt.children:
|
| 279 |
+
if child.conciliation is None:
|
| 280 |
+
if isinstance(child, neutral_prompt_parser.LeafPrompt) and child.local_transform is None:
|
| 281 |
+
child_x_out = args.x_out[args.cond_indices[index_in][0]]
|
| 282 |
+
child_weight = child.weight
|
| 283 |
+
else:
|
| 284 |
+
child_x_out, child_weight = _get_cond_delta_and_weight(child, args, index_in)
|
| 285 |
+
child_x_out = child_x_out + args.uncond
|
| 286 |
+
|
| 287 |
+
index_offset = index_out + len(sliced_x_out)
|
| 288 |
+
sliced_x_out.append(child_x_out)
|
| 289 |
+
sliced_cond_indices.append((index_offset, child_weight))
|
| 290 |
+
|
| 291 |
+
index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor())
|
| 292 |
+
|
| 293 |
+
return sliced_x_out, sliced_cond_indices
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def _get_cond_delta_and_weight(
|
| 297 |
+
prompt: neutral_prompt_parser.PromptExpr,
|
| 298 |
+
args: _DenoiseArgs,
|
| 299 |
+
index: int,
|
| 300 |
+
) -> Tuple[torch.Tensor, float]:
|
| 301 |
+
cond_delta = prompt.accept(_CondDeltaVisitor(), args, index)
|
| 302 |
+
cond_delta = cond_delta + prompt.accept(_AuxCondDeltaVisitor(), args, cond_delta, index)
|
| 303 |
+
weight = prompt.weight
|
| 304 |
+
|
| 305 |
+
if prompt.local_transform is not None:
|
| 306 |
+
transformed, weight_tensor = affine_mod.apply_masked_transform(
|
| 307 |
+
cond_delta + args.uncond,
|
| 308 |
+
prompt.local_transform,
|
| 309 |
+
prompt.weight,
|
| 310 |
+
)
|
| 311 |
+
cond_delta = transformed - args.uncond
|
| 312 |
+
weight = weight_tensor
|
| 313 |
+
|
| 314 |
+
return cond_delta, weight
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
# ---------------------------------------------------------------------------
|
| 318 |
+
# Visitor: CondDelta
|
| 319 |
+
# ---------------------------------------------------------------------------
|
| 320 |
+
|
| 321 |
+
class _CondDeltaVisitor:
|
| 322 |
+
def visit_leaf_prompt(
|
| 323 |
+
self,
|
| 324 |
+
that: neutral_prompt_parser.LeafPrompt,
|
| 325 |
+
args: _DenoiseArgs,
|
| 326 |
+
index: int,
|
| 327 |
+
) -> torch.Tensor:
|
| 328 |
+
cond_info = args.cond_indices[index]
|
| 329 |
+
if that.weight != cond_info[1]:
|
| 330 |
+
_console_warn(f'''
|
| 331 |
+
Unexpected noise weight at prompt #{index}
|
| 332 |
+
Expected :{that.weight}, got :{cond_info[1]}
|
| 333 |
+
''')
|
| 334 |
+
return args.x_out[cond_info[0]] - args.uncond
|
| 335 |
+
|
| 336 |
+
def visit_composite_prompt(
|
| 337 |
+
self,
|
| 338 |
+
that: neutral_prompt_parser.CompositePrompt,
|
| 339 |
+
args: _DenoiseArgs,
|
| 340 |
+
index: int,
|
| 341 |
+
) -> torch.Tensor:
|
| 342 |
+
cond_delta = torch.zeros_like(args.x_out[0])
|
| 343 |
+
for child in that.children:
|
| 344 |
+
if child.conciliation is None:
|
| 345 |
+
child_delta, child_weight = _get_cond_delta_and_weight(child, args, index)
|
| 346 |
+
cond_delta = cond_delta + child_weight * child_delta
|
| 347 |
+
index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
|
| 348 |
+
return cond_delta
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
# ---------------------------------------------------------------------------
|
| 352 |
+
# Visitor: AuxCondDelta — all conciliation strategies
|
| 353 |
+
# ---------------------------------------------------------------------------
|
| 354 |
+
|
| 355 |
+
class _AuxCondDeltaVisitor:
|
| 356 |
+
def visit_leaf_prompt(
|
| 357 |
+
self,
|
| 358 |
+
that: neutral_prompt_parser.LeafPrompt,
|
| 359 |
+
args: _DenoiseArgs,
|
| 360 |
+
cond_delta: torch.Tensor,
|
| 361 |
+
index: int,
|
| 362 |
+
) -> torch.Tensor:
|
| 363 |
+
return torch.zeros_like(args.x_out[0])
|
| 364 |
+
|
| 365 |
+
def visit_composite_prompt(
|
| 366 |
+
self,
|
| 367 |
+
that: neutral_prompt_parser.CompositePrompt,
|
| 368 |
+
args: _DenoiseArgs,
|
| 369 |
+
cond_delta: torch.Tensor,
|
| 370 |
+
index: int,
|
| 371 |
+
) -> torch.Tensor:
|
| 372 |
+
aux_cond_delta = torch.zeros_like(args.x_out[0])
|
| 373 |
+
|
| 374 |
+
# Each salience list entry: (delta, weight, k)
|
| 375 |
+
salient_cond_deltas: List[Tuple[torch.Tensor, float, float]] = []
|
| 376 |
+
salient_wide_deltas: List[Tuple[torch.Tensor, float, float]] = []
|
| 377 |
+
salient_blob_deltas: List[Tuple[torch.Tensor, float, float]] = []
|
| 378 |
+
align_blend_deltas: List[Tuple[torch.Tensor, float, int, int]] = []
|
| 379 |
+
mask_align_deltas: List[Tuple[torch.Tensor, float, int, int]] = []
|
| 380 |
+
|
| 381 |
+
CS = neutral_prompt_parser.ConciliationStrategy
|
| 382 |
+
|
| 383 |
+
# Step-window gating: compute normalised progress once per composite visit
|
| 384 |
+
if global_state.step_window_enabled:
|
| 385 |
+
from lib_neutral_prompt.step_utils import (
|
| 386 |
+
strategy_is_active_from_state, normalize_progress,
|
| 387 |
+
)
|
| 388 |
+
progress = normalize_progress(global_state.current_step, global_state.total_steps)
|
| 389 |
+
else:
|
| 390 |
+
progress = None
|
| 391 |
+
|
| 392 |
+
for child in that.children:
|
| 393 |
+
if child.conciliation is not None:
|
| 394 |
+
# --- Step-window gate ---
|
| 395 |
+
if progress is not None:
|
| 396 |
+
strat_name = child.conciliation.name
|
| 397 |
+
if not strategy_is_active_from_state(strat_name, progress):
|
| 398 |
+
index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
|
| 399 |
+
continue
|
| 400 |
+
|
| 401 |
+
child_delta, child_weight = _get_cond_delta_and_weight(child, args, index)
|
| 402 |
+
strat = child.conciliation
|
| 403 |
+
cp = child.conciliation_params # per-child params dict
|
| 404 |
+
|
| 405 |
+
if strat == CS.PERPENDICULAR:
|
| 406 |
+
aux_cond_delta = aux_cond_delta + child_weight * _get_perpendicular_component(cond_delta, child_delta)
|
| 407 |
+
|
| 408 |
+
elif strat == CS.SALIENCE_MASK:
|
| 409 |
+
# Default k=5: noticeably more coverage than original k=20.
|
| 410 |
+
# User can override via AND_SALT[k].
|
| 411 |
+
k = float(cp.get('k', 5.0))
|
| 412 |
+
salient_cond_deltas.append((child_delta, child_weight, k))
|
| 413 |
+
|
| 414 |
+
elif strat == CS.SALIENCE_MASK_WIDE:
|
| 415 |
+
k = float(cp.get('k', 1.0))
|
| 416 |
+
salient_wide_deltas.append((child_delta, child_weight, k))
|
| 417 |
+
|
| 418 |
+
elif strat == CS.SALIENCE_MASK_BLOB:
|
| 419 |
+
k = float(cp.get('k', 5.0))
|
| 420 |
+
salient_blob_deltas.append((child_delta, child_weight, k))
|
| 421 |
+
|
| 422 |
+
elif strat == CS.SEMANTIC_GUIDANCE:
|
| 423 |
+
threshold = float(cp.get('threshold', 0.05))
|
| 424 |
+
aux_cond_delta = aux_cond_delta + child_weight * _filter_abs_top_k(child_delta, threshold)
|
| 425 |
+
|
| 426 |
+
elif strat == CS.ALIGNMENT_BLEND_CUSTOM:
|
| 427 |
+
# AND_ALIGN[D,S] — any user-specified pair
|
| 428 |
+
d = int(cp.get('d', 4))
|
| 429 |
+
s = int(cp.get('s', 8))
|
| 430 |
+
if 2 <= d <= 32 and 2 <= s <= 32 and d != s:
|
| 431 |
+
align_blend_deltas.append((child_delta, child_weight, d, s))
|
| 432 |
+
|
| 433 |
+
elif strat == CS.ALIGNMENT_MASK_BLEND_CUSTOM:
|
| 434 |
+
# AND_MASK_ALIGN[D,S]
|
| 435 |
+
d = int(cp.get('d', 4))
|
| 436 |
+
s = int(cp.get('s', 8))
|
| 437 |
+
if 2 <= d <= 32 and 2 <= s <= 32 and d != s:
|
| 438 |
+
mask_align_deltas.append((child_delta, child_weight, d, s))
|
| 439 |
+
|
| 440 |
+
else:
|
| 441 |
+
# Fixed-suffix AND_ALIGN_D_S (backward-compat)
|
| 442 |
+
m = re.match(r'AND_ALIGN_(\d+)_(\d+)', strat.value)
|
| 443 |
+
if m:
|
| 444 |
+
align_blend_deltas.append((child_delta, child_weight, int(m.group(1)), int(m.group(2))))
|
| 445 |
+
else:
|
| 446 |
+
m = re.match(r'AND_MASK_ALIGN_(\d+)_(\d+)', strat.value)
|
| 447 |
+
if m:
|
| 448 |
+
mask_align_deltas.append((child_delta, child_weight, int(m.group(1)), int(m.group(2))))
|
| 449 |
+
|
| 450 |
+
index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
|
| 451 |
+
|
| 452 |
+
aux_cond_delta = aux_cond_delta + _salient_blend(cond_delta, salient_cond_deltas)
|
| 453 |
+
aux_cond_delta = aux_cond_delta + _salient_blend(cond_delta, salient_wide_deltas)
|
| 454 |
+
aux_cond_delta = aux_cond_delta + _salient_blend_blob(cond_delta, salient_blob_deltas)
|
| 455 |
+
aux_cond_delta = aux_cond_delta + _alignment_blend(cond_delta, align_blend_deltas)
|
| 456 |
+
aux_cond_delta = aux_cond_delta + _alignment_mask_blend(cond_delta, mask_align_deltas)
|
| 457 |
+
return aux_cond_delta
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
# ---------------------------------------------------------------------------
|
| 461 |
+
# Strategy implementations
|
| 462 |
+
# ---------------------------------------------------------------------------
|
| 463 |
+
|
| 464 |
+
def _get_perpendicular_component(normal: torch.Tensor, vector: torch.Tensor) -> torch.Tensor:
|
| 465 |
+
if (normal == 0).all():
|
| 466 |
+
if shared.state.sampling_step <= 0:
|
| 467 |
+
_warn_projection_not_found()
|
| 468 |
+
return vector
|
| 469 |
+
return vector - normal * torch.sum(normal * vector) / torch.norm(normal) ** 2
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def _salient_blend(
|
| 473 |
+
normal: torch.Tensor,
|
| 474 |
+
vectors: List[Tuple[torch.Tensor, float, float]],
|
| 475 |
+
) -> torch.Tensor:
|
| 476 |
+
"""
|
| 477 |
+
Saliency-guided blend. Each entry: (delta, weight, k).
|
| 478 |
+
|
| 479 |
+
k controls child mask sharpness:
|
| 480 |
+
k=1 → broad (~55% of pixels) — AND_SALT_WIDE default
|
| 481 |
+
k=5 → moderate (~5-10%) — AND_SALT default (was 20, now tuned)
|
| 482 |
+
k=20 → very sharp (1-2 pixels) — can be set via AND_SALT[20]
|
| 483 |
+
|
| 484 |
+
Parent always uses k=1 (diffuse reference) so children compete fairly.
|
| 485 |
+
"""
|
| 486 |
+
if not vectors:
|
| 487 |
+
return torch.zeros_like(normal)
|
| 488 |
+
|
| 489 |
+
salience_maps = [_get_salience(normal, k=1.0)] + [
|
| 490 |
+
_get_salience(v, k=k) for v, _, k in vectors
|
| 491 |
+
]
|
| 492 |
+
mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0)
|
| 493 |
+
|
| 494 |
+
result = torch.zeros_like(normal)
|
| 495 |
+
for mask_i, (vector, weight, _) in enumerate(vectors, start=1):
|
| 496 |
+
vector_mask = (mask == mask_i).float()
|
| 497 |
+
result = result + weight * vector_mask * (vector - normal)
|
| 498 |
+
return result
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def _salient_blend_blob(
|
| 502 |
+
normal: torch.Tensor,
|
| 503 |
+
vectors: List[Tuple[torch.Tensor, float, float]],
|
| 504 |
+
) -> torch.Tensor:
|
| 505 |
+
"""
|
| 506 |
+
AND_SALT_BLOB — life-branch algorithm (fixed erode/thickify order).
|
| 507 |
+
|
| 508 |
+
Fix vs original: original did 6×erode then 2×thickify, which always
|
| 509 |
+
destroyed the 1-2 pixel seed from k=20 on the first erosion step.
|
| 510 |
+
Correct pipeline:
|
| 511 |
+
1. k-softmax → initial seed (sharpness from conciliation_params)
|
| 512 |
+
2. thickify ×3 → grow seed into a dense blob first
|
| 513 |
+
3. erode ×1 → smooth/sharpen edges without destroying the core
|
| 514 |
+
"""
|
| 515 |
+
if not vectors:
|
| 516 |
+
return torch.zeros_like(normal)
|
| 517 |
+
|
| 518 |
+
salience_maps = [_get_salience(normal, k=1.0)] + [
|
| 519 |
+
_get_salience(v, k=k) for v, _, k in vectors
|
| 520 |
+
]
|
| 521 |
+
mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0)
|
| 522 |
+
|
| 523 |
+
result = torch.zeros_like(normal)
|
| 524 |
+
for mask_i, (vector, weight, _) in enumerate(vectors, start=1):
|
| 525 |
+
vector_mask = (mask == mask_i).float()
|
| 526 |
+
# Grow first, then smooth — prevents seed self-destruction
|
| 527 |
+
for _ in range(3):
|
| 528 |
+
vector_mask = _life_step(vector_mask, _thickify_rule)
|
| 529 |
+
for _ in range(1):
|
| 530 |
+
vector_mask = _life_step(vector_mask, _erode_rule)
|
| 531 |
+
result = result + weight * vector_mask * (vector - normal)
|
| 532 |
+
return result
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def _life_step(board: torch.Tensor, rule) -> torch.Tensor:
|
| 536 |
+
C = board.shape[0]
|
| 537 |
+
kernel = torch.ones((C, 3, 3), dtype=board.dtype, device=board.device)
|
| 538 |
+
kernel = kernel.unsqueeze(0).unsqueeze(0)
|
| 539 |
+
|
| 540 |
+
padded = torch.cat([board.clone(), board[:-1].clone()], dim=0)
|
| 541 |
+
padded = torch.nn.functional.pad(padded, (1, 1, 1, 1, 0, 0), value=0)
|
| 542 |
+
|
| 543 |
+
neighbors = torch.nn.functional.conv3d(
|
| 544 |
+
padded.unsqueeze(0).unsqueeze(0),
|
| 545 |
+
kernel,
|
| 546 |
+
padding=0,
|
| 547 |
+
).squeeze(0).squeeze(0)
|
| 548 |
+
|
| 549 |
+
neighbors = neighbors - board
|
| 550 |
+
return rule(board, neighbors).float()
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def _erode_rule(board: torch.Tensor, neighbors: torch.Tensor) -> torch.Tensor:
|
| 554 |
+
C = board.shape[0]
|
| 555 |
+
return (board == 1) & (neighbors >= C * 5)
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def _thickify_rule(board: torch.Tensor, neighbors: torch.Tensor) -> torch.Tensor:
|
| 559 |
+
population = board + neighbors
|
| 560 |
+
return (board == 1) | (population >= 4)
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
def _get_salience(vector: torch.Tensor, k: float = 1.0) -> torch.Tensor:
|
| 564 |
+
return torch.softmax(k * torch.abs(vector).flatten(), dim=0).reshape_as(vector)
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def _filter_abs_top_k(vector: torch.Tensor, k_ratio: float) -> torch.Tensor:
|
| 568 |
+
k = int(torch.numel(vector) * (1 - k_ratio))
|
| 569 |
+
k = max(1, k)
|
| 570 |
+
threshold, _ = torch.kthvalue(torch.abs(vector.flatten()), k)
|
| 571 |
+
return vector * (torch.abs(vector) >= threshold).to(vector.dtype)
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
# ---------------------------------------------------------------------------
|
| 575 |
+
# Alignment blend (alignment_blend branch)
|
| 576 |
+
# ---------------------------------------------------------------------------
|
| 577 |
+
|
| 578 |
+
def _compute_subregion_similarity_map(
|
| 579 |
+
child_vector: torch.Tensor,
|
| 580 |
+
parent_vector: torch.Tensor,
|
| 581 |
+
region_size: int = 2,
|
| 582 |
+
) -> torch.Tensor:
|
| 583 |
+
C, H, W = child_vector.shape
|
| 584 |
+
parent = parent_vector.unsqueeze(0)
|
| 585 |
+
child = child_vector.unsqueeze(0)
|
| 586 |
+
|
| 587 |
+
region_radius = region_size // 2
|
| 588 |
+
if region_size % 2 == 1:
|
| 589 |
+
pad_size = (region_radius,) * 4
|
| 590 |
+
else:
|
| 591 |
+
pad_size = (region_radius - 1, region_radius) * 2
|
| 592 |
+
|
| 593 |
+
parent_reg = F.unfold(F.pad(parent, pad_size, 'constant', 0), kernel_size=region_size)
|
| 594 |
+
child_reg = F.unfold(F.pad(child, pad_size, 'constant', 0), kernel_size=region_size)
|
| 595 |
+
|
| 596 |
+
parent_reg = parent_reg.view(1, C, region_size**2, H*W).permute(3, 1, 2, 0).contiguous().view(H*W, C, region_size, region_size)
|
| 597 |
+
child_reg = child_reg.view( 1, C, region_size**2, H*W).permute(3, 1, 2, 0).contiguous().view(H*W, C, region_size, region_size)
|
| 598 |
+
|
| 599 |
+
unfold2 = torch.nn.Unfold(kernel_size=2)
|
| 600 |
+
parent_sub = unfold2(parent_reg).view(H*W, C, 4, (region_size - 1)**2)
|
| 601 |
+
child_sub = unfold2(child_reg ).view(H*W, C, 4, (region_size - 1)**2)
|
| 602 |
+
|
| 603 |
+
parent_sub = F.normalize(parent_sub, p=2, dim=2)
|
| 604 |
+
child_sub = F.normalize(child_sub, p=2, dim=2)
|
| 605 |
+
sim = (parent_sub * child_sub).sum(dim=2)
|
| 606 |
+
return sim.mean(dim=2).permute(1, 0).contiguous().view(C, H, W)
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
def _alignment_blend(
|
| 610 |
+
parent: torch.Tensor,
|
| 611 |
+
children: List[Tuple[torch.Tensor, float, int, int]],
|
| 612 |
+
) -> torch.Tensor:
|
| 613 |
+
result = torch.zeros_like(parent)
|
| 614 |
+
for child, weight, detail_size, structure_size in children:
|
| 615 |
+
detail_sim = _compute_subregion_similarity_map(child, parent, detail_size)
|
| 616 |
+
structure_sim = _compute_subregion_similarity_map(child, parent, structure_size)
|
| 617 |
+
|
| 618 |
+
d_abs_max = detail_sim.abs().max().clamp(min=1e-8)
|
| 619 |
+
s_abs_max = structure_sim.abs().max().clamp(min=1e-8)
|
| 620 |
+
detail_sim = detail_sim / d_abs_max
|
| 621 |
+
structure_sim = structure_sim / s_abs_max
|
| 622 |
+
|
| 623 |
+
alignment_weight = torch.clamp(structure_sim - detail_sim, min=0.0, max=1.0)
|
| 624 |
+
result = result + (child - parent) * weight * alignment_weight
|
| 625 |
+
return result
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
def _alignment_mask_blend(
|
| 629 |
+
parent: torch.Tensor,
|
| 630 |
+
children: List[Tuple[torch.Tensor, float, int, int]],
|
| 631 |
+
) -> torch.Tensor:
|
| 632 |
+
result = torch.zeros_like(parent)
|
| 633 |
+
for child, weight, detail_size, structure_size in children:
|
| 634 |
+
detail_sim = _compute_subregion_similarity_map(child, parent, detail_size)
|
| 635 |
+
structure_sim = _compute_subregion_similarity_map(child, parent, structure_size)
|
| 636 |
+
|
| 637 |
+
d_abs_max = detail_sim.abs().max().clamp(min=1e-8)
|
| 638 |
+
s_abs_max = structure_sim.abs().max().clamp(min=1e-8)
|
| 639 |
+
detail_sim = detail_sim / d_abs_max
|
| 640 |
+
structure_sim = structure_sim / s_abs_max
|
| 641 |
+
|
| 642 |
+
alignment_mask = (structure_sim > detail_sim).to(child)
|
| 643 |
+
result = result + (child - parent) * weight * alignment_mask
|
| 644 |
+
return result
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
# ---------------------------------------------------------------------------
|
| 649 |
+
# Base-prompt protection policy (imported from protection_utils)
|
| 650 |
+
# ---------------------------------------------------------------------------
|
| 651 |
+
from lib_neutral_prompt.protection_utils import (
|
| 652 |
+
prompt_has_valid_base_path as _prompt_has_valid_base_path,
|
| 653 |
+
get_invalid_base_prompt_indices as _get_invalid_base_prompt_indices,
|
| 654 |
+
has_valid_base_path as _has_valid_base_path,
|
| 655 |
+
safe_norm as _safe_norm,
|
| 656 |
+
apply_soft_attenuation_if_needed,
|
| 657 |
+
)
|
| 658 |
+
from lib_neutral_prompt.affine_utils import try_invert_affine as _try_invert_affine
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
def _should_fallback_for_protection(
|
| 662 |
+
x_out: torch.Tensor,
|
| 663 |
+
batch_cond_indices,
|
| 664 |
+
text_uncond: torch.Tensor,
|
| 665 |
+
) -> tuple:
|
| 666 |
+
"""
|
| 667 |
+
Decide whether to fall back to standard CFG. Delegates to protection_utils.
|
| 668 |
+
Passes the visitor classes defined in this module to avoid circular imports.
|
| 669 |
+
"""
|
| 670 |
+
from lib_neutral_prompt.protection_utils import should_fallback_for_protection
|
| 671 |
+
return should_fallback_for_protection(
|
| 672 |
+
x_out, batch_cond_indices, text_uncond,
|
| 673 |
+
_CondDeltaVisitor, _AuxCondDeltaVisitor, _DenoiseArgs,
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
# ---------------------------------------------------------------------------
|
| 679 |
+
# Sampler hijack
|
| 680 |
+
# ---------------------------------------------------------------------------
|
| 681 |
+
|
| 682 |
+
sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
|
| 683 |
+
module=sd_samplers,
|
| 684 |
+
hijacker_attribute='__neutral_prompt_hijacker',
|
| 685 |
+
on_uninstall=script_callbacks.on_script_unloaded,
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
@sd_samplers_hijacker.hijack('create_sampler')
|
| 690 |
+
def create_sampler_hijack(name: str, model, original_function):
|
| 691 |
+
sampler = original_function(name, model)
|
| 692 |
+
if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'):
|
| 693 |
+
if global_state.is_enabled:
|
| 694 |
+
_warn_unsupported_sampler()
|
| 695 |
+
return sampler
|
| 696 |
+
|
| 697 |
+
sampler.model_wrap_cfg.combine_denoised = functools.partial(
|
| 698 |
+
combine_denoised_hijack,
|
| 699 |
+
original_function=sampler.model_wrap_cfg.combine_denoised,
|
| 700 |
+
)
|
| 701 |
+
return sampler
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
# ---------------------------------------------------------------------------
|
| 705 |
+
# Warnings / logging
|
| 706 |
+
# ---------------------------------------------------------------------------
|
| 707 |
+
|
| 708 |
+
def _warn_unsupported_sampler() -> None:
|
| 709 |
+
_console_warn('''
|
| 710 |
+
Neutral prompt relies on composition via AND, which the webui does not support
|
| 711 |
+
when using any of the DDIM, PLMS and UniPC samplers.
|
| 712 |
+
The sampler will NOT be patched – falling back on the original implementation.
|
| 713 |
+
''')
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def _warn_projection_not_found() -> None:
|
| 717 |
+
_console_warn('''
|
| 718 |
+
Could not find a projection for one or more AND_PERP prompts.
|
| 719 |
+
These prompts will NOT be made perpendicular.
|
| 720 |
+
''')
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
def _console_warn(message: str) -> None:
|
| 724 |
+
if not global_state.verbose:
|
| 725 |
+
return
|
| 726 |
+
print(f'\n[sd-webui-neutral-prompt]{textwrap.dedent(message)}', file=sys.stderr)
|
neutral_prompt_patcheds/lib_neutral_prompt/external_code/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@contextlib.contextmanager
|
| 5 |
+
def fix_path():
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
extension_path = str(Path(__file__).parent.parent.parent)
|
| 10 |
+
added = False
|
| 11 |
+
if extension_path not in sys.path:
|
| 12 |
+
sys.path.insert(0, extension_path)
|
| 13 |
+
added = True
|
| 14 |
+
|
| 15 |
+
yield
|
| 16 |
+
|
| 17 |
+
if added:
|
| 18 |
+
sys.path.remove(extension_path)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
with fix_path():
|
| 22 |
+
del fix_path, contextlib
|
| 23 |
+
from .api import *
|
neutral_prompt_patcheds/lib_neutral_prompt/external_code/api.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
External API for sd-webui-neutral-prompt.
|
| 3 |
+
|
| 4 |
+
Provides thin helpers that external extensions / scripts can import via:
|
| 5 |
+
|
| 6 |
+
from lib_neutral_prompt.external_code import override_cfg_rescale, get_last_cfg_rescale_factor
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
from lib_neutral_prompt import global_state
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def override_cfg_rescale(cfg_rescale: float) -> None:
|
| 15 |
+
"""
|
| 16 |
+
Override the CFG rescale value for the *next* generation step only.
|
| 17 |
+
After the step runs the value is automatically cleared.
|
| 18 |
+
"""
|
| 19 |
+
global_state.cfg_rescale_override = cfg_rescale
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_last_cfg_rescale_factor() -> Optional[float]:
|
| 23 |
+
"""
|
| 24 |
+
Return the CFG rescale factor computed during the most recent denoising
|
| 25 |
+
step, or None if rescaling was not active.
|
| 26 |
+
"""
|
| 27 |
+
return global_state.CFGRescaleFactorSingleton.get()
|
neutral_prompt_patcheds/lib_neutral_prompt/global_state.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Global mutable state shared across the extension.
|
| 3 |
+
|
| 4 |
+
Combines:
|
| 5 |
+
- is_enabled / prompt_exprs / verbose (all branches)
|
| 6 |
+
- cfg_rescale + cfg_rescale_override (main branch)
|
| 7 |
+
- batch_cond_indices (affine branch)
|
| 8 |
+
- CFGRescaleFactorSingleton (export_rescale_factor branch)
|
| 9 |
+
- protection_mode / strict_threshold (patched — base-prompt protection)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import threading
|
| 13 |
+
from typing import List, Optional, Tuple
|
| 14 |
+
from lib_neutral_prompt import neutral_prompt_parser
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ---------------------------------------------------------------------------
|
| 18 |
+
# Runtime state
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
|
| 21 |
+
is_enabled: bool = False
|
| 22 |
+
prompt_exprs: List[neutral_prompt_parser.PromptExpr] = []
|
| 23 |
+
batch_cond_indices: List[List[Tuple[int, float]]] = []
|
| 24 |
+
cfg_rescale: float = 0.0
|
| 25 |
+
verbose: bool = False
|
| 26 |
+
|
| 27 |
+
# Set to a float value by XYZ-grid or external API to override cfg_rescale
|
| 28 |
+
# for a single generation step, then auto-cleared.
|
| 29 |
+
cfg_rescale_override: Optional[float] = None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def apply_and_clear_cfg_rescale_override() -> None:
|
| 33 |
+
"""Apply a one-shot cfg_rescale override and immediately clear it."""
|
| 34 |
+
global cfg_rescale, cfg_rescale_override
|
| 35 |
+
if cfg_rescale_override is not None:
|
| 36 |
+
cfg_rescale = cfg_rescale_override
|
| 37 |
+
cfg_rescale_override = None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
# Base-prompt protection
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
# Three modes:
|
| 44 |
+
#
|
| 45 |
+
# 'off' — no guard; auxiliary children can dominate freely.
|
| 46 |
+
# Use for experiments or intentional replacement-style prompts.
|
| 47 |
+
#
|
| 48 |
+
# 'auto' — (DEFAULT) structural guard: fires when the prompt tree has no
|
| 49 |
+
# top-level base AND-child (i.e. all children have conciliation
|
| 50 |
+
# keywords). Prevents obvious "main prompt disappears" cases.
|
| 51 |
+
#
|
| 52 |
+
# 'strict' — structural guard PLUS numerical guard: also fires when a base
|
| 53 |
+
# child exists but its cond_delta is too weak relative to the
|
| 54 |
+
# auxiliary contribution:
|
| 55 |
+
# norm(base_delta) / norm(aux_delta) < strict_threshold
|
| 56 |
+
# Use when you notice auxiliary prompts quietly dominating even
|
| 57 |
+
# though a base prompt is present.
|
| 58 |
+
#
|
| 59 |
+
# Both auto/strict fall back to standard A1111 CFG and print a warning when
|
| 60 |
+
# verbose=True in Settings → Neutral Prompt.
|
| 61 |
+
# ---------------------------------------------------------------------------
|
| 62 |
+
|
| 63 |
+
_VALID_PROTECTION_MODES = frozenset({'off', 'auto', 'strict', 'soft'})
|
| 64 |
+
|
| 65 |
+
protection_mode: str = 'auto'
|
| 66 |
+
strict_threshold: float = 0.10 # ratio threshold used by both strict and soft modes
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def normalize_protection_mode(value: str) -> str:
|
| 70 |
+
"""
|
| 71 |
+
Normalise an incoming protection-mode string.
|
| 72 |
+
Strips whitespace, lowercases, and falls back to 'auto' for unknown values.
|
| 73 |
+
Used wherever mode arrives from the UI or infotext paste.
|
| 74 |
+
Valid modes: 'off', 'auto', 'strict', 'soft'.
|
| 75 |
+
"""
|
| 76 |
+
normalised = str(value).strip().lower()
|
| 77 |
+
if normalised not in _VALID_PROTECTION_MODES:
|
| 78 |
+
return 'auto'
|
| 79 |
+
return normalised
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def clamp_strict_threshold(value) -> float:
|
| 83 |
+
"""Clamp threshold to [0.01, 0.50]; returns 0.10 on parse failure."""
|
| 84 |
+
try:
|
| 85 |
+
return max(0.01, min(0.50, float(value)))
|
| 86 |
+
except (TypeError, ValueError):
|
| 87 |
+
return 0.10
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# ---------------------------------------------------------------------------
|
| 92 |
+
# Step-window activation control
|
| 93 |
+
# ---------------------------------------------------------------------------
|
| 94 |
+
# Controls which denoising steps each AND_* strategy is active during.
|
| 95 |
+
# All values here are read by step_utils.strategy_is_active_from_state().
|
| 96 |
+
#
|
| 97 |
+
# step_window_enabled — master switch; False = always active (default)
|
| 98 |
+
# step_window_global — StepWindow or None, applied to every strategy
|
| 99 |
+
# step_window_per_strategy — dict[family_name, StepWindow] or None
|
| 100 |
+
# step_window_use_defaults — True = fall back to step_utils.STRATEGY_DEFAULTS
|
| 101 |
+
|
| 102 |
+
step_window_enabled: bool = False
|
| 103 |
+
step_window_global = None # StepWindow | None
|
| 104 |
+
step_window_per_strategy = None # dict[str, StepWindow] | None (family → StepWindow)
|
| 105 |
+
step_window_use_defaults: bool = False
|
| 106 |
+
step_window_custom_raw = None # dict[str, tuple] | None {AND_* UI key → (start,end)}
|
| 107 |
+
step_window_lock_after_end: bool = False # freeze strategies after their window closes
|
| 108 |
+
|
| 109 |
+
# Per-generation per-strategy lock flags {family_name → True if locked}
|
| 110 |
+
# Keyed by _lock_generation_id: flags from a previous generation are never
|
| 111 |
+
# visible to the current one even if reset_step_lock_flags() was not called.
|
| 112 |
+
_step_lock_flags: dict = {}
|
| 113 |
+
|
| 114 |
+
# Monotonically increasing counter bumped once per real generation
|
| 115 |
+
# (i.e. once per NeutralPromptScript.process() call).
|
| 116 |
+
# strategy_is_active_from_state embeds this ID in each flag entry so that
|
| 117 |
+
# flags from a previous generation are treated as absent.
|
| 118 |
+
_lock_generation_id: int = 0
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def begin_new_generation() -> None:
|
| 122 |
+
"""
|
| 123 |
+
Mark the start of a new generation run.
|
| 124 |
+
|
| 125 |
+
Call this exactly once per NeutralPromptScript.process() call.
|
| 126 |
+
Bumping the generation ID effectively invalidates all lock flags from
|
| 127 |
+
the previous run — without requiring any lock-flag cleanup in the
|
| 128 |
+
denoising loop itself.
|
| 129 |
+
|
| 130 |
+
This is intentionally NOT triggered by current_step == 0 because
|
| 131 |
+
some samplers (Restart, DPM++ variants) revisit step 0 internally
|
| 132 |
+
within a single denoising pass. Using the process() lifecycle hook
|
| 133 |
+
instead gives a clean, unambiguous per-generation boundary.
|
| 134 |
+
"""
|
| 135 |
+
global _lock_generation_id, _step_lock_flags
|
| 136 |
+
_lock_generation_id += 1
|
| 137 |
+
_step_lock_flags = {}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def ensure_step_lock_run_initialized(run_token) -> bool:
|
| 141 |
+
"""
|
| 142 |
+
Token-based alternative to begin_new_generation().
|
| 143 |
+
|
| 144 |
+
Initialises lock flags when *run_token* differs from the last seen token.
|
| 145 |
+
Returns True if a reset was performed (new generation), False otherwise.
|
| 146 |
+
|
| 147 |
+
Typical token: ``id(global_state.prompt_exprs)`` — a new list object is
|
| 148 |
+
created on every NeutralPromptScript.process() call, so the token changes
|
| 149 |
+
exactly once per generation and is stable across sampler restarts that
|
| 150 |
+
happen inside the denoising loop.
|
| 151 |
+
|
| 152 |
+
Prefer begin_new_generation() when you control the call-site (scripts/);
|
| 153 |
+
use this helper when you only have access to an existing context object.
|
| 154 |
+
"""
|
| 155 |
+
global _lock_generation_id, _step_lock_flags
|
| 156 |
+
current_token = getattr(ensure_step_lock_run_initialized, '_last_token', object())
|
| 157 |
+
if current_token != run_token:
|
| 158 |
+
ensure_step_lock_run_initialized._last_token = run_token # type: ignore[attr-defined]
|
| 159 |
+
_lock_generation_id += 1
|
| 160 |
+
_step_lock_flags = {}
|
| 161 |
+
return True
|
| 162 |
+
return False
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def reset_step_lock_flags() -> None:
|
| 166 |
+
"""
|
| 167 |
+
Clear all Lock-after-End flags for the current generation.
|
| 168 |
+
|
| 169 |
+
Prefer begin_new_generation() for per-generation resets; this helper
|
| 170 |
+
exists for testing and for callers that need an in-generation flush.
|
| 171 |
+
"""
|
| 172 |
+
global _step_lock_flags
|
| 173 |
+
_step_lock_flags = {}
|
| 174 |
+
|
| 175 |
+
# Current step counter — set by the hijack each denoising step.
|
| 176 |
+
current_step: int = 0
|
| 177 |
+
total_steps: int = 1
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def update_step_state(step, total) -> None:
|
| 181 |
+
"""
|
| 182 |
+
Update the current denoising step counters.
|
| 183 |
+
|
| 184 |
+
Called by the sampler hijack at each step so that step-window gating
|
| 185 |
+
has accurate progress information.
|
| 186 |
+
|
| 187 |
+
Parameters
|
| 188 |
+
----------
|
| 189 |
+
step : current 0-based step index (any numeric type)
|
| 190 |
+
total : total number of steps (any numeric type, must be ≥ 1)
|
| 191 |
+
"""
|
| 192 |
+
global current_step, total_steps
|
| 193 |
+
try:
|
| 194 |
+
current = max(0, int(step))
|
| 195 |
+
except Exception:
|
| 196 |
+
current = 0
|
| 197 |
+
try:
|
| 198 |
+
total_ = max(1, int(total))
|
| 199 |
+
except Exception:
|
| 200 |
+
total_ = 1
|
| 201 |
+
current_step = current
|
| 202 |
+
total_steps = total_
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# ---------------------------------------------------------------------------
|
| 206 |
+
# Infotext helpers
|
| 207 |
+
# ---------------------------------------------------------------------------
|
| 208 |
+
# Keys used when saving/restoring settings via A1111 infotext.
|
| 209 |
+
|
| 210 |
+
INFOTEXT_KEY_PROTECTION_MODE = 'NP Protection Mode'
|
| 211 |
+
INFOTEXT_KEY_STRICT_THRESHOLD = 'NP Strict Threshold'
|
| 212 |
+
INFOTEXT_KEY_CFG_RESCALE = 'CFG Rescale phi'
|
| 213 |
+
INFOTEXT_KEY_STEP_WIN_ENABLED = 'NP Step Window Enabled'
|
| 214 |
+
INFOTEXT_KEY_STEP_WIN_MODE = 'NP Step Window Mode'
|
| 215 |
+
INFOTEXT_KEY_STEP_WIN_START = 'NP Step Window Start'
|
| 216 |
+
INFOTEXT_KEY_STEP_WIN_END = 'NP Step Window End'
|
| 217 |
+
INFOTEXT_KEY_STEP_WIN_CUSTOM = 'NP Step Window Custom'
|
| 218 |
+
INFOTEXT_KEY_STEP_WIN_LOCK = 'NP Step Window Lock'
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def apply_infotext(fields: dict) -> None:
|
| 222 |
+
"""
|
| 223 |
+
Apply a dict of infotext key→value pairs to global state.
|
| 224 |
+
Called from the script's on_infotext_pasted handler.
|
| 225 |
+
Only updates fields that are actually present in the dict.
|
| 226 |
+
"""
|
| 227 |
+
global protection_mode, strict_threshold, cfg_rescale
|
| 228 |
+
global step_window_enabled, step_window_global, step_window_use_defaults
|
| 229 |
+
global step_window_per_strategy, step_window_custom_raw, step_window_lock_after_end
|
| 230 |
+
|
| 231 |
+
if INFOTEXT_KEY_PROTECTION_MODE in fields:
|
| 232 |
+
protection_mode = normalize_protection_mode(fields[INFOTEXT_KEY_PROTECTION_MODE])
|
| 233 |
+
if INFOTEXT_KEY_STRICT_THRESHOLD in fields:
|
| 234 |
+
strict_threshold = clamp_strict_threshold(fields[INFOTEXT_KEY_STRICT_THRESHOLD])
|
| 235 |
+
if INFOTEXT_KEY_CFG_RESCALE in fields:
|
| 236 |
+
try:
|
| 237 |
+
cfg_rescale = float(fields[INFOTEXT_KEY_CFG_RESCALE])
|
| 238 |
+
except (TypeError, ValueError):
|
| 239 |
+
pass
|
| 240 |
+
|
| 241 |
+
# Step-window fields (all optional — may be absent in older infotext)
|
| 242 |
+
if INFOTEXT_KEY_STEP_WIN_ENABLED in fields:
|
| 243 |
+
raw = fields[INFOTEXT_KEY_STEP_WIN_ENABLED]
|
| 244 |
+
step_window_enabled = str(raw).lower() in ('true', '1', 'yes')
|
| 245 |
+
|
| 246 |
+
if step_window_enabled:
|
| 247 |
+
mode = str(fields.get(INFOTEXT_KEY_STEP_WIN_MODE, 'global'))
|
| 248 |
+
try:
|
| 249 |
+
start = max(0.0, min(1.0, float(fields.get(INFOTEXT_KEY_STEP_WIN_START, 0.0))))
|
| 250 |
+
except (TypeError, ValueError):
|
| 251 |
+
start = 0.0
|
| 252 |
+
try:
|
| 253 |
+
end = max(0.0, min(1.0, float(fields.get(INFOTEXT_KEY_STEP_WIN_END, 1.0))))
|
| 254 |
+
except (TypeError, ValueError):
|
| 255 |
+
end = 1.0
|
| 256 |
+
if start > end:
|
| 257 |
+
start, end = 0.0, 1.0
|
| 258 |
+
|
| 259 |
+
if mode == 'global':
|
| 260 |
+
# Import here to avoid circular import at module level
|
| 261 |
+
from lib_neutral_prompt.step_utils import StepWindow as _SW
|
| 262 |
+
step_window_global = _SW(start, end)
|
| 263 |
+
step_window_use_defaults = False
|
| 264 |
+
step_window_per_strategy = None
|
| 265 |
+
elif mode == 'per-strategy defaults':
|
| 266 |
+
step_window_global = None
|
| 267 |
+
step_window_use_defaults = True
|
| 268 |
+
step_window_per_strategy = None
|
| 269 |
+
elif mode == 'per-strategy custom':
|
| 270 |
+
step_window_global = None
|
| 271 |
+
step_window_use_defaults = False
|
| 272 |
+
# Restore per-strategy dict from compact string if present
|
| 273 |
+
if INFOTEXT_KEY_STEP_WIN_CUSTOM in fields:
|
| 274 |
+
from lib_neutral_prompt.step_utils import (
|
| 275 |
+
deserialize_per_strategy_windows,
|
| 276 |
+
build_per_strategy_windows,
|
| 277 |
+
)
|
| 278 |
+
raw = deserialize_per_strategy_windows(
|
| 279 |
+
str(fields[INFOTEXT_KEY_STEP_WIN_CUSTOM])
|
| 280 |
+
)
|
| 281 |
+
step_window_custom_raw = raw
|
| 282 |
+
step_window_per_strategy = build_per_strategy_windows(raw)
|
| 283 |
+
else:
|
| 284 |
+
step_window_global = None
|
| 285 |
+
step_window_use_defaults = False
|
| 286 |
+
|
| 287 |
+
# Lock after End (may be set independently of other step-window fields)
|
| 288 |
+
if INFOTEXT_KEY_STEP_WIN_LOCK in fields:
|
| 289 |
+
raw_lock = fields[INFOTEXT_KEY_STEP_WIN_LOCK]
|
| 290 |
+
step_window_lock_after_end = str(raw_lock).lower() in ('true', '1', 'yes')
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# ---------------------------------------------------------------------------
|
| 294 |
+
# CFG Rescale Factor Singleton (export_rescale_factor branch)
|
| 295 |
+
#
|
| 296 |
+
# Stores the last computed rescale factor so that external tools / scripts
|
| 297 |
+
# can read it without re-deriving it.
|
| 298 |
+
#
|
| 299 |
+
# Uses threading.local() so concurrent API workers each get their own slot.
|
| 300 |
+
# clear() must be called at the start of every _cfg_rescale() call so that
|
| 301 |
+
# get() correctly returns None when rescaling was skipped this step.
|
| 302 |
+
# ---------------------------------------------------------------------------
|
| 303 |
+
|
| 304 |
+
class CFGRescaleFactorSingleton:
|
| 305 |
+
"""Thread-local store for the most recently computed CFG rescale factor.
|
| 306 |
+
|
| 307 |
+
Lifecycle per denoising step:
|
| 308 |
+
1. _cfg_rescale() calls clear() at entry — value becomes None.
|
| 309 |
+
2. If rescaling is active, set() stores the computed factor.
|
| 310 |
+
3. External code calls get() after the step; receives the factor or None.
|
| 311 |
+
"""
|
| 312 |
+
|
| 313 |
+
_state = threading.local()
|
| 314 |
+
|
| 315 |
+
@classmethod
|
| 316 |
+
def set(cls, value: float) -> None:
|
| 317 |
+
cls._state.cfg_rescale_factor = value
|
| 318 |
+
|
| 319 |
+
@classmethod
|
| 320 |
+
def get(cls) -> Optional[float]:
|
| 321 |
+
return getattr(cls._state, 'cfg_rescale_factor', None)
|
| 322 |
+
|
| 323 |
+
@classmethod
|
| 324 |
+
def clear(cls) -> None:
|
| 325 |
+
cls._state.cfg_rescale_factor = None
|
neutral_prompt_patcheds/lib_neutral_prompt/hijacker.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ModuleHijacker:
|
| 5 |
+
def __init__(self, module):
|
| 6 |
+
self.__module = module
|
| 7 |
+
self.__original_functions = dict()
|
| 8 |
+
|
| 9 |
+
def hijack(self, attribute):
|
| 10 |
+
if attribute not in self.__original_functions:
|
| 11 |
+
self.__original_functions[attribute] = getattr(self.__module, attribute)
|
| 12 |
+
|
| 13 |
+
def decorator(function):
|
| 14 |
+
setattr(self.__module, attribute, functools.partial(function, original_function=self.__original_functions[attribute]))
|
| 15 |
+
return function
|
| 16 |
+
|
| 17 |
+
return decorator
|
| 18 |
+
|
| 19 |
+
def reset_module(self):
|
| 20 |
+
for attribute, original_function in self.__original_functions.items():
|
| 21 |
+
setattr(self.__module, attribute, original_function)
|
| 22 |
+
|
| 23 |
+
self.__original_functions.clear()
|
| 24 |
+
|
| 25 |
+
@staticmethod
|
| 26 |
+
def install_or_get(module, hijacker_attribute, on_uninstall=lambda _callback: None):
|
| 27 |
+
if not hasattr(module, hijacker_attribute):
|
| 28 |
+
module_hijacker = ModuleHijacker(module)
|
| 29 |
+
setattr(module, hijacker_attribute, module_hijacker)
|
| 30 |
+
on_uninstall(lambda: delattr(module, hijacker_attribute))
|
| 31 |
+
on_uninstall(module_hijacker.reset_module)
|
| 32 |
+
return module_hijacker
|
| 33 |
+
else:
|
| 34 |
+
return getattr(module, hijacker_attribute)
|
neutral_prompt_patcheds/lib_neutral_prompt/matryoshka_utils.py
ADDED
|
@@ -0,0 +1,644 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pure-Python helpers for Matryoshka (nested prompt) construction, tree rendering,
|
| 3 |
+
and diagnostics.
|
| 4 |
+
|
| 5 |
+
Importable without gradio or A1111 modules — all functions here have no
|
| 6 |
+
side effects and depend only on neutral_prompt_parser and global_state.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from typing import Any, Dict, List, Optional
|
| 12 |
+
|
| 13 |
+
from lib_neutral_prompt import global_state, neutral_prompt_parser
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# ---------------------------------------------------------------------------
|
| 17 |
+
# Builder helpers
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
|
| 20 |
+
_KW_DEFAULTS: Dict[str, str] = {
|
| 21 |
+
'AND_PERP': 'AND_PERP',
|
| 22 |
+
'AND_SALT': 'AND_SALT[5]',
|
| 23 |
+
'AND_SALT_WIDE': 'AND_SALT_WIDE[1]',
|
| 24 |
+
'AND_SALT_BLOB': 'AND_SALT_BLOB[5]',
|
| 25 |
+
'AND_TOPK': 'AND_TOPK[0.05]',
|
| 26 |
+
'AND_ALIGN': 'AND_ALIGN[4,8]',
|
| 27 |
+
'AND_MASK_ALIGN': 'AND_MASK_ALIGN[4,8]',
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
BUILDER_STRATEGIES = sorted(_KW_DEFAULTS.keys())
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def build_child_block(strategy: str, text: str, weight: float,
|
| 34 |
+
affine: str = '', nested: str = '') -> str:
|
| 35 |
+
"""
|
| 36 |
+
Build a single AND_* child block, optionally with an affine prefix
|
| 37 |
+
and an inner nested block.
|
| 38 |
+
|
| 39 |
+
Returns the child line(s) as a string (no trailing newline).
|
| 40 |
+
Returns an empty string if both `text` and `nested` are empty/whitespace —
|
| 41 |
+
callers should check for this and skip the block.
|
| 42 |
+
|
| 43 |
+
Examples::
|
| 44 |
+
|
| 45 |
+
build_child_block('AND_SALT', 'texture', 0.8)
|
| 46 |
+
# → 'AND_SALT[5] texture :0.8'
|
| 47 |
+
|
| 48 |
+
build_child_block('AND_TOPK', 'highlights', 0.5,
|
| 49 |
+
nested='AND_PERP blur :0.4')
|
| 50 |
+
# → 'AND_TOPK[0.05] [\\n highlights :0.5\\n AND_PERP blur :0.4\\n] :0.5'
|
| 51 |
+
|
| 52 |
+
build_child_block('AND_SALT', '', 0.8)
|
| 53 |
+
# → '' (empty — nothing to insert)
|
| 54 |
+
"""
|
| 55 |
+
kw = _KW_DEFAULTS.get(strategy, strategy)
|
| 56 |
+
prefix = f'{affine.strip()} ' if affine.strip() else ''
|
| 57 |
+
w_str = f':{weight}'
|
| 58 |
+
text = text.strip()
|
| 59 |
+
nested = nested.strip()
|
| 60 |
+
|
| 61 |
+
# Nothing to add — signal the caller to skip this block entirely.
|
| 62 |
+
if not text and not nested:
|
| 63 |
+
return ''
|
| 64 |
+
|
| 65 |
+
if nested:
|
| 66 |
+
inner_lines: List[str] = []
|
| 67 |
+
# Only include the text line when text is non-empty
|
| 68 |
+
if text:
|
| 69 |
+
inner_lines.append(f' {text} {w_str}')
|
| 70 |
+
for ln in nested.splitlines():
|
| 71 |
+
inner_lines.append(f' {ln}')
|
| 72 |
+
body = '\n'.join(inner_lines)
|
| 73 |
+
return f'{prefix}{kw} [\n{body}\n] {w_str}'
|
| 74 |
+
else:
|
| 75 |
+
return f'{prefix}{kw} {text} {w_str}'
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def build_nested_prompt(base: str, children: List[Dict[str, Any]]) -> str:
|
| 79 |
+
"""
|
| 80 |
+
Build a complete prompt from a base string and a list of child dicts.
|
| 81 |
+
|
| 82 |
+
Each child dict may contain:
|
| 83 |
+
strategy (str), text (str), weight (float),
|
| 84 |
+
affine (str, optional), nested (str, optional)
|
| 85 |
+
|
| 86 |
+
Children whose `text` and `nested` are both empty are silently skipped.
|
| 87 |
+
Returns the full prompt as a single string.
|
| 88 |
+
"""
|
| 89 |
+
parts = [base.strip()] if base.strip() else []
|
| 90 |
+
for ch in children:
|
| 91 |
+
block = build_child_block(
|
| 92 |
+
strategy=ch.get('strategy', 'AND_PERP'),
|
| 93 |
+
text=ch.get('text', ''),
|
| 94 |
+
weight=float(ch.get('weight', 1.0)),
|
| 95 |
+
affine=ch.get('affine', ''),
|
| 96 |
+
nested=ch.get('nested', ''),
|
| 97 |
+
)
|
| 98 |
+
if block: # empty string → skip
|
| 99 |
+
parts.append(block)
|
| 100 |
+
return '\n'.join(parts)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# ---------------------------------------------------------------------------
|
| 104 |
+
# Tree renderer
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
|
| 107 |
+
def render_prompt_node(
|
| 108 |
+
node: neutral_prompt_parser.PromptExpr,
|
| 109 |
+
depth: int = 0,
|
| 110 |
+
prefix: str = '',
|
| 111 |
+
is_last: bool = True,
|
| 112 |
+
) -> List[str]:
|
| 113 |
+
"""
|
| 114 |
+
Recursively render a single AST node as tree lines using box-drawing chars.
|
| 115 |
+
|
| 116 |
+
Returns a list of strings (one per line), without trailing newline.
|
| 117 |
+
"""
|
| 118 |
+
connector = '└─ ' if is_last else '├─ '
|
| 119 |
+
child_pfx = prefix + (' ' if is_last else '│ ')
|
| 120 |
+
|
| 121 |
+
strat = node.conciliation.name if node.conciliation else 'BASE'
|
| 122 |
+
w = f'w={node.weight:.2f}'
|
| 123 |
+
params_str = ''
|
| 124 |
+
if node.conciliation_params:
|
| 125 |
+
params_str = ' [' + ', '.join(f'{k}={v}' for k, v in node.conciliation_params.items()) + ']'
|
| 126 |
+
aff_str = ' ✦affine' if node.local_transform is not None else ''
|
| 127 |
+
|
| 128 |
+
if hasattr(node, 'prompt'):
|
| 129 |
+
txt = node.prompt.strip()
|
| 130 |
+
short = (txt[:55] + '…') if len(txt) > 55 else txt
|
| 131 |
+
label = f'{connector}[{strat}] {w}{params_str}{aff_str} "{short}"'
|
| 132 |
+
else:
|
| 133 |
+
n = len(node.children) if hasattr(node, 'children') else 0
|
| 134 |
+
label = f'{connector}[{strat}] {w}{params_str}{aff_str} ({n} children)'
|
| 135 |
+
|
| 136 |
+
lines: List[str] = [prefix + label]
|
| 137 |
+
|
| 138 |
+
if hasattr(node, 'children') and node.children:
|
| 139 |
+
for i, child in enumerate(node.children):
|
| 140 |
+
last = (i == len(node.children) - 1)
|
| 141 |
+
lines.extend(render_prompt_node(child, depth + 1, child_pfx, last))
|
| 142 |
+
|
| 143 |
+
return lines
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def render_prompt_tree(prompt_str: str) -> str:
|
| 147 |
+
"""Full tree render. Returns a multi-line string ready for display."""
|
| 148 |
+
if not prompt_str.strip():
|
| 149 |
+
return '(empty prompt — nothing to show)'
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
root = neutral_prompt_parser.parse_root(prompt_str)
|
| 153 |
+
except Exception as exc:
|
| 154 |
+
return f'⚠ Parse error: {exc}'
|
| 155 |
+
|
| 156 |
+
n = len(root.children) if hasattr(root, 'children') else 0
|
| 157 |
+
lines = [f'ROOT ({n} top-level segments)']
|
| 158 |
+
if hasattr(root, 'children') and root.children:
|
| 159 |
+
for i, child in enumerate(root.children):
|
| 160 |
+
last = (i == len(root.children) - 1)
|
| 161 |
+
lines.extend(render_prompt_node(child, prefix='', is_last=last))
|
| 162 |
+
return '\n'.join(lines)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# ---------------------------------------------------------------------------
|
| 166 |
+
# Diagnostics
|
| 167 |
+
# ---------------------------------------------------------------------------
|
| 168 |
+
|
| 169 |
+
_BRACKET_KW_NAMES = frozenset({
|
| 170 |
+
'SALIENCE_MASK', 'SALIENCE_MASK_WIDE', 'SALIENCE_MASK_BLOB',
|
| 171 |
+
'SEMANTIC_GUIDANCE', 'ALIGNMENT_BLEND_CUSTOM', 'ALIGNMENT_MASK_BLEND_CUSTOM',
|
| 172 |
+
})
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def collect_prompt_diagnostics(prompt_str: str) -> Dict[str, Any]:
|
| 176 |
+
"""
|
| 177 |
+
Parse prompt_str and collect diagnostic metadata.
|
| 178 |
+
|
| 179 |
+
Returns a dict with keys:
|
| 180 |
+
parse_ok, parse_error, max_depth, n_segments, n_leaf, n_composite,
|
| 181 |
+
n_affine, n_invalid_base, strategies, ignored_hint,
|
| 182 |
+
protection_would_fire, protection_mode
|
| 183 |
+
"""
|
| 184 |
+
result: Dict[str, Any] = {
|
| 185 |
+
'parse_ok': False,
|
| 186 |
+
'parse_error': None,
|
| 187 |
+
'max_depth': 0,
|
| 188 |
+
'n_segments': 0,
|
| 189 |
+
'n_leaf': 0,
|
| 190 |
+
'n_composite': 0,
|
| 191 |
+
'n_affine': 0,
|
| 192 |
+
'n_invalid_base': 0,
|
| 193 |
+
'strategies': [],
|
| 194 |
+
'ignored_hint': False,
|
| 195 |
+
'protection_would_fire': False,
|
| 196 |
+
'protection_mode': global_state.normalize_protection_mode(global_state.protection_mode),
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
if not prompt_str.strip():
|
| 200 |
+
return result
|
| 201 |
+
|
| 202 |
+
try:
|
| 203 |
+
root = neutral_prompt_parser.parse_root(prompt_str)
|
| 204 |
+
except Exception as exc:
|
| 205 |
+
result['parse_error'] = str(exc)
|
| 206 |
+
return result
|
| 207 |
+
|
| 208 |
+
result['parse_ok'] = True
|
| 209 |
+
strategies_seen: List[str] = []
|
| 210 |
+
|
| 211 |
+
def _walk(node: Any, depth: int) -> None:
|
| 212 |
+
result['max_depth'] = max(result['max_depth'], depth)
|
| 213 |
+
result['n_segments'] += 1
|
| 214 |
+
|
| 215 |
+
if node.local_transform is not None:
|
| 216 |
+
result['n_affine'] += 1
|
| 217 |
+
|
| 218 |
+
strat = node.conciliation.name if node.conciliation else 'BASE'
|
| 219 |
+
if strat not in strategies_seen:
|
| 220 |
+
strategies_seen.append(strat)
|
| 221 |
+
|
| 222 |
+
if (node.conciliation is not None
|
| 223 |
+
and node.conciliation.name in _BRACKET_KW_NAMES
|
| 224 |
+
and not node.conciliation_params):
|
| 225 |
+
result['ignored_hint'] = True
|
| 226 |
+
|
| 227 |
+
if hasattr(node, 'prompt'):
|
| 228 |
+
result['n_leaf'] += 1
|
| 229 |
+
else:
|
| 230 |
+
result['n_composite'] += 1
|
| 231 |
+
if hasattr(node, 'children'):
|
| 232 |
+
for child in node.children:
|
| 233 |
+
_walk(child, depth + 1)
|
| 234 |
+
|
| 235 |
+
_walk(root, 0)
|
| 236 |
+
result['strategies'] = strategies_seen
|
| 237 |
+
|
| 238 |
+
if isinstance(root, neutral_prompt_parser.CompositePrompt):
|
| 239 |
+
has_base = any(c.conciliation is None for c in root.children)
|
| 240 |
+
mode = result['protection_mode']
|
| 241 |
+
if mode in ('auto', 'strict') and not has_base:
|
| 242 |
+
result['protection_would_fire'] = True
|
| 243 |
+
result['n_invalid_base'] = 1
|
| 244 |
+
|
| 245 |
+
return result
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def render_diagnostics(diag: Dict[str, Any]) -> str:
|
| 249 |
+
"""Turn a diagnostics dict into a readable report string."""
|
| 250 |
+
if not diag['parse_ok']:
|
| 251 |
+
err = diag.get('parse_error') or 'unknown'
|
| 252 |
+
return f'⚠ Parse error: {err}'
|
| 253 |
+
|
| 254 |
+
lines = ['── Diagnostics ──────────────────────────']
|
| 255 |
+
lines.append(f'Segments total : {diag["n_segments"]} '
|
| 256 |
+
f'(leaf={diag["n_leaf"]}, composite={diag["n_composite"]})')
|
| 257 |
+
lines.append(f'Max nesting depth : {diag["max_depth"]}')
|
| 258 |
+
lines.append(f'Affine transforms : {diag["n_affine"]}')
|
| 259 |
+
lines.append(f'Strategies used : {", ".join(diag["strategies"]) or "none"}')
|
| 260 |
+
|
| 261 |
+
if diag['ignored_hint']:
|
| 262 |
+
lines.append('⚠ One or more bracket params were invalid and ignored. Using defaults.')
|
| 263 |
+
|
| 264 |
+
lines.append('')
|
| 265 |
+
mode = diag['protection_mode']
|
| 266 |
+
lines.append(f'Protection mode : {mode}')
|
| 267 |
+
if mode == 'off':
|
| 268 |
+
lines.append(' Guard disabled — no fallback will occur')
|
| 269 |
+
elif diag['protection_would_fire']:
|
| 270 |
+
lines.append(' ⚠ WOULD FIRE — no plain base segment at the top level')
|
| 271 |
+
lines.append(' Fix: add a plain base prompt before all AND_* segments')
|
| 272 |
+
else:
|
| 273 |
+
lines.append(' OK — base segment present at top level')
|
| 274 |
+
if mode == 'strict':
|
| 275 |
+
thr = global_state.strict_threshold
|
| 276 |
+
lines.append(f' Strict ratio check ({thr:.2f}) runs at generation time')
|
| 277 |
+
|
| 278 |
+
return '\n'.join(lines)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def render_full_explain(prompt_str: str) -> str:
|
| 282 |
+
"""
|
| 283 |
+
Combine tree + diagnostics + effect summaries into a single panel output.
|
| 284 |
+
|
| 285 |
+
For an empty prompt returns only the empty-state message (no spurious
|
| 286 |
+
parse-error noise).
|
| 287 |
+
"""
|
| 288 |
+
if not prompt_str.strip():
|
| 289 |
+
return '(empty prompt — nothing to show)'
|
| 290 |
+
|
| 291 |
+
tree = render_prompt_tree(prompt_str)
|
| 292 |
+
diag = collect_prompt_diagnostics(prompt_str)
|
| 293 |
+
diagr = render_diagnostics(diag)
|
| 294 |
+
summary = _render_effect_summary(diag)
|
| 295 |
+
|
| 296 |
+
parts = [tree, '', diagr]
|
| 297 |
+
if summary:
|
| 298 |
+
parts += ['', summary]
|
| 299 |
+
return '\n'.join(parts)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
# ---------------------------------------------------------------------------
|
| 303 |
+
# Effect summaries
|
| 304 |
+
# ---------------------------------------------------------------------------
|
| 305 |
+
|
| 306 |
+
_EFFECT_DESCRIPTIONS: Dict[str, str] = {
|
| 307 |
+
'BASE': 'base prompt contribution — drives the overall image',
|
| 308 |
+
'PERPENDICULAR': 'perpendicular projection — reduces contradicting features',
|
| 309 |
+
'SALIENCE_MASK': 'sharp saliency mask — targets the most salient latent pixels',
|
| 310 |
+
'SALIENCE_MASK_WIDE': 'broad saliency mask — covers a wide region (~55% of pixels)',
|
| 311 |
+
'SALIENCE_MASK_BLOB': 'blob saliency — grows seed region to a smooth connected area',
|
| 312 |
+
'SEMANTIC_GUIDANCE': 'semantic top-k — applies sparse targeted changes to the strongest elements',
|
| 313 |
+
'ALIGNMENT_BLEND_CUSTOM': 'alignment blend — soft injection that preserves global structure',
|
| 314 |
+
'ALIGNMENT_MASK_BLEND_CUSTOM': 'alignment mask — binary-mask variant of alignment blend (stricter)',
|
| 315 |
+
}
|
| 316 |
+
# Catch legacy fixed-suffix ALIGNMENT_BLEND_D_S patterns
|
| 317 |
+
_ALIGNMENT_PREFIX = 'ALIGNMENT_BLEND_'
|
| 318 |
+
_ALIGNMENT_MASK_PREFIX = 'ALIGNMENT_MASK_BLEND_'
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def _effect_for_strategy(name: str) -> str:
|
| 322 |
+
if name in _EFFECT_DESCRIPTIONS:
|
| 323 |
+
return _EFFECT_DESCRIPTIONS[name]
|
| 324 |
+
if name.startswith(_ALIGNMENT_MASK_PREFIX):
|
| 325 |
+
return _EFFECT_DESCRIPTIONS['ALIGNMENT_MASK_BLEND_CUSTOM']
|
| 326 |
+
if name.startswith(_ALIGNMENT_PREFIX):
|
| 327 |
+
return _EFFECT_DESCRIPTIONS['ALIGNMENT_BLEND_CUSTOM']
|
| 328 |
+
return f'strategy: {name}'
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def _render_effect_summary(diag: Dict[str, Any]) -> str:
|
| 332 |
+
"""
|
| 333 |
+
Build a one-line-per-strategy effect summary from diagnostics data.
|
| 334 |
+
Returns an empty string if the prompt failed to parse.
|
| 335 |
+
"""
|
| 336 |
+
if not diag.get('parse_ok'):
|
| 337 |
+
return ''
|
| 338 |
+
strategies = diag.get('strategies', [])
|
| 339 |
+
if not strategies:
|
| 340 |
+
return ''
|
| 341 |
+
lines = ['── Effect summary ───────────────────────']
|
| 342 |
+
for s in strategies:
|
| 343 |
+
lines.append(f' {s:32s}→ {_effect_for_strategy(s)}')
|
| 344 |
+
lines.append('')
|
| 345 |
+
lines.append(' Note: protection verdict above is a structural preview.')
|
| 346 |
+
lines.append(' The strict ratio check and batch diagnostics run at generation time.')
|
| 347 |
+
return '\n'.join(lines)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
# ---------------------------------------------------------------------------
|
| 351 |
+
# Matryoshka templates
|
| 352 |
+
# ---------------------------------------------------------------------------
|
| 353 |
+
|
| 354 |
+
MATRYOSHKA_TEMPLATES: Dict[str, Dict[str, str]] = {
|
| 355 |
+
'Nested local detail': {
|
| 356 |
+
'description': (
|
| 357 |
+
'Add fine detail inside a saliency region. '
|
| 358 |
+
'SALT focuses on salient pixels; nested TOPK applies sparse corrections within that region.'
|
| 359 |
+
),
|
| 360 |
+
'prompt': (
|
| 361 |
+
'base subject, environment\n'
|
| 362 |
+
'AND_SALT[5] [\n'
|
| 363 |
+
' subject texture :0.8\n'
|
| 364 |
+
' AND_TOPK[0.05] fine detail highlights :0.5\n'
|
| 365 |
+
'] :0.8'
|
| 366 |
+
),
|
| 367 |
+
},
|
| 368 |
+
'Structure preserve + style inject': {
|
| 369 |
+
'description': (
|
| 370 |
+
'Keep global composition while injecting a style locally. '
|
| 371 |
+
'ALIGN preserves structure; nested PERP reduces contradiction.'
|
| 372 |
+
),
|
| 373 |
+
'prompt': (
|
| 374 |
+
'base composition, environment\n'
|
| 375 |
+
'AND_ALIGN[4,8] [\n'
|
| 376 |
+
' style name :0.8\n'
|
| 377 |
+
' AND_PERP style contradiction :0.5\n'
|
| 378 |
+
'] :0.7'
|
| 379 |
+
),
|
| 380 |
+
},
|
| 381 |
+
'Sparse detail inside broad region': {
|
| 382 |
+
'description': (
|
| 383 |
+
'SALT_WIDE covers a broad region; nested SALT sharpens a smaller focal point inside it.'
|
| 384 |
+
),
|
| 385 |
+
'prompt': (
|
| 386 |
+
'base subject\n'
|
| 387 |
+
'AND_SALT_WIDE[1] [\n'
|
| 388 |
+
' surface texture :0.7\n'
|
| 389 |
+
' AND_SALT[7] focal detail :0.6\n'
|
| 390 |
+
'] :0.8'
|
| 391 |
+
),
|
| 392 |
+
},
|
| 393 |
+
'Perpendicular correction inside texture block': {
|
| 394 |
+
'description': (
|
| 395 |
+
'Inject texture while suppressing contradicting composition hints '
|
| 396 |
+
'via nested perpendicular projection.'
|
| 397 |
+
),
|
| 398 |
+
'prompt': (
|
| 399 |
+
'base portrait, detailed\n'
|
| 400 |
+
'AND_SALT[5] [\n'
|
| 401 |
+
' skin texture, pores :0.8\n'
|
| 402 |
+
' AND_PERP smooth featureless face :0.4\n'
|
| 403 |
+
'] :0.7'
|
| 404 |
+
),
|
| 405 |
+
},
|
| 406 |
+
'Nested ALIGN + SALT': {
|
| 407 |
+
'description': (
|
| 408 |
+
'Two-pass spatial injection: ALIGN blends a structural concept, '
|
| 409 |
+
'SALT sharpens a second concept only in salient regions.'
|
| 410 |
+
),
|
| 411 |
+
'prompt': (
|
| 412 |
+
'base subject\n'
|
| 413 |
+
'AND_ALIGN[4,8] background concept :0.6\n'
|
| 414 |
+
'AND_SALT[5] [\n'
|
| 415 |
+
' foreground detail :0.8\n'
|
| 416 |
+
' AND_TOPK[0.1] specular highlights :0.4\n'
|
| 417 |
+
'] :0.7'
|
| 418 |
+
),
|
| 419 |
+
},
|
| 420 |
+
'Mirrored composition': {
|
| 421 |
+
'description': (
|
| 422 |
+
'Apply the same concept mirrored, blended via PERP to add symmetric detail '
|
| 423 |
+
'without overwriting the composition.'
|
| 424 |
+
),
|
| 425 |
+
'prompt': (
|
| 426 |
+
'base subject\n'
|
| 427 |
+
'SCALE[-1,1] AND_PERP mirrored version of subject :0.5'
|
| 428 |
+
),
|
| 429 |
+
},
|
| 430 |
+
'Deep nested concept isolation': {
|
| 431 |
+
'description': (
|
| 432 |
+
'Three levels: outer SALT focuses area, middle ALIGN preserves sub-structure, '
|
| 433 |
+
'inner TOPK applies sparse micro-corrections.'
|
| 434 |
+
),
|
| 435 |
+
'prompt': (
|
| 436 |
+
'base subject\n'
|
| 437 |
+
'AND_SALT[5] [\n'
|
| 438 |
+
' region concept :0.8\n'
|
| 439 |
+
' AND_ALIGN[4,8] [\n'
|
| 440 |
+
' structure keeper :0.7\n'
|
| 441 |
+
' AND_TOPK[0.05] micro detail :0.4\n'
|
| 442 |
+
' ] :0.6\n'
|
| 443 |
+
'] :0.7'
|
| 444 |
+
),
|
| 445 |
+
},
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
# ---------------------------------------------------------------------------
|
| 450 |
+
# Affine snippet builder — re-exported from affine_utils (single source of truth)
|
| 451 |
+
# ---------------------------------------------------------------------------
|
| 452 |
+
from lib_neutral_prompt.affine_utils import (
|
| 453 |
+
AFFINE_TRANSFORMS as _AFFINE_TRANSFORMS,
|
| 454 |
+
AFFINE_PRESETS as _AFFINE_PRESETS,
|
| 455 |
+
build_affine_snippet as _build_affine_snippet,
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
# ---------------------------------------------------------------------------
|
| 460 |
+
# Builder v1.2 — list-based node model
|
| 461 |
+
# ---------------------------------------------------------------------------
|
| 462 |
+
|
| 463 |
+
import dataclasses
|
| 464 |
+
import uuid
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
@dataclasses.dataclass
|
| 468 |
+
class BuilderNode:
|
| 469 |
+
"""
|
| 470 |
+
A single AND_* node in the visual matryoshka builder tree.
|
| 471 |
+
|
| 472 |
+
Fields
|
| 473 |
+
------
|
| 474 |
+
node_id : unique identifier (auto-generated UUID4 short)
|
| 475 |
+
strategy : key into _KW_DEFAULTS (e.g. 'AND_SALT', 'AND_TOPK')
|
| 476 |
+
text : text content for this node
|
| 477 |
+
weight : contribution weight (0.0 … 2.0 typical)
|
| 478 |
+
affine : optional affine snippet prefix (e.g. 'ROTATE[0.125]')
|
| 479 |
+
children : list of BuilderNode — forms the nested (matryoshka) subtree
|
| 480 |
+
"""
|
| 481 |
+
node_id: str
|
| 482 |
+
strategy: str
|
| 483 |
+
text: str
|
| 484 |
+
weight: float
|
| 485 |
+
affine: str = dataclasses.field(default='')
|
| 486 |
+
children: List['BuilderNode'] = dataclasses.field(default_factory=list)
|
| 487 |
+
|
| 488 |
+
@classmethod
|
| 489 |
+
def make(cls, strategy: str = 'AND_SALT', text: str = '',
|
| 490 |
+
weight: float = 0.8, affine: str = '') -> 'BuilderNode':
|
| 491 |
+
"""Create a new BuilderNode with a fresh UUID."""
|
| 492 |
+
return cls(
|
| 493 |
+
node_id=uuid.uuid4().hex[:8],
|
| 494 |
+
strategy=strategy,
|
| 495 |
+
text=text,
|
| 496 |
+
weight=weight,
|
| 497 |
+
affine=affine,
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
# ---------------------------------------------------------------------------
|
| 502 |
+
# Builder tree operations (pure — no gradio)
|
| 503 |
+
# ---------------------------------------------------------------------------
|
| 504 |
+
|
| 505 |
+
def builder_add_child(nodes: List[BuilderNode], parent_id: Optional[str] = None,
|
| 506 |
+
**kwargs) -> List[BuilderNode]:
|
| 507 |
+
"""
|
| 508 |
+
Add a new child BuilderNode.
|
| 509 |
+
|
| 510 |
+
If *parent_id* is None: append to the top-level list.
|
| 511 |
+
If *parent_id* matches a node id: append as nested child of that node.
|
| 512 |
+
Returns a *new* list (original is not mutated).
|
| 513 |
+
"""
|
| 514 |
+
new_node = BuilderNode.make(**kwargs)
|
| 515 |
+
if parent_id is None:
|
| 516 |
+
return list(nodes) + [new_node]
|
| 517 |
+
|
| 518 |
+
def _add(ns: List[BuilderNode]) -> List[BuilderNode]:
|
| 519 |
+
result = []
|
| 520 |
+
for n in ns:
|
| 521 |
+
if n.node_id == parent_id:
|
| 522 |
+
result.append(dataclasses.replace(n, children=list(n.children) + [new_node]))
|
| 523 |
+
else:
|
| 524 |
+
result.append(dataclasses.replace(n, children=_add(n.children)))
|
| 525 |
+
return result
|
| 526 |
+
|
| 527 |
+
return _add(nodes)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def builder_remove_node(nodes: List[BuilderNode], node_id: str) -> List[BuilderNode]:
|
| 531 |
+
"""Remove the node with *node_id* from anywhere in the tree."""
|
| 532 |
+
def _remove(ns: List[BuilderNode]) -> List[BuilderNode]:
|
| 533 |
+
return [
|
| 534 |
+
dataclasses.replace(n, children=_remove(n.children))
|
| 535 |
+
for n in ns if n.node_id != node_id
|
| 536 |
+
]
|
| 537 |
+
return _remove(nodes)
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def builder_duplicate_node(nodes: List[BuilderNode], node_id: str) -> List[BuilderNode]:
|
| 541 |
+
"""
|
| 542 |
+
Insert a deep copy of *node_id* immediately after it in its parent list.
|
| 543 |
+
The duplicate gets a fresh node_id.
|
| 544 |
+
"""
|
| 545 |
+
def _fresh(n: BuilderNode) -> BuilderNode:
|
| 546 |
+
return dataclasses.replace(
|
| 547 |
+
n,
|
| 548 |
+
node_id=uuid.uuid4().hex[:8],
|
| 549 |
+
children=[_fresh(c) for c in n.children],
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
def _dup(ns: List[BuilderNode]) -> List[BuilderNode]:
|
| 553 |
+
result = []
|
| 554 |
+
for n in ns:
|
| 555 |
+
result.append(dataclasses.replace(n, children=_dup(n.children)))
|
| 556 |
+
if n.node_id == node_id:
|
| 557 |
+
result.append(_fresh(n))
|
| 558 |
+
return result
|
| 559 |
+
|
| 560 |
+
return _dup(nodes)
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
def builder_move_up(nodes: List[BuilderNode], node_id: str) -> List[BuilderNode]:
|
| 564 |
+
"""Swap *node_id* with its predecessor in its parent list."""
|
| 565 |
+
def _swap(ns: List[BuilderNode]) -> List[BuilderNode]:
|
| 566 |
+
out = []
|
| 567 |
+
for i, n in enumerate(ns):
|
| 568 |
+
if n.node_id == node_id and i > 0:
|
| 569 |
+
out[-1], swapped = dataclasses.replace(n, children=_swap(n.children)), out[-1]
|
| 570 |
+
out.append(swapped)
|
| 571 |
+
else:
|
| 572 |
+
out.append(dataclasses.replace(n, children=_swap(n.children)))
|
| 573 |
+
return out
|
| 574 |
+
return _swap(nodes)
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def builder_move_down(nodes: List[BuilderNode], node_id: str) -> List[BuilderNode]:
|
| 578 |
+
"""Swap *node_id* with its successor in its parent list."""
|
| 579 |
+
def _swap(ns: List[BuilderNode]) -> List[BuilderNode]:
|
| 580 |
+
out = [dataclasses.replace(n, children=_swap(n.children)) for n in ns]
|
| 581 |
+
for i, n in enumerate(out):
|
| 582 |
+
if n.node_id == node_id and i < len(out) - 1:
|
| 583 |
+
out[i], out[i + 1] = out[i + 1], out[i]
|
| 584 |
+
break
|
| 585 |
+
return out
|
| 586 |
+
return _swap(nodes)
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
def builder_update_node(nodes: List[BuilderNode], node_id: str,
|
| 590 |
+
**kwargs) -> List[BuilderNode]:
|
| 591 |
+
"""
|
| 592 |
+
Update fields of *node_id* anywhere in the tree.
|
| 593 |
+
Only keys present in *kwargs* are changed.
|
| 594 |
+
"""
|
| 595 |
+
def _update(ns: List[BuilderNode]) -> List[BuilderNode]:
|
| 596 |
+
return [
|
| 597 |
+
dataclasses.replace(
|
| 598 |
+
n,
|
| 599 |
+
children=_update(n.children),
|
| 600 |
+
**{k: v for k, v in kwargs.items() if k != 'children'},
|
| 601 |
+
) if n.node_id == node_id
|
| 602 |
+
else dataclasses.replace(n, children=_update(n.children))
|
| 603 |
+
for n in ns
|
| 604 |
+
]
|
| 605 |
+
return _update(nodes)
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def serialize_builder_tree(base: str, nodes: List[BuilderNode], indent: int = 0) -> str:
|
| 609 |
+
"""
|
| 610 |
+
Convert a base string + list of BuilderNodes into a prompt string.
|
| 611 |
+
Nested children are rendered as composite bracket blocks.
|
| 612 |
+
"""
|
| 613 |
+
lines: List[str] = []
|
| 614 |
+
if indent == 0 and base.strip():
|
| 615 |
+
lines.append(base.strip())
|
| 616 |
+
|
| 617 |
+
for node in nodes:
|
| 618 |
+
kw = _KW_DEFAULTS.get(node.strategy, node.strategy)
|
| 619 |
+
prefix = f'{node.affine.strip()} ' if node.affine.strip() else ''
|
| 620 |
+
w_str = f':{node.weight}'
|
| 621 |
+
pad = ' ' * indent
|
| 622 |
+
|
| 623 |
+
if node.children:
|
| 624 |
+
# Composite bracket block
|
| 625 |
+
inner: List[str] = []
|
| 626 |
+
if node.text.strip():
|
| 627 |
+
inner.append(f' {pad}{node.text.strip()} {w_str}')
|
| 628 |
+
for sub in node.children:
|
| 629 |
+
for sub_ln in serialize_builder_tree('', [sub], indent + 1).splitlines():
|
| 630 |
+
if sub_ln.strip():
|
| 631 |
+
inner.append(f' {pad}{sub_ln}')
|
| 632 |
+
body = '\n'.join(inner)
|
| 633 |
+
lines.append(f'{pad}{prefix}{kw} [\n{body}\n{pad}] {w_str}')
|
| 634 |
+
else:
|
| 635 |
+
if node.text.strip():
|
| 636 |
+
lines.append(f'{pad}{prefix}{kw} {node.text.strip()} {w_str}')
|
| 637 |
+
# Node with no text and no children: skip (same policy as build_child_block)
|
| 638 |
+
|
| 639 |
+
return '\n'.join(lines)
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def builder_tree_to_prompt(base: str, nodes: List[BuilderNode]) -> str:
|
| 643 |
+
"""Top-level entry point: serialise the full builder tree to a prompt string."""
|
| 644 |
+
return serialize_builder_tree(base, nodes)
|
neutral_prompt_patcheds/lib_neutral_prompt/neutral_prompt_parser.py
ADDED
|
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified Neutral Prompt Parser
|
| 3 |
+
Combines:
|
| 4 |
+
- Core AND / AND_PERP / AND_SALT / AND_TOPK strategies (main branch)
|
| 5 |
+
- Affine spatial transforms: ROTATE / SLIDE / SCALE / SHEAR (affine branch)
|
| 6 |
+
- Fixed AND_ALIGN_D_S / AND_MASK_ALIGN_D_S (alignment_blend branch)
|
| 7 |
+
- Bracket-syntax AND_ALIGN[D,S] / AND_MASK_ALIGN[D,S] (NEW — any valid pair)
|
| 8 |
+
- Configurable k for salience: AND_SALT[k] / AND_SALT_WIDE[k] / AND_SALT_BLOB[k] (NEW)
|
| 9 |
+
|
| 10 |
+
New syntax examples:
|
| 11 |
+
AND_SALT[5] concept :0.8 k=5.0 (more coverage than default 20)
|
| 12 |
+
AND_SALT_WIDE[3] concept :0.6 explicit k override
|
| 13 |
+
AND_SALT_BLOB[8] concept :1.0 blob seed sharpness
|
| 14 |
+
AND_ALIGN[4,8] style :0.5 detail=4, structure=8 (any valid pair)
|
| 15 |
+
AND_MASK_ALIGN[6,12] style :0.5 binary-mask variant
|
| 16 |
+
|
| 17 |
+
AND_ALIGN_4_8 still works — backward-compatible.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import abc
|
| 21 |
+
import dataclasses
|
| 22 |
+
import math
|
| 23 |
+
import re
|
| 24 |
+
from enum import Enum
|
| 25 |
+
from itertools import product
|
| 26 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
# Keyword registry
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
|
| 35 |
+
_BASE_KEYWORD_MAP = {
|
| 36 |
+
'AND_PERP': 'PERPENDICULAR',
|
| 37 |
+
'AND_SALT': 'SALIENCE_MASK',
|
| 38 |
+
'AND_SALT_WIDE': 'SALIENCE_MASK_WIDE',
|
| 39 |
+
'AND_SALT_BLOB': 'SALIENCE_MASK_BLOB',
|
| 40 |
+
'AND_TOPK': 'SEMANTIC_GUIDANCE',
|
| 41 |
+
# Bracket-syntax custom alignment (any D,S pair via AND_ALIGN[D,S])
|
| 42 |
+
'AND_ALIGN': 'ALIGNMENT_BLEND_CUSTOM',
|
| 43 |
+
'AND_MASK_ALIGN': 'ALIGNMENT_MASK_BLEND_CUSTOM',
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
_ALIGN_KEYWORD_MAP = {
|
| 47 |
+
f'AND_ALIGN_{i}_{j}': f'ALIGNMENT_BLEND_{i}_{j}'
|
| 48 |
+
for i, j in product(range(2, 33), repeat=2) if i != j
|
| 49 |
+
}
|
| 50 |
+
_MASK_ALIGN_KEYWORD_MAP = {
|
| 51 |
+
f'AND_MASK_ALIGN_{i}_{j}': f'ALIGNMENT_MASK_BLEND_{i}_{j}'
|
| 52 |
+
for i, j in product(range(2, 33), repeat=2) if i != j
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
keyword_mapping = _BASE_KEYWORD_MAP | _ALIGN_KEYWORD_MAP | _MASK_ALIGN_KEYWORD_MAP
|
| 56 |
+
|
| 57 |
+
PromptKeyword = Enum('PromptKeyword', {'AND': 'AND', **{k: k for k in keyword_mapping}})
|
| 58 |
+
ConciliationStrategy = Enum('ConciliationStrategy', {v: k for k, v in keyword_mapping.items()})
|
| 59 |
+
|
| 60 |
+
prompt_keywords = [e.value for e in PromptKeyword]
|
| 61 |
+
conciliation_strategies = [e.value for e in ConciliationStrategy]
|
| 62 |
+
|
| 63 |
+
_prompt_keywords_set = frozenset(prompt_keywords)
|
| 64 |
+
_conciliation_strategies_set = frozenset(conciliation_strategies)
|
| 65 |
+
|
| 66 |
+
_SALT_KEYWORDS = frozenset({'AND_SALT', 'AND_SALT_WIDE', 'AND_SALT_BLOB'})
|
| 67 |
+
_ALIGN_BRACKET_KEYWORDS = frozenset({'AND_ALIGN', 'AND_MASK_ALIGN'})
|
| 68 |
+
_TOPK_KEYWORDS = frozenset({'AND_TOPK'})
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
# Affine transform definitions (imported from affine_utils — single source of truth)
|
| 74 |
+
# ---------------------------------------------------------------------------
|
| 75 |
+
from lib_neutral_prompt.affine_utils import (
|
| 76 |
+
affine_transforms,
|
| 77 |
+
_affine_warn,
|
| 78 |
+
_make_rotate,
|
| 79 |
+
_make_slide,
|
| 80 |
+
_make_scale,
|
| 81 |
+
_make_shear,
|
| 82 |
+
_apply_affine,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# ---------------------------------------------------------------------------
|
| 88 |
+
# AST node types
|
| 89 |
+
# ---------------------------------------------------------------------------
|
| 90 |
+
|
| 91 |
+
@dataclasses.dataclass
|
| 92 |
+
class PromptExpr(abc.ABC):
|
| 93 |
+
weight: float
|
| 94 |
+
conciliation: Optional[ConciliationStrategy]
|
| 95 |
+
local_transform: Optional[torch.Tensor]
|
| 96 |
+
|
| 97 |
+
@abc.abstractmethod
|
| 98 |
+
def accept(self, visitor, *args, **kwargs) -> Any:
|
| 99 |
+
pass
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@dataclasses.dataclass
|
| 103 |
+
class LeafPrompt(PromptExpr):
|
| 104 |
+
prompt: str
|
| 105 |
+
# conciliation_params MUST be last — it has a default, parent fields don't.
|
| 106 |
+
conciliation_params: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
| 107 |
+
|
| 108 |
+
def accept(self, visitor, *args, **kwargs):
|
| 109 |
+
return visitor.visit_leaf_prompt(self, *args, **kwargs)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@dataclasses.dataclass
|
| 113 |
+
class CompositePrompt(PromptExpr):
|
| 114 |
+
children: List[PromptExpr]
|
| 115 |
+
conciliation_params: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
| 116 |
+
|
| 117 |
+
def accept(self, visitor, *args, **kwargs):
|
| 118 |
+
return visitor.visit_composite_prompt(self, *args, **kwargs)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class FlatSizeVisitor:
|
| 122 |
+
def visit_leaf_prompt(self, that: 'LeafPrompt') -> int:
|
| 123 |
+
return 1
|
| 124 |
+
|
| 125 |
+
def visit_composite_prompt(self, that: 'CompositePrompt') -> int:
|
| 126 |
+
return sum(child.accept(self) for child in that.children) if that.children else 0
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ---------------------------------------------------------------------------
|
| 130 |
+
# Parser
|
| 131 |
+
# ---------------------------------------------------------------------------
|
| 132 |
+
|
| 133 |
+
def parse_root(string: str) -> CompositePrompt:
|
| 134 |
+
tokens = tokenize(string)
|
| 135 |
+
prompts = parse_prompts(tokens)
|
| 136 |
+
return CompositePrompt(1.0, None, None, prompts)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def parse_prompts(tokens: List[str], *, nested: bool = False) -> List[PromptExpr]:
|
| 140 |
+
prompts = [parse_prompt(tokens, first=True, nested=nested)]
|
| 141 |
+
while tokens:
|
| 142 |
+
if nested and tokens[0] == ']':
|
| 143 |
+
break
|
| 144 |
+
prompts.append(parse_prompt(tokens, first=False, nested=nested))
|
| 145 |
+
return prompts
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _compose_affine(
|
| 149 |
+
a: Optional[torch.Tensor],
|
| 150 |
+
b: Optional[torch.Tensor],
|
| 151 |
+
) -> Optional[torch.Tensor]:
|
| 152 |
+
if a is None:
|
| 153 |
+
return b
|
| 154 |
+
if b is None:
|
| 155 |
+
return a
|
| 156 |
+
a3 = torch.vstack([a, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32)])
|
| 157 |
+
b3 = torch.vstack([b, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32)])
|
| 158 |
+
return (b3 @ a3)[:-1]
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _looks_like_leading_affine_prompt(tokens: List[str]) -> bool:
|
| 162 |
+
pos = 0
|
| 163 |
+
n = len(tokens)
|
| 164 |
+
if pos < n and not tokens[pos].strip():
|
| 165 |
+
pos += 1
|
| 166 |
+
if pos >= n or tokens[pos] not in affine_transforms:
|
| 167 |
+
return False
|
| 168 |
+
while pos < n and tokens[pos] in affine_transforms:
|
| 169 |
+
pos += 1
|
| 170 |
+
if pos >= n or tokens[pos] != '[':
|
| 171 |
+
return False
|
| 172 |
+
pos += 1
|
| 173 |
+
if pos < n and tokens[pos] != ']':
|
| 174 |
+
pos += 1
|
| 175 |
+
if pos >= n or tokens[pos] != ']':
|
| 176 |
+
return False
|
| 177 |
+
pos += 1
|
| 178 |
+
if pos < n and not tokens[pos].strip():
|
| 179 |
+
pos += 1
|
| 180 |
+
return pos < n and tokens[pos] in _prompt_keywords_set
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _parse_keyword_params(tokens: List[str], keyword: str) -> Dict[str, Any]:
|
| 184 |
+
"""
|
| 185 |
+
Consume optional bracket params immediately after a keyword token.
|
| 186 |
+
|
| 187 |
+
Safe: only consumes [inner] when inner is purely numeric (digits, dots,
|
| 188 |
+
commas, spaces), so composite-prompt brackets like AND_ALIGN[child AND other]
|
| 189 |
+
are never eaten.
|
| 190 |
+
|
| 191 |
+
Supported forms (spaces around comma are tolerated):
|
| 192 |
+
AND_SALT[5] -> {'k': 5.0}
|
| 193 |
+
AND_SALT[3, 5] -- invalid (too many parts) → warning
|
| 194 |
+
AND_ALIGN[4,8] -> {'d': 4, 's': 8}
|
| 195 |
+
AND_ALIGN[4, 8] -> {'d': 4, 's': 8} (spaces OK)
|
| 196 |
+
AND_MASK_ALIGN[6,12] -> {'d': 6, 's': 12}
|
| 197 |
+
|
| 198 |
+
On invalid but numeric content: consumes the brackets, emits a warning,
|
| 199 |
+
returns {} so the text is not silently swallowed as plain prompt text.
|
| 200 |
+
"""
|
| 201 |
+
if (keyword not in _SALT_KEYWORDS and keyword not in _ALIGN_BRACKET_KEYWORDS
|
| 202 |
+
and keyword not in _TOPK_KEYWORDS):
|
| 203 |
+
return {}
|
| 204 |
+
if len(tokens) < 3 or tokens[0] != '[' or tokens[2] != ']':
|
| 205 |
+
return {}
|
| 206 |
+
|
| 207 |
+
inner = tokens[1].strip()
|
| 208 |
+
# Allow digits, dots, commas, and spaces — but nothing alphabetic.
|
| 209 |
+
# This keeps composite brackets like [concept AND style] safely un-consumed.
|
| 210 |
+
if not re.match(r'^[\d.,\s]+$', inner):
|
| 211 |
+
return {}
|
| 212 |
+
|
| 213 |
+
# Consume the brackets now — we own this content.
|
| 214 |
+
tokens.pop(0); tokens.pop(0); tokens.pop(0)
|
| 215 |
+
|
| 216 |
+
try:
|
| 217 |
+
parts = [p.strip() for p in inner.split(',') if p.strip()]
|
| 218 |
+
|
| 219 |
+
if keyword in _SALT_KEYWORDS:
|
| 220 |
+
if len(parts) != 1:
|
| 221 |
+
_param_warn(keyword, inner,
|
| 222 |
+
f"expected exactly one number (k), got {len(parts)} value(s)")
|
| 223 |
+
return {}
|
| 224 |
+
k = float(parts[0])
|
| 225 |
+
if k <= 0:
|
| 226 |
+
_param_warn(keyword, inner,
|
| 227 |
+
f"k must be > 0, got {k}")
|
| 228 |
+
return {}
|
| 229 |
+
return {'k': k}
|
| 230 |
+
|
| 231 |
+
elif keyword in _ALIGN_BRACKET_KEYWORDS:
|
| 232 |
+
if len(parts) != 2:
|
| 233 |
+
_param_warn(keyword, inner,
|
| 234 |
+
f"expected exactly two numbers (D, S), got {len(parts)} value(s)")
|
| 235 |
+
return {}
|
| 236 |
+
d, s = int(float(parts[0])), int(float(parts[1]))
|
| 237 |
+
if not (2 <= d <= 32):
|
| 238 |
+
_param_warn(keyword, inner,
|
| 239 |
+
f"D={d} is out of range [2, 32]")
|
| 240 |
+
return {}
|
| 241 |
+
if not (2 <= s <= 32):
|
| 242 |
+
_param_warn(keyword, inner,
|
| 243 |
+
f"S={s} is out of range [2, 32]")
|
| 244 |
+
return {}
|
| 245 |
+
if d == s:
|
| 246 |
+
_param_warn(keyword, inner,
|
| 247 |
+
f"D and S must differ (both are {d}); "
|
| 248 |
+
"no structural distinction possible")
|
| 249 |
+
return {}
|
| 250 |
+
return {'d': d, 's': s}
|
| 251 |
+
|
| 252 |
+
elif keyword in _TOPK_KEYWORDS:
|
| 253 |
+
if len(parts) != 1:
|
| 254 |
+
_param_warn(keyword, inner,
|
| 255 |
+
f"expected exactly one number (threshold), got {len(parts)} value(s)")
|
| 256 |
+
return {}
|
| 257 |
+
threshold = float(parts[0])
|
| 258 |
+
if not (0.0 < threshold <= 1.0):
|
| 259 |
+
_param_warn(keyword, inner,
|
| 260 |
+
f"threshold must be in (0, 1], got {threshold}")
|
| 261 |
+
return {}
|
| 262 |
+
return {'threshold': threshold}
|
| 263 |
+
|
| 264 |
+
except (ValueError, IndexError) as exc:
|
| 265 |
+
_param_warn(keyword, inner, f"could not parse numbers: {exc}")
|
| 266 |
+
|
| 267 |
+
return {}
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _param_warn(keyword: str, raw: str, reason: str) -> None:
|
| 271 |
+
"""
|
| 272 |
+
Print a one-line parse warning to stderr.
|
| 273 |
+
|
| 274 |
+
Controlled by global_state.verbose so it doesn't spam quiet setups.
|
| 275 |
+
Imported lazily to avoid a circular import at module level.
|
| 276 |
+
"""
|
| 277 |
+
try:
|
| 278 |
+
from lib_neutral_prompt import global_state as _gs
|
| 279 |
+
if not _gs.verbose:
|
| 280 |
+
return
|
| 281 |
+
except ImportError:
|
| 282 |
+
pass # if global_state is not yet available, always print
|
| 283 |
+
import sys
|
| 284 |
+
print(
|
| 285 |
+
f"[neutral-prompt] WARNING: invalid params for {keyword}[{raw}] — {reason}. "
|
| 286 |
+
"Ignoring bracket content; keyword will use its default behaviour.",
|
| 287 |
+
file=sys.stderr,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def parse_prompt(tokens: List[str], *, first: bool, nested: bool = False) -> PromptExpr:
|
| 292 |
+
leading_affine: Optional[torch.Tensor] = None
|
| 293 |
+
if _looks_like_leading_affine_prompt(tokens):
|
| 294 |
+
probe = tokens.copy()
|
| 295 |
+
leading_affine = _parse_affine_transform(probe)
|
| 296 |
+
tokens[:] = probe
|
| 297 |
+
|
| 298 |
+
if tokens and tokens[0] in _prompt_keywords_set:
|
| 299 |
+
prompt_type = tokens.pop(0)
|
| 300 |
+
else:
|
| 301 |
+
prompt_type = PromptKeyword.AND.value
|
| 302 |
+
|
| 303 |
+
conciliation = ConciliationStrategy(prompt_type) if prompt_type in _conciliation_strategies_set else None
|
| 304 |
+
|
| 305 |
+
# Parse trailing affine once before params (handles AND_PERP ROTATE[0.25] text)
|
| 306 |
+
# and once after (handles AND_SALT[5] ROTATE[0.25] text).
|
| 307 |
+
trailing_affine_pre = _parse_affine_transform(tokens)
|
| 308 |
+
|
| 309 |
+
# Consume bracket params BEFORE composite-bracket check.
|
| 310 |
+
# This ensures AND_ALIGN[4,8] text is not confused with
|
| 311 |
+
# AND_ALIGN [child AND other] (composite form).
|
| 312 |
+
conciliation_params = _parse_keyword_params(tokens, prompt_type)
|
| 313 |
+
|
| 314 |
+
# Second trailing-affine pass: catches affine that comes AFTER [k] params.
|
| 315 |
+
trailing_affine_post = _parse_affine_transform(tokens)
|
| 316 |
+
trailing_affine = _compose_affine(trailing_affine_pre, trailing_affine_post)
|
| 317 |
+
affine_transform = _compose_affine(leading_affine, trailing_affine)
|
| 318 |
+
|
| 319 |
+
tokens_copy = tokens.copy()
|
| 320 |
+
if tokens_copy and tokens_copy[0] == '[':
|
| 321 |
+
tokens_copy.pop(0)
|
| 322 |
+
prompts = parse_prompts(tokens_copy, nested=True)
|
| 323 |
+
if tokens_copy:
|
| 324 |
+
assert tokens_copy.pop(0) == ']'
|
| 325 |
+
if len(prompts) > 1:
|
| 326 |
+
tokens[:] = tokens_copy
|
| 327 |
+
weight = _parse_weight(tokens)
|
| 328 |
+
return CompositePrompt(weight, conciliation, affine_transform, prompts, conciliation_params)
|
| 329 |
+
|
| 330 |
+
prompt_text, weight = _parse_prompt_text(tokens, nested=nested)
|
| 331 |
+
return LeafPrompt(weight, conciliation, affine_transform, prompt_text, conciliation_params)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def _parse_prompt_text(tokens: List[str], *, nested: bool = False) -> Tuple[str, float]:
|
| 335 |
+
text = ''
|
| 336 |
+
depth = 0
|
| 337 |
+
weight = 1.0
|
| 338 |
+
while tokens:
|
| 339 |
+
tok = tokens[0]
|
| 340 |
+
if tok == ']':
|
| 341 |
+
if depth == 0:
|
| 342 |
+
if nested:
|
| 343 |
+
break
|
| 344 |
+
else:
|
| 345 |
+
depth -= 1
|
| 346 |
+
elif tok == '[':
|
| 347 |
+
depth += 1
|
| 348 |
+
elif tok == ':':
|
| 349 |
+
if len(tokens) >= 2 and _is_float(tokens[1].strip()):
|
| 350 |
+
if len(tokens) < 3 or tokens[2] in _prompt_keywords_set or (tokens[2] == ']' and depth == 0):
|
| 351 |
+
tokens.pop(0)
|
| 352 |
+
weight = float(tokens.pop(0).strip())
|
| 353 |
+
break
|
| 354 |
+
elif tok in _prompt_keywords_set:
|
| 355 |
+
break
|
| 356 |
+
elif depth == 0 and _looks_like_leading_affine_prompt(tokens):
|
| 357 |
+
break
|
| 358 |
+
text += tokens.pop(0)
|
| 359 |
+
return text, weight
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def _parse_affine_transform(tokens: List[str]) -> Optional[torch.Tensor]:
|
| 363 |
+
tokens_copy = tokens.copy()
|
| 364 |
+
if tokens_copy and not tokens_copy[0].strip():
|
| 365 |
+
tokens_copy.pop(0)
|
| 366 |
+
|
| 367 |
+
affine_funcs = []
|
| 368 |
+
|
| 369 |
+
while tokens_copy and tokens_copy[0] in affine_transforms:
|
| 370 |
+
func = affine_transforms[tokens_copy.pop(0)]
|
| 371 |
+
args: List[float] = []
|
| 372 |
+
|
| 373 |
+
if not (tokens_copy and tokens_copy[0] == '['):
|
| 374 |
+
break
|
| 375 |
+
tokens_copy.pop(0)
|
| 376 |
+
|
| 377 |
+
if tokens_copy and tokens_copy[0] != ']':
|
| 378 |
+
if tokens_copy[0].strip():
|
| 379 |
+
try:
|
| 380 |
+
args = [float(a.strip()) for a in tokens_copy.pop(0).split(',')]
|
| 381 |
+
except ValueError:
|
| 382 |
+
break
|
| 383 |
+
else:
|
| 384 |
+
tokens_copy.pop(0)
|
| 385 |
+
|
| 386 |
+
if not (tokens_copy and tokens_copy[0] == ']'):
|
| 387 |
+
break
|
| 388 |
+
tokens_copy.pop(0)
|
| 389 |
+
|
| 390 |
+
affine_funcs.append(lambda t, f=func, a=args: f(t, *a))
|
| 391 |
+
if tokens_copy and not tokens_copy[0].strip():
|
| 392 |
+
tokens_copy.pop(0)
|
| 393 |
+
tokens[:] = tokens_copy
|
| 394 |
+
|
| 395 |
+
if not affine_funcs:
|
| 396 |
+
return None
|
| 397 |
+
|
| 398 |
+
transform = torch.eye(3, dtype=torch.float32)[:-1]
|
| 399 |
+
for fn in reversed(affine_funcs):
|
| 400 |
+
transform = fn(transform)
|
| 401 |
+
return transform
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def _parse_weight(tokens: List[str]) -> float:
|
| 405 |
+
if len(tokens) >= 2 and tokens[0] == ':' and _is_float(tokens[1]):
|
| 406 |
+
tokens.pop(0)
|
| 407 |
+
return float(tokens.pop(0))
|
| 408 |
+
return 1.0
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def tokenize(s: str) -> List[str]:
|
| 412 |
+
affine_kw_pattern = '|'.join(
|
| 413 |
+
rf'(?<!\w){kw}(?!\w)' for kw in affine_transforms
|
| 414 |
+
)
|
| 415 |
+
keyword_pattern = (
|
| 416 |
+
r'(?<!\w)AND_MASK_ALIGN_\d+_\d+(?!\w)' # fixed-suffix (longest first)
|
| 417 |
+
r'|(?<!\w)AND_ALIGN_\d+_\d+(?!\w)'
|
| 418 |
+
r'|(?<!\w)AND_MASK_ALIGN(?!\w)' # bare bracket-syntax keyword
|
| 419 |
+
r'|(?<!\w)AND_ALIGN(?!\w)'
|
| 420 |
+
r'|(?<!\w)AND_PERP(?!\w)'
|
| 421 |
+
r'|(?<!\w)AND_SALT_WIDE(?!\w)'
|
| 422 |
+
r'|(?<!\w)AND_SALT_BLOB(?!\w)'
|
| 423 |
+
r'|(?<!\w)AND_SALT(?!\w)'
|
| 424 |
+
r'|(?<!\w)AND_TOPK(?!\w)'
|
| 425 |
+
r'|(?<!\w)AND(?!\w)'
|
| 426 |
+
)
|
| 427 |
+
return [t for t in re.split(rf'(\[|\]|:|{keyword_pattern}|{affine_kw_pattern})', s) if t.strip()]
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def _is_float(s: str) -> bool:
|
| 431 |
+
try:
|
| 432 |
+
float(s)
|
| 433 |
+
return True
|
| 434 |
+
except ValueError:
|
| 435 |
+
return False
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
if __name__ == '__main__':
|
| 439 |
+
cases = [
|
| 440 |
+
('original AND_SALT', 'hello AND_SALT concept :1.0'),
|
| 441 |
+
('original AND_ALIGN_4_8', 'hello AND_ALIGN_4_8 watercolor :0.5'),
|
| 442 |
+
('new AND_SALT[5]', 'hello AND_SALT[5] concept :1.0'),
|
| 443 |
+
('new AND_SALT_WIDE[3]', 'hello AND_SALT_WIDE[3] concept :0.8'),
|
| 444 |
+
('new AND_SALT_BLOB[8]', 'hello AND_SALT_BLOB[8] concept :1.0'),
|
| 445 |
+
('new AND_ALIGN[4,8]', 'hello AND_ALIGN[4,8] watercolor :0.5'),
|
| 446 |
+
('new AND_MASK_ALIGN[6,12]', 'hello AND_MASK_ALIGN[6,12] structure :0.7'),
|
| 447 |
+
('composite not eaten', 'hello AND_ALIGN[arst AND defg :2.0] :0.5'),
|
| 448 |
+
]
|
| 449 |
+
for label, c in cases:
|
| 450 |
+
res = parse_root(c)
|
| 451 |
+
child = res.children[-1]
|
| 452 |
+
print(f'[{label}]')
|
| 453 |
+
print(f' conciliation={child.conciliation} params={child.conciliation_params}')
|
neutral_prompt_patcheds/lib_neutral_prompt/prompt_parser_hijack.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
from lib_neutral_prompt import hijacker, global_state, neutral_prompt_parser
|
| 6 |
+
from modules import script_callbacks, prompt_parser
|
| 7 |
+
|
| 8 |
+
# ---------------------------------------------------------------------------
|
| 9 |
+
# Fix: prompt_parser_fixed escapes standalone '&' as a multicond separator.
|
| 10 |
+
# neutral_prompt_parser does NOT treat '&' as AND — it keeps it as literal text.
|
| 11 |
+
# So a leaf like "cat & dog" transpiles to "cat & dog :1.0",
|
| 12 |
+
# and the patched prompt_parser splits that into 3 multicond branches instead of 1,
|
| 13 |
+
# which shifts all batch_cond_indices and breaks PERP/SALT/affine.
|
| 14 |
+
#
|
| 15 |
+
# Solution: escape any standalone '&' inside leaf text with '\&' before handing
|
| 16 |
+
# the string to the webui parser. The patched prompt_parser correctly unescapes
|
| 17 |
+
# '\&' back to '&' during conditioning, so the model sees the original text.
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
|
| 20 |
+
_STANDALONE_AMP = re.compile(r'(?<!\\)(?<!\S)&(?!\S)')
|
| 21 |
+
|
| 22 |
+
# ── debug logging ──────────────────────────────────────────────────────────
|
| 23 |
+
# Set env variable NP_DEBUG=1 to enable verbose output in the A1111 console.
|
| 24 |
+
# Example (Linux/Mac): NP_DEBUG=1 python launch.py
|
| 25 |
+
# Example (Windows cmd): set NP_DEBUG=1 && python launch.py
|
| 26 |
+
import os as _os
|
| 27 |
+
_DEBUG = _os.getenv("NP_DEBUG", "0").strip() not in ("0", "", "false", "no", "off")
|
| 28 |
+
_log = logging.getLogger("neutral_prompt.hijack")
|
| 29 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _escape_leaf_ampersands(text: str) -> str:
|
| 33 |
+
"""Escape standalone '&' so patched prompt_parser doesn't split a single
|
| 34 |
+
Neutral Prompt leaf into extra multicond branches.
|
| 35 |
+
"cat & dog" -> "cat \\& dog"
|
| 36 |
+
"R&D" -> "R&D" (unchanged — not standalone)
|
| 37 |
+
"\\&" -> "\\&" (unchanged — already escaped)
|
| 38 |
+
"""
|
| 39 |
+
if not text or '&' not in text:
|
| 40 |
+
return text
|
| 41 |
+
return _STANDALONE_AMP.sub(r'\\&', text)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
prompt_parser_hijacker = hijacker.ModuleHijacker.install_or_get(
|
| 45 |
+
module=prompt_parser,
|
| 46 |
+
hijacker_attribute='__neutral_prompt_hijacker',
|
| 47 |
+
on_uninstall=script_callbacks.on_script_unloaded,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@prompt_parser_hijacker.hijack('get_multicond_prompt_list')
|
| 52 |
+
def get_multicond_prompt_list_hijack(prompts, original_function):
|
| 53 |
+
if not global_state.is_enabled:
|
| 54 |
+
return original_function(prompts)
|
| 55 |
+
|
| 56 |
+
global_state.prompt_exprs = parse_prompts(prompts)
|
| 57 |
+
webui_prompts = transpile_exprs(global_state.prompt_exprs)
|
| 58 |
+
|
| 59 |
+
# ── debug: transpiled strings ──────────────────────────────────────────
|
| 60 |
+
if _DEBUG:
|
| 61 |
+
for i, (orig, transp) in enumerate(zip(prompts, webui_prompts)):
|
| 62 |
+
_log.warning(
|
| 63 |
+
"[NP_DEBUG] prompt[%d]\n"
|
| 64 |
+
" original : %r\n"
|
| 65 |
+
" transpiled : %r\n"
|
| 66 |
+
" branches : %d",
|
| 67 |
+
i, orig, transp,
|
| 68 |
+
# count multicond splits the same way prompt_parser_fixed does
|
| 69 |
+
len(re.split(r'(?:\bAND\b|(?<!\S)&(?!\S))(?!_PERP|_SALT|_TOPK)', transp))
|
| 70 |
+
)
|
| 71 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 72 |
+
|
| 73 |
+
if isinstance(prompts, getattr(prompt_parser, 'SdConditioning', type(None))):
|
| 74 |
+
webui_prompts = prompt_parser.SdConditioning(webui_prompts, copy_from=prompts)
|
| 75 |
+
|
| 76 |
+
result = original_function(webui_prompts)
|
| 77 |
+
|
| 78 |
+
# ── debug: multicond parts ─────────────────────────────────────────────
|
| 79 |
+
if _DEBUG:
|
| 80 |
+
conds_list, prompt_flat_list, prompt_indexes = result
|
| 81 |
+
_log.warning(
|
| 82 |
+
"[NP_DEBUG] get_multicond_prompt_list result\n"
|
| 83 |
+
" prompt_flat_list : %r\n"
|
| 84 |
+
" prompt_indexes : %r\n"
|
| 85 |
+
" conds_list : %r",
|
| 86 |
+
list(prompt_flat_list),
|
| 87 |
+
dict(prompt_indexes),
|
| 88 |
+
conds_list,
|
| 89 |
+
)
|
| 90 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 91 |
+
|
| 92 |
+
return result
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def parse_prompts(prompts: List[str]) -> List[neutral_prompt_parser.PromptExpr]:
|
| 96 |
+
exprs = []
|
| 97 |
+
for prompt in prompts:
|
| 98 |
+
expr = neutral_prompt_parser.parse_root(prompt)
|
| 99 |
+
exprs.append(expr)
|
| 100 |
+
return exprs
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def transpile_exprs(exprs: neutral_prompt_parser.PromptExpr):
|
| 104 |
+
webui_prompts = []
|
| 105 |
+
for expr in exprs:
|
| 106 |
+
webui_prompts.append(expr.accept(WebuiPromptVisitor()))
|
| 107 |
+
return webui_prompts
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class WebuiPromptVisitor:
|
| 111 |
+
def visit_leaf_prompt(self, that: neutral_prompt_parser.LeafPrompt) -> str:
|
| 112 |
+
prompt = _escape_leaf_ampersands(that.prompt)
|
| 113 |
+
return f'{prompt} :{that.weight}'
|
| 114 |
+
|
| 115 |
+
def visit_composite_prompt(self, that: neutral_prompt_parser.CompositePrompt) -> str:
|
| 116 |
+
return ' AND '.join(child.accept(self) for child in that.children)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@prompt_parser_hijacker.hijack('reconstruct_multicond_batch')
|
| 120 |
+
def reconstruct_multicond_batch_hijack(*args, original_function, **kwargs):
|
| 121 |
+
"""Store batch_cond_indices for the pre-noise affine hook (affine branch)."""
|
| 122 |
+
res = original_function(*args, **kwargs)
|
| 123 |
+
global_state.batch_cond_indices = res[0]
|
| 124 |
+
|
| 125 |
+
# ── debug: batch_cond_indices ──────────────────────────────────────────
|
| 126 |
+
if _DEBUG:
|
| 127 |
+
_log.warning(
|
| 128 |
+
"[NP_DEBUG] reconstruct_multicond_batch\n"
|
| 129 |
+
" batch_cond_indices : %r",
|
| 130 |
+
res[0],
|
| 131 |
+
)
|
| 132 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 133 |
+
|
| 134 |
+
return res
|
neutral_prompt_patcheds/lib_neutral_prompt/protection_utils.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Base-prompt protection subsystem for sd-webui-neutral-prompt.
|
| 3 |
+
|
| 4 |
+
Extracted from cfg_denoiser_hijack.py to form a standalone, testable module.
|
| 5 |
+
|
| 6 |
+
Protection modes
|
| 7 |
+
----------------
|
| 8 |
+
off — no guard at all; auxiliary children can dominate freely.
|
| 9 |
+
auto — structural guard: fires when the prompt tree lacks a top-level
|
| 10 |
+
base AND-child (every child has a conciliation keyword).
|
| 11 |
+
strict — auto PLUS numerical: also fires when norm(base_delta)/norm(aux_delta)
|
| 12 |
+
falls below *strict_threshold*. Falls back to standard A1111 CFG.
|
| 13 |
+
soft — auto structural guard + numerical attenuation: instead of a hard
|
| 14 |
+
fallback, attenuates the auxiliary delta so that
|
| 15 |
+
norm(base_delta) / norm(attenuated_aux_delta) ≈ soft_floor.
|
| 16 |
+
The image retains auxiliary influence but cannot overpower the base.
|
| 17 |
+
|
| 18 |
+
All modes except 'off' print a diagnostic message when global_state.verbose=True.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import sys
|
| 24 |
+
from typing import List, Optional, Tuple
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
|
| 28 |
+
from lib_neutral_prompt import global_state, neutral_prompt_parser
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
# Structural checks (per-prompt, O(n) in children)
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
|
| 35 |
+
def prompt_has_valid_base_path(
|
| 36 |
+
prompt: neutral_prompt_parser.PromptExpr,
|
| 37 |
+
) -> bool:
|
| 38 |
+
"""
|
| 39 |
+
Per-prompt structural check.
|
| 40 |
+
|
| 41 |
+
Returns True iff *prompt* is a CompositePrompt with at least one top-level
|
| 42 |
+
child that carries no conciliation keyword (i.e. a plain base segment).
|
| 43 |
+
"""
|
| 44 |
+
return (
|
| 45 |
+
isinstance(prompt, neutral_prompt_parser.CompositePrompt)
|
| 46 |
+
and any(c.conciliation is None for c in prompt.children)
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_invalid_base_prompt_indices(
|
| 51 |
+
prompt_exprs: List[neutral_prompt_parser.PromptExpr],
|
| 52 |
+
) -> List[int]:
|
| 53 |
+
"""
|
| 54 |
+
Return the 0-based indices of any prompts in the batch that lack a
|
| 55 |
+
valid base AND-path. An empty return value means every prompt is OK.
|
| 56 |
+
|
| 57 |
+
Per-prompt rather than any(...): a healthy neighbour must not silently
|
| 58 |
+
legalise a broken prompt in the same batched request.
|
| 59 |
+
"""
|
| 60 |
+
return [
|
| 61 |
+
i for i, p in enumerate(prompt_exprs)
|
| 62 |
+
if not prompt_has_valid_base_path(p)
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def has_valid_base_path(
|
| 67 |
+
prompt_exprs: List[neutral_prompt_parser.PromptExpr],
|
| 68 |
+
) -> bool:
|
| 69 |
+
"""
|
| 70 |
+
Batch-level alias: True iff the batch is non-empty AND every prompt has a
|
| 71 |
+
valid base path. An empty batch returns False (nothing to protect).
|
| 72 |
+
"""
|
| 73 |
+
return bool(prompt_exprs) and len(get_invalid_base_prompt_indices(prompt_exprs)) == 0
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Old private names kept for backward-compat (hijack still references them via import)
|
| 77 |
+
_prompt_has_valid_base_path = prompt_has_valid_base_path
|
| 78 |
+
_get_invalid_base_prompt_indices = get_invalid_base_prompt_indices
|
| 79 |
+
_has_valid_base_path = has_valid_base_path
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# ---------------------------------------------------------------------------
|
| 83 |
+
# Numerical helpers
|
| 84 |
+
# ---------------------------------------------------------------------------
|
| 85 |
+
|
| 86 |
+
def safe_norm(tensor: torch.Tensor) -> float:
|
| 87 |
+
"""Return float L2-norm; 0.0 on any error (empty tensor, NaN, etc.)."""
|
| 88 |
+
try:
|
| 89 |
+
v = torch.linalg.norm(tensor.float()).item()
|
| 90 |
+
return v if v == v else 0.0 # NaN guard
|
| 91 |
+
except Exception:
|
| 92 |
+
return 0.0
|
| 93 |
+
|
| 94 |
+
_safe_norm = safe_norm # alias
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# ---------------------------------------------------------------------------
|
| 98 |
+
# Soft-mode attenuation
|
| 99 |
+
# ---------------------------------------------------------------------------
|
| 100 |
+
|
| 101 |
+
def compute_soft_attenuation(
|
| 102 |
+
base_norm: float,
|
| 103 |
+
aux_norm: float,
|
| 104 |
+
threshold: float,
|
| 105 |
+
) -> float:
|
| 106 |
+
"""
|
| 107 |
+
Compute a [0, 1] scale factor for the auxiliary delta so that, after
|
| 108 |
+
scaling, the base/aux ratio reaches at least *threshold*.
|
| 109 |
+
|
| 110 |
+
When the ratio is already ≥ threshold returns 1.0 (no change).
|
| 111 |
+
When aux_norm ≈ 0 returns 1.0 (nothing to attenuate).
|
| 112 |
+
|
| 113 |
+
Formula: factor = min(1, base_norm / (aux_norm * threshold))
|
| 114 |
+
This ensures norm(base) / norm(factor * aux) ≈ threshold.
|
| 115 |
+
"""
|
| 116 |
+
if aux_norm < 1e-8 or base_norm < 0:
|
| 117 |
+
return 1.0
|
| 118 |
+
ratio = base_norm / aux_norm
|
| 119 |
+
if ratio >= threshold:
|
| 120 |
+
return 1.0
|
| 121 |
+
# factor so that base / (factor * aux) = threshold ↔ factor = base / (aux * threshold)
|
| 122 |
+
return max(0.0, base_norm / (aux_norm * threshold))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# ---------------------------------------------------------------------------
|
| 126 |
+
# Primary policy function
|
| 127 |
+
# ---------------------------------------------------------------------------
|
| 128 |
+
|
| 129 |
+
def should_fallback_for_protection(
|
| 130 |
+
x_out: torch.Tensor,
|
| 131 |
+
batch_cond_indices: List[List],
|
| 132 |
+
text_uncond: torch.Tensor,
|
| 133 |
+
# These visitors are injected to avoid circular imports
|
| 134 |
+
CondDeltaVisitor, # noqa: N803
|
| 135 |
+
AuxCondDeltaVisitor, # noqa: N803
|
| 136 |
+
DenoiseArgs, # noqa: N803
|
| 137 |
+
) -> Tuple[bool, Optional[str]]:
|
| 138 |
+
"""
|
| 139 |
+
Decide whether to fall back to standard CFG.
|
| 140 |
+
|
| 141 |
+
Returns ``(should_fallback: bool, reason: str | None)``.
|
| 142 |
+
|
| 143 |
+
Modes:
|
| 144 |
+
'off' → (False, None) — never fires.
|
| 145 |
+
'auto' → structural check only.
|
| 146 |
+
'strict' → structural check + numerical ratio check → hard fallback.
|
| 147 |
+
'soft' → structural check only → (False, None); attenuation is
|
| 148 |
+
handled separately by apply_soft_attenuation_if_needed().
|
| 149 |
+
"""
|
| 150 |
+
mode = global_state.normalize_protection_mode(global_state.protection_mode)
|
| 151 |
+
|
| 152 |
+
if mode == 'off':
|
| 153 |
+
return False, None
|
| 154 |
+
|
| 155 |
+
# --- Structural check (auto / strict / soft) — per-prompt ---
|
| 156 |
+
invalid_indices = get_invalid_base_prompt_indices(global_state.prompt_exprs)
|
| 157 |
+
if invalid_indices:
|
| 158 |
+
reason = (
|
| 159 |
+
f'no valid base AND-path for prompt index/indices {invalid_indices} '
|
| 160 |
+
f'(mode={mode}). '
|
| 161 |
+
'Those segments have only conciliation keywords (AND_PERP / AND_SALT / ...). '
|
| 162 |
+
'Add a plain base prompt before AND_* segments, '
|
| 163 |
+
'or set protection to "off" to allow this intentionally.'
|
| 164 |
+
)
|
| 165 |
+
return True, reason
|
| 166 |
+
|
| 167 |
+
# Soft mode only does structural check; numerical handling is done in-place.
|
| 168 |
+
if mode in ('auto', 'soft'):
|
| 169 |
+
return False, None
|
| 170 |
+
|
| 171 |
+
# --- Numerical check (strict only) ---
|
| 172 |
+
threshold = global_state.clamp_strict_threshold(global_state.strict_threshold)
|
| 173 |
+
uncond = x_out[-text_uncond.shape[0]:]
|
| 174 |
+
|
| 175 |
+
for batch_i, (prompt, cond_indices) in enumerate(
|
| 176 |
+
zip(global_state.prompt_exprs, batch_cond_indices)):
|
| 177 |
+
args = DenoiseArgs(x_out, uncond[batch_i], cond_indices)
|
| 178 |
+
base_delta = prompt.accept(CondDeltaVisitor(), args, 0)
|
| 179 |
+
aux_delta = prompt.accept(AuxCondDeltaVisitor(), args, base_delta, 0)
|
| 180 |
+
|
| 181 |
+
base_norm = safe_norm(base_delta)
|
| 182 |
+
aux_norm = safe_norm(aux_delta)
|
| 183 |
+
|
| 184 |
+
if aux_norm < 1e-8:
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
ratio = base_norm / aux_norm
|
| 188 |
+
if ratio < threshold:
|
| 189 |
+
return (
|
| 190 |
+
True,
|
| 191 |
+
f'base/aux ratio {ratio:.4f} < threshold {threshold:.4f} '
|
| 192 |
+
f'(prompt #{batch_i}, mode=strict). '
|
| 193 |
+
'Base contribution is too weak relative to auxiliary. '
|
| 194 |
+
'Lower the strict threshold or switch to "auto" to allow this.',
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
return False, None
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def apply_soft_attenuation_if_needed(
|
| 201 |
+
aux_delta: torch.Tensor,
|
| 202 |
+
base_delta: torch.Tensor,
|
| 203 |
+
prompt_index: int,
|
| 204 |
+
) -> Tuple[torch.Tensor, Optional[str]]:
|
| 205 |
+
"""
|
| 206 |
+
For 'soft' mode: attenuate *aux_delta* in-place (returns new tensor) so
|
| 207 |
+
that norm(base)/norm(aux) ≥ soft_threshold.
|
| 208 |
+
|
| 209 |
+
Returns ``(possibly_attenuated_aux_delta, diagnostic_note | None)``.
|
| 210 |
+
Called per-prompt inside the combine_denoised loop.
|
| 211 |
+
"""
|
| 212 |
+
mode = global_state.normalize_protection_mode(global_state.protection_mode)
|
| 213 |
+
if mode != 'soft':
|
| 214 |
+
return aux_delta, None
|
| 215 |
+
|
| 216 |
+
threshold = global_state.clamp_strict_threshold(global_state.strict_threshold)
|
| 217 |
+
base_norm = safe_norm(base_delta)
|
| 218 |
+
aux_norm = safe_norm(aux_delta)
|
| 219 |
+
factor = compute_soft_attenuation(base_norm, aux_norm, threshold)
|
| 220 |
+
|
| 221 |
+
if factor >= 1.0:
|
| 222 |
+
return aux_delta, None
|
| 223 |
+
|
| 224 |
+
note = (
|
| 225 |
+
f'[soft protection] prompt #{prompt_index}: '
|
| 226 |
+
f'base/aux ratio {base_norm / aux_norm:.4f} < {threshold:.4f}; '
|
| 227 |
+
f'attenuating aux × {factor:.3f}'
|
| 228 |
+
)
|
| 229 |
+
if global_state.verbose:
|
| 230 |
+
print(f'[neutral-prompt] {note}', file=sys.stderr)
|
| 231 |
+
|
| 232 |
+
return aux_delta * factor, note
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# ---------------------------------------------------------------------------
|
| 236 |
+
# Diagnostics helpers (used by matryoshka_utils explain panel)
|
| 237 |
+
# ---------------------------------------------------------------------------
|
| 238 |
+
|
| 239 |
+
def protection_verdict(
|
| 240 |
+
prompt_str: str,
|
| 241 |
+
) -> Tuple[str, str]:
|
| 242 |
+
"""
|
| 243 |
+
Structural-only verdict for the debug/explain panel.
|
| 244 |
+
|
| 245 |
+
Returns (status, message) where status is 'ok', 'fire', or 'off'.
|
| 246 |
+
Does not simulate runtime ratio checks or batch state.
|
| 247 |
+
"""
|
| 248 |
+
from lib_neutral_prompt import neutral_prompt_parser as _p
|
| 249 |
+
mode = global_state.normalize_protection_mode(global_state.protection_mode)
|
| 250 |
+
|
| 251 |
+
if mode == 'off':
|
| 252 |
+
return 'off', 'Guard disabled — no fallback will occur'
|
| 253 |
+
|
| 254 |
+
try:
|
| 255 |
+
root = _p.parse_root(prompt_str)
|
| 256 |
+
except Exception:
|
| 257 |
+
return 'fire', 'Parse failed — cannot determine base-path'
|
| 258 |
+
|
| 259 |
+
if isinstance(root, _p.CompositePrompt):
|
| 260 |
+
has_base = any(c.conciliation is None for c in root.children)
|
| 261 |
+
else:
|
| 262 |
+
has_base = root.conciliation is None
|
| 263 |
+
|
| 264 |
+
if not has_base:
|
| 265 |
+
return 'fire', 'WOULD FIRE — no plain base segment at the top level'
|
| 266 |
+
|
| 267 |
+
suffix = ''
|
| 268 |
+
if mode == 'strict':
|
| 269 |
+
thr = global_state.clamp_strict_threshold(global_state.strict_threshold)
|
| 270 |
+
suffix = f'\n Strict ratio check ({thr:.2f}) runs at generation time'
|
| 271 |
+
elif mode == 'soft':
|
| 272 |
+
thr = global_state.clamp_strict_threshold(global_state.strict_threshold)
|
| 273 |
+
suffix = (
|
| 274 |
+
f'\n Soft attenuation active: aux will be scaled if '
|
| 275 |
+
f'base/aux ratio < {thr:.2f}'
|
| 276 |
+
)
|
| 277 |
+
return 'ok', f'OK — base segment present{suffix}'
|
neutral_prompt_patcheds/lib_neutral_prompt/step_utils.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Step-window activation control for sd-webui-neutral-prompt.
|
| 3 |
+
|
| 4 |
+
Allows AND_* strategies to be active only during a sub-range of the denoising
|
| 5 |
+
process, specified as a normalised progress window [start, end] ∈ [0, 1].
|
| 6 |
+
|
| 7 |
+
Typical use:
|
| 8 |
+
- AND_ALIGN / AND_PERP : early steps (0.0–0.5) for structure/composition
|
| 9 |
+
- AND_SALT / AND_TOPK : late steps (0.5–1.0) for texture/detail
|
| 10 |
+
|
| 11 |
+
Three levels of control (coarse to fine):
|
| 12 |
+
1. Global window — applies to ALL AND_* strategies.
|
| 13 |
+
2. Per-strategy window — overrides global for a specific strategy family.
|
| 14 |
+
3. Per-node window — (future) per-node override inside matryoshka trees.
|
| 15 |
+
|
| 16 |
+
This module is intentionally pure (no torch, no gradio) so it can be
|
| 17 |
+
imported in both UI and runtime without side effects.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import dataclasses
|
| 23 |
+
from typing import Dict, Optional, Tuple
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
# Data model
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
|
| 30 |
+
@dataclasses.dataclass(frozen=True)
|
| 31 |
+
class StepWindow:
|
| 32 |
+
"""A normalised [start, end] activation window.
|
| 33 |
+
|
| 34 |
+
start and end are clamped to [0, 1] with start ≤ end.
|
| 35 |
+
An all-active window is StepWindow(0.0, 1.0).
|
| 36 |
+
"""
|
| 37 |
+
start: float = 0.0
|
| 38 |
+
end: float = 1.0
|
| 39 |
+
|
| 40 |
+
def __post_init__(self) -> None:
|
| 41 |
+
# Use object.__setattr__ because the dataclass is frozen
|
| 42 |
+
object.__setattr__(self, 'start', max(0.0, min(1.0, float(self.start))))
|
| 43 |
+
object.__setattr__(self, 'end', max(0.0, min(1.0, float(self.end))))
|
| 44 |
+
if self.start > self.end:
|
| 45 |
+
object.__setattr__(self, 'end', self.start)
|
| 46 |
+
|
| 47 |
+
def is_active_at(self, progress: float) -> bool:
|
| 48 |
+
"""Return True if *progress* ∈ [start, end] (inclusive)."""
|
| 49 |
+
return self.start <= progress <= self.end
|
| 50 |
+
|
| 51 |
+
def __repr__(self) -> str:
|
| 52 |
+
return f'StepWindow({self.start:.2f}–{self.end:.2f})'
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
ALWAYS_ACTIVE = StepWindow(0.0, 1.0)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ---------------------------------------------------------------------------
|
| 59 |
+
# Strategy-family presets
|
| 60 |
+
# ---------------------------------------------------------------------------
|
| 61 |
+
# These are sensible defaults reflecting typical SD denoising behaviour:
|
| 62 |
+
# - Global structure/composition → early steps
|
| 63 |
+
# - Local texture/detail → late steps
|
| 64 |
+
# Users can override any entry via the UI.
|
| 65 |
+
|
| 66 |
+
STRATEGY_DEFAULTS: Dict[str, StepWindow] = {
|
| 67 |
+
'PERPENDICULAR': StepWindow(0.0, 0.5), # composition correction
|
| 68 |
+
'SALIENCE_MASK': StepWindow(0.4, 1.0), # local texture detail
|
| 69 |
+
'SALIENCE_MASK_WIDE': StepWindow(0.3, 1.0), # broad region texture
|
| 70 |
+
'SALIENCE_MASK_BLOB': StepWindow(0.4, 1.0), # blob region
|
| 71 |
+
'SEMANTIC_GUIDANCE': StepWindow(0.4, 1.0), # sparse semantic detail
|
| 72 |
+
'ALIGNMENT_BLEND_CUSTOM': StepWindow(0.0, 0.6), # alignment preserves structure
|
| 73 |
+
'ALIGNMENT_MASK_BLEND_CUSTOM': StepWindow(0.0, 0.6),
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _strategy_family(strategy_name: str) -> str:
|
| 78 |
+
"""
|
| 79 |
+
Map a concrete ConciliationStrategy name (including legacy fixed-suffix
|
| 80 |
+
variants like 'ALIGNMENT_BLEND_4_8') to a canonical family key.
|
| 81 |
+
"""
|
| 82 |
+
if strategy_name.startswith('ALIGNMENT_MASK_BLEND_'):
|
| 83 |
+
return 'ALIGNMENT_MASK_BLEND_CUSTOM'
|
| 84 |
+
if strategy_name.startswith('ALIGNMENT_BLEND_'):
|
| 85 |
+
return 'ALIGNMENT_BLEND_CUSTOM'
|
| 86 |
+
return strategy_name
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# ---------------------------------------------------------------------------
|
| 90 |
+
# Progress normalisation
|
| 91 |
+
# ---------------------------------------------------------------------------
|
| 92 |
+
|
| 93 |
+
def normalize_progress(current_step: int, total_steps: int) -> float:
|
| 94 |
+
"""
|
| 95 |
+
Convert an absolute step index to a normalised progress value in [0, 1].
|
| 96 |
+
|
| 97 |
+
The first step (index 0) maps to 0.0 and the last step (index
|
| 98 |
+
total_steps - 1) maps to 1.0. Edge-cases:
|
| 99 |
+
|
| 100 |
+
total_steps = 1 → always returns 0.0
|
| 101 |
+
step 19 of 20 → 1.00
|
| 102 |
+
"""
|
| 103 |
+
denom = max(total_steps - 1, 1)
|
| 104 |
+
return max(0.0, min(1.0, current_step / denom))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def normalize_step_range(start, end) -> 'Tuple[float, float]':
|
| 108 |
+
"""
|
| 109 |
+
Parse and clamp a (start, end) pair to [0, 1].
|
| 110 |
+
|
| 111 |
+
If start > end after clamping, returns (0.0, 1.0) — full range fallback.
|
| 112 |
+
Accepts floats, ints, or stringified numbers.
|
| 113 |
+
"""
|
| 114 |
+
try:
|
| 115 |
+
s = max(0.0, min(1.0, float(start)))
|
| 116 |
+
except (TypeError, ValueError):
|
| 117 |
+
s = 0.0
|
| 118 |
+
try:
|
| 119 |
+
e = max(0.0, min(1.0, float(end)))
|
| 120 |
+
except (TypeError, ValueError):
|
| 121 |
+
e = 1.0
|
| 122 |
+
if s > e:
|
| 123 |
+
return 0.0, 1.0
|
| 124 |
+
return s, e
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# Canonical UI names for each per-strategy group (user-facing labels)
|
| 128 |
+
STRATEGY_UI_KEYS = [
|
| 129 |
+
'AND_PERP',
|
| 130 |
+
'AND_SALT',
|
| 131 |
+
'AND_SALT_WIDE',
|
| 132 |
+
'AND_SALT_BLOB',
|
| 133 |
+
'AND_TOPK',
|
| 134 |
+
'AND_ALIGN',
|
| 135 |
+
'AND_MASK_ALIGN',
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
# Map UI key → canonical family name used in strategy_is_active
|
| 139 |
+
_UI_KEY_TO_FAMILY: Dict[str, str] = {
|
| 140 |
+
'AND_PERP': 'PERPENDICULAR',
|
| 141 |
+
'AND_SALT': 'SALIENCE_MASK',
|
| 142 |
+
'AND_SALT_WIDE': 'SALIENCE_MASK_WIDE',
|
| 143 |
+
'AND_SALT_BLOB': 'SALIENCE_MASK_BLOB',
|
| 144 |
+
'AND_TOPK': 'SEMANTIC_GUIDANCE',
|
| 145 |
+
'AND_ALIGN': 'ALIGNMENT_BLEND_CUSTOM',
|
| 146 |
+
'AND_MASK_ALIGN': 'ALIGNMENT_MASK_BLEND_CUSTOM',
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
_FAMILY_TO_UI_KEY: Dict[str, str] = {v: k for k, v in _UI_KEY_TO_FAMILY.items()}
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def ui_key_to_family(ui_key: str) -> str:
|
| 153 |
+
"""Convert an AND_* UI key to its canonical family name."""
|
| 154 |
+
return _UI_KEY_TO_FAMILY.get(ui_key, ui_key)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def family_to_ui_key(family: str) -> str:
|
| 158 |
+
"""Convert a family name back to its AND_* UI key."""
|
| 159 |
+
return _FAMILY_TO_UI_KEY.get(family, family)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def build_per_strategy_windows(
|
| 163 |
+
custom_dict: Dict[str, 'Tuple[float, float]'],
|
| 164 |
+
) -> Dict[str, StepWindow]:
|
| 165 |
+
"""
|
| 166 |
+
Convert a raw {AND_* UI key → (start, end)} dict to the
|
| 167 |
+
{family_name → StepWindow} format expected by strategy_is_active.
|
| 168 |
+
|
| 169 |
+
Entries with invalid ranges are repaired by normalize_step_range.
|
| 170 |
+
Unrecognised keys are silently skipped.
|
| 171 |
+
"""
|
| 172 |
+
result: Dict[str, StepWindow] = {}
|
| 173 |
+
for ui_key, (start, end) in custom_dict.items():
|
| 174 |
+
if ui_key not in _UI_KEY_TO_FAMILY:
|
| 175 |
+
continue # silently skip unrecognised keys
|
| 176 |
+
family = ui_key_to_family(ui_key)
|
| 177 |
+
s, e = normalize_step_range(start, end)
|
| 178 |
+
result[family] = StepWindow(s, e)
|
| 179 |
+
return result
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def default_custom_windows() -> Dict[str, 'Tuple[float, float]']:
|
| 183 |
+
"""
|
| 184 |
+
Return the STRATEGY_DEFAULTS as a {AND_* UI key → (start, end)} dict —
|
| 185 |
+
the natural starting point for the per-strategy custom editor.
|
| 186 |
+
"""
|
| 187 |
+
out: Dict[str, 'Tuple[float, float]'] = {}
|
| 188 |
+
for family, win in STRATEGY_DEFAULTS.items():
|
| 189 |
+
ui_key = _FAMILY_TO_UI_KEY.get(family)
|
| 190 |
+
if ui_key:
|
| 191 |
+
out[ui_key] = (win.start, win.end)
|
| 192 |
+
return out
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def serialize_per_strategy_windows(
|
| 196 |
+
custom_dict: Dict[str, 'Tuple[float, float]'],
|
| 197 |
+
) -> str:
|
| 198 |
+
"""
|
| 199 |
+
Serialise to a compact string for generation params / infotext.
|
| 200 |
+
|
| 201 |
+
Format: ``AND_PERP:0.00-0.50,AND_SALT:0.40-1.00,...``
|
| 202 |
+
"""
|
| 203 |
+
parts = []
|
| 204 |
+
for ui_key in STRATEGY_UI_KEYS:
|
| 205 |
+
if ui_key in custom_dict:
|
| 206 |
+
s, e = custom_dict[ui_key]
|
| 207 |
+
parts.append(f'{ui_key}:{s:.4f}-{e:.4f}')
|
| 208 |
+
return ','.join(parts)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def deserialize_per_strategy_windows(
|
| 212 |
+
raw: str,
|
| 213 |
+
) -> Dict[str, 'Tuple[float, float]']:
|
| 214 |
+
"""
|
| 215 |
+
Parse the compact string back to {AND_* UI key → (start, end)}.
|
| 216 |
+
Tolerates malformed entries (skips them silently).
|
| 217 |
+
"""
|
| 218 |
+
result: Dict[str, 'Tuple[float, float]'] = {}
|
| 219 |
+
for token in raw.split(','):
|
| 220 |
+
token = token.strip()
|
| 221 |
+
if not token:
|
| 222 |
+
continue
|
| 223 |
+
if ':' not in token or '-' not in token:
|
| 224 |
+
continue
|
| 225 |
+
try:
|
| 226 |
+
ui_key, range_part = token.split(':', 1)
|
| 227 |
+
ui_key = ui_key.strip()
|
| 228 |
+
start_str, end_str = range_part.split('-', 1)
|
| 229 |
+
s, e = normalize_step_range(start_str.strip(), end_str.strip())
|
| 230 |
+
if ui_key in _UI_KEY_TO_FAMILY:
|
| 231 |
+
result[ui_key] = (s, e)
|
| 232 |
+
except Exception:
|
| 233 |
+
continue
|
| 234 |
+
return result
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
"""
|
| 239 |
+
Convert an absolute step counter to a normalised progress value [0, 1].
|
| 240 |
+
|
| 241 |
+
Examples:
|
| 242 |
+
step 0 of 20 → 0.00
|
| 243 |
+
step 10 of 20 → 0.526 (10/19)
|
| 244 |
+
step 19 of 20 → 1.00
|
| 245 |
+
"""
|
| 246 |
+
denom = max(total_steps - 1, 1)
|
| 247 |
+
return max(0.0, min(1.0, current_step / denom))
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
# ---------------------------------------------------------------------------
|
| 251 |
+
# Activation decision
|
| 252 |
+
# ---------------------------------------------------------------------------
|
| 253 |
+
|
| 254 |
+
def strategy_is_active(
|
| 255 |
+
strategy_name: str,
|
| 256 |
+
progress: float,
|
| 257 |
+
global_window: Optional[StepWindow] = None,
|
| 258 |
+
per_strategy_windows: Optional[Dict[str, StepWindow]] = None,
|
| 259 |
+
use_defaults: bool = False,
|
| 260 |
+
) -> bool:
|
| 261 |
+
"""
|
| 262 |
+
Return True if *strategy_name* should be applied at the given *progress*.
|
| 263 |
+
|
| 264 |
+
Priority order (highest wins):
|
| 265 |
+
1. Per-strategy window (if provided and contains this strategy's family).
|
| 266 |
+
2. Global window (if provided).
|
| 267 |
+
3. STRATEGY_DEFAULTS (if use_defaults=True).
|
| 268 |
+
4. ALWAYS_ACTIVE otherwise.
|
| 269 |
+
|
| 270 |
+
Parameters
|
| 271 |
+
----------
|
| 272 |
+
strategy_name :
|
| 273 |
+
ConciliationStrategy.name or a raw strategy string.
|
| 274 |
+
progress :
|
| 275 |
+
Normalised step progress in [0, 1].
|
| 276 |
+
global_window :
|
| 277 |
+
A single window applied to every strategy (unless overridden).
|
| 278 |
+
per_strategy_windows :
|
| 279 |
+
Dict mapping strategy family name → StepWindow. Overrides global.
|
| 280 |
+
use_defaults :
|
| 281 |
+
If True, fall back to STRATEGY_DEFAULTS when no explicit window is set.
|
| 282 |
+
"""
|
| 283 |
+
family = _strategy_family(strategy_name)
|
| 284 |
+
|
| 285 |
+
# Per-strategy override
|
| 286 |
+
if per_strategy_windows and family in per_strategy_windows:
|
| 287 |
+
return per_strategy_windows[family].is_active_at(progress)
|
| 288 |
+
|
| 289 |
+
# Global override
|
| 290 |
+
if global_window is not None:
|
| 291 |
+
return global_window.is_active_at(progress)
|
| 292 |
+
|
| 293 |
+
# Preset defaults
|
| 294 |
+
if use_defaults and family in STRATEGY_DEFAULTS:
|
| 295 |
+
return STRATEGY_DEFAULTS[family].is_active_at(progress)
|
| 296 |
+
|
| 297 |
+
return True # ALWAYS_ACTIVE
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
# ---------------------------------------------------------------------------
|
| 301 |
+
# global_state integration helpers
|
| 302 |
+
# ---------------------------------------------------------------------------
|
| 303 |
+
|
| 304 |
+
def get_global_window() -> Optional[StepWindow]:
|
| 305 |
+
"""Read the global step window from global_state (may be None = disabled)."""
|
| 306 |
+
try:
|
| 307 |
+
from lib_neutral_prompt import global_state as _gs
|
| 308 |
+
return _gs.step_window_global
|
| 309 |
+
except AttributeError:
|
| 310 |
+
return None
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def get_per_strategy_windows() -> Dict[str, StepWindow]:
|
| 314 |
+
"""Read per-strategy windows from global_state (empty dict if not set)."""
|
| 315 |
+
try:
|
| 316 |
+
from lib_neutral_prompt import global_state as _gs
|
| 317 |
+
return _gs.step_window_per_strategy or {}
|
| 318 |
+
except AttributeError:
|
| 319 |
+
return {}
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def get_use_defaults() -> bool:
|
| 323 |
+
"""Read whether STRATEGY_DEFAULTS are active."""
|
| 324 |
+
try:
|
| 325 |
+
from lib_neutral_prompt import global_state as _gs
|
| 326 |
+
return bool(_gs.step_window_use_defaults)
|
| 327 |
+
except AttributeError:
|
| 328 |
+
return False
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def strategy_is_active_from_state(
|
| 332 |
+
strategy_name: str,
|
| 333 |
+
progress: float,
|
| 334 |
+
) -> bool:
|
| 335 |
+
"""
|
| 336 |
+
Convenience wrapper: read all step-window settings from global_state and
|
| 337 |
+
return whether *strategy_name* is active at *progress*.
|
| 338 |
+
|
| 339 |
+
Also handles Lock after End: if ``global_state.step_window_lock_after_end``
|
| 340 |
+
is True and the strategy's window has ended, it sets a permanent lock flag
|
| 341 |
+
in ``global_state._step_lock_flags[family]`` and returns False for all
|
| 342 |
+
subsequent calls this generation.
|
| 343 |
+
"""
|
| 344 |
+
from lib_neutral_prompt import global_state as _gs
|
| 345 |
+
|
| 346 |
+
family = _strategy_family(strategy_name)
|
| 347 |
+
|
| 348 |
+
# ── Lock after End: check permanent lock first ─────────────────────────
|
| 349 |
+
lock_enabled = getattr(_gs, 'step_window_lock_after_end', False)
|
| 350 |
+
if lock_enabled:
|
| 351 |
+
lock_flags = getattr(_gs, '_step_lock_flags', None)
|
| 352 |
+
if lock_flags is None:
|
| 353 |
+
_gs._step_lock_flags = {}
|
| 354 |
+
lock_flags = _gs._step_lock_flags
|
| 355 |
+
if lock_flags.get(family, False):
|
| 356 |
+
return False # permanently frozen this generation
|
| 357 |
+
|
| 358 |
+
# ── Normal activity check ───────────────────────────────────────────────
|
| 359 |
+
active = strategy_is_active(
|
| 360 |
+
strategy_name,
|
| 361 |
+
progress,
|
| 362 |
+
global_window=get_global_window(),
|
| 363 |
+
per_strategy_windows=get_per_strategy_windows(),
|
| 364 |
+
use_defaults=get_use_defaults(),
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# ── Lock after End: set flag when window has closed ─────────────────────
|
| 368 |
+
if lock_enabled and not active:
|
| 369 |
+
# Only lock once we have actually been past the window start
|
| 370 |
+
# (avoids locking strategies that haven't started yet)
|
| 371 |
+
window = _resolve_window_for_family(family)
|
| 372 |
+
if window is not None and progress >= window.start:
|
| 373 |
+
# We were inside the window at some point; window has now ended
|
| 374 |
+
lock_flags[family] = True # type: ignore[index]
|
| 375 |
+
|
| 376 |
+
return active
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def _resolve_window_for_family(family: str) -> Optional[StepWindow]:
|
| 380 |
+
"""
|
| 381 |
+
Return the effective StepWindow for *family* given the current global_state,
|
| 382 |
+
or None if no window restriction applies.
|
| 383 |
+
Used by Lock after End to decide whether to set the lock flag.
|
| 384 |
+
"""
|
| 385 |
+
try:
|
| 386 |
+
from lib_neutral_prompt import global_state as _gs
|
| 387 |
+
per = getattr(_gs, 'step_window_per_strategy', None) or {}
|
| 388 |
+
if family in per:
|
| 389 |
+
return per[family]
|
| 390 |
+
gw = getattr(_gs, 'step_window_global', None)
|
| 391 |
+
if gw is not None:
|
| 392 |
+
return gw
|
| 393 |
+
if getattr(_gs, 'step_window_use_defaults', False):
|
| 394 |
+
return STRATEGY_DEFAULTS.get(family)
|
| 395 |
+
except Exception:
|
| 396 |
+
pass
|
| 397 |
+
return None
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
# ---------------------------------------------------------------------------
|
| 401 |
+
# Explain / diagnostic summary
|
| 402 |
+
# ---------------------------------------------------------------------------
|
| 403 |
+
|
| 404 |
+
def render_step_window_summary(
|
| 405 |
+
strategies: list,
|
| 406 |
+
progress: Optional[float] = None,
|
| 407 |
+
global_window: Optional[StepWindow] = None,
|
| 408 |
+
per_strategy_windows: Optional[Dict[str, StepWindow]] = None,
|
| 409 |
+
use_defaults: bool = False,
|
| 410 |
+
) -> str:
|
| 411 |
+
"""
|
| 412 |
+
Build a human-readable summary of step-window settings for the debug panel.
|
| 413 |
+
|
| 414 |
+
*strategies* is a list of strategy-name strings (from diagnostics['strategies']).
|
| 415 |
+
*progress* is the current (or hypothetical) normalised step (optional).
|
| 416 |
+
"""
|
| 417 |
+
if not strategies:
|
| 418 |
+
return ''
|
| 419 |
+
|
| 420 |
+
lines = ['── Step windows ─────────────────────────']
|
| 421 |
+
|
| 422 |
+
# Header: show effective global window or "all steps"
|
| 423 |
+
if global_window is not None:
|
| 424 |
+
lines.append(f'Global window : {global_window}')
|
| 425 |
+
elif use_defaults:
|
| 426 |
+
lines.append('Global window : (per-strategy defaults active)')
|
| 427 |
+
else:
|
| 428 |
+
lines.append('Global window : all steps (no restriction)')
|
| 429 |
+
|
| 430 |
+
for s in strategies:
|
| 431 |
+
if s == 'BASE':
|
| 432 |
+
continue
|
| 433 |
+
family = _strategy_family(s)
|
| 434 |
+
# Resolve effective window
|
| 435 |
+
if per_strategy_windows and family in per_strategy_windows:
|
| 436 |
+
win = per_strategy_windows[family]
|
| 437 |
+
src = 'custom'
|
| 438 |
+
elif global_window is not None:
|
| 439 |
+
win = global_window
|
| 440 |
+
src = 'global'
|
| 441 |
+
elif use_defaults and family in STRATEGY_DEFAULTS:
|
| 442 |
+
win = STRATEGY_DEFAULTS[family]
|
| 443 |
+
src = 'default'
|
| 444 |
+
else:
|
| 445 |
+
win = ALWAYS_ACTIVE
|
| 446 |
+
src = 'always'
|
| 447 |
+
|
| 448 |
+
if progress is not None:
|
| 449 |
+
active = '✓ active' if win.is_active_at(progress) else '✗ gated'
|
| 450 |
+
lines.append(f' {s:32s} {win} [{src}] @ {progress:.2f} → {active}')
|
| 451 |
+
else:
|
| 452 |
+
lines.append(f' {s:32s} {win} [{src}]')
|
| 453 |
+
|
| 454 |
+
return '\n'.join(lines)
|
neutral_prompt_patcheds/lib_neutral_prompt/ui.py
ADDED
|
@@ -0,0 +1,811 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI for sd-webui-neutral-prompt — v1.1 "Matryoshka visibility"
|
| 3 |
+
|
| 4 |
+
New in v1.1:
|
| 5 |
+
- Tree-style debug/explain with branch connectors and full diagnostics
|
| 6 |
+
- Matryoshka builder v1: 2-level visual composer with live preview
|
| 7 |
+
- Matryoshka templates: 7 ready-made nested recipes
|
| 8 |
+
- AND_TOPK[threshold] slider
|
| 9 |
+
- Affine transform builder with presets
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
| 15 |
+
import dataclasses
|
| 16 |
+
import gradio as gr
|
| 17 |
+
|
| 18 |
+
from lib_neutral_prompt import global_state, neutral_prompt_parser
|
| 19 |
+
from lib_neutral_prompt.matryoshka_utils import (
|
| 20 |
+
build_child_block,
|
| 21 |
+
build_nested_prompt,
|
| 22 |
+
render_prompt_node,
|
| 23 |
+
render_prompt_tree,
|
| 24 |
+
collect_prompt_diagnostics,
|
| 25 |
+
render_diagnostics,
|
| 26 |
+
render_full_explain,
|
| 27 |
+
MATRYOSHKA_TEMPLATES,
|
| 28 |
+
BUILDER_STRATEGIES,
|
| 29 |
+
_KW_DEFAULTS,
|
| 30 |
+
)
|
| 31 |
+
from lib_neutral_prompt.affine_utils import (
|
| 32 |
+
AFFINE_TRANSFORMS as _AFFINE_TRANSFORMS,
|
| 33 |
+
AFFINE_PRESETS as _AFFINE_PRESETS,
|
| 34 |
+
build_affine_snippet as _build_affine_snippet,
|
| 35 |
+
)
|
| 36 |
+
from lib_neutral_prompt.step_utils import (
|
| 37 |
+
StepWindow, STRATEGY_DEFAULTS,
|
| 38 |
+
STRATEGY_UI_KEYS as _STRATEGY_UI_KEYS,
|
| 39 |
+
build_per_strategy_windows as _build_per_strategy_windows,
|
| 40 |
+
serialize_per_strategy_windows as _serialize_per_strategy_windows,
|
| 41 |
+
deserialize_per_strategy_windows as _deserialize_per_strategy_windows,
|
| 42 |
+
default_custom_windows as _step_default_custom,
|
| 43 |
+
)
|
| 44 |
+
from modules import script_callbacks, shared
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
txt2img_prompt_textbox = None
|
| 48 |
+
img2img_prompt_textbox = None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
# Prompt-type registry
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
|
| 55 |
+
prompt_types: Dict[str, str] = {
|
| 56 |
+
'Perpendicular (AND_PERP)': neutral_prompt_parser.PromptKeyword.AND_PERP.value,
|
| 57 |
+
'Saliency sharp (AND_SALT)': neutral_prompt_parser.PromptKeyword.AND_SALT.value,
|
| 58 |
+
'Saliency blob (AND_SALT_BLOB)': neutral_prompt_parser.PromptKeyword.AND_SALT_BLOB.value,
|
| 59 |
+
'Saliency wide / classic (AND_SALT_WIDE)': neutral_prompt_parser.PromptKeyword.AND_SALT_WIDE.value,
|
| 60 |
+
'Semantic guidance top-k (AND_TOPK)': neutral_prompt_parser.PromptKeyword.AND_TOPK.value,
|
| 61 |
+
'Alignment blend — custom D/S (AND_ALIGN)': neutral_prompt_parser.PromptKeyword.AND_ALIGN.value,
|
| 62 |
+
'Alignment mask — custom D/S (AND_MASK_ALIGN)': neutral_prompt_parser.PromptKeyword.AND_MASK_ALIGN.value,
|
| 63 |
+
}
|
| 64 |
+
for _d, _s in ((2, 4), (2, 8), (4, 8), (4, 16), (8, 16), (8, 32)):
|
| 65 |
+
prompt_types[f'Alignment blend detail={_d} structure={_s} (AND_ALIGN_{_d}_{_s})'] = \
|
| 66 |
+
getattr(neutral_prompt_parser.PromptKeyword, f'AND_ALIGN_{_d}_{_s}').value
|
| 67 |
+
prompt_types[f'Alignment mask detail={_d} structure={_s} (AND_MASK_ALIGN_{_d}_{_s})'] = \
|
| 68 |
+
getattr(neutral_prompt_parser.PromptKeyword, f'AND_MASK_ALIGN_{_d}_{_s}').value
|
| 69 |
+
|
| 70 |
+
_SALT_LABELS = frozenset({'Saliency sharp (AND_SALT)', 'Saliency blob (AND_SALT_BLOB)',
|
| 71 |
+
'Saliency wide / classic (AND_SALT_WIDE)'})
|
| 72 |
+
_TOPK_LABELS = frozenset({'Semantic guidance top-k (AND_TOPK)'})
|
| 73 |
+
_ALIGN_CUSTOM_LABELS = frozenset({'Alignment blend — custom D/S (AND_ALIGN)',
|
| 74 |
+
'Alignment mask — custom D/S (AND_MASK_ALIGN)'})
|
| 75 |
+
|
| 76 |
+
prompt_types_tooltip = (
|
| 77 |
+
'AND – add all features equally (webui built-in)\n'
|
| 78 |
+
'AND_PERP – reduce contradicting features via perpendicular projection\n'
|
| 79 |
+
'AND_SALT[k] – sharp saliency mask (default k=5; higher k → fewer pixels)\n'
|
| 80 |
+
'AND_SALT_BLOB[k] – blob saliency: grown seed region\n'
|
| 81 |
+
'AND_SALT_WIDE[k] – broad saliency mask (default k=1; ~55% of pixels)\n'
|
| 82 |
+
'AND_TOPK[t] – small targeted changes (top-t fraction; default 0.05)\n'
|
| 83 |
+
'\n'
|
| 84 |
+
'AND_ALIGN[D,S] – soft blend: detail without breaking structure\n'
|
| 85 |
+
'AND_MASK_ALIGN[D,S] – binary-mask variant of AND_ALIGN\n'
|
| 86 |
+
' D = detail kernel [2,32] S = structure kernel [2,32] D ≠ S\n'
|
| 87 |
+
'\n'
|
| 88 |
+
'Affine: ROTATE[angle] SLIDE[x,y] SCALE[x,y] SHEAR[x,y]\n'
|
| 89 |
+
'\n'
|
| 90 |
+
'Matryoshka nesting: AND_SALT[5] [ inner text AND_TOPK[0.1] detail ]'
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _build_keyword(label: str, salt_k: float, align_d: int, align_s: int,
|
| 95 |
+
topk_threshold: float = 0.05) -> str:
|
| 96 |
+
kw = prompt_types[label]
|
| 97 |
+
if label in _SALT_LABELS:
|
| 98 |
+
return f'{kw}[{salt_k}]'
|
| 99 |
+
if label in _TOPK_LABELS:
|
| 100 |
+
return f'{kw}[{topk_threshold}]'
|
| 101 |
+
if label in _ALIGN_CUSTOM_LABELS:
|
| 102 |
+
d, s = int(align_d), int(align_s)
|
| 103 |
+
if d != s and 2 <= d <= 32 and 2 <= s <= 32:
|
| 104 |
+
return f'{kw}[{d},{s}]'
|
| 105 |
+
return kw
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _build_step_window_preview(mode: str, start: float, end: float,
|
| 110 |
+
custom_raw: Optional[Dict] = None) -> str:
|
| 111 |
+
"""Build a human-readable summary of the step-window setting."""
|
| 112 |
+
if mode == 'global':
|
| 113 |
+
win = StepWindow(start, end)
|
| 114 |
+
lines = [f'Global window: {win}', '']
|
| 115 |
+
for strat in ('PERPENDICULAR', 'SALIENCE_MASK', 'SALIENCE_MASK_WIDE',
|
| 116 |
+
'SEMANTIC_GUIDANCE', 'ALIGNMENT_BLEND_CUSTOM'):
|
| 117 |
+
lines.append(f' {strat:34s} → {win}')
|
| 118 |
+
return '\n'.join(lines)
|
| 119 |
+
elif mode == 'per-strategy defaults':
|
| 120 |
+
lines = ['Per-strategy defaults:', '']
|
| 121 |
+
for strat, win in STRATEGY_DEFAULTS.items():
|
| 122 |
+
lines.append(f' {strat:34s} → {win}')
|
| 123 |
+
return '\n'.join(lines)
|
| 124 |
+
elif mode == 'per-strategy custom' and custom_raw:
|
| 125 |
+
from lib_neutral_prompt.step_utils import StepWindow as _SW, normalize_step_range
|
| 126 |
+
lines = ['Per-strategy custom windows:', '']
|
| 127 |
+
for ui_key in _STRATEGY_UI_KEYS:
|
| 128 |
+
if ui_key in custom_raw:
|
| 129 |
+
s, e = custom_raw[ui_key]
|
| 130 |
+
lines.append(f' {ui_key:16s} → {_SW(s, e)}')
|
| 131 |
+
return '\n'.join(lines)
|
| 132 |
+
return ''
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# ---------------------------------------------------------------------------
|
| 136 |
+
# UI component class
|
| 137 |
+
# ---------------------------------------------------------------------------
|
| 138 |
+
|
| 139 |
+
@dataclasses.dataclass
|
| 140 |
+
class AccordionInterface:
|
| 141 |
+
get_elem_id: Callable[[str], str]
|
| 142 |
+
|
| 143 |
+
def __post_init__(self):
|
| 144 |
+
self.is_rendered = False
|
| 145 |
+
|
| 146 |
+
self.cfg_rescale = gr.Slider(
|
| 147 |
+
label='CFG rescale φ', minimum=0, maximum=1, value=0,
|
| 148 |
+
info='0 = disabled. Rescales CFG output to reduce over-saturation.',
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# --- Base-prompt protection ---
|
| 152 |
+
self.protection_mode = gr.Radio(
|
| 153 |
+
label='Base prompt protection', choices=['off', 'auto', 'strict', 'soft'],
|
| 154 |
+
value=global_state.protection_mode,
|
| 155 |
+
info=(
|
| 156 |
+
'off — no guard. '
|
| 157 |
+
'auto — structural guard (hard fallback). '
|
| 158 |
+
'strict — structural + ratio guard (hard fallback). '
|
| 159 |
+
'soft — structural guard + ratio attenuation (no fallback, visible effect).'
|
| 160 |
+
),
|
| 161 |
+
)
|
| 162 |
+
self.strict_threshold = gr.Slider(
|
| 163 |
+
label='Ratio threshold (strict / soft)', minimum=0.01, maximum=0.5, step=0.01,
|
| 164 |
+
value=global_state.strict_threshold,
|
| 165 |
+
info='strict: fallback when base/aux < threshold. soft: attenuate aux until ratio ≥ threshold.',
|
| 166 |
+
visible=(global_state.normalize_protection_mode(global_state.protection_mode) in ('strict', 'soft')),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# --- Prompt formatter ---
|
| 170 |
+
self.neutral_prompt = gr.Textbox(
|
| 171 |
+
label='Neutral prompt', show_label=False, lines=3,
|
| 172 |
+
placeholder='Auxiliary prompt (press "Apply to prompt" to append)',
|
| 173 |
+
)
|
| 174 |
+
self.neutral_cond_scale = gr.Slider(label='Prompt weight', minimum=-3, maximum=3, value=1)
|
| 175 |
+
self.aux_prompt_type = gr.Dropdown(
|
| 176 |
+
label='Prompt type', choices=list(prompt_types.keys()),
|
| 177 |
+
value=next(iter(prompt_types.keys())), info=prompt_types_tooltip,
|
| 178 |
+
elem_id=self.get_elem_id('formatter_prompt_type'),
|
| 179 |
+
)
|
| 180 |
+
_def = next(iter(prompt_types.keys()))
|
| 181 |
+
self.salt_k = gr.Slider(
|
| 182 |
+
label='Salience sharpness k', minimum=0.5, maximum=20.0, step=0.5, value=5.0,
|
| 183 |
+
info='Low k → broad. High k → surgical. Default 5.',
|
| 184 |
+
visible=(_def in _SALT_LABELS),
|
| 185 |
+
)
|
| 186 |
+
self.topk_threshold = gr.Slider(
|
| 187 |
+
label='Top-k threshold', minimum=0.01, maximum=0.5, step=0.01, value=0.05,
|
| 188 |
+
info='Fraction of elements that receive the contribution. Default 0.05.',
|
| 189 |
+
visible=(_def in _TOPK_LABELS),
|
| 190 |
+
)
|
| 191 |
+
self.align_detail = gr.Slider(
|
| 192 |
+
label='Alignment detail kernel D', minimum=2, maximum=32, step=1, value=4,
|
| 193 |
+
info='Fine detail kernel size. D ≠ S.', visible=False,
|
| 194 |
+
)
|
| 195 |
+
self.align_structure = gr.Slider(
|
| 196 |
+
label='Alignment structure kernel S', minimum=2, maximum=32, step=1, value=8,
|
| 197 |
+
info='Global structure kernel size. D ≠ S.', visible=False,
|
| 198 |
+
)
|
| 199 |
+
self.append_to_prompt_button = gr.Button(value='Apply to prompt')
|
| 200 |
+
|
| 201 |
+
# --- Affine formatter ---
|
| 202 |
+
self.affine_preset = gr.Dropdown(
|
| 203 |
+
label='Affine preset', choices=list(_AFFINE_PRESETS.keys()), value='Custom',
|
| 204 |
+
)
|
| 205 |
+
self.affine_transform = gr.Dropdown(
|
| 206 |
+
label='Transform', choices=_AFFINE_TRANSFORMS, value='ROTATE',
|
| 207 |
+
)
|
| 208 |
+
self.affine_p1 = gr.Slider(
|
| 209 |
+
label='Parameter 1 (angle / x / scale)', minimum=-2.0, maximum=2.0, step=0.01,
|
| 210 |
+
value=0.125,
|
| 211 |
+
info='ROTATE: turns (0.25=90°). SCALE/SLIDE/SHEAR: x-component.',
|
| 212 |
+
)
|
| 213 |
+
self.affine_p2 = gr.Slider(
|
| 214 |
+
label='Parameter 2 (y / scale-y)', minimum=-2.0, maximum=2.0, step=0.01, value=1.0,
|
| 215 |
+
info='y-component. Ignored for ROTATE, FLIP_H, FLIP_V.',
|
| 216 |
+
)
|
| 217 |
+
self.affine_preview = gr.Textbox(
|
| 218 |
+
label='Affine snippet', interactive=False,
|
| 219 |
+
value=_build_affine_snippet('ROTATE', 0.125, 1.0),
|
| 220 |
+
)
|
| 221 |
+
self.affine_insert_button = gr.Button(value='Insert affine into prompt')
|
| 222 |
+
|
| 223 |
+
# --- Step window ---
|
| 224 |
+
self.step_window_enabled = gr.Checkbox(
|
| 225 |
+
label='Enable step-window gating for AND_*', value=False,
|
| 226 |
+
info='Restrict each strategy to a sub-range of the denoising steps.',
|
| 227 |
+
)
|
| 228 |
+
self.step_window_lock = gr.Checkbox(
|
| 229 |
+
label='🔒 Lock after End', value=False, visible=False,
|
| 230 |
+
info='После последнего активного диапазона — заморозить навсегда. '
|
| 231 |
+
'Стратегия не возобновится, даже если шаги войдут в другой диапазон.',
|
| 232 |
+
)
|
| 233 |
+
self.step_window_mode = gr.Radio(
|
| 234 |
+
label='Step window mode',
|
| 235 |
+
choices=['global', 'per-strategy defaults', 'per-strategy custom'],
|
| 236 |
+
value='global', visible=False,
|
| 237 |
+
info='"global" applies one window to all AND_* strategies. '
|
| 238 |
+
'"per-strategy defaults" uses built-in schedule (PERP/ALIGN early, SALT/TOPK late). '
|
| 239 |
+
'"per-strategy custom" lets you set each window independently.',
|
| 240 |
+
)
|
| 241 |
+
self.step_window_start = gr.Slider(
|
| 242 |
+
label='Global window start', minimum=0.0, maximum=1.0, step=0.01, value=0.0,
|
| 243 |
+
info='Normalised progress: 0 = first step, 1 = last step.',
|
| 244 |
+
visible=False,
|
| 245 |
+
)
|
| 246 |
+
self.step_window_end = gr.Slider(
|
| 247 |
+
label='Global window end', minimum=0.0, maximum=1.0, step=0.01, value=1.0,
|
| 248 |
+
visible=False,
|
| 249 |
+
)
|
| 250 |
+
# Per-strategy custom sliders — one start/end pair per strategy
|
| 251 |
+
_ps_defaults = _step_default_custom()
|
| 252 |
+
self.step_ps_sliders: Dict[str, Tuple] = {} # {ui_key: (start_slider, end_slider)}
|
| 253 |
+
for _ui_key in _STRATEGY_UI_KEYS:
|
| 254 |
+
_s, _e = _ps_defaults.get(_ui_key, (0.0, 1.0))
|
| 255 |
+
self.step_ps_sliders[_ui_key] = (
|
| 256 |
+
gr.Slider(label=f'{_ui_key} start', minimum=0.0, maximum=1.0,
|
| 257 |
+
step=0.01, value=_s, visible=False),
|
| 258 |
+
gr.Slider(label=f'{_ui_key} end', minimum=0.0, maximum=1.0,
|
| 259 |
+
step=0.01, value=_e, visible=False),
|
| 260 |
+
)
|
| 261 |
+
self.step_ps_reset_btn = gr.Button(value='Reset to defaults', visible=False)
|
| 262 |
+
self.step_window_preview = gr.Textbox(
|
| 263 |
+
label='Step window preview', interactive=False, visible=False,
|
| 264 |
+
value='',
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# --- Matryoshka builder ---
|
| 268 |
+
self.builder_base = gr.Textbox(
|
| 269 |
+
label='Base prompt', lines=2, placeholder='main subject, environment, style…',
|
| 270 |
+
)
|
| 271 |
+
# Child 1
|
| 272 |
+
self.builder_ch1_strategy = gr.Dropdown(
|
| 273 |
+
label='Child 1 — strategy', choices=BUILDER_STRATEGIES, value='AND_SALT',
|
| 274 |
+
)
|
| 275 |
+
self.builder_ch1_text = gr.Textbox(
|
| 276 |
+
label='Child 1 — text', lines=1, placeholder='concept, texture, style…',
|
| 277 |
+
)
|
| 278 |
+
self.builder_ch1_weight = gr.Slider(
|
| 279 |
+
label='Child 1 — weight', minimum=0.0, maximum=2.0, step=0.05, value=0.8,
|
| 280 |
+
)
|
| 281 |
+
self.builder_ch1_affine = gr.Textbox(
|
| 282 |
+
label='Child 1 — affine (optional)', lines=1,
|
| 283 |
+
placeholder='e.g. ROTATE[0.125] or SCALE[-1,1]',
|
| 284 |
+
)
|
| 285 |
+
# Optional nested child inside child 1
|
| 286 |
+
self.builder_ch1_nested_enable = gr.Checkbox(label='Add nested child inside child 1', value=False)
|
| 287 |
+
self.builder_ch1_nested_strategy = gr.Dropdown(
|
| 288 |
+
label='Nested — strategy', choices=BUILDER_STRATEGIES, value='AND_TOPK', visible=False,
|
| 289 |
+
)
|
| 290 |
+
self.builder_ch1_nested_text = gr.Textbox(
|
| 291 |
+
label='Nested — text', lines=1, placeholder='inner concept…', visible=False,
|
| 292 |
+
)
|
| 293 |
+
self.builder_ch1_nested_weight = gr.Slider(
|
| 294 |
+
label='Nested — weight', minimum=0.0, maximum=2.0, step=0.05, value=0.5, visible=False,
|
| 295 |
+
)
|
| 296 |
+
# Child 2
|
| 297 |
+
self.builder_ch2_enable = gr.Checkbox(label='Add second child', value=False)
|
| 298 |
+
self.builder_ch2_strategy = gr.Dropdown(
|
| 299 |
+
label='Child 2 — strategy', choices=BUILDER_STRATEGIES, value='AND_PERP', visible=False,
|
| 300 |
+
)
|
| 301 |
+
self.builder_ch2_text = gr.Textbox(
|
| 302 |
+
label='Child 2 — text', lines=1, placeholder='concept…', visible=False,
|
| 303 |
+
)
|
| 304 |
+
self.builder_ch2_weight = gr.Slider(
|
| 305 |
+
label='Child 2 — weight', minimum=0.0, maximum=2.0, step=0.05, value=0.6, visible=False,
|
| 306 |
+
)
|
| 307 |
+
# Output
|
| 308 |
+
self.builder_preview = gr.Textbox(
|
| 309 |
+
label='Generated prompt (preview)', lines=7, interactive=False,
|
| 310 |
+
)
|
| 311 |
+
self.builder_apply_button = gr.Button(value='Apply to prompt')
|
| 312 |
+
self.builder_copy_button = gr.Button(value='Copy generated prompt')
|
| 313 |
+
self.builder_explain_button = gr.Button(value='Explain this prompt')
|
| 314 |
+
|
| 315 |
+
# --- Templates ---
|
| 316 |
+
self.template_dropdown = gr.Dropdown(
|
| 317 |
+
label='Matryoshka template', choices=list(MATRYOSHKA_TEMPLATES.keys()),
|
| 318 |
+
value=next(iter(MATRYOSHKA_TEMPLATES.keys())),
|
| 319 |
+
)
|
| 320 |
+
self.template_description = gr.Textbox(
|
| 321 |
+
label='Template description', interactive=False,
|
| 322 |
+
value=next(iter(MATRYOSHKA_TEMPLATES.values()))['description'],
|
| 323 |
+
)
|
| 324 |
+
self.template_preview = gr.Textbox(
|
| 325 |
+
label='Template prompt', lines=8, interactive=False,
|
| 326 |
+
value=next(iter(MATRYOSHKA_TEMPLATES.values()))['prompt'],
|
| 327 |
+
)
|
| 328 |
+
self.template_load_button = gr.Button(value='Load template into prompt')
|
| 329 |
+
|
| 330 |
+
# --- Debug / explain (matryoshka tree) ---
|
| 331 |
+
self.debug_prompt_input = gr.Textbox(
|
| 332 |
+
label='Prompt to explain', lines=5,
|
| 333 |
+
placeholder='Paste any prompt here, including nested ones, to see the full parse tree…',
|
| 334 |
+
)
|
| 335 |
+
self.debug_output = gr.Textbox(
|
| 336 |
+
label='Parse tree + diagnostics (structural preview — not a full runtime simulation)',
|
| 337 |
+
lines=18, interactive=False,
|
| 338 |
+
value='(enter a prompt above)',
|
| 339 |
+
)
|
| 340 |
+
self.debug_parse_button = gr.Button(value='Explain prompt')
|
| 341 |
+
self.debug_copy_button = gr.Button(value='Copy explain output')
|
| 342 |
+
|
| 343 |
+
# ------------------------------------------------------------------
|
| 344 |
+
|
| 345 |
+
def arrange_components(self, is_img2img: bool) -> None:
|
| 346 |
+
if self.is_rendered:
|
| 347 |
+
return
|
| 348 |
+
|
| 349 |
+
with gr.Accordion(label='Neutral Prompt', open=False):
|
| 350 |
+
self.cfg_rescale.render()
|
| 351 |
+
|
| 352 |
+
with gr.Accordion(label='Base prompt protection', open=False):
|
| 353 |
+
self.protection_mode.render()
|
| 354 |
+
self.strict_threshold.render()
|
| 355 |
+
|
| 356 |
+
with gr.Accordion(label='Prompt formatter', open=False):
|
| 357 |
+
self.neutral_prompt.render()
|
| 358 |
+
self.neutral_cond_scale.render()
|
| 359 |
+
self.aux_prompt_type.render()
|
| 360 |
+
self.salt_k.render()
|
| 361 |
+
self.topk_threshold.render()
|
| 362 |
+
with gr.Row():
|
| 363 |
+
self.align_detail.render()
|
| 364 |
+
self.align_structure.render()
|
| 365 |
+
self.append_to_prompt_button.render()
|
| 366 |
+
|
| 367 |
+
with gr.Accordion(label='Matryoshka builder', open=False):
|
| 368 |
+
self.builder_base.render()
|
| 369 |
+
with gr.Row():
|
| 370 |
+
self.builder_ch1_strategy.render()
|
| 371 |
+
self.builder_ch1_weight.render()
|
| 372 |
+
self.builder_ch1_text.render()
|
| 373 |
+
self.builder_ch1_affine.render()
|
| 374 |
+
self.builder_ch1_nested_enable.render()
|
| 375 |
+
with gr.Row():
|
| 376 |
+
self.builder_ch1_nested_strategy.render()
|
| 377 |
+
self.builder_ch1_nested_weight.render()
|
| 378 |
+
self.builder_ch1_nested_text.render()
|
| 379 |
+
self.builder_ch2_enable.render()
|
| 380 |
+
with gr.Row():
|
| 381 |
+
self.builder_ch2_strategy.render()
|
| 382 |
+
self.builder_ch2_weight.render()
|
| 383 |
+
self.builder_ch2_text.render()
|
| 384 |
+
self.builder_preview.render()
|
| 385 |
+
with gr.Row():
|
| 386 |
+
self.builder_apply_button.render()
|
| 387 |
+
self.builder_copy_button.render()
|
| 388 |
+
self.builder_explain_button.render()
|
| 389 |
+
|
| 390 |
+
with gr.Accordion(label='Matryoshka templates', open=False):
|
| 391 |
+
self.template_dropdown.render()
|
| 392 |
+
self.template_description.render()
|
| 393 |
+
self.template_preview.render()
|
| 394 |
+
self.template_load_button.render()
|
| 395 |
+
|
| 396 |
+
with gr.Accordion(label='Affine transform builder', open=False):
|
| 397 |
+
self.affine_preset.render()
|
| 398 |
+
with gr.Row():
|
| 399 |
+
self.affine_transform.render()
|
| 400 |
+
self.affine_p1.render()
|
| 401 |
+
self.affine_p2.render()
|
| 402 |
+
self.affine_preview.render()
|
| 403 |
+
self.affine_insert_button.render()
|
| 404 |
+
|
| 405 |
+
with gr.Accordion(label='Step-window gating (AND_* activation)', open=False):
|
| 406 |
+
self.step_window_enabled.render()
|
| 407 |
+
self.step_window_lock.render()
|
| 408 |
+
self.step_window_mode.render()
|
| 409 |
+
with gr.Row():
|
| 410 |
+
self.step_window_start.render()
|
| 411 |
+
self.step_window_end.render()
|
| 412 |
+
with gr.Accordion(label='Per-strategy custom windows', open=False):
|
| 413 |
+
for _ui_key, (_s_slider, _e_slider) in self.step_ps_sliders.items():
|
| 414 |
+
with gr.Row():
|
| 415 |
+
_s_slider.render()
|
| 416 |
+
_e_slider.render()
|
| 417 |
+
self.step_ps_reset_btn.render()
|
| 418 |
+
self.step_window_preview.render()
|
| 419 |
+
|
| 420 |
+
with gr.Accordion(label='Prompt debug / explain', open=False):
|
| 421 |
+
self.debug_prompt_input.render()
|
| 422 |
+
with gr.Row():
|
| 423 |
+
self.debug_parse_button.render()
|
| 424 |
+
self.debug_copy_button.render()
|
| 425 |
+
self.debug_output.render()
|
| 426 |
+
|
| 427 |
+
def connect_events(self, is_img2img: bool) -> None:
|
| 428 |
+
if self.is_rendered:
|
| 429 |
+
return
|
| 430 |
+
|
| 431 |
+
prompt_textbox = img2img_prompt_textbox if is_img2img else txt2img_prompt_textbox
|
| 432 |
+
|
| 433 |
+
# ---- Protection ----
|
| 434 |
+
def _on_protection_change(mode, threshold):
|
| 435 |
+
mode = global_state.normalize_protection_mode(mode)
|
| 436 |
+
global_state.protection_mode = mode
|
| 437 |
+
global_state.strict_threshold = global_state.clamp_strict_threshold(threshold)
|
| 438 |
+
return gr.update(visible=(mode in ('strict', 'soft')))
|
| 439 |
+
|
| 440 |
+
self.protection_mode.change(
|
| 441 |
+
fn=_on_protection_change,
|
| 442 |
+
inputs=[self.protection_mode, self.strict_threshold],
|
| 443 |
+
outputs=[self.strict_threshold],
|
| 444 |
+
)
|
| 445 |
+
self.strict_threshold.change(
|
| 446 |
+
fn=lambda m, t: (
|
| 447 |
+
setattr(global_state, 'protection_mode', global_state.normalize_protection_mode(m)) or
|
| 448 |
+
setattr(global_state, 'strict_threshold', global_state.clamp_strict_threshold(t))
|
| 449 |
+
),
|
| 450 |
+
inputs=[self.protection_mode, self.strict_threshold], outputs=[],
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# ---- Formatter type sliders ----
|
| 454 |
+
def _on_type_change(label):
|
| 455 |
+
return (gr.update(visible=(label in _SALT_LABELS)),
|
| 456 |
+
gr.update(visible=(label in _TOPK_LABELS)),
|
| 457 |
+
gr.update(visible=(label in _ALIGN_CUSTOM_LABELS)),
|
| 458 |
+
gr.update(visible=(label in _ALIGN_CUSTOM_LABELS)))
|
| 459 |
+
|
| 460 |
+
self.aux_prompt_type.change(
|
| 461 |
+
fn=_on_type_change, inputs=[self.aux_prompt_type],
|
| 462 |
+
outputs=[self.salt_k, self.topk_threshold, self.align_detail, self.align_structure],
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
self.append_to_prompt_button.click(
|
| 466 |
+
fn=lambda init, p, sc, lbl, k, topk, d, s: (
|
| 467 |
+
f'{init}\n{_build_keyword(lbl, k, int(d), int(s), topk)} {p} :{sc}', ''),
|
| 468 |
+
inputs=[prompt_textbox, self.neutral_prompt, self.neutral_cond_scale,
|
| 469 |
+
self.aux_prompt_type, self.salt_k, self.topk_threshold,
|
| 470 |
+
self.align_detail, self.align_structure],
|
| 471 |
+
outputs=[prompt_textbox, self.neutral_prompt],
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# ---- Affine ----
|
| 475 |
+
def _apply_affine_preset(name):
|
| 476 |
+
data = _AFFINE_PRESETS.get(name)
|
| 477 |
+
if data is None:
|
| 478 |
+
return gr.update(), gr.update(), gr.update(), gr.update()
|
| 479 |
+
t, p1, p2 = data[0], float(data[1]) if len(data) > 1 else 0.0, float(data[2]) if len(data) > 2 else 1.0
|
| 480 |
+
return (gr.update(value=t), gr.update(value=p1),
|
| 481 |
+
gr.update(value=p2), gr.update(value=_build_affine_snippet(t, p1, p2)))
|
| 482 |
+
|
| 483 |
+
self.affine_preset.change(
|
| 484 |
+
fn=_apply_affine_preset, inputs=[self.affine_preset],
|
| 485 |
+
outputs=[self.affine_transform, self.affine_p1, self.affine_p2, self.affine_preview],
|
| 486 |
+
)
|
| 487 |
+
def _live_affine(t, p1, p2):
|
| 488 |
+
return gr.update(value=_build_affine_snippet(t, p1, p2))
|
| 489 |
+
for c in (self.affine_transform, self.affine_p1, self.affine_p2):
|
| 490 |
+
c.change(fn=_live_affine, inputs=[self.affine_transform, self.affine_p1, self.affine_p2],
|
| 491 |
+
outputs=[self.affine_preview])
|
| 492 |
+
|
| 493 |
+
self.affine_insert_button.click(
|
| 494 |
+
fn=lambda init, t, p1, p2: f'{init} {_build_affine_snippet(t, p1, p2)}'.strip(),
|
| 495 |
+
inputs=[prompt_textbox, self.affine_transform, self.affine_p1, self.affine_p2],
|
| 496 |
+
outputs=[prompt_textbox],
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
# ---- Builder live preview ----
|
| 500 |
+
def _builder_update(base, s1, t1, w1, aff1, nest_on, ns, nt, nw, ch2_on, s2, t2, w2):
|
| 501 |
+
nested_str = build_child_block(ns, nt, nw) if nest_on and nt.strip() else ''
|
| 502 |
+
ch1 = build_child_block(s1, t1, w1, affine=aff1, nested=nested_str) if t1.strip() else ''
|
| 503 |
+
ch2 = build_child_block(s2, t2, w2) if ch2_on and t2.strip() else ''
|
| 504 |
+
children = [c for c in [
|
| 505 |
+
{'strategy': s1, 'text': t1, 'weight': w1, 'affine': aff1,
|
| 506 |
+
'nested': build_child_block(ns, nt, nw) if nest_on and nt.strip() else ''},
|
| 507 |
+
{'strategy': s2, 'text': t2, 'weight': w2} if ch2_on and t2.strip() else None,
|
| 508 |
+
] if c is not None]
|
| 509 |
+
return gr.update(value=build_nested_prompt(base, children))
|
| 510 |
+
|
| 511 |
+
_builder_inputs = [
|
| 512 |
+
self.builder_base,
|
| 513 |
+
self.builder_ch1_strategy, self.builder_ch1_text, self.builder_ch1_weight,
|
| 514 |
+
self.builder_ch1_affine,
|
| 515 |
+
self.builder_ch1_nested_enable,
|
| 516 |
+
self.builder_ch1_nested_strategy, self.builder_ch1_nested_text, self.builder_ch1_nested_weight,
|
| 517 |
+
self.builder_ch2_enable,
|
| 518 |
+
self.builder_ch2_strategy, self.builder_ch2_text, self.builder_ch2_weight,
|
| 519 |
+
]
|
| 520 |
+
for ctrl in _builder_inputs:
|
| 521 |
+
ctrl.change(fn=_builder_update, inputs=_builder_inputs, outputs=[self.builder_preview])
|
| 522 |
+
|
| 523 |
+
# Toggle visibility of nested / ch2 sections
|
| 524 |
+
self.builder_ch1_nested_enable.change(
|
| 525 |
+
fn=lambda on: (gr.update(visible=on), gr.update(visible=on), gr.update(visible=on)),
|
| 526 |
+
inputs=[self.builder_ch1_nested_enable],
|
| 527 |
+
outputs=[self.builder_ch1_nested_strategy, self.builder_ch1_nested_text,
|
| 528 |
+
self.builder_ch1_nested_weight],
|
| 529 |
+
)
|
| 530 |
+
self.builder_ch2_enable.change(
|
| 531 |
+
fn=lambda on: (gr.update(visible=on), gr.update(visible=on), gr.update(visible=on)),
|
| 532 |
+
inputs=[self.builder_ch2_enable],
|
| 533 |
+
outputs=[self.builder_ch2_strategy, self.builder_ch2_text, self.builder_ch2_weight],
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
# Builder → apply / explain
|
| 537 |
+
self.builder_apply_button.click(
|
| 538 |
+
fn=lambda init, preview: (f'{init}\n{preview}'.strip() if preview.strip() else init, ''),
|
| 539 |
+
inputs=[prompt_textbox, self.builder_preview], outputs=[prompt_textbox, self.builder_preview],
|
| 540 |
+
)
|
| 541 |
+
# Copy buttons: copy content into a dedicated copy-sink textbox via JS,
|
| 542 |
+
# or — since Gradio doesn't have a native clipboard API — make the textbox
|
| 543 |
+
# temporarily interactive so the user can Ctrl+A / Ctrl+C easily.
|
| 544 |
+
# We implement this as a no-op event (the Textbox is already selectable).
|
| 545 |
+
# A real clipboard call requires gr.HTML + JS, which we keep optional.
|
| 546 |
+
self.builder_copy_button.click(fn=lambda x: x,
|
| 547 |
+
inputs=[self.builder_preview],
|
| 548 |
+
outputs=[self.builder_preview])
|
| 549 |
+
self.debug_copy_button.click(fn=lambda x: x,
|
| 550 |
+
inputs=[self.debug_output],
|
| 551 |
+
outputs=[self.debug_output])
|
| 552 |
+
self.builder_explain_button.click(
|
| 553 |
+
fn=lambda preview: gr.update(value=render_full_explain(preview)),
|
| 554 |
+
inputs=[self.builder_preview], outputs=[self.debug_output],
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
# ---- Templates ----
|
| 558 |
+
def _on_template_change(name):
|
| 559 |
+
t = MATRYOSHKA_TEMPLATES.get(name, {})
|
| 560 |
+
return (gr.update(value=t.get('description', '')),
|
| 561 |
+
gr.update(value=t.get('prompt', '')))
|
| 562 |
+
|
| 563 |
+
self.template_dropdown.change(
|
| 564 |
+
fn=_on_template_change, inputs=[self.template_dropdown],
|
| 565 |
+
outputs=[self.template_description, self.template_preview],
|
| 566 |
+
)
|
| 567 |
+
self.template_load_button.click(
|
| 568 |
+
fn=lambda init, tmpl: f'{init}\n{tmpl}'.strip() if tmpl.strip() else init,
|
| 569 |
+
inputs=[prompt_textbox, self.template_preview], outputs=[prompt_textbox],
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
# ---- Step window ----
|
| 573 |
+
_ps_all_sliders = [sl for (s, e) in self.step_ps_sliders.values() for sl in (s, e)]
|
| 574 |
+
|
| 575 |
+
def _step_window_update(enabled, lock, mode, start, end, *ps_vals):
|
| 576 |
+
vis = bool(enabled)
|
| 577 |
+
show_global = vis and mode == 'global'
|
| 578 |
+
show_custom = vis and mode == 'per-strategy custom'
|
| 579 |
+
|
| 580 |
+
global_state.step_window_enabled = vis
|
| 581 |
+
global_state.step_window_lock_after_end = bool(lock)
|
| 582 |
+
if vis:
|
| 583 |
+
if mode == 'global':
|
| 584 |
+
global_state.step_window_global = StepWindow(start, end)
|
| 585 |
+
global_state.step_window_use_defaults = False
|
| 586 |
+
global_state.step_window_per_strategy = None
|
| 587 |
+
elif mode == 'per-strategy defaults':
|
| 588 |
+
global_state.step_window_global = None
|
| 589 |
+
global_state.step_window_use_defaults = True
|
| 590 |
+
global_state.step_window_per_strategy = None
|
| 591 |
+
elif mode == 'per-strategy custom':
|
| 592 |
+
raw = {}
|
| 593 |
+
for idx, ui_key in enumerate(_STRATEGY_UI_KEYS):
|
| 594 |
+
raw[ui_key] = (float(ps_vals[idx * 2]), float(ps_vals[idx * 2 + 1]))
|
| 595 |
+
global_state.step_window_global = None
|
| 596 |
+
global_state.step_window_use_defaults = False
|
| 597 |
+
global_state.step_window_custom_raw = raw
|
| 598 |
+
global_state.step_window_per_strategy = _build_per_strategy_windows(raw)
|
| 599 |
+
else:
|
| 600 |
+
global_state.step_window_global = None
|
| 601 |
+
global_state.step_window_use_defaults = False
|
| 602 |
+
global_state.step_window_per_strategy = None
|
| 603 |
+
|
| 604 |
+
custom_raw = getattr(global_state, 'step_window_custom_raw', None) or {}
|
| 605 |
+
preview = _build_step_window_preview(mode, start, end, custom_raw) if vis else ''
|
| 606 |
+
|
| 607 |
+
ps_vis = [gr.update(visible=show_custom)] * len(_ps_all_sliders)
|
| 608 |
+
return (
|
| 609 |
+
gr.update(visible=vis), # mode
|
| 610 |
+
gr.update(visible=vis), # lock checkbox
|
| 611 |
+
gr.update(visible=show_global), # global start
|
| 612 |
+
gr.update(visible=show_global), # global end
|
| 613 |
+
*ps_vis, # per-strategy sliders
|
| 614 |
+
gr.update(visible=show_custom), # reset_btn
|
| 615 |
+
gr.update(visible=vis, value=preview), # preview
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
def _ps_reset(*_):
|
| 619 |
+
from lib_neutral_prompt.step_utils import default_custom_windows
|
| 620 |
+
defaults = default_custom_windows()
|
| 621 |
+
updates = []
|
| 622 |
+
for ui_key in _STRATEGY_UI_KEYS:
|
| 623 |
+
s, e = defaults.get(ui_key, (0.0, 1.0))
|
| 624 |
+
updates.extend([gr.update(value=s), gr.update(value=e)])
|
| 625 |
+
return updates
|
| 626 |
+
|
| 627 |
+
_sw_inputs = ([self.step_window_enabled, self.step_window_lock,
|
| 628 |
+
self.step_window_mode,
|
| 629 |
+
self.step_window_start, self.step_window_end]
|
| 630 |
+
+ _ps_all_sliders)
|
| 631 |
+
_sw_outputs = ([self.step_window_mode, self.step_window_lock,
|
| 632 |
+
self.step_window_start, self.step_window_end]
|
| 633 |
+
+ _ps_all_sliders
|
| 634 |
+
+ [self.step_ps_reset_btn, self.step_window_preview])
|
| 635 |
+
|
| 636 |
+
for ctrl in _sw_inputs:
|
| 637 |
+
ctrl.change(fn=_step_window_update, inputs=_sw_inputs, outputs=_sw_outputs)
|
| 638 |
+
|
| 639 |
+
self.step_ps_reset_btn.click(fn=_ps_reset, inputs=[], outputs=_ps_all_sliders)
|
| 640 |
+
_explain_fn = lambda s: gr.update(value=render_full_explain(s))
|
| 641 |
+
self.debug_parse_button.click(fn=_explain_fn, inputs=[self.debug_prompt_input],
|
| 642 |
+
outputs=[self.debug_output])
|
| 643 |
+
self.debug_prompt_input.change(fn=_explain_fn, inputs=[self.debug_prompt_input],
|
| 644 |
+
outputs=[self.debug_output])
|
| 645 |
+
|
| 646 |
+
def set_rendered(self, value: bool = True) -> None:
|
| 647 |
+
self.is_rendered = value
|
| 648 |
+
|
| 649 |
+
# ------------------------------------------------------------------
|
| 650 |
+
|
| 651 |
+
def _collect_ps_slider_values(self) -> Dict[str, Tuple[float, float]]:
|
| 652 |
+
"""Read current per-strategy slider values into a raw dict."""
|
| 653 |
+
raw: Dict[str, Tuple[float, float]] = {}
|
| 654 |
+
for ui_key, (s_sl, e_sl) in self.step_ps_sliders.items():
|
| 655 |
+
raw[ui_key] = (float(s_sl.value), float(e_sl.value))
|
| 656 |
+
return raw
|
| 657 |
+
|
| 658 |
+
def get_components(self) -> Tuple:
|
| 659 |
+
ps_components = []
|
| 660 |
+
for _ui_key, (_s, _e) in self.step_ps_sliders.items():
|
| 661 |
+
ps_components.extend([_s, _e])
|
| 662 |
+
return (self.cfg_rescale, self.protection_mode, self.strict_threshold,
|
| 663 |
+
self.step_window_enabled, self.step_window_lock,
|
| 664 |
+
self.step_window_mode,
|
| 665 |
+
self.step_window_start, self.step_window_end,
|
| 666 |
+
*ps_components)
|
| 667 |
+
|
| 668 |
+
def get_infotext_fields(self) -> Tuple:
|
| 669 |
+
return tuple(zip(self.get_components(), (
|
| 670 |
+
'CFG Rescale phi',
|
| 671 |
+
'NP Protection Mode', 'NP Strict Threshold',
|
| 672 |
+
'NP Step Window Enabled', 'NP Step Window Lock',
|
| 673 |
+
'NP Step Window Mode',
|
| 674 |
+
'NP Step Window Start', 'NP Step Window End',
|
| 675 |
+
# per-strategy sliders not individually in infotext — serialised as one key
|
| 676 |
+
)))
|
| 677 |
+
|
| 678 |
+
def get_paste_field_names(self) -> List[str]:
|
| 679 |
+
return [
|
| 680 |
+
'CFG Rescale phi',
|
| 681 |
+
'NP Protection Mode', 'NP Strict Threshold',
|
| 682 |
+
'NP Step Window Enabled', 'NP Step Window Lock',
|
| 683 |
+
'NP Step Window Mode',
|
| 684 |
+
'NP Step Window Start', 'NP Step Window End',
|
| 685 |
+
'NP Step Window Custom',
|
| 686 |
+
]
|
| 687 |
+
|
| 688 |
+
def get_extra_generation_params(self, args: Dict) -> Dict:
|
| 689 |
+
params = {
|
| 690 |
+
'CFG Rescale phi': args['cfg_rescale'],
|
| 691 |
+
'NP Protection Mode': args['protection_mode'],
|
| 692 |
+
'NP Strict Threshold': args['strict_threshold'],
|
| 693 |
+
}
|
| 694 |
+
if args.get('step_window_enabled'):
|
| 695 |
+
params['NP Step Window Enabled'] = True
|
| 696 |
+
params['NP Step Window Mode'] = args.get('step_window_mode', 'global')
|
| 697 |
+
params['NP Step Window Start'] = round(float(args.get('step_window_start', 0.0)), 4)
|
| 698 |
+
params['NP Step Window End'] = round(float(args.get('step_window_end', 1.0)), 4)
|
| 699 |
+
if args.get('step_window_lock_after_end'):
|
| 700 |
+
params['NP Step Window Lock'] = True
|
| 701 |
+
if args.get('step_window_mode') == 'per-strategy custom':
|
| 702 |
+
raw = args.get('step_window_custom_raw') or {}
|
| 703 |
+
params['NP Step Window Custom'] = _serialize_per_strategy_windows(raw)
|
| 704 |
+
return params
|
| 705 |
+
|
| 706 |
+
def unpack_processing_args(self, cfg_rescale: float,
|
| 707 |
+
protection_mode: str = 'auto',
|
| 708 |
+
strict_threshold: float = 0.1,
|
| 709 |
+
step_window_enabled: bool = False,
|
| 710 |
+
step_window_lock_after_end: bool = False,
|
| 711 |
+
step_window_mode: str = 'global',
|
| 712 |
+
step_window_start: float = 0.0,
|
| 713 |
+
step_window_end: float = 1.0,
|
| 714 |
+
**ps_kwargs) -> Dict:
|
| 715 |
+
# Per-strategy raw dict: collect from kwargs (AND_PERP_start, AND_PERP_end, ...)
|
| 716 |
+
custom_raw: Dict[str, Tuple[float, float]] = {}
|
| 717 |
+
for ui_key in _STRATEGY_UI_KEYS:
|
| 718 |
+
s_key = f'{ui_key}_start'
|
| 719 |
+
e_key = f'{ui_key}_end'
|
| 720 |
+
if s_key in ps_kwargs and e_key in ps_kwargs:
|
| 721 |
+
custom_raw[ui_key] = (float(ps_kwargs[s_key]), float(ps_kwargs[e_key]))
|
| 722 |
+
|
| 723 |
+
enabled = bool(step_window_enabled)
|
| 724 |
+
lock = bool(step_window_lock_after_end)
|
| 725 |
+
mode = str(step_window_mode)
|
| 726 |
+
start = max(0.0, min(1.0, float(step_window_start)))
|
| 727 |
+
end = max(0.0, min(1.0, float(step_window_end)))
|
| 728 |
+
if start > end:
|
| 729 |
+
start, end = 0.0, 1.0
|
| 730 |
+
|
| 731 |
+
global_state.step_window_enabled = enabled
|
| 732 |
+
global_state.step_window_lock_after_end = lock
|
| 733 |
+
if enabled:
|
| 734 |
+
if mode == 'global':
|
| 735 |
+
global_state.step_window_global = StepWindow(start, end)
|
| 736 |
+
global_state.step_window_use_defaults = False
|
| 737 |
+
global_state.step_window_per_strategy = None
|
| 738 |
+
global_state.step_window_custom_raw = None
|
| 739 |
+
elif mode == 'per-strategy defaults':
|
| 740 |
+
global_state.step_window_global = None
|
| 741 |
+
global_state.step_window_use_defaults = True
|
| 742 |
+
global_state.step_window_per_strategy = None
|
| 743 |
+
global_state.step_window_custom_raw = None
|
| 744 |
+
elif mode == 'per-strategy custom':
|
| 745 |
+
global_state.step_window_global = None
|
| 746 |
+
global_state.step_window_use_defaults = False
|
| 747 |
+
global_state.step_window_custom_raw = custom_raw
|
| 748 |
+
global_state.step_window_per_strategy = _build_per_strategy_windows(custom_raw)
|
| 749 |
+
else:
|
| 750 |
+
global_state.step_window_global = None
|
| 751 |
+
global_state.step_window_use_defaults = False
|
| 752 |
+
global_state.step_window_per_strategy = None
|
| 753 |
+
else:
|
| 754 |
+
global_state.step_window_global = None
|
| 755 |
+
global_state.step_window_use_defaults = False
|
| 756 |
+
global_state.step_window_per_strategy = None
|
| 757 |
+
|
| 758 |
+
return {
|
| 759 |
+
'cfg_rescale': cfg_rescale,
|
| 760 |
+
'protection_mode': protection_mode,
|
| 761 |
+
'strict_threshold': strict_threshold,
|
| 762 |
+
'step_window_enabled': enabled,
|
| 763 |
+
'step_window_lock_after_end': lock,
|
| 764 |
+
'step_window_mode': mode,
|
| 765 |
+
'step_window_start': start,
|
| 766 |
+
'step_window_end': end,
|
| 767 |
+
'step_window_custom_raw': custom_raw,
|
| 768 |
+
}
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
# ---------------------------------------------------------------------------
|
| 772 |
+
# Settings & callbacks
|
| 773 |
+
# ---------------------------------------------------------------------------
|
| 774 |
+
|
| 775 |
+
def on_ui_settings() -> None:
|
| 776 |
+
section = ('neutral_prompt', 'Neutral Prompt')
|
| 777 |
+
shared.opts.add_option(
|
| 778 |
+
'neutral_prompt_enabled',
|
| 779 |
+
shared.OptionInfo(True, 'Enable neutral-prompt extension', section=section),
|
| 780 |
+
)
|
| 781 |
+
global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
|
| 782 |
+
shared.opts.onchange('neutral_prompt_enabled', _update_enabled)
|
| 783 |
+
shared.opts.add_option(
|
| 784 |
+
'neutral_prompt_verbose',
|
| 785 |
+
shared.OptionInfo(False, 'Enable verbose debugging for neutral-prompt', section=section),
|
| 786 |
+
)
|
| 787 |
+
global_state.verbose = shared.opts.data.get('neutral_prompt_verbose', False)
|
| 788 |
+
shared.opts.onchange('neutral_prompt_verbose', _update_verbose)
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
script_callbacks.on_ui_settings(on_ui_settings)
|
| 792 |
+
|
| 793 |
+
|
| 794 |
+
def _update_enabled() -> None:
|
| 795 |
+
global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
def _update_verbose() -> None:
|
| 799 |
+
global_state.verbose = shared.opts.data.get('neutral_prompt_verbose', False)
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
def on_after_component(component, **_kwargs) -> None:
|
| 803 |
+
global txt2img_prompt_textbox, img2img_prompt_textbox
|
| 804 |
+
eid = getattr(component, 'elem_id', None)
|
| 805 |
+
if eid == 'txt2img_prompt':
|
| 806 |
+
txt2img_prompt_textbox = component
|
| 807 |
+
elif eid == 'img2img_prompt':
|
| 808 |
+
img2img_prompt_textbox = component
|
| 809 |
+
|
| 810 |
+
|
| 811 |
+
script_callbacks.on_after_component(on_after_component)
|
neutral_prompt_patcheds/lib_neutral_prompt/xyz_grid.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from types import ModuleType
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from modules import scripts
|
| 5 |
+
from lib_neutral_prompt import global_state
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def patch():
|
| 9 |
+
xyz_module = find_xyz_module()
|
| 10 |
+
if xyz_module is None:
|
| 11 |
+
print("[sd-webui-neutral-prompt]", "xyz_grid.py not found.", file=sys.stderr)
|
| 12 |
+
return
|
| 13 |
+
|
| 14 |
+
xyz_module.axis_options.extend([
|
| 15 |
+
xyz_module.AxisOption("[Neutral Prompt] CFG Rescale", int_or_float, apply_cfg_rescale()),
|
| 16 |
+
])
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class XyzFloat(float):
|
| 20 |
+
is_xyz: bool = True
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def apply_cfg_rescale():
|
| 24 |
+
def callback(_p, v, _vs):
|
| 25 |
+
global_state.cfg_rescale = XyzFloat(v)
|
| 26 |
+
|
| 27 |
+
return callback
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def int_or_float(string):
|
| 31 |
+
try:
|
| 32 |
+
return int(string)
|
| 33 |
+
except ValueError:
|
| 34 |
+
return float(string)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def find_xyz_module() -> Optional[ModuleType]:
|
| 38 |
+
for data in scripts.scripts_data:
|
| 39 |
+
if data.script_class.__module__ in {"xyz_grid.py", "xy_grid.py"} and hasattr(data, "module"):
|
| 40 |
+
return data.module
|
| 41 |
+
|
| 42 |
+
return None
|
neutral_prompt_patcheds/scripts/__pycache__/neutral_prompt.cpython-310.pyc
ADDED
|
Binary file (3.63 kB). View file
|
|
|
neutral_prompt_patcheds/scripts/neutral_prompt.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser, prompt_parser_hijack, cfg_denoiser_hijack, ui, xyz_grid
|
| 2 |
+
from modules import scripts, processing, shared
|
| 3 |
+
from typing import Dict
|
| 4 |
+
import functools
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class NeutralPromptScript(scripts.Script):
|
| 8 |
+
def __init__(self):
|
| 9 |
+
self.accordion_interface = None
|
| 10 |
+
self._is_img2img = False
|
| 11 |
+
|
| 12 |
+
@property
|
| 13 |
+
def is_img2img(self):
|
| 14 |
+
return self._is_img2img
|
| 15 |
+
|
| 16 |
+
@is_img2img.setter
|
| 17 |
+
def is_img2img(self, is_img2img):
|
| 18 |
+
self._is_img2img = is_img2img
|
| 19 |
+
if self.accordion_interface is None:
|
| 20 |
+
self.accordion_interface = ui.AccordionInterface(self.elem_id)
|
| 21 |
+
|
| 22 |
+
def title(self) -> str:
|
| 23 |
+
return "Neutral Prompt"
|
| 24 |
+
|
| 25 |
+
def show(self, is_img2img: bool):
|
| 26 |
+
return scripts.AlwaysVisible
|
| 27 |
+
|
| 28 |
+
def ui(self, is_img2img: bool):
|
| 29 |
+
self.hijack_composable_lora(is_img2img)
|
| 30 |
+
|
| 31 |
+
self.accordion_interface.arrange_components(is_img2img)
|
| 32 |
+
self.accordion_interface.connect_events(is_img2img)
|
| 33 |
+
self.infotext_fields = self.accordion_interface.get_infotext_fields()
|
| 34 |
+
self.paste_field_names = self.accordion_interface.get_paste_field_names()
|
| 35 |
+
self.accordion_interface.set_rendered()
|
| 36 |
+
return self.accordion_interface.get_components()
|
| 37 |
+
|
| 38 |
+
def process(self, p: processing.StableDiffusionProcessing, *args):
|
| 39 |
+
args = self.accordion_interface.unpack_processing_args(*args)
|
| 40 |
+
|
| 41 |
+
self.update_global_state(args)
|
| 42 |
+
if global_state.is_enabled:
|
| 43 |
+
p.extra_generation_params.update(self.accordion_interface.get_extra_generation_params(args))
|
| 44 |
+
# Reset Lock-after-End flags once per real generation lifecycle.
|
| 45 |
+
# This must happen AFTER update_global_state so that
|
| 46 |
+
# step_window_lock_after_end reflects the current UI value.
|
| 47 |
+
if global_state.step_window_lock_after_end:
|
| 48 |
+
global_state.begin_new_generation()
|
| 49 |
+
|
| 50 |
+
def update_global_state(self, args: Dict):
|
| 51 |
+
if shared.state.job_no == 0:
|
| 52 |
+
global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
|
| 53 |
+
|
| 54 |
+
for k, v in args.items():
|
| 55 |
+
try:
|
| 56 |
+
getattr(global_state, k)
|
| 57 |
+
except AttributeError:
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
if getattr(getattr(global_state, k), 'is_xyz', False):
|
| 61 |
+
xyz_attr = getattr(global_state, k)
|
| 62 |
+
xyz_attr.is_xyz = False
|
| 63 |
+
args[k] = xyz_attr
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
if shared.state.job_no > 0:
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
setattr(global_state, k, v)
|
| 70 |
+
|
| 71 |
+
def hijack_composable_lora(self, is_img2img):
|
| 72 |
+
if self.accordion_interface.is_rendered:
|
| 73 |
+
return
|
| 74 |
+
|
| 75 |
+
lora_script = None
|
| 76 |
+
script_runner = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img
|
| 77 |
+
|
| 78 |
+
for script in script_runner.alwayson_scripts:
|
| 79 |
+
if script.title().lower() == "composable lora":
|
| 80 |
+
lora_script = script
|
| 81 |
+
break
|
| 82 |
+
|
| 83 |
+
if lora_script is not None:
|
| 84 |
+
lora_script.process = functools.partial(composable_lora_process_hijack, original_function=lora_script.process)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def composable_lora_process_hijack(p: processing.StableDiffusionProcessing, *args, original_function, **kwargs):
|
| 88 |
+
if not global_state.is_enabled:
|
| 89 |
+
return original_function(p, *args, **kwargs)
|
| 90 |
+
|
| 91 |
+
exprs = prompt_parser_hijack.parse_prompts(p.all_prompts)
|
| 92 |
+
all_prompts, p.all_prompts = p.all_prompts, prompt_parser_hijack.transpile_exprs(exprs)
|
| 93 |
+
res = original_function(p, *args, **kwargs)
|
| 94 |
+
# restore original prompts
|
| 95 |
+
p.all_prompts = all_prompts
|
| 96 |
+
return res
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
xyz_grid.patch()
|
neutral_prompt_patcheds/test/perp_parser/__init__.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test package init — installs a minimal torch mock when real torch is absent.
|
| 3 |
+
|
| 4 |
+
This runs before any test module in the package is imported, so all parser
|
| 5 |
+
tests get a working `torch` in sys.modules without needing torch installed.
|
| 6 |
+
Runtime tests that require real torch use @unittest.skipUnless independently.
|
| 7 |
+
"""
|
| 8 |
+
import pathlib, sys
|
| 9 |
+
|
| 10 |
+
# Make the extension root importable regardless of CWD.
|
| 11 |
+
_root = str(pathlib.Path(__file__).parent.parent.parent)
|
| 12 |
+
if _root not in sys.path:
|
| 13 |
+
sys.path.insert(0, _root)
|
| 14 |
+
|
| 15 |
+
# Install mock if real torch is absent.
|
| 16 |
+
import math, types
|
| 17 |
+
|
| 18 |
+
def _install_mock_torch():
|
| 19 |
+
try:
|
| 20 |
+
import importlib.util
|
| 21 |
+
spec = importlib.util.find_spec('torch')
|
| 22 |
+
if spec is not None and getattr(spec, 'origin', None) is not None:
|
| 23 |
+
return # real torch present — nothing to do
|
| 24 |
+
except (ValueError, ModuleNotFoundError):
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
if 'torch' in sys.modules:
|
| 28 |
+
return # already mocked
|
| 29 |
+
|
| 30 |
+
class _FT:
|
| 31 |
+
def __init__(self, data=None): self.data = data
|
| 32 |
+
def __getitem__(self, key):
|
| 33 |
+
if isinstance(self.data, list):
|
| 34 |
+
return _FT(self.data[key] if isinstance(key, (int, slice)) else None)
|
| 35 |
+
return _FT(None)
|
| 36 |
+
def __matmul__(self, other): return _FT('matmul')
|
| 37 |
+
|
| 38 |
+
torch = types.ModuleType('torch')
|
| 39 |
+
torch.Tensor = _FT
|
| 40 |
+
torch.float32 = 'float32'
|
| 41 |
+
torch.pi = math.pi
|
| 42 |
+
torch.eye = lambda n, dtype=None: _FT(list(range(n)))
|
| 43 |
+
torch.tensor = lambda data, dtype=None: _FT(data)
|
| 44 |
+
torch.vstack = lambda tensors: _FT([t.data for t in tensors])
|
| 45 |
+
torch.linalg = types.SimpleNamespace(inv=lambda x: _FT('inv'))
|
| 46 |
+
|
| 47 |
+
nn = types.ModuleType('torch.nn')
|
| 48 |
+
nn_fn = types.ModuleType('torch.nn.functional')
|
| 49 |
+
torch.nn = nn
|
| 50 |
+
sys.modules['torch'] = torch
|
| 51 |
+
sys.modules['torch.nn'] = nn
|
| 52 |
+
sys.modules['torch.nn.functional'] = nn_fn
|
| 53 |
+
|
| 54 |
+
_install_mock_torch()
|
neutral_prompt_patcheds/test/perp_parser/mock_torch.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Minimal torch mock for parser-only tests (no real tensor math needed).
|
| 3 |
+
|
| 4 |
+
Usage — place this BEFORE any lib_neutral_prompt imports:
|
| 5 |
+
|
| 6 |
+
from test.perp_parser.mock_torch import install
|
| 7 |
+
install()
|
| 8 |
+
from lib_neutral_prompt import neutral_prompt_parser
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
import sys
|
| 13 |
+
import types
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class _FakeTensor:
|
| 17 |
+
"""Enough of torch.Tensor for the parser AST dataclasses."""
|
| 18 |
+
def __init__(self, data=None):
|
| 19 |
+
self.data = data
|
| 20 |
+
|
| 21 |
+
def __getitem__(self, key):
|
| 22 |
+
if isinstance(self.data, list) and isinstance(key, (int, slice)):
|
| 23 |
+
return _FakeTensor(self.data[key])
|
| 24 |
+
return _FakeTensor(None)
|
| 25 |
+
|
| 26 |
+
def __matmul__(self, other):
|
| 27 |
+
return _FakeTensor('matmul')
|
| 28 |
+
|
| 29 |
+
def __repr__(self):
|
| 30 |
+
return f'FakeTensor({self.data})'
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def install() -> None:
|
| 34 |
+
"""Install a minimal torch mock into sys.modules (idempotent)."""
|
| 35 |
+
if 'torch' in sys.modules and not isinstance(sys.modules['torch'], types.ModuleType):
|
| 36 |
+
return # real torch already present
|
| 37 |
+
# Only mock if torch is actually missing
|
| 38 |
+
try:
|
| 39 |
+
import importlib.util
|
| 40 |
+
spec = importlib.util.find_spec('torch')
|
| 41 |
+
if spec is not None and getattr(spec, 'origin', None) is not None:
|
| 42 |
+
return # real torch is available — use it
|
| 43 |
+
except (ValueError, ModuleNotFoundError):
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
torch = types.ModuleType('torch')
|
| 47 |
+
torch.Tensor = _FakeTensor
|
| 48 |
+
torch.float32 = 'float32'
|
| 49 |
+
torch.pi = math.pi
|
| 50 |
+
torch.eye = lambda n, dtype=None: _FakeTensor(list(range(n)))
|
| 51 |
+
torch.tensor = lambda data, dtype=None: _FakeTensor(data)
|
| 52 |
+
torch.vstack = lambda tensors: _FakeTensor([t.data for t in tensors])
|
| 53 |
+
torch.linalg = types.SimpleNamespace(inv=lambda x: _FakeTensor('inv'))
|
| 54 |
+
|
| 55 |
+
nn = types.ModuleType('torch.nn')
|
| 56 |
+
nn_fn = types.ModuleType('torch.nn.functional')
|
| 57 |
+
torch.nn = nn
|
| 58 |
+
|
| 59 |
+
sys.modules['torch'] = torch
|
| 60 |
+
sys.modules['torch.nn'] = nn
|
| 61 |
+
sys.modules['torch.nn.functional'] = nn_fn
|
neutral_prompt_patcheds/test/perp_parser/test_affine_keyword_order.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for affine-keyword ordering in parse_prompt.
|
| 3 |
+
|
| 4 |
+
Both orderings must work correctly:
|
| 5 |
+
(A) AND_PERP ROTATE[0.125] vivid colors ← trailing affine (original)
|
| 6 |
+
(B) ROTATE[0.125] AND_PERP vivid colors ← leading affine (documented)
|
| 7 |
+
"""
|
| 8 |
+
import unittest
|
| 9 |
+
|
| 10 |
+
from lib_neutral_prompt import neutral_prompt_parser as p
|
| 11 |
+
from lib_neutral_prompt.neutral_prompt_parser import ConciliationStrategy
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TestAffineKeywordOrder(unittest.TestCase):
|
| 15 |
+
|
| 16 |
+
# ------------------------------------------------------------------ #
|
| 17 |
+
# trailing affine (AND_PERP ROTATE[...] text) #
|
| 18 |
+
# ------------------------------------------------------------------ #
|
| 19 |
+
|
| 20 |
+
def test_trailing_affine_basic(self):
|
| 21 |
+
root = p.parse_root("base AND_PERP ROTATE[0.125] vivid colors :0.8")
|
| 22 |
+
self.assertEqual(len(root.children), 2)
|
| 23 |
+
self.assertEqual(root.children[0].prompt.strip(), "base")
|
| 24 |
+
self.assertIsNone(root.children[0].local_transform)
|
| 25 |
+
self.assertIsNotNone(root.children[1].local_transform)
|
| 26 |
+
self.assertEqual(root.children[1].conciliation, ConciliationStrategy.PERPENDICULAR)
|
| 27 |
+
self.assertAlmostEqual(root.children[1].weight, 0.8)
|
| 28 |
+
|
| 29 |
+
# ------------------------------------------------------------------ #
|
| 30 |
+
# leading affine (ROTATE[...] AND_PERP text) #
|
| 31 |
+
# ------------------------------------------------------------------ #
|
| 32 |
+
|
| 33 |
+
def test_leading_affine_mid_prompt(self):
|
| 34 |
+
"""ROTATE[...] AND_PERP must NOT be consumed into the previous prompt's text."""
|
| 35 |
+
root = p.parse_root("base ROTATE[0.125] AND_PERP vivid colors :0.8")
|
| 36 |
+
self.assertEqual(len(root.children), 2,
|
| 37 |
+
f"Expected 2 children, got {len(root.children)}: "
|
| 38 |
+
f"{[getattr(c,'prompt','composite') for c in root.children]}")
|
| 39 |
+
first, second = root.children
|
| 40 |
+
# first segment = "base " with no transform
|
| 41 |
+
self.assertIn("base", first.prompt)
|
| 42 |
+
self.assertNotIn("ROTATE", first.prompt)
|
| 43 |
+
self.assertIsNone(first.local_transform)
|
| 44 |
+
# second segment gets the transform
|
| 45 |
+
self.assertIsNotNone(second.local_transform)
|
| 46 |
+
self.assertEqual(second.conciliation, ConciliationStrategy.PERPENDICULAR)
|
| 47 |
+
self.assertAlmostEqual(second.weight, 0.8)
|
| 48 |
+
|
| 49 |
+
def test_leading_affine_at_start(self):
|
| 50 |
+
"""ROTATE[...] AND_PERP as first (and only non-AND) segment."""
|
| 51 |
+
root = p.parse_root("ROTATE[0.125] AND_PERP vivid colors :0.8")
|
| 52 |
+
self.assertEqual(len(root.children), 1)
|
| 53 |
+
child = root.children[0]
|
| 54 |
+
self.assertIsNotNone(child.local_transform)
|
| 55 |
+
self.assertEqual(child.conciliation, ConciliationStrategy.PERPENDICULAR)
|
| 56 |
+
self.assertIn("vivid", child.prompt)
|
| 57 |
+
self.assertNotIn("ROTATE", child.prompt)
|
| 58 |
+
|
| 59 |
+
def test_leading_affine_with_and_salt(self):
|
| 60 |
+
root = p.parse_root("portrait SLIDE[0.05,0] AND_SALT vibrant :1.2")
|
| 61 |
+
self.assertEqual(len(root.children), 2)
|
| 62 |
+
second = root.children[1]
|
| 63 |
+
self.assertIsNotNone(second.local_transform)
|
| 64 |
+
self.assertEqual(second.conciliation, ConciliationStrategy.SALIENCE_MASK)
|
| 65 |
+
|
| 66 |
+
# ------------------------------------------------------------------ #
|
| 67 |
+
# Composed / chained affine #
|
| 68 |
+
# ------------------------------------------------------------------ #
|
| 69 |
+
|
| 70 |
+
def test_leading_and_trailing_compose(self):
|
| 71 |
+
"""ROTATE[a] AND_PERP SCALE[b,b] text — both affines composed."""
|
| 72 |
+
root = p.parse_root("base AND_PERP ROTATE[0.125] AND_PERP SCALE[1.5,1.5] vivid")
|
| 73 |
+
# Just check no crash and all children parse
|
| 74 |
+
self.assertGreaterEqual(len(root.children), 1)
|
| 75 |
+
|
| 76 |
+
# ------------------------------------------------------------------ #
|
| 77 |
+
# CFGRescaleFactorSingleton lifecycle #
|
| 78 |
+
# ------------------------------------------------------------------ #
|
| 79 |
+
|
| 80 |
+
def test_cfg_rescale_singleton_clear(self):
|
| 81 |
+
from lib_neutral_prompt.global_state import CFGRescaleFactorSingleton as S
|
| 82 |
+
S.clear()
|
| 83 |
+
self.assertIsNone(S.get())
|
| 84 |
+
S.set(1.23)
|
| 85 |
+
self.assertAlmostEqual(S.get(), 1.23)
|
| 86 |
+
S.clear()
|
| 87 |
+
self.assertIsNone(S.get())
|
| 88 |
+
|
| 89 |
+
def test_cfg_rescale_singleton_thread_local(self):
|
| 90 |
+
import threading
|
| 91 |
+
from lib_neutral_prompt.global_state import CFGRescaleFactorSingleton as S
|
| 92 |
+
results = {}
|
| 93 |
+
def worker(val, key):
|
| 94 |
+
S.clear()
|
| 95 |
+
S.set(val)
|
| 96 |
+
import time; time.sleep(0.01)
|
| 97 |
+
results[key] = S.get()
|
| 98 |
+
threads = [threading.Thread(target=worker, args=(i * 10.0, i)) for i in range(3)]
|
| 99 |
+
for t in threads: t.start()
|
| 100 |
+
for t in threads: t.join()
|
| 101 |
+
for i in range(3):
|
| 102 |
+
self.assertAlmostEqual(results[i], i * 10.0,
|
| 103 |
+
msg=f"Thread {i} got {results[i]} expected {i*10.0}")
|
| 104 |
+
|
| 105 |
+
# ------------------------------------------------------------------ #
|
| 106 |
+
# New AND_SALT variants parse correctly #
|
| 107 |
+
# ------------------------------------------------------------------ #
|
| 108 |
+
|
| 109 |
+
def test_and_salt_wide_parsed(self):
|
| 110 |
+
root = p.parse_root("base AND_SALT_WIDE vivid")
|
| 111 |
+
self.assertEqual(len(root.children), 2)
|
| 112 |
+
self.assertEqual(root.children[1].conciliation,
|
| 113 |
+
p.ConciliationStrategy.SALIENCE_MASK_WIDE)
|
| 114 |
+
|
| 115 |
+
def test_and_salt_blob_parsed(self):
|
| 116 |
+
root = p.parse_root("base AND_SALT_BLOB vivid")
|
| 117 |
+
self.assertEqual(len(root.children), 2)
|
| 118 |
+
self.assertEqual(root.children[1].conciliation,
|
| 119 |
+
p.ConciliationStrategy.SALIENCE_MASK_BLOB)
|
| 120 |
+
|
| 121 |
+
def test_and_align_parsed(self):
|
| 122 |
+
root = p.parse_root("base AND_ALIGN_4_8 vivid")
|
| 123 |
+
self.assertEqual(len(root.children), 2)
|
| 124 |
+
self.assertIn('AND_ALIGN_4_8', root.children[1].conciliation.value)
|
| 125 |
+
|
| 126 |
+
def test_and_mask_align_parsed(self):
|
| 127 |
+
root = p.parse_root("base AND_MASK_ALIGN_4_8 vivid")
|
| 128 |
+
self.assertEqual(len(root.children), 2)
|
| 129 |
+
self.assertIn('AND_MASK_ALIGN_4_8', root.children[1].conciliation.value)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
if __name__ == '__main__':
|
| 133 |
+
unittest.main()
|
neutral_prompt_patcheds/test/perp_parser/test_affine_pipeline.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Integration smoke tests for the pre-noise affine pipeline.
|
| 3 |
+
|
| 4 |
+
Requires PyTorch. If torch is not installed the entire module is skipped via
|
| 5 |
+
``unittest.skipUnless``, so ``unittest discover`` still exits cleanly.
|
| 6 |
+
|
| 7 |
+
Mocked: A1111 modules only (modules.script_callbacks, modules.sd_samplers,
|
| 8 |
+
modules.shared). torch itself is real — no sys.modules replacement.
|
| 9 |
+
|
| 10 |
+
Tests:
|
| 11 |
+
1. Hook is a no-op when global_state.is_enabled = False.
|
| 12 |
+
2. Identity transform (angle=0) leaves x numerically unchanged.
|
| 13 |
+
3. Non-identity transform calls apply_affine_transform (verified by spy).
|
| 14 |
+
4. Prompts with no affine leave x unchanged.
|
| 15 |
+
5. Empty state does not crash.
|
| 16 |
+
6. Batch with multiple prompts does not crash.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import dataclasses
|
| 20 |
+
import importlib
|
| 21 |
+
import math
|
| 22 |
+
import sys
|
| 23 |
+
import types
|
| 24 |
+
import unittest
|
| 25 |
+
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
# Detect torch availability — skip the whole module if absent
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
try:
|
| 30 |
+
import importlib.util as _ilu
|
| 31 |
+
_spec = _ilu.find_spec('torch')
|
| 32 |
+
_TORCH_AVAILABLE = _spec is not None and getattr(_spec, 'origin', None) is not None
|
| 33 |
+
except (ValueError, ModuleNotFoundError):
|
| 34 |
+
_TORCH_AVAILABLE = False
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
# A1111 module stubs (installed unconditionally so cfg_denoiser_hijack can
|
| 39 |
+
# be imported even when running the skip path)
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
|
| 42 |
+
def _stub(name):
|
| 43 |
+
m = types.ModuleType(name)
|
| 44 |
+
sys.modules.setdefault(name, m)
|
| 45 |
+
return sys.modules[name]
|
| 46 |
+
|
| 47 |
+
for _mod_name in ('modules', 'modules.script_callbacks', 'modules.sd_samplers',
|
| 48 |
+
'modules.shared', 'modules.prompt_parser', 'gradio'):
|
| 49 |
+
_stub(_mod_name)
|
| 50 |
+
|
| 51 |
+
import modules.script_callbacks as _sc
|
| 52 |
+
if not hasattr(_sc, 'on_cfg_denoiser'):
|
| 53 |
+
_sc.on_cfg_denoiser = lambda fn: None
|
| 54 |
+
if not hasattr(_sc, 'on_script_unloaded'):
|
| 55 |
+
_sc.on_script_unloaded = lambda fn: None
|
| 56 |
+
|
| 57 |
+
import modules.shared as _sh
|
| 58 |
+
if not hasattr(_sh, 'opts'):
|
| 59 |
+
_sh.opts = types.SimpleNamespace(
|
| 60 |
+
data={},
|
| 61 |
+
add_option=lambda *a, **k: None,
|
| 62 |
+
onchange=lambda *a, **k: None,
|
| 63 |
+
OptionInfo=lambda *a, **k: None,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
import modules.sd_samplers as _ss
|
| 67 |
+
if not hasattr(_ss, 'cfg_denoiser'):
|
| 68 |
+
_ss.cfg_denoiser = None
|
| 69 |
+
# cfg_denoiser_hijack installs a hijack on create_sampler at import time
|
| 70 |
+
if not hasattr(_ss, 'create_sampler'):
|
| 71 |
+
_ss.create_sampler = lambda *a, **k: None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# ---------------------------------------------------------------------------
|
| 75 |
+
# Repo root on sys.path
|
| 76 |
+
# ---------------------------------------------------------------------------
|
| 77 |
+
|
| 78 |
+
from pathlib import Path
|
| 79 |
+
_REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 80 |
+
if str(_REPO_ROOT) not in sys.path:
|
| 81 |
+
sys.path.insert(0, str(_REPO_ROOT))
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ---------------------------------------------------------------------------
|
| 85 |
+
# Conditional imports (only when torch is available)
|
| 86 |
+
# ---------------------------------------------------------------------------
|
| 87 |
+
|
| 88 |
+
if _TORCH_AVAILABLE:
|
| 89 |
+
import torch
|
| 90 |
+
from lib_neutral_prompt import global_state, neutral_prompt_parser
|
| 91 |
+
from lib_neutral_prompt.cfg_denoiser_hijack import _on_cfg_denoiser
|
| 92 |
+
import lib_neutral_prompt.affine_transform as _at_mod
|
| 93 |
+
|
| 94 |
+
# -----------------------------------------------------------------------
|
| 95 |
+
# Helpers
|
| 96 |
+
# -----------------------------------------------------------------------
|
| 97 |
+
|
| 98 |
+
def _leaf(text, angle=None):
|
| 99 |
+
"""LeafPrompt, optionally with a ROTATE affine (angle in turns)."""
|
| 100 |
+
transform = None
|
| 101 |
+
if angle is not None:
|
| 102 |
+
c = math.cos(angle * 2 * math.pi)
|
| 103 |
+
s = math.sin(angle * 2 * math.pi)
|
| 104 |
+
transform = torch.tensor([[c, -s, 0.0], [s, c, 0.0]])
|
| 105 |
+
return neutral_prompt_parser.LeafPrompt(1.0, None, transform, text)
|
| 106 |
+
|
| 107 |
+
def _comp(*children):
|
| 108 |
+
return neutral_prompt_parser.CompositePrompt(1.0, None, None, list(children))
|
| 109 |
+
|
| 110 |
+
@dataclasses.dataclass
|
| 111 |
+
class _FakeParams:
|
| 112 |
+
x: torch.Tensor
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# ---------------------------------------------------------------------------
|
| 116 |
+
# Test class
|
| 117 |
+
# ---------------------------------------------------------------------------
|
| 118 |
+
|
| 119 |
+
@unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping pipeline tests')
|
| 120 |
+
class TestOnCfgDenoiserSmokeTest(unittest.TestCase):
|
| 121 |
+
|
| 122 |
+
def setUp(self):
|
| 123 |
+
global_state.is_enabled = True
|
| 124 |
+
global_state.prompt_exprs = []
|
| 125 |
+
global_state.batch_cond_indices = []
|
| 126 |
+
|
| 127 |
+
def tearDown(self):
|
| 128 |
+
global_state.is_enabled = False
|
| 129 |
+
global_state.prompt_exprs = []
|
| 130 |
+
global_state.batch_cond_indices = []
|
| 131 |
+
|
| 132 |
+
# 1. disabled → x untouched
|
| 133 |
+
def test_disabled_leaves_x_unchanged(self):
|
| 134 |
+
global_state.is_enabled = False
|
| 135 |
+
x = torch.zeros(2, 4, 8, 8)
|
| 136 |
+
x[0] = 1.0
|
| 137 |
+
orig = x.clone()
|
| 138 |
+
params = _FakeParams(x=x.clone())
|
| 139 |
+
global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp", 0.125))]
|
| 140 |
+
global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
|
| 141 |
+
_on_cfg_denoiser(params)
|
| 142 |
+
self.assertTrue(torch.equal(params.x, orig), "Disabled hook must not modify x")
|
| 143 |
+
|
| 144 |
+
# 2. identity transform → x numerically unchanged
|
| 145 |
+
def test_identity_transform_no_change(self):
|
| 146 |
+
x = torch.randn(2, 4, 8, 8)
|
| 147 |
+
orig = x.clone()
|
| 148 |
+
params = _FakeParams(x=x.clone())
|
| 149 |
+
global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp", 0))]
|
| 150 |
+
global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
|
| 151 |
+
_on_cfg_denoiser(params)
|
| 152 |
+
self.assertTrue(torch.allclose(params.x, orig, atol=1e-5),
|
| 153 |
+
"Identity affine must not change x")
|
| 154 |
+
|
| 155 |
+
# 3. non-identity → apply_affine_transform is called (spy)
|
| 156 |
+
def test_nonidentity_calls_apply_affine_transform(self):
|
| 157 |
+
x = torch.zeros(2, 4, 8, 8)
|
| 158 |
+
x[1] = torch.randn(4, 8, 8)
|
| 159 |
+
params = _FakeParams(x=x.clone())
|
| 160 |
+
global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp", 0.25))]
|
| 161 |
+
global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
|
| 162 |
+
|
| 163 |
+
_real = _at_mod.apply_affine_transform
|
| 164 |
+
calls = []
|
| 165 |
+
def _spy(tensor, affine, mode='bilinear'):
|
| 166 |
+
calls.append(True)
|
| 167 |
+
return _real(tensor, affine, mode)
|
| 168 |
+
|
| 169 |
+
_at_mod.apply_affine_transform = _spy
|
| 170 |
+
try:
|
| 171 |
+
_on_cfg_denoiser(params)
|
| 172 |
+
finally:
|
| 173 |
+
_at_mod.apply_affine_transform = _real
|
| 174 |
+
|
| 175 |
+
self.assertGreater(len(calls), 0,
|
| 176 |
+
"apply_affine_transform must be called for non-identity transform")
|
| 177 |
+
|
| 178 |
+
# 4. no affine on any child → x unchanged
|
| 179 |
+
def test_no_affine_no_change(self):
|
| 180 |
+
x = torch.randn(2, 4, 8, 8)
|
| 181 |
+
orig = x.clone()
|
| 182 |
+
params = _FakeParams(x=x.clone())
|
| 183 |
+
global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp"))]
|
| 184 |
+
global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
|
| 185 |
+
_on_cfg_denoiser(params)
|
| 186 |
+
self.assertTrue(torch.allclose(params.x, orig, atol=1e-5),
|
| 187 |
+
"No affine → x must not change")
|
| 188 |
+
|
| 189 |
+
# 5. empty state → no crash
|
| 190 |
+
def test_empty_state_no_crash(self):
|
| 191 |
+
params = _FakeParams(x=torch.randn(1, 4, 8, 8))
|
| 192 |
+
global_state.prompt_exprs = []
|
| 193 |
+
global_state.batch_cond_indices = []
|
| 194 |
+
try:
|
| 195 |
+
_on_cfg_denoiser(params)
|
| 196 |
+
except Exception as e:
|
| 197 |
+
self.fail(f"_on_cfg_denoiser raised with empty state: {e}")
|
| 198 |
+
|
| 199 |
+
# 6. two batch entries → no crash
|
| 200 |
+
def test_batch_multiple_prompts_no_crash(self):
|
| 201 |
+
params = _FakeParams(x=torch.zeros(4, 4, 8, 8))
|
| 202 |
+
global_state.prompt_exprs = [
|
| 203 |
+
_comp(_leaf("base_a"), _leaf("perp_a", 0.25)),
|
| 204 |
+
_comp(_leaf("base_b"), _leaf("perp_b", 0.125)),
|
| 205 |
+
]
|
| 206 |
+
global_state.batch_cond_indices = [
|
| 207 |
+
[(0, 1.0), (1, 1.0)],
|
| 208 |
+
[(2, 1.0), (3, 1.0)],
|
| 209 |
+
]
|
| 210 |
+
try:
|
| 211 |
+
_on_cfg_denoiser(params)
|
| 212 |
+
except Exception as e:
|
| 213 |
+
self.fail(f"Multi-prompt batch raised: {e}")
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
if __name__ == '__main__':
|
| 217 |
+
unittest.main()
|
neutral_prompt_patcheds/test/perp_parser/test_basic_parser.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
import pathlib
|
| 3 |
+
import sys
|
| 4 |
+
sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
|
| 5 |
+
from lib_neutral_prompt import neutral_prompt_parser
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TestPromptParser(unittest.TestCase):
|
| 9 |
+
def setUp(self):
|
| 10 |
+
self.simple_prompt = neutral_prompt_parser.parse_root("hello :1.0")
|
| 11 |
+
self.and_prompt = neutral_prompt_parser.parse_root("hello AND goodbye :2.0")
|
| 12 |
+
self.and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP goodbye :2.0")
|
| 13 |
+
self.and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT goodbye :2.0")
|
| 14 |
+
self.nested_and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0 AND_PERP welcome :3.0]")
|
| 15 |
+
self.nested_and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0 AND_SALT welcome :3.0]")
|
| 16 |
+
self.invalid_weight = neutral_prompt_parser.parse_root("hello :not_a_float")
|
| 17 |
+
|
| 18 |
+
def test_simple_prompt_child_count(self):
|
| 19 |
+
self.assertEqual(len(self.simple_prompt.children), 1)
|
| 20 |
+
|
| 21 |
+
def test_simple_prompt_child_weight(self):
|
| 22 |
+
self.assertEqual(self.simple_prompt.children[0].weight, 1.0)
|
| 23 |
+
|
| 24 |
+
def test_simple_prompt_child_prompt(self):
|
| 25 |
+
self.assertEqual(self.simple_prompt.children[0].prompt, "hello ")
|
| 26 |
+
|
| 27 |
+
def test_square_weight_prompt(self):
|
| 28 |
+
prompt = "a [b c d e : f g h :1.5]"
|
| 29 |
+
parsed = neutral_prompt_parser.parse_root(prompt)
|
| 30 |
+
self.assertEqual(parsed.children[0].prompt, prompt)
|
| 31 |
+
|
| 32 |
+
composed_prompt = f"{prompt} AND_PERP other prompt"
|
| 33 |
+
parsed = neutral_prompt_parser.parse_root(composed_prompt)
|
| 34 |
+
self.assertEqual(parsed.children[0].prompt, prompt)
|
| 35 |
+
|
| 36 |
+
def test_and_prompt_child_count(self):
|
| 37 |
+
self.assertEqual(len(self.and_prompt.children), 2)
|
| 38 |
+
|
| 39 |
+
def test_and_prompt_child_weights_and_prompts(self):
|
| 40 |
+
self.assertEqual(self.and_prompt.children[0].weight, 1.0)
|
| 41 |
+
self.assertEqual(self.and_prompt.children[0].prompt, "hello ")
|
| 42 |
+
self.assertEqual(self.and_prompt.children[1].weight, 2.0)
|
| 43 |
+
self.assertEqual(self.and_prompt.children[1].prompt, " goodbye ")
|
| 44 |
+
|
| 45 |
+
def test_and_perp_prompt_child_count(self):
|
| 46 |
+
self.assertEqual(len(self.and_perp_prompt.children), 2)
|
| 47 |
+
|
| 48 |
+
def test_and_perp_prompt_child_types(self):
|
| 49 |
+
self.assertIsInstance(self.and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
|
| 50 |
+
self.assertIsInstance(self.and_perp_prompt.children[1], neutral_prompt_parser.LeafPrompt)
|
| 51 |
+
|
| 52 |
+
def test_and_perp_prompt_nested_child(self):
|
| 53 |
+
nested_child = self.and_perp_prompt.children[1]
|
| 54 |
+
self.assertEqual(nested_child.weight, 2.0)
|
| 55 |
+
self.assertEqual(nested_child.prompt.strip(), "goodbye")
|
| 56 |
+
|
| 57 |
+
def test_nested_and_perp_prompt_child_count(self):
|
| 58 |
+
self.assertEqual(len(self.nested_and_perp_prompt.children), 2)
|
| 59 |
+
|
| 60 |
+
def test_nested_and_perp_prompt_child_types(self):
|
| 61 |
+
self.assertIsInstance(self.nested_and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
|
| 62 |
+
self.assertIsInstance(self.nested_and_perp_prompt.children[1], neutral_prompt_parser.CompositePrompt)
|
| 63 |
+
|
| 64 |
+
def test_nested_and_perp_prompt_nested_child_types(self):
|
| 65 |
+
nested_child = self.nested_and_perp_prompt.children[1].children[0]
|
| 66 |
+
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
|
| 67 |
+
nested_child = self.nested_and_perp_prompt.children[1].children[1]
|
| 68 |
+
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
|
| 69 |
+
|
| 70 |
+
def test_nested_and_perp_prompt_nested_child(self):
|
| 71 |
+
nested_child = self.nested_and_perp_prompt.children[1].children[1]
|
| 72 |
+
self.assertEqual(nested_child.weight, 3.0)
|
| 73 |
+
self.assertEqual(nested_child.prompt.strip(), "welcome")
|
| 74 |
+
|
| 75 |
+
def test_invalid_weight_child_count(self):
|
| 76 |
+
self.assertEqual(len(self.invalid_weight.children), 1)
|
| 77 |
+
|
| 78 |
+
def test_invalid_weight_child_weight(self):
|
| 79 |
+
self.assertEqual(self.invalid_weight.children[0].weight, 1.0)
|
| 80 |
+
|
| 81 |
+
def test_invalid_weight_child_prompt(self):
|
| 82 |
+
self.assertEqual(self.invalid_weight.children[0].prompt, "hello :not_a_float")
|
| 83 |
+
|
| 84 |
+
def test_and_salt_prompt_child_count(self):
|
| 85 |
+
self.assertEqual(len(self.and_salt_prompt.children), 2)
|
| 86 |
+
|
| 87 |
+
def test_and_salt_prompt_child_types(self):
|
| 88 |
+
self.assertIsInstance(self.and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
|
| 89 |
+
self.assertIsInstance(self.and_salt_prompt.children[1], neutral_prompt_parser.LeafPrompt)
|
| 90 |
+
|
| 91 |
+
def test_and_salt_prompt_nested_child(self):
|
| 92 |
+
nested_child = self.and_salt_prompt.children[1]
|
| 93 |
+
self.assertEqual(nested_child.weight, 2.0)
|
| 94 |
+
self.assertEqual(nested_child.prompt.strip(), "goodbye")
|
| 95 |
+
|
| 96 |
+
def test_nested_and_salt_prompt_child_count(self):
|
| 97 |
+
self.assertEqual(len(self.nested_and_salt_prompt.children), 2)
|
| 98 |
+
|
| 99 |
+
def test_nested_and_salt_prompt_child_types(self):
|
| 100 |
+
self.assertIsInstance(self.nested_and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
|
| 101 |
+
self.assertIsInstance(self.nested_and_salt_prompt.children[1], neutral_prompt_parser.CompositePrompt)
|
| 102 |
+
|
| 103 |
+
def test_nested_and_salt_prompt_nested_child_types(self):
|
| 104 |
+
nested_child = self.nested_and_salt_prompt.children[1].children[0]
|
| 105 |
+
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
|
| 106 |
+
nested_child = self.nested_and_salt_prompt.children[1].children[1]
|
| 107 |
+
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
|
| 108 |
+
|
| 109 |
+
def test_nested_and_salt_prompt_nested_child(self):
|
| 110 |
+
nested_child = self.nested_and_salt_prompt.children[1].children[1]
|
| 111 |
+
self.assertEqual(nested_child.weight, 3.0)
|
| 112 |
+
self.assertEqual(nested_child.prompt.strip(), "welcome")
|
| 113 |
+
|
| 114 |
+
def test_start_with_prompt_editing(self):
|
| 115 |
+
prompt = "[(long shot:1.2):0.1] detail.."
|
| 116 |
+
res = neutral_prompt_parser.parse_root(prompt)
|
| 117 |
+
self.assertEqual(res.children[0].weight, 1.0)
|
| 118 |
+
self.assertEqual(res.children[0].prompt, prompt)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == '__main__':
|
| 122 |
+
unittest.main()
|
neutral_prompt_patcheds/test/perp_parser/test_lock_after_end.py
ADDED
|
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for Lock after End (Sprint 3 feature).
|
| 3 |
+
|
| 4 |
+
Covers:
|
| 5 |
+
1. Basic lock: strategy frozen after window closes
|
| 6 |
+
2. No-lock: strategy reactivates normally without the flag
|
| 7 |
+
3. Lock not set before window start (haven't entered the window yet)
|
| 8 |
+
4. Reset at step 0: lock flags cleared each generation
|
| 9 |
+
5. Per-strategy custom: only the specific strategy gets locked
|
| 10 |
+
6. Global window + lock interaction
|
| 11 |
+
7. Infotext round-trip: lock_after_end serialized and restored
|
| 12 |
+
8. update helper: reset_step_lock_flags clears all flags
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import unittest
|
| 16 |
+
from lib_neutral_prompt import global_state
|
| 17 |
+
from lib_neutral_prompt.step_utils import (
|
| 18 |
+
StepWindow,
|
| 19 |
+
normalize_progress,
|
| 20 |
+
strategy_is_active,
|
| 21 |
+
strategy_is_active_from_state,
|
| 22 |
+
build_per_strategy_windows,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
class _StateGuard:
|
| 26 |
+
"""Context manager that saves and restores global_state step-window vars."""
|
| 27 |
+
_FIELDS = (
|
| 28 |
+
'step_window_enabled',
|
| 29 |
+
'step_window_global',
|
| 30 |
+
'step_window_per_strategy',
|
| 31 |
+
'step_window_use_defaults',
|
| 32 |
+
'step_window_lock_after_end',
|
| 33 |
+
'_step_lock_flags',
|
| 34 |
+
'_lock_generation_id',
|
| 35 |
+
'current_step',
|
| 36 |
+
'total_steps',
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def __enter__(self):
|
| 40 |
+
self._saved = {f: getattr(global_state, f, None) for f in self._FIELDS}
|
| 41 |
+
return self
|
| 42 |
+
|
| 43 |
+
def __exit__(self, *_):
|
| 44 |
+
for f, v in self._saved.items():
|
| 45 |
+
setattr(global_state, f, v)
|
| 46 |
+
return False
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
# 1. Basic lock behaviour
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
|
| 53 |
+
class TestLockAfterEndBasic(unittest.TestCase):
|
| 54 |
+
|
| 55 |
+
def test_lock_freezes_after_window_end(self):
|
| 56 |
+
"""
|
| 57 |
+
Window 0.0–0.5: active at progress 0.3, inactive (and locked) at 0.7.
|
| 58 |
+
Once locked, stays False even if progress would re-enter the range.
|
| 59 |
+
"""
|
| 60 |
+
with _StateGuard():
|
| 61 |
+
global_state.step_window_enabled = True
|
| 62 |
+
global_state.step_window_lock_after_end = True
|
| 63 |
+
global_state.step_window_global = StepWindow(0.0, 0.5)
|
| 64 |
+
global_state.step_window_use_defaults = False
|
| 65 |
+
global_state.step_window_per_strategy = None
|
| 66 |
+
global_state._step_lock_flags = {}
|
| 67 |
+
|
| 68 |
+
# Inside window — must be active
|
| 69 |
+
self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3))
|
| 70 |
+
|
| 71 |
+
# After window ends at progress 0.7 → inactive and locked
|
| 72 |
+
result_late = strategy_is_active_from_state('PERPENDICULAR', 0.7)
|
| 73 |
+
self.assertFalse(result_late)
|
| 74 |
+
|
| 75 |
+
# Flag must now be set
|
| 76 |
+
self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False))
|
| 77 |
+
|
| 78 |
+
# Subsequent call at 0.4 (would normally be inside window) still False
|
| 79 |
+
result_rewound = strategy_is_active_from_state('PERPENDICULAR', 0.4)
|
| 80 |
+
self.assertFalse(result_rewound)
|
| 81 |
+
|
| 82 |
+
def test_no_lock_flag_allows_reactivation(self):
|
| 83 |
+
"""Without lock_after_end the strategy returns True when re-entering range."""
|
| 84 |
+
with _StateGuard():
|
| 85 |
+
global_state.step_window_enabled = True
|
| 86 |
+
global_state.step_window_lock_after_end = False
|
| 87 |
+
global_state.step_window_global = StepWindow(0.0, 0.5)
|
| 88 |
+
global_state.step_window_use_defaults = False
|
| 89 |
+
global_state.step_window_per_strategy = None
|
| 90 |
+
global_state._step_lock_flags = {}
|
| 91 |
+
|
| 92 |
+
self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3))
|
| 93 |
+
self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.7))
|
| 94 |
+
# No lock set → a re-query at 0.3 returns True
|
| 95 |
+
self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3))
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# ---------------------------------------------------------------------------
|
| 99 |
+
# 2. Lock not fired before window start
|
| 100 |
+
# ---------------------------------------------------------------------------
|
| 101 |
+
|
| 102 |
+
class TestLockNotFiredBeforeWindowStart(unittest.TestCase):
|
| 103 |
+
|
| 104 |
+
def test_progress_before_window_start_does_not_lock(self):
|
| 105 |
+
"""
|
| 106 |
+
If progress is before the window start (strategy hasn't started yet),
|
| 107 |
+
the lock must NOT be set — otherwise the strategy could never activate.
|
| 108 |
+
Window 0.4–1.0: at progress 0.1 strategy is inactive but lock must NOT fire.
|
| 109 |
+
"""
|
| 110 |
+
with _StateGuard():
|
| 111 |
+
global_state.step_window_enabled = True
|
| 112 |
+
global_state.step_window_lock_after_end = True
|
| 113 |
+
global_state.step_window_global = StepWindow(0.4, 1.0)
|
| 114 |
+
global_state.step_window_use_defaults = False
|
| 115 |
+
global_state.step_window_per_strategy = None
|
| 116 |
+
global_state._step_lock_flags = {}
|
| 117 |
+
|
| 118 |
+
# Step 0 of 10 → progress 0.0 → before window start
|
| 119 |
+
result = strategy_is_active_from_state('SALIENCE_MASK', 0.0)
|
| 120 |
+
self.assertFalse(result)
|
| 121 |
+
self.assertFalse(global_state._step_lock_flags.get('SALIENCE_MASK', False),
|
| 122 |
+
'Lock must not fire before window start')
|
| 123 |
+
|
| 124 |
+
# Now enter the window at 0.5 → must be active
|
| 125 |
+
result_in = strategy_is_active_from_state('SALIENCE_MASK', 0.5)
|
| 126 |
+
self.assertTrue(result_in)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ---------------------------------------------------------------------------
|
| 130 |
+
# 3. reset_step_lock_flags and begin_new_generation
|
| 131 |
+
# ---------------------------------------------------------------------------
|
| 132 |
+
|
| 133 |
+
class TestLockLifecycle(unittest.TestCase):
|
| 134 |
+
|
| 135 |
+
def test_reset_clears_all_flags(self):
|
| 136 |
+
with _StateGuard():
|
| 137 |
+
global_state._step_lock_flags = {
|
| 138 |
+
'PERPENDICULAR': True,
|
| 139 |
+
'SALIENCE_MASK': True,
|
| 140 |
+
}
|
| 141 |
+
global_state.reset_step_lock_flags()
|
| 142 |
+
self.assertEqual(global_state._step_lock_flags, {})
|
| 143 |
+
|
| 144 |
+
def test_reset_idempotent_on_empty(self):
|
| 145 |
+
with _StateGuard():
|
| 146 |
+
global_state._step_lock_flags = {}
|
| 147 |
+
global_state.reset_step_lock_flags()
|
| 148 |
+
self.assertEqual(global_state._step_lock_flags, {})
|
| 149 |
+
|
| 150 |
+
def test_begin_new_generation_clears_flags(self):
|
| 151 |
+
with _StateGuard():
|
| 152 |
+
global_state._step_lock_flags = {'PERPENDICULAR': True}
|
| 153 |
+
global_state.begin_new_generation()
|
| 154 |
+
self.assertEqual(global_state._step_lock_flags, {})
|
| 155 |
+
|
| 156 |
+
def test_begin_new_generation_increments_id(self):
|
| 157 |
+
with _StateGuard():
|
| 158 |
+
before = global_state._lock_generation_id
|
| 159 |
+
global_state.begin_new_generation()
|
| 160 |
+
self.assertEqual(global_state._lock_generation_id, before + 1)
|
| 161 |
+
global_state.begin_new_generation()
|
| 162 |
+
self.assertEqual(global_state._lock_generation_id, before + 2)
|
| 163 |
+
|
| 164 |
+
def test_strategy_active_after_begin_new_generation(self):
|
| 165 |
+
"""After begin_new_generation, a previously locked strategy re-arms."""
|
| 166 |
+
with _StateGuard():
|
| 167 |
+
global_state.step_window_enabled = True
|
| 168 |
+
global_state.step_window_lock_after_end = True
|
| 169 |
+
global_state.step_window_global = StepWindow(0.0, 0.5)
|
| 170 |
+
global_state.step_window_use_defaults = False
|
| 171 |
+
global_state.step_window_per_strategy = None
|
| 172 |
+
global_state._step_lock_flags = {}
|
| 173 |
+
|
| 174 |
+
# Lock it
|
| 175 |
+
self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3))
|
| 176 |
+
self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.7))
|
| 177 |
+
self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False))
|
| 178 |
+
|
| 179 |
+
# New generation → re-arms
|
| 180 |
+
global_state.begin_new_generation()
|
| 181 |
+
self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3))
|
| 182 |
+
|
| 183 |
+
# ── KEY TEST: restart / rollback safety ───────────────────────────────
|
| 184 |
+
|
| 185 |
+
def test_step_zero_revisit_does_not_clear_lock(self):
|
| 186 |
+
"""
|
| 187 |
+
Sampler-internal rollback to step 0 (Restart sampler, DPM++ multi-pass)
|
| 188 |
+
must NOT clear lock flags. Only begin_new_generation() (= process() hook)
|
| 189 |
+
should do that.
|
| 190 |
+
"""
|
| 191 |
+
with _StateGuard():
|
| 192 |
+
global_state.step_window_enabled = True
|
| 193 |
+
global_state.step_window_lock_after_end = True
|
| 194 |
+
global_state.step_window_global = StepWindow(0.0, 0.5)
|
| 195 |
+
global_state.step_window_use_defaults = False
|
| 196 |
+
global_state.step_window_per_strategy = None
|
| 197 |
+
global_state._step_lock_flags = {}
|
| 198 |
+
|
| 199 |
+
# Normal run: enters window, exits, gets locked
|
| 200 |
+
self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3))
|
| 201 |
+
self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.7))
|
| 202 |
+
self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False))
|
| 203 |
+
|
| 204 |
+
# Simulate sampler rollback: step counter goes back to 0
|
| 205 |
+
global_state.update_step_state(0, 20)
|
| 206 |
+
progress_after_rollback = normalize_progress(
|
| 207 |
+
global_state.current_step, global_state.total_steps
|
| 208 |
+
)
|
| 209 |
+
# Without a new process() call, lock must still be active
|
| 210 |
+
result = strategy_is_active_from_state('PERPENDICULAR', progress_after_rollback)
|
| 211 |
+
self.assertFalse(result,
|
| 212 |
+
'Lock must survive sampler-internal rollback to step 0')
|
| 213 |
+
self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False),
|
| 214 |
+
'Flag must remain set after step-0 rollback')
|
| 215 |
+
|
| 216 |
+
def test_hires_fix_second_pass_clears_lock(self):
|
| 217 |
+
"""
|
| 218 |
+
Hires fix calls process() again for its second denoising pass.
|
| 219 |
+
Simulating that with begin_new_generation() must re-arm strategies.
|
| 220 |
+
"""
|
| 221 |
+
with _StateGuard():
|
| 222 |
+
global_state.step_window_enabled = True
|
| 223 |
+
global_state.step_window_lock_after_end = True
|
| 224 |
+
global_state.step_window_global = StepWindow(0.0, 0.5)
|
| 225 |
+
global_state.step_window_use_defaults = False
|
| 226 |
+
global_state.step_window_per_strategy = None
|
| 227 |
+
global_state._step_lock_flags = {}
|
| 228 |
+
|
| 229 |
+
# First pass: gets locked
|
| 230 |
+
self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.7))
|
| 231 |
+
|
| 232 |
+
# Hires fix: process() is called → begin_new_generation()
|
| 233 |
+
global_state.begin_new_generation()
|
| 234 |
+
|
| 235 |
+
# Second pass must start fresh
|
| 236 |
+
self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3))
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# ---------------------------------------------------------------------------
|
| 240 |
+
# 4. Per-strategy: only the locked strategy is frozen
|
| 241 |
+
# ---------------------------------------------------------------------------
|
| 242 |
+
|
| 243 |
+
class TestLockPerStrategy(unittest.TestCase):
|
| 244 |
+
|
| 245 |
+
def test_only_locked_strategy_frozen(self):
|
| 246 |
+
"""
|
| 247 |
+
Custom windows: AND_PERP 0.0–0.4, AND_SALT 0.5–1.0.
|
| 248 |
+
After AND_PERP window closes (progress=0.6), AND_PERP is locked.
|
| 249 |
+
AND_SALT is still active.
|
| 250 |
+
"""
|
| 251 |
+
with _StateGuard():
|
| 252 |
+
raw = {'AND_PERP': (0.0, 0.4), 'AND_SALT': (0.5, 1.0)}
|
| 253 |
+
global_state.step_window_enabled = True
|
| 254 |
+
global_state.step_window_lock_after_end = True
|
| 255 |
+
global_state.step_window_global = None
|
| 256 |
+
global_state.step_window_use_defaults = False
|
| 257 |
+
global_state.step_window_per_strategy = build_per_strategy_windows(raw)
|
| 258 |
+
global_state._step_lock_flags = {}
|
| 259 |
+
|
| 260 |
+
# Inside AND_PERP window
|
| 261 |
+
self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.2))
|
| 262 |
+
# Inside AND_SALT window
|
| 263 |
+
self.assertTrue(strategy_is_active_from_state('SALIENCE_MASK', 0.6))
|
| 264 |
+
|
| 265 |
+
# After AND_PERP ends → lock fires
|
| 266 |
+
self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.6))
|
| 267 |
+
self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False))
|
| 268 |
+
|
| 269 |
+
# AND_SALT unaffected
|
| 270 |
+
self.assertFalse(global_state._step_lock_flags.get('SALIENCE_MASK', False))
|
| 271 |
+
self.assertTrue(strategy_is_active_from_state('SALIENCE_MASK', 0.8))
|
| 272 |
+
|
| 273 |
+
def test_defaults_mode_with_lock(self):
|
| 274 |
+
"""
|
| 275 |
+
use_defaults=True: PERPENDICULAR default 0.0–0.5.
|
| 276 |
+
At progress 0.7 it should lock (window has ended and we were past start).
|
| 277 |
+
"""
|
| 278 |
+
with _StateGuard():
|
| 279 |
+
global_state.step_window_enabled = True
|
| 280 |
+
global_state.step_window_lock_after_end = True
|
| 281 |
+
global_state.step_window_global = None
|
| 282 |
+
global_state.step_window_use_defaults = True
|
| 283 |
+
global_state.step_window_per_strategy = None
|
| 284 |
+
global_state._step_lock_flags = {}
|
| 285 |
+
|
| 286 |
+
# Active inside default window
|
| 287 |
+
self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3))
|
| 288 |
+
# Outside → lock
|
| 289 |
+
self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.7))
|
| 290 |
+
self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False))
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# ---------------------------------------------------------------------------
|
| 294 |
+
# 5. Infotext round-trip
|
| 295 |
+
# ---------------------------------------------------------------------------
|
| 296 |
+
|
| 297 |
+
class TestLockAfterEndInfotext(unittest.TestCase):
|
| 298 |
+
|
| 299 |
+
def test_apply_infotext_restores_lock(self):
|
| 300 |
+
with _StateGuard():
|
| 301 |
+
global_state.step_window_lock_after_end = False
|
| 302 |
+
global_state.apply_infotext({
|
| 303 |
+
'NP Step Window Enabled': 'True',
|
| 304 |
+
'NP Step Window Mode': 'global',
|
| 305 |
+
'NP Step Window Start': '0.0',
|
| 306 |
+
'NP Step Window End': '0.5',
|
| 307 |
+
'NP Step Window Lock': 'True',
|
| 308 |
+
})
|
| 309 |
+
self.assertTrue(global_state.step_window_lock_after_end)
|
| 310 |
+
|
| 311 |
+
def test_apply_infotext_disables_lock(self):
|
| 312 |
+
with _StateGuard():
|
| 313 |
+
global_state.step_window_lock_after_end = True
|
| 314 |
+
global_state.apply_infotext({'NP Step Window Lock': 'False'})
|
| 315 |
+
self.assertFalse(global_state.step_window_lock_after_end)
|
| 316 |
+
|
| 317 |
+
def test_apply_infotext_lock_absent_leaves_state_unchanged(self):
|
| 318 |
+
with _StateGuard():
|
| 319 |
+
global_state.step_window_lock_after_end = True
|
| 320 |
+
global_state.apply_infotext({'NP Step Window Enabled': 'True'})
|
| 321 |
+
# Key absent → unchanged
|
| 322 |
+
self.assertTrue(global_state.step_window_lock_after_end)
|
| 323 |
+
|
| 324 |
+
def test_lock_key_in_infotext_constants(self):
|
| 325 |
+
self.assertEqual(global_state.INFOTEXT_KEY_STEP_WIN_LOCK, 'NP Step Window Lock')
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
# ---------------------------------------------------------------------------
|
| 329 |
+
# 6. Global simulation: step 0→N progression with lock
|
| 330 |
+
# ---------------------------------------------------------------------------
|
| 331 |
+
|
| 332 |
+
class TestLockSimulatedRun(unittest.TestCase):
|
| 333 |
+
"""
|
| 334 |
+
Simulate a real 10-step run with lock_after_end=True and window 0.0–0.5.
|
| 335 |
+
Verify: active steps 0–4, locked from step 5 onward.
|
| 336 |
+
"""
|
| 337 |
+
|
| 338 |
+
def test_simulated_10_step_run(self):
|
| 339 |
+
with _StateGuard():
|
| 340 |
+
global_state.step_window_enabled = True
|
| 341 |
+
global_state.step_window_lock_after_end = True
|
| 342 |
+
global_state.step_window_global = StepWindow(0.0, 0.5)
|
| 343 |
+
global_state.step_window_use_defaults = False
|
| 344 |
+
global_state.step_window_per_strategy = None
|
| 345 |
+
global_state._step_lock_flags = {}
|
| 346 |
+
|
| 347 |
+
results = []
|
| 348 |
+
for step in range(10):
|
| 349 |
+
global_state.update_step_state(step, 10)
|
| 350 |
+
p = normalize_progress(global_state.current_step, global_state.total_steps)
|
| 351 |
+
active = strategy_is_active_from_state('PERPENDICULAR', p)
|
| 352 |
+
results.append(active)
|
| 353 |
+
|
| 354 |
+
# Steps 0–4: progress 0.0–0.44 → inside window 0.0–0.5 → active
|
| 355 |
+
for i in range(5):
|
| 356 |
+
self.assertTrue(results[i], f'Expected active at step {i}')
|
| 357 |
+
# Step 5 onward: progress 0.56+ → outside → locked
|
| 358 |
+
for i in range(5, 10):
|
| 359 |
+
self.assertFalse(results[i], f'Expected locked at step {i}')
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
# ---------------------------------------------------------------------------
|
| 363 |
+
# 7. Restart-within-same-generation must NOT clear lock flags
|
| 364 |
+
# ---------------------------------------------------------------------------
|
| 365 |
+
|
| 366 |
+
class TestLockSurvivesRestartWithinRun(unittest.TestCase):
|
| 367 |
+
"""
|
| 368 |
+
Key property: if the sampler revisits step 0 within one generation
|
| 369 |
+
(e.g. Restart sampler, DPM++ SDE second pass, hires fix denoiser),
|
| 370 |
+
the lock flags must be preserved.
|
| 371 |
+
|
| 372 |
+
This is guaranteed by begin_new_generation() being called from process()
|
| 373 |
+
lifecycle, NOT from step == 0 in the denoising loop.
|
| 374 |
+
"""
|
| 375 |
+
|
| 376 |
+
def test_begin_new_generation_called_once_then_step_reset(self):
|
| 377 |
+
"""
|
| 378 |
+
Simulate: generation starts → begin_new_generation() → strategy locks →
|
| 379 |
+
sampler restarts (step resets to 0 internally) → lock survives.
|
| 380 |
+
"""
|
| 381 |
+
with _StateGuard():
|
| 382 |
+
global_state.step_window_enabled = True
|
| 383 |
+
global_state.step_window_lock_after_end = True
|
| 384 |
+
global_state.step_window_global = StepWindow(0.0, 0.5)
|
| 385 |
+
global_state.step_window_use_defaults = False
|
| 386 |
+
global_state.step_window_per_strategy = None
|
| 387 |
+
|
| 388 |
+
# --- Generation starts: process() calls begin_new_generation() ---
|
| 389 |
+
global_state.begin_new_generation()
|
| 390 |
+
|
| 391 |
+
# Inside window → active
|
| 392 |
+
self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3))
|
| 393 |
+
|
| 394 |
+
# Window ends → locked
|
| 395 |
+
self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.7))
|
| 396 |
+
self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False),
|
| 397 |
+
'Lock flag must be set after window end')
|
| 398 |
+
|
| 399 |
+
# Sampler restart: step goes back to 0 — but NO new begin_new_generation()
|
| 400 |
+
# (process() is NOT called again for a restart within one generation)
|
| 401 |
+
global_state.update_step_state(0, 20)
|
| 402 |
+
|
| 403 |
+
# Lock must survive the step reset
|
| 404 |
+
self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.0),
|
| 405 |
+
'Strategy must remain locked after intra-run step reset')
|
| 406 |
+
self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False),
|
| 407 |
+
'Lock flag must persist after intra-run step reset')
|
| 408 |
+
|
| 409 |
+
def test_ensure_step_lock_run_initialized_same_token_no_reset(self):
|
| 410 |
+
"""
|
| 411 |
+
ensure_step_lock_run_initialized with the same token must NOT reset flags.
|
| 412 |
+
"""
|
| 413 |
+
with _StateGuard():
|
| 414 |
+
global_state.step_window_enabled = True
|
| 415 |
+
global_state.step_window_lock_after_end = True
|
| 416 |
+
global_state.step_window_global = StepWindow(0.0, 0.5)
|
| 417 |
+
global_state.step_window_use_defaults = False
|
| 418 |
+
global_state.step_window_per_strategy = None
|
| 419 |
+
global_state._step_lock_flags = {}
|
| 420 |
+
|
| 421 |
+
RUN_TOKEN = object()
|
| 422 |
+
global_state.ensure_step_lock_run_initialized(RUN_TOKEN)
|
| 423 |
+
|
| 424 |
+
# Active inside window
|
| 425 |
+
self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.2))
|
| 426 |
+
# Window ends → lock fires
|
| 427 |
+
self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.7))
|
| 428 |
+
self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False))
|
| 429 |
+
|
| 430 |
+
# Same token called again (sampler restart within same generation)
|
| 431 |
+
reset_happened = global_state.ensure_step_lock_run_initialized(RUN_TOKEN)
|
| 432 |
+
|
| 433 |
+
self.assertFalse(reset_happened, 'Same token must not trigger reset')
|
| 434 |
+
self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False),
|
| 435 |
+
'Lock flag must survive same-token re-call')
|
| 436 |
+
|
| 437 |
+
# Strategy must still be locked
|
| 438 |
+
self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.3),
|
| 439 |
+
'Strategy must stay locked within same run')
|
| 440 |
+
|
| 441 |
+
def test_ensure_step_lock_run_initialized_new_token_resets(self):
|
| 442 |
+
"""
|
| 443 |
+
ensure_step_lock_run_initialized with a NEW token must reset flags
|
| 444 |
+
(simulates a genuinely new generation).
|
| 445 |
+
"""
|
| 446 |
+
with _StateGuard():
|
| 447 |
+
global_state.step_window_enabled = True
|
| 448 |
+
global_state.step_window_lock_after_end = True
|
| 449 |
+
global_state.step_window_global = StepWindow(0.0, 0.5)
|
| 450 |
+
global_state.step_window_use_defaults = False
|
| 451 |
+
global_state.step_window_per_strategy = None
|
| 452 |
+
global_state._step_lock_flags = {}
|
| 453 |
+
|
| 454 |
+
TOKEN_A = object()
|
| 455 |
+
TOKEN_B = object()
|
| 456 |
+
|
| 457 |
+
global_state.ensure_step_lock_run_initialized(TOKEN_A)
|
| 458 |
+
|
| 459 |
+
# Lock fires
|
| 460 |
+
strategy_is_active_from_state('PERPENDICULAR', 0.2) # activate
|
| 461 |
+
self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.7))
|
| 462 |
+
self.assertTrue(global_state._step_lock_flags.get('PERPENDICULAR', False))
|
| 463 |
+
|
| 464 |
+
# New generation: new token
|
| 465 |
+
reset_happened = global_state.ensure_step_lock_run_initialized(TOKEN_B)
|
| 466 |
+
|
| 467 |
+
self.assertTrue(reset_happened, 'New token must trigger reset')
|
| 468 |
+
self.assertFalse(global_state._step_lock_flags.get('PERPENDICULAR', False),
|
| 469 |
+
'Lock flags must be cleared for new generation')
|
| 470 |
+
|
| 471 |
+
# Strategy must now be active again (inside window)
|
| 472 |
+
self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3),
|
| 473 |
+
'Strategy must be active again after new generation reset')
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
# ---------------------------------------------------------------------------
|
| 477 |
+
# 8. begin_new_generation() vs reset_step_lock_flags() contract
|
| 478 |
+
# ---------------------------------------------------------------------------
|
| 479 |
+
|
| 480 |
+
class TestBeginNewGenerationContract(unittest.TestCase):
|
| 481 |
+
|
| 482 |
+
def test_begin_new_generation_increments_id(self):
|
| 483 |
+
with _StateGuard():
|
| 484 |
+
before = global_state._lock_generation_id
|
| 485 |
+
global_state.begin_new_generation()
|
| 486 |
+
self.assertGreater(global_state._lock_generation_id, before)
|
| 487 |
+
|
| 488 |
+
def test_begin_new_generation_clears_flags(self):
|
| 489 |
+
with _StateGuard():
|
| 490 |
+
global_state._step_lock_flags = {'PERPENDICULAR': True, 'SALIENCE_MASK': True}
|
| 491 |
+
global_state.begin_new_generation()
|
| 492 |
+
self.assertEqual(global_state._step_lock_flags, {})
|
| 493 |
+
|
| 494 |
+
def test_begin_new_generation_twice_clears_flags_both_times(self):
|
| 495 |
+
"""Two consecutive begin_new_generation() calls each start with clean flags."""
|
| 496 |
+
with _StateGuard():
|
| 497 |
+
global_state.begin_new_generation()
|
| 498 |
+
global_state._step_lock_flags['PERPENDICULAR'] = True
|
| 499 |
+
global_state.begin_new_generation()
|
| 500 |
+
self.assertEqual(global_state._step_lock_flags, {})
|
| 501 |
+
|
| 502 |
+
def test_reset_step_lock_flags_does_not_change_generation_id(self):
|
| 503 |
+
with _StateGuard():
|
| 504 |
+
gen_id_before = global_state._lock_generation_id
|
| 505 |
+
global_state.reset_step_lock_flags()
|
| 506 |
+
self.assertEqual(global_state._lock_generation_id, gen_id_before)
|
| 507 |
+
|
| 508 |
+
def test_process_lifecycle_integration(self):
|
| 509 |
+
"""
|
| 510 |
+
Verify the full lifecycle as used in scripts/neutral_prompt.py:
|
| 511 |
+
begin_new_generation() once per process() → locks survive denoising loop →
|
| 512 |
+
next begin_new_generation() clears them.
|
| 513 |
+
"""
|
| 514 |
+
with _StateGuard():
|
| 515 |
+
global_state.step_window_enabled = True
|
| 516 |
+
global_state.step_window_lock_after_end = True
|
| 517 |
+
global_state.step_window_global = StepWindow(0.0, 0.5)
|
| 518 |
+
global_state.step_window_use_defaults = False
|
| 519 |
+
global_state.step_window_per_strategy = None
|
| 520 |
+
|
| 521 |
+
# --- Generation 1: process() called once ---
|
| 522 |
+
global_state.begin_new_generation()
|
| 523 |
+
|
| 524 |
+
strategy_is_active_from_state('PERPENDICULAR', 0.2) # activate
|
| 525 |
+
self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.7))
|
| 526 |
+
locked_gen1 = global_state._step_lock_flags.get('PERPENDICULAR', False)
|
| 527 |
+
self.assertTrue(locked_gen1)
|
| 528 |
+
|
| 529 |
+
# No new process() call → lock survives any step count
|
| 530 |
+
self.assertFalse(strategy_is_active_from_state('PERPENDICULAR', 0.0))
|
| 531 |
+
|
| 532 |
+
# --- Generation 2: new process() call ---
|
| 533 |
+
global_state.begin_new_generation()
|
| 534 |
+
flags_after_new_gen = global_state._step_lock_flags.copy()
|
| 535 |
+
self.assertEqual(flags_after_new_gen, {},
|
| 536 |
+
'All flags must be cleared at start of new generation')
|
| 537 |
+
|
| 538 |
+
# Strategy active again inside its window
|
| 539 |
+
self.assertTrue(strategy_is_active_from_state('PERPENDICULAR', 0.3))
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
if __name__ == '__main__':
|
| 543 |
+
unittest.main()
|
| 544 |
+
|
neutral_prompt_patcheds/test/perp_parser/test_malicious_parser.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
import pathlib
|
| 3 |
+
import sys
|
| 4 |
+
sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
|
| 5 |
+
from lib_neutral_prompt import neutral_prompt_parser
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TestMaliciousPromptParser(unittest.TestCase):
|
| 9 |
+
def setUp(self):
|
| 10 |
+
self.parser = neutral_prompt_parser
|
| 11 |
+
|
| 12 |
+
def test_empty(self):
|
| 13 |
+
result = self.parser.parse_root("")
|
| 14 |
+
self.assertEqual(result.children[0].prompt, "")
|
| 15 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 16 |
+
|
| 17 |
+
def test_zero_weight(self):
|
| 18 |
+
result = self.parser.parse_root("hello :0.0")
|
| 19 |
+
self.assertEqual(result.children[0].weight, 0.0)
|
| 20 |
+
|
| 21 |
+
def test_mixed_positive_and_negative_weights(self):
|
| 22 |
+
result = self.parser.parse_root("hello :1.0 AND goodbye :-2.0")
|
| 23 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 24 |
+
self.assertEqual(result.children[1].weight, -2.0)
|
| 25 |
+
|
| 26 |
+
def test_debalanced_square_brackets(self):
|
| 27 |
+
prompt = "a [ b " * 100
|
| 28 |
+
result = self.parser.parse_root(prompt)
|
| 29 |
+
self.assertEqual(result.children[0].prompt, prompt)
|
| 30 |
+
|
| 31 |
+
prompt = "a ] b " * 100
|
| 32 |
+
result = self.parser.parse_root(prompt)
|
| 33 |
+
self.assertEqual(result.children[0].prompt, prompt)
|
| 34 |
+
|
| 35 |
+
repeats = 10
|
| 36 |
+
prompt = "a [ [ b AND c ] " * repeats
|
| 37 |
+
result = self.parser.parse_root(prompt)
|
| 38 |
+
self.assertEqual([x.prompt for x in result.children], ["a [[ b ", *[" c ] a [[ b "] * (repeats - 1), " c ]"])
|
| 39 |
+
|
| 40 |
+
repeats = 10
|
| 41 |
+
prompt = "a [ b AND c ] ] " * repeats
|
| 42 |
+
result = self.parser.parse_root(prompt)
|
| 43 |
+
self.assertEqual([x.prompt for x in result.children], ["a [ b ", *[" c ]] a [ b "] * (repeats - 1), " c ]]"])
|
| 44 |
+
|
| 45 |
+
def test_erroneous_syntax(self):
|
| 46 |
+
result = self.parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0")
|
| 47 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 48 |
+
self.assertEqual(result.children[1].prompt, "[goodbye ")
|
| 49 |
+
self.assertEqual(result.children[1].weight, 2.0)
|
| 50 |
+
|
| 51 |
+
result = self.parser.parse_root("hello :1.0 AND_PERP goodbye :2.0]")
|
| 52 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 53 |
+
self.assertEqual(result.children[1].prompt, " goodbye ")
|
| 54 |
+
|
| 55 |
+
result = self.parser.parse_root("hello :1.0 AND_PERP goodbye] :2.0")
|
| 56 |
+
self.assertEqual(result.children[1].prompt, " goodbye]")
|
| 57 |
+
self.assertEqual(result.children[1].weight, 2.0)
|
| 58 |
+
|
| 59 |
+
result = self.parser.parse_root("hello :1.0 AND_PERP a [ goodbye :2.0")
|
| 60 |
+
self.assertEqual(result.children[1].weight, 2.0)
|
| 61 |
+
self.assertEqual(result.children[1].prompt, " a [ goodbye ")
|
| 62 |
+
|
| 63 |
+
result = self.parser.parse_root("hello :1.0 AND_PERP AND goodbye :2.0")
|
| 64 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 65 |
+
self.assertEqual(result.children[2].prompt, " goodbye ")
|
| 66 |
+
|
| 67 |
+
def test_huge_number_of_prompt_parts(self):
|
| 68 |
+
result = self.parser.parse_root(" AND ".join(f"hello{i} :{i}" for i in range(10**4)))
|
| 69 |
+
self.assertEqual(len(result.children), 10**4)
|
| 70 |
+
|
| 71 |
+
def test_prompt_ending_with_weight(self):
|
| 72 |
+
result = self.parser.parse_root("hello :1.0 AND :2.0")
|
| 73 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 74 |
+
self.assertEqual(result.children[1].prompt, "")
|
| 75 |
+
self.assertEqual(result.children[1].weight, 2.0)
|
| 76 |
+
|
| 77 |
+
def test_huge_input_string(self):
|
| 78 |
+
big_string = "hello :1.0 AND " * 10**4
|
| 79 |
+
result = self.parser.parse_root(big_string)
|
| 80 |
+
self.assertEqual(len(result.children), 10**4 + 1)
|
| 81 |
+
|
| 82 |
+
def test_deeply_nested_prompt(self):
|
| 83 |
+
deeply_nested_prompt = "hello :1.0" + " AND_PERP [goodbye :2.0" * 100 + "]" * 100
|
| 84 |
+
result = self.parser.parse_root(deeply_nested_prompt)
|
| 85 |
+
self.assertIsInstance(result.children[1], neutral_prompt_parser.CompositePrompt)
|
| 86 |
+
|
| 87 |
+
def test_complex_nested_prompts(self):
|
| 88 |
+
complex_prompt = "hello :1.0 AND goodbye :2.0 AND_PERP [welcome :3.0 AND farewell :4.0 AND_PERP greetings:5.0]"
|
| 89 |
+
result = self.parser.parse_root(complex_prompt)
|
| 90 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 91 |
+
self.assertEqual(result.children[1].weight, 2.0)
|
| 92 |
+
self.assertEqual(result.children[2].children[0].weight, 3.0)
|
| 93 |
+
self.assertEqual(result.children[2].children[1].weight, 4.0)
|
| 94 |
+
self.assertEqual(result.children[2].children[2].weight, 5.0)
|
| 95 |
+
|
| 96 |
+
def test_string_with_random_characters(self):
|
| 97 |
+
random_chars = "ASDFGHJKL:@#$/.,|}{><~`12[3]456AND_PERP7890"
|
| 98 |
+
try:
|
| 99 |
+
self.parser.parse_root(random_chars)
|
| 100 |
+
except Exception:
|
| 101 |
+
self.fail("parse_root couldn't handle a string with random characters.")
|
| 102 |
+
|
| 103 |
+
def test_string_with_unexpected_symbols(self):
|
| 104 |
+
unexpected_symbols = "hello :1.0 AND $%^&*()goodbye :2.0"
|
| 105 |
+
try:
|
| 106 |
+
self.parser.parse_root(unexpected_symbols)
|
| 107 |
+
except Exception:
|
| 108 |
+
self.fail("parse_root couldn't handle a string with unexpected symbols.")
|
| 109 |
+
|
| 110 |
+
def test_string_with_unconventional_structure(self):
|
| 111 |
+
unconventional_structure = "hello :1.0 AND_PERP :2.0 AND [goodbye]"
|
| 112 |
+
try:
|
| 113 |
+
self.parser.parse_root(unconventional_structure)
|
| 114 |
+
except Exception:
|
| 115 |
+
self.fail("parse_root couldn't handle a string with unconventional structure.")
|
| 116 |
+
|
| 117 |
+
def test_string_with_mixed_alphabets_and_numbers(self):
|
| 118 |
+
mixed_alphabets_and_numbers = "123hello :1.0 AND goodbye456 :2.0"
|
| 119 |
+
try:
|
| 120 |
+
self.parser.parse_root(mixed_alphabets_and_numbers)
|
| 121 |
+
except Exception:
|
| 122 |
+
self.fail("parse_root couldn't handle a string with mixed alphabets and numbers.")
|
| 123 |
+
|
| 124 |
+
def test_string_with_nested_brackets(self):
|
| 125 |
+
nested_brackets = "hello :1.0 AND [goodbye :2.0 AND [[welcome :3.0]]]"
|
| 126 |
+
try:
|
| 127 |
+
self.parser.parse_root(nested_brackets)
|
| 128 |
+
except Exception:
|
| 129 |
+
self.fail("parse_root couldn't handle a string with nested brackets.")
|
| 130 |
+
|
| 131 |
+
def test_unmatched_opening_braces(self):
|
| 132 |
+
unmatched_opening_braces = "hello [[[[[[[[[ :1.0 AND_PERP goodbye :2.0"
|
| 133 |
+
try:
|
| 134 |
+
self.parser.parse_root(unmatched_opening_braces)
|
| 135 |
+
except Exception:
|
| 136 |
+
self.fail("parse_root couldn't handle a string with unmatched opening braces.")
|
| 137 |
+
|
| 138 |
+
def test_unmatched_closing_braces(self):
|
| 139 |
+
unmatched_closing_braces = "hello :1.0 AND_PERP goodbye ]]]]]]]]] :2.0"
|
| 140 |
+
try:
|
| 141 |
+
self.parser.parse_root(unmatched_closing_braces)
|
| 142 |
+
except Exception:
|
| 143 |
+
self.fail("parse_root couldn't handle a string with unmatched closing braces.")
|
| 144 |
+
|
| 145 |
+
def test_repeating_colons(self):
|
| 146 |
+
repeating_colons = "hello ::::::: :1.0 AND_PERP goodbye :::: :2.0"
|
| 147 |
+
try:
|
| 148 |
+
self.parser.parse_root(repeating_colons)
|
| 149 |
+
except Exception:
|
| 150 |
+
self.fail("parse_root couldn't handle a string with repeating colons.")
|
| 151 |
+
|
| 152 |
+
def test_excessive_whitespace(self):
|
| 153 |
+
excessive_whitespace = "hello :1.0 AND_PERP goodbye :2.0"
|
| 154 |
+
try:
|
| 155 |
+
self.parser.parse_root(excessive_whitespace)
|
| 156 |
+
except Exception:
|
| 157 |
+
self.fail("parse_root couldn't handle a string with excessive whitespace.")
|
| 158 |
+
|
| 159 |
+
def test_repeating_AND_keyword(self):
|
| 160 |
+
repeating_AND_keyword = "hello :1.0 AND AND AND AND AND goodbye :2.0"
|
| 161 |
+
try:
|
| 162 |
+
self.parser.parse_root(repeating_AND_keyword)
|
| 163 |
+
except Exception:
|
| 164 |
+
self.fail("parse_root couldn't handle a string with repeating AND keyword.")
|
| 165 |
+
|
| 166 |
+
def test_repeating_AND_PERP_keyword(self):
|
| 167 |
+
repeating_AND_PERP_keyword = "hello :1.0 AND_PERP AND_PERP AND_PERP AND_PERP goodbye :2.0"
|
| 168 |
+
try:
|
| 169 |
+
self.parser.parse_root(repeating_AND_PERP_keyword)
|
| 170 |
+
except Exception:
|
| 171 |
+
self.fail("parse_root couldn't handle a string with repeating AND_PERP keyword.")
|
| 172 |
+
|
| 173 |
+
def test_square_weight_prompt(self):
|
| 174 |
+
prompt = "AND_PERP [weighted] you thought it was the end"
|
| 175 |
+
try:
|
| 176 |
+
self.parser.parse_root(prompt)
|
| 177 |
+
except Exception:
|
| 178 |
+
self.fail("parse_root couldn't handle a string starting with a square-weighted sub-prompt.")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
if __name__ == '__main__':
|
| 182 |
+
unittest.main()
|
neutral_prompt_patcheds/test/perp_parser/test_matryoshka.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for v1.1 matryoshka tree render, diagnostics, and builder helpers.
|
| 3 |
+
All tests are pure-Python and require no real torch or gradio.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import unittest
|
| 7 |
+
# mock_torch is installed by the package __init__ before this file loads
|
| 8 |
+
|
| 9 |
+
import lib_neutral_prompt.neutral_prompt_parser as p
|
| 10 |
+
from lib_neutral_prompt.matryoshka_utils import (
|
| 11 |
+
render_prompt_node,
|
| 12 |
+
render_prompt_tree,
|
| 13 |
+
collect_prompt_diagnostics,
|
| 14 |
+
render_diagnostics,
|
| 15 |
+
render_full_explain,
|
| 16 |
+
build_child_block,
|
| 17 |
+
build_nested_prompt,
|
| 18 |
+
MATRYOSHKA_TEMPLATES,
|
| 19 |
+
_build_affine_snippet,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
# render_prompt_tree
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
|
| 27 |
+
class TestRenderPromptTree(unittest.TestCase):
|
| 28 |
+
|
| 29 |
+
def test_empty_prompt_returns_hint(self):
|
| 30 |
+
out = render_prompt_tree('')
|
| 31 |
+
self.assertIn('empty', out.lower())
|
| 32 |
+
|
| 33 |
+
def test_base_only(self):
|
| 34 |
+
out = render_prompt_tree('a dog in a park')
|
| 35 |
+
self.assertIn('ROOT', out)
|
| 36 |
+
self.assertIn('BASE', out)
|
| 37 |
+
self.assertIn('a dog in a park', out)
|
| 38 |
+
|
| 39 |
+
def test_single_child(self):
|
| 40 |
+
out = render_prompt_tree('base subject AND_SALT[5] texture :0.8')
|
| 41 |
+
self.assertIn('BASE', out)
|
| 42 |
+
self.assertIn('SALIENCE_MASK', out)
|
| 43 |
+
self.assertIn('k=5', out)
|
| 44 |
+
|
| 45 |
+
def test_topk_child(self):
|
| 46 |
+
out = render_prompt_tree('base AND_TOPK[0.1] highlights :0.6')
|
| 47 |
+
self.assertIn('SEMANTIC_GUIDANCE', out)
|
| 48 |
+
self.assertIn('threshold=0.1', out)
|
| 49 |
+
|
| 50 |
+
def test_align_child(self):
|
| 51 |
+
out = render_prompt_tree('base AND_ALIGN[4,8] style :0.7')
|
| 52 |
+
self.assertIn('ALIGNMENT_BLEND_CUSTOM', out)
|
| 53 |
+
self.assertIn('d=4', out)
|
| 54 |
+
self.assertIn('s=8', out)
|
| 55 |
+
|
| 56 |
+
def test_affine_annotated(self):
|
| 57 |
+
out = render_prompt_tree('base ROTATE[0.125] AND_PERP vivid :0.8')
|
| 58 |
+
self.assertIn('affine', out)
|
| 59 |
+
|
| 60 |
+
def test_multiple_children_all_shown(self):
|
| 61 |
+
out = render_prompt_tree('base AND_SALT[5] tex :0.8 AND_TOPK[0.1] style :0.6')
|
| 62 |
+
self.assertIn('SALIENCE_MASK', out)
|
| 63 |
+
self.assertIn('SEMANTIC_GUIDANCE', out)
|
| 64 |
+
|
| 65 |
+
def test_nested_prompt_depth(self):
|
| 66 |
+
# nested composite should show multiple levels
|
| 67 |
+
prompt = 'base AND_SALT[5] [ inner text :0.8 AND_TOPK[0.1] detail :0.5 ] :0.7'
|
| 68 |
+
out = render_prompt_tree(prompt)
|
| 69 |
+
self.assertIn('SALIENCE_MASK', out)
|
| 70 |
+
# inner structure rendered
|
| 71 |
+
|
| 72 |
+
def test_tree_connectors_present(self):
|
| 73 |
+
out = render_prompt_tree('base AND_SALT[5] tex :0.8 AND_PERP vivid :0.5')
|
| 74 |
+
# should contain box-drawing branch chars
|
| 75 |
+
self.assertTrue('├' in out or '└' in out)
|
| 76 |
+
|
| 77 |
+
def test_weight_shown(self):
|
| 78 |
+
out = render_prompt_tree('base AND_PERP style :0.42')
|
| 79 |
+
self.assertIn('0.42', out)
|
| 80 |
+
|
| 81 |
+
def test_parse_error_graceful(self):
|
| 82 |
+
# Pathological input must not raise
|
| 83 |
+
try:
|
| 84 |
+
out = render_prompt_tree('[[[[[')
|
| 85 |
+
except Exception as exc:
|
| 86 |
+
self.fail(f'render_prompt_tree raised on bad input: {exc}')
|
| 87 |
+
|
| 88 |
+
def test_deeply_nested(self):
|
| 89 |
+
prompt = ('base AND_SALT[5] [\n'
|
| 90 |
+
' region :0.8\n'
|
| 91 |
+
' AND_ALIGN[4,8] [\n'
|
| 92 |
+
' sub-structure :0.7\n'
|
| 93 |
+
' AND_TOPK[0.05] micro :0.4\n'
|
| 94 |
+
' ] :0.6\n'
|
| 95 |
+
'] :0.7')
|
| 96 |
+
try:
|
| 97 |
+
out = render_prompt_tree(prompt)
|
| 98 |
+
except Exception as exc:
|
| 99 |
+
self.fail(f'deeply nested raised: {exc}')
|
| 100 |
+
self.assertIn('ROOT', out)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# ---------------------------------------------------------------------------
|
| 104 |
+
# render_prompt_node (unit)
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
|
| 107 |
+
class TestRenderPromptNode(unittest.TestCase):
|
| 108 |
+
|
| 109 |
+
def _leaf_node(self, text='hello', strat=None, params=None, w=1.0):
|
| 110 |
+
return p.LeafPrompt(
|
| 111 |
+
weight=w,
|
| 112 |
+
conciliation=strat,
|
| 113 |
+
local_transform=None,
|
| 114 |
+
prompt=text,
|
| 115 |
+
conciliation_params=params or {},
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def test_base_leaf(self):
|
| 119 |
+
node = self._leaf_node('base text')
|
| 120 |
+
lines = render_prompt_node(node)
|
| 121 |
+
self.assertEqual(len(lines), 1)
|
| 122 |
+
self.assertIn('BASE', lines[0])
|
| 123 |
+
self.assertIn('base text', lines[0])
|
| 124 |
+
|
| 125 |
+
def test_leaf_with_params(self):
|
| 126 |
+
node = self._leaf_node('texture', strat=p.ConciliationStrategy.SALIENCE_MASK,
|
| 127 |
+
params={'k': 5.0})
|
| 128 |
+
lines = render_prompt_node(node)
|
| 129 |
+
self.assertIn('k=5.0', lines[0])
|
| 130 |
+
|
| 131 |
+
def test_leaf_affine_flag(self):
|
| 132 |
+
import lib_neutral_prompt.neutral_prompt_parser as _pp
|
| 133 |
+
node = self._leaf_node('vivid', strat=_pp.ConciliationStrategy.PERPENDICULAR)
|
| 134 |
+
try:
|
| 135 |
+
import torch
|
| 136 |
+
node.local_transform = torch.eye(3)[:-1]
|
| 137 |
+
except ImportError:
|
| 138 |
+
return # skip if no torch
|
| 139 |
+
lines = render_prompt_node(node)
|
| 140 |
+
self.assertIn('affine', lines[0])
|
| 141 |
+
|
| 142 |
+
def test_connector_chars_for_non_last(self):
|
| 143 |
+
node = self._leaf_node('x')
|
| 144 |
+
lines = render_prompt_node(node, is_last=False)
|
| 145 |
+
self.assertTrue(lines[0].startswith('├─'))
|
| 146 |
+
|
| 147 |
+
def test_connector_chars_for_last(self):
|
| 148 |
+
node = self._leaf_node('x')
|
| 149 |
+
lines = render_prompt_node(node, is_last=True)
|
| 150 |
+
self.assertTrue(lines[0].startswith('└─'))
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ---------------------------------------------------------------------------
|
| 154 |
+
# collect_prompt_diagnostics
|
| 155 |
+
# ---------------------------------------------------------------------------
|
| 156 |
+
|
| 157 |
+
class TestCollectDiagnostics(unittest.TestCase):
|
| 158 |
+
|
| 159 |
+
def test_empty_prompt(self):
|
| 160 |
+
d = collect_prompt_diagnostics('')
|
| 161 |
+
self.assertFalse(d['parse_ok'])
|
| 162 |
+
|
| 163 |
+
def test_base_only(self):
|
| 164 |
+
d = collect_prompt_diagnostics('a dog')
|
| 165 |
+
self.assertTrue(d['parse_ok'])
|
| 166 |
+
self.assertGreater(d['n_segments'], 0)
|
| 167 |
+
self.assertIn('BASE', d['strategies'])
|
| 168 |
+
|
| 169 |
+
def test_single_salt(self):
|
| 170 |
+
d = collect_prompt_diagnostics('base AND_SALT[5] texture :0.8')
|
| 171 |
+
self.assertTrue(d['parse_ok'])
|
| 172 |
+
self.assertIn('SALIENCE_MASK', d['strategies'])
|
| 173 |
+
self.assertIn('BASE', d['strategies'])
|
| 174 |
+
|
| 175 |
+
def test_no_base_fires_protection(self):
|
| 176 |
+
d = collect_prompt_diagnostics('AND_SALT[5] concept :0.8')
|
| 177 |
+
self.assertTrue(d['parse_ok'])
|
| 178 |
+
self.assertTrue(d['protection_would_fire'])
|
| 179 |
+
|
| 180 |
+
def test_with_base_no_fire(self):
|
| 181 |
+
d = collect_prompt_diagnostics('base subject AND_PERP style :0.8')
|
| 182 |
+
self.assertFalse(d['protection_would_fire'])
|
| 183 |
+
|
| 184 |
+
def test_max_depth_base(self):
|
| 185 |
+
d = collect_prompt_diagnostics('just a base')
|
| 186 |
+
self.assertGreaterEqual(d['max_depth'], 0)
|
| 187 |
+
|
| 188 |
+
def test_affine_counted(self):
|
| 189 |
+
d = collect_prompt_diagnostics('base ROTATE[0.125] AND_PERP vivid :0.8')
|
| 190 |
+
self.assertGreater(d['n_affine'], 0)
|
| 191 |
+
|
| 192 |
+
def test_multiple_strategies(self):
|
| 193 |
+
d = collect_prompt_diagnostics('base AND_SALT[5] tex :0.8 AND_TOPK[0.1] sty :0.6')
|
| 194 |
+
self.assertIn('SALIENCE_MASK', d['strategies'])
|
| 195 |
+
self.assertIn('SEMANTIC_GUIDANCE', d['strategies'])
|
| 196 |
+
|
| 197 |
+
def test_parse_error_graceful(self):
|
| 198 |
+
d = collect_prompt_diagnostics('AND[invalid garbage')
|
| 199 |
+
# Should not raise; parse_ok may be False or True depending on parser tolerance
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# ---------------------------------------------------------------------------
|
| 203 |
+
# render_diagnostics
|
| 204 |
+
# ---------------------------------------------------------------------------
|
| 205 |
+
|
| 206 |
+
class TestRenderDiagnostics(unittest.TestCase):
|
| 207 |
+
|
| 208 |
+
def test_no_base_warning(self):
|
| 209 |
+
d = collect_prompt_diagnostics('AND_SALT[5] only :0.8')
|
| 210 |
+
out = render_diagnostics(d)
|
| 211 |
+
self.assertIn('WOULD FIRE', out)
|
| 212 |
+
|
| 213 |
+
def test_ok_case(self):
|
| 214 |
+
d = collect_prompt_diagnostics('base AND_PERP style :0.8')
|
| 215 |
+
out = render_diagnostics(d)
|
| 216 |
+
self.assertIn('OK', out)
|
| 217 |
+
|
| 218 |
+
def test_parse_error_shown(self):
|
| 219 |
+
d = {'parse_ok': False, 'parse_error': 'something bad'}
|
| 220 |
+
out = render_diagnostics(d)
|
| 221 |
+
self.assertIn('something bad', out)
|
| 222 |
+
|
| 223 |
+
def test_strategies_listed(self):
|
| 224 |
+
d = collect_prompt_diagnostics('base AND_SALT[5] tex :0.8')
|
| 225 |
+
out = render_diagnostics(d)
|
| 226 |
+
self.assertIn('SALIENCE_MASK', out)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
# ---------------------------------------------------------------------------
|
| 230 |
+
# render_full_explain
|
| 231 |
+
# ---------------------------------------------------------------------------
|
| 232 |
+
|
| 233 |
+
class TestRenderFullExplain(unittest.TestCase):
|
| 234 |
+
|
| 235 |
+
def test_contains_both_tree_and_diagnostics(self):
|
| 236 |
+
out = render_full_explain('base AND_SALT[5] texture :0.8')
|
| 237 |
+
self.assertIn('ROOT', out)
|
| 238 |
+
self.assertIn('Diagnostics', out)
|
| 239 |
+
|
| 240 |
+
def test_empty_prompt(self):
|
| 241 |
+
out = render_full_explain('')
|
| 242 |
+
self.assertIn('empty', out.lower())
|
| 243 |
+
|
| 244 |
+
def test_nested_prompt(self):
|
| 245 |
+
try:
|
| 246 |
+
out = render_full_explain(
|
| 247 |
+
'base AND_SALT[5] [ inner :0.8 AND_TOPK[0.05] detail :0.4 ] :0.7'
|
| 248 |
+
)
|
| 249 |
+
except Exception as exc:
|
| 250 |
+
self.fail(f'render_full_explain raised on nested: {exc}')
|
| 251 |
+
self.assertIn('ROOT', out)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# ---------------------------------------------------------------------------
|
| 255 |
+
# build_child_block
|
| 256 |
+
# ---------------------------------------------------------------------------
|
| 257 |
+
|
| 258 |
+
class TestBuildChildBlock(unittest.TestCase):
|
| 259 |
+
|
| 260 |
+
def test_simple_salt(self):
|
| 261 |
+
out = build_child_block('AND_SALT', 'texture detail', 0.8)
|
| 262 |
+
self.assertIn('AND_SALT[5]', out)
|
| 263 |
+
self.assertIn('texture detail', out)
|
| 264 |
+
self.assertIn(':0.8', out)
|
| 265 |
+
|
| 266 |
+
def test_topk_threshold(self):
|
| 267 |
+
out = build_child_block('AND_TOPK', 'highlights', 0.5)
|
| 268 |
+
self.assertIn('AND_TOPK[0.05]', out)
|
| 269 |
+
self.assertIn('highlights', out)
|
| 270 |
+
|
| 271 |
+
def test_align_default(self):
|
| 272 |
+
out = build_child_block('AND_ALIGN', 'structure', 0.7)
|
| 273 |
+
self.assertIn('AND_ALIGN[4,8]', out)
|
| 274 |
+
|
| 275 |
+
def test_perp_simple(self):
|
| 276 |
+
out = build_child_block('AND_PERP', 'blur', 0.5)
|
| 277 |
+
self.assertIn('AND_PERP', out)
|
| 278 |
+
self.assertIn('blur', out)
|
| 279 |
+
|
| 280 |
+
def test_with_affine(self):
|
| 281 |
+
out = build_child_block('AND_PERP', 'mirror', 0.6, affine='SCALE[-1,1]')
|
| 282 |
+
self.assertIn('SCALE[-1,1]', out)
|
| 283 |
+
self.assertTrue(out.startswith('SCALE[-1,1]'))
|
| 284 |
+
|
| 285 |
+
def test_with_nested_child(self):
|
| 286 |
+
nested = build_child_block('AND_TOPK', 'highlights', 0.4)
|
| 287 |
+
out = build_child_block('AND_SALT', 'texture', 0.8, nested=nested)
|
| 288 |
+
self.assertIn('[', out)
|
| 289 |
+
self.assertIn(']', out)
|
| 290 |
+
self.assertIn('AND_TOPK', out)
|
| 291 |
+
|
| 292 |
+
def test_empty_text_handled(self):
|
| 293 |
+
out = build_child_block('AND_SALT', '', 0.8)
|
| 294 |
+
# Empty text → empty string (caller should skip)
|
| 295 |
+
self.assertEqual(out, '')
|
| 296 |
+
|
| 297 |
+
def test_weight_formatting(self):
|
| 298 |
+
out = build_child_block('AND_PERP', 'concept', 0.35)
|
| 299 |
+
self.assertIn(':0.35', out)
|
| 300 |
+
|
| 301 |
+
def test_nested_block_indented(self):
|
| 302 |
+
nested = build_child_block('AND_PERP', 'inner', 0.4)
|
| 303 |
+
out = build_child_block('AND_SALT', 'outer', 0.8, nested=nested)
|
| 304 |
+
# Inner content should be indented
|
| 305 |
+
lines = out.splitlines()
|
| 306 |
+
inner_lines = [l for l in lines if 'inner' in l]
|
| 307 |
+
self.assertTrue(any(l.startswith(' ') for l in inner_lines),
|
| 308 |
+
'Inner nested line should be indented')
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
# ---------------------------------------------------------------------------
|
| 312 |
+
# build_nested_prompt
|
| 313 |
+
# ---------------------------------------------------------------------------
|
| 314 |
+
|
| 315 |
+
class TestBuildNestedPrompt(unittest.TestCase):
|
| 316 |
+
|
| 317 |
+
def test_base_only(self):
|
| 318 |
+
out = build_nested_prompt('a dog', [])
|
| 319 |
+
self.assertEqual(out.strip(), 'a dog')
|
| 320 |
+
|
| 321 |
+
def test_base_with_one_child(self):
|
| 322 |
+
out = build_nested_prompt('base subject', [
|
| 323 |
+
{'strategy': 'AND_SALT', 'text': 'texture', 'weight': 0.8},
|
| 324 |
+
])
|
| 325 |
+
self.assertIn('base subject', out)
|
| 326 |
+
self.assertIn('AND_SALT', out)
|
| 327 |
+
self.assertIn('texture', out)
|
| 328 |
+
|
| 329 |
+
def test_two_children(self):
|
| 330 |
+
out = build_nested_prompt('base', [
|
| 331 |
+
{'strategy': 'AND_SALT', 'text': 'tex', 'weight': 0.8},
|
| 332 |
+
{'strategy': 'AND_TOPK', 'text': 'sty', 'weight': 0.5},
|
| 333 |
+
])
|
| 334 |
+
self.assertIn('AND_SALT', out)
|
| 335 |
+
self.assertIn('AND_TOPK', out)
|
| 336 |
+
|
| 337 |
+
def test_child_with_nested(self):
|
| 338 |
+
inner = build_child_block('AND_TOPK', 'highlights', 0.4)
|
| 339 |
+
out = build_nested_prompt('base', [
|
| 340 |
+
{'strategy': 'AND_SALT', 'text': 'texture', 'weight': 0.8, 'nested': inner},
|
| 341 |
+
])
|
| 342 |
+
self.assertIn('AND_TOPK', out)
|
| 343 |
+
|
| 344 |
+
def test_generated_prompt_parseable(self):
|
| 345 |
+
out = build_nested_prompt('base subject', [
|
| 346 |
+
{'strategy': 'AND_SALT', 'text': 'tex', 'weight': 0.8},
|
| 347 |
+
{'strategy': 'AND_PERP', 'text': 'sty', 'weight': 0.5},
|
| 348 |
+
])
|
| 349 |
+
try:
|
| 350 |
+
p.parse_root(out)
|
| 351 |
+
except Exception as exc:
|
| 352 |
+
self.fail(f'build_nested_prompt produced unparseable prompt: {exc}')
|
| 353 |
+
|
| 354 |
+
def test_empty_base_no_crash(self):
|
| 355 |
+
out = build_nested_prompt('', [
|
| 356 |
+
{'strategy': 'AND_SALT', 'text': 'texture', 'weight': 0.8},
|
| 357 |
+
])
|
| 358 |
+
self.assertIn('AND_SALT', out)
|
| 359 |
+
|
| 360 |
+
def test_child_with_empty_text_skipped(self):
|
| 361 |
+
out = build_nested_prompt('base', [
|
| 362 |
+
{'strategy': 'AND_SALT', 'text': '', 'weight': 0.8},
|
| 363 |
+
])
|
| 364 |
+
# Empty text child: no crash, base survives
|
| 365 |
+
self.assertIn('base', out)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
# ---------------------------------------------------------------------------
|
| 369 |
+
# MATRYOSHKA_TEMPLATES
|
| 370 |
+
# ---------------------------------------------------------------------------
|
| 371 |
+
|
| 372 |
+
class TestMatryoshkaTemplates(unittest.TestCase):
|
| 373 |
+
|
| 374 |
+
def test_all_templates_have_required_keys(self):
|
| 375 |
+
for name, tmpl in MATRYOSHKA_TEMPLATES.items():
|
| 376 |
+
with self.subTest(template=name):
|
| 377 |
+
self.assertIn('description', tmpl, f'{name}: missing "description"')
|
| 378 |
+
self.assertIn('prompt', tmpl, f'{name}: missing "prompt"')
|
| 379 |
+
self.assertIsInstance(tmpl['description'], str)
|
| 380 |
+
self.assertIsInstance(tmpl['prompt'], str)
|
| 381 |
+
|
| 382 |
+
def test_all_templates_parse_without_crash(self):
|
| 383 |
+
for name, tmpl in MATRYOSHKA_TEMPLATES.items():
|
| 384 |
+
with self.subTest(template=name):
|
| 385 |
+
try:
|
| 386 |
+
p.parse_root(tmpl['prompt'])
|
| 387 |
+
except Exception as exc:
|
| 388 |
+
self.fail(f'Template "{name}" failed parse_root: {exc}')
|
| 389 |
+
|
| 390 |
+
def test_all_template_prompts_nonempty(self):
|
| 391 |
+
for name, tmpl in MATRYOSHKA_TEMPLATES.items():
|
| 392 |
+
with self.subTest(template=name):
|
| 393 |
+
self.assertTrue(tmpl['prompt'].strip(), f'{name} prompt is empty')
|
| 394 |
+
|
| 395 |
+
def test_at_least_five_templates(self):
|
| 396 |
+
self.assertGreaterEqual(len(MATRYOSHKA_TEMPLATES), 5)
|
| 397 |
+
|
| 398 |
+
def test_deep_nested_template_parseable(self):
|
| 399 |
+
tmpl = MATRYOSHKA_TEMPLATES.get('Deep nested concept isolation')
|
| 400 |
+
if tmpl:
|
| 401 |
+
try:
|
| 402 |
+
p.parse_root(tmpl['prompt'])
|
| 403 |
+
except Exception as exc:
|
| 404 |
+
self.fail(f'Deep nested template parse failed: {exc}')
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
# ---------------------------------------------------------------------------
|
| 408 |
+
# _build_affine_snippet (already tested in test_parametric_syntax, quick sanity)
|
| 409 |
+
# ---------------------------------------------------------------------------
|
| 410 |
+
|
| 411 |
+
class TestAffineSnippetSanity(unittest.TestCase):
|
| 412 |
+
|
| 413 |
+
def test_flip_h(self):
|
| 414 |
+
self.assertEqual(_build_affine_snippet('FLIP_H', 0, 0), 'SCALE[-1,1]')
|
| 415 |
+
|
| 416 |
+
def test_rotate(self):
|
| 417 |
+
s = _build_affine_snippet('ROTATE', 0.25, 1.0)
|
| 418 |
+
self.assertIn('ROTATE', s)
|
| 419 |
+
self.assertIn('0.25', s)
|
| 420 |
+
|
| 421 |
+
def test_scale_uniform(self):
|
| 422 |
+
s = _build_affine_snippet('SCALE', 2.0, 2.0)
|
| 423 |
+
self.assertEqual(s, 'SCALE[2.0]')
|
| 424 |
+
|
| 425 |
+
def test_slide_2d(self):
|
| 426 |
+
s = _build_affine_snippet('SLIDE', 0.1, 0.2)
|
| 427 |
+
self.assertEqual(s, 'SLIDE[0.1,0.2]')
|
| 428 |
+
|
| 429 |
+
def test_snippet_parseable(self):
|
| 430 |
+
snippets = ['ROTATE[0.125]', 'SCALE[-1,1]', 'SCALE[0.5]', 'SLIDE[0.1,0.0]']
|
| 431 |
+
for snip in snippets:
|
| 432 |
+
with self.subTest(snip=snip):
|
| 433 |
+
try:
|
| 434 |
+
p.parse_root(f'base {snip} AND_PERP style :0.8')
|
| 435 |
+
except Exception as exc:
|
| 436 |
+
self.fail(f'Snippet "{snip}" not parseable: {exc}')
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
if __name__ == '__main__':
|
| 440 |
+
unittest.main()
|
neutral_prompt_patcheds/test/perp_parser/test_matryoshka_golden.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Golden tests for matryoshka_utils: tree render, diagnostics, effect summaries.
|
| 3 |
+
|
| 4 |
+
These tests lock in the *exact* formatted output for a set of representative
|
| 5 |
+
prompts so that silent formatting/logic regressions are caught immediately.
|
| 6 |
+
They also cover the polishing fixes from Sprint 1:
|
| 7 |
+
- build_child_block: empty text → empty string
|
| 8 |
+
- build_child_block: empty text + nested → no bare ':weight' line
|
| 9 |
+
- render_full_explain: empty prompt → clean message only
|
| 10 |
+
- effect summaries appear in render_full_explain output
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import unittest
|
| 14 |
+
# mock_torch installed by package __init__
|
| 15 |
+
|
| 16 |
+
import lib_neutral_prompt.neutral_prompt_parser as p
|
| 17 |
+
from lib_neutral_prompt.matryoshka_utils import (
|
| 18 |
+
build_child_block,
|
| 19 |
+
build_nested_prompt,
|
| 20 |
+
render_prompt_tree,
|
| 21 |
+
render_prompt_node,
|
| 22 |
+
collect_prompt_diagnostics,
|
| 23 |
+
render_diagnostics,
|
| 24 |
+
render_full_explain,
|
| 25 |
+
_render_effect_summary,
|
| 26 |
+
_effect_for_strategy,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
# Builder polishing: empty-text behaviour
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
|
| 34 |
+
class TestBuilderEmptyTextPolishing(unittest.TestCase):
|
| 35 |
+
|
| 36 |
+
def test_empty_text_returns_empty_string(self):
|
| 37 |
+
out = build_child_block('AND_SALT', '', 0.8)
|
| 38 |
+
self.assertEqual(out, '',
|
| 39 |
+
'Empty text with no nested must return empty string')
|
| 40 |
+
|
| 41 |
+
def test_empty_text_empty_nested_returns_empty_string(self):
|
| 42 |
+
out = build_child_block('AND_TOPK', '', 0.5, nested='')
|
| 43 |
+
self.assertEqual(out, '')
|
| 44 |
+
|
| 45 |
+
def test_empty_text_whitespace_returns_empty_string(self):
|
| 46 |
+
out = build_child_block('AND_PERP', ' ', 0.7)
|
| 47 |
+
self.assertEqual(out, '')
|
| 48 |
+
|
| 49 |
+
def test_empty_text_with_nested_no_bare_weight_line(self):
|
| 50 |
+
"""If text is empty but nested is not, the nested block must not
|
| 51 |
+
include a bare ' :0.5' line at the top."""
|
| 52 |
+
nested = 'AND_PERP blur :0.4'
|
| 53 |
+
out = build_child_block('AND_TOPK', '', 0.5, nested=nested)
|
| 54 |
+
# Should not be empty (nested content exists)
|
| 55 |
+
self.assertNotEqual(out, '')
|
| 56 |
+
# Must not contain a line that is only whitespace + ':weight'
|
| 57 |
+
for line in out.splitlines():
|
| 58 |
+
stripped = line.strip()
|
| 59 |
+
self.assertFalse(
|
| 60 |
+
stripped.startswith(':') and len(stripped.split()) == 1,
|
| 61 |
+
f'Found bare weight line: {repr(line)}',
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def test_nonempty_text_with_nested_includes_text_line(self):
|
| 65 |
+
"""When text is non-empty, the nested block must include a text line."""
|
| 66 |
+
nested = 'AND_PERP correction :0.3'
|
| 67 |
+
out = build_child_block('AND_SALT', 'texture', 0.8, nested=nested)
|
| 68 |
+
self.assertIn('texture', out)
|
| 69 |
+
self.assertIn('AND_PERP', out)
|
| 70 |
+
self.assertIn('[', out)
|
| 71 |
+
|
| 72 |
+
def test_nested_prompt_skips_empty_children(self):
|
| 73 |
+
out = build_nested_prompt('base subject', [
|
| 74 |
+
{'strategy': 'AND_SALT', 'text': '', 'weight': 0.8}, # empty → skip
|
| 75 |
+
{'strategy': 'AND_PERP', 'text': 'vivid', 'weight': 0.5}, # kept
|
| 76 |
+
{'strategy': 'AND_TOPK', 'text': ' ', 'weight': 0.6}, # whitespace → skip
|
| 77 |
+
])
|
| 78 |
+
self.assertIn('AND_PERP', out)
|
| 79 |
+
self.assertNotIn('AND_SALT', out)
|
| 80 |
+
self.assertNotIn('AND_TOPK', out)
|
| 81 |
+
|
| 82 |
+
def test_nested_prompt_all_empty_children_gives_base_only(self):
|
| 83 |
+
out = build_nested_prompt('base', [
|
| 84 |
+
{'strategy': 'AND_SALT', 'text': '', 'weight': 0.8},
|
| 85 |
+
])
|
| 86 |
+
self.assertEqual(out.strip(), 'base')
|
| 87 |
+
|
| 88 |
+
def test_nested_only_child_no_text(self):
|
| 89 |
+
"""Child with empty text but non-empty nested should be kept."""
|
| 90 |
+
inner = build_child_block('AND_PERP', 'correction', 0.3)
|
| 91 |
+
out = build_child_block('AND_TOPK', '', 0.5, nested=inner)
|
| 92 |
+
self.assertNotEqual(out, '', 'Should keep block when nested is non-empty')
|
| 93 |
+
self.assertIn('AND_PERP', out)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# ---------------------------------------------------------------------------
|
| 97 |
+
# render_full_explain: empty-prompt clean state
|
| 98 |
+
# ---------------------------------------------------------------------------
|
| 99 |
+
|
| 100 |
+
class TestRenderFullExplainEmptyState(unittest.TestCase):
|
| 101 |
+
|
| 102 |
+
def test_empty_string_returns_clean_message(self):
|
| 103 |
+
out = render_full_explain('')
|
| 104 |
+
self.assertIn('empty', out.lower())
|
| 105 |
+
self.assertNotIn('Parse error', out)
|
| 106 |
+
self.assertNotIn('unknown', out.lower())
|
| 107 |
+
|
| 108 |
+
def test_whitespace_only_returns_clean_message(self):
|
| 109 |
+
out = render_full_explain(' \n ')
|
| 110 |
+
self.assertIn('empty', out.lower())
|
| 111 |
+
self.assertNotIn('Parse error', out)
|
| 112 |
+
|
| 113 |
+
def test_nonempty_prompt_no_empty_message(self):
|
| 114 |
+
out = render_full_explain('base AND_SALT[5] texture :0.8')
|
| 115 |
+
self.assertNotIn('empty prompt', out.lower())
|
| 116 |
+
self.assertIn('ROOT', out)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# ---------------------------------------------------------------------------
|
| 120 |
+
# Effect summaries
|
| 121 |
+
# ---------------------------------------------------------------------------
|
| 122 |
+
|
| 123 |
+
class TestEffectSummary(unittest.TestCase):
|
| 124 |
+
|
| 125 |
+
def test_effect_for_base(self):
|
| 126 |
+
desc = _effect_for_strategy('BASE')
|
| 127 |
+
self.assertIn('base', desc.lower())
|
| 128 |
+
|
| 129 |
+
def test_effect_for_salience_mask(self):
|
| 130 |
+
desc = _effect_for_strategy('SALIENCE_MASK')
|
| 131 |
+
self.assertIn('saliency', desc.lower())
|
| 132 |
+
|
| 133 |
+
def test_effect_for_semantic_guidance(self):
|
| 134 |
+
desc = _effect_for_strategy('SEMANTIC_GUIDANCE')
|
| 135 |
+
self.assertIn('top-k', desc.lower())
|
| 136 |
+
|
| 137 |
+
def test_effect_for_alignment_custom(self):
|
| 138 |
+
desc = _effect_for_strategy('ALIGNMENT_BLEND_CUSTOM')
|
| 139 |
+
self.assertIn('align', desc.lower())
|
| 140 |
+
|
| 141 |
+
def test_effect_for_alignment_legacy_suffix(self):
|
| 142 |
+
# Fixed-suffix strategies like ALIGNMENT_BLEND_4_8 → same description
|
| 143 |
+
desc = _effect_for_strategy('ALIGNMENT_BLEND_4_8')
|
| 144 |
+
self.assertIn('align', desc.lower())
|
| 145 |
+
|
| 146 |
+
def test_effect_for_unknown_strategy(self):
|
| 147 |
+
desc = _effect_for_strategy('SOME_UNKNOWN_STRATEGY')
|
| 148 |
+
self.assertIn('SOME_UNKNOWN_STRATEGY', desc)
|
| 149 |
+
|
| 150 |
+
def test_effect_summary_in_render_full_explain(self):
|
| 151 |
+
out = render_full_explain('base AND_SALT[5] texture :0.8')
|
| 152 |
+
self.assertIn('Effect summary', out)
|
| 153 |
+
self.assertIn('SALIENCE_MASK', out)
|
| 154 |
+
|
| 155 |
+
def test_effect_summary_notes_preview_nature(self):
|
| 156 |
+
out = render_full_explain('base AND_TOPK[0.1] highlights :0.6')
|
| 157 |
+
self.assertIn('structural preview', out.lower())
|
| 158 |
+
|
| 159 |
+
def test_render_effect_summary_empty_on_parse_fail(self):
|
| 160 |
+
diag = {'parse_ok': False, 'strategies': []}
|
| 161 |
+
out = _render_effect_summary(diag)
|
| 162 |
+
self.assertEqual(out, '')
|
| 163 |
+
|
| 164 |
+
def test_render_effect_summary_empty_on_no_strategies(self):
|
| 165 |
+
diag = {'parse_ok': True, 'strategies': []}
|
| 166 |
+
out = _render_effect_summary(diag)
|
| 167 |
+
self.assertEqual(out, '')
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# ---------------------------------------------------------------------------
|
| 171 |
+
# Golden: tree format
|
| 172 |
+
# ---------------------------------------------------------------------------
|
| 173 |
+
|
| 174 |
+
class TestTreeGoldenFormat(unittest.TestCase):
|
| 175 |
+
"""Check structural properties of the tree output rather than exact strings,
|
| 176 |
+
making these tests robust to minor phrasing changes while still catching
|
| 177 |
+
regressions in tree shape, connectors, and content."""
|
| 178 |
+
|
| 179 |
+
def _tree(self, prompt):
|
| 180 |
+
return render_prompt_tree(prompt)
|
| 181 |
+
|
| 182 |
+
def test_root_header(self):
|
| 183 |
+
out = self._tree('base AND_PERP style :0.8')
|
| 184 |
+
self.assertTrue(out.startswith('ROOT'), f'Expected ROOT header, got: {out[:40]}')
|
| 185 |
+
|
| 186 |
+
def test_leaf_connector_last(self):
|
| 187 |
+
out = self._tree('base AND_PERP style :0.8')
|
| 188 |
+
# Last segment uses └─
|
| 189 |
+
self.assertIn('└─', out)
|
| 190 |
+
|
| 191 |
+
def test_multiple_children_mixed_connectors(self):
|
| 192 |
+
out = self._tree('base AND_SALT[5] tex :0.8 AND_PERP vivid :0.5')
|
| 193 |
+
self.assertIn('├─', out)
|
| 194 |
+
self.assertIn('└─', out)
|
| 195 |
+
|
| 196 |
+
def test_base_shows_BASE(self):
|
| 197 |
+
out = self._tree('base subject')
|
| 198 |
+
self.assertIn('BASE', out)
|
| 199 |
+
|
| 200 |
+
def test_salt_shows_k_param(self):
|
| 201 |
+
out = self._tree('base AND_SALT[7] texture :0.8')
|
| 202 |
+
self.assertIn('k=7.0', out)
|
| 203 |
+
|
| 204 |
+
def test_topk_shows_threshold_param(self):
|
| 205 |
+
out = self._tree('base AND_TOPK[0.15] highlights :0.6')
|
| 206 |
+
self.assertIn('threshold=0.15', out)
|
| 207 |
+
|
| 208 |
+
def test_align_shows_d_s_params(self):
|
| 209 |
+
out = self._tree('base AND_ALIGN[6,12] style :0.7')
|
| 210 |
+
self.assertIn('d=6', out)
|
| 211 |
+
self.assertIn('s=12', out)
|
| 212 |
+
|
| 213 |
+
def test_weight_shown_to_two_decimals(self):
|
| 214 |
+
out = self._tree('base AND_PERP vivid :0.42')
|
| 215 |
+
self.assertIn('0.42', out)
|
| 216 |
+
|
| 217 |
+
def test_nested_composite_shows_children_count(self):
|
| 218 |
+
prompt = 'base AND_SALT[5] [ texture :0.8 AND_TOPK[0.05] detail :0.4 ] :0.7'
|
| 219 |
+
out = self._tree(prompt)
|
| 220 |
+
# Composite node should show "(N children)"
|
| 221 |
+
self.assertIn('children)', out)
|
| 222 |
+
|
| 223 |
+
def test_deep_nested_has_continuation_bar(self):
|
| 224 |
+
prompt = ('base AND_SALT[5] [\n'
|
| 225 |
+
' region :0.8\n'
|
| 226 |
+
' AND_ALIGN[4,8] [\n'
|
| 227 |
+
' sub :0.7\n'
|
| 228 |
+
' AND_TOPK[0.05] micro :0.4\n'
|
| 229 |
+
' ] :0.6\n'
|
| 230 |
+
'] :0.7')
|
| 231 |
+
out = self._tree(prompt)
|
| 232 |
+
# Deep nesting should show either │ continuation bars or ├─/└─ at 2+ depths
|
| 233 |
+
# (exact rendering depends on depth threshold)
|
| 234 |
+
# Ensure at least 3 levels of nesting are represented in the tree
|
| 235 |
+
self.assertGreaterEqual(out.count('├─') + out.count('└─'), 4,
|
| 236 |
+
'Expected at least 4 branch markers for deep nesting')
|
| 237 |
+
|
| 238 |
+
def test_affine_annotated(self):
|
| 239 |
+
out = self._tree('base ROTATE[0.125] AND_PERP vivid :0.8')
|
| 240 |
+
self.assertIn('affine', out)
|
| 241 |
+
|
| 242 |
+
def test_text_truncated_at_55_chars(self):
|
| 243 |
+
long_text = 'a' * 60
|
| 244 |
+
out = self._tree(f'base AND_PERP {long_text} :0.8')
|
| 245 |
+
self.assertIn('…', out)
|
| 246 |
+
|
| 247 |
+
def test_empty_returns_clean_empty_message(self):
|
| 248 |
+
out = self._tree('')
|
| 249 |
+
self.assertIn('empty', out.lower())
|
| 250 |
+
self.assertNotIn('ROOT', out)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
# ---------------------------------------------------------------------------
|
| 254 |
+
# Golden: diagnostics content
|
| 255 |
+
# ---------------------------------------------------------------------------
|
| 256 |
+
|
| 257 |
+
class TestDiagnosticsGolden(unittest.TestCase):
|
| 258 |
+
|
| 259 |
+
def _diag(self, prompt):
|
| 260 |
+
return collect_prompt_diagnostics(prompt)
|
| 261 |
+
|
| 262 |
+
def test_base_only_n_segments(self):
|
| 263 |
+
d = self._diag('a dog in a park')
|
| 264 |
+
self.assertGreaterEqual(d['n_segments'], 1)
|
| 265 |
+
|
| 266 |
+
def test_two_children_strategy_list(self):
|
| 267 |
+
d = self._diag('base AND_SALT[5] tex :0.8 AND_TOPK[0.1] sty :0.6')
|
| 268 |
+
self.assertIn('SALIENCE_MASK', d['strategies'])
|
| 269 |
+
self.assertIn('SEMANTIC_GUIDANCE', d['strategies'])
|
| 270 |
+
self.assertIn('BASE', d['strategies'])
|
| 271 |
+
|
| 272 |
+
def test_affine_count(self):
|
| 273 |
+
d = self._diag('base ROTATE[0.125] AND_PERP vivid :0.8')
|
| 274 |
+
self.assertGreaterEqual(d['n_affine'], 1)
|
| 275 |
+
|
| 276 |
+
def test_protection_fires_no_base(self):
|
| 277 |
+
d = self._diag('AND_SALT[5] only :0.8')
|
| 278 |
+
self.assertTrue(d['protection_would_fire'])
|
| 279 |
+
|
| 280 |
+
def test_protection_ok_with_base(self):
|
| 281 |
+
d = self._diag('base AND_SALT[5] tex :0.8')
|
| 282 |
+
self.assertFalse(d['protection_would_fire'])
|
| 283 |
+
|
| 284 |
+
def test_depth_deep_nested(self):
|
| 285 |
+
prompt = ('base AND_SALT[5] [\n'
|
| 286 |
+
' region :0.8\n'
|
| 287 |
+
' AND_ALIGN[4,8] [\n'
|
| 288 |
+
' sub :0.7\n'
|
| 289 |
+
' AND_TOPK[0.05] micro :0.4\n'
|
| 290 |
+
' ] :0.6\n'
|
| 291 |
+
'] :0.7')
|
| 292 |
+
d = self._diag(prompt)
|
| 293 |
+
self.assertGreaterEqual(d['max_depth'], 2)
|
| 294 |
+
|
| 295 |
+
def test_render_diagnostics_includes_strategies(self):
|
| 296 |
+
d = self._diag('base AND_SALT[5] tex :0.8')
|
| 297 |
+
out = render_diagnostics(d)
|
| 298 |
+
self.assertIn('SALIENCE_MASK', out)
|
| 299 |
+
self.assertIn('Strategies used', out)
|
| 300 |
+
|
| 301 |
+
def test_render_diagnostics_protection_would_fire_message(self):
|
| 302 |
+
d = self._diag('AND_TOPK[0.1] only :0.8')
|
| 303 |
+
out = render_diagnostics(d)
|
| 304 |
+
self.assertIn('WOULD FIRE', out)
|
| 305 |
+
|
| 306 |
+
def test_render_diagnostics_ok_message(self):
|
| 307 |
+
d = self._diag('base AND_PERP style :0.8')
|
| 308 |
+
out = render_diagnostics(d)
|
| 309 |
+
self.assertIn('OK', out)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
# ---------------------------------------------------------------------------
|
| 313 |
+
# Integration: full explain round-trip for all templates
|
| 314 |
+
# ---------------------------------------------------------------------------
|
| 315 |
+
|
| 316 |
+
class TestTemplatesExplainIntegration(unittest.TestCase):
|
| 317 |
+
"""Every template must produce a non-trivial full explain output."""
|
| 318 |
+
|
| 319 |
+
def test_all_templates_full_explain(self):
|
| 320 |
+
from lib_neutral_prompt.matryoshka_utils import MATRYOSHKA_TEMPLATES
|
| 321 |
+
for name, tmpl in MATRYOSHKA_TEMPLATES.items():
|
| 322 |
+
with self.subTest(template=name):
|
| 323 |
+
out = render_full_explain(tmpl['prompt'])
|
| 324 |
+
self.assertIn('ROOT', out, f'{name}: missing ROOT in explain')
|
| 325 |
+
self.assertIn('Diagnostics', out, f'{name}: missing Diagnostics')
|
| 326 |
+
self.assertIn('Effect summary', out, f'{name}: missing Effect summary')
|
| 327 |
+
self.assertNotIn('Parse error', out, f'{name}: unexpected parse error')
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
if __name__ == '__main__':
|
| 331 |
+
unittest.main()
|
neutral_prompt_patcheds/test/perp_parser/test_parametric_syntax.py
ADDED
|
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for parametric prompt syntax (patched extension).
|
| 3 |
+
|
| 4 |
+
Requires: torch OR falls back to mock_torch for pure parser logic.
|
| 5 |
+
|
| 6 |
+
Covers:
|
| 7 |
+
- AND_SALT[k] / AND_SALT_WIDE[k] / AND_SALT_BLOB[k] parser
|
| 8 |
+
- AND_ALIGN[d,s] / AND_MASK_ALIGN[d,s] parser
|
| 9 |
+
- conciliation_params populated correctly
|
| 10 |
+
- weight still parsed correctly alongside params
|
| 11 |
+
- composite brackets NOT consumed as params
|
| 12 |
+
- backward-compat: AND_ALIGN_4_8 still works
|
| 13 |
+
- invalid / boundary params rejected gracefully
|
| 14 |
+
- params survive alongside affine transforms
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import pathlib
|
| 18 |
+
import sys
|
| 19 |
+
import unittest
|
| 20 |
+
|
| 21 |
+
sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.parent))
|
| 22 |
+
|
| 23 |
+
# mock_torch installed by package __init__.py
|
| 24 |
+
|
| 25 |
+
from lib_neutral_prompt import neutral_prompt_parser as p
|
| 26 |
+
|
| 27 |
+
CS = p.ConciliationStrategy
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _child(prompt_str, idx=-1):
|
| 31 |
+
return p.parse_root(prompt_str).children[idx]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class TestSaltKParam(unittest.TestCase):
|
| 35 |
+
|
| 36 |
+
def test_salt_default_no_params(self):
|
| 37 |
+
child = _child("base AND_SALT concept :1.0")
|
| 38 |
+
self.assertEqual(child.conciliation, CS.SALIENCE_MASK)
|
| 39 |
+
self.assertEqual(child.conciliation_params, {})
|
| 40 |
+
|
| 41 |
+
def test_salt_with_k_integer(self):
|
| 42 |
+
child = _child("base AND_SALT[5] concept :1.0")
|
| 43 |
+
self.assertEqual(child.conciliation, CS.SALIENCE_MASK)
|
| 44 |
+
self.assertAlmostEqual(child.conciliation_params['k'], 5.0)
|
| 45 |
+
|
| 46 |
+
def test_salt_with_k_float(self):
|
| 47 |
+
child = _child("base AND_SALT[3.5] concept :0.8")
|
| 48 |
+
self.assertAlmostEqual(child.conciliation_params['k'], 3.5)
|
| 49 |
+
|
| 50 |
+
def test_salt_k_preserves_weight(self):
|
| 51 |
+
child = _child("base AND_SALT[5] concept :0.75")
|
| 52 |
+
self.assertAlmostEqual(child.weight, 0.75)
|
| 53 |
+
self.assertAlmostEqual(child.conciliation_params['k'], 5.0)
|
| 54 |
+
|
| 55 |
+
def test_salt_wide_with_k(self):
|
| 56 |
+
child = _child("base AND_SALT_WIDE[1.5] concept :1.0")
|
| 57 |
+
self.assertEqual(child.conciliation, CS.SALIENCE_MASK_WIDE)
|
| 58 |
+
self.assertAlmostEqual(child.conciliation_params['k'], 1.5)
|
| 59 |
+
|
| 60 |
+
def test_salt_wide_default_no_params(self):
|
| 61 |
+
child = _child("base AND_SALT_WIDE concept :1.0")
|
| 62 |
+
self.assertEqual(child.conciliation_params, {})
|
| 63 |
+
|
| 64 |
+
def test_salt_blob_with_k(self):
|
| 65 |
+
child = _child("base AND_SALT_BLOB[8] concept :1.0")
|
| 66 |
+
self.assertEqual(child.conciliation, CS.SALIENCE_MASK_BLOB)
|
| 67 |
+
self.assertAlmostEqual(child.conciliation_params['k'], 8.0)
|
| 68 |
+
|
| 69 |
+
def test_salt_blob_default_no_params(self):
|
| 70 |
+
child = _child("base AND_SALT_BLOB concept :1.0")
|
| 71 |
+
self.assertEqual(child.conciliation_params, {})
|
| 72 |
+
|
| 73 |
+
def test_salt_k_zero_rejected(self):
|
| 74 |
+
child = _child("base AND_SALT[0] concept :1.0")
|
| 75 |
+
self.assertEqual(child.conciliation_params, {})
|
| 76 |
+
|
| 77 |
+
def test_salt_k_negative_rejected(self):
|
| 78 |
+
child = _child("base AND_SALT[-1] concept :1.0")
|
| 79 |
+
self.assertEqual(child.conciliation_params, {})
|
| 80 |
+
|
| 81 |
+
def test_salt_k_max_value_accepted(self):
|
| 82 |
+
child = _child("base AND_SALT[20] concept :1.0")
|
| 83 |
+
self.assertAlmostEqual(child.conciliation_params['k'], 20.0)
|
| 84 |
+
|
| 85 |
+
def test_salt_k_very_small_accepted(self):
|
| 86 |
+
child = _child("base AND_SALT[0.5] concept :1.0")
|
| 87 |
+
self.assertAlmostEqual(child.conciliation_params['k'], 0.5)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class TestAlignBracketParam(unittest.TestCase):
|
| 91 |
+
|
| 92 |
+
def test_align_bracket_basic(self):
|
| 93 |
+
child = _child("base AND_ALIGN[4,8] style :0.5")
|
| 94 |
+
self.assertEqual(child.conciliation, CS.ALIGNMENT_BLEND_CUSTOM)
|
| 95 |
+
self.assertEqual(child.conciliation_params, {'d': 4, 's': 8})
|
| 96 |
+
|
| 97 |
+
def test_align_bracket_large_pair(self):
|
| 98 |
+
child = _child("base AND_ALIGN[8,32] style :0.5")
|
| 99 |
+
self.assertEqual(child.conciliation_params, {'d': 8, 's': 32})
|
| 100 |
+
|
| 101 |
+
def test_align_bracket_preserves_weight(self):
|
| 102 |
+
child = _child("base AND_ALIGN[4,8] style :0.65")
|
| 103 |
+
self.assertAlmostEqual(child.weight, 0.65)
|
| 104 |
+
self.assertEqual(child.conciliation_params, {'d': 4, 's': 8})
|
| 105 |
+
|
| 106 |
+
def test_mask_align_bracket_basic(self):
|
| 107 |
+
child = _child("base AND_MASK_ALIGN[6,12] structure :0.7")
|
| 108 |
+
self.assertEqual(child.conciliation, CS.ALIGNMENT_MASK_BLEND_CUSTOM)
|
| 109 |
+
self.assertEqual(child.conciliation_params, {'d': 6, 's': 12})
|
| 110 |
+
|
| 111 |
+
def test_mask_align_bracket_min_pair(self):
|
| 112 |
+
child = _child("base AND_MASK_ALIGN[2,3] structure :0.5")
|
| 113 |
+
self.assertEqual(child.conciliation_params, {'d': 2, 's': 3})
|
| 114 |
+
|
| 115 |
+
def test_align_equal_ds_rejected(self):
|
| 116 |
+
"""d == s has no structural distinction → params must be rejected."""
|
| 117 |
+
child = _child("base AND_ALIGN[4,4] style :0.5")
|
| 118 |
+
self.assertEqual(child.conciliation_params, {})
|
| 119 |
+
|
| 120 |
+
def test_align_d_out_of_range_low(self):
|
| 121 |
+
child = _child("base AND_ALIGN[1,8] style :0.5")
|
| 122 |
+
self.assertEqual(child.conciliation_params, {})
|
| 123 |
+
|
| 124 |
+
def test_align_s_out_of_range_high(self):
|
| 125 |
+
child = _child("base AND_ALIGN[4,33] style :0.5")
|
| 126 |
+
self.assertEqual(child.conciliation_params, {})
|
| 127 |
+
|
| 128 |
+
def test_align_inverted_pair_accepted(self):
|
| 129 |
+
"""d > s is valid (just an unusual choice) as long as d != s."""
|
| 130 |
+
child = _child("base AND_ALIGN[8,4] style :0.5")
|
| 131 |
+
self.assertEqual(child.conciliation_params, {'d': 8, 's': 4})
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class TestBackwardCompatibility(unittest.TestCase):
|
| 135 |
+
|
| 136 |
+
def test_fixed_suffix_align(self):
|
| 137 |
+
child = _child("base AND_ALIGN_4_8 style :0.5")
|
| 138 |
+
self.assertEqual(child.conciliation, CS.ALIGNMENT_BLEND_4_8)
|
| 139 |
+
self.assertEqual(child.conciliation_params, {})
|
| 140 |
+
|
| 141 |
+
def test_fixed_suffix_mask_align(self):
|
| 142 |
+
child = _child("base AND_MASK_ALIGN_4_8 structure :0.5")
|
| 143 |
+
self.assertEqual(child.conciliation, CS.ALIGNMENT_MASK_BLEND_4_8)
|
| 144 |
+
self.assertEqual(child.conciliation_params, {})
|
| 145 |
+
|
| 146 |
+
def test_both_forms_coexist(self):
|
| 147 |
+
root = p.parse_root("base AND_ALIGN_4_8 old :0.5 AND_ALIGN[6,12] new :0.7")
|
| 148 |
+
old, new = root.children[1], root.children[2]
|
| 149 |
+
self.assertEqual(old.conciliation, CS.ALIGNMENT_BLEND_4_8)
|
| 150 |
+
self.assertEqual(old.conciliation_params, {})
|
| 151 |
+
self.assertEqual(new.conciliation, CS.ALIGNMENT_BLEND_CUSTOM)
|
| 152 |
+
self.assertEqual(new.conciliation_params, {'d': 6, 's': 12})
|
| 153 |
+
|
| 154 |
+
def test_plain_and_salt_unchanged(self):
|
| 155 |
+
child = _child("base AND_SALT concept :1.0")
|
| 156 |
+
self.assertEqual(child.conciliation, CS.SALIENCE_MASK)
|
| 157 |
+
self.assertEqual(child.conciliation_params, {})
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class TestCompositeBracketNotConsumed(unittest.TestCase):
|
| 161 |
+
|
| 162 |
+
def test_align_composite_not_eaten(self):
|
| 163 |
+
root = p.parse_root("base AND_ALIGN[arst AND defg :2.0] :0.5")
|
| 164 |
+
child = root.children[-1]
|
| 165 |
+
self.assertIsInstance(child, p.CompositePrompt)
|
| 166 |
+
self.assertEqual(child.conciliation_params, {})
|
| 167 |
+
|
| 168 |
+
def test_salt_composite_not_eaten(self):
|
| 169 |
+
root = p.parse_root("base AND_SALT[concept AND style :1.5] :0.8")
|
| 170 |
+
child = root.children[-1]
|
| 171 |
+
self.assertIsInstance(child, p.CompositePrompt)
|
| 172 |
+
self.assertEqual(child.conciliation_params, {})
|
| 173 |
+
|
| 174 |
+
def test_mask_align_composite_not_eaten(self):
|
| 175 |
+
root = p.parse_root("base AND_MASK_ALIGN[a AND b] :0.5")
|
| 176 |
+
child = root.children[-1]
|
| 177 |
+
self.assertIsInstance(child, p.CompositePrompt)
|
| 178 |
+
self.assertEqual(child.conciliation_params, {})
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class TestParamsWithAffineTransform(unittest.TestCase):
|
| 182 |
+
|
| 183 |
+
def test_salt_k_with_trailing_affine(self):
|
| 184 |
+
child = _child("base AND_SALT[5] ROTATE[0.125] concept :0.8")
|
| 185 |
+
self.assertEqual(child.conciliation, CS.SALIENCE_MASK)
|
| 186 |
+
self.assertAlmostEqual(child.conciliation_params['k'], 5.0)
|
| 187 |
+
self.assertIsNotNone(child.local_transform)
|
| 188 |
+
|
| 189 |
+
def test_align_custom_with_leading_affine(self):
|
| 190 |
+
child = _child("base ROTATE[0.125] AND_ALIGN[4,8] style :0.5")
|
| 191 |
+
self.assertEqual(child.conciliation, CS.ALIGNMENT_BLEND_CUSTOM)
|
| 192 |
+
self.assertEqual(child.conciliation_params, {'d': 4, 's': 8})
|
| 193 |
+
self.assertIsNotNone(child.local_transform)
|
| 194 |
+
|
| 195 |
+
def test_salt_k_and_affine_weight_all_correct(self):
|
| 196 |
+
child = _child("base AND_SALT[3] ROTATE[0.25] concept :0.6")
|
| 197 |
+
self.assertAlmostEqual(child.conciliation_params['k'], 3.0)
|
| 198 |
+
self.assertAlmostEqual(child.weight, 0.6)
|
| 199 |
+
self.assertIsNotNone(child.local_transform)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
if __name__ == '__main__':
|
| 203 |
+
unittest.main()
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class TestSpacesInBracketParams(unittest.TestCase):
|
| 207 |
+
"""Spaces around comma inside bracket params must be tolerated."""
|
| 208 |
+
|
| 209 |
+
def test_align_spaces_around_comma(self):
|
| 210 |
+
child = _child("base AND_ALIGN[4, 8] style :0.5")
|
| 211 |
+
self.assertEqual(child.conciliation_params, {'d': 4, 's': 8})
|
| 212 |
+
|
| 213 |
+
def test_mask_align_spaces_around_comma(self):
|
| 214 |
+
child = _child("base AND_MASK_ALIGN[6, 12] style :0.5")
|
| 215 |
+
self.assertEqual(child.conciliation_params, {'d': 6, 's': 12})
|
| 216 |
+
|
| 217 |
+
def test_salt_inner_spaces(self):
|
| 218 |
+
child = _child("base AND_SALT[ 5 ] concept :1.0")
|
| 219 |
+
self.assertAlmostEqual(child.conciliation_params.get('k', 0), 5.0)
|
| 220 |
+
|
| 221 |
+
def test_align_many_spaces(self):
|
| 222 |
+
child = _child("base AND_ALIGN[ 4 , 8 ] style :0.5")
|
| 223 |
+
self.assertEqual(child.conciliation_params, {'d': 4, 's': 8})
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class TestInvalidParamWarnings(unittest.TestCase):
|
| 227 |
+
"""
|
| 228 |
+
Invalid bracket params must:
|
| 229 |
+
1. NOT crash.
|
| 230 |
+
2. Return empty conciliation_params (keyword uses its default).
|
| 231 |
+
3. NOT leak bracket content into prompt text.
|
| 232 |
+
4. Emit a warning to stderr when verbose=True.
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
def _parse(self, s):
|
| 236 |
+
return p.parse_root(s)
|
| 237 |
+
|
| 238 |
+
def _last_child(self, s):
|
| 239 |
+
return self._parse(s).children[-1]
|
| 240 |
+
|
| 241 |
+
# --- k invalid values ---
|
| 242 |
+
|
| 243 |
+
def test_k_zero_no_params(self):
|
| 244 |
+
child = self._last_child("base AND_SALT[0] concept :1.0")
|
| 245 |
+
self.assertEqual(child.conciliation_params, {})
|
| 246 |
+
|
| 247 |
+
def test_k_negative_no_params(self):
|
| 248 |
+
child = self._last_child("base AND_SALT[-3] concept :1.0")
|
| 249 |
+
self.assertEqual(child.conciliation_params, {})
|
| 250 |
+
|
| 251 |
+
def test_k_too_many_parts_no_params(self):
|
| 252 |
+
child = self._last_child("base AND_SALT[3,5] concept :1.0")
|
| 253 |
+
self.assertEqual(child.conciliation_params, {})
|
| 254 |
+
|
| 255 |
+
# --- D,S invalid values ---
|
| 256 |
+
|
| 257 |
+
def test_ds_equal_no_params(self):
|
| 258 |
+
child = self._last_child("base AND_ALIGN[4,4] style :0.5")
|
| 259 |
+
self.assertEqual(child.conciliation_params, {})
|
| 260 |
+
|
| 261 |
+
def test_d_too_low_no_params(self):
|
| 262 |
+
child = self._last_child("base AND_ALIGN[1,8] style :0.5")
|
| 263 |
+
self.assertEqual(child.conciliation_params, {})
|
| 264 |
+
|
| 265 |
+
def test_s_too_high_no_params(self):
|
| 266 |
+
child = self._last_child("base AND_ALIGN[4,33] style :0.5")
|
| 267 |
+
self.assertEqual(child.conciliation_params, {})
|
| 268 |
+
|
| 269 |
+
def test_wrong_part_count_no_params(self):
|
| 270 |
+
child = self._last_child("base AND_ALIGN[4] style :0.5")
|
| 271 |
+
self.assertEqual(child.conciliation_params, {})
|
| 272 |
+
|
| 273 |
+
# --- No crash on bad input ---
|
| 274 |
+
|
| 275 |
+
def test_no_crash_on_empty_brackets(self):
|
| 276 |
+
# AND_SALT[] — inner is empty, should not crash
|
| 277 |
+
try:
|
| 278 |
+
root = self._parse("base AND_SALT[] concept :1.0")
|
| 279 |
+
# brackets may be treated as composite or ignored; just no crash
|
| 280 |
+
except Exception as exc:
|
| 281 |
+
self.fail(f"Crashed on AND_SALT[]: {exc}")
|
| 282 |
+
|
| 283 |
+
# --- Warning emitted to stderr ---
|
| 284 |
+
|
| 285 |
+
def test_warning_emitted_for_k_zero(self):
|
| 286 |
+
import sys, io
|
| 287 |
+
from lib_neutral_prompt import global_state as gs
|
| 288 |
+
gs.verbose = True
|
| 289 |
+
buf = io.StringIO()
|
| 290 |
+
with unittest.mock.patch('sys.stderr', buf):
|
| 291 |
+
self._parse("base AND_SALT[0] concept :1.0")
|
| 292 |
+
gs.verbose = False
|
| 293 |
+
self.assertIn('AND_SALT', buf.getvalue(),
|
| 294 |
+
"Expected warning mentioning AND_SALT in stderr")
|
| 295 |
+
self.assertIn('k must be', buf.getvalue().lower().replace('k must be', 'k must be'),
|
| 296 |
+
"Expected warning explaining the reason")
|
| 297 |
+
|
| 298 |
+
def test_warning_emitted_for_ds_equal(self):
|
| 299 |
+
import io, unittest.mock
|
| 300 |
+
from lib_neutral_prompt import global_state as gs
|
| 301 |
+
gs.verbose = True
|
| 302 |
+
buf = io.StringIO()
|
| 303 |
+
with unittest.mock.patch('sys.stderr', buf):
|
| 304 |
+
self._parse("base AND_ALIGN[4,4] style :0.5")
|
| 305 |
+
gs.verbose = False
|
| 306 |
+
self.assertIn('AND_ALIGN', buf.getvalue())
|
| 307 |
+
self.assertIn('differ', buf.getvalue())
|
| 308 |
+
|
| 309 |
+
def test_no_warning_when_verbose_false(self):
|
| 310 |
+
import io, unittest.mock
|
| 311 |
+
from lib_neutral_prompt import global_state as gs
|
| 312 |
+
gs.verbose = False
|
| 313 |
+
buf = io.StringIO()
|
| 314 |
+
with unittest.mock.patch('sys.stderr', buf):
|
| 315 |
+
self._parse("base AND_SALT[0] concept :1.0")
|
| 316 |
+
self.assertEqual(buf.getvalue(), '',
|
| 317 |
+
"No warning expected when verbose=False")
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# ---------------------------------------------------------------------------
|
| 321 |
+
# AND_TOPK[threshold] — parametric threshold
|
| 322 |
+
# ---------------------------------------------------------------------------
|
| 323 |
+
|
| 324 |
+
class TestTopkBracketSyntax(unittest.TestCase):
|
| 325 |
+
|
| 326 |
+
def test_topk_default_no_brackets(self):
|
| 327 |
+
root = p.parse_root('base AND_TOPK concept :0.8')
|
| 328 |
+
child = root.children[-1]
|
| 329 |
+
self.assertEqual(child.conciliation_params, {})
|
| 330 |
+
|
| 331 |
+
def test_topk_explicit_threshold(self):
|
| 332 |
+
root = p.parse_root('base AND_TOPK[0.1] concept :0.8')
|
| 333 |
+
child = root.children[-1]
|
| 334 |
+
self.assertAlmostEqual(child.conciliation_params.get('threshold'), 0.1)
|
| 335 |
+
|
| 336 |
+
def test_topk_small_threshold(self):
|
| 337 |
+
root = p.parse_root('base AND_TOPK[0.01] concept :0.8')
|
| 338 |
+
child = root.children[-1]
|
| 339 |
+
self.assertAlmostEqual(child.conciliation_params.get('threshold'), 0.01)
|
| 340 |
+
|
| 341 |
+
def test_topk_max_threshold(self):
|
| 342 |
+
root = p.parse_root('base AND_TOPK[1.0] concept :0.8')
|
| 343 |
+
child = root.children[-1]
|
| 344 |
+
self.assertAlmostEqual(child.conciliation_params.get('threshold'), 1.0)
|
| 345 |
+
|
| 346 |
+
def test_topk_threshold_zero_rejected(self):
|
| 347 |
+
# threshold=0 is invalid (nothing would be selected)
|
| 348 |
+
root = p.parse_root('base AND_TOPK[0] concept :0.8')
|
| 349 |
+
child = root.children[-1]
|
| 350 |
+
self.assertEqual(child.conciliation_params, {})
|
| 351 |
+
|
| 352 |
+
def test_topk_threshold_over_one_rejected(self):
|
| 353 |
+
root = p.parse_root('base AND_TOPK[1.1] concept :0.8')
|
| 354 |
+
child = root.children[-1]
|
| 355 |
+
self.assertEqual(child.conciliation_params, {})
|
| 356 |
+
|
| 357 |
+
def test_topk_too_many_params_rejected(self):
|
| 358 |
+
root = p.parse_root('base AND_TOPK[0.1,0.2] concept :0.8')
|
| 359 |
+
child = root.children[-1]
|
| 360 |
+
self.assertEqual(child.conciliation_params, {})
|
| 361 |
+
|
| 362 |
+
def test_topk_spaces_in_brackets(self):
|
| 363 |
+
root = p.parse_root('base AND_TOPK[ 0.1 ] concept :0.8')
|
| 364 |
+
child = root.children[-1]
|
| 365 |
+
self.assertAlmostEqual(child.conciliation_params.get('threshold'), 0.1)
|
| 366 |
+
|
| 367 |
+
def test_topk_does_not_consume_non_numeric_brackets(self):
|
| 368 |
+
# Composite brackets must NOT be eaten
|
| 369 |
+
root = p.parse_root('base AND_TOPK [concept AND style] :0.8')
|
| 370 |
+
# The [...] is a composite group, not params — topk should have no params
|
| 371 |
+
child = root.children[-1]
|
| 372 |
+
self.assertEqual(child.conciliation_params, {})
|
| 373 |
+
|
| 374 |
+
def test_topk_with_trailing_affine(self):
|
| 375 |
+
root = p.parse_root('base AND_TOPK[0.05] ROTATE[0.125] concept :0.8')
|
| 376 |
+
child = root.children[-1]
|
| 377 |
+
self.assertAlmostEqual(child.conciliation_params.get('threshold'), 0.05)
|
| 378 |
+
self.assertIsNotNone(child.local_transform)
|
| 379 |
+
|
| 380 |
+
def test_topk_with_leading_affine(self):
|
| 381 |
+
root = p.parse_root('base ROTATE[0.125] AND_TOPK[0.1] concept :0.8')
|
| 382 |
+
child = root.children[-1]
|
| 383 |
+
self.assertAlmostEqual(child.conciliation_params.get('threshold'), 0.1)
|
| 384 |
+
self.assertIsNotNone(child.local_transform)
|
| 385 |
+
|
| 386 |
+
def test_topk_backward_compat_default(self):
|
| 387 |
+
# Without brackets the default 0.05 must come from cfg_denoiser, not parser
|
| 388 |
+
root = p.parse_root('base AND_TOPK concept :1.0')
|
| 389 |
+
child = root.children[-1]
|
| 390 |
+
self.assertEqual(child.conciliation_params, {})
|
| 391 |
+
|
| 392 |
+
def test_topk_coexists_with_salt(self):
|
| 393 |
+
root = p.parse_root('base AND_SALT[3] texture :0.8 AND_TOPK[0.1] style :0.6')
|
| 394 |
+
self.assertEqual(len(root.children), 3)
|
| 395 |
+
salt = root.children[1]
|
| 396 |
+
topk = root.children[2]
|
| 397 |
+
self.assertAlmostEqual(salt.conciliation_params.get('k'), 3.0)
|
| 398 |
+
self.assertAlmostEqual(topk.conciliation_params.get('threshold'), 0.1)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
# ---------------------------------------------------------------------------
|
| 402 |
+
# Affine builder helper (ui-level function, no gradio import needed)
|
| 403 |
+
# ---------------------------------------------------------------------------
|
| 404 |
+
|
| 405 |
+
class TestAffineBuilderHelper(unittest.TestCase):
|
| 406 |
+
"""Tests for _build_affine_snippet via neutral_prompt_parser — no gradio."""
|
| 407 |
+
|
| 408 |
+
def _snippet(self, transform, p1, p2):
|
| 409 |
+
# Inline logic mirrors _build_affine_snippet in ui.py
|
| 410 |
+
t = transform.strip()
|
| 411 |
+
import math
|
| 412 |
+
if t == 'FLIP_H': return 'SCALE[-1,1]'
|
| 413 |
+
if t == 'FLIP_V': return 'SCALE[1,-1]'
|
| 414 |
+
if t == 'ROTATE': return f'ROTATE[{p1}]'
|
| 415 |
+
if t == 'SCALE':
|
| 416 |
+
return f'SCALE[{p1},{p2}]' if abs(p1 - p2) > 1e-9 else f'SCALE[{p1}]'
|
| 417 |
+
if t == 'SLIDE': return f'SLIDE[{p1},{p2}]'
|
| 418 |
+
if t == 'SHEAR':
|
| 419 |
+
return f'SHEAR[{p1},{p2}]' if abs(p1 - p2) > 1e-9 else f'SHEAR[{p1}]'
|
| 420 |
+
return ''
|
| 421 |
+
|
| 422 |
+
def test_flip_h(self):
|
| 423 |
+
self.assertEqual(self._snippet('FLIP_H', 0, 0), 'SCALE[-1,1]')
|
| 424 |
+
|
| 425 |
+
def test_flip_v(self):
|
| 426 |
+
self.assertEqual(self._snippet('FLIP_V', 0, 0), 'SCALE[1,-1]')
|
| 427 |
+
|
| 428 |
+
def test_rotate(self):
|
| 429 |
+
self.assertEqual(self._snippet('ROTATE', 0.125, 1.0), 'ROTATE[0.125]')
|
| 430 |
+
|
| 431 |
+
def test_scale_uniform(self):
|
| 432 |
+
self.assertEqual(self._snippet('SCALE', 2.0, 2.0), 'SCALE[2.0]')
|
| 433 |
+
|
| 434 |
+
def test_scale_anisotropic(self):
|
| 435 |
+
self.assertEqual(self._snippet('SCALE', 1.5, 1.0), 'SCALE[1.5,1.0]')
|
| 436 |
+
|
| 437 |
+
def test_slide(self):
|
| 438 |
+
self.assertEqual(self._snippet('SLIDE', 0.1, 0.2), 'SLIDE[0.1,0.2]')
|
| 439 |
+
|
| 440 |
+
def test_shear_uniform(self):
|
| 441 |
+
self.assertEqual(self._snippet('SHEAR', 0.05, 0.05), 'SHEAR[0.05]')
|
| 442 |
+
|
| 443 |
+
def test_snippet_parseable_by_parser(self):
|
| 444 |
+
"""Each generated snippet must survive parse_root without crashing."""
|
| 445 |
+
cases = [
|
| 446 |
+
'ROTATE[0.125]', 'SCALE[-1,1]', 'SCALE[1,-1]',
|
| 447 |
+
'SCALE[0.5]', 'SCALE[1.5,1.0]',
|
| 448 |
+
'SLIDE[0.1,0.2]', 'SHEAR[0.05]',
|
| 449 |
+
]
|
| 450 |
+
for snippet in cases:
|
| 451 |
+
with self.subTest(snippet=snippet):
|
| 452 |
+
try:
|
| 453 |
+
p.parse_root(f'base {snippet} AND_PERP style :0.8')
|
| 454 |
+
except Exception as exc:
|
| 455 |
+
self.fail(f'snippet "{snippet}" crashed parse_root: {exc}')
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
# ---------------------------------------------------------------------------
|
| 459 |
+
# Debug/explain helper (_render_parse_tree logic, without gradio)
|
| 460 |
+
# ---------------------------------------------------------------------------
|
| 461 |
+
|
| 462 |
+
class TestDebugRenderLogic(unittest.TestCase):
|
| 463 |
+
"""Test the parse-tree renderer used by the debug panel."""
|
| 464 |
+
|
| 465 |
+
def _render(self, prompt_str):
|
| 466 |
+
# Import the function directly from ui module logic — replicate inline
|
| 467 |
+
# so we don't need gradio at test time.
|
| 468 |
+
if not prompt_str.strip():
|
| 469 |
+
return '(empty prompt)'
|
| 470 |
+
try:
|
| 471 |
+
root = p.parse_root(prompt_str)
|
| 472 |
+
except Exception as exc:
|
| 473 |
+
return f'Parse error: {exc}'
|
| 474 |
+
|
| 475 |
+
lines = []
|
| 476 |
+
|
| 477 |
+
def _node(node, depth=0):
|
| 478 |
+
indent = ' ' * depth
|
| 479 |
+
strat = node.conciliation.name if node.conciliation else 'BASE'
|
| 480 |
+
w = f'weight={node.weight:.3f}'
|
| 481 |
+
params = ''
|
| 482 |
+
if node.conciliation_params:
|
| 483 |
+
params = ' params=' + ', '.join(f'{k}={v}' for k, v in node.conciliation_params.items())
|
| 484 |
+
aff = ' [affine]' if node.local_transform is not None else ''
|
| 485 |
+
if hasattr(node, 'prompt'):
|
| 486 |
+
text = repr(node.prompt[:60])
|
| 487 |
+
lines.append(f'{indent}[{strat}] {w}{params}{aff} text={text}')
|
| 488 |
+
else:
|
| 489 |
+
lines.append(f'{indent}[{strat}] {w}{params}{aff} ({len(node.children)} children)')
|
| 490 |
+
for child in node.children:
|
| 491 |
+
_node(child, depth + 1)
|
| 492 |
+
|
| 493 |
+
_node(root)
|
| 494 |
+
return '\n'.join(lines)
|
| 495 |
+
|
| 496 |
+
def test_empty_prompt(self):
|
| 497 |
+
out = self._render('')
|
| 498 |
+
self.assertEqual(out, '(empty prompt)')
|
| 499 |
+
|
| 500 |
+
def test_base_only(self):
|
| 501 |
+
out = self._render('a dog')
|
| 502 |
+
self.assertIn('BASE', out)
|
| 503 |
+
self.assertIn('a dog', out)
|
| 504 |
+
|
| 505 |
+
def test_with_salt(self):
|
| 506 |
+
out = self._render('base AND_SALT[3] texture :0.8')
|
| 507 |
+
self.assertIn('SALIENCE_MASK', out)
|
| 508 |
+
self.assertIn('k=3', out)
|
| 509 |
+
|
| 510 |
+
def test_with_topk(self):
|
| 511 |
+
out = self._render('base AND_TOPK[0.1] style :0.6')
|
| 512 |
+
self.assertIn('SEMANTIC_GUIDANCE', out)
|
| 513 |
+
self.assertIn('threshold=0.1', out)
|
| 514 |
+
|
| 515 |
+
def test_with_align(self):
|
| 516 |
+
out = self._render('base AND_ALIGN[4,8] detail :0.7')
|
| 517 |
+
self.assertIn('ALIGNMENT_BLEND_CUSTOM', out)
|
| 518 |
+
self.assertIn('d=4', out)
|
| 519 |
+
self.assertIn('s=8', out)
|
| 520 |
+
|
| 521 |
+
def test_with_affine(self):
|
| 522 |
+
out = self._render('base ROTATE[0.125] AND_PERP vivid :0.8')
|
| 523 |
+
self.assertIn('[affine]', out)
|
| 524 |
+
|
| 525 |
+
def test_multiple_children(self):
|
| 526 |
+
out = self._render('base AND_SALT[5] texture :0.8 AND_TOPK[0.05] style :0.6')
|
| 527 |
+
self.assertIn('SALIENCE_MASK', out)
|
| 528 |
+
self.assertIn('SEMANTIC_GUIDANCE', out)
|
| 529 |
+
|
| 530 |
+
def test_parse_error_graceful(self):
|
| 531 |
+
# Malformed / deeply weird input must not raise — parser should handle it
|
| 532 |
+
try:
|
| 533 |
+
out = self._render('[[[[[')
|
| 534 |
+
except Exception as exc:
|
| 535 |
+
self.fail(f'_render raised on bad input: {exc}')
|
neutral_prompt_patcheds/test/perp_parser/test_runtime_behavior.py
ADDED
|
@@ -0,0 +1,826 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Runtime behavior tests.
|
| 3 |
+
|
| 4 |
+
Requires PyTorch. If torch is absent the whole module is skipped — same
|
| 5 |
+
pattern as test_affine_pipeline.py so the CI exit code stays clean.
|
| 6 |
+
|
| 7 |
+
Covers:
|
| 8 |
+
- _get_salience: k param changes pixel coverage (monotonically)
|
| 9 |
+
- blob morphology: thickify-first pipeline vs erode-first pipeline
|
| 10 |
+
- no-base guard: fallback when all children are conciliatory
|
| 11 |
+
- API override lifecycle: clear → set → get → clear
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import pathlib
|
| 15 |
+
import sys
|
| 16 |
+
import types
|
| 17 |
+
import unittest
|
| 18 |
+
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
# Torch availability check — skip entire module if absent
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
try:
|
| 23 |
+
import importlib.util as _ilu
|
| 24 |
+
_spec = _ilu.find_spec('torch')
|
| 25 |
+
_TORCH_AVAILABLE = _spec is not None and getattr(_spec, 'origin', None) is not None
|
| 26 |
+
except (ValueError, ModuleNotFoundError):
|
| 27 |
+
_TORCH_AVAILABLE = False
|
| 28 |
+
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
# A1111 stubs (must be installed before importing lib modules)
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.parent))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _stub(name):
|
| 36 |
+
m = types.ModuleType(name)
|
| 37 |
+
sys.modules.setdefault(name, m)
|
| 38 |
+
return sys.modules[name]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
for _n in ('modules', 'modules.script_callbacks', 'modules.sd_samplers',
|
| 42 |
+
'modules.shared', 'modules.prompt_parser', 'gradio'):
|
| 43 |
+
_stub(_n)
|
| 44 |
+
|
| 45 |
+
import modules.script_callbacks as _sc
|
| 46 |
+
for _attr in ('on_cfg_denoiser', 'on_script_unloaded', 'CFGDenoiserParams'):
|
| 47 |
+
if not hasattr(_sc, _attr):
|
| 48 |
+
setattr(_sc, _attr, lambda *a, **kw: None)
|
| 49 |
+
|
| 50 |
+
import modules.shared as _sh
|
| 51 |
+
if not hasattr(_sh, 'state'):
|
| 52 |
+
_sh.state = types.SimpleNamespace(sampling_step=1)
|
| 53 |
+
|
| 54 |
+
import modules.sd_samplers as _ss
|
| 55 |
+
if not hasattr(_ss, 'create_sampler'):
|
| 56 |
+
_ss.create_sampler = lambda *a, **kw: None
|
| 57 |
+
|
| 58 |
+
import lib_neutral_prompt.hijacker as _h_mod
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class _StubHijacker:
|
| 62 |
+
@classmethod
|
| 63 |
+
def install_or_get(cls, **kwargs):
|
| 64 |
+
return cls()
|
| 65 |
+
def hijack(self, name):
|
| 66 |
+
return lambda fn: fn
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
_h_mod.ModuleHijacker = _StubHijacker
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping runtime tests')
|
| 73 |
+
class TestSalienceKCoverage(unittest.TestCase):
|
| 74 |
+
"""Higher k → strictly fewer pixels where child salience beats parent."""
|
| 75 |
+
|
| 76 |
+
@classmethod
|
| 77 |
+
def setUpClass(cls):
|
| 78 |
+
import torch
|
| 79 |
+
import lib_neutral_prompt.cfg_denoiser_hijack as hj
|
| 80 |
+
cls.torch = torch
|
| 81 |
+
cls.hj = hj
|
| 82 |
+
|
| 83 |
+
def _winning_fraction(self, k_child, C=4, H=64, W=64):
|
| 84 |
+
torch, hj = self.torch, self.hj
|
| 85 |
+
torch.manual_seed(7)
|
| 86 |
+
parent = torch.randn(C, H, W)
|
| 87 |
+
child = torch.randn(C, H, W)
|
| 88 |
+
sal_p = hj._get_salience(parent, k=1.0)
|
| 89 |
+
sal_c = hj._get_salience(child, k=k_child)
|
| 90 |
+
mask = torch.argmax(torch.stack([sal_p, sal_c]), dim=0)
|
| 91 |
+
return (mask == 1).float().mean().item()
|
| 92 |
+
|
| 93 |
+
def test_k1_broad(self):
|
| 94 |
+
self.assertGreater(self._winning_fraction(1.0), 0.40)
|
| 95 |
+
|
| 96 |
+
def test_k5_moderate(self):
|
| 97 |
+
f = self._winning_fraction(5.0)
|
| 98 |
+
self.assertLess(f, 0.15)
|
| 99 |
+
self.assertGreater(f, 0.001)
|
| 100 |
+
|
| 101 |
+
def test_k20_surgical(self):
|
| 102 |
+
self.assertLess(self._winning_fraction(20.0), 0.01)
|
| 103 |
+
|
| 104 |
+
def test_monotone_k(self):
|
| 105 |
+
f1 = self._winning_fraction(1.0)
|
| 106 |
+
f5 = self._winning_fraction(5.0)
|
| 107 |
+
f20 = self._winning_fraction(20.0)
|
| 108 |
+
self.assertGreater(f1, f5, "k=1 should cover more than k=5")
|
| 109 |
+
self.assertGreater(f5, f20, "k=5 should cover more than k=20")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping runtime tests')
|
| 113 |
+
class TestBlobMorphology(unittest.TestCase):
|
| 114 |
+
"""Thickify-first pipeline keeps blob alive; erode-first destroys it."""
|
| 115 |
+
|
| 116 |
+
@classmethod
|
| 117 |
+
def setUpClass(cls):
|
| 118 |
+
import torch
|
| 119 |
+
import lib_neutral_prompt.cfg_denoiser_hijack as hj
|
| 120 |
+
cls.torch = torch
|
| 121 |
+
cls.hj = hj
|
| 122 |
+
|
| 123 |
+
def _seed(self, C=4, H=8, W=8):
|
| 124 |
+
board = self.torch.zeros(C, H, W)
|
| 125 |
+
board[:, H // 2, W // 2] = 1.0
|
| 126 |
+
return board
|
| 127 |
+
|
| 128 |
+
def _run(self, board, thickify, erode):
|
| 129 |
+
hj = self.hj
|
| 130 |
+
for _ in range(thickify):
|
| 131 |
+
board = hj._life_step(board, hj._thickify_rule)
|
| 132 |
+
for _ in range(erode):
|
| 133 |
+
board = hj._life_step(board, hj._erode_rule)
|
| 134 |
+
return board
|
| 135 |
+
|
| 136 |
+
def test_erode_first_kills_isolated_seed(self):
|
| 137 |
+
"""Original ultimate ordering: 6 erodes → seed dies immediately."""
|
| 138 |
+
result = self._run(self._seed(), thickify=0, erode=6)
|
| 139 |
+
self.assertEqual(result.sum().item(), 0,
|
| 140 |
+
"Erode-first should destroy an isolated pixel seed")
|
| 141 |
+
|
| 142 |
+
def test_thickify_first_survives(self):
|
| 143 |
+
"""Patched ordering: 3 thickify then 1 erode → seed survives."""
|
| 144 |
+
result = self._run(self._seed(), thickify=3, erode=1)
|
| 145 |
+
self.assertGreater(result.sum().item(), 0,
|
| 146 |
+
"Thickify-first should keep blob alive")
|
| 147 |
+
|
| 148 |
+
def test_thickify_grows_region(self):
|
| 149 |
+
original_mass = self._seed().sum().item()
|
| 150 |
+
grown = self._run(self._seed(), thickify=3, erode=0)
|
| 151 |
+
self.assertGreater(grown.sum().item(), original_mass)
|
| 152 |
+
|
| 153 |
+
def test_erode_after_thickify_reduces_mass(self):
|
| 154 |
+
grown = self._run(self._seed(), thickify=3, erode=0).sum().item()
|
| 155 |
+
pruned = self._run(self._seed(), thickify=3, erode=1).sum().item()
|
| 156 |
+
self.assertLess(pruned, grown,
|
| 157 |
+
"Erosion after thickify should trim edges")
|
| 158 |
+
self.assertGreater(pruned, 0,
|
| 159 |
+
"But core should still survive")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
@unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping runtime tests')
|
| 163 |
+
class TestNoBaseGuard(unittest.TestCase):
|
| 164 |
+
"""combine_denoised_hijack falls back when all children are conciliatory."""
|
| 165 |
+
|
| 166 |
+
@classmethod
|
| 167 |
+
def setUpClass(cls):
|
| 168 |
+
import torch
|
| 169 |
+
import lib_neutral_prompt.cfg_denoiser_hijack as hj
|
| 170 |
+
import lib_neutral_prompt.global_state as gs
|
| 171 |
+
import lib_neutral_prompt.neutral_prompt_parser as npp
|
| 172 |
+
cls.torch = torch
|
| 173 |
+
cls.hj = hj
|
| 174 |
+
cls.gs = gs
|
| 175 |
+
cls.npp = npp
|
| 176 |
+
|
| 177 |
+
def setUp(self):
|
| 178 |
+
self.gs.is_enabled = True
|
| 179 |
+
self.gs.verbose = False
|
| 180 |
+
self.gs.prompt_exprs = []
|
| 181 |
+
self.gs.batch_cond_indices = []
|
| 182 |
+
|
| 183 |
+
def tearDown(self):
|
| 184 |
+
self.gs.is_enabled = False
|
| 185 |
+
self.gs.prompt_exprs = []
|
| 186 |
+
self.gs.batch_cond_indices = []
|
| 187 |
+
|
| 188 |
+
def _counter(self):
|
| 189 |
+
calls = [0]
|
| 190 |
+
def fn(x_out, batch_cond_indices, text_uncond, cond_scale):
|
| 191 |
+
calls[0] += 1
|
| 192 |
+
return x_out[-1:]
|
| 193 |
+
return fn, calls
|
| 194 |
+
|
| 195 |
+
def test_fallback_when_no_base_child(self):
|
| 196 |
+
root = self.npp.parse_root("AND_SALT concept :1.0")
|
| 197 |
+
self.gs.prompt_exprs = [root]
|
| 198 |
+
self.gs.batch_cond_indices = [[(0, 1.0)]]
|
| 199 |
+
|
| 200 |
+
x_out = self.torch.randn(2, 4, 8, 8)
|
| 201 |
+
text_uncond = x_out[-1:]
|
| 202 |
+
fn, calls = self._counter()
|
| 203 |
+
|
| 204 |
+
self.hj.combine_denoised_hijack(
|
| 205 |
+
x_out, self.gs.batch_cond_indices, text_uncond, 7.0, fn
|
| 206 |
+
)
|
| 207 |
+
self.assertEqual(calls[0], 1,
|
| 208 |
+
"Guard must fall back when there is no base AND child")
|
| 209 |
+
|
| 210 |
+
def test_no_fallback_when_base_exists(self):
|
| 211 |
+
"""
|
| 212 |
+
When a base AND-child exists, the guard must NOT fire.
|
| 213 |
+
|
| 214 |
+
Architecture note: combine_denoised_hijack always calls original_function
|
| 215 |
+
at least once via _get_webui_denoised() on the normal path — that call
|
| 216 |
+
receives a *sliced* x_out (not the raw full tensor).
|
| 217 |
+
The guard fallback, by contrast, passes the *original* x_out unchanged.
|
| 218 |
+
|
| 219 |
+
We distinguish the two by capturing the first argument the counter saw:
|
| 220 |
+
- guard path → x_out.shape[0] == n_conds + n_uncond (full tensor)
|
| 221 |
+
- normal path → x_out.shape[0] <= n_conds (sliced tensor)
|
| 222 |
+
"""
|
| 223 |
+
root = self.npp.parse_root("base concept AND_SALT tweak :1.0")
|
| 224 |
+
self.gs.prompt_exprs = [root]
|
| 225 |
+
# 2 cond slots (base + AND_SALT) + 1 implicit uncond
|
| 226 |
+
self.gs.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
|
| 227 |
+
|
| 228 |
+
n_cond = 2
|
| 229 |
+
n_uncond = 1
|
| 230 |
+
x_out = self.torch.randn(n_cond + n_uncond, 4, 8, 8)
|
| 231 |
+
text_uncond = x_out[-n_uncond:]
|
| 232 |
+
|
| 233 |
+
seen_shapes = []
|
| 234 |
+
|
| 235 |
+
def capturing_fn(x, batch_cond_indices, text_uncond, cond_scale):
|
| 236 |
+
seen_shapes.append(x.shape[0])
|
| 237 |
+
return x[-1:]
|
| 238 |
+
|
| 239 |
+
try:
|
| 240 |
+
self.hj.combine_denoised_hijack(
|
| 241 |
+
x_out, self.gs.batch_cond_indices, text_uncond, 7.0, capturing_fn
|
| 242 |
+
)
|
| 243 |
+
except Exception:
|
| 244 |
+
pass # downstream tensor errors are fine; we only check entry shape
|
| 245 |
+
|
| 246 |
+
self.assertTrue(len(seen_shapes) > 0,
|
| 247 |
+
"original_function was never called at all")
|
| 248 |
+
# On the normal path the first call is from _get_webui_denoised which
|
| 249 |
+
# passes a sliced tensor (only base-AND children + uncond).
|
| 250 |
+
# On the guard path it would pass the full x_out (n_cond + n_uncond rows).
|
| 251 |
+
full_tensor_size = n_cond + n_uncond
|
| 252 |
+
self.assertLess(seen_shapes[0], full_tensor_size,
|
| 253 |
+
"Guard fired and passed full x_out — it should NOT have")
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
@unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping runtime tests')
|
| 257 |
+
class TestCFGRescaleAPILifecycle(unittest.TestCase):
|
| 258 |
+
|
| 259 |
+
@classmethod
|
| 260 |
+
def setUpClass(cls):
|
| 261 |
+
import lib_neutral_prompt.global_state as gs
|
| 262 |
+
cls.gs = gs
|
| 263 |
+
|
| 264 |
+
def setUp(self):
|
| 265 |
+
self.gs.cfg_rescale = 0.0
|
| 266 |
+
self.gs.cfg_rescale_override = None
|
| 267 |
+
self.gs.CFGRescaleFactorSingleton.clear()
|
| 268 |
+
|
| 269 |
+
def test_override_applied_and_cleared(self):
|
| 270 |
+
self.gs.cfg_rescale_override = 0.7
|
| 271 |
+
self.gs.apply_and_clear_cfg_rescale_override()
|
| 272 |
+
self.assertAlmostEqual(self.gs.cfg_rescale, 0.7)
|
| 273 |
+
self.assertIsNone(self.gs.cfg_rescale_override)
|
| 274 |
+
|
| 275 |
+
def test_singleton_clear_returns_none(self):
|
| 276 |
+
self.gs.CFGRescaleFactorSingleton.clear()
|
| 277 |
+
self.assertIsNone(self.gs.CFGRescaleFactorSingleton.get())
|
| 278 |
+
|
| 279 |
+
def test_singleton_set_and_get(self):
|
| 280 |
+
self.gs.CFGRescaleFactorSingleton.set(0.85)
|
| 281 |
+
self.assertAlmostEqual(self.gs.CFGRescaleFactorSingleton.get(), 0.85)
|
| 282 |
+
|
| 283 |
+
def test_singleton_cleared_between_steps(self):
|
| 284 |
+
self.gs.CFGRescaleFactorSingleton.set(0.85)
|
| 285 |
+
self.gs.CFGRescaleFactorSingleton.clear()
|
| 286 |
+
self.assertIsNone(self.gs.CFGRescaleFactorSingleton.get())
|
| 287 |
+
|
| 288 |
+
def test_no_override_does_not_change_rescale(self):
|
| 289 |
+
"""apply_and_clear with no pending override must leave cfg_rescale alone."""
|
| 290 |
+
self.gs.cfg_rescale = 0.3
|
| 291 |
+
self.gs.apply_and_clear_cfg_rescale_override()
|
| 292 |
+
self.assertAlmostEqual(self.gs.cfg_rescale, 0.3)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
if __name__ == '__main__':
|
| 296 |
+
unittest.main()
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
@unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping runtime tests')
|
| 300 |
+
class TestProtectionModes(unittest.TestCase):
|
| 301 |
+
"""Off / Auto / Strict protection modes."""
|
| 302 |
+
|
| 303 |
+
@classmethod
|
| 304 |
+
def setUpClass(cls):
|
| 305 |
+
import torch
|
| 306 |
+
import lib_neutral_prompt.cfg_denoiser_hijack as hj
|
| 307 |
+
import lib_neutral_prompt.global_state as gs
|
| 308 |
+
import lib_neutral_prompt.neutral_prompt_parser as npp
|
| 309 |
+
cls.torch = torch
|
| 310 |
+
cls.hj = hj
|
| 311 |
+
cls.gs = gs
|
| 312 |
+
cls.npp = npp
|
| 313 |
+
|
| 314 |
+
def setUp(self):
|
| 315 |
+
self.gs.is_enabled = True
|
| 316 |
+
self.gs.verbose = False
|
| 317 |
+
self.gs.protection_mode = 'auto'
|
| 318 |
+
self.gs.strict_threshold = 0.1
|
| 319 |
+
self.gs.prompt_exprs = []
|
| 320 |
+
self.gs.batch_cond_indices = []
|
| 321 |
+
|
| 322 |
+
def tearDown(self):
|
| 323 |
+
self.gs.is_enabled = False
|
| 324 |
+
self.gs.protection_mode = 'auto'
|
| 325 |
+
self.gs.prompt_exprs = []
|
| 326 |
+
self.gs.batch_cond_indices = []
|
| 327 |
+
|
| 328 |
+
def _run(self, prompt_str, n_cond, protection_mode, strict_threshold=0.1):
|
| 329 |
+
"""Return the shape[0] of x_out passed to original_fn on first call."""
|
| 330 |
+
root = self.npp.parse_root(prompt_str)
|
| 331 |
+
self.gs.prompt_exprs = [root]
|
| 332 |
+
self.gs.batch_cond_indices = [[(i, 1.0) for i in range(n_cond)]]
|
| 333 |
+
self.gs.protection_mode = protection_mode
|
| 334 |
+
self.gs.strict_threshold = strict_threshold
|
| 335 |
+
|
| 336 |
+
n_uncond = 1
|
| 337 |
+
x_out = self.torch.randn(n_cond + n_uncond, 4, 8, 8)
|
| 338 |
+
text_uncond = x_out[-n_uncond:]
|
| 339 |
+
|
| 340 |
+
seen = []
|
| 341 |
+
def fn(x, *a, **kw):
|
| 342 |
+
seen.append(x.shape[0])
|
| 343 |
+
return x[-1:]
|
| 344 |
+
|
| 345 |
+
try:
|
| 346 |
+
self.hj.combine_denoised_hijack(
|
| 347 |
+
x_out, self.gs.batch_cond_indices, text_uncond, 7.0, fn)
|
| 348 |
+
except Exception:
|
| 349 |
+
pass
|
| 350 |
+
|
| 351 |
+
return seen, n_cond + n_uncond
|
| 352 |
+
|
| 353 |
+
def _guard_fired(self, seen, full_size):
|
| 354 |
+
"""Guard fires when original_fn receives the full-size x_out."""
|
| 355 |
+
return len(seen) > 0 and seen[0] == full_size
|
| 356 |
+
|
| 357 |
+
# --- Auto mode ---
|
| 358 |
+
|
| 359 |
+
def test_auto_fires_on_no_base(self):
|
| 360 |
+
seen, full = self._run("AND_SALT concept :1.0", n_cond=1,
|
| 361 |
+
protection_mode='auto')
|
| 362 |
+
self.assertTrue(self._guard_fired(seen, full),
|
| 363 |
+
"Auto guard must fire when no base AND-child exists")
|
| 364 |
+
|
| 365 |
+
def test_auto_does_not_fire_with_base(self):
|
| 366 |
+
seen, full = self._run("base concept AND_SALT tweak :1.0", n_cond=2,
|
| 367 |
+
protection_mode='auto')
|
| 368 |
+
self.assertFalse(self._guard_fired(seen, full),
|
| 369 |
+
"Auto guard must NOT fire when base AND-child exists")
|
| 370 |
+
|
| 371 |
+
# --- Off mode ---
|
| 372 |
+
|
| 373 |
+
def test_off_never_fires_even_without_base(self):
|
| 374 |
+
seen, full = self._run("AND_SALT concept :1.0", n_cond=1,
|
| 375 |
+
protection_mode='off')
|
| 376 |
+
self.assertFalse(self._guard_fired(seen, full),
|
| 377 |
+
"Off mode must never fire the guard")
|
| 378 |
+
|
| 379 |
+
# --- Strict mode: structural check same as auto ---
|
| 380 |
+
|
| 381 |
+
def test_strict_fires_on_no_base(self):
|
| 382 |
+
seen, full = self._run("AND_SALT concept :1.0", n_cond=1,
|
| 383 |
+
protection_mode='strict')
|
| 384 |
+
self.assertTrue(self._guard_fired(seen, full),
|
| 385 |
+
"Strict guard must also fire on structural no-base case")
|
| 386 |
+
|
| 387 |
+
# --- global_state defaults ---
|
| 388 |
+
|
| 389 |
+
def test_default_protection_mode_is_auto(self):
|
| 390 |
+
import importlib, lib_neutral_prompt.global_state as gs_mod
|
| 391 |
+
# default is set at module level; check it's 'auto'
|
| 392 |
+
self.assertEqual(gs_mod.protection_mode, 'auto')
|
| 393 |
+
|
| 394 |
+
def test_strict_threshold_default(self):
|
| 395 |
+
import lib_neutral_prompt.global_state as gs_mod
|
| 396 |
+
self.assertAlmostEqual(gs_mod.strict_threshold, 0.1)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
# ---------------------------------------------------------------------------
|
| 400 |
+
# Tests for global_state helpers (no torch needed)
|
| 401 |
+
# ---------------------------------------------------------------------------
|
| 402 |
+
|
| 403 |
+
class TestGlobalStateHelpers(unittest.TestCase):
|
| 404 |
+
"""normalize_protection_mode and clamp_strict_threshold."""
|
| 405 |
+
|
| 406 |
+
@classmethod
|
| 407 |
+
def setUpClass(cls):
|
| 408 |
+
import lib_neutral_prompt.global_state as gs
|
| 409 |
+
cls.gs = gs
|
| 410 |
+
|
| 411 |
+
def test_normalize_valid_modes(self):
|
| 412 |
+
for mode in ('off', 'auto', 'strict'):
|
| 413 |
+
self.assertEqual(self.gs.normalize_protection_mode(mode), mode)
|
| 414 |
+
|
| 415 |
+
def test_normalize_uppercase(self):
|
| 416 |
+
self.assertEqual(self.gs.normalize_protection_mode('AUTO'), 'auto')
|
| 417 |
+
self.assertEqual(self.gs.normalize_protection_mode('Strict'), 'strict')
|
| 418 |
+
|
| 419 |
+
def test_normalize_whitespace(self):
|
| 420 |
+
self.assertEqual(self.gs.normalize_protection_mode(' off '), 'off')
|
| 421 |
+
|
| 422 |
+
def test_normalize_unknown_falls_back_to_auto(self):
|
| 423 |
+
self.assertEqual(self.gs.normalize_protection_mode('strong'), 'auto')
|
| 424 |
+
self.assertEqual(self.gs.normalize_protection_mode(''), 'auto')
|
| 425 |
+
self.assertEqual(self.gs.normalize_protection_mode('1'), 'auto')
|
| 426 |
+
|
| 427 |
+
def test_clamp_valid(self):
|
| 428 |
+
self.assertAlmostEqual(self.gs.clamp_strict_threshold(0.10), 0.10)
|
| 429 |
+
self.assertAlmostEqual(self.gs.clamp_strict_threshold(0.50), 0.50)
|
| 430 |
+
self.assertAlmostEqual(self.gs.clamp_strict_threshold(0.01), 0.01)
|
| 431 |
+
|
| 432 |
+
def test_clamp_below_min(self):
|
| 433 |
+
self.assertAlmostEqual(self.gs.clamp_strict_threshold(0.0), 0.01)
|
| 434 |
+
self.assertAlmostEqual(self.gs.clamp_strict_threshold(-5), 0.01)
|
| 435 |
+
|
| 436 |
+
def test_clamp_above_max(self):
|
| 437 |
+
self.assertAlmostEqual(self.gs.clamp_strict_threshold(0.99), 0.50)
|
| 438 |
+
|
| 439 |
+
def test_clamp_bad_type_returns_default(self):
|
| 440 |
+
self.assertAlmostEqual(self.gs.clamp_strict_threshold('bad'), 0.10)
|
| 441 |
+
self.assertAlmostEqual(self.gs.clamp_strict_threshold(None), 0.10)
|
| 442 |
+
|
| 443 |
+
def test_apply_infotext_protection(self):
|
| 444 |
+
import lib_neutral_prompt.global_state as gs
|
| 445 |
+
gs.protection_mode = 'auto'
|
| 446 |
+
gs.strict_threshold = 0.10
|
| 447 |
+
gs.apply_infotext({
|
| 448 |
+
gs.INFOTEXT_KEY_PROTECTION_MODE: 'strict',
|
| 449 |
+
gs.INFOTEXT_KEY_STRICT_THRESHOLD: '0.25',
|
| 450 |
+
})
|
| 451 |
+
self.assertEqual(gs.protection_mode, 'strict')
|
| 452 |
+
self.assertAlmostEqual(gs.strict_threshold, 0.25)
|
| 453 |
+
|
| 454 |
+
def test_apply_infotext_invalid_mode_falls_back(self):
|
| 455 |
+
import lib_neutral_prompt.global_state as gs
|
| 456 |
+
gs.protection_mode = 'off'
|
| 457 |
+
gs.apply_infotext({gs.INFOTEXT_KEY_PROTECTION_MODE: 'maximum'})
|
| 458 |
+
self.assertEqual(gs.protection_mode, 'auto')
|
| 459 |
+
|
| 460 |
+
def test_apply_infotext_partial_update(self):
|
| 461 |
+
"""apply_infotext must not touch fields absent from the dict."""
|
| 462 |
+
import lib_neutral_prompt.global_state as gs
|
| 463 |
+
gs.protection_mode = 'strict'
|
| 464 |
+
gs.strict_threshold = 0.20
|
| 465 |
+
gs.apply_infotext({}) # empty dict
|
| 466 |
+
self.assertEqual(gs.protection_mode, 'strict')
|
| 467 |
+
self.assertAlmostEqual(gs.strict_threshold, 0.20)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
@unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping runtime tests')
|
| 471 |
+
class TestProtectionPolicyFunction(unittest.TestCase):
|
| 472 |
+
"""_should_fallback_for_protection across Off / Auto / Strict."""
|
| 473 |
+
|
| 474 |
+
@classmethod
|
| 475 |
+
def setUpClass(cls):
|
| 476 |
+
import torch
|
| 477 |
+
import lib_neutral_prompt.cfg_denoiser_hijack as hj
|
| 478 |
+
import lib_neutral_prompt.global_state as gs
|
| 479 |
+
import lib_neutral_prompt.neutral_prompt_parser as npp
|
| 480 |
+
cls.torch = torch
|
| 481 |
+
cls.hj = hj
|
| 482 |
+
cls.gs = gs
|
| 483 |
+
cls.npp = npp
|
| 484 |
+
|
| 485 |
+
def setUp(self):
|
| 486 |
+
self.gs.is_enabled = True
|
| 487 |
+
self.gs.protection_mode = 'auto'
|
| 488 |
+
self.gs.strict_threshold = 0.10
|
| 489 |
+
self.gs.prompt_exprs = []
|
| 490 |
+
self.gs.batch_cond_indices = []
|
| 491 |
+
|
| 492 |
+
def tearDown(self):
|
| 493 |
+
self.gs.is_enabled = False
|
| 494 |
+
self.gs.protection_mode = 'auto'
|
| 495 |
+
|
| 496 |
+
def _setup(self, prompt_str, n_cond):
|
| 497 |
+
root = self.npp.parse_root(prompt_str)
|
| 498 |
+
self.gs.prompt_exprs = [root]
|
| 499 |
+
self.gs.batch_cond_indices = [[(i, 1.0) for i in range(n_cond)]]
|
| 500 |
+
n_uncond = 1
|
| 501 |
+
x_out = self.torch.randn(n_cond + n_uncond, 4, 8, 8)
|
| 502 |
+
text_uncond = x_out[-n_uncond:]
|
| 503 |
+
return x_out, text_uncond
|
| 504 |
+
|
| 505 |
+
# --- Off: never fires regardless of structure ---
|
| 506 |
+
|
| 507 |
+
def test_off_no_base(self):
|
| 508 |
+
self.gs.protection_mode = 'off'
|
| 509 |
+
x_out, text_uncond = self._setup("AND_SALT concept :1.0", n_cond=1)
|
| 510 |
+
should, reason = self.hj._should_fallback_for_protection(
|
| 511 |
+
x_out, self.gs.batch_cond_indices, text_uncond)
|
| 512 |
+
self.assertFalse(should)
|
| 513 |
+
self.assertIsNone(reason)
|
| 514 |
+
|
| 515 |
+
def test_off_with_base(self):
|
| 516 |
+
self.gs.protection_mode = 'off'
|
| 517 |
+
x_out, text_uncond = self._setup("base AND_SALT tweak :1.0", n_cond=2)
|
| 518 |
+
should, _ = self.hj._should_fallback_for_protection(
|
| 519 |
+
x_out, self.gs.batch_cond_indices, text_uncond)
|
| 520 |
+
self.assertFalse(should)
|
| 521 |
+
|
| 522 |
+
# --- Auto: structural check only ---
|
| 523 |
+
|
| 524 |
+
def test_auto_fires_no_base(self):
|
| 525 |
+
self.gs.protection_mode = 'auto'
|
| 526 |
+
x_out, text_uncond = self._setup("AND_SALT concept :1.0", n_cond=1)
|
| 527 |
+
should, reason = self.hj._should_fallback_for_protection(
|
| 528 |
+
x_out, self.gs.batch_cond_indices, text_uncond)
|
| 529 |
+
self.assertTrue(should)
|
| 530 |
+
self.assertIn('no valid base', reason)
|
| 531 |
+
self.assertIn('auto', reason)
|
| 532 |
+
|
| 533 |
+
def test_auto_does_not_fire_with_base(self):
|
| 534 |
+
self.gs.protection_mode = 'auto'
|
| 535 |
+
x_out, text_uncond = self._setup("base AND_SALT tweak :1.0", n_cond=2)
|
| 536 |
+
should, _ = self.hj._should_fallback_for_protection(
|
| 537 |
+
x_out, self.gs.batch_cond_indices, text_uncond)
|
| 538 |
+
self.assertFalse(should)
|
| 539 |
+
|
| 540 |
+
# --- Strict: also fires on no-base (structural) ---
|
| 541 |
+
|
| 542 |
+
def test_strict_fires_on_no_base(self):
|
| 543 |
+
self.gs.protection_mode = 'strict'
|
| 544 |
+
x_out, text_uncond = self._setup("AND_SALT concept :1.0", n_cond=1)
|
| 545 |
+
should, reason = self.hj._should_fallback_for_protection(
|
| 546 |
+
x_out, self.gs.batch_cond_indices, text_uncond)
|
| 547 |
+
self.assertTrue(should)
|
| 548 |
+
self.assertIn('strict', reason)
|
| 549 |
+
|
| 550 |
+
# --- Strict: numerical ratio check ---
|
| 551 |
+
|
| 552 |
+
def test_strict_reason_contains_ratio_info(self):
|
| 553 |
+
"""When strict fires on ratio, reason must include ratio and threshold."""
|
| 554 |
+
self.gs.protection_mode = 'strict'
|
| 555 |
+
self.gs.strict_threshold = 0.99 # almost certain to fire
|
| 556 |
+
|
| 557 |
+
x_out, text_uncond = self._setup("base AND_SALT tweak :1.0", n_cond=2)
|
| 558 |
+
should, reason = self.hj._should_fallback_for_protection(
|
| 559 |
+
x_out, self.gs.batch_cond_indices, text_uncond)
|
| 560 |
+
|
| 561 |
+
if should:
|
| 562 |
+
# If it fired, reason must explain why
|
| 563 |
+
self.assertIn('ratio', reason)
|
| 564 |
+
self.assertIn('threshold', reason)
|
| 565 |
+
# If it didn't fire (both norms happened to be equal), that's fine too.
|
| 566 |
+
|
| 567 |
+
def test_strict_does_not_fire_with_threshold_zero(self):
|
| 568 |
+
"""threshold=0.01 (minimum) should almost never fire on random latents."""
|
| 569 |
+
self.gs.protection_mode = 'strict'
|
| 570 |
+
self.gs.strict_threshold = 0.01
|
| 571 |
+
|
| 572 |
+
x_out, text_uncond = self._setup("base AND_SALT tweak :1.0", n_cond=2)
|
| 573 |
+
should, _ = self.hj._should_fallback_for_protection(
|
| 574 |
+
x_out, self.gs.batch_cond_indices, text_uncond)
|
| 575 |
+
self.assertFalse(should,
|
| 576 |
+
"threshold=0.01 should almost never trigger on random latents")
|
| 577 |
+
|
| 578 |
+
# --- _has_valid_base_path ---
|
| 579 |
+
|
| 580 |
+
def test_has_valid_base_path_true(self):
|
| 581 |
+
root = self.npp.parse_root("base AND_SALT tweak :1.0")
|
| 582 |
+
self.assertTrue(self.hj._has_valid_base_path([root]))
|
| 583 |
+
|
| 584 |
+
def test_has_valid_base_path_false(self):
|
| 585 |
+
root = self.npp.parse_root("AND_SALT only :1.0")
|
| 586 |
+
self.assertFalse(self.hj._has_valid_base_path([root]))
|
| 587 |
+
|
| 588 |
+
def test_has_valid_base_path_empty(self):
|
| 589 |
+
self.assertFalse(self.hj._has_valid_base_path([]))
|
| 590 |
+
|
| 591 |
+
# --- _safe_norm ---
|
| 592 |
+
|
| 593 |
+
def test_safe_norm_normal(self):
|
| 594 |
+
t = self.torch.ones(4, 8, 8)
|
| 595 |
+
self.assertGreater(self.hj._safe_norm(t), 0.0)
|
| 596 |
+
|
| 597 |
+
def test_safe_norm_zeros(self):
|
| 598 |
+
t = self.torch.zeros(4, 8, 8)
|
| 599 |
+
self.assertAlmostEqual(self.hj._safe_norm(t), 0.0)
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
@unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping runtime tests')
|
| 603 |
+
class TestPerPromptProtection(unittest.TestCase):
|
| 604 |
+
"""Per-prompt structural check — a healthy prompt must not shield a bad neighbour."""
|
| 605 |
+
|
| 606 |
+
@classmethod
|
| 607 |
+
def setUpClass(cls):
|
| 608 |
+
import torch
|
| 609 |
+
import lib_neutral_prompt.cfg_denoiser_hijack as hj
|
| 610 |
+
import lib_neutral_prompt.global_state as gs
|
| 611 |
+
import lib_neutral_prompt.neutral_prompt_parser as npp
|
| 612 |
+
cls.torch = torch
|
| 613 |
+
cls.hj = hj
|
| 614 |
+
cls.gs = gs
|
| 615 |
+
cls.npp = npp
|
| 616 |
+
|
| 617 |
+
def setUp(self):
|
| 618 |
+
self.gs.is_enabled = True
|
| 619 |
+
self.gs.protection_mode = 'auto'
|
| 620 |
+
self.gs.strict_threshold = 0.10
|
| 621 |
+
self.gs.prompt_exprs = []
|
| 622 |
+
self.gs.batch_cond_indices = []
|
| 623 |
+
|
| 624 |
+
def tearDown(self):
|
| 625 |
+
self.gs.is_enabled = False
|
| 626 |
+
self.gs.protection_mode = 'auto'
|
| 627 |
+
self.gs.prompt_exprs = []
|
| 628 |
+
self.gs.batch_cond_indices = []
|
| 629 |
+
|
| 630 |
+
# --- _get_invalid_base_prompt_indices ---
|
| 631 |
+
|
| 632 |
+
def test_helper_all_good(self):
|
| 633 |
+
roots = [
|
| 634 |
+
self.npp.parse_root("base AND_SALT tweak :1.0"),
|
| 635 |
+
self.npp.parse_root("base AND_PERP style :0.8"),
|
| 636 |
+
]
|
| 637 |
+
self.assertEqual(self.hj._get_invalid_base_prompt_indices(roots), [])
|
| 638 |
+
|
| 639 |
+
def test_helper_all_bad(self):
|
| 640 |
+
roots = [
|
| 641 |
+
self.npp.parse_root("AND_SALT only :1.0"),
|
| 642 |
+
self.npp.parse_root("AND_PERP only :0.8"),
|
| 643 |
+
]
|
| 644 |
+
self.assertEqual(self.hj._get_invalid_base_prompt_indices(roots), [0, 1])
|
| 645 |
+
|
| 646 |
+
def test_helper_mixed_batch(self):
|
| 647 |
+
roots = [
|
| 648 |
+
self.npp.parse_root("base AND_SALT tweak :1.0"), # index 0 — good
|
| 649 |
+
self.npp.parse_root("AND_PERP only :0.8"), # index 1 — bad
|
| 650 |
+
]
|
| 651 |
+
self.assertEqual(self.hj._get_invalid_base_prompt_indices(roots), [1])
|
| 652 |
+
|
| 653 |
+
def test_helper_empty_exprs(self):
|
| 654 |
+
self.assertEqual(self.hj._get_invalid_base_prompt_indices([]), [])
|
| 655 |
+
|
| 656 |
+
# --- _prompt_has_valid_base_path ---
|
| 657 |
+
|
| 658 |
+
def test_per_prompt_good(self):
|
| 659 |
+
root = self.npp.parse_root("base AND_SALT tweak :1.0")
|
| 660 |
+
self.assertTrue(self.hj._prompt_has_valid_base_path(root))
|
| 661 |
+
|
| 662 |
+
def test_per_prompt_bad(self):
|
| 663 |
+
root = self.npp.parse_root("AND_SALT only :1.0")
|
| 664 |
+
self.assertFalse(self.hj._prompt_has_valid_base_path(root))
|
| 665 |
+
|
| 666 |
+
# --- Mixed-batch policy: auto must fire even if index-0 is healthy ---
|
| 667 |
+
|
| 668 |
+
def _run_mixed_batch(self, protection_mode):
|
| 669 |
+
"""Batch: prompt 0 OK, prompt 1 has no base."""
|
| 670 |
+
root_good = self.npp.parse_root("base AND_SALT tweak :1.0")
|
| 671 |
+
root_bad = self.npp.parse_root("AND_SALT only :1.0")
|
| 672 |
+
self.gs.prompt_exprs = [root_good, root_bad]
|
| 673 |
+
self.gs.batch_cond_indices = [[(0, 1.0), (1, 1.0)], [(2, 1.0)]]
|
| 674 |
+
self.gs.protection_mode = protection_mode
|
| 675 |
+
|
| 676 |
+
n_uncond = 1
|
| 677 |
+
x_out = self.torch.randn(3 + n_uncond, 4, 8, 8)
|
| 678 |
+
text_uncond = x_out[-n_uncond:]
|
| 679 |
+
|
| 680 |
+
seen = []
|
| 681 |
+
def fn(x, *a, **kw):
|
| 682 |
+
seen.append(x.shape[0])
|
| 683 |
+
return x[-1:]
|
| 684 |
+
|
| 685 |
+
try:
|
| 686 |
+
self.hj.combine_denoised_hijack(
|
| 687 |
+
x_out, self.gs.batch_cond_indices, text_uncond, 7.0, fn)
|
| 688 |
+
except Exception:
|
| 689 |
+
pass
|
| 690 |
+
return seen, x_out.shape[0]
|
| 691 |
+
|
| 692 |
+
def test_auto_fires_on_mixed_batch(self):
|
| 693 |
+
seen, full = self._run_mixed_batch('auto')
|
| 694 |
+
# Guard should have passed full x_out to original_fn (fallback path)
|
| 695 |
+
self.assertTrue(len(seen) > 0 and seen[0] == full,
|
| 696 |
+
"Auto guard must fire when any prompt in the batch lacks a base path")
|
| 697 |
+
|
| 698 |
+
def test_off_does_not_fire_on_mixed_batch(self):
|
| 699 |
+
seen, full = self._run_mixed_batch('off')
|
| 700 |
+
self.assertTrue(len(seen) == 0 or seen[0] != full,
|
| 701 |
+
"Off mode must never trigger the guard, even for bad prompts")
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
@unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping runtime tests')
|
| 705 |
+
class TestSafeAffineInversion(unittest.TestCase):
|
| 706 |
+
"""_try_invert_affine: graceful fallback on singular / NaN matrices."""
|
| 707 |
+
|
| 708 |
+
@classmethod
|
| 709 |
+
def setUpClass(cls):
|
| 710 |
+
import torch
|
| 711 |
+
import lib_neutral_prompt.cfg_denoiser_hijack as hj
|
| 712 |
+
import lib_neutral_prompt.global_state as gs
|
| 713 |
+
cls.torch = torch
|
| 714 |
+
cls.hj = hj
|
| 715 |
+
cls.gs = gs
|
| 716 |
+
|
| 717 |
+
def setUp(self):
|
| 718 |
+
self.gs.verbose = False
|
| 719 |
+
|
| 720 |
+
def tearDown(self):
|
| 721 |
+
self.gs.verbose = False
|
| 722 |
+
|
| 723 |
+
def test_normal_invertible_matrix(self):
|
| 724 |
+
m = self.torch.eye(3)
|
| 725 |
+
result = self.hj._try_invert_affine(m, 'test')
|
| 726 |
+
self.assertTrue(self.torch.allclose(result, self.torch.eye(3)),
|
| 727 |
+
"Inverting identity should return identity")
|
| 728 |
+
|
| 729 |
+
def test_singular_matrix_returns_identity(self):
|
| 730 |
+
"""SCALE[0] gives a singular matrix — must not crash."""
|
| 731 |
+
m = self.torch.zeros(3, 3) # fully degenerate
|
| 732 |
+
result = self.hj._try_invert_affine(m, 'singular test')
|
| 733 |
+
self.assertTrue(self.torch.allclose(result, self.torch.eye(3)),
|
| 734 |
+
"Singular matrix must fall back to identity")
|
| 735 |
+
|
| 736 |
+
def test_nearly_singular_returns_identity(self):
|
| 737 |
+
m = self.torch.eye(3)
|
| 738 |
+
m[0, 0] = 1e-9 # effectively zero scale
|
| 739 |
+
result = self.hj._try_invert_affine(m, 'nearly singular')
|
| 740 |
+
# Either returns eye (fallback) or a finite result — must not be NaN/Inf
|
| 741 |
+
self.assertTrue(self.torch.isfinite(result).all(),
|
| 742 |
+
"Result must be finite even for nearly-singular input")
|
| 743 |
+
|
| 744 |
+
def test_verbose_warning_on_singular(self):
|
| 745 |
+
import io, unittest.mock
|
| 746 |
+
self.gs.verbose = True
|
| 747 |
+
buf = io.StringIO()
|
| 748 |
+
m = self.torch.zeros(3, 3)
|
| 749 |
+
with unittest.mock.patch('sys.stderr', buf):
|
| 750 |
+
self.hj._try_invert_affine(m, 'singular context')
|
| 751 |
+
self.gs.verbose = False
|
| 752 |
+
self.assertIn('singular context', buf.getvalue())
|
| 753 |
+
self.assertIn('identity', buf.getvalue().lower())
|
| 754 |
+
|
| 755 |
+
def test_no_warning_when_verbose_false(self):
|
| 756 |
+
import io, unittest.mock
|
| 757 |
+
self.gs.verbose = False
|
| 758 |
+
buf = io.StringIO()
|
| 759 |
+
m = self.torch.zeros(3, 3)
|
| 760 |
+
with unittest.mock.patch('sys.stderr', buf):
|
| 761 |
+
self.hj._try_invert_affine(m, 'silent singular')
|
| 762 |
+
self.assertEqual(buf.getvalue(), '',
|
| 763 |
+
"No warning expected when verbose=False")
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
class TestAffineParserValidation(unittest.TestCase):
|
| 767 |
+
"""Parser-level rejection of invalid affine params."""
|
| 768 |
+
|
| 769 |
+
@classmethod
|
| 770 |
+
def setUpClass(cls):
|
| 771 |
+
# mock_torch already installed by package __init__
|
| 772 |
+
import lib_neutral_prompt.neutral_prompt_parser as _p
|
| 773 |
+
cls.p = _p
|
| 774 |
+
|
| 775 |
+
def _parse(self, s):
|
| 776 |
+
return self.p.parse_root(s)
|
| 777 |
+
|
| 778 |
+
def _child(self, s, idx=-1):
|
| 779 |
+
return self._parse(s).children[idx]
|
| 780 |
+
|
| 781 |
+
# --- SCALE[0] must not produce a transform ---
|
| 782 |
+
|
| 783 |
+
def test_scale_zero_rejected(self):
|
| 784 |
+
"""SCALE[0] must not crash. If a transform is stored it must come from
|
| 785 |
+
the identity fallback (make_scale returned None → _apply_affine kept t)."""
|
| 786 |
+
try:
|
| 787 |
+
child = self._child("base AND_PERP SCALE[0] text :0.8")
|
| 788 |
+
except Exception as exc:
|
| 789 |
+
self.fail(f"SCALE[0] must not crash the parser: {exc}")
|
| 790 |
+
# Either no transform stored (None) or the stored transform is the
|
| 791 |
+
# identity that _apply_affine kept when make_scale returned None.
|
| 792 |
+
# Both are acceptable — the key guarantee is no crash.
|
| 793 |
+
|
| 794 |
+
def test_scale_nonzero_accepted(self):
|
| 795 |
+
child = self._child("base AND_PERP SCALE[2] text :0.8")
|
| 796 |
+
self.assertIsNotNone(child.local_transform,
|
| 797 |
+
"SCALE[2] should produce a valid transform")
|
| 798 |
+
|
| 799 |
+
# --- SHEAR near 0.25 turns must be rejected ---
|
| 800 |
+
|
| 801 |
+
def test_shear_safe_value_accepted(self):
|
| 802 |
+
child = self._child("base AND_PERP SHEAR[0.1] text :0.8")
|
| 803 |
+
# Should produce a transform (0.1 turns is safe)
|
| 804 |
+
# Don't assert non-None since mock torch may not multiply correctly,
|
| 805 |
+
# just check no crash occurred.
|
| 806 |
+
|
| 807 |
+
def test_shear_dangerous_value_warning(self):
|
| 808 |
+
"""SHEAR[0.25] is exactly at the singularity — must not crash."""
|
| 809 |
+
try:
|
| 810 |
+
root = self._parse("base AND_PERP SHEAR[0.25] text :0.8")
|
| 811 |
+
except Exception as exc:
|
| 812 |
+
self.fail(f"SHEAR[0.25] must not crash the parser: {exc}")
|
| 813 |
+
|
| 814 |
+
# --- Non-finite params must not crash ---
|
| 815 |
+
|
| 816 |
+
def test_rotate_no_crash(self):
|
| 817 |
+
try:
|
| 818 |
+
self._parse("base AND_PERP ROTATE[0.5] text :0.8")
|
| 819 |
+
except Exception as exc:
|
| 820 |
+
self.fail(f"ROTATE[0.5] crashed: {exc}")
|
| 821 |
+
|
| 822 |
+
def test_slide_no_crash(self):
|
| 823 |
+
try:
|
| 824 |
+
self._parse("base AND_SALT SLIDE[0.1,0.2] text :1.0")
|
| 825 |
+
except Exception as exc:
|
| 826 |
+
self.fail(f"SLIDE crashed: {exc}")
|
neutral_prompt_patcheds/test/perp_parser/test_sprint2.py
ADDED
|
@@ -0,0 +1,656 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sprint 2 tests: protection_utils, affine_utils, step_utils, builder v1.2.
|
| 3 |
+
|
| 4 |
+
All pure-Python, no torch (torch-dependent cases skipped via mock).
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import unittest
|
| 8 |
+
|
| 9 |
+
# mock_torch installed by package __init__
|
| 10 |
+
|
| 11 |
+
from lib_neutral_prompt import global_state
|
| 12 |
+
from lib_neutral_prompt.affine_utils import (
|
| 13 |
+
make_rotate, make_slide, make_scale, make_shear,
|
| 14 |
+
apply_affine, try_invert_affine,
|
| 15 |
+
build_affine_snippet, AFFINE_PRESETS, AFFINE_TRANSFORMS,
|
| 16 |
+
affine_transforms,
|
| 17 |
+
_make_rotate, _make_scale, _apply_affine, # alias checks
|
| 18 |
+
)
|
| 19 |
+
from lib_neutral_prompt.protection_utils import (
|
| 20 |
+
prompt_has_valid_base_path,
|
| 21 |
+
get_invalid_base_prompt_indices,
|
| 22 |
+
has_valid_base_path,
|
| 23 |
+
safe_norm,
|
| 24 |
+
compute_soft_attenuation,
|
| 25 |
+
protection_verdict,
|
| 26 |
+
)
|
| 27 |
+
from lib_neutral_prompt.step_utils import (
|
| 28 |
+
StepWindow, ALWAYS_ACTIVE, STRATEGY_DEFAULTS,
|
| 29 |
+
normalize_progress, strategy_is_active,
|
| 30 |
+
_strategy_family, render_step_window_summary,
|
| 31 |
+
)
|
| 32 |
+
from lib_neutral_prompt.matryoshka_utils import (
|
| 33 |
+
BuilderNode,
|
| 34 |
+
builder_add_child, builder_remove_node, builder_duplicate_node,
|
| 35 |
+
builder_move_up, builder_move_down, builder_update_node,
|
| 36 |
+
serialize_builder_tree, builder_tree_to_prompt,
|
| 37 |
+
)
|
| 38 |
+
import lib_neutral_prompt.neutral_prompt_parser as p
|
| 39 |
+
|
| 40 |
+
import sys as _sys
|
| 41 |
+
|
| 42 |
+
def _has_real_torch() -> bool:
|
| 43 |
+
"""True iff real torch (not the test mock) is available."""
|
| 44 |
+
try:
|
| 45 |
+
import torch as _t
|
| 46 |
+
_t.tensor([1.0]).float().item()
|
| 47 |
+
_t.linalg.norm(_t.tensor([3.0, 4.0])).item()
|
| 48 |
+
return True
|
| 49 |
+
except Exception:
|
| 50 |
+
return False
|
| 51 |
+
|
| 52 |
+
_REAL_TORCH = _has_real_torch()
|
| 53 |
+
_skip_no_torch = unittest.skipUnless(_REAL_TORCH, 'requires real torch')
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ---------------------------------------------------------------------------
|
| 58 |
+
# affine_utils
|
| 59 |
+
# ---------------------------------------------------------------------------
|
| 60 |
+
|
| 61 |
+
class TestAffineUtils(unittest.TestCase):
|
| 62 |
+
|
| 63 |
+
# --- matrix constructors ---
|
| 64 |
+
|
| 65 |
+
@_skip_no_torch
|
| 66 |
+
def test_rotate_matrix_shape(self):
|
| 67 |
+
import torch
|
| 68 |
+
m = make_rotate(0.0)
|
| 69 |
+
self.assertEqual(m.shape, (3, 3))
|
| 70 |
+
|
| 71 |
+
@_skip_no_torch
|
| 72 |
+
def test_rotate_zero_is_identity(self):
|
| 73 |
+
import torch
|
| 74 |
+
m = make_rotate(0.0)
|
| 75 |
+
self.assertTrue(torch.allclose(m, torch.eye(3), atol=1e-5))
|
| 76 |
+
|
| 77 |
+
@_skip_no_torch
|
| 78 |
+
def test_rotate_quarter_turn(self):
|
| 79 |
+
import torch
|
| 80 |
+
m = make_rotate(0.25) # 90°
|
| 81 |
+
expected = torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], dtype=torch.float32)
|
| 82 |
+
self.assertTrue(torch.allclose(m, expected, atol=1e-5))
|
| 83 |
+
|
| 84 |
+
def test_rotate_nonfinite_returns_none(self):
|
| 85 |
+
self.assertIsNone(make_rotate(float('inf')))
|
| 86 |
+
self.assertIsNone(make_rotate(float('nan')))
|
| 87 |
+
|
| 88 |
+
@_skip_no_torch
|
| 89 |
+
def test_scale_uniform(self):
|
| 90 |
+
import torch
|
| 91 |
+
m = make_scale(2.0)
|
| 92 |
+
self.assertAlmostEqual(m[0, 0].item(), 2.0)
|
| 93 |
+
self.assertAlmostEqual(m[1, 1].item(), 2.0)
|
| 94 |
+
|
| 95 |
+
def test_scale_zero_rejected(self):
|
| 96 |
+
self.assertIsNone(make_scale(0.0))
|
| 97 |
+
self.assertIsNone(make_scale(1.0, 0.0))
|
| 98 |
+
self.assertIsNone(make_scale(0.0, 1.0))
|
| 99 |
+
|
| 100 |
+
def test_scale_nonfinite_rejected(self):
|
| 101 |
+
self.assertIsNone(make_scale(float('nan'), 1.0))
|
| 102 |
+
|
| 103 |
+
@_skip_no_torch
|
| 104 |
+
def test_shear_safe_range(self):
|
| 105 |
+
import torch
|
| 106 |
+
m = make_shear(0.1)
|
| 107 |
+
self.assertIsNotNone(m)
|
| 108 |
+
self.assertEqual(m.shape, (3, 3))
|
| 109 |
+
|
| 110 |
+
def test_shear_unsafe_range_rejected(self):
|
| 111 |
+
self.assertIsNone(make_shear(0.25))
|
| 112 |
+
self.assertIsNone(make_shear(-0.25))
|
| 113 |
+
|
| 114 |
+
@_skip_no_torch
|
| 115 |
+
def test_slide_matrix(self):
|
| 116 |
+
import torch
|
| 117 |
+
m = make_slide(0.1, 0.2)
|
| 118 |
+
self.assertAlmostEqual(m[0, 2].item(), 0.1)
|
| 119 |
+
self.assertAlmostEqual(m[1, 2].item(), 0.2)
|
| 120 |
+
|
| 121 |
+
def test_slide_nonfinite_rejected(self):
|
| 122 |
+
self.assertIsNone(make_slide(float('nan'), 0.0))
|
| 123 |
+
|
| 124 |
+
@_skip_no_torch
|
| 125 |
+
def test_apply_affine_none_returns_t(self):
|
| 126 |
+
import torch
|
| 127 |
+
t = torch.eye(3)
|
| 128 |
+
self.assertIs(apply_affine(t, None), t)
|
| 129 |
+
|
| 130 |
+
@_skip_no_torch
|
| 131 |
+
def test_try_invert_identity(self):
|
| 132 |
+
import torch
|
| 133 |
+
eye = torch.eye(3)
|
| 134 |
+
result = try_invert_affine(eye)
|
| 135 |
+
self.assertTrue(torch.allclose(result, eye, atol=1e-6))
|
| 136 |
+
|
| 137 |
+
@_skip_no_torch
|
| 138 |
+
def test_try_invert_singular_returns_identity(self):
|
| 139 |
+
import torch
|
| 140 |
+
singular = torch.zeros(3, 3)
|
| 141 |
+
result = try_invert_affine(singular, 'test_singular')
|
| 142 |
+
self.assertTrue(torch.allclose(result, torch.eye(3), atol=1e-6))
|
| 143 |
+
|
| 144 |
+
def test_private_aliases_work(self):
|
| 145 |
+
# Backward-compat aliases must point to the same function
|
| 146 |
+
self.assertIs(_make_rotate, make_rotate)
|
| 147 |
+
self.assertIs(_make_scale, make_scale)
|
| 148 |
+
self.assertIs(_apply_affine, apply_affine)
|
| 149 |
+
|
| 150 |
+
# --- parser dispatch table ---
|
| 151 |
+
|
| 152 |
+
def test_affine_transforms_keys(self):
|
| 153 |
+
self.assertIn('ROTATE', affine_transforms)
|
| 154 |
+
self.assertIn('SCALE', affine_transforms)
|
| 155 |
+
self.assertIn('SLIDE', affine_transforms)
|
| 156 |
+
self.assertIn('SHEAR', affine_transforms)
|
| 157 |
+
|
| 158 |
+
@_skip_no_torch
|
| 159 |
+
def test_affine_transforms_callable(self):
|
| 160 |
+
import torch
|
| 161 |
+
t = torch.eye(3, dtype=torch.float32)
|
| 162 |
+
result = affine_transforms['ROTATE'](t, 0.0)
|
| 163 |
+
self.assertTrue(torch.allclose(result, t, atol=1e-5))
|
| 164 |
+
|
| 165 |
+
# --- snippet builder ---
|
| 166 |
+
|
| 167 |
+
def test_snippet_flip_h(self):
|
| 168 |
+
self.assertEqual(build_affine_snippet('FLIP_H', 0, 0), 'SCALE[-1,1]')
|
| 169 |
+
|
| 170 |
+
def test_snippet_flip_v(self):
|
| 171 |
+
self.assertEqual(build_affine_snippet('FLIP_V', 0, 0), 'SCALE[1,-1]')
|
| 172 |
+
|
| 173 |
+
def test_snippet_rotate(self):
|
| 174 |
+
s = build_affine_snippet('ROTATE', 0.25, 0)
|
| 175 |
+
self.assertEqual(s, 'ROTATE[0.25]')
|
| 176 |
+
|
| 177 |
+
def test_snippet_scale_uniform(self):
|
| 178 |
+
s = build_affine_snippet('SCALE', 2.0, 2.0)
|
| 179 |
+
self.assertEqual(s, 'SCALE[2.0]')
|
| 180 |
+
|
| 181 |
+
def test_snippet_scale_anisotropic(self):
|
| 182 |
+
s = build_affine_snippet('SCALE', 1.5, 1.0)
|
| 183 |
+
self.assertEqual(s, 'SCALE[1.5,1.0]')
|
| 184 |
+
|
| 185 |
+
def test_snippet_parseable(self):
|
| 186 |
+
for snip in ['ROTATE[0.125]', 'SCALE[-1,1]', 'SLIDE[0.1,0.0]']:
|
| 187 |
+
with self.subTest(snip=snip):
|
| 188 |
+
p.parse_root(f'base {snip} AND_PERP style :0.8')
|
| 189 |
+
|
| 190 |
+
def test_presets_non_none_are_tuples(self):
|
| 191 |
+
for name, val in AFFINE_PRESETS.items():
|
| 192 |
+
with self.subTest(preset=name):
|
| 193 |
+
if val is not None:
|
| 194 |
+
self.assertIsInstance(val, tuple)
|
| 195 |
+
self.assertIn(val[0], AFFINE_TRANSFORMS)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# ---------------------------------------------------------------------------
|
| 199 |
+
# protection_utils
|
| 200 |
+
# ---------------------------------------------------------------------------
|
| 201 |
+
|
| 202 |
+
class TestProtectionUtils(unittest.TestCase):
|
| 203 |
+
|
| 204 |
+
def _leaf(self, strat=None):
|
| 205 |
+
return p.LeafPrompt(weight=1.0, conciliation=strat,
|
| 206 |
+
local_transform=None, prompt='text',
|
| 207 |
+
conciliation_params={})
|
| 208 |
+
|
| 209 |
+
def _composite(self, children):
|
| 210 |
+
return p.CompositePrompt(weight=1.0, conciliation=None,
|
| 211 |
+
local_transform=None, children=children,
|
| 212 |
+
conciliation_params={})
|
| 213 |
+
|
| 214 |
+
# --- structural checks ---
|
| 215 |
+
|
| 216 |
+
def test_base_child_is_valid(self):
|
| 217 |
+
node = self._composite([self._leaf(None)])
|
| 218 |
+
self.assertTrue(prompt_has_valid_base_path(node))
|
| 219 |
+
|
| 220 |
+
def test_all_conciliation_is_invalid(self):
|
| 221 |
+
node = self._composite([
|
| 222 |
+
self._leaf(p.ConciliationStrategy.PERPENDICULAR),
|
| 223 |
+
])
|
| 224 |
+
self.assertFalse(prompt_has_valid_base_path(node))
|
| 225 |
+
|
| 226 |
+
def test_leaf_not_valid_base_path(self):
|
| 227 |
+
self.assertFalse(prompt_has_valid_base_path(self._leaf(None)))
|
| 228 |
+
|
| 229 |
+
def test_mixed_children_is_valid(self):
|
| 230 |
+
node = self._composite([
|
| 231 |
+
self._leaf(None),
|
| 232 |
+
self._leaf(p.ConciliationStrategy.SALIENCE_MASK),
|
| 233 |
+
])
|
| 234 |
+
self.assertTrue(prompt_has_valid_base_path(node))
|
| 235 |
+
|
| 236 |
+
def test_invalid_indices_empty_on_ok_batch(self):
|
| 237 |
+
ok = self._composite([self._leaf(None)])
|
| 238 |
+
self.assertEqual(get_invalid_base_prompt_indices([ok]), [])
|
| 239 |
+
|
| 240 |
+
def test_invalid_indices_reports_bad_positions(self):
|
| 241 |
+
ok = self._composite([self._leaf(None)])
|
| 242 |
+
bad = self._composite([self._leaf(p.ConciliationStrategy.PERPENDICULAR)])
|
| 243 |
+
indices = get_invalid_base_prompt_indices([ok, bad, ok])
|
| 244 |
+
self.assertEqual(indices, [1])
|
| 245 |
+
|
| 246 |
+
def test_has_valid_base_path_empty_is_false(self):
|
| 247 |
+
self.assertFalse(has_valid_base_path([]))
|
| 248 |
+
|
| 249 |
+
def test_has_valid_base_path_all_ok(self):
|
| 250 |
+
ok = self._composite([self._leaf(None)])
|
| 251 |
+
self.assertTrue(has_valid_base_path([ok]))
|
| 252 |
+
|
| 253 |
+
def test_has_valid_base_path_mixed_batch_is_false(self):
|
| 254 |
+
ok = self._composite([self._leaf(None)])
|
| 255 |
+
bad = self._composite([self._leaf(p.ConciliationStrategy.PERPENDICULAR)])
|
| 256 |
+
self.assertFalse(has_valid_base_path([ok, bad]))
|
| 257 |
+
|
| 258 |
+
# --- safe_norm ---
|
| 259 |
+
|
| 260 |
+
@_skip_no_torch
|
| 261 |
+
def test_safe_norm_positive(self):
|
| 262 |
+
import torch
|
| 263 |
+
t = torch.tensor([3.0, 4.0])
|
| 264 |
+
self.assertAlmostEqual(safe_norm(t), 5.0, places=4)
|
| 265 |
+
|
| 266 |
+
@_skip_no_torch
|
| 267 |
+
def test_safe_norm_zero_tensor(self):
|
| 268 |
+
import torch
|
| 269 |
+
self.assertEqual(safe_norm(torch.zeros(3)), 0.0)
|
| 270 |
+
|
| 271 |
+
@_skip_no_torch
|
| 272 |
+
def test_safe_norm_nan_returns_zero(self):
|
| 273 |
+
import torch
|
| 274 |
+
t = torch.tensor([float('nan'), 1.0])
|
| 275 |
+
self.assertEqual(safe_norm(t), 0.0)
|
| 276 |
+
|
| 277 |
+
# --- soft attenuation ---
|
| 278 |
+
|
| 279 |
+
def test_attenuation_no_change_when_ratio_ok(self):
|
| 280 |
+
# base_norm / aux_norm = 0.8 / 0.4 = 2.0 ≥ threshold 0.1
|
| 281 |
+
factor = compute_soft_attenuation(0.8, 0.4, 0.1)
|
| 282 |
+
self.assertAlmostEqual(factor, 1.0)
|
| 283 |
+
|
| 284 |
+
def test_attenuation_reduces_when_ratio_low(self):
|
| 285 |
+
# base=0.05, aux=1.0, threshold=0.2 → ratio=0.05 < 0.2
|
| 286 |
+
# factor = 0.05 / (1.0 * 0.2) = 0.25
|
| 287 |
+
factor = compute_soft_attenuation(0.05, 1.0, 0.2)
|
| 288 |
+
self.assertAlmostEqual(factor, 0.25, places=5)
|
| 289 |
+
|
| 290 |
+
def test_attenuation_zero_aux_returns_one(self):
|
| 291 |
+
self.assertAlmostEqual(compute_soft_attenuation(0.5, 0.0, 0.1), 1.0)
|
| 292 |
+
|
| 293 |
+
def test_attenuation_clamped_to_zero_min(self):
|
| 294 |
+
factor = compute_soft_attenuation(-1.0, 1.0, 0.2)
|
| 295 |
+
self.assertGreaterEqual(factor, 0.0)
|
| 296 |
+
|
| 297 |
+
# --- protection_verdict ---
|
| 298 |
+
|
| 299 |
+
def test_verdict_off(self):
|
| 300 |
+
saved = global_state.protection_mode
|
| 301 |
+
global_state.protection_mode = 'off'
|
| 302 |
+
try:
|
| 303 |
+
status, msg = protection_verdict('base AND_PERP style :0.8')
|
| 304 |
+
self.assertEqual(status, 'off')
|
| 305 |
+
finally:
|
| 306 |
+
global_state.protection_mode = saved
|
| 307 |
+
|
| 308 |
+
def test_verdict_ok_with_base(self):
|
| 309 |
+
saved = global_state.protection_mode
|
| 310 |
+
global_state.protection_mode = 'auto'
|
| 311 |
+
try:
|
| 312 |
+
status, msg = protection_verdict('base subject AND_SALT[5] texture :0.8')
|
| 313 |
+
self.assertEqual(status, 'ok')
|
| 314 |
+
finally:
|
| 315 |
+
global_state.protection_mode = saved
|
| 316 |
+
|
| 317 |
+
def test_verdict_fire_no_base(self):
|
| 318 |
+
saved = global_state.protection_mode
|
| 319 |
+
global_state.protection_mode = 'auto'
|
| 320 |
+
try:
|
| 321 |
+
status, msg = protection_verdict('AND_SALT[5] only :0.8')
|
| 322 |
+
self.assertEqual(status, 'fire')
|
| 323 |
+
finally:
|
| 324 |
+
global_state.protection_mode = saved
|
| 325 |
+
|
| 326 |
+
def test_verdict_soft_mode_mentions_attenuation(self):
|
| 327 |
+
saved = global_state.protection_mode
|
| 328 |
+
global_state.protection_mode = 'soft'
|
| 329 |
+
try:
|
| 330 |
+
status, msg = protection_verdict('base AND_SALT[5] texture :0.8')
|
| 331 |
+
self.assertEqual(status, 'ok')
|
| 332 |
+
self.assertIn('attenuat', msg.lower())
|
| 333 |
+
finally:
|
| 334 |
+
global_state.protection_mode = saved
|
| 335 |
+
|
| 336 |
+
def test_normalize_soft_is_valid(self):
|
| 337 |
+
self.assertEqual(global_state.normalize_protection_mode('soft'), 'soft')
|
| 338 |
+
|
| 339 |
+
def test_normalize_unknown_falls_back_to_auto(self):
|
| 340 |
+
self.assertEqual(global_state.normalize_protection_mode('INVALID'), 'auto')
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
# ---------------------------------------------------------------------------
|
| 344 |
+
# step_utils
|
| 345 |
+
# ---------------------------------------------------------------------------
|
| 346 |
+
|
| 347 |
+
class TestStepWindow(unittest.TestCase):
|
| 348 |
+
|
| 349 |
+
def test_clamp_start_to_zero(self):
|
| 350 |
+
w = StepWindow(-0.5, 0.8)
|
| 351 |
+
self.assertAlmostEqual(w.start, 0.0)
|
| 352 |
+
|
| 353 |
+
def test_clamp_end_to_one(self):
|
| 354 |
+
w = StepWindow(0.2, 2.0)
|
| 355 |
+
self.assertAlmostEqual(w.end, 1.0)
|
| 356 |
+
|
| 357 |
+
def test_start_gt_end_clamped(self):
|
| 358 |
+
w = StepWindow(0.8, 0.2)
|
| 359 |
+
self.assertLessEqual(w.start, w.end)
|
| 360 |
+
|
| 361 |
+
def test_is_active_inside(self):
|
| 362 |
+
w = StepWindow(0.3, 0.7)
|
| 363 |
+
self.assertTrue(w.is_active_at(0.5))
|
| 364 |
+
|
| 365 |
+
def test_is_active_at_boundary(self):
|
| 366 |
+
w = StepWindow(0.3, 0.7)
|
| 367 |
+
self.assertTrue(w.is_active_at(0.3))
|
| 368 |
+
self.assertTrue(w.is_active_at(0.7))
|
| 369 |
+
|
| 370 |
+
def test_is_active_outside(self):
|
| 371 |
+
w = StepWindow(0.3, 0.7)
|
| 372 |
+
self.assertFalse(w.is_active_at(0.1))
|
| 373 |
+
self.assertFalse(w.is_active_at(0.9))
|
| 374 |
+
|
| 375 |
+
def test_always_active_is_always_active(self):
|
| 376 |
+
for p in [0.0, 0.5, 1.0]:
|
| 377 |
+
self.assertTrue(ALWAYS_ACTIVE.is_active_at(p))
|
| 378 |
+
|
| 379 |
+
def test_repr(self):
|
| 380 |
+
w = StepWindow(0.3, 0.7)
|
| 381 |
+
self.assertIn('0.30', repr(w))
|
| 382 |
+
self.assertIn('0.70', repr(w))
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
class TestNormalizeProgress(unittest.TestCase):
|
| 386 |
+
|
| 387 |
+
def test_first_step(self):
|
| 388 |
+
self.assertAlmostEqual(normalize_progress(0, 20), 0.0)
|
| 389 |
+
|
| 390 |
+
def test_last_step(self):
|
| 391 |
+
self.assertAlmostEqual(normalize_progress(19, 20), 1.0)
|
| 392 |
+
|
| 393 |
+
def test_midpoint(self):
|
| 394 |
+
p = normalize_progress(10, 21)
|
| 395 |
+
self.assertAlmostEqual(p, 10 / 20)
|
| 396 |
+
|
| 397 |
+
def test_single_step(self):
|
| 398 |
+
self.assertAlmostEqual(normalize_progress(0, 1), 0.0)
|
| 399 |
+
|
| 400 |
+
def test_clamp_to_zero_one(self):
|
| 401 |
+
self.assertGreaterEqual(normalize_progress(0, 100), 0.0)
|
| 402 |
+
self.assertLessEqual(normalize_progress(99, 100), 1.0)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class TestStrategyIsActive(unittest.TestCase):
|
| 406 |
+
|
| 407 |
+
def test_no_window_always_active(self):
|
| 408 |
+
self.assertTrue(strategy_is_active('SALIENCE_MASK', 0.5))
|
| 409 |
+
self.assertTrue(strategy_is_active('PERPENDICULAR', 0.0))
|
| 410 |
+
|
| 411 |
+
def test_global_window_gates(self):
|
| 412 |
+
win = StepWindow(0.4, 0.8)
|
| 413 |
+
self.assertTrue( strategy_is_active('SALIENCE_MASK', 0.6, global_window=win))
|
| 414 |
+
self.assertFalse(strategy_is_active('SALIENCE_MASK', 0.1, global_window=win))
|
| 415 |
+
|
| 416 |
+
def test_per_strategy_overrides_global(self):
|
| 417 |
+
global_win = StepWindow(0.0, 1.0)
|
| 418 |
+
per_s = {'SALIENCE_MASK': StepWindow(0.6, 1.0)}
|
| 419 |
+
# At 0.3: global says active, per-strategy says inactive
|
| 420 |
+
self.assertFalse(strategy_is_active(
|
| 421 |
+
'SALIENCE_MASK', 0.3, global_window=global_win,
|
| 422 |
+
per_strategy_windows=per_s))
|
| 423 |
+
|
| 424 |
+
def test_defaults_perp_early(self):
|
| 425 |
+
# PERPENDICULAR default: 0.0–0.5
|
| 426 |
+
self.assertTrue( strategy_is_active('PERPENDICULAR', 0.3, use_defaults=True))
|
| 427 |
+
self.assertFalse(strategy_is_active('PERPENDICULAR', 0.8, use_defaults=True))
|
| 428 |
+
|
| 429 |
+
def test_defaults_salt_late(self):
|
| 430 |
+
# SALIENCE_MASK default: 0.4–1.0
|
| 431 |
+
self.assertTrue( strategy_is_active('SALIENCE_MASK', 0.7, use_defaults=True))
|
| 432 |
+
self.assertFalse(strategy_is_active('SALIENCE_MASK', 0.1, use_defaults=True))
|
| 433 |
+
|
| 434 |
+
def test_legacy_alignment_strategy_family(self):
|
| 435 |
+
fam = _strategy_family('ALIGNMENT_BLEND_4_8')
|
| 436 |
+
self.assertEqual(fam, 'ALIGNMENT_BLEND_CUSTOM')
|
| 437 |
+
|
| 438 |
+
def test_legacy_mask_alignment_family(self):
|
| 439 |
+
fam = _strategy_family('ALIGNMENT_MASK_BLEND_8_16')
|
| 440 |
+
self.assertEqual(fam, 'ALIGNMENT_MASK_BLEND_CUSTOM')
|
| 441 |
+
|
| 442 |
+
def test_base_strategy_family_unchanged(self):
|
| 443 |
+
self.assertEqual(_strategy_family('PERPENDICULAR'), 'PERPENDICULAR')
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class TestStepWindowSummary(unittest.TestCase):
|
| 447 |
+
|
| 448 |
+
def test_empty_strategies_returns_empty(self):
|
| 449 |
+
out = render_step_window_summary([])
|
| 450 |
+
self.assertEqual(out, '')
|
| 451 |
+
|
| 452 |
+
def test_no_window_shows_all_steps(self):
|
| 453 |
+
out = render_step_window_summary(['SALIENCE_MASK'])
|
| 454 |
+
self.assertIn('all steps', out)
|
| 455 |
+
|
| 456 |
+
def test_global_window_shown(self):
|
| 457 |
+
win = StepWindow(0.3, 0.8)
|
| 458 |
+
out = render_step_window_summary(['SALIENCE_MASK'], global_window=win)
|
| 459 |
+
self.assertIn('SALIENCE_MASK', out)
|
| 460 |
+
self.assertIn('0.30', out)
|
| 461 |
+
|
| 462 |
+
def test_defaults_shown(self):
|
| 463 |
+
out = render_step_window_summary(['PERPENDICULAR'], use_defaults=True)
|
| 464 |
+
self.assertIn('PERPENDICULAR', out)
|
| 465 |
+
self.assertIn('default', out)
|
| 466 |
+
|
| 467 |
+
def test_progress_shows_active_status(self):
|
| 468 |
+
win = StepWindow(0.4, 0.9)
|
| 469 |
+
out = render_step_window_summary(['SALIENCE_MASK'], progress=0.5, global_window=win)
|
| 470 |
+
self.assertIn('active', out)
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
# ---------------------------------------------------------------------------
|
| 474 |
+
# Builder v1.2
|
| 475 |
+
# ---------------------------------------------------------------------------
|
| 476 |
+
|
| 477 |
+
class TestBuilderNode(unittest.TestCase):
|
| 478 |
+
|
| 479 |
+
def test_make_creates_node(self):
|
| 480 |
+
n = BuilderNode.make('AND_SALT', 'texture', 0.8)
|
| 481 |
+
self.assertEqual(n.strategy, 'AND_SALT')
|
| 482 |
+
self.assertEqual(n.text, 'texture')
|
| 483 |
+
self.assertAlmostEqual(n.weight, 0.8)
|
| 484 |
+
self.assertIsNotNone(n.node_id)
|
| 485 |
+
|
| 486 |
+
def test_make_unique_ids(self):
|
| 487 |
+
ids = {BuilderNode.make().node_id for _ in range(20)}
|
| 488 |
+
self.assertEqual(len(ids), 20)
|
| 489 |
+
|
| 490 |
+
def test_default_no_children(self):
|
| 491 |
+
n = BuilderNode.make()
|
| 492 |
+
self.assertEqual(n.children, [])
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
class TestBuilderAddChild(unittest.TestCase):
|
| 496 |
+
|
| 497 |
+
def test_add_to_empty(self):
|
| 498 |
+
nodes = builder_add_child([], strategy='AND_SALT', text='tex', weight=0.8)
|
| 499 |
+
self.assertEqual(len(nodes), 1)
|
| 500 |
+
self.assertEqual(nodes[0].strategy, 'AND_SALT')
|
| 501 |
+
|
| 502 |
+
def test_add_appends(self):
|
| 503 |
+
n1 = BuilderNode.make('AND_SALT', 'tex', 0.8)
|
| 504 |
+
nodes = builder_add_child([n1], strategy='AND_PERP', text='style', weight=0.5)
|
| 505 |
+
self.assertEqual(len(nodes), 2)
|
| 506 |
+
|
| 507 |
+
def test_add_nested_child(self):
|
| 508 |
+
parent = BuilderNode.make('AND_SALT', 'tex', 0.8)
|
| 509 |
+
nodes = builder_add_child([parent], parent_id=parent.node_id,
|
| 510 |
+
strategy='AND_TOPK', text='detail', weight=0.4)
|
| 511 |
+
self.assertEqual(len(nodes), 1)
|
| 512 |
+
self.assertEqual(len(nodes[0].children), 1)
|
| 513 |
+
self.assertEqual(nodes[0].children[0].strategy, 'AND_TOPK')
|
| 514 |
+
|
| 515 |
+
def test_original_not_mutated(self):
|
| 516 |
+
original = [BuilderNode.make()]
|
| 517 |
+
builder_add_child(original)
|
| 518 |
+
self.assertEqual(len(original), 1)
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
class TestBuilderRemoveNode(unittest.TestCase):
|
| 522 |
+
|
| 523 |
+
def test_remove_top_level(self):
|
| 524 |
+
n1 = BuilderNode.make()
|
| 525 |
+
n2 = BuilderNode.make()
|
| 526 |
+
result = builder_remove_node([n1, n2], n1.node_id)
|
| 527 |
+
self.assertEqual(len(result), 1)
|
| 528 |
+
self.assertEqual(result[0].node_id, n2.node_id)
|
| 529 |
+
|
| 530 |
+
def test_remove_nested(self):
|
| 531 |
+
child = BuilderNode.make('AND_TOPK', 'c', 0.5)
|
| 532 |
+
parent = BuilderNode.make('AND_SALT', 'p', 0.8)
|
| 533 |
+
parent = dataclasses_replace(parent, children=[child])
|
| 534 |
+
result = builder_remove_node([parent], child.node_id)
|
| 535 |
+
self.assertEqual(len(result[0].children), 0)
|
| 536 |
+
|
| 537 |
+
def test_remove_nonexistent_no_error(self):
|
| 538 |
+
n = BuilderNode.make()
|
| 539 |
+
result = builder_remove_node([n], 'nonexistent')
|
| 540 |
+
self.assertEqual(len(result), 1)
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
class TestBuilderDuplicateNode(unittest.TestCase):
|
| 544 |
+
|
| 545 |
+
def test_duplicate_creates_new_id(self):
|
| 546 |
+
n = BuilderNode.make('AND_SALT', 'tex', 0.8)
|
| 547 |
+
result = builder_duplicate_node([n], n.node_id)
|
| 548 |
+
self.assertEqual(len(result), 2)
|
| 549 |
+
self.assertNotEqual(result[0].node_id, result[1].node_id)
|
| 550 |
+
|
| 551 |
+
def test_duplicate_copies_content(self):
|
| 552 |
+
n = BuilderNode.make('AND_SALT', 'tex', 0.8)
|
| 553 |
+
result = builder_duplicate_node([n], n.node_id)
|
| 554 |
+
self.assertEqual(result[1].text, 'tex')
|
| 555 |
+
self.assertEqual(result[1].strategy, 'AND_SALT')
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
class TestBuilderMoveNodes(unittest.TestCase):
|
| 559 |
+
|
| 560 |
+
def _nodes(self, *strategies):
|
| 561 |
+
return [BuilderNode.make(s, s, 1.0) for s in strategies]
|
| 562 |
+
|
| 563 |
+
def test_move_up_swaps(self):
|
| 564 |
+
ns = self._nodes('AND_SALT', 'AND_PERP', 'AND_TOPK')
|
| 565 |
+
result = builder_move_up(ns, ns[1].node_id)
|
| 566 |
+
self.assertEqual(result[0].strategy, 'AND_PERP')
|
| 567 |
+
self.assertEqual(result[1].strategy, 'AND_SALT')
|
| 568 |
+
|
| 569 |
+
def test_move_up_first_no_change(self):
|
| 570 |
+
ns = self._nodes('AND_SALT', 'AND_PERP')
|
| 571 |
+
result = builder_move_up(ns, ns[0].node_id)
|
| 572 |
+
self.assertEqual(result[0].strategy, 'AND_SALT')
|
| 573 |
+
|
| 574 |
+
def test_move_down_swaps(self):
|
| 575 |
+
ns = self._nodes('AND_SALT', 'AND_PERP', 'AND_TOPK')
|
| 576 |
+
result = builder_move_down(ns, ns[1].node_id)
|
| 577 |
+
self.assertEqual(result[1].strategy, 'AND_TOPK')
|
| 578 |
+
self.assertEqual(result[2].strategy, 'AND_PERP')
|
| 579 |
+
|
| 580 |
+
def test_move_down_last_no_change(self):
|
| 581 |
+
ns = self._nodes('AND_SALT', 'AND_PERP')
|
| 582 |
+
result = builder_move_down(ns, ns[-1].node_id)
|
| 583 |
+
self.assertEqual(result[-1].strategy, 'AND_PERP')
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
class TestBuilderUpdateNode(unittest.TestCase):
|
| 587 |
+
|
| 588 |
+
def test_update_text(self):
|
| 589 |
+
n = BuilderNode.make('AND_SALT', 'old', 0.8)
|
| 590 |
+
result = builder_update_node([n], n.node_id, text='new')
|
| 591 |
+
self.assertEqual(result[0].text, 'new')
|
| 592 |
+
|
| 593 |
+
def test_update_weight(self):
|
| 594 |
+
n = BuilderNode.make('AND_SALT', 'tex', 0.8)
|
| 595 |
+
result = builder_update_node([n], n.node_id, weight=0.3)
|
| 596 |
+
self.assertAlmostEqual(result[0].weight, 0.3)
|
| 597 |
+
|
| 598 |
+
def test_update_nonexistent_no_change(self):
|
| 599 |
+
n = BuilderNode.make('AND_SALT', 'tex', 0.8)
|
| 600 |
+
result = builder_update_node([n], 'nope', text='changed')
|
| 601 |
+
self.assertEqual(result[0].text, 'tex')
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
class TestSerializeBuilderTree(unittest.TestCase):
|
| 605 |
+
|
| 606 |
+
def test_base_only(self):
|
| 607 |
+
out = builder_tree_to_prompt('base subject', [])
|
| 608 |
+
self.assertEqual(out.strip(), 'base subject')
|
| 609 |
+
|
| 610 |
+
def test_single_node(self):
|
| 611 |
+
n = BuilderNode.make('AND_SALT', 'texture', 0.8)
|
| 612 |
+
out = builder_tree_to_prompt('base', [n])
|
| 613 |
+
self.assertIn('AND_SALT[5]', out)
|
| 614 |
+
self.assertIn('texture', out)
|
| 615 |
+
self.assertIn(':0.8', out)
|
| 616 |
+
|
| 617 |
+
def test_nested_node_renders_brackets(self):
|
| 618 |
+
child = BuilderNode.make('AND_TOPK', 'highlights', 0.4)
|
| 619 |
+
parent_node = BuilderNode.make('AND_SALT', 'texture', 0.8)
|
| 620 |
+
parent_node = dataclasses_replace(parent_node, children=[child])
|
| 621 |
+
out = builder_tree_to_prompt('base', [parent_node])
|
| 622 |
+
self.assertIn('[', out)
|
| 623 |
+
self.assertIn(']', out)
|
| 624 |
+
self.assertIn('AND_TOPK', out)
|
| 625 |
+
|
| 626 |
+
def test_generated_prompt_parseable(self):
|
| 627 |
+
n1 = BuilderNode.make('AND_SALT', 'texture', 0.8)
|
| 628 |
+
n2 = BuilderNode.make('AND_PERP', 'style', 0.5)
|
| 629 |
+
out = builder_tree_to_prompt('base', [n1, n2])
|
| 630 |
+
p.parse_root(out) # must not raise
|
| 631 |
+
|
| 632 |
+
def test_empty_text_node_skipped(self):
|
| 633 |
+
n = BuilderNode.make('AND_SALT', '', 0.8)
|
| 634 |
+
out = builder_tree_to_prompt('base', [n])
|
| 635 |
+
self.assertNotIn('AND_SALT', out)
|
| 636 |
+
|
| 637 |
+
def test_affine_prefix_included(self):
|
| 638 |
+
n = BuilderNode.make('AND_PERP', 'mirror', 0.6, affine='SCALE[-1,1]')
|
| 639 |
+
out = builder_tree_to_prompt('base', [n])
|
| 640 |
+
self.assertIn('SCALE[-1,1]', out)
|
| 641 |
+
|
| 642 |
+
def test_two_top_level_children(self):
|
| 643 |
+
n1 = BuilderNode.make('AND_SALT', 'tex', 0.8)
|
| 644 |
+
n2 = BuilderNode.make('AND_ALIGN', 'shape', 0.7)
|
| 645 |
+
out = builder_tree_to_prompt('base', [n1, n2])
|
| 646 |
+
self.assertIn('AND_SALT', out)
|
| 647 |
+
self.assertIn('AND_ALIGN', out)
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
# dataclasses.replace alias for Python <3.13
|
| 651 |
+
import dataclasses
|
| 652 |
+
dataclasses_replace = dataclasses.replace
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
if __name__ == '__main__':
|
| 656 |
+
unittest.main()
|
neutral_prompt_patcheds/test/perp_parser/test_sprint2_hotfix.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sprint 2 hotfix tests — covers all 4 issues patched after Sprint 2 review.
|
| 3 |
+
|
| 4 |
+
1. BUILDER_STRATEGIES name fix (smoke-test that the name resolves)
|
| 5 |
+
2. update_step_state helper + step-counter round-trip
|
| 6 |
+
3. get_extra_generation_params / unpack_processing_args step-window round-trip
|
| 7 |
+
4. apply_infotext step-window restore
|
| 8 |
+
+ 'custom' mode absent from two-mode choices
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import unittest
|
| 12 |
+
|
| 13 |
+
import lib_neutral_prompt.neutral_prompt_parser as p
|
| 14 |
+
from lib_neutral_prompt import global_state
|
| 15 |
+
from lib_neutral_prompt.matryoshka_utils import BUILDER_STRATEGIES
|
| 16 |
+
from lib_neutral_prompt.step_utils import StepWindow, strategy_is_active, normalize_progress
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
# Fix 1 — BUILDER_STRATEGIES is importable and non-empty
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
class TestBuilderStrategiesName(unittest.TestCase):
|
| 24 |
+
|
| 25 |
+
def test_importable_and_non_empty(self):
|
| 26 |
+
self.assertIsInstance(BUILDER_STRATEGIES, list)
|
| 27 |
+
self.assertGreater(len(BUILDER_STRATEGIES), 0)
|
| 28 |
+
|
| 29 |
+
def test_contains_core_strategies(self):
|
| 30 |
+
self.assertIn('AND_SALT', BUILDER_STRATEGIES)
|
| 31 |
+
self.assertIn('AND_PERP', BUILDER_STRATEGIES)
|
| 32 |
+
self.assertIn('AND_TOPK', BUILDER_STRATEGIES)
|
| 33 |
+
self.assertIn('AND_ALIGN', BUILDER_STRATEGIES)
|
| 34 |
+
|
| 35 |
+
def test_no_underscore_variant_exported(self):
|
| 36 |
+
"""The UI used to import _BUILDER_STRATEGIES — make sure that's gone."""
|
| 37 |
+
import lib_neutral_prompt.matryoshka_utils as mu
|
| 38 |
+
self.assertFalse(
|
| 39 |
+
hasattr(mu, '_BUILDER_STRATEGIES'),
|
| 40 |
+
'_BUILDER_STRATEGIES should not be exported; use BUILDER_STRATEGIES',
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
# Fix 2 — update_step_state + step-window gating integration
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
class TestUpdateStepState(unittest.TestCase):
|
| 49 |
+
|
| 50 |
+
def setUp(self):
|
| 51 |
+
self._saved_step = global_state.current_step
|
| 52 |
+
self._saved_total = global_state.total_steps
|
| 53 |
+
|
| 54 |
+
def tearDown(self):
|
| 55 |
+
global_state.current_step = self._saved_step
|
| 56 |
+
global_state.total_steps = self._saved_total
|
| 57 |
+
|
| 58 |
+
def test_basic_update(self):
|
| 59 |
+
global_state.update_step_state(5, 20)
|
| 60 |
+
self.assertEqual(global_state.current_step, 5)
|
| 61 |
+
self.assertEqual(global_state.total_steps, 20)
|
| 62 |
+
|
| 63 |
+
def test_float_inputs_converted(self):
|
| 64 |
+
global_state.update_step_state(3.9, 19.1)
|
| 65 |
+
self.assertEqual(global_state.current_step, 3)
|
| 66 |
+
self.assertEqual(global_state.total_steps, 19)
|
| 67 |
+
|
| 68 |
+
def test_negative_step_clamped_to_zero(self):
|
| 69 |
+
global_state.update_step_state(-1, 10)
|
| 70 |
+
self.assertEqual(global_state.current_step, 0)
|
| 71 |
+
|
| 72 |
+
def test_zero_total_clamped_to_one(self):
|
| 73 |
+
global_state.update_step_state(0, 0)
|
| 74 |
+
self.assertEqual(global_state.total_steps, 1)
|
| 75 |
+
|
| 76 |
+
def test_bad_types_do_not_crash(self):
|
| 77 |
+
global_state.update_step_state('bad', None)
|
| 78 |
+
# Should stay at some valid state
|
| 79 |
+
self.assertGreaterEqual(global_state.current_step, 0)
|
| 80 |
+
self.assertGreaterEqual(global_state.total_steps, 1)
|
| 81 |
+
|
| 82 |
+
def test_first_step_progress_is_zero(self):
|
| 83 |
+
global_state.update_step_state(0, 20)
|
| 84 |
+
p = normalize_progress(global_state.current_step, global_state.total_steps)
|
| 85 |
+
self.assertAlmostEqual(p, 0.0)
|
| 86 |
+
|
| 87 |
+
def test_last_step_progress_is_one(self):
|
| 88 |
+
global_state.update_step_state(19, 20)
|
| 89 |
+
p = normalize_progress(global_state.current_step, global_state.total_steps)
|
| 90 |
+
self.assertAlmostEqual(p, 1.0)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class TestStepWindowGatingIntegration(unittest.TestCase):
|
| 94 |
+
"""
|
| 95 |
+
Simulate the hijack path: update_step_state → normalize_progress
|
| 96 |
+
→ strategy_is_active with a global window.
|
| 97 |
+
This is the scenario described in the fix-2 specification.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def setUp(self):
|
| 101 |
+
self._saved_step = global_state.current_step
|
| 102 |
+
self._saved_total = global_state.total_steps
|
| 103 |
+
|
| 104 |
+
def tearDown(self):
|
| 105 |
+
global_state.current_step = self._saved_step
|
| 106 |
+
global_state.total_steps = self._saved_total
|
| 107 |
+
|
| 108 |
+
def _set_step(self, step, total):
|
| 109 |
+
global_state.update_step_state(step, total)
|
| 110 |
+
|
| 111 |
+
def test_and_salt_gated_out_in_early_window(self):
|
| 112 |
+
"""AND_SALT with window 0.6–1.0 should be skipped at step 0/9."""
|
| 113 |
+
self._set_step(0, 10)
|
| 114 |
+
win = StepWindow(0.6, 1.0)
|
| 115 |
+
progress = normalize_progress(global_state.current_step, global_state.total_steps)
|
| 116 |
+
active = strategy_is_active('SALIENCE_MASK', progress, global_window=win)
|
| 117 |
+
self.assertFalse(active, f'Expected inactive at progress={progress:.3f}')
|
| 118 |
+
|
| 119 |
+
def test_and_salt_active_in_late_window(self):
|
| 120 |
+
"""AND_SALT with window 0.6–1.0 should be active at step 8/10."""
|
| 121 |
+
self._set_step(8, 10)
|
| 122 |
+
win = StepWindow(0.6, 1.0)
|
| 123 |
+
progress = normalize_progress(global_state.current_step, global_state.total_steps)
|
| 124 |
+
active = strategy_is_active('SALIENCE_MASK', progress, global_window=win)
|
| 125 |
+
self.assertTrue(active, f'Expected active at progress={progress:.3f}')
|
| 126 |
+
|
| 127 |
+
def test_and_perp_active_early_with_defaults(self):
|
| 128 |
+
"""AND_PERP defaults: 0.0–0.5 → step 2/10 should be active."""
|
| 129 |
+
self._set_step(2, 10)
|
| 130 |
+
progress = normalize_progress(global_state.current_step, global_state.total_steps)
|
| 131 |
+
active = strategy_is_active('PERPENDICULAR', progress, use_defaults=True)
|
| 132 |
+
self.assertTrue(active, f'Expected active at progress={progress:.3f}')
|
| 133 |
+
|
| 134 |
+
def test_and_perp_gated_out_late_with_defaults(self):
|
| 135 |
+
"""AND_PERP defaults: 0.0–0.5 → step 8/10 should be inactive."""
|
| 136 |
+
self._set_step(8, 10)
|
| 137 |
+
progress = normalize_progress(global_state.current_step, global_state.total_steps)
|
| 138 |
+
active = strategy_is_active('PERPENDICULAR', progress, use_defaults=True)
|
| 139 |
+
self.assertFalse(active, f'Expected inactive at progress={progress:.3f}')
|
| 140 |
+
|
| 141 |
+
def test_no_window_always_active_regardless_of_step(self):
|
| 142 |
+
for step in (0, 5, 9):
|
| 143 |
+
with self.subTest(step=step):
|
| 144 |
+
self._set_step(step, 10)
|
| 145 |
+
progress = normalize_progress(global_state.current_step, global_state.total_steps)
|
| 146 |
+
self.assertTrue(strategy_is_active('SALIENCE_MASK', progress))
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# ---------------------------------------------------------------------------
|
| 150 |
+
# Fix 3 — step-window round-trip through unpack_processing_args + get_extra
|
| 151 |
+
# ---------------------------------------------------------------------------
|
| 152 |
+
|
| 153 |
+
def _unpack_step_window(cfg_rescale=0.0, protection_mode='auto', strict_threshold=0.1,
|
| 154 |
+
step_window_enabled=False, step_window_mode='global',
|
| 155 |
+
step_window_start=0.0, step_window_end=1.0):
|
| 156 |
+
"""
|
| 157 |
+
Pure-logic equivalent of AccordionInterface.unpack_processing_args
|
| 158 |
+
without any gradio dependency — mirrors the real implementation exactly.
|
| 159 |
+
"""
|
| 160 |
+
enabled = bool(step_window_enabled)
|
| 161 |
+
mode = str(step_window_mode)
|
| 162 |
+
start = max(0.0, min(1.0, float(step_window_start)))
|
| 163 |
+
end = max(0.0, min(1.0, float(step_window_end)))
|
| 164 |
+
if start > end:
|
| 165 |
+
start, end = 0.0, 1.0
|
| 166 |
+
|
| 167 |
+
global_state.step_window_enabled = enabled
|
| 168 |
+
if enabled:
|
| 169 |
+
if mode == 'global':
|
| 170 |
+
global_state.step_window_global = StepWindow(start, end)
|
| 171 |
+
global_state.step_window_use_defaults = False
|
| 172 |
+
elif mode == 'per-strategy defaults':
|
| 173 |
+
global_state.step_window_global = None
|
| 174 |
+
global_state.step_window_use_defaults = True
|
| 175 |
+
else:
|
| 176 |
+
global_state.step_window_global = None
|
| 177 |
+
global_state.step_window_use_defaults = False
|
| 178 |
+
else:
|
| 179 |
+
global_state.step_window_global = None
|
| 180 |
+
global_state.step_window_use_defaults = False
|
| 181 |
+
|
| 182 |
+
return {'cfg_rescale': cfg_rescale, 'protection_mode': protection_mode,
|
| 183 |
+
'strict_threshold': strict_threshold,
|
| 184 |
+
'step_window_enabled': enabled, 'step_window_mode': mode,
|
| 185 |
+
'step_window_start': start, 'step_window_end': end}
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _get_extra_params(args):
|
| 189 |
+
"""Pure-logic equivalent of AccordionInterface.get_extra_generation_params."""
|
| 190 |
+
p = {'CFG Rescale phi': args['cfg_rescale'],
|
| 191 |
+
'NP Protection Mode': args['protection_mode'],
|
| 192 |
+
'NP Strict Threshold': args['strict_threshold']}
|
| 193 |
+
if args.get('step_window_enabled'):
|
| 194 |
+
p['NP Step Window Enabled'] = True
|
| 195 |
+
p['NP Step Window Mode'] = args.get('step_window_mode', 'global')
|
| 196 |
+
p['NP Step Window Start'] = round(float(args.get('step_window_start', 0.0)), 4)
|
| 197 |
+
p['NP Step Window End'] = round(float(args.get('step_window_end', 1.0)), 4)
|
| 198 |
+
return p
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class TestStepWindowProcessingRoundTrip(unittest.TestCase):
|
| 202 |
+
|
| 203 |
+
def setUp(self):
|
| 204 |
+
self._saved_enabled = global_state.step_window_enabled
|
| 205 |
+
self._saved_global = global_state.step_window_global
|
| 206 |
+
self._saved_use_defaults = global_state.step_window_use_defaults
|
| 207 |
+
|
| 208 |
+
def tearDown(self):
|
| 209 |
+
global_state.step_window_enabled = self._saved_enabled
|
| 210 |
+
global_state.step_window_global = self._saved_global
|
| 211 |
+
global_state.step_window_use_defaults = self._saved_use_defaults
|
| 212 |
+
|
| 213 |
+
def test_global_window_written_to_state(self):
|
| 214 |
+
args = _unpack_step_window(
|
| 215 |
+
cfg_rescale=0.0,
|
| 216 |
+
step_window_enabled=True,
|
| 217 |
+
step_window_mode='global',
|
| 218 |
+
step_window_start=0.3,
|
| 219 |
+
step_window_end=0.8,
|
| 220 |
+
)
|
| 221 |
+
self.assertTrue(global_state.step_window_enabled)
|
| 222 |
+
self.assertIsNotNone(global_state.step_window_global)
|
| 223 |
+
self.assertAlmostEqual(global_state.step_window_global.start, 0.3, places=4)
|
| 224 |
+
self.assertAlmostEqual(global_state.step_window_global.end, 0.8, places=4)
|
| 225 |
+
|
| 226 |
+
def test_defaults_mode_sets_use_defaults(self):
|
| 227 |
+
_unpack_step_window(
|
| 228 |
+
cfg_rescale=0.0,
|
| 229 |
+
step_window_enabled=True,
|
| 230 |
+
step_window_mode='per-strategy defaults',
|
| 231 |
+
)
|
| 232 |
+
self.assertTrue(global_state.step_window_use_defaults)
|
| 233 |
+
self.assertIsNone(global_state.step_window_global)
|
| 234 |
+
|
| 235 |
+
def test_disabled_clears_state(self):
|
| 236 |
+
global_state.step_window_enabled = True
|
| 237 |
+
global_state.step_window_global = StepWindow(0.3, 0.8)
|
| 238 |
+
_unpack_step_window(cfg_rescale=0.0, step_window_enabled=False)
|
| 239 |
+
self.assertFalse(global_state.step_window_enabled)
|
| 240 |
+
self.assertIsNone(global_state.step_window_global)
|
| 241 |
+
|
| 242 |
+
def test_invalid_range_falls_back_to_full(self):
|
| 243 |
+
args = _unpack_step_window(
|
| 244 |
+
cfg_rescale=0.0,
|
| 245 |
+
step_window_enabled=True,
|
| 246 |
+
step_window_mode='global',
|
| 247 |
+
step_window_start=0.9,
|
| 248 |
+
step_window_end=0.1, # start > end → fallback
|
| 249 |
+
)
|
| 250 |
+
self.assertAlmostEqual(args['step_window_start'], 0.0)
|
| 251 |
+
self.assertAlmostEqual(args['step_window_end'], 1.0)
|
| 252 |
+
|
| 253 |
+
def test_extra_params_include_step_window_when_enabled(self):
|
| 254 |
+
args = {
|
| 255 |
+
'cfg_rescale': 0.0,
|
| 256 |
+
'protection_mode': 'auto',
|
| 257 |
+
'strict_threshold': 0.1,
|
| 258 |
+
'step_window_enabled': True,
|
| 259 |
+
'step_window_mode': 'global',
|
| 260 |
+
'step_window_start': 0.3,
|
| 261 |
+
'step_window_end': 0.8,
|
| 262 |
+
}
|
| 263 |
+
params = _get_extra_params(args)
|
| 264 |
+
self.assertIn('NP Step Window Enabled', params)
|
| 265 |
+
self.assertIn('NP Step Window Mode', params)
|
| 266 |
+
self.assertIn('NP Step Window Start', params)
|
| 267 |
+
self.assertIn('NP Step Window End', params)
|
| 268 |
+
self.assertAlmostEqual(params['NP Step Window Start'], 0.3, places=4)
|
| 269 |
+
|
| 270 |
+
def test_extra_params_omit_step_window_when_disabled(self):
|
| 271 |
+
args = {
|
| 272 |
+
'cfg_rescale': 0.0,
|
| 273 |
+
'protection_mode': 'auto',
|
| 274 |
+
'strict_threshold': 0.1,
|
| 275 |
+
'step_window_enabled': False,
|
| 276 |
+
}
|
| 277 |
+
params = _get_extra_params(args)
|
| 278 |
+
self.assertNotIn('NP Step Window Enabled', params)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
# ---------------------------------------------------------------------------
|
| 282 |
+
# Fix 3 — apply_infotext step-window restore
|
| 283 |
+
# ---------------------------------------------------------------------------
|
| 284 |
+
|
| 285 |
+
class TestApplyInfotextStepWindow(unittest.TestCase):
|
| 286 |
+
|
| 287 |
+
def setUp(self):
|
| 288 |
+
self._saved_enabled = global_state.step_window_enabled
|
| 289 |
+
self._saved_global = global_state.step_window_global
|
| 290 |
+
self._saved_use_defaults = global_state.step_window_use_defaults
|
| 291 |
+
|
| 292 |
+
def tearDown(self):
|
| 293 |
+
global_state.step_window_enabled = self._saved_enabled
|
| 294 |
+
global_state.step_window_global = self._saved_global
|
| 295 |
+
global_state.step_window_use_defaults = self._saved_use_defaults
|
| 296 |
+
|
| 297 |
+
def test_global_window_restored(self):
|
| 298 |
+
global_state.apply_infotext({
|
| 299 |
+
'NP Step Window Enabled': 'True',
|
| 300 |
+
'NP Step Window Mode': 'global',
|
| 301 |
+
'NP Step Window Start': '0.25',
|
| 302 |
+
'NP Step Window End': '0.75',
|
| 303 |
+
})
|
| 304 |
+
self.assertTrue(global_state.step_window_enabled)
|
| 305 |
+
self.assertIsNotNone(global_state.step_window_global)
|
| 306 |
+
self.assertAlmostEqual(global_state.step_window_global.start, 0.25, places=4)
|
| 307 |
+
self.assertAlmostEqual(global_state.step_window_global.end, 0.75, places=4)
|
| 308 |
+
|
| 309 |
+
def test_defaults_mode_restored(self):
|
| 310 |
+
global_state.apply_infotext({
|
| 311 |
+
'NP Step Window Enabled': 'True',
|
| 312 |
+
'NP Step Window Mode': 'per-strategy defaults',
|
| 313 |
+
})
|
| 314 |
+
self.assertTrue(global_state.step_window_enabled)
|
| 315 |
+
self.assertTrue(global_state.step_window_use_defaults)
|
| 316 |
+
self.assertIsNone(global_state.step_window_global)
|
| 317 |
+
|
| 318 |
+
def test_disabled_not_restored_when_key_absent(self):
|
| 319 |
+
global_state.step_window_enabled = True
|
| 320 |
+
global_state.apply_infotext({}) # no step-window keys
|
| 321 |
+
# Should be unchanged
|
| 322 |
+
self.assertTrue(global_state.step_window_enabled)
|
| 323 |
+
|
| 324 |
+
def test_false_string_disables(self):
|
| 325 |
+
global_state.step_window_enabled = True
|
| 326 |
+
global_state.apply_infotext({'NP Step Window Enabled': 'False'})
|
| 327 |
+
self.assertFalse(global_state.step_window_enabled)
|
| 328 |
+
|
| 329 |
+
def test_invalid_range_falls_back_to_full(self):
|
| 330 |
+
global_state.apply_infotext({
|
| 331 |
+
'NP Step Window Enabled': 'True',
|
| 332 |
+
'NP Step Window Mode': 'global',
|
| 333 |
+
'NP Step Window Start': '0.9',
|
| 334 |
+
'NP Step Window End': '0.1', # inverted
|
| 335 |
+
})
|
| 336 |
+
win = global_state.step_window_global
|
| 337 |
+
self.assertAlmostEqual(win.start, 0.0, places=4)
|
| 338 |
+
self.assertAlmostEqual(win.end, 1.0, places=4)
|
| 339 |
+
|
| 340 |
+
def test_old_infotext_without_step_window_keys_does_not_crash(self):
|
| 341 |
+
global_state.apply_infotext({
|
| 342 |
+
'CFG Rescale phi': '0.5',
|
| 343 |
+
'NP Protection Mode': 'auto',
|
| 344 |
+
})
|
| 345 |
+
# No crash, step_window unchanged
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
# ---------------------------------------------------------------------------
|
| 349 |
+
# Fix 4 — 'custom' absent from step_window_mode UI choices
|
| 350 |
+
# ---------------------------------------------------------------------------
|
| 351 |
+
|
| 352 |
+
class TestCustomModeRemovedFromUI(unittest.TestCase):
|
| 353 |
+
|
| 354 |
+
def test_step_window_update_handler_accepts_two_modes(self):
|
| 355 |
+
"""
|
| 356 |
+
The step-window update logic must handle exactly the two released modes
|
| 357 |
+
without any 'custom' branch.
|
| 358 |
+
"""
|
| 359 |
+
saved = {
|
| 360 |
+
'enabled': global_state.step_window_enabled,
|
| 361 |
+
'global': global_state.step_window_global,
|
| 362 |
+
'use_defaults': global_state.step_window_use_defaults,
|
| 363 |
+
}
|
| 364 |
+
try:
|
| 365 |
+
for mode in ('global', 'per-strategy defaults'):
|
| 366 |
+
with self.subTest(mode=mode):
|
| 367 |
+
global_state.step_window_enabled = True
|
| 368 |
+
if mode == 'global':
|
| 369 |
+
global_state.step_window_global = StepWindow(0.3, 0.8)
|
| 370 |
+
global_state.step_window_use_defaults = False
|
| 371 |
+
else:
|
| 372 |
+
global_state.step_window_global = None
|
| 373 |
+
global_state.step_window_use_defaults = True
|
| 374 |
+
# Verify state is consistent for both valid modes
|
| 375 |
+
if mode == 'global':
|
| 376 |
+
self.assertIsNotNone(global_state.step_window_global)
|
| 377 |
+
else:
|
| 378 |
+
self.assertTrue(global_state.step_window_use_defaults)
|
| 379 |
+
finally:
|
| 380 |
+
global_state.step_window_enabled = saved['enabled']
|
| 381 |
+
global_state.step_window_global = saved['global']
|
| 382 |
+
global_state.step_window_use_defaults = saved['use_defaults']
|
| 383 |
+
|
| 384 |
+
def test_step_utils_still_supports_custom_internally(self):
|
| 385 |
+
"""
|
| 386 |
+
custom is removed from UI but step_utils.strategy_is_active must
|
| 387 |
+
still accept None per_strategy_windows gracefully (for API use).
|
| 388 |
+
"""
|
| 389 |
+
from lib_neutral_prompt.step_utils import strategy_is_active
|
| 390 |
+
# custom = no global_window, no per_strategy_windows, no defaults
|
| 391 |
+
active = strategy_is_active('SALIENCE_MASK', 0.5,
|
| 392 |
+
global_window=None,
|
| 393 |
+
per_strategy_windows=None,
|
| 394 |
+
use_defaults=False)
|
| 395 |
+
self.assertTrue(active, 'Without any window, strategy must be always active')
|
| 396 |
+
|
| 397 |
+
def test_unpack_does_not_expose_custom_mode_to_state(self):
|
| 398 |
+
"""
|
| 399 |
+
If someone passes 'custom' through the API, unpack must fall back
|
| 400 |
+
to a safe state (no window, no defaults).
|
| 401 |
+
"""
|
| 402 |
+
saved_global = global_state.step_window_global
|
| 403 |
+
saved_defaults = global_state.step_window_use_defaults
|
| 404 |
+
try:
|
| 405 |
+
_unpack_step_window(
|
| 406 |
+
cfg_rescale=0.0,
|
| 407 |
+
step_window_enabled=True,
|
| 408 |
+
step_window_mode='custom', # not in UI, but API-safe
|
| 409 |
+
)
|
| 410 |
+
# In the 'else' branch: global=None, use_defaults=False
|
| 411 |
+
self.assertIsNone(global_state.step_window_global)
|
| 412 |
+
self.assertFalse(global_state.step_window_use_defaults)
|
| 413 |
+
finally:
|
| 414 |
+
global_state.step_window_global = saved_global
|
| 415 |
+
global_state.step_window_use_defaults = saved_defaults
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
if __name__ == '__main__':
|
| 419 |
+
unittest.main()
|
neutral_prompt_patcheds/test/perp_parser/test_stabilization.py
ADDED
|
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Stabilization test suite — Sprint 2 hotfix + per-strategy custom windows.
|
| 3 |
+
|
| 4 |
+
Sections
|
| 5 |
+
--------
|
| 6 |
+
1. UI smoke-test – verifies ui.py imports and AccordionInterface constructs
|
| 7 |
+
without NameError / bad Gradio kwargs
|
| 8 |
+
2. Step-window integration – update_step_state → progress → strategy activation
|
| 9 |
+
3. Generation params round-trip – serialize → apply_infotext → state restored
|
| 10 |
+
4. Per-strategy custom windows – serialize/deserialize/build helpers
|
| 11 |
+
5. Per-strategy round-trip – custom mode infotext restore
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import sys
|
| 15 |
+
import types
|
| 16 |
+
import unittest
|
| 17 |
+
import importlib
|
| 18 |
+
|
| 19 |
+
from lib_neutral_prompt import global_state
|
| 20 |
+
from lib_neutral_prompt.step_utils import (
|
| 21 |
+
StepWindow,
|
| 22 |
+
normalize_progress,
|
| 23 |
+
strategy_is_active,
|
| 24 |
+
STRATEGY_UI_KEYS,
|
| 25 |
+
STRATEGY_DEFAULTS,
|
| 26 |
+
normalize_step_range,
|
| 27 |
+
build_per_strategy_windows,
|
| 28 |
+
serialize_per_strategy_windows,
|
| 29 |
+
deserialize_per_strategy_windows,
|
| 30 |
+
default_custom_windows,
|
| 31 |
+
ui_key_to_family,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ===========================================================================
|
| 36 |
+
# 1. UI smoke-test
|
| 37 |
+
# ===========================================================================
|
| 38 |
+
|
| 39 |
+
def _make_fake_gradio():
|
| 40 |
+
"""
|
| 41 |
+
Minimal gradio mock that records kwargs without raising on unknown ones.
|
| 42 |
+
This lets us detect *missing constants* (NameError) and *wrong import paths*
|
| 43 |
+
while ignoring version-specific kwargs differences.
|
| 44 |
+
"""
|
| 45 |
+
gr = types.ModuleType('gradio')
|
| 46 |
+
|
| 47 |
+
class _Component:
|
| 48 |
+
def __init__(self, *a, **kw):
|
| 49 |
+
self._kwargs = kw
|
| 50 |
+
self.value = kw.get('value')
|
| 51 |
+
def render(self): return self
|
| 52 |
+
def change(self, **kw): return self
|
| 53 |
+
def click(self, **kw): return self
|
| 54 |
+
@staticmethod
|
| 55 |
+
def update(**kw): return kw
|
| 56 |
+
|
| 57 |
+
for name in ('Checkbox', 'Radio', 'Slider', 'Textbox', 'Button',
|
| 58 |
+
'Dropdown', 'Accordion', 'Row', 'Column', 'Tab', 'Tabs',
|
| 59 |
+
'HTML', 'Markdown', 'Number', 'Code'):
|
| 60 |
+
setattr(gr, name, _Component)
|
| 61 |
+
|
| 62 |
+
gr.update = lambda **kw: kw
|
| 63 |
+
gr.__version__ = '3.41.0'
|
| 64 |
+
return gr
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class TestUISmoke(unittest.TestCase):
|
| 68 |
+
"""
|
| 69 |
+
Import lib_neutral_prompt.ui under mocked gradio + modules and verify that
|
| 70 |
+
AccordionInterface can be constructed without errors.
|
| 71 |
+
|
| 72 |
+
This catches:
|
| 73 |
+
* NameError (undefined constants like _BUILDER_STRATEGIES)
|
| 74 |
+
* ImportError (bad relative imports)
|
| 75 |
+
* AttributeError from wrong Gradio kwarg names such as 'tooltip'
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
@classmethod
|
| 79 |
+
def setUpClass(cls):
|
| 80 |
+
# If ui.py was already imported in this process (e.g. by another test
|
| 81 |
+
# module), the gradio mock cannot be reliably injected. Mark as skipped
|
| 82 |
+
# rather than erroring out.
|
| 83 |
+
if 'lib_neutral_prompt.ui' in sys.modules:
|
| 84 |
+
cls._ui = sys.modules['lib_neutral_prompt.ui']
|
| 85 |
+
cls._skipped = True
|
| 86 |
+
return
|
| 87 |
+
cls._skipped = False
|
| 88 |
+
|
| 89 |
+
gr = _make_fake_gradio()
|
| 90 |
+
sys.modules['gradio'] = gr
|
| 91 |
+
|
| 92 |
+
fake_cb = types.SimpleNamespace(
|
| 93 |
+
on_ui_settings=lambda *a, **k: None,
|
| 94 |
+
on_ui_tabs=lambda *a, **k: None,
|
| 95 |
+
on_after_component=lambda *a, **k: None,
|
| 96 |
+
)
|
| 97 |
+
fake_shared = types.SimpleNamespace(
|
| 98 |
+
opts=types.SimpleNamespace(
|
| 99 |
+
add_option=lambda *a, **k: None,
|
| 100 |
+
onchange=lambda *a, **k: None,
|
| 101 |
+
data={},
|
| 102 |
+
),
|
| 103 |
+
state=types.SimpleNamespace(sampling_step=0, sampling_steps=1),
|
| 104 |
+
)
|
| 105 |
+
fake_modules = types.ModuleType('modules')
|
| 106 |
+
fake_modules.script_callbacks = fake_cb
|
| 107 |
+
fake_modules.shared = fake_shared
|
| 108 |
+
sys.modules.setdefault('modules', fake_modules)
|
| 109 |
+
sys.modules.setdefault('modules.script_callbacks', fake_cb)
|
| 110 |
+
sys.modules.setdefault('modules.shared', fake_shared)
|
| 111 |
+
|
| 112 |
+
try:
|
| 113 |
+
cls._ui = importlib.import_module('lib_neutral_prompt.ui')
|
| 114 |
+
except Exception as exc:
|
| 115 |
+
raise unittest.SkipTest(f'lib_neutral_prompt.ui failed to import: {exc}')
|
| 116 |
+
|
| 117 |
+
def setUp(self):
|
| 118 |
+
if self._skipped:
|
| 119 |
+
self.skipTest('lib_neutral_prompt.ui already imported; smoke-test skipped')
|
| 120 |
+
|
| 121 |
+
def test_module_imported_without_error(self):
|
| 122 |
+
self.assertIsNotNone(self._ui)
|
| 123 |
+
|
| 124 |
+
def test_accordion_interface_constructs(self):
|
| 125 |
+
obj = self._ui.AccordionInterface(get_elem_id=lambda s: f'test_{s}')
|
| 126 |
+
self.assertIsNotNone(obj)
|
| 127 |
+
|
| 128 |
+
def test_builder_strategies_constant_reachable(self):
|
| 129 |
+
from lib_neutral_prompt.matryoshka_utils import BUILDER_STRATEGIES
|
| 130 |
+
self.assertIsInstance(BUILDER_STRATEGIES, list)
|
| 131 |
+
self.assertGreater(len(BUILDER_STRATEGIES), 0)
|
| 132 |
+
|
| 133 |
+
def test_step_ps_sliders_created_for_all_strategies(self):
|
| 134 |
+
obj = self._ui.AccordionInterface(get_elem_id=lambda s: f'test2_{s}')
|
| 135 |
+
for ui_key in STRATEGY_UI_KEYS:
|
| 136 |
+
self.assertIn(ui_key, obj.step_ps_sliders,
|
| 137 |
+
f'Missing per-strategy slider pair for {ui_key}')
|
| 138 |
+
pair = obj.step_ps_sliders[ui_key]
|
| 139 |
+
self.assertEqual(len(pair), 2)
|
| 140 |
+
|
| 141 |
+
def test_no_tooltip_kwarg_on_dropdown(self):
|
| 142 |
+
"""Regression: tooltip= must have been replaced by info= on all Dropdowns."""
|
| 143 |
+
import ast
|
| 144 |
+
import pathlib
|
| 145 |
+
src = pathlib.Path(
|
| 146 |
+
__file__).parent.parent.parent / 'lib_neutral_prompt' / 'ui.py'
|
| 147 |
+
tree = ast.parse(src.read_text())
|
| 148 |
+
violations = []
|
| 149 |
+
for node in ast.walk(tree):
|
| 150 |
+
if isinstance(node, ast.Call):
|
| 151 |
+
for kw in node.keywords:
|
| 152 |
+
if kw.arg == 'tooltip':
|
| 153 |
+
violations.append(f'line {node.lineno}')
|
| 154 |
+
self.assertFalse(violations,
|
| 155 |
+
f'Found tooltip= kwarg(s) in ui.py: {violations}')
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# ===========================================================================
|
| 159 |
+
# 2. Step-window integration
|
| 160 |
+
# ===========================================================================
|
| 161 |
+
|
| 162 |
+
class TestStepWindowIntegration(unittest.TestCase):
|
| 163 |
+
"""
|
| 164 |
+
Verify that update_step_state → normalize_progress → strategy_is_active
|
| 165 |
+
gives the expected gating behaviour.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
def setUp(self):
|
| 169 |
+
self._step = global_state.current_step
|
| 170 |
+
self._total = global_state.total_steps
|
| 171 |
+
|
| 172 |
+
def tearDown(self):
|
| 173 |
+
global_state.current_step = self._step
|
| 174 |
+
global_state.total_steps = self._total
|
| 175 |
+
|
| 176 |
+
# ---- global window ----
|
| 177 |
+
|
| 178 |
+
def test_late_strategy_inactive_at_step_0(self):
|
| 179 |
+
global_state.update_step_state(0, 10)
|
| 180 |
+
p = normalize_progress(global_state.current_step, global_state.total_steps)
|
| 181 |
+
# SALIENCE_MASK default: 0.4–1.0 → inactive at 0.0
|
| 182 |
+
self.assertFalse(strategy_is_active('SALIENCE_MASK', p, use_defaults=True))
|
| 183 |
+
|
| 184 |
+
def test_late_strategy_active_at_step_8(self):
|
| 185 |
+
global_state.update_step_state(8, 10)
|
| 186 |
+
p = normalize_progress(global_state.current_step, global_state.total_steps)
|
| 187 |
+
self.assertTrue(strategy_is_active('SALIENCE_MASK', p, use_defaults=True))
|
| 188 |
+
|
| 189 |
+
def test_early_strategy_active_at_step_0(self):
|
| 190 |
+
global_state.update_step_state(0, 10)
|
| 191 |
+
p = normalize_progress(global_state.current_step, global_state.total_steps)
|
| 192 |
+
# PERPENDICULAR default: 0.0–0.5 → active at 0.0
|
| 193 |
+
self.assertTrue(strategy_is_active('PERPENDICULAR', p, use_defaults=True))
|
| 194 |
+
|
| 195 |
+
def test_early_strategy_inactive_at_step_8(self):
|
| 196 |
+
global_state.update_step_state(8, 10)
|
| 197 |
+
p = normalize_progress(global_state.current_step, global_state.total_steps)
|
| 198 |
+
self.assertFalse(strategy_is_active('PERPENDICULAR', p, use_defaults=True))
|
| 199 |
+
|
| 200 |
+
def test_global_window_overrides_defaults(self):
|
| 201 |
+
"""A global window [0.6, 1.0] gates AND_PERP out even at step 0."""
|
| 202 |
+
global_state.update_step_state(0, 10)
|
| 203 |
+
p = normalize_progress(global_state.current_step, global_state.total_steps)
|
| 204 |
+
win = StepWindow(0.6, 1.0)
|
| 205 |
+
self.assertFalse(strategy_is_active('PERPENDICULAR', p,
|
| 206 |
+
global_window=win, use_defaults=True))
|
| 207 |
+
|
| 208 |
+
def test_no_window_always_active(self):
|
| 209 |
+
for step in (0, 5, 9):
|
| 210 |
+
global_state.update_step_state(step, 10)
|
| 211 |
+
p = normalize_progress(global_state.current_step, global_state.total_steps)
|
| 212 |
+
with self.subTest(step=step):
|
| 213 |
+
self.assertTrue(strategy_is_active('SALIENCE_MASK', p))
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# ===========================================================================
|
| 217 |
+
# 3. Generation params round-trip (global mode)
|
| 218 |
+
# ===========================================================================
|
| 219 |
+
|
| 220 |
+
def _unpack(cfg_rescale=0.0, protection_mode='auto', strict_threshold=0.1,
|
| 221 |
+
step_window_enabled=False, step_window_mode='global',
|
| 222 |
+
step_window_start=0.0, step_window_end=1.0, **ps_kwargs):
|
| 223 |
+
"""Mirrors AccordionInterface.unpack_processing_args logic — no gradio needed."""
|
| 224 |
+
from lib_neutral_prompt.step_utils import (
|
| 225 |
+
StepWindow, build_per_strategy_windows, STRATEGY_UI_KEYS,
|
| 226 |
+
)
|
| 227 |
+
enabled = bool(step_window_enabled)
|
| 228 |
+
mode = str(step_window_mode)
|
| 229 |
+
start = max(0.0, min(1.0, float(step_window_start)))
|
| 230 |
+
end = max(0.0, min(1.0, float(step_window_end)))
|
| 231 |
+
if start > end:
|
| 232 |
+
start, end = 0.0, 1.0
|
| 233 |
+
|
| 234 |
+
custom_raw = {}
|
| 235 |
+
for uk in STRATEGY_UI_KEYS:
|
| 236 |
+
sk, ek = f'{uk}_start', f'{uk}_end'
|
| 237 |
+
if sk in ps_kwargs and ek in ps_kwargs:
|
| 238 |
+
custom_raw[uk] = (float(ps_kwargs[sk]), float(ps_kwargs[ek]))
|
| 239 |
+
|
| 240 |
+
global_state.step_window_enabled = enabled
|
| 241 |
+
if enabled:
|
| 242 |
+
if mode == 'global':
|
| 243 |
+
global_state.step_window_global = StepWindow(start, end)
|
| 244 |
+
global_state.step_window_use_defaults = False
|
| 245 |
+
global_state.step_window_per_strategy = None
|
| 246 |
+
elif mode == 'per-strategy defaults':
|
| 247 |
+
global_state.step_window_global = None
|
| 248 |
+
global_state.step_window_use_defaults = True
|
| 249 |
+
global_state.step_window_per_strategy = None
|
| 250 |
+
elif mode == 'per-strategy custom':
|
| 251 |
+
global_state.step_window_global = None
|
| 252 |
+
global_state.step_window_use_defaults = False
|
| 253 |
+
global_state.step_window_custom_raw = custom_raw
|
| 254 |
+
global_state.step_window_per_strategy = build_per_strategy_windows(custom_raw)
|
| 255 |
+
else:
|
| 256 |
+
global_state.step_window_global = None
|
| 257 |
+
global_state.step_window_use_defaults = False
|
| 258 |
+
else:
|
| 259 |
+
global_state.step_window_global = None
|
| 260 |
+
global_state.step_window_use_defaults = False
|
| 261 |
+
global_state.step_window_per_strategy = None
|
| 262 |
+
|
| 263 |
+
return {
|
| 264 |
+
'cfg_rescale': cfg_rescale, 'protection_mode': protection_mode,
|
| 265 |
+
'strict_threshold': strict_threshold,
|
| 266 |
+
'step_window_enabled': enabled, 'step_window_mode': mode,
|
| 267 |
+
'step_window_start': start, 'step_window_end': end,
|
| 268 |
+
'step_window_custom_raw': custom_raw,
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def _extra(args):
|
| 273 |
+
"""Mirrors AccordionInterface.get_extra_generation_params — no gradio."""
|
| 274 |
+
p = {
|
| 275 |
+
'CFG Rescale phi': args['cfg_rescale'],
|
| 276 |
+
'NP Protection Mode': args['protection_mode'],
|
| 277 |
+
'NP Strict Threshold': args['strict_threshold'],
|
| 278 |
+
}
|
| 279 |
+
if args.get('step_window_enabled'):
|
| 280 |
+
p['NP Step Window Enabled'] = True
|
| 281 |
+
p['NP Step Window Mode'] = args.get('step_window_mode', 'global')
|
| 282 |
+
p['NP Step Window Start'] = round(float(args.get('step_window_start', 0.0)), 4)
|
| 283 |
+
p['NP Step Window End'] = round(float(args.get('step_window_end', 1.0)), 4)
|
| 284 |
+
if args.get('step_window_mode') == 'per-strategy custom':
|
| 285 |
+
raw = args.get('step_window_custom_raw') or {}
|
| 286 |
+
p['NP Step Window Custom'] = serialize_per_strategy_windows(raw)
|
| 287 |
+
return p
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class TestGenerationParamsRoundTrip(unittest.TestCase):
|
| 291 |
+
|
| 292 |
+
def setUp(self):
|
| 293 |
+
self._state = {
|
| 294 |
+
'enabled': global_state.step_window_enabled,
|
| 295 |
+
'global': global_state.step_window_global,
|
| 296 |
+
'use_defaults': global_state.step_window_use_defaults,
|
| 297 |
+
'per_strategy': global_state.step_window_per_strategy,
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
def tearDown(self):
|
| 301 |
+
global_state.step_window_enabled = self._state['enabled']
|
| 302 |
+
global_state.step_window_global = self._state['global']
|
| 303 |
+
global_state.step_window_use_defaults = self._state['use_defaults']
|
| 304 |
+
global_state.step_window_per_strategy = self._state['per_strategy']
|
| 305 |
+
|
| 306 |
+
# -- global mode round-trip --
|
| 307 |
+
|
| 308 |
+
def test_global_extra_params_present_when_enabled(self):
|
| 309 |
+
args = _unpack(step_window_enabled=True, step_window_mode='global',
|
| 310 |
+
step_window_start=0.25, step_window_end=0.80)
|
| 311 |
+
p = _extra(args)
|
| 312 |
+
self.assertTrue(p['NP Step Window Enabled'])
|
| 313 |
+
self.assertEqual(p['NP Step Window Mode'], 'global')
|
| 314 |
+
self.assertAlmostEqual(p['NP Step Window Start'], 0.25, places=4)
|
| 315 |
+
self.assertAlmostEqual(p['NP Step Window End'], 0.80, places=4)
|
| 316 |
+
|
| 317 |
+
def test_global_extra_params_absent_when_disabled(self):
|
| 318 |
+
args = _unpack(step_window_enabled=False)
|
| 319 |
+
p = _extra(args)
|
| 320 |
+
self.assertNotIn('NP Step Window Enabled', p)
|
| 321 |
+
|
| 322 |
+
def test_global_apply_infotext_restores_state(self):
|
| 323 |
+
global_state.apply_infotext({
|
| 324 |
+
'NP Step Window Enabled': 'True',
|
| 325 |
+
'NP Step Window Mode': 'global',
|
| 326 |
+
'NP Step Window Start': '0.3',
|
| 327 |
+
'NP Step Window End': '0.7',
|
| 328 |
+
})
|
| 329 |
+
self.assertTrue(global_state.step_window_enabled)
|
| 330 |
+
self.assertIsNotNone(global_state.step_window_global)
|
| 331 |
+
self.assertAlmostEqual(global_state.step_window_global.start, 0.3, places=4)
|
| 332 |
+
self.assertAlmostEqual(global_state.step_window_global.end, 0.7, places=4)
|
| 333 |
+
|
| 334 |
+
def test_inverted_range_repaired_in_unpack(self):
|
| 335 |
+
args = _unpack(step_window_enabled=True, step_window_mode='global',
|
| 336 |
+
step_window_start=0.9, step_window_end=0.1)
|
| 337 |
+
self.assertAlmostEqual(args['step_window_start'], 0.0)
|
| 338 |
+
self.assertAlmostEqual(args['step_window_end'], 1.0)
|
| 339 |
+
|
| 340 |
+
def test_inverted_range_repaired_in_apply_infotext(self):
|
| 341 |
+
global_state.apply_infotext({
|
| 342 |
+
'NP Step Window Enabled': 'True',
|
| 343 |
+
'NP Step Window Mode': 'global',
|
| 344 |
+
'NP Step Window Start': '0.9',
|
| 345 |
+
'NP Step Window End': '0.1',
|
| 346 |
+
})
|
| 347 |
+
win = global_state.step_window_global
|
| 348 |
+
self.assertAlmostEqual(win.start, 0.0)
|
| 349 |
+
self.assertAlmostEqual(win.end, 1.0)
|
| 350 |
+
|
| 351 |
+
def test_old_infotext_without_step_window_keys_safe(self):
|
| 352 |
+
global_state.apply_infotext({'CFG Rescale phi': '0.5', 'NP Protection Mode': 'auto'})
|
| 353 |
+
# No exception — step-window state unchanged
|
| 354 |
+
|
| 355 |
+
def test_defaults_mode_apply_infotext(self):
|
| 356 |
+
global_state.apply_infotext({
|
| 357 |
+
'NP Step Window Enabled': 'True',
|
| 358 |
+
'NP Step Window Mode': 'per-strategy defaults',
|
| 359 |
+
})
|
| 360 |
+
self.assertTrue(global_state.step_window_enabled)
|
| 361 |
+
self.assertTrue(global_state.step_window_use_defaults)
|
| 362 |
+
self.assertIsNone(global_state.step_window_global)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# ===========================================================================
|
| 366 |
+
# 4. Per-strategy custom window helpers
|
| 367 |
+
# ===========================================================================
|
| 368 |
+
|
| 369 |
+
class TestPerStrategyHelpers(unittest.TestCase):
|
| 370 |
+
|
| 371 |
+
def test_ui_key_to_family_all_keys(self):
|
| 372 |
+
mapping = {
|
| 373 |
+
'AND_PERP': 'PERPENDICULAR',
|
| 374 |
+
'AND_SALT': 'SALIENCE_MASK',
|
| 375 |
+
'AND_SALT_WIDE': 'SALIENCE_MASK_WIDE',
|
| 376 |
+
'AND_SALT_BLOB': 'SALIENCE_MASK_BLOB',
|
| 377 |
+
'AND_TOPK': 'SEMANTIC_GUIDANCE',
|
| 378 |
+
'AND_ALIGN': 'ALIGNMENT_BLEND_CUSTOM',
|
| 379 |
+
'AND_MASK_ALIGN': 'ALIGNMENT_MASK_BLEND_CUSTOM',
|
| 380 |
+
}
|
| 381 |
+
for ui_key, expected_family in mapping.items():
|
| 382 |
+
with self.subTest(ui_key=ui_key):
|
| 383 |
+
self.assertEqual(ui_key_to_family(ui_key), expected_family)
|
| 384 |
+
|
| 385 |
+
def test_build_per_strategy_windows_types(self):
|
| 386 |
+
raw = {'AND_SALT': (0.4, 1.0), 'AND_PERP': (0.0, 0.5)}
|
| 387 |
+
result = build_per_strategy_windows(raw)
|
| 388 |
+
self.assertIsInstance(result['SALIENCE_MASK'], StepWindow)
|
| 389 |
+
self.assertAlmostEqual(result['SALIENCE_MASK'].start, 0.4)
|
| 390 |
+
self.assertAlmostEqual(result['PERPENDICULAR'].end, 0.5)
|
| 391 |
+
|
| 392 |
+
def test_build_per_strategy_windows_repairs_inverted(self):
|
| 393 |
+
raw = {'AND_SALT': (0.9, 0.1)} # inverted
|
| 394 |
+
result = build_per_strategy_windows(raw)
|
| 395 |
+
win = result['SALIENCE_MASK']
|
| 396 |
+
self.assertAlmostEqual(win.start, 0.0)
|
| 397 |
+
self.assertAlmostEqual(win.end, 1.0)
|
| 398 |
+
|
| 399 |
+
def test_unknown_ui_key_skipped(self):
|
| 400 |
+
raw = {'AND_MYSTERY': (0.2, 0.8)}
|
| 401 |
+
result = build_per_strategy_windows(raw)
|
| 402 |
+
self.assertEqual(len(result), 0)
|
| 403 |
+
|
| 404 |
+
def test_serialize_round_trip(self):
|
| 405 |
+
raw = {'AND_PERP': (0.0, 0.5), 'AND_SALT': (0.4, 1.0)}
|
| 406 |
+
serialized = serialize_per_strategy_windows(raw)
|
| 407 |
+
restored = deserialize_per_strategy_windows(serialized)
|
| 408 |
+
for ui_key in raw:
|
| 409 |
+
self.assertIn(ui_key, restored)
|
| 410 |
+
self.assertAlmostEqual(restored[ui_key][0], raw[ui_key][0], places=3)
|
| 411 |
+
self.assertAlmostEqual(restored[ui_key][1], raw[ui_key][1], places=3)
|
| 412 |
+
|
| 413 |
+
def test_deserialize_tolerates_garbage(self):
|
| 414 |
+
restored = deserialize_per_strategy_windows('bad:data,AND_SALT:0.40-1.00,junk')
|
| 415 |
+
self.assertIn('AND_SALT', restored)
|
| 416 |
+
self.assertAlmostEqual(restored['AND_SALT'][0], 0.4, places=2)
|
| 417 |
+
|
| 418 |
+
def test_deserialize_empty_string(self):
|
| 419 |
+
self.assertEqual(deserialize_per_strategy_windows(''), {})
|
| 420 |
+
|
| 421 |
+
def test_normalize_step_range_clamps(self):
|
| 422 |
+
s, e = normalize_step_range(-0.5, 1.5)
|
| 423 |
+
self.assertAlmostEqual(s, 0.0)
|
| 424 |
+
self.assertAlmostEqual(e, 1.0)
|
| 425 |
+
|
| 426 |
+
def test_normalize_step_range_swaps_inverted(self):
|
| 427 |
+
s, e = normalize_step_range(0.8, 0.2)
|
| 428 |
+
self.assertAlmostEqual(s, 0.0)
|
| 429 |
+
self.assertAlmostEqual(e, 1.0)
|
| 430 |
+
|
| 431 |
+
def test_default_custom_windows_covers_all_keys(self):
|
| 432 |
+
defaults = default_custom_windows()
|
| 433 |
+
for uk in STRATEGY_UI_KEYS:
|
| 434 |
+
self.assertIn(uk, defaults)
|
| 435 |
+
s, e = defaults[uk]
|
| 436 |
+
self.assertGreaterEqual(s, 0.0)
|
| 437 |
+
self.assertLessEqual(e, 1.0)
|
| 438 |
+
self.assertLessEqual(s, e)
|
| 439 |
+
|
| 440 |
+
def test_strategy_is_active_custom_window_overrides_defaults(self):
|
| 441 |
+
"""Custom window for AND_SALT: 0.0–0.3 should gate it out at step 8/10."""
|
| 442 |
+
raw = {'AND_SALT': (0.0, 0.3)}
|
| 443 |
+
per_strategy = build_per_strategy_windows(raw)
|
| 444 |
+
progress = normalize_progress(8, 10)
|
| 445 |
+
active = strategy_is_active('SALIENCE_MASK', progress,
|
| 446 |
+
per_strategy_windows=per_strategy,
|
| 447 |
+
use_defaults=True)
|
| 448 |
+
self.assertFalse(active)
|
| 449 |
+
|
| 450 |
+
def test_strategy_is_active_fallback_to_defaults_when_key_missing(self):
|
| 451 |
+
"""If AND_TOPK not in custom dict, fall back to STRATEGY_DEFAULTS."""
|
| 452 |
+
raw = {'AND_SALT': (0.0, 0.3)} # only SALT defined
|
| 453 |
+
per_strategy = build_per_strategy_windows(raw)
|
| 454 |
+
# SEMANTIC_GUIDANCE default: 0.4–1.0 → active at step 8/10
|
| 455 |
+
progress = normalize_progress(8, 10)
|
| 456 |
+
active = strategy_is_active('SEMANTIC_GUIDANCE', progress,
|
| 457 |
+
per_strategy_windows=per_strategy,
|
| 458 |
+
use_defaults=True)
|
| 459 |
+
self.assertTrue(active)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
# ===========================================================================
|
| 463 |
+
# 5. Per-strategy custom round-trip via infotext
|
| 464 |
+
# ===========================================================================
|
| 465 |
+
|
| 466 |
+
class TestPerStrategyRoundTrip(unittest.TestCase):
|
| 467 |
+
|
| 468 |
+
def setUp(self):
|
| 469 |
+
self._state = {
|
| 470 |
+
'enabled': global_state.step_window_enabled,
|
| 471 |
+
'use_defaults': global_state.step_window_use_defaults,
|
| 472 |
+
'per_strategy': global_state.step_window_per_strategy,
|
| 473 |
+
'custom_raw': global_state.step_window_custom_raw,
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
def tearDown(self):
|
| 477 |
+
global_state.step_window_enabled = self._state['enabled']
|
| 478 |
+
global_state.step_window_use_defaults = self._state['use_defaults']
|
| 479 |
+
global_state.step_window_per_strategy = self._state['per_strategy']
|
| 480 |
+
global_state.step_window_custom_raw = self._state['custom_raw']
|
| 481 |
+
|
| 482 |
+
def test_custom_mode_extra_params_has_custom_key(self):
|
| 483 |
+
raw = {'AND_PERP': (0.0, 0.4), 'AND_SALT': (0.5, 1.0)}
|
| 484 |
+
args = _unpack(step_window_enabled=True, step_window_mode='per-strategy custom',
|
| 485 |
+
**{f'{k}_start': v[0] for k, v in raw.items()},
|
| 486 |
+
**{f'{k}_end': v[1] for k, v in raw.items()})
|
| 487 |
+
params = _extra(args)
|
| 488 |
+
self.assertIn('NP Step Window Custom', params)
|
| 489 |
+
self.assertIn('AND_PERP', params['NP Step Window Custom'])
|
| 490 |
+
self.assertIn('AND_SALT', params['NP Step Window Custom'])
|
| 491 |
+
|
| 492 |
+
def test_custom_mode_infotext_restore(self):
|
| 493 |
+
raw_str = serialize_per_strategy_windows(
|
| 494 |
+
{'AND_PERP': (0.0, 0.4), 'AND_SALT': (0.5, 1.0)}
|
| 495 |
+
)
|
| 496 |
+
global_state.apply_infotext({
|
| 497 |
+
'NP Step Window Enabled': 'True',
|
| 498 |
+
'NP Step Window Mode': 'per-strategy custom',
|
| 499 |
+
'NP Step Window Custom': raw_str,
|
| 500 |
+
})
|
| 501 |
+
self.assertTrue(global_state.step_window_enabled)
|
| 502 |
+
self.assertIsNotNone(global_state.step_window_per_strategy)
|
| 503 |
+
self.assertIn('PERPENDICULAR', global_state.step_window_per_strategy)
|
| 504 |
+
self.assertIn('SALIENCE_MASK', global_state.step_window_per_strategy)
|
| 505 |
+
perp_win = global_state.step_window_per_strategy['PERPENDICULAR']
|
| 506 |
+
self.assertAlmostEqual(perp_win.start, 0.0, places=3)
|
| 507 |
+
self.assertAlmostEqual(perp_win.end, 0.4, places=3)
|
| 508 |
+
|
| 509 |
+
def test_custom_mode_restored_state_gates_correctly(self):
|
| 510 |
+
"""After restore, strategy_is_active uses the per-strategy windows."""
|
| 511 |
+
raw_str = serialize_per_strategy_windows({'AND_PERP': (0.0, 0.2)})
|
| 512 |
+
global_state.apply_infotext({
|
| 513 |
+
'NP Step Window Enabled': 'True',
|
| 514 |
+
'NP Step Window Mode': 'per-strategy custom',
|
| 515 |
+
'NP Step Window Custom': raw_str,
|
| 516 |
+
})
|
| 517 |
+
per_strategy = global_state.step_window_per_strategy
|
| 518 |
+
# AND_PERP window 0.0–0.2 → inactive at step 5/10 (progress=0.56)
|
| 519 |
+
progress = normalize_progress(5, 10)
|
| 520 |
+
self.assertFalse(
|
| 521 |
+
strategy_is_active('PERPENDICULAR', progress,
|
| 522 |
+
per_strategy_windows=per_strategy)
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
def test_full_round_trip_all_strategies(self):
|
| 526 |
+
raw = default_custom_windows()
|
| 527 |
+
serialized = serialize_per_strategy_windows(raw)
|
| 528 |
+
restored = deserialize_per_strategy_windows(serialized)
|
| 529 |
+
# Every key present
|
| 530 |
+
for uk in STRATEGY_UI_KEYS:
|
| 531 |
+
self.assertIn(uk, restored)
|
| 532 |
+
# Values match to 3 decimal places
|
| 533 |
+
for uk in STRATEGY_UI_KEYS:
|
| 534 |
+
self.assertAlmostEqual(restored[uk][0], raw[uk][0], places=3)
|
| 535 |
+
self.assertAlmostEqual(restored[uk][1], raw[uk][1], places=3)
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
if __name__ == '__main__':
|
| 539 |
+
unittest.main()
|