chatterbox-voice-studio / tests /test_schemas.py
techfreakworm's picture
feat(schemas): add group field to ParamSpec (basic/advanced)
a75ec91 unverified
import pytest
from pydantic import ValidationError
from server.schemas import (
ActiveModelStatus,
ErrorBody,
GenerateParams,
HealthResponse,
Lang,
ModelInfo,
ParamSpec,
)
def test_param_spec_float_with_bounds():
p = ParamSpec(
name="exaggeration",
label="Exaggeration",
type="float",
default=0.5,
min=0.0,
max=2.0,
step=0.05,
)
assert p.default == 0.5
def test_param_spec_enum_requires_choices():
with pytest.raises(ValidationError):
ParamSpec(name="lang", label="Lang", type="enum", default="en")
def test_param_spec_enum_default_must_be_in_choices():
with pytest.raises(ValidationError):
ParamSpec(
name="lang",
label="Lang",
type="enum",
default="zz",
choices=["en", "fr"],
)
def test_param_spec_float_default_within_bounds():
with pytest.raises(ValidationError):
ParamSpec(name="x", label="X", type="float", default=99.0, min=0.0, max=1.0)
def test_model_info_round_trip():
info = ModelInfo(
id="chatterbox-en",
label="Chatterbox English",
description="English voice cloning",
languages=[Lang(code="en", label="English")],
paralinguistic_tags=[],
supports_voice_clone=True,
params=[
ParamSpec(name="cfg_weight", label="CFG", type="float", default=0.5, min=0.0, max=1.0)
],
)
dumped = info.model_dump()
assert dumped["id"] == "chatterbox-en"
def test_active_model_status_idle():
s = ActiveModelStatus(id=None, status="idle", last_error=None)
assert s.status == "idle"
def test_health_response_minimal():
h = HealthResponse(device="cpu", torch_version="2.4.1", model_status="idle")
assert h.device == "cpu"
def test_error_body_serializable():
e = ErrorBody(error={"code": "model_not_found", "message": "x", "detail": None})
assert e.error["code"] == "model_not_found"
def test_generate_params_accepts_arbitrary_dict():
g = GenerateParams(values={"temperature": 0.8, "cfg_weight": 0.5})
assert g.values["temperature"] == 0.8
def test_param_spec_default_group_is_basic():
p = ParamSpec(name="t", label="T", type="float", default=0.5, min=0.0, max=1.0)
assert p.group == "basic"
def test_param_spec_advanced_group_round_trips():
p = ParamSpec(
name="seed", label="Seed", type="int", default=-1, min=-1, group="advanced",
)
assert p.group == "advanced"
assert p.model_dump()["group"] == "advanced"