undertrial-ai / data /prepare_dataset.py
Draken1606's picture
first commit
4052d84
"""
UndertriAI — Dataset Preparation
Converts indian_bail_judgments.csv -> structured JSONL episode files for 4 curriculum stages.
Usage:
python prepare_dataset.py --csv /path/to/indian_bail_judgments.csv --output ./episodes
"""
import ast, json, os, argparse, random, re, csv
from pathlib import Path
from typing import Any, Dict, List, Tuple
random.seed(42)
def parse_list_field(val: str) -> List[str]:
if not val or val.strip() in ("", "[]"):
return []
try:
result = ast.literal_eval(val)
return [str(x).strip() for x in result if str(x).strip()]
except Exception:
return [s.strip() for s in val.strip("[]").split(",") if s.strip()]
def parse_bool(val: str) -> bool:
return str(val).strip().lower() in ("true", "1", "yes")
def split_arguments(legal_issues: List[str]) -> Tuple[List[str], List[str]]:
prosecution, defence = [], []
for i, issue in enumerate(legal_issues):
low = issue.lower()
if any(k in low for k in ["cancel","reject","deny","gravity","tamper","abscond","repeat","custodial","investigate"]):
prosecution.append(issue)
elif any(k in low for k in ["grant","parity","cooperat","local","surety","eligible","half","undertrial"]):
defence.append(issue)
else:
(prosecution if i % 2 == 0 else defence).append(issue)
return prosecution, defence
def infer_custody_months(facts: str, bail_type: str) -> float:
for pat, mult in [(r"(\d+)\s+years?\s+in\s+(?:custody|prison|jail)", 12),
(r"(\d+)\s+months?\s+in\s+(?:custody|prison|jail)", 1)]:
m = re.search(pat, facts, re.IGNORECASE)
if m:
return float(int(m.group(1)) * mult)
return {"Regular": 6.0, "Anticipatory": 0.0, "Interim": 1.0}.get(bail_type, 4.0)
def infer_max_sentence(sections: List[str]) -> float:
MAP = {"302":99,"103":99,"307":10,"109":10,"376":14,"64":14,"304B":14,"80":14,
"395":10,"310":10,"392":10,"309":10,"420":7,"318":7,"498A":3,"85":3,
"406":3,"316":3,"465":2,"336":2,"323":1,"115":1,"354":2,"74":2,
"120B":3,"61":3,"506":2,"351":2}
return max((MAP.get(s.strip(), 0) for s in sections), default=5.0) or 5.0
def assign_stage(row: Dict[str, Any]) -> int:
landmark = parse_bool(row.get("landmark_case","False"))
cancel = parse_bool(row.get("bail_cancellation_case","False"))
region = row.get("region","")
if landmark and not cancel: return 1
if cancel: return 3
if region in {"Assam","Tamil Nadu","Kerala","Punjab","Maharashtra"}: return 4
return 2
def build_episode(row: Dict[str, Any]) -> Dict[str, Any]:
sections = parse_list_field(row.get("ipc_sections","[]"))
issues = parse_list_field(row.get("legal_issues","[]"))
pros, def_ = split_arguments(issues)
facts = row.get("facts","")
btype = row.get("bail_type","Regular")
reason = row.get("judgment_reason","").lower()
gt_outcome = "Bail Granted" if row.get("bail_outcome","").lower()=="granted" else "Bail Denied"
if any(k in reason for k in ["not a flight","local ties","cooperat","permanent resident"]):
gt_risk = "Low"
elif any(k in reason for k in ["abscond","tamper","influential","intimidat","repeat offend","serious"]):
gt_risk = "High"
else:
gt_risk = "Medium"
docs = ["FIR Copy","Charge Sheet"]
if "surety" in facts.lower(): docs.append("Surety Affidavit")
if "medical" in facts.lower(): docs.append("Medical Report")
if "prior" in facts.lower(): docs.append("Criminal History Record")
return {
"case_id": row.get("case_id",""), "case_title": row.get("case_title",""),
"court": row.get("court",""), "date": row.get("date",""),
"charge_sheet": facts, "ipc_sections": sections,
"crime_type": row.get("crime_type","Unknown"), "bail_type": btype,
"prosecution_arguments": pros, "defence_arguments": def_,
"legal_principles": parse_list_field(row.get("legal_principles_discussed","[]")),
"documents_available": docs, "summary": row.get("summary",""),
"accused_profile": {
"name": row.get("accused_name","Unknown"), "gender": row.get("accused_gender","Unknown"),
"occupation": None, "region": row.get("region","Unknown"),
"prior_cases": row.get("prior_cases","Unknown"), "bail_type": btype,
},
"custody_months": infer_custody_months(facts, btype),
"max_sentence_years": infer_max_sentence(sections),
"ground_truth": {
"outcome": gt_outcome, "implicit_flight_risk": gt_risk,
"judgment_reason": row.get("judgment_reason",""),
"outcome_detail": row.get("bail_outcome_label_detailed",""),
"bias_flag": parse_bool(row.get("bias_flag","False")),
"parity_argument_used": parse_bool(row.get("parity_argument_used","False")),
},
"curriculum_stage": assign_stage(row),
"landmark_case": parse_bool(row.get("landmark_case","False")),
"bail_cancellation_case": parse_bool(row.get("bail_cancellation_case","False")),
"region": row.get("region","Unknown"),
"special_laws": row.get("special_laws",""),
"schema_drift_eligible": row.get("region","") in {"Assam","Tamil Nadu","Kerala","Punjab","Maharashtra"},
}
def prepare(csv_path: str, output_dir: str) -> None:
Path(output_dir).mkdir(parents=True, exist_ok=True)
stages: Dict[int, list] = {1:[], 2:[], 3:[], 4:[]}
with open(csv_path, "r", encoding="utf-8", errors="replace") as f:
for row in csv.DictReader(f):
try:
ep = build_episode(row)
stages[ep["curriculum_stage"]].append(ep)
except Exception as e:
print(f" [WARN] Skipping {row.get('case_id')}: {e}")
all_eps = []
for stage, eps in stages.items():
random.shuffle(eps)
out = os.path.join(output_dir, f"episodes_stage_{stage}.jsonl")
with open(out, "w", encoding="utf-8") as f:
for ep in eps: f.write(json.dumps(ep, ensure_ascii=False)+"\n")
print(f" Stage {stage}: {len(eps):4d} episodes -> {out}")
all_eps.extend(eps)
random.shuffle(all_eps)
with open(os.path.join(output_dir,"episodes_all.jsonl"), "w", encoding="utf-8") as f:
for ep in all_eps: f.write(json.dumps(ep, ensure_ascii=False)+"\n")
print(f"\nDone. Total: {len(all_eps)} episodes | stages: { {k:len(v) for k,v in stages.items()} }")
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--csv", required=True)
p.add_argument("--output", default="./episodes")
args = p.parse_args()
prepare(args.csv, args.output)