Spaces:
Running
Running
Commit ·
6218d9a
1
Parent(s): d76d092
Fix 5 audit gaps: conditional bail, action history, efficiency reward, train/val split, env API routing
Browse files- models.py +6 -2
- server/reward.py +35 -13
- server/undertrial_environment.py +8 -0
- training/train_grpo.py +125 -7
models.py
CHANGED
|
@@ -121,8 +121,8 @@ class SubmitMemoAction(Action):
|
|
| 121 |
)
|
| 122 |
|
| 123 |
# Recommendation
|
| 124 |
-
recommended_outcome: Literal["Bail Granted", "Bail Denied"] = Field(
|
| 125 |
-
..., description="Final recommendation"
|
| 126 |
)
|
| 127 |
recommended_conditions: Optional[List[str]] = Field(
|
| 128 |
None,
|
|
@@ -186,6 +186,10 @@ class CaseObservation(Observation):
|
|
| 186 |
|
| 187 |
# Episode state
|
| 188 |
action_result: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
flags_raised: List[str] = Field(default_factory=list)
|
| 190 |
precedents_retrieved: List[str] = Field(default_factory=list)
|
| 191 |
memo_submitted: bool = False
|
|
|
|
| 121 |
)
|
| 122 |
|
| 123 |
# Recommendation
|
| 124 |
+
recommended_outcome: Literal["Bail Granted", "Bail Denied", "Bail Conditional"] = Field(
|
| 125 |
+
..., description="Final recommendation: Bail Granted | Bail Denied | Bail Conditional (strict conditions)"
|
| 126 |
)
|
| 127 |
recommended_conditions: Optional[List[str]] = Field(
|
| 128 |
None,
|
|
|
|
| 186 |
|
| 187 |
# Episode state
|
| 188 |
action_result: Optional[str] = None
|
| 189 |
+
action_history: List[str] = Field(
|
| 190 |
+
default_factory=list,
|
| 191 |
+
description="Ordered log of all tool results seen so far this episode",
|
| 192 |
+
)
|
| 193 |
flags_raised: List[str] = Field(default_factory=list)
|
| 194 |
precedents_retrieved: List[str] = Field(default_factory=list)
|
| 195 |
memo_submitted: bool = False
|
server/reward.py
CHANGED
|
@@ -20,17 +20,26 @@ def compute_outcome_match(agent_outcome: str, ground_truth: Dict[str, Any]) -> f
|
|
| 20 |
Checks if the agent's final recommendation matches the High Court decision.
|
| 21 |
|
| 22 |
Scoring:
|
| 23 |
-
1.0 — Exact string match
|
| 24 |
-
0.
|
|
|
|
| 25 |
0.0 — Wrong direction (granted vs. denied, or vice versa)
|
| 26 |
"""
|
| 27 |
gt = ground_truth["outcome"]
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
if
|
| 30 |
return 1.0
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
return 0.8 if (agent_granted == gt_granted) else 0.0
|
| 36 |
|
|
@@ -334,6 +343,8 @@ def compute_reward(
|
|
| 334 |
agent_computation: str,
|
| 335 |
agent_conditions: List[str],
|
| 336 |
episode: Dict[str, Any],
|
|
|
|
|
|
|
| 337 |
) -> Dict[str, float]:
|
| 338 |
"""
|
| 339 |
Computes the full reward for a submitted bail assessment memo.
|
|
@@ -343,10 +354,11 @@ def compute_reward(
|
|
| 343 |
+ 0.2*flight_risk_accuracy
|
| 344 |
+ 0.2*statutory_accuracy
|
| 345 |
+ 0.2*condition_appropriateness
|
|
|
|
| 346 |
- 0.3*bias_penalty
|
| 347 |
|
| 348 |
Returns a dict with all component scores + total_reward.
|
| 349 |
-
Range: [-0.3, 1.
|
| 350 |
"""
|
| 351 |
gt = episode["ground_truth"]
|
| 352 |
|
|
@@ -356,16 +368,26 @@ def compute_reward(
|
|
| 356 |
ca = compute_condition_score(agent_outcome, agent_conditions, gt)
|
| 357 |
bias = compute_bias_penalty(agent_outcome, episode)
|
| 358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
lam = 0.3
|
| 360 |
-
total = 0.4*om + 0.2*fr + 0.2*sa + 0.2*ca - lam*bias
|
| 361 |
|
| 362 |
return {
|
| 363 |
-
"outcome_match": round(om,
|
| 364 |
-
"flight_risk_accuracy": round(fr,
|
| 365 |
-
"statutory_accuracy": round(sa,
|
| 366 |
-
"condition_appropriateness": round(ca,
|
| 367 |
-
"
|
| 368 |
-
"
|
|
|
|
| 369 |
"ground_truth_outcome": gt["outcome"],
|
| 370 |
"agent_outcome": agent_outcome,
|
|
|
|
| 371 |
}
|
|
|
|
| 20 |
Checks if the agent's final recommendation matches the High Court decision.
|
| 21 |
|
| 22 |
Scoring:
|
| 23 |
+
1.0 — Exact string match
|
| 24 |
+
0.9 — "Bail Conditional" vs "Bail Granted" (conditional IS bail)
|
| 25 |
+
0.8 — Directionally correct but loose string
|
| 26 |
0.0 — Wrong direction (granted vs. denied, or vice versa)
|
| 27 |
"""
|
| 28 |
gt = ground_truth["outcome"]
|
| 29 |
+
agent_norm = agent_outcome.strip().lower()
|
| 30 |
+
gt_norm = gt.strip().lower()
|
| 31 |
|
| 32 |
+
if agent_norm == gt_norm:
|
| 33 |
return 1.0
|
| 34 |
|
| 35 |
+
# Conditional bail counts almost as well as full bail
|
| 36 |
+
if "conditional" in agent_norm and "grant" in gt_norm:
|
| 37 |
+
return 0.9
|
| 38 |
+
if "grant" in agent_norm and "conditional" in gt_norm:
|
| 39 |
+
return 0.9
|
| 40 |
+
|
| 41 |
+
agent_granted = "grant" in agent_norm or "conditional" in agent_norm
|
| 42 |
+
gt_granted = "grant" in gt_norm or "conditional" in gt_norm
|
| 43 |
|
| 44 |
return 0.8 if (agent_granted == gt_granted) else 0.0
|
| 45 |
|
|
|
|
| 343 |
agent_computation: str,
|
| 344 |
agent_conditions: List[str],
|
| 345 |
episode: Dict[str, Any],
|
| 346 |
+
step_count: int = 0,
|
| 347 |
+
max_steps: int = 10,
|
| 348 |
) -> Dict[str, float]:
|
| 349 |
"""
|
| 350 |
Computes the full reward for a submitted bail assessment memo.
|
|
|
|
| 354 |
+ 0.2*flight_risk_accuracy
|
| 355 |
+ 0.2*statutory_accuracy
|
| 356 |
+ 0.2*condition_appropriateness
|
| 357 |
+
+ 0.1*efficiency_bonus (only when outcome is correct)
|
| 358 |
- 0.3*bias_penalty
|
| 359 |
|
| 360 |
Returns a dict with all component scores + total_reward.
|
| 361 |
+
Range: [-0.3, 1.1] (efficiency can push above 1.0 slightly on perfect runs).
|
| 362 |
"""
|
| 363 |
gt = episode["ground_truth"]
|
| 364 |
|
|
|
|
| 368 |
ca = compute_condition_score(agent_outcome, agent_conditions, gt)
|
| 369 |
bias = compute_bias_penalty(agent_outcome, episode)
|
| 370 |
|
| 371 |
+
# R4 — Efficiency bonus: reward finishing faster when the answer is correct.
|
| 372 |
+
# Only fires on directionally-correct outcomes (om >= 0.8) to prevent
|
| 373 |
+
# rewarding efficient-but-wrong agents.
|
| 374 |
+
efficiency = 0.0
|
| 375 |
+
if om >= 0.8 and max_steps > 1:
|
| 376 |
+
efficiency = round((1.0 - (step_count - 1) / (max_steps - 1)), 4)
|
| 377 |
+
efficiency = max(0.0, min(1.0, efficiency))
|
| 378 |
+
|
| 379 |
lam = 0.3
|
| 380 |
+
total = 0.4*om + 0.2*fr + 0.2*sa + 0.2*ca + 0.1*efficiency - lam*bias
|
| 381 |
|
| 382 |
return {
|
| 383 |
+
"outcome_match": round(om, 4),
|
| 384 |
+
"flight_risk_accuracy": round(fr, 4),
|
| 385 |
+
"statutory_accuracy": round(sa, 4),
|
| 386 |
+
"condition_appropriateness": round(ca, 4),
|
| 387 |
+
"efficiency_bonus": round(efficiency, 4),
|
| 388 |
+
"bias_penalty": round(bias, 4),
|
| 389 |
+
"total_reward": round(total, 4),
|
| 390 |
"ground_truth_outcome": gt["outcome"],
|
| 391 |
"agent_outcome": agent_outcome,
|
| 392 |
+
"steps_used": step_count,
|
| 393 |
}
|
server/undertrial_environment.py
CHANGED
|
@@ -86,6 +86,7 @@ class UndertriAIEnvironment(Environment):
|
|
| 86 |
self._step_count = 0
|
| 87 |
self._flags = []
|
| 88 |
self._retrieved_precedents = []
|
|
|
|
| 89 |
return self._make_observation(action_result=None)
|
| 90 |
|
| 91 |
def step(
|
|
@@ -113,6 +114,8 @@ class UndertriAIEnvironment(Environment):
|
|
| 113 |
agent_computation = action.statutory_computation,
|
| 114 |
agent_conditions = action.recommended_conditions or [],
|
| 115 |
episode = self._episode,
|
|
|
|
|
|
|
| 116 |
)
|
| 117 |
# Apply skip penalty (can push total legitimately negative)
|
| 118 |
reward_dict["total_reward"] = round(reward_dict["total_reward"] - no_tool_penalty, 4)
|
|
@@ -149,6 +152,10 @@ class UndertriAIEnvironment(Environment):
|
|
| 149 |
else:
|
| 150 |
result = self._dispatch_tool(action)
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
# Force submit if max steps reached
|
| 153 |
done = (self._step_count >= self.MAX_STEPS)
|
| 154 |
reward = -0.1 if done else 0.0 # Small penalty for exhausting budget
|
|
@@ -277,6 +284,7 @@ class UndertriAIEnvironment(Environment):
|
|
| 277 |
cited_precedents = init_precedents + self._retrieved_precedents,
|
| 278 |
documents_available = ep.get("documents_available", []),
|
| 279 |
action_result = action_result,
|
|
|
|
| 280 |
flags_raised = list(self._flags),
|
| 281 |
precedents_retrieved = list(self._retrieved_precedents),
|
| 282 |
memo_submitted = memo_submitted,
|
|
|
|
| 86 |
self._step_count = 0
|
| 87 |
self._flags = []
|
| 88 |
self._retrieved_precedents = []
|
| 89 |
+
self._action_history: List[str] = [] # accumulated tool results (Gap 4)
|
| 90 |
return self._make_observation(action_result=None)
|
| 91 |
|
| 92 |
def step(
|
|
|
|
| 114 |
agent_computation = action.statutory_computation,
|
| 115 |
agent_conditions = action.recommended_conditions or [],
|
| 116 |
episode = self._episode,
|
| 117 |
+
step_count = self._step_count, # Gap 5: efficiency reward
|
| 118 |
+
max_steps = self.MAX_STEPS,
|
| 119 |
)
|
| 120 |
# Apply skip penalty (can push total legitimately negative)
|
| 121 |
reward_dict["total_reward"] = round(reward_dict["total_reward"] - no_tool_penalty, 4)
|
|
|
|
| 152 |
else:
|
| 153 |
result = self._dispatch_tool(action)
|
| 154 |
|
| 155 |
+
# Accumulate action history (Gap 4)
|
| 156 |
+
summary = f"[Step {self._step_count}] {type(action).__name__}: {result[:120]}..."
|
| 157 |
+
self._action_history.append(summary)
|
| 158 |
+
|
| 159 |
# Force submit if max steps reached
|
| 160 |
done = (self._step_count >= self.MAX_STEPS)
|
| 161 |
reward = -0.1 if done else 0.0 # Small penalty for exhausting budget
|
|
|
|
| 284 |
cited_precedents = init_precedents + self._retrieved_precedents,
|
| 285 |
documents_available = ep.get("documents_available", []),
|
| 286 |
action_result = action_result,
|
| 287 |
+
action_history = list(self._action_history), # Gap 4
|
| 288 |
flags_raised = list(self._flags),
|
| 289 |
precedents_retrieved = list(self._retrieved_precedents),
|
| 290 |
memo_submitted = memo_submitted,
|
training/train_grpo.py
CHANGED
|
@@ -25,12 +25,20 @@ INSTALL_COMMANDS = """
|
|
| 25 |
# CELL 2 — Imports
|
| 26 |
# ============================================================
|
| 27 |
|
| 28 |
-
import os, sys, json, re, argparse, random
|
| 29 |
from pathlib import Path
|
| 30 |
-
from typing import List, Dict, Any, Optional
|
|
|
|
|
|
|
| 31 |
|
| 32 |
import torch
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
# ── Fix 1: Import authoritative reward functions from server/reward.py ──────
|
| 35 |
# This ensures training optimises the SAME signal the deployed demo evaluates.
|
| 36 |
try:
|
|
@@ -334,7 +342,15 @@ def combined_reward(
|
|
| 334 |
ca = reward_conditions([comp], [ep])[0] # condition score, not format
|
| 335 |
b = reward_no_bias([comp], [ep])[0]
|
| 336 |
|
| 337 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
rewards.append(round(total, 4)) # No max(0.0) clamp — bias can go negative
|
| 339 |
return rewards
|
| 340 |
|
|
@@ -343,15 +359,117 @@ def combined_reward(
|
|
| 343 |
# CELL 5 — Dataset builder
|
| 344 |
# ============================================================
|
| 345 |
|
| 346 |
-
def load_episodes(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
path = Path(episodes_dir) / f"episodes_stage_{stage}.jsonl"
|
| 348 |
if not path.exists():
|
| 349 |
-
# Try the combined file
|
| 350 |
path = Path(episodes_dir) / "episodes_all.jsonl"
|
| 351 |
if not path.exists():
|
| 352 |
-
raise FileNotFoundError(f"No episodes found in {episodes_dir}.
|
| 353 |
with open(path, encoding="utf-8") as f:
|
| 354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
|
| 356 |
|
| 357 |
def build_hf_dataset(episodes: List[Dict], tokenizer) -> Dataset:
|
|
|
|
| 25 |
# CELL 2 — Imports
|
| 26 |
# ============================================================
|
| 27 |
|
| 28 |
+
import os, sys, json, re, argparse, random, time
|
| 29 |
from pathlib import Path
|
| 30 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 31 |
+
import urllib.request
|
| 32 |
+
import urllib.parse
|
| 33 |
|
| 34 |
import torch
|
| 35 |
|
| 36 |
+
# ── Environment API (Gap 1) ─────────────────────────────────────────────────
|
| 37 |
+
ENV_API_URL = os.environ.get(
|
| 38 |
+
"UNDERTRIAL_ENV_URL",
|
| 39 |
+
"https://draken1606-undertrial-ai.hf.space",
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
# ── Fix 1: Import authoritative reward functions from server/reward.py ──────
|
| 43 |
# This ensures training optimises the SAME signal the deployed demo evaluates.
|
| 44 |
try:
|
|
|
|
| 342 |
ca = reward_conditions([comp], [ep])[0] # condition score, not format
|
| 343 |
b = reward_no_bias([comp], [ep])[0]
|
| 344 |
|
| 345 |
+
# R4 efficiency bonus: reward fewer steps when outcome is correct
|
| 346 |
+
eff = 0.0
|
| 347 |
+
if o >= 0.8:
|
| 348 |
+
steps_taken = kwargs.get("step_counts", [None] * len(completions))
|
| 349 |
+
sc = steps_taken[completions.index(comp)] if comp in completions else None
|
| 350 |
+
if sc is not None:
|
| 351 |
+
eff = max(0.0, 1.0 - (sc - 1) / 9)
|
| 352 |
+
|
| 353 |
+
total = 0.4*o + 0.2*fr + 0.2*s + 0.2*ca + 0.1*eff - 0.3*b
|
| 354 |
rewards.append(round(total, 4)) # No max(0.0) clamp — bias can go negative
|
| 355 |
return rewards
|
| 356 |
|
|
|
|
| 359 |
# CELL 5 — Dataset builder
|
| 360 |
# ============================================================
|
| 361 |
|
| 362 |
+
def load_episodes(
|
| 363 |
+
episodes_dir: str,
|
| 364 |
+
stage: int = 1,
|
| 365 |
+
split: str = "train",
|
| 366 |
+
val_fraction: float = 0.15,
|
| 367 |
+
test_fraction: float = 0.10,
|
| 368 |
+
) -> List[Dict]:
|
| 369 |
+
"""
|
| 370 |
+
Load episodes for a given split (Gap 2: train/val/test split).
|
| 371 |
+
|
| 372 |
+
Split fractions (applied deterministically by index, no shuffle):
|
| 373 |
+
train = first (1 - val - test) fraction
|
| 374 |
+
val = next val_fraction
|
| 375 |
+
test = last test_fraction
|
| 376 |
+
"""
|
| 377 |
path = Path(episodes_dir) / f"episodes_stage_{stage}.jsonl"
|
| 378 |
if not path.exists():
|
|
|
|
| 379 |
path = Path(episodes_dir) / "episodes_all.jsonl"
|
| 380 |
if not path.exists():
|
| 381 |
+
raise FileNotFoundError(f"No episodes found in {episodes_dir}.")
|
| 382 |
with open(path, encoding="utf-8") as f:
|
| 383 |
+
all_eps = [json.loads(l) for l in f if l.strip()]
|
| 384 |
+
|
| 385 |
+
n = len(all_eps)
|
| 386 |
+
n_test = max(1, int(n * test_fraction))
|
| 387 |
+
n_val = max(1, int(n * val_fraction))
|
| 388 |
+
n_train = n - n_val - n_test
|
| 389 |
+
|
| 390 |
+
if split == "train":
|
| 391 |
+
return all_eps[:n_train]
|
| 392 |
+
elif split == "val":
|
| 393 |
+
return all_eps[n_train:n_train + n_val]
|
| 394 |
+
elif split == "test":
|
| 395 |
+
return all_eps[n_train + n_val:]
|
| 396 |
+
else:
|
| 397 |
+
return all_eps # all: for backward compat
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def rollout_via_env_api(
|
| 401 |
+
completion: str,
|
| 402 |
+
episode: Dict,
|
| 403 |
+
env_url: str = ENV_API_URL,
|
| 404 |
+
session_id: Optional[str] = None,
|
| 405 |
+
timeout: float = 10.0,
|
| 406 |
+
) -> float:
|
| 407 |
+
"""
|
| 408 |
+
Gap 1: Route reward through the live deployed environment API.
|
| 409 |
+
|
| 410 |
+
Sends the model's completion to the environment server via HTTP,
|
| 411 |
+
replaying the parsed submit_memo action, and returns the official reward.
|
| 412 |
+
Falls back to local reward on any network error.
|
| 413 |
+
"""
|
| 414 |
+
import urllib.error
|
| 415 |
+
try:
|
| 416 |
+
from server.reward import compute_reward as _local_reward
|
| 417 |
+
except ImportError:
|
| 418 |
+
_local_reward = None
|
| 419 |
+
|
| 420 |
+
parsed = parse_model_output(completion)
|
| 421 |
+
if not parsed["recommended_outcome"]:
|
| 422 |
+
return 0.0 # Malformed output
|
| 423 |
+
|
| 424 |
+
try:
|
| 425 |
+
# Step 1: Reset the environment with the correct episode
|
| 426 |
+
episode_stage = episode.get("curriculum_stage", 1)
|
| 427 |
+
reset_url = f"{env_url}/reset?stage={episode_stage}"
|
| 428 |
+
req = urllib.request.Request(reset_url, method="POST")
|
| 429 |
+
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
| 430 |
+
reset_data = json.loads(resp.read())
|
| 431 |
+
sid = session_id or reset_data.get("session_id", "")
|
| 432 |
+
|
| 433 |
+
# Step 2: Submit the parsed memo
|
| 434 |
+
memo_payload = json.dumps({
|
| 435 |
+
"session_id": sid,
|
| 436 |
+
"action": {
|
| 437 |
+
"tool_name": "submit_memo",
|
| 438 |
+
"flight_risk": parsed["flight_risk"] or "Medium",
|
| 439 |
+
"flight_risk_justification": parsed["flight_risk_just"] or "Not specified",
|
| 440 |
+
"statutory_eligible": parsed["statutory_eligible"],
|
| 441 |
+
"statutory_computation": parsed["statutory_computation"] or "Not computed",
|
| 442 |
+
"grounds_for_bail": parsed["grounds_for"] or [],
|
| 443 |
+
"grounds_against_bail": parsed["grounds_against"] or [],
|
| 444 |
+
"recommended_outcome": parsed["recommended_outcome"],
|
| 445 |
+
"recommended_conditions": parsed["conditions"] or [],
|
| 446 |
+
"confidence": "Medium",
|
| 447 |
+
}
|
| 448 |
+
}).encode()
|
| 449 |
+
step_req = urllib.request.Request(
|
| 450 |
+
f"{env_url}/step",
|
| 451 |
+
data=memo_payload,
|
| 452 |
+
headers={"Content-Type": "application/json"},
|
| 453 |
+
method="POST",
|
| 454 |
+
)
|
| 455 |
+
with urllib.request.urlopen(step_req, timeout=timeout) as resp:
|
| 456 |
+
step_data = json.loads(resp.read())
|
| 457 |
+
return float(step_data.get("reward", 0.0))
|
| 458 |
+
|
| 459 |
+
except Exception as e:
|
| 460 |
+
# Network / parse error: fall back to local reward
|
| 461 |
+
print(f"[env_api] Falling back to local reward: {e}")
|
| 462 |
+
if _local_reward and episode:
|
| 463 |
+
rd = _local_reward(
|
| 464 |
+
agent_outcome=parsed["recommended_outcome"],
|
| 465 |
+
agent_flight_risk=parsed["flight_risk"] or "Medium",
|
| 466 |
+
agent_eligible=parsed["statutory_eligible"],
|
| 467 |
+
agent_computation=parsed["statutory_computation"] or "",
|
| 468 |
+
agent_conditions=parsed["conditions"] or [],
|
| 469 |
+
episode=episode,
|
| 470 |
+
)
|
| 471 |
+
return rd["total_reward"]
|
| 472 |
+
return 0.0
|
| 473 |
|
| 474 |
|
| 475 |
def build_hf_dataset(episodes: List[Dict], tokenizer) -> Dataset:
|