"""FastAPI application factory.""" from __future__ import annotations import asyncio import json import os import tempfile from contextlib import asynccontextmanager from pathlib import Path import torch from fastapi import FastAPI, File, Form, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response from fastapi.staticfiles import StaticFiles from sse_starlette.sse import EventSourceResponse from server.audio import AudioValidationError, validate_reference_clip from server.device import select_device from server.dialog import ( DialogParseError, DialogReferenceError, generate_dialog, parse_dialog, ) from server.progress import get_bus from server.registry import Registry STATIC_DIR = Path(__file__).parent / "static" def _discover_adapter_classes() -> dict[str, type]: """Lazily import adapter modules. Empty dict during early scaffolding.""" classes: dict[str, type] = {} for module_name in ("chatterbox_en", "chatterbox_turbo", "chatterbox_mtl"): try: mod = __import__(f"server.models.{module_name}", fromlist=["Adapter"]) except ImportError: continue cls = getattr(mod, "Adapter", None) if cls is not None: classes[cls.id] = cls return classes def build_app() -> FastAPI: @asynccontextmanager async def lifespan(app: FastAPI): device = select_device() app.state.registry = Registry( adapter_classes=_discover_adapter_classes(), device=device, ) yield app = FastAPI(title="Chatterbox Voice Studio", lifespan=lifespan) origins = os.getenv( "CORS_ORIGINS", "http://localhost:5173,http://127.0.0.1:5173", ).split(",") app.add_middleware( CORSMiddleware, allow_origins=origins, allow_methods=["*"], allow_headers=["*"], ) @app.get("/api/health") def health() -> dict: registry = app.state.registry return { "device": registry.device, "torch_version": torch.__version__, "model_status": registry.status()["status"], } @app.get("/api/models") def list_models() -> list[dict]: return app.state.registry.list_models() @app.get("/api/models/active") def active_model() -> dict: return app.state.registry.status() @app.post("/api/models/{model_id}/activate") async def activate_model(model_id: str): try: await app.state.registry.get_or_load(model_id) except KeyError: raise HTTPException( status_code=404, detail={"error": {"code": "model_not_found", "message": model_id}}, ) except Exception as exc: return JSONResponse( status_code=503, content={"error": {"code": "model_load_failed", "message": str(exc)}}, ) return {"ok": True} @app.get("/api/models/active/events") async def active_events(): async def gen(): async for evt in app.state.registry.stream_events(): yield {"data": json.dumps(evt)} return EventSourceResponse(gen()) @app.get("/api/progress") async def progress_events(): bus = get_bus() async def gen(): async with bus.subscribe() as q: while True: evt = await q.get() yield {"data": json.dumps(evt.to_dict())} return EventSourceResponse(gen()) @app.post("/api/generate") async def generate( text: str = Form(...), model_id: str = Form(...), params: str = Form("{}"), language: str | None = Form(None), reference_wav: UploadFile | None = File(None), ): try: adapter = await app.state.registry.get_or_load(model_id) except KeyError: raise HTTPException( status_code=404, detail={"error": {"code": "model_not_found", "message": model_id}}, ) ref_path: str | None = None if reference_wav is not None: data = await reference_wav.read() try: validate_reference_clip(data) except AudioValidationError as exc: detail = { "size_bytes": len(data), "first_4": data[:4].decode("latin-1", errors="replace"), "filename": reference_wav.filename, "content_type": reference_wav.content_type, } print( f"[reference_invalid] {exc} | {detail}", flush=True, ) return JSONResponse( status_code=400, content={ "error": { "code": "reference_invalid", "message": str(exc), "detail": detail, } }, ) tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") tmp.write(data) tmp.flush() tmp.close() ref_path = tmp.name bus = get_bus() try: async with bus.session("single", total_turns=1) as sess: wav_bytes, _sr, seed_used = await asyncio.to_thread( adapter.generate, text, ref_path, language, json.loads(params or "{}"), ) sess.set_seed(seed_used) except Exception as exc: return JSONResponse( status_code=500, content={"error": {"code": "generation_failed", "message": str(exc)}}, ) return Response( content=wav_bytes, media_type="audio/wav", headers={"X-Seed-Used": str(seed_used), "Access-Control-Expose-Headers": "X-Seed-Used"}, ) @app.post("/api/generate/dialog") async def generate_dialog_route( text: str = Form(...), engine_id: str = Form(...), params: str = Form("{}"), language: str | None = Form(None), reference_wav_a: UploadFile | None = File(None), reference_wav_b: UploadFile | None = File(None), reference_wav_c: UploadFile | None = File(None), reference_wav_d: UploadFile | None = File(None), ): speaker_clips: dict[str, bytes] = {} upload_map = { "A": reference_wav_a, "B": reference_wav_b, "C": reference_wav_c, "D": reference_wav_d, } for letter, upload in upload_map.items(): if upload is None: continue data = await upload.read() try: validate_reference_clip(data) except AudioValidationError as exc: return JSONResponse( status_code=400, content={ "error": { "code": "reference_invalid", "message": f"speaker {letter}: {exc}", } }, ) speaker_clips[letter] = data bus = get_bus() try: turns_preview = parse_dialog(text) total_turns = len(turns_preview) except DialogParseError as exc: return JSONResponse( status_code=400, content={ "error": {"code": "dialog_format_invalid", "message": str(exc)} }, ) try: async with bus.session("dialog", total_turns=total_turns) as sess: wav_bytes, _sr, seed_used = await generate_dialog( registry=app.state.registry, engine_id=engine_id, text=text, language=language, params=json.loads(params or "{}"), speaker_clips=speaker_clips, session=sess, ) sess.set_seed(seed_used) except KeyError: raise HTTPException( status_code=404, detail={"error": {"code": "model_not_found", "message": engine_id}}, ) except DialogReferenceError as exc: return JSONResponse( status_code=400, content={ "error": {"code": "dialog_missing_reference", "message": str(exc)} }, ) except Exception as exc: return JSONResponse( status_code=500, content={ "error": {"code": "generation_failed", "message": str(exc)} }, ) return Response( content=wav_bytes, media_type="audio/wav", headers={ "X-Seed-Used": str(seed_used), "Access-Control-Expose-Headers": "X-Seed-Used", }, ) @app.exception_handler(HTTPException) async def _http_exc(request, exc: HTTPException): if isinstance(exc.detail, dict) and "error" in exc.detail: return JSONResponse(status_code=exc.status_code, content=exc.detail) return JSONResponse( status_code=exc.status_code, content={"error": {"code": "http_error", "message": str(exc.detail)}}, ) if STATIC_DIR.exists(): app.mount("/", StaticFiles(directory=str(STATIC_DIR), html=True), name="static") return app app = build_app()