File size: 2,576 Bytes
88f8bd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a75ec91
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import pytest
from pydantic import ValidationError

from server.schemas import (
    ActiveModelStatus,
    ErrorBody,
    GenerateParams,
    HealthResponse,
    Lang,
    ModelInfo,
    ParamSpec,
)


def test_param_spec_float_with_bounds():
    p = ParamSpec(
        name="exaggeration",
        label="Exaggeration",
        type="float",
        default=0.5,
        min=0.0,
        max=2.0,
        step=0.05,
    )
    assert p.default == 0.5


def test_param_spec_enum_requires_choices():
    with pytest.raises(ValidationError):
        ParamSpec(name="lang", label="Lang", type="enum", default="en")


def test_param_spec_enum_default_must_be_in_choices():
    with pytest.raises(ValidationError):
        ParamSpec(
            name="lang",
            label="Lang",
            type="enum",
            default="zz",
            choices=["en", "fr"],
        )


def test_param_spec_float_default_within_bounds():
    with pytest.raises(ValidationError):
        ParamSpec(name="x", label="X", type="float", default=99.0, min=0.0, max=1.0)


def test_model_info_round_trip():
    info = ModelInfo(
        id="chatterbox-en",
        label="Chatterbox English",
        description="English voice cloning",
        languages=[Lang(code="en", label="English")],
        paralinguistic_tags=[],
        supports_voice_clone=True,
        params=[
            ParamSpec(name="cfg_weight", label="CFG", type="float", default=0.5, min=0.0, max=1.0)
        ],
    )
    dumped = info.model_dump()
    assert dumped["id"] == "chatterbox-en"


def test_active_model_status_idle():
    s = ActiveModelStatus(id=None, status="idle", last_error=None)
    assert s.status == "idle"


def test_health_response_minimal():
    h = HealthResponse(device="cpu", torch_version="2.4.1", model_status="idle")
    assert h.device == "cpu"


def test_error_body_serializable():
    e = ErrorBody(error={"code": "model_not_found", "message": "x", "detail": None})
    assert e.error["code"] == "model_not_found"


def test_generate_params_accepts_arbitrary_dict():
    g = GenerateParams(values={"temperature": 0.8, "cfg_weight": 0.5})
    assert g.values["temperature"] == 0.8


def test_param_spec_default_group_is_basic():
    p = ParamSpec(name="t", label="T", type="float", default=0.5, min=0.0, max=1.0)
    assert p.group == "basic"


def test_param_spec_advanced_group_round_trips():
    p = ParamSpec(
        name="seed", label="Seed", type="int", default=-1, min=-1, group="advanced",
    )
    assert p.group == "advanced"
    assert p.model_dump()["group"] == "advanced"