File size: 5,141 Bytes
3ec5524
 
 
 
 
 
229fc53
c19c1f8
229fc53
3ec5524
229fc53
3ec5524
229fc53
 
3ec5524
 
 
 
 
 
 
 
 
229fc53
 
3ec5524
 
 
 
 
 
 
229fc53
 
5a060ad
 
 
 
 
398c962
 
 
 
 
 
5a060ad
 
 
398c962
 
5a060ad
398c962
 
 
 
 
 
a1ea7b4
398c962
5a060ad
a1ea7b4
398c962
5a060ad
 
f89a5cf
 
 
 
 
 
 
 
 
229fc53
f89a5cf
c19c1f8
f89a5cf
 
c19c1f8
f89a5cf
 
229fc53
3ec5524
 
f89a5cf
 
 
 
 
 
3ec5524
ea938cc
 
 
 
 
 
 
 
 
 
f89a5cf
 
3ec5524
f89a5cf
3ec5524
f89a5cf
3ec5524
 
f89a5cf
 
229fc53
f89a5cf
 
 
 
 
c19c1f8
 
 
 
 
 
 
 
 
 
 
 
 
229fc53
3ec5524
a1ea7b4
ea938cc
 
f89a5cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
Inference backend: calls the vLLM server on AMD Developer Cloud via the
OpenAI-compatible API.  No local model weights are loaded here.
"""
import base64
import mimetypes
import os
import time

import src.config as config

_client = None


def _get_client():
    global _client
    if _client is None:
        from openai import OpenAI
        _client = OpenAI(
            base_url=f"{config.VLLM_API_URL}/v1",
            api_key=os.environ.get("VLLM_API_KEY", "not-required"),
        )
    return _client


def _encode_image(image_path: str) -> tuple[str, str]:
    mime_type, _ = mimetypes.guess_type(image_path)
    if not mime_type:
        mime_type = "image/jpeg"
    with open(image_path, "rb") as f:
        data = base64.b64encode(f.read()).decode("utf-8")
    return data, mime_type


def check_connection() -> tuple[bool, str]:
    """
    Ping the vLLM server's /v1/models endpoint.
    Returns (is_connected: bool, status_message: str).
    """
    import requests as req

    url = f"{config.VLLM_API_URL}/v1/models"
    api_key = os.environ.get("VLLM_API_KEY", "not-required")
    print(f"[Connection] Checking AMD Cloud at {url} ...")

    try:
        r = req.get(url, headers={"Authorization": f"Bearer {api_key}"}, timeout=5)
        if r.status_code == 200:
            models = [m.get("id", "?") for m in r.json().get("data", [])]
            print(f"[Connection] OK — models available: {models}")
            return True, f"Connected · {config.VLLM_API_URL}"
        print(f"[Connection] FAILED — HTTP {r.status_code}: {r.text[:200]}")
        return False, f"HTTP {r.status_code}"
    except req.exceptions.ConnectionError as exc:
        print(f"[Connection] FAILED — ConnectionError: {exc}")
        return False, f"ConnectionError: {exc}"
    except req.exceptions.Timeout:
        print(f"[Connection] FAILED — Timeout after 5s")
        return False, "Timeout (5s)"
    except Exception as exc:
        print(f"[Connection] FAILED — {type(exc).__name__}: {exc}")
        return False, f"{type(exc).__name__}: {exc}"


def generate_response(
    system_prompt: str,
    user_prompt: str,
    image_path: str = None,
    image_path_2: str = None,
    max_tokens: int = None,
    temperature: float = None,
    force_json: bool = False,
) -> tuple[str, dict]:
    """
    Send a chat completion to the vLLM endpoint with proper system/user separation.

    system_prompt → role: system
    user_prompt   → role: user (may include 0, 1, or 2 images)

    Returns (text_output, metrics).
    metrics keys: latency_ms, total_tokens, tokens_per_sec
    """
    try:
        client = _get_client()
        _max_tokens = max_tokens if max_tokens is not None else config.MAX_NEW_TOKENS
        _temperature = temperature if temperature is not None else config.TEMPERATURE

        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})

        if image_path or image_path_2:
            content = []
            if image_path:
                b64, mime = _encode_image(image_path)
                content.append({"type": "image_url",
                                 "image_url": {"url": f"data:{mime};base64,{b64}"}})
            if image_path_2:
                b64, mime = _encode_image(image_path_2)
                content.append({"type": "image_url",
                                 "image_url": {"url": f"data:{mime};base64,{b64}"}})
            content.append({"type": "text", "text": user_prompt})
            messages.append({"role": "user", "content": content})
        else:
            messages.append({"role": "user", "content": user_prompt})

        kwargs = dict(
            model=config.MODEL_NAME,
            messages=messages,
            max_tokens=_max_tokens,
            temperature=_temperature,
        )
        if force_json:
            kwargs["response_format"] = {"type": "json_object"}

        t0 = time.perf_counter()
        response = client.chat.completions.create(**kwargs)
        latency_ms = (time.perf_counter() - t0) * 1000

        usage = getattr(response, "usage", None)
        completion_tokens = getattr(usage, "completion_tokens", 0) or 0
        total_tokens = getattr(usage, "total_tokens", 0) or 0
        tokens_per_sec = (completion_tokens / (latency_ms / 1000)) if latency_ms > 0 and completion_tokens > 0 else 0

        metrics = {
            "latency_ms": round(latency_ms),
            "total_tokens": total_tokens,
            "tokens_per_sec": round(tokens_per_sec, 1),
        }
        return response.choices[0].message.content, metrics

    except Exception as exc:
        raise RuntimeError(f"AMD Cloud backend unreachable: {exc}") from exc


def generate_text(
    system_prompt: str,
    user_prompt: str,
    max_tokens: int = None,
    temperature: float = None,
    force_json: bool = False,
) -> tuple[str, dict]:
    """Text-only call — no image encoding."""
    return generate_response(
        system_prompt=system_prompt,
        user_prompt=user_prompt,
        max_tokens=max_tokens,
        temperature=temperature,
        force_json=force_json,
    )