Spaces:
Sleeping
Sleeping
| """ | |
| Simple baseline agent for WhipStudio. | |
| This is a minimal example showing how to interact with the WhipStudio environment | |
| using direct code submission (no tool use). Good for understanding the basic API. | |
| Usage: | |
| python examples/simple_agent.py --env-url http://localhost:7860 | |
| python examples/simple_agent.py --env-url https://your-space.hf.space | |
| """ | |
| import argparse | |
| import os | |
| import httpx | |
| from openai import OpenAI | |
| SYSTEM_PROMPT = """You are an expert PyTorch debugging agent. | |
| You receive a broken training script and must fix ALL bugs in it. | |
| Rules: | |
| - Return ONLY the complete corrected Python code, nothing else. | |
| - No markdown, no backticks, no explanation text. | |
| - The script must print losses in format: LOSSES:[v1, v2, ...] | |
| - For tasks requiring validation metrics, also print: VAL_ACC:X.XX | |
| - Keep all torch.manual_seed() calls intact.""".strip() | |
| def get_client(): | |
| """Initialize OpenAI-compatible client.""" | |
| api_base = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1") | |
| api_key = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError("Set HF_TOKEN or OPENAI_API_KEY environment variable") | |
| return OpenAI(base_url=api_base, api_key=api_key) | |
| def generate_fix(client, buggy_code: str, task_description: str, error_log: str = "") -> str: | |
| """Use LLM to generate a fix for the buggy code.""" | |
| model = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-Coder-32B-Instruct") | |
| user_prompt = f"""Task: {task_description} | |
| Buggy Code: | |
| ```python | |
| {buggy_code} | |
| ```""" | |
| if error_log: | |
| user_prompt += f"\n\nPrevious Error:\n{error_log}" | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| max_tokens=4096, | |
| temperature=0.2, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| def run_task(env_url: str, task_id: str, client, max_attempts: int = 3) -> float: | |
| """Run a single task with multiple attempts.""" | |
| print(f"\n{'='*60}") | |
| print(f"Starting {task_id}") | |
| print(f"{'='*60}") | |
| # Reset environment | |
| with httpx.Client(timeout=60.0) as http_client: | |
| resp = http_client.post(f"{env_url}/reset", json={"task_id": task_id}) | |
| resp.raise_for_status() | |
| obs = resp.json().get("observation", resp.json()) | |
| buggy_code = obs.get("buggy_code", "") | |
| task_description = obs.get("task_description", "") | |
| print(f"Task: {task_description[:100]}...") | |
| best_reward = 0.0 | |
| error_log = "" | |
| for attempt in range(1, max_attempts + 1): | |
| # Reset for each attempt (except first, already reset above) | |
| if attempt > 1: | |
| try: | |
| with httpx.Client(timeout=30.0) as http_client: | |
| resp = http_client.post(f"{env_url}/reset", json={"task_id": task_id}) | |
| resp.raise_for_status() | |
| obs = resp.json().get("observation", resp.json()) | |
| print(f"[Reset for attempt {attempt}]") | |
| except Exception as e: | |
| print(f"Reset Error: {e}") | |
| continue | |
| print(f"\n--- Attempt {attempt}/{max_attempts} ---") | |
| # Generate fix using LLM | |
| try: | |
| fixed_code = generate_fix(client, buggy_code, task_description, error_log) | |
| # Clean up markdown if present | |
| if "```python" in fixed_code: | |
| fixed_code = fixed_code.split("```python", 1)[1].split("```", 1)[0].strip() | |
| elif "```" in fixed_code: | |
| fixed_code = fixed_code.split("```", 1)[1].split("```", 1)[0].strip() | |
| except Exception as e: | |
| print(f"LLM Error: {e}") | |
| continue | |
| # Submit fix | |
| action = { | |
| "action_type": "submit_fix", | |
| "fixed_code": fixed_code, | |
| "attempt_number": attempt, | |
| } | |
| try: | |
| with httpx.Client(timeout=120.0) as http_client: | |
| resp = http_client.post(f"{env_url}/step", json={"action": action}) | |
| resp.raise_for_status() | |
| result = resp.json() | |
| except Exception as e: | |
| print(f"API Error: {e}") | |
| continue | |
| obs = result.get("observation", {}) | |
| reward = float(result.get("reward", 0.0) or 0.0) | |
| done = result.get("done", False) | |
| print(f"Reward: {reward:.4f}") | |
| if reward > best_reward: | |
| best_reward = reward | |
| error_log = obs.get("error_log", "") | |
| # Only stop if we got a great score | |
| if reward >= 0.95: | |
| print(f"Task solved! Stopping attempts.") | |
| break | |
| print(f"\nBest reward for {task_id}: {best_reward:.4f}") | |
| return best_reward | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Simple WhipStudio Agent") | |
| parser.add_argument("--env-url", default="http://localhost:7860", help="Environment URL") | |
| parser.add_argument("--tasks", nargs="+", default=["task1", "task2", "task3", "task4", "task5", "task6"], | |
| help="Task IDs to run") | |
| parser.add_argument("--max-attempts", type=int, default=3, help="Max attempts per task") | |
| args = parser.parse_args() | |
| client = get_client() | |
| results = {} | |
| for task_id in args.tasks: | |
| try: | |
| score = run_task(args.env_url, task_id, client, args.max_attempts) | |
| results[task_id] = score | |
| except Exception as e: | |
| print(f"Error on {task_id}: {e}") | |
| results[task_id] = 0.0 | |
| print("\n" + "="*60) | |
| print("FINAL RESULTS") | |
| print("="*60) | |
| total = 0.0 | |
| for task_id, score in results.items(): | |
| emoji = "✅" if score >= 0.7 else ("📈" if score >= 0.3 else "❌") | |
| print(f"{emoji} {task_id}: {score:.4f}") | |
| total += score | |
| avg = total / len(results) if results else 0.0 | |
| print(f"\nAverage: {avg:.4f}") | |
| if __name__ == "__main__": | |
| main() | |