| """Smoke-test the trained Repair Agent locally on one episode. | |
| Loads the LoRA adapter pushed to ``akhiilll/forgeenv-repair-agent``, hits | |
| the live ForgeEnv Space for a fresh broken script, asks the model to | |
| emit a unified diff, applies it, and prints the verifier breakdown. | |
| Usage:: | |
| python scripts/test_repair_agent.py --seed 7 | |
| python scripts/test_repair_agent.py --seed 7 --base-model unsloth/Qwen2.5-Coder-1.5B-Instruct | |
| Requires GPU + transformers/peft. Skip this if you only want a quick | |
| demo -- use ``scripts/test_live_env.py`` or the Gradio Space instead. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import asyncio | |
| import json | |
| from openenv.core import GenericAction, GenericEnvClient | |
| ENV_URL = "https://akhiilll-forgeenv.hf.space" | |
| LORA_REPO = "akhiilll/forgeenv-repair-agent" | |
| REPAIR_PROMPT = """\ | |
| You are a senior ML engineer fixing a HuggingFace training script that just broke. | |
| Output ONLY a unified diff (`--- a/script.py` / `+++ b/script.py`) that fixes the | |
| breakage signaled by the error trace. No prose, no fences, no explanation. | |
| # Broken script | |
| ```python | |
| {script} | |
| ``` | |
| # Error trace | |
| ``` | |
| {error} | |
| ``` | |
| # Diff | |
| """ | |
| async def fetch_broken_episode(seed: int): | |
| client = GenericEnvClient(base_url=ENV_URL) | |
| res = await client.reset(seed=seed, options={"difficulty": "medium"}) | |
| target = res.observation["target_category"] | |
| res = await client.step(GenericAction( | |
| breakage={"action_type": "breakage", "primitive_type": target, "params": {}}, | |
| repair=None, | |
| )) | |
| obs = res.observation | |
| return client, obs.get("script_content") or obs.get("broken_script") or "", obs.get("error_trace", "") | |
| async def submit_repair(client: GenericEnvClient, diff: str): | |
| res = await client.step(GenericAction( | |
| breakage=None, | |
| repair={"action_type": "repair", "unified_diff": diff}, | |
| )) | |
| return res | |
| def generate_diff(base_model: str, lora_repo: str, prompt: str) -> str: | |
| import torch | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| print(f"loading base model: {base_model}") | |
| tok = AutoTokenizer.from_pretrained(base_model) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| base_model, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ) | |
| print(f"attaching LoRA: {lora_repo}") | |
| model = PeftModel.from_pretrained(model, lora_repo) | |
| model.eval() | |
| inputs = tok(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| do_sample=False, | |
| temperature=0.0, | |
| pad_token_id=tok.eos_token_id, | |
| ) | |
| text = tok.decode(out[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True) | |
| return text.strip() | |
| async def main(args) -> None: | |
| print(f"--- pulling broken episode (seed={args.seed}) from {ENV_URL}") | |
| client, broken_script, error_trace = await fetch_broken_episode(args.seed) | |
| if not broken_script: | |
| raise SystemExit("env returned empty script_content; pick a different seed") | |
| print(f"broken script length: {len(broken_script)} chars") | |
| print(f"error trace : {(error_trace[:200] + '...') if len(error_trace) > 200 else error_trace}") | |
| prompt = REPAIR_PROMPT.format(script=broken_script, error=error_trace or "<env did not surface a trace>") | |
| diff = generate_diff(args.base_model, args.lora_repo, prompt) | |
| print("\n=== model diff ===") | |
| print(diff) | |
| print("\n=== submitting diff to env ===") | |
| res = await submit_repair(client, diff) | |
| print(f"reward: {res.reward} done: {res.done}") | |
| breakdown = res.observation.get("reward_breakdown") if isinstance(res.observation, dict) else None | |
| if breakdown: | |
| print("reward_breakdown:") | |
| print(json.dumps(breakdown, indent=2)) | |
| if __name__ == "__main__": | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--seed", type=int, default=7) | |
| p.add_argument("--base-model", default="unsloth/Qwen2.5-Coder-1.5B-Instruct") | |
| p.add_argument("--lora-repo", default=LORA_REPO) | |
| args = p.parse_args() | |
| asyncio.run(main(args)) | |