Text Generation
PEFT
gemma4
medical
healthcare
community-health
lora
unsloth
who-imci
saheli
fhir
function-calling
thinking-mode
semantic-rag
saheli-gemma4-e4b-medical / eval_benchmarks.py
muthuk1's picture
Upload eval_benchmarks.py with huggingface_hub
74c656c verified
"""
SAHELI Medical Benchmark Evaluation
Evaluates Gemma 4 E4B (base and fine-tuned) on:
1. MedQA-USMLE (4-option MCQ)
2. MedMCQA (Indian clinical MCQ)
3. PubMedQA (biomedical yes/no/maybe)
4. AfriMed-QA v2 (African medical context)
Outputs accuracy scores and per-category breakdowns.
"""
import os
import json
import torch
import re
from datetime import datetime
from datasets import load_dataset
from transformers import AutoProcessor, AutoModelForCausalLM
# ============================================================
# Configuration
# ============================================================
MODEL_ID = os.environ.get("MODEL_ID", "google/gemma-4-E4B-it")
MAX_SAMPLES = int(os.environ.get("EVAL_SAMPLES", "200")) # Per benchmark
OUTPUT_FILE = "benchmark_results.json"
print(f"Evaluating: {MODEL_ID}")
print(f"Samples per benchmark: {MAX_SAMPLES}")
# ============================================================
# Load Model
# ============================================================
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto",
)
model.eval()
def generate_answer(prompt, max_tokens=64):
"""Generate a short answer from the model."""
text = processor.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = processor(text=text, return_tensors="pt").to(model.device)
input_len = inputs["input_ids"].shape[-1]
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.1, # Low temp for deterministic eval
top_p=0.95,
do_sample=True,
)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=True)
return response.strip()
def extract_answer_letter(response):
"""Extract single letter answer (A/B/C/D) from model response."""
response = response.strip().upper()
# Direct letter
if response and response[0] in "ABCD":
return response[0]
# "The answer is X" pattern
match = re.search(r'(?:answer|option)\s*(?:is|:)\s*([A-D])', response, re.IGNORECASE)
if match:
return match.group(1).upper()
# Letter followed by period or parenthesis
match = re.search(r'\b([A-D])[.):]', response)
if match:
return match.group(1).upper()
return response[0] if response else "X"
# ============================================================
# Benchmark 1: MedQA-USMLE
# ============================================================
def eval_medqa():
print("\n" + "=" * 60)
print("Benchmark: MedQA-USMLE (4-option)")
print("=" * 60)
ds = load_dataset("GBaker/MedQA-USMLE-4-options", split="test")
if len(ds) > MAX_SAMPLES:
ds = ds.select(range(MAX_SAMPLES))
correct = 0
total = 0
for i, row in enumerate(ds):
question = row["question"]
options = row["options"]
answer = row["answer_idx"] # "A", "B", "C", or "D"
prompt = f"""You are a medical expert. Answer the following USMLE-style question by selecting the single best answer.
Question: {question}
A. {options['A']}
B. {options['B']}
C. {options['C']}
D. {options['D']}
Answer with ONLY the letter (A, B, C, or D):"""
response = generate_answer(prompt)
predicted = extract_answer_letter(response)
if predicted == answer:
correct += 1
total += 1
if (i + 1) % 50 == 0:
print(f" Progress: {i+1}/{len(ds)}, Accuracy: {correct/total:.1%}")
accuracy = correct / total if total > 0 else 0
print(f" MedQA-USMLE Accuracy: {accuracy:.1%} ({correct}/{total})")
return {"benchmark": "MedQA-USMLE", "accuracy": accuracy, "correct": correct, "total": total}
# ============================================================
# Benchmark 2: MedMCQA
# ============================================================
def eval_medmcqa():
print("\n" + "=" * 60)
print("Benchmark: MedMCQA (Indian Clinical)")
print("=" * 60)
ds = load_dataset("openlifescienceai/medmcqa", split="validation")
if len(ds) > MAX_SAMPLES:
ds = ds.select(range(MAX_SAMPLES))
correct = 0
total = 0
idx_to_letter = {0: "A", 1: "B", 2: "C", 3: "D"}
for i, row in enumerate(ds):
question = row["question"]
options = [row["opa"], row["opb"], row["opc"], row["opd"]]
answer_idx = row["cop"] # 0-3
answer = idx_to_letter.get(answer_idx, "A")
prompt = f"""You are a medical expert. Answer this clinical question.
Question: {question}
A. {options[0]}
B. {options[1]}
C. {options[2]}
D. {options[3]}
Answer with ONLY the letter (A, B, C, or D):"""
response = generate_answer(prompt)
predicted = extract_answer_letter(response)
if predicted == answer:
correct += 1
total += 1
if (i + 1) % 50 == 0:
print(f" Progress: {i+1}/{len(ds)}, Accuracy: {correct/total:.1%}")
accuracy = correct / total if total > 0 else 0
print(f" MedMCQA Accuracy: {accuracy:.1%} ({correct}/{total})")
return {"benchmark": "MedMCQA", "accuracy": accuracy, "correct": correct, "total": total}
# ============================================================
# Benchmark 3: PubMedQA
# ============================================================
def eval_pubmedqa():
print("\n" + "=" * 60)
print("Benchmark: PubMedQA (Labeled)")
print("=" * 60)
ds = load_dataset("qiaojin/PubMedQA", "pqa_labeled", split="train")
if len(ds) > MAX_SAMPLES:
ds = ds.select(range(MAX_SAMPLES))
correct = 0
total = 0
for i, row in enumerate(ds):
question = row["question"]
context = " ".join(row["context"]["contexts"][:3]) if row.get("context") else ""
answer = row["final_decision"] # "yes", "no", "maybe"
prompt = f"""Based on the following biomedical context, answer the question with exactly one word: yes, no, or maybe.
Context: {context[:1000]}
Question: {question}
Answer (yes/no/maybe):"""
response = generate_answer(prompt, max_tokens=10)
response_lower = response.lower().strip()
if "yes" in response_lower:
predicted = "yes"
elif "no" in response_lower:
predicted = "no"
elif "maybe" in response_lower:
predicted = "maybe"
else:
predicted = response_lower
if predicted == answer:
correct += 1
total += 1
if (i + 1) % 50 == 0:
print(f" Progress: {i+1}/{len(ds)}, Accuracy: {correct/total:.1%}")
accuracy = correct / total if total > 0 else 0
print(f" PubMedQA Accuracy: {accuracy:.1%} ({correct}/{total})")
return {"benchmark": "PubMedQA", "accuracy": accuracy, "correct": correct, "total": total}
# ============================================================
# Benchmark 4: AfriMed-QA v2
# ============================================================
def eval_afrimedqa():
print("\n" + "=" * 60)
print("Benchmark: AfriMed-QA v2 (African Medical)")
print("=" * 60)
try:
ds = load_dataset("afrimedqa/afrimedqa_v2", split="train")
except Exception:
print(" AfriMed-QA v2 not available, skipping")
return {"benchmark": "AfriMed-QA", "accuracy": 0, "correct": 0, "total": 0, "error": "Dataset not available"}
# Filter to MCQ with valid answers
mcq_data = [row for row in ds if row.get("answer_options") and row.get("correct_answer")]
if len(mcq_data) > MAX_SAMPLES:
mcq_data = mcq_data[:MAX_SAMPLES]
if not mcq_data:
return {"benchmark": "AfriMed-QA", "accuracy": 0, "correct": 0, "total": 0, "error": "No MCQ data"}
correct = 0
total = 0
for i, row in enumerate(mcq_data):
question = row["question"]
options = row["answer_options"]
answer = row["correct_answer"]
# Format options
options_text = ""
if isinstance(options, dict):
for k, v in options.items():
options_text += f"{k}. {v}\n"
elif isinstance(options, list):
for j, opt in enumerate(options):
letter = chr(65 + j)
options_text += f"{letter}. {opt}\n"
prompt = f"""You are a medical expert working in an African healthcare context. Answer this clinical question.
Question: {question}
{options_text}
Answer with ONLY the letter:"""
response = generate_answer(prompt)
predicted = extract_answer_letter(response)
if predicted == str(answer).upper():
correct += 1
total += 1
if (i + 1) % 50 == 0:
print(f" Progress: {i+1}/{len(mcq_data)}, Accuracy: {correct/total:.1%}")
accuracy = correct / total if total > 0 else 0
print(f" AfriMed-QA Accuracy: {accuracy:.1%} ({correct}/{total})")
return {"benchmark": "AfriMed-QA", "accuracy": accuracy, "correct": correct, "total": total}
# ============================================================
# Run All Benchmarks
# ============================================================
if __name__ == "__main__":
print(f"\nStarting SAHELI Medical Benchmark Evaluation")
print(f"Model: {MODEL_ID}")
print(f"Date: {datetime.now().isoformat()}")
print(f"Samples per benchmark: {MAX_SAMPLES}")
results = {
"model": MODEL_ID,
"date": datetime.now().isoformat(),
"max_samples_per_benchmark": MAX_SAMPLES,
"benchmarks": []
}
for eval_fn in [eval_medqa, eval_medmcqa, eval_pubmedqa, eval_afrimedqa]:
try:
result = eval_fn()
results["benchmarks"].append(result)
except Exception as e:
print(f" Error: {e}")
results["benchmarks"].append({"error": str(e)})
# Summary
print("\n" + "=" * 60)
print("SAHELI BENCHMARK RESULTS SUMMARY")
print("=" * 60)
for r in results["benchmarks"]:
if "error" not in r:
print(f" {r['benchmark']:20s}: {r['accuracy']:.1%} ({r['correct']}/{r['total']})")
# Save results
with open(OUTPUT_FILE, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to {OUTPUT_FILE}")
# Upload results to Hub
try:
from huggingface_hub import HfApi
api = HfApi()
api.upload_file(
path_or_fileobj=OUTPUT_FILE,
path_in_repo="benchmark_results.json",
repo_id="muthuk1/saheli-gemma4-e4b-medical",
repo_type="model",
)
print(f"Results uploaded to https://huggingface.co/muthuk1/saheli-gemma4-e4b-medical")
except Exception as e:
print(f"Could not upload to Hub: {e}")