diff --git "a/asds/scripts/asymmetric_tiling_UNIFIED (69).py" "b/asds/scripts/asymmetric_tiling_UNIFIED (69).py" new file mode 100644--- /dev/null +++ "b/asds/scripts/asymmetric_tiling_UNIFIED (69).py" @@ -0,0 +1,6828 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import gradio as gr +from modules import scripts, shared, sd_samplers, sd_samplers_common, sd_samplers_kdiffusion +from modules.script_callbacks import on_cfg_denoiser +import k_diffusion.sampling +from k_diffusion.sampling import to_d, default_noise_sampler, get_ancestral_step +from tqdm.auto import trange +import math +import numpy as np +from collections import OrderedDict +import os +import sys + +from PIL import Image + +# === ИМПОРТ HELPER-МОДУЛЕЙ: fallback на scripts/ и libs/ === +# Поддерживает обе раскладки файлов: +# A) helper'ы лежат рядом со скриптом (.../scripts/) ← текущая раскладка +# B) helper'ы лежат в .../libs/ ← будущая раскладка +current_script_path = os.path.abspath(__file__) +_scripts_dir = os.path.dirname(current_script_path) # .../scripts +_extension_root = os.path.dirname(_scripts_dir) # .../extensions/asd +_libs_dir = os.path.join(_extension_root, "libs") # .../extensions/asd/libs + +for _candidate in (_scripts_dir, _libs_dir): + if os.path.isdir(_candidate) and _candidate not in sys.path: + sys.path.append(_candidate) +# ============================================================ + +# Теперь импорты заработают корректно из папки libs +from improved_tiling_functions import ( + compute_advanced_blend_padding, + compute_blend_fade_to_black, + apply_multires_blend, + validate_blend_params, + validate_multires_params, + BlendStrategy, + MultiResStrategy +) +from advanced_zoom_extension import ( + validate_zoom_params, +) +# ======================================================================== +# КОНСТАНТЫ И КОНФИГУРАЦИЯ +# ======================================================================== +MODE_OFF = "Default (Off)" +MODE_CIRCULAR = "Circular" +MODE_MIRROR = "Mirror (Reflect)" +MODE_HEXAGONAL = "Hexagonal (Staggered)" +MODE_PANORAMA = "Panorama 360°" +MODE_CUBEMAP = "Cubemap (3D)" +MODE_BLEND = "Soft Blend Edges" # legacy / unused as a mode value; + # blend logic is driven by use_blend checkbox, not this constant +MODE_ANISOTROPIC = "Anisotropic (Directional)" +MODE_POLAR = "Polar (Sphere Correct)" + +# NEW MODES - Advanced Features from ENHANCED +MODE_VORONOI = "Voronoi (Organic)" +MODE_PERLIN = "Perlin Noise Distortion" +MODE_FRACTAL = "Fractal Recursive" +MODE_ADAPTIVE = "Adaptive Smart" + +# Глобальное хранилище +_ORIGINAL_METHODS_CACHE = {} +_MASK_CACHE = {} +_CUBEMAP_GRID_CACHE = {} +_SAMPLER_REGISTERED = False + + +# Panorama Live caches +_PANO_GRID_CACHE = {} +_BLUR_KERNEL_CACHE = {} +# ======================================================================== +# KOHAKU LONYU YOG SAMPLER IMPLEMENTATION +# ======================================================================== +# ИСПРАВЛЕННАЯ ВЕРСИЯ — вставить вместо обоих дублей +def get_safe_epsilon(tensor_or_dtype): + """Float16-safe epsilon""" + if isinstance(tensor_or_dtype, torch.Tensor): + dtype = tensor_or_dtype.dtype + else: + dtype = tensor_or_dtype + + if dtype in (torch.float16, torch.bfloat16): + return 1e-3 # безопасно для half precision + elif dtype == torch.float32: + return 1e-6 # ← ИСПРАВЛЕНО: прямое значение, не рекурсивный вызов + else: + return 1e-12 # float64 + + +# ======================================================================== +# LATENT NOISE INIT — Calibrated Gaussian Noise (F13/F14) +# Источник: SD-Advanced-Noise-main/latent_noise_generator.py +# ======================================================================== + +LATENT_RANGES = { + "v1": { # SD 1.5 — диапазоны откалиброваны на реальных VAE экспериментах + "min": [-5.5618, -17.1368, -10.3445, -8.6218], + "max": [13.5369, 11.1997, 16.3043, 10.6343], + "null": [-5.3870, -14.2931, 6.2738, 7.1220], # закодированное чёрное + }, + "xl": { # SDXL + "min": [-22.2127, -20.0131, -17.7673, -14.9434], + "max": [ 17.9334, 26.3043, 33.1648, 8.9380], + "null": [-21.9287, 3.8783, 2.5879, 2.5435], # закодированное чёрное + }, +} + + +@torch.no_grad() +def gaussian_latent_noise(h, w, ver="v1", fac=0.6, nul=0.0, + srnd=True, seed=-1, device="cpu", dtype=None): + """ + Генерирует шум в реальных диапазонах латентного пространства. + + srnd=True — shared random: все 4 канала из одного rand тензора → coherent + nul=0.25 — 25% смешивания с null-латентом → нейтральный серый ст��рт. + Рекомендуется для panorama init: предотвращает цветовые пятна. + + Источник: SD-Advanced-Noise-main/latent_noise_generator.py + """ + # BUG FIX 9: use a local Generator so we never mutate the global RNG + # state; other parts of the pipeline (samplers, etc.) keep their own + # reproducible sequences regardless of whether this feature is on. + gen = None + if seed >= 0: + gen = torch.Generator(device=device) + gen.manual_seed(int(seed)) + ver = ver if ver in LATENT_RANGES else "v1" + lims = LATENT_RANGES[ver] + mn, mx, nl = lims["min"], lims["max"], lims["null"] + + if srnd: + rand = torch.rand([h, w], device=device, generator=gen) + lat = torch.stack([ + rand.clone() * (mx[i] - mn[i]) + mn[i] + for i in range(4) + ]) + else: + lat = torch.stack([ + torch.rand([h, w], device=device, generator=gen) * (mx[i] - mn[i]) + mn[i] + for i in range(4) + ]) + + null_lat = torch.stack([ + torch.ones([h, w], device=device) * nl[i] + for i in range(4) + ]) + + result = ((lat * fac) * (1.0 - nul) + null_lat * nul) / 2.0 + if dtype is not None: + result = result.to(dtype=dtype) + return result + + +# ======================================================================== +# DIFFUSION CG — Color Grading / Recenter / Normalize +# Источник: sdsw/diffusion_cg.py + sdsw/sdxl_latent_tweak.py +# ======================================================================== + +@torch.no_grad() +def center_tensor(x, per_channel_shift=1.0, full_tensor_shift=1.0, channels=None): + """ + Вычитает среднее по цветовым каналам и всему тензору. + Убирает цветовой дрейф (bias к красному/зелёному и т.д.) + Источник: sdxl_latent_tweak.py / Timothy Alexis Vass + """ + if channels is None: + channels = [1, 2] # ch1=G, ch2=R для SD1.5 (цветовые каналы) + for c in channels: + x[:, c] -= x[:, c].mean() * per_channel_shift + return x - x.mean() * full_tensor_shift + + +@torch.no_grad() +def maximize_tensor_v2(x, boundary=4.0, channels=None): + """ + Нормализует тензор к диапазону [-boundary, +boundary] per-batch. + Аналог Levels в Photoshop: сохраняет среднее, растягивает min/max. + Источник: sdxl_latent_tweak.py / SLAPaper + """ + if channels is None: + channels = [0, 1, 2] + for i in range(x.size(0)): + ch = x[i, channels, :, :] + if torch.any(ch < 0) and torch.any(ch > 0): + mean = ch.mean() + neg_min = ch[ch < mean].min() if (ch < mean).any() else ch.new_tensor(-1.0) + pos_max = ch[ch > mean].max() if (ch > mean).any() else ch.new_tensor(1.0) + ch = torch.where(ch < mean, -boundary * (ch / neg_min), ch) + ch = torch.where(ch > mean, boundary * (ch / pos_max), ch) + else: + max_abs = ch.abs().max() + if max_abs > 0: + ch = boundary * ch / max_abs + x[i, channels, :, :] = ch + return x + + +@torch.no_grad() +def apply_diffusion_cg(latent, cur_step, total_steps, + recenter=0.0, normalization=0.0): + """ + Применяет Color Grading к латенту с убывающим sin-расписанием. + + Расписание: strength = 1 - sin(ratio * pi/2) + При ratio=0 (начало): strength=1.0 (полный эффект) + При ratio=1 (конец): strength=0.0 (нет эффекта) + + recenter: 0.0–1.0 → балансировка цвета (убирает цветовой дрейф) + normalization: 0.0–1.0 → нормализация контраста (растяжка до ±boundary) + + Источник: sdsw/diffusion_cg.py / DiffusionCG v2.0.0 + """ + if recenter <= 0.0 and normalization <= 0.0: + return latent + if total_steps <= 1: + return latent + try: + import math + ratio = cur_step / total_steps + strength = 1.0 - math.sin(ratio * math.pi / 2.0) + + if recenter > 0.0 and strength > 0.0: + lat = latent.clone() + B, C = lat.shape[0], lat.shape[1] + for b in range(B): + _std = float(lat[b].std()) or 1.0 + for c in range(C): + bias = float(lat[b][c].std()) / _std * strength + lat[b][c] += (0.0 - lat[b][c].mean()) * bias * recenter + latent = lat + + if normalization > 0.0 and strength > 0.0: + lat = latent.clone() + B, C = lat.shape[0], lat.shape[1] + for b in range(B): + for c in range(C): + magnitude = float(lat[b][c].max() - lat[b][c].min()) + if magnitude > 0: + factor = 1.0 / 0.13025 + scale = max(factor / magnitude - 1.0, 0.0) + lat[b][c] *= scale * normalization * strength + 1.0 + latent = lat + + return latent + except Exception as e: + print(f"[DiffusionCG] ошибка: {e}") + return latent + + +# ======================================================================== +# BAND-PASS GRAIN SYSTEM (F5 / F11 / F12) +# Источник: ComfyUI-LatentDetailer-main / LatentDetailer.py +# ======================================================================== + +@torch.no_grad() +def _lowpass_avgpool(x: torch.Tensor, radius: int) -> torch.Tensor: + """Lowpass blur via avg_pool2d (zero padding — intentionally preserved).""" + r = int(max(0, radius)) + if r <= 0: + return x + k = 2 * r + 1 + return F.avg_pool2d(x, kernel_size=k, stride=1, padding=r) + + +@torch.no_grad() +def _rms_norm_(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + """In-place RMS normalisation so noise_scale stays comparable.""" + var = x.pow(2).mean(dim=(2, 3), keepdim=True) + return x.mul_(torch.rsqrt(var + eps)) + + +@torch.no_grad() +def _randn_like_grain(x: torch.Tensor, seed: int) -> torch.Tensor: + """Seeded or fully-random noise, same shape/device/dtype as x.""" + if seed < 0: + return torch.randn_like(x) + # BUG FIX 9 (continued): always use a local Generator; avoid the global + # torch.manual_seed fallback that was here before. + try: + g = torch.Generator(device=x.device) + g.manual_seed(int(seed)) + return torch.randn_like(x, generator=g) + except Exception: + # CPU fallback: still use a local generator, not the global one + g_cpu = torch.Generator() + g_cpu.manual_seed(int(seed)) + return torch.randn(x.shape, dtype=x.dtype, generator=g_cpu).to(x.device) + + +@torch.no_grad() +def _shape_noise_tail(n: torch.Tensor, noise_tail: float) -> torch.Tensor: + """ + Reshape noise tails while keeping RMS comparable. + noise_tail 0 = Gaussian, +1 = heavier tails, -1 = lighter tails. + """ + t = float(max(-1.0, min(1.0, noise_tail))) + if abs(t) < 1e-6: + return n + p = 2.0 ** t + x = torch.sign(n) * torch.pow(torch.abs(n) + 1e-12, p) + return _rms_norm_(x) + + +@torch.no_grad() +def _bandpass_grain(noise: torch.Tensor, r: int) -> torch.Tensor: + """ + Band-pass filter on noise via double lowpass difference. + lp1 - lp2 = band centred around radius r. + r=0 → broadband (no filter), r>0 → textured mid-frequency grain. + """ + r = int(max(0, r)) + if r == 0: + return _rms_norm_(noise) + lp1 = _lowpass_avgpool(noise, r) + lp2 = _lowpass_avgpool(noise, r * 2) + band = lp1 - lp2 + return _rms_norm_(band) + + +@torch.no_grad() +def _grain_luma_weight(x: torch.Tensor, grain_luma: float, + noise_radius: int) -> torch.Tensor: + """ + Luma-dependent grain weight (B,1,H,W). + Darker regions get more grain, highlights get less. + grain_luma=0 → uniform; =1 → maximum shadow/highlight split. + """ + gl = float(max(0.0, min(1.0, grain_luma))) + if gl <= 0.0 or x.ndim != 4: + return x.new_ones((x.shape[0], 1, x.shape[-2], x.shape[-1])) + c = int(min(4, x.shape[1])) + lum = x[:, :c].mean(dim=1, keepdim=True) + blur_r = int(min(24, max(2, noise_radius * 4))) + lum = _lowpass_avgpool(lum, blur_r) + mean = lum.mean(dim=(2, 3), keepdim=True) + std = lum.std(dim=(2, 3), keepdim=True) + 1e-6 + z = (lum - mean) / std + w = torch.clamp(torch.exp(-1.75 * z), 0.25, 4.0) + return (1.0 - gl) + gl * w + + +@torch.no_grad() +def apply_latent_grain(x: torch.Tensor, + noise_scale: float = 0.02, + noise_radius: int = 3, + noise_tail: float = 0.0, + grain_luma: float = 0.0, + seed: int = -1) -> torch.Tensor: + """ + Adds band-pass, RMS-normalised grain to latent tensor. + + noise_scale: amplitude (0.02–0.08 рекомендуется для тайлинга) + noise_radius: размер зерна в латентных пикселях (1=мелкое, 8=крупное) + noise_tail: форма распределения (0=Гаусс, +1=тяжёлые хвосты) + grain_luma: bias в тёмные области (0=равномерно, 1=максимум) + seed: -1=случайный каждый шаг, ≥0=воспроизводимый + """ + ns = float(max(0.0, noise_scale)) + if ns <= 0.0: + return x + try: + n = _randn_like_grain(x, seed) + g = _bandpass_grain(n, int(noise_radius)) + if abs(float(noise_tail)) > 1e-6: + g = _shape_noise_tail(g, float(noise_tail)) + if float(grain_luma) > 0.0: + w = _grain_luma_weight(x, float(grain_luma), int(noise_radius)) + g = g * w + return x + g * ns + except Exception as e: + print(f"[Grain] ошибка: {e}") + return x + + +# ======================================================================== +# FOURIER FILTER — GAUSSIAN FFT +# Источник: ComfyUI_FreeU_V2_advanced-main/utils.py +# ======================================================================== + +@torch.no_grad() +def fourier_filter_gauss(x, radius_ratio=0.08, scale=1.0, hf_boost=1.0): + """ + Gaussian FFT без артефактов Гиббса. + radius_ratio: доля от min(H,W), адаптируется к размеру тайла. + scale<1.0 ослабляет НЧ, hf_boost>1.0 усиливает ВЧ. + """ + orig_dtype = x.dtype + x_f = torch.fft.fftn(x.float(), dim=(-2, -1)) + x_f = torch.fft.fftshift(x_f, dim=(-2, -1)) + _B, _C, H, W = x_f.shape + R = max(1, int(min(H, W) * float(radius_ratio))) + yy = torch.arange(H, device=x.device, dtype=torch.float32) - H // 2 + xx = torch.arange(W, device=x.device, dtype=torch.float32) - W // 2 + yy, xx = torch.meshgrid(yy, xx, indexing='ij') + sigma_f = max(1e-6, float(R ** 2) / 2.0) + center = torch.exp(-(yy ** 2 + xx ** 2) / sigma_f).view(1, 1, H, W) + mask = float(scale) * center + float(hf_boost) * (1.0 - center) + result = torch.fft.ifftn( + torch.fft.ifftshift(x_f * mask, dim=(-2, -1)), dim=(-2, -1) + ).real + return result.to(dtype=orig_dtype) + + +# ======================================================================== +# ADVANCED ZOOM — двухбэкендная система +# geometry: grid_sample warp без noise (grid_warp / spiral_zoom / convergence_shift) +# canvas: outpaint-style latent fill с coherent noise (outpaint_zoom / blend_transition / hybrid) +# ======================================================================== + +ZOOM_BACKEND_GEOMETRY = "geometry" +ZOOM_BACKEND_CANVAS = "canvas" +GEOMETRY_ZOOM_MODES = {"grid_warp", "spiral_zoom", "convergence_shift"} +CANVAS_ZOOM_MODES = {"outpaint_zoom", "blend_transition", "hybrid"} + + +def resolve_zoom_backend(zoom_mode: str, zoom_engine: str = "auto") -> str: + """Определяет backend по mode и явному выбору engine. + + При явном задании engine проверяет совместимость с mode: + - spiral_zoom в canvas-engine: нет полной реализации → принудительно geometry + - outpaint/blend/hybrid в geometry-engine: не будет outpaint-fill → warn + canvas + Логика: явный выбор пользователя уважается где это безопасно; несовместимая + комбинация логируется и нормализуется к рабочему backend-у. + """ + zm = str(zoom_mode or "grid_warp").lower() + ze = str(zoom_engine or "auto").lower() + + if ze == ZOOM_BACKEND_GEOMETRY: + if zm in CANVAS_ZOOM_MODES: + # outpaint/blend/hybrid требуют canvas fill; в geometry они деградируют + # до обычного warp без fill — предупреждаем и переключаемся + print(f"[Zoom] ⚠ zoom_mode='{zm}' несовместим с engine='geometry' " + f"(нет canvas fill). Принудительно переключаю на engine='canvas'.") + return ZOOM_BACKEND_CANVAS + return ZOOM_BACKEND_GEOMETRY + + if ze == ZOOM_BACKEND_CANVAS: + if zm in ("spiral_zoom", "convergence_shift"): + # spiral_zoom реализован только в geometry; + # convergence_shift требует depth-curve warp, которого нет в canvas + print(f"[Zoom] ⚠ zoom_mode='{zm}' не реализован для engine='canvas'. " + f"Принудительно переключаю на engine='geometry'.") + return ZOOM_BACKEND_GEOMETRY + return ZOOM_BACKEND_CANVAS + + # auto: выбираем по mode без предупреждений + return ZOOM_BACKEND_GEOMETRY if zm in GEOMETRY_ZOOM_MODES else ZOOM_BACKEND_CANVAS + + +@torch.no_grad() +def apply_geometry_zoom_latent( + x: torch.Tensor, + zoom_factor: float = 0.0, + convergence_point: float = 0.5, + convergence_y: float = 0.5, + pan_x: float = 0.0, + pan_y: float = 0.0, + interp_mode: str = "bilinear", + zoom_mode: str = "grid_warp", + spiral_rotation: float = 0.5, + spiral_direction: float = 1.0, + depth_power: float = 1.0, + auto_clamp_pan: bool = True, +) -> torch.Tensor: + """ + Чистый latent-space geometric warp через F.grid_sample. + Сохраняет форму [B,C,H,W], не добавляет шум. + Поддерживает: grid_warp, spiral_zoom, convergence_shift. + """ + if x is None or x.ndim != 4: + return x + zf = float(zoom_factor) + pax = float(pan_x) + pay = float(pan_y) + zm = str(zoom_mode).lower() + if abs(zf) < 1e-6 and abs(pax) < 1e-6 and abs(pay) < 1e-6: + return x + + orig_dtype = x.dtype + xf = x.float() + b, c, h, w = xf.shape + device = xf.device + + scale = max(0.05, min(10.0, 1.0 + zf * 0.1)) + + if auto_clamp_pan and abs(zf) > 1e-4: + # Only clamp pan when zoom is active — at scale≈1.0 the formula gives + # max_pan≈0 which would zero out pan even though it's perfectly valid. + max_pan = max(0.0, 1.0 - 1.0 / scale) * 0.95 + pax = max(-max_pan, min(max_pan, pax)) + pay = max(-max_pan, min(max_pan, pay)) + + yy = torch.linspace(-1.0, 1.0, h, device=device, dtype=torch.float32) + xx = torch.linspace(-1.0, 1.0, w, device=device, dtype=torch.float32) + y_grid, x_grid = torch.meshgrid(yy, xx, indexing="ij") + + cx = (float(convergence_point) - 0.5) * 2.0 + cy = (float(convergence_y) - 0.5) * 2.0 + ox = pax * 2.0 + oy = pay * 2.0 + + if zm == "spiral_zoom": + dx = x_grid - cx + dy = y_grid - cy + r = torch.sqrt(dx * dx + dy * dy + 1e-6) + a = torch.atan2(dy, dx) + coerce_spiral_direction(spiral_direction) * float(spiral_rotation) * r + x_new = (r * torch.cos(a)) / scale + cx - ox + y_new = (r * torch.sin(a)) / scale + cy - oy + elif zm == "convergence_shift": + dp = max(0.1, float(depth_power)) + dx = x_grid - cx + dy = y_grid - cy + r = torch.sqrt(dx * dx + dy * dy + 1e-6).clamp(0.0, 2.0) + w_mask = torch.pow(r / 2.0, dp) + x_new = cx + dx / scale * (1.0 - w_mask) + dx * w_mask - ox + y_new = cy + dy / scale * (1.0 - w_mask) + dy * w_mask - oy + else: # grid_warp (default) + x_new = (x_grid - cx) / scale + cx - ox + y_new = (y_grid - cy) / scale + cy - oy + + grid = torch.stack((x_new, y_new), dim=-1).unsqueeze(0).expand(b, -1, -1, -1) + + _mode = str(interp_mode).lower() + if _mode == "bicubic": + try: + out = F.grid_sample(xf, grid, mode="bicubic", + padding_mode="zeros", align_corners=True) + except Exception: + out = F.grid_sample(xf, grid, mode="bilinear", + padding_mode="zeros", align_corners=True) + elif _mode == "nearest": + out = F.grid_sample(xf, grid, mode="nearest", + padding_mode="zeros", align_corners=True) + else: + out = F.grid_sample(xf, grid, mode="bilinear", + padding_mode="zeros", align_corners=True) + + return out.to(dtype=orig_dtype) + + +@torch.no_grad() +def _latent_valid_mask(grid: torch.Tensor) -> torch.Tensor: + """[B,H,W,2] grid → [B,1,H,W] float маска валидных пикселей.""" + gx = grid[..., 0] + gy = grid[..., 1] + return ((gx >= -1.0) & (gx <= 1.0) & (gy >= -1.0) & (gy <= 1.0)).float().unsqueeze(1) + + +@torch.no_grad() +def _distance_from_valid(valid_mask: torch.Tensor, blur_r: int = 4) -> torch.Tensor: + """Расстояние от ближайшего валидного пикселя (0=внутри контента, 1=далеко).""" + inv = 1.0 - valid_mask + if blur_r > 0: + k = blur_r * 2 + 1 + inv = F.avg_pool2d(inv, kernel_size=k, stride=1, padding=blur_r) + return inv.clamp(0.0, 1.0) + + +@torch.no_grad() +def _make_canvas_fill( + x: torch.Tensor, + valid_mask: torch.Tensor, + noise_strength: float = 1.0, + adaptive_noise_scale: bool = True, + seed: int = -1, +) -> torch.Tensor: + """Coherent latent fill для пустых зон canvas zoom.""" + b, c, h, w = x.shape + fills = [] + for bi in range(b): + fill = gaussian_latent_noise( + h, w, + ver="xl" if (c == 4 and h >= 128) else "v1", + fac=max(0.05, min(2.0, float(noise_strength))), + nul=0.15, srnd=True, + seed=(seed + bi) if seed >= 0 else -1, + device=x.device, dtype=x.dtype, + ).unsqueeze(0) + fills.append(fill) + fill = torch.cat(fills, dim=0) + if adaptive_noise_scale: + dist = _distance_from_valid(valid_mask, blur_r=6) + fill = fill * (0.35 + 0.65 * dist) + return fill + + +@torch.no_grad() +def _apply_variance_correction( + out: torch.Tensor, + ref: torch.Tensor, + valid_mask: torch.Tensor, + strength: float = 1.0, +) -> torch.Tensor: + """Подтягивает std выходного тензора к std оригинала — убирает серость на швах.""" + try: + s = float(max(0.0, min(1.0, strength))) + if s < 1e-4: + return out + vm = valid_mask.expand_as(ref) + n = vm.sum().clamp(min=1.0) + ref_std = ((ref * vm).pow(2).sum() / n).sqrt() + out_std = ((out * vm).pow(2).sum() / n).sqrt() + if out_std > 1e-6: + scale = float((1.0 + s * (ref_std / out_std - 1.0)).clamp(0.5, 2.0)) + return out * scale + except Exception: + pass + return out + + +@torch.no_grad() +def apply_canvas_zoom_latent( + x: torch.Tensor, + zoom_mode: str = "outpaint_zoom", + zoom_factor: float = 0.0, + convergence_point: float = 0.5, + convergence_y: float = 0.5, + pan_x: float = 0.0, + pan_y: float = 0.0, + interp_mode: str = "bilinear", + depth_power: float = 1.0, + blend_mode: str = "circular_reflect", + noise_strength: float = 1.0, + edge_fade: bool = True, # ранее zoom_in_fade; затухание краёв fill-зоны + variance_correction: bool = True, + auto_clamp_pan: bool = True, + adaptive_noise_scale: bool = True, + fade_to_black: bool = False, + fade_strength: float = 0.15, + # gradient_radial params — используются в blend_mode="gradient_radial" + gradient_center_x: float = 0.5, # центр радиального gradient [0..1] + gradient_center_y: float = 0.5, + gradient_radius: float = 1.0, # нормированный радиус (1.0 = до угла кадра) + # noise_blend params — используются в blend_mode="noise_blend" + noise_scale: float = 5.0, # частота процедурной шум-маски (1..20) + noise_octaves: int = 2, # количество октав шум-маски + seed: int = -1, +) -> torch.Tensor: + """ + Outpaint-style latent zoom: контент + coherent fill + параметрический blend. + Поддерживает: outpaint_zoom, blend_transition, hybrid. + + blend_mode="gradient_radial": радиальный gradient с центром gradient_center_x/y + и радиусом gradient_radius. Центр кадра → полный контент, края → fill. + blend_mode="noise_blend": процедурная шум-маска с параметрами noise_scale / + noise_octaves. Grain-стиль смешивания контента и fill. + + edge_fade: затухание fill-зоны у границ кадра (уменьшает артефакты при + zoom-out, где fill доходит до краёв; независимо от zoom_factor sign). + """ + if x is None or x.ndim != 4: + return x + + zf = float(zoom_factor) + pax = float(pan_x) + pay = float(pan_y) + zm = str(zoom_mode).lower() + orig_dtype = x.dtype + xf = x.float() + b, c, h, w = xf.shape + device = xf.device + + scale = max(0.05, min(10.0, 1.0 + zf * 0.1)) + + if auto_clamp_pan and abs(zf) > 1e-4: + # Only clamp when zoom is active — at scale=1.0 max_pan=0 kills all pan. + max_pan = max(0.0, 1.0 - 1.0 / scale) * 0.95 + pax = max(-max_pan, min(max_pan, pax)) + pay = max(-max_pan, min(max_pan, pay)) + + cx = (float(convergence_point) - 0.5) * 2.0 + cy = (float(convergence_y) - 0.5) * 2.0 + ox = pax * 2.0 + oy = pay * 2.0 + + yy = torch.linspace(-1.0, 1.0, h, device=device, dtype=torch.float32) + xx = torch.linspace(-1.0, 1.0, w, device=device, dtype=torch.float32) + y_grid, x_grid = torch.meshgrid(yy, xx, indexing="ij") + + x_new = (x_grid - cx) / scale + cx - ox + y_new = (y_grid - cy) / scale + cy - oy + grid = torch.stack((x_new, y_new), dim=-1).unsqueeze(0).expand(b, -1, -1, -1) + + # hybrid: лёгкий geometry pre-warp перед canvas fill + if zm == "hybrid": + x_src = apply_geometry_zoom_latent( + xf, zoom_factor=zf * 0.4, + convergence_point=convergence_point, convergence_y=convergence_y, + pan_x=pax * 0.5, pan_y=pay * 0.5, + interp_mode=interp_mode, zoom_mode="grid_warp", + depth_power=depth_power, auto_clamp_pan=False, + ) + else: + x_src = xf + + _mode = str(interp_mode).lower() + if _mode not in ("bilinear", "bicubic", "nearest"): + _mode = "bilinear" + try: + content = F.grid_sample(x_src, grid, mode=_mode, + padding_mode="zeros", align_corners=True) + except Exception: + content = F.grid_sample(x_src, grid, mode="bilinear", + padding_mode="zeros", align_corners=True) + + valid_mask = _latent_valid_mask(grid) + ns = float(noise_strength) * (0.5 if zm == "blend_transition" else 1.0) + + # depth_power: степенная кривая для fill-strength distribution. + # dp > 1 → fill сильнее только у дальних краёв (контент держится дольше) + # dp < 1 → fill равномернее по всей пустой зоне + # Применяем к distance map перед вычислением adaptive_noise_scale и alpha. + dp = max(0.1, float(depth_power)) + if abs(dp - 1.0) > 0.05: + # Модифицируем valid_mask через curved distance: это влияет на + # adaptive fill strength и на soft-blend alpha в blend_transition. + dist_raw = _distance_from_valid(valid_mask, blur_r=0).clamp(1e-6, 1.0) + dist_curved = dist_raw.pow(dp) + # Пересчитываем valid_mask как 1 - dist_curved (более агрессивная/мягкая граница) + _vm_curved = (1.0 - dist_curved).clamp(0.0, 1.0) + else: + _vm_curved = valid_mask + + fill = _make_canvas_fill(xf, _vm_curved, noise_strength=ns, + adaptive_noise_scale=adaptive_noise_scale, seed=seed) + + # ── Build alpha (content weight) ──────────────────────────────────── + # Все ветки используют _vm_curved, чтобы depth_power влиял на весь canvas pipeline: + # gradient_radial, noise_blend, blend_transition, outpaint/hybrid, edge_fade, variance. + bm = str(blend_mode).lower() + + if bm == "gradient_radial": + # Радиальный gradient с настраиваемым центром и радиусом. + # gradient_center_x/y в [0..1] → NDC [-1..1] + gcx = (float(gradient_center_x) - 0.5) * 2.0 + gcy = (float(gradient_center_y) - 0.5) * 2.0 + gr_ = max(0.05, float(gradient_radius)) + dx_g = x_grid - gcx + dy_g = y_grid - gcy + r_g = (dx_g.pow(2) + dy_g.pow(2)).sqrt() + radial = (1.0 - (r_g / gr_).clamp(0.0, 1.0)) # 1=centre, 0=edge + alpha = (_vm_curved * radial.unsqueeze(0).unsqueeze(0)).clamp(0.0, 1.0) + + elif bm == "noise_blend": + # Процедурная шум-маска с noise_scale и noise_octaves. + # Несколько октав усредняются → grain-стиль перехода. + ns_k = max(1.0, min(50.0, float(noise_scale))) + n_oct = max(1, min(8, int(noise_octaves))) + g_gen = torch.Generator(device=device) + g_gen.manual_seed(seed if seed >= 0 else 42) + na_acc = torch.zeros(b, 1, h, w, device=device, dtype=torch.float32) + amplitude = 1.0 + freq = 1.0 + amp_sum = 0.0 + for _oct in range(n_oct): + raw = torch.rand(b, 1, h, w, device=device, dtype=torch.float32, + generator=g_gen) + # Smooth at current frequency + k_size = max(3, min(h, w, int(ns_k / freq))) + k_size = k_size if k_size % 2 == 1 else k_size + 1 + raw_sm = F.avg_pool2d(raw, k_size, stride=1, padding=k_size // 2) + na_acc += raw_sm * amplitude + amp_sum += amplitude + amplitude *= 0.5 + freq *= 2.0 + na_norm = (na_acc / amp_sum).clamp(0.0, 1.0) + alpha = (_vm_curved * na_norm).clamp(0.0, 1.0) + + elif zm == "blend_transition" or bm not in ( + "circular_reflect", "circular_constant", "reflect_constant", + "polar_circular", "mirror_circular", "aniso_circular", "custom", + ): + # blend_transition + любые остальные blend-моды: soft edge через distance. + # _vm_curved уже учитывает depth_power, так что: + # dp > 1 → контент "держится" дольше, soft-зона уже; + # dp < 1 → переход мягче, fill заходит дальше. + dist = _distance_from_valid(_vm_curved, blur_r=max(2, min(h, w) // 16)) + alpha = (_vm_curved * (1.0 - dist)).clamp(0.0, 1.0) + + else: + # outpaint_zoom + hybrid + стандартные режимы. + # Используем _vm_curved: при dp != 1 граница контента становится мягче/жёстче. + alpha = _vm_curved + + # ── Mix content + fill ─────────────────────────────────────────────── + alpha_c = alpha.expand(b, c, h, w) + out = content * alpha_c + fill * (1.0 - alpha_c) + + # ── Edge fade: затухание fill-зоны у границ кадра ─────────────────── + # Использует _vm_curved — depth_power влияет на форму fade-зоны так же, + # как и на alpha: dp > 1 → fade только у дальних краёв, dp < 1 → равномернее. + if edge_fade and float(fade_strength) > 0.0: + has_fill = (valid_mask.mean() < 0.999) # проверяем реальный valid_mask + if has_fill: + dist_edge = _distance_from_valid(_vm_curved, blur_r=4) + fade_w = (1.0 - dist_edge * float(fade_strength)).clamp(0.0, 1.0) + out = out * fade_w.expand(b, c, h, w) + + # ── Fade to black ──────────────────────────────────────────────────── + if fade_to_black and float(fade_strength) > 0.0: + r_ = (x_grid.pow(2) + y_grid.pow(2)).sqrt().clamp(0.0, 1.414) + fm = (1.0 - (r_ * float(fade_strength)).clamp(0.0, 1.0)).view(1, 1, h, w) + out = out * fm.expand(b, c, h, w) + + # ── Variance correction ────────────────────────────────────────────── + # _vm_curved как reference mask: std вычисля��тся по depth-weighted content zone. + if variance_correction: + out = _apply_variance_correction(out, xf, _vm_curved) + + return out.to(dtype=orig_dtype) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# ZOOM COMPOSE HELPERS — Pinned Subject & Dual-Transform ROI +# ═══════════════════════════════════════════════════════════════════════════════ + +@torch.no_grad() +def make_soft_roi_mask( + x, + cx: float = 0.5, cy: float = 0.5, + rx: float = 0.18, ry: float = 0.18, + feather: float = 0.08, + shape: str = "ellipse", +) -> torch.Tensor: + """ + Build a soft ROI mask in latent space [B, 1, H, W]. + mask=1 inside ROI, mask=0 outside, smooth transition on boundary. + cx/cy/rx/ry are fractions [0..1] relative to spatial dims. + """ + b, c, h, w = x.shape + device, dtype = x.device, x.dtype + yy = torch.linspace(0.0, 1.0, h, device=device, dtype=dtype) + xx = torch.linspace(0.0, 1.0, w, device=device, dtype=dtype) + y_grid, x_grid = torch.meshgrid(yy, xx, indexing="ij") + dx = (x_grid - float(cx)) / max(float(rx), 1e-6) + dy = (y_grid - float(cy)) / max(float(ry), 1e-6) + if str(shape).lower() == "box": + d = torch.maximum(dx.abs(), dy.abs()) + else: + d = torch.sqrt(dx * dx + dy * dy + 1e-8) + # smoothstep from d=1.0 (inner edge) to d=1+feather (outer edge) + f = max(float(feather), 1e-6) + t = ((d - 1.0) / f).clamp(0.0, 1.0) + smooth = t * t * (3.0 - 2.0 * t) # smoothstep + mask = 1.0 - smooth # inside=1, outside=0 + return mask.unsqueeze(0).unsqueeze(0).expand(b, 1, h, w).contiguous() + + +@torch.no_grad() +def compose_pinned_subject( + x_orig: torch.Tensor, + x_bg: torch.Tensor, + mask: torch.Tensor, + preserve_strength: float = 1.0, +) -> torch.Tensor: + """ + Pinned Subject: ROI area preserved from original, outside gets bg zoom. + out = x_bg * (1 - mask*s) + x_orig * (mask*s) + """ + a = mask.clamp(0.0, 1.0) * float(max(0.0, min(1.0, preserve_strength))) + return x_bg * (1.0 - a) + x_orig * a + + +@torch.no_grad() +def apply_zoom_compose_latent( + src: torch.Tensor, + base_zoom_params: dict, + compose_mode: str = "Global", + roi_shape: str = "ellipse", + roi_center_x: float = 0.5, + roi_center_y: float = 0.5, + roi_radius_x: float = 0.18, + roi_radius_y: float = 0.18, + roi_feather: float = 0.08, + roi_preserve_strength: float = 1.0, + roi_fg_zoom_factor: float = 0.0, + roi_bg_zoom_factor: float = 0.0, + seed: int = -1, +) -> torch.Tensor: + """ + Zoom composition router: + Global → standard apply_advanced_zoom_latent (unchanged) + Pinned Subject → bg zooms, ROI area stays near original + Dual-Transform → bg and fg get independent zoom factors + Geometry backend is enforced for compose modes to avoid canvas fill artefacts. + """ + mode = str(compose_mode or "Global") + + if mode == "Global": + return apply_advanced_zoom_latent(src, base_zoom_params, seed=seed) + + # Build ROI mask — on same device/dtype as src + mask = make_soft_roi_mask( + src, + cx=roi_center_x, cy=roi_center_y, + rx=roi_radius_x, ry=roi_radius_y, + feather=roi_feather, + shape=roi_shape, + ) + + if mode == "Pinned Subject": + # Force geometry backend for clean results + zp_bg = dict(base_zoom_params or {}) + zp_bg["zoom_engine"] = "geometry" + zp_bg.setdefault("zoom_mode", "grid_warp") + x_bg = apply_advanced_zoom_latent(src, zp_bg, seed=seed) + return compose_pinned_subject(src, x_bg, mask, preserve_strength=roi_preserve_strength) + + if mode == "Dual-Transform ROI": + zp_bg = dict(base_zoom_params or {}) + zp_fg = dict(base_zoom_params or {}) + # Dual mode: each gets independent zoom_factor, geometry only. + # NOTE: roi_bg/fg_zoom_factor are FIXED values set in UI and override + # any animated _zoom_factor_override from the range interpolation. + # This is intentional — Dual-Transform users control fg/bg zoom directly. + zp_bg["zoom_factor"] = float(roi_bg_zoom_factor) + zp_fg["zoom_factor"] = float(roi_fg_zoom_factor) + zp_bg["zoom_engine"] = "geometry" + zp_fg["zoom_engine"] = "geometry" + zp_bg.setdefault("zoom_mode", "grid_warp") + zp_fg.setdefault("zoom_mode", "grid_warp") + x_bg = apply_advanced_zoom_latent(src, zp_bg, seed=seed) + x_fg = apply_advanced_zoom_latent(src, zp_fg, seed=seed) + return x_fg * mask + x_bg * (1.0 - mask) + + # Fallback + return apply_advanced_zoom_latent(src, base_zoom_params, seed=seed) + + +@torch.no_grad() +def apply_advanced_zoom_latent( + x: torch.Tensor, + zp: dict, + seed: int = -1, +) -> torch.Tensor: + """ + Главный роутер: geometry или canvas backend по zoom_mode / zoom_engine. + zp — словарь всех zoom-параметров (validate_zoom_params-compatible). + """ + if x is None or x.ndim != 4: + return x + + zm = zp.get("zoom_mode", "grid_warp") + if hasattr(zm, "value"): + zm = str(zm.value) + else: + zm = str(zm).lower() + + backend = resolve_zoom_backend(zm, zp.get("zoom_engine", "auto")) + + bm = zp.get("blend_mode", "circular_reflect") + if hasattr(bm, "value"): + bm = str(bm.value) + else: + bm = str(bm).lower() + + edge_fade_val = bool(zp.get("edge_fade", zp.get("zoom_in_fade", True))) + fade_to_black_val = bool(zp.get("zoom_fade_to_black", zp.get("fade_to_black", False))) + fade_strength_val = float(zp.get("zoom_fade_strength", zp.get("fade_strength", 0.15))) + + if backend == ZOOM_BACKEND_GEOMETRY: + return apply_geometry_zoom_latent( + x, + zoom_factor = float(zp.get("zoom_factor", 0.0)), + convergence_point = float(zp.get("convergence_point", 0.5)), + convergence_y = float(zp.get("convergence_y", 0.5)), + pan_x = float(zp.get("pan_x", 0.0)), + pan_y = float(zp.get("pan_y", 0.0)), + interp_mode = str(zp.get("interp_mode", "bilinear")), + zoom_mode = zm, + spiral_rotation = float(zp.get("spiral_rotation", 0.5)), + spiral_direction = coerce_spiral_direction(zp.get("spiral_direction", 1.0)), + depth_power = float(zp.get("depth_power", 1.0)), + auto_clamp_pan = bool(zp.get("auto_clamp_pan", True)), + ) + return apply_canvas_zoom_latent( + x, + zoom_mode = zm, + zoom_factor = float(zp.get("zoom_factor", 0.0)), + convergence_point = float(zp.get("convergence_point", 0.5)), + convergence_y = float(zp.get("convergence_y", 0.5)), + pan_x = float(zp.get("pan_x", 0.0)), + pan_y = float(zp.get("pan_y", 0.0)), + interp_mode = str(zp.get("interp_mode", "bilinear")), + depth_power = float(zp.get("depth_power", 1.0)), + blend_mode = bm, + noise_strength = float(zp.get("noise_strength", 1.0)), + edge_fade = edge_fade_val, + variance_correction = bool(zp.get("variance_correction", True)), + auto_clamp_pan = bool(zp.get("auto_clamp_pan", True)), + adaptive_noise_scale= bool(zp.get("adaptive_noise_scale",True)), + fade_to_black = fade_to_black_val, + fade_strength = fade_strength_val, + gradient_center_x = float(zp.get("gradient_center_x", 0.5)), + gradient_center_y = float(zp.get("gradient_center_y", 0.5)), + gradient_radius = float(zp.get("gradient_radius", 1.0)), + noise_scale = float(zp.get("noise_scale", 5.0)), + noise_octaves = int(zp.get("noise_octaves", 2)), + seed = int(seed), + ) + + +@torch.no_grad() +def sample_kohaku_lonyu_yog(model, x, sigmas, extra_args=None, callback=None, + disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), + s_noise=1., noise_sampler=None, eta=1.): + """ + Kohaku_LoNyu_Yog Sampler - Geometric Second-Order Method + """ + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + + steps_total = len(sigmas) - 1 + halfway_point = steps_total // 2 + + for i in trange(steps_total, disable=disable, desc="Kohaku Sampling"): + gamma = min(s_churn / steps_total, 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + sigma_hat = sigmas[i] * (gamma + 1) + + if gamma > 0: + eps = torch.randn_like(x) * s_noise + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + dt = sigma_down - sigmas[i] + + if i <= halfway_point: + x_antipode = -x + + denoised2 = model(x_antipode, sigma_hat * s_in, **extra_args) + d2 = to_d(x_antipode, sigma_hat, denoised2) + + v_down = (d + d2) / 2 + x_closer = x + v_down * dt + + denoised3 = model(x_closer, sigma_hat * s_in, **extra_args) + d3 = to_d(x_closer, sigma_hat, denoised3) + + real_d = (d + d3) / 2 + x = x + real_d * dt + + if sigma_up > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + else: + x = x + d * dt + if sigma_up > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + + if callback is not None: + callback({ + 'x': x, + 'i': i, + 'sigma': sigmas[i], + 'sigma_hat': sigma_hat, + 'denoised': denoised + }) + + return x + +# ======================================================================== +# LRU CACHE SYSTEM - Enhanced memory management +# ======================================================================== + +class LRUCache: + """Thread-safe LRU cache with size limits and memory management""" + def __init__(self, max_size=50, max_memory_mb=500): + self.cache = OrderedDict() + self.max_size = max_size + self.max_memory_mb = max_memory_mb + self.current_memory_mb = 0 + + def _estimate_size_mb(self, tensor): + """Estimate tensor size in MB""" + if isinstance(tensor, torch.Tensor): + return tensor.element_size() * tensor.nelement() / (1024 * 1024) + return 0 + + def get(self, key): + """Get cached value, move to end (most recent)""" + if key in self.cache: + self.cache.move_to_end(key) + return self.cache[key] + return None + + def set(self, key, value): + """Set value with automatic eviction if needed""" + size_mb = self._estimate_size_mb(value) + + # Evict old entries if needed + while (len(self.cache) >= self.max_size or + self.current_memory_mb + size_mb > self.max_memory_mb): + if len(self.cache) == 0: + break + old_key, old_value = self.cache.popitem(last=False) + self.current_memory_mb -= self._estimate_size_mb(old_value) + if isinstance(old_value, torch.Tensor) and old_value.is_cuda: + del old_value + + self.cache[key] = value + self.cache.move_to_end(key) + self.current_memory_mb += size_mb + + def clear(self): + """Clear all cache""" + for value in self.cache.values(): + if isinstance(value, torch.Tensor) and value.is_cuda: + del value + self.cache.clear() + self.current_memory_mb = 0 + +# Enhanced cache instances (keep compatibility with old dict-based caches) +_VORONOI_CACHE = LRUCache(max_size=20, max_memory_mb=100) + +# ======================================================================== +# ZOOM EASING HELPERS +# ======================================================================== + +def _zoom_easing_fn(t: float, mode: str) -> float: + """ + Map t ∈ [0,1] through an easing curve. + instant — hard cut: 0 everywhere, 1 at t=1 + linear — no curve + ease_in — starts slow, accelerates + ease_out — starts fast, decelerates + ease_in_out — S-curve (default) + """ + t = max(0.0, min(1.0, t)) + if mode == "instant": + return 1.0 if t >= 1.0 else 0.0 + if mode == "linear": + return t + if mode == "ease_in": + return t * t + if mode == "ease_out": + return 1.0 - (1.0 - t) * (1.0 - t) + # ease_in_out (default / fallback) + if t < 0.5: + return 2.0 * t * t + return 1.0 - (-2.0 * t + 2.0) ** 2 / 2.0 + + +# ======================================================================== +# VALIDATION UTILITIES - Protect against NaN/Inf +# ======================================================================== + +def validate_float(value, min_val=None, max_val=None, default=0.0, name="parameter"): + """Safely validate and clamp float values""" + try: + val = float(value) + if math.isnan(val) or math.isinf(val): + print(f"⚠ Warning: {name} is NaN/Inf, using default {default}") + return default + if min_val is not None and val < min_val: + return min_val + if max_val is not None and val > max_val: + return max_val + return val + except (ValueError, TypeError): + print(f"⚠ Warning: Invalid {name}, using default {default}") + return default + +def validate_int(value, min_val=None, max_val=None, default=0, name="parameter"): + """Safely validate and clamp integer values""" + try: + val = int(value) + if min_val is not None and val < min_val: + return min_val + if max_val is not None and val > max_val: + return max_val + return val + except (ValueError, TypeError): + print(f"⚠ Warning: Invalid {name}, using default {default}") + return default + +def coerce_spiral_direction(value, default=1.0): + """ + Normalise spiral_direction to ±1.0. + Accepts floats/ints OR the label strings that Gradio Radio may return + when choices are defined as (label, value) tuples. + """ + if isinstance(value, str): + s = value.strip().lower() + if s in {"clockwise", "cw", "+1", "1", "1.0"}: + return 1.0 + if s in {"counter-clockwise", "counterclockwise", + "anti-clockwise", "anticlockwise", "ccw", "-1", "-1.0"}: + return -1.0 + return default + try: + v = float(value) + return 1.0 if v >= 0 else -1.0 + except (TypeError, ValueError): + return default + +def validate_tensor(tensor, name="tensor"): + """Check tensor for NaN/Inf values""" + if isinstance(tensor, torch.Tensor): + if torch.isnan(tensor).any(): + print(f"⚠ Warning: {name} contains NaN values!") + return torch.nan_to_num(tensor, nan=0.0) + if torch.isinf(tensor).any(): + print(f"⚠ Warning: {name} contains Inf values!") + return torch.nan_to_num(tensor, posinf=1.0, neginf=-1.0) + return tensor + +# ======================================================================== +# РАСШИРЕННЫЕ ФУНКЦИИ ПАДДИНГА +# ======================================================================== + +def get_or_create_mask(h, w, device): + """Кэширование масок для оптимизации""" + key = (h, w, str(device)) + if key not in _MASK_CACHE: + row_indices = torch.arange(h, device=device).view(1, 1, h, 1) + _MASK_CACHE[key] = (row_indices % 2 == 1) + return _MASK_CACHE[key] + +def compute_anisotropic_padding(input_tensor, pad_h, pad_w, angle_deg=45, angle_deg2=None, angle_mix=1.0): + """ + Анизотропный паддинг - разное поведение по диагоналям. + Эмулирует направленные материалы (дерево, металл, волокна). + """ + b, c, h, w = input_tensor.shape + + # Базовые паддинги: circular и reflect + padded = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='circular') + padded_reflect = _safe_pad4d(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='reflect') + + # Размеры уже с учетом паддинга + _, H, W = padded.shape[1:] + + # Преобразуем угол в радианы + angle_rad = math.radians(angle_deg) + + # Координаты в нормализованной системе для всего padded-тензора + y_coords = torch.linspace(-1.0, 1.0, steps=H, device=input_tensor.device, dtype=input_tensor.dtype).view(1, 1, H, 1) + x_coords = torch.linspace(-1.0, 1.0, steps=W, device=input_tensor.device, dtype=input_tensor.dtype).view(1, 1, 1, W) + + # Проекция на направление волокон + directional_component = x_coords * math.cos(angle_rad) + y_coords * math.sin(angle_rad) + directional_strength = directional_component.abs() + + # Optional second direction (advanced): blend two angle fields + if angle_deg2 is not None: + angle_mix = float(max(0.0, min(float(angle_mix), 1.0))) + angle_rad2 = math.radians(float(angle_deg2)) + directional_component2 = x_coords * math.cos(angle_rad2) + y_coords * math.sin(angle_rad2) + directional_strength2 = directional_component2.abs() + directional_strength = directional_strength * angle_mix + directional_strength2 * (1.0 - angle_mix) + + # Альфа-блендинг: вдоль направления больше circular, поперек больше reflect + alpha = directional_strength.clamp(0.0, 1.0) + result = padded * alpha + padded_reflect * (1.0 - alpha) + + return result + +def compute_polar_padding(input_tensor, pad_h, pad_w): + """ + Полярный паддинг для сферических проекций. + """ + b, c, h, w = input_tensor.shape + + # X-axis: стандартный circular (долгота замыкается) + x = F.pad(input_tensor, (pad_w, pad_w, 0, 0), mode='circular') + + # Y-axis: полярная коррекция (широта через полюса) + shift = w // 2 + + # Верхний паддинг (Северный полюс) + top_strip = x[:, :, :pad_h, :] + top_pad = torch.roll(top_strip, shifts=shift, dims=3) + top_pad = torch.flip(top_pad, dims=[2]) + + # Нижний паддинг (Южный полюс) + bot_strip = x[:, :, -pad_h:, :] + bot_pad = torch.roll(bot_strip, shifts=shift, dims=3) + bot_pad = torch.flip(bot_pad, dims=[2]) + + result = torch.cat([top_pad, x, bot_pad], dim=2) + return result + + +# ======================================================================== +# NEW ENHANCED MODES - Voronoi, Perlin, Fractal, Adaptive +# ======================================================================== + +# ===== EDGE DETECTION for Adaptive Mode ===== + +def detect_edges_sobel(tensor): + """ + Sobel edge detection для адаптивного блендинга + ✅ FLOAT16 FIX: Безопасная нормализация + """ + # Convert to grayscale if needed + if tensor.shape[1] > 1: + gray = tensor.mean(dim=1, keepdim=True) + else: + gray = tensor + + # Sobel kernels + sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], + dtype=tensor.dtype, device=tensor.device).view(1, 1, 3, 3) + sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], + dtype=tensor.dtype, device=tensor.device).view(1, 1, 3, 3) + + # Apply Sobel + edge_x = F.conv2d(gray, sobel_x, padding=1) + edge_y = F.conv2d(gray, sobel_y, padding=1) + + # ✅ FIX: Адаптивный epsilon для sqrt и нормализации + eps_val = get_safe_epsilon(tensor.dtype) + + # Magnitude с защитой + edges = torch.sqrt(edge_x ** 2 + edge_y ** 2 + eps_val) + + # Normalize с защитой от деления на ноль + edges_max = torch.clamp(edges.max(), min=eps_val) + edges = edges / edges_max + + return edges + +# ===== VORONOI ORGANIC TILING ===== + +def compute_voronoi_padding(input_tensor, pad_h, pad_w, num_cells=8, seed=42): + """ + Voronoi diagram-based organic tiling + ✅ FLOAT16 FIX: Защищенный sqrt и деление + """ + b, c, h, w = input_tensor.shape + device = input_tensor.device + dtype = input_tensor.dtype # ✅ FIX: Получаем dtype + + # Validate parameters + num_cells = validate_int(num_cells, min_val=4, max_val=32, default=8, name="voronoi_cells") + seed = validate_int(seed, min_val=0, max_val=99999, default=42, name="voronoi_seed") + + # ✅ FIX: Адаптивный epsilon + eps_val = get_safe_epsilon(dtype) + + # Cache key + cache_key = (h, w, pad_h, pad_w, num_cells, seed, str(device)) + cached_map = _VORONOI_CACHE.get(cache_key) + + if cached_map is None: + # Create extended canvas + H_ext = h + 2 * pad_h + W_ext = w + 2 * pad_w + + # Generate random Voronoi cell centers + # BUG FIX 9 (continued): local Generator keeps Voronoi reproducible + # without touching the global RNG state. + _vg = torch.Generator(device=device) + _vg.manual_seed(seed) + centers_y = torch.rand(num_cells, device=device, generator=_vg) * H_ext + centers_x = torch.rand(num_cells, device=device, generator=_vg) * W_ext + + # Create coordinate grids + y_grid = torch.arange(H_ext, device=device, dtype=torch.float32).view(-1, 1) + x_grid = torch.arange(W_ext, device=device, dtype=torch.float32).view(1, -1) + + # Compute distances to all centers + distances = [] + for i in range(num_cells): + # ✅ FIX: Добавлен eps_val для защиты sqrt + dist = torch.sqrt((y_grid - centers_y[i])**2 + (x_grid - centers_x[i])**2 + eps_val) + distances.append(dist) + + distances = torch.stack(distances, dim=0) + + # Find nearest cell for each pixel + nearest_cell = torch.argmin(distances, dim=0) + + # Create mapping from extended to source coordinates + voronoi_map = torch.zeros(2, H_ext, W_ext, device=device, dtype=torch.float32) + + for cell_id in range(num_cells): + mask = (nearest_cell == cell_id) + # Map this region to corresponding source region (with wrapping) + y_offset = (centers_y[cell_id] % h) - h // 2 + x_offset = (centers_x[cell_id] % w) - w // 2 + + y_coords_masked = y_grid.expand_as(mask)[mask] + x_coords_masked = x_grid.expand_as(mask)[mask] + + voronoi_map[0][mask] = (y_coords_masked - y_offset) % h + voronoi_map[1][mask] = (x_coords_masked - x_offset) % w + + _VORONOI_CACHE.set(cache_key, voronoi_map) + cached_map = voronoi_map + + # Apply mapping using grid_sample + # Normalize coordinates to [-1, 1] + grid = cached_map.permute(1, 2, 0).unsqueeze(0) # [1, H_ext, W_ext, 2] + grid = grid.to(dtype=torch.float32) # fp16-safe: grid всегда float32 + + # ✅ FIX: Защита от деления на 0 + h_safe = max(h - 1, 1) + w_safe = max(w - 1, 1) + grid[..., 0] = (grid[..., 0] / h_safe) * 2 - 1 # y + grid[..., 1] = (grid[..., 1] / w_safe) * 2 - 1 # x + + # Sample from input + result = F.grid_sample( + input_tensor, + grid, + mode='bilinear', + padding_mode='border', + align_corners=False + ) + + return result + +# ===== PERLIN NOISE DISTORTION ===== + +def generate_perlin_noise(height, width, scale=10.0, octaves=4, persistence=0.5, device='cuda'): + """ + Generate Perlin-like noise using multiple octaves of smoothed noise + """ + def smooth_noise(noise): + return F.avg_pool2d( + F.pad(noise, (1, 1, 1, 1), mode='reflect'), + kernel_size=3, stride=1, padding=0 + ) + + total_noise = torch.zeros(1, 1, height, width, device=device) + amplitude = 1.0 + frequency = 1.0 + max_value = 0.0 + + for _ in range(octaves): + # Generate random noise at this frequency + noise_h = int(height / scale * frequency) + 1 + noise_w = int(width / scale * frequency) + 1 + + noise = torch.rand(1, 1, noise_h, noise_w, device=device) * 2 - 1 + + # Smooth and upscale + for _ in range(3): # Multiple smoothing passes + noise = smooth_noise(noise) + + # Resize to target size + noise = F.interpolate(noise, size=(height, width), mode='bilinear', align_corners=False) + + total_noise += noise * amplitude + max_value += amplitude + + amplitude *= persistence + frequency *= 2.0 + + # Normalize + total_noise /= max_value + return total_noise + +def compute_perlin_padding(input_tensor, pad_h, pad_w, strength=0.3, scale=10.0): + """ + Apply Perlin noise distortion for natural-looking tiling + """ + # Validate parameters + strength = validate_float(strength, min_val=0.0, max_val=1.0, default=0.3, name="perlin_strength") + scale = validate_float(scale, min_val=1.0, max_val=50.0, default=10.0, name="perlin_scale") + + b, c, h, w = input_tensor.shape + device = input_tensor.device + + # BUG FIX 8: Compute extended canvas dimensions directly instead of + # allocating a full padded tensor that is used only to read .shape. + # This saves one full-size allocation. + H_ext = h + 2 * pad_h + W_ext = w + 2 * pad_w + + # Generate Perlin noise for displacement + noise_y = generate_perlin_noise(H_ext, W_ext, scale=scale, device=device) + noise_x = generate_perlin_noise(H_ext, W_ext, scale=scale * 1.3, device=device) + + # Create coordinate grid + y_coords = torch.linspace(0, h - 1, H_ext, device=device, dtype=torch.float32).view(1, 1, H_ext, 1).expand(1, 1, H_ext, W_ext) + x_coords = torch.linspace(0, w - 1, W_ext, device=device, dtype=torch.float32).view(1, 1, 1, W_ext).expand(1, 1, H_ext, W_ext) + + # Apply distortion + distortion_strength = strength * max(pad_h, pad_w) + y_distorted = (y_coords + noise_y * distortion_strength) % h + x_distorted = (x_coords + noise_x * distortion_strength) % w + + # Create sampling grid + # FIX V3.6: Убираем лишнюю размерность канала .squeeze(1) + grid = torch.stack([x_distorted, y_distorted], dim=-1).squeeze(1) # Становится (1, H_ext, W_ext, 2) + + # BUG FIX 7: protect against 1-pixel inputs (w-1 == 0 → div-by-zero / NaN). + # Voronoi already uses this pattern; apply it here consistently. + h_safe = max(h - 1, 1) + w_safe = max(w - 1, 1) + grid[..., 0] = (grid[..., 0] / w_safe) * 2 - 1 # Normalize x to [-1, 1] + grid[..., 1] = (grid[..., 1] / h_safe) * 2 - 1 # Normalize y to [-1, 1] + + # Sample with distortion (from input_tensor directly, not from padded) + result = F.grid_sample( + input_tensor, + grid.expand(b, -1, -1, -1), + mode='bilinear', + padding_mode='border', + align_corners=True + ) + + return result + +# ===== FRACTAL RECURSIVE TILING ===== + +def compute_fractal_padding(input_tensor, pad_h, pad_w, iterations=2, scale_factor=0.6): + """ + Fractal recursive tiling - creates self-similar patterns + Each iteration adds scaled version of the pattern + """ + iterations = validate_int(iterations, min_val=1, max_val=4, default=2, name="fractal_iterations") + scale_factor = validate_float(scale_factor, min_val=0.3, max_val=0.9, default=0.6, name="fractal_scale") + + # Base padding + result = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='circular') + + b, c, h_ext, w_ext = result.shape + + # Add fractal details + for i in range(1, iterations + 1): + scale = scale_factor ** i + + # Scale down the pattern + scaled_h = max(8, int(h_ext * scale)) + scaled_w = max(8, int(w_ext * scale)) + + scaled_pattern = F.interpolate( + result, + size=(scaled_h, scaled_w), + mode='bilinear', + align_corners=False + ) + + # Tile it across the result + tiles_y = (h_ext + scaled_h - 1) // scaled_h + tiles_x = (w_ext + scaled_w - 1) // scaled_w + + tiled = scaled_pattern.repeat(1, 1, tiles_y, tiles_x) + tiled = tiled[:, :, :h_ext, :w_ext] + + # Blend with decreasing strength + alpha = 0.3 / (i + 1) + result = result * (1 - alpha) + tiled * alpha + + return result + +# ===== ADAPTIVE SMART MODE ===== + +def compute_adaptive_padding(input_tensor, pad_h, pad_w, edge_threshold=0.1): + """ + Adaptive padding that analyzes image content and chooses best mode + Uses edge detection to decide between circular, reflect, or blend + """ + edge_threshold = validate_float(edge_threshold, min_val=0.0, max_val=1.0, default=0.1, name="adaptive_threshold") + + # Detect edges + edges = detect_edges_sobel(input_tensor) + + # Analyze edge intensity at borders + top_edge = edges[:, :, :min(pad_h, edges.shape[2]), :].mean() + bottom_edge = edges[:, :, -min(pad_h, edges.shape[2]):, :].mean() + left_edge = edges[:, :, :, :min(pad_w, edges.shape[3])].mean() + right_edge = edges[:, :, :, -min(pad_w, edges.shape[3]):].mean() + + avg_edge = (top_edge + bottom_edge + left_edge + right_edge) / 4 + + # Decide mode based on edge strength + if avg_edge < edge_threshold: + # Low edge content -> circular works well + return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='circular') + elif avg_edge > 0.5: + # High edge content -> use blend mode + return compute_blend_padding(input_tensor, pad_h, pad_w, strength=0.7) + else: + # Medium edge content -> adaptive blend + blend_strength = float(avg_edge) * 1.5 # Scale to [0, 0.75] + return compute_blend_padding(input_tensor, pad_h, pad_w, strength=blend_strength) + +# ===== ENHANCED BLEND MODE (improved version) ===== + +def create_edge_blend_mask(h, w, blend_width, device, dtype=torch.float32): + """ + Create gradient mask for soft edge blending + Returns mask with smooth falloff at borders + """ + mask = torch.ones(h, w, device=device, dtype=dtype) + + if blend_width <= 0: + return mask.unsqueeze(0).unsqueeze(0) + + blend_w = min(blend_width, w // 4) + blend_h = min(blend_width, h // 4) + + # Horizontal edges (left and right) + for i in range(blend_w): + alpha = (i + 1) / (blend_w + 1) + alpha = alpha ** 0.5 # Gamma correction + mask[:, i] = alpha + mask[:, -(i + 1)] = alpha + + # Vertical edges (top and bottom) + for i in range(blend_h): + alpha = (i + 1) / (blend_h + 1) + alpha = alpha ** 0.5 + mask[i, :] = torch.minimum(mask[i, :], torch.tensor(alpha, device=device, dtype=dtype)) + mask[-(i + 1), :] = torch.minimum(mask[-(i + 1), :], torch.tensor(alpha, device=device, dtype=dtype)) + + return mask.unsqueeze(0).unsqueeze(0) + +def compute_blend_padding(input_tensor, pad_h, pad_w, strength=0.5, blend_width=None): + """ + Enhanced Soft Blend Edges implementation + Blends between circular and reflect padding modes at the borders + """ + strength = validate_float(strength, min_val=0.0, max_val=1.0, default=0.5, name="blend_strength") + + # Auto blend width if not specified + if blend_width is None: + blend_width = max(pad_h, pad_w) + + # Create both padding modes + circular = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='circular') + reflect = _safe_pad4d(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='reflect') + + # Create blend mask + h, w = circular.shape[-2:] + mask = create_edge_blend_mask(h, w, blend_width, input_tensor.device, input_tensor.dtype) + + # Apply blending with strength control + alpha = mask * strength + (1 - strength) * 0.5 + alpha = alpha.expand_as(circular) + + result = circular * alpha + reflect * (1 - alpha) + return result + + +# =================================== +# CUBEMAP (3D) — Engine A (Fast) + Engine B (Seam-Blend) +# =================================== + +def _safe_pad4d(x, pad, mode='reflect', value=0.0): + """ + Safe wrapper around F.pad for 4D tensors. + - For mode='reflect', PyTorch requires pad < input_size. + If invalid, we fall back to 'replicate' to avoid runtime errors. + pad: (left, right, top, bottom) + """ + if not isinstance(pad, (tuple, list)) or len(pad) != 4: + return F.pad(x, pad, mode=mode, value=value) if mode == 'constant' else F.pad(x, pad, mode=mode) + + l, r, t, b = pad + if mode == 'reflect': + h = int(x.shape[-2]) + w = int(x.shape[-1]) + if (l >= w) or (r >= w) or (t >= h) or (b >= h): + mode = 'replicate' + + if mode == 'constant': + return F.pad(x, (l, r, t, b), mode=mode, value=value) + return F.pad(x, (l, r, t, b), mode=mode) + + +def _cubemap_split_faces(x): + """ + Splits a 3x2 cubemap net into faces. + Layout expected (top row / bottom row): + S | E | N + B | T | W + Returns tuple (S, E, N, B, T, W), each (B,C,h,w) + """ + B, C, H, W = x.shape + if H % 2 != 0 or W % 3 != 0: + raise ValueError("Cubemap expects H%2==0 and W%3==0 (3x2 net).") + h, w = H // 2, W // 3 + S = x[:, :, 0:h, 0:w] + E = x[:, :, 0:h, w:2*w] + N = x[:, :, 0:h, 2*w:3*w] + Bm = x[:, :, h:2*h, 0:w] + T = x[:, :, h:2*h, w:2*w] + Wf = x[:, :, h:2*h, 2*w:3*w] + return S, E, N, Bm, T, Wf + + +def _cubemap_stitch_faces(S, E, N, Bm, T, Wf): + """Stitches faces back into a 3x2 net (S/E/N over B/T/W).""" + B, C, h, w = S.shape + out = torch.zeros((B, C, h * 2, w * 3), device=S.device, dtype=S.dtype) + out[:, :, 0:h, 0:w] = S + out[:, :, 0:h, w:2*w] = E + out[:, :, 0:h, 2*w:3*w] = N + out[:, :, h:2*h, 0:w] = Bm + out[:, :, h:2*h, w:2*w] = T + out[:, :, h:2*h, 2*w:3*w] = Wf + return out + + +def _cubemap_pad_with_adjoint(O, L, R, U, D, pL, pR, pU, pD, pad_mode='replicate', + seam_strength=0.0, seam_width=0): + """ + Pads a face O with neighbor strips L/R/U/D (already extracted from adjacent faces). + Supports optional seam blending (Engine B) by mixing neighbor padding with O edge. + """ + B, C, h, w = O.shape + Hp = h + pU + pD + Wp = w + pL + pR + Z = torch.zeros((B, C, Hp, Wp), device=O.device, dtype=O.dtype) + Z[:, :, pU:pU + h, pL:pL + w] = O + + if pL == 0 and pR == 0 and pU == 0 and pD == 0: + return Z + + # Helper: create ramp for seam_width (0 at boundary, 1 at outer pad) + def _make_ramp(n, seam_w, device, dtype): + if n <= 0: + return None + seam_w = int(max(0, min(seam_w, n))) + if seam_w == 0: + return torch.ones((n,), device=device, dtype=dtype) + if seam_w == 1: + ramp = torch.ones((n,), device=device, dtype=dtype) + ramp[0] = 0.0 + return ramp + ramp = torch.ones((n,), device=device, dtype=dtype) + ramp[:seam_w] = torch.linspace(0.0, 1.0, steps=seam_w, device=device, dtype=dtype) + return ramp + + # Fill left/right strips + if pL > 0: + Lp = _safe_pad4d(L, (0, 0, pU, pD), mode=pad_mode) + strip = Lp + if seam_strength > 0.0: + Oedge = O[:, :, :, :min(pL, w)] + Oedge = _safe_pad4d(Oedge, (0, max(0, pL - Oedge.shape[-1]), pU, pD), mode='replicate') + ramp = _make_ramp(pL, seam_width, O.device, O.dtype).view(1, 1, 1, pL) + blend_scheme = Oedge * (1.0 - ramp) + strip * ramp + strip = strip * (1.0 - seam_strength) + blend_scheme * seam_strength + Z[:, :, :, :pL] = strip + + if pR > 0: + Rp = _safe_pad4d(R, (0, 0, pU, pD), mode=pad_mode) + strip = Rp + if seam_strength > 0.0: + Oedge = O[:, :, :, max(0, w - pR):w] + need = pR - Oedge.shape[-1] + Oedge = _safe_pad4d(Oedge, (max(0, need), 0, pU, pD), mode='replicate') + ramp = _make_ramp(pR, seam_width, O.device, O.dtype).view(1, 1, 1, pR).flip(-1) + blend_scheme = Oedge * (1.0 - ramp) + strip * ramp + strip = strip * (1.0 - seam_strength) + blend_scheme * seam_strength + Z[:, :, :, -pR:] = strip + + # Fill top/bottom strips + if pU > 0: + Up = _safe_pad4d(U, (pL, pR, 0, 0), mode=pad_mode) + strip = Up + if seam_strength > 0.0: + Oedge = O[:, :, :min(pU, h), :] + Oedge = _safe_pad4d(Oedge, (pL, pR, 0, max(0, pU - Oedge.shape[-2])), mode='replicate') + ramp = _make_ramp(pU, seam_width, O.device, O.dtype).view(1, 1, pU, 1) + blend_scheme = Oedge * (1.0 - ramp) + strip * ramp + strip = strip * (1.0 - seam_strength) + blend_scheme * seam_strength + Z[:, :, :pU, :] = strip + + if pD > 0: + Dp = _safe_pad4d(D, (pL, pR, 0, 0), mode=pad_mode) + strip = Dp + if seam_strength > 0.0: + Oedge = O[:, :, max(0, h - pD):h, :] + need = pD - Oedge.shape[-2] + Oedge = _safe_pad4d(Oedge, (pL, pR, max(0, need), 0), mode='replicate') + ramp = _make_ramp(pD, seam_width, O.device, O.dtype).view(1, 1, pD, 1).flip(-2) + blend_scheme = Oedge * (1.0 - ramp) + strip * ramp + strip = strip * (1.0 - seam_strength) + blend_scheme * seam_strength + Z[:, :, -pD:, :] = strip + + # Fix corners overlapping (same as cubemap(3).py logic) + if pU and pL: + Z[:, :, :pU, :pL] /= 2 + if pU and pR: + Z[:, :, :pU, -pR:] /= 2 + if pD and pL: + Z[:, :, -pD:, :pL] /= 2 + if pD and pR: + Z[:, :, -pD:, -pR:] /= 2 + + return Z + + +def conv2d_cubemap_batched(input_tensor, weight, bias, stride, dilation, groups, + pad_h, pad_w, pad_mode='replicate', + engine='A (Fast)', seam_width=0, seam_strength=0.0): + """ + Cubemap convolution for a 3x2 cubemap net (S/E/N over B/T/W), using 1 conv call: + - Engine A: neighbor padding (fast, like cubemap(3).py but batched) + - Engine B: same, but with seam-aware blending inside padding regions. + NOTE: Requires square faces (h == w) to keep rotations consistent. + """ + if pad_h != pad_w: + return F.conv2d(input_tensor, weight, bias, stride, (pad_h, pad_w), dilation, groups) + + if pad_h == 0 and pad_w == 0: + return F.conv2d(input_tensor, weight, bias, stride, (0, 0), dilation, groups) + + try: + S, E, N, Bm, T, Wf = _cubemap_split_faces(input_tensor) + except Exception: + return F.conv2d(input_tensor, weight, bias, stride, (pad_h, pad_w), dilation, groups) + + B, C, h, w = S.shape + if h != w: + return F.conv2d(input_tensor, weight, bias, stride, (pad_h, pad_w), dilation, groups) + + p = int(pad_h) + pL = pR = pU = pD = p + + seam_strength = float(max(0.0, min(seam_strength, 1.0))) if (engine or '').startswith('B') else 0.0 + seam_width = int(max(0, seam_width)) + + ZS = _cubemap_pad_with_adjoint( + S, + L=Wf[:, :, :, -pL:], + R=E[:, :, :, :pR], + U=T[:, :, -pU:, :], + D=Bm[:, :, :pD, :], + pL=pL, pR=pR, pU=pU, pD=pD, + pad_mode=pad_mode, + seam_strength=seam_strength, + seam_width=seam_width + ) + + ZE = _cubemap_pad_with_adjoint( + E, + L=S[:, :, :, -pL:], + R=N[:, :, :, :pR], + U=torch.rot90(T[:, :, :, -pU:], k=-1, dims=[2, 3]), + D=torch.rot90(Bm[:, :, :, -pD:], k=+1, dims=[2, 3]), + pL=pL, pR=pR, pU=pU, pD=pD, + pad_mode=pad_mode, + seam_strength=seam_strength, + seam_width=seam_width + ) + + ZN = _cubemap_pad_with_adjoint( + N, + L=E[:, :, :, -pL:], + R=Wf[:, :, :, :pR], + U=T[:, :, :pU, :].flip(-1), + D=Bm[:, :, -pD:, :].flip(-1), + pL=pL, pR=pR, pU=pU, pD=pD, + pad_mode=pad_mode, + seam_strength=seam_strength, + seam_width=seam_width + ) + + ZB = _cubemap_pad_with_adjoint( + Bm, + L=torch.rot90(Wf[:, :, -pL:, :], k=+1, dims=[2, 3]), + R=torch.rot90(E[:, :, -pR:, :], k=-1, dims=[2, 3]), + U=S[:, :, -pU:, :], + D=N[:, :, -pD:, :].flip(-1), + pL=pL, pR=pR, pU=pU, pD=pD, + pad_mode=pad_mode, + seam_strength=seam_strength, + seam_width=seam_width + ) + + ZT = _cubemap_pad_with_adjoint( + T, + L=torch.rot90(Wf[:, :, :pL, :], k=-1, dims=[2, 3]), + R=torch.rot90(E[:, :, :pR, :], k=+1, dims=[2, 3]), + U=N[:, :, :pU, :].flip(-1), + D=S[:, :, :pD, :], + pL=pL, pR=pR, pU=pU, pD=pD, + pad_mode=pad_mode, + seam_strength=seam_strength, + seam_width=seam_width + ) + + ZW = _cubemap_pad_with_adjoint( + Wf, + L=N[:, :, :, -pL:], + R=S[:, :, :, :pR], + U=torch.rot90(T[:, :, :, :pL], k=+1, dims=[2, 3]), + D=torch.rot90(Bm[:, :, :, :pD], k=-1, dims=[2, 3]), + pL=pL, pR=pR, pU=pU, pD=pD, + pad_mode=pad_mode, + seam_strength=seam_strength, + seam_width=seam_width + ) + + Z = torch.cat([ZS, ZE, ZN, ZB, ZT, ZW], dim=0) + Y = F.conv2d(Z, weight, bias, stride, (0, 0), dilation, groups) + YS, YE, YN, YB, YT, YW = Y.chunk(6, dim=0) + + return _cubemap_stitch_faces(YS, YE, YN, YB, YT, YW) + +# =================================== +# CUBEMAP (3D) — Engine C (GridSample / True 3D mapping) +# =================================== + +def _ypr_rotation_matrix(yaw_deg: float, pitch_deg: float, roll_deg: float, device, dtype): + """ + Builds a rotation matrix from yaw/pitch/roll angles (degrees). + Convention: + - yaw around +Y axis + - pitch around +X axis + - roll around +Z axis + Applied as: R = Rz(roll) @ Rx(pitch) @ Ry(yaw) + """ + yaw = math.radians(float(yaw_deg)) + pitch = math.radians(float(pitch_deg)) + roll = math.radians(float(roll_deg)) + + cy, sy = math.cos(yaw), math.sin(yaw) + cp, sp = math.cos(pitch), math.sin(pitch) + cr, sr = math.cos(roll), math.sin(roll) + + # Ry (yaw) + Ry = torch.tensor([[cy, 0.0, sy], + [0.0, 1.0, 0.0], + [-sy, 0.0, cy]], device=device, dtype=dtype) + + # Rx (pitch) + Rx = torch.tensor([[1.0, 0.0, 0.0], + [0.0, cp, -sp], + [0.0, sp, cp]], device=device, dtype=dtype) + + # Rz (roll) + Rz = torch.tensor([[cr, -sr, 0.0], + [sr, cr, 0.0], + [0.0, 0.0, 1.0]], device=device, dtype=dtype) + + return (Rz @ Rx @ Ry) + + +def _cubemap_dirs_from_face_uv(face_id: int, u, v): + """ + Maps face-local (u,v) to 3D direction vectors BEFORE normalization. + Faces in our atlas mapping: + 0: Front (+Z) -> S + 1: Right (+X) -> E + 2: Back (-Z) -> N + 3: Bottom (-Y) -> Bm + 4: Top (+Y) -> T + 5: Left (-X) -> Wf + u, v are broadcastable tensors, typically shaped (Hp, Wp) or (1,1,Hp,Wp) + """ + if face_id == 0: # +Z (Front) + x, y, z = u, -v, torch.ones_like(u) + elif face_id == 1: # +X (Right) + x, y, z = torch.ones_like(u), -v, -u + elif face_id == 2: # -Z (Back) + x, y, z = -u, -v, -torch.ones_like(u) + elif face_id == 3: # -Y (Bottom) + x, y, z = u, -torch.ones_like(u), -v + elif face_id == 4: # +Y (Top) + x, y, z = u, torch.ones_like(u), v + elif face_id == 5: # -X (Left) + x, y, z = -torch.ones_like(u), -v, u + else: + raise ValueError("Invalid face_id for cubemap.") + return x, y, z + + +def _cubemap_dir_to_atlas_grid(x, y, z, face_h: int, face_w: int, device, dtype): + """ + Converts 3D direction vectors to a single atlas (3x2 net) sampling grid in [-1,1]. + Returns grid shaped (..., 2) with last dim [x_norm, y_norm]. + + ✅ FLOAT16 FIX: Использует адаптивный epsilon! + 🔧 FIX #5: Координаты приводятся к float32 для точности вычислений + """ + # 🔧 FIX #5: Приведение к float32 для точных вычислений (избегаем артефактов в fp16) + x = x.to(torch.float32) + y = y.to(torch.float32) + z = z.to(torch.float32) + + # ✅ FIX: Адаптивный epsilon для float32 (т.к. теперь все в float32) + eps_val = get_safe_epsilon(torch.float32) + eps = torch.tensor(eps_val, device=device, dtype=torch.float32) + + # Normalize directions (avoid divide-by-zero) + inv_len = torch.rsqrt(torch.clamp(x * x + y * y + z * z, min=eps_val)) + x = x * inv_len + y = y * inv_len + z = z * inv_len + + ax = x.abs() + ay = y.abs() + az = z.abs() + + # Major axis selection + is_x = (ax >= ay) & (ax >= az) + is_y = (ay >= ax) & (ay >= az) + is_z = ~(is_x | is_y) + + # Face index map: 0..5 + face_idx = torch.empty_like(x, dtype=torch.int64) + + # Defaults (placeholders) + u = torch.zeros_like(x) + v = torch.zeros_like(x) + + # +X / -X + mask = is_x & (x >= 0) + face_idx[mask] = 1 + u[mask] = -z[mask] / (ax[mask] + eps) # ✅ eps теперь безопасен для float16 + v[mask] = -y[mask] / (ax[mask] + eps) + + mask = is_x & (x < 0) + face_idx[mask] = 5 + u[mask] = z[mask] / (ax[mask] + eps) + v[mask] = -y[mask] / (ax[mask] + eps) + + # +Y / -Y + mask = is_y & (y >= 0) + face_idx[mask] = 4 + u[mask] = x[mask] / (ay[mask] + eps) + v[mask] = z[mask] / (ay[mask] + eps) + + mask = is_y & (y < 0) + face_idx[mask] = 3 + u[mask] = x[mask] / (ay[mask] + eps) + v[mask] = -z[mask] / (ay[mask] + eps) + + # +Z / -Z + mask = is_z & (z >= 0) + face_idx[mask] = 0 + u[mask] = x[mask] / (az[mask] + eps) + v[mask] = -y[mask] / (az[mask] + eps) + + mask = is_z & (z < 0) + face_idx[mask] = 2 + u[mask] = -x[mask] / (az[mask] + eps) + v[mask] = -y[mask] / (az[mask] + eps) + + # Atlas tile offsets (col,row) for each face_idx + # 0:F -> (0,0), 1:R -> (1,0), 2:B -> (2,0), 3:Bo -> (0,1), 4:T -> (1,1), 5:L -> (2,1) + col = torch.zeros_like(u) + row = torch.zeros_like(v) + + col = torch.where(face_idx == 0, torch.tensor(0.0, device=device, dtype=dtype), col) + row = torch.where(face_idx == 0, torch.tensor(0.0, device=device, dtype=dtype), row) + + col = torch.where(face_idx == 1, torch.tensor(1.0, device=device, dtype=dtype), col) + row = torch.where(face_idx == 1, torch.tensor(0.0, device=device, dtype=dtype), row) + + col = torch.where(face_idx == 2, torch.tensor(2.0, device=device, dtype=dtype), col) + row = torch.where(face_idx == 2, torch.tensor(0.0, device=device, dtype=dtype), row) + + col = torch.where(face_idx == 3, torch.tensor(0.0, device=device, dtype=dtype), col) + row = torch.where(face_idx == 3, torch.tensor(1.0, device=device, dtype=dtype), row) + + col = torch.where(face_idx == 4, torch.tensor(1.0, device=device, dtype=dtype), col) + row = torch.where(face_idx == 4, torch.tensor(1.0, device=device, dtype=dtype), row) + + col = torch.where(face_idx == 5, torch.tensor(2.0, device=device, dtype=dtype), col) + row = torch.where(face_idx == 5, torch.tensor(1.0, device=device, dtype=dtype), row) + + # Convert (u,v) [-1,1] -> atlas pixel coords -> normalized coords [-1,1] + H_atlas = int(face_h * 2) + W_atlas = int(face_w * 3) + + # align_corners=True mapping uses (W-1)/(H-1) + x_pix = col * face_w + (u + 1.0) * 0.5 * (face_w - 1) + y_pix = row * face_h + (v + 1.0) * 0.5 * (face_h - 1) + + x_norm = (x_pix / max(W_atlas - 1, 1)) * 2.0 - 1.0 + y_norm = (y_pix / max(H_atlas - 1, 1)) * 2.0 - 1.0 + + # 🔧 FIX #5: Возвращаем в исходный dtype после вычислений + grid = torch.stack([x_norm, y_norm], dim=-1).to(dtype) + return grid + + +def _build_cubemap_engine_c_grids(face_h: int, face_w: int, pad: int, + yaw: float, pitch: float, roll: float, + coord_mode: str = "Cartesian (Face UV)", + twist_deg: float = 0.0, + polar_scale: float = 1.0, + polar_power: float = 1.0, + swirl_deg: float = 0.0, + swirl_power: float = 1.0, + device=None, dtype=None, + antipode: bool = False, + angle_quant: float = 0.5): + """ + Builds and caches per-face sampling grids (Engine C) for cubemap atlas. + Grids map each pixel in a padded face to the correct location in the 3x2 atlas. + """ + if face_h <= 1 or face_w <= 1: + return None + + # Quantize angles to stabilize caching + q = float(angle_quant) + q_milli = int(round(float(q) * 1000.0)) + if q_milli <= 0: q_milli = 1 + yaw_t = int(round(float(yaw) / q)) + pitch_t = int(round(float(pitch) / q)) + roll_t = int(round(float(roll) / q)) + twist_t = int(round(float(twist_deg) / q)) + swirl_t = int(round(float(swirl_deg) / q)) + yaw_q = float(yaw_t) * q + pitch_q = float(pitch_t) * q + roll_q = float(roll_t) * q + twist_q = float(twist_t) * q + swirl_q = float(swirl_t) * q + + # Quantize continuous params a bit for caching + polar_scale_q = round(float(polar_scale) * 100.0) / 100.0 + polar_power_q = round(float(polar_power) * 100.0) / 100.0 + swirl_power_q = round(float(swirl_power) * 100.0) / 100.0 + + dev_type = getattr(device, "type", None) + dev_index = getattr(device, "index", None) + key = ( + str(dev_type) if dev_type is not None else str(device), + int(dev_index) if dev_index is not None else -1, + str(dtype), int(face_h), int(face_w), int(pad), + str(coord_mode), + int(yaw_t), int(pitch_t), int(roll_t), + int(twist_t), + int(round(float(polar_scale_q) * 100.0)), int(round(float(polar_power_q) * 100.0)), + int(swirl_t), int(round(float(swirl_power_q) * 100.0)), + bool(antipode), int(q_milli)) + cached = _CUBEMAP_GRID_CACHE.get(key, None) + if cached is not None: + return cached + + p = int(max(0, pad)) + Hp = int(face_h + 2 * p) + Wp = int(face_w + 2 * p) + + # Face-local u,v coordinate system (padded) + # u maps across width, v maps across height. Center region uses u,v in [-1,1]. + j = torch.arange(Wp, device=device, dtype=dtype) + i = torch.arange(Hp, device=device, dtype=dtype) + denom_w = float(max(face_w - 1, 1)) + denom_h = float(max(face_h - 1, 1)) + u = 2.0 * ((j - p) / denom_w) - 1.0 + v = 2.0 * ((i - p) / denom_h) - 1.0 + + # Broadcast to (Hp,Wp) + u2 = u.view(1, Wp).expand(Hp, Wp) + v2 = v.view(Hp, 1).expand(Hp, Wp) + + # Advanced UV transform (twist / polar warp / swirl) + if coord_mode is None: + coord_mode = "Cartesian (Face UV)" + cm = str(coord_mode) + twist_rad = float(twist_q) * (math.pi / 180.0) + swirl_rad = float(swirl_q) * (math.pi / 180.0) + do_polar = cm.startswith("Polar") + if abs(twist_rad) > 1e-9 or abs(swirl_rad) > 1e-9 or do_polar: + # 🔧 FIX #4: Добавлен eps_val для защиты sqrt от NaN в float16 + + eps_val = get_safe_epsilon(dtype) + + r = torch.sqrt(u2 * u2 + v2 * v2 + eps_val) + # clamp radius for numerical stability in padding areas + r_clamped = torch.clamp(r, 0.0, 2.0) + theta = torch.atan2(v2, u2) + theta = theta + twist_rad + if abs(swirl_rad) > 1e-9: + sp = float(swirl_power_q) + theta = theta + swirl_rad * torch.pow(r_clamped, sp) + if do_polar: + ps = float(polar_scale_q) + pp = float(polar_power_q) + r2 = torch.pow(torch.clamp(r_clamped * ps, min=0.0), pp) + else: + r2 = r + u2 = r2 * torch.cos(theta) + v2 = r2 * torch.sin(theta) + + R = _ypr_rotation_matrix(yaw_q, pitch_q, roll_q, device=device, dtype=dtype) + + grids = [] + for face_id in range(6): + x, y, z = _cubemap_dirs_from_face_uv(face_id, u2, v2) + + # Rotate directions + # (Hp,Wp,3) @ (3,3)^T + dirs = torch.stack([x, y, z], dim=-1) + dirs = torch.matmul(dirs, R.transpose(0, 1)) + + if antipode: + dirs = -dirs + + grid = _cubemap_dir_to_atlas_grid( + dirs[..., 0], dirs[..., 1], dirs[..., 2], + face_h=face_h, face_w=face_w, + device=device, dtype=dtype + ) + grids.append(grid) + + grids = torch.stack(grids, dim=0) # (6,Hp,Wp,2) + _CUBEMAP_GRID_CACHE[key] = grids + return grids + + +def _grid_sample_geoaa(atlas, grid, samples: int = 1, radius_px: float = 0.0, + mode: str = "bilinear", padding_mode: str = "border"): + """ + Optional geometric AA (multi-sampling) for Engine C. + - samples: 1..4 + - radius_px: pixel radius in atlas space (approx) + """ + samples = int(max(1, min(int(samples), 4))) + radius_px = float(max(0.0, radius_px)) + # sanitize grid_sample args + if mode not in ("bilinear", "nearest"): + mode = "bilinear" + if padding_mode not in ("border", "reflection", "zeros"): + padding_mode = "border" + + + if samples == 1 or radius_px <= 0.0: + return F.grid_sample(atlas, grid, mode=mode, padding_mode=padding_mode, align_corners=True) + + B, C, H, W = atlas.shape + # normalize radius to grid space (align_corners=True => 1px == 2/(W-1)) + dx = (radius_px * 2.0) / max(W - 1, 1) + dy = (radius_px * 2.0) / max(H - 1, 1) + + offsets = [(0.0, 0.0)] + if samples >= 2: + offsets.append((dx, dy)) + if samples >= 3: + offsets.append((-dx, dy)) + if samples >= 4: + offsets.append((dx, -dy)) + + acc = None + for ox, oy in offsets: + g = grid.clone() + g[..., 0] = (g[..., 0] + ox).clamp(-1.0, 1.0) + g[..., 1] = (g[..., 1] + oy).clamp(-1.0, 1.0) + y = F.grid_sample(atlas, g, mode=mode, padding_mode=padding_mode, align_corners=True) + acc = y if acc is None else (acc + y) + + return acc / float(len(offsets)) + + +def conv2d_cubemap_gridsample(input_tensor, weight, bias, stride, dilation, groups, + pad_h, pad_w, + yaw=0.0, pitch=0.0, roll=0.0, + coord_mode="Cartesian (Face UV)", twist_deg=0.0, + polar_scale=1.0, polar_power=1.0, + swirl_deg=0.0, swirl_power=1.0, + grid_interp="bilinear", grid_padding="border", + cache_angle_quant=0.5, + geoaa_samples=1, geoaa_radius_px=0.0, + antipode_strength=0.0): + """ + Engine C: True 3D cubemap mapping using grid_sample. + - Builds padded faces by sampling from the full 3x2 atlas via direction mapping. + - Supports yaw/pitch/roll rotation of the sampling directions. + - Optional geometric AA (multi-sampling) and Kohaku-inspired antipode mixing. + """ + if pad_h != pad_w: + return F.conv2d(input_tensor, weight, bias, stride, (pad_h, pad_w), dilation, groups) + + p = int(pad_h) + if p <= 0: + return F.conv2d(input_tensor, weight, bias, stride, (0, 0), dilation, groups) + + B, C, H, W = input_tensor.shape + if H % 2 != 0 or W % 3 != 0: + return F.conv2d(input_tensor, weight, bias, stride, (pad_h, pad_w), dilation, groups) + + face_h = H // 2 + face_w = W // 3 + if face_h != face_w: + return F.conv2d(input_tensor, weight, bias, stride, (pad_h, pad_w), dilation, groups) + + device = input_tensor.device + # grid_sample expects float grid; use float32 for stability if input is fp16/bf16 + grid_dtype = torch.float32 if input_tensor.dtype in (torch.float16, torch.bfloat16) else input_tensor.dtype + + grids = _build_cubemap_engine_c_grids(face_h, face_w, p, yaw, pitch, roll, + coord_mode, twist_deg, polar_scale, polar_power, + swirl_deg, swirl_power, + device, grid_dtype, + antipode=False, + angle_quant=cache_angle_quant) + if grids is None: + return F.conv2d(input_tensor, weight, bias, stride, (pad_h, pad_w), dilation, groups) + + antipode_strength = float(max(0.0, min(float(antipode_strength), 1.0))) + + if antipode_strength > 0.0: + grids_anti = _build_cubemap_engine_c_grids(face_h, face_w, p, yaw, pitch, roll, + coord_mode, twist_deg, polar_scale, polar_power, + swirl_deg, swirl_power, + device, grid_dtype, + antipode=True, + angle_quant=cache_angle_quant) + else: + grids_anti = None + + Hp = int(face_h + 2 * p) + Wp = int(face_w + 2 * p) + + faces_padded = [] + for face_id in range(6): + g = grids[face_id].to(device=device) + gB = g.unsqueeze(0).expand(B, Hp, Wp, 2).contiguous() + + # grid_sample uses input dtype, but grid dtype may differ; that's ok. + y0 = _grid_sample_geoaa(input_tensor, gB, samples=geoaa_samples, radius_px=geoaa_radius_px, mode=grid_interp, padding_mode=grid_padding) + + if grids_anti is not None: + ga = grids_anti[face_id].to(device=device) + gaB = ga.unsqueeze(0).expand(B, Hp, Wp, 2).contiguous() + y1 = _grid_sample_geoaa(input_tensor, gaB, samples=geoaa_samples, radius_px=geoaa_radius_px, mode=grid_interp, padding_mode=grid_padding) + y0 = y0 * (1.0 - antipode_strength) + y1 * antipode_strength + + faces_padded.append(y0) + + Z = torch.cat(faces_padded, dim=0) # (6B,C,Hp,Wp) + Y = F.conv2d(Z, weight, bias, stride, (0, 0), dilation, groups) + YS, YE, YN, YB, YT, YW = Y.chunk(6, dim=0) + + return _cubemap_stitch_faces(YS, YE, YN, YB, YT, YW) + + + +# ======================================================================== +# PANORAMA LIVE (Equirectangular) — Engine C (3D grid_sample) +# ======================================================================== + +def _get_blur_kernel_1d(radius: int, device, dtype): + """Depthwise 1D blur kernel along X (width).""" + r = int(max(0, radius)) + if r <= 0: + return None + k = 2 * r + 1 + dev_type = getattr(device, "type", None) + dev_index = getattr(device, "index", None) + key = (int(k), str(dev_type) if dev_type is not None else str(device), int(dev_index) if dev_index is not None else -1, str(dtype)) + ker = _BLUR_KERNEL_CACHE.get(key, None) + if ker is not None: + return ker + # Simple box kernel (stable, cheap). You can swap to Gaussian if needed. + w = torch.ones((k,), device=device, dtype=dtype) / float(k) + ker = w.view(1, 1, 1, k) # (1,1,1,k) + _BLUR_KERNEL_CACHE[key] = ker + return ker + + +def _apply_pole_blur_smoothing(x, strength: float = 0.0, radius: int = 0, power: float = 1.0): + """ + Applies circular horizontal blur near poles (top/bottom) with a smooth mask. + x: (B,C,H,W) + """ + strength = float(max(0.0, min(float(strength), 1.0))) + radius = int(max(0, int(radius))) + power = float(max(0.25, min(float(power), 4.0))) + + if strength <= 0.0 or radius <= 0: + return x + + B, C, H, W = x.shape + device = x.device + dtype = x.dtype + + ker = _get_blur_kernel_1d(radius, device, dtype) + if ker is None: + return x + + # Pole mask: 1 near top/bottom, 0 near equator + yy = torch.linspace(0.0, 1.0, steps=H, device=device, dtype=dtype).view(1, 1, H, 1) + t = torch.abs(yy - 0.5) * 2.0 # 0 at equator, 1 at poles + pole_mask = torch.pow(torch.clamp(t, 0.0, 1.0), power) # (1,1,H,1) + + # Circular pad along X then depthwise conv + xp = F.pad(x, (radius, radius, 0, 0), mode="circular") + # Depthwise conv: expand kernel per-channel + weight = ker.expand(C, 1, 1, ker.shape[-1]).contiguous() + blurred = F.conv2d(xp, weight, bias=None, stride=1, padding=0, groups=C) + + m = pole_mask * strength + return x * (1.0 - m) + blurred * m + + +def _build_panorama_engine_c_grid(H: int, W: int, pad_h: int, pad_w: int, + yaw: float, pitch: float, roll: float, + coord_mode: str = "Cartesian (lon/lat)", + polar_scale: float = 1.0, + polar_power: float = 1.0, + twist_deg: float = 0.0, + twist_power: float = 1.0, + swirl_deg: float = 0.0, + swirl_power: float = 1.0, + pole_ease_power: float = 1.0, + antipode: bool = False, + angle_quant: float = 0.5, + device=None, dtype=None): + """ + Builds/caches a sampling grid for equirectangular panoramas. + Grid maps output pixels in a padded canvas to source coords in the original panorama. + Uses true 3D spherical mapping (yaw/pitch/roll) and optional UV warps. + """ + if H <= 1 or W <= 1: + return None + + ph = int(max(0, pad_h)) + pw = int(max(0, pad_w)) + Hp = int(H + 2 * ph) + Wp = int(W + 2 * pw) + + q = float(max(0.1, float(angle_quant))) + q_milli = int(round(float(q) * 1000.0)) + if q_milli <= 0: q_milli = 1 + yaw_t = int(round(float(yaw) / q)) + pitch_t = int(round(float(pitch) / q)) + roll_t = int(round(float(roll) / q)) + twist_t = int(round(float(twist_deg) / q)) + swirl_t = int(round(float(swirl_deg) / q)) + yaw_q = float(yaw_t) * q + pitch_q = float(pitch_t) * q + roll_q = float(roll_t) * q + twist_q = float(twist_t) * q + swirl_q = float(swirl_t) * q + + polar_scale_q = round(float(polar_scale) * 100.0) / 100.0 + polar_power_q = round(float(polar_power) * 100.0) / 100.0 + twist_power_q = round(float(twist_power) * 100.0) / 100.0 + swirl_power_q = round(float(swirl_power) * 100.0) / 100.0 + pole_ease_q = round(float(pole_ease_power) * 100.0) / 100.0 + + dev_type = getattr(device, "type", None) + dev_index = getattr(device, "index", None) + key = ( + str(dev_type) if dev_type is not None else str(device), + int(dev_index) if dev_index is not None else -1, + str(dtype), int(H), int(W), int(ph), int(pw), + str(coord_mode), + int(yaw_t), int(pitch_t), int(roll_t), + int(twist_t), int(round(float(twist_power_q) * 100.0)), + int(swirl_t), int(round(float(swirl_power_q) * 100.0)), + int(round(float(polar_scale_q) * 100.0)), int(round(float(polar_power_q) * 100.0)), + int(round(float(pole_ease_q) * 100.0)), + bool(antipode), int(q_milli)) + cached = _PANO_GRID_CACHE.get(key, None) + if cached is not None: + return cached + + # Output pixel -> base lon/lat (can extend beyond [0,1] in padding; that's OK) + j = torch.arange(Wp, device=device, dtype=dtype) + i = torch.arange(Hp, device=device, dtype=dtype) + + denom_w = float(max(W - 1, 1)) + denom_h = float(max(H - 1, 1)) + + u = (j - pw) / denom_w # 0..1 over original image + v = (i - ph) / denom_h + + u2 = u.view(1, Wp).expand(Hp, Wp) + v2 = v.view(Hp, 1).expand(Hp, Wp) + + # lon in radians (wrap naturally via sin/cos); lat in radians (can go beyond poles) + lon = (u2 - 0.5) * (2.0 * math.pi) + lat = (0.5 - v2) * math.pi + + cm = str(coord_mode or "Cartesian (lon/lat)") + do_polar = cm.startswith("Polar") + + # --- Optional twist & swirl in (lon,lat) domain --- + # Twist is latitude-dependent longitude shift ("roll" feel along parallels) + tr = float(twist_q) * (math.pi / 180.0) + tp = float(max(0.25, min(float(twist_power_q), 4.0))) + if abs(tr) > 1e-9: + t = torch.clamp(torch.abs(lat) / (0.5 * math.pi), 0.0, 1.0) + lon = lon + tr * torch.sign(lat) * torch.pow(t, tp) + + sr = float(swirl_q) * (math.pi / 180.0) + sp = float(max(0.25, min(float(swirl_power_q), 4.0))) + if abs(sr) > 1e-9: + # Swirl strongest near poles by default + t = torch.clamp(torch.abs(lat) / (0.5 * math.pi), 0.0, 1.0) + lon = lon + sr * torch.pow(t, sp) + + # --- Polar mode: radial warp around poles via latitude reparameterization --- + if do_polar: + ps = float(max(0.01, float(polar_scale_q))) + pp = float(max(0.25, min(float(polar_power_q), 6.0))) + # t=0 at equator, t=1 at poles + t = torch.clamp(torch.abs(lat) / (0.5 * math.pi), 0.0, 1.0) + r = 1.0 - t # r=1 at equator, 0 at poles + r2 = torch.pow(torch.clamp(r * ps, min=0.0, max=1.0), pp) + t2 = 1.0 - r2 + lat = torch.sign(lat) * t2 * (0.5 * math.pi) + + # Convert (lon,lat) to 3D direction + cl = torch.cos(lon) + sl = torch.sin(lon) + ca = torch.cos(lat) + sa = torch.sin(lat) + + x = sl * ca + y = sa + z = cl * ca + + # Apply global rotation + R = _ypr_rotation_matrix(yaw_q, pitch_q, roll_q, device=device, dtype=dtype) + dirs = torch.stack([x, y, z], dim=-1) + dirs = torch.matmul(dirs, R.transpose(0, 1)) + + if antipode: + dirs = -dirs + + # Back to lon/lat + x2 = dirs[..., 0] + y2 = torch.clamp(dirs[..., 1], -1.0, 1.0) + z2 = dirs[..., 2] + + lon2 = torch.atan2(x2, z2) # [-pi,pi] + lat2 = torch.asin(y2) # [-pi/2,pi/2] + + # Pole easing curve (power) on latitude magnitude + pe = float(max(0.25, min(float(pole_ease_q), 6.0))) + if abs(pe - 1.0) > get_safe_epsilon(torch.float16): + t = torch.clamp(torch.abs(lat2) / (0.5 * math.pi), 0.0, 1.0) + t = torch.pow(t, pe) + lat2 = torch.sign(lat2) * t * (0.5 * math.pi) + + # Convert to source UV [0,1) with X wrap + u_src = (lon2 / (2.0 * math.pi)) + 0.5 + u_src = torch.remainder(u_src, 1.0) # wrap horizontally + v_src = 0.5 - (lat2 / math.pi) # 0..1 + + # to normalized grid_sample coords [-1,1] + x_norm = u_src * 2.0 - 1.0 + y_norm = v_src * 2.0 - 1.0 + + # 🔧 FIX #5: Возвращаем в исходный dtype после вычислений + grid = torch.stack([x_norm, y_norm], dim=-1).to(dtype) # (Hp,Wp,2) + _PANO_GRID_CACHE[key] = grid + return grid + + +def conv2d_panorama_gridsample(input_tensor, weight, bias, stride, dilation, groups, + pad_h, pad_w, + yaw=0.0, pitch=0.0, roll=0.0, + coord_mode="Cartesian (lon/lat)", + polar_scale=1.0, polar_power=1.0, + twist_deg=0.0, twist_power=1.0, + swirl_deg=0.0, swirl_power=1.0, + pole_ease_power=1.0, + grid_interp="bilinear", grid_padding="border", + cache_angle_quant=0.5, + geoaa_samples=1, geoaa_radius_px=0.0, + antipode_strength=0.0, + pole_blur_strength=0.0, pole_blur_radius=0, pole_blur_power=1.0): + """ + Panorama Live Engine C: + - Builds a padded panorama by sampling the original via 3D spherical mapping. + - Runs conv2d without extra padding. + - Optional Kohaku-style antipode mixing and pole blur smoothing. + """ + ph = int(max(0, int(pad_h))) + pw = int(max(0, int(pad_w))) + if ph <= 0 and pw <= 0: + return F.conv2d(input_tensor, weight, bias, stride, (0, 0), dilation, groups) + + B, C, H, W = input_tensor.shape + device = input_tensor.device + grid_dtype = torch.float32 if input_tensor.dtype in (torch.float16, torch.bfloat16) else input_tensor.dtype + + grid = _build_panorama_engine_c_grid( + H, W, ph, pw, + yaw=yaw, pitch=pitch, roll=roll, + coord_mode=coord_mode, + polar_scale=polar_scale, polar_power=polar_power, + twist_deg=twist_deg, twist_power=twist_power, + swirl_deg=swirl_deg, swirl_power=swirl_power, + pole_ease_power=pole_ease_power, + antipode=False, + angle_quant=cache_angle_quant, + device=device, dtype=grid_dtype + ) + if grid is None: + return F.conv2d(input_tensor, weight, bias, stride, (pad_h, pad_w), dilation, groups) + + Hp = int(H + 2 * ph) + Wp = int(W + 2 * pw) + + gB = grid.unsqueeze(0).expand(B, Hp, Wp, 2).contiguous() + y0 = _grid_sample_geoaa(input_tensor, gB, samples=geoaa_samples, radius_px=geoaa_radius_px, + mode=grid_interp, padding_mode=grid_padding) + + antipode_strength = float(max(0.0, min(float(antipode_strength), 1.0))) + if antipode_strength > 0.0: + grid_a = _build_panorama_engine_c_grid( + H, W, ph, pw, + yaw=yaw, pitch=pitch, roll=roll, + coord_mode=coord_mode, + polar_scale=polar_scale, polar_power=polar_power, + twist_deg=twist_deg, twist_power=twist_power, + swirl_deg=swirl_deg, swirl_power=swirl_power, + pole_ease_power=pole_ease_power, + antipode=True, + angle_quant=cache_angle_quant, + device=device, dtype=grid_dtype + ) + gaB = grid_a.unsqueeze(0).expand(B, Hp, Wp, 2).contiguous() + y1 = _grid_sample_geoaa(input_tensor, gaB, samples=geoaa_samples, radius_px=geoaa_radius_px, + mode=grid_interp, padding_mode=grid_padding) + y0 = y0 * (1.0 - antipode_strength) + y1 * antipode_strength + + # Optional pole blur + y0 = _apply_pole_blur_smoothing(y0, + strength=pole_blur_strength, + radius=pole_blur_radius, + power=pole_blur_power) + + return F.conv2d(y0, weight, bias, stride, (0, 0), dilation, groups) + +def compute_stereoscopic_padding(input_tensor, pad_h, pad_w, eye='left', + convergence=0.05, separation=0.065): + """ + Стереоскопический паддинг для 3D изображений. + eye: 'left', 'right' или 'both' + """ + b, c, h, w = input_tensor.shape + + eye = (eye or 'left').lower() + shift_amount = int(w * separation) + + x_coords = torch.linspace( + 0.0, 1.0, w, + device=input_tensor.device, + dtype=input_tensor.dtype + ).view(1, 1, 1, w) + depth_map = torch.abs(x_coords - convergence).expand(b, c, h, w) + alpha = depth_map.clamp(0.0, 1.0) + + if eye == 'left': + shifted = torch.roll(input_tensor, shifts=shift_amount, dims=3) + stereo_adjusted = input_tensor * (1.0 - alpha) + shifted * alpha + elif eye == 'right': + shifted = torch.roll(input_tensor, shifts=-shift_amount, dims=3) + stereo_adjusted = input_tensor * (1.0 - alpha) + shifted * alpha + else: + # 'both' — симметричный режим + shifted_left = torch.roll(input_tensor, shifts=shift_amount, dims=3) + shifted_right = torch.roll(input_tensor, shifts=-shift_amount, dims=3) + shifted_avg = 0.5 * (shifted_left + shifted_right) + stereo_adjusted = input_tensor * (1.0 - alpha) + shifted_avg * alpha + + padded = F.pad(stereo_adjusted, (pad_w, pad_w, pad_h, pad_h), mode='circular') + return padded + +def compute_hex_padding_x(input_tensor, pad_l, pad_r): + """Гексагональный паддинг (из v2.0)""" + b, c, h, w = input_tensor.shape + odd_mask = get_or_create_mask(h, w, input_tensor.device).expand(b, c, h, w) + + shift = w // 2 + input_shifted = torch.roll(input_tensor, shifts=shift, dims=3) + source = torch.where(odd_mask, input_shifted, input_tensor) + + left_pad = source[:, :, :, -pad_l:] + right_pad = source[:, :, :, :pad_r] + + return torch.cat([left_pad, input_tensor, right_pad], dim=3) + +# ======================================================================== +# ГЛАВНАЯ ФУНКЦИЯ ПАДДИНГА +# ======================================================================== +# ======================================================================== +# ИСПРАВЛЕНИЕ: Функции для обеспечения правильных размеров тензоров +# ======================================================================== + +def ensure_output_size_match(output_tensor, expected_h, expected_w): + """ + УЛУЧШЕННАЯ версия: Гарантирует точное совпадение размеров. + Критично для VAE ResNet блоков с skip connections. + FIX для ошибки: RuntimeError: The size of tensor a must match the size of tensor b + """ + if output_tensor is None: + return None + + b, c, h, w = output_tensor.shape + + # Точное совпадение - возвращаем как есть + if h == expected_h and w == expected_w: + return output_tensor + + # Вычисляем разницу + diff_h = h - expected_h + diff_w = w - expected_w + + # ✅ КРИТИЧЕСКОЕ ИСПРАВЛЕНИЕ: используем центральный crop/pad + if diff_h > 0 or diff_w > 0: + # Обрезаем лишнее (центрируем) + crop_h_start = max(0, diff_h // 2) + crop_w_start = max(0, diff_w // 2) + crop_h_end = crop_h_start + expected_h + crop_w_end = crop_w_start + expected_w + + # Защита от выхода за границы + crop_h_end = min(crop_h_end, h) + crop_w_end = min(crop_w_end, w) + + result = output_tensor[:, :, crop_h_start:crop_h_end, crop_w_start:crop_w_end] + + # Если всё ещё не совпадает - используем интерполяцию + if result.shape[2] != expected_h or result.shape[3] != expected_w: + result = F.interpolate(result, size=(expected_h, expected_w), + mode='bilinear', align_corners=False) + return result + else: + # Добавляем padding (симметричный) + pad_h_total = -diff_h + pad_w_total = -diff_w + + pad_h_left = pad_h_total // 2 + pad_h_right = pad_h_total - pad_h_left + pad_w_left = pad_w_total // 2 + pad_w_right = pad_w_total - pad_w_left + + result = F.pad(output_tensor, + (pad_w_left, pad_w_right, pad_h_left, pad_h_right), + mode='replicate') + + # Финальная проверка + if result.shape[2] != expected_h or result.shape[3] != expected_w: + result = F.interpolate(result, size=(expected_h, expected_w), + mode='bilinear', align_corners=False) + return result + +def compute_expected_output_size(input_h, input_w, kernel_size, stride, padding, dilation): + """Вычисляет ожидаемый размер выходного тензора после свертки.""" + k_h, k_w = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) + s_h, s_w = stride if isinstance(stride, tuple) else (stride, stride) + d_h, d_w = dilation if isinstance(dilation, tuple) else (dilation, dilation) + + if isinstance(padding, (tuple, list)): + p_h, p_w = padding[0], padding[1] + elif isinstance(padding, int): + p_h, p_w = padding, padding + else: + p_h, p_w = 0, 0 + + out_h = ((input_h + 2*p_h - d_h*(k_h-1) - 1) // s_h) + 1 + out_w = ((input_w + 2*p_w - d_w*(k_w-1) - 1) // s_w) + 1 + return out_h, out_w + +def custom_padding_forward(input_tensor, weight, bias, stride, padding, dilation, groups, params): + """ + ИСПРАВЛЕННАЯ версия custom_padding_forward с фиксом размеров тензоров. + + КРИТИЧЕСКИЕ ИЗМЕНЕНИЯ: + 1. Вычисляем ожидаемый размер выходного тензора ПЕРЕД применением padding + 2. После всех операций проверяем и корректируем размеры + 3. Гарантируем совместимость с skip connections в UNet + """ + try: + # ===== ШАГ 1: Сохраняем оригинальные размеры ===== + b, c, orig_h, orig_w = input_tensor.shape + + # 🔧 FIX #1: Создаем унифицированную переменную для результата + tiled_output = None + + # ===== ШАГ 2: Вычисляем ожидаемый размер ВЫХОДНОГО тензора ===== + # Это размер, который должен получиться после свертки с оригинальным padding + k_h, k_w = weight.shape[2], weight.shape[3] + + # Обрабатываем разные форматы padding + if isinstance(padding, str): + if padding == 'same': + # Для 'same' padding выход должен совпадать с входом (при stride=1) + expected_out_h = orig_h + expected_out_w = orig_w + elif padding == 'valid': + # Для 'valid' padding + s_h, s_w = stride if isinstance(stride, tuple) else (stride, stride) + d_h, d_w = dilation if isinstance(dilation, tuple) else (dilation, dilation) + expected_out_h = ((orig_h - d_h*(k_h-1) - 1) // s_h) + 1 + expected_out_w = ((orig_w - d_w*(k_w-1) - 1) // s_w) + 1 + else: + # Fallback + expected_out_h = orig_h + expected_out_w = orig_w + else: + # Вычисляем для числового padding + p_val = padding if isinstance(padding, int) else padding[0] + expected_out_h, expected_out_w = compute_expected_output_size( + orig_h, orig_w, (k_h, k_w), stride, (p_val, p_val), dilation + ) + + # ===== ШАГ 3: Проверяем активность тайлинга ===== + current_step = getattr(shared.state, 'sampling_step', 0) if 'shared' in dir() else 0 + start_step = params.get('start_step', 0) + end_step = params.get('end_step', 9999) + + _in_range = (start_step <= current_step < end_step) + _past_end = (current_step >= end_step) + + if not _in_range: + # 🔒 Lock after End: keep tiling active past end_step + if _past_end and params.get('tiling_lock_after_end', False): + pass # fall through to tiling below + else: + # Тайлинг неактивен - используем стандартную свертку + return F.conv2d(input_tensor, weight, bias, stride, padding, dilation, groups) + + # 🎯 Pan One-Shot: apply torch.roll on the very first step of the active range + if params.get('tiling_pan_one_shot', False) and current_step == start_step: + script_ref = params.get('script_ref', None) + _already_panned = getattr(script_ref, '_tiling_pan_one_shot_done', False) if script_ref else True + if not _already_panned and script_ref is not None: + _px = params.get('pan_x', 0.0) + _py = params.get('pan_y', 0.0) + if _px or _py: + _h, _w = input_tensor.shape[-2], input_tensor.shape[-1] + _sh = int(round(_py * _h)) + _sw = int(round(_px * _w)) + input_tensor = torch.roll(input_tensor, shifts=(_sh, _sw), dims=(-2, -1)) + script_ref._tiling_pan_one_shot_done = True + + # Проверка "Disable Advanced Tiling during hires pass" + script = params.get('script_ref', None) + if params.get('tiling_disable_hr', False) and script is not None: + if getattr(script, 'tiling_enable_hr', False) and getattr(script, 'tiling_is_hires', False): + return F.conv2d(input_tensor, weight, bias, stride, padding, dilation, groups) + + # ===== ШАГ 4: Вычисляем требуемый padding ===== + d_h, d_w = (dilation, dilation) if isinstance(dilation, int) else dilation + + if isinstance(padding, str): + if padding == 'same': + req_pad_h = ((k_h - 1) * d_h) // 2 + req_pad_w = ((k_w - 1) * d_w) // 2 + elif padding == 'valid': + req_pad_h = req_pad_w = 0 + else: + req_pad_h = req_pad_w = 1 + elif isinstance(padding, int): + req_pad_h = req_pad_w = padding + elif isinstance(padding, (tuple, list)): + req_pad_h, req_pad_w = padding[0], padding[1] + else: + req_pad_h = req_pad_w = 0 + + # FIX V3.7: Не выходим рано если нет паддинга и нет других режимов + if req_pad_h == 0 and req_pad_w == 0: + return F.conv2d(input_tensor, weight, bias, stride, padding, dilation, groups) + + mode_x = params.get('mode_x', MODE_OFF) + mode_y = params.get('mode_y', MODE_OFF) + + if mode_x == MODE_CUBEMAP or mode_y == MODE_CUBEMAP: + cubemap_engine = params.get('cubemap_engine', 'A (Fast)') + cubemap_pad_mode = params.get('cubemap_pad_mode', 'replicate') + seam_w = int(params.get('cubemap_blend_width', 0) or 0) + seam_s = float(params.get('cubemap_blend_strength', 0.0) or 0.0) + + if str(cubemap_engine).startswith('C'): + # Engine C: grid_sample cubemap + yaw = float(params.get('cubemap_yaw', 0.0) or 0.0) + pitch = float(params.get('cubemap_pitch', 0.0) or 0.0) + roll = float(params.get('cubemap_roll', 0.0) or 0.0) + coord_mode = params.get('cubemap_coord_mode', 'Cartesian (Face UV)') + twist_deg = float(params.get('cubemap_twist_deg', 0.0) or 0.0) + polar_scale = float(params.get('cubemap_polar_scale', 1.0) or 1.0) + polar_power = float(params.get('cubemap_polar_power', 1.0) or 1.0) + swirl_deg = float(params.get('cubemap_swirl_deg', 0.0) or 0.0) + swirl_power = float(params.get('cubemap_swirl_power', 1.0) or 1.0) + grid_interp = params.get('cubemap_grid_interp', 'bilinear') + grid_padding = params.get('cubemap_grid_padding', 'border') + cache_angle_quant = float(params.get('cubemap_cache_angle_quant', 0.5) or 0.5) + geoaa_samples = int(params.get('cubemap_geoaa_samples', 1) or 1) + geoaa_radius = float(params.get('cubemap_geoaa_radius_px', 0.0) or 0.0) + antipode_strength = float(params.get('cubemap_antipode_strength', 0.0) or 0.0) + + # 🔧 FIX #1: Используем tiled_output + tiled_output = conv2d_cubemap_gridsample( + input_tensor, weight, bias, + stride, dilation, groups, + req_pad_h, req_pad_w, + yaw=yaw, pitch=pitch, roll=roll, + coord_mode=coord_mode, twist_deg=twist_deg, + polar_scale=polar_scale, polar_power=polar_power, + swirl_deg=swirl_deg, swirl_power=swirl_power, + grid_interp=grid_interp, grid_padding=grid_padding, + cache_angle_quant=cache_angle_quant, + geoaa_samples=geoaa_samples, geoaa_radius_px=geoaa_radius, + antipode_strength=antipode_strength + ) + else: + # Engine A/B: batched cubemap + # 🔧 FIX #1: Используем tiled_output + tiled_output = conv2d_cubemap_batched( + input_tensor, weight, bias, + stride, dilation, groups, + req_pad_h, req_pad_w, + pad_mode=cubemap_pad_mode, + engine=cubemap_engine, + seam_width=seam_w, + seam_strength=seam_s + ) + + # ✅ ИСПРАВЛЕНО: Multi-resolution для Cubemap (output-space) + if params.get('multires_enabled', False): + multires_params = validate_multires_params(params) + + # Создаем простую версию (output space) + x_default = F.pad(input_tensor, (req_pad_w, req_pad_w, req_pad_h, req_pad_h), + mode='replicate') + out_default = F.conv2d(x_default, weight, bias, stride, 0, dilation, groups) + + # Применяем смешивание к OUTPUTS + # 🔧 FIX #1: Используем tiled_output + + tiled_output = apply_multires_blend( + tensor_simple=out_default, + tensor_advanced=tiled_output, # ← ИСПРАВЛЕНО: используем tiled_output! + current_step=current_step, + start_step=params['start_step'], + end_step=params['end_step'], + strategy=multires_params['strategy'], + transition_start=multires_params['transition_start'], + transition_end=multires_params['transition_end'], + sharpness=multires_params['sharpness'], + enabled=True + ) + + # ✅ КРИТИЧЕСКОЕ ИСПРАВЛЕНИЕ: корректируем размеры + tiled_output = ensure_output_size_match(tiled_output, expected_out_h, expected_out_w) + return tiled_output + + # ────────── PANORAMA ────────── + elif mode_x == MODE_PANORAMA or mode_y == MODE_PANORAMA: + pano_engine = params.get('panorama_engine', 'A (Legacy)') + pano_engine_str = str(pano_engine) + + # Engine C: true spherical mapping via grid_sample + if pano_engine_str.startswith('C'): + yaw = float(params.get('panorama_yaw', 0.0) or 0.0) + pitch = float(params.get('panorama_pitch', 0.0) or 0.0) + roll = float(params.get('panorama_roll', 0.0) or 0.0) + coord_mode = params.get('panorama_coord_mode', 'Cartesian (lon/lat)') + polar_scale = float(params.get('panorama_polar_scale', 1.0) or 1.0) + polar_power = float(params.get('panorama_polar_power', 1.0) or 1.0) + twist_deg = float(params.get('panorama_twist_deg', 0.0) or 0.0) + twist_power = float(params.get('panorama_twist_power', 1.0) or 1.0) + swirl_deg = float(params.get('panorama_swirl_deg', 0.0) or 0.0) + swirl_power = float(params.get('panorama_swirl_power', 1.0) or 1.0) + pole_ease_power = float(params.get('panorama_pole_ease_power', 1.0) or 1.0) + grid_interp = params.get('panorama_grid_interp', 'bilinear') + grid_padding = params.get('panorama_grid_padding', 'border') + cache_angle_quant = float(params.get('panorama_cache_angle_quant', 0.5) or 0.5) + geoaa_samples = int(params.get('panorama_geoaa_samples', 1) or 1) + geoaa_radius = float(params.get('panorama_geoaa_radius_px', 0.0) or 0.0) + antipode_strength = float(params.get('panorama_antipode_strength', 0.0) or 0.0) + pole_blur_strength = float(params.get('panorama_pole_blur_strength', 0.0) or 0.0) + pole_blur_radius = int(params.get('panorama_pole_blur_radius', 0) or 0) + pole_blur_power = float(params.get('panorama_pole_blur_power', 1.0) or 1.0) + + # 🔧 FIX #1: Используем tiled_output + tiled_output = conv2d_panorama_gridsample( + input_tensor, weight, bias, + stride, dilation, groups, + req_pad_h, req_pad_w, + yaw=yaw, pitch=pitch, roll=roll, + coord_mode=coord_mode, + polar_scale=polar_scale, polar_power=polar_power, + twist_deg=twist_deg, twist_power=twist_power, + swirl_deg=swirl_deg, swirl_power=swirl_power, + pole_ease_power=pole_ease_power, + grid_interp=grid_interp, grid_padding=grid_padding, + cache_angle_quant=cache_angle_quant, + geoaa_samples=geoaa_samples, geoaa_radius_px=geoaa_radius, + antipode_strength=antipode_strength, + pole_blur_strength=pole_blur_strength, pole_blur_radius=pole_blur_radius, + pole_blur_power=pole_blur_power + ) + + # ✅ УЖЕ ПРАВИЛЬНО: Multi-resolution для Panorama C (output-space) + if params.get('multires_enabled', False): + multires_params = validate_multires_params(params) + + x_default = F.pad(input_tensor, (req_pad_w, req_pad_w, req_pad_h, req_pad_h), + mode='replicate') + out_default = F.conv2d(x_default, weight, bias, stride, 0, dilation, groups) + + # 🔧 FIX #1: Используем tiled_output + + + tiled_output = apply_multires_blend( + tensor_simple=out_default, + tensor_advanced=tiled_output, + current_step=current_step, + start_step=params['start_step'], + end_step=params['end_step'], + strategy=multires_params['strategy'], + transition_start=multires_params['transition_start'], + transition_end=multires_params['transition_end'], + sharpness=multires_params['sharpness'], + enabled=True + ) + + # ✅ КРИТИЧЕСКОЕ ИСПРАВЛЕНИЕ: корректируем размеры + tiled_output = ensure_output_size_match(tiled_output, expected_out_h, expected_out_w) + return tiled_output + + # Engine B: fast pole-correct padding + elif pano_engine_str.startswith('B'): + x = input_tensor + anti_s = float(params.get('panorama_antipode_strength', 0.0) or 0.0) + anti_s = float(max(0.0, min(anti_s, 1.0))) + if anti_s > 0.0: + anti = torch.roll(torch.flip(x, dims=[2]), shifts=x.shape[-1] // 2, dims=3) + x = x * (1.0 - anti_s) + anti * anti_s + + x = compute_polar_padding(x, req_pad_h, req_pad_w) + + pole_blur_strength = float(params.get('panorama_pole_blur_strength', 0.0) or 0.0) + pole_blur_radius = int(params.get('panorama_pole_blur_radius', 0) or 0) + pole_blur_power = float(params.get('panorama_pole_blur_power', 1.0) or 1.0) + x = _apply_pole_blur_smoothing(x, strength=pole_blur_strength, + radius=pole_blur_radius, power=pole_blur_power) + + # ✅ ИСПРАВЛЕНО: Применяем multires правильно + # 🔧 FIX #1: Используем tiled_output + + tiled_output = F.conv2d(x, weight, bias, stride, 0, dilation, groups) + + if params.get('multires_enabled', False): + multires_params = validate_multires_params(params) + + x_default = F.pad(input_tensor, (req_pad_w, req_pad_w, req_pad_h, req_pad_h), + mode='replicate') + out_default = F.conv2d(x_default, weight, bias, stride, 0, dilation, groups) + + # 🔧 FIX #1: Используем tiled_output + + + tiled_output = apply_multires_blend( + tensor_simple=out_default, + tensor_advanced=tiled_output, + current_step=current_step, + start_step=params['start_step'], + end_step=params['end_step'], + strategy=multires_params['strategy'], + transition_start=multires_params['transition_start'], + transition_end=multires_params['transition_end'], + sharpness=multires_params['sharpness'], + enabled=True + ) + + # ✅ КРИТИЧЕСКОЕ ИСПРАВЛЕНИЕ: корректируем размеры + tiled_output = ensure_output_size_match(tiled_output, expected_out_h, expected_out_w) + return tiled_output + + # Engine A (legacy): simple circular wrap + else: + x = input_tensor + x = F.pad(x, (req_pad_w, req_pad_w, 0, 0), mode='circular') + x = F.pad(x, (0, 0, req_pad_h, req_pad_h), mode='circular') + + # 🔧 FIX #1: Используем tiled_output + + + tiled_output = F.conv2d(x, weight, bias, stride, 0, dilation, groups) + + if params.get('multires_enabled', False): + multires_params = validate_multires_params(params) + + x_default = F.pad(input_tensor, (req_pad_w, req_pad_w, req_pad_h, req_pad_h), + mode='replicate') + out_default = F.conv2d(x_default, weight, bias, stride, 0, dilation, groups) + + # 🔧 FIX #1: Используем tiled_output + + + tiled_output = apply_multires_blend( + tensor_simple=out_default, + tensor_advanced=tiled_output, + current_step=current_step, + start_step=params['start_step'], + end_step=params['end_step'], + strategy=multires_params['strategy'], + transition_start=multires_params['transition_start'], + transition_end=multires_params['transition_end'], + sharpness=multires_params['sharpness'], + enabled=True + ) + + # ✅ КРИТИЧЕСКОЕ ИСПРАВЛЕНИЕ: корректируем размеры + tiled_output = ensure_output_size_match(tiled_output, expected_out_h, expected_out_w) + return tiled_output + + # ═══════════════════════════════════════════════════════════════════ + # ОБЫЧНЫЕ РЕЖИМЫ (Circular, Mirror, Hexagonal, Voronoi, и т.д.) + # ═══════════════════════════════════════════════════════════════════ + + x = input_tensor + + # ────────── NEW ENHANCED MODES ────────── + if mode_x == MODE_VORONOI or mode_y == MODE_VORONOI: + voronoi_cells = int(params.get('voronoi_cells', 8) or 8) + voronoi_seed = int(params.get('voronoi_seed', 42) or 42) + x = compute_voronoi_padding(x, req_pad_h, req_pad_w, voronoi_cells, voronoi_seed) + + elif mode_x == MODE_PERLIN or mode_y == MODE_PERLIN: + perlin_strength = float(params.get('perlin_strength', 0.3) or 0.3) + perlin_scale = float(params.get('perlin_scale', 10.0) or 10.0) + x = compute_perlin_padding(x, req_pad_h, req_pad_w, perlin_strength, perlin_scale) + + elif mode_x == MODE_FRACTAL or mode_y == MODE_FRACTAL: + fractal_iterations = int(params.get('fractal_iterations', 2) or 2) + fractal_scale = float(params.get('fractal_scale', 0.6) or 0.6) + x = compute_fractal_padding(x, req_pad_h, req_pad_w, fractal_iterations, fractal_scale) + + elif mode_x == MODE_ADAPTIVE or mode_y == MODE_ADAPTIVE: + adaptive_threshold = float(params.get('adaptive_threshold', 0.1) or 0.1) + x = compute_adaptive_padding(x, req_pad_h, req_pad_w, adaptive_threshold) + + elif mode_x == MODE_POLAR or mode_y == MODE_POLAR: + # Polar: X=circular (longitude wraps), Y=pole-correct (latitude flips) + # compute_polar_padding handles both axes together + x = compute_polar_padding(x, req_pad_h, req_pad_w) + + elif mode_x == MODE_ANISOTROPIC or mode_y == MODE_ANISOTROPIC: + # Anisotropic: directional padding blending circular + reflect at angle + x = compute_anisotropic_padding( + x, + req_pad_h, + req_pad_w, + angle_deg=float(params.get('anisotropic_angle', 45) or 45), + angle_deg2=params.get('anisotropic_angle2', None), + angle_mix=float(params.get('anisotropic_angle_mix', 1.0) or 1.0), + ) + + # ────────── STANDARD MODES (раздельно по осям) ────────── + else: + # ✅ ИСПРАВЛЕНО: Blend mode применяется ВМЕСТО обычного padding + blend_enabled = params.get('blend_enabled', False) + + if blend_enabled and (req_pad_h > 0 or req_pad_w > 0): + # Blend mode ЗАМЕНЯЕТ обычный padding + blend_params = validate_blend_params(params) + + # Определяем какой advanced режим использовать. + # Polar и Anisotropic используют reflect как наиболее безопасный + # промежуточный fallback — он менее агрессивен, чем circular. + # Hexagonal по Y ведёт себя как circular, по X — смещённый wrap. + if mode_x == MODE_CIRCULAR or mode_y == MODE_CIRCULAR: + mode_advanced_str = 'circular' + elif mode_x == MODE_MIRROR or mode_y == MODE_MIRROR: + mode_advanced_str = 'reflect' + elif mode_x == MODE_HEXAGONAL or mode_y == MODE_HEXAGONAL: + mode_advanced_str = 'circular' # hex стагер близок к circular + elif mode_x == MODE_POLAR or mode_y == MODE_POLAR: + mode_advanced_str = 'reflect' # pole edges: reflect < circular artifacts + elif mode_x == MODE_ANISOTROPIC or mode_y == MODE_ANISOTROPIC: + mode_advanced_str = 'reflect' # directional: reflect более нейтрален + else: + mode_advanced_str = 'circular' # fallback: Voronoi/Perlin/Fractal/Adaptive + + # Применяем улучшенный blend mode + x = compute_advanced_blend_padding( + input_tensor, # ← Исходный тензор БЕЗ padding + pad_h=req_pad_h, + pad_w=req_pad_w, + mode_simple='replicate', + mode_advanced=mode_advanced_str, + blend_strength=blend_params['strength'], + blend_width=blend_params['width'], + falloff_curve=blend_params['falloff'], + edge_sharpness=blend_params['sharpness'] + ) + + else: + # Обычный padding без blend + # Ось Y + if mode_y == MODE_CIRCULAR: + x = F.pad(x, (0, 0, req_pad_h, req_pad_h), mode='circular') + elif mode_y == MODE_MIRROR: + x = _safe_pad4d(x, (0, 0, req_pad_h, req_pad_h), mode='reflect') + elif mode_y == MODE_HEXAGONAL: + x = F.pad(x, (0, 0, req_pad_h, req_pad_h), mode='circular') + else: + x = F.pad(x, (0, 0, req_pad_h, req_pad_h), mode='constant', value=0) + + # Ось X + if mode_x == MODE_CIRCULAR: + x = F.pad(x, (req_pad_w, req_pad_w, 0, 0), mode='circular') + elif mode_x == MODE_MIRROR: + x = _safe_pad4d(x, (req_pad_w, req_pad_w, 0, 0), mode='reflect') + elif mode_x == MODE_HEXAGONAL: + x = compute_hex_padding_x(x, req_pad_w, req_pad_w) + else: + x = F.pad(x, (req_pad_w, req_pad_w, 0, 0), mode='constant', value=0) + + # ═══════════════════════════════════════════════════════════════════ + # CONVOLUTION + MULTI-RESOLUTION (унифицированно для всех режимов) + # ═══════════════════════════════════════════════════════════════════ + + # x теперь содержит padded тензор (с blend или без) + + # 🔥 ЖЕЛЕЗНАЯ ЗАЩИТА ТИПОВ (Для Voronoi, Perlin, Fractal и остальных) + # Если веса модели в float16, а паддинг вернул float32 — конвертируем. + if x.dtype != weight.dtype: + x = x.to(dtype=weight.dtype) + + out_advanced = F.conv2d(x, weight, bias, stride, 0, dilation, groups) + + # ✅ ИСПРАВЛЕНО: Multi-resolution для обычных режимов + if params.get('multires_enabled', False): + multires_params = validate_multires_params(params) + + # Создаем простую версию (output space) + x_default = F.pad(input_tensor, (req_pad_w, req_pad_w, req_pad_h, req_pad_h), + mode='replicate') + out_default = F.conv2d(x_default, weight, bias, stride, 0, dilation, groups) + + # Смешиваем outputs + out_final = apply_multires_blend( + tensor_simple=out_default, + tensor_advanced=out_advanced, # ← ИСПРАВЛЕНО: правильная переменная! + current_step=current_step, + start_step=params['start_step'], + end_step=params['end_step'], + strategy=multires_params['strategy'], + transition_start=multires_params['transition_start'], + transition_end=multires_params['transition_end'], + sharpness=multires_params['sharpness'], + enabled=True + ) + + # ✅ КРИТИЧЕСКОЕ ИСПРАВЛЕНИЕ: корректируем размеры + out_final = ensure_output_size_match(out_final, expected_out_h, expected_out_w) + return out_final # ← ИСПРАВЛЕНО: возвращаем с multires! + else: + # ✅ КРИТИЧЕСКОЕ ИСПРАВЛЕНИЕ: корректируем размеры + out_advanced = ensure_output_size_match(out_advanced, expected_out_h, expected_out_w) + return out_advanced # ← Без multires + + except Exception as e: + print(f"Advanced Tiling v3.1 Error: {e}") + import traceback + traceback.print_exc() + return F.conv2d(input_tensor, weight, bias, stride, padding, dilation, groups) + +# ======================================================================== +# РЕГИСТРАЦИЯ KOHAKU SAMPLER +# ======================================================================== + + +# ======================================================================== +# STEREOSCOPIC 3D (postprocess helper) +# ======================================================================== + +def _pil_to_float_tensor(img: Image.Image) -> torch.Tensor: + """PIL -> torch float tensor (1,C,H,W) in [0,1] on CPU.""" + if img.mode != "RGB": + img = img.convert("RGB") + arr = np.array(img).astype(np.float32) / 255.0 # (H,W,3) + t = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) # (1,3,H,W) + return t + + +def _float_tensor_to_pil(t: torch.Tensor) -> Image.Image: + """torch float tensor (1,C,H,W) in [0,1] -> PIL RGB.""" + t = t.detach().clamp(0.0, 1.0) + arr = (t.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255.0).round().astype(np.uint8) + return Image.fromarray(arr, mode="RGB") + + +def _stereo_warp_tensor(t: torch.Tensor, + eye: str, + engine: str, + separation: float, + convergence: float, + depth_power: float = 1.0, + pole_power: float = 1.0) -> torch.Tensor: + """ + Apply a lightweight stereo warp to an RGB tensor (1,3,H,W) in equirectangular space. + - Engine A: legacy roll+blend + - Engine B: spherical warp via grid_sample with longitude wrap and pole attenuation + """ + eye = str(eye or "left").lower() + sign = 1.0 if eye == "left" else -1.0 + + separation = float(max(0.0, min(float(separation), 0.5))) + convergence = float(max(0.0, min(float(convergence), 1.0))) + depth_power = float(max(0.25, min(float(depth_power), 6.0))) + pole_power = float(max(0.25, min(float(pole_power), 6.0))) + + B, C, H, W = t.shape + if W <= 1 or H <= 1 or separation <= 0.0: + return t + + # base u,v in [0,1] + u = torch.linspace(0.0, 1.0, steps=W, dtype=torch.float32).view(1, 1, 1, W).expand(B, 1, H, W) + v = torch.linspace(0.0, 1.0, steps=H, dtype=torch.float32).view(1, 1, H, 1).expand(B, 1, H, W) + + depth = torch.abs(u - convergence) + depth = torch.pow(torch.clamp(depth, 0.0, 1.0), depth_power) + + if str(engine).startswith("A"): + # Legacy roll+blend + shift_px = int(round(W * separation)) + if shift_px == 0: + return t + rolled = torch.roll(t, shifts=int(round(sign * shift_px)), dims=3) + alpha = depth.expand_as(t) + return t * (1.0 - alpha) + rolled * alpha + + # Engine B: wrap-safe grid_sample warp + lat = (0.5 - v) * math.pi + attn = torch.pow(torch.cos(lat).clamp(0.0, 1.0), pole_power) + + shift_u = sign * separation * depth * attn + u2 = torch.remainder(u + shift_u, 1.0) + + # grid_sample normalized coords + x = u2 * 2.0 - 1.0 + y = v * 2.0 - 1.0 + grid = torch.cat([x, y], dim=1).permute(0, 2, 3, 1).contiguous() # (B,H,W,2) + + return F.grid_sample(t, grid, mode="bilinear", padding_mode="border", align_corners=True) + + +def _make_stereo_outputs(img: Image.Image, settings: dict): + """Returns (left,right) PIL images for the given input.""" + t = _pil_to_float_tensor(img) + engine = settings.get("engine", "B (Spherical warp)") + separation = settings.get("separation", 0.065) + convergence = settings.get("convergence", 0.5) + depth_power = settings.get("depth_power", 1.0) + pole_power = settings.get("pole_power", 1.0) + + left_t = _stereo_warp_tensor(t, "left", engine, separation, convergence, depth_power, pole_power) + right_t = _stereo_warp_tensor(t, "right", engine, separation, convergence, depth_power, pole_power) + + return _float_tensor_to_pil(left_t), _float_tensor_to_pil(right_t) + +def register_kohaku_sampler(): + """Регистрирует Kohaku_LoNyu_Yog сэмплер в WebUI""" + global _SAMPLER_REGISTERED + + if _SAMPLER_REGISTERED: + return + + if any(s.name == 'Kohaku_LoNyu_Yog' for s in sd_samplers.all_samplers): + _SAMPLER_REGISTERED = True + return + + if not hasattr(k_diffusion.sampling, 'sample_kohaku_lonyu_yog'): + setattr(k_diffusion.sampling, 'sample_kohaku_lonyu_yog', sample_kohaku_lonyu_yog) + + sampler_data = sd_samplers_common.SamplerData( + name='Kohaku_LoNyu_Yog', + constructor=lambda model: sd_samplers_kdiffusion.KDiffusionSampler('sample_kohaku_lonyu_yog', model), + aliases=['kohaku', 'lonyu'], + options={'second_order': True} + ) + + sd_samplers.all_samplers.append(sampler_data) + sd_samplers.all_samplers_map = {x.name: x for x in sd_samplers.all_samplers} + + _SAMPLER_REGISTERED = True + print("✓ Kohaku_LoNyu_Yog sampler registered successfully!") + +# ======================================================================== +# CFG DRIFT CORRECTION +# Источник: adept_sampler_v4_COMPLETE.py / ComfyUI-Latent-Modifiers +# ======================================================================== + +@torch.no_grad() +def apply_combat_cfg_drift(latent: torch.Tensor, + method: str = 'mean', + intensity: float = 1.0) -> torch.Tensor: + """ + Убирает CFG-induced mean drift из латента. + + При высоком CFG (> 7) среднее латента постепенно смещается от нуля, + что вызывает цветовые сдвиги и пересвет/недосвет. Эта функция + вычитает это смещение пропорционально intensity. + + Args: + latent: Латентный тензор [B, C, H, W] + method: 'mean' — быстро, глобальное среднее по всему батчу + 'median' — устойчив к выбросам, чуть медленнее + intensity: 0.0 = ничего не делать, 1.0 = убрать весь дрейф + """ + if intensity <= 0.0: + return latent + try: + if method == 'median': + center = latent.view(latent.shape[0], -1).median(dim=-1, keepdim=True)[0] + center = center.view(-1, 1, 1, 1) + else: + center = latent.mean(dim=(1, 2, 3), keepdim=True) + return latent - center * float(intensity) + except Exception as e: + print(f"[CFG Drift] ошибка: {e}") + return latent + + +# ======================================================================== +# КЛАСС СКРИПТА (Advanced Tiling + Latent Mirroring) +# ======================================================================== + +class AdvancedTilingScriptV3(scripts.Script): + def title(self): + return "Advanced Tiling v3.0 PRO (Kohaku + Stereo + Aniso + Latent Mirror)" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_img2img): + with gr.Accordion("🚀 Advanced Tiling v3.0 PRO Edition", open=False): + with gr.Row(): + enabled = gr.Checkbox(label="Enable Tiling", value=False) + use_kohaku = gr.Checkbox( + label="Use Kohaku_LoNyu_Yog Sampler", + value=False, + info="Geometric second-order method for better seamless quality" + + ) + + with gr.Tabs(): + with gr.Tab("🎨 Basic Modes"): + with gr.Row(): + mode_x = gr.Dropdown( + label="Mode X", + choices=[MODE_OFF, MODE_CIRCULAR, MODE_MIRROR, MODE_HEXAGONAL, + MODE_VORONOI, MODE_PERLIN, MODE_FRACTAL, MODE_ADAPTIVE], + value=MODE_CIRCULAR, + info="NEW: Voronoi/Perlin/Fractal/Adaptive modes available!" + + ) + mode_y = gr.Dropdown( + label="Mode Y", + choices=[MODE_OFF, MODE_CIRCULAR, MODE_MIRROR, MODE_HEXAGONAL, + MODE_VORONOI, MODE_PERLIN, MODE_FRACTAL, MODE_ADAPTIVE], + value=MODE_CIRCULAR, + info="NEW: Voronoi/Perlin/Fractal/Adaptive modes available!" + ) + + # Shown when any special override is active (Panorama/Cubemap/Polar/Anisotropic) + special_mode_notice = gr.Markdown( + "ℹ️ Basic / Enhanced modes are active", + visible=True + ) + + with gr.Row(): + multires = gr.Checkbox( + label="Multi-Resolution Mode", + value=False, + info="Gradually transition from simple to advanced tiling" + + ) + # 🚀 ADVANCED ZOOM & PROXIMITY SYSTEM + with gr.Row(): + use_zoom = gr.Checkbox(label="🔍 Advanced Zoom System (NEW!)", value=False + ) + use_blend = gr.Checkbox(label="Advanced Blend Mode", value=False + ) + + with gr.Accordion("🔍 Advanced Zoom Settings", open=False): + with gr.Row(): + zoom_factor = gr.Slider(-9.0, 20.0, step=0.1, label="🎯 Zoom Factor", value=0.0 + ) + with gr.Row(): + zoom_mode = gr.Dropdown(label="Zoom Mode", + choices=["outpaint_zoom", "blend_transition", "convergence_shift", + "grid_warp", "hybrid", "spiral_zoom"], + value="grid_warp", + info="geometry: grid_warp / spiral_zoom / convergence_shift | canvas: outpaint_zoom / blend_transition / hybrid" + ) + zoom_engine = gr.Dropdown(label="Zoom Engine", + choices=["auto", "geometry", "canvas"], + value="auto", + info="auto = выбрать backend по Zoom Mode" + ) + with gr.Row(): + zoom_path = gr.Radio( + label="Zoom Path", + choices=["Latent", "Conv [Experimental]", "Both [Experimental]"], + value="Latent", + info=( + "Latent = варп params.x (стабильно, поддерживает все режимы). " + "Conv = scale+crop в каждом conv слое — поддерживает только zoom_factor + strength. " + "Both = оба пути одновременно." + ) + ) + with gr.Row(): + conv_zoom_strength = gr.Slider( + 0.0, 1.0, step=0.05, + label="Conv Zoom Strength", + value=0.5, + visible=False, + info="Общая сила conv-path зума. 1.0 = полная сила по stage-весам." + ) + with gr.Row(): + blend_mode_type = gr.Dropdown(label="Blend Mode Type", + choices=["circular_reflect", "circular_constant", "reflect_constant", + "polar_circular", "mirror_circular", "aniso_circular", + "custom", "gradient_radial", "noise_blend"], + value="circular_reflect" + ) + interp_mode = gr.Dropdown(choices=["bilinear", "bicubic", "nearest"], + value="bilinear", + label="Interpolation", + info="bicubic = лучшее качество (требует PyTorch ≥1.10)" + ) + + # ── Geometry Settings ────────────────────────────── + with gr.Group() as zoom_geometry_group: + gr.Markdown("### 🔷 Geometry Settings *(grid_warp · spiral_zoom · convergence_shift)*") + with gr.Row(): + zoom_convergence = gr.Slider(0.0, 1.0, step=0.05, label="Convergence Point X", value=0.5) + zoom_depth_power = gr.Slider(0.25, 4.0, step=0.05, label="Depth Curve Power", value=1.0, + info="Кривая усиления эффекта convergence_shift") + with gr.Row(): + auto_clamp_pan = gr.Checkbox(label="Auto Clamp Pan", value=True, + info="Не даёт уйти контенту за кадр при сильном pan") + + # ── Canvas Settings ──────────────────────────────── + with gr.Group() as zoom_canvas_group: + gr.Markdown("### 🎨 Canvas Settings *(outpaint_zoom · blend_transition · hybrid)*") + with gr.Row(): + noise_strength = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, value=1.0, + label="Noise Strength", + info="Сила coherent latent fill в пустых зонах") + adaptive_noise_scale = gr.Checkbox(label="Adaptive Noise", value=True, + info="Дальше от контента → сильнее fill") + with gr.Row(): + edge_fade = gr.Checkbox(label="Edge Fade (fill zones)", value=True, + info="Затухание fill-зоны у краёв кадра. Работает только когда есть fill (zoom-out / outpaint).") + variance_correction = gr.Checkbox(label="Variance Correction", value=True, + info="Убирает серость на швах") + with gr.Row(): + zoom_fade_to_black = gr.Checkbox(label="Fade to Black", value=False, + info="Радиальное затемнение краёв") + zoom_fade_strength = gr.Slider(minimum=0.0, maximum=0.5, step=0.01, value=0.15, + label="Fade Strength", + info="Размер зоны затемнения") + + # ── Common / Shared ──────────────────────────────── + with gr.Group(): + gr.Markdown("### ⚙️ Common Settings") + with gr.Row(): + debug_mode = gr.Checkbox(label="Debug Mode", value=False) + + with gr.Accordion("🗄️ Compatibility / Legacy", open=False): + zoom_force_original_size = gr.Checkbox( + label="🛡️ Force Original Size (legacy / compat only)", + value=True, + info="[Legacy] Этот флаг относился к conv-path zoom, который отключён. " + "Latent-path zoom всегда сохраняет исходный размер тензора. " + "Оставлен только для совместимости с сохранёнными пресетами." + ) + + with gr.Accordion("🎯 Zoom Positioning", open=False): + gr.Markdown(""" + **Точное позиционирование для Zoom:** + - Независимое от Latent Mirroring + - Контроль положения оригинала при zoom + """) + with gr.Row(): + zoom_pan_x = gr.Slider(minimum=-1.0, maximum=1.0, step=0.01, + label='Zoom X Offset (Pan)', value=0.0, + info='Смещение по горизонтали: -1=влево, 0=центр, +1=вправо') + zoom_pan_y = gr.Slider(minimum=-1.0, maximum=1.0, step=0.01, + label='Zoom Y Offset (Pan)', value=0.0, + info='Смещение по вертикали: -1=вверх, 0=центр, +1=вниз') + gr.Markdown("**Быстрые пресеты:**") + with gr.Row(): + btn_center = gr.Button("📍 Center", size="sm") + btn_top = gr.Button("⬆️ Top", size="sm") + btn_bottom = gr.Button("⬇️ Bottom", size="sm") + btn_left = gr.Button("⬅️ Left", size="sm") + btn_right = gr.Button("➡️ Right", size="sm") + btn_center.click(lambda: (0.0, 0.0), outputs=[zoom_pan_x, zoom_pan_y]) + btn_top.click( lambda: (0.0, -0.5), outputs=[zoom_pan_x, zoom_pan_y]) + btn_bottom.click(lambda: (0.0, 0.5), outputs=[zoom_pan_x, zoom_pan_y]) + btn_left.click( lambda: (-0.5, 0.0), outputs=[zoom_pan_x, zoom_pan_y]) + btn_right.click( lambda: (0.5, 0.0), outputs=[zoom_pan_x, zoom_pan_y]) + + with gr.Accordion("🌀 Spiral Zoom Settings", open=False) as zoom_spiral_group: + gr.Markdown(""" + **Параметры спирального зума:** + - Работает только с Zoom Mode = spiral_zoom + """) + with gr.Row(): + spiral_rotation = gr.Slider(0.0, 2.0, step=0.05, + label="Rotation Strength", value=0.5, + info="0 = нет вращения, 1 = средне, 2 = максимум") + spiral_direction = gr.Radio( + label="Rotation Direction", + choices=["Clockwise", "Counter-clockwise"], + value="Clockwise", + info="Clockwise = +1.0, Counter-clockwise = -1.0") + + # ── Zoom Step Control ────────────────────────────── + with gr.Accordion("⏱ Zoom Step Control", open=False): + gr.Markdown(""" +**По умолчанию zoom применяется на шаге 0 (чистый шум).** +Включите, чтобы приме��ять zoom позже или как анимированный диапазон. + +- **Step 0** = текущее поведение (лучше для canvas fill) +- **Custom fraction** = один раз, на выбранном шаге (0.10–0.25 = после структуры) +- **Animated range** = плавный «наезд камеры» по шагам *(только geometry backend)* + """) + with gr.Row(): + zoom_step_control_enabled = gr.Checkbox( + label="Enable Zoom Step Control", + value=False, + info="Включить. По умолчанию: step 0." + ) + zoom_lock_after_end = gr.Checkbox( + label="🔒 Lock after End", + value=False, + info=( + "После последнего zoom-шага блокирует любой повторный старт zoom " + "(например, при hires pass или сбросе состояния). " + "Sampler продолжает естественно от последнего warped-состояния — " + "latent не перезаписывается." + ) + ) + with gr.Row(): + zoom_apply_mode = gr.Radio( + label="Apply Mode", + choices=["Step 0 (default)", "Custom fraction", "Animated range"], + value="Step 0 (default)", + visible=False, + info="Step 0 = one-shot. Custom = один шаг по фракции. Animated = диапазон с интерполяцией." + ) + # ── Once: Custom fraction ────────────────────── + with gr.Row(): + zoom_apply_frac = gr.Slider( + 0.0, 1.0, step=0.01, + label="Apply at fraction of total steps", + value=0.0, + visible=False, + info="0.0 = шаг 0, 0.15 = 15% шагов, 1.0 = последний шаг" + ) + # ── Animated range controls ──────────────────── + with gr.Row(): + zoom_range_start_frac = gr.Slider( + 0.0, 1.0, step=0.01, + label="Range Start (fraction)", + value=0.0, + visible=False, + info="Начало диапазона (0.0 = шаг 0)" + ) + zoom_range_end_frac = gr.Slider( + 0.0, 1.0, step=0.01, + label="Range End (fraction)", + value=0.35, + visible=False, + info="Конец диапазона (0.35 = 35% шагов)" + ) + with gr.Row(): + zoom_factor_start = gr.Slider( + -9.0, 20.0, step=0.1, + label="Zoom Factor Start", + value=0.0, + visible=False, + info="Zoom factor в начале диапазона" + ) + zoom_factor_end = gr.Slider( + -9.0, 20.0, step=0.1, + label="Zoom Factor End", + value=5.0, + visible=False, + info="Zoom factor в конце диапазона" + ) + with gr.Row(): + zoom_easing = gr.Dropdown( + choices=["instant", "linear", "ease_in", "ease_out", "ease_in_out"], + value="ease_in_out", + label="Easing", + visible=False, + info="instant = скачок в начале, ease_in_out = медленный наезд" + ) + zoom_every_n_steps = gr.Slider( + 1, 10, step=1, + label="Apply every N steps", + value=1, + visible=False, + info="1 = каждый шаг, 3 = каждый третий" + ) + with gr.Row(): + zoom_from_origin = gr.Checkbox( + label="From origin (recommended)", + value=True, + visible=False, + info="Варпить от сохранённого латента начала диапазона, а не от текущего. Устраняет накопление blur." + ) + + with gr.Accordion("🎭 Zoom Compose / ROI", open=False): + with gr.Row(): + zoom_compose_mode = gr.Radio( + label="Zoom Compose Mode", + choices=["Global", "Pinned Subject", "Dual-Transform ROI"], + value="Global", + info=( + "Global = обычный zoom. " + "Pinned Subject = объект удерживается, фон зумится. " + "Dual-Transform ROI = объект и фон масштабируются по-разному." + ) + ) + with gr.Row(): + roi_shape = gr.Dropdown( + label="ROI Shape", + choices=["ellipse", "box"], + value="ellipse", + visible=False, + ) + roi_preserve_strength = gr.Slider( + 0.0, 1.0, step=0.05, + label="Preserve Strength", + value=1.0, + visible=False, + info="Сила удержания ROI-зоны в режиме Pinned Subject" + ) + with gr.Row(): + roi_center_x = gr.Slider( + 0.0, 1.0, step=0.01, + label="ROI Center X", value=0.5, + visible=False, + ) + roi_center_y = gr.Slider( + 0.0, 1.0, step=0.01, + label="ROI Center Y", value=0.5, + visible=False, + ) + with gr.Row(): + roi_radius_x = gr.Slider( + 0.01, 1.0, step=0.01, + label="ROI Radius X", value=0.18, + visible=False, + ) + roi_radius_y = gr.Slider( + 0.01, 1.0, step=0.01, + label="ROI Radius Y", value=0.18, + visible=False, + ) + with gr.Row(): + roi_feather = gr.Slider( + 0.0, 0.5, step=0.01, + label="ROI Feather", value=0.08, + visible=False, + info="Мягкость перехода на границе ROI" + ) + with gr.Row(): + roi_fg_zoom_factor = gr.Slider( + -9.0, 20.0, step=0.1, + label="ROI Foreground Zoom Factor", + value=0.0, + visible=False, + info="Zoom factor для объекта (Dual-Transform ROI)" + ) + roi_bg_zoom_factor = gr.Slider( + -9.0, 20.0, step=0.1, + label="ROI Background Zoom Factor", + value=0.0, + visible=False, + info="Zoom factor для фона (Dual-Transform ROI). Полностью заменяет основной Zoom Factor для фоновой зоны." + ) + + + with gr.Row(): + blend_strength = gr.Slider(0.0, 1.0, step=0.05, label="Blend Strength (legacy)", value=0.5 + ) + + # ═══ NEW: ADVANCED BLEND SETTINGS ═══ + with gr.Accordion("🎨 Advanced Blend Settings", open=False): + gr.Markdown(""" + **Blend Mode улучшен!** Теперь это настоящий слайдер приближения/отдаления: + - `0.0` = далеко (простой padding) + - `0.5` = средняя дистанция (смешивание) + - `1.0` = близко (полный advanced tiling) + """) + + with gr.Row(): + blend_falloff = gr.Dropdown( + label="Falloff Curve", + choices=["linear", "smoothstep", "cosine", "perceptual"], + value="smoothstep", + info="Shape of the transition gradient" + + ) + blend_sharpness = gr.Slider( + 0.1, 5.0, step=0.1, + label="Edge Sharpness", + value=1.0, + info="<1 = softer, >1 = sharper transitions" + + ) + + blend_width = gr.Slider( + 0, 128, step=8, + label="Blend Width (pixels, 0=auto)", + value=0, + info="Width of transition zone" + + ) + + + # ═══ NEW: ADVANCED MULTI-RESOLUTION SETTINGS ═══ + with gr.Accordion("🔬 Advanced Multi-Resolution Settings", open=False): + gr.Markdown(""" +**Multi-Resolution** смешивает обычный (replicate) и продвинутый (tiled) padding по шагам диффузии: +- **linear** — равномерный переход +- **cosine** — плавный S-образный ✅ рекомендуется +- **exponential** — быстрый старт, медленный финиш +- **sigmoid** — острая S-кривая +- **early_boost** — максимум в начале, быстрое угасание +- **late_smooth** — медленный нарастающий финиш + """) + + multires_strategy = gr.Dropdown( + label="Interpolation Strategy", + choices=["linear", "cosine", "exponential", "sigmoid", "early_boost", "late_smooth"], + value="cosine", + info="How to blend simple → advanced over diffusion steps" + ) + + with gr.Row(): + multires_start = gr.Slider( + 0.0, 1.0, step=0.05, + label="Transition Start", + value=0.0, + info="Fraction of steps where transition begins (0.0 = from start)" + ) + multires_end = gr.Slider( + 0.0, 1.0, step=0.05, + label="Transition End", + value=0.3, + info="Fraction of steps where transition ends (0.3 = at 30%) — auto-swapped if < Start" + ) + + multires_sharpness = gr.Slider( + 0.1, 5.0, step=0.1, + label="Transition Sharpness", + value=1.0, + info="Speed of transition (< 1 = slower / wider, > 1 = faster / sharper)" + ) + + # NEW ENHANCED MODES Parameters + with gr.Accordion("✨ NEW Enhanced Modes Parameters", open=False): + gr.Markdown("**Parameters for Voronoi, Perlin, Fractal, and Adaptive modes**") + + with gr.Row(): + voronoi_cells = gr.Slider( + label="Voronoi: Cells", + minimum=4, maximum=32, value=8, step=1, + info="Number of organic cells for Voronoi mode" + + ) + voronoi_seed = gr.Slider( + label="Voronoi: Seed", + minimum=0, maximum=9999, value=42, step=1, + info="Random seed for pattern reproducibility" + + ) + + with gr.Row(): + perlin_strength = gr.Slider( + label="Perlin: Distortion Strength", + minimum=0.0, maximum=1.0, value=0.3, step=0.05, + info="How much Perlin noise distorts tiling" + + ) + perlin_scale = gr.Slider( + label="Perlin: Noise Scale", + minimum=1.0, maximum=50.0, value=10.0, step=1.0, + info="Scale of Perlin noise pattern" + + ) + + with gr.Row(): + fractal_iterations = gr.Slider( + label="Fractal: Iterations", + minimum=1, maximum=4, value=2, step=1, + info="Number of fractal recursion levels" + + ) + fractal_scale = gr.Slider( + label="Fractal: Scale Factor", + minimum=0.3, maximum=0.9, value=0.6, step=0.05, + info="Size reduction per iteration" + + ) + + adaptive_threshold = gr.Slider( + label="Adaptive: Edge Threshold", + minimum=0.0, maximum=1.0, value=0.1, step=0.05, + info="Edge detection threshold for auto mode selection" + + ) + + with gr.Tab("🌐 Advanced Modes"): + gr.Markdown("**Специальные режимы для сложных топологий**") + + with gr.Row(): + use_panorama = gr.Checkbox(label="Panorama 360°", value=False + ) + use_polar = gr.Checkbox( + label="Polar (Sphere Correct)", + value=False, + info="Correct pole transitions for equirectangular" + + ) + + + with gr.Accordion("🧭 Panorama Live (panorama-tools integrated)", open=False): + gr.Markdown("**Live equirectangular panorama processing during generation**") + + with gr.Row(): + panorama_engine = gr.Dropdown( + label="Panorama Engine", + choices=["A (Legacy)", "B (Fast Pole-Correct)", "C (Live 3D GridSample)"], + value="C (Live 3D GridSample)", + info="A = legacy circular wrap. B = fast pole-correct padding (+ optional antipode + pole blur). C = true spherical mapping via grid_sample (best seams, slower)." + + ) + panorama_coord_mode = gr.Dropdown( + label="Engine C: Coord Mode", + choices=["Cartesian (lon/lat)", "Polar (radial warp around poles)"], + value="Cartesian (lon/lat)", + info="Advanced: Cartesian uses standard lon/lat. Polar re-parameterizes latitude near poles." + + ) + + with gr.Row(): + panorama_yaw = gr.Slider(-180, 180, step=1, label="Engine C: Yaw (°)", value=0 + ) + panorama_pitch = gr.Slider(-180, 180, step=1, label="Engine C: Pitch (°)", value=0 + ) + panorama_roll = gr.Slider(-180, 180, step=1, label="Engine C: Roll (°)", value=0 + ) + + with gr.Row(): + panorama_twist_deg = gr.Slider(-180, 180, step=1, label="Engine C: Twist (°)", value=0, + info="Latitude-dependent longitude twist (local roll feel)." + ) + panorama_twist_power = gr.Slider(0.25, 4.0, step=0.01, label="Engine C: Twist Power", value=1.0 + ) + + panorama_swirl_deg = gr.Slider(-180, 180, step=1, label="Engine C: Swirl (°)", value=0, + info="Extra rotation bias toward poles (0=off)." + ) + panorama_swirl_power = gr.Slider(0.25, 4.0, step=0.01, label="Engine C: Swirl Power", value=1.0 + ) + + with gr.Row(): + panorama_polar_scale = gr.Slider(0.25, 2.0, step=0.01, label="Engine C: Polar Scale", value=1.0 + ) + panorama_polar_power = gr.Slider(0.25, 4.0, step=0.01, label="Engine C: Polar Power", value=1.0 + ) + panorama_pole_ease_power = gr.Slider(0.25, 4.0, step=0.01, label="Pole Easing Power", value=1.0, + info="Eases latitude magnitude near poles (1=off)." + ) + + with gr.Row(): + panorama_antipode_strength = gr.Slider(0.0, 1.0, step=0.05, label="Antipode Mix (anti-seam)", value=0.0, + info="Kohaku-inspired mixing with the antipode point (0=off)." + ) + panorama_pole_blur_strength = gr.Slider(0.0, 1.0, step=0.05, label="Pole Blur Strength", value=0.0 + ) + panorama_pole_blur_radius = gr.Slider(0, 16, step=1, label="Pole Blur Radius (px)", value=0 + ) + + with gr.Row(): + panorama_pole_blur_power = gr.Slider(0.25, 4.0, step=0.01, label="Pole Blur Mask Power", value=1.0 + ) + panorama_geoaa_samples = gr.Slider(1, 4, step=1, label="Engine C: Geo AA Samples", value=1 + ) + panorama_geoaa_radius = gr.Slider(0.0, 2.0, step=0.25, label="Engine C: Geo AA Radius (px)", value=0.0 + ) + + with gr.Row(): + panorama_grid_interp = gr.Dropdown(label="Engine C: grid_sample Interp", + choices=["bilinear", "nearest"], + value="bilinear" + ) + panorama_grid_padding = gr.Dropdown(label="Engine C: grid_sample Padding", + choices=["border", "reflection", "zeros"], + value="border" + ) + panorama_cache_angle_quant = gr.Slider(0.1, 5.0, step=0.1, label="Engine C: Cache Quant (°)", value=0.5 + ) + + with gr.Row(): + use_cubemap = gr.Checkbox( + label="Cubemap (3D)", + value=False, + info="3x2 cubemap net (S/E/N over B/T/W). Faces must be square: width/3 == height/2." + + ) + cubemap_engine = gr.Dropdown( + label="Cubemap Engine", + choices=["A (Fast)", "B (Advanced)", "C (3D GridSample)"], + value="A (Fast)", + info="A = fastest batched cubemap padding. B = seam-aware padding blend (reduces seams). C = true 3D grid_sample mapping (yaw/pitch/roll, slower but best seams)." + + ) + cubemap_resolution_mode = gr.Dropdown( + label="Cubemap Input Size", + choices=["Already 3x2 (Full image)", "Face size (auto scale to 3x2)"], + value="Already 3x2 (Full image)", + info="If you choose 'Face size', your base width/height will be interpreted as a single face." + + ) + + with gr.Row(): + cubemap_pad_mode = gr.Dropdown( + label="Cubemap Pad Mode", + choices=["replicate", "reflect", "circular"], + value="replicate", + info="Padding mode used when assembling neighbor strips." + + ) + cubemap_blend_width = gr.Slider( + 0, 16, step=1, + label="Cubemap Seam Width (px)", + value=2, + info="Only for Engine B: how many pixels inside padding are blended." + + ) + cubemap_blend_strength = gr.Slider( + 0.0, 1.0, step=0.05, + label="Cubemap Seam Strength", + value=0.75, + info="Only for Engine B: 0 = Engine A behavior, 1 = full seam blending." + + ) + + + + with gr.Row(): + cubemap_yaw = gr.Slider( + -180, 180, step=1, + label="Engine C: Yaw (°)", + value=0, + info="Only for Engine C: rotation around Y axis." + + ) + cubemap_pitch = gr.Slider( + -180, 180, step=1, + label="Engine C: Pitch (°)", + value=0, + info="Only for Engine C: rotation around X axis." + + ) + cubemap_roll = gr.Slider( + -180, 180, step=1, + label="Engine C: Roll (°)", + value=0, + info="Only for Engine C: rotation around Z axis." + + ) + + # --- Engine C advanced coordinate controls --- + with gr.Row(): + cubemap_coord_mode = gr.Dropdown( + choices=["Cartesian (Face UV)", "Polar (Radial Warp)"], + value="Cartesian (Face UV)", + label="Engine C: Coord Mode", + info="Advanced: how face UV is interpreted before cubemap projection." + + ) + cubemap_twist_deg = gr.Slider( + -180, 180, step=1, + label="Engine C: Face Twist (°)", + value=0, + info="Advanced: rotate face UV around face normal before projection." + + ) + + with gr.Row(): + cubemap_polar_scale = gr.Slider( + 0.25, 2.0, step=0.01, + label="Engine C: Polar Radius Scale", + value=1.0, + info="Only for Polar mode: r' = (r*scale)^power." + + ) + cubemap_polar_power = gr.Slider( + 0.25, 4.0, step=0.01, + label="Engine C: Polar Power", + value=1.0, + info="Only for Polar mode: r' = (r*scale)^power." + + ) + + with gr.Row(): + cubemap_swirl_deg = gr.Slider( + -180, 180, step=1, + label="Engine C: Swirl (°)", + value=0, + info="Extra angle: adds radius-dependent twist (0=off)." + + ) + cubemap_swirl_power = gr.Slider( + 0.25, 4.0, step=0.01, + label="Engine C: Swirl Power", + value=1.0, + info="Swirl curve: twist ∝ r^power." + + ) + + with gr.Row(): + cubemap_grid_interp = gr.Dropdown( + choices=["bilinear", "nearest"], + value="bilinear", + label="Engine C: grid_sample Interp", + info="Bilinear=best quality. Nearest=faster/sharper." + + ) + cubemap_grid_padding = gr.Dropdown( + choices=["border", "reflection", "zeros"], + value="border", + label="Engine C: grid_sample Padding", + info="border recommended for seams." + + ) + cubemap_cache_angle_quant = gr.Slider( + 0.1, 5.0, step=0.1, + label="Engine C: Cache Quant (°)", + value=0.5, + info="Cache step for yaw/pitch/roll/twist/swirl. Lower=more precise, slower." + + ) + + with gr.Row(): + cubemap_geoaa_samples = gr.Slider( + 1, 4, step=1, + label="Engine C: Geo AA Samples", + value=1, + info="Only for Engine C: multi-sampling (1=off). Higher = smoother but slower." + + ) + cubemap_geoaa_radius = gr.Slider( + 0.0, 2.0, step=0.25, + label="Engine C: Geo AA Radius (px)", + value=0.0, + info="Only for Engine C: jitter radius in atlas pixels." + + ) + cubemap_antipode_strength = gr.Slider( + 0.0, 1.0, step=0.05, + label="Engine C: Antipode Mix", + value=0.0, + info="Only for Engine C: Kohaku-inspired antipode blending on the sphere." + + ) + + with gr.Row(): + use_anisotropic = gr.Checkbox( + label="Anisotropic (Directional)", + value=False, + info="Different behavior along diagonals" + + ) + aniso_angle = gr.Slider( + 0, 360, step=15, + label="Anisotropic Angle", + value=45, + info="Direction of fibers/texture (degrees)" + + ) + + aniso_angle2 = gr.Slider( + 0, 360, step=15, + label="Second Angle (optional)", + value=0, + info="Advanced: enable by setting Angle Mix < 1.0" + + ) + aniso_angle_mix = gr.Slider( + 0.0, 1.0, step=0.05, + label="Angle Mix (Angle1 → Angle2)", + value=1.0, + info="1.0 = only Angle1, 0.0 = only Angle2." + + ) + + # ── LATENT NOISE INIT ───────────────────────────────── + with gr.Accordion("🎲 Latent Noise Init (Panorama)", open=False): + gr.Markdown(""" +Заменяет стандартный `randn` шум на откалиброванный при старте генерации. +Рекомендуется при **Panorama 360°** — убирает цветовые пятна на стыках. +`null_mix=0.25` = 25% нейтрального латента → нет случайного цветового дрейфа. + """) + latent_noise_init = gr.Checkbox( + label="Enable Calibrated Noise Init", + value=False, + info="Только при первом шаге (step=0)" + ) + null_mix = gr.Slider( + minimum=0.0, maximum=0.5, step=0.05, + label="Null Mix", + value=0.25, + info="Доля нейтрального латента (рек. 0.25 для панорам)" + ) + + + with gr.Tab("🎭 Stereoscopic 3D"): + gr.Markdown("**Генерация стереопар для 3D контента**") + + with gr.Row(): + stereo_enabled = gr.Checkbox(label="Enable Stereoscopic Mode", value=False) + stereo_eye = gr.Radio( + label="Eye", + choices=["left", "right", "both"], + value="left", + info="Which eye view to generate" + + ) + + + with gr.Row(): + stereo_engine = gr.Dropdown( + label="Stereo Engine", + choices=["A (Legacy shift)", "B (Spherical warp)"], + value="B (Spherical warp)", + info="B = smoother spherical warp (recommended). A = old-style shift/blend." + + ) + stereo_output = gr.Dropdown( + label="Output", + choices=["Side-by-Side", "Left only", "Right only", "Both separate"], + value="Side-by-Side", + info="Default: create a single side-by-side image. You can also output only one eye or two separate images." + + ) + stereo_in_model = gr.Checkbox( + label="Apply stereo during denoise (experimental)", + value=False, + info="If enabled, stereo shift affects UNet/conv padding during generation. If disabled, stereo is applied as a postprocess (recommended)." + + ) + + with gr.Row(): + stereo_depth_power = gr.Slider( + 0.25, 4.0, step=0.05, + label="Depth Curve Power", + value=1.0, + info="Higher values push parallax more toward edges." + + ) + stereo_pole_power = gr.Slider( + 0.25, 4.0, step=0.05, + label="Pole Attenuation Power", + value=1.0, + info="Scales stereo shift by cos(latitude)^power to reduce pole stretching." + + ) + + with gr.Row(): + stereo_separation = gr.Slider( + 0.0, 0.15, step=0.005, + label="IPD (Inter-Pupillary Distance)", + value=0.065, + info="Eye separation as fraction of width" + + ) + stereo_convergence = gr.Slider( + 0.0, 1.0, step=0.05, + label="Convergence Point", + value=0.5, + info="Depth at which eyes converge" + + ) + + gr.Markdown(""" + **💡 Совет для стерео:** + - Генерируйте сначала левый глаз + - Затем правый с теми же параметрами + - Используйте Side-by-Side или Anaglyph компоновку + """) + + with gr.Tab("🪞 Latent Mirroring"): + # Главный переключатель латентного зеркалирования + enable_mirroring = gr.Checkbox( + label="Enable Latent Mirroring", + value=False + ) + + with gr.Group(): + mirror_mode = gr.Radio( + label='Latent Mirror mode', + choices=['None', 'Alternate Steps', 'Blend Average'], + value='None', + type="index" + ) + mirror_style = gr.Radio( + label='Latent Mirror style', + choices=[ + 'Horizontal Mirroring', + 'Vertical Mirroring', + 'Horizontal+Vertical Mirroring', + '90 Degree Rotation', + '180 Degree Rotation', + 'Roll Channels', + 'None' + ], + value='Horizontal Mirroring', + type="index" + ) + + # ── Step Control ────────────────────────────────────── + with gr.Accordion("⏱️ Step Control", open=False): + mirror_step_control_enabled = gr.Checkbox( + label="Enable Step Control", value=False, + info="Применять зеркало/паннинг только в указанном диапазоне шагов") + mirror_step_mode = gr.Radio( + label="Step mode", + choices=["Max fraction (original)", "Custom range"], + value="Max fraction (original)", + type="index" + ) + + mirroring_max_step_fraction = gr.Slider( + minimum=0.0, maximum=1.0, step=0.01, + label='Maximum steps fraction to mirror at', + value=0.25, + visible=True, + info="Используется в режиме Max fraction" + ) + with gr.Row(): + mirror_start_frac = gr.Slider( + minimum=0.0, maximum=1.0, step=0.01, + label="Mirror Start (fraction of total steps)", + value=0.0, + visible=False + ) + mirror_end_frac = gr.Slider( + minimum=0.0, maximum=1.0, step=0.01, + label="Mirror End (fraction of total steps)", + value=0.25, + visible=False + ) + + gr.HTML("