recap / backend /serve.py
afif-ahmed's picture
deploy: sync from fe7cce1
ba54ea9 verified
"""Loads MedGemma-27B-MM + Qwen-32B co-resident on a single AMD MI300X.
Designed to run inside the FastAPI server on the droplet. Models are loaded
lazily (first request triggers load) so the health endpoint is responsive
even before the heavy weights touch GPU memory.
"""
from __future__ import annotations
import os
import time
from threading import Lock
from typing import Any
import torch
from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoProcessor,
AutoTokenizer,
)
MEDGEMMA_ID = os.getenv("MEDGEMMA_ID", "google/medgemma-27b-it")
QWEN_ID = os.getenv("QWEN_ID", "Qwen/Qwen3.6-27B")
DEVICE = "cuda:0"
DTYPE = torch.bfloat16
_state: dict[str, Any] = {"loaded": False}
_lock = Lock()
def _ensure_loaded() -> None:
"""Load both models into GPU memory once. Idempotent + thread-safe."""
if _state["loaded"]:
return
with _lock:
if _state["loaded"]: # double-checked
return
t0 = time.time()
print(f"[serve] loading MedGemma: {MEDGEMMA_ID}", flush=True)
_state["medgemma_proc"] = AutoProcessor.from_pretrained(MEDGEMMA_ID)
_state["medgemma"] = AutoModelForImageTextToText.from_pretrained(
MEDGEMMA_ID, torch_dtype=DTYPE, device_map=DEVICE,
)
torch.cuda.synchronize()
peak_after_mg = torch.cuda.max_memory_allocated() / 1e9
print(f"[serve] medgemma loaded in {time.time() - t0:.1f}s, peak {peak_after_mg:.1f} GB", flush=True)
t1 = time.time()
print(f"[serve] loading Qwen: {QWEN_ID}", flush=True)
_state["qwen_tok"] = AutoTokenizer.from_pretrained(QWEN_ID)
_state["qwen"] = AutoModelForCausalLM.from_pretrained(
QWEN_ID, torch_dtype=DTYPE, device_map=DEVICE,
)
torch.cuda.synchronize()
peak = torch.cuda.max_memory_allocated() / 1e9
print(f"[serve] qwen loaded in {time.time() - t1:.1f}s, total peak {peak:.1f} GB", flush=True)
_state["loaded"] = True
_state["peak_after_load_gb"] = peak
def memory_stats() -> dict[str, float]:
if not torch.cuda.is_available():
return {"available": False}
return {
"available": True,
"allocated_gb": torch.cuda.memory_allocated() / 1e9,
"reserved_gb": torch.cuda.memory_reserved() / 1e9,
"total_gb": torch.cuda.get_device_properties(0).total_memory / 1e9,
"peak_after_load_gb": _state.get("peak_after_load_gb"),
"device_name": torch.cuda.get_device_name(0),
}
def medgemma_extract(system: str, user: str, max_new_tokens: int = 384) -> str:
"""First stage of the two-stage reasoner: read records, surface relevant findings."""
_ensure_loaded()
msgs = [
{"role": "system", "content": [{"type": "text", "text": system}]},
{"role": "user", "content": [{"type": "text", "text": user}]},
]
inputs = _state["medgemma_proc"].apply_chat_template(
msgs, add_generation_prompt=True, tokenize=True,
return_dict=True, return_tensors="pt",
).to(DEVICE)
out = _state["medgemma"].generate(
**inputs, max_new_tokens=max_new_tokens, do_sample=False,
)
new_tokens = out[0][inputs["input_ids"].shape[-1]:]
return _state["medgemma_proc"].decode(new_tokens, skip_special_tokens=True)
def qwen_synthesize(system: str, user: str, max_new_tokens: int = 512) -> str:
"""Second stage: synthesize MedGemma's findings into the final cited answer."""
_ensure_loaded()
msgs = [
{"role": "system", "content": system},
{"role": "user", "content": user},
]
text = _state["qwen_tok"].apply_chat_template(
msgs, add_generation_prompt=True, tokenize=False,
)
inputs = _state["qwen_tok"](text, return_tensors="pt").to(DEVICE)
out = _state["qwen"].generate(
**inputs, max_new_tokens=max_new_tokens, do_sample=False,
)
new_tokens = out[0][inputs["input_ids"].shape[-1]:]
return _state["qwen_tok"].decode(new_tokens, skip_special_tokens=True)