File size: 4,156 Bytes
a15535e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""Smoke-test the trained Repair Agent locally on one episode.

Loads the LoRA adapter pushed to ``akhiilll/forgeenv-repair-agent``, hits
the live ForgeEnv Space for a fresh broken script, asks the model to
emit a unified diff, applies it, and prints the verifier breakdown.

Usage::

    python scripts/test_repair_agent.py --seed 7
    python scripts/test_repair_agent.py --seed 7 --base-model unsloth/Qwen2.5-Coder-1.5B-Instruct

Requires GPU + transformers/peft. Skip this if you only want a quick
demo -- use ``scripts/test_live_env.py`` or the Gradio Space instead.
"""
from __future__ import annotations

import argparse
import asyncio
import json

from openenv.core import GenericAction, GenericEnvClient

ENV_URL = "https://akhiilll-forgeenv.hf.space"
LORA_REPO = "akhiilll/forgeenv-repair-agent"

REPAIR_PROMPT = """\
You are a senior ML engineer fixing a HuggingFace training script that just broke.
Output ONLY a unified diff (`--- a/script.py` / `+++ b/script.py`) that fixes the
breakage signaled by the error trace. No prose, no fences, no explanation.

# Broken script
```python
{script}
```

# Error trace
```
{error}
```

# Diff
"""


async def fetch_broken_episode(seed: int):
    client = GenericEnvClient(base_url=ENV_URL)
    res = await client.reset(seed=seed, options={"difficulty": "medium"})
    target = res.observation["target_category"]
    res = await client.step(GenericAction(
        breakage={"action_type": "breakage", "primitive_type": target, "params": {}},
        repair=None,
    ))
    obs = res.observation
    return client, obs.get("script_content") or obs.get("broken_script") or "", obs.get("error_trace", "")


async def submit_repair(client: GenericEnvClient, diff: str):
    res = await client.step(GenericAction(
        breakage=None,
        repair={"action_type": "repair", "unified_diff": diff},
    ))
    return res


def generate_diff(base_model: str, lora_repo: str, prompt: str) -> str:
    import torch
    from peft import PeftModel
    from transformers import AutoModelForCausalLM, AutoTokenizer

    print(f"loading base model: {base_model}")
    tok = AutoTokenizer.from_pretrained(base_model)
    model = AutoModelForCausalLM.from_pretrained(
        base_model,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    print(f"attaching LoRA: {lora_repo}")
    model = PeftModel.from_pretrained(model, lora_repo)
    model.eval()

    inputs = tok(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=False,
            temperature=0.0,
            pad_token_id=tok.eos_token_id,
        )
    text = tok.decode(out[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
    return text.strip()


async def main(args) -> None:
    print(f"--- pulling broken episode (seed={args.seed}) from {ENV_URL}")
    client, broken_script, error_trace = await fetch_broken_episode(args.seed)
    if not broken_script:
        raise SystemExit("env returned empty script_content; pick a different seed")
    print(f"broken script length: {len(broken_script)} chars")
    print(f"error trace        : {(error_trace[:200] + '...') if len(error_trace) > 200 else error_trace}")

    prompt = REPAIR_PROMPT.format(script=broken_script, error=error_trace or "<env did not surface a trace>")
    diff = generate_diff(args.base_model, args.lora_repo, prompt)

    print("\n=== model diff ===")
    print(diff)

    print("\n=== submitting diff to env ===")
    res = await submit_repair(client, diff)
    print(f"reward: {res.reward}    done: {res.done}")
    breakdown = res.observation.get("reward_breakdown") if isinstance(res.observation, dict) else None
    if breakdown:
        print("reward_breakdown:")
        print(json.dumps(breakdown, indent=2))


if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--seed", type=int, default=7)
    p.add_argument("--base-model", default="unsloth/Qwen2.5-Coder-1.5B-Instruct")
    p.add_argument("--lora-repo", default=LORA_REPO)
    args = p.parse_args()
    asyncio.run(main(args))