"""Roll up per-session and per-QA evaluations into baseline-level summaries. Recall & correctness: per-session average (not pooled cumulative). Interference: pooled across sessions. QA & evidence: pooled across questions. """ from __future__ import annotations from collections.abc import Mapping, Sequence def _safe_div(a: float, b: float) -> float: return a / b if b else 0.0 def aggregate_metrics( baseline_id: str, *, session_evaluations: Sequence[Mapping[str, object]] = (), qa_evaluations: Sequence[Mapping[str, object]] = (), ) -> dict[str, object]: """Aggregate all per-session and per-QA evaluations.""" # --- Per-session recall (average) --- recall_scores: list[float] = [] update_recall_scores: list[float] = [] # --- Per-session correctness (average) --- correctness_scores: list[float] = [] hallucination_scores: list[float] = [] irrelevant_scores: list[float] = [] # --- Update handling (pooled) --- upd_num_updated = 0 upd_num_both = 0 upd_num_outdated = 0 upd_total_items = 0 # --- Interference rejection (pooled) --- interf_num_rejected = 0 interf_num_memorized = 0 interf_total_items = 0 # --- Per-session detail counters (for reference) --- total_gold_points = 0 total_covered = 0 total_memories = 0 total_correct = 0 total_hallucination = 0 total_irrelevant = 0 for s in session_evaluations: # Recall: per-session score r = s.get("recall") if r is not None: recall_scores.append(float(r)) ur = s.get("update_recall") if ur is not None: update_recall_scores.append(float(ur)) # Correctness: per-session score cr = s.get("correctness_rate") if cr is not None: correctness_scores.append(float(cr)) nm = int(s.get("num_memories", 0)) if nm > 0: hallucination_scores.append( float(s.get("num_hallucination", 0)) / nm ) irrelevant_scores.append( float(s.get("num_irrelevant", 0)) / nm ) # Detail counters c = s.get("covered_count") if c is not None: total_covered += int(c) total_gold_points += int(s.get("num_gold", 0)) total_memories += nm total_correct += int(s.get("num_correct", 0)) total_hallucination += int(s.get("num_hallucination", 0)) total_irrelevant += int(s.get("num_irrelevant", 0)) # Update handling (pooled) upd_num_updated += int(s.get("update_num_updated", 0)) upd_num_both += int(s.get("update_num_both", 0)) upd_num_outdated += int(s.get("update_num_outdated", 0)) upd_total_items += int(s.get("update_total_items", 0)) # Interference rejection (pooled) interf_num_rejected += int(s.get("interference_num_rejected", 0)) interf_num_memorized += int(s.get("interference_num_memorized", 0)) interf_total_items += int(s.get("interference_total_items", 0)) # --- QA (pooled) --- qa_total = 0 qa_valid = 0 qa_correct = 0 qa_hallucination = 0 qa_omission = 0 evidence_covered = 0 evidence_total = 0 for q in qa_evaluations: qa_total += 1 label = q.get("answer_label") if label in ("Correct", "Hallucination", "Omission"): qa_valid += 1 if label == "Correct": qa_correct += 1 elif label == "Hallucination": qa_hallucination += 1 elif label == "Omission": qa_omission += 1 ec = q.get("evidence_covered_count") if ec is not None: evidence_covered += int(ec) evidence_total += int(q.get("num_evidence", 0)) n_recall = len(recall_scores) n_update = len(update_recall_scores) n_correct = len(correctness_scores) n_hallu = len(hallucination_scores) n_irrel = len(irrelevant_scores) return { "baseline_id": baseline_id, "memory_recall": { "avg_recall": _safe_div(sum(recall_scores), n_recall), "avg_update_recall": _safe_div(sum(update_recall_scores), n_update), "num_sessions_with_recall": n_recall, "num_sessions_with_update": n_update, "total_covered": total_covered, "total_gold": total_gold_points, }, "memory_correctness": { "avg_correctness": _safe_div(sum(correctness_scores), n_correct), "avg_hallucination": _safe_div(sum(hallucination_scores), n_hallu), "avg_irrelevant": _safe_div(sum(irrelevant_scores), n_irrel), "num_sessions": n_correct, "total_memories": total_memories, "total_correct": total_correct, "total_hallucination": total_hallucination, "total_irrelevant": total_irrelevant, }, "update_handling": { "score": _safe_div(upd_num_updated * 1.0 + upd_num_both * 0.5, upd_total_items), "num_updated": upd_num_updated, "num_both": upd_num_both, "num_outdated": upd_num_outdated, "num_total": upd_total_items, }, "interference_rejection": { "score": _safe_div(interf_num_rejected, interf_total_items), "num_rejected": interf_num_rejected, "num_memorized": interf_num_memorized, "num_total": interf_total_items, }, "question_answering": { "correct_ratio": _safe_div(qa_correct, qa_valid), "hallucination_ratio": _safe_div(qa_hallucination, qa_valid), "omission_ratio": _safe_div(qa_omission, qa_valid), "num_total": qa_total, "num_valid": qa_valid, }, "evidence_coverage": { "hit_rate": _safe_div(evidence_covered, evidence_total), "num_covered": evidence_covered, "num_total": evidence_total, }, }