| |
| """ |
| Production-grade LLM client wrapper for MediAgent. |
| Handles text and multimodal (vision) completions against local Qwen model. |
| Implements retry logic, error handling, response parsing, and OpenAI-compatible API calls. |
| """ |
|
|
| import logging |
| import time |
| import re |
| from typing import Any, Dict, List, Optional, Union |
|
|
| import openai |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class LLMClient: |
| """ |
| Lightweight, framework-agnostic LLM client wrapping OpenAI Python SDK. |
| Designed for local inference endpoints (vLLM, Ollama, TensorRT-LLM) |
| running at http://localhost:8000/v1 with model path "/model". |
| """ |
|
|
| DEFAULT_BASE_URL = "http://localhost:8000/v1" |
| DEFAULT_MODEL = "/model" |
| DEFAULT_API_KEY = "none" |
|
|
| def __init__( |
| self, |
| base_url: str = DEFAULT_BASE_URL, |
| model: str = DEFAULT_MODEL, |
| max_retries: int = 3, |
| timeout: float = 90.0, |
| temperature: float = 0.0 |
| ): |
| self.model = model |
| self.max_retries = max_retries |
| self.default_temperature = temperature |
| self.timeout = timeout |
|
|
| self.client = openai.OpenAI( |
| base_url=base_url, |
| api_key=self.DEFAULT_API_KEY, |
| timeout=timeout |
| ) |
| logger.info(f"LLMClient initialized | Model: {self.model} | Endpoint: {base_url}") |
|
|
| |
| |
| |
|
|
| def generate_text( |
| self, |
| prompt: str, |
| system_prompt: str = "", |
| temperature: Optional[float] = None, |
| force_json: bool = False, |
| max_tokens: Optional[int] = None, |
| extra_body: Optional[Dict] = None, |
| ) -> Dict[str, Any]: |
| """ |
| Send a text-only completion request to the LLM. |
| Returns standardized response dict with content, usage, success flag, and error. |
| """ |
| messages = self._build_messages(system_prompt, prompt) |
| response_format = {"type": "json_object"} if force_json else None |
| return self._execute_with_retry( |
| messages=messages, |
| temperature=temperature, |
| response_format=response_format, |
| call_type="TEXT", |
| max_tokens=max_tokens, |
| extra_body=extra_body, |
| ) |
|
|
| def generate_text_streaming( |
| self, |
| prompt: str, |
| system_prompt: str = "", |
| temperature: Optional[float] = None, |
| on_token: Optional[Any] = None, |
| ) -> Dict[str, Any]: |
| """ |
| Text completion with optional token-level streaming callback. |
| When on_token is provided, calls on_token(chunk: str) for every token chunk |
| as it arrives from the model. Returns the full response dict at the end. |
| Falls back to standard generate_text if streaming fails. |
| """ |
| if on_token is None: |
| return self.generate_text(prompt, system_prompt, temperature) |
|
|
| messages = self._build_messages(system_prompt, prompt) |
| temp = temperature if temperature is not None else self.default_temperature |
|
|
| try: |
| stream = self.client.chat.completions.create( |
| model=self.model, |
| messages=messages, |
| temperature=temp, |
| stream=True, |
| ) |
| full_content = "" |
| for chunk in stream: |
| delta = (chunk.choices[0].delta.content or "") if chunk.choices else "" |
| if delta: |
| full_content += delta |
| try: |
| on_token(delta) |
| except Exception: |
| pass |
|
|
| logger.debug("Streaming TEXT generation completed | chars=%d", len(full_content)) |
| return { |
| "success": True, |
| "content": full_content, |
| "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, |
| "model": self.model, |
| "error": None, |
| } |
| except Exception as e: |
| logger.warning("Streaming failed (%s), falling back to standard call", e) |
| return self.generate_text(prompt, system_prompt, temperature) |
|
|
| def generate_vision( |
| self, |
| base64_image: str, |
| prompt: str, |
| system_prompt: str = "", |
| temperature: Optional[float] = None, |
| max_tokens: Optional[int] = None |
| ) -> Dict[str, Any]: |
| """ |
| Send a multimodal completion request with a base64 encoded medical image. |
| Automatically detects image MIME type and formats per OpenAI vision spec. |
| """ |
| img_url = self._format_image_url(base64_image) |
| user_content = [ |
| {"type": "text", "text": prompt}, |
| {"type": "image_url", "image_url": {"url": img_url}} |
| ] |
| messages = self._build_messages(system_prompt, user_content) |
| return self._execute_with_retry( |
| messages=messages, |
| temperature=temperature, |
| response_format=None, |
| call_type="VISION", |
| max_tokens=max_tokens |
| ) |
|
|
| |
| |
| |
|
|
| def _build_messages( |
| self, |
| system_prompt: str, |
| user_content: Union[str, List[Dict]] |
| ) -> List[Dict]: |
| """Construct OpenAI-compatible message array.""" |
| messages = [] |
| if system_prompt: |
| messages.append({"role": "system", "content": system_prompt}) |
| if isinstance(user_content, str): |
| messages.append({"role": "user", "content": user_content}) |
| else: |
| messages.append({"role": "user", "content": user_content}) |
| return messages |
|
|
| def _format_image_url(self, base64_data: str) -> str: |
| """Normalize base64 image data into OpenAI vision-compatible URL format.""" |
| if base64_data.startswith(("data:image/png;base64,", "data:image/jpeg;base64,", "data:image/jpg;base64,")): |
| return base64_data |
| |
| return f"data:image/jpeg;base64,{base64_data}" |
|
|
| def _attempt_call( |
| self, |
| messages: List[Dict], |
| temperature: Optional[float], |
| response_format: Optional[Dict], |
| max_tokens: Optional[int] = None, |
| extra_body: Optional[Dict] = None, |
| ) -> Dict[str, Any]: |
| """Execute a single API call with the OpenAI client.""" |
| kwargs = { |
| "model": self.model, |
| "messages": messages, |
| "temperature": temperature if temperature is not None else self.default_temperature, |
| } |
| if max_tokens: |
| kwargs["max_tokens"] = max_tokens |
| if response_format: |
| kwargs["response_format"] = response_format |
| if extra_body: |
| kwargs["extra_body"] = extra_body |
|
|
| response = self.client.chat.completions.create(**kwargs) |
| content = response.choices[0].message.content or "" |
| |
| usage = response.usage |
| return { |
| "success": True, |
| "content": content, |
| "raw_response": response, |
| "usage": { |
| "prompt_tokens": usage.prompt_tokens if usage else 0, |
| "completion_tokens": usage.completion_tokens if usage else 0, |
| "total_tokens": usage.total_tokens if usage else 0, |
| }, |
| "model": response.model, |
| "error": None |
| } |
|
|
| def _execute_with_retry( |
| self, |
| messages: List[Dict], |
| temperature: Optional[float], |
| response_format: Optional[Dict], |
| call_type: str, |
| max_tokens: Optional[int] = None, |
| extra_body: Optional[Dict] = None, |
| ) -> Dict[str, Any]: |
| """Retry wrapper with exponential backoff for robust local inference.""" |
| last_error = None |
| for attempt in range(1, self.max_retries + 1): |
| try: |
| result = self._attempt_call(messages, temperature, response_format, max_tokens, extra_body) |
| if result["success"]: |
| logger.debug(f"{call_type} generation successful on attempt {attempt}") |
| return result |
| except Exception as e: |
| last_error = str(e) |
| logger.warning(f"{call_type} generation failed on attempt {attempt}/{self.max_retries}: {e}") |
| if attempt < self.max_retries: |
| |
| backoff = 1.0 |
| logger.info(f"Retrying in {backoff}s...") |
| time.sleep(backoff) |
|
|
| logger.error(f"{call_type} generation failed permanently after {self.max_retries} attempts.") |
| return { |
| "success": False, |
| "content": "", |
| "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, |
| "model": self.model, |
| "error": last_error or f"{call_type} endpoint unreachable or max retries exceeded." |
| } |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def extract_json_from_response(content: str) -> Optional[Dict[str, Any]]: |
| """ |
| Safely extract JSON from LLM output, stripping markdown formatting |
| and handling partial/comma-separated JSON arrays if necessary. |
| """ |
| if not content: |
| return None |
| try: |
| |
| cleaned = re.sub(r"^```(?:json)?\s*|\s*```$", "", content.strip(), flags=re.MULTILINE) |
| |
| return LLMClient._safe_json_decode(cleaned) |
| except Exception: |
| logger.debug("Direct JSON extraction failed. Attempting fallback parsing...") |
| return LLMClient._fallback_json_parse(cleaned) |
|
|
| @staticmethod |
| def _safe_json_decode(text: str): |
| """Import json lazily and decode, raising cleanly on failure.""" |
| import json |
| return json.loads(text) |
|
|
| @staticmethod |
| def _fallback_json_parse(text: str) -> Optional[Dict[str, Any]]: |
| """ |
| Fallback: scan for first valid JSON object or array in the text. |
| Handles cases where the LLM adds conversational padding. |
| """ |
| import json |
| brace_depth = 0 |
| start_idx = None |
| for i, char in enumerate(text): |
| if char == "{": |
| if brace_depth == 0: |
| start_idx = i |
| brace_depth += 1 |
| elif char == "}": |
| brace_depth -= 1 |
| if brace_depth == 0 and start_idx is not None: |
| candidate = text[start_idx:i+1] |
| try: |
| return json.loads(candidate) |
| except json.JSONDecodeError: |
| continue |
| |
| bracket_depth = 0 |
| start_idx = None |
| for i, char in enumerate(text): |
| if char == "[": |
| if bracket_depth == 0: |
| start_idx = i |
| bracket_depth += 1 |
| elif char == "]": |
| bracket_depth -= 1 |
| if bracket_depth == 0 and start_idx is not None: |
| candidate = text[start_idx:i+1] |
| try: |
| return json.loads(candidate) |
| except json.JSONDecodeError: |
| continue |
| return None |
|
|