import pytest from pydantic import ValidationError from server.schemas import ( ActiveModelStatus, ErrorBody, GenerateParams, HealthResponse, Lang, ModelInfo, ParamSpec, ) def test_param_spec_float_with_bounds(): p = ParamSpec( name="exaggeration", label="Exaggeration", type="float", default=0.5, min=0.0, max=2.0, step=0.05, ) assert p.default == 0.5 def test_param_spec_enum_requires_choices(): with pytest.raises(ValidationError): ParamSpec(name="lang", label="Lang", type="enum", default="en") def test_param_spec_enum_default_must_be_in_choices(): with pytest.raises(ValidationError): ParamSpec( name="lang", label="Lang", type="enum", default="zz", choices=["en", "fr"], ) def test_param_spec_float_default_within_bounds(): with pytest.raises(ValidationError): ParamSpec(name="x", label="X", type="float", default=99.0, min=0.0, max=1.0) def test_model_info_round_trip(): info = ModelInfo( id="chatterbox-en", label="Chatterbox English", description="English voice cloning", languages=[Lang(code="en", label="English")], paralinguistic_tags=[], supports_voice_clone=True, params=[ ParamSpec(name="cfg_weight", label="CFG", type="float", default=0.5, min=0.0, max=1.0) ], ) dumped = info.model_dump() assert dumped["id"] == "chatterbox-en" def test_active_model_status_idle(): s = ActiveModelStatus(id=None, status="idle", last_error=None) assert s.status == "idle" def test_health_response_minimal(): h = HealthResponse(device="cpu", torch_version="2.4.1", model_status="idle") assert h.device == "cpu" def test_error_body_serializable(): e = ErrorBody(error={"code": "model_not_found", "message": "x", "detail": None}) assert e.error["code"] == "model_not_found" def test_generate_params_accepts_arbitrary_dict(): g = GenerateParams(values={"temperature": 0.8, "cfg_weight": 0.5}) assert g.values["temperature"] == 0.8 def test_param_spec_default_group_is_basic(): p = ParamSpec(name="t", label="T", type="float", default=0.5, min=0.0, max=1.0) assert p.group == "basic" def test_param_spec_advanced_group_round_trips(): p = ParamSpec( name="seed", label="Seed", type="int", default=-1, min=-1, group="advanced", ) assert p.group == "advanced" assert p.model_dump()["group"] == "advanced"