whipstudio / server /environment.py
Amogh-kal1's picture
Upload folder using huggingface_hub
a8f49ae verified
"""
WhipStudio Environment with Agent Tools
This module implements the ML debugging environment with multiple tool-calling
capabilities for agents to debug code step-by-step before submitting a fix.
"""
import ast
import difflib
import math
import os
import re
import subprocess
import sys
import tempfile
import time
from typing import Optional
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from ..models import MLDebugAction, MLDebugObservation
from .sandbox import SAFE_ENV, execute_code, strip_markdown_code
from .tasks import (
task1_broken_loop,
task2_nan_loss,
task3_oom_leakage,
task4_wrong_loss,
task5_frozen_backbone,
task6_io_mismatch,
)
from .tasks.graders import RunResult, parse_losses, parse_val_accs, score_task
except ImportError:
from models import MLDebugAction, MLDebugObservation
from server.sandbox import SAFE_ENV, execute_code, strip_markdown_code
from server.tasks import (
task1_broken_loop,
task2_nan_loss,
task3_oom_leakage,
task4_wrong_loss,
task5_frozen_backbone,
task6_io_mismatch,
)
from server.tasks.graders import parse_losses, parse_val_accs, score_task
# ── Task Registry ──────────────────────────────────────────────────────────
TASKS = {
"task1": task1_broken_loop,
"task2": task2_nan_loss,
"task3": task3_oom_leakage,
"task4": task4_wrong_loss,
"task5": task5_frozen_backbone,
"task6": task6_io_mismatch,
}
# ── Configuration ──────────────────────────────────────────────────────────
MAX_TURNS_PER_EPISODE = int(os.environ.get("MAX_TURNS_PER_EPISODE", "10"))
TOOL_TIMEOUT_SECONDS = 30 # Increased for PyTorch initialization time
MAX_OUTPUT_BYTES = 50000 # Increased to show complete outputs
# Allowed imports for sandboxed execution
ALLOWED_PACKAGES = {
"torch", "numpy", "sklearn", "pandas", "matplotlib", "scipy",
"math", "random", "os", "sys", "collections", "itertools",
"functools", "json", "re", "typing", "copy", "dataclasses",
"torch.nn", "torch.optim", "torch.utils", "torch.utils.data",
"numpy.random", "sklearn.datasets", "sklearn.model_selection",
}
# Banned imports for security
BANNED_IMPORTS = {"socket", "requests", "httpx", "urllib", "subprocess", "shutil"}
# ── Security Helpers ───────────────────────────────────────────────────────
def check_code_security(code: str) -> tuple[bool, str]:
"""
Analyze code for security violations using AST.
Returns (is_safe, error_message).
"""
try:
tree = ast.parse(code)
except SyntaxError:
return True, "" # Let it fail at runtime with proper error
for node in ast.walk(tree):
# Check imports
if isinstance(node, ast.Import):
for alias in node.names:
module_root = alias.name.split(".")[0]
if module_root in BANNED_IMPORTS:
return False, f"Import of '{alias.name}' is not allowed (network/system access)"
if isinstance(node, ast.ImportFrom):
if node.module:
module_root = node.module.split(".")[0]
if module_root in BANNED_IMPORTS:
return False, f"Import from '{node.module}' is not allowed (network/system access)"
# Check file writes outside /tmp
if isinstance(node, ast.Call):
func = node.func
if isinstance(func, ast.Name) and func.id == "open":
if len(node.args) >= 1:
first_arg = node.args[0]
if isinstance(first_arg, ast.Constant) and isinstance(first_arg.value, str):
path = first_arg.value
if not path.startswith("/tmp") and not path.startswith("tmp"):
# Check mode argument
mode = "r"
if len(node.args) >= 2 and isinstance(node.args[1], ast.Constant):
mode = str(node.args[1].value)
for kw in node.keywords:
if kw.arg == "mode" and isinstance(kw.value, ast.Constant):
mode = str(kw.value.value)
if "w" in mode or "a" in mode or "+" in mode:
return False, f"File writes outside /tmp are not allowed: {path}"
return True, ""
def run_sandboxed_code(code: str, timeout: int = TOOL_TIMEOUT_SECONDS) -> dict:
"""
Run code in a sandboxed subprocess with timeout and security constraints.
Returns dict with stdout, stderr, exit_code, timed_out.
"""
# Security check
is_safe, error = check_code_security(code)
if not is_safe:
return {
"stdout": "",
"stderr": f"Security violation: {error}",
"exit_code": -1,
"timed_out": False,
}
# Clean up markdown
cleaned_code = strip_markdown_code(code)
# Write to temp file
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, dir="/tmp") as f:
f.write(cleaned_code)
tmp_path = f.name
start = time.time()
try:
# Run with restricted environment
env = dict(SAFE_ENV)
env["no_proxy"] = "*" # Disable network
proc = subprocess.run(
[sys.executable, tmp_path],
capture_output=True,
text=True,
timeout=timeout,
env=env,
cwd="/tmp",
)
return {
"stdout": proc.stdout,
"stderr": proc.stderr,
"exit_code": proc.returncode,
"timed_out": False,
"elapsed": round(time.time() - start, 2),
}
except subprocess.TimeoutExpired:
return {
"stdout": "",
"stderr": f"Execution timed out after {timeout} seconds",
"exit_code": -1,
"timed_out": True,
"elapsed": timeout,
}
finally:
try:
os.unlink(tmp_path)
except Exception:
pass
# ── Tool Implementations ───────────────────────────────────────────────────
def tool_execute_snippet(code: str, turn: int) -> MLDebugObservation:
"""Execute a Python snippet and return stdout/stderr/exit_code."""
result = run_sandboxed_code(code, timeout=TOOL_TIMEOUT_SECONDS)
return MLDebugObservation(
action_type="execute_snippet",
turn=turn,
episode_done=False,
reward=None,
stdout=result["stdout"],
stderr=result["stderr"],
exit_code=result["exit_code"],
timed_out=result["timed_out"],
)
def tool_inspect_tensor(setup_code: str, target_expression: str, turn: int) -> MLDebugObservation:
"""
Inspect a tensor or module parameter.
Returns shape, dtype, requires_grad, grad status, min/max/mean, nan/inf checks.
"""
# Validate target_expression is a simple expression (no newlines, reasonable length)
if not target_expression or not target_expression.strip():
return MLDebugObservation(
action_type="inspect_tensor",
turn=turn,
episode_done=False,
reward=None,
error="target_expression is required",
)
# Clean up target expression - strip whitespace, remove newlines
target_expression = target_expression.strip().replace('\n', ' ').replace('\r', '')
if len(target_expression) > 500:
return MLDebugObservation(
action_type="inspect_tensor",
turn=turn,
episode_done=False,
reward=None,
error="target_expression too long (max 500 chars)",
)
# Build inspection script - use exec/eval for safer expression handling
inspection_code = f'''{setup_code}
import torch
import json
import math
def _inspect_target():
try:
# Evaluate target expression in the current scope
_target = eval({repr(target_expression)})
except Exception as e:
return {{"error": "Failed to evaluate expression: " + str(e)}}
result = {{}}
# Check if it's a tensor
if isinstance(_target, torch.Tensor):
result["shape"] = list(_target.shape)
result["dtype"] = str(_target.dtype)
result["requires_grad"] = _target.requires_grad
result["grad_is_none"] = _target.grad is None if _target.requires_grad else None
# Convert to float for stats if needed
try:
_data = _target.detach().float()
result["min_val"] = float(_data.min().item())
result["max_val"] = float(_data.max().item())
result["mean_val"] = float(_data.mean().item())
result["is_nan"] = bool(torch.isnan(_data).any().item())
result["is_inf"] = bool(torch.isinf(_data).any().item())
except Exception as e:
result["stats_error"] = str(e)
elif hasattr(_target, 'weight'):
# It's likely a module
result["is_module"] = True
result["type"] = type(_target).__name__
if hasattr(_target.weight, 'shape'):
result["shape"] = list(_target.weight.shape)
result["dtype"] = str(_target.weight.dtype)
result["requires_grad"] = _target.weight.requires_grad
result["grad_is_none"] = _target.weight.grad is None if _target.weight.requires_grad else None
else:
result["error"] = "Target is not a tensor or module: " + type(_target).__name__
return result
_result = _inspect_target()
print("##INSPECT_RESULT##")
print(json.dumps(_result))
print("##END_INSPECT##")
'''
result = run_sandboxed_code(inspection_code, timeout=TOOL_TIMEOUT_SECONDS)
obs = MLDebugObservation(
action_type="inspect_tensor",
turn=turn,
episode_done=False,
reward=None,
stdout=result["stdout"],
stderr=result["stderr"],
exit_code=result["exit_code"],
timed_out=result["timed_out"],
)
# Parse the result
if "##INSPECT_RESULT##" in result["stdout"]:
try:
match = re.search(r"##INSPECT_RESULT##\s*(\{.*?\})\s*##END_INSPECT##", result["stdout"], re.DOTALL)
if match:
import json
data = json.loads(match.group(1))
if "error" in data:
obs.error = data["error"]
else:
obs.shape = data.get("shape")
obs.dtype = data.get("dtype")
obs.requires_grad = data.get("requires_grad")
obs.grad_is_none = data.get("grad_is_none")
obs.min_val = data.get("min_val")
obs.max_val = data.get("max_val")
obs.mean_val = data.get("mean_val")
obs.is_nan = data.get("is_nan")
obs.is_inf = data.get("is_inf")
except Exception as e:
obs.error = f"Failed to parse inspection result: {e}"
elif result["stderr"]:
obs.error = result["stderr"][:500]
return obs
def tool_run_training_probe(code: str, steps: int, turn: int) -> MLDebugObservation:
"""
Run N steps of training and return loss curve + gradient norms.
"""
# Cap steps at 10
steps = min(steps, 10)
# Wrap the training code to capture metrics
probe_code = f'''
import torch
import torch.nn as nn
import json
import math
# Monkey-patch to capture gradient norms
_grad_norms = {{}}
_losses = []
_step_count = 0
_max_steps = {steps}
_original_backward = torch.Tensor.backward
def _patched_backward(self, *args, **kwargs):
global _step_count, _losses
result = _original_backward(self, *args, **kwargs)
if _step_count < _max_steps:
try:
loss_val = self.item()
_losses.append(loss_val)
except:
pass
_step_count += 1
return result
torch.Tensor.backward = _patched_backward
# Run the user's code
try:
exec("""
{code.replace(chr(34)*3, chr(39)*3)}
""")
except Exception as e:
print(f"EXECUTION_ERROR: {{e}}")
# Restore backward
torch.Tensor.backward = _original_backward
# Try to find model and optimizer in globals
_model = None
_optimizer = None
for _name, _obj in list(globals().items()):
if isinstance(_obj, nn.Module) and _model is None:
_model = _obj
if hasattr(_obj, 'param_groups') and _optimizer is None:
_optimizer = _obj
# Capture gradient norms
if _model is not None:
for name, param in _model.named_parameters():
if param.requires_grad and param.grad is not None:
_grad_norms[name] = float(param.grad.norm().item())
# Capture optimizer param count
_opt_count = None
if _optimizer is not None:
_opt_count = sum(p.numel() for g in _optimizer.param_groups for p in g['params'])
# Output results
print("##PROBE_RESULT##")
print(json.dumps({{
"losses": _losses[:_max_steps],
"grad_norms": _grad_norms,
"optimizer_param_count": _opt_count,
"final_loss": _losses[-1] if _losses else None,
"loss_is_nan": any(math.isnan(l) if isinstance(l, float) else False for l in _losses),
"loss_is_inf": any(math.isinf(l) if isinstance(l, float) else False for l in _losses),
}}))
print("##END_PROBE##")
'''
result = run_sandboxed_code(probe_code, timeout=TOOL_TIMEOUT_SECONDS)
obs = MLDebugObservation(
action_type="run_training_probe",
turn=turn,
episode_done=False,
reward=None,
stdout=result["stdout"],
stderr=result["stderr"],
timed_out=result["timed_out"],
)
# Parse the probe result
if "##PROBE_RESULT##" in result["stdout"]:
try:
match = re.search(r"##PROBE_RESULT##\s*(\{.*?\})\s*##END_PROBE##", result["stdout"], re.DOTALL)
if match:
import json
data = json.loads(match.group(1))
obs.losses = data.get("losses", [])
obs.grad_norms = data.get("grad_norms", {})
obs.optimizer_param_count = data.get("optimizer_param_count")
obs.final_loss = data.get("final_loss")
obs.loss_is_nan = data.get("loss_is_nan", False)
obs.loss_is_inf = data.get("loss_is_inf", False)
except Exception as e:
obs.error = f"Failed to parse probe result: {e}"
return obs
def tool_get_variable_state(setup_code: str, expressions: list[str], turn: int) -> MLDebugObservation:
"""
Evaluate multiple expressions and return their repr, type, value, shape.
"""
# Limit expressions
expressions = expressions[:10]
expr_list_str = repr(expressions)
eval_code = f'''
{setup_code}
import torch
import json
_expressions = {expr_list_str}
_results = {{}}
for _expr in _expressions:
try:
_val = eval(_expr)
_result = {{
"repr": repr(_val)[:500],
"type": type(_val).__name__,
"value": None,
"shape": None,
"error": None,
}}
# Extract scalar value
if isinstance(_val, (int, float, bool)):
_result["value"] = _val
elif isinstance(_val, str):
_result["value"] = _val[:200]
elif hasattr(_val, 'item') and _val.numel() == 1:
_result["value"] = float(_val.item())
# Extract shape
if hasattr(_val, 'shape'):
_result["shape"] = list(_val.shape)
elif hasattr(_val, '__len__') and not isinstance(_val, str):
_result["shape"] = [len(_val)]
_results[_expr] = _result
except Exception as e:
_results[_expr] = {{
"repr": "",
"type": "",
"value": None,
"shape": None,
"error": str(e)[:200],
}}
print("##VAR_RESULT##")
print(json.dumps(_results))
print("##END_VAR##")
'''
result = run_sandboxed_code(eval_code, timeout=TOOL_TIMEOUT_SECONDS)
obs = MLDebugObservation(
action_type="get_variable_state",
turn=turn,
episode_done=False,
reward=None,
stdout=result["stdout"],
stderr=result["stderr"],
timed_out=result["timed_out"],
)
# Parse results
if "##VAR_RESULT##" in result["stdout"]:
try:
match = re.search(r"##VAR_RESULT##\s*(\{.*?\})\s*##END_VAR##", result["stdout"], re.DOTALL)
if match:
import json
obs.results = json.loads(match.group(1))
except Exception as e:
obs.error = f"Failed to parse variable results: {e}"
return obs
def tool_inspect_diff(original_code: str, proposed_code: str, turn: int) -> MLDebugObservation:
"""
Generate a unified diff between original and proposed code.
"""
original_lines = original_code.strip().splitlines(keepends=True)
proposed_lines = proposed_code.strip().splitlines(keepends=True)
diff = list(difflib.unified_diff(
original_lines,
proposed_lines,
fromfile="original.py",
tofile="proposed.py",
lineterm="",
))
# Count changes
additions = sum(1 for line in diff if line.startswith("+") and not line.startswith("+++"))
deletions = sum(1 for line in diff if line.startswith("-") and not line.startswith("---"))
lines_changed = additions + deletions
return MLDebugObservation(
action_type="inspect_diff",
turn=turn,
episode_done=False,
reward=None,
diff="\n".join(diff),
lines_changed=lines_changed,
additions=additions,
deletions=deletions,
)
# ── Global Session Store ───────────────────────────────────────────────────
# OpenEnv creates a new environment instance for each HTTP request, so we need
# a global store to track state (step_count, submitted status) across requests.
from dataclasses import dataclass, field
from typing import Dict
@dataclass
class SessionState:
"""State for a single episode session."""
episode_id: str
task_id: str
step_count: int = 0
submitted: bool = False
best_reward: float = 0.0
original_buggy_code: str = ""
trajectory: list = field(default_factory=list)
# Global session store keyed by episode_id
_SESSION_STORE: Dict[str, SessionState] = {}
def _get_session(episode_id: str) -> Optional[SessionState]:
"""Get session state for an episode, or None if not found."""
return _SESSION_STORE.get(episode_id)
def _create_session(episode_id: str, task_id: str, original_buggy_code: str) -> SessionState:
"""Create a new session and store it."""
session = SessionState(
episode_id=episode_id,
task_id=task_id,
original_buggy_code=original_buggy_code,
)
_SESSION_STORE[episode_id] = session
return session
def _clear_session(episode_id: str) -> None:
"""Remove a session from the store."""
_SESSION_STORE.pop(episode_id, None)
# ── Main Environment ───────────────────────────────────────────────────────
class MLDebugEnvironment(Environment):
"""
ML Debug Environment with global session store.
OpenEnv creates a new instance for each HTTP request, so state is tracked
in a global session store keyed by episode_id.
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
# Instance state (transient per request)
self._state = State(episode_id=str(uuid4()), step_count=0)
def reset(self, task_id: str = "task1", **kwargs) -> MLDebugObservation:
if task_id not in TASKS:
task_id = "task1"
# Create new episode ID and session
episode_id = str(uuid4())
task = TASKS[task_id]
original_buggy_code = task.BUGGY_CODE.strip()
# Create session in global store
_create_session(episode_id, task_id, original_buggy_code)
self._state = State(episode_id=episode_id, step_count=0)
return MLDebugObservation(
action_type="reset",
task_id=task_id,
task_description=task.TASK_DESCRIPTION.strip(),
buggy_code=original_buggy_code,
error_log="",
last_reward=0.0,
metrics={},
done=False,
episode_done=False,
reward=None,
turn=0,
episode_id=episode_id,
)
def step(self, action: MLDebugAction) -> MLDebugObservation:
# Get episode_id from action (passed from client)
episode_id = getattr(action, 'episode_id', None) or ""
# Look up session in global store
session = _get_session(episode_id) if episode_id else None
if session is None:
# No valid session - return error
return MLDebugObservation(
action_type=action.action_type,
error=f"No active session for episode_id '{episode_id}'. Call reset() first.",
episode_done=True,
done=True,
reward=0.001,
turn=0,
task_id="task1",
task_description="",
buggy_code="",
episode_id=episode_id or "",
)
# Check if episode already complete
if session.submitted:
return MLDebugObservation(
action_type=action.action_type,
error="Episode already complete. Call reset() to start a new episode.",
episode_done=True,
done=True,
reward=0.001,
turn=session.step_count,
task_id=session.task_id,
task_description=TASKS[session.task_id].TASK_DESCRIPTION.strip(),
buggy_code=session.original_buggy_code,
episode_id=episode_id,
)
# Increment turn in session store
session.step_count += 1
current_turn = session.step_count
# Check max turns
if current_turn > MAX_TURNS_PER_EPISODE:
session.submitted = True
return MLDebugObservation(
action_type=action.action_type,
error=f"Maximum turns ({MAX_TURNS_PER_EPISODE}) exceeded. Episode terminated.",
episode_done=True,
done=True,
reward=0.001,
turn=current_turn,
task_id=session.task_id,
task_description=TASKS[session.task_id].TASK_DESCRIPTION.strip(),
buggy_code=session.original_buggy_code,
episode_id=episode_id,
)
# Dispatch based on action type
action_type = action.action_type
task_id = session.task_id
original_buggy_code = session.original_buggy_code
if action_type == "execute_snippet":
obs = tool_execute_snippet(action.code, current_turn)
obs.task_id = task_id
obs.task_description = TASKS[task_id].TASK_DESCRIPTION.strip()
obs.buggy_code = original_buggy_code
obs.episode_id = episode_id
return obs
elif action_type == "inspect_tensor":
obs = tool_inspect_tensor(action.setup_code, action.target_expression, current_turn)
obs.task_id = task_id
obs.task_description = TASKS[task_id].TASK_DESCRIPTION.strip()
obs.buggy_code = original_buggy_code
obs.episode_id = episode_id
return obs
elif action_type == "run_training_probe":
obs = tool_run_training_probe(action.code, action.steps, current_turn)
obs.task_id = task_id
obs.task_description = TASKS[task_id].TASK_DESCRIPTION.strip()
obs.buggy_code = original_buggy_code
obs.episode_id = episode_id
return obs
elif action_type == "get_variable_state":
obs = tool_get_variable_state(action.setup_code, action.expressions, current_turn)
obs.task_id = task_id
obs.task_description = TASKS[task_id].TASK_DESCRIPTION.strip()
obs.buggy_code = original_buggy_code
obs.episode_id = episode_id
return obs
elif action_type == "inspect_diff":
obs = tool_inspect_diff(
original_buggy_code,
action.proposed_code,
current_turn,
)
obs.task_id = task_id
obs.task_description = TASKS[task_id].TASK_DESCRIPTION.strip()
obs.buggy_code = original_buggy_code
obs.episode_id = episode_id
return obs
elif action_type == "submit_fix":
return self._handle_submit_fix(action, session)
else:
# Unknown action type - treat as submit_fix for backward compat
if action.fixed_code:
return self._handle_submit_fix(action, session)
else:
return MLDebugObservation(
action_type=action_type,
error=f"Unknown action type: {action_type}",
episode_done=False,
reward=None,
turn=current_turn,
task_id=task_id,
task_description=TASKS[task_id].TASK_DESCRIPTION.strip(),
buggy_code=original_buggy_code,
episode_id=episode_id,
)
def _handle_submit_fix(self, action: MLDebugAction, session: SessionState) -> MLDebugObservation:
"""Handle the submit_fix action (terminal action)."""
# Mark as submitted in session
session.submitted = True
turns_used = session.step_count
episode_id = session.episode_id
task_id = session.task_id
original_buggy_code = session.original_buggy_code
# Get the fixed code
fixed_code = action.fixed_code or action.code or ""
if not fixed_code or not fixed_code.strip():
return MLDebugObservation(
action_type="submit_fix",
turn=turns_used,
episode_done=True,
done=True,
reward=0.001,
success=False,
grader_details={"error": "empty code submitted"},
turns_used=turns_used,
task_id=task_id,
task_description=TASKS[task_id].TASK_DESCRIPTION.strip(),
buggy_code=original_buggy_code,
error_log="empty code submitted",
error="Empty code submitted",
episode_id=episode_id,
)
# Execute and grade
run_result1 = execute_code(fixed_code)
reward1, breakdown1 = score_task(task_id, run_result1)
# Consistency check for high scores
consistency_flag = False
reward_variance = 0.0
final_reward = reward1
final_breakdown = breakdown1
run_result = run_result1
if reward1 > 0.5:
run_result2 = execute_code(fixed_code)
reward2, breakdown2 = score_task(task_id, run_result2)
reward_variance = abs(reward1 - reward2)
if reward_variance > 0.15:
consistency_flag = True
final_reward = min(reward1, reward2)
if reward2 < reward1:
final_breakdown = breakdown2
run_result = run_result2
else:
consistency_flag = False
final_reward = (reward1 + reward2) / 2.0
session.best_reward = max(session.best_reward, final_reward)
# Parse metrics
losses = parse_losses(run_result.stdout)
val_accs = parse_val_accs(run_result.stdout)
final_loss = None
if losses:
final_loss = losses[-1]
else:
match = re.search(r"FINAL_LOSS:([-\d.]+)", run_result.stdout)
if match:
final_loss = float(match.group(1))
metrics = {
"exit_code": run_result.exit_code,
"elapsed_seconds": run_result.elapsed_seconds,
"timed_out": run_result.timed_out,
"step": session.step_count,
"best_reward_so_far": session.best_reward,
"final_loss": final_loss,
"nan_count": sum(1 for x in losses if math.isnan(x) or math.isinf(x)) if losses else 0,
"val_acc": val_accs[-1] if val_accs else None,
"consistency_flag": consistency_flag,
"reward_variance": round(reward_variance, 4),
"reward_breakdown": final_breakdown,
}
session.trajectory.append({
"step": session.step_count,
"reward": final_reward,
"best_reward": session.best_reward,
"metrics": metrics,
"done": True,
"timestamp": time.time(),
})
task = TASKS[task_id]
# Return complete outputs without truncation
return MLDebugObservation(
action_type="submit_fix",
turn=turns_used,
episode_done=True,
done=True,
reward=final_reward,
success=final_reward >= 0.7,
grader_details=final_breakdown,
turns_used=turns_used,
task_id=task_id,
task_description=task.TASK_DESCRIPTION.strip(),
buggy_code=original_buggy_code,
error_log=(run_result.stdout + "\n" + run_result.stderr).strip(),
last_reward=final_reward,
metrics=metrics,
stdout=run_result.stdout,
stderr=run_result.stderr,
exit_code=run_result.exit_code,
timed_out=run_result.timed_out,
losses=losses if losses else [],
final_loss=final_loss,
episode_id=episode_id,
)
@property
def trajectory(self) -> list[dict]:
return list(self._trajectory)
@property
def state(self) -> State:
return self._state
# ── Tool Definitions for /tools endpoint ───────────────────────────────────
TOOL_DEFINITIONS = [
{
"name": "execute_snippet",
"description": "Execute a short Python code snippet and return stdout, stderr, and exit code. Use for quick probes and checks.",
"action_fields": {
"action_type": "execute_snippet",
"code": "str - The Python snippet to run",
},
"observation_fields": ["stdout", "stderr", "exit_code", "timed_out", "turn"],
"constraints": {
"timeout": "10 seconds",
"network": "disabled",
"file_writes": "/tmp only",
},
},
{
"name": "inspect_tensor",
"description": "Inspect a tensor or module parameter. Returns shape, dtype, requires_grad, grad status, min/max/mean, and NaN/Inf checks.",
"action_fields": {
"action_type": "inspect_tensor",
"setup_code": "str - Code that defines the tensor/module (imports + definitions)",
"target_expression": "str - The expression to inspect (e.g., 'model.weight.grad')",
},
"observation_fields": ["shape", "dtype", "requires_grad", "grad_is_none", "min_val", "max_val", "mean_val", "is_nan", "is_inf", "error", "turn"],
},
{
"name": "run_training_probe",
"description": "Run N steps of training and return loss curve + gradient norms. Use to verify if a fix is working before submission.",
"action_fields": {
"action_type": "run_training_probe",
"code": "str - The full training script to probe",
"steps": "int - Number of training steps (1-10, default 5)",
},
"observation_fields": ["losses", "grad_norms", "optimizer_param_count", "final_loss", "loss_is_nan", "loss_is_inf", "stderr", "timed_out", "turn"],
},
{
"name": "get_variable_state",
"description": "Evaluate multiple expressions and return their repr, type, value, and shape. Useful for checking model.training, optimizer params, etc.",
"action_fields": {
"action_type": "get_variable_state",
"setup_code": "str - Setup code (imports + definitions)",
"expressions": "list[str] - List of expressions to evaluate (max 10)",
},
"observation_fields": ["results", "turn"],
},
{
"name": "inspect_diff",
"description": "Generate a unified diff between the original buggy code and your proposed fix. Review your changes before submitting.",
"action_fields": {
"action_type": "inspect_diff",
"proposed_code": "str - Your proposed fix",
},
"observation_fields": ["diff", "lines_changed", "additions", "deletions", "turn"],
},
{
"name": "submit_fix",
"description": "Submit your final fix. This is the terminal action that ends the episode and returns the grader score.",
"action_fields": {
"action_type": "submit_fix",
"fixed_code": "str - The complete fixed Python script",
},
"observation_fields": ["reward", "success", "grader_details", "turns_used", "episode_done", "turn"],
"terminal": True,
},
]