medivision-ai-agent / src /model_loader.py
dikheng's picture
refactor: system/user prompt split + robust JSON extraction
f89a5cf
"""
Inference backend: calls the vLLM server on AMD Developer Cloud via the
OpenAI-compatible API. No local model weights are loaded here.
"""
import base64
import mimetypes
import os
import time
import src.config as config
_client = None
def _get_client():
global _client
if _client is None:
from openai import OpenAI
_client = OpenAI(
base_url=f"{config.VLLM_API_URL}/v1",
api_key=os.environ.get("VLLM_API_KEY", "not-required"),
)
return _client
def _encode_image(image_path: str) -> tuple[str, str]:
mime_type, _ = mimetypes.guess_type(image_path)
if not mime_type:
mime_type = "image/jpeg"
with open(image_path, "rb") as f:
data = base64.b64encode(f.read()).decode("utf-8")
return data, mime_type
def check_connection() -> tuple[bool, str]:
"""
Ping the vLLM server's /v1/models endpoint.
Returns (is_connected: bool, status_message: str).
"""
import requests as req
url = f"{config.VLLM_API_URL}/v1/models"
api_key = os.environ.get("VLLM_API_KEY", "not-required")
print(f"[Connection] Checking AMD Cloud at {url} ...")
try:
r = req.get(url, headers={"Authorization": f"Bearer {api_key}"}, timeout=5)
if r.status_code == 200:
models = [m.get("id", "?") for m in r.json().get("data", [])]
print(f"[Connection] OK — models available: {models}")
return True, f"Connected · {config.VLLM_API_URL}"
print(f"[Connection] FAILED — HTTP {r.status_code}: {r.text[:200]}")
return False, f"HTTP {r.status_code}"
except req.exceptions.ConnectionError as exc:
print(f"[Connection] FAILED — ConnectionError: {exc}")
return False, f"ConnectionError: {exc}"
except req.exceptions.Timeout:
print(f"[Connection] FAILED — Timeout after 5s")
return False, "Timeout (5s)"
except Exception as exc:
print(f"[Connection] FAILED — {type(exc).__name__}: {exc}")
return False, f"{type(exc).__name__}: {exc}"
def generate_response(
system_prompt: str,
user_prompt: str,
image_path: str = None,
image_path_2: str = None,
max_tokens: int = None,
temperature: float = None,
force_json: bool = False,
) -> tuple[str, dict]:
"""
Send a chat completion to the vLLM endpoint with proper system/user separation.
system_prompt → role: system
user_prompt → role: user (may include 0, 1, or 2 images)
Returns (text_output, metrics).
metrics keys: latency_ms, total_tokens, tokens_per_sec
"""
try:
client = _get_client()
_max_tokens = max_tokens if max_tokens is not None else config.MAX_NEW_TOKENS
_temperature = temperature if temperature is not None else config.TEMPERATURE
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if image_path or image_path_2:
content = []
if image_path:
b64, mime = _encode_image(image_path)
content.append({"type": "image_url",
"image_url": {"url": f"data:{mime};base64,{b64}"}})
if image_path_2:
b64, mime = _encode_image(image_path_2)
content.append({"type": "image_url",
"image_url": {"url": f"data:{mime};base64,{b64}"}})
content.append({"type": "text", "text": user_prompt})
messages.append({"role": "user", "content": content})
else:
messages.append({"role": "user", "content": user_prompt})
kwargs = dict(
model=config.MODEL_NAME,
messages=messages,
max_tokens=_max_tokens,
temperature=_temperature,
)
if force_json:
kwargs["response_format"] = {"type": "json_object"}
t0 = time.perf_counter()
response = client.chat.completions.create(**kwargs)
latency_ms = (time.perf_counter() - t0) * 1000
usage = getattr(response, "usage", None)
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
total_tokens = getattr(usage, "total_tokens", 0) or 0
tokens_per_sec = (completion_tokens / (latency_ms / 1000)) if latency_ms > 0 and completion_tokens > 0 else 0
metrics = {
"latency_ms": round(latency_ms),
"total_tokens": total_tokens,
"tokens_per_sec": round(tokens_per_sec, 1),
}
return response.choices[0].message.content, metrics
except Exception as exc:
raise RuntimeError(f"AMD Cloud backend unreachable: {exc}") from exc
def generate_text(
system_prompt: str,
user_prompt: str,
max_tokens: int = None,
temperature: float = None,
force_json: bool = False,
) -> tuple[str, dict]:
"""Text-only call — no image encoding."""
return generate_response(
system_prompt=system_prompt,
user_prompt=user_prompt,
max_tokens=max_tokens,
temperature=temperature,
force_json=force_json,
)