Spaces:
Sleeping
Sleeping
| """FastAPI application factory.""" | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import os | |
| import tempfile | |
| from contextlib import asynccontextmanager | |
| from pathlib import Path | |
| import torch | |
| from fastapi import FastAPI, File, Form, HTTPException, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, Response | |
| from fastapi.staticfiles import StaticFiles | |
| from sse_starlette.sse import EventSourceResponse | |
| from server.audio import AudioValidationError, validate_reference_clip | |
| from server.device import select_device | |
| from server.dialog import ( | |
| DialogParseError, | |
| DialogReferenceError, | |
| generate_dialog, | |
| parse_dialog, | |
| ) | |
| from server.progress import get_bus | |
| from server.registry import Registry | |
| STATIC_DIR = Path(__file__).parent / "static" | |
| def _discover_adapter_classes() -> dict[str, type]: | |
| """Lazily import adapter modules. Empty dict during early scaffolding.""" | |
| classes: dict[str, type] = {} | |
| for module_name in ("chatterbox_en", "chatterbox_turbo", "chatterbox_mtl"): | |
| try: | |
| mod = __import__(f"server.models.{module_name}", fromlist=["Adapter"]) | |
| except ImportError: | |
| continue | |
| cls = getattr(mod, "Adapter", None) | |
| if cls is not None: | |
| classes[cls.id] = cls | |
| return classes | |
| def build_app() -> FastAPI: | |
| async def lifespan(app: FastAPI): | |
| device = select_device() | |
| app.state.registry = Registry( | |
| adapter_classes=_discover_adapter_classes(), | |
| device=device, | |
| ) | |
| yield | |
| app = FastAPI(title="Chatterbox Voice Studio", lifespan=lifespan) | |
| origins = os.getenv( | |
| "CORS_ORIGINS", | |
| "http://localhost:5173,http://127.0.0.1:5173", | |
| ).split(",") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def health() -> dict: | |
| registry = app.state.registry | |
| return { | |
| "device": registry.device, | |
| "torch_version": torch.__version__, | |
| "model_status": registry.status()["status"], | |
| } | |
| def list_models() -> list[dict]: | |
| return app.state.registry.list_models() | |
| def active_model() -> dict: | |
| return app.state.registry.status() | |
| async def activate_model(model_id: str): | |
| try: | |
| await app.state.registry.get_or_load(model_id) | |
| except KeyError: | |
| raise HTTPException( | |
| status_code=404, | |
| detail={"error": {"code": "model_not_found", "message": model_id}}, | |
| ) | |
| except Exception as exc: | |
| return JSONResponse( | |
| status_code=503, | |
| content={"error": {"code": "model_load_failed", "message": str(exc)}}, | |
| ) | |
| return {"ok": True} | |
| async def active_events(): | |
| async def gen(): | |
| async for evt in app.state.registry.stream_events(): | |
| yield {"data": json.dumps(evt)} | |
| return EventSourceResponse(gen()) | |
| async def progress_events(): | |
| bus = get_bus() | |
| async def gen(): | |
| async with bus.subscribe() as q: | |
| while True: | |
| evt = await q.get() | |
| yield {"data": json.dumps(evt.to_dict())} | |
| return EventSourceResponse(gen()) | |
| async def generate( | |
| text: str = Form(...), | |
| model_id: str = Form(...), | |
| params: str = Form("{}"), | |
| language: str | None = Form(None), | |
| reference_wav: UploadFile | None = File(None), | |
| ): | |
| try: | |
| adapter = await app.state.registry.get_or_load(model_id) | |
| except KeyError: | |
| raise HTTPException( | |
| status_code=404, | |
| detail={"error": {"code": "model_not_found", "message": model_id}}, | |
| ) | |
| ref_path: str | None = None | |
| if reference_wav is not None: | |
| data = await reference_wav.read() | |
| try: | |
| validate_reference_clip(data) | |
| except AudioValidationError as exc: | |
| detail = { | |
| "size_bytes": len(data), | |
| "first_4": data[:4].decode("latin-1", errors="replace"), | |
| "filename": reference_wav.filename, | |
| "content_type": reference_wav.content_type, | |
| } | |
| print( | |
| f"[reference_invalid] {exc} | {detail}", | |
| flush=True, | |
| ) | |
| return JSONResponse( | |
| status_code=400, | |
| content={ | |
| "error": { | |
| "code": "reference_invalid", | |
| "message": str(exc), | |
| "detail": detail, | |
| } | |
| }, | |
| ) | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
| tmp.write(data) | |
| tmp.flush() | |
| tmp.close() | |
| ref_path = tmp.name | |
| bus = get_bus() | |
| try: | |
| async with bus.session("single", total_turns=1) as sess: | |
| wav_bytes, _sr, seed_used = await asyncio.to_thread( | |
| adapter.generate, text, ref_path, language, json.loads(params or "{}"), | |
| ) | |
| sess.set_seed(seed_used) | |
| except Exception as exc: | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": {"code": "generation_failed", "message": str(exc)}}, | |
| ) | |
| return Response( | |
| content=wav_bytes, | |
| media_type="audio/wav", | |
| headers={"X-Seed-Used": str(seed_used), "Access-Control-Expose-Headers": "X-Seed-Used"}, | |
| ) | |
| async def generate_dialog_route( | |
| text: str = Form(...), | |
| engine_id: str = Form(...), | |
| params: str = Form("{}"), | |
| language: str | None = Form(None), | |
| reference_wav_a: UploadFile | None = File(None), | |
| reference_wav_b: UploadFile | None = File(None), | |
| reference_wav_c: UploadFile | None = File(None), | |
| reference_wav_d: UploadFile | None = File(None), | |
| ): | |
| speaker_clips: dict[str, bytes] = {} | |
| upload_map = { | |
| "A": reference_wav_a, | |
| "B": reference_wav_b, | |
| "C": reference_wav_c, | |
| "D": reference_wav_d, | |
| } | |
| for letter, upload in upload_map.items(): | |
| if upload is None: | |
| continue | |
| data = await upload.read() | |
| try: | |
| validate_reference_clip(data) | |
| except AudioValidationError as exc: | |
| return JSONResponse( | |
| status_code=400, | |
| content={ | |
| "error": { | |
| "code": "reference_invalid", | |
| "message": f"speaker {letter}: {exc}", | |
| } | |
| }, | |
| ) | |
| speaker_clips[letter] = data | |
| bus = get_bus() | |
| try: | |
| turns_preview = parse_dialog(text) | |
| total_turns = len(turns_preview) | |
| except DialogParseError as exc: | |
| return JSONResponse( | |
| status_code=400, | |
| content={ | |
| "error": {"code": "dialog_format_invalid", "message": str(exc)} | |
| }, | |
| ) | |
| try: | |
| async with bus.session("dialog", total_turns=total_turns) as sess: | |
| wav_bytes, _sr, seed_used = await generate_dialog( | |
| registry=app.state.registry, | |
| engine_id=engine_id, | |
| text=text, | |
| language=language, | |
| params=json.loads(params or "{}"), | |
| speaker_clips=speaker_clips, | |
| session=sess, | |
| ) | |
| sess.set_seed(seed_used) | |
| except KeyError: | |
| raise HTTPException( | |
| status_code=404, | |
| detail={"error": {"code": "model_not_found", "message": engine_id}}, | |
| ) | |
| except DialogReferenceError as exc: | |
| return JSONResponse( | |
| status_code=400, | |
| content={ | |
| "error": {"code": "dialog_missing_reference", "message": str(exc)} | |
| }, | |
| ) | |
| except Exception as exc: | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "error": {"code": "generation_failed", "message": str(exc)} | |
| }, | |
| ) | |
| return Response( | |
| content=wav_bytes, | |
| media_type="audio/wav", | |
| headers={ | |
| "X-Seed-Used": str(seed_used), | |
| "Access-Control-Expose-Headers": "X-Seed-Used", | |
| }, | |
| ) | |
| async def _http_exc(request, exc: HTTPException): | |
| if isinstance(exc.detail, dict) and "error" in exc.detail: | |
| return JSONResponse(status_code=exc.status_code, content=exc.detail) | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content={"error": {"code": "http_error", "message": str(exc.detail)}}, | |
| ) | |
| if STATIC_DIR.exists(): | |
| app.mount("/", StaticFiles(directory=str(STATIC_DIR), html=True), name="static") | |
| return app | |
| app = build_app() | |