Spaces:
Running
Running
File size: 6,772 Bytes
4052d84 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | """
UndertriAI — Dataset Preparation
Converts indian_bail_judgments.csv -> structured JSONL episode files for 4 curriculum stages.
Usage:
python prepare_dataset.py --csv /path/to/indian_bail_judgments.csv --output ./episodes
"""
import ast, json, os, argparse, random, re, csv
from pathlib import Path
from typing import Any, Dict, List, Tuple
random.seed(42)
def parse_list_field(val: str) -> List[str]:
if not val or val.strip() in ("", "[]"):
return []
try:
result = ast.literal_eval(val)
return [str(x).strip() for x in result if str(x).strip()]
except Exception:
return [s.strip() for s in val.strip("[]").split(",") if s.strip()]
def parse_bool(val: str) -> bool:
return str(val).strip().lower() in ("true", "1", "yes")
def split_arguments(legal_issues: List[str]) -> Tuple[List[str], List[str]]:
prosecution, defence = [], []
for i, issue in enumerate(legal_issues):
low = issue.lower()
if any(k in low for k in ["cancel","reject","deny","gravity","tamper","abscond","repeat","custodial","investigate"]):
prosecution.append(issue)
elif any(k in low for k in ["grant","parity","cooperat","local","surety","eligible","half","undertrial"]):
defence.append(issue)
else:
(prosecution if i % 2 == 0 else defence).append(issue)
return prosecution, defence
def infer_custody_months(facts: str, bail_type: str) -> float:
for pat, mult in [(r"(\d+)\s+years?\s+in\s+(?:custody|prison|jail)", 12),
(r"(\d+)\s+months?\s+in\s+(?:custody|prison|jail)", 1)]:
m = re.search(pat, facts, re.IGNORECASE)
if m:
return float(int(m.group(1)) * mult)
return {"Regular": 6.0, "Anticipatory": 0.0, "Interim": 1.0}.get(bail_type, 4.0)
def infer_max_sentence(sections: List[str]) -> float:
MAP = {"302":99,"103":99,"307":10,"109":10,"376":14,"64":14,"304B":14,"80":14,
"395":10,"310":10,"392":10,"309":10,"420":7,"318":7,"498A":3,"85":3,
"406":3,"316":3,"465":2,"336":2,"323":1,"115":1,"354":2,"74":2,
"120B":3,"61":3,"506":2,"351":2}
return max((MAP.get(s.strip(), 0) for s in sections), default=5.0) or 5.0
def assign_stage(row: Dict[str, Any]) -> int:
landmark = parse_bool(row.get("landmark_case","False"))
cancel = parse_bool(row.get("bail_cancellation_case","False"))
region = row.get("region","")
if landmark and not cancel: return 1
if cancel: return 3
if region in {"Assam","Tamil Nadu","Kerala","Punjab","Maharashtra"}: return 4
return 2
def build_episode(row: Dict[str, Any]) -> Dict[str, Any]:
sections = parse_list_field(row.get("ipc_sections","[]"))
issues = parse_list_field(row.get("legal_issues","[]"))
pros, def_ = split_arguments(issues)
facts = row.get("facts","")
btype = row.get("bail_type","Regular")
reason = row.get("judgment_reason","").lower()
gt_outcome = "Bail Granted" if row.get("bail_outcome","").lower()=="granted" else "Bail Denied"
if any(k in reason for k in ["not a flight","local ties","cooperat","permanent resident"]):
gt_risk = "Low"
elif any(k in reason for k in ["abscond","tamper","influential","intimidat","repeat offend","serious"]):
gt_risk = "High"
else:
gt_risk = "Medium"
docs = ["FIR Copy","Charge Sheet"]
if "surety" in facts.lower(): docs.append("Surety Affidavit")
if "medical" in facts.lower(): docs.append("Medical Report")
if "prior" in facts.lower(): docs.append("Criminal History Record")
return {
"case_id": row.get("case_id",""), "case_title": row.get("case_title",""),
"court": row.get("court",""), "date": row.get("date",""),
"charge_sheet": facts, "ipc_sections": sections,
"crime_type": row.get("crime_type","Unknown"), "bail_type": btype,
"prosecution_arguments": pros, "defence_arguments": def_,
"legal_principles": parse_list_field(row.get("legal_principles_discussed","[]")),
"documents_available": docs, "summary": row.get("summary",""),
"accused_profile": {
"name": row.get("accused_name","Unknown"), "gender": row.get("accused_gender","Unknown"),
"occupation": None, "region": row.get("region","Unknown"),
"prior_cases": row.get("prior_cases","Unknown"), "bail_type": btype,
},
"custody_months": infer_custody_months(facts, btype),
"max_sentence_years": infer_max_sentence(sections),
"ground_truth": {
"outcome": gt_outcome, "implicit_flight_risk": gt_risk,
"judgment_reason": row.get("judgment_reason",""),
"outcome_detail": row.get("bail_outcome_label_detailed",""),
"bias_flag": parse_bool(row.get("bias_flag","False")),
"parity_argument_used": parse_bool(row.get("parity_argument_used","False")),
},
"curriculum_stage": assign_stage(row),
"landmark_case": parse_bool(row.get("landmark_case","False")),
"bail_cancellation_case": parse_bool(row.get("bail_cancellation_case","False")),
"region": row.get("region","Unknown"),
"special_laws": row.get("special_laws",""),
"schema_drift_eligible": row.get("region","") in {"Assam","Tamil Nadu","Kerala","Punjab","Maharashtra"},
}
def prepare(csv_path: str, output_dir: str) -> None:
Path(output_dir).mkdir(parents=True, exist_ok=True)
stages: Dict[int, list] = {1:[], 2:[], 3:[], 4:[]}
with open(csv_path, "r", encoding="utf-8", errors="replace") as f:
for row in csv.DictReader(f):
try:
ep = build_episode(row)
stages[ep["curriculum_stage"]].append(ep)
except Exception as e:
print(f" [WARN] Skipping {row.get('case_id')}: {e}")
all_eps = []
for stage, eps in stages.items():
random.shuffle(eps)
out = os.path.join(output_dir, f"episodes_stage_{stage}.jsonl")
with open(out, "w", encoding="utf-8") as f:
for ep in eps: f.write(json.dumps(ep, ensure_ascii=False)+"\n")
print(f" Stage {stage}: {len(eps):4d} episodes -> {out}")
all_eps.extend(eps)
random.shuffle(all_eps)
with open(os.path.join(output_dir,"episodes_all.jsonl"), "w", encoding="utf-8") as f:
for ep in all_eps: f.write(json.dumps(ep, ensure_ascii=False)+"\n")
print(f"\nDone. Total: {len(all_eps)} episodes | stages: { {k:len(v) for k,v in stages.items()} }")
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--csv", required=True)
p.add_argument("--output", default="./episodes")
args = p.parse_args()
prepare(args.csv, args.output)
|