Spaces:
Sleeping
Sleeping
| """ | |
| Tool-using agent for WhipStudio. | |
| This example demonstrates how to use WhipStudio's debugging tools | |
| to iteratively analyze and fix bugs before submitting a final solution. | |
| The agent uses: | |
| 1. execute_snippet - To test hypotheses about the code | |
| 2. inspect_tensor - To check tensor shapes, dtypes, and gradients | |
| 3. get_variable_state - To evaluate multiple expressions | |
| 4. run_training_probe - To test potential fixes | |
| 5. inspect_diff - To review changes before submission | |
| 6. submit_fix - Final submission | |
| Usage: | |
| python examples/tool_agent.py --env-url http://localhost:7860 --task task1 | |
| python examples/tool_agent.py --env-url https://your-space.hf.space --task task6 | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import re | |
| import httpx | |
| from openai import OpenAI | |
| SYSTEM_PROMPT = """You are an expert PyTorch debugging agent that fixes buggy ML code. | |
| You have debugging tools available, but your PRIMARY GOAL is to SUBMIT A FIX. | |
| AVAILABLE TOOLS: | |
| 1. execute_snippet - Run Python code to test hypotheses | |
| 2. inspect_tensor - Check tensor shape, dtype, gradients, NaN/Inf | |
| 3. get_variable_state - Inspect multiple variable values | |
| 4. run_training_probe - Run training steps to see loss curve | |
| 5. inspect_diff - Preview your changes before submitting | |
| 6. submit_fix - Submit your final fix (ALWAYS DO THIS BEFORE RUNNING OUT OF TURNS) | |
| RESPONSE FORMAT - ALWAYS respond with valid JSON only: | |
| { | |
| "reasoning": "Brief analysis", | |
| "action_type": "tool_name", | |
| "action_params": { ... } | |
| } | |
| For submit_fix, the fixed_code must be COMPLETE working Python code: | |
| { | |
| "reasoning": "Fix explanation", | |
| "action_type": "submit_fix", | |
| "action_params": { | |
| "fixed_code": "import torch\\nimport torch.nn as nn\\n..." | |
| } | |
| } | |
| CRITICAL RULES: | |
| 1. You MUST call submit_fix before your turns run out | |
| 2. If you have 2 or fewer turns remaining, IMMEDIATELY submit your fix | |
| 3. Don't waste turns - analyze, test once if needed, then SUBMIT | |
| 4. Fixed code must print: LOSSES:[v1, v2, ...] or similar metrics | |
| 5. Keep torch.manual_seed() calls intact for reproducibility | |
| 6. Use \\n for newlines in code strings | |
| EFFICIENT DEBUGGING: | |
| - Turn 1-2: Analyze bug, maybe one quick test | |
| - Turn 3+: SUBMIT YOUR FIX - don't keep testing!""".strip() | |
| def get_client(): | |
| """Initialize OpenAI-compatible client.""" | |
| api_base = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1") | |
| api_key = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError("Set HF_TOKEN or OPENAI_API_KEY environment variable") | |
| return OpenAI(base_url=api_base, api_key=api_key) | |
| def parse_agent_response(response: str) -> dict: | |
| """Parse JSON response from agent, handling common formatting issues.""" | |
| response = response.strip() | |
| # Remove markdown code blocks | |
| if response.startswith("```json"): | |
| response = response[7:] | |
| elif response.startswith("```"): | |
| response = response[3:] | |
| if response.endswith("```"): | |
| response = response[:-3] | |
| response = response.strip() | |
| try: | |
| return json.loads(response) | |
| except json.JSONDecodeError: | |
| # Try to extract JSON from response | |
| match = re.search(r'\{[\s\S]*\}', response) | |
| if match: | |
| try: | |
| return json.loads(match.group()) | |
| except json.JSONDecodeError: | |
| pass | |
| # Fallback | |
| return { | |
| "reasoning": "Fallback: could not parse response", | |
| "action_type": "submit_fix", | |
| "action_params": {"fixed_code": response} | |
| } | |
| def format_tool_result(obs: dict, action_type: str) -> str: | |
| """Format tool result for display.""" | |
| turn = obs.get("turn", 0) | |
| error = obs.get("error") | |
| if error: | |
| return f"[Turn {turn}] {action_type} ERROR: {error}" | |
| if action_type == "execute_snippet": | |
| stdout = obs.get("stdout", "")[:2000] # Increased from 500 | |
| stderr = obs.get("stderr", "")[:1000] # Increased from 200 | |
| exit_code = obs.get("exit_code", 0) | |
| result = f"[Turn {turn}] execute_snippet (exit={exit_code}):\n{stdout}" | |
| if stderr: | |
| result += f"\nSTDERR:\n{stderr}" | |
| return result | |
| elif action_type == "inspect_tensor": | |
| return f"""[Turn {turn}] inspect_tensor: | |
| shape: {obs.get('shape')} | |
| dtype: {obs.get('dtype')} | |
| requires_grad: {obs.get('requires_grad')} | |
| grad_is_none: {obs.get('grad_is_none')} | |
| min/max/mean: {obs.get('min_val')}/{obs.get('max_val')}/{obs.get('mean_val')} | |
| is_nan: {obs.get('is_nan')}, is_inf: {obs.get('is_inf')}""" | |
| elif action_type == "run_training_probe": | |
| losses = obs.get("losses", [])[:10] | |
| final_loss = obs.get("final_loss") | |
| grad_norms = obs.get("grad_norms", {}) | |
| return f"[Turn {turn}] run_training_probe:\n losses: {losses}\n final_loss: {final_loss}\n grad_norms: {grad_norms}" | |
| elif action_type == "get_variable_state": | |
| results = obs.get("results", {}) | |
| lines = [f"[Turn {turn}] get_variable_state:"] | |
| for expr, res in results.items(): | |
| if res.get("error"): | |
| lines.append(f" {expr}: ERROR - {res['error']}") | |
| else: | |
| val = res.get("repr", str(res.get("value", "?")))[:200] # Increased from 80 | |
| lines.append(f" {expr}: {val}") | |
| return "\n".join(lines) | |
| elif action_type == "inspect_diff": | |
| lines_changed = obs.get("lines_changed", 0) | |
| additions = obs.get("additions", 0) | |
| deletions = obs.get("deletions", 0) | |
| diff = obs.get("diff", "")[:500] | |
| return f"[Turn {turn}] inspect_diff: {lines_changed} lines changed (+{additions}/-{deletions})\n{diff}" | |
| elif action_type == "submit_fix": | |
| reward = obs.get("reward", 0.0) | |
| return f"[Turn {turn}] submit_fix: reward={reward}" | |
| return f"[Turn {turn}] {action_type}: {json.dumps(obs, default=str)[:500]}" | |
| def run_tool_agent(env_url: str, task_id: str, client, max_turns: int = 10) -> float: | |
| """Run a tool-using agent on a single task.""" | |
| print(f"\n{'='*60}") | |
| print(f"Tool Agent: {task_id}") | |
| print(f"{'='*60}") | |
| model = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-Coder-3B-Instruct") | |
| # Reset environment | |
| with httpx.Client(timeout=60.0) as http_client: | |
| resp = http_client.post(f"{env_url}/reset", json={"task_id": task_id}) | |
| resp.raise_for_status() | |
| obs = resp.json().get("observation", resp.json()) | |
| buggy_code = obs.get("buggy_code", "") | |
| task_description = obs.get("task_description", "") | |
| episode_id = obs.get("episode_id", "") # Track episode_id for session persistence | |
| print(f"Task: {task_description[:100]}...") | |
| print(f"Episode ID: {episode_id[:16]}..." if episode_id else "No episode_id") | |
| tool_history = [] | |
| best_reward = 0.0 | |
| for turn in range(1, max_turns + 1): | |
| print(f"\n--- Turn {turn}/{max_turns} ---") | |
| turns_remaining = max_turns - turn | |
| # Build context | |
| history_text = "\n".join(tool_history[-5:]) if tool_history else "No previous tool calls." | |
| # Force submission on last turn | |
| if turns_remaining == 0: | |
| urgency = "\n⚠️ THIS IS YOUR LAST TURN! You MUST call submit_fix NOW with your best fix." | |
| elif turns_remaining <= 2: | |
| urgency = f"\n⚠️ ONLY {turns_remaining} TURN(S) LEFT! Submit your fix soon!" | |
| else: | |
| urgency = "" | |
| user_prompt = f"""Task: {task_description} | |
| Buggy Code: | |
| ```python | |
| {buggy_code} | |
| ``` | |
| Turn {turn}/{max_turns} | |
| Best reward so far: {best_reward} | |
| Turns remaining: {turns_remaining}{urgency} | |
| Tool History: | |
| {history_text} | |
| Analyze the code and submit your fix. Don't waste turns on unnecessary testing.""".strip() | |
| # Get LLM response | |
| try: | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| max_tokens=4096, | |
| temperature=0.2, | |
| ) | |
| response_text = response.choices[0].message.content.strip() | |
| parsed = parse_agent_response(response_text) | |
| except Exception as e: | |
| print(f"LLM Error: {e}") | |
| tool_history.append(f"[Turn {turn}] LLM ERROR: {e}") | |
| continue | |
| action_type = parsed.get("action_type", "submit_fix") | |
| action_params = parsed.get("action_params", {}) | |
| reasoning = parsed.get("reasoning", "")[:100] | |
| # Force submit_fix on last turn if agent didn't choose it | |
| if turns_remaining == 0 and action_type != "submit_fix": | |
| print(f"[OVERRIDE] Last turn - forcing submit_fix instead of {action_type}") | |
| action_type = "submit_fix" | |
| # Use fixed_code from action_params if available, otherwise use any code param | |
| fixed_code = action_params.get("fixed_code") or action_params.get("code") or action_params.get("proposed_code") or buggy_code | |
| action_params = {"fixed_code": fixed_code} | |
| print(f"Action: {action_type}") | |
| print(f"Reasoning: {reasoning}...") | |
| # Build action payload - ALWAYS include episode_id for session tracking | |
| action = { | |
| "action_type": action_type, | |
| "episode_id": episode_id, # Critical for session persistence in HTTP mode | |
| } | |
| if action_type == "execute_snippet": | |
| action["code"] = action_params.get("code", "print('test')") | |
| elif action_type == "inspect_tensor": | |
| action["setup_code"] = action_params.get("setup_code", "")[:8000] | |
| action["target_expression"] = action_params.get("target_expression", "") | |
| elif action_type == "run_training_probe": | |
| action["code"] = action_params.get("code", buggy_code)[:8000] | |
| action["steps"] = min(action_params.get("steps", 5), 10) | |
| elif action_type == "get_variable_state": | |
| action["setup_code"] = action_params.get("setup_code", "")[:8000] | |
| action["expressions"] = action_params.get("expressions", [])[:10] | |
| elif action_type == "inspect_diff": | |
| action["proposed_code"] = action_params.get("proposed_code", "") | |
| elif action_type == "submit_fix": | |
| fixed_code = action_params.get("fixed_code", "") | |
| # Clean markdown | |
| if "```python" in fixed_code: | |
| fixed_code = fixed_code.split("```python", 1)[1].split("```", 1)[0].strip() | |
| elif "```" in fixed_code: | |
| fixed_code = fixed_code.split("```", 1)[1].split("```", 1)[0].strip() | |
| action["fixed_code"] = fixed_code | |
| # Execute action | |
| try: | |
| with httpx.Client(timeout=120.0) as http_client: | |
| resp = http_client.post(f"{env_url}/step", json={"action": action}) | |
| resp.raise_for_status() | |
| result = resp.json() | |
| except Exception as e: | |
| print(f"API Error: {e}") | |
| tool_history.append(f"[Turn {turn}] API ERROR: {e}") | |
| continue | |
| obs = result.get("observation", {}) | |
| reward = float(result.get("reward", 0.0) or 0.0) | |
| done = result.get("done", False) or obs.get("episode_done", False) | |
| # Format and store result | |
| tool_result = format_tool_result(obs, action_type) | |
| tool_history.append(tool_result) | |
| print(tool_result) | |
| if reward > best_reward: | |
| best_reward = reward | |
| if action_type == "submit_fix": | |
| print(f"Reward: {reward:.4f}") | |
| if reward >= 0.95 or done: | |
| break | |
| if done: | |
| break | |
| print(f"\nFinal reward for {task_id}: {best_reward:.4f}") | |
| return best_reward | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Tool-using WhipStudio Agent") | |
| parser.add_argument("--env-url", default="http://localhost:7860", help="Environment URL") | |
| parser.add_argument("--task", default="task1", help="Task ID to run") | |
| parser.add_argument("--all-tasks", action="store_true", help="Run all tasks") | |
| parser.add_argument("--max-turns", type=int, default=10, help="Max turns per task") | |
| args = parser.parse_args() | |
| client = get_client() | |
| tasks = ["task1", "task2", "task3", "task4", "task5", "task6"] if args.all_tasks else [args.task] | |
| results = {} | |
| for task_id in tasks: | |
| try: | |
| score = run_tool_agent(args.env_url, task_id, client, args.max_turns) | |
| results[task_id] = score | |
| except Exception as e: | |
| print(f"Error on {task_id}: {e}") | |
| results[task_id] = 0.0 | |
| if len(results) > 1: | |
| print("\n" + "="*60) | |
| print("FINAL RESULTS") | |
| print("="*60) | |
| total = 0.0 | |
| for task_id, score in results.items(): | |
| emoji = "✅" if score >= 0.7 else ("📈" if score >= 0.3 else "❌") | |
| print(f"{emoji} {task_id}: {score:.4f}") | |
| total += score | |
| avg = total / len(results) if results else 0.0 | |
| print(f"\nAverage: {avg:.4f}") | |
| if __name__ == "__main__": | |
| main() | |