diff --git "a/adept-sampler-v5/scripts/adept_sampler_v5.py" "b/adept-sampler-v5/scripts/adept_sampler_v5.py" new file mode 100644--- /dev/null +++ "b/adept-sampler-v5/scripts/adept_sampler_v5.py" @@ -0,0 +1,3671 @@ +""" +Adept Sampler v5 for Automatic1111 WebUI +Complete port with ALL custom samplers from ComfyUI + +Version: 5.0 +""" + +import torch +import numpy as np +import math +from tqdm import trange +from modules import scripts, shared, script_callbacks +import gradio as gr +import k_diffusion.sampling + +# Try import torchvision for detail enhancement +try: + from torchvision.transforms.functional import gaussian_blur + TORCHVISION_AVAILABLE = True +except ImportError: + TORCHVISION_AVAILABLE = False + print("⚠️ torchvision not available - detail enhancement disabled") + +# ============================================================================ +# GLOBAL STATE +# ============================================================================ +ADEPT_STATE = { + "enabled": False, + "scale": 1.0, + "shift": 0.0, + "start_pct": 0.0, + "end_pct": 1.0, + "eta": 1.0, + "s_noise": 1.0, + "adaptive_eta": False, + "scheduler": "Standard", + "vae_reflection": False, + + # Custom sampler settings + "use_custom_sampler": False, + "custom_sampler": "Akashic Solver v2", + "tau": 0.5, + "phase_strength": 0.5, + "solver_order": 2, + "use_corrector": True, + "phase_noise": False, + "enhanced_derivative": False, + "smea_strength": 0.0, + "ndb_strength": 0.0, + "eqvae_mode": "Off", + + # Mirror Correction Euler controls + "mirror_correction_phase": 0.5, + "mirror_smooth_phase": False, + + # CFG enhancement settings + "cfg_drift_enabled": False, + "cfg_drift_method": "mean", + "cfg_drift_intensity": 0.5, + "spectral_cfg_enabled": False, + "spectral_multiplier": 1.0, + "spectral_percentile": 5.0, + "phase_cfg_enabled": False, + "phase_cfg_alpha": 2.0, + "phase_cfg_beta": 2.0, + "cfg_runtime_mode": "off", # off | a1111-postcfg | a1111-monkeypatch | native-hook + + # Internal bookkeeping for phase-aware CFG progress tracking + "_cfg_step_idx": 0, + "_cfg_total_steps": 1, +} + +# Store original samplers +ORIGINAL_SAMPLERS = {} + +# VAE Reflection state +_vae_reflection_active = False +_vae_original_padding_modes = {} + +# CFG hook / callback runtime state +_ADEPT_CFG_AFTER_CB = None +_ADEPT_CFG_DENOISER_CB = None +_ADEPT_NATIVE_CFG_HOOK_ACTIVE = False + +# CFGDenoiser monkey-patch runtime state +_CFGD_ORIG_COMBINE = None +_CFGD_ORIG_COMBINE_EDIT = None +_CFGD_ORIG_FORWARD = None +_CFGD_MONKEYPATCH_ACTIVE = False +_ADEPT_CFGDENOISER_CTX_ATTR = "_adept_cfg_ctx" + +# ============================================================================ +# BASIC UTILITY FUNCTIONS (from v3) +# ============================================================================ + +def to_d(x, sigma, denoised): + """Convert denoised prediction to derivative.""" + diff = x - denoised + safe_sigma = torch.clamp(sigma, min=1e-4) + derivative = diff / safe_sigma + + sigma_adaptive_threshold = 1000.0 * (1.0 + sigma / 10.0) + derivative_max = torch.abs(derivative).max() + if derivative_max > sigma_adaptive_threshold: + derivative = torch.clamp(derivative, -sigma_adaptive_threshold, sigma_adaptive_threshold) + + return derivative + + +def get_ancestral_step(sigma, sigma_next, eta=1.0): + """Calculate ancestral step sizes.""" + if sigma_next == 0: + return 0.0, 0.0 + sigma_up = min(sigma_next, eta * (sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) ** 0.5) + sigma_down = (sigma_next ** 2 - sigma_up ** 2) ** 0.5 + return sigma_down, sigma_up + + +def compute_dynamic_scale(step_idx, total_steps, base_scale, start_pct, end_pct): + """ + Compute weight scale for the current step with smooth fade-in/fade-out. + + The fade ramp is proportional to the active window width rather than a + fixed 0.1 absolute value. For a narrow window (e.g. start=0.4, end=0.5) + a hard-coded 0.1 ramp would consume the entire window and produce jarring + or contradictory behaviour; normalising to 20 % of window width keeps the + envelope sensible at any window size. + + Returns 1.0 (no-op) outside [start_pct, end_pct]. + """ + # Clamp / validate inputs so callers can pass raw UI values safely. + start_pct = max(0.0, min(float(start_pct), 1.0)) + end_pct = max(start_pct, min(float(end_pct), 1.0)) + total_steps = max(total_steps, 1) + progress = step_idx / max(total_steps - 1, 1) + + if progress < start_pct or progress > end_pct: + return 1.0 + + window = end_pct - start_pct + if window < 1e-6: + # Degenerate window — treat as fully active for that single step. + return float(base_scale) + + # Fade ramp = 20 % of window, capped at 0.05 so it never feels sluggish + # on very wide windows and never feels jarring on narrow ones. + ramp = min(0.20 * window, 0.05) + + if ramp > 0 and progress < start_pct + ramp: + fade = (progress - start_pct) / ramp + return 1.0 + (base_scale - 1.0) * fade + elif ramp > 0 and progress > end_pct - ramp: + fade = (end_pct - progress) / ramp + return 1.0 + (base_scale - 1.0) * fade + else: + return float(base_scale) + + +def default_noise_sampler(x): + """Simple noise sampler fallback.""" + def sampler(sigma, sigma_next): + return torch.randn_like(x) + return sampler + + +def get_noise_sampler(x): + """Get noise sampler for the given tensor.""" + return default_noise_sampler(x) + + +# ============================================================================ +# ADVANCED UTILITY FUNCTIONS +# ============================================================================ + +def to_d_enhanced_ancestral(x, sigma, denoised, eta, progress, generator=None): + """Enhanced derivative for ancestral sampling.""" + diff = x - denoised + safe_sigma = torch.clamp(sigma, min=1e-4) + base_derivative = diff / safe_sigma + + def safe_randn_like(tensor, generator=None): + if generator is None: + return torch.randn_like(tensor) + try: + return torch.randn(tensor.shape, device=tensor.device, dtype=tensor.dtype, generator=generator) + except (TypeError, AttributeError): + return torch.randn_like(tensor) + + if eta > 1.0: + eta_correction = 0.02 * (eta - 1.0) * safe_randn_like(diff, generator) * progress + base_derivative = base_derivative + eta_correction + elif eta < 1.0: + eta_correction = 0.015 * (1.0 - eta) * safe_randn_like(diff, generator) * (1.0 - progress) + base_derivative = base_derivative - eta_correction + + if progress < 0.3: + phase_correction = 0.01 * safe_randn_like(diff, generator) + base_derivative = base_derivative + phase_correction + elif progress > 0.7: + phase_correction = 0.008 * safe_randn_like(diff, generator) + base_derivative = base_derivative - phase_correction + + sigma_adaptive_threshold = 500.0 * (1.0 + sigma / 10.0) + derivative_max = torch.abs(base_derivative).max() + if derivative_max > sigma_adaptive_threshold: + base_derivative = torch.clamp(base_derivative, -sigma_adaptive_threshold, sigma_adaptive_threshold) + + return base_derivative + + +def apply_dynamic_thresholding(x, percentile=0.995, clamp_range=1.0): + """Dynamic thresholding for high CFG.""" + if percentile >= 1.0: + return x + + try: + batch_size = x.shape[0] + x_flat = x.view(batch_size, -1) + + abs_max = torch.abs(x_flat).max(dim=1, keepdim=True)[0] + if abs_max.max() < 5.0: + return x + + k = max(1, int(x_flat.shape[1] * (1.0 - percentile))) + topk_vals = torch.topk(torch.abs(x_flat), k=k, dim=1, largest=True)[0] + s = topk_vals[:, -1:].clamp(min=1.0) + + threshold = s * 2.5 + mask = torch.abs(x_flat) > threshold + x_flat = torch.where(mask, torch.sign(x_flat) * threshold, x_flat) + x_flat = x_flat * 0.98 + + return x_flat.view(x.shape) + except Exception as e: + return x + + +def compute_compensation_ratio(r, step_idx, total_steps, base_ratio=1.0): + """DC-Solver compensation.""" + progress = step_idx / max(total_steps - 1, 1) + if progress < 0.3: + phase_weight = 1.5 + elif progress < 0.7: + phase_weight = 1.0 + else: + phase_weight = 1.3 + return base_ratio * phase_weight * (1.0 + 0.1 * math.tanh(r - 1.0)) + + +def compute_tau_eqvae(progress, base_tau=0.5, phase_strength=0.5): + """Phase-aware tau for standard VAE.""" + if progress < 0.30: + phase_factor = 1.0 + 0.2 * phase_strength + elif progress < 0.60: + phase_factor = 1.0 - 0.15 * phase_strength + else: + phase_factor = 1.0 - 0.3 * phase_strength + return min(1.0, max(0.0, base_tau * phase_factor)) + + +def compute_eqvae_tau(progress, base_tau, phase_strength): + """EQ-VAE tau with shifted phases.""" + if progress < 0.25: + phase_factor = 1.0 + 0.10 * phase_strength + elif progress < 0.55: + phase_factor = 1.0 - 0.10 * phase_strength + else: + phase_factor = 1.0 - 0.20 * phase_strength + return min(1.0, max(0.0, base_tau * phase_factor)) + + +def compute_eqvae_noise_scale(base_s_noise, progress): + """EQ-VAE noise scale.""" + eqvae_base_factor = 0.88 + if progress < 0.25: + phase_factor = 1.0 + 0.05 * (1.0 - progress / 0.25) + elif progress < 0.60: + phase_factor = 1.0 - 0.05 * ((progress - 0.25) / 0.35) + else: + phase_factor = 0.95 + return base_s_noise * eqvae_base_factor * phase_factor + + +def compute_eqvae_ndb(progress, ndb_strength): + """Native Detail Boost for EQ-VAE.""" + if ndb_strength <= 0: + return 0.5, 0.0 + + blur_sigma = 0.6 + if progress < 0.30: + phase_progress = progress / 0.30 + high_freq_boost = 0.03 * ndb_strength * phase_progress + elif progress < 0.60: + phase_progress = (progress - 0.30) / 0.30 + high_freq_boost = (0.03 + 0.07 * phase_progress) * ndb_strength + else: + phase_progress = (progress - 0.60) / 0.40 + high_freq_boost = (0.10 + 0.10 * phase_progress) * ndb_strength + return blur_sigma, high_freq_boost + + +def compute_native_detail_boost(progress, ndb_strength=0.0): + """Native Detail Boost for standard VAE.""" + if ndb_strength <= 0: + return 1.0, 0.0 + + if progress < 0.30: + phase_progress = progress / 0.30 + high_freq_boost = 0.03 * ndb_strength * phase_progress + elif progress < 0.60: + phase_progress = (progress - 0.30) / 0.30 + high_freq_boost = (0.03 + 0.07 * phase_progress) * ndb_strength + else: + phase_progress = (progress - 0.60) / 0.40 + high_freq_boost = (0.10 + 0.08 * phase_progress) * ndb_strength + return 1.0, high_freq_boost + + +def compute_smea_factor(progress, smea_strength=0.5): + """SMEA coherency.""" + if smea_strength <= 0: + return 1.0 + smea_interp = 0.5 * (1 + math.sin(math.pi * (progress - 0.5))) + return 1.0 - smea_strength * (1.0 - smea_interp) + + + + +# ============================================================================ +# ADVANCED CFG TECHNIQUES from reForge +# ============================================================================ + +def apply_spectral_modulation_clybius(noise_pred, multiplier=1.0, percentile=5.0): + """ + Clybius Spectral Modulation: Apply frequency-domain corrections to noise prediction. + + This is the correct implementation based on ComfyUI-Latent-Modifiers. + It should be applied to noise_pred (cond - uncond), NOT to denoised latent. + + Args: + noise_pred: The noise prediction tensor (cond - uncond) + multiplier: Modulation strength (0=none, 1=full Clybius effect). Default: 1.0 + percentile: Upper/lower percentile threshold. Default: 5.0 + + Returns: + Spectrally modulated noise prediction + """ + if multiplier == 0 or percentile <= 0: + return noise_pred + + try: + # FFT + fourier = torch.fft.fft2(noise_pred, dim=(-2, -1)) + + # Log amplitude (with small epsilon for numerical stability) + log_amp = torch.log(torch.sqrt(fourier.real ** 2 + fourier.imag ** 2) + 1e-8) + + # Compute quantiles on absolute log amplitude + log_amp_flat = log_amp.abs().flatten(2) + quantile_low = torch.quantile(log_amp_flat, percentile * 0.01, dim=2) + quantile_high = torch.quantile(log_amp_flat, 1 - percentile * 0.01, dim=2) + + # Expand quantiles back to log_amp shape + quantile_low = quantile_low.unsqueeze(-1).unsqueeze(-1).expand(log_amp.shape) + quantile_high = quantile_high.unsqueeze(-1).unsqueeze(-1).expand(log_amp.shape) + + # Create masks (Clybius approach) + # mask_low: boost values below low threshold (range 1.0 to 1.5) + # mask_high: reduce values above high threshold (range 0.5 to 1.0) + mask_low = ((log_amp < quantile_low).float() + 1).clamp_(max=1.5) + mask_high = ((log_amp < quantile_high).float()).clamp_(min=0.5) + + # Apply modulation via exponentiation + filtered_fourier = fourier * ((mask_low * mask_high) ** multiplier) + + # Inverse FFT + result = torch.fft.ifft2(filtered_fourier, dim=(-2, -1)).real + + return result + + except Exception as e: + print(f"⚠️ Spectral modulation failed: {e}") + return noise_pred + + + +def create_spectral_modulation_cfg_hook(multiplier=1.0, percentile=5.0): + """ + Create a CFG hook that applies Clybius spectral modulation to noise prediction. + + This hooks into reForge's set_model_sampler_cfg_function to intercept + the CFG calculation and apply spectral modulation at the correct point. + + Args: + multiplier: Modulation strength (0=none, 1=full). Default: 1.0 + percentile: Frequency percentile threshold. Default: 5.0 + + Returns: + A hook function to pass to set_model_sampler_cfg_function + """ + def spectral_cfg_hook(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + sigma = args["sigma"] + x_orig = args["input"] + + # Reshape sigma for broadcasting + sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1)) + + # Convert to v-pred space (from RescaleCFG reference) + x = x_orig / (sigma * sigma + 1.0) + cond_v = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma) + uncond_v = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma) + + # Compute noise prediction + noise_pred = cond_v - uncond_v + + # Apply Clybius spectral modulation to noise prediction + noise_pred_modulated = apply_spectral_modulation_clybius(noise_pred, multiplier, percentile) + + # Compute CFG with modified noise prediction + x_cfg = uncond_v + cond_scale * noise_pred_modulated + + # Convert back from v-pred space + return x_orig - (x - x_cfg * sigma / (sigma * sigma + 1.0) ** 0.5) + + return spectral_cfg_hook + + + +def apply_combat_cfg_drift(latent, method='mean', intensity=1.0): + """ + Combat CFG Drift: Reduce mean drift from high CFG values. + + Based on ComfyUI-Latent-Modifiers. + + As CFG increases, the latent mean can drift away from 0, which causes + color shifts and other artifacts. This technique reduces the drift + proportionally based on intensity. + + Args: + latent: The latent tensor to correct + method: 'mean' or 'median'. Default: 'mean' + intensity: How much drift to remove (0=none, 1=full). Default: 1.0 + + Returns: + Drift-corrected latent + """ + if intensity <= 0: + return latent + + try: + if method == 'median': + # Compute global median per batch (across all channels and spatial dims) + center = latent.view(latent.shape[0], -1).median(dim=-1, keepdim=True)[0] + center = center.view(latent.shape[0], 1, 1, 1) + else: + # Compute global mean per batch (across all channels and spatial dims) + # This matches ComfyUI's PostCFGsubtractMeanNode implementation + center = latent.mean(dim=(1, 2, 3), keepdim=True) + + # Remove drift proportionally based on intensity + # intensity=1.0 removes all drift, intensity=0.5 removes half + return latent - center * intensity + + except Exception as e: + print(f"⚠️ Combat CFG drift failed: {e}") + return latent + + + +def compute_phase_aware_cfg_scale(base_scale, progress, alpha=2.0, beta=2.0): + """ + Phase-Aware CFG Scaling: Adjust CFG scale based on sampling progress. + + Inspired by β-CFG (arXiv:2502.10574). + + CFG effectiveness varies by sampling phase: + - Early: Lower CFG allows manifold exploration + - Middle: Higher CFG for prompt adherence + - Late: Lower CFG to stay on data manifold + + Args: + base_scale: The user-specified CFG scale + progress: Sampling progress (0.0 to 1.0) + alpha: Beta distribution alpha parameter. Default: 2.0 + beta: Beta distribution beta parameter. Default: 2.0 + + Returns: + Adjusted CFG scale for the current step + """ + try: + # Use a simple polynomial approximation of beta distribution + # Beta(2,2) peaks at 0.5 with a smooth curve + # f(x) = 6 * x * (1-x) for Beta(2,2), normalized to peak at 1 + if alpha == 2.0 and beta == 2.0: + # Simple case: symmetric peak at 0.5 + scale_factor = 4.0 * progress * (1.0 - progress) # Peaks at 1.0 when progress=0.5 + scale_factor = 0.7 + 0.6 * scale_factor # Range: 0.7 to 1.3 + else: + # General case: use polynomial approximation + # Mode of Beta(a,b) is at (a-1)/(a+b-2) + mode = (alpha - 1.0) / (alpha + beta - 2.0) if (alpha + beta) > 2 else 0.5 + # Create a smooth curve that peaks at the mode + dist_from_mode = abs(progress - mode) + scale_factor = 1.0 - 0.3 * dist_from_mode * 2 # Simple linear falloff + scale_factor = max(0.7, min(1.3, scale_factor)) + + return base_scale * scale_factor + + except Exception as e: + print(f"⚠️ Phase-aware CFG scaling failed: {e}") + return base_scale + + + +# apply_cfg_techniques() removed — was using legacy keys (akashic_combat_cfg_drift / +# akashic_combat_drift_intensity) that no longer match the live CFG runtime, which +# operates through configure_cfg_runtime() / adept_after_cfg_callback instead. + + +# ============================================================================ +# DUAL-MODE CFG RUNTIME (A1111 callbacks + optional native hook) +# ============================================================================ + +def create_phase_aware_native_cfg_hook(base_hook=None, alpha=2.0, beta=2.0): + """ + Native CFG hook for Forge/reForge-like backends that support + set_model_sampler_cfg_function(). Applies phase-aware CFG scaling, + then optionally delegates to a downstream hook (e.g. spectral modulation). + """ + def hook(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = float(args["cond_scale"]) + sigma = args["sigma"] + x_orig = args["input"] + + total_steps = max(int(ADEPT_STATE.get("_cfg_total_steps", 1)), 1) + step_idx = int(ADEPT_STATE.get("_cfg_step_idx", 0)) + progress = min(max(step_idx / max(total_steps - 1, 1), 0.0), 1.0) + + phased_scale = compute_phase_aware_cfg_scale(cond_scale, progress, + alpha=alpha, beta=beta) + patched_args = dict(args) + patched_args["cond_scale"] = phased_scale + + if base_hook is not None: + return base_hook(patched_args) + + # Vanilla CFG combine with phased scale + sigma_b = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1)) + x = x_orig / (sigma_b * sigma_b + 1.0) + cond_v = ((x - (x_orig - cond)) * (sigma_b ** 2 + 1.0) ** 0.5) / sigma_b + uncond_v = ((x - (x_orig - uncond)) * (sigma_b ** 2 + 1.0) ** 0.5) / sigma_b + x_cfg = uncond_v + phased_scale * (cond_v - uncond_v) + return x_orig - (x - x_cfg * sigma_b / (sigma_b * sigma_b + 1.0) ** 0.5) + return hook + + +def create_combined_native_cfg_hook(): + """ + Build one composite native hook from whatever CFG features are enabled. + Layer order: phase-aware scale → spectral modulation. + Returns None if nothing is enabled (caller should clear the hook). + """ + base_hook = None + if ADEPT_STATE.get("spectral_cfg_enabled", False): + base_hook = create_spectral_modulation_cfg_hook( + multiplier=ADEPT_STATE.get("spectral_multiplier", 1.0), + percentile=ADEPT_STATE.get("spectral_percentile", 5.0), + ) + if ADEPT_STATE.get("phase_cfg_enabled", False): + return create_phase_aware_native_cfg_hook( + base_hook=base_hook, + alpha=ADEPT_STATE.get("phase_cfg_alpha", 2.0), + beta=ADEPT_STATE.get("phase_cfg_beta", 2.0), + ) + return base_hook + + +def adept_cfg_denoiser_callback(params): + """ + Official A1111 on_cfg_denoiser callback. Used only to track step + counters for phase-aware progress bookkeeping; the public API here + doesn't expose cond/uncond predictions so we can't do CFG math. + """ + ADEPT_STATE["_cfg_step_idx"] = int(getattr(params, "sampling_step", 0)) + ADEPT_STATE["_cfg_total_steps"] = int(getattr(params, "total_sampling_steps", 1)) + + +def adept_after_cfg_callback(params): + """ + Official A1111 on_cfg_after_cfg callback. + Combat CFG Drift is the only technique that maps cleanly here, + because AfterCFGCallbackParams only provides (x, sampling_step, + total_sampling_steps) — no raw cond/uncond tensors. + """ + if not ADEPT_STATE.get("enabled", False): + return + if not ADEPT_STATE.get("cfg_drift_enabled", False): + return + try: + params.x = apply_combat_cfg_drift( + params.x, + method=ADEPT_STATE.get("cfg_drift_method", "mean"), + intensity=ADEPT_STATE.get("cfg_drift_intensity", 0.5), + ) + except Exception as e: + print(f"⚠️ Adept post-CFG drift callback failed: {e}") + + +def uninstall_a1111_cfg_callbacks(): + global _ADEPT_CFG_AFTER_CB, _ADEPT_CFG_DENOISER_CB + for cb in (_ADEPT_CFG_AFTER_CB, _ADEPT_CFG_DENOISER_CB): + if cb is not None: + try: + script_callbacks.remove_callbacks_for_function(cb) + except Exception: + pass + _ADEPT_CFG_AFTER_CB = None + _ADEPT_CFG_DENOISER_CB = None + + +def install_a1111_cfg_callbacks(): + global _ADEPT_CFG_AFTER_CB, _ADEPT_CFG_DENOISER_CB + uninstall_a1111_cfg_callbacks() + _ADEPT_CFG_DENOISER_CB = adept_cfg_denoiser_callback + _ADEPT_CFG_AFTER_CB = adept_after_cfg_callback + script_callbacks.on_cfg_denoiser( _ADEPT_CFG_DENOISER_CB, name="adept_cfg_denoiser") + script_callbacks.on_cfg_after_cfg( _ADEPT_CFG_AFTER_CB, name="adept_after_cfg") + + +def _get_native_cfg_hook_target(): + """ + Locate a Forge/reForge-like model object that supports + set_model_sampler_cfg_function(), if one exists. + """ + sd = getattr(shared, "sd_model", None) + candidates = [ + sd, + getattr(sd, "forge_objects", None), + getattr(sd, "model", None), + getattr(getattr(sd, "model", None), "model", None) if sd else None, + ] + for obj in candidates: + if obj is not None and hasattr(obj, "set_model_sampler_cfg_function"): + return obj + return None + + +def uninstall_native_cfg_hook(): + global _ADEPT_NATIVE_CFG_HOOK_ACTIVE + target = _get_native_cfg_hook_target() + if target is not None: + try: + target.set_model_sampler_cfg_function(None) + except Exception: + pass + _ADEPT_NATIVE_CFG_HOOK_ACTIVE = False + + +def install_native_cfg_hook(): + global _ADEPT_NATIVE_CFG_HOOK_ACTIVE + target = _get_native_cfg_hook_target() + if target is None: + _ADEPT_NATIVE_CFG_HOOK_ACTIVE = False + return False + hook = create_combined_native_cfg_hook() + try: + target.set_model_sampler_cfg_function(hook) # None clears it if nothing enabled + _ADEPT_NATIVE_CFG_HOOK_ACTIVE = (hook is not None) + return True + except Exception as e: + print(f"⚠️ Adept native CFG hook install failed: {e}") + _ADEPT_NATIVE_CFG_HOOK_ACTIVE = False + return False + + +# ============================================================================ +# CFGDenoiser MONKEY-PATCH (stock A1111 fallback for spectral/phase CFG) +# ============================================================================ + +def _adept_cfg_progress_from_denoiser(denoiser): + """Compute sampling progress [0,1] from CFGDenoiser step counters.""" + total_steps = max(int(getattr(denoiser, "total_steps", 1) or 1), 1) + step_idx = int(getattr(denoiser, "step", 0)) + return min(max(step_idx / max(total_steps - 1, 1), 0.0), 1.0) + + +def _adept_nativeish_cfg_term(x_i, sigma_i, cond_i, uncond_i, scale): + """ + Approximate native hook behavior for one cond/uncond pair. + Feeds x/sigma/cond/uncond into the same composite hook builder used in + native-hook mode so stock A1111 gets as close as possible to Forge parity. + Falls back to plain weighted delta if hook is unavailable or errors. + """ + if abs(float(scale)) < 1e-12: + return torch.zeros_like(uncond_i) + hook = create_combined_native_cfg_hook() + if hook is None: + return (cond_i - uncond_i) * float(scale) + try: + combined = hook({ + "cond": cond_i, + "uncond": uncond_i, + "cond_scale": float(scale), + "sigma": sigma_i, + "input": x_i, + }) + return combined - uncond_i + except Exception as e: + print(f"⚠️ Adept native-ish CFG term fallback: {e}") + return (cond_i - uncond_i) * float(scale) + + +def patch_cfg_denoiser(): + """ + Stock A1111 fallback for Spectral Modulation + Phase-Aware CFG. + + Strategy: + 1. Thin forward() wrapper that stashes x / sigma on the instance so the + combine methods can reach them — original forward logic is untouched. + 2. Patched combine_denoised() uses those values to call the same composite + hook builder as native-hook mode, giving near-parity behaviour. + 3. Patched combine_denoised_for_edit_model() does the same for pix2pix. + + This is intentionally safer than a full forward() rewrite: upstream A1111 + changes to refiner/masking/skip-uncond logic remain unaffected. + """ + global _CFGD_ORIG_COMBINE, _CFGD_ORIG_COMBINE_EDIT, _CFGD_ORIG_FORWARD, _CFGD_MONKEYPATCH_ACTIVE + try: + from modules import sd_samplers_cfg_denoiser as sd_cfg + except Exception as e: + print(f"⚠️ Adept CFGDenoiser patch import failed: {e}") + _CFGD_MONKEYPATCH_ACTIVE = False + return False + + cls = sd_cfg.CFGDenoiser + + if _CFGD_ORIG_COMBINE is None: + _CFGD_ORIG_COMBINE = cls.combine_denoised + if _CFGD_ORIG_COMBINE_EDIT is None: + _CFGD_ORIG_COMBINE_EDIT = cls.combine_denoised_for_edit_model + if _CFGD_ORIG_FORWARD is None: + _CFGD_ORIG_FORWARD = cls.forward + + # --- forward wrapper: stash x/sigma, then run original --- + def adept_forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): + setattr(self, _ADEPT_CFGDENOISER_CTX_ATTR, { + "x": x, + "sigma": sigma, + "uncond": uncond, + "cond": cond, + "cond_scale": float(cond_scale), + }) + try: + return _CFGD_ORIG_FORWARD(self, x, sigma, uncond, cond, + cond_scale, s_min_uncond, image_cond) + finally: + try: + delattr(self, _ADEPT_CFGDENOISER_CTX_ATTR) + except AttributeError: + pass + + # --- combine_denoised: per-cond native-ish or plain path --- + def adept_combine_denoised(self, x_out, conds_list, uncond, cond_scale): + denoised_uncond = x_out[-uncond.shape[0]:] + denoised = torch.clone(denoised_uncond) + ctx = getattr(self, _ADEPT_CFGDENOISER_CTX_ATTR, None) + x_ctx = None if ctx is None else ctx.get("x", None) + sigma_ctx = None if ctx is None else ctx.get("sigma", None) + progress = _adept_cfg_progress_from_denoiser(self) + + eff_scale = float(cond_scale) + if ADEPT_STATE.get("phase_cfg_enabled", False): + eff_scale = compute_phase_aware_cfg_scale( + eff_scale, progress, + alpha=ADEPT_STATE.get("phase_cfg_alpha", 2.0), + beta =ADEPT_STATE.get("phase_cfg_beta", 2.0), + ) + + use_nativeish = ( + (ADEPT_STATE.get("spectral_cfg_enabled", False) or + ADEPT_STATE.get("phase_cfg_enabled", False)) + and x_ctx is not None and sigma_ctx is not None + ) + + for i, conds in enumerate(conds_list): + for cond_index, weight in conds: + cond_i = x_out[cond_index:cond_index + 1] + uncond_i = denoised_uncond[i:i + 1] + + if use_nativeish: + term = _adept_nativeish_cfg_term( + x_i = x_ctx[i:i + 1], + sigma_i = sigma_ctx[i:i + 1], + cond_i = cond_i, + uncond_i = uncond_i, + scale = float(weight) * eff_scale, + ) + denoised[i:i + 1] += term + else: + delta = cond_i - uncond_i + if ADEPT_STATE.get("spectral_cfg_enabled", False): + delta = apply_spectral_modulation_clybius( + delta, + multiplier=ADEPT_STATE.get("spectral_multiplier", 1.0), + percentile=ADEPT_STATE.get("spectral_percentile", 5.0), + ) + denoised[i:i + 1] += delta * (float(weight) * eff_scale) + + return denoised + + # --- combine_denoised_for_edit_model: pix2pix / instruct path --- + def adept_combine_denoised_for_edit_model(self, x_out, cond_scale): + out_cond, out_img_cond, out_uncond = x_out.chunk(3) + ctx = getattr(self, _ADEPT_CFGDENOISER_CTX_ATTR, None) + x_ctx = None if ctx is None else ctx.get("x", None) + sigma_ctx = None if ctx is None else ctx.get("sigma", None) + progress = _adept_cfg_progress_from_denoiser(self) + + eff_scale = float(cond_scale) + if ADEPT_STATE.get("phase_cfg_enabled", False): + eff_scale = compute_phase_aware_cfg_scale( + eff_scale, progress, + alpha=ADEPT_STATE.get("phase_cfg_alpha", 2.0), + beta =ADEPT_STATE.get("phase_cfg_beta", 2.0), + ) + + # Native-ish path when context is available + if (ADEPT_STATE.get("spectral_cfg_enabled", False) or + ADEPT_STATE.get("phase_cfg_enabled", False)): + if x_ctx is not None and sigma_ctx is not None: + try: + hook = create_combined_native_cfg_hook() + if hook is not None: + base = hook({ + "cond": out_cond, + "uncond": out_img_cond, + "cond_scale": eff_scale, + "sigma": sigma_ctx, + "input": x_ctx, + }) + return base + self.image_cfg_scale * (out_img_cond - out_uncond) + except Exception as e: + print(f"⚠️ Adept edit-model native-ish fallback: {e}") + + # Plain path (no context or hook failed) + delta = out_cond - out_img_cond + if ADEPT_STATE.get("spectral_cfg_enabled", False): + delta = apply_spectral_modulation_clybius( + delta, + multiplier=ADEPT_STATE.get("spectral_multiplier", 1.0), + percentile=ADEPT_STATE.get("spectral_percentile", 5.0), + ) + return out_uncond + eff_scale * delta + self.image_cfg_scale * (out_img_cond - out_uncond) + + try: + cls.forward = adept_forward + cls.combine_denoised = adept_combine_denoised + cls.combine_denoised_for_edit_model = adept_combine_denoised_for_edit_model + _CFGD_MONKEYPATCH_ACTIVE = True + return True + except Exception as e: + print(f"⚠️ Adept CFGDenoiser patch failed: {e}") + _CFGD_MONKEYPATCH_ACTIVE = False + return False + + +def unpatch_cfg_denoiser(): + global _CFGD_MONKEYPATCH_ACTIVE + try: + from modules import sd_samplers_cfg_denoiser as sd_cfg + except Exception: + _CFGD_MONKEYPATCH_ACTIVE = False + return False + cls = sd_cfg.CFGDenoiser + try: + if _CFGD_ORIG_FORWARD is not None: + cls.forward = _CFGD_ORIG_FORWARD + if _CFGD_ORIG_COMBINE is not None: + cls.combine_denoised = _CFGD_ORIG_COMBINE + if _CFGD_ORIG_COMBINE_EDIT is not None: + cls.combine_denoised_for_edit_model = _CFGD_ORIG_COMBINE_EDIT + _CFGD_MONKEYPATCH_ACTIVE = False + return True + except Exception as e: + print(f"⚠️ Adept CFGDenoiser unpatch failed: {e}") + _CFGD_MONKEYPATCH_ACTIVE = False + return False + + +def configure_cfg_runtime(): + """ + Select and activate the right CFG runtime mode: + off – nothing enabled; all hooks/callbacks/patches cleared + a1111-postcfg – stock A1111; Combat CFG Drift only via official callback + a1111-monkeypatch – stock A1111; Spectral + Phase-Aware via CFGDenoiser patch + native-hook – Forge/reForge-like backend; all three via sampler CFG hook + + Returns the mode string so process() can log it. + """ + # If the extension is globally disabled, always tear down and return off. + if not ADEPT_STATE.get("enabled", False): + uninstall_a1111_cfg_callbacks() + uninstall_native_cfg_hook() + unpatch_cfg_denoiser() + ADEPT_STATE["cfg_runtime_mode"] = "off" + return "off" + + drift = ADEPT_STATE.get("cfg_drift_enabled", False) + spectral = ADEPT_STATE.get("spectral_cfg_enabled", False) + phase = ADEPT_STATE.get("phase_cfg_enabled", False) + + # Always tear down everything first for a clean slate + uninstall_a1111_cfg_callbacks() + uninstall_native_cfg_hook() + unpatch_cfg_denoiser() + + if not (drift or spectral or phase): + ADEPT_STATE["cfg_runtime_mode"] = "off" + return "off" + + # Prefer native hook if backend supports it + native_target = _get_native_cfg_hook_target() + if native_target is not None: + install_a1111_cfg_callbacks() # keeps drift working in native mode too + install_native_cfg_hook() + ADEPT_STATE["cfg_runtime_mode"] = "native-hook" + return "native-hook" + + # Stock A1111: always install callbacks (drift) + install_a1111_cfg_callbacks() + + if spectral or phase: + ok = patch_cfg_denoiser() + if ok: + ADEPT_STATE["cfg_runtime_mode"] = "a1111-monkeypatch" + print("✅ Adept: stock A1111 CFGDenoiser monkey-patch active (spectral/phase + drift enabled)") + return "a1111-monkeypatch" + print("⚠️ Adept: CFGDenoiser monkey-patch failed; falling back to post-CFG drift only") + + ADEPT_STATE["cfg_runtime_mode"] = "a1111-postcfg" + return "a1111-postcfg" + + +def sa_solver_step(x, d_history, sigma, sigma_next, tau, s_noise=1.0, noise_sampler=None, + order=2, ndb_strength=0.0, progress=0.0, eqvae_mode=False, eqvae_blur_sigma=None): + """SA-Solver step - CRITICAL for Akashic Solver.""" + dt = sigma_next - sigma + + if len(d_history) >= 2 and order >= 2: + sigma_cur, d_cur = d_history[-1] + sigma_prev, d_prev = d_history[-2] + + h_prev = sigma_cur - sigma_prev + r = abs(dt / (h_prev + 1e-8)) if abs(h_prev) > 1e-8 else 1.0 + r = min(r, 2.0) + + if len(d_history) >= 3 and order >= 3: + sigma_0, d_0 = d_history[-3] + h_0 = sigma_prev - sigma_0 + h_1 = h_prev + + if abs(h_0) > 1e-6 and abs(h_1) > 1e-6: + r0 = min(abs(h_1 / h_0), 2.0) + r1 = min(abs(dt / (h_1 + 1e-8)), 2.0) + + tau_blend = 1.0 - tau + c0_ab3 = 1.0 + (1.0 + r0) * r1 / 2.0 + c1_ab3 = -(1.0 + r0) * r1 / 2.0 + c2_ab3 = r0 * r1 / 2.0 + c0 = tau_blend * c0_ab3 + (1.0 - tau_blend) * 1.0 + c1 = tau_blend * c1_ab3 + c2 = tau_blend * c2_ab3 + + c_sum = c0 + c1 + c2 + if abs(c_sum) > 1e-8: + c0 /= c_sum + c1 /= c_sum + c2 /= c_sum + else: + c0, c1, c2 = 1.0, 0.0, 0.0 + + d_interp = c0 * d_cur + c1 * d_prev + c2 * d_0 + else: + tau_blend = 1.0 - tau + c1_ab2 = 1.0 + 0.5 * r + c2_ab2 = -0.5 * r + c1 = tau_blend * c1_ab2 + (1.0 - tau_blend) * 1.0 + c2 = tau_blend * c2_ab2 + c_sum = c1 + c2 + if abs(c_sum) > 1e-8: + c1 /= c_sum + c2 /= c_sum + d_interp = c1 * d_cur + c2 * d_prev + else: + tau_blend = 1.0 - tau + c1_ab2 = 1.0 + 0.5 * r + c2_ab2 = -0.5 * r + c1 = tau_blend * c1_ab2 + (1.0 - tau_blend) * 1.0 + c2 = tau_blend * c2_ab2 + c_sum = c1 + c2 + if abs(c_sum) > 1e-8: + c1 /= c_sum + c2 /= c_sum + d_interp = c1 * d_cur + c2 * d_prev + elif len(d_history) >= 1: + d_interp = d_history[-1][1] + else: + d_interp = torch.zeros_like(x) + + # Compute sigma_up based on tau (controls stochasticity) + sigma_up = 0.0 + if tau > 0 and sigma_next > 0 and noise_sampler is not None: + sigma_ancestral_sq = sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / (sigma ** 2 + 1e-8) + sigma_ancestral = sigma_ancestral_sq ** 0.5 if sigma_ancestral_sq > 0 else 0.0 + sigma_up = tau * sigma_ancestral + + sigma_down = (sigma_next ** 2 - sigma_up ** 2) ** 0.5 + dt_adjusted = sigma_down - sigma + + x_det = x + d_interp * dt_adjusted + noise = noise_sampler(sigma, sigma_next) * s_noise * sigma_up + + # Apply Native Detail Boost if enabled + if ndb_strength > 0 and TORCHVISION_AVAILABLE: + # Use EQ-VAE optimized NDB parameters if in EQ-VAE mode + if eqvae_mode: + blur_sigma, high_freq_boost = compute_eqvae_ndb(progress, ndb_strength) + else: + _, high_freq_boost = compute_native_detail_boost(progress, ndb_strength) + blur_sigma = 0.5 # Default blur sigma + + # Override blur_sigma if explicitly provided + if eqvae_blur_sigma is not None: + blur_sigma = eqvae_blur_sigma + + # Extract high-frequency component from noise using Gaussian blur + try: + low_freq_noise = gaussian_blur(noise, kernel_size=3, sigma=blur_sigma) + high_freq_noise = noise - low_freq_noise + noise = noise + high_freq_noise * high_freq_boost + except Exception: + pass # Fallback: use original noise if blur fails + + x_next = x_det + noise + else: + x_next = x + d_interp * dt + + return x_next, sigma_up + + +def create_detail_enhanced_model(model, x, sigmas, settings): + # NOTE: Detail Enhancement is currently an internal/experimental path. + # It is not wired into the UI and callers always pass + # use_detail_enhancement=False, so this function is never invoked at + # runtime. Kept for future re-integration; do not rely on it. + """Detail enhancement wrapper.""" + if not TORCHVISION_AVAILABLE: + return model + + base_strength = settings.get('detail_enhancement_strength', 0.05) + radius = settings.get('detail_separation_radius', 0.5) + total_steps = len(sigmas) - 1 + + class DetailEnhancer: + def __init__(self): + self.current_step = 0 + + def __call__(self, x_current, sigma, **kwargs): + denoised = model(x_current, sigma, **kwargs) + + try: + low_freq = gaussian_blur(denoised, kernel_size=3, sigma=radius) + high_freq = denoised - low_freq + + progress = min(self.current_step / max(total_steps, 1), 1.0) + strength = base_strength * (0.5 + progress) + + enhanced = denoised + high_freq * strength + self.current_step += 1 + return enhanced + except Exception: + return denoised + + return DetailEnhancer() + + +# ============================================================================ +# ============================================================================ +# ============================================================================ +# CUSTOM ADVANCED SAMPLERS (Complete port from ComfyUI) +# ============================================================================ + +@torch.no_grad() +def sample_adept_solver(model, x, sigmas, extra_args=None, callback=None, disable=None, + order=2, use_corrector=True, use_detail_enhancement=False, settings=None): + """ + Adept Solver: A unified training-free diffusion solver synthesizing improvements from: + - DPM-Solver++ (data prediction, dynamic thresholding) + - UniPC (unified predictor-corrector framework) + - DEIS (exponential integrator) + - DC-Solver (dynamic compensation) + """ + extra_args = {} if extra_args is None else extra_args + settings = settings or {} + s_in = x.new_ones([x.shape[0]]) + + order = max(1, min(order, 3)) + + print(f"🚀 Adept Solver active (Order: {order}, Corrector: {'On' if use_corrector else 'Off'})") + + active_model = model + # use_detail_enhancement is always False from current call-sites; + # the block below is preserved for future re-integration but is not + # currently reachable via the UI. + if use_detail_enhancement and TORCHVISION_AVAILABLE: + active_model = create_detail_enhanced_model(model, x, sigmas, settings) + + model_outputs = [] + + for i in range(len(sigmas) - 1): + sigma = sigmas[i] + sigma_next = sigmas[i + 1] + + denoised = active_model(x, sigma * s_in, **extra_args) + + if extra_args.get('cond_scale', 1.0) > 7.0: + denoised = apply_dynamic_thresholding(denoised, percentile=0.995) + + d = to_d(x, sigma, denoised) + + derivative_max = torch.abs(d).max() + sigma_adaptive_threshold = 1000.0 * (1.0 + sigma / 10.0) + if torch.isnan(d).any() or torch.isinf(d).any() or derivative_max > sigma_adaptive_threshold: + print(f"⚠️ Extreme derivative detected at step {i}/{len(sigmas)-1}. Clamping for stability.") + d = torch.clamp(d, -sigma_adaptive_threshold, sigma_adaptive_threshold) + if torch.isnan(d).any() or torch.isinf(d).any(): + d = torch.zeros_like(d) + + model_outputs.append((sigma, d)) + if len(model_outputs) > order: + model_outputs.pop(0) + + dt = sigma_next - sigma + + if len(model_outputs) == 1 or order == 1: + x_pred = x + d * dt + elif len(model_outputs) == 2 and order >= 2: + sigma_prev, d_prev = model_outputs[-2] + d_cur = model_outputs[-1][1] + + h = sigma - sigma_prev + compensation_ratio = compute_compensation_ratio(h.item() if torch.is_tensor(h) else float(h), i, len(sigmas)) + + d_interp = d_cur + compensation_ratio * (d_cur - d_prev) + x_pred = x + d_interp * dt + else: + sigma_0, d_0 = model_outputs[-3] + sigma_1, d_1 = model_outputs[-2] + sigma_2, d_2 = model_outputs[-1] + + h_0 = sigma_2 - sigma_1 + h_1 = sigma_1 - sigma_0 + + h_0_val = h_0.item() if torch.is_tensor(h_0) else float(h_0) + h_1_val = h_1.item() if torch.is_tensor(h_1) else float(h_1) + + if abs(h_1_val) < 1e-6: + compensation_ratio = compute_compensation_ratio(h_0_val, i, len(sigmas)) + d_interp = d_2 + compensation_ratio * (d_2 - d_1) + else: + r0 = h_0_val / h_1_val + c0 = 1.0 + r0 / 2.0 + c1 = -r0 / 2.0 + c2 = 0.0 + + c_sum = c0 + c1 + c2 + c0 /= c_sum + c1 /= c_sum + c2 = 1.0 - c0 - c1 + + d_interp = c0 * d_2 + c1 * d_1 + c2 * d_0 + + x_pred = x + d_interp * dt + + if use_corrector and i < len(sigmas) - 2: + denoised_pred = active_model(x_pred, sigma_next * s_in, **extra_args) + + if extra_args.get('cond_scale', 1.0) > 7.0: + denoised_pred = apply_dynamic_thresholding(denoised_pred, percentile=0.995) + + d_pred = to_d(x_pred, sigma_next, denoised_pred) + + if torch.isnan(d_pred).any() or torch.isinf(d_pred).any() or torch.abs(d_pred).max() > 1000.0: + d_pred = torch.clamp(d_pred, -100.0, 100.0) + if torch.isnan(d_pred).any() or torch.isinf(d_pred).any(): + d_pred = torch.zeros_like(d_pred) + + dt = sigma_next - sigma + x = x + (d + d_pred) * dt * 0.5 + else: + x = x_pred + + if torch.isnan(x).any() or torch.isinf(x).any(): + print(f"❌ CRITICAL: NaN/Inf detected at step {i}/{len(sigmas)-1}!") + if i == 0: + raise RuntimeError("NaN/Inf on first step - check model/inputs") + + print(" Attempting recovery with conservative Euler step...") + denoised_safe = active_model(x, sigma * s_in, **extra_args) + if torch.isnan(denoised_safe).any(): + raise RuntimeError("Model producing NaN - check CFG scale and model") + + d_safe = to_d(x, sigma, denoised_safe) + dt_safe = (sigma_next - sigma) * 0.5 + x = x + d_safe * dt_safe + use_corrector = False + print(" Recovery successful. Corrector disabled for stability.") + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) + + return x + + +@torch.no_grad() +def sample_adept_ancestral_solver(model, x, sigmas, extra_args=None, callback=None, disable=None, + eta=1.0, s_noise=1.0, adaptive_eta=False, phase_noise=False, + phase_strength=0.5, enhanced_derivative=False, + use_detail_enhancement=False, settings=None): + """ + Enhanced Adept Ancestral Solver: Advanced ancestral sampling with phase-aware adaptations. + + Key innovations: + 1. Adaptive ancestral step sizing that changes throughout sampling phases + 2. Phase-aware noise injection (more noise early, less noise late) + 3. Enhanced derivative computation with ancestral-specific corrections + 4. Dynamic eta scheduling for better control + """ + extra_args = {} if extra_args is None else extra_args + settings = settings or {} + s_in = x.new_ones([x.shape[0]]) + + print(f"🚀 Enhanced Adept Ancestral Solver active (η: {eta:.2f}, s_noise: {s_noise:.2f})") + print(f" Adaptive Eta: {adaptive_eta}, Phase Noise: {phase_noise}, Enhanced Derivative: {enhanced_derivative}") + + active_model = model + # use_detail_enhancement is always False from current call-sites. + if use_detail_enhancement and TORCHVISION_AVAILABLE: + active_model = create_detail_enhanced_model(model, x, sigmas, settings) + + noise_sampler = get_noise_sampler(x) + for i in range(len(sigmas) - 1): + sigma = sigmas[i] + sigma_next = sigmas[i + 1] + + progress = i / max(len(sigmas) - 1, 1) + + if adaptive_eta: + if progress < 0.3: + current_eta = eta * 1.08 + elif progress < 0.7: + current_eta = eta * 0.95 + else: + current_eta = eta * 1.02 + else: + current_eta = eta + + denoised = active_model(x, sigma * s_in, **extra_args) + + if extra_args.get('cond_scale', 1.0) > 7.0: + denoised = apply_dynamic_thresholding(denoised, percentile=0.995) + + if enhanced_derivative: + d = to_d_enhanced_ancestral(x, sigma, denoised, current_eta, progress, None) + else: + d = to_d(x, sigma, denoised) + + derivative_max = torch.abs(d).max() + sigma_adaptive_threshold = 1000.0 * (1.0 + sigma / 10.0) + if torch.isnan(d).any() or torch.isinf(d).any() or derivative_max > sigma_adaptive_threshold: + d = torch.clamp(d, -sigma_adaptive_threshold, sigma_adaptive_threshold) + if torch.isnan(d).any() or torch.isinf(d).any(): + d = torch.zeros_like(d) + + if sigma_next > 0: + sigma_up = min(sigma_next, current_eta * (sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) ** 0.5) + sigma_down = (sigma_next ** 2 - sigma_up ** 2) ** 0.5 + else: + sigma_up = 0.0 + sigma_down = 0.0 + + dt = sigma_down - sigma + x_pred = x + d * dt + + if sigma_next > 0: + if phase_noise: + if progress < 0.25: + target_multiplier = 1.0 + (0.05 * min(progress / 0.25, 1.0)) + elif progress < 0.6: + target_multiplier = 1.0 - (0.02 * min((progress - 0.25) / 0.35, 1.0)) + else: + target_multiplier = 1.0 - (0.05 * min((progress - 0.6) / 0.4, 1.0)) + + noise_multiplier = 1.0 + (target_multiplier - 1.0) * phase_strength + adaptive_s_noise = s_noise * noise_multiplier + else: + adaptive_s_noise = s_noise + + noise = noise_sampler(sigma, sigma_next) * adaptive_s_noise * sigma_up + x = x_pred + noise + else: + x = x_pred + + if torch.isnan(x).any() or torch.isinf(x).any(): + print(f"❌ CRITICAL: NaN/Inf detected at step {i}/{len(sigmas)-1}!") + if i == 0: + raise RuntimeError("NaN/Inf on first step - check model/inputs") + + print(" Attempting recovery...") + denoised_safe = active_model(x, sigma * s_in, **extra_args) + if torch.isnan(denoised_safe).any(): + raise RuntimeError("Model producing NaN - check CFG scale and model") + + d_safe = to_d(x, sigma, denoised_safe) + dt_safe = (sigma_next - sigma) * 0.5 + x = x + d_safe * dt_safe + print(" Recovery successful.") + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) + + return x + + +@torch.no_grad() +def sample_mirror_correction_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, + eta=1.0, s_noise=1.0, correction_phase=0.5, smooth_phase=False): + """ + Mirror Correction Euler: Euler Ancestral with a semantic reflection probe. + + In the first `correction_phase` fraction of steps, uses a 3-call Heun correction: + x_probe = 2*D(x) - x (reflection of x through its own denoised prediction) + The probe lies on the denoising trajectory, giving a curvature estimate for the + Heun correction. Remaining steps: standard 1-call Euler Ancestral. + + Args: + eta: Ancestral noise coefficient. 0=deterministic, 1=full ancestral. Default: 1.0 + s_noise: Noise scale multiplier. Default: 1.0 + correction_phase: Fraction of steps that receive the 3-call correction. Default: 0.5 + smooth_phase: Use continuous log-sigma weighting instead of a binary cutoff. Default: False + """ + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + print(f"🔮 Mirror Correction Euler active (η: {eta:.2f}, s_noise: {s_noise:.2f})") + print(f" Correction Phase: {correction_phase:.2f}, Smooth Phase: {smooth_phase}") + + noise_sampler = get_noise_sampler(x) + n_steps = len(sigmas) - 1 + + log_sigma_phase = None + log_sigma_max = None + smooth_denom = 1e-6 + if smooth_phase and n_steps > 0: + sigma_max_val = sigmas[0].clamp(min=1e-6) + phase_idx = min(int(correction_phase * n_steps), n_steps - 1) + sigma_phase_val = sigmas[phase_idx].clamp(min=1e-6) + log_sigma_max = torch.log(sigma_max_val).item() + log_sigma_phase = torch.log(sigma_phase_val).item() + smooth_denom = max(log_sigma_max - log_sigma_phase, 1e-6) + + for i in range(n_steps): + sigma = sigmas[i] + sigma_next = sigmas[i + 1] + progress = i / max(n_steps - 1, 1) + + denoised = model(x, sigma * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) + + d = to_d(x, sigma, denoised) + + if sigma_next > 0: + sigma_up = min(sigma_next, eta * (sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) ** 0.5) + sigma_down = (sigma_next ** 2 - sigma_up ** 2) ** 0.5 + else: + sigma_up = 0.0 + sigma_down = 0.0 + dt = sigma_down - sigma + + if smooth_phase and log_sigma_phase is not None: + log_sig = torch.log(sigma.clamp(min=1e-6)).item() + t = max(0.0, min(1.0, (log_sig - log_sigma_phase) / smooth_denom)) + correction_weight = t ** 0.5 + + if correction_weight > 1e-3 and sigma_next > 0: + x_probe = 2 * denoised - x + d_probe = to_d(x_probe, sigma, model(x_probe, sigma * s_in, **extra_args)) + + d_diff_norm = (d - d_probe).norm() + d_scale = (d.norm() + d_probe.norm()) / 2 + 1e-6 + gradient_agreement = max(0.0, 1.0 - (d_diff_norm / d_scale).item()) + effective_weight = correction_weight * gradient_agreement + + if effective_weight > 1e-3: + x3 = x + ((d + d_probe) / 2) * dt + d3 = to_d(x3, sigma, model(x3, sigma * s_in, **extra_args)) + d_heun = (d + d3) / 2 + if not (torch.isnan(d_heun).any() or torch.isinf(d_heun).any()): + d = d + effective_weight * (d_heun - d) + else: + if progress < correction_phase and sigma_next > 0: + x_probe = 2 * denoised - x + d_probe = to_d(x_probe, sigma, model(x_probe, sigma * s_in, **extra_args)) + x3 = x + ((d + d_probe) / 2) * dt + d3 = to_d(x3, sigma, model(x3, sigma * s_in, **extra_args)) + d = (d + d3) / 2 + if torch.isnan(d).any() or torch.isinf(d).any(): + d = torch.zeros_like(d) + + x = x + d * dt + if sigma_next > 0: + x = x + noise_sampler(sigma, sigma_next) * s_noise * sigma_up + + return x + + +@torch.no_grad() +def sample_akashic_solver(model, x, sigmas, extra_args=None, callback=None, disable=None, + tau=0.5, eta=1.0, s_noise=1.0, adaptive_eta=True, phase_strength=0.5, + order=2, smea_strength=0.0, ndb_strength=0.0, + use_detail_enhancement=False, settings=None, eqvae_mode='Off'): + """ + AkashicSolver v2 [EXPERIMENTAL]: Advanced sampler optimized for EQ-VAE models. + + Combines: + 1. SA-SOLVER BASE: Multi-step Adams-Bashforth integration with tau function + 2. PHASE-AWARE SAMPLING: Three-phase approach with adaptive parameters + 3. SMEA COHERENCY: Sine-based interpolation for high-resolution coherency + + Args: + eqvae_mode: EQ-VAE optimization mode ('Off' or 'Balanced') + """ + extra_args = {} if extra_args is None else extra_args + settings = settings or {} + s_in = x.new_ones([x.shape[0]]) + + if isinstance(eqvae_mode, bool): + eqvae_enabled = eqvae_mode + else: + eqvae_enabled = eqvae_mode == 'Balanced' + + if eqvae_enabled: + print(f"🌀 AkashicSolver v2 [EQ-VAE BALANCED] active") + print(f" Optimized for EQ-VAE's cleaner latent space") + else: + print(f"🌀 AkashicSolver v2 [EXPERIMENTAL] active") + print(f" τ (tau): {tau:.2f}, η (eta): {eta:.2f}, s_noise: {s_noise:.2f}") + print(f" Order: {order}, Adaptive Eta: {adaptive_eta}, Phase Strength: {phase_strength:.2f}") + if smea_strength > 0: + print(f" SMEA: {smea_strength:.2f} (high-res coherency)") + if ndb_strength > 0: + print(f" Native Detail Boost: {ndb_strength:.2f} (detail enhancement)") + if not eqvae_enabled: + print(f" ⚠️ Use external rescaleCFG (e.g., 0.7) for EQ-VAE models") + + active_model = model + # use_detail_enhancement is always False from current call-sites. + if use_detail_enhancement and TORCHVISION_AVAILABLE: + active_model = create_detail_enhanced_model(model, x, sigmas, settings) + + noise_sampler = get_noise_sampler(x) + total_steps = len(sigmas) - 1 + d_history = [] + + for i in range(total_steps): + sigma = sigmas[i] + sigma_next = sigmas[i + 1] + + progress = i / max(total_steps - 1, 1) + + if adaptive_eta: + if eqvae_enabled: + current_tau = compute_eqvae_tau(progress, tau, phase_strength) + else: + current_tau = compute_tau_eqvae(progress, tau, phase_strength) + else: + current_tau = tau + + if adaptive_eta: + if eqvae_enabled: + if progress < 0.25: + current_eta = eta * (1.0 + 0.03 * phase_strength) + elif progress < 0.55: + current_eta = eta * (1.0 - 0.03 * phase_strength) + else: + current_eta = eta * (1.0 + 0.02 * phase_strength) + else: + if progress < 0.30: + current_eta = eta * (1.0 + 0.08 * phase_strength) + elif progress < 0.60: + current_eta = eta * (1.0 - 0.05 * phase_strength) + else: + current_eta = eta * (1.0 + 0.02 * phase_strength) + else: + current_eta = eta + + smea_factor = compute_smea_factor(progress, smea_strength) + + denoised = active_model(x, sigma * s_in, **extra_args) + + cfg_scale = extra_args.get('cond_scale', 1.0) + if cfg_scale > 7.0: + denoised = apply_dynamic_thresholding(denoised, percentile=0.995) + + d = to_d(x, sigma, denoised) + + derivative_max = torch.abs(d).max() + sigma_adaptive_threshold = 1000.0 * (1.0 + sigma / 10.0) + if torch.isnan(d).any() or torch.isinf(d).any() or derivative_max > sigma_adaptive_threshold: + d = torch.clamp(d, -sigma_adaptive_threshold, sigma_adaptive_threshold) + if torch.isnan(d).any() or torch.isinf(d).any(): + d = torch.zeros_like(d) + + d_history.append((sigma, d)) + if len(d_history) > order: + d_history.pop(0) + + effective_tau = current_tau + + if eqvae_enabled: + effective_s_noise = compute_eqvae_noise_scale(s_noise * current_eta, progress) * smea_factor + else: + effective_s_noise = s_noise * current_eta * smea_factor + + if progress < 0.30: + noise_multiplier = 1.0 + 0.03 * phase_strength + elif progress < 0.60: + noise_multiplier = 1.0 - 0.01 * phase_strength + else: + noise_multiplier = 1.0 - 0.02 * phase_strength + + effective_s_noise *= noise_multiplier + + if eqvae_enabled and ndb_strength > 0: + eqvae_blur_sigma, _ = compute_eqvae_ndb(progress, ndb_strength) + else: + eqvae_blur_sigma = None + + x, sigma_up = sa_solver_step( + x=x, + d_history=d_history, + sigma=sigma, + sigma_next=sigma_next, + tau=effective_tau, + s_noise=effective_s_noise, + noise_sampler=noise_sampler, + order=order, + ndb_strength=ndb_strength, + progress=progress, + eqvae_mode=eqvae_enabled, + eqvae_blur_sigma=eqvae_blur_sigma + ) + + if torch.isnan(x).any() or torch.isinf(x).any(): + print(f"❌ AkashicSolver v2: NaN/Inf detected at step {i}/{total_steps}!") + + if i == 0: + raise RuntimeError("NaN/Inf on first step - check model/inputs") + + print(" Attempting recovery...") + denoised_safe = active_model(x, sigma * s_in, **extra_args) + if torch.isnan(denoised_safe).any(): + raise RuntimeError("Model producing NaN - reduce CFG scale or check model") + + d_safe = to_d(x, sigma, denoised_safe) + dt_safe = (sigma_next - sigma) * 0.5 + x = x + d_safe * dt_safe + + d_history.clear() + print(" Recovery successful. Multi-step history cleared.") + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) + + return x + + +# ============================================================================ +# ============================================================================ +# ============================================================================ +# This file continues from PART1 + PART1B +# ============================================================================ + +# ============================================================================ +# WEIGHT PATCHER (from v3 - unchanged) +# ============================================================================ + +def should_patch_weights(unet_model, scale, shift): + """Return True if weight patching is actually needed for these parameters.""" + return ( + unet_model is not None and + (abs(scale - 1.0) > 1e-6 or abs(shift) > 1e-6) + ) + + +class AdeptWeightPatcher: + """Temporary weight scaling for UNet.""" + + def __init__(self, unet_model, scale=1.0, shift=0.0): + self.unet_model = unet_model + self.scale = scale + self.shift = shift + self.backups = {} + self.target_layers = [] + + def __enter__(self): + if self.unet_model is None or (abs(self.scale - 1.0) < 1e-6 and abs(self.shift) < 1e-6): + return self + + self.target_layers.clear() + self.backups.clear() + + try: + for name, module in self.unet_model.named_modules(): + if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): + if hasattr(module, 'weight') and module.weight is not None: + self.target_layers.append((name, module)) + self.backups[name] = module.weight.data.clone() + module.weight.data = module.weight.data * self.scale + self.shift + except Exception as e: + print(f"❌ Weight patcher failed: {e}") + self.__exit__(None, None, None) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + for name, module in self.target_layers: + if name in self.backups: + module.weight.data.copy_(self.backups[name]) + self.backups.clear() + self.target_layers.clear() + except Exception as e: + print(f"❌ CRITICAL: Failed to restore weights: {e}") + for name, backup_data in self.backups.items(): + try: + for n, m in self.target_layers: + if n == name: + m.weight.data.copy_(backup_data) + except: + pass + + return False + + +# ============================================================================ +# VAE REFLECTION PATCHER (from v3 - unchanged) +# ============================================================================ + +class VAEReflectionPatcher: + """Context manager for VAE reflection padding.""" + + def __init__(self, vae_model): + self.vae_model = vae_model + self.backups = {} + + def __enter__(self): + global _vae_reflection_active, _vae_original_padding_modes + + if _vae_reflection_active or self.vae_model is None: + return self + + _vae_original_padding_modes.clear() + patched_count = 0 + + try: + for name, module in self.vae_model.named_modules(): + if isinstance(module, torch.nn.Conv2d): + _vae_original_padding_modes[name] = module.padding_mode + module.padding_mode = 'reflect' + patched_count += 1 + + _vae_reflection_active = True + print(f"🪞 VAE Reflection: Patched {patched_count} Conv2d layers") + except Exception as e: + print(f"❌ VAE Reflection failed: {e}") + self.__exit__(None, None, None) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + global _vae_reflection_active, _vae_original_padding_modes + + if self.vae_model is None: + _vae_reflection_active = False + _vae_original_padding_modes.clear() + return False + + restored_count = 0 + try: + for name, module in self.vae_model.named_modules(): + if isinstance(module, torch.nn.Conv2d) and name in _vae_original_padding_modes: + module.padding_mode = _vae_original_padding_modes[name] + restored_count += 1 + + _vae_reflection_active = False + _vae_original_padding_modes.clear() + print(f"🔄 VAE Reflection: Restored {restored_count} layers") + except Exception as e: + print(f"⚠️ VAE Reflection restore warning: {e}") + + return False + + +def force_restore_vae_reflection(): + """ + Emergency / unload-path restore for VAE padding modes. + Safe to call at any time — does nothing if VAE reflection was not active. + """ + global _vae_reflection_active, _vae_original_padding_modes + if not _vae_reflection_active and not _vae_original_padding_modes: + return + try: + sd = getattr(shared, "sd_model", None) + vae = getattr(sd, "first_stage_model", None) if sd else None + if vae is not None: + restored = 0 + for name, module in vae.named_modules(): + if isinstance(module, torch.nn.Conv2d) and name in _vae_original_padding_modes: + module.padding_mode = _vae_original_padding_modes[name] + restored += 1 + if restored: + print(f"🔄 VAE Reflection: force-restored {restored} layers") + except Exception as e: + print(f"⚠️ VAE Reflection force-restore warning: {e}") + finally: + _vae_reflection_active = False + _vae_original_padding_modes.clear() + + +# ALL SCHEDULERS (18 types) +# ============================================================================ + +def create_aos_v_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): + """AOS-V (Anime-Optimized Schedule for v-prediction models).""" + rho = 7.0 + + p1_steps = int(num_steps * 0.2) + p2_steps = int(num_steps * 0.6) + + ramp = torch.empty(num_steps, device=device, dtype=torch.float32) + + if p1_steps > 0: + torch.linspace(0, 1, p1_steps, out=ramp[:p1_steps]) + ramp[:p1_steps].pow_(0.5).mul_(0.6) + + if p2_steps > p1_steps: + torch.linspace(0.6, 0.9, p2_steps - p1_steps, out=ramp[p1_steps:p2_steps]) + + if num_steps > p2_steps: + torch.linspace(0, 1, num_steps - p2_steps, out=ramp[p2_steps:]) + ramp[p2_steps:].pow_(3).mul_(0.1).add_(0.9) + + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + ramp.mul_(min_inv_rho - max_inv_rho).add_(max_inv_rho).pow_(rho) + + return torch.cat([ramp, torch.zeros(1, device=device)]) + + +def create_aos_e_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): + """AOS-ε (Anime-Optimized Schedule for epsilon-prediction models).""" + rho = 7.0 + + p1_frac, p2_frac = 0.35, 0.7 + ramp_p1_val, ramp_p2_val = 0.4, 0.75 + + p1_steps = int(num_steps * p1_frac) + p2_steps = int(num_steps * p2_frac) + + phase1_ramp = torch.linspace(0, 1, p1_steps, device=device) ** 1.5 * ramp_p1_val + phase2_ramp = torch.linspace(ramp_p1_val, ramp_p2_val, p2_steps - p1_steps, device=device) + phase3_base = torch.linspace(0, 1, num_steps - p2_steps, device=device) ** 0.7 + phase3_ramp = phase3_base * (1 - ramp_p2_val) + ramp_p2_val + + if p1_steps == 0: phase1_ramp = torch.empty(0, device=device) + if p2_steps - p1_steps == 0: phase2_ramp = torch.empty(0, device=device) + if num_steps - p2_steps == 0: phase3_ramp = torch.empty(0, device=device) + + ramp = torch.cat([phase1_ramp, phase2_ramp, phase3_ramp]) + + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + + return torch.cat([sigmas, torch.zeros(1, device=device)]) + + +def create_aos_akashic_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): + """AkashicAOS v2: Detail-Progressive Schedule for EQ-VAE SDXL models.""" + rho = 7.0 + + u = torch.linspace(0, 1, num_steps, device=device) + + detail_power = 0.85 + u_progressive = u ** detail_power + + mid_boost_strength = 0.08 + mid_boost = mid_boost_strength * torch.sin(math.pi * u) * (1 - u * 0.5) + + u_modulated = u_progressive + mid_boost + + u_min, u_max = u_modulated.min(), u_modulated.max() + if u_max - u_min > 1e-8: + u_modulated = (u_modulated - u_min) / (u_max - u_min) + + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + u_modulated * (min_inv_rho - max_inv_rho)) ** rho + + for i in range(1, len(sigmas)): + if sigmas[i] >= sigmas[i-1]: + sigmas[i] = sigmas[i-1] * 0.995 + max_ratio = 1.5 + if i > 0 and sigmas[i-1] / sigmas[i] > max_ratio: + sigmas[i] = sigmas[i-1] / max_ratio + + return torch.cat([sigmas, torch.zeros(1, device=device)]) + + +def create_entropic_sigmas(sigma_max, sigma_min, num_steps, power=6.0, device='cpu'): + """Entropic power schedule.""" + rho = 7.0 + + linear_ramp = torch.linspace(0, 1, num_steps, device=device) + power_ramp = 1 - torch.linspace(1, 0, num_steps, device=device) ** power + + ramp = (linear_ramp + power_ramp) / 2.0 + + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + + return torch.cat([sigmas, torch.zeros(1, device=device)]) + + +def create_snr_optimized_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): + """Schedule optimized around log SNR = 0 region.""" + rho = 7.0 + + log_snr_max = 2 * torch.log(sigma_max) + log_snr_min = 2 * torch.log(sigma_min) + + t = torch.linspace(0, 1, num_steps, device=device) + + concentration_power = 3.0 + sigmoid_t = torch.sigmoid(concentration_power * (t - 0.5)) + + linear_t = t + blend_factor = 0.7 + combined_t = blend_factor * sigmoid_t + (1 - blend_factor) * linear_t + + log_snr = log_snr_max + combined_t * (log_snr_min - log_snr_max) + sigmas = torch.exp(log_snr / 2) + + return torch.cat([sigmas, torch.zeros(1, device=device)]) + + +def create_constant_rate_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): + """Constant rate of distributional change.""" + rho = 7.0 + + t = torch.linspace(0, 1, num_steps, device=device) + corrected_t = t + 0.3 * torch.sin(math.pi * t) * (1 - t) + + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + corrected_t * (min_inv_rho - max_inv_rho)) ** rho + + return torch.cat([sigmas, torch.zeros(1, device=device)]) + + +def create_adaptive_optimized_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): + """Adaptive schedule combining multiple strategies.""" + rho = 7.0 + + base_t = torch.linspace(0, 1, num_steps, device=device) + + strategies = [ + lambda t: t, + lambda t: t ** 0.8, + lambda t: t + 0.2 * torch.sin(2 * math.pi * t) * (1 - t), + lambda t: 1 / (1 + torch.exp(-3 * (t - 0.5))), + ] + + weights = [0.2, 0.3, 0.2, 0.3] + combined_t = sum(w * s(base_t) for w, s in zip(weights, strategies)) + + if (combined_t.max() - combined_t.min()) > 1e-6: + combined_t = (combined_t - combined_t.min()) / (combined_t.max() - combined_t.min()) + + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + combined_t * (min_inv_rho - max_inv_rho)) ** rho + + return torch.cat([sigmas, torch.zeros(1, device=device)]) + + +def create_cosine_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): + """Cosine-annealed schedule.""" + rho = 7.0 + u = torch.linspace(0, 1, num_steps, device=device) + t = (1 - torch.cos(math.pi * u)) / 2 + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + t * (min_inv_rho - max_inv_rho)) ** rho + return torch.cat([sigmas, torch.zeros(1, device=device)]) + + +def create_logsnr_uniform_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): + """Uniform in log-SNR space.""" + u = torch.linspace(0, 1, num_steps, device=device) + log_snr_max = 2 * torch.log(sigma_max) + log_snr_min = 2 * torch.log(sigma_min) + log_snr = log_snr_max + u * (log_snr_min - log_snr_max) + sigmas = torch.exp(log_snr / 2) + return torch.cat([sigmas, torch.zeros(1, device=device)]) + + +def create_tanh_midboost_sigmas(sigma_max, sigma_min, num_steps, device='cpu', k=4.0): + """Concentrate steps near mid-range sigmas.""" + rho = 7.0 + u = torch.linspace(0, 1, num_steps, device=device) + k_tensor = torch.tensor(k, device=device, dtype=u.dtype) + t = 0.5 * (torch.tanh(k_tensor * (u - 0.5)) / torch.tanh(k_tensor / 2) + 1.0) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + t * (min_inv_rho - max_inv_rho)) ** rho + return torch.cat([sigmas, torch.zeros(1, device=device)]) + + +def create_exponential_tail_sigmas(sigma_max, sigma_min, num_steps, device='cpu', pivot=0.7, gamma=0.8, beta=5.0): + """Faster early lock-in with extra resolution in final steps.""" + rho = 7.0 + u = torch.linspace(0, 1, num_steps, device=device) + + early_mask = u < pivot + late_mask = ~early_mask + + t = torch.empty_like(u) + t[early_mask] = (u[early_mask] / pivot) ** gamma * pivot + late_u = u[late_mask] + t[late_mask] = pivot + (1 - pivot) * (1 - torch.exp(-beta * (late_u - pivot) / (1 - pivot))) + + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + t * (min_inv_rho - max_inv_rho)) ** rho + + return torch.cat([sigmas, torch.zeros(1, device=device)]) + + +def create_jittered_karras_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): + """Karras schedule with controlled jitter.""" + if num_steps <= 0: + return torch.cat([sigma_max.unsqueeze(0), torch.zeros(1, device=device)]) + + rho = 7.0 + indices = torch.arange(num_steps, device=device, dtype=torch.float32) + denom = max(1, num_steps - 1) + + base = (indices + 0.5) / denom + jitter_seed = torch.sin((indices + 1) * 2.3999632) + jitter_strength = 0.35 + jitter = jitter_seed * jitter_strength / denom + + u = torch.clamp(base + jitter, 0.0, 1.0) + + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + u * (min_inv_rho - max_inv_rho)) ** rho + + return torch.cat([sigmas, torch.zeros(1, device=device)]) + + +def create_stochastic_sigmas(sigma_max, sigma_min, num_steps, device='cpu', noise_type='brownian', noise_scale=0.3, base_schedule='karras'): + """Stochastic scheduler with controlled randomness.""" + rho = 7.0 + + # Base schedule + if base_schedule == 'karras': + indices = torch.arange(num_steps, device=device, dtype=torch.float32) + u = (indices / max(1, num_steps - 1)) ** (1 / rho) + elif base_schedule == 'cosine': + u = torch.linspace(0, 1, num_steps, device=device) + u = (1 - torch.cos(math.pi * u)) / 2 + else: # uniform + u = torch.linspace(0, 1, num_steps, device=device) + + # Add noise + if noise_type == 'brownian': + noise = torch.randn(num_steps, device=device).cumsum(0) + noise = noise / noise.std() + elif noise_type == 'uniform': + noise = torch.rand(num_steps, device=device) * 2 - 1 + else: # normal + noise = torch.randn(num_steps, device=device) + + u_noisy = u + noise * noise_scale / num_steps + u_noisy = torch.clamp(u_noisy, 0, 1) + + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + u_noisy * (min_inv_rho - max_inv_rho)) ** rho + + # Sort the final sigmas descending so schedule is always noise→clean. + # Sorting u_noisy descending before the transform gives wrong order + # because the Karras mapping is monotone-decreasing in u. + sigmas, _ = torch.sort(sigmas, descending=True) + + return torch.cat([sigmas, torch.zeros(1, device=device)]) + + +def create_jys_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): + """ + JYS (Jump Your Steps) schedule using dynamically computed timestep sequences. + Strategy: Large jumps early, dense clustering in detail region, fine steps at end. + Ported from ComfyUI reference implementation. + """ + # _compute_jys_timesteps returns num_steps entries + a trailing 0. + # Strip the trailing 0 so we get exactly num_steps timesteps; the + # explicit zeros(1) terminator is appended below. + jys_timesteps = _compute_jys_timesteps(num_steps) + if jys_timesteps and jys_timesteps[-1] == 0: + jys_timesteps = jys_timesteps[:-1] + + rho = 7.0 + + normalized_timesteps = [(1000 - t) / 1000.0 for t in jys_timesteps] + t_tensor = torch.tensor(normalized_timesteps, device=device, dtype=torch.float32) + + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + t_tensor * (min_inv_rho - max_inv_rho)) ** rho + + sigmas, _ = torch.sort(sigmas, descending=True) + return torch.cat([sigmas, torch.zeros(1, device=device)]) + + +def _compute_jys_timesteps(num_steps): + """Dynamically compute optimised JYS timestep sequence (0..1000 scale).""" + if num_steps <= 0: + return [0] + if num_steps == 1: + return [1000, 0] + elif num_steps == 2: + return [1000, 500, 0] + elif num_steps == 3: + return [1000, 600, 200, 0] + + early_steps = max(1, int(num_steps * 0.2)) + final_steps = max(1, int(num_steps * 0.2)) + middle_steps = max(1, num_steps - early_steps - final_steps) + + early_jump_size = max(50, (1000 - 600) // early_steps) + early_timesteps = [] + current_t = 1000 + for _ in range(early_steps): + early_timesteps.append(int(current_t)) + current_t = max(600, current_t - early_jump_size) + + middle_timesteps = [] + structure_steps = max(1, middle_steps // 2) + structure_jump_size = max(10, (600 - 300) // structure_steps) + current_t = 600 + for _ in range(structure_steps): + middle_timesteps.append(int(current_t)) + current_t = max(300, current_t - structure_jump_size) + + detail_steps = middle_steps - structure_steps + if detail_steps > 0: + detail_jump_size = max(5, (300 - 200) // detail_steps) + current_t = 300 + for _ in range(detail_steps): + middle_timesteps.append(int(current_t)) + current_t = max(200, current_t - detail_jump_size) + + final_start = min(middle_timesteps) if middle_timesteps else 200 + final_jump_size = max(5, final_start // final_steps) + final_timesteps = [] + current_t = final_start + for _ in range(final_steps): + final_timesteps.append(int(current_t)) + current_t = max(0, current_t - final_jump_size) + + all_timesteps = early_timesteps + middle_timesteps + final_timesteps + unique_timesteps = list(dict.fromkeys(all_timesteps)) + unique_timesteps.sort(reverse=True) + + while len(unique_timesteps) < num_steps: + for i in range(len(unique_timesteps) - 1): + mid_point = (unique_timesteps[i] + unique_timesteps[i + 1]) // 2 + if mid_point not in unique_timesteps: + unique_timesteps.insert(i + 1, mid_point) + if len(unique_timesteps) >= num_steps: + break + + if len(unique_timesteps) > num_steps: + unique_timesteps = unique_timesteps[:num_steps] + + if unique_timesteps[-1] != 0: + unique_timesteps.append(0) + + return unique_timesteps + + +def create_hybrid_jys_karras_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): + """Hybrid: JYS mid-phase with Karras locks.""" + if num_steps <= 0: + return torch.cat([sigma_max.unsqueeze(0), torch.zeros(1, device=device)]) + + rho = 7.0 + + jys_sigmas = create_jys_sigmas(sigma_max, sigma_min, num_steps, device=device)[:-1] + + indices = torch.arange(num_steps, device=device, dtype=torch.float32) + denom = max(1, num_steps - 1) + base = (indices + 0.5) / denom + jitter_seed = torch.sin((indices + 1) * 2.3999632) + jitter_strength = 0.35 + jitter = jitter_seed * jitter_strength / denom + u = torch.clamp(base + jitter, 0.0, 1.0) + + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + karras_sigmas = (max_inv_rho + u * (min_inv_rho - max_inv_rho)) ** rho + + positions = torch.linspace(0, 1, num_steps, device=device) + jys_weight = torch.empty_like(positions) + early_mask = positions < 0.3 + mid_mask = (positions >= 0.3) & (positions < 0.8) + late_mask = positions >= 0.8 + jys_weight[early_mask] = 0.2 + 0.4 * (positions[early_mask] / 0.3) + jys_weight[mid_mask] = 0.6 + 0.3 * ((positions[mid_mask] - 0.3) / 0.5) + jys_weight[late_mask] = 0.9 + jys_weight = jys_weight.clamp(0.2, 0.9) + + log_jys = torch.log(jys_sigmas.clamp_min(1e-6)) + log_karras = torch.log(karras_sigmas.clamp_min(1e-6)) + log_hybrid = torch.lerp(log_karras, log_jys, jys_weight) + + hybrid = torch.exp(log_hybrid) + + smoothing = 1.0 - 0.05 * (1 - positions) ** 2 + hybrid = hybrid * smoothing + + for i in range(1, hybrid.shape[0]): + if hybrid[i] > hybrid[i - 1]: + hybrid[i] = hybrid[i - 1] * 0.999 + + return torch.cat([hybrid, torch.zeros(1, device=device)]) + + +def create_ays_sdxl_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): + """AYS (Align Your Steps) optimized for SDXL.""" + + AYS_SCHEDULES = { + 10: [1.0000, 0.8751, 0.7502, 0.6254, 0.5004, 0.3755, 0.2506, 0.1253, 0.0502, 0.0000], + 15: [1.0000, 0.9167, 0.8334, 0.7501, 0.6668, 0.5835, 0.5002, 0.4169, 0.3336, + 0.2503, 0.1670, 0.0837, 0.0335, 0.0084, 0.0000], + 20: [1.0000, 0.9375, 0.8750, 0.8125, 0.7500, 0.6875, 0.6250, 0.5625, 0.5000, + 0.4375, 0.3750, 0.3125, 0.2500, 0.1875, 0.1250, 0.0625, 0.0313, 0.0156, + 0.0039, 0.0000], + 25: [1.0000, 0.9500, 0.9000, 0.8500, 0.8000, 0.7500, 0.7000, 0.6500, 0.6000, + 0.5500, 0.5000, 0.4500, 0.4000, 0.3500, 0.3000, 0.2500, 0.2000, 0.1500, + 0.1000, 0.0625, 0.0391, 0.0195, 0.0098, 0.0024, 0.0000], + 30: [1.0000, 0.9583, 0.9167, 0.8750, 0.8333, 0.7917, 0.7500, 0.7083, 0.6667, + 0.6250, 0.5833, 0.5417, 0.5000, 0.4583, 0.4167, 0.3750, 0.3333, 0.2917, + 0.2500, 0.2083, 0.1667, 0.1250, 0.0833, 0.0521, 0.0326, 0.0163, 0.0081, + 0.0041, 0.0010, 0.0000], + } + + if num_steps in AYS_SCHEDULES: + normalized = torch.tensor(AYS_SCHEDULES[num_steps], device=device, dtype=torch.float32) + else: + available_steps = sorted(AYS_SCHEDULES.keys()) + + if num_steps < available_steps[0]: + ref_steps = available_steps[0] + elif num_steps > available_steps[-1]: + ref_steps = available_steps[-1] + else: + ref_steps = min([s for s in available_steps if s >= num_steps], default=available_steps[-1]) + + ref_schedule = np.array(AYS_SCHEDULES[ref_steps]) + + t_ref = np.linspace(0, 1, len(ref_schedule)) + t_new = np.linspace(0, 1, num_steps + 1) + + log_ref = np.log(ref_schedule + 1e-8) + log_ref[-1] = log_ref[-2] - 3.0 + + log_interp = np.interp(t_new, t_ref, log_ref) + normalized_np = np.exp(log_interp) + normalized_np[-1] = 0.0 + + normalized = torch.tensor(normalized_np, device=device, dtype=torch.float32) + + sigma_range = sigma_max - sigma_min + sigmas = normalized * sigma_range + sigma_min + + sigmas[0] = sigma_max + sigmas[-1] = 0.0 + + for i in range(1, len(sigmas) - 1): + if sigmas[i] >= sigmas[i-1]: + sigmas[i] = sigmas[i-1] * 0.999 + + # Append zero-terminator so output has num_steps+1 entries like all other schedulers. + return torch.cat([sigmas, torch.zeros(1, device=device)]) + + +def create_aos_akashic_alt_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): + """ + AkashicAOS Alt: Karras-based schedule with EQ-VAE-tuned warping. + Stronger detail-progressive bias (power=0.78) and shifted tanh crossover at t=0.55. + Adaptive rho scales with step count for multi-step solver stability. + """ + if num_steps <= 0: + return torch.zeros(1, device=device) + + rho = min(11.0, max(7.0, 7.0 + 2.0 * (20.0 / max(num_steps, 10)))) + u = torch.linspace(0, 1, num_steps, device=device) + + detail_power = 0.78 + u_detail = u ** detail_power + + t_center = 0.55 + beta = 0.07 + gamma = 4.0 + crossover = beta * torch.tanh(gamma * (u - t_center)) + u_modulated = u_detail + crossover + + u_min, u_max = u_modulated.min(), u_modulated.max() + if u_max - u_min > 1e-8: + u_modulated = (u_modulated - u_min) / (u_max - u_min) + + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + u_modulated * (min_inv_rho - max_inv_rho)) ** rho + + max_ratio = 1.5 + for i in range(1, len(sigmas)): + if sigmas[i] >= sigmas[i - 1]: + sigmas[i] = sigmas[i - 1] * 0.995 + if sigmas[i - 1] / sigmas[i].clamp(min=1e-10) > max_ratio: + sigmas[i] = sigmas[i - 1] / max_ratio + + return torch.cat([sigmas, torch.zeros(1, device=device)]) + + +def create_akashic_eqflow_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): + """ + AkashicEQFlow: Robust crossover-focused log-SNR schedule for EQ-VAE models. + Concentrates steps around the structure-to-detail transition in logSNR space, + blended with a Karras prior. Adaptive density width + ratio slew-rate limiting. + """ + if num_steps <= 0: + return torch.zeros(1, device=device) + + lambda_min = -2.0 * math.log(max(float(sigma_max), 1e-10)) + lambda_max = -2.0 * math.log(max(float(sigma_min), 1e-10)) + lambda_range = max(lambda_max - lambda_min, 1e-8) + + step_factor = min(1.0, max(0.0, (num_steps - 16) / 30.0)) + lambda_center = 0.20 + 0.15 * step_factor + u_center = (lambda_center - lambda_min) / lambda_range + u_center = float(min(0.88, max(0.12, u_center))) + + concentration = min(3.2, max(1.35, 1.1 + num_steps / 16.0)) + base_width = min(0.30, max(0.18, 0.31 - 0.0028 * num_steps)) + width_left = base_width * 1.06 + width_right = base_width * 0.94 + detail_side_gain = 1.08 + 0.04 * step_factor + + N = 1200 + t = torch.linspace(0, 1, N, device=device) + delta = t - u_center + left_core = torch.exp(-((delta / width_left) ** 2) / 2.0) + right_core = detail_side_gain * torch.exp(-((delta / width_right) ** 2) / 2.0) + crossover_core = torch.where(delta <= 0, left_core, right_core) + + detail_floor = 0.08 * (t ** 1.4) + composition_floor = 0.05 * ((1 - t) ** 1.7) + density = 1.0 + concentration * crossover_core + detail_floor + composition_floor + + dt_val = 1.0 / (N - 1) + cdf = torch.zeros(N, device=device) + cdf[1:] = torch.cumsum((density[:-1] + density[1:]) * 0.5 * dt_val, dim=0) + cdf = cdf / cdf[-1].clamp(min=1e-12) + + targets = torch.linspace(0, 1, num_steps, device=device) + indices = torch.searchsorted(cdf, targets).clamp(1, N - 1) + lo = indices - 1 + hi = indices + frac = (targets - cdf[lo]) / (cdf[hi] - cdf[lo]).clamp(min=1e-12) + u_steps = t[lo] + frac * (t[hi] - t[lo]) + + lambdas_eqflow = lambda_min + u_steps * lambda_range + + rho = min(10.0, max(7.0, 7.0 + 1.5 * (22.0 / max(num_steps, 12)))) + u_karras = torch.linspace(0, 1, num_steps, device=device) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas_karras = (max_inv_rho + u_karras * (min_inv_rho - max_inv_rho)) ** rho + lambdas_karras = -2.0 * torch.log(sigmas_karras.clamp(min=1e-10)) + + blend_eqflow = min(0.60, max(0.35, 0.38 + num_steps / 200.0)) + lambdas = (1.0 - blend_eqflow) * lambdas_karras + blend_eqflow * lambdas_eqflow + sigmas = torch.exp(-lambdas / 2.0) + + if num_steps >= 40: + max_ratio = 1.50 + elif num_steps >= 28: + max_ratio = 1.55 + elif num_steps >= 18: + max_ratio = 1.65 + else: + max_ratio = 1.85 + ratio_slew = 1.18 + prev_ratio = None + + sigmas[0] = sigma_max + for i in range(1, len(sigmas)): + if sigmas[i] >= sigmas[i - 1]: + sigmas[i] = sigmas[i - 1] * 0.995 + ratio = float((sigmas[i - 1] / sigmas[i].clamp(min=1e-10)).item()) + ratio = min(ratio, max_ratio) + if prev_ratio is not None: + ratio = min(ratio, prev_ratio * ratio_slew) + ratio = max(ratio, prev_ratio / ratio_slew) + ratio = max(1.001, ratio) + sigmas[i] = sigmas[i - 1] / ratio + prev_ratio = ratio + + return torch.cat([sigmas, torch.zeros(1, device=device)]) + + +def apply_custom_scheduler(sigmas, scheduler_type="Standard"): + """ + Apply a custom sigma schedule. + + sigma_min uses sigmas[-2] (last non-zero step), never the zero-terminator. + Each scheduler is invoked via a lambda so keyword args with non-standard + defaults (e.g. Entropic's `power`) are always passed correctly. + """ + if scheduler_type == "Standard" or len(sigmas) < 2: + return sigmas + + sigma_max = sigmas[0] + # Use the last non-zero sigma as sigma_min; sigmas[-1] is always 0. + sigma_min = sigmas[-2] if len(sigmas) >= 2 else sigmas[0] + if sigma_min <= 0: + sigma_min = sigma_max * 1e-3 + num_steps = len(sigmas) - 1 + device = sigmas.device + + scheduler_map = { + "AOS-V": lambda: create_aos_v_sigmas(sigma_max, sigma_min, num_steps, device), + "AOS-Epsilon": lambda: create_aos_e_sigmas(sigma_max, sigma_min, num_steps, device), + "AkashicAOS": lambda: create_aos_akashic_sigmas(sigma_max, sigma_min, num_steps, device), + "Entropic": lambda: create_entropic_sigmas(sigma_max, sigma_min, num_steps, power=6.0, device=device), + "SNR-Optimized": lambda: create_snr_optimized_sigmas(sigma_max, sigma_min, num_steps, device), + "Constant-Rate": lambda: create_constant_rate_sigmas(sigma_max, sigma_min, num_steps, device), + "Adaptive-Optimized": lambda: create_adaptive_optimized_sigmas(sigma_max, sigma_min, num_steps, device), + "Cosine-Annealed": lambda: create_cosine_sigmas(sigma_max, sigma_min, num_steps, device), + "LogSNR-Uniform": lambda: create_logsnr_uniform_sigmas(sigma_max, sigma_min, num_steps, device), + "Tanh Mid-Boost": lambda: create_tanh_midboost_sigmas(sigma_max, sigma_min, num_steps, device), + "Exponential Tail": lambda: create_exponential_tail_sigmas(sigma_max, sigma_min, num_steps, device), + "Jittered-Karras": lambda: create_jittered_karras_sigmas(sigma_max, sigma_min, num_steps, device), + "Stochastic": lambda: create_stochastic_sigmas(sigma_max, sigma_min, num_steps, device=device), + "JYS (Dynamic)": lambda: create_jys_sigmas(sigma_max, sigma_min, num_steps, device), + "Hybrid JYS-Karras": lambda: create_hybrid_jys_karras_sigmas(sigma_max, sigma_min, num_steps, device), + "AYS-SDXL": lambda: create_ays_sdxl_sigmas(sigma_max, sigma_min, num_steps, device), + "AkashicAOS Alt": lambda: create_aos_akashic_alt_sigmas(sigma_max, sigma_min, num_steps, device), + "AkashicEQFlow": lambda: create_akashic_eqflow_sigmas(sigma_max, sigma_min, num_steps, device), + } + + fn = scheduler_map.get(scheduler_type) + if fn is not None: + try: + result = fn() + if result is not None and len(result) > 1: + return result + print(f"⚠️ Scheduler {scheduler_type} returned empty/None, using standard") + except Exception as e: + print(f"⚠️ Scheduler {scheduler_type} failed: {e}, using standard") + + return sigmas + + +# ============================================================================ +# ============================================================================ +# K-DIFFUSION SAMPLERS with Custom Sampler Integration +# ============================================================================ + +@torch.no_grad() +def sample_adept_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """Euler sampler with Adept weight scaling OR custom sampler.""" + + # CUSTOM SAMPLER INTEGRATION + if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False): + custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2') + print(f"🌀 Redirecting to {custom_type}") + + # Apply scheduler to sigmas for custom samplers + scheduler = ADEPT_STATE.get('scheduler', 'Standard') + if scheduler != "Standard": + sigmas = apply_custom_scheduler(sigmas, scheduler) + print(f" 📊 Applied {scheduler} scheduler") + + if custom_type == "Akashic Solver v2": + return sample_akashic_solver( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + tau=ADEPT_STATE.get('tau', 0.5), + eta=ADEPT_STATE.get('eta', 1.0), + s_noise=ADEPT_STATE.get('s_noise', 1.0), + adaptive_eta=ADEPT_STATE.get('adaptive_eta', True), + phase_strength=ADEPT_STATE.get('phase_strength', 0.5), + order=ADEPT_STATE.get('solver_order', 2), + smea_strength=ADEPT_STATE.get('smea_strength', 0.0), + ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0), + use_detail_enhancement=False, + settings={}, + eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off') + ) + elif custom_type == "Adept Solver": + return sample_adept_solver( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + order=ADEPT_STATE.get('solver_order', 2), + use_corrector=ADEPT_STATE.get('use_corrector', True), + use_detail_enhancement=False, + settings={} + ) + elif custom_type == "Adept Ancestral Solver": + return sample_adept_ancestral_solver( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + eta=ADEPT_STATE.get('eta', 1.0), + s_noise=ADEPT_STATE.get('s_noise', 1.0), + adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), + phase_noise=ADEPT_STATE.get('phase_noise', False), + phase_strength=ADEPT_STATE.get('phase_strength', 0.5), + enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False), + use_detail_enhancement=False, + settings={} + ) + elif custom_type == "Mirror Correction Euler": + return sample_mirror_correction_euler( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + eta=ADEPT_STATE.get('eta', 1.0), + s_noise=ADEPT_STATE.get('s_noise', 1.0), + correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5), + smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False) + ) + + # STANDARD K-DIFFUSION MODE (from v3 - unchanged) + if not ADEPT_STATE.get('enabled', False): + global ORIGINAL_SAMPLERS + if 'euler' in ORIGINAL_SAMPLERS: + return ORIGINAL_SAMPLERS['euler'](model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise) + return _basic_euler(model, x, sigmas, extra_args, callback, disable) + + # Apply custom scheduler deterministically (before the loop, not via p.sampler.model_wrap) + _sched = ADEPT_STATE.get('scheduler', 'Standard') + if _sched != 'Standard': + sigmas = apply_custom_scheduler(sigmas, _sched) + + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + base_scale = ADEPT_STATE.get('scale', 1.0) + shift = ADEPT_STATE.get('shift', 0.0) + start_pct = ADEPT_STATE.get('start_pct', 0.0) + end_pct = ADEPT_STATE.get('end_pct', 1.0) + + try: + unet_model = shared.sd_model.model.diffusion_model + except AttributeError: + unet_model = None + + total_steps = len(sigmas) - 1 + + for i in range(total_steps): + sigma = sigmas[i] + gamma = min(s_churn / total_steps, 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0 + + current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct) + + with AdeptWeightPatcher(unet_model, current_scale, shift): + eps = torch.randn_like(x) * s_noise if gamma > 0 else 0 + sigma_hat = sigma * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigma ** 2) ** 0.5 + + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + dt = sigmas[i + 1] - sigma_hat + x = x + d * dt + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigma_hat, 'denoised': denoised}) + + return x + +def sample_adept_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, + disable=None, eta=1.0, s_noise=1.0, noise_sampler=None): + """Euler Ancestral with Adept weight scaling.""" + + # CUSTOM SAMPLER INTEGRATION + if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False): + custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2') + print(f"🌀 Redirecting to {custom_type}") + + # Apply scheduler to sigmas for custom samplers + scheduler = ADEPT_STATE.get('scheduler', 'Standard') + if scheduler != "Standard": + sigmas = apply_custom_scheduler(sigmas, scheduler) + print(f" 📊 Applied {scheduler} scheduler") + + if custom_type == "Akashic Solver v2": + return sample_akashic_solver( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0), + s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True), + phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2), + smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0), + use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off') + ) + elif custom_type == "Adept Solver": + return sample_adept_solver( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True), + use_detail_enhancement=False, settings={} + ) + elif custom_type == "Adept Ancestral Solver": + return sample_adept_ancestral_solver( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), + adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False), + phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False), + use_detail_enhancement=False, settings={} + ) + elif custom_type == "Mirror Correction Euler": + return sample_mirror_correction_euler( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + eta=ADEPT_STATE.get('eta', 1.0), + s_noise=ADEPT_STATE.get('s_noise', 1.0), + correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5), + smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False) + ) + + + + if not ADEPT_STATE.get('enabled', False): + global ORIGINAL_SAMPLERS + if 'euler_ancestral' in ORIGINAL_SAMPLERS: + return ORIGINAL_SAMPLERS['euler_ancestral'](model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) + return _basic_euler_ancestral(model, x, sigmas, extra_args, callback, disable, eta, s_noise) + # Apply custom scheduler deterministically (before the loop, not via p.sampler.model_wrap) + _sched = ADEPT_STATE.get('scheduler', 'Standard') + if _sched != 'Standard': + sigmas = apply_custom_scheduler(sigmas, _sched) + + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + # Get settings + base_scale = ADEPT_STATE.get('scale', 1.0) + shift = ADEPT_STATE.get('shift', 0.0) + start_pct = ADEPT_STATE.get('start_pct', 0.0) + end_pct = ADEPT_STATE.get('end_pct', 1.0) + use_adaptive_eta = ADEPT_STATE.get('adaptive_eta', False) + current_eta = ADEPT_STATE.get('eta', eta) + current_s_noise = ADEPT_STATE.get('s_noise', s_noise) + + # Get UNet + try: + unet_model = shared.sd_model.model.diffusion_model + except AttributeError: + unet_model = None + + if noise_sampler is None: + noise_sampler = default_noise_sampler(x) + + total_steps = len(sigmas) - 1 + print(f"✅ Adept Euler A active: scale={base_scale:.2f}, eta={current_eta:.2f}") + + for i in trange(len(sigmas) - 1, disable=disable, desc="Adept Euler A"): + sigma = sigmas[i] + sigma_next = sigmas[i + 1] + + progress = i / max(total_steps, 1) + + # Adaptive eta + if use_adaptive_eta: + if progress < 0.3: + current_eta = eta * 1.08 + elif progress < 0.7: + current_eta = eta * 0.95 + else: + current_eta = eta * 1.02 + else: + current_eta = eta + + # Dynamic scale + current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct) + + # Evaluate model with weight patching + if should_patch_weights(unet_model, current_scale, shift): + with AdeptWeightPatcher(unet_model, current_scale, shift): + denoised = model(x, sigma * s_in, **extra_args) + else: + denoised = model(x, sigma * s_in, **extra_args) + + # Euler Ancestral step + sigma_down, sigma_up = get_ancestral_step(sigma, sigma_next, current_eta) + d = to_d(x, sigma, denoised) + + if torch.isnan(d).any() or torch.isinf(d).any(): + d = torch.nan_to_num(d, nan=0.0, posinf=1.0, neginf=-1.0) + + dt = sigma_down - sigma + x = x + d * dt + + if sigma_up > 0: + noise = noise_sampler(sigma, sigma_next) * current_s_noise + x = x + noise * sigma_up + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) + + return x + + +def _basic_euler(model, x, sigmas, extra_args=None, callback=None, disable=None): + """Fallback basic Euler (used when ORIGINAL_SAMPLERS has no 'euler' key).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + sigma = sigmas[i] + denoised = model(x, sigma * s_in, **extra_args) + d = to_d(x, sigma, denoised) + dt = sigmas[i + 1] - sigma + x = x + d * dt + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) + return x + + +def _basic_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0): + """Fallback basic Euler Ancestral.""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + noise_sampler = default_noise_sampler(x) + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta) + d = to_d(x, sigmas[i], denoised) + dt = sigma_down - sigmas[i] + x = x + d * dt + if sigma_up > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) + + return x + + +@torch.no_grad() + +def sample_adept_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """Heun sampler with Adept weight scaling.""" + + # CUSTOM SAMPLER INTEGRATION + if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False): + custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2') + print(f"🌀 Redirecting to {custom_type}") + + # Apply scheduler to sigmas for custom samplers + scheduler = ADEPT_STATE.get('scheduler', 'Standard') + if scheduler != "Standard": + sigmas = apply_custom_scheduler(sigmas, scheduler) + print(f" 📊 Applied {scheduler} scheduler") + + if custom_type == "Akashic Solver v2": + return sample_akashic_solver( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0), + s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True), + phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2), + smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0), + use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off') + ) + elif custom_type == "Adept Solver": + return sample_adept_solver( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True), + use_detail_enhancement=False, settings={} + ) + elif custom_type == "Adept Ancestral Solver": + return sample_adept_ancestral_solver( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), + adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False), + phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False), + use_detail_enhancement=False, settings={} + ) + elif custom_type == "Mirror Correction Euler": + return sample_mirror_correction_euler( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + eta=ADEPT_STATE.get('eta', 1.0), + s_noise=ADEPT_STATE.get('s_noise', 1.0), + correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5), + smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False) + ) + + + + if not ADEPT_STATE.get('enabled', False): + global ORIGINAL_SAMPLERS + if 'heun' in ORIGINAL_SAMPLERS: + return ORIGINAL_SAMPLERS['heun'](model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise) + return _basic_heun(model, x, sigmas, extra_args, callback, disable) + # Apply custom scheduler deterministically (before the loop, not via p.sampler.model_wrap) + _sched = ADEPT_STATE.get('scheduler', 'Standard') + if _sched != 'Standard': + sigmas = apply_custom_scheduler(sigmas, _sched) + + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + # Get settings + base_scale = ADEPT_STATE.get('scale', 1.0) + shift = ADEPT_STATE.get('shift', 0.0) + start_pct = ADEPT_STATE.get('start_pct', 0.0) + end_pct = ADEPT_STATE.get('end_pct', 1.0) + + # Get UNet + try: + unet_model = shared.sd_model.model.diffusion_model + except AttributeError: + unet_model = None + + total_steps = len(sigmas) - 1 + print(f"✅ Adept Heun active: scale={base_scale:.2f}") + + for i in trange(len(sigmas) - 1, disable=disable, desc="Adept Heun"): + sigma = sigmas[i] + sigma_next = sigmas[i + 1] + + # Dynamic scale + current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct) + + # First evaluation + if should_patch_weights(unet_model, current_scale, shift): + with AdeptWeightPatcher(unet_model, current_scale, shift): + denoised = model(x, sigma * s_in, **extra_args) + else: + denoised = model(x, sigma * s_in, **extra_args) + + d = to_d(x, sigma, denoised) + + if torch.isnan(d).any() or torch.isinf(d).any(): + d = torch.nan_to_num(d, nan=0.0, posinf=1.0, neginf=-1.0) + + dt = sigma_next - sigma + + if sigma_next == 0: + # Last step + x = x + d * dt + else: + # Heun's method: two-stage + x_2 = x + d * dt + + # Second evaluation + if should_patch_weights(unet_model, current_scale, shift): + with AdeptWeightPatcher(unet_model, current_scale, shift): + denoised_2 = model(x_2, sigma_next * s_in, **extra_args) + else: + denoised_2 = model(x_2, sigma_next * s_in, **extra_args) + + d_2 = to_d(x_2, sigma_next, denoised_2) + + if torch.isnan(d_2).any() or torch.isinf(d_2).any(): + d_2 = torch.nan_to_num(d_2, nan=0.0, posinf=1.0, neginf=-1.0) + + # Average + d_prime = (d + d_2) / 2 + x = x + d_prime * dt + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) + + return x + + +def _basic_heun(model, x, sigmas, extra_args=None, callback=None, disable=None): + """Fallback basic Heun.""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + d = to_d(x, sigmas[i], denoised) + dt = sigmas[i + 1] - sigmas[i] + + if sigmas[i + 1] == 0: + x = x + d * dt + else: + x_2 = x + d * dt + denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + d_prime = (d + d_2) / 2 + x = x + d_prime * dt + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) + + return x + + +@torch.no_grad() + +def sample_adept_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None): + """DPM++ 2M sampler with Adept weight scaling.""" + + # CUSTOM SAMPLER INTEGRATION + if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False): + custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2') + print(f"🌀 Redirecting to {custom_type}") + + # Apply scheduler to sigmas for custom samplers + scheduler = ADEPT_STATE.get('scheduler', 'Standard') + if scheduler != "Standard": + sigmas = apply_custom_scheduler(sigmas, scheduler) + print(f" 📊 Applied {scheduler} scheduler") + + if custom_type == "Akashic Solver v2": + return sample_akashic_solver( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0), + s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True), + phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2), + smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0), + use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off') + ) + elif custom_type == "Adept Solver": + return sample_adept_solver( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True), + use_detail_enhancement=False, settings={} + ) + elif custom_type == "Adept Ancestral Solver": + return sample_adept_ancestral_solver( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), + adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False), + phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False), + use_detail_enhancement=False, settings={} + ) + elif custom_type == "Mirror Correction Euler": + return sample_mirror_correction_euler( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + eta=ADEPT_STATE.get('eta', 1.0), + s_noise=ADEPT_STATE.get('s_noise', 1.0), + correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5), + smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False) + ) + + + + if not ADEPT_STATE.get('enabled', False): + global ORIGINAL_SAMPLERS + if 'dpmpp_2m' in ORIGINAL_SAMPLERS: + return ORIGINAL_SAMPLERS['dpmpp_2m'](model, x, sigmas, extra_args, callback, disable) + return _basic_dpmpp_2m(model, x, sigmas, extra_args, callback, disable) + # Apply custom scheduler deterministically (before the loop, not via p.sampler.model_wrap) + _sched = ADEPT_STATE.get('scheduler', 'Standard') + if _sched != 'Standard': + sigmas = apply_custom_scheduler(sigmas, _sched) + + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + # Get settings + base_scale = ADEPT_STATE.get('scale', 1.0) + shift = ADEPT_STATE.get('shift', 0.0) + start_pct = ADEPT_STATE.get('start_pct', 0.0) + end_pct = ADEPT_STATE.get('end_pct', 1.0) + + # Get UNet + try: + unet_model = shared.sd_model.model.diffusion_model + except AttributeError: + unet_model = None + + total_steps = len(sigmas) - 1 + print(f"✅ Adept DPM++ 2M active: scale={base_scale:.2f}") + + old_denoised = None + + for i in trange(len(sigmas) - 1, disable=disable, desc="Adept DPM++ 2M"): + sigma = sigmas[i] + sigma_next = sigmas[i + 1] + + # Dynamic scale + current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct) + + # Evaluate model with weight patching + if should_patch_weights(unet_model, current_scale, shift): + with AdeptWeightPatcher(unet_model, current_scale, shift): + denoised = model(x, sigma * s_in, **extra_args) + else: + denoised = model(x, sigma * s_in, **extra_args) + + # DPM++ 2M step + t, t_next = sigma, sigma_next + h = t_next - t + + if old_denoised is None or sigma_next == 0: + # First step (Euler) + x = (sigma_next / sigma) * x - (-h).expm1() * denoised + else: + # Second order + h_last = t - sigmas[i - 1] + r = h_last / h + denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised + x = (sigma_next / sigma) * x - (-h).expm1() * denoised_d + + old_denoised = denoised + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) + + return x + + +def _basic_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None): + """Fallback basic DPM++ 2M.""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + old_denoised = None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + t, t_next = sigmas[i], sigmas[i + 1] + h = t_next - t + + if old_denoised is None or sigmas[i + 1] == 0: + x = (t_next / t) * x - (-h).expm1() * denoised + else: + h_last = t - sigmas[i - 1] + r = h_last / h + denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised + x = (t_next / t) * x - (-h).expm1() * denoised_d + + old_denoised = denoised + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) + + return x + + +@torch.no_grad() + +def sample_adept_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0, noise_sampler=None): + """DPM++ 2S Ancestral with Adept weight scaling.""" + + # CUSTOM SAMPLER INTEGRATION + if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False): + custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2') + print(f"🌀 Redirecting to {custom_type}") + + # Apply scheduler to sigmas for custom samplers + scheduler = ADEPT_STATE.get('scheduler', 'Standard') + if scheduler != "Standard": + sigmas = apply_custom_scheduler(sigmas, scheduler) + print(f" 📊 Applied {scheduler} scheduler") + + if custom_type == "Akashic Solver v2": + return sample_akashic_solver( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0), + s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True), + phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2), + smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0), + use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off') + ) + elif custom_type == "Adept Solver": + return sample_adept_solver( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True), + use_detail_enhancement=False, settings={} + ) + elif custom_type == "Adept Ancestral Solver": + return sample_adept_ancestral_solver( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), + adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False), + phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False), + use_detail_enhancement=False, settings={} + ) + elif custom_type == "Mirror Correction Euler": + return sample_mirror_correction_euler( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + eta=ADEPT_STATE.get('eta', 1.0), + s_noise=ADEPT_STATE.get('s_noise', 1.0), + correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5), + smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False) + ) + + + + if not ADEPT_STATE.get('enabled', False): + global ORIGINAL_SAMPLERS + if 'dpmpp_2s_ancestral' in ORIGINAL_SAMPLERS: + return ORIGINAL_SAMPLERS['dpmpp_2s_ancestral'](model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) + return _basic_dpmpp_2s_ancestral(model, x, sigmas, extra_args, callback, disable, eta, s_noise) + # Apply custom scheduler deterministically (before the loop, not via p.sampler.model_wrap) + _sched = ADEPT_STATE.get('scheduler', 'Standard') + if _sched != 'Standard': + sigmas = apply_custom_scheduler(sigmas, _sched) + + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + # Get settings + base_scale = ADEPT_STATE.get('scale', 1.0) + shift = ADEPT_STATE.get('shift', 0.0) + start_pct = ADEPT_STATE.get('start_pct', 0.0) + end_pct = ADEPT_STATE.get('end_pct', 1.0) + current_eta = ADEPT_STATE.get('eta', eta) + current_s_noise = ADEPT_STATE.get('s_noise', s_noise) + + # Get UNet + try: + unet_model = shared.sd_model.model.diffusion_model + except AttributeError: + unet_model = None + + if noise_sampler is None: + noise_sampler = default_noise_sampler(x) + + total_steps = len(sigmas) - 1 + print(f"✅ Adept DPM++ 2S A active: scale={base_scale:.2f}") + + for i in trange(len(sigmas) - 1, disable=disable, desc="Adept DPM++ 2S A"): + sigma = sigmas[i] + sigma_next = sigmas[i + 1] + + # Dynamic scale + current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct) + + # First evaluation + if should_patch_weights(unet_model, current_scale, shift): + with AdeptWeightPatcher(unet_model, current_scale, shift): + denoised = model(x, sigma * s_in, **extra_args) + else: + denoised = model(x, sigma * s_in, **extra_args) + + # DPM++ 2S step with ancestral noise + sigma_down, sigma_up = get_ancestral_step(sigma, sigma_next, current_eta) + + if sigma_down == 0: + d = to_d(x, sigma, denoised) + x = x + d * (sigma_down - sigma) + else: + # Midpoint method + t, t_next = sigma, sigma_down + h = t_next - t + s = t + h * 0.5 + + # Step to midpoint + x_mid = (s / t) * x - (-(h * 0.5)).expm1() * denoised + + # Evaluate at midpoint + if should_patch_weights(unet_model, current_scale, shift): + with AdeptWeightPatcher(unet_model, current_scale, shift): + denoised_mid = model(x_mid, s * s_in, **extra_args) + else: + denoised_mid = model(x_mid, s * s_in, **extra_args) + + # Full step using midpoint + x = (t_next / t) * x - (-h).expm1() * denoised_mid + + # Add ancestral noise + if sigma_up > 0: + noise = noise_sampler(sigma, sigma_next) * current_s_noise + x = x + noise * sigma_up + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) + + return x + + +def _basic_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0): + """Fallback basic DPM++ 2S Ancestral.""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + noise_sampler = default_noise_sampler(x) + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta) + + if sigma_down == 0: + d = to_d(x, sigmas[i], denoised) + x = x + d * (sigma_down - sigmas[i]) + else: + t, t_next = sigmas[i], sigma_down + h = t_next - t + s = t + h * 0.5 + x_mid = (s / t) * x - (-(h * 0.5)).expm1() * denoised + denoised_mid = model(x_mid, s * s_in, **extra_args) + x = (t_next / t) * x - (-h).expm1() * denoised_mid + + if sigma_up > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) + + return x + + +@torch.no_grad() + +def sample_adept_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): + """LMS sampler with Adept weight scaling.""" + + # CUSTOM SAMPLER INTEGRATION + if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False): + custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2') + print(f"🌀 Redirecting to {custom_type}") + + # Apply scheduler to sigmas for custom samplers + scheduler = ADEPT_STATE.get('scheduler', 'Standard') + if scheduler != "Standard": + sigmas = apply_custom_scheduler(sigmas, scheduler) + print(f" 📊 Applied {scheduler} scheduler") + + if custom_type == "Akashic Solver v2": + return sample_akashic_solver( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0), + s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True), + phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2), + smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0), + use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off') + ) + elif custom_type == "Adept Solver": + return sample_adept_solver( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True), + use_detail_enhancement=False, settings={} + ) + elif custom_type == "Adept Ancestral Solver": + return sample_adept_ancestral_solver( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), + adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False), + phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False), + use_detail_enhancement=False, settings={} + ) + elif custom_type == "Mirror Correction Euler": + return sample_mirror_correction_euler( + model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, + eta=ADEPT_STATE.get('eta', 1.0), + s_noise=ADEPT_STATE.get('s_noise', 1.0), + correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5), + smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False) + ) + + + + if not ADEPT_STATE.get('enabled', False): + global ORIGINAL_SAMPLERS + if 'lms' in ORIGINAL_SAMPLERS: + return ORIGINAL_SAMPLERS['lms'](model, x, sigmas, extra_args, callback, disable, order) + return _basic_lms(model, x, sigmas, extra_args, callback, disable, order) + # Apply custom scheduler deterministically (before the loop, not via p.sampler.model_wrap) + _sched = ADEPT_STATE.get('scheduler', 'Standard') + if _sched != 'Standard': + sigmas = apply_custom_scheduler(sigmas, _sched) + + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + # Get settings + base_scale = ADEPT_STATE.get('scale', 1.0) + shift = ADEPT_STATE.get('shift', 0.0) + start_pct = ADEPT_STATE.get('start_pct', 0.0) + end_pct = ADEPT_STATE.get('end_pct', 1.0) + + # Get UNet + try: + unet_model = shared.sd_model.model.diffusion_model + except AttributeError: + unet_model = None + + total_steps = len(sigmas) - 1 + print(f"✅ Adept LMS active: scale={base_scale:.2f}, order={order}") + + ds = [] + + for i in trange(len(sigmas) - 1, disable=disable, desc="Adept LMS"): + sigma = sigmas[i] + + # Dynamic scale + current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct) + + # Evaluate model with weight patching + if should_patch_weights(unet_model, current_scale, shift): + with AdeptWeightPatcher(unet_model, current_scale, shift): + denoised = model(x, sigma * s_in, **extra_args) + else: + denoised = model(x, sigma * s_in, **extra_args) + + d = to_d(x, sigma, denoised) + ds.append(d) + + if len(ds) > order: + ds.pop(0) + + # Linear multistep coefficients + cur_order = min(i + 1, order) + coeffs = [1.0] + + for j in range(1, cur_order): + prod = 1.0 + for k in range(cur_order): + if k != j: + prod *= (sigmas[i] - sigmas[i - k]) / (sigmas[i - j] - sigmas[i - k]) + coeffs.append(prod) + + # Apply multistep + d_multistep = sum(c * d_val for c, d_val in zip(coeffs, reversed(ds[-cur_order:]))) + + dt = sigmas[i + 1] - sigma + x = x + d_multistep * dt + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) + + return x + + +def _basic_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): + """Fallback basic LMS.""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + ds = [] + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + d = to_d(x, sigmas[i], denoised) + ds.append(d) + + if len(ds) > order: + ds.pop(0) + + cur_order = min(i + 1, order) + coeffs = [1.0] + for j in range(1, cur_order): + prod = 1.0 + for k in range(cur_order): + if k != j: + prod *= (sigmas[i] - sigmas[i - k]) / (sigmas[i - j] - sigmas[i - k]) + coeffs.append(prod) + + d_multistep = sum(c * d_val for c, d_val in zip(coeffs, reversed(ds[-cur_order:]))) + dt = sigmas[i + 1] - sigmas[i] + x = x + d_multistep * dt + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) + + return x + + +# ============================================================================ +# MONKEY PATCHING +# ============================================================================ + +def patch_k_diffusion(): + """Apply monkey patches to ALL k-diffusion samplers.""" + global ORIGINAL_SAMPLERS + + samplers_to_patch = { + 'sample_euler': sample_adept_euler, + 'sample_euler_ancestral': sample_adept_euler_ancestral, + 'sample_heun': sample_adept_heun, + 'sample_dpmpp_2m': sample_adept_dpmpp_2m, + 'sample_dpmpp_2s_ancestral': sample_adept_dpmpp_2s_ancestral, + 'sample_lms': sample_adept_lms, + } + + patched_count = 0 + for original_name, adept_func in samplers_to_patch.items(): + if not hasattr(k_diffusion.sampling, original_name): + continue + key = original_name.replace('sample_', '') + current_func = getattr(k_diffusion.sampling, original_name) + # Only save original if we haven't stored it yet (avoid saving already-patched func) + if key not in ORIGINAL_SAMPLERS: + ORIGINAL_SAMPLERS[key] = current_func + # Always (re-)apply our patch if it isn't already there + if current_func is not adept_func: + setattr(k_diffusion.sampling, original_name, adept_func) + patched_count += 1 + + print(f"✅ Adept Sampler v5: Patched {patched_count} samplers") + print(f" Samplers: Euler, Euler A, Heun, DPM++ 2M, DPM++ 2S A, LMS") + print(f" Schedulers: 18 types available") + + +def unpatch_k_diffusion(): + """ + Restore original k-diffusion samplers. + + Safe-unpatch strategy: before restoring we check whether the live slot + still holds *our* wrapper. If another extension has wrapped us on top + (i.e. live_func is not our adept_func but also not the original we + saved), blindly restoring would silently remove *their* wrapper too. + In that case we skip the restore for that slot and log a warning so the + operator knows the coexistence situation. + """ + global ORIGINAL_SAMPLERS + + adept_wrappers = { + 'euler': 'sample_euler', + 'euler_ancestral': 'sample_euler_ancestral', + 'heun': 'sample_heun', + 'dpmpp_2m': 'sample_dpmpp_2m', + 'dpmpp_2s_ancestral': 'sample_dpmpp_2s_ancestral', + 'lms': 'sample_lms', + } + adept_funcs = { + 'euler': sample_adept_euler, + 'euler_ancestral': sample_adept_euler_ancestral, + 'heun': sample_adept_heun, + 'dpmpp_2m': sample_adept_dpmpp_2m, + 'dpmpp_2s_ancestral': sample_adept_dpmpp_2s_ancestral, + 'lms': sample_adept_lms, + } + + restored_count = 0 + skipped_count = 0 + for key, attr_name in adept_wrappers.items(): + if key not in ORIGINAL_SAMPLERS: + continue + live_func = getattr(k_diffusion.sampling, attr_name, None) + our_func = adept_funcs[key] + saved_original = ORIGINAL_SAMPLERS[key] + + if live_func is our_func: + # Normal case: we still own the slot — safe to restore. + setattr(k_diffusion.sampling, attr_name, saved_original) + restored_count += 1 + elif live_func is saved_original: + # Already restored somehow — nothing to do. + restored_count += 1 + else: + # Another extension wrapped us. Restoring would silently + # remove their wrapper; skip and warn instead. + print(f"⚠️ Adept unpatch: {attr_name} is currently owned by another " + f"extension ({live_func!r}). Skipping restore to avoid breaking " + f"their wrapper — you may need to reload the UI to fully unload.") + skipped_count += 1 + + ORIGINAL_SAMPLERS.clear() + print(f"🔄 Adept Sampler: Restored {restored_count} samplers" + + (f", skipped {skipped_count} (foreign wrappers)" if skipped_count else "")) + + +# ============================================================================ +# A1111 EXTENSION SCRIPT +# ============================================================================ + +class AdeptSamplerScript(scripts.Script): + def title(self): + return "Adept Sampler v5" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_img2img): + with gr.Accordion("Adept Sampler v5", open=False): + enabled = gr.Checkbox(label="Enable Adept Sampler", value=False, elem_id="adept_enabled") + + with gr.Row(): + scale = gr.Slider(minimum=0.5, maximum=2.0, step=0.05, value=1.0, label="Weight Scale") + shift = gr.Slider(minimum=-0.5, maximum=0.5, step=0.01, value=0.0, label="Weight Shift") + + with gr.Row(): + start_pct = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.0, label="Start Percent") + end_pct = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=1.0, label="End Percent") + + gr.HTML("
" + "⚠️ Weight Scale / Shift / Start–End apply to the 6 patched k-diffusion samplers only. " + "Custom samplers (Akashic, Adept, Mirror) use their own internal parameters.
") + + with gr.Row(): + eta = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=1.0, label="Eta (Ancestral samplers)") + s_noise = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=1.0, label="S-Noise") + + adaptive_eta = gr.Checkbox(label="Adaptive Eta (dynamic eta during sampling)", value=False) + + scheduler = gr.Dropdown( + choices=["Standard", "AOS-V", "AOS-Epsilon", "AkashicAOS", "Entropic", "SNR-Optimized", + "Constant-Rate", "Adaptive-Optimized", "Cosine-Annealed", "LogSNR-Uniform", + "Tanh Mid-Boost", "Exponential Tail", "Jittered-Karras", "Stochastic", + "JYS (Dynamic)", "Hybrid JYS-Karras", "AYS-SDXL", + "AkashicAOS Alt", "AkashicEQFlow"], + value="Standard", label="Scheduler Type" + ) + + vae_reflection = gr.Checkbox(label="Enable VAE Reflection (fixes edge artifacts for EQ-VAE)", value=False) + + gr.HTML("Enable to use Akashic/Adept/Ancestral samplers instead of k-diffusion
") + + use_custom = gr.Checkbox(label="Use Custom Sampler (overrides k-diffusion)", value=False) + custom_type = gr.Dropdown( + choices=["Akashic Solver v2", "Adept Solver", "Adept Ancestral Solver", "Mirror Correction Euler"], + value="Akashic Solver v2", label="Custom Sampler Type" + ) + + with gr.Accordion("⚙️ Akashic Solver Settings", open=False): + tau = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Tau (0=ODE, 1=SDE)") + phase_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.5, label="Phase Strength") + smea = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.0, label="SMEA (high-res coherency)") + ndb = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.0, label="NDB (detail boost)") + eqvae = gr.Dropdown(choices=["Off", "Balanced"], value="Off", label="EQ-VAE Mode") + + with gr.Accordion("⚙️ Adept Solver Settings", open=False): + solver_order = gr.Slider(minimum=1, maximum=3, step=1, value=2, label="Order (1-3)") + use_corrector = gr.Checkbox(value=True, label="Use Corrector") + + with gr.Accordion("⚙️ Ancestral Solver Settings", open=False): + phase_noise = gr.Checkbox(value=False, label="Phase-Aware Noise") + enhanced_deriv = gr.Checkbox(value=False, label="Enhanced Derivative") + + with gr.Accordion("⚙️ Mirror Correction Euler Settings", open=False): + gr.HTML("Active only when Custom Sampler = Mirror Correction Euler
") + mirror_correction_phase = gr.Slider( + minimum=0.0, maximum=1.0, step=0.05, value=0.5, + label="Correction Phase (fraction of steps with 3-call Heun correction)" + ) + mirror_smooth_phase = gr.Checkbox( + value=False, + label="Smooth Phase (log-sigma blend instead of binary cutoff)" + ) + + gr.HTML("" + "Combat CFG Drift works in stock A1111 via official callback. " + "Spectral Modulation & Phase-Aware CFG use a native sampler hook on " + "Forge/reForge-like backends, or a CFGDenoiser monkey-patch on stock A1111 " + "(near-parity; active mode logged to console).
") + + with gr.Accordion("⚙️ CFG Enhancement Settings", open=False): + cfg_drift_enabled = gr.Checkbox(value=False, label="Enable Combat CFG Drift") + with gr.Row(): + cfg_drift_method = gr.Dropdown( + choices=["mean", "median"], value="mean", + label="Drift Method" + ) + cfg_drift_intensity = gr.Slider( + minimum=0.0, maximum=1.0, step=0.05, value=0.5, + label="Drift Intensity" + ) + spectral_cfg_enabled = gr.Checkbox( + value=False, label="Enable Spectral Modulation (native hook or A1111 monkey-patch)" + ) + with gr.Row(): + spectral_multiplier = gr.Slider( + minimum=0.0, maximum=2.0, step=0.05, value=1.0, + label="Spectral Multiplier" + ) + spectral_percentile = gr.Slider( + minimum=1.0, maximum=25.0, step=0.5, value=5.0, + label="Spectral Percentile" + ) + phase_cfg_enabled = gr.Checkbox( + value=False, label="Enable Phase-Aware CFG (native hook or A1111 monkey-patch)" + ) + with gr.Row(): + phase_cfg_alpha = gr.Slider( + minimum=1.1, maximum=4.0, step=0.1, value=2.0, + label="Phase CFG Alpha" + ) + phase_cfg_beta = gr.Slider( + minimum=1.1, maximum=4.0, step=0.1, value=2.0, + label="Phase CFG Beta" + ) + + return [enabled, scale, shift, start_pct, end_pct, eta, s_noise, adaptive_eta, scheduler, vae_reflection, + use_custom, custom_type, tau, phase_strength, smea, ndb, eqvae, solver_order, use_corrector, + phase_noise, enhanced_deriv, + mirror_correction_phase, mirror_smooth_phase, + cfg_drift_enabled, cfg_drift_method, cfg_drift_intensity, + spectral_cfg_enabled, spectral_multiplier, spectral_percentile, + phase_cfg_enabled, phase_cfg_alpha, phase_cfg_beta] + + def process(self, p, enabled, scale, shift, start_pct, end_pct, eta, s_noise, adaptive_eta, scheduler, vae_reflection, + use_custom, custom_type, tau, phase_strength, smea, ndb, eqvae, solver_order, use_corrector, + phase_noise, enhanced_deriv, + mirror_correction_phase, mirror_smooth_phase, + cfg_drift_enabled, cfg_drift_method, cfg_drift_intensity, + spectral_cfg_enabled, spectral_multiplier, spectral_percentile, + phase_cfg_enabled, phase_cfg_alpha, phase_cfg_beta): + global ADEPT_STATE + + # Gate all sub-features through the master enabled switch. + # This prevents CFG hooks, native patches, and VAE Reflection from + # activating when the extension is globally toggled off. + ADEPT_STATE.update({ + "enabled": enabled, + "scale": scale, + "shift": shift, + "start_pct": start_pct, + "end_pct": end_pct, + "eta": eta, + "s_noise": s_noise, + "adaptive_eta": adaptive_eta, + "scheduler": scheduler, + "vae_reflection": enabled and vae_reflection, # gated + "use_custom_sampler": use_custom, + "custom_sampler": custom_type, + "tau": tau, + "phase_strength": phase_strength, + "smea_strength": smea, + "ndb_strength": ndb, + "eqvae_mode": eqvae, + "solver_order": int(solver_order), + "use_corrector": use_corrector, + "phase_noise": phase_noise, + "enhanced_derivative": enhanced_deriv, + # Mirror Correction Euler + "mirror_correction_phase": mirror_correction_phase, + "mirror_smooth_phase": mirror_smooth_phase, + # CFG enhancements — all gated through enabled + "cfg_drift_enabled": enabled and cfg_drift_enabled, + "cfg_drift_method": cfg_drift_method, + "cfg_drift_intensity": cfg_drift_intensity, + "spectral_cfg_enabled": enabled and spectral_cfg_enabled, + "spectral_multiplier": spectral_multiplier, + "spectral_percentile": spectral_percentile, + "phase_cfg_enabled": enabled and phase_cfg_enabled, + "phase_cfg_alpha": phase_cfg_alpha, + "phase_cfg_beta": phase_cfg_beta, + }) + + # Scheduler is now applied inside each patched sampler function, + # so p.sampler.model_wrap patching is no longer needed here. + + # Always reconfigure CFG runtime — even when disabled — so any previously + # installed native hook or A1111 callbacks get cleanly removed. + runtime_mode = configure_cfg_runtime() + + if enabled: + info = { + "Adept Sampler": "v5", + "Adept Scheduler": scheduler, + "CFG Runtime": runtime_mode, + } + if use_custom: + info["Adept Custom"] = custom_type + if custom_type == "Akashic Solver v2": + info["Adept Tau"] = tau + info["Adept EQ-VAE"] = eqvae + p.extra_generation_params.update(info) + + def process_batch(self, p, *args, **kwargs): + """Apply VAE Reflection before batch processing.""" + if ADEPT_STATE.get("enabled", False) and ADEPT_STATE.get("vae_reflection", False): + try: + vae_model = shared.sd_model.first_stage_model + patcher = VAEReflectionPatcher(vae_model) + patcher.__enter__() + p.adept_vae_patcher = patcher + except Exception as e: + print(f"⚠️ VAE Reflection error: {e}") + + def postprocess_batch(self, p, *args, **kwargs): + """Restore VAE padding modes after batch processing.""" + if hasattr(p, 'adept_vae_patcher'): + try: + p.adept_vae_patcher.__exit__(None, None, None) + delattr(p, 'adept_vae_patcher') + except Exception as e: + print(f"⚠️ VAE Reflection restore error: {e}") + # Safety net: force-restore even if the patcher context failed + force_restore_vae_reflection() + + +# ============================================================================ +# INITIALIZATION +# ============================================================================ +# +# k-diffusion wrappers are installed via on_before_ui (fires after all +# extensions are imported) rather than at bare module import time. +# This reduces the risk of interacting badly with other extensions that +# also wrap k_diffusion.sampling functions, because our wrappers are put +# on last and therefore sit outermost in the call chain. +# Uninstall happens in on_script_unloaded() as before. + +def _adept_deferred_init(): + patch_k_diffusion() + +try: + script_callbacks.on_before_ui(_adept_deferred_init) +except Exception: + # Fallback: if on_before_ui isn't available (older A1111), patch immediately. + patch_k_diffusion() + +def on_script_unloaded(): + try: + force_restore_vae_reflection() + except Exception: + pass + try: + uninstall_a1111_cfg_callbacks() + except Exception: + pass + try: + uninstall_native_cfg_hook() + except Exception: + pass + try: + unpatch_cfg_denoiser() + except Exception: + pass + try: + unpatch_k_diffusion() + except Exception: + pass + +try: + script_callbacks.on_script_unloaded(on_script_unloaded) +except AttributeError: + print("⚠️ Script unload callback not available") + +print("🚀 Adept Sampler v5 loaded!") +print(" ✨ 4 Custom Samplers: Akashic v2, Adept Solver, Adept Ancestral, Mirror Correction Euler") +print(" ⚡ 6 k-diffusion Samplers with weight scaling") +print(" 📅 18 Schedulers (including AkashicAOS Alt, AkashicEQFlow)") +print(" 🎨 VAE Reflection") +print(" ✅ A1111 port of ComfyUI-Adept-Sampler") +