File size: 4,036 Bytes
ba54ea9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""Loads MedGemma-27B-MM + Qwen-32B co-resident on a single AMD MI300X.

Designed to run inside the FastAPI server on the droplet. Models are loaded
lazily (first request triggers load) so the health endpoint is responsive
even before the heavy weights touch GPU memory.
"""

from __future__ import annotations

import os
import time
from threading import Lock
from typing import Any

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoModelForImageTextToText,
    AutoProcessor,
    AutoTokenizer,
)

MEDGEMMA_ID = os.getenv("MEDGEMMA_ID", "google/medgemma-27b-it")
QWEN_ID = os.getenv("QWEN_ID", "Qwen/Qwen3.6-27B")

DEVICE = "cuda:0"
DTYPE = torch.bfloat16

_state: dict[str, Any] = {"loaded": False}
_lock = Lock()


def _ensure_loaded() -> None:
    """Load both models into GPU memory once. Idempotent + thread-safe."""
    if _state["loaded"]:
        return
    with _lock:
        if _state["loaded"]:  # double-checked
            return

        t0 = time.time()
        print(f"[serve] loading MedGemma: {MEDGEMMA_ID}", flush=True)
        _state["medgemma_proc"] = AutoProcessor.from_pretrained(MEDGEMMA_ID)
        _state["medgemma"] = AutoModelForImageTextToText.from_pretrained(
            MEDGEMMA_ID, torch_dtype=DTYPE, device_map=DEVICE,
        )
        torch.cuda.synchronize()
        peak_after_mg = torch.cuda.max_memory_allocated() / 1e9
        print(f"[serve] medgemma loaded in {time.time() - t0:.1f}s, peak {peak_after_mg:.1f} GB", flush=True)

        t1 = time.time()
        print(f"[serve] loading Qwen: {QWEN_ID}", flush=True)
        _state["qwen_tok"] = AutoTokenizer.from_pretrained(QWEN_ID)
        _state["qwen"] = AutoModelForCausalLM.from_pretrained(
            QWEN_ID, torch_dtype=DTYPE, device_map=DEVICE,
        )
        torch.cuda.synchronize()
        peak = torch.cuda.max_memory_allocated() / 1e9
        print(f"[serve] qwen loaded in {time.time() - t1:.1f}s, total peak {peak:.1f} GB", flush=True)

        _state["loaded"] = True
        _state["peak_after_load_gb"] = peak


def memory_stats() -> dict[str, float]:
    if not torch.cuda.is_available():
        return {"available": False}
    return {
        "available": True,
        "allocated_gb": torch.cuda.memory_allocated() / 1e9,
        "reserved_gb": torch.cuda.memory_reserved() / 1e9,
        "total_gb": torch.cuda.get_device_properties(0).total_memory / 1e9,
        "peak_after_load_gb": _state.get("peak_after_load_gb"),
        "device_name": torch.cuda.get_device_name(0),
    }


def medgemma_extract(system: str, user: str, max_new_tokens: int = 384) -> str:
    """First stage of the two-stage reasoner: read records, surface relevant findings."""
    _ensure_loaded()
    msgs = [
        {"role": "system", "content": [{"type": "text", "text": system}]},
        {"role": "user",   "content": [{"type": "text", "text": user}]},
    ]
    inputs = _state["medgemma_proc"].apply_chat_template(
        msgs, add_generation_prompt=True, tokenize=True,
        return_dict=True, return_tensors="pt",
    ).to(DEVICE)
    out = _state["medgemma"].generate(
        **inputs, max_new_tokens=max_new_tokens, do_sample=False,
    )
    new_tokens = out[0][inputs["input_ids"].shape[-1]:]
    return _state["medgemma_proc"].decode(new_tokens, skip_special_tokens=True)


def qwen_synthesize(system: str, user: str, max_new_tokens: int = 512) -> str:
    """Second stage: synthesize MedGemma's findings into the final cited answer."""
    _ensure_loaded()
    msgs = [
        {"role": "system", "content": system},
        {"role": "user",   "content": user},
    ]
    text = _state["qwen_tok"].apply_chat_template(
        msgs, add_generation_prompt=True, tokenize=False,
    )
    inputs = _state["qwen_tok"](text, return_tensors="pt").to(DEVICE)
    out = _state["qwen"].generate(
        **inputs, max_new_tokens=max_new_tokens, do_sample=False,
    )
    new_tokens = out[0][inputs["input_ids"].shape[-1]:]
    return _state["qwen_tok"].decode(new_tokens, skip_special_tokens=True)