forgeenv-source / scripts /test_repair_agent.py
akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
"""Smoke-test the trained Repair Agent locally on one episode.
Loads the LoRA adapter pushed to ``akhiilll/forgeenv-repair-agent``, hits
the live ForgeEnv Space for a fresh broken script, asks the model to
emit a unified diff, applies it, and prints the verifier breakdown.
Usage::
python scripts/test_repair_agent.py --seed 7
python scripts/test_repair_agent.py --seed 7 --base-model unsloth/Qwen2.5-Coder-1.5B-Instruct
Requires GPU + transformers/peft. Skip this if you only want a quick
demo -- use ``scripts/test_live_env.py`` or the Gradio Space instead.
"""
from __future__ import annotations
import argparse
import asyncio
import json
from openenv.core import GenericAction, GenericEnvClient
ENV_URL = "https://akhiilll-forgeenv.hf.space"
LORA_REPO = "akhiilll/forgeenv-repair-agent"
REPAIR_PROMPT = """\
You are a senior ML engineer fixing a HuggingFace training script that just broke.
Output ONLY a unified diff (`--- a/script.py` / `+++ b/script.py`) that fixes the
breakage signaled by the error trace. No prose, no fences, no explanation.
# Broken script
```python
{script}
```
# Error trace
```
{error}
```
# Diff
"""
async def fetch_broken_episode(seed: int):
client = GenericEnvClient(base_url=ENV_URL)
res = await client.reset(seed=seed, options={"difficulty": "medium"})
target = res.observation["target_category"]
res = await client.step(GenericAction(
breakage={"action_type": "breakage", "primitive_type": target, "params": {}},
repair=None,
))
obs = res.observation
return client, obs.get("script_content") or obs.get("broken_script") or "", obs.get("error_trace", "")
async def submit_repair(client: GenericEnvClient, diff: str):
res = await client.step(GenericAction(
breakage=None,
repair={"action_type": "repair", "unified_diff": diff},
))
return res
def generate_diff(base_model: str, lora_repo: str, prompt: str) -> str:
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
print(f"loading base model: {base_model}")
tok = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(
base_model,
torch_dtype=torch.bfloat16,
device_map="auto",
)
print(f"attaching LoRA: {lora_repo}")
model = PeftModel.from_pretrained(model, lora_repo)
model.eval()
inputs = tok(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=512,
do_sample=False,
temperature=0.0,
pad_token_id=tok.eos_token_id,
)
text = tok.decode(out[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
return text.strip()
async def main(args) -> None:
print(f"--- pulling broken episode (seed={args.seed}) from {ENV_URL}")
client, broken_script, error_trace = await fetch_broken_episode(args.seed)
if not broken_script:
raise SystemExit("env returned empty script_content; pick a different seed")
print(f"broken script length: {len(broken_script)} chars")
print(f"error trace : {(error_trace[:200] + '...') if len(error_trace) > 200 else error_trace}")
prompt = REPAIR_PROMPT.format(script=broken_script, error=error_trace or "<env did not surface a trace>")
diff = generate_diff(args.base_model, args.lora_repo, prompt)
print("\n=== model diff ===")
print(diff)
print("\n=== submitting diff to env ===")
res = await submit_repair(client, diff)
print(f"reward: {res.reward} done: {res.done}")
breakdown = res.observation.get("reward_breakdown") if isinstance(res.observation, dict) else None
if breakdown:
print("reward_breakdown:")
print(json.dumps(breakdown, indent=2))
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--seed", type=int, default=7)
p.add_argument("--base-model", default="unsloth/Qwen2.5-Coder-1.5B-Instruct")
p.add_argument("--lora-repo", default=LORA_REPO)
args = p.parse_args()
asyncio.run(main(args))