Delete negative_rejection_steering/scripts/nrs_kohaku_enhanced_v2 (1).py

#2
negative_rejection_steering/scripts/nrs_kohaku_enhanced_v2 (1).py DELETED
@@ -1,1220 +0,0 @@
1
- import math
2
- import torch
3
- import gradio as gr
4
- from modules import scripts, script_callbacks, sd_samplers_cfg_denoiser, shared
5
-
6
- # ==============================================================================
7
- # NRS + KOHAKU ENHANCED — Version 2.0
8
- #
9
- # Improvements over v1:
10
- # 1. Midpoint Refinement (replaces flawed Antipodal — correct Kohaku principle)
11
- # 2. Curve Scheduling (12 curves: Constant/Linear/Cosine/Power/Repeating/Sawtooth)
12
- # 3. CADS Trapezoidal Schedule (tau1/tau2 annealing)
13
- # 4. Adaptive Phases (Euler → DPM++ → Detail, from adaptive_progressive.py)
14
- # 5. Per-Channel NRS (independent processing per latent channel)
15
- # 6. AD Normalization (Absolute Deviation, more robust than L2)
16
- # 7. Variance Preserving Rescale (phi blend)
17
- # 8. Interpolate Phi (NRS ↔ plain CFG blend)
18
- # 9. CFG Drift Correction (mean/median centering, from adept_sampler_v4)
19
- # 10. Momentum smoothing (from res_solver / clybius)
20
- # 11. GE-Gamma Extrapolation (from gradient_estimation.py)
21
- # 12. Native Detail Boost (Gaussian HF enhancement, from adept_sampler_v4)
22
- # 13. Spectral Modulation (FFT frequency correction, from adept_sampler_v4)
23
- # 14. Uncond Noise & Scale (from forge_condBlast)
24
- # 15. Output Clamp (adaptive sigma-based, from adept_sampler_v4)
25
- # ==============================================================================
26
-
27
- CURVE_CHOICES = [
28
- "Constant",
29
- "Linear Down", "Linear Up",
30
- "Cosine Down", "Cosine Up",
31
- "Half Cosine Down", "Half Cosine Up",
32
- "Power Down", "Power Up",
33
- "Linear Repeating", "Cosine Repeating",
34
- "Sawtooth",
35
- ]
36
-
37
- SCHED_MODES = ["Off", "Individual Curves", "CADS Anneal", "Adaptive Phases"]
38
- INTER_STEP_MODES = ["Off", "Momentum", "GE-Gamma"]
39
- DRIFT_METHODS = ["mean", "median"]
40
-
41
-
42
- # ==============================================================================
43
- # ЧАСТЬ 1: МАТЕМАТИЧЕСКОЕ ЯДРО
44
- # ==============================================================================
45
-
46
- def _nrs_core(x_orig, cond, uncond, sigma, skew, stretch, squash, use_ad_norm=False):
47
- """
48
- Base NRS math kernel.
49
- use_ad_norm: use Absolute Deviation norm for squash (more robust to outliers).
50
- Source: dynthres_core (3).py variability_measure='AD'
51
- """
52
- is_v_pred = False
53
- if hasattr(shared.sd_model, 'parameterization'):
54
- is_v_pred = shared.sd_model.parameterization == "v"
55
-
56
- if isinstance(sigma, torch.Tensor):
57
- sig_tens = sigma[0]
58
- else:
59
- sig_tens = torch.tensor(sigma, device=cond.device, dtype=cond.dtype)
60
- if sig_tens.dtype != cond.dtype:
61
- sig_tens = sig_tens.to(dtype=cond.dtype)
62
-
63
- sig_tens = sig_tens.view(1, 1, 1, 1)
64
- sig_root = (sig_tens ** 2 + 1).sqrt()
65
-
66
- if is_v_pred:
67
- nrs_cond, nrs_uncond = cond, uncond
68
- x_div = None
69
- else:
70
- x_div = x_orig / (sig_tens ** 2 + 1)
71
- factor = sig_tens / sig_root
72
- nrs_cond = x_orig - (x_div - cond * factor)
73
- nrs_uncond = x_orig - (x_div - uncond * factor)
74
-
75
- def _dot(a, b):
76
- return (a * b).sum(dim=1, keepdim=True)
77
-
78
- def _nrm2(v):
79
- return _dot(v, v)
80
-
81
- eps_safe = 1e-6
82
-
83
- c_dot_c = _nrm2(nrs_cond) + eps_safe
84
- u_dot_c = _dot(nrs_uncond, nrs_cond)
85
- u_on_c = (u_dot_c / c_dot_c) * nrs_cond
86
-
87
- proj_diff = nrs_cond - u_on_c
88
- stretched = nrs_cond + (stretch * proj_diff)
89
-
90
- u_rej_c = nrs_uncond - u_on_c
91
- skewed = stretched - (skew * u_rej_c)
92
-
93
- if use_ad_norm:
94
- # AD: Mean Absolute Deviation per channel, then average across channels
95
- # Source: dynthres_core sep_feat_channels=True, variability_measure='AD'
96
- cond_len = nrs_cond.abs().mean(dim=(2, 3), keepdim=True).mean(dim=1, keepdim=True)
97
- nrs_len = skewed.abs().mean(dim=(2, 3), keepdim=True).mean(dim=1, keepdim=True) + eps_safe
98
- else:
99
- cond_len = nrs_cond.norm(dim=1, keepdim=True)
100
- nrs_len = skewed.norm(dim=1, keepdim=True) + eps_safe
101
-
102
- squash_scale = (1 - squash) + (squash * (cond_len / nrs_len))
103
- x_final = skewed * squash_scale
104
-
105
- if is_v_pred:
106
- return x_final
107
- else:
108
- return (x_div - (x_orig - x_final)) * (sig_root / sig_tens)
109
-
110
-
111
- def calc_nrs(x_orig, cond, uncond, sigma, skew, stretch, squash):
112
- """Backward-compatible wrapper."""
113
- return _nrs_core(x_orig, cond, uncond, sigma, skew, stretch, squash)
114
-
115
-
116
- def calc_nrs_per_channel(x_orig, cond, uncond, sigma, skew, stretch, squash, use_ad_norm=False):
117
- """
118
- Per-channel NRS: process each latent channel independently.
119
- Source: dynthres_core (3).py sep_feat_channels=True
120
- Per-channel norms use dim=(2,3) spatial only, preventing cross-channel influence.
121
- """
122
- results = []
123
- for ch in range(cond.shape[1]):
124
- r = _nrs_core(
125
- x_orig[:, ch:ch+1],
126
- cond[:, ch:ch+1],
127
- uncond[:, ch:ch+1],
128
- sigma, skew, stretch, squash, use_ad_norm
129
- )
130
- results.append(r)
131
- return torch.cat(results, dim=1)
132
-
133
-
134
- def calc_nrs_midpoint_refined(x_orig, cond, uncond, sigma, skew, stretch, squash,
135
- refine_blend=0.0, first_half_only=True,
136
- current_step=0, total_steps=20,
137
- use_per_channel=False, use_ad_norm=False):
138
- """
139
- Correct Kohaku midpoint refinement for NRS.
140
-
141
- Kohaku_LoNyu_Yog (sampler, smea_sampling_46.py):
142
- d = to_d(x, sigma, model(x))
143
- x3 = x + (d + d2) / 2 * dt # midpoint from averaged direction
144
- d3 = to_d(x3, sigma, model(x3))
145
- real_d = (d + d3) / 2 # Runge-Kutta 2nd order average
146
-
147
- NRS adaptation (no extra model calls needed):
148
- nrs_direct = NRS(x_orig, cond, uncond)
149
- x_mid = x_orig + (nrs_direct - x_orig) * blend * 0.5 # shifted latent
150
- nrs_refined = NRS(x_mid, cond, uncond)
151
- result = (nrs_direct + nrs_refined) / 2
152
- """
153
- _fn = calc_nrs_per_channel if use_per_channel else _nrs_core
154
-
155
- nrs_direct = _fn(x_orig, cond, uncond, sigma, skew, stretch, squash, use_ad_norm)
156
-
157
- if refine_blend <= 0.0:
158
- return nrs_direct
159
-
160
- if first_half_only and current_step > total_steps / 2:
161
- return nrs_direct
162
-
163
- # Midpoint in x-space (between noisy x_orig and denoised nrs_direct)
164
- x_mid = x_orig + (nrs_direct - x_orig) * (refine_blend * 0.5)
165
-
166
- nrs_refined = _fn(x_mid, cond, uncond, sigma, skew, stretch, squash, use_ad_norm)
167
-
168
- # Runge-Kutta style average (Kohaku: real_d = (d + d3) / 2)
169
- return (nrs_direct + nrs_refined) * 0.5
170
-
171
-
172
- # ==============================================================================
173
- # ЧАСТЬ 2: РАСПИСАНИЕ ПАРАМЕТРОВ (SCHEDULING)
174
- # ==============================================================================
175
-
176
- def nrs_schedule_value(base_value, step, total_steps, curve="Constant",
177
- min_value=0.0, sched_val=2.0):
178
- """
179
- Apply curve to parameter over sampling steps.
180
- Source: dynthres_core (3).py interpret_scale() + khrfix (26).py curve_progress()
181
- """
182
- if curve == "Constant":
183
- return base_value
184
-
185
- frac = step / max(total_steps - 1, 1)
186
- frac = max(0.0, min(1.0, frac))
187
- scale = base_value - min_value
188
-
189
- if curve == "Linear Down":
190
- val = 1.0 - frac
191
- elif curve == "Linear Up":
192
- val = frac
193
- elif curve == "Cosine Down":
194
- # Source: dynthres_core cos(frac * pi/2) — от 1.0 до ~0
195
- val = math.cos(frac * 1.5707963)
196
- elif curve == "Cosine Up":
197
- # Source: dynthres_core 1 - cos(frac * pi/2)
198
- val = 1.0 - math.cos(frac * 1.5707963)
199
- elif curve == "Half Cosine Down":
200
- # Source: dynthres_core + khrfix → cos(frac), НЕ cos(frac*pi/2)
201
- val = math.cos(frac)
202
- elif curve == "Half Cosine Up":
203
- # Source: dynthres_core + khrfix → 1 - cos(frac)
204
- val = 1.0 - math.cos(frac)
205
- elif curve == "Power Down":
206
- val = 1.0 - math.pow(frac, max(sched_val, 0.1))
207
- elif curve == "Power Up":
208
- val = math.pow(frac, max(sched_val, 0.1))
209
- elif curve == "Linear Repeating":
210
- sv = max(sched_val, 0.1)
211
- portion = (frac * sv) % 1.0
212
- val = 1.0 - abs(2.0 * portion - 1.0)
213
- elif curve == "Cosine Repeating":
214
- sv = max(sched_val, 0.1)
215
- val = math.cos(2.0 * math.pi * frac * sv) * 0.5 + 0.5
216
- elif curve == "Sawtooth":
217
- sv = max(sched_val, 0.1)
218
- val = (frac * sv) % 1.0
219
- else:
220
- val = 1.0
221
-
222
- return min_value + max(0.0, scale * val)
223
-
224
-
225
- def nrs_cads_schedule(step, total_steps, tau1=0.6, tau2=0.9):
226
- """
227
- CADS trapezoidal NRS strength schedule.
228
- Source: cads__6__fixed.py cads_linear_schedule(), "Hold after full" mode.
229
-
230
- t = 1 - step/total (descends from 1 to 0 during sampling)
231
- - t > tau2 (early steps): gamma = 0 (NRS off)
232
- - tau1 < t < tau2 (ramp): gamma linearly rises 0→1
233
- - t <= tau1 (late steps): gamma = 1 (NRS full strength)
234
-
235
- Defaults tau1=0.6, tau2=0.9 → NRS activates at ~10% of steps,
236
- reaches full strength at ~40%, stays full for remainder.
237
- """
238
- t = 1.0 - step / max(total_steps - 1, 1)
239
- t = max(0.0, min(1.0, t))
240
- tau1 = max(0.0, min(1.0, tau1))
241
- tau2 = max(0.0, min(1.0, tau2))
242
-
243
- if tau1 >= tau2:
244
- return 1.0 if t <= tau1 else 0.0
245
- if t >= tau2:
246
- return 0.0
247
- if t <= tau1:
248
- return 1.0
249
- return (tau2 - t) / (tau2 - tau1)
250
-
251
-
252
- def calc_adaptive_nrs_params(base_skew, base_stretch, base_squash, progress,
253
- euler_end=0.35, dpm_end=0.70):
254
- """
255
- Phase-based parameter adjustment.
256
- Source: adaptive_progressive.py calc_phase_bounds() + phase weight logic.
257
-
258
- Phase 1 (0 → euler_end): Structural — max skew, moderate stretch
259
- Phase 2 (euler_end → dpm_end): Transition — decreasing skew, rising squash
260
- Phase 3 (dpm_end → 1.0): Detail — minimal skew, max squash
261
- """
262
- euler_end = max(0.0, min(1.0, euler_end))
263
- dpm_end = max(euler_end + 0.05, min(1.0, dpm_end))
264
-
265
- if progress < euler_end:
266
- skew_f, stretch_f, squash_add = 1.0, 0.8, 0.0
267
- elif progress < dpm_end:
268
- ph = (progress - euler_end) / max(dpm_end - euler_end, 1e-8)
269
- w_euler = max(0.0, 1.0 - ph * 2.5)
270
- skew_f = w_euler + (1.0 - w_euler) * 0.3
271
- stretch_f = w_euler * 0.8 + (1.0 - w_euler) * 1.0
272
- squash_add = (1.0 - w_euler) * 0.3
273
- else:
274
- ph = (progress - dpm_end) / max(1.0 - dpm_end, 1e-8)
275
- skew_f = max(0.0, 0.3 - ph * 0.3)
276
- stretch_f = 0.8
277
- squash_add = 0.3 + ph * 0.4
278
-
279
- return (
280
- base_skew * skew_f,
281
- base_stretch * stretch_f,
282
- min(1.0, base_squash + (1.0 - base_squash) * squash_add),
283
- )
284
-
285
-
286
- # ==============================================================================
287
- # ЧАСТЬ 3: POST-PROCESSING ФУНКЦИИ
288
- # ==============================================================================
289
-
290
- def apply_nrs_drift_correction(tensor, intensity=0.0, method='mean'):
291
- """
292
- Remove CFG mean/median drift.
293
- Source: adept_sampler_v4_COMPLETE (2).py apply_combat_cfg_drift()
294
- Based on ComfyUI-Latent-Modifiers.
295
- """
296
- if intensity <= 0.0:
297
- return tensor
298
- try:
299
- if method == 'median':
300
- center = tensor.view(tensor.shape[0], -1).median(dim=-1, keepdim=True)[0]
301
- center = center.view(tensor.shape[0], 1, 1, 1)
302
- else:
303
- center = tensor.mean(dim=(1, 2, 3), keepdim=True)
304
- return tensor - center * intensity
305
- except Exception:
306
- return tensor
307
-
308
-
309
- def apply_variance_preserving_rescale(nrs_result, cond_reference, phi=0.0):
310
- """
311
- Scale NRS result to match std of reference cond.
312
- Source: dynthres_core (3).py interpolation logic + variance concept.
313
- """
314
- if phi <= 0.0:
315
- return nrs_result
316
- try:
317
- std_ref = cond_reference.std()
318
- std_nrs = nrs_result.std()
319
- if std_nrs < 1e-8:
320
- return nrs_result
321
- rescaled = nrs_result * (std_ref / std_nrs)
322
- return phi * rescaled + (1.0 - phi) * nrs_result
323
- except Exception:
324
- return nrs_result
325
-
326
-
327
- def apply_blend_phi(nrs_result, plain_cfg_result, phi=1.0):
328
- """
329
- Blend NRS output with standard CFG output.
330
- Source: dynthres_core (3).py interpolate_phi.
331
- phi=1.0 → pure NRS, phi=0.0 → pure CFG.
332
- """
333
- if phi >= 1.0:
334
- return nrs_result
335
- if phi <= 0.0:
336
- return plain_cfg_result
337
- return phi * nrs_result + (1.0 - phi) * plain_cfg_result
338
-
339
-
340
- def apply_nrs_momentum(nrs_result, prev_result, prev_vel, momentum=0.0):
341
- """
342
- Full momentum smoothing between steps.
343
- Source: res_solver (11).py + clybius_dpmpp_4m_sde (7).py momentum_func()
344
- Formula: vel = m*(1-m/2)*prev_vel + (1-m*(1-m/2))*curr_diff
345
- result = prev_result + vel
346
- Returns: (smoothed_result, new_vel)
347
- """
348
- if momentum <= 0.0 or prev_result is None:
349
- curr_diff = nrs_result - prev_result if prev_result is not None else None
350
- return nrs_result, curr_diff
351
- try:
352
- curr_diff = nrs_result - prev_result
353
- eff_m = momentum * (1.0 - momentum * 0.5)
354
- if prev_vel is None:
355
- # First step: velocity = current diff (no history)
356
- vel = curr_diff
357
- else:
358
- # Full RES/Clybius formula: blend prev velocity with current diff
359
- vel = eff_m * prev_vel + (1.0 - eff_m) * curr_diff
360
- smoothed = prev_result + vel
361
- return smoothed, vel
362
- except Exception:
363
- return nrs_result, None
364
-
365
-
366
- def apply_nrs_ge_extrapolation(nrs_result, prev_result, prev_diff, ge_gamma=1.0):
367
- """
368
- Gradient Estimation extrapolation between steps.
369
- Source: gradient_estimation (5).py
370
- Formula: d_bar = ge_gamma * (d - old_d) + old_d
371
- ge_gamma=1.0 → standard, >1.0 → extrapolation, <1.0 → smoothing.
372
- """
373
- if ge_gamma == 1.0 or prev_result is None or prev_diff is None:
374
- return nrs_result
375
- try:
376
- d = nrs_result - prev_result
377
- d_bar = ge_gamma * (d - prev_diff) + prev_diff
378
- return prev_result + d_bar
379
- except Exception:
380
- return nrs_result
381
-
382
-
383
- def apply_nrs_detail_boost(nrs_result, progress, boost_strength=0.0):
384
- """
385
- Progressive high-frequency detail enhancement.
386
- Source: adept_sampler_v4_COMPLETE (2).py compute_native_detail_boost()
387
- Three phases: early (gentle) → mid → late (strong).
388
- """
389
- if boost_strength <= 0.0:
390
- return nrs_result
391
- try:
392
- import torch.nn.functional as F
393
-
394
- if progress < 0.30:
395
- hf_boost = 0.03 * boost_strength * (progress / 0.30)
396
- elif progress < 0.60:
397
- hf_boost = (0.03 + 0.07 * (progress - 0.30) / 0.30) * boost_strength
398
- else:
399
- hf_boost = (0.10 + 0.08 * (progress - 0.60) / 0.40) * boost_strength
400
-
401
- if hf_boost <= 1e-6:
402
- return nrs_result
403
-
404
- # Gaussian kernel for low-freq extraction
405
- sigma_g = 0.5
406
- ks = 3
407
- x_k = torch.linspace(-(ks - 1) / 2, (ks - 1) / 2, ks,
408
- device=nrs_result.device, dtype=nrs_result.dtype)
409
- gauss = torch.exp(-0.5 * (x_k / sigma_g) ** 2)
410
- gauss = gauss / gauss.sum()
411
- kernel = torch.mm(gauss[:, None], gauss[None, :])
412
- # .contiguous() required: .expand() creates non-contiguous view, F.conv2d needs contiguous weight
413
- kernel = kernel.expand(nrs_result.shape[1], 1, ks, ks).contiguous()
414
-
415
- padded = F.pad(nrs_result, (1, 1, 1, 1), mode='reflect')
416
- low_freq = F.conv2d(padded, kernel, groups=nrs_result.shape[1])
417
- high_freq = nrs_result - low_freq
418
- return nrs_result + high_freq * hf_boost
419
- except Exception:
420
- return nrs_result
421
-
422
-
423
- def apply_spectral_modulation(noise_pred, multiplier=0.0, percentile=5.0):
424
- """
425
- Clybius spectral modulation on noise_pred = (cond_x0 - uncond_x0).
426
- Source: adept_sampler_v4_COMPLETE (2).py apply_spectral_modulation_clybius()
427
- Boosts low-freq, suppresses extreme high-freq outliers.
428
- Applied BEFORE NRS computation.
429
- """
430
- if multiplier == 0.0 or percentile <= 0:
431
- return noise_pred
432
- try:
433
- fourier = torch.fft.fft2(noise_pred, dim=(-2, -1))
434
- log_amp = torch.log(torch.sqrt(fourier.real ** 2 + fourier.imag ** 2) + 1e-8)
435
- flat = log_amp.abs().flatten(2) # [B, C, H*W]
436
-
437
- q_lo = torch.quantile(flat, percentile * 0.01, dim=2)
438
- q_hi = torch.quantile(flat, 1.0 - percentile * 0.01, dim=2)
439
-
440
- # Expand to [B, C, H, W]
441
- q_lo = q_lo.unsqueeze(-1).unsqueeze(-1).expand(log_amp.shape)
442
- q_hi = q_hi.unsqueeze(-1).unsqueeze(-1).expand(log_amp.shape)
443
-
444
- # mask_low: boost frequencies below lower threshold (1.0–1.5 range)
445
- # mask_high: reduce frequencies above upper threshold (0.5–1.0 range)
446
- mask_low = ((log_amp < q_lo).float() + 1.0).clamp_(max=1.5)
447
- mask_high = (log_amp < q_hi).float().clamp_(min=0.5)
448
-
449
- filtered = fourier * ((mask_low * mask_high) ** multiplier)
450
- return torch.fft.ifft2(filtered, dim=(-2, -1)).real
451
- except Exception:
452
- return noise_pred
453
-
454
-
455
- def apply_uncond_modifications(uncond, noise_strength=0.0, uncond_scale=1.0):
456
- """
457
- Add noise to uncond and/or scale it.
458
- Source: forge_condBlast (6).py
459
- noise: lerp(uncond, randn*uncond.std(), strength)
460
- scale: lerp(zeros, uncond, scale)
461
- """
462
- if noise_strength <= 0.0 and uncond_scale == 1.0:
463
- return uncond
464
- try:
465
- result = uncond.clone()
466
- if noise_strength > 0.0:
467
- noise = torch.randn_like(result) * result.std()
468
- result = torch.lerp(result, noise, noise_strength)
469
- if uncond_scale != 1.0:
470
- result = torch.lerp(torch.zeros_like(result), result, uncond_scale)
471
- return result
472
- except Exception:
473
- return uncond
474
-
475
-
476
- def apply_nrs_output_clamp(nrs_result, sigma, clamp_multiplier=0.0):
477
- """
478
- Adaptive output clamping based on sigma.
479
- Source: adept_sampler_v4_COMPLETE (2).py apply_dynamic_thresholding().
480
- threshold = clamp * (1 + sigma/10)
481
- """
482
- if clamp_multiplier <= 0.0:
483
- return nrs_result
484
- try:
485
- sigma_val = sigma[0].item() if isinstance(sigma, torch.Tensor) else float(sigma)
486
- threshold = clamp_multiplier * (1.0 + sigma_val / 10.0)
487
- return torch.clamp(nrs_result, -threshold, threshold)
488
- except Exception:
489
- return nrs_result
490
-
491
-
492
- # ==============================================================================
493
- # ЧАСТЬ 4: STEP CONTROL (сохранено из оригинала)
494
- # ==============================================================================
495
-
496
- def should_apply_at_step(current_step, total_steps, start_step, end_step,
497
- start_frac, end_frac, step_mode):
498
- if step_mode == "Absolute Steps":
499
- eff_start = max(0, start_step)
500
- eff_end = min(total_steps, end_step) if end_step > 0 else total_steps
501
- return eff_start <= current_step < eff_end
502
- else:
503
- eff_start = int(total_steps * max(0.0, min(1.0, start_frac)))
504
- eff_end = int(total_steps * max(0.0, min(1.0, end_frac)))
505
- if eff_end == 0:
506
- eff_end = total_steps
507
- return eff_start <= current_step < eff_end
508
-
509
-
510
- def get_param_value_at_step(base_value, current_step, total_steps, start_step, end_step,
511
- start_frac, end_frac, step_mode, enabled):
512
- if not enabled:
513
- return base_value
514
- if should_apply_at_step(current_step, total_steps, start_step, end_step,
515
- start_frac, end_frac, step_mode):
516
- return base_value
517
- return 0.0
518
-
519
-
520
- # ==============================================================================
521
- # ЧАСТЬ 5: HOOKS
522
- # ==============================================================================
523
-
524
- def hook_cfg_denoiser_params(params):
525
- if hasattr(params.denoiser, 'p') and getattr(params.denoiser.p, '_nrs_enabled', False):
526
- params.denoiser.p._nrs_current_sigma = params.sigma
527
- params.denoiser.p._nrs_current_x_in = params.x
528
- if hasattr(params, 'sampling_step'):
529
- params.denoiser.p._nrs_current_step = params.sampling_step
530
- elif hasattr(params.denoiser, 'step'):
531
- params.denoiser.p._nrs_current_step = params.denoiser.step
532
- else:
533
- params.denoiser.p._nrs_current_step = getattr(
534
- params.denoiser.p, '_nrs_current_step', 0)
535
-
536
-
537
- script_callbacks.on_cfg_denoiser(hook_cfg_denoiser_params)
538
-
539
- if not hasattr(sd_samplers_cfg_denoiser.CFGDenoiser, 'original_combine_denoised_nrs_backup'):
540
- sd_samplers_cfg_denoiser.CFGDenoiser.original_combine_denoised_nrs_backup = \
541
- sd_samplers_cfg_denoiser.CFGDenoiser.combine_denoised
542
-
543
-
544
- def hijacked_combine_denoised(self, x_out, conds_list, uncond, cond_scale):
545
- _orig = sd_samplers_cfg_denoiser.CFGDenoiser.original_combine_denoised_nrs_backup
546
-
547
- if not getattr(self, 'p', None) or not getattr(self.p, '_nrs_enabled', False):
548
- return _orig(self, x_out, conds_list, uncond, cond_scale)
549
-
550
- if not hasattr(self.p, '_nrs_current_sigma') or not hasattr(self.p, '_nrs_current_x_in'):
551
- return _orig(self, x_out, conds_list, uncond, cond_scale)
552
-
553
- try:
554
- p = self.p
555
- sigma = p._nrs_current_sigma
556
- base_skew, base_stretch, base_squash = p._nrs_params
557
-
558
- # ── Step Control ──────────────────────────────────────────────────────
559
- current_step = getattr(p, '_nrs_current_step', 0)
560
- total_steps = getattr(p, 'steps', 20)
561
- step_ctrl = getattr(p, '_nrs_step_control_enabled', False)
562
- step_mode_global = getattr(p, '_nrs_step_control_mode', 'Global')
563
-
564
- if step_ctrl:
565
- if step_mode_global == 'Global':
566
- gs = getattr(p, '_nrs_global_step_settings', {})
567
- if not should_apply_at_step(
568
- current_step, total_steps,
569
- gs.get('start_step', 0), gs.get('end_step', total_steps),
570
- gs.get('start_frac', 0.0), gs.get('end_frac', 1.0),
571
- gs.get('step_mode', 'Absolute Steps')):
572
- return _orig(self, x_out, conds_list, uncond, cond_scale)
573
- skew, stretch, squash = base_skew, base_stretch, base_squash
574
- else:
575
- ind = getattr(p, '_nrs_individual_step_settings', {})
576
- sk = ind.get('skew', {})
577
- st = ind.get('stretch', {})
578
- sq = ind.get('squash', {})
579
- skew = get_param_value_at_step(
580
- base_skew, current_step, total_steps,
581
- sk.get('start_step', 0), sk.get('end_step', total_steps),
582
- sk.get('start_frac', 0.0), sk.get('end_frac', 1.0),
583
- sk.get('step_mode', 'Absolute Steps'), sk.get('enabled', True))
584
- stretch = get_param_value_at_step(
585
- base_stretch, current_step, total_steps,
586
- st.get('start_step', 0), st.get('end_step', total_steps),
587
- st.get('start_frac', 0.0), st.get('end_frac', 1.0),
588
- st.get('step_mode', 'Absolute Steps'), st.get('enabled', True))
589
- squash = get_param_value_at_step(
590
- base_squash, current_step, total_steps,
591
- sq.get('start_step', 0), sq.get('end_step', total_steps),
592
- sq.get('start_frac', 0.0), sq.get('end_frac', 1.0),
593
- sq.get('step_mode', 'Absolute Steps'), sq.get('enabled', True))
594
- else:
595
- skew, stretch, squash = base_skew, base_stretch, base_squash
596
-
597
- # ── Scheduling ────────────────────────────────────────────────────────
598
- sched_mode = getattr(p, '_nrs_sched_mode', 'Off')
599
- progress = current_step / max(total_steps - 1, 1)
600
-
601
- if sched_mode == 'Individual Curves':
602
- sched_val = getattr(p, '_nrs_sched_val', 2.0)
603
- skew = nrs_schedule_value(
604
- skew, current_step, total_steps,
605
- getattr(p, '_nrs_skew_curve', 'Constant'),
606
- getattr(p, '_nrs_skew_curve_min', 0.0), sched_val)
607
- stretch = nrs_schedule_value(
608
- stretch, current_step, total_steps,
609
- getattr(p, '_nrs_stretch_curve', 'Constant'),
610
- getattr(p, '_nrs_stretch_curve_min', 0.0), sched_val)
611
- squash = nrs_schedule_value(
612
- squash, current_step, total_steps,
613
- getattr(p, '_nrs_squash_curve', 'Constant'),
614
- getattr(p, '_nrs_squash_curve_min', 0.0), sched_val)
615
-
616
- elif sched_mode == 'CADS Anneal':
617
- tau1 = getattr(p, '_nrs_cads_tau1', 0.6)
618
- tau2 = getattr(p, '_nrs_cads_tau2', 0.9)
619
- cads_scale = nrs_cads_schedule(current_step, total_steps, tau1, tau2)
620
- skew *= cads_scale
621
- stretch *= cads_scale
622
- # squash stays at base — CADS doesn't affect the clamp
623
-
624
- elif sched_mode == 'Adaptive Phases':
625
- euler_end = getattr(p, '_nrs_adaptive_euler_end', 0.35)
626
- dpm_end = getattr(p, '_nrs_adaptive_dpm_end', 0.70)
627
- skew, stretch, squash = calc_adaptive_nrs_params(
628
- skew, stretch, squash, progress, euler_end, dpm_end)
629
-
630
- # ── Feature flags ─────────────────────────────────────────────────────
631
- per_channel = getattr(p, '_nrs_per_channel', False)
632
- use_ad_norm = getattr(p, '_nrs_ad_norm', False)
633
- refine_blend = getattr(p, '_nrs_refine_blend', 0.0)
634
- refine_first_half = getattr(p, '_nrs_refine_first_half', True)
635
- blend_phi = getattr(p, '_nrs_blend_phi', 1.0)
636
- variance_phi = getattr(p, '_nrs_variance_phi', 0.0)
637
- drift_intensity = getattr(p, '_nrs_drift_intensity', 0.0)
638
- drift_method = getattr(p, '_nrs_drift_method', 'mean')
639
- output_clamp = getattr(p, '_nrs_output_clamp', 0.0)
640
- inter_step_mode = getattr(p, '_nrs_inter_step_mode', 'Off')
641
- momentum = getattr(p, '_nrs_momentum', 0.0)
642
- ge_gamma = getattr(p, '_nrs_ge_gamma', 1.0)
643
- detail_boost = getattr(p, '_nrs_detail_boost', 0.0)
644
- spectral_mod = getattr(p, '_nrs_spectral_mod', 0.0)
645
- spectral_pct = getattr(p, '_nrs_spectral_pct', 5.0)
646
- uncond_noise = getattr(p, '_nrs_uncond_noise', 0.0)
647
- uncond_scale = getattr(p, '_nrs_uncond_scale', 1.0)
648
-
649
- # Inter-step state
650
- prev_results = getattr(p, '_nrs_prev_results', {})
651
- prev_diffs = getattr(p, '_nrs_prev_diffs', {})
652
-
653
- # ── Prepare tensors ───────────────────────────────────────────────────
654
- denoised_uncond = x_out[-uncond.shape[0]:]
655
- denoised = torch.clone(denoised_uncond)
656
- x_orig_uncond = p._nrs_current_x_in[-uncond.shape[0]:]
657
-
658
- # ── Main per-item loop ────────────────────────────────────────────────
659
- for i, conds in enumerate(conds_list):
660
- for idx, (cond_index, weight) in enumerate(conds):
661
- current_cond = x_out[cond_index]
662
- if idx != 0:
663
- denoised[i] += (current_cond - denoised_uncond[i]) * (weight * cond_scale)
664
- continue
665
-
666
- x_orig_i = x_orig_uncond[i].unsqueeze(0)
667
- c_in = current_cond.unsqueeze(0) # original, before any modifications
668
- u_in = denoised_uncond[i].unsqueeze(0)
669
-
670
- # 1. Uncond modifications
671
- if uncond_noise > 0.0 or uncond_scale != 1.0:
672
- u_in = apply_uncond_modifications(u_in, uncond_noise, uncond_scale)
673
-
674
- # 2. Spectral modulation on noise_pred BEFORE NRS
675
- # Applied to c_in_mod only — original c_in kept for blend_phi and variance_phi
676
- c_in_for_nrs = c_in
677
- if spectral_mod > 0.0:
678
- noise_pred = c_in - u_in
679
- noise_pred_mod = apply_spectral_modulation(noise_pred, spectral_mod, spectral_pct)
680
- c_in_for_nrs = u_in + noise_pred_mod
681
-
682
- # 3. Core NRS computation
683
- nrs_result = calc_nrs_midpoint_refined(
684
- x_orig_i, c_in_for_nrs, u_in, sigma,
685
- skew, stretch, squash,
686
- refine_blend=refine_blend,
687
- first_half_only=refine_first_half,
688
- current_step=current_step,
689
- total_steps=total_steps,
690
- use_per_channel=per_channel,
691
- use_ad_norm=use_ad_norm,
692
- )
693
-
694
- # 4. Variance preserving rescale — use ORIGINAL c_in as reference
695
- if variance_phi > 0.0:
696
- nrs_result = apply_variance_preserving_rescale(nrs_result, c_in, variance_phi)
697
-
698
- # 5. Drift correction
699
- if drift_intensity > 0.0:
700
- nrs_result = apply_nrs_drift_correction(nrs_result, drift_intensity, drift_method)
701
-
702
- # 6. Blend phi (NRS ↔ plain CFG) — plain_cfg uses ORIGINAL c_in
703
- if blend_phi < 1.0:
704
- plain_cfg = u_in + (c_in - u_in) * cond_scale
705
- nrs_result = apply_blend_phi(nrs_result, plain_cfg, blend_phi)
706
-
707
- # 7. Inter-step: Momentum or GE-Gamma
708
- prev_r = prev_results.get(i, None)
709
- if inter_step_mode == 'Momentum' and momentum > 0.0:
710
- # Full RES/Clybius momentum with velocity tracking
711
- prev_vel = prev_diffs.get(i, None)
712
- nrs_result, new_vel = apply_nrs_momentum(nrs_result, prev_r, prev_vel, momentum)
713
- if new_vel is not None:
714
- prev_diffs[i] = new_vel.detach().clone()
715
- elif inter_step_mode == 'GE-Gamma' and ge_gamma != 1.0:
716
- prev_d = prev_diffs.get(i, None)
717
- # Save RAW diff BEFORE extrapolation (this is old_d for next step)
718
- if prev_r is not None:
719
- raw_diff = (nrs_result - prev_r).detach().clone()
720
- nrs_result = apply_nrs_ge_extrapolation(nrs_result, prev_r, prev_d, ge_gamma)
721
- if prev_r is not None:
722
- prev_diffs[i] = raw_diff # store pre-extrapolation diff
723
-
724
- # Update inter-step state
725
- prev_results[i] = nrs_result.detach().clone()
726
-
727
- # 8. Detail boost
728
- if detail_boost > 0.0:
729
- nrs_result = apply_nrs_detail_boost(nrs_result, progress, detail_boost)
730
-
731
- # 9. Output clamp
732
- if output_clamp > 0.0:
733
- nrs_result = apply_nrs_output_clamp(nrs_result, sigma, output_clamp)
734
-
735
- # Write result
736
- if len(conds) == 1:
737
- denoised[i] = nrs_result.squeeze(0)
738
- else:
739
- delta = nrs_result.squeeze(0) - denoised_uncond[i]
740
- denoised[i] += delta * weight
741
-
742
- # Save inter-step state
743
- p._nrs_prev_results = prev_results
744
- p._nrs_prev_diffs = prev_diffs
745
-
746
- return denoised
747
-
748
- except Exception as e:
749
- print(f"!!! NRS Enhanced Error (Fallback): {e}")
750
- import traceback
751
- traceback.print_exc()
752
- return sd_samplers_cfg_denoiser.CFGDenoiser.original_combine_denoised_nrs_backup(
753
- self, x_out, conds_list, uncond, cond_scale)
754
-
755
-
756
- # ==============================================================================
757
- # ЧАСТЬ 6: UI
758
- # ==============================================================================
759
-
760
- class NRSScript(scripts.Script):
761
- def title(self):
762
- return "NRS + Kohaku Enhanced"
763
-
764
- def show(self, is_img2img):
765
- return scripts.AlwaysVisible
766
-
767
- def ui(self, is_img2img):
768
- with gr.Accordion("NRS + Kohaku Enhanced", open=False):
769
- enabled = gr.Checkbox(label="Включить NRS (Enable)", value=False)
770
-
771
- # ── Инструкция ────────────────────────────────────────────────────
772
- with gr.Accordion("❓ Инструкция / Help", open=False):
773
- gr.Markdown("""
774
- ### NRS + Kohaku Enhanced v2.0
775
-
776
- **NRS (Negative Rejection Steering)** — замена стандартному CFG с 3 параметрами:
777
- - **Skew** — отталкивание от Negative prompt (аналог силы CFG для структуры). Старт: 3–5
778
- - **Stretch** — притяжение к Positive prompt (усиление цветов/стиля). Старт: 2–7
779
- - **Squash** — ограничитель (0=максимум, 1=мягко+детали). Старт: 0.0
780
-
781
- ### 🔮 Midpoint Refinement (исправленный Kohaku)
782
- Правильная адаптация Kohaku_LoNyu_Yog: вычисляет NRS в промежуточной точке и усредняет результаты (Runge-Kutta 2-го порядка). Даёт более точное направление к целевой области.
783
-
784
- ### 📐 Scheduling
785
- - **Individual Curves**: каждый параметр меняется по своей кривой (Linear/Cosine/Power/...)
786
- - **CADS Anneal**: NRS нарастает через несколько шагов (tau1/tau2 трапеция)
787
- - **Adaptive Phases**: автоматические фазы Euler→DPM→Detail
788
-
789
- ### 🔬 Advanced Math
790
- - **Per-Channel**: независимая обработка каждого латентного канала
791
- - **AD Norm**: Absolute Deviation вместо L2 (устойчивее к выбросам)
792
- - **Blend Phi**: смешение NRS↔CFG (1.0=чистый NRS, 0.0=чистый CFG)
793
- - **Variance Phi**: сохранение дисперсии после NRS
794
-
795
- ### 🔁 Inter-Step
796
- - **Momentum**: сглаживание NRS-векторов между шагами
797
- - **GE-Gamma**: экстраполяция направления (>1 усиливает тренд)
798
- """)
799
-
800
- # ── Основные параметры ────────────────────────────────────────────
801
- gr.HTML("<div style='margin:0.6em 0 0.4em; border-bottom:1px solid #555;"
802
- " font-size:0.9em; opacity:0.8;'>Основные параметры</div>")
803
- with gr.Row():
804
- skew = gr.Slider(label="Skew (Композиция)", minimum=-30.0, maximum=30.0,
805
- step=0.05, value=4.0,
806
- info="Отклонение от Neg prompt. Рекомендуется: 3–5")
807
- stretch = gr.Slider(label="Stretch (Цвета/Стиль)", minimum=-30.0, maximum=30.0,
808
- step=0.05, value=2.0,
809
- info="Притяжение к Pos prompt. Рекомендуется: 2–7")
810
- squash = gr.Slider(label="Squash (Защита от пережарки)", minimum=0.0, maximum=1.0,
811
- step=0.01, value=0.0,
812
- info="0=максимальный эффект, 1=больше деталей/мягче")
813
-
814
- # ── Midpoint Refinement ───────────────────────────────────────────
815
- gr.HTML("<div style='margin:0.6em 0 0.4em; border-bottom:1px solid #555;"
816
- " font-size:0.9em; opacity:0.8;'>🔮 Midpoint Refinement (Kohaku)</div>")
817
- refine_blend = gr.Slider(
818
- label="Refinement Blend", minimum=0.0, maximum=1.0, step=0.01, value=0.0,
819
- info="0=выкл, 0.5=рекомендуется. Runge-Kutta уточнение NRS-вектора")
820
- refine_first_half = gr.Checkbox(
821
- label="Only first half of steps (как в оригинале Kohaku)",
822
- value=True,
823
- info="Применять refinement только на первой половине шагов")
824
-
825
- # ── Scheduling ────────────────────────────────────────────────────
826
- gr.HTML("<div style='margin:0.6em 0 0.4em; border-bottom:1px solid #555;"
827
- " font-size:0.9em; opacity:0.8;'>📐 Parameter Scheduling</div>")
828
- sched_mode = gr.Radio(
829
- label="Режим расписания", choices=SCHED_MODES, value="Off")
830
-
831
- with gr.Group(visible=False) as curves_group:
832
- gr.HTML("<div style='font-size:0.85em; opacity:0.7; margin:0.3em 0;'>"
833
- "Кривые применяются к базовым значениям независимо</div>")
834
- with gr.Row():
835
- skew_curve = gr.Dropdown(label="Skew Curve", choices=CURVE_CHOICES,
836
- value="Constant")
837
- skew_curve_min = gr.Slider(label="Skew Min", minimum=-30.0, maximum=30.0,
838
- step=0.05, value=0.0)
839
- with gr.Row():
840
- stretch_curve = gr.Dropdown(label="Stretch Curve", choices=CURVE_CHOICES,
841
- value="Constant")
842
- stretch_curve_min = gr.Slider(label="Stretch Min", minimum=-30.0, maximum=30.0,
843
- step=0.05, value=0.0)
844
- with gr.Row():
845
- squash_curve = gr.Dropdown(label="Squash Curve", choices=CURVE_CHOICES,
846
- value="Constant")
847
- squash_curve_min = gr.Slider(label="Squash Min", minimum=0.0, maximum=1.0,
848
- step=0.01, value=0.0)
849
- sched_val = gr.Slider(
850
- label="Sched Value (для Power/Repeating кривых)",
851
- minimum=0.1, maximum=8.0, step=0.1, value=2.0)
852
-
853
- with gr.Group(visible=False) as cads_group:
854
- gr.HTML("<div style='font-size:0.85em; opacity:0.7; margin:0.3em 0;'>"
855
- "Трапецеидальное нарастание силы NRS. "
856
- "tau1=0.6, tau2=0.9: NRS включается на ~10% шагов, "
857
- "полная сила с ~40%</div>")
858
- with gr.Row():
859
- cads_tau1 = gr.Slider(label="Tau 1 (полная сила)", minimum=0.0, maximum=1.0,
860
- step=0.05, value=0.6)
861
- cads_tau2 = gr.Slider(label="Tau 2 (начало нарастания)", minimum=0.0, maximum=1.0,
862
- step=0.05, value=0.9)
863
-
864
- with gr.Group(visible=False) as adaptive_group:
865
- gr.HTML("<div style='font-size:0.85em; opacity:0.7; margin:0.3em 0;'>"
866
- "Euler Phase: макс. Skew. "
867
- "DPM Phase: переход. "
868
- "Detail Phase: минимум Skew, максимум Squash</div>")
869
- with gr.Row():
870
- adaptive_euler_end = gr.Slider(label="Euler Phase End", minimum=0.0, maximum=1.0,
871
- step=0.05, value=0.35)
872
- adaptive_dpm_end = gr.Slider(label="DPM Phase End", minimum=0.0, maximum=1.0,
873
- step=0.05, value=0.70)
874
-
875
- def update_sched_groups(mode):
876
- return {
877
- curves_group: gr.update(visible=(mode == "Individual Curves")),
878
- cads_group: gr.update(visible=(mode == "CADS Anneal")),
879
- adaptive_group: gr.update(visible=(mode == "Adaptive Phases")),
880
- }
881
-
882
- sched_mode.change(fn=update_sched_groups, inputs=[sched_mode],
883
- outputs=[curves_group, cads_group, adaptive_group])
884
-
885
- # ── Advanced Math ─────────────────────────────────────────────────
886
- gr.HTML("<div style='margin:0.6em 0 0.4em; border-bottom:1px solid #555;"
887
- " font-size:0.9em; opacity:0.8;'>🔬 Advanced Math</div>")
888
- with gr.Row():
889
- per_channel = gr.Checkbox(
890
- label="Per-Channel Processing",
891
- value=False,
892
- info="Обрабатывать каждый латентный канал независимо")
893
- ad_norm = gr.Checkbox(
894
- label="AD Normalization",
895
- value=False,
896
- info="Absolute Deviation вместо L2 (устойчивее к выбросам)")
897
- with gr.Row():
898
- blend_phi = gr.Slider(
899
- label="Blend Phi (NRS↔CFG)", minimum=0.0, maximum=1.0, step=0.01, value=1.0,
900
- info="1.0=чистый NRS, 0.0=чистый CFG, между — смесь")
901
- variance_phi = gr.Slider(
902
- label="Variance Rescale Phi", minimum=0.0, maximum=1.0, step=0.01, value=0.0,
903
- info="0=выкл. Нормирует дисперсию NRS-результата к дисперсии cond")
904
-
905
- # ── Post-Processing ───────────────────────────────────────────────
906
- gr.HTML("<div style='margin:0.6em 0 0.4em; border-bottom:1px solid #555;"
907
- " font-size:0.9em; opacity:0.8;'>📡 Post-Processing</div>")
908
- with gr.Row():
909
- drift_intensity = gr.Slider(
910
- label="Drift Correction", minimum=0.0, maximum=1.0, step=0.01, value=0.0,
911
- info="Убирает смещение mean/median от высокого CFG")
912
- drift_method = gr.Radio(
913
- label="Метод", choices=DRIFT_METHODS, value="mean")
914
- output_clamp = gr.Slider(
915
- label="Output Clamp (0=выкл)", minimum=0.0, maximum=200.0, step=0.5, value=0.0,
916
- info="Адаптивное ограничение экстремальных значений. threshold = clamp*(1+sigma/10)")
917
-
918
- # ── Inter-Step ────────────────────────────────────────────────────
919
- gr.HTML("<div style='margin:0.6em 0 0.4em; border-bottom:1px solid #555;"
920
- " font-size:0.9em; opacity:0.8;'>🔁 Inter-Step</div>")
921
- inter_step_mode = gr.Radio(
922
- label="Режим", choices=INTER_STEP_MODES, value="Off")
923
- with gr.Row():
924
- momentum_slider = gr.Slider(
925
- label="Momentum", minimum=0.0, maximum=0.95, step=0.01, value=0.5,
926
- visible=False,
927
- info="Сглаживание NRS между шагами (0=выкл, 0.5=рекомендуется)")
928
- ge_gamma_slider = gr.Slider(
929
- label="GE Gamma", minimum=0.1, maximum=4.0, step=0.05, value=1.5,
930
- visible=False,
931
- info=">1=экстраполяция тренда, 1=стандарт, <1=сглаживание")
932
-
933
- def update_inter_step(mode):
934
- return {
935
- momentum_slider: gr.update(visible=(mode == "Momentum")),
936
- ge_gamma_slider: gr.update(visible=(mode == "GE-Gamma")),
937
- }
938
-
939
- inter_step_mode.change(fn=update_inter_step, inputs=[inter_step_mode],
940
- outputs=[momentum_slider, ge_gamma_slider])
941
-
942
- # ── Enhancements ──────────────────────────────────────────────────
943
- gr.HTML("<div style='margin:0.6em 0 0.4em; border-bottom:1px solid #555;"
944
- " font-size:0.9em; opacity:0.8;'>✨ Enhancements</div>")
945
- with gr.Row():
946
- detail_boost = gr.Slider(
947
- label="Detail Boost (0=выкл)", minimum=0.0, maximum=3.0, step=0.05, value=0.0,
948
- info="Усиление высокочастотных деталей на поздних шагах")
949
- spectral_mod = gr.Slider(
950
- label="Spectral Modulation (0=выкл)", minimum=0.0, maximum=2.0, step=0.05, value=0.0,
951
- info="FFT-коррекция частот noise_pred перед NRS")
952
- spectral_pct = gr.Slider(
953
- label="Spectral Percentile", minimum=1.0, maximum=20.0, step=0.5, value=5.0,
954
- info="Процентиль для частотной маски (меньше = агрессивнее)")
955
- with gr.Row():
956
- uncond_noise = gr.Slider(
957
- label="Uncond Noise (0=выкл)", minimum=0.0, maximum=0.5, step=0.01, value=0.0,
958
- info="Добавить шум к uncond (увеличивает разнообразие)")
959
- uncond_scale = gr.Slider(
960
- label="Uncond Scale", minimum=0.1, maximum=2.0, step=0.01, value=1.0,
961
- info="Масштаб uncond (1.0=стандарт, <1=ослабить neg)")
962
-
963
- # ── Step Control ──────────────────────────────────────────────────
964
- with gr.Accordion("⏱️ Step Control", open=False):
965
- with gr.Row():
966
- step_control_enabled = gr.Checkbox(label="Включить Step Control", value=False)
967
- step_control_mode = gr.Radio(
968
- label="Режим", choices=["Global", "Individual"], value="Global")
969
-
970
- with gr.Group(visible=True) as global_group:
971
- gr.HTML("<div style='font-weight:bold; margin:0.4em 0;'>Глобальные настройки</div>")
972
- global_step_mode = gr.Radio(
973
- label="Режим шагов",
974
- choices=["Absolute Steps", "Fraction of Steps"],
975
- value="Absolute Steps")
976
- with gr.Row():
977
- global_start_step = gr.Slider(label="Start Step", minimum=0,
978
- maximum=150, step=1, value=0, visible=True)
979
- global_end_step = gr.Slider(label="End Step (0=конец)", minimum=0,
980
- maximum=150, step=1, value=0, visible=True)
981
- with gr.Row():
982
- global_start_frac = gr.Slider(label="Start (fraction)", minimum=0.0,
983
- maximum=1.0, step=0.01, value=0.0, visible=False)
984
- global_end_frac = gr.Slider(label="End (fraction)", minimum=0.0,
985
- maximum=1.0, step=0.01, value=1.0, visible=False)
986
-
987
- with gr.Group(visible=False) as individual_group:
988
- gr.HTML("<div style='font-weight:bold; margin:0.4em 0;'>Индивидуальные настройки</div>")
989
- with gr.Accordion("Skew — Step Settings", open=False):
990
- skew_step_enabled = gr.Checkbox(label="Включить для Skew", value=True)
991
- skew_step_mode = gr.Radio(label="Режим",
992
- choices=["Absolute Steps", "Fraction of Steps"],
993
- value="Absolute Steps")
994
- with gr.Row():
995
- skew_start_step = gr.Slider(label="Start Step", minimum=0,
996
- maximum=150, step=1, value=0, visible=True)
997
- skew_end_step = gr.Slider(label="End Step", minimum=0,
998
- maximum=150, step=1, value=0, visible=True)
999
- with gr.Row():
1000
- skew_start_frac = gr.Slider(label="Start (fraction)", minimum=0.0,
1001
- maximum=1.0, step=0.01, value=0.0, visible=False)
1002
- skew_end_frac = gr.Slider(label="End (fraction)", minimum=0.0,
1003
- maximum=1.0, step=0.01, value=1.0, visible=False)
1004
- with gr.Accordion("Stretch — Step Settings", open=False):
1005
- stretch_step_enabled = gr.Checkbox(label="Включить для Stretch", value=True)
1006
- stretch_step_mode = gr.Radio(label="Режим",
1007
- choices=["Absolute Steps", "Fraction of Steps"],
1008
- value="Absolute Steps")
1009
- with gr.Row():
1010
- stretch_start_step = gr.Slider(label="Start Step", minimum=0,
1011
- maximum=150, step=1, value=0, visible=True)
1012
- stretch_end_step = gr.Slider(label="End Step", minimum=0,
1013
- maximum=150, step=1, value=0, visible=True)
1014
- with gr.Row():
1015
- stretch_start_frac = gr.Slider(label="Start (fraction)", minimum=0.0,
1016
- maximum=1.0, step=0.01, value=0.0, visible=False)
1017
- stretch_end_frac = gr.Slider(label="End (fraction)", minimum=0.0,
1018
- maximum=1.0, step=0.01, value=1.0, visible=False)
1019
- with gr.Accordion("Squash — Step Settings", open=False):
1020
- squash_step_enabled = gr.Checkbox(label="Включить для Squash", value=True)
1021
- squash_step_mode = gr.Radio(label="Режим",
1022
- choices=["Absolute Steps", "Fraction of Steps"],
1023
- value="Absolute Steps")
1024
- with gr.Row():
1025
- squash_start_step = gr.Slider(label="Start Step", minimum=0,
1026
- maximum=150, step=1, value=0, visible=True)
1027
- squash_end_step = gr.Slider(label="End Step", minimum=0,
1028
- maximum=150, step=1, value=0, visible=True)
1029
- with gr.Row():
1030
- squash_start_frac = gr.Slider(label="Start (fraction)", minimum=0.0,
1031
- maximum=1.0, step=0.01, value=0.0, visible=False)
1032
- squash_end_frac = gr.Slider(label="End (fraction)", minimum=0.0,
1033
- maximum=1.0, step=0.01, value=1.0, visible=False)
1034
-
1035
- def update_sc_groups(mode):
1036
- return {
1037
- global_group: gr.update(visible=(mode == "Global")),
1038
- individual_group: gr.update(visible=(mode == "Individual")),
1039
- }
1040
-
1041
- step_control_mode.change(fn=update_sc_groups, inputs=[step_control_mode],
1042
- outputs=[global_group, individual_group])
1043
-
1044
- def _tog(mode):
1045
- a = mode == "Absolute Steps"
1046
- return (gr.update(visible=a), gr.update(visible=a),
1047
- gr.update(visible=not a), gr.update(visible=not a))
1048
-
1049
- global_step_mode.change(fn=_tog, inputs=[global_step_mode],
1050
- outputs=[global_start_step, global_end_step,
1051
- global_start_frac, global_end_frac])
1052
- skew_step_mode.change(fn=_tog, inputs=[skew_step_mode],
1053
- outputs=[skew_start_step, skew_end_step,
1054
- skew_start_frac, skew_end_frac])
1055
- stretch_step_mode.change(fn=_tog, inputs=[stretch_step_mode],
1056
- outputs=[stretch_start_step, stretch_end_step,
1057
- stretch_start_frac, stretch_end_frac])
1058
- squash_step_mode.change(fn=_tog, inputs=[squash_step_mode],
1059
- outputs=[squash_start_step, squash_end_step,
1060
- squash_start_frac, squash_end_frac])
1061
-
1062
- return [
1063
- # Core
1064
- enabled, skew, stretch, squash,
1065
- # Midpoint Refinement
1066
- refine_blend, refine_first_half,
1067
- # Scheduling
1068
- sched_mode,
1069
- skew_curve, skew_curve_min,
1070
- stretch_curve, stretch_curve_min,
1071
- squash_curve, squash_curve_min,
1072
- sched_val,
1073
- cads_tau1, cads_tau2,
1074
- adaptive_euler_end, adaptive_dpm_end,
1075
- # Advanced Math
1076
- per_channel, ad_norm,
1077
- blend_phi, variance_phi,
1078
- # Post-Processing
1079
- drift_intensity, drift_method,
1080
- output_clamp,
1081
- # Inter-Step
1082
- inter_step_mode, momentum_slider, ge_gamma_slider,
1083
- # Enhancements
1084
- detail_boost,
1085
- spectral_mod, spectral_pct,
1086
- uncond_noise, uncond_scale,
1087
- # Step Control
1088
- step_control_enabled, step_control_mode,
1089
- global_step_mode, global_start_step, global_end_step,
1090
- global_start_frac, global_end_frac,
1091
- skew_step_enabled, skew_step_mode, skew_start_step, skew_end_step,
1092
- skew_start_frac, skew_end_frac,
1093
- stretch_step_enabled, stretch_step_mode, stretch_start_step, stretch_end_step,
1094
- stretch_start_frac, stretch_end_frac,
1095
- squash_step_enabled, squash_step_mode, squash_start_step, squash_end_step,
1096
- squash_start_frac, squash_end_frac,
1097
- ]
1098
-
1099
- def process(self, p,
1100
- # Core
1101
- enabled, skew, stretch, squash,
1102
- # Midpoint Refinement
1103
- refine_blend, refine_first_half,
1104
- # Scheduling
1105
- sched_mode,
1106
- skew_curve, skew_curve_min,
1107
- stretch_curve, stretch_curve_min,
1108
- squash_curve, squash_curve_min,
1109
- sched_val,
1110
- cads_tau1, cads_tau2,
1111
- adaptive_euler_end, adaptive_dpm_end,
1112
- # Advanced Math
1113
- per_channel, ad_norm,
1114
- blend_phi, variance_phi,
1115
- # Post-Processing
1116
- drift_intensity, drift_method,
1117
- output_clamp,
1118
- # Inter-Step
1119
- inter_step_mode, momentum, ge_gamma,
1120
- # Enhancements
1121
- detail_boost,
1122
- spectral_mod, spectral_pct,
1123
- uncond_noise, uncond_scale,
1124
- # Step Control
1125
- step_control_enabled, step_control_mode,
1126
- global_step_mode, global_start_step, global_end_step,
1127
- global_start_frac, global_end_frac,
1128
- skew_step_enabled, skew_step_mode, skew_start_step, skew_end_step,
1129
- skew_start_frac, skew_end_frac,
1130
- stretch_step_enabled, stretch_step_mode, stretch_start_step, stretch_end_step,
1131
- stretch_start_frac, stretch_end_frac,
1132
- squash_step_enabled, squash_step_mode, squash_start_step, squash_end_step,
1133
- squash_start_frac, squash_end_frac):
1134
-
1135
- p._nrs_enabled = enabled
1136
- if not enabled:
1137
- return
1138
-
1139
- # Core params
1140
- p._nrs_params = (skew, stretch, squash)
1141
-
1142
- # Midpoint Refinement
1143
- p._nrs_refine_blend = refine_blend
1144
- p._nrs_refine_first_half = refine_first_half
1145
-
1146
- # Scheduling
1147
- p._nrs_sched_mode = sched_mode
1148
- p._nrs_skew_curve = skew_curve
1149
- p._nrs_skew_curve_min = skew_curve_min
1150
- p._nrs_stretch_curve = stretch_curve
1151
- p._nrs_stretch_curve_min = stretch_curve_min
1152
- p._nrs_squash_curve = squash_curve
1153
- p._nrs_squash_curve_min = squash_curve_min
1154
- p._nrs_sched_val = sched_val
1155
- p._nrs_cads_tau1 = cads_tau1
1156
- p._nrs_cads_tau2 = cads_tau2
1157
- p._nrs_adaptive_euler_end = adaptive_euler_end
1158
- p._nrs_adaptive_dpm_end = adaptive_dpm_end
1159
-
1160
- # Advanced Math
1161
- p._nrs_per_channel = per_channel
1162
- p._nrs_ad_norm = ad_norm
1163
- p._nrs_blend_phi = blend_phi
1164
- p._nrs_variance_phi = variance_phi
1165
-
1166
- # Post-Processing
1167
- p._nrs_drift_intensity = drift_intensity
1168
- p._nrs_drift_method = drift_method
1169
- p._nrs_output_clamp = output_clamp
1170
-
1171
- # Inter-Step
1172
- p._nrs_inter_step_mode = inter_step_mode
1173
- p._nrs_momentum = momentum
1174
- p._nrs_ge_gamma = ge_gamma
1175
- p._nrs_prev_results = {}
1176
- p._nrs_prev_diffs = {}
1177
-
1178
- # Enhancements
1179
- p._nrs_detail_boost = detail_boost
1180
- p._nrs_spectral_mod = spectral_mod
1181
- p._nrs_spectral_pct = spectral_pct
1182
- p._nrs_uncond_noise = uncond_noise
1183
- p._nrs_uncond_scale = uncond_scale
1184
-
1185
- # Step Control
1186
- p._nrs_step_control_enabled = step_control_enabled
1187
- p._nrs_step_control_mode = step_control_mode
1188
- p._nrs_global_step_settings = {
1189
- 'step_mode': global_step_mode,
1190
- 'start_step': global_start_step,
1191
- 'end_step': global_end_step,
1192
- 'start_frac': global_start_frac,
1193
- 'end_frac': global_end_frac,
1194
- }
1195
- p._nrs_individual_step_settings = {
1196
- 'skew': {
1197
- 'enabled': skew_step_enabled, 'step_mode': skew_step_mode,
1198
- 'start_step': skew_start_step, 'end_step': skew_end_step,
1199
- 'start_frac': skew_start_frac, 'end_frac': skew_end_frac,
1200
- },
1201
- 'stretch': {
1202
- 'enabled': stretch_step_enabled, 'step_mode': stretch_step_mode,
1203
- 'start_step': stretch_start_step, 'end_step': stretch_end_step,
1204
- 'start_frac': stretch_start_frac, 'end_frac': stretch_end_frac,
1205
- },
1206
- 'squash': {
1207
- 'enabled': squash_step_enabled, 'step_mode': squash_step_mode,
1208
- 'start_step': squash_start_step, 'end_step': squash_end_step,
1209
- 'start_frac': squash_start_frac, 'end_frac': squash_end_frac,
1210
- },
1211
- }
1212
-
1213
- p._nrs_current_step = 0
1214
- sd_samplers_cfg_denoiser.CFGDenoiser.combine_denoised = hijacked_combine_denoised
1215
-
1216
- def postprocess(self, p, processed, *args):
1217
- # Clean up inter-step state to avoid memory leaks between generations
1218
- for attr in ('_nrs_prev_results', '_nrs_prev_diffs'):
1219
- if hasattr(p, attr):
1220
- delattr(p, attr)