Upload 12 files
Browse files- mega_freeu_a1111/README.md +30 -0
- mega_freeu_a1111/install.py +3 -0
- mega_freeu_a1111/lib_mega_freeu/__init__.py +1 -0
- mega_freeu_a1111/lib_mega_freeu/global_state.py +285 -0
- mega_freeu_a1111/lib_mega_freeu/unet.py +559 -0
- mega_freeu_a1111/lib_mega_freeu/xyz_grid.py +97 -0
- mega_freeu_a1111/scripts/mega_freeu.py +765 -0
- mega_freeu_a1111/tests/README.md +31 -0
- mega_freeu_a1111/tests/mock_torch.py +229 -0
- mega_freeu_a1111/tests/test_core.py +534 -0
- mega_freeu_a1111/tests/test_fixes.py +166 -0
- mega_freeu_a1111/tests/test_preset_pcfg.py +160 -0
mega_freeu_a1111/README.md
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# β‘ Mega FreeU β A1111 Extension
|
| 2 |
+
|
| 3 |
+
**Mega FreeU** combines the best features from 5 FreeU implementations into a single
|
| 4 |
+
production-ready A1111 extension.
|
| 5 |
+
|
| 6 |
+
## Sources & Features
|
| 7 |
+
|
| 8 |
+
| Source | Features taken |
|
| 9 |
+
|--------|---------------|
|
| 10 |
+
| **sd-webui-freeu** | `th.cat` hijack (only correct A1111 approach), V1/V2 backbone, box filter (BUG FIXED), schedule start/stop/smoothness, backbone region (width+offset), presets JSON, PNG metadata, XYZ grid, ControlNet patch |
|
| 11 |
+
| **WAS FreeU_Advanced** | 9 blending modes (lerp/inject/bislerp/colorize/cosine/cuberp/hslerp/stable_slerp/linear_dodge), 13 multi-scale FFT presets, override_scales textarea, Post-CFG Shift (ported to A1111). Note: `target_block` / `input_block` / `middle_block` / `slice_b1/b2` not ported β A1111's `th.cat` hijack operates on output-side skip connections only. |
|
| 12 |
+
| **ComfyUI_FreeU_V2_Advanced** | Gaussian FFT filter (smooth, no ringing), Adaptive Cap loop (MAX_CAP_ITER=3), independent B/S timestep ranges per stage, channel_threshold matching |
|
| 13 |
+
| **FreeU_V2_timestepadd** | b_start/b_end%, s_start/s_end% per-stage step-fraction gating (note: original ComfyUI used `percent_to_sigma`; this port uses `current_step / total_steps` β conceptually equivalent for typical schedulers) |
|
| 14 |
+
| **nrs_kohaku_v3.5** | hf_boost parameter, gaussian on output, on_cpu fallback tracker |
|
| 15 |
+
|
| 16 |
+
## Bug Fixed
|
| 17 |
+
- `sdwebui-freeU-extension` had Fourier mask applied to ONE quadrant:
|
| 18 |
+
`mask[..., crow-t:crow, ccol-t:ccol]`
|
| 19 |
+
**Fixed**: `mask[..., crow-t:crow+t, ccol-t:ccol+t]` (symmetric center)
|
| 20 |
+
|
| 21 |
+
## Installation
|
| 22 |
+
Copy `mega_freeu_a1111/` into `stable-diffusion-webui/extensions/` and restart.
|
| 23 |
+
|
| 24 |
+
## Recommended Settings (SD1.5 V2 Gaussian + Independent B/S)
|
| 25 |
+
- Stage 1: B=1.2, S=0.9, FFT=gaussian, B End%=0.35, S Start%=0.35
|
| 26 |
+
- Stage 2: B=1.4, S=0.2, FFT=gaussian, B End%=0.35, S Start%=0.35
|
| 27 |
+
|
| 28 |
+
## Key Concept: Independent B/S Timestep Ranges
|
| 29 |
+
- **B-scaling** is most effective in structure phase (early steps, B Start%=0.0, B End%=0.35)
|
| 30 |
+
- **S-filtering** is most effective in detail phase (late steps, S Start%=0.35, S End%=1.0)
|
mega_freeu_a1111/install.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# install.py β auto-run by A1111 extension system (no external deps needed)
|
| 2 |
+
import launch
|
| 3 |
+
# All deps (torch, gradio, modules) are already in A1111 environment
|
mega_freeu_a1111/lib_mega_freeu/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# lib_mega_freeu β Mega FreeU for A1111
|
mega_freeu_a1111/lib_mega_freeu/global_state.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
lib_mega_freeu/global_state.py
|
| 3 |
+
Runtime state, data structures, presets for β‘ Mega FreeU.
|
| 4 |
+
|
| 5 |
+
Sources:
|
| 6 |
+
sd-webui-freeu/lib_free_u/global_state.py -- StageInfo layout, State, preset JSON, XYZ
|
| 7 |
+
WAS FreeU_Advanced/nodes.py -- BLEND_MODE_NAMES, MSCALES
|
| 8 |
+
ComfyUI_FreeU_V2_Advanced/FreeU_B1B2.py -- b_start/b_end, channel_threshold
|
| 9 |
+
ComfyUI_FreeU_V2_Advanced/FreeU_S1S2.py -- s_start/s_end, adaptive cap
|
| 10 |
+
nrs_kohaku_enhanced_v3_5.py -- hf_boost, gaussian standalone
|
| 11 |
+
"""
|
| 12 |
+
import dataclasses
|
| 13 |
+
import json
|
| 14 |
+
import math
|
| 15 |
+
import pathlib
|
| 16 |
+
import re
|
| 17 |
+
import sys
|
| 18 |
+
from typing import Any, Dict, List, Optional, Union
|
| 19 |
+
|
| 20 |
+
# βββ Blending modes (WAS FreeU_Advanced/nodes.py blending_modes keys) βββββββββ
|
| 21 |
+
BLEND_MODE_NAMES: List[str] = [
|
| 22 |
+
"lerp", "inject", "bislerp", "colorize",
|
| 23 |
+
"cosine interp", "cuberp", "hslerp", "stable_slerp", "linear dodge",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
# βββ Multi-scale presets (WAS nodes.py mscales dict -- exact) βββββββββββββββββ
|
| 27 |
+
MSCALES: Dict[str, Optional[list]] = {
|
| 28 |
+
"Default": None,
|
| 29 |
+
"Low-Pass": [(10, 1.0)],
|
| 30 |
+
"Pass-Through": [(10, 1.0)],
|
| 31 |
+
"Gaussian-Blur": [(10, 0.5)],
|
| 32 |
+
"Edge-Enhancement": [(10, 2.0)],
|
| 33 |
+
"Sharpen": [(10, 1.5)],
|
| 34 |
+
"Multi-Bandpass": [[(5, 0.0), (15, 1.0), (25, 0.0)]],
|
| 35 |
+
"Multi-Low-Pass": [[(5, 1.0), (10, 0.5), (15, 0.2)]],
|
| 36 |
+
"Multi-High-Pass": [[(5, 0.0), (10, 0.5), (15, 0.8)]],
|
| 37 |
+
"Multi-Pass-Through": [[(5, 1.0), (10, 1.0), (15, 1.0)]],
|
| 38 |
+
"Multi-Gaussian-Blur": [[(5, 0.5), (10, 0.8), (15, 0.2)]],
|
| 39 |
+
"Multi-Edge-Enhancement": [[(5, 1.2), (10, 1.5), (15, 2.0)]],
|
| 40 |
+
"Multi-Sharpen": [[(5, 1.5), (10, 2.0), (15, 2.5)]],
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
ALL_VERSIONS: Dict[str, str] = {"Version 1": "1", "Version 2": "2"}
|
| 44 |
+
REVERSED_VERSIONS: Dict[str, str] = {v: k for k, v in ALL_VERSIONS.items()}
|
| 45 |
+
FFT_TYPES: List[str] = ["gaussian", "box"]
|
| 46 |
+
STAGES_COUNT: int = 3
|
| 47 |
+
|
| 48 |
+
_shorthand_re = re.compile(r"^([a-z]{1,3})(\d+)$")
|
| 49 |
+
|
| 50 |
+
# βββ StageInfo βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 51 |
+
@dataclasses.dataclass
|
| 52 |
+
class StageInfo:
|
| 53 |
+
"""
|
| 54 |
+
All per-stage parameters.
|
| 55 |
+
Fields 1-6: same order as sd-webui-freeu for PNG backwards compat.
|
| 56 |
+
New fields appended at end.
|
| 57 |
+
"""
|
| 58 |
+
# sd-webui-freeu compat (DO NOT REORDER first 6)
|
| 59 |
+
backbone_factor: float = 1.0
|
| 60 |
+
skip_factor: float = 1.0
|
| 61 |
+
backbone_offset: float = 0.0
|
| 62 |
+
backbone_width: float = 0.5
|
| 63 |
+
skip_cutoff: float = 0.0
|
| 64 |
+
skip_high_end_factor: float = 1.0
|
| 65 |
+
# WAS blending
|
| 66 |
+
backbone_blend_mode: str = "lerp"
|
| 67 |
+
backbone_blend: float = 1.0
|
| 68 |
+
# ComfyUI V2 independent timestep ranges
|
| 69 |
+
b_start_ratio: float = 0.0
|
| 70 |
+
b_end_ratio: float = 1.0
|
| 71 |
+
s_start_ratio: float = 0.0
|
| 72 |
+
s_end_ratio: float = 1.0
|
| 73 |
+
# FFT
|
| 74 |
+
fft_type: str = "box"
|
| 75 |
+
fft_radius_ratio: float = 0.07
|
| 76 |
+
hf_boost: float = 1.0
|
| 77 |
+
# Adaptive Cap (FreeU_S1S2)
|
| 78 |
+
enable_adaptive_cap: bool = False
|
| 79 |
+
cap_threshold: float = 0.35
|
| 80 |
+
cap_factor: float = 0.6
|
| 81 |
+
adaptive_cap_mode: str = "adaptive"
|
| 82 |
+
|
| 83 |
+
def to_dict(self, include_default=False):
|
| 84 |
+
default = StageInfo()
|
| 85 |
+
d = dataclasses.asdict(self)
|
| 86 |
+
if not include_default:
|
| 87 |
+
d = {k: v for k, v in d.items() if v != getattr(default, k)}
|
| 88 |
+
return d
|
| 89 |
+
|
| 90 |
+
def copy(self):
|
| 91 |
+
return StageInfo(**dataclasses.asdict(self))
|
| 92 |
+
|
| 93 |
+
STAGE_FIELD_NAMES = [f.name for f in dataclasses.fields(StageInfo)]
|
| 94 |
+
STAGE_FIELD_COUNT = len(STAGE_FIELD_NAMES)
|
| 95 |
+
|
| 96 |
+
# βββ State βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 97 |
+
@dataclasses.dataclass
|
| 98 |
+
class State:
|
| 99 |
+
enable: bool = True
|
| 100 |
+
start_ratio: Any = 0.0
|
| 101 |
+
stop_ratio: Any = 1.0
|
| 102 |
+
transition_smoothness: float = 0.0
|
| 103 |
+
version: str = "1"
|
| 104 |
+
multiscale_mode: str = "Default"
|
| 105 |
+
multiscale_strength: float = 1.0
|
| 106 |
+
override_scales: str = ""
|
| 107 |
+
channel_threshold: int = 96
|
| 108 |
+
stage_infos: List[Any] = dataclasses.field(
|
| 109 |
+
default_factory=lambda: [StageInfo() for _ in range(STAGES_COUNT)]
|
| 110 |
+
)
|
| 111 |
+
# Post-CFG Shift (WAS_PostCFGShift) β stored in presets & PNG
|
| 112 |
+
pcfg_enabled: bool = False
|
| 113 |
+
pcfg_steps: int = 20
|
| 114 |
+
pcfg_mode: str = "inject"
|
| 115 |
+
pcfg_blend: float = 1.0
|
| 116 |
+
pcfg_b: float = 1.1
|
| 117 |
+
pcfg_fourier: bool = False
|
| 118 |
+
pcfg_ms_mode: str = "Default"
|
| 119 |
+
pcfg_ms_str: float = 1.0
|
| 120 |
+
pcfg_threshold: int = 1
|
| 121 |
+
pcfg_s: float = 0.5
|
| 122 |
+
pcfg_gain: float = 1.0
|
| 123 |
+
verbose: bool = False
|
| 124 |
+
|
| 125 |
+
def __post_init__(self):
|
| 126 |
+
self.stage_infos = self._coerce_stages()
|
| 127 |
+
self.version = ALL_VERSIONS.get(self.version, self.version)
|
| 128 |
+
|
| 129 |
+
def _coerce_stages(self):
|
| 130 |
+
result, raw = [], list(self.stage_infos)
|
| 131 |
+
i = 0
|
| 132 |
+
while i < len(raw) and len(result) < STAGES_COUNT:
|
| 133 |
+
item = raw[i]
|
| 134 |
+
if isinstance(item, StageInfo):
|
| 135 |
+
result.append(item); i += 1
|
| 136 |
+
elif isinstance(item, dict):
|
| 137 |
+
known = {k: v for k, v in item.items() if k in STAGE_FIELD_NAMES}
|
| 138 |
+
result.append(StageInfo(**known)); i += 1
|
| 139 |
+
else:
|
| 140 |
+
chunk = raw[i:i+STAGE_FIELD_COUNT]
|
| 141 |
+
result.append(StageInfo(*chunk))
|
| 142 |
+
i += STAGE_FIELD_COUNT
|
| 143 |
+
while len(result) < STAGES_COUNT:
|
| 144 |
+
result.append(StageInfo())
|
| 145 |
+
return result
|
| 146 |
+
|
| 147 |
+
def to_dict(self):
|
| 148 |
+
d = dataclasses.asdict(self)
|
| 149 |
+
d["stage_infos"] = [si.to_dict() for si in self.stage_infos]
|
| 150 |
+
del d["enable"]
|
| 151 |
+
return d
|
| 152 |
+
|
| 153 |
+
def copy(self):
|
| 154 |
+
d = dataclasses.asdict(self)
|
| 155 |
+
d["stage_infos"] = [StageInfo(**s) for s in d["stage_infos"]]
|
| 156 |
+
return State(**d)
|
| 157 |
+
|
| 158 |
+
def update_attr(self, key, value):
|
| 159 |
+
if m := _shorthand_re.match(key):
|
| 160 |
+
char, idx = m.group(1), int(m.group(2))
|
| 161 |
+
if 0 <= idx < STAGES_COUNT:
|
| 162 |
+
si = self.stage_infos[idx]
|
| 163 |
+
_MAP = {
|
| 164 |
+
"b":"backbone_factor","s":"skip_factor","o":"backbone_offset",
|
| 165 |
+
"w":"backbone_width","t":"skip_cutoff","h":"skip_high_end_factor",
|
| 166 |
+
"bm":"backbone_blend_mode","bb":"backbone_blend",
|
| 167 |
+
"bs":"b_start_ratio","be":"b_end_ratio",
|
| 168 |
+
"ss":"s_start_ratio","se":"s_end_ratio",
|
| 169 |
+
"ft":"fft_type","fr":"fft_radius_ratio","hfb":"hf_boost",
|
| 170 |
+
"cap":"enable_adaptive_cap","ct":"cap_threshold","cf":"cap_factor","acm":"adaptive_cap_mode",
|
| 171 |
+
}
|
| 172 |
+
if char in _MAP:
|
| 173 |
+
setattr(si, _MAP[char], value); return
|
| 174 |
+
if hasattr(self, key):
|
| 175 |
+
setattr(self, key, value)
|
| 176 |
+
|
| 177 |
+
# βββ Singletons ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 178 |
+
instance: State = State()
|
| 179 |
+
xyz_attrs: Dict[str, Any] = {}
|
| 180 |
+
current_sampling_step: int = 0
|
| 181 |
+
|
| 182 |
+
# βββ Preset builders βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 183 |
+
def _v1(*pairs):
|
| 184 |
+
infos = [StageInfo(backbone_factor=b, skip_factor=s) for b,s in pairs]
|
| 185 |
+
while len(infos) < STAGES_COUNT: infos.append(StageInfo())
|
| 186 |
+
return State(version="1", stage_infos=infos)
|
| 187 |
+
|
| 188 |
+
def _v2g(pairs):
|
| 189 |
+
infos = []
|
| 190 |
+
for b, s, r, hfb, bs, be, ss, se in pairs:
|
| 191 |
+
infos.append(StageInfo(
|
| 192 |
+
backbone_factor=b, skip_factor=s,
|
| 193 |
+
fft_type="gaussian", fft_radius_ratio=r, hf_boost=hfb,
|
| 194 |
+
b_start_ratio=bs, b_end_ratio=be,
|
| 195 |
+
s_start_ratio=ss, s_end_ratio=se,
|
| 196 |
+
))
|
| 197 |
+
while len(infos) < STAGES_COUNT: infos.append(StageInfo())
|
| 198 |
+
return State(version="2", stage_infos=infos)
|
| 199 |
+
|
| 200 |
+
default_presets: Dict[str, State] = {
|
| 201 |
+
"SD1.4 Recommendations": _v1((1.2,0.9),(1.4,0.2),(1.0,1.0)),
|
| 202 |
+
"SD2.1 Recommendations": _v1((1.1,0.9),(1.2,0.2),(1.0,1.0)),
|
| 203 |
+
"SDXL Recommendations": _v1((1.1,0.6),(1.2,0.4),(1.0,1.0)),
|
| 204 |
+
"SD1.5 V2 Gaussian": _v2g([
|
| 205 |
+
(1.2,0.9,0.07,1.0, 0.0,0.35, 0.35,1.0),
|
| 206 |
+
(1.4,0.2,0.07,1.0, 0.0,0.35, 0.35,1.0),
|
| 207 |
+
]),
|
| 208 |
+
"SD1.5 V2 High Detail": _v2g([
|
| 209 |
+
(1.4,0.8,0.08,1.2, 0.0,0.35, 0.35,1.0),
|
| 210 |
+
(1.6,0.1,0.06,1.0, 0.0,0.35, 0.35,1.0),
|
| 211 |
+
]),
|
| 212 |
+
"SDXL V2 Gaussian": _v2g([
|
| 213 |
+
(1.1,0.6,0.05,1.1, 0.0,0.35, 0.35,1.0),
|
| 214 |
+
(1.2,0.4,0.05,1.1, 0.0,0.35, 0.35,1.0),
|
| 215 |
+
]),
|
| 216 |
+
"SD1.5 Adaptive Cap": State(version="2", stage_infos=[
|
| 217 |
+
StageInfo(backbone_factor=1.3,skip_factor=0.9,
|
| 218 |
+
fft_type="gaussian",fft_radius_ratio=0.08,hf_boost=1.2,
|
| 219 |
+
b_start_ratio=0.0,b_end_ratio=0.35,
|
| 220 |
+
s_start_ratio=0.35,s_end_ratio=1.0,
|
| 221 |
+
enable_adaptive_cap=True,cap_threshold=0.35,
|
| 222 |
+
cap_factor=0.6,adaptive_cap_mode="adaptive"),
|
| 223 |
+
StageInfo(backbone_factor=1.4,skip_factor=0.2,
|
| 224 |
+
fft_type="gaussian",fft_radius_ratio=0.06,hf_boost=1.0,
|
| 225 |
+
b_start_ratio=0.0,b_end_ratio=0.35,
|
| 226 |
+
s_start_ratio=0.35,s_end_ratio=1.0,
|
| 227 |
+
enable_adaptive_cap=True,cap_threshold=0.70,
|
| 228 |
+
cap_factor=0.6,adaptive_cap_mode="adaptive"),
|
| 229 |
+
StageInfo(),
|
| 230 |
+
]),
|
| 231 |
+
"Independent B/S (SD1.5)": _v2g([
|
| 232 |
+
(1.2,0.9,0.07,1.0, 0.0,0.35, 0.35,1.0),
|
| 233 |
+
(1.4,0.2,0.06,1.0, 0.0,0.35, 0.35,1.0),
|
| 234 |
+
]),
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
all_presets: Dict[str, State] = {}
|
| 238 |
+
PRESETS_PATH = pathlib.Path(__file__).parent.parent / "presets.json"
|
| 239 |
+
|
| 240 |
+
def reload_presets():
|
| 241 |
+
all_presets.clear()
|
| 242 |
+
all_presets.update(default_presets)
|
| 243 |
+
all_presets.update(_load_user_presets())
|
| 244 |
+
|
| 245 |
+
def _load_user_presets():
|
| 246 |
+
if not PRESETS_PATH.exists(): return {}
|
| 247 |
+
try:
|
| 248 |
+
with open(PRESETS_PATH, encoding="utf-8") as f:
|
| 249 |
+
raw = json.load(f)
|
| 250 |
+
except Exception as e:
|
| 251 |
+
print(f"[MegaFreeU] preset load error: {e}", file=sys.stderr)
|
| 252 |
+
return {}
|
| 253 |
+
result = {}
|
| 254 |
+
_state_fields = {f.name for f in dataclasses.fields(State)}
|
| 255 |
+
for k, v in raw.items():
|
| 256 |
+
try:
|
| 257 |
+
# Filter unknown keys so future/old fields don't crash State(**v)
|
| 258 |
+
known = {fk: fv for fk, fv in v.items() if fk in _state_fields}
|
| 259 |
+
result[k] = State(**known)
|
| 260 |
+
except Exception as e:
|
| 261 |
+
print(f"[MegaFreeU] skipping preset {k!r}: {e}", file=sys.stderr)
|
| 262 |
+
return result
|
| 263 |
+
|
| 264 |
+
def save_presets(custom=None):
|
| 265 |
+
if custom is None: custom = get_user_presets()
|
| 266 |
+
try:
|
| 267 |
+
PRESETS_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 268 |
+
with open(PRESETS_PATH, "w", encoding="utf-8") as f:
|
| 269 |
+
json.dump({k: v.to_dict() for k,v in custom.items()}, f, indent=4)
|
| 270 |
+
except Exception as e:
|
| 271 |
+
print(f"[MegaFreeU] preset save error: {e}", file=sys.stderr)
|
| 272 |
+
|
| 273 |
+
def get_user_presets():
|
| 274 |
+
return {k: v for k,v in all_presets.items() if k not in default_presets}
|
| 275 |
+
|
| 276 |
+
def apply_xyz():
|
| 277 |
+
global instance
|
| 278 |
+
if pk := xyz_attrs.get("preset"):
|
| 279 |
+
if p := all_presets.get(pk):
|
| 280 |
+
instance = p.copy()
|
| 281 |
+
elif pk != "UI Settings":
|
| 282 |
+
print(f"[MegaFreeU] XYZ preset '{pk}' not found", file=sys.stderr)
|
| 283 |
+
for k, v in xyz_attrs.items():
|
| 284 |
+
if k != "preset":
|
| 285 |
+
instance.update_attr(k, v)
|
mega_freeu_a1111/lib_mega_freeu/unet.py
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
lib_mega_freeu/unet.py β Math engine + A1111 th.cat patch
|
| 3 |
+
|
| 4 |
+
BUGS FIXED vs sdwebui-freeU-extension/scripts/freeunet_hijack.py:
|
| 5 |
+
BUG 1 dtype: mask = torch.ones(..., dtype=torch.bool)
|
| 6 |
+
bool*float = NOOP, scale always 1.0
|
| 7 |
+
Fix: torch.full(..., float(scale_high))
|
| 8 |
+
BUG 2 quadrant: mask[..., crow-t:crow, ccol-t:ccol] (top-left only)
|
| 9 |
+
Fix: mask[..., crow-t:crow+t, ccol-t:ccol+t] (symmetric center)
|
| 10 |
+
|
| 11 |
+
Sources:
|
| 12 |
+
sd-webui-freeu/lib_free_u/unet.py patch(), free_u_cat_hijack(),
|
| 13 |
+
get_backbone_scale(), ratio_to_region(), filter_skip()[box],
|
| 14 |
+
get_schedule_ratio(), is_gpu_complex_supported(), lerp()
|
| 15 |
+
WAS FreeU_Advanced/nodes.py 9 blending modes, Fourier_filter() multiscale
|
| 16 |
+
ComfyUI_FreeU_V2_advanced/utils.py Fourier_filter_gauss(), get_band_energy_stats()
|
| 17 |
+
ComfyUI_FreeU_V2_advanced/FreeU_S1S2.py Adaptive Cap loop MAX_CAP_ITER=3
|
| 18 |
+
ComfyUI_FreeU_V2_advanced/FreeU_B1B2.py channel_threshold, model_channels*4/2/1
|
| 19 |
+
FreeU_V2_timestepadd.py step-fraction timestep gating concept
|
| 20 |
+
nrs_kohaku_enhanced_v3_5.py _freeu_b_scale_h, _freeu_fourier_filter_gaussian,
|
| 21 |
+
hf_boost param, on_cpu_devices dict
|
| 22 |
+
"""
|
| 23 |
+
import dataclasses
|
| 24 |
+
import functools
|
| 25 |
+
import logging
|
| 26 |
+
import math
|
| 27 |
+
import pathlib
|
| 28 |
+
import sys
|
| 29 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
from lib_mega_freeu import global_state
|
| 33 |
+
|
| 34 |
+
# ββ GPU complex support (sd-webui-freeu exact) ββββββββββββββββββββββββββββββββ
|
| 35 |
+
_gpu_complex_support: Optional[bool] = None
|
| 36 |
+
|
| 37 |
+
def is_gpu_complex_supported(x: torch.Tensor) -> bool:
|
| 38 |
+
global _gpu_complex_support
|
| 39 |
+
if x.is_cpu:
|
| 40 |
+
return True
|
| 41 |
+
if _gpu_complex_support is not None:
|
| 42 |
+
return _gpu_complex_support
|
| 43 |
+
mps_avail = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
| 44 |
+
try:
|
| 45 |
+
import torch_directml
|
| 46 |
+
except ImportError:
|
| 47 |
+
dml_avail = False
|
| 48 |
+
else:
|
| 49 |
+
dml_avail = torch_directml.is_available()
|
| 50 |
+
_gpu_complex_support = not (mps_avail or dml_avail)
|
| 51 |
+
if _gpu_complex_support:
|
| 52 |
+
try: torch.fft.fftn(x.float(), dim=(-2, -1))
|
| 53 |
+
except RuntimeError: _gpu_complex_support = False
|
| 54 |
+
return _gpu_complex_support
|
| 55 |
+
|
| 56 |
+
_on_cpu_devices: Dict = {}
|
| 57 |
+
|
| 58 |
+
# ββ Blending modes (WAS nodes.py exact) βββββββββββββββββββββββββββββββββββββββ
|
| 59 |
+
def _normalize(t):
|
| 60 |
+
mn, mx = t.min(), t.max()
|
| 61 |
+
return (t - mn) / (mx - mn + 1e-8)
|
| 62 |
+
|
| 63 |
+
def _hslerp(a, b, t):
|
| 64 |
+
nc = a.size(1)
|
| 65 |
+
iv = torch.zeros(1, nc, 1, 1, device=a.device, dtype=a.dtype)
|
| 66 |
+
iv[0, 0, 0, 0] = 1.0
|
| 67 |
+
result = (1 - t) * a + t * b
|
| 68 |
+
if t < 0.5:
|
| 69 |
+
result += (torch.norm(b - a, dim=1, keepdim=True) / 6) * iv
|
| 70 |
+
else:
|
| 71 |
+
result -= (torch.norm(b - a, dim=1, keepdim=True) / 6) * iv
|
| 72 |
+
return result
|
| 73 |
+
|
| 74 |
+
def _stable_slerp(a, b, t, eps=1e-6):
|
| 75 |
+
an = a / torch.linalg.norm(a, dim=1, keepdim=True).clamp_min(eps)
|
| 76 |
+
bn = b / torch.linalg.norm(b, dim=1, keepdim=True).clamp_min(eps)
|
| 77 |
+
dot = (an * bn).sum(dim=1, keepdim=True).clamp(-1.0 + eps, 1.0 - eps)
|
| 78 |
+
theta = torch.acos(dot)
|
| 79 |
+
sin_t = torch.sin(theta).clamp_min(eps)
|
| 80 |
+
s0 = torch.sin((1.0 - t) * theta) / sin_t
|
| 81 |
+
s1 = torch.sin(t * theta) / sin_t
|
| 82 |
+
slerp_out = s0 * a + s1 * b
|
| 83 |
+
lerp_out = (1.0 - t) * a + t * b
|
| 84 |
+
use_lerp = (theta < 1e-3).squeeze(1)
|
| 85 |
+
return torch.where(use_lerp.unsqueeze(1), lerp_out, slerp_out)
|
| 86 |
+
|
| 87 |
+
BLENDING_MODES = {
|
| 88 |
+
"bislerp": lambda a, b, t: _normalize((1 - t) * a + t * b),
|
| 89 |
+
"colorize": lambda a, b, t: a + (b - a) * t,
|
| 90 |
+
"cosine interp": lambda a, b, t: (
|
| 91 |
+
a + b - (a - b) * torch.cos(
|
| 92 |
+
torch.tensor(math.pi, device=a.device, dtype=a.dtype) * t)) / 2,
|
| 93 |
+
"cuberp": lambda a, b, t: a + (b - a) * (3 * t**2 - 2 * t**3),
|
| 94 |
+
"hslerp": _hslerp,
|
| 95 |
+
"stable_slerp": _stable_slerp,
|
| 96 |
+
"inject": lambda a, b, t: a + b * t,
|
| 97 |
+
"lerp": lambda a, b, t: (1 - t) * a + t * b,
|
| 98 |
+
"linear dodge": lambda a, b, t: _normalize(a + b * t),
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
def lerp(a, b, r):
|
| 102 |
+
return (1 - r) * a + r * b
|
| 103 |
+
|
| 104 |
+
# ββ Backbone scaling ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 105 |
+
def get_backbone_scale(h: torch.Tensor, backbone_factor: float, version: str):
|
| 106 |
+
if version == "1":
|
| 107 |
+
return backbone_factor
|
| 108 |
+
# V2: adaptive hidden_mean (FreeU_B1B2.py + kohaku _freeu_b_scale_h exact)
|
| 109 |
+
features_mean = h.mean(1, keepdim=True)
|
| 110 |
+
B = features_mean.shape[0]
|
| 111 |
+
hidden_max, _ = torch.max(features_mean.view(B, -1), dim=-1, keepdim=True)
|
| 112 |
+
hidden_min, _ = torch.min(features_mean.view(B, -1), dim=-1, keepdim=True)
|
| 113 |
+
denom = (hidden_max - hidden_min).clamp_min(1e-6)
|
| 114 |
+
hidden_mean = (features_mean - hidden_min.unsqueeze(2).unsqueeze(3)) \
|
| 115 |
+
/ denom.unsqueeze(2).unsqueeze(3)
|
| 116 |
+
return 1 + (backbone_factor - 1) * hidden_mean
|
| 117 |
+
|
| 118 |
+
def ratio_to_region(width: float, offset: float, n: int) -> Tuple[int, int, bool]:
|
| 119 |
+
"""sd-webui-freeu ratio_to_region exact."""
|
| 120 |
+
if width < 0:
|
| 121 |
+
offset += width; width = -width
|
| 122 |
+
width = min(width, 1.0)
|
| 123 |
+
if offset < 0:
|
| 124 |
+
offset = 1 + offset - int(offset)
|
| 125 |
+
offset = math.fmod(offset, 1.0)
|
| 126 |
+
if width + offset <= 1:
|
| 127 |
+
return round(offset * n), round((width + offset) * n), False
|
| 128 |
+
else:
|
| 129 |
+
return round((width + offset - 1) * n), round(offset * n), True
|
| 130 |
+
|
| 131 |
+
# ββ Box FFT (BUGS FIXED symmetric center + float dtype) βββββββββββββββββββββ
|
| 132 |
+
def filter_skip_box(x: torch.Tensor, cutoff: float,
|
| 133 |
+
scale: float, scale_high: float = 1.0) -> torch.Tensor:
|
| 134 |
+
"""
|
| 135 |
+
FreeU box filter with TWO BUGS FIXED from sdwebui-freeU-extension:
|
| 136 |
+
BUG 1 (dtype): was torch.bool mask -> scale multiplication was NOOP
|
| 137 |
+
BUG 2 (region): was [crow-t:crow, ccol-t:ccol] -> single quadrant top-left
|
| 138 |
+
Both fixed: torch.full float + symmetric [crow-t:crow+t, ccol-t:ccol+t].
|
| 139 |
+
sd-webui-freeu has these correct already, we match their implementation.
|
| 140 |
+
"""
|
| 141 |
+
if scale == 1.0 and scale_high == 1.0:
|
| 142 |
+
return x
|
| 143 |
+
fft_dev = x.device if is_gpu_complex_supported(x) else torch.device("cpu")
|
| 144 |
+
x_freq = torch.fft.fftn(x.to(fft_dev, dtype=torch.float32), dim=(-2, -1))
|
| 145 |
+
x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
|
| 146 |
+
B, C, H, W = x_freq.shape
|
| 147 |
+
mask = torch.full((B, C, H, W), float(scale_high), device=fft_dev) # FIX: float, not bool
|
| 148 |
+
crow, ccol = H // 2, W // 2
|
| 149 |
+
tr = max(1, math.floor(crow * cutoff)) if cutoff > 0 else 1
|
| 150 |
+
tc = max(1, math.floor(ccol * cutoff)) if cutoff > 0 else 1
|
| 151 |
+
mask[..., crow - tr:crow + tr, ccol - tc:ccol + tc] = scale # FIX: symmetric center
|
| 152 |
+
x_freq *= mask
|
| 153 |
+
x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1))
|
| 154 |
+
return torch.fft.ifftn(x_freq, dim=(-2, -1)).real.to(device=x.device, dtype=x.dtype)
|
| 155 |
+
|
| 156 |
+
# ββ Box + WAS multiscale overlay (WAS nodes.py Fourier_filter exact) βββββββββ
|
| 157 |
+
def filter_skip_box_multiscale(x: torch.Tensor, cutoff: float, scale: float,
|
| 158 |
+
scales_preset: Optional[list],
|
| 159 |
+
strength: float = 1.0,
|
| 160 |
+
scale_high: float = 1.0) -> torch.Tensor:
|
| 161 |
+
"""
|
| 162 |
+
WAS FreeU_Advanced/nodes.py Fourier_filter(x, threshold, scale, scales, strength).
|
| 163 |
+
threshold = cutoff: float ratio [0-1] or int pixels (WAS uses int default=1).
|
| 164 |
+
scales: None, list of (radius_px, val) single-scale, or list of lists multi-scale.
|
| 165 |
+
"""
|
| 166 |
+
if scale == 1.0 and scale_high == 1.0 and scales_preset is None:
|
| 167 |
+
return x
|
| 168 |
+
fft_dev = x.device if is_gpu_complex_supported(x) else torch.device("cpu")
|
| 169 |
+
x_freq = torch.fft.fftn(x.to(fft_dev, dtype=torch.float32), dim=(-2, -1))
|
| 170 |
+
x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
|
| 171 |
+
B, C, H, W = x_freq.shape
|
| 172 |
+
crow, ccol = H // 2, W // 2
|
| 173 |
+
if isinstance(cutoff, float) and 0 < cutoff <= 1.0:
|
| 174 |
+
tr = max(1, math.floor(crow * cutoff)); tc = max(1, math.floor(ccol * cutoff))
|
| 175 |
+
else:
|
| 176 |
+
t = max(1, int(cutoff)) if cutoff > 0 else 1; tr = tc = t
|
| 177 |
+
mask = torch.ones((B, C, H, W), device=fft_dev)
|
| 178 |
+
mask[..., crow - tr:crow + tr, ccol - tc:ccol + tc] = scale
|
| 179 |
+
if scale_high != 1.0:
|
| 180 |
+
hfm = torch.full((B, C, H, W), float(scale_high), device=fft_dev)
|
| 181 |
+
hfm[..., crow - tr:crow + tr, ccol - tc:ccol + tc] = 1.0
|
| 182 |
+
mask = mask * hfm
|
| 183 |
+
if scales_preset:
|
| 184 |
+
if isinstance(scales_preset[0], tuple):
|
| 185 |
+
# WAS single-scale mode
|
| 186 |
+
for scale_threshold, scale_value in scales_preset:
|
| 187 |
+
sv = scale_value * strength
|
| 188 |
+
sm = torch.ones((B, C, H, W), device=fft_dev)
|
| 189 |
+
st = max(1, int(scale_threshold))
|
| 190 |
+
sm[..., crow - st:crow + st, ccol - st:ccol + st] = sv
|
| 191 |
+
mask = mask + (sm - mask) * strength
|
| 192 |
+
else:
|
| 193 |
+
# WAS multi-scale mode
|
| 194 |
+
for scale_params in scales_preset:
|
| 195 |
+
if isinstance(scale_params, list):
|
| 196 |
+
for scale_threshold, scale_value in scale_params:
|
| 197 |
+
sv = scale_value * strength
|
| 198 |
+
sm = torch.ones((B, C, H, W), device=fft_dev)
|
| 199 |
+
st = max(1, int(scale_threshold))
|
| 200 |
+
sm[..., crow - st:crow + st, ccol - st:ccol + st] = sv
|
| 201 |
+
mask = mask + (sm - mask) * strength
|
| 202 |
+
x_freq = x_freq * mask
|
| 203 |
+
x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1))
|
| 204 |
+
return torch.fft.ifftn(x_freq, dim=(-2, -1)).real.to(device=x.device, dtype=x.dtype)
|
| 205 |
+
|
| 206 |
+
# ββ Gaussian FFT (ComfyUI utils.py exact) ββββββββββββββββββββββββββββββββββββ
|
| 207 |
+
def fourier_filter_gauss(x: torch.Tensor, radius_ratio: float,
|
| 208 |
+
scale: float, hf_boost: float = 1.0) -> torch.Tensor:
|
| 209 |
+
"""
|
| 210 |
+
ComfyUI_FreeU_V2_advanced/utils.py Fourier_filter_gauss() exact.
|
| 211 |
+
Also matches kohaku _freeu_fourier_filter_gaussian().
|
| 212 |
+
R = max(1, int(min(H,W)*radius_ratio))
|
| 213 |
+
sigma_f = R^2/2
|
| 214 |
+
center = exp(-dist2/sigma_f)
|
| 215 |
+
mask = scale*center + hf_boost*(1-center)
|
| 216 |
+
"""
|
| 217 |
+
x_f = torch.fft.fftn(x.float(), dim=(-2, -1))
|
| 218 |
+
x_f = torch.fft.fftshift(x_f, dim=(-2, -1))
|
| 219 |
+
B, C, H, W = x_f.shape
|
| 220 |
+
R = max(1, int(min(H, W) * radius_ratio))
|
| 221 |
+
sigma_f = max(1e-6, (R * R) / 2.0)
|
| 222 |
+
yy, xx = torch.meshgrid(
|
| 223 |
+
torch.arange(H, device=x.device, dtype=torch.float32) - H // 2,
|
| 224 |
+
torch.arange(W, device=x.device, dtype=torch.float32) - W // 2,
|
| 225 |
+
indexing="ij")
|
| 226 |
+
center = torch.exp(-(yy**2 + xx**2) / sigma_f)
|
| 227 |
+
mask = (scale * center + hf_boost * (1.0 - center)).view(1, 1, H, W)
|
| 228 |
+
x_f = x_f * mask
|
| 229 |
+
x_f = torch.fft.ifftshift(x_f, dim=(-2, -1))
|
| 230 |
+
return torch.fft.ifftn(x_f, dim=(-2, -1)).real.to(x.dtype)
|
| 231 |
+
|
| 232 |
+
# ββ Band energy stats (ComfyUI utils.py exact) ββββββββββββββββββββββββββββββββ
|
| 233 |
+
def get_band_energy_stats(x: torch.Tensor, R: int) -> Tuple[float, float, float]:
|
| 234 |
+
"""ComfyUI_FreeU_V2_advanced/utils.py get_band_energy_stats() exact."""
|
| 235 |
+
xf = torch.fft.fftn(x.float(), dim=(-2, -1))
|
| 236 |
+
xf = torch.fft.fftshift(xf, dim=(-2, -1))
|
| 237 |
+
B, C, H, W = xf.shape
|
| 238 |
+
yy, xx = torch.meshgrid(
|
| 239 |
+
torch.arange(H, device=x.device, dtype=torch.float32) - H // 2,
|
| 240 |
+
torch.arange(W, device=x.device, dtype=torch.float32) - W // 2,
|
| 241 |
+
indexing="ij")
|
| 242 |
+
lf_mask = (yy**2 + xx**2) <= (R * R)
|
| 243 |
+
mag2 = xf.real**2 + xf.imag**2
|
| 244 |
+
# FIX: expand_as requires same ndim; use 2D mask on last dims
|
| 245 |
+
lf_e = mag2[:, :, lf_mask].mean().item() if lf_mask.any() else 0.0
|
| 246 |
+
hf_e = mag2[:, :, ~lf_mask].mean().item() if (~lf_mask).any() else 0.0
|
| 247 |
+
cover = lf_mask.sum().item() / (H * W) * 100.0
|
| 248 |
+
return lf_e, hf_e, cover
|
| 249 |
+
|
| 250 |
+
# ββ Adaptive Cap Gaussian (FreeU_S1S2.py MAX_CAP_ITER=3 exact) βββββββββββββββ
|
| 251 |
+
def filter_skip_gaussian_adaptive(hsp: torch.Tensor,
|
| 252 |
+
si: "global_state.StageInfo",
|
| 253 |
+
verbose: bool = False) -> torch.Tensor:
|
| 254 |
+
"""
|
| 255 |
+
ComfyUI_FreeU_V2_advanced/FreeU_S1S2.py exact algorithm:
|
| 256 |
+
1. Compute LF/HF ratio before.
|
| 257 |
+
2. Apply Gaussian filter.
|
| 258 |
+
3. If enable_adaptive_cap and drop > cap_threshold: loop up to MAX_CAP_ITER=3.
|
| 259 |
+
adaptive mode: eff_factor = cap_factor * (cap_threshold / drop)
|
| 260 |
+
fixed mode: eff_factor = cap_factor
|
| 261 |
+
capped_s = 1 - eff_factor*(1-s_scale) [interpolate FROM ORIGINAL]
|
| 262 |
+
capped_s = max(capped_s, current_s*(1+1e-4))
|
| 263 |
+
Re-apply from original_hsp with capped_s.
|
| 264 |
+
hf_boost combined = max(si.hf_boost, si.skip_high_end_factor) [kohaku pattern]
|
| 265 |
+
"""
|
| 266 |
+
s_scale = si.skip_factor
|
| 267 |
+
radius_r = si.fft_radius_ratio
|
| 268 |
+
hf_boost = max(si.hf_boost, si.skip_high_end_factor)
|
| 269 |
+
orig_dev = hsp.device
|
| 270 |
+
H, W = hsp.shape[-2:]
|
| 271 |
+
R_eff = max(1, int(min(H, W) * radius_r))
|
| 272 |
+
|
| 273 |
+
# CRITICAL ORDER: init cpu-fallback flag and helpers BEFORE any FFT call
|
| 274 |
+
use_cpu = _on_cpu_devices.get(orig_dev, not is_gpu_complex_supported(hsp))
|
| 275 |
+
if use_cpu:
|
| 276 |
+
_on_cpu_devices[orig_dev] = True
|
| 277 |
+
|
| 278 |
+
def _tod(t): # to FFT-safe device
|
| 279 |
+
return t.cpu() if use_cpu else t
|
| 280 |
+
|
| 281 |
+
def _fromd(t): # back to original device
|
| 282 |
+
return t.to(orig_dev) if use_cpu else t
|
| 283 |
+
|
| 284 |
+
def _energy(t):
|
| 285 |
+
return get_band_energy_stats(_tod(t), R_eff)
|
| 286 |
+
|
| 287 |
+
def _filt(inp, sc):
|
| 288 |
+
nonlocal use_cpu
|
| 289 |
+
try:
|
| 290 |
+
out = fourier_filter_gauss(_tod(inp), radius_r, sc, hf_boost)
|
| 291 |
+
return _fromd(out)
|
| 292 |
+
except Exception:
|
| 293 |
+
if not use_cpu:
|
| 294 |
+
logging.warning(f"[MegaFreeU] {orig_dev} -> CPU fallback for FFT")
|
| 295 |
+
_on_cpu_devices[orig_dev] = True
|
| 296 |
+
use_cpu = True
|
| 297 |
+
return fourier_filter_gauss(inp.cpu(), radius_r, sc, hf_boost).to(orig_dev)
|
| 298 |
+
return inp
|
| 299 |
+
|
| 300 |
+
# Pre-filter energy (now safe on all devices)
|
| 301 |
+
lf_b, hf_b, cover = _energy(hsp)
|
| 302 |
+
ratio_b = lf_b / hf_b if hf_b > 1e-6 else float("inf")
|
| 303 |
+
if verbose:
|
| 304 |
+
logging.info(f"[MegaFreeU] Gauss {H}x{W} R={R_eff}px cov={cover:.1f}% "
|
| 305 |
+
f"LF={lf_b:.3e} HF={hf_b:.3e} ratio_b={ratio_b:.4f}")
|
| 306 |
+
|
| 307 |
+
hsp_filt = _filt(hsp, s_scale)
|
| 308 |
+
if not si.enable_adaptive_cap:
|
| 309 |
+
return hsp_filt
|
| 310 |
+
|
| 311 |
+
MAX_CAP_ITER = 3
|
| 312 |
+
original_hsp = hsp
|
| 313 |
+
current_s = s_scale
|
| 314 |
+
lf_a, hf_a, _ = _energy(hsp_filt)
|
| 315 |
+
ratio_a = lf_a / hf_a if hf_a > 1e-6 else float("inf")
|
| 316 |
+
drop = 1.0 - (ratio_a / ratio_b) if ratio_b > 1e-6 else 0.0
|
| 317 |
+
orig_drop = drop
|
| 318 |
+
iters = 0
|
| 319 |
+
hsp_cur = hsp_filt
|
| 320 |
+
|
| 321 |
+
while (si.enable_adaptive_cap
|
| 322 |
+
and drop > si.cap_threshold
|
| 323 |
+
and current_s < 0.999
|
| 324 |
+
and iters < MAX_CAP_ITER):
|
| 325 |
+
|
| 326 |
+
if iters == 0:
|
| 327 |
+
logging.warning(f"[MegaFreeU] Over-attenuation: drop={drop*100:.1f}% > "
|
| 328 |
+
f"{si.cap_threshold*100:.1f}% s={s_scale:.4f}")
|
| 329 |
+
|
| 330 |
+
eff_f = si.cap_factor
|
| 331 |
+
if si.adaptive_cap_mode == "adaptive":
|
| 332 |
+
eff_f = si.cap_factor * (si.cap_threshold / max(drop, 1e-8))
|
| 333 |
+
|
| 334 |
+
capped_s = 1.0 - eff_f * (1.0 - s_scale) # interpolate from ORIGINAL s
|
| 335 |
+
capped_s = max(capped_s, current_s * (1.0 + 1e-4)) # only ever relax
|
| 336 |
+
if abs(capped_s - current_s) < 1e-4:
|
| 337 |
+
if verbose: logging.info(" Cap converged.")
|
| 338 |
+
break
|
| 339 |
+
|
| 340 |
+
if verbose:
|
| 341 |
+
logging.info(f" Cap iter {iters+1}: s {current_s:.4f}->{capped_s:.4f} eff={eff_f:.4f}")
|
| 342 |
+
|
| 343 |
+
try:
|
| 344 |
+
hsp_new = _filt(original_hsp, capped_s)
|
| 345 |
+
except Exception as e:
|
| 346 |
+
logging.error(f"[MegaFreeU] cap re-apply error: {e}")
|
| 347 |
+
hsp_cur = original_hsp # restore to original on error (ComfyUI FreeU_S1S2.py pattern)
|
| 348 |
+
break
|
| 349 |
+
|
| 350 |
+
hsp_cur = hsp_new
|
| 351 |
+
lf_a, hf_a, _ = _energy(hsp_cur)
|
| 352 |
+
ratio_a = lf_a / hf_a if hf_a > 1e-6 else float("inf")
|
| 353 |
+
drop = 1.0 - (ratio_a / ratio_b) if ratio_b > 1e-6 else 0.0
|
| 354 |
+
current_s = capped_s
|
| 355 |
+
iters += 1
|
| 356 |
+
|
| 357 |
+
if iters > 0 or verbose:
|
| 358 |
+
logging.info(f"[MegaFreeU] Cap done: {orig_drop*100:.1f}%->{drop*100:.1f}% "
|
| 359 |
+
f"({iters} iters s_final={current_s:.4f})")
|
| 360 |
+
return hsp_cur
|
| 361 |
+
|
| 362 |
+
# ββ Schedule (sd-webui-freeu exact) ββββββββββββββββββββββββββββββββββββββββββ
|
| 363 |
+
def get_schedule_ratio() -> float:
|
| 364 |
+
from modules import shared
|
| 365 |
+
st = global_state.instance
|
| 366 |
+
steps = shared.state.sampling_steps or 20
|
| 367 |
+
cur = global_state.current_sampling_step
|
| 368 |
+
start = _to_step(st.start_ratio, steps)
|
| 369 |
+
stop = _to_step(st.stop_ratio, steps)
|
| 370 |
+
if start == stop:
|
| 371 |
+
smooth = 0.0
|
| 372 |
+
elif cur < start:
|
| 373 |
+
smooth = min(1.0, max(0.0, cur / (start + 1e-8)))
|
| 374 |
+
else:
|
| 375 |
+
smooth = min(1.0, max(0.0, 1 + (cur - start) / (start - stop + 1e-8)))
|
| 376 |
+
flat = 1.0 if start <= cur < stop else 0.0
|
| 377 |
+
return lerp(flat, smooth, st.transition_smoothness)
|
| 378 |
+
|
| 379 |
+
def get_stage_bsratio(b_start: float, b_end: float) -> float:
|
| 380 |
+
"""Independent B/S timestep range gate (FreeU_V2_timestepadd concept -> step fraction)."""
|
| 381 |
+
from modules import shared
|
| 382 |
+
steps = max(shared.state.sampling_steps or 20, 1)
|
| 383 |
+
cur = global_state.current_sampling_step
|
| 384 |
+
pct = cur / (steps - 1) if steps > 1 else 0.0
|
| 385 |
+
return 1.0 if b_start <= pct <= b_end else 0.0
|
| 386 |
+
|
| 387 |
+
def _to_step(v, steps):
|
| 388 |
+
return int(v * steps) if isinstance(v, float) else int(v)
|
| 389 |
+
|
| 390 |
+
# ββ Stage auto-detection (FreeU_B1B2.py + kohaku exact) ββββββββββββββββββββββ
|
| 391 |
+
_stage_channels: Tuple[int, int, int] = (1280, 640, 320)
|
| 392 |
+
|
| 393 |
+
def detect_model_channels():
|
| 394 |
+
global _stage_channels
|
| 395 |
+
try:
|
| 396 |
+
from modules import shared
|
| 397 |
+
mc = int(shared.sd_model.model.diffusion_model.model_channels)
|
| 398 |
+
_stage_channels = (mc * 4, mc * 2, mc * 1)
|
| 399 |
+
except Exception:
|
| 400 |
+
_stage_channels = (1280, 640, 320)
|
| 401 |
+
|
| 402 |
+
def get_stage_index(dims: int, channel_threshold: int = 96) -> Optional[int]:
|
| 403 |
+
"""FreeU_B1B2.py abs(ch - target) <= channel_threshold proximity match."""
|
| 404 |
+
for i, target in enumerate(_stage_channels):
|
| 405 |
+
if abs(dims - target) <= channel_threshold:
|
| 406 |
+
return i
|
| 407 |
+
return None
|
| 408 |
+
|
| 409 |
+
# ββ Override scales parser (WAS nodes.py format exact) βββββββββββββββββββββββ
|
| 410 |
+
def parse_override_scales(text: str) -> Optional[List]:
|
| 411 |
+
if not text or not text.strip():
|
| 412 |
+
return None
|
| 413 |
+
result = []
|
| 414 |
+
for line in text.strip().splitlines():
|
| 415 |
+
line = line.strip()
|
| 416 |
+
if not line or line.startswith(("#", "!", "//")):
|
| 417 |
+
continue
|
| 418 |
+
parts = line.split(",")
|
| 419 |
+
if len(parts) == 2:
|
| 420 |
+
try:
|
| 421 |
+
result.append((int(parts[0].strip()), float(parts[1].strip())))
|
| 422 |
+
except ValueError:
|
| 423 |
+
pass
|
| 424 |
+
return result if result else None
|
| 425 |
+
|
| 426 |
+
class _VerboseRef:
|
| 427 |
+
value: bool = False
|
| 428 |
+
verbose_ref = _VerboseRef()
|
| 429 |
+
|
| 430 |
+
# ββ Core th.cat hijack (sd-webui-freeu exact + extended) βββββββββββββββββββββ
|
| 431 |
+
def free_u_cat_hijack(hs, *args, original_function, **kwargs):
|
| 432 |
+
"""
|
| 433 |
+
Intercepts torch.cat([h, h_skip], dim=1) in UNet output_blocks.
|
| 434 |
+
Signature: kwargs=={"dim":1} and len(hs)==2 (sd-webui-freeu exact check).
|
| 435 |
+
|
| 436 |
+
Why th.cat over alternatives:
|
| 437 |
+
- sdwebui-freeU-extension CondFunc(UNetModel.forward): rewrites full forward,
|
| 438 |
+
incompatible with other extensions, plus 2 bugs in fourier mask.
|
| 439 |
+
- kohaku register_forward_hook: output already concatenated,
|
| 440 |
+
can't cleanly separate h from h_skip for independent filtering.
|
| 441 |
+
- th.cat hijack: intercepts exactly [h, h_skip] before concatenation. CORRECT.
|
| 442 |
+
"""
|
| 443 |
+
st = global_state.instance
|
| 444 |
+
if not st.enable:
|
| 445 |
+
return original_function(hs, *args, **kwargs)
|
| 446 |
+
|
| 447 |
+
sched = get_schedule_ratio()
|
| 448 |
+
if sched == 0:
|
| 449 |
+
return original_function(hs, *args, **kwargs)
|
| 450 |
+
|
| 451 |
+
try:
|
| 452 |
+
h, h_skip = hs
|
| 453 |
+
if list(kwargs.keys()) != ["dim"] or kwargs.get("dim", -1) != 1:
|
| 454 |
+
return original_function(hs, *args, **kwargs)
|
| 455 |
+
except (ValueError, TypeError):
|
| 456 |
+
return original_function(hs, *args, **kwargs)
|
| 457 |
+
|
| 458 |
+
dims = int(h.shape[1])
|
| 459 |
+
stage_idx = get_stage_index(dims, st.channel_threshold)
|
| 460 |
+
if stage_idx is None:
|
| 461 |
+
return original_function(hs, *args, **kwargs)
|
| 462 |
+
|
| 463 |
+
si = st.stage_infos[stage_idx]
|
| 464 |
+
version = st.version
|
| 465 |
+
verbose = verbose_ref.value
|
| 466 |
+
|
| 467 |
+
# ββ BACKBONE βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 468 |
+
b_gate = get_stage_bsratio(si.b_start_ratio, si.b_end_ratio)
|
| 469 |
+
eff_b = sched * b_gate
|
| 470 |
+
|
| 471 |
+
if eff_b > 0.0 and abs(si.backbone_factor - 1.0) > 1e-6:
|
| 472 |
+
try:
|
| 473 |
+
rbegin, rend, rinv = ratio_to_region(si.backbone_width, si.backbone_offset, dims)
|
| 474 |
+
ch_idx = torch.arange(dims, device=h.device)
|
| 475 |
+
mask = (rbegin <= ch_idx) & (ch_idx <= rend)
|
| 476 |
+
if rinv: mask = ~mask
|
| 477 |
+
mask = mask.reshape(1, -1, 1, 1).to(h.dtype)
|
| 478 |
+
|
| 479 |
+
eff_factor = float(lerp(1.0, si.backbone_factor, eff_b))
|
| 480 |
+
scale = get_backbone_scale(h, eff_factor, version)
|
| 481 |
+
# h_scaled_full: full h with mask region scaled, rest unchanged
|
| 482 |
+
# This matches original: h *= mask*scale + (1-mask)
|
| 483 |
+
h_scaled_full = h * (mask * scale + (1.0 - mask))
|
| 484 |
+
|
| 485 |
+
bmode = si.backbone_blend_mode
|
| 486 |
+
if bmode in BLENDING_MODES and abs(si.backbone_blend - 1.0) > 1e-6:
|
| 487 |
+
# Blend on FULL tensors so modes like slerp/hslerp see proper norms.
|
| 488 |
+
# Then restore unmasked channels to original h.
|
| 489 |
+
h_blended = BLENDING_MODES[bmode](h, h_scaled_full, si.backbone_blend)
|
| 490 |
+
h = h * (1.0 - mask) + h_blended * mask
|
| 491 |
+
else:
|
| 492 |
+
h = h_scaled_full
|
| 493 |
+
except Exception as e:
|
| 494 |
+
logging.warning(f"[MegaFreeU] B-scaling stage {stage_idx}: {e}")
|
| 495 |
+
|
| 496 |
+
# ββ SKIP / FOURIER ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 497 |
+
s_gate = get_stage_bsratio(si.s_start_ratio, si.s_end_ratio)
|
| 498 |
+
eff_s = sched * s_gate
|
| 499 |
+
|
| 500 |
+
if eff_s > 0.0 and (abs(si.skip_factor - 1.0) > 1e-6
|
| 501 |
+
or abs(si.hf_boost - 1.0) > 1e-6
|
| 502 |
+
or abs(si.skip_high_end_factor - 1.0) > 1e-6):
|
| 503 |
+
try:
|
| 504 |
+
s_scale = float(lerp(1.0, si.skip_factor, eff_s))
|
| 505 |
+
s_high = float(lerp(1.0, si.skip_high_end_factor, eff_s))
|
| 506 |
+
|
| 507 |
+
if si.fft_type == "gaussian":
|
| 508 |
+
hf_b = float(lerp(1.0, si.hf_boost, eff_s))
|
| 509 |
+
si_eff = dataclasses.replace(si, skip_factor=s_scale, skip_high_end_factor=s_high, hf_boost=hf_b)
|
| 510 |
+
h_skip = filter_skip_gaussian_adaptive(h_skip, si_eff, verbose)
|
| 511 |
+
else:
|
| 512 |
+
override = parse_override_scales(st.override_scales)
|
| 513 |
+
ms_preset = override or global_state.MSCALES.get(st.multiscale_mode)
|
| 514 |
+
if ms_preset is not None:
|
| 515 |
+
h_skip = filter_skip_box_multiscale(
|
| 516 |
+
h_skip, si.skip_cutoff, s_scale, ms_preset,
|
| 517 |
+
st.multiscale_strength, s_high)
|
| 518 |
+
else:
|
| 519 |
+
h_skip = filter_skip_box(h_skip, si.skip_cutoff, s_scale, s_high)
|
| 520 |
+
except Exception as e:
|
| 521 |
+
logging.warning(f"[MegaFreeU] skip filter stage {stage_idx}: {e}")
|
| 522 |
+
|
| 523 |
+
return original_function([h, h_skip], *args, **kwargs)
|
| 524 |
+
|
| 525 |
+
# ββ Patch (sd-webui-freeu exact + ControlNet) βββββββββββββββββββββββββββββββββ
|
| 526 |
+
_patched = False # guard against double-patch on hot-reload
|
| 527 |
+
|
| 528 |
+
def patch():
|
| 529 |
+
global _patched
|
| 530 |
+
try:
|
| 531 |
+
from modules.sd_hijack_unet import th
|
| 532 |
+
except ImportError:
|
| 533 |
+
print("[MegaFreeU] sd_hijack_unet not available", file=sys.stderr); return
|
| 534 |
+
|
| 535 |
+
if _patched or (hasattr(th.cat, "func") and getattr(th.cat.func, "__name__", "") == "free_u_cat_hijack"):
|
| 536 |
+
return # already patched (by name; handles module reload)
|
| 537 |
+
th.cat = functools.partial(free_u_cat_hijack, original_function=th.cat)
|
| 538 |
+
_patched = True
|
| 539 |
+
|
| 540 |
+
cn_status = "enabled"
|
| 541 |
+
try:
|
| 542 |
+
from modules import scripts
|
| 543 |
+
cn_paths = [
|
| 544 |
+
str(pathlib.Path(scripts.basedir()).parent.parent / "extensions-builtin" / "sd-webui-controlnet"),
|
| 545 |
+
str(pathlib.Path(scripts.basedir()).parent / "sd-webui-controlnet"),
|
| 546 |
+
]
|
| 547 |
+
sys.path[0:0] = cn_paths
|
| 548 |
+
try:
|
| 549 |
+
import scripts.hook as cn_hook
|
| 550 |
+
cn_hook.th.cat = functools.partial(free_u_cat_hijack, original_function=cn_hook.th.cat)
|
| 551 |
+
except ImportError:
|
| 552 |
+
cn_status = "disabled"
|
| 553 |
+
finally:
|
| 554 |
+
for p in cn_paths:
|
| 555 |
+
if p in sys.path: sys.path.remove(p)
|
| 556 |
+
except Exception:
|
| 557 |
+
cn_status = "error"
|
| 558 |
+
|
| 559 |
+
print(f"[MegaFreeU] th.cat patched ControlNet: *{cn_status}*")
|
mega_freeu_a1111/lib_mega_freeu/xyz_grid.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
lib_mega_freeu/xyz_grid.py β XYZ/XY grid axes
|
| 3 |
+
|
| 4 |
+
Source: sd-webui-freeu/lib_free_u/xyz_grid.py (exact find_xyz_module check)
|
| 5 |
+
Extended with all Mega FreeU per-stage params.
|
| 6 |
+
"""
|
| 7 |
+
import sys
|
| 8 |
+
from types import ModuleType
|
| 9 |
+
from typing import Optional
|
| 10 |
+
from modules import scripts
|
| 11 |
+
from lib_mega_freeu import global_state
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def patch():
|
| 15 |
+
xyz_module = _find_xyz_module()
|
| 16 |
+
if xyz_module is None:
|
| 17 |
+
print("[MegaFreeU] xyz_grid.py not found β XYZ disabled", file=sys.stderr)
|
| 18 |
+
return
|
| 19 |
+
|
| 20 |
+
def _apply(k, key_map=None):
|
| 21 |
+
def cb(_p, v, _vs):
|
| 22 |
+
if key_map is not None:
|
| 23 |
+
v = key_map.get(v, v)
|
| 24 |
+
global_state.xyz_attrs[k] = v
|
| 25 |
+
return cb
|
| 26 |
+
|
| 27 |
+
opts = [
|
| 28 |
+
xyz_module.AxisOption("[MegaFreeU] Enabled", _bool, _apply("enable"),
|
| 29 |
+
choices=lambda: ["False","True"]),
|
| 30 |
+
xyz_module.AxisOption("[MegaFreeU] Version", str, _apply("version",
|
| 31 |
+
key_map=global_state.ALL_VERSIONS),
|
| 32 |
+
choices=lambda: list(global_state.ALL_VERSIONS.keys())),
|
| 33 |
+
xyz_module.AxisOption("[MegaFreeU] Preset", str, _apply("preset"),
|
| 34 |
+
choices=_choices_preset),
|
| 35 |
+
xyz_module.AxisOption("[MegaFreeU] Start At Step", _num, _apply("start_ratio")),
|
| 36 |
+
xyz_module.AxisOption("[MegaFreeU] Stop At Step", _num, _apply("stop_ratio")),
|
| 37 |
+
xyz_module.AxisOption("[MegaFreeU] Smoothness", float, _apply("transition_smoothness")),
|
| 38 |
+
xyz_module.AxisOption("[MegaFreeU] Multiscale Mode", str, _apply("multiscale_mode"),
|
| 39 |
+
choices=lambda: list(global_state.MSCALES.keys())),
|
| 40 |
+
xyz_module.AxisOption("[MegaFreeU] Multiscale Str", float, _apply("multiscale_strength")),
|
| 41 |
+
xyz_module.AxisOption("[MegaFreeU] Ch Threshold", int, _apply("channel_threshold")),
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
for i in range(global_state.STAGES_COUNT):
|
| 45 |
+
n = i + 1
|
| 46 |
+
opts += [
|
| 47 |
+
xyz_module.AxisOption(f"[MFU] S{n} B Scale", float, _apply(f"b{i}")),
|
| 48 |
+
xyz_module.AxisOption(f"[MFU] S{n} B Offset", float, _apply(f"o{i}")),
|
| 49 |
+
xyz_module.AxisOption(f"[MFU] S{n} B Width", float, _apply(f"w{i}")),
|
| 50 |
+
xyz_module.AxisOption(f"[MFU] S{n} Blend Mode", str, _apply(f"bm{i}"),
|
| 51 |
+
choices=lambda: global_state.BLEND_MODE_NAMES),
|
| 52 |
+
xyz_module.AxisOption(f"[MFU] S{n} Blend Str", float, _apply(f"bb{i}")),
|
| 53 |
+
xyz_module.AxisOption(f"[MFU] S{n} S Scale", float, _apply(f"s{i}")),
|
| 54 |
+
xyz_module.AxisOption(f"[MFU] S{n} Cutoff", float, _apply(f"t{i}")),
|
| 55 |
+
xyz_module.AxisOption(f"[MFU] S{n} High-End", float, _apply(f"h{i}")),
|
| 56 |
+
xyz_module.AxisOption(f"[MFU] S{n} HF Boost", float, _apply(f"hfb{i}")),
|
| 57 |
+
xyz_module.AxisOption(f"[MFU] S{n} B Start%", float, _apply(f"bs{i}")),
|
| 58 |
+
xyz_module.AxisOption(f"[MFU] S{n} B End%", float, _apply(f"be{i}")),
|
| 59 |
+
xyz_module.AxisOption(f"[MFU] S{n} S Start%", float, _apply(f"ss{i}")),
|
| 60 |
+
xyz_module.AxisOption(f"[MFU] S{n} S End%", float, _apply(f"se{i}")),
|
| 61 |
+
xyz_module.AxisOption(f"[MFU] S{n} FFT Type", str, _apply(f"ft{i}"),
|
| 62 |
+
choices=lambda: global_state.FFT_TYPES),
|
| 63 |
+
xyz_module.AxisOption(f"[MFU] S{n} Radius", float, _apply(f"fr{i}")),
|
| 64 |
+
xyz_module.AxisOption(f"[MFU] S{n} Cap Enable", _bool, _apply(f"cap{i}")),
|
| 65 |
+
xyz_module.AxisOption(f"[MFU] S{n} Cap Thresh", float, _apply(f"ct{i}")),
|
| 66 |
+
xyz_module.AxisOption(f"[MFU] S{n} Cap Factor", float, _apply(f"cf{i}")),
|
| 67 |
+
xyz_module.AxisOption(f"[MFU] S{n} Cap Mode", str, _apply(f"acm{i}"),
|
| 68 |
+
choices=lambda: ["adaptive", "fixed"]),
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
xyz_module.axis_options.extend(opts)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _find_xyz_module() -> Optional[ModuleType]:
|
| 75 |
+
"""Exact check from sd-webui-freeu/lib_free_u/xyz_grid.py"""
|
| 76 |
+
for data in scripts.scripts_data:
|
| 77 |
+
if data.script_class.__module__ in {"xyz_grid.py", "xy_grid.py"} and hasattr(data, "module"):
|
| 78 |
+
return data.module
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _choices_preset():
|
| 83 |
+
presets = list(global_state.all_presets.keys())
|
| 84 |
+
presets.insert(0, "UI Settings")
|
| 85 |
+
return presets
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _bool(s):
|
| 89 |
+
s = str(s).lower()
|
| 90 |
+
if s in ("true","1","yes"): return True
|
| 91 |
+
if s in ("false","0","no"): return False
|
| 92 |
+
return bool(s)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _num(s):
|
| 96 |
+
try: return int(s)
|
| 97 |
+
except (ValueError, TypeError): return float(s)
|
mega_freeu_a1111/scripts/mega_freeu.py
ADDED
|
@@ -0,0 +1,765 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
scripts/mega_freeu.py - Mega FreeU for A1111 / Forge
|
| 3 |
+
|
| 4 |
+
Combined from 5 sources:
|
| 5 |
+
1. sd-webui-freeu th.cat hijack, V1/V2 backbone, box filter, schedule,
|
| 6 |
+
presets JSON, PNG metadata, XYZ, ControlNet, region masking,
|
| 7 |
+
dict-API compat (alwayson_scripts legacy)
|
| 8 |
+
2. WAS FreeU_Advanced 9 blending modes, 13 multi-scale FFT presets, override_scales,
|
| 9 |
+
Post-CFG Shift (WAS_PostCFGShift ported to A1111 callback)
|
| 10 |
+
NOTE: target_block / input_block / middle_block / slice_b1/b2
|
| 11 |
+
were not ported β th.cat hijack works on output-side skip concat.
|
| 12 |
+
3. ComfyUI_FreeU_V2_Adv Gaussian filter, Adaptive Cap (MAX_CAP_ITER=3),
|
| 13 |
+
independent B/S timestep ranges per-stage, channel_threshold
|
| 14 |
+
4. FreeU_V2_timestepadd b_start/b_end%, s_start/s_end% per-stage gating
|
| 15 |
+
NOTE: gating uses step-fraction (cur/total), not percent_to_sigma
|
| 16 |
+
as in original ComfyUI sources. Conceptually equivalent.
|
| 17 |
+
5. nrs_kohaku_v3.5 hf_boost param, on_cpu_devices dict, gaussian standalone
|
| 18 |
+
|
| 19 |
+
BUGS FIXED vs sdwebui-freeU-extension:
|
| 20 |
+
BUG 1: bool mask in Fourier filter (scale multiplication was NOOP)
|
| 21 |
+
BUG 2: single-quadrant mask instead of symmetric center
|
| 22 |
+
"""
|
| 23 |
+
import dataclasses
|
| 24 |
+
import json
|
| 25 |
+
from typing import List
|
| 26 |
+
|
| 27 |
+
import gradio as gr
|
| 28 |
+
from modules import script_callbacks, scripts, shared, processing
|
| 29 |
+
|
| 30 |
+
from lib_mega_freeu import global_state, unet, xyz_grid
|
| 31 |
+
|
| 32 |
+
_steps_comps = {"txt2img": None, "img2img": None}
|
| 33 |
+
_steps_cbs = {"txt2img": [], "img2img": []}
|
| 34 |
+
|
| 35 |
+
_SF = [f.name for f in dataclasses.fields(global_state.StageInfo)]
|
| 36 |
+
_SN = len(_SF) # 19 fields per stage
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _stage_ui(idx, si, elem_id_fn):
|
| 40 |
+
n = idx + 1
|
| 41 |
+
ch = {0: "~1280ch (deep)", 1: "~640ch (mid)", 2: "~320ch (shallow)"}.get(idx, f"stage{n}")
|
| 42 |
+
|
| 43 |
+
with gr.Accordion(open=(idx < 2), label=f"Stage {n} ({ch})"):
|
| 44 |
+
|
| 45 |
+
# Backbone
|
| 46 |
+
gr.HTML(f"<p style=\'margin:4px 0;font-size:.82em;color:#aaa;\'>Backbone h (B)</p>")
|
| 47 |
+
with gr.Row():
|
| 48 |
+
bf = gr.Slider(label=f"B{n} Scale", minimum=-1, maximum=3, step=0.001,
|
| 49 |
+
value=si.backbone_factor,
|
| 50 |
+
info=">1 strengthens backbone features. V2: adaptive per-region.")
|
| 51 |
+
bo = gr.Slider(label=f"B{n} Offset", minimum=0, maximum=1, step=0.001,
|
| 52 |
+
value=si.backbone_offset, info="Channel region start [0-1].")
|
| 53 |
+
bw = gr.Slider(label=f"B{n} Width", minimum=-1, maximum=1, step=0.001,
|
| 54 |
+
value=si.backbone_width, info="Channel region width. Negative=invert.")
|
| 55 |
+
with gr.Row():
|
| 56 |
+
bm = gr.Dropdown(label=f"B{n} Blend Mode",
|
| 57 |
+
choices=global_state.BLEND_MODE_NAMES,
|
| 58 |
+
value=si.backbone_blend_mode,
|
| 59 |
+
info="lerp=default, stable_slerp=quality, inject=additive")
|
| 60 |
+
bb = gr.Slider(label=f"B{n} Blend Str", minimum=0, maximum=2, step=0.001,
|
| 61 |
+
value=si.backbone_blend)
|
| 62 |
+
gr.HTML("<p style=\'font-size:.75em;color:#888;margin:2px 0;\'>B timestep range (ComfyUI V2)</p>")
|
| 63 |
+
with gr.Row():
|
| 64 |
+
bsr = gr.Slider(label=f"B{n} Start%", minimum=0, maximum=1, step=0.001,
|
| 65 |
+
value=si.b_start_ratio, info="B activates at this step fraction.")
|
| 66 |
+
ber = gr.Slider(label=f"B{n} End%", minimum=0, maximum=1, step=0.001,
|
| 67 |
+
value=si.b_end_ratio, info="B stops. 0.35=structure phase only.")
|
| 68 |
+
|
| 69 |
+
# Skip / FFT
|
| 70 |
+
gr.HTML(f"<p style=\'margin:8px 0 4px;font-size:.82em;color:#aaa;\'>Skip h_skip (S) - Fourier Filter</p>")
|
| 71 |
+
with gr.Row():
|
| 72 |
+
sf = gr.Slider(label=f"S{n} LF Scale", minimum=-1, maximum=3, step=0.001,
|
| 73 |
+
value=si.skip_factor,
|
| 74 |
+
info="<1 suppresses LF components. 0.2=strong suppression.")
|
| 75 |
+
she = gr.Slider(label=f"S{n} HF (Box)", minimum=-1, maximum=3, step=0.001,
|
| 76 |
+
value=si.skip_high_end_factor,
|
| 77 |
+
info="HF scale outside LF region (box filter). >1=boost HF.")
|
| 78 |
+
hfb = gr.Slider(label=f"S{n} HF Boost (Gauss)", minimum=0, maximum=3, step=0.001,
|
| 79 |
+
value=si.hf_boost,
|
| 80 |
+
info="Gaussian explicit HF multiplier. Combined as max(hf_boost, high_end).")
|
| 81 |
+
with gr.Row():
|
| 82 |
+
ft = gr.Radio(label=f"S{n} FFT Type",
|
| 83 |
+
choices=global_state.FFT_TYPES, value=si.fft_type,
|
| 84 |
+
info="gaussian=smooth no-ringing. box=original FreeU (both bugs fixed).")
|
| 85 |
+
sco = gr.Slider(label=f"S{n} Cutoff (Box)", minimum=0, maximum=1, step=0.001,
|
| 86 |
+
value=si.skip_cutoff, info="Box: LF cutoff fraction. 0=1px default.")
|
| 87 |
+
srr = gr.Slider(label=f"S{n} Radius (Gauss)", minimum=0.01, maximum=0.5, step=0.001,
|
| 88 |
+
value=si.fft_radius_ratio,
|
| 89 |
+
info="Gaussian R=ratio*min(H,W). 0.07=moderate LF.")
|
| 90 |
+
gr.HTML("<p style=\'font-size:.75em;color:#888;margin:2px 0;\'>S timestep range (ComfyUI V2)</p>")
|
| 91 |
+
with gr.Row():
|
| 92 |
+
ssr = gr.Slider(label=f"S{n} Start%", minimum=0, maximum=1, step=0.001,
|
| 93 |
+
value=si.s_start_ratio,
|
| 94 |
+
info="S activates. Tip: set = B End% for clean phase separation.")
|
| 95 |
+
ser = gr.Slider(label=f"S{n} End%", minimum=0, maximum=1, step=0.001,
|
| 96 |
+
value=si.s_end_ratio, info="S stops. 1.0=to last step.")
|
| 97 |
+
|
| 98 |
+
# Adaptive Cap
|
| 99 |
+
gr.HTML("<p style=\'font-size:.75em;color:#888;margin:4px 0;\'>Adaptive Cap - prevents LF over-attenuation (FreeU_S1S2.py)</p>")
|
| 100 |
+
with gr.Row():
|
| 101 |
+
eac = gr.Checkbox(label=f"S{n} Enable Cap", value=si.enable_adaptive_cap,
|
| 102 |
+
info="Iteratively weakens Gaussian if LF/HF drop exceeds threshold.")
|
| 103 |
+
ct = gr.Slider(label="Threshold", minimum=0, maximum=1, step=0.001,
|
| 104 |
+
value=si.cap_threshold, info="Max allowed LF/HF ratio drop. 0.35=35%.")
|
| 105 |
+
cf = gr.Slider(label="Factor", minimum=0, maximum=1, step=0.001,
|
| 106 |
+
value=si.cap_factor, info="Relaxation factor. 0.6=moderate.")
|
| 107 |
+
cm = gr.Radio(label="Mode", choices=["adaptive", "fixed"],
|
| 108 |
+
value=si.adaptive_cap_mode,
|
| 109 |
+
info="adaptive: scales factor with over-attenuation. fixed: always cap_factor.")
|
| 110 |
+
|
| 111 |
+
# Return exactly in _SF field order
|
| 112 |
+
return [bf, sf, bo, bw, sco, she, bm, bb, bsr, ber, ssr, ser, ft, srr, hfb, eac, ct, cf, cm]
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class MegaFreeUScript(scripts.Script):
|
| 116 |
+
|
| 117 |
+
def title(self): return "Mega FreeU"
|
| 118 |
+
def show(self, is_img2img): return scripts.AlwaysVisible
|
| 119 |
+
|
| 120 |
+
def ui(self, is_img2img):
|
| 121 |
+
global_state.reload_presets()
|
| 122 |
+
pnames = list(global_state.all_presets.keys())
|
| 123 |
+
def_sis = global_state.all_presets[pnames[0]].stage_infos
|
| 124 |
+
|
| 125 |
+
with gr.Accordion(open=False, label="Mega FreeU"):
|
| 126 |
+
|
| 127 |
+
# Top bar
|
| 128 |
+
with gr.Row():
|
| 129 |
+
enabled = gr.Checkbox(label="Enable Mega FreeU", value=False)
|
| 130 |
+
version = gr.Dropdown(
|
| 131 |
+
label="Version",
|
| 132 |
+
choices=list(global_state.ALL_VERSIONS.keys()),
|
| 133 |
+
value="Version 2",
|
| 134 |
+
elem_id=self.elem_id("version"),
|
| 135 |
+
info="V2=adaptive hidden-mean backbone. V1=flat multiplier.")
|
| 136 |
+
|
| 137 |
+
with gr.Row():
|
| 138 |
+
preset_dd = gr.Dropdown(
|
| 139 |
+
label="Preset", choices=pnames, value=pnames[0],
|
| 140 |
+
allow_custom_value=True,
|
| 141 |
+
elem_id=self.elem_id("preset_name"),
|
| 142 |
+
info="Apply loads settings. Custom name enables Save. Delete auto-saves.")
|
| 143 |
+
btn_apply = gr.Button("Apply", size="sm", elem_classes="tool")
|
| 144 |
+
btn_save = gr.Button("Save", size="sm", elem_classes="tool")
|
| 145 |
+
btn_refresh = gr.Button("Refresh", size="sm", elem_classes="tool")
|
| 146 |
+
btn_delete = gr.Button("Delete", size="sm", elem_classes="tool")
|
| 147 |
+
|
| 148 |
+
# Global schedule
|
| 149 |
+
gr.HTML("<p style=\'font-size:.82em;color:#aaa;margin:6px 0 2px;\'>Global Schedule</p>")
|
| 150 |
+
with gr.Row():
|
| 151 |
+
start_r = gr.Slider(label="Start At", elem_id=self.elem_id("start_at_step"),
|
| 152 |
+
minimum=0, maximum=1, step=0.001, value=0)
|
| 153 |
+
stop_r = gr.Slider(label="Stop At", elem_id=self.elem_id("stop_at_step"),
|
| 154 |
+
minimum=0, maximum=1, step=0.001, value=1)
|
| 155 |
+
smooth = gr.Slider(label="Transition Smoothness",
|
| 156 |
+
elem_id=self.elem_id("transition_smoothness"),
|
| 157 |
+
minimum=0, maximum=1, step=0.001, value=0,
|
| 158 |
+
info="0=hard on/off. 1=smooth fade.")
|
| 159 |
+
|
| 160 |
+
# Box Multi-Scale (WAS FreeU_Advanced)
|
| 161 |
+
with gr.Accordion(open=False, label="Box Multi-Scale FFT (WAS FreeU_Advanced)"):
|
| 162 |
+
gr.HTML("<p style=\'font-size:.8em;color:#888;\'>Applied on top of Box filter. Ignored in Gaussian mode.</p>")
|
| 163 |
+
with gr.Row():
|
| 164 |
+
ms_mode = gr.Dropdown(label="Multiscale Mode",
|
| 165 |
+
choices=list(global_state.MSCALES.keys()),
|
| 166 |
+
value="Default")
|
| 167 |
+
ms_str = gr.Slider(label="Strength", minimum=0, maximum=1,
|
| 168 |
+
step=0.001, value=1.0)
|
| 169 |
+
ov_scales = gr.Textbox(
|
| 170 |
+
label="Override Scales (WAS format: radius_px, scale per line, # comments)",
|
| 171 |
+
lines=3,
|
| 172 |
+
placeholder="# Example custom scales:\n10, 1.5\n20, 0.8",
|
| 173 |
+
value="")
|
| 174 |
+
|
| 175 |
+
with gr.Row():
|
| 176 |
+
ch_thresh = gr.Slider(
|
| 177 |
+
label="Channel Match Threshold (+-)",
|
| 178 |
+
elem_id=self.elem_id("ch_thresh"),
|
| 179 |
+
minimum=0, maximum=256, step=1, value=96,
|
| 180 |
+
info="Stage channel tolerance. 96=standard (FreeU_B1B2.py default).")
|
| 181 |
+
|
| 182 |
+
# Per-stage accordions
|
| 183 |
+
flat_comps: List = []
|
| 184 |
+
for i in range(global_state.STAGES_COUNT):
|
| 185 |
+
si = def_sis[i] if i < len(def_sis) else global_state.StageInfo()
|
| 186 |
+
flat_comps.extend(_stage_ui(i, si, self.elem_id))
|
| 187 |
+
|
| 188 |
+
# Post-CFG Shift (WAS_PostCFGShift -> A1111)
|
| 189 |
+
with gr.Accordion(open=False, label="Post-CFG Shift (WAS_PostCFGShift -> A1111 callback)"):
|
| 190 |
+
gr.HTML("<p style=\'font-size:.8em;color:#888;\'>Runs after combine_denoised. Blends denoised*b into output via on_cfg_after_cfg callback.</p>")
|
| 191 |
+
with gr.Row():
|
| 192 |
+
pcfg_en = gr.Checkbox(label="Enable Post-CFG Shift", value=False)
|
| 193 |
+
pcfg_steps = gr.Slider(label="Max Steps", minimum=1, maximum=200,
|
| 194 |
+
step=1, value=20,
|
| 195 |
+
info="Apply only to first N steps.")
|
| 196 |
+
with gr.Row():
|
| 197 |
+
pcfg_mode = gr.Dropdown(label="Blend Mode",
|
| 198 |
+
choices=global_state.BLEND_MODE_NAMES,
|
| 199 |
+
value="inject")
|
| 200 |
+
pcfg_bl = gr.Slider(label="Blend", minimum=0, maximum=5,
|
| 201 |
+
step=0.001, value=1.0)
|
| 202 |
+
pcfg_b = gr.Slider(label="B Factor", minimum=0, maximum=5,
|
| 203 |
+
step=0.001, value=1.1,
|
| 204 |
+
info=">1 amplifies shift.")
|
| 205 |
+
with gr.Row():
|
| 206 |
+
pcfg_fou = gr.Checkbox(label="Apply Fourier Filter", value=False)
|
| 207 |
+
pcfg_mmd = gr.Dropdown(label="Fourier Multiscale",
|
| 208 |
+
choices=list(global_state.MSCALES.keys()),
|
| 209 |
+
value="Default")
|
| 210 |
+
pcfg_mst = gr.Slider(label="Fourier Strength", minimum=0, maximum=1,
|
| 211 |
+
step=0.001, value=1.0)
|
| 212 |
+
with gr.Row():
|
| 213 |
+
pcfg_thr = gr.Slider(label="Threshold (px)", minimum=1, maximum=20,
|
| 214 |
+
step=1, value=1,
|
| 215 |
+
info="Box filter LF radius in pixels.")
|
| 216 |
+
pcfg_s = gr.Slider(label="S Scale", minimum=0, maximum=3,
|
| 217 |
+
step=0.001, value=0.5)
|
| 218 |
+
pcfg_gain = gr.Slider(label="Force Gain", minimum=0, maximum=5,
|
| 219 |
+
step=0.01, value=1.0,
|
| 220 |
+
info="Final output multiplier.")
|
| 221 |
+
|
| 222 |
+
verbose = gr.Checkbox(label="Verbose Logging (Adaptive Cap, energy stats)", value=False)
|
| 223 |
+
|
| 224 |
+
# Hidden PNG infotext components
|
| 225 |
+
sched_info = gr.HTML(visible=False)
|
| 226 |
+
stages_info = gr.HTML(visible=False)
|
| 227 |
+
version_info = gr.HTML(visible=False)
|
| 228 |
+
ms_mode_info = gr.HTML(visible=False)
|
| 229 |
+
ms_str_info = gr.HTML(visible=False)
|
| 230 |
+
ov_scales_info = gr.HTML(visible=False)
|
| 231 |
+
ch_thresh_info = gr.HTML(visible=False)
|
| 232 |
+
postcfg_info = gr.HTML(visible=False)
|
| 233 |
+
verbose_info = gr.HTML(visible=False)
|
| 234 |
+
# Legacy sd-webui-freeu keys for backward compat
|
| 235 |
+
legacy_sched_info = gr.HTML(visible=False)
|
| 236 |
+
legacy_stages_info = gr.HTML(visible=False)
|
| 237 |
+
legacy_version_info = gr.HTML(visible=False)
|
| 238 |
+
|
| 239 |
+
# Preset buttons
|
| 240 |
+
def _btn_upd(name):
|
| 241 |
+
ex = name in global_state.all_presets
|
| 242 |
+
usr = name not in global_state.default_presets
|
| 243 |
+
return (gr.update(interactive=ex),
|
| 244 |
+
gr.update(interactive=usr),
|
| 245 |
+
gr.update(interactive=usr and ex))
|
| 246 |
+
|
| 247 |
+
preset_dd.change(fn=_btn_upd, inputs=[preset_dd],
|
| 248 |
+
outputs=[btn_apply, btn_save, btn_delete])
|
| 249 |
+
|
| 250 |
+
def _apply_p(name):
|
| 251 |
+
p = global_state.all_presets.get(name)
|
| 252 |
+
n_extras = 20 # 8 main + 11 Post-CFG + 1 verbose
|
| 253 |
+
if p is None:
|
| 254 |
+
return [gr.skip()] * (n_extras + len(flat_comps))
|
| 255 |
+
flat = []
|
| 256 |
+
for si in p.stage_infos:
|
| 257 |
+
for f in _SF:
|
| 258 |
+
flat.append(getattr(si, f))
|
| 259 |
+
vlabel = global_state.REVERSED_VERSIONS.get(p.version, "Version 2")
|
| 260 |
+
return (
|
| 261 |
+
gr.update(value=p.start_ratio),
|
| 262 |
+
gr.update(value=p.stop_ratio),
|
| 263 |
+
gr.update(value=p.transition_smoothness),
|
| 264 |
+
gr.update(value=vlabel),
|
| 265 |
+
gr.update(value=p.multiscale_mode),
|
| 266 |
+
gr.update(value=p.multiscale_strength),
|
| 267 |
+
gr.update(value=p.override_scales),
|
| 268 |
+
gr.update(value=p.channel_threshold),
|
| 269 |
+
gr.update(value=p.pcfg_enabled),
|
| 270 |
+
gr.update(value=p.pcfg_steps),
|
| 271 |
+
gr.update(value=p.pcfg_mode),
|
| 272 |
+
gr.update(value=p.pcfg_blend),
|
| 273 |
+
gr.update(value=p.pcfg_b),
|
| 274 |
+
gr.update(value=p.pcfg_fourier),
|
| 275 |
+
gr.update(value=p.pcfg_ms_mode),
|
| 276 |
+
gr.update(value=p.pcfg_ms_str),
|
| 277 |
+
gr.update(value=p.pcfg_threshold),
|
| 278 |
+
gr.update(value=p.pcfg_s),
|
| 279 |
+
gr.update(value=p.pcfg_gain),
|
| 280 |
+
gr.update(value=p.verbose),
|
| 281 |
+
*[gr.update(value=v) for v in flat],
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
btn_apply.click(
|
| 285 |
+
fn=_apply_p,
|
| 286 |
+
inputs=[preset_dd],
|
| 287 |
+
outputs=[
|
| 288 |
+
start_r, stop_r, smooth, version,
|
| 289 |
+
ms_mode, ms_str, ov_scales, ch_thresh,
|
| 290 |
+
pcfg_en, pcfg_steps, pcfg_mode, pcfg_bl, pcfg_b,
|
| 291 |
+
pcfg_fou, pcfg_mmd, pcfg_mst, pcfg_thr, pcfg_s, pcfg_gain,
|
| 292 |
+
verbose,
|
| 293 |
+
*flat_comps,
|
| 294 |
+
]
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
def _save_p(
|
| 298 |
+
name, sr, sp, sm, ver, msm, mss, ovs, cht,
|
| 299 |
+
p_en, p_steps, p_mode, p_bl, p_b,
|
| 300 |
+
p_four, p_mmd, p_mst, p_thr, p_s, p_gain,
|
| 301 |
+
v_log,
|
| 302 |
+
*flat
|
| 303 |
+
):
|
| 304 |
+
sis = _flat_to_sis(flat)
|
| 305 |
+
vc = global_state.ALL_VERSIONS.get(ver, "1")
|
| 306 |
+
global_state.all_presets[name] = global_state.State(
|
| 307 |
+
start_ratio=sr, stop_ratio=sp, transition_smoothness=sm,
|
| 308 |
+
version=vc,
|
| 309 |
+
multiscale_mode=msm,
|
| 310 |
+
multiscale_strength=float(mss),
|
| 311 |
+
override_scales=ovs or "",
|
| 312 |
+
channel_threshold=int(cht),
|
| 313 |
+
stage_infos=sis,
|
| 314 |
+
pcfg_enabled=bool(p_en),
|
| 315 |
+
pcfg_steps=int(p_steps),
|
| 316 |
+
pcfg_mode=str(p_mode),
|
| 317 |
+
pcfg_blend=float(p_bl),
|
| 318 |
+
pcfg_b=float(p_b),
|
| 319 |
+
pcfg_fourier=bool(p_four),
|
| 320 |
+
pcfg_ms_mode=str(p_mmd),
|
| 321 |
+
pcfg_ms_str=float(p_mst),
|
| 322 |
+
pcfg_threshold=int(p_thr),
|
| 323 |
+
pcfg_s=float(p_s),
|
| 324 |
+
pcfg_gain=float(p_gain),
|
| 325 |
+
verbose=bool(v_log),
|
| 326 |
+
)
|
| 327 |
+
global_state.save_presets()
|
| 328 |
+
return (
|
| 329 |
+
gr.update(choices=list(global_state.all_presets.keys())),
|
| 330 |
+
gr.update(interactive=True),
|
| 331 |
+
gr.update(interactive=True),
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
btn_save.click(
|
| 335 |
+
fn=_save_p,
|
| 336 |
+
inputs=[
|
| 337 |
+
preset_dd, start_r, stop_r, smooth, version,
|
| 338 |
+
ms_mode, ms_str, ov_scales, ch_thresh,
|
| 339 |
+
pcfg_en, pcfg_steps, pcfg_mode, pcfg_bl, pcfg_b,
|
| 340 |
+
pcfg_fou, pcfg_mmd, pcfg_mst, pcfg_thr, pcfg_s, pcfg_gain,
|
| 341 |
+
verbose,
|
| 342 |
+
*flat_comps,
|
| 343 |
+
],
|
| 344 |
+
outputs=[preset_dd, btn_apply, btn_delete]
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
def _refresh_p(name):
|
| 348 |
+
global_state.reload_presets()
|
| 349 |
+
ex = name in global_state.all_presets
|
| 350 |
+
usr = name not in global_state.default_presets
|
| 351 |
+
ch = list(global_state.all_presets.keys())
|
| 352 |
+
return (gr.update(choices=ch, value=name),
|
| 353 |
+
gr.update(interactive=ex), gr.update(interactive=usr),
|
| 354 |
+
gr.update(interactive=usr and ex))
|
| 355 |
+
|
| 356 |
+
btn_refresh.click(fn=_refresh_p, inputs=[preset_dd],
|
| 357 |
+
outputs=[preset_dd, btn_apply, btn_save, btn_delete])
|
| 358 |
+
|
| 359 |
+
def _delete_p(name):
|
| 360 |
+
if name in global_state.all_presets and name not in global_state.default_presets:
|
| 361 |
+
idx = list(global_state.all_presets.keys()).index(name)
|
| 362 |
+
del global_state.all_presets[name]
|
| 363 |
+
global_state.save_presets()
|
| 364 |
+
names = list(global_state.all_presets.keys())
|
| 365 |
+
name = names[min(idx, len(names) - 1)]
|
| 366 |
+
ex = name in global_state.all_presets
|
| 367 |
+
usr = name not in global_state.default_presets
|
| 368 |
+
return (gr.update(choices=list(global_state.all_presets.keys()), value=name),
|
| 369 |
+
gr.update(interactive=ex), gr.update(interactive=usr),
|
| 370 |
+
gr.update(interactive=usr and ex))
|
| 371 |
+
|
| 372 |
+
btn_delete.click(fn=_delete_p, inputs=[preset_dd],
|
| 373 |
+
outputs=[preset_dd, btn_apply, btn_save, btn_delete])
|
| 374 |
+
|
| 375 |
+
# PNG schedule restore
|
| 376 |
+
def _restore_sched(info, steps):
|
| 377 |
+
if not info: return [gr.skip()] * 4
|
| 378 |
+
try:
|
| 379 |
+
parts = info.split(", ")
|
| 380 |
+
sr, sp, sm = parts[0], parts[1], parts[2]
|
| 381 |
+
total = max(int(float(steps)), 1)
|
| 382 |
+
def _r(v):
|
| 383 |
+
n = float(v.strip())
|
| 384 |
+
return n / total if n > 1.0 else n
|
| 385 |
+
return (gr.update(value=""), gr.update(value=_r(sr)),
|
| 386 |
+
gr.update(value=_r(sp)), gr.update(value=float(sm)))
|
| 387 |
+
except Exception:
|
| 388 |
+
return [gr.skip()] * 4
|
| 389 |
+
|
| 390 |
+
def _reg_sched_cb(steps_comp):
|
| 391 |
+
sched_info.change(fn=_restore_sched,
|
| 392 |
+
inputs=[sched_info, steps_comp],
|
| 393 |
+
outputs=[sched_info, start_r, stop_r, smooth])
|
| 394 |
+
|
| 395 |
+
mode_key = "img2img" if is_img2img else "txt2img"
|
| 396 |
+
if _steps_comps[mode_key] is None:
|
| 397 |
+
_steps_cbs[mode_key].append(_reg_sched_cb)
|
| 398 |
+
else:
|
| 399 |
+
_reg_sched_cb(_steps_comps[mode_key])
|
| 400 |
+
|
| 401 |
+
def _restore_stages(info):
|
| 402 |
+
n_out = 2 + len(flat_comps)
|
| 403 |
+
if not info: return [gr.skip()] * n_out
|
| 404 |
+
try:
|
| 405 |
+
raw_list = json.loads(info)
|
| 406 |
+
sis = []
|
| 407 |
+
for d in raw_list:
|
| 408 |
+
known = {k: v for k, v in d.items()
|
| 409 |
+
if k in global_state.STAGE_FIELD_NAMES}
|
| 410 |
+
sis.append(global_state.StageInfo(**known))
|
| 411 |
+
while len(sis) < global_state.STAGES_COUNT:
|
| 412 |
+
sis.append(global_state.StageInfo())
|
| 413 |
+
except Exception:
|
| 414 |
+
return [gr.skip()] * n_out
|
| 415 |
+
flat = []
|
| 416 |
+
for si in sis:
|
| 417 |
+
for f in _SF:
|
| 418 |
+
flat.append(getattr(si, f, getattr(global_state.StageInfo(), f)))
|
| 419 |
+
auto_en = shared.opts.data.get("mega_freeu_png_auto_enable", True)
|
| 420 |
+
return (gr.update(value=""), gr.update(value=auto_en),
|
| 421 |
+
*[gr.update(value=v) for v in flat])
|
| 422 |
+
|
| 423 |
+
stages_info.change(fn=_restore_stages, inputs=[stages_info],
|
| 424 |
+
outputs=[stages_info, enabled, *flat_comps])
|
| 425 |
+
|
| 426 |
+
def _restore_ver(info):
|
| 427 |
+
if not info: return [gr.skip()] * 2
|
| 428 |
+
lbl = global_state.REVERSED_VERSIONS.get(info.strip(), info.strip())
|
| 429 |
+
return gr.update(value=""), gr.update(value=lbl)
|
| 430 |
+
|
| 431 |
+
version_info.change(fn=_restore_ver, inputs=[version_info],
|
| 432 |
+
outputs=[version_info, version])
|
| 433 |
+
|
| 434 |
+
# ββ New extended PNG restore callbacks βββββββββββββββββββββββββββββ
|
| 435 |
+
def _restore_ms_mode(info):
|
| 436 |
+
if not info: return gr.skip(), gr.skip()
|
| 437 |
+
return gr.update(value=""), gr.update(value=info.strip())
|
| 438 |
+
|
| 439 |
+
def _restore_ms_str(info):
|
| 440 |
+
if not info: return gr.skip(), gr.skip()
|
| 441 |
+
try: return gr.update(value=""), gr.update(value=float(info.strip()))
|
| 442 |
+
except Exception: return gr.skip(), gr.skip()
|
| 443 |
+
|
| 444 |
+
def _restore_ov_scales(info):
|
| 445 |
+
if info is None: return gr.skip(), gr.skip()
|
| 446 |
+
return gr.update(value=""), gr.update(value=info)
|
| 447 |
+
|
| 448 |
+
def _restore_ch_thresh(info):
|
| 449 |
+
if not info: return gr.skip(), gr.skip()
|
| 450 |
+
try: return gr.update(value=""), gr.update(value=int(float(info.strip())))
|
| 451 |
+
except Exception: return gr.skip(), gr.skip()
|
| 452 |
+
|
| 453 |
+
def _restore_verbose(info):
|
| 454 |
+
if not info: return gr.skip(), gr.skip()
|
| 455 |
+
return gr.update(value=""), gr.update(value=(info.strip().lower() == "true"))
|
| 456 |
+
|
| 457 |
+
def _restore_postcfg(info):
|
| 458 |
+
n = 12
|
| 459 |
+
if not info: return [gr.skip()] * n
|
| 460 |
+
try:
|
| 461 |
+
d = json.loads(info)
|
| 462 |
+
return (
|
| 463 |
+
gr.update(value=""),
|
| 464 |
+
gr.update(value=bool(d.get("enabled", False))),
|
| 465 |
+
gr.update(value=int(d.get("steps", 20))),
|
| 466 |
+
gr.update(value=str(d.get("mode", "inject"))),
|
| 467 |
+
gr.update(value=float(d.get("blend", 1.0))),
|
| 468 |
+
gr.update(value=float(d.get("b", 1.1))),
|
| 469 |
+
gr.update(value=bool(d.get("fourier", False))),
|
| 470 |
+
gr.update(value=str(d.get("ms_mode", "Default"))),
|
| 471 |
+
gr.update(value=float(d.get("ms_str", 1.0))),
|
| 472 |
+
gr.update(value=int(d.get("threshold", 1))),
|
| 473 |
+
gr.update(value=float(d.get("s", 0.5))),
|
| 474 |
+
gr.update(value=float(d.get("gain", 1.0))),
|
| 475 |
+
)
|
| 476 |
+
except Exception:
|
| 477 |
+
return [gr.skip()] * n
|
| 478 |
+
|
| 479 |
+
ms_mode_info.change(fn=_restore_ms_mode, inputs=[ms_mode_info],
|
| 480 |
+
outputs=[ms_mode_info, ms_mode])
|
| 481 |
+
ms_str_info.change(fn=_restore_ms_str, inputs=[ms_str_info],
|
| 482 |
+
outputs=[ms_str_info, ms_str])
|
| 483 |
+
ov_scales_info.change(fn=_restore_ov_scales, inputs=[ov_scales_info],
|
| 484 |
+
outputs=[ov_scales_info, ov_scales])
|
| 485 |
+
ch_thresh_info.change(fn=_restore_ch_thresh, inputs=[ch_thresh_info],
|
| 486 |
+
outputs=[ch_thresh_info, ch_thresh])
|
| 487 |
+
verbose_info.change(fn=_restore_verbose, inputs=[verbose_info],
|
| 488 |
+
outputs=[verbose_info, verbose])
|
| 489 |
+
postcfg_info.change(fn=_restore_postcfg, inputs=[postcfg_info],
|
| 490 |
+
outputs=[postcfg_info,
|
| 491 |
+
pcfg_en, pcfg_steps, pcfg_mode, pcfg_bl, pcfg_b,
|
| 492 |
+
pcfg_fou, pcfg_mmd, pcfg_mst, pcfg_thr, pcfg_s, pcfg_gain])
|
| 493 |
+
|
| 494 |
+
# Legacy sd-webui-freeu keys β reuse same restore logic
|
| 495 |
+
legacy_sched_info.change(fn=lambda info, steps: _restore_sched(info, steps),
|
| 496 |
+
inputs=[legacy_sched_info, _steps_comps.get(mode_key) or legacy_sched_info],
|
| 497 |
+
outputs=[legacy_sched_info, start_r, stop_r, smooth])
|
| 498 |
+
legacy_stages_info.change(fn=_restore_stages, inputs=[legacy_stages_info],
|
| 499 |
+
outputs=[legacy_stages_info, enabled, *flat_comps])
|
| 500 |
+
legacy_version_info.change(fn=_restore_ver, inputs=[legacy_version_info],
|
| 501 |
+
outputs=[legacy_version_info, version])
|
| 502 |
+
|
| 503 |
+
self.infotext_fields = [
|
| 504 |
+
(sched_info, "MegaFreeU Schedule"),
|
| 505 |
+
(stages_info, "MegaFreeU Stages"),
|
| 506 |
+
(version_info, "MegaFreeU Version"),
|
| 507 |
+
(ms_mode_info, "MegaFreeU Multiscale Mode"),
|
| 508 |
+
(ms_str_info, "MegaFreeU Multiscale Strength"),
|
| 509 |
+
(ov_scales_info, "MegaFreeU Override Scales"),
|
| 510 |
+
(ch_thresh_info, "MegaFreeU Channel Threshold"),
|
| 511 |
+
(postcfg_info, "MegaFreeU PostCFG"),
|
| 512 |
+
(verbose_info, "MegaFreeU Verbose"),
|
| 513 |
+
# Backward compat with sd-webui-freeu generated PNGs
|
| 514 |
+
(legacy_sched_info, "FreeU Schedule"),
|
| 515 |
+
(legacy_stages_info, "FreeU Stages"),
|
| 516 |
+
(legacy_version_info,"FreeU Version"),
|
| 517 |
+
]
|
| 518 |
+
self.paste_field_names = [f for _, f in self.infotext_fields]
|
| 519 |
+
|
| 520 |
+
return [
|
| 521 |
+
enabled, version, preset_dd,
|
| 522 |
+
start_r, stop_r, smooth,
|
| 523 |
+
ms_mode, ms_str, ov_scales,
|
| 524 |
+
ch_thresh,
|
| 525 |
+
*flat_comps,
|
| 526 |
+
pcfg_en, pcfg_steps, pcfg_mode, pcfg_bl, pcfg_b,
|
| 527 |
+
pcfg_fou, pcfg_mmd, pcfg_mst, pcfg_thr, pcfg_s, pcfg_gain,
|
| 528 |
+
verbose,
|
| 529 |
+
]
|
| 530 |
+
|
| 531 |
+
def process(self, p: processing.StableDiffusionProcessing, *args):
|
| 532 |
+
# ββ Branch 1: old sd-webui-freeu API (dict passed as first arg) βββββββ
|
| 533 |
+
if args and isinstance(args[0], dict):
|
| 534 |
+
global_state.instance = global_state.State(**{
|
| 535 |
+
k: v for k, v in args[0].items()
|
| 536 |
+
if k in {f.name for f in dataclasses.fields(global_state.State)}
|
| 537 |
+
})
|
| 538 |
+
global_state.apply_xyz()
|
| 539 |
+
global_state.xyz_attrs.clear()
|
| 540 |
+
st = global_state.instance
|
| 541 |
+
unet.verbose_ref.value = bool(getattr(st, "verbose", False))
|
| 542 |
+
if getattr(st, "pcfg_enabled", False):
|
| 543 |
+
p._mega_pcfg = {
|
| 544 |
+
"enabled": True,
|
| 545 |
+
"steps": st.pcfg_steps,
|
| 546 |
+
"mode": st.pcfg_mode,
|
| 547 |
+
"blend": st.pcfg_blend,
|
| 548 |
+
"b": st.pcfg_b,
|
| 549 |
+
"fourier": st.pcfg_fourier,
|
| 550 |
+
"ms_mode": st.pcfg_ms_mode,
|
| 551 |
+
"ms_str": st.pcfg_ms_str,
|
| 552 |
+
"threshold": st.pcfg_threshold,
|
| 553 |
+
"s": st.pcfg_s,
|
| 554 |
+
"gain": st.pcfg_gain,
|
| 555 |
+
"step": 0,
|
| 556 |
+
}
|
| 557 |
+
else:
|
| 558 |
+
p._mega_pcfg = {"enabled": False}
|
| 559 |
+
if st.enable:
|
| 560 |
+
unet.detect_model_channels()
|
| 561 |
+
unet._on_cpu_devices.clear()
|
| 562 |
+
_write_generation_params(p, st)
|
| 563 |
+
return
|
| 564 |
+
|
| 565 |
+
# ββ Branch 2: normal UI call βββββββββββββββββββββββββββββββββββββββββββ
|
| 566 |
+
(enabled, version, preset_dd,
|
| 567 |
+
start_r, stop_r, smooth,
|
| 568 |
+
ms_mode, ms_str, ov_scales,
|
| 569 |
+
ch_thresh, *rest) = args
|
| 570 |
+
|
| 571 |
+
n_sv = _SN * global_state.STAGES_COUNT
|
| 572 |
+
flat_stage = rest[:n_sv]
|
| 573 |
+
post = rest[n_sv:] # 11 pcfg params + verbose
|
| 574 |
+
|
| 575 |
+
verbose = bool(post[11]) if len(post) > 11 else False
|
| 576 |
+
unet.verbose_ref.value = verbose
|
| 577 |
+
|
| 578 |
+
# Write UI values into instance BEFORE apply_xyz so XYZ can override any of them
|
| 579 |
+
inst = global_state.instance
|
| 580 |
+
inst.enable = bool(enabled)
|
| 581 |
+
inst.start_ratio = start_r
|
| 582 |
+
inst.stop_ratio = stop_r
|
| 583 |
+
inst.transition_smoothness = smooth
|
| 584 |
+
inst.version = global_state.ALL_VERSIONS.get(version, "1")
|
| 585 |
+
inst.multiscale_mode = ms_mode
|
| 586 |
+
inst.multiscale_strength = float(ms_str)
|
| 587 |
+
inst.override_scales = ov_scales or ""
|
| 588 |
+
inst.channel_threshold = int(ch_thresh)
|
| 589 |
+
inst.stage_infos = _flat_to_sis(flat_stage)
|
| 590 |
+
|
| 591 |
+
# Sync Post-CFG into instance state so presets/PNG capture it
|
| 592 |
+
pcfg = post[:11]
|
| 593 |
+
if len(pcfg) >= 11:
|
| 594 |
+
inst.pcfg_enabled = bool(pcfg[0])
|
| 595 |
+
inst.pcfg_steps = int(pcfg[1])
|
| 596 |
+
inst.pcfg_mode = str(pcfg[2])
|
| 597 |
+
inst.pcfg_blend = float(pcfg[3])
|
| 598 |
+
inst.pcfg_b = float(pcfg[4])
|
| 599 |
+
inst.pcfg_fourier = bool(pcfg[5])
|
| 600 |
+
inst.pcfg_ms_mode = str(pcfg[6])
|
| 601 |
+
inst.pcfg_ms_str = float(pcfg[7])
|
| 602 |
+
inst.pcfg_threshold = int(pcfg[8])
|
| 603 |
+
inst.pcfg_s = float(pcfg[9])
|
| 604 |
+
inst.pcfg_gain = float(pcfg[10])
|
| 605 |
+
inst.verbose = verbose
|
| 606 |
+
|
| 607 |
+
# apply_xyz() may replace global_state.instance with a preset copy;
|
| 608 |
+
# take the fresh reference AFTER so PNG metadata / verbose use the final state.
|
| 609 |
+
global_state.apply_xyz()
|
| 610 |
+
global_state.xyz_attrs.clear()
|
| 611 |
+
st = global_state.instance # β fresh ref post-XYZ
|
| 612 |
+
|
| 613 |
+
# ββ Post-CFG: set up ALWAYS (independent of main Enable) ββββββββββββββ
|
| 614 |
+
if st.pcfg_enabled:
|
| 615 |
+
p._mega_pcfg = {
|
| 616 |
+
"enabled": True,
|
| 617 |
+
"steps": st.pcfg_steps,
|
| 618 |
+
"mode": st.pcfg_mode,
|
| 619 |
+
"blend": st.pcfg_blend,
|
| 620 |
+
"b": st.pcfg_b,
|
| 621 |
+
"fourier": st.pcfg_fourier,
|
| 622 |
+
"ms_mode": st.pcfg_ms_mode,
|
| 623 |
+
"ms_str": st.pcfg_ms_str,
|
| 624 |
+
"threshold": st.pcfg_threshold,
|
| 625 |
+
"s": st.pcfg_s,
|
| 626 |
+
"gain": st.pcfg_gain,
|
| 627 |
+
"step": 0,
|
| 628 |
+
}
|
| 629 |
+
else:
|
| 630 |
+
p._mega_pcfg = {"enabled": False}
|
| 631 |
+
|
| 632 |
+
if not st.enable:
|
| 633 |
+
# Write partial params so PNG records the session even when disabled
|
| 634 |
+
_write_generation_params(p, st)
|
| 635 |
+
return
|
| 636 |
+
|
| 637 |
+
unet.detect_model_channels()
|
| 638 |
+
unet._on_cpu_devices.clear()
|
| 639 |
+
|
| 640 |
+
_write_generation_params(p, st)
|
| 641 |
+
|
| 642 |
+
if unet.verbose_ref.value:
|
| 643 |
+
print(f"[MegaFreeU] v{st.version} "
|
| 644 |
+
f"start={st.start_ratio:.3f} stop={st.stop_ratio:.3f} "
|
| 645 |
+
f"smooth={st.transition_smoothness:.3f} "
|
| 646 |
+
f"ch_thresh=+-{st.channel_threshold}")
|
| 647 |
+
for i, si in enumerate(st.stage_infos):
|
| 648 |
+
ch = unet._stage_channels[i] if i < len(unet._stage_channels) else "?"
|
| 649 |
+
print(f" Stage {i+1} ({ch}ch): "
|
| 650 |
+
f"b={si.backbone_factor:.3f} [{si.b_start_ratio:.2f}-{si.b_end_ratio:.2f}] "
|
| 651 |
+
f"{si.backbone_blend_mode}:{si.backbone_blend:.2f} "
|
| 652 |
+
f"s={si.skip_factor:.3f} [{si.s_start_ratio:.2f}-{si.s_end_ratio:.2f}] "
|
| 653 |
+
f"fft={si.fft_type} r={si.fft_radius_ratio:.3f} "
|
| 654 |
+
f"hfe={si.skip_high_end_factor:.2f} hfb={si.hf_boost:.2f} "
|
| 655 |
+
f"cap={'ON' if si.enable_adaptive_cap else 'off'} "
|
| 656 |
+
f"({si.cap_threshold:.2f}/{si.cap_factor:.2f} {si.adaptive_cap_mode})")
|
| 657 |
+
|
| 658 |
+
def process_batch(self, p, *args, **kwargs):
|
| 659 |
+
global_state.current_sampling_step = 0
|
| 660 |
+
# FIX: reset PostCFG step counter for each image in batch
|
| 661 |
+
if hasattr(p, "_mega_pcfg"):
|
| 662 |
+
p._mega_pcfg["step"] = 0
|
| 663 |
+
|
| 664 |
+
def postprocess(self, p, processed, *args, **kwargs):
|
| 665 |
+
"""Clean up per-image state after generation."""
|
| 666 |
+
if hasattr(p, "_mega_pcfg"):
|
| 667 |
+
p._mega_pcfg = {"enabled": False}
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
def _write_generation_params(p, st):
|
| 671 |
+
"""Write full Mega FreeU state into PNG extra_generation_params."""
|
| 672 |
+
p.extra_generation_params["MegaFreeU Schedule"] = (
|
| 673 |
+
f"{st.start_ratio}, {st.stop_ratio}, {st.transition_smoothness}")
|
| 674 |
+
p.extra_generation_params["MegaFreeU Stages"] = (
|
| 675 |
+
json.dumps([si.to_dict() for si in st.stage_infos]))
|
| 676 |
+
p.extra_generation_params["MegaFreeU Version"] = st.version
|
| 677 |
+
p.extra_generation_params["MegaFreeU Multiscale Mode"] = st.multiscale_mode
|
| 678 |
+
p.extra_generation_params["MegaFreeU Multiscale Strength"] = str(st.multiscale_strength)
|
| 679 |
+
p.extra_generation_params["MegaFreeU Override Scales"] = st.override_scales or ""
|
| 680 |
+
p.extra_generation_params["MegaFreeU Channel Threshold"] = str(st.channel_threshold)
|
| 681 |
+
p.extra_generation_params["MegaFreeU Verbose"] = str(st.verbose)
|
| 682 |
+
if st.pcfg_enabled:
|
| 683 |
+
p.extra_generation_params["MegaFreeU PostCFG"] = json.dumps({
|
| 684 |
+
"enabled": st.pcfg_enabled,
|
| 685 |
+
"steps": st.pcfg_steps,
|
| 686 |
+
"mode": st.pcfg_mode,
|
| 687 |
+
"blend": st.pcfg_blend,
|
| 688 |
+
"b": st.pcfg_b,
|
| 689 |
+
"fourier": st.pcfg_fourier,
|
| 690 |
+
"ms_mode": st.pcfg_ms_mode,
|
| 691 |
+
"ms_str": st.pcfg_ms_str,
|
| 692 |
+
"threshold": st.pcfg_threshold,
|
| 693 |
+
"s": st.pcfg_s,
|
| 694 |
+
"gain": st.pcfg_gain,
|
| 695 |
+
})
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
def _flat_to_sis(flat) -> List[global_state.StageInfo]:
|
| 699 |
+
result = []
|
| 700 |
+
for i in range(global_state.STAGES_COUNT):
|
| 701 |
+
chunk = flat[i * _SN:(i + 1) * _SN]
|
| 702 |
+
si = global_state.StageInfo()
|
| 703 |
+
for j, fname in enumerate(_SF):
|
| 704 |
+
if j < len(chunk):
|
| 705 |
+
setattr(si, fname, chunk[j])
|
| 706 |
+
result.append(si)
|
| 707 |
+
return result
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
# Callbacks
|
| 711 |
+
def _on_cfg_step(*_args, **_kwargs):
|
| 712 |
+
global_state.current_sampling_step += 1
|
| 713 |
+
|
| 714 |
+
def _on_cfg_post(params):
|
| 715 |
+
"""WAS_PostCFGShift ported to A1111 on_cfg_after_cfg callback (exact algorithm)."""
|
| 716 |
+
p = getattr(params, "p", None)
|
| 717 |
+
if p is None:
|
| 718 |
+
p = getattr(getattr(params, "denoiser", None), "p", None)
|
| 719 |
+
if p is None: return
|
| 720 |
+
cfg = getattr(p, "_mega_pcfg", None)
|
| 721 |
+
if not cfg or not cfg.get("enabled"): return
|
| 722 |
+
cfg["step"] = cfg.get("step", 0) + 1
|
| 723 |
+
if cfg["step"] > cfg["steps"]: return
|
| 724 |
+
x = params.x
|
| 725 |
+
fn = unet.BLENDING_MODES.get(cfg["mode"], unet.BLENDING_MODES["inject"])
|
| 726 |
+
y = fn(x, x * cfg["b"], cfg["blend"])
|
| 727 |
+
if cfg["fourier"]:
|
| 728 |
+
ms = global_state.MSCALES.get(cfg["ms_mode"])
|
| 729 |
+
y = unet.filter_skip_box_multiscale(
|
| 730 |
+
y, cfg["threshold"], cfg["s"], ms, cfg["ms_str"])
|
| 731 |
+
if cfg["gain"] != 1.0:
|
| 732 |
+
y = y * float(cfg["gain"])
|
| 733 |
+
params.x = y
|
| 734 |
+
|
| 735 |
+
try:
|
| 736 |
+
script_callbacks.on_cfg_after_cfg(_on_cfg_step)
|
| 737 |
+
script_callbacks.on_cfg_after_cfg(_on_cfg_post)
|
| 738 |
+
except AttributeError:
|
| 739 |
+
# webui < 1.6.0 (sd-webui-freeu compatibility note)
|
| 740 |
+
script_callbacks.on_cfg_denoised(_on_cfg_step)
|
| 741 |
+
script_callbacks.on_cfg_denoised(_on_cfg_post)
|
| 742 |
+
|
| 743 |
+
def _on_after_component(component, **kwargs):
|
| 744 |
+
eid = kwargs.get("elem_id", "")
|
| 745 |
+
for key, sid in [("txt2img", "txt2img_steps"), ("img2img", "img2img_steps")]:
|
| 746 |
+
if eid == sid:
|
| 747 |
+
_steps_comps[key] = component
|
| 748 |
+
for cb in _steps_cbs[key]: cb(component)
|
| 749 |
+
_steps_cbs[key].clear()
|
| 750 |
+
|
| 751 |
+
script_callbacks.on_after_component(_on_after_component)
|
| 752 |
+
|
| 753 |
+
def _on_ui_settings():
|
| 754 |
+
shared.opts.add_option(
|
| 755 |
+
"mega_freeu_png_auto_enable",
|
| 756 |
+
shared.OptionInfo(
|
| 757 |
+
default=True,
|
| 758 |
+
label="Auto-enable Mega FreeU when loading PNG info from a FreeU generation",
|
| 759 |
+
section=("mega_freeu", "Mega FreeU")))
|
| 760 |
+
|
| 761 |
+
script_callbacks.on_ui_settings(_on_ui_settings)
|
| 762 |
+
script_callbacks.on_before_ui(xyz_grid.patch)
|
| 763 |
+
|
| 764 |
+
# Install th.cat patch at import (sd-webui-freeu pattern)
|
| 765 |
+
unet.patch()
|
mega_freeu_a1111/tests/README.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mega FreeU β Test Suite
|
| 2 |
+
|
| 3 |
+
Tests run without A1111 or PyTorch installed. A NumPy-based torch mock is used instead.
|
| 4 |
+
|
| 5 |
+
## Requirements
|
| 6 |
+
|
| 7 |
+
```
|
| 8 |
+
pip install numpy
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
## Running
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
# From the extension root:
|
| 15 |
+
python tests/test_core.py # 144 tests β math, filters, blending, schedule, state
|
| 16 |
+
python tests/test_fixes.py # 36 tests β dict-API, PNG metadata, Post-CFG, XYZ
|
| 17 |
+
python tests/test_preset_pcfg.py # 32 tests β full preset save/apply round-trip
|
| 18 |
+
|
| 19 |
+
# Or run all at once:
|
| 20 |
+
for f in tests/test_core.py tests/test_fixes.py tests/test_preset_pcfg.py; do
|
| 21 |
+
python "$f" && echo "--- $f PASSED ---" || echo "--- $f FAILED ---"
|
| 22 |
+
done
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
## What is covered
|
| 26 |
+
|
| 27 |
+
| File | Tests | Coverage |
|
| 28 |
+
|------|-------|----------|
|
| 29 |
+
| `test_core.py` | 144 | `ratio_to_region`, `lerp`, `get_backbone_scale` (V1+V2), `filter_skip_box`, `fourier_filter_gauss`, `get_band_energy_stats`, `filter_skip_box_multiscale`, all 9 blending modes, `parse_override_scales`, `_normalize`, `get_schedule_ratio`, `get_stage_bsratio`, `StageInfo`/`State` dataclasses, `update_attr`, `_load_user_presets`, `filter_skip_gaussian_adaptive` (no-cap/cap/fixed/aggressive), backbone blend math, `apply_xyz`, `detect_model_channels`, `_flat_to_sis`, PostCFG step counter |
|
| 30 |
+
| `test_fixes.py` | 36 | `State` pcfg/verbose fields + round-trip, `_load_user_presets` with pcfg, `_write_generation_params` PNG keys, Post-CFG independent of Enable, dict-API compat |
|
| 31 |
+
| `test_preset_pcfg.py` | 32 | Full preset save/apply with all 20 fields, pcfg disabled case, unknown preset, dict-API pcfg propagation |
|
mega_freeu_a1111/tests/mock_torch.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math, types, sys
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
class Tensor:
|
| 5 |
+
def __init__(self, data, device='cpu', dtype=None):
|
| 6 |
+
if isinstance(data, np.ndarray):
|
| 7 |
+
self._d = data
|
| 8 |
+
elif isinstance(data, Tensor):
|
| 9 |
+
self._d = data._d.copy()
|
| 10 |
+
else:
|
| 11 |
+
self._d = np.array(data, dtype=np.float32)
|
| 12 |
+
self.device = device
|
| 13 |
+
self.dtype = dtype or 'float32'
|
| 14 |
+
|
| 15 |
+
@property
|
| 16 |
+
def shape(self): return self._d.shape
|
| 17 |
+
def size(self, dim=None): return self.shape[dim] if dim is not None else self.shape
|
| 18 |
+
@property
|
| 19 |
+
def is_cpu(self): return True
|
| 20 |
+
|
| 21 |
+
# dtype / device
|
| 22 |
+
def float(self):
|
| 23 |
+
return Tensor(self._d.astype(np.float32), self.device, 'float32')
|
| 24 |
+
def to(self, *a, **kw): return self
|
| 25 |
+
def cpu(self): return Tensor(self._d.copy(), 'cpu', self.dtype)
|
| 26 |
+
def item(self): return float(self._d.flat[0])
|
| 27 |
+
|
| 28 |
+
# shape ops
|
| 29 |
+
def view(self, *shape):
|
| 30 |
+
s = shape[0] if len(shape)==1 and isinstance(shape[0],(tuple,list)) else shape
|
| 31 |
+
return Tensor(self._d.reshape(s), self.device, self.dtype)
|
| 32 |
+
def reshape(self, *s): return self.view(*s)
|
| 33 |
+
def unsqueeze(self, dim):
|
| 34 |
+
return Tensor(np.expand_dims(self._d, dim), self.device, self.dtype)
|
| 35 |
+
def squeeze(self, dim=None):
|
| 36 |
+
d = self._d.squeeze() if dim is None else self._d.squeeze(dim)
|
| 37 |
+
return Tensor(d, self.device, self.dtype)
|
| 38 |
+
def expand_as(self, other):
|
| 39 |
+
return Tensor(np.broadcast_to(self._d, other.shape).copy(), self.device, self.dtype)
|
| 40 |
+
|
| 41 |
+
# .real / .imag β real tensors just return themselves
|
| 42 |
+
@property
|
| 43 |
+
def real(self): return Tensor(self._d.real.astype(np.float32), self.device, self.dtype)
|
| 44 |
+
@property
|
| 45 |
+
def imag(self): return Tensor(np.zeros_like(self._d.real, dtype=np.float32), self.device, self.dtype)
|
| 46 |
+
|
| 47 |
+
# reductions
|
| 48 |
+
def mean(self, dim=None, keepdim=False):
|
| 49 |
+
if dim is None: return Tensor(np.array(self._d.mean(), np.float32), self.device)
|
| 50 |
+
return Tensor(self._d.mean(axis=dim, keepdims=keepdim).astype(np.float32), self.device)
|
| 51 |
+
def sum(self, dim=None, keepdim=False):
|
| 52 |
+
if dim is None: return Tensor(np.array(self._d.sum(), np.float32), self.device)
|
| 53 |
+
return Tensor(self._d.sum(axis=dim, keepdims=keepdim).astype(np.float32), self.device)
|
| 54 |
+
def any(self): return bool(self._d.any())
|
| 55 |
+
def clamp(self, lo=None, hi=None):
|
| 56 |
+
return Tensor(np.clip(self._d, lo, hi).astype(np.float32), self.device, self.dtype)
|
| 57 |
+
def clamp_min(self, v):
|
| 58 |
+
return Tensor(np.maximum(self._d, v).astype(np.float32), self.device, self.dtype)
|
| 59 |
+
def abs(self): return Tensor(np.abs(self._d).astype(np.float32), self.device, self.dtype)
|
| 60 |
+
def min(self, dim=None, keepdim=False):
|
| 61 |
+
if dim is None: return Tensor(np.array(self._d.min(), np.float32), self.device)
|
| 62 |
+
v=self._d.min(axis=dim,keepdims=keepdim); i=self._d.argmin(axis=dim)
|
| 63 |
+
return Tensor(v.astype(np.float32),self.device), Tensor(i.astype(np.float32),self.device)
|
| 64 |
+
def max(self, dim=None, keepdim=False):
|
| 65 |
+
if dim is None: return Tensor(np.array(self._d.max(), np.float32), self.device)
|
| 66 |
+
v=self._d.max(axis=dim,keepdims=keepdim); i=self._d.argmax(axis=dim)
|
| 67 |
+
return Tensor(v.astype(np.float32),self.device), Tensor(i.astype(np.float32),self.device)
|
| 68 |
+
|
| 69 |
+
# indexing
|
| 70 |
+
def __getitem__(self, idx):
|
| 71 |
+
if isinstance(idx, Tensor):
|
| 72 |
+
idx = idx._d.astype(bool)
|
| 73 |
+
elif isinstance(idx, tuple):
|
| 74 |
+
idx = tuple(i._d.astype(bool) if isinstance(i, Tensor) else i for i in idx)
|
| 75 |
+
return Tensor(self._d[idx], self.device, self.dtype)
|
| 76 |
+
def __setitem__(self, idx, val):
|
| 77 |
+
self._d[idx] = val._d if isinstance(val, Tensor) else val
|
| 78 |
+
|
| 79 |
+
# arithmetic helpers
|
| 80 |
+
def _v(self, o): return o._d if isinstance(o, Tensor) else o
|
| 81 |
+
def __add__(self, o): return Tensor((self._d + self._v(o)).astype(np.float32), self.device)
|
| 82 |
+
def __radd__(self, o): return self.__add__(o)
|
| 83 |
+
def __sub__(self, o): return Tensor((self._d - self._v(o)).astype(np.float32), self.device)
|
| 84 |
+
def __rsub__(self, o): return Tensor((self._v(o) - self._d).astype(np.float32), self.device)
|
| 85 |
+
def __mul__(self, o): return Tensor((self._d * self._v(o)).astype(np.float32), self.device)
|
| 86 |
+
def __rmul__(self, o): return self.__mul__(o)
|
| 87 |
+
def __truediv__(self, o): return Tensor((self._d / self._v(o)).astype(np.float32), self.device)
|
| 88 |
+
def __rtruediv__(self, o): return Tensor((self._v(o) / self._d).astype(np.float32), self.device)
|
| 89 |
+
def __neg__(self): return Tensor(-self._d, self.device, self.dtype)
|
| 90 |
+
def __pow__(self, n): return Tensor(self._d**n, self.device, self.dtype)
|
| 91 |
+
def __imul__(self, o): self._d = (self._d * self._v(o)).astype(np.float32); return self
|
| 92 |
+
def __abs__(self): return self.abs()
|
| 93 |
+
|
| 94 |
+
# comparisons β float32 0/1 tensor (bool-compatible)
|
| 95 |
+
def __le__(self, o): return Tensor((self._d <= self._v(o)).astype(np.float32), self.device)
|
| 96 |
+
def __lt__(self, o): return Tensor((self._d < self._v(o)).astype(np.float32), self.device)
|
| 97 |
+
def __ge__(self, o): return Tensor((self._d >= self._v(o)).astype(np.float32), self.device)
|
| 98 |
+
def __gt__(self, o): return Tensor((self._d > self._v(o)).astype(np.float32), self.device)
|
| 99 |
+
def __eq__(self, o): return bool(np.array_equal(self._d, self._v(o))) if isinstance(o, (Tensor,np.ndarray)) else Tensor((self._d == o).astype(np.float32), self.device)
|
| 100 |
+
def __invert__(self): # ~bool_tensor
|
| 101 |
+
return Tensor((~self._d.astype(bool)).astype(np.float32), self.device)
|
| 102 |
+
def __bool__(self): return bool(self._d.flat[0])
|
| 103 |
+
def __repr__(self): return f"Tensor({self._d.shape})"
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# ComplexTensor β returned by FFT ops
|
| 107 |
+
class CTensor(Tensor):
|
| 108 |
+
def __init__(self, data, device='cpu', dtype=None):
|
| 109 |
+
super().__init__(data, device, dtype)
|
| 110 |
+
@property
|
| 111 |
+
def real(self):
|
| 112 |
+
return Tensor(self._d.real.astype(np.float32), self.device, self.dtype)
|
| 113 |
+
@property
|
| 114 |
+
def imag(self):
|
| 115 |
+
return Tensor(self._d.imag.astype(np.float32), self.device, self.dtype)
|
| 116 |
+
def __mul__(self, o):
|
| 117 |
+
v = o._d if isinstance(o, Tensor) else o
|
| 118 |
+
return CTensor((self._d * v), self.device)
|
| 119 |
+
def __imul__(self, o):
|
| 120 |
+
v = o._d if isinstance(o, Tensor) else o
|
| 121 |
+
self._d = self._d * v
|
| 122 |
+
return self
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# ββ torch namespace ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 126 |
+
torch = types.ModuleType('torch')
|
| 127 |
+
torch.Tensor = Tensor
|
| 128 |
+
torch.float32 = 'float32'
|
| 129 |
+
torch.device = lambda s: s
|
| 130 |
+
|
| 131 |
+
torch.full = lambda shape, v, device='cpu', **kw: Tensor(np.full(shape, float(v), np.float32), device)
|
| 132 |
+
torch.ones = lambda *sh, device='cpu', **kw: Tensor(np.ones(sh[0] if len(sh)==1 else sh, np.float32), device)
|
| 133 |
+
torch.zeros = lambda *sh, device='cpu', **kw: Tensor(np.zeros(sh[0] if len(sh)==1 else sh, np.float32), device)
|
| 134 |
+
torch.tensor = lambda v, device='cpu', dtype=None, **kw: Tensor(np.array(v, dtype=np.float32), device)
|
| 135 |
+
torch.arange = lambda n, device='cpu', dtype=None, **kw: Tensor(np.arange(n, dtype=np.float32), device)
|
| 136 |
+
|
| 137 |
+
def _meshgrid(y, x, indexing='ij'):
|
| 138 |
+
yy, xx = np.meshgrid(y._d, x._d, indexing=indexing)
|
| 139 |
+
return Tensor(yy.astype(np.float32), y.device), Tensor(xx.astype(np.float32), x.device)
|
| 140 |
+
torch.meshgrid = _meshgrid
|
| 141 |
+
|
| 142 |
+
torch.exp = lambda x: Tensor(np.exp(np.clip(x._d,-500,500)).astype(np.float32), x.device)
|
| 143 |
+
torch.sin = lambda x: Tensor(np.sin(x._d).astype(np.float32), x.device)
|
| 144 |
+
torch.cos = lambda x: Tensor(np.cos(x._d).astype(np.float32), x.device)
|
| 145 |
+
torch.acos = lambda x: Tensor(np.arccos(np.clip(x._d,-1+1e-7,1-1e-7)).astype(np.float32), x.device)
|
| 146 |
+
torch.abs = lambda x: Tensor(np.abs(x._d).astype(np.float32), x.device)
|
| 147 |
+
torch.sqrt = lambda x: Tensor(np.sqrt(np.maximum(x._d,0)).astype(np.float32), x.device)
|
| 148 |
+
torch.norm = lambda x, dim=None, keepdim=False, **kw: Tensor(
|
| 149 |
+
np.linalg.norm(x._d, axis=dim, keepdims=keepdim).astype(np.float32), x.device)
|
| 150 |
+
|
| 151 |
+
def _max(x, dim=None, keepdim=False):
|
| 152 |
+
if dim is None: return float(x._d.max())
|
| 153 |
+
v = x._d.max(axis=dim, keepdims=keepdim)
|
| 154 |
+
i = x._d.argmax(axis=dim)
|
| 155 |
+
return Tensor(v.astype(np.float32), x.device), Tensor(i.astype(np.float32), x.device)
|
| 156 |
+
def _min(x, dim=None, keepdim=False):
|
| 157 |
+
if dim is None: return float(x._d.min())
|
| 158 |
+
v = x._d.min(axis=dim, keepdims=keepdim)
|
| 159 |
+
i = x._d.argmin(axis=dim)
|
| 160 |
+
return Tensor(v.astype(np.float32), x.device), Tensor(i.astype(np.float32), x.device)
|
| 161 |
+
torch.max = _max
|
| 162 |
+
torch.min = _min
|
| 163 |
+
|
| 164 |
+
def _where(c, a, b):
|
| 165 |
+
cd = c._d.astype(bool) if isinstance(c, Tensor) else np.array(c, bool)
|
| 166 |
+
ad = a._d if isinstance(a, Tensor) else a
|
| 167 |
+
bd = b._d if isinstance(b, Tensor) else b
|
| 168 |
+
return Tensor(np.where(cd, ad, bd).astype(np.float32),
|
| 169 |
+
(a if isinstance(a,Tensor) else b).device)
|
| 170 |
+
torch.where = _where
|
| 171 |
+
|
| 172 |
+
class _Linalg:
|
| 173 |
+
@staticmethod
|
| 174 |
+
def norm(x, dim=None, keepdim=False, **kw):
|
| 175 |
+
return Tensor(
|
| 176 |
+
np.linalg.norm(x._d, axis=dim, keepdims=keepdim).astype(np.float32),
|
| 177 |
+
x.device)
|
| 178 |
+
torch.linalg = _Linalg
|
| 179 |
+
|
| 180 |
+
class _FFT:
|
| 181 |
+
@staticmethod
|
| 182 |
+
def fftn(x, dim=None):
|
| 183 |
+
ax = tuple(dim) if dim is not None else None
|
| 184 |
+
return CTensor(np.fft.fftn(x._d.astype(complex), axes=ax), x.device)
|
| 185 |
+
@staticmethod
|
| 186 |
+
def ifftn(x, dim=None):
|
| 187 |
+
ax = tuple(dim) if dim is not None else None
|
| 188 |
+
r = np.fft.ifftn(x._d, axes=ax).real.astype(np.float32)
|
| 189 |
+
# Return a plain Tensor but with .real property (already float32)
|
| 190 |
+
return Tensor(r, x.device)
|
| 191 |
+
@staticmethod
|
| 192 |
+
def fftshift(x, dim=None):
|
| 193 |
+
ax = tuple(dim) if dim is not None else None
|
| 194 |
+
d = np.fft.fftshift(x._d, axes=ax)
|
| 195 |
+
return CTensor(d, x.device) if isinstance(x, CTensor) else Tensor(d, x.device)
|
| 196 |
+
@staticmethod
|
| 197 |
+
def ifftshift(x, dim=None):
|
| 198 |
+
ax = tuple(dim) if dim is not None else None
|
| 199 |
+
d = np.fft.ifftshift(x._d, axes=ax)
|
| 200 |
+
return CTensor(d, x.device) if isinstance(x, CTensor) else Tensor(d, x.device)
|
| 201 |
+
torch.fft = _FFT
|
| 202 |
+
|
| 203 |
+
class _Backends:
|
| 204 |
+
class mps:
|
| 205 |
+
@staticmethod
|
| 206 |
+
def is_available(): return False
|
| 207 |
+
torch.backends = _Backends
|
| 208 |
+
|
| 209 |
+
sys.modules['torch'] = torch
|
| 210 |
+
|
| 211 |
+
# ββ modules mock βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 212 |
+
class _State:
|
| 213 |
+
sampling_steps = 20
|
| 214 |
+
class _SharedObj:
|
| 215 |
+
state = _State()
|
| 216 |
+
class opts:
|
| 217 |
+
data = {}
|
| 218 |
+
|
| 219 |
+
mmod = types.ModuleType('modules')
|
| 220 |
+
mmod.shared = _SharedObj()
|
| 221 |
+
sys.modules['modules'] = mmod
|
| 222 |
+
for sub in ['modules.shared','modules.scripts','modules.processing',
|
| 223 |
+
'modules.script_callbacks','gradio']:
|
| 224 |
+
sys.modules[sub] = types.ModuleType(sub)
|
| 225 |
+
|
| 226 |
+
import logging
|
| 227 |
+
sys.modules['logging'] = logging
|
| 228 |
+
# expose modules.shared in the right place
|
| 229 |
+
import types as _t
|
mega_freeu_a1111/tests/test_core.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pathlib, sys
|
| 2 |
+
exec(open(str(pathlib.Path(__file__).parent / 'mock_torch.py')).read())
|
| 3 |
+
|
| 4 |
+
import sys, math, types, dataclasses, json, tempfile, os, pathlib
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pathlib; sys.path.insert(0, str(pathlib.Path(__file__).parent.parent))
|
| 7 |
+
|
| 8 |
+
# ββ load ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 9 |
+
import importlib.util
|
| 10 |
+
|
| 11 |
+
def load_mod(name, path):
|
| 12 |
+
spec = importlib.util.spec_from_file_location(name, path)
|
| 13 |
+
m = importlib.util.module_from_spec(spec)
|
| 14 |
+
sys.modules[name] = m
|
| 15 |
+
spec.loader.exec_module(m)
|
| 16 |
+
return m
|
| 17 |
+
|
| 18 |
+
lib_pkg = types.ModuleType('lib_mega_freeu')
|
| 19 |
+
sys.modules['lib_mega_freeu'] = lib_pkg
|
| 20 |
+
|
| 21 |
+
GS = load_mod('lib_mega_freeu.global_state',
|
| 22 |
+
str(pathlib.Path(__file__).parent.parent / 'lib_mega_freeu' / 'global_state.py'))
|
| 23 |
+
UN = load_mod('lib_mega_freeu.unet',
|
| 24 |
+
str(pathlib.Path(__file__).parent.parent / 'lib_mega_freeu' / 'unet.py'))
|
| 25 |
+
print("Loaded OK\n")
|
| 26 |
+
|
| 27 |
+
# ββ test helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 28 |
+
P=0; F=0; ERRS=[]
|
| 29 |
+
|
| 30 |
+
def ok(t): global P; P+=1; print(f" β {t}")
|
| 31 |
+
def ng(t,m=""): global F; F+=1; ERRS.append(f"{t}: {m}"); print(f" β {t} {m}")
|
| 32 |
+
def chk(t, cond, m=""): ok(t) if cond else ng(t, m)
|
| 33 |
+
def near(t, got, want, tol=1e-4):
|
| 34 |
+
g = float(got) if hasattr(got,'item') else float(got)
|
| 35 |
+
(ok(t+f" ({g:.5g})") if abs(g-want)<=tol else ng(t, f"got {g:.5g} want {want}"))
|
| 36 |
+
def shp(t, tensor, expected):
|
| 37 |
+
chk(t+f" shape={tensor.shape}", tuple(tensor.shape)==tuple(expected), f"β {expected}")
|
| 38 |
+
|
| 39 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 40 |
+
print("β"*54)
|
| 41 |
+
print("1. RATIO_TO_REGION")
|
| 42 |
+
print("β"*54)
|
| 43 |
+
R = UN.ratio_to_region
|
| 44 |
+
s,e,inv = R(0.5, 0.0, 100)
|
| 45 |
+
chk("[0,50] no-inv", s==0 and e==50 and not inv, f"s={s},e={e},inv={inv}")
|
| 46 |
+
s,e,inv = R(0.5, 0.5, 100)
|
| 47 |
+
chk("[50,100] no-inv", s==50 and e==100 and not inv)
|
| 48 |
+
s,e,inv = R(0.7, 0.5, 100)
|
| 49 |
+
chk("width+offset>1 β inverted", inv and s==20 and e==50, f"s={s},e={e},inv={inv}")
|
| 50 |
+
s,e,inv = R(-0.3, 0.5, 100) # negative width
|
| 51 |
+
chk("neg width no crash", True)
|
| 52 |
+
s,e,inv = R(0.0, 0.0, 100)
|
| 53 |
+
chk("zero width: s==e", s==e)
|
| 54 |
+
s,e,inv = R(1.0, 0.0, 100) # full width β [0,100]
|
| 55 |
+
chk("full width [0,100]", s==0 and e==100 and not inv)
|
| 56 |
+
|
| 57 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 58 |
+
print("\n2. LERP")
|
| 59 |
+
near("lerp(0,1,.5)", UN.lerp(0,1,.5), .5)
|
| 60 |
+
near("lerp t=0 β a", UN.lerp(5,10,0), 5)
|
| 61 |
+
near("lerp t=1 β b", UN.lerp(5,10,1), 10)
|
| 62 |
+
near("lerp(2,4,.25)", UN.lerp(2,4,.25), 2.5)
|
| 63 |
+
a=Tensor(np.array([[[[2.,4.]]]])); b=Tensor(np.array([[[[6.,8.]]]]))
|
| 64 |
+
r=UN.lerp(a,b,0.5)
|
| 65 |
+
chk("tensor lerp values", np.allclose(r._d, [[[[4.,6.]]]]))
|
| 66 |
+
|
| 67 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 68 |
+
print("\n3. GET_BACKBONE_SCALE")
|
| 69 |
+
h = Tensor(np.random.randn(2,8,4,4).astype(np.float32))
|
| 70 |
+
# V1
|
| 71 |
+
v1 = UN.get_backbone_scale(h, 1.3, "1")
|
| 72 |
+
near("V1=1.3", v1 if isinstance(v1,(int,float)) else v1.item(), 1.3)
|
| 73 |
+
# V2
|
| 74 |
+
v2 = UN.get_backbone_scale(h, 1.4, "2")
|
| 75 |
+
shp("V2 shape=(2,1,4,4)", v2, (2,1,4,4))
|
| 76 |
+
chk("V2 vals in (0,2.5)", v2._d.min()>0 and v2._d.max()<2.5,
|
| 77 |
+
f"[{v2._d.min():.3f},{v2._d.max():.3f}]")
|
| 78 |
+
# V2 factor=1.0 β all 1s
|
| 79 |
+
v2id = UN.get_backbone_scale(h, 1.0, "2")
|
| 80 |
+
chk("V2 factor=1βall-ones", np.allclose(v2id._d, 1.0, atol=1e-5))
|
| 81 |
+
# V2 factor=2.0 β range [1,2]
|
| 82 |
+
v2f2 = UN.get_backbone_scale(h, 2.0, "2")
|
| 83 |
+
chk("V2 factor=2 β [1,2]",
|
| 84 |
+
v2f2._d.min()>=1.0-1e-5 and v2f2._d.max()<=2.0+1e-5,
|
| 85 |
+
f"[{v2f2._d.min():.4f},{v2f2._d.max():.4f}]")
|
| 86 |
+
# zero input β no nan/crash
|
| 87 |
+
hz = Tensor(np.zeros((1,8,4,4),np.float32))
|
| 88 |
+
v2z = UN.get_backbone_scale(hz, 1.5, "2")
|
| 89 |
+
chk("V2 zero input: no nan", not np.isnan(v2z._d).any())
|
| 90 |
+
|
| 91 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 92 |
+
print("\n4. FILTER_SKIP_BOX")
|
| 93 |
+
x = Tensor(np.random.randn(1,4,16,16).astype(np.float32))
|
| 94 |
+
# identity
|
| 95 |
+
out = UN.filter_skip_box(x, 0.5, 1.0, 1.0)
|
| 96 |
+
chk("box identity", np.allclose(out._d, x._d, atol=1e-4))
|
| 97 |
+
# scale=0
|
| 98 |
+
out0 = UN.filter_skip_box(x, 0.5, 0.0, 1.0)
|
| 99 |
+
shp("box s=0 shape", out0, (1,4,16,16))
|
| 100 |
+
chk("box s=0 β input", not np.allclose(out0._d, x._d, atol=1e-3))
|
| 101 |
+
# scale_high=0
|
| 102 |
+
outh = UN.filter_skip_box(x, 0.5, 1.0, 0.0)
|
| 103 |
+
chk("box h=0 β input", not np.allclose(outh._d, x._d, atol=1e-3))
|
| 104 |
+
# cutoff=0 edge
|
| 105 |
+
outc0 = UN.filter_skip_box(x, 0.0, 0.5, 1.0)
|
| 106 |
+
shp("box cutoff=0 shape", outc0, (1,4,16,16))
|
| 107 |
+
# int cutoff
|
| 108 |
+
outi = UN.filter_skip_box(x, 2, 0.7, 1.0)
|
| 109 |
+
shp("box int cutoff", outi, (1,4,16,16))
|
| 110 |
+
# batch > 1
|
| 111 |
+
xb = Tensor(np.random.randn(3,8,16,16).astype(np.float32))
|
| 112 |
+
outb = UN.filter_skip_box(xb, 0.3, 0.7, 1.0)
|
| 113 |
+
shp("box batch=3", outb, (3,8,16,16))
|
| 114 |
+
# no nan
|
| 115 |
+
chk("box no nan", not np.isnan(out0._d).any())
|
| 116 |
+
|
| 117 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 118 |
+
print("\n5. FOURIER_FILTER_GAUSS")
|
| 119 |
+
xg = Tensor(np.random.randn(1,4,16,16).astype(np.float32))
|
| 120 |
+
outg = UN.fourier_filter_gauss(xg, 0.1, 0.8, 1.2)
|
| 121 |
+
shp("gauss shape", outg, (1,4,16,16))
|
| 122 |
+
chk("gauss s=.8 β input", not np.allclose(outg._d, xg._d, atol=1e-3))
|
| 123 |
+
# identity
|
| 124 |
+
outgi = UN.fourier_filter_gauss(xg, 0.1, 1.0, 1.0)
|
| 125 |
+
chk("gauss identity β input", np.allclose(outgi._d, xg._d, atol=1e-4),
|
| 126 |
+
f"maxdiff={abs(outgi._d-xg._d).max():.2e}")
|
| 127 |
+
# no nan
|
| 128 |
+
chk("gauss no nan", not np.isnan(outg._d).any())
|
| 129 |
+
# batch=3
|
| 130 |
+
xg3 = Tensor(np.random.randn(3,8,16,16).astype(np.float32))
|
| 131 |
+
outg3 = UN.fourier_filter_gauss(xg3, 0.08, 0.9, 1.0)
|
| 132 |
+
shp("gauss batch=3", outg3, (3,8,16,16))
|
| 133 |
+
# tiny radius (R=1)
|
| 134 |
+
outgt = UN.fourier_filter_gauss(xg, 0.01, 0.5, 1.0)
|
| 135 |
+
shp("gauss tiny R=1 shape", outgt, (1,4,16,16))
|
| 136 |
+
# large radius (R=max)
|
| 137 |
+
outgl = UN.fourier_filter_gauss(xg, 0.49, 0.5, 1.0)
|
| 138 |
+
shp("gauss large radius shape", outgl, (1,4,16,16))
|
| 139 |
+
|
| 140 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 141 |
+
print("\n6. GET_BAND_ENERGY_STATS")
|
| 142 |
+
xe = Tensor(np.random.randn(2,4,16,16).astype(np.float32))
|
| 143 |
+
lf,hf,cov = UN.get_band_energy_stats(xe, 3)
|
| 144 |
+
chk("lf>0", lf>0)
|
| 145 |
+
chk("hf>0", hf>0)
|
| 146 |
+
chk("0<cov<100", 0<cov<100, f"cov={cov:.1f}")
|
| 147 |
+
# zeros β energies=0
|
| 148 |
+
xze = Tensor(np.zeros((1,4,8,8),np.float32))
|
| 149 |
+
lf0,hf0,_ = UN.get_band_energy_stats(xze, 2)
|
| 150 |
+
chk("zeros β lf=hf=0", lf0==0 and hf0==0)
|
| 151 |
+
# R=1 minimal β some coverage
|
| 152 |
+
lf1,hf1,cov1 = UN.get_band_energy_stats(xe, 1)
|
| 153 |
+
chk("R=1 cover>0", cov1>0)
|
| 154 |
+
|
| 155 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 156 |
+
print("\n7. FILTER_SKIP_BOX_MULTISCALE")
|
| 157 |
+
xm = Tensor(np.random.randn(1,4,32,32).astype(np.float32))
|
| 158 |
+
# identity
|
| 159 |
+
outmi = UN.filter_skip_box_multiscale(xm, 0.3, 1.0, None, 1.0, 1.0)
|
| 160 |
+
chk("ms identity", np.allclose(outmi._d, xm._d, atol=1e-4))
|
| 161 |
+
# single-scale preset
|
| 162 |
+
outss = UN.filter_skip_box_multiscale(xm, 0.3, 0.7, [(8,1.5)], 1.0)
|
| 163 |
+
shp("ms single-scale shape", outss, (1,4,32,32))
|
| 164 |
+
chk("ms single differs", not np.allclose(outss._d, xm._d, atol=1e-3))
|
| 165 |
+
# multi-scale preset
|
| 166 |
+
outms = UN.filter_skip_box_multiscale(xm, 0.3, 0.7, [[(5,0.0),(15,1.0)]], 1.0)
|
| 167 |
+
shp("ms multi-scale shape", outms, (1,4,32,32))
|
| 168 |
+
# scale_high
|
| 169 |
+
outsh = UN.filter_skip_box_multiscale(xm, 0.3, 0.7, None, 1.0, 1.5)
|
| 170 |
+
chk("ms scale_high=1.5 differs", not np.allclose(outsh._d, xm._d, atol=1e-3))
|
| 171 |
+
# no nan
|
| 172 |
+
chk("ms no nan", not np.isnan(outss._d).any())
|
| 173 |
+
|
| 174 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 175 |
+
print("\n8. BLENDING MODES")
|
| 176 |
+
a = Tensor(np.full((1,4,4,4), 2.0, np.float32))
|
| 177 |
+
b = Tensor(np.full((1,4,4,4), 4.0, np.float32))
|
| 178 |
+
for mname, fn in UN.BLENDING_MODES.items():
|
| 179 |
+
try:
|
| 180 |
+
out = fn(a, b, 0.5)
|
| 181 |
+
shp(f" {mname}", out, (1,4,4,4))
|
| 182 |
+
chk(f" {mname} no nan", not np.isnan(out._d).any())
|
| 183 |
+
except Exception as e:
|
| 184 |
+
ng(f" {mname} CRASHED", str(e))
|
| 185 |
+
|
| 186 |
+
lerp_fn = UN.BLENDING_MODES['lerp']
|
| 187 |
+
near("lerp t=0βa", lerp_fn(a,b,0)._d.mean(), 2.0)
|
| 188 |
+
near("lerp t=1βb", lerp_fn(a,b,1)._d.mean(), 4.0)
|
| 189 |
+
near("lerp t=.5β3", lerp_fn(a,b,.5)._d.mean(), 3.0)
|
| 190 |
+
|
| 191 |
+
inj = UN.BLENDING_MODES['inject']
|
| 192 |
+
near("inject a+b*.5=4", inj(a,b,.5)._d.mean(), 4.0)
|
| 193 |
+
near("inject t=0βa", inj(a,b,0)._d.mean(), 2.0)
|
| 194 |
+
|
| 195 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 196 |
+
print("\n9. PARSE_OVERRIDE_SCALES")
|
| 197 |
+
P2 = UN.parse_override_scales
|
| 198 |
+
chk("NoneβNone", P2(None) is None)
|
| 199 |
+
chk("emptyβNone", P2("") is None)
|
| 200 |
+
chk("comment-onlyβNone", P2("# x\n! y\n// z") is None)
|
| 201 |
+
r = P2("10, 1.5\n20, 0.8")
|
| 202 |
+
chk("2 lines β 2 entries", r is not None and len(r)==2)
|
| 203 |
+
chk("entry[0]=(10,1.5)", r and r[0]==(10,1.5), str(r))
|
| 204 |
+
r2 = P2("# comment\n5, 2.0\n! skip\n15, 0.5")
|
| 205 |
+
chk("with comments β 2 entries", r2 and len(r2)==2)
|
| 206 |
+
r3 = P2("10,1.5\nbad_line\n20,0.8") # malformed line skipped
|
| 207 |
+
chk("malformed skipped", r3 and len(r3)==2, str(r3))
|
| 208 |
+
|
| 209 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 210 |
+
print("\n10. _NORMALIZE")
|
| 211 |
+
n = UN._normalize
|
| 212 |
+
chk("norm [0,3]: min=0 max=1",
|
| 213 |
+
abs(n(Tensor(np.array([[[[0.,1.,2.,3.]]]])))._d.min())<1e-5 and
|
| 214 |
+
abs(n(Tensor(np.array([[[[0.,1.,2.,3.]]]])))._d.max()-1.0)<1e-5)
|
| 215 |
+
chk("norm const: no crash/nan",
|
| 216 |
+
not np.isnan(n(Tensor(np.ones((1,1,4,4),np.float32)*5))._d).any())
|
| 217 |
+
chk("norm negatives: min=0 max=1",
|
| 218 |
+
abs(n(Tensor(np.array([[[[-2.,-1.,0.,1.,2.]]]])))._d.max()-1.0)<1e-5)
|
| 219 |
+
|
| 220 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 221 |
+
print("\n11. GET_SCHEDULE_RATIO")
|
| 222 |
+
mmod.shared.state.sampling_steps = 20
|
| 223 |
+
GS.instance = GS.State(start_ratio=0.0, stop_ratio=1.0, transition_smoothness=0.0)
|
| 224 |
+
GS.current_sampling_step = 0
|
| 225 |
+
near("full range step=0 β 1", UN.get_schedule_ratio(), 1.0)
|
| 226 |
+
GS.current_sampling_step = 10
|
| 227 |
+
near("full range step=10 β 1", UN.get_schedule_ratio(), 1.0)
|
| 228 |
+
GS.current_sampling_step = 25
|
| 229 |
+
near("past stop β 0", UN.get_schedule_ratio(), 0.0)
|
| 230 |
+
GS.instance = GS.State(start_ratio=0.5, stop_ratio=1.0, transition_smoothness=0.0)
|
| 231 |
+
GS.current_sampling_step = 5
|
| 232 |
+
near("before start β 0", UN.get_schedule_ratio(), 0.0)
|
| 233 |
+
GS.current_sampling_step = 15
|
| 234 |
+
near("in range β 1", UN.get_schedule_ratio(), 1.0)
|
| 235 |
+
GS.instance = GS.State(start_ratio=0.5, stop_ratio=0.5, transition_smoothness=0.0)
|
| 236 |
+
near("start==stop β 0", UN.get_schedule_ratio(), 0.0)
|
| 237 |
+
GS.instance = GS.State(start_ratio=0.0, stop_ratio=1.0, transition_smoothness=1.0)
|
| 238 |
+
GS.current_sampling_step = 0
|
| 239 |
+
r = UN.get_schedule_ratio()
|
| 240 |
+
chk("smoothness=1 step=0: 0β€rβ€1", 0.0<=r<=1.0)
|
| 241 |
+
|
| 242 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 243 |
+
print("\n12. GET_STAGE_BSRATIO")
|
| 244 |
+
bsr = UN.get_stage_bsratio
|
| 245 |
+
mmod.shared.state.sampling_steps = 20
|
| 246 |
+
GS.current_sampling_step = 5 # pctβ0.26
|
| 247 |
+
near("bsr [0,1]β1", bsr(0.0,1.0), 1.0)
|
| 248 |
+
near("bsr [0.5,1]β0 (pct<0.5)", bsr(0.5,1.0), 0.0)
|
| 249 |
+
near("bsr [0,0.5]β1 (pct<0.5)", bsr(0.0,0.5), 1.0)
|
| 250 |
+
GS.current_sampling_step = 18 # pctβ0.95
|
| 251 |
+
near("bsr [0,0.5]β0 (pct>0.5)", bsr(0.0,0.5), 0.0)
|
| 252 |
+
near("bsr [0.5,1]β1 (pct>0.5)", bsr(0.5,1.0), 1.0)
|
| 253 |
+
mmod.shared.state.sampling_steps = 1
|
| 254 |
+
GS.current_sampling_step = 0
|
| 255 |
+
near("steps=1 pct=0 in [0,1]β1", bsr(0.0,1.0), 1.0)
|
| 256 |
+
mmod.shared.state.sampling_steps = 20
|
| 257 |
+
|
| 258 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 259 |
+
print("\n13. STAGEINFO + STATE")
|
| 260 |
+
si = GS.StageInfo()
|
| 261 |
+
chk("default bf=1", si.backbone_factor==1.0)
|
| 262 |
+
chk("default fft=box", si.fft_type=="box")
|
| 263 |
+
chk("default cap=off", not si.enable_adaptive_cap)
|
| 264 |
+
chk("19 STAGE_FIELD_NAMES", len(GS.STAGE_FIELD_NAMES)==19)
|
| 265 |
+
|
| 266 |
+
st = GS.State(version="Version 2")
|
| 267 |
+
chk("version coerced '2'", st.version=="2")
|
| 268 |
+
st2 = GS.State(version="1")
|
| 269 |
+
chk("version '1' stays '1'", st2.version=="1")
|
| 270 |
+
|
| 271 |
+
# unknown dict key in stage_infos
|
| 272 |
+
st3 = GS.State(stage_infos=[{'backbone_factor':1.5,'UNKNOWN':999}])
|
| 273 |
+
chk("dict unknown key ignored", st3.stage_infos[0].backbone_factor==1.5)
|
| 274 |
+
|
| 275 |
+
# padding to STAGES_COUNT
|
| 276 |
+
st4 = GS.State(stage_infos=[GS.StageInfo(backbone_factor=2.0)])
|
| 277 |
+
chk("pads to 3 stages", len(st4.stage_infos)==3)
|
| 278 |
+
chk("pad default bf=1", st4.stage_infos[1].backbone_factor==1.0)
|
| 279 |
+
|
| 280 |
+
# round-trip
|
| 281 |
+
st5 = GS.State(version="2", stage_infos=[GS.StageInfo(backbone_factor=1.7)])
|
| 282 |
+
d = st5.to_dict()
|
| 283 |
+
chk("to_dict no 'enable'", 'enable' not in d)
|
| 284 |
+
chk("to_dict has stages", 'stage_infos' in d)
|
| 285 |
+
fields = {f.name for f in dataclasses.fields(GS.State)}
|
| 286 |
+
st6 = GS.State(**{k:v for k,v in d.items() if k in fields})
|
| 287 |
+
chk("round-trip version", st6.version=="2")
|
| 288 |
+
|
| 289 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 290 |
+
print("\n14. UPDATE_ATTR (XYZ shorthands)")
|
| 291 |
+
st = GS.State()
|
| 292 |
+
st.update_attr("b0", 1.5); chk("b0βbf stage0", st.stage_infos[0].backbone_factor==1.5)
|
| 293 |
+
st.update_attr("s1", 0.3); chk("s1βsf stage1", st.stage_infos[1].skip_factor==0.3)
|
| 294 |
+
st.update_attr("ft2","gaussian"); chk("ft2βfft_type stage2", st.stage_infos[2].fft_type=="gaussian")
|
| 295 |
+
st.update_attr("acm0","fixed"); chk("acm0βcap_mode stage0", st.stage_infos[0].adaptive_cap_mode=="fixed")
|
| 296 |
+
st.update_attr("cap1", True); chk("cap1βenable_adaptive_cap", st.stage_infos[1].enable_adaptive_cap==True)
|
| 297 |
+
st.update_attr("ct0", 0.4); chk("ct0βcap_threshold", st.stage_infos[0].cap_threshold==0.4)
|
| 298 |
+
st.update_attr("start_ratio", 0.2); chk("start_ratio direct", st.start_ratio==0.2)
|
| 299 |
+
st.update_attr("enable", True); chk("enable direct", st.enable==True)
|
| 300 |
+
# unknown key β no crash
|
| 301 |
+
try:
|
| 302 |
+
st.update_attr("UNKNOWN_KEY", 99)
|
| 303 |
+
chk("unknown key: no crash", True)
|
| 304 |
+
except Exception as e:
|
| 305 |
+
chk("unknown key: no crash", False, str(e))
|
| 306 |
+
|
| 307 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 308 |
+
print("\n15. _LOAD_USER_PRESETS robustness")
|
| 309 |
+
# Good + bad preset in same file
|
| 310 |
+
pdata = {
|
| 311 |
+
"good": {"start_ratio":0.0,"stop_ratio":1.0,"transition_smoothness":0.0,
|
| 312 |
+
"version":"2","multiscale_mode":"Default","multiscale_strength":1.0,
|
| 313 |
+
"override_scales":"","channel_threshold":96,
|
| 314 |
+
"stage_infos":[{"backbone_factor":1.3}]},
|
| 315 |
+
"with_unknown": {"start_ratio":0.0,"stop_ratio":1.0,"transition_smoothness":0.0,
|
| 316 |
+
"version":"1","multiscale_mode":"Default","multiscale_strength":1.0,
|
| 317 |
+
"override_scales":"","channel_threshold":96,
|
| 318 |
+
"FUTURE_FIELD":"ignored","stage_infos":[]}
|
| 319 |
+
}
|
| 320 |
+
with tempfile.NamedTemporaryFile(mode='w',suffix='.json',delete=False) as f:
|
| 321 |
+
json.dump(pdata, f); tmp = f.name
|
| 322 |
+
GS.PRESETS_PATH = pathlib.Path(tmp)
|
| 323 |
+
res = GS._load_user_presets()
|
| 324 |
+
chk("good preset loaded", "good" in res)
|
| 325 |
+
chk("good preset bf=1.3", res.get("good") and res["good"].stage_infos[0].backbone_factor==1.3)
|
| 326 |
+
chk("no crash on unknown field preset", True)
|
| 327 |
+
os.unlink(tmp)
|
| 328 |
+
|
| 329 |
+
# Invalid JSON β {}
|
| 330 |
+
with tempfile.NamedTemporaryFile(mode='w',suffix='.json',delete=False) as f:
|
| 331 |
+
f.write("{bad json!!!"); tmp2 = f.name
|
| 332 |
+
GS.PRESETS_PATH = pathlib.Path(tmp2)
|
| 333 |
+
chk("invalid JSON β {}", GS._load_user_presets() == {})
|
| 334 |
+
os.unlink(tmp2)
|
| 335 |
+
|
| 336 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 337 |
+
print("\n16. FILTER_SKIP_GAUSSIAN_ADAPTIVE")
|
| 338 |
+
hs = Tensor(np.random.randn(1,4,16,16).astype(np.float32))
|
| 339 |
+
|
| 340 |
+
# no cap
|
| 341 |
+
si_nc = GS.StageInfo(skip_factor=0.8, fft_type='gaussian',
|
| 342 |
+
fft_radius_ratio=0.1, hf_boost=1.2,
|
| 343 |
+
skip_high_end_factor=1.1, enable_adaptive_cap=False)
|
| 344 |
+
out_nc = UN.filter_skip_gaussian_adaptive(hs, si_nc)
|
| 345 |
+
shp("no-cap shape", out_nc, (1,4,16,16))
|
| 346 |
+
chk("no-cap differs from input", not np.allclose(out_nc._d, hs._d, atol=1e-3))
|
| 347 |
+
chk("no-cap no nan", not np.isnan(out_nc._d).any())
|
| 348 |
+
|
| 349 |
+
# identity (scale=1, hfb=1)
|
| 350 |
+
si_id = GS.StageInfo(skip_factor=1.0, fft_type='gaussian',
|
| 351 |
+
fft_radius_ratio=0.1, hf_boost=1.0,
|
| 352 |
+
skip_high_end_factor=1.0, enable_adaptive_cap=False)
|
| 353 |
+
out_id = UN.filter_skip_gaussian_adaptive(hs, si_id)
|
| 354 |
+
chk("gauss identity β input", np.allclose(out_id._d, hs._d, atol=1e-4),
|
| 355 |
+
f"maxdiff={abs(out_id._d-hs._d).max():.2e}")
|
| 356 |
+
|
| 357 |
+
# with cap
|
| 358 |
+
si_cap = GS.StageInfo(skip_factor=0.3, fft_type='gaussian',
|
| 359 |
+
fft_radius_ratio=0.15, hf_boost=1.0,
|
| 360 |
+
skip_high_end_factor=1.0, enable_adaptive_cap=True,
|
| 361 |
+
cap_threshold=0.35, cap_factor=0.6, adaptive_cap_mode='adaptive')
|
| 362 |
+
out_cap = UN.filter_skip_gaussian_adaptive(hs, si_cap)
|
| 363 |
+
shp("cap shape", out_cap, (1,4,16,16))
|
| 364 |
+
chk("cap no nan", not np.isnan(out_cap._d).any())
|
| 365 |
+
|
| 366 |
+
# fixed cap mode
|
| 367 |
+
si_fixed = GS.StageInfo(skip_factor=0.3, fft_type='gaussian',
|
| 368 |
+
fft_radius_ratio=0.15, hf_boost=1.0,
|
| 369 |
+
skip_high_end_factor=1.0, enable_adaptive_cap=True,
|
| 370 |
+
cap_threshold=0.35, cap_factor=0.6, adaptive_cap_mode='fixed')
|
| 371 |
+
out_fixed = UN.filter_skip_gaussian_adaptive(hs, si_fixed)
|
| 372 |
+
shp("fixed-cap shape", out_fixed, (1,4,16,16))
|
| 373 |
+
chk("fixed-cap no nan", not np.isnan(out_fixed._d).any())
|
| 374 |
+
|
| 375 |
+
# very aggressive scale (s=0.0) with cap
|
| 376 |
+
si_agg = GS.StageInfo(skip_factor=0.0, fft_type='gaussian',
|
| 377 |
+
fft_radius_ratio=0.1, hf_boost=1.0,
|
| 378 |
+
skip_high_end_factor=1.0, enable_adaptive_cap=True,
|
| 379 |
+
cap_threshold=0.35, cap_factor=0.6, adaptive_cap_mode='adaptive')
|
| 380 |
+
out_agg = UN.filter_skip_gaussian_adaptive(hs, si_agg)
|
| 381 |
+
shp("aggressive cap shape", out_agg, (1,4,16,16))
|
| 382 |
+
chk("aggressive cap no nan", not np.isnan(out_agg._d).any())
|
| 383 |
+
|
| 384 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 385 |
+
print("\n17. BACKBONE BLEND MATH (unit test)")
|
| 386 |
+
h = Tensor(np.full((1,8,4,4), 3.0, np.float32))
|
| 387 |
+
dims = 8
|
| 388 |
+
rbegin, rend, rinv = UN.ratio_to_region(0.5, 0.0, dims)
|
| 389 |
+
mask_np = np.zeros(dims, np.float32)
|
| 390 |
+
if not rinv: mask_np[rbegin:rend] = 1.0
|
| 391 |
+
else: mask_np[:rend]=1.0; mask_np[rbegin:]=1.0
|
| 392 |
+
mask_t = Tensor(mask_np.reshape(1,-1,1,1))
|
| 393 |
+
|
| 394 |
+
# V1 scale=2.0: masked β 6, unmasked β 3
|
| 395 |
+
scale = 2.0
|
| 396 |
+
h_scaled = h * (mask_t * scale + (1.0 - mask_t))
|
| 397 |
+
masked = h_scaled._d[0, :rend, 0, 0]
|
| 398 |
+
unmasked = h_scaled._d[0, rend:, 0, 0]
|
| 399 |
+
chk("masked ch β 6.0", np.allclose(masked, 6.0, atol=1e-4), str(masked))
|
| 400 |
+
chk("unmasked ch β 3.0", np.allclose(unmasked, 3.0, atol=1e-4), str(unmasked))
|
| 401 |
+
|
| 402 |
+
# blend lerp 0.5: masked β lerp(3,6,0.5)=4.5
|
| 403 |
+
lerp_fn = UN.BLENDING_MODES['lerp']
|
| 404 |
+
h_scaled_full = h * (mask_t * scale + (1.0 - mask_t))
|
| 405 |
+
h_blended = lerp_fn(h, h_scaled_full, 0.5)
|
| 406 |
+
h_out = h * (1.0 - mask_t) + h_blended * mask_t
|
| 407 |
+
chk("blend lerp 0.5 maskedβ4.5",
|
| 408 |
+
np.allclose(h_out._d[0,:rend,0,0], 4.5, atol=1e-4),
|
| 409 |
+
str(h_out._d[0,:rend,0,0]))
|
| 410 |
+
chk("blend lerp 0.5 unmaskedβ3.0",
|
| 411 |
+
np.allclose(h_out._d[0,rend:,0,0], 3.0, atol=1e-4))
|
| 412 |
+
|
| 413 |
+
# inject blend t=0.5: h + h_scaled*0.5 = 3 + 6*0.5 = 6
|
| 414 |
+
inj_fn = UN.BLENDING_MODES['inject']
|
| 415 |
+
h_inj = inj_fn(h, h_scaled_full, 0.5)
|
| 416 |
+
h_out2 = h * (1.0 - mask_t) + h_inj * mask_t
|
| 417 |
+
chk("blend inject maskedβh+h_scaled*.5",
|
| 418 |
+
np.allclose(h_out2._d[0,:rend,0,0], 3+6*0.5, atol=1e-4),
|
| 419 |
+
str(h_out2._d[0,:rend,0,0]))
|
| 420 |
+
|
| 421 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 422 |
+
print("\n18. APPLY_XYZ")
|
| 423 |
+
GS.instance = GS.State(); GS.xyz_attrs.clear()
|
| 424 |
+
orig = GS.instance.stage_infos[0].backbone_factor
|
| 425 |
+
GS.apply_xyz()
|
| 426 |
+
chk("empty attrs: unchanged", GS.instance.stage_infos[0].backbone_factor==orig)
|
| 427 |
+
|
| 428 |
+
GS.xyz_attrs['b0'] = 2.5; GS.apply_xyz()
|
| 429 |
+
chk("b0=2.5 applied", GS.instance.stage_infos[0].backbone_factor==2.5)
|
| 430 |
+
GS.xyz_attrs.clear()
|
| 431 |
+
|
| 432 |
+
# Preset
|
| 433 |
+
GS.reload_presets()
|
| 434 |
+
pname = list(GS.all_presets.keys())[0]
|
| 435 |
+
pbf = GS.all_presets[pname].stage_infos[0].backbone_factor
|
| 436 |
+
GS.xyz_attrs['preset'] = pname; GS.apply_xyz()
|
| 437 |
+
chk("preset applied", abs(GS.instance.stage_infos[0].backbone_factor-pbf)<1e-5)
|
| 438 |
+
GS.xyz_attrs.clear()
|
| 439 |
+
|
| 440 |
+
# Unknown preset β warning, instance unchanged
|
| 441 |
+
GS.instance = GS.State()
|
| 442 |
+
bf_before = GS.instance.stage_infos[0].backbone_factor
|
| 443 |
+
GS.xyz_attrs['preset'] = 'NO_SUCH_PRESET_XYZ'; GS.apply_xyz()
|
| 444 |
+
chk("unknown preset: no crash", True)
|
| 445 |
+
GS.xyz_attrs.clear()
|
| 446 |
+
|
| 447 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 448 |
+
print("\n19. DETECT_MODEL_CHANNELS")
|
| 449 |
+
UN.detect_model_channels()
|
| 450 |
+
chk("fallback=(1280,640,320)", UN._stage_channels==(1280,640,320))
|
| 451 |
+
|
| 452 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 453 |
+
print("\n20. _FLAT_TO_SIS (simulated)")
|
| 454 |
+
_SF = [f.name for f in dataclasses.fields(GS.StageInfo)]
|
| 455 |
+
_SN = len(_SF)
|
| 456 |
+
def flat_to_sis(flat):
|
| 457 |
+
res=[]
|
| 458 |
+
for i in range(GS.STAGES_COUNT):
|
| 459 |
+
chunk=flat[i*_SN:(i+1)*_SN]
|
| 460 |
+
si_new=GS.StageInfo()
|
| 461 |
+
for j,fname in enumerate(_SF):
|
| 462 |
+
if j<len(chunk): setattr(si_new,fname,chunk[j])
|
| 463 |
+
res.append(si_new)
|
| 464 |
+
return res
|
| 465 |
+
|
| 466 |
+
si0=GS.StageInfo(backbone_factor=1.7, skip_factor=0.3, fft_type='gaussian')
|
| 467 |
+
flat=[]
|
| 468 |
+
for si_x in [si0,GS.StageInfo(),GS.StageInfo()]:
|
| 469 |
+
for f in _SF: flat.append(getattr(si_x,f))
|
| 470 |
+
sis=flat_to_sis(flat)
|
| 471 |
+
chk("flat bf=1.7", sis[0].backbone_factor==1.7)
|
| 472 |
+
chk("flat sf=0.3", sis[0].skip_factor==0.3)
|
| 473 |
+
chk("flat fft=gaussian", sis[0].fft_type=='gaussian')
|
| 474 |
+
chk("flat 3 stages", len(sis)==3)
|
| 475 |
+
chk("flat short: no crash", len(flat_to_sis(flat[:20]))==3)
|
| 476 |
+
|
| 477 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 478 |
+
print("\n21. POSTCFG STEP COUNTER LOGIC")
|
| 479 |
+
# Simulate _on_cfg_post step counting
|
| 480 |
+
cfg = {"enabled":True,"steps":3,"mode":"lerp","blend":0.5,"b":1.1,
|
| 481 |
+
"fourier":False,"ms_mode":"Default","ms_str":1.0,
|
| 482 |
+
"threshold":1,"s":0.5,"gain":1.0,"step":0}
|
| 483 |
+
|
| 484 |
+
class FakeParams:
|
| 485 |
+
def __init__(self): self.x = Tensor(np.ones((1,4,8,8),np.float32)*2.0)
|
| 486 |
+
|
| 487 |
+
class FakeP:
|
| 488 |
+
_mega_pcfg = cfg
|
| 489 |
+
|
| 490 |
+
class FakeDenoiser:
|
| 491 |
+
p = FakeP()
|
| 492 |
+
|
| 493 |
+
# simulate _on_cfg_post inline
|
| 494 |
+
def run_post(params):
|
| 495 |
+
p = getattr(params, "p", None)
|
| 496 |
+
if p is None:
|
| 497 |
+
p = getattr(getattr(params,"denoiser",None),"p",None)
|
| 498 |
+
if p is None: return False
|
| 499 |
+
c = getattr(p,"_mega_pcfg",None)
|
| 500 |
+
if not c or not c.get("enabled"): return False
|
| 501 |
+
c["step"] = c.get("step",0)+1
|
| 502 |
+
if c["step"] > c["steps"]: return False
|
| 503 |
+
x = params.x
|
| 504 |
+
fn = UN.BLENDING_MODES.get(c["mode"], UN.BLENDING_MODES["inject"])
|
| 505 |
+
params.x = fn(x, x*c["b"], c["blend"])
|
| 506 |
+
return True
|
| 507 |
+
|
| 508 |
+
# via p attribute
|
| 509 |
+
fp1 = FakeParams(); fp1.p = FakeP(); fp1.p._mega_pcfg = {"enabled":True,"steps":2,"mode":"lerp","blend":0.5,"b":1.0,"fourier":False,"step":0,"gain":1.0}
|
| 510 |
+
ran1 = run_post(fp1); chk("postcfg step1 ran", ran1)
|
| 511 |
+
ran2 = run_post(fp1); chk("postcfg step2 ran", ran2)
|
| 512 |
+
ran3 = run_post(fp1); chk("postcfg step3 β past limit, skipped", not ran3)
|
| 513 |
+
|
| 514 |
+
# via denoiser.p
|
| 515 |
+
fp2 = FakeParams()
|
| 516 |
+
class D2:
|
| 517 |
+
class p2:
|
| 518 |
+
_mega_pcfg = {"enabled":True,"steps":1,"mode":"inject","blend":0.5,"b":1.1,"fourier":False,"step":0,"gain":1.0}
|
| 519 |
+
p = p2
|
| 520 |
+
fp2.denoiser = D2
|
| 521 |
+
chk("postcfg via denoiser.p: no crash", run_post(fp2))
|
| 522 |
+
|
| 523 |
+
# disabled β skip
|
| 524 |
+
fp3 = FakeParams(); fp3.p = type('P',(),{'_mega_pcfg':{"enabled":False}})()
|
| 525 |
+
chk("postcfg disabled β skip", not run_post(fp3))
|
| 526 |
+
|
| 527 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 528 |
+
print(f"\n{'β'*54}")
|
| 529 |
+
print(f"TOTAL: {P} PASS {F} FAIL")
|
| 530 |
+
if ERRS:
|
| 531 |
+
print("\nFailed:")
|
| 532 |
+
for e in ERRS: print(f" β’ {e}")
|
| 533 |
+
else:
|
| 534 |
+
print("ALL TESTS PASSED β")
|
mega_freeu_a1111/tests/test_fixes.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pathlib, sys
|
| 2 |
+
exec(open(str(pathlib.Path(__file__).parent / 'mock_torch.py')).read())
|
| 3 |
+
import sys, math, types, dataclasses, json
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pathlib; sys.path.insert(0, str(pathlib.Path(__file__).parent.parent))
|
| 6 |
+
import importlib.util
|
| 7 |
+
|
| 8 |
+
def load_mod(name, path):
|
| 9 |
+
spec = importlib.util.spec_from_file_location(name, path)
|
| 10 |
+
m = importlib.util.module_from_spec(spec); sys.modules[name]=m; spec.loader.exec_module(m); return m
|
| 11 |
+
|
| 12 |
+
lib_pkg = types.ModuleType('lib_mega_freeu'); sys.modules['lib_mega_freeu'] = lib_pkg
|
| 13 |
+
GS = load_mod('lib_mega_freeu.global_state', str(pathlib.Path(__file__).parent.parent / 'lib_mega_freeu' / 'global_state.py'))
|
| 14 |
+
UN = load_mod('lib_mega_freeu.unet', str(pathlib.Path(__file__).parent.parent / 'lib_mega_freeu' / 'unet.py'))
|
| 15 |
+
|
| 16 |
+
P=0; F=0; ERRS=[]
|
| 17 |
+
def ok(t): global P; P+=1; print(f" β {t}")
|
| 18 |
+
def ng(t,m=""): global F; F+=1; ERRS.append(f"{t}: {m}"); print(f" β {t} {m}")
|
| 19 |
+
def chk(t, c, m=""): ok(t) if c else ng(t, m)
|
| 20 |
+
|
| 21 |
+
print("β"*52)
|
| 22 |
+
print("FIX 1: State has pcfg_ and verbose fields")
|
| 23 |
+
print("β"*52)
|
| 24 |
+
st = GS.State()
|
| 25 |
+
chk("pcfg_enabled default False", not st.pcfg_enabled)
|
| 26 |
+
chk("pcfg_steps default 20", st.pcfg_steps == 20)
|
| 27 |
+
chk("pcfg_mode default inject", st.pcfg_mode == "inject")
|
| 28 |
+
chk("verbose default False", not st.verbose)
|
| 29 |
+
|
| 30 |
+
st2 = GS.State(pcfg_enabled=True, pcfg_b=1.5, pcfg_steps=10, verbose=True)
|
| 31 |
+
chk("pcfg_enabled=True", st2.pcfg_enabled)
|
| 32 |
+
chk("pcfg_b=1.5", st2.pcfg_b == 1.5)
|
| 33 |
+
chk("verbose=True", st2.verbose)
|
| 34 |
+
|
| 35 |
+
print("\nβ"*27)
|
| 36 |
+
print("FIX 2: to_dict() round-trip includes pcfg fields")
|
| 37 |
+
d = st2.to_dict()
|
| 38 |
+
chk("to_dict has pcfg_enabled", "pcfg_enabled" in d)
|
| 39 |
+
chk("to_dict has pcfg_b", "pcfg_b" in d)
|
| 40 |
+
chk("to_dict has verbose", "verbose" in d)
|
| 41 |
+
chk("to_dict no 'enable'", "enable" not in d)
|
| 42 |
+
st3 = GS.State(**{k:v for k,v in d.items() if k in {f.name for f in dataclasses.fields(GS.State)}})
|
| 43 |
+
chk("round-trip pcfg_b=1.5", st3.pcfg_b == 1.5)
|
| 44 |
+
chk("round-trip verbose=True", st3.verbose)
|
| 45 |
+
|
| 46 |
+
print("\nβ"*27)
|
| 47 |
+
print("FIX 3: _load_user_presets saves/restores pcfg")
|
| 48 |
+
import json, tempfile, os, pathlib
|
| 49 |
+
preset_data = {
|
| 50 |
+
"my_pcfg_preset": {
|
| 51 |
+
"start_ratio": 0.0, "stop_ratio": 1.0, "transition_smoothness": 0.0,
|
| 52 |
+
"version": "2", "multiscale_mode": "Default", "multiscale_strength": 1.0,
|
| 53 |
+
"override_scales": "", "channel_threshold": 96,
|
| 54 |
+
"pcfg_enabled": True, "pcfg_b": 1.8, "pcfg_steps": 5,
|
| 55 |
+
"pcfg_mode": "lerp", "pcfg_blend": 0.7, "verbose": True,
|
| 56 |
+
"stage_infos": []
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
| 60 |
+
json.dump(preset_data, f); tmp = f.name
|
| 61 |
+
GS.PRESETS_PATH = pathlib.Path(tmp)
|
| 62 |
+
res = GS._load_user_presets()
|
| 63 |
+
chk("preset loaded", "my_pcfg_preset" in res)
|
| 64 |
+
p = res.get("my_pcfg_preset")
|
| 65 |
+
chk("pcfg_enabled restored", p and p.pcfg_enabled == True)
|
| 66 |
+
chk("pcfg_b restored", p and p.pcfg_b == 1.8)
|
| 67 |
+
chk("pcfg_steps restored", p and p.pcfg_steps == 5)
|
| 68 |
+
chk("verbose restored", p and p.verbose == True)
|
| 69 |
+
os.unlink(tmp)
|
| 70 |
+
|
| 71 |
+
print("\nβ"*27)
|
| 72 |
+
print("FIX 4: _write_generation_params (simulate)")
|
| 73 |
+
class FakePNG:
|
| 74 |
+
extra_generation_params = {}
|
| 75 |
+
|
| 76 |
+
st4 = GS.State(
|
| 77 |
+
start_ratio=0.1, stop_ratio=0.9, transition_smoothness=0.5,
|
| 78 |
+
version="2", multiscale_mode="Multi-Bandpass", multiscale_strength=0.8,
|
| 79 |
+
override_scales="10, 1.5\n20, 0.8", channel_threshold=64,
|
| 80 |
+
pcfg_enabled=True, pcfg_b=1.2, pcfg_steps=15, pcfg_mode="inject",
|
| 81 |
+
pcfg_blend=1.0, pcfg_fourier=False, pcfg_ms_mode="Default",
|
| 82 |
+
pcfg_ms_str=1.0, pcfg_threshold=1, pcfg_s=0.5, pcfg_gain=1.0,
|
| 83 |
+
verbose=True,
|
| 84 |
+
stage_infos=[GS.StageInfo(backbone_factor=1.3)]
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Simulate _write_generation_params
|
| 88 |
+
fp = FakePNG()
|
| 89 |
+
fp.extra_generation_params["MegaFreeU Schedule"] = f"{st4.start_ratio}, {st4.stop_ratio}, {st4.transition_smoothness}"
|
| 90 |
+
fp.extra_generation_params["MegaFreeU Stages"] = json.dumps([si.to_dict() for si in st4.stage_infos])
|
| 91 |
+
fp.extra_generation_params["MegaFreeU Version"] = st4.version
|
| 92 |
+
fp.extra_generation_params["MegaFreeU Multiscale Mode"] = st4.multiscale_mode
|
| 93 |
+
fp.extra_generation_params["MegaFreeU Multiscale Strength"] = str(st4.multiscale_strength)
|
| 94 |
+
fp.extra_generation_params["MegaFreeU Override Scales"] = st4.override_scales
|
| 95 |
+
fp.extra_generation_params["MegaFreeU Channel Threshold"] = str(st4.channel_threshold)
|
| 96 |
+
fp.extra_generation_params["MegaFreeU Verbose"] = str(st4.verbose)
|
| 97 |
+
if st4.pcfg_enabled:
|
| 98 |
+
fp.extra_generation_params["MegaFreeU PostCFG"] = json.dumps({
|
| 99 |
+
"enabled": st4.pcfg_enabled, "steps": st4.pcfg_steps,
|
| 100 |
+
"mode": st4.pcfg_mode, "blend": st4.pcfg_blend,
|
| 101 |
+
"b": st4.pcfg_b, "fourier": st4.pcfg_fourier,
|
| 102 |
+
"ms_mode": st4.pcfg_ms_mode, "ms_str": st4.pcfg_ms_str,
|
| 103 |
+
"threshold": st4.pcfg_threshold, "s": st4.pcfg_s, "gain": st4.pcfg_gain,
|
| 104 |
+
})
|
| 105 |
+
|
| 106 |
+
eg = fp.extra_generation_params
|
| 107 |
+
chk("PNG has Schedule", "MegaFreeU Schedule" in eg)
|
| 108 |
+
chk("PNG has Stages", "MegaFreeU Stages" in eg)
|
| 109 |
+
chk("PNG has Version", "MegaFreeU Version" in eg)
|
| 110 |
+
chk("PNG has Multiscale Mode", "MegaFreeU Multiscale Mode" in eg)
|
| 111 |
+
chk("PNG has Multiscale Strength","MegaFreeU Multiscale Strength" in eg)
|
| 112 |
+
chk("PNG has Override Scales", "MegaFreeU Override Scales" in eg)
|
| 113 |
+
chk("PNG has Channel Threshold", "MegaFreeU Channel Threshold" in eg)
|
| 114 |
+
chk("PNG has Verbose", "MegaFreeU Verbose" in eg)
|
| 115 |
+
chk("PNG has PostCFG", "MegaFreeU PostCFG" in eg)
|
| 116 |
+
|
| 117 |
+
# Verify PostCFG round-trip
|
| 118 |
+
pcfg_d = json.loads(eg["MegaFreeU PostCFG"])
|
| 119 |
+
chk("PostCFG b=1.2", pcfg_d["b"] == 1.2)
|
| 120 |
+
chk("PostCFG steps=15",pcfg_d["steps"] == 15)
|
| 121 |
+
|
| 122 |
+
# Verify multiscale restore
|
| 123 |
+
chk("ms_mode=Multi-Bandpass", eg["MegaFreeU Multiscale Mode"] == "Multi-Bandpass")
|
| 124 |
+
chk("ch_thresh=64", eg["MegaFreeU Channel Threshold"] == "64")
|
| 125 |
+
|
| 126 |
+
print("\nβ"*27)
|
| 127 |
+
print("FIX 5: Post-CFG independent of Enable (simulate process logic)")
|
| 128 |
+
# The fix: pcfg is set BEFORE checking st.enable
|
| 129 |
+
# Test: when enabled=False but pcfg_enabled=True, pcfg still created
|
| 130 |
+
|
| 131 |
+
class FakeP2:
|
| 132 |
+
extra_generation_params = {}
|
| 133 |
+
_mega_pcfg = None
|
| 134 |
+
|
| 135 |
+
fp2 = FakeP2()
|
| 136 |
+
# Simulate new process() logic for disabled main FreeU + enabled Post-CFG
|
| 137 |
+
st_disabled = GS.State(enable=False, pcfg_enabled=True, pcfg_b=1.5, pcfg_steps=10)
|
| 138 |
+
# Post-CFG created regardless
|
| 139 |
+
if st_disabled.pcfg_enabled:
|
| 140 |
+
fp2._mega_pcfg = {"enabled": True, "b": st_disabled.pcfg_b, "steps": st_disabled.pcfg_steps, "step": 0}
|
| 141 |
+
else:
|
| 142 |
+
fp2._mega_pcfg = {"enabled": False}
|
| 143 |
+
|
| 144 |
+
chk("pcfg set even when main disabled", fp2._mega_pcfg["enabled"] == True)
|
| 145 |
+
chk("pcfg_b=1.5 propagated", fp2._mega_pcfg["b"] == 1.5)
|
| 146 |
+
|
| 147 |
+
print("\nβ"*27)
|
| 148 |
+
print("FIX 6: dict-API compat (old sd-webui-freeu alwayson_scripts)")
|
| 149 |
+
# Simulate the dict branch of process()
|
| 150 |
+
dict_args = {
|
| 151 |
+
"enable": True, "start_ratio": 0.2, "stop_ratio": 0.8,
|
| 152 |
+
"version": "2", "multiscale_mode": "Default"
|
| 153 |
+
}
|
| 154 |
+
fields = {f.name for f in dataclasses.fields(GS.State)}
|
| 155 |
+
GS.instance = GS.State(**{k:v for k,v in dict_args.items() if k in fields})
|
| 156 |
+
chk("dict API: start_ratio=0.2", GS.instance.start_ratio == 0.2)
|
| 157 |
+
chk("dict API: version coerced '2'", GS.instance.version == "2")
|
| 158 |
+
chk("dict API: pcfg defaults", not GS.instance.pcfg_enabled)
|
| 159 |
+
|
| 160 |
+
print(f"\n{'β'*52}")
|
| 161 |
+
print(f"NEW FIXES: {P} PASS {F} FAIL")
|
| 162 |
+
if ERRS:
|
| 163 |
+
print("\nFailed:")
|
| 164 |
+
for e in ERRS: print(f" β’ {e}")
|
| 165 |
+
else:
|
| 166 |
+
print("ALL NEW FIX TESTS PASSED β")
|
mega_freeu_a1111/tests/test_preset_pcfg.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pathlib, sys
|
| 2 |
+
exec(open(str(pathlib.Path(__file__).parent / 'mock_torch.py')).read())
|
| 3 |
+
import sys, types, dataclasses, json
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pathlib; sys.path.insert(0, str(pathlib.Path(__file__).parent.parent))
|
| 6 |
+
import importlib.util
|
| 7 |
+
|
| 8 |
+
def load(name, path):
|
| 9 |
+
spec = importlib.util.spec_from_file_location(name, path)
|
| 10 |
+
m = importlib.util.module_from_spec(spec); sys.modules[name]=m; spec.loader.exec_module(m); return m
|
| 11 |
+
|
| 12 |
+
lib = types.ModuleType('lib_mega_freeu'); sys.modules['lib_mega_freeu'] = lib
|
| 13 |
+
GS = load('lib_mega_freeu.global_state', str(pathlib.Path(__file__).parent.parent / 'lib_mega_freeu' / 'global_state.py'))
|
| 14 |
+
UN = load('lib_mega_freeu.unet', str(pathlib.Path(__file__).parent.parent / 'lib_mega_freeu' / 'unet.py'))
|
| 15 |
+
|
| 16 |
+
P=0; F=0; ERRS=[]
|
| 17 |
+
def ok(t): global P; P+=1; print(f" β {t}")
|
| 18 |
+
def ng(t,m=""): global F; F+=1; ERRS.append(f"{t}: {m}"); print(f" β {t} {m}")
|
| 19 |
+
def chk(t, c, m=""): ok(t) if c else ng(t, m)
|
| 20 |
+
|
| 21 |
+
_SF = [f.name for f in dataclasses.fields(GS.StageInfo)]
|
| 22 |
+
_SN = len(_SF)
|
| 23 |
+
|
| 24 |
+
def _flat_to_sis(flat):
|
| 25 |
+
result = []
|
| 26 |
+
for i in range(GS.STAGES_COUNT):
|
| 27 |
+
chunk = flat[i*_SN:(i+1)*_SN]
|
| 28 |
+
si = GS.StageInfo()
|
| 29 |
+
for j, fname in enumerate(_SF):
|
| 30 |
+
if j < len(chunk): setattr(si, fname, chunk[j])
|
| 31 |
+
result.append(si)
|
| 32 |
+
return result
|
| 33 |
+
|
| 34 |
+
# Simulate _save_p with all fields
|
| 35 |
+
def _save_p(name, sr, sp, sm, ver, msm, mss, ovs, cht,
|
| 36 |
+
p_en, p_steps, p_mode, p_bl, p_b,
|
| 37 |
+
p_four, p_mmd, p_mst, p_thr, p_s, p_gain,
|
| 38 |
+
v_log, *flat):
|
| 39 |
+
sis = _flat_to_sis(flat)
|
| 40 |
+
vc = GS.ALL_VERSIONS.get(ver, "1")
|
| 41 |
+
GS.all_presets[name] = GS.State(
|
| 42 |
+
start_ratio=sr, stop_ratio=sp, transition_smoothness=sm,
|
| 43 |
+
version=vc, multiscale_mode=msm, multiscale_strength=float(mss),
|
| 44 |
+
override_scales=ovs or "", channel_threshold=int(cht),
|
| 45 |
+
stage_infos=sis,
|
| 46 |
+
pcfg_enabled=bool(p_en), pcfg_steps=int(p_steps),
|
| 47 |
+
pcfg_mode=str(p_mode), pcfg_blend=float(p_bl), pcfg_b=float(p_b),
|
| 48 |
+
pcfg_fourier=bool(p_four), pcfg_ms_mode=str(p_mmd),
|
| 49 |
+
pcfg_ms_str=float(p_mst), pcfg_threshold=int(p_thr),
|
| 50 |
+
pcfg_s=float(p_s), pcfg_gain=float(p_gain), verbose=bool(v_log),
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Simulate _apply_p
|
| 54 |
+
def _apply_p(name):
|
| 55 |
+
p = GS.all_presets.get(name)
|
| 56 |
+
if p is None: return None
|
| 57 |
+
flat = []
|
| 58 |
+
for si in p.stage_infos:
|
| 59 |
+
for f in _SF: flat.append(getattr(si, f))
|
| 60 |
+
return {
|
| 61 |
+
"start_ratio": p.start_ratio, "stop_ratio": p.stop_ratio,
|
| 62 |
+
"smooth": p.transition_smoothness,
|
| 63 |
+
"version": GS.REVERSED_VERSIONS.get(p.version, "Version 2"),
|
| 64 |
+
"ms_mode": p.multiscale_mode, "ms_str": p.multiscale_strength,
|
| 65 |
+
"ov_scales": p.override_scales, "ch_thresh": p.channel_threshold,
|
| 66 |
+
"pcfg_en": p.pcfg_enabled, "pcfg_steps": p.pcfg_steps,
|
| 67 |
+
"pcfg_mode": p.pcfg_mode, "pcfg_bl": p.pcfg_blend,
|
| 68 |
+
"pcfg_b": p.pcfg_b, "pcfg_fou": p.pcfg_fourier,
|
| 69 |
+
"pcfg_mmd": p.pcfg_ms_mode, "pcfg_mst": p.pcfg_ms_str,
|
| 70 |
+
"pcfg_thr": p.pcfg_threshold, "pcfg_s": p.pcfg_s,
|
| 71 |
+
"pcfg_gain": p.pcfg_gain, "verbose": p.verbose,
|
| 72 |
+
"flat": flat,
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
print("β"*52)
|
| 76 |
+
print("TEST: Full preset save/apply round-trip")
|
| 77 |
+
print("β"*52)
|
| 78 |
+
|
| 79 |
+
# Build flat for 3 stages
|
| 80 |
+
si0 = GS.StageInfo(backbone_factor=1.3, skip_factor=0.8)
|
| 81 |
+
flat_in = []
|
| 82 |
+
for si in [si0, GS.StageInfo(), GS.StageInfo()]:
|
| 83 |
+
for f in _SF: flat_in.append(getattr(si, f))
|
| 84 |
+
|
| 85 |
+
# Save
|
| 86 |
+
_save_p(
|
| 87 |
+
"full_preset",
|
| 88 |
+
0.1, 0.9, 0.5, "Version 2",
|
| 89 |
+
"Multi-Bandpass", 0.7, "10,1.5", 64,
|
| 90 |
+
True, 15, "lerp", 0.8, 1.4,
|
| 91 |
+
True, "Default", 0.9, 3, 0.6, 1.2,
|
| 92 |
+
True,
|
| 93 |
+
*flat_in
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
chk("preset saved", "full_preset" in GS.all_presets)
|
| 97 |
+
saved = GS.all_presets["full_preset"]
|
| 98 |
+
chk("saved pcfg_enabled=True", saved.pcfg_enabled == True)
|
| 99 |
+
chk("saved pcfg_steps=15", saved.pcfg_steps == 15)
|
| 100 |
+
chk("saved pcfg_mode=lerp", saved.pcfg_mode == "lerp")
|
| 101 |
+
chk("saved pcfg_blend=0.8", abs(saved.pcfg_blend - 0.8) < 1e-5)
|
| 102 |
+
chk("saved pcfg_b=1.4", abs(saved.pcfg_b - 1.4) < 1e-5)
|
| 103 |
+
chk("saved pcfg_fourier=True", saved.pcfg_fourier == True)
|
| 104 |
+
chk("saved pcfg_ms_mode=Default", saved.pcfg_ms_mode == "Default")
|
| 105 |
+
chk("saved pcfg_ms_str=0.9", abs(saved.pcfg_ms_str - 0.9) < 1e-5)
|
| 106 |
+
chk("saved pcfg_threshold=3", saved.pcfg_threshold == 3)
|
| 107 |
+
chk("saved pcfg_s=0.6", abs(saved.pcfg_s - 0.6) < 1e-5)
|
| 108 |
+
chk("saved pcfg_gain=1.2", abs(saved.pcfg_gain - 1.2) < 1e-5)
|
| 109 |
+
chk("saved verbose=True", saved.verbose == True)
|
| 110 |
+
chk("saved ms_mode=Multi-Bandpass", saved.multiscale_mode == "Multi-Bandpass")
|
| 111 |
+
chk("saved ms_str=0.7", abs(saved.multiscale_strength - 0.7) < 1e-5)
|
| 112 |
+
chk("saved ch_thresh=64", saved.channel_threshold == 64)
|
| 113 |
+
chk("saved bf=1.3", abs(saved.stage_infos[0].backbone_factor - 1.3) < 1e-5)
|
| 114 |
+
|
| 115 |
+
# Apply
|
| 116 |
+
restored = _apply_p("full_preset")
|
| 117 |
+
chk("apply: pcfg_en=True", restored["pcfg_en"] == True)
|
| 118 |
+
chk("apply: pcfg_steps=15", restored["pcfg_steps"] == 15)
|
| 119 |
+
chk("apply: pcfg_mode=lerp", restored["pcfg_mode"] == "lerp")
|
| 120 |
+
chk("apply: pcfg_b=1.4", abs(restored["pcfg_b"] - 1.4) < 1e-5)
|
| 121 |
+
chk("apply: pcfg_fourier=True", restored["pcfg_fou"] == True)
|
| 122 |
+
chk("apply: verbose=True", restored["verbose"] == True)
|
| 123 |
+
chk("apply: ms_mode restored", restored["ms_mode"] == "Multi-Bandpass")
|
| 124 |
+
chk("apply: ch_thresh=64", restored["ch_thresh"] == 64)
|
| 125 |
+
chk("apply: bf=1.3 in flat", abs(restored["flat"][0] - 1.3) < 1e-5)
|
| 126 |
+
|
| 127 |
+
# Save with pcfg disabled β defaults
|
| 128 |
+
_save_p(
|
| 129 |
+
"no_pcfg",
|
| 130 |
+
0.0, 1.0, 0.0, "Version 1",
|
| 131 |
+
"Default", 1.0, "", 96,
|
| 132 |
+
False, 20, "inject", 1.0, 1.1,
|
| 133 |
+
False, "Default", 1.0, 1, 0.5, 1.0,
|
| 134 |
+
False,
|
| 135 |
+
*flat_in
|
| 136 |
+
)
|
| 137 |
+
r2 = _apply_p("no_pcfg")
|
| 138 |
+
chk("no_pcfg: pcfg_en=False", r2["pcfg_en"] == False)
|
| 139 |
+
chk("no_pcfg: verbose=False", r2["verbose"] == False)
|
| 140 |
+
|
| 141 |
+
# Unknown preset β None
|
| 142 |
+
r3 = _apply_p("does_not_exist")
|
| 143 |
+
chk("unknown preset β None", r3 is None)
|
| 144 |
+
|
| 145 |
+
# dict-API branch: pcfg_enabled passed through dict
|
| 146 |
+
GS.instance = GS.State()
|
| 147 |
+
d_api = {"enable": True, "pcfg_enabled": True, "pcfg_b": 1.9, "verbose": True}
|
| 148 |
+
fields = {f.name for f in dataclasses.fields(GS.State)}
|
| 149 |
+
GS.instance = GS.State(**{k:v for k,v in d_api.items() if k in fields})
|
| 150 |
+
chk("dict-API: pcfg_enabled propagated", GS.instance.pcfg_enabled == True)
|
| 151 |
+
chk("dict-API: pcfg_b=1.9", abs(GS.instance.pcfg_b - 1.9) < 1e-5)
|
| 152 |
+
chk("dict-API: verbose=True", GS.instance.verbose == True)
|
| 153 |
+
|
| 154 |
+
print(f"\n{'β'*52}")
|
| 155 |
+
print(f"PRESET PCFG ROUND-TRIP: {P} PASS {F} FAIL")
|
| 156 |
+
if ERRS:
|
| 157 |
+
print("\nFailed:")
|
| 158 |
+
for e in ERRS: print(f" β’ {e}")
|
| 159 |
+
else:
|
| 160 |
+
print("ALL PASSED β")
|