techfreakworm commited on
Commit
829be0a
·
unverified ·
1 Parent(s): e6b3389

feat(api,models): FastAPI app with /api/health, /api/models, activate, /api/generate + chatterbox-en adapter

Browse files
server/main.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI application factory."""
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ import os
6
+ import tempfile
7
+ from contextlib import asynccontextmanager
8
+ from pathlib import Path
9
+
10
+ import torch
11
+ from fastapi import FastAPI, File, Form, HTTPException, UploadFile
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+ from fastapi.responses import JSONResponse, Response
14
+ from fastapi.staticfiles import StaticFiles
15
+ from sse_starlette.sse import EventSourceResponse
16
+
17
+ from server.audio import AudioValidationError, validate_reference_clip
18
+ from server.device import select_device
19
+ from server.registry import Registry
20
+ from server.zerogpu import decorate
21
+
22
+
23
+ STATIC_DIR = Path(__file__).parent / "static"
24
+
25
+
26
+ def _discover_adapter_classes() -> dict[str, type]:
27
+ """Lazily import adapter modules. Empty dict during early scaffolding."""
28
+ classes: dict[str, type] = {}
29
+ for module_name in ("chatterbox_en", "chatterbox_turbo", "chatterbox_mtl"):
30
+ try:
31
+ mod = __import__(f"server.models.{module_name}", fromlist=["Adapter"])
32
+ except ImportError:
33
+ continue
34
+ cls = getattr(mod, "Adapter", None)
35
+ if cls is not None:
36
+ classes[cls.id] = cls
37
+ return classes
38
+
39
+
40
+ def build_app() -> FastAPI:
41
+ @asynccontextmanager
42
+ async def lifespan(app: FastAPI):
43
+ device = select_device()
44
+ app.state.registry = Registry(
45
+ adapter_classes=_discover_adapter_classes(),
46
+ device=device,
47
+ )
48
+ yield
49
+
50
+ app = FastAPI(title="Chatterbox Voice Studio", lifespan=lifespan)
51
+
52
+ origins = os.getenv(
53
+ "CORS_ORIGINS",
54
+ "http://localhost:5173,http://127.0.0.1:5173",
55
+ ).split(",")
56
+ app.add_middleware(
57
+ CORSMiddleware,
58
+ allow_origins=origins,
59
+ allow_methods=["*"],
60
+ allow_headers=["*"],
61
+ )
62
+
63
+ @app.get("/api/health")
64
+ def health() -> dict:
65
+ registry = app.state.registry
66
+ return {
67
+ "device": registry.device,
68
+ "torch_version": torch.__version__,
69
+ "model_status": registry.status()["status"],
70
+ }
71
+
72
+ @app.get("/api/models")
73
+ def list_models() -> list[dict]:
74
+ return app.state.registry.list_models()
75
+
76
+ @app.get("/api/models/active")
77
+ def active_model() -> dict:
78
+ return app.state.registry.status()
79
+
80
+ @app.post("/api/models/{model_id}/activate")
81
+ async def activate_model(model_id: str):
82
+ try:
83
+ await app.state.registry.get_or_load(model_id)
84
+ except KeyError:
85
+ raise HTTPException(
86
+ status_code=404,
87
+ detail={"error": {"code": "model_not_found", "message": model_id}},
88
+ )
89
+ except Exception as exc:
90
+ return JSONResponse(
91
+ status_code=503,
92
+ content={"error": {"code": "model_load_failed", "message": str(exc)}},
93
+ )
94
+ return {"ok": True}
95
+
96
+ @app.get("/api/models/active/events")
97
+ async def active_events():
98
+ async def gen():
99
+ async for evt in app.state.registry.stream_events():
100
+ yield {"data": json.dumps(evt)}
101
+
102
+ return EventSourceResponse(gen())
103
+
104
+ @app.post("/api/generate")
105
+ async def generate(
106
+ text: str = Form(...),
107
+ model_id: str = Form(...),
108
+ params: str = Form("{}"),
109
+ language: str | None = Form(None),
110
+ reference_wav: UploadFile | None = File(None),
111
+ ):
112
+ try:
113
+ adapter = await app.state.registry.get_or_load(model_id)
114
+ except KeyError:
115
+ raise HTTPException(
116
+ status_code=404,
117
+ detail={"error": {"code": "model_not_found", "message": model_id}},
118
+ )
119
+
120
+ ref_path: str | None = None
121
+ if reference_wav is not None:
122
+ data = await reference_wav.read()
123
+ try:
124
+ validate_reference_clip(data)
125
+ except AudioValidationError as exc:
126
+ return JSONResponse(
127
+ status_code=400,
128
+ content={"error": {"code": "reference_invalid", "message": str(exc)}},
129
+ )
130
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
131
+ tmp.write(data)
132
+ tmp.flush()
133
+ tmp.close()
134
+ ref_path = tmp.name
135
+
136
+ gen_fn = decorate(adapter.generate)
137
+ try:
138
+ wav_bytes, _sr = gen_fn(text, ref_path, language, json.loads(params or "{}"))
139
+ except Exception as exc:
140
+ return JSONResponse(
141
+ status_code=500,
142
+ content={"error": {"code": "generation_failed", "message": str(exc)}},
143
+ )
144
+ return Response(content=wav_bytes, media_type="audio/wav")
145
+
146
+ @app.exception_handler(HTTPException)
147
+ async def _http_exc(request, exc: HTTPException):
148
+ if isinstance(exc.detail, dict) and "error" in exc.detail:
149
+ return JSONResponse(status_code=exc.status_code, content=exc.detail)
150
+ return JSONResponse(
151
+ status_code=exc.status_code,
152
+ content={"error": {"code": "http_error", "message": str(exc.detail)}},
153
+ )
154
+
155
+ if STATIC_DIR.exists():
156
+ app.mount("/", StaticFiles(directory=str(STATIC_DIR), html=True), name="static")
157
+
158
+ return app
159
+
160
+
161
+ app = build_app()
server/models/chatterbox_en.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chatterbox English adapter (ResembleAI/chatterbox)."""
2
+ from __future__ import annotations
3
+
4
+ import io
5
+ from typing import Any, ClassVar
6
+
7
+ import soundfile as sf
8
+
9
+ from server.schemas import Lang, ParamSpec
10
+
11
+
12
+ class Adapter:
13
+ id: ClassVar[str] = "chatterbox-en"
14
+ label: ClassVar[str] = "Chatterbox (English)"
15
+ description: ClassVar[str] = (
16
+ "Original Chatterbox English voice cloning with CFG and exaggeration controls."
17
+ )
18
+ languages: ClassVar[list[Lang]] = [Lang(code="en", label="English")]
19
+ paralinguistic_tags: ClassVar[list[str]] = []
20
+ supports_voice_clone: ClassVar[bool] = True
21
+ params: ClassVar[list[ParamSpec]] = [
22
+ ParamSpec(
23
+ name="exaggeration", label="Exaggeration", type="float",
24
+ default=0.5, min=0.0, max=2.0, step=0.05,
25
+ help="Higher = more expressive prosody.",
26
+ ),
27
+ ParamSpec(
28
+ name="cfg_weight", label="CFG weight", type="float",
29
+ default=0.5, min=0.0, max=1.0, step=0.05,
30
+ ),
31
+ ParamSpec(
32
+ name="temperature", label="Temperature", type="float",
33
+ default=0.8, min=0.1, max=1.5, step=0.05,
34
+ ),
35
+ ]
36
+
37
+ def __init__(self, device: str) -> None:
38
+ self.device = device
39
+ self._model = None
40
+
41
+ def load(self) -> None:
42
+ from chatterbox.tts import ChatterboxTTS
43
+
44
+ self._model = ChatterboxTTS.from_pretrained(device=self.device)
45
+
46
+ def unload(self) -> None:
47
+ self._model = None
48
+
49
+ def generate(
50
+ self,
51
+ text: str,
52
+ reference_wav_path: str | None,
53
+ language: str | None,
54
+ params: dict[str, Any],
55
+ ) -> tuple[bytes, int]:
56
+ if self._model is None:
57
+ raise RuntimeError("model not loaded")
58
+ wav = self._model.generate(
59
+ text,
60
+ audio_prompt_path=reference_wav_path,
61
+ exaggeration=float(params.get("exaggeration", 0.5)),
62
+ cfg_weight=float(params.get("cfg_weight", 0.5)),
63
+ temperature=float(params.get("temperature", 0.8)),
64
+ )
65
+ import numpy as np
66
+ import torch
67
+
68
+ if hasattr(wav, "detach"):
69
+ wav = wav.detach().cpu().numpy()
70
+ if isinstance(wav, torch.Tensor): # pragma: no cover
71
+ wav = wav.numpy()
72
+ arr = np.asarray(wav).squeeze()
73
+ sr = getattr(self._model, "sr", 24000)
74
+ buf = io.BytesIO()
75
+ sf.write(buf, arr, sr, format="WAV", subtype="PCM_16")
76
+ return buf.getvalue(), sr
tests/conftest.py CHANGED
@@ -1,7 +1,7 @@
1
  """Shared test fixtures."""
2
  from __future__ import annotations
3
 
4
- import asyncio
5
 
6
  import pytest
7
 
@@ -49,3 +49,10 @@ class FakeAdapterB(FakeAdapter):
49
  def fake_classes():
50
  FakeAdapter.instances.clear()
51
  return {FakeAdapter.id: FakeAdapter, FakeAdapterB.id: FakeAdapterB}
 
 
 
 
 
 
 
 
1
  """Shared test fixtures."""
2
  from __future__ import annotations
3
 
4
+ from contextlib import asynccontextmanager
5
 
6
  import pytest
7
 
 
49
  def fake_classes():
50
  FakeAdapter.instances.clear()
51
  return {FakeAdapter.id: FakeAdapter, FakeAdapterB.id: FakeAdapterB}
52
+
53
+
54
+ @asynccontextmanager
55
+ async def lifespan_ctx(app):
56
+ """Run an ASGI app's lifespan startup/shutdown around an `httpx.AsyncClient`."""
57
+ async with app.router.lifespan_context(app):
58
+ yield
tests/test_adapter_contract.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ import pytest
4
+
5
+ from server.models.base import is_valid_adapter
6
+ from server.schemas import ParamSpec
7
+
8
+
9
+ ADAPTER_MODULES = [
10
+ "server.models.chatterbox_en",
11
+ ]
12
+
13
+
14
+ @pytest.mark.parametrize("module_name", ADAPTER_MODULES)
15
+ def test_adapter_class_attributes_valid(module_name):
16
+ mod = importlib.import_module(module_name)
17
+ cls = getattr(mod, "Adapter")
18
+ assert is_valid_adapter(cls)
19
+ assert cls.id
20
+ for p in cls.params:
21
+ assert isinstance(p, ParamSpec)
tests/test_main_activate.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ import httpx
4
+ import pytest
5
+
6
+ from server.main import build_app
7
+
8
+
9
+ pytestmark = pytest.mark.asyncio
10
+
11
+
12
+ async def test_activate_then_status_loaded(monkeypatch, fake_classes):
13
+ monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
14
+ monkeypatch.setattr("server.main.select_device", lambda: "cpu")
15
+ app = build_app()
16
+ from tests.conftest import lifespan_ctx
17
+ transport = httpx.ASGITransport(app=app)
18
+ async with lifespan_ctx(app), httpx.AsyncClient(transport=transport, base_url="http://t") as c:
19
+ r = await c.post("/api/models/fake/activate")
20
+ assert r.status_code in (200, 202)
21
+ for _ in range(20):
22
+ s = (await c.get("/api/models/active")).json()
23
+ if s["status"] == "loaded":
24
+ break
25
+ await asyncio.sleep(0.05)
26
+ assert s["id"] == "fake"
27
+ assert s["status"] == "loaded"
28
+
29
+
30
+ async def test_activate_unknown_returns_404(monkeypatch, fake_classes):
31
+ monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
32
+ monkeypatch.setattr("server.main.select_device", lambda: "cpu")
33
+ app = build_app()
34
+ from tests.conftest import lifespan_ctx
35
+ transport = httpx.ASGITransport(app=app)
36
+ async with lifespan_ctx(app), httpx.AsyncClient(transport=transport, base_url="http://t") as c:
37
+ r = await c.post("/api/models/nope/activate")
38
+ assert r.status_code == 404
39
+ assert r.json()["error"]["code"] == "model_not_found"
40
+
41
+
42
+ # Note: integration test for /api/models/active/events SSE stream is omitted.
43
+ # Registry event emission is unit-tested in tests/test_registry.py
44
+ # (test_emits_loading_then_loaded_events). The /api/models/active/events
45
+ # endpoint is a thin sse-starlette wrapper around that generator.
tests/test_main_generate.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import httpx
2
+ import pytest
3
+
4
+ from server.main import build_app
5
+
6
+
7
+ pytestmark = pytest.mark.asyncio
8
+
9
+
10
+ async def test_generate_returns_wav_bytes(monkeypatch, fake_classes):
11
+ monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
12
+ monkeypatch.setattr("server.main.select_device", lambda: "cpu")
13
+ app = build_app()
14
+ from tests.conftest import lifespan_ctx
15
+ transport = httpx.ASGITransport(app=app)
16
+ async with lifespan_ctx(app), httpx.AsyncClient(transport=transport, base_url="http://t") as c:
17
+ r = await c.post(
18
+ "/api/generate",
19
+ data={
20
+ "text": "hello world",
21
+ "model_id": "fake",
22
+ "params": "{}",
23
+ },
24
+ )
25
+ assert r.status_code == 200
26
+ assert r.headers["content-type"].startswith("audio/wav")
27
+ assert r.content == b"FAKEWAV"
28
+
29
+
30
+ async def test_generate_unknown_model_404(monkeypatch, fake_classes):
31
+ monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
32
+ monkeypatch.setattr("server.main.select_device", lambda: "cpu")
33
+ app = build_app()
34
+ from tests.conftest import lifespan_ctx
35
+ transport = httpx.ASGITransport(app=app)
36
+ async with lifespan_ctx(app), httpx.AsyncClient(transport=transport, base_url="http://t") as c:
37
+ r = await c.post(
38
+ "/api/generate",
39
+ data={"text": "x", "model_id": "nope", "params": "{}"},
40
+ )
41
+ assert r.status_code == 404
42
+ assert r.json()["error"]["code"] == "model_not_found"
43
+
44
+
45
+ async def test_generate_invalid_reference_returns_400(monkeypatch, fake_classes):
46
+ monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
47
+ monkeypatch.setattr("server.main.select_device", lambda: "cpu")
48
+ app = build_app()
49
+ from tests.conftest import lifespan_ctx
50
+ transport = httpx.ASGITransport(app=app)
51
+ bad = b"not a wav"
52
+ async with lifespan_ctx(app), httpx.AsyncClient(transport=transport, base_url="http://t") as c:
53
+ r = await c.post(
54
+ "/api/generate",
55
+ data={"text": "x", "model_id": "fake", "params": "{}"},
56
+ files={"reference_wav": ("ref.wav", bad, "audio/wav")},
57
+ )
58
+ assert r.status_code == 400
59
+ assert r.json()["error"]["code"] == "reference_invalid"
tests/test_main_health.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi.testclient import TestClient
2
+
3
+ from server.main import build_app
4
+
5
+
6
+ def test_health_returns_device_and_status(monkeypatch, fake_classes):
7
+ monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
8
+ monkeypatch.setattr("server.main.select_device", lambda: "cpu")
9
+ app = build_app()
10
+ with TestClient(app) as client:
11
+ r = client.get("/api/health")
12
+ assert r.status_code == 200
13
+ data = r.json()
14
+ assert data["device"] == "cpu"
15
+ assert data["model_status"] == "idle"
16
+ assert "torch_version" in data
tests/test_main_models.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi.testclient import TestClient
2
+
3
+ from server.main import build_app
4
+
5
+
6
+ def test_models_list_returns_registered(monkeypatch, fake_classes):
7
+ monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
8
+ monkeypatch.setattr("server.main.select_device", lambda: "cpu")
9
+ app = build_app()
10
+ with TestClient(app) as client:
11
+ r = client.get("/api/models")
12
+ assert r.status_code == 200
13
+ items = r.json()
14
+ ids = sorted(m["id"] for m in items)
15
+ assert ids == ["fake", "fake-b"]
16
+ fake = next(m for m in items if m["id"] == "fake")
17
+ assert fake["paralinguistic_tags"] == ["[laugh]"]
18
+ assert fake["params"][0]["name"] == "t"
19
+
20
+
21
+ def test_active_model_initially_idle(monkeypatch, fake_classes):
22
+ monkeypatch.setattr("server.main._discover_adapter_classes", lambda: fake_classes)
23
+ monkeypatch.setattr("server.main.select_device", lambda: "cpu")
24
+ app = build_app()
25
+ with TestClient(app) as client:
26
+ r = client.get("/api/models/active")
27
+ assert r.status_code == 200
28
+ body = r.json()
29
+ assert body["id"] is None
30
+ assert body["status"] == "idle"