Spaces:
Running on Zero
Running on Zero
| """ | |
| APIRouter for /api/tools/* endpoints. | |
| Each endpoint is sync request-response (no SSE, no job state). Input files | |
| land in a fresh per-run directory, outputs are returned as a download URL | |
| to GET /api/tools/file/{run_id}/{filename}. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| from pathlib import Path | |
| from typing import Optional | |
| from fastapi import APIRouter, File, Form, HTTPException, Request, UploadFile | |
| from fastapi.responses import FileResponse, JSONResponse, PlainTextResponse | |
| from pydantic import BaseModel | |
| from server import limiter, _download_url, _is_allowed_video_host | |
| from . import audio_cleanup, dramabox, scene_writer, subtitles, voice_clone | |
| from .storage import ( | |
| file_url, | |
| new_run_dir, | |
| reap_old_runs, | |
| run_dir, | |
| safe_filename, | |
| ) | |
| router = APIRouter(prefix="/api/tools", tags=["tools"]) | |
| # Per-tool body size cap (separate from pipeline's MAX_UPLOAD_BYTES check). | |
| TOOLS_MAX_BYTES = 50 * 1024 * 1024 # 50 MB | |
| # ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def _save_upload(file: UploadFile, dest_dir: Path, default_name: str) -> Path: | |
| """Stream upload to disk, enforcing the tools size cap.""" | |
| dest = dest_dir / safe_filename(file.filename, default_name) | |
| written = 0 | |
| with open(dest, "wb") as fh: | |
| while chunk := await file.read(1024 * 1024): | |
| written += len(chunk) | |
| if written > TOOLS_MAX_BYTES: | |
| fh.close() | |
| dest.unlink(missing_ok=True) | |
| raise HTTPException(413, f"File too large (max {TOOLS_MAX_BYTES // (1024*1024)} MB).") | |
| fh.write(chunk) | |
| return dest | |
| def _ext_to_media_type(filename: str) -> str: | |
| ext = Path(filename).suffix.lower() | |
| return { | |
| ".mp4": "video/mp4", | |
| ".mov": "video/quicktime", | |
| ".webm": "video/webm", | |
| ".mp3": "audio/mpeg", | |
| ".wav": "audio/wav", | |
| ".srt": "application/x-subrip", | |
| ".vtt": "text/vtt", | |
| ".txt": "text/plain", | |
| }.get(ext, "application/octet-stream") | |
| # ββ Subtitles ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def subtitles_endpoint( | |
| request: Request, | |
| file: Optional[UploadFile] = File(None), | |
| url: Optional[str] = Form(None), | |
| source_lang: str = Form("Auto-detect"), | |
| target_lang: str = Form("Same as source"), | |
| fmt: str = Form("srt"), | |
| style: str = Form("tiktok"), | |
| position: str = Form("bottom"), | |
| h_align: str = Form("center"), | |
| font_size: Optional[int] = Form(None), | |
| margin_v: Optional[int] = Form(None), | |
| ): | |
| if fmt not in ("srt", "vtt", "txt", "mp4"): | |
| raise HTTPException(400, "fmt must be one of: srt, vtt, txt, mp4") | |
| if style not in ("tiktok", "youtube", "minimal"): | |
| raise HTTPException(400, "style must be one of: tiktok, youtube, minimal") | |
| if position not in ("top", "middle", "bottom"): | |
| raise HTTPException(400, "position must be one of: top, middle, bottom") | |
| if h_align not in ("left", "center", "right"): | |
| raise HTTPException(400, "h_align must be one of: left, center, right") | |
| url = (url or "").strip() | |
| if not file and not url: | |
| raise HTTPException(400, "Provide either a file upload or a video URL.") | |
| if file and url: | |
| raise HTTPException(400, "Send a file OR a URL, not both.") | |
| run_id, dest_dir = new_run_dir() | |
| if file: | |
| input_path = await _save_upload(file, dest_dir, "input.mp4") | |
| else: | |
| if not _is_allowed_video_host(url): | |
| raise HTTPException(400, "URL host not supported. Use TikTok, YouTube, or Instagram.") | |
| input_path = Path(dest_dir) / "input.mp4" | |
| try: | |
| await asyncio.to_thread(_download_url, url, str(input_path)) | |
| except Exception as e: # noqa: BLE001 | |
| raise HTTPException(400, f"Couldn't fetch the video URL: {e}") | |
| try: | |
| # Heavy: transcribe + (optional) translate + (optional) ffmpeg burn-in. | |
| # Run off the event loop so concurrent requests don't starve. | |
| info = await asyncio.to_thread( | |
| subtitles.generate_subtitles, | |
| input_path=input_path, | |
| out_dir=dest_dir, | |
| source_lang_name=source_lang, | |
| target_lang_name=target_lang, | |
| fmt=fmt, # type: ignore[arg-type] | |
| style=style, # type: ignore[arg-type] | |
| position=position, # type: ignore[arg-type] | |
| h_align=h_align, # type: ignore[arg-type] | |
| font_size=font_size, | |
| margin_v=margin_v, | |
| ) | |
| except ValueError as e: | |
| raise HTTPException(400, str(e)) | |
| except Exception as e: # noqa: BLE001 | |
| raise HTTPException(500, f"Subtitle generation failed: {e}") | |
| return JSONResponse({ | |
| "run_id": run_id, | |
| "format": info["format"], | |
| "filename": info["filename"], | |
| "url": file_url(run_id, info["filename"]), | |
| "segments": info["segments"], | |
| "translated": info["translated"], | |
| }) | |
| # ββ Voice clone ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def voice_clone_endpoint( | |
| request: Request, | |
| sample: UploadFile = File(...), | |
| text: str = Form(...), | |
| language_id: str = Form("en"), | |
| ): | |
| text = (text or "").strip() | |
| if not text: | |
| raise HTTPException(400, "text is required") | |
| if len(text) > 1000: | |
| raise HTTPException(400, "text exceeds 1000 char limit") | |
| run_id, dest_dir = new_run_dir() | |
| sample_path = await _save_upload(sample, dest_dir, "sample.wav") | |
| try: | |
| info = await asyncio.to_thread( | |
| voice_clone.clone_voice, | |
| sample_path=sample_path, | |
| text=text, | |
| out_dir=dest_dir, | |
| language_id=language_id, | |
| ) | |
| except ValueError as e: | |
| raise HTTPException(400, str(e)) | |
| except Exception as e: # noqa: BLE001 | |
| raise HTTPException(500, f"Voice clone failed: {e}") | |
| return JSONResponse({ | |
| "run_id": run_id, | |
| "engine": info["engine"], | |
| "chunks": info["chunks"], | |
| "filename": info["filename"], | |
| "url": file_url(run_id, info["filename"]), | |
| }) | |
| # ββ Dramabox βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def dramabox_endpoint( | |
| request: Request, | |
| prompt: str = Form(...), | |
| audio_ref: Optional[UploadFile] = File(None), | |
| cfg: float = Form(2.5), | |
| stg: float = Form(1.5), | |
| dur_mult: float = Form(1.1), | |
| gen_dur: float = Form(0.0), | |
| ref_dur: float = Form(10.0), | |
| seed: int = Form(42), | |
| ): | |
| prompt = (prompt or "").strip() | |
| if not prompt: | |
| raise HTTPException(400, "prompt is required") | |
| if len(prompt) > 2000: | |
| raise HTTPException(400, "prompt exceeds 2000 char limit") | |
| # Range guards mirror the upstream Dramabox sliders. | |
| if not (1.0 <= cfg <= 10.0): | |
| raise HTTPException(400, "cfg must be between 1 and 10") | |
| if not (0.0 <= stg <= 5.0): | |
| raise HTTPException(400, "stg must be between 0 and 5") | |
| if not (0.8 <= dur_mult <= 2.0): | |
| raise HTTPException(400, "dur_mult must be between 0.8 and 2.0") | |
| if not (0.0 <= gen_dur <= 60.0): | |
| raise HTTPException(400, "gen_dur must be between 0 and 60") | |
| if not (3.0 <= ref_dur <= 30.0): | |
| raise HTTPException(400, "ref_dur must be between 3 and 30") | |
| run_id, dest_dir = new_run_dir() | |
| ref_path: Optional[Path] = None | |
| if audio_ref is not None and audio_ref.filename: | |
| ref_path = await _save_upload(audio_ref, dest_dir, "voice_ref.wav") | |
| try: | |
| info = await asyncio.to_thread( | |
| dramabox.generate_scene, | |
| prompt=prompt, | |
| out_dir=dest_dir, | |
| audio_ref=ref_path, | |
| cfg=cfg, | |
| stg=stg, | |
| dur_mult=dur_mult, | |
| gen_dur=gen_dur, | |
| ref_dur=ref_dur, | |
| seed=seed, | |
| ) | |
| except ValueError as e: | |
| raise HTTPException(400, str(e)) | |
| except RuntimeError as e: | |
| # Raised by dramabox._ensure_server() on Spaces that don't ship the | |
| # vendored model. Surface clearly so the frontend can fall back. | |
| raise HTTPException(503, str(e)) | |
| except Exception as e: # noqa: BLE001 | |
| raise HTTPException(500, f"Dramabox generation failed: {e}") | |
| return JSONResponse({ | |
| "run_id": run_id, | |
| "filename": info["filename"], | |
| "url": file_url(run_id, info["filename"]), | |
| "elapsed": info["elapsed"], | |
| "settings": info["settings"], | |
| }) | |
| # ββ Audio cleanup ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def audio_cleanup_endpoint( | |
| request: Request, | |
| file: UploadFile = File(...), | |
| mode: str = Form("vocals-only"), | |
| ): | |
| if mode not in ("vocals-only", "instrumental-only", "stems"): | |
| raise HTTPException(400, "mode must be one of: vocals-only, instrumental-only, stems") | |
| run_id, dest_dir = new_run_dir() | |
| input_path = await _save_upload(file, dest_dir, "input.wav") | |
| try: | |
| stems = await asyncio.to_thread( | |
| audio_cleanup.separate, | |
| input_path=input_path, | |
| out_dir=dest_dir, | |
| mode=mode, # type: ignore[arg-type] | |
| ) | |
| except ValueError as e: | |
| raise HTTPException(400, str(e)) | |
| except Exception as e: # noqa: BLE001 | |
| raise HTTPException(500, f"Audio separation failed: {e}") | |
| return JSONResponse({ | |
| "run_id": run_id, | |
| "mode": mode, | |
| "stems": [ | |
| {**stem, "url": file_url(run_id, stem["filename"])} | |
| for stem in stems | |
| ], | |
| }) | |
| # ββ Scene writer βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class SceneRequest(BaseModel): | |
| hook_pattern: str | |
| hook_title: str = "" | |
| register: str = "" | |
| register_label: str = "" | |
| niche: str | |
| async def scene_endpoint(request: Request, body: SceneRequest): | |
| """Draft a directed DramaBox scene from a viral hook + the creator's niche.""" | |
| niche = body.niche.strip() | |
| hook_pattern = body.hook_pattern.strip() | |
| if not niche: | |
| raise HTTPException(400, "niche is required") | |
| if not hook_pattern: | |
| raise HTTPException(400, "hook_pattern is required") | |
| if len(niche) > 400: | |
| raise HTTPException(400, "niche exceeds 400 char limit") | |
| try: | |
| scene = await asyncio.to_thread( | |
| scene_writer.write_scene, | |
| hook_pattern=hook_pattern, | |
| hook_title=body.hook_title.strip(), | |
| register=body.register.strip(), | |
| register_label=body.register_label.strip(), | |
| niche=niche, | |
| ) | |
| except Exception as e: # noqa: BLE001 | |
| raise HTTPException(500, f"Scene generation failed: {e}") | |
| return JSONResponse({"scene": scene}) | |
| # ββ File download ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def tools_file(run_id: str, filename: str): | |
| """Serve a generated artifact. Run dirs auto-expire after RUN_TTL_SECONDS.""" | |
| safe_name = safe_filename(filename) | |
| if safe_name != filename: | |
| raise HTTPException(400, "Invalid filename") | |
| base = run_dir(run_id) | |
| if base is None: | |
| raise HTTPException(404, "Run not found or expired") | |
| target = base / safe_name | |
| if not target.exists() or not target.is_file(): | |
| raise HTTPException(404, "File not found") | |
| return FileResponse( | |
| path=str(target), | |
| media_type=_ext_to_media_type(safe_name), | |
| filename=safe_name, | |
| ) | |
| # ββ Cleanup hook βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def _reap(): | |
| """Manual reap trigger (mostly for testing). Auto-reap runs on a timer.""" | |
| removed = await asyncio.to_thread(reap_old_runs) | |
| return {"removed": removed} | |