Draken1606 commited on
Commit
6218d9a
·
1 Parent(s): d76d092

Fix 5 audit gaps: conditional bail, action history, efficiency reward, train/val split, env API routing

Browse files
models.py CHANGED
@@ -121,8 +121,8 @@ class SubmitMemoAction(Action):
121
  )
122
 
123
  # Recommendation
124
- recommended_outcome: Literal["Bail Granted", "Bail Denied"] = Field(
125
- ..., description="Final recommendation"
126
  )
127
  recommended_conditions: Optional[List[str]] = Field(
128
  None,
@@ -186,6 +186,10 @@ class CaseObservation(Observation):
186
 
187
  # Episode state
188
  action_result: Optional[str] = None
 
 
 
 
189
  flags_raised: List[str] = Field(default_factory=list)
190
  precedents_retrieved: List[str] = Field(default_factory=list)
191
  memo_submitted: bool = False
 
121
  )
122
 
123
  # Recommendation
124
+ recommended_outcome: Literal["Bail Granted", "Bail Denied", "Bail Conditional"] = Field(
125
+ ..., description="Final recommendation: Bail Granted | Bail Denied | Bail Conditional (strict conditions)"
126
  )
127
  recommended_conditions: Optional[List[str]] = Field(
128
  None,
 
186
 
187
  # Episode state
188
  action_result: Optional[str] = None
189
+ action_history: List[str] = Field(
190
+ default_factory=list,
191
+ description="Ordered log of all tool results seen so far this episode",
192
+ )
193
  flags_raised: List[str] = Field(default_factory=list)
194
  precedents_retrieved: List[str] = Field(default_factory=list)
195
  memo_submitted: bool = False
server/reward.py CHANGED
@@ -20,17 +20,26 @@ def compute_outcome_match(agent_outcome: str, ground_truth: Dict[str, Any]) -> f
20
  Checks if the agent's final recommendation matches the High Court decision.
21
 
22
  Scoring:
23
- 1.0 — Exact string match (e.g. "Bail Granted" == "Bail Granted")
24
- 0.8Directionally correct (agent says "granted", GT says "Bail Granted")
 
25
  0.0 — Wrong direction (granted vs. denied, or vice versa)
26
  """
27
  gt = ground_truth["outcome"]
 
 
28
 
29
- if agent_outcome.strip().lower() == gt.strip().lower():
30
  return 1.0
31
 
32
- agent_granted = "grant" in agent_outcome.lower()
33
- gt_granted = "grant" in gt.lower()
 
 
 
 
 
 
34
 
35
  return 0.8 if (agent_granted == gt_granted) else 0.0
36
 
@@ -334,6 +343,8 @@ def compute_reward(
334
  agent_computation: str,
335
  agent_conditions: List[str],
336
  episode: Dict[str, Any],
 
 
337
  ) -> Dict[str, float]:
338
  """
339
  Computes the full reward for a submitted bail assessment memo.
@@ -343,10 +354,11 @@ def compute_reward(
343
  + 0.2*flight_risk_accuracy
344
  + 0.2*statutory_accuracy
345
  + 0.2*condition_appropriateness
 
346
  - 0.3*bias_penalty
347
 
348
  Returns a dict with all component scores + total_reward.
349
- Range: [-0.3, 1.0] (bias penalty can produce negative totals this is intentional).
350
  """
351
  gt = episode["ground_truth"]
352
 
@@ -356,16 +368,26 @@ def compute_reward(
356
  ca = compute_condition_score(agent_outcome, agent_conditions, gt)
357
  bias = compute_bias_penalty(agent_outcome, episode)
358
 
 
 
 
 
 
 
 
 
359
  lam = 0.3
360
- total = 0.4*om + 0.2*fr + 0.2*sa + 0.2*ca - lam*bias
361
 
362
  return {
363
- "outcome_match": round(om, 4),
364
- "flight_risk_accuracy": round(fr, 4),
365
- "statutory_accuracy": round(sa, 4),
366
- "condition_appropriateness": round(ca, 4),
367
- "bias_penalty": round(bias, 4),
368
- "total_reward": round(total, 4),
 
369
  "ground_truth_outcome": gt["outcome"],
370
  "agent_outcome": agent_outcome,
 
371
  }
 
20
  Checks if the agent's final recommendation matches the High Court decision.
21
 
22
  Scoring:
23
+ 1.0 — Exact string match
24
+ 0.9"Bail Conditional" vs "Bail Granted" (conditional IS bail)
25
+ 0.8 — Directionally correct but loose string
26
  0.0 — Wrong direction (granted vs. denied, or vice versa)
27
  """
28
  gt = ground_truth["outcome"]
29
+ agent_norm = agent_outcome.strip().lower()
30
+ gt_norm = gt.strip().lower()
31
 
32
+ if agent_norm == gt_norm:
33
  return 1.0
34
 
35
+ # Conditional bail counts almost as well as full bail
36
+ if "conditional" in agent_norm and "grant" in gt_norm:
37
+ return 0.9
38
+ if "grant" in agent_norm and "conditional" in gt_norm:
39
+ return 0.9
40
+
41
+ agent_granted = "grant" in agent_norm or "conditional" in agent_norm
42
+ gt_granted = "grant" in gt_norm or "conditional" in gt_norm
43
 
44
  return 0.8 if (agent_granted == gt_granted) else 0.0
45
 
 
343
  agent_computation: str,
344
  agent_conditions: List[str],
345
  episode: Dict[str, Any],
346
+ step_count: int = 0,
347
+ max_steps: int = 10,
348
  ) -> Dict[str, float]:
349
  """
350
  Computes the full reward for a submitted bail assessment memo.
 
354
  + 0.2*flight_risk_accuracy
355
  + 0.2*statutory_accuracy
356
  + 0.2*condition_appropriateness
357
+ + 0.1*efficiency_bonus (only when outcome is correct)
358
  - 0.3*bias_penalty
359
 
360
  Returns a dict with all component scores + total_reward.
361
+ Range: [-0.3, 1.1] (efficiency can push above 1.0 slightly on perfect runs).
362
  """
363
  gt = episode["ground_truth"]
364
 
 
368
  ca = compute_condition_score(agent_outcome, agent_conditions, gt)
369
  bias = compute_bias_penalty(agent_outcome, episode)
370
 
371
+ # R4 — Efficiency bonus: reward finishing faster when the answer is correct.
372
+ # Only fires on directionally-correct outcomes (om >= 0.8) to prevent
373
+ # rewarding efficient-but-wrong agents.
374
+ efficiency = 0.0
375
+ if om >= 0.8 and max_steps > 1:
376
+ efficiency = round((1.0 - (step_count - 1) / (max_steps - 1)), 4)
377
+ efficiency = max(0.0, min(1.0, efficiency))
378
+
379
  lam = 0.3
380
+ total = 0.4*om + 0.2*fr + 0.2*sa + 0.2*ca + 0.1*efficiency - lam*bias
381
 
382
  return {
383
+ "outcome_match": round(om, 4),
384
+ "flight_risk_accuracy": round(fr, 4),
385
+ "statutory_accuracy": round(sa, 4),
386
+ "condition_appropriateness": round(ca, 4),
387
+ "efficiency_bonus": round(efficiency, 4),
388
+ "bias_penalty": round(bias, 4),
389
+ "total_reward": round(total, 4),
390
  "ground_truth_outcome": gt["outcome"],
391
  "agent_outcome": agent_outcome,
392
+ "steps_used": step_count,
393
  }
server/undertrial_environment.py CHANGED
@@ -86,6 +86,7 @@ class UndertriAIEnvironment(Environment):
86
  self._step_count = 0
87
  self._flags = []
88
  self._retrieved_precedents = []
 
89
  return self._make_observation(action_result=None)
90
 
91
  def step(
@@ -113,6 +114,8 @@ class UndertriAIEnvironment(Environment):
113
  agent_computation = action.statutory_computation,
114
  agent_conditions = action.recommended_conditions or [],
115
  episode = self._episode,
 
 
116
  )
117
  # Apply skip penalty (can push total legitimately negative)
118
  reward_dict["total_reward"] = round(reward_dict["total_reward"] - no_tool_penalty, 4)
@@ -149,6 +152,10 @@ class UndertriAIEnvironment(Environment):
149
  else:
150
  result = self._dispatch_tool(action)
151
 
 
 
 
 
152
  # Force submit if max steps reached
153
  done = (self._step_count >= self.MAX_STEPS)
154
  reward = -0.1 if done else 0.0 # Small penalty for exhausting budget
@@ -277,6 +284,7 @@ class UndertriAIEnvironment(Environment):
277
  cited_precedents = init_precedents + self._retrieved_precedents,
278
  documents_available = ep.get("documents_available", []),
279
  action_result = action_result,
 
280
  flags_raised = list(self._flags),
281
  precedents_retrieved = list(self._retrieved_precedents),
282
  memo_submitted = memo_submitted,
 
86
  self._step_count = 0
87
  self._flags = []
88
  self._retrieved_precedents = []
89
+ self._action_history: List[str] = [] # accumulated tool results (Gap 4)
90
  return self._make_observation(action_result=None)
91
 
92
  def step(
 
114
  agent_computation = action.statutory_computation,
115
  agent_conditions = action.recommended_conditions or [],
116
  episode = self._episode,
117
+ step_count = self._step_count, # Gap 5: efficiency reward
118
+ max_steps = self.MAX_STEPS,
119
  )
120
  # Apply skip penalty (can push total legitimately negative)
121
  reward_dict["total_reward"] = round(reward_dict["total_reward"] - no_tool_penalty, 4)
 
152
  else:
153
  result = self._dispatch_tool(action)
154
 
155
+ # Accumulate action history (Gap 4)
156
+ summary = f"[Step {self._step_count}] {type(action).__name__}: {result[:120]}..."
157
+ self._action_history.append(summary)
158
+
159
  # Force submit if max steps reached
160
  done = (self._step_count >= self.MAX_STEPS)
161
  reward = -0.1 if done else 0.0 # Small penalty for exhausting budget
 
284
  cited_precedents = init_precedents + self._retrieved_precedents,
285
  documents_available = ep.get("documents_available", []),
286
  action_result = action_result,
287
+ action_history = list(self._action_history), # Gap 4
288
  flags_raised = list(self._flags),
289
  precedents_retrieved = list(self._retrieved_precedents),
290
  memo_submitted = memo_submitted,
training/train_grpo.py CHANGED
@@ -25,12 +25,20 @@ INSTALL_COMMANDS = """
25
  # CELL 2 — Imports
26
  # ============================================================
27
 
28
- import os, sys, json, re, argparse, random
29
  from pathlib import Path
30
- from typing import List, Dict, Any, Optional
 
 
31
 
32
  import torch
33
 
 
 
 
 
 
 
34
  # ── Fix 1: Import authoritative reward functions from server/reward.py ──────
35
  # This ensures training optimises the SAME signal the deployed demo evaluates.
36
  try:
@@ -334,7 +342,15 @@ def combined_reward(
334
  ca = reward_conditions([comp], [ep])[0] # condition score, not format
335
  b = reward_no_bias([comp], [ep])[0]
336
 
337
- total = 0.4*o + 0.2*fr + 0.2*s + 0.2*ca - 0.3*b
 
 
 
 
 
 
 
 
338
  rewards.append(round(total, 4)) # No max(0.0) clamp — bias can go negative
339
  return rewards
340
 
@@ -343,15 +359,117 @@ def combined_reward(
343
  # CELL 5 — Dataset builder
344
  # ============================================================
345
 
346
- def load_episodes(episodes_dir: str, stage: int = 1) -> List[Dict]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  path = Path(episodes_dir) / f"episodes_stage_{stage}.jsonl"
348
  if not path.exists():
349
- # Try the combined file
350
  path = Path(episodes_dir) / "episodes_all.jsonl"
351
  if not path.exists():
352
- raise FileNotFoundError(f"No episodes found in {episodes_dir}. Run data/prepare_dataset.py first.")
353
  with open(path, encoding="utf-8") as f:
354
- return [json.loads(l) for l in f if l.strip()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
 
357
  def build_hf_dataset(episodes: List[Dict], tokenizer) -> Dataset:
 
25
  # CELL 2 — Imports
26
  # ============================================================
27
 
28
+ import os, sys, json, re, argparse, random, time
29
  from pathlib import Path
30
+ from typing import List, Dict, Any, Optional, Tuple
31
+ import urllib.request
32
+ import urllib.parse
33
 
34
  import torch
35
 
36
+ # ── Environment API (Gap 1) ─────────────────────────────────────────────────
37
+ ENV_API_URL = os.environ.get(
38
+ "UNDERTRIAL_ENV_URL",
39
+ "https://draken1606-undertrial-ai.hf.space",
40
+ )
41
+
42
  # ── Fix 1: Import authoritative reward functions from server/reward.py ──────
43
  # This ensures training optimises the SAME signal the deployed demo evaluates.
44
  try:
 
342
  ca = reward_conditions([comp], [ep])[0] # condition score, not format
343
  b = reward_no_bias([comp], [ep])[0]
344
 
345
+ # R4 efficiency bonus: reward fewer steps when outcome is correct
346
+ eff = 0.0
347
+ if o >= 0.8:
348
+ steps_taken = kwargs.get("step_counts", [None] * len(completions))
349
+ sc = steps_taken[completions.index(comp)] if comp in completions else None
350
+ if sc is not None:
351
+ eff = max(0.0, 1.0 - (sc - 1) / 9)
352
+
353
+ total = 0.4*o + 0.2*fr + 0.2*s + 0.2*ca + 0.1*eff - 0.3*b
354
  rewards.append(round(total, 4)) # No max(0.0) clamp — bias can go negative
355
  return rewards
356
 
 
359
  # CELL 5 — Dataset builder
360
  # ============================================================
361
 
362
+ def load_episodes(
363
+ episodes_dir: str,
364
+ stage: int = 1,
365
+ split: str = "train",
366
+ val_fraction: float = 0.15,
367
+ test_fraction: float = 0.10,
368
+ ) -> List[Dict]:
369
+ """
370
+ Load episodes for a given split (Gap 2: train/val/test split).
371
+
372
+ Split fractions (applied deterministically by index, no shuffle):
373
+ train = first (1 - val - test) fraction
374
+ val = next val_fraction
375
+ test = last test_fraction
376
+ """
377
  path = Path(episodes_dir) / f"episodes_stage_{stage}.jsonl"
378
  if not path.exists():
 
379
  path = Path(episodes_dir) / "episodes_all.jsonl"
380
  if not path.exists():
381
+ raise FileNotFoundError(f"No episodes found in {episodes_dir}.")
382
  with open(path, encoding="utf-8") as f:
383
+ all_eps = [json.loads(l) for l in f if l.strip()]
384
+
385
+ n = len(all_eps)
386
+ n_test = max(1, int(n * test_fraction))
387
+ n_val = max(1, int(n * val_fraction))
388
+ n_train = n - n_val - n_test
389
+
390
+ if split == "train":
391
+ return all_eps[:n_train]
392
+ elif split == "val":
393
+ return all_eps[n_train:n_train + n_val]
394
+ elif split == "test":
395
+ return all_eps[n_train + n_val:]
396
+ else:
397
+ return all_eps # all: for backward compat
398
+
399
+
400
+ def rollout_via_env_api(
401
+ completion: str,
402
+ episode: Dict,
403
+ env_url: str = ENV_API_URL,
404
+ session_id: Optional[str] = None,
405
+ timeout: float = 10.0,
406
+ ) -> float:
407
+ """
408
+ Gap 1: Route reward through the live deployed environment API.
409
+
410
+ Sends the model's completion to the environment server via HTTP,
411
+ replaying the parsed submit_memo action, and returns the official reward.
412
+ Falls back to local reward on any network error.
413
+ """
414
+ import urllib.error
415
+ try:
416
+ from server.reward import compute_reward as _local_reward
417
+ except ImportError:
418
+ _local_reward = None
419
+
420
+ parsed = parse_model_output(completion)
421
+ if not parsed["recommended_outcome"]:
422
+ return 0.0 # Malformed output
423
+
424
+ try:
425
+ # Step 1: Reset the environment with the correct episode
426
+ episode_stage = episode.get("curriculum_stage", 1)
427
+ reset_url = f"{env_url}/reset?stage={episode_stage}"
428
+ req = urllib.request.Request(reset_url, method="POST")
429
+ with urllib.request.urlopen(req, timeout=timeout) as resp:
430
+ reset_data = json.loads(resp.read())
431
+ sid = session_id or reset_data.get("session_id", "")
432
+
433
+ # Step 2: Submit the parsed memo
434
+ memo_payload = json.dumps({
435
+ "session_id": sid,
436
+ "action": {
437
+ "tool_name": "submit_memo",
438
+ "flight_risk": parsed["flight_risk"] or "Medium",
439
+ "flight_risk_justification": parsed["flight_risk_just"] or "Not specified",
440
+ "statutory_eligible": parsed["statutory_eligible"],
441
+ "statutory_computation": parsed["statutory_computation"] or "Not computed",
442
+ "grounds_for_bail": parsed["grounds_for"] or [],
443
+ "grounds_against_bail": parsed["grounds_against"] or [],
444
+ "recommended_outcome": parsed["recommended_outcome"],
445
+ "recommended_conditions": parsed["conditions"] or [],
446
+ "confidence": "Medium",
447
+ }
448
+ }).encode()
449
+ step_req = urllib.request.Request(
450
+ f"{env_url}/step",
451
+ data=memo_payload,
452
+ headers={"Content-Type": "application/json"},
453
+ method="POST",
454
+ )
455
+ with urllib.request.urlopen(step_req, timeout=timeout) as resp:
456
+ step_data = json.loads(resp.read())
457
+ return float(step_data.get("reward", 0.0))
458
+
459
+ except Exception as e:
460
+ # Network / parse error: fall back to local reward
461
+ print(f"[env_api] Falling back to local reward: {e}")
462
+ if _local_reward and episode:
463
+ rd = _local_reward(
464
+ agent_outcome=parsed["recommended_outcome"],
465
+ agent_flight_risk=parsed["flight_risk"] or "Medium",
466
+ agent_eligible=parsed["statutory_eligible"],
467
+ agent_computation=parsed["statutory_computation"] or "",
468
+ agent_conditions=parsed["conditions"] or [],
469
+ episode=episode,
470
+ )
471
+ return rd["total_reward"]
472
+ return 0.0
473
 
474
 
475
  def build_hf_dataset(episodes: List[Dict], tokenizer) -> Dataset: