| import json |
| import argparse |
| import random |
| from typing import List, Dict, Any, Tuple |
| import re |
| from collections import defaultdict |
|
|
| |
| CATEGORY_ORDER = [ |
| "detection", |
| "classification", |
| "localization", |
| "comparison", |
| "relationship", |
| "diagnosis", |
| "characterization", |
| ] |
|
|
|
|
| def extract_letter_answer(answer: str) -> str: |
| """Extract just the letter answer from various answer formats. |
| |
| Args: |
| answer: The answer string to extract a letter from |
| |
| Returns: |
| str: The extracted letter in uppercase, or empty string if no letter found |
| """ |
| if not answer: |
| return "" |
|
|
| |
| answer = str(answer).strip() |
|
|
| |
| if len(answer) == 1 and answer.upper() in "ABCDEF": |
| return answer.upper() |
|
|
| |
| match = re.match(r"^([A-F])[).\s]", answer, re.IGNORECASE) |
| if match: |
| return match.group(1).upper() |
|
|
| |
| |
| matches = re.findall(r"(?:^|\s)([A-F])(?:[).\s]|$)", answer, re.IGNORECASE) |
| if matches: |
| return matches[0].upper() |
|
|
| |
| letters = re.findall(r"[A-F]", answer, re.IGNORECASE) |
| if letters: |
| return letters[0].upper() |
|
|
| |
| return answer.strip().upper() |
|
|
|
|
| def parse_json_lines(file_path: str) -> Tuple[str, List[Dict[str, Any]]]: |
| """Parse JSON Lines file and extract valid predictions. |
| |
| Args: |
| file_path: Path to the JSON Lines file to parse |
| |
| Returns: |
| Tuple containing: |
| - str: Model name or file path if model name not found |
| - List[Dict[str, Any]]: List of valid prediction entries |
| """ |
| valid_predictions = [] |
| model_name = None |
|
|
| |
| try: |
| with open(file_path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| if data.get("model") == "llava-med-v1.5-mistral-7b": |
| model_name = data["model"] |
| for result in data.get("results", []): |
| if all(k in result for k in ["case_id", "question_id", "correct_answer"]): |
| |
| model_answer = ( |
| result.get("model_answer") |
| or result.get("validated_answer") |
| or result.get("raw_output", "") |
| ) |
|
|
| |
| prediction = { |
| "case_id": result["case_id"], |
| "question_id": result["question_id"], |
| "model_answer": model_answer, |
| "correct_answer": result["correct_answer"], |
| "input": { |
| "question_data": { |
| "metadata": { |
| "categories": [ |
| "detection", |
| "classification", |
| "localization", |
| "comparison", |
| "relationship", |
| "diagnosis", |
| "characterization", |
| ] |
| } |
| } |
| }, |
| } |
| valid_predictions.append(prediction) |
| return model_name, valid_predictions |
| except (json.JSONDecodeError, KeyError): |
| pass |
|
|
| |
| with open(file_path, "r", encoding="utf-8") as f: |
| for line in f: |
| if line.startswith("HTTP Request:"): |
| continue |
| try: |
| data = json.loads(line.strip()) |
| if "model" in data: |
| model_name = data["model"] |
| if all( |
| k in data for k in ["model_answer", "correct_answer", "case_id", "question_id"] |
| ): |
| valid_predictions.append(data) |
| except json.JSONDecodeError: |
| continue |
|
|
| return model_name if model_name else file_path, valid_predictions |
|
|
|
|
| def filter_common_questions( |
| predictions_list: List[List[Dict[str, Any]]] |
| ) -> List[List[Dict[str, Any]]]: |
| """Ensure only questions that exist across all models are evaluated. |
| |
| Args: |
| predictions_list: List of prediction lists from different models |
| |
| Returns: |
| List[List[Dict[str, Any]]]: Filtered predictions containing only common questions |
| """ |
| question_sets = [ |
| set((p["case_id"], p["question_id"]) for p in preds) for preds in predictions_list |
| ] |
| common_questions = set.intersection(*question_sets) |
|
|
| return [ |
| [p for p in preds if (p["case_id"], p["question_id"]) in common_questions] |
| for preds in predictions_list |
| ] |
|
|
|
|
| def calculate_accuracy( |
| predictions: List[Dict[str, Any]] |
| ) -> Tuple[float, int, int, Dict[str, Dict[str, float]]]: |
| """Compute overall and category-level accuracy. |
| |
| Args: |
| predictions: List of prediction entries to analyze |
| |
| Returns: |
| Tuple containing: |
| - float: Overall accuracy percentage |
| - int: Number of correct predictions |
| - int: Total number of predictions |
| - Dict[str, Dict[str, float]]: Category-level accuracy statistics |
| """ |
| if not predictions: |
| return 0.0, 0, 0, {} |
|
|
| category_performance = defaultdict(lambda: {"total": 0, "correct": 0}) |
| correct = 0 |
| total = 0 |
| sample_size = min(5, len(predictions)) |
| sampled_indices = random.sample(range(len(predictions)), sample_size) |
|
|
| print("\nSample extracted answers:") |
| for i in sampled_indices: |
| pred = predictions[i] |
| model_ans = extract_letter_answer(pred["model_answer"]) |
| correct_ans = extract_letter_answer(pred["correct_answer"]) |
| print(f"QID: {pred['question_id']}") |
| print(f" Raw Model Answer: {pred['model_answer']}") |
| print(f" Extracted Model Answer: {model_ans}") |
| print(f" Raw Correct Answer: {pred['correct_answer']}") |
| print(f" Extracted Correct Answer: {correct_ans}") |
| print("-" * 80) |
|
|
| for pred in predictions: |
| try: |
| model_ans = extract_letter_answer(pred["model_answer"]) |
| correct_ans = extract_letter_answer(pred["correct_answer"]) |
| categories = ( |
| pred.get("input", {}) |
| .get("question_data", {}) |
| .get("metadata", {}) |
| .get("categories", []) |
| ) |
|
|
| if model_ans and correct_ans: |
| total += 1 |
| is_correct = model_ans == correct_ans |
| if is_correct: |
| correct += 1 |
|
|
| for category in categories: |
| category_performance[category]["total"] += 1 |
| if is_correct: |
| category_performance[category]["correct"] += 1 |
|
|
| except KeyError: |
| continue |
|
|
| category_accuracies = { |
| category: { |
| "accuracy": (stats["correct"] / stats["total"]) * 100 if stats["total"] > 0 else 0, |
| "total": stats["total"], |
| "correct": stats["correct"], |
| } |
| for category, stats in category_performance.items() |
| } |
|
|
| return (correct / total * 100 if total > 0 else 0.0, correct, total, category_accuracies) |
|
|
|
|
| def compare_models(file_paths: List[str]) -> None: |
| """Compare accuracy between multiple model prediction files. |
| |
| Args: |
| file_paths: List of paths to model prediction files to compare |
| """ |
| |
| parsed_results = [parse_json_lines(file_path) for file_path in file_paths] |
| model_names, predictions_list = zip(*parsed_results) |
|
|
| |
| print(f"\n📊 **Initial Accuracy**:") |
| results = [] |
| category_results = [] |
|
|
| for preds, name in zip(predictions_list, model_names): |
| acc, correct, total, category_acc = calculate_accuracy(preds) |
| results.append((acc, correct, total, name)) |
| category_results.append(category_acc) |
| print(f"{name}: Accuracy = {acc:.2f}% ({correct}/{total} correct)") |
|
|
| |
| filtered_predictions = filter_common_questions(predictions_list) |
| print( |
| f"\nQuestions per model after ensuring common questions: {[len(p) for p in filtered_predictions]}" |
| ) |
|
|
| |
| print(f"\n📊 **Accuracy on Common Questions**:") |
| filtered_results = [] |
| filtered_category_results = [] |
|
|
| for preds, name in zip(filtered_predictions, model_names): |
| acc, correct, total, category_acc = calculate_accuracy(preds) |
| filtered_results.append((acc, correct, total, name)) |
| filtered_category_results.append(category_acc) |
| print(f"{name}: Accuracy = {acc:.2f}% ({correct}/{total} correct)") |
|
|
| |
| print("\nCategory Performance (Common Questions):") |
| for category in CATEGORY_ORDER: |
| print(f"\n{category.capitalize()}:") |
| for model_name, category_acc in zip(model_names, filtered_category_results): |
| stats = category_acc.get(category, {"accuracy": 0, "total": 0, "correct": 0}) |
| print(f" {model_name}: {stats['accuracy']:.2f}% ({stats['correct']}/{stats['total']})") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Compare accuracy across multiple model prediction files" |
| ) |
| parser.add_argument("files", nargs="+", help="Paths to model prediction files") |
| parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling") |
|
|
| args = parser.parse_args() |
| random.seed(args.seed) |
|
|
| compare_models(args.files) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|