""" UndertriAI — Dataset Preparation Converts indian_bail_judgments.csv -> structured JSONL episode files for 4 curriculum stages. Usage: python prepare_dataset.py --csv /path/to/indian_bail_judgments.csv --output ./episodes """ import ast, json, os, argparse, random, re, csv from pathlib import Path from typing import Any, Dict, List, Tuple random.seed(42) def parse_list_field(val: str) -> List[str]: if not val or val.strip() in ("", "[]"): return [] try: result = ast.literal_eval(val) return [str(x).strip() for x in result if str(x).strip()] except Exception: return [s.strip() for s in val.strip("[]").split(",") if s.strip()] def parse_bool(val: str) -> bool: return str(val).strip().lower() in ("true", "1", "yes") def split_arguments(legal_issues: List[str]) -> Tuple[List[str], List[str]]: prosecution, defence = [], [] for i, issue in enumerate(legal_issues): low = issue.lower() if any(k in low for k in ["cancel","reject","deny","gravity","tamper","abscond","repeat","custodial","investigate"]): prosecution.append(issue) elif any(k in low for k in ["grant","parity","cooperat","local","surety","eligible","half","undertrial"]): defence.append(issue) else: (prosecution if i % 2 == 0 else defence).append(issue) return prosecution, defence def infer_custody_months(facts: str, bail_type: str) -> float: for pat, mult in [(r"(\d+)\s+years?\s+in\s+(?:custody|prison|jail)", 12), (r"(\d+)\s+months?\s+in\s+(?:custody|prison|jail)", 1)]: m = re.search(pat, facts, re.IGNORECASE) if m: return float(int(m.group(1)) * mult) return {"Regular": 6.0, "Anticipatory": 0.0, "Interim": 1.0}.get(bail_type, 4.0) def infer_max_sentence(sections: List[str]) -> float: MAP = {"302":99,"103":99,"307":10,"109":10,"376":14,"64":14,"304B":14,"80":14, "395":10,"310":10,"392":10,"309":10,"420":7,"318":7,"498A":3,"85":3, "406":3,"316":3,"465":2,"336":2,"323":1,"115":1,"354":2,"74":2, "120B":3,"61":3,"506":2,"351":2} return max((MAP.get(s.strip(), 0) for s in sections), default=5.0) or 5.0 def assign_stage(row: Dict[str, Any]) -> int: landmark = parse_bool(row.get("landmark_case","False")) cancel = parse_bool(row.get("bail_cancellation_case","False")) region = row.get("region","") if landmark and not cancel: return 1 if cancel: return 3 if region in {"Assam","Tamil Nadu","Kerala","Punjab","Maharashtra"}: return 4 return 2 def build_episode(row: Dict[str, Any]) -> Dict[str, Any]: sections = parse_list_field(row.get("ipc_sections","[]")) issues = parse_list_field(row.get("legal_issues","[]")) pros, def_ = split_arguments(issues) facts = row.get("facts","") btype = row.get("bail_type","Regular") reason = row.get("judgment_reason","").lower() gt_outcome = "Bail Granted" if row.get("bail_outcome","").lower()=="granted" else "Bail Denied" if any(k in reason for k in ["not a flight","local ties","cooperat","permanent resident"]): gt_risk = "Low" elif any(k in reason for k in ["abscond","tamper","influential","intimidat","repeat offend","serious"]): gt_risk = "High" else: gt_risk = "Medium" docs = ["FIR Copy","Charge Sheet"] if "surety" in facts.lower(): docs.append("Surety Affidavit") if "medical" in facts.lower(): docs.append("Medical Report") if "prior" in facts.lower(): docs.append("Criminal History Record") return { "case_id": row.get("case_id",""), "case_title": row.get("case_title",""), "court": row.get("court",""), "date": row.get("date",""), "charge_sheet": facts, "ipc_sections": sections, "crime_type": row.get("crime_type","Unknown"), "bail_type": btype, "prosecution_arguments": pros, "defence_arguments": def_, "legal_principles": parse_list_field(row.get("legal_principles_discussed","[]")), "documents_available": docs, "summary": row.get("summary",""), "accused_profile": { "name": row.get("accused_name","Unknown"), "gender": row.get("accused_gender","Unknown"), "occupation": None, "region": row.get("region","Unknown"), "prior_cases": row.get("prior_cases","Unknown"), "bail_type": btype, }, "custody_months": infer_custody_months(facts, btype), "max_sentence_years": infer_max_sentence(sections), "ground_truth": { "outcome": gt_outcome, "implicit_flight_risk": gt_risk, "judgment_reason": row.get("judgment_reason",""), "outcome_detail": row.get("bail_outcome_label_detailed",""), "bias_flag": parse_bool(row.get("bias_flag","False")), "parity_argument_used": parse_bool(row.get("parity_argument_used","False")), }, "curriculum_stage": assign_stage(row), "landmark_case": parse_bool(row.get("landmark_case","False")), "bail_cancellation_case": parse_bool(row.get("bail_cancellation_case","False")), "region": row.get("region","Unknown"), "special_laws": row.get("special_laws",""), "schema_drift_eligible": row.get("region","") in {"Assam","Tamil Nadu","Kerala","Punjab","Maharashtra"}, } def prepare(csv_path: str, output_dir: str) -> None: Path(output_dir).mkdir(parents=True, exist_ok=True) stages: Dict[int, list] = {1:[], 2:[], 3:[], 4:[]} with open(csv_path, "r", encoding="utf-8", errors="replace") as f: for row in csv.DictReader(f): try: ep = build_episode(row) stages[ep["curriculum_stage"]].append(ep) except Exception as e: print(f" [WARN] Skipping {row.get('case_id')}: {e}") all_eps = [] for stage, eps in stages.items(): random.shuffle(eps) out = os.path.join(output_dir, f"episodes_stage_{stage}.jsonl") with open(out, "w", encoding="utf-8") as f: for ep in eps: f.write(json.dumps(ep, ensure_ascii=False)+"\n") print(f" Stage {stage}: {len(eps):4d} episodes -> {out}") all_eps.extend(eps) random.shuffle(all_eps) with open(os.path.join(output_dir,"episodes_all.jsonl"), "w", encoding="utf-8") as f: for ep in all_eps: f.write(json.dumps(ep, ensure_ascii=False)+"\n") print(f"\nDone. Total: {len(all_eps)} episodes | stages: { {k:len(v) for k,v in stages.items()} }") if __name__ == "__main__": p = argparse.ArgumentParser() p.add_argument("--csv", required=True) p.add_argument("--output", default="./episodes") args = p.parse_args() prepare(args.csv, args.output)