""" MedScribe v2 — Evaluation Evaluates the fine-tuned model on the validation set. Measures: 1. Form extraction accuracy (field-by-field) 2. Danger sign precision / recall 3. Hallucination rate (danger signs without evidence) 4. Referral decision accuracy 5. JSON validity rate Usage: python scripts/06_evaluate.py --model models/checkpoints/final python scripts/06_evaluate.py --model ollama:medscribe-v2 """ import argparse import json import os import sys from collections import defaultdict from pathlib import Path os.environ["TORCH_COMPILE_DISABLE"] = "1" os.environ["TORCHDYNAMO_DISABLE"] = "1" import torch torch._dynamo.config.suppress_errors = True def load_val_data(path: str) -> list[dict]: """Load validation JSONL (raw format with ground truth).""" samples = [] with open(path, "r", encoding="utf-8") as f: for line in f: if line.strip(): samples.append(json.loads(line)) return samples def field_accuracy(predicted: dict, ground_truth: dict, prefix: str = "") -> dict: """ Compare predicted vs ground truth field-by-field. Returns {field: {correct, total, accuracy}}. """ results = {} if not isinstance(ground_truth, dict): return results for key, gt_val in ground_truth.items(): field_name = f"{prefix}.{key}" if prefix else key pred_val = predicted.get(key) if isinstance(predicted, dict) else None if isinstance(gt_val, dict): sub = field_accuracy(pred_val, gt_val, field_name) results.update(sub) elif isinstance(gt_val, list): # For arrays, check set overlap gt_set = set(str(x) for x in gt_val) if gt_val else set() pred_set = set(str(x) for x in (pred_val or [])) if isinstance(pred_val, list) else set() overlap = len(gt_set & pred_set) total = max(len(gt_set), 1) results[field_name] = {"correct": overlap, "total": total} else: match = (pred_val == gt_val) or (pred_val is None and gt_val is None) results[field_name] = {"correct": 1 if match else 0, "total": 1} return results def danger_sign_metrics(predicted: dict, ground_truth: dict) -> dict: """ Compute precision, recall, F1 for danger sign detection. Also checks hallucination rate (signs without evidence). """ gt_signs = {s["sign"] for s in ground_truth.get("danger_signs", [])} pred_signs_list = predicted.get("danger_signs", []) if isinstance(predicted, dict) else [] pred_signs = {s.get("sign", "") for s in pred_signs_list} tp = len(gt_signs & pred_signs) fp = len(pred_signs - gt_signs) fn = len(gt_signs - pred_signs) precision = tp / (tp + fp) if (tp + fp) > 0 else (1.0 if len(gt_signs) == 0 else 0.0) recall = tp / (tp + fn) if (tp + fn) > 0 else (1.0 if len(gt_signs) == 0 else 0.0) f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 # Hallucination check: predicted signs without utterance_evidence hallucinations = 0 for s in pred_signs_list: if not s.get("utterance_evidence"): hallucinations += 1 # Referral decision accuracy gt_decision = ground_truth.get("referral_decision", {}).get("decision", "") pred_decision = "" if isinstance(predicted, dict): pred_decision = predicted.get("referral_decision", {}).get("decision", "") referral_correct = gt_decision == pred_decision return { "true_positives": tp, "false_positives": fp, "false_negatives": fn, "precision": precision, "recall": recall, "f1": f1, "hallucinations": hallucinations, "total_predicted": len(pred_signs_list), "referral_correct": referral_correct, } def run_inference_ollama(transcript: str, system_prompt: str, model: str) -> str: """Run inference via Ollama.""" import ollama response = ollama.chat( model=model, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": transcript}, ], ) return response.message.content def run_inference_transformers(transcript: str, system_prompt: str, model_path: str) -> str: """Run inference via Unsloth-loaded model (handles LoRA + 4-bit).""" from unsloth import FastLanguageModel # Cache model loading if not hasattr(run_inference_transformers, "_model"): print(" Loading model via Unsloth...") model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_path, max_seq_length=4096, load_in_4bit=True, ) FastLanguageModel.for_inference(model) run_inference_transformers._model = model run_inference_transformers._tokenizer = tokenizer model = run_inference_transformers._model tokenizer = run_inference_transformers._tokenizer messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": transcript}, ] text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(text=[text], return_tensors="pt").to("cuda") with torch.no_grad(): output_ids = model.generate(**inputs, max_new_tokens=2048, do_sample=False) return tokenizer.decode(output_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) def main(): parser = argparse.ArgumentParser(description="MedScribe v2 — Evaluation") parser.add_argument("--model", required=True, help="Model path or ollama:") parser.add_argument("--val-file", default="data/processed/val.jsonl") parser.add_argument("--raw-file", default="data/processed/training_data_raw.jsonl", help="Raw data with ground truth for danger sign eval") parser.add_argument("--limit", type=int, default=0, help="Limit eval to N samples") args = parser.parse_args() is_ollama = args.model.startswith("ollama:") model_ref = args.model.split(":", 1)[1] if is_ollama else args.model print("=" * 60) print(f"MedScribe v2 — Evaluation") print(f"Model: {args.model}") print("=" * 60) # Load validation data val_data = load_val_data(args.val_file) if args.limit > 0: val_data = val_data[:args.limit] print(f"Evaluating on {len(val_data)} samples") # Metrics accumulators json_valid = 0 json_invalid = 0 all_field_results = defaultdict(lambda: {"correct": 0, "total": 0}) all_danger_metrics = [] for i, sample in enumerate(val_data): messages = sample["messages"] system_msg = messages[0]["content"] user_msg = messages[1]["content"] gt_response = messages[2]["content"] # Run inference try: if is_ollama: pred_text = run_inference_ollama(user_msg, system_msg, model_ref) else: pred_text = run_inference_transformers(user_msg, system_msg, model_ref) except Exception as e: print(f" [{i+1}] Inference error: {e}") json_invalid += 1 continue # Parse JSON — strip markdown code fences if present pred_clean = pred_text.strip() if pred_clean.startswith("```"): # Remove ```json ... ``` wrapper lines = pred_clean.split("\n") lines = [l for l in lines if not l.strip().startswith("```")] pred_clean = "\n".join(lines) try: pred_data = json.loads(pred_clean) gt_data = json.loads(gt_response) json_valid += 1 except json.JSONDecodeError: json_invalid += 1 print(f" [{i+1}] Invalid JSON: {pred_clean[:100]}...") continue # Field accuracy field_results = field_accuracy(pred_data, gt_data) for field, res in field_results.items(): all_field_results[field]["correct"] += res["correct"] all_field_results[field]["total"] += res["total"] # Danger sign metrics (if this is a danger sign task) task = sample.get("metadata", {}).get("task", "") if task == "danger_signs": dm = danger_sign_metrics(pred_data, gt_data) all_danger_metrics.append(dm) if (i + 1) % 10 == 0: print(f" [{i+1}/{len(val_data)}] processed") # ── Results ── total = json_valid + json_invalid print(f"\n{'=' * 60}") print("EVALUATION RESULTS") print("=" * 60) print(f"\n JSON validity: {json_valid}/{total} ({json_valid/total*100:.0f}%)") # Field accuracy if all_field_results: total_correct = sum(v["correct"] for v in all_field_results.values()) total_fields = sum(v["total"] for v in all_field_results.values()) print(f"\n Overall field accuracy: {total_correct}/{total_fields} ({total_correct/total_fields*100:.1f}%)") print(f"\n Per-field accuracy (top 20):") sorted_fields = sorted(all_field_results.items(), key=lambda x: x[1]["correct"]/max(x[1]["total"],1)) for field, res in sorted_fields[:20]: acc = res["correct"] / max(res["total"], 1) * 100 print(f" {field}: {acc:.0f}% ({res['correct']}/{res['total']})") # Danger sign metrics if all_danger_metrics: avg_precision = sum(m["precision"] for m in all_danger_metrics) / len(all_danger_metrics) avg_recall = sum(m["recall"] for m in all_danger_metrics) / len(all_danger_metrics) avg_f1 = sum(m["f1"] for m in all_danger_metrics) / len(all_danger_metrics) total_hallucinations = sum(m["hallucinations"] for m in all_danger_metrics) total_predicted = sum(m["total_predicted"] for m in all_danger_metrics) referral_correct = sum(1 for m in all_danger_metrics if m["referral_correct"]) hallucination_rate = total_hallucinations / max(total_predicted, 1) * 100 print(f"\n Danger Sign Detection:") print(f" Precision: {avg_precision:.2f}") print(f" Recall: {avg_recall:.2f}") print(f" F1: {avg_f1:.2f}") print(f" Hallucination rate: {hallucination_rate:.1f}% ({total_hallucinations}/{total_predicted})") print(f" Referral accuracy: {referral_correct}/{len(all_danger_metrics)} ({referral_correct/len(all_danger_metrics)*100:.0f}%)") # Save results output = { "json_validity": {"valid": json_valid, "invalid": json_invalid, "rate": json_valid / max(total, 1)}, "field_accuracy": {k: {**v, "accuracy": v["correct"]/max(v["total"],1)} for k, v in all_field_results.items()}, } if all_danger_metrics: output["danger_signs"] = { "precision": avg_precision, "recall": avg_recall, "f1": avg_f1, "hallucination_rate": hallucination_rate, "referral_accuracy": referral_correct / len(all_danger_metrics), } eval_path = "data/processed/eval_results.json" with open(eval_path, "w") as f: json.dump(output, f, indent=2) print(f"\n Results saved to {eval_path}") print("=" * 60) if __name__ == "__main__": main()