| from __future__ import annotations |
|
|
| import math |
| import torch |
| import torch.nn.functional as F |
|
|
| import comfy.utils as _utils |
| import comfy.sample as _sample |
| import comfy.samplers as _samplers |
| from comfy.k_diffusion import sampling as _kds |
|
|
| import nodes |
|
|
|
|
| def _smoothstep01(x: torch.Tensor) -> torch.Tensor: |
| return x * x * (3.0 - 2.0 * x) |
|
|
|
|
| def _build_hybrid_sigmas(model, steps: int, base_sampler: str, mode: str, |
| mix: float, denoise: float, jitter: float, seed: int, |
| _debug: bool = False, tail_smooth: float = 0.0, |
| auto_hybrid_tail: bool = True, auto_tail_strength: float = 0.35): |
| """Return 1D tensor of sigmas (len == steps+1), strictly descending and ending with 0. |
| |
| mode: 'karras' | 'beta' | 'hybrid'. If 'hybrid', blend tail toward beta by `mix`. |
| We DO NOT apply 'drop penultimate' until the very end to preserve denoise math. |
| """ |
| ms = model.get_model_object("model_sampling") |
| steps = int(steps) |
| assert steps >= 1 |
|
|
| |
| sig_k = _samplers.calculate_sigmas(ms, "karras", steps) |
| sig_b = _samplers.calculate_sigmas(ms, "beta", steps) |
|
|
| def _align_len(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| """Align two sigma schedules to the same length (use tail of longer).""" |
| if a.shape[0] == b.shape[0]: |
| return a, b |
| m = min(a.shape[0], b.shape[0]) |
| return a[-m:], b[-m:] |
|
|
| mode = str(mode).lower() |
| sig_k, sig_b = _align_len(sig_k, sig_b) |
| if mode == "karras": |
| sig = sig_k |
| elif mode == "beta": |
| sig = sig_b |
| else: |
| n = sig_k.shape[0] |
| t = torch.linspace(0.0, 1.0, n, device=sig_k.device, dtype=sig_k.dtype) |
| m = float(max(0.0, min(1.0, mix))) |
| eps = 1e-6 if m < 1e-6 else m |
| w = torch.clamp((t - (1.0 - m)) / eps, 0.0, 1.0) |
| w = _smoothstep01(w) |
| sig = sig_k * (1.0 - w) + sig_b * w |
|
|
| |
| sig_k_base = sig_k |
| sig_b_base = sig_b |
| if denoise is not None and 0.0 < float(denoise) < 0.999999: |
| new_steps = max(1, int(steps / max(1e-6, float(denoise)))) |
| sk = _samplers.calculate_sigmas(ms, "karras", new_steps) |
| sb = _samplers.calculate_sigmas(ms, "beta", new_steps) |
| sk, sb = _align_len(sk, sb) |
| if mode == "karras": |
| sig_full = sk |
| elif mode == "beta": |
| sig_full = sb |
| else: |
| n2 = sk.shape[0] |
| t2 = torch.linspace(0.0, 1.0, n2, device=sk.device, dtype=sk.dtype) |
| m = float(max(0.0, min(1.0, mix))) |
| eps = 1e-6 if m < 1e-6 else m |
| w2 = torch.clamp((t2 - (1.0 - m)) / eps, 0.0, 1.0) |
| w2 = _smoothstep01(w2) |
| sig_full = sk * (1.0 - w2) + sb * w2 |
| need = steps + 1 |
| if sig_full.shape[0] >= need: |
| sig = sig_full[-need:] |
| sig_k_base = sk[-need:] |
| sig_b_base = sb[-need:] |
| else: |
| |
| sig = sig_full |
| tail = min(need, sk.shape[0]) |
| sig_k_base = sk[-tail:] |
| sig_b_base = sb[-tail:] |
|
|
| |
| if bool(auto_hybrid_tail) and sig.numel() > 2: |
| n = sig.shape[0] |
| t = torch.linspace(0.0, 1.0, n, device=sig.device, dtype=sig.dtype) |
| m = float(max(0.0, min(1.0, mix))) |
| if mode == "hybrid": |
| eps = 1e-6 if m < 1e-6 else m |
| w_m = torch.clamp((t - (1.0 - m)) / eps, 0.0, 1.0) |
| w_m = _smoothstep01(w_m) |
| elif mode == "beta": |
| w_m = torch.ones_like(t) |
| else: |
| w_m = torch.zeros_like(t) |
| dif = (sig[1:] - sig[:-1]).abs() / sig[:-1].abs().clamp_min(1e-8) |
| dif = torch.cat([dif, dif[-1:]], dim=0) |
| dif = (dif - dif.min()) / (dif.max() - dif.min() + 1e-8) |
| ramp = _smoothstep01(torch.clamp((t - 0.7) / 0.3, 0.0, 1.0)) |
| w_a = dif * ramp |
| g = float(max(0.0, min(1.0, auto_tail_strength))) |
| u = w_m + g * w_a - w_m * g * w_a |
| sig = sig_k_base * (1.0 - u) + sig_b_base * u |
|
|
| |
| j = float(max(0.0, min(0.1, float(jitter)))) |
| if j > 0.0 and sig.numel() > 1: |
| gen = torch.Generator(device='cpu') |
| gen.manual_seed(int(seed) ^ 0x5EEDCAFE) |
| noise = torch.randn(sig.shape, generator=gen, device='cpu').to(sig.device, sig.dtype) |
| amp = j * float(sig[0].item() - sig[-1].item()) * 1e-3 |
| sig = sig + noise * amp |
| sig, _ = torch.sort(sig, descending=True) |
|
|
| |
| if sig[-1].abs() > 1e-12: |
| sig = torch.cat([sig[:-1], sig.new_zeros(1)], dim=0) |
|
|
| |
| |
| ts = float(max(0.0, min(1.0, tail_smooth))) |
| if ts > 0.0 and sig.numel() > 2: |
| s = sig.clone() |
| n = int(s.shape[0]) |
| t = torch.linspace(0.0, 1.0, n, device=s.device, dtype=s.dtype) |
| w = (t.pow(2) * ts).clamp(0.0, 1.0) |
| for i in range(n - 2, -1, -1): |
| a = float(min(0.5, 0.5 * w[i].item())) |
| s[i] = (1.0 - a) * s[i] + a * s[i + 1] |
| sig = s |
|
|
| if base_sampler in _samplers.KSampler.DISCARD_PENULTIMATE_SIGMA_SAMPLERS and sig.numel() >= 2: |
| sig = torch.cat([sig[:-2], sig[-1:]], dim=0) |
|
|
| sig = sig.to(model.load_device) |
|
|
| |
| if _debug: |
| try: |
| desc_ok = bool((sig[:-1] > sig[1:]).all().item()) if sig.numel() > 1 else True |
| head = ", ".join(f"{float(v):.4g}" for v in sig[:3].tolist()) if sig.numel() >= 3 else \ |
| ", ".join(f"{float(v):.4g}" for v in sig.tolist()) |
| tail = ", ".join(f"{float(v):.4g}" for v in sig[-3:].tolist()) if sig.numel() >= 3 else head |
| print(f"[ZeSmart][dbg] sigmas len={sig.numel()} desc={desc_ok} first={float(sig[0]):.6g} last={float(sig[-1]):.6g}") |
| print(f"[ZeSmart][dbg] head: [{head}] tail: [{tail}]") |
| except Exception: |
| pass |
|
|
| return sig |
|
|
|
|
| class MG_ZeSmartSampler: |
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "model": ("MODEL", {}), |
| "seed": ("INT", {"default": 0, "min": 0, "max": 2**63-1, "control_after_generate": True}), |
| "steps": ("INT", {"default": 20, "min": 1, "max": 4096}), |
| "cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 50.0, "step": 0.1}), |
| "base_sampler": (_samplers.KSampler.SAMPLERS, {"default": "dpmpp_2m"}), |
| "schedule": (["karras", "beta", "hybrid"], {"default": "hybrid", "tooltip": "Sigma curve: karras — soft start; beta — stable tail; hybrid — their mix."}), |
| "positive": ("CONDITIONING", {}), |
| "negative": ("CONDITIONING", {}), |
| "latent": ("LATENT", {}), |
| }, |
| "optional": { |
| "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Path shortening: 1.0 = full; <1.0 = take the last steps only. Useful for inpaint/mixing."}), |
| "hybrid_mix": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "For schedule=hybrid: tail fraction blended toward beta (0=karras, 1=beta)."}), |
| "jitter_sigma": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 0.1, "step": 0.001, "tooltip": "Tiny sigma jitter to kill moiré/banding on backgrounds. 0–0.02 is usually enough."}), |
| "tail_smooth": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Smooth the sigma tail — reduces wobble/banding. Too high may soften details."}), |
| "auto_hybrid_tail": ("BOOLEAN", {"default": True, "tooltip": "Auto‑blend beta on the tail when steps become brittle."}), |
| "auto_tail_strength": ("FLOAT", {"default": 0.4, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Strength of auto beta‑mix on the tail (0=off, 1=max)."}), |
| "debug_probe": ("BOOLEAN", {"default": False, "tooltip": "Print sigma summary (length, first/last, head/tail)."}), |
| } |
| } |
|
|
| RETURN_TYPES = ("LATENT",) |
| RETURN_NAMES = ("LATENT",) |
| FUNCTION = "apply" |
| CATEGORY = "MagicNodes/Experimental" |
|
|
| def apply(self, model, seed, steps, cfg, base_sampler, schedule, |
| positive, negative, latent, denoise=1.0, hybrid_mix=0.5, |
| jitter_sigma=0.02, tail_smooth=0.07, |
| auto_hybrid_tail=True, auto_tail_strength=0.3, |
| debug_probe=False): |
| |
| lat_img = latent["samples"] |
| lat_img = _sample.fix_empty_latent_channels(model, lat_img) |
| batch_inds = latent.get("batch_index", None) |
| noise = _sample.prepare_noise(lat_img, seed, batch_inds) |
| noise_mask = latent.get("noise_mask", None) |
|
|
| |
| sigmas = _build_hybrid_sigmas(model, int(steps), str(base_sampler), str(schedule), |
| float(hybrid_mix), float(denoise), float(jitter_sigma), int(seed), |
| _debug=bool(debug_probe), tail_smooth=float(tail_smooth), |
| auto_hybrid_tail=bool(auto_hybrid_tail), |
| auto_tail_strength=float(auto_tail_strength)) |
|
|
| |
| sampler_obj = _samplers.sampler_object(str(base_sampler)) |
| callback = nodes.latent_preview.prepare_callback(model, int(steps)) |
| disable_pbar = not _utils.PROGRESS_BAR_ENABLED |
| samples = _sample.sample_custom(model, noise, float(cfg), sampler_obj, sigmas, |
| positive, negative, lat_img, |
| noise_mask=noise_mask, callback=callback, |
| disable_pbar=disable_pbar, seed=seed) |
| out = {**latent} |
| out["samples"] = samples |
| return (out,) |
|
|