from __future__ import annotations import pytest from src import config from src.errors import UserFacingError from src.validation import validate_and_clamp def test_defaults_match_distill8() -> None: p = validate_and_clamp( prompt="a test prompt", negative_prompt="", width=1024, height=1024, steps=8, cfg=1.0, batch_size=1, sampler_name="euler_ancestral", scheduler="beta", denoise=1.0, ) assert p.width == 1024 assert p.height == 1024 assert p.steps == 8 assert p.cfg == 1.0 assert p.sampler_name == "euler_ancestral" assert p.scheduler == "beta" assert p.denoise == 1.0 assert p.negative_prompt == "" def test_clamp_width_height() -> None: p = validate_and_clamp( prompt="x", negative_prompt="", width=100, height=3000, steps=8, cfg=1.0, batch_size=1, sampler_name="euler_ancestral", scheduler="beta", denoise=1.0, ) assert p.width == 512 assert p.height == 2048 assert p.warnings def test_euler_a_alias() -> None: p = validate_and_clamp( prompt="x", negative_prompt="", width=1024, height=1024, steps=8, cfg=1.0, batch_size=1, sampler_name="euler_a", scheduler="beta", denoise=1.0, ) assert p.sampler_name == "euler_ancestral" assert any("euler" in w.lower() for w in p.warnings) def test_empty_prompt_rejected() -> None: with pytest.raises(UserFacingError): validate_and_clamp( prompt=" ", negative_prompt="", width=1024, height=1024, steps=8, cfg=1.0, batch_size=1, sampler_name="euler_ancestral", scheduler="beta", denoise=1.0, ) def test_cfg_respects_step() -> None: p = validate_and_clamp( prompt="x", negative_prompt="", width=1024, height=1024, steps=8, cfg=1.23, batch_size=1, sampler_name="euler_ancestral", scheduler="beta", denoise=1.0, ) assert abs(p.cfg - 1.2) < 0.01 def test_cfg_clamped_to_distill_max() -> None: p = validate_and_clamp( prompt="x", negative_prompt="", width=1024, height=1024, steps=8, cfg=2.0, batch_size=1, sampler_name="euler_ancestral", scheduler="beta", denoise=1.0, ) assert p.cfg == 1.5 assert any("cfg" in w.lower() for w in p.warnings) def test_fixed_seed_is_used() -> None: p = validate_and_clamp( prompt="x", negative_prompt="", width=1024, height=1024, steps=8, cfg=1.0, batch_size=1, sampler_name="euler_ancestral", scheduler="beta", denoise=1.0, seed=12345, randomize_seed=False, ) assert p.seed == 12345 def test_randomized_seed_is_in_range() -> None: p = validate_and_clamp( prompt="x", negative_prompt="", width=1024, height=1024, steps=8, cfg=1.0, batch_size=1, sampler_name="euler_ancestral", scheduler="beta", denoise=1.0, seed=12345, randomize_seed=True, ) assert config.MIN_SEED <= p.seed <= config.MAX_SEED def test_seed_clamps_when_locked() -> None: p = validate_and_clamp( prompt="x", negative_prompt="", width=1024, height=1024, steps=8, cfg=1.0, batch_size=1, sampler_name="euler_ancestral", scheduler="beta", denoise=1.0, seed=config.MAX_SEED + 1, randomize_seed=False, ) assert p.seed == config.MAX_SEED assert any("seed" in w for w in p.warnings)