""" SAHELI Medical Benchmark Evaluation Evaluates Gemma 4 E4B (base and fine-tuned) on: 1. MedQA-USMLE (4-option MCQ) 2. MedMCQA (Indian clinical MCQ) 3. PubMedQA (biomedical yes/no/maybe) 4. AfriMed-QA v2 (African medical context) Outputs accuracy scores and per-category breakdowns. """ import os import json import torch import re from datetime import datetime from datasets import load_dataset from transformers import AutoProcessor, AutoModelForCausalLM # ============================================================ # Configuration # ============================================================ MODEL_ID = os.environ.get("MODEL_ID", "google/gemma-4-E4B-it") MAX_SAMPLES = int(os.environ.get("EVAL_SAMPLES", "200")) # Per benchmark OUTPUT_FILE = "benchmark_results.json" print(f"Evaluating: {MODEL_ID}") print(f"Samples per benchmark: {MAX_SAMPLES}") # ============================================================ # Load Model # ============================================================ processor = AutoProcessor.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, dtype="auto", device_map="auto", ) model.eval() def generate_answer(prompt, max_tokens=64): """Generate a short answer from the model.""" text = processor.apply_chat_template( [{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True, enable_thinking=False, ) inputs = processor(text=text, return_tensors="pt").to(model.device) input_len = inputs["input_ids"].shape[-1] with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_tokens, temperature=0.1, # Low temp for deterministic eval top_p=0.95, do_sample=True, ) response = processor.decode(outputs[0][input_len:], skip_special_tokens=True) return response.strip() def extract_answer_letter(response): """Extract single letter answer (A/B/C/D) from model response.""" response = response.strip().upper() # Direct letter if response and response[0] in "ABCD": return response[0] # "The answer is X" pattern match = re.search(r'(?:answer|option)\s*(?:is|:)\s*([A-D])', response, re.IGNORECASE) if match: return match.group(1).upper() # Letter followed by period or parenthesis match = re.search(r'\b([A-D])[.):]', response) if match: return match.group(1).upper() return response[0] if response else "X" # ============================================================ # Benchmark 1: MedQA-USMLE # ============================================================ def eval_medqa(): print("\n" + "=" * 60) print("Benchmark: MedQA-USMLE (4-option)") print("=" * 60) ds = load_dataset("GBaker/MedQA-USMLE-4-options", split="test") if len(ds) > MAX_SAMPLES: ds = ds.select(range(MAX_SAMPLES)) correct = 0 total = 0 for i, row in enumerate(ds): question = row["question"] options = row["options"] answer = row["answer_idx"] # "A", "B", "C", or "D" prompt = f"""You are a medical expert. Answer the following USMLE-style question by selecting the single best answer. Question: {question} A. {options['A']} B. {options['B']} C. {options['C']} D. {options['D']} Answer with ONLY the letter (A, B, C, or D):""" response = generate_answer(prompt) predicted = extract_answer_letter(response) if predicted == answer: correct += 1 total += 1 if (i + 1) % 50 == 0: print(f" Progress: {i+1}/{len(ds)}, Accuracy: {correct/total:.1%}") accuracy = correct / total if total > 0 else 0 print(f" MedQA-USMLE Accuracy: {accuracy:.1%} ({correct}/{total})") return {"benchmark": "MedQA-USMLE", "accuracy": accuracy, "correct": correct, "total": total} # ============================================================ # Benchmark 2: MedMCQA # ============================================================ def eval_medmcqa(): print("\n" + "=" * 60) print("Benchmark: MedMCQA (Indian Clinical)") print("=" * 60) ds = load_dataset("openlifescienceai/medmcqa", split="validation") if len(ds) > MAX_SAMPLES: ds = ds.select(range(MAX_SAMPLES)) correct = 0 total = 0 idx_to_letter = {0: "A", 1: "B", 2: "C", 3: "D"} for i, row in enumerate(ds): question = row["question"] options = [row["opa"], row["opb"], row["opc"], row["opd"]] answer_idx = row["cop"] # 0-3 answer = idx_to_letter.get(answer_idx, "A") prompt = f"""You are a medical expert. Answer this clinical question. Question: {question} A. {options[0]} B. {options[1]} C. {options[2]} D. {options[3]} Answer with ONLY the letter (A, B, C, or D):""" response = generate_answer(prompt) predicted = extract_answer_letter(response) if predicted == answer: correct += 1 total += 1 if (i + 1) % 50 == 0: print(f" Progress: {i+1}/{len(ds)}, Accuracy: {correct/total:.1%}") accuracy = correct / total if total > 0 else 0 print(f" MedMCQA Accuracy: {accuracy:.1%} ({correct}/{total})") return {"benchmark": "MedMCQA", "accuracy": accuracy, "correct": correct, "total": total} # ============================================================ # Benchmark 3: PubMedQA # ============================================================ def eval_pubmedqa(): print("\n" + "=" * 60) print("Benchmark: PubMedQA (Labeled)") print("=" * 60) ds = load_dataset("qiaojin/PubMedQA", "pqa_labeled", split="train") if len(ds) > MAX_SAMPLES: ds = ds.select(range(MAX_SAMPLES)) correct = 0 total = 0 for i, row in enumerate(ds): question = row["question"] context = " ".join(row["context"]["contexts"][:3]) if row.get("context") else "" answer = row["final_decision"] # "yes", "no", "maybe" prompt = f"""Based on the following biomedical context, answer the question with exactly one word: yes, no, or maybe. Context: {context[:1000]} Question: {question} Answer (yes/no/maybe):""" response = generate_answer(prompt, max_tokens=10) response_lower = response.lower().strip() if "yes" in response_lower: predicted = "yes" elif "no" in response_lower: predicted = "no" elif "maybe" in response_lower: predicted = "maybe" else: predicted = response_lower if predicted == answer: correct += 1 total += 1 if (i + 1) % 50 == 0: print(f" Progress: {i+1}/{len(ds)}, Accuracy: {correct/total:.1%}") accuracy = correct / total if total > 0 else 0 print(f" PubMedQA Accuracy: {accuracy:.1%} ({correct}/{total})") return {"benchmark": "PubMedQA", "accuracy": accuracy, "correct": correct, "total": total} # ============================================================ # Benchmark 4: AfriMed-QA v2 # ============================================================ def eval_afrimedqa(): print("\n" + "=" * 60) print("Benchmark: AfriMed-QA v2 (African Medical)") print("=" * 60) try: ds = load_dataset("afrimedqa/afrimedqa_v2", split="train") except Exception: print(" AfriMed-QA v2 not available, skipping") return {"benchmark": "AfriMed-QA", "accuracy": 0, "correct": 0, "total": 0, "error": "Dataset not available"} # Filter to MCQ with valid answers mcq_data = [row for row in ds if row.get("answer_options") and row.get("correct_answer")] if len(mcq_data) > MAX_SAMPLES: mcq_data = mcq_data[:MAX_SAMPLES] if not mcq_data: return {"benchmark": "AfriMed-QA", "accuracy": 0, "correct": 0, "total": 0, "error": "No MCQ data"} correct = 0 total = 0 for i, row in enumerate(mcq_data): question = row["question"] options = row["answer_options"] answer = row["correct_answer"] # Format options options_text = "" if isinstance(options, dict): for k, v in options.items(): options_text += f"{k}. {v}\n" elif isinstance(options, list): for j, opt in enumerate(options): letter = chr(65 + j) options_text += f"{letter}. {opt}\n" prompt = f"""You are a medical expert working in an African healthcare context. Answer this clinical question. Question: {question} {options_text} Answer with ONLY the letter:""" response = generate_answer(prompt) predicted = extract_answer_letter(response) if predicted == str(answer).upper(): correct += 1 total += 1 if (i + 1) % 50 == 0: print(f" Progress: {i+1}/{len(mcq_data)}, Accuracy: {correct/total:.1%}") accuracy = correct / total if total > 0 else 0 print(f" AfriMed-QA Accuracy: {accuracy:.1%} ({correct}/{total})") return {"benchmark": "AfriMed-QA", "accuracy": accuracy, "correct": correct, "total": total} # ============================================================ # Run All Benchmarks # ============================================================ if __name__ == "__main__": print(f"\nStarting SAHELI Medical Benchmark Evaluation") print(f"Model: {MODEL_ID}") print(f"Date: {datetime.now().isoformat()}") print(f"Samples per benchmark: {MAX_SAMPLES}") results = { "model": MODEL_ID, "date": datetime.now().isoformat(), "max_samples_per_benchmark": MAX_SAMPLES, "benchmarks": [] } for eval_fn in [eval_medqa, eval_medmcqa, eval_pubmedqa, eval_afrimedqa]: try: result = eval_fn() results["benchmarks"].append(result) except Exception as e: print(f" Error: {e}") results["benchmarks"].append({"error": str(e)}) # Summary print("\n" + "=" * 60) print("SAHELI BENCHMARK RESULTS SUMMARY") print("=" * 60) for r in results["benchmarks"]: if "error" not in r: print(f" {r['benchmark']:20s}: {r['accuracy']:.1%} ({r['correct']}/{r['total']})") # Save results with open(OUTPUT_FILE, "w") as f: json.dump(results, f, indent=2) print(f"\nResults saved to {OUTPUT_FILE}") # Upload results to Hub try: from huggingface_hub import HfApi api = HfApi() api.upload_file( path_or_fileobj=OUTPUT_FILE, path_in_repo="benchmark_results.json", repo_id="muthuk1/saheli-gemma4-e4b-medical", repo_type="model", ) print(f"Results uploaded to https://huggingface.co/muthuk1/saheli-gemma4-e4b-medical") except Exception as e: print(f"Could not upload to Hub: {e}")