import asyncio import json import os import re from dotenv import load_dotenv load_dotenv() import httpx # ── Task difficulty configuration ────────────────────────────────────────── TASK_CONFIG = { "task1": {"difficulty": "easy", "max_turns": 8, "description": "Broken training loop"}, "task2": {"difficulty": "medium", "max_turns": 10, "description": "Silent NaN loss"}, "task3": {"difficulty": "medium", "max_turns": 12, "description": "Label inversion"}, "task4": {"difficulty": "medium", "max_turns": 10, "description": "Wrong loss function"}, "task5": {"difficulty": "medium", "max_turns": 10, "description": "Frozen backbone"}, "task6": {"difficulty": "hard", "max_turns": 15, "description": "Input-Output mismatch"}, } # ── System prompts ───────────────────────────────────────────────────────── SYSTEM_PROMPT_TOOLS = """ You are an expert PyTorch debugging agent with access to debugging tools. You receive a broken training script and must systematically debug and fix ALL bugs. AVAILABLE TOOLS: 1. execute_snippet - Run a quick Python snippet to test hypotheses 2. inspect_tensor - Check tensor shape, dtype, gradients, NaN/Inf 3. get_variable_state - Inspect multiple variable values 4. run_training_probe - Run a few training steps to see loss curve 5. inspect_diff - Review your proposed changes before submitting 6. submit_fix - Submit your final fix (terminal action) DEBUGGING STRATEGY: 1. First, analyze the buggy code carefully 2. Use execute_snippet or get_variable_state to verify your hypotheses 3. Use inspect_tensor to check gradient flow and tensor properties 4. Use run_training_probe to test potential fixes 5. Use inspect_diff to review your changes 6. Only submit_fix when confident RESPONSE FORMAT - CRITICAL: You MUST respond with ONLY a valid JSON object. No markdown, no explanation outside JSON. For submit_fix (THE fixed_code MUST BE ACTUAL PYTHON CODE STRING, NOT JSON): { "reasoning": "Why this fix should work", "action_type": "submit_fix", "action_params": { "fixed_code": "import torch\\nimport torch.nn as nn\\n# Full Python script here\\nprint('LOSSES:', losses)" } } For execute_snippet: { "reasoning": "Testing hypothesis", "action_type": "execute_snippet", "action_params": {"code": "print('test')"} } For inspect_tensor: { "reasoning": "Check gradients", "action_type": "inspect_tensor", "action_params": { "setup_code": "import torch\\nmodel = ...", "target_expression": "model.weight.grad" } } For get_variable_state: { "reasoning": "Verify shapes", "action_type": "get_variable_state", "action_params": { "setup_code": "import torch\\ndata = ...", "expressions": ["data.shape", "data.dtype"] } } For run_training_probe: { "reasoning": "Test my fix", "action_type": "run_training_probe", "action_params": { "code": "import torch\\n# full script", "steps": 5 } } For inspect_diff: { "reasoning": "Review changes", "action_type": "inspect_diff", "action_params": {"proposed_code": "import torch\\n# your fix"} } CRITICAL RULES: - Respond ONLY with valid JSON - For submit_fix: fixed_code = PYTHON STRING (use \\n for newlines), NOT nested JSON - Fixed code must print: LOSSES:[v1, v2, ...] - For task3: also print VAL_ACCS:[...] and FINAL_LOSS:X.XX - Keep torch.manual_seed() intact """.strip() SYSTEM_PROMPT_SIMPLE = """ You are an expert PyTorch debugging agent. You receive a broken training script and must fix ALL bugs in it. Rules: - Return ONLY the complete corrected Python code, nothing else. - No markdown, no backticks, no explanation text. - The script must print losses in format: LOSSES:[v1, v2, ...] - For task3, also print: VAL_ACCS:[v1,...] and FINAL_LOSS:X.XX - Keep all torch.manual_seed() calls intact. """.strip() SUPPORTED_MODEL_IDS = [ "Qwen/Qwen2.5-Coder-1.5B-Instruct", "Qwen/Qwen2.5-Coder-3B-Instruct", "Qwen/Qwen2.5-Coder-7B-Instruct", "Qwen/Qwen2.5-Coder-14B-Instruct", "Qwen/Qwen2.5-Coder-32B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3", ] def get_model(model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct"): from smolagents import InferenceClientModel hf_token = os.environ.get("HF_TOKEN") if not hf_token: raise RuntimeError( "HF_TOKEN is not set. Set HF_TOKEN to run /baseline with InferenceClientModel." ) if model_id not in SUPPORTED_MODEL_IDS: raise ValueError( f"Unsupported model_id '{model_id}'. Supported options: {SUPPORTED_MODEL_IDS}" ) return InferenceClientModel( model_id=model_id, token=hf_token, ) def _extract_text(response) -> str: if isinstance(response, str): return response if hasattr(response, "content"): content = getattr(response, "content") if isinstance(content, str): return content if isinstance(content, list): chunks = [] for item in content: if isinstance(item, str): chunks.append(item) elif isinstance(item, dict): text = item.get("text") or item.get("content") if text: chunks.append(str(text)) if chunks: return "\n".join(chunks) if isinstance(response, dict): text = response.get("content") or response.get("text") if isinstance(text, str): return text return str(response) def _generate_response(model, system_prompt: str, prompt: str) -> str: """Generate a response from the model.""" if hasattr(model, "generate"): generate = getattr(model, "generate") messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ] try: return _extract_text(generate(messages=messages)) except TypeError: return _extract_text(generate(messages)) if callable(model): try: return _extract_text(model(prompt, system_prompt=system_prompt)) except TypeError: return _extract_text(model(prompt)) raise AttributeError("Model does not support callable() or generate() inference APIs") def _parse_agent_response(response: str) -> dict: """Parse the agent's JSON response, handling potential markdown wrapping.""" response = response.strip() # Remove markdown code blocks if present if response.startswith("```json"): response = response[7:] elif response.startswith("```"): response = response[3:] if response.endswith("```"): response = response[:-3] response = response.strip() try: return json.loads(response) except json.JSONDecodeError: # Try to extract JSON from the response match = re.search(r'\{[\s\S]*\}', response) if match: try: return json.loads(match.group()) except json.JSONDecodeError: pass # Fallback: treat entire response as code for submit_fix return { "reasoning": "Fallback: treating response as code", "action_type": "submit_fix", "action_params": {"fixed_code": response} } def _format_tool_result(obs: dict, action_type: str) -> str: """Format tool observation for the agent's context.""" turn = obs.get("turn", 0) error = obs.get("error") if error: return f"Turn {turn} - {action_type}: ERROR - {error}" if action_type == "execute_snippet": stdout = obs.get("stdout", "")[:1500] stderr = obs.get("stderr", "")[:500] exit_code = obs.get("exit_code", 0) timed_out = obs.get("timed_out", False) result = f"Turn {turn} - execute_snippet (exit={exit_code}, timed_out={timed_out}):\n" if stdout: result += f"stdout:\n{stdout}\n" if stderr: result += f"stderr:\n{stderr}\n" return result elif action_type == "inspect_tensor": parts = [f"Turn {turn} - inspect_tensor:"] if obs.get("shape"): parts.append(f" shape: {obs['shape']}") if obs.get("dtype"): parts.append(f" dtype: {obs['dtype']}") if obs.get("requires_grad") is not None: parts.append(f" requires_grad: {obs['requires_grad']}") if obs.get("grad_is_none") is not None: parts.append(f" grad_is_none: {obs['grad_is_none']}") if obs.get("min_val") is not None: parts.append(f" min: {obs['min_val']:.6f}") if obs.get("max_val") is not None: parts.append(f" max: {obs['max_val']:.6f}") if obs.get("mean_val") is not None: parts.append(f" mean: {obs['mean_val']:.6f}") if obs.get("is_nan"): parts.append(" ⚠️ CONTAINS NaN") if obs.get("is_inf"): parts.append(" ⚠️ CONTAINS Inf") return "\n".join(parts) elif action_type == "run_training_probe": losses = obs.get("losses", []) grad_norms = obs.get("grad_norms", {}) final_loss = obs.get("final_loss") loss_is_nan = obs.get("loss_is_nan", False) loss_is_inf = obs.get("loss_is_inf", False) timed_out = obs.get("timed_out", False) result = f"Turn {turn} - run_training_probe:\n" result += f" losses: {losses[:10]}\n" result += f" final_loss: {final_loss}\n" if grad_norms: result += f" grad_norms: {dict(list(grad_norms.items())[:5])}\n" if loss_is_nan: result += " ⚠️ NaN LOSS DETECTED\n" if loss_is_inf: result += " ⚠️ Inf LOSS DETECTED\n" if timed_out: result += " ⚠️ TIMED OUT\n" return result elif action_type == "get_variable_state": results = obs.get("results", {}) lines = [f"Turn {turn} - get_variable_state:"] for expr, res in results.items(): if res.get("error"): lines.append(f" {expr}: ERROR - {res['error']}") else: val = res.get("repr", str(res.get("value", "?")))[:100] typ = res.get("type", "?") shape = res.get("shape") shape_str = f" shape={shape}" if shape else "" lines.append(f" {expr}: {val} ({typ}{shape_str})") return "\n".join(lines) elif action_type == "inspect_diff": lines_changed = obs.get("lines_changed", 0) additions = obs.get("additions", 0) deletions = obs.get("deletions", 0) diff = obs.get("diff", "")[:2000] return f"Turn {turn} - inspect_diff: {lines_changed} lines changed (+{additions}/-{deletions})\n{diff}" return f"Turn {turn} - {action_type}: {json.dumps(obs, default=str)[:500]}" async def run_single_task( task_id: str, env_url: str = "http://localhost:7860", model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct", ) -> float: """Backwards-compatible wrapper that returns just the score.""" result = await run_single_task_detailed(task_id, env_url, model_id) return result["score"] async def run_single_task_detailed( task_id: str, env_url: str = "http://localhost:7860", model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct", use_tools: bool = True, ) -> dict: """Run the baseline agent on a single task with optional tool use.""" model = get_model(model_id) timeout = httpx.Timeout(900.0, connect=10.0) task_config = TASK_CONFIG.get(task_id, {"max_turns": 10, "difficulty": "medium"}) max_turns = task_config["max_turns"] tool_history = [] attempts_log = [] async with httpx.AsyncClient(timeout=timeout) as client: # Reset environment reset_resp = await client.post(f"{env_url}/reset", json={"task_id": task_id}) reset_resp.raise_for_status() obs = reset_resp.json().get("observation", reset_resp.json()) buggy_code = obs.get("buggy_code", "") task_description = obs.get("task_description", "") best_reward = 0.0 best_code = "" best_output = "" turn = 0 if use_tools: # Multi-step tool-using agent while turn < max_turns: turn += 1 # Build context with tool history tool_context = "\n\n".join(tool_history[-5:]) if tool_history else "No previous tool calls." prompt = f""" Task: {task_description} Buggy code: ```python {buggy_code} ``` Turn {turn}/{max_turns} - Tool History: {tool_context} Best reward so far: {best_reward} Analyze the buggy code and decide your next action. Remember: - Use tools to understand the bugs before fixing - You have {max_turns - turn} turns remaining - Submit your fix when confident Respond with a JSON object containing your reasoning and action. """.strip() try: response = _generate_response(model, SYSTEM_PROMPT_TOOLS, prompt) parsed = _parse_agent_response(response) action_type = parsed.get("action_type", "submit_fix") action_params = parsed.get("action_params", {}) reasoning = parsed.get("reasoning", "") # Build action payload action = {"action_type": action_type} if action_type == "execute_snippet": code = action_params.get("code", "print('test')") # Validate it's actual Python, not JSON if code.strip().startswith("{") and '"action_type"' in code: raise ValueError("Model returned nested JSON instead of Python code") action["code"] = code elif action_type == "inspect_tensor": setup_code = action_params.get("setup_code", "") # Truncate if too long (model may be exceeding context) if len(setup_code) > 8000: setup_code = setup_code[:8000] + "\n# ... truncated ..." action["setup_code"] = setup_code action["target_expression"] = action_params.get("target_expression", "") elif action_type == "run_training_probe": code = action_params.get("code", buggy_code) if len(code) > 8000: code = code[:8000] if code.strip().startswith("{") and '"action_type"' in code: raise ValueError("Model returned nested JSON instead of Python code") action["code"] = code action["steps"] = min(action_params.get("steps", 5), 10) elif action_type == "get_variable_state": setup_code = action_params.get("setup_code", "") if len(setup_code) > 8000: setup_code = setup_code[:8000] action["setup_code"] = setup_code action["expressions"] = action_params.get("expressions", [])[:10] elif action_type == "inspect_diff": action["proposed_code"] = action_params.get("proposed_code", "") elif action_type == "submit_fix": fixed_code = action_params.get("fixed_code", "") # Clean up markdown if present if "```python" in fixed_code: fixed_code = fixed_code.split("```python", 1)[1].split("```", 1)[0].strip() elif "```json" in fixed_code: fixed_code = fixed_code.split("```json", 1)[1].split("```", 1)[0].strip() elif "```" in fixed_code: fixed_code = fixed_code.split("```", 1)[1].split("```", 1)[0].strip() # CRITICAL: Detect if model returned JSON instead of Python code if fixed_code.strip().startswith("{") and ('"action_type"' in fixed_code or '"reasoning"' in fixed_code): raise ValueError( "Model returned nested JSON instead of Python code for submit_fix. " "The fixed_code field must contain actual Python code, not JSON." ) # Validate it's not empty if not fixed_code or len(fixed_code) < 20: raise ValueError("submit_fix received empty or too-short code") action["fixed_code"] = fixed_code else: # Unknown action, treat as submit_fix action = {"action_type": "submit_fix", "fixed_code": str(action_params)} # Execute action step_resp = await client.post(f"{env_url}/step", json={"action": action}) step_resp.raise_for_status() result = step_resp.json() obs = result.get("observation", {}) reward = float(result.get("reward", 0.0) or 0.0) done = result.get("done", False) or obs.get("episode_done", False) # Format and store tool result tool_result = _format_tool_result(obs, action_type) tool_history.append(f"[Turn {turn}] Reasoning: {reasoning[:200]}\n{tool_result}") # Log attempt if it was a submit if action_type == "submit_fix": output_log = obs.get("error_log", "") if isinstance(obs, dict) else "" attempts_log.append({ "turn": turn, "action": "submit_fix", "code": action.get("fixed_code", "")[:2000], "output": output_log[:2000], "reward": reward, }) if reward > best_reward: best_reward = reward best_code = action.get("fixed_code", "") best_output = output_log if reward >= 0.95 or done: break else: attempts_log.append({ "turn": turn, "action": action_type, "params": {k: str(v)[:200] for k, v in action_params.items()}, "result": tool_result[:500], }) if done: break except ValueError as ve: # Model error (nested JSON, empty code, etc.) tool_history.append(f"[Turn {turn}] MODEL ERROR: {str(ve)[:300]}") attempts_log.append({ "turn": turn, "action": "error", "params": {"error": str(ve), "response_preview": response[:500]}, "result": "Model generated invalid response - skipping turn", }) # Continue to next turn, give model another chance continue except httpx.HTTPError as he: # API error tool_history.append(f"[Turn {turn}] API ERROR: {str(he)[:200]}") attempts_log.append({ "turn": turn, "action": "error", "params": {"error": str(he)}, "result": "API call failed", }) # Continue to next turn continue except Exception as e: tool_history.append(f"[Turn {turn}] UNEXPECTED ERROR: {str(e)[:200]}") attempts_log.append({ "turn": turn, "action": "error", "params": {"error": str(e), "type": type(e).__name__}, "result": "Unexpected error", }) # Continue to next turn continue # If we haven't submitted yet, do a final submit if best_reward == 0.0 and turn < max_turns: # Generate a simple fix without tools simple_prompt = f""" Task: {task_description} Buggy code: {buggy_code} Tool debugging history: {chr(10).join(tool_history[-3:])} Generate the complete fixed Python code. Return ONLY the code, no explanation. """.strip() try: fixed_code = _generate_response(model, SYSTEM_PROMPT_SIMPLE, simple_prompt) if "```python" in fixed_code: fixed_code = fixed_code.split("```python", 1)[1].split("```", 1)[0].strip() elif "```" in fixed_code: fixed_code = fixed_code.split("```", 1)[1].split("```", 1)[0].strip() step_resp = await client.post(f"{env_url}/step", json={ "action": {"action_type": "submit_fix", "fixed_code": fixed_code} }) result = step_resp.json() reward = float(result.get("reward", 0.0) or 0.0) obs = result.get("observation", {}) if reward > best_reward: best_reward = reward best_code = fixed_code best_output = obs.get("error_log", "") except Exception: pass else: # Simple direct submission (fallback mode) for attempt in range(1, 4): prompt = f""" Task: {task_description} Buggy code: {buggy_code} Previous execution output (if any): {obs.get('error_log', 'None')} Previous score: {obs.get('last_reward', 0.0)} """.strip() fixed_code = _generate_response(model, SYSTEM_PROMPT_SIMPLE, prompt) if "```python" in fixed_code: fixed_code = fixed_code.split("```python", 1)[1].split("```", 1)[0].strip() elif "```" in fixed_code: fixed_code = fixed_code.split("```", 1)[1].split("```", 1)[0].strip() step_payload = {"action": {"action_type": "submit_fix", "fixed_code": fixed_code}} step_resp = await client.post(f"{env_url}/step", json=step_payload) step_resp.raise_for_status() result = step_resp.json() reward = float(result.get("reward", 0.0) or 0.0) obs = result.get("observation", obs) output_log = obs.get("error_log", "") if isinstance(obs, dict) else "" attempts_log.append({ "attempt": attempt, "code": fixed_code, "output": output_log[:3000], "reward": reward, }) if reward > best_reward: best_reward = reward best_code = fixed_code best_output = output_log if result.get("done") or reward >= 0.95: break return { "score": best_reward, "fixed_code": best_code, "output": best_output[:3000], "attempts": attempts_log, "tool_history": tool_history, } if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--env-url", default="http://localhost:7860") parser.add_argument("--no-tools", action="store_true", help="Disable tool use") parser.add_argument("--task", default=None, help="Run single task") args = parser.parse_args() async def main(): tasks = [args.task] if args.task else ["task1", "task2", "task3", "task4", "task5", "task6"] scores = {} for tid in tasks: try: result = await asyncio.wait_for( run_single_task_detailed(tid, args.env_url, use_tools=not args.no_tools), timeout=900.0 ) s = result["score"] print(f"{tid}: {s:.4f}") if result.get("tool_history"): print(f" Tool calls: {len(result['tool_history'])}") except TimeoutError: s = 0.0 print(f"{tid}: TIMEOUT") scores[tid] = round(s, 4) if len(scores) > 1: print(f"Average: {sum(scores.values()) / len(scores):.4f}") asyncio.run(main())