"""Unified session evaluation: recall + correctness (includes update & interference). Per session, 2 LLM calls — both scoped to THIS SESSION's memory delta only: Call 1 — Recall: how many of this session's gold points are covered by the session's memory delta (add/update ops)? Call 2 — Correctness: is each delta memory correct, hallucinated, or irrelevant? (reference = this session's gold points + interference) Aggregate: per-session recall/correctness averaged across sessions. """ from __future__ import annotations from eval_framework.judges import ( evaluate_correctness_batch, evaluate_interference_single, evaluate_recall_batch, evaluate_update_single, ) from eval_framework.pipeline.records import PipelineSessionRecord def _delta_to_text(session: PipelineSessionRecord) -> str: """Only the memories added or updated in THIS session (not the full snapshot).""" lines: list[str] = [] idx = 0 for d in session.memory_delta: if d.op in ("add", "update"): idx += 1 lines.append(f"[{idx}] {d.text}") return "\n".join(lines) def _delta_texts(session: PipelineSessionRecord) -> list[str]: """Text list of memories added or updated in THIS session.""" return [d.text for d in session.memory_delta if d.op in ("add", "update")] def _build_recall_gold_points(session: PipelineSessionRecord) -> list[str]: """Current session's new + update gold points only (NOT cumulative).""" out: list[str] = [] for g in session.gold_state.session_new_memories: out.append(f"[normal] {g.memory_content}") for g in session.gold_state.session_update_memories: out.append(f"[update] {g.memory_content}") return out def _build_correctness_gold_points(session: PipelineSessionRecord) -> list[str]: """Current session's new + update + interference gold points as reference.""" out: list[str] = [] for g in session.gold_state.session_new_memories: out.append(f"[normal] {g.memory_content}") for g in session.gold_state.session_update_memories: out.append(f"[update] {g.memory_content}") for g in session.gold_state.session_interference_memories: out.append(f"[interference] {g.memory_content}") return out def evaluate_extraction( session: PipelineSessionRecord, **_kwargs: object, ) -> dict[str, object]: """Unified session evaluation: recall + correctness in 2 LLM calls. Uses only THIS session's new gold points for recall and correctness, not the cumulative history. Aggregate averages per-session scores. """ delta_str = _delta_to_text(session) delta_texts = _delta_texts(session) interference_total = len(session.gold_state.session_interference_memories) # --- Call 1: Recall (this session's gold points vs this session's delta) --- recall_gold = _build_recall_gold_points(session) if not recall_gold: recall = None update_recall = None recall_result: dict[str, object] = { "covered_count": 0, "update_covered_count": 0, "total": 0, "update_total": 0, "reasoning": "No new gold points in this session.", } elif not delta_str.strip(): recall = 0.0 update_recall = 0.0 update_total = sum(1 for p in recall_gold if p.startswith("[update]")) recall_result = { "covered_count": 0, "update_covered_count": 0, "total": len(recall_gold), "update_total": update_total, "reasoning": "No add/update memories in this session's delta.", } else: recall_result = evaluate_recall_batch(delta_str, recall_gold) covered = recall_result.get("covered_count") upd_covered = recall_result.get("update_covered_count") total_gold = recall_result.get("total", len(recall_gold)) upd_total = recall_result.get("update_total", 0) if recall_gold: recall = float(covered) / float(total_gold) if covered is not None and total_gold else None update_recall = float(upd_covered) / float(upd_total) if upd_covered is not None and upd_total else None # --- Call 2: Correctness (this session's delta memories, reference = this session's golds) --- correctness_gold = _build_correctness_gold_points(session) correctness_result = evaluate_correctness_batch(delta_texts, correctness_gold, interference_total) correctness_records = correctness_result.get("results", []) num_correct = sum(1 for r in correctness_records if r.get("label") == "correct") num_hallucination = sum(1 for r in correctness_records if r.get("label") == "hallucination") num_irrelevant = sum(1 for r in correctness_records if r.get("label") == "irrelevant") num_memories = len(delta_texts) correctness_rate = float(num_correct) / float(num_memories) if num_memories else 0.0 # --- Call 3+: Update handling (one LLM call per update gold point) --- update_records: list[dict[str, object]] = [] for g in session.gold_state.session_update_memories: res = evaluate_update_single( delta_str, new_content=g.memory_content, old_contents=list(g.original_memories), ) update_records.append({ "memory_id": g.memory_id, "label": res["label"], "reasoning": res["reasoning"], }) num_updated = sum(1 for r in update_records if r["label"] == "updated") num_both = sum(1 for r in update_records if r["label"] == "both") num_outdated = sum(1 for r in update_records if r["label"] == "outdated") update_total_items = len(update_records) # Score: updated=1.0, both=0.5, outdated=0.0 update_score = ( (num_updated * 1.0 + num_both * 0.5) / update_total_items if update_total_items else None ) # --- Call 4+: Interference rejection (one LLM call per interference gold point) --- interference_records: list[dict[str, object]] = [] for g in session.gold_state.session_interference_memories: res = evaluate_interference_single( delta_str, interference_content=g.memory_content, ) interference_records.append({ "memory_id": g.memory_id, "label": res["label"], "reasoning": res["reasoning"], }) num_rejected = sum(1 for r in interference_records if r["label"] == "rejected") num_memorized = sum(1 for r in interference_records if r["label"] == "memorized") interference_total_items = len(interference_records) # Score: rejected=1.0, memorized=0.0 interference_score = ( float(num_rejected) / interference_total_items if interference_total_items else None ) return { "session_id": session.session_id, "recall": recall, "covered_count": covered, "num_gold": total_gold, "update_recall": update_recall, "update_covered_count": upd_covered, "update_total": upd_total, "recall_reasoning": recall_result.get("reasoning", ""), "correctness_rate": correctness_rate, "num_memories": num_memories, "num_correct": num_correct, "num_hallucination": num_hallucination, "num_irrelevant": num_irrelevant, "correctness_reasoning": correctness_result.get("reasoning", ""), "correctness_records": correctness_records, # Update handling "update_score": update_score, "update_num_updated": num_updated, "update_num_both": num_both, "update_num_outdated": num_outdated, "update_total_items": update_total_items, "update_records": update_records, # Interference rejection "interference_score": interference_score, "interference_num_rejected": num_rejected, "interference_num_memorized": num_memorized, "interference_total_items": interference_total_items, "interference_records": interference_records, }