| import difflib |
| from typing import Any, Dict, List, Optional |
|
|
| from .bundle import load_bundle |
|
|
|
|
| def _normalize_for_compare(x: Any) -> Any: |
| if isinstance(x, dict): |
| return {k: _normalize_for_compare(x[k]) for k in sorted(x.keys())} |
| if isinstance(x, list): |
| return [_normalize_for_compare(v) for v in x] |
| return x |
|
|
|
|
| def _event_core(ev: Dict[str, Any]) -> Any: |
| return _normalize_for_compare({k: ev.get(k) for k in ("kind", "step", "payload")}) |
|
|
|
|
| def build_alignment(A_events: List[Dict[str, Any]], B_events: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| rows: List[Dict[str, Any]] = [] |
| n = max(len(A_events), len(B_events)) |
| for i in range(n): |
| a = A_events[i] if i < len(A_events) else None |
| b = B_events[i] if i < len(B_events) else None |
|
|
| if a is None: |
| status = "missing_in_A" |
| elif b is None: |
| status = "missing_in_B" |
| else: |
| status = "same" if _event_core(a) == _event_core(b) else "diff" |
|
|
| rows.append( |
| { |
| "i": i, |
| "status": status, |
| "kind_a": a.get("kind") if a else None, |
| "step_a": a.get("step") if a else None, |
| "kind_b": b.get("kind") if b else None, |
| "step_b": b.get("step") if b else None, |
| } |
| ) |
| return rows |
|
|
|
|
| def _json_diff(a: Any, b: Any, path: str = "") -> List[Dict[str, Any]]: |
| diffs: List[Dict[str, Any]] = [] |
|
|
| if type(a) != type(b): |
| diffs.append({"path": path or "$", "kind": "type", "a": str(type(a)), "b": str(type(b))}) |
| return diffs |
|
|
| if isinstance(a, dict): |
| akeys = set(a.keys()) |
| bkeys = set(b.keys()) |
| for k in sorted(akeys - bkeys): |
| diffs.append({"path": f"{path}.{k}" if path else k, "kind": "removed", "a": a[k], "b": None}) |
| for k in sorted(bkeys - akeys): |
| diffs.append({"path": f"{path}.{k}" if path else k, "kind": "added", "a": None, "b": b[k]}) |
| for k in sorted(akeys & bkeys): |
| diffs.extend(_json_diff(a[k], b[k], f"{path}.{k}" if path else k)) |
| return diffs |
|
|
| if isinstance(a, list): |
| n = max(len(a), len(b)) |
| for i in range(n): |
| pa = a[i] if i < len(a) else None |
| pb = b[i] if i < len(b) else None |
| if i >= len(a): |
| diffs.append({"path": f"{path}[{i}]", "kind": "added", "a": None, "b": pb}) |
| elif i >= len(b): |
| diffs.append({"path": f"{path}[{i}]", "kind": "removed", "a": pa, "b": None}) |
| else: |
| diffs.extend(_json_diff(pa, pb, f"{path}[{i}]")) |
| return diffs |
|
|
| if a != b: |
| diffs.append({"path": path or "$", "kind": "value", "a": a, "b": b}) |
| return diffs |
|
|
|
|
| def _classify_divergence(kind_a: Optional[str], kind_b: Optional[str]) -> str: |
| if kind_a != kind_b: |
| return "control-flow" |
| if kind_a in ("tool_call", "tool_result"): |
| return "tool" |
| if kind_a in ("memory_write", "memory_read"): |
| return "memory" |
| if kind_a in ("llm_sample", "llm_call"): |
| return "sampling" |
| if kind_a in ("guardrail",): |
| return "governance" |
| return "state" |
|
|
|
|
| def _text_delta(a: str, b: str) -> str: |
| a_lines = a.splitlines() |
| b_lines = b.splitlines() |
| diff = difflib.unified_diff(a_lines, b_lines, fromfile="A", tofile="B", lineterm="") |
| return "\n".join(diff) |
|
|
|
|
| def _extract_final_reward(events: List[Dict[str, Any]]) -> Optional[float]: |
| """ |
| Looks for last state_snapshot payload containing: |
| - payload.reward_total |
| - payload.metrics.reward_total |
| """ |
| for ev in reversed(events): |
| if ev.get("kind") != "state_snapshot": |
| continue |
| p = ev.get("payload", {}) or {} |
| if isinstance(p, dict): |
| rt = p.get("reward_total") |
| if isinstance(rt, (int, float)): |
| return float(rt) |
| m = p.get("metrics") |
| if isinstance(m, dict): |
| rt2 = m.get("reward_total") |
| if isinstance(rt2, (int, float)): |
| return float(rt2) |
| return None |
|
|
|
|
| def _event_link(manifest: Dict[str, Any], i: int) -> Optional[str]: |
| """ |
| Optional deep-link generation. |
| Supported: |
| - manifest.replay.base_url + manifest.replay.pattern with {run_id} and {i} |
| - manifest.run_url + ?i={i} |
| """ |
| run_id = manifest.get("run_id") |
| replay = manifest.get("replay") |
|
|
| if isinstance(replay, dict): |
| base = replay.get("base_url") |
| pattern = replay.get("pattern", "") |
| if isinstance(base, str) and isinstance(pattern, str) and run_id: |
| try: |
| return base.rstrip("/") + pattern.format(run_id=run_id, i=i) |
| except Exception: |
| return None |
|
|
| run_url = manifest.get("run_url") |
| if isinstance(run_url, str) and run_url: |
| |
| joiner = "&" if "?" in run_url else "?" |
| return f"{run_url}{joiner}i={i}" |
|
|
| return None |
|
|
|
|
| def diff_bundles(zip_a: str, zip_b: str) -> Dict[str, Any]: |
| A = load_bundle(zip_a) |
| B = load_bundle(zip_b) |
|
|
| ea = A.events |
| eb = B.events |
|
|
| alignment = build_alignment(ea, eb) |
|
|
| |
| first_div: Optional[int] = None |
| for row in alignment: |
| if row["status"] != "same": |
| first_div = row["i"] |
| break |
|
|
| |
| per_event: List[Dict[str, Any]] = [] |
| n = min(len(ea), len(eb)) |
| for i in range(n): |
| na = _event_core(ea[i]) |
| nb = _event_core(eb[i]) |
| if na == nb: |
| continue |
|
|
| diffs = _json_diff(na, nb) |
| item: Dict[str, Any] = { |
| "i": i, |
| "kind_a": ea[i].get("kind"), |
| "kind_b": eb[i].get("kind"), |
| "step_a": ea[i].get("step"), |
| "step_b": eb[i].get("step"), |
| "class": _classify_divergence(ea[i].get("kind"), eb[i].get("kind")), |
| "diffs": diffs[:200], |
| "link_a": _event_link(A.manifest, i), |
| "link_b": _event_link(B.manifest, i), |
| } |
|
|
| ta = (ea[i].get("payload", {}) or {}).get("text") |
| tb = (eb[i].get("payload", {}) or {}).get("text") |
| if isinstance(ta, str) and isinstance(tb, str) and ta != tb: |
| item["text_unified_diff"] = _text_delta(ta, tb)[:20000] |
|
|
| per_event.append(item) |
|
|
| diff_count = sum(1 for r in alignment if r["status"] == "diff") |
| missing_count = sum(1 for r in alignment if r["status"] in ("missing_in_A", "missing_in_B")) |
|
|
| ra = _extract_final_reward(ea) |
| rb = _extract_final_reward(eb) |
| reward_delta = (rb - ra) if (ra is not None and rb is not None) else None |
|
|
| |
| counts: Dict[str, int] = {} |
| for item in per_event: |
| c = item["class"] |
| counts[c] = counts.get(c, 0) + 1 |
|
|
| summary: Dict[str, Any] = { |
| "run_a": A.manifest.get("run_id"), |
| "run_b": B.manifest.get("run_id"), |
| "framework_a": A.manifest.get("framework"), |
| "framework_b": B.manifest.get("framework"), |
| "model_a": A.manifest.get("model_id"), |
| "model_b": B.manifest.get("model_id"), |
| "events_a": len(ea), |
| "events_b": len(eb), |
| "first_divergence_index": first_div, |
| "identical_until_index": first_div, |
| "diff_event_count": diff_count, |
| "missing_event_count": missing_count, |
| "final_reward_a": ra, |
| "final_reward_b": rb, |
| "final_reward_delta": reward_delta, |
| "run_link_a": _event_link(A.manifest, 0), |
| "run_link_b": _event_link(B.manifest, 0), |
| } |
|
|
| return { |
| "summary": summary, |
| "class_counts": counts, |
| "alignment": alignment, |
| "differences": per_event[:400], |
| } |