Spaces:
Running
Running
Commit Β·
f3fd4ef
1
Parent(s): 1ecd7e1
Spec-compliance overhaul: remove difficulty_multiplier, weighted blend scoring, dep_hard fix, [END] format
Browse filesFIXES:
1. inference.py: Remove score= from [END] lines (not in official spec)
2. inference.py: Scoring changed to 0.60*max + 0.40*mean (from pure average)
3. inference.py: All abort/error paths emit spec-compliant [END] with empty rewards
4. base_grader.py: difficulty_multiplier() REMOVED (uniform caps killed variance)
5. dependency_cases.py: dep_hard min_actions=1, seq=['migrate_api'] (was 2x which forced repetition penalty)
6. app.py /inference: Parse scores from final_scores JSON (not task_id/total_reward which didn't exist)
7. app.py /benchmark: Same weighted blend + no score= in [END]
8. README.md: Fix [END] example, update scoring description
- README.md +13 -27
- inference.py +158 -134
- server/app.py +63 -35
- server/datasets/dependency_cases.py +292 -298
- server/graders/base_grader.py +94 -51
README.md
CHANGED
|
@@ -106,7 +106,7 @@ Agents detect missing steps in hospital workflows, rank them by clinical priorit
|
|
| 106 |
| π **Multi-Turn Episodes** | Agents iterate through identify β act β revise workflows |
|
| 107 |
| π‘οΈ **3-Stage Validation** | Schema β Domain β Consistency checks with helpful error hints |
|
| 108 |
| π **Score Breakdown** | Per-component feedback in every step so agents learn *what* to improve |
|
| 109 |
-
| ποΈ **Fatal Error Handling** | Automatic 402/401 detection stops wasted API calls immediately |
|
| 110 |
| π **Universal LLM Support** | Works with any OpenAI-compatible model (Qwen, Llama, DeepSeek, Gemini, etc.) |
|
| 111 |
| π³ **Docker-Ready** | One-command deploy to Hugging Face Spaces |
|
| 112 |
| π **GRPO-Compatible** | Smooth reward gradients designed for policy optimization training |
|
|
@@ -229,33 +229,19 @@ entropyenv/
|
|
| 229 |
|
| 230 |
## π Baseline Performance
|
| 231 |
|
| 232 |
-
|
| 233 |
|
| 234 |
| Model | Provider | sec_easy | sec_med | sec_hard | dep_easy | dep_med | dep_hard | cli_easy | cli_med | cli_hard | **Avg** |
|
| 235 |
|-------|----------|:--------:|:-------:|:--------:|:--------:|:-------:|:--------:|:--------:|:-------:|:--------:|:-------:|
|
| 236 |
-
|
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
| Llama 3.3 70B | Meta | 0.87 | 0.20 | 0.38 | 0.83 | 0.95 | 0.85 | 0.09 | 0.84 | 0.83 | **0.65** |
|
| 246 |
-
| GPT-OSS-20B | OpenAI | 0.65 | 0.16 | 0.51 | 0.99 | 0.95 | 0.85 | 0.09 | 0.57 | 0.83 | **0.62** |
|
| 247 |
-
| Llama 3.1 8B | Meta | 0.53 | 0.22 | 0.44 | 0.45 | 0.67 | 0.85 | 0.74 | 0.48 | 0.80 | **0.57** |
|
| 248 |
-
| GPT-OSS-120B | OpenAI | 0.87 | 0.21 | 0.20 | 0.99 | 0.11 | 0.13 | 0.74 | 0.95 | 0.45 | **0.52** |
|
| 249 |
-
| Qwen3.5-9B | Alibaba | 0.87 | 0.72 | 0.51 | 0.99 | 0.11 | 0.20 | 0.05 | 0.01 | 0.02 | **0.38** |
|
| 250 |
-
| MiniMax M2.5 | MiniMax | 0.53 | 0.13 | 0.02 | 0.45 | 0.01 | 0.01 | 0.74 | 0.23 | 0.12 | **0.25** |
|
| 251 |
-
| MiniMax M2.7 | MiniMax | 0.53 | 0.01 | 0.39 | 0.45 | 0.01 | 0.01 | 0.04 | 0.11 | 0.42 | **0.22** |
|
| 252 |
-
| MiMo-v2 Pro | Xiaomi | 0.01 | 0.01 | 0.01 | 0.01 | 0.01 | 0.01 | 0.01 | 0.01 | 0.01 | **0.01** |
|
| 253 |
-
|
| 254 |
-
**Key observations:**
|
| 255 |
-
- π― **Clear difficulty progression:** Easy > Medium > Hard across all domains
|
| 256 |
-
- π **High variance:** Scores range from 0.01 (incompatible models) to 0.80 (DeepSeek R1)
|
| 257 |
-
- π¬ **Security is hardest:** Even top models score < 0.61 on `sec_hard` (propose_fix/revise_fix are genuinely difficult)
|
| 258 |
-
- π§ **Model discrimination:** The benchmark clearly separates 70B+ reasoning models from smaller/weaker ones
|
| 259 |
|
| 260 |
---
|
| 261 |
|
|
@@ -264,10 +250,10 @@ Tested across 14 models from 9 providers. Scores range from **0.01 to 0.80**, de
|
|
| 264 |
The baseline `inference.py` emits structured logs matching the OpenEnv spec:
|
| 265 |
|
| 266 |
```
|
| 267 |
-
[START] task=sec_easy env=
|
| 268 |
[STEP] step=1 action=identify_vulnerability reward=0.85 done=false error=null
|
| 269 |
[STEP] step=2 action=propose_fix reward=0.92 done=true error=null
|
| 270 |
-
[END] success=true steps=2
|
| 271 |
```
|
| 272 |
|
| 273 |
---
|
|
|
|
| 106 |
| π **Multi-Turn Episodes** | Agents iterate through identify β act β revise workflows |
|
| 107 |
| π‘οΈ **3-Stage Validation** | Schema β Domain β Consistency checks with helpful error hints |
|
| 108 |
| π **Score Breakdown** | Per-component feedback in every step so agents learn *what* to improve |
|
| 109 |
+
| ποΈ **Fatal Error Handling** | Automatic 402/401/403 detection stops wasted API calls immediately |
|
| 110 |
| π **Universal LLM Support** | Works with any OpenAI-compatible model (Qwen, Llama, DeepSeek, Gemini, etc.) |
|
| 111 |
| π³ **Docker-Ready** | One-command deploy to Hugging Face Spaces |
|
| 112 |
| π **GRPO-Compatible** | Smooth reward gradients designed for policy optimization training |
|
|
|
|
| 229 |
|
| 230 |
## π Baseline Performance
|
| 231 |
|
| 232 |
+
> **Note:** Scores below are from the latest grading revision (v3: weighted 0.60Γmax + 0.40Γmean scoring, difficulty_multiplier removed, dep_hard done-condition fixed). Re-benchmarking across 14+ models in progress.
|
| 233 |
|
| 234 |
| Model | Provider | sec_easy | sec_med | sec_hard | dep_easy | dep_med | dep_hard | cli_easy | cli_med | cli_hard | **Avg** |
|
| 235 |
|-------|----------|:--------:|:-------:|:--------:|:--------:|:-------:|:--------:|:--------:|:-------:|:--------:|:-------:|
|
| 236 |
+
| *Benchmarking in progress...* | | | | | | | | | | | |
|
| 237 |
+
|
| 238 |
+
**Scoring formula:** `score = 0.60 Γ max(step_rewards) + 0.40 Γ mean(step_rewards)`, clamped to `[0.01, 0.99]`
|
| 239 |
+
|
| 240 |
+
**Design principles:**
|
| 241 |
+
- π― **No artificial difficulty caps** β scores reflect actual grader correctness
|
| 242 |
+
- π **Weighted blend** β rewards consistently good episodes over single-lucky-step flukes
|
| 243 |
+
- π¬ **Spec-compliant** β `[END]` lines have NO `score=` field per official guidelines
|
| 244 |
+
- π§ **14+ model families tested** for universal compatibility
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
---
|
| 247 |
|
|
|
|
| 250 |
The baseline `inference.py` emits structured logs matching the OpenEnv spec:
|
| 251 |
|
| 252 |
```
|
| 253 |
+
[START] task=sec_easy env=EntropyEnv model=Qwen/Qwen2.5-72B-Instruct
|
| 254 |
[STEP] step=1 action=identify_vulnerability reward=0.85 done=false error=null
|
| 255 |
[STEP] step=2 action=propose_fix reward=0.92 done=true error=null
|
| 256 |
+
[END] success=true steps=2 rewards=0.85,0.92
|
| 257 |
```
|
| 258 |
|
| 259 |
---
|
inference.py
CHANGED
|
@@ -2,10 +2,18 @@
|
|
| 2 |
# Mandatory baseline inference script for OpenEnv hackathon.
|
| 3 |
# Uses OpenAI-compatible client for HuggingFace Inference API.
|
| 4 |
#
|
| 5 |
-
# STDOUT FORMAT (
|
| 6 |
# [START] task=<task_name> env=<benchmark> model=<model_name>
|
| 7 |
# [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
|
| 8 |
-
# [END] success=<true|false> steps=<n>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
import os
|
| 11 |
import re
|
|
@@ -20,23 +28,21 @@ try:
|
|
| 20 |
except ImportError:
|
| 21 |
pass
|
| 22 |
|
| 23 |
-
# ββ Mandatory environment variables ββ
|
| 24 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 25 |
-
MODEL_NAME = os.getenv("MODEL_NAME",
|
| 26 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 27 |
-
ENV_URL = os.getenv("ENV_URL",
|
| 28 |
|
| 29 |
MAX_STEPS = 8
|
| 30 |
TEMPERATURE = 0.1
|
| 31 |
MAX_TOKENS = 400
|
| 32 |
BENCHMARK = "EntropyEnv"
|
| 33 |
|
| 34 |
-
#
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
RETRYABLE_HTTP_CODES = {429, 500, 502, 503, 504}
|
| 39 |
-
MAX_CONSECUTIVE_ERRORS = 3 # stop task after 3 consecutive API errors
|
| 40 |
|
| 41 |
TASKS = [
|
| 42 |
"sec_easy", "sec_medium", "sec_hard",
|
|
@@ -75,7 +81,7 @@ EXACT FORMAT EXAMPLES β copy field names exactly:
|
|
| 75 |
{"action_type": "revise_fix", "fix_code": "cursor.execute(sql, values)", "addressed_feedback": "Used parameterized queries and added input validation"}
|
| 76 |
{"action_type": "flag_outdated", "packages": {"torch": "1.9.0"}, "deprecated_api": "torch.autograd.Variable", "replacement": "plain tensor"}
|
| 77 |
{"action_type": "resolve_conflict", "packages": {"torch": "2.1.0", "numpy": "1.24.0"}, "reasoning": "torch 2.1 requires numpy >=1.24"}
|
| 78 |
-
{"action_type": "migrate_api", "completed_items": ["break_001", "break_002"
|
| 79 |
{"action_type": "detect_gap", "missing_steps": ["pre_op_consent"], "risk_level": "critical"}
|
| 80 |
{"action_type": "rank_issues", "priority_order": ["resolve_insurance", "pre_op_consent", "book_specialist"]}
|
| 81 |
{"action_type": "order_steps", "recovery_steps": ["resolve_insurance", "complete_pre_op", "book_specialist", "schedule_surgery"]}
|
|
@@ -85,12 +91,9 @@ CRITICAL: Output ONLY the JSON object. Nothing before or after it.
|
|
| 85 |
|
| 86 |
|
| 87 |
def _extract_http_code(error_str: str) -> int:
|
| 88 |
-
"""Extract HTTP status code from error message string. Returns 0 if not found."""
|
| 89 |
-
# Matches patterns like "Error code: 402" or "status_code=402" or "HTTP 402"
|
| 90 |
match = re.search(r'(?:Error code:|status_code=|HTTP )\s*(\d{3})', str(error_str))
|
| 91 |
if match:
|
| 92 |
return int(match.group(1))
|
| 93 |
-
# Also check for bare 4xx/5xx at start of error
|
| 94 |
match = re.search(r'\b(4\d{2}|5\d{2})\b', str(error_str))
|
| 95 |
if match:
|
| 96 |
return int(match.group(1))
|
|
@@ -98,31 +101,26 @@ def _extract_http_code(error_str: str) -> int:
|
|
| 98 |
|
| 99 |
|
| 100 |
def _is_fatal_error(error_str: str) -> bool:
|
| 101 |
-
"""Return True if this error means we should stop ALL tasks (not just this one)."""
|
| 102 |
code = _extract_http_code(error_str)
|
| 103 |
if code in FATAL_HTTP_CODES:
|
| 104 |
return True
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
err_lower = str(error_str).lower()
|
| 109 |
-
return any(kw in err_lower for kw in fatal_keywords)
|
| 110 |
|
| 111 |
|
| 112 |
-
def
|
| 113 |
-
"""Return True if this error means we should stop THIS task but try others."""
|
| 114 |
code = _extract_http_code(error_str)
|
| 115 |
-
if code in
|
| 116 |
return True
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
return any(kw in err_lower for kw in task_fatal_keywords)
|
| 121 |
|
| 122 |
|
| 123 |
def build_user_prompt(step_num: int, obs: dict, history: list) -> str:
|
| 124 |
task_type = obs.get("task_type", "unknown")
|
| 125 |
-
task_id = obs.get("task_id",
|
| 126 |
task_sub = obs.get("task_subtype", "")
|
| 127 |
|
| 128 |
parts = [f"Step {step_num} | task_type={task_type} | task_id={task_id} | subtype={task_sub}"]
|
|
@@ -132,8 +130,8 @@ def build_user_prompt(step_num: int, obs: dict, history: list) -> str:
|
|
| 132 |
last = history[-1]
|
| 133 |
parts.append(f"Actions used: {used}")
|
| 134 |
parts.append(f"Last reward: {last['reward']:.2f}")
|
| 135 |
-
if last["reward"] < 0.
|
| 136 |
-
parts.append(
|
| 137 |
|
| 138 |
if obs.get("validation_failed"):
|
| 139 |
parts.append(f"\nβ VALIDATION FAILED!")
|
|
@@ -145,39 +143,34 @@ def build_user_prompt(step_num: int, obs: dict, history: list) -> str:
|
|
| 145 |
parts.append(obs["reviewer_feedback"])
|
| 146 |
|
| 147 |
obs_copy = dict(obs)
|
| 148 |
-
|
| 149 |
-
|
| 150 |
|
| 151 |
-
core_text = json.dumps(obs_copy, default=str
|
|
|
|
|
|
|
| 152 |
parts.append(f"\nObservation:\n{core_text}")
|
| 153 |
|
| 154 |
-
if
|
| 155 |
-
parts.append(
|
| 156 |
-
for pkg, versions in
|
| 157 |
-
parts.append(f" {pkg}:")
|
| 158 |
for ver, deps in versions.items():
|
| 159 |
-
if deps
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
parts.append(f"\nDependency Graph (prerequisites must come first):")
|
| 166 |
-
for step, prereqs in dep_graph.items():
|
| 167 |
-
if prereqs:
|
| 168 |
-
parts.append(f" {step} requires: {prereqs}")
|
| 169 |
-
else:
|
| 170 |
-
parts.append(f" {step} β (no prereqs)")
|
| 171 |
|
|
|
|
| 172 |
if task_type == "security":
|
| 173 |
used_types = [h["action_type"] for h in history]
|
| 174 |
-
if
|
| 175 |
parts.append("\nβ‘οΈ NEXT: identify_vulnerability")
|
| 176 |
elif "propose_fix" not in used_types:
|
| 177 |
parts.append("\nβ‘οΈ NEXT: propose_fix")
|
| 178 |
else:
|
| 179 |
parts.append("\nβ‘οΈ NEXT: revise_fix (address reviewer_feedback)")
|
| 180 |
-
|
| 181 |
elif task_type == "clinical":
|
| 182 |
used_types = [h["action_type"] for h in history]
|
| 183 |
if "detect_gap" not in used_types:
|
|
@@ -195,15 +188,13 @@ def parse_action(raw_text: str) -> dict:
|
|
| 195 |
"""Parse LLM response into action dict. Universal model compatibility."""
|
| 196 |
text = raw_text.strip()
|
| 197 |
|
|
|
|
| 198 |
for tag in ["think", "thinking", "reasoning", "reflection", "thought", "antThinking"]:
|
| 199 |
-
open_tag = f"<{tag}>"
|
| 200 |
-
close_tag = f"</{tag}>"
|
| 201 |
if open_tag in text:
|
| 202 |
-
if close_tag in text
|
| 203 |
-
text = text.split(close_tag)[-1].strip()
|
| 204 |
-
else:
|
| 205 |
-
text = text.split(open_tag)[-1].strip()
|
| 206 |
|
|
|
|
| 207 |
if "```json" in text:
|
| 208 |
text = text.split("```json")[1].split("```")[0].strip()
|
| 209 |
elif "```" in text:
|
|
@@ -211,6 +202,7 @@ def parse_action(raw_text: str) -> dict:
|
|
| 211 |
if len(parts) >= 3:
|
| 212 |
text = parts[1].strip()
|
| 213 |
|
|
|
|
| 214 |
if not text.startswith("{"):
|
| 215 |
start = text.find("{")
|
| 216 |
if start >= 0:
|
|
@@ -233,42 +225,72 @@ def parse_action(raw_text: str) -> dict:
|
|
| 233 |
return {"action_type": "error", "raw": text[:100]}
|
| 234 |
|
| 235 |
|
| 236 |
-
def
|
| 237 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
"""
|
| 241 |
-
# Reset
|
| 242 |
try:
|
| 243 |
resp = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=30)
|
| 244 |
data = resp.json()
|
| 245 |
except Exception as e:
|
|
|
|
| 246 |
print(f"[START] task={task_id} env={BENCHMARK} model={MODEL_NAME}", flush=True)
|
| 247 |
-
print(f"[END] success=false steps=0
|
| 248 |
return 0.01, False
|
| 249 |
|
| 250 |
if "error" in data and not data.get("episode_id"):
|
| 251 |
print(f"[START] task={task_id} env={BENCHMARK} model={MODEL_NAME}", flush=True)
|
| 252 |
-
print(f"[END] success=false steps=0
|
| 253 |
return 0.01, False
|
| 254 |
|
| 255 |
episode_id = data.get("episode_id", "unknown")
|
| 256 |
-
obs
|
| 257 |
|
|
|
|
| 258 |
print(f"[START] task={task_id} env={BENCHMARK} model={MODEL_NAME}", flush=True)
|
| 259 |
|
| 260 |
-
rewards
|
| 261 |
-
history
|
| 262 |
-
step_num
|
| 263 |
-
|
|
|
|
| 264 |
|
| 265 |
for step_num in range(1, MAX_STEPS + 1):
|
| 266 |
user_prompt = build_user_prompt(step_num, obs, history)
|
|
|
|
| 267 |
|
| 268 |
-
|
| 269 |
-
fatal_error = False
|
| 270 |
-
task_fatal = False
|
| 271 |
-
|
| 272 |
try:
|
| 273 |
reply = client.chat.completions.create(
|
| 274 |
model=MODEL_NAME,
|
|
@@ -280,80 +302,85 @@ def run_task(client: OpenAI, task_id: str) -> tuple:
|
|
| 280 |
max_tokens=MAX_TOKENS,
|
| 281 |
)
|
| 282 |
response_text = (reply.choices[0].message.content or "").strip()
|
| 283 |
-
|
| 284 |
|
| 285 |
except Exception as e:
|
| 286 |
-
error_msg
|
| 287 |
response_text = '{"action_type": "error"}'
|
| 288 |
-
|
| 289 |
|
| 290 |
-
# Check if this is a fatal error (auth/payment) β stop everything
|
| 291 |
if _is_fatal_error(error_msg):
|
| 292 |
fatal_error = True
|
| 293 |
-
|
| 294 |
-
|
|
|
|
| 295 |
rewards.append(0.01)
|
| 296 |
-
step_num_final = step_num
|
| 297 |
break
|
| 298 |
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
short_err = error_msg[:120].replace('\n', ' ')
|
| 303 |
-
print(f"[STEP] step={step_num} action=invalid reward=0.01 done=true error=TASK_STOP:{short_err}", flush=True)
|
| 304 |
rewards.append(0.01)
|
| 305 |
-
step_num_final = step_num
|
| 306 |
break
|
| 307 |
|
| 308 |
-
action
|
| 309 |
action_type = action.get("action_type", "unknown")
|
| 310 |
action["episode_id"] = episode_id
|
| 311 |
|
|
|
|
| 312 |
try:
|
| 313 |
step_resp = requests.post(f"{ENV_URL}/step", json=action, timeout=30)
|
| 314 |
step_data = step_resp.json()
|
| 315 |
except Exception as e:
|
| 316 |
-
|
| 317 |
-
print(f"[STEP] step={step_num} action={action_type} reward=0.01 done=true error={
|
| 318 |
rewards.append(0.01)
|
| 319 |
-
step_num_final = step_num
|
| 320 |
-
fatal_error = False
|
| 321 |
break
|
| 322 |
|
| 323 |
-
reward
|
| 324 |
-
done
|
| 325 |
-
obs
|
| 326 |
step_error = step_data.get("error") or error_msg
|
| 327 |
|
| 328 |
rewards.append(reward)
|
| 329 |
history.append({"step": step_num, "action_type": action_type, "reward": reward, "done": done})
|
| 330 |
|
| 331 |
-
|
| 332 |
-
if obs.get("validation_failed")
|
| 333 |
-
display_action = "invalid"
|
| 334 |
-
|
| 335 |
-
error_val = step_error if step_error else "null"
|
| 336 |
-
# Truncate long error messages in output
|
| 337 |
-
if error_val and error_val != "null" and len(str(error_val)) > 150:
|
| 338 |
-
error_val = str(error_val)[:150] + "..."
|
| 339 |
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
|
| 342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
|
| 344 |
if done:
|
| 345 |
-
fatal_error = False
|
| 346 |
break
|
| 347 |
-
else:
|
| 348 |
-
step_num_final = step_num
|
| 349 |
-
fatal_error = False
|
| 350 |
|
| 351 |
-
|
| 352 |
-
score
|
| 353 |
-
success =
|
|
|
|
|
|
|
|
|
|
| 354 |
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 355 |
|
| 356 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
|
| 358 |
return score, fatal_error
|
| 359 |
|
|
@@ -361,63 +388,60 @@ def run_task(client: OpenAI, task_id: str) -> tuple:
|
|
| 361 |
def main() -> None:
|
| 362 |
"""Run all 9 tasks and report final scores."""
|
| 363 |
if not HF_TOKEN:
|
| 364 |
-
print("ERROR: Set HF_TOKEN
|
| 365 |
return
|
| 366 |
|
| 367 |
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 368 |
|
| 369 |
-
# Health check
|
| 370 |
try:
|
| 371 |
health = requests.get(f"{ENV_URL}/", timeout=10, headers={"Accept": "application/json"})
|
| 372 |
health_data = health.json()
|
| 373 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
except Exception as e:
|
| 375 |
print(f"ERROR: Cannot connect to environment at {ENV_URL}: {e}", flush=True)
|
| 376 |
return
|
| 377 |
|
| 378 |
-
scores
|
| 379 |
-
|
| 380 |
|
| 381 |
for task_id in TASKS:
|
| 382 |
try:
|
| 383 |
score, is_fatal = run_task(client, task_id)
|
| 384 |
scores[task_id] = score
|
| 385 |
|
| 386 |
-
# If we hit a fatal API error (402/401/403), stop ALL remaining tasks
|
| 387 |
if is_fatal:
|
| 388 |
-
|
| 389 |
-
print(f"\nπ« Fatal API error on {task_id}. Stopping
|
| 390 |
-
print(f" Likely cause: invalid token, no credits, or unauthorized access.", flush=True)
|
| 391 |
-
# Emit mandatory [START]/[END] lines for remaining tasks (spec compliance)
|
| 392 |
for remaining in TASKS:
|
| 393 |
if remaining not in scores:
|
| 394 |
scores[remaining] = 0.01
|
| 395 |
print(f"[START] task={remaining} env={BENCHMARK} model={MODEL_NAME}", flush=True)
|
| 396 |
-
print(f"[END] success=false steps=0
|
| 397 |
break
|
| 398 |
|
| 399 |
except Exception as e:
|
| 400 |
-
print(f"[START] task={task_id} env={BENCHMARK} model={MODEL_NAME}", flush=True)
|
| 401 |
-
print(f"[END] success=false steps=0 score=0.01 rewards=", flush=True)
|
| 402 |
scores[task_id] = 0.01
|
|
|
|
|
|
|
| 403 |
|
| 404 |
-
avg = round(sum(scores.values()) / max(len(scores), 1),
|
| 405 |
-
print(f"\nβ
All tasks complete! Average: {avg:.
|
|
|
|
| 406 |
print(json.dumps({"final_scores": scores}), flush=True)
|
| 407 |
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
# scores that would corrupt the benchmark history.
|
| 411 |
-
if had_fatal_error:
|
| 412 |
-
print(f"β οΈ Results NOT saved β run was aborted due to a fatal API error (invalid token / no credits).", flush=True)
|
| 413 |
-
print(f" Fix your API key/credits and re-run to get valid scores.", flush=True)
|
| 414 |
else:
|
| 415 |
try:
|
| 416 |
from server.benchmark_store import append_result
|
| 417 |
append_result(MODEL_NAME, MODEL_NAME, scores)
|
| 418 |
print(f"πΎ Results saved (avg: {avg:.4f})", flush=True)
|
| 419 |
except Exception as e:
|
| 420 |
-
print(f"β οΈ
|
| 421 |
|
| 422 |
|
| 423 |
if __name__ == "__main__":
|
|
|
|
| 2 |
# Mandatory baseline inference script for OpenEnv hackathon.
|
| 3 |
# Uses OpenAI-compatible client for HuggingFace Inference API.
|
| 4 |
#
|
| 5 |
+
# OFFICIAL STDOUT FORMAT (from Meta_OpenEnv_Hackathon__Guidelines.txt):
|
| 6 |
# [START] task=<task_name> env=<benchmark> model=<model_name>
|
| 7 |
# [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
|
| 8 |
+
# [END] success=<true|false> steps=<n> rewards=<r1,r2,...,rn>
|
| 9 |
+
#
|
| 10 |
+
# KEY RULES FROM OFFICIAL SPEC:
|
| 11 |
+
# - reward and rewards formatted to 2 decimal places ONLY
|
| 12 |
+
# - done and success are lowercase booleans: true or false
|
| 13 |
+
# - error is null when no error (the literal string "null")
|
| 14 |
+
# - NO score= field in [END] β not in the official spec
|
| 15 |
+
# - NO task_id=, NO episode_id=, NO total_reward= β none of these are in spec
|
| 16 |
+
# - rewards= is a comma-separated list of step rewards with NO spaces
|
| 17 |
|
| 18 |
import os
|
| 19 |
import re
|
|
|
|
| 28 |
except ImportError:
|
| 29 |
pass
|
| 30 |
|
| 31 |
+
# ββ Mandatory environment variables (names exactly as spec requires) ββ
|
| 32 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 33 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 34 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 35 |
+
ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
|
| 36 |
|
| 37 |
MAX_STEPS = 8
|
| 38 |
TEMPERATURE = 0.1
|
| 39 |
MAX_TOKENS = 400
|
| 40 |
BENCHMARK = "EntropyEnv"
|
| 41 |
|
| 42 |
+
# Fatal HTTP codes: stop ALL tasks immediately
|
| 43 |
+
FATAL_HTTP_CODES = {402, 401, 403}
|
| 44 |
+
RETRYABLE_CODES = {429, 500, 502, 503, 504}
|
| 45 |
+
MAX_CONSEC_ERRORS = 3
|
|
|
|
|
|
|
| 46 |
|
| 47 |
TASKS = [
|
| 48 |
"sec_easy", "sec_medium", "sec_hard",
|
|
|
|
| 81 |
{"action_type": "revise_fix", "fix_code": "cursor.execute(sql, values)", "addressed_feedback": "Used parameterized queries and added input validation"}
|
| 82 |
{"action_type": "flag_outdated", "packages": {"torch": "1.9.0"}, "deprecated_api": "torch.autograd.Variable", "replacement": "plain tensor"}
|
| 83 |
{"action_type": "resolve_conflict", "packages": {"torch": "2.1.0", "numpy": "1.24.0"}, "reasoning": "torch 2.1 requires numpy >=1.24"}
|
| 84 |
+
{"action_type": "migrate_api", "completed_items": ["break_001", "break_002"], "code_changes": {"break_001": "use torch.where", "break_002": "use tensor.shape[0]"}}
|
| 85 |
{"action_type": "detect_gap", "missing_steps": ["pre_op_consent"], "risk_level": "critical"}
|
| 86 |
{"action_type": "rank_issues", "priority_order": ["resolve_insurance", "pre_op_consent", "book_specialist"]}
|
| 87 |
{"action_type": "order_steps", "recovery_steps": ["resolve_insurance", "complete_pre_op", "book_specialist", "schedule_surgery"]}
|
|
|
|
| 91 |
|
| 92 |
|
| 93 |
def _extract_http_code(error_str: str) -> int:
|
|
|
|
|
|
|
| 94 |
match = re.search(r'(?:Error code:|status_code=|HTTP )\s*(\d{3})', str(error_str))
|
| 95 |
if match:
|
| 96 |
return int(match.group(1))
|
|
|
|
| 97 |
match = re.search(r'\b(4\d{2}|5\d{2})\b', str(error_str))
|
| 98 |
if match:
|
| 99 |
return int(match.group(1))
|
|
|
|
| 101 |
|
| 102 |
|
| 103 |
def _is_fatal_error(error_str: str) -> bool:
|
|
|
|
| 104 |
code = _extract_http_code(error_str)
|
| 105 |
if code in FATAL_HTTP_CODES:
|
| 106 |
return True
|
| 107 |
+
fatal_kw = ['insufficient credits', 'unauthorized', 'invalid api key',
|
| 108 |
+
'authentication failed', 'no api key', 'forbidden']
|
| 109 |
+
return any(kw in str(error_str).lower() for kw in fatal_kw)
|
|
|
|
|
|
|
| 110 |
|
| 111 |
|
| 112 |
+
def _is_task_fatal(error_str: str) -> bool:
|
|
|
|
| 113 |
code = _extract_http_code(error_str)
|
| 114 |
+
if code in RETRYABLE_CODES:
|
| 115 |
return True
|
| 116 |
+
task_kw = ['model not found', 'model unavailable', 'context length',
|
| 117 |
+
'maximum context', 'rate limit']
|
| 118 |
+
return any(kw in str(error_str).lower() for kw in task_kw)
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
def build_user_prompt(step_num: int, obs: dict, history: list) -> str:
|
| 122 |
task_type = obs.get("task_type", "unknown")
|
| 123 |
+
task_id = obs.get("task_id", "unknown")
|
| 124 |
task_sub = obs.get("task_subtype", "")
|
| 125 |
|
| 126 |
parts = [f"Step {step_num} | task_type={task_type} | task_id={task_id} | subtype={task_sub}"]
|
|
|
|
| 130 |
last = history[-1]
|
| 131 |
parts.append(f"Actions used: {used}")
|
| 132 |
parts.append(f"Last reward: {last['reward']:.2f}")
|
| 133 |
+
if last["reward"] < 0.40:
|
| 134 |
+
parts.append("β οΈ Low score. Try a different approach.")
|
| 135 |
|
| 136 |
if obs.get("validation_failed"):
|
| 137 |
parts.append(f"\nβ VALIDATION FAILED!")
|
|
|
|
| 143 |
parts.append(obs["reviewer_feedback"])
|
| 144 |
|
| 145 |
obs_copy = dict(obs)
|
| 146 |
+
compat = obs_copy.pop("compatibility_matrix", None)
|
| 147 |
+
dep_g = obs_copy.pop("dependency_graph", None)
|
| 148 |
|
| 149 |
+
core_text = json.dumps(obs_copy, default=str)
|
| 150 |
+
if len(core_text) > 1600:
|
| 151 |
+
core_text = core_text[:1600] + "..."
|
| 152 |
parts.append(f"\nObservation:\n{core_text}")
|
| 153 |
|
| 154 |
+
if compat:
|
| 155 |
+
parts.append("\nCompatibility Matrix (use this to resolve conflicts):")
|
| 156 |
+
for pkg, versions in compat.items():
|
|
|
|
| 157 |
for ver, deps in versions.items():
|
| 158 |
+
parts.append(f" {pkg} {ver} β {deps if deps else '(no constraints)'}")
|
| 159 |
+
|
| 160 |
+
if dep_g:
|
| 161 |
+
parts.append("\nDependency Graph (prerequisites must come first):")
|
| 162 |
+
for step, prereqs in dep_g.items():
|
| 163 |
+
parts.append(f" {step} requires: {prereqs}" if prereqs else f" {step} β (no prereqs)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
+
# Next-action hint β keeps all models on track
|
| 166 |
if task_type == "security":
|
| 167 |
used_types = [h["action_type"] for h in history]
|
| 168 |
+
if "identify_vulnerability" not in used_types:
|
| 169 |
parts.append("\nβ‘οΈ NEXT: identify_vulnerability")
|
| 170 |
elif "propose_fix" not in used_types:
|
| 171 |
parts.append("\nβ‘οΈ NEXT: propose_fix")
|
| 172 |
else:
|
| 173 |
parts.append("\nβ‘οΈ NEXT: revise_fix (address reviewer_feedback)")
|
|
|
|
| 174 |
elif task_type == "clinical":
|
| 175 |
used_types = [h["action_type"] for h in history]
|
| 176 |
if "detect_gap" not in used_types:
|
|
|
|
| 188 |
"""Parse LLM response into action dict. Universal model compatibility."""
|
| 189 |
text = raw_text.strip()
|
| 190 |
|
| 191 |
+
# Strip reasoning/thinking blocks
|
| 192 |
for tag in ["think", "thinking", "reasoning", "reflection", "thought", "antThinking"]:
|
| 193 |
+
open_tag, close_tag = f"<{tag}>", f"</{tag}>"
|
|
|
|
| 194 |
if open_tag in text:
|
| 195 |
+
text = text.split(close_tag)[-1].strip() if close_tag in text else text.split(open_tag)[-1].strip()
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
+
# Strip markdown fences
|
| 198 |
if "```json" in text:
|
| 199 |
text = text.split("```json")[1].split("```")[0].strip()
|
| 200 |
elif "```" in text:
|
|
|
|
| 202 |
if len(parts) >= 3:
|
| 203 |
text = parts[1].strip()
|
| 204 |
|
| 205 |
+
# Find JSON object in prose
|
| 206 |
if not text.startswith("{"):
|
| 207 |
start = text.find("{")
|
| 208 |
if start >= 0:
|
|
|
|
| 225 |
return {"action_type": "error", "raw": text[:100]}
|
| 226 |
|
| 227 |
|
| 228 |
+
def _compute_score(rewards: list) -> float:
|
| 229 |
+
"""
|
| 230 |
+
Compute the episode score from a list of step rewards.
|
| 231 |
+
|
| 232 |
+
DESIGN RATIONALE β why neither pure max nor pure mean is right:
|
| 233 |
+
- Pure max: agent scores 0.85 on step 1, then 0.01 on all later steps β score=0.85
|
| 234 |
+
This rewards single-lucky-step behaviour and hides that later steps failed.
|
| 235 |
+
- Pure mean: agent scores 0.85 on step 1, 0.01 on 3 more β score=0.23
|
| 236 |
+
This massively under-reports good episodes that have validation failures early.
|
| 237 |
+
|
| 238 |
+
SOLUTION β weighted blend of max and mean:
|
| 239 |
+
score = 0.60 * max(rewards) + 0.40 * mean(rewards)
|
| 240 |
+
|
| 241 |
+
WHY THIS WORKS:
|
| 242 |
+
- A great single-step performance (0.85) still shows up clearly (0.51 baseline contribution)
|
| 243 |
+
- A consistently good episode (0.80, 0.85, 0.80) gets full credit (β0.83)
|
| 244 |
+
- A fluke-then-fail episode (0.85, 0.01, 0.01, 0.01) scores 0.52 β honestly mediocre
|
| 245 |
+
- A failed episode (all 0.01) scores 0.01 β correctly bad
|
| 246 |
+
|
| 247 |
+
Clamped to [0.01, 0.99] per Discord consensus on the (0,1) exclusive range.
|
| 248 |
+
"""
|
| 249 |
+
if not rewards:
|
| 250 |
+
return 0.01
|
| 251 |
+
max_r = max(rewards)
|
| 252 |
+
mean_r = sum(rewards) / len(rewards)
|
| 253 |
+
raw = 0.60 * max_r + 0.40 * mean_r
|
| 254 |
+
return round(min(max(raw, 0.01), 0.99), 4)
|
| 255 |
|
| 256 |
+
|
| 257 |
+
def run_task(client: OpenAI, task_id: str) -> tuple:
|
| 258 |
+
"""
|
| 259 |
+
Run a single task through the environment.
|
| 260 |
+
Returns (score: float, is_fatal_api_error: bool).
|
| 261 |
"""
|
| 262 |
+
# ββ Reset ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 263 |
try:
|
| 264 |
resp = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=30)
|
| 265 |
data = resp.json()
|
| 266 |
except Exception as e:
|
| 267 |
+
# Env unreachable β must still emit [START] and [END]
|
| 268 |
print(f"[START] task={task_id} env={BENCHMARK} model={MODEL_NAME}", flush=True)
|
| 269 |
+
print(f"[END] success=false steps=0 rewards=", flush=True)
|
| 270 |
return 0.01, False
|
| 271 |
|
| 272 |
if "error" in data and not data.get("episode_id"):
|
| 273 |
print(f"[START] task={task_id} env={BENCHMARK} model={MODEL_NAME}", flush=True)
|
| 274 |
+
print(f"[END] success=false steps=0 rewards=", flush=True)
|
| 275 |
return 0.01, False
|
| 276 |
|
| 277 |
episode_id = data.get("episode_id", "unknown")
|
| 278 |
+
obs = data.get("observation", data)
|
| 279 |
|
| 280 |
+
# ββ Mandatory [START] line β exact official spec ββββββββββββββββββββββββ
|
| 281 |
print(f"[START] task={task_id} env={BENCHMARK} model={MODEL_NAME}", flush=True)
|
| 282 |
|
| 283 |
+
rewards = []
|
| 284 |
+
history = []
|
| 285 |
+
step_num = 0
|
| 286 |
+
consec_errs = 0
|
| 287 |
+
fatal_error = False
|
| 288 |
|
| 289 |
for step_num in range(1, MAX_STEPS + 1):
|
| 290 |
user_prompt = build_user_prompt(step_num, obs, history)
|
| 291 |
+
error_msg = None
|
| 292 |
|
| 293 |
+
# ββ LLM call βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
| 294 |
try:
|
| 295 |
reply = client.chat.completions.create(
|
| 296 |
model=MODEL_NAME,
|
|
|
|
| 302 |
max_tokens=MAX_TOKENS,
|
| 303 |
)
|
| 304 |
response_text = (reply.choices[0].message.content or "").strip()
|
| 305 |
+
consec_errs = 0
|
| 306 |
|
| 307 |
except Exception as e:
|
| 308 |
+
error_msg = str(e)
|
| 309 |
response_text = '{"action_type": "error"}'
|
| 310 |
+
consec_errs += 1
|
| 311 |
|
|
|
|
| 312 |
if _is_fatal_error(error_msg):
|
| 313 |
fatal_error = True
|
| 314 |
+
short = error_msg[:120].replace('\n', ' ')
|
| 315 |
+
# Emit mandatory [STEP] then break β [END] emitted below
|
| 316 |
+
print(f"[STEP] step={step_num} action=invalid reward=0.01 done=true error=FATAL:{short}", flush=True)
|
| 317 |
rewards.append(0.01)
|
|
|
|
| 318 |
break
|
| 319 |
|
| 320 |
+
if _is_task_fatal(error_msg) or consec_errs >= MAX_CONSEC_ERRORS:
|
| 321 |
+
short = error_msg[:120].replace('\n', ' ')
|
| 322 |
+
print(f"[STEP] step={step_num} action=invalid reward=0.01 done=true error=TASK_STOP:{short}", flush=True)
|
|
|
|
|
|
|
| 323 |
rewards.append(0.01)
|
|
|
|
| 324 |
break
|
| 325 |
|
| 326 |
+
action = parse_action(response_text)
|
| 327 |
action_type = action.get("action_type", "unknown")
|
| 328 |
action["episode_id"] = episode_id
|
| 329 |
|
| 330 |
+
# ββ Env step βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 331 |
try:
|
| 332 |
step_resp = requests.post(f"{ENV_URL}/step", json=action, timeout=30)
|
| 333 |
step_data = step_resp.json()
|
| 334 |
except Exception as e:
|
| 335 |
+
short = str(e)[:100]
|
| 336 |
+
print(f"[STEP] step={step_num} action={action_type} reward=0.01 done=true error={short}", flush=True)
|
| 337 |
rewards.append(0.01)
|
|
|
|
|
|
|
| 338 |
break
|
| 339 |
|
| 340 |
+
reward = float(step_data.get("reward", 0.0))
|
| 341 |
+
done = bool(step_data.get("done", False))
|
| 342 |
+
obs = step_data.get("observation", step_data)
|
| 343 |
step_error = step_data.get("error") or error_msg
|
| 344 |
|
| 345 |
rewards.append(reward)
|
| 346 |
history.append({"step": step_num, "action_type": action_type, "reward": reward, "done": done})
|
| 347 |
|
| 348 |
+
# Show 'invalid' in log when validation failed
|
| 349 |
+
display_action = "invalid" if obs.get("validation_failed") else action_type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
+
# Format error value: null or truncated string
|
| 352 |
+
if step_error:
|
| 353 |
+
error_val = str(step_error)[:150].replace('\n', ' ')
|
| 354 |
+
else:
|
| 355 |
+
error_val = "null"
|
| 356 |
|
| 357 |
+
# ββ Mandatory [STEP] line β exact official spec ββββββββββββββββββββ
|
| 358 |
+
# reward=<0.00> means 2 decimal places
|
| 359 |
+
# done=<true|false> means lowercase boolean string
|
| 360 |
+
print(
|
| 361 |
+
f"[STEP] step={step_num} action={display_action} reward={reward:.2f} "
|
| 362 |
+
f"done={str(done).lower()} error={error_val}",
|
| 363 |
+
flush=True
|
| 364 |
+
)
|
| 365 |
|
| 366 |
if done:
|
|
|
|
| 367 |
break
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
+
# ββ Compute final score ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 370 |
+
score = _compute_score(rewards)
|
| 371 |
+
# success = at least one step scored meaningfully above the floor
|
| 372 |
+
success = any(r > 0.10 for r in rewards)
|
| 373 |
+
|
| 374 |
+
# rewards list: 2 decimal places, comma-separated, no spaces
|
| 375 |
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 376 |
|
| 377 |
+
# ββ Mandatory [END] line β exact official spec βββββββββββββββββββββββββ
|
| 378 |
+
# spec: success=<true|false> steps=<n> rewards=<r1,r2,...,rn>
|
| 379 |
+
# NO score= field β not in the official spec
|
| 380 |
+
print(
|
| 381 |
+
f"[END] success={str(success).lower()} steps={step_num} rewards={rewards_str}",
|
| 382 |
+
flush=True
|
| 383 |
+
)
|
| 384 |
|
| 385 |
return score, fatal_error
|
| 386 |
|
|
|
|
| 388 |
def main() -> None:
|
| 389 |
"""Run all 9 tasks and report final scores."""
|
| 390 |
if not HF_TOKEN:
|
| 391 |
+
print("ERROR: Set HF_TOKEN environment variable.", flush=True)
|
| 392 |
return
|
| 393 |
|
| 394 |
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 395 |
|
|
|
|
| 396 |
try:
|
| 397 |
health = requests.get(f"{ENV_URL}/", timeout=10, headers={"Accept": "application/json"})
|
| 398 |
health_data = health.json()
|
| 399 |
+
print(
|
| 400 |
+
f"Environment: {health_data.get('env', 'unknown')} | "
|
| 401 |
+
f"Tasks: {health_data.get('tasks', 0)}",
|
| 402 |
+
flush=True
|
| 403 |
+
)
|
| 404 |
except Exception as e:
|
| 405 |
print(f"ERROR: Cannot connect to environment at {ENV_URL}: {e}", flush=True)
|
| 406 |
return
|
| 407 |
|
| 408 |
+
scores = {}
|
| 409 |
+
had_fatal = False
|
| 410 |
|
| 411 |
for task_id in TASKS:
|
| 412 |
try:
|
| 413 |
score, is_fatal = run_task(client, task_id)
|
| 414 |
scores[task_id] = score
|
| 415 |
|
|
|
|
| 416 |
if is_fatal:
|
| 417 |
+
had_fatal = True
|
| 418 |
+
print(f"\nπ« Fatal API error on {task_id}. Stopping remaining tasks.", flush=True)
|
|
|
|
|
|
|
| 419 |
for remaining in TASKS:
|
| 420 |
if remaining not in scores:
|
| 421 |
scores[remaining] = 0.01
|
| 422 |
print(f"[START] task={remaining} env={BENCHMARK} model={MODEL_NAME}", flush=True)
|
| 423 |
+
print(f"[END] success=false steps=0 rewards=", flush=True)
|
| 424 |
break
|
| 425 |
|
| 426 |
except Exception as e:
|
|
|
|
|
|
|
| 427 |
scores[task_id] = 0.01
|
| 428 |
+
print(f"[START] task={task_id} env={BENCHMARK} model={MODEL_NAME}", flush=True)
|
| 429 |
+
print(f"[END] success=false steps=0 rewards=", flush=True)
|
| 430 |
|
| 431 |
+
avg = round(sum(scores.values()) / max(len(scores), 1), 4)
|
| 432 |
+
print(f"\nβ
All tasks complete! Average: {avg:.4f}", flush=True)
|
| 433 |
+
# Final JSON summary β evaluator may parse this
|
| 434 |
print(json.dumps({"final_scores": scores}), flush=True)
|
| 435 |
|
| 436 |
+
if had_fatal:
|
| 437 |
+
print("β οΈ Results NOT saved β fatal API error (invalid token / no credits).", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
else:
|
| 439 |
try:
|
| 440 |
from server.benchmark_store import append_result
|
| 441 |
append_result(MODEL_NAME, MODEL_NAME, scores)
|
| 442 |
print(f"πΎ Results saved (avg: {avg:.4f})", flush=True)
|
| 443 |
except Exception as e:
|
| 444 |
+
print(f"β οΈ Could not save results: {e}", flush=True)
|
| 445 |
|
| 446 |
|
| 447 |
if __name__ == "__main__":
|
server/app.py
CHANGED
|
@@ -289,7 +289,7 @@ async def run_inference(request: Request):
|
|
| 289 |
env_vars = os.environ.copy()
|
| 290 |
env_vars['ENV_URL'] = env_vars.get('ENV_URL', 'http://localhost:7860')
|
| 291 |
|
| 292 |
-
#
|
| 293 |
inference_path = os.path.join(
|
| 294 |
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
| 295 |
'inference.py'
|
|
@@ -309,30 +309,19 @@ async def run_inference(request: Request):
|
|
| 309 |
|
| 310 |
stdout = result.stdout or ''
|
| 311 |
stderr = result.stderr or ''
|
|
|
|
| 312 |
|
| 313 |
-
#
|
| 314 |
-
logs = []
|
| 315 |
-
final_scores = {}
|
| 316 |
for line in stdout.splitlines():
|
| 317 |
line = line.strip()
|
| 318 |
-
if
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
parts[k] = v
|
| 327 |
-
task_id = parts.get('task_id', '')
|
| 328 |
-
total_reward = parts.get('total_reward', '0')
|
| 329 |
-
if task_id:
|
| 330 |
-
try:
|
| 331 |
-
final_scores[task_id] = float(total_reward)
|
| 332 |
-
except ValueError:
|
| 333 |
-
final_scores[task_id] = 0.01
|
| 334 |
-
|
| 335 |
-
# Also try final JSON summary line
|
| 336 |
for line in reversed(stdout.splitlines()):
|
| 337 |
line = line.strip()
|
| 338 |
if line.startswith('{') and 'final_scores' in line:
|
|
@@ -340,9 +329,43 @@ async def run_inference(request: Request):
|
|
| 340 |
parsed = json.loads(line)
|
| 341 |
if 'final_scores' in parsed:
|
| 342 |
final_scores = parsed['final_scores']
|
|
|
|
| 343 |
except Exception:
|
| 344 |
pass
|
| 345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
avg = (
|
| 348 |
round(sum(final_scores.values()) / len(final_scores), 4)
|
|
@@ -350,22 +373,22 @@ async def run_inference(request: Request):
|
|
| 350 |
)
|
| 351 |
|
| 352 |
return JSONResponse(status_code=200, content={
|
| 353 |
-
'status':
|
| 354 |
-
'final_scores':
|
| 355 |
'average_score': avg,
|
| 356 |
-
'logs':
|
| 357 |
-
'stderr':
|
| 358 |
-
'returncode':
|
| 359 |
})
|
| 360 |
|
| 361 |
except subprocess.TimeoutExpired:
|
| 362 |
return JSONResponse(status_code=200, content={
|
| 363 |
-
'error':
|
| 364 |
'final_scores': {},
|
| 365 |
})
|
| 366 |
except Exception as e:
|
| 367 |
return JSONResponse(status_code=200, content={
|
| 368 |
-
'error':
|
| 369 |
'final_scores': {},
|
| 370 |
})
|
| 371 |
|
|
@@ -528,13 +551,18 @@ def _run_single_task_inline(task_id, api_base, api_key, model_id, system_prompt)
|
|
| 528 |
logs.append(msg)
|
| 529 |
yield {'type': 'log', 'level': 'info', 'msg': msg}
|
| 530 |
|
| 531 |
-
#
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
rewards_str = ','.join(f'{r:.2f}' for r in rewards)
|
| 536 |
|
| 537 |
-
|
|
|
|
| 538 |
logs.append(msg)
|
| 539 |
yield {'type': 'log', 'level': 'ok', 'msg': msg}
|
| 540 |
yield {'type': 'task_done', 'task_id': task_id, 'score': score, 'logs': logs}
|
|
|
|
| 289 |
env_vars = os.environ.copy()
|
| 290 |
env_vars['ENV_URL'] = env_vars.get('ENV_URL', 'http://localhost:7860')
|
| 291 |
|
| 292 |
+
# inference.py is at project root (one level up from server/)
|
| 293 |
inference_path = os.path.join(
|
| 294 |
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
| 295 |
'inference.py'
|
|
|
|
| 309 |
|
| 310 |
stdout = result.stdout or ''
|
| 311 |
stderr = result.stderr or ''
|
| 312 |
+
logs = []
|
| 313 |
|
| 314 |
+
# Collect all log lines for display
|
|
|
|
|
|
|
| 315 |
for line in stdout.splitlines():
|
| 316 |
line = line.strip()
|
| 317 |
+
if line:
|
| 318 |
+
logs.append(line)
|
| 319 |
+
|
| 320 |
+
# ββ Parse final_scores from the JSON summary line ββ
|
| 321 |
+
# This is authoritative β inference.py always prints:
|
| 322 |
+
# {"final_scores": {"sec_easy": 0.85, ...}}
|
| 323 |
+
# at the end of main().
|
| 324 |
+
final_scores = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
for line in reversed(stdout.splitlines()):
|
| 326 |
line = line.strip()
|
| 327 |
if line.startswith('{') and 'final_scores' in line:
|
|
|
|
| 329 |
parsed = json.loads(line)
|
| 330 |
if 'final_scores' in parsed:
|
| 331 |
final_scores = parsed['final_scores']
|
| 332 |
+
break
|
| 333 |
except Exception:
|
| 334 |
pass
|
| 335 |
+
|
| 336 |
+
# ββ Fallback: parse [END] lines for any tasks missing from JSON ββ
|
| 337 |
+
# Official [END] format: success=<bool> steps=<n> rewards=<r1,r2,...>
|
| 338 |
+
# We track which task we're in via the preceding [START] line.
|
| 339 |
+
if not final_scores:
|
| 340 |
+
current_task = None
|
| 341 |
+
for line in stdout.splitlines():
|
| 342 |
+
line = line.strip()
|
| 343 |
+
if line.startswith('[START]'):
|
| 344 |
+
# Extract task= field
|
| 345 |
+
for token in line.split():
|
| 346 |
+
if token.startswith('task='):
|
| 347 |
+
current_task = token.split('=', 1)[1]
|
| 348 |
+
break
|
| 349 |
+
elif line.startswith('[END]') and current_task:
|
| 350 |
+
# Parse rewards= field and compute score from it
|
| 351 |
+
parts = {}
|
| 352 |
+
for token in line.split():
|
| 353 |
+
if '=' in token:
|
| 354 |
+
k, v = token.split('=', 1)
|
| 355 |
+
parts[k] = v
|
| 356 |
+
rewards_str = parts.get('rewards', '')
|
| 357 |
+
if rewards_str:
|
| 358 |
+
try:
|
| 359 |
+
step_rewards = [float(r) for r in rewards_str.split(',') if r]
|
| 360 |
+
if step_rewards:
|
| 361 |
+
# Same weighted blend as inference.py _compute_score()
|
| 362 |
+
max_r = max(step_rewards)
|
| 363 |
+
mean_r = sum(step_rewards) / len(step_rewards)
|
| 364 |
+
score = round(min(max(0.60 * max_r + 0.40 * mean_r, 0.01), 0.99), 4)
|
| 365 |
+
final_scores[current_task] = score
|
| 366 |
+
except (ValueError, TypeError):
|
| 367 |
+
final_scores[current_task] = 0.01
|
| 368 |
+
current_task = None
|
| 369 |
|
| 370 |
avg = (
|
| 371 |
round(sum(final_scores.values()) / len(final_scores), 4)
|
|
|
|
| 373 |
)
|
| 374 |
|
| 375 |
return JSONResponse(status_code=200, content={
|
| 376 |
+
'status': 'ok' if result.returncode == 0 else 'completed_with_errors',
|
| 377 |
+
'final_scores': final_scores,
|
| 378 |
'average_score': avg,
|
| 379 |
+
'logs': logs[-50:],
|
| 380 |
+
'stderr': stderr[-500:] if stderr else '',
|
| 381 |
+
'returncode': result.returncode,
|
| 382 |
})
|
| 383 |
|
| 384 |
except subprocess.TimeoutExpired:
|
| 385 |
return JSONResponse(status_code=200, content={
|
| 386 |
+
'error': 'inference.py timed out after 20 minutes',
|
| 387 |
'final_scores': {},
|
| 388 |
})
|
| 389 |
except Exception as e:
|
| 390 |
return JSONResponse(status_code=200, content={
|
| 391 |
+
'error': str(e),
|
| 392 |
'final_scores': {},
|
| 393 |
})
|
| 394 |
|
|
|
|
| 551 |
logs.append(msg)
|
| 552 |
yield {'type': 'log', 'level': 'info', 'msg': msg}
|
| 553 |
|
| 554 |
+
# Weighted blend scoring β same as inference.py _compute_score()
|
| 555 |
+
if rewards:
|
| 556 |
+
max_r = max(rewards)
|
| 557 |
+
mean_r = sum(rewards) / len(rewards)
|
| 558 |
+
score = round(min(max(0.60 * max_r + 0.40 * mean_r, 0.01), 0.99), 4)
|
| 559 |
+
else:
|
| 560 |
+
score = 0.01
|
| 561 |
+
success = any(r > 0.10 for r in rewards)
|
| 562 |
rewards_str = ','.join(f'{r:.2f}' for r in rewards)
|
| 563 |
|
| 564 |
+
# [END] line β NO score= field (not in official spec)
|
| 565 |
+
msg = f'[END] success={str(success).lower()} steps={len(rewards)} rewards={rewards_str}'
|
| 566 |
logs.append(msg)
|
| 567 |
yield {'type': 'log', 'level': 'ok', 'msg': msg}
|
| 568 |
yield {'type': 'task_done', 'task_id': task_id, 'score': score, 'logs': logs}
|
server/datasets/dependency_cases.py
CHANGED
|
@@ -1,421 +1,415 @@
|
|
| 1 |
# server/datasets/dependency_cases.py
|
| 2 |
# Ground truth cases for PyTorch Migration Time-Machine tasks.
|
| 3 |
#
|
| 4 |
-
#
|
| 5 |
-
#
|
| 6 |
-
#
|
| 7 |
-
#
|
| 8 |
-
#
|
| 9 |
-
#
|
| 10 |
-
#
|
| 11 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
DEPENDENCY_CASES = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
'dep_easy': [
|
| 15 |
{
|
| 16 |
'case_id': 'dep_easy_001',
|
| 17 |
'task_subtype': 'flag',
|
| 18 |
-
'completion_threshold': 0.
|
| 19 |
'max_steps': 4,
|
| 20 |
'done_conditions': {'min_actions': 1, 'required_sequence': ['flag_outdated']},
|
| 21 |
-
'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
'expected_deprecated_api': 'torch.autograd.Variable',
|
| 23 |
-
'
|
| 24 |
-
'code_snippet': '''import torch
|
| 25 |
-
from torch.autograd import Variable
|
| 26 |
-
|
| 27 |
-
x = Variable(torch.randn(3, 4), requires_grad=True)
|
| 28 |
-
y = Variable(torch.randn(3, 4))
|
| 29 |
-
z = x + y''',
|
| 30 |
-
'task_description': 'Identify outdated PyTorch packages and deprecated APIs in this legacy training script. List the exact package name and deprecated API call.',
|
| 31 |
},
|
| 32 |
{
|
| 33 |
'case_id': 'dep_easy_002',
|
| 34 |
'task_subtype': 'flag',
|
| 35 |
-
'completion_threshold': 0.
|
| 36 |
'max_steps': 4,
|
| 37 |
'done_conditions': {'min_actions': 1, 'required_sequence': ['flag_outdated']},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
'expected_outdated_packages': ['torch'],
|
| 39 |
-
'expected_deprecated_api': '
|
| 40 |
-
'
|
| 41 |
-
'code_snippet': '''import torch
|
| 42 |
-
|
| 43 |
-
model = torch.nn.Linear(10, 5)
|
| 44 |
-
x = torch.randn(1, 10)
|
| 45 |
-
output = model(x)
|
| 46 |
-
result = output.data.numpy() # deprecated''',
|
| 47 |
-
'task_description': 'Find the exact deprecated tensor conversion API in this code. Provide the exact deprecated call.',
|
| 48 |
},
|
| 49 |
{
|
| 50 |
'case_id': 'dep_easy_003',
|
| 51 |
'task_subtype': 'flag',
|
| 52 |
-
'completion_threshold': 0.
|
| 53 |
'max_steps': 4,
|
| 54 |
'done_conditions': {'min_actions': 1, 'required_sequence': ['flag_outdated']},
|
| 55 |
-
'
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
'
|
| 68 |
},
|
| 69 |
{
|
| 70 |
'case_id': 'dep_easy_004',
|
| 71 |
'task_subtype': 'flag',
|
| 72 |
-
'completion_threshold': 0.
|
| 73 |
'max_steps': 4,
|
| 74 |
'done_conditions': {'min_actions': 1, 'required_sequence': ['flag_outdated']},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
'expected_outdated_packages': ['torch'],
|
| 76 |
-
'expected_deprecated_api': '
|
| 77 |
-
'
|
| 78 |
-
'code_snippet': '''import torch
|
| 79 |
-
|
| 80 |
-
model = torch.nn.Linear(10, 5)
|
| 81 |
-
dummy = torch.randn(1, 10)
|
| 82 |
-
torch.onnx.export(model, dummy, "model.onnx",
|
| 83 |
-
opset_version=11)''',
|
| 84 |
-
'task_description': 'Find the deprecated ONNX export API. Specify the exact deprecated function.',
|
| 85 |
},
|
| 86 |
{
|
| 87 |
'case_id': 'dep_easy_005',
|
| 88 |
'task_subtype': 'flag',
|
| 89 |
-
'completion_threshold': 0.
|
| 90 |
'max_steps': 4,
|
| 91 |
'done_conditions': {'min_actions': 1, 'required_sequence': ['flag_outdated']},
|
| 92 |
-
'
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
'code_snippet':
|
| 96 |
-
import torch
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
'
|
|
|
|
|
|
|
|
|
|
| 102 |
},
|
| 103 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
'dep_medium': [
|
| 105 |
{
|
| 106 |
'case_id': 'dep_medium_001',
|
| 107 |
'task_subtype': 'resolve',
|
| 108 |
-
'completion_threshold': 0.
|
| 109 |
-
'max_steps':
|
| 110 |
-
# FIX: min_actions=1 is correct for resolve (1 action needed)
|
| 111 |
-
# but now the grader is tighter so passing takes real work
|
| 112 |
'done_conditions': {'min_actions': 1, 'required_sequence': ['resolve_conflict']},
|
| 113 |
-
'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
'compatibility_matrix': {
|
| 115 |
'torch': {
|
| 116 |
-
'2.1.0': {'numpy': '>=1.
|
| 117 |
-
'2.0.0': {'numpy': '>=1.
|
| 118 |
-
'1.13.0': {'numpy': '>=1.19,<1.25'},
|
| 119 |
},
|
| 120 |
'numpy': {
|
| 121 |
-
'1.26.0': {},
|
| 122 |
'1.24.0': {},
|
| 123 |
-
'1.
|
| 124 |
-
'1.
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
},
|
| 127 |
},
|
| 128 |
-
'requirements': {'torch': '1.9.0', 'numpy': '1.16.0'},
|
| 129 |
-
'code_snippet': '''# requirements.txt
|
| 130 |
-
torch==1.9.0
|
| 131 |
-
numpy==1.16.0
|
| 132 |
-
torchvision==0.10.0''',
|
| 133 |
-
'task_description': 'Resolve the version conflict between torch and numpy. Use the compatibility_matrix to find valid versions where ALL cross-constraints are satisfied.',
|
| 134 |
},
|
| 135 |
{
|
| 136 |
'case_id': 'dep_medium_002',
|
| 137 |
'task_subtype': 'resolve',
|
| 138 |
-
'completion_threshold': 0.
|
| 139 |
-
'max_steps':
|
| 140 |
'done_conditions': {'min_actions': 1, 'required_sequence': ['resolve_conflict']},
|
| 141 |
-
'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
'compatibility_matrix': {
|
| 143 |
-
'
|
| 144 |
-
'2.
|
| 145 |
-
'2.
|
| 146 |
-
'2.0.0': {'numpy': '>=1.22,<1.26', 'torchvision': '>=0.15,<0.16'},
|
| 147 |
},
|
| 148 |
-
'
|
| 149 |
-
'
|
| 150 |
-
'
|
| 151 |
-
'1.22.0': {},
|
| 152 |
},
|
| 153 |
-
'
|
| 154 |
-
'
|
| 155 |
-
'
|
| 156 |
-
'
|
| 157 |
},
|
| 158 |
},
|
| 159 |
-
'requirements': {'torch': '1.12.0', 'numpy': '1.21.0', 'torchvision': '0.13.0'},
|
| 160 |
-
'code_snippet': '''# requirements.txt
|
| 161 |
-
torch==1.12.0
|
| 162 |
-
numpy==1.21.0
|
| 163 |
-
torchvision==0.13.0
|
| 164 |
-
# CUDA 11.7''',
|
| 165 |
-
'task_description': 'Resolve three-way conflict between PyTorch, NumPy, and TorchVision. Note: torchvision 0.16 requires torch >=2.1 AND <2.2. Check ALL constraints carefully.',
|
| 166 |
},
|
| 167 |
{
|
| 168 |
'case_id': 'dep_medium_003',
|
| 169 |
'task_subtype': 'resolve',
|
| 170 |
-
'completion_threshold': 0.
|
| 171 |
-
'max_steps':
|
| 172 |
'done_conditions': {'min_actions': 1, 'required_sequence': ['resolve_conflict']},
|
| 173 |
-
'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
'compatibility_matrix': {
|
| 175 |
-
'torch': {
|
| 176 |
-
'2.1.0': {'transformers': '>=4.35,<4.38'}, # FIX: upper bound added
|
| 177 |
-
'2.0.0': {'transformers': '>=4.30,<4.36'},
|
| 178 |
-
},
|
| 179 |
'transformers': {
|
| 180 |
-
'4.
|
| 181 |
-
'4.
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
},
|
| 184 |
},
|
| 185 |
-
'requirements': {'torch': '1.11.0', 'transformers': '4.20.0'},
|
| 186 |
-
'code_snippet': '''# requirements.txt
|
| 187 |
-
torch==1.11.0
|
| 188 |
-
transformers==4.20.0''',
|
| 189 |
-
'task_description': 'Resolve conflict between PyTorch and Transformers. Note the upper bounds in the compatibility matrix β not all combinations work.',
|
| 190 |
},
|
| 191 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
'dep_hard': [
|
| 193 |
{
|
| 194 |
'case_id': 'dep_hard_001',
|
| 195 |
'task_subtype': 'migrate',
|
| 196 |
-
'completion_threshold': 0.60,
|
| 197 |
-
'max_steps':
|
| 198 |
-
#
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
'graph_breaks': ['break_001', 'break_002', 'break_003'],
|
| 201 |
'checklist_dependency_graph': {
|
| 202 |
-
'break_003': ['
|
| 203 |
-
'break_002': ['break_001'],
|
| 204 |
-
'break_001': [],
|
| 205 |
},
|
| 206 |
'correct_fix_map': {
|
| 207 |
'break_001': 'torch.where',
|
| 208 |
-
'break_002': '
|
| 209 |
-
'break_003': '.
|
| 210 |
},
|
| 211 |
-
'code_snippet': '''import torch
|
| 212 |
-
|
| 213 |
-
@torch.compile(fullgraph=True)
|
| 214 |
-
def forward(x):
|
| 215 |
-
# break_001: data-dependent branch
|
| 216 |
-
if x.max().item() > 1.0:
|
| 217 |
-
x = x / x.max()
|
| 218 |
-
# break_002: Python len() on tensor
|
| 219 |
-
n = len(x)
|
| 220 |
-
# break_003: .data.numpy() deprecated
|
| 221 |
-
result = x.data.numpy()
|
| 222 |
-
return result''',
|
| 223 |
-
'break_descriptions': [
|
| 224 |
-
'break_001: data-dependent control flow β use torch.where()',
|
| 225 |
-
'break_002: len() on tensor β use tensor.shape[0]',
|
| 226 |
-
'break_003: .data.numpy() β use .detach().numpy()',
|
| 227 |
-
],
|
| 228 |
-
'graph_break_report': [
|
| 229 |
-
'break_001: data-dependent control flow β use torch.where()',
|
| 230 |
-
'break_002: len() on tensor β use tensor.shape[0]',
|
| 231 |
-
'break_003: .data.numpy() β use .detach().numpy()',
|
| 232 |
-
],
|
| 233 |
-
'task_description': 'Fix 3 graph-break patterns in this compiled forward pass. Break_002 depends on break_001. Break_003 depends on both. Fix in dependency order.',
|
| 234 |
},
|
| 235 |
{
|
| 236 |
'case_id': 'dep_hard_002',
|
| 237 |
'task_subtype': 'migrate',
|
| 238 |
'completion_threshold': 0.60,
|
| 239 |
-
'max_steps':
|
| 240 |
-
'done_conditions': {'min_actions':
|
| 241 |
-
'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
'checklist_dependency_graph': {
|
| 243 |
-
'
|
| 244 |
-
'break_c': ['break_a'],
|
| 245 |
-
'break_b': [],
|
| 246 |
-
'break_a': [],
|
| 247 |
},
|
| 248 |
'correct_fix_map': {
|
| 249 |
-
'
|
| 250 |
-
'
|
| 251 |
-
'
|
| 252 |
-
'break_d': '.detach()',
|
| 253 |
},
|
| 254 |
-
'code_snippet': '''import torch
|
| 255 |
-
|
| 256 |
-
@torch.compile(fullgraph=True)
|
| 257 |
-
def training_step(model, x, labels):
|
| 258 |
-
# break_a: data-dependent branch
|
| 259 |
-
if x.max().item() > 1.0:
|
| 260 |
-
x = x / x.max()
|
| 261 |
-
# break_b: Python len() on tensor
|
| 262 |
-
n_samples = len(x)
|
| 263 |
-
# break_c: Python list to tensor inside compile
|
| 264 |
-
weights = torch.FloatTensor([1.0, 2.0, 3.0])
|
| 265 |
-
# break_d: in-place operation on leaf tensor
|
| 266 |
-
x += 0.1
|
| 267 |
-
output = model(x)
|
| 268 |
-
loss = torch.nn.functional.cross_entropy(output, labels)
|
| 269 |
-
return loss''',
|
| 270 |
-
'break_descriptions': [
|
| 271 |
-
'break_a: line 6 β data-dependent: if x.max().item() > 1.0',
|
| 272 |
-
'break_b: line 10 β Python builtin: len(x)',
|
| 273 |
-
'break_c: line 13 β legacy constructor: torch.FloatTensor()',
|
| 274 |
-
'break_d: line 16 β in-place op on leaf: x += 0.1',
|
| 275 |
-
],
|
| 276 |
-
'graph_break_report': [
|
| 277 |
-
'break_a: line 6 β data-dependent: if x.max().item() > 1.0',
|
| 278 |
-
'break_b: line 10 β Python builtin: len(x)',
|
| 279 |
-
'break_c: line 13 β legacy constructor: torch.FloatTensor()',
|
| 280 |
-
'break_d: line 16 β in-place op on leaf: x += 0.1',
|
| 281 |
-
],
|
| 282 |
-
'task_description': 'Fix all 4 graph-break patterns in this compiled training step. Break_d depends on break_b AND break_c. Break_c depends on break_a. Fix in dependency order.',
|
| 283 |
},
|
| 284 |
{
|
| 285 |
'case_id': 'dep_hard_003',
|
| 286 |
'task_subtype': 'migrate',
|
| 287 |
'completion_threshold': 0.60,
|
| 288 |
-
'max_steps':
|
| 289 |
-
'done_conditions': {'min_actions':
|
| 290 |
-
'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
'checklist_dependency_graph': {
|
| 292 |
-
'
|
| 293 |
-
'
|
| 294 |
-
'break_x': [],
|
| 295 |
},
|
| 296 |
'correct_fix_map': {
|
| 297 |
-
'
|
| 298 |
-
'
|
| 299 |
-
'
|
| 300 |
},
|
| 301 |
-
'code_snippet': '''import torch
|
| 302 |
-
|
| 303 |
-
@torch.compile
|
| 304 |
-
def forward(x, mask):
|
| 305 |
-
# break_x: tensor.size() returns Python int (graph break)
|
| 306 |
-
n = x.size(0) * x.size(1)
|
| 307 |
-
# break_y: Python function call inside compile
|
| 308 |
-
def custom_fn(t):
|
| 309 |
-
return t * 2
|
| 310 |
-
x = custom_fn(x)
|
| 311 |
-
# break_z: gradient tracking inside compiled region
|
| 312 |
-
with torch.enable_grad():
|
| 313 |
-
x = x * mask
|
| 314 |
-
return x''',
|
| 315 |
-
'break_descriptions': [
|
| 316 |
-
'break_x: line 6 β tensor.size() returns Python int, use tensor.numel()',
|
| 317 |
-
'break_y: line 10 β Python function call, use torch.jit.script decorator',
|
| 318 |
-
'break_z: line 14 β enable_grad inside compile, use torch.no_grad()',
|
| 319 |
-
],
|
| 320 |
-
'graph_break_report': [
|
| 321 |
-
'break_x: line 6 β tensor.size() returns Python int, use tensor.numel()',
|
| 322 |
-
'break_y: line 10 β Python function call, use torch.jit.script decorator',
|
| 323 |
-
'break_z: line 14 β enable_grad inside compile, use torch.no_grad()',
|
| 324 |
-
],
|
| 325 |
-
'task_description': 'Fix torch.compile graph breaks. break_z needs break_x fixed first.',
|
| 326 |
},
|
| 327 |
{
|
| 328 |
'case_id': 'dep_hard_004',
|
| 329 |
'task_subtype': 'migrate',
|
| 330 |
'completion_threshold': 0.60,
|
| 331 |
-
'max_steps':
|
| 332 |
-
'done_conditions': {'min_actions':
|
| 333 |
-
'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
'checklist_dependency_graph': {
|
| 335 |
-
'
|
| 336 |
-
'
|
| 337 |
-
'break_beta': [],
|
| 338 |
-
'break_alpha': [],
|
| 339 |
},
|
| 340 |
'correct_fix_map': {
|
| 341 |
-
'
|
| 342 |
-
'
|
| 343 |
-
'
|
| 344 |
-
'
|
| 345 |
},
|
| 346 |
-
'code_snippet': '''import torch
|
| 347 |
-
|
| 348 |
-
@torch.compile(fullgraph=True)
|
| 349 |
-
def loss_fn(pred, target, weights):
|
| 350 |
-
# break_alpha: if statement on tensor value
|
| 351 |
-
if target.sum() > 0:
|
| 352 |
-
pred = pred * 1.5
|
| 353 |
-
# break_beta: len() on tensor
|
| 354 |
-
batch_size = len(pred)
|
| 355 |
-
# break_gamma: Python list β tensor conversion
|
| 356 |
-
normalized = []
|
| 357 |
-
for i in range(batch_size):
|
| 358 |
-
normalized.append(pred[i] / weights[i])
|
| 359 |
-
result = torch.tensor(normalized)
|
| 360 |
-
# break_delta: calls non-scripted helper
|
| 361 |
-
def helper(x):
|
| 362 |
-
return x.clamp(0, 1)
|
| 363 |
-
return helper(result)''',
|
| 364 |
-
'break_descriptions': [
|
| 365 |
-
'break_alpha: line 6 β data-dependent control flow, use torch.where()',
|
| 366 |
-
'break_beta: line 10 β len() builtin on tensor, use tensor.shape[0]',
|
| 367 |
-
'break_gamma: line 16 β torch.tensor() on Python list, use torch.stack()',
|
| 368 |
-
'break_delta: line 20 β unscripted helper, add @torch.jit.script',
|
| 369 |
-
],
|
| 370 |
-
'graph_break_report': [
|
| 371 |
-
'break_alpha: line 6 β data-dependent control flow, use torch.where()',
|
| 372 |
-
'break_beta: line 10 β len() builtin on tensor, use tensor.shape[0]',
|
| 373 |
-
'break_gamma: line 16 β torch.tensor() on Python list, use torch.stack()',
|
| 374 |
-
'break_delta: line 20 β unscripted helper, add @torch.jit.script',
|
| 375 |
-
],
|
| 376 |
-
'task_description': 'Complex graph-break cascade. Delta depends on Beta AND Gamma. Gamma depends on Alpha. Fix in dependency order.',
|
| 377 |
},
|
| 378 |
{
|
| 379 |
'case_id': 'dep_hard_005',
|
| 380 |
'task_subtype': 'migrate',
|
| 381 |
'completion_threshold': 0.60,
|
| 382 |
-
'max_steps':
|
| 383 |
-
'done_conditions': {'min_actions':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
'graph_breaks': ['break_001', 'break_002', 'break_003'],
|
| 385 |
'checklist_dependency_graph': {
|
| 386 |
-
'
|
| 387 |
-
'break_002': [],
|
| 388 |
-
'break_001': [],
|
| 389 |
},
|
| 390 |
'correct_fix_map': {
|
| 391 |
-
'break_001': 'torch.
|
| 392 |
-
'break_002': '
|
| 393 |
-
'break_003': 'torch.
|
| 394 |
},
|
| 395 |
-
'code_snippet': '''import torch
|
| 396 |
-
from torch.nn.utils import clip_grad_norm_
|
| 397 |
-
|
| 398 |
-
@torch.compile
|
| 399 |
-
def training_step(model, batch, optimizer):
|
| 400 |
-
loss = model(batch['x'], batch['y'])
|
| 401 |
-
loss.backward()
|
| 402 |
-
optimizer.step() # graph break
|
| 403 |
-
grads = []
|
| 404 |
-
for param in model.parameters():
|
| 405 |
-
grads.append(param.grad.norm())
|
| 406 |
-
clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 407 |
-
return loss.item()''',
|
| 408 |
-
'break_descriptions': [
|
| 409 |
-
'break_001: optimizer.step() not compilable, use torch.compile(disable=True)',
|
| 410 |
-
'break_002: Python loop batching, use functorch.vmap',
|
| 411 |
-
'break_003: in-place grad clipping, use torch.export',
|
| 412 |
-
],
|
| 413 |
-
'graph_break_report': [
|
| 414 |
-
'break_001: optimizer.step() not compilable, use torch.compile(disable=True)',
|
| 415 |
-
'break_002: Python loop batching, use functorch.vmap',
|
| 416 |
-
'break_003: in-place grad clipping, use torch.export',
|
| 417 |
-
],
|
| 418 |
-
'task_description': 'Fix training loop graph breaks. Optimizer, gradient accumulation, and clipping all cause compilation failures. Break_003 needs both others first.',
|
| 419 |
},
|
| 420 |
],
|
| 421 |
}
|
|
|
|
| 1 |
# server/datasets/dependency_cases.py
|
| 2 |
# Ground truth cases for PyTorch Migration Time-Machine tasks.
|
| 3 |
#
|
| 4 |
+
# CRITICAL FIX:
|
| 5 |
+
# dep_hard previously had:
|
| 6 |
+
# done_conditions: {min_actions: 2, required_sequence: ['migrate_api', 'migrate_api']}
|
| 7 |
+
#
|
| 8 |
+
# This caused TWO bugs:
|
| 9 |
+
# 1. The agent called migrate_api once. Router checked Counter: needs 2, has 1 β not done.
|
| 10 |
+
# 2. Agent called migrate_api again β repetition_penalty fires (-0.20), tanking the score.
|
| 11 |
+
# 3. Episode only ends at max_steps with a broken accumulated score.
|
| 12 |
+
#
|
| 13 |
+
# FIX: dep_hard now uses min_actions=1, required_sequence=['migrate_api'].
|
| 14 |
+
# The task is already hard enough from the grader β complex checklist, ordering
|
| 15 |
+
# constraints, and exact token matching in fix_quality. The done condition
|
| 16 |
+
# should not add extra difficulty on top of this.
|
| 17 |
+
#
|
| 18 |
+
# ALL dep_easy, dep_medium, dep_hard done conditions verified below.
|
| 19 |
|
| 20 |
DEPENDENCY_CASES = {
|
| 21 |
+
|
| 22 |
+
# ββ DEP EASY βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 23 |
+
# Task: flag outdated packages and deprecated API usage.
|
| 24 |
+
# Done: after 1 flag_outdated action.
|
| 25 |
+
# Grader: F1 on packages (precision+recall) Γ 0.55 + deprecated_api_match Γ 0.45
|
| 26 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
'dep_easy': [
|
| 28 |
{
|
| 29 |
'case_id': 'dep_easy_001',
|
| 30 |
'task_subtype': 'flag',
|
| 31 |
+
'completion_threshold': 0.75,
|
| 32 |
'max_steps': 4,
|
| 33 |
'done_conditions': {'min_actions': 1, 'required_sequence': ['flag_outdated']},
|
| 34 |
+
'task_description': (
|
| 35 |
+
'This codebase uses torch==1.9.0 and relies on torch.autograd.Variable. '
|
| 36 |
+
'Flag all outdated packages and the deprecated API.'
|
| 37 |
+
),
|
| 38 |
+
'code_snippet': (
|
| 39 |
+
'import torch\n'
|
| 40 |
+
'from torch.autograd import Variable\n'
|
| 41 |
+
'x = Variable(torch.randn(3, 4))\n'
|
| 42 |
+
'model = torch.nn.Linear(4, 2)\n'
|
| 43 |
+
'out = model(x)'
|
| 44 |
+
),
|
| 45 |
+
'requirements': {'torch': '1.9.0', 'torchvision': '0.10.0'},
|
| 46 |
+
'expected_outdated_packages': ['torch', 'torchvision'],
|
| 47 |
'expected_deprecated_api': 'torch.autograd.Variable',
|
| 48 |
+
'expected_replacement': 'plain tensor with requires_grad=True',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
},
|
| 50 |
{
|
| 51 |
'case_id': 'dep_easy_002',
|
| 52 |
'task_subtype': 'flag',
|
| 53 |
+
'completion_threshold': 0.75,
|
| 54 |
'max_steps': 4,
|
| 55 |
'done_conditions': {'min_actions': 1, 'required_sequence': ['flag_outdated']},
|
| 56 |
+
'task_description': (
|
| 57 |
+
'This codebase uses torch==1.4.0 and calls .cuda() directly. '
|
| 58 |
+
'Flag outdated packages and the deprecated device assignment pattern.'
|
| 59 |
+
),
|
| 60 |
+
'code_snippet': (
|
| 61 |
+
'import torch\n'
|
| 62 |
+
'model = MyModel()\n'
|
| 63 |
+
'model.cuda() # deprecated β use .to(device)\n'
|
| 64 |
+
'tensor = torch.randn(2, 3).cuda()'
|
| 65 |
+
),
|
| 66 |
+
'requirements': {'torch': '1.4.0'},
|
| 67 |
'expected_outdated_packages': ['torch'],
|
| 68 |
+
'expected_deprecated_api': '.cuda()',
|
| 69 |
+
'expected_replacement': '.to(device)',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
},
|
| 71 |
{
|
| 72 |
'case_id': 'dep_easy_003',
|
| 73 |
'task_subtype': 'flag',
|
| 74 |
+
'completion_threshold': 0.75,
|
| 75 |
'max_steps': 4,
|
| 76 |
'done_conditions': {'min_actions': 1, 'required_sequence': ['flag_outdated']},
|
| 77 |
+
'task_description': (
|
| 78 |
+
'This codebase uses torch==1.7.0 with DataParallel. '
|
| 79 |
+
'Flag the outdated package and the deprecated multi-GPU API.'
|
| 80 |
+
),
|
| 81 |
+
'code_snippet': (
|
| 82 |
+
'import torch\n'
|
| 83 |
+
'model = torch.nn.DataParallel(MyModel())\n'
|
| 84 |
+
'model.cuda()'
|
| 85 |
+
),
|
| 86 |
+
'requirements': {'torch': '1.7.0', 'numpy': '1.18.0'},
|
| 87 |
+
'expected_outdated_packages': ['torch', 'numpy'],
|
| 88 |
+
'expected_deprecated_api': 'torch.nn.DataParallel',
|
| 89 |
+
'expected_replacement': 'DistributedDataParallel',
|
| 90 |
},
|
| 91 |
{
|
| 92 |
'case_id': 'dep_easy_004',
|
| 93 |
'task_subtype': 'flag',
|
| 94 |
+
'completion_threshold': 0.75,
|
| 95 |
'max_steps': 4,
|
| 96 |
'done_conditions': {'min_actions': 1, 'required_sequence': ['flag_outdated']},
|
| 97 |
+
'task_description': (
|
| 98 |
+
'Flag outdated packages and the deprecated ONNX export API in this code.'
|
| 99 |
+
),
|
| 100 |
+
'code_snippet': (
|
| 101 |
+
'import torch\n'
|
| 102 |
+
'torch.onnx.export(model, dummy_input, "model.onnx",\n'
|
| 103 |
+
' opset_version=9,\n'
|
| 104 |
+
' enable_onnx_checker=True) # deprecated kwarg'
|
| 105 |
+
),
|
| 106 |
+
'requirements': {'torch': '1.8.0'},
|
| 107 |
'expected_outdated_packages': ['torch'],
|
| 108 |
+
'expected_deprecated_api': 'enable_onnx_checker',
|
| 109 |
+
'expected_replacement': 'remove the kwarg (deprecated in 1.9, removed in 2.0)',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
},
|
| 111 |
{
|
| 112 |
'case_id': 'dep_easy_005',
|
| 113 |
'task_subtype': 'flag',
|
| 114 |
+
'completion_threshold': 0.75,
|
| 115 |
'max_steps': 4,
|
| 116 |
'done_conditions': {'min_actions': 1, 'required_sequence': ['flag_outdated']},
|
| 117 |
+
'task_description': (
|
| 118 |
+
'Flag outdated packages and the deprecated autocast API.'
|
| 119 |
+
),
|
| 120 |
+
'code_snippet': (
|
| 121 |
+
'import torch\n'
|
| 122 |
+
'from torch.cuda.amp import autocast\n'
|
| 123 |
+
'with autocast(): # deprecated import path\n'
|
| 124 |
+
' output = model(input)'
|
| 125 |
+
),
|
| 126 |
+
'requirements': {'torch': '1.6.0', 'torchaudio': '0.6.0'},
|
| 127 |
+
'expected_outdated_packages': ['torch', 'torchaudio'],
|
| 128 |
+
'expected_deprecated_api': 'torch.cuda.amp.autocast',
|
| 129 |
+
'expected_replacement': 'torch.amp.autocast',
|
| 130 |
},
|
| 131 |
],
|
| 132 |
+
|
| 133 |
+
# ββ DEP MEDIUM ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 134 |
+
# Task: resolve version conflicts using the compatibility_matrix.
|
| 135 |
+
# Done: after 1 resolve_conflict action.
|
| 136 |
+
# Grader: valid_pkgs/conflict_count + cross-constraint check - downgrade penalty
|
| 137 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 138 |
'dep_medium': [
|
| 139 |
{
|
| 140 |
'case_id': 'dep_medium_001',
|
| 141 |
'task_subtype': 'resolve',
|
| 142 |
+
'completion_threshold': 0.70,
|
| 143 |
+
'max_steps': 4,
|
|
|
|
|
|
|
| 144 |
'done_conditions': {'min_actions': 1, 'required_sequence': ['resolve_conflict']},
|
| 145 |
+
'task_description': (
|
| 146 |
+
'Resolve the version conflict between torch, numpy, and protobuf. '
|
| 147 |
+
'Use the compatibility_matrix to find a compatible set of versions.'
|
| 148 |
+
),
|
| 149 |
+
'code_snippet': 'requirements.txt with conflicting torch==2.0.0, numpy==1.20.0, protobuf==3.9.0',
|
| 150 |
+
'requirements': {'torch': '2.0.0', 'numpy': '1.20.0', 'protobuf': '3.9.0'},
|
| 151 |
+
'conflict_packages': ['torch', 'numpy', 'protobuf'],
|
| 152 |
'compatibility_matrix': {
|
| 153 |
'torch': {
|
| 154 |
+
'2.1.0': {'numpy': '>=1.21,<2.0', 'protobuf': '>=3.20,<5.0'},
|
| 155 |
+
'2.0.0': {'numpy': '>=1.20,<1.25', 'protobuf': '>=3.19,<4.0'},
|
|
|
|
| 156 |
},
|
| 157 |
'numpy': {
|
|
|
|
| 158 |
'1.24.0': {},
|
| 159 |
+
'1.21.0': {},
|
| 160 |
+
'1.20.0': {},
|
| 161 |
+
},
|
| 162 |
+
'protobuf': {
|
| 163 |
+
'4.23.0': {},
|
| 164 |
+
'3.20.0': {},
|
| 165 |
+
'3.9.0': {'torch': '<=1.13'},
|
| 166 |
},
|
| 167 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
},
|
| 169 |
{
|
| 170 |
'case_id': 'dep_medium_002',
|
| 171 |
'task_subtype': 'resolve',
|
| 172 |
+
'completion_threshold': 0.70,
|
| 173 |
+
'max_steps': 4,
|
| 174 |
'done_conditions': {'min_actions': 1, 'required_sequence': ['resolve_conflict']},
|
| 175 |
+
'task_description': (
|
| 176 |
+
'Resolve the version conflict between tensorflow, keras, and h5py.'
|
| 177 |
+
),
|
| 178 |
+
'code_snippet': 'requirements.txt: tensorflow==2.10.0, keras==2.10.0, h5py==2.10.0',
|
| 179 |
+
'requirements': {'tensorflow': '2.10.0', 'keras': '2.10.0', 'h5py': '2.10.0'},
|
| 180 |
+
'conflict_packages': ['tensorflow', 'keras', 'h5py'],
|
| 181 |
'compatibility_matrix': {
|
| 182 |
+
'tensorflow': {
|
| 183 |
+
'2.13.0': {'keras': '>=2.13,<2.14', 'h5py': '>=3.7'},
|
| 184 |
+
'2.10.0': {'keras': '==2.10.0', 'h5py': '>=3.1'},
|
|
|
|
| 185 |
},
|
| 186 |
+
'keras': {
|
| 187 |
+
'2.13.0': {'tensorflow': '>=2.13,<2.14'},
|
| 188 |
+
'2.10.0': {'tensorflow': '==2.10.0'},
|
|
|
|
| 189 |
},
|
| 190 |
+
'h5py': {
|
| 191 |
+
'3.9.0': {},
|
| 192 |
+
'3.7.0': {},
|
| 193 |
+
'2.10.0': {'tensorflow': '<=2.3'},
|
| 194 |
},
|
| 195 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
},
|
| 197 |
{
|
| 198 |
'case_id': 'dep_medium_003',
|
| 199 |
'task_subtype': 'resolve',
|
| 200 |
+
'completion_threshold': 0.70,
|
| 201 |
+
'max_steps': 4,
|
| 202 |
'done_conditions': {'min_actions': 1, 'required_sequence': ['resolve_conflict']},
|
| 203 |
+
'task_description': (
|
| 204 |
+
'Resolve the conflict between transformers, tokenizers, and datasets packages.'
|
| 205 |
+
),
|
| 206 |
+
'code_snippet': 'requirements: transformers==4.20.0, tokenizers==0.11.0, datasets==1.18.0',
|
| 207 |
+
'requirements': {'transformers': '4.20.0', 'tokenizers': '0.11.0', 'datasets': '1.18.0'},
|
| 208 |
+
'conflict_packages': ['transformers', 'tokenizers', 'datasets'],
|
| 209 |
'compatibility_matrix': {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
'transformers': {
|
| 211 |
+
'4.35.0': {'tokenizers': '>=0.14,<0.19', 'datasets': '>=2.14'},
|
| 212 |
+
'4.20.0': {'tokenizers': '>=0.11,<0.14', 'datasets': '>=1.18'},
|
| 213 |
+
},
|
| 214 |
+
'tokenizers': {
|
| 215 |
+
'0.15.0': {'transformers': '>=4.28'},
|
| 216 |
+
'0.14.0': {'transformers': '>=4.25'},
|
| 217 |
+
'0.11.0': {},
|
| 218 |
+
},
|
| 219 |
+
'datasets': {
|
| 220 |
+
'2.14.0': {},
|
| 221 |
+
'2.10.0': {},
|
| 222 |
+
'1.18.0': {'tokenizers': '<=0.13'},
|
| 223 |
},
|
| 224 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
},
|
| 226 |
],
|
| 227 |
+
|
| 228 |
+
# ββ DEP HARD ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 229 |
+
# Task: fix torch.compile graph-break patterns.
|
| 230 |
+
# Done: after 1 migrate_api action (FIXED from 2 β 1).
|
| 231 |
+
#
|
| 232 |
+
# IMPORTANT: min_actions=1, required_sequence=['migrate_api']
|
| 233 |
+
# The grader already makes this hard through:
|
| 234 |
+
# - Multiple graph_breaks to fix (3-5 per case)
|
| 235 |
+
# - Ordering constraints via checklist_dependency_graph
|
| 236 |
+
# - Exact token matching in fix_quality
|
| 237 |
+
# We do NOT need the done condition to create artificial difficulty.
|
| 238 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 239 |
'dep_hard': [
|
| 240 |
{
|
| 241 |
'case_id': 'dep_hard_001',
|
| 242 |
'task_subtype': 'migrate',
|
| 243 |
+
'completion_threshold': 0.60,
|
| 244 |
+
'max_steps': 6,
|
| 245 |
+
# FIXED: was min_actions=2, required_sequence=['migrate_api','migrate_api']
|
| 246 |
+
# which caused repetition penalty on the 2nd call and never terminated cleanly
|
| 247 |
+
'done_conditions': {'min_actions': 1, 'required_sequence': ['migrate_api']},
|
| 248 |
+
'task_description': (
|
| 249 |
+
'Fix the torch.compile graph-break patterns in this training loop. '
|
| 250 |
+
'Provide completed_items (list of break IDs) and code_changes (dict of fixes).'
|
| 251 |
+
),
|
| 252 |
+
'code_snippet': (
|
| 253 |
+
'import torch\n\n'
|
| 254 |
+
'def train_step(model, x):\n'
|
| 255 |
+
' out = model(x)\n'
|
| 256 |
+
' if out.shape[0] != x.shape[0]: # data-dependent branch [break_001]\n'
|
| 257 |
+
' out = torch.zeros_like(x)\n'
|
| 258 |
+
' idx = int(out.argmax()) # int() conversion [break_002]\n'
|
| 259 |
+
' mask = out > 0.5 # dynamic masking [break_003]\n'
|
| 260 |
+
' return out[mask].sum()\n'
|
| 261 |
+
),
|
| 262 |
+
'graph_break_report': [
|
| 263 |
+
'break_001: data-dependent control flow (if out.shape[0] != x.shape[0])',
|
| 264 |
+
'break_002: Python int() call on tensor (int(out.argmax()))',
|
| 265 |
+
'break_003: dynamic boolean indexing (out[mask])',
|
| 266 |
+
],
|
| 267 |
'graph_breaks': ['break_001', 'break_002', 'break_003'],
|
| 268 |
'checklist_dependency_graph': {
|
| 269 |
+
'break_003': ['break_002'], # must fix int() conversion before mask
|
|
|
|
|
|
|
| 270 |
},
|
| 271 |
'correct_fix_map': {
|
| 272 |
'break_001': 'torch.where',
|
| 273 |
+
'break_002': 'torch.argmax',
|
| 274 |
+
'break_003': 'torch.masked_select',
|
| 275 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
},
|
| 277 |
{
|
| 278 |
'case_id': 'dep_hard_002',
|
| 279 |
'task_subtype': 'migrate',
|
| 280 |
'completion_threshold': 0.60,
|
| 281 |
+
'max_steps': 6,
|
| 282 |
+
'done_conditions': {'min_actions': 1, 'required_sequence': ['migrate_api']},
|
| 283 |
+
'task_description': (
|
| 284 |
+
'Fix these torch.compile graph-breaks in a model forward pass.'
|
| 285 |
+
),
|
| 286 |
+
'code_snippet': (
|
| 287 |
+
'def forward(self, x):\n'
|
| 288 |
+
' x = self.conv(x)\n'
|
| 289 |
+
' size = x.size(0) # .size() with int [break_001]\n'
|
| 290 |
+
' out = x.numpy() # .numpy() call [break_002]\n'
|
| 291 |
+
' out = torch.from_numpy(out)\n'
|
| 292 |
+
' return out[:size//2] # dynamic slice [break_003]\n'
|
| 293 |
+
),
|
| 294 |
+
'graph_break_report': [
|
| 295 |
+
'break_001: .size() call returning Python int',
|
| 296 |
+
'break_002: .numpy() call breaks compilation boundary',
|
| 297 |
+
'break_003: dynamic slicing with Python division',
|
| 298 |
+
],
|
| 299 |
+
'graph_breaks': ['break_001', 'break_002', 'break_003'],
|
| 300 |
'checklist_dependency_graph': {
|
| 301 |
+
'break_003': ['break_001'],
|
|
|
|
|
|
|
|
|
|
| 302 |
},
|
| 303 |
'correct_fix_map': {
|
| 304 |
+
'break_001': 'tensor.shape[0]',
|
| 305 |
+
'break_002': 'detach',
|
| 306 |
+
'break_003': 'torch.narrow',
|
|
|
|
| 307 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
},
|
| 309 |
{
|
| 310 |
'case_id': 'dep_hard_003',
|
| 311 |
'task_subtype': 'migrate',
|
| 312 |
'completion_threshold': 0.60,
|
| 313 |
+
'max_steps': 6,
|
| 314 |
+
'done_conditions': {'min_actions': 1, 'required_sequence': ['migrate_api']},
|
| 315 |
+
'task_description': (
|
| 316 |
+
'Fix torch.compile graph-breaks in this attention implementation.'
|
| 317 |
+
),
|
| 318 |
+
'code_snippet': (
|
| 319 |
+
'def attention(q, k, v):\n'
|
| 320 |
+
' scores = torch.matmul(q, k.transpose(-2, -1))\n'
|
| 321 |
+
' if scores.max() > 100: # data-dependent branch [break_001]\n'
|
| 322 |
+
' scores = scores / 100\n'
|
| 323 |
+
' weights = scores.numpy() # numpy call [break_002]\n'
|
| 324 |
+
' weights = torch.softmax(torch.tensor(weights), dim=-1)\n'
|
| 325 |
+
' n = int(q.shape[0]) # Python int [break_003]\n'
|
| 326 |
+
' return weights[:n] @ v\n'
|
| 327 |
+
),
|
| 328 |
+
'graph_break_report': [
|
| 329 |
+
'break_001: data-dependent branch on scores.max()',
|
| 330 |
+
'break_002: .numpy() breaks torch.compile boundary',
|
| 331 |
+
'break_003: Python int() on tensor dimension',
|
| 332 |
+
],
|
| 333 |
+
'graph_breaks': ['break_001', 'break_002', 'break_003'],
|
| 334 |
'checklist_dependency_graph': {
|
| 335 |
+
'break_003': ['break_001'],
|
| 336 |
+
'break_002': ['break_001'],
|
|
|
|
| 337 |
},
|
| 338 |
'correct_fix_map': {
|
| 339 |
+
'break_001': 'torch.clamp',
|
| 340 |
+
'break_002': 'torch.softmax',
|
| 341 |
+
'break_003': 'tensor.shape',
|
| 342 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
},
|
| 344 |
{
|
| 345 |
'case_id': 'dep_hard_004',
|
| 346 |
'task_subtype': 'migrate',
|
| 347 |
'completion_threshold': 0.60,
|
| 348 |
+
'max_steps': 6,
|
| 349 |
+
'done_conditions': {'min_actions': 1, 'required_sequence': ['migrate_api']},
|
| 350 |
+
'task_description': (
|
| 351 |
+
'Fix four torch.compile graph-breaks in this training utility.'
|
| 352 |
+
),
|
| 353 |
+
'code_snippet': (
|
| 354 |
+
'def process_batch(batch):\n'
|
| 355 |
+
' lengths = [len(x) for x in batch] # Python list comp [break_001]\n'
|
| 356 |
+
' max_len = max(lengths) # Python max() [break_002]\n'
|
| 357 |
+
' padded = torch.zeros(len(batch), max_len)\n'
|
| 358 |
+
' for i, x in enumerate(batch): # Python loop [break_003]\n'
|
| 359 |
+
' padded[i, :len(x)] = x\n'
|
| 360 |
+
' out = model(padded)\n'
|
| 361 |
+
' return out.cpu().numpy() # .numpy() [break_004]\n'
|
| 362 |
+
),
|
| 363 |
+
'graph_break_report': [
|
| 364 |
+
'break_001: Python list comprehension over tensor data',
|
| 365 |
+
'break_002: Python max() on list of tensor values',
|
| 366 |
+
'break_003: Python for loop with tensor indexing',
|
| 367 |
+
'break_004: .numpy() call at output',
|
| 368 |
+
],
|
| 369 |
+
'graph_breaks': ['break_001', 'break_002', 'break_003', 'break_004'],
|
| 370 |
'checklist_dependency_graph': {
|
| 371 |
+
'break_002': ['break_001'],
|
| 372 |
+
'break_003': ['break_002'],
|
|
|
|
|
|
|
| 373 |
},
|
| 374 |
'correct_fix_map': {
|
| 375 |
+
'break_001': 'torch.tensor',
|
| 376 |
+
'break_002': 'torch.max',
|
| 377 |
+
'break_003': 'torch.nn.utils.rnn.pad_sequence',
|
| 378 |
+
'break_004': 'detach',
|
| 379 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
},
|
| 381 |
{
|
| 382 |
'case_id': 'dep_hard_005',
|
| 383 |
'task_subtype': 'migrate',
|
| 384 |
'completion_threshold': 0.60,
|
| 385 |
+
'max_steps': 6,
|
| 386 |
+
'done_conditions': {'min_actions': 1, 'required_sequence': ['migrate_api']},
|
| 387 |
+
'task_description': (
|
| 388 |
+
'Fix torch.compile graph-breaks caused by vmap incompatibilities.'
|
| 389 |
+
),
|
| 390 |
+
'code_snippet': (
|
| 391 |
+
'from torch._vmap_internals import vmap # deprecated [break_001]\n'
|
| 392 |
+
'import functorch # deprecated module [break_002]\n\n'
|
| 393 |
+
'def batched_fn(x):\n'
|
| 394 |
+
' result = vmap(model)(x)\n'
|
| 395 |
+
' if result.isnan().any(): # data-dependent check [break_003]\n'
|
| 396 |
+
' result = torch.zeros_like(result)\n'
|
| 397 |
+
' return result\n'
|
| 398 |
+
),
|
| 399 |
+
'graph_break_report': [
|
| 400 |
+
'break_001: torch._vmap_internals.vmap is deprecated (use torch.vmap)',
|
| 401 |
+
'break_002: functorch module is deprecated (merged into torch)',
|
| 402 |
+
'break_003: data-dependent .any() check breaks compilation',
|
| 403 |
+
],
|
| 404 |
'graph_breaks': ['break_001', 'break_002', 'break_003'],
|
| 405 |
'checklist_dependency_graph': {
|
| 406 |
+
'break_002': ['break_001'],
|
|
|
|
|
|
|
| 407 |
},
|
| 408 |
'correct_fix_map': {
|
| 409 |
+
'break_001': 'torch.vmap',
|
| 410 |
+
'break_002': 'torch.func',
|
| 411 |
+
'break_003': 'torch.where',
|
| 412 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
},
|
| 414 |
],
|
| 415 |
}
|
server/graders/base_grader.py
CHANGED
|
@@ -1,38 +1,86 @@
|
|
| 1 |
# server/graders/base_grader.py
|
| 2 |
# Core grading utilities used by ALL domain graders.
|
| 3 |
-
#
|
| 4 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
from typing import Dict, Any, List, Callable
|
| 7 |
|
| 8 |
|
| 9 |
def safe_score(raw) -> float:
|
| 10 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
if raw is None:
|
| 12 |
return 0.01
|
| 13 |
try:
|
| 14 |
val = float(raw)
|
| 15 |
-
# FIX: Don't round aggressively β keep 4 decimal places so variance is visible
|
| 16 |
return round(max(0.01, min(0.99, val)), 4)
|
| 17 |
except (TypeError, ValueError):
|
| 18 |
return 0.01
|
| 19 |
|
| 20 |
|
| 21 |
def repetition_penalty(action_type: str, last_actions: List[str], window: int = 3) -> float:
|
| 22 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
count = last_actions[-window:].count(action_type)
|
| 24 |
-
# FIX: Increased penalty from -0.15 to -0.20 per repeat so it actually stings
|
| 25 |
return -0.20 * count
|
| 26 |
|
| 27 |
|
| 28 |
def invalid_action_penalty(action_type: str, valid_actions: List[str]) -> float:
|
| 29 |
-
"""
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
return -0.40 if action_type not in valid_actions else 0.0
|
| 32 |
|
| 33 |
|
| 34 |
def harmful_output_penalty(action: Dict, forbidden_patterns: List[str]) -> float:
|
| 35 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
action_str = str(action).lower()
|
| 37 |
for p in forbidden_patterns:
|
| 38 |
if p.lower() in action_str:
|
|
@@ -41,70 +89,65 @@ def harmful_output_penalty(action: Dict, forbidden_patterns: List[str]) -> float
|
|
| 41 |
|
| 42 |
|
| 43 |
def efficiency_bonus(step_count: int, max_steps: int, done: bool) -> float:
|
| 44 |
-
"""Small bonus for finishing early. FIX: reduced from 0.10 to 0.05 so it doesn't
|
| 45 |
-
inflate scores β the correctness score should be the main signal."""
|
| 46 |
-
return 0.05 if done and step_count < max_steps // 2 else 0.0
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def difficulty_multiplier(task_id: str) -> float:
|
| 50 |
"""
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
- medium tasks: a perfect answer gets 0.90 max (10% cap)
|
| 56 |
-
- hard tasks: a perfect answer gets 0.80 max (20% cap) β they're SUPPOSED to be hard
|
| 57 |
-
|
| 58 |
-
This ensures there's real spread between easy/medium/hard scores.
|
| 59 |
"""
|
| 60 |
-
if
|
| 61 |
-
return 0.80
|
| 62 |
-
elif 'medium' in task_id:
|
| 63 |
-
return 0.90
|
| 64 |
-
else:
|
| 65 |
-
return 0.99 # easy β allow near-perfect
|
| 66 |
|
| 67 |
|
| 68 |
def grade_dynamic(
|
| 69 |
-
action:
|
| 70 |
session,
|
| 71 |
compute_correctness_fn: Callable,
|
| 72 |
-
valid_actions:
|
| 73 |
-
forbidden_patterns:
|
| 74 |
-
max_steps:
|
| 75 |
) -> float:
|
| 76 |
-
"""
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
"""
|
| 80 |
if forbidden_patterns is None:
|
| 81 |
forbidden_patterns = []
|
| 82 |
|
| 83 |
action_type = action.get('action_type', 'unknown')
|
| 84 |
|
| 85 |
-
#
|
| 86 |
-
inv
|
| 87 |
-
rep
|
| 88 |
-
harm = harmful_output_penalty(action, forbidden_patterns)
|
| 89 |
-
|
| 90 |
-
# If action type is invalid, skip the grader entirely
|
| 91 |
if inv < 0:
|
| 92 |
return safe_score(inv + rep)
|
| 93 |
|
| 94 |
-
#
|
| 95 |
correctness = compute_correctness_fn(action, session.task_case)
|
| 96 |
-
|
| 97 |
if correctness is None:
|
| 98 |
-
correctness = 0.
|
| 99 |
|
| 100 |
-
#
|
| 101 |
-
|
| 102 |
-
max_allowed = difficulty_multiplier(task_id)
|
| 103 |
-
correctness = min(correctness, max_allowed)
|
| 104 |
|
| 105 |
-
#
|
| 106 |
eff = efficiency_bonus(session.step_count + 1, max_steps, correctness >= 0.75)
|
| 107 |
|
| 108 |
-
# Combine and clamp
|
| 109 |
raw = correctness + rep + harm + eff
|
| 110 |
return safe_score(raw)
|
|
|
|
| 1 |
# server/graders/base_grader.py
|
| 2 |
# Core grading utilities used by ALL domain graders.
|
| 3 |
+
#
|
| 4 |
+
# CHANGES FROM PREVIOUS VERSION:
|
| 5 |
+
# 1. difficulty_multiplier() β REMOVED ENTIRELY.
|
| 6 |
+
# The cap (hardβ0.80, mediumβ0.90) made every hard task score identically
|
| 7 |
+
# at 0.80 and every medium task at 0.90, regardless of agent quality.
|
| 8 |
+
# This is exactly the wrong behaviour for an RL training environment:
|
| 9 |
+
# GRPO needs variance WITHIN difficulty levels, not a uniform ceiling.
|
| 10 |
+
# Task difficulty now comes from the grader logic and case design alone.
|
| 11 |
+
#
|
| 12 |
+
# 2. safe_score range: [0.01, 0.99]
|
| 13 |
+
# The official spec says "strictly between 0 and 1".
|
| 14 |
+
# Discord consensus from many participants confirmed 0.01/0.99 as the
|
| 15 |
+
# correct interpretation. Do not change this back to [0.0, 1.0].
|
| 16 |
+
#
|
| 17 |
+
# 3. Penalty values kept as-is (increased in last revision):
|
| 18 |
+
# - repetition_penalty: -0.20 per repeat (was -0.15)
|
| 19 |
+
# - invalid_action_penalty: -0.40 for wrong domain action (was -0.20)
|
| 20 |
+
# - harmful_output_penalty: -0.50 for destructive patterns
|
| 21 |
+
# These are intentionally higher to create real signal.
|
| 22 |
+
#
|
| 23 |
+
# 4. efficiency_bonus reduced to 0.05 (was 0.10).
|
| 24 |
+
# Small enough that it doesn't inflate scores, but still rewards
|
| 25 |
+
# agents that solve tasks efficiently.
|
| 26 |
|
| 27 |
from typing import Dict, Any, List, Callable
|
| 28 |
|
| 29 |
|
| 30 |
def safe_score(raw) -> float:
|
| 31 |
+
"""
|
| 32 |
+
Clamp score to [0.01, 0.99]. Never crash. Returns float.
|
| 33 |
+
|
| 34 |
+
WHY [0.01, 0.99] NOT [0.0, 1.0]:
|
| 35 |
+
- Official spec says scores must be strictly between 0 and 1
|
| 36 |
+
- Discord confirmed 0.01/0.99 as the correct practical interpretation
|
| 37 |
+
- A score of exactly 0.0 from a broken run looks like a crash
|
| 38 |
+
- A score of exactly 1.0 means the grader is trivially solved
|
| 39 |
+
|
| 40 |
+
WHY 4 DECIMAL PLACES:
|
| 41 |
+
- Keeps variance visible (0.4500 vs 0.4750 are meaningfully different)
|
| 42 |
+
- round() handles float precision artifacts
|
| 43 |
+
"""
|
| 44 |
if raw is None:
|
| 45 |
return 0.01
|
| 46 |
try:
|
| 47 |
val = float(raw)
|
|
|
|
| 48 |
return round(max(0.01, min(0.99, val)), 4)
|
| 49 |
except (TypeError, ValueError):
|
| 50 |
return 0.01
|
| 51 |
|
| 52 |
|
| 53 |
def repetition_penalty(action_type: str, last_actions: List[str], window: int = 3) -> float:
|
| 54 |
+
"""
|
| 55 |
+
Penalise repeating the same action type in the last N steps.
|
| 56 |
+
|
| 57 |
+
WHY: Without this, GRPO agents discover they can emit the same
|
| 58 |
+
high-scoring action repeatedly within an episode. The penalty
|
| 59 |
+
forces genuine strategy exploration each turn.
|
| 60 |
+
|
| 61 |
+
-0.20 per repeat (capped by window=3, so max penalty is -0.60).
|
| 62 |
+
"""
|
| 63 |
count = last_actions[-window:].count(action_type)
|
|
|
|
| 64 |
return -0.20 * count
|
| 65 |
|
| 66 |
|
| 67 |
def invalid_action_penalty(action_type: str, valid_actions: List[str]) -> float:
|
| 68 |
+
"""
|
| 69 |
+
Penalise actions not in the valid set for this domain.
|
| 70 |
+
|
| 71 |
+
-0.40 because calling a dependency action on a security task is a
|
| 72 |
+
fundamental routing error β it should hurt significantly.
|
| 73 |
+
"""
|
| 74 |
return -0.40 if action_type not in valid_actions else 0.0
|
| 75 |
|
| 76 |
|
| 77 |
def harmful_output_penalty(action: Dict, forbidden_patterns: List[str]) -> float:
|
| 78 |
+
"""
|
| 79 |
+
Penalise destructive patterns like 'os.remove', 'drop table'.
|
| 80 |
+
|
| 81 |
+
-0.50 because these patterns represent the agent trying to "cheat"
|
| 82 |
+
by deleting things rather than fixing them.
|
| 83 |
+
"""
|
| 84 |
action_str = str(action).lower()
|
| 85 |
for p in forbidden_patterns:
|
| 86 |
if p.lower() in action_str:
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
def efficiency_bonus(step_count: int, max_steps: int, done: bool) -> float:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
"""
|
| 93 |
+
Small bonus for finishing early β rewards decisive, confident agents.
|
| 94 |
+
|
| 95 |
+
WHY ONLY 0.05: The correctness score must be the dominant signal.
|
| 96 |
+
The efficiency bonus should never flip a mediocre answer into a good score.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
"""
|
| 98 |
+
return 0.05 if done and step_count < max_steps // 2 else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
|
| 101 |
def grade_dynamic(
|
| 102 |
+
action: Dict[str, Any],
|
| 103 |
session,
|
| 104 |
compute_correctness_fn: Callable,
|
| 105 |
+
valid_actions: List[str],
|
| 106 |
+
forbidden_patterns: List[str] = None,
|
| 107 |
+
max_steps: int = 8,
|
| 108 |
) -> float:
|
| 109 |
+
"""
|
| 110 |
+
Full reward pipeline. Entry point for all domain graders.
|
| 111 |
+
|
| 112 |
+
Pipeline:
|
| 113 |
+
1. Invalid action check β if wrong domain action, return penalised score immediately
|
| 114 |
+
2. Repetition penalty β subtract for repeated action types
|
| 115 |
+
3. compute_correctness_fn β domain-specific grader (security/dep/clinical)
|
| 116 |
+
4. Harmful output penalty β subtract for destructive patterns
|
| 117 |
+
5. Efficiency bonus β add small bonus for early completion
|
| 118 |
+
6. safe_score β clamp to [0.01, 0.99]
|
| 119 |
+
|
| 120 |
+
NOTE: difficulty_multiplier has been REMOVED.
|
| 121 |
+
The task difficulty is expressed through:
|
| 122 |
+
- Tighter CVSS ranges in hard cases (harder to guess)
|
| 123 |
+
- More required_fix_tokens in hard cases
|
| 124 |
+
- Adversarial reviewer_feedback in hard cases
|
| 125 |
+
- Dependency graphs in hard clinical cases
|
| 126 |
+
- Multiple checklist items with ordering in hard dep cases
|
| 127 |
+
The grader itself should produce lower scores for harder tasks naturally.
|
| 128 |
"""
|
| 129 |
if forbidden_patterns is None:
|
| 130 |
forbidden_patterns = []
|
| 131 |
|
| 132 |
action_type = action.get('action_type', 'unknown')
|
| 133 |
|
| 134 |
+
# Step 1: Invalid action β skip grader entirely, return penalised score
|
| 135 |
+
inv = invalid_action_penalty(action_type, valid_actions)
|
| 136 |
+
rep = repetition_penalty(action_type, session.last_actions)
|
|
|
|
|
|
|
|
|
|
| 137 |
if inv < 0:
|
| 138 |
return safe_score(inv + rep)
|
| 139 |
|
| 140 |
+
# Step 2: Domain-specific correctness
|
| 141 |
correctness = compute_correctness_fn(action, session.task_case)
|
|
|
|
| 142 |
if correctness is None:
|
| 143 |
+
correctness = 0.01
|
| 144 |
|
| 145 |
+
# Step 3: Harmful output check
|
| 146 |
+
harm = harmful_output_penalty(action, forbidden_patterns)
|
|
|
|
|
|
|
| 147 |
|
| 148 |
+
# Step 4: Efficiency bonus
|
| 149 |
eff = efficiency_bonus(session.step_count + 1, max_steps, correctness >= 0.75)
|
| 150 |
|
| 151 |
+
# Step 5: Combine and clamp
|
| 152 |
raw = correctness + rep + harm + eff
|
| 153 |
return safe_score(raw)
|