| |
| """ |
| Pre-training data quality inspector for Parlay JSONL episode files. |
| Read-only: loads JSONL and prints statistics and RED FLAGS. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import statistics |
| from collections import Counter, defaultdict |
| from pathlib import Path |
| from typing import Any |
|
|
|
|
| def _safe_float(x: Any, default: float = 0.0) -> float: |
| try: |
| return float(x) |
| except (TypeError, ValueError): |
| return default |
|
|
|
|
| def _percentile_sorted(sorted_vals: list[float], p: float) -> float: |
| if not sorted_vals: |
| return 0.0 |
| k = (len(sorted_vals) - 1) * p / 100.0 |
| f = math.floor(k) |
| c = math.ceil(k) |
| if f == c: |
| return sorted_vals[int(k)] |
| return sorted_vals[f] * (c - k) + sorted_vals[c] * (k - f) |
|
|
|
|
| def _outcome_bucket(rec: dict[str, Any]) -> str: |
| tr = (rec.get("termination_reason") or "") or "" |
| tr_l = tr.lower() |
| if rec.get("deal_reached") is True or tr_l in ("deal_reached", "deal", "agreement"): |
| return "deal" |
| if "zopa_collapsed" in tr_l or tr_l == "zopa_collapsed": |
| return "zopa_collapsed" |
| if "walk" in tr_l or tr_l in ("walk_away", "walkaway"): |
| return "walk_away" |
| if tr_l in ("max_turns",) or (rec.get("deal_reached") is False and "max" in tr_l): |
| return "max_turns" |
| if tr_l: |
| return f"other:{tr_l}" |
| if rec.get("deal_reached") is False and rec.get("final_price") is None: |
| return "no_deal_or_unknown" |
| return "unknown" |
|
|
|
|
| def _utterance_lengths(conversation: Any) -> list[int]: |
| if not isinstance(conversation, list): |
| return [] |
| out: list[int] = [] |
| for turn in conversation: |
| if not isinstance(turn, dict): |
| continue |
| content = turn.get("content", "") |
| if isinstance(content, str) and content.strip(): |
| out.append(len(content)) |
| return out |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Inspect Parlay episode JSONL data quality") |
| parser.add_argument("--data", type=str, default="data/episodes.jsonl", help="Path to JSONL file") |
| args = parser.parse_args() |
|
|
| path = Path(args.data) |
| if not path.is_file(): |
| print(f"File not found: {path.resolve()}") |
| print("Run: python -m training.generate_data --episodes 80 --output data/episodes.jsonl") |
| return |
|
|
| records: list[dict[str, Any]] = [] |
| with path.open("r", encoding="utf-8") as f: |
| for line_no, line in enumerate(f, 1): |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| records.append(json.loads(line)) |
| except json.JSONDecodeError as e: |
| print(f"[WARN] Skipping line {line_no}: invalid JSON ({e})") |
|
|
| n = len(records) |
| print(f"=== Parlay data inspector: {path} ===") |
| print(f"Total episode records: {n}\n") |
|
|
| if n == 0: |
| print("No records to analyze.") |
| return |
|
|
| |
| missing_prompt = sum(1 for r in records if not str(r.get("prompt", "")).strip()) |
| missing_scenario = sum(1 for r in records if not str(r.get("scenario_id", "")).strip()) |
| missing_persona = sum(1 for r in records if not str(r.get("persona", "")).strip()) |
| missing_metadata = sum(1 for r in records if "metadata" not in r) |
| print("SCHEMA") |
| print(f" prompt present: {n - missing_prompt}/{n}") |
| print(f" scenario_id present: {n - missing_scenario}/{n}") |
| print(f" persona present: {n - missing_persona}/{n}") |
| print(f" metadata key present: {n - missing_metadata}/{n} (audit checklist; generate_data may omit)") |
| print() |
|
|
| def cum_reward(r: dict[str, Any]) -> float: |
| if "cumulative_reward" in r: |
| return _safe_float(r.get("cumulative_reward")) |
| return _safe_float(r.get("reward")) |
|
|
| rews = [cum_reward(r) for r in records] |
| rews_sorted = sorted(rews) |
|
|
| print("REWARD (total / cumulative - field 'reward' or 'cumulative_reward')") |
| print(f" min: {min(rews):.4f}") |
| print(f" max: {max(rews):.4f}") |
| print(f" mean: {statistics.mean(rews):.4f}") |
| print(f" std: {statistics.stdev(rews) if len(rews) > 1 else 0.0:.4f}") |
| print(f" p10: {_percentile_sorted(rews_sorted, 10):.4f}") |
| print(f" p90: {_percentile_sorted(rews_sorted, 90):.4f}") |
| print() |
|
|
| outcomes = [_outcome_bucket(r) for r in records] |
| oc = Counter(outcomes) |
| print("EPISODE OUTCOMES (best-effort from termination_reason + deal_reached)") |
| for k, v in sorted(oc.items(), key=lambda x: -x[1]): |
| print(f" {k}: {v} ({100.0 * v / n:.1f}%)") |
| print() |
|
|
| effs = [_safe_float(r.get("deal_efficiency"), 0.0) for r in records] |
| toms = [] |
| for r in records: |
| t = r.get("tom_accuracy_avg", r.get("tom_accuracy")) |
| toms.append(_safe_float(t, 0.0)) |
|
|
| print("EFFICIENCY (deal_efficiency, 0-1)") |
| if effs: |
| print(f" mean: {statistics.mean(effs):.4f} min: {min(effs):.4f} max: {max(effs):.4f}") |
| print() |
|
|
| print("TOM (tom_accuracy_avg or tom_accuracy)") |
| if toms: |
| print(f" mean: {statistics.mean(toms):.4f} min: {min(toms):.4f} max: {max(toms):.4f}") |
| print() |
|
|
| all_lens: list[int] = [] |
| degenerate_turns = 0 |
| total_turns = 0 |
| for r in records: |
| lens = _utterance_lengths(r.get("conversation")) |
| all_lens.extend(lens) |
| for L in lens: |
| total_turns += 1 |
| if L < 10: |
| degenerate_turns += 1 |
|
|
| print("UTTERANCE LENGTH (conversation[*].content)") |
| if all_lens: |
| print(f" mean chars/turn: {statistics.mean(all_lens):.1f}") |
| print(f" turns < 10 chars: {degenerate_turns}/{total_turns} ({100.0 * degenerate_turns / max(1, total_turns):.1f}%)") |
| else: |
| print(" (no conversation utterances found)") |
| print() |
|
|
| bluff_pos = sum(1 for r in records if int(r.get("bluffs_caught", 0) or 0) > 0) |
| drift_yes = sum(1 for r in records if r.get("drift_adapted") is True) |
|
|
| print("BLUFF RATE: episodes with bluffs_caught > 0") |
| print(f" {bluff_pos}/{n} ({100.0 * bluff_pos / n:.1f}%) (field may be missing in JSONL -> counted as 0)") |
| print() |
| print("DRIFT ADAPTATION: drift_adapted == True") |
| print(f" {drift_yes}/{n} ({100.0 * drift_yes / n:.1f}%)") |
| print() |
|
|
| by_persona: dict[str, list[dict]] = defaultdict(list) |
| by_scenario: dict[str, list[dict]] = defaultdict(list) |
| for r in records: |
| p = str(r.get("persona", "??")) |
| s = str(r.get("scenario_id", "??")) |
| by_persona[p].append(r) |
| by_scenario[s].append(r) |
|
|
| print("=== PER-PERSONA ===") |
| for p in sorted(by_persona.keys()): |
| grp = by_persona[p] |
| m = len(grp) |
| pr = [cum_reward(x) for x in grp] |
| pe = [_safe_float(x.get("deal_efficiency"), 0) for x in grp] |
| pt = [_safe_float(x.get("tom_accuracy_avg", x.get("tom_accuracy")), 0) for x in grp] |
| po = [_outcome_bucket(x) for x in grp] |
| dr = sum(1 for f in po if f == "deal") / m |
| print( |
| f" {p}: n={m} mean_reward={statistics.mean(pr) if pr else 0:.2f} " |
| f"mean_eff={statistics.mean(pe) if pe else 0:.3f} mean_tom={statistics.mean(pt) if pt else 0:.3f} deal_rate={dr:.2%}" |
| ) |
| print() |
|
|
| print("=== PER-SCENARIO ===") |
| for s in sorted(by_scenario.keys()): |
| grp = by_scenario[s] |
| m = len(grp) |
| pr = [cum_reward(x) for x in grp] |
| pe = [_safe_float(x.get("deal_efficiency"), 0) for x in grp] |
| pt = [_safe_float(x.get("tom_accuracy_avg", x.get("tom_accuracy")), 0) for x in grp] |
| po = [_outcome_bucket(x) for x in grp] |
| dr = sum(1 for f in po if f == "deal") / m |
| print( |
| f" {s}: n={m} mean_reward={statistics.mean(pr) if pr else 0:.2f} " |
| f"mean_eff={statistics.mean(pe) if pe else 0:.3f} mean_tom={statistics.mean(pt) if pt else 0:.3f} deal_rate={dr:.2%}" |
| ) |
| print() |
|
|
| |
| print("=== RED FLAGS ===") |
| flags: list[str] = [] |
|
|
| bad_rew = sum(1 for x in rews if x < -50) / n |
| if bad_rew > 0.30: |
| flags.append(f"> 30% episodes with total reward < -50 ({100 * bad_rew:.1f}%)") |
|
|
| max_turns_rate = sum(1 for o in outcomes if o == "max_turns") / n |
| if max_turns_rate > 0.40: |
| flags.append(f"> 40% ending in max_turns ({100 * max_turns_rate:.1f}%)") |
|
|
| drift_rate = drift_yes / n |
| if drift_rate < 0.10: |
| flags.append(f"< 10% drift_adapted ({100 * drift_rate:.1f}%)") |
|
|
| for p, grp in by_persona.items(): |
| po = [_outcome_bucket(x) for x in grp] |
| dr = sum(1 for f in po if f == "deal") / len(grp) if grp else 0.0 |
| if dr == 0.0 and len(grp) >= 3: |
| flags.append(f"Persona {p!r} has 0% deal rate (n={len(grp)})") |
|
|
| for s, grp in by_scenario.items(): |
| po = [_outcome_bucket(x) for x in grp] |
| dr = sum(1 for f in po if f == "deal") / len(grp) if grp else 0.0 |
| if dr == 0.0 and len(grp) >= 3: |
| flags.append(f"Scenario {s!r} has 0% deal rate (n={len(grp)})") |
|
|
| if all_lens and statistics.mean(all_lens) < 20.0: |
| flags.append(f"Mean utterance length {statistics.mean(all_lens):.1f} chars < 20 (possibly degenerate)") |
|
|
| if max(rews) > 400: |
| flags.append(f"At least one episode with total reward > 400 (max={max(rews):.2f}) - check for scale bugs or rare combo") |
|
|
| if missing_metadata == n: |
| flags.append("No record has a top-level 'metadata' key (optional for training; audit asked for it)") |
|
|
| if not flags: |
| print(" (none triggered)") |
| else: |
| for f in flags: |
| print(f" * {f}") |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|