Spaces:
Running
Running
Commit Β·
6f95f2a
1
Parent(s): 829f543
fix(benchmark): Hardening multi-agent environment and strict score compliance
Browse files- Enforce strictly OpenEnv-compliant scores in (0.01, 0.99) across all paths
- Implement smart prompt truncation to fix API errors
- Introduce domain check LRU caching
- Increase dataset size to 41 cases (added 9 new hard tasks)
- Convert to Docker multi-stage build mapped with Gunicorn
- Overhaul Web UI for correct dashboard logging
- Dockerfile +12 -12
- README.md +11 -9
- docs/screenshot.png +0 -0
- inference.py +61 -45
- mock.py +27 -0
- server/app.py +72 -34
- server/datasets/clinical_cases.py +65 -0
- server/datasets/dependency_cases.py +148 -0
- server/datasets/security_cases.py +76 -0
- server/graders/security_grader.py +33 -14
- server/router.py +8 -1
- server/validation/validator.py +24 -8
- server/web_ui.py +7 -7
Dockerfile
CHANGED
|
@@ -1,29 +1,29 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
#
|
|
|
|
| 4 |
LABEL org.opencontainers.image.title="multi-agent-dev-tools-env"
|
| 5 |
LABEL org.opencontainers.image.description="Multi-Agent Dev Tools RL Environment"
|
| 6 |
LABEL openenv="true"
|
| 7 |
|
| 8 |
WORKDIR /app
|
| 9 |
|
| 10 |
-
#
|
| 11 |
-
COPY
|
| 12 |
-
RUN pip install --no-cache-dir . 2>/dev/null || pip install --no-cache-dir \
|
| 13 |
-
fastapi uvicorn pydantic openai requests packaging gradio python-dotenv
|
| 14 |
-
|
| 15 |
-
# Make sure results directory exists and is writable by any user
|
| 16 |
-
RUN mkdir -p results && chmod 777 results
|
| 17 |
|
| 18 |
# Copy project files
|
| 19 |
COPY . .
|
| 20 |
|
| 21 |
-
#
|
|
|
|
|
|
|
| 22 |
EXPOSE 7860
|
| 23 |
|
| 24 |
-
# Health check for HF Spaces
|
| 25 |
HEALTHCHECK --interval=30s --timeout=10s --retries=3 \
|
| 26 |
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:7860/')" || exit 1
|
| 27 |
|
| 28 |
-
# Start the server
|
| 29 |
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
|
| 1 |
+
# Stage 1: Build dependencies
|
| 2 |
+
FROM python:3.10-slim AS builder
|
| 3 |
+
WORKDIR /build
|
| 4 |
+
COPY pyproject.toml .
|
| 5 |
+
RUN pip install --no-cache-dir --target=/install . || pip install --no-cache-dir --target=/install fastapi uvicorn pydantic openai requests packaging gradio python-dotenv
|
| 6 |
|
| 7 |
+
# Stage 2: Runtime
|
| 8 |
+
FROM python:3.10-slim
|
| 9 |
LABEL org.opencontainers.image.title="multi-agent-dev-tools-env"
|
| 10 |
LABEL org.opencontainers.image.description="Multi-Agent Dev Tools RL Environment"
|
| 11 |
LABEL openenv="true"
|
| 12 |
|
| 13 |
WORKDIR /app
|
| 14 |
|
| 15 |
+
# Copy only installed packages (not build tools)
|
| 16 |
+
COPY --from=builder /install /usr/local/lib/python3.10/site-packages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# Copy project files
|
| 19 |
COPY . .
|
| 20 |
|
| 21 |
+
# Results directory
|
| 22 |
+
RUN mkdir -p results && chmod 777 results
|
| 23 |
+
|
| 24 |
EXPOSE 7860
|
| 25 |
|
|
|
|
| 26 |
HEALTHCHECK --interval=30s --timeout=10s --retries=3 \
|
| 27 |
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:7860/')" || exit 1
|
| 28 |
|
|
|
|
| 29 |
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -30,6 +30,8 @@ Most existing RL benchmarks test agents on **static, single-turn tasks** β cla
|
|
| 30 |
|
| 31 |
## π― What Is This?
|
| 32 |
|
|
|
|
|
|
|
| 33 |
This is a **training gym for AI agents** β not the agent itself.
|
| 34 |
Think of it like a driving test course: you build the course, and different AI "drivers" take the test.
|
| 35 |
|
|
@@ -66,7 +68,7 @@ Agents must identify vulnerabilities in code snippets, propose fixes, and iterat
|
|
| 66 |
| `sec_medium` | Medium | `multi` | 6 | 0.75 | `identify` β `propose_fix` β `revise_fix` |
|
| 67 |
| `sec_hard` | Hard | `adversarial` | 8 | 0.70 | `identify` β `propose_fix` β `revise_fix` (reviewer) |
|
| 68 |
|
| 69 |
-
**Dataset:**
|
| 70 |
|
| 71 |
### π¦ Domain 2: PyTorch Migration Time-Machine
|
| 72 |
|
|
@@ -78,7 +80,7 @@ Agents must detect deprecated APIs, resolve version conflicts, and fix `torch.co
|
|
| 78 |
| `dep_medium` | Medium | `resolve` | 6 | 0.75 | `resolve_conflict` |
|
| 79 |
| `dep_hard` | Hard | `migrate` | 8 | 0.70 | `migrate_api` / `validate_tree` |
|
| 80 |
|
| 81 |
-
**Dataset:**
|
| 82 |
|
| 83 |
### π₯ Domain 3: Clinical Workflow Chaos Simulator
|
| 84 |
|
|
@@ -90,7 +92,7 @@ Agents must detect missing steps in hospital workflows, rank them by priority, a
|
|
| 90 |
| `cli_medium` | Medium | 6 | 0.75 | `detect_gap` β `rank_issues` |
|
| 91 |
| `cli_hard` | Hard | 6 | 0.70 | `detect_gap` β `rank_issues` β `order_steps` |
|
| 92 |
|
| 93 |
-
**Dataset:**
|
| 94 |
|
| 95 |
---
|
| 96 |
|
|
@@ -225,9 +227,9 @@ project-root/
|
|
| 225 |
β β βββ dependency_grader.py # F1 scoring, version checking, graph ordering
|
| 226 |
β β βββ clinical_grader.py # F1, NDCG ranking, dependency-violation counting
|
| 227 |
β βββ datasets/
|
| 228 |
-
β βββ security_cases.py #
|
| 229 |
-
β βββ dependency_cases.py #
|
| 230 |
-
β βββ clinical_cases.py #
|
| 231 |
βββ results/
|
| 232 |
βββ run_history.json # Persistent benchmark results (auto-created)
|
| 233 |
```
|
|
@@ -337,9 +339,9 @@ Tested with multiple model families for universal compatibility:
|
|
| 337 |
|
| 338 |
| Model | Family | Parameters | Average Score |
|
| 339 |
|-------|--------|------------|---------------|
|
| 340 |
-
| Llama 3.3 70B | Meta | 70B | **0.
|
| 341 |
-
| Qwen3-32B | Alibaba | 32B | **0.
|
| 342 |
-
| DeepSeek V3.2 | DeepSeek | MoE | **0.
|
| 343 |
|
| 344 |
The environment provides smooth reward gradients that enable GRPO training of smaller models (8B+).
|
| 345 |
|
|
|
|
| 30 |
|
| 31 |
## π― What Is This?
|
| 32 |
|
| 33 |
+

|
| 34 |
+
|
| 35 |
This is a **training gym for AI agents** β not the agent itself.
|
| 36 |
Think of it like a driving test course: you build the course, and different AI "drivers" take the test.
|
| 37 |
|
|
|
|
| 68 |
| `sec_medium` | Medium | `multi` | 6 | 0.75 | `identify` β `propose_fix` β `revise_fix` |
|
| 69 |
| `sec_hard` | Hard | `adversarial` | 8 | 0.70 | `identify` β `propose_fix` β `revise_fix` (reviewer) |
|
| 70 |
|
| 71 |
+
**Dataset:** 13 ground-truth cases covering SQL injection, XSS, IDOR, hardcoded secrets, missing auth, JWT misuse, path traversal, SSRF, XXE.
|
| 72 |
|
| 73 |
### π¦ Domain 2: PyTorch Migration Time-Machine
|
| 74 |
|
|
|
|
| 80 |
| `dep_medium` | Medium | `resolve` | 6 | 0.75 | `resolve_conflict` |
|
| 81 |
| `dep_hard` | Hard | `migrate` | 8 | 0.70 | `migrate_api` / `validate_tree` |
|
| 82 |
|
| 83 |
+
**Dataset:** 13 ground-truth cases covering Variable, cuda(), DataParallel, ONNX export, torch.compile graph-breaks.
|
| 84 |
|
| 85 |
### π₯ Domain 3: Clinical Workflow Chaos Simulator
|
| 86 |
|
|
|
|
| 92 |
| `cli_medium` | Medium | 6 | 0.75 | `detect_gap` β `rank_issues` |
|
| 93 |
| `cli_hard` | Hard | 6 | 0.70 | `detect_gap` β `rank_issues` β `order_steps` |
|
| 94 |
|
| 95 |
+
**Dataset:** 13 ground-truth cases covering surgery prep, ER triage, chemotherapy, cardiac emergency, blood transfusion, organ transplant, stroke code.
|
| 96 |
|
| 97 |
---
|
| 98 |
|
|
|
|
| 227 |
β β βββ dependency_grader.py # F1 scoring, version checking, graph ordering
|
| 228 |
β β βββ clinical_grader.py # F1, NDCG ranking, dependency-violation counting
|
| 229 |
β βββ datasets/
|
| 230 |
+
β βββ security_cases.py # 13 cases: SQL injection, XSS, IDOR, SSRF, XXE, etc.
|
| 231 |
+
β βββ dependency_cases.py # 13 cases: Variable, cuda(), DataParallel, graph-breaks
|
| 232 |
+
β βββ clinical_cases.py # 13 cases: surgery prep, ER triage, chemo, cardiac, transplant
|
| 233 |
βββ results/
|
| 234 |
βββ run_history.json # Persistent benchmark results (auto-created)
|
| 235 |
```
|
|
|
|
| 339 |
|
| 340 |
| Model | Family | Parameters | Average Score |
|
| 341 |
|-------|--------|------------|---------------|
|
| 342 |
+
| Llama 3.3 70B | Meta | 70B | **0.87** |
|
| 343 |
+
| Qwen3-32B | Alibaba | 32B | **0.89** |
|
| 344 |
+
| DeepSeek V3.2 | DeepSeek | MoE | **0.86** |
|
| 345 |
|
| 346 |
The environment provides smooth reward gradients that enable GRPO training of smaller models (8B+).
|
| 347 |
|
docs/screenshot.png
ADDED
|
inference.py
CHANGED
|
@@ -84,69 +84,84 @@ CRITICAL: Output ONLY the JSON object. Nothing before or after it.
|
|
| 84 |
|
| 85 |
|
| 86 |
def build_user_prompt(step_num: int, obs: dict, history: list) -> str:
|
| 87 |
-
"""Build a focused user prompt from observation and history.
|
| 88 |
-
Works with ALL models β keeps context compact to avoid truncation.
|
| 89 |
-
"""
|
| 90 |
task_type = obs.get("task_type", "unknown")
|
| 91 |
task_id = obs.get("task_id", "unknown")
|
| 92 |
task_sub = obs.get("task_subtype", "")
|
| 93 |
|
| 94 |
parts = [f"Step {step_num} | task_type={task_type} | task_id={task_id} | subtype={task_sub}"]
|
| 95 |
|
| 96 |
-
# History summary
|
| 97 |
if history:
|
| 98 |
used = [h["action_type"] for h in history]
|
| 99 |
last = history[-1]
|
| 100 |
-
parts.append(f"Actions used
|
| 101 |
parts.append(f"Last reward: {last['reward']:.2f}")
|
| 102 |
-
if last["reward"]
|
| 103 |
-
parts.append("
|
| 104 |
-
elif last["reward"] < 0.4:
|
| 105 |
-
parts.append(f"WARNING: Low score ({last['reward']:.2f}). Try a better approach.")
|
| 106 |
|
| 107 |
-
# Validation failure
|
| 108 |
if obs.get("validation_failed"):
|
| 109 |
-
parts.append(f"\
|
| 110 |
-
parts.append(f"Error: {obs.get('message', 'unknown
|
| 111 |
-
|
| 112 |
-
parts.append(f"Hint: {hint}")
|
| 113 |
-
parts.append("Fix your JSON and try again with a VALID action.")
|
| 114 |
|
| 115 |
-
# Reviewer feedback
|
| 116 |
if obs.get("reviewer_feedback"):
|
| 117 |
-
parts.append(f"\
|
| 118 |
parts.append(obs["reviewer_feedback"])
|
| 119 |
|
| 120 |
-
#
|
| 121 |
obs_copy = dict(obs)
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
# Next action hint
|
| 132 |
if task_type == "security":
|
| 133 |
used_types = [h["action_type"] for h in history]
|
| 134 |
-
if not
|
| 135 |
-
parts.append("\
|
| 136 |
elif "propose_fix" not in used_types:
|
| 137 |
-
parts.append("\
|
| 138 |
else:
|
| 139 |
-
parts.append("\
|
|
|
|
| 140 |
elif task_type == "clinical":
|
| 141 |
used_types = [h["action_type"] for h in history]
|
| 142 |
if "detect_gap" not in used_types:
|
| 143 |
-
parts.append("\
|
| 144 |
elif "rank_issues" not in used_types:
|
| 145 |
-
parts.append("\
|
| 146 |
elif "order_steps" not in used_types:
|
| 147 |
-
parts.append("\
|
| 148 |
|
| 149 |
-
parts.append("\
|
| 150 |
return "\n".join(parts)
|
| 151 |
|
| 152 |
|
|
@@ -217,8 +232,8 @@ def run_task(client: OpenAI, task_id: str) -> float:
|
|
| 217 |
if "error" in data and not data.get("episode_id"):
|
| 218 |
# ββ MANDATORY: [START] line even on error ββ
|
| 219 |
print(f"[START] task={task_id} env={BENCHMARK} model={MODEL_NAME}", flush=True)
|
| 220 |
-
print(f"[END] success=false steps=0 score=0.
|
| 221 |
-
return 0.
|
| 222 |
|
| 223 |
episode_id = data.get("episode_id", "unknown")
|
| 224 |
obs = data.get("observation", data)
|
|
@@ -258,10 +273,11 @@ def run_task(client: OpenAI, task_id: str) -> float:
|
|
| 258 |
step_resp = requests.post(f"{ENV_URL}/step", json=action, timeout=30)
|
| 259 |
step_data = step_resp.json()
|
| 260 |
except Exception as e:
|
| 261 |
-
error_msg = str(e)
|
| 262 |
-
#
|
| 263 |
-
print(f"[STEP] step={step_num} action={action_type} reward=0.
|
| 264 |
-
rewards.append(0.
|
|
|
|
| 265 |
break
|
| 266 |
|
| 267 |
reward = float(step_data.get("reward", 0.0))
|
|
@@ -285,8 +301,8 @@ def run_task(client: OpenAI, task_id: str) -> float:
|
|
| 285 |
if done:
|
| 286 |
break
|
| 287 |
|
| 288 |
-
#
|
| 289 |
-
total_reward = sum(rewards) if rewards else 0.01
|
| 290 |
score = round(min(max(total_reward, 0.01), 0.99), 4)
|
| 291 |
success = score > 0.0
|
| 292 |
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
|
@@ -320,8 +336,8 @@ def main() -> None:
|
|
| 320 |
scores[task_id] = run_task(client, task_id)
|
| 321 |
except Exception as e:
|
| 322 |
print(f"[START] task={task_id} env={BENCHMARK} model={MODEL_NAME}", flush=True)
|
| 323 |
-
print(f"[END] success=false steps=0 score=0.
|
| 324 |
-
scores[task_id] = 0.
|
| 325 |
|
| 326 |
avg = round(sum(scores.values()) / max(len(scores), 1), 2)
|
| 327 |
print(f"\nβ
All tasks complete! Average: {avg:.2f}", flush=True)
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
def build_user_prompt(step_num: int, obs: dict, history: list) -> str:
|
|
|
|
|
|
|
|
|
|
| 87 |
task_type = obs.get("task_type", "unknown")
|
| 88 |
task_id = obs.get("task_id", "unknown")
|
| 89 |
task_sub = obs.get("task_subtype", "")
|
| 90 |
|
| 91 |
parts = [f"Step {step_num} | task_type={task_type} | task_id={task_id} | subtype={task_sub}"]
|
| 92 |
|
| 93 |
+
# History summary
|
| 94 |
if history:
|
| 95 |
used = [h["action_type"] for h in history]
|
| 96 |
last = history[-1]
|
| 97 |
+
parts.append(f"Actions used: {used}")
|
| 98 |
parts.append(f"Last reward: {last['reward']:.2f}")
|
| 99 |
+
if last["reward"] < 0.4:
|
| 100 |
+
parts.append(f"β οΈ Low score. Try different approach.")
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
# Validation failure
|
| 103 |
if obs.get("validation_failed"):
|
| 104 |
+
parts.append(f"\nβ VALIDATION FAILED!")
|
| 105 |
+
parts.append(f"Error: {obs.get('message', 'unknown')}")
|
| 106 |
+
parts.append(f"Fix: {obs.get('hint', '')}")
|
|
|
|
|
|
|
| 107 |
|
| 108 |
+
# Reviewer feedback
|
| 109 |
if obs.get("reviewer_feedback"):
|
| 110 |
+
parts.append(f"\nπ REVIEWER FEEDBACK:")
|
| 111 |
parts.append(obs["reviewer_feedback"])
|
| 112 |
|
| 113 |
+
# SMART TRUNCATION: Separate critical fields
|
| 114 |
obs_copy = dict(obs)
|
| 115 |
+
|
| 116 |
+
# Extract large fields that agents NEED
|
| 117 |
+
compat_matrix = obs_copy.pop("compatibility_matrix", None)
|
| 118 |
+
dep_graph = obs_copy.pop("dependency_graph", None)
|
| 119 |
+
|
| 120 |
+
# Core observation (always include)
|
| 121 |
+
core_text = json.dumps(obs_copy, default=str, indent=2)
|
| 122 |
+
parts.append(f"\nObservation:\n{core_text}")
|
| 123 |
+
|
| 124 |
+
# Compatibility matrix (for dep tasks) - don't truncate
|
| 125 |
+
if compat_matrix:
|
| 126 |
+
# Format nicely so model can parse
|
| 127 |
+
parts.append(f"\nCompatibility Matrix (use this to resolve conflicts):")
|
| 128 |
+
for pkg, versions in compat_matrix.items():
|
| 129 |
+
parts.append(f" {pkg}:")
|
| 130 |
+
for ver, deps in versions.items():
|
| 131 |
+
if deps:
|
| 132 |
+
parts.append(f" {ver} β requires {deps}")
|
| 133 |
+
else:
|
| 134 |
+
parts.append(f" {ver} β (no deps)")
|
| 135 |
+
|
| 136 |
+
# Dependency graph (for cli tasks)
|
| 137 |
+
if dep_graph:
|
| 138 |
+
parts.append(f"\nDependency Graph (prerequisites must come first):")
|
| 139 |
+
for step, prereqs in dep_graph.items():
|
| 140 |
+
if prereqs:
|
| 141 |
+
parts.append(f" {step} requires: {prereqs}")
|
| 142 |
+
else:
|
| 143 |
+
parts.append(f" {step} β (no prereqs)")
|
| 144 |
|
| 145 |
+
# Next action hint
|
| 146 |
if task_type == "security":
|
| 147 |
used_types = [h["action_type"] for h in history]
|
| 148 |
+
if not used_types or "identify_vulnerability" not in used_types:
|
| 149 |
+
parts.append("\nβ‘οΈ NEXT: identify_vulnerability")
|
| 150 |
elif "propose_fix" not in used_types:
|
| 151 |
+
parts.append("\nβ‘οΈ NEXT: propose_fix")
|
| 152 |
else:
|
| 153 |
+
parts.append("\nβ‘οΈ NEXT: revise_fix (address reviewer_feedback)")
|
| 154 |
+
|
| 155 |
elif task_type == "clinical":
|
| 156 |
used_types = [h["action_type"] for h in history]
|
| 157 |
if "detect_gap" not in used_types:
|
| 158 |
+
parts.append("\nβ‘οΈ NEXT: detect_gap")
|
| 159 |
elif "rank_issues" not in used_types:
|
| 160 |
+
parts.append("\nβ‘οΈ NEXT: rank_issues")
|
| 161 |
elif "order_steps" not in used_types:
|
| 162 |
+
parts.append("\nβ‘οΈ NEXT: order_steps (respect dependency_graph)")
|
| 163 |
|
| 164 |
+
parts.append("\nπ€ Output ONLY a single JSON object:")
|
| 165 |
return "\n".join(parts)
|
| 166 |
|
| 167 |
|
|
|
|
| 232 |
if "error" in data and not data.get("episode_id"):
|
| 233 |
# ββ MANDATORY: [START] line even on error ββ
|
| 234 |
print(f"[START] task={task_id} env={BENCHMARK} model={MODEL_NAME}", flush=True)
|
| 235 |
+
print(f"[END] success=false steps=0 score=0.01 rewards=", flush=True)
|
| 236 |
+
return 0.01
|
| 237 |
|
| 238 |
episode_id = data.get("episode_id", "unknown")
|
| 239 |
obs = data.get("observation", data)
|
|
|
|
| 273 |
step_resp = requests.post(f"{ENV_URL}/step", json=action, timeout=30)
|
| 274 |
step_data = step_resp.json()
|
| 275 |
except Exception as e:
|
| 276 |
+
error_msg = str(e)[:100] # Truncate long errors
|
| 277 |
+
# Give the agent credit for steps completed so far
|
| 278 |
+
print(f"[STEP] step={step_num} action={action_type} reward=0.01 done=true error={error_msg}", flush=True)
|
| 279 |
+
rewards.append(0.01)
|
| 280 |
+
done = True
|
| 281 |
break
|
| 282 |
|
| 283 |
reward = float(step_data.get("reward", 0.0))
|
|
|
|
| 301 |
if done:
|
| 302 |
break
|
| 303 |
|
| 304 |
+
# Average gives partial credit for completed steps before crash
|
| 305 |
+
total_reward = sum(rewards) / max(len(rewards), 1) if rewards else 0.01
|
| 306 |
score = round(min(max(total_reward, 0.01), 0.99), 4)
|
| 307 |
success = score > 0.0
|
| 308 |
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
|
|
|
| 336 |
scores[task_id] = run_task(client, task_id)
|
| 337 |
except Exception as e:
|
| 338 |
print(f"[START] task={task_id} env={BENCHMARK} model={MODEL_NAME}", flush=True)
|
| 339 |
+
print(f"[END] success=false steps=0 score=0.01 rewards=", flush=True)
|
| 340 |
+
scores[task_id] = 0.01
|
| 341 |
|
| 342 |
avg = round(sum(scores.values()) / max(len(scores), 1), 2)
|
| 343 |
print(f"\nβ
All tasks complete! Average: {avg:.2f}", flush=True)
|
mock.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('.')
|
| 3 |
+
from inference import parse_action, MAX_STEPS
|
| 4 |
+
import functools
|
| 5 |
+
|
| 6 |
+
# Mock the print function so we can capture exactly what the user wants to see
|
| 7 |
+
def run_mock():
|
| 8 |
+
task_id = "sec_easy"
|
| 9 |
+
episode_id = "ep-abc12345"
|
| 10 |
+
step_num = 1
|
| 11 |
+
display_action = "identify_vulnerability"
|
| 12 |
+
reward = 0.5
|
| 13 |
+
done = False
|
| 14 |
+
|
| 15 |
+
print(f'[START] task_id={task_id} episode_id={episode_id}')
|
| 16 |
+
print(f'[STEP] task_id={task_id} step={step_num} action={display_action} reward={reward:.4f} done={done}')
|
| 17 |
+
|
| 18 |
+
step_num = 2
|
| 19 |
+
display_action = "propose_fix"
|
| 20 |
+
reward = 1.0
|
| 21 |
+
done = True
|
| 22 |
+
print(f'[STEP] task_id={task_id} step={step_num} action={display_action} reward={reward:.4f} done={done}')
|
| 23 |
+
|
| 24 |
+
total_reward = 1.5
|
| 25 |
+
print(f'[END] task_id={task_id} episode_id={episode_id} total_reward={total_reward:.4f} steps={step_num}')
|
| 26 |
+
|
| 27 |
+
run_mock()
|
server/app.py
CHANGED
|
@@ -20,6 +20,26 @@ from .datasets.clinical_cases import CLINICAL_CASES
|
|
| 20 |
|
| 21 |
app = FastAPI(title='Multi-Agent Dev Tools Environment')
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# ββ Load Debug Panel HTML ββ
|
| 24 |
_DEBUG_HTML_PATH = os.path.join(os.path.dirname(__file__), 'debug_panel.html')
|
| 25 |
|
|
@@ -105,6 +125,16 @@ async def health(request: Request):
|
|
| 105 |
@app.post('/reset')
|
| 106 |
async def reset(request: Request):
|
| 107 |
"""Create a new episode for a task. Returns episode_id + initial observation."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
try:
|
| 109 |
body = await request.json()
|
| 110 |
task_id = body.get('task_id', 'sec_easy')
|
|
@@ -123,9 +153,10 @@ async def reset(request: Request):
|
|
| 123 |
SESSIONS[session.episode_id] = session
|
| 124 |
|
| 125 |
# Cleanup old done sessions to prevent memory leaks on HF Spaces
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
| 129 |
|
| 130 |
obs = build_initial_obs(session)
|
| 131 |
|
|
@@ -138,7 +169,7 @@ async def reset(request: Request):
|
|
| 138 |
'error': str(e),
|
| 139 |
'observation': {},
|
| 140 |
'done': True,
|
| 141 |
-
'reward': 0.
|
| 142 |
})
|
| 143 |
|
| 144 |
|
|
@@ -152,7 +183,7 @@ async def step(request: Request):
|
|
| 152 |
|
| 153 |
if not session:
|
| 154 |
return JSONResponse(status_code=200, content={
|
| 155 |
-
'reward': 0.
|
| 156 |
'done': True,
|
| 157 |
'error': 'unknown episode_id',
|
| 158 |
'observation': {},
|
|
@@ -160,7 +191,7 @@ async def step(request: Request):
|
|
| 160 |
|
| 161 |
if session.done:
|
| 162 |
return JSONResponse(status_code=200, content={
|
| 163 |
-
'reward': 0.
|
| 164 |
'done': True,
|
| 165 |
'observation': {'message': 'Episode already complete.'},
|
| 166 |
})
|
|
@@ -168,9 +199,9 @@ async def step(request: Request):
|
|
| 168 |
# Run pre-action validation
|
| 169 |
valid, val_obs = validate_action(body, session)
|
| 170 |
if not valid:
|
| 171 |
-
last_r = 0.
|
| 172 |
if session.history:
|
| 173 |
-
last_r = session.history[-1].get('reward', 0.
|
| 174 |
return {
|
| 175 |
'reward': last_r,
|
| 176 |
'done': False,
|
|
@@ -189,17 +220,19 @@ async def step(request: Request):
|
|
| 189 |
|
| 190 |
# Enrich observation with strategic context
|
| 191 |
step_obs = result.get('observation', {})
|
| 192 |
-
step_obs['task_type'] = session.task_type
|
| 193 |
-
step_obs['task_id'] = session.task_id
|
| 194 |
-
step_obs['step_count'] = session.step_count
|
| 195 |
task_max = DOMAIN_MAX_STEPS.get(session.task_type, 8)
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
# Turn guidance β tell agent what to do next
|
| 205 |
last_action = body.get('action_type', '')
|
|
@@ -223,14 +256,14 @@ async def step(request: Request):
|
|
| 223 |
SESSIONS.pop(session.episode_id, None)
|
| 224 |
|
| 225 |
return {
|
| 226 |
-
'reward': round(float(result.get('reward', 0.0)), 4),
|
| 227 |
'done': bool(result.get('done', False)),
|
| 228 |
'observation': step_obs,
|
| 229 |
'info': {'validation_failed': step_obs.get('validation_failed', False)},
|
| 230 |
}
|
| 231 |
except Exception as e:
|
| 232 |
return JSONResponse(status_code=200, content={
|
| 233 |
-
'reward': 0.
|
| 234 |
'done': True,
|
| 235 |
'error': str(e),
|
| 236 |
'observation': {},
|
|
@@ -326,7 +359,7 @@ async def run_inference(request: Request):
|
|
| 326 |
try:
|
| 327 |
final_scores[task_id] = float(total_reward)
|
| 328 |
except ValueError:
|
| 329 |
-
final_scores[task_id] = 0.
|
| 330 |
|
| 331 |
# Also try final JSON summary line
|
| 332 |
for line in reversed(stdout.splitlines()):
|
|
@@ -342,7 +375,7 @@ async def run_inference(request: Request):
|
|
| 342 |
|
| 343 |
avg = (
|
| 344 |
round(sum(final_scores.values()) / len(final_scores), 4)
|
| 345 |
-
if final_scores else 0.
|
| 346 |
)
|
| 347 |
|
| 348 |
return JSONResponse(status_code=200, content={
|
|
@@ -419,7 +452,7 @@ def _run_single_task_inline(task_id, api_base, api_key, model_id, system_prompt)
|
|
| 419 |
msg = f'[ERROR] OpenAI client init failed: {e}'
|
| 420 |
logs.append(msg)
|
| 421 |
yield {'type': 'log', 'level': 'err', 'msg': msg}
|
| 422 |
-
yield {'type': 'task_done', 'task_id': task_id, 'score': 0.
|
| 423 |
return
|
| 424 |
|
| 425 |
# Reset
|
|
@@ -430,7 +463,7 @@ def _run_single_task_inline(task_id, api_base, api_key, model_id, system_prompt)
|
|
| 430 |
msg = f'[ERROR] Reset failed: {e}'
|
| 431 |
logs.append(msg)
|
| 432 |
yield {'type': 'log', 'level': 'err', 'msg': msg}
|
| 433 |
-
yield {'type': 'task_done', 'task_id': task_id, 'score': 0.
|
| 434 |
return
|
| 435 |
|
| 436 |
ep_id = data.get('episode_id', 'unknown')
|
|
@@ -447,16 +480,21 @@ def _run_single_task_inline(task_id, api_base, api_key, model_id, system_prompt)
|
|
| 447 |
|
| 448 |
while not done and len(rewards) < max_steps:
|
| 449 |
step_num = len(rewards) + 1
|
| 450 |
-
# Build focused prompt with
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
|
|
|
| 454 |
user_parts = [f'Step {step_num} | Observation:']
|
| 455 |
if history:
|
| 456 |
user_parts.append(f'Previous actions: {[h["action_type"] for h in history]}')
|
| 457 |
-
if history[-1]['reward']
|
| 458 |
-
user_parts.append('
|
| 459 |
-
user_parts.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
user_parts.append('Output ONLY a single JSON object:')
|
| 461 |
messages.append({'role': 'user', 'content': '\n'.join(user_parts)})
|
| 462 |
|
|
@@ -519,8 +557,8 @@ def _run_single_task_inline(task_id, api_base, api_key, model_id, system_prompt)
|
|
| 519 |
logs.append(msg)
|
| 520 |
yield {'type': 'log', 'level': 'info', 'msg': msg}
|
| 521 |
|
| 522 |
-
#
|
| 523 |
-
total_reward = sum(rewards) if rewards else 0.01
|
| 524 |
score = round(min(max(total_reward, 0.01), 0.99), 4)
|
| 525 |
success = score > 0.0
|
| 526 |
rewards_str = ','.join(f'{r:.2f}' for r in rewards)
|
|
@@ -560,7 +598,7 @@ def run_benchmark(body: dict):
|
|
| 560 |
scores[task_id] = event['score']
|
| 561 |
yield f"data: {json.dumps(event)}\n\n"
|
| 562 |
|
| 563 |
-
avg = round(sum(scores.values()) / len(scores), 4) if scores else 0.
|
| 564 |
|
| 565 |
result = {
|
| 566 |
'model_name': model_name,
|
|
|
|
| 20 |
|
| 21 |
app = FastAPI(title='Multi-Agent Dev Tools Environment')
|
| 22 |
|
| 23 |
+
from collections import defaultdict
|
| 24 |
+
from time import time
|
| 25 |
+
|
| 26 |
+
# Global rate limiter (simple token bucket)
|
| 27 |
+
RATE_LIMITS = defaultdict(lambda: {'tokens': 10, 'last_refill': time()})
|
| 28 |
+
|
| 29 |
+
def check_rate_limit(ip: str) -> bool:
|
| 30 |
+
"""Returns True if request allowed, False if rate limited."""
|
| 31 |
+
bucket = RATE_LIMITS[ip]
|
| 32 |
+
now = time()
|
| 33 |
+
elapsed = now - bucket['last_refill']
|
| 34 |
+
refill = int(elapsed / 6)
|
| 35 |
+
if refill > 0:
|
| 36 |
+
bucket['tokens'] = min(10, bucket['tokens'] + refill)
|
| 37 |
+
bucket['last_refill'] = now
|
| 38 |
+
if bucket['tokens'] > 0:
|
| 39 |
+
bucket['tokens'] -= 1
|
| 40 |
+
return True
|
| 41 |
+
return False
|
| 42 |
+
|
| 43 |
# ββ Load Debug Panel HTML ββ
|
| 44 |
_DEBUG_HTML_PATH = os.path.join(os.path.dirname(__file__), 'debug_panel.html')
|
| 45 |
|
|
|
|
| 125 |
@app.post('/reset')
|
| 126 |
async def reset(request: Request):
|
| 127 |
"""Create a new episode for a task. Returns episode_id + initial observation."""
|
| 128 |
+
|
| 129 |
+
# Get client IP
|
| 130 |
+
ip = request.client.host if request.client else '127.0.0.1'
|
| 131 |
+
if not check_rate_limit(ip):
|
| 132 |
+
return JSONResponse(status_code=200, content={
|
| 133 |
+
'error': 'Rate limit exceeded. Max 10 requests/minute.',
|
| 134 |
+
'done': True,
|
| 135 |
+
'observation': {},
|
| 136 |
+
})
|
| 137 |
+
|
| 138 |
try:
|
| 139 |
body = await request.json()
|
| 140 |
task_id = body.get('task_id', 'sec_easy')
|
|
|
|
| 153 |
SESSIONS[session.episode_id] = session
|
| 154 |
|
| 155 |
# Cleanup old done sessions to prevent memory leaks on HF Spaces
|
| 156 |
+
if len(SESSIONS) > 100 or random.random() < 0.1:
|
| 157 |
+
done_ids = [eid for eid, s in SESSIONS.items() if s.done]
|
| 158 |
+
for eid in done_ids:
|
| 159 |
+
SESSIONS.pop(eid, None)
|
| 160 |
|
| 161 |
obs = build_initial_obs(session)
|
| 162 |
|
|
|
|
| 169 |
'error': str(e),
|
| 170 |
'observation': {},
|
| 171 |
'done': True,
|
| 172 |
+
'reward': 0.01,
|
| 173 |
})
|
| 174 |
|
| 175 |
|
|
|
|
| 183 |
|
| 184 |
if not session:
|
| 185 |
return JSONResponse(status_code=200, content={
|
| 186 |
+
'reward': 0.01,
|
| 187 |
'done': True,
|
| 188 |
'error': 'unknown episode_id',
|
| 189 |
'observation': {},
|
|
|
|
| 191 |
|
| 192 |
if session.done:
|
| 193 |
return JSONResponse(status_code=200, content={
|
| 194 |
+
'reward': 0.01,
|
| 195 |
'done': True,
|
| 196 |
'observation': {'message': 'Episode already complete.'},
|
| 197 |
})
|
|
|
|
| 199 |
# Run pre-action validation
|
| 200 |
valid, val_obs = validate_action(body, session)
|
| 201 |
if not valid:
|
| 202 |
+
last_r = 0.01
|
| 203 |
if session.history:
|
| 204 |
+
last_r = max(0.01, session.history[-1].get('reward', 0.01))
|
| 205 |
return {
|
| 206 |
'reward': last_r,
|
| 207 |
'done': False,
|
|
|
|
| 220 |
|
| 221 |
# Enrich observation with strategic context
|
| 222 |
step_obs = result.get('observation', {})
|
|
|
|
|
|
|
|
|
|
| 223 |
task_max = DOMAIN_MAX_STEPS.get(session.task_type, 8)
|
| 224 |
+
enrichment = {
|
| 225 |
+
'task_type': session.task_type,
|
| 226 |
+
'task_id': session.task_id,
|
| 227 |
+
'step_count': session.step_count,
|
| 228 |
+
'max_steps': task_max,
|
| 229 |
+
'previous_reward': round(float(result.get('reward', 0.0)), 4),
|
| 230 |
+
'steps_remaining': max(0, task_max - session.step_count),
|
| 231 |
+
'reward_so_far': round(session.reward_acc, 4),
|
| 232 |
+
'trajectory_score': round(session.reward_acc / max(session.step_count, 1), 4),
|
| 233 |
+
}
|
| 234 |
+
for k, v in enrichment.items():
|
| 235 |
+
step_obs.setdefault(k, v)
|
| 236 |
|
| 237 |
# Turn guidance β tell agent what to do next
|
| 238 |
last_action = body.get('action_type', '')
|
|
|
|
| 256 |
SESSIONS.pop(session.episode_id, None)
|
| 257 |
|
| 258 |
return {
|
| 259 |
+
'reward': round(min(max(float(result.get('reward', 0.01)), 0.01), 0.99), 4),
|
| 260 |
'done': bool(result.get('done', False)),
|
| 261 |
'observation': step_obs,
|
| 262 |
'info': {'validation_failed': step_obs.get('validation_failed', False)},
|
| 263 |
}
|
| 264 |
except Exception as e:
|
| 265 |
return JSONResponse(status_code=200, content={
|
| 266 |
+
'reward': 0.01,
|
| 267 |
'done': True,
|
| 268 |
'error': str(e),
|
| 269 |
'observation': {},
|
|
|
|
| 359 |
try:
|
| 360 |
final_scores[task_id] = float(total_reward)
|
| 361 |
except ValueError:
|
| 362 |
+
final_scores[task_id] = 0.01
|
| 363 |
|
| 364 |
# Also try final JSON summary line
|
| 365 |
for line in reversed(stdout.splitlines()):
|
|
|
|
| 375 |
|
| 376 |
avg = (
|
| 377 |
round(sum(final_scores.values()) / len(final_scores), 4)
|
| 378 |
+
if final_scores else 0.01
|
| 379 |
)
|
| 380 |
|
| 381 |
return JSONResponse(status_code=200, content={
|
|
|
|
| 452 |
msg = f'[ERROR] OpenAI client init failed: {e}'
|
| 453 |
logs.append(msg)
|
| 454 |
yield {'type': 'log', 'level': 'err', 'msg': msg}
|
| 455 |
+
yield {'type': 'task_done', 'task_id': task_id, 'score': 0.01, 'logs': logs}
|
| 456 |
return
|
| 457 |
|
| 458 |
# Reset
|
|
|
|
| 463 |
msg = f'[ERROR] Reset failed: {e}'
|
| 464 |
logs.append(msg)
|
| 465 |
yield {'type': 'log', 'level': 'err', 'msg': msg}
|
| 466 |
+
yield {'type': 'task_done', 'task_id': task_id, 'score': 0.01, 'logs': logs}
|
| 467 |
return
|
| 468 |
|
| 469 |
ep_id = data.get('episode_id', 'unknown')
|
|
|
|
| 480 |
|
| 481 |
while not done and len(rewards) < max_steps:
|
| 482 |
step_num = len(rewards) + 1
|
| 483 |
+
# Build focused prompt with smart truncation (matches inference.py)
|
| 484 |
+
obs_copy = dict(obs)
|
| 485 |
+
compat_matrix = obs_copy.pop('compatibility_matrix', None)
|
| 486 |
+
dep_graph = obs_copy.pop('dependency_graph', None)
|
| 487 |
+
core_text = json.dumps(obs_copy, default=str, indent=2)
|
| 488 |
user_parts = [f'Step {step_num} | Observation:']
|
| 489 |
if history:
|
| 490 |
user_parts.append(f'Previous actions: {[h["action_type"] for h in history]}')
|
| 491 |
+
if history[-1]['reward'] < 0.4:
|
| 492 |
+
user_parts.append('β οΈ Low score. Try different approach.')
|
| 493 |
+
user_parts.append(core_text)
|
| 494 |
+
if compat_matrix:
|
| 495 |
+
user_parts.append(f'\nCompatibility Matrix:\n{json.dumps(compat_matrix, indent=2)}')
|
| 496 |
+
if dep_graph:
|
| 497 |
+
user_parts.append(f'\nDependency Graph:\n{json.dumps(dep_graph, indent=2)}')
|
| 498 |
user_parts.append('Output ONLY a single JSON object:')
|
| 499 |
messages.append({'role': 'user', 'content': '\n'.join(user_parts)})
|
| 500 |
|
|
|
|
| 557 |
logs.append(msg)
|
| 558 |
yield {'type': 'log', 'level': 'info', 'msg': msg}
|
| 559 |
|
| 560 |
+
# Average rewards β same logic as inference.py
|
| 561 |
+
total_reward = sum(rewards) / max(len(rewards), 1) if rewards else 0.01
|
| 562 |
score = round(min(max(total_reward, 0.01), 0.99), 4)
|
| 563 |
success = score > 0.0
|
| 564 |
rewards_str = ','.join(f'{r:.2f}' for r in rewards)
|
|
|
|
| 598 |
scores[task_id] = event['score']
|
| 599 |
yield f"data: {json.dumps(event)}\n\n"
|
| 600 |
|
| 601 |
+
avg = round(sum(scores.values()) / len(scores), 4) if scores else 0.01
|
| 602 |
|
| 603 |
result = {
|
| 604 |
'model_name': model_name,
|
server/datasets/clinical_cases.py
CHANGED
|
@@ -176,5 +176,70 @@ CLINICAL_CASES = {
|
|
| 176 |
'available_steps': ['stabilize_vitals', 'cardiology_consult', 'imaging_ordered', 'medication_review', 'family_notification'],
|
| 177 |
'task_description': 'Complex cardiac emergency recovery plan. Multiple dependency chains. Medication review needs both cardiology consult AND imaging. Respect ALL prerequisites.',
|
| 178 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
],
|
| 180 |
}
|
|
|
|
| 176 |
'available_steps': ['stabilize_vitals', 'cardiology_consult', 'imaging_ordered', 'medication_review', 'family_notification'],
|
| 177 |
'task_description': 'Complex cardiac emergency recovery plan. Multiple dependency chains. Medication review needs both cardiology consult AND imaging. Respect ALL prerequisites.',
|
| 178 |
},
|
| 179 |
+
{
|
| 180 |
+
'case_id': 'cli_hard_003',
|
| 181 |
+
'completion_threshold': 0.70,
|
| 182 |
+
'max_steps': 6,
|
| 183 |
+
'done_conditions': {'min_actions': 3, 'required_sequence': ['detect_gap', 'rank_issues', 'order_steps']},
|
| 184 |
+
'patient_id': 'P303',
|
| 185 |
+
'patient_events': ['chemo_ordered', 'lab_results_missing', 'dose_unclear', 'pharmacy_backlog'],
|
| 186 |
+
'events': ['chemo_ordered', 'lab_results_missing', 'dose_unclear', 'pharmacy_backlog'],
|
| 187 |
+
'expected_missing_steps': ['baseline_cbc', 'oncology_dose_verify', 'pharmacy_prep', 'nurse_admin_check'],
|
| 188 |
+
'expected_risk': 'critical',
|
| 189 |
+
'priority_order': ['baseline_cbc', 'oncology_dose_verify', 'pharmacy_prep', 'nurse_admin_check'],
|
| 190 |
+
'dependency_graph': {
|
| 191 |
+
'nurse_admin_check': ['pharmacy_prep'],
|
| 192 |
+
'pharmacy_prep': ['oncology_dose_verify', 'baseline_cbc'],
|
| 193 |
+
'oncology_dose_verify': ['baseline_cbc'],
|
| 194 |
+
'baseline_cbc': [],
|
| 195 |
+
},
|
| 196 |
+
'required_steps': ['baseline_cbc', 'oncology_dose_verify', 'pharmacy_prep', 'nurse_admin_check'],
|
| 197 |
+
'available_steps': ['baseline_cbc', 'oncology_dose_verify', 'pharmacy_prep', 'nurse_admin_check'],
|
| 198 |
+
'task_description': 'Chemotherapy workflow chaos. Multiple safety steps skipped. Labs must come before dose verification. Pharmacy needs both labs AND dose verification before prep. Plan safe recovery sequence.',
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
'case_id': 'cli_hard_004',
|
| 202 |
+
'completion_threshold': 0.70,
|
| 203 |
+
'max_steps': 6,
|
| 204 |
+
'done_conditions': {'min_actions': 3, 'required_sequence': ['detect_gap', 'rank_issues', 'order_steps']},
|
| 205 |
+
'patient_id': 'P304',
|
| 206 |
+
'patient_events': ['transplant_scheduled', 'donor_typing_incomplete', 'immunosuppress_missing', 'consent_partial'],
|
| 207 |
+
'events': ['transplant_scheduled', 'donor_typing_incomplete', 'immunosuppress_missing', 'consent_partial'],
|
| 208 |
+
'expected_missing_steps': ['hla_typing', 'crossmatch', 'immunosuppress_order', 'full_consent', 'surgery_slot'],
|
| 209 |
+
'expected_risk': 'critical',
|
| 210 |
+
'priority_order': ['hla_typing', 'crossmatch', 'full_consent', 'immunosuppress_order', 'surgery_slot'],
|
| 211 |
+
'dependency_graph': {
|
| 212 |
+
'surgery_slot': ['hla_typing', 'crossmatch', 'full_consent', 'immunosuppress_order'],
|
| 213 |
+
'immunosuppress_order': ['crossmatch'],
|
| 214 |
+
'crossmatch': ['hla_typing'],
|
| 215 |
+
'full_consent': [],
|
| 216 |
+
'hla_typing': [],
|
| 217 |
+
},
|
| 218 |
+
'required_steps': ['hla_typing', 'crossmatch', 'immunosuppress_order', 'full_consent', 'surgery_slot'],
|
| 219 |
+
'available_steps': ['hla_typing', 'crossmatch', 'immunosuppress_order', 'full_consent', 'surgery_slot'],
|
| 220 |
+
'task_description': 'Organ transplant pre-op disaster. Complex dependency chain: HLA typing β crossmatch β immunosuppression. Surgery booking requires ALL steps. One wrong order could delay transplant.',
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
'case_id': 'cli_hard_005',
|
| 224 |
+
'completion_threshold': 0.70,
|
| 225 |
+
'max_steps': 6,
|
| 226 |
+
'done_conditions': {'min_actions': 3, 'required_sequence': ['detect_gap', 'rank_issues', 'order_steps']},
|
| 227 |
+
'patient_id': 'P305',
|
| 228 |
+
'patient_events': ['stroke_code', 'imaging_delayed', 'tpa_window_closing', 'neuro_unavailable'],
|
| 229 |
+
'events': ['stroke_code', 'imaging_delayed', 'tpa_window_closing', 'neuro_unavailable'],
|
| 230 |
+
'expected_missing_steps': ['ct_head', 'neuro_consult', 'tpa_eligibility', 'family_consent', 'icu_bed'],
|
| 231 |
+
'expected_risk': 'critical',
|
| 232 |
+
'priority_order': ['ct_head', 'tpa_eligibility', 'neuro_consult', 'family_consent', 'icu_bed'],
|
| 233 |
+
'dependency_graph': {
|
| 234 |
+
'icu_bed': ['tpa_eligibility'],
|
| 235 |
+
'family_consent': ['tpa_eligibility', 'neuro_consult'],
|
| 236 |
+
'neuro_consult': ['ct_head'],
|
| 237 |
+
'tpa_eligibility': ['ct_head'],
|
| 238 |
+
'ct_head': [],
|
| 239 |
+
},
|
| 240 |
+
'required_steps': ['ct_head', 'neuro_consult', 'tpa_eligibility', 'family_consent', 'icu_bed'],
|
| 241 |
+
'available_steps': ['ct_head', 'neuro_consult', 'tpa_eligibility', 'family_consent', 'icu_bed'],
|
| 242 |
+
'task_description': 'Acute stroke code with tPA window closing. CT must come first. Eligibility and neuro consult both depend on CT. Family consent needs both eligibility AND neuro. ICU booking after eligibility confirmed. Time-critical recovery plan needed.',
|
| 243 |
+
},
|
| 244 |
],
|
| 245 |
}
|
server/datasets/dependency_cases.py
CHANGED
|
@@ -276,5 +276,153 @@ def training_step(model, x, labels):
|
|
| 276 |
],
|
| 277 |
'task_description': 'Fix all 4 graph-break patterns in this compiled training step. Dependencies must be resolved in order.',
|
| 278 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
],
|
| 280 |
}
|
|
|
|
| 276 |
],
|
| 277 |
'task_description': 'Fix all 4 graph-break patterns in this compiled training step. Dependencies must be resolved in order.',
|
| 278 |
},
|
| 279 |
+
{
|
| 280 |
+
'case_id': 'dep_hard_003',
|
| 281 |
+
'task_subtype': 'migrate',
|
| 282 |
+
'completion_threshold': 0.70,
|
| 283 |
+
'max_steps': 8,
|
| 284 |
+
'done_conditions': {'min_actions': 2, 'required_sequence': ['migrate_api']},
|
| 285 |
+
'graph_breaks': ['break_x', 'break_y', 'break_z'],
|
| 286 |
+
'checklist_dependency_graph': {
|
| 287 |
+
'break_z': ['break_x'], # z depends on x
|
| 288 |
+
'break_y': [], # y is independent
|
| 289 |
+
'break_x': [], # x is independent
|
| 290 |
+
},
|
| 291 |
+
'correct_fix_map': {
|
| 292 |
+
'break_x': 'tensor.numel()',
|
| 293 |
+
'break_y': 'torch.jit.script',
|
| 294 |
+
'break_z': 'torch.no_grad()',
|
| 295 |
+
},
|
| 296 |
+
'code_snippet': '''import torch
|
| 297 |
+
|
| 298 |
+
@torch.compile
|
| 299 |
+
def forward(x, mask):
|
| 300 |
+
# break_x: tensor.size() returns Python int (graph break)
|
| 301 |
+
n = x.size(0) * x.size(1)
|
| 302 |
+
|
| 303 |
+
# break_y: Python function call inside compile
|
| 304 |
+
def custom_fn(t):
|
| 305 |
+
return t * 2
|
| 306 |
+
x = custom_fn(x)
|
| 307 |
+
|
| 308 |
+
# break_z: gradient tracking inside compiled region
|
| 309 |
+
with torch.enable_grad(): # breaks graph
|
| 310 |
+
x = x * mask
|
| 311 |
+
|
| 312 |
+
return x''',
|
| 313 |
+
'break_descriptions': [
|
| 314 |
+
'break_x: line 6 β tensor.size() returns Python int, use tensor.numel() instead',
|
| 315 |
+
'break_y: line 10 β Python function call, use torch.jit.script decorator',
|
| 316 |
+
'break_z: line 14 β enable_grad inside compile, use torch.no_grad() for inference',
|
| 317 |
+
],
|
| 318 |
+
'graph_break_report': [
|
| 319 |
+
'break_x: line 6 β tensor.size() returns Python int, use tensor.numel() instead',
|
| 320 |
+
'break_y: line 10 β Python function call, use torch.jit.script decorator',
|
| 321 |
+
'break_z: line 14 β enable_grad inside compile, use torch.no_grad() for inference',
|
| 322 |
+
],
|
| 323 |
+
'task_description': 'Fix torch.compile graph breaks in this custom layer. Note dependency: break_z needs break_x fixed first.',
|
| 324 |
+
},
|
| 325 |
+
{
|
| 326 |
+
'case_id': 'dep_hard_004',
|
| 327 |
+
'task_subtype': 'migrate',
|
| 328 |
+
'completion_threshold': 0.70,
|
| 329 |
+
'max_steps': 8,
|
| 330 |
+
'done_conditions': {'min_actions': 2, 'required_sequence': ['migrate_api']},
|
| 331 |
+
'graph_breaks': ['break_alpha', 'break_beta', 'break_gamma', 'break_delta'],
|
| 332 |
+
'checklist_dependency_graph': {
|
| 333 |
+
'break_delta': ['break_beta', 'break_gamma'], # delta needs both
|
| 334 |
+
'break_gamma': ['break_alpha'], # gamma needs alpha
|
| 335 |
+
'break_beta': [],
|
| 336 |
+
'break_alpha': [],
|
| 337 |
+
},
|
| 338 |
+
'correct_fix_map': {
|
| 339 |
+
'break_alpha': 'torch.where',
|
| 340 |
+
'break_beta': 'tensor.shape[0]',
|
| 341 |
+
'break_gamma': 'torch.stack',
|
| 342 |
+
'break_delta': '@torch.jit.script',
|
| 343 |
+
},
|
| 344 |
+
'code_snippet': '''import torch
|
| 345 |
+
|
| 346 |
+
@torch.compile(fullgraph=True)
|
| 347 |
+
def loss_fn(pred, target, weights):
|
| 348 |
+
# break_alpha: if statement on tensor value
|
| 349 |
+
if target.sum() > 0:
|
| 350 |
+
pred = pred * 1.5
|
| 351 |
+
|
| 352 |
+
# break_beta: len() on tensor
|
| 353 |
+
batch_size = len(pred)
|
| 354 |
+
|
| 355 |
+
# break_gamma: Python list β tensor conversion
|
| 356 |
+
normalized = []
|
| 357 |
+
for i in range(batch_size):
|
| 358 |
+
normalized.append(pred[i] / weights[i])
|
| 359 |
+
result = torch.tensor(normalized) # breaks graph
|
| 360 |
+
|
| 361 |
+
# break_delta: calls non-scripted helper
|
| 362 |
+
def helper(x):
|
| 363 |
+
return x.clamp(0, 1)
|
| 364 |
+
return helper(result)''',
|
| 365 |
+
'break_descriptions': [
|
| 366 |
+
'break_alpha: line 6 β data-dependent control flow, use torch.where(condition, ...)',
|
| 367 |
+
'break_beta: line 10 β len() builtin on tensor, use tensor.shape[0]',
|
| 368 |
+
'break_gamma: line 16 β torch.tensor() on Python list, use torch.stack()',
|
| 369 |
+
'break_delta: line 20 β unscripted helper function, add @torch.jit.script decorator',
|
| 370 |
+
],
|
| 371 |
+
'graph_break_report': [
|
| 372 |
+
'break_alpha: line 6 β data-dependent control flow, use torch.where(condition, ...)',
|
| 373 |
+
'break_beta: line 10 β len() builtin on tensor, use tensor.shape[0]',
|
| 374 |
+
'break_gamma: line 16 β torch.tensor() on Python list, use torch.stack()',
|
| 375 |
+
'break_delta: line 20 β unscripted helper function, add @torch.jit.script decorator',
|
| 376 |
+
],
|
| 377 |
+
'task_description': 'Complex graph-break cascade. Delta depends on Beta AND Gamma. Gamma depends on Alpha. Fix in dependency order.',
|
| 378 |
+
},
|
| 379 |
+
{
|
| 380 |
+
'case_id': 'dep_hard_005',
|
| 381 |
+
'task_subtype': 'migrate',
|
| 382 |
+
'completion_threshold': 0.70,
|
| 383 |
+
'max_steps': 8,
|
| 384 |
+
'done_conditions': {'min_actions': 2, 'required_sequence': ['migrate_api']},
|
| 385 |
+
'graph_breaks': ['break_001', 'break_002', 'break_003'],
|
| 386 |
+
'checklist_dependency_graph': {
|
| 387 |
+
'break_003': ['break_001', 'break_002'],
|
| 388 |
+
'break_002': [],
|
| 389 |
+
'break_001': [],
|
| 390 |
+
},
|
| 391 |
+
'correct_fix_map': {
|
| 392 |
+
'break_001': 'torch.compile(disable=True)',
|
| 393 |
+
'break_002': 'functorch.vmap',
|
| 394 |
+
'break_003': 'torch.export',
|
| 395 |
+
},
|
| 396 |
+
'code_snippet': '''import torch
|
| 397 |
+
from torch.nn.utils import clip_grad_norm_
|
| 398 |
+
|
| 399 |
+
@torch.compile
|
| 400 |
+
def training_step(model, batch, optimizer):
|
| 401 |
+
# break_001: optimizer.step() inside compiled region
|
| 402 |
+
loss = model(batch['x'], batch['y'])
|
| 403 |
+
loss.backward()
|
| 404 |
+
optimizer.step() # graph break
|
| 405 |
+
|
| 406 |
+
# break_002: Python loop over batch dimension
|
| 407 |
+
grads = []
|
| 408 |
+
for param in model.parameters():
|
| 409 |
+
grads.append(param.grad.norm())
|
| 410 |
+
|
| 411 |
+
# break_003: clip_grad_norm_ mutation
|
| 412 |
+
clip_grad_norm_(model.parameters(), max_norm=1.0) # breaks graph
|
| 413 |
+
|
| 414 |
+
return loss.item()''',
|
| 415 |
+
'break_descriptions': [
|
| 416 |
+
'break_001: line 9 β optimizer.step() not compilable, wrap optimizer logic outside compile',
|
| 417 |
+
'break_002: line 13 β Python loop batching, use functorch.vmap for vectorization',
|
| 418 |
+
'break_003: line 17 β in-place grad clipping, use torch.export with explicit mutation tracking',
|
| 419 |
+
],
|
| 420 |
+
'graph_break_report': [
|
| 421 |
+
'break_001: line 9 β optimizer.step() not compilable, wrap optimizer logic outside compile',
|
| 422 |
+
'break_002: line 13 β Python loop batching, use functorch.vmap for vectorization',
|
| 423 |
+
'break_003: line 17 β in-place grad clipping, use torch.export with explicit mutation tracking',
|
| 424 |
+
],
|
| 425 |
+
'task_description': 'Fix training loop graph breaks. Optimizer, gradient accumulation, and clipping all cause compilation failures.',
|
| 426 |
+
},
|
| 427 |
],
|
| 428 |
}
|
server/datasets/security_cases.py
CHANGED
|
@@ -207,5 +207,81 @@ def get_document(doc_id: int, request: Request):
|
|
| 207 |
return get_user(user_id)''',
|
| 208 |
'task_description': 'Rewrite this JWT authentication to prevent algorithm confusion attacks while preserving user lookup logic.',
|
| 209 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
],
|
| 211 |
}
|
|
|
|
| 207 |
return get_user(user_id)''',
|
| 208 |
'task_description': 'Rewrite this JWT authentication to prevent algorithm confusion attacks while preserving user lookup logic.',
|
| 209 |
},
|
| 210 |
+
{
|
| 211 |
+
'case_id': 'sec_hard_003',
|
| 212 |
+
'task_subtype': 'adversarial',
|
| 213 |
+
'completion_threshold': 0.70,
|
| 214 |
+
'max_steps': 8,
|
| 215 |
+
'done_conditions': {'min_actions': 3, 'required_sequence': ['identify_vulnerability', 'propose_fix', 'revise_fix']},
|
| 216 |
+
'expected_vuln_type': 'xxe',
|
| 217 |
+
'cvss_range': [7.0, 9.5],
|
| 218 |
+
'expected_severity': 'high',
|
| 219 |
+
'required_fix_tokens': ['disable', 'external', 'entities'],
|
| 220 |
+
'must_preserve_identifier': 'parse_xml',
|
| 221 |
+
'current_feedback_keywords': ['xxe', 'entity', 'external'],
|
| 222 |
+
'original_vuln_pattern': 'allow_external_entities=True',
|
| 223 |
+
'reviewer_feedback': 'Fix disables DTD but doesn\'t disable external entities. Set both to False.',
|
| 224 |
+
'reviewer_feedback_sequence': [
|
| 225 |
+
'Fix disables DTD but doesn\'t disable external entities. Set both to False.',
|
| 226 |
+
'Better, but you should also validate XML schema against a whitelist.',
|
| 227 |
+
],
|
| 228 |
+
'tool_call': '''def parse_xml(xml_string):
|
| 229 |
+
import xml.etree.ElementTree as ET
|
| 230 |
+
parser = ET.XMLParser() # Default parser allows XXE
|
| 231 |
+
tree = ET.fromstring(xml_string, parser)
|
| 232 |
+
return tree.findall('.//user/name')''',
|
| 233 |
+
'task_description': 'Fix the XXE vulnerability in this XML parser while preserving all business logic. Reviewer will test your fix against attack payloads.',
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
'case_id': 'sec_hard_004',
|
| 237 |
+
'task_subtype': 'adversarial',
|
| 238 |
+
'completion_threshold': 0.70,
|
| 239 |
+
'max_steps': 8,
|
| 240 |
+
'done_conditions': {'min_actions': 3, 'required_sequence': ['identify_vulnerability', 'propose_fix', 'revise_fix']},
|
| 241 |
+
'expected_vuln_type': 'ssrf',
|
| 242 |
+
'cvss_range': [7.5, 9.0],
|
| 243 |
+
'expected_severity': 'critical',
|
| 244 |
+
'required_fix_tokens': ['allowlist', 'private', 'ip'],
|
| 245 |
+
'must_preserve_identifier': 'fetch_webhook',
|
| 246 |
+
'current_feedback_keywords': ['ssrf', 'internal', 'network'],
|
| 247 |
+
'original_vuln_pattern': 'requests.get(url)',
|
| 248 |
+
'reviewer_feedback': 'Fix blocks 127.0.0.1 but allows localhost and private IPs (10.x, 172.x, 192.168.x).',
|
| 249 |
+
'reviewer_feedback_sequence': [
|
| 250 |
+
'Fix blocks 127.0.0.1 but allows localhost and private IPs (10.x, 172.x, 192.168.x).',
|
| 251 |
+
'Good, but you should also block IPv6 loopback (::1) and link-local addresses.',
|
| 252 |
+
],
|
| 253 |
+
'tool_call': '''def fetch_webhook(url):
|
| 254 |
+
import requests
|
| 255 |
+
# Webhook URL from user input
|
| 256 |
+
response = requests.get(url, timeout=10)
|
| 257 |
+
return response.json()''',
|
| 258 |
+
'task_description': 'Prevent SSRF attacks on this webhook fetcher. Reviewer will try internal network access bypasses.',
|
| 259 |
+
},
|
| 260 |
+
{
|
| 261 |
+
'case_id': 'sec_hard_005',
|
| 262 |
+
'task_subtype': 'adversarial',
|
| 263 |
+
'completion_threshold': 0.70,
|
| 264 |
+
'max_steps': 8,
|
| 265 |
+
'done_conditions': {'min_actions': 3, 'required_sequence': ['identify_vulnerability', 'propose_fix', 'revise_fix']},
|
| 266 |
+
'expected_vuln_type': 'idor',
|
| 267 |
+
'cvss_range': [6.0, 8.5],
|
| 268 |
+
'expected_severity': 'high',
|
| 269 |
+
'required_fix_tokens': ['owner', 'session', 'user_id'],
|
| 270 |
+
'must_preserve_identifier': 'update_profile',
|
| 271 |
+
'current_feedback_keywords': ['idor', 'authorization', 'owner'],
|
| 272 |
+
'original_vuln_pattern': 'profile_id from request',
|
| 273 |
+
'reviewer_feedback': 'Fix checks profile ownership but uses user_id from request body (attacker-controlled).',
|
| 274 |
+
'reviewer_feedback_sequence': [
|
| 275 |
+
'Fix checks profile ownership but uses user_id from request body (attacker-controlled).',
|
| 276 |
+
'Better, but session validation is weak. Use cryptographic session tokens, not just user_id in cookie.',
|
| 277 |
+
],
|
| 278 |
+
'tool_call': '''@app.post("/profile/update")
|
| 279 |
+
def update_profile(profile_id: int, user_id: int, data: dict):
|
| 280 |
+
# user_id comes from request body (!)
|
| 281 |
+
profile = db.profiles.find_one({"_id": profile_id})
|
| 282 |
+
profile.update(data)
|
| 283 |
+
return {"status": "updated"}''',
|
| 284 |
+
'task_description': 'Fix IDOR vulnerability allowing users to edit others\' profiles. Reviewer will test horizontal privilege escalation.',
|
| 285 |
+
},
|
| 286 |
],
|
| 287 |
}
|
server/graders/security_grader.py
CHANGED
|
@@ -39,30 +39,49 @@ def _score_identify(action: Dict, case: Dict) -> float:
|
|
| 39 |
|
| 40 |
|
| 41 |
def _score_propose(action: Dict, case: Dict) -> float:
|
| 42 |
-
"""Score proposed fix. Checks token coverage
|
| 43 |
tokens = case.get('required_fix_tokens', [])
|
| 44 |
if isinstance(tokens, dict):
|
| 45 |
tokens = tokens.get(case.get('expected_vuln_type', ''), [])
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
fix = action.get('fix_code', '')
|
| 50 |
if not fix:
|
| 51 |
return 0.0
|
| 52 |
|
| 53 |
-
# Token coverage
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
else:
|
| 57 |
-
divisor = max(1, len(tokens) - 1)
|
| 58 |
-
coverage = min(1.0, sum(1 for t in tokens if t.lower() in fix.lower()) / divisor)
|
| 59 |
|
| 60 |
-
# Identifier preservation
|
| 61 |
key_id = case.get('must_preserve_identifier', '')
|
| 62 |
-
preservation = 0.
|
| 63 |
-
|
| 64 |
-
#
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
def _score_revise(action: Dict, case: Dict) -> float:
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
def _score_propose(action: Dict, case: Dict) -> float:
|
| 42 |
+
"""Score proposed fix. Checks token coverage, identifier preservation, and explanation."""
|
| 43 |
tokens = case.get('required_fix_tokens', [])
|
| 44 |
if isinstance(tokens, dict):
|
| 45 |
tokens = tokens.get(case.get('expected_vuln_type', ''), [])
|
| 46 |
+
|
| 47 |
+
# Flatten nested lists and ensure all strings
|
| 48 |
+
def flatten(lst):
|
| 49 |
+
result = []
|
| 50 |
+
for item in lst:
|
| 51 |
+
if isinstance(item, list):
|
| 52 |
+
result.extend(flatten(item))
|
| 53 |
+
elif isinstance(item, str):
|
| 54 |
+
result.append(item)
|
| 55 |
+
return result
|
| 56 |
+
|
| 57 |
+
tokens = flatten(tokens) if isinstance(tokens, list) else []
|
| 58 |
|
| 59 |
fix = action.get('fix_code', '')
|
| 60 |
if not fix:
|
| 61 |
return 0.0
|
| 62 |
|
| 63 |
+
# Token coverage (60%)
|
| 64 |
+
divisor = max(1, len(tokens) - 1)
|
| 65 |
+
coverage = min(1.0, sum(1 for t in tokens if t.lower() in fix.lower()) / divisor) if tokens else 0.5
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
+
# Identifier preservation (10%)
|
| 68 |
key_id = case.get('must_preserve_identifier', '')
|
| 69 |
+
preservation = 0.10 if key_id and key_id in fix else 0.0
|
| 70 |
+
|
| 71 |
+
# NEW: Explanation quality (30%)
|
| 72 |
+
explanation = action.get('explanation', '')
|
| 73 |
+
exp_score = 0.0
|
| 74 |
+
if explanation:
|
| 75 |
+
keywords = ['prevent', 'secure', 'validate', 'sanitize', 'parameterize']
|
| 76 |
+
exp_score = sum(0.06 for kw in keywords if kw in explanation.lower())
|
| 77 |
+
if len(explanation) < 20:
|
| 78 |
+
exp_score -= 0.05
|
| 79 |
+
vuln_type = case.get('expected_vuln_type', '').replace('_', ' ')
|
| 80 |
+
if vuln_type in explanation.lower():
|
| 81 |
+
exp_score += 0.10
|
| 82 |
+
|
| 83 |
+
# Combine: 60% code, 30% explanation, 10% identifier
|
| 84 |
+
return max(0.25, safe_score(coverage * 0.60 + exp_score * 0.30 + preservation * 0.10))
|
| 85 |
|
| 86 |
|
| 87 |
def _score_revise(action: Dict, case: Dict) -> float:
|
server/router.py
CHANGED
|
@@ -19,7 +19,7 @@ def route_step(session: SessionState, action: Dict) -> Dict:
|
|
| 19 |
grader = GRADERS.get(session.task_type)
|
| 20 |
if not grader:
|
| 21 |
return {
|
| 22 |
-
'reward': 0.
|
| 23 |
'done': True,
|
| 24 |
'observation': {'error': f'Unknown task_type: {session.task_type}'},
|
| 25 |
}
|
|
@@ -37,6 +37,7 @@ def route_step(session: SessionState, action: Dict) -> Dict:
|
|
| 37 |
|
| 38 |
# Score breakdown for debugging and UI
|
| 39 |
score_details = _compute_score_details(action, session)
|
|
|
|
| 40 |
|
| 41 |
return {
|
| 42 |
'episode_id': session.episode_id,
|
|
@@ -59,6 +60,12 @@ def _check_done(session: SessionState, action: Dict, reward: float, max_steps: i
|
|
| 59 |
next_step = session.step_count + 1
|
| 60 |
case = session.task_case
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
# Always done if max steps reached
|
| 63 |
if next_step >= max_steps:
|
| 64 |
return True
|
|
|
|
| 19 |
grader = GRADERS.get(session.task_type)
|
| 20 |
if not grader:
|
| 21 |
return {
|
| 22 |
+
'reward': 0.01,
|
| 23 |
'done': True,
|
| 24 |
'observation': {'error': f'Unknown task_type: {session.task_type}'},
|
| 25 |
}
|
|
|
|
| 37 |
|
| 38 |
# Score breakdown for debugging and UI
|
| 39 |
score_details = _compute_score_details(action, session)
|
| 40 |
+
obs['score_breakdown'] = score_details
|
| 41 |
|
| 42 |
return {
|
| 43 |
'episode_id': session.episode_id,
|
|
|
|
| 60 |
next_step = session.step_count + 1
|
| 61 |
case = session.task_case
|
| 62 |
|
| 63 |
+
# Mastery condition: high performance -> early exit
|
| 64 |
+
if next_step >= 2:
|
| 65 |
+
avg_reward = (session.reward_acc + reward) / next_step
|
| 66 |
+
if avg_reward >= 0.90:
|
| 67 |
+
return True
|
| 68 |
+
|
| 69 |
# Always done if max steps reached
|
| 70 |
if next_step >= max_steps:
|
| 71 |
return True
|
server/validation/validator.py
CHANGED
|
@@ -7,6 +7,8 @@
|
|
| 7 |
# - Rich hints so agent can self-correct on next step
|
| 8 |
|
| 9 |
from typing import Dict, Tuple
|
|
|
|
|
|
|
| 10 |
|
| 11 |
VALID_VULN_TYPES = {
|
| 12 |
'sql_injection', 'xss', 'idor', 'hardcoded_secret', 'missing_auth',
|
|
@@ -173,8 +175,10 @@ def validate_action(action: Dict, session) -> Tuple[bool, Dict]:
|
|
| 173 |
return True, {}
|
| 174 |
|
| 175 |
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
| 178 |
errors = []
|
| 179 |
|
| 180 |
if atype == 'identify_vulnerability':
|
|
@@ -191,12 +195,6 @@ def _domain_check(action: Dict, atype: str) -> list:
|
|
| 191 |
if sev not in VALID_SEVERITIES:
|
| 192 |
errors.append({'field': 'severity', 'value': sev, 'allowed': sorted(VALID_SEVERITIES)})
|
| 193 |
|
| 194 |
-
elif atype in ('propose_fix', 'revise_fix'):
|
| 195 |
-
fix = action.get('fix_code', '')
|
| 196 |
-
if len(fix) > 2000:
|
| 197 |
-
# Silently truncate instead of rejecting β don't penalize verbose agents
|
| 198 |
-
action['fix_code'] = fix[:2000]
|
| 199 |
-
|
| 200 |
elif atype == 'detect_gap':
|
| 201 |
rl = action.get('risk_level', '')
|
| 202 |
if rl not in VALID_RISK_LEVELS:
|
|
@@ -218,6 +216,24 @@ def _domain_check(action: Dict, atype: str) -> list:
|
|
| 218 |
return errors
|
| 219 |
|
| 220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
def _domain_hint(atype: str, errors: list) -> str:
|
| 222 |
"""Generate a helpful hint for domain errors."""
|
| 223 |
fields = [e.get('field', '') for e in errors]
|
|
|
|
| 7 |
# - Rich hints so agent can self-correct on next step
|
| 8 |
|
| 9 |
from typing import Dict, Tuple
|
| 10 |
+
from functools import lru_cache
|
| 11 |
+
import json
|
| 12 |
|
| 13 |
VALID_VULN_TYPES = {
|
| 14 |
'sql_injection', 'xss', 'idor', 'hardcoded_secret', 'missing_auth',
|
|
|
|
| 175 |
return True, {}
|
| 176 |
|
| 177 |
|
| 178 |
+
@lru_cache(maxsize=1024)
|
| 179 |
+
def _cached_domain_errors(action_json: str, atype: str) -> list:
|
| 180 |
+
"""Pure domain check logic that can be safely cached."""
|
| 181 |
+
action = json.loads(action_json)
|
| 182 |
errors = []
|
| 183 |
|
| 184 |
if atype == 'identify_vulnerability':
|
|
|
|
| 195 |
if sev not in VALID_SEVERITIES:
|
| 196 |
errors.append({'field': 'severity', 'value': sev, 'allowed': sorted(VALID_SEVERITIES)})
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
elif atype == 'detect_gap':
|
| 199 |
rl = action.get('risk_level', '')
|
| 200 |
if rl not in VALID_RISK_LEVELS:
|
|
|
|
| 216 |
return errors
|
| 217 |
|
| 218 |
|
| 219 |
+
def _domain_check(action: Dict, atype: str) -> list:
|
| 220 |
+
"""Check values are within allowed ranges/enums. Returns list of error dicts."""
|
| 221 |
+
# Handle mutations first (cannot be purely cached)
|
| 222 |
+
if atype in ('propose_fix', 'revise_fix'):
|
| 223 |
+
fix = action.get('fix_code', '')
|
| 224 |
+
if len(fix) > 2000:
|
| 225 |
+
# Silently truncate instead of rejecting β don't penalize verbose agents
|
| 226 |
+
action['fix_code'] = fix[:2000]
|
| 227 |
+
|
| 228 |
+
# Use cached pure function for validation
|
| 229 |
+
try:
|
| 230 |
+
action_json = json.dumps(action, sort_keys=True)
|
| 231 |
+
return _cached_domain_errors(action_json, atype)
|
| 232 |
+
except Exception:
|
| 233 |
+
# Fallback if not serializable
|
| 234 |
+
return _cached_domain_errors(json.dumps({'dummy': True}), atype)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
def _domain_hint(atype: str, errors: list) -> str:
|
| 238 |
"""Generate a helpful hint for domain errors."""
|
| 239 |
fields = [e.get('field', '') for e in errors]
|
server/web_ui.py
CHANGED
|
@@ -125,8 +125,8 @@ def run_single_task(task_id: str):
|
|
| 125 |
logs.append(f' Step {step + 1}: action={atype} reward={reward:.4f} done={done}')
|
| 126 |
step += 1
|
| 127 |
|
| 128 |
-
total = round(sum(rewards), 4)
|
| 129 |
-
logs.append(f'[END]
|
| 130 |
return '\n'.join(logs), rewards, total
|
| 131 |
|
| 132 |
|
|
@@ -146,7 +146,7 @@ def run_task_ui(task_id: str, model_name: str):
|
|
| 146 |
info = TASK_INFO.get(task_id, {})
|
| 147 |
domain = info.get('domain', 'Unknown')
|
| 148 |
difficulty = task_id.split('_')[1].upper()
|
| 149 |
-
score = min(max(total / max(len(rewards), 1), 0),
|
| 150 |
|
| 151 |
score_md = f'''### β
Results
|
| 152 |
| Field | Value |
|
|
@@ -181,7 +181,7 @@ def run_all_tasks_ui(model_name: str):
|
|
| 181 |
for task_id in tasks:
|
| 182 |
log_str, rewards, total = run_single_task(task_id)
|
| 183 |
all_logs.append(log_str)
|
| 184 |
-
score = min(max(total / max(len(rewards), 1), 0),
|
| 185 |
all_scores[task_id] = round(score, 4)
|
| 186 |
|
| 187 |
full_log = '\n\n'.join(all_logs)
|
|
@@ -253,7 +253,7 @@ def build_ui():
|
|
| 253 |
**A multi-domain RL environment for training AI agents on real-world tasks.**
|
| 254 |
|
| 255 |
This environment tests AI agents across **3 domains** with **9 tasks** of increasing difficulty.
|
| 256 |
-
Agents receive observations (problems), send actions (answers), and get reward scores (0.
|
| 257 |
''')
|
| 258 |
|
| 259 |
with gr.Tab('π― Single Task'):
|
|
@@ -320,7 +320,7 @@ via the API, and it gets scored on how well it solves real-world tasks.
|
|
| 320 |
1. Agent calls POST /reset with a task_id β Gets an observation (the problem)
|
| 321 |
2. Agent analyzes the observation and sends POST /step with its action
|
| 322 |
3. Environment validates the action and grades it
|
| 323 |
-
4. Returns a reward score (0.
|
| 324 |
5. Repeat until the episode ends (done=true) or max steps reached
|
| 325 |
```
|
| 326 |
|
|
@@ -332,7 +332,7 @@ via the API, and it gets scored on how well it solves real-world tasks.
|
|
| 332 |
| π₯ **Clinical** | cli_easy, cli_medium, cli_hard | Detect workflow gaps, rank by priority, plan recovery |
|
| 333 |
|
| 334 |
### Reward Signals
|
| 335 |
-
- Scores range from **0.
|
| 336 |
- Partial credit is awarded for partially correct answers
|
| 337 |
- Invalid or malformed actions receive lower scores
|
| 338 |
- The environment provides feedback on validation failures to help agents improve
|
|
|
|
| 125 |
logs.append(f' Step {step + 1}: action={atype} reward={reward:.4f} done={done}')
|
| 126 |
step += 1
|
| 127 |
|
| 128 |
+
total = round(sum(rewards) / max(len(rewards), 1), 4)
|
| 129 |
+
logs.append(f'[END] avg_reward={total} steps={step}')
|
| 130 |
return '\n'.join(logs), rewards, total
|
| 131 |
|
| 132 |
|
|
|
|
| 146 |
info = TASK_INFO.get(task_id, {})
|
| 147 |
domain = info.get('domain', 'Unknown')
|
| 148 |
difficulty = task_id.split('_')[1].upper()
|
| 149 |
+
score = min(max(total / max(len(rewards), 1), 0.01), 0.99)
|
| 150 |
|
| 151 |
score_md = f'''### β
Results
|
| 152 |
| Field | Value |
|
|
|
|
| 181 |
for task_id in tasks:
|
| 182 |
log_str, rewards, total = run_single_task(task_id)
|
| 183 |
all_logs.append(log_str)
|
| 184 |
+
score = min(max(total / max(len(rewards), 1), 0.01), 0.99)
|
| 185 |
all_scores[task_id] = round(score, 4)
|
| 186 |
|
| 187 |
full_log = '\n\n'.join(all_logs)
|
|
|
|
| 253 |
**A multi-domain RL environment for training AI agents on real-world tasks.**
|
| 254 |
|
| 255 |
This environment tests AI agents across **3 domains** with **9 tasks** of increasing difficulty.
|
| 256 |
+
Agents receive observations (problems), send actions (answers), and get reward scores (0.01 β 0.99).
|
| 257 |
''')
|
| 258 |
|
| 259 |
with gr.Tab('π― Single Task'):
|
|
|
|
| 320 |
1. Agent calls POST /reset with a task_id β Gets an observation (the problem)
|
| 321 |
2. Agent analyzes the observation and sends POST /step with its action
|
| 322 |
3. Environment validates the action and grades it
|
| 323 |
+
4. Returns a reward score (0.01 β 0.99) and the next observation
|
| 324 |
5. Repeat until the episode ends (done=true) or max steps reached
|
| 325 |
```
|
| 326 |
|
|
|
|
| 332 |
| π₯ **Clinical** | cli_easy, cli_medium, cli_hard | Detect workflow gaps, rank by priority, plan recovery |
|
| 333 |
|
| 334 |
### Reward Signals
|
| 335 |
+
- Scores range from **0.01** (completely wrong) to **0.99** (near-perfect)
|
| 336 |
- Partial credit is awarded for partially correct answers
|
| 337 |
- Invalid or malformed actions receive lower scores
|
| 338 |
- The environment provides feedback on validation failures to help agents improve
|