Spaces:
Sleeping
Sleeping
| """ | |
| WhipStudio Environment with Agent Tools | |
| This module implements the ML debugging environment with multiple tool-calling | |
| capabilities for agents to debug code step-by-step before submitting a fix. | |
| """ | |
| import ast | |
| import difflib | |
| import math | |
| import os | |
| import re | |
| import subprocess | |
| import sys | |
| import tempfile | |
| import time | |
| from typing import Optional | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| try: | |
| from ..models import MLDebugAction, MLDebugObservation | |
| from .sandbox import SAFE_ENV, execute_code, strip_markdown_code | |
| from .tasks import ( | |
| task1_broken_loop, | |
| task2_nan_loss, | |
| task3_oom_leakage, | |
| task4_wrong_loss, | |
| task5_frozen_backbone, | |
| task6_io_mismatch, | |
| ) | |
| from .tasks.graders import RunResult, parse_losses, parse_val_accs, score_task | |
| except ImportError: | |
| from models import MLDebugAction, MLDebugObservation | |
| from server.sandbox import SAFE_ENV, execute_code, strip_markdown_code | |
| from server.tasks import ( | |
| task1_broken_loop, | |
| task2_nan_loss, | |
| task3_oom_leakage, | |
| task4_wrong_loss, | |
| task5_frozen_backbone, | |
| task6_io_mismatch, | |
| ) | |
| from server.tasks.graders import parse_losses, parse_val_accs, score_task | |
| # ββ Task Registry ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| TASKS = { | |
| "task1": task1_broken_loop, | |
| "task2": task2_nan_loss, | |
| "task3": task3_oom_leakage, | |
| "task4": task4_wrong_loss, | |
| "task5": task5_frozen_backbone, | |
| "task6": task6_io_mismatch, | |
| } | |
| # ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MAX_TURNS_PER_EPISODE = int(os.environ.get("MAX_TURNS_PER_EPISODE", "10")) | |
| TOOL_TIMEOUT_SECONDS = 30 # Increased for PyTorch initialization time | |
| MAX_OUTPUT_BYTES = 50000 # Increased to show complete outputs | |
| # Allowed imports for sandboxed execution | |
| ALLOWED_PACKAGES = { | |
| "torch", "numpy", "sklearn", "pandas", "matplotlib", "scipy", | |
| "math", "random", "os", "sys", "collections", "itertools", | |
| "functools", "json", "re", "typing", "copy", "dataclasses", | |
| "torch.nn", "torch.optim", "torch.utils", "torch.utils.data", | |
| "numpy.random", "sklearn.datasets", "sklearn.model_selection", | |
| } | |
| # Banned imports for security | |
| BANNED_IMPORTS = {"socket", "requests", "httpx", "urllib", "subprocess", "shutil"} | |
| # ββ Security Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def check_code_security(code: str) -> tuple[bool, str]: | |
| """ | |
| Analyze code for security violations using AST. | |
| Returns (is_safe, error_message). | |
| """ | |
| try: | |
| tree = ast.parse(code) | |
| except SyntaxError: | |
| return True, "" # Let it fail at runtime with proper error | |
| for node in ast.walk(tree): | |
| # Check imports | |
| if isinstance(node, ast.Import): | |
| for alias in node.names: | |
| module_root = alias.name.split(".")[0] | |
| if module_root in BANNED_IMPORTS: | |
| return False, f"Import of '{alias.name}' is not allowed (network/system access)" | |
| if isinstance(node, ast.ImportFrom): | |
| if node.module: | |
| module_root = node.module.split(".")[0] | |
| if module_root in BANNED_IMPORTS: | |
| return False, f"Import from '{node.module}' is not allowed (network/system access)" | |
| # Check file writes outside /tmp | |
| if isinstance(node, ast.Call): | |
| func = node.func | |
| if isinstance(func, ast.Name) and func.id == "open": | |
| if len(node.args) >= 1: | |
| first_arg = node.args[0] | |
| if isinstance(first_arg, ast.Constant) and isinstance(first_arg.value, str): | |
| path = first_arg.value | |
| if not path.startswith("/tmp") and not path.startswith("tmp"): | |
| # Check mode argument | |
| mode = "r" | |
| if len(node.args) >= 2 and isinstance(node.args[1], ast.Constant): | |
| mode = str(node.args[1].value) | |
| for kw in node.keywords: | |
| if kw.arg == "mode" and isinstance(kw.value, ast.Constant): | |
| mode = str(kw.value.value) | |
| if "w" in mode or "a" in mode or "+" in mode: | |
| return False, f"File writes outside /tmp are not allowed: {path}" | |
| return True, "" | |
| def run_sandboxed_code(code: str, timeout: int = TOOL_TIMEOUT_SECONDS) -> dict: | |
| """ | |
| Run code in a sandboxed subprocess with timeout and security constraints. | |
| Returns dict with stdout, stderr, exit_code, timed_out. | |
| """ | |
| # Security check | |
| is_safe, error = check_code_security(code) | |
| if not is_safe: | |
| return { | |
| "stdout": "", | |
| "stderr": f"Security violation: {error}", | |
| "exit_code": -1, | |
| "timed_out": False, | |
| } | |
| # Clean up markdown | |
| cleaned_code = strip_markdown_code(code) | |
| # Write to temp file | |
| with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, dir="/tmp") as f: | |
| f.write(cleaned_code) | |
| tmp_path = f.name | |
| start = time.time() | |
| try: | |
| # Run with restricted environment | |
| env = dict(SAFE_ENV) | |
| env["no_proxy"] = "*" # Disable network | |
| proc = subprocess.run( | |
| [sys.executable, tmp_path], | |
| capture_output=True, | |
| text=True, | |
| timeout=timeout, | |
| env=env, | |
| cwd="/tmp", | |
| ) | |
| return { | |
| "stdout": proc.stdout, | |
| "stderr": proc.stderr, | |
| "exit_code": proc.returncode, | |
| "timed_out": False, | |
| "elapsed": round(time.time() - start, 2), | |
| } | |
| except subprocess.TimeoutExpired: | |
| return { | |
| "stdout": "", | |
| "stderr": f"Execution timed out after {timeout} seconds", | |
| "exit_code": -1, | |
| "timed_out": True, | |
| "elapsed": timeout, | |
| } | |
| finally: | |
| try: | |
| os.unlink(tmp_path) | |
| except Exception: | |
| pass | |
| # ββ Tool Implementations βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def tool_execute_snippet(code: str, turn: int) -> MLDebugObservation: | |
| """Execute a Python snippet and return stdout/stderr/exit_code.""" | |
| result = run_sandboxed_code(code, timeout=TOOL_TIMEOUT_SECONDS) | |
| return MLDebugObservation( | |
| action_type="execute_snippet", | |
| turn=turn, | |
| episode_done=False, | |
| reward=None, | |
| stdout=result["stdout"], | |
| stderr=result["stderr"], | |
| exit_code=result["exit_code"], | |
| timed_out=result["timed_out"], | |
| ) | |
| def tool_inspect_tensor(setup_code: str, target_expression: str, turn: int) -> MLDebugObservation: | |
| """ | |
| Inspect a tensor or module parameter. | |
| Returns shape, dtype, requires_grad, grad status, min/max/mean, nan/inf checks. | |
| """ | |
| # Validate target_expression is a simple expression (no newlines, reasonable length) | |
| if not target_expression or not target_expression.strip(): | |
| return MLDebugObservation( | |
| action_type="inspect_tensor", | |
| turn=turn, | |
| episode_done=False, | |
| reward=None, | |
| error="target_expression is required", | |
| ) | |
| # Clean up target expression - strip whitespace, remove newlines | |
| target_expression = target_expression.strip().replace('\n', ' ').replace('\r', '') | |
| if len(target_expression) > 500: | |
| return MLDebugObservation( | |
| action_type="inspect_tensor", | |
| turn=turn, | |
| episode_done=False, | |
| reward=None, | |
| error="target_expression too long (max 500 chars)", | |
| ) | |
| # Build inspection script - use exec/eval for safer expression handling | |
| inspection_code = f'''{setup_code} | |
| import torch | |
| import json | |
| import math | |
| def _inspect_target(): | |
| try: | |
| # Evaluate target expression in the current scope | |
| _target = eval({repr(target_expression)}) | |
| except Exception as e: | |
| return {{"error": "Failed to evaluate expression: " + str(e)}} | |
| result = {{}} | |
| # Check if it's a tensor | |
| if isinstance(_target, torch.Tensor): | |
| result["shape"] = list(_target.shape) | |
| result["dtype"] = str(_target.dtype) | |
| result["requires_grad"] = _target.requires_grad | |
| result["grad_is_none"] = _target.grad is None if _target.requires_grad else None | |
| # Convert to float for stats if needed | |
| try: | |
| _data = _target.detach().float() | |
| result["min_val"] = float(_data.min().item()) | |
| result["max_val"] = float(_data.max().item()) | |
| result["mean_val"] = float(_data.mean().item()) | |
| result["is_nan"] = bool(torch.isnan(_data).any().item()) | |
| result["is_inf"] = bool(torch.isinf(_data).any().item()) | |
| except Exception as e: | |
| result["stats_error"] = str(e) | |
| elif hasattr(_target, 'weight'): | |
| # It's likely a module | |
| result["is_module"] = True | |
| result["type"] = type(_target).__name__ | |
| if hasattr(_target.weight, 'shape'): | |
| result["shape"] = list(_target.weight.shape) | |
| result["dtype"] = str(_target.weight.dtype) | |
| result["requires_grad"] = _target.weight.requires_grad | |
| result["grad_is_none"] = _target.weight.grad is None if _target.weight.requires_grad else None | |
| else: | |
| result["error"] = "Target is not a tensor or module: " + type(_target).__name__ | |
| return result | |
| _result = _inspect_target() | |
| print("##INSPECT_RESULT##") | |
| print(json.dumps(_result)) | |
| print("##END_INSPECT##") | |
| ''' | |
| result = run_sandboxed_code(inspection_code, timeout=TOOL_TIMEOUT_SECONDS) | |
| obs = MLDebugObservation( | |
| action_type="inspect_tensor", | |
| turn=turn, | |
| episode_done=False, | |
| reward=None, | |
| stdout=result["stdout"], | |
| stderr=result["stderr"], | |
| exit_code=result["exit_code"], | |
| timed_out=result["timed_out"], | |
| ) | |
| # Parse the result | |
| if "##INSPECT_RESULT##" in result["stdout"]: | |
| try: | |
| match = re.search(r"##INSPECT_RESULT##\s*(\{.*?\})\s*##END_INSPECT##", result["stdout"], re.DOTALL) | |
| if match: | |
| import json | |
| data = json.loads(match.group(1)) | |
| if "error" in data: | |
| obs.error = data["error"] | |
| else: | |
| obs.shape = data.get("shape") | |
| obs.dtype = data.get("dtype") | |
| obs.requires_grad = data.get("requires_grad") | |
| obs.grad_is_none = data.get("grad_is_none") | |
| obs.min_val = data.get("min_val") | |
| obs.max_val = data.get("max_val") | |
| obs.mean_val = data.get("mean_val") | |
| obs.is_nan = data.get("is_nan") | |
| obs.is_inf = data.get("is_inf") | |
| except Exception as e: | |
| obs.error = f"Failed to parse inspection result: {e}" | |
| elif result["stderr"]: | |
| obs.error = result["stderr"][:500] | |
| return obs | |
| def tool_run_training_probe(code: str, steps: int, turn: int) -> MLDebugObservation: | |
| """ | |
| Run N steps of training and return loss curve + gradient norms. | |
| """ | |
| # Cap steps at 10 | |
| steps = min(steps, 10) | |
| # Wrap the training code to capture metrics | |
| probe_code = f''' | |
| import torch | |
| import torch.nn as nn | |
| import json | |
| import math | |
| # Monkey-patch to capture gradient norms | |
| _grad_norms = {{}} | |
| _losses = [] | |
| _step_count = 0 | |
| _max_steps = {steps} | |
| _original_backward = torch.Tensor.backward | |
| def _patched_backward(self, *args, **kwargs): | |
| global _step_count, _losses | |
| result = _original_backward(self, *args, **kwargs) | |
| if _step_count < _max_steps: | |
| try: | |
| loss_val = self.item() | |
| _losses.append(loss_val) | |
| except: | |
| pass | |
| _step_count += 1 | |
| return result | |
| torch.Tensor.backward = _patched_backward | |
| # Run the user's code | |
| try: | |
| exec(""" | |
| {code.replace(chr(34)*3, chr(39)*3)} | |
| """) | |
| except Exception as e: | |
| print(f"EXECUTION_ERROR: {{e}}") | |
| # Restore backward | |
| torch.Tensor.backward = _original_backward | |
| # Try to find model and optimizer in globals | |
| _model = None | |
| _optimizer = None | |
| for _name, _obj in list(globals().items()): | |
| if isinstance(_obj, nn.Module) and _model is None: | |
| _model = _obj | |
| if hasattr(_obj, 'param_groups') and _optimizer is None: | |
| _optimizer = _obj | |
| # Capture gradient norms | |
| if _model is not None: | |
| for name, param in _model.named_parameters(): | |
| if param.requires_grad and param.grad is not None: | |
| _grad_norms[name] = float(param.grad.norm().item()) | |
| # Capture optimizer param count | |
| _opt_count = None | |
| if _optimizer is not None: | |
| _opt_count = sum(p.numel() for g in _optimizer.param_groups for p in g['params']) | |
| # Output results | |
| print("##PROBE_RESULT##") | |
| print(json.dumps({{ | |
| "losses": _losses[:_max_steps], | |
| "grad_norms": _grad_norms, | |
| "optimizer_param_count": _opt_count, | |
| "final_loss": _losses[-1] if _losses else None, | |
| "loss_is_nan": any(math.isnan(l) if isinstance(l, float) else False for l in _losses), | |
| "loss_is_inf": any(math.isinf(l) if isinstance(l, float) else False for l in _losses), | |
| }})) | |
| print("##END_PROBE##") | |
| ''' | |
| result = run_sandboxed_code(probe_code, timeout=TOOL_TIMEOUT_SECONDS) | |
| obs = MLDebugObservation( | |
| action_type="run_training_probe", | |
| turn=turn, | |
| episode_done=False, | |
| reward=None, | |
| stdout=result["stdout"], | |
| stderr=result["stderr"], | |
| timed_out=result["timed_out"], | |
| ) | |
| # Parse the probe result | |
| if "##PROBE_RESULT##" in result["stdout"]: | |
| try: | |
| match = re.search(r"##PROBE_RESULT##\s*(\{.*?\})\s*##END_PROBE##", result["stdout"], re.DOTALL) | |
| if match: | |
| import json | |
| data = json.loads(match.group(1)) | |
| obs.losses = data.get("losses", []) | |
| obs.grad_norms = data.get("grad_norms", {}) | |
| obs.optimizer_param_count = data.get("optimizer_param_count") | |
| obs.final_loss = data.get("final_loss") | |
| obs.loss_is_nan = data.get("loss_is_nan", False) | |
| obs.loss_is_inf = data.get("loss_is_inf", False) | |
| except Exception as e: | |
| obs.error = f"Failed to parse probe result: {e}" | |
| return obs | |
| def tool_get_variable_state(setup_code: str, expressions: list[str], turn: int) -> MLDebugObservation: | |
| """ | |
| Evaluate multiple expressions and return their repr, type, value, shape. | |
| """ | |
| # Limit expressions | |
| expressions = expressions[:10] | |
| expr_list_str = repr(expressions) | |
| eval_code = f''' | |
| {setup_code} | |
| import torch | |
| import json | |
| _expressions = {expr_list_str} | |
| _results = {{}} | |
| for _expr in _expressions: | |
| try: | |
| _val = eval(_expr) | |
| _result = {{ | |
| "repr": repr(_val)[:500], | |
| "type": type(_val).__name__, | |
| "value": None, | |
| "shape": None, | |
| "error": None, | |
| }} | |
| # Extract scalar value | |
| if isinstance(_val, (int, float, bool)): | |
| _result["value"] = _val | |
| elif isinstance(_val, str): | |
| _result["value"] = _val[:200] | |
| elif hasattr(_val, 'item') and _val.numel() == 1: | |
| _result["value"] = float(_val.item()) | |
| # Extract shape | |
| if hasattr(_val, 'shape'): | |
| _result["shape"] = list(_val.shape) | |
| elif hasattr(_val, '__len__') and not isinstance(_val, str): | |
| _result["shape"] = [len(_val)] | |
| _results[_expr] = _result | |
| except Exception as e: | |
| _results[_expr] = {{ | |
| "repr": "", | |
| "type": "", | |
| "value": None, | |
| "shape": None, | |
| "error": str(e)[:200], | |
| }} | |
| print("##VAR_RESULT##") | |
| print(json.dumps(_results)) | |
| print("##END_VAR##") | |
| ''' | |
| result = run_sandboxed_code(eval_code, timeout=TOOL_TIMEOUT_SECONDS) | |
| obs = MLDebugObservation( | |
| action_type="get_variable_state", | |
| turn=turn, | |
| episode_done=False, | |
| reward=None, | |
| stdout=result["stdout"], | |
| stderr=result["stderr"], | |
| timed_out=result["timed_out"], | |
| ) | |
| # Parse results | |
| if "##VAR_RESULT##" in result["stdout"]: | |
| try: | |
| match = re.search(r"##VAR_RESULT##\s*(\{.*?\})\s*##END_VAR##", result["stdout"], re.DOTALL) | |
| if match: | |
| import json | |
| obs.results = json.loads(match.group(1)) | |
| except Exception as e: | |
| obs.error = f"Failed to parse variable results: {e}" | |
| return obs | |
| def tool_inspect_diff(original_code: str, proposed_code: str, turn: int) -> MLDebugObservation: | |
| """ | |
| Generate a unified diff between original and proposed code. | |
| """ | |
| original_lines = original_code.strip().splitlines(keepends=True) | |
| proposed_lines = proposed_code.strip().splitlines(keepends=True) | |
| diff = list(difflib.unified_diff( | |
| original_lines, | |
| proposed_lines, | |
| fromfile="original.py", | |
| tofile="proposed.py", | |
| lineterm="", | |
| )) | |
| # Count changes | |
| additions = sum(1 for line in diff if line.startswith("+") and not line.startswith("+++")) | |
| deletions = sum(1 for line in diff if line.startswith("-") and not line.startswith("---")) | |
| lines_changed = additions + deletions | |
| return MLDebugObservation( | |
| action_type="inspect_diff", | |
| turn=turn, | |
| episode_done=False, | |
| reward=None, | |
| diff="\n".join(diff), | |
| lines_changed=lines_changed, | |
| additions=additions, | |
| deletions=deletions, | |
| ) | |
| # ββ Global Session Store βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # OpenEnv creates a new environment instance for each HTTP request, so we need | |
| # a global store to track state (step_count, submitted status) across requests. | |
| from dataclasses import dataclass, field | |
| from typing import Dict | |
| class SessionState: | |
| """State for a single episode session.""" | |
| episode_id: str | |
| task_id: str | |
| step_count: int = 0 | |
| submitted: bool = False | |
| best_reward: float = 0.0 | |
| original_buggy_code: str = "" | |
| trajectory: list = field(default_factory=list) | |
| # Global session store keyed by episode_id | |
| _SESSION_STORE: Dict[str, SessionState] = {} | |
| def _get_session(episode_id: str) -> Optional[SessionState]: | |
| """Get session state for an episode, or None if not found.""" | |
| return _SESSION_STORE.get(episode_id) | |
| def _create_session(episode_id: str, task_id: str, original_buggy_code: str) -> SessionState: | |
| """Create a new session and store it.""" | |
| session = SessionState( | |
| episode_id=episode_id, | |
| task_id=task_id, | |
| original_buggy_code=original_buggy_code, | |
| ) | |
| _SESSION_STORE[episode_id] = session | |
| return session | |
| def _clear_session(episode_id: str) -> None: | |
| """Remove a session from the store.""" | |
| _SESSION_STORE.pop(episode_id, None) | |
| # ββ Main Environment βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class MLDebugEnvironment(Environment): | |
| """ | |
| ML Debug Environment with global session store. | |
| OpenEnv creates a new instance for each HTTP request, so state is tracked | |
| in a global session store keyed by episode_id. | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self): | |
| # Instance state (transient per request) | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| def reset(self, task_id: str = "task1", **kwargs) -> MLDebugObservation: | |
| if task_id not in TASKS: | |
| task_id = "task1" | |
| # Create new episode ID and session | |
| episode_id = str(uuid4()) | |
| task = TASKS[task_id] | |
| original_buggy_code = task.BUGGY_CODE.strip() | |
| # Create session in global store | |
| _create_session(episode_id, task_id, original_buggy_code) | |
| self._state = State(episode_id=episode_id, step_count=0) | |
| return MLDebugObservation( | |
| action_type="reset", | |
| task_id=task_id, | |
| task_description=task.TASK_DESCRIPTION.strip(), | |
| buggy_code=original_buggy_code, | |
| error_log="", | |
| last_reward=0.0, | |
| metrics={}, | |
| done=False, | |
| episode_done=False, | |
| reward=None, | |
| turn=0, | |
| episode_id=episode_id, | |
| ) | |
| def step(self, action: MLDebugAction) -> MLDebugObservation: | |
| # Get episode_id from action (passed from client) | |
| episode_id = getattr(action, 'episode_id', None) or "" | |
| # Look up session in global store | |
| session = _get_session(episode_id) if episode_id else None | |
| if session is None: | |
| # No valid session - return error | |
| return MLDebugObservation( | |
| action_type=action.action_type, | |
| error=f"No active session for episode_id '{episode_id}'. Call reset() first.", | |
| episode_done=True, | |
| done=True, | |
| reward=0.001, | |
| turn=0, | |
| task_id="task1", | |
| task_description="", | |
| buggy_code="", | |
| episode_id=episode_id or "", | |
| ) | |
| # Check if episode already complete | |
| if session.submitted: | |
| return MLDebugObservation( | |
| action_type=action.action_type, | |
| error="Episode already complete. Call reset() to start a new episode.", | |
| episode_done=True, | |
| done=True, | |
| reward=0.001, | |
| turn=session.step_count, | |
| task_id=session.task_id, | |
| task_description=TASKS[session.task_id].TASK_DESCRIPTION.strip(), | |
| buggy_code=session.original_buggy_code, | |
| episode_id=episode_id, | |
| ) | |
| # Increment turn in session store | |
| session.step_count += 1 | |
| current_turn = session.step_count | |
| # Check max turns | |
| if current_turn > MAX_TURNS_PER_EPISODE: | |
| session.submitted = True | |
| return MLDebugObservation( | |
| action_type=action.action_type, | |
| error=f"Maximum turns ({MAX_TURNS_PER_EPISODE}) exceeded. Episode terminated.", | |
| episode_done=True, | |
| done=True, | |
| reward=0.001, | |
| turn=current_turn, | |
| task_id=session.task_id, | |
| task_description=TASKS[session.task_id].TASK_DESCRIPTION.strip(), | |
| buggy_code=session.original_buggy_code, | |
| episode_id=episode_id, | |
| ) | |
| # Dispatch based on action type | |
| action_type = action.action_type | |
| task_id = session.task_id | |
| original_buggy_code = session.original_buggy_code | |
| if action_type == "execute_snippet": | |
| obs = tool_execute_snippet(action.code, current_turn) | |
| obs.task_id = task_id | |
| obs.task_description = TASKS[task_id].TASK_DESCRIPTION.strip() | |
| obs.buggy_code = original_buggy_code | |
| obs.episode_id = episode_id | |
| return obs | |
| elif action_type == "inspect_tensor": | |
| obs = tool_inspect_tensor(action.setup_code, action.target_expression, current_turn) | |
| obs.task_id = task_id | |
| obs.task_description = TASKS[task_id].TASK_DESCRIPTION.strip() | |
| obs.buggy_code = original_buggy_code | |
| obs.episode_id = episode_id | |
| return obs | |
| elif action_type == "run_training_probe": | |
| obs = tool_run_training_probe(action.code, action.steps, current_turn) | |
| obs.task_id = task_id | |
| obs.task_description = TASKS[task_id].TASK_DESCRIPTION.strip() | |
| obs.buggy_code = original_buggy_code | |
| obs.episode_id = episode_id | |
| return obs | |
| elif action_type == "get_variable_state": | |
| obs = tool_get_variable_state(action.setup_code, action.expressions, current_turn) | |
| obs.task_id = task_id | |
| obs.task_description = TASKS[task_id].TASK_DESCRIPTION.strip() | |
| obs.buggy_code = original_buggy_code | |
| obs.episode_id = episode_id | |
| return obs | |
| elif action_type == "inspect_diff": | |
| obs = tool_inspect_diff( | |
| original_buggy_code, | |
| action.proposed_code, | |
| current_turn, | |
| ) | |
| obs.task_id = task_id | |
| obs.task_description = TASKS[task_id].TASK_DESCRIPTION.strip() | |
| obs.buggy_code = original_buggy_code | |
| obs.episode_id = episode_id | |
| return obs | |
| elif action_type == "submit_fix": | |
| return self._handle_submit_fix(action, session) | |
| else: | |
| # Unknown action type - treat as submit_fix for backward compat | |
| if action.fixed_code: | |
| return self._handle_submit_fix(action, session) | |
| else: | |
| return MLDebugObservation( | |
| action_type=action_type, | |
| error=f"Unknown action type: {action_type}", | |
| episode_done=False, | |
| reward=None, | |
| turn=current_turn, | |
| task_id=task_id, | |
| task_description=TASKS[task_id].TASK_DESCRIPTION.strip(), | |
| buggy_code=original_buggy_code, | |
| episode_id=episode_id, | |
| ) | |
| def _handle_submit_fix(self, action: MLDebugAction, session: SessionState) -> MLDebugObservation: | |
| """Handle the submit_fix action (terminal action).""" | |
| # Mark as submitted in session | |
| session.submitted = True | |
| turns_used = session.step_count | |
| episode_id = session.episode_id | |
| task_id = session.task_id | |
| original_buggy_code = session.original_buggy_code | |
| # Get the fixed code | |
| fixed_code = action.fixed_code or action.code or "" | |
| if not fixed_code or not fixed_code.strip(): | |
| return MLDebugObservation( | |
| action_type="submit_fix", | |
| turn=turns_used, | |
| episode_done=True, | |
| done=True, | |
| reward=0.001, | |
| success=False, | |
| grader_details={"error": "empty code submitted"}, | |
| turns_used=turns_used, | |
| task_id=task_id, | |
| task_description=TASKS[task_id].TASK_DESCRIPTION.strip(), | |
| buggy_code=original_buggy_code, | |
| error_log="empty code submitted", | |
| error="Empty code submitted", | |
| episode_id=episode_id, | |
| ) | |
| # Execute and grade | |
| run_result1 = execute_code(fixed_code) | |
| reward1, breakdown1 = score_task(task_id, run_result1) | |
| # Consistency check for high scores | |
| consistency_flag = False | |
| reward_variance = 0.0 | |
| final_reward = reward1 | |
| final_breakdown = breakdown1 | |
| run_result = run_result1 | |
| if reward1 > 0.5: | |
| run_result2 = execute_code(fixed_code) | |
| reward2, breakdown2 = score_task(task_id, run_result2) | |
| reward_variance = abs(reward1 - reward2) | |
| if reward_variance > 0.15: | |
| consistency_flag = True | |
| final_reward = min(reward1, reward2) | |
| if reward2 < reward1: | |
| final_breakdown = breakdown2 | |
| run_result = run_result2 | |
| else: | |
| consistency_flag = False | |
| final_reward = (reward1 + reward2) / 2.0 | |
| session.best_reward = max(session.best_reward, final_reward) | |
| # Parse metrics | |
| losses = parse_losses(run_result.stdout) | |
| val_accs = parse_val_accs(run_result.stdout) | |
| final_loss = None | |
| if losses: | |
| final_loss = losses[-1] | |
| else: | |
| match = re.search(r"FINAL_LOSS:([-\d.]+)", run_result.stdout) | |
| if match: | |
| final_loss = float(match.group(1)) | |
| metrics = { | |
| "exit_code": run_result.exit_code, | |
| "elapsed_seconds": run_result.elapsed_seconds, | |
| "timed_out": run_result.timed_out, | |
| "step": session.step_count, | |
| "best_reward_so_far": session.best_reward, | |
| "final_loss": final_loss, | |
| "nan_count": sum(1 for x in losses if math.isnan(x) or math.isinf(x)) if losses else 0, | |
| "val_acc": val_accs[-1] if val_accs else None, | |
| "consistency_flag": consistency_flag, | |
| "reward_variance": round(reward_variance, 4), | |
| "reward_breakdown": final_breakdown, | |
| } | |
| session.trajectory.append({ | |
| "step": session.step_count, | |
| "reward": final_reward, | |
| "best_reward": session.best_reward, | |
| "metrics": metrics, | |
| "done": True, | |
| "timestamp": time.time(), | |
| }) | |
| task = TASKS[task_id] | |
| # Return complete outputs without truncation | |
| return MLDebugObservation( | |
| action_type="submit_fix", | |
| turn=turns_used, | |
| episode_done=True, | |
| done=True, | |
| reward=final_reward, | |
| success=final_reward >= 0.7, | |
| grader_details=final_breakdown, | |
| turns_used=turns_used, | |
| task_id=task_id, | |
| task_description=task.TASK_DESCRIPTION.strip(), | |
| buggy_code=original_buggy_code, | |
| error_log=(run_result.stdout + "\n" + run_result.stderr).strip(), | |
| last_reward=final_reward, | |
| metrics=metrics, | |
| stdout=run_result.stdout, | |
| stderr=run_result.stderr, | |
| exit_code=run_result.exit_code, | |
| timed_out=run_result.timed_out, | |
| losses=losses if losses else [], | |
| final_loss=final_loss, | |
| episode_id=episode_id, | |
| ) | |
| def trajectory(self) -> list[dict]: | |
| return list(self._trajectory) | |
| def state(self) -> State: | |
| return self._state | |
| # ββ Tool Definitions for /tools endpoint βββββββββββββββββββββββββββββββββββ | |
| TOOL_DEFINITIONS = [ | |
| { | |
| "name": "execute_snippet", | |
| "description": "Execute a short Python code snippet and return stdout, stderr, and exit code. Use for quick probes and checks.", | |
| "action_fields": { | |
| "action_type": "execute_snippet", | |
| "code": "str - The Python snippet to run", | |
| }, | |
| "observation_fields": ["stdout", "stderr", "exit_code", "timed_out", "turn"], | |
| "constraints": { | |
| "timeout": "10 seconds", | |
| "network": "disabled", | |
| "file_writes": "/tmp only", | |
| }, | |
| }, | |
| { | |
| "name": "inspect_tensor", | |
| "description": "Inspect a tensor or module parameter. Returns shape, dtype, requires_grad, grad status, min/max/mean, and NaN/Inf checks.", | |
| "action_fields": { | |
| "action_type": "inspect_tensor", | |
| "setup_code": "str - Code that defines the tensor/module (imports + definitions)", | |
| "target_expression": "str - The expression to inspect (e.g., 'model.weight.grad')", | |
| }, | |
| "observation_fields": ["shape", "dtype", "requires_grad", "grad_is_none", "min_val", "max_val", "mean_val", "is_nan", "is_inf", "error", "turn"], | |
| }, | |
| { | |
| "name": "run_training_probe", | |
| "description": "Run N steps of training and return loss curve + gradient norms. Use to verify if a fix is working before submission.", | |
| "action_fields": { | |
| "action_type": "run_training_probe", | |
| "code": "str - The full training script to probe", | |
| "steps": "int - Number of training steps (1-10, default 5)", | |
| }, | |
| "observation_fields": ["losses", "grad_norms", "optimizer_param_count", "final_loss", "loss_is_nan", "loss_is_inf", "stderr", "timed_out", "turn"], | |
| }, | |
| { | |
| "name": "get_variable_state", | |
| "description": "Evaluate multiple expressions and return their repr, type, value, and shape. Useful for checking model.training, optimizer params, etc.", | |
| "action_fields": { | |
| "action_type": "get_variable_state", | |
| "setup_code": "str - Setup code (imports + definitions)", | |
| "expressions": "list[str] - List of expressions to evaluate (max 10)", | |
| }, | |
| "observation_fields": ["results", "turn"], | |
| }, | |
| { | |
| "name": "inspect_diff", | |
| "description": "Generate a unified diff between the original buggy code and your proposed fix. Review your changes before submitting.", | |
| "action_fields": { | |
| "action_type": "inspect_diff", | |
| "proposed_code": "str - Your proposed fix", | |
| }, | |
| "observation_fields": ["diff", "lines_changed", "additions", "deletions", "turn"], | |
| }, | |
| { | |
| "name": "submit_fix", | |
| "description": "Submit your final fix. This is the terminal action that ends the episode and returns the grader score.", | |
| "action_fields": { | |
| "action_type": "submit_fix", | |
| "fixed_code": "str - The complete fixed Python script", | |
| }, | |
| "observation_fields": ["reward", "success", "grader_details", "turns_used", "episode_done", "turn"], | |
| "terminal": True, | |
| }, | |
| ] | |