"""Checkpoint QA evaluation: answer quality + batch evidence coverage. Two dimensions: 1. Answer evaluation: Correct / Hallucination / Omission (1 LLM call) 2. Evidence coverage: how many gold evidence points are covered by the memories the model actually *cited* when answering? (1 LLM call) """ from __future__ import annotations from eval_framework.judges import evaluate_evidence_batch, evaluate_qa_llm from eval_framework.pipeline.records import PipelineCheckpointQARecord def evaluate_checkpoint_qa( record: PipelineCheckpointQARecord, **_kwargs: object, ) -> dict[str, object]: """LLM-judged QA evaluation: answer correctness + evidence coverage.""" # --- Build cited-memories text (what the model actually used) --- if record.cited_memories: cited_lines = [f"[{i + 1}] {m}" for i, m in enumerate(record.cited_memories)] cited_str = "\n".join(cited_lines) else: # Fallback: use full retrieval (legacy records without cited_memories) cited_lines = [f"[{item.rank}] {item.text}" for item in record.retrieval.items] cited_str = "\n".join(cited_lines) if cited_lines else "" # --- Answer evaluation (1 LLM call, unchanged) --- gold_evidence_str = ( "\n".join(record.gold_evidence_contents) if record.gold_evidence_contents else "No evidence available." ) answer_result = evaluate_qa_llm( question=record.question, reference_answer=record.gold_answer, key_memory_points=gold_evidence_str, system_response=record.generated_answer, ) answer_label = answer_result.get("evaluation_result") # --- Evidence coverage (1 LLM call, batch) --- # Only check against cited memories, not the full retrieval gold_contents = list(record.gold_evidence_contents) evidence_result: dict[str, object] = { "covered_count": 0, "total": len(gold_contents), "reasoning": "" } if gold_contents and cited_str.strip(): evidence_result = evaluate_evidence_batch(cited_str, gold_contents) covered = evidence_result.get("covered_count") total_ev = evidence_result.get("total", len(gold_contents)) if covered is not None and total_ev: evidence_hit_rate = float(covered) / float(total_ev) else: evidence_hit_rate = 0.0 return { "answer_label": answer_label, "answer_reasoning": answer_result.get("reasoning", ""), "answer_is_valid": answer_label in ("Correct", "Hallucination", "Omission"), "evidence_hit_rate": evidence_hit_rate, "evidence_covered_count": covered, "num_evidence": total_ev, "evidence_reasoning": evidence_result.get("reasoning", ""), "num_cited_memories": len(record.cited_memories), }