Draken1606 commited on
Commit
898bc18
·
1 Parent(s): ca62faa

Fix 8 compliance gaps: repeat-action dedup+cache, min-steps hard block, criminal history tool (12th action), efficiency removed from training formula, circular import cleaned, yaml formula synced

Browse files
models.py CHANGED
@@ -135,6 +135,13 @@ class ApplyProportionalityAction(Action):
135
  )
136
 
137
 
 
 
 
 
 
 
 
138
  class SubmitMemoAction(Action):
139
  """
140
  TERMINAL ACTION — Submit the structured bail assessment memo.
 
135
  )
136
 
137
 
138
+ class PullCriminalHistoryAction(Action):
139
+ """Pull the accused's prior criminal record, bail history, and conviction status."""
140
+ tool_name: Literal["pull_criminal_history"] = "pull_criminal_history"
141
+ include_bail_history: bool = Field(
142
+ default=True, description="Whether to include prior bail applications and outcomes"
143
+ )
144
+
145
  class SubmitMemoAction(Action):
146
  """
147
  TERMINAL ACTION — Submit the structured bail assessment memo.
openenv.yaml CHANGED
@@ -46,20 +46,23 @@ actions:
46
  description: Examine specific case factors (parity, evidence tampering, victim vulnerability)
47
  - name: apply_proportionality
48
  description: Apply BNSS 479 proportionality — custody vs. max sentence vs. trial timeline
 
 
49
  - name: submit_memo
50
  description: "TERMINAL — Submit structured bail assessment memo"
51
 
52
  reward:
53
- formula: "0.4*outcome + 0.2*flight_risk + 0.2*statutory + 0.2*conditions + 0.1*efficiency + 0.05*process_bonus - 0.3*bias"
54
  range: [-0.7, 1.15]
55
  terminal_action: submit_memo
56
  deterministic: true
57
  llm_as_judge: false
58
  components:
59
- - outcome_match: "Agreement with real High Court decision (40%)"
60
  - flight_risk_accuracy: "Flight risk classification accuracy (20%)"
61
  - statutory_accuracy: "IPC/BNSS threshold computation (20%)"
62
  - condition_appropriateness: "Bail condition quality (20%)"
 
63
  - bias_penalty: "Penalty for ignoring parity in bias cases (-30%)"
64
 
65
  curriculum:
 
46
  description: Examine specific case factors (parity, evidence tampering, victim vulnerability)
47
  - name: apply_proportionality
48
  description: Apply BNSS 479 proportionality — custody vs. max sentence vs. trial timeline
49
+ - name: pull_criminal_history
50
+ description: Pull the accused's prior criminal record, bail history, and conviction status
51
  - name: submit_memo
52
  description: "TERMINAL — Submit structured bail assessment memo"
53
 
54
  reward:
55
+ formula: "0.3*outcome + 0.2*flight_risk + 0.2*statutory + 0.2*conditions + 0.1*reasoning_quality + 0.1*efficiency + 0.05*process_bonus - 0.3*bias"
56
  range: [-0.7, 1.15]
57
  terminal_action: submit_memo
58
  deterministic: true
59
  llm_as_judge: false
60
  components:
61
+ - outcome_match: "Agreement with real High Court decision (30%)"
62
  - flight_risk_accuracy: "Flight risk classification accuracy (20%)"
63
  - statutory_accuracy: "IPC/BNSS threshold computation (20%)"
64
  - condition_appropriateness: "Bail condition quality (20%)"
65
+ - reasoning_quality: "Justification anchoring + arithmetic verification + grounds specificity (10%)"
66
  - bias_penalty: "Penalty for ignoring parity in bias cases (-30%)"
67
 
68
  curriculum:
server/__init__.py CHANGED
@@ -1,6 +1,5 @@
1
- """UndertriAI server package."""
2
- try:
3
- from ..models import *
4
- from ..client import UndertriAIEnv
5
- except ImportError:
6
- pass # Standalone import (e.g., from train_grpo.py) — skip re-exports
 
1
+ """UndertriAI server package.
2
+
3
+ Server-side code only. Do not import from client.py here —
4
+ user-facing exports live in the root undertrial_ai/__init__.py.
5
+ """
 
server/app.py CHANGED
@@ -101,7 +101,7 @@ def step(payload: dict):
101
  AssessSuretyAction, ClassifyBailTypeAction,
102
  ReadSubmissionsAction, AssessFlightRiskAction,
103
  CheckCaseFactorsAction, ApplyProportionalityAction,
104
- SubmitMemoAction,
105
  )
106
  ACTION_MAP = {
107
  "request_document": RequestDocumentAction,
@@ -114,6 +114,7 @@ def step(payload: dict):
114
  "assess_flight_risk": AssessFlightRiskAction,
115
  "check_case_factors": CheckCaseFactorsAction,
116
  "apply_proportionality": ApplyProportionalityAction,
 
117
  "submit_memo": SubmitMemoAction,
118
  }
119
  action_cls = ACTION_MAP.get(tool_name)
 
101
  AssessSuretyAction, ClassifyBailTypeAction,
102
  ReadSubmissionsAction, AssessFlightRiskAction,
103
  CheckCaseFactorsAction, ApplyProportionalityAction,
104
+ PullCriminalHistoryAction, SubmitMemoAction,
105
  )
106
  ACTION_MAP = {
107
  "request_document": RequestDocumentAction,
 
114
  "assess_flight_risk": AssessFlightRiskAction,
115
  "check_case_factors": CheckCaseFactorsAction,
116
  "apply_proportionality": ApplyProportionalityAction,
117
+ "pull_criminal_history": PullCriminalHistoryAction,
118
  "submit_memo": SubmitMemoAction,
119
  }
120
  action_cls = ACTION_MAP.get(tool_name)
server/undertrial_environment.py CHANGED
@@ -37,6 +37,7 @@ from ..models import (
37
  ComputeStatutoryEligibilityAction, AssessSuretyAction, ClassifyBailTypeAction,
38
  ReadSubmissionsAction, AssessFlightRiskAction, CheckCaseFactorsAction,
39
  ApplyProportionalityAction,
 
40
  SubmitMemoAction,
41
  )
42
  from .precedent_db import PrecedentDB
@@ -97,7 +98,8 @@ class UndertriAIEnvironment(Environment):
97
  self._flags = []
98
  self._retrieved_precedents = []
99
  self._action_history: List[str] = []
100
- self._statutory_tool_called: bool = False # M2: process reward tracking
 
101
  return self._make_observation(action_result=None)
102
 
103
  def step(
@@ -114,8 +116,26 @@ class UndertriAIEnvironment(Environment):
114
 
115
  # ---- Terminal action: submit memo ----
116
  if isinstance(action, SubmitMemoAction):
117
- # Penalty for skipping all tool calls
118
- # Increased to 0.40 so instant-submit can never be profitable by chance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  no_tool_penalty = 0.40 if self._step_count == 1 else 0.0
120
 
121
  reward_dict = compute_reward(
@@ -147,9 +167,26 @@ class UndertriAIEnvironment(Environment):
147
  info=reward_dict,
148
  )
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  # ---- Tool actions with optional timeout enforcement ----
151
  if isinstance(action, ComputeStatutoryEligibilityAction):
152
- self._statutory_tool_called = True # M2: track for process reward
153
 
154
  if timeout_s is not None:
155
  with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
@@ -342,6 +379,26 @@ class UndertriAIEnvironment(Environment):
342
  lines.append(" ⚠️ Projected total custody exceeds maximum sentence — strong proportionality argument for bail")
343
  return "\n".join(lines)
344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  return f"Unknown action type: {type(action).__name__}"
346
 
347
  # ------------------------------------------------------------------
 
37
  ComputeStatutoryEligibilityAction, AssessSuretyAction, ClassifyBailTypeAction,
38
  ReadSubmissionsAction, AssessFlightRiskAction, CheckCaseFactorsAction,
39
  ApplyProportionalityAction,
40
+ PullCriminalHistoryAction,
41
  SubmitMemoAction,
42
  )
43
  from .precedent_db import PrecedentDB
 
98
  self._flags = []
99
  self._retrieved_precedents = []
100
  self._action_history: List[str] = []
101
+ self._statutory_tool_called: bool = False # process reward tracking
102
+ self._tools_called: set = set() # 5B.2: track unique tool types for repeat detection
103
  return self._make_observation(action_result=None)
104
 
105
  def step(
 
116
 
117
  # ---- Terminal action: submit memo ----
118
  if isinstance(action, SubmitMemoAction):
119
+ # 4.5 Hard minimum: agent must have called at least 1 distinct tool before submitting.
120
+ # This is a structural gate — even a skip-penalty can't compensate for zero information.
121
+ if len(self._tools_called) == 0:
122
+ obs = self._make_observation(
123
+ action_result=(
124
+ "[BLOCKED] You must call at least one legal tool before submitting a memo. "
125
+ "Use tools such as compute_statutory_eligibility, assess_flight_risk, "
126
+ "read_submissions, or check_case_factors first."
127
+ ),
128
+ memo_submitted=False,
129
+ )
130
+ return StepResult(
131
+ observation=obs,
132
+ reward=-0.15, # Stronger signal than just a penalty post-submission
133
+ done=False,
134
+ info={"blocked": "minimum_tools_not_met", "tools_called": 0},
135
+ )
136
+
137
+ # Skip penalty only if submitted on step 1 despite having called a tool
138
+ # (edge case where first action is somehow both a tool and submit)
139
  no_tool_penalty = 0.40 if self._step_count == 1 else 0.0
140
 
141
  reward_dict = compute_reward(
 
167
  info=reward_dict,
168
  )
169
 
170
+ # ---- Repeat-action deduplication (5B.2) ----
171
+ tool_key = type(action).__name__
172
+ if tool_key in self._tools_called:
173
+ # Return cached note — no re-execution, no reward gaming
174
+ obs = self._make_observation(
175
+ action_result=(
176
+ f"[CACHED] {tool_key} was already called this episode. "
177
+ "The result is already in your action history above. "
178
+ "Use a different tool or submit your memo."
179
+ ),
180
+ memo_submitted=False,
181
+ )
182
+ return StepResult(observation=obs, reward=-0.05, done=False,
183
+ info={"cached": True, "tool": tool_key})
184
+
185
+ self._tools_called.add(tool_key)
186
+
187
  # ---- Tool actions with optional timeout enforcement ----
188
  if isinstance(action, ComputeStatutoryEligibilityAction):
189
+ self._statutory_tool_called = True # track for process reward
190
 
191
  if timeout_s is not None:
192
  with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
 
379
  lines.append(" ⚠️ Projected total custody exceeds maximum sentence — strong proportionality argument for bail")
380
  return "\n".join(lines)
381
 
382
+ elif isinstance(action, PullCriminalHistoryAction):
383
+ ep = self._episode
384
+ profile = ep.get("accused_profile", {})
385
+ prior = profile.get("prior_cases", "No prior criminal record on file")
386
+ bail_type = profile.get("bail_type", "Unknown")
387
+ lines = [
388
+ "Criminal History Report:",
389
+ f" Prior cases: {prior}",
390
+ f" Bail type context: {bail_type}",
391
+ ]
392
+ if action.include_bail_history:
393
+ # Infer from parity flag and stage whether HC has dealt with bail before
394
+ parity = ep.get("ground_truth", {}).get("parity_argument_used", False)
395
+ lines.append(
396
+ f" Prior bail history: {'Co-accused parity argument on record — HC previously granted bail to similarly placed accused' if parity else 'No co-accused parity argument on record'}"
397
+ )
398
+ first_time = prior in ("None", "nil", "no prior", "No prior criminal record on file", None, "")
399
+ lines.append(f" → Classification: {'FIRST-TIME OFFENDER ✓' if first_time else 'HAS PRIOR RECORD — review above'}")
400
+ return "\n".join(lines)
401
+
402
  return f"Unknown action type: {type(action).__name__}"
403
 
404
  # ------------------------------------------------------------------
training/train_grpo.py CHANGED
@@ -359,15 +359,12 @@ def combined_reward(
359
  b = reward_no_bias([comp], [ep])[0]
360
  rq = 0.5 # Neutral when server functions unavailable
361
 
362
- # R4 efficiency bonus: reward fewer steps when outcome is correct
 
 
363
  eff = 0.0
364
- if o >= 0.8:
365
- steps_taken = kwargs.get("step_counts", [None] * len(completions))
366
- sc = steps_taken[completions.index(comp)] if comp in completions else None
367
- if sc is not None:
368
- eff = max(0.0, 1.0 - (sc - 1) / 9)
369
 
370
- total = 0.3*o + 0.2*fr + 0.2*s + 0.2*ca + 0.1*rq + 0.1*eff - 0.3*b
371
  rewards.append(round(total, 4)) # No max(0.0) clamp — bias can go negative
372
  return rewards
373
 
 
359
  b = reward_no_bias([comp], [ep])[0]
360
  rq = 0.5 # Neutral when server functions unavailable
361
 
362
+ # NOTE: Efficiency is NOT computed in GRPO training because step_count=1
363
+ # always (single-shot generation), making eff=1.0 a constant non-signal.
364
+ # Efficiency is preserved in the environment's compute_reward for live inference.
365
  eff = 0.0
 
 
 
 
 
366
 
367
+ total = 0.3*o + 0.2*fr + 0.2*s + 0.2*ca + 0.1*rq - 0.3*b
368
  rewards.append(round(total, 4)) # No max(0.0) clamp — bias can go negative
369
  return rewards
370