whipstudio / examples /tool_agent.py
Amogh-kal1's picture
Upload folder using huggingface_hub
ffd85e1 verified
"""
Tool-using agent for WhipStudio.
This example demonstrates how to use WhipStudio's debugging tools
to iteratively analyze and fix bugs before submitting a final solution.
The agent uses:
1. execute_snippet - To test hypotheses about the code
2. inspect_tensor - To check tensor shapes, dtypes, and gradients
3. get_variable_state - To evaluate multiple expressions
4. run_training_probe - To test potential fixes
5. inspect_diff - To review changes before submission
6. submit_fix - Final submission
Usage:
python examples/tool_agent.py --env-url http://localhost:7860 --task task1
python examples/tool_agent.py --env-url https://your-space.hf.space --task task6
"""
import argparse
import json
import os
import re
import httpx
from openai import OpenAI
SYSTEM_PROMPT = """You are an expert PyTorch debugging agent that fixes buggy ML code.
You have debugging tools available, but your PRIMARY GOAL is to SUBMIT A FIX.
AVAILABLE TOOLS:
1. execute_snippet - Run Python code to test hypotheses
2. inspect_tensor - Check tensor shape, dtype, gradients, NaN/Inf
3. get_variable_state - Inspect multiple variable values
4. run_training_probe - Run training steps to see loss curve
5. inspect_diff - Preview your changes before submitting
6. submit_fix - Submit your final fix (ALWAYS DO THIS BEFORE RUNNING OUT OF TURNS)
RESPONSE FORMAT - ALWAYS respond with valid JSON only:
{
"reasoning": "Brief analysis",
"action_type": "tool_name",
"action_params": { ... }
}
For submit_fix, the fixed_code must be COMPLETE working Python code:
{
"reasoning": "Fix explanation",
"action_type": "submit_fix",
"action_params": {
"fixed_code": "import torch\\nimport torch.nn as nn\\n..."
}
}
CRITICAL RULES:
1. You MUST call submit_fix before your turns run out
2. If you have 2 or fewer turns remaining, IMMEDIATELY submit your fix
3. Don't waste turns - analyze, test once if needed, then SUBMIT
4. Fixed code must print: LOSSES:[v1, v2, ...] or similar metrics
5. Keep torch.manual_seed() calls intact for reproducibility
6. Use \\n for newlines in code strings
EFFICIENT DEBUGGING:
- Turn 1-2: Analyze bug, maybe one quick test
- Turn 3+: SUBMIT YOUR FIX - don't keep testing!""".strip()
def get_client():
"""Initialize OpenAI-compatible client."""
api_base = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
api_key = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY")
if not api_key:
raise ValueError("Set HF_TOKEN or OPENAI_API_KEY environment variable")
return OpenAI(base_url=api_base, api_key=api_key)
def parse_agent_response(response: str) -> dict:
"""Parse JSON response from agent, handling common formatting issues."""
response = response.strip()
# Remove markdown code blocks
if response.startswith("```json"):
response = response[7:]
elif response.startswith("```"):
response = response[3:]
if response.endswith("```"):
response = response[:-3]
response = response.strip()
try:
return json.loads(response)
except json.JSONDecodeError:
# Try to extract JSON from response
match = re.search(r'\{[\s\S]*\}', response)
if match:
try:
return json.loads(match.group())
except json.JSONDecodeError:
pass
# Fallback
return {
"reasoning": "Fallback: could not parse response",
"action_type": "submit_fix",
"action_params": {"fixed_code": response}
}
def format_tool_result(obs: dict, action_type: str) -> str:
"""Format tool result for display."""
turn = obs.get("turn", 0)
error = obs.get("error")
if error:
return f"[Turn {turn}] {action_type} ERROR: {error}"
if action_type == "execute_snippet":
stdout = obs.get("stdout", "")[:2000] # Increased from 500
stderr = obs.get("stderr", "")[:1000] # Increased from 200
exit_code = obs.get("exit_code", 0)
result = f"[Turn {turn}] execute_snippet (exit={exit_code}):\n{stdout}"
if stderr:
result += f"\nSTDERR:\n{stderr}"
return result
elif action_type == "inspect_tensor":
return f"""[Turn {turn}] inspect_tensor:
shape: {obs.get('shape')}
dtype: {obs.get('dtype')}
requires_grad: {obs.get('requires_grad')}
grad_is_none: {obs.get('grad_is_none')}
min/max/mean: {obs.get('min_val')}/{obs.get('max_val')}/{obs.get('mean_val')}
is_nan: {obs.get('is_nan')}, is_inf: {obs.get('is_inf')}"""
elif action_type == "run_training_probe":
losses = obs.get("losses", [])[:10]
final_loss = obs.get("final_loss")
grad_norms = obs.get("grad_norms", {})
return f"[Turn {turn}] run_training_probe:\n losses: {losses}\n final_loss: {final_loss}\n grad_norms: {grad_norms}"
elif action_type == "get_variable_state":
results = obs.get("results", {})
lines = [f"[Turn {turn}] get_variable_state:"]
for expr, res in results.items():
if res.get("error"):
lines.append(f" {expr}: ERROR - {res['error']}")
else:
val = res.get("repr", str(res.get("value", "?")))[:200] # Increased from 80
lines.append(f" {expr}: {val}")
return "\n".join(lines)
elif action_type == "inspect_diff":
lines_changed = obs.get("lines_changed", 0)
additions = obs.get("additions", 0)
deletions = obs.get("deletions", 0)
diff = obs.get("diff", "")[:500]
return f"[Turn {turn}] inspect_diff: {lines_changed} lines changed (+{additions}/-{deletions})\n{diff}"
elif action_type == "submit_fix":
reward = obs.get("reward", 0.0)
return f"[Turn {turn}] submit_fix: reward={reward}"
return f"[Turn {turn}] {action_type}: {json.dumps(obs, default=str)[:500]}"
def run_tool_agent(env_url: str, task_id: str, client, max_turns: int = 10) -> float:
"""Run a tool-using agent on a single task."""
print(f"\n{'='*60}")
print(f"Tool Agent: {task_id}")
print(f"{'='*60}")
model = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-Coder-3B-Instruct")
# Reset environment
with httpx.Client(timeout=60.0) as http_client:
resp = http_client.post(f"{env_url}/reset", json={"task_id": task_id})
resp.raise_for_status()
obs = resp.json().get("observation", resp.json())
buggy_code = obs.get("buggy_code", "")
task_description = obs.get("task_description", "")
episode_id = obs.get("episode_id", "") # Track episode_id for session persistence
print(f"Task: {task_description[:100]}...")
print(f"Episode ID: {episode_id[:16]}..." if episode_id else "No episode_id")
tool_history = []
best_reward = 0.0
for turn in range(1, max_turns + 1):
print(f"\n--- Turn {turn}/{max_turns} ---")
turns_remaining = max_turns - turn
# Build context
history_text = "\n".join(tool_history[-5:]) if tool_history else "No previous tool calls."
# Force submission on last turn
if turns_remaining == 0:
urgency = "\n⚠️ THIS IS YOUR LAST TURN! You MUST call submit_fix NOW with your best fix."
elif turns_remaining <= 2:
urgency = f"\n⚠️ ONLY {turns_remaining} TURN(S) LEFT! Submit your fix soon!"
else:
urgency = ""
user_prompt = f"""Task: {task_description}
Buggy Code:
```python
{buggy_code}
```
Turn {turn}/{max_turns}
Best reward so far: {best_reward}
Turns remaining: {turns_remaining}{urgency}
Tool History:
{history_text}
Analyze the code and submit your fix. Don't waste turns on unnecessary testing.""".strip()
# Get LLM response
try:
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
max_tokens=4096,
temperature=0.2,
)
response_text = response.choices[0].message.content.strip()
parsed = parse_agent_response(response_text)
except Exception as e:
print(f"LLM Error: {e}")
tool_history.append(f"[Turn {turn}] LLM ERROR: {e}")
continue
action_type = parsed.get("action_type", "submit_fix")
action_params = parsed.get("action_params", {})
reasoning = parsed.get("reasoning", "")[:100]
# Force submit_fix on last turn if agent didn't choose it
if turns_remaining == 0 and action_type != "submit_fix":
print(f"[OVERRIDE] Last turn - forcing submit_fix instead of {action_type}")
action_type = "submit_fix"
# Use fixed_code from action_params if available, otherwise use any code param
fixed_code = action_params.get("fixed_code") or action_params.get("code") or action_params.get("proposed_code") or buggy_code
action_params = {"fixed_code": fixed_code}
print(f"Action: {action_type}")
print(f"Reasoning: {reasoning}...")
# Build action payload - ALWAYS include episode_id for session tracking
action = {
"action_type": action_type,
"episode_id": episode_id, # Critical for session persistence in HTTP mode
}
if action_type == "execute_snippet":
action["code"] = action_params.get("code", "print('test')")
elif action_type == "inspect_tensor":
action["setup_code"] = action_params.get("setup_code", "")[:8000]
action["target_expression"] = action_params.get("target_expression", "")
elif action_type == "run_training_probe":
action["code"] = action_params.get("code", buggy_code)[:8000]
action["steps"] = min(action_params.get("steps", 5), 10)
elif action_type == "get_variable_state":
action["setup_code"] = action_params.get("setup_code", "")[:8000]
action["expressions"] = action_params.get("expressions", [])[:10]
elif action_type == "inspect_diff":
action["proposed_code"] = action_params.get("proposed_code", "")
elif action_type == "submit_fix":
fixed_code = action_params.get("fixed_code", "")
# Clean markdown
if "```python" in fixed_code:
fixed_code = fixed_code.split("```python", 1)[1].split("```", 1)[0].strip()
elif "```" in fixed_code:
fixed_code = fixed_code.split("```", 1)[1].split("```", 1)[0].strip()
action["fixed_code"] = fixed_code
# Execute action
try:
with httpx.Client(timeout=120.0) as http_client:
resp = http_client.post(f"{env_url}/step", json={"action": action})
resp.raise_for_status()
result = resp.json()
except Exception as e:
print(f"API Error: {e}")
tool_history.append(f"[Turn {turn}] API ERROR: {e}")
continue
obs = result.get("observation", {})
reward = float(result.get("reward", 0.0) or 0.0)
done = result.get("done", False) or obs.get("episode_done", False)
# Format and store result
tool_result = format_tool_result(obs, action_type)
tool_history.append(tool_result)
print(tool_result)
if reward > best_reward:
best_reward = reward
if action_type == "submit_fix":
print(f"Reward: {reward:.4f}")
if reward >= 0.95 or done:
break
if done:
break
print(f"\nFinal reward for {task_id}: {best_reward:.4f}")
return best_reward
def main():
parser = argparse.ArgumentParser(description="Tool-using WhipStudio Agent")
parser.add_argument("--env-url", default="http://localhost:7860", help="Environment URL")
parser.add_argument("--task", default="task1", help="Task ID to run")
parser.add_argument("--all-tasks", action="store_true", help="Run all tasks")
parser.add_argument("--max-turns", type=int, default=10, help="Max turns per task")
args = parser.parse_args()
client = get_client()
tasks = ["task1", "task2", "task3", "task4", "task5", "task6"] if args.all_tasks else [args.task]
results = {}
for task_id in tasks:
try:
score = run_tool_agent(args.env_url, task_id, client, args.max_turns)
results[task_id] = score
except Exception as e:
print(f"Error on {task_id}: {e}")
results[task_id] = 0.0
if len(results) > 1:
print("\n" + "="*60)
print("FINAL RESULTS")
print("="*60)
total = 0.0
for task_id, score in results.items():
emoji = "✅" if score >= 0.7 else ("📈" if score >= 0.3 else "❌")
print(f"{emoji} {task_id}: {score:.4f}")
total += score
avg = total / len(results) if results else 0.0
print(f"\nAverage: {avg:.4f}")
if __name__ == "__main__":
main()