Spaces:
Sleeping
Sleeping
File size: 6,156 Bytes
ffd85e1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 | """
Simple baseline agent for WhipStudio.
This is a minimal example showing how to interact with the WhipStudio environment
using direct code submission (no tool use). Good for understanding the basic API.
Usage:
python examples/simple_agent.py --env-url http://localhost:7860
python examples/simple_agent.py --env-url https://your-space.hf.space
"""
import argparse
import os
import httpx
from openai import OpenAI
SYSTEM_PROMPT = """You are an expert PyTorch debugging agent.
You receive a broken training script and must fix ALL bugs in it.
Rules:
- Return ONLY the complete corrected Python code, nothing else.
- No markdown, no backticks, no explanation text.
- The script must print losses in format: LOSSES:[v1, v2, ...]
- For tasks requiring validation metrics, also print: VAL_ACC:X.XX
- Keep all torch.manual_seed() calls intact.""".strip()
def get_client():
"""Initialize OpenAI-compatible client."""
api_base = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
api_key = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY")
if not api_key:
raise ValueError("Set HF_TOKEN or OPENAI_API_KEY environment variable")
return OpenAI(base_url=api_base, api_key=api_key)
def generate_fix(client, buggy_code: str, task_description: str, error_log: str = "") -> str:
"""Use LLM to generate a fix for the buggy code."""
model = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-Coder-32B-Instruct")
user_prompt = f"""Task: {task_description}
Buggy Code:
```python
{buggy_code}
```"""
if error_log:
user_prompt += f"\n\nPrevious Error:\n{error_log}"
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
max_tokens=4096,
temperature=0.2,
)
return response.choices[0].message.content.strip()
def run_task(env_url: str, task_id: str, client, max_attempts: int = 3) -> float:
"""Run a single task with multiple attempts."""
print(f"\n{'='*60}")
print(f"Starting {task_id}")
print(f"{'='*60}")
# Reset environment
with httpx.Client(timeout=60.0) as http_client:
resp = http_client.post(f"{env_url}/reset", json={"task_id": task_id})
resp.raise_for_status()
obs = resp.json().get("observation", resp.json())
buggy_code = obs.get("buggy_code", "")
task_description = obs.get("task_description", "")
print(f"Task: {task_description[:100]}...")
best_reward = 0.0
error_log = ""
for attempt in range(1, max_attempts + 1):
# Reset for each attempt (except first, already reset above)
if attempt > 1:
try:
with httpx.Client(timeout=30.0) as http_client:
resp = http_client.post(f"{env_url}/reset", json={"task_id": task_id})
resp.raise_for_status()
obs = resp.json().get("observation", resp.json())
print(f"[Reset for attempt {attempt}]")
except Exception as e:
print(f"Reset Error: {e}")
continue
print(f"\n--- Attempt {attempt}/{max_attempts} ---")
# Generate fix using LLM
try:
fixed_code = generate_fix(client, buggy_code, task_description, error_log)
# Clean up markdown if present
if "```python" in fixed_code:
fixed_code = fixed_code.split("```python", 1)[1].split("```", 1)[0].strip()
elif "```" in fixed_code:
fixed_code = fixed_code.split("```", 1)[1].split("```", 1)[0].strip()
except Exception as e:
print(f"LLM Error: {e}")
continue
# Submit fix
action = {
"action_type": "submit_fix",
"fixed_code": fixed_code,
"attempt_number": attempt,
}
try:
with httpx.Client(timeout=120.0) as http_client:
resp = http_client.post(f"{env_url}/step", json={"action": action})
resp.raise_for_status()
result = resp.json()
except Exception as e:
print(f"API Error: {e}")
continue
obs = result.get("observation", {})
reward = float(result.get("reward", 0.0) or 0.0)
done = result.get("done", False)
print(f"Reward: {reward:.4f}")
if reward > best_reward:
best_reward = reward
error_log = obs.get("error_log", "")
# Only stop if we got a great score
if reward >= 0.95:
print(f"Task solved! Stopping attempts.")
break
print(f"\nBest reward for {task_id}: {best_reward:.4f}")
return best_reward
def main():
parser = argparse.ArgumentParser(description="Simple WhipStudio Agent")
parser.add_argument("--env-url", default="http://localhost:7860", help="Environment URL")
parser.add_argument("--tasks", nargs="+", default=["task1", "task2", "task3", "task4", "task5", "task6"],
help="Task IDs to run")
parser.add_argument("--max-attempts", type=int, default=3, help="Max attempts per task")
args = parser.parse_args()
client = get_client()
results = {}
for task_id in args.tasks:
try:
score = run_task(args.env_url, task_id, client, args.max_attempts)
results[task_id] = score
except Exception as e:
print(f"Error on {task_id}: {e}")
results[task_id] = 0.0
print("\n" + "="*60)
print("FINAL RESULTS")
print("="*60)
total = 0.0
for task_id, score in results.items():
emoji = "✅" if score >= 0.7 else ("📈" if score >= 0.3 else "❌")
print(f"{emoji} {task_id}: {score:.4f}")
total += score
avg = total / len(results) if results else 0.0
print(f"\nAverage: {avg:.4f}")
if __name__ == "__main__":
main()
|