| """ |
| SAHELI Medical Benchmark Evaluation |
| Evaluates Gemma 4 E4B (base and fine-tuned) on: |
| 1. MedQA-USMLE (4-option MCQ) |
| 2. MedMCQA (Indian clinical MCQ) |
| 3. PubMedQA (biomedical yes/no/maybe) |
| 4. AfriMed-QA v2 (African medical context) |
| |
| Outputs accuracy scores and per-category breakdowns. |
| """ |
|
|
| import os |
| import json |
| import torch |
| import re |
| from datetime import datetime |
| from datasets import load_dataset |
| from transformers import AutoProcessor, AutoModelForCausalLM |
|
|
| |
| |
| |
| MODEL_ID = os.environ.get("MODEL_ID", "google/gemma-4-E4B-it") |
| MAX_SAMPLES = int(os.environ.get("EVAL_SAMPLES", "200")) |
| OUTPUT_FILE = "benchmark_results.json" |
|
|
| print(f"Evaluating: {MODEL_ID}") |
| print(f"Samples per benchmark: {MAX_SAMPLES}") |
|
|
| |
| |
| |
| processor = AutoProcessor.from_pretrained(MODEL_ID) |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| dtype="auto", |
| device_map="auto", |
| ) |
| model.eval() |
|
|
| def generate_answer(prompt, max_tokens=64): |
| """Generate a short answer from the model.""" |
| text = processor.apply_chat_template( |
| [{"role": "user", "content": prompt}], |
| tokenize=False, |
| add_generation_prompt=True, |
| enable_thinking=False, |
| ) |
| inputs = processor(text=text, return_tensors="pt").to(model.device) |
| input_len = inputs["input_ids"].shape[-1] |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=max_tokens, |
| temperature=0.1, |
| top_p=0.95, |
| do_sample=True, |
| ) |
| |
| response = processor.decode(outputs[0][input_len:], skip_special_tokens=True) |
| return response.strip() |
|
|
| def extract_answer_letter(response): |
| """Extract single letter answer (A/B/C/D) from model response.""" |
| response = response.strip().upper() |
| |
| |
| if response and response[0] in "ABCD": |
| return response[0] |
| |
| |
| match = re.search(r'(?:answer|option)\s*(?:is|:)\s*([A-D])', response, re.IGNORECASE) |
| if match: |
| return match.group(1).upper() |
| |
| |
| match = re.search(r'\b([A-D])[.):]', response) |
| if match: |
| return match.group(1).upper() |
| |
| return response[0] if response else "X" |
|
|
| |
| |
| |
| def eval_medqa(): |
| print("\n" + "=" * 60) |
| print("Benchmark: MedQA-USMLE (4-option)") |
| print("=" * 60) |
| |
| ds = load_dataset("GBaker/MedQA-USMLE-4-options", split="test") |
| if len(ds) > MAX_SAMPLES: |
| ds = ds.select(range(MAX_SAMPLES)) |
| |
| correct = 0 |
| total = 0 |
| |
| for i, row in enumerate(ds): |
| question = row["question"] |
| options = row["options"] |
| answer = row["answer_idx"] |
| |
| prompt = f"""You are a medical expert. Answer the following USMLE-style question by selecting the single best answer. |
| |
| Question: {question} |
| |
| A. {options['A']} |
| B. {options['B']} |
| C. {options['C']} |
| D. {options['D']} |
| |
| Answer with ONLY the letter (A, B, C, or D):""" |
| |
| response = generate_answer(prompt) |
| predicted = extract_answer_letter(response) |
| |
| if predicted == answer: |
| correct += 1 |
| total += 1 |
| |
| if (i + 1) % 50 == 0: |
| print(f" Progress: {i+1}/{len(ds)}, Accuracy: {correct/total:.1%}") |
| |
| accuracy = correct / total if total > 0 else 0 |
| print(f" MedQA-USMLE Accuracy: {accuracy:.1%} ({correct}/{total})") |
| return {"benchmark": "MedQA-USMLE", "accuracy": accuracy, "correct": correct, "total": total} |
|
|
| |
| |
| |
| def eval_medmcqa(): |
| print("\n" + "=" * 60) |
| print("Benchmark: MedMCQA (Indian Clinical)") |
| print("=" * 60) |
| |
| ds = load_dataset("openlifescienceai/medmcqa", split="validation") |
| if len(ds) > MAX_SAMPLES: |
| ds = ds.select(range(MAX_SAMPLES)) |
| |
| correct = 0 |
| total = 0 |
| idx_to_letter = {0: "A", 1: "B", 2: "C", 3: "D"} |
| |
| for i, row in enumerate(ds): |
| question = row["question"] |
| options = [row["opa"], row["opb"], row["opc"], row["opd"]] |
| answer_idx = row["cop"] |
| answer = idx_to_letter.get(answer_idx, "A") |
| |
| prompt = f"""You are a medical expert. Answer this clinical question. |
| |
| Question: {question} |
| |
| A. {options[0]} |
| B. {options[1]} |
| C. {options[2]} |
| D. {options[3]} |
| |
| Answer with ONLY the letter (A, B, C, or D):""" |
| |
| response = generate_answer(prompt) |
| predicted = extract_answer_letter(response) |
| |
| if predicted == answer: |
| correct += 1 |
| total += 1 |
| |
| if (i + 1) % 50 == 0: |
| print(f" Progress: {i+1}/{len(ds)}, Accuracy: {correct/total:.1%}") |
| |
| accuracy = correct / total if total > 0 else 0 |
| print(f" MedMCQA Accuracy: {accuracy:.1%} ({correct}/{total})") |
| return {"benchmark": "MedMCQA", "accuracy": accuracy, "correct": correct, "total": total} |
|
|
| |
| |
| |
| def eval_pubmedqa(): |
| print("\n" + "=" * 60) |
| print("Benchmark: PubMedQA (Labeled)") |
| print("=" * 60) |
| |
| ds = load_dataset("qiaojin/PubMedQA", "pqa_labeled", split="train") |
| if len(ds) > MAX_SAMPLES: |
| ds = ds.select(range(MAX_SAMPLES)) |
| |
| correct = 0 |
| total = 0 |
| |
| for i, row in enumerate(ds): |
| question = row["question"] |
| context = " ".join(row["context"]["contexts"][:3]) if row.get("context") else "" |
| answer = row["final_decision"] |
| |
| prompt = f"""Based on the following biomedical context, answer the question with exactly one word: yes, no, or maybe. |
| |
| Context: {context[:1000]} |
| |
| Question: {question} |
| |
| Answer (yes/no/maybe):""" |
| |
| response = generate_answer(prompt, max_tokens=10) |
| response_lower = response.lower().strip() |
| |
| if "yes" in response_lower: |
| predicted = "yes" |
| elif "no" in response_lower: |
| predicted = "no" |
| elif "maybe" in response_lower: |
| predicted = "maybe" |
| else: |
| predicted = response_lower |
| |
| if predicted == answer: |
| correct += 1 |
| total += 1 |
| |
| if (i + 1) % 50 == 0: |
| print(f" Progress: {i+1}/{len(ds)}, Accuracy: {correct/total:.1%}") |
| |
| accuracy = correct / total if total > 0 else 0 |
| print(f" PubMedQA Accuracy: {accuracy:.1%} ({correct}/{total})") |
| return {"benchmark": "PubMedQA", "accuracy": accuracy, "correct": correct, "total": total} |
|
|
| |
| |
| |
| def eval_afrimedqa(): |
| print("\n" + "=" * 60) |
| print("Benchmark: AfriMed-QA v2 (African Medical)") |
| print("=" * 60) |
| |
| try: |
| ds = load_dataset("afrimedqa/afrimedqa_v2", split="train") |
| except Exception: |
| print(" AfriMed-QA v2 not available, skipping") |
| return {"benchmark": "AfriMed-QA", "accuracy": 0, "correct": 0, "total": 0, "error": "Dataset not available"} |
| |
| |
| mcq_data = [row for row in ds if row.get("answer_options") and row.get("correct_answer")] |
| if len(mcq_data) > MAX_SAMPLES: |
| mcq_data = mcq_data[:MAX_SAMPLES] |
| |
| if not mcq_data: |
| return {"benchmark": "AfriMed-QA", "accuracy": 0, "correct": 0, "total": 0, "error": "No MCQ data"} |
| |
| correct = 0 |
| total = 0 |
| |
| for i, row in enumerate(mcq_data): |
| question = row["question"] |
| options = row["answer_options"] |
| answer = row["correct_answer"] |
| |
| |
| options_text = "" |
| if isinstance(options, dict): |
| for k, v in options.items(): |
| options_text += f"{k}. {v}\n" |
| elif isinstance(options, list): |
| for j, opt in enumerate(options): |
| letter = chr(65 + j) |
| options_text += f"{letter}. {opt}\n" |
| |
| prompt = f"""You are a medical expert working in an African healthcare context. Answer this clinical question. |
| |
| Question: {question} |
| |
| {options_text} |
| Answer with ONLY the letter:""" |
| |
| response = generate_answer(prompt) |
| predicted = extract_answer_letter(response) |
| |
| if predicted == str(answer).upper(): |
| correct += 1 |
| total += 1 |
| |
| if (i + 1) % 50 == 0: |
| print(f" Progress: {i+1}/{len(mcq_data)}, Accuracy: {correct/total:.1%}") |
| |
| accuracy = correct / total if total > 0 else 0 |
| print(f" AfriMed-QA Accuracy: {accuracy:.1%} ({correct}/{total})") |
| return {"benchmark": "AfriMed-QA", "accuracy": accuracy, "correct": correct, "total": total} |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| print(f"\nStarting SAHELI Medical Benchmark Evaluation") |
| print(f"Model: {MODEL_ID}") |
| print(f"Date: {datetime.now().isoformat()}") |
| print(f"Samples per benchmark: {MAX_SAMPLES}") |
| |
| results = { |
| "model": MODEL_ID, |
| "date": datetime.now().isoformat(), |
| "max_samples_per_benchmark": MAX_SAMPLES, |
| "benchmarks": [] |
| } |
| |
| for eval_fn in [eval_medqa, eval_medmcqa, eval_pubmedqa, eval_afrimedqa]: |
| try: |
| result = eval_fn() |
| results["benchmarks"].append(result) |
| except Exception as e: |
| print(f" Error: {e}") |
| results["benchmarks"].append({"error": str(e)}) |
| |
| |
| print("\n" + "=" * 60) |
| print("SAHELI BENCHMARK RESULTS SUMMARY") |
| print("=" * 60) |
| for r in results["benchmarks"]: |
| if "error" not in r: |
| print(f" {r['benchmark']:20s}: {r['accuracy']:.1%} ({r['correct']}/{r['total']})") |
| |
| |
| with open(OUTPUT_FILE, "w") as f: |
| json.dump(results, f, indent=2) |
| print(f"\nResults saved to {OUTPUT_FILE}") |
|
|
| |
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
| api.upload_file( |
| path_or_fileobj=OUTPUT_FILE, |
| path_in_repo="benchmark_results.json", |
| repo_id="muthuk1/saheli-gemma4-e4b-medical", |
| repo_type="model", |
| ) |
| print(f"Results uploaded to https://huggingface.co/muthuk1/saheli-gemma4-e4b-medical") |
| except Exception as e: |
| print(f"Could not upload to Hub: {e}") |
|
|