| from typing import Dict, List, Optional, Any, Union |
| import re |
| import json |
| import os |
| import glob |
| import time |
| import logging |
| import socket |
| import requests |
| import httpx |
| import backoff |
| from datetime import datetime |
| from tenacity import retry, wait_exponential, stop_after_attempt |
| from openai import OpenAI |
|
|
| |
| MODEL_NAME = "meta-llama/llama-3.2-90b-vision-instruct" |
| temperature = 0.2 |
|
|
| |
| log_filename = f"api_usage_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" |
| logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s") |
|
|
|
|
| def verify_dns() -> bool: |
| """Verify DNS resolution and connectivity. |
| |
| Returns: |
| bool: True if DNS resolution succeeds, False otherwise |
| """ |
| try: |
| |
| socket.gethostbyname("openrouter.ai") |
| return True |
| except socket.gaierror: |
| print("DNS resolution failed. Trying to use Google DNS (8.8.8.8)...") |
| |
| try: |
| with open("/etc/resolv.conf", "w") as f: |
| f.write("nameserver 8.8.8.8\n") |
| return True |
| except Exception as e: |
| print(f"Failed to update DNS settings: {e}") |
| return False |
|
|
|
|
| def verify_connection() -> bool: |
| """Verify connection to OpenRouter API. |
| |
| Returns: |
| bool: True if connection succeeds, False otherwise |
| """ |
| try: |
| response = requests.get("https://openrouter.ai/api/v1/status", timeout=10) |
| return response.status_code == 200 |
| except Exception as e: |
| print(f"Connection test failed: {e}") |
| return False |
|
|
|
|
| def initialize_client() -> OpenAI: |
| """Initialize the OpenRouter client with proper timeout settings and connection verification. |
| |
| Returns: |
| OpenAI: Configured OpenAI client for OpenRouter |
| |
| Raises: |
| ValueError: If OPENROUTER_API_KEY environment variable is not set |
| ConnectionError: If DNS verification or connection test fails |
| """ |
| api_key = os.getenv("OPENROUTER_API_KEY") |
| if not api_key: |
| raise ValueError("OPENROUTER_API_KEY environment variable is not set.") |
|
|
| |
| timeout_settings = 120 |
|
|
| |
| if not verify_dns(): |
| raise ConnectionError("DNS verification failed. Please check your network settings.") |
|
|
| if not verify_connection(): |
| raise ConnectionError( |
| "Cannot connect to OpenRouter. Please check your internet connection." |
| ) |
|
|
| |
| return OpenAI( |
| base_url="https://openrouter.ai/api/v1", |
| api_key=api_key, |
| timeout=timeout_settings, |
| http_client=httpx.Client( |
| timeout=timeout_settings, transport=httpx.HTTPTransport(retries=3) |
| ), |
| ) |
|
|
|
|
| @backoff.on_exception( |
| backoff.expo, |
| (ConnectionError, TimeoutError, socket.gaierror, httpx.ConnectError), |
| max_tries=5, |
| max_time=300, |
| ) |
| def create_multimodal_request( |
| question_data: Dict[str, Any], |
| case_details: Dict[str, Any], |
| case_id: str, |
| question_id: str, |
| client: OpenAI, |
| ) -> Optional[Any]: |
| """Create and send a multimodal request to the model. |
| |
| Args: |
| question_data: Dictionary containing question details |
| case_details: Dictionary containing case information |
| case_id: ID of the medical case |
| question_id: ID of the specific question |
| client: OpenAI client instance |
| |
| Returns: |
| Optional[Any]: Model response if successful, None if skipped |
| |
| Raises: |
| ConnectionError: If connection fails |
| TimeoutError: If request times out |
| Exception: For other errors |
| """ |
|
|
| system_prompt = """You are a medical imaging expert. Your task is to provide ONLY a single letter answer. |
| Rules: |
| 1. Respond with exactly one uppercase letter (A/B/C/D/E/F) |
| 2. Do not add periods, explanations, or any other text |
| 3. Do not use markdown or formatting |
| 4. Do not restate the question |
| 5. Do not explain your reasoning |
| |
| Examples of valid responses: |
| A |
| B |
| C |
| |
| Examples of invalid responses: |
| "A." |
| "Answer: B" |
| "C) This shows..." |
| "The answer is D" |
| """ |
|
|
| prompt = f"""Given the following medical case: |
| Please answer this multiple choice question: |
| {question_data['question']} |
| Base your answer only on the provided images and case information.""" |
|
|
| |
| try: |
| if isinstance(question_data["figures"], str): |
| try: |
| required_figures = json.loads(question_data["figures"]) |
| except json.JSONDecodeError: |
| required_figures = [question_data["figures"]] |
| elif isinstance(question_data["figures"], list): |
| required_figures = question_data["figures"] |
| else: |
| required_figures = [str(question_data["figures"])] |
| except Exception as e: |
| print(f"Error parsing figures: {e}") |
| required_figures = [] |
|
|
| required_figures = [ |
| fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures |
| ] |
|
|
| |
| content = [{"type": "text", "text": prompt}] |
| image_urls = [] |
| image_captions = [] |
|
|
| for figure in required_figures: |
| base_figure_num = "".join(filter(str.isdigit, figure)) |
| figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None |
|
|
| matching_figures = [ |
| case_figure |
| for case_figure in case_details.get("figures", []) |
| if case_figure["number"] == f"Figure {base_figure_num}" |
| ] |
|
|
| for case_figure in matching_figures: |
| subfigures = [] |
| if figure_letter: |
| subfigures = [ |
| subfig |
| for subfig in case_figure.get("subfigures", []) |
| if subfig.get("number", "").lower().endswith(figure_letter.lower()) |
| or subfig.get("label", "").lower() == figure_letter.lower() |
| ] |
| else: |
| subfigures = case_figure.get("subfigures", []) |
|
|
| for subfig in subfigures: |
| if "url" in subfig: |
| content.append({"type": "image_url", "image_url": {"url": subfig["url"]}}) |
| image_urls.append(subfig["url"]) |
| image_captions.append(subfig.get("caption", "")) |
|
|
| if len(content) == 1: |
| print(f"No images found for case {case_id}, question {question_id}") |
| |
| log_entry = { |
| "case_id": case_id, |
| "question_id": question_id, |
| "timestamp": datetime.now().isoformat(), |
| "model": MODEL_NAME, |
| "status": "skipped", |
| "reason": "no_images", |
| "input": { |
| "question_data": { |
| "question": question_data["question"], |
| "explanation": question_data["explanation"], |
| "metadata": question_data.get("metadata", {}), |
| "figures": question_data["figures"], |
| }, |
| "image_urls": image_urls, |
| }, |
| } |
| logging.info(json.dumps(log_entry)) |
| return None |
|
|
| try: |
| start_time = time.time() |
|
|
| response = client.chat.completions.create( |
| model=MODEL_NAME, |
| temperature=temperature, |
| messages=[ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": content}, |
| ], |
| ) |
| duration = time.time() - start_time |
|
|
| |
| raw_answer = response.choices[0].message.content |
|
|
| |
| clean_answer = validate_answer(raw_answer) |
|
|
| if not clean_answer: |
| print(f"Warning: Invalid response format for case {case_id}, question {question_id}") |
| print(f"Raw response: {raw_answer}") |
|
|
| |
| response.choices[0].message.content = clean_answer |
|
|
| |
| log_entry = { |
| "case_id": case_id, |
| "question_id": question_id, |
| "timestamp": datetime.now().isoformat(), |
| "model": MODEL_NAME, |
| "temperature": temperature, |
| "duration": round(duration, 2), |
| "usage": { |
| "prompt_tokens": response.usage.prompt_tokens, |
| "completion_tokens": response.usage.completion_tokens, |
| "total_tokens": response.usage.total_tokens, |
| }, |
| "model_answer": response.choices[0].message.content, |
| "correct_answer": question_data["answer"], |
| "input": { |
| "question_data": { |
| "question": question_data["question"], |
| "explanation": question_data["explanation"], |
| "metadata": question_data.get("metadata", {}), |
| "figures": question_data["figures"], |
| }, |
| "image_urls": image_urls, |
| }, |
| } |
| logging.info(json.dumps(log_entry)) |
| return response |
|
|
| except ConnectionError as e: |
| print(f"Connection error for case {case_id}, question {question_id}: {str(e)}") |
| print("Retrying after a longer delay...") |
| time.sleep(30) |
| raise |
| except TimeoutError as e: |
| print(f"Timeout error for case {case_id}, question {question_id}: {str(e)}") |
| print("Retrying with increased timeout...") |
| raise |
| except Exception as e: |
| |
| log_entry = { |
| "case_id": case_id, |
| "question_id": question_id, |
| "timestamp": datetime.now().isoformat(), |
| "model": MODEL_NAME, |
| "temperature": temperature, |
| "status": "error", |
| "error": str(e), |
| "input": { |
| "question_data": { |
| "question": question_data["question"], |
| "explanation": question_data["explanation"], |
| "metadata": question_data.get("metadata", {}), |
| "figures": question_data["figures"], |
| }, |
| "image_urls": image_urls, |
| }, |
| } |
| logging.info(json.dumps(log_entry)) |
| raise |
|
|
|
|
| def extract_answer(response_text: str) -> Optional[str]: |
| """Extract single letter answer from model response. |
| |
| Args: |
| response_text: Raw text response from model |
| |
| Returns: |
| Optional[str]: Single letter answer if found, None otherwise |
| """ |
| |
| text = response_text.upper().replace(".", "") |
|
|
| |
| patterns = [ |
| r"ANSWER:\s*([A-F])", |
| r"OPTION\s*([A-F])", |
| r"([A-F])\)", |
| r"\b([A-F])\b", |
| ] |
|
|
| for pattern in patterns: |
| matches = re.findall(pattern, text) |
| if matches: |
| return matches[0] |
|
|
| return None |
|
|
|
|
| def validate_answer(response_text: str) -> Optional[str]: |
| """Enforce strict single-letter response format. |
| |
| Args: |
| response_text: Raw text response from model |
| |
| Returns: |
| Optional[str]: Valid single letter answer if found, None otherwise |
| """ |
| if not response_text: |
| return None |
|
|
| |
| cleaned = response_text.strip().upper() |
|
|
| |
| if len(cleaned) == 1 and cleaned in "ABCDEF": |
| return cleaned |
|
|
| |
| match = re.search(r"([A-F])", cleaned) |
| return match.group(1) if match else None |
|
|
|
|
| def load_benchmark_questions(case_id: str) -> List[str]: |
| """Find all question files for a given case ID. |
| |
| Args: |
| case_id: ID of the medical case |
| |
| Returns: |
| List[str]: List of paths to question files |
| """ |
| benchmark_dir = "../benchmark/questions" |
| return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json") |
|
|
|
|
| def count_total_questions() -> Tuple[int, int]: |
| """Count total number of cases and questions. |
| |
| Returns: |
| Tuple[int, int]: (total_cases, total_questions) |
| """ |
| total_cases = len(glob.glob("../benchmark/questions/*")) |
| total_questions = sum( |
| len(glob.glob(f"../benchmark/questions/{case_id}/*.json")) |
| for case_id in os.listdir("../benchmark/questions") |
| ) |
| return total_cases, total_questions |
|
|
|
|
| def main(): |
| with open("../data/eurorad_metadata.json", "r") as file: |
| data = json.load(file) |
|
|
| client = initialize_client() |
| total_cases, total_questions = count_total_questions() |
| cases_processed = 0 |
| questions_processed = 0 |
| skipped_questions = 0 |
|
|
| print(f"Beginning benchmark evaluation for {MODEL_NAME} with temperature {temperature}") |
|
|
| for case_id, case_details in data.items(): |
| question_files = load_benchmark_questions(case_id) |
| if not question_files: |
| continue |
|
|
| cases_processed += 1 |
| for question_file in question_files: |
| with open(question_file, "r") as file: |
| question_data = json.load(file) |
| question_id = os.path.basename(question_file).split(".")[0] |
|
|
| questions_processed += 1 |
| response = create_multimodal_request( |
| question_data, case_details, case_id, question_id, client |
| ) |
|
|
| if response is None: |
| skipped_questions += 1 |
| print(f"Skipped question: Case ID {case_id}, Question ID {question_id}") |
| continue |
|
|
| print( |
| f"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}" |
| ) |
| print(f"Case ID: {case_id}") |
| print(f"Question ID: {question_id}") |
| print(f"Model Answer: {response.choices[0].message.content}") |
| print(f"Correct Answer: {question_data['answer']}\n") |
|
|
| print(f"\nBenchmark Summary:") |
| print(f"Total Cases Processed: {cases_processed}") |
| print(f"Total Questions Processed: {questions_processed}") |
| print(f"Total Questions Skipped: {skipped_questions}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|