techfreakworm's picture
refactor: remove dead ZeroGPU shim (Docker Spaces don't support ZeroGPU)
1d7b5cd unverified
"""FastAPI application factory."""
from __future__ import annotations
import asyncio
import json
import os
import tempfile
from contextlib import asynccontextmanager
from pathlib import Path
import torch
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response
from fastapi.staticfiles import StaticFiles
from sse_starlette.sse import EventSourceResponse
from server.audio import AudioValidationError, validate_reference_clip
from server.device import select_device
from server.dialog import (
DialogParseError,
DialogReferenceError,
generate_dialog,
parse_dialog,
)
from server.progress import get_bus
from server.registry import Registry
STATIC_DIR = Path(__file__).parent / "static"
def _discover_adapter_classes() -> dict[str, type]:
"""Lazily import adapter modules. Empty dict during early scaffolding."""
classes: dict[str, type] = {}
for module_name in ("chatterbox_en", "chatterbox_turbo", "chatterbox_mtl"):
try:
mod = __import__(f"server.models.{module_name}", fromlist=["Adapter"])
except ImportError:
continue
cls = getattr(mod, "Adapter", None)
if cls is not None:
classes[cls.id] = cls
return classes
def build_app() -> FastAPI:
@asynccontextmanager
async def lifespan(app: FastAPI):
device = select_device()
app.state.registry = Registry(
adapter_classes=_discover_adapter_classes(),
device=device,
)
yield
app = FastAPI(title="Chatterbox Voice Studio", lifespan=lifespan)
origins = os.getenv(
"CORS_ORIGINS",
"http://localhost:5173,http://127.0.0.1:5173",
).split(",")
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/api/health")
def health() -> dict:
registry = app.state.registry
return {
"device": registry.device,
"torch_version": torch.__version__,
"model_status": registry.status()["status"],
}
@app.get("/api/models")
def list_models() -> list[dict]:
return app.state.registry.list_models()
@app.get("/api/models/active")
def active_model() -> dict:
return app.state.registry.status()
@app.post("/api/models/{model_id}/activate")
async def activate_model(model_id: str):
try:
await app.state.registry.get_or_load(model_id)
except KeyError:
raise HTTPException(
status_code=404,
detail={"error": {"code": "model_not_found", "message": model_id}},
)
except Exception as exc:
return JSONResponse(
status_code=503,
content={"error": {"code": "model_load_failed", "message": str(exc)}},
)
return {"ok": True}
@app.get("/api/models/active/events")
async def active_events():
async def gen():
async for evt in app.state.registry.stream_events():
yield {"data": json.dumps(evt)}
return EventSourceResponse(gen())
@app.get("/api/progress")
async def progress_events():
bus = get_bus()
async def gen():
async with bus.subscribe() as q:
while True:
evt = await q.get()
yield {"data": json.dumps(evt.to_dict())}
return EventSourceResponse(gen())
@app.post("/api/generate")
async def generate(
text: str = Form(...),
model_id: str = Form(...),
params: str = Form("{}"),
language: str | None = Form(None),
reference_wav: UploadFile | None = File(None),
):
try:
adapter = await app.state.registry.get_or_load(model_id)
except KeyError:
raise HTTPException(
status_code=404,
detail={"error": {"code": "model_not_found", "message": model_id}},
)
ref_path: str | None = None
if reference_wav is not None:
data = await reference_wav.read()
try:
validate_reference_clip(data)
except AudioValidationError as exc:
detail = {
"size_bytes": len(data),
"first_4": data[:4].decode("latin-1", errors="replace"),
"filename": reference_wav.filename,
"content_type": reference_wav.content_type,
}
print(
f"[reference_invalid] {exc} | {detail}",
flush=True,
)
return JSONResponse(
status_code=400,
content={
"error": {
"code": "reference_invalid",
"message": str(exc),
"detail": detail,
}
},
)
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
tmp.write(data)
tmp.flush()
tmp.close()
ref_path = tmp.name
bus = get_bus()
try:
async with bus.session("single", total_turns=1) as sess:
wav_bytes, _sr, seed_used = await asyncio.to_thread(
adapter.generate, text, ref_path, language, json.loads(params or "{}"),
)
sess.set_seed(seed_used)
except Exception as exc:
return JSONResponse(
status_code=500,
content={"error": {"code": "generation_failed", "message": str(exc)}},
)
return Response(
content=wav_bytes,
media_type="audio/wav",
headers={"X-Seed-Used": str(seed_used), "Access-Control-Expose-Headers": "X-Seed-Used"},
)
@app.post("/api/generate/dialog")
async def generate_dialog_route(
text: str = Form(...),
engine_id: str = Form(...),
params: str = Form("{}"),
language: str | None = Form(None),
reference_wav_a: UploadFile | None = File(None),
reference_wav_b: UploadFile | None = File(None),
reference_wav_c: UploadFile | None = File(None),
reference_wav_d: UploadFile | None = File(None),
):
speaker_clips: dict[str, bytes] = {}
upload_map = {
"A": reference_wav_a,
"B": reference_wav_b,
"C": reference_wav_c,
"D": reference_wav_d,
}
for letter, upload in upload_map.items():
if upload is None:
continue
data = await upload.read()
try:
validate_reference_clip(data)
except AudioValidationError as exc:
return JSONResponse(
status_code=400,
content={
"error": {
"code": "reference_invalid",
"message": f"speaker {letter}: {exc}",
}
},
)
speaker_clips[letter] = data
bus = get_bus()
try:
turns_preview = parse_dialog(text)
total_turns = len(turns_preview)
except DialogParseError as exc:
return JSONResponse(
status_code=400,
content={
"error": {"code": "dialog_format_invalid", "message": str(exc)}
},
)
try:
async with bus.session("dialog", total_turns=total_turns) as sess:
wav_bytes, _sr, seed_used = await generate_dialog(
registry=app.state.registry,
engine_id=engine_id,
text=text,
language=language,
params=json.loads(params or "{}"),
speaker_clips=speaker_clips,
session=sess,
)
sess.set_seed(seed_used)
except KeyError:
raise HTTPException(
status_code=404,
detail={"error": {"code": "model_not_found", "message": engine_id}},
)
except DialogReferenceError as exc:
return JSONResponse(
status_code=400,
content={
"error": {"code": "dialog_missing_reference", "message": str(exc)}
},
)
except Exception as exc:
return JSONResponse(
status_code=500,
content={
"error": {"code": "generation_failed", "message": str(exc)}
},
)
return Response(
content=wav_bytes,
media_type="audio/wav",
headers={
"X-Seed-Used": str(seed_used),
"Access-Control-Expose-Headers": "X-Seed-Used",
},
)
@app.exception_handler(HTTPException)
async def _http_exc(request, exc: HTTPException):
if isinstance(exc.detail, dict) and "error" in exc.detail:
return JSONResponse(status_code=exc.status_code, content=exc.detail)
return JSONResponse(
status_code=exc.status_code,
content={"error": {"code": "http_error", "message": str(exc.detail)}},
)
if STATIC_DIR.exists():
app.mount("/", StaticFiles(directory=str(STATIC_DIR), html=True), name="static")
return app
app = build_app()