from __future__ import annotations import math import random from typing import Any from src import config from src.config import GenerationParams from src.errors import UserFacingError def _as_int(name: str, value: Any, lo: int, hi: int) -> int: if value is None: raise UserFacingError(f"Missing value for {name!r}.") try: n = int(round(float(value))) except (TypeError, ValueError): raise UserFacingError( f"Invalid {name!r}: expected a number in [{lo}, {hi}].", details=str(value) ) if n < lo or n > hi: raise UserFacingError( f"Invalid {name!r}: {n} is outside the allowed range [{lo}, {hi}]." ) return n def _as_float(name: str, value: Any, lo: float, hi: float) -> float: if value is None: raise UserFacingError(f"Missing value for {name!r}.") try: n = float(value) except (TypeError, ValueError): raise UserFacingError( f"Invalid {name!r}: expected a number in [{lo}, {hi}].", details=str(value) ) if not math.isfinite(n): raise UserFacingError(f"Invalid {name!r}: must be a finite number.") if n < lo or n > hi: raise UserFacingError( f"Invalid {name!r}: {n} is outside the allowed range [{lo}, {hi}]." ) return n def _clamp_int(name: str, value: Any, lo: int, hi: int) -> tuple[int, str | None]: """Return clamped int and optional warning if clamped from out-of-range input.""" try: n = int(round(float(value))) except (TypeError, ValueError): raise UserFacingError( f"Invalid {name!r}: expected a number; got {value!r}.", details=repr(value) ) if n < lo: return lo, f"{name} was {n} (below minimum {lo}); using {lo}." if n > hi: return hi, f"{name} was {n} (above maximum {hi}); using {hi}." return n, None def _clamp_float( name: str, value: Any, lo: float, hi: float, *, step: float | None = None, decimals: int | None = None, ) -> tuple[float, str | None]: try: n = float(value) except (TypeError, ValueError): raise UserFacingError( f"Invalid {name!r}: expected a number; got {value!r}.", details=repr(value) ) if not math.isfinite(n): raise UserFacingError(f"Invalid {name!r}: must be a finite number.") warn = None if n < lo: warn = f"{name} was {n} (below minimum {lo}); using {lo}." n = lo elif n > hi: warn = f"{name} was {n} (above maximum {hi}); using {hi}." n = hi if step is not None and step > 0: # Snap to grid relative to lo n = lo + round((n - lo) / step) * step n = min(hi, max(lo, n)) if decimals is not None: n = round(n, decimals) return n, warn def _normalize_sampler(name: str) -> tuple[str, str | None]: s = (name or "").strip() if not s: return config.DEFAULT_SAMPLER, f"Empty sampler: using default {config.DEFAULT_SAMPLER!r}." if s in config.SAMPLER_CHOICES: return s, None # Common alias from Z-Image / Z-Anime docs if s in ("euler_a", "euler-a", "euler a"): return "euler_ancestral", ( f"Sampler {name!r} is not a known id; remapped to 'euler_ancestral'." ) return config.DEFAULT_SAMPLER, ( f"Sampler {name!r} is not in the supported set for this Space; using " f"{config.DEFAULT_SAMPLER!r}. Supported examples: {', '.join(config.SAMPLER_CHOICES[:6])}…" ) def _normalize_scheduler(name: str) -> tuple[str, str | None]: s = (name or "").strip() if not s: return config.DEFAULT_SCHEDULER, f"Empty scheduler: using default {config.DEFAULT_SCHEDULER!r}." if s in config.SCHEDULER_CHOICES: return s, None return config.DEFAULT_SCHEDULER, ( f"Scheduler {name!r} is not in the supported set for this Space; using " f"{config.DEFAULT_SCHEDULER!r}." ) def validate_and_clamp( *, prompt: str, negative_prompt: str | None, width: Any, height: Any, steps: Any, cfg: Any, batch_size: Any, sampler_name: str | None, scheduler: str | None, denoise: Any, seed: Any = None, randomize_seed: bool = True, ) -> GenerationParams: """ Validate and clamp all user parameters; collect non-fatal warnings. Rejects unparseable types with clear messages. """ warnings: list[str] = [] p = (prompt or "").strip() if not p: raise UserFacingError("Prompt must not be empty.") if len(p) > 20_000: raise UserFacingError("Prompt is too long (max 20,000 characters).") neg = (negative_prompt or "").strip() if len(neg) > 20_000: raise UserFacingError("Negative prompt is too long (max 20,000 characters).") w, wmsg = _clamp_int("width", width, config.MIN_WH, config.MAX_WH) if wmsg: warnings.append(wmsg) h, hmsg = _clamp_int("height", height, config.MIN_WH, config.MAX_WH) if hmsg: warnings.append(hmsg) st, stmsg = _clamp_int("steps", steps, config.MIN_STEPS, config.MAX_STEPS) if stmsg: warnings.append(stmsg) cfg_v, cmsg = _clamp_float("cfg", cfg, config.MIN_CFG, config.MAX_CFG, step=0.1, decimals=1) if cmsg: warnings.append(cmsg) d_v, dmsg = _clamp_float( "denoise", denoise, config.MIN_DENOISE, config.MAX_DENOISE, step=0.01, decimals=2 ) if dmsg: warnings.append(dmsg) bs, bmsg = _clamp_int("batch_size", batch_size, config.MIN_BATCH, config.MAX_BATCH) if bmsg: warnings.append(bmsg) sampler, sm = _normalize_sampler(sampler_name or "") if sm: warnings.append(sm) sched, sc = _normalize_scheduler(scheduler or "") if sc: warnings.append(sc) if randomize_seed: seed_value = random.randint(config.MIN_SEED, config.MAX_SEED) else: seed_value, seed_msg = _clamp_int("seed", seed, config.MIN_SEED, config.MAX_SEED) if seed_msg: warnings.append(seed_msg) return GenerationParams( prompt=p, negative_prompt=neg, width=w, height=h, steps=st, cfg=cfg_v, batch_size=bs, sampler_name=sampler, scheduler=sched, denoise=d_v, seed=seed_value, warnings=tuple(warnings), )