File size: 2,226 Bytes
ba54ea9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Recap MI300X premium-mode backend. Runs on the AMD Developer Cloud droplet.

Deploy:
    cd backend
    pip install -r requirements.txt
    # ROCm torch installed separately on the droplet image.
    uvicorn backend.server:app --host 0.0.0.0 --port 8080

Then expose to the public Space via ngrok / cloudflared and set
RECAP_MI300X_URL in the Space's env to the public URL.
"""

from __future__ import annotations

import os
from contextlib import asynccontextmanager

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

from backend import serve

EAGER_LOAD = os.getenv("RECAP_EAGER_LOAD", "1") == "1"


@asynccontextmanager
async def lifespan(app: FastAPI):
    if EAGER_LOAD:
        # Load models at startup so the first /medgemma request is fast.
        # Set RECAP_EAGER_LOAD=0 if you want a fast boot for debugging.
        try:
            serve._ensure_loaded()
        except Exception as e:  # noqa: BLE001 — defer the failure to first request
            print(f"[server] eager load failed: {e}", flush=True)
    yield


app = FastAPI(title="Recap Premium Backend", version="0.1.0", lifespan=lifespan)


class GenRequest(BaseModel):
    system: str
    user: str
    max_new_tokens: int = 384


class GenResponse(BaseModel):
    text: str


@app.post("/medgemma", response_model=GenResponse)
def medgemma(req: GenRequest) -> GenResponse:
    try:
        text = serve.medgemma_extract(req.system, req.user, req.max_new_tokens)
    except Exception as e:  # noqa: BLE001
        raise HTTPException(status_code=500, detail=str(e)) from e
    return GenResponse(text=text)


@app.post("/qwen", response_model=GenResponse)
def qwen(req: GenRequest) -> GenResponse:
    try:
        text = serve.qwen_synthesize(req.system, req.user, req.max_new_tokens)
    except Exception as e:  # noqa: BLE001
        raise HTTPException(status_code=500, detail=str(e)) from e
    return GenResponse(text=text)


@app.get("/health")
def health() -> dict:
    return {
        "ok": True,
        "loaded": serve._state.get("loaded", False),
        "memory": serve.memory_stats(),
        "models": {
            "medgemma_id": serve.MEDGEMMA_ID,
            "qwen_id": serve.QWEN_ID,
        },
    }