sakhi / scripts /evaluate.py
Tushar9802's picture
HF Space deploy — initial
745f62a
"""
MedScribe v2 — Evaluation
Evaluates the fine-tuned model on the validation set.
Measures:
1. Form extraction accuracy (field-by-field)
2. Danger sign precision / recall
3. Hallucination rate (danger signs without evidence)
4. Referral decision accuracy
5. JSON validity rate
Usage:
python scripts/06_evaluate.py --model models/checkpoints/final
python scripts/06_evaluate.py --model ollama:medscribe-v2
"""
import argparse
import json
import os
import sys
from collections import defaultdict
from pathlib import Path
os.environ["TORCH_COMPILE_DISABLE"] = "1"
os.environ["TORCHDYNAMO_DISABLE"] = "1"
import torch
torch._dynamo.config.suppress_errors = True
def load_val_data(path: str) -> list[dict]:
"""Load validation JSONL (raw format with ground truth)."""
samples = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
samples.append(json.loads(line))
return samples
def field_accuracy(predicted: dict, ground_truth: dict, prefix: str = "") -> dict:
"""
Compare predicted vs ground truth field-by-field.
Returns {field: {correct, total, accuracy}}.
"""
results = {}
if not isinstance(ground_truth, dict):
return results
for key, gt_val in ground_truth.items():
field_name = f"{prefix}.{key}" if prefix else key
pred_val = predicted.get(key) if isinstance(predicted, dict) else None
if isinstance(gt_val, dict):
sub = field_accuracy(pred_val, gt_val, field_name)
results.update(sub)
elif isinstance(gt_val, list):
# For arrays, check set overlap
gt_set = set(str(x) for x in gt_val) if gt_val else set()
pred_set = set(str(x) for x in (pred_val or [])) if isinstance(pred_val, list) else set()
overlap = len(gt_set & pred_set)
total = max(len(gt_set), 1)
results[field_name] = {"correct": overlap, "total": total}
else:
match = (pred_val == gt_val) or (pred_val is None and gt_val is None)
results[field_name] = {"correct": 1 if match else 0, "total": 1}
return results
def danger_sign_metrics(predicted: dict, ground_truth: dict) -> dict:
"""
Compute precision, recall, F1 for danger sign detection.
Also checks hallucination rate (signs without evidence).
"""
gt_signs = {s["sign"] for s in ground_truth.get("danger_signs", [])}
pred_signs_list = predicted.get("danger_signs", []) if isinstance(predicted, dict) else []
pred_signs = {s.get("sign", "") for s in pred_signs_list}
tp = len(gt_signs & pred_signs)
fp = len(pred_signs - gt_signs)
fn = len(gt_signs - pred_signs)
precision = tp / (tp + fp) if (tp + fp) > 0 else (1.0 if len(gt_signs) == 0 else 0.0)
recall = tp / (tp + fn) if (tp + fn) > 0 else (1.0 if len(gt_signs) == 0 else 0.0)
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
# Hallucination check: predicted signs without utterance_evidence
hallucinations = 0
for s in pred_signs_list:
if not s.get("utterance_evidence"):
hallucinations += 1
# Referral decision accuracy
gt_decision = ground_truth.get("referral_decision", {}).get("decision", "")
pred_decision = ""
if isinstance(predicted, dict):
pred_decision = predicted.get("referral_decision", {}).get("decision", "")
referral_correct = gt_decision == pred_decision
return {
"true_positives": tp,
"false_positives": fp,
"false_negatives": fn,
"precision": precision,
"recall": recall,
"f1": f1,
"hallucinations": hallucinations,
"total_predicted": len(pred_signs_list),
"referral_correct": referral_correct,
}
def run_inference_ollama(transcript: str, system_prompt: str, model: str) -> str:
"""Run inference via Ollama."""
import ollama
response = ollama.chat(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": transcript},
],
)
return response.message.content
def run_inference_transformers(transcript: str, system_prompt: str, model_path: str) -> str:
"""Run inference via Unsloth-loaded model (handles LoRA + 4-bit)."""
from unsloth import FastLanguageModel
# Cache model loading
if not hasattr(run_inference_transformers, "_model"):
print(" Loading model via Unsloth...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=4096,
load_in_4bit=True,
)
FastLanguageModel.for_inference(model)
run_inference_transformers._model = model
run_inference_transformers._tokenizer = tokenizer
model = run_inference_transformers._model
tokenizer = run_inference_transformers._tokenizer
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": transcript},
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text=[text], return_tensors="pt").to("cuda")
with torch.no_grad():
output_ids = model.generate(**inputs, max_new_tokens=2048, do_sample=False)
return tokenizer.decode(output_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
def main():
parser = argparse.ArgumentParser(description="MedScribe v2 — Evaluation")
parser.add_argument("--model", required=True, help="Model path or ollama:<name>")
parser.add_argument("--val-file", default="data/processed/val.jsonl")
parser.add_argument("--raw-file", default="data/processed/training_data_raw.jsonl",
help="Raw data with ground truth for danger sign eval")
parser.add_argument("--limit", type=int, default=0, help="Limit eval to N samples")
args = parser.parse_args()
is_ollama = args.model.startswith("ollama:")
model_ref = args.model.split(":", 1)[1] if is_ollama else args.model
print("=" * 60)
print(f"MedScribe v2 — Evaluation")
print(f"Model: {args.model}")
print("=" * 60)
# Load validation data
val_data = load_val_data(args.val_file)
if args.limit > 0:
val_data = val_data[:args.limit]
print(f"Evaluating on {len(val_data)} samples")
# Metrics accumulators
json_valid = 0
json_invalid = 0
all_field_results = defaultdict(lambda: {"correct": 0, "total": 0})
all_danger_metrics = []
for i, sample in enumerate(val_data):
messages = sample["messages"]
system_msg = messages[0]["content"]
user_msg = messages[1]["content"]
gt_response = messages[2]["content"]
# Run inference
try:
if is_ollama:
pred_text = run_inference_ollama(user_msg, system_msg, model_ref)
else:
pred_text = run_inference_transformers(user_msg, system_msg, model_ref)
except Exception as e:
print(f" [{i+1}] Inference error: {e}")
json_invalid += 1
continue
# Parse JSON — strip markdown code fences if present
pred_clean = pred_text.strip()
if pred_clean.startswith("```"):
# Remove ```json ... ``` wrapper
lines = pred_clean.split("\n")
lines = [l for l in lines if not l.strip().startswith("```")]
pred_clean = "\n".join(lines)
try:
pred_data = json.loads(pred_clean)
gt_data = json.loads(gt_response)
json_valid += 1
except json.JSONDecodeError:
json_invalid += 1
print(f" [{i+1}] Invalid JSON: {pred_clean[:100]}...")
continue
# Field accuracy
field_results = field_accuracy(pred_data, gt_data)
for field, res in field_results.items():
all_field_results[field]["correct"] += res["correct"]
all_field_results[field]["total"] += res["total"]
# Danger sign metrics (if this is a danger sign task)
task = sample.get("metadata", {}).get("task", "")
if task == "danger_signs":
dm = danger_sign_metrics(pred_data, gt_data)
all_danger_metrics.append(dm)
if (i + 1) % 10 == 0:
print(f" [{i+1}/{len(val_data)}] processed")
# ── Results ──
total = json_valid + json_invalid
print(f"\n{'=' * 60}")
print("EVALUATION RESULTS")
print("=" * 60)
print(f"\n JSON validity: {json_valid}/{total} ({json_valid/total*100:.0f}%)")
# Field accuracy
if all_field_results:
total_correct = sum(v["correct"] for v in all_field_results.values())
total_fields = sum(v["total"] for v in all_field_results.values())
print(f"\n Overall field accuracy: {total_correct}/{total_fields} ({total_correct/total_fields*100:.1f}%)")
print(f"\n Per-field accuracy (top 20):")
sorted_fields = sorted(all_field_results.items(),
key=lambda x: x[1]["correct"]/max(x[1]["total"],1))
for field, res in sorted_fields[:20]:
acc = res["correct"] / max(res["total"], 1) * 100
print(f" {field}: {acc:.0f}% ({res['correct']}/{res['total']})")
# Danger sign metrics
if all_danger_metrics:
avg_precision = sum(m["precision"] for m in all_danger_metrics) / len(all_danger_metrics)
avg_recall = sum(m["recall"] for m in all_danger_metrics) / len(all_danger_metrics)
avg_f1 = sum(m["f1"] for m in all_danger_metrics) / len(all_danger_metrics)
total_hallucinations = sum(m["hallucinations"] for m in all_danger_metrics)
total_predicted = sum(m["total_predicted"] for m in all_danger_metrics)
referral_correct = sum(1 for m in all_danger_metrics if m["referral_correct"])
hallucination_rate = total_hallucinations / max(total_predicted, 1) * 100
print(f"\n Danger Sign Detection:")
print(f" Precision: {avg_precision:.2f}")
print(f" Recall: {avg_recall:.2f}")
print(f" F1: {avg_f1:.2f}")
print(f" Hallucination rate: {hallucination_rate:.1f}% ({total_hallucinations}/{total_predicted})")
print(f" Referral accuracy: {referral_correct}/{len(all_danger_metrics)} ({referral_correct/len(all_danger_metrics)*100:.0f}%)")
# Save results
output = {
"json_validity": {"valid": json_valid, "invalid": json_invalid, "rate": json_valid / max(total, 1)},
"field_accuracy": {k: {**v, "accuracy": v["correct"]/max(v["total"],1)} for k, v in all_field_results.items()},
}
if all_danger_metrics:
output["danger_signs"] = {
"precision": avg_precision, "recall": avg_recall, "f1": avg_f1,
"hallucination_rate": hallucination_rate,
"referral_accuracy": referral_correct / len(all_danger_metrics),
}
eval_path = "data/processed/eval_results.json"
with open(eval_path, "w") as f:
json.dump(output, f, indent=2)
print(f"\n Results saved to {eval_path}")
print("=" * 60)
if __name__ == "__main__":
main()