""" Inference backend: calls the vLLM server on AMD Developer Cloud via the OpenAI-compatible API. No local model weights are loaded here. """ import base64 import mimetypes import os import time import src.config as config _client = None def _get_client(): global _client if _client is None: from openai import OpenAI _client = OpenAI( base_url=f"{config.VLLM_API_URL}/v1", api_key=os.environ.get("VLLM_API_KEY", "not-required"), ) return _client def _encode_image(image_path: str) -> tuple[str, str]: mime_type, _ = mimetypes.guess_type(image_path) if not mime_type: mime_type = "image/jpeg" with open(image_path, "rb") as f: data = base64.b64encode(f.read()).decode("utf-8") return data, mime_type def check_connection() -> tuple[bool, str]: """ Ping the vLLM server's /v1/models endpoint. Returns (is_connected: bool, status_message: str). """ import requests as req url = f"{config.VLLM_API_URL}/v1/models" api_key = os.environ.get("VLLM_API_KEY", "not-required") print(f"[Connection] Checking AMD Cloud at {url} ...") try: r = req.get(url, headers={"Authorization": f"Bearer {api_key}"}, timeout=5) if r.status_code == 200: models = [m.get("id", "?") for m in r.json().get("data", [])] print(f"[Connection] OK — models available: {models}") return True, f"Connected · {config.VLLM_API_URL}" print(f"[Connection] FAILED — HTTP {r.status_code}: {r.text[:200]}") return False, f"HTTP {r.status_code}" except req.exceptions.ConnectionError as exc: print(f"[Connection] FAILED — ConnectionError: {exc}") return False, f"ConnectionError: {exc}" except req.exceptions.Timeout: print(f"[Connection] FAILED — Timeout after 5s") return False, "Timeout (5s)" except Exception as exc: print(f"[Connection] FAILED — {type(exc).__name__}: {exc}") return False, f"{type(exc).__name__}: {exc}" def generate_response( system_prompt: str, user_prompt: str, image_path: str = None, image_path_2: str = None, max_tokens: int = None, temperature: float = None, force_json: bool = False, ) -> tuple[str, dict]: """ Send a chat completion to the vLLM endpoint with proper system/user separation. system_prompt → role: system user_prompt → role: user (may include 0, 1, or 2 images) Returns (text_output, metrics). metrics keys: latency_ms, total_tokens, tokens_per_sec """ try: client = _get_client() _max_tokens = max_tokens if max_tokens is not None else config.MAX_NEW_TOKENS _temperature = temperature if temperature is not None else config.TEMPERATURE messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) if image_path or image_path_2: content = [] if image_path: b64, mime = _encode_image(image_path) content.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}}) if image_path_2: b64, mime = _encode_image(image_path_2) content.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}}) content.append({"type": "text", "text": user_prompt}) messages.append({"role": "user", "content": content}) else: messages.append({"role": "user", "content": user_prompt}) kwargs = dict( model=config.MODEL_NAME, messages=messages, max_tokens=_max_tokens, temperature=_temperature, ) if force_json: kwargs["response_format"] = {"type": "json_object"} t0 = time.perf_counter() response = client.chat.completions.create(**kwargs) latency_ms = (time.perf_counter() - t0) * 1000 usage = getattr(response, "usage", None) completion_tokens = getattr(usage, "completion_tokens", 0) or 0 total_tokens = getattr(usage, "total_tokens", 0) or 0 tokens_per_sec = (completion_tokens / (latency_ms / 1000)) if latency_ms > 0 and completion_tokens > 0 else 0 metrics = { "latency_ms": round(latency_ms), "total_tokens": total_tokens, "tokens_per_sec": round(tokens_per_sec, 1), } return response.choices[0].message.content, metrics except Exception as exc: raise RuntimeError(f"AMD Cloud backend unreachable: {exc}") from exc def generate_text( system_prompt: str, user_prompt: str, max_tokens: int = None, temperature: float = None, force_json: bool = False, ) -> tuple[str, dict]: """Text-only call — no image encoding.""" return generate_response( system_prompt=system_prompt, user_prompt=user_prompt, max_tokens=max_tokens, temperature=temperature, force_json=force_json, )