whipstudio / baseline_agent.py
Amogh-kal1's picture
Upload folder using huggingface_hub
1d1a2bf verified
import asyncio
import json
import os
import re
from dotenv import load_dotenv
load_dotenv()
import httpx
# ── Task difficulty configuration ──────────────────────────────────────────
TASK_CONFIG = {
"task1": {"difficulty": "easy", "max_turns": 8, "description": "Broken training loop"},
"task2": {"difficulty": "medium", "max_turns": 10, "description": "Silent NaN loss"},
"task3": {"difficulty": "medium", "max_turns": 12, "description": "Label inversion"},
"task4": {"difficulty": "medium", "max_turns": 10, "description": "Wrong loss function"},
"task5": {"difficulty": "medium", "max_turns": 10, "description": "Frozen backbone"},
"task6": {"difficulty": "hard", "max_turns": 15, "description": "Input-Output mismatch"},
}
# ── System prompts ─────────────────────────────────────────────────────────
SYSTEM_PROMPT_TOOLS = """
You are an expert PyTorch debugging agent with access to debugging tools.
You receive a broken training script and must systematically debug and fix ALL bugs.
AVAILABLE TOOLS:
1. execute_snippet - Run a quick Python snippet to test hypotheses
2. inspect_tensor - Check tensor shape, dtype, gradients, NaN/Inf
3. get_variable_state - Inspect multiple variable values
4. run_training_probe - Run a few training steps to see loss curve
5. inspect_diff - Review your proposed changes before submitting
6. submit_fix - Submit your final fix (terminal action)
DEBUGGING STRATEGY:
1. First, analyze the buggy code carefully
2. Use execute_snippet or get_variable_state to verify your hypotheses
3. Use inspect_tensor to check gradient flow and tensor properties
4. Use run_training_probe to test potential fixes
5. Use inspect_diff to review your changes
6. Only submit_fix when confident
RESPONSE FORMAT - CRITICAL:
You MUST respond with ONLY a valid JSON object. No markdown, no explanation outside JSON.
For submit_fix (THE fixed_code MUST BE ACTUAL PYTHON CODE STRING, NOT JSON):
{
"reasoning": "Why this fix should work",
"action_type": "submit_fix",
"action_params": {
"fixed_code": "import torch\\nimport torch.nn as nn\\n# Full Python script here\\nprint('LOSSES:', losses)"
}
}
For execute_snippet:
{
"reasoning": "Testing hypothesis",
"action_type": "execute_snippet",
"action_params": {"code": "print('test')"}
}
For inspect_tensor:
{
"reasoning": "Check gradients",
"action_type": "inspect_tensor",
"action_params": {
"setup_code": "import torch\\nmodel = ...",
"target_expression": "model.weight.grad"
}
}
For get_variable_state:
{
"reasoning": "Verify shapes",
"action_type": "get_variable_state",
"action_params": {
"setup_code": "import torch\\ndata = ...",
"expressions": ["data.shape", "data.dtype"]
}
}
For run_training_probe:
{
"reasoning": "Test my fix",
"action_type": "run_training_probe",
"action_params": {
"code": "import torch\\n# full script",
"steps": 5
}
}
For inspect_diff:
{
"reasoning": "Review changes",
"action_type": "inspect_diff",
"action_params": {"proposed_code": "import torch\\n# your fix"}
}
CRITICAL RULES:
- Respond ONLY with valid JSON
- For submit_fix: fixed_code = PYTHON STRING (use \\n for newlines), NOT nested JSON
- Fixed code must print: LOSSES:[v1, v2, ...]
- For task3: also print VAL_ACCS:[...] and FINAL_LOSS:X.XX
- Keep torch.manual_seed() intact
""".strip()
SYSTEM_PROMPT_SIMPLE = """
You are an expert PyTorch debugging agent.
You receive a broken training script and must fix ALL bugs in it.
Rules:
- Return ONLY the complete corrected Python code, nothing else.
- No markdown, no backticks, no explanation text.
- The script must print losses in format: LOSSES:[v1, v2, ...]
- For task3, also print: VAL_ACCS:[v1,...] and FINAL_LOSS:X.XX
- Keep all torch.manual_seed() calls intact.
""".strip()
SUPPORTED_MODEL_IDS = [
"Qwen/Qwen2.5-Coder-1.5B-Instruct",
"Qwen/Qwen2.5-Coder-3B-Instruct",
"Qwen/Qwen2.5-Coder-7B-Instruct",
"Qwen/Qwen2.5-Coder-14B-Instruct",
"Qwen/Qwen2.5-Coder-32B-Instruct",
"mistralai/Mistral-7B-Instruct-v0.3",
]
def get_model(model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct"):
from smolagents import InferenceClientModel
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
raise RuntimeError(
"HF_TOKEN is not set. Set HF_TOKEN to run /baseline with InferenceClientModel."
)
if model_id not in SUPPORTED_MODEL_IDS:
raise ValueError(
f"Unsupported model_id '{model_id}'. Supported options: {SUPPORTED_MODEL_IDS}"
)
return InferenceClientModel(
model_id=model_id,
token=hf_token,
)
def _extract_text(response) -> str:
if isinstance(response, str):
return response
if hasattr(response, "content"):
content = getattr(response, "content")
if isinstance(content, str):
return content
if isinstance(content, list):
chunks = []
for item in content:
if isinstance(item, str):
chunks.append(item)
elif isinstance(item, dict):
text = item.get("text") or item.get("content")
if text:
chunks.append(str(text))
if chunks:
return "\n".join(chunks)
if isinstance(response, dict):
text = response.get("content") or response.get("text")
if isinstance(text, str):
return text
return str(response)
def _generate_response(model, system_prompt: str, prompt: str) -> str:
"""Generate a response from the model."""
if hasattr(model, "generate"):
generate = getattr(model, "generate")
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
try:
return _extract_text(generate(messages=messages))
except TypeError:
return _extract_text(generate(messages))
if callable(model):
try:
return _extract_text(model(prompt, system_prompt=system_prompt))
except TypeError:
return _extract_text(model(prompt))
raise AttributeError("Model does not support callable() or generate() inference APIs")
def _parse_agent_response(response: str) -> dict:
"""Parse the agent's JSON response, handling potential markdown wrapping."""
response = response.strip()
# Remove markdown code blocks if present
if response.startswith("```json"):
response = response[7:]
elif response.startswith("```"):
response = response[3:]
if response.endswith("```"):
response = response[:-3]
response = response.strip()
try:
return json.loads(response)
except json.JSONDecodeError:
# Try to extract JSON from the response
match = re.search(r'\{[\s\S]*\}', response)
if match:
try:
return json.loads(match.group())
except json.JSONDecodeError:
pass
# Fallback: treat entire response as code for submit_fix
return {
"reasoning": "Fallback: treating response as code",
"action_type": "submit_fix",
"action_params": {"fixed_code": response}
}
def _format_tool_result(obs: dict, action_type: str) -> str:
"""Format tool observation for the agent's context."""
turn = obs.get("turn", 0)
error = obs.get("error")
if error:
return f"Turn {turn} - {action_type}: ERROR - {error}"
if action_type == "execute_snippet":
stdout = obs.get("stdout", "")[:1500]
stderr = obs.get("stderr", "")[:500]
exit_code = obs.get("exit_code", 0)
timed_out = obs.get("timed_out", False)
result = f"Turn {turn} - execute_snippet (exit={exit_code}, timed_out={timed_out}):\n"
if stdout:
result += f"stdout:\n{stdout}\n"
if stderr:
result += f"stderr:\n{stderr}\n"
return result
elif action_type == "inspect_tensor":
parts = [f"Turn {turn} - inspect_tensor:"]
if obs.get("shape"): parts.append(f" shape: {obs['shape']}")
if obs.get("dtype"): parts.append(f" dtype: {obs['dtype']}")
if obs.get("requires_grad") is not None: parts.append(f" requires_grad: {obs['requires_grad']}")
if obs.get("grad_is_none") is not None: parts.append(f" grad_is_none: {obs['grad_is_none']}")
if obs.get("min_val") is not None: parts.append(f" min: {obs['min_val']:.6f}")
if obs.get("max_val") is not None: parts.append(f" max: {obs['max_val']:.6f}")
if obs.get("mean_val") is not None: parts.append(f" mean: {obs['mean_val']:.6f}")
if obs.get("is_nan"): parts.append(" ⚠️ CONTAINS NaN")
if obs.get("is_inf"): parts.append(" ⚠️ CONTAINS Inf")
return "\n".join(parts)
elif action_type == "run_training_probe":
losses = obs.get("losses", [])
grad_norms = obs.get("grad_norms", {})
final_loss = obs.get("final_loss")
loss_is_nan = obs.get("loss_is_nan", False)
loss_is_inf = obs.get("loss_is_inf", False)
timed_out = obs.get("timed_out", False)
result = f"Turn {turn} - run_training_probe:\n"
result += f" losses: {losses[:10]}\n"
result += f" final_loss: {final_loss}\n"
if grad_norms:
result += f" grad_norms: {dict(list(grad_norms.items())[:5])}\n"
if loss_is_nan: result += " ⚠️ NaN LOSS DETECTED\n"
if loss_is_inf: result += " ⚠️ Inf LOSS DETECTED\n"
if timed_out: result += " ⚠️ TIMED OUT\n"
return result
elif action_type == "get_variable_state":
results = obs.get("results", {})
lines = [f"Turn {turn} - get_variable_state:"]
for expr, res in results.items():
if res.get("error"):
lines.append(f" {expr}: ERROR - {res['error']}")
else:
val = res.get("repr", str(res.get("value", "?")))[:100]
typ = res.get("type", "?")
shape = res.get("shape")
shape_str = f" shape={shape}" if shape else ""
lines.append(f" {expr}: {val} ({typ}{shape_str})")
return "\n".join(lines)
elif action_type == "inspect_diff":
lines_changed = obs.get("lines_changed", 0)
additions = obs.get("additions", 0)
deletions = obs.get("deletions", 0)
diff = obs.get("diff", "")[:2000]
return f"Turn {turn} - inspect_diff: {lines_changed} lines changed (+{additions}/-{deletions})\n{diff}"
return f"Turn {turn} - {action_type}: {json.dumps(obs, default=str)[:500]}"
async def run_single_task(
task_id: str,
env_url: str = "http://localhost:7860",
model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
) -> float:
"""Backwards-compatible wrapper that returns just the score."""
result = await run_single_task_detailed(task_id, env_url, model_id)
return result["score"]
async def run_single_task_detailed(
task_id: str,
env_url: str = "http://localhost:7860",
model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
use_tools: bool = True,
) -> dict:
"""Run the baseline agent on a single task with optional tool use."""
model = get_model(model_id)
timeout = httpx.Timeout(900.0, connect=10.0)
task_config = TASK_CONFIG.get(task_id, {"max_turns": 10, "difficulty": "medium"})
max_turns = task_config["max_turns"]
tool_history = []
attempts_log = []
async with httpx.AsyncClient(timeout=timeout) as client:
# Reset environment
reset_resp = await client.post(f"{env_url}/reset", json={"task_id": task_id})
reset_resp.raise_for_status()
obs = reset_resp.json().get("observation", reset_resp.json())
buggy_code = obs.get("buggy_code", "")
task_description = obs.get("task_description", "")
best_reward = 0.0
best_code = ""
best_output = ""
turn = 0
if use_tools:
# Multi-step tool-using agent
while turn < max_turns:
turn += 1
# Build context with tool history
tool_context = "\n\n".join(tool_history[-5:]) if tool_history else "No previous tool calls."
prompt = f"""
Task: {task_description}
Buggy code:
```python
{buggy_code}
```
Turn {turn}/{max_turns} - Tool History:
{tool_context}
Best reward so far: {best_reward}
Analyze the buggy code and decide your next action. Remember:
- Use tools to understand the bugs before fixing
- You have {max_turns - turn} turns remaining
- Submit your fix when confident
Respond with a JSON object containing your reasoning and action.
""".strip()
try:
response = _generate_response(model, SYSTEM_PROMPT_TOOLS, prompt)
parsed = _parse_agent_response(response)
action_type = parsed.get("action_type", "submit_fix")
action_params = parsed.get("action_params", {})
reasoning = parsed.get("reasoning", "")
# Build action payload
action = {"action_type": action_type}
if action_type == "execute_snippet":
code = action_params.get("code", "print('test')")
# Validate it's actual Python, not JSON
if code.strip().startswith("{") and '"action_type"' in code:
raise ValueError("Model returned nested JSON instead of Python code")
action["code"] = code
elif action_type == "inspect_tensor":
setup_code = action_params.get("setup_code", "")
# Truncate if too long (model may be exceeding context)
if len(setup_code) > 8000:
setup_code = setup_code[:8000] + "\n# ... truncated ..."
action["setup_code"] = setup_code
action["target_expression"] = action_params.get("target_expression", "")
elif action_type == "run_training_probe":
code = action_params.get("code", buggy_code)
if len(code) > 8000:
code = code[:8000]
if code.strip().startswith("{") and '"action_type"' in code:
raise ValueError("Model returned nested JSON instead of Python code")
action["code"] = code
action["steps"] = min(action_params.get("steps", 5), 10)
elif action_type == "get_variable_state":
setup_code = action_params.get("setup_code", "")
if len(setup_code) > 8000:
setup_code = setup_code[:8000]
action["setup_code"] = setup_code
action["expressions"] = action_params.get("expressions", [])[:10]
elif action_type == "inspect_diff":
action["proposed_code"] = action_params.get("proposed_code", "")
elif action_type == "submit_fix":
fixed_code = action_params.get("fixed_code", "")
# Clean up markdown if present
if "```python" in fixed_code:
fixed_code = fixed_code.split("```python", 1)[1].split("```", 1)[0].strip()
elif "```json" in fixed_code:
fixed_code = fixed_code.split("```json", 1)[1].split("```", 1)[0].strip()
elif "```" in fixed_code:
fixed_code = fixed_code.split("```", 1)[1].split("```", 1)[0].strip()
# CRITICAL: Detect if model returned JSON instead of Python code
if fixed_code.strip().startswith("{") and ('"action_type"' in fixed_code or '"reasoning"' in fixed_code):
raise ValueError(
"Model returned nested JSON instead of Python code for submit_fix. "
"The fixed_code field must contain actual Python code, not JSON."
)
# Validate it's not empty
if not fixed_code or len(fixed_code) < 20:
raise ValueError("submit_fix received empty or too-short code")
action["fixed_code"] = fixed_code
else:
# Unknown action, treat as submit_fix
action = {"action_type": "submit_fix", "fixed_code": str(action_params)}
# Execute action
step_resp = await client.post(f"{env_url}/step", json={"action": action})
step_resp.raise_for_status()
result = step_resp.json()
obs = result.get("observation", {})
reward = float(result.get("reward", 0.0) or 0.0)
done = result.get("done", False) or obs.get("episode_done", False)
# Format and store tool result
tool_result = _format_tool_result(obs, action_type)
tool_history.append(f"[Turn {turn}] Reasoning: {reasoning[:200]}\n{tool_result}")
# Log attempt if it was a submit
if action_type == "submit_fix":
output_log = obs.get("error_log", "") if isinstance(obs, dict) else ""
attempts_log.append({
"turn": turn,
"action": "submit_fix",
"code": action.get("fixed_code", "")[:2000],
"output": output_log[:2000],
"reward": reward,
})
if reward > best_reward:
best_reward = reward
best_code = action.get("fixed_code", "")
best_output = output_log
if reward >= 0.95 or done:
break
else:
attempts_log.append({
"turn": turn,
"action": action_type,
"params": {k: str(v)[:200] for k, v in action_params.items()},
"result": tool_result[:500],
})
if done:
break
except ValueError as ve:
# Model error (nested JSON, empty code, etc.)
tool_history.append(f"[Turn {turn}] MODEL ERROR: {str(ve)[:300]}")
attempts_log.append({
"turn": turn,
"action": "error",
"params": {"error": str(ve), "response_preview": response[:500]},
"result": "Model generated invalid response - skipping turn",
})
# Continue to next turn, give model another chance
continue
except httpx.HTTPError as he:
# API error
tool_history.append(f"[Turn {turn}] API ERROR: {str(he)[:200]}")
attempts_log.append({
"turn": turn,
"action": "error",
"params": {"error": str(he)},
"result": "API call failed",
})
# Continue to next turn
continue
except Exception as e:
tool_history.append(f"[Turn {turn}] UNEXPECTED ERROR: {str(e)[:200]}")
attempts_log.append({
"turn": turn,
"action": "error",
"params": {"error": str(e), "type": type(e).__name__},
"result": "Unexpected error",
})
# Continue to next turn
continue
# If we haven't submitted yet, do a final submit
if best_reward == 0.0 and turn < max_turns:
# Generate a simple fix without tools
simple_prompt = f"""
Task: {task_description}
Buggy code:
{buggy_code}
Tool debugging history:
{chr(10).join(tool_history[-3:])}
Generate the complete fixed Python code. Return ONLY the code, no explanation.
""".strip()
try:
fixed_code = _generate_response(model, SYSTEM_PROMPT_SIMPLE, simple_prompt)
if "```python" in fixed_code:
fixed_code = fixed_code.split("```python", 1)[1].split("```", 1)[0].strip()
elif "```" in fixed_code:
fixed_code = fixed_code.split("```", 1)[1].split("```", 1)[0].strip()
step_resp = await client.post(f"{env_url}/step", json={
"action": {"action_type": "submit_fix", "fixed_code": fixed_code}
})
result = step_resp.json()
reward = float(result.get("reward", 0.0) or 0.0)
obs = result.get("observation", {})
if reward > best_reward:
best_reward = reward
best_code = fixed_code
best_output = obs.get("error_log", "")
except Exception:
pass
else:
# Simple direct submission (fallback mode)
for attempt in range(1, 4):
prompt = f"""
Task: {task_description}
Buggy code:
{buggy_code}
Previous execution output (if any):
{obs.get('error_log', 'None')}
Previous score: {obs.get('last_reward', 0.0)}
""".strip()
fixed_code = _generate_response(model, SYSTEM_PROMPT_SIMPLE, prompt)
if "```python" in fixed_code:
fixed_code = fixed_code.split("```python", 1)[1].split("```", 1)[0].strip()
elif "```" in fixed_code:
fixed_code = fixed_code.split("```", 1)[1].split("```", 1)[0].strip()
step_payload = {"action": {"action_type": "submit_fix", "fixed_code": fixed_code}}
step_resp = await client.post(f"{env_url}/step", json=step_payload)
step_resp.raise_for_status()
result = step_resp.json()
reward = float(result.get("reward", 0.0) or 0.0)
obs = result.get("observation", obs)
output_log = obs.get("error_log", "") if isinstance(obs, dict) else ""
attempts_log.append({
"attempt": attempt,
"code": fixed_code,
"output": output_log[:3000],
"reward": reward,
})
if reward > best_reward:
best_reward = reward
best_code = fixed_code
best_output = output_log
if result.get("done") or reward >= 0.95:
break
return {
"score": best_reward,
"fixed_code": best_code,
"output": best_output[:3000],
"attempts": attempts_log,
"tool_history": tool_history,
}
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--env-url", default="http://localhost:7860")
parser.add_argument("--no-tools", action="store_true", help="Disable tool use")
parser.add_argument("--task", default=None, help="Run single task")
args = parser.parse_args()
async def main():
tasks = [args.task] if args.task else ["task1", "task2", "task3", "task4", "task5", "task6"]
scores = {}
for tid in tasks:
try:
result = await asyncio.wait_for(
run_single_task_detailed(tid, args.env_url, use_tools=not args.no_tools),
timeout=900.0
)
s = result["score"]
print(f"{tid}: {s:.4f}")
if result.get("tool_history"):
print(f" Tool calls: {len(result['tool_history'])}")
except TimeoutError:
s = 0.0
print(f"{tid}: TIMEOUT")
scores[tid] = round(s, 4)
if len(scores) > 1:
print(f"Average: {sum(scores.values()) / len(scores):.4f}")
asyncio.run(main())