techfreakworm commited on
Commit
88f8bd9
·
unverified ·
1 Parent(s): ca78147

feat(schemas): pydantic models for ParamSpec/ModelInfo/Health/Errors

Browse files
Files changed (2) hide show
  1. server/schemas.py +75 -0
  2. tests/test_schemas.py +82 -0
server/schemas.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic models for the public API."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Any, Literal
5
+
6
+ from pydantic import BaseModel, Field, model_validator
7
+
8
+
9
+ ParamType = Literal["float", "int", "bool", "enum"]
10
+ ModelStatus = Literal["idle", "loading", "loaded", "error"]
11
+
12
+
13
+ class Lang(BaseModel):
14
+ code: str
15
+ label: str
16
+
17
+
18
+ class ParamSpec(BaseModel):
19
+ name: str
20
+ label: str
21
+ type: ParamType
22
+ default: float | int | bool | str
23
+ min: float | int | None = None
24
+ max: float | int | None = None
25
+ step: float | int | None = None
26
+ choices: list[str] | None = None
27
+ help: str = ""
28
+
29
+ @model_validator(mode="after")
30
+ def _validate(self) -> "ParamSpec":
31
+ if self.type == "enum":
32
+ if not self.choices:
33
+ raise ValueError("enum params must define `choices`")
34
+ if self.default not in self.choices:
35
+ raise ValueError("enum default must appear in `choices`")
36
+ if self.type in {"float", "int"}:
37
+ if self.min is not None and isinstance(self.default, (int, float)) and self.default < self.min:
38
+ raise ValueError("default below min")
39
+ if self.max is not None and isinstance(self.default, (int, float)) and self.default > self.max:
40
+ raise ValueError("default above max")
41
+ return self
42
+
43
+
44
+ class ModelInfo(BaseModel):
45
+ id: str
46
+ label: str
47
+ description: str
48
+ languages: list[Lang]
49
+ paralinguistic_tags: list[str]
50
+ supports_voice_clone: bool
51
+ params: list[ParamSpec]
52
+
53
+
54
+ class ActiveModelStatus(BaseModel):
55
+ id: str | None
56
+ status: ModelStatus
57
+ last_error: str | None = None
58
+
59
+
60
+ class HealthResponse(BaseModel):
61
+ device: str
62
+ torch_version: str
63
+ model_status: ModelStatus
64
+
65
+
66
+ class ErrorBody(BaseModel):
67
+ error: dict[str, Any] = Field(
68
+ ...,
69
+ description="{code, message, detail?}",
70
+ )
71
+
72
+
73
+ class GenerateParams(BaseModel):
74
+ """Free-form param bag — adapter-specific."""
75
+ values: dict[str, Any] = {}
tests/test_schemas.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from pydantic import ValidationError
3
+
4
+ from server.schemas import (
5
+ ActiveModelStatus,
6
+ ErrorBody,
7
+ GenerateParams,
8
+ HealthResponse,
9
+ Lang,
10
+ ModelInfo,
11
+ ParamSpec,
12
+ )
13
+
14
+
15
+ def test_param_spec_float_with_bounds():
16
+ p = ParamSpec(
17
+ name="exaggeration",
18
+ label="Exaggeration",
19
+ type="float",
20
+ default=0.5,
21
+ min=0.0,
22
+ max=2.0,
23
+ step=0.05,
24
+ )
25
+ assert p.default == 0.5
26
+
27
+
28
+ def test_param_spec_enum_requires_choices():
29
+ with pytest.raises(ValidationError):
30
+ ParamSpec(name="lang", label="Lang", type="enum", default="en")
31
+
32
+
33
+ def test_param_spec_enum_default_must_be_in_choices():
34
+ with pytest.raises(ValidationError):
35
+ ParamSpec(
36
+ name="lang",
37
+ label="Lang",
38
+ type="enum",
39
+ default="zz",
40
+ choices=["en", "fr"],
41
+ )
42
+
43
+
44
+ def test_param_spec_float_default_within_bounds():
45
+ with pytest.raises(ValidationError):
46
+ ParamSpec(name="x", label="X", type="float", default=99.0, min=0.0, max=1.0)
47
+
48
+
49
+ def test_model_info_round_trip():
50
+ info = ModelInfo(
51
+ id="chatterbox-en",
52
+ label="Chatterbox English",
53
+ description="English voice cloning",
54
+ languages=[Lang(code="en", label="English")],
55
+ paralinguistic_tags=[],
56
+ supports_voice_clone=True,
57
+ params=[
58
+ ParamSpec(name="cfg_weight", label="CFG", type="float", default=0.5, min=0.0, max=1.0)
59
+ ],
60
+ )
61
+ dumped = info.model_dump()
62
+ assert dumped["id"] == "chatterbox-en"
63
+
64
+
65
+ def test_active_model_status_idle():
66
+ s = ActiveModelStatus(id=None, status="idle", last_error=None)
67
+ assert s.status == "idle"
68
+
69
+
70
+ def test_health_response_minimal():
71
+ h = HealthResponse(device="cpu", torch_version="2.4.1", model_status="idle")
72
+ assert h.device == "cpu"
73
+
74
+
75
+ def test_error_body_serializable():
76
+ e = ErrorBody(error={"code": "model_not_found", "message": "x", "detail": None})
77
+ assert e.error["code"] == "model_not_found"
78
+
79
+
80
+ def test_generate_params_accepts_arbitrary_dict():
81
+ g = GenerateParams(values={"temperature": 0.8, "cfg_weight": 0.5})
82
+ assert g.values["temperature"] == 0.8