Shabista Sehar commited on
Commit
a085ad1
·
1 Parent(s): d8f8a45
client.py CHANGED
@@ -19,7 +19,7 @@ from .models import (
19
  AssessSuretyAction, ClassifyBailTypeAction,
20
  ReadSubmissionsAction, AssessFlightRiskAction,
21
  CheckCaseFactorsAction, ApplyProportionalityAction,
22
- SubmitMemoAction,
23
  StepResult,
24
  )
25
 
@@ -141,5 +141,6 @@ __all__ = [
141
  "AssessFlightRiskAction",
142
  "CheckCaseFactorsAction",
143
  "ApplyProportionalityAction",
 
144
  "SubmitMemoAction",
145
  ]
 
19
  AssessSuretyAction, ClassifyBailTypeAction,
20
  ReadSubmissionsAction, AssessFlightRiskAction,
21
  CheckCaseFactorsAction, ApplyProportionalityAction,
22
+ PullCriminalHistoryAction, SubmitMemoAction,
23
  StepResult,
24
  )
25
 
 
141
  "AssessFlightRiskAction",
142
  "CheckCaseFactorsAction",
143
  "ApplyProportionalityAction",
144
+ "PullCriminalHistoryAction",
145
  "SubmitMemoAction",
146
  ]
demo_comparison.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UndertriAI — Before/After Demo Comparison Script
3
+
4
+ Demonstrates the environment using DEMO001 (Ramesh Kumar — IPC 420 cheating case).
5
+ Shows two simulated agent trajectories on the SAME case:
6
+ 1. Naive agent: skips tools, guesses wrong
7
+ 2. Skilled agent: uses tools properly, reaches correct conclusion
8
+
9
+ This script does NOT require a trained model — it simulates both agent
10
+ behaviors programmatically to show the reward difference.
11
+
12
+ Usage:
13
+ python demo_comparison.py
14
+ """
15
+
16
+ import sys
17
+ import os
18
+ import json
19
+
20
+ # Add parent of project root so relative imports within the package work
21
+ _project_root = os.path.dirname(os.path.abspath(__file__))
22
+ _parent = os.path.dirname(_project_root)
23
+ _pkg_name = os.path.basename(_project_root)
24
+ if _parent not in sys.path:
25
+ sys.path.insert(0, _parent)
26
+
27
+ # Import via package name (needed for relative imports in server/)
28
+ _env_mod = __import__(f"{_pkg_name}.server.undertrial_environment", fromlist=["UndertriAIEnvironment"])
29
+ UndertriAIEnvironment = _env_mod.UndertriAIEnvironment
30
+
31
+ _models_mod = __import__(f"{_pkg_name}.models", fromlist=[
32
+ "ComputeStatutoryEligibilityAction", "AssessFlightRiskAction",
33
+ "ReadSubmissionsAction", "CheckCaseFactorsAction", "SubmitMemoAction",
34
+ ])
35
+ ComputeStatutoryEligibilityAction = _models_mod.ComputeStatutoryEligibilityAction
36
+ AssessFlightRiskAction = _models_mod.AssessFlightRiskAction
37
+ ReadSubmissionsAction = _models_mod.ReadSubmissionsAction
38
+ CheckCaseFactorsAction = _models_mod.CheckCaseFactorsAction
39
+ SubmitMemoAction = _models_mod.SubmitMemoAction
40
+
41
+
42
+ def run_demo():
43
+ """Run before/after comparison on DEMO001."""
44
+ print("=" * 65)
45
+ print(" UndertriAI — Before vs After Training Demo")
46
+ print(" Case: DEMO001 — Ramesh Kumar vs State of Delhi (IPC 420)")
47
+ print("=" * 65)
48
+
49
+ env = UndertriAIEnvironment()
50
+
51
+ # ================================================================
52
+ # NAIVE AGENT (simulates untrained model behavior)
53
+ # ================================================================
54
+ print("\n" + "─" * 65)
55
+ print(" NAIVE AGENT (before training)")
56
+ print("─" * 65)
57
+
58
+ obs = env.reset(stage=1, seed=0)
59
+ print(f" Case: {obs.case_title}")
60
+ print(f" Crime: {obs.crime_type} | Sections: {obs.ipc_sections}")
61
+ print(f" Custody: {env._episode.get('custody_months')} months")
62
+
63
+ # Naive agent: calls one tool minimally, then submits wrong answer
64
+ print("\n Step 1: Read submissions (both)")
65
+ result = env.step(ReadSubmissionsAction(
66
+ party="both",
67
+ ))
68
+ print(f" → {result.observation.action_result[:80]}...")
69
+
70
+ # Naive agent gets the outcome WRONG (denies bail when it should be granted)
71
+ print("\n Step 2: Submit memo (WRONG — denies bail)")
72
+ result = env.step(SubmitMemoAction(
73
+ flight_risk="High",
74
+ flight_risk_justification="Accused may flee",
75
+ statutory_eligible=False,
76
+ statutory_computation="Unknown sections, cannot determine",
77
+ grounds_for_bail=["None identified"],
78
+ grounds_against_bail=["Serious charge"],
79
+ recommended_outcome="Bail Denied",
80
+ recommended_conditions=[],
81
+ ))
82
+ naive_reward = result.reward
83
+ naive_info = result.info
84
+ print(f"\n NAIVE REWARD: {naive_reward:.4f}")
85
+ print(f" Outcome match: {naive_info.get('outcome_match', 'N/A')}")
86
+ print(f" Flight risk acc: {naive_info.get('flight_risk_accuracy', 'N/A')}")
87
+ print(f" Statutory acc: {naive_info.get('statutory_accuracy', 'N/A')}")
88
+ print(f" Condition score: {naive_info.get('condition_appropriateness', 'N/A')}")
89
+ print(f" Bias penalty: {naive_info.get('bias_penalty', 'N/A')}")
90
+ print(f" Ground truth: {naive_info.get('ground_truth_outcome', 'N/A')}")
91
+
92
+ # ================================================================
93
+ # SKILLED AGENT (simulates trained model behavior)
94
+ # ================================================================
95
+ print("\n" + "─" * 65)
96
+ print(" SKILLED AGENT (after training)")
97
+ print("─" * 65)
98
+
99
+ obs = env.reset(stage=1, seed=0) # Same case
100
+ print(f" Case: {obs.case_title}")
101
+
102
+ # Skilled agent: uses multiple relevant tools
103
+ print("\n Step 1: Read submissions (both)")
104
+ result = env.step(ReadSubmissionsAction(party="both"))
105
+ print(f" → {result.observation.action_result[:80]}...")
106
+
107
+ print("\n Step 2: Compute statutory eligibility")
108
+ result = env.step(ComputeStatutoryEligibilityAction(
109
+ sections_invoked=["420"],
110
+ max_sentence_years=7.0,
111
+ custody_months=8.0,
112
+ special_law_applicable=False,
113
+ ))
114
+ print(f" → {result.observation.action_result[:100]}...")
115
+
116
+ print("\n Step 3: Assess flight risk")
117
+ result = env.step(AssessFlightRiskAction(
118
+ severity_of_offence="moderate",
119
+ roots_in_community="Permanent resident of Delhi, family with minor children",
120
+ prior_absconding=False,
121
+ passport_status="unknown",
122
+ ))
123
+ print(f" �� {result.observation.action_result[:100]}...")
124
+
125
+ print("\n Step 4: Check case factors")
126
+ result = env.step(CheckCaseFactorsAction(
127
+ factors_to_check=["nature_of_offence", "criminal_history", "evidence_tampering"],
128
+ ))
129
+ print(f" → {result.observation.action_result[:100]}...")
130
+
131
+ # Skilled agent: correct outcome with proper reasoning
132
+ print("\n Step 5: Submit memo (CORRECT — grants bail with conditions)")
133
+ result = env.step(SubmitMemoAction(
134
+ flight_risk="Low",
135
+ flight_risk_justification=(
136
+ "Accused is a permanent resident of Delhi with family ties including "
137
+ "two minor children. No prior criminal record. IPC 420 is a moderate "
138
+ "offence. No evidence of prior absconding. Prosecution has not cited "
139
+ "any flight risk. Community roots are strong."
140
+ ),
141
+ statutory_eligible=False,
142
+ statutory_computation=(
143
+ "IPC Section 420: max sentence 7 years (84 months). "
144
+ "BNSS 479 threshold = 42 months (50%). "
145
+ "Time served = 8 months (9.5%). "
146
+ "Threshold NOT yet met — not eligible for default bail. "
147
+ "However, bail sought on merits, not statutory default."
148
+ ),
149
+ grounds_for_bail=[
150
+ "No prior criminal record — first-time offender",
151
+ "Permanent resident of Delhi with strong family ties",
152
+ "Two minor children dependent on accused",
153
+ "No flight risk identified by prosecution",
154
+ "Offence is non-violent (cheating, not bodily harm)",
155
+ ],
156
+ grounds_against_bail=[
157
+ "Investigation still pending per prosecution",
158
+ "Alleged fraud of Rs. 50,000",
159
+ ],
160
+ recommended_outcome="Bail Granted",
161
+ recommended_conditions=[
162
+ "Personal bond of Rs. 25,000 with one local surety",
163
+ "Weekly reporting to the concerned police station",
164
+ "Surrender passport if held",
165
+ "Not to leave Delhi without court permission",
166
+ "Cooperate with ongoing investigation",
167
+ ],
168
+ ))
169
+ skilled_reward = result.reward
170
+ skilled_info = result.info
171
+ print(f"\n SKILLED REWARD: {skilled_reward:.4f}")
172
+ print(f" Outcome match: {skilled_info.get('outcome_match', 'N/A')}")
173
+ print(f" Flight risk acc: {skilled_info.get('flight_risk_accuracy', 'N/A')}")
174
+ print(f" Statutory acc: {skilled_info.get('statutory_accuracy', 'N/A')}")
175
+ print(f" Condition score: {skilled_info.get('condition_appropriateness', 'N/A')}")
176
+ print(f" Bias penalty: {skilled_info.get('bias_penalty', 'N/A')}")
177
+ print(f" Ground truth: {skilled_info.get('ground_truth_outcome', 'N/A')}")
178
+
179
+ # ================================================================
180
+ # COMPARISON
181
+ # ================================================================
182
+ print("\n" + "═" * 65)
183
+ print(" COMPARISON SUMMARY")
184
+ print("═" * 65)
185
+ delta = skilled_reward - naive_reward
186
+ print(f" Naive agent reward: {naive_reward:.4f}")
187
+ print(f" Skilled agent reward: {skilled_reward:.4f}")
188
+ print(f" Improvement: {delta:+.4f} ({delta/max(0.01, abs(naive_reward))*100:+.0f}%)")
189
+ print()
190
+
191
+ # Component-by-component comparison
192
+ components = [
193
+ ("Outcome Match", "outcome_match"),
194
+ ("Flight Risk", "flight_risk_accuracy"),
195
+ ("Statutory", "statutory_accuracy"),
196
+ ("Conditions", "condition_appropriateness"),
197
+ ("Bias Penalty", "bias_penalty"),
198
+ ]
199
+ print(f" {'Component':<20} {'Naive':>8} {'Skilled':>8} {'Delta':>8}")
200
+ print(f" {'─'*20} {'─'*8} {'─'*8} {'─'*8}")
201
+ for name, key in components:
202
+ n = naive_info.get(key, 0)
203
+ s = skilled_info.get(key, 0)
204
+ d = s - n
205
+ sign = "+" if d >= 0 else ""
206
+ print(f" {name:<20} {n:>8.3f} {s:>8.3f} {sign}{d:>7.3f}")
207
+
208
+ print()
209
+ print(f" Ground truth: {skilled_info.get('ground_truth_outcome', '?')}")
210
+ print(f" Naive agent: Bail Denied (WRONG)")
211
+ print(f" Skilled agent: Bail Granted (CORRECT)")
212
+ print("═" * 65)
213
+
214
+ return {
215
+ "naive_reward": naive_reward,
216
+ "skilled_reward": skilled_reward,
217
+ "delta": delta,
218
+ }
219
+
220
+
221
+ if __name__ == "__main__":
222
+ results = run_demo()
models.py CHANGED
@@ -200,6 +200,7 @@ BailAction = Union[
200
  AssessFlightRiskAction,
201
  CheckCaseFactorsAction,
202
  ApplyProportionalityAction,
 
203
  SubmitMemoAction,
204
  ]
205
 
 
200
  AssessFlightRiskAction,
201
  CheckCaseFactorsAction,
202
  ApplyProportionalityAction,
203
+ PullCriminalHistoryAction,
204
  SubmitMemoAction,
205
  ]
206
 
openenv.yaml CHANGED
@@ -131,7 +131,7 @@ endpoints:
131
  training:
132
  method: GRPO
133
  framework: TRL + Unsloth
134
- model: unsloth/Qwen2.5-7B-Instruct
135
  notebook: training/UndertriAI_GRPO_Training.ipynb
136
  script: training/train_grpo.py
137
  modes:
 
131
  training:
132
  method: GRPO
133
  framework: TRL + Unsloth
134
+ model: unsloth/Qwen2.5-3B-Instruct
135
  notebook: training/UndertriAI_GRPO_Training.ipynb
136
  script: training/train_grpo.py
137
  modes:
pyproject.toml CHANGED
@@ -32,6 +32,7 @@ train = [
32
  "torch>=2.1.0",
33
  "datasets>=2.18.0",
34
  "transformers>=4.40.0",
 
35
  ]
36
 
37
  [project.scripts]
 
32
  "torch>=2.1.0",
33
  "datasets>=2.18.0",
34
  "transformers>=4.40.0",
35
+ "matplotlib>=3.7.0",
36
  ]
37
 
38
  [project.scripts]
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # UndertriAI — Server dependencies
2
+ fastapi>=0.110.0
3
+ uvicorn[standard]>=0.27.0
4
+ pydantic>=2.6.0
5
+ websockets>=12.0
6
+ openenv-core>=0.1.0
7
+ matplotlib>=3.7.0
8
+ httpx>=0.27.0
server/app.py CHANGED
@@ -4,6 +4,7 @@ Wraps UndertriAIEnvironment as an OpenEnv-compatible HTTP + WebSocket server.
4
  """
5
 
6
  import os
 
7
  from pathlib import Path
8
  from dataclasses import dataclass, field
9
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
@@ -13,6 +14,8 @@ import json
13
  import uuid
14
  from typing import List, Optional
15
 
 
 
16
  from .undertrial_environment import UndertriAIEnvironment
17
  from .performance_tracker import PerformanceTracker
18
  from .adaptive_selector import AdaptiveSelector
@@ -215,6 +218,11 @@ def step(payload: dict):
215
  v["curriculum_stage"] = stage
216
  env.dataset._episodes.setdefault(stage, []).append(v)
217
  session.synthetic_cases_generated += len(variants)
 
 
 
 
 
218
 
219
  return {
220
  "session_id": session_id,
@@ -255,6 +263,7 @@ def list_tools():
255
  {"name": "assess_flight_risk", "description": "Systematic flight risk assessment with scoring matrix"},
256
  {"name": "check_case_factors", "description": "Examine specific case factors (parity, evidence tampering, victim vulnerability)"},
257
  {"name": "apply_proportionality", "description": "Apply BNSS 479 proportionality: custody vs. max sentence vs. trial timeline"},
 
258
  {"name": "submit_memo", "description": "TERMINAL — Submit structured bail assessment memo"},
259
  ]
260
  }
@@ -337,7 +346,7 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str):
337
  AssessSuretyAction, ClassifyBailTypeAction,
338
  ReadSubmissionsAction, AssessFlightRiskAction,
339
  CheckCaseFactorsAction, ApplyProportionalityAction,
340
- SubmitMemoAction,
341
  )
342
  ACTION_MAP = {
343
  "request_document": RequestDocumentAction,
@@ -350,6 +359,7 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str):
350
  "assess_flight_risk": AssessFlightRiskAction,
351
  "check_case_factors": CheckCaseFactorsAction,
352
  "apply_proportionality": ApplyProportionalityAction,
 
353
  "submit_memo": SubmitMemoAction,
354
  }
355
  action_cls = ACTION_MAP.get(tool_name)
 
4
  """
5
 
6
  import os
7
+ import logging
8
  from pathlib import Path
9
  from dataclasses import dataclass, field
10
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
 
14
  import uuid
15
  from typing import List, Optional
16
 
17
+ logger = logging.getLogger("undertrial")
18
+
19
  from .undertrial_environment import UndertriAIEnvironment
20
  from .performance_tracker import PerformanceTracker
21
  from .adaptive_selector import AdaptiveSelector
 
218
  v["curriculum_stage"] = stage
219
  env.dataset._episodes.setdefault(stage, []).append(v)
220
  session.synthetic_cases_generated += len(variants)
221
+ for v in variants:
222
+ logger.info(
223
+ f"Synthetic case generated: {v['case_id']} "
224
+ f"({v.get('perturbation_type', 'unknown')})"
225
+ )
226
 
227
  return {
228
  "session_id": session_id,
 
263
  {"name": "assess_flight_risk", "description": "Systematic flight risk assessment with scoring matrix"},
264
  {"name": "check_case_factors", "description": "Examine specific case factors (parity, evidence tampering, victim vulnerability)"},
265
  {"name": "apply_proportionality", "description": "Apply BNSS 479 proportionality: custody vs. max sentence vs. trial timeline"},
266
+ {"name": "pull_criminal_history", "description": "Pull accused's prior criminal record, bail history, and conviction status"},
267
  {"name": "submit_memo", "description": "TERMINAL — Submit structured bail assessment memo"},
268
  ]
269
  }
 
346
  AssessSuretyAction, ClassifyBailTypeAction,
347
  ReadSubmissionsAction, AssessFlightRiskAction,
348
  CheckCaseFactorsAction, ApplyProportionalityAction,
349
+ PullCriminalHistoryAction, SubmitMemoAction,
350
  )
351
  ACTION_MAP = {
352
  "request_document": RequestDocumentAction,
 
359
  "assess_flight_risk": AssessFlightRiskAction,
360
  "check_case_factors": CheckCaseFactorsAction,
361
  "apply_proportionality": ApplyProportionalityAction,
362
+ "pull_criminal_history": PullCriminalHistoryAction,
363
  "submit_memo": SubmitMemoAction,
364
  }
365
  action_cls = ACTION_MAP.get(tool_name)
server/performance_tracker.py CHANGED
@@ -37,6 +37,9 @@ class PerformanceTracker:
37
 
38
  Thread-safe for single-session use (no locks needed).
39
  All public methods handle missing/malformed input gracefully.
 
 
 
40
  """
41
 
42
  def __init__(self, alpha: float = 0.1):
 
37
 
38
  Thread-safe for single-session use (no locks needed).
39
  All public methods handle missing/malformed input gracefully.
40
+
41
+ NOTE: Tracker state is in-memory only. Server restart clears history.
42
+ For production: persist via tracker.get_profile() → JSON file on /reset.
43
  """
44
 
45
  def __init__(self, alpha: float = 0.1):
server/reward.py CHANGED
@@ -208,6 +208,11 @@ def compute_statutory_accuracy(
208
  return 0.0
209
 
210
  # ── Standard IPC/BNSS statutory scoring ──────────────────────────────
 
 
 
 
 
211
  # Compute ground-truth eligibility for cases with known custody duration
212
  half_sent_months = (max_sent * 12) / 2
213
  truly_eligible = (custody_mo >= half_sent_months) and not special_laws
@@ -244,6 +249,10 @@ def compute_statutory_accuracy(
244
  elif has_numbers or has_time_ref:
245
  score += 0.15 if direction_correct else 0.05
246
 
 
 
 
 
247
  return min(1.0, score)
248
 
249
 
 
208
  return 0.0
209
 
210
  # ── Standard IPC/BNSS statutory scoring ──────────────────────────────
211
+ # D4: Detect unreliable custody_months=6.0 default on serious crimes.
212
+ # 74% of episodes have custody_months=6.0 which may be a dataset default.
213
+ # Cap score at 0.60 to avoid rewarding threshold arithmetic on unreliable data.
214
+ custody_unreliable = (custody_mo == 6.0 and max_sent > 3.0)
215
+
216
  # Compute ground-truth eligibility for cases with known custody duration
217
  half_sent_months = (max_sent * 12) / 2
218
  truly_eligible = (custody_mo >= half_sent_months) and not special_laws
 
249
  elif has_numbers or has_time_ref:
250
  score += 0.15 if direction_correct else 0.05
251
 
252
+ # D4: Cap score when custody data is unreliable (likely dataset default)
253
+ if custody_unreliable:
254
+ score = min(score, 0.60)
255
+
256
  return min(1.0, score)
257
 
258
 
server/undertrial_environment.py CHANGED
@@ -307,11 +307,12 @@ class UndertriAIEnvironment(Environment):
307
 
308
  elif isinstance(action, AssessSuretyAction):
309
  feasible = action.proposed_amount <= (action.income_estimate or 50000) * 3
 
310
  return (
311
  f"Surety Assessment:\n"
312
  f" Proposed Amount: ₹{action.proposed_amount:,}\n"
313
  f" Accused Occupation: {action.accused_occupation}\n"
314
- f" Income Estimate: {action.income_estimate:,}/month\n"
315
  f" → {'FINANCIALLY FEASIBLE ✓' if feasible else 'AMOUNT MAY BE EXCESSIVE — consider reduction'}"
316
  )
317
 
 
307
 
308
  elif isinstance(action, AssessSuretyAction):
309
  feasible = action.proposed_amount <= (action.income_estimate or 50000) * 3
310
+ income_str = f"₹{action.income_estimate:,}/month" if action.income_estimate is not None else "Not provided"
311
  return (
312
  f"Surety Assessment:\n"
313
  f" Proposed Amount: ₹{action.proposed_amount:,}\n"
314
  f" Accused Occupation: {action.accused_occupation}\n"
315
+ f" Income Estimate: {income_str}\n"
316
  f" → {'FINANCIALLY FEASIBLE ✓' if feasible else 'AMOUNT MAY BE EXCESSIVE — consider reduction'}"
317
  )
318
 
training/train_grpo.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  UndertriAI — GRPO Training Script
3
- Fine-tunes Qwen2.5-7B-Instruct using Group Relative Policy Optimization
4
  against the UndertriAI bail assessment environment.
5
 
6
  Run in Google Colab (T4 GPU recommended):
@@ -373,8 +373,13 @@ def reward_conditions(completions: List[str], episode_batch: List[Dict], **kwarg
373
  if kw in cond_text:
374
  score = min(1.0, score + 0.04)
375
  else:
376
- # Denial should have empty conditions
377
- score = 1.0 if len(conditions) == 0 else 0.5
 
 
 
 
 
378
  scores.append(min(1.0, score))
379
  return scores
380
 
@@ -508,12 +513,19 @@ def load_episodes(
508
  test = last test_fraction
509
  """
510
  path = Path(episodes_dir) / f"episodes_stage_{stage}.jsonl"
 
511
  if not path.exists():
512
  path = Path(episodes_dir) / "episodes_all.jsonl"
 
513
  if not path.exists():
514
  raise FileNotFoundError(f"No episodes found in {episodes_dir}.")
515
  with open(path, encoding="utf-8") as f:
516
  all_eps = [json.loads(l) for l in f if l.strip()]
 
 
 
 
 
517
 
518
  n = len(all_eps)
519
  n_test = max(1, int(n * test_fraction))
@@ -678,14 +690,14 @@ def train(
678
  ):
679
  print("=" * 60)
680
  print(" UndertriAI — GRPO Training with Unsloth")
681
- print(f" Model: Qwen2.5-7B-Instruct | Stage: {stage}")
682
  print("=" * 60)
683
 
684
  # ── Load model ──────────────────────────────────────────
685
  from unsloth import FastLanguageModel # type: ignore
686
 
687
  model, tokenizer = FastLanguageModel.from_pretrained(
688
- model_name = "unsloth/Qwen2.5-7B-Instruct",
689
  max_seq_length = max_seq_len,
690
  load_in_4bit = True,
691
  fast_inference = False,
@@ -791,23 +803,101 @@ def train(
791
  model.save_pretrained(output_dir, save_adapters_only=True)
792
  tokenizer.save_pretrained(output_dir)
793
  print(f"\nModel adapters saved to {output_dir}")
 
 
 
 
794
  return results
795
 
796
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
797
  # ============================================================
798
  # CELL 7 — Evaluate baseline (before training)
799
  # ============================================================
800
 
801
  def evaluate_baseline(episodes_dir: str, n_samples: int = 20):
802
  """
803
- Quick evaluation of a zero-shot Qwen2.5-7B-Instruct on bail cases.
804
  Run this BEFORE training to get the baseline reward curve starting point.
805
  """
806
  print("\nEvaluating zero-shot baseline...")
807
  from unsloth import FastLanguageModel # type: ignore
808
 
809
  model, tokenizer = FastLanguageModel.from_pretrained(
810
- model_name = "unsloth/Qwen2.5-7B-Instruct",
811
  max_seq_length = 3072,
812
  load_in_4bit = True,
813
  )
@@ -970,7 +1060,7 @@ def train_curriculum(
970
 
971
  # Load model once — reused across all stages
972
  model, tokenizer = FastLanguageModel.from_pretrained(
973
- model_name="unsloth/Qwen2.5-7B-Instruct",
974
  max_seq_length=3072,
975
  load_in_4bit=True,
976
  fast_inference=False,
@@ -1131,6 +1221,12 @@ def train_curriculum(
1131
  }, indent=2))
1132
  print(f" Results saved: {results_path}")
1133
 
 
 
 
 
 
 
1134
  return stage_results
1135
 
1136
 
@@ -1175,7 +1271,7 @@ def train_adaptive(
1175
 
1176
  # Load model once
1177
  model, tokenizer = FastLanguageModel.from_pretrained(
1178
- model_name="unsloth/Qwen2.5-7B-Instruct",
1179
  max_seq_length=3072,
1180
  load_in_4bit=True,
1181
  fast_inference=False,
@@ -1372,6 +1468,11 @@ def train_adaptive(
1372
  tokenizer.save_pretrained(final_dir)
1373
  print(f" Final model saved: {final_dir}")
1374
 
 
 
 
 
 
1375
  return results
1376
 
1377
 
 
1
  """
2
  UndertriAI — GRPO Training Script
3
+ Fine-tunes Qwen2.5-3B-Instruct using Group Relative Policy Optimization
4
  against the UndertriAI bail assessment environment.
5
 
6
  Run in Google Colab (T4 GPU recommended):
 
373
  if kw in cond_text:
374
  score = min(1.0, score + 0.04)
375
  else:
376
+ # Denial: empty conditions is correct ONLY when GT also denied
377
+ gt_outcome = ep.get("ground_truth", {}).get("outcome", "").lower()
378
+ gt_denied = "deni" in gt_outcome
379
+ if len(conditions) == 0:
380
+ score = 1.0 if gt_denied else 0.3 # H3: 0.3 not 1.0 when GT=granted
381
+ else:
382
+ score = 0.5 # Denied but listed conditions — inconsistent
383
  scores.append(min(1.0, score))
384
  return scores
385
 
 
513
  test = last test_fraction
514
  """
515
  path = Path(episodes_dir) / f"episodes_stage_{stage}.jsonl"
516
+ use_all_fallback = False
517
  if not path.exists():
518
  path = Path(episodes_dir) / "episodes_all.jsonl"
519
+ use_all_fallback = True
520
  if not path.exists():
521
  raise FileNotFoundError(f"No episodes found in {episodes_dir}.")
522
  with open(path, encoding="utf-8") as f:
523
  all_eps = [json.loads(l) for l in f if l.strip()]
524
+ # H1: filter by curriculum_stage when falling back to episodes_all.jsonl
525
+ if use_all_fallback:
526
+ filtered = [ep for ep in all_eps if ep.get("curriculum_stage") == stage]
527
+ if filtered:
528
+ all_eps = filtered
529
 
530
  n = len(all_eps)
531
  n_test = max(1, int(n * test_fraction))
 
690
  ):
691
  print("=" * 60)
692
  print(" UndertriAI — GRPO Training with Unsloth")
693
+ print(f" Model: Qwen2.5-3B-Instruct | Stage: {stage}")
694
  print("=" * 60)
695
 
696
  # ── Load model ──────────────────────────────────────────
697
  from unsloth import FastLanguageModel # type: ignore
698
 
699
  model, tokenizer = FastLanguageModel.from_pretrained(
700
+ model_name = "unsloth/Qwen2.5-3B-Instruct",
701
  max_seq_length = max_seq_len,
702
  load_in_4bit = True,
703
  fast_inference = False,
 
803
  model.save_pretrained(output_dir, save_adapters_only=True)
804
  tokenizer.save_pretrained(output_dir)
805
  print(f"\nModel adapters saved to {output_dir}")
806
+
807
+ # Save training plots (C6)
808
+ save_training_plots(trainer.state.log_history, output_dir)
809
+
810
  return results
811
 
812
 
813
+
814
+ # ============================================================
815
+ # Plot saving utility (C6)
816
+ # ============================================================
817
+
818
+ def save_training_plots(log_history: list, output_dir: str) -> None:
819
+ """
820
+ Save training reward curve and loss plots.
821
+ Called at the end of train(), train_curriculum(), and train_adaptive().
822
+ """
823
+ try:
824
+ import matplotlib
825
+ matplotlib.use("Agg") # Non-interactive backend
826
+ import matplotlib.pyplot as plt
827
+ import numpy as np
828
+ except ImportError:
829
+ print("[WARNING] matplotlib not installed — skipping plot generation.")
830
+ return
831
+
832
+ plots_dir = Path(output_dir) / "plots"
833
+ plots_dir.mkdir(parents=True, exist_ok=True)
834
+
835
+ # Extract reward data from training log
836
+ steps = [e["step"] for e in log_history if "reward" in e]
837
+ rewards = [e["reward"] for e in log_history if "reward" in e]
838
+
839
+ if not steps:
840
+ print("[WARNING] No reward data in training log — skipping plots.")
841
+ return
842
+
843
+ # Plot 1: Reward curve
844
+ fig, ax = plt.subplots(figsize=(10, 5))
845
+ fig.patch.set_facecolor("#0a0d1a")
846
+ ax.set_facecolor("#0a0d1a")
847
+ ax.plot(steps, rewards, color="#6366f1", linewidth=1.5, alpha=0.6, label="Raw")
848
+ if len(rewards) > 5:
849
+ smooth = np.convolve(rewards, np.ones(5) / 5, mode="valid")
850
+ ax.plot(steps[2:-2], smooth, color="#14b8a6", linewidth=2, label="Smoothed")
851
+ ax.set_xlabel("Training Step", color="#94a3b8")
852
+ ax.set_ylabel("Reward", color="#94a3b8")
853
+ ax.set_title("UndertriAI — Training Reward Curve", color="#e2e8f0", pad=12)
854
+ ax.tick_params(colors="#94a3b8")
855
+ ax.grid(True, alpha=0.2)
856
+ ax.legend(facecolor="#111827", edgecolor="#1e2d45", labelcolor="#94a3b8")
857
+ for spine in ax.spines.values():
858
+ spine.set_color("#1e2d45")
859
+ fig.tight_layout()
860
+ reward_path = plots_dir / "reward_curve.png"
861
+ fig.savefig(str(reward_path), dpi=150, bbox_inches="tight", facecolor="#0a0d1a")
862
+ plt.close(fig)
863
+ print(f" Plot saved: {reward_path}")
864
+
865
+ # Plot 2: Loss curve (if available)
866
+ loss_steps = [e["step"] for e in log_history if "loss" in e]
867
+ loss_values = [e["loss"] for e in log_history if "loss" in e]
868
+ if loss_steps:
869
+ fig2, ax2 = plt.subplots(figsize=(10, 5))
870
+ fig2.patch.set_facecolor("#0a0d1a")
871
+ ax2.set_facecolor("#0a0d1a")
872
+ ax2.plot(loss_steps, loss_values, color="#f97316", linewidth=1.5)
873
+ ax2.set_xlabel("Training Step", color="#94a3b8")
874
+ ax2.set_ylabel("Loss", color="#94a3b8")
875
+ ax2.set_title("UndertriAI — Training Loss", color="#e2e8f0", pad=12)
876
+ ax2.tick_params(colors="#94a3b8")
877
+ ax2.grid(True, alpha=0.2)
878
+ for spine in ax2.spines.values():
879
+ spine.set_color("#1e2d45")
880
+ fig2.tight_layout()
881
+ loss_path = plots_dir / "training_loss.png"
882
+ fig2.savefig(str(loss_path), dpi=150, bbox_inches="tight", facecolor="#0a0d1a")
883
+ plt.close(fig2)
884
+ print(f" Plot saved: {loss_path}")
885
+
886
+
887
  # ============================================================
888
  # CELL 7 — Evaluate baseline (before training)
889
  # ============================================================
890
 
891
  def evaluate_baseline(episodes_dir: str, n_samples: int = 20):
892
  """
893
+ Quick evaluation of a zero-shot Qwen2.5-3B-Instruct on bail cases.
894
  Run this BEFORE training to get the baseline reward curve starting point.
895
  """
896
  print("\nEvaluating zero-shot baseline...")
897
  from unsloth import FastLanguageModel # type: ignore
898
 
899
  model, tokenizer = FastLanguageModel.from_pretrained(
900
+ model_name = "unsloth/Qwen2.5-3B-Instruct",
901
  max_seq_length = 3072,
902
  load_in_4bit = True,
903
  )
 
1060
 
1061
  # Load model once — reused across all stages
1062
  model, tokenizer = FastLanguageModel.from_pretrained(
1063
+ model_name="unsloth/Qwen2.5-3B-Instruct",
1064
  max_seq_length=3072,
1065
  load_in_4bit=True,
1066
  fast_inference=False,
 
1221
  }, indent=2))
1222
  print(f" Results saved: {results_path}")
1223
 
1224
+ # Save training plots (C6) — use last trainer's log
1225
+ try:
1226
+ save_training_plots(trainer.state.log_history, output_dir)
1227
+ except Exception:
1228
+ print(" [WARNING] Could not save training plots.")
1229
+
1230
  return stage_results
1231
 
1232
 
 
1271
 
1272
  # Load model once
1273
  model, tokenizer = FastLanguageModel.from_pretrained(
1274
+ model_name="unsloth/Qwen2.5-3B-Instruct",
1275
  max_seq_length=3072,
1276
  load_in_4bit=True,
1277
  fast_inference=False,
 
1468
  tokenizer.save_pretrained(final_dir)
1469
  print(f" Final model saved: {final_dir}")
1470
 
1471
+ # Save training plots (C6)
1472
+ # Build a synthetic log_history from reward_curve for adaptive mode
1473
+ adaptive_log = [{"step": s, "reward": r} for s, r in reward_curve]
1474
+ save_training_plots(adaptive_log, output_dir)
1475
+
1476
  return results
1477
 
1478