z-anime-distill8-gradio-zerogpu / tests /test_validation.py
JSCPPProgrammer's picture
Upload Z-Anime Distill-8 FP8 Gradio ZeroGPU Space
610a02a verified
from __future__ import annotations
import pytest
from src import config
from src.errors import UserFacingError
from src.validation import validate_and_clamp
def test_defaults_match_distill8() -> None:
p = validate_and_clamp(
prompt="a test prompt",
negative_prompt="",
width=1024,
height=1024,
steps=8,
cfg=1.0,
batch_size=1,
sampler_name="euler_ancestral",
scheduler="beta",
denoise=1.0,
)
assert p.width == 1024
assert p.height == 1024
assert p.steps == 8
assert p.cfg == 1.0
assert p.sampler_name == "euler_ancestral"
assert p.scheduler == "beta"
assert p.denoise == 1.0
assert p.negative_prompt == ""
def test_clamp_width_height() -> None:
p = validate_and_clamp(
prompt="x",
negative_prompt="",
width=100,
height=3000,
steps=8,
cfg=1.0,
batch_size=1,
sampler_name="euler_ancestral",
scheduler="beta",
denoise=1.0,
)
assert p.width == 512
assert p.height == 2048
assert p.warnings
def test_euler_a_alias() -> None:
p = validate_and_clamp(
prompt="x",
negative_prompt="",
width=1024,
height=1024,
steps=8,
cfg=1.0,
batch_size=1,
sampler_name="euler_a",
scheduler="beta",
denoise=1.0,
)
assert p.sampler_name == "euler_ancestral"
assert any("euler" in w.lower() for w in p.warnings)
def test_empty_prompt_rejected() -> None:
with pytest.raises(UserFacingError):
validate_and_clamp(
prompt=" ",
negative_prompt="",
width=1024,
height=1024,
steps=8,
cfg=1.0,
batch_size=1,
sampler_name="euler_ancestral",
scheduler="beta",
denoise=1.0,
)
def test_cfg_respects_step() -> None:
p = validate_and_clamp(
prompt="x",
negative_prompt="",
width=1024,
height=1024,
steps=8,
cfg=1.23,
batch_size=1,
sampler_name="euler_ancestral",
scheduler="beta",
denoise=1.0,
)
assert abs(p.cfg - 1.2) < 0.01
def test_cfg_clamped_to_distill_max() -> None:
p = validate_and_clamp(
prompt="x",
negative_prompt="",
width=1024,
height=1024,
steps=8,
cfg=2.0,
batch_size=1,
sampler_name="euler_ancestral",
scheduler="beta",
denoise=1.0,
)
assert p.cfg == 1.5
assert any("cfg" in w.lower() for w in p.warnings)
def test_fixed_seed_is_used() -> None:
p = validate_and_clamp(
prompt="x",
negative_prompt="",
width=1024,
height=1024,
steps=8,
cfg=1.0,
batch_size=1,
sampler_name="euler_ancestral",
scheduler="beta",
denoise=1.0,
seed=12345,
randomize_seed=False,
)
assert p.seed == 12345
def test_randomized_seed_is_in_range() -> None:
p = validate_and_clamp(
prompt="x",
negative_prompt="",
width=1024,
height=1024,
steps=8,
cfg=1.0,
batch_size=1,
sampler_name="euler_ancestral",
scheduler="beta",
denoise=1.0,
seed=12345,
randomize_seed=True,
)
assert config.MIN_SEED <= p.seed <= config.MAX_SEED
def test_seed_clamps_when_locked() -> None:
p = validate_and_clamp(
prompt="x",
negative_prompt="",
width=1024,
height=1024,
steps=8,
cfg=1.0,
batch_size=1,
sampler_name="euler_ancestral",
scheduler="beta",
denoise=1.0,
seed=config.MAX_SEED + 1,
randomize_seed=False,
)
assert p.seed == config.MAX_SEED
assert any("seed" in w for w in p.warnings)