| """Loads MedGemma-27B-MM + Qwen-32B co-resident on a single AMD MI300X. |
| |
| Designed to run inside the FastAPI server on the droplet. Models are loaded |
| lazily (first request triggers load) so the health endpoint is responsive |
| even before the heavy weights touch GPU memory. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import time |
| from threading import Lock |
| from typing import Any |
|
|
| import torch |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoModelForImageTextToText, |
| AutoProcessor, |
| AutoTokenizer, |
| ) |
|
|
| MEDGEMMA_ID = os.getenv("MEDGEMMA_ID", "google/medgemma-27b-it") |
| QWEN_ID = os.getenv("QWEN_ID", "Qwen/Qwen3.6-27B") |
|
|
| DEVICE = "cuda:0" |
| DTYPE = torch.bfloat16 |
|
|
| _state: dict[str, Any] = {"loaded": False} |
| _lock = Lock() |
|
|
|
|
| def _ensure_loaded() -> None: |
| """Load both models into GPU memory once. Idempotent + thread-safe.""" |
| if _state["loaded"]: |
| return |
| with _lock: |
| if _state["loaded"]: |
| return |
|
|
| t0 = time.time() |
| print(f"[serve] loading MedGemma: {MEDGEMMA_ID}", flush=True) |
| _state["medgemma_proc"] = AutoProcessor.from_pretrained(MEDGEMMA_ID) |
| _state["medgemma"] = AutoModelForImageTextToText.from_pretrained( |
| MEDGEMMA_ID, torch_dtype=DTYPE, device_map=DEVICE, |
| ) |
| torch.cuda.synchronize() |
| peak_after_mg = torch.cuda.max_memory_allocated() / 1e9 |
| print(f"[serve] medgemma loaded in {time.time() - t0:.1f}s, peak {peak_after_mg:.1f} GB", flush=True) |
|
|
| t1 = time.time() |
| print(f"[serve] loading Qwen: {QWEN_ID}", flush=True) |
| _state["qwen_tok"] = AutoTokenizer.from_pretrained(QWEN_ID) |
| _state["qwen"] = AutoModelForCausalLM.from_pretrained( |
| QWEN_ID, torch_dtype=DTYPE, device_map=DEVICE, |
| ) |
| torch.cuda.synchronize() |
| peak = torch.cuda.max_memory_allocated() / 1e9 |
| print(f"[serve] qwen loaded in {time.time() - t1:.1f}s, total peak {peak:.1f} GB", flush=True) |
|
|
| _state["loaded"] = True |
| _state["peak_after_load_gb"] = peak |
|
|
|
|
| def memory_stats() -> dict[str, float]: |
| if not torch.cuda.is_available(): |
| return {"available": False} |
| return { |
| "available": True, |
| "allocated_gb": torch.cuda.memory_allocated() / 1e9, |
| "reserved_gb": torch.cuda.memory_reserved() / 1e9, |
| "total_gb": torch.cuda.get_device_properties(0).total_memory / 1e9, |
| "peak_after_load_gb": _state.get("peak_after_load_gb"), |
| "device_name": torch.cuda.get_device_name(0), |
| } |
|
|
|
|
| def medgemma_extract(system: str, user: str, max_new_tokens: int = 384) -> str: |
| """First stage of the two-stage reasoner: read records, surface relevant findings.""" |
| _ensure_loaded() |
| msgs = [ |
| {"role": "system", "content": [{"type": "text", "text": system}]}, |
| {"role": "user", "content": [{"type": "text", "text": user}]}, |
| ] |
| inputs = _state["medgemma_proc"].apply_chat_template( |
| msgs, add_generation_prompt=True, tokenize=True, |
| return_dict=True, return_tensors="pt", |
| ).to(DEVICE) |
| out = _state["medgemma"].generate( |
| **inputs, max_new_tokens=max_new_tokens, do_sample=False, |
| ) |
| new_tokens = out[0][inputs["input_ids"].shape[-1]:] |
| return _state["medgemma_proc"].decode(new_tokens, skip_special_tokens=True) |
|
|
|
|
| def qwen_synthesize(system: str, user: str, max_new_tokens: int = 512) -> str: |
| """Second stage: synthesize MedGemma's findings into the final cited answer.""" |
| _ensure_loaded() |
| msgs = [ |
| {"role": "system", "content": system}, |
| {"role": "user", "content": user}, |
| ] |
| text = _state["qwen_tok"].apply_chat_template( |
| msgs, add_generation_prompt=True, tokenize=False, |
| ) |
| inputs = _state["qwen_tok"](text, return_tensors="pt").to(DEVICE) |
| out = _state["qwen"].generate( |
| **inputs, max_new_tokens=max_new_tokens, do_sample=False, |
| ) |
| new_tokens = out[0][inputs["input_ids"].shape[-1]:] |
| return _state["qwen_tok"].decode(new_tokens, skip_special_tokens=True) |
|
|