File size: 12,767 Bytes
9a75c73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
# mediagent/core/llm.py
"""
Production-grade LLM client wrapper for MediAgent.
Handles text and multimodal (vision) completions against local Qwen model.
Implements retry logic, error handling, response parsing, and OpenAI-compatible API calls.
"""

import logging
import time
import re
from typing import Any, Dict, List, Optional, Union

import openai

logger = logging.getLogger(__name__)


class LLMClient:
    """
    Lightweight, framework-agnostic LLM client wrapping OpenAI Python SDK.
    Designed for local inference endpoints (vLLM, Ollama, TensorRT-LLM) 
    running at http://localhost:8000/v1 with model path "/model".
    """

    DEFAULT_BASE_URL = "http://localhost:8000/v1"
    DEFAULT_MODEL = "/model"
    DEFAULT_API_KEY = "none"

    def __init__(
        self,
        base_url: str = DEFAULT_BASE_URL,
        model: str = DEFAULT_MODEL,
        max_retries: int = 3,
        timeout: float = 90.0,
        temperature: float = 0.0
    ):
        self.model = model
        self.max_retries = max_retries
        self.default_temperature = temperature
        self.timeout = timeout

        self.client = openai.OpenAI(
            base_url=base_url,
            api_key=self.DEFAULT_API_KEY,
            timeout=timeout
        )
        logger.info(f"LLMClient initialized | Model: {self.model} | Endpoint: {base_url}")

    # ─────────────────────────────────────────────────────────────────────────
    # CORE GENERATION METHODS
    # ─────────────────────────────────────────────────────────────────────────

    def generate_text(
        self,
        prompt: str,
        system_prompt: str = "",
        temperature: Optional[float] = None,
        force_json: bool = False,
        max_tokens: Optional[int] = None,
        extra_body: Optional[Dict] = None,
    ) -> Dict[str, Any]:
        """
        Send a text-only completion request to the LLM.
        Returns standardized response dict with content, usage, success flag, and error.
        """
        messages = self._build_messages(system_prompt, prompt)
        response_format = {"type": "json_object"} if force_json else None
        return self._execute_with_retry(
            messages=messages,
            temperature=temperature,
            response_format=response_format,
            call_type="TEXT",
            max_tokens=max_tokens,
            extra_body=extra_body,
        )

    def generate_text_streaming(
        self,
        prompt: str,
        system_prompt: str = "",
        temperature: Optional[float] = None,
        on_token: Optional[Any] = None,
    ) -> Dict[str, Any]:
        """
        Text completion with optional token-level streaming callback.
        When on_token is provided, calls on_token(chunk: str) for every token chunk
        as it arrives from the model. Returns the full response dict at the end.
        Falls back to standard generate_text if streaming fails.
        """
        if on_token is None:
            return self.generate_text(prompt, system_prompt, temperature)

        messages = self._build_messages(system_prompt, prompt)
        temp = temperature if temperature is not None else self.default_temperature

        try:
            stream = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                temperature=temp,
                stream=True,
            )
            full_content = ""
            for chunk in stream:
                delta = (chunk.choices[0].delta.content or "") if chunk.choices else ""
                if delta:
                    full_content += delta
                    try:
                        on_token(delta)
                    except Exception:
                        pass  # callback errors must not break generation

            logger.debug("Streaming TEXT generation completed | chars=%d", len(full_content))
            return {
                "success": True,
                "content": full_content,
                "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
                "model": self.model,
                "error": None,
            }
        except Exception as e:
            logger.warning("Streaming failed (%s), falling back to standard call", e)
            return self.generate_text(prompt, system_prompt, temperature)

    def generate_vision(
        self,
        base64_image: str,
        prompt: str,
        system_prompt: str = "",
        temperature: Optional[float] = None,
        max_tokens: Optional[int] = None
    ) -> Dict[str, Any]:
        """
        Send a multimodal completion request with a base64 encoded medical image.
        Automatically detects image MIME type and formats per OpenAI vision spec.
        """
        img_url = self._format_image_url(base64_image)
        user_content = [
            {"type": "text", "text": prompt},
            {"type": "image_url", "image_url": {"url": img_url}}
        ]
        messages = self._build_messages(system_prompt, user_content)
        return self._execute_with_retry(
            messages=messages,
            temperature=temperature,
            response_format=None,
            call_type="VISION",
            max_tokens=max_tokens
        )

    # ─────────────────────────────────────────────────────────────────────────
    # INTERNAL HELPERS
    # ─────────────────────────────────────────────────────────────────────────

    def _build_messages(
        self,
        system_prompt: str,
        user_content: Union[str, List[Dict]]
    ) -> List[Dict]:
        """Construct OpenAI-compatible message array."""
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        if isinstance(user_content, str):
            messages.append({"role": "user", "content": user_content})
        else:
            messages.append({"role": "user", "content": user_content})
        return messages

    def _format_image_url(self, base64_data: str) -> str:
        """Normalize base64 image data into OpenAI vision-compatible URL format."""
        if base64_data.startswith(("data:image/png;base64,", "data:image/jpeg;base64,", "data:image/jpg;base64,")):
            return base64_data
        # Default to JPEG if no MIME prefix is present
        return f"data:image/jpeg;base64,{base64_data}"

    def _attempt_call(
        self,
        messages: List[Dict],
        temperature: Optional[float],
        response_format: Optional[Dict],
        max_tokens: Optional[int] = None,
        extra_body: Optional[Dict] = None,
    ) -> Dict[str, Any]:
        """Execute a single API call with the OpenAI client."""
        kwargs = {
            "model": self.model,
            "messages": messages,
            "temperature": temperature if temperature is not None else self.default_temperature,
        }
        if max_tokens:
            kwargs["max_tokens"] = max_tokens
        if response_format:
            kwargs["response_format"] = response_format
        if extra_body:
            kwargs["extra_body"] = extra_body

        response = self.client.chat.completions.create(**kwargs)
        content = response.choices[0].message.content or ""
        
        usage = response.usage
        return {
            "success": True,
            "content": content,
            "raw_response": response,
            "usage": {
                "prompt_tokens": usage.prompt_tokens if usage else 0,
                "completion_tokens": usage.completion_tokens if usage else 0,
                "total_tokens": usage.total_tokens if usage else 0,
            },
            "model": response.model,
            "error": None
        }

    def _execute_with_retry(
        self,
        messages: List[Dict],
        temperature: Optional[float],
        response_format: Optional[Dict],
        call_type: str,
        max_tokens: Optional[int] = None,
        extra_body: Optional[Dict] = None,
    ) -> Dict[str, Any]:
        """Retry wrapper with exponential backoff for robust local inference."""
        last_error = None
        for attempt in range(1, self.max_retries + 1):
            try:
                result = self._attempt_call(messages, temperature, response_format, max_tokens, extra_body)
                if result["success"]:
                    logger.debug(f"{call_type} generation successful on attempt {attempt}")
                    return result
            except Exception as e:
                last_error = str(e)
                logger.warning(f"{call_type} generation failed on attempt {attempt}/{self.max_retries}: {e}")
                if attempt < self.max_retries:
                    # Short fixed backoff for local inference β€” no need for exponential waits
                    backoff = 1.0
                    logger.info(f"Retrying in {backoff}s...")
                    time.sleep(backoff)

        logger.error(f"{call_type} generation failed permanently after {self.max_retries} attempts.")
        return {
            "success": False,
            "content": "",
            "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
            "model": self.model,
            "error": last_error or f"{call_type} endpoint unreachable or max retries exceeded."
        }

    # ─────────────────────────────────────────────────────────────────────────
    # RESPONSE PARSING UTILITIES
    # ─────────────────────────────────────────────────────────────────────────

    @staticmethod
    def extract_json_from_response(content: str) -> Optional[Dict[str, Any]]:
        """
        Safely extract JSON from LLM output, stripping markdown formatting
        and handling partial/comma-separated JSON arrays if necessary.
        """
        if not content:
            return None
        try:
            # Strip markdown code fences if present
            cleaned = re.sub(r"^```(?:json)?\s*|\s*```$", "", content.strip(), flags=re.MULTILINE)
            # First try direct JSON decode
            return LLMClient._safe_json_decode(cleaned)
        except Exception:
            logger.debug("Direct JSON extraction failed. Attempting fallback parsing...")
            return LLMClient._fallback_json_parse(cleaned)

    @staticmethod
    def _safe_json_decode(text: str):
        """Import json lazily and decode, raising cleanly on failure."""
        import json
        return json.loads(text)

    @staticmethod
    def _fallback_json_parse(text: str) -> Optional[Dict[str, Any]]:
        """
        Fallback: scan for first valid JSON object or array in the text.
        Handles cases where the LLM adds conversational padding.
        """
        import json
        brace_depth = 0
        start_idx = None
        for i, char in enumerate(text):
            if char == "{":
                if brace_depth == 0:
                    start_idx = i
                brace_depth += 1
            elif char == "}":
                brace_depth -= 1
                if brace_depth == 0 and start_idx is not None:
                    candidate = text[start_idx:i+1]
                    try:
                        return json.loads(candidate)
                    except json.JSONDecodeError:
                        continue
        # Try array fallback
        bracket_depth = 0
        start_idx = None
        for i, char in enumerate(text):
            if char == "[":
                if bracket_depth == 0:
                    start_idx = i
                bracket_depth += 1
            elif char == "]":
                bracket_depth -= 1
                if bracket_depth == 0 and start_idx is not None:
                    candidate = text[start_idx:i+1]
                    try:
                        return json.loads(candidate)
                    except json.JSONDecodeError:
                        continue
        return None