Shabista Sehar commited on
Commit
d8f8a45
·
1 Parent(s): 4855450

implemented

Browse files
README.md CHANGED
@@ -58,8 +58,11 @@ UndertriAI is an **OpenEnv-compliant RL training environment** that teaches an L
58
  | Method | Endpoint | Description |
59
  |---|---|---|
60
  | `POST` | `/reset?stage=1` | Start a new episode (curriculum stage 1–4) |
 
61
  | `POST` | `/step` | Submit a tool call or final memo |
62
  | `GET` | `/state?session_id=...` | Inspect current episode state |
 
 
63
  | `GET` | `/health` | Health check |
64
  | `GET` | `/tools` | List available tools |
65
  | `WS` | `/ws/{session_id}` | WebSocket real-time feed |
@@ -74,6 +77,11 @@ UndertriAI is an **OpenEnv-compliant RL training environment** that teaches an L
74
  | `classify_bail_type` | Determine regular / anticipatory / default bail |
75
  | `request_document` | Request additional case documents |
76
  | `flag_inconsistency` | Flag contradictions in the charge sheet |
 
 
 
 
 
77
  | `submit_memo` | **Terminal action** — submit final bail recommendation |
78
 
79
  ### 4-Stage Curriculum
@@ -87,13 +95,71 @@ UndertriAI is an **OpenEnv-compliant RL training environment** that teaches an L
87
 
88
  ---
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  ## Reward Function
91
 
92
  ```
93
- R = 0.4 × outcome_match
94
  + 0.2 × flight_risk_accuracy
95
  + 0.2 × statutory_accuracy
96
  + 0.2 × condition_appropriateness
 
 
97
  − 0.3 × bias_penalty
98
  ```
99
 
@@ -101,16 +167,20 @@ All components are **fully deterministic and rule-based** — no LLM-as-judge.
101
 
102
  | Component | Signal | Details |
103
  |---|---|---|
104
- | **Outcome Match** | 0.0 / 0.8 / 1.0 | Exact, directional, or wrong vs HC decision |
105
  | **Flight Risk** | 0–1 | Ordinal distance to ground-truth risk level |
106
- | **Statutory** | 0–1 | IPC/BNSS section, sentence threshold, custody duration |
107
  | **Conditions** | 0–1 | Appropriate bail conditions for crime/risk profile |
 
 
108
  | **Bias Penalty** | −0.3 | Fired if parity argument ignored in bias-flagged cases |
109
 
110
  ### Anti-Reward-Hacking Design
111
 
112
- - 5 independent reward signals (harder to simultaneously game all)
113
  - `GenerationInspectionCallback` prints raw completions every 25 training steps
 
 
114
  - Bias penalty operates as a separate signal, not folded into outcome
115
  - Schema drift (Stage 4) tests adaptability, not pattern memorisation
116
 
@@ -118,18 +188,110 @@ All components are **fully deterministic and rule-based** — no LLM-as-judge.
118
 
119
  ## Training
120
 
121
- Uses **GRPO** (Group Relative Policy Optimization) via TRL + Unsloth on `Qwen2.5-3B-Instruct`.
122
 
123
- ```bash
124
- # Run with before/after eval and results.json
125
- python training/train_grpo.py \
126
- --episodes_dir ./data/episodes \
127
- --stage 1 \
128
- --steps 200 \
129
- --eval_after
130
- ```
131
 
132
- Or use the Colab notebook: [`training/UndertriAI_GRPO_Training.ipynb`](training/UndertriAI_GRPO_Training.ipynb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  ### Training Architecture
135
 
@@ -146,7 +308,13 @@ Episode Dataset (JSONL)
146
 
147
  GRPO updates model weights
148
 
149
- GenerationInspectionCallback logs samples every 25 steps
 
 
 
 
 
 
150
  ```
151
 
152
  ---
@@ -178,21 +346,25 @@ env = from_hub("Draken1606/undertrial-ai")
178
  ```
179
  undertrial_ai/
180
  ├── server/
181
- │ ├── app.py # FastAPI routes
182
- │ ├── undertrial_environment.py # Environment logic
183
- │ ├── reward.py # 5-component deterministic reward
184
- │ ├── dataset.py # Curriculum-staged episode loader
185
- ── schema_drift.py # IPC → BNSS remapping (Stage 4)
 
 
 
186
  ├── training/
187
- │ ├── train_grpo.py # GRPO training script
188
  │ └── UndertriAI_GRPO_Training.ipynb # Colab notebook
189
  ├── data/
190
- │ └── episodes/ # 1,200 HC judgments across 4 stages
191
  ├── demo/
192
- │ └── index.html # Interactive demo UI
193
- ├── client.py # UndertriAIEnv HTTP client
194
- ├── models.py # Pydantic action/observation schemas
195
- ── Dockerfile # HF Spaces deployment
 
196
  ```
197
 
198
  ---
 
58
  | Method | Endpoint | Description |
59
  |---|---|---|
60
  | `POST` | `/reset?stage=1` | Start a new episode (curriculum stage 1–4) |
61
+ | `POST` | `/reset?adaptive=true&auto_stage=true` | Start episode with adaptive selection (Theme 4) |
62
  | `POST` | `/step` | Submit a tool call or final memo |
63
  | `GET` | `/state?session_id=...` | Inspect current episode state |
64
+ | `GET` | `/profile?session_id=...` | Agent performance profile (Theme 4) |
65
+ | `GET` | `/adaptive_status` | Adaptive mode capabilities & thresholds |
66
  | `GET` | `/health` | Health check |
67
  | `GET` | `/tools` | List available tools |
68
  | `WS` | `/ws/{session_id}` | WebSocket real-time feed |
 
77
  | `classify_bail_type` | Determine regular / anticipatory / default bail |
78
  | `request_document` | Request additional case documents |
79
  | `flag_inconsistency` | Flag contradictions in the charge sheet |
80
+ | `read_submissions` | Read prosecution/defence arguments on record |
81
+ | `assess_flight_risk` | Systematic flight risk scoring matrix |
82
+ | `check_case_factors` | Examine parity, evidence tampering, victim vulnerability |
83
+ | `apply_proportionality` | BNSS 479 custody vs. max sentence proportionality |
84
+ | `pull_criminal_history` | Prior record, bail history, conviction status |
85
  | `submit_memo` | **Terminal action** — submit final bail recommendation |
86
 
87
  ### 4-Stage Curriculum
 
95
 
96
  ---
97
 
98
+ ## Theme 4 — Self-Improvement
99
+
100
+ UndertriAI qualifies for Theme 4 through three mechanisms:
101
+
102
+ **1. Adaptive Curriculum Promotion**
103
+ The environment tracks per-domain and per-stage performance using exponential
104
+ moving averages. When the agent demonstrates consistent improvement
105
+ (Stage 1 mean reward ≥ 0.65 over 20 episodes), it automatically promotes
106
+ to the next curriculum stage. This is visible in training logs as:
107
+ ```
108
+ [SELF-IMPROVEMENT] Step 100: Promoted to Stage 2. Stage 1 mean reward: 0.710 → Stage 2 begins.
109
+ ```
110
+
111
+ **2. Weakness-Targeted Episode Selection**
112
+ In adaptive mode, the episode selector identifies the crime type where the
113
+ agent performs worst and serves proportionally more cases from that domain.
114
+ As the agent improves on weak domains, the selection distribution shifts —
115
+ the environment continuously finds and targets new weaknesses.
116
+
117
+ | Selection | Weight | Mechanism |
118
+ |---|---|---|
119
+ | Weakest domain | 60% | EMA-tracked per-crime-type reward |
120
+ | Failure replay | 30% | Re-serve cases with reward < 0.40 |
121
+ | Exploration | 10% | Uniform random (prevent overfitting) |
122
+
123
+ **3. Synthetic Case Generation**
124
+ When the agent masters a domain (mean reward > 0.70), the environment
125
+ generates harder synthetic variants using 5 perturbation types:
126
+
127
+ | Perturbation | What it tests |
128
+ |---|---|
129
+ | Custody escalation | Custody 2 months below threshold — forces careful statutory computation |
130
+ | Co-accused conflict | Opposite bail outcome for co-accused — tests parity reasoning |
131
+ | Section ambiguity | IPC ↔ BNSS section swap — tests schema drift adaptability |
132
+ | Evidence reversal | Key witness retracted — tests flight risk reassessment |
133
+ | Surety complexity | Non-resident surety — tests condition appropriateness |
134
+
135
+ **Live Demo — Self-Improvement in Action**
136
+ ```bash
137
+ # Start the server
138
+ python -m server.app
139
+
140
+ # In another terminal — start adaptive training
141
+ python training/train_grpo.py --adaptive --steps 50 --env_url http://localhost:8000
142
+ ```
143
+
144
+ Monitor progress via:
145
+ ```
146
+ GET /profile?session_id={id}
147
+ GET /adaptive_status
148
+ ```
149
+
150
+ Watch stage promotions in the training log.
151
+
152
+ ---
153
+
154
  ## Reward Function
155
 
156
  ```
157
+ R = 0.4 × outcome_match (gated by reasoning quality)
158
  + 0.2 × flight_risk_accuracy
159
  + 0.2 × statutory_accuracy
160
  + 0.2 × condition_appropriateness
161
+ + 0.1 × reasoning_quality (bonus)
162
+ + 0.05 × format_compliance (bonus)
163
  − 0.3 × bias_penalty
164
  ```
165
 
 
167
 
168
  | Component | Signal | Details |
169
  |---|---|---|
170
+ | **Outcome Match** | 0.0 / 0.8 / 1.0 | Exact, directional, or wrong vs HC decision — gated by `<think>` block |
171
  | **Flight Risk** | 0–1 | Ordinal distance to ground-truth risk level |
172
+ | **Statutory** | 0–1 | IPC/BNSS threshold computation, direction-gated, NDPS Section 37 aware |
173
  | **Conditions** | 0–1 | Appropriate bail conditions for crime/risk profile |
174
+ | **Reasoning Quality** | 0–1 | Anchoring + arithmetic + grounds specificity (10% bonus) |
175
+ | **Format Compliance** | 0–1 | XML tag adherence to system prompt (5% bonus) |
176
  | **Bias Penalty** | −0.3 | Fired if parity argument ignored in bias-flagged cases |
177
 
178
  ### Anti-Reward-Hacking Design
179
 
180
+ - 7 independent reward signals (harder to simultaneously game all)
181
  - `GenerationInspectionCallback` prints raw completions every 25 training steps
182
+ - Reasoning gate: no `<think>` block → outcome reward zeroed in Stage 2+
183
+ - Direction gate: wrong bail direction → statutory bonus capped
184
  - Bias penalty operates as a separate signal, not folded into outcome
185
  - Schema drift (Stage 4) tests adaptability, not pattern memorisation
186
 
 
188
 
189
  ## Training
190
 
191
+ Uses **GRPO** (Group Relative Policy Optimization) via TRL + Unsloth on `Qwen2.5-7B-Instruct`.
192
 
193
+ ### Training Modes
 
 
 
 
 
 
 
194
 
195
+ | Mode | Command | Description |
196
+ |---|---|---|
197
+ | Single stage | `python training/train_grpo.py --stage 1 --steps 200` | Train on one stage |
198
+ | Curriculum | `python training/train_grpo.py --curriculum --steps 150` | Sequential 4-stage with trace harvesting |
199
+ | **Adaptive** | `python training/train_grpo.py --adaptive --steps 50` | **Theme 4** — self-directed with auto-promotion |
200
+
201
+ ### Google Colab Training Walkthrough
202
+
203
+ ```python
204
+ # ============================================================
205
+ # STEP 1 — Install dependencies (run in first cell)
206
+ # ============================================================
207
+ !pip install -q "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
208
+ !pip install -q --no-deps trl peft accelerate bitsandbytes xformers
209
+ !pip install -q openenv-core datasets
210
+
211
+ # ============================================================
212
+ # STEP 2 — Clone the repository
213
+ # ============================================================
214
+ !git clone https://github.com/Faiz-1606/Undertrial.git
215
+ %cd Undertrial
216
+
217
+ # ============================================================
218
+ # STEP 3 — Verify episodes are available
219
+ # ============================================================
220
+ import os
221
+ episodes_dir = "./data/episodes"
222
+ if not os.path.exists(episodes_dir):
223
+ print("No episodes directory — will use built-in demo episodes")
224
+ else:
225
+ for f in os.listdir(episodes_dir):
226
+ if f.endswith('.jsonl'):
227
+ count = sum(1 for _ in open(f"{episodes_dir}/{f}"))
228
+ print(f" {f}: {count} episodes")
229
+
230
+ # ============================================================
231
+ # STEP 4 — Option A: Single-stage training (quick, ~20 min on T4)
232
+ # ============================================================
233
+ !python training/train_grpo.py \
234
+ --episodes_dir ./data/episodes \
235
+ --stage 1 \
236
+ --steps 200 \
237
+ --batch_size 4 \
238
+ --eval_after
239
+
240
+ # ============================================================
241
+ # STEP 4 — Option B: Curriculum training (full, ~90 min on T4)
242
+ # ============================================================
243
+ !python training/train_grpo.py \
244
+ --episodes_dir ./data/episodes \
245
+ --curriculum \
246
+ --steps 150 \
247
+ --batch_size 4
248
+
249
+ # ============================================================
250
+ # STEP 4 — Option C: Adaptive training (Theme 4, ~60 min on T4)
251
+ # (Requires server running — start in a background cell first)
252
+ # ============================================================
253
+ # Background cell: start the server
254
+ import subprocess
255
+ server = subprocess.Popen(
256
+ ["python", "-m", "uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"],
257
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE
258
+ )
259
+ import time; time.sleep(5) # Wait for server startup
260
+
261
+ # Then run adaptive training
262
+ !python training/train_grpo.py \
263
+ --adaptive \
264
+ --episodes_dir ./data/episodes \
265
+ --steps 50 \
266
+ --batch_size 4 \
267
+ --env_url http://localhost:8000
268
+
269
+ # ============================================================
270
+ # STEP 5 — View results
271
+ # ============================================================
272
+ import json
273
+ # For single/curriculum:
274
+ results = json.load(open("./output/undertrial_grpo/results.json"))
275
+ print(json.dumps(results, indent=2))
276
+
277
+ # For adaptive:
278
+ # results = json.load(open("./output/undertrial_grpo/results_adaptive.json"))
279
+
280
+ # ============================================================
281
+ # STEP 6 — (Optional) Merge LoRA adapters for inference
282
+ # ============================================================
283
+ from unsloth import FastLanguageModel
284
+ model, tokenizer = FastLanguageModel.from_pretrained(
285
+ "./output/undertrial_grpo/final",
286
+ max_seq_length=3072,
287
+ )
288
+ model.save_pretrained_merged(
289
+ "./output/undertrial_merged",
290
+ tokenizer,
291
+ save_method="merged_16bit",
292
+ )
293
+ print("Merged model saved to ./output/undertrial_merged")
294
+ ```
295
 
296
  ### Training Architecture
297
 
 
308
 
309
  GRPO updates model weights
310
 
311
+ [Theme 4] PerformanceTracker updates EMA per domain/stage
312
+
313
+ [Theme 4] AdaptiveSelector targets weakest domain
314
+
315
+ [Theme 4] CaseGenerator creates harder synthetic variants
316
+
317
+ [Theme 4] Auto-promote when stage EMA exceeds threshold
318
  ```
319
 
320
  ---
 
346
  ```
347
  undertrial_ai/
348
  ├── server/
349
+ │ ├── app.py # FastAPI routes + Theme 4 endpoints
350
+ │ ├── undertrial_environment.py # Environment logic
351
+ │ ├── reward.py # 7-component deterministic reward
352
+ │ ├── dataset.py # Curriculum-staged episode loader
353
+ ── schema_drift.py # IPC → BNSS remapping (Stage 4)
354
+ │ ├── performance_tracker.py # [Theme 4] EMA-based performance profiling
355
+ │ ├── adaptive_selector.py # [Theme 4] Weakness-targeted episode selection
356
+ │ └── case_generator.py # [Theme 4] Synthetic case perturbation
357
  ├── training/
358
+ │ ├── train_grpo.py # GRPO training (single/curriculum/adaptive)
359
  │ └── UndertriAI_GRPO_Training.ipynb # Colab notebook
360
  ├── data/
361
+ │ └── episodes/ # 1,200 HC judgments across 4 stages
362
  ├── demo/
363
+ │ └── index.html # Interactive demo UI
364
+ ├── client.py # UndertriAIEnv HTTP client
365
+ ├── models.py # Pydantic action/observation schemas
366
+ ── openenv.yaml # OpenEnv manifest
367
+ └── Dockerfile # HF Spaces deployment
368
  ```
369
 
370
  ---
openenv.yaml CHANGED
@@ -1,10 +1,12 @@
1
  name: undertrial-ai
2
- version: "1.0.0"
3
  description: >
4
- OpenEnv-compliant RL training environment for Indian bail decision support.
5
- An LLM agent reads High Court bail cases, invokes legal tools, and submits
6
- structured bail recommendations. Reward computed deterministically against
7
- real HC judgments with an explicit bias penalty (lambda=0.3).
 
 
8
 
9
  author: Draken1606
10
  license: MIT
@@ -19,6 +21,8 @@ tags:
19
  - world-modeling
20
  - bias-mitigation
21
  - bnss-2023
 
 
22
 
23
  environment:
24
  class: undertrial_ai.server.undertrial_environment.UndertriAIEnvironment
@@ -52,17 +56,18 @@ actions:
52
  description: "TERMINAL — Submit structured bail assessment memo"
53
 
54
  reward:
55
- formula: "0.3*outcome + 0.2*flight_risk + 0.2*statutory + 0.2*conditions + 0.1*reasoning_quality + 0.1*efficiency + 0.05*process_bonus - 0.3*bias"
56
  range: [-0.7, 1.15]
57
  terminal_action: submit_memo
58
  deterministic: true
59
  llm_as_judge: false
60
  components:
61
- - outcome_match: "Agreement with real High Court decision (30%)"
62
  - flight_risk_accuracy: "Flight risk classification accuracy (20%)"
63
- - statutory_accuracy: "IPC/BNSS threshold computation (20%)"
64
  - condition_appropriateness: "Bail condition quality (20%)"
65
- - reasoning_quality: "Justification anchoring + arithmetic verification + grounds specificity (10%)"
 
66
  - bias_penalty: "Penalty for ignoring parity in bias cases (-30%)"
67
 
68
  curriculum:
@@ -72,12 +77,70 @@ curriculum:
72
  stage_3: "Bias reversal / parity cases"
73
  stage_4: "Schema drift (IPC→BNSS, regional FIR formats)"
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  training:
76
  method: GRPO
77
  framework: TRL + Unsloth
78
  model: unsloth/Qwen2.5-7B-Instruct
79
  notebook: training/UndertriAI_GRPO_Training.ipynb
80
  script: training/train_grpo.py
 
 
 
 
 
 
 
81
 
82
  deployment:
83
  platform: huggingface-spaces
 
1
  name: undertrial-ai
2
+ version: "1.1.0"
3
  description: >
4
+ OpenEnv-compliant RL training environment for Indian bail decision support
5
+ with adaptive self-improvement (Theme 4). An LLM agent reads High Court bail
6
+ cases, invokes legal tools, and submits structured bail recommendations.
7
+ Reward computed deterministically against real HC judgments with an explicit
8
+ bias penalty (lambda=0.3). Features performance-aware episode selection,
9
+ stage-gated curriculum promotion, and synthetic case generation.
10
 
11
  author: Draken1606
12
  license: MIT
 
21
  - world-modeling
22
  - bias-mitigation
23
  - bnss-2023
24
+ - self-improvement
25
+ - adaptive-curriculum
26
 
27
  environment:
28
  class: undertrial_ai.server.undertrial_environment.UndertriAIEnvironment
 
56
  description: "TERMINAL — Submit structured bail assessment memo"
57
 
58
  reward:
59
+ formula: "0.4*outcome_gated + 0.2*flight_risk + 0.2*statutory + 0.2*conditions + 0.1*reasoning_quality + 0.05*format - 0.3*bias"
60
  range: [-0.7, 1.15]
61
  terminal_action: submit_memo
62
  deterministic: true
63
  llm_as_judge: false
64
  components:
65
+ - outcome_match: "Agreement with real High Court decision, gated by reasoning quality (40%)"
66
  - flight_risk_accuracy: "Flight risk classification accuracy (20%)"
67
+ - statutory_accuracy: "IPC/BNSS threshold computation with direction gate (20%)"
68
  - condition_appropriateness: "Bail condition quality (20%)"
69
+ - reasoning_quality: "Justification anchoring + arithmetic verification + grounds specificity (10% bonus)"
70
+ - format_compliance: "XML tag adherence matching system prompt structure (5% bonus)"
71
  - bias_penalty: "Penalty for ignoring parity in bias cases (-30%)"
72
 
73
  curriculum:
 
77
  stage_3: "Bias reversal / parity cases"
78
  stage_4: "Schema drift (IPC→BNSS, regional FIR formats)"
79
 
80
+ self_improvement:
81
+ adaptive_curriculum:
82
+ description: >
83
+ Performance-gated stage promotion using exponential moving averages.
84
+ Agent auto-promotes when per-stage EMA exceeds threshold.
85
+ thresholds:
86
+ stage_1_to_2: {min_reward: 0.65, min_episodes: 20}
87
+ stage_2_to_3: {min_reward: 0.55, min_episodes: 50}
88
+ stage_3_to_4: {min_reward: 0.50, min_episodes: 20}
89
+ weakness_targeting:
90
+ description: >
91
+ Adaptive episode selection identifies the crime type with lowest EMA
92
+ reward and serves proportionally more cases from that domain.
93
+ strategy: "60% weakest domain / 30% failure replay / 10% exploration"
94
+ synthetic_generation:
95
+ description: >
96
+ When agent masters a domain (EMA > 0.70), generates harder synthetic
97
+ variants using 5 perturbation types.
98
+ perturbation_types:
99
+ - custody_escalation
100
+ - co_accused_conflict
101
+ - section_ambiguity
102
+ - evidence_reversal
103
+ - surety_complexity
104
+
105
+ endpoints:
106
+ - path: /reset
107
+ method: POST
108
+ description: "Start a new episode. Supports adaptive=true and auto_stage=true for Theme 4."
109
+ - path: /step
110
+ method: POST
111
+ description: "Submit a tool call or final memo. Updates performance tracker when done."
112
+ - path: /state
113
+ method: GET
114
+ description: "Inspect current episode state."
115
+ - path: /health
116
+ method: GET
117
+ description: "Health check."
118
+ - path: /tools
119
+ method: GET
120
+ description: "List available tools."
121
+ - path: /profile
122
+ method: GET
123
+ description: "Get agent performance profile for a session (Theme 4)."
124
+ - path: /adaptive_status
125
+ method: GET
126
+ description: "Get global adaptive mode capabilities and thresholds."
127
+ - path: /ws/{session_id}
128
+ method: WS
129
+ description: "WebSocket real-time feed."
130
+
131
  training:
132
  method: GRPO
133
  framework: TRL + Unsloth
134
  model: unsloth/Qwen2.5-7B-Instruct
135
  notebook: training/UndertriAI_GRPO_Training.ipynb
136
  script: training/train_grpo.py
137
+ modes:
138
+ - name: single_stage
139
+ command: "python training/train_grpo.py --stage 1 --steps 200"
140
+ - name: curriculum
141
+ command: "python training/train_grpo.py --curriculum --steps 150"
142
+ - name: adaptive
143
+ command: "python training/train_grpo.py --adaptive --steps 50 --env_url http://localhost:8000"
144
 
145
  deployment:
146
  platform: huggingface-spaces
server/adaptive_selector.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UndertriAI — Adaptive Episode Selector (Theme 4: Self-Improvement)
3
+
4
+ Wraps the existing BailDataset to provide performance-aware episode
5
+ selection when adaptive mode is enabled. Falls back to uniform random
6
+ (identical to existing behavior) when adaptive=False.
7
+ """
8
+
9
+ import random
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ from .performance_tracker import PerformanceTracker
13
+
14
+
15
+ class AdaptiveSelector:
16
+ """
17
+ Performance-aware episode selector.
18
+
19
+ Selection strategy (applied in order when adaptive=True):
20
+ 60%: sample from the weakest crime-type domain in current_stage
21
+ 30%: replay cases where recent performance was poor (reward < 0.40)
22
+ 10%: uniform random from current_stage (exploration)
23
+
24
+ Always returns a valid episode dict. Never raises.
25
+ """
26
+
27
+ def __init__(self, dataset, tracker: PerformanceTracker):
28
+ """
29
+ Args:
30
+ dataset: BailDataset instance (has _episodes, sample_episode)
31
+ tracker: PerformanceTracker instance driving selection
32
+ """
33
+ self.dataset = dataset
34
+ self.tracker = tracker
35
+
36
+ # ------------------------------------------------------------------
37
+ # Public API
38
+ # ------------------------------------------------------------------
39
+
40
+ def select_episode(self, current_stage: int) -> Dict[str, Any]:
41
+ """
42
+ Performance-aware selection for adaptive mode.
43
+
44
+ 60% weakest domain → 30% failure replay → 10% exploration.
45
+ Falls back to uniform on any failure.
46
+ """
47
+ try:
48
+ roll = random.random()
49
+
50
+ if roll < 0.60:
51
+ # Try weakest domain
52
+ ep = self._select_weakest_domain(current_stage)
53
+ if ep is not None:
54
+ return ep
55
+
56
+ if roll < 0.90:
57
+ # Try failure replay
58
+ ep = self._select_failure_replay(current_stage)
59
+ if ep is not None:
60
+ return ep
61
+
62
+ # 10% exploration or fallback
63
+ return self.select_episode_uniform(current_stage)
64
+
65
+ except Exception:
66
+ # Absolute fallback — never crash
67
+ return self.select_episode_uniform(current_stage)
68
+
69
+ def select_episode_uniform(self, current_stage: int) -> Dict[str, Any]:
70
+ """
71
+ Pure random selection from current_stage.
72
+ Identical to existing BailDataset.sample_episode() behavior.
73
+ """
74
+ return self.dataset.sample_episode(stage=current_stage)
75
+
76
+ # ------------------------------------------------------------------
77
+ # Internal strategies
78
+ # ------------------------------------------------------------------
79
+
80
+ def _select_weakest_domain(
81
+ self, current_stage: int
82
+ ) -> Optional[Dict[str, Any]]:
83
+ """
84
+ Select an episode from the weakest crime-type domain.
85
+ Returns None if no weak domain identified or no matching episodes.
86
+ """
87
+ weak_domain = self.tracker.weakest_domain()
88
+ if weak_domain is None:
89
+ return None
90
+
91
+ # Find episodes matching this crime type in the current stage
92
+ episodes = self._get_stage_episodes(current_stage)
93
+ matches = [
94
+ ep for ep in episodes
95
+ if str(ep.get("crime_type", "")).strip() == weak_domain
96
+ ]
97
+
98
+ if not matches:
99
+ return None
100
+
101
+ return random.choice(matches)
102
+
103
+ def _select_failure_replay(
104
+ self, current_stage: int
105
+ ) -> Optional[Dict[str, Any]]:
106
+ """
107
+ Replay a case where the agent recently scored below 0.40.
108
+ Returns None if no recent failures or no matching episodes.
109
+ """
110
+ failed_ids = self.tracker.get_recent_failures(threshold=0.40)
111
+ if not failed_ids:
112
+ return None
113
+
114
+ # Find episodes matching failed case_ids in current stage
115
+ episodes = self._get_stage_episodes(current_stage)
116
+ matches = [
117
+ ep for ep in episodes
118
+ if ep.get("case_id", "") in failed_ids
119
+ ]
120
+
121
+ if not matches:
122
+ return None
123
+
124
+ return random.choice(matches)
125
+
126
+ def _get_stage_episodes(self, stage: int) -> List[Dict[str, Any]]:
127
+ """Get all episodes for a given stage from the dataset."""
128
+ try:
129
+ eps = self.dataset._episodes.get(stage, [])
130
+ if eps:
131
+ return eps
132
+ # Fallback chain matching BailDataset.sample_episode
133
+ for candidate in [stage - 1, stage + 1, 1, 2, 3, 4]:
134
+ if 1 <= candidate <= 4:
135
+ eps = self.dataset._episodes.get(candidate, [])
136
+ if eps:
137
+ return eps
138
+ except Exception:
139
+ pass
140
+ return []
server/app.py CHANGED
@@ -5,16 +5,40 @@ Wraps UndertriAIEnvironment as an OpenEnv-compatible HTTP + WebSocket server.
5
 
6
  import os
7
  from pathlib import Path
 
8
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import JSONResponse, HTMLResponse
11
  import json
12
  import uuid
13
-
14
 
15
  from .undertrial_environment import UndertriAIEnvironment
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- # Session store: episode_id → environment instance
 
18
  _sessions: dict = {}
19
 
20
  app = FastAPI(
@@ -33,14 +57,16 @@ app.add_middleware(
33
  EPISODES_DIR = os.environ.get("UNDERTRIAL_EPISODES_DIR", None)
34
 
35
 
36
- def get_or_create_env(session_id: str) -> UndertriAIEnvironment:
 
37
  if session_id not in _sessions:
38
- _sessions[session_id] = UndertriAIEnvironment(episodes_dir=EPISODES_DIR)
 
39
  return _sessions[session_id]
40
 
41
 
42
  # ------------------------------------------------------------------
43
- # REST endpoints
44
  # ------------------------------------------------------------------
45
 
46
  @app.get("/", response_class=HTMLResponse)
@@ -69,12 +95,46 @@ def health():
69
 
70
 
71
  @app.post("/reset")
72
- def reset(stage: int = 1, session_id: str = None, seed: int = None, episode_id: str = None):
 
 
 
 
 
 
 
73
  if session_id is None:
74
  session_id = str(uuid.uuid4())
75
- env = get_or_create_env(session_id)
76
- env.set_stage(stage)
77
- obs = env.reset(stage=stage, seed=seed, episode_id=episode_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  return {
79
  "session_id": session_id,
80
  "observation": obs.model_dump(),
@@ -91,7 +151,8 @@ def step(payload: dict):
91
  if not session_id or session_id not in _sessions:
92
  return JSONResponse(status_code=400, content={"error": "Invalid session_id. Call /reset first."})
93
 
94
- env = _sessions[session_id]
 
95
 
96
  # Deserialize action by tool_name
97
  tool_name = action_data.get("tool_name", "")
@@ -126,7 +187,35 @@ def step(payload: dict):
126
  except Exception as e:
127
  return JSONResponse(status_code=422, content={"error": str(e)})
128
 
 
 
 
 
129
  result = env.step(action)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  return {
131
  "session_id": session_id,
132
  "observation": result.observation.model_dump(),
@@ -140,7 +229,7 @@ def step(payload: dict):
140
  def state(session_id: str):
141
  if session_id not in _sessions:
142
  return JSONResponse(status_code=400, content={"error": "Invalid session_id."})
143
- return _sessions[session_id].state
144
 
145
 
146
  @app.get("/observation")
@@ -148,7 +237,7 @@ def observation(session_id: str):
148
  """OpenEnv spec alias for /state — returns current episode observation."""
149
  if session_id not in _sessions:
150
  return JSONResponse(status_code=400, content={"error": "Invalid session_id."})
151
- return _sessions[session_id].state
152
 
153
 
154
  @app.get("/tools")
@@ -171,6 +260,48 @@ def list_tools():
171
  }
172
 
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  # ------------------------------------------------------------------
175
  # WebSocket endpoint (OpenEnv standard)
176
  # ------------------------------------------------------------------
@@ -178,7 +309,8 @@ def list_tools():
178
  @app.websocket("/ws/{session_id}")
179
  async def websocket_endpoint(websocket: WebSocket, session_id: str):
180
  await websocket.accept()
181
- env = get_or_create_env(session_id)
 
182
  try:
183
  while True:
184
  data = await websocket.receive_text()
 
5
 
6
  import os
7
  from pathlib import Path
8
+ from dataclasses import dataclass, field
9
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from fastapi.responses import JSONResponse, HTMLResponse
12
  import json
13
  import uuid
14
+ from typing import List, Optional
15
 
16
  from .undertrial_environment import UndertriAIEnvironment
17
+ from .performance_tracker import PerformanceTracker
18
+ from .adaptive_selector import AdaptiveSelector
19
+ from .case_generator import generate_variants
20
+
21
+
22
+ # ------------------------------------------------------------------
23
+ # Session state
24
+ # ------------------------------------------------------------------
25
+
26
+ @dataclass
27
+ class SessionState:
28
+ """Per-session state wrapping the environment + Theme 4 components."""
29
+ env: UndertriAIEnvironment
30
+ tracker: PerformanceTracker = field(default_factory=PerformanceTracker)
31
+ adaptive: bool = False
32
+ selector: Optional[AdaptiveSelector] = None
33
+ tools_used: List[str] = field(default_factory=list)
34
+ synthetic_cases_generated: int = 0
35
+
36
+ def __post_init__(self):
37
+ if self.selector is None:
38
+ self.selector = AdaptiveSelector(self.env.dataset, self.tracker)
39
 
40
+
41
+ # Session store: session_id → SessionState
42
  _sessions: dict = {}
43
 
44
  app = FastAPI(
 
57
  EPISODES_DIR = os.environ.get("UNDERTRIAL_EPISODES_DIR", None)
58
 
59
 
60
+ def get_or_create_session(session_id: str) -> SessionState:
61
+ """Get existing session or create new one with all Theme 4 components."""
62
  if session_id not in _sessions:
63
+ env = UndertriAIEnvironment(episodes_dir=EPISODES_DIR)
64
+ _sessions[session_id] = SessionState(env=env)
65
  return _sessions[session_id]
66
 
67
 
68
  # ------------------------------------------------------------------
69
+ # REST endpoints (existing — preserved exactly)
70
  # ------------------------------------------------------------------
71
 
72
  @app.get("/", response_class=HTMLResponse)
 
95
 
96
 
97
  @app.post("/reset")
98
+ def reset(
99
+ stage: int = 1,
100
+ session_id: str = None,
101
+ seed: int = None,
102
+ episode_id: str = None,
103
+ adaptive: bool = False,
104
+ auto_stage: bool = False,
105
+ ):
106
  if session_id is None:
107
  session_id = str(uuid.uuid4())
108
+
109
+ session = get_or_create_session(session_id)
110
+ env = session.env
111
+ session.adaptive = adaptive
112
+ session.tools_used = [] # Reset tools tracking
113
+
114
+ # Auto-stage: use tracker's suggestion
115
+ effective_stage = stage
116
+ if auto_stage:
117
+ effective_stage = session.tracker.suggest_next_stage()
118
+
119
+ env.set_stage(effective_stage)
120
+
121
+ # Adaptive episode selection
122
+ if adaptive and episode_id is None and seed is None:
123
+ # Use adaptive selector instead of uniform random
124
+ selected_ep = session.selector.select_episode(effective_stage)
125
+ # Inject the selected episode directly into the environment
126
+ env._episode = selected_ep
127
+ env._episode_id = str(uuid.uuid4())
128
+ env._step_count = 0
129
+ env._flags = []
130
+ env._retrieved_precedents = []
131
+ env._action_history = []
132
+ env._statutory_tool_called = False
133
+ env._tools_called = set()
134
+ obs = env._make_observation(action_result=None)
135
+ else:
136
+ obs = env.reset(stage=effective_stage, seed=seed, episode_id=episode_id)
137
+
138
  return {
139
  "session_id": session_id,
140
  "observation": obs.model_dump(),
 
151
  if not session_id or session_id not in _sessions:
152
  return JSONResponse(status_code=400, content={"error": "Invalid session_id. Call /reset first."})
153
 
154
+ session = _sessions[session_id]
155
+ env = session.env
156
 
157
  # Deserialize action by tool_name
158
  tool_name = action_data.get("tool_name", "")
 
187
  except Exception as e:
188
  return JSONResponse(status_code=422, content={"error": str(e)})
189
 
190
+ # Track tool usage for this session
191
+ if tool_name != "submit_memo":
192
+ session.tools_used.append(tool_name)
193
+
194
  result = env.step(action)
195
+
196
+ # Theme 4: Update tracker after terminal action (reward available)
197
+ if result.done and hasattr(result, "info") and isinstance(result.info, dict):
198
+ reward_components = result.info
199
+ episode = env._episode or {}
200
+ session.tracker.update(
201
+ episode=episode,
202
+ reward_components=reward_components,
203
+ tools_used=list(session.tools_used),
204
+ )
205
+
206
+ # Generate synthetic cases if agent mastered this domain
207
+ if session.adaptive:
208
+ crime_type = episode.get("crime_type", "")
209
+ if crime_type and session.tracker.should_generate_synthetic(crime_type):
210
+ variants = generate_variants(episode, n=3)
211
+ if variants:
212
+ # Inject synthetic cases into the dataset
213
+ stage = episode.get("curriculum_stage", 1)
214
+ for v in variants:
215
+ v["curriculum_stage"] = stage
216
+ env.dataset._episodes.setdefault(stage, []).append(v)
217
+ session.synthetic_cases_generated += len(variants)
218
+
219
  return {
220
  "session_id": session_id,
221
  "observation": result.observation.model_dump(),
 
229
  def state(session_id: str):
230
  if session_id not in _sessions:
231
  return JSONResponse(status_code=400, content={"error": "Invalid session_id."})
232
+ return _sessions[session_id].env.state
233
 
234
 
235
  @app.get("/observation")
 
237
  """OpenEnv spec alias for /state — returns current episode observation."""
238
  if session_id not in _sessions:
239
  return JSONResponse(status_code=400, content={"error": "Invalid session_id."})
240
+ return _sessions[session_id].env.state
241
 
242
 
243
  @app.get("/tools")
 
260
  }
261
 
262
 
263
+ # ------------------------------------------------------------------
264
+ # Theme 4: New API endpoints (additive — do not replace existing)
265
+ # ------------------------------------------------------------------
266
+
267
+ @app.get("/profile")
268
+ def get_profile(session_id: str):
269
+ """Returns the current PerformanceTracker profile for the session."""
270
+ if session_id not in _sessions:
271
+ return JSONResponse(
272
+ status_code=404,
273
+ content={"error": f"Session '{session_id}' not found. Call /reset first."},
274
+ )
275
+ session = _sessions[session_id]
276
+ return {
277
+ "session_id": session_id,
278
+ "profile": session.tracker.get_profile(),
279
+ "adaptive_mode": session.adaptive,
280
+ "synthetic_cases_generated": session.synthetic_cases_generated,
281
+ }
282
+
283
+
284
+ @app.get("/adaptive_status")
285
+ def adaptive_status():
286
+ """Returns global adaptive mode capabilities (not session-specific)."""
287
+ return {
288
+ "adaptive_available": True,
289
+ "description": "Performance-aware episode selection and synthetic case generation",
290
+ "promotion_thresholds": {
291
+ "stage_1_to_2": {"min_reward": 0.65, "min_episodes": 20},
292
+ "stage_2_to_3": {"min_reward": 0.55, "min_episodes": 50},
293
+ "stage_3_to_4": {"min_reward": 0.50, "min_episodes": 20},
294
+ },
295
+ "perturbation_types": [
296
+ "custody_escalation",
297
+ "co_accused_conflict",
298
+ "section_ambiguity",
299
+ "evidence_reversal",
300
+ "surety_complexity",
301
+ ],
302
+ }
303
+
304
+
305
  # ------------------------------------------------------------------
306
  # WebSocket endpoint (OpenEnv standard)
307
  # ------------------------------------------------------------------
 
309
  @app.websocket("/ws/{session_id}")
310
  async def websocket_endpoint(websocket: WebSocket, session_id: str):
311
  await websocket.accept()
312
+ session = get_or_create_session(session_id)
313
+ env = session.env
314
  try:
315
  while True:
316
  data = await websocket.receive_text()
server/case_generator.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UndertriAI — Synthetic Case Generator (Theme 4: Self-Improvement)
3
+
4
+ When the agent masters a domain, this generates harder synthetic variants
5
+ of existing cases. All generation is deterministic string manipulation —
6
+ no LLM calls.
7
+
8
+ 5 perturbation types:
9
+ 1. custody_escalation — custody just below statutory threshold
10
+ 2. co_accused_conflict — co-accused with opposite bail outcome
11
+ 3. section_ambiguity — IPC ↔ BNSS section swap
12
+ 4. evidence_reversal — retracted witness / unreliable evidence
13
+ 5. surety_complexity — non-resident surety complication
14
+ """
15
+
16
+ import copy
17
+ import re
18
+ from typing import Any, Dict, List, Optional
19
+
20
+
21
+ # IPC → BNSS mapping (subset used by the environment)
22
+ IPC_TO_BNSS = {
23
+ "302": "103", "307": "109", "376": "64", "304B": "80", "395": "310",
24
+ "392": "309", "420": "318", "498A": "85", "406": "316", "465": "336",
25
+ "323": "115", "354": "74", "120B": "61", "506": "351", "121": "147",
26
+ "379": "303", "324": "117", "354A": "75",
27
+ }
28
+ BNSS_TO_IPC = {v: k for k, v in IPC_TO_BNSS.items()}
29
+
30
+
31
+ # ── Required fields for schema validation ────────────────────────────
32
+
33
+ REQUIRED_FIELDS = {
34
+ "case_id": str,
35
+ "crime_type": str,
36
+ "ipc_sections": list,
37
+ "custody_months": (int, float),
38
+ "charge_sheet": str,
39
+ "ground_truth": dict,
40
+ "curriculum_stage": (int, float),
41
+ }
42
+
43
+
44
+ def is_schema_valid(episode: Dict[str, Any]) -> bool:
45
+ """
46
+ Check that all required fields are present and correct types.
47
+ Returns True/False — used to filter out malformed synthetic cases.
48
+ """
49
+ for field, expected_type in REQUIRED_FIELDS.items():
50
+ if field not in episode:
51
+ return False
52
+ if not isinstance(episode[field], expected_type):
53
+ return False
54
+
55
+ # ground_truth must have 'outcome'
56
+ gt = episode.get("ground_truth", {})
57
+ if "outcome" not in gt:
58
+ return False
59
+
60
+ return True
61
+
62
+
63
+ def generate_variants(
64
+ source_episode: Dict[str, Any],
65
+ n: int = 5,
66
+ ) -> List[Dict[str, Any]]:
67
+ """
68
+ Generate up to n synthetic harder variants of a real episode.
69
+ Each variant applies exactly ONE perturbation.
70
+
71
+ Returns only valid variants (may be fewer than n if some
72
+ perturbations can't be applied cleanly).
73
+ """
74
+ if not is_schema_valid(source_episode):
75
+ return []
76
+
77
+ perturbations = [
78
+ _custody_escalation,
79
+ _co_accused_conflict,
80
+ _section_ambiguity,
81
+ _evidence_reversal,
82
+ _surety_complexity,
83
+ ]
84
+
85
+ variants = []
86
+ for i, perturb_fn in enumerate(perturbations[:n]):
87
+ try:
88
+ variant = perturb_fn(source_episode)
89
+ if variant is not None and is_schema_valid(variant):
90
+ variants.append(variant)
91
+ except Exception:
92
+ # Skip perturbation on any error
93
+ continue
94
+
95
+ return variants
96
+
97
+
98
+ # ── Perturbation 1: Custody Escalation ───────────────────────────────
99
+
100
+ def _custody_escalation(episode: Dict[str, Any]) -> Optional[Dict[str, Any]]:
101
+ """
102
+ Set custody_months to exactly 2 months below the statutory threshold.
103
+ Forces careful computation — case is NOT yet eligible for default bail.
104
+ """
105
+ ep = copy.deepcopy(episode)
106
+ max_sent = ep.get("max_sentence_years", 5.0)
107
+
108
+ # Threshold is 50% of max sentence in months
109
+ threshold_months = (max_sent * 12) / 2
110
+ new_custody = max(1.0, threshold_months - 2.0)
111
+
112
+ old_custody = ep.get("custody_months", 0)
113
+ ep["custody_months"] = round(new_custody, 1)
114
+
115
+ # Update charge sheet text if it mentions custody duration
116
+ charge = ep.get("charge_sheet", "")
117
+ if str(int(old_custody)) in charge:
118
+ charge = charge.replace(
119
+ f"{int(old_custody)} months",
120
+ f"{int(new_custody)} months",
121
+ )
122
+ ep["charge_sheet"] = charge
123
+
124
+ # Metadata
125
+ parent_id = ep.get("case_id", "UNKNOWN")
126
+ ep["case_id"] = f"SYN_{parent_id}_CUST"
127
+ ep["source"] = "synthetic"
128
+ ep["parent_case_id"] = parent_id
129
+ ep["perturbation_type"] = "custody_escalation"
130
+ ep["difficulty"] = "hard"
131
+
132
+ return ep
133
+
134
+
135
+ # ── Perturbation 2: Co-Accused Conflict ──────────────────────────────
136
+
137
+ def _co_accused_conflict(episode: Dict[str, Any]) -> Optional[Dict[str, Any]]:
138
+ """
139
+ Add a co-accused with the OPPOSITE bail outcome.
140
+ Forces the agent to make a parity argument.
141
+ """
142
+ ep = copy.deepcopy(episode)
143
+ gt = ep.get("ground_truth", {})
144
+ gt_outcome = gt.get("outcome", "Bail Granted")
145
+
146
+ # Opposite outcome
147
+ if "grant" in gt_outcome.lower():
148
+ co_outcome = "Bail Denied"
149
+ else:
150
+ co_outcome = "Bail Granted"
151
+
152
+ ep["co_accused"] = [{
153
+ "name": "Co-Accused A",
154
+ "bail_outcome": co_outcome,
155
+ "sections": ep.get("ipc_sections", []),
156
+ }]
157
+
158
+ gt["parity_argument_used"] = True
159
+ ep["ground_truth"] = gt
160
+
161
+ # Add parity context to defence arguments
162
+ defence = ep.get("defence_arguments", [])
163
+ defence.append(
164
+ f"Co-accused was {'granted' if 'grant' in co_outcome.lower() else 'denied'} "
165
+ f"bail under identical charges — parity principle applies."
166
+ )
167
+ ep["defence_arguments"] = defence
168
+
169
+ # Metadata
170
+ parent_id = ep.get("case_id", "UNKNOWN")
171
+ ep["case_id"] = f"SYN_{parent_id}_COAC"
172
+ ep["source"] = "synthetic"
173
+ ep["parent_case_id"] = parent_id
174
+ ep["perturbation_type"] = "co_accused_conflict"
175
+ ep["difficulty"] = "hard"
176
+
177
+ return ep
178
+
179
+
180
+ # ── Perturbation 3: Section Ambiguity (IPC ↔ BNSS) ──────────────────
181
+
182
+ def _section_ambiguity(episode: Dict[str, Any]) -> Optional[Dict[str, Any]]:
183
+ """
184
+ Swap IPC sections to BNSS equivalents (or vice versa).
185
+ Tests schema drift adaptability.
186
+ """
187
+ ep = copy.deepcopy(episode)
188
+ sections = ep.get("ipc_sections", [])
189
+
190
+ if not sections:
191
+ return None
192
+
193
+ new_sections = []
194
+ swapped = False
195
+ for sec in sections:
196
+ sec_clean = sec.strip()
197
+ if sec_clean in IPC_TO_BNSS:
198
+ new_sections.append(IPC_TO_BNSS[sec_clean])
199
+ swapped = True
200
+ elif sec_clean in BNSS_TO_IPC:
201
+ new_sections.append(BNSS_TO_IPC[sec_clean])
202
+ swapped = True
203
+ else:
204
+ new_sections.append(sec_clean)
205
+
206
+ if not swapped:
207
+ return None
208
+
209
+ ep["ipc_sections"] = new_sections
210
+
211
+ # Update charge sheet references
212
+ charge = ep.get("charge_sheet", "")
213
+ for old_sec, new_sec in zip(sections, new_sections):
214
+ if old_sec != new_sec:
215
+ charge = charge.replace(f"Section {old_sec}", f"Section {new_sec}")
216
+ charge = charge.replace(f"section {old_sec}", f"section {new_sec}")
217
+ ep["charge_sheet"] = charge
218
+
219
+ # Metadata
220
+ parent_id = ep.get("case_id", "UNKNOWN")
221
+ ep["case_id"] = f"SYN_{parent_id}_SECT"
222
+ ep["source"] = "synthetic"
223
+ ep["parent_case_id"] = parent_id
224
+ ep["perturbation_type"] = "section_ambiguity"
225
+ ep["difficulty"] = "hard"
226
+
227
+ return ep
228
+
229
+
230
+ # ── Perturbation 4: Evidence Reversal ────────────────────────────────
231
+
232
+ def _evidence_reversal(episode: Dict[str, Any]) -> Optional[Dict[str, Any]]:
233
+ """
234
+ Add a contradicting element to the strongest evidence.
235
+ Tests whether the agent updates assessment when evidence weakens.
236
+ """
237
+ ep = copy.deepcopy(episode)
238
+
239
+ # Find the strongest evidence mention
240
+ evidence_keywords = ["witness", "evidence", "testimony", "eyewitness"]
241
+ pros_args = ep.get("prosecution_arguments", [])
242
+ charge = ep.get("charge_sheet", "")
243
+
244
+ # Check prosecution arguments first
245
+ target_arg = None
246
+ for arg in pros_args:
247
+ if any(kw in arg.lower() for kw in evidence_keywords):
248
+ target_arg = arg
249
+ break
250
+
251
+ if target_arg is None:
252
+ # Check charge sheet sentences
253
+ sentences = [s.strip() for s in charge.split('.') if s.strip()]
254
+ for sent in sentences:
255
+ if any(kw in sent.lower() for kw in evidence_keywords):
256
+ target_arg = sent
257
+ break
258
+
259
+ if target_arg is None:
260
+ return None # No evidence to reverse
261
+
262
+ # Add reversal to defence arguments
263
+ defence = ep.get("defence_arguments", [])
264
+ defence.append(
265
+ "However, the key prosecution evidence was subsequently found "
266
+ "unreliable — the primary witness retracted their statement and "
267
+ "forensic analysis raised doubts about the physical evidence."
268
+ )
269
+ ep["defence_arguments"] = defence
270
+
271
+ # Metadata
272
+ parent_id = ep.get("case_id", "UNKNOWN")
273
+ ep["case_id"] = f"SYN_{parent_id}_EVID"
274
+ ep["source"] = "synthetic"
275
+ ep["parent_case_id"] = parent_id
276
+ ep["perturbation_type"] = "evidence_reversal"
277
+ ep["difficulty"] = "hard"
278
+
279
+ return ep
280
+
281
+
282
+ # ── Perturbation 5: Surety Complexity ────────────────────────────────
283
+
284
+ def _surety_complexity(episode: Dict[str, Any]) -> Optional[Dict[str, Any]]:
285
+ """
286
+ Add a surety complication forcing careful condition assessment.
287
+ """
288
+ ep = copy.deepcopy(episode)
289
+
290
+ # Add surety complication to defence arguments
291
+ defence = ep.get("defence_arguments", [])
292
+ defence.append(
293
+ "Proposed surety is a non-resident relative with no verifiable "
294
+ "local assets or employment in the jurisdiction. Surety bond "
295
+ "amount of Rs. 5,00,000 proposed."
296
+ )
297
+ ep["defence_arguments"] = defence
298
+
299
+ # Add surety info to accused profile
300
+ profile = ep.get("accused_profile", {})
301
+ profile["surety_status"] = "non-resident, unverified assets"
302
+ ep["accused_profile"] = profile
303
+
304
+ # Metadata
305
+ parent_id = ep.get("case_id", "UNKNOWN")
306
+ ep["case_id"] = f"SYN_{parent_id}_SURE"
307
+ ep["source"] = "synthetic"
308
+ ep["parent_case_id"] = parent_id
309
+ ep["perturbation_type"] = "surety_complexity"
310
+ ep["difficulty"] = "hard"
311
+
312
+ return ep
server/performance_tracker.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UndertriAI — Performance Tracker (Theme 4: Self-Improvement)
3
+
4
+ Tracks the agent's running performance profile across dimensions
5
+ and uses it to drive adaptive curriculum decisions.
6
+
7
+ Pure Python — no server/training/FastAPI dependencies.
8
+ """
9
+
10
+ import warnings
11
+ from collections import deque
12
+ from typing import Any, Dict, List, Optional
13
+
14
+
15
+ class ExponentialMean:
16
+ """Exponential moving average with configurable decay."""
17
+
18
+ __slots__ = ("alpha", "value", "count")
19
+
20
+ def __init__(self, alpha: float = 0.1, initial: float = 0.5):
21
+ self.alpha = alpha
22
+ self.value = initial
23
+ self.count = 0
24
+
25
+ def update(self, x: float) -> None:
26
+ self.value = self.alpha * x + (1 - self.alpha) * self.value
27
+ self.count += 1
28
+
29
+ def get(self) -> float:
30
+ return self.value
31
+
32
+
33
+ class PerformanceTracker:
34
+ """
35
+ Tracks agent performance across crime types, stages, and reward
36
+ components. Drives adaptive episode selection and stage promotion.
37
+
38
+ Thread-safe for single-session use (no locks needed).
39
+ All public methods handle missing/malformed input gracefully.
40
+ """
41
+
42
+ def __init__(self, alpha: float = 0.1):
43
+ self._alpha = alpha
44
+
45
+ # Per-crime-type EMA of total reward
46
+ self.per_crime_type: Dict[str, ExponentialMean] = {}
47
+
48
+ # Per-stage EMA of total reward
49
+ self.per_stage: Dict[int, ExponentialMean] = {
50
+ s: ExponentialMean(alpha=alpha) for s in range(1, 5)
51
+ }
52
+
53
+ # Last 50 total rewards (for stage promotion smoothing)
54
+ self.recent_rewards: deque = deque(maxlen=50)
55
+
56
+ # Bias fire rate: 1.0 when penalty fired, 0.0 when not
57
+ self.bias_fire_rate: ExponentialMean = ExponentialMean(alpha=alpha)
58
+
59
+ # Tool usage counts (cumulative per session)
60
+ self.tool_usage: Dict[str, int] = {}
61
+
62
+ # Episode counters
63
+ self.episodes_seen: int = 0
64
+ self.stage_episodes: Dict[int, int] = {1: 0, 2: 0, 3: 0, 4: 0}
65
+
66
+ # Recent case performance for failure-replay
67
+ self._recent_case_rewards: deque = deque(maxlen=30)
68
+
69
+ # ------------------------------------------------------------------
70
+ # Core update
71
+ # ------------------------------------------------------------------
72
+
73
+ def update(
74
+ self,
75
+ episode: Dict[str, Any],
76
+ reward_components: Dict[str, Any],
77
+ tools_used: Optional[List[str]] = None,
78
+ ) -> None:
79
+ """
80
+ Update all internal state from a completed episode.
81
+
82
+ Handles missing keys gracefully — never raises on malformed input.
83
+ """
84
+ try:
85
+ total = float(reward_components.get("total_reward",
86
+ reward_components.get("total", 0.0)))
87
+ except (TypeError, ValueError):
88
+ total = 0.0
89
+
90
+ # Update recent rewards
91
+ self.recent_rewards.append(total)
92
+ self.episodes_seen += 1
93
+
94
+ # Per-crime-type tracking
95
+ crime_type = ""
96
+ try:
97
+ crime_type = str(episode.get("crime_type", "")).strip()
98
+ except Exception:
99
+ pass
100
+
101
+ if crime_type:
102
+ if crime_type not in self.per_crime_type:
103
+ self.per_crime_type[crime_type] = ExponentialMean(
104
+ alpha=self._alpha
105
+ )
106
+ self.per_crime_type[crime_type].update(total)
107
+
108
+ # Per-stage tracking
109
+ stage = 1
110
+ try:
111
+ stage = int(episode.get("curriculum_stage", 1))
112
+ except (TypeError, ValueError):
113
+ stage = 1
114
+ if 1 <= stage <= 4:
115
+ self.per_stage[stage].update(total)
116
+ self.stage_episodes[stage] = self.stage_episodes.get(stage, 0) + 1
117
+
118
+ # Bias fire rate
119
+ try:
120
+ bias_val = float(reward_components.get("bias_penalty", 0.0))
121
+ self.bias_fire_rate.update(1.0 if bias_val > 0.01 else 0.0)
122
+ except (TypeError, ValueError):
123
+ pass
124
+
125
+ # Tool usage
126
+ if tools_used:
127
+ for tool in tools_used:
128
+ t = str(tool)
129
+ self.tool_usage[t] = self.tool_usage.get(t, 0) + 1
130
+
131
+ # Track case_id → reward for failure-replay
132
+ case_id = ""
133
+ try:
134
+ case_id = str(episode.get("case_id", ""))
135
+ except Exception:
136
+ pass
137
+ if case_id:
138
+ self._recent_case_rewards.append((case_id, total, stage))
139
+
140
+ # ------------------------------------------------------------------
141
+ # Queries
142
+ # ------------------------------------------------------------------
143
+
144
+ def weakest_domain(self) -> Optional[str]:
145
+ """
146
+ Returns the crime_type with the lowest EMA reward.
147
+ Returns None if fewer than 5 episodes seen total or no crime type
148
+ has at least 3 observations.
149
+ """
150
+ if self.episodes_seen < 5:
151
+ return None
152
+
153
+ candidates = [
154
+ (ct, ema.get())
155
+ for ct, ema in self.per_crime_type.items()
156
+ if ema.count >= 3
157
+ ]
158
+ if not candidates:
159
+ return None
160
+
161
+ return min(candidates, key=lambda x: x[1])[0]
162
+
163
+ def suggest_next_stage(self) -> int:
164
+ """
165
+ Returns the recommended stage (1-4) based on readiness thresholds.
166
+ Never demotes — returns highest eligible stage.
167
+ """
168
+ current = 1
169
+
170
+ # Stage 1 → 2: EMA >= 0.65 AND at least 20 episodes
171
+ if (self.per_stage[1].get() >= 0.65
172
+ and self.stage_episodes.get(1, 0) >= 20):
173
+ current = 2
174
+
175
+ # Stage 2 → 3: EMA >= 0.55 AND at least 50 episodes
176
+ if (current >= 2
177
+ and self.per_stage[2].get() >= 0.55
178
+ and self.stage_episodes.get(2, 0) >= 50):
179
+ current = 3
180
+
181
+ # Stage 3 → 4: EMA >= 0.50 AND at least 20 episodes
182
+ if (current >= 3
183
+ and self.per_stage[3].get() >= 0.50
184
+ and self.stage_episodes.get(3, 0) >= 20):
185
+ current = 4
186
+
187
+ return current
188
+
189
+ def should_generate_synthetic(self, crime_type: str) -> bool:
190
+ """
191
+ Returns True if the agent has mastered this crime type domain
192
+ (EMA > 0.70 with at least 10 observations).
193
+ """
194
+ ema = self.per_crime_type.get(crime_type)
195
+ if ema is None:
196
+ return False
197
+ return ema.get() > 0.70 and ema.count >= 10
198
+
199
+ def get_recent_failures(self, threshold: float = 0.40) -> List[str]:
200
+ """
201
+ Returns case_ids from recent episodes where reward was below threshold.
202
+ Used by AdaptiveSelector for failure-replay.
203
+ """
204
+ return [
205
+ case_id
206
+ for case_id, reward, _ in self._recent_case_rewards
207
+ if reward < threshold
208
+ ]
209
+
210
+ # ------------------------------------------------------------------
211
+ # Serialization
212
+ # ------------------------------------------------------------------
213
+
214
+ def get_profile(self) -> Dict[str, Any]:
215
+ """
216
+ Returns a fully JSON-serializable profile dict.
217
+ No class instances — all values are primitive types.
218
+ """
219
+ recent = list(self.recent_rewards)
220
+ recent_mean = sum(recent) / len(recent) if recent else 0.0
221
+
222
+ return {
223
+ "per_crime_type": {
224
+ ct: round(ema.get(), 4)
225
+ for ct, ema in self.per_crime_type.items()
226
+ },
227
+ "per_stage": {
228
+ str(s): round(ema.get(), 4)
229
+ for s, ema in self.per_stage.items()
230
+ },
231
+ "bias_fire_rate": round(self.bias_fire_rate.get(), 4),
232
+ "tool_usage": dict(self.tool_usage),
233
+ "episodes_seen": self.episodes_seen,
234
+ "stage_episodes": dict(self.stage_episodes),
235
+ "weakest_domain": self.weakest_domain(),
236
+ "suggested_stage": self.suggest_next_stage(),
237
+ "recent_mean_reward": round(recent_mean, 4),
238
+ }
239
+
240
+ # ------------------------------------------------------------------
241
+ # Session management
242
+ # ------------------------------------------------------------------
243
+
244
+ def reset_session(self) -> None:
245
+ """
246
+ Clears transient session state but preserves accumulated
247
+ per-crime-type and per-stage learning.
248
+ """
249
+ self.recent_rewards.clear()
250
+ self.tool_usage.clear()
251
+ self._recent_case_rewards.clear()
server/reward.py CHANGED
@@ -11,6 +11,27 @@ import re
11
  from typing import Any, Dict, List, Optional
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # ---------------------------------------------------------------------------
15
  # 1. Outcome Match (40%)
16
  # ---------------------------------------------------------------------------
@@ -170,6 +191,23 @@ def compute_statutory_accuracy(
170
  if not special_laws and any(t in crime_type_lower for t in CRIME_TYPE_SPECIAL_LAWS):
171
  special_laws = "INFERRED" # Treat as special-law-restricted for eligibility
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  # Compute ground-truth eligibility for cases with known custody duration
174
  half_sent_months = (max_sent * 12) / 2
175
  truly_eligible = (custody_mo >= half_sent_months) and not special_laws
@@ -463,6 +501,72 @@ def compute_reasoning_quality(
463
  return round(max(0.0, min(1.0, base - consistency_deduction)), 4)
464
 
465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  # ---------------------------------------------------------------------------
467
  # Master reward function
468
  # ---------------------------------------------------------------------------
@@ -480,22 +584,28 @@ def compute_reward(
480
  agent_flight_risk_justification: str = "",
481
  agent_grounds_for: Optional[List[str]] = None,
482
  agent_grounds_against: Optional[List[str]] = None,
 
 
483
  ) -> Dict[str, float]:
484
  """
485
  Computes the full reward for a submitted bail assessment memo.
486
 
487
- Formula:
488
- R = 0.3*outcome_match (was 0.4 — reduced to reward reasoning)
489
  + 0.2*flight_risk_accuracy
490
  + 0.2*statutory_accuracy
491
  + 0.2*condition_appropriateness
492
- + 0.1*reasoning_quality (NEW — anchoring + arithmetic + specificity)
493
- + 0.1*efficiency_bonus (only when outcome is correct)
494
- + 0.05*process_bonus
 
495
  - 0.3*bias_penalty
496
 
 
 
 
 
497
  Returns a dict with all component scores + total_reward.
498
- Range: approx [-0.4, 1.1].
499
  """
500
  gt = episode["ground_truth"]
501
 
@@ -515,6 +625,18 @@ def compute_reward(
515
  episode = episode,
516
  )
517
 
 
 
 
 
 
 
 
 
 
 
 
 
518
  # Efficiency bonus: reward finishing faster when the answer is correct.
519
  # Only fires on directionally-correct outcomes (om >= 0.8) to prevent
520
  # rewarding efficient-but-wrong agents.
@@ -526,15 +648,25 @@ def compute_reward(
526
  # Process reward: +0.05 if agent actually used the statutory tool.
527
  process_bonus = 0.05 if statutory_tool_used else 0.0
528
 
 
 
 
 
 
529
  lam = 0.3
530
- total = 0.3*om + 0.2*fr + 0.2*sa + 0.2*ca + 0.1*rq + 0.1*efficiency + process_bonus - lam*bias
 
 
531
 
532
  return {
533
  "outcome_match": round(om, 4),
 
 
534
  "flight_risk_accuracy": round(fr, 4),
535
  "statutory_accuracy": round(sa, 4),
536
  "condition_appropriateness": round(ca, 4),
537
  "reasoning_quality": round(rq, 4),
 
538
  "efficiency_bonus": round(efficiency, 4),
539
  "process_bonus": round(process_bonus,4),
540
  "bias_penalty": round(bias, 4),
 
11
  from typing import Any, Dict, List, Optional
12
 
13
 
14
+ # ---------------------------------------------------------------------------
15
+ # Shared helper: NDPS detection (canonical definition — import this elsewhere)
16
+ # ---------------------------------------------------------------------------
17
+
18
+ def _is_ndps_case(episode: dict) -> bool:
19
+ """
20
+ Detect narcotics cases even when special_laws field is empty.
21
+ Checks ipc_sections and crime_type for NDPS indicators.
22
+
23
+ This is the SINGLE canonical definition — import from server.reward
24
+ in undertrial_environment.py and training/train_grpo.py.
25
+ """
26
+ sections = " ".join(str(s) for s in episode.get("ipc_sections", [])).lower()
27
+ crime = str(episode.get("crime_type", "")).lower()
28
+ narcotics_indicators = [
29
+ "ndps", "narcotic", "drug", "psychotropic",
30
+ "20(b)", "22(b)", "27a", "section 37",
31
+ ]
32
+ return any(ind in sections or ind in crime for ind in narcotics_indicators)
33
+
34
+
35
  # ---------------------------------------------------------------------------
36
  # 1. Outcome Match (40%)
37
  # ---------------------------------------------------------------------------
 
191
  if not special_laws and any(t in crime_type_lower for t in CRIME_TYPE_SPECIAL_LAWS):
192
  special_laws = "INFERRED" # Treat as special-law-restricted for eligibility
193
 
194
+ # ── B9: NDPS-specific statutory scoring ──────────────────────────────
195
+ # NDPS Section 37 twin conditions override standard threshold logic.
196
+ # Reward the agent for recognizing NDPS applies, not for arithmetic.
197
+ if _is_ndps_case(episode):
198
+ gt_granted = "grant" in gt_outcome.lower()
199
+ direction_correct = (agent_eligible == gt_granted)
200
+ ndps_recognized = any(
201
+ t in comp for t in ["section 37", "twin condition", "ndps", "37(1)(b)"]
202
+ )
203
+ if ndps_recognized and direction_correct:
204
+ return 1.0
205
+ elif direction_correct:
206
+ return 0.5
207
+ else:
208
+ return 0.0
209
+
210
+ # ── Standard IPC/BNSS statutory scoring ──────────────────────────────
211
  # Compute ground-truth eligibility for cases with known custody duration
212
  half_sent_months = (max_sent * 12) / 2
213
  truly_eligible = (custody_mo >= half_sent_months) and not special_laws
 
501
  return round(max(0.0, min(1.0, base - consistency_deduction)), 4)
502
 
503
 
504
+ # ---------------------------------------------------------------------------
505
+ # 7. Think-block reasoning gate (B6)
506
+ # ---------------------------------------------------------------------------
507
+
508
+ def compute_think_factor(completion: str, current_stage: int) -> float:
509
+ """
510
+ Gate outcome credit on reasoning quality.
511
+ Stage 1: soft floor of 0.3 minimum (model still learning format).
512
+ Stage 2+: hard gate — no reasoning = no outcome credit.
513
+ Threshold: 120 words for full credit.
514
+ """
515
+ if not completion:
516
+ return 0.3 if current_stage == 1 else 0.0
517
+
518
+ think_match = re.search(r'<think>(.*?)</think>', completion, re.DOTALL)
519
+ think_text = think_match.group(1).strip() if think_match else ""
520
+ think_len = len(think_text.split())
521
+ raw_factor = min(1.0, think_len / 120.0)
522
+
523
+ if current_stage == 1:
524
+ # Soft floor: minimum 0.3 credit even with no think block
525
+ # Ensures GRPO has non-zero gradient signal in early training
526
+ return 0.3 + 0.7 * raw_factor
527
+ else:
528
+ # Hard gate: must reason to earn outcome credit
529
+ return raw_factor
530
+
531
+
532
+ # ---------------------------------------------------------------------------
533
+ # 8. Format compliance (B8)
534
+ # ---------------------------------------------------------------------------
535
+
536
+ def reward_format(completion: str) -> float:
537
+ """
538
+ Score structural compliance of the bail memo.
539
+ Checks for required XML tags matching the system prompt and valid outcome.
540
+ Returns 0.0–1.0 (fraction of required elements present).
541
+ """
542
+ if not completion:
543
+ return 0.0
544
+
545
+ # Tags must match exactly what SYSTEM_PROMPT instructs the model to produce
546
+ required_tags = [
547
+ r'<think>',
548
+ r'<memo>',
549
+ r'<flight_risk>',
550
+ r'<statutory_eligible>',
551
+ r'<recommended_outcome>',
552
+ r'<statutory_computation>',
553
+ ]
554
+ valid_outcomes = [
555
+ 'bail granted', 'bail denied',
556
+ 'conditional bail', 'default bail',
557
+ ]
558
+
559
+ checks = [
560
+ bool(re.search(tag, completion, re.IGNORECASE))
561
+ for tag in required_tags
562
+ ]
563
+ checks.append(
564
+ any(outcome in completion.lower() for outcome in valid_outcomes)
565
+ )
566
+
567
+ return sum(checks) / len(checks)
568
+
569
+
570
  # ---------------------------------------------------------------------------
571
  # Master reward function
572
  # ---------------------------------------------------------------------------
 
584
  agent_flight_risk_justification: str = "",
585
  agent_grounds_for: Optional[List[str]] = None,
586
  agent_grounds_against: Optional[List[str]] = None,
587
+ completion_text: Optional[str] = None,
588
+ current_stage: int = 1,
589
  ) -> Dict[str, float]:
590
  """
591
  Computes the full reward for a submitted bail assessment memo.
592
 
593
+ Formula (B6/B8 update):
594
+ R = 0.4*outcome_gated (gated by think_factor)
595
  + 0.2*flight_risk_accuracy
596
  + 0.2*statutory_accuracy
597
  + 0.2*condition_appropriateness
598
+ + 0.1*reasoning_quality
599
+ + 0.05*efficiency_bonus
600
+ + 0.05*format_score
601
+ + process_bonus
602
  - 0.3*bias_penalty
603
 
604
+ Core components: 0.4+0.2+0.2+0.2 = 1.0
605
+ Bonuses: rq(0.1) + eff(0.05) + fmt(0.05) + process(0.05)
606
+ Penalty: -0.3*bias
607
+
608
  Returns a dict with all component scores + total_reward.
 
609
  """
610
  gt = episode["ground_truth"]
611
 
 
625
  episode = episode,
626
  )
627
 
628
+ # B6: Gate outcome credit on reasoning quality (think block)
629
+ # In server path, completion_text may be None (structured memo submission)
630
+ # — default to think_factor=1.0 (no gating; env already enforces min tools).
631
+ if completion_text:
632
+ think_factor = compute_think_factor(completion_text, current_stage)
633
+ else:
634
+ think_factor = 1.0
635
+ om_gated = om * think_factor
636
+
637
+ # B8: Format compliance score
638
+ fmt = reward_format(completion_text) if completion_text else 0.5
639
+
640
  # Efficiency bonus: reward finishing faster when the answer is correct.
641
  # Only fires on directionally-correct outcomes (om >= 0.8) to prevent
642
  # rewarding efficient-but-wrong agents.
 
648
  # Process reward: +0.05 if agent actually used the statutory tool.
649
  process_bonus = 0.05 if statutory_tool_used else 0.0
650
 
651
+ # Reward formula:
652
+ # Core (sum=1.0): 0.4*outcome_gated + 0.2*flight + 0.2*statutory + 0.2*conditions
653
+ # Bonuses: 0.1*reasoning_quality + 0.05*efficiency + 0.05*format
654
+ # Process: +0.05 if statutory tool used
655
+ # Penalty: -0.3*bias
656
  lam = 0.3
657
+ total = (0.4*om_gated + 0.2*fr + 0.2*sa + 0.2*ca
658
+ + 0.1*rq + 0.05*efficiency + 0.05*fmt
659
+ + process_bonus - lam*bias)
660
 
661
  return {
662
  "outcome_match": round(om, 4),
663
+ "outcome_match_gated": round(om_gated, 4),
664
+ "think_factor": round(think_factor, 4),
665
  "flight_risk_accuracy": round(fr, 4),
666
  "statutory_accuracy": round(sa, 4),
667
  "condition_appropriateness": round(ca, 4),
668
  "reasoning_quality": round(rq, 4),
669
+ "format_score": round(fmt, 4),
670
  "efficiency_bonus": round(efficiency, 4),
671
  "process_bonus": round(process_bonus,4),
672
  "bias_penalty": round(bias, 4),
server/undertrial_environment.py CHANGED
@@ -8,7 +8,7 @@ import uuid
8
  from typing import Any, Dict, List, Optional
9
 
10
  from .dataset import BailDataset
11
- from .reward import compute_reward
12
  from .schema_drift import maybe_apply_drift
13
 
14
  try:
@@ -277,6 +277,21 @@ class UndertriAIEnvironment(Environment):
277
  return "No directly applicable precedents found in database."
278
 
279
  elif isinstance(action, ComputeStatutoryEligibilityAction):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  half_months = (action.max_sentence_years * 12) / 2
281
  eligible = action.custody_months >= half_months and not action.special_law_applicable
282
  pct = round((action.custody_months / (action.max_sentence_years * 12)) * 100, 1) if action.max_sentence_years else 0
 
8
  from typing import Any, Dict, List, Optional
9
 
10
  from .dataset import BailDataset
11
+ from .reward import compute_reward, _is_ndps_case
12
  from .schema_drift import maybe_apply_drift
13
 
14
  try:
 
277
  return "No directly applicable precedents found in database."
278
 
279
  elif isinstance(action, ComputeStatutoryEligibilityAction):
280
+ # B9: NDPS cases get Section 37 response instead of threshold arithmetic
281
+ if _is_ndps_case(self._episode):
282
+ return (
283
+ f"Statutory Eligibility Analysis:\n"
284
+ f" Sections: {', '.join(action.sections_invoked)}\n"
285
+ f" Special Law: NDPS Act applies\n"
286
+ f" Section: Section 37 NDPS Act\n"
287
+ f" Message: NDPS Section 37 applies. Standard custody threshold not applicable. "
288
+ f"Bail requires twin conditions under Section 37(1)(b): "
289
+ f"(i) reasonable grounds to believe accused is not guilty, "
290
+ f"(ii) no reasonable opportunity to commit offence if released. "
291
+ f"These are matters for judicial discretion, not statutory calculation.\n"
292
+ f" → ELIGIBLE FOR DEFAULT BAIL: NOT APPLICABLE (NDPS twin conditions govern)"
293
+ )
294
+
295
  half_months = (action.max_sentence_years * 12) / 2
296
  eligible = action.custody_months >= half_months and not action.special_law_applicable
297
  pct = round((action.custody_months / (action.max_sentence_years * 12)) * 100, 1) if action.max_sentence_years else 0
training/train_grpo.py CHANGED
@@ -52,12 +52,41 @@ try:
52
  compute_condition_score,
53
  compute_bias_penalty as _server_bias,
54
  compute_reasoning_quality,
 
 
 
55
  )
56
  _USE_SERVER_REWARDS = True
57
  print("[reward] Using authoritative server/reward.py functions.")
58
  except ImportError:
59
  _USE_SERVER_REWARDS = False
60
  print("[reward] server/reward.py not found — using local fallback functions.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  from datasets import Dataset
62
 
63
  # ============================================================
@@ -188,16 +217,39 @@ def parse_model_output(output: str) -> Dict[str, Any]:
188
 
189
 
190
  def reward_format(completions: List[str], **kwargs) -> List[float]:
191
- """Reward well-formed XML output structure."""
192
- scores = []
193
- for c in completions:
194
- score = 0.0
195
- if "<think>" in c and "</think>" in c: score += 0.15
196
- if "<memo>" in c and "</memo>" in c: score += 0.15
197
- for tag in ["flight_risk","statutory_eligible","recommended_outcome","statutory_computation"]:
198
- if f"<{tag}>" in c: score += 0.05
199
- scores.append(min(1.0, score))
200
- return scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
 
203
  def reward_outcome_match(completions: List[str], episode_batch: List[Dict], **kwargs) -> List[float]:
@@ -234,7 +286,12 @@ def reward_flight_risk(completions: List[str], episode_batch: List[Dict], **kwar
234
 
235
 
236
  def reward_statutory(completions: List[str], episode_batch: List[Dict], **kwargs) -> List[float]:
237
- """20% weight: correct statutory eligibility computation."""
 
 
 
 
 
238
  scores = []
239
  for comp, ep in zip(completions, episode_batch):
240
  parsed = parse_model_output(comp)
@@ -242,19 +299,61 @@ def reward_statutory(completions: List[str], episode_batch: List[Dict], **kwargs
242
  sections = ep.get("ipc_sections", [])
243
  max_sent = ep.get("max_sentence_years", 5.0)
244
  custody = ep.get("custody_months", 0.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  score = 0.0
247
- # Mentions relevant sections
248
- for sec in sections:
249
- if sec.strip().lower() in comp_text or sec.strip() in comp:
250
- score += 0.2
251
- score = min(0.4, score)
252
-
253
- # Mentions numbers
254
- if re.search(r'\d+', comp_text): score += 0.3
255
- # Mentions time-related words
256
- if any(w in comp_text for w in ["month","year","sentence","custody","half","served","threshold"]):
257
- score += 0.3
 
 
 
 
 
 
 
 
 
 
 
258
  scores.append(min(1.0, score))
259
  return scores
260
 
@@ -309,14 +408,24 @@ def reward_no_bias(completions: List[str], episode_batch: List[Dict], **kwargs)
309
  def combined_reward(
310
  completions: List[str],
311
  episode_batch: List[Dict],
 
312
  **kwargs
313
  ) -> List[float]:
314
  """
315
  Master reward combining all components.
316
- R = 0.4*outcome + 0.2*flight_risk + 0.2*statutory + 0.2*condition - 0.3*bias
 
 
 
 
 
 
 
 
317
 
318
  Uses server/reward.py functions when available (Fix 1).
319
- Condition appropriateness replaces format score (Fix 2).
 
320
  """
321
  rewards = []
322
 
@@ -359,12 +468,22 @@ def combined_reward(
359
  b = reward_no_bias([comp], [ep])[0]
360
  rq = 0.5 # Neutral when server functions unavailable
361
 
362
- # NOTE: Efficiency is NOT computed in GRPO training because step_count=1
363
- # always (single-shot generation), making eff=1.0 a constant non-signal.
364
- # Efficiency is preserved in the environment's compute_reward for live inference.
365
- eff = 0.0
366
 
367
- total = 0.3*o + 0.2*fr + 0.2*s + 0.2*ca + 0.1*rq - 0.3*b
 
 
 
 
 
 
 
 
 
 
 
368
  rewards.append(round(total, 4)) # No max(0.0) clamp — bias can go negative
369
  return rewards
370
 
@@ -593,7 +712,7 @@ def train(
593
  # Reward wrapper that unpacks the stored JSON episode
594
  def reward_fn(completions: List[str], episode: List[str], **kwargs) -> List[float]:
595
  ep_objs = [json.loads(e) for e in episode]
596
- return combined_reward(completions, ep_objs)
597
 
598
  # ── GRPO Config ──────────────────────────────────────────
599
  from trl import GRPOConfig, GRPOTrainer # type: ignore
@@ -711,7 +830,7 @@ def evaluate_baseline(episodes_dir: str, n_samples: int = 20):
711
  out = model.generate(inputs, max_new_tokens=512, temperature=0.7, do_sample=True)
712
 
713
  completion = tokenizer.decode(out[0][inputs.shape[1]:], skip_special_tokens=True)
714
- r = combined_reward([completion], [ep])[0]
715
  rewards.append(r)
716
  print(f" Case {ep['case_id']}: reward={r:.3f} | GT={ep['ground_truth']['outcome']}")
717
 
@@ -770,7 +889,7 @@ def evaluate_on_stage(
770
  out = model.generate(inputs, max_new_tokens=512, temperature=0.7, do_sample=True)
771
 
772
  completion = tokenizer.decode(out[0][inputs.shape[1]:], skip_special_tokens=True)
773
- r = combined_reward([completion], [ep])[0]
774
  rewards.append(r)
775
  results.append({"episode": ep, "completion": completion, "reward": r})
776
 
@@ -916,10 +1035,7 @@ def train_curriculum(
916
 
917
  def reward_fn(completions: List[str], episode: List[str], **kwargs) -> List[float]:
918
  ep_objs = [json.loads(e) for e in episode]
919
- # Pass step_count=1 for curriculum training (single-shot XML, no multi-step env loop)
920
- # This keeps efficiency contribution honest rather than silently 0.0
921
- step_counts = [1] * len(completions)
922
- return combined_reward(completions, ep_objs, step_counts=step_counts)
923
 
924
  stage_output = f"{output_dir}/stage_{stage}"
925
  config = GRPOConfig(
@@ -1019,7 +1135,248 @@ def train_curriculum(
1019
 
1020
 
1021
  # ============================================================
1022
- # CELL 9 — Entry point
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1023
  # ============================================================
1024
 
1025
  if __name__ == "__main__":
@@ -1035,6 +1392,10 @@ if __name__ == "__main__":
1035
  help="Run evaluation after training to measure improvement")
1036
  parser.add_argument("--curriculum", action="store_true",
1037
  help="Run self-improving curriculum training (all 4 stages)")
 
 
 
 
1038
 
1039
  args = parser.parse_args()
1040
 
@@ -1047,6 +1408,15 @@ if __name__ == "__main__":
1047
  max_steps_per_stage=args.steps,
1048
  batch_size=args.batch_size,
1049
  )
 
 
 
 
 
 
 
 
 
1050
  else:
1051
  train(
1052
  episodes_dir = args.episodes_dir,
 
52
  compute_condition_score,
53
  compute_bias_penalty as _server_bias,
54
  compute_reasoning_quality,
55
+ compute_think_factor,
56
+ reward_format as server_reward_format,
57
+ _is_ndps_case,
58
  )
59
  _USE_SERVER_REWARDS = True
60
  print("[reward] Using authoritative server/reward.py functions.")
61
  except ImportError:
62
  _USE_SERVER_REWARDS = False
63
  print("[reward] server/reward.py not found — using local fallback functions.")
64
+
65
+ # Local fallback definition of _is_ndps_case (mirrors server/reward.py)
66
+ def _is_ndps_case(episode: dict) -> bool:
67
+ sections = " ".join(str(s) for s in episode.get("ipc_sections", [])).lower()
68
+ crime = str(episode.get("crime_type", "")).lower()
69
+ narcotics_indicators = [
70
+ "ndps", "narcotic", "drug", "psychotropic",
71
+ "20(b)", "22(b)", "27a", "section 37",
72
+ ]
73
+ return any(ind in sections or ind in crime for ind in narcotics_indicators)
74
+
75
+ # Local fallback definition of compute_think_factor (mirrors server/reward.py)
76
+ def compute_think_factor(completion: str, current_stage: int) -> float:
77
+ if not completion:
78
+ return 0.3 if current_stage == 1 else 0.0
79
+ think_match = re.search(r'<think>(.*?)</think>', completion, re.DOTALL)
80
+ think_text = think_match.group(1).strip() if think_match else ""
81
+ think_len = len(think_text.split())
82
+ raw_factor = min(1.0, think_len / 120.0)
83
+ if current_stage == 1:
84
+ return 0.3 + 0.7 * raw_factor
85
+ else:
86
+ return raw_factor
87
+
88
+ # Local fallback server_reward_format
89
+ server_reward_format = None # Will use local reward_format below
90
  from datasets import Dataset
91
 
92
  # ============================================================
 
217
 
218
 
219
  def reward_format(completions: List[str], **kwargs) -> List[float]:
220
+ """Reward well-formed XML output structure (batch API for GRPO compatibility)."""
221
+ return [reward_format_single(c) for c in completions]
222
+
223
+
224
+ def reward_format_single(completion: str) -> float:
225
+ """
226
+ Score structural compliance of the bail memo.
227
+ Checks for required XML tags matching the system prompt and valid outcome.
228
+ Returns 0.0–1.0 (fraction of required elements present).
229
+ """
230
+ if not completion:
231
+ return 0.0
232
+ # Tags match exactly what SYSTEM_PROMPT instructs the model to produce
233
+ required_tags = [
234
+ r'<think>',
235
+ r'<memo>',
236
+ r'<flight_risk>',
237
+ r'<statutory_eligible>',
238
+ r'<recommended_outcome>',
239
+ r'<statutory_computation>',
240
+ ]
241
+ valid_outcomes = [
242
+ 'bail granted', 'bail denied',
243
+ 'conditional bail', 'default bail',
244
+ ]
245
+ checks = [
246
+ bool(re.search(tag, completion, re.IGNORECASE))
247
+ for tag in required_tags
248
+ ]
249
+ checks.append(
250
+ any(outcome in completion.lower() for outcome in valid_outcomes)
251
+ )
252
+ return sum(checks) / len(checks)
253
 
254
 
255
  def reward_outcome_match(completions: List[str], episode_batch: List[Dict], **kwargs) -> List[float]:
 
286
 
287
 
288
  def reward_statutory(completions: List[str], episode_batch: List[Dict], **kwargs) -> List[float]:
289
+ """20% weight: correct statutory eligibility computation.
290
+
291
+ B3: Direction-gated computation bonus — wrong direction gets 0.10 not 0.30.
292
+ B9: NDPS cases use crime_type detection and reward Section 37 recognition.
293
+ """
294
+ TIME_WORDS = ["month", "year", "sentence", "custody", "half", "served", "threshold"]
295
  scores = []
296
  for comp, ep in zip(completions, episode_batch):
297
  parsed = parse_model_output(comp)
 
299
  sections = ep.get("ipc_sections", [])
300
  max_sent = ep.get("max_sentence_years", 5.0)
301
  custody = ep.get("custody_months", 0.0)
302
+ special_laws = ep.get("special_laws", "").strip()
303
+ gt_outcome = ep.get("ground_truth", {}).get("outcome", "")
304
+ agent_eligible = parsed["statutory_eligible"]
305
+
306
+ # B9: NDPS-specific scoring
307
+ if _is_ndps_case(ep):
308
+ gt_granted = "grant" in gt_outcome.lower()
309
+ direction_correct = (agent_eligible == gt_granted)
310
+ ndps_recognized = any(
311
+ t in comp_text for t in ["section 37", "twin condition", "ndps", "37(1)(b)"]
312
+ )
313
+ if ndps_recognized and direction_correct:
314
+ scores.append(1.0)
315
+ elif direction_correct:
316
+ scores.append(0.5)
317
+ else:
318
+ scores.append(0.0)
319
+ continue
320
+
321
+ # Infer special law from crime_type
322
+ CRIME_TYPE_SPECIAL_LAWS = [
323
+ "narcotics", "ndps", "pocso", "uapa", "pmla",
324
+ "terrorism", "organised crime", "money laundering",
325
+ ]
326
+ crime_type_lower = ep.get("crime_type", "").lower()
327
+ if not special_laws and any(t in crime_type_lower for t in CRIME_TYPE_SPECIAL_LAWS):
328
+ special_laws = "INFERRED"
329
+
330
+ # Standard IPC/BNSS threshold computation
331
+ half_sent_months = (max_sent * 12) / 2
332
+ truly_eligible = (custody >= half_sent_months) and not special_laws
333
 
334
  score = 0.0
335
+
336
+ # 40%: eligibility direction
337
+ direction_correct = (agent_eligible == truly_eligible)
338
+ if direction_correct:
339
+ score += 0.4
340
+ elif (agent_eligible and "grant" in gt_outcome.lower()) or \
341
+ (not agent_eligible and "deni" in gt_outcome.lower()):
342
+ score += 0.2
343
+
344
+ # 30%: cited relevant sections
345
+ if sections:
346
+ hits = sum(1 for sec in sections if sec.strip().lower() in comp_text or sec.strip() in comp)
347
+ score += 0.3 * min(1.0, hits / len(sections))
348
+
349
+ # 30%: numeric computation (B3: direction-gated)
350
+ has_numbers = bool(re.search(r'\d+', comp_text))
351
+ has_time_ref = any(w in comp_text for w in TIME_WORDS)
352
+ if has_numbers and has_time_ref:
353
+ score += 0.3 if direction_correct else 0.10
354
+ elif has_numbers or has_time_ref:
355
+ score += 0.15 if direction_correct else 0.05
356
+
357
  scores.append(min(1.0, score))
358
  return scores
359
 
 
408
  def combined_reward(
409
  completions: List[str],
410
  episode_batch: List[Dict],
411
+ current_stage: int = 1,
412
  **kwargs
413
  ) -> List[float]:
414
  """
415
  Master reward combining all components.
416
+
417
+ Formula (B6/B8 update):
418
+ R = 0.4*outcome_gated + 0.2*flight_risk + 0.2*statutory + 0.2*condition
419
+ + 0.1*reasoning_quality + 0.05*format
420
+ - 0.3*bias
421
+
422
+ Core (sum=1.0): 0.4*om_gated + 0.2*fr + 0.2*s + 0.2*ca
423
+ Bonuses: 0.1*rq + 0.05*fmt
424
+ Penalty: -0.3*bias
425
 
426
  Uses server/reward.py functions when available (Fix 1).
427
+ B6: Outcome gated by think_factor (stage-aware).
428
+ B8: Format compliance score included with 0.05 weight.
429
  """
430
  rewards = []
431
 
 
468
  b = reward_no_bias([comp], [ep])[0]
469
  rq = 0.5 # Neutral when server functions unavailable
470
 
471
+ # B6: Gate outcome credit on reasoning quality (think block)
472
+ think_factor = compute_think_factor(comp, current_stage)
473
+ om_gated = o * think_factor
 
474
 
475
+ # B8: Format compliance score
476
+ if _USE_SERVER_REWARDS and server_reward_format is not None:
477
+ fmt = server_reward_format(comp)
478
+ else:
479
+ fmt = reward_format_single(comp)
480
+
481
+ # Reward formula:
482
+ # Core (sum=1.0): 0.4*outcome_gated + 0.2*flight + 0.2*statutory + 0.2*conditions
483
+ # Bonuses: 0.1*reasoning_quality + 0.05*format
484
+ # Penalty: -0.3*bias
485
+ total = (0.4*om_gated + 0.2*fr + 0.2*s + 0.2*ca
486
+ + 0.1*rq + 0.05*fmt - 0.3*b)
487
  rewards.append(round(total, 4)) # No max(0.0) clamp — bias can go negative
488
  return rewards
489
 
 
712
  # Reward wrapper that unpacks the stored JSON episode
713
  def reward_fn(completions: List[str], episode: List[str], **kwargs) -> List[float]:
714
  ep_objs = [json.loads(e) for e in episode]
715
+ return combined_reward(completions, ep_objs, current_stage=stage)
716
 
717
  # ── GRPO Config ──────────────────────────────────────────
718
  from trl import GRPOConfig, GRPOTrainer # type: ignore
 
830
  out = model.generate(inputs, max_new_tokens=512, temperature=0.7, do_sample=True)
831
 
832
  completion = tokenizer.decode(out[0][inputs.shape[1]:], skip_special_tokens=True)
833
+ r = combined_reward([completion], [ep], current_stage=1)[0]
834
  rewards.append(r)
835
  print(f" Case {ep['case_id']}: reward={r:.3f} | GT={ep['ground_truth']['outcome']}")
836
 
 
889
  out = model.generate(inputs, max_new_tokens=512, temperature=0.7, do_sample=True)
890
 
891
  completion = tokenizer.decode(out[0][inputs.shape[1]:], skip_special_tokens=True)
892
+ r = combined_reward([completion], [ep], current_stage=stage)[0]
893
  rewards.append(r)
894
  results.append({"episode": ep, "completion": completion, "reward": r})
895
 
 
1035
 
1036
  def reward_fn(completions: List[str], episode: List[str], **kwargs) -> List[float]:
1037
  ep_objs = [json.loads(e) for e in episode]
1038
+ return combined_reward(completions, ep_objs, current_stage=stage)
 
 
 
1039
 
1040
  stage_output = f"{output_dir}/stage_{stage}"
1041
  config = GRPOConfig(
 
1135
 
1136
 
1137
  # ============================================================
1138
+ # CELL 9 — Adaptive Training (Theme 4: Self-Improvement)
1139
+ # ============================================================
1140
+
1141
+ def train_adaptive(
1142
+ episodes_dir: str = "./data/episodes",
1143
+ output_dir: str = "./output/undertrial_adaptive",
1144
+ steps_per_assessment: int = 50,
1145
+ max_total_steps: int = 2000,
1146
+ batch_size: int = 4,
1147
+ grad_accum: int = 4,
1148
+ lr: float = 5e-6,
1149
+ base_url: str = "http://localhost:8000",
1150
+ ):
1151
+ """
1152
+ Self-directed curriculum training (Theme 4).
1153
+
1154
+ Uses the /profile endpoint to check stage readiness every
1155
+ steps_per_assessment steps and promotes automatically.
1156
+
1157
+ This function communicates with the server via HTTP — it does NOT
1158
+ import server internals. OpenEnv client/server separation is preserved.
1159
+
1160
+ Training loop:
1161
+ 1. Start at stage 1
1162
+ 2. Train for steps_per_assessment steps
1163
+ 3. Query /profile for suggested_stage
1164
+ 4. If suggested_stage > current_stage, promote
1165
+ 5. Repeat until max_total_steps or stage 4 mastered
1166
+ """
1167
+ print("=" * 60)
1168
+ print(" UndertriAI — Adaptive Self-Improvement Training")
1169
+ print(f" Assessment every {steps_per_assessment} steps | Max {max_total_steps} steps")
1170
+ print(f" Server: {base_url}")
1171
+ print("=" * 60)
1172
+
1173
+ from unsloth import FastLanguageModel # type: ignore
1174
+ from trl import GRPOConfig, GRPOTrainer # type: ignore
1175
+
1176
+ # Load model once
1177
+ model, tokenizer = FastLanguageModel.from_pretrained(
1178
+ model_name="unsloth/Qwen2.5-7B-Instruct",
1179
+ max_seq_length=3072,
1180
+ load_in_4bit=True,
1181
+ fast_inference=False,
1182
+ )
1183
+ model = FastLanguageModel.get_peft_model(
1184
+ model,
1185
+ r=16,
1186
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
1187
+ "gate_proj", "up_proj", "down_proj"],
1188
+ lora_alpha=16, lora_dropout=0, bias="none",
1189
+ use_gradient_checkpointing="unsloth", random_state=42,
1190
+ )
1191
+
1192
+ # HTTP helper for server communication
1193
+ def query_profile(session_id: str) -> Optional[Dict]:
1194
+ """Query the performance profile from the server via HTTP."""
1195
+ try:
1196
+ url = f"{base_url}/profile?session_id={urllib.parse.quote(session_id)}"
1197
+ req = urllib.request.Request(url)
1198
+ with urllib.request.urlopen(req, timeout=5.0) as resp:
1199
+ return json.loads(resp.read())
1200
+ except Exception as e:
1201
+ print(f" [adaptive] Could not reach server profile: {e}")
1202
+ return None
1203
+
1204
+ def notify_reset(session_id: str, stage: int) -> Optional[str]:
1205
+ """Call /reset with adaptive=true on the server."""
1206
+ try:
1207
+ url = f"{base_url}/reset?session_id={urllib.parse.quote(session_id)}&stage={stage}&adaptive=true"
1208
+ req = urllib.request.Request(url, method="POST")
1209
+ with urllib.request.urlopen(req, timeout=5.0) as resp:
1210
+ data = json.loads(resp.read())
1211
+ return data.get("session_id", session_id)
1212
+ except Exception:
1213
+ return None
1214
+
1215
+ current_stage = 1
1216
+ total_steps = 0
1217
+ session_id = f"adaptive_{uuid.uuid4().hex[:8]}" if 'uuid' in dir() else "adaptive_training"
1218
+
1219
+ # Try to initialise session on server
1220
+ import uuid as _uuid_mod
1221
+ session_id = f"adaptive_{_uuid_mod.uuid4().hex[:8]}"
1222
+ notify_reset(session_id, current_stage)
1223
+
1224
+ # Tracking
1225
+ stage_promotion_steps = []
1226
+ reward_curve = []
1227
+ stage_rewards = {1: [], 2: [], 3: [], 4: []}
1228
+
1229
+ while total_steps < max_total_steps:
1230
+ print(f"\n{'━' * 60}")
1231
+ print(f" ADAPTIVE BLOCK: Steps {total_steps}–{total_steps + steps_per_assessment}")
1232
+ print(f" Current Stage: {current_stage} — {STAGE_NAMES.get(current_stage, '?')}")
1233
+ print(f"{'━' * 60}")
1234
+
1235
+ # Load episodes for current stage
1236
+ try:
1237
+ episodes = load_episodes(episodes_dir, stage=current_stage, split="train")
1238
+ except FileNotFoundError:
1239
+ print(f" No episodes for stage {current_stage} — breaking")
1240
+ break
1241
+
1242
+ if not episodes:
1243
+ print(f" Empty episode list for stage {current_stage} — breaking")
1244
+ break
1245
+
1246
+ # Build dataset
1247
+ dataset = build_hf_dataset(episodes, tokenizer)
1248
+ stage_for_closure = current_stage # Capture for closure
1249
+
1250
+ def reward_fn(completions: List[str], episode: List[str], **kwargs) -> List[float]:
1251
+ ep_objs = [json.loads(e) for e in episode]
1252
+ return combined_reward(completions, ep_objs, current_stage=stage_for_closure)
1253
+
1254
+ block_output = f"{output_dir}/block_{total_steps}"
1255
+ config = GRPOConfig(
1256
+ output_dir=block_output,
1257
+ learning_rate=lr,
1258
+ per_device_train_batch_size=batch_size,
1259
+ gradient_accumulation_steps=grad_accum,
1260
+ num_train_epochs=1,
1261
+ max_steps=steps_per_assessment,
1262
+ num_generations=6,
1263
+ max_completion_length=1024,
1264
+ temperature=0.7,
1265
+ beta=0.01,
1266
+ logging_steps=5,
1267
+ save_steps=steps_per_assessment,
1268
+ report_to="none",
1269
+ remove_unused_columns=False,
1270
+ )
1271
+
1272
+ FastLanguageModel.for_training(model)
1273
+ trainer = GRPOTrainer(
1274
+ model=model,
1275
+ processing_class=tokenizer,
1276
+ config=config,
1277
+ train_dataset=dataset,
1278
+ reward_funcs=[reward_fn],
1279
+ )
1280
+ trainer.train()
1281
+ total_steps += steps_per_assessment
1282
+
1283
+ # Evaluate current performance
1284
+ eval_reward, _ = evaluate_on_stage(
1285
+ model, tokenizer, episodes_dir, stage=current_stage, n_samples=15
1286
+ )
1287
+ stage_rewards[current_stage].append(eval_reward)
1288
+ reward_curve.append((total_steps, round(eval_reward, 4)))
1289
+ print(f" Stage {current_stage} eval reward: {eval_reward:.4f}")
1290
+
1291
+ # Query server for stage promotion suggestion
1292
+ profile_data = query_profile(session_id)
1293
+ suggested_stage = current_stage
1294
+
1295
+ if profile_data and "profile" in profile_data:
1296
+ suggested_stage = profile_data["profile"].get(
1297
+ "suggested_stage", current_stage
1298
+ )
1299
+ else:
1300
+ # Fallback: use local heuristic
1301
+ if eval_reward >= 0.65 and current_stage == 1:
1302
+ suggested_stage = 2
1303
+ elif eval_reward >= 0.55 and current_stage == 2:
1304
+ suggested_stage = 3
1305
+ elif eval_reward >= 0.50 and current_stage == 3:
1306
+ suggested_stage = 4
1307
+
1308
+ if suggested_stage > current_stage:
1309
+ old_stage = current_stage
1310
+ old_reward = eval_reward
1311
+ current_stage = suggested_stage
1312
+ stage_promotion_steps.append(
1313
+ (total_steps, old_stage, current_stage, round(old_reward, 4))
1314
+ )
1315
+ print(
1316
+ f"[SELF-IMPROVEMENT] Step {total_steps}: "
1317
+ f"Promoted to Stage {current_stage}. "
1318
+ f"Stage {old_stage} mean reward: {old_reward:.3f} → "
1319
+ f"Stage {current_stage} begins."
1320
+ )
1321
+ # Notify server of promotion
1322
+ notify_reset(session_id, current_stage)
1323
+
1324
+ # Check completion
1325
+ if current_stage == 4:
1326
+ s4_rewards = stage_rewards.get(4, [])
1327
+ if s4_rewards and s4_rewards[-1] >= 0.50:
1328
+ print(
1329
+ f"\n[SELF-IMPROVEMENT] Stage 4 mastered at step {total_steps}! "
1330
+ f"Reward: {s4_rewards[-1]:.3f}"
1331
+ )
1332
+ break
1333
+
1334
+ # Save checkpoint
1335
+ model.save_pretrained(block_output, save_adapters_only=True)
1336
+ tokenizer.save_pretrained(block_output)
1337
+
1338
+ # ── Final summary ──
1339
+ print(f"\n{'═' * 60}")
1340
+ print(" ADAPTIVE TRAINING COMPLETE")
1341
+ print(f"{'═' * 60}")
1342
+ print(f" Total steps: {total_steps}")
1343
+ print(f" Stage promotions: {len(stage_promotion_steps)}")
1344
+ for step_n, from_s, to_s, reward in stage_promotion_steps:
1345
+ print(f" Step {step_n}: Stage {from_s} → {to_s} (reward {reward:.3f})")
1346
+ print(f" Final stage: {current_stage}")
1347
+
1348
+ # Compute final reward per stage
1349
+ final_reward_per_stage = {}
1350
+ for s, rewards_list in stage_rewards.items():
1351
+ if rewards_list:
1352
+ final_reward_per_stage[str(s)] = round(rewards_list[-1], 4)
1353
+
1354
+ # Save results
1355
+ results = {
1356
+ "stage_promotion_steps": [
1357
+ {"step": s, "from_stage": f, "to_stage": t, "reward": r}
1358
+ for s, f, t, r in stage_promotion_steps
1359
+ ],
1360
+ "final_reward_per_stage": final_reward_per_stage,
1361
+ "total_steps_completed": total_steps,
1362
+ "reward_curve": [{"step": s, "reward": r} for s, r in reward_curve],
1363
+ }
1364
+ results_path = Path(output_dir) / "results_adaptive.json"
1365
+ results_path.parent.mkdir(parents=True, exist_ok=True)
1366
+ results_path.write_text(json.dumps(results, indent=2))
1367
+ print(f"\n Results saved: {results_path}")
1368
+
1369
+ # Save final model
1370
+ final_dir = f"{output_dir}/final"
1371
+ model.save_pretrained(final_dir, save_adapters_only=True)
1372
+ tokenizer.save_pretrained(final_dir)
1373
+ print(f" Final model saved: {final_dir}")
1374
+
1375
+ return results
1376
+
1377
+
1378
+ # ============================================================
1379
+ # CELL 10 — Entry point
1380
  # ============================================================
1381
 
1382
  if __name__ == "__main__":
 
1392
  help="Run evaluation after training to measure improvement")
1393
  parser.add_argument("--curriculum", action="store_true",
1394
  help="Run self-improving curriculum training (all 4 stages)")
1395
+ parser.add_argument("--adaptive", action="store_true",
1396
+ help="Run adaptive self-improvement training (Theme 4)")
1397
+ parser.add_argument("--env_url", default="http://localhost:8000",
1398
+ help="Server URL for adaptive training")
1399
 
1400
  args = parser.parse_args()
1401
 
 
1408
  max_steps_per_stage=args.steps,
1409
  batch_size=args.batch_size,
1410
  )
1411
+ elif args.adaptive:
1412
+ train_adaptive(
1413
+ episodes_dir=args.episodes_dir,
1414
+ output_dir=args.output,
1415
+ steps_per_assessment=args.steps,
1416
+ max_total_steps=2000,
1417
+ batch_size=args.batch_size,
1418
+ base_url=args.env_url,
1419
+ )
1420
  else:
1421
  train(
1422
  episodes_dir = args.episodes_dir,