dikdimon commited on
Commit
e0cb5cd
Β·
verified Β·
1 Parent(s): cd6c22c

Upload 12 files

Browse files
mega_freeu_a1111/README.md ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ⚑ Mega FreeU β€” A1111 Extension
2
+
3
+ **Mega FreeU** combines the best features from 5 FreeU implementations into a single
4
+ production-ready A1111 extension.
5
+
6
+ ## Sources & Features
7
+
8
+ | Source | Features taken |
9
+ |--------|---------------|
10
+ | **sd-webui-freeu** | `th.cat` hijack (only correct A1111 approach), V1/V2 backbone, box filter (BUG FIXED), schedule start/stop/smoothness, backbone region (width+offset), presets JSON, PNG metadata, XYZ grid, ControlNet patch |
11
+ | **WAS FreeU_Advanced** | 9 blending modes (lerp/inject/bislerp/colorize/cosine/cuberp/hslerp/stable_slerp/linear_dodge), 13 multi-scale FFT presets, override_scales textarea, Post-CFG Shift (ported to A1111). Note: `target_block` / `input_block` / `middle_block` / `slice_b1/b2` not ported β€” A1111's `th.cat` hijack operates on output-side skip connections only. |
12
+ | **ComfyUI_FreeU_V2_Advanced** | Gaussian FFT filter (smooth, no ringing), Adaptive Cap loop (MAX_CAP_ITER=3), independent B/S timestep ranges per stage, channel_threshold matching |
13
+ | **FreeU_V2_timestepadd** | b_start/b_end%, s_start/s_end% per-stage step-fraction gating (note: original ComfyUI used `percent_to_sigma`; this port uses `current_step / total_steps` β€” conceptually equivalent for typical schedulers) |
14
+ | **nrs_kohaku_v3.5** | hf_boost parameter, gaussian on output, on_cpu fallback tracker |
15
+
16
+ ## Bug Fixed
17
+ - `sdwebui-freeU-extension` had Fourier mask applied to ONE quadrant:
18
+ `mask[..., crow-t:crow, ccol-t:ccol]`
19
+ **Fixed**: `mask[..., crow-t:crow+t, ccol-t:ccol+t]` (symmetric center)
20
+
21
+ ## Installation
22
+ Copy `mega_freeu_a1111/` into `stable-diffusion-webui/extensions/` and restart.
23
+
24
+ ## Recommended Settings (SD1.5 V2 Gaussian + Independent B/S)
25
+ - Stage 1: B=1.2, S=0.9, FFT=gaussian, B End%=0.35, S Start%=0.35
26
+ - Stage 2: B=1.4, S=0.2, FFT=gaussian, B End%=0.35, S Start%=0.35
27
+
28
+ ## Key Concept: Independent B/S Timestep Ranges
29
+ - **B-scaling** is most effective in structure phase (early steps, B Start%=0.0, B End%=0.35)
30
+ - **S-filtering** is most effective in detail phase (late steps, S Start%=0.35, S End%=1.0)
mega_freeu_a1111/install.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # install.py β€” auto-run by A1111 extension system (no external deps needed)
2
+ import launch
3
+ # All deps (torch, gradio, modules) are already in A1111 environment
mega_freeu_a1111/lib_mega_freeu/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # lib_mega_freeu β€” Mega FreeU for A1111
mega_freeu_a1111/lib_mega_freeu/global_state.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ lib_mega_freeu/global_state.py
3
+ Runtime state, data structures, presets for ⚑ Mega FreeU.
4
+
5
+ Sources:
6
+ sd-webui-freeu/lib_free_u/global_state.py -- StageInfo layout, State, preset JSON, XYZ
7
+ WAS FreeU_Advanced/nodes.py -- BLEND_MODE_NAMES, MSCALES
8
+ ComfyUI_FreeU_V2_Advanced/FreeU_B1B2.py -- b_start/b_end, channel_threshold
9
+ ComfyUI_FreeU_V2_Advanced/FreeU_S1S2.py -- s_start/s_end, adaptive cap
10
+ nrs_kohaku_enhanced_v3_5.py -- hf_boost, gaussian standalone
11
+ """
12
+ import dataclasses
13
+ import json
14
+ import math
15
+ import pathlib
16
+ import re
17
+ import sys
18
+ from typing import Any, Dict, List, Optional, Union
19
+
20
+ # ─── Blending modes (WAS FreeU_Advanced/nodes.py blending_modes keys) ─────────
21
+ BLEND_MODE_NAMES: List[str] = [
22
+ "lerp", "inject", "bislerp", "colorize",
23
+ "cosine interp", "cuberp", "hslerp", "stable_slerp", "linear dodge",
24
+ ]
25
+
26
+ # ─── Multi-scale presets (WAS nodes.py mscales dict -- exact) ─────────────────
27
+ MSCALES: Dict[str, Optional[list]] = {
28
+ "Default": None,
29
+ "Low-Pass": [(10, 1.0)],
30
+ "Pass-Through": [(10, 1.0)],
31
+ "Gaussian-Blur": [(10, 0.5)],
32
+ "Edge-Enhancement": [(10, 2.0)],
33
+ "Sharpen": [(10, 1.5)],
34
+ "Multi-Bandpass": [[(5, 0.0), (15, 1.0), (25, 0.0)]],
35
+ "Multi-Low-Pass": [[(5, 1.0), (10, 0.5), (15, 0.2)]],
36
+ "Multi-High-Pass": [[(5, 0.0), (10, 0.5), (15, 0.8)]],
37
+ "Multi-Pass-Through": [[(5, 1.0), (10, 1.0), (15, 1.0)]],
38
+ "Multi-Gaussian-Blur": [[(5, 0.5), (10, 0.8), (15, 0.2)]],
39
+ "Multi-Edge-Enhancement": [[(5, 1.2), (10, 1.5), (15, 2.0)]],
40
+ "Multi-Sharpen": [[(5, 1.5), (10, 2.0), (15, 2.5)]],
41
+ }
42
+
43
+ ALL_VERSIONS: Dict[str, str] = {"Version 1": "1", "Version 2": "2"}
44
+ REVERSED_VERSIONS: Dict[str, str] = {v: k for k, v in ALL_VERSIONS.items()}
45
+ FFT_TYPES: List[str] = ["gaussian", "box"]
46
+ STAGES_COUNT: int = 3
47
+
48
+ _shorthand_re = re.compile(r"^([a-z]{1,3})(\d+)$")
49
+
50
+ # ─── StageInfo ─────────────────────────────────────────────────────────────────
51
+ @dataclasses.dataclass
52
+ class StageInfo:
53
+ """
54
+ All per-stage parameters.
55
+ Fields 1-6: same order as sd-webui-freeu for PNG backwards compat.
56
+ New fields appended at end.
57
+ """
58
+ # sd-webui-freeu compat (DO NOT REORDER first 6)
59
+ backbone_factor: float = 1.0
60
+ skip_factor: float = 1.0
61
+ backbone_offset: float = 0.0
62
+ backbone_width: float = 0.5
63
+ skip_cutoff: float = 0.0
64
+ skip_high_end_factor: float = 1.0
65
+ # WAS blending
66
+ backbone_blend_mode: str = "lerp"
67
+ backbone_blend: float = 1.0
68
+ # ComfyUI V2 independent timestep ranges
69
+ b_start_ratio: float = 0.0
70
+ b_end_ratio: float = 1.0
71
+ s_start_ratio: float = 0.0
72
+ s_end_ratio: float = 1.0
73
+ # FFT
74
+ fft_type: str = "box"
75
+ fft_radius_ratio: float = 0.07
76
+ hf_boost: float = 1.0
77
+ # Adaptive Cap (FreeU_S1S2)
78
+ enable_adaptive_cap: bool = False
79
+ cap_threshold: float = 0.35
80
+ cap_factor: float = 0.6
81
+ adaptive_cap_mode: str = "adaptive"
82
+
83
+ def to_dict(self, include_default=False):
84
+ default = StageInfo()
85
+ d = dataclasses.asdict(self)
86
+ if not include_default:
87
+ d = {k: v for k, v in d.items() if v != getattr(default, k)}
88
+ return d
89
+
90
+ def copy(self):
91
+ return StageInfo(**dataclasses.asdict(self))
92
+
93
+ STAGE_FIELD_NAMES = [f.name for f in dataclasses.fields(StageInfo)]
94
+ STAGE_FIELD_COUNT = len(STAGE_FIELD_NAMES)
95
+
96
+ # ─── State ─────────────────────────────────────────────────────────────────────
97
+ @dataclasses.dataclass
98
+ class State:
99
+ enable: bool = True
100
+ start_ratio: Any = 0.0
101
+ stop_ratio: Any = 1.0
102
+ transition_smoothness: float = 0.0
103
+ version: str = "1"
104
+ multiscale_mode: str = "Default"
105
+ multiscale_strength: float = 1.0
106
+ override_scales: str = ""
107
+ channel_threshold: int = 96
108
+ stage_infos: List[Any] = dataclasses.field(
109
+ default_factory=lambda: [StageInfo() for _ in range(STAGES_COUNT)]
110
+ )
111
+ # Post-CFG Shift (WAS_PostCFGShift) β€” stored in presets & PNG
112
+ pcfg_enabled: bool = False
113
+ pcfg_steps: int = 20
114
+ pcfg_mode: str = "inject"
115
+ pcfg_blend: float = 1.0
116
+ pcfg_b: float = 1.1
117
+ pcfg_fourier: bool = False
118
+ pcfg_ms_mode: str = "Default"
119
+ pcfg_ms_str: float = 1.0
120
+ pcfg_threshold: int = 1
121
+ pcfg_s: float = 0.5
122
+ pcfg_gain: float = 1.0
123
+ verbose: bool = False
124
+
125
+ def __post_init__(self):
126
+ self.stage_infos = self._coerce_stages()
127
+ self.version = ALL_VERSIONS.get(self.version, self.version)
128
+
129
+ def _coerce_stages(self):
130
+ result, raw = [], list(self.stage_infos)
131
+ i = 0
132
+ while i < len(raw) and len(result) < STAGES_COUNT:
133
+ item = raw[i]
134
+ if isinstance(item, StageInfo):
135
+ result.append(item); i += 1
136
+ elif isinstance(item, dict):
137
+ known = {k: v for k, v in item.items() if k in STAGE_FIELD_NAMES}
138
+ result.append(StageInfo(**known)); i += 1
139
+ else:
140
+ chunk = raw[i:i+STAGE_FIELD_COUNT]
141
+ result.append(StageInfo(*chunk))
142
+ i += STAGE_FIELD_COUNT
143
+ while len(result) < STAGES_COUNT:
144
+ result.append(StageInfo())
145
+ return result
146
+
147
+ def to_dict(self):
148
+ d = dataclasses.asdict(self)
149
+ d["stage_infos"] = [si.to_dict() for si in self.stage_infos]
150
+ del d["enable"]
151
+ return d
152
+
153
+ def copy(self):
154
+ d = dataclasses.asdict(self)
155
+ d["stage_infos"] = [StageInfo(**s) for s in d["stage_infos"]]
156
+ return State(**d)
157
+
158
+ def update_attr(self, key, value):
159
+ if m := _shorthand_re.match(key):
160
+ char, idx = m.group(1), int(m.group(2))
161
+ if 0 <= idx < STAGES_COUNT:
162
+ si = self.stage_infos[idx]
163
+ _MAP = {
164
+ "b":"backbone_factor","s":"skip_factor","o":"backbone_offset",
165
+ "w":"backbone_width","t":"skip_cutoff","h":"skip_high_end_factor",
166
+ "bm":"backbone_blend_mode","bb":"backbone_blend",
167
+ "bs":"b_start_ratio","be":"b_end_ratio",
168
+ "ss":"s_start_ratio","se":"s_end_ratio",
169
+ "ft":"fft_type","fr":"fft_radius_ratio","hfb":"hf_boost",
170
+ "cap":"enable_adaptive_cap","ct":"cap_threshold","cf":"cap_factor","acm":"adaptive_cap_mode",
171
+ }
172
+ if char in _MAP:
173
+ setattr(si, _MAP[char], value); return
174
+ if hasattr(self, key):
175
+ setattr(self, key, value)
176
+
177
+ # ─── Singletons ────────────────────────────────────────────────────────────────
178
+ instance: State = State()
179
+ xyz_attrs: Dict[str, Any] = {}
180
+ current_sampling_step: int = 0
181
+
182
+ # ─── Preset builders ───────────────────────────────────────────────────────────
183
+ def _v1(*pairs):
184
+ infos = [StageInfo(backbone_factor=b, skip_factor=s) for b,s in pairs]
185
+ while len(infos) < STAGES_COUNT: infos.append(StageInfo())
186
+ return State(version="1", stage_infos=infos)
187
+
188
+ def _v2g(pairs):
189
+ infos = []
190
+ for b, s, r, hfb, bs, be, ss, se in pairs:
191
+ infos.append(StageInfo(
192
+ backbone_factor=b, skip_factor=s,
193
+ fft_type="gaussian", fft_radius_ratio=r, hf_boost=hfb,
194
+ b_start_ratio=bs, b_end_ratio=be,
195
+ s_start_ratio=ss, s_end_ratio=se,
196
+ ))
197
+ while len(infos) < STAGES_COUNT: infos.append(StageInfo())
198
+ return State(version="2", stage_infos=infos)
199
+
200
+ default_presets: Dict[str, State] = {
201
+ "SD1.4 Recommendations": _v1((1.2,0.9),(1.4,0.2),(1.0,1.0)),
202
+ "SD2.1 Recommendations": _v1((1.1,0.9),(1.2,0.2),(1.0,1.0)),
203
+ "SDXL Recommendations": _v1((1.1,0.6),(1.2,0.4),(1.0,1.0)),
204
+ "SD1.5 V2 Gaussian": _v2g([
205
+ (1.2,0.9,0.07,1.0, 0.0,0.35, 0.35,1.0),
206
+ (1.4,0.2,0.07,1.0, 0.0,0.35, 0.35,1.0),
207
+ ]),
208
+ "SD1.5 V2 High Detail": _v2g([
209
+ (1.4,0.8,0.08,1.2, 0.0,0.35, 0.35,1.0),
210
+ (1.6,0.1,0.06,1.0, 0.0,0.35, 0.35,1.0),
211
+ ]),
212
+ "SDXL V2 Gaussian": _v2g([
213
+ (1.1,0.6,0.05,1.1, 0.0,0.35, 0.35,1.0),
214
+ (1.2,0.4,0.05,1.1, 0.0,0.35, 0.35,1.0),
215
+ ]),
216
+ "SD1.5 Adaptive Cap": State(version="2", stage_infos=[
217
+ StageInfo(backbone_factor=1.3,skip_factor=0.9,
218
+ fft_type="gaussian",fft_radius_ratio=0.08,hf_boost=1.2,
219
+ b_start_ratio=0.0,b_end_ratio=0.35,
220
+ s_start_ratio=0.35,s_end_ratio=1.0,
221
+ enable_adaptive_cap=True,cap_threshold=0.35,
222
+ cap_factor=0.6,adaptive_cap_mode="adaptive"),
223
+ StageInfo(backbone_factor=1.4,skip_factor=0.2,
224
+ fft_type="gaussian",fft_radius_ratio=0.06,hf_boost=1.0,
225
+ b_start_ratio=0.0,b_end_ratio=0.35,
226
+ s_start_ratio=0.35,s_end_ratio=1.0,
227
+ enable_adaptive_cap=True,cap_threshold=0.70,
228
+ cap_factor=0.6,adaptive_cap_mode="adaptive"),
229
+ StageInfo(),
230
+ ]),
231
+ "Independent B/S (SD1.5)": _v2g([
232
+ (1.2,0.9,0.07,1.0, 0.0,0.35, 0.35,1.0),
233
+ (1.4,0.2,0.06,1.0, 0.0,0.35, 0.35,1.0),
234
+ ]),
235
+ }
236
+
237
+ all_presets: Dict[str, State] = {}
238
+ PRESETS_PATH = pathlib.Path(__file__).parent.parent / "presets.json"
239
+
240
+ def reload_presets():
241
+ all_presets.clear()
242
+ all_presets.update(default_presets)
243
+ all_presets.update(_load_user_presets())
244
+
245
+ def _load_user_presets():
246
+ if not PRESETS_PATH.exists(): return {}
247
+ try:
248
+ with open(PRESETS_PATH, encoding="utf-8") as f:
249
+ raw = json.load(f)
250
+ except Exception as e:
251
+ print(f"[MegaFreeU] preset load error: {e}", file=sys.stderr)
252
+ return {}
253
+ result = {}
254
+ _state_fields = {f.name for f in dataclasses.fields(State)}
255
+ for k, v in raw.items():
256
+ try:
257
+ # Filter unknown keys so future/old fields don't crash State(**v)
258
+ known = {fk: fv for fk, fv in v.items() if fk in _state_fields}
259
+ result[k] = State(**known)
260
+ except Exception as e:
261
+ print(f"[MegaFreeU] skipping preset {k!r}: {e}", file=sys.stderr)
262
+ return result
263
+
264
+ def save_presets(custom=None):
265
+ if custom is None: custom = get_user_presets()
266
+ try:
267
+ PRESETS_PATH.parent.mkdir(parents=True, exist_ok=True)
268
+ with open(PRESETS_PATH, "w", encoding="utf-8") as f:
269
+ json.dump({k: v.to_dict() for k,v in custom.items()}, f, indent=4)
270
+ except Exception as e:
271
+ print(f"[MegaFreeU] preset save error: {e}", file=sys.stderr)
272
+
273
+ def get_user_presets():
274
+ return {k: v for k,v in all_presets.items() if k not in default_presets}
275
+
276
+ def apply_xyz():
277
+ global instance
278
+ if pk := xyz_attrs.get("preset"):
279
+ if p := all_presets.get(pk):
280
+ instance = p.copy()
281
+ elif pk != "UI Settings":
282
+ print(f"[MegaFreeU] XYZ preset '{pk}' not found", file=sys.stderr)
283
+ for k, v in xyz_attrs.items():
284
+ if k != "preset":
285
+ instance.update_attr(k, v)
mega_freeu_a1111/lib_mega_freeu/unet.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ lib_mega_freeu/unet.py β€” Math engine + A1111 th.cat patch
3
+
4
+ BUGS FIXED vs sdwebui-freeU-extension/scripts/freeunet_hijack.py:
5
+ BUG 1 dtype: mask = torch.ones(..., dtype=torch.bool)
6
+ bool*float = NOOP, scale always 1.0
7
+ Fix: torch.full(..., float(scale_high))
8
+ BUG 2 quadrant: mask[..., crow-t:crow, ccol-t:ccol] (top-left only)
9
+ Fix: mask[..., crow-t:crow+t, ccol-t:ccol+t] (symmetric center)
10
+
11
+ Sources:
12
+ sd-webui-freeu/lib_free_u/unet.py patch(), free_u_cat_hijack(),
13
+ get_backbone_scale(), ratio_to_region(), filter_skip()[box],
14
+ get_schedule_ratio(), is_gpu_complex_supported(), lerp()
15
+ WAS FreeU_Advanced/nodes.py 9 blending modes, Fourier_filter() multiscale
16
+ ComfyUI_FreeU_V2_advanced/utils.py Fourier_filter_gauss(), get_band_energy_stats()
17
+ ComfyUI_FreeU_V2_advanced/FreeU_S1S2.py Adaptive Cap loop MAX_CAP_ITER=3
18
+ ComfyUI_FreeU_V2_advanced/FreeU_B1B2.py channel_threshold, model_channels*4/2/1
19
+ FreeU_V2_timestepadd.py step-fraction timestep gating concept
20
+ nrs_kohaku_enhanced_v3_5.py _freeu_b_scale_h, _freeu_fourier_filter_gaussian,
21
+ hf_boost param, on_cpu_devices dict
22
+ """
23
+ import dataclasses
24
+ import functools
25
+ import logging
26
+ import math
27
+ import pathlib
28
+ import sys
29
+ from typing import Dict, List, Optional, Tuple, Union
30
+
31
+ import torch
32
+ from lib_mega_freeu import global_state
33
+
34
+ # ── GPU complex support (sd-webui-freeu exact) ────────────────────────────────
35
+ _gpu_complex_support: Optional[bool] = None
36
+
37
+ def is_gpu_complex_supported(x: torch.Tensor) -> bool:
38
+ global _gpu_complex_support
39
+ if x.is_cpu:
40
+ return True
41
+ if _gpu_complex_support is not None:
42
+ return _gpu_complex_support
43
+ mps_avail = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
44
+ try:
45
+ import torch_directml
46
+ except ImportError:
47
+ dml_avail = False
48
+ else:
49
+ dml_avail = torch_directml.is_available()
50
+ _gpu_complex_support = not (mps_avail or dml_avail)
51
+ if _gpu_complex_support:
52
+ try: torch.fft.fftn(x.float(), dim=(-2, -1))
53
+ except RuntimeError: _gpu_complex_support = False
54
+ return _gpu_complex_support
55
+
56
+ _on_cpu_devices: Dict = {}
57
+
58
+ # ── Blending modes (WAS nodes.py exact) ───────────────────────────────────────
59
+ def _normalize(t):
60
+ mn, mx = t.min(), t.max()
61
+ return (t - mn) / (mx - mn + 1e-8)
62
+
63
+ def _hslerp(a, b, t):
64
+ nc = a.size(1)
65
+ iv = torch.zeros(1, nc, 1, 1, device=a.device, dtype=a.dtype)
66
+ iv[0, 0, 0, 0] = 1.0
67
+ result = (1 - t) * a + t * b
68
+ if t < 0.5:
69
+ result += (torch.norm(b - a, dim=1, keepdim=True) / 6) * iv
70
+ else:
71
+ result -= (torch.norm(b - a, dim=1, keepdim=True) / 6) * iv
72
+ return result
73
+
74
+ def _stable_slerp(a, b, t, eps=1e-6):
75
+ an = a / torch.linalg.norm(a, dim=1, keepdim=True).clamp_min(eps)
76
+ bn = b / torch.linalg.norm(b, dim=1, keepdim=True).clamp_min(eps)
77
+ dot = (an * bn).sum(dim=1, keepdim=True).clamp(-1.0 + eps, 1.0 - eps)
78
+ theta = torch.acos(dot)
79
+ sin_t = torch.sin(theta).clamp_min(eps)
80
+ s0 = torch.sin((1.0 - t) * theta) / sin_t
81
+ s1 = torch.sin(t * theta) / sin_t
82
+ slerp_out = s0 * a + s1 * b
83
+ lerp_out = (1.0 - t) * a + t * b
84
+ use_lerp = (theta < 1e-3).squeeze(1)
85
+ return torch.where(use_lerp.unsqueeze(1), lerp_out, slerp_out)
86
+
87
+ BLENDING_MODES = {
88
+ "bislerp": lambda a, b, t: _normalize((1 - t) * a + t * b),
89
+ "colorize": lambda a, b, t: a + (b - a) * t,
90
+ "cosine interp": lambda a, b, t: (
91
+ a + b - (a - b) * torch.cos(
92
+ torch.tensor(math.pi, device=a.device, dtype=a.dtype) * t)) / 2,
93
+ "cuberp": lambda a, b, t: a + (b - a) * (3 * t**2 - 2 * t**3),
94
+ "hslerp": _hslerp,
95
+ "stable_slerp": _stable_slerp,
96
+ "inject": lambda a, b, t: a + b * t,
97
+ "lerp": lambda a, b, t: (1 - t) * a + t * b,
98
+ "linear dodge": lambda a, b, t: _normalize(a + b * t),
99
+ }
100
+
101
+ def lerp(a, b, r):
102
+ return (1 - r) * a + r * b
103
+
104
+ # ── Backbone scaling ──────────────────────────────────────────────────────────
105
+ def get_backbone_scale(h: torch.Tensor, backbone_factor: float, version: str):
106
+ if version == "1":
107
+ return backbone_factor
108
+ # V2: adaptive hidden_mean (FreeU_B1B2.py + kohaku _freeu_b_scale_h exact)
109
+ features_mean = h.mean(1, keepdim=True)
110
+ B = features_mean.shape[0]
111
+ hidden_max, _ = torch.max(features_mean.view(B, -1), dim=-1, keepdim=True)
112
+ hidden_min, _ = torch.min(features_mean.view(B, -1), dim=-1, keepdim=True)
113
+ denom = (hidden_max - hidden_min).clamp_min(1e-6)
114
+ hidden_mean = (features_mean - hidden_min.unsqueeze(2).unsqueeze(3)) \
115
+ / denom.unsqueeze(2).unsqueeze(3)
116
+ return 1 + (backbone_factor - 1) * hidden_mean
117
+
118
+ def ratio_to_region(width: float, offset: float, n: int) -> Tuple[int, int, bool]:
119
+ """sd-webui-freeu ratio_to_region exact."""
120
+ if width < 0:
121
+ offset += width; width = -width
122
+ width = min(width, 1.0)
123
+ if offset < 0:
124
+ offset = 1 + offset - int(offset)
125
+ offset = math.fmod(offset, 1.0)
126
+ if width + offset <= 1:
127
+ return round(offset * n), round((width + offset) * n), False
128
+ else:
129
+ return round((width + offset - 1) * n), round(offset * n), True
130
+
131
+ # ── Box FFT (BUGS FIXED symmetric center + float dtype) ─────────────────────
132
+ def filter_skip_box(x: torch.Tensor, cutoff: float,
133
+ scale: float, scale_high: float = 1.0) -> torch.Tensor:
134
+ """
135
+ FreeU box filter with TWO BUGS FIXED from sdwebui-freeU-extension:
136
+ BUG 1 (dtype): was torch.bool mask -> scale multiplication was NOOP
137
+ BUG 2 (region): was [crow-t:crow, ccol-t:ccol] -> single quadrant top-left
138
+ Both fixed: torch.full float + symmetric [crow-t:crow+t, ccol-t:ccol+t].
139
+ sd-webui-freeu has these correct already, we match their implementation.
140
+ """
141
+ if scale == 1.0 and scale_high == 1.0:
142
+ return x
143
+ fft_dev = x.device if is_gpu_complex_supported(x) else torch.device("cpu")
144
+ x_freq = torch.fft.fftn(x.to(fft_dev, dtype=torch.float32), dim=(-2, -1))
145
+ x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
146
+ B, C, H, W = x_freq.shape
147
+ mask = torch.full((B, C, H, W), float(scale_high), device=fft_dev) # FIX: float, not bool
148
+ crow, ccol = H // 2, W // 2
149
+ tr = max(1, math.floor(crow * cutoff)) if cutoff > 0 else 1
150
+ tc = max(1, math.floor(ccol * cutoff)) if cutoff > 0 else 1
151
+ mask[..., crow - tr:crow + tr, ccol - tc:ccol + tc] = scale # FIX: symmetric center
152
+ x_freq *= mask
153
+ x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1))
154
+ return torch.fft.ifftn(x_freq, dim=(-2, -1)).real.to(device=x.device, dtype=x.dtype)
155
+
156
+ # ── Box + WAS multiscale overlay (WAS nodes.py Fourier_filter exact) ─────────
157
+ def filter_skip_box_multiscale(x: torch.Tensor, cutoff: float, scale: float,
158
+ scales_preset: Optional[list],
159
+ strength: float = 1.0,
160
+ scale_high: float = 1.0) -> torch.Tensor:
161
+ """
162
+ WAS FreeU_Advanced/nodes.py Fourier_filter(x, threshold, scale, scales, strength).
163
+ threshold = cutoff: float ratio [0-1] or int pixels (WAS uses int default=1).
164
+ scales: None, list of (radius_px, val) single-scale, or list of lists multi-scale.
165
+ """
166
+ if scale == 1.0 and scale_high == 1.0 and scales_preset is None:
167
+ return x
168
+ fft_dev = x.device if is_gpu_complex_supported(x) else torch.device("cpu")
169
+ x_freq = torch.fft.fftn(x.to(fft_dev, dtype=torch.float32), dim=(-2, -1))
170
+ x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
171
+ B, C, H, W = x_freq.shape
172
+ crow, ccol = H // 2, W // 2
173
+ if isinstance(cutoff, float) and 0 < cutoff <= 1.0:
174
+ tr = max(1, math.floor(crow * cutoff)); tc = max(1, math.floor(ccol * cutoff))
175
+ else:
176
+ t = max(1, int(cutoff)) if cutoff > 0 else 1; tr = tc = t
177
+ mask = torch.ones((B, C, H, W), device=fft_dev)
178
+ mask[..., crow - tr:crow + tr, ccol - tc:ccol + tc] = scale
179
+ if scale_high != 1.0:
180
+ hfm = torch.full((B, C, H, W), float(scale_high), device=fft_dev)
181
+ hfm[..., crow - tr:crow + tr, ccol - tc:ccol + tc] = 1.0
182
+ mask = mask * hfm
183
+ if scales_preset:
184
+ if isinstance(scales_preset[0], tuple):
185
+ # WAS single-scale mode
186
+ for scale_threshold, scale_value in scales_preset:
187
+ sv = scale_value * strength
188
+ sm = torch.ones((B, C, H, W), device=fft_dev)
189
+ st = max(1, int(scale_threshold))
190
+ sm[..., crow - st:crow + st, ccol - st:ccol + st] = sv
191
+ mask = mask + (sm - mask) * strength
192
+ else:
193
+ # WAS multi-scale mode
194
+ for scale_params in scales_preset:
195
+ if isinstance(scale_params, list):
196
+ for scale_threshold, scale_value in scale_params:
197
+ sv = scale_value * strength
198
+ sm = torch.ones((B, C, H, W), device=fft_dev)
199
+ st = max(1, int(scale_threshold))
200
+ sm[..., crow - st:crow + st, ccol - st:ccol + st] = sv
201
+ mask = mask + (sm - mask) * strength
202
+ x_freq = x_freq * mask
203
+ x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1))
204
+ return torch.fft.ifftn(x_freq, dim=(-2, -1)).real.to(device=x.device, dtype=x.dtype)
205
+
206
+ # ── Gaussian FFT (ComfyUI utils.py exact) ────────────────────────────────────
207
+ def fourier_filter_gauss(x: torch.Tensor, radius_ratio: float,
208
+ scale: float, hf_boost: float = 1.0) -> torch.Tensor:
209
+ """
210
+ ComfyUI_FreeU_V2_advanced/utils.py Fourier_filter_gauss() exact.
211
+ Also matches kohaku _freeu_fourier_filter_gaussian().
212
+ R = max(1, int(min(H,W)*radius_ratio))
213
+ sigma_f = R^2/2
214
+ center = exp(-dist2/sigma_f)
215
+ mask = scale*center + hf_boost*(1-center)
216
+ """
217
+ x_f = torch.fft.fftn(x.float(), dim=(-2, -1))
218
+ x_f = torch.fft.fftshift(x_f, dim=(-2, -1))
219
+ B, C, H, W = x_f.shape
220
+ R = max(1, int(min(H, W) * radius_ratio))
221
+ sigma_f = max(1e-6, (R * R) / 2.0)
222
+ yy, xx = torch.meshgrid(
223
+ torch.arange(H, device=x.device, dtype=torch.float32) - H // 2,
224
+ torch.arange(W, device=x.device, dtype=torch.float32) - W // 2,
225
+ indexing="ij")
226
+ center = torch.exp(-(yy**2 + xx**2) / sigma_f)
227
+ mask = (scale * center + hf_boost * (1.0 - center)).view(1, 1, H, W)
228
+ x_f = x_f * mask
229
+ x_f = torch.fft.ifftshift(x_f, dim=(-2, -1))
230
+ return torch.fft.ifftn(x_f, dim=(-2, -1)).real.to(x.dtype)
231
+
232
+ # ── Band energy stats (ComfyUI utils.py exact) ────────────────────────────────
233
+ def get_band_energy_stats(x: torch.Tensor, R: int) -> Tuple[float, float, float]:
234
+ """ComfyUI_FreeU_V2_advanced/utils.py get_band_energy_stats() exact."""
235
+ xf = torch.fft.fftn(x.float(), dim=(-2, -1))
236
+ xf = torch.fft.fftshift(xf, dim=(-2, -1))
237
+ B, C, H, W = xf.shape
238
+ yy, xx = torch.meshgrid(
239
+ torch.arange(H, device=x.device, dtype=torch.float32) - H // 2,
240
+ torch.arange(W, device=x.device, dtype=torch.float32) - W // 2,
241
+ indexing="ij")
242
+ lf_mask = (yy**2 + xx**2) <= (R * R)
243
+ mag2 = xf.real**2 + xf.imag**2
244
+ # FIX: expand_as requires same ndim; use 2D mask on last dims
245
+ lf_e = mag2[:, :, lf_mask].mean().item() if lf_mask.any() else 0.0
246
+ hf_e = mag2[:, :, ~lf_mask].mean().item() if (~lf_mask).any() else 0.0
247
+ cover = lf_mask.sum().item() / (H * W) * 100.0
248
+ return lf_e, hf_e, cover
249
+
250
+ # ── Adaptive Cap Gaussian (FreeU_S1S2.py MAX_CAP_ITER=3 exact) ───────────────
251
+ def filter_skip_gaussian_adaptive(hsp: torch.Tensor,
252
+ si: "global_state.StageInfo",
253
+ verbose: bool = False) -> torch.Tensor:
254
+ """
255
+ ComfyUI_FreeU_V2_advanced/FreeU_S1S2.py exact algorithm:
256
+ 1. Compute LF/HF ratio before.
257
+ 2. Apply Gaussian filter.
258
+ 3. If enable_adaptive_cap and drop > cap_threshold: loop up to MAX_CAP_ITER=3.
259
+ adaptive mode: eff_factor = cap_factor * (cap_threshold / drop)
260
+ fixed mode: eff_factor = cap_factor
261
+ capped_s = 1 - eff_factor*(1-s_scale) [interpolate FROM ORIGINAL]
262
+ capped_s = max(capped_s, current_s*(1+1e-4))
263
+ Re-apply from original_hsp with capped_s.
264
+ hf_boost combined = max(si.hf_boost, si.skip_high_end_factor) [kohaku pattern]
265
+ """
266
+ s_scale = si.skip_factor
267
+ radius_r = si.fft_radius_ratio
268
+ hf_boost = max(si.hf_boost, si.skip_high_end_factor)
269
+ orig_dev = hsp.device
270
+ H, W = hsp.shape[-2:]
271
+ R_eff = max(1, int(min(H, W) * radius_r))
272
+
273
+ # CRITICAL ORDER: init cpu-fallback flag and helpers BEFORE any FFT call
274
+ use_cpu = _on_cpu_devices.get(orig_dev, not is_gpu_complex_supported(hsp))
275
+ if use_cpu:
276
+ _on_cpu_devices[orig_dev] = True
277
+
278
+ def _tod(t): # to FFT-safe device
279
+ return t.cpu() if use_cpu else t
280
+
281
+ def _fromd(t): # back to original device
282
+ return t.to(orig_dev) if use_cpu else t
283
+
284
+ def _energy(t):
285
+ return get_band_energy_stats(_tod(t), R_eff)
286
+
287
+ def _filt(inp, sc):
288
+ nonlocal use_cpu
289
+ try:
290
+ out = fourier_filter_gauss(_tod(inp), radius_r, sc, hf_boost)
291
+ return _fromd(out)
292
+ except Exception:
293
+ if not use_cpu:
294
+ logging.warning(f"[MegaFreeU] {orig_dev} -> CPU fallback for FFT")
295
+ _on_cpu_devices[orig_dev] = True
296
+ use_cpu = True
297
+ return fourier_filter_gauss(inp.cpu(), radius_r, sc, hf_boost).to(orig_dev)
298
+ return inp
299
+
300
+ # Pre-filter energy (now safe on all devices)
301
+ lf_b, hf_b, cover = _energy(hsp)
302
+ ratio_b = lf_b / hf_b if hf_b > 1e-6 else float("inf")
303
+ if verbose:
304
+ logging.info(f"[MegaFreeU] Gauss {H}x{W} R={R_eff}px cov={cover:.1f}% "
305
+ f"LF={lf_b:.3e} HF={hf_b:.3e} ratio_b={ratio_b:.4f}")
306
+
307
+ hsp_filt = _filt(hsp, s_scale)
308
+ if not si.enable_adaptive_cap:
309
+ return hsp_filt
310
+
311
+ MAX_CAP_ITER = 3
312
+ original_hsp = hsp
313
+ current_s = s_scale
314
+ lf_a, hf_a, _ = _energy(hsp_filt)
315
+ ratio_a = lf_a / hf_a if hf_a > 1e-6 else float("inf")
316
+ drop = 1.0 - (ratio_a / ratio_b) if ratio_b > 1e-6 else 0.0
317
+ orig_drop = drop
318
+ iters = 0
319
+ hsp_cur = hsp_filt
320
+
321
+ while (si.enable_adaptive_cap
322
+ and drop > si.cap_threshold
323
+ and current_s < 0.999
324
+ and iters < MAX_CAP_ITER):
325
+
326
+ if iters == 0:
327
+ logging.warning(f"[MegaFreeU] Over-attenuation: drop={drop*100:.1f}% > "
328
+ f"{si.cap_threshold*100:.1f}% s={s_scale:.4f}")
329
+
330
+ eff_f = si.cap_factor
331
+ if si.adaptive_cap_mode == "adaptive":
332
+ eff_f = si.cap_factor * (si.cap_threshold / max(drop, 1e-8))
333
+
334
+ capped_s = 1.0 - eff_f * (1.0 - s_scale) # interpolate from ORIGINAL s
335
+ capped_s = max(capped_s, current_s * (1.0 + 1e-4)) # only ever relax
336
+ if abs(capped_s - current_s) < 1e-4:
337
+ if verbose: logging.info(" Cap converged.")
338
+ break
339
+
340
+ if verbose:
341
+ logging.info(f" Cap iter {iters+1}: s {current_s:.4f}->{capped_s:.4f} eff={eff_f:.4f}")
342
+
343
+ try:
344
+ hsp_new = _filt(original_hsp, capped_s)
345
+ except Exception as e:
346
+ logging.error(f"[MegaFreeU] cap re-apply error: {e}")
347
+ hsp_cur = original_hsp # restore to original on error (ComfyUI FreeU_S1S2.py pattern)
348
+ break
349
+
350
+ hsp_cur = hsp_new
351
+ lf_a, hf_a, _ = _energy(hsp_cur)
352
+ ratio_a = lf_a / hf_a if hf_a > 1e-6 else float("inf")
353
+ drop = 1.0 - (ratio_a / ratio_b) if ratio_b > 1e-6 else 0.0
354
+ current_s = capped_s
355
+ iters += 1
356
+
357
+ if iters > 0 or verbose:
358
+ logging.info(f"[MegaFreeU] Cap done: {orig_drop*100:.1f}%->{drop*100:.1f}% "
359
+ f"({iters} iters s_final={current_s:.4f})")
360
+ return hsp_cur
361
+
362
+ # ── Schedule (sd-webui-freeu exact) ──────────────────────────────────────────
363
+ def get_schedule_ratio() -> float:
364
+ from modules import shared
365
+ st = global_state.instance
366
+ steps = shared.state.sampling_steps or 20
367
+ cur = global_state.current_sampling_step
368
+ start = _to_step(st.start_ratio, steps)
369
+ stop = _to_step(st.stop_ratio, steps)
370
+ if start == stop:
371
+ smooth = 0.0
372
+ elif cur < start:
373
+ smooth = min(1.0, max(0.0, cur / (start + 1e-8)))
374
+ else:
375
+ smooth = min(1.0, max(0.0, 1 + (cur - start) / (start - stop + 1e-8)))
376
+ flat = 1.0 if start <= cur < stop else 0.0
377
+ return lerp(flat, smooth, st.transition_smoothness)
378
+
379
+ def get_stage_bsratio(b_start: float, b_end: float) -> float:
380
+ """Independent B/S timestep range gate (FreeU_V2_timestepadd concept -> step fraction)."""
381
+ from modules import shared
382
+ steps = max(shared.state.sampling_steps or 20, 1)
383
+ cur = global_state.current_sampling_step
384
+ pct = cur / (steps - 1) if steps > 1 else 0.0
385
+ return 1.0 if b_start <= pct <= b_end else 0.0
386
+
387
+ def _to_step(v, steps):
388
+ return int(v * steps) if isinstance(v, float) else int(v)
389
+
390
+ # ── Stage auto-detection (FreeU_B1B2.py + kohaku exact) ──────────────────────
391
+ _stage_channels: Tuple[int, int, int] = (1280, 640, 320)
392
+
393
+ def detect_model_channels():
394
+ global _stage_channels
395
+ try:
396
+ from modules import shared
397
+ mc = int(shared.sd_model.model.diffusion_model.model_channels)
398
+ _stage_channels = (mc * 4, mc * 2, mc * 1)
399
+ except Exception:
400
+ _stage_channels = (1280, 640, 320)
401
+
402
+ def get_stage_index(dims: int, channel_threshold: int = 96) -> Optional[int]:
403
+ """FreeU_B1B2.py abs(ch - target) <= channel_threshold proximity match."""
404
+ for i, target in enumerate(_stage_channels):
405
+ if abs(dims - target) <= channel_threshold:
406
+ return i
407
+ return None
408
+
409
+ # ── Override scales parser (WAS nodes.py format exact) ───────────────────────
410
+ def parse_override_scales(text: str) -> Optional[List]:
411
+ if not text or not text.strip():
412
+ return None
413
+ result = []
414
+ for line in text.strip().splitlines():
415
+ line = line.strip()
416
+ if not line or line.startswith(("#", "!", "//")):
417
+ continue
418
+ parts = line.split(",")
419
+ if len(parts) == 2:
420
+ try:
421
+ result.append((int(parts[0].strip()), float(parts[1].strip())))
422
+ except ValueError:
423
+ pass
424
+ return result if result else None
425
+
426
+ class _VerboseRef:
427
+ value: bool = False
428
+ verbose_ref = _VerboseRef()
429
+
430
+ # ── Core th.cat hijack (sd-webui-freeu exact + extended) ─────────────────────
431
+ def free_u_cat_hijack(hs, *args, original_function, **kwargs):
432
+ """
433
+ Intercepts torch.cat([h, h_skip], dim=1) in UNet output_blocks.
434
+ Signature: kwargs=={"dim":1} and len(hs)==2 (sd-webui-freeu exact check).
435
+
436
+ Why th.cat over alternatives:
437
+ - sdwebui-freeU-extension CondFunc(UNetModel.forward): rewrites full forward,
438
+ incompatible with other extensions, plus 2 bugs in fourier mask.
439
+ - kohaku register_forward_hook: output already concatenated,
440
+ can't cleanly separate h from h_skip for independent filtering.
441
+ - th.cat hijack: intercepts exactly [h, h_skip] before concatenation. CORRECT.
442
+ """
443
+ st = global_state.instance
444
+ if not st.enable:
445
+ return original_function(hs, *args, **kwargs)
446
+
447
+ sched = get_schedule_ratio()
448
+ if sched == 0:
449
+ return original_function(hs, *args, **kwargs)
450
+
451
+ try:
452
+ h, h_skip = hs
453
+ if list(kwargs.keys()) != ["dim"] or kwargs.get("dim", -1) != 1:
454
+ return original_function(hs, *args, **kwargs)
455
+ except (ValueError, TypeError):
456
+ return original_function(hs, *args, **kwargs)
457
+
458
+ dims = int(h.shape[1])
459
+ stage_idx = get_stage_index(dims, st.channel_threshold)
460
+ if stage_idx is None:
461
+ return original_function(hs, *args, **kwargs)
462
+
463
+ si = st.stage_infos[stage_idx]
464
+ version = st.version
465
+ verbose = verbose_ref.value
466
+
467
+ # ── BACKBONE ─────────────────────────────────────────────────────────────
468
+ b_gate = get_stage_bsratio(si.b_start_ratio, si.b_end_ratio)
469
+ eff_b = sched * b_gate
470
+
471
+ if eff_b > 0.0 and abs(si.backbone_factor - 1.0) > 1e-6:
472
+ try:
473
+ rbegin, rend, rinv = ratio_to_region(si.backbone_width, si.backbone_offset, dims)
474
+ ch_idx = torch.arange(dims, device=h.device)
475
+ mask = (rbegin <= ch_idx) & (ch_idx <= rend)
476
+ if rinv: mask = ~mask
477
+ mask = mask.reshape(1, -1, 1, 1).to(h.dtype)
478
+
479
+ eff_factor = float(lerp(1.0, si.backbone_factor, eff_b))
480
+ scale = get_backbone_scale(h, eff_factor, version)
481
+ # h_scaled_full: full h with mask region scaled, rest unchanged
482
+ # This matches original: h *= mask*scale + (1-mask)
483
+ h_scaled_full = h * (mask * scale + (1.0 - mask))
484
+
485
+ bmode = si.backbone_blend_mode
486
+ if bmode in BLENDING_MODES and abs(si.backbone_blend - 1.0) > 1e-6:
487
+ # Blend on FULL tensors so modes like slerp/hslerp see proper norms.
488
+ # Then restore unmasked channels to original h.
489
+ h_blended = BLENDING_MODES[bmode](h, h_scaled_full, si.backbone_blend)
490
+ h = h * (1.0 - mask) + h_blended * mask
491
+ else:
492
+ h = h_scaled_full
493
+ except Exception as e:
494
+ logging.warning(f"[MegaFreeU] B-scaling stage {stage_idx}: {e}")
495
+
496
+ # ── SKIP / FOURIER ────────────────────────────────────────────────────────
497
+ s_gate = get_stage_bsratio(si.s_start_ratio, si.s_end_ratio)
498
+ eff_s = sched * s_gate
499
+
500
+ if eff_s > 0.0 and (abs(si.skip_factor - 1.0) > 1e-6
501
+ or abs(si.hf_boost - 1.0) > 1e-6
502
+ or abs(si.skip_high_end_factor - 1.0) > 1e-6):
503
+ try:
504
+ s_scale = float(lerp(1.0, si.skip_factor, eff_s))
505
+ s_high = float(lerp(1.0, si.skip_high_end_factor, eff_s))
506
+
507
+ if si.fft_type == "gaussian":
508
+ hf_b = float(lerp(1.0, si.hf_boost, eff_s))
509
+ si_eff = dataclasses.replace(si, skip_factor=s_scale, skip_high_end_factor=s_high, hf_boost=hf_b)
510
+ h_skip = filter_skip_gaussian_adaptive(h_skip, si_eff, verbose)
511
+ else:
512
+ override = parse_override_scales(st.override_scales)
513
+ ms_preset = override or global_state.MSCALES.get(st.multiscale_mode)
514
+ if ms_preset is not None:
515
+ h_skip = filter_skip_box_multiscale(
516
+ h_skip, si.skip_cutoff, s_scale, ms_preset,
517
+ st.multiscale_strength, s_high)
518
+ else:
519
+ h_skip = filter_skip_box(h_skip, si.skip_cutoff, s_scale, s_high)
520
+ except Exception as e:
521
+ logging.warning(f"[MegaFreeU] skip filter stage {stage_idx}: {e}")
522
+
523
+ return original_function([h, h_skip], *args, **kwargs)
524
+
525
+ # ── Patch (sd-webui-freeu exact + ControlNet) ─────────────────────────────────
526
+ _patched = False # guard against double-patch on hot-reload
527
+
528
+ def patch():
529
+ global _patched
530
+ try:
531
+ from modules.sd_hijack_unet import th
532
+ except ImportError:
533
+ print("[MegaFreeU] sd_hijack_unet not available", file=sys.stderr); return
534
+
535
+ if _patched or (hasattr(th.cat, "func") and getattr(th.cat.func, "__name__", "") == "free_u_cat_hijack"):
536
+ return # already patched (by name; handles module reload)
537
+ th.cat = functools.partial(free_u_cat_hijack, original_function=th.cat)
538
+ _patched = True
539
+
540
+ cn_status = "enabled"
541
+ try:
542
+ from modules import scripts
543
+ cn_paths = [
544
+ str(pathlib.Path(scripts.basedir()).parent.parent / "extensions-builtin" / "sd-webui-controlnet"),
545
+ str(pathlib.Path(scripts.basedir()).parent / "sd-webui-controlnet"),
546
+ ]
547
+ sys.path[0:0] = cn_paths
548
+ try:
549
+ import scripts.hook as cn_hook
550
+ cn_hook.th.cat = functools.partial(free_u_cat_hijack, original_function=cn_hook.th.cat)
551
+ except ImportError:
552
+ cn_status = "disabled"
553
+ finally:
554
+ for p in cn_paths:
555
+ if p in sys.path: sys.path.remove(p)
556
+ except Exception:
557
+ cn_status = "error"
558
+
559
+ print(f"[MegaFreeU] th.cat patched ControlNet: *{cn_status}*")
mega_freeu_a1111/lib_mega_freeu/xyz_grid.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ lib_mega_freeu/xyz_grid.py β€” XYZ/XY grid axes
3
+
4
+ Source: sd-webui-freeu/lib_free_u/xyz_grid.py (exact find_xyz_module check)
5
+ Extended with all Mega FreeU per-stage params.
6
+ """
7
+ import sys
8
+ from types import ModuleType
9
+ from typing import Optional
10
+ from modules import scripts
11
+ from lib_mega_freeu import global_state
12
+
13
+
14
+ def patch():
15
+ xyz_module = _find_xyz_module()
16
+ if xyz_module is None:
17
+ print("[MegaFreeU] xyz_grid.py not found β€” XYZ disabled", file=sys.stderr)
18
+ return
19
+
20
+ def _apply(k, key_map=None):
21
+ def cb(_p, v, _vs):
22
+ if key_map is not None:
23
+ v = key_map.get(v, v)
24
+ global_state.xyz_attrs[k] = v
25
+ return cb
26
+
27
+ opts = [
28
+ xyz_module.AxisOption("[MegaFreeU] Enabled", _bool, _apply("enable"),
29
+ choices=lambda: ["False","True"]),
30
+ xyz_module.AxisOption("[MegaFreeU] Version", str, _apply("version",
31
+ key_map=global_state.ALL_VERSIONS),
32
+ choices=lambda: list(global_state.ALL_VERSIONS.keys())),
33
+ xyz_module.AxisOption("[MegaFreeU] Preset", str, _apply("preset"),
34
+ choices=_choices_preset),
35
+ xyz_module.AxisOption("[MegaFreeU] Start At Step", _num, _apply("start_ratio")),
36
+ xyz_module.AxisOption("[MegaFreeU] Stop At Step", _num, _apply("stop_ratio")),
37
+ xyz_module.AxisOption("[MegaFreeU] Smoothness", float, _apply("transition_smoothness")),
38
+ xyz_module.AxisOption("[MegaFreeU] Multiscale Mode", str, _apply("multiscale_mode"),
39
+ choices=lambda: list(global_state.MSCALES.keys())),
40
+ xyz_module.AxisOption("[MegaFreeU] Multiscale Str", float, _apply("multiscale_strength")),
41
+ xyz_module.AxisOption("[MegaFreeU] Ch Threshold", int, _apply("channel_threshold")),
42
+ ]
43
+
44
+ for i in range(global_state.STAGES_COUNT):
45
+ n = i + 1
46
+ opts += [
47
+ xyz_module.AxisOption(f"[MFU] S{n} B Scale", float, _apply(f"b{i}")),
48
+ xyz_module.AxisOption(f"[MFU] S{n} B Offset", float, _apply(f"o{i}")),
49
+ xyz_module.AxisOption(f"[MFU] S{n} B Width", float, _apply(f"w{i}")),
50
+ xyz_module.AxisOption(f"[MFU] S{n} Blend Mode", str, _apply(f"bm{i}"),
51
+ choices=lambda: global_state.BLEND_MODE_NAMES),
52
+ xyz_module.AxisOption(f"[MFU] S{n} Blend Str", float, _apply(f"bb{i}")),
53
+ xyz_module.AxisOption(f"[MFU] S{n} S Scale", float, _apply(f"s{i}")),
54
+ xyz_module.AxisOption(f"[MFU] S{n} Cutoff", float, _apply(f"t{i}")),
55
+ xyz_module.AxisOption(f"[MFU] S{n} High-End", float, _apply(f"h{i}")),
56
+ xyz_module.AxisOption(f"[MFU] S{n} HF Boost", float, _apply(f"hfb{i}")),
57
+ xyz_module.AxisOption(f"[MFU] S{n} B Start%", float, _apply(f"bs{i}")),
58
+ xyz_module.AxisOption(f"[MFU] S{n} B End%", float, _apply(f"be{i}")),
59
+ xyz_module.AxisOption(f"[MFU] S{n} S Start%", float, _apply(f"ss{i}")),
60
+ xyz_module.AxisOption(f"[MFU] S{n} S End%", float, _apply(f"se{i}")),
61
+ xyz_module.AxisOption(f"[MFU] S{n} FFT Type", str, _apply(f"ft{i}"),
62
+ choices=lambda: global_state.FFT_TYPES),
63
+ xyz_module.AxisOption(f"[MFU] S{n} Radius", float, _apply(f"fr{i}")),
64
+ xyz_module.AxisOption(f"[MFU] S{n} Cap Enable", _bool, _apply(f"cap{i}")),
65
+ xyz_module.AxisOption(f"[MFU] S{n} Cap Thresh", float, _apply(f"ct{i}")),
66
+ xyz_module.AxisOption(f"[MFU] S{n} Cap Factor", float, _apply(f"cf{i}")),
67
+ xyz_module.AxisOption(f"[MFU] S{n} Cap Mode", str, _apply(f"acm{i}"),
68
+ choices=lambda: ["adaptive", "fixed"]),
69
+ ]
70
+
71
+ xyz_module.axis_options.extend(opts)
72
+
73
+
74
+ def _find_xyz_module() -> Optional[ModuleType]:
75
+ """Exact check from sd-webui-freeu/lib_free_u/xyz_grid.py"""
76
+ for data in scripts.scripts_data:
77
+ if data.script_class.__module__ in {"xyz_grid.py", "xy_grid.py"} and hasattr(data, "module"):
78
+ return data.module
79
+ return None
80
+
81
+
82
+ def _choices_preset():
83
+ presets = list(global_state.all_presets.keys())
84
+ presets.insert(0, "UI Settings")
85
+ return presets
86
+
87
+
88
+ def _bool(s):
89
+ s = str(s).lower()
90
+ if s in ("true","1","yes"): return True
91
+ if s in ("false","0","no"): return False
92
+ return bool(s)
93
+
94
+
95
+ def _num(s):
96
+ try: return int(s)
97
+ except (ValueError, TypeError): return float(s)
mega_freeu_a1111/scripts/mega_freeu.py ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ scripts/mega_freeu.py - Mega FreeU for A1111 / Forge
3
+
4
+ Combined from 5 sources:
5
+ 1. sd-webui-freeu th.cat hijack, V1/V2 backbone, box filter, schedule,
6
+ presets JSON, PNG metadata, XYZ, ControlNet, region masking,
7
+ dict-API compat (alwayson_scripts legacy)
8
+ 2. WAS FreeU_Advanced 9 blending modes, 13 multi-scale FFT presets, override_scales,
9
+ Post-CFG Shift (WAS_PostCFGShift ported to A1111 callback)
10
+ NOTE: target_block / input_block / middle_block / slice_b1/b2
11
+ were not ported β€” th.cat hijack works on output-side skip concat.
12
+ 3. ComfyUI_FreeU_V2_Adv Gaussian filter, Adaptive Cap (MAX_CAP_ITER=3),
13
+ independent B/S timestep ranges per-stage, channel_threshold
14
+ 4. FreeU_V2_timestepadd b_start/b_end%, s_start/s_end% per-stage gating
15
+ NOTE: gating uses step-fraction (cur/total), not percent_to_sigma
16
+ as in original ComfyUI sources. Conceptually equivalent.
17
+ 5. nrs_kohaku_v3.5 hf_boost param, on_cpu_devices dict, gaussian standalone
18
+
19
+ BUGS FIXED vs sdwebui-freeU-extension:
20
+ BUG 1: bool mask in Fourier filter (scale multiplication was NOOP)
21
+ BUG 2: single-quadrant mask instead of symmetric center
22
+ """
23
+ import dataclasses
24
+ import json
25
+ from typing import List
26
+
27
+ import gradio as gr
28
+ from modules import script_callbacks, scripts, shared, processing
29
+
30
+ from lib_mega_freeu import global_state, unet, xyz_grid
31
+
32
+ _steps_comps = {"txt2img": None, "img2img": None}
33
+ _steps_cbs = {"txt2img": [], "img2img": []}
34
+
35
+ _SF = [f.name for f in dataclasses.fields(global_state.StageInfo)]
36
+ _SN = len(_SF) # 19 fields per stage
37
+
38
+
39
+ def _stage_ui(idx, si, elem_id_fn):
40
+ n = idx + 1
41
+ ch = {0: "~1280ch (deep)", 1: "~640ch (mid)", 2: "~320ch (shallow)"}.get(idx, f"stage{n}")
42
+
43
+ with gr.Accordion(open=(idx < 2), label=f"Stage {n} ({ch})"):
44
+
45
+ # Backbone
46
+ gr.HTML(f"<p style=\'margin:4px 0;font-size:.82em;color:#aaa;\'>Backbone h (B)</p>")
47
+ with gr.Row():
48
+ bf = gr.Slider(label=f"B{n} Scale", minimum=-1, maximum=3, step=0.001,
49
+ value=si.backbone_factor,
50
+ info=">1 strengthens backbone features. V2: adaptive per-region.")
51
+ bo = gr.Slider(label=f"B{n} Offset", minimum=0, maximum=1, step=0.001,
52
+ value=si.backbone_offset, info="Channel region start [0-1].")
53
+ bw = gr.Slider(label=f"B{n} Width", minimum=-1, maximum=1, step=0.001,
54
+ value=si.backbone_width, info="Channel region width. Negative=invert.")
55
+ with gr.Row():
56
+ bm = gr.Dropdown(label=f"B{n} Blend Mode",
57
+ choices=global_state.BLEND_MODE_NAMES,
58
+ value=si.backbone_blend_mode,
59
+ info="lerp=default, stable_slerp=quality, inject=additive")
60
+ bb = gr.Slider(label=f"B{n} Blend Str", minimum=0, maximum=2, step=0.001,
61
+ value=si.backbone_blend)
62
+ gr.HTML("<p style=\'font-size:.75em;color:#888;margin:2px 0;\'>B timestep range (ComfyUI V2)</p>")
63
+ with gr.Row():
64
+ bsr = gr.Slider(label=f"B{n} Start%", minimum=0, maximum=1, step=0.001,
65
+ value=si.b_start_ratio, info="B activates at this step fraction.")
66
+ ber = gr.Slider(label=f"B{n} End%", minimum=0, maximum=1, step=0.001,
67
+ value=si.b_end_ratio, info="B stops. 0.35=structure phase only.")
68
+
69
+ # Skip / FFT
70
+ gr.HTML(f"<p style=\'margin:8px 0 4px;font-size:.82em;color:#aaa;\'>Skip h_skip (S) - Fourier Filter</p>")
71
+ with gr.Row():
72
+ sf = gr.Slider(label=f"S{n} LF Scale", minimum=-1, maximum=3, step=0.001,
73
+ value=si.skip_factor,
74
+ info="<1 suppresses LF components. 0.2=strong suppression.")
75
+ she = gr.Slider(label=f"S{n} HF (Box)", minimum=-1, maximum=3, step=0.001,
76
+ value=si.skip_high_end_factor,
77
+ info="HF scale outside LF region (box filter). >1=boost HF.")
78
+ hfb = gr.Slider(label=f"S{n} HF Boost (Gauss)", minimum=0, maximum=3, step=0.001,
79
+ value=si.hf_boost,
80
+ info="Gaussian explicit HF multiplier. Combined as max(hf_boost, high_end).")
81
+ with gr.Row():
82
+ ft = gr.Radio(label=f"S{n} FFT Type",
83
+ choices=global_state.FFT_TYPES, value=si.fft_type,
84
+ info="gaussian=smooth no-ringing. box=original FreeU (both bugs fixed).")
85
+ sco = gr.Slider(label=f"S{n} Cutoff (Box)", minimum=0, maximum=1, step=0.001,
86
+ value=si.skip_cutoff, info="Box: LF cutoff fraction. 0=1px default.")
87
+ srr = gr.Slider(label=f"S{n} Radius (Gauss)", minimum=0.01, maximum=0.5, step=0.001,
88
+ value=si.fft_radius_ratio,
89
+ info="Gaussian R=ratio*min(H,W). 0.07=moderate LF.")
90
+ gr.HTML("<p style=\'font-size:.75em;color:#888;margin:2px 0;\'>S timestep range (ComfyUI V2)</p>")
91
+ with gr.Row():
92
+ ssr = gr.Slider(label=f"S{n} Start%", minimum=0, maximum=1, step=0.001,
93
+ value=si.s_start_ratio,
94
+ info="S activates. Tip: set = B End% for clean phase separation.")
95
+ ser = gr.Slider(label=f"S{n} End%", minimum=0, maximum=1, step=0.001,
96
+ value=si.s_end_ratio, info="S stops. 1.0=to last step.")
97
+
98
+ # Adaptive Cap
99
+ gr.HTML("<p style=\'font-size:.75em;color:#888;margin:4px 0;\'>Adaptive Cap - prevents LF over-attenuation (FreeU_S1S2.py)</p>")
100
+ with gr.Row():
101
+ eac = gr.Checkbox(label=f"S{n} Enable Cap", value=si.enable_adaptive_cap,
102
+ info="Iteratively weakens Gaussian if LF/HF drop exceeds threshold.")
103
+ ct = gr.Slider(label="Threshold", minimum=0, maximum=1, step=0.001,
104
+ value=si.cap_threshold, info="Max allowed LF/HF ratio drop. 0.35=35%.")
105
+ cf = gr.Slider(label="Factor", minimum=0, maximum=1, step=0.001,
106
+ value=si.cap_factor, info="Relaxation factor. 0.6=moderate.")
107
+ cm = gr.Radio(label="Mode", choices=["adaptive", "fixed"],
108
+ value=si.adaptive_cap_mode,
109
+ info="adaptive: scales factor with over-attenuation. fixed: always cap_factor.")
110
+
111
+ # Return exactly in _SF field order
112
+ return [bf, sf, bo, bw, sco, she, bm, bb, bsr, ber, ssr, ser, ft, srr, hfb, eac, ct, cf, cm]
113
+
114
+
115
+ class MegaFreeUScript(scripts.Script):
116
+
117
+ def title(self): return "Mega FreeU"
118
+ def show(self, is_img2img): return scripts.AlwaysVisible
119
+
120
+ def ui(self, is_img2img):
121
+ global_state.reload_presets()
122
+ pnames = list(global_state.all_presets.keys())
123
+ def_sis = global_state.all_presets[pnames[0]].stage_infos
124
+
125
+ with gr.Accordion(open=False, label="Mega FreeU"):
126
+
127
+ # Top bar
128
+ with gr.Row():
129
+ enabled = gr.Checkbox(label="Enable Mega FreeU", value=False)
130
+ version = gr.Dropdown(
131
+ label="Version",
132
+ choices=list(global_state.ALL_VERSIONS.keys()),
133
+ value="Version 2",
134
+ elem_id=self.elem_id("version"),
135
+ info="V2=adaptive hidden-mean backbone. V1=flat multiplier.")
136
+
137
+ with gr.Row():
138
+ preset_dd = gr.Dropdown(
139
+ label="Preset", choices=pnames, value=pnames[0],
140
+ allow_custom_value=True,
141
+ elem_id=self.elem_id("preset_name"),
142
+ info="Apply loads settings. Custom name enables Save. Delete auto-saves.")
143
+ btn_apply = gr.Button("Apply", size="sm", elem_classes="tool")
144
+ btn_save = gr.Button("Save", size="sm", elem_classes="tool")
145
+ btn_refresh = gr.Button("Refresh", size="sm", elem_classes="tool")
146
+ btn_delete = gr.Button("Delete", size="sm", elem_classes="tool")
147
+
148
+ # Global schedule
149
+ gr.HTML("<p style=\'font-size:.82em;color:#aaa;margin:6px 0 2px;\'>Global Schedule</p>")
150
+ with gr.Row():
151
+ start_r = gr.Slider(label="Start At", elem_id=self.elem_id("start_at_step"),
152
+ minimum=0, maximum=1, step=0.001, value=0)
153
+ stop_r = gr.Slider(label="Stop At", elem_id=self.elem_id("stop_at_step"),
154
+ minimum=0, maximum=1, step=0.001, value=1)
155
+ smooth = gr.Slider(label="Transition Smoothness",
156
+ elem_id=self.elem_id("transition_smoothness"),
157
+ minimum=0, maximum=1, step=0.001, value=0,
158
+ info="0=hard on/off. 1=smooth fade.")
159
+
160
+ # Box Multi-Scale (WAS FreeU_Advanced)
161
+ with gr.Accordion(open=False, label="Box Multi-Scale FFT (WAS FreeU_Advanced)"):
162
+ gr.HTML("<p style=\'font-size:.8em;color:#888;\'>Applied on top of Box filter. Ignored in Gaussian mode.</p>")
163
+ with gr.Row():
164
+ ms_mode = gr.Dropdown(label="Multiscale Mode",
165
+ choices=list(global_state.MSCALES.keys()),
166
+ value="Default")
167
+ ms_str = gr.Slider(label="Strength", minimum=0, maximum=1,
168
+ step=0.001, value=1.0)
169
+ ov_scales = gr.Textbox(
170
+ label="Override Scales (WAS format: radius_px, scale per line, # comments)",
171
+ lines=3,
172
+ placeholder="# Example custom scales:\n10, 1.5\n20, 0.8",
173
+ value="")
174
+
175
+ with gr.Row():
176
+ ch_thresh = gr.Slider(
177
+ label="Channel Match Threshold (+-)",
178
+ elem_id=self.elem_id("ch_thresh"),
179
+ minimum=0, maximum=256, step=1, value=96,
180
+ info="Stage channel tolerance. 96=standard (FreeU_B1B2.py default).")
181
+
182
+ # Per-stage accordions
183
+ flat_comps: List = []
184
+ for i in range(global_state.STAGES_COUNT):
185
+ si = def_sis[i] if i < len(def_sis) else global_state.StageInfo()
186
+ flat_comps.extend(_stage_ui(i, si, self.elem_id))
187
+
188
+ # Post-CFG Shift (WAS_PostCFGShift -> A1111)
189
+ with gr.Accordion(open=False, label="Post-CFG Shift (WAS_PostCFGShift -> A1111 callback)"):
190
+ gr.HTML("<p style=\'font-size:.8em;color:#888;\'>Runs after combine_denoised. Blends denoised*b into output via on_cfg_after_cfg callback.</p>")
191
+ with gr.Row():
192
+ pcfg_en = gr.Checkbox(label="Enable Post-CFG Shift", value=False)
193
+ pcfg_steps = gr.Slider(label="Max Steps", minimum=1, maximum=200,
194
+ step=1, value=20,
195
+ info="Apply only to first N steps.")
196
+ with gr.Row():
197
+ pcfg_mode = gr.Dropdown(label="Blend Mode",
198
+ choices=global_state.BLEND_MODE_NAMES,
199
+ value="inject")
200
+ pcfg_bl = gr.Slider(label="Blend", minimum=0, maximum=5,
201
+ step=0.001, value=1.0)
202
+ pcfg_b = gr.Slider(label="B Factor", minimum=0, maximum=5,
203
+ step=0.001, value=1.1,
204
+ info=">1 amplifies shift.")
205
+ with gr.Row():
206
+ pcfg_fou = gr.Checkbox(label="Apply Fourier Filter", value=False)
207
+ pcfg_mmd = gr.Dropdown(label="Fourier Multiscale",
208
+ choices=list(global_state.MSCALES.keys()),
209
+ value="Default")
210
+ pcfg_mst = gr.Slider(label="Fourier Strength", minimum=0, maximum=1,
211
+ step=0.001, value=1.0)
212
+ with gr.Row():
213
+ pcfg_thr = gr.Slider(label="Threshold (px)", minimum=1, maximum=20,
214
+ step=1, value=1,
215
+ info="Box filter LF radius in pixels.")
216
+ pcfg_s = gr.Slider(label="S Scale", minimum=0, maximum=3,
217
+ step=0.001, value=0.5)
218
+ pcfg_gain = gr.Slider(label="Force Gain", minimum=0, maximum=5,
219
+ step=0.01, value=1.0,
220
+ info="Final output multiplier.")
221
+
222
+ verbose = gr.Checkbox(label="Verbose Logging (Adaptive Cap, energy stats)", value=False)
223
+
224
+ # Hidden PNG infotext components
225
+ sched_info = gr.HTML(visible=False)
226
+ stages_info = gr.HTML(visible=False)
227
+ version_info = gr.HTML(visible=False)
228
+ ms_mode_info = gr.HTML(visible=False)
229
+ ms_str_info = gr.HTML(visible=False)
230
+ ov_scales_info = gr.HTML(visible=False)
231
+ ch_thresh_info = gr.HTML(visible=False)
232
+ postcfg_info = gr.HTML(visible=False)
233
+ verbose_info = gr.HTML(visible=False)
234
+ # Legacy sd-webui-freeu keys for backward compat
235
+ legacy_sched_info = gr.HTML(visible=False)
236
+ legacy_stages_info = gr.HTML(visible=False)
237
+ legacy_version_info = gr.HTML(visible=False)
238
+
239
+ # Preset buttons
240
+ def _btn_upd(name):
241
+ ex = name in global_state.all_presets
242
+ usr = name not in global_state.default_presets
243
+ return (gr.update(interactive=ex),
244
+ gr.update(interactive=usr),
245
+ gr.update(interactive=usr and ex))
246
+
247
+ preset_dd.change(fn=_btn_upd, inputs=[preset_dd],
248
+ outputs=[btn_apply, btn_save, btn_delete])
249
+
250
+ def _apply_p(name):
251
+ p = global_state.all_presets.get(name)
252
+ n_extras = 20 # 8 main + 11 Post-CFG + 1 verbose
253
+ if p is None:
254
+ return [gr.skip()] * (n_extras + len(flat_comps))
255
+ flat = []
256
+ for si in p.stage_infos:
257
+ for f in _SF:
258
+ flat.append(getattr(si, f))
259
+ vlabel = global_state.REVERSED_VERSIONS.get(p.version, "Version 2")
260
+ return (
261
+ gr.update(value=p.start_ratio),
262
+ gr.update(value=p.stop_ratio),
263
+ gr.update(value=p.transition_smoothness),
264
+ gr.update(value=vlabel),
265
+ gr.update(value=p.multiscale_mode),
266
+ gr.update(value=p.multiscale_strength),
267
+ gr.update(value=p.override_scales),
268
+ gr.update(value=p.channel_threshold),
269
+ gr.update(value=p.pcfg_enabled),
270
+ gr.update(value=p.pcfg_steps),
271
+ gr.update(value=p.pcfg_mode),
272
+ gr.update(value=p.pcfg_blend),
273
+ gr.update(value=p.pcfg_b),
274
+ gr.update(value=p.pcfg_fourier),
275
+ gr.update(value=p.pcfg_ms_mode),
276
+ gr.update(value=p.pcfg_ms_str),
277
+ gr.update(value=p.pcfg_threshold),
278
+ gr.update(value=p.pcfg_s),
279
+ gr.update(value=p.pcfg_gain),
280
+ gr.update(value=p.verbose),
281
+ *[gr.update(value=v) for v in flat],
282
+ )
283
+
284
+ btn_apply.click(
285
+ fn=_apply_p,
286
+ inputs=[preset_dd],
287
+ outputs=[
288
+ start_r, stop_r, smooth, version,
289
+ ms_mode, ms_str, ov_scales, ch_thresh,
290
+ pcfg_en, pcfg_steps, pcfg_mode, pcfg_bl, pcfg_b,
291
+ pcfg_fou, pcfg_mmd, pcfg_mst, pcfg_thr, pcfg_s, pcfg_gain,
292
+ verbose,
293
+ *flat_comps,
294
+ ]
295
+ )
296
+
297
+ def _save_p(
298
+ name, sr, sp, sm, ver, msm, mss, ovs, cht,
299
+ p_en, p_steps, p_mode, p_bl, p_b,
300
+ p_four, p_mmd, p_mst, p_thr, p_s, p_gain,
301
+ v_log,
302
+ *flat
303
+ ):
304
+ sis = _flat_to_sis(flat)
305
+ vc = global_state.ALL_VERSIONS.get(ver, "1")
306
+ global_state.all_presets[name] = global_state.State(
307
+ start_ratio=sr, stop_ratio=sp, transition_smoothness=sm,
308
+ version=vc,
309
+ multiscale_mode=msm,
310
+ multiscale_strength=float(mss),
311
+ override_scales=ovs or "",
312
+ channel_threshold=int(cht),
313
+ stage_infos=sis,
314
+ pcfg_enabled=bool(p_en),
315
+ pcfg_steps=int(p_steps),
316
+ pcfg_mode=str(p_mode),
317
+ pcfg_blend=float(p_bl),
318
+ pcfg_b=float(p_b),
319
+ pcfg_fourier=bool(p_four),
320
+ pcfg_ms_mode=str(p_mmd),
321
+ pcfg_ms_str=float(p_mst),
322
+ pcfg_threshold=int(p_thr),
323
+ pcfg_s=float(p_s),
324
+ pcfg_gain=float(p_gain),
325
+ verbose=bool(v_log),
326
+ )
327
+ global_state.save_presets()
328
+ return (
329
+ gr.update(choices=list(global_state.all_presets.keys())),
330
+ gr.update(interactive=True),
331
+ gr.update(interactive=True),
332
+ )
333
+
334
+ btn_save.click(
335
+ fn=_save_p,
336
+ inputs=[
337
+ preset_dd, start_r, stop_r, smooth, version,
338
+ ms_mode, ms_str, ov_scales, ch_thresh,
339
+ pcfg_en, pcfg_steps, pcfg_mode, pcfg_bl, pcfg_b,
340
+ pcfg_fou, pcfg_mmd, pcfg_mst, pcfg_thr, pcfg_s, pcfg_gain,
341
+ verbose,
342
+ *flat_comps,
343
+ ],
344
+ outputs=[preset_dd, btn_apply, btn_delete]
345
+ )
346
+
347
+ def _refresh_p(name):
348
+ global_state.reload_presets()
349
+ ex = name in global_state.all_presets
350
+ usr = name not in global_state.default_presets
351
+ ch = list(global_state.all_presets.keys())
352
+ return (gr.update(choices=ch, value=name),
353
+ gr.update(interactive=ex), gr.update(interactive=usr),
354
+ gr.update(interactive=usr and ex))
355
+
356
+ btn_refresh.click(fn=_refresh_p, inputs=[preset_dd],
357
+ outputs=[preset_dd, btn_apply, btn_save, btn_delete])
358
+
359
+ def _delete_p(name):
360
+ if name in global_state.all_presets and name not in global_state.default_presets:
361
+ idx = list(global_state.all_presets.keys()).index(name)
362
+ del global_state.all_presets[name]
363
+ global_state.save_presets()
364
+ names = list(global_state.all_presets.keys())
365
+ name = names[min(idx, len(names) - 1)]
366
+ ex = name in global_state.all_presets
367
+ usr = name not in global_state.default_presets
368
+ return (gr.update(choices=list(global_state.all_presets.keys()), value=name),
369
+ gr.update(interactive=ex), gr.update(interactive=usr),
370
+ gr.update(interactive=usr and ex))
371
+
372
+ btn_delete.click(fn=_delete_p, inputs=[preset_dd],
373
+ outputs=[preset_dd, btn_apply, btn_save, btn_delete])
374
+
375
+ # PNG schedule restore
376
+ def _restore_sched(info, steps):
377
+ if not info: return [gr.skip()] * 4
378
+ try:
379
+ parts = info.split(", ")
380
+ sr, sp, sm = parts[0], parts[1], parts[2]
381
+ total = max(int(float(steps)), 1)
382
+ def _r(v):
383
+ n = float(v.strip())
384
+ return n / total if n > 1.0 else n
385
+ return (gr.update(value=""), gr.update(value=_r(sr)),
386
+ gr.update(value=_r(sp)), gr.update(value=float(sm)))
387
+ except Exception:
388
+ return [gr.skip()] * 4
389
+
390
+ def _reg_sched_cb(steps_comp):
391
+ sched_info.change(fn=_restore_sched,
392
+ inputs=[sched_info, steps_comp],
393
+ outputs=[sched_info, start_r, stop_r, smooth])
394
+
395
+ mode_key = "img2img" if is_img2img else "txt2img"
396
+ if _steps_comps[mode_key] is None:
397
+ _steps_cbs[mode_key].append(_reg_sched_cb)
398
+ else:
399
+ _reg_sched_cb(_steps_comps[mode_key])
400
+
401
+ def _restore_stages(info):
402
+ n_out = 2 + len(flat_comps)
403
+ if not info: return [gr.skip()] * n_out
404
+ try:
405
+ raw_list = json.loads(info)
406
+ sis = []
407
+ for d in raw_list:
408
+ known = {k: v for k, v in d.items()
409
+ if k in global_state.STAGE_FIELD_NAMES}
410
+ sis.append(global_state.StageInfo(**known))
411
+ while len(sis) < global_state.STAGES_COUNT:
412
+ sis.append(global_state.StageInfo())
413
+ except Exception:
414
+ return [gr.skip()] * n_out
415
+ flat = []
416
+ for si in sis:
417
+ for f in _SF:
418
+ flat.append(getattr(si, f, getattr(global_state.StageInfo(), f)))
419
+ auto_en = shared.opts.data.get("mega_freeu_png_auto_enable", True)
420
+ return (gr.update(value=""), gr.update(value=auto_en),
421
+ *[gr.update(value=v) for v in flat])
422
+
423
+ stages_info.change(fn=_restore_stages, inputs=[stages_info],
424
+ outputs=[stages_info, enabled, *flat_comps])
425
+
426
+ def _restore_ver(info):
427
+ if not info: return [gr.skip()] * 2
428
+ lbl = global_state.REVERSED_VERSIONS.get(info.strip(), info.strip())
429
+ return gr.update(value=""), gr.update(value=lbl)
430
+
431
+ version_info.change(fn=_restore_ver, inputs=[version_info],
432
+ outputs=[version_info, version])
433
+
434
+ # ── New extended PNG restore callbacks ─────────────────────────────
435
+ def _restore_ms_mode(info):
436
+ if not info: return gr.skip(), gr.skip()
437
+ return gr.update(value=""), gr.update(value=info.strip())
438
+
439
+ def _restore_ms_str(info):
440
+ if not info: return gr.skip(), gr.skip()
441
+ try: return gr.update(value=""), gr.update(value=float(info.strip()))
442
+ except Exception: return gr.skip(), gr.skip()
443
+
444
+ def _restore_ov_scales(info):
445
+ if info is None: return gr.skip(), gr.skip()
446
+ return gr.update(value=""), gr.update(value=info)
447
+
448
+ def _restore_ch_thresh(info):
449
+ if not info: return gr.skip(), gr.skip()
450
+ try: return gr.update(value=""), gr.update(value=int(float(info.strip())))
451
+ except Exception: return gr.skip(), gr.skip()
452
+
453
+ def _restore_verbose(info):
454
+ if not info: return gr.skip(), gr.skip()
455
+ return gr.update(value=""), gr.update(value=(info.strip().lower() == "true"))
456
+
457
+ def _restore_postcfg(info):
458
+ n = 12
459
+ if not info: return [gr.skip()] * n
460
+ try:
461
+ d = json.loads(info)
462
+ return (
463
+ gr.update(value=""),
464
+ gr.update(value=bool(d.get("enabled", False))),
465
+ gr.update(value=int(d.get("steps", 20))),
466
+ gr.update(value=str(d.get("mode", "inject"))),
467
+ gr.update(value=float(d.get("blend", 1.0))),
468
+ gr.update(value=float(d.get("b", 1.1))),
469
+ gr.update(value=bool(d.get("fourier", False))),
470
+ gr.update(value=str(d.get("ms_mode", "Default"))),
471
+ gr.update(value=float(d.get("ms_str", 1.0))),
472
+ gr.update(value=int(d.get("threshold", 1))),
473
+ gr.update(value=float(d.get("s", 0.5))),
474
+ gr.update(value=float(d.get("gain", 1.0))),
475
+ )
476
+ except Exception:
477
+ return [gr.skip()] * n
478
+
479
+ ms_mode_info.change(fn=_restore_ms_mode, inputs=[ms_mode_info],
480
+ outputs=[ms_mode_info, ms_mode])
481
+ ms_str_info.change(fn=_restore_ms_str, inputs=[ms_str_info],
482
+ outputs=[ms_str_info, ms_str])
483
+ ov_scales_info.change(fn=_restore_ov_scales, inputs=[ov_scales_info],
484
+ outputs=[ov_scales_info, ov_scales])
485
+ ch_thresh_info.change(fn=_restore_ch_thresh, inputs=[ch_thresh_info],
486
+ outputs=[ch_thresh_info, ch_thresh])
487
+ verbose_info.change(fn=_restore_verbose, inputs=[verbose_info],
488
+ outputs=[verbose_info, verbose])
489
+ postcfg_info.change(fn=_restore_postcfg, inputs=[postcfg_info],
490
+ outputs=[postcfg_info,
491
+ pcfg_en, pcfg_steps, pcfg_mode, pcfg_bl, pcfg_b,
492
+ pcfg_fou, pcfg_mmd, pcfg_mst, pcfg_thr, pcfg_s, pcfg_gain])
493
+
494
+ # Legacy sd-webui-freeu keys β€” reuse same restore logic
495
+ legacy_sched_info.change(fn=lambda info, steps: _restore_sched(info, steps),
496
+ inputs=[legacy_sched_info, _steps_comps.get(mode_key) or legacy_sched_info],
497
+ outputs=[legacy_sched_info, start_r, stop_r, smooth])
498
+ legacy_stages_info.change(fn=_restore_stages, inputs=[legacy_stages_info],
499
+ outputs=[legacy_stages_info, enabled, *flat_comps])
500
+ legacy_version_info.change(fn=_restore_ver, inputs=[legacy_version_info],
501
+ outputs=[legacy_version_info, version])
502
+
503
+ self.infotext_fields = [
504
+ (sched_info, "MegaFreeU Schedule"),
505
+ (stages_info, "MegaFreeU Stages"),
506
+ (version_info, "MegaFreeU Version"),
507
+ (ms_mode_info, "MegaFreeU Multiscale Mode"),
508
+ (ms_str_info, "MegaFreeU Multiscale Strength"),
509
+ (ov_scales_info, "MegaFreeU Override Scales"),
510
+ (ch_thresh_info, "MegaFreeU Channel Threshold"),
511
+ (postcfg_info, "MegaFreeU PostCFG"),
512
+ (verbose_info, "MegaFreeU Verbose"),
513
+ # Backward compat with sd-webui-freeu generated PNGs
514
+ (legacy_sched_info, "FreeU Schedule"),
515
+ (legacy_stages_info, "FreeU Stages"),
516
+ (legacy_version_info,"FreeU Version"),
517
+ ]
518
+ self.paste_field_names = [f for _, f in self.infotext_fields]
519
+
520
+ return [
521
+ enabled, version, preset_dd,
522
+ start_r, stop_r, smooth,
523
+ ms_mode, ms_str, ov_scales,
524
+ ch_thresh,
525
+ *flat_comps,
526
+ pcfg_en, pcfg_steps, pcfg_mode, pcfg_bl, pcfg_b,
527
+ pcfg_fou, pcfg_mmd, pcfg_mst, pcfg_thr, pcfg_s, pcfg_gain,
528
+ verbose,
529
+ ]
530
+
531
+ def process(self, p: processing.StableDiffusionProcessing, *args):
532
+ # ── Branch 1: old sd-webui-freeu API (dict passed as first arg) ───────
533
+ if args and isinstance(args[0], dict):
534
+ global_state.instance = global_state.State(**{
535
+ k: v for k, v in args[0].items()
536
+ if k in {f.name for f in dataclasses.fields(global_state.State)}
537
+ })
538
+ global_state.apply_xyz()
539
+ global_state.xyz_attrs.clear()
540
+ st = global_state.instance
541
+ unet.verbose_ref.value = bool(getattr(st, "verbose", False))
542
+ if getattr(st, "pcfg_enabled", False):
543
+ p._mega_pcfg = {
544
+ "enabled": True,
545
+ "steps": st.pcfg_steps,
546
+ "mode": st.pcfg_mode,
547
+ "blend": st.pcfg_blend,
548
+ "b": st.pcfg_b,
549
+ "fourier": st.pcfg_fourier,
550
+ "ms_mode": st.pcfg_ms_mode,
551
+ "ms_str": st.pcfg_ms_str,
552
+ "threshold": st.pcfg_threshold,
553
+ "s": st.pcfg_s,
554
+ "gain": st.pcfg_gain,
555
+ "step": 0,
556
+ }
557
+ else:
558
+ p._mega_pcfg = {"enabled": False}
559
+ if st.enable:
560
+ unet.detect_model_channels()
561
+ unet._on_cpu_devices.clear()
562
+ _write_generation_params(p, st)
563
+ return
564
+
565
+ # ── Branch 2: normal UI call ───────────────────────────────────────────
566
+ (enabled, version, preset_dd,
567
+ start_r, stop_r, smooth,
568
+ ms_mode, ms_str, ov_scales,
569
+ ch_thresh, *rest) = args
570
+
571
+ n_sv = _SN * global_state.STAGES_COUNT
572
+ flat_stage = rest[:n_sv]
573
+ post = rest[n_sv:] # 11 pcfg params + verbose
574
+
575
+ verbose = bool(post[11]) if len(post) > 11 else False
576
+ unet.verbose_ref.value = verbose
577
+
578
+ # Write UI values into instance BEFORE apply_xyz so XYZ can override any of them
579
+ inst = global_state.instance
580
+ inst.enable = bool(enabled)
581
+ inst.start_ratio = start_r
582
+ inst.stop_ratio = stop_r
583
+ inst.transition_smoothness = smooth
584
+ inst.version = global_state.ALL_VERSIONS.get(version, "1")
585
+ inst.multiscale_mode = ms_mode
586
+ inst.multiscale_strength = float(ms_str)
587
+ inst.override_scales = ov_scales or ""
588
+ inst.channel_threshold = int(ch_thresh)
589
+ inst.stage_infos = _flat_to_sis(flat_stage)
590
+
591
+ # Sync Post-CFG into instance state so presets/PNG capture it
592
+ pcfg = post[:11]
593
+ if len(pcfg) >= 11:
594
+ inst.pcfg_enabled = bool(pcfg[0])
595
+ inst.pcfg_steps = int(pcfg[1])
596
+ inst.pcfg_mode = str(pcfg[2])
597
+ inst.pcfg_blend = float(pcfg[3])
598
+ inst.pcfg_b = float(pcfg[4])
599
+ inst.pcfg_fourier = bool(pcfg[5])
600
+ inst.pcfg_ms_mode = str(pcfg[6])
601
+ inst.pcfg_ms_str = float(pcfg[7])
602
+ inst.pcfg_threshold = int(pcfg[8])
603
+ inst.pcfg_s = float(pcfg[9])
604
+ inst.pcfg_gain = float(pcfg[10])
605
+ inst.verbose = verbose
606
+
607
+ # apply_xyz() may replace global_state.instance with a preset copy;
608
+ # take the fresh reference AFTER so PNG metadata / verbose use the final state.
609
+ global_state.apply_xyz()
610
+ global_state.xyz_attrs.clear()
611
+ st = global_state.instance # ← fresh ref post-XYZ
612
+
613
+ # ── Post-CFG: set up ALWAYS (independent of main Enable) ──────────────
614
+ if st.pcfg_enabled:
615
+ p._mega_pcfg = {
616
+ "enabled": True,
617
+ "steps": st.pcfg_steps,
618
+ "mode": st.pcfg_mode,
619
+ "blend": st.pcfg_blend,
620
+ "b": st.pcfg_b,
621
+ "fourier": st.pcfg_fourier,
622
+ "ms_mode": st.pcfg_ms_mode,
623
+ "ms_str": st.pcfg_ms_str,
624
+ "threshold": st.pcfg_threshold,
625
+ "s": st.pcfg_s,
626
+ "gain": st.pcfg_gain,
627
+ "step": 0,
628
+ }
629
+ else:
630
+ p._mega_pcfg = {"enabled": False}
631
+
632
+ if not st.enable:
633
+ # Write partial params so PNG records the session even when disabled
634
+ _write_generation_params(p, st)
635
+ return
636
+
637
+ unet.detect_model_channels()
638
+ unet._on_cpu_devices.clear()
639
+
640
+ _write_generation_params(p, st)
641
+
642
+ if unet.verbose_ref.value:
643
+ print(f"[MegaFreeU] v{st.version} "
644
+ f"start={st.start_ratio:.3f} stop={st.stop_ratio:.3f} "
645
+ f"smooth={st.transition_smoothness:.3f} "
646
+ f"ch_thresh=+-{st.channel_threshold}")
647
+ for i, si in enumerate(st.stage_infos):
648
+ ch = unet._stage_channels[i] if i < len(unet._stage_channels) else "?"
649
+ print(f" Stage {i+1} ({ch}ch): "
650
+ f"b={si.backbone_factor:.3f} [{si.b_start_ratio:.2f}-{si.b_end_ratio:.2f}] "
651
+ f"{si.backbone_blend_mode}:{si.backbone_blend:.2f} "
652
+ f"s={si.skip_factor:.3f} [{si.s_start_ratio:.2f}-{si.s_end_ratio:.2f}] "
653
+ f"fft={si.fft_type} r={si.fft_radius_ratio:.3f} "
654
+ f"hfe={si.skip_high_end_factor:.2f} hfb={si.hf_boost:.2f} "
655
+ f"cap={'ON' if si.enable_adaptive_cap else 'off'} "
656
+ f"({si.cap_threshold:.2f}/{si.cap_factor:.2f} {si.adaptive_cap_mode})")
657
+
658
+ def process_batch(self, p, *args, **kwargs):
659
+ global_state.current_sampling_step = 0
660
+ # FIX: reset PostCFG step counter for each image in batch
661
+ if hasattr(p, "_mega_pcfg"):
662
+ p._mega_pcfg["step"] = 0
663
+
664
+ def postprocess(self, p, processed, *args, **kwargs):
665
+ """Clean up per-image state after generation."""
666
+ if hasattr(p, "_mega_pcfg"):
667
+ p._mega_pcfg = {"enabled": False}
668
+
669
+
670
+ def _write_generation_params(p, st):
671
+ """Write full Mega FreeU state into PNG extra_generation_params."""
672
+ p.extra_generation_params["MegaFreeU Schedule"] = (
673
+ f"{st.start_ratio}, {st.stop_ratio}, {st.transition_smoothness}")
674
+ p.extra_generation_params["MegaFreeU Stages"] = (
675
+ json.dumps([si.to_dict() for si in st.stage_infos]))
676
+ p.extra_generation_params["MegaFreeU Version"] = st.version
677
+ p.extra_generation_params["MegaFreeU Multiscale Mode"] = st.multiscale_mode
678
+ p.extra_generation_params["MegaFreeU Multiscale Strength"] = str(st.multiscale_strength)
679
+ p.extra_generation_params["MegaFreeU Override Scales"] = st.override_scales or ""
680
+ p.extra_generation_params["MegaFreeU Channel Threshold"] = str(st.channel_threshold)
681
+ p.extra_generation_params["MegaFreeU Verbose"] = str(st.verbose)
682
+ if st.pcfg_enabled:
683
+ p.extra_generation_params["MegaFreeU PostCFG"] = json.dumps({
684
+ "enabled": st.pcfg_enabled,
685
+ "steps": st.pcfg_steps,
686
+ "mode": st.pcfg_mode,
687
+ "blend": st.pcfg_blend,
688
+ "b": st.pcfg_b,
689
+ "fourier": st.pcfg_fourier,
690
+ "ms_mode": st.pcfg_ms_mode,
691
+ "ms_str": st.pcfg_ms_str,
692
+ "threshold": st.pcfg_threshold,
693
+ "s": st.pcfg_s,
694
+ "gain": st.pcfg_gain,
695
+ })
696
+
697
+
698
+ def _flat_to_sis(flat) -> List[global_state.StageInfo]:
699
+ result = []
700
+ for i in range(global_state.STAGES_COUNT):
701
+ chunk = flat[i * _SN:(i + 1) * _SN]
702
+ si = global_state.StageInfo()
703
+ for j, fname in enumerate(_SF):
704
+ if j < len(chunk):
705
+ setattr(si, fname, chunk[j])
706
+ result.append(si)
707
+ return result
708
+
709
+
710
+ # Callbacks
711
+ def _on_cfg_step(*_args, **_kwargs):
712
+ global_state.current_sampling_step += 1
713
+
714
+ def _on_cfg_post(params):
715
+ """WAS_PostCFGShift ported to A1111 on_cfg_after_cfg callback (exact algorithm)."""
716
+ p = getattr(params, "p", None)
717
+ if p is None:
718
+ p = getattr(getattr(params, "denoiser", None), "p", None)
719
+ if p is None: return
720
+ cfg = getattr(p, "_mega_pcfg", None)
721
+ if not cfg or not cfg.get("enabled"): return
722
+ cfg["step"] = cfg.get("step", 0) + 1
723
+ if cfg["step"] > cfg["steps"]: return
724
+ x = params.x
725
+ fn = unet.BLENDING_MODES.get(cfg["mode"], unet.BLENDING_MODES["inject"])
726
+ y = fn(x, x * cfg["b"], cfg["blend"])
727
+ if cfg["fourier"]:
728
+ ms = global_state.MSCALES.get(cfg["ms_mode"])
729
+ y = unet.filter_skip_box_multiscale(
730
+ y, cfg["threshold"], cfg["s"], ms, cfg["ms_str"])
731
+ if cfg["gain"] != 1.0:
732
+ y = y * float(cfg["gain"])
733
+ params.x = y
734
+
735
+ try:
736
+ script_callbacks.on_cfg_after_cfg(_on_cfg_step)
737
+ script_callbacks.on_cfg_after_cfg(_on_cfg_post)
738
+ except AttributeError:
739
+ # webui < 1.6.0 (sd-webui-freeu compatibility note)
740
+ script_callbacks.on_cfg_denoised(_on_cfg_step)
741
+ script_callbacks.on_cfg_denoised(_on_cfg_post)
742
+
743
+ def _on_after_component(component, **kwargs):
744
+ eid = kwargs.get("elem_id", "")
745
+ for key, sid in [("txt2img", "txt2img_steps"), ("img2img", "img2img_steps")]:
746
+ if eid == sid:
747
+ _steps_comps[key] = component
748
+ for cb in _steps_cbs[key]: cb(component)
749
+ _steps_cbs[key].clear()
750
+
751
+ script_callbacks.on_after_component(_on_after_component)
752
+
753
+ def _on_ui_settings():
754
+ shared.opts.add_option(
755
+ "mega_freeu_png_auto_enable",
756
+ shared.OptionInfo(
757
+ default=True,
758
+ label="Auto-enable Mega FreeU when loading PNG info from a FreeU generation",
759
+ section=("mega_freeu", "Mega FreeU")))
760
+
761
+ script_callbacks.on_ui_settings(_on_ui_settings)
762
+ script_callbacks.on_before_ui(xyz_grid.patch)
763
+
764
+ # Install th.cat patch at import (sd-webui-freeu pattern)
765
+ unet.patch()
mega_freeu_a1111/tests/README.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mega FreeU β€” Test Suite
2
+
3
+ Tests run without A1111 or PyTorch installed. A NumPy-based torch mock is used instead.
4
+
5
+ ## Requirements
6
+
7
+ ```
8
+ pip install numpy
9
+ ```
10
+
11
+ ## Running
12
+
13
+ ```bash
14
+ # From the extension root:
15
+ python tests/test_core.py # 144 tests β€” math, filters, blending, schedule, state
16
+ python tests/test_fixes.py # 36 tests β€” dict-API, PNG metadata, Post-CFG, XYZ
17
+ python tests/test_preset_pcfg.py # 32 tests β€” full preset save/apply round-trip
18
+
19
+ # Or run all at once:
20
+ for f in tests/test_core.py tests/test_fixes.py tests/test_preset_pcfg.py; do
21
+ python "$f" && echo "--- $f PASSED ---" || echo "--- $f FAILED ---"
22
+ done
23
+ ```
24
+
25
+ ## What is covered
26
+
27
+ | File | Tests | Coverage |
28
+ |------|-------|----------|
29
+ | `test_core.py` | 144 | `ratio_to_region`, `lerp`, `get_backbone_scale` (V1+V2), `filter_skip_box`, `fourier_filter_gauss`, `get_band_energy_stats`, `filter_skip_box_multiscale`, all 9 blending modes, `parse_override_scales`, `_normalize`, `get_schedule_ratio`, `get_stage_bsratio`, `StageInfo`/`State` dataclasses, `update_attr`, `_load_user_presets`, `filter_skip_gaussian_adaptive` (no-cap/cap/fixed/aggressive), backbone blend math, `apply_xyz`, `detect_model_channels`, `_flat_to_sis`, PostCFG step counter |
30
+ | `test_fixes.py` | 36 | `State` pcfg/verbose fields + round-trip, `_load_user_presets` with pcfg, `_write_generation_params` PNG keys, Post-CFG independent of Enable, dict-API compat |
31
+ | `test_preset_pcfg.py` | 32 | Full preset save/apply with all 20 fields, pcfg disabled case, unknown preset, dict-API pcfg propagation |
mega_freeu_a1111/tests/mock_torch.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math, types, sys
2
+ import numpy as np
3
+
4
+ class Tensor:
5
+ def __init__(self, data, device='cpu', dtype=None):
6
+ if isinstance(data, np.ndarray):
7
+ self._d = data
8
+ elif isinstance(data, Tensor):
9
+ self._d = data._d.copy()
10
+ else:
11
+ self._d = np.array(data, dtype=np.float32)
12
+ self.device = device
13
+ self.dtype = dtype or 'float32'
14
+
15
+ @property
16
+ def shape(self): return self._d.shape
17
+ def size(self, dim=None): return self.shape[dim] if dim is not None else self.shape
18
+ @property
19
+ def is_cpu(self): return True
20
+
21
+ # dtype / device
22
+ def float(self):
23
+ return Tensor(self._d.astype(np.float32), self.device, 'float32')
24
+ def to(self, *a, **kw): return self
25
+ def cpu(self): return Tensor(self._d.copy(), 'cpu', self.dtype)
26
+ def item(self): return float(self._d.flat[0])
27
+
28
+ # shape ops
29
+ def view(self, *shape):
30
+ s = shape[0] if len(shape)==1 and isinstance(shape[0],(tuple,list)) else shape
31
+ return Tensor(self._d.reshape(s), self.device, self.dtype)
32
+ def reshape(self, *s): return self.view(*s)
33
+ def unsqueeze(self, dim):
34
+ return Tensor(np.expand_dims(self._d, dim), self.device, self.dtype)
35
+ def squeeze(self, dim=None):
36
+ d = self._d.squeeze() if dim is None else self._d.squeeze(dim)
37
+ return Tensor(d, self.device, self.dtype)
38
+ def expand_as(self, other):
39
+ return Tensor(np.broadcast_to(self._d, other.shape).copy(), self.device, self.dtype)
40
+
41
+ # .real / .imag β€” real tensors just return themselves
42
+ @property
43
+ def real(self): return Tensor(self._d.real.astype(np.float32), self.device, self.dtype)
44
+ @property
45
+ def imag(self): return Tensor(np.zeros_like(self._d.real, dtype=np.float32), self.device, self.dtype)
46
+
47
+ # reductions
48
+ def mean(self, dim=None, keepdim=False):
49
+ if dim is None: return Tensor(np.array(self._d.mean(), np.float32), self.device)
50
+ return Tensor(self._d.mean(axis=dim, keepdims=keepdim).astype(np.float32), self.device)
51
+ def sum(self, dim=None, keepdim=False):
52
+ if dim is None: return Tensor(np.array(self._d.sum(), np.float32), self.device)
53
+ return Tensor(self._d.sum(axis=dim, keepdims=keepdim).astype(np.float32), self.device)
54
+ def any(self): return bool(self._d.any())
55
+ def clamp(self, lo=None, hi=None):
56
+ return Tensor(np.clip(self._d, lo, hi).astype(np.float32), self.device, self.dtype)
57
+ def clamp_min(self, v):
58
+ return Tensor(np.maximum(self._d, v).astype(np.float32), self.device, self.dtype)
59
+ def abs(self): return Tensor(np.abs(self._d).astype(np.float32), self.device, self.dtype)
60
+ def min(self, dim=None, keepdim=False):
61
+ if dim is None: return Tensor(np.array(self._d.min(), np.float32), self.device)
62
+ v=self._d.min(axis=dim,keepdims=keepdim); i=self._d.argmin(axis=dim)
63
+ return Tensor(v.astype(np.float32),self.device), Tensor(i.astype(np.float32),self.device)
64
+ def max(self, dim=None, keepdim=False):
65
+ if dim is None: return Tensor(np.array(self._d.max(), np.float32), self.device)
66
+ v=self._d.max(axis=dim,keepdims=keepdim); i=self._d.argmax(axis=dim)
67
+ return Tensor(v.astype(np.float32),self.device), Tensor(i.astype(np.float32),self.device)
68
+
69
+ # indexing
70
+ def __getitem__(self, idx):
71
+ if isinstance(idx, Tensor):
72
+ idx = idx._d.astype(bool)
73
+ elif isinstance(idx, tuple):
74
+ idx = tuple(i._d.astype(bool) if isinstance(i, Tensor) else i for i in idx)
75
+ return Tensor(self._d[idx], self.device, self.dtype)
76
+ def __setitem__(self, idx, val):
77
+ self._d[idx] = val._d if isinstance(val, Tensor) else val
78
+
79
+ # arithmetic helpers
80
+ def _v(self, o): return o._d if isinstance(o, Tensor) else o
81
+ def __add__(self, o): return Tensor((self._d + self._v(o)).astype(np.float32), self.device)
82
+ def __radd__(self, o): return self.__add__(o)
83
+ def __sub__(self, o): return Tensor((self._d - self._v(o)).astype(np.float32), self.device)
84
+ def __rsub__(self, o): return Tensor((self._v(o) - self._d).astype(np.float32), self.device)
85
+ def __mul__(self, o): return Tensor((self._d * self._v(o)).astype(np.float32), self.device)
86
+ def __rmul__(self, o): return self.__mul__(o)
87
+ def __truediv__(self, o): return Tensor((self._d / self._v(o)).astype(np.float32), self.device)
88
+ def __rtruediv__(self, o): return Tensor((self._v(o) / self._d).astype(np.float32), self.device)
89
+ def __neg__(self): return Tensor(-self._d, self.device, self.dtype)
90
+ def __pow__(self, n): return Tensor(self._d**n, self.device, self.dtype)
91
+ def __imul__(self, o): self._d = (self._d * self._v(o)).astype(np.float32); return self
92
+ def __abs__(self): return self.abs()
93
+
94
+ # comparisons β†’ float32 0/1 tensor (bool-compatible)
95
+ def __le__(self, o): return Tensor((self._d <= self._v(o)).astype(np.float32), self.device)
96
+ def __lt__(self, o): return Tensor((self._d < self._v(o)).astype(np.float32), self.device)
97
+ def __ge__(self, o): return Tensor((self._d >= self._v(o)).astype(np.float32), self.device)
98
+ def __gt__(self, o): return Tensor((self._d > self._v(o)).astype(np.float32), self.device)
99
+ def __eq__(self, o): return bool(np.array_equal(self._d, self._v(o))) if isinstance(o, (Tensor,np.ndarray)) else Tensor((self._d == o).astype(np.float32), self.device)
100
+ def __invert__(self): # ~bool_tensor
101
+ return Tensor((~self._d.astype(bool)).astype(np.float32), self.device)
102
+ def __bool__(self): return bool(self._d.flat[0])
103
+ def __repr__(self): return f"Tensor({self._d.shape})"
104
+
105
+
106
+ # ComplexTensor β€” returned by FFT ops
107
+ class CTensor(Tensor):
108
+ def __init__(self, data, device='cpu', dtype=None):
109
+ super().__init__(data, device, dtype)
110
+ @property
111
+ def real(self):
112
+ return Tensor(self._d.real.astype(np.float32), self.device, self.dtype)
113
+ @property
114
+ def imag(self):
115
+ return Tensor(self._d.imag.astype(np.float32), self.device, self.dtype)
116
+ def __mul__(self, o):
117
+ v = o._d if isinstance(o, Tensor) else o
118
+ return CTensor((self._d * v), self.device)
119
+ def __imul__(self, o):
120
+ v = o._d if isinstance(o, Tensor) else o
121
+ self._d = self._d * v
122
+ return self
123
+
124
+
125
+ # ── torch namespace ──────────────────────────────────────────────────────────
126
+ torch = types.ModuleType('torch')
127
+ torch.Tensor = Tensor
128
+ torch.float32 = 'float32'
129
+ torch.device = lambda s: s
130
+
131
+ torch.full = lambda shape, v, device='cpu', **kw: Tensor(np.full(shape, float(v), np.float32), device)
132
+ torch.ones = lambda *sh, device='cpu', **kw: Tensor(np.ones(sh[0] if len(sh)==1 else sh, np.float32), device)
133
+ torch.zeros = lambda *sh, device='cpu', **kw: Tensor(np.zeros(sh[0] if len(sh)==1 else sh, np.float32), device)
134
+ torch.tensor = lambda v, device='cpu', dtype=None, **kw: Tensor(np.array(v, dtype=np.float32), device)
135
+ torch.arange = lambda n, device='cpu', dtype=None, **kw: Tensor(np.arange(n, dtype=np.float32), device)
136
+
137
+ def _meshgrid(y, x, indexing='ij'):
138
+ yy, xx = np.meshgrid(y._d, x._d, indexing=indexing)
139
+ return Tensor(yy.astype(np.float32), y.device), Tensor(xx.astype(np.float32), x.device)
140
+ torch.meshgrid = _meshgrid
141
+
142
+ torch.exp = lambda x: Tensor(np.exp(np.clip(x._d,-500,500)).astype(np.float32), x.device)
143
+ torch.sin = lambda x: Tensor(np.sin(x._d).astype(np.float32), x.device)
144
+ torch.cos = lambda x: Tensor(np.cos(x._d).astype(np.float32), x.device)
145
+ torch.acos = lambda x: Tensor(np.arccos(np.clip(x._d,-1+1e-7,1-1e-7)).astype(np.float32), x.device)
146
+ torch.abs = lambda x: Tensor(np.abs(x._d).astype(np.float32), x.device)
147
+ torch.sqrt = lambda x: Tensor(np.sqrt(np.maximum(x._d,0)).astype(np.float32), x.device)
148
+ torch.norm = lambda x, dim=None, keepdim=False, **kw: Tensor(
149
+ np.linalg.norm(x._d, axis=dim, keepdims=keepdim).astype(np.float32), x.device)
150
+
151
+ def _max(x, dim=None, keepdim=False):
152
+ if dim is None: return float(x._d.max())
153
+ v = x._d.max(axis=dim, keepdims=keepdim)
154
+ i = x._d.argmax(axis=dim)
155
+ return Tensor(v.astype(np.float32), x.device), Tensor(i.astype(np.float32), x.device)
156
+ def _min(x, dim=None, keepdim=False):
157
+ if dim is None: return float(x._d.min())
158
+ v = x._d.min(axis=dim, keepdims=keepdim)
159
+ i = x._d.argmin(axis=dim)
160
+ return Tensor(v.astype(np.float32), x.device), Tensor(i.astype(np.float32), x.device)
161
+ torch.max = _max
162
+ torch.min = _min
163
+
164
+ def _where(c, a, b):
165
+ cd = c._d.astype(bool) if isinstance(c, Tensor) else np.array(c, bool)
166
+ ad = a._d if isinstance(a, Tensor) else a
167
+ bd = b._d if isinstance(b, Tensor) else b
168
+ return Tensor(np.where(cd, ad, bd).astype(np.float32),
169
+ (a if isinstance(a,Tensor) else b).device)
170
+ torch.where = _where
171
+
172
+ class _Linalg:
173
+ @staticmethod
174
+ def norm(x, dim=None, keepdim=False, **kw):
175
+ return Tensor(
176
+ np.linalg.norm(x._d, axis=dim, keepdims=keepdim).astype(np.float32),
177
+ x.device)
178
+ torch.linalg = _Linalg
179
+
180
+ class _FFT:
181
+ @staticmethod
182
+ def fftn(x, dim=None):
183
+ ax = tuple(dim) if dim is not None else None
184
+ return CTensor(np.fft.fftn(x._d.astype(complex), axes=ax), x.device)
185
+ @staticmethod
186
+ def ifftn(x, dim=None):
187
+ ax = tuple(dim) if dim is not None else None
188
+ r = np.fft.ifftn(x._d, axes=ax).real.astype(np.float32)
189
+ # Return a plain Tensor but with .real property (already float32)
190
+ return Tensor(r, x.device)
191
+ @staticmethod
192
+ def fftshift(x, dim=None):
193
+ ax = tuple(dim) if dim is not None else None
194
+ d = np.fft.fftshift(x._d, axes=ax)
195
+ return CTensor(d, x.device) if isinstance(x, CTensor) else Tensor(d, x.device)
196
+ @staticmethod
197
+ def ifftshift(x, dim=None):
198
+ ax = tuple(dim) if dim is not None else None
199
+ d = np.fft.ifftshift(x._d, axes=ax)
200
+ return CTensor(d, x.device) if isinstance(x, CTensor) else Tensor(d, x.device)
201
+ torch.fft = _FFT
202
+
203
+ class _Backends:
204
+ class mps:
205
+ @staticmethod
206
+ def is_available(): return False
207
+ torch.backends = _Backends
208
+
209
+ sys.modules['torch'] = torch
210
+
211
+ # ── modules mock ─────────────────────────────────────────────────────────────
212
+ class _State:
213
+ sampling_steps = 20
214
+ class _SharedObj:
215
+ state = _State()
216
+ class opts:
217
+ data = {}
218
+
219
+ mmod = types.ModuleType('modules')
220
+ mmod.shared = _SharedObj()
221
+ sys.modules['modules'] = mmod
222
+ for sub in ['modules.shared','modules.scripts','modules.processing',
223
+ 'modules.script_callbacks','gradio']:
224
+ sys.modules[sub] = types.ModuleType(sub)
225
+
226
+ import logging
227
+ sys.modules['logging'] = logging
228
+ # expose modules.shared in the right place
229
+ import types as _t
mega_freeu_a1111/tests/test_core.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib, sys
2
+ exec(open(str(pathlib.Path(__file__).parent / 'mock_torch.py')).read())
3
+
4
+ import sys, math, types, dataclasses, json, tempfile, os, pathlib
5
+ import numpy as np
6
+ import pathlib; sys.path.insert(0, str(pathlib.Path(__file__).parent.parent))
7
+
8
+ # ── load ──────────────────────────────────────────────────────────────────────
9
+ import importlib.util
10
+
11
+ def load_mod(name, path):
12
+ spec = importlib.util.spec_from_file_location(name, path)
13
+ m = importlib.util.module_from_spec(spec)
14
+ sys.modules[name] = m
15
+ spec.loader.exec_module(m)
16
+ return m
17
+
18
+ lib_pkg = types.ModuleType('lib_mega_freeu')
19
+ sys.modules['lib_mega_freeu'] = lib_pkg
20
+
21
+ GS = load_mod('lib_mega_freeu.global_state',
22
+ str(pathlib.Path(__file__).parent.parent / 'lib_mega_freeu' / 'global_state.py'))
23
+ UN = load_mod('lib_mega_freeu.unet',
24
+ str(pathlib.Path(__file__).parent.parent / 'lib_mega_freeu' / 'unet.py'))
25
+ print("Loaded OK\n")
26
+
27
+ # ── test helpers ──────────────────────────────────────────────────────────────
28
+ P=0; F=0; ERRS=[]
29
+
30
+ def ok(t): global P; P+=1; print(f" βœ“ {t}")
31
+ def ng(t,m=""): global F; F+=1; ERRS.append(f"{t}: {m}"); print(f" βœ— {t} {m}")
32
+ def chk(t, cond, m=""): ok(t) if cond else ng(t, m)
33
+ def near(t, got, want, tol=1e-4):
34
+ g = float(got) if hasattr(got,'item') else float(got)
35
+ (ok(t+f" ({g:.5g})") if abs(g-want)<=tol else ng(t, f"got {g:.5g} want {want}"))
36
+ def shp(t, tensor, expected):
37
+ chk(t+f" shape={tensor.shape}", tuple(tensor.shape)==tuple(expected), f"β‰ {expected}")
38
+
39
+ # ══════════════════════════════════════════════════════════════════════════════
40
+ print("═"*54)
41
+ print("1. RATIO_TO_REGION")
42
+ print("═"*54)
43
+ R = UN.ratio_to_region
44
+ s,e,inv = R(0.5, 0.0, 100)
45
+ chk("[0,50] no-inv", s==0 and e==50 and not inv, f"s={s},e={e},inv={inv}")
46
+ s,e,inv = R(0.5, 0.5, 100)
47
+ chk("[50,100] no-inv", s==50 and e==100 and not inv)
48
+ s,e,inv = R(0.7, 0.5, 100)
49
+ chk("width+offset>1 β†’ inverted", inv and s==20 and e==50, f"s={s},e={e},inv={inv}")
50
+ s,e,inv = R(-0.3, 0.5, 100) # negative width
51
+ chk("neg width no crash", True)
52
+ s,e,inv = R(0.0, 0.0, 100)
53
+ chk("zero width: s==e", s==e)
54
+ s,e,inv = R(1.0, 0.0, 100) # full width β†’ [0,100]
55
+ chk("full width [0,100]", s==0 and e==100 and not inv)
56
+
57
+ # ══════════════════════════════════════════════════════════════════════════════
58
+ print("\n2. LERP")
59
+ near("lerp(0,1,.5)", UN.lerp(0,1,.5), .5)
60
+ near("lerp t=0 β†’ a", UN.lerp(5,10,0), 5)
61
+ near("lerp t=1 β†’ b", UN.lerp(5,10,1), 10)
62
+ near("lerp(2,4,.25)", UN.lerp(2,4,.25), 2.5)
63
+ a=Tensor(np.array([[[[2.,4.]]]])); b=Tensor(np.array([[[[6.,8.]]]]))
64
+ r=UN.lerp(a,b,0.5)
65
+ chk("tensor lerp values", np.allclose(r._d, [[[[4.,6.]]]]))
66
+
67
+ # ══════════════════════════════════════════════════════════════════════════════
68
+ print("\n3. GET_BACKBONE_SCALE")
69
+ h = Tensor(np.random.randn(2,8,4,4).astype(np.float32))
70
+ # V1
71
+ v1 = UN.get_backbone_scale(h, 1.3, "1")
72
+ near("V1=1.3", v1 if isinstance(v1,(int,float)) else v1.item(), 1.3)
73
+ # V2
74
+ v2 = UN.get_backbone_scale(h, 1.4, "2")
75
+ shp("V2 shape=(2,1,4,4)", v2, (2,1,4,4))
76
+ chk("V2 vals in (0,2.5)", v2._d.min()>0 and v2._d.max()<2.5,
77
+ f"[{v2._d.min():.3f},{v2._d.max():.3f}]")
78
+ # V2 factor=1.0 β†’ all 1s
79
+ v2id = UN.get_backbone_scale(h, 1.0, "2")
80
+ chk("V2 factor=1β†’all-ones", np.allclose(v2id._d, 1.0, atol=1e-5))
81
+ # V2 factor=2.0 β†’ range [1,2]
82
+ v2f2 = UN.get_backbone_scale(h, 2.0, "2")
83
+ chk("V2 factor=2 β†’ [1,2]",
84
+ v2f2._d.min()>=1.0-1e-5 and v2f2._d.max()<=2.0+1e-5,
85
+ f"[{v2f2._d.min():.4f},{v2f2._d.max():.4f}]")
86
+ # zero input β†’ no nan/crash
87
+ hz = Tensor(np.zeros((1,8,4,4),np.float32))
88
+ v2z = UN.get_backbone_scale(hz, 1.5, "2")
89
+ chk("V2 zero input: no nan", not np.isnan(v2z._d).any())
90
+
91
+ # ══════════════════════════════════════════════════════════════════════════════
92
+ print("\n4. FILTER_SKIP_BOX")
93
+ x = Tensor(np.random.randn(1,4,16,16).astype(np.float32))
94
+ # identity
95
+ out = UN.filter_skip_box(x, 0.5, 1.0, 1.0)
96
+ chk("box identity", np.allclose(out._d, x._d, atol=1e-4))
97
+ # scale=0
98
+ out0 = UN.filter_skip_box(x, 0.5, 0.0, 1.0)
99
+ shp("box s=0 shape", out0, (1,4,16,16))
100
+ chk("box s=0 β‰  input", not np.allclose(out0._d, x._d, atol=1e-3))
101
+ # scale_high=0
102
+ outh = UN.filter_skip_box(x, 0.5, 1.0, 0.0)
103
+ chk("box h=0 β‰  input", not np.allclose(outh._d, x._d, atol=1e-3))
104
+ # cutoff=0 edge
105
+ outc0 = UN.filter_skip_box(x, 0.0, 0.5, 1.0)
106
+ shp("box cutoff=0 shape", outc0, (1,4,16,16))
107
+ # int cutoff
108
+ outi = UN.filter_skip_box(x, 2, 0.7, 1.0)
109
+ shp("box int cutoff", outi, (1,4,16,16))
110
+ # batch > 1
111
+ xb = Tensor(np.random.randn(3,8,16,16).astype(np.float32))
112
+ outb = UN.filter_skip_box(xb, 0.3, 0.7, 1.0)
113
+ shp("box batch=3", outb, (3,8,16,16))
114
+ # no nan
115
+ chk("box no nan", not np.isnan(out0._d).any())
116
+
117
+ # ══════════════════════════════════════════════════════════════════════════════
118
+ print("\n5. FOURIER_FILTER_GAUSS")
119
+ xg = Tensor(np.random.randn(1,4,16,16).astype(np.float32))
120
+ outg = UN.fourier_filter_gauss(xg, 0.1, 0.8, 1.2)
121
+ shp("gauss shape", outg, (1,4,16,16))
122
+ chk("gauss s=.8 β‰  input", not np.allclose(outg._d, xg._d, atol=1e-3))
123
+ # identity
124
+ outgi = UN.fourier_filter_gauss(xg, 0.1, 1.0, 1.0)
125
+ chk("gauss identity β‰ˆ input", np.allclose(outgi._d, xg._d, atol=1e-4),
126
+ f"maxdiff={abs(outgi._d-xg._d).max():.2e}")
127
+ # no nan
128
+ chk("gauss no nan", not np.isnan(outg._d).any())
129
+ # batch=3
130
+ xg3 = Tensor(np.random.randn(3,8,16,16).astype(np.float32))
131
+ outg3 = UN.fourier_filter_gauss(xg3, 0.08, 0.9, 1.0)
132
+ shp("gauss batch=3", outg3, (3,8,16,16))
133
+ # tiny radius (R=1)
134
+ outgt = UN.fourier_filter_gauss(xg, 0.01, 0.5, 1.0)
135
+ shp("gauss tiny R=1 shape", outgt, (1,4,16,16))
136
+ # large radius (R=max)
137
+ outgl = UN.fourier_filter_gauss(xg, 0.49, 0.5, 1.0)
138
+ shp("gauss large radius shape", outgl, (1,4,16,16))
139
+
140
+ # ══════════════════════════════════════════════════════════════════════════════
141
+ print("\n6. GET_BAND_ENERGY_STATS")
142
+ xe = Tensor(np.random.randn(2,4,16,16).astype(np.float32))
143
+ lf,hf,cov = UN.get_band_energy_stats(xe, 3)
144
+ chk("lf>0", lf>0)
145
+ chk("hf>0", hf>0)
146
+ chk("0<cov<100", 0<cov<100, f"cov={cov:.1f}")
147
+ # zeros β†’ energies=0
148
+ xze = Tensor(np.zeros((1,4,8,8),np.float32))
149
+ lf0,hf0,_ = UN.get_band_energy_stats(xze, 2)
150
+ chk("zeros β†’ lf=hf=0", lf0==0 and hf0==0)
151
+ # R=1 minimal β†’ some coverage
152
+ lf1,hf1,cov1 = UN.get_band_energy_stats(xe, 1)
153
+ chk("R=1 cover>0", cov1>0)
154
+
155
+ # ══════════════════════════════════════════════════════════════════════════════
156
+ print("\n7. FILTER_SKIP_BOX_MULTISCALE")
157
+ xm = Tensor(np.random.randn(1,4,32,32).astype(np.float32))
158
+ # identity
159
+ outmi = UN.filter_skip_box_multiscale(xm, 0.3, 1.0, None, 1.0, 1.0)
160
+ chk("ms identity", np.allclose(outmi._d, xm._d, atol=1e-4))
161
+ # single-scale preset
162
+ outss = UN.filter_skip_box_multiscale(xm, 0.3, 0.7, [(8,1.5)], 1.0)
163
+ shp("ms single-scale shape", outss, (1,4,32,32))
164
+ chk("ms single differs", not np.allclose(outss._d, xm._d, atol=1e-3))
165
+ # multi-scale preset
166
+ outms = UN.filter_skip_box_multiscale(xm, 0.3, 0.7, [[(5,0.0),(15,1.0)]], 1.0)
167
+ shp("ms multi-scale shape", outms, (1,4,32,32))
168
+ # scale_high
169
+ outsh = UN.filter_skip_box_multiscale(xm, 0.3, 0.7, None, 1.0, 1.5)
170
+ chk("ms scale_high=1.5 differs", not np.allclose(outsh._d, xm._d, atol=1e-3))
171
+ # no nan
172
+ chk("ms no nan", not np.isnan(outss._d).any())
173
+
174
+ # ══════════════════════════════════════════════════════════════════════════════
175
+ print("\n8. BLENDING MODES")
176
+ a = Tensor(np.full((1,4,4,4), 2.0, np.float32))
177
+ b = Tensor(np.full((1,4,4,4), 4.0, np.float32))
178
+ for mname, fn in UN.BLENDING_MODES.items():
179
+ try:
180
+ out = fn(a, b, 0.5)
181
+ shp(f" {mname}", out, (1,4,4,4))
182
+ chk(f" {mname} no nan", not np.isnan(out._d).any())
183
+ except Exception as e:
184
+ ng(f" {mname} CRASHED", str(e))
185
+
186
+ lerp_fn = UN.BLENDING_MODES['lerp']
187
+ near("lerp t=0β†’a", lerp_fn(a,b,0)._d.mean(), 2.0)
188
+ near("lerp t=1β†’b", lerp_fn(a,b,1)._d.mean(), 4.0)
189
+ near("lerp t=.5β†’3", lerp_fn(a,b,.5)._d.mean(), 3.0)
190
+
191
+ inj = UN.BLENDING_MODES['inject']
192
+ near("inject a+b*.5=4", inj(a,b,.5)._d.mean(), 4.0)
193
+ near("inject t=0β†’a", inj(a,b,0)._d.mean(), 2.0)
194
+
195
+ # ══════════════════════════════════════════════════════════════════════════════
196
+ print("\n9. PARSE_OVERRIDE_SCALES")
197
+ P2 = UN.parse_override_scales
198
+ chk("None→None", P2(None) is None)
199
+ chk("empty→None", P2("") is None)
200
+ chk("comment-only→None", P2("# x\n! y\n// z") is None)
201
+ r = P2("10, 1.5\n20, 0.8")
202
+ chk("2 lines β†’ 2 entries", r is not None and len(r)==2)
203
+ chk("entry[0]=(10,1.5)", r and r[0]==(10,1.5), str(r))
204
+ r2 = P2("# comment\n5, 2.0\n! skip\n15, 0.5")
205
+ chk("with comments β†’ 2 entries", r2 and len(r2)==2)
206
+ r3 = P2("10,1.5\nbad_line\n20,0.8") # malformed line skipped
207
+ chk("malformed skipped", r3 and len(r3)==2, str(r3))
208
+
209
+ # ══════════════════════════════════════════════════════════════════════════════
210
+ print("\n10. _NORMALIZE")
211
+ n = UN._normalize
212
+ chk("norm [0,3]: min=0 max=1",
213
+ abs(n(Tensor(np.array([[[[0.,1.,2.,3.]]]])))._d.min())<1e-5 and
214
+ abs(n(Tensor(np.array([[[[0.,1.,2.,3.]]]])))._d.max()-1.0)<1e-5)
215
+ chk("norm const: no crash/nan",
216
+ not np.isnan(n(Tensor(np.ones((1,1,4,4),np.float32)*5))._d).any())
217
+ chk("norm negatives: min=0 max=1",
218
+ abs(n(Tensor(np.array([[[[-2.,-1.,0.,1.,2.]]]])))._d.max()-1.0)<1e-5)
219
+
220
+ # ══════════════════════════════════════════════════════════════════════════════
221
+ print("\n11. GET_SCHEDULE_RATIO")
222
+ mmod.shared.state.sampling_steps = 20
223
+ GS.instance = GS.State(start_ratio=0.0, stop_ratio=1.0, transition_smoothness=0.0)
224
+ GS.current_sampling_step = 0
225
+ near("full range step=0 β†’ 1", UN.get_schedule_ratio(), 1.0)
226
+ GS.current_sampling_step = 10
227
+ near("full range step=10 β†’ 1", UN.get_schedule_ratio(), 1.0)
228
+ GS.current_sampling_step = 25
229
+ near("past stop β†’ 0", UN.get_schedule_ratio(), 0.0)
230
+ GS.instance = GS.State(start_ratio=0.5, stop_ratio=1.0, transition_smoothness=0.0)
231
+ GS.current_sampling_step = 5
232
+ near("before start β†’ 0", UN.get_schedule_ratio(), 0.0)
233
+ GS.current_sampling_step = 15
234
+ near("in range β†’ 1", UN.get_schedule_ratio(), 1.0)
235
+ GS.instance = GS.State(start_ratio=0.5, stop_ratio=0.5, transition_smoothness=0.0)
236
+ near("start==stop β†’ 0", UN.get_schedule_ratio(), 0.0)
237
+ GS.instance = GS.State(start_ratio=0.0, stop_ratio=1.0, transition_smoothness=1.0)
238
+ GS.current_sampling_step = 0
239
+ r = UN.get_schedule_ratio()
240
+ chk("smoothness=1 step=0: 0≀r≀1", 0.0<=r<=1.0)
241
+
242
+ # ══════════════════════════════════════════════════════════════════════════════
243
+ print("\n12. GET_STAGE_BSRATIO")
244
+ bsr = UN.get_stage_bsratio
245
+ mmod.shared.state.sampling_steps = 20
246
+ GS.current_sampling_step = 5 # pctβ‰ˆ0.26
247
+ near("bsr [0,1]β†’1", bsr(0.0,1.0), 1.0)
248
+ near("bsr [0.5,1]β†’0 (pct<0.5)", bsr(0.5,1.0), 0.0)
249
+ near("bsr [0,0.5]β†’1 (pct<0.5)", bsr(0.0,0.5), 1.0)
250
+ GS.current_sampling_step = 18 # pctβ‰ˆ0.95
251
+ near("bsr [0,0.5]β†’0 (pct>0.5)", bsr(0.0,0.5), 0.0)
252
+ near("bsr [0.5,1]β†’1 (pct>0.5)", bsr(0.5,1.0), 1.0)
253
+ mmod.shared.state.sampling_steps = 1
254
+ GS.current_sampling_step = 0
255
+ near("steps=1 pct=0 in [0,1]β†’1", bsr(0.0,1.0), 1.0)
256
+ mmod.shared.state.sampling_steps = 20
257
+
258
+ # ══════════════════════════════════════════════════════════════════════════════
259
+ print("\n13. STAGEINFO + STATE")
260
+ si = GS.StageInfo()
261
+ chk("default bf=1", si.backbone_factor==1.0)
262
+ chk("default fft=box", si.fft_type=="box")
263
+ chk("default cap=off", not si.enable_adaptive_cap)
264
+ chk("19 STAGE_FIELD_NAMES", len(GS.STAGE_FIELD_NAMES)==19)
265
+
266
+ st = GS.State(version="Version 2")
267
+ chk("version coerced '2'", st.version=="2")
268
+ st2 = GS.State(version="1")
269
+ chk("version '1' stays '1'", st2.version=="1")
270
+
271
+ # unknown dict key in stage_infos
272
+ st3 = GS.State(stage_infos=[{'backbone_factor':1.5,'UNKNOWN':999}])
273
+ chk("dict unknown key ignored", st3.stage_infos[0].backbone_factor==1.5)
274
+
275
+ # padding to STAGES_COUNT
276
+ st4 = GS.State(stage_infos=[GS.StageInfo(backbone_factor=2.0)])
277
+ chk("pads to 3 stages", len(st4.stage_infos)==3)
278
+ chk("pad default bf=1", st4.stage_infos[1].backbone_factor==1.0)
279
+
280
+ # round-trip
281
+ st5 = GS.State(version="2", stage_infos=[GS.StageInfo(backbone_factor=1.7)])
282
+ d = st5.to_dict()
283
+ chk("to_dict no 'enable'", 'enable' not in d)
284
+ chk("to_dict has stages", 'stage_infos' in d)
285
+ fields = {f.name for f in dataclasses.fields(GS.State)}
286
+ st6 = GS.State(**{k:v for k,v in d.items() if k in fields})
287
+ chk("round-trip version", st6.version=="2")
288
+
289
+ # ══════════════════════════════════════════════════════════════════════════════
290
+ print("\n14. UPDATE_ATTR (XYZ shorthands)")
291
+ st = GS.State()
292
+ st.update_attr("b0", 1.5); chk("b0β†’bf stage0", st.stage_infos[0].backbone_factor==1.5)
293
+ st.update_attr("s1", 0.3); chk("s1β†’sf stage1", st.stage_infos[1].skip_factor==0.3)
294
+ st.update_attr("ft2","gaussian"); chk("ft2β†’fft_type stage2", st.stage_infos[2].fft_type=="gaussian")
295
+ st.update_attr("acm0","fixed"); chk("acm0β†’cap_mode stage0", st.stage_infos[0].adaptive_cap_mode=="fixed")
296
+ st.update_attr("cap1", True); chk("cap1β†’enable_adaptive_cap", st.stage_infos[1].enable_adaptive_cap==True)
297
+ st.update_attr("ct0", 0.4); chk("ct0β†’cap_threshold", st.stage_infos[0].cap_threshold==0.4)
298
+ st.update_attr("start_ratio", 0.2); chk("start_ratio direct", st.start_ratio==0.2)
299
+ st.update_attr("enable", True); chk("enable direct", st.enable==True)
300
+ # unknown key β†’ no crash
301
+ try:
302
+ st.update_attr("UNKNOWN_KEY", 99)
303
+ chk("unknown key: no crash", True)
304
+ except Exception as e:
305
+ chk("unknown key: no crash", False, str(e))
306
+
307
+ # ══════════════════════════════════════════════════════════════════════════════
308
+ print("\n15. _LOAD_USER_PRESETS robustness")
309
+ # Good + bad preset in same file
310
+ pdata = {
311
+ "good": {"start_ratio":0.0,"stop_ratio":1.0,"transition_smoothness":0.0,
312
+ "version":"2","multiscale_mode":"Default","multiscale_strength":1.0,
313
+ "override_scales":"","channel_threshold":96,
314
+ "stage_infos":[{"backbone_factor":1.3}]},
315
+ "with_unknown": {"start_ratio":0.0,"stop_ratio":1.0,"transition_smoothness":0.0,
316
+ "version":"1","multiscale_mode":"Default","multiscale_strength":1.0,
317
+ "override_scales":"","channel_threshold":96,
318
+ "FUTURE_FIELD":"ignored","stage_infos":[]}
319
+ }
320
+ with tempfile.NamedTemporaryFile(mode='w',suffix='.json',delete=False) as f:
321
+ json.dump(pdata, f); tmp = f.name
322
+ GS.PRESETS_PATH = pathlib.Path(tmp)
323
+ res = GS._load_user_presets()
324
+ chk("good preset loaded", "good" in res)
325
+ chk("good preset bf=1.3", res.get("good") and res["good"].stage_infos[0].backbone_factor==1.3)
326
+ chk("no crash on unknown field preset", True)
327
+ os.unlink(tmp)
328
+
329
+ # Invalid JSON β†’ {}
330
+ with tempfile.NamedTemporaryFile(mode='w',suffix='.json',delete=False) as f:
331
+ f.write("{bad json!!!"); tmp2 = f.name
332
+ GS.PRESETS_PATH = pathlib.Path(tmp2)
333
+ chk("invalid JSON β†’ {}", GS._load_user_presets() == {})
334
+ os.unlink(tmp2)
335
+
336
+ # ══════════════════════════════════════════════════════════════════════════════
337
+ print("\n16. FILTER_SKIP_GAUSSIAN_ADAPTIVE")
338
+ hs = Tensor(np.random.randn(1,4,16,16).astype(np.float32))
339
+
340
+ # no cap
341
+ si_nc = GS.StageInfo(skip_factor=0.8, fft_type='gaussian',
342
+ fft_radius_ratio=0.1, hf_boost=1.2,
343
+ skip_high_end_factor=1.1, enable_adaptive_cap=False)
344
+ out_nc = UN.filter_skip_gaussian_adaptive(hs, si_nc)
345
+ shp("no-cap shape", out_nc, (1,4,16,16))
346
+ chk("no-cap differs from input", not np.allclose(out_nc._d, hs._d, atol=1e-3))
347
+ chk("no-cap no nan", not np.isnan(out_nc._d).any())
348
+
349
+ # identity (scale=1, hfb=1)
350
+ si_id = GS.StageInfo(skip_factor=1.0, fft_type='gaussian',
351
+ fft_radius_ratio=0.1, hf_boost=1.0,
352
+ skip_high_end_factor=1.0, enable_adaptive_cap=False)
353
+ out_id = UN.filter_skip_gaussian_adaptive(hs, si_id)
354
+ chk("gauss identity β‰ˆ input", np.allclose(out_id._d, hs._d, atol=1e-4),
355
+ f"maxdiff={abs(out_id._d-hs._d).max():.2e}")
356
+
357
+ # with cap
358
+ si_cap = GS.StageInfo(skip_factor=0.3, fft_type='gaussian',
359
+ fft_radius_ratio=0.15, hf_boost=1.0,
360
+ skip_high_end_factor=1.0, enable_adaptive_cap=True,
361
+ cap_threshold=0.35, cap_factor=0.6, adaptive_cap_mode='adaptive')
362
+ out_cap = UN.filter_skip_gaussian_adaptive(hs, si_cap)
363
+ shp("cap shape", out_cap, (1,4,16,16))
364
+ chk("cap no nan", not np.isnan(out_cap._d).any())
365
+
366
+ # fixed cap mode
367
+ si_fixed = GS.StageInfo(skip_factor=0.3, fft_type='gaussian',
368
+ fft_radius_ratio=0.15, hf_boost=1.0,
369
+ skip_high_end_factor=1.0, enable_adaptive_cap=True,
370
+ cap_threshold=0.35, cap_factor=0.6, adaptive_cap_mode='fixed')
371
+ out_fixed = UN.filter_skip_gaussian_adaptive(hs, si_fixed)
372
+ shp("fixed-cap shape", out_fixed, (1,4,16,16))
373
+ chk("fixed-cap no nan", not np.isnan(out_fixed._d).any())
374
+
375
+ # very aggressive scale (s=0.0) with cap
376
+ si_agg = GS.StageInfo(skip_factor=0.0, fft_type='gaussian',
377
+ fft_radius_ratio=0.1, hf_boost=1.0,
378
+ skip_high_end_factor=1.0, enable_adaptive_cap=True,
379
+ cap_threshold=0.35, cap_factor=0.6, adaptive_cap_mode='adaptive')
380
+ out_agg = UN.filter_skip_gaussian_adaptive(hs, si_agg)
381
+ shp("aggressive cap shape", out_agg, (1,4,16,16))
382
+ chk("aggressive cap no nan", not np.isnan(out_agg._d).any())
383
+
384
+ # ══════════════════════════════════════════════════════════════════════════════
385
+ print("\n17. BACKBONE BLEND MATH (unit test)")
386
+ h = Tensor(np.full((1,8,4,4), 3.0, np.float32))
387
+ dims = 8
388
+ rbegin, rend, rinv = UN.ratio_to_region(0.5, 0.0, dims)
389
+ mask_np = np.zeros(dims, np.float32)
390
+ if not rinv: mask_np[rbegin:rend] = 1.0
391
+ else: mask_np[:rend]=1.0; mask_np[rbegin:]=1.0
392
+ mask_t = Tensor(mask_np.reshape(1,-1,1,1))
393
+
394
+ # V1 scale=2.0: masked β†’ 6, unmasked β†’ 3
395
+ scale = 2.0
396
+ h_scaled = h * (mask_t * scale + (1.0 - mask_t))
397
+ masked = h_scaled._d[0, :rend, 0, 0]
398
+ unmasked = h_scaled._d[0, rend:, 0, 0]
399
+ chk("masked ch β†’ 6.0", np.allclose(masked, 6.0, atol=1e-4), str(masked))
400
+ chk("unmasked ch β†’ 3.0", np.allclose(unmasked, 3.0, atol=1e-4), str(unmasked))
401
+
402
+ # blend lerp 0.5: masked β†’ lerp(3,6,0.5)=4.5
403
+ lerp_fn = UN.BLENDING_MODES['lerp']
404
+ h_scaled_full = h * (mask_t * scale + (1.0 - mask_t))
405
+ h_blended = lerp_fn(h, h_scaled_full, 0.5)
406
+ h_out = h * (1.0 - mask_t) + h_blended * mask_t
407
+ chk("blend lerp 0.5 masked→4.5",
408
+ np.allclose(h_out._d[0,:rend,0,0], 4.5, atol=1e-4),
409
+ str(h_out._d[0,:rend,0,0]))
410
+ chk("blend lerp 0.5 unmasked→3.0",
411
+ np.allclose(h_out._d[0,rend:,0,0], 3.0, atol=1e-4))
412
+
413
+ # inject blend t=0.5: h + h_scaled*0.5 = 3 + 6*0.5 = 6
414
+ inj_fn = UN.BLENDING_MODES['inject']
415
+ h_inj = inj_fn(h, h_scaled_full, 0.5)
416
+ h_out2 = h * (1.0 - mask_t) + h_inj * mask_t
417
+ chk("blend inject masked→h+h_scaled*.5",
418
+ np.allclose(h_out2._d[0,:rend,0,0], 3+6*0.5, atol=1e-4),
419
+ str(h_out2._d[0,:rend,0,0]))
420
+
421
+ # ══════════════════════════════════════════════════════════════════════════════
422
+ print("\n18. APPLY_XYZ")
423
+ GS.instance = GS.State(); GS.xyz_attrs.clear()
424
+ orig = GS.instance.stage_infos[0].backbone_factor
425
+ GS.apply_xyz()
426
+ chk("empty attrs: unchanged", GS.instance.stage_infos[0].backbone_factor==orig)
427
+
428
+ GS.xyz_attrs['b0'] = 2.5; GS.apply_xyz()
429
+ chk("b0=2.5 applied", GS.instance.stage_infos[0].backbone_factor==2.5)
430
+ GS.xyz_attrs.clear()
431
+
432
+ # Preset
433
+ GS.reload_presets()
434
+ pname = list(GS.all_presets.keys())[0]
435
+ pbf = GS.all_presets[pname].stage_infos[0].backbone_factor
436
+ GS.xyz_attrs['preset'] = pname; GS.apply_xyz()
437
+ chk("preset applied", abs(GS.instance.stage_infos[0].backbone_factor-pbf)<1e-5)
438
+ GS.xyz_attrs.clear()
439
+
440
+ # Unknown preset β†’ warning, instance unchanged
441
+ GS.instance = GS.State()
442
+ bf_before = GS.instance.stage_infos[0].backbone_factor
443
+ GS.xyz_attrs['preset'] = 'NO_SUCH_PRESET_XYZ'; GS.apply_xyz()
444
+ chk("unknown preset: no crash", True)
445
+ GS.xyz_attrs.clear()
446
+
447
+ # ══════════════════════════════════════════════════════════════════════════════
448
+ print("\n19. DETECT_MODEL_CHANNELS")
449
+ UN.detect_model_channels()
450
+ chk("fallback=(1280,640,320)", UN._stage_channels==(1280,640,320))
451
+
452
+ # ══════════════════════════════════════════════════════════════════════════════
453
+ print("\n20. _FLAT_TO_SIS (simulated)")
454
+ _SF = [f.name for f in dataclasses.fields(GS.StageInfo)]
455
+ _SN = len(_SF)
456
+ def flat_to_sis(flat):
457
+ res=[]
458
+ for i in range(GS.STAGES_COUNT):
459
+ chunk=flat[i*_SN:(i+1)*_SN]
460
+ si_new=GS.StageInfo()
461
+ for j,fname in enumerate(_SF):
462
+ if j<len(chunk): setattr(si_new,fname,chunk[j])
463
+ res.append(si_new)
464
+ return res
465
+
466
+ si0=GS.StageInfo(backbone_factor=1.7, skip_factor=0.3, fft_type='gaussian')
467
+ flat=[]
468
+ for si_x in [si0,GS.StageInfo(),GS.StageInfo()]:
469
+ for f in _SF: flat.append(getattr(si_x,f))
470
+ sis=flat_to_sis(flat)
471
+ chk("flat bf=1.7", sis[0].backbone_factor==1.7)
472
+ chk("flat sf=0.3", sis[0].skip_factor==0.3)
473
+ chk("flat fft=gaussian", sis[0].fft_type=='gaussian')
474
+ chk("flat 3 stages", len(sis)==3)
475
+ chk("flat short: no crash", len(flat_to_sis(flat[:20]))==3)
476
+
477
+ # ══════════════════════════════════════════════════════════════════════════════
478
+ print("\n21. POSTCFG STEP COUNTER LOGIC")
479
+ # Simulate _on_cfg_post step counting
480
+ cfg = {"enabled":True,"steps":3,"mode":"lerp","blend":0.5,"b":1.1,
481
+ "fourier":False,"ms_mode":"Default","ms_str":1.0,
482
+ "threshold":1,"s":0.5,"gain":1.0,"step":0}
483
+
484
+ class FakeParams:
485
+ def __init__(self): self.x = Tensor(np.ones((1,4,8,8),np.float32)*2.0)
486
+
487
+ class FakeP:
488
+ _mega_pcfg = cfg
489
+
490
+ class FakeDenoiser:
491
+ p = FakeP()
492
+
493
+ # simulate _on_cfg_post inline
494
+ def run_post(params):
495
+ p = getattr(params, "p", None)
496
+ if p is None:
497
+ p = getattr(getattr(params,"denoiser",None),"p",None)
498
+ if p is None: return False
499
+ c = getattr(p,"_mega_pcfg",None)
500
+ if not c or not c.get("enabled"): return False
501
+ c["step"] = c.get("step",0)+1
502
+ if c["step"] > c["steps"]: return False
503
+ x = params.x
504
+ fn = UN.BLENDING_MODES.get(c["mode"], UN.BLENDING_MODES["inject"])
505
+ params.x = fn(x, x*c["b"], c["blend"])
506
+ return True
507
+
508
+ # via p attribute
509
+ fp1 = FakeParams(); fp1.p = FakeP(); fp1.p._mega_pcfg = {"enabled":True,"steps":2,"mode":"lerp","blend":0.5,"b":1.0,"fourier":False,"step":0,"gain":1.0}
510
+ ran1 = run_post(fp1); chk("postcfg step1 ran", ran1)
511
+ ran2 = run_post(fp1); chk("postcfg step2 ran", ran2)
512
+ ran3 = run_post(fp1); chk("postcfg step3 β†’ past limit, skipped", not ran3)
513
+
514
+ # via denoiser.p
515
+ fp2 = FakeParams()
516
+ class D2:
517
+ class p2:
518
+ _mega_pcfg = {"enabled":True,"steps":1,"mode":"inject","blend":0.5,"b":1.1,"fourier":False,"step":0,"gain":1.0}
519
+ p = p2
520
+ fp2.denoiser = D2
521
+ chk("postcfg via denoiser.p: no crash", run_post(fp2))
522
+
523
+ # disabled β†’ skip
524
+ fp3 = FakeParams(); fp3.p = type('P',(),{'_mega_pcfg':{"enabled":False}})()
525
+ chk("postcfg disabled β†’ skip", not run_post(fp3))
526
+
527
+ # ══════════════════════════════════════════════════════════════════════════════
528
+ print(f"\n{'═'*54}")
529
+ print(f"TOTAL: {P} PASS {F} FAIL")
530
+ if ERRS:
531
+ print("\nFailed:")
532
+ for e in ERRS: print(f" β€’ {e}")
533
+ else:
534
+ print("ALL TESTS PASSED βœ“")
mega_freeu_a1111/tests/test_fixes.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib, sys
2
+ exec(open(str(pathlib.Path(__file__).parent / 'mock_torch.py')).read())
3
+ import sys, math, types, dataclasses, json
4
+ import numpy as np
5
+ import pathlib; sys.path.insert(0, str(pathlib.Path(__file__).parent.parent))
6
+ import importlib.util
7
+
8
+ def load_mod(name, path):
9
+ spec = importlib.util.spec_from_file_location(name, path)
10
+ m = importlib.util.module_from_spec(spec); sys.modules[name]=m; spec.loader.exec_module(m); return m
11
+
12
+ lib_pkg = types.ModuleType('lib_mega_freeu'); sys.modules['lib_mega_freeu'] = lib_pkg
13
+ GS = load_mod('lib_mega_freeu.global_state', str(pathlib.Path(__file__).parent.parent / 'lib_mega_freeu' / 'global_state.py'))
14
+ UN = load_mod('lib_mega_freeu.unet', str(pathlib.Path(__file__).parent.parent / 'lib_mega_freeu' / 'unet.py'))
15
+
16
+ P=0; F=0; ERRS=[]
17
+ def ok(t): global P; P+=1; print(f" βœ“ {t}")
18
+ def ng(t,m=""): global F; F+=1; ERRS.append(f"{t}: {m}"); print(f" βœ— {t} {m}")
19
+ def chk(t, c, m=""): ok(t) if c else ng(t, m)
20
+
21
+ print("═"*52)
22
+ print("FIX 1: State has pcfg_ and verbose fields")
23
+ print("═"*52)
24
+ st = GS.State()
25
+ chk("pcfg_enabled default False", not st.pcfg_enabled)
26
+ chk("pcfg_steps default 20", st.pcfg_steps == 20)
27
+ chk("pcfg_mode default inject", st.pcfg_mode == "inject")
28
+ chk("verbose default False", not st.verbose)
29
+
30
+ st2 = GS.State(pcfg_enabled=True, pcfg_b=1.5, pcfg_steps=10, verbose=True)
31
+ chk("pcfg_enabled=True", st2.pcfg_enabled)
32
+ chk("pcfg_b=1.5", st2.pcfg_b == 1.5)
33
+ chk("verbose=True", st2.verbose)
34
+
35
+ print("\n═"*27)
36
+ print("FIX 2: to_dict() round-trip includes pcfg fields")
37
+ d = st2.to_dict()
38
+ chk("to_dict has pcfg_enabled", "pcfg_enabled" in d)
39
+ chk("to_dict has pcfg_b", "pcfg_b" in d)
40
+ chk("to_dict has verbose", "verbose" in d)
41
+ chk("to_dict no 'enable'", "enable" not in d)
42
+ st3 = GS.State(**{k:v for k,v in d.items() if k in {f.name for f in dataclasses.fields(GS.State)}})
43
+ chk("round-trip pcfg_b=1.5", st3.pcfg_b == 1.5)
44
+ chk("round-trip verbose=True", st3.verbose)
45
+
46
+ print("\n═"*27)
47
+ print("FIX 3: _load_user_presets saves/restores pcfg")
48
+ import json, tempfile, os, pathlib
49
+ preset_data = {
50
+ "my_pcfg_preset": {
51
+ "start_ratio": 0.0, "stop_ratio": 1.0, "transition_smoothness": 0.0,
52
+ "version": "2", "multiscale_mode": "Default", "multiscale_strength": 1.0,
53
+ "override_scales": "", "channel_threshold": 96,
54
+ "pcfg_enabled": True, "pcfg_b": 1.8, "pcfg_steps": 5,
55
+ "pcfg_mode": "lerp", "pcfg_blend": 0.7, "verbose": True,
56
+ "stage_infos": []
57
+ }
58
+ }
59
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
60
+ json.dump(preset_data, f); tmp = f.name
61
+ GS.PRESETS_PATH = pathlib.Path(tmp)
62
+ res = GS._load_user_presets()
63
+ chk("preset loaded", "my_pcfg_preset" in res)
64
+ p = res.get("my_pcfg_preset")
65
+ chk("pcfg_enabled restored", p and p.pcfg_enabled == True)
66
+ chk("pcfg_b restored", p and p.pcfg_b == 1.8)
67
+ chk("pcfg_steps restored", p and p.pcfg_steps == 5)
68
+ chk("verbose restored", p and p.verbose == True)
69
+ os.unlink(tmp)
70
+
71
+ print("\n═"*27)
72
+ print("FIX 4: _write_generation_params (simulate)")
73
+ class FakePNG:
74
+ extra_generation_params = {}
75
+
76
+ st4 = GS.State(
77
+ start_ratio=0.1, stop_ratio=0.9, transition_smoothness=0.5,
78
+ version="2", multiscale_mode="Multi-Bandpass", multiscale_strength=0.8,
79
+ override_scales="10, 1.5\n20, 0.8", channel_threshold=64,
80
+ pcfg_enabled=True, pcfg_b=1.2, pcfg_steps=15, pcfg_mode="inject",
81
+ pcfg_blend=1.0, pcfg_fourier=False, pcfg_ms_mode="Default",
82
+ pcfg_ms_str=1.0, pcfg_threshold=1, pcfg_s=0.5, pcfg_gain=1.0,
83
+ verbose=True,
84
+ stage_infos=[GS.StageInfo(backbone_factor=1.3)]
85
+ )
86
+
87
+ # Simulate _write_generation_params
88
+ fp = FakePNG()
89
+ fp.extra_generation_params["MegaFreeU Schedule"] = f"{st4.start_ratio}, {st4.stop_ratio}, {st4.transition_smoothness}"
90
+ fp.extra_generation_params["MegaFreeU Stages"] = json.dumps([si.to_dict() for si in st4.stage_infos])
91
+ fp.extra_generation_params["MegaFreeU Version"] = st4.version
92
+ fp.extra_generation_params["MegaFreeU Multiscale Mode"] = st4.multiscale_mode
93
+ fp.extra_generation_params["MegaFreeU Multiscale Strength"] = str(st4.multiscale_strength)
94
+ fp.extra_generation_params["MegaFreeU Override Scales"] = st4.override_scales
95
+ fp.extra_generation_params["MegaFreeU Channel Threshold"] = str(st4.channel_threshold)
96
+ fp.extra_generation_params["MegaFreeU Verbose"] = str(st4.verbose)
97
+ if st4.pcfg_enabled:
98
+ fp.extra_generation_params["MegaFreeU PostCFG"] = json.dumps({
99
+ "enabled": st4.pcfg_enabled, "steps": st4.pcfg_steps,
100
+ "mode": st4.pcfg_mode, "blend": st4.pcfg_blend,
101
+ "b": st4.pcfg_b, "fourier": st4.pcfg_fourier,
102
+ "ms_mode": st4.pcfg_ms_mode, "ms_str": st4.pcfg_ms_str,
103
+ "threshold": st4.pcfg_threshold, "s": st4.pcfg_s, "gain": st4.pcfg_gain,
104
+ })
105
+
106
+ eg = fp.extra_generation_params
107
+ chk("PNG has Schedule", "MegaFreeU Schedule" in eg)
108
+ chk("PNG has Stages", "MegaFreeU Stages" in eg)
109
+ chk("PNG has Version", "MegaFreeU Version" in eg)
110
+ chk("PNG has Multiscale Mode", "MegaFreeU Multiscale Mode" in eg)
111
+ chk("PNG has Multiscale Strength","MegaFreeU Multiscale Strength" in eg)
112
+ chk("PNG has Override Scales", "MegaFreeU Override Scales" in eg)
113
+ chk("PNG has Channel Threshold", "MegaFreeU Channel Threshold" in eg)
114
+ chk("PNG has Verbose", "MegaFreeU Verbose" in eg)
115
+ chk("PNG has PostCFG", "MegaFreeU PostCFG" in eg)
116
+
117
+ # Verify PostCFG round-trip
118
+ pcfg_d = json.loads(eg["MegaFreeU PostCFG"])
119
+ chk("PostCFG b=1.2", pcfg_d["b"] == 1.2)
120
+ chk("PostCFG steps=15",pcfg_d["steps"] == 15)
121
+
122
+ # Verify multiscale restore
123
+ chk("ms_mode=Multi-Bandpass", eg["MegaFreeU Multiscale Mode"] == "Multi-Bandpass")
124
+ chk("ch_thresh=64", eg["MegaFreeU Channel Threshold"] == "64")
125
+
126
+ print("\n═"*27)
127
+ print("FIX 5: Post-CFG independent of Enable (simulate process logic)")
128
+ # The fix: pcfg is set BEFORE checking st.enable
129
+ # Test: when enabled=False but pcfg_enabled=True, pcfg still created
130
+
131
+ class FakeP2:
132
+ extra_generation_params = {}
133
+ _mega_pcfg = None
134
+
135
+ fp2 = FakeP2()
136
+ # Simulate new process() logic for disabled main FreeU + enabled Post-CFG
137
+ st_disabled = GS.State(enable=False, pcfg_enabled=True, pcfg_b=1.5, pcfg_steps=10)
138
+ # Post-CFG created regardless
139
+ if st_disabled.pcfg_enabled:
140
+ fp2._mega_pcfg = {"enabled": True, "b": st_disabled.pcfg_b, "steps": st_disabled.pcfg_steps, "step": 0}
141
+ else:
142
+ fp2._mega_pcfg = {"enabled": False}
143
+
144
+ chk("pcfg set even when main disabled", fp2._mega_pcfg["enabled"] == True)
145
+ chk("pcfg_b=1.5 propagated", fp2._mega_pcfg["b"] == 1.5)
146
+
147
+ print("\n═"*27)
148
+ print("FIX 6: dict-API compat (old sd-webui-freeu alwayson_scripts)")
149
+ # Simulate the dict branch of process()
150
+ dict_args = {
151
+ "enable": True, "start_ratio": 0.2, "stop_ratio": 0.8,
152
+ "version": "2", "multiscale_mode": "Default"
153
+ }
154
+ fields = {f.name for f in dataclasses.fields(GS.State)}
155
+ GS.instance = GS.State(**{k:v for k,v in dict_args.items() if k in fields})
156
+ chk("dict API: start_ratio=0.2", GS.instance.start_ratio == 0.2)
157
+ chk("dict API: version coerced '2'", GS.instance.version == "2")
158
+ chk("dict API: pcfg defaults", not GS.instance.pcfg_enabled)
159
+
160
+ print(f"\n{'═'*52}")
161
+ print(f"NEW FIXES: {P} PASS {F} FAIL")
162
+ if ERRS:
163
+ print("\nFailed:")
164
+ for e in ERRS: print(f" β€’ {e}")
165
+ else:
166
+ print("ALL NEW FIX TESTS PASSED βœ“")
mega_freeu_a1111/tests/test_preset_pcfg.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib, sys
2
+ exec(open(str(pathlib.Path(__file__).parent / 'mock_torch.py')).read())
3
+ import sys, types, dataclasses, json
4
+ import numpy as np
5
+ import pathlib; sys.path.insert(0, str(pathlib.Path(__file__).parent.parent))
6
+ import importlib.util
7
+
8
+ def load(name, path):
9
+ spec = importlib.util.spec_from_file_location(name, path)
10
+ m = importlib.util.module_from_spec(spec); sys.modules[name]=m; spec.loader.exec_module(m); return m
11
+
12
+ lib = types.ModuleType('lib_mega_freeu'); sys.modules['lib_mega_freeu'] = lib
13
+ GS = load('lib_mega_freeu.global_state', str(pathlib.Path(__file__).parent.parent / 'lib_mega_freeu' / 'global_state.py'))
14
+ UN = load('lib_mega_freeu.unet', str(pathlib.Path(__file__).parent.parent / 'lib_mega_freeu' / 'unet.py'))
15
+
16
+ P=0; F=0; ERRS=[]
17
+ def ok(t): global P; P+=1; print(f" βœ“ {t}")
18
+ def ng(t,m=""): global F; F+=1; ERRS.append(f"{t}: {m}"); print(f" βœ— {t} {m}")
19
+ def chk(t, c, m=""): ok(t) if c else ng(t, m)
20
+
21
+ _SF = [f.name for f in dataclasses.fields(GS.StageInfo)]
22
+ _SN = len(_SF)
23
+
24
+ def _flat_to_sis(flat):
25
+ result = []
26
+ for i in range(GS.STAGES_COUNT):
27
+ chunk = flat[i*_SN:(i+1)*_SN]
28
+ si = GS.StageInfo()
29
+ for j, fname in enumerate(_SF):
30
+ if j < len(chunk): setattr(si, fname, chunk[j])
31
+ result.append(si)
32
+ return result
33
+
34
+ # Simulate _save_p with all fields
35
+ def _save_p(name, sr, sp, sm, ver, msm, mss, ovs, cht,
36
+ p_en, p_steps, p_mode, p_bl, p_b,
37
+ p_four, p_mmd, p_mst, p_thr, p_s, p_gain,
38
+ v_log, *flat):
39
+ sis = _flat_to_sis(flat)
40
+ vc = GS.ALL_VERSIONS.get(ver, "1")
41
+ GS.all_presets[name] = GS.State(
42
+ start_ratio=sr, stop_ratio=sp, transition_smoothness=sm,
43
+ version=vc, multiscale_mode=msm, multiscale_strength=float(mss),
44
+ override_scales=ovs or "", channel_threshold=int(cht),
45
+ stage_infos=sis,
46
+ pcfg_enabled=bool(p_en), pcfg_steps=int(p_steps),
47
+ pcfg_mode=str(p_mode), pcfg_blend=float(p_bl), pcfg_b=float(p_b),
48
+ pcfg_fourier=bool(p_four), pcfg_ms_mode=str(p_mmd),
49
+ pcfg_ms_str=float(p_mst), pcfg_threshold=int(p_thr),
50
+ pcfg_s=float(p_s), pcfg_gain=float(p_gain), verbose=bool(v_log),
51
+ )
52
+
53
+ # Simulate _apply_p
54
+ def _apply_p(name):
55
+ p = GS.all_presets.get(name)
56
+ if p is None: return None
57
+ flat = []
58
+ for si in p.stage_infos:
59
+ for f in _SF: flat.append(getattr(si, f))
60
+ return {
61
+ "start_ratio": p.start_ratio, "stop_ratio": p.stop_ratio,
62
+ "smooth": p.transition_smoothness,
63
+ "version": GS.REVERSED_VERSIONS.get(p.version, "Version 2"),
64
+ "ms_mode": p.multiscale_mode, "ms_str": p.multiscale_strength,
65
+ "ov_scales": p.override_scales, "ch_thresh": p.channel_threshold,
66
+ "pcfg_en": p.pcfg_enabled, "pcfg_steps": p.pcfg_steps,
67
+ "pcfg_mode": p.pcfg_mode, "pcfg_bl": p.pcfg_blend,
68
+ "pcfg_b": p.pcfg_b, "pcfg_fou": p.pcfg_fourier,
69
+ "pcfg_mmd": p.pcfg_ms_mode, "pcfg_mst": p.pcfg_ms_str,
70
+ "pcfg_thr": p.pcfg_threshold, "pcfg_s": p.pcfg_s,
71
+ "pcfg_gain": p.pcfg_gain, "verbose": p.verbose,
72
+ "flat": flat,
73
+ }
74
+
75
+ print("═"*52)
76
+ print("TEST: Full preset save/apply round-trip")
77
+ print("═"*52)
78
+
79
+ # Build flat for 3 stages
80
+ si0 = GS.StageInfo(backbone_factor=1.3, skip_factor=0.8)
81
+ flat_in = []
82
+ for si in [si0, GS.StageInfo(), GS.StageInfo()]:
83
+ for f in _SF: flat_in.append(getattr(si, f))
84
+
85
+ # Save
86
+ _save_p(
87
+ "full_preset",
88
+ 0.1, 0.9, 0.5, "Version 2",
89
+ "Multi-Bandpass", 0.7, "10,1.5", 64,
90
+ True, 15, "lerp", 0.8, 1.4,
91
+ True, "Default", 0.9, 3, 0.6, 1.2,
92
+ True,
93
+ *flat_in
94
+ )
95
+
96
+ chk("preset saved", "full_preset" in GS.all_presets)
97
+ saved = GS.all_presets["full_preset"]
98
+ chk("saved pcfg_enabled=True", saved.pcfg_enabled == True)
99
+ chk("saved pcfg_steps=15", saved.pcfg_steps == 15)
100
+ chk("saved pcfg_mode=lerp", saved.pcfg_mode == "lerp")
101
+ chk("saved pcfg_blend=0.8", abs(saved.pcfg_blend - 0.8) < 1e-5)
102
+ chk("saved pcfg_b=1.4", abs(saved.pcfg_b - 1.4) < 1e-5)
103
+ chk("saved pcfg_fourier=True", saved.pcfg_fourier == True)
104
+ chk("saved pcfg_ms_mode=Default", saved.pcfg_ms_mode == "Default")
105
+ chk("saved pcfg_ms_str=0.9", abs(saved.pcfg_ms_str - 0.9) < 1e-5)
106
+ chk("saved pcfg_threshold=3", saved.pcfg_threshold == 3)
107
+ chk("saved pcfg_s=0.6", abs(saved.pcfg_s - 0.6) < 1e-5)
108
+ chk("saved pcfg_gain=1.2", abs(saved.pcfg_gain - 1.2) < 1e-5)
109
+ chk("saved verbose=True", saved.verbose == True)
110
+ chk("saved ms_mode=Multi-Bandpass", saved.multiscale_mode == "Multi-Bandpass")
111
+ chk("saved ms_str=0.7", abs(saved.multiscale_strength - 0.7) < 1e-5)
112
+ chk("saved ch_thresh=64", saved.channel_threshold == 64)
113
+ chk("saved bf=1.3", abs(saved.stage_infos[0].backbone_factor - 1.3) < 1e-5)
114
+
115
+ # Apply
116
+ restored = _apply_p("full_preset")
117
+ chk("apply: pcfg_en=True", restored["pcfg_en"] == True)
118
+ chk("apply: pcfg_steps=15", restored["pcfg_steps"] == 15)
119
+ chk("apply: pcfg_mode=lerp", restored["pcfg_mode"] == "lerp")
120
+ chk("apply: pcfg_b=1.4", abs(restored["pcfg_b"] - 1.4) < 1e-5)
121
+ chk("apply: pcfg_fourier=True", restored["pcfg_fou"] == True)
122
+ chk("apply: verbose=True", restored["verbose"] == True)
123
+ chk("apply: ms_mode restored", restored["ms_mode"] == "Multi-Bandpass")
124
+ chk("apply: ch_thresh=64", restored["ch_thresh"] == 64)
125
+ chk("apply: bf=1.3 in flat", abs(restored["flat"][0] - 1.3) < 1e-5)
126
+
127
+ # Save with pcfg disabled β†’ defaults
128
+ _save_p(
129
+ "no_pcfg",
130
+ 0.0, 1.0, 0.0, "Version 1",
131
+ "Default", 1.0, "", 96,
132
+ False, 20, "inject", 1.0, 1.1,
133
+ False, "Default", 1.0, 1, 0.5, 1.0,
134
+ False,
135
+ *flat_in
136
+ )
137
+ r2 = _apply_p("no_pcfg")
138
+ chk("no_pcfg: pcfg_en=False", r2["pcfg_en"] == False)
139
+ chk("no_pcfg: verbose=False", r2["verbose"] == False)
140
+
141
+ # Unknown preset β†’ None
142
+ r3 = _apply_p("does_not_exist")
143
+ chk("unknown preset β†’ None", r3 is None)
144
+
145
+ # dict-API branch: pcfg_enabled passed through dict
146
+ GS.instance = GS.State()
147
+ d_api = {"enable": True, "pcfg_enabled": True, "pcfg_b": 1.9, "verbose": True}
148
+ fields = {f.name for f in dataclasses.fields(GS.State)}
149
+ GS.instance = GS.State(**{k:v for k,v in d_api.items() if k in fields})
150
+ chk("dict-API: pcfg_enabled propagated", GS.instance.pcfg_enabled == True)
151
+ chk("dict-API: pcfg_b=1.9", abs(GS.instance.pcfg_b - 1.9) < 1e-5)
152
+ chk("dict-API: verbose=True", GS.instance.verbose == True)
153
+
154
+ print(f"\n{'═'*52}")
155
+ print(f"PRESET PCFG ROUND-TRIP: {P} PASS {F} FAIL")
156
+ if ERRS:
157
+ print("\nFailed:")
158
+ for e in ERRS: print(f" β€’ {e}")
159
+ else:
160
+ print("ALL PASSED βœ“")