File size: 5,141 Bytes
3ec5524 229fc53 c19c1f8 229fc53 3ec5524 229fc53 3ec5524 229fc53 3ec5524 229fc53 3ec5524 229fc53 5a060ad 398c962 5a060ad 398c962 5a060ad 398c962 a1ea7b4 398c962 5a060ad a1ea7b4 398c962 5a060ad f89a5cf 229fc53 f89a5cf c19c1f8 f89a5cf c19c1f8 f89a5cf 229fc53 3ec5524 f89a5cf 3ec5524 ea938cc f89a5cf 3ec5524 f89a5cf 3ec5524 f89a5cf 3ec5524 f89a5cf 229fc53 f89a5cf c19c1f8 229fc53 3ec5524 a1ea7b4 ea938cc f89a5cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | """
Inference backend: calls the vLLM server on AMD Developer Cloud via the
OpenAI-compatible API. No local model weights are loaded here.
"""
import base64
import mimetypes
import os
import time
import src.config as config
_client = None
def _get_client():
global _client
if _client is None:
from openai import OpenAI
_client = OpenAI(
base_url=f"{config.VLLM_API_URL}/v1",
api_key=os.environ.get("VLLM_API_KEY", "not-required"),
)
return _client
def _encode_image(image_path: str) -> tuple[str, str]:
mime_type, _ = mimetypes.guess_type(image_path)
if not mime_type:
mime_type = "image/jpeg"
with open(image_path, "rb") as f:
data = base64.b64encode(f.read()).decode("utf-8")
return data, mime_type
def check_connection() -> tuple[bool, str]:
"""
Ping the vLLM server's /v1/models endpoint.
Returns (is_connected: bool, status_message: str).
"""
import requests as req
url = f"{config.VLLM_API_URL}/v1/models"
api_key = os.environ.get("VLLM_API_KEY", "not-required")
print(f"[Connection] Checking AMD Cloud at {url} ...")
try:
r = req.get(url, headers={"Authorization": f"Bearer {api_key}"}, timeout=5)
if r.status_code == 200:
models = [m.get("id", "?") for m in r.json().get("data", [])]
print(f"[Connection] OK — models available: {models}")
return True, f"Connected · {config.VLLM_API_URL}"
print(f"[Connection] FAILED — HTTP {r.status_code}: {r.text[:200]}")
return False, f"HTTP {r.status_code}"
except req.exceptions.ConnectionError as exc:
print(f"[Connection] FAILED — ConnectionError: {exc}")
return False, f"ConnectionError: {exc}"
except req.exceptions.Timeout:
print(f"[Connection] FAILED — Timeout after 5s")
return False, "Timeout (5s)"
except Exception as exc:
print(f"[Connection] FAILED — {type(exc).__name__}: {exc}")
return False, f"{type(exc).__name__}: {exc}"
def generate_response(
system_prompt: str,
user_prompt: str,
image_path: str = None,
image_path_2: str = None,
max_tokens: int = None,
temperature: float = None,
force_json: bool = False,
) -> tuple[str, dict]:
"""
Send a chat completion to the vLLM endpoint with proper system/user separation.
system_prompt → role: system
user_prompt → role: user (may include 0, 1, or 2 images)
Returns (text_output, metrics).
metrics keys: latency_ms, total_tokens, tokens_per_sec
"""
try:
client = _get_client()
_max_tokens = max_tokens if max_tokens is not None else config.MAX_NEW_TOKENS
_temperature = temperature if temperature is not None else config.TEMPERATURE
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if image_path or image_path_2:
content = []
if image_path:
b64, mime = _encode_image(image_path)
content.append({"type": "image_url",
"image_url": {"url": f"data:{mime};base64,{b64}"}})
if image_path_2:
b64, mime = _encode_image(image_path_2)
content.append({"type": "image_url",
"image_url": {"url": f"data:{mime};base64,{b64}"}})
content.append({"type": "text", "text": user_prompt})
messages.append({"role": "user", "content": content})
else:
messages.append({"role": "user", "content": user_prompt})
kwargs = dict(
model=config.MODEL_NAME,
messages=messages,
max_tokens=_max_tokens,
temperature=_temperature,
)
if force_json:
kwargs["response_format"] = {"type": "json_object"}
t0 = time.perf_counter()
response = client.chat.completions.create(**kwargs)
latency_ms = (time.perf_counter() - t0) * 1000
usage = getattr(response, "usage", None)
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
total_tokens = getattr(usage, "total_tokens", 0) or 0
tokens_per_sec = (completion_tokens / (latency_ms / 1000)) if latency_ms > 0 and completion_tokens > 0 else 0
metrics = {
"latency_ms": round(latency_ms),
"total_tokens": total_tokens,
"tokens_per_sec": round(tokens_per_sec, 1),
}
return response.choices[0].message.content, metrics
except Exception as exc:
raise RuntimeError(f"AMD Cloud backend unreachable: {exc}") from exc
def generate_text(
system_prompt: str,
user_prompt: str,
max_tokens: int = None,
temperature: float = None,
force_json: bool = False,
) -> tuple[str, dict]:
"""Text-only call — no image encoding."""
return generate_response(
system_prompt=system_prompt,
user_prompt=user_prompt,
max_tokens=max_tokens,
temperature=temperature,
force_json=force_json,
)
|