cricket-captain-llm / scripts /curate_transitions.py
pratinavseth's picture
feat: align benchmark data and training roadmap
d45f009
"""Cricsheet -> adaptive cricket datasets.
Builds two artifacts:
1. Existing Markov transition table for simulator compatibility.
2. Rich ball-outcome records for adaptive T20 captaincy modeling.
Usage:
python scripts/curate_transitions.py --format t20
python scripts/curate_transitions.py --local-zip path/to/t20s_json.zip --format t20
python scripts/curate_transitions.py --validate-only --format t20
"""
import argparse
import json
import os
import pickle
import sys
import urllib.request
import zipfile
from collections import defaultdict
from pathlib import Path
_ROOT = Path(__file__).parent.parent
_OUT_DIR = _ROOT / "data" / "processed"
_TRANSITION_PATH = _OUT_DIR / "cricket_transitions_v1.pkl"
_BALL_OUTCOMES_TEMPLATE = "ball_outcomes_{fmt}_v1.pkl"
CRICSHEET_URLS = {
"odi": "https://cricsheet.org/downloads/odis_json.zip",
"t20": "https://cricsheet.org/downloads/t20s_json.zip",
}
def over_to_phase(over: int, fmt: str = "odi") -> str:
if over <= 5:
return "powerplay"
if fmt == "t20":
return "middle" if over <= 15 else "death"
return "middle" if over <= 35 else "death"
_SPIN_KEYWORDS = {"off", "leg", "slow", "orthodox", "chinaman", "googly", "doosra", "spin"}
def _is_spin(bowling_style: str) -> bool:
return bool(bowling_style) and any(kw in bowling_style.lower() for kw in _SPIN_KEYWORDS)
def _dismissal_kind(delivery: dict) -> str | None:
wickets = delivery.get("wickets") or []
if not wickets:
return None
return str(wickets[0].get("kind", "unknown"))
def _is_legal_delivery(delivery: dict) -> bool:
extras = delivery.get("extras", {}) or {}
return "wides" not in extras and "noballs" not in extras
def _match_id(name: str, data: dict) -> str:
registry = data.get("info", {}).get("registry", {}).get("people", {})
dates = data.get("info", {}).get("dates", [])
date = dates[0] if dates else "unknown-date"
return f"{Path(name).stem}:{date}:{len(registry)}"
def download_zip(fmt: str, dest_path: str) -> None:
url = CRICSHEET_URLS[fmt]
print(f"Downloading {url} ...", flush=True)
urllib.request.urlretrieve(url, dest_path)
size_mb = os.path.getsize(dest_path) / 1e6
print(f"Downloaded {size_mb:.1f} MB -> {dest_path}")
def parse_match(data: dict, fmt: str = "odi", match_name: str = "unknown") -> list[dict]:
"""Return rich delivery records for both innings."""
innings = data.get("innings", [])
if not innings:
return []
bowler_registry: dict[str, str] = {}
bowling_styles = data.get("info", {}).get("bowling_style", {})
for bowler, style in bowling_styles.items():
bowler_registry[bowler] = "spin" if _is_spin(style) else "pace"
max_overs = 20 if fmt == "t20" else 50
first_innings_score: int | None = None
records: list[dict] = []
for innings_idx, inning in enumerate(innings[:2]):
team = inning.get("team", "")
target = first_innings_score + 1 if innings_idx == 1 and first_innings_score is not None else None
wickets_lost = 0
total_runs = 0
legal_balls = 0
for over_block in inning.get("overs", []):
over_num = int(over_block.get("over", 0))
phase = over_to_phase(over_num, fmt)
for ball_idx, delivery in enumerate(over_block.get("deliveries", [])):
bowler = delivery.get("bowler", "")
batter = delivery.get("batter", "")
non_striker = delivery.get("non_striker", "")
bowler_type = bowler_registry.get(bowler, "pace")
score_band = min(total_runs // 10, 49)
legal = _is_legal_delivery(delivery)
balls_remaining = max(0, max_overs * 6 - legal_balls)
runs_required = max(0, target - total_runs) if target is not None else None
required_rate = (
runs_required / max(balls_remaining / 6, 1e-6)
if runs_required is not None
else None
)
runs = delivery.get("runs", {})
runs_batter = int(runs.get("batter", 0))
runs_extras = int(runs.get("extras", 0))
runs_total = int(runs.get("total", runs_batter + runs_extras))
dismissal_kind = _dismissal_kind(delivery)
is_wicket = dismissal_kind is not None
records.append({
"match_id": _match_id(match_name, data),
"format": fmt,
"innings_index": innings_idx,
"innings_type": "first" if innings_idx == 0 else "second",
"batting_team": team,
"over": over_num,
"ball_in_over": ball_idx,
"legal_ball_index": legal_balls,
"legal_delivery": legal,
"wickets_before": min(wickets_lost, 9),
"score_before": total_runs,
"score_band": score_band,
"target": target,
"runs_required": runs_required,
"required_rate": round(required_rate, 4) if required_rate is not None else None,
"phase": phase,
"batter": batter,
"non_striker": non_striker,
"bowler": bowler,
"bowler_type": bowler_type,
"runs_batter": runs_batter,
"runs_extras": runs_extras,
"runs_total": runs_total,
"wicket": is_wicket,
"dismissal_kind": dismissal_kind,
})
total_runs += runs_total
if legal:
legal_balls += 1
if is_wicket:
wickets_lost += 1
if innings_idx == 0:
first_innings_score = total_runs
return records
def build_table(all_records: list[dict]) -> dict:
"""Return transition table compatible with MarkovCricketEngine."""
counts: dict = defaultdict(lambda: {"wickets": 0, "runs": defaultdict(int), "total": 0})
for r in all_records:
if not r.get("legal_delivery", True):
continue
key5 = (r["over"], r["wickets_before"], r["score_band"], r["phase"], r["bowler_type"])
key4 = (r["over"], r["wickets_before"], r["score_band"], r["phase"], None)
for key in (key5, key4):
counts[key]["total"] += 1
if r["wicket"]:
counts[key]["wickets"] += 1
else:
counts[key]["runs"][r["runs_batter"]] += 1
table = {}
for key, c in counts.items():
n = c["total"]
if n == 0:
continue
wicket_prob = c["wickets"] / n
safe_n = n - c["wickets"]
run_dist = {0: 1.0} if safe_n == 0 else {r: cnt / safe_n for r, cnt in c["runs"].items()}
table[key] = {
"wicket_prob": round(wicket_prob, 6),
"run_dist": {int(r): round(p, 6) for r, p in run_dist.items()},
"sample_size": n,
}
return table
def run(local_zip: str | None = None, fmt: str = "t20") -> None:
_OUT_DIR.mkdir(parents=True, exist_ok=True)
zip_path = local_zip or str(_ROOT / "data" / f"{fmt}s_json.zip")
if not os.path.exists(zip_path):
if local_zip:
sys.exit(f"Zip not found: {local_zip}")
download_zip(fmt, zip_path)
print(f"Parsing {zip_path} ...", flush=True)
all_records: list[dict] = []
match_count = 0
error_count = 0
with zipfile.ZipFile(zip_path) as zf:
json_names = [n for n in zf.namelist() if n.endswith(".json") and not n.endswith("README")]
print(f" {len(json_names)} match files found")
for name in json_names:
try:
with zf.open(name) as f:
data = json.load(f)
all_records.extend(parse_match(data, fmt=fmt, match_name=name))
match_count += 1
if match_count % 200 == 0:
print(f" Parsed {match_count}/{len(json_names)} matches ...", flush=True)
except Exception as e:
error_count += 1
if error_count <= 5:
print(f" WARN: {name}: {e}")
print(f"Parsed {match_count} matches, {len(all_records):,} deliveries ({error_count} errors)")
ball_path = _OUT_DIR / _BALL_OUTCOMES_TEMPLATE.format(fmt=fmt)
with open(ball_path, "wb") as f:
pickle.dump(all_records, f)
print(f"Wrote {ball_path} ({os.path.getsize(ball_path) / 1024:.0f} KB)")
print("Building transition table ...", flush=True)
table = build_table(all_records)
key5_count = sum(1 for k in table if k[4] is not None)
key4_count = sum(1 for k in table if k[4] is None)
high_conf = sum(1 for v in table.values() if v["sample_size"] >= 50)
print(f" {key5_count} key5 entries, {key4_count} key4 entries, {high_conf} with n>=50")
with open(_TRANSITION_PATH, "wb") as f:
pickle.dump(table, f)
print(f"Wrote {_TRANSITION_PATH} ({os.path.getsize(_TRANSITION_PATH) / 1024:.0f} KB)")
def validate(fmt: str = "t20", transition_path: str = str(_TRANSITION_PATH)) -> None:
print(f"Validating {transition_path} ...")
with open(transition_path, "rb") as f:
table = pickle.load(f)
total = len(table)
high_conf = sum(1 for v in table.values() if v["sample_size"] >= 50)
phases = {k[3] for k in table}
bowler_types = {k[4] for k in table}
min_over = min(k[0] for k in table)
max_over = max(k[0] for k in table)
print(f" Total keys: {total}")
print(f" High-confidence (n>=50): {high_conf} ({100 * high_conf / total:.1f}%)")
print(f" Phases: {phases}")
print(f" Bowler types: {bowler_types}")
print(f" Over range: {min_over}-{max_over}")
for key, v in list(table.items())[:5]:
run_sum = sum(v["run_dist"].values())
assert abs(run_sum - 1.0) < 1e-3, f"run_dist doesn't sum to 1: {key}"
assert 0 <= v["wicket_prob"] <= 1, f"wicket_prob out of range: {key}"
ball_path = _OUT_DIR / _BALL_OUTCOMES_TEMPLATE.format(fmt=fmt)
if ball_path.exists():
with open(ball_path, "rb") as f:
records = pickle.load(f)
assert records, f"No records in {ball_path}"
sample = records[0]
required = {"innings_type", "target", "batter", "bowler", "dismissal_kind", "legal_delivery"}
missing = required - set(sample)
assert not missing, f"Rich records missing fields: {missing}"
print(f" Rich records: {len(records):,} at {ball_path}")
print(" All checks passed.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--format", default="t20", choices=["t20", "odi"])
parser.add_argument("--local-zip", default=None, help="Path to already-downloaded Cricsheet JSON zip")
parser.add_argument("--validate-only", action="store_true")
args = parser.parse_args()
if not args.validate_only:
run(local_zip=args.local_zip, fmt=args.format)
validate(fmt=args.format)