Spaces:
Sleeping
Sleeping
| """OpenAI-compatible proxy that fans out to two local vLLM servers. | |
| Why a custom proxy and not nginx: | |
| * nginx routing on a JSON body field requires lua-nginx or njs (build | |
| pain on a CUDA base image), and we need to PEEK at the body without | |
| consuming it for streaming requests. | |
| * httpx async streaming + FastAPI is ~80 LoC, debuggable in plain Python, | |
| and reuses the same connection pool across requests. | |
| Endpoints exposed on :7860 (matches OpenAI spec): | |
| * GET /v1/models — lists both registered model ids | |
| * GET /v1/models/{model_id} — single model lookup | |
| * POST /v1/chat/completions — main route. Reads `model` from body | |
| and forwards to whichever vLLM owns it. | |
| * POST /v1/completions — same routing, kept for old clients. | |
| * GET /health — 200 iff both upstreams are healthy. | |
| HF's container monitor uses the Docker | |
| HEALTHCHECK from the Dockerfile, but | |
| we expose this for the demo's frontend | |
| so it can show a "warming up..." badge | |
| during cold starts. | |
| * GET / — friendly landing page so the bare | |
| Space URL doesn't 404. | |
| Streaming is forwarded byte-for-byte (StreamingResponse over the upstream's | |
| chunks) so SSE `data: {...}\n\n` framing survives intact. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| from contextlib import asynccontextmanager | |
| from typing import AsyncIterator | |
| import httpx | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse | |
| logger = logging.getLogger("physix-infer-proxy") | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s") | |
| QWEN_MODEL = os.environ.get("QWEN_MODEL", "Qwen/Qwen2.5-3B-Instruct") | |
| PHYSIX_MODEL = os.environ.get("PHYSIX_MODEL", "Pratyush-01/physix-3b-rl") | |
| QWEN_UPSTREAM = "http://127.0.0.1:8001" | |
| PHYSIX_UPSTREAM = "http://127.0.0.1:8002" | |
| ROUTING: dict[str, str] = { | |
| QWEN_MODEL: QWEN_UPSTREAM, | |
| PHYSIX_MODEL: PHYSIX_UPSTREAM, | |
| } | |
| # Generous timeout — first request after a cold start can sit on the | |
| # upstream for ~30 s while CUDA graphs warm up. Streaming tokens come | |
| # back fast once that's done. | |
| TIMEOUT = httpx.Timeout(connect=10.0, read=600.0, write=60.0, pool=5.0) | |
| async def lifespan(_app: FastAPI): | |
| """Open one shared httpx client for the proxy's lifetime. | |
| Keep-alive across requests matters: every chat completion otherwise | |
| pays a TCP+HTTP/1.1 handshake (~1-2 ms localhost, but it adds up | |
| under autoplay loops that fire 8 turns/episode). | |
| """ | |
| async with httpx.AsyncClient(timeout=TIMEOUT) as client: | |
| _app.state.http = client | |
| yield | |
| app = FastAPI( | |
| title="PhysiX-Infer", | |
| description="Dual-model OpenAI-compatible inference (Qwen 2.5 3B + physix-3b-rl).", | |
| lifespan=lifespan, | |
| ) | |
| def _resolve_upstream(model: str | None) -> str: | |
| if not model: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Missing 'model' field. Pass either " | |
| f"'{QWEN_MODEL}' or '{PHYSIX_MODEL}'.", | |
| ) | |
| upstream = ROUTING.get(model) | |
| if upstream is None: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=( | |
| f"Model '{model}' is not served by this Space. " | |
| f"Available: {list(ROUTING.keys())}." | |
| ), | |
| ) | |
| return upstream | |
| async def _proxy_json(request: Request, path: str) -> JSONResponse | StreamingResponse: | |
| """Read body, route on `model`, forward, stream back if `stream=true`.""" | |
| raw = await request.body() | |
| try: | |
| payload = json.loads(raw) if raw else {} | |
| except json.JSONDecodeError as exc: | |
| raise HTTPException(status_code=400, detail=f"Invalid JSON: {exc}") from exc | |
| upstream = _resolve_upstream(payload.get("model")) | |
| is_stream = bool(payload.get("stream")) | |
| # Strip any hop-by-hop headers; pass auth/content-type through. | |
| fwd_headers = { | |
| k: v | |
| for k, v in request.headers.items() | |
| if k.lower() in {"content-type", "accept", "authorization", "x-request-id"} | |
| } | |
| fwd_headers.setdefault("content-type", "application/json") | |
| client: httpx.AsyncClient = request.app.state.http | |
| upstream_url = f"{upstream}{path}" | |
| if not is_stream: | |
| try: | |
| resp = await client.post(upstream_url, content=raw, headers=fwd_headers) | |
| except httpx.HTTPError as exc: | |
| logger.exception("upstream %s failed", upstream_url) | |
| raise HTTPException(status_code=502, detail=f"Upstream error: {exc}") from exc | |
| # vLLM returns JSON with content-type=application/json or text/event-stream | |
| # for streaming. We've handled streaming above, so trust upstream content-type. | |
| return JSONResponse( | |
| status_code=resp.status_code, | |
| content=resp.json() if resp.headers.get("content-type", "").startswith("application/json") | |
| else {"raw": resp.text}, | |
| ) | |
| # Streaming path: open the upstream as a streaming request and pump | |
| # chunks straight to the client. Note the `async with` lives INSIDE | |
| # the generator so it stays open until StreamingResponse is done. | |
| async def _gen() -> AsyncIterator[bytes]: | |
| try: | |
| async with client.stream( | |
| "POST", upstream_url, content=raw, headers=fwd_headers | |
| ) as upstream_resp: | |
| if upstream_resp.status_code >= 400: | |
| body = await upstream_resp.aread() | |
| yield body | |
| return | |
| async for chunk in upstream_resp.aiter_raw(): | |
| if chunk: | |
| yield chunk | |
| except httpx.HTTPError as exc: | |
| logger.exception("upstream stream %s failed", upstream_url) | |
| err = json.dumps({"error": {"message": str(exc), "type": "upstream_error"}}) | |
| yield f"data: {err}\n\n".encode() | |
| return StreamingResponse(_gen(), media_type="text/event-stream") | |
| async def chat_completions(request: Request): | |
| return await _proxy_json(request, "/v1/chat/completions") | |
| async def completions(request: Request): | |
| return await _proxy_json(request, "/v1/completions") | |
| async def list_models(): | |
| """Static listing — vLLM exposes the same shape per-upstream, but we | |
| union them here so a single GET covers both. `created` and `owned_by` | |
| are filled with sensible placeholders since neither field is load-bearing | |
| for any client we know of.""" | |
| return { | |
| "object": "list", | |
| "data": [ | |
| { | |
| "id": QWEN_MODEL, | |
| "object": "model", | |
| "created": 0, | |
| "owned_by": "Qwen", | |
| }, | |
| { | |
| "id": PHYSIX_MODEL, | |
| "object": "model", | |
| "created": 0, | |
| "owned_by": "Pratyush-01", | |
| }, | |
| ], | |
| } | |
| async def get_model(model_id: str): | |
| if model_id not in ROUTING: | |
| raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found.") | |
| owner = "Qwen" if model_id == QWEN_MODEL else "Pratyush-01" | |
| return {"id": model_id, "object": "model", "created": 0, "owned_by": owner} | |
| async def health(request: Request): | |
| """Both upstreams must answer /health — the demo frontend uses this | |
| to decide whether to show a 'warming up' notice on cold start.""" | |
| client: httpx.AsyncClient = request.app.state.http | |
| statuses = {} | |
| overall_ok = True | |
| for name, base in (("qwen", QWEN_UPSTREAM), ("physix", PHYSIX_UPSTREAM)): | |
| try: | |
| r = await client.get(f"{base}/health", timeout=5.0) | |
| statuses[name] = "ok" if r.status_code == 200 else f"status={r.status_code}" | |
| overall_ok = overall_ok and r.status_code == 200 | |
| except httpx.HTTPError as exc: | |
| statuses[name] = f"unreachable: {exc.__class__.__name__}" | |
| overall_ok = False | |
| return JSONResponse( | |
| status_code=200 if overall_ok else 503, | |
| content={"status": "ok" if overall_ok else "starting", "upstreams": statuses}, | |
| ) | |
| async def root(): | |
| """Landing page so the bare Space URL doesn't 404. Plain HTML — no | |
| framework, no static dir to manage.""" | |
| return f"""<!doctype html> | |
| <html><head><meta charset="utf-8"><title>PhysiX-Infer</title> | |
| <style> | |
| body{{font-family:system-ui,sans-serif;max-width:680px;margin:3em auto;padding:0 1em;color:#222}} | |
| code,pre{{background:#f4f4f4;padding:.2em .4em;border-radius:4px;font-size:.95em}} | |
| pre{{padding:1em;overflow-x:auto}} | |
| h1{{margin-bottom:.2em}} | |
| .muted{{color:#777}} | |
| </style> | |
| </head><body> | |
| <h1>PhysiX-Infer</h1> | |
| <p class="muted">OpenAI-compatible inference proxy for two 3B Qwen2 checkpoints.</p> | |
| <h3>Models served</h3> | |
| <ul> | |
| <li><code>{QWEN_MODEL}</code> — untrained baseline</li> | |
| <li><code>{PHYSIX_MODEL}</code> — GRPO-trained variant</li> | |
| </ul> | |
| <h3>Endpoints</h3> | |
| <ul> | |
| <li><code>GET /v1/models</code></li> | |
| <li><code>POST /v1/chat/completions</code> (set <code>model</code> to one of the ids above)</li> | |
| <li><code>GET /health</code></li> | |
| </ul> | |
| <h3>Example</h3> | |
| <pre>curl -X POST https://<this-space>.hf.space/v1/chat/completions \\ | |
| -H 'content-type: application/json' \\ | |
| -d '{{"model":"{PHYSIX_MODEL}","messages":[{{"role":"user","content":"hi"}}]}}'</pre> | |
| <p class="muted">No auth, but the Space sleeps after a short idle window — first request after sleep takes ~90 s while both vLLMs warm up.</p> | |
| </body></html>""" | |