Spaces:
Sleeping
Sleeping
Shabista Sehar commited on
Commit ·
d8f8a45
1
Parent(s): 4855450
implemented
Browse files- README.md +198 -26
- openenv.yaml +72 -9
- server/adaptive_selector.py +140 -0
- server/app.py +145 -13
- server/case_generator.py +312 -0
- server/performance_tracker.py +251 -0
- server/reward.py +139 -7
- server/undertrial_environment.py +16 -1
- training/train_grpo.py +407 -37
README.md
CHANGED
|
@@ -58,8 +58,11 @@ UndertriAI is an **OpenEnv-compliant RL training environment** that teaches an L
|
|
| 58 |
| Method | Endpoint | Description |
|
| 59 |
|---|---|---|
|
| 60 |
| `POST` | `/reset?stage=1` | Start a new episode (curriculum stage 1–4) |
|
|
|
|
| 61 |
| `POST` | `/step` | Submit a tool call or final memo |
|
| 62 |
| `GET` | `/state?session_id=...` | Inspect current episode state |
|
|
|
|
|
|
|
| 63 |
| `GET` | `/health` | Health check |
|
| 64 |
| `GET` | `/tools` | List available tools |
|
| 65 |
| `WS` | `/ws/{session_id}` | WebSocket real-time feed |
|
|
@@ -74,6 +77,11 @@ UndertriAI is an **OpenEnv-compliant RL training environment** that teaches an L
|
|
| 74 |
| `classify_bail_type` | Determine regular / anticipatory / default bail |
|
| 75 |
| `request_document` | Request additional case documents |
|
| 76 |
| `flag_inconsistency` | Flag contradictions in the charge sheet |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
| `submit_memo` | **Terminal action** — submit final bail recommendation |
|
| 78 |
|
| 79 |
### 4-Stage Curriculum
|
|
@@ -87,13 +95,71 @@ UndertriAI is an **OpenEnv-compliant RL training environment** that teaches an L
|
|
| 87 |
|
| 88 |
---
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
## Reward Function
|
| 91 |
|
| 92 |
```
|
| 93 |
-
R = 0.4 × outcome_match
|
| 94 |
+ 0.2 × flight_risk_accuracy
|
| 95 |
+ 0.2 × statutory_accuracy
|
| 96 |
+ 0.2 × condition_appropriateness
|
|
|
|
|
|
|
| 97 |
− 0.3 × bias_penalty
|
| 98 |
```
|
| 99 |
|
|
@@ -101,16 +167,20 @@ All components are **fully deterministic and rule-based** — no LLM-as-judge.
|
|
| 101 |
|
| 102 |
| Component | Signal | Details |
|
| 103 |
|---|---|---|
|
| 104 |
-
| **Outcome Match** | 0.0 / 0.8 / 1.0 | Exact, directional, or wrong vs HC decision |
|
| 105 |
| **Flight Risk** | 0–1 | Ordinal distance to ground-truth risk level |
|
| 106 |
-
| **Statutory** | 0–1 | IPC/BNSS
|
| 107 |
| **Conditions** | 0–1 | Appropriate bail conditions for crime/risk profile |
|
|
|
|
|
|
|
| 108 |
| **Bias Penalty** | −0.3 | Fired if parity argument ignored in bias-flagged cases |
|
| 109 |
|
| 110 |
### Anti-Reward-Hacking Design
|
| 111 |
|
| 112 |
-
-
|
| 113 |
- `GenerationInspectionCallback` prints raw completions every 25 training steps
|
|
|
|
|
|
|
| 114 |
- Bias penalty operates as a separate signal, not folded into outcome
|
| 115 |
- Schema drift (Stage 4) tests adaptability, not pattern memorisation
|
| 116 |
|
|
@@ -118,18 +188,110 @@ All components are **fully deterministic and rule-based** — no LLM-as-judge.
|
|
| 118 |
|
| 119 |
## Training
|
| 120 |
|
| 121 |
-
Uses **GRPO** (Group Relative Policy Optimization) via TRL + Unsloth on `Qwen2.5-
|
| 122 |
|
| 123 |
-
|
| 124 |
-
# Run with before/after eval and results.json
|
| 125 |
-
python training/train_grpo.py \
|
| 126 |
-
--episodes_dir ./data/episodes \
|
| 127 |
-
--stage 1 \
|
| 128 |
-
--steps 200 \
|
| 129 |
-
--eval_after
|
| 130 |
-
```
|
| 131 |
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
### Training Architecture
|
| 135 |
|
|
@@ -146,7 +308,13 @@ Episode Dataset (JSONL)
|
|
| 146 |
↓
|
| 147 |
GRPO updates model weights
|
| 148 |
↓
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
```
|
| 151 |
|
| 152 |
---
|
|
@@ -178,21 +346,25 @@ env = from_hub("Draken1606/undertrial-ai")
|
|
| 178 |
```
|
| 179 |
undertrial_ai/
|
| 180 |
├── server/
|
| 181 |
-
│ ├── app.py
|
| 182 |
-
│ ├── undertrial_environment.py
|
| 183 |
-
│ ├── reward.py
|
| 184 |
-
│ ├── dataset.py
|
| 185 |
-
│
|
|
|
|
|
|
|
|
|
|
| 186 |
├── training/
|
| 187 |
-
│ ├── train_grpo.py
|
| 188 |
│ └── UndertriAI_GRPO_Training.ipynb # Colab notebook
|
| 189 |
├── data/
|
| 190 |
-
│ └── episodes/
|
| 191 |
├── demo/
|
| 192 |
-
│ └── index.html
|
| 193 |
-
├── client.py
|
| 194 |
-
├── models.py
|
| 195 |
-
|
|
|
|
| 196 |
```
|
| 197 |
|
| 198 |
---
|
|
|
|
| 58 |
| Method | Endpoint | Description |
|
| 59 |
|---|---|---|
|
| 60 |
| `POST` | `/reset?stage=1` | Start a new episode (curriculum stage 1–4) |
|
| 61 |
+
| `POST` | `/reset?adaptive=true&auto_stage=true` | Start episode with adaptive selection (Theme 4) |
|
| 62 |
| `POST` | `/step` | Submit a tool call or final memo |
|
| 63 |
| `GET` | `/state?session_id=...` | Inspect current episode state |
|
| 64 |
+
| `GET` | `/profile?session_id=...` | Agent performance profile (Theme 4) |
|
| 65 |
+
| `GET` | `/adaptive_status` | Adaptive mode capabilities & thresholds |
|
| 66 |
| `GET` | `/health` | Health check |
|
| 67 |
| `GET` | `/tools` | List available tools |
|
| 68 |
| `WS` | `/ws/{session_id}` | WebSocket real-time feed |
|
|
|
|
| 77 |
| `classify_bail_type` | Determine regular / anticipatory / default bail |
|
| 78 |
| `request_document` | Request additional case documents |
|
| 79 |
| `flag_inconsistency` | Flag contradictions in the charge sheet |
|
| 80 |
+
| `read_submissions` | Read prosecution/defence arguments on record |
|
| 81 |
+
| `assess_flight_risk` | Systematic flight risk scoring matrix |
|
| 82 |
+
| `check_case_factors` | Examine parity, evidence tampering, victim vulnerability |
|
| 83 |
+
| `apply_proportionality` | BNSS 479 custody vs. max sentence proportionality |
|
| 84 |
+
| `pull_criminal_history` | Prior record, bail history, conviction status |
|
| 85 |
| `submit_memo` | **Terminal action** — submit final bail recommendation |
|
| 86 |
|
| 87 |
### 4-Stage Curriculum
|
|
|
|
| 95 |
|
| 96 |
---
|
| 97 |
|
| 98 |
+
## Theme 4 — Self-Improvement
|
| 99 |
+
|
| 100 |
+
UndertriAI qualifies for Theme 4 through three mechanisms:
|
| 101 |
+
|
| 102 |
+
**1. Adaptive Curriculum Promotion**
|
| 103 |
+
The environment tracks per-domain and per-stage performance using exponential
|
| 104 |
+
moving averages. When the agent demonstrates consistent improvement
|
| 105 |
+
(Stage 1 mean reward ≥ 0.65 over 20 episodes), it automatically promotes
|
| 106 |
+
to the next curriculum stage. This is visible in training logs as:
|
| 107 |
+
```
|
| 108 |
+
[SELF-IMPROVEMENT] Step 100: Promoted to Stage 2. Stage 1 mean reward: 0.710 → Stage 2 begins.
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
**2. Weakness-Targeted Episode Selection**
|
| 112 |
+
In adaptive mode, the episode selector identifies the crime type where the
|
| 113 |
+
agent performs worst and serves proportionally more cases from that domain.
|
| 114 |
+
As the agent improves on weak domains, the selection distribution shifts —
|
| 115 |
+
the environment continuously finds and targets new weaknesses.
|
| 116 |
+
|
| 117 |
+
| Selection | Weight | Mechanism |
|
| 118 |
+
|---|---|---|
|
| 119 |
+
| Weakest domain | 60% | EMA-tracked per-crime-type reward |
|
| 120 |
+
| Failure replay | 30% | Re-serve cases with reward < 0.40 |
|
| 121 |
+
| Exploration | 10% | Uniform random (prevent overfitting) |
|
| 122 |
+
|
| 123 |
+
**3. Synthetic Case Generation**
|
| 124 |
+
When the agent masters a domain (mean reward > 0.70), the environment
|
| 125 |
+
generates harder synthetic variants using 5 perturbation types:
|
| 126 |
+
|
| 127 |
+
| Perturbation | What it tests |
|
| 128 |
+
|---|---|
|
| 129 |
+
| Custody escalation | Custody 2 months below threshold — forces careful statutory computation |
|
| 130 |
+
| Co-accused conflict | Opposite bail outcome for co-accused — tests parity reasoning |
|
| 131 |
+
| Section ambiguity | IPC ↔ BNSS section swap — tests schema drift adaptability |
|
| 132 |
+
| Evidence reversal | Key witness retracted — tests flight risk reassessment |
|
| 133 |
+
| Surety complexity | Non-resident surety — tests condition appropriateness |
|
| 134 |
+
|
| 135 |
+
**Live Demo — Self-Improvement in Action**
|
| 136 |
+
```bash
|
| 137 |
+
# Start the server
|
| 138 |
+
python -m server.app
|
| 139 |
+
|
| 140 |
+
# In another terminal — start adaptive training
|
| 141 |
+
python training/train_grpo.py --adaptive --steps 50 --env_url http://localhost:8000
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
Monitor progress via:
|
| 145 |
+
```
|
| 146 |
+
GET /profile?session_id={id}
|
| 147 |
+
GET /adaptive_status
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
Watch stage promotions in the training log.
|
| 151 |
+
|
| 152 |
+
---
|
| 153 |
+
|
| 154 |
## Reward Function
|
| 155 |
|
| 156 |
```
|
| 157 |
+
R = 0.4 × outcome_match (gated by reasoning quality)
|
| 158 |
+ 0.2 × flight_risk_accuracy
|
| 159 |
+ 0.2 × statutory_accuracy
|
| 160 |
+ 0.2 × condition_appropriateness
|
| 161 |
+
+ 0.1 × reasoning_quality (bonus)
|
| 162 |
+
+ 0.05 × format_compliance (bonus)
|
| 163 |
− 0.3 × bias_penalty
|
| 164 |
```
|
| 165 |
|
|
|
|
| 167 |
|
| 168 |
| Component | Signal | Details |
|
| 169 |
|---|---|---|
|
| 170 |
+
| **Outcome Match** | 0.0 / 0.8 / 1.0 | Exact, directional, or wrong vs HC decision — gated by `<think>` block |
|
| 171 |
| **Flight Risk** | 0–1 | Ordinal distance to ground-truth risk level |
|
| 172 |
+
| **Statutory** | 0–1 | IPC/BNSS threshold computation, direction-gated, NDPS Section 37 aware |
|
| 173 |
| **Conditions** | 0–1 | Appropriate bail conditions for crime/risk profile |
|
| 174 |
+
| **Reasoning Quality** | 0–1 | Anchoring + arithmetic + grounds specificity (10% bonus) |
|
| 175 |
+
| **Format Compliance** | 0–1 | XML tag adherence to system prompt (5% bonus) |
|
| 176 |
| **Bias Penalty** | −0.3 | Fired if parity argument ignored in bias-flagged cases |
|
| 177 |
|
| 178 |
### Anti-Reward-Hacking Design
|
| 179 |
|
| 180 |
+
- 7 independent reward signals (harder to simultaneously game all)
|
| 181 |
- `GenerationInspectionCallback` prints raw completions every 25 training steps
|
| 182 |
+
- Reasoning gate: no `<think>` block → outcome reward zeroed in Stage 2+
|
| 183 |
+
- Direction gate: wrong bail direction → statutory bonus capped
|
| 184 |
- Bias penalty operates as a separate signal, not folded into outcome
|
| 185 |
- Schema drift (Stage 4) tests adaptability, not pattern memorisation
|
| 186 |
|
|
|
|
| 188 |
|
| 189 |
## Training
|
| 190 |
|
| 191 |
+
Uses **GRPO** (Group Relative Policy Optimization) via TRL + Unsloth on `Qwen2.5-7B-Instruct`.
|
| 192 |
|
| 193 |
+
### Training Modes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
+
| Mode | Command | Description |
|
| 196 |
+
|---|---|---|
|
| 197 |
+
| Single stage | `python training/train_grpo.py --stage 1 --steps 200` | Train on one stage |
|
| 198 |
+
| Curriculum | `python training/train_grpo.py --curriculum --steps 150` | Sequential 4-stage with trace harvesting |
|
| 199 |
+
| **Adaptive** | `python training/train_grpo.py --adaptive --steps 50` | **Theme 4** — self-directed with auto-promotion |
|
| 200 |
+
|
| 201 |
+
### Google Colab Training Walkthrough
|
| 202 |
+
|
| 203 |
+
```python
|
| 204 |
+
# ============================================================
|
| 205 |
+
# STEP 1 — Install dependencies (run in first cell)
|
| 206 |
+
# ============================================================
|
| 207 |
+
!pip install -q "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
|
| 208 |
+
!pip install -q --no-deps trl peft accelerate bitsandbytes xformers
|
| 209 |
+
!pip install -q openenv-core datasets
|
| 210 |
+
|
| 211 |
+
# ============================================================
|
| 212 |
+
# STEP 2 — Clone the repository
|
| 213 |
+
# ============================================================
|
| 214 |
+
!git clone https://github.com/Faiz-1606/Undertrial.git
|
| 215 |
+
%cd Undertrial
|
| 216 |
+
|
| 217 |
+
# ============================================================
|
| 218 |
+
# STEP 3 — Verify episodes are available
|
| 219 |
+
# ============================================================
|
| 220 |
+
import os
|
| 221 |
+
episodes_dir = "./data/episodes"
|
| 222 |
+
if not os.path.exists(episodes_dir):
|
| 223 |
+
print("No episodes directory — will use built-in demo episodes")
|
| 224 |
+
else:
|
| 225 |
+
for f in os.listdir(episodes_dir):
|
| 226 |
+
if f.endswith('.jsonl'):
|
| 227 |
+
count = sum(1 for _ in open(f"{episodes_dir}/{f}"))
|
| 228 |
+
print(f" {f}: {count} episodes")
|
| 229 |
+
|
| 230 |
+
# ============================================================
|
| 231 |
+
# STEP 4 — Option A: Single-stage training (quick, ~20 min on T4)
|
| 232 |
+
# ============================================================
|
| 233 |
+
!python training/train_grpo.py \
|
| 234 |
+
--episodes_dir ./data/episodes \
|
| 235 |
+
--stage 1 \
|
| 236 |
+
--steps 200 \
|
| 237 |
+
--batch_size 4 \
|
| 238 |
+
--eval_after
|
| 239 |
+
|
| 240 |
+
# ============================================================
|
| 241 |
+
# STEP 4 — Option B: Curriculum training (full, ~90 min on T4)
|
| 242 |
+
# ============================================================
|
| 243 |
+
!python training/train_grpo.py \
|
| 244 |
+
--episodes_dir ./data/episodes \
|
| 245 |
+
--curriculum \
|
| 246 |
+
--steps 150 \
|
| 247 |
+
--batch_size 4
|
| 248 |
+
|
| 249 |
+
# ============================================================
|
| 250 |
+
# STEP 4 — Option C: Adaptive training (Theme 4, ~60 min on T4)
|
| 251 |
+
# (Requires server running — start in a background cell first)
|
| 252 |
+
# ============================================================
|
| 253 |
+
# Background cell: start the server
|
| 254 |
+
import subprocess
|
| 255 |
+
server = subprocess.Popen(
|
| 256 |
+
["python", "-m", "uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"],
|
| 257 |
+
stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
| 258 |
+
)
|
| 259 |
+
import time; time.sleep(5) # Wait for server startup
|
| 260 |
+
|
| 261 |
+
# Then run adaptive training
|
| 262 |
+
!python training/train_grpo.py \
|
| 263 |
+
--adaptive \
|
| 264 |
+
--episodes_dir ./data/episodes \
|
| 265 |
+
--steps 50 \
|
| 266 |
+
--batch_size 4 \
|
| 267 |
+
--env_url http://localhost:8000
|
| 268 |
+
|
| 269 |
+
# ============================================================
|
| 270 |
+
# STEP 5 — View results
|
| 271 |
+
# ============================================================
|
| 272 |
+
import json
|
| 273 |
+
# For single/curriculum:
|
| 274 |
+
results = json.load(open("./output/undertrial_grpo/results.json"))
|
| 275 |
+
print(json.dumps(results, indent=2))
|
| 276 |
+
|
| 277 |
+
# For adaptive:
|
| 278 |
+
# results = json.load(open("./output/undertrial_grpo/results_adaptive.json"))
|
| 279 |
+
|
| 280 |
+
# ============================================================
|
| 281 |
+
# STEP 6 — (Optional) Merge LoRA adapters for inference
|
| 282 |
+
# ============================================================
|
| 283 |
+
from unsloth import FastLanguageModel
|
| 284 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 285 |
+
"./output/undertrial_grpo/final",
|
| 286 |
+
max_seq_length=3072,
|
| 287 |
+
)
|
| 288 |
+
model.save_pretrained_merged(
|
| 289 |
+
"./output/undertrial_merged",
|
| 290 |
+
tokenizer,
|
| 291 |
+
save_method="merged_16bit",
|
| 292 |
+
)
|
| 293 |
+
print("Merged model saved to ./output/undertrial_merged")
|
| 294 |
+
```
|
| 295 |
|
| 296 |
### Training Architecture
|
| 297 |
|
|
|
|
| 308 |
↓
|
| 309 |
GRPO updates model weights
|
| 310 |
↓
|
| 311 |
+
[Theme 4] PerformanceTracker updates EMA per domain/stage
|
| 312 |
+
↓
|
| 313 |
+
[Theme 4] AdaptiveSelector targets weakest domain
|
| 314 |
+
↓
|
| 315 |
+
[Theme 4] CaseGenerator creates harder synthetic variants
|
| 316 |
+
↓
|
| 317 |
+
[Theme 4] Auto-promote when stage EMA exceeds threshold
|
| 318 |
```
|
| 319 |
|
| 320 |
---
|
|
|
|
| 346 |
```
|
| 347 |
undertrial_ai/
|
| 348 |
├── server/
|
| 349 |
+
│ ├── app.py # FastAPI routes + Theme 4 endpoints
|
| 350 |
+
│ ├── undertrial_environment.py # Environment logic
|
| 351 |
+
│ ├── reward.py # 7-component deterministic reward
|
| 352 |
+
│ ├── dataset.py # Curriculum-staged episode loader
|
| 353 |
+
│ ├── schema_drift.py # IPC → BNSS remapping (Stage 4)
|
| 354 |
+
│ ├── performance_tracker.py # [Theme 4] EMA-based performance profiling
|
| 355 |
+
│ ├── adaptive_selector.py # [Theme 4] Weakness-targeted episode selection
|
| 356 |
+
│ └── case_generator.py # [Theme 4] Synthetic case perturbation
|
| 357 |
├── training/
|
| 358 |
+
│ ├── train_grpo.py # GRPO training (single/curriculum/adaptive)
|
| 359 |
│ └── UndertriAI_GRPO_Training.ipynb # Colab notebook
|
| 360 |
├── data/
|
| 361 |
+
│ └── episodes/ # 1,200 HC judgments across 4 stages
|
| 362 |
├── demo/
|
| 363 |
+
│ └── index.html # Interactive demo UI
|
| 364 |
+
├── client.py # UndertriAIEnv HTTP client
|
| 365 |
+
├── models.py # Pydantic action/observation schemas
|
| 366 |
+
├── openenv.yaml # OpenEnv manifest
|
| 367 |
+
└── Dockerfile # HF Spaces deployment
|
| 368 |
```
|
| 369 |
|
| 370 |
---
|
openenv.yaml
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
name: undertrial-ai
|
| 2 |
-
version: "1.
|
| 3 |
description: >
|
| 4 |
-
OpenEnv-compliant RL training environment for Indian bail decision support
|
| 5 |
-
An LLM agent reads High Court bail
|
| 6 |
-
|
| 7 |
-
real HC judgments with an explicit
|
|
|
|
|
|
|
| 8 |
|
| 9 |
author: Draken1606
|
| 10 |
license: MIT
|
|
@@ -19,6 +21,8 @@ tags:
|
|
| 19 |
- world-modeling
|
| 20 |
- bias-mitigation
|
| 21 |
- bnss-2023
|
|
|
|
|
|
|
| 22 |
|
| 23 |
environment:
|
| 24 |
class: undertrial_ai.server.undertrial_environment.UndertriAIEnvironment
|
|
@@ -52,17 +56,18 @@ actions:
|
|
| 52 |
description: "TERMINAL — Submit structured bail assessment memo"
|
| 53 |
|
| 54 |
reward:
|
| 55 |
-
formula: "0.
|
| 56 |
range: [-0.7, 1.15]
|
| 57 |
terminal_action: submit_memo
|
| 58 |
deterministic: true
|
| 59 |
llm_as_judge: false
|
| 60 |
components:
|
| 61 |
-
- outcome_match: "Agreement with real High Court decision (
|
| 62 |
- flight_risk_accuracy: "Flight risk classification accuracy (20%)"
|
| 63 |
-
- statutory_accuracy: "IPC/BNSS threshold computation (20%)"
|
| 64 |
- condition_appropriateness: "Bail condition quality (20%)"
|
| 65 |
-
- reasoning_quality: "Justification anchoring + arithmetic verification + grounds specificity (10%)"
|
|
|
|
| 66 |
- bias_penalty: "Penalty for ignoring parity in bias cases (-30%)"
|
| 67 |
|
| 68 |
curriculum:
|
|
@@ -72,12 +77,70 @@ curriculum:
|
|
| 72 |
stage_3: "Bias reversal / parity cases"
|
| 73 |
stage_4: "Schema drift (IPC→BNSS, regional FIR formats)"
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
training:
|
| 76 |
method: GRPO
|
| 77 |
framework: TRL + Unsloth
|
| 78 |
model: unsloth/Qwen2.5-7B-Instruct
|
| 79 |
notebook: training/UndertriAI_GRPO_Training.ipynb
|
| 80 |
script: training/train_grpo.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
deployment:
|
| 83 |
platform: huggingface-spaces
|
|
|
|
| 1 |
name: undertrial-ai
|
| 2 |
+
version: "1.1.0"
|
| 3 |
description: >
|
| 4 |
+
OpenEnv-compliant RL training environment for Indian bail decision support
|
| 5 |
+
with adaptive self-improvement (Theme 4). An LLM agent reads High Court bail
|
| 6 |
+
cases, invokes legal tools, and submits structured bail recommendations.
|
| 7 |
+
Reward computed deterministically against real HC judgments with an explicit
|
| 8 |
+
bias penalty (lambda=0.3). Features performance-aware episode selection,
|
| 9 |
+
stage-gated curriculum promotion, and synthetic case generation.
|
| 10 |
|
| 11 |
author: Draken1606
|
| 12 |
license: MIT
|
|
|
|
| 21 |
- world-modeling
|
| 22 |
- bias-mitigation
|
| 23 |
- bnss-2023
|
| 24 |
+
- self-improvement
|
| 25 |
+
- adaptive-curriculum
|
| 26 |
|
| 27 |
environment:
|
| 28 |
class: undertrial_ai.server.undertrial_environment.UndertriAIEnvironment
|
|
|
|
| 56 |
description: "TERMINAL — Submit structured bail assessment memo"
|
| 57 |
|
| 58 |
reward:
|
| 59 |
+
formula: "0.4*outcome_gated + 0.2*flight_risk + 0.2*statutory + 0.2*conditions + 0.1*reasoning_quality + 0.05*format - 0.3*bias"
|
| 60 |
range: [-0.7, 1.15]
|
| 61 |
terminal_action: submit_memo
|
| 62 |
deterministic: true
|
| 63 |
llm_as_judge: false
|
| 64 |
components:
|
| 65 |
+
- outcome_match: "Agreement with real High Court decision, gated by reasoning quality (40%)"
|
| 66 |
- flight_risk_accuracy: "Flight risk classification accuracy (20%)"
|
| 67 |
+
- statutory_accuracy: "IPC/BNSS threshold computation with direction gate (20%)"
|
| 68 |
- condition_appropriateness: "Bail condition quality (20%)"
|
| 69 |
+
- reasoning_quality: "Justification anchoring + arithmetic verification + grounds specificity (10% bonus)"
|
| 70 |
+
- format_compliance: "XML tag adherence matching system prompt structure (5% bonus)"
|
| 71 |
- bias_penalty: "Penalty for ignoring parity in bias cases (-30%)"
|
| 72 |
|
| 73 |
curriculum:
|
|
|
|
| 77 |
stage_3: "Bias reversal / parity cases"
|
| 78 |
stage_4: "Schema drift (IPC→BNSS, regional FIR formats)"
|
| 79 |
|
| 80 |
+
self_improvement:
|
| 81 |
+
adaptive_curriculum:
|
| 82 |
+
description: >
|
| 83 |
+
Performance-gated stage promotion using exponential moving averages.
|
| 84 |
+
Agent auto-promotes when per-stage EMA exceeds threshold.
|
| 85 |
+
thresholds:
|
| 86 |
+
stage_1_to_2: {min_reward: 0.65, min_episodes: 20}
|
| 87 |
+
stage_2_to_3: {min_reward: 0.55, min_episodes: 50}
|
| 88 |
+
stage_3_to_4: {min_reward: 0.50, min_episodes: 20}
|
| 89 |
+
weakness_targeting:
|
| 90 |
+
description: >
|
| 91 |
+
Adaptive episode selection identifies the crime type with lowest EMA
|
| 92 |
+
reward and serves proportionally more cases from that domain.
|
| 93 |
+
strategy: "60% weakest domain / 30% failure replay / 10% exploration"
|
| 94 |
+
synthetic_generation:
|
| 95 |
+
description: >
|
| 96 |
+
When agent masters a domain (EMA > 0.70), generates harder synthetic
|
| 97 |
+
variants using 5 perturbation types.
|
| 98 |
+
perturbation_types:
|
| 99 |
+
- custody_escalation
|
| 100 |
+
- co_accused_conflict
|
| 101 |
+
- section_ambiguity
|
| 102 |
+
- evidence_reversal
|
| 103 |
+
- surety_complexity
|
| 104 |
+
|
| 105 |
+
endpoints:
|
| 106 |
+
- path: /reset
|
| 107 |
+
method: POST
|
| 108 |
+
description: "Start a new episode. Supports adaptive=true and auto_stage=true for Theme 4."
|
| 109 |
+
- path: /step
|
| 110 |
+
method: POST
|
| 111 |
+
description: "Submit a tool call or final memo. Updates performance tracker when done."
|
| 112 |
+
- path: /state
|
| 113 |
+
method: GET
|
| 114 |
+
description: "Inspect current episode state."
|
| 115 |
+
- path: /health
|
| 116 |
+
method: GET
|
| 117 |
+
description: "Health check."
|
| 118 |
+
- path: /tools
|
| 119 |
+
method: GET
|
| 120 |
+
description: "List available tools."
|
| 121 |
+
- path: /profile
|
| 122 |
+
method: GET
|
| 123 |
+
description: "Get agent performance profile for a session (Theme 4)."
|
| 124 |
+
- path: /adaptive_status
|
| 125 |
+
method: GET
|
| 126 |
+
description: "Get global adaptive mode capabilities and thresholds."
|
| 127 |
+
- path: /ws/{session_id}
|
| 128 |
+
method: WS
|
| 129 |
+
description: "WebSocket real-time feed."
|
| 130 |
+
|
| 131 |
training:
|
| 132 |
method: GRPO
|
| 133 |
framework: TRL + Unsloth
|
| 134 |
model: unsloth/Qwen2.5-7B-Instruct
|
| 135 |
notebook: training/UndertriAI_GRPO_Training.ipynb
|
| 136 |
script: training/train_grpo.py
|
| 137 |
+
modes:
|
| 138 |
+
- name: single_stage
|
| 139 |
+
command: "python training/train_grpo.py --stage 1 --steps 200"
|
| 140 |
+
- name: curriculum
|
| 141 |
+
command: "python training/train_grpo.py --curriculum --steps 150"
|
| 142 |
+
- name: adaptive
|
| 143 |
+
command: "python training/train_grpo.py --adaptive --steps 50 --env_url http://localhost:8000"
|
| 144 |
|
| 145 |
deployment:
|
| 146 |
platform: huggingface-spaces
|
server/adaptive_selector.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
UndertriAI — Adaptive Episode Selector (Theme 4: Self-Improvement)
|
| 3 |
+
|
| 4 |
+
Wraps the existing BailDataset to provide performance-aware episode
|
| 5 |
+
selection when adaptive mode is enabled. Falls back to uniform random
|
| 6 |
+
(identical to existing behavior) when adaptive=False.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import random
|
| 10 |
+
from typing import Any, Dict, List, Optional
|
| 11 |
+
|
| 12 |
+
from .performance_tracker import PerformanceTracker
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class AdaptiveSelector:
|
| 16 |
+
"""
|
| 17 |
+
Performance-aware episode selector.
|
| 18 |
+
|
| 19 |
+
Selection strategy (applied in order when adaptive=True):
|
| 20 |
+
60%: sample from the weakest crime-type domain in current_stage
|
| 21 |
+
30%: replay cases where recent performance was poor (reward < 0.40)
|
| 22 |
+
10%: uniform random from current_stage (exploration)
|
| 23 |
+
|
| 24 |
+
Always returns a valid episode dict. Never raises.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, dataset, tracker: PerformanceTracker):
|
| 28 |
+
"""
|
| 29 |
+
Args:
|
| 30 |
+
dataset: BailDataset instance (has _episodes, sample_episode)
|
| 31 |
+
tracker: PerformanceTracker instance driving selection
|
| 32 |
+
"""
|
| 33 |
+
self.dataset = dataset
|
| 34 |
+
self.tracker = tracker
|
| 35 |
+
|
| 36 |
+
# ------------------------------------------------------------------
|
| 37 |
+
# Public API
|
| 38 |
+
# ------------------------------------------------------------------
|
| 39 |
+
|
| 40 |
+
def select_episode(self, current_stage: int) -> Dict[str, Any]:
|
| 41 |
+
"""
|
| 42 |
+
Performance-aware selection for adaptive mode.
|
| 43 |
+
|
| 44 |
+
60% weakest domain → 30% failure replay → 10% exploration.
|
| 45 |
+
Falls back to uniform on any failure.
|
| 46 |
+
"""
|
| 47 |
+
try:
|
| 48 |
+
roll = random.random()
|
| 49 |
+
|
| 50 |
+
if roll < 0.60:
|
| 51 |
+
# Try weakest domain
|
| 52 |
+
ep = self._select_weakest_domain(current_stage)
|
| 53 |
+
if ep is not None:
|
| 54 |
+
return ep
|
| 55 |
+
|
| 56 |
+
if roll < 0.90:
|
| 57 |
+
# Try failure replay
|
| 58 |
+
ep = self._select_failure_replay(current_stage)
|
| 59 |
+
if ep is not None:
|
| 60 |
+
return ep
|
| 61 |
+
|
| 62 |
+
# 10% exploration or fallback
|
| 63 |
+
return self.select_episode_uniform(current_stage)
|
| 64 |
+
|
| 65 |
+
except Exception:
|
| 66 |
+
# Absolute fallback — never crash
|
| 67 |
+
return self.select_episode_uniform(current_stage)
|
| 68 |
+
|
| 69 |
+
def select_episode_uniform(self, current_stage: int) -> Dict[str, Any]:
|
| 70 |
+
"""
|
| 71 |
+
Pure random selection from current_stage.
|
| 72 |
+
Identical to existing BailDataset.sample_episode() behavior.
|
| 73 |
+
"""
|
| 74 |
+
return self.dataset.sample_episode(stage=current_stage)
|
| 75 |
+
|
| 76 |
+
# ------------------------------------------------------------------
|
| 77 |
+
# Internal strategies
|
| 78 |
+
# ------------------------------------------------------------------
|
| 79 |
+
|
| 80 |
+
def _select_weakest_domain(
|
| 81 |
+
self, current_stage: int
|
| 82 |
+
) -> Optional[Dict[str, Any]]:
|
| 83 |
+
"""
|
| 84 |
+
Select an episode from the weakest crime-type domain.
|
| 85 |
+
Returns None if no weak domain identified or no matching episodes.
|
| 86 |
+
"""
|
| 87 |
+
weak_domain = self.tracker.weakest_domain()
|
| 88 |
+
if weak_domain is None:
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
# Find episodes matching this crime type in the current stage
|
| 92 |
+
episodes = self._get_stage_episodes(current_stage)
|
| 93 |
+
matches = [
|
| 94 |
+
ep for ep in episodes
|
| 95 |
+
if str(ep.get("crime_type", "")).strip() == weak_domain
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
if not matches:
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
return random.choice(matches)
|
| 102 |
+
|
| 103 |
+
def _select_failure_replay(
|
| 104 |
+
self, current_stage: int
|
| 105 |
+
) -> Optional[Dict[str, Any]]:
|
| 106 |
+
"""
|
| 107 |
+
Replay a case where the agent recently scored below 0.40.
|
| 108 |
+
Returns None if no recent failures or no matching episodes.
|
| 109 |
+
"""
|
| 110 |
+
failed_ids = self.tracker.get_recent_failures(threshold=0.40)
|
| 111 |
+
if not failed_ids:
|
| 112 |
+
return None
|
| 113 |
+
|
| 114 |
+
# Find episodes matching failed case_ids in current stage
|
| 115 |
+
episodes = self._get_stage_episodes(current_stage)
|
| 116 |
+
matches = [
|
| 117 |
+
ep for ep in episodes
|
| 118 |
+
if ep.get("case_id", "") in failed_ids
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
if not matches:
|
| 122 |
+
return None
|
| 123 |
+
|
| 124 |
+
return random.choice(matches)
|
| 125 |
+
|
| 126 |
+
def _get_stage_episodes(self, stage: int) -> List[Dict[str, Any]]:
|
| 127 |
+
"""Get all episodes for a given stage from the dataset."""
|
| 128 |
+
try:
|
| 129 |
+
eps = self.dataset._episodes.get(stage, [])
|
| 130 |
+
if eps:
|
| 131 |
+
return eps
|
| 132 |
+
# Fallback chain matching BailDataset.sample_episode
|
| 133 |
+
for candidate in [stage - 1, stage + 1, 1, 2, 3, 4]:
|
| 134 |
+
if 1 <= candidate <= 4:
|
| 135 |
+
eps = self.dataset._episodes.get(candidate, [])
|
| 136 |
+
if eps:
|
| 137 |
+
return eps
|
| 138 |
+
except Exception:
|
| 139 |
+
pass
|
| 140 |
+
return []
|
server/app.py
CHANGED
|
@@ -5,16 +5,40 @@ Wraps UndertriAIEnvironment as an OpenEnv-compatible HTTP + WebSocket server.
|
|
| 5 |
|
| 6 |
import os
|
| 7 |
from pathlib import Path
|
|
|
|
| 8 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 9 |
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
from fastapi.responses import JSONResponse, HTMLResponse
|
| 11 |
import json
|
| 12 |
import uuid
|
| 13 |
-
|
| 14 |
|
| 15 |
from .undertrial_environment import UndertriAIEnvironment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
|
|
|
| 18 |
_sessions: dict = {}
|
| 19 |
|
| 20 |
app = FastAPI(
|
|
@@ -33,14 +57,16 @@ app.add_middleware(
|
|
| 33 |
EPISODES_DIR = os.environ.get("UNDERTRIAL_EPISODES_DIR", None)
|
| 34 |
|
| 35 |
|
| 36 |
-
def
|
|
|
|
| 37 |
if session_id not in _sessions:
|
| 38 |
-
|
|
|
|
| 39 |
return _sessions[session_id]
|
| 40 |
|
| 41 |
|
| 42 |
# ------------------------------------------------------------------
|
| 43 |
-
# REST endpoints
|
| 44 |
# ------------------------------------------------------------------
|
| 45 |
|
| 46 |
@app.get("/", response_class=HTMLResponse)
|
|
@@ -69,12 +95,46 @@ def health():
|
|
| 69 |
|
| 70 |
|
| 71 |
@app.post("/reset")
|
| 72 |
-
def reset(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
if session_id is None:
|
| 74 |
session_id = str(uuid.uuid4())
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
return {
|
| 79 |
"session_id": session_id,
|
| 80 |
"observation": obs.model_dump(),
|
|
@@ -91,7 +151,8 @@ def step(payload: dict):
|
|
| 91 |
if not session_id or session_id not in _sessions:
|
| 92 |
return JSONResponse(status_code=400, content={"error": "Invalid session_id. Call /reset first."})
|
| 93 |
|
| 94 |
-
|
|
|
|
| 95 |
|
| 96 |
# Deserialize action by tool_name
|
| 97 |
tool_name = action_data.get("tool_name", "")
|
|
@@ -126,7 +187,35 @@ def step(payload: dict):
|
|
| 126 |
except Exception as e:
|
| 127 |
return JSONResponse(status_code=422, content={"error": str(e)})
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
result = env.step(action)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
return {
|
| 131 |
"session_id": session_id,
|
| 132 |
"observation": result.observation.model_dump(),
|
|
@@ -140,7 +229,7 @@ def step(payload: dict):
|
|
| 140 |
def state(session_id: str):
|
| 141 |
if session_id not in _sessions:
|
| 142 |
return JSONResponse(status_code=400, content={"error": "Invalid session_id."})
|
| 143 |
-
return _sessions[session_id].state
|
| 144 |
|
| 145 |
|
| 146 |
@app.get("/observation")
|
|
@@ -148,7 +237,7 @@ def observation(session_id: str):
|
|
| 148 |
"""OpenEnv spec alias for /state — returns current episode observation."""
|
| 149 |
if session_id not in _sessions:
|
| 150 |
return JSONResponse(status_code=400, content={"error": "Invalid session_id."})
|
| 151 |
-
return _sessions[session_id].state
|
| 152 |
|
| 153 |
|
| 154 |
@app.get("/tools")
|
|
@@ -171,6 +260,48 @@ def list_tools():
|
|
| 171 |
}
|
| 172 |
|
| 173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
# ------------------------------------------------------------------
|
| 175 |
# WebSocket endpoint (OpenEnv standard)
|
| 176 |
# ------------------------------------------------------------------
|
|
@@ -178,7 +309,8 @@ def list_tools():
|
|
| 178 |
@app.websocket("/ws/{session_id}")
|
| 179 |
async def websocket_endpoint(websocket: WebSocket, session_id: str):
|
| 180 |
await websocket.accept()
|
| 181 |
-
|
|
|
|
| 182 |
try:
|
| 183 |
while True:
|
| 184 |
data = await websocket.receive_text()
|
|
|
|
| 5 |
|
| 6 |
import os
|
| 7 |
from pathlib import Path
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 10 |
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
from fastapi.responses import JSONResponse, HTMLResponse
|
| 12 |
import json
|
| 13 |
import uuid
|
| 14 |
+
from typing import List, Optional
|
| 15 |
|
| 16 |
from .undertrial_environment import UndertriAIEnvironment
|
| 17 |
+
from .performance_tracker import PerformanceTracker
|
| 18 |
+
from .adaptive_selector import AdaptiveSelector
|
| 19 |
+
from .case_generator import generate_variants
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ------------------------------------------------------------------
|
| 23 |
+
# Session state
|
| 24 |
+
# ------------------------------------------------------------------
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class SessionState:
|
| 28 |
+
"""Per-session state wrapping the environment + Theme 4 components."""
|
| 29 |
+
env: UndertriAIEnvironment
|
| 30 |
+
tracker: PerformanceTracker = field(default_factory=PerformanceTracker)
|
| 31 |
+
adaptive: bool = False
|
| 32 |
+
selector: Optional[AdaptiveSelector] = None
|
| 33 |
+
tools_used: List[str] = field(default_factory=list)
|
| 34 |
+
synthetic_cases_generated: int = 0
|
| 35 |
+
|
| 36 |
+
def __post_init__(self):
|
| 37 |
+
if self.selector is None:
|
| 38 |
+
self.selector = AdaptiveSelector(self.env.dataset, self.tracker)
|
| 39 |
|
| 40 |
+
|
| 41 |
+
# Session store: session_id → SessionState
|
| 42 |
_sessions: dict = {}
|
| 43 |
|
| 44 |
app = FastAPI(
|
|
|
|
| 57 |
EPISODES_DIR = os.environ.get("UNDERTRIAL_EPISODES_DIR", None)
|
| 58 |
|
| 59 |
|
| 60 |
+
def get_or_create_session(session_id: str) -> SessionState:
|
| 61 |
+
"""Get existing session or create new one with all Theme 4 components."""
|
| 62 |
if session_id not in _sessions:
|
| 63 |
+
env = UndertriAIEnvironment(episodes_dir=EPISODES_DIR)
|
| 64 |
+
_sessions[session_id] = SessionState(env=env)
|
| 65 |
return _sessions[session_id]
|
| 66 |
|
| 67 |
|
| 68 |
# ------------------------------------------------------------------
|
| 69 |
+
# REST endpoints (existing — preserved exactly)
|
| 70 |
# ------------------------------------------------------------------
|
| 71 |
|
| 72 |
@app.get("/", response_class=HTMLResponse)
|
|
|
|
| 95 |
|
| 96 |
|
| 97 |
@app.post("/reset")
|
| 98 |
+
def reset(
|
| 99 |
+
stage: int = 1,
|
| 100 |
+
session_id: str = None,
|
| 101 |
+
seed: int = None,
|
| 102 |
+
episode_id: str = None,
|
| 103 |
+
adaptive: bool = False,
|
| 104 |
+
auto_stage: bool = False,
|
| 105 |
+
):
|
| 106 |
if session_id is None:
|
| 107 |
session_id = str(uuid.uuid4())
|
| 108 |
+
|
| 109 |
+
session = get_or_create_session(session_id)
|
| 110 |
+
env = session.env
|
| 111 |
+
session.adaptive = adaptive
|
| 112 |
+
session.tools_used = [] # Reset tools tracking
|
| 113 |
+
|
| 114 |
+
# Auto-stage: use tracker's suggestion
|
| 115 |
+
effective_stage = stage
|
| 116 |
+
if auto_stage:
|
| 117 |
+
effective_stage = session.tracker.suggest_next_stage()
|
| 118 |
+
|
| 119 |
+
env.set_stage(effective_stage)
|
| 120 |
+
|
| 121 |
+
# Adaptive episode selection
|
| 122 |
+
if adaptive and episode_id is None and seed is None:
|
| 123 |
+
# Use adaptive selector instead of uniform random
|
| 124 |
+
selected_ep = session.selector.select_episode(effective_stage)
|
| 125 |
+
# Inject the selected episode directly into the environment
|
| 126 |
+
env._episode = selected_ep
|
| 127 |
+
env._episode_id = str(uuid.uuid4())
|
| 128 |
+
env._step_count = 0
|
| 129 |
+
env._flags = []
|
| 130 |
+
env._retrieved_precedents = []
|
| 131 |
+
env._action_history = []
|
| 132 |
+
env._statutory_tool_called = False
|
| 133 |
+
env._tools_called = set()
|
| 134 |
+
obs = env._make_observation(action_result=None)
|
| 135 |
+
else:
|
| 136 |
+
obs = env.reset(stage=effective_stage, seed=seed, episode_id=episode_id)
|
| 137 |
+
|
| 138 |
return {
|
| 139 |
"session_id": session_id,
|
| 140 |
"observation": obs.model_dump(),
|
|
|
|
| 151 |
if not session_id or session_id not in _sessions:
|
| 152 |
return JSONResponse(status_code=400, content={"error": "Invalid session_id. Call /reset first."})
|
| 153 |
|
| 154 |
+
session = _sessions[session_id]
|
| 155 |
+
env = session.env
|
| 156 |
|
| 157 |
# Deserialize action by tool_name
|
| 158 |
tool_name = action_data.get("tool_name", "")
|
|
|
|
| 187 |
except Exception as e:
|
| 188 |
return JSONResponse(status_code=422, content={"error": str(e)})
|
| 189 |
|
| 190 |
+
# Track tool usage for this session
|
| 191 |
+
if tool_name != "submit_memo":
|
| 192 |
+
session.tools_used.append(tool_name)
|
| 193 |
+
|
| 194 |
result = env.step(action)
|
| 195 |
+
|
| 196 |
+
# Theme 4: Update tracker after terminal action (reward available)
|
| 197 |
+
if result.done and hasattr(result, "info") and isinstance(result.info, dict):
|
| 198 |
+
reward_components = result.info
|
| 199 |
+
episode = env._episode or {}
|
| 200 |
+
session.tracker.update(
|
| 201 |
+
episode=episode,
|
| 202 |
+
reward_components=reward_components,
|
| 203 |
+
tools_used=list(session.tools_used),
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Generate synthetic cases if agent mastered this domain
|
| 207 |
+
if session.adaptive:
|
| 208 |
+
crime_type = episode.get("crime_type", "")
|
| 209 |
+
if crime_type and session.tracker.should_generate_synthetic(crime_type):
|
| 210 |
+
variants = generate_variants(episode, n=3)
|
| 211 |
+
if variants:
|
| 212 |
+
# Inject synthetic cases into the dataset
|
| 213 |
+
stage = episode.get("curriculum_stage", 1)
|
| 214 |
+
for v in variants:
|
| 215 |
+
v["curriculum_stage"] = stage
|
| 216 |
+
env.dataset._episodes.setdefault(stage, []).append(v)
|
| 217 |
+
session.synthetic_cases_generated += len(variants)
|
| 218 |
+
|
| 219 |
return {
|
| 220 |
"session_id": session_id,
|
| 221 |
"observation": result.observation.model_dump(),
|
|
|
|
| 229 |
def state(session_id: str):
|
| 230 |
if session_id not in _sessions:
|
| 231 |
return JSONResponse(status_code=400, content={"error": "Invalid session_id."})
|
| 232 |
+
return _sessions[session_id].env.state
|
| 233 |
|
| 234 |
|
| 235 |
@app.get("/observation")
|
|
|
|
| 237 |
"""OpenEnv spec alias for /state — returns current episode observation."""
|
| 238 |
if session_id not in _sessions:
|
| 239 |
return JSONResponse(status_code=400, content={"error": "Invalid session_id."})
|
| 240 |
+
return _sessions[session_id].env.state
|
| 241 |
|
| 242 |
|
| 243 |
@app.get("/tools")
|
|
|
|
| 260 |
}
|
| 261 |
|
| 262 |
|
| 263 |
+
# ------------------------------------------------------------------
|
| 264 |
+
# Theme 4: New API endpoints (additive — do not replace existing)
|
| 265 |
+
# ------------------------------------------------------------------
|
| 266 |
+
|
| 267 |
+
@app.get("/profile")
|
| 268 |
+
def get_profile(session_id: str):
|
| 269 |
+
"""Returns the current PerformanceTracker profile for the session."""
|
| 270 |
+
if session_id not in _sessions:
|
| 271 |
+
return JSONResponse(
|
| 272 |
+
status_code=404,
|
| 273 |
+
content={"error": f"Session '{session_id}' not found. Call /reset first."},
|
| 274 |
+
)
|
| 275 |
+
session = _sessions[session_id]
|
| 276 |
+
return {
|
| 277 |
+
"session_id": session_id,
|
| 278 |
+
"profile": session.tracker.get_profile(),
|
| 279 |
+
"adaptive_mode": session.adaptive,
|
| 280 |
+
"synthetic_cases_generated": session.synthetic_cases_generated,
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
@app.get("/adaptive_status")
|
| 285 |
+
def adaptive_status():
|
| 286 |
+
"""Returns global adaptive mode capabilities (not session-specific)."""
|
| 287 |
+
return {
|
| 288 |
+
"adaptive_available": True,
|
| 289 |
+
"description": "Performance-aware episode selection and synthetic case generation",
|
| 290 |
+
"promotion_thresholds": {
|
| 291 |
+
"stage_1_to_2": {"min_reward": 0.65, "min_episodes": 20},
|
| 292 |
+
"stage_2_to_3": {"min_reward": 0.55, "min_episodes": 50},
|
| 293 |
+
"stage_3_to_4": {"min_reward": 0.50, "min_episodes": 20},
|
| 294 |
+
},
|
| 295 |
+
"perturbation_types": [
|
| 296 |
+
"custody_escalation",
|
| 297 |
+
"co_accused_conflict",
|
| 298 |
+
"section_ambiguity",
|
| 299 |
+
"evidence_reversal",
|
| 300 |
+
"surety_complexity",
|
| 301 |
+
],
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
|
| 305 |
# ------------------------------------------------------------------
|
| 306 |
# WebSocket endpoint (OpenEnv standard)
|
| 307 |
# ------------------------------------------------------------------
|
|
|
|
| 309 |
@app.websocket("/ws/{session_id}")
|
| 310 |
async def websocket_endpoint(websocket: WebSocket, session_id: str):
|
| 311 |
await websocket.accept()
|
| 312 |
+
session = get_or_create_session(session_id)
|
| 313 |
+
env = session.env
|
| 314 |
try:
|
| 315 |
while True:
|
| 316 |
data = await websocket.receive_text()
|
server/case_generator.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
UndertriAI — Synthetic Case Generator (Theme 4: Self-Improvement)
|
| 3 |
+
|
| 4 |
+
When the agent masters a domain, this generates harder synthetic variants
|
| 5 |
+
of existing cases. All generation is deterministic string manipulation —
|
| 6 |
+
no LLM calls.
|
| 7 |
+
|
| 8 |
+
5 perturbation types:
|
| 9 |
+
1. custody_escalation — custody just below statutory threshold
|
| 10 |
+
2. co_accused_conflict — co-accused with opposite bail outcome
|
| 11 |
+
3. section_ambiguity — IPC ↔ BNSS section swap
|
| 12 |
+
4. evidence_reversal — retracted witness / unreliable evidence
|
| 13 |
+
5. surety_complexity — non-resident surety complication
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import copy
|
| 17 |
+
import re
|
| 18 |
+
from typing import Any, Dict, List, Optional
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# IPC → BNSS mapping (subset used by the environment)
|
| 22 |
+
IPC_TO_BNSS = {
|
| 23 |
+
"302": "103", "307": "109", "376": "64", "304B": "80", "395": "310",
|
| 24 |
+
"392": "309", "420": "318", "498A": "85", "406": "316", "465": "336",
|
| 25 |
+
"323": "115", "354": "74", "120B": "61", "506": "351", "121": "147",
|
| 26 |
+
"379": "303", "324": "117", "354A": "75",
|
| 27 |
+
}
|
| 28 |
+
BNSS_TO_IPC = {v: k for k, v in IPC_TO_BNSS.items()}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ── Required fields for schema validation ────────────────────────────
|
| 32 |
+
|
| 33 |
+
REQUIRED_FIELDS = {
|
| 34 |
+
"case_id": str,
|
| 35 |
+
"crime_type": str,
|
| 36 |
+
"ipc_sections": list,
|
| 37 |
+
"custody_months": (int, float),
|
| 38 |
+
"charge_sheet": str,
|
| 39 |
+
"ground_truth": dict,
|
| 40 |
+
"curriculum_stage": (int, float),
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def is_schema_valid(episode: Dict[str, Any]) -> bool:
|
| 45 |
+
"""
|
| 46 |
+
Check that all required fields are present and correct types.
|
| 47 |
+
Returns True/False — used to filter out malformed synthetic cases.
|
| 48 |
+
"""
|
| 49 |
+
for field, expected_type in REQUIRED_FIELDS.items():
|
| 50 |
+
if field not in episode:
|
| 51 |
+
return False
|
| 52 |
+
if not isinstance(episode[field], expected_type):
|
| 53 |
+
return False
|
| 54 |
+
|
| 55 |
+
# ground_truth must have 'outcome'
|
| 56 |
+
gt = episode.get("ground_truth", {})
|
| 57 |
+
if "outcome" not in gt:
|
| 58 |
+
return False
|
| 59 |
+
|
| 60 |
+
return True
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def generate_variants(
|
| 64 |
+
source_episode: Dict[str, Any],
|
| 65 |
+
n: int = 5,
|
| 66 |
+
) -> List[Dict[str, Any]]:
|
| 67 |
+
"""
|
| 68 |
+
Generate up to n synthetic harder variants of a real episode.
|
| 69 |
+
Each variant applies exactly ONE perturbation.
|
| 70 |
+
|
| 71 |
+
Returns only valid variants (may be fewer than n if some
|
| 72 |
+
perturbations can't be applied cleanly).
|
| 73 |
+
"""
|
| 74 |
+
if not is_schema_valid(source_episode):
|
| 75 |
+
return []
|
| 76 |
+
|
| 77 |
+
perturbations = [
|
| 78 |
+
_custody_escalation,
|
| 79 |
+
_co_accused_conflict,
|
| 80 |
+
_section_ambiguity,
|
| 81 |
+
_evidence_reversal,
|
| 82 |
+
_surety_complexity,
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
variants = []
|
| 86 |
+
for i, perturb_fn in enumerate(perturbations[:n]):
|
| 87 |
+
try:
|
| 88 |
+
variant = perturb_fn(source_episode)
|
| 89 |
+
if variant is not None and is_schema_valid(variant):
|
| 90 |
+
variants.append(variant)
|
| 91 |
+
except Exception:
|
| 92 |
+
# Skip perturbation on any error
|
| 93 |
+
continue
|
| 94 |
+
|
| 95 |
+
return variants
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# ── Perturbation 1: Custody Escalation ───────────────────────────────
|
| 99 |
+
|
| 100 |
+
def _custody_escalation(episode: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| 101 |
+
"""
|
| 102 |
+
Set custody_months to exactly 2 months below the statutory threshold.
|
| 103 |
+
Forces careful computation — case is NOT yet eligible for default bail.
|
| 104 |
+
"""
|
| 105 |
+
ep = copy.deepcopy(episode)
|
| 106 |
+
max_sent = ep.get("max_sentence_years", 5.0)
|
| 107 |
+
|
| 108 |
+
# Threshold is 50% of max sentence in months
|
| 109 |
+
threshold_months = (max_sent * 12) / 2
|
| 110 |
+
new_custody = max(1.0, threshold_months - 2.0)
|
| 111 |
+
|
| 112 |
+
old_custody = ep.get("custody_months", 0)
|
| 113 |
+
ep["custody_months"] = round(new_custody, 1)
|
| 114 |
+
|
| 115 |
+
# Update charge sheet text if it mentions custody duration
|
| 116 |
+
charge = ep.get("charge_sheet", "")
|
| 117 |
+
if str(int(old_custody)) in charge:
|
| 118 |
+
charge = charge.replace(
|
| 119 |
+
f"{int(old_custody)} months",
|
| 120 |
+
f"{int(new_custody)} months",
|
| 121 |
+
)
|
| 122 |
+
ep["charge_sheet"] = charge
|
| 123 |
+
|
| 124 |
+
# Metadata
|
| 125 |
+
parent_id = ep.get("case_id", "UNKNOWN")
|
| 126 |
+
ep["case_id"] = f"SYN_{parent_id}_CUST"
|
| 127 |
+
ep["source"] = "synthetic"
|
| 128 |
+
ep["parent_case_id"] = parent_id
|
| 129 |
+
ep["perturbation_type"] = "custody_escalation"
|
| 130 |
+
ep["difficulty"] = "hard"
|
| 131 |
+
|
| 132 |
+
return ep
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# ── Perturbation 2: Co-Accused Conflict ──────────────────────────────
|
| 136 |
+
|
| 137 |
+
def _co_accused_conflict(episode: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| 138 |
+
"""
|
| 139 |
+
Add a co-accused with the OPPOSITE bail outcome.
|
| 140 |
+
Forces the agent to make a parity argument.
|
| 141 |
+
"""
|
| 142 |
+
ep = copy.deepcopy(episode)
|
| 143 |
+
gt = ep.get("ground_truth", {})
|
| 144 |
+
gt_outcome = gt.get("outcome", "Bail Granted")
|
| 145 |
+
|
| 146 |
+
# Opposite outcome
|
| 147 |
+
if "grant" in gt_outcome.lower():
|
| 148 |
+
co_outcome = "Bail Denied"
|
| 149 |
+
else:
|
| 150 |
+
co_outcome = "Bail Granted"
|
| 151 |
+
|
| 152 |
+
ep["co_accused"] = [{
|
| 153 |
+
"name": "Co-Accused A",
|
| 154 |
+
"bail_outcome": co_outcome,
|
| 155 |
+
"sections": ep.get("ipc_sections", []),
|
| 156 |
+
}]
|
| 157 |
+
|
| 158 |
+
gt["parity_argument_used"] = True
|
| 159 |
+
ep["ground_truth"] = gt
|
| 160 |
+
|
| 161 |
+
# Add parity context to defence arguments
|
| 162 |
+
defence = ep.get("defence_arguments", [])
|
| 163 |
+
defence.append(
|
| 164 |
+
f"Co-accused was {'granted' if 'grant' in co_outcome.lower() else 'denied'} "
|
| 165 |
+
f"bail under identical charges — parity principle applies."
|
| 166 |
+
)
|
| 167 |
+
ep["defence_arguments"] = defence
|
| 168 |
+
|
| 169 |
+
# Metadata
|
| 170 |
+
parent_id = ep.get("case_id", "UNKNOWN")
|
| 171 |
+
ep["case_id"] = f"SYN_{parent_id}_COAC"
|
| 172 |
+
ep["source"] = "synthetic"
|
| 173 |
+
ep["parent_case_id"] = parent_id
|
| 174 |
+
ep["perturbation_type"] = "co_accused_conflict"
|
| 175 |
+
ep["difficulty"] = "hard"
|
| 176 |
+
|
| 177 |
+
return ep
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# ── Perturbation 3: Section Ambiguity (IPC ↔ BNSS) ──────────────────
|
| 181 |
+
|
| 182 |
+
def _section_ambiguity(episode: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| 183 |
+
"""
|
| 184 |
+
Swap IPC sections to BNSS equivalents (or vice versa).
|
| 185 |
+
Tests schema drift adaptability.
|
| 186 |
+
"""
|
| 187 |
+
ep = copy.deepcopy(episode)
|
| 188 |
+
sections = ep.get("ipc_sections", [])
|
| 189 |
+
|
| 190 |
+
if not sections:
|
| 191 |
+
return None
|
| 192 |
+
|
| 193 |
+
new_sections = []
|
| 194 |
+
swapped = False
|
| 195 |
+
for sec in sections:
|
| 196 |
+
sec_clean = sec.strip()
|
| 197 |
+
if sec_clean in IPC_TO_BNSS:
|
| 198 |
+
new_sections.append(IPC_TO_BNSS[sec_clean])
|
| 199 |
+
swapped = True
|
| 200 |
+
elif sec_clean in BNSS_TO_IPC:
|
| 201 |
+
new_sections.append(BNSS_TO_IPC[sec_clean])
|
| 202 |
+
swapped = True
|
| 203 |
+
else:
|
| 204 |
+
new_sections.append(sec_clean)
|
| 205 |
+
|
| 206 |
+
if not swapped:
|
| 207 |
+
return None
|
| 208 |
+
|
| 209 |
+
ep["ipc_sections"] = new_sections
|
| 210 |
+
|
| 211 |
+
# Update charge sheet references
|
| 212 |
+
charge = ep.get("charge_sheet", "")
|
| 213 |
+
for old_sec, new_sec in zip(sections, new_sections):
|
| 214 |
+
if old_sec != new_sec:
|
| 215 |
+
charge = charge.replace(f"Section {old_sec}", f"Section {new_sec}")
|
| 216 |
+
charge = charge.replace(f"section {old_sec}", f"section {new_sec}")
|
| 217 |
+
ep["charge_sheet"] = charge
|
| 218 |
+
|
| 219 |
+
# Metadata
|
| 220 |
+
parent_id = ep.get("case_id", "UNKNOWN")
|
| 221 |
+
ep["case_id"] = f"SYN_{parent_id}_SECT"
|
| 222 |
+
ep["source"] = "synthetic"
|
| 223 |
+
ep["parent_case_id"] = parent_id
|
| 224 |
+
ep["perturbation_type"] = "section_ambiguity"
|
| 225 |
+
ep["difficulty"] = "hard"
|
| 226 |
+
|
| 227 |
+
return ep
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# ── Perturbation 4: Evidence Reversal ────────────────────────────────
|
| 231 |
+
|
| 232 |
+
def _evidence_reversal(episode: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| 233 |
+
"""
|
| 234 |
+
Add a contradicting element to the strongest evidence.
|
| 235 |
+
Tests whether the agent updates assessment when evidence weakens.
|
| 236 |
+
"""
|
| 237 |
+
ep = copy.deepcopy(episode)
|
| 238 |
+
|
| 239 |
+
# Find the strongest evidence mention
|
| 240 |
+
evidence_keywords = ["witness", "evidence", "testimony", "eyewitness"]
|
| 241 |
+
pros_args = ep.get("prosecution_arguments", [])
|
| 242 |
+
charge = ep.get("charge_sheet", "")
|
| 243 |
+
|
| 244 |
+
# Check prosecution arguments first
|
| 245 |
+
target_arg = None
|
| 246 |
+
for arg in pros_args:
|
| 247 |
+
if any(kw in arg.lower() for kw in evidence_keywords):
|
| 248 |
+
target_arg = arg
|
| 249 |
+
break
|
| 250 |
+
|
| 251 |
+
if target_arg is None:
|
| 252 |
+
# Check charge sheet sentences
|
| 253 |
+
sentences = [s.strip() for s in charge.split('.') if s.strip()]
|
| 254 |
+
for sent in sentences:
|
| 255 |
+
if any(kw in sent.lower() for kw in evidence_keywords):
|
| 256 |
+
target_arg = sent
|
| 257 |
+
break
|
| 258 |
+
|
| 259 |
+
if target_arg is None:
|
| 260 |
+
return None # No evidence to reverse
|
| 261 |
+
|
| 262 |
+
# Add reversal to defence arguments
|
| 263 |
+
defence = ep.get("defence_arguments", [])
|
| 264 |
+
defence.append(
|
| 265 |
+
"However, the key prosecution evidence was subsequently found "
|
| 266 |
+
"unreliable — the primary witness retracted their statement and "
|
| 267 |
+
"forensic analysis raised doubts about the physical evidence."
|
| 268 |
+
)
|
| 269 |
+
ep["defence_arguments"] = defence
|
| 270 |
+
|
| 271 |
+
# Metadata
|
| 272 |
+
parent_id = ep.get("case_id", "UNKNOWN")
|
| 273 |
+
ep["case_id"] = f"SYN_{parent_id}_EVID"
|
| 274 |
+
ep["source"] = "synthetic"
|
| 275 |
+
ep["parent_case_id"] = parent_id
|
| 276 |
+
ep["perturbation_type"] = "evidence_reversal"
|
| 277 |
+
ep["difficulty"] = "hard"
|
| 278 |
+
|
| 279 |
+
return ep
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# ── Perturbation 5: Surety Complexity ────────────────────────────────
|
| 283 |
+
|
| 284 |
+
def _surety_complexity(episode: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| 285 |
+
"""
|
| 286 |
+
Add a surety complication forcing careful condition assessment.
|
| 287 |
+
"""
|
| 288 |
+
ep = copy.deepcopy(episode)
|
| 289 |
+
|
| 290 |
+
# Add surety complication to defence arguments
|
| 291 |
+
defence = ep.get("defence_arguments", [])
|
| 292 |
+
defence.append(
|
| 293 |
+
"Proposed surety is a non-resident relative with no verifiable "
|
| 294 |
+
"local assets or employment in the jurisdiction. Surety bond "
|
| 295 |
+
"amount of Rs. 5,00,000 proposed."
|
| 296 |
+
)
|
| 297 |
+
ep["defence_arguments"] = defence
|
| 298 |
+
|
| 299 |
+
# Add surety info to accused profile
|
| 300 |
+
profile = ep.get("accused_profile", {})
|
| 301 |
+
profile["surety_status"] = "non-resident, unverified assets"
|
| 302 |
+
ep["accused_profile"] = profile
|
| 303 |
+
|
| 304 |
+
# Metadata
|
| 305 |
+
parent_id = ep.get("case_id", "UNKNOWN")
|
| 306 |
+
ep["case_id"] = f"SYN_{parent_id}_SURE"
|
| 307 |
+
ep["source"] = "synthetic"
|
| 308 |
+
ep["parent_case_id"] = parent_id
|
| 309 |
+
ep["perturbation_type"] = "surety_complexity"
|
| 310 |
+
ep["difficulty"] = "hard"
|
| 311 |
+
|
| 312 |
+
return ep
|
server/performance_tracker.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
UndertriAI — Performance Tracker (Theme 4: Self-Improvement)
|
| 3 |
+
|
| 4 |
+
Tracks the agent's running performance profile across dimensions
|
| 5 |
+
and uses it to drive adaptive curriculum decisions.
|
| 6 |
+
|
| 7 |
+
Pure Python — no server/training/FastAPI dependencies.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import warnings
|
| 11 |
+
from collections import deque
|
| 12 |
+
from typing import Any, Dict, List, Optional
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ExponentialMean:
|
| 16 |
+
"""Exponential moving average with configurable decay."""
|
| 17 |
+
|
| 18 |
+
__slots__ = ("alpha", "value", "count")
|
| 19 |
+
|
| 20 |
+
def __init__(self, alpha: float = 0.1, initial: float = 0.5):
|
| 21 |
+
self.alpha = alpha
|
| 22 |
+
self.value = initial
|
| 23 |
+
self.count = 0
|
| 24 |
+
|
| 25 |
+
def update(self, x: float) -> None:
|
| 26 |
+
self.value = self.alpha * x + (1 - self.alpha) * self.value
|
| 27 |
+
self.count += 1
|
| 28 |
+
|
| 29 |
+
def get(self) -> float:
|
| 30 |
+
return self.value
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class PerformanceTracker:
|
| 34 |
+
"""
|
| 35 |
+
Tracks agent performance across crime types, stages, and reward
|
| 36 |
+
components. Drives adaptive episode selection and stage promotion.
|
| 37 |
+
|
| 38 |
+
Thread-safe for single-session use (no locks needed).
|
| 39 |
+
All public methods handle missing/malformed input gracefully.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, alpha: float = 0.1):
|
| 43 |
+
self._alpha = alpha
|
| 44 |
+
|
| 45 |
+
# Per-crime-type EMA of total reward
|
| 46 |
+
self.per_crime_type: Dict[str, ExponentialMean] = {}
|
| 47 |
+
|
| 48 |
+
# Per-stage EMA of total reward
|
| 49 |
+
self.per_stage: Dict[int, ExponentialMean] = {
|
| 50 |
+
s: ExponentialMean(alpha=alpha) for s in range(1, 5)
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
# Last 50 total rewards (for stage promotion smoothing)
|
| 54 |
+
self.recent_rewards: deque = deque(maxlen=50)
|
| 55 |
+
|
| 56 |
+
# Bias fire rate: 1.0 when penalty fired, 0.0 when not
|
| 57 |
+
self.bias_fire_rate: ExponentialMean = ExponentialMean(alpha=alpha)
|
| 58 |
+
|
| 59 |
+
# Tool usage counts (cumulative per session)
|
| 60 |
+
self.tool_usage: Dict[str, int] = {}
|
| 61 |
+
|
| 62 |
+
# Episode counters
|
| 63 |
+
self.episodes_seen: int = 0
|
| 64 |
+
self.stage_episodes: Dict[int, int] = {1: 0, 2: 0, 3: 0, 4: 0}
|
| 65 |
+
|
| 66 |
+
# Recent case performance for failure-replay
|
| 67 |
+
self._recent_case_rewards: deque = deque(maxlen=30)
|
| 68 |
+
|
| 69 |
+
# ------------------------------------------------------------------
|
| 70 |
+
# Core update
|
| 71 |
+
# ------------------------------------------------------------------
|
| 72 |
+
|
| 73 |
+
def update(
|
| 74 |
+
self,
|
| 75 |
+
episode: Dict[str, Any],
|
| 76 |
+
reward_components: Dict[str, Any],
|
| 77 |
+
tools_used: Optional[List[str]] = None,
|
| 78 |
+
) -> None:
|
| 79 |
+
"""
|
| 80 |
+
Update all internal state from a completed episode.
|
| 81 |
+
|
| 82 |
+
Handles missing keys gracefully — never raises on malformed input.
|
| 83 |
+
"""
|
| 84 |
+
try:
|
| 85 |
+
total = float(reward_components.get("total_reward",
|
| 86 |
+
reward_components.get("total", 0.0)))
|
| 87 |
+
except (TypeError, ValueError):
|
| 88 |
+
total = 0.0
|
| 89 |
+
|
| 90 |
+
# Update recent rewards
|
| 91 |
+
self.recent_rewards.append(total)
|
| 92 |
+
self.episodes_seen += 1
|
| 93 |
+
|
| 94 |
+
# Per-crime-type tracking
|
| 95 |
+
crime_type = ""
|
| 96 |
+
try:
|
| 97 |
+
crime_type = str(episode.get("crime_type", "")).strip()
|
| 98 |
+
except Exception:
|
| 99 |
+
pass
|
| 100 |
+
|
| 101 |
+
if crime_type:
|
| 102 |
+
if crime_type not in self.per_crime_type:
|
| 103 |
+
self.per_crime_type[crime_type] = ExponentialMean(
|
| 104 |
+
alpha=self._alpha
|
| 105 |
+
)
|
| 106 |
+
self.per_crime_type[crime_type].update(total)
|
| 107 |
+
|
| 108 |
+
# Per-stage tracking
|
| 109 |
+
stage = 1
|
| 110 |
+
try:
|
| 111 |
+
stage = int(episode.get("curriculum_stage", 1))
|
| 112 |
+
except (TypeError, ValueError):
|
| 113 |
+
stage = 1
|
| 114 |
+
if 1 <= stage <= 4:
|
| 115 |
+
self.per_stage[stage].update(total)
|
| 116 |
+
self.stage_episodes[stage] = self.stage_episodes.get(stage, 0) + 1
|
| 117 |
+
|
| 118 |
+
# Bias fire rate
|
| 119 |
+
try:
|
| 120 |
+
bias_val = float(reward_components.get("bias_penalty", 0.0))
|
| 121 |
+
self.bias_fire_rate.update(1.0 if bias_val > 0.01 else 0.0)
|
| 122 |
+
except (TypeError, ValueError):
|
| 123 |
+
pass
|
| 124 |
+
|
| 125 |
+
# Tool usage
|
| 126 |
+
if tools_used:
|
| 127 |
+
for tool in tools_used:
|
| 128 |
+
t = str(tool)
|
| 129 |
+
self.tool_usage[t] = self.tool_usage.get(t, 0) + 1
|
| 130 |
+
|
| 131 |
+
# Track case_id → reward for failure-replay
|
| 132 |
+
case_id = ""
|
| 133 |
+
try:
|
| 134 |
+
case_id = str(episode.get("case_id", ""))
|
| 135 |
+
except Exception:
|
| 136 |
+
pass
|
| 137 |
+
if case_id:
|
| 138 |
+
self._recent_case_rewards.append((case_id, total, stage))
|
| 139 |
+
|
| 140 |
+
# ------------------------------------------------------------------
|
| 141 |
+
# Queries
|
| 142 |
+
# ------------------------------------------------------------------
|
| 143 |
+
|
| 144 |
+
def weakest_domain(self) -> Optional[str]:
|
| 145 |
+
"""
|
| 146 |
+
Returns the crime_type with the lowest EMA reward.
|
| 147 |
+
Returns None if fewer than 5 episodes seen total or no crime type
|
| 148 |
+
has at least 3 observations.
|
| 149 |
+
"""
|
| 150 |
+
if self.episodes_seen < 5:
|
| 151 |
+
return None
|
| 152 |
+
|
| 153 |
+
candidates = [
|
| 154 |
+
(ct, ema.get())
|
| 155 |
+
for ct, ema in self.per_crime_type.items()
|
| 156 |
+
if ema.count >= 3
|
| 157 |
+
]
|
| 158 |
+
if not candidates:
|
| 159 |
+
return None
|
| 160 |
+
|
| 161 |
+
return min(candidates, key=lambda x: x[1])[0]
|
| 162 |
+
|
| 163 |
+
def suggest_next_stage(self) -> int:
|
| 164 |
+
"""
|
| 165 |
+
Returns the recommended stage (1-4) based on readiness thresholds.
|
| 166 |
+
Never demotes — returns highest eligible stage.
|
| 167 |
+
"""
|
| 168 |
+
current = 1
|
| 169 |
+
|
| 170 |
+
# Stage 1 → 2: EMA >= 0.65 AND at least 20 episodes
|
| 171 |
+
if (self.per_stage[1].get() >= 0.65
|
| 172 |
+
and self.stage_episodes.get(1, 0) >= 20):
|
| 173 |
+
current = 2
|
| 174 |
+
|
| 175 |
+
# Stage 2 → 3: EMA >= 0.55 AND at least 50 episodes
|
| 176 |
+
if (current >= 2
|
| 177 |
+
and self.per_stage[2].get() >= 0.55
|
| 178 |
+
and self.stage_episodes.get(2, 0) >= 50):
|
| 179 |
+
current = 3
|
| 180 |
+
|
| 181 |
+
# Stage 3 → 4: EMA >= 0.50 AND at least 20 episodes
|
| 182 |
+
if (current >= 3
|
| 183 |
+
and self.per_stage[3].get() >= 0.50
|
| 184 |
+
and self.stage_episodes.get(3, 0) >= 20):
|
| 185 |
+
current = 4
|
| 186 |
+
|
| 187 |
+
return current
|
| 188 |
+
|
| 189 |
+
def should_generate_synthetic(self, crime_type: str) -> bool:
|
| 190 |
+
"""
|
| 191 |
+
Returns True if the agent has mastered this crime type domain
|
| 192 |
+
(EMA > 0.70 with at least 10 observations).
|
| 193 |
+
"""
|
| 194 |
+
ema = self.per_crime_type.get(crime_type)
|
| 195 |
+
if ema is None:
|
| 196 |
+
return False
|
| 197 |
+
return ema.get() > 0.70 and ema.count >= 10
|
| 198 |
+
|
| 199 |
+
def get_recent_failures(self, threshold: float = 0.40) -> List[str]:
|
| 200 |
+
"""
|
| 201 |
+
Returns case_ids from recent episodes where reward was below threshold.
|
| 202 |
+
Used by AdaptiveSelector for failure-replay.
|
| 203 |
+
"""
|
| 204 |
+
return [
|
| 205 |
+
case_id
|
| 206 |
+
for case_id, reward, _ in self._recent_case_rewards
|
| 207 |
+
if reward < threshold
|
| 208 |
+
]
|
| 209 |
+
|
| 210 |
+
# ------------------------------------------------------------------
|
| 211 |
+
# Serialization
|
| 212 |
+
# ------------------------------------------------------------------
|
| 213 |
+
|
| 214 |
+
def get_profile(self) -> Dict[str, Any]:
|
| 215 |
+
"""
|
| 216 |
+
Returns a fully JSON-serializable profile dict.
|
| 217 |
+
No class instances — all values are primitive types.
|
| 218 |
+
"""
|
| 219 |
+
recent = list(self.recent_rewards)
|
| 220 |
+
recent_mean = sum(recent) / len(recent) if recent else 0.0
|
| 221 |
+
|
| 222 |
+
return {
|
| 223 |
+
"per_crime_type": {
|
| 224 |
+
ct: round(ema.get(), 4)
|
| 225 |
+
for ct, ema in self.per_crime_type.items()
|
| 226 |
+
},
|
| 227 |
+
"per_stage": {
|
| 228 |
+
str(s): round(ema.get(), 4)
|
| 229 |
+
for s, ema in self.per_stage.items()
|
| 230 |
+
},
|
| 231 |
+
"bias_fire_rate": round(self.bias_fire_rate.get(), 4),
|
| 232 |
+
"tool_usage": dict(self.tool_usage),
|
| 233 |
+
"episodes_seen": self.episodes_seen,
|
| 234 |
+
"stage_episodes": dict(self.stage_episodes),
|
| 235 |
+
"weakest_domain": self.weakest_domain(),
|
| 236 |
+
"suggested_stage": self.suggest_next_stage(),
|
| 237 |
+
"recent_mean_reward": round(recent_mean, 4),
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
# ------------------------------------------------------------------
|
| 241 |
+
# Session management
|
| 242 |
+
# ------------------------------------------------------------------
|
| 243 |
+
|
| 244 |
+
def reset_session(self) -> None:
|
| 245 |
+
"""
|
| 246 |
+
Clears transient session state but preserves accumulated
|
| 247 |
+
per-crime-type and per-stage learning.
|
| 248 |
+
"""
|
| 249 |
+
self.recent_rewards.clear()
|
| 250 |
+
self.tool_usage.clear()
|
| 251 |
+
self._recent_case_rewards.clear()
|
server/reward.py
CHANGED
|
@@ -11,6 +11,27 @@ import re
|
|
| 11 |
from typing import Any, Dict, List, Optional
|
| 12 |
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
# ---------------------------------------------------------------------------
|
| 15 |
# 1. Outcome Match (40%)
|
| 16 |
# ---------------------------------------------------------------------------
|
|
@@ -170,6 +191,23 @@ def compute_statutory_accuracy(
|
|
| 170 |
if not special_laws and any(t in crime_type_lower for t in CRIME_TYPE_SPECIAL_LAWS):
|
| 171 |
special_laws = "INFERRED" # Treat as special-law-restricted for eligibility
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
# Compute ground-truth eligibility for cases with known custody duration
|
| 174 |
half_sent_months = (max_sent * 12) / 2
|
| 175 |
truly_eligible = (custody_mo >= half_sent_months) and not special_laws
|
|
@@ -463,6 +501,72 @@ def compute_reasoning_quality(
|
|
| 463 |
return round(max(0.0, min(1.0, base - consistency_deduction)), 4)
|
| 464 |
|
| 465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
# ---------------------------------------------------------------------------
|
| 467 |
# Master reward function
|
| 468 |
# ---------------------------------------------------------------------------
|
|
@@ -480,22 +584,28 @@ def compute_reward(
|
|
| 480 |
agent_flight_risk_justification: str = "",
|
| 481 |
agent_grounds_for: Optional[List[str]] = None,
|
| 482 |
agent_grounds_against: Optional[List[str]] = None,
|
|
|
|
|
|
|
| 483 |
) -> Dict[str, float]:
|
| 484 |
"""
|
| 485 |
Computes the full reward for a submitted bail assessment memo.
|
| 486 |
|
| 487 |
-
Formula:
|
| 488 |
-
R = 0.
|
| 489 |
+ 0.2*flight_risk_accuracy
|
| 490 |
+ 0.2*statutory_accuracy
|
| 491 |
+ 0.2*condition_appropriateness
|
| 492 |
-
+ 0.1*reasoning_quality
|
| 493 |
-
+ 0.
|
| 494 |
-
+ 0.05*
|
|
|
|
| 495 |
- 0.3*bias_penalty
|
| 496 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
Returns a dict with all component scores + total_reward.
|
| 498 |
-
Range: approx [-0.4, 1.1].
|
| 499 |
"""
|
| 500 |
gt = episode["ground_truth"]
|
| 501 |
|
|
@@ -515,6 +625,18 @@ def compute_reward(
|
|
| 515 |
episode = episode,
|
| 516 |
)
|
| 517 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
# Efficiency bonus: reward finishing faster when the answer is correct.
|
| 519 |
# Only fires on directionally-correct outcomes (om >= 0.8) to prevent
|
| 520 |
# rewarding efficient-but-wrong agents.
|
|
@@ -526,15 +648,25 @@ def compute_reward(
|
|
| 526 |
# Process reward: +0.05 if agent actually used the statutory tool.
|
| 527 |
process_bonus = 0.05 if statutory_tool_used else 0.0
|
| 528 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
lam = 0.3
|
| 530 |
-
total = 0.
|
|
|
|
|
|
|
| 531 |
|
| 532 |
return {
|
| 533 |
"outcome_match": round(om, 4),
|
|
|
|
|
|
|
| 534 |
"flight_risk_accuracy": round(fr, 4),
|
| 535 |
"statutory_accuracy": round(sa, 4),
|
| 536 |
"condition_appropriateness": round(ca, 4),
|
| 537 |
"reasoning_quality": round(rq, 4),
|
|
|
|
| 538 |
"efficiency_bonus": round(efficiency, 4),
|
| 539 |
"process_bonus": round(process_bonus,4),
|
| 540 |
"bias_penalty": round(bias, 4),
|
|
|
|
| 11 |
from typing import Any, Dict, List, Optional
|
| 12 |
|
| 13 |
|
| 14 |
+
# ---------------------------------------------------------------------------
|
| 15 |
+
# Shared helper: NDPS detection (canonical definition — import this elsewhere)
|
| 16 |
+
# ---------------------------------------------------------------------------
|
| 17 |
+
|
| 18 |
+
def _is_ndps_case(episode: dict) -> bool:
|
| 19 |
+
"""
|
| 20 |
+
Detect narcotics cases even when special_laws field is empty.
|
| 21 |
+
Checks ipc_sections and crime_type for NDPS indicators.
|
| 22 |
+
|
| 23 |
+
This is the SINGLE canonical definition — import from server.reward
|
| 24 |
+
in undertrial_environment.py and training/train_grpo.py.
|
| 25 |
+
"""
|
| 26 |
+
sections = " ".join(str(s) for s in episode.get("ipc_sections", [])).lower()
|
| 27 |
+
crime = str(episode.get("crime_type", "")).lower()
|
| 28 |
+
narcotics_indicators = [
|
| 29 |
+
"ndps", "narcotic", "drug", "psychotropic",
|
| 30 |
+
"20(b)", "22(b)", "27a", "section 37",
|
| 31 |
+
]
|
| 32 |
+
return any(ind in sections or ind in crime for ind in narcotics_indicators)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
# ---------------------------------------------------------------------------
|
| 36 |
# 1. Outcome Match (40%)
|
| 37 |
# ---------------------------------------------------------------------------
|
|
|
|
| 191 |
if not special_laws and any(t in crime_type_lower for t in CRIME_TYPE_SPECIAL_LAWS):
|
| 192 |
special_laws = "INFERRED" # Treat as special-law-restricted for eligibility
|
| 193 |
|
| 194 |
+
# ── B9: NDPS-specific statutory scoring ──────────────────────────────
|
| 195 |
+
# NDPS Section 37 twin conditions override standard threshold logic.
|
| 196 |
+
# Reward the agent for recognizing NDPS applies, not for arithmetic.
|
| 197 |
+
if _is_ndps_case(episode):
|
| 198 |
+
gt_granted = "grant" in gt_outcome.lower()
|
| 199 |
+
direction_correct = (agent_eligible == gt_granted)
|
| 200 |
+
ndps_recognized = any(
|
| 201 |
+
t in comp for t in ["section 37", "twin condition", "ndps", "37(1)(b)"]
|
| 202 |
+
)
|
| 203 |
+
if ndps_recognized and direction_correct:
|
| 204 |
+
return 1.0
|
| 205 |
+
elif direction_correct:
|
| 206 |
+
return 0.5
|
| 207 |
+
else:
|
| 208 |
+
return 0.0
|
| 209 |
+
|
| 210 |
+
# ── Standard IPC/BNSS statutory scoring ──────────────────────────────
|
| 211 |
# Compute ground-truth eligibility for cases with known custody duration
|
| 212 |
half_sent_months = (max_sent * 12) / 2
|
| 213 |
truly_eligible = (custody_mo >= half_sent_months) and not special_laws
|
|
|
|
| 501 |
return round(max(0.0, min(1.0, base - consistency_deduction)), 4)
|
| 502 |
|
| 503 |
|
| 504 |
+
# ---------------------------------------------------------------------------
|
| 505 |
+
# 7. Think-block reasoning gate (B6)
|
| 506 |
+
# ---------------------------------------------------------------------------
|
| 507 |
+
|
| 508 |
+
def compute_think_factor(completion: str, current_stage: int) -> float:
|
| 509 |
+
"""
|
| 510 |
+
Gate outcome credit on reasoning quality.
|
| 511 |
+
Stage 1: soft floor of 0.3 minimum (model still learning format).
|
| 512 |
+
Stage 2+: hard gate — no reasoning = no outcome credit.
|
| 513 |
+
Threshold: 120 words for full credit.
|
| 514 |
+
"""
|
| 515 |
+
if not completion:
|
| 516 |
+
return 0.3 if current_stage == 1 else 0.0
|
| 517 |
+
|
| 518 |
+
think_match = re.search(r'<think>(.*?)</think>', completion, re.DOTALL)
|
| 519 |
+
think_text = think_match.group(1).strip() if think_match else ""
|
| 520 |
+
think_len = len(think_text.split())
|
| 521 |
+
raw_factor = min(1.0, think_len / 120.0)
|
| 522 |
+
|
| 523 |
+
if current_stage == 1:
|
| 524 |
+
# Soft floor: minimum 0.3 credit even with no think block
|
| 525 |
+
# Ensures GRPO has non-zero gradient signal in early training
|
| 526 |
+
return 0.3 + 0.7 * raw_factor
|
| 527 |
+
else:
|
| 528 |
+
# Hard gate: must reason to earn outcome credit
|
| 529 |
+
return raw_factor
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
# ---------------------------------------------------------------------------
|
| 533 |
+
# 8. Format compliance (B8)
|
| 534 |
+
# ---------------------------------------------------------------------------
|
| 535 |
+
|
| 536 |
+
def reward_format(completion: str) -> float:
|
| 537 |
+
"""
|
| 538 |
+
Score structural compliance of the bail memo.
|
| 539 |
+
Checks for required XML tags matching the system prompt and valid outcome.
|
| 540 |
+
Returns 0.0–1.0 (fraction of required elements present).
|
| 541 |
+
"""
|
| 542 |
+
if not completion:
|
| 543 |
+
return 0.0
|
| 544 |
+
|
| 545 |
+
# Tags must match exactly what SYSTEM_PROMPT instructs the model to produce
|
| 546 |
+
required_tags = [
|
| 547 |
+
r'<think>',
|
| 548 |
+
r'<memo>',
|
| 549 |
+
r'<flight_risk>',
|
| 550 |
+
r'<statutory_eligible>',
|
| 551 |
+
r'<recommended_outcome>',
|
| 552 |
+
r'<statutory_computation>',
|
| 553 |
+
]
|
| 554 |
+
valid_outcomes = [
|
| 555 |
+
'bail granted', 'bail denied',
|
| 556 |
+
'conditional bail', 'default bail',
|
| 557 |
+
]
|
| 558 |
+
|
| 559 |
+
checks = [
|
| 560 |
+
bool(re.search(tag, completion, re.IGNORECASE))
|
| 561 |
+
for tag in required_tags
|
| 562 |
+
]
|
| 563 |
+
checks.append(
|
| 564 |
+
any(outcome in completion.lower() for outcome in valid_outcomes)
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
return sum(checks) / len(checks)
|
| 568 |
+
|
| 569 |
+
|
| 570 |
# ---------------------------------------------------------------------------
|
| 571 |
# Master reward function
|
| 572 |
# ---------------------------------------------------------------------------
|
|
|
|
| 584 |
agent_flight_risk_justification: str = "",
|
| 585 |
agent_grounds_for: Optional[List[str]] = None,
|
| 586 |
agent_grounds_against: Optional[List[str]] = None,
|
| 587 |
+
completion_text: Optional[str] = None,
|
| 588 |
+
current_stage: int = 1,
|
| 589 |
) -> Dict[str, float]:
|
| 590 |
"""
|
| 591 |
Computes the full reward for a submitted bail assessment memo.
|
| 592 |
|
| 593 |
+
Formula (B6/B8 update):
|
| 594 |
+
R = 0.4*outcome_gated (gated by think_factor)
|
| 595 |
+ 0.2*flight_risk_accuracy
|
| 596 |
+ 0.2*statutory_accuracy
|
| 597 |
+ 0.2*condition_appropriateness
|
| 598 |
+
+ 0.1*reasoning_quality
|
| 599 |
+
+ 0.05*efficiency_bonus
|
| 600 |
+
+ 0.05*format_score
|
| 601 |
+
+ process_bonus
|
| 602 |
- 0.3*bias_penalty
|
| 603 |
|
| 604 |
+
Core components: 0.4+0.2+0.2+0.2 = 1.0
|
| 605 |
+
Bonuses: rq(0.1) + eff(0.05) + fmt(0.05) + process(0.05)
|
| 606 |
+
Penalty: -0.3*bias
|
| 607 |
+
|
| 608 |
Returns a dict with all component scores + total_reward.
|
|
|
|
| 609 |
"""
|
| 610 |
gt = episode["ground_truth"]
|
| 611 |
|
|
|
|
| 625 |
episode = episode,
|
| 626 |
)
|
| 627 |
|
| 628 |
+
# B6: Gate outcome credit on reasoning quality (think block)
|
| 629 |
+
# In server path, completion_text may be None (structured memo submission)
|
| 630 |
+
# — default to think_factor=1.0 (no gating; env already enforces min tools).
|
| 631 |
+
if completion_text:
|
| 632 |
+
think_factor = compute_think_factor(completion_text, current_stage)
|
| 633 |
+
else:
|
| 634 |
+
think_factor = 1.0
|
| 635 |
+
om_gated = om * think_factor
|
| 636 |
+
|
| 637 |
+
# B8: Format compliance score
|
| 638 |
+
fmt = reward_format(completion_text) if completion_text else 0.5
|
| 639 |
+
|
| 640 |
# Efficiency bonus: reward finishing faster when the answer is correct.
|
| 641 |
# Only fires on directionally-correct outcomes (om >= 0.8) to prevent
|
| 642 |
# rewarding efficient-but-wrong agents.
|
|
|
|
| 648 |
# Process reward: +0.05 if agent actually used the statutory tool.
|
| 649 |
process_bonus = 0.05 if statutory_tool_used else 0.0
|
| 650 |
|
| 651 |
+
# Reward formula:
|
| 652 |
+
# Core (sum=1.0): 0.4*outcome_gated + 0.2*flight + 0.2*statutory + 0.2*conditions
|
| 653 |
+
# Bonuses: 0.1*reasoning_quality + 0.05*efficiency + 0.05*format
|
| 654 |
+
# Process: +0.05 if statutory tool used
|
| 655 |
+
# Penalty: -0.3*bias
|
| 656 |
lam = 0.3
|
| 657 |
+
total = (0.4*om_gated + 0.2*fr + 0.2*sa + 0.2*ca
|
| 658 |
+
+ 0.1*rq + 0.05*efficiency + 0.05*fmt
|
| 659 |
+
+ process_bonus - lam*bias)
|
| 660 |
|
| 661 |
return {
|
| 662 |
"outcome_match": round(om, 4),
|
| 663 |
+
"outcome_match_gated": round(om_gated, 4),
|
| 664 |
+
"think_factor": round(think_factor, 4),
|
| 665 |
"flight_risk_accuracy": round(fr, 4),
|
| 666 |
"statutory_accuracy": round(sa, 4),
|
| 667 |
"condition_appropriateness": round(ca, 4),
|
| 668 |
"reasoning_quality": round(rq, 4),
|
| 669 |
+
"format_score": round(fmt, 4),
|
| 670 |
"efficiency_bonus": round(efficiency, 4),
|
| 671 |
"process_bonus": round(process_bonus,4),
|
| 672 |
"bias_penalty": round(bias, 4),
|
server/undertrial_environment.py
CHANGED
|
@@ -8,7 +8,7 @@ import uuid
|
|
| 8 |
from typing import Any, Dict, List, Optional
|
| 9 |
|
| 10 |
from .dataset import BailDataset
|
| 11 |
-
from .reward import compute_reward
|
| 12 |
from .schema_drift import maybe_apply_drift
|
| 13 |
|
| 14 |
try:
|
|
@@ -277,6 +277,21 @@ class UndertriAIEnvironment(Environment):
|
|
| 277 |
return "No directly applicable precedents found in database."
|
| 278 |
|
| 279 |
elif isinstance(action, ComputeStatutoryEligibilityAction):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
half_months = (action.max_sentence_years * 12) / 2
|
| 281 |
eligible = action.custody_months >= half_months and not action.special_law_applicable
|
| 282 |
pct = round((action.custody_months / (action.max_sentence_years * 12)) * 100, 1) if action.max_sentence_years else 0
|
|
|
|
| 8 |
from typing import Any, Dict, List, Optional
|
| 9 |
|
| 10 |
from .dataset import BailDataset
|
| 11 |
+
from .reward import compute_reward, _is_ndps_case
|
| 12 |
from .schema_drift import maybe_apply_drift
|
| 13 |
|
| 14 |
try:
|
|
|
|
| 277 |
return "No directly applicable precedents found in database."
|
| 278 |
|
| 279 |
elif isinstance(action, ComputeStatutoryEligibilityAction):
|
| 280 |
+
# B9: NDPS cases get Section 37 response instead of threshold arithmetic
|
| 281 |
+
if _is_ndps_case(self._episode):
|
| 282 |
+
return (
|
| 283 |
+
f"Statutory Eligibility Analysis:\n"
|
| 284 |
+
f" Sections: {', '.join(action.sections_invoked)}\n"
|
| 285 |
+
f" Special Law: NDPS Act applies\n"
|
| 286 |
+
f" Section: Section 37 NDPS Act\n"
|
| 287 |
+
f" Message: NDPS Section 37 applies. Standard custody threshold not applicable. "
|
| 288 |
+
f"Bail requires twin conditions under Section 37(1)(b): "
|
| 289 |
+
f"(i) reasonable grounds to believe accused is not guilty, "
|
| 290 |
+
f"(ii) no reasonable opportunity to commit offence if released. "
|
| 291 |
+
f"These are matters for judicial discretion, not statutory calculation.\n"
|
| 292 |
+
f" → ELIGIBLE FOR DEFAULT BAIL: NOT APPLICABLE (NDPS twin conditions govern)"
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
half_months = (action.max_sentence_years * 12) / 2
|
| 296 |
eligible = action.custody_months >= half_months and not action.special_law_applicable
|
| 297 |
pct = round((action.custody_months / (action.max_sentence_years * 12)) * 100, 1) if action.max_sentence_years else 0
|
training/train_grpo.py
CHANGED
|
@@ -52,12 +52,41 @@ try:
|
|
| 52 |
compute_condition_score,
|
| 53 |
compute_bias_penalty as _server_bias,
|
| 54 |
compute_reasoning_quality,
|
|
|
|
|
|
|
|
|
|
| 55 |
)
|
| 56 |
_USE_SERVER_REWARDS = True
|
| 57 |
print("[reward] Using authoritative server/reward.py functions.")
|
| 58 |
except ImportError:
|
| 59 |
_USE_SERVER_REWARDS = False
|
| 60 |
print("[reward] server/reward.py not found — using local fallback functions.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
from datasets import Dataset
|
| 62 |
|
| 63 |
# ============================================================
|
|
@@ -188,16 +217,39 @@ def parse_model_output(output: str) -> Dict[str, Any]:
|
|
| 188 |
|
| 189 |
|
| 190 |
def reward_format(completions: List[str], **kwargs) -> List[float]:
|
| 191 |
-
"""Reward well-formed XML output structure."""
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
|
| 203 |
def reward_outcome_match(completions: List[str], episode_batch: List[Dict], **kwargs) -> List[float]:
|
|
@@ -234,7 +286,12 @@ def reward_flight_risk(completions: List[str], episode_batch: List[Dict], **kwar
|
|
| 234 |
|
| 235 |
|
| 236 |
def reward_statutory(completions: List[str], episode_batch: List[Dict], **kwargs) -> List[float]:
|
| 237 |
-
"""20% weight: correct statutory eligibility computation.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
scores = []
|
| 239 |
for comp, ep in zip(completions, episode_batch):
|
| 240 |
parsed = parse_model_output(comp)
|
|
@@ -242,19 +299,61 @@ def reward_statutory(completions: List[str], episode_batch: List[Dict], **kwargs
|
|
| 242 |
sections = ep.get("ipc_sections", [])
|
| 243 |
max_sent = ep.get("max_sentence_years", 5.0)
|
| 244 |
custody = ep.get("custody_months", 0.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
score = 0.0
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
scores.append(min(1.0, score))
|
| 259 |
return scores
|
| 260 |
|
|
@@ -309,14 +408,24 @@ def reward_no_bias(completions: List[str], episode_batch: List[Dict], **kwargs)
|
|
| 309 |
def combined_reward(
|
| 310 |
completions: List[str],
|
| 311 |
episode_batch: List[Dict],
|
|
|
|
| 312 |
**kwargs
|
| 313 |
) -> List[float]:
|
| 314 |
"""
|
| 315 |
Master reward combining all components.
|
| 316 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
|
| 318 |
Uses server/reward.py functions when available (Fix 1).
|
| 319 |
-
|
|
|
|
| 320 |
"""
|
| 321 |
rewards = []
|
| 322 |
|
|
@@ -359,12 +468,22 @@ def combined_reward(
|
|
| 359 |
b = reward_no_bias([comp], [ep])[0]
|
| 360 |
rq = 0.5 # Neutral when server functions unavailable
|
| 361 |
|
| 362 |
-
#
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
eff = 0.0
|
| 366 |
|
| 367 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
rewards.append(round(total, 4)) # No max(0.0) clamp — bias can go negative
|
| 369 |
return rewards
|
| 370 |
|
|
@@ -593,7 +712,7 @@ def train(
|
|
| 593 |
# Reward wrapper that unpacks the stored JSON episode
|
| 594 |
def reward_fn(completions: List[str], episode: List[str], **kwargs) -> List[float]:
|
| 595 |
ep_objs = [json.loads(e) for e in episode]
|
| 596 |
-
return combined_reward(completions, ep_objs)
|
| 597 |
|
| 598 |
# ── GRPO Config ──────────────────────────────────────────
|
| 599 |
from trl import GRPOConfig, GRPOTrainer # type: ignore
|
|
@@ -711,7 +830,7 @@ def evaluate_baseline(episodes_dir: str, n_samples: int = 20):
|
|
| 711 |
out = model.generate(inputs, max_new_tokens=512, temperature=0.7, do_sample=True)
|
| 712 |
|
| 713 |
completion = tokenizer.decode(out[0][inputs.shape[1]:], skip_special_tokens=True)
|
| 714 |
-
r = combined_reward([completion], [ep])[0]
|
| 715 |
rewards.append(r)
|
| 716 |
print(f" Case {ep['case_id']}: reward={r:.3f} | GT={ep['ground_truth']['outcome']}")
|
| 717 |
|
|
@@ -770,7 +889,7 @@ def evaluate_on_stage(
|
|
| 770 |
out = model.generate(inputs, max_new_tokens=512, temperature=0.7, do_sample=True)
|
| 771 |
|
| 772 |
completion = tokenizer.decode(out[0][inputs.shape[1]:], skip_special_tokens=True)
|
| 773 |
-
r = combined_reward([completion], [ep])[0]
|
| 774 |
rewards.append(r)
|
| 775 |
results.append({"episode": ep, "completion": completion, "reward": r})
|
| 776 |
|
|
@@ -916,10 +1035,7 @@ def train_curriculum(
|
|
| 916 |
|
| 917 |
def reward_fn(completions: List[str], episode: List[str], **kwargs) -> List[float]:
|
| 918 |
ep_objs = [json.loads(e) for e in episode]
|
| 919 |
-
|
| 920 |
-
# This keeps efficiency contribution honest rather than silently 0.0
|
| 921 |
-
step_counts = [1] * len(completions)
|
| 922 |
-
return combined_reward(completions, ep_objs, step_counts=step_counts)
|
| 923 |
|
| 924 |
stage_output = f"{output_dir}/stage_{stage}"
|
| 925 |
config = GRPOConfig(
|
|
@@ -1019,7 +1135,248 @@ def train_curriculum(
|
|
| 1019 |
|
| 1020 |
|
| 1021 |
# ============================================================
|
| 1022 |
-
# CELL 9 —
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1023 |
# ============================================================
|
| 1024 |
|
| 1025 |
if __name__ == "__main__":
|
|
@@ -1035,6 +1392,10 @@ if __name__ == "__main__":
|
|
| 1035 |
help="Run evaluation after training to measure improvement")
|
| 1036 |
parser.add_argument("--curriculum", action="store_true",
|
| 1037 |
help="Run self-improving curriculum training (all 4 stages)")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1038 |
|
| 1039 |
args = parser.parse_args()
|
| 1040 |
|
|
@@ -1047,6 +1408,15 @@ if __name__ == "__main__":
|
|
| 1047 |
max_steps_per_stage=args.steps,
|
| 1048 |
batch_size=args.batch_size,
|
| 1049 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1050 |
else:
|
| 1051 |
train(
|
| 1052 |
episodes_dir = args.episodes_dir,
|
|
|
|
| 52 |
compute_condition_score,
|
| 53 |
compute_bias_penalty as _server_bias,
|
| 54 |
compute_reasoning_quality,
|
| 55 |
+
compute_think_factor,
|
| 56 |
+
reward_format as server_reward_format,
|
| 57 |
+
_is_ndps_case,
|
| 58 |
)
|
| 59 |
_USE_SERVER_REWARDS = True
|
| 60 |
print("[reward] Using authoritative server/reward.py functions.")
|
| 61 |
except ImportError:
|
| 62 |
_USE_SERVER_REWARDS = False
|
| 63 |
print("[reward] server/reward.py not found — using local fallback functions.")
|
| 64 |
+
|
| 65 |
+
# Local fallback definition of _is_ndps_case (mirrors server/reward.py)
|
| 66 |
+
def _is_ndps_case(episode: dict) -> bool:
|
| 67 |
+
sections = " ".join(str(s) for s in episode.get("ipc_sections", [])).lower()
|
| 68 |
+
crime = str(episode.get("crime_type", "")).lower()
|
| 69 |
+
narcotics_indicators = [
|
| 70 |
+
"ndps", "narcotic", "drug", "psychotropic",
|
| 71 |
+
"20(b)", "22(b)", "27a", "section 37",
|
| 72 |
+
]
|
| 73 |
+
return any(ind in sections or ind in crime for ind in narcotics_indicators)
|
| 74 |
+
|
| 75 |
+
# Local fallback definition of compute_think_factor (mirrors server/reward.py)
|
| 76 |
+
def compute_think_factor(completion: str, current_stage: int) -> float:
|
| 77 |
+
if not completion:
|
| 78 |
+
return 0.3 if current_stage == 1 else 0.0
|
| 79 |
+
think_match = re.search(r'<think>(.*?)</think>', completion, re.DOTALL)
|
| 80 |
+
think_text = think_match.group(1).strip() if think_match else ""
|
| 81 |
+
think_len = len(think_text.split())
|
| 82 |
+
raw_factor = min(1.0, think_len / 120.0)
|
| 83 |
+
if current_stage == 1:
|
| 84 |
+
return 0.3 + 0.7 * raw_factor
|
| 85 |
+
else:
|
| 86 |
+
return raw_factor
|
| 87 |
+
|
| 88 |
+
# Local fallback server_reward_format
|
| 89 |
+
server_reward_format = None # Will use local reward_format below
|
| 90 |
from datasets import Dataset
|
| 91 |
|
| 92 |
# ============================================================
|
|
|
|
| 217 |
|
| 218 |
|
| 219 |
def reward_format(completions: List[str], **kwargs) -> List[float]:
|
| 220 |
+
"""Reward well-formed XML output structure (batch API for GRPO compatibility)."""
|
| 221 |
+
return [reward_format_single(c) for c in completions]
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def reward_format_single(completion: str) -> float:
|
| 225 |
+
"""
|
| 226 |
+
Score structural compliance of the bail memo.
|
| 227 |
+
Checks for required XML tags matching the system prompt and valid outcome.
|
| 228 |
+
Returns 0.0–1.0 (fraction of required elements present).
|
| 229 |
+
"""
|
| 230 |
+
if not completion:
|
| 231 |
+
return 0.0
|
| 232 |
+
# Tags match exactly what SYSTEM_PROMPT instructs the model to produce
|
| 233 |
+
required_tags = [
|
| 234 |
+
r'<think>',
|
| 235 |
+
r'<memo>',
|
| 236 |
+
r'<flight_risk>',
|
| 237 |
+
r'<statutory_eligible>',
|
| 238 |
+
r'<recommended_outcome>',
|
| 239 |
+
r'<statutory_computation>',
|
| 240 |
+
]
|
| 241 |
+
valid_outcomes = [
|
| 242 |
+
'bail granted', 'bail denied',
|
| 243 |
+
'conditional bail', 'default bail',
|
| 244 |
+
]
|
| 245 |
+
checks = [
|
| 246 |
+
bool(re.search(tag, completion, re.IGNORECASE))
|
| 247 |
+
for tag in required_tags
|
| 248 |
+
]
|
| 249 |
+
checks.append(
|
| 250 |
+
any(outcome in completion.lower() for outcome in valid_outcomes)
|
| 251 |
+
)
|
| 252 |
+
return sum(checks) / len(checks)
|
| 253 |
|
| 254 |
|
| 255 |
def reward_outcome_match(completions: List[str], episode_batch: List[Dict], **kwargs) -> List[float]:
|
|
|
|
| 286 |
|
| 287 |
|
| 288 |
def reward_statutory(completions: List[str], episode_batch: List[Dict], **kwargs) -> List[float]:
|
| 289 |
+
"""20% weight: correct statutory eligibility computation.
|
| 290 |
+
|
| 291 |
+
B3: Direction-gated computation bonus — wrong direction gets 0.10 not 0.30.
|
| 292 |
+
B9: NDPS cases use crime_type detection and reward Section 37 recognition.
|
| 293 |
+
"""
|
| 294 |
+
TIME_WORDS = ["month", "year", "sentence", "custody", "half", "served", "threshold"]
|
| 295 |
scores = []
|
| 296 |
for comp, ep in zip(completions, episode_batch):
|
| 297 |
parsed = parse_model_output(comp)
|
|
|
|
| 299 |
sections = ep.get("ipc_sections", [])
|
| 300 |
max_sent = ep.get("max_sentence_years", 5.0)
|
| 301 |
custody = ep.get("custody_months", 0.0)
|
| 302 |
+
special_laws = ep.get("special_laws", "").strip()
|
| 303 |
+
gt_outcome = ep.get("ground_truth", {}).get("outcome", "")
|
| 304 |
+
agent_eligible = parsed["statutory_eligible"]
|
| 305 |
+
|
| 306 |
+
# B9: NDPS-specific scoring
|
| 307 |
+
if _is_ndps_case(ep):
|
| 308 |
+
gt_granted = "grant" in gt_outcome.lower()
|
| 309 |
+
direction_correct = (agent_eligible == gt_granted)
|
| 310 |
+
ndps_recognized = any(
|
| 311 |
+
t in comp_text for t in ["section 37", "twin condition", "ndps", "37(1)(b)"]
|
| 312 |
+
)
|
| 313 |
+
if ndps_recognized and direction_correct:
|
| 314 |
+
scores.append(1.0)
|
| 315 |
+
elif direction_correct:
|
| 316 |
+
scores.append(0.5)
|
| 317 |
+
else:
|
| 318 |
+
scores.append(0.0)
|
| 319 |
+
continue
|
| 320 |
+
|
| 321 |
+
# Infer special law from crime_type
|
| 322 |
+
CRIME_TYPE_SPECIAL_LAWS = [
|
| 323 |
+
"narcotics", "ndps", "pocso", "uapa", "pmla",
|
| 324 |
+
"terrorism", "organised crime", "money laundering",
|
| 325 |
+
]
|
| 326 |
+
crime_type_lower = ep.get("crime_type", "").lower()
|
| 327 |
+
if not special_laws and any(t in crime_type_lower for t in CRIME_TYPE_SPECIAL_LAWS):
|
| 328 |
+
special_laws = "INFERRED"
|
| 329 |
+
|
| 330 |
+
# Standard IPC/BNSS threshold computation
|
| 331 |
+
half_sent_months = (max_sent * 12) / 2
|
| 332 |
+
truly_eligible = (custody >= half_sent_months) and not special_laws
|
| 333 |
|
| 334 |
score = 0.0
|
| 335 |
+
|
| 336 |
+
# 40%: eligibility direction
|
| 337 |
+
direction_correct = (agent_eligible == truly_eligible)
|
| 338 |
+
if direction_correct:
|
| 339 |
+
score += 0.4
|
| 340 |
+
elif (agent_eligible and "grant" in gt_outcome.lower()) or \
|
| 341 |
+
(not agent_eligible and "deni" in gt_outcome.lower()):
|
| 342 |
+
score += 0.2
|
| 343 |
+
|
| 344 |
+
# 30%: cited relevant sections
|
| 345 |
+
if sections:
|
| 346 |
+
hits = sum(1 for sec in sections if sec.strip().lower() in comp_text or sec.strip() in comp)
|
| 347 |
+
score += 0.3 * min(1.0, hits / len(sections))
|
| 348 |
+
|
| 349 |
+
# 30%: numeric computation (B3: direction-gated)
|
| 350 |
+
has_numbers = bool(re.search(r'\d+', comp_text))
|
| 351 |
+
has_time_ref = any(w in comp_text for w in TIME_WORDS)
|
| 352 |
+
if has_numbers and has_time_ref:
|
| 353 |
+
score += 0.3 if direction_correct else 0.10
|
| 354 |
+
elif has_numbers or has_time_ref:
|
| 355 |
+
score += 0.15 if direction_correct else 0.05
|
| 356 |
+
|
| 357 |
scores.append(min(1.0, score))
|
| 358 |
return scores
|
| 359 |
|
|
|
|
| 408 |
def combined_reward(
|
| 409 |
completions: List[str],
|
| 410 |
episode_batch: List[Dict],
|
| 411 |
+
current_stage: int = 1,
|
| 412 |
**kwargs
|
| 413 |
) -> List[float]:
|
| 414 |
"""
|
| 415 |
Master reward combining all components.
|
| 416 |
+
|
| 417 |
+
Formula (B6/B8 update):
|
| 418 |
+
R = 0.4*outcome_gated + 0.2*flight_risk + 0.2*statutory + 0.2*condition
|
| 419 |
+
+ 0.1*reasoning_quality + 0.05*format
|
| 420 |
+
- 0.3*bias
|
| 421 |
+
|
| 422 |
+
Core (sum=1.0): 0.4*om_gated + 0.2*fr + 0.2*s + 0.2*ca
|
| 423 |
+
Bonuses: 0.1*rq + 0.05*fmt
|
| 424 |
+
Penalty: -0.3*bias
|
| 425 |
|
| 426 |
Uses server/reward.py functions when available (Fix 1).
|
| 427 |
+
B6: Outcome gated by think_factor (stage-aware).
|
| 428 |
+
B8: Format compliance score included with 0.05 weight.
|
| 429 |
"""
|
| 430 |
rewards = []
|
| 431 |
|
|
|
|
| 468 |
b = reward_no_bias([comp], [ep])[0]
|
| 469 |
rq = 0.5 # Neutral when server functions unavailable
|
| 470 |
|
| 471 |
+
# B6: Gate outcome credit on reasoning quality (think block)
|
| 472 |
+
think_factor = compute_think_factor(comp, current_stage)
|
| 473 |
+
om_gated = o * think_factor
|
|
|
|
| 474 |
|
| 475 |
+
# B8: Format compliance score
|
| 476 |
+
if _USE_SERVER_REWARDS and server_reward_format is not None:
|
| 477 |
+
fmt = server_reward_format(comp)
|
| 478 |
+
else:
|
| 479 |
+
fmt = reward_format_single(comp)
|
| 480 |
+
|
| 481 |
+
# Reward formula:
|
| 482 |
+
# Core (sum=1.0): 0.4*outcome_gated + 0.2*flight + 0.2*statutory + 0.2*conditions
|
| 483 |
+
# Bonuses: 0.1*reasoning_quality + 0.05*format
|
| 484 |
+
# Penalty: -0.3*bias
|
| 485 |
+
total = (0.4*om_gated + 0.2*fr + 0.2*s + 0.2*ca
|
| 486 |
+
+ 0.1*rq + 0.05*fmt - 0.3*b)
|
| 487 |
rewards.append(round(total, 4)) # No max(0.0) clamp — bias can go negative
|
| 488 |
return rewards
|
| 489 |
|
|
|
|
| 712 |
# Reward wrapper that unpacks the stored JSON episode
|
| 713 |
def reward_fn(completions: List[str], episode: List[str], **kwargs) -> List[float]:
|
| 714 |
ep_objs = [json.loads(e) for e in episode]
|
| 715 |
+
return combined_reward(completions, ep_objs, current_stage=stage)
|
| 716 |
|
| 717 |
# ── GRPO Config ──────────────────────────────────────────
|
| 718 |
from trl import GRPOConfig, GRPOTrainer # type: ignore
|
|
|
|
| 830 |
out = model.generate(inputs, max_new_tokens=512, temperature=0.7, do_sample=True)
|
| 831 |
|
| 832 |
completion = tokenizer.decode(out[0][inputs.shape[1]:], skip_special_tokens=True)
|
| 833 |
+
r = combined_reward([completion], [ep], current_stage=1)[0]
|
| 834 |
rewards.append(r)
|
| 835 |
print(f" Case {ep['case_id']}: reward={r:.3f} | GT={ep['ground_truth']['outcome']}")
|
| 836 |
|
|
|
|
| 889 |
out = model.generate(inputs, max_new_tokens=512, temperature=0.7, do_sample=True)
|
| 890 |
|
| 891 |
completion = tokenizer.decode(out[0][inputs.shape[1]:], skip_special_tokens=True)
|
| 892 |
+
r = combined_reward([completion], [ep], current_stage=stage)[0]
|
| 893 |
rewards.append(r)
|
| 894 |
results.append({"episode": ep, "completion": completion, "reward": r})
|
| 895 |
|
|
|
|
| 1035 |
|
| 1036 |
def reward_fn(completions: List[str], episode: List[str], **kwargs) -> List[float]:
|
| 1037 |
ep_objs = [json.loads(e) for e in episode]
|
| 1038 |
+
return combined_reward(completions, ep_objs, current_stage=stage)
|
|
|
|
|
|
|
|
|
|
| 1039 |
|
| 1040 |
stage_output = f"{output_dir}/stage_{stage}"
|
| 1041 |
config = GRPOConfig(
|
|
|
|
| 1135 |
|
| 1136 |
|
| 1137 |
# ============================================================
|
| 1138 |
+
# CELL 9 — Adaptive Training (Theme 4: Self-Improvement)
|
| 1139 |
+
# ============================================================
|
| 1140 |
+
|
| 1141 |
+
def train_adaptive(
|
| 1142 |
+
episodes_dir: str = "./data/episodes",
|
| 1143 |
+
output_dir: str = "./output/undertrial_adaptive",
|
| 1144 |
+
steps_per_assessment: int = 50,
|
| 1145 |
+
max_total_steps: int = 2000,
|
| 1146 |
+
batch_size: int = 4,
|
| 1147 |
+
grad_accum: int = 4,
|
| 1148 |
+
lr: float = 5e-6,
|
| 1149 |
+
base_url: str = "http://localhost:8000",
|
| 1150 |
+
):
|
| 1151 |
+
"""
|
| 1152 |
+
Self-directed curriculum training (Theme 4).
|
| 1153 |
+
|
| 1154 |
+
Uses the /profile endpoint to check stage readiness every
|
| 1155 |
+
steps_per_assessment steps and promotes automatically.
|
| 1156 |
+
|
| 1157 |
+
This function communicates with the server via HTTP — it does NOT
|
| 1158 |
+
import server internals. OpenEnv client/server separation is preserved.
|
| 1159 |
+
|
| 1160 |
+
Training loop:
|
| 1161 |
+
1. Start at stage 1
|
| 1162 |
+
2. Train for steps_per_assessment steps
|
| 1163 |
+
3. Query /profile for suggested_stage
|
| 1164 |
+
4. If suggested_stage > current_stage, promote
|
| 1165 |
+
5. Repeat until max_total_steps or stage 4 mastered
|
| 1166 |
+
"""
|
| 1167 |
+
print("=" * 60)
|
| 1168 |
+
print(" UndertriAI — Adaptive Self-Improvement Training")
|
| 1169 |
+
print(f" Assessment every {steps_per_assessment} steps | Max {max_total_steps} steps")
|
| 1170 |
+
print(f" Server: {base_url}")
|
| 1171 |
+
print("=" * 60)
|
| 1172 |
+
|
| 1173 |
+
from unsloth import FastLanguageModel # type: ignore
|
| 1174 |
+
from trl import GRPOConfig, GRPOTrainer # type: ignore
|
| 1175 |
+
|
| 1176 |
+
# Load model once
|
| 1177 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 1178 |
+
model_name="unsloth/Qwen2.5-7B-Instruct",
|
| 1179 |
+
max_seq_length=3072,
|
| 1180 |
+
load_in_4bit=True,
|
| 1181 |
+
fast_inference=False,
|
| 1182 |
+
)
|
| 1183 |
+
model = FastLanguageModel.get_peft_model(
|
| 1184 |
+
model,
|
| 1185 |
+
r=16,
|
| 1186 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 1187 |
+
"gate_proj", "up_proj", "down_proj"],
|
| 1188 |
+
lora_alpha=16, lora_dropout=0, bias="none",
|
| 1189 |
+
use_gradient_checkpointing="unsloth", random_state=42,
|
| 1190 |
+
)
|
| 1191 |
+
|
| 1192 |
+
# HTTP helper for server communication
|
| 1193 |
+
def query_profile(session_id: str) -> Optional[Dict]:
|
| 1194 |
+
"""Query the performance profile from the server via HTTP."""
|
| 1195 |
+
try:
|
| 1196 |
+
url = f"{base_url}/profile?session_id={urllib.parse.quote(session_id)}"
|
| 1197 |
+
req = urllib.request.Request(url)
|
| 1198 |
+
with urllib.request.urlopen(req, timeout=5.0) as resp:
|
| 1199 |
+
return json.loads(resp.read())
|
| 1200 |
+
except Exception as e:
|
| 1201 |
+
print(f" [adaptive] Could not reach server profile: {e}")
|
| 1202 |
+
return None
|
| 1203 |
+
|
| 1204 |
+
def notify_reset(session_id: str, stage: int) -> Optional[str]:
|
| 1205 |
+
"""Call /reset with adaptive=true on the server."""
|
| 1206 |
+
try:
|
| 1207 |
+
url = f"{base_url}/reset?session_id={urllib.parse.quote(session_id)}&stage={stage}&adaptive=true"
|
| 1208 |
+
req = urllib.request.Request(url, method="POST")
|
| 1209 |
+
with urllib.request.urlopen(req, timeout=5.0) as resp:
|
| 1210 |
+
data = json.loads(resp.read())
|
| 1211 |
+
return data.get("session_id", session_id)
|
| 1212 |
+
except Exception:
|
| 1213 |
+
return None
|
| 1214 |
+
|
| 1215 |
+
current_stage = 1
|
| 1216 |
+
total_steps = 0
|
| 1217 |
+
session_id = f"adaptive_{uuid.uuid4().hex[:8]}" if 'uuid' in dir() else "adaptive_training"
|
| 1218 |
+
|
| 1219 |
+
# Try to initialise session on server
|
| 1220 |
+
import uuid as _uuid_mod
|
| 1221 |
+
session_id = f"adaptive_{_uuid_mod.uuid4().hex[:8]}"
|
| 1222 |
+
notify_reset(session_id, current_stage)
|
| 1223 |
+
|
| 1224 |
+
# Tracking
|
| 1225 |
+
stage_promotion_steps = []
|
| 1226 |
+
reward_curve = []
|
| 1227 |
+
stage_rewards = {1: [], 2: [], 3: [], 4: []}
|
| 1228 |
+
|
| 1229 |
+
while total_steps < max_total_steps:
|
| 1230 |
+
print(f"\n{'━' * 60}")
|
| 1231 |
+
print(f" ADAPTIVE BLOCK: Steps {total_steps}–{total_steps + steps_per_assessment}")
|
| 1232 |
+
print(f" Current Stage: {current_stage} — {STAGE_NAMES.get(current_stage, '?')}")
|
| 1233 |
+
print(f"{'━' * 60}")
|
| 1234 |
+
|
| 1235 |
+
# Load episodes for current stage
|
| 1236 |
+
try:
|
| 1237 |
+
episodes = load_episodes(episodes_dir, stage=current_stage, split="train")
|
| 1238 |
+
except FileNotFoundError:
|
| 1239 |
+
print(f" No episodes for stage {current_stage} — breaking")
|
| 1240 |
+
break
|
| 1241 |
+
|
| 1242 |
+
if not episodes:
|
| 1243 |
+
print(f" Empty episode list for stage {current_stage} — breaking")
|
| 1244 |
+
break
|
| 1245 |
+
|
| 1246 |
+
# Build dataset
|
| 1247 |
+
dataset = build_hf_dataset(episodes, tokenizer)
|
| 1248 |
+
stage_for_closure = current_stage # Capture for closure
|
| 1249 |
+
|
| 1250 |
+
def reward_fn(completions: List[str], episode: List[str], **kwargs) -> List[float]:
|
| 1251 |
+
ep_objs = [json.loads(e) for e in episode]
|
| 1252 |
+
return combined_reward(completions, ep_objs, current_stage=stage_for_closure)
|
| 1253 |
+
|
| 1254 |
+
block_output = f"{output_dir}/block_{total_steps}"
|
| 1255 |
+
config = GRPOConfig(
|
| 1256 |
+
output_dir=block_output,
|
| 1257 |
+
learning_rate=lr,
|
| 1258 |
+
per_device_train_batch_size=batch_size,
|
| 1259 |
+
gradient_accumulation_steps=grad_accum,
|
| 1260 |
+
num_train_epochs=1,
|
| 1261 |
+
max_steps=steps_per_assessment,
|
| 1262 |
+
num_generations=6,
|
| 1263 |
+
max_completion_length=1024,
|
| 1264 |
+
temperature=0.7,
|
| 1265 |
+
beta=0.01,
|
| 1266 |
+
logging_steps=5,
|
| 1267 |
+
save_steps=steps_per_assessment,
|
| 1268 |
+
report_to="none",
|
| 1269 |
+
remove_unused_columns=False,
|
| 1270 |
+
)
|
| 1271 |
+
|
| 1272 |
+
FastLanguageModel.for_training(model)
|
| 1273 |
+
trainer = GRPOTrainer(
|
| 1274 |
+
model=model,
|
| 1275 |
+
processing_class=tokenizer,
|
| 1276 |
+
config=config,
|
| 1277 |
+
train_dataset=dataset,
|
| 1278 |
+
reward_funcs=[reward_fn],
|
| 1279 |
+
)
|
| 1280 |
+
trainer.train()
|
| 1281 |
+
total_steps += steps_per_assessment
|
| 1282 |
+
|
| 1283 |
+
# Evaluate current performance
|
| 1284 |
+
eval_reward, _ = evaluate_on_stage(
|
| 1285 |
+
model, tokenizer, episodes_dir, stage=current_stage, n_samples=15
|
| 1286 |
+
)
|
| 1287 |
+
stage_rewards[current_stage].append(eval_reward)
|
| 1288 |
+
reward_curve.append((total_steps, round(eval_reward, 4)))
|
| 1289 |
+
print(f" Stage {current_stage} eval reward: {eval_reward:.4f}")
|
| 1290 |
+
|
| 1291 |
+
# Query server for stage promotion suggestion
|
| 1292 |
+
profile_data = query_profile(session_id)
|
| 1293 |
+
suggested_stage = current_stage
|
| 1294 |
+
|
| 1295 |
+
if profile_data and "profile" in profile_data:
|
| 1296 |
+
suggested_stage = profile_data["profile"].get(
|
| 1297 |
+
"suggested_stage", current_stage
|
| 1298 |
+
)
|
| 1299 |
+
else:
|
| 1300 |
+
# Fallback: use local heuristic
|
| 1301 |
+
if eval_reward >= 0.65 and current_stage == 1:
|
| 1302 |
+
suggested_stage = 2
|
| 1303 |
+
elif eval_reward >= 0.55 and current_stage == 2:
|
| 1304 |
+
suggested_stage = 3
|
| 1305 |
+
elif eval_reward >= 0.50 and current_stage == 3:
|
| 1306 |
+
suggested_stage = 4
|
| 1307 |
+
|
| 1308 |
+
if suggested_stage > current_stage:
|
| 1309 |
+
old_stage = current_stage
|
| 1310 |
+
old_reward = eval_reward
|
| 1311 |
+
current_stage = suggested_stage
|
| 1312 |
+
stage_promotion_steps.append(
|
| 1313 |
+
(total_steps, old_stage, current_stage, round(old_reward, 4))
|
| 1314 |
+
)
|
| 1315 |
+
print(
|
| 1316 |
+
f"[SELF-IMPROVEMENT] Step {total_steps}: "
|
| 1317 |
+
f"Promoted to Stage {current_stage}. "
|
| 1318 |
+
f"Stage {old_stage} mean reward: {old_reward:.3f} → "
|
| 1319 |
+
f"Stage {current_stage} begins."
|
| 1320 |
+
)
|
| 1321 |
+
# Notify server of promotion
|
| 1322 |
+
notify_reset(session_id, current_stage)
|
| 1323 |
+
|
| 1324 |
+
# Check completion
|
| 1325 |
+
if current_stage == 4:
|
| 1326 |
+
s4_rewards = stage_rewards.get(4, [])
|
| 1327 |
+
if s4_rewards and s4_rewards[-1] >= 0.50:
|
| 1328 |
+
print(
|
| 1329 |
+
f"\n[SELF-IMPROVEMENT] Stage 4 mastered at step {total_steps}! "
|
| 1330 |
+
f"Reward: {s4_rewards[-1]:.3f}"
|
| 1331 |
+
)
|
| 1332 |
+
break
|
| 1333 |
+
|
| 1334 |
+
# Save checkpoint
|
| 1335 |
+
model.save_pretrained(block_output, save_adapters_only=True)
|
| 1336 |
+
tokenizer.save_pretrained(block_output)
|
| 1337 |
+
|
| 1338 |
+
# ── Final summary ──
|
| 1339 |
+
print(f"\n{'═' * 60}")
|
| 1340 |
+
print(" ADAPTIVE TRAINING COMPLETE")
|
| 1341 |
+
print(f"{'═' * 60}")
|
| 1342 |
+
print(f" Total steps: {total_steps}")
|
| 1343 |
+
print(f" Stage promotions: {len(stage_promotion_steps)}")
|
| 1344 |
+
for step_n, from_s, to_s, reward in stage_promotion_steps:
|
| 1345 |
+
print(f" Step {step_n}: Stage {from_s} → {to_s} (reward {reward:.3f})")
|
| 1346 |
+
print(f" Final stage: {current_stage}")
|
| 1347 |
+
|
| 1348 |
+
# Compute final reward per stage
|
| 1349 |
+
final_reward_per_stage = {}
|
| 1350 |
+
for s, rewards_list in stage_rewards.items():
|
| 1351 |
+
if rewards_list:
|
| 1352 |
+
final_reward_per_stage[str(s)] = round(rewards_list[-1], 4)
|
| 1353 |
+
|
| 1354 |
+
# Save results
|
| 1355 |
+
results = {
|
| 1356 |
+
"stage_promotion_steps": [
|
| 1357 |
+
{"step": s, "from_stage": f, "to_stage": t, "reward": r}
|
| 1358 |
+
for s, f, t, r in stage_promotion_steps
|
| 1359 |
+
],
|
| 1360 |
+
"final_reward_per_stage": final_reward_per_stage,
|
| 1361 |
+
"total_steps_completed": total_steps,
|
| 1362 |
+
"reward_curve": [{"step": s, "reward": r} for s, r in reward_curve],
|
| 1363 |
+
}
|
| 1364 |
+
results_path = Path(output_dir) / "results_adaptive.json"
|
| 1365 |
+
results_path.parent.mkdir(parents=True, exist_ok=True)
|
| 1366 |
+
results_path.write_text(json.dumps(results, indent=2))
|
| 1367 |
+
print(f"\n Results saved: {results_path}")
|
| 1368 |
+
|
| 1369 |
+
# Save final model
|
| 1370 |
+
final_dir = f"{output_dir}/final"
|
| 1371 |
+
model.save_pretrained(final_dir, save_adapters_only=True)
|
| 1372 |
+
tokenizer.save_pretrained(final_dir)
|
| 1373 |
+
print(f" Final model saved: {final_dir}")
|
| 1374 |
+
|
| 1375 |
+
return results
|
| 1376 |
+
|
| 1377 |
+
|
| 1378 |
+
# ============================================================
|
| 1379 |
+
# CELL 10 — Entry point
|
| 1380 |
# ============================================================
|
| 1381 |
|
| 1382 |
if __name__ == "__main__":
|
|
|
|
| 1392 |
help="Run evaluation after training to measure improvement")
|
| 1393 |
parser.add_argument("--curriculum", action="store_true",
|
| 1394 |
help="Run self-improving curriculum training (all 4 stages)")
|
| 1395 |
+
parser.add_argument("--adaptive", action="store_true",
|
| 1396 |
+
help="Run adaptive self-improvement training (Theme 4)")
|
| 1397 |
+
parser.add_argument("--env_url", default="http://localhost:8000",
|
| 1398 |
+
help="Server URL for adaptive training")
|
| 1399 |
|
| 1400 |
args = parser.parse_args()
|
| 1401 |
|
|
|
|
| 1408 |
max_steps_per_stage=args.steps,
|
| 1409 |
batch_size=args.batch_size,
|
| 1410 |
)
|
| 1411 |
+
elif args.adaptive:
|
| 1412 |
+
train_adaptive(
|
| 1413 |
+
episodes_dir=args.episodes_dir,
|
| 1414 |
+
output_dir=args.output,
|
| 1415 |
+
steps_per_assessment=args.steps,
|
| 1416 |
+
max_total_steps=2000,
|
| 1417 |
+
batch_size=args.batch_size,
|
| 1418 |
+
base_url=args.env_url,
|
| 1419 |
+
)
|
| 1420 |
else:
|
| 1421 |
train(
|
| 1422 |
episodes_dir = args.episodes_dir,
|