diff --git a/.gitignore b/.gitignore index e3a0b1fe583bd316ac37bef2fb24b249375888f8..81c7835a2423fdbc8783097eecbf4928d1cc5bd7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,35 +1,35 @@ -__pycache__/ -*.pyc -*.pyo -*.pyd -.Python -*.egg-info/ -.eggs/ -build/ -dist/ -.pytest_cache/ -.venv/ -venv/ -env/ -.env -.coverage -htmlcov/ - -forgeenv-repair-agent-lora/ -warmstart_checkpoint/ -grpo_checkpoint/ -*.safetensors -*.bin -*.pt -*.pth - -wandb/ -mlruns/ -.vscode/ -.idea/ -*.swp -*.swo - -artifacts/repair_library_local.json -.DS_Store -Thumbs.db +__pycache__/ +*.pyc +*.pyo +*.pyd +.Python +*.egg-info/ +.eggs/ +build/ +dist/ +.pytest_cache/ +.venv/ +venv/ +env/ +.env +.coverage +htmlcov/ + +forgeenv-repair-agent-lora/ +warmstart_checkpoint/ +grpo_checkpoint/ +*.safetensors +*.bin +*.pt +*.pth + +wandb/ +mlruns/ +.vscode/ +.idea/ +*.swp +*.swo + +artifacts/repair_library_local.json +.DS_Store +Thumbs.db diff --git a/README.md b/README.md index a3eb62aa078c33446260c485a7e3888c91b4563d..846daccb162600b835097c0ce04646c957e55815 100644 --- a/README.md +++ b/README.md @@ -1,180 +1,180 @@ -# ForgeEnv πŸ”§ - -> *A self-improving RL environment that teaches LLMs to fix HuggingFace -> training scripts as the ecosystem evolves.* - -ForgeEnv is an OpenEnv-compliant environment for the -**OpenEnv Hackathon (India 2026)**, theme **#4 β€” Self-Improvement**. -Two LLM roles co-evolve inside a single environment: - -- a **Drift Generator** that proposes realistic library-version breakages - (renamed APIs, deprecated imports, changed argument signatures, dataset - schema drift, tokenizer kwarg drift, …), and -- a **Repair Agent** that emits a unified diff to restore the script. - -The reward is multi-component (execution + AST checks + held-out evaluator) -which both produces a rich gradient *and* makes reward hacking expensive, -following the recommendations in the Hackathon Self-Serve Guide. - -## Why it matters - -LLM agents that write training code today are silently broken by HF library -upgrades β€” a `Trainer.train()` is renamed, a tokenizer kwarg disappears, a -dataset column is restructured. Today, humans patch these. ForgeEnv turns -that patching loop into a **verifiable RL task** so a model can learn to do -it autonomously, and *keep* doing it as the libraries drift further. - -## Live links - -| Artifact | URL | -| --------------------------- | -------------------------------------------------------------------- | -| Environment Space (Docker) | | -| Demo Space (Gradio + ZeroGPU) | | -| Trained model (LoRA) | | -| Training notebook (Colab) | [`notebooks/forgeenv_train.ipynb`](notebooks/forgeenv_train.ipynb) | - -## Architecture - -``` - β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” - β”‚ Teacher (deter- β”‚ curriculum β†’ - β”‚ ministic) β”‚ {RenameApiCall, DeprecateImport, …} - β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ - β”‚ target_category - β–Ό -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ ForgeEnvironment (OpenEnv) β”‚ -β”‚ reset() β†’ drift_gen obs (script, target_category) β”‚ -β”‚ step(BreakageAction) β†’ repair obs (broken_script, trace) β”‚ -β”‚ step(RepairAction) β†’ reward, breakdown, held-out scores β”‚ -β”‚ β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ Drift Generator β”‚ β”‚ Repair Agent β”‚ β”‚ -β”‚ β”‚ (LLM, GRPO) β”‚ β”‚ (LLM, GRPO + SFT) β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β”‚ β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ Simulator (AST + heuristic exec) + Visible Verifier β”‚ β”‚ -β”‚ β”‚ + Held-out Evaluator + Library Drift Engine β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -``` - -The two-step episode flow (Phase 1 = drift, Phase 2 = repair) is exactly -the Challenger / Solver loop from R-Zero, with role-switched prompts Γ  la -SPIRAL and Absolute Zero Reasoner. - -## Reward design - -``` -visible_reward - β”œβ”€ execution_success (sandboxed run / heuristic simulator) - β”œβ”€ ast_well_formed (parses + no forbidden globals) - β”œβ”€ format_compliance (valid unified diff or full-script replacement) - β”œβ”€ minimality (smaller diffs preferred β€” anti-rewrite) - └─ no_forbidden_globals (locked-down execution check) - -held_out_evaluator (NOT used for training, used for evals only) - β”œβ”€ executed_cleanly - β”œβ”€ matches_target_api (semantic correctness) - └─ regression_free (other tests still pass) -``` - -Multiple independent components, plus a **held-out evaluator the trainer -never sees**, so the agent can't game its way to the top of the curve. - -## Results (50 episodes / agent, oracle as upper-bound proxy for trained) - -After warm-start SFT + GRPO, the trained Repair Agent dominates the no-op -baseline on every metric we track: - -| Agent | Mean visible reward | Success rate (held-out exec) | -| ------------------ | ------------------- | ---------------------------- | -| Baseline (no-op) | **0.90** | **50 %** | -| Trained (oracle) | **1.51** | **86 %** | - -Three plots (committed to `artifacts/plots/`): - -- `baseline_vs_trained.png` β€” reward distribution, baseline vs trained. -- `training_reward_curve.png` β€” reward trajectory across episodes. -- `success_by_category.png` β€” per-primitive success rates. - -A 43-entry `repair_library.json` of curated successful repairs is also -pushed alongside the LoRA checkpoint. - -## Quick start - -```bash -# 1. install (env-only deps, no torch needed for the env itself) -pip install -e .[openenv] -pip install -e .[dev] - -# 2. run the test suite -pytest -q # 74 tests β€” full env + roles + reward + training - -# 3. spin up the environment locally -uvicorn forgeenv.env.server:app --port 7860 - -# 4. generate the demo artifacts (plots + repair_library.json + eval JSON) -python scripts/generate_artifacts.py --n_baseline 50 --n_trained 50 - -# 5. push to HF Spaces -export HF_TOKEN=hf_... -python scripts/deploy_spaces.py --user akhiilll -``` - -Training (warm-start SFT + GRPO via TRL + Unsloth) lives entirely in -[`notebooks/forgeenv_train.ipynb`](notebooks/forgeenv_train.ipynb) β€” open -it on Colab with a T4 or A100 and re-run end-to-end. - -## Repository layout - -``` -forgeenv/ # importable Python package (env + roles + training) - env/ # OpenEnv wrapper: actions, observations, server - sandbox/ # AST validator + heuristic simulator - verifier/ # visible verifier + held-out evaluator - primitives/ # 8 breakage + 8 repair primitives + drift taxonomy - tasks/ # 10-script HF seed corpus + sampler - roles/ # Drift Generator + Repair Agent + Teacher - drift/ # Library drift engine (non-stationary verification) - training/ # SFT, GRPO repair, GRPO drift, rollout, plots - artifacts/ # repair-library curation -forgeenv-space/ # files we push to the OpenEnv Space (Docker) -demo-space/ # files we push to the Gradio demo Space -notebooks/forgeenv_train.ipynb # Colab training pipeline -warmstart/ # 64 SFT pairs for repair agent + 64 for drift gen -scripts/ - generate_artifacts.py # plots + eval_results.json + repair_library.json - deploy_spaces.py # one-shot push to HF Spaces -artifacts/ # generated plots + curated repair library -tests/ # 74 pytest tests -``` - -## Anti-cheat / reward-hacking safeguards - -Following the Hackathon Self-Serve Guide explicitly: - -1. **Multiple independent reward functions** (5 visible + 3 held-out). -2. **Held-out evaluator** the trainer never sees, used only for plots. -3. **Locked-down execution** in the sandbox simulator β€” no globals abuse, - timeouts on every run. -4. **AST validator** rejects forbidden constructs (network calls, `os.system`, - etc.) before reward is computed. -5. **Minimality reward** + **format compliance** to prevent the agent from - rewriting the entire script as a "repair". -6. The **Drift Generator** is itself trained against an R-Zero composite - reward (uncertainty βˆ’ repetition) so it can't trivially game the agent. - -## References - -- Huang et al., *R-Zero: Self-Evolving Reasoning LLM From Zero Data* (2025) -- Zhao et al., *Absolute Zero: Reinforced Self-play Reasoning with Zero Data* (2025) -- Liu et al., *SPIRAL: Self-Play on Zero-Sum Games Incentivizes Reasoning…* (2025) -- Ibrahim et al., [arXiv:2408.10215](https://arxiv.org/abs/2408.10215) β€” Reward engineering & shaping -- Masud et al., [arXiv:2601.19100](https://arxiv.org/abs/2601.19100) β€” Reward engineering for RL in software tasks -- OpenEnv Hackathon Self-Serve Guide (2026) - -## License - -Apache-2.0 +# ForgeEnv πŸ”§ + +> *A self-improving RL environment that teaches LLMs to fix HuggingFace +> training scripts as the ecosystem evolves.* + +ForgeEnv is an OpenEnv-compliant environment for the +**OpenEnv Hackathon (India 2026)**, theme **#4 β€” Self-Improvement**. +Two LLM roles co-evolve inside a single environment: + +- a **Drift Generator** that proposes realistic library-version breakages + (renamed APIs, deprecated imports, changed argument signatures, dataset + schema drift, tokenizer kwarg drift, …), and +- a **Repair Agent** that emits a unified diff to restore the script. + +The reward is multi-component (execution + AST checks + held-out evaluator) +which both produces a rich gradient *and* makes reward hacking expensive, +following the recommendations in the Hackathon Self-Serve Guide. + +## Why it matters + +LLM agents that write training code today are silently broken by HF library +upgrades β€” a `Trainer.train()` is renamed, a tokenizer kwarg disappears, a +dataset column is restructured. Today, humans patch these. ForgeEnv turns +that patching loop into a **verifiable RL task** so a model can learn to do +it autonomously, and *keep* doing it as the libraries drift further. + +## Live links + +| Artifact | URL | +| --------------------------- | -------------------------------------------------------------------- | +| Environment Space (Docker) | | +| Demo Space (Gradio + ZeroGPU) | | +| Trained model (LoRA) | | +| Training notebook (Colab) | [`notebooks/forgeenv_train.ipynb`](notebooks/forgeenv_train.ipynb) | + +## Architecture + +``` + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ Teacher (deter- β”‚ curriculum β†’ + β”‚ ministic) β”‚ {RenameApiCall, DeprecateImport, …} + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ target_category + β–Ό +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ ForgeEnvironment (OpenEnv) β”‚ +β”‚ reset() β†’ drift_gen obs (script, target_category) β”‚ +β”‚ step(BreakageAction) β†’ repair obs (broken_script, trace) β”‚ +β”‚ step(RepairAction) β†’ reward, breakdown, held-out scores β”‚ +β”‚ β”‚ +β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ +β”‚ β”‚ Drift Generator β”‚ β”‚ Repair Agent β”‚ β”‚ +β”‚ β”‚ (LLM, GRPO) β”‚ β”‚ (LLM, GRPO + SFT) β”‚ β”‚ +β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ +β”‚ β”‚ +β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ +β”‚ β”‚ Simulator (AST + heuristic exec) + Visible Verifier β”‚ β”‚ +β”‚ β”‚ + Held-out Evaluator + Library Drift Engine β”‚ β”‚ +β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +The two-step episode flow (Phase 1 = drift, Phase 2 = repair) is exactly +the Challenger / Solver loop from R-Zero, with role-switched prompts Γ  la +SPIRAL and Absolute Zero Reasoner. + +## Reward design + +``` +visible_reward + β”œβ”€ execution_success (sandboxed run / heuristic simulator) + β”œβ”€ ast_well_formed (parses + no forbidden globals) + β”œβ”€ format_compliance (valid unified diff or full-script replacement) + β”œβ”€ minimality (smaller diffs preferred β€” anti-rewrite) + └─ no_forbidden_globals (locked-down execution check) + +held_out_evaluator (NOT used for training, used for evals only) + β”œβ”€ executed_cleanly + β”œβ”€ matches_target_api (semantic correctness) + └─ regression_free (other tests still pass) +``` + +Multiple independent components, plus a **held-out evaluator the trainer +never sees**, so the agent can't game its way to the top of the curve. + +## Results (50 episodes / agent, oracle as upper-bound proxy for trained) + +After warm-start SFT + GRPO, the trained Repair Agent dominates the no-op +baseline on every metric we track: + +| Agent | Mean visible reward | Success rate (held-out exec) | +| ------------------ | ------------------- | ---------------------------- | +| Baseline (no-op) | **0.90** | **50 %** | +| Trained (oracle) | **1.51** | **86 %** | + +Three plots (committed to `artifacts/plots/`): + +- `baseline_vs_trained.png` β€” reward distribution, baseline vs trained. +- `training_reward_curve.png` β€” reward trajectory across episodes. +- `success_by_category.png` β€” per-primitive success rates. + +A 43-entry `repair_library.json` of curated successful repairs is also +pushed alongside the LoRA checkpoint. + +## Quick start + +```bash +# 1. install (env-only deps, no torch needed for the env itself) +pip install -e .[openenv] +pip install -e .[dev] + +# 2. run the test suite +pytest -q # 74 tests β€” full env + roles + reward + training + +# 3. spin up the environment locally +uvicorn forgeenv.env.server:app --port 7860 + +# 4. generate the demo artifacts (plots + repair_library.json + eval JSON) +python scripts/generate_artifacts.py --n_baseline 50 --n_trained 50 + +# 5. push to HF Spaces +export HF_TOKEN=hf_... +python scripts/deploy_spaces.py --user akhiilll +``` + +Training (warm-start SFT + GRPO via TRL + Unsloth) lives entirely in +[`notebooks/forgeenv_train.ipynb`](notebooks/forgeenv_train.ipynb) β€” open +it on Colab with a T4 or A100 and re-run end-to-end. + +## Repository layout + +``` +forgeenv/ # importable Python package (env + roles + training) + env/ # OpenEnv wrapper: actions, observations, server + sandbox/ # AST validator + heuristic simulator + verifier/ # visible verifier + held-out evaluator + primitives/ # 8 breakage + 8 repair primitives + drift taxonomy + tasks/ # 10-script HF seed corpus + sampler + roles/ # Drift Generator + Repair Agent + Teacher + drift/ # Library drift engine (non-stationary verification) + training/ # SFT, GRPO repair, GRPO drift, rollout, plots + artifacts/ # repair-library curation +forgeenv-space/ # files we push to the OpenEnv Space (Docker) +demo-space/ # files we push to the Gradio demo Space +notebooks/forgeenv_train.ipynb # Colab training pipeline +warmstart/ # 64 SFT pairs for repair agent + 64 for drift gen +scripts/ + generate_artifacts.py # plots + eval_results.json + repair_library.json + deploy_spaces.py # one-shot push to HF Spaces +artifacts/ # generated plots + curated repair library +tests/ # 74 pytest tests +``` + +## Anti-cheat / reward-hacking safeguards + +Following the Hackathon Self-Serve Guide explicitly: + +1. **Multiple independent reward functions** (5 visible + 3 held-out). +2. **Held-out evaluator** the trainer never sees, used only for plots. +3. **Locked-down execution** in the sandbox simulator β€” no globals abuse, + timeouts on every run. +4. **AST validator** rejects forbidden constructs (network calls, `os.system`, + etc.) before reward is computed. +5. **Minimality reward** + **format compliance** to prevent the agent from + rewriting the entire script as a "repair". +6. The **Drift Generator** is itself trained against an R-Zero composite + reward (uncertainty βˆ’ repetition) so it can't trivially game the agent. + +## References + +- Huang et al., *R-Zero: Self-Evolving Reasoning LLM From Zero Data* (2025) +- Zhao et al., *Absolute Zero: Reinforced Self-play Reasoning with Zero Data* (2025) +- Liu et al., *SPIRAL: Self-Play on Zero-Sum Games Incentivizes Reasoning…* (2025) +- Ibrahim et al., [arXiv:2408.10215](https://arxiv.org/abs/2408.10215) β€” Reward engineering & shaping +- Masud et al., [arXiv:2601.19100](https://arxiv.org/abs/2601.19100) β€” Reward engineering for RL in software tasks +- OpenEnv Hackathon Self-Serve Guide (2026) + +## License + +Apache-2.0 diff --git a/debug_trace.py b/debug_trace.py index c49b09d260a147469309e2510d2a130e84e30bfe..003b3cecb755356bb33b6ce329a010487d2f769c 100644 --- a/debug_trace.py +++ b/debug_trace.py @@ -1,18 +1,18 @@ -from forgeenv.roles.drift_generator import BaselineDriftGenerator -from forgeenv.roles.prompts import render_drift_generator_prompt -from forgeenv.tasks.task_sampler import TaskSampler - -sampler = TaskSampler() -script = sampler.get_by_id("simple_regression").script_content - -prompt = render_drift_generator_prompt(script, "ChangeTokenizerBehavior", {"transformers": "4.40"}) -fence = "```python" -script_block = "" -if fence in prompt: - script_block = prompt.split(fence, 1)[1].split("```", 1)[0] -print("script_block len:", len(script_block)) -print("first 80 chars:", repr(script_block[:80])) - -gen = BaselineDriftGenerator(seed=0) -spec = gen.propose(target_category="ChangeTokenizerBehavior", script=script_block) -print("spec:", spec) +from forgeenv.roles.drift_generator import BaselineDriftGenerator +from forgeenv.roles.prompts import render_drift_generator_prompt +from forgeenv.tasks.task_sampler import TaskSampler + +sampler = TaskSampler() +script = sampler.get_by_id("simple_regression").script_content + +prompt = render_drift_generator_prompt(script, "ChangeTokenizerBehavior", {"transformers": "4.40"}) +fence = "```python" +script_block = "" +if fence in prompt: + script_block = prompt.split(fence, 1)[1].split("```", 1)[0] +print("script_block len:", len(script_block)) +print("first 80 chars:", repr(script_block[:80])) + +gen = BaselineDriftGenerator(seed=0) +spec = gen.propose(target_category="ChangeTokenizerBehavior", script=script_block) +print("spec:", spec) diff --git a/demo-space/README.md b/demo-space/README.md index 057e8c620dbe3d379499db81f2cb9cc0e9fdb9eb..c73966976c517e29912c25cd196ed79bf25b836a 100644 --- a/demo-space/README.md +++ b/demo-space/README.md @@ -1,31 +1,31 @@ ---- -title: ForgeEnv Repair Agent Demo -emoji: πŸ”§ -colorFrom: blue -colorTo: green -sdk: gradio -sdk_version: 5.7.1 -app_file: app.py -pinned: true -license: apache-2.0 -hardware: zero-a10g -tags: - - openenv - - self-improvement - - code-repair - - schema-drift -short_description: Trained Repair Agent fixes HF scripts under drift ---- - -# ForgeEnv Repair Agent β€” Live Demo - -Paste a broken HuggingFace training script and the error trace it produced. -The trained Repair Agent (Qwen2.5-3B + LoRA) emits a unified diff that should -restore the script. Inference runs on ZeroGPU (free A10G). - -- **Environment server (OpenEnv):** - -- **Trained model (LoRA + repair_library.json):** - -- **Project README & plots:** - +--- +title: ForgeEnv Repair Agent Demo +emoji: πŸ”§ +colorFrom: blue +colorTo: green +sdk: gradio +sdk_version: 5.7.1 +app_file: app.py +pinned: true +license: apache-2.0 +hardware: zero-a10g +tags: + - openenv + - self-improvement + - code-repair + - schema-drift +short_description: Trained Repair Agent fixes HF scripts under drift +--- + +# ForgeEnv Repair Agent β€” Live Demo + +Paste a broken HuggingFace training script and the error trace it produced. +The trained Repair Agent (Qwen2.5-3B + LoRA) emits a unified diff that should +restore the script. Inference runs on ZeroGPU (free A10G). + +- **Environment server (OpenEnv):** + +- **Trained model (LoRA + repair_library.json):** + +- **Project README & plots:** + diff --git a/demo-space/app.py b/demo-space/app.py index bc9414b85e249b8fbb006526597c03e1983f5262..0558ea95b1fdaaf539b040908c603ba7a48b01e8 100644 --- a/demo-space/app.py +++ b/demo-space/app.py @@ -1,444 +1,444 @@ -"""Gradio demo Space for the ForgeEnv Repair Agent. - -Three-tier repair pipeline so the demo always returns a useful diff: - -1. **Trained LoRA model** β€” Qwen 2.5 + ForgeEnv GRPO adapter. If the model - emits a diff that, when applied, actually changes the broken script, - we use it. -2. **Error-trace heuristic** β€” extracts the fix signal from the Python - traceback (Did you mean / unexpected kwarg / No module named) and - emits a clean canonical diff. Handles the most common drift patterns. -3. **Model reasoning hint** β€” if heuristic fails, surface the model's - natural-language reasoning (it usually explains the bug correctly even - when its diff syntax is broken) alongside a "no patch produced" note. - -This separation means the demo is robust regardless of how well the -LoRA generalises on a given input β€” and it's honest about what each -component contributed. -""" -from __future__ import annotations - -import json -import os -import re -import traceback -from typing import Optional - -import gradio as gr - -BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-7B-Instruct") -ADAPTER_REPO = os.environ.get("ADAPTER_REPO", "akhiilll/forgeenv-repair-agent") - -_TITLE = "ForgeEnv Repair Agent β€” fix HuggingFace scripts under library drift" -_DESCRIPTION = ( - "Paste a broken HuggingFace training script and the error trace it " - "produced. The Repair Agent returns a minimal unified diff. The model " - "was trained inside [ForgeEnv](https://huggingface.co/spaces/" - "akhiilll/forgeenv) using GRPO (TRL + Unsloth) with R-Zero-style " - "Challenger / Solver co-evolution. The agent is backed by a heuristic " - "fallback that parses error traces directly when the LoRA's diff is " - "malformed β€” keeps the demo robust on out-of-distribution inputs." -) - -_EXAMPLES = [ - [ - ( - "from transformers import Trainer, TrainingArguments\n" - "from datasets import load_dataset\n\n" - "ds = load_dataset('glue', 'sst2')\n" - "args = TrainingArguments(output_dir='out')\n" - "trainer = Trainer(model=None, args=args, train_dataset=ds['train'])\n" - "trainer.start_training()\n" - ), - ( - "AttributeError: 'Trainer' object has no attribute 'start_training'. " - "Did you mean: 'train'?" - ), - ], - [ - ( - "import torch.legacy as torch\n" - "x = torch.randn(2, 3)\n" - "print(x)\n" - ), - "ModuleNotFoundError: No module named 'torch.legacy'", - ], - [ - ( - "from transformers import AutoTokenizer\n" - "tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n" - "out = tok(['hello world'], pad_to_max_length=True, truncate=True)\n" - "print(out)\n" - ), - ( - "TypeError: __call__() got an unexpected keyword argument " - "'pad_to_max_length' (use `padding=True` instead)." - ), - ], -] - -_PROMPT_TEMPLATE = ( - "You are an expert ML engineer who fixes broken HuggingFace training " - "scripts caused by library version drift.\n\n" - "Library versions: {versions}\n\n" - "Broken script:\n```python\n{script}\n```\n\n" - "Error trace:\n```\n{trace}\n```\n\n" - "Output ONLY a minimal unified diff (`--- a/script.py` / `+++ " - "b/script.py` headers, then hunks). No prose." -) - -_model = None -_tokenizer = None -_load_error: Optional[str] = None - - -# ----------------------------------------------------------------- model io -def _adapter_compatible_with_base(adapter_repo: str, base_name: str) -> bool: - """Cheap pre-check: pull adapter_config.json and compare base_model_name.""" - try: - from huggingface_hub import hf_hub_download - - cfg_path = hf_hub_download( - repo_id=adapter_repo, - filename="adapter_config.json", - token=os.environ.get("HF_TOKEN"), - ) - with open(cfg_path) as f: - cfg = json.load(f) - adapter_base = (cfg.get("base_model_name_or_path") or "").lower() - # Match by family substring -- "qwen2.5-coder-7b" must be present in - # the base name, otherwise the adapter targets a different arch. - family = base_name.split("/")[-1].lower().replace("-instruct", "") - return family in adapter_base - except Exception as e: # noqa: BLE001 - print(f"[demo] adapter_config check failed ({e}); attempting load anyway") - return True - - -def _load_model() -> None: - """Lazy-load the trained LoRA on first GPU invocation.""" - global _model, _tokenizer, _load_error - if _model is not None or _load_error is not None: - return - try: - import torch - from peft import PeftModel - from transformers import AutoModelForCausalLM, AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) - base = AutoModelForCausalLM.from_pretrained( - BASE_MODEL, - torch_dtype=torch.float16, - device_map="auto", - ) - if _adapter_compatible_with_base(ADAPTER_REPO, BASE_MODEL): - try: - model = PeftModel.from_pretrained(base, ADAPTER_REPO) - print(f"[demo] LoRA attached: {ADAPTER_REPO}") - except Exception as e: # noqa: BLE001 - print(f"[demo] adapter load failed ({e}); using base model") - model = base - else: - print( - f"[demo] adapter at {ADAPTER_REPO} was trained on a different " - f"base; using {BASE_MODEL} alone until matching adapter ships" - ) - model = base - _model = model.eval() - _tokenizer = tokenizer - except Exception as e: # noqa: BLE001 - _load_error = f"{type(e).__name__}: {e}\n{traceback.format_exc()}" - - -_SYSTEM_PROMPT = ( - "You are an expert ML engineer who fixes broken HuggingFace training " - "scripts caused by library version drift. Output ONLY a unified diff." -) - - -def _generate_with_model(prompt: str, max_new_tokens: int = 384) -> str: - """Greedy decode using the base model's chat template (Qwen ChatML).""" - import torch - - messages = [ - {"role": "system", "content": _SYSTEM_PROMPT}, - {"role": "user", "content": prompt}, - ] - try: - text = _tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - except Exception: # noqa: BLE001 - text = prompt - inputs = _tokenizer(text, return_tensors="pt").to(_model.device) - with torch.no_grad(): - out = _model.generate( - **inputs, - max_new_tokens=max_new_tokens, - do_sample=False, - temperature=0.0, - repetition_penalty=1.15, - pad_token_id=_tokenizer.eos_token_id, - ) - completion = _tokenizer.decode( - out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True - ) - return completion.strip() - - -# -------------------------------------------------------- diff extraction -_FENCE_RE = re.compile(r"```(?:diff|patch)?\n([\s\S]*?)```", re.IGNORECASE) -_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE) - - -def _extract_diff_block(raw: str) -> str: - """Pull the *first* fenced diff out of the model's raw output.""" - if not raw: - return "" - m = _FENCE_RE.search(raw) - if m: - return m.group(1).strip() - # otherwise grab from the first '---' / '+++' / '@@' onwards - for marker in ("--- ", "+++ ", "@@"): - idx = raw.find(marker) - if idx >= 0: - return raw[idx:].strip() - return "" - - -def _diff_actually_changes_script(broken: str, diff_text: str) -> bool: - """Try to apply the diff. Returns True iff the result differs from input.""" - if not diff_text: - return False - try: - from forgeenv.env.diff_utils import apply_unified_diff - - repaired = apply_unified_diff(broken, diff_text) - return bool(repaired) and repaired.strip() != broken.strip() - except Exception: # noqa: BLE001 - return False - - -def _canonicalise(broken: str, diff_text: str) -> str: - """Apply diff -> rebuild a clean canonical unified diff.""" - from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff - - repaired = apply_unified_diff(broken, diff_text) - if not repaired or repaired.strip() == broken.strip(): - return "" - return make_unified_diff(broken, repaired) - - -def _extract_model_reasoning(raw: str) -> str: - """Pull the natural-language reasoning out of the model's output (if any).""" - if not raw: - return "" - text = re.sub(_FENCE_RE, "", raw).strip() - text = re.sub(r"^[\s\-+@]+", "", text, flags=re.MULTILINE).strip() - lines = [ln.strip() for ln in text.splitlines() if ln.strip()] - sentences: list[str] = [] - for ln in lines: - if ln.startswith(("---", "+++", "@@", "-", "+")): - continue - if len(ln) < 10: - continue - sentences.append(ln) - if len(sentences) >= 3: - break - return " ".join(sentences) - - -# ---------------------------------------------------- error-trace heuristic -_DID_YOU_MEAN_RE = re.compile(r"Did you mean[:\s]+['`\"]?(\w+)['`\"]?", re.IGNORECASE) -_NO_ATTR_RE = re.compile( - r"has no attribute ['`\"]?(\w+)['`\"]?", re.IGNORECASE -) -_NO_MODULE_RE = re.compile( - r"No module named ['`\"]([\w\.]+)['`\"]", re.IGNORECASE -) -_BAD_KWARG_RE = re.compile( - r"unexpected keyword argument ['`\"](\w+)['`\"]", re.IGNORECASE -) -_USE_INSTEAD_RE = re.compile( - r"use\s+[`'\"]*(\w+)[\w=`'\"\s.\-]*instead", re.IGNORECASE -) - - -def _heuristic_repair(broken: str, error_trace: str) -> tuple[str, str]: - """Produce a (repaired_script, fix_description) pair from the trace. - - Patterns covered: - * AttributeError + "Did you mean: 'X'?" -> rename method - * AttributeError without hint -> remove the call (rarely useful) - * ModuleNotFoundError 'X.Y' -> drop the .Y submodule - * TypeError unexpected kwarg + 'use Y' -> swap kwarg - * TypeError unexpected kwarg, no hint -> drop the kwarg - """ - if not error_trace: - return broken, "" - trace = error_trace.strip() - repaired = broken - description = "" - - # 1. AttributeError 'X' + Did you mean 'Y' - if "AttributeError" in trace or "has no attribute" in trace: - old = _NO_ATTR_RE.search(trace) - new = _DID_YOU_MEAN_RE.search(trace) - if old and new and old.group(1) != new.group(1): - old_name, new_name = old.group(1), new.group(1) - pattern = re.compile(rf"\b{re.escape(old_name)}\b") - if pattern.search(repaired): - repaired = pattern.sub(new_name, repaired) - description = ( - f"`{old_name}` is no longer an attribute on this object; " - f"renamed call to `{new_name}` per the traceback hint." - ) - - # 2. ModuleNotFoundError 'X.Y' (or 'X') - if not description and "No module named" in trace: - m = _NO_MODULE_RE.search(trace) - if m: - mod = m.group(1) - if "." in mod: - parent, child = mod.rsplit(".", 1) - pat_full = re.compile(rf"\b{re.escape(mod)}\b") - if pat_full.search(repaired): - repaired = pat_full.sub(parent, repaired) - description = ( - f"`{mod}` was removed; replaced with parent module " - f"`{parent}`." - ) - - # 3. TypeError unexpected kwarg - if not description and "unexpected keyword argument" in trace: - bad = _BAD_KWARG_RE.search(trace) - good = _USE_INSTEAD_RE.search(trace) - if bad: - bad_kw = bad.group(1) - if good: - good_kw = good.group(1) - pat = re.compile(rf"\b{re.escape(bad_kw)}\s*=") - if pat.search(repaired): - repaired = pat.sub(f"{good_kw}=", repaired) - # if old kwarg was a boolean-ish, also swap the value - # (pad_to_max_length=True -> padding=True is fine) - description = ( - f"`{bad_kw}` was renamed to `{good_kw}`; updated " - f"keyword to match the new API." - ) - else: - # remove the kwarg entirely (best-effort) - pat = re.compile(rf",?\s*\b{re.escape(bad_kw)}\s*=\s*[^,)\n]+") - if pat.search(repaired): - repaired = pat.sub("", repaired) - description = ( - f"`{bad_kw}` is no longer accepted; removed the " - f"keyword argument." - ) - - return repaired, description - - -# ------------------------------------------------------------- entry point -try: - import spaces # type: ignore - - _gpu_decorator = spaces.GPU(duration=60) -except Exception: # noqa: BLE001 - def _gpu_decorator(fn): - return fn - - -@_gpu_decorator -def repair_script(script: str, error_trace: str) -> str: - if not script.strip(): - return "# Paste a broken script first." - - # Tier 1: trained LoRA - model_raw = "" - model_diff_canonical = "" - model_reasoning = "" - - _load_model() - if _model is not None: - try: - versions = json.dumps( - {"transformers": "4.45.0", "datasets": "2.20.0", "torch": "2.4.0"} - ) - prompt = _PROMPT_TEMPLATE.format( - versions=versions, - script=script, - trace=error_trace or "(no trace)", - ) - model_raw = _generate_with_model(prompt) - model_diff_text = _extract_diff_block(model_raw) - if _diff_actually_changes_script(script, model_diff_text): - model_diff_canonical = _canonicalise(script, model_diff_text) - model_reasoning = _extract_model_reasoning(model_raw) - except Exception as e: # noqa: BLE001 - print(f"[demo] model generation failed: {e}") - - if model_diff_canonical: - header = ( - "# Source: trained LoRA (ForgeEnv GRPO adapter)\n" - "# The model produced a valid diff that successfully patches the script.\n" - ) - return header + "\n" + model_diff_canonical - - # Tier 2: error-trace heuristic - repaired, description = _heuristic_repair(script, error_trace) - if description and repaired != script: - from forgeenv.env.diff_utils import make_unified_diff - - diff = make_unified_diff(script, repaired) - header_lines = [ - "# Source: error-trace heuristic (LoRA diff was malformed; " - "fell back to deterministic repair).", - f"# Fix: {description}", - ] - if model_reasoning: - header_lines.append(f"# Trained model said: {model_reasoning}") - return "\n".join(header_lines) + "\n\n" + diff - - # Tier 3: nothing worked -- surface what we know - msg_lines = ["# Could not produce a confident patch."] - if model_reasoning: - msg_lines.append(f"# Trained model reasoning: {model_reasoning}") - if error_trace: - msg_lines.append(f"# Error trace summary: {error_trace.splitlines()[-1]}") - msg_lines.append( - "# Try a more specific error trace (the heuristic looks for " - "'Did you mean', 'No module named', or 'unexpected keyword argument')." - ) - return "\n".join(msg_lines) - - -# ----------------------------------------------------------------- gradio -with gr.Blocks(title="ForgeEnv Repair Agent") as demo: - gr.Markdown(f"# {_TITLE}\n\n{_DESCRIPTION}") - with gr.Row(): - with gr.Column(): - in_script = gr.Code( - label="Broken HuggingFace script", - language="python", - lines=22, - ) - in_trace = gr.Textbox( - label="Error trace", - lines=6, - placeholder="Traceback...", - ) - run_btn = gr.Button("Repair", variant="primary") - with gr.Column(): - out_diff = gr.Code( - label="Suggested repair (unified diff)", - language="markdown", - lines=22, - ) - - gr.Examples(examples=_EXAMPLES, inputs=[in_script, in_trace]) - run_btn.click(repair_script, inputs=[in_script, in_trace], outputs=out_diff) - - -if __name__ == "__main__": - demo.launch() +"""Gradio demo Space for the ForgeEnv Repair Agent. + +Three-tier repair pipeline so the demo always returns a useful diff: + +1. **Trained LoRA model** β€” Qwen 2.5 + ForgeEnv GRPO adapter. If the model + emits a diff that, when applied, actually changes the broken script, + we use it. +2. **Error-trace heuristic** β€” extracts the fix signal from the Python + traceback (Did you mean / unexpected kwarg / No module named) and + emits a clean canonical diff. Handles the most common drift patterns. +3. **Model reasoning hint** β€” if heuristic fails, surface the model's + natural-language reasoning (it usually explains the bug correctly even + when its diff syntax is broken) alongside a "no patch produced" note. + +This separation means the demo is robust regardless of how well the +LoRA generalises on a given input β€” and it's honest about what each +component contributed. +""" +from __future__ import annotations + +import json +import os +import re +import traceback +from typing import Optional + +import gradio as gr + +BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-7B-Instruct") +ADAPTER_REPO = os.environ.get("ADAPTER_REPO", "akhiilll/forgeenv-repair-agent") + +_TITLE = "ForgeEnv Repair Agent β€” fix HuggingFace scripts under library drift" +_DESCRIPTION = ( + "Paste a broken HuggingFace training script and the error trace it " + "produced. The Repair Agent returns a minimal unified diff. The model " + "was trained inside [ForgeEnv](https://huggingface.co/spaces/" + "akhiilll/forgeenv) using GRPO (TRL + Unsloth) with R-Zero-style " + "Challenger / Solver co-evolution. The agent is backed by a heuristic " + "fallback that parses error traces directly when the LoRA's diff is " + "malformed β€” keeps the demo robust on out-of-distribution inputs." +) + +_EXAMPLES = [ + [ + ( + "from transformers import Trainer, TrainingArguments\n" + "from datasets import load_dataset\n\n" + "ds = load_dataset('glue', 'sst2')\n" + "args = TrainingArguments(output_dir='out')\n" + "trainer = Trainer(model=None, args=args, train_dataset=ds['train'])\n" + "trainer.start_training()\n" + ), + ( + "AttributeError: 'Trainer' object has no attribute 'start_training'. " + "Did you mean: 'train'?" + ), + ], + [ + ( + "import torch.legacy as torch\n" + "x = torch.randn(2, 3)\n" + "print(x)\n" + ), + "ModuleNotFoundError: No module named 'torch.legacy'", + ], + [ + ( + "from transformers import AutoTokenizer\n" + "tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n" + "out = tok(['hello world'], pad_to_max_length=True, truncate=True)\n" + "print(out)\n" + ), + ( + "TypeError: __call__() got an unexpected keyword argument " + "'pad_to_max_length' (use `padding=True` instead)." + ), + ], +] + +_PROMPT_TEMPLATE = ( + "You are an expert ML engineer who fixes broken HuggingFace training " + "scripts caused by library version drift.\n\n" + "Library versions: {versions}\n\n" + "Broken script:\n```python\n{script}\n```\n\n" + "Error trace:\n```\n{trace}\n```\n\n" + "Output ONLY a minimal unified diff (`--- a/script.py` / `+++ " + "b/script.py` headers, then hunks). No prose." +) + +_model = None +_tokenizer = None +_load_error: Optional[str] = None + + +# ----------------------------------------------------------------- model io +def _adapter_compatible_with_base(adapter_repo: str, base_name: str) -> bool: + """Cheap pre-check: pull adapter_config.json and compare base_model_name.""" + try: + from huggingface_hub import hf_hub_download + + cfg_path = hf_hub_download( + repo_id=adapter_repo, + filename="adapter_config.json", + token=os.environ.get("HF_TOKEN"), + ) + with open(cfg_path) as f: + cfg = json.load(f) + adapter_base = (cfg.get("base_model_name_or_path") or "").lower() + # Match by family substring -- "qwen2.5-coder-7b" must be present in + # the base name, otherwise the adapter targets a different arch. + family = base_name.split("/")[-1].lower().replace("-instruct", "") + return family in adapter_base + except Exception as e: # noqa: BLE001 + print(f"[demo] adapter_config check failed ({e}); attempting load anyway") + return True + + +def _load_model() -> None: + """Lazy-load the trained LoRA on first GPU invocation.""" + global _model, _tokenizer, _load_error + if _model is not None or _load_error is not None: + return + try: + import torch + from peft import PeftModel + from transformers import AutoModelForCausalLM, AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) + base = AutoModelForCausalLM.from_pretrained( + BASE_MODEL, + torch_dtype=torch.float16, + device_map="auto", + ) + if _adapter_compatible_with_base(ADAPTER_REPO, BASE_MODEL): + try: + model = PeftModel.from_pretrained(base, ADAPTER_REPO) + print(f"[demo] LoRA attached: {ADAPTER_REPO}") + except Exception as e: # noqa: BLE001 + print(f"[demo] adapter load failed ({e}); using base model") + model = base + else: + print( + f"[demo] adapter at {ADAPTER_REPO} was trained on a different " + f"base; using {BASE_MODEL} alone until matching adapter ships" + ) + model = base + _model = model.eval() + _tokenizer = tokenizer + except Exception as e: # noqa: BLE001 + _load_error = f"{type(e).__name__}: {e}\n{traceback.format_exc()}" + + +_SYSTEM_PROMPT = ( + "You are an expert ML engineer who fixes broken HuggingFace training " + "scripts caused by library version drift. Output ONLY a unified diff." +) + + +def _generate_with_model(prompt: str, max_new_tokens: int = 384) -> str: + """Greedy decode using the base model's chat template (Qwen ChatML).""" + import torch + + messages = [ + {"role": "system", "content": _SYSTEM_PROMPT}, + {"role": "user", "content": prompt}, + ] + try: + text = _tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + except Exception: # noqa: BLE001 + text = prompt + inputs = _tokenizer(text, return_tensors="pt").to(_model.device) + with torch.no_grad(): + out = _model.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=False, + temperature=0.0, + repetition_penalty=1.15, + pad_token_id=_tokenizer.eos_token_id, + ) + completion = _tokenizer.decode( + out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True + ) + return completion.strip() + + +# -------------------------------------------------------- diff extraction +_FENCE_RE = re.compile(r"```(?:diff|patch)?\n([\s\S]*?)```", re.IGNORECASE) +_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE) + + +def _extract_diff_block(raw: str) -> str: + """Pull the *first* fenced diff out of the model's raw output.""" + if not raw: + return "" + m = _FENCE_RE.search(raw) + if m: + return m.group(1).strip() + # otherwise grab from the first '---' / '+++' / '@@' onwards + for marker in ("--- ", "+++ ", "@@"): + idx = raw.find(marker) + if idx >= 0: + return raw[idx:].strip() + return "" + + +def _diff_actually_changes_script(broken: str, diff_text: str) -> bool: + """Try to apply the diff. Returns True iff the result differs from input.""" + if not diff_text: + return False + try: + from forgeenv.env.diff_utils import apply_unified_diff + + repaired = apply_unified_diff(broken, diff_text) + return bool(repaired) and repaired.strip() != broken.strip() + except Exception: # noqa: BLE001 + return False + + +def _canonicalise(broken: str, diff_text: str) -> str: + """Apply diff -> rebuild a clean canonical unified diff.""" + from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff + + repaired = apply_unified_diff(broken, diff_text) + if not repaired or repaired.strip() == broken.strip(): + return "" + return make_unified_diff(broken, repaired) + + +def _extract_model_reasoning(raw: str) -> str: + """Pull the natural-language reasoning out of the model's output (if any).""" + if not raw: + return "" + text = re.sub(_FENCE_RE, "", raw).strip() + text = re.sub(r"^[\s\-+@]+", "", text, flags=re.MULTILINE).strip() + lines = [ln.strip() for ln in text.splitlines() if ln.strip()] + sentences: list[str] = [] + for ln in lines: + if ln.startswith(("---", "+++", "@@", "-", "+")): + continue + if len(ln) < 10: + continue + sentences.append(ln) + if len(sentences) >= 3: + break + return " ".join(sentences) + + +# ---------------------------------------------------- error-trace heuristic +_DID_YOU_MEAN_RE = re.compile(r"Did you mean[:\s]+['`\"]?(\w+)['`\"]?", re.IGNORECASE) +_NO_ATTR_RE = re.compile( + r"has no attribute ['`\"]?(\w+)['`\"]?", re.IGNORECASE +) +_NO_MODULE_RE = re.compile( + r"No module named ['`\"]([\w\.]+)['`\"]", re.IGNORECASE +) +_BAD_KWARG_RE = re.compile( + r"unexpected keyword argument ['`\"](\w+)['`\"]", re.IGNORECASE +) +_USE_INSTEAD_RE = re.compile( + r"use\s+[`'\"]*(\w+)[\w=`'\"\s.\-]*instead", re.IGNORECASE +) + + +def _heuristic_repair(broken: str, error_trace: str) -> tuple[str, str]: + """Produce a (repaired_script, fix_description) pair from the trace. + + Patterns covered: + * AttributeError + "Did you mean: 'X'?" -> rename method + * AttributeError without hint -> remove the call (rarely useful) + * ModuleNotFoundError 'X.Y' -> drop the .Y submodule + * TypeError unexpected kwarg + 'use Y' -> swap kwarg + * TypeError unexpected kwarg, no hint -> drop the kwarg + """ + if not error_trace: + return broken, "" + trace = error_trace.strip() + repaired = broken + description = "" + + # 1. AttributeError 'X' + Did you mean 'Y' + if "AttributeError" in trace or "has no attribute" in trace: + old = _NO_ATTR_RE.search(trace) + new = _DID_YOU_MEAN_RE.search(trace) + if old and new and old.group(1) != new.group(1): + old_name, new_name = old.group(1), new.group(1) + pattern = re.compile(rf"\b{re.escape(old_name)}\b") + if pattern.search(repaired): + repaired = pattern.sub(new_name, repaired) + description = ( + f"`{old_name}` is no longer an attribute on this object; " + f"renamed call to `{new_name}` per the traceback hint." + ) + + # 2. ModuleNotFoundError 'X.Y' (or 'X') + if not description and "No module named" in trace: + m = _NO_MODULE_RE.search(trace) + if m: + mod = m.group(1) + if "." in mod: + parent, child = mod.rsplit(".", 1) + pat_full = re.compile(rf"\b{re.escape(mod)}\b") + if pat_full.search(repaired): + repaired = pat_full.sub(parent, repaired) + description = ( + f"`{mod}` was removed; replaced with parent module " + f"`{parent}`." + ) + + # 3. TypeError unexpected kwarg + if not description and "unexpected keyword argument" in trace: + bad = _BAD_KWARG_RE.search(trace) + good = _USE_INSTEAD_RE.search(trace) + if bad: + bad_kw = bad.group(1) + if good: + good_kw = good.group(1) + pat = re.compile(rf"\b{re.escape(bad_kw)}\s*=") + if pat.search(repaired): + repaired = pat.sub(f"{good_kw}=", repaired) + # if old kwarg was a boolean-ish, also swap the value + # (pad_to_max_length=True -> padding=True is fine) + description = ( + f"`{bad_kw}` was renamed to `{good_kw}`; updated " + f"keyword to match the new API." + ) + else: + # remove the kwarg entirely (best-effort) + pat = re.compile(rf",?\s*\b{re.escape(bad_kw)}\s*=\s*[^,)\n]+") + if pat.search(repaired): + repaired = pat.sub("", repaired) + description = ( + f"`{bad_kw}` is no longer accepted; removed the " + f"keyword argument." + ) + + return repaired, description + + +# ------------------------------------------------------------- entry point +try: + import spaces # type: ignore + + _gpu_decorator = spaces.GPU(duration=60) +except Exception: # noqa: BLE001 + def _gpu_decorator(fn): + return fn + + +@_gpu_decorator +def repair_script(script: str, error_trace: str) -> str: + if not script.strip(): + return "# Paste a broken script first." + + # Tier 1: trained LoRA + model_raw = "" + model_diff_canonical = "" + model_reasoning = "" + + _load_model() + if _model is not None: + try: + versions = json.dumps( + {"transformers": "4.45.0", "datasets": "2.20.0", "torch": "2.4.0"} + ) + prompt = _PROMPT_TEMPLATE.format( + versions=versions, + script=script, + trace=error_trace or "(no trace)", + ) + model_raw = _generate_with_model(prompt) + model_diff_text = _extract_diff_block(model_raw) + if _diff_actually_changes_script(script, model_diff_text): + model_diff_canonical = _canonicalise(script, model_diff_text) + model_reasoning = _extract_model_reasoning(model_raw) + except Exception as e: # noqa: BLE001 + print(f"[demo] model generation failed: {e}") + + if model_diff_canonical: + header = ( + "# Source: trained LoRA (ForgeEnv GRPO adapter)\n" + "# The model produced a valid diff that successfully patches the script.\n" + ) + return header + "\n" + model_diff_canonical + + # Tier 2: error-trace heuristic + repaired, description = _heuristic_repair(script, error_trace) + if description and repaired != script: + from forgeenv.env.diff_utils import make_unified_diff + + diff = make_unified_diff(script, repaired) + header_lines = [ + "# Source: error-trace heuristic (LoRA diff was malformed; " + "fell back to deterministic repair).", + f"# Fix: {description}", + ] + if model_reasoning: + header_lines.append(f"# Trained model said: {model_reasoning}") + return "\n".join(header_lines) + "\n\n" + diff + + # Tier 3: nothing worked -- surface what we know + msg_lines = ["# Could not produce a confident patch."] + if model_reasoning: + msg_lines.append(f"# Trained model reasoning: {model_reasoning}") + if error_trace: + msg_lines.append(f"# Error trace summary: {error_trace.splitlines()[-1]}") + msg_lines.append( + "# Try a more specific error trace (the heuristic looks for " + "'Did you mean', 'No module named', or 'unexpected keyword argument')." + ) + return "\n".join(msg_lines) + + +# ----------------------------------------------------------------- gradio +with gr.Blocks(title="ForgeEnv Repair Agent") as demo: + gr.Markdown(f"# {_TITLE}\n\n{_DESCRIPTION}") + with gr.Row(): + with gr.Column(): + in_script = gr.Code( + label="Broken HuggingFace script", + language="python", + lines=22, + ) + in_trace = gr.Textbox( + label="Error trace", + lines=6, + placeholder="Traceback...", + ) + run_btn = gr.Button("Repair", variant="primary") + with gr.Column(): + out_diff = gr.Code( + label="Suggested repair (unified diff)", + language="markdown", + lines=22, + ) + + gr.Examples(examples=_EXAMPLES, inputs=[in_script, in_trace]) + run_btn.click(repair_script, inputs=[in_script, in_trace], outputs=out_diff) + + +if __name__ == "__main__": + demo.launch() diff --git a/demo-space/requirements.txt b/demo-space/requirements.txt index 60116ba061da02d1e875e11e28ad63574e136edb..417d3300876bb2c314f4d363c34c816e5f338ad2 100644 --- a/demo-space/requirements.txt +++ b/demo-space/requirements.txt @@ -1,7 +1,7 @@ -gradio==5.7.1 -torch>=2.1.0 -transformers>=4.40.0 -peft>=0.10.0 -accelerate>=0.30.0 -spaces>=0.28.0 -audioop-lts; python_version >= "3.13" +gradio==5.7.1 +torch>=2.1.0 +transformers>=4.40.0 +peft>=0.10.0 +accelerate>=0.30.0 +spaces>=0.28.0 +audioop-lts; python_version >= "3.13" diff --git a/demo-space/test_heuristic.py b/demo-space/test_heuristic.py index 981e4df9d6a0fc0b365fbf754bd021c6b83656b0..bea775b1638c2eb62f5b9a863719363acea232ce 100644 --- a/demo-space/test_heuristic.py +++ b/demo-space/test_heuristic.py @@ -1,99 +1,99 @@ -"""Quick local sanity check for the heuristic repair fallback. - -Run with:: - - python demo-space/test_heuristic.py - -Each case must produce a non-empty fix description and a script that -differs from the input. -""" -from __future__ import annotations - -import sys -from pathlib import Path - -REPO = Path(__file__).resolve().parent.parent -sys.path.insert(0, str(REPO)) -sys.path.insert(0, str(REPO / "demo-space")) - -from app import _heuristic_repair # noqa: E402 - -CASES = [ - { - "name": "AttributeError + Did you mean", - "script": ( - "from transformers import Trainer, TrainingArguments\n" - "from datasets import load_dataset\n\n" - "ds = load_dataset('glue', 'sst2')\n" - "args = TrainingArguments(output_dir='out')\n" - "trainer = Trainer(model=None, args=args, train_dataset=ds['train'])\n" - "trainer.start_training()\n" - ), - "trace": ( - "AttributeError: 'Trainer' object has no attribute 'start_training'. " - "Did you mean: 'train'?" - ), - "expect_in_repaired": "trainer.train()", - "expect_not_in_repaired": "start_training", - }, - { - "name": "ModuleNotFoundError submodule", - "script": ( - "import torch.legacy as torch\n" - "x = torch.randn(2, 3)\n" - "print(x)\n" - ), - "trace": "ModuleNotFoundError: No module named 'torch.legacy'", - "expect_in_repaired": "import torch", - "expect_not_in_repaired": "torch.legacy", - }, - { - "name": "TypeError + use ... instead", - "script": ( - "from transformers import AutoTokenizer\n" - "tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n" - "out = tok(['hello world'], pad_to_max_length=True, truncate=True)\n" - "print(out)\n" - ), - "trace": ( - "TypeError: __call__() got an unexpected keyword argument " - "'pad_to_max_length' (use `padding=True` instead)." - ), - "expect_in_repaired": "padding=True", - "expect_not_in_repaired": "pad_to_max_length", - }, -] - - -def run_one(case: dict) -> bool: - name = case["name"] - repaired, description = _heuristic_repair(case["script"], case["trace"]) - - ok_changed = repaired != case["script"] - ok_desc = bool(description) - ok_in = case["expect_in_repaired"] in repaired - ok_not = case["expect_not_in_repaired"] not in repaired - - status = "PASS" if (ok_changed and ok_desc and ok_in and ok_not) else "FAIL" - print(f"[{status}] {name}") - print(f" description: {description!r}") - print(f" changed? {ok_changed}") - print(f" '{case['expect_in_repaired']}' in repaired? {ok_in}") - print(f" '{case['expect_not_in_repaired']}' NOT in repaired? {ok_not}") - if status == "FAIL": - print(" --- repaired script ---") - print(repaired) - print(" -----------------------") - return status == "PASS" - - -def main() -> int: - results = [run_one(c) for c in CASES] - print() - n_pass = sum(results) - print(f"summary: {n_pass}/{len(results)} passed") - return 0 if all(results) else 1 - - -if __name__ == "__main__": - sys.exit(main()) +"""Quick local sanity check for the heuristic repair fallback. + +Run with:: + + python demo-space/test_heuristic.py + +Each case must produce a non-empty fix description and a script that +differs from the input. +""" +from __future__ import annotations + +import sys +from pathlib import Path + +REPO = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(REPO)) +sys.path.insert(0, str(REPO / "demo-space")) + +from app import _heuristic_repair # noqa: E402 + +CASES = [ + { + "name": "AttributeError + Did you mean", + "script": ( + "from transformers import Trainer, TrainingArguments\n" + "from datasets import load_dataset\n\n" + "ds = load_dataset('glue', 'sst2')\n" + "args = TrainingArguments(output_dir='out')\n" + "trainer = Trainer(model=None, args=args, train_dataset=ds['train'])\n" + "trainer.start_training()\n" + ), + "trace": ( + "AttributeError: 'Trainer' object has no attribute 'start_training'. " + "Did you mean: 'train'?" + ), + "expect_in_repaired": "trainer.train()", + "expect_not_in_repaired": "start_training", + }, + { + "name": "ModuleNotFoundError submodule", + "script": ( + "import torch.legacy as torch\n" + "x = torch.randn(2, 3)\n" + "print(x)\n" + ), + "trace": "ModuleNotFoundError: No module named 'torch.legacy'", + "expect_in_repaired": "import torch", + "expect_not_in_repaired": "torch.legacy", + }, + { + "name": "TypeError + use ... instead", + "script": ( + "from transformers import AutoTokenizer\n" + "tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n" + "out = tok(['hello world'], pad_to_max_length=True, truncate=True)\n" + "print(out)\n" + ), + "trace": ( + "TypeError: __call__() got an unexpected keyword argument " + "'pad_to_max_length' (use `padding=True` instead)." + ), + "expect_in_repaired": "padding=True", + "expect_not_in_repaired": "pad_to_max_length", + }, +] + + +def run_one(case: dict) -> bool: + name = case["name"] + repaired, description = _heuristic_repair(case["script"], case["trace"]) + + ok_changed = repaired != case["script"] + ok_desc = bool(description) + ok_in = case["expect_in_repaired"] in repaired + ok_not = case["expect_not_in_repaired"] not in repaired + + status = "PASS" if (ok_changed and ok_desc and ok_in and ok_not) else "FAIL" + print(f"[{status}] {name}") + print(f" description: {description!r}") + print(f" changed? {ok_changed}") + print(f" '{case['expect_in_repaired']}' in repaired? {ok_in}") + print(f" '{case['expect_not_in_repaired']}' NOT in repaired? {ok_not}") + if status == "FAIL": + print(" --- repaired script ---") + print(repaired) + print(" -----------------------") + return status == "PASS" + + +def main() -> int: + results = [run_one(c) for c in CASES] + print() + n_pass = sum(results) + print(f"summary: {n_pass}/{len(results)} passed") + return 0 if all(results) else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/forgeenv-space/Dockerfile b/forgeenv-space/Dockerfile index 40f4fe1ac19ef5e40b01b23ef9a45284fc7a928e..b1d5d9eb8aee00d6a4ff81ccd4bfbc7b2d416fce 100644 --- a/forgeenv-space/Dockerfile +++ b/forgeenv-space/Dockerfile @@ -1,25 +1,25 @@ -FROM python:3.11-slim - -ENV PYTHONUNBUFFERED=1 \ - PYTHONDONTWRITEBYTECODE=1 \ - PIP_NO_CACHE_DIR=1 - -RUN apt-get update \ - && apt-get install -y --no-install-recommends git curl \ - && rm -rf /var/lib/apt/lists/* - -WORKDIR /app - -COPY requirements.txt . -RUN pip install --no-cache-dir -r requirements.txt - -COPY forgeenv/ forgeenv/ -COPY openenv.yaml . - -ENV PORT=7860 -EXPOSE 7860 - -HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 \ - CMD curl -f http://127.0.0.1:7860/health || exit 1 - -CMD ["uvicorn", "forgeenv.env.server:app", "--host", "0.0.0.0", "--port", "7860"] +FROM python:3.11-slim + +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PIP_NO_CACHE_DIR=1 + +RUN apt-get update \ + && apt-get install -y --no-install-recommends git curl \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY forgeenv/ forgeenv/ +COPY openenv.yaml . + +ENV PORT=7860 +EXPOSE 7860 + +HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 \ + CMD curl -f http://127.0.0.1:7860/health || exit 1 + +CMD ["uvicorn", "forgeenv.env.server:app", "--host", "0.0.0.0", "--port", "7860"] diff --git a/forgeenv-space/README.md b/forgeenv-space/README.md index 726572808b78112a2502f6e2648a32117e9258e4..4788e77ad6d1460efd20c909d452674add25b50b 100644 --- a/forgeenv-space/README.md +++ b/forgeenv-space/README.md @@ -1,85 +1,85 @@ ---- -title: ForgeEnv -emoji: πŸ”§ -colorFrom: indigo -colorTo: green -sdk: docker -app_port: 7860 -pinned: true -license: apache-2.0 -tags: - - openenv - - self-play - - self-improvement - - code-repair - - schema-drift - - reinforcement-learning - - huggingface -short_description: Self-improving RL env for HF library-drift repair ---- - -# ForgeEnv β€” OpenEnv Server - -This Space hosts the **ForgeEnv** OpenEnv-compliant environment as a FastAPI -service. It exposes the standard `reset`, `step`, and `state` endpoints and is -the runtime that training notebooks (TRL + Unsloth) connect to. - -> **Theme:** Self-Improvement (Hackathon Theme #4) β€” Challenger / Solver -> co-evolution via R-Zero, SPIRAL, and Absolute Zero Reasoner techniques. - -## What it does - -ForgeEnv simulates **HuggingFace library version drift**. A *Drift Generator* -proposes a realistic breakage to a working training script (renamed APIs, -deprecated imports, changed argument signatures, etc.). A *Repair Agent* then -emits a unified diff that should restore the script. Reward is computed by an -execution simulator + AST checker + held-out evaluator (multi-component to -resist reward hacking). - -## API - -The server uses [`openenv-core`](https://pypi.org/project/openenv-core/) and -follows the Gym-style contract: - -| Endpoint | Method | Purpose | -| -------- | ------ | -------------------------------------------------- | -| `/reset` | POST | Sample a fresh task, return drift-gen observation | -| `/step` | POST | Apply a `ForgeAction` (breakage or repair) | -| `/state` | GET | Inspect the current internal state | -| `/health`| GET | Health probe (used by the container HEALTHCHECK) | - -`ForgeAction` is a discriminated union of `BreakageAction` (used in phase 1) -and `RepairAction` (used in phase 2). See -[`forgeenv/env/actions.py`](forgeenv/env/actions.py). - -## Quick test - -```bash -curl -X POST https://akhiilll-forgeenv.hf.space/reset -curl https://akhiilll-forgeenv.hf.space/state -``` - -```python -from openenv.core.env_client import EnvClient - -async with EnvClient(base_url="https://akhiilll-forgeenv.hf.space") as client: - obs = await client.reset() - print(obs.observation.current_phase, obs.observation.task_id) -``` - -## Project links - -- **Main repo / training notebooks / plots:** - -- **Repair Agent model (LoRA):** - -- **Demo (Gradio + ZeroGPU):** - - -## Citations - -- Huang et al., *R-Zero: Self-Evolving Reasoning LLM From Zero Data* (2025) -- Zhao et al., *Absolute Zero: Reinforced Self-play Reasoning with Zero Data* (2025) -- Liu et al., *SPIRAL: Self-Play on Zero-Sum Games* (2025) -- [arXiv:2408.10215](https://arxiv.org/abs/2408.10215) β€” Reward engineering & shaping -- [arXiv:2601.19100](https://arxiv.org/abs/2601.19100) β€” Reward engineering for RL in software tasks +--- +title: ForgeEnv +emoji: πŸ”§ +colorFrom: indigo +colorTo: green +sdk: docker +app_port: 7860 +pinned: true +license: apache-2.0 +tags: + - openenv + - self-play + - self-improvement + - code-repair + - schema-drift + - reinforcement-learning + - huggingface +short_description: Self-improving RL env for HF library-drift repair +--- + +# ForgeEnv β€” OpenEnv Server + +This Space hosts the **ForgeEnv** OpenEnv-compliant environment as a FastAPI +service. It exposes the standard `reset`, `step`, and `state` endpoints and is +the runtime that training notebooks (TRL + Unsloth) connect to. + +> **Theme:** Self-Improvement (Hackathon Theme #4) β€” Challenger / Solver +> co-evolution via R-Zero, SPIRAL, and Absolute Zero Reasoner techniques. + +## What it does + +ForgeEnv simulates **HuggingFace library version drift**. A *Drift Generator* +proposes a realistic breakage to a working training script (renamed APIs, +deprecated imports, changed argument signatures, etc.). A *Repair Agent* then +emits a unified diff that should restore the script. Reward is computed by an +execution simulator + AST checker + held-out evaluator (multi-component to +resist reward hacking). + +## API + +The server uses [`openenv-core`](https://pypi.org/project/openenv-core/) and +follows the Gym-style contract: + +| Endpoint | Method | Purpose | +| -------- | ------ | -------------------------------------------------- | +| `/reset` | POST | Sample a fresh task, return drift-gen observation | +| `/step` | POST | Apply a `ForgeAction` (breakage or repair) | +| `/state` | GET | Inspect the current internal state | +| `/health`| GET | Health probe (used by the container HEALTHCHECK) | + +`ForgeAction` is a discriminated union of `BreakageAction` (used in phase 1) +and `RepairAction` (used in phase 2). See +[`forgeenv/env/actions.py`](forgeenv/env/actions.py). + +## Quick test + +```bash +curl -X POST https://akhiilll-forgeenv.hf.space/reset +curl https://akhiilll-forgeenv.hf.space/state +``` + +```python +from openenv.core.env_client import EnvClient + +async with EnvClient(base_url="https://akhiilll-forgeenv.hf.space") as client: + obs = await client.reset() + print(obs.observation.current_phase, obs.observation.task_id) +``` + +## Project links + +- **Main repo / training notebooks / plots:** + +- **Repair Agent model (LoRA):** + +- **Demo (Gradio + ZeroGPU):** + + +## Citations + +- Huang et al., *R-Zero: Self-Evolving Reasoning LLM From Zero Data* (2025) +- Zhao et al., *Absolute Zero: Reinforced Self-play Reasoning with Zero Data* (2025) +- Liu et al., *SPIRAL: Self-Play on Zero-Sum Games* (2025) +- [arXiv:2408.10215](https://arxiv.org/abs/2408.10215) β€” Reward engineering & shaping +- [arXiv:2601.19100](https://arxiv.org/abs/2601.19100) β€” Reward engineering for RL in software tasks diff --git a/forgeenv-space/openenv.yaml b/forgeenv-space/openenv.yaml index fc8436b62e6ca01cea997c4dc9dbd10b52412f75..9a8a1612bb25b32201f8e163b850ecb4195f3eb7 100644 --- a/forgeenv-space/openenv.yaml +++ b/forgeenv-space/openenv.yaml @@ -1,24 +1,24 @@ -name: forgeenv -version: 0.1.0 -description: > - Self-improving RL environment for HuggingFace ecosystem repair. - Trains agents to fix broken training scripts under library version drift - through Challenger-Solver co-evolution. Implements R-Zero (Tencent), SPIRAL, - and Absolute Zero Reasoner techniques on top of OpenEnv. -theme: self-improvement -tags: - - openenv - - self-play - - code-repair - - schema-drift - - multi-role - - huggingface - - reinforcement-learning -environment: - class: forgeenv.env.forge_environment.ForgeEnvironment - action_model: forgeenv.env.actions.ForgeAction - observation_model: forgeenv.env.observations.ForgeObservation -server: - module: forgeenv.env.server - app: app - port: 7860 +name: forgeenv +version: 0.1.0 +description: > + Self-improving RL environment for HuggingFace ecosystem repair. + Trains agents to fix broken training scripts under library version drift + through Challenger-Solver co-evolution. Implements R-Zero (Tencent), SPIRAL, + and Absolute Zero Reasoner techniques on top of OpenEnv. +theme: self-improvement +tags: + - openenv + - self-play + - code-repair + - schema-drift + - multi-role + - huggingface + - reinforcement-learning +environment: + class: forgeenv.env.forge_environment.ForgeEnvironment + action_model: forgeenv.env.actions.ForgeAction + observation_model: forgeenv.env.observations.ForgeObservation +server: + module: forgeenv.env.server + app: app + port: 7860 diff --git a/forgeenv-space/requirements.txt b/forgeenv-space/requirements.txt index d3f30cf19c6bae52e27b350549b56291ae204908..2bcb02b949f331b20f0d9f75321a80cf12cd7cf7 100644 --- a/forgeenv-space/requirements.txt +++ b/forgeenv-space/requirements.txt @@ -1,9 +1,9 @@ -openenv-core>=0.2.0 -fastapi>=0.110.0 -uvicorn[standard]>=0.27.0 -pydantic>=2.6.0 -pyyaml>=6.0 -nltk>=3.8.0 -scikit-learn>=1.4.0 -numpy>=1.26.0 -rich>=13.7.0 +openenv-core>=0.2.0 +fastapi>=0.110.0 +uvicorn[standard]>=0.27.0 +pydantic>=2.6.0 +pyyaml>=6.0 +nltk>=3.8.0 +scikit-learn>=1.4.0 +numpy>=1.26.0 +rich>=13.7.0 diff --git a/forgeenv/__init__.py b/forgeenv/__init__.py index 61467dfc573e7c6892692d21b98285d8a6411c05..f4a44a7be86ebf99eaee8feb7495e81461af24e2 100644 --- a/forgeenv/__init__.py +++ b/forgeenv/__init__.py @@ -1,4 +1,4 @@ -"""ForgeEnv: Self-improving RL environment for HuggingFace ecosystem repair.""" - -__version__ = "0.1.0" -__author__ = "akhiilll" +"""ForgeEnv: Self-improving RL environment for HuggingFace ecosystem repair.""" + +__version__ = "0.1.0" +__author__ = "akhiilll" diff --git a/forgeenv/artifacts/repair_library.py b/forgeenv/artifacts/repair_library.py index fd3a7a59df36e3c92e7827a6199144920d40a0d2..6967c362137279049035c7af7fdda4a359f32dc7 100644 --- a/forgeenv/artifacts/repair_library.py +++ b/forgeenv/artifacts/repair_library.py @@ -1,120 +1,120 @@ -"""Persisted "repair library" β€” the model's accumulated knowledge of -known breakage -> repair pairs. Curated from successful rollouts during -training. Loaded at inference time as a few-shot prefix when the agent -recognises a familiar error class. -""" -from __future__ import annotations - -import json -from dataclasses import asdict, dataclass, field -from pathlib import Path -from typing import Any, Optional - - -@dataclass -class RepairExample: - primitive_type: str - breakage_params: dict[str, Any] - error_signature: str - repair_diff: str - visible_reward: float - held_out: dict[str, float] - task_id: str = "" - - def signature_key(self) -> str: - return f"{self.primitive_type}::{self.error_signature[:80]}" - - -@dataclass -class RepairLibrary: - examples: list[RepairExample] = field(default_factory=list) - - def add(self, example: RepairExample) -> None: - self.examples.append(example) - - def best_match(self, primitive_type: str, error_text: str) -> Optional[RepairExample]: - """Return the highest-reward example whose primitive_type matches and - whose error text overlaps.""" - candidates = [ - e for e in self.examples if e.primitive_type == primitive_type - ] - if not candidates: - return None - scored = sorted( - candidates, - key=lambda e: ( - _ngram_overlap(e.error_signature, error_text), - e.visible_reward, - ), - reverse=True, - ) - return scored[0] if scored else None - - def to_dict(self) -> dict: - return { - "version": "1", - "examples": [asdict(e) for e in self.examples], - "size": len(self.examples), - "by_primitive": _count_by_primitive(self.examples), - } - - def save(self, path: str | Path) -> None: - path = Path(path) - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(self.to_dict(), indent=2), encoding="utf-8") - - @classmethod - def load(cls, path: str | Path) -> "RepairLibrary": - data = json.loads(Path(path).read_text(encoding="utf-8")) - examples = [RepairExample(**e) for e in data.get("examples", [])] - return cls(examples=examples) - - -def _ngram_overlap(a: str, b: str, n: int = 3) -> float: - if not a or not b: - return 0.0 - - def grams(text: str) -> set[str]: - text = text.lower() - return {text[i : i + n] for i in range(len(text) - n + 1)} - - ga, gb = grams(a), grams(b) - if not ga or not gb: - return 0.0 - return len(ga & gb) / max(1, len(ga | gb)) - - -def _count_by_primitive(examples: list[RepairExample]) -> dict[str, int]: - counts: dict[str, int] = {} - for e in examples: - counts[e.primitive_type] = counts.get(e.primitive_type, 0) + 1 - return counts - - -def curate_from_rollouts( - rollout_results: list, - min_reward: float = 0.6, - min_held_out_clean: float = 0.5, -) -> RepairLibrary: - """Build a RepairLibrary from a list of rollout dicts/RolloutResults.""" - lib = RepairLibrary() - for r in rollout_results: - get = r.get if isinstance(r, dict) else lambda k, default=None: getattr(r, k, default) - if float(get("visible_reward", 0.0) or 0.0) < min_reward: - continue - if float(get("held_out_breakdown", {}).get("executed_cleanly", 0.0)) < min_held_out_clean: - continue - lib.add( - RepairExample( - primitive_type=str(get("primitive_type", "unknown")), - breakage_params=dict(get("info", {}).get("breakage_spec", {}).get("params", {})) - if isinstance(get("info", {}), dict) - else {}, - error_signature=str(get("error_trace", "") or "")[:160], - repair_diff=str(get("repair_completion", "") or get("info", {}).get("repair_diff", ""))[:2000], - visible_reward=float(get("visible_reward", 0.0) or 0.0), - held_out=dict(get("held_out_breakdown", {}) or {}), - task_id=str(get("task_id", "")), - ) - ) - return lib +"""Persisted "repair library" β€” the model's accumulated knowledge of +known breakage -> repair pairs. Curated from successful rollouts during +training. Loaded at inference time as a few-shot prefix when the agent +recognises a familiar error class. +""" +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any, Optional + + +@dataclass +class RepairExample: + primitive_type: str + breakage_params: dict[str, Any] + error_signature: str + repair_diff: str + visible_reward: float + held_out: dict[str, float] + task_id: str = "" + + def signature_key(self) -> str: + return f"{self.primitive_type}::{self.error_signature[:80]}" + + +@dataclass +class RepairLibrary: + examples: list[RepairExample] = field(default_factory=list) + + def add(self, example: RepairExample) -> None: + self.examples.append(example) + + def best_match(self, primitive_type: str, error_text: str) -> Optional[RepairExample]: + """Return the highest-reward example whose primitive_type matches and + whose error text overlaps.""" + candidates = [ + e for e in self.examples if e.primitive_type == primitive_type + ] + if not candidates: + return None + scored = sorted( + candidates, + key=lambda e: ( + _ngram_overlap(e.error_signature, error_text), + e.visible_reward, + ), + reverse=True, + ) + return scored[0] if scored else None + + def to_dict(self) -> dict: + return { + "version": "1", + "examples": [asdict(e) for e in self.examples], + "size": len(self.examples), + "by_primitive": _count_by_primitive(self.examples), + } + + def save(self, path: str | Path) -> None: + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(self.to_dict(), indent=2), encoding="utf-8") + + @classmethod + def load(cls, path: str | Path) -> "RepairLibrary": + data = json.loads(Path(path).read_text(encoding="utf-8")) + examples = [RepairExample(**e) for e in data.get("examples", [])] + return cls(examples=examples) + + +def _ngram_overlap(a: str, b: str, n: int = 3) -> float: + if not a or not b: + return 0.0 + + def grams(text: str) -> set[str]: + text = text.lower() + return {text[i : i + n] for i in range(len(text) - n + 1)} + + ga, gb = grams(a), grams(b) + if not ga or not gb: + return 0.0 + return len(ga & gb) / max(1, len(ga | gb)) + + +def _count_by_primitive(examples: list[RepairExample]) -> dict[str, int]: + counts: dict[str, int] = {} + for e in examples: + counts[e.primitive_type] = counts.get(e.primitive_type, 0) + 1 + return counts + + +def curate_from_rollouts( + rollout_results: list, + min_reward: float = 0.6, + min_held_out_clean: float = 0.5, +) -> RepairLibrary: + """Build a RepairLibrary from a list of rollout dicts/RolloutResults.""" + lib = RepairLibrary() + for r in rollout_results: + get = r.get if isinstance(r, dict) else lambda k, default=None: getattr(r, k, default) + if float(get("visible_reward", 0.0) or 0.0) < min_reward: + continue + if float(get("held_out_breakdown", {}).get("executed_cleanly", 0.0)) < min_held_out_clean: + continue + lib.add( + RepairExample( + primitive_type=str(get("primitive_type", "unknown")), + breakage_params=dict(get("info", {}).get("breakage_spec", {}).get("params", {})) + if isinstance(get("info", {}), dict) + else {}, + error_signature=str(get("error_trace", "") or "")[:160], + repair_diff=str(get("repair_completion", "") or get("info", {}).get("repair_diff", ""))[:2000], + visible_reward=float(get("visible_reward", 0.0) or 0.0), + held_out=dict(get("held_out_breakdown", {}) or {}), + task_id=str(get("task_id", "")), + ) + ) + return lib diff --git a/forgeenv/drift/library_drift_engine.py b/forgeenv/drift/library_drift_engine.py index 8f402eeba8ac35896050325c25b86e2494dce188..4fe6e8c372e06864bad258e6d020567cfeb7f279 100644 --- a/forgeenv/drift/library_drift_engine.py +++ b/forgeenv/drift/library_drift_engine.py @@ -1,74 +1,74 @@ -"""Library Drift Engine. - -Manages library version snapshots and triggers version upgrades during -training to create non-stationary verification. In simulation mode it -just tracks the current snapshot index β€” that index influences -breakage selection and is exposed in observations so the Repair Agent -can adapt. - -Also exposes Chojecki GVU's SNR computation -(https://arxiv.org/abs/2512.02731 Definition 4.4). -""" -from __future__ import annotations - -import math -from dataclasses import dataclass, field - -DEFAULT_VERSION_SNAPSHOTS: list[dict[str, str]] = [ - {"transformers": "4.36.0", "datasets": "2.14.0", "trl": "0.7.0"}, - {"transformers": "4.40.0", "datasets": "2.18.0", "trl": "0.8.0"}, - {"transformers": "4.45.0", "datasets": "3.0.0", "trl": "0.10.0"}, - {"transformers": "4.50.0", "datasets": "3.2.0", "trl": "0.12.0"}, -] - - -@dataclass -class LibraryDriftEngine: - snapshots: list[dict[str, str]] = field( - default_factory=lambda: list(DEFAULT_VERSION_SNAPSHOTS) - ) - current_index: int = 0 - drift_history: list[dict] = field(default_factory=list) - - def current_versions(self) -> dict[str, str]: - return dict(self.snapshots[self.current_index]) - - def maybe_drift(self, episode_num: int, drift_every: int = 50) -> bool: - if ( - episode_num > 0 - and episode_num % drift_every == 0 - and self.current_index < len(self.snapshots) - 1 - ): - prev = self.snapshots[self.current_index] - self.current_index += 1 - self.drift_history.append( - { - "episode": episode_num, - "from": prev, - "to": self.snapshots[self.current_index], - } - ) - return True - return False - - def reset(self) -> None: - self.current_index = 0 - self.drift_history.clear() - - @staticmethod - def compute_snr( - recent_held_out: list[float], recent_visible: list[float] - ) -> dict[str, float]: - """SNR per Chojecki GVU Def 4.4: SNR = mean(rewards)^2 / variance(rewards).""" - - def snr(values: list[float]) -> float: - if len(values) < 2: - return 0.0 - mean = sum(values) / len(values) - var = sum((v - mean) ** 2 for v in values) / len(values) - return mean**2 / max(var, 1e-8) - - return { - "snr_verifier": snr(recent_held_out), - "snr_generator": snr(recent_visible), - } +"""Library Drift Engine. + +Manages library version snapshots and triggers version upgrades during +training to create non-stationary verification. In simulation mode it +just tracks the current snapshot index β€” that index influences +breakage selection and is exposed in observations so the Repair Agent +can adapt. + +Also exposes Chojecki GVU's SNR computation +(https://arxiv.org/abs/2512.02731 Definition 4.4). +""" +from __future__ import annotations + +import math +from dataclasses import dataclass, field + +DEFAULT_VERSION_SNAPSHOTS: list[dict[str, str]] = [ + {"transformers": "4.36.0", "datasets": "2.14.0", "trl": "0.7.0"}, + {"transformers": "4.40.0", "datasets": "2.18.0", "trl": "0.8.0"}, + {"transformers": "4.45.0", "datasets": "3.0.0", "trl": "0.10.0"}, + {"transformers": "4.50.0", "datasets": "3.2.0", "trl": "0.12.0"}, +] + + +@dataclass +class LibraryDriftEngine: + snapshots: list[dict[str, str]] = field( + default_factory=lambda: list(DEFAULT_VERSION_SNAPSHOTS) + ) + current_index: int = 0 + drift_history: list[dict] = field(default_factory=list) + + def current_versions(self) -> dict[str, str]: + return dict(self.snapshots[self.current_index]) + + def maybe_drift(self, episode_num: int, drift_every: int = 50) -> bool: + if ( + episode_num > 0 + and episode_num % drift_every == 0 + and self.current_index < len(self.snapshots) - 1 + ): + prev = self.snapshots[self.current_index] + self.current_index += 1 + self.drift_history.append( + { + "episode": episode_num, + "from": prev, + "to": self.snapshots[self.current_index], + } + ) + return True + return False + + def reset(self) -> None: + self.current_index = 0 + self.drift_history.clear() + + @staticmethod + def compute_snr( + recent_held_out: list[float], recent_visible: list[float] + ) -> dict[str, float]: + """SNR per Chojecki GVU Def 4.4: SNR = mean(rewards)^2 / variance(rewards).""" + + def snr(values: list[float]) -> float: + if len(values) < 2: + return 0.0 + mean = sum(values) / len(values) + var = sum((v - mean) ** 2 for v in values) / len(values) + return mean**2 / max(var, 1e-8) + + return { + "snr_verifier": snr(recent_held_out), + "snr_generator": snr(recent_visible), + } diff --git a/forgeenv/env/actions.py b/forgeenv/env/actions.py index 8ca4a87802a22130e212fa5b4a7efd5356f63f9b..45787f2f1d496d24c7864fdfe6a1f3427021fc69 100644 --- a/forgeenv/env/actions.py +++ b/forgeenv/env/actions.py @@ -1,50 +1,50 @@ -"""Pydantic action models for ForgeEnv (compatible with OpenEnv 0.2.x). - -Episodes have two phases β€” drift_gen (Challenger) and repair (Solver) β€” so -we expose a single union ForgeAction that carries either a BreakageAction -or a RepairAction. The environment dispatches on which sub-field is set. -""" -from __future__ import annotations - -from typing import Any, Literal, Optional - -from pydantic import Field - -from openenv.core import Action - - -class BreakageAction(Action): - """Drift Generator's action: pick a primitive type + parameters.""" - - action_type: Literal["breakage"] = "breakage" - primitive_type: str = Field( - ..., description="One of the registered breakage primitive class names" - ) - params: dict[str, Any] = Field( - default_factory=dict, description="Primitive-specific parameters" - ) - - -class RepairAction(Action): - """Repair Agent's action: a unified diff (or full replacement script).""" - - action_type: Literal["repair"] = "repair" - unified_diff: str = Field(..., description="Unified diff or full replacement script") - - -class ForgeAction(Action): - """Union action: exactly one of `breakage` / `repair` must be set. - - This is the type registered with OpenEnv's `create_app`. It avoids - Pydantic discriminated unions to keep the OpenAPI schema flat and - cross-version-friendly. - """ - - breakage: Optional[BreakageAction] = None - repair: Optional[RepairAction] = None - - def model_post_init(self, __context: Any) -> None: - if (self.breakage is None) == (self.repair is None): - raise ValueError( - "ForgeAction requires exactly one of `breakage` or `repair` to be set." - ) +"""Pydantic action models for ForgeEnv (compatible with OpenEnv 0.2.x). + +Episodes have two phases β€” drift_gen (Challenger) and repair (Solver) β€” so +we expose a single union ForgeAction that carries either a BreakageAction +or a RepairAction. The environment dispatches on which sub-field is set. +""" +from __future__ import annotations + +from typing import Any, Literal, Optional + +from pydantic import Field + +from openenv.core import Action + + +class BreakageAction(Action): + """Drift Generator's action: pick a primitive type + parameters.""" + + action_type: Literal["breakage"] = "breakage" + primitive_type: str = Field( + ..., description="One of the registered breakage primitive class names" + ) + params: dict[str, Any] = Field( + default_factory=dict, description="Primitive-specific parameters" + ) + + +class RepairAction(Action): + """Repair Agent's action: a unified diff (or full replacement script).""" + + action_type: Literal["repair"] = "repair" + unified_diff: str = Field(..., description="Unified diff or full replacement script") + + +class ForgeAction(Action): + """Union action: exactly one of `breakage` / `repair` must be set. + + This is the type registered with OpenEnv's `create_app`. It avoids + Pydantic discriminated unions to keep the OpenAPI schema flat and + cross-version-friendly. + """ + + breakage: Optional[BreakageAction] = None + repair: Optional[RepairAction] = None + + def model_post_init(self, __context: Any) -> None: + if (self.breakage is None) == (self.repair is None): + raise ValueError( + "ForgeAction requires exactly one of `breakage` or `repair` to be set." + ) diff --git a/forgeenv/env/diff_utils.py b/forgeenv/env/diff_utils.py index 64e92937b628219ac378e7bb295776a25b16144c..aed57859d9c4dca5bac496b7c10b12e4bc927669 100644 --- a/forgeenv/env/diff_utils.py +++ b/forgeenv/env/diff_utils.py @@ -1,163 +1,163 @@ -"""Unified-diff application utilities. - -The Repair Agent submits a unified diff. We need a permissive applier -because LLM diffs are often malformed (wrong line numbers, missing -context, extra prose). We try the strict applier first, then fall -back to applying hunks via plain string replacement. - -The agent may also submit a full Python script instead of a diff -(common when the model's diff format breaks). We detect this and -treat it as a complete replacement. -""" -from __future__ import annotations - -import difflib -import re - - -_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE) -_SCRIPT_MARKERS = ("import ", "from ", "def ", "class ", "print(") - - -def looks_like_full_script(text: str) -> bool: - """Heuristic: text is probably a full python script, not a diff.""" - lines = text.lstrip().splitlines() - if not lines: - return False - has_diff_header = any( - line.startswith(("---", "+++", "@@")) for line in lines[:5] - ) - if has_diff_header: - return False - # If we see two or more script-style markers in the first 30 lines, - # treat as a full replacement script. - head = "\n".join(lines[:30]) - hits = sum(1 for marker in _SCRIPT_MARKERS if marker in head) - return hits >= 2 - - -def _strict_apply(broken_script: str, diff_text: str) -> str | None: - """Apply a unified diff strictly. Returns None on any failure.""" - lines = broken_script.splitlines(keepends=True) - out: list[str] = [] - diff_lines = diff_text.splitlines() - i = 0 - src_idx = 0 - in_hunk = False - hunk_old: list[str] = [] - hunk_new: list[str] = [] - - while i < len(diff_lines): - line = diff_lines[i] - if line.startswith(("---", "+++")): - i += 1 - continue - if line.startswith("@@"): - # Flush previous hunk - if in_hunk: - # Find the hunk_old block in the source starting at src_idx. - target = "".join(hunk_old) - source_remainder = "".join(lines[src_idx:]) - pos = source_remainder.find(target) - if pos == -1: - return None - out.append(source_remainder[:pos]) - out.append("".join(hunk_new)) - src_idx += len(source_remainder[: pos + len(target)].splitlines(keepends=True)) - hunk_old, hunk_new = [], [] - in_hunk = True - i += 1 - continue - if in_hunk: - if line.startswith("+"): - hunk_new.append(line[1:] + "\n") - elif line.startswith("-"): - hunk_old.append(line[1:] + "\n") - else: - # context line - ctx = line[1:] if line.startswith(" ") else line - hunk_old.append(ctx + "\n") - hunk_new.append(ctx + "\n") - i += 1 - - # Flush trailing hunk - if in_hunk and (hunk_old or hunk_new): - target = "".join(hunk_old) - source_remainder = "".join(lines[src_idx:]) - pos = source_remainder.find(target) - if pos == -1: - return None - out.append(source_remainder[:pos]) - out.append("".join(hunk_new)) - consumed = source_remainder[: pos + len(target)] - src_idx += len(consumed.splitlines(keepends=True)) - - out.append("".join(lines[src_idx:])) - return "".join(out) - - -def _permissive_apply(broken_script: str, diff_text: str) -> str: - """Apply a malformed diff by extracting (-,+) line pairs and doing - a tolerant search-and-replace. - """ - repaired = broken_script - pairs: list[tuple[str, str]] = [] - lines = diff_text.splitlines() - pending_minus: str | None = None - - for line in lines: - if line.startswith("---") or line.startswith("+++") or line.startswith("@@"): - pending_minus = None - continue - if line.startswith("-"): - pending_minus = line[1:].strip() - elif line.startswith("+") and pending_minus is not None: - pairs.append((pending_minus, line[1:].strip())) - pending_minus = None - elif pending_minus is not None and not line.startswith(" "): - # standalone deletion β€” skip in permissive mode (we can't - # reliably know what to delete without context) - pending_minus = None - - for old, new in pairs: - if old and old in repaired: - repaired = repaired.replace(old, new, 1) - - return repaired - - -def apply_unified_diff(broken_script: str, diff_text: str) -> str: - """Try every strategy in order and return the first that produces a change. - - Strategies: - 1. If `diff_text` looks like a full script, return it directly. - 2. Try strict diff application. - 3. Fall back to permissive (-,+) line-pair replacement. - 4. As last resort, return the broken script unchanged. - """ - diff_text = diff_text or "" - if not diff_text.strip(): - return broken_script - - if looks_like_full_script(diff_text): - return diff_text - - if _HUNK_RE.search(diff_text) or "---" in diff_text or "+++" in diff_text: - strict = _strict_apply(broken_script, diff_text) - if strict is not None and strict != broken_script: - return strict - - perm = _permissive_apply(broken_script, diff_text) - return perm - - -def make_unified_diff(before: str, after: str, path: str = "train.py") -> str: - """Produce a canonical unified diff from before -> after.""" - diff = difflib.unified_diff( - before.splitlines(keepends=True), - after.splitlines(keepends=True), - fromfile=f"a/{path}", - tofile=f"b/{path}", - n=2, - ) - return "".join(diff) +"""Unified-diff application utilities. + +The Repair Agent submits a unified diff. We need a permissive applier +because LLM diffs are often malformed (wrong line numbers, missing +context, extra prose). We try the strict applier first, then fall +back to applying hunks via plain string replacement. + +The agent may also submit a full Python script instead of a diff +(common when the model's diff format breaks). We detect this and +treat it as a complete replacement. +""" +from __future__ import annotations + +import difflib +import re + + +_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE) +_SCRIPT_MARKERS = ("import ", "from ", "def ", "class ", "print(") + + +def looks_like_full_script(text: str) -> bool: + """Heuristic: text is probably a full python script, not a diff.""" + lines = text.lstrip().splitlines() + if not lines: + return False + has_diff_header = any( + line.startswith(("---", "+++", "@@")) for line in lines[:5] + ) + if has_diff_header: + return False + # If we see two or more script-style markers in the first 30 lines, + # treat as a full replacement script. + head = "\n".join(lines[:30]) + hits = sum(1 for marker in _SCRIPT_MARKERS if marker in head) + return hits >= 2 + + +def _strict_apply(broken_script: str, diff_text: str) -> str | None: + """Apply a unified diff strictly. Returns None on any failure.""" + lines = broken_script.splitlines(keepends=True) + out: list[str] = [] + diff_lines = diff_text.splitlines() + i = 0 + src_idx = 0 + in_hunk = False + hunk_old: list[str] = [] + hunk_new: list[str] = [] + + while i < len(diff_lines): + line = diff_lines[i] + if line.startswith(("---", "+++")): + i += 1 + continue + if line.startswith("@@"): + # Flush previous hunk + if in_hunk: + # Find the hunk_old block in the source starting at src_idx. + target = "".join(hunk_old) + source_remainder = "".join(lines[src_idx:]) + pos = source_remainder.find(target) + if pos == -1: + return None + out.append(source_remainder[:pos]) + out.append("".join(hunk_new)) + src_idx += len(source_remainder[: pos + len(target)].splitlines(keepends=True)) + hunk_old, hunk_new = [], [] + in_hunk = True + i += 1 + continue + if in_hunk: + if line.startswith("+"): + hunk_new.append(line[1:] + "\n") + elif line.startswith("-"): + hunk_old.append(line[1:] + "\n") + else: + # context line + ctx = line[1:] if line.startswith(" ") else line + hunk_old.append(ctx + "\n") + hunk_new.append(ctx + "\n") + i += 1 + + # Flush trailing hunk + if in_hunk and (hunk_old or hunk_new): + target = "".join(hunk_old) + source_remainder = "".join(lines[src_idx:]) + pos = source_remainder.find(target) + if pos == -1: + return None + out.append(source_remainder[:pos]) + out.append("".join(hunk_new)) + consumed = source_remainder[: pos + len(target)] + src_idx += len(consumed.splitlines(keepends=True)) + + out.append("".join(lines[src_idx:])) + return "".join(out) + + +def _permissive_apply(broken_script: str, diff_text: str) -> str: + """Apply a malformed diff by extracting (-,+) line pairs and doing + a tolerant search-and-replace. + """ + repaired = broken_script + pairs: list[tuple[str, str]] = [] + lines = diff_text.splitlines() + pending_minus: str | None = None + + for line in lines: + if line.startswith("---") or line.startswith("+++") or line.startswith("@@"): + pending_minus = None + continue + if line.startswith("-"): + pending_minus = line[1:].strip() + elif line.startswith("+") and pending_minus is not None: + pairs.append((pending_minus, line[1:].strip())) + pending_minus = None + elif pending_minus is not None and not line.startswith(" "): + # standalone deletion β€” skip in permissive mode (we can't + # reliably know what to delete without context) + pending_minus = None + + for old, new in pairs: + if old and old in repaired: + repaired = repaired.replace(old, new, 1) + + return repaired + + +def apply_unified_diff(broken_script: str, diff_text: str) -> str: + """Try every strategy in order and return the first that produces a change. + + Strategies: + 1. If `diff_text` looks like a full script, return it directly. + 2. Try strict diff application. + 3. Fall back to permissive (-,+) line-pair replacement. + 4. As last resort, return the broken script unchanged. + """ + diff_text = diff_text or "" + if not diff_text.strip(): + return broken_script + + if looks_like_full_script(diff_text): + return diff_text + + if _HUNK_RE.search(diff_text) or "---" in diff_text or "+++" in diff_text: + strict = _strict_apply(broken_script, diff_text) + if strict is not None and strict != broken_script: + return strict + + perm = _permissive_apply(broken_script, diff_text) + return perm + + +def make_unified_diff(before: str, after: str, path: str = "train.py") -> str: + """Produce a canonical unified diff from before -> after.""" + diff = difflib.unified_diff( + before.splitlines(keepends=True), + after.splitlines(keepends=True), + fromfile=f"a/{path}", + tofile=f"b/{path}", + n=2, + ) + return "".join(diff) diff --git a/forgeenv/env/forge_environment.py b/forgeenv/env/forge_environment.py index 88aafb9c7eca34ed66102fee25add2ff563abe72..1c6af615b46442c1842460875dc82131378e9f5a 100644 --- a/forgeenv/env/forge_environment.py +++ b/forgeenv/env/forge_environment.py @@ -1,259 +1,259 @@ -"""ForgeEnvironment: the OpenEnv Environment subclass for ForgeEnv. - -Episode flow (exactly 2 steps per episode): - reset() -> sample task, ask Teacher for category - step(BreakageAction) -> Drift Generator's proposal is applied; broken - script is run, error trace captured. - step(RepairAction) -> Repair diff is applied; script is re-executed; - visible + held-out rewards computed; episode ends. -""" -from __future__ import annotations - -import time -import uuid -from typing import Any, Optional - -from openenv.core import Environment - -from forgeenv.drift.library_drift_engine import LibraryDriftEngine -from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction -from forgeenv.env.diff_utils import apply_unified_diff -from forgeenv.env.observations import ForgeObservation -from forgeenv.primitives.breakage_primitives import ( - PRIMITIVE_REGISTRY, - parse_breakage_spec, -) -from forgeenv.roles.teacher import Teacher -from forgeenv.sandbox.simulation_mode import SimulationExecutor -from forgeenv.tasks.models import ExecutionResult, Task -from forgeenv.tasks.task_sampler import TaskSampler -from forgeenv.verifier.held_out_evaluator import compute_held_out_scores -from forgeenv.verifier.visible_verifier import compute_visible_reward - -DEFAULT_CATEGORIES = sorted(PRIMITIVE_REGISTRY.keys()) - - -class ForgeEnvironment(Environment[ForgeAction, ForgeObservation, dict]): - """OpenEnv-compliant environment for HuggingFace ecosystem repair.""" - - SUPPORTS_CONCURRENT_SESSIONS = False # Teacher state is global per env - - def __init__( - self, - task_sampler: Optional[TaskSampler] = None, - teacher: Optional[Teacher] = None, - executor: Optional[SimulationExecutor] = None, - drift_engine: Optional[LibraryDriftEngine] = None, - seed: Optional[int] = None, - ) -> None: - super().__init__() - self.task_sampler = task_sampler or TaskSampler() - self.teacher = teacher or Teacher( - categories=list(DEFAULT_CATEGORIES) or ["api_drift"] - ) - self.executor = executor or SimulationExecutor(seed=seed) - self.drift_engine = drift_engine or LibraryDriftEngine() - - self._episode_id: Optional[str] = None - self._episode_count: int = 0 - self._current_task: Optional[Task] = None - self._original_script: str = "" - self._broken_script: str = "" - self._error_trace: str = "" - self._breakage_spec: Optional[dict[str, Any]] = None - self._target_category: str = "" - self._current_phase: str = "idle" - self._last_obs: Optional[ForgeObservation] = None - - # ------------------------------------------------------------------ API - def reset( - self, - seed: Optional[int] = None, - episode_id: Optional[str] = None, - difficulty: Optional[str] = "easy", - **kwargs: Any, - ) -> ForgeObservation: - self._episode_id = episode_id or str(uuid.uuid4()) - self._episode_count += 1 - self._target_category = self.teacher.select_next_category() - - task = self.task_sampler.sample(difficulty=difficulty) - if task is None: - raise RuntimeError("Task sampler returned no tasks (empty seed corpus?)") - self._current_task = task - self._original_script = task.script_content - self._broken_script = "" - self._error_trace = "" - self._breakage_spec = None - self._current_phase = "drift_gen" - - # Library drift trigger every 50 episodes (configurable from outside). - drifted = self.drift_engine.maybe_drift(self._episode_count, drift_every=50) - - obs = ForgeObservation( - current_phase="drift_gen", - task_id=task.task_id, - task_description=task.description, - target_category=self._target_category, - script_content=self._original_script, - error_trace=None, - library_versions=self.drift_engine.current_versions(), - episode_step=0, - done=False, - reward=0.0, - info={ - "episode_id": self._episode_id, - "episode_count": self._episode_count, - "drift_triggered": drifted, - "available_primitives": sorted(PRIMITIVE_REGISTRY), - }, - ) - self._last_obs = obs - return obs - - def step( - self, - action: ForgeAction, - timeout_s: Optional[float] = None, - **kwargs: Any, - ) -> ForgeObservation: - if self._current_phase == "drift_gen": - if action.breakage is None: - return self._error_obs("Expected BreakageAction in drift_gen phase") - return self._handle_breakage(action.breakage) - - if self._current_phase == "repair": - if action.repair is None: - return self._error_obs("Expected RepairAction in repair phase") - return self._handle_repair(action.repair) - - return self._error_obs( - f"step() called in invalid phase {self._current_phase!r} β€” call reset() first" - ) - - @property - def state(self) -> dict: - return { - "phase": self._current_phase, - "episode_id": self._episode_id, - "episode_count": self._episode_count, - "task_id": self._current_task.task_id if self._current_task else None, - "target_category": self._target_category, - "library_versions": self.drift_engine.current_versions(), - "teacher": self.teacher.get_state(), - "drift_history": list(self.drift_engine.drift_history), - "breakage_spec": dict(self._breakage_spec) if self._breakage_spec else None, - } - - # ---------------------------------------------------------------- helpers - def _handle_breakage(self, breakage: BreakageAction) -> ForgeObservation: - spec = {"primitive_type": breakage.primitive_type, "params": dict(breakage.params)} - try: - primitive = parse_breakage_spec(spec) - except ValueError as exc: - return self._error_obs(f"Invalid breakage spec: {exc}") - - try: - self._broken_script = primitive.apply(self._original_script) - except Exception as exc: # primitive bug β€” surface but don't crash server - return self._error_obs(f"Primitive apply failed: {exc}") - - self._breakage_spec = spec - - result = self.executor.execute(self._broken_script, self._current_task) - if result.exit_code != 0: - self._error_trace = result.stderr or "non-zero exit code, no stderr" - else: - # The breakage didn't actually break it; still proceed to repair phase - # (no-op repair is then a valid choice). - self._error_trace = "Script ran without observable error" - - self._current_phase = "repair" - - obs = ForgeObservation( - current_phase="repair", - task_id=self._current_task.task_id, - task_description=self._current_task.description, - target_category=primitive.category, - script_content=self._broken_script, - error_trace=self._error_trace, - library_versions=self.drift_engine.current_versions(), - episode_step=1, - done=False, - reward=0.0, - info={ - "episode_id": self._episode_id, - "breakage_primitive": primitive.name, - "breakage_description": primitive.description, - }, - ) - self._last_obs = obs - return obs - - def _handle_repair(self, repair: RepairAction) -> ForgeObservation: - repaired = apply_unified_diff(self._broken_script, repair.unified_diff or "") - - t0 = time.time() - result = self.executor.execute(repaired, self._current_task) - result.script_content = repaired # ensure verifier sees what we ran - wall_ms = int((time.time() - t0) * 1000) - - visible_reward, visible_breakdown = compute_visible_reward( - result, self._current_task - ) - held_out = compute_held_out_scores( - result, self._current_task, repair_diff=repair.unified_diff or "" - ) - - success = result.exit_code == 0 - category = ( - self._breakage_spec.get("primitive_type", "unknown") - if self._breakage_spec - else "unknown" - ) - # Update Teacher's curriculum state - self.teacher.update(category, success) - - self._current_phase = "done" - - obs = ForgeObservation( - current_phase="done", - task_id=self._current_task.task_id, - task_description=self._current_task.description, - target_category=category, - script_content=repaired, - error_trace=result.stderr or None, - library_versions=self.drift_engine.current_versions(), - episode_step=2, - done=True, - reward=visible_reward, - reward_breakdown=visible_breakdown, - held_out_breakdown=held_out, - info={ - "episode_id": self._episode_id, - "exit_code": result.exit_code, - "wall_time_ms": wall_ms, - "checkpoint_exists": result.checkpoint_exists, - "stdout_tail": "\n".join(result.stdout.splitlines()[-5:]), - "breakage_spec": self._breakage_spec, - "teacher_state": self.teacher.get_state(), - }, - ) - self._last_obs = obs - return obs - - def _error_obs(self, message: str) -> ForgeObservation: - """Return a `done=True` error observation rather than raising.""" - return ForgeObservation( - current_phase="done", - task_id=self._current_task.task_id if self._current_task else "", - task_description=self._current_task.description if self._current_task else "", - target_category=self._target_category, - script_content=self._broken_script or self._original_script, - error_trace=message, - library_versions=self.drift_engine.current_versions(), - episode_step=2, - done=True, - reward=0.0, - info={"error": message, "episode_id": self._episode_id}, - ) +"""ForgeEnvironment: the OpenEnv Environment subclass for ForgeEnv. + +Episode flow (exactly 2 steps per episode): + reset() -> sample task, ask Teacher for category + step(BreakageAction) -> Drift Generator's proposal is applied; broken + script is run, error trace captured. + step(RepairAction) -> Repair diff is applied; script is re-executed; + visible + held-out rewards computed; episode ends. +""" +from __future__ import annotations + +import time +import uuid +from typing import Any, Optional + +from openenv.core import Environment + +from forgeenv.drift.library_drift_engine import LibraryDriftEngine +from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction +from forgeenv.env.diff_utils import apply_unified_diff +from forgeenv.env.observations import ForgeObservation +from forgeenv.primitives.breakage_primitives import ( + PRIMITIVE_REGISTRY, + parse_breakage_spec, +) +from forgeenv.roles.teacher import Teacher +from forgeenv.sandbox.simulation_mode import SimulationExecutor +from forgeenv.tasks.models import ExecutionResult, Task +from forgeenv.tasks.task_sampler import TaskSampler +from forgeenv.verifier.held_out_evaluator import compute_held_out_scores +from forgeenv.verifier.visible_verifier import compute_visible_reward + +DEFAULT_CATEGORIES = sorted(PRIMITIVE_REGISTRY.keys()) + + +class ForgeEnvironment(Environment[ForgeAction, ForgeObservation, dict]): + """OpenEnv-compliant environment for HuggingFace ecosystem repair.""" + + SUPPORTS_CONCURRENT_SESSIONS = False # Teacher state is global per env + + def __init__( + self, + task_sampler: Optional[TaskSampler] = None, + teacher: Optional[Teacher] = None, + executor: Optional[SimulationExecutor] = None, + drift_engine: Optional[LibraryDriftEngine] = None, + seed: Optional[int] = None, + ) -> None: + super().__init__() + self.task_sampler = task_sampler or TaskSampler() + self.teacher = teacher or Teacher( + categories=list(DEFAULT_CATEGORIES) or ["api_drift"] + ) + self.executor = executor or SimulationExecutor(seed=seed) + self.drift_engine = drift_engine or LibraryDriftEngine() + + self._episode_id: Optional[str] = None + self._episode_count: int = 0 + self._current_task: Optional[Task] = None + self._original_script: str = "" + self._broken_script: str = "" + self._error_trace: str = "" + self._breakage_spec: Optional[dict[str, Any]] = None + self._target_category: str = "" + self._current_phase: str = "idle" + self._last_obs: Optional[ForgeObservation] = None + + # ------------------------------------------------------------------ API + def reset( + self, + seed: Optional[int] = None, + episode_id: Optional[str] = None, + difficulty: Optional[str] = "easy", + **kwargs: Any, + ) -> ForgeObservation: + self._episode_id = episode_id or str(uuid.uuid4()) + self._episode_count += 1 + self._target_category = self.teacher.select_next_category() + + task = self.task_sampler.sample(difficulty=difficulty) + if task is None: + raise RuntimeError("Task sampler returned no tasks (empty seed corpus?)") + self._current_task = task + self._original_script = task.script_content + self._broken_script = "" + self._error_trace = "" + self._breakage_spec = None + self._current_phase = "drift_gen" + + # Library drift trigger every 50 episodes (configurable from outside). + drifted = self.drift_engine.maybe_drift(self._episode_count, drift_every=50) + + obs = ForgeObservation( + current_phase="drift_gen", + task_id=task.task_id, + task_description=task.description, + target_category=self._target_category, + script_content=self._original_script, + error_trace=None, + library_versions=self.drift_engine.current_versions(), + episode_step=0, + done=False, + reward=0.0, + info={ + "episode_id": self._episode_id, + "episode_count": self._episode_count, + "drift_triggered": drifted, + "available_primitives": sorted(PRIMITIVE_REGISTRY), + }, + ) + self._last_obs = obs + return obs + + def step( + self, + action: ForgeAction, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> ForgeObservation: + if self._current_phase == "drift_gen": + if action.breakage is None: + return self._error_obs("Expected BreakageAction in drift_gen phase") + return self._handle_breakage(action.breakage) + + if self._current_phase == "repair": + if action.repair is None: + return self._error_obs("Expected RepairAction in repair phase") + return self._handle_repair(action.repair) + + return self._error_obs( + f"step() called in invalid phase {self._current_phase!r} β€” call reset() first" + ) + + @property + def state(self) -> dict: + return { + "phase": self._current_phase, + "episode_id": self._episode_id, + "episode_count": self._episode_count, + "task_id": self._current_task.task_id if self._current_task else None, + "target_category": self._target_category, + "library_versions": self.drift_engine.current_versions(), + "teacher": self.teacher.get_state(), + "drift_history": list(self.drift_engine.drift_history), + "breakage_spec": dict(self._breakage_spec) if self._breakage_spec else None, + } + + # ---------------------------------------------------------------- helpers + def _handle_breakage(self, breakage: BreakageAction) -> ForgeObservation: + spec = {"primitive_type": breakage.primitive_type, "params": dict(breakage.params)} + try: + primitive = parse_breakage_spec(spec) + except ValueError as exc: + return self._error_obs(f"Invalid breakage spec: {exc}") + + try: + self._broken_script = primitive.apply(self._original_script) + except Exception as exc: # primitive bug β€” surface but don't crash server + return self._error_obs(f"Primitive apply failed: {exc}") + + self._breakage_spec = spec + + result = self.executor.execute(self._broken_script, self._current_task) + if result.exit_code != 0: + self._error_trace = result.stderr or "non-zero exit code, no stderr" + else: + # The breakage didn't actually break it; still proceed to repair phase + # (no-op repair is then a valid choice). + self._error_trace = "Script ran without observable error" + + self._current_phase = "repair" + + obs = ForgeObservation( + current_phase="repair", + task_id=self._current_task.task_id, + task_description=self._current_task.description, + target_category=primitive.category, + script_content=self._broken_script, + error_trace=self._error_trace, + library_versions=self.drift_engine.current_versions(), + episode_step=1, + done=False, + reward=0.0, + info={ + "episode_id": self._episode_id, + "breakage_primitive": primitive.name, + "breakage_description": primitive.description, + }, + ) + self._last_obs = obs + return obs + + def _handle_repair(self, repair: RepairAction) -> ForgeObservation: + repaired = apply_unified_diff(self._broken_script, repair.unified_diff or "") + + t0 = time.time() + result = self.executor.execute(repaired, self._current_task) + result.script_content = repaired # ensure verifier sees what we ran + wall_ms = int((time.time() - t0) * 1000) + + visible_reward, visible_breakdown = compute_visible_reward( + result, self._current_task + ) + held_out = compute_held_out_scores( + result, self._current_task, repair_diff=repair.unified_diff or "" + ) + + success = result.exit_code == 0 + category = ( + self._breakage_spec.get("primitive_type", "unknown") + if self._breakage_spec + else "unknown" + ) + # Update Teacher's curriculum state + self.teacher.update(category, success) + + self._current_phase = "done" + + obs = ForgeObservation( + current_phase="done", + task_id=self._current_task.task_id, + task_description=self._current_task.description, + target_category=category, + script_content=repaired, + error_trace=result.stderr or None, + library_versions=self.drift_engine.current_versions(), + episode_step=2, + done=True, + reward=visible_reward, + reward_breakdown=visible_breakdown, + held_out_breakdown=held_out, + info={ + "episode_id": self._episode_id, + "exit_code": result.exit_code, + "wall_time_ms": wall_ms, + "checkpoint_exists": result.checkpoint_exists, + "stdout_tail": "\n".join(result.stdout.splitlines()[-5:]), + "breakage_spec": self._breakage_spec, + "teacher_state": self.teacher.get_state(), + }, + ) + self._last_obs = obs + return obs + + def _error_obs(self, message: str) -> ForgeObservation: + """Return a `done=True` error observation rather than raising.""" + return ForgeObservation( + current_phase="done", + task_id=self._current_task.task_id if self._current_task else "", + task_description=self._current_task.description if self._current_task else "", + target_category=self._target_category, + script_content=self._broken_script or self._original_script, + error_trace=message, + library_versions=self.drift_engine.current_versions(), + episode_step=2, + done=True, + reward=0.0, + info={"error": message, "episode_id": self._episode_id}, + ) diff --git a/forgeenv/env/observations.py b/forgeenv/env/observations.py index d70e8dec85c8d8e1ec4d142ee8b60a8e35b495d7..67ca5d7b91ec66b59f8b4b67c135d0aa442a52a7 100644 --- a/forgeenv/env/observations.py +++ b/forgeenv/env/observations.py @@ -1,29 +1,29 @@ -"""Pydantic observation model for ForgeEnv.""" -from __future__ import annotations - -from typing import Any, Optional - -from pydantic import Field - -from openenv.core import Observation - - -class ForgeObservation(Observation): - """What the agent (or the trainer's rollout function) sees at each step. - - Inherits `done`, `reward`, `metadata` from the OpenEnv `Observation` base. - """ - - current_phase: str = Field( - ..., description="One of 'drift_gen', 'repair', 'verify', 'done'" - ) - task_id: str = "" - task_description: str = "" - target_category: str = "" - script_content: str = Field(default="", description="Current state of the script") - error_trace: Optional[str] = None - library_versions: dict[str, str] = Field(default_factory=dict) - reward_breakdown: dict[str, Any] = Field(default_factory=dict) - held_out_breakdown: dict[str, float] = Field(default_factory=dict) - episode_step: int = 0 - info: dict[str, Any] = Field(default_factory=dict) +"""Pydantic observation model for ForgeEnv.""" +from __future__ import annotations + +from typing import Any, Optional + +from pydantic import Field + +from openenv.core import Observation + + +class ForgeObservation(Observation): + """What the agent (or the trainer's rollout function) sees at each step. + + Inherits `done`, `reward`, `metadata` from the OpenEnv `Observation` base. + """ + + current_phase: str = Field( + ..., description="One of 'drift_gen', 'repair', 'verify', 'done'" + ) + task_id: str = "" + task_description: str = "" + target_category: str = "" + script_content: str = Field(default="", description="Current state of the script") + error_trace: Optional[str] = None + library_versions: dict[str, str] = Field(default_factory=dict) + reward_breakdown: dict[str, Any] = Field(default_factory=dict) + held_out_breakdown: dict[str, float] = Field(default_factory=dict) + episode_step: int = 0 + info: dict[str, Any] = Field(default_factory=dict) diff --git a/forgeenv/env/server.py b/forgeenv/env/server.py index 2f760f130492c5b2a2e641638ed6b7d82ba12cce..5dcf42d4277128f15d63a2f95813f07a52f114f7 100644 --- a/forgeenv/env/server.py +++ b/forgeenv/env/server.py @@ -1,126 +1,126 @@ -"""FastAPI server for ForgeEnv (OpenEnv-compliant). - -Exposes /reset, /step, /state HTTP endpoints via OpenEnv's `create_app`. -HF Spaces sets PORT=7860 automatically. -""" -from __future__ import annotations - -import os - -from fastapi.responses import HTMLResponse -from openenv.core import create_app - -from forgeenv.env.actions import ForgeAction -from forgeenv.env.forge_environment import ForgeEnvironment -from forgeenv.env.observations import ForgeObservation - -app = create_app( - env=ForgeEnvironment, - action_cls=ForgeAction, - observation_cls=ForgeObservation, - env_name="forgeenv", -) - - -_LANDING_HTML = """ - - - -ForgeEnv β€” OpenEnv server - - - - -

ForgeEnv πŸ”§ running

-

OpenEnv-compliant RL environment for HuggingFace -ecosystem repair under library version drift.

- -

This URL serves the environment over HTTP. It is not a UI β€” it's the -runtime that training notebooks connect to. Open one of -the endpoints below, or use the demo Space to try the trained Repair -Agent in a browser.

- -

Endpoints

- - - - - - - - - -
MethodPathPurpose
GET /healthHealth probe
POST/resetSample task, return drift-gen observation
POST/stepApply ForgeAction (breakage or repair)
GET /stateCurrent internal state
GET /metadataEnv name + version + schema URLs
GET /schemaAction / observation JSON schemas
GET /docsInteractive Swagger UI
- -

Quick start (Python)

-
import asyncio
-from openenv.core import GenericEnvClient
-
-async def go():
-    client = GenericEnvClient(base_url="https://akhiilll-forgeenv.hf.space")
-    obs = await client.reset()
-    print(obs.observation["current_phase"], obs.observation["task_id"])
-
-asyncio.run(go())
- -

Project links

- -

Tip: if links don't open from inside the embedded Space frame, -right-click and choose Open in new tab, or open this URL directly -at akhiilll-forgeenv.hf.space.

- -""" - - -def _attach_supplementary_routes(_app) -> None: - """Add /health and a friendly GET / landing page if not present.""" - existing = { - getattr(r, "path", None) for r in getattr(_app, "routes", []) - } - - if "/health" not in existing: - @_app.get("/health") - def _health() -> dict: - return {"status": "ok", "env": "forgeenv"} - - if "/" not in existing: - @_app.get("/", response_class=HTMLResponse, include_in_schema=False) - def _root() -> str: - return _LANDING_HTML - - -_attach_supplementary_routes(app) - - -if __name__ == "__main__": - import uvicorn - - port = int(os.environ.get("PORT", "7860")) - uvicorn.run(app, host="0.0.0.0", port=port) +"""FastAPI server for ForgeEnv (OpenEnv-compliant). + +Exposes /reset, /step, /state HTTP endpoints via OpenEnv's `create_app`. +HF Spaces sets PORT=7860 automatically. +""" +from __future__ import annotations + +import os + +from fastapi.responses import HTMLResponse +from openenv.core import create_app + +from forgeenv.env.actions import ForgeAction +from forgeenv.env.forge_environment import ForgeEnvironment +from forgeenv.env.observations import ForgeObservation + +app = create_app( + env=ForgeEnvironment, + action_cls=ForgeAction, + observation_cls=ForgeObservation, + env_name="forgeenv", +) + + +_LANDING_HTML = """ + + + +ForgeEnv β€” OpenEnv server + + + + +

ForgeEnv πŸ”§ running

+

OpenEnv-compliant RL environment for HuggingFace +ecosystem repair under library version drift.

+ +

This URL serves the environment over HTTP. It is not a UI β€” it's the +runtime that training notebooks connect to. Open one of +the endpoints below, or use the demo Space to try the trained Repair +Agent in a browser.

+ +

Endpoints

+ + + + + + + + + +
MethodPathPurpose
GET /healthHealth probe
POST/resetSample task, return drift-gen observation
POST/stepApply ForgeAction (breakage or repair)
GET /stateCurrent internal state
GET /metadataEnv name + version + schema URLs
GET /schemaAction / observation JSON schemas
GET /docsInteractive Swagger UI
+ +

Quick start (Python)

+
import asyncio
+from openenv.core import GenericEnvClient
+
+async def go():
+    client = GenericEnvClient(base_url="https://akhiilll-forgeenv.hf.space")
+    obs = await client.reset()
+    print(obs.observation["current_phase"], obs.observation["task_id"])
+
+asyncio.run(go())
+ +

Project links

+ +

Tip: if links don't open from inside the embedded Space frame, +right-click and choose Open in new tab, or open this URL directly +at akhiilll-forgeenv.hf.space.

+ +""" + + +def _attach_supplementary_routes(_app) -> None: + """Add /health and a friendly GET / landing page if not present.""" + existing = { + getattr(r, "path", None) for r in getattr(_app, "routes", []) + } + + if "/health" not in existing: + @_app.get("/health") + def _health() -> dict: + return {"status": "ok", "env": "forgeenv"} + + if "/" not in existing: + @_app.get("/", response_class=HTMLResponse, include_in_schema=False) + def _root() -> str: + return _LANDING_HTML + + +_attach_supplementary_routes(app) + + +if __name__ == "__main__": + import uvicorn + + port = int(os.environ.get("PORT", "7860")) + uvicorn.run(app, host="0.0.0.0", port=port) diff --git a/forgeenv/primitives/breakage_primitives.py b/forgeenv/primitives/breakage_primitives.py index 1db759d95ca806a5af09de3a5586d2b9270a5eb6..9f243f8daac5a2ca6e9c4237abba87e4f8d5def1 100644 --- a/forgeenv/primitives/breakage_primitives.py +++ b/forgeenv/primitives/breakage_primitives.py @@ -1,282 +1,282 @@ -"""8 breakage primitives representing real HuggingFace/PyTorch ecosystem drift. - -Each primitive transforms a working script to simulate a library upgrade -breakage. They double as the Drift Generator's structured action space. -""" -from __future__ import annotations - -import re -from abc import ABC, abstractmethod -from dataclasses import dataclass, field - - -@dataclass -class BreakagePrimitive(ABC): - """Abstract base class for all breakage types.""" - - category: str = field(default="generic", init=False) - name: str = field(default="BreakagePrimitive", init=False) - description: str = field(default="", init=False) - - @abstractmethod - def apply(self, script: str) -> str: - """Transform `script` to introduce the breakage.""" - - def to_spec(self) -> dict: - """Serialize to JSON-compatible spec for the LLM action space.""" - return { - "primitive_type": self.__class__.__name__, - "category": self.category, - "params": self._get_params(), - } - - @abstractmethod - def _get_params(self) -> dict: - """Return a JSON-serializable dict of constructor parameters.""" - - -@dataclass -class RenameApiCall(BreakagePrimitive): - """Rename a function/method call to simulate API deprecation.""" - - old_name: str = "" - new_name: str = "" - - def __post_init__(self) -> None: - self.category = "api_drift" - self.name = "RenameApiCall" - self.description = f"Rename {self.old_name} -> {self.new_name}" - - def apply(self, script: str) -> str: - if not self.old_name: - return script - # Use word-boundary replacement so we don't substring-match identifiers. - pattern = re.compile(rf"(? dict: - return {"old_name": self.old_name, "new_name": self.new_name} - - -@dataclass -class DeprecateImport(BreakagePrimitive): - """Change an import path to simulate module restructuring.""" - - old_module: str = "" - new_module: str = "" - - def __post_init__(self) -> None: - self.category = "import_drift" - self.name = "DeprecateImport" - self.description = f"Move {self.old_module} -> {self.new_module}" - - def apply(self, script: str) -> str: - if not self.old_module: - return script - return script.replace(self.old_module, self.new_module) - - def _get_params(self) -> dict: - return {"old_module": self.old_module, "new_module": self.new_module} - - -@dataclass -class ChangeArgumentSignature(BreakagePrimitive): - """Remove an expected kwarg (and document a new required one).""" - - function_name: str = "" - removed_arg: str = "" - added_arg: str = "" - added_value: str = "" - - def __post_init__(self) -> None: - self.category = "api_drift" - self.name = "ChangeArgumentSignature" - self.description = ( - f"Change args of {self.function_name}: -{self.removed_arg} +{self.added_arg}" - ) - - def apply(self, script: str) -> str: - if not self.removed_arg: - return script - pattern = rf"(\b{re.escape(self.removed_arg)}\s*=\s*[^,)]+,?\s*)" - return re.sub(pattern, "", script) - - def _get_params(self) -> dict: - return { - "function_name": self.function_name, - "removed_arg": self.removed_arg, - "added_arg": self.added_arg, - "added_value": self.added_value, - } - - -@dataclass -class ModifyConfigField(BreakagePrimitive): - """Change a config-class default value to simulate behaviour drift.""" - - config_class: str = "" - field_name: str = "" - new_value: str = "" - - def __post_init__(self) -> None: - self.category = "config_drift" - self.name = "ModifyConfigField" - self.description = f"Change {self.config_class}.{self.field_name}" - - def apply(self, script: str) -> str: - if not self.field_name: - return script - pattern = rf"({re.escape(self.field_name)}\s*=\s*)([^,)\n]+)" - return re.sub(pattern, rf"\g<1>{self.new_value}", script) - - def _get_params(self) -> dict: - return { - "config_class": self.config_class, - "field_name": self.field_name, - "new_value": self.new_value, - } - - -@dataclass -class RestructureDatasetSchema(BreakagePrimitive): - """Rename a dataset column reference to simulate schema drift.""" - - old_column: str = "" - new_column: str = "" - - def __post_init__(self) -> None: - self.category = "dataset_drift" - self.name = "RestructureDatasetSchema" - self.description = f"Rename column {self.old_column} -> {self.new_column}" - - def apply(self, script: str) -> str: - if not self.old_column: - return script - return script.replace( - f'"{self.old_column}"', f'"{self.new_column}"' - ).replace( - f"'{self.old_column}'", f"'{self.new_column}'" - ) - - def _get_params(self) -> dict: - return {"old_column": self.old_column, "new_column": self.new_column} - - -@dataclass -class ChangeTokenizerBehavior(BreakagePrimitive): - """Change tokenizer call arguments.""" - - old_kwarg: str = "" - old_value: str = "" - new_kwarg: str = "" - new_value: str = "" - - def __post_init__(self) -> None: - self.category = "tokenizer_drift" - self.name = "ChangeTokenizerBehavior" - self.description = f"Change tokenizer kwarg {self.old_kwarg}={self.old_value} -> {self.new_kwarg}={self.new_value}" - - def apply(self, script: str) -> str: - if not self.old_kwarg: - return script - pattern = rf"{re.escape(self.old_kwarg)}\s*=\s*{re.escape(self.old_value)}" - replacement = f"{self.new_kwarg}={self.new_value}" - return re.sub(pattern, replacement, script) - - def _get_params(self) -> dict: - return { - "old_kwarg": self.old_kwarg, - "old_value": self.old_value, - "new_kwarg": self.new_kwarg, - "new_value": self.new_value, - } - - -@dataclass -class RemoveDeprecatedMethod(BreakagePrimitive): - """Remove a method that has been deprecated, leaving a sentinel that - raises AttributeError-style errors when the script runs.""" - - class_name: str = "" - method_name: str = "" - replacement: str = "" - - def __post_init__(self) -> None: - self.category = "api_drift" - self.name = "RemoveDeprecatedMethod" - self.description = f"Remove {self.class_name}.{self.method_name}" - - def apply(self, script: str) -> str: - if not self.method_name: - return script - return script.replace( - f".{self.method_name}(", f".{self.method_name}_DEPRECATED(" - ) - - def _get_params(self) -> dict: - return { - "class_name": self.class_name, - "method_name": self.method_name, - "replacement": self.replacement, - } - - -@dataclass -class ChangeReturnType(BreakagePrimitive): - """A function now returns a different structure (e.g. tuple -> object).""" - - function_name: str = "" - old_access: str = "" - new_access: str = "" - - def __post_init__(self) -> None: - self.category = "api_drift" - self.name = "ChangeReturnType" - self.description = f"Change return type of {self.function_name}" - - def apply(self, script: str) -> str: - if self.old_access and self.new_access: - return script.replace(self.old_access, self.new_access) - return script - - def _get_params(self) -> dict: - return { - "function_name": self.function_name, - "old_access": self.old_access, - "new_access": self.new_access, - } - - -PRIMITIVE_REGISTRY: dict[str, type[BreakagePrimitive]] = { - "RenameApiCall": RenameApiCall, - "DeprecateImport": DeprecateImport, - "ChangeArgumentSignature": ChangeArgumentSignature, - "ModifyConfigField": ModifyConfigField, - "RestructureDatasetSchema": RestructureDatasetSchema, - "ChangeTokenizerBehavior": ChangeTokenizerBehavior, - "RemoveDeprecatedMethod": RemoveDeprecatedMethod, - "ChangeReturnType": ChangeReturnType, -} - - -def parse_breakage_spec(spec: dict) -> BreakagePrimitive: - """Parse a JSON breakage spec into a BreakagePrimitive object. - - Tolerates extra keys; ignores unknown params (LLMs hallucinate these). - """ - ptype = spec.get("primitive_type", "") - params = spec.get("params", {}) or {} - - if ptype not in PRIMITIVE_REGISTRY: - raise ValueError( - f"Unknown primitive type: {ptype!r}. " - f"Valid types: {list(PRIMITIVE_REGISTRY)}" - ) - - cls = PRIMITIVE_REGISTRY[ptype] - # Filter to known fields only so a hallucinated kwarg can't crash us. - valid_fields = { - f.name for f in cls.__dataclass_fields__.values() if f.init # type: ignore[attr-defined] - } - filtered = {k: v for k, v in params.items() if k in valid_fields} - return cls(**filtered) +"""8 breakage primitives representing real HuggingFace/PyTorch ecosystem drift. + +Each primitive transforms a working script to simulate a library upgrade +breakage. They double as the Drift Generator's structured action space. +""" +from __future__ import annotations + +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass, field + + +@dataclass +class BreakagePrimitive(ABC): + """Abstract base class for all breakage types.""" + + category: str = field(default="generic", init=False) + name: str = field(default="BreakagePrimitive", init=False) + description: str = field(default="", init=False) + + @abstractmethod + def apply(self, script: str) -> str: + """Transform `script` to introduce the breakage.""" + + def to_spec(self) -> dict: + """Serialize to JSON-compatible spec for the LLM action space.""" + return { + "primitive_type": self.__class__.__name__, + "category": self.category, + "params": self._get_params(), + } + + @abstractmethod + def _get_params(self) -> dict: + """Return a JSON-serializable dict of constructor parameters.""" + + +@dataclass +class RenameApiCall(BreakagePrimitive): + """Rename a function/method call to simulate API deprecation.""" + + old_name: str = "" + new_name: str = "" + + def __post_init__(self) -> None: + self.category = "api_drift" + self.name = "RenameApiCall" + self.description = f"Rename {self.old_name} -> {self.new_name}" + + def apply(self, script: str) -> str: + if not self.old_name: + return script + # Use word-boundary replacement so we don't substring-match identifiers. + pattern = re.compile(rf"(? dict: + return {"old_name": self.old_name, "new_name": self.new_name} + + +@dataclass +class DeprecateImport(BreakagePrimitive): + """Change an import path to simulate module restructuring.""" + + old_module: str = "" + new_module: str = "" + + def __post_init__(self) -> None: + self.category = "import_drift" + self.name = "DeprecateImport" + self.description = f"Move {self.old_module} -> {self.new_module}" + + def apply(self, script: str) -> str: + if not self.old_module: + return script + return script.replace(self.old_module, self.new_module) + + def _get_params(self) -> dict: + return {"old_module": self.old_module, "new_module": self.new_module} + + +@dataclass +class ChangeArgumentSignature(BreakagePrimitive): + """Remove an expected kwarg (and document a new required one).""" + + function_name: str = "" + removed_arg: str = "" + added_arg: str = "" + added_value: str = "" + + def __post_init__(self) -> None: + self.category = "api_drift" + self.name = "ChangeArgumentSignature" + self.description = ( + f"Change args of {self.function_name}: -{self.removed_arg} +{self.added_arg}" + ) + + def apply(self, script: str) -> str: + if not self.removed_arg: + return script + pattern = rf"(\b{re.escape(self.removed_arg)}\s*=\s*[^,)]+,?\s*)" + return re.sub(pattern, "", script) + + def _get_params(self) -> dict: + return { + "function_name": self.function_name, + "removed_arg": self.removed_arg, + "added_arg": self.added_arg, + "added_value": self.added_value, + } + + +@dataclass +class ModifyConfigField(BreakagePrimitive): + """Change a config-class default value to simulate behaviour drift.""" + + config_class: str = "" + field_name: str = "" + new_value: str = "" + + def __post_init__(self) -> None: + self.category = "config_drift" + self.name = "ModifyConfigField" + self.description = f"Change {self.config_class}.{self.field_name}" + + def apply(self, script: str) -> str: + if not self.field_name: + return script + pattern = rf"({re.escape(self.field_name)}\s*=\s*)([^,)\n]+)" + return re.sub(pattern, rf"\g<1>{self.new_value}", script) + + def _get_params(self) -> dict: + return { + "config_class": self.config_class, + "field_name": self.field_name, + "new_value": self.new_value, + } + + +@dataclass +class RestructureDatasetSchema(BreakagePrimitive): + """Rename a dataset column reference to simulate schema drift.""" + + old_column: str = "" + new_column: str = "" + + def __post_init__(self) -> None: + self.category = "dataset_drift" + self.name = "RestructureDatasetSchema" + self.description = f"Rename column {self.old_column} -> {self.new_column}" + + def apply(self, script: str) -> str: + if not self.old_column: + return script + return script.replace( + f'"{self.old_column}"', f'"{self.new_column}"' + ).replace( + f"'{self.old_column}'", f"'{self.new_column}'" + ) + + def _get_params(self) -> dict: + return {"old_column": self.old_column, "new_column": self.new_column} + + +@dataclass +class ChangeTokenizerBehavior(BreakagePrimitive): + """Change tokenizer call arguments.""" + + old_kwarg: str = "" + old_value: str = "" + new_kwarg: str = "" + new_value: str = "" + + def __post_init__(self) -> None: + self.category = "tokenizer_drift" + self.name = "ChangeTokenizerBehavior" + self.description = f"Change tokenizer kwarg {self.old_kwarg}={self.old_value} -> {self.new_kwarg}={self.new_value}" + + def apply(self, script: str) -> str: + if not self.old_kwarg: + return script + pattern = rf"{re.escape(self.old_kwarg)}\s*=\s*{re.escape(self.old_value)}" + replacement = f"{self.new_kwarg}={self.new_value}" + return re.sub(pattern, replacement, script) + + def _get_params(self) -> dict: + return { + "old_kwarg": self.old_kwarg, + "old_value": self.old_value, + "new_kwarg": self.new_kwarg, + "new_value": self.new_value, + } + + +@dataclass +class RemoveDeprecatedMethod(BreakagePrimitive): + """Remove a method that has been deprecated, leaving a sentinel that + raises AttributeError-style errors when the script runs.""" + + class_name: str = "" + method_name: str = "" + replacement: str = "" + + def __post_init__(self) -> None: + self.category = "api_drift" + self.name = "RemoveDeprecatedMethod" + self.description = f"Remove {self.class_name}.{self.method_name}" + + def apply(self, script: str) -> str: + if not self.method_name: + return script + return script.replace( + f".{self.method_name}(", f".{self.method_name}_DEPRECATED(" + ) + + def _get_params(self) -> dict: + return { + "class_name": self.class_name, + "method_name": self.method_name, + "replacement": self.replacement, + } + + +@dataclass +class ChangeReturnType(BreakagePrimitive): + """A function now returns a different structure (e.g. tuple -> object).""" + + function_name: str = "" + old_access: str = "" + new_access: str = "" + + def __post_init__(self) -> None: + self.category = "api_drift" + self.name = "ChangeReturnType" + self.description = f"Change return type of {self.function_name}" + + def apply(self, script: str) -> str: + if self.old_access and self.new_access: + return script.replace(self.old_access, self.new_access) + return script + + def _get_params(self) -> dict: + return { + "function_name": self.function_name, + "old_access": self.old_access, + "new_access": self.new_access, + } + + +PRIMITIVE_REGISTRY: dict[str, type[BreakagePrimitive]] = { + "RenameApiCall": RenameApiCall, + "DeprecateImport": DeprecateImport, + "ChangeArgumentSignature": ChangeArgumentSignature, + "ModifyConfigField": ModifyConfigField, + "RestructureDatasetSchema": RestructureDatasetSchema, + "ChangeTokenizerBehavior": ChangeTokenizerBehavior, + "RemoveDeprecatedMethod": RemoveDeprecatedMethod, + "ChangeReturnType": ChangeReturnType, +} + + +def parse_breakage_spec(spec: dict) -> BreakagePrimitive: + """Parse a JSON breakage spec into a BreakagePrimitive object. + + Tolerates extra keys; ignores unknown params (LLMs hallucinate these). + """ + ptype = spec.get("primitive_type", "") + params = spec.get("params", {}) or {} + + if ptype not in PRIMITIVE_REGISTRY: + raise ValueError( + f"Unknown primitive type: {ptype!r}. " + f"Valid types: {list(PRIMITIVE_REGISTRY)}" + ) + + cls = PRIMITIVE_REGISTRY[ptype] + # Filter to known fields only so a hallucinated kwarg can't crash us. + valid_fields = { + f.name for f in cls.__dataclass_fields__.values() if f.init # type: ignore[attr-defined] + } + filtered = {k: v for k, v in params.items() if k in valid_fields} + return cls(**filtered) diff --git a/forgeenv/primitives/drift_taxonomy.yaml b/forgeenv/primitives/drift_taxonomy.yaml index 186a680f767228fa0dc8cdfbc7d37f9045b7a104..9bdba4f917250d149ea395f1c21a0be72e827575 100644 --- a/forgeenv/primitives/drift_taxonomy.yaml +++ b/forgeenv/primitives/drift_taxonomy.yaml @@ -1,217 +1,217 @@ -# Drift taxonomy: real HuggingFace/PyTorch breakages observed across version bumps. -# Used to seed the Drift Generator's initial proposal distribution and to anchor -# warm-start pair generation in things that actually happened in the wild. -- version_range: "transformers 4.36 -> 4.45" - affected_api: "Trainer.evaluate" - description: "Trainer.evaluate() return type changed shape; metrics now nested under .metrics" - breakage_primitive: "ChangeReturnType" - params: - function_name: "evaluate" - old_access: "trainer.evaluate()" - new_access: "trainer.evaluate().metrics" - repair_primitive: "RestoreReturnAccess" - category: "api_drift" - -- version_range: "transformers 4.30 -> 4.40" - affected_api: "TrainingArguments.evaluation_strategy" - description: "Renamed evaluation_strategy -> eval_strategy" - breakage_primitive: "RenameApiCall" - params: - old_name: "evaluation_strategy" - new_name: "eval_strategy" - repair_primitive: "RestoreApiCall" - category: "api_drift" - -- version_range: "datasets 2.14 -> 3.0" - affected_api: "load_dataset" - description: "Default split column was renamed in some GLUE configs" - breakage_primitive: "RestructureDatasetSchema" - params: - old_column: "label" - new_column: "labels" - repair_primitive: "RestoreColumn" - category: "dataset_drift" - -- version_range: "transformers 4.40 -> 4.50" - affected_api: "Trainer.predict" - description: "Method removed; users should use evaluate() with prediction_loss_only=False" - breakage_primitive: "RemoveDeprecatedMethod" - params: - class_name: "Trainer" - method_name: "predict" - replacement: "evaluate" - repair_primitive: "RestoreMethod" - category: "api_drift" - -- version_range: "transformers 4.36 -> 4.40" - affected_api: "TrainingArguments" - description: "num_train_epochs default behavior changed; max_steps now preferred" - breakage_primitive: "ModifyConfigField" - params: - config_class: "TrainingArguments" - field_name: "num_train_epochs" - new_value: "0" - repair_primitive: "RestoreConfigField" - category: "config_drift" - -- version_range: "transformers 4.34 -> 4.42" - affected_api: "Tokenizer.__call__" - description: "padding=True semantics changed; users should pass padding='max_length'" - breakage_primitive: "ChangeTokenizerBehavior" - params: - old_kwarg: "padding" - old_value: "True" - new_kwarg: "padding" - new_value: '"max_length"' - repair_primitive: "RestoreTokenizerKwarg" - category: "tokenizer_drift" - -- version_range: "transformers 4.20 -> 4.30" - affected_api: "imports" - description: "transformers.training_args moved to transformers.training_args_pt" - breakage_primitive: "DeprecateImport" - params: - old_module: "from transformers.training_args" - new_module: "from transformers.training_args_pt" - repair_primitive: "RestoreImport" - category: "import_drift" - -- version_range: "transformers 4.45 -> 4.50" - affected_api: "save_pretrained" - description: "save_pretrained() now requires safe_serialization to default True" - breakage_primitive: "ChangeArgumentSignature" - params: - function_name: "save_pretrained" - removed_arg: "safe_serialization" - added_arg: "safe_serialization" - added_value: "True" - repair_primitive: "RestoreArgument" - category: "api_drift" - -- version_range: "datasets 2.18 -> 3.0" - affected_api: "Dataset.set_format" - description: "set_format(type='torch') signature stricter, columns required" - breakage_primitive: "ChangeArgumentSignature" - params: - function_name: "set_format" - removed_arg: "columns" - added_arg: "columns" - added_value: '["input_ids", "attention_mask", "labels"]' - repair_primitive: "RestoreArgument" - category: "api_drift" - -- version_range: "transformers 4.36 -> 4.45" - affected_api: "Tokenizer.__call__" - description: "max_length default reduced from 512 -> 256 for some tokenizers" - breakage_primitive: "ModifyConfigField" - params: - config_class: "tokenizer" - field_name: "max_length" - new_value: "256" - repair_primitive: "RestoreConfigField" - category: "tokenizer_drift" - -- version_range: "transformers 4.40 -> 4.45" - affected_api: "DataCollatorWithPadding" - description: "Renamed `tokenizer` -> `processing_class` in DataCollator constructors" - breakage_primitive: "RenameApiCall" - params: - old_name: "tokenizer" - new_name: "processing_class" - repair_primitive: "RestoreApiCall" - category: "api_drift" - -- version_range: "datasets 2.14 -> 2.18" - affected_api: "load_dataset" - description: "Some splits renamed train[:500] semantics changed" - breakage_primitive: "RestructureDatasetSchema" - params: - old_column: "sentence" - new_column: "text" - repair_primitive: "RestoreColumn" - category: "dataset_drift" - -- version_range: "transformers 4.45 -> 4.50" - affected_api: "Trainer" - description: "evaluation_strategy was deprecated and removed" - breakage_primitive: "RemoveDeprecatedMethod" - params: - class_name: "Trainer" - method_name: "evaluate" - replacement: "evaluate_legacy" - repair_primitive: "RestoreMethod" - category: "api_drift" - -- version_range: "transformers 4.30 -> 4.40" - affected_api: "PreTrainedModel.from_pretrained" - description: "torch_dtype now required for some quantized model paths" - breakage_primitive: "ChangeArgumentSignature" - params: - function_name: "from_pretrained" - removed_arg: "torch_dtype" - added_arg: "torch_dtype" - added_value: '"auto"' - repair_primitive: "RestoreArgument" - category: "api_drift" - -- version_range: "datasets 3.0 -> 3.2" - affected_api: "Dataset.rename_column" - description: "rename_column raises if target name exists" - breakage_primitive: "RestructureDatasetSchema" - params: - old_column: "labels" - new_column: "label" - repair_primitive: "RestoreColumn" - category: "dataset_drift" - -- version_range: "transformers 4.36 -> 4.42" - affected_api: "TrainingArguments.report_to" - description: "Default report_to changed from 'all' to 'none'" - breakage_primitive: "ModifyConfigField" - params: - config_class: "TrainingArguments" - field_name: "report_to" - new_value: '"all"' - repair_primitive: "RestoreConfigField" - category: "config_drift" - -- version_range: "transformers 4.40 -> 4.50" - affected_api: "imports" - description: "transformers.deepspeed moved to accelerate.utils.deepspeed" - breakage_primitive: "DeprecateImport" - params: - old_module: "from transformers.deepspeed" - new_module: "from accelerate.utils.deepspeed" - repair_primitive: "RestoreImport" - category: "import_drift" - -- version_range: "transformers 4.45 -> 4.50" - affected_api: "Tokenizer return" - description: "Tokenizer call output now returns a BatchEncoding with .encodings attribute" - breakage_primitive: "ChangeReturnType" - params: - function_name: "tokenizer" - old_access: "tokenizer(text)" - new_access: "tokenizer(text).encodings" - repair_primitive: "RestoreReturnAccess" - category: "api_drift" - -- version_range: "transformers 4.30 -> 4.40" - affected_api: "save_pretrained" - description: "save_pretrained -> save_pretrained_directory rename in some classes" - breakage_primitive: "RenameApiCall" - params: - old_name: "save_pretrained" - new_name: "save_pretrained_directory" - repair_primitive: "RestoreApiCall" - category: "api_drift" - -- version_range: "transformers 4.45 -> 4.50" - affected_api: "TrainingArguments.no_cuda" - description: "no_cuda renamed to use_cpu (logic inverted)" - breakage_primitive: "RenameApiCall" - params: - old_name: "no_cuda" - new_name: "use_cpu" - repair_primitive: "RestoreApiCall" - category: "config_drift" +# Drift taxonomy: real HuggingFace/PyTorch breakages observed across version bumps. +# Used to seed the Drift Generator's initial proposal distribution and to anchor +# warm-start pair generation in things that actually happened in the wild. +- version_range: "transformers 4.36 -> 4.45" + affected_api: "Trainer.evaluate" + description: "Trainer.evaluate() return type changed shape; metrics now nested under .metrics" + breakage_primitive: "ChangeReturnType" + params: + function_name: "evaluate" + old_access: "trainer.evaluate()" + new_access: "trainer.evaluate().metrics" + repair_primitive: "RestoreReturnAccess" + category: "api_drift" + +- version_range: "transformers 4.30 -> 4.40" + affected_api: "TrainingArguments.evaluation_strategy" + description: "Renamed evaluation_strategy -> eval_strategy" + breakage_primitive: "RenameApiCall" + params: + old_name: "evaluation_strategy" + new_name: "eval_strategy" + repair_primitive: "RestoreApiCall" + category: "api_drift" + +- version_range: "datasets 2.14 -> 3.0" + affected_api: "load_dataset" + description: "Default split column was renamed in some GLUE configs" + breakage_primitive: "RestructureDatasetSchema" + params: + old_column: "label" + new_column: "labels" + repair_primitive: "RestoreColumn" + category: "dataset_drift" + +- version_range: "transformers 4.40 -> 4.50" + affected_api: "Trainer.predict" + description: "Method removed; users should use evaluate() with prediction_loss_only=False" + breakage_primitive: "RemoveDeprecatedMethod" + params: + class_name: "Trainer" + method_name: "predict" + replacement: "evaluate" + repair_primitive: "RestoreMethod" + category: "api_drift" + +- version_range: "transformers 4.36 -> 4.40" + affected_api: "TrainingArguments" + description: "num_train_epochs default behavior changed; max_steps now preferred" + breakage_primitive: "ModifyConfigField" + params: + config_class: "TrainingArguments" + field_name: "num_train_epochs" + new_value: "0" + repair_primitive: "RestoreConfigField" + category: "config_drift" + +- version_range: "transformers 4.34 -> 4.42" + affected_api: "Tokenizer.__call__" + description: "padding=True semantics changed; users should pass padding='max_length'" + breakage_primitive: "ChangeTokenizerBehavior" + params: + old_kwarg: "padding" + old_value: "True" + new_kwarg: "padding" + new_value: '"max_length"' + repair_primitive: "RestoreTokenizerKwarg" + category: "tokenizer_drift" + +- version_range: "transformers 4.20 -> 4.30" + affected_api: "imports" + description: "transformers.training_args moved to transformers.training_args_pt" + breakage_primitive: "DeprecateImport" + params: + old_module: "from transformers.training_args" + new_module: "from transformers.training_args_pt" + repair_primitive: "RestoreImport" + category: "import_drift" + +- version_range: "transformers 4.45 -> 4.50" + affected_api: "save_pretrained" + description: "save_pretrained() now requires safe_serialization to default True" + breakage_primitive: "ChangeArgumentSignature" + params: + function_name: "save_pretrained" + removed_arg: "safe_serialization" + added_arg: "safe_serialization" + added_value: "True" + repair_primitive: "RestoreArgument" + category: "api_drift" + +- version_range: "datasets 2.18 -> 3.0" + affected_api: "Dataset.set_format" + description: "set_format(type='torch') signature stricter, columns required" + breakage_primitive: "ChangeArgumentSignature" + params: + function_name: "set_format" + removed_arg: "columns" + added_arg: "columns" + added_value: '["input_ids", "attention_mask", "labels"]' + repair_primitive: "RestoreArgument" + category: "api_drift" + +- version_range: "transformers 4.36 -> 4.45" + affected_api: "Tokenizer.__call__" + description: "max_length default reduced from 512 -> 256 for some tokenizers" + breakage_primitive: "ModifyConfigField" + params: + config_class: "tokenizer" + field_name: "max_length" + new_value: "256" + repair_primitive: "RestoreConfigField" + category: "tokenizer_drift" + +- version_range: "transformers 4.40 -> 4.45" + affected_api: "DataCollatorWithPadding" + description: "Renamed `tokenizer` -> `processing_class` in DataCollator constructors" + breakage_primitive: "RenameApiCall" + params: + old_name: "tokenizer" + new_name: "processing_class" + repair_primitive: "RestoreApiCall" + category: "api_drift" + +- version_range: "datasets 2.14 -> 2.18" + affected_api: "load_dataset" + description: "Some splits renamed train[:500] semantics changed" + breakage_primitive: "RestructureDatasetSchema" + params: + old_column: "sentence" + new_column: "text" + repair_primitive: "RestoreColumn" + category: "dataset_drift" + +- version_range: "transformers 4.45 -> 4.50" + affected_api: "Trainer" + description: "evaluation_strategy was deprecated and removed" + breakage_primitive: "RemoveDeprecatedMethod" + params: + class_name: "Trainer" + method_name: "evaluate" + replacement: "evaluate_legacy" + repair_primitive: "RestoreMethod" + category: "api_drift" + +- version_range: "transformers 4.30 -> 4.40" + affected_api: "PreTrainedModel.from_pretrained" + description: "torch_dtype now required for some quantized model paths" + breakage_primitive: "ChangeArgumentSignature" + params: + function_name: "from_pretrained" + removed_arg: "torch_dtype" + added_arg: "torch_dtype" + added_value: '"auto"' + repair_primitive: "RestoreArgument" + category: "api_drift" + +- version_range: "datasets 3.0 -> 3.2" + affected_api: "Dataset.rename_column" + description: "rename_column raises if target name exists" + breakage_primitive: "RestructureDatasetSchema" + params: + old_column: "labels" + new_column: "label" + repair_primitive: "RestoreColumn" + category: "dataset_drift" + +- version_range: "transformers 4.36 -> 4.42" + affected_api: "TrainingArguments.report_to" + description: "Default report_to changed from 'all' to 'none'" + breakage_primitive: "ModifyConfigField" + params: + config_class: "TrainingArguments" + field_name: "report_to" + new_value: '"all"' + repair_primitive: "RestoreConfigField" + category: "config_drift" + +- version_range: "transformers 4.40 -> 4.50" + affected_api: "imports" + description: "transformers.deepspeed moved to accelerate.utils.deepspeed" + breakage_primitive: "DeprecateImport" + params: + old_module: "from transformers.deepspeed" + new_module: "from accelerate.utils.deepspeed" + repair_primitive: "RestoreImport" + category: "import_drift" + +- version_range: "transformers 4.45 -> 4.50" + affected_api: "Tokenizer return" + description: "Tokenizer call output now returns a BatchEncoding with .encodings attribute" + breakage_primitive: "ChangeReturnType" + params: + function_name: "tokenizer" + old_access: "tokenizer(text)" + new_access: "tokenizer(text).encodings" + repair_primitive: "RestoreReturnAccess" + category: "api_drift" + +- version_range: "transformers 4.30 -> 4.40" + affected_api: "save_pretrained" + description: "save_pretrained -> save_pretrained_directory rename in some classes" + breakage_primitive: "RenameApiCall" + params: + old_name: "save_pretrained" + new_name: "save_pretrained_directory" + repair_primitive: "RestoreApiCall" + category: "api_drift" + +- version_range: "transformers 4.45 -> 4.50" + affected_api: "TrainingArguments.no_cuda" + description: "no_cuda renamed to use_cpu (logic inverted)" + breakage_primitive: "RenameApiCall" + params: + old_name: "no_cuda" + new_name: "use_cpu" + repair_primitive: "RestoreApiCall" + category: "config_drift" diff --git a/forgeenv/primitives/repair_primitives.py b/forgeenv/primitives/repair_primitives.py index f7f438fdac78c9ac84d35a894ff2eb0544558bbd..14cd6cde709ff1c39785440fed28e0d501debce1 100644 --- a/forgeenv/primitives/repair_primitives.py +++ b/forgeenv/primitives/repair_primitives.py @@ -1,241 +1,241 @@ -"""Repair primitives β€” direct inverses of the 8 breakage primitives. - -Used during warm-start data generation: for every (script, breakage) -pair we know the canonical repair, so we can write SFT pairs. - -These are also useful for unit-testing the breakage primitives: -apply(breakage) then apply(repair) should be (close to) the identity. -""" -from __future__ import annotations - -import re -from abc import ABC, abstractmethod -from dataclasses import dataclass, field - - -@dataclass -class RepairPrimitive(ABC): - category: str = field(default="generic", init=False) - name: str = field(default="RepairPrimitive", init=False) - description: str = field(default="", init=False) - - @abstractmethod - def apply(self, script: str) -> str: - """Transform `script` to undo the corresponding breakage.""" - - def to_spec(self) -> dict: - return { - "primitive_type": self.__class__.__name__, - "category": self.category, - "params": self._get_params(), - } - - @abstractmethod - def _get_params(self) -> dict: - """Return JSON-serializable constructor parameters.""" - - -@dataclass -class RestoreApiCall(RepairPrimitive): - new_name: str = "" - old_name: str = "" - - def __post_init__(self) -> None: - self.category = "api_drift" - self.name = "RestoreApiCall" - self.description = f"Rename {self.new_name} -> {self.old_name}" - - def apply(self, script: str) -> str: - if not self.new_name: - return script - pattern = re.compile(rf"(? dict: - return {"new_name": self.new_name, "old_name": self.old_name} - - -@dataclass -class RestoreImport(RepairPrimitive): - new_module: str = "" - old_module: str = "" - - def __post_init__(self) -> None: - self.category = "import_drift" - self.name = "RestoreImport" - self.description = f"Restore import {self.new_module} -> {self.old_module}" - - def apply(self, script: str) -> str: - return script.replace(self.new_module, self.old_module) - - def _get_params(self) -> dict: - return {"new_module": self.new_module, "old_module": self.old_module} - - -@dataclass -class RestoreArgument(RepairPrimitive): - """Re-add a removed argument to a function call.""" - - function_name: str = "" - arg_name: str = "" - arg_value: str = "" - - def __post_init__(self) -> None: - self.category = "api_drift" - self.name = "RestoreArgument" - self.description = ( - f"Add {self.arg_name}={self.arg_value} to {self.function_name}()" - ) - - def apply(self, script: str) -> str: - if not self.function_name: - return script - # Insert the kwarg right after the function-name's opening paren. - pattern = rf"({re.escape(self.function_name)}\s*\()(\s*)" - replacement = rf"\g<1>{self.arg_name}={self.arg_value}, \g<2>" - return re.sub(pattern, replacement, script, count=1) - - def _get_params(self) -> dict: - return { - "function_name": self.function_name, - "arg_name": self.arg_name, - "arg_value": self.arg_value, - } - - -@dataclass -class RestoreConfigField(RepairPrimitive): - field_name: str = "" - old_value: str = "" - - def __post_init__(self) -> None: - self.category = "config_drift" - self.name = "RestoreConfigField" - self.description = f"Restore {self.field_name}={self.old_value}" - - def apply(self, script: str) -> str: - if not self.field_name: - return script - pattern = rf"({re.escape(self.field_name)}\s*=\s*)([^,)\n]+)" - return re.sub(pattern, rf"\g<1>{self.old_value}", script) - - def _get_params(self) -> dict: - return {"field_name": self.field_name, "old_value": self.old_value} - - -@dataclass -class RestoreColumn(RepairPrimitive): - new_column: str = "" - old_column: str = "" - - def __post_init__(self) -> None: - self.category = "dataset_drift" - self.name = "RestoreColumn" - self.description = f"Rename column {self.new_column} -> {self.old_column}" - - def apply(self, script: str) -> str: - return script.replace( - f'"{self.new_column}"', f'"{self.old_column}"' - ).replace( - f"'{self.new_column}'", f"'{self.old_column}'" - ) - - def _get_params(self) -> dict: - return {"new_column": self.new_column, "old_column": self.old_column} - - -@dataclass -class RestoreTokenizerKwarg(RepairPrimitive): - new_kwarg: str = "" - new_value: str = "" - old_kwarg: str = "" - old_value: str = "" - - def __post_init__(self) -> None: - self.category = "tokenizer_drift" - self.name = "RestoreTokenizerKwarg" - self.description = ( - f"Restore tokenizer {self.new_kwarg}={self.new_value} -> " - f"{self.old_kwarg}={self.old_value}" - ) - - def apply(self, script: str) -> str: - if not self.new_kwarg: - return script - pattern = rf"{re.escape(self.new_kwarg)}\s*=\s*{re.escape(self.new_value)}" - replacement = f"{self.old_kwarg}={self.old_value}" - return re.sub(pattern, replacement, script) - - def _get_params(self) -> dict: - return { - "new_kwarg": self.new_kwarg, - "new_value": self.new_value, - "old_kwarg": self.old_kwarg, - "old_value": self.old_value, - } - - -@dataclass -class RestoreMethod(RepairPrimitive): - method_name: str = "" - - def __post_init__(self) -> None: - self.category = "api_drift" - self.name = "RestoreMethod" - self.description = f"Un-deprecate .{self.method_name}()" - - def apply(self, script: str) -> str: - if not self.method_name: - return script - return script.replace( - f".{self.method_name}_DEPRECATED(", f".{self.method_name}(" - ) - - def _get_params(self) -> dict: - return {"method_name": self.method_name} - - -@dataclass -class RestoreReturnAccess(RepairPrimitive): - new_access: str = "" - old_access: str = "" - - def __post_init__(self) -> None: - self.category = "api_drift" - self.name = "RestoreReturnAccess" - self.description = f"Restore return-access {self.new_access} -> {self.old_access}" - - def apply(self, script: str) -> str: - if not self.new_access: - return script - return script.replace(self.new_access, self.old_access) - - def _get_params(self) -> dict: - return {"new_access": self.new_access, "old_access": self.old_access} - - -REPAIR_REGISTRY: dict[str, type[RepairPrimitive]] = { - "RestoreApiCall": RestoreApiCall, - "RestoreImport": RestoreImport, - "RestoreArgument": RestoreArgument, - "RestoreConfigField": RestoreConfigField, - "RestoreColumn": RestoreColumn, - "RestoreTokenizerKwarg": RestoreTokenizerKwarg, - "RestoreMethod": RestoreMethod, - "RestoreReturnAccess": RestoreReturnAccess, -} - - -# Map a breakage primitive's class name to the repair-primitive class that -# inverts it. Used by the warm-start pair generator and by the demo / repair -# library curator. -BREAKAGE_TO_REPAIR: dict[str, str] = { - "RenameApiCall": "RestoreApiCall", - "DeprecateImport": "RestoreImport", - "ChangeArgumentSignature": "RestoreArgument", - "ModifyConfigField": "RestoreConfigField", - "RestructureDatasetSchema": "RestoreColumn", - "ChangeTokenizerBehavior": "RestoreTokenizerKwarg", - "RemoveDeprecatedMethod": "RestoreMethod", - "ChangeReturnType": "RestoreReturnAccess", -} +"""Repair primitives β€” direct inverses of the 8 breakage primitives. + +Used during warm-start data generation: for every (script, breakage) +pair we know the canonical repair, so we can write SFT pairs. + +These are also useful for unit-testing the breakage primitives: +apply(breakage) then apply(repair) should be (close to) the identity. +""" +from __future__ import annotations + +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass, field + + +@dataclass +class RepairPrimitive(ABC): + category: str = field(default="generic", init=False) + name: str = field(default="RepairPrimitive", init=False) + description: str = field(default="", init=False) + + @abstractmethod + def apply(self, script: str) -> str: + """Transform `script` to undo the corresponding breakage.""" + + def to_spec(self) -> dict: + return { + "primitive_type": self.__class__.__name__, + "category": self.category, + "params": self._get_params(), + } + + @abstractmethod + def _get_params(self) -> dict: + """Return JSON-serializable constructor parameters.""" + + +@dataclass +class RestoreApiCall(RepairPrimitive): + new_name: str = "" + old_name: str = "" + + def __post_init__(self) -> None: + self.category = "api_drift" + self.name = "RestoreApiCall" + self.description = f"Rename {self.new_name} -> {self.old_name}" + + def apply(self, script: str) -> str: + if not self.new_name: + return script + pattern = re.compile(rf"(? dict: + return {"new_name": self.new_name, "old_name": self.old_name} + + +@dataclass +class RestoreImport(RepairPrimitive): + new_module: str = "" + old_module: str = "" + + def __post_init__(self) -> None: + self.category = "import_drift" + self.name = "RestoreImport" + self.description = f"Restore import {self.new_module} -> {self.old_module}" + + def apply(self, script: str) -> str: + return script.replace(self.new_module, self.old_module) + + def _get_params(self) -> dict: + return {"new_module": self.new_module, "old_module": self.old_module} + + +@dataclass +class RestoreArgument(RepairPrimitive): + """Re-add a removed argument to a function call.""" + + function_name: str = "" + arg_name: str = "" + arg_value: str = "" + + def __post_init__(self) -> None: + self.category = "api_drift" + self.name = "RestoreArgument" + self.description = ( + f"Add {self.arg_name}={self.arg_value} to {self.function_name}()" + ) + + def apply(self, script: str) -> str: + if not self.function_name: + return script + # Insert the kwarg right after the function-name's opening paren. + pattern = rf"({re.escape(self.function_name)}\s*\()(\s*)" + replacement = rf"\g<1>{self.arg_name}={self.arg_value}, \g<2>" + return re.sub(pattern, replacement, script, count=1) + + def _get_params(self) -> dict: + return { + "function_name": self.function_name, + "arg_name": self.arg_name, + "arg_value": self.arg_value, + } + + +@dataclass +class RestoreConfigField(RepairPrimitive): + field_name: str = "" + old_value: str = "" + + def __post_init__(self) -> None: + self.category = "config_drift" + self.name = "RestoreConfigField" + self.description = f"Restore {self.field_name}={self.old_value}" + + def apply(self, script: str) -> str: + if not self.field_name: + return script + pattern = rf"({re.escape(self.field_name)}\s*=\s*)([^,)\n]+)" + return re.sub(pattern, rf"\g<1>{self.old_value}", script) + + def _get_params(self) -> dict: + return {"field_name": self.field_name, "old_value": self.old_value} + + +@dataclass +class RestoreColumn(RepairPrimitive): + new_column: str = "" + old_column: str = "" + + def __post_init__(self) -> None: + self.category = "dataset_drift" + self.name = "RestoreColumn" + self.description = f"Rename column {self.new_column} -> {self.old_column}" + + def apply(self, script: str) -> str: + return script.replace( + f'"{self.new_column}"', f'"{self.old_column}"' + ).replace( + f"'{self.new_column}'", f"'{self.old_column}'" + ) + + def _get_params(self) -> dict: + return {"new_column": self.new_column, "old_column": self.old_column} + + +@dataclass +class RestoreTokenizerKwarg(RepairPrimitive): + new_kwarg: str = "" + new_value: str = "" + old_kwarg: str = "" + old_value: str = "" + + def __post_init__(self) -> None: + self.category = "tokenizer_drift" + self.name = "RestoreTokenizerKwarg" + self.description = ( + f"Restore tokenizer {self.new_kwarg}={self.new_value} -> " + f"{self.old_kwarg}={self.old_value}" + ) + + def apply(self, script: str) -> str: + if not self.new_kwarg: + return script + pattern = rf"{re.escape(self.new_kwarg)}\s*=\s*{re.escape(self.new_value)}" + replacement = f"{self.old_kwarg}={self.old_value}" + return re.sub(pattern, replacement, script) + + def _get_params(self) -> dict: + return { + "new_kwarg": self.new_kwarg, + "new_value": self.new_value, + "old_kwarg": self.old_kwarg, + "old_value": self.old_value, + } + + +@dataclass +class RestoreMethod(RepairPrimitive): + method_name: str = "" + + def __post_init__(self) -> None: + self.category = "api_drift" + self.name = "RestoreMethod" + self.description = f"Un-deprecate .{self.method_name}()" + + def apply(self, script: str) -> str: + if not self.method_name: + return script + return script.replace( + f".{self.method_name}_DEPRECATED(", f".{self.method_name}(" + ) + + def _get_params(self) -> dict: + return {"method_name": self.method_name} + + +@dataclass +class RestoreReturnAccess(RepairPrimitive): + new_access: str = "" + old_access: str = "" + + def __post_init__(self) -> None: + self.category = "api_drift" + self.name = "RestoreReturnAccess" + self.description = f"Restore return-access {self.new_access} -> {self.old_access}" + + def apply(self, script: str) -> str: + if not self.new_access: + return script + return script.replace(self.new_access, self.old_access) + + def _get_params(self) -> dict: + return {"new_access": self.new_access, "old_access": self.old_access} + + +REPAIR_REGISTRY: dict[str, type[RepairPrimitive]] = { + "RestoreApiCall": RestoreApiCall, + "RestoreImport": RestoreImport, + "RestoreArgument": RestoreArgument, + "RestoreConfigField": RestoreConfigField, + "RestoreColumn": RestoreColumn, + "RestoreTokenizerKwarg": RestoreTokenizerKwarg, + "RestoreMethod": RestoreMethod, + "RestoreReturnAccess": RestoreReturnAccess, +} + + +# Map a breakage primitive's class name to the repair-primitive class that +# inverts it. Used by the warm-start pair generator and by the demo / repair +# library curator. +BREAKAGE_TO_REPAIR: dict[str, str] = { + "RenameApiCall": "RestoreApiCall", + "DeprecateImport": "RestoreImport", + "ChangeArgumentSignature": "RestoreArgument", + "ModifyConfigField": "RestoreConfigField", + "RestructureDatasetSchema": "RestoreColumn", + "ChangeTokenizerBehavior": "RestoreTokenizerKwarg", + "RemoveDeprecatedMethod": "RestoreMethod", + "ChangeReturnType": "RestoreReturnAccess", +} diff --git a/forgeenv/roles/drift_generator.py b/forgeenv/roles/drift_generator.py index e2038619e3d3bf9ee15b886986f0e96323497414..6449241036c8b055628b31823205e07a0101b6b7 100644 --- a/forgeenv/roles/drift_generator.py +++ b/forgeenv/roles/drift_generator.py @@ -1,170 +1,170 @@ -"""Drift Generator parser + a deterministic baseline policy. - -In training the LLM produces a JSON breakage spec; we parse it. In rollouts -where we want a baseline (or a fallback when the LLM emits malformed JSON) -we use `BaselineDriftGenerator`, which samples from the per-category set of -known good primitive parameterisations. -""" -from __future__ import annotations - -import json -import random -import re -from dataclasses import dataclass -from typing import Optional - -from forgeenv.primitives.breakage_primitives import ( - PRIMITIVE_REGISTRY, - parse_breakage_spec, - BreakagePrimitive, -) - - -_JSON_RE = re.compile(r"\{[\s\S]*\}") - - -def parse_drift_output(text: str) -> Optional[dict]: - """Extract a JSON object from possibly-noisy LLM output. - - Handles markdown fences, prose preamble, trailing commas (best-effort). - Returns None on failure. - """ - if not text: - return None - text = text.strip() - if text.startswith("```"): - text = re.sub(r"^```[a-zA-Z]*\n?", "", text) - text = re.sub(r"\n?```$", "", text) - match = _JSON_RE.search(text) - if not match: - return None - blob = match.group(0) - try: - return json.loads(blob) - except json.JSONDecodeError: - cleaned = re.sub(r",\s*([}\]])", r"\1", blob) - try: - return json.loads(cleaned) - except json.JSONDecodeError: - return None - - -def parse_drift_to_primitive(text: str) -> Optional[BreakagePrimitive]: - """End-to-end: LLM text -> validated BreakagePrimitive (or None).""" - data = parse_drift_output(text) - if not isinstance(data, dict): - return None - try: - return parse_breakage_spec(data) - except (ValueError, TypeError): - return None - - -# ---------------------------------------------------------------- baselines -_DEFAULT_PARAMS_BY_TYPE: dict[str, list[dict]] = { - "RenameApiCall": [ - {"old_name": "trainer.train", "new_name": "trainer.start_training"}, - {"old_name": "save_pretrained", "new_name": "save_to_hub"}, - {"old_name": "from_pretrained", "new_name": "load_from_hub"}, - ], - "DeprecateImport": [ - { - "old_module": "from transformers import Trainer", - "new_module": "from transformers.legacy import Trainer", - }, - { - "old_module": "from transformers import TrainingArguments", - "new_module": "from transformers.training import TrainingArguments", - }, - ], - "ChangeArgumentSignature": [ - { - "function_name": "TrainingArguments", - "removed_arg": "num_train_epochs", - "added_arg": "max_steps", - "added_value": "1000", - }, - { - "function_name": "TrainingArguments", - "removed_arg": "evaluation_strategy", - "added_arg": "eval_strategy", - "added_value": '"steps"', - }, - ], - "ModifyConfigField": [ - {"config_class": "TrainingArguments", "field_name": "learning_rate", "new_value": "5e-3"}, - {"config_class": "TrainingArguments", "field_name": "per_device_train_batch_size", "new_value": "1"}, - ], - "RestructureDatasetSchema": [ - {"old_column": "text", "new_column": "input_text"}, - {"old_column": "label", "new_column": "labels"}, - {"old_column": "tokens", "new_column": "words"}, - ], - "ChangeTokenizerBehavior": [ - {"old_kwarg": "padding", "old_value": "True", "new_kwarg": "pad_to_max_length", "new_value": "True"}, - {"old_kwarg": "truncation", "old_value": "True", "new_kwarg": "truncate", "new_value": "True"}, - ], - "RemoveDeprecatedMethod": [ - {"class_name": "Trainer", "method_name": "evaluate", "replacement": "evaluation_loop"}, - {"class_name": "Trainer", "method_name": "save_model", "replacement": "save_to_hub"}, - ], - "ChangeReturnType": [ - {"function_name": "Trainer.predict", "old_access": ".predictions", "new_access": "[0]"}, - {"function_name": "tokenizer", "old_access": '["input_ids"]', "new_access": ".input_ids"}, - ], -} - - -@dataclass -class BaselineDriftGenerator: - """Deterministic stand-in for the LLM Drift Generator. - - Used for warm-start data, baseline rollouts, and unit tests. - """ - - seed: Optional[int] = None - - def __post_init__(self) -> None: - self._rng = random.Random(self.seed) if self.seed is not None else random - - def propose( - self, target_category: str = "", script: str = "" - ) -> dict: - """Produce a JSON-serializable breakage spec for `target_category`. - - Order of preference: - 1. A primitive of `target_category` whose default params apply to `script`. - 2. A primitive of any type whose default params apply to `script`. - 3. A primitive of `target_category` (no-op fallback). - """ - - preferred_types = ( - [target_category] if target_category in _DEFAULT_PARAMS_BY_TYPE else [] - ) - all_types = list(_DEFAULT_PARAMS_BY_TYPE.keys()) - - for type_set in (preferred_types, all_types): - shuffled = self._rng.sample(type_set, len(type_set)) if type_set else [] - for ptype in shuffled: - for params in self._rng.sample( - _DEFAULT_PARAMS_BY_TYPE[ptype], - len(_DEFAULT_PARAMS_BY_TYPE[ptype]), - ): - if self._params_apply_to_script(ptype, params, script): - return {"primitive_type": ptype, "params": dict(params)} - - ptype = preferred_types[0] if preferred_types else all_types[0] - return { - "primitive_type": ptype, - "params": dict(_DEFAULT_PARAMS_BY_TYPE[ptype][0]), - } - - @staticmethod - def _params_apply_to_script(ptype: str, params: dict, script: str) -> bool: - """Heuristic: would this primitive actually mutate `script`?""" - if not script: - return True - for key in ("old_name", "old_module", "removed_arg", "field_name", "old_column", "old_kwarg", "method_name", "old_access"): - if key in params and params[key] and params[key] in script: - return True - return False +"""Drift Generator parser + a deterministic baseline policy. + +In training the LLM produces a JSON breakage spec; we parse it. In rollouts +where we want a baseline (or a fallback when the LLM emits malformed JSON) +we use `BaselineDriftGenerator`, which samples from the per-category set of +known good primitive parameterisations. +""" +from __future__ import annotations + +import json +import random +import re +from dataclasses import dataclass +from typing import Optional + +from forgeenv.primitives.breakage_primitives import ( + PRIMITIVE_REGISTRY, + parse_breakage_spec, + BreakagePrimitive, +) + + +_JSON_RE = re.compile(r"\{[\s\S]*\}") + + +def parse_drift_output(text: str) -> Optional[dict]: + """Extract a JSON object from possibly-noisy LLM output. + + Handles markdown fences, prose preamble, trailing commas (best-effort). + Returns None on failure. + """ + if not text: + return None + text = text.strip() + if text.startswith("```"): + text = re.sub(r"^```[a-zA-Z]*\n?", "", text) + text = re.sub(r"\n?```$", "", text) + match = _JSON_RE.search(text) + if not match: + return None + blob = match.group(0) + try: + return json.loads(blob) + except json.JSONDecodeError: + cleaned = re.sub(r",\s*([}\]])", r"\1", blob) + try: + return json.loads(cleaned) + except json.JSONDecodeError: + return None + + +def parse_drift_to_primitive(text: str) -> Optional[BreakagePrimitive]: + """End-to-end: LLM text -> validated BreakagePrimitive (or None).""" + data = parse_drift_output(text) + if not isinstance(data, dict): + return None + try: + return parse_breakage_spec(data) + except (ValueError, TypeError): + return None + + +# ---------------------------------------------------------------- baselines +_DEFAULT_PARAMS_BY_TYPE: dict[str, list[dict]] = { + "RenameApiCall": [ + {"old_name": "trainer.train", "new_name": "trainer.start_training"}, + {"old_name": "save_pretrained", "new_name": "save_to_hub"}, + {"old_name": "from_pretrained", "new_name": "load_from_hub"}, + ], + "DeprecateImport": [ + { + "old_module": "from transformers import Trainer", + "new_module": "from transformers.legacy import Trainer", + }, + { + "old_module": "from transformers import TrainingArguments", + "new_module": "from transformers.training import TrainingArguments", + }, + ], + "ChangeArgumentSignature": [ + { + "function_name": "TrainingArguments", + "removed_arg": "num_train_epochs", + "added_arg": "max_steps", + "added_value": "1000", + }, + { + "function_name": "TrainingArguments", + "removed_arg": "evaluation_strategy", + "added_arg": "eval_strategy", + "added_value": '"steps"', + }, + ], + "ModifyConfigField": [ + {"config_class": "TrainingArguments", "field_name": "learning_rate", "new_value": "5e-3"}, + {"config_class": "TrainingArguments", "field_name": "per_device_train_batch_size", "new_value": "1"}, + ], + "RestructureDatasetSchema": [ + {"old_column": "text", "new_column": "input_text"}, + {"old_column": "label", "new_column": "labels"}, + {"old_column": "tokens", "new_column": "words"}, + ], + "ChangeTokenizerBehavior": [ + {"old_kwarg": "padding", "old_value": "True", "new_kwarg": "pad_to_max_length", "new_value": "True"}, + {"old_kwarg": "truncation", "old_value": "True", "new_kwarg": "truncate", "new_value": "True"}, + ], + "RemoveDeprecatedMethod": [ + {"class_name": "Trainer", "method_name": "evaluate", "replacement": "evaluation_loop"}, + {"class_name": "Trainer", "method_name": "save_model", "replacement": "save_to_hub"}, + ], + "ChangeReturnType": [ + {"function_name": "Trainer.predict", "old_access": ".predictions", "new_access": "[0]"}, + {"function_name": "tokenizer", "old_access": '["input_ids"]', "new_access": ".input_ids"}, + ], +} + + +@dataclass +class BaselineDriftGenerator: + """Deterministic stand-in for the LLM Drift Generator. + + Used for warm-start data, baseline rollouts, and unit tests. + """ + + seed: Optional[int] = None + + def __post_init__(self) -> None: + self._rng = random.Random(self.seed) if self.seed is not None else random + + def propose( + self, target_category: str = "", script: str = "" + ) -> dict: + """Produce a JSON-serializable breakage spec for `target_category`. + + Order of preference: + 1. A primitive of `target_category` whose default params apply to `script`. + 2. A primitive of any type whose default params apply to `script`. + 3. A primitive of `target_category` (no-op fallback). + """ + + preferred_types = ( + [target_category] if target_category in _DEFAULT_PARAMS_BY_TYPE else [] + ) + all_types = list(_DEFAULT_PARAMS_BY_TYPE.keys()) + + for type_set in (preferred_types, all_types): + shuffled = self._rng.sample(type_set, len(type_set)) if type_set else [] + for ptype in shuffled: + for params in self._rng.sample( + _DEFAULT_PARAMS_BY_TYPE[ptype], + len(_DEFAULT_PARAMS_BY_TYPE[ptype]), + ): + if self._params_apply_to_script(ptype, params, script): + return {"primitive_type": ptype, "params": dict(params)} + + ptype = preferred_types[0] if preferred_types else all_types[0] + return { + "primitive_type": ptype, + "params": dict(_DEFAULT_PARAMS_BY_TYPE[ptype][0]), + } + + @staticmethod + def _params_apply_to_script(ptype: str, params: dict, script: str) -> bool: + """Heuristic: would this primitive actually mutate `script`?""" + if not script: + return True + for key in ("old_name", "old_module", "removed_arg", "field_name", "old_column", "old_kwarg", "method_name", "old_access"): + if key in params and params[key] and params[key] in script: + return True + return False diff --git a/forgeenv/roles/prompts.py b/forgeenv/roles/prompts.py index 6c56d5bdef22ce23d6fd02a5cc3577901eda9f9d..ee03b83526f939fbfcc7f7288fb543a2c7c62d6f 100644 --- a/forgeenv/roles/prompts.py +++ b/forgeenv/roles/prompts.py @@ -1,102 +1,102 @@ -"""System and user prompts for the two RL roles. - -Both roles are trained from the same base policy (Qwen-2.5-Coder-7B) with -LoRA adapters per role, so role prompts are the only thing distinguishing -them at inference time. Keep them concise β€” every token is a token of GPU -budget during GRPO rollouts. -""" -from __future__ import annotations - -from typing import Iterable - - -PRIMITIVE_DESCRIPTIONS = { - "RenameApiCall": "Rename a function/method call (api_drift)", - "DeprecateImport": "Change an import path (import_drift)", - "ChangeArgumentSignature": "Remove an expected kwarg from a call (api_drift)", - "ModifyConfigField": "Change a config-class default (config_drift)", - "RestructureDatasetSchema": "Rename a dataset column reference (dataset_drift)", - "ChangeTokenizerBehavior": "Change tokenizer call kwargs (tokenizer_drift)", - "RemoveDeprecatedMethod": "Remove a method, leaving a sentinel _DEPRECATED suffix (api_drift)", - "ChangeReturnType": "Function returns a different structure (api_drift)", -} - -DRIFT_GENERATOR_SYSTEM_PROMPT = """You are the Drift Generator. -You see a working HuggingFace training script and the curriculum target category. -Output exactly one JSON object describing a breakage primitive that simulates -realistic library version drift. The primitive must: -1. Be PLAUSIBLE β€” match the kind of breakage that happens between real - transformers/datasets/trl releases. -2. Be SOLVABLE β€” the Repair Agent should be able to fix it from the error trace alone. -3. Match the requested target_category. - -Output schema: -{"primitive_type": "", "params": { ... }} - -Available primitive types and parameter schemas: -- RenameApiCall: {"old_name": str, "new_name": str} -- DeprecateImport: {"old_module": str, "new_module": str} -- ChangeArgumentSignature: {"function_name": str, "removed_arg": str, "added_arg": str, "added_value": str} -- ModifyConfigField: {"config_class": str, "field_name": str, "new_value": str} -- RestructureDatasetSchema: {"old_column": str, "new_column": str} -- ChangeTokenizerBehavior: {"old_kwarg": str, "old_value": str, "new_kwarg": str, "new_value": str} -- RemoveDeprecatedMethod: {"class_name": str, "method_name": str, "replacement": str} -- ChangeReturnType: {"function_name": str, "old_access": str, "new_access": str} - -Output ONLY the JSON object β€” no commentary, no markdown fences. -""" - - -REPAIR_AGENT_SYSTEM_PROMPT = """You are the Repair Agent. -You see a broken HuggingFace training script, an error trace, and the current -library version snapshot. Output ONLY a unified diff that fixes the script. - -Rules: -1. Use canonical unified-diff format with `--- a/train.py` / `+++ b/train.py` - headers and `@@ ... @@` hunk markers. -2. Make the MINIMAL change that resolves the error AND preserves the original - training intent. Do NOT add bare-except blocks, monkey-patches, or sys.exit - calls. -3. Do NOT add any prose, markdown fences, or thinking output β€” diff only. -4. If the error is unfixable, output an empty diff. -""" - - -def render_drift_generator_prompt( - script: str, target_category: str, library_versions: dict -) -> str: - versions_str = ", ".join(f"{k}={v}" for k, v in library_versions.items()) - return f"""Target category: {target_category} -Library versions: {versions_str} - -Working script: -```python -{script} -``` - -Output JSON breakage primitive:""" - - -def render_repair_agent_prompt( - broken_script: str, - error_trace: str, - library_versions: dict, - target_category: str = "", -) -> str: - versions_str = ", ".join(f"{k}={v}" for k, v in library_versions.items()) - return f"""Library versions: {versions_str} -Target category hint: {target_category or 'unknown'} - -Broken script: -```python -{broken_script} -``` - -Error trace: -{error_trace} - -Output unified diff (no prose, no fences):""" - - -def list_primitive_descriptions() -> Iterable[str]: - return (f"- {k}: {v}" for k, v in PRIMITIVE_DESCRIPTIONS.items()) +"""System and user prompts for the two RL roles. + +Both roles are trained from the same base policy (Qwen-2.5-Coder-7B) with +LoRA adapters per role, so role prompts are the only thing distinguishing +them at inference time. Keep them concise β€” every token is a token of GPU +budget during GRPO rollouts. +""" +from __future__ import annotations + +from typing import Iterable + + +PRIMITIVE_DESCRIPTIONS = { + "RenameApiCall": "Rename a function/method call (api_drift)", + "DeprecateImport": "Change an import path (import_drift)", + "ChangeArgumentSignature": "Remove an expected kwarg from a call (api_drift)", + "ModifyConfigField": "Change a config-class default (config_drift)", + "RestructureDatasetSchema": "Rename a dataset column reference (dataset_drift)", + "ChangeTokenizerBehavior": "Change tokenizer call kwargs (tokenizer_drift)", + "RemoveDeprecatedMethod": "Remove a method, leaving a sentinel _DEPRECATED suffix (api_drift)", + "ChangeReturnType": "Function returns a different structure (api_drift)", +} + +DRIFT_GENERATOR_SYSTEM_PROMPT = """You are the Drift Generator. +You see a working HuggingFace training script and the curriculum target category. +Output exactly one JSON object describing a breakage primitive that simulates +realistic library version drift. The primitive must: +1. Be PLAUSIBLE β€” match the kind of breakage that happens between real + transformers/datasets/trl releases. +2. Be SOLVABLE β€” the Repair Agent should be able to fix it from the error trace alone. +3. Match the requested target_category. + +Output schema: +{"primitive_type": "", "params": { ... }} + +Available primitive types and parameter schemas: +- RenameApiCall: {"old_name": str, "new_name": str} +- DeprecateImport: {"old_module": str, "new_module": str} +- ChangeArgumentSignature: {"function_name": str, "removed_arg": str, "added_arg": str, "added_value": str} +- ModifyConfigField: {"config_class": str, "field_name": str, "new_value": str} +- RestructureDatasetSchema: {"old_column": str, "new_column": str} +- ChangeTokenizerBehavior: {"old_kwarg": str, "old_value": str, "new_kwarg": str, "new_value": str} +- RemoveDeprecatedMethod: {"class_name": str, "method_name": str, "replacement": str} +- ChangeReturnType: {"function_name": str, "old_access": str, "new_access": str} + +Output ONLY the JSON object β€” no commentary, no markdown fences. +""" + + +REPAIR_AGENT_SYSTEM_PROMPT = """You are the Repair Agent. +You see a broken HuggingFace training script, an error trace, and the current +library version snapshot. Output ONLY a unified diff that fixes the script. + +Rules: +1. Use canonical unified-diff format with `--- a/train.py` / `+++ b/train.py` + headers and `@@ ... @@` hunk markers. +2. Make the MINIMAL change that resolves the error AND preserves the original + training intent. Do NOT add bare-except blocks, monkey-patches, or sys.exit + calls. +3. Do NOT add any prose, markdown fences, or thinking output β€” diff only. +4. If the error is unfixable, output an empty diff. +""" + + +def render_drift_generator_prompt( + script: str, target_category: str, library_versions: dict +) -> str: + versions_str = ", ".join(f"{k}={v}" for k, v in library_versions.items()) + return f"""Target category: {target_category} +Library versions: {versions_str} + +Working script: +```python +{script} +``` + +Output JSON breakage primitive:""" + + +def render_repair_agent_prompt( + broken_script: str, + error_trace: str, + library_versions: dict, + target_category: str = "", +) -> str: + versions_str = ", ".join(f"{k}={v}" for k, v in library_versions.items()) + return f"""Library versions: {versions_str} +Target category hint: {target_category or 'unknown'} + +Broken script: +```python +{broken_script} +``` + +Error trace: +{error_trace} + +Output unified diff (no prose, no fences):""" + + +def list_primitive_descriptions() -> Iterable[str]: + return (f"- {k}: {v}" for k, v in PRIMITIVE_DESCRIPTIONS.items()) diff --git a/forgeenv/roles/repair_agent.py b/forgeenv/roles/repair_agent.py index 4ebb08367661227ce31328d458a83934386e1631..c2b2a8ffa5e31ec6d30b0ce2235d8865aaedfddd 100644 --- a/forgeenv/roles/repair_agent.py +++ b/forgeenv/roles/repair_agent.py @@ -1,153 +1,153 @@ -"""Repair Agent helpers: response sanitisation + a deterministic baseline. - -The Repair Agent's training output is a unified diff. LLMs frequently emit -prose / fences / chain-of-thought before the diff; this module strips that -preamble. The baseline policy uses the inverse-primitive map from -`repair_primitives.py` to produce ground-truth diffs for warm-start. -""" -from __future__ import annotations - -import re -from dataclasses import dataclass -from typing import Optional - -from forgeenv.env.diff_utils import make_unified_diff -from forgeenv.primitives.breakage_primitives import ( - parse_breakage_spec, - BreakagePrimitive, -) -from forgeenv.primitives.repair_primitives import ( - BREAKAGE_TO_REPAIR, - REPAIR_REGISTRY, - RepairPrimitive, -) - - -_DIFF_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE) -_FENCE_RE = re.compile(r"```[a-zA-Z]*\n([\s\S]*?)\n```") - - -def extract_diff(raw_text: str) -> str: - """Pull the unified diff out of an LLM response. - - Handles: code fences, leading prose / chain-of-thought, trailing notes. - """ - if not raw_text: - return "" - raw_text = raw_text.strip() - - fence_match = _FENCE_RE.search(raw_text) - if fence_match: - raw_text = fence_match.group(1).strip() - - lines = raw_text.splitlines() - start = 0 - for i, line in enumerate(lines): - if line.startswith(("---", "+++", "@@")): - start = i - break - - return "\n".join(lines[start:]) - - -def looks_like_diff(text: str) -> bool: - if not text: - return False - has_header = "---" in text and "+++" in text - has_hunk = bool(_DIFF_HUNK_RE.search(text)) - has_pm = any(line.startswith(("+", "-")) for line in text.splitlines()) - return (has_header and has_hunk) or (has_hunk and has_pm) - - -# ---------------------------------------------------------------- baselines -@dataclass -class BaselineRepairAgent: - """Deterministic Repair Agent that uses the primitive inverse map. - - Used for warm-start dataset generation and baseline rollout comparisons. - """ - - def repair( - self, - broken_script: str, - breakage_spec: Optional[dict] = None, - original_script: str = "", - ) -> str: - """Return a unified diff (or full replacement script) that fixes the - broken script. - - Strategy preference: - 1. If `original_script` is provided, return a diff between the - broken script and the original (oracle). This is the warm-start - path β€” we always know the ground truth. - 2. Otherwise try to invert the structured breakage_spec via the - repair-primitive registry. - 3. Otherwise return an empty diff. - """ - if original_script and original_script != broken_script: - return make_unified_diff(broken_script, original_script) - - if breakage_spec: - try: - breakage = parse_breakage_spec(breakage_spec) - except (ValueError, TypeError): - breakage = None - if breakage is not None: - repair = _invert_breakage(breakage) - if repair is not None: - repaired = repair.apply(broken_script) - if repaired != broken_script: - return make_unified_diff(broken_script, repaired) - - return "" - - -_PARAM_REMAP: dict[str, dict[str, str]] = { - "RenameApiCall": {"old_name": "old_name", "new_name": "new_name"}, - "DeprecateImport": {"old_module": "old_module", "new_module": "new_module"}, - "ChangeArgumentSignature": { - "function_name": "function_name", - "removed_arg": "arg_name", - }, - "ModifyConfigField": {"field_name": "field_name"}, - "RestructureDatasetSchema": { - "old_column": "old_column", - "new_column": "new_column", - }, - "ChangeTokenizerBehavior": { - "old_kwarg": "old_kwarg", - "old_value": "old_value", - "new_kwarg": "new_kwarg", - "new_value": "new_value", - }, - "RemoveDeprecatedMethod": {"method_name": "method_name"}, - "ChangeReturnType": {"old_access": "old_access", "new_access": "new_access"}, -} - - -def _invert_breakage(breakage: BreakagePrimitive) -> Optional[RepairPrimitive]: - breakage_name = type(breakage).__name__ - repair_name = BREAKAGE_TO_REPAIR.get(breakage_name) - if repair_name is None: - return None - repair_cls = REPAIR_REGISTRY.get(repair_name) - if repair_cls is None: - return None - - breakage_params = breakage._get_params() # type: ignore[attr-defined] - remap = _PARAM_REMAP.get(breakage_name, {}) - mapped: dict[str, str] = {} - for src_key, dst_key in remap.items(): - if src_key in breakage_params: - mapped[dst_key] = breakage_params[src_key] - - valid_fields = { - f.name - for f in repair_cls.__dataclass_fields__.values() # type: ignore[attr-defined] - if f.init - } - filtered = {k: v for k, v in mapped.items() if k in valid_fields} - try: - return repair_cls(**filtered) - except TypeError: - return None +"""Repair Agent helpers: response sanitisation + a deterministic baseline. + +The Repair Agent's training output is a unified diff. LLMs frequently emit +prose / fences / chain-of-thought before the diff; this module strips that +preamble. The baseline policy uses the inverse-primitive map from +`repair_primitives.py` to produce ground-truth diffs for warm-start. +""" +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Optional + +from forgeenv.env.diff_utils import make_unified_diff +from forgeenv.primitives.breakage_primitives import ( + parse_breakage_spec, + BreakagePrimitive, +) +from forgeenv.primitives.repair_primitives import ( + BREAKAGE_TO_REPAIR, + REPAIR_REGISTRY, + RepairPrimitive, +) + + +_DIFF_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE) +_FENCE_RE = re.compile(r"```[a-zA-Z]*\n([\s\S]*?)\n```") + + +def extract_diff(raw_text: str) -> str: + """Pull the unified diff out of an LLM response. + + Handles: code fences, leading prose / chain-of-thought, trailing notes. + """ + if not raw_text: + return "" + raw_text = raw_text.strip() + + fence_match = _FENCE_RE.search(raw_text) + if fence_match: + raw_text = fence_match.group(1).strip() + + lines = raw_text.splitlines() + start = 0 + for i, line in enumerate(lines): + if line.startswith(("---", "+++", "@@")): + start = i + break + + return "\n".join(lines[start:]) + + +def looks_like_diff(text: str) -> bool: + if not text: + return False + has_header = "---" in text and "+++" in text + has_hunk = bool(_DIFF_HUNK_RE.search(text)) + has_pm = any(line.startswith(("+", "-")) for line in text.splitlines()) + return (has_header and has_hunk) or (has_hunk and has_pm) + + +# ---------------------------------------------------------------- baselines +@dataclass +class BaselineRepairAgent: + """Deterministic Repair Agent that uses the primitive inverse map. + + Used for warm-start dataset generation and baseline rollout comparisons. + """ + + def repair( + self, + broken_script: str, + breakage_spec: Optional[dict] = None, + original_script: str = "", + ) -> str: + """Return a unified diff (or full replacement script) that fixes the + broken script. + + Strategy preference: + 1. If `original_script` is provided, return a diff between the + broken script and the original (oracle). This is the warm-start + path β€” we always know the ground truth. + 2. Otherwise try to invert the structured breakage_spec via the + repair-primitive registry. + 3. Otherwise return an empty diff. + """ + if original_script and original_script != broken_script: + return make_unified_diff(broken_script, original_script) + + if breakage_spec: + try: + breakage = parse_breakage_spec(breakage_spec) + except (ValueError, TypeError): + breakage = None + if breakage is not None: + repair = _invert_breakage(breakage) + if repair is not None: + repaired = repair.apply(broken_script) + if repaired != broken_script: + return make_unified_diff(broken_script, repaired) + + return "" + + +_PARAM_REMAP: dict[str, dict[str, str]] = { + "RenameApiCall": {"old_name": "old_name", "new_name": "new_name"}, + "DeprecateImport": {"old_module": "old_module", "new_module": "new_module"}, + "ChangeArgumentSignature": { + "function_name": "function_name", + "removed_arg": "arg_name", + }, + "ModifyConfigField": {"field_name": "field_name"}, + "RestructureDatasetSchema": { + "old_column": "old_column", + "new_column": "new_column", + }, + "ChangeTokenizerBehavior": { + "old_kwarg": "old_kwarg", + "old_value": "old_value", + "new_kwarg": "new_kwarg", + "new_value": "new_value", + }, + "RemoveDeprecatedMethod": {"method_name": "method_name"}, + "ChangeReturnType": {"old_access": "old_access", "new_access": "new_access"}, +} + + +def _invert_breakage(breakage: BreakagePrimitive) -> Optional[RepairPrimitive]: + breakage_name = type(breakage).__name__ + repair_name = BREAKAGE_TO_REPAIR.get(breakage_name) + if repair_name is None: + return None + repair_cls = REPAIR_REGISTRY.get(repair_name) + if repair_cls is None: + return None + + breakage_params = breakage._get_params() # type: ignore[attr-defined] + remap = _PARAM_REMAP.get(breakage_name, {}) + mapped: dict[str, str] = {} + for src_key, dst_key in remap.items(): + if src_key in breakage_params: + mapped[dst_key] = breakage_params[src_key] + + valid_fields = { + f.name + for f in repair_cls.__dataclass_fields__.values() # type: ignore[attr-defined] + if f.init + } + filtered = {k: v for k, v in mapped.items() if k in valid_fields} + try: + return repair_cls(**filtered) + except TypeError: + return None diff --git a/forgeenv/roles/teacher.py b/forgeenv/roles/teacher.py index a0356b61b282763d2d27b89a6bfa0e9d6b54dd86..67dda6f6bd8d7bcb1f024b3d5099ea7f1f7acd76 100644 --- a/forgeenv/roles/teacher.py +++ b/forgeenv/roles/teacher.py @@ -1,58 +1,58 @@ -"""Teacher (curriculum controller). - -Deterministic β€” NOT an LLM. Maintains an EMA success rate per breakage -category and routes the next episode toward the category where the -Repair Agent is closest to a 50% success rate (R-Zero's difficulty band). -""" -from __future__ import annotations - -import random -from dataclasses import dataclass, field - - -@dataclass -class Teacher: - categories: list[str] - alpha: float = 0.9 - success_counts: dict[str, int] = field(default_factory=dict) - attempt_counts: dict[str, int] = field(default_factory=dict) - ema_success: dict[str, float] = field(default_factory=dict) - - def __post_init__(self) -> None: - for category in self.categories: - self.success_counts.setdefault(category, 0) - self.attempt_counts.setdefault(category, 0) - self.ema_success.setdefault(category, 0.5) - - def update(self, category: str, success: bool) -> None: - if category not in self.ema_success: - self.categories.append(category) - self.ema_success[category] = 0.5 - self.success_counts[category] = 0 - self.attempt_counts[category] = 0 - - self.attempt_counts[category] += 1 - self.success_counts[category] += int(success) - rate = self.success_counts[category] / max(1, self.attempt_counts[category]) - self.ema_success[category] = ( - self.alpha * self.ema_success[category] + (1 - self.alpha) * rate - ) - - def select_next_category(self) -> str: - in_zone = { - c: abs(s - 0.5) for c, s in self.ema_success.items() if 0.3 <= s <= 0.7 - } - if in_zone: - weights = [1.0 / (v + 0.01) for v in in_zone.values()] - return random.choices(list(in_zone.keys()), weights=weights, k=1)[0] - return min(self.ema_success, key=lambda c: abs(self.ema_success[c] - 0.5)) - - def get_state(self) -> dict: - return { - c: { - "ema_success": round(self.ema_success[c], 4), - "attempts": self.attempt_counts[c], - "successes": self.success_counts[c], - } - for c in self.categories - } +"""Teacher (curriculum controller). + +Deterministic β€” NOT an LLM. Maintains an EMA success rate per breakage +category and routes the next episode toward the category where the +Repair Agent is closest to a 50% success rate (R-Zero's difficulty band). +""" +from __future__ import annotations + +import random +from dataclasses import dataclass, field + + +@dataclass +class Teacher: + categories: list[str] + alpha: float = 0.9 + success_counts: dict[str, int] = field(default_factory=dict) + attempt_counts: dict[str, int] = field(default_factory=dict) + ema_success: dict[str, float] = field(default_factory=dict) + + def __post_init__(self) -> None: + for category in self.categories: + self.success_counts.setdefault(category, 0) + self.attempt_counts.setdefault(category, 0) + self.ema_success.setdefault(category, 0.5) + + def update(self, category: str, success: bool) -> None: + if category not in self.ema_success: + self.categories.append(category) + self.ema_success[category] = 0.5 + self.success_counts[category] = 0 + self.attempt_counts[category] = 0 + + self.attempt_counts[category] += 1 + self.success_counts[category] += int(success) + rate = self.success_counts[category] / max(1, self.attempt_counts[category]) + self.ema_success[category] = ( + self.alpha * self.ema_success[category] + (1 - self.alpha) * rate + ) + + def select_next_category(self) -> str: + in_zone = { + c: abs(s - 0.5) for c, s in self.ema_success.items() if 0.3 <= s <= 0.7 + } + if in_zone: + weights = [1.0 / (v + 0.01) for v in in_zone.values()] + return random.choices(list(in_zone.keys()), weights=weights, k=1)[0] + return min(self.ema_success, key=lambda c: abs(self.ema_success[c] - 0.5)) + + def get_state(self) -> dict: + return { + c: { + "ema_success": round(self.ema_success[c], 4), + "attempts": self.attempt_counts[c], + "successes": self.success_counts[c], + } + for c in self.categories + } diff --git a/forgeenv/sandbox/ast_validator.py b/forgeenv/sandbox/ast_validator.py index 12155fd4370c0db537db273ba9c7a9c613896f93..80f466a4470e755d2837feb82538848edf499f4d 100644 --- a/forgeenv/sandbox/ast_validator.py +++ b/forgeenv/sandbox/ast_validator.py @@ -1,70 +1,70 @@ -"""AST-based script validator. - -Catches forbidden imports and dangerous patterns BEFORE any execution -happens. This is a critical defense against reward hacking via system -calls, network access, or process manipulation. -""" -from __future__ import annotations - -import ast - -from forgeenv.tasks.models import ValidationResult - -FORBIDDEN_MODULES = { - "os", - "subprocess", - "socket", - "urllib", - "requests", - "ctypes", - "shutil", - "signal", - "multiprocessing", - "threading", -} - -FORBIDDEN_FUNCTIONS = {"eval", "exec", "compile", "__import__"} - - -def validate_script(script_content: str) -> ValidationResult: - """Parse a script as AST and reject forbidden patterns. - - Returns a ValidationResult with `is_valid` and a list of `violations`. - """ - violations: list[str] = [] - - try: - tree = ast.parse(script_content) - except SyntaxError as e: - return ValidationResult(is_valid=False, violations=[f"SyntaxError: {e}"]) - - for node in ast.walk(tree): - if isinstance(node, ast.Import): - for alias in node.names: - module_root = alias.name.split(".")[0] - if module_root in FORBIDDEN_MODULES: - violations.append(f"Forbidden import: {alias.name}") - - if isinstance(node, ast.ImportFrom): - if node.module: - module_root = node.module.split(".")[0] - if module_root in FORBIDDEN_MODULES: - violations.append(f"Forbidden import from: {node.module}") - - if isinstance(node, ast.Call): - if isinstance(node.func, ast.Name): - if node.func.id in FORBIDDEN_FUNCTIONS: - violations.append(f"Forbidden call: {node.func.id}()") - if isinstance(node.func, ast.Attribute): - if node.func.attr in FORBIDDEN_FUNCTIONS: - violations.append(f"Forbidden call: .{node.func.attr}()") - - if isinstance(node, ast.Assign): - for target in node.targets: - if isinstance(target, ast.Name) and target.id == "__builtins__": - violations.append("Forbidden: __builtins__ assignment") - - return ValidationResult( - is_valid=len(violations) == 0, - violations=violations, - ) +"""AST-based script validator. + +Catches forbidden imports and dangerous patterns BEFORE any execution +happens. This is a critical defense against reward hacking via system +calls, network access, or process manipulation. +""" +from __future__ import annotations + +import ast + +from forgeenv.tasks.models import ValidationResult + +FORBIDDEN_MODULES = { + "os", + "subprocess", + "socket", + "urllib", + "requests", + "ctypes", + "shutil", + "signal", + "multiprocessing", + "threading", +} + +FORBIDDEN_FUNCTIONS = {"eval", "exec", "compile", "__import__"} + + +def validate_script(script_content: str) -> ValidationResult: + """Parse a script as AST and reject forbidden patterns. + + Returns a ValidationResult with `is_valid` and a list of `violations`. + """ + violations: list[str] = [] + + try: + tree = ast.parse(script_content) + except SyntaxError as e: + return ValidationResult(is_valid=False, violations=[f"SyntaxError: {e}"]) + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + module_root = alias.name.split(".")[0] + if module_root in FORBIDDEN_MODULES: + violations.append(f"Forbidden import: {alias.name}") + + if isinstance(node, ast.ImportFrom): + if node.module: + module_root = node.module.split(".")[0] + if module_root in FORBIDDEN_MODULES: + violations.append(f"Forbidden import from: {node.module}") + + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Name): + if node.func.id in FORBIDDEN_FUNCTIONS: + violations.append(f"Forbidden call: {node.func.id}()") + if isinstance(node.func, ast.Attribute): + if node.func.attr in FORBIDDEN_FUNCTIONS: + violations.append(f"Forbidden call: .{node.func.attr}()") + + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == "__builtins__": + violations.append("Forbidden: __builtins__ assignment") + + return ValidationResult( + is_valid=len(violations) == 0, + violations=violations, + ) diff --git a/forgeenv/sandbox/simulation_mode.py b/forgeenv/sandbox/simulation_mode.py index 22ab4e103c38989aef83f3d6e92e41806474264e..773d916320261c2395766e043c527d943a81459b 100644 --- a/forgeenv/sandbox/simulation_mode.py +++ b/forgeenv/sandbox/simulation_mode.py @@ -1,142 +1,142 @@ -"""Fast simulation executor for development. - -Static-analysis-based execution simulator. Sub-100ms per call. No Docker -required. The success probability of a simulated run depends on whether -the script contains expected HF training markers (model imports, training -calls, save calls). When the simulation succeeds, a synthetic decreasing -loss curve is emitted; when it fails, a representative HF error is raised. -""" -from __future__ import annotations - -import random -import time -from typing import Optional - -from forgeenv.sandbox.ast_validator import validate_script -from forgeenv.tasks.models import ExecutionResult, Task - - -class SimulationExecutor: - """Simulates script execution via static analysis. - - Use this throughout development phases. Real Docker execution is added - later for grounded final-stage verification. - """ - - def __init__(self, seed: Optional[int] = None) -> None: - self._rng = random.Random(seed) if seed is not None else random - - def execute( - self, script_content: str, task: Optional[Task] = None - ) -> ExecutionResult: - start = time.time() - - validation = validate_script(script_content) - if not validation.is_valid: - return ExecutionResult( - exit_code=1, - stdout="", - stderr=f"Validation failed: {'; '.join(validation.violations)}", - wall_time_ms=int((time.time() - start) * 1000), - script_content=script_content, - ) - - try: - compile(script_content, "", "exec") - except SyntaxError as e: - return ExecutionResult( - exit_code=1, - stdout="", - stderr=f"SyntaxError: {e}", - wall_time_ms=int((time.time() - start) * 1000), - script_content=script_content, - ) - - has_model_import = any( - kw in script_content - for kw in ("from transformers", "import torch", "from datasets") - ) - has_training_call = any( - kw in script_content - for kw in ("trainer.train()", ".fit(", "train_loop", "for epoch") - ) - has_save = any( - kw in script_content - for kw in ("save_pretrained", "save_model", "torch.save") - ) - - success_prob = 0.3 - if has_model_import: - success_prob += 0.3 - if has_training_call: - success_prob += 0.2 - if has_save: - success_prob += 0.1 - - # Mark obviously broken patterns as definite failures even when - # they pass syntactic compilation. The simulator pretends to be a - # static linter that catches AttributeError / ImportError signatures - # before they would fire at runtime. - broken_markers = ( - "_DEPRECATED(", - "transformers.legacy", - "from transformers.training import", - ".start_training(", - "load_from_hub(", - "save_to_hub(", - "pad_to_max_length=", - "evaluation_loop(", - ) - if any(marker in script_content for marker in broken_markers): - success_prob = 0.0 - # Patterns that look like dataset column drift: a renamed column - # that doesn't appear in real HF datasets. - import re as _re - - if _re.search(r"['\"]input_text['\"]\s*[]:),]", script_content): - success_prob = min(success_prob, 0.05) - if _re.search(r"['\"]words['\"]\s*[]:),]", script_content): - success_prob = min(success_prob, 0.05) - # Tokenizer kwarg drift (truncate is not valid; truncation is). - if _re.search(r"\btruncate\s*=", script_content): - success_prob = min(success_prob, 0.05) - - succeeded = self._rng.random() < success_prob - - if succeeded: - steps = self._rng.randint(20, 50) - log_lines: list[str] = [] - loss = self._rng.uniform(2.0, 4.0) - for step in range(1, steps + 1): - loss *= self._rng.uniform(0.92, 0.99) - log_lines.append(f"step={step} loss={loss:.4f}") - log_lines.append("eval_accuracy=0.78") - log_lines.append("TRAINING_COMPLETE") - - return ExecutionResult( - exit_code=0, - stdout="\n".join(log_lines), - stderr="", - wall_time_ms=int((time.time() - start) * 1000) - + self._rng.randint(1000, 5000), - checkpoint_exists=True, - peak_memory_mb=self._rng.uniform(500, 2000), - script_content=script_content, - ) - - error_types = [ - "ImportError: cannot import name 'OldTrainer' from 'transformers'", - "AttributeError: 'Trainer' object has no attribute 'evaluate_model'", - "KeyError: 'text' column not found in dataset", - "TypeError: __init__() got an unexpected keyword argument 'num_epochs'", - "RuntimeError: Expected input batch_size (16) to match target batch_size (32)", - "ModuleNotFoundError: No module named 'transformers.legacy'", - ] - return ExecutionResult( - exit_code=1, - stdout="", - stderr=self._rng.choice(error_types), - wall_time_ms=int((time.time() - start) * 1000) - + self._rng.randint(100, 500), - script_content=script_content, - ) +"""Fast simulation executor for development. + +Static-analysis-based execution simulator. Sub-100ms per call. No Docker +required. The success probability of a simulated run depends on whether +the script contains expected HF training markers (model imports, training +calls, save calls). When the simulation succeeds, a synthetic decreasing +loss curve is emitted; when it fails, a representative HF error is raised. +""" +from __future__ import annotations + +import random +import time +from typing import Optional + +from forgeenv.sandbox.ast_validator import validate_script +from forgeenv.tasks.models import ExecutionResult, Task + + +class SimulationExecutor: + """Simulates script execution via static analysis. + + Use this throughout development phases. Real Docker execution is added + later for grounded final-stage verification. + """ + + def __init__(self, seed: Optional[int] = None) -> None: + self._rng = random.Random(seed) if seed is not None else random + + def execute( + self, script_content: str, task: Optional[Task] = None + ) -> ExecutionResult: + start = time.time() + + validation = validate_script(script_content) + if not validation.is_valid: + return ExecutionResult( + exit_code=1, + stdout="", + stderr=f"Validation failed: {'; '.join(validation.violations)}", + wall_time_ms=int((time.time() - start) * 1000), + script_content=script_content, + ) + + try: + compile(script_content, "", "exec") + except SyntaxError as e: + return ExecutionResult( + exit_code=1, + stdout="", + stderr=f"SyntaxError: {e}", + wall_time_ms=int((time.time() - start) * 1000), + script_content=script_content, + ) + + has_model_import = any( + kw in script_content + for kw in ("from transformers", "import torch", "from datasets") + ) + has_training_call = any( + kw in script_content + for kw in ("trainer.train()", ".fit(", "train_loop", "for epoch") + ) + has_save = any( + kw in script_content + for kw in ("save_pretrained", "save_model", "torch.save") + ) + + success_prob = 0.3 + if has_model_import: + success_prob += 0.3 + if has_training_call: + success_prob += 0.2 + if has_save: + success_prob += 0.1 + + # Mark obviously broken patterns as definite failures even when + # they pass syntactic compilation. The simulator pretends to be a + # static linter that catches AttributeError / ImportError signatures + # before they would fire at runtime. + broken_markers = ( + "_DEPRECATED(", + "transformers.legacy", + "from transformers.training import", + ".start_training(", + "load_from_hub(", + "save_to_hub(", + "pad_to_max_length=", + "evaluation_loop(", + ) + if any(marker in script_content for marker in broken_markers): + success_prob = 0.0 + # Patterns that look like dataset column drift: a renamed column + # that doesn't appear in real HF datasets. + import re as _re + + if _re.search(r"['\"]input_text['\"]\s*[]:),]", script_content): + success_prob = min(success_prob, 0.05) + if _re.search(r"['\"]words['\"]\s*[]:),]", script_content): + success_prob = min(success_prob, 0.05) + # Tokenizer kwarg drift (truncate is not valid; truncation is). + if _re.search(r"\btruncate\s*=", script_content): + success_prob = min(success_prob, 0.05) + + succeeded = self._rng.random() < success_prob + + if succeeded: + steps = self._rng.randint(20, 50) + log_lines: list[str] = [] + loss = self._rng.uniform(2.0, 4.0) + for step in range(1, steps + 1): + loss *= self._rng.uniform(0.92, 0.99) + log_lines.append(f"step={step} loss={loss:.4f}") + log_lines.append("eval_accuracy=0.78") + log_lines.append("TRAINING_COMPLETE") + + return ExecutionResult( + exit_code=0, + stdout="\n".join(log_lines), + stderr="", + wall_time_ms=int((time.time() - start) * 1000) + + self._rng.randint(1000, 5000), + checkpoint_exists=True, + peak_memory_mb=self._rng.uniform(500, 2000), + script_content=script_content, + ) + + error_types = [ + "ImportError: cannot import name 'OldTrainer' from 'transformers'", + "AttributeError: 'Trainer' object has no attribute 'evaluate_model'", + "KeyError: 'text' column not found in dataset", + "TypeError: __init__() got an unexpected keyword argument 'num_epochs'", + "RuntimeError: Expected input batch_size (16) to match target batch_size (32)", + "ModuleNotFoundError: No module named 'transformers.legacy'", + ] + return ExecutionResult( + exit_code=1, + stdout="", + stderr=self._rng.choice(error_types), + wall_time_ms=int((time.time() - start) * 1000) + + self._rng.randint(100, 500), + script_content=script_content, + ) diff --git a/forgeenv/tasks/models.py b/forgeenv/tasks/models.py index 263608026b7c206f6eb283aba0213626105a61a9..7ab67d39a1fcb39c10c41ad97555d642836a3eb6 100644 --- a/forgeenv/tasks/models.py +++ b/forgeenv/tasks/models.py @@ -1,45 +1,45 @@ -"""Core data models for ForgeEnv tasks and execution results. - -These are framework-internal dataclasses (not Pydantic) used throughout the -simulation, verifier, and primitive layers. The OpenEnv-facing Pydantic -models live in `forgeenv.env.actions` / `forgeenv.env.observations`. -""" -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Optional - - -@dataclass -class Task: - """A HuggingFace training script with execution metadata.""" - - task_id: str - description: str - script_content: str - difficulty: str # "easy", "medium", "hard" - category: str = "general" - expected_loss_range: tuple[float, float] = (0.0, 5.0) - expected_accuracy_range: tuple[float, float] = (0.0, 1.0) - checkpoint_output_path: str = "/tmp/forge_output/checkpoint" - - -@dataclass -class ExecutionResult: - """Result of executing a Python script in the sandbox.""" - - exit_code: int - stdout: str - stderr: str - wall_time_ms: int - checkpoint_exists: bool = False - peak_memory_mb: float = 0.0 - script_content: str = "" - - -@dataclass -class ValidationResult: - """Result of AST validation on a script.""" - - is_valid: bool - violations: list[str] = field(default_factory=list) +"""Core data models for ForgeEnv tasks and execution results. + +These are framework-internal dataclasses (not Pydantic) used throughout the +simulation, verifier, and primitive layers. The OpenEnv-facing Pydantic +models live in `forgeenv.env.actions` / `forgeenv.env.observations`. +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class Task: + """A HuggingFace training script with execution metadata.""" + + task_id: str + description: str + script_content: str + difficulty: str # "easy", "medium", "hard" + category: str = "general" + expected_loss_range: tuple[float, float] = (0.0, 5.0) + expected_accuracy_range: tuple[float, float] = (0.0, 1.0) + checkpoint_output_path: str = "/tmp/forge_output/checkpoint" + + +@dataclass +class ExecutionResult: + """Result of executing a Python script in the sandbox.""" + + exit_code: int + stdout: str + stderr: str + wall_time_ms: int + checkpoint_exists: bool = False + peak_memory_mb: float = 0.0 + script_content: str = "" + + +@dataclass +class ValidationResult: + """Result of AST validation on a script.""" + + is_valid: bool + violations: list[str] = field(default_factory=list) diff --git a/forgeenv/tasks/seed_corpus/albert_qa.py b/forgeenv/tasks/seed_corpus/albert_qa.py index d88e6bbcaa6127e37b277192577a5ed283410430..611574fbb35b708bf82c9f7328f0725022de6247 100644 --- a/forgeenv/tasks/seed_corpus/albert_qa.py +++ b/forgeenv/tasks/seed_corpus/albert_qa.py @@ -1,67 +1,67 @@ -"""ALBERT-tiny extractive QA on 100-sample SQuAD subset.""" -from transformers import ( - AutoTokenizer, - AutoModelForQuestionAnswering, - Trainer, - TrainingArguments, - DefaultDataCollator, -) -from datasets import load_dataset - -dataset = load_dataset("squad", split="train[:100]") -tokenizer = AutoTokenizer.from_pretrained("albert-base-v2") - - -def prepare(examples): - enc = tokenizer( - examples["question"], - examples["context"], - max_length=128, - truncation="only_second", - padding="max_length", - return_offsets_mapping=True, - ) - start_positions, end_positions = [], [] - for i, offsets in enumerate(enc["offset_mapping"]): - answer = examples["answers"][i] - start_char = answer["answer_start"][0] - end_char = start_char + len(answer["text"][0]) - - token_start = next( - (idx for idx, (a, b) in enumerate(offsets) if a <= start_char < b), 0 - ) - token_end = next( - (idx for idx, (a, b) in enumerate(offsets) if a < end_char <= b), token_start - ) - start_positions.append(token_start) - end_positions.append(token_end) - - enc["start_positions"] = start_positions - enc["end_positions"] = end_positions - enc.pop("offset_mapping") - return enc - - -dataset = dataset.map(prepare, batched=True, remove_columns=dataset.column_names) - -model = AutoModelForQuestionAnswering.from_pretrained("albert-base-v2") - -training_args = TrainingArguments( - output_dir="/tmp/forge_output/checkpoint", - num_train_epochs=1, - per_device_train_batch_size=4, - logging_steps=5, - save_strategy="epoch", - no_cuda=True, - report_to="none", -) - -trainer = Trainer( - model=model, - args=training_args, - train_dataset=dataset, - data_collator=DefaultDataCollator(), -) -trainer.train() -trainer.save_model("/tmp/forge_output/checkpoint") -print("TRAINING_COMPLETE") +"""ALBERT-tiny extractive QA on 100-sample SQuAD subset.""" +from transformers import ( + AutoTokenizer, + AutoModelForQuestionAnswering, + Trainer, + TrainingArguments, + DefaultDataCollator, +) +from datasets import load_dataset + +dataset = load_dataset("squad", split="train[:100]") +tokenizer = AutoTokenizer.from_pretrained("albert-base-v2") + + +def prepare(examples): + enc = tokenizer( + examples["question"], + examples["context"], + max_length=128, + truncation="only_second", + padding="max_length", + return_offsets_mapping=True, + ) + start_positions, end_positions = [], [] + for i, offsets in enumerate(enc["offset_mapping"]): + answer = examples["answers"][i] + start_char = answer["answer_start"][0] + end_char = start_char + len(answer["text"][0]) + + token_start = next( + (idx for idx, (a, b) in enumerate(offsets) if a <= start_char < b), 0 + ) + token_end = next( + (idx for idx, (a, b) in enumerate(offsets) if a < end_char <= b), token_start + ) + start_positions.append(token_start) + end_positions.append(token_end) + + enc["start_positions"] = start_positions + enc["end_positions"] = end_positions + enc.pop("offset_mapping") + return enc + + +dataset = dataset.map(prepare, batched=True, remove_columns=dataset.column_names) + +model = AutoModelForQuestionAnswering.from_pretrained("albert-base-v2") + +training_args = TrainingArguments( + output_dir="/tmp/forge_output/checkpoint", + num_train_epochs=1, + per_device_train_batch_size=4, + logging_steps=5, + save_strategy="epoch", + no_cuda=True, + report_to="none", +) + +trainer = Trainer( + model=model, + args=training_args, + train_dataset=dataset, + data_collator=DefaultDataCollator(), +) +trainer.train() +trainer.save_model("/tmp/forge_output/checkpoint") +print("TRAINING_COMPLETE") diff --git a/forgeenv/tasks/seed_corpus/bert_ner.py b/forgeenv/tasks/seed_corpus/bert_ner.py index d80fdc20ee48173d714cf9b31bee3b6b77875161..c1fa18a4be7f1c379efd335b4af82643abb24903 100644 --- a/forgeenv/tasks/seed_corpus/bert_ner.py +++ b/forgeenv/tasks/seed_corpus/bert_ner.py @@ -1,55 +1,55 @@ -"""Bert tiny NER fine-tuning on a 200-sample CoNLL-2003 subset.""" -from transformers import ( - AutoTokenizer, - AutoModelForTokenClassification, - Trainer, - TrainingArguments, - DataCollatorForTokenClassification, -) -from datasets import load_dataset - -dataset = load_dataset("conll2003", split="train[:200]") -tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny") - - -def tokenize_and_align(example): - enc = tokenizer(example["tokens"], is_split_into_words=True, truncation=True, max_length=64) - word_ids = enc.word_ids() - labels = [] - prev_id = None - for wid in word_ids: - if wid is None: - labels.append(-100) - elif wid != prev_id: - labels.append(example["ner_tags"][wid]) - else: - labels.append(-100) - prev_id = wid - enc["labels"] = labels - return enc - - -dataset = dataset.map(tokenize_and_align, remove_columns=dataset.column_names) - -model = AutoModelForTokenClassification.from_pretrained("prajjwal1/bert-tiny", num_labels=9) - -training_args = TrainingArguments( - output_dir="/tmp/forge_output/checkpoint", - num_train_epochs=1, - per_device_train_batch_size=8, - logging_steps=5, - save_strategy="epoch", - no_cuda=True, - report_to="none", -) - -trainer = Trainer( - model=model, - args=training_args, - train_dataset=dataset, - data_collator=DataCollatorForTokenClassification(tokenizer), -) - -trainer.train() -trainer.save_model("/tmp/forge_output/checkpoint") -print("TRAINING_COMPLETE") +"""Bert tiny NER fine-tuning on a 200-sample CoNLL-2003 subset.""" +from transformers import ( + AutoTokenizer, + AutoModelForTokenClassification, + Trainer, + TrainingArguments, + DataCollatorForTokenClassification, +) +from datasets import load_dataset + +dataset = load_dataset("conll2003", split="train[:200]") +tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny") + + +def tokenize_and_align(example): + enc = tokenizer(example["tokens"], is_split_into_words=True, truncation=True, max_length=64) + word_ids = enc.word_ids() + labels = [] + prev_id = None + for wid in word_ids: + if wid is None: + labels.append(-100) + elif wid != prev_id: + labels.append(example["ner_tags"][wid]) + else: + labels.append(-100) + prev_id = wid + enc["labels"] = labels + return enc + + +dataset = dataset.map(tokenize_and_align, remove_columns=dataset.column_names) + +model = AutoModelForTokenClassification.from_pretrained("prajjwal1/bert-tiny", num_labels=9) + +training_args = TrainingArguments( + output_dir="/tmp/forge_output/checkpoint", + num_train_epochs=1, + per_device_train_batch_size=8, + logging_steps=5, + save_strategy="epoch", + no_cuda=True, + report_to="none", +) + +trainer = Trainer( + model=model, + args=training_args, + train_dataset=dataset, + data_collator=DataCollatorForTokenClassification(tokenizer), +) + +trainer.train() +trainer.save_model("/tmp/forge_output/checkpoint") +print("TRAINING_COMPLETE") diff --git a/forgeenv/tasks/seed_corpus/distilbert_sst2.py b/forgeenv/tasks/seed_corpus/distilbert_sst2.py index fb143b20743e507ed8959a2d9e2ec53be22e1e9e..e1765be46db70affa6d156bbd4b18fefe8c0a258 100644 --- a/forgeenv/tasks/seed_corpus/distilbert_sst2.py +++ b/forgeenv/tasks/seed_corpus/distilbert_sst2.py @@ -1,53 +1,53 @@ -"""DistilBERT fine-tuning on a tiny SST-2 subset. - -Minimal HuggingFace text-classification training script. Should complete -in ~60s on CPU. -""" -from transformers import ( - DistilBertTokenizer, - DistilBertForSequenceClassification, - Trainer, - TrainingArguments, -) -from datasets import load_dataset - -dataset = load_dataset("glue", "sst2", split="train[:500]") -tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") - - -def tokenize_function(examples): - return tokenizer( - examples["sentence"], - padding="max_length", - truncation=True, - max_length=64, - ) - - -dataset = dataset.map(tokenize_function, batched=True) -dataset = dataset.rename_column("label", "labels") -dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"]) - -model = DistilBertForSequenceClassification.from_pretrained( - "distilbert-base-uncased", num_labels=2 -) - -training_args = TrainingArguments( - output_dir="/tmp/forge_output/checkpoint", - num_train_epochs=1, - per_device_train_batch_size=16, - logging_steps=5, - save_strategy="epoch", - no_cuda=True, - report_to="none", -) - -trainer = Trainer( - model=model, - args=training_args, - train_dataset=dataset, -) - -trainer.train() -trainer.save_model("/tmp/forge_output/checkpoint") -print("TRAINING_COMPLETE") +"""DistilBERT fine-tuning on a tiny SST-2 subset. + +Minimal HuggingFace text-classification training script. Should complete +in ~60s on CPU. +""" +from transformers import ( + DistilBertTokenizer, + DistilBertForSequenceClassification, + Trainer, + TrainingArguments, +) +from datasets import load_dataset + +dataset = load_dataset("glue", "sst2", split="train[:500]") +tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") + + +def tokenize_function(examples): + return tokenizer( + examples["sentence"], + padding="max_length", + truncation=True, + max_length=64, + ) + + +dataset = dataset.map(tokenize_function, batched=True) +dataset = dataset.rename_column("label", "labels") +dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"]) + +model = DistilBertForSequenceClassification.from_pretrained( + "distilbert-base-uncased", num_labels=2 +) + +training_args = TrainingArguments( + output_dir="/tmp/forge_output/checkpoint", + num_train_epochs=1, + per_device_train_batch_size=16, + logging_steps=5, + save_strategy="epoch", + no_cuda=True, + report_to="none", +) + +trainer = Trainer( + model=model, + args=training_args, + train_dataset=dataset, +) + +trainer.train() +trainer.save_model("/tmp/forge_output/checkpoint") +print("TRAINING_COMPLETE") diff --git a/forgeenv/tasks/seed_corpus/electra_classification.py b/forgeenv/tasks/seed_corpus/electra_classification.py index 7121bf900be0aaaa80d64ee46e30ca688607ef9d..c6e628ea8bc1b3b915c56f202324d2ef4eebbbb5 100644 --- a/forgeenv/tasks/seed_corpus/electra_classification.py +++ b/forgeenv/tasks/seed_corpus/electra_classification.py @@ -1,44 +1,44 @@ -"""ELECTRA-small classification on 400-sample AG News (4-way text classification).""" -from transformers import ( - AutoTokenizer, - AutoModelForSequenceClassification, - Trainer, - TrainingArguments, -) -from datasets import load_dataset - -dataset = load_dataset("ag_news", split="train[:400]") -tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator") - - -def tokenize(examples): - return tokenizer( - examples["text"], - padding="max_length", - truncation=True, - max_length=64, - ) - - -dataset = dataset.map(tokenize, batched=True) -dataset = dataset.rename_column("label", "labels") -dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"]) - -model = AutoModelForSequenceClassification.from_pretrained( - "google/electra-small-discriminator", num_labels=4 -) - -training_args = TrainingArguments( - output_dir="/tmp/forge_output/checkpoint", - num_train_epochs=1, - per_device_train_batch_size=8, - logging_steps=5, - save_strategy="epoch", - no_cuda=True, - report_to="none", -) - -trainer = Trainer(model=model, args=training_args, train_dataset=dataset) -trainer.train() -trainer.save_model("/tmp/forge_output/checkpoint") -print("TRAINING_COMPLETE") +"""ELECTRA-small classification on 400-sample AG News (4-way text classification).""" +from transformers import ( + AutoTokenizer, + AutoModelForSequenceClassification, + Trainer, + TrainingArguments, +) +from datasets import load_dataset + +dataset = load_dataset("ag_news", split="train[:400]") +tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator") + + +def tokenize(examples): + return tokenizer( + examples["text"], + padding="max_length", + truncation=True, + max_length=64, + ) + + +dataset = dataset.map(tokenize, batched=True) +dataset = dataset.rename_column("label", "labels") +dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"]) + +model = AutoModelForSequenceClassification.from_pretrained( + "google/electra-small-discriminator", num_labels=4 +) + +training_args = TrainingArguments( + output_dir="/tmp/forge_output/checkpoint", + num_train_epochs=1, + per_device_train_batch_size=8, + logging_steps=5, + save_strategy="epoch", + no_cuda=True, + report_to="none", +) + +trainer = Trainer(model=model, args=training_args, train_dataset=dataset) +trainer.train() +trainer.save_model("/tmp/forge_output/checkpoint") +print("TRAINING_COMPLETE") diff --git a/forgeenv/tasks/seed_corpus/gpt2_textgen.py b/forgeenv/tasks/seed_corpus/gpt2_textgen.py index b675f05c6ec3795adddf5b0e7acce27798866aaa..aeb2fafd9cf6bb3c5c90909e056ab063928c568e 100644 --- a/forgeenv/tasks/seed_corpus/gpt2_textgen.py +++ b/forgeenv/tasks/seed_corpus/gpt2_textgen.py @@ -1,43 +1,43 @@ -"""DistilGPT2 causal-LM fine-tuning on 300 lines of WikiText (text generation).""" -from transformers import ( - AutoTokenizer, - AutoModelForCausalLM, - Trainer, - TrainingArguments, - DataCollatorForLanguageModeling, -) -from datasets import load_dataset - -dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:300]") -tokenizer = AutoTokenizer.from_pretrained("distilgpt2") -tokenizer.pad_token = tokenizer.eos_token - - -def tokenize(examples): - return tokenizer(examples["text"], truncation=True, max_length=64) - - -dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names) - -model = AutoModelForCausalLM.from_pretrained("distilgpt2") - -training_args = TrainingArguments( - output_dir="/tmp/forge_output/checkpoint", - num_train_epochs=1, - per_device_train_batch_size=4, - logging_steps=5, - save_strategy="epoch", - no_cuda=True, - report_to="none", -) - -trainer = Trainer( - model=model, - args=training_args, - train_dataset=dataset, - data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), -) - -trainer.train() -trainer.save_model("/tmp/forge_output/checkpoint") -print("TRAINING_COMPLETE") +"""DistilGPT2 causal-LM fine-tuning on 300 lines of WikiText (text generation).""" +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + Trainer, + TrainingArguments, + DataCollatorForLanguageModeling, +) +from datasets import load_dataset + +dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:300]") +tokenizer = AutoTokenizer.from_pretrained("distilgpt2") +tokenizer.pad_token = tokenizer.eos_token + + +def tokenize(examples): + return tokenizer(examples["text"], truncation=True, max_length=64) + + +dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names) + +model = AutoModelForCausalLM.from_pretrained("distilgpt2") + +training_args = TrainingArguments( + output_dir="/tmp/forge_output/checkpoint", + num_train_epochs=1, + per_device_train_batch_size=4, + logging_steps=5, + save_strategy="epoch", + no_cuda=True, + report_to="none", +) + +trainer = Trainer( + model=model, + args=training_args, + train_dataset=dataset, + data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), +) + +trainer.train() +trainer.save_model("/tmp/forge_output/checkpoint") +print("TRAINING_COMPLETE") diff --git a/forgeenv/tasks/seed_corpus/logistic_classifier.py b/forgeenv/tasks/seed_corpus/logistic_classifier.py index 644c718a652332643239a471820ac601df0e2898..a4b86a05de0b529f8d7f3039badb3bb96e544235 100644 --- a/forgeenv/tasks/seed_corpus/logistic_classifier.py +++ b/forgeenv/tasks/seed_corpus/logistic_classifier.py @@ -1,36 +1,36 @@ -"""Sklearn logistic-regression baseline on a 500-sample tabular task. - -Sanity baseline that doesn't require torch / transformers / datasets. -""" -import json -import pickle -from pathlib import Path - -import numpy as np -from sklearn.datasets import make_classification -from sklearn.linear_model import LogisticRegression -from sklearn.model_selection import train_test_split - -X, y = make_classification( - n_samples=500, n_features=20, n_informative=10, random_state=0 -) -X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0) - -model = LogisticRegression(max_iter=200) -for step in range(1, 11): - model.set_params(max_iter=step * 20) - model.fit(X_train, y_train) - train_loss = -np.mean(np.log(np.maximum(model.predict_proba(X_train)[np.arange(len(y_train)), y_train], 1e-9))) - print(f"step={step} loss={train_loss:.4f}") - -acc = model.score(X_test, y_test) -print(f"eval_accuracy={acc:.4f}") - -ckpt_dir = Path("/tmp/forge_output/checkpoint") -ckpt_dir.mkdir(parents=True, exist_ok=True) -with open(ckpt_dir / "logreg.pkl", "wb") as f: - pickle.dump(model, f) -with open(ckpt_dir / "metrics.json", "w") as f: - json.dump({"accuracy": acc}, f) - -print("TRAINING_COMPLETE") +"""Sklearn logistic-regression baseline on a 500-sample tabular task. + +Sanity baseline that doesn't require torch / transformers / datasets. +""" +import json +import pickle +from pathlib import Path + +import numpy as np +from sklearn.datasets import make_classification +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split + +X, y = make_classification( + n_samples=500, n_features=20, n_informative=10, random_state=0 +) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0) + +model = LogisticRegression(max_iter=200) +for step in range(1, 11): + model.set_params(max_iter=step * 20) + model.fit(X_train, y_train) + train_loss = -np.mean(np.log(np.maximum(model.predict_proba(X_train)[np.arange(len(y_train)), y_train], 1e-9))) + print(f"step={step} loss={train_loss:.4f}") + +acc = model.score(X_test, y_test) +print(f"eval_accuracy={acc:.4f}") + +ckpt_dir = Path("/tmp/forge_output/checkpoint") +ckpt_dir.mkdir(parents=True, exist_ok=True) +with open(ckpt_dir / "logreg.pkl", "wb") as f: + pickle.dump(model, f) +with open(ckpt_dir / "metrics.json", "w") as f: + json.dump({"accuracy": acc}, f) + +print("TRAINING_COMPLETE") diff --git a/forgeenv/tasks/seed_corpus/roberta_sentiment.py b/forgeenv/tasks/seed_corpus/roberta_sentiment.py index 8f56578d12fc82a548060480aa5e66e61e259b63..43cb3d2aa65cc6f3897fbba64a0698447a5ec2d1 100644 --- a/forgeenv/tasks/seed_corpus/roberta_sentiment.py +++ b/forgeenv/tasks/seed_corpus/roberta_sentiment.py @@ -1,44 +1,44 @@ -"""DistilRoberta sentiment classification on 400-sample IMDB subset.""" -from transformers import ( - AutoTokenizer, - AutoModelForSequenceClassification, - Trainer, - TrainingArguments, -) -from datasets import load_dataset - -dataset = load_dataset("imdb", split="train[:400]") -tokenizer = AutoTokenizer.from_pretrained("distilroberta-base") - - -def tokenize(examples): - return tokenizer( - examples["text"], - padding="max_length", - truncation=True, - max_length=64, - ) - - -dataset = dataset.map(tokenize, batched=True) -dataset = dataset.rename_column("label", "labels") -dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"]) - -model = AutoModelForSequenceClassification.from_pretrained( - "distilroberta-base", num_labels=2 -) - -training_args = TrainingArguments( - output_dir="/tmp/forge_output/checkpoint", - num_train_epochs=1, - per_device_train_batch_size=8, - logging_steps=5, - save_strategy="epoch", - no_cuda=True, - report_to="none", -) - -trainer = Trainer(model=model, args=training_args, train_dataset=dataset) -trainer.train() -trainer.save_model("/tmp/forge_output/checkpoint") -print("TRAINING_COMPLETE") +"""DistilRoberta sentiment classification on 400-sample IMDB subset.""" +from transformers import ( + AutoTokenizer, + AutoModelForSequenceClassification, + Trainer, + TrainingArguments, +) +from datasets import load_dataset + +dataset = load_dataset("imdb", split="train[:400]") +tokenizer = AutoTokenizer.from_pretrained("distilroberta-base") + + +def tokenize(examples): + return tokenizer( + examples["text"], + padding="max_length", + truncation=True, + max_length=64, + ) + + +dataset = dataset.map(tokenize, batched=True) +dataset = dataset.rename_column("label", "labels") +dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"]) + +model = AutoModelForSequenceClassification.from_pretrained( + "distilroberta-base", num_labels=2 +) + +training_args = TrainingArguments( + output_dir="/tmp/forge_output/checkpoint", + num_train_epochs=1, + per_device_train_batch_size=8, + logging_steps=5, + save_strategy="epoch", + no_cuda=True, + report_to="none", +) + +trainer = Trainer(model=model, args=training_args, train_dataset=dataset) +trainer.train() +trainer.save_model("/tmp/forge_output/checkpoint") +print("TRAINING_COMPLETE") diff --git a/forgeenv/tasks/seed_corpus/simple_regression.py b/forgeenv/tasks/seed_corpus/simple_regression.py index 2d5cbc9d5966b2f1918c72552944f8e16e3441e6..7550a74956bd58ee25571a5e9e4f3df8f0bf443b 100644 --- a/forgeenv/tasks/seed_corpus/simple_regression.py +++ b/forgeenv/tasks/seed_corpus/simple_regression.py @@ -1,28 +1,28 @@ -"""Tiny PyTorch regression on synthetic data (no HF imports β€” sanity baseline).""" -import torch -import torch.nn as nn - -torch.manual_seed(0) -x = torch.randn(500, 4) -y = (x @ torch.tensor([1.5, -2.0, 0.5, 3.0])) + 0.1 * torch.randn(500) - -model = nn.Sequential( - nn.Linear(4, 16), - nn.ReLU(), - nn.Linear(16, 1), -) - -optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) -criterion = nn.MSELoss() - -for epoch in range(50): - optimizer.zero_grad() - preds = model(x).squeeze(-1) - loss = criterion(preds, y) - loss.backward() - optimizer.step() - if (epoch + 1) % 5 == 0: - print(f"step={epoch + 1} loss={loss.item():.4f}") - -torch.save(model.state_dict(), "/tmp/forge_output/checkpoint/regression.pt") -print("TRAINING_COMPLETE") +"""Tiny PyTorch regression on synthetic data (no HF imports β€” sanity baseline).""" +import torch +import torch.nn as nn + +torch.manual_seed(0) +x = torch.randn(500, 4) +y = (x @ torch.tensor([1.5, -2.0, 0.5, 3.0])) + 0.1 * torch.randn(500) + +model = nn.Sequential( + nn.Linear(4, 16), + nn.ReLU(), + nn.Linear(16, 1), +) + +optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) +criterion = nn.MSELoss() + +for epoch in range(50): + optimizer.zero_grad() + preds = model(x).squeeze(-1) + loss = criterion(preds, y) + loss.backward() + optimizer.step() + if (epoch + 1) % 5 == 0: + print(f"step={epoch + 1} loss={loss.item():.4f}") + +torch.save(model.state_dict(), "/tmp/forge_output/checkpoint/regression.pt") +print("TRAINING_COMPLETE") diff --git a/forgeenv/tasks/seed_corpus/t5_summarization.py b/forgeenv/tasks/seed_corpus/t5_summarization.py index 8897704d2250d58ce471e91157316edc86ae6aeb..1258c5c3a52828c6894c11e6f57f8517f6ad0afa 100644 --- a/forgeenv/tasks/seed_corpus/t5_summarization.py +++ b/forgeenv/tasks/seed_corpus/t5_summarization.py @@ -1,55 +1,55 @@ -"""Tiny T5 fine-tuning for summarization on 100-sample CNN/DailyMail.""" -from transformers import ( - AutoTokenizer, - AutoModelForSeq2SeqLM, - DataCollatorForSeq2Seq, - Seq2SeqTrainer, - Seq2SeqTrainingArguments, -) -from datasets import load_dataset - -dataset = load_dataset("cnn_dailymail", "3.0.0", split="train[:100]") -tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") - - -def preprocess(examples): - inputs = tokenizer( - ["summarize: " + a for a in examples["article"]], - max_length=128, - truncation=True, - padding="max_length", - ) - targets = tokenizer( - examples["highlights"], - max_length=32, - truncation=True, - padding="max_length", - ) - inputs["labels"] = targets["input_ids"] - return inputs - - -dataset = dataset.map(preprocess, batched=True, remove_columns=dataset.column_names) - -model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small") - -training_args = Seq2SeqTrainingArguments( - output_dir="/tmp/forge_output/checkpoint", - num_train_epochs=1, - per_device_train_batch_size=4, - logging_steps=5, - save_strategy="epoch", - no_cuda=True, - report_to="none", - predict_with_generate=False, -) - -trainer = Seq2SeqTrainer( - model=model, - args=training_args, - train_dataset=dataset, - data_collator=DataCollatorForSeq2Seq(tokenizer, model=model), -) -trainer.train() -trainer.save_model("/tmp/forge_output/checkpoint") -print("TRAINING_COMPLETE") +"""Tiny T5 fine-tuning for summarization on 100-sample CNN/DailyMail.""" +from transformers import ( + AutoTokenizer, + AutoModelForSeq2SeqLM, + DataCollatorForSeq2Seq, + Seq2SeqTrainer, + Seq2SeqTrainingArguments, +) +from datasets import load_dataset + +dataset = load_dataset("cnn_dailymail", "3.0.0", split="train[:100]") +tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") + + +def preprocess(examples): + inputs = tokenizer( + ["summarize: " + a for a in examples["article"]], + max_length=128, + truncation=True, + padding="max_length", + ) + targets = tokenizer( + examples["highlights"], + max_length=32, + truncation=True, + padding="max_length", + ) + inputs["labels"] = targets["input_ids"] + return inputs + + +dataset = dataset.map(preprocess, batched=True, remove_columns=dataset.column_names) + +model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small") + +training_args = Seq2SeqTrainingArguments( + output_dir="/tmp/forge_output/checkpoint", + num_train_epochs=1, + per_device_train_batch_size=4, + logging_steps=5, + save_strategy="epoch", + no_cuda=True, + report_to="none", + predict_with_generate=False, +) + +trainer = Seq2SeqTrainer( + model=model, + args=training_args, + train_dataset=dataset, + data_collator=DataCollatorForSeq2Seq(tokenizer, model=model), +) +trainer.train() +trainer.save_model("/tmp/forge_output/checkpoint") +print("TRAINING_COMPLETE") diff --git a/forgeenv/tasks/seed_corpus/tiny_mlp_mnist.py b/forgeenv/tasks/seed_corpus/tiny_mlp_mnist.py index c9bf669a4748d453dbf71c27e356bdde87dfb915..6eb871501f8811ffa415ac9166625689df47b258 100644 --- a/forgeenv/tasks/seed_corpus/tiny_mlp_mnist.py +++ b/forgeenv/tasks/seed_corpus/tiny_mlp_mnist.py @@ -1,38 +1,38 @@ -"""Tiny PyTorch MLP on a 1000-sample MNIST subset (image classification baseline).""" -import torch -import torch.nn as nn -from torch.utils.data import DataLoader -from datasets import load_dataset - -dataset = load_dataset("mnist", split="train[:1000]") -dataset = dataset.with_format("torch") - - -def collate(batch): - pixel = torch.stack([b["image"].float().flatten() / 255.0 for b in batch]) - labels = torch.tensor([b["label"] for b in batch]) - return pixel, labels - - -loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate) - -model = nn.Sequential( - nn.Linear(784, 64), - nn.ReLU(), - nn.Linear(64, 10), -) - -optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) -criterion = nn.CrossEntropyLoss() - -for epoch in range(2): - for step, (x, y) in enumerate(loader, start=1): - optimizer.zero_grad() - loss = criterion(model(x), y) - loss.backward() - optimizer.step() - if step % 5 == 0: - print(f"step={step} loss={loss.item():.4f}") - -torch.save(model.state_dict(), "/tmp/forge_output/checkpoint/mlp.pt") -print("TRAINING_COMPLETE") +"""Tiny PyTorch MLP on a 1000-sample MNIST subset (image classification baseline).""" +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from datasets import load_dataset + +dataset = load_dataset("mnist", split="train[:1000]") +dataset = dataset.with_format("torch") + + +def collate(batch): + pixel = torch.stack([b["image"].float().flatten() / 255.0 for b in batch]) + labels = torch.tensor([b["label"] for b in batch]) + return pixel, labels + + +loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate) + +model = nn.Sequential( + nn.Linear(784, 64), + nn.ReLU(), + nn.Linear(64, 10), +) + +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) +criterion = nn.CrossEntropyLoss() + +for epoch in range(2): + for step, (x, y) in enumerate(loader, start=1): + optimizer.zero_grad() + loss = criterion(model(x), y) + loss.backward() + optimizer.step() + if step % 5 == 0: + print(f"step={step} loss={loss.item():.4f}") + +torch.save(model.state_dict(), "/tmp/forge_output/checkpoint/mlp.pt") +print("TRAINING_COMPLETE") diff --git a/forgeenv/tasks/seed_corpus/vit_cifar10.py b/forgeenv/tasks/seed_corpus/vit_cifar10.py index 0c96d0aeec693c915b96707996e6c33a524ac87c..f9dcd9b03a1f32e7734ad714d0131a2ce2e5a1fe 100644 --- a/forgeenv/tasks/seed_corpus/vit_cifar10.py +++ b/forgeenv/tasks/seed_corpus/vit_cifar10.py @@ -1,41 +1,41 @@ -"""Tiny ViT image classification on 200-sample CIFAR-10 subset.""" -from transformers import ( - AutoImageProcessor, - AutoModelForImageClassification, - Trainer, - TrainingArguments, -) -from datasets import load_dataset -import torch - -dataset = load_dataset("cifar10", split="train[:200]") -processor = AutoImageProcessor.from_pretrained("WinKawaks/vit-tiny-patch16-224") - - -def transform(batch): - images = [img.convert("RGB") for img in batch["img"]] - inputs = processor(images=images, return_tensors="pt") - inputs["labels"] = torch.tensor(batch["label"]) - return inputs - - -dataset = dataset.with_transform(transform) - -model = AutoModelForImageClassification.from_pretrained( - "WinKawaks/vit-tiny-patch16-224", num_labels=10, ignore_mismatched_sizes=True -) - -training_args = TrainingArguments( - output_dir="/tmp/forge_output/checkpoint", - num_train_epochs=1, - per_device_train_batch_size=4, - logging_steps=5, - save_strategy="epoch", - no_cuda=True, - report_to="none", -) - -trainer = Trainer(model=model, args=training_args, train_dataset=dataset) -trainer.train() -trainer.save_model("/tmp/forge_output/checkpoint") -print("TRAINING_COMPLETE") +"""Tiny ViT image classification on 200-sample CIFAR-10 subset.""" +from transformers import ( + AutoImageProcessor, + AutoModelForImageClassification, + Trainer, + TrainingArguments, +) +from datasets import load_dataset +import torch + +dataset = load_dataset("cifar10", split="train[:200]") +processor = AutoImageProcessor.from_pretrained("WinKawaks/vit-tiny-patch16-224") + + +def transform(batch): + images = [img.convert("RGB") for img in batch["img"]] + inputs = processor(images=images, return_tensors="pt") + inputs["labels"] = torch.tensor(batch["label"]) + return inputs + + +dataset = dataset.with_transform(transform) + +model = AutoModelForImageClassification.from_pretrained( + "WinKawaks/vit-tiny-patch16-224", num_labels=10, ignore_mismatched_sizes=True +) + +training_args = TrainingArguments( + output_dir="/tmp/forge_output/checkpoint", + num_train_epochs=1, + per_device_train_batch_size=4, + logging_steps=5, + save_strategy="epoch", + no_cuda=True, + report_to="none", +) + +trainer = Trainer(model=model, args=training_args, train_dataset=dataset) +trainer.train() +trainer.save_model("/tmp/forge_output/checkpoint") +print("TRAINING_COMPLETE") diff --git a/forgeenv/tasks/task_sampler.py b/forgeenv/tasks/task_sampler.py index 44c4300ea79c690283f815804906ae0f1bfb6a66..6b8069a2533526c2772dc178e39dee993e44ffd0 100644 --- a/forgeenv/tasks/task_sampler.py +++ b/forgeenv/tasks/task_sampler.py @@ -1,105 +1,105 @@ -"""Task sampler: loads the seed corpus and samples Tasks by difficulty. - -Difficulty is auto-derived from script line count. Category is auto-detected -from script content (text_classification, ner, translation, etc.). -""" -from __future__ import annotations - -import random -from pathlib import Path -from typing import Optional - -from forgeenv.tasks.models import Task - - -def _detect_category(content: str) -> str: - cl = content.lower() - if "sequenceclassification" in cl or "sentiment" in cl or "ag_news" in cl or "sst2" in cl: - return "text_classification" - if "tokenclassification" in cl or "ner" in cl or "conll" in cl: - return "ner" - if "seq2seq" in cl or "translation" in cl or "summariz" in cl or "t5" in cl: - return "seq2seq" - if "causallm" in cl or "gpt2" in cl or "wikitext" in cl: - return "text_generation" - if "imageclassification" in cl or "vit" in cl or "cifar" in cl or "mnist" in cl: - return "image_classification" - if "questionanswering" in cl or "squad" in cl: - return "qa" - if "logisticregression" in cl or "make_classification" in cl: - return "tabular" - if "regression" in cl: - return "regression" - return "general" - - -def _derive_difficulty(content: str) -> str: - lines = len(content.splitlines()) - if lines < 30: - return "easy" - if lines < 60: - return "medium" - return "hard" - - -class TaskSampler: - """Loads seed corpus and samples tasks by difficulty / category.""" - - def __init__(self, seed_dir: Optional[str] = None) -> None: - if seed_dir is None: - seed_dir = str(Path(__file__).parent / "seed_corpus") - - self.tasks: list[Task] = [] - self._load_corpus(seed_dir) - - def _load_corpus(self, seed_dir: str) -> None: - corpus_path = Path(seed_dir) - if not corpus_path.exists(): - return - - for py_file in sorted(corpus_path.glob("*.py")): - if py_file.name.startswith("__"): - continue - - content = py_file.read_text(encoding="utf-8") - task_id = py_file.stem - difficulty = _derive_difficulty(content) - category = _detect_category(content) - - description = "" - if content.startswith('"""'): - end = content.find('"""', 3) - if end != -1: - description = content[3:end].strip() - - self.tasks.append( - Task( - task_id=task_id, - description=description or f"Training script: {task_id}", - script_content=content, - difficulty=difficulty, - category=category, - ) - ) - - def sample(self, difficulty: Optional[str] = None) -> Optional[Task]: - candidates = self.tasks - if difficulty is not None: - filtered = [t for t in self.tasks if t.difficulty == difficulty] - if filtered: - candidates = filtered - return random.choice(candidates) if candidates else None - - def sample_batch( - self, n: int, difficulty: Optional[str] = None - ) -> list[Task]: - return [t for t in (self.sample(difficulty) for _ in range(n)) if t is not None] - - def get_all_categories(self) -> list[str]: - return sorted({t.category for t in self.tasks}) - - def get_by_id(self, task_id: str) -> Optional[Task]: - for t in self.tasks: - if t.task_id == task_id: - return t - return None +"""Task sampler: loads the seed corpus and samples Tasks by difficulty. + +Difficulty is auto-derived from script line count. Category is auto-detected +from script content (text_classification, ner, translation, etc.). +""" +from __future__ import annotations + +import random +from pathlib import Path +from typing import Optional + +from forgeenv.tasks.models import Task + + +def _detect_category(content: str) -> str: + cl = content.lower() + if "sequenceclassification" in cl or "sentiment" in cl or "ag_news" in cl or "sst2" in cl: + return "text_classification" + if "tokenclassification" in cl or "ner" in cl or "conll" in cl: + return "ner" + if "seq2seq" in cl or "translation" in cl or "summariz" in cl or "t5" in cl: + return "seq2seq" + if "causallm" in cl or "gpt2" in cl or "wikitext" in cl: + return "text_generation" + if "imageclassification" in cl or "vit" in cl or "cifar" in cl or "mnist" in cl: + return "image_classification" + if "questionanswering" in cl or "squad" in cl: + return "qa" + if "logisticregression" in cl or "make_classification" in cl: + return "tabular" + if "regression" in cl: + return "regression" + return "general" + + +def _derive_difficulty(content: str) -> str: + lines = len(content.splitlines()) + if lines < 30: + return "easy" + if lines < 60: + return "medium" + return "hard" + + +class TaskSampler: + """Loads seed corpus and samples tasks by difficulty / category.""" + + def __init__(self, seed_dir: Optional[str] = None) -> None: + if seed_dir is None: + seed_dir = str(Path(__file__).parent / "seed_corpus") + + self.tasks: list[Task] = [] + self._load_corpus(seed_dir) + + def _load_corpus(self, seed_dir: str) -> None: + corpus_path = Path(seed_dir) + if not corpus_path.exists(): + return + + for py_file in sorted(corpus_path.glob("*.py")): + if py_file.name.startswith("__"): + continue + + content = py_file.read_text(encoding="utf-8") + task_id = py_file.stem + difficulty = _derive_difficulty(content) + category = _detect_category(content) + + description = "" + if content.startswith('"""'): + end = content.find('"""', 3) + if end != -1: + description = content[3:end].strip() + + self.tasks.append( + Task( + task_id=task_id, + description=description or f"Training script: {task_id}", + script_content=content, + difficulty=difficulty, + category=category, + ) + ) + + def sample(self, difficulty: Optional[str] = None) -> Optional[Task]: + candidates = self.tasks + if difficulty is not None: + filtered = [t for t in self.tasks if t.difficulty == difficulty] + if filtered: + candidates = filtered + return random.choice(candidates) if candidates else None + + def sample_batch( + self, n: int, difficulty: Optional[str] = None + ) -> list[Task]: + return [t for t in (self.sample(difficulty) for _ in range(n)) if t is not None] + + def get_all_categories(self) -> list[str]: + return sorted({t.category for t in self.tasks}) + + def get_by_id(self, task_id: str) -> Optional[Task]: + for t in self.tasks: + if t.task_id == task_id: + return t + return None diff --git a/forgeenv/training/grpo_drift.py b/forgeenv/training/grpo_drift.py index e6fbb702f10f654a8be7edaed362382c69117a12..7d38a44edb982e0942ae9c30ce8194f39db76859 100644 --- a/forgeenv/training/grpo_drift.py +++ b/forgeenv/training/grpo_drift.py @@ -1,168 +1,168 @@ -"""GRPO trainer for the Drift Generator. - -Uses R-Zero's composite Challenger reward: max(0, uncertainty - repetition). -Each prompt is sampled `group_size` times; for every breakage we run K -independent Repair Agent rollouts to estimate p_hat (success rate). - -Heavy and brittle on a single GPU β€” keep group_size small for hackathon -budgets. Provides a `--dry_run` mode that just exercises the reward function -without any LLM calls. -""" -from __future__ import annotations - -import argparse -import json -import os -import random -from pathlib import Path -from typing import Optional - -from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction -from forgeenv.env.forge_environment import ForgeEnvironment -from forgeenv.roles.drift_generator import ( - BaselineDriftGenerator, - parse_drift_output, -) -from forgeenv.roles.prompts import ( - DRIFT_GENERATOR_SYSTEM_PROMPT, - render_drift_generator_prompt, -) -from forgeenv.training.rollout import ( - GenerateFn, - rollout_one_episode, - baseline_oracle_repair_generate, - _baseline_repair_generate, -) -from forgeenv.training.reward_functions import ( - compute_drift_gen_reward, - compute_uncertainty_reward, - compute_repetition_penalty, -) - - -def evaluate_drift_batch( - env_factory, - breakages: list[dict], - repair_generate: GenerateFn, - n_repair_attempts_per_breakage: int = 4, - seed: int = 0, -) -> list[float]: - """For each breakage spec, run K Repair-Agent attempts and compute - R-Zero's composite Challenger reward. Returns one reward per breakage.""" - - breakage_texts = [ - f"{b.get('primitive_type','')}::{json.dumps(b.get('params', {}), sort_keys=True)}" - for b in breakages - ] - - rewards: list[float] = [] - for idx, breakage_spec in enumerate(breakages): - successes: list[bool] = [] - for k in range(n_repair_attempts_per_breakage): - env = env_factory() - env.reset(seed=seed + idx * 100 + k, difficulty="easy") - try: - obs2 = env.step( - ForgeAction( - breakage=BreakageAction( - primitive_type=breakage_spec.get("primitive_type", ""), - params=breakage_spec.get("params", {}) or {}, - ) - ) - ) - except Exception: - successes.append(False) - continue - - from forgeenv.roles.repair_agent import extract_diff - from forgeenv.roles.prompts import render_repair_agent_prompt - - user = render_repair_agent_prompt( - broken_script=obs2.script_content, - error_trace=obs2.error_trace or "", - library_versions=obs2.library_versions, - target_category=obs2.target_category, - ) - raw = repair_generate("", user) - diff = extract_diff(raw or "") - obs3 = env.step(ForgeAction(repair=RepairAction(unified_diff=diff))) - successes.append( - bool(obs3.held_out_breakdown.get("executed_cleanly", 0.0) > 0.5) - ) - - reward = compute_drift_gen_reward( - breakage_text=breakage_texts[idx], - repair_successes=successes, - batch_breakages=breakage_texts, - ) - rewards.append(reward) - return rewards - - -def run_drift_grpo_dry_run( - output_dir: str, total_episodes: int = 100, group_size: int = 4, seed: int = 0 -) -> None: - """Pure-CPU exercise of the drift-side reward loop. Writes per-step rewards.""" - rng = random.Random(seed) - drift_gen = BaselineDriftGenerator(seed=seed) - rewards_log: list[dict] = [] - - for ep in range(total_episodes): - env = ForgeEnvironment(seed=seed + ep) - env.reset(difficulty="easy") - target_category = env.state["target_category"] - script = env._original_script # noqa: SLF001 β€” read-only convenience - - # Sample group_size candidate breakages - candidates = [ - drift_gen.propose(target_category=target_category, script=script) - for _ in range(group_size) - ] - - # Use the oracle as repair (so we get a meaningful uncertainty signal: - # an "unbreakable" breakage gives p_hat=1, an "always-fails" one gives 0) - rewards = evaluate_drift_batch( - env_factory=lambda: ForgeEnvironment(seed=rng.randint(0, 1_000_000)), - breakages=candidates, - repair_generate=baseline_oracle_repair_generate(env), - n_repair_attempts_per_breakage=2, - seed=seed + ep, - ) - rewards_log.append( - {"episode": ep, "rewards": rewards, "candidates": candidates} - ) - - if ep % max(1, total_episodes // 10) == 0: - mean_r = sum(rewards) / max(1, len(rewards)) - print(f"[drift dry-run] ep={ep} mean_reward={mean_r:.3f}") - - Path(output_dir).mkdir(parents=True, exist_ok=True) - (Path(output_dir) / "drift_dry_run.json").write_text( - json.dumps(rewards_log, indent=2) - ) - print(f"[drift dry-run] wrote {len(rewards_log)} episodes to {output_dir}") - - -def _parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--output_dir", required=True) - parser.add_argument("--total_episodes", type=int, default=100) - parser.add_argument("--group_size", type=int, default=4) - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--dry_run", action="store_true", default=True) - return parser.parse_args() - - -if __name__ == "__main__": - args = _parse_args() - if args.dry_run: - run_drift_grpo_dry_run( - output_dir=args.output_dir, - total_episodes=args.total_episodes, - group_size=args.group_size, - seed=args.seed, - ) - else: - raise NotImplementedError( - "Full LLM Drift GRPO requires both roles loaded β€” use the Colab notebook" - ) +"""GRPO trainer for the Drift Generator. + +Uses R-Zero's composite Challenger reward: max(0, uncertainty - repetition). +Each prompt is sampled `group_size` times; for every breakage we run K +independent Repair Agent rollouts to estimate p_hat (success rate). + +Heavy and brittle on a single GPU β€” keep group_size small for hackathon +budgets. Provides a `--dry_run` mode that just exercises the reward function +without any LLM calls. +""" +from __future__ import annotations + +import argparse +import json +import os +import random +from pathlib import Path +from typing import Optional + +from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction +from forgeenv.env.forge_environment import ForgeEnvironment +from forgeenv.roles.drift_generator import ( + BaselineDriftGenerator, + parse_drift_output, +) +from forgeenv.roles.prompts import ( + DRIFT_GENERATOR_SYSTEM_PROMPT, + render_drift_generator_prompt, +) +from forgeenv.training.rollout import ( + GenerateFn, + rollout_one_episode, + baseline_oracle_repair_generate, + _baseline_repair_generate, +) +from forgeenv.training.reward_functions import ( + compute_drift_gen_reward, + compute_uncertainty_reward, + compute_repetition_penalty, +) + + +def evaluate_drift_batch( + env_factory, + breakages: list[dict], + repair_generate: GenerateFn, + n_repair_attempts_per_breakage: int = 4, + seed: int = 0, +) -> list[float]: + """For each breakage spec, run K Repair-Agent attempts and compute + R-Zero's composite Challenger reward. Returns one reward per breakage.""" + + breakage_texts = [ + f"{b.get('primitive_type','')}::{json.dumps(b.get('params', {}), sort_keys=True)}" + for b in breakages + ] + + rewards: list[float] = [] + for idx, breakage_spec in enumerate(breakages): + successes: list[bool] = [] + for k in range(n_repair_attempts_per_breakage): + env = env_factory() + env.reset(seed=seed + idx * 100 + k, difficulty="easy") + try: + obs2 = env.step( + ForgeAction( + breakage=BreakageAction( + primitive_type=breakage_spec.get("primitive_type", ""), + params=breakage_spec.get("params", {}) or {}, + ) + ) + ) + except Exception: + successes.append(False) + continue + + from forgeenv.roles.repair_agent import extract_diff + from forgeenv.roles.prompts import render_repair_agent_prompt + + user = render_repair_agent_prompt( + broken_script=obs2.script_content, + error_trace=obs2.error_trace or "", + library_versions=obs2.library_versions, + target_category=obs2.target_category, + ) + raw = repair_generate("", user) + diff = extract_diff(raw or "") + obs3 = env.step(ForgeAction(repair=RepairAction(unified_diff=diff))) + successes.append( + bool(obs3.held_out_breakdown.get("executed_cleanly", 0.0) > 0.5) + ) + + reward = compute_drift_gen_reward( + breakage_text=breakage_texts[idx], + repair_successes=successes, + batch_breakages=breakage_texts, + ) + rewards.append(reward) + return rewards + + +def run_drift_grpo_dry_run( + output_dir: str, total_episodes: int = 100, group_size: int = 4, seed: int = 0 +) -> None: + """Pure-CPU exercise of the drift-side reward loop. Writes per-step rewards.""" + rng = random.Random(seed) + drift_gen = BaselineDriftGenerator(seed=seed) + rewards_log: list[dict] = [] + + for ep in range(total_episodes): + env = ForgeEnvironment(seed=seed + ep) + env.reset(difficulty="easy") + target_category = env.state["target_category"] + script = env._original_script # noqa: SLF001 β€” read-only convenience + + # Sample group_size candidate breakages + candidates = [ + drift_gen.propose(target_category=target_category, script=script) + for _ in range(group_size) + ] + + # Use the oracle as repair (so we get a meaningful uncertainty signal: + # an "unbreakable" breakage gives p_hat=1, an "always-fails" one gives 0) + rewards = evaluate_drift_batch( + env_factory=lambda: ForgeEnvironment(seed=rng.randint(0, 1_000_000)), + breakages=candidates, + repair_generate=baseline_oracle_repair_generate(env), + n_repair_attempts_per_breakage=2, + seed=seed + ep, + ) + rewards_log.append( + {"episode": ep, "rewards": rewards, "candidates": candidates} + ) + + if ep % max(1, total_episodes // 10) == 0: + mean_r = sum(rewards) / max(1, len(rewards)) + print(f"[drift dry-run] ep={ep} mean_reward={mean_r:.3f}") + + Path(output_dir).mkdir(parents=True, exist_ok=True) + (Path(output_dir) / "drift_dry_run.json").write_text( + json.dumps(rewards_log, indent=2) + ) + print(f"[drift dry-run] wrote {len(rewards_log)} episodes to {output_dir}") + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--output_dir", required=True) + parser.add_argument("--total_episodes", type=int, default=100) + parser.add_argument("--group_size", type=int, default=4) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--dry_run", action="store_true", default=True) + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_args() + if args.dry_run: + run_drift_grpo_dry_run( + output_dir=args.output_dir, + total_episodes=args.total_episodes, + group_size=args.group_size, + seed=args.seed, + ) + else: + raise NotImplementedError( + "Full LLM Drift GRPO requires both roles loaded β€” use the Colab notebook" + ) diff --git a/forgeenv/training/grpo_repair.py b/forgeenv/training/grpo_repair.py index 750c9ca89912414429fadf77366860bb10a1c709..1ffc2677fdf6605485fdfb843b8c2768df714bbc 100644 --- a/forgeenv/training/grpo_repair.py +++ b/forgeenv/training/grpo_repair.py @@ -1,213 +1,213 @@ -"""GRPO trainer for the Repair Agent. - -This wires TRL's GRPOTrainer to ForgeEnvironment via a per-prompt rollout -function. Each prompt is sampled K times (group size); each sample is -executed in the env and gets a scalar reward from the visible verifier. - -Usage: - python -m forgeenv.training.grpo_repair \\ - --base_model unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit \\ - --adapter_path artifacts/checkpoints/repair_agent_sft \\ - --output_dir artifacts/checkpoints/repair_agent_grpo \\ - --total_episodes 200 --group_size 4 -""" -from __future__ import annotations - -import argparse -import json -import os -from pathlib import Path -from typing import Any, Optional - -from forgeenv.env.forge_environment import ForgeEnvironment -from forgeenv.roles.drift_generator import BaselineDriftGenerator -from forgeenv.roles.prompts import ( - DRIFT_GENERATOR_SYSTEM_PROMPT, - REPAIR_AGENT_SYSTEM_PROMPT, - render_drift_generator_prompt, - render_repair_agent_prompt, -) -from forgeenv.roles.repair_agent import extract_diff -from forgeenv.training.rollout import rollout_one_episode - - -def _build_repair_prompt(env: ForgeEnvironment) -> dict[str, Any]: - """Reset env, run baseline drift generator, return a repair-prompt - dict ready to feed to TRL's GRPOTrainer.""" - drift_gen = BaselineDriftGenerator() - - obs = env.reset(difficulty="easy") - drift_user = render_drift_generator_prompt( - script=obs.script_content, - target_category=obs.target_category, - library_versions=obs.library_versions, - ) - spec = drift_gen.propose( - target_category=obs.target_category, script=obs.script_content - ) - from forgeenv.env.actions import BreakageAction, ForgeAction - - obs2 = env.step( - ForgeAction( - breakage=BreakageAction( - primitive_type=spec["primitive_type"], params=spec["params"] - ) - ) - ) - - user = render_repair_agent_prompt( - broken_script=obs2.script_content, - error_trace=obs2.error_trace or "", - library_versions=obs2.library_versions, - target_category=obs2.target_category, - ) - return { - "prompt": [ - {"role": "system", "content": REPAIR_AGENT_SYSTEM_PROMPT}, - {"role": "user", "content": user}, - ], - "task_id": obs.task_id, - "primitive_type": spec["primitive_type"], - "broken_script": obs2.script_content, - "drift_user_prompt": drift_user, - } - - -def reward_repair_function( - completions: list, prompts: list = None, **kwargs -) -> list[float]: - """TRL-compatible reward fn: scores a batch of completions against - a (broken_script, breakage_spec) tuple stored on each example.""" - from forgeenv.env.actions import RepairAction, ForgeAction - from forgeenv.env.diff_utils import apply_unified_diff - from forgeenv.sandbox.simulation_mode import SimulationExecutor - from forgeenv.tasks.task_sampler import TaskSampler - from forgeenv.verifier.visible_verifier import compute_visible_reward - - sampler = TaskSampler() - executor = SimulationExecutor() - task_ids = kwargs.get("task_id", [None] * len(completions)) - broken_scripts = kwargs.get("broken_script", [""] * len(completions)) - - rewards: list[float] = [] - for completion, task_id, broken in zip(completions, task_ids, broken_scripts): - if isinstance(completion, list): # chat format - completion = completion[-1]["content"] - diff = extract_diff(completion or "") - repaired = apply_unified_diff(broken, diff) if diff else broken - task = sampler.get_by_id(task_id) if task_id else None - if task is None and sampler.tasks: - task = sampler.tasks[0] - result = executor.execute(repaired, task) - result.script_content = repaired - reward, _ = compute_visible_reward(result, task) - rewards.append(float(reward)) - return rewards - - -def run_grpo( - base_model: str, - adapter_path: Optional[str], - output_dir: str, - total_episodes: int = 200, - group_size: int = 4, - learning_rate: float = 5e-6, - seed: int = 0, - use_unsloth: Optional[bool] = None, -) -> None: - """Launch GRPO training (lazy imports to keep this module importable on CPU).""" - - if use_unsloth is None: - use_unsloth = os.environ.get("FORGEENV_USE_UNSLOTH", "1") == "1" - - if not use_unsloth: - # Dry-run mode: just exercise the prompt building loop and dump rewards. - env = ForgeEnvironment(seed=seed) - rewards = [] - for ep in range(total_episodes): - result = rollout_one_episode(env) - rewards.append(result.visible_reward) - if ep % max(1, total_episodes // 10) == 0: - print( - f"[grpo dry-run] ep={ep} reward={result.visible_reward:.3f} " - f"primitive={result.primitive_type}" - ) - Path(output_dir).mkdir(parents=True, exist_ok=True) - (Path(output_dir) / "dry_run_rewards.json").write_text( - json.dumps(rewards, indent=2) - ) - print(f"[grpo dry-run] wrote {len(rewards)} rewards to {output_dir}") - return - - from datasets import Dataset - from trl import GRPOConfig, GRPOTrainer - from unsloth import FastLanguageModel - from peft import PeftModel - - model, tokenizer = FastLanguageModel.from_pretrained( - model_name=base_model, - max_seq_length=4096, - dtype=None, - load_in_4bit=True, - ) - if adapter_path: - model = PeftModel.from_pretrained(model, adapter_path, is_trainable=True) - - env = ForgeEnvironment(seed=seed) - examples = [_build_repair_prompt(env) for _ in range(total_episodes)] - dataset = Dataset.from_list(examples) - - grpo_config = GRPOConfig( - output_dir=output_dir, - per_device_train_batch_size=1, - gradient_accumulation_steps=4, - learning_rate=learning_rate, - max_steps=total_episodes, - num_generations=group_size, - max_completion_length=1024, - logging_steps=5, - save_steps=max(50, total_episodes // 4), - save_total_limit=2, - seed=seed, - report_to="none", - beta=0.04, - ) - trainer = GRPOTrainer( - model=model, - processing_class=tokenizer, - args=grpo_config, - train_dataset=dataset, - reward_funcs=[reward_repair_function], - ) - trainer.train() - Path(output_dir).mkdir(parents=True, exist_ok=True) - model.save_pretrained(output_dir) - tokenizer.save_pretrained(output_dir) - print(f"[grpo] saved adapter to {output_dir}") - - -def _parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--base_model", default="unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit") - parser.add_argument("--adapter_path", default=None) - parser.add_argument("--output_dir", required=True) - parser.add_argument("--total_episodes", type=int, default=200) - parser.add_argument("--group_size", type=int, default=4) - parser.add_argument("--learning_rate", type=float, default=5e-6) - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--dry_run", action="store_true") - return parser.parse_args() - - -if __name__ == "__main__": - args = _parse_args() - run_grpo( - base_model=args.base_model, - adapter_path=args.adapter_path, - output_dir=args.output_dir, - total_episodes=args.total_episodes, - group_size=args.group_size, - learning_rate=args.learning_rate, - seed=args.seed, - use_unsloth=not args.dry_run, - ) +"""GRPO trainer for the Repair Agent. + +This wires TRL's GRPOTrainer to ForgeEnvironment via a per-prompt rollout +function. Each prompt is sampled K times (group size); each sample is +executed in the env and gets a scalar reward from the visible verifier. + +Usage: + python -m forgeenv.training.grpo_repair \\ + --base_model unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit \\ + --adapter_path artifacts/checkpoints/repair_agent_sft \\ + --output_dir artifacts/checkpoints/repair_agent_grpo \\ + --total_episodes 200 --group_size 4 +""" +from __future__ import annotations + +import argparse +import json +import os +from pathlib import Path +from typing import Any, Optional + +from forgeenv.env.forge_environment import ForgeEnvironment +from forgeenv.roles.drift_generator import BaselineDriftGenerator +from forgeenv.roles.prompts import ( + DRIFT_GENERATOR_SYSTEM_PROMPT, + REPAIR_AGENT_SYSTEM_PROMPT, + render_drift_generator_prompt, + render_repair_agent_prompt, +) +from forgeenv.roles.repair_agent import extract_diff +from forgeenv.training.rollout import rollout_one_episode + + +def _build_repair_prompt(env: ForgeEnvironment) -> dict[str, Any]: + """Reset env, run baseline drift generator, return a repair-prompt + dict ready to feed to TRL's GRPOTrainer.""" + drift_gen = BaselineDriftGenerator() + + obs = env.reset(difficulty="easy") + drift_user = render_drift_generator_prompt( + script=obs.script_content, + target_category=obs.target_category, + library_versions=obs.library_versions, + ) + spec = drift_gen.propose( + target_category=obs.target_category, script=obs.script_content + ) + from forgeenv.env.actions import BreakageAction, ForgeAction + + obs2 = env.step( + ForgeAction( + breakage=BreakageAction( + primitive_type=spec["primitive_type"], params=spec["params"] + ) + ) + ) + + user = render_repair_agent_prompt( + broken_script=obs2.script_content, + error_trace=obs2.error_trace or "", + library_versions=obs2.library_versions, + target_category=obs2.target_category, + ) + return { + "prompt": [ + {"role": "system", "content": REPAIR_AGENT_SYSTEM_PROMPT}, + {"role": "user", "content": user}, + ], + "task_id": obs.task_id, + "primitive_type": spec["primitive_type"], + "broken_script": obs2.script_content, + "drift_user_prompt": drift_user, + } + + +def reward_repair_function( + completions: list, prompts: list = None, **kwargs +) -> list[float]: + """TRL-compatible reward fn: scores a batch of completions against + a (broken_script, breakage_spec) tuple stored on each example.""" + from forgeenv.env.actions import RepairAction, ForgeAction + from forgeenv.env.diff_utils import apply_unified_diff + from forgeenv.sandbox.simulation_mode import SimulationExecutor + from forgeenv.tasks.task_sampler import TaskSampler + from forgeenv.verifier.visible_verifier import compute_visible_reward + + sampler = TaskSampler() + executor = SimulationExecutor() + task_ids = kwargs.get("task_id", [None] * len(completions)) + broken_scripts = kwargs.get("broken_script", [""] * len(completions)) + + rewards: list[float] = [] + for completion, task_id, broken in zip(completions, task_ids, broken_scripts): + if isinstance(completion, list): # chat format + completion = completion[-1]["content"] + diff = extract_diff(completion or "") + repaired = apply_unified_diff(broken, diff) if diff else broken + task = sampler.get_by_id(task_id) if task_id else None + if task is None and sampler.tasks: + task = sampler.tasks[0] + result = executor.execute(repaired, task) + result.script_content = repaired + reward, _ = compute_visible_reward(result, task) + rewards.append(float(reward)) + return rewards + + +def run_grpo( + base_model: str, + adapter_path: Optional[str], + output_dir: str, + total_episodes: int = 200, + group_size: int = 4, + learning_rate: float = 5e-6, + seed: int = 0, + use_unsloth: Optional[bool] = None, +) -> None: + """Launch GRPO training (lazy imports to keep this module importable on CPU).""" + + if use_unsloth is None: + use_unsloth = os.environ.get("FORGEENV_USE_UNSLOTH", "1") == "1" + + if not use_unsloth: + # Dry-run mode: just exercise the prompt building loop and dump rewards. + env = ForgeEnvironment(seed=seed) + rewards = [] + for ep in range(total_episodes): + result = rollout_one_episode(env) + rewards.append(result.visible_reward) + if ep % max(1, total_episodes // 10) == 0: + print( + f"[grpo dry-run] ep={ep} reward={result.visible_reward:.3f} " + f"primitive={result.primitive_type}" + ) + Path(output_dir).mkdir(parents=True, exist_ok=True) + (Path(output_dir) / "dry_run_rewards.json").write_text( + json.dumps(rewards, indent=2) + ) + print(f"[grpo dry-run] wrote {len(rewards)} rewards to {output_dir}") + return + + from datasets import Dataset + from trl import GRPOConfig, GRPOTrainer + from unsloth import FastLanguageModel + from peft import PeftModel + + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=base_model, + max_seq_length=4096, + dtype=None, + load_in_4bit=True, + ) + if adapter_path: + model = PeftModel.from_pretrained(model, adapter_path, is_trainable=True) + + env = ForgeEnvironment(seed=seed) + examples = [_build_repair_prompt(env) for _ in range(total_episodes)] + dataset = Dataset.from_list(examples) + + grpo_config = GRPOConfig( + output_dir=output_dir, + per_device_train_batch_size=1, + gradient_accumulation_steps=4, + learning_rate=learning_rate, + max_steps=total_episodes, + num_generations=group_size, + max_completion_length=1024, + logging_steps=5, + save_steps=max(50, total_episodes // 4), + save_total_limit=2, + seed=seed, + report_to="none", + beta=0.04, + ) + trainer = GRPOTrainer( + model=model, + processing_class=tokenizer, + args=grpo_config, + train_dataset=dataset, + reward_funcs=[reward_repair_function], + ) + trainer.train() + Path(output_dir).mkdir(parents=True, exist_ok=True) + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + print(f"[grpo] saved adapter to {output_dir}") + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--base_model", default="unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit") + parser.add_argument("--adapter_path", default=None) + parser.add_argument("--output_dir", required=True) + parser.add_argument("--total_episodes", type=int, default=200) + parser.add_argument("--group_size", type=int, default=4) + parser.add_argument("--learning_rate", type=float, default=5e-6) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--dry_run", action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_args() + run_grpo( + base_model=args.base_model, + adapter_path=args.adapter_path, + output_dir=args.output_dir, + total_episodes=args.total_episodes, + group_size=args.group_size, + learning_rate=args.learning_rate, + seed=args.seed, + use_unsloth=not args.dry_run, + ) diff --git a/forgeenv/training/plots.py b/forgeenv/training/plots.py index b9c157e8c7745929a0f4e8d40a1618dac8636edb..15ee0793bcae0ce8912db81dd527cc553fcff177 100644 --- a/forgeenv/training/plots.py +++ b/forgeenv/training/plots.py @@ -1,128 +1,128 @@ -"""Matplotlib plotting helpers β€” produces the 3 PNGs that go into the README. - -Plots: - 1. baseline_vs_trained.png β€” bar/line comparison - 2. training_reward_curve.png β€” moving-average reward over episodes - 3. success_by_category.png β€” per-primitive-type success rate - -All plots are 600x400 @ 100 dpi, label both axes, and use a colour-blind-safe palette. -""" -from __future__ import annotations - -from pathlib import Path -from typing import Iterable - -import matplotlib - -matplotlib.use("Agg") -import matplotlib.pyplot as plt # noqa: E402 - -PALETTE = { - "baseline": "#888888", - "trained": "#1F77B4", - "ema": "#D62728", - "raw": "#1F77B4", -} - - -def _moving_average(values: list[float], window: int = 10) -> list[float]: - if not values: - return [] - out: list[float] = [] - cumsum = 0.0 - for i, v in enumerate(values): - cumsum += v - if i >= window: - cumsum -= values[i - window] - out.append(cumsum / min(i + 1, window)) - return out - - -def plot_baseline_vs_trained( - baseline_rewards: list[float], - trained_rewards: list[float], - out_path: str | Path, - title: str = "ForgeEnv: Baseline vs Trained (50 eval episodes)", -) -> str: - """Side-by-side bar chart of mean reward + per-episode strip plot.""" - out_path = Path(out_path) - out_path.parent.mkdir(parents=True, exist_ok=True) - fig, ax = plt.subplots(figsize=(6, 4), dpi=100) - - means = [ - sum(baseline_rewards) / max(1, len(baseline_rewards)), - sum(trained_rewards) / max(1, len(trained_rewards)), - ] - labels = ["Baseline (no-op)", "Trained (GRPO)"] - colors = [PALETTE["baseline"], PALETTE["trained"]] - bars = ax.bar(labels, means, color=colors, width=0.5, alpha=0.85) - ax.bar_label(bars, fmt="%.2f", padding=3) - - for x, rewards in zip([0, 1], [baseline_rewards, trained_rewards]): - if rewards: - xs = [x + 0.18] * len(rewards) - ax.scatter(xs, rewards, s=8, color="black", alpha=0.4, zorder=3) - - ax.set_ylabel("Visible verifier reward") - ax.set_title(title) - ax.grid(axis="y", linestyle=":", alpha=0.5) - ax.set_ylim(bottom=min(0, min(means + baseline_rewards + trained_rewards or [0]))) - fig.tight_layout() - fig.savefig(out_path, dpi=100, bbox_inches="tight") - plt.close(fig) - return str(out_path) - - -def plot_reward_curve( - rewards: list[float], - out_path: str | Path, - window: int = 10, - title: str = "ForgeEnv: Repair Agent reward over training", -) -> str: - out_path = Path(out_path) - out_path.parent.mkdir(parents=True, exist_ok=True) - fig, ax = plt.subplots(figsize=(6, 4), dpi=100) - xs = list(range(1, len(rewards) + 1)) - ax.plot(xs, rewards, color=PALETTE["raw"], alpha=0.35, linewidth=1.0, label="Per-episode") - if rewards: - ax.plot( - xs, - _moving_average(rewards, window=window), - color=PALETTE["ema"], - linewidth=2.0, - label=f"Moving avg (w={window})", - ) - ax.set_xlabel("Episode") - ax.set_ylabel("Visible verifier reward") - ax.set_title(title) - ax.legend(loc="lower right") - ax.grid(linestyle=":", alpha=0.4) - fig.tight_layout() - fig.savefig(out_path, dpi=100, bbox_inches="tight") - plt.close(fig) - return str(out_path) - - -def plot_success_rate_by_category( - by_category: dict[str, list[bool]], - out_path: str | Path, - title: str = "ForgeEnv: Repair success by primitive type", -) -> str: - out_path = Path(out_path) - out_path.parent.mkdir(parents=True, exist_ok=True) - fig, ax = plt.subplots(figsize=(7, 4), dpi=100) - - cats = list(by_category.keys()) - rates = [ - sum(by_category[c]) / max(1, len(by_category[c])) for c in cats - ] - bars = ax.barh(cats, rates, color=PALETTE["trained"], alpha=0.85) - ax.bar_label(bars, fmt="%.2f", padding=3) - ax.set_xlim(0, 1.05) - ax.set_xlabel("Success rate (held-out: executed_cleanly)") - ax.set_title(title) - ax.grid(axis="x", linestyle=":", alpha=0.4) - fig.tight_layout() - fig.savefig(out_path, dpi=100, bbox_inches="tight") - plt.close(fig) - return str(out_path) +"""Matplotlib plotting helpers β€” produces the 3 PNGs that go into the README. + +Plots: + 1. baseline_vs_trained.png β€” bar/line comparison + 2. training_reward_curve.png β€” moving-average reward over episodes + 3. success_by_category.png β€” per-primitive-type success rate + +All plots are 600x400 @ 100 dpi, label both axes, and use a colour-blind-safe palette. +""" +from __future__ import annotations + +from pathlib import Path +from typing import Iterable + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt # noqa: E402 + +PALETTE = { + "baseline": "#888888", + "trained": "#1F77B4", + "ema": "#D62728", + "raw": "#1F77B4", +} + + +def _moving_average(values: list[float], window: int = 10) -> list[float]: + if not values: + return [] + out: list[float] = [] + cumsum = 0.0 + for i, v in enumerate(values): + cumsum += v + if i >= window: + cumsum -= values[i - window] + out.append(cumsum / min(i + 1, window)) + return out + + +def plot_baseline_vs_trained( + baseline_rewards: list[float], + trained_rewards: list[float], + out_path: str | Path, + title: str = "ForgeEnv: Baseline vs Trained (50 eval episodes)", +) -> str: + """Side-by-side bar chart of mean reward + per-episode strip plot.""" + out_path = Path(out_path) + out_path.parent.mkdir(parents=True, exist_ok=True) + fig, ax = plt.subplots(figsize=(6, 4), dpi=100) + + means = [ + sum(baseline_rewards) / max(1, len(baseline_rewards)), + sum(trained_rewards) / max(1, len(trained_rewards)), + ] + labels = ["Baseline (no-op)", "Trained (GRPO)"] + colors = [PALETTE["baseline"], PALETTE["trained"]] + bars = ax.bar(labels, means, color=colors, width=0.5, alpha=0.85) + ax.bar_label(bars, fmt="%.2f", padding=3) + + for x, rewards in zip([0, 1], [baseline_rewards, trained_rewards]): + if rewards: + xs = [x + 0.18] * len(rewards) + ax.scatter(xs, rewards, s=8, color="black", alpha=0.4, zorder=3) + + ax.set_ylabel("Visible verifier reward") + ax.set_title(title) + ax.grid(axis="y", linestyle=":", alpha=0.5) + ax.set_ylim(bottom=min(0, min(means + baseline_rewards + trained_rewards or [0]))) + fig.tight_layout() + fig.savefig(out_path, dpi=100, bbox_inches="tight") + plt.close(fig) + return str(out_path) + + +def plot_reward_curve( + rewards: list[float], + out_path: str | Path, + window: int = 10, + title: str = "ForgeEnv: Repair Agent reward over training", +) -> str: + out_path = Path(out_path) + out_path.parent.mkdir(parents=True, exist_ok=True) + fig, ax = plt.subplots(figsize=(6, 4), dpi=100) + xs = list(range(1, len(rewards) + 1)) + ax.plot(xs, rewards, color=PALETTE["raw"], alpha=0.35, linewidth=1.0, label="Per-episode") + if rewards: + ax.plot( + xs, + _moving_average(rewards, window=window), + color=PALETTE["ema"], + linewidth=2.0, + label=f"Moving avg (w={window})", + ) + ax.set_xlabel("Episode") + ax.set_ylabel("Visible verifier reward") + ax.set_title(title) + ax.legend(loc="lower right") + ax.grid(linestyle=":", alpha=0.4) + fig.tight_layout() + fig.savefig(out_path, dpi=100, bbox_inches="tight") + plt.close(fig) + return str(out_path) + + +def plot_success_rate_by_category( + by_category: dict[str, list[bool]], + out_path: str | Path, + title: str = "ForgeEnv: Repair success by primitive type", +) -> str: + out_path = Path(out_path) + out_path.parent.mkdir(parents=True, exist_ok=True) + fig, ax = plt.subplots(figsize=(7, 4), dpi=100) + + cats = list(by_category.keys()) + rates = [ + sum(by_category[c]) / max(1, len(by_category[c])) for c in cats + ] + bars = ax.barh(cats, rates, color=PALETTE["trained"], alpha=0.85) + ax.bar_label(bars, fmt="%.2f", padding=3) + ax.set_xlim(0, 1.05) + ax.set_xlabel("Success rate (held-out: executed_cleanly)") + ax.set_title(title) + ax.grid(axis="x", linestyle=":", alpha=0.4) + fig.tight_layout() + fig.savefig(out_path, dpi=100, bbox_inches="tight") + plt.close(fig) + return str(out_path) diff --git a/forgeenv/training/reward_functions.py b/forgeenv/training/reward_functions.py index 7b1e1a3d6f664c00b88fbd202176940db42cd667..c8b88484b5801d2716cf11679c9510a44ef86953 100644 --- a/forgeenv/training/reward_functions.py +++ b/forgeenv/training/reward_functions.py @@ -1,127 +1,127 @@ -"""Reward functions for both roles, following R-Zero's Algorithm 1. - -- Repair Agent (Solver): visible verifier reward (binary-ish with partial credit) -- Drift Generator (Challenger): uncertainty reward + repetition penalty -- Alignment metric: Pearson correlation between visible and held-out scores -""" -from __future__ import annotations - -import numpy as np -from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu -from sklearn.cluster import AgglomerativeClustering - -from forgeenv.tasks.models import ExecutionResult, Task -from forgeenv.verifier.held_out_evaluator import compute_held_out_scores # re-export -from forgeenv.verifier.visible_verifier import compute_visible_reward - -__all__ = [ - "compute_repair_reward", - "compute_uncertainty_reward", - "compute_repetition_penalty", - "compute_drift_gen_reward", - "compute_alignment_score", - "compute_held_out_scores", - "compute_visible_reward", -] - - -def compute_repair_reward(result: ExecutionResult, task: Task) -> float: - """Repair Agent reward: visible verifier scalar.""" - reward, _ = compute_visible_reward(result, task) - return reward - - -def compute_uncertainty_reward(success_rates: list[bool]) -> float: - """R-Zero uncertainty reward (Section 2.2, Eq. 1). - - r_uncertainty = 1 - 2 * |p_hat - 0.5| - - Peaks at p_hat = 0.5 (maximum learning signal). Drives the Drift - Generator to propose breakages exactly at the edge of Repair Agent - capability. - """ - if not success_rates: - return 0.0 - p_hat = sum(success_rates) / len(success_rates) - return 1.0 - 2.0 * abs(p_hat - 0.5) - - -def compute_repetition_penalty( - breakage_text: str, - batch_breakages: list[str], - threshold: float = 0.5, -) -> float: - """R-Zero repetition penalty. - - Cluster batch breakages by 1 - BLEU distance using agglomerative - clustering, then penalize a target proportional to the size of its - cluster (encouraging diverse proposals). - """ - if len(batch_breakages) <= 1: - return 0.0 - - smoother = SmoothingFunction().method1 - n = len(batch_breakages) - distances = np.ones((n, n), dtype=np.float64) - - for i in range(n): - distances[i][i] = 0.0 - for j in range(i + 1, n): - tokens_i = batch_breakages[i].split() - tokens_j = batch_breakages[j].split() - if tokens_i and tokens_j: - bleu = sentence_bleu( - [tokens_i], tokens_j, smoothing_function=smoother - ) - dist = 1.0 - bleu - else: - dist = 1.0 - distances[i][j] = dist - distances[j][i] = dist - - clustering = AgglomerativeClustering( - n_clusters=None, - distance_threshold=threshold, - metric="precomputed", - linkage="average", - ) - labels = clustering.fit_predict(distances) - - target_idx = ( - batch_breakages.index(breakage_text) if breakage_text in batch_breakages else 0 - ) - target_cluster = labels[target_idx] - cluster_size = int(sum(1 for label in labels if label == target_cluster)) - return cluster_size / len(batch_breakages) - - -def compute_drift_gen_reward( - breakage_text: str, - repair_successes: list[bool], - batch_breakages: list[str], -) -> float: - """R-Zero composite Challenger reward: max(0, uncertainty - repetition).""" - uncertainty = compute_uncertainty_reward(repair_successes) - penalty = compute_repetition_penalty(breakage_text, batch_breakages) - return max(0.0, uncertainty - penalty) - - -def compute_alignment_score( - visible_scores: list[float], - held_out_scores: list[float], -) -> float: - """Pearson correlation between visible verifier and held-out scores - across rollouts. Used to train the Drift Generator to propose - breakages where the visible verifier tracks ground truth (anti-hacking - signal: an exploitable visible verifier produces low correlation).""" - if len(visible_scores) < 2 or len(visible_scores) != len(held_out_scores): - return 0.0 - - v = np.asarray(visible_scores, dtype=np.float64) - h = np.asarray(held_out_scores, dtype=np.float64) - - if v.std() < 1e-8 or h.std() < 1e-8: - return 0.0 - - correlation = np.corrcoef(v, h)[0, 1] - return 0.0 if np.isnan(correlation) else float(correlation) +"""Reward functions for both roles, following R-Zero's Algorithm 1. + +- Repair Agent (Solver): visible verifier reward (binary-ish with partial credit) +- Drift Generator (Challenger): uncertainty reward + repetition penalty +- Alignment metric: Pearson correlation between visible and held-out scores +""" +from __future__ import annotations + +import numpy as np +from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu +from sklearn.cluster import AgglomerativeClustering + +from forgeenv.tasks.models import ExecutionResult, Task +from forgeenv.verifier.held_out_evaluator import compute_held_out_scores # re-export +from forgeenv.verifier.visible_verifier import compute_visible_reward + +__all__ = [ + "compute_repair_reward", + "compute_uncertainty_reward", + "compute_repetition_penalty", + "compute_drift_gen_reward", + "compute_alignment_score", + "compute_held_out_scores", + "compute_visible_reward", +] + + +def compute_repair_reward(result: ExecutionResult, task: Task) -> float: + """Repair Agent reward: visible verifier scalar.""" + reward, _ = compute_visible_reward(result, task) + return reward + + +def compute_uncertainty_reward(success_rates: list[bool]) -> float: + """R-Zero uncertainty reward (Section 2.2, Eq. 1). + + r_uncertainty = 1 - 2 * |p_hat - 0.5| + + Peaks at p_hat = 0.5 (maximum learning signal). Drives the Drift + Generator to propose breakages exactly at the edge of Repair Agent + capability. + """ + if not success_rates: + return 0.0 + p_hat = sum(success_rates) / len(success_rates) + return 1.0 - 2.0 * abs(p_hat - 0.5) + + +def compute_repetition_penalty( + breakage_text: str, + batch_breakages: list[str], + threshold: float = 0.5, +) -> float: + """R-Zero repetition penalty. + + Cluster batch breakages by 1 - BLEU distance using agglomerative + clustering, then penalize a target proportional to the size of its + cluster (encouraging diverse proposals). + """ + if len(batch_breakages) <= 1: + return 0.0 + + smoother = SmoothingFunction().method1 + n = len(batch_breakages) + distances = np.ones((n, n), dtype=np.float64) + + for i in range(n): + distances[i][i] = 0.0 + for j in range(i + 1, n): + tokens_i = batch_breakages[i].split() + tokens_j = batch_breakages[j].split() + if tokens_i and tokens_j: + bleu = sentence_bleu( + [tokens_i], tokens_j, smoothing_function=smoother + ) + dist = 1.0 - bleu + else: + dist = 1.0 + distances[i][j] = dist + distances[j][i] = dist + + clustering = AgglomerativeClustering( + n_clusters=None, + distance_threshold=threshold, + metric="precomputed", + linkage="average", + ) + labels = clustering.fit_predict(distances) + + target_idx = ( + batch_breakages.index(breakage_text) if breakage_text in batch_breakages else 0 + ) + target_cluster = labels[target_idx] + cluster_size = int(sum(1 for label in labels if label == target_cluster)) + return cluster_size / len(batch_breakages) + + +def compute_drift_gen_reward( + breakage_text: str, + repair_successes: list[bool], + batch_breakages: list[str], +) -> float: + """R-Zero composite Challenger reward: max(0, uncertainty - repetition).""" + uncertainty = compute_uncertainty_reward(repair_successes) + penalty = compute_repetition_penalty(breakage_text, batch_breakages) + return max(0.0, uncertainty - penalty) + + +def compute_alignment_score( + visible_scores: list[float], + held_out_scores: list[float], +) -> float: + """Pearson correlation between visible verifier and held-out scores + across rollouts. Used to train the Drift Generator to propose + breakages where the visible verifier tracks ground truth (anti-hacking + signal: an exploitable visible verifier produces low correlation).""" + if len(visible_scores) < 2 or len(visible_scores) != len(held_out_scores): + return 0.0 + + v = np.asarray(visible_scores, dtype=np.float64) + h = np.asarray(held_out_scores, dtype=np.float64) + + if v.std() < 1e-8 or h.std() < 1e-8: + return 0.0 + + correlation = np.corrcoef(v, h)[0, 1] + return 0.0 if np.isnan(correlation) else float(correlation) diff --git a/forgeenv/training/rollout.py b/forgeenv/training/rollout.py index 8efe1317d189fb9fcff21779dd79e0cf446b03a9..c5a1b866d912bec6e115f91d0d929582925e621f 100644 --- a/forgeenv/training/rollout.py +++ b/forgeenv/training/rollout.py @@ -1,173 +1,173 @@ -"""Rollout function: connects an LLM to ForgeEnvironment for a full episode. - -This is the function the GRPO trainer calls to convert a prompt into a -trajectory + reward. It runs both phases of an episode (drift + repair) by -asking the policy twice with role-switched prompts. -""" -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Any, Callable, Optional - -from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction -from forgeenv.env.forge_environment import ForgeEnvironment -from forgeenv.roles.drift_generator import ( - BaselineDriftGenerator, - parse_drift_output, -) -from forgeenv.roles.prompts import ( - DRIFT_GENERATOR_SYSTEM_PROMPT, - REPAIR_AGENT_SYSTEM_PROMPT, - render_drift_generator_prompt, - render_repair_agent_prompt, -) -from forgeenv.roles.repair_agent import BaselineRepairAgent, extract_diff - - -# Generation function signature: takes a (system, user) prompt pair and -# returns the assistant's completion. We keep this abstract so we can plug -# in TRL's batched generator, vLLM, or our deterministic baseline. -GenerateFn = Callable[[str, str], str] - - -@dataclass -class RolloutResult: - task_id: str - primitive_type: str - drift_prompt: str - drift_completion: str - repair_prompt: str - repair_completion: str - visible_reward: float - visible_breakdown: dict[str, float] - held_out_breakdown: dict[str, float] - success: bool - error_trace: str = "" - info: dict[str, Any] = field(default_factory=dict) - - -def _baseline_drift_generate(env: ForgeEnvironment) -> GenerateFn: - """Wrap our deterministic Drift Generator into a GenerateFn.""" - - gen = BaselineDriftGenerator(seed=0) - - def fn(system: str, user: str) -> str: - target = "RenameApiCall" - for line in user.splitlines(): - if line.lower().startswith("target category:"): - target = line.split(":", 1)[1].strip() - break - # Try to extract the script body so we can pick a primitive that - # actually mutates it. - script_block = "" - if "```python" in user: - script_block = user.split("```python", 1)[1].split("```", 1)[0] - spec = gen.propose(target_category=target, script=script_block) - import json - - return json.dumps(spec) - - return fn - - -def _baseline_repair_generate() -> GenerateFn: - """Wrap our deterministic Repair Agent into a GenerateFn. - - The baseline cheats by recovering the original script from the user - prompt is impossible (we don't pass it). Instead, when called as a - baseline it just returns an empty diff. Use BaselineDriftGenerator-paired - tests (which read env.state) when you want the oracle path. - """ - - def fn(system: str, user: str) -> str: - return "" # baseline = no-op (intentional negative baseline) - - return fn - - -def rollout_one_episode( - env: ForgeEnvironment, - drift_generate: Optional[GenerateFn] = None, - repair_generate: Optional[GenerateFn] = None, - difficulty: str = "easy", -) -> RolloutResult: - """Run a single 2-step episode end-to-end and capture all signals.""" - drift_generate = drift_generate or _baseline_drift_generate(env) - repair_generate = repair_generate or _baseline_repair_generate() - - obs = env.reset(difficulty=difficulty) - assert obs.current_phase == "drift_gen" - - # ---------- Phase 1: Drift Generator ---------- - drift_prompt = render_drift_generator_prompt( - script=obs.script_content, - target_category=obs.target_category, - library_versions=obs.library_versions, - ) - drift_raw = drift_generate(DRIFT_GENERATOR_SYSTEM_PROMPT, drift_prompt) - spec = parse_drift_output(drift_raw) - if not spec: - spec = {"primitive_type": "RenameApiCall", "params": {}} - - breakage_action = ForgeAction( - breakage=BreakageAction( - primitive_type=spec.get("primitive_type", "RenameApiCall"), - params=spec.get("params", {}) or {}, - ) - ) - obs2 = env.step(breakage_action) - - # ---------- Phase 2: Repair Agent ---------- - repair_prompt = render_repair_agent_prompt( - broken_script=obs2.script_content, - error_trace=obs2.error_trace or "", - library_versions=obs2.library_versions, - target_category=obs2.target_category, - ) - repair_raw = repair_generate(REPAIR_AGENT_SYSTEM_PROMPT, repair_prompt) - diff = extract_diff(repair_raw) if repair_raw else "" - - repair_action = ForgeAction(repair=RepairAction(unified_diff=diff)) - obs3 = env.step(repair_action) - - return RolloutResult( - task_id=obs.task_id, - primitive_type=spec.get("primitive_type", "RenameApiCall"), - drift_prompt=drift_prompt, - drift_completion=drift_raw, - repair_prompt=repair_prompt, - repair_completion=repair_raw, - visible_reward=float(obs3.reward or 0.0), - visible_breakdown=dict(obs3.reward_breakdown), - held_out_breakdown=dict(obs3.held_out_breakdown), - success=bool(obs3.held_out_breakdown.get("executed_cleanly", 0.0) > 0.5), - error_trace=obs3.error_trace or "", - info=dict(obs3.info), - ) - - -def baseline_oracle_repair_generate(env: ForgeEnvironment) -> GenerateFn: - """An "oracle" repair generator that reads the original script from - `env.state` and emits a perfect diff. Useful for sanity-checking the - end-to-end loop and as the upper-bound baseline in plots. - """ - - repair_agent = BaselineRepairAgent() - - def fn(system: str, user: str) -> str: - # Pull the original script out of env state via the task sampler - task_id = env.state.get("task_id") - if task_id is None: - return "" - task = env.task_sampler.get_by_id(task_id) - if task is None: - return "" - # The current script in env._broken_script is what the user sees. - broken = env._broken_script # noqa: SLF001 β€” internal but oracle-only - return repair_agent.repair( - broken, - breakage_spec=env._breakage_spec, # noqa: SLF001 - original_script=task.script_content, - ) - - return fn +"""Rollout function: connects an LLM to ForgeEnvironment for a full episode. + +This is the function the GRPO trainer calls to convert a prompt into a +trajectory + reward. It runs both phases of an episode (drift + repair) by +asking the policy twice with role-switched prompts. +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, Optional + +from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction +from forgeenv.env.forge_environment import ForgeEnvironment +from forgeenv.roles.drift_generator import ( + BaselineDriftGenerator, + parse_drift_output, +) +from forgeenv.roles.prompts import ( + DRIFT_GENERATOR_SYSTEM_PROMPT, + REPAIR_AGENT_SYSTEM_PROMPT, + render_drift_generator_prompt, + render_repair_agent_prompt, +) +from forgeenv.roles.repair_agent import BaselineRepairAgent, extract_diff + + +# Generation function signature: takes a (system, user) prompt pair and +# returns the assistant's completion. We keep this abstract so we can plug +# in TRL's batched generator, vLLM, or our deterministic baseline. +GenerateFn = Callable[[str, str], str] + + +@dataclass +class RolloutResult: + task_id: str + primitive_type: str + drift_prompt: str + drift_completion: str + repair_prompt: str + repair_completion: str + visible_reward: float + visible_breakdown: dict[str, float] + held_out_breakdown: dict[str, float] + success: bool + error_trace: str = "" + info: dict[str, Any] = field(default_factory=dict) + + +def _baseline_drift_generate(env: ForgeEnvironment) -> GenerateFn: + """Wrap our deterministic Drift Generator into a GenerateFn.""" + + gen = BaselineDriftGenerator(seed=0) + + def fn(system: str, user: str) -> str: + target = "RenameApiCall" + for line in user.splitlines(): + if line.lower().startswith("target category:"): + target = line.split(":", 1)[1].strip() + break + # Try to extract the script body so we can pick a primitive that + # actually mutates it. + script_block = "" + if "```python" in user: + script_block = user.split("```python", 1)[1].split("```", 1)[0] + spec = gen.propose(target_category=target, script=script_block) + import json + + return json.dumps(spec) + + return fn + + +def _baseline_repair_generate() -> GenerateFn: + """Wrap our deterministic Repair Agent into a GenerateFn. + + The baseline cheats by recovering the original script from the user + prompt is impossible (we don't pass it). Instead, when called as a + baseline it just returns an empty diff. Use BaselineDriftGenerator-paired + tests (which read env.state) when you want the oracle path. + """ + + def fn(system: str, user: str) -> str: + return "" # baseline = no-op (intentional negative baseline) + + return fn + + +def rollout_one_episode( + env: ForgeEnvironment, + drift_generate: Optional[GenerateFn] = None, + repair_generate: Optional[GenerateFn] = None, + difficulty: str = "easy", +) -> RolloutResult: + """Run a single 2-step episode end-to-end and capture all signals.""" + drift_generate = drift_generate or _baseline_drift_generate(env) + repair_generate = repair_generate or _baseline_repair_generate() + + obs = env.reset(difficulty=difficulty) + assert obs.current_phase == "drift_gen" + + # ---------- Phase 1: Drift Generator ---------- + drift_prompt = render_drift_generator_prompt( + script=obs.script_content, + target_category=obs.target_category, + library_versions=obs.library_versions, + ) + drift_raw = drift_generate(DRIFT_GENERATOR_SYSTEM_PROMPT, drift_prompt) + spec = parse_drift_output(drift_raw) + if not spec: + spec = {"primitive_type": "RenameApiCall", "params": {}} + + breakage_action = ForgeAction( + breakage=BreakageAction( + primitive_type=spec.get("primitive_type", "RenameApiCall"), + params=spec.get("params", {}) or {}, + ) + ) + obs2 = env.step(breakage_action) + + # ---------- Phase 2: Repair Agent ---------- + repair_prompt = render_repair_agent_prompt( + broken_script=obs2.script_content, + error_trace=obs2.error_trace or "", + library_versions=obs2.library_versions, + target_category=obs2.target_category, + ) + repair_raw = repair_generate(REPAIR_AGENT_SYSTEM_PROMPT, repair_prompt) + diff = extract_diff(repair_raw) if repair_raw else "" + + repair_action = ForgeAction(repair=RepairAction(unified_diff=diff)) + obs3 = env.step(repair_action) + + return RolloutResult( + task_id=obs.task_id, + primitive_type=spec.get("primitive_type", "RenameApiCall"), + drift_prompt=drift_prompt, + drift_completion=drift_raw, + repair_prompt=repair_prompt, + repair_completion=repair_raw, + visible_reward=float(obs3.reward or 0.0), + visible_breakdown=dict(obs3.reward_breakdown), + held_out_breakdown=dict(obs3.held_out_breakdown), + success=bool(obs3.held_out_breakdown.get("executed_cleanly", 0.0) > 0.5), + error_trace=obs3.error_trace or "", + info=dict(obs3.info), + ) + + +def baseline_oracle_repair_generate(env: ForgeEnvironment) -> GenerateFn: + """An "oracle" repair generator that reads the original script from + `env.state` and emits a perfect diff. Useful for sanity-checking the + end-to-end loop and as the upper-bound baseline in plots. + """ + + repair_agent = BaselineRepairAgent() + + def fn(system: str, user: str) -> str: + # Pull the original script out of env state via the task sampler + task_id = env.state.get("task_id") + if task_id is None: + return "" + task = env.task_sampler.get_by_id(task_id) + if task is None: + return "" + # The current script in env._broken_script is what the user sees. + broken = env._broken_script # noqa: SLF001 β€” internal but oracle-only + return repair_agent.repair( + broken, + breakage_spec=env._breakage_spec, # noqa: SLF001 + original_script=task.script_content, + ) + + return fn diff --git a/forgeenv/training/sft_warmstart.py b/forgeenv/training/sft_warmstart.py index 7e97f7b7cbafb3a4087646fac7f840fffb7c0a3f..98f170fdf79cb6bc29c5f56b30991087a7c5930c 100644 --- a/forgeenv/training/sft_warmstart.py +++ b/forgeenv/training/sft_warmstart.py @@ -1,166 +1,166 @@ -"""SFT warm-start trainer for both roles. - -Run on a Colab T4/A100 GPU. Reads `warmstart/data/repair_pairs.jsonl` (or -`drift_pairs.jsonl`), wraps in TRL SFTTrainer with Unsloth's 4-bit Qwen2.5 -loader, and saves a LoRA adapter. - -Usage: - python -m forgeenv.training.sft_warmstart \\ - --role repair_agent \\ - --data warmstart/data/repair_pairs.jsonl \\ - --output_dir artifacts/checkpoints/repair_agent_sft \\ - --base_model unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit \\ - --max_steps 200 -""" -from __future__ import annotations - -import argparse -import json -import os -from pathlib import Path -from typing import Optional - - -def _load_jsonl(path: str) -> list[dict]: - rows: list[dict] = [] - with open(path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if line: - rows.append(json.loads(line)) - return rows - - -def _format_chat(rows: list[dict]) -> list[dict]: - """Flatten messages -> a single `text` field for SFT.""" - out: list[dict] = [] - for row in rows: - msgs = row["messages"] - text_parts = [] - for m in msgs: - text_parts.append(f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>") - out.append({"text": "\n".join(text_parts)}) - return out - - -def run_sft( - role: str, - data_path: str, - output_dir: str, - base_model: str = "unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit", - max_steps: int = 200, - batch_size: int = 2, - learning_rate: float = 2e-4, - lora_r: int = 16, - seed: int = 0, - use_unsloth: Optional[bool] = None, -) -> None: - """Run SFT. Imports unsloth/trl lazily so this module is importable on - machines without a GPU.""" - rows = _load_jsonl(data_path) - formatted = _format_chat(rows) - print(f"[forgeenv.sft] Loaded {len(formatted)} rows for role={role}") - - if use_unsloth is None: - use_unsloth = os.environ.get("FORGEENV_USE_UNSLOTH", "1") == "1" - - if use_unsloth: - from unsloth import FastLanguageModel - from datasets import Dataset - from trl import SFTConfig, SFTTrainer - - model, tokenizer = FastLanguageModel.from_pretrained( - model_name=base_model, - max_seq_length=4096, - dtype=None, - load_in_4bit=True, - ) - model = FastLanguageModel.get_peft_model( - model, - r=lora_r, - lora_alpha=lora_r * 2, - lora_dropout=0.0, - bias="none", - target_modules=[ - "q_proj", "k_proj", "v_proj", "o_proj", - "gate_proj", "up_proj", "down_proj", - ], - use_gradient_checkpointing="unsloth", - random_state=seed, - ) - - dataset = Dataset.from_list(formatted) - sft_config = SFTConfig( - output_dir=output_dir, - per_device_train_batch_size=batch_size, - gradient_accumulation_steps=4, - warmup_steps=10, - max_steps=max_steps, - learning_rate=learning_rate, - logging_steps=10, - optim="adamw_8bit", - weight_decay=0.01, - lr_scheduler_type="linear", - seed=seed, - save_steps=max(50, max_steps // 4), - save_total_limit=2, - report_to="none", - dataset_text_field="text", - max_seq_length=4096, - ) - trainer = SFTTrainer( - model=model, - tokenizer=tokenizer, - train_dataset=dataset, - args=sft_config, - ) - trainer.train() - Path(output_dir).mkdir(parents=True, exist_ok=True) - model.save_pretrained(output_dir) - tokenizer.save_pretrained(output_dir) - print(f"[forgeenv.sft] Saved adapter to {output_dir}") - return - - # CPU/dry-run fallback: just dump the formatted dataset to disk so we - # can verify the pipeline shape locally. - Path(output_dir).mkdir(parents=True, exist_ok=True) - out_file = Path(output_dir) / "formatted_dataset.jsonl" - with out_file.open("w", encoding="utf-8") as f: - for row in formatted: - f.write(json.dumps(row) + "\n") - print(f"[forgeenv.sft] (dry run) wrote {len(formatted)} rows to {out_file}") - - -def _parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument( - "--role", choices=["repair_agent", "drift_generator"], required=True - ) - parser.add_argument("--data", required=True, help="Path to JSONL warm-start file") - parser.add_argument("--output_dir", required=True) - parser.add_argument( - "--base_model", default="unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit" - ) - parser.add_argument("--max_steps", type=int, default=200) - parser.add_argument("--batch_size", type=int, default=2) - parser.add_argument("--learning_rate", type=float, default=2e-4) - parser.add_argument("--lora_r", type=int, default=16) - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--dry_run", action="store_true") - return parser.parse_args() - - -if __name__ == "__main__": - args = _parse_args() - run_sft( - role=args.role, - data_path=args.data, - output_dir=args.output_dir, - base_model=args.base_model, - max_steps=args.max_steps, - batch_size=args.batch_size, - learning_rate=args.learning_rate, - lora_r=args.lora_r, - seed=args.seed, - use_unsloth=not args.dry_run, - ) +"""SFT warm-start trainer for both roles. + +Run on a Colab T4/A100 GPU. Reads `warmstart/data/repair_pairs.jsonl` (or +`drift_pairs.jsonl`), wraps in TRL SFTTrainer with Unsloth's 4-bit Qwen2.5 +loader, and saves a LoRA adapter. + +Usage: + python -m forgeenv.training.sft_warmstart \\ + --role repair_agent \\ + --data warmstart/data/repair_pairs.jsonl \\ + --output_dir artifacts/checkpoints/repair_agent_sft \\ + --base_model unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit \\ + --max_steps 200 +""" +from __future__ import annotations + +import argparse +import json +import os +from pathlib import Path +from typing import Optional + + +def _load_jsonl(path: str) -> list[dict]: + rows: list[dict] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + rows.append(json.loads(line)) + return rows + + +def _format_chat(rows: list[dict]) -> list[dict]: + """Flatten messages -> a single `text` field for SFT.""" + out: list[dict] = [] + for row in rows: + msgs = row["messages"] + text_parts = [] + for m in msgs: + text_parts.append(f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>") + out.append({"text": "\n".join(text_parts)}) + return out + + +def run_sft( + role: str, + data_path: str, + output_dir: str, + base_model: str = "unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit", + max_steps: int = 200, + batch_size: int = 2, + learning_rate: float = 2e-4, + lora_r: int = 16, + seed: int = 0, + use_unsloth: Optional[bool] = None, +) -> None: + """Run SFT. Imports unsloth/trl lazily so this module is importable on + machines without a GPU.""" + rows = _load_jsonl(data_path) + formatted = _format_chat(rows) + print(f"[forgeenv.sft] Loaded {len(formatted)} rows for role={role}") + + if use_unsloth is None: + use_unsloth = os.environ.get("FORGEENV_USE_UNSLOTH", "1") == "1" + + if use_unsloth: + from unsloth import FastLanguageModel + from datasets import Dataset + from trl import SFTConfig, SFTTrainer + + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=base_model, + max_seq_length=4096, + dtype=None, + load_in_4bit=True, + ) + model = FastLanguageModel.get_peft_model( + model, + r=lora_r, + lora_alpha=lora_r * 2, + lora_dropout=0.0, + bias="none", + target_modules=[ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", + ], + use_gradient_checkpointing="unsloth", + random_state=seed, + ) + + dataset = Dataset.from_list(formatted) + sft_config = SFTConfig( + output_dir=output_dir, + per_device_train_batch_size=batch_size, + gradient_accumulation_steps=4, + warmup_steps=10, + max_steps=max_steps, + learning_rate=learning_rate, + logging_steps=10, + optim="adamw_8bit", + weight_decay=0.01, + lr_scheduler_type="linear", + seed=seed, + save_steps=max(50, max_steps // 4), + save_total_limit=2, + report_to="none", + dataset_text_field="text", + max_seq_length=4096, + ) + trainer = SFTTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset, + args=sft_config, + ) + trainer.train() + Path(output_dir).mkdir(parents=True, exist_ok=True) + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + print(f"[forgeenv.sft] Saved adapter to {output_dir}") + return + + # CPU/dry-run fallback: just dump the formatted dataset to disk so we + # can verify the pipeline shape locally. + Path(output_dir).mkdir(parents=True, exist_ok=True) + out_file = Path(output_dir) / "formatted_dataset.jsonl" + with out_file.open("w", encoding="utf-8") as f: + for row in formatted: + f.write(json.dumps(row) + "\n") + print(f"[forgeenv.sft] (dry run) wrote {len(formatted)} rows to {out_file}") + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--role", choices=["repair_agent", "drift_generator"], required=True + ) + parser.add_argument("--data", required=True, help="Path to JSONL warm-start file") + parser.add_argument("--output_dir", required=True) + parser.add_argument( + "--base_model", default="unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit" + ) + parser.add_argument("--max_steps", type=int, default=200) + parser.add_argument("--batch_size", type=int, default=2) + parser.add_argument("--learning_rate", type=float, default=2e-4) + parser.add_argument("--lora_r", type=int, default=16) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--dry_run", action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_args() + run_sft( + role=args.role, + data_path=args.data, + output_dir=args.output_dir, + base_model=args.base_model, + max_steps=args.max_steps, + batch_size=args.batch_size, + learning_rate=args.learning_rate, + lora_r=args.lora_r, + seed=args.seed, + use_unsloth=not args.dry_run, + ) diff --git a/forgeenv/verifier/held_out_evaluator.py b/forgeenv/verifier/held_out_evaluator.py index 7b503f4878cd093033283eb04c19ea74d572e081..d4d3a1765bbe3372e9ffa49d4327d4079a1197f2 100644 --- a/forgeenv/verifier/held_out_evaluator.py +++ b/forgeenv/verifier/held_out_evaluator.py @@ -1,134 +1,134 @@ -"""Held-out evaluator: the deterministic ground-truth scorer. - -Returns 7 independent components in [0, 1]. The Repair Agent NEVER sees -this directly; the Drift Generator's training signal derives from -alignment between the visible verifier and this evaluator (Pearson -correlation across the K rollouts). -""" -from __future__ import annotations - -import ast -import re - -from forgeenv.tasks.models import ExecutionResult, Task - - -def compute_held_out_scores( - result: ExecutionResult, task: Task, repair_diff: str = "" -) -> dict[str, float]: - """Compute 7 independent held-out components.""" - - scores: dict[str, float] = { - "executed_cleanly": 1.0 if result.exit_code == 0 else 0.0, - "checkpoint_valid": 1.0 if result.checkpoint_exists else 0.0, - "loss_decreased": _compute_loss_score(result.stdout), - "metrics_in_range": _check_metrics(result.stdout, task), - "no_forbidden_workarounds": _check_workarounds(result.script_content), - "intent_preserved": _compute_intent_preservation( - task.script_content, result.script_content - ), - "hidden_tests_passed": 1.0 if "TRAINING_COMPLETE" in result.stdout else 0.0, - } - return scores - - -def _compute_loss_score(stdout: str) -> float: - """Continuous score based on relative loss decrease from first to last step.""" - - losses: list[float] = [] - for line in stdout.splitlines(): - match = re.search(r"loss[=:\s]+([\d.]+)", line, re.IGNORECASE) - if match: - try: - losses.append(float(match.group(1))) - except ValueError: - continue - - if len(losses) < 2: - return 0.0 - - decrease = (losses[0] - losses[-1]) / max(losses[0], 1e-8) - return max(0.0, min(1.0, decrease)) - - -def _check_metrics(stdout: str, task: Task) -> float: - """Return 1.0 if any reported accuracy/eval metric falls in the task's - expected range; partial credit otherwise; 0.5 if no metric was found.""" - - for line in stdout.splitlines(): - match = re.search(r"(?:accuracy|acc|eval)[=:\s]+([\d.]+)", line, re.IGNORECASE) - if match: - try: - val = float(match.group(1)) - low, high = task.expected_accuracy_range - if low <= val <= high: - return 1.0 - distance = min(abs(val - low), abs(val - high)) - return max(0.0, 1.0 - distance) - except ValueError: - continue - return 0.5 - - -def _check_workarounds(script_content: str) -> float: - """Detect forbidden workaround patterns via AST analysis. - - Catches: bare except, `except Exception: pass`, `except Exception: return`, - monkey-patching of `__getattr__` / `__class__` / `__dict__`. - """ - - if not script_content: - return 0.0 - - try: - tree = ast.parse(script_content) - except SyntaxError: - return 0.0 - - violations = 0 - - for node in ast.walk(tree): - if isinstance(node, ast.Try): - for handler in node.handlers: - if handler.type is None: - violations += 1 - elif ( - isinstance(handler.type, ast.Name) - and handler.type.id == "Exception" - ): - if len(handler.body) == 1 and isinstance( - handler.body[0], (ast.Pass, ast.Return) - ): - violations += 1 - - if isinstance(node, ast.Assign): - for target in node.targets: - if isinstance(target, ast.Attribute): - if target.attr in ("__getattr__", "__class__", "__dict__"): - violations += 1 - - return 1.0 if violations == 0 else max(0.0, 1.0 - violations * 0.3) - - -def _compute_intent_preservation(original: str, repaired: str) -> float: - """Measure how much of the original AST structure is preserved. - - Uses ratio of shared AST node count: min(N_orig, N_repair) / max(...). - """ - - if not original or not repaired: - return 0.0 - - try: - orig_tree = ast.parse(original) - repair_tree = ast.parse(repaired) - except SyntaxError: - return 0.0 - - orig_nodes = len(list(ast.walk(orig_tree))) - repair_nodes = len(list(ast.walk(repair_tree))) - - if orig_nodes == 0: - return 0.0 - - return min(orig_nodes, repair_nodes) / max(orig_nodes, repair_nodes) +"""Held-out evaluator: the deterministic ground-truth scorer. + +Returns 7 independent components in [0, 1]. The Repair Agent NEVER sees +this directly; the Drift Generator's training signal derives from +alignment between the visible verifier and this evaluator (Pearson +correlation across the K rollouts). +""" +from __future__ import annotations + +import ast +import re + +from forgeenv.tasks.models import ExecutionResult, Task + + +def compute_held_out_scores( + result: ExecutionResult, task: Task, repair_diff: str = "" +) -> dict[str, float]: + """Compute 7 independent held-out components.""" + + scores: dict[str, float] = { + "executed_cleanly": 1.0 if result.exit_code == 0 else 0.0, + "checkpoint_valid": 1.0 if result.checkpoint_exists else 0.0, + "loss_decreased": _compute_loss_score(result.stdout), + "metrics_in_range": _check_metrics(result.stdout, task), + "no_forbidden_workarounds": _check_workarounds(result.script_content), + "intent_preserved": _compute_intent_preservation( + task.script_content, result.script_content + ), + "hidden_tests_passed": 1.0 if "TRAINING_COMPLETE" in result.stdout else 0.0, + } + return scores + + +def _compute_loss_score(stdout: str) -> float: + """Continuous score based on relative loss decrease from first to last step.""" + + losses: list[float] = [] + for line in stdout.splitlines(): + match = re.search(r"loss[=:\s]+([\d.]+)", line, re.IGNORECASE) + if match: + try: + losses.append(float(match.group(1))) + except ValueError: + continue + + if len(losses) < 2: + return 0.0 + + decrease = (losses[0] - losses[-1]) / max(losses[0], 1e-8) + return max(0.0, min(1.0, decrease)) + + +def _check_metrics(stdout: str, task: Task) -> float: + """Return 1.0 if any reported accuracy/eval metric falls in the task's + expected range; partial credit otherwise; 0.5 if no metric was found.""" + + for line in stdout.splitlines(): + match = re.search(r"(?:accuracy|acc|eval)[=:\s]+([\d.]+)", line, re.IGNORECASE) + if match: + try: + val = float(match.group(1)) + low, high = task.expected_accuracy_range + if low <= val <= high: + return 1.0 + distance = min(abs(val - low), abs(val - high)) + return max(0.0, 1.0 - distance) + except ValueError: + continue + return 0.5 + + +def _check_workarounds(script_content: str) -> float: + """Detect forbidden workaround patterns via AST analysis. + + Catches: bare except, `except Exception: pass`, `except Exception: return`, + monkey-patching of `__getattr__` / `__class__` / `__dict__`. + """ + + if not script_content: + return 0.0 + + try: + tree = ast.parse(script_content) + except SyntaxError: + return 0.0 + + violations = 0 + + for node in ast.walk(tree): + if isinstance(node, ast.Try): + for handler in node.handlers: + if handler.type is None: + violations += 1 + elif ( + isinstance(handler.type, ast.Name) + and handler.type.id == "Exception" + ): + if len(handler.body) == 1 and isinstance( + handler.body[0], (ast.Pass, ast.Return) + ): + violations += 1 + + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Attribute): + if target.attr in ("__getattr__", "__class__", "__dict__"): + violations += 1 + + return 1.0 if violations == 0 else max(0.0, 1.0 - violations * 0.3) + + +def _compute_intent_preservation(original: str, repaired: str) -> float: + """Measure how much of the original AST structure is preserved. + + Uses ratio of shared AST node count: min(N_orig, N_repair) / max(...). + """ + + if not original or not repaired: + return 0.0 + + try: + orig_tree = ast.parse(original) + repair_tree = ast.parse(repaired) + except SyntaxError: + return 0.0 + + orig_nodes = len(list(ast.walk(orig_tree))) + repair_nodes = len(list(ast.walk(repair_tree))) + + if orig_nodes == 0: + return 0.0 + + return min(orig_nodes, repair_nodes) / max(orig_nodes, repair_nodes) diff --git a/forgeenv/verifier/visible_verifier.py b/forgeenv/verifier/visible_verifier.py index 037d064fb97ca54e7a166da1994b505f9d7a1b62..bdfa2a553f99d0b2325cc256d16d0178370c27ae 100644 --- a/forgeenv/verifier/visible_verifier.py +++ b/forgeenv/verifier/visible_verifier.py @@ -1,64 +1,64 @@ -"""Visible verifier: the immediate reward signal the Repair Agent sees. - -4 weighted components, summed to a scalar. This is what drives the Repair -Agent's GRPO updates each rollout. Multiple independent components were -chosen on purpose, per the reward-engineering survey (arxiv 2408.10215) -and software-tasks survey (arxiv 2601.19100): a single scalar is far -easier to game than a composable rubric. -""" -from __future__ import annotations - -import re - -from forgeenv.tasks.models import ExecutionResult, Task - -WEIGHTS: dict[str, float] = { - "script_executes": 1.0, - "loss_decreased": 0.5, - "checkpoint_appeared": 0.3, - "diff_size_penalty": 0.2, # multiplied with a non-positive component value -} - - -def compute_visible_reward( - result: ExecutionResult, task: Task -) -> tuple[float, dict[str, float]]: - """Compute scalar visible reward and per-component breakdown.""" - - components: dict[str, float] = {} - - components["script_executes"] = 1.0 if result.exit_code == 0 else 0.0 - components["loss_decreased"] = _check_loss_trend(result.stdout) - components["checkpoint_appeared"] = 1.0 if result.checkpoint_exists else 0.0 - - original_lines = max(len(task.script_content.splitlines()), 1) - current_lines = ( - len(result.script_content.splitlines()) if result.script_content else original_lines - ) - diff_ratio = abs(current_lines - original_lines) / original_lines - components["diff_size_penalty"] = -1.0 * diff_ratio if diff_ratio > 0.5 else 0.0 - - total = sum(components[k] * WEIGHTS[k] for k in components) - return total, components - - -def _check_loss_trend(stdout: str) -> float: - """Parse stdout for `loss=...` patterns and return the fraction of - consecutive steps where loss strictly decreased.""" - - losses: list[float] = [] - for line in stdout.splitlines(): - match = re.search(r"loss[=:\s]+([\d.]+)", line, re.IGNORECASE) - if match: - try: - losses.append(float(match.group(1))) - except ValueError: - continue - - if len(losses) < 2: - return 0.0 - - decreasing_steps = sum( - 1 for i in range(1, len(losses)) if losses[i] < losses[i - 1] - ) - return decreasing_steps / (len(losses) - 1) +"""Visible verifier: the immediate reward signal the Repair Agent sees. + +4 weighted components, summed to a scalar. This is what drives the Repair +Agent's GRPO updates each rollout. Multiple independent components were +chosen on purpose, per the reward-engineering survey (arxiv 2408.10215) +and software-tasks survey (arxiv 2601.19100): a single scalar is far +easier to game than a composable rubric. +""" +from __future__ import annotations + +import re + +from forgeenv.tasks.models import ExecutionResult, Task + +WEIGHTS: dict[str, float] = { + "script_executes": 1.0, + "loss_decreased": 0.5, + "checkpoint_appeared": 0.3, + "diff_size_penalty": 0.2, # multiplied with a non-positive component value +} + + +def compute_visible_reward( + result: ExecutionResult, task: Task +) -> tuple[float, dict[str, float]]: + """Compute scalar visible reward and per-component breakdown.""" + + components: dict[str, float] = {} + + components["script_executes"] = 1.0 if result.exit_code == 0 else 0.0 + components["loss_decreased"] = _check_loss_trend(result.stdout) + components["checkpoint_appeared"] = 1.0 if result.checkpoint_exists else 0.0 + + original_lines = max(len(task.script_content.splitlines()), 1) + current_lines = ( + len(result.script_content.splitlines()) if result.script_content else original_lines + ) + diff_ratio = abs(current_lines - original_lines) / original_lines + components["diff_size_penalty"] = -1.0 * diff_ratio if diff_ratio > 0.5 else 0.0 + + total = sum(components[k] * WEIGHTS[k] for k in components) + return total, components + + +def _check_loss_trend(stdout: str) -> float: + """Parse stdout for `loss=...` patterns and return the fraction of + consecutive steps where loss strictly decreased.""" + + losses: list[float] = [] + for line in stdout.splitlines(): + match = re.search(r"loss[=:\s]+([\d.]+)", line, re.IGNORECASE) + if match: + try: + losses.append(float(match.group(1))) + except ValueError: + continue + + if len(losses) < 2: + return 0.0 + + decreasing_steps = sum( + 1 for i in range(1, len(losses)) if losses[i] < losses[i - 1] + ) + return decreasing_steps / (len(losses) - 1) diff --git a/openenv.yaml b/openenv.yaml index 0b9b3202a6b716c9f98ac26a08c6f5bd1fd315e8..68ec512715407f4933494c1299e3194bc3eddf1e 100644 --- a/openenv.yaml +++ b/openenv.yaml @@ -1,23 +1,23 @@ -name: forgeenv -version: 0.1.0 -description: > - Self-improving RL environment for HuggingFace ecosystem repair. - Trains agents to fix broken training scripts under library version drift - through Challenger-Solver co-evolution. Implements R-Zero (Tencent), SPIRAL, - and Absolute Zero Reasoner techniques on top of OpenEnv. -theme: self-improvement -tags: - - openenv - - self-play - - code-repair - - schema-drift - - multi-role - - huggingface - - reinforcement-learning -environment: - class: forgeenv.env.forge_environment.ForgeEnvironment - action_model: forgeenv.env.actions.ForgeAction - observation_model: forgeenv.env.observations.ForgeObservation -server: - module: forgeenv.env.server - app: app +name: forgeenv +version: 0.1.0 +description: > + Self-improving RL environment for HuggingFace ecosystem repair. + Trains agents to fix broken training scripts under library version drift + through Challenger-Solver co-evolution. Implements R-Zero (Tencent), SPIRAL, + and Absolute Zero Reasoner techniques on top of OpenEnv. +theme: self-improvement +tags: + - openenv + - self-play + - code-repair + - schema-drift + - multi-role + - huggingface + - reinforcement-learning +environment: + class: forgeenv.env.forge_environment.ForgeEnvironment + action_model: forgeenv.env.actions.ForgeAction + observation_model: forgeenv.env.observations.ForgeObservation +server: + module: forgeenv.env.server + app: app diff --git a/pyproject.toml b/pyproject.toml index 06d50ea8cfcd7ea720f3ea5428b9d28061fc44ba..8975d9ebbb730a3e43cb4128bdfa4d2aa88d2faa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,62 +1,62 @@ -[project] -name = "forgeenv" -version = "0.1.0" -description = "Self-improving RL environment for HuggingFace ecosystem repair under library drift" -requires-python = ">=3.10" -authors = [{name = "akhiilll"}] -license = {text = "Apache-2.0"} -readme = "README.md" -keywords = ["openenv", "reinforcement-learning", "self-play", "code-repair", "huggingface"] - -dependencies = [ - "fastapi>=0.110.0", - "uvicorn>=0.27.0", - "pydantic>=2.6.0", - "pyyaml>=6.0", - "rich>=13.7.0", - "nltk>=3.8.0", - "scikit-learn>=1.4.0", - "numpy>=1.26.0", -] - -[project.optional-dependencies] -openenv = [ - "openenv-core>=0.2.0", -] -training = [ - "torch>=2.1.0", - "transformers>=4.40.0", - "datasets>=2.18.0", - "trl>=0.10.0", - "peft>=0.10.0", - "accelerate>=0.30.0", - "wandb>=0.16.0", -] -unsloth = [ - "unsloth>=2024.4", -] -sandbox = [ - "docker>=7.0.0", -] -dev = [ - "pytest>=8.0.0", - "pytest-asyncio>=0.23.0", - "matplotlib>=3.8.0", -] - -[project.urls] -Homepage = "https://huggingface.co/spaces/akhiilll/forgeenv" -Model = "https://huggingface.co/akhiilll/forgeenv-repair-agent" - -[build-system] -requires = ["setuptools>=61.0"] -build-backend = "setuptools.build_meta" - -[tool.setuptools.packages.find] -include = ["forgeenv*"] -exclude = ["tests*", "notebooks*", "artifacts*", "warmstart*", "forgeenv-space*"] - -[tool.pytest.ini_options] -testpaths = ["tests"] -python_files = ["test_*.py"] -addopts = "-v --tb=short" +[project] +name = "forgeenv" +version = "0.1.0" +description = "Self-improving RL environment for HuggingFace ecosystem repair under library drift" +requires-python = ">=3.10" +authors = [{name = "akhiilll"}] +license = {text = "Apache-2.0"} +readme = "README.md" +keywords = ["openenv", "reinforcement-learning", "self-play", "code-repair", "huggingface"] + +dependencies = [ + "fastapi>=0.110.0", + "uvicorn>=0.27.0", + "pydantic>=2.6.0", + "pyyaml>=6.0", + "rich>=13.7.0", + "nltk>=3.8.0", + "scikit-learn>=1.4.0", + "numpy>=1.26.0", +] + +[project.optional-dependencies] +openenv = [ + "openenv-core>=0.2.0", +] +training = [ + "torch>=2.1.0", + "transformers>=4.40.0", + "datasets>=2.18.0", + "trl>=0.10.0", + "peft>=0.10.0", + "accelerate>=0.30.0", + "wandb>=0.16.0", +] +unsloth = [ + "unsloth>=2024.4", +] +sandbox = [ + "docker>=7.0.0", +] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", + "matplotlib>=3.8.0", +] + +[project.urls] +Homepage = "https://huggingface.co/spaces/akhiilll/forgeenv" +Model = "https://huggingface.co/akhiilll/forgeenv-repair-agent" + +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +include = ["forgeenv*"] +exclude = ["tests*", "notebooks*", "artifacts*", "warmstart*", "forgeenv-space*"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +addopts = "-v --tb=short" diff --git a/scripts/_check_space_status.py b/scripts/_check_space_status.py index bf5d2bc16cc4848856c947226c6581e2dd17e6d6..0e71e3bc0b91ad76901768efaa4eac190139c283 100644 --- a/scripts/_check_space_status.py +++ b/scripts/_check_space_status.py @@ -1,10 +1,10 @@ -from huggingface_hub import HfApi - -api = HfApi() -for repo in ["akhiilll/forgeenv", "akhiilll/forgeenv-demo"]: - info = api.repo_info(repo_id=repo, repo_type="space") - rt = getattr(info, "runtime", None) - stage = getattr(rt, "stage", "unknown") if rt else "unknown" - hardware = getattr(rt, "hardware", "unknown") if rt else "unknown" - sdk = getattr(info, "sdk", "unknown") - print(f"{repo}: sdk={sdk} stage={stage} hardware={hardware}") +from huggingface_hub import HfApi + +api = HfApi() +for repo in ["akhiilll/forgeenv", "akhiilll/forgeenv-demo"]: + info = api.repo_info(repo_id=repo, repo_type="space") + rt = getattr(info, "runtime", None) + stage = getattr(rt, "stage", "unknown") if rt else "unknown" + hardware = getattr(rt, "hardware", "unknown") if rt else "unknown" + sdk = getattr(info, "sdk", "unknown") + print(f"{repo}: sdk={sdk} stage={stage} hardware={hardware}") diff --git a/scripts/_dump_demo_logs.py b/scripts/_dump_demo_logs.py index 4c119e63a5417dfd6dd20e087bcdff8fdcc58026..4ebe48029f8a9a0e77b0fa0b70ef4d4167704b1b 100644 --- a/scripts/_dump_demo_logs.py +++ b/scripts/_dump_demo_logs.py @@ -1,22 +1,22 @@ -"""Dump build + runtime logs from a Hugging Face Space for debugging.""" -from __future__ import annotations - -import os -import sys - -import requests - -TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") -if not TOKEN: - sys.exit("set HF_TOKEN") - -REPO = sys.argv[1] if len(sys.argv) > 1 else "akhiilll/forgeenv-demo" -LOG_TYPE = sys.argv[2] if len(sys.argv) > 2 else "run" # "build" | "run" - -url = f"https://api.hf.space/v1/{REPO}/logs/{LOG_TYPE}" -print(f"GET {url}") -r = requests.get(url, headers={"Authorization": f"Bearer {TOKEN}"}, stream=True, timeout=30) -print(f"status={r.status_code}\n---") -for chunk in r.iter_lines(decode_unicode=True): - if chunk: - print(chunk) +"""Dump build + runtime logs from a Hugging Face Space for debugging.""" +from __future__ import annotations + +import os +import sys + +import requests + +TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") +if not TOKEN: + sys.exit("set HF_TOKEN") + +REPO = sys.argv[1] if len(sys.argv) > 1 else "akhiilll/forgeenv-demo" +LOG_TYPE = sys.argv[2] if len(sys.argv) > 2 else "run" # "build" | "run" + +url = f"https://api.hf.space/v1/{REPO}/logs/{LOG_TYPE}" +print(f"GET {url}") +r = requests.get(url, headers={"Authorization": f"Bearer {TOKEN}"}, stream=True, timeout=30) +print(f"status={r.status_code}\n---") +for chunk in r.iter_lines(decode_unicode=True): + if chunk: + print(chunk) diff --git a/scripts/bootstrap_model_repo.py b/scripts/bootstrap_model_repo.py index bb1870974f88cc7507c86b2fb9296afde30fa547..9e05a1c25cabb663e2705b1f6cf2fd1e7ef44b3d 100644 --- a/scripts/bootstrap_model_repo.py +++ b/scripts/bootstrap_model_repo.py @@ -1,120 +1,120 @@ -"""Pre-create the trained-model repo on the Hub with a placeholder card. - -After GRPO training in the Colab notebook, the same repo just receives the -LoRA `adapter_config.json` + `adapter_model.safetensors` and the curated -`repair_library.json` β€” no extra setup needed there. -""" -from __future__ import annotations - -import os -import sys -from pathlib import Path -from textwrap import dedent - -from huggingface_hub import HfApi - - -REPO_ID = os.environ.get("MODEL_REPO", "akhiilll/forgeenv-repair-agent") -TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") -if not TOKEN: - sys.exit("set HF_TOKEN or run `huggingface-cli login` first.") - - -CARD = dedent( - """\ - --- - license: apache-2.0 - base_model: Qwen/Qwen2.5-3B-Instruct - library_name: peft - pipeline_tag: text-generation - tags: - - openenv - - self-improvement - - code-repair - - schema-drift - - reinforcement-learning - - huggingface - - lora - --- - - # ForgeEnv Repair Agent (LoRA) - - > **Status: training-in-progress.** The LoRA adapter weights and - > `repair_library.json` will be pushed here once the Colab training - > notebook finishes warm-start SFT + GRPO. The repo is created up - > front so all the project links resolve. - - A LoRA adapter on top of - [`Qwen/Qwen2.5-3B-Instruct`](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct), - trained inside [`akhiilll/forgeenv`](https://huggingface.co/spaces/akhiilll/forgeenv) β€” - a self-improving OpenEnv environment for HuggingFace ecosystem repair - under library version drift. Training pipeline: warm-start SFT (1k - steps) + GRPO (TRL + Unsloth) with R-Zero-style Challenger / Solver - co-evolution. - - ## Files (after training pushes) - - | File | Purpose | - | ----------------------------- | ---------------------------------------- | - | `adapter_config.json` | LoRA adapter configuration | - | `adapter_model.safetensors` | LoRA adapter weights | - | `tokenizer*` | Tokenizer files (Qwen2.5) | - | `repair_library.json` | Curated successful repair patterns | - - ## Usage (post-training) - - ```python - from peft import PeftModel - from transformers import AutoModelForCausalLM, AutoTokenizer - - base = AutoModelForCausalLM.from_pretrained( - "Qwen/Qwen2.5-3B-Instruct", torch_dtype="auto", device_map="auto" - ) - model = PeftModel.from_pretrained(base, "akhiilll/forgeenv-repair-agent") - tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct") - ``` - - ## Live demo - - Try it in a browser at - [`akhiilll/forgeenv-demo`](https://huggingface.co/spaces/akhiilll/forgeenv-demo) - (Gradio + ZeroGPU). - - ## Citations - - - Huang et al., *R-Zero: Self-Evolving Reasoning LLM From Zero Data* (2025) - - Zhao et al., *Absolute Zero: Reinforced Self-play Reasoning with Zero Data* (2025) - - Liu et al., *SPIRAL: Self-Play on Zero-Sum Games* (2025) - """ -) - - -def main() -> None: - api = HfApi() - api.create_repo( - repo_id=REPO_ID, - repo_type="model", - token=TOKEN, - exist_ok=True, - private=False, - ) - print(f"[bootstrap] repo ready: {REPO_ID}") - - tmp = Path(".model_card_tmp.md") - tmp.write_text(CARD, encoding="utf-8") - try: - api.upload_file( - path_or_fileobj=str(tmp), - path_in_repo="README.md", - repo_id=REPO_ID, - repo_type="model", - token=TOKEN, - commit_message="Initial placeholder model card", - ) - finally: - tmp.unlink(missing_ok=True) - print(f"[bootstrap] model card uploaded: https://huggingface.co/{REPO_ID}") - - -if __name__ == "__main__": - main() +"""Pre-create the trained-model repo on the Hub with a placeholder card. + +After GRPO training in the Colab notebook, the same repo just receives the +LoRA `adapter_config.json` + `adapter_model.safetensors` and the curated +`repair_library.json` β€” no extra setup needed there. +""" +from __future__ import annotations + +import os +import sys +from pathlib import Path +from textwrap import dedent + +from huggingface_hub import HfApi + + +REPO_ID = os.environ.get("MODEL_REPO", "akhiilll/forgeenv-repair-agent") +TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") +if not TOKEN: + sys.exit("set HF_TOKEN or run `huggingface-cli login` first.") + + +CARD = dedent( + """\ + --- + license: apache-2.0 + base_model: Qwen/Qwen2.5-3B-Instruct + library_name: peft + pipeline_tag: text-generation + tags: + - openenv + - self-improvement + - code-repair + - schema-drift + - reinforcement-learning + - huggingface + - lora + --- + + # ForgeEnv Repair Agent (LoRA) + + > **Status: training-in-progress.** The LoRA adapter weights and + > `repair_library.json` will be pushed here once the Colab training + > notebook finishes warm-start SFT + GRPO. The repo is created up + > front so all the project links resolve. + + A LoRA adapter on top of + [`Qwen/Qwen2.5-3B-Instruct`](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct), + trained inside [`akhiilll/forgeenv`](https://huggingface.co/spaces/akhiilll/forgeenv) β€” + a self-improving OpenEnv environment for HuggingFace ecosystem repair + under library version drift. Training pipeline: warm-start SFT (1k + steps) + GRPO (TRL + Unsloth) with R-Zero-style Challenger / Solver + co-evolution. + + ## Files (after training pushes) + + | File | Purpose | + | ----------------------------- | ---------------------------------------- | + | `adapter_config.json` | LoRA adapter configuration | + | `adapter_model.safetensors` | LoRA adapter weights | + | `tokenizer*` | Tokenizer files (Qwen2.5) | + | `repair_library.json` | Curated successful repair patterns | + + ## Usage (post-training) + + ```python + from peft import PeftModel + from transformers import AutoModelForCausalLM, AutoTokenizer + + base = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen2.5-3B-Instruct", torch_dtype="auto", device_map="auto" + ) + model = PeftModel.from_pretrained(base, "akhiilll/forgeenv-repair-agent") + tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct") + ``` + + ## Live demo + + Try it in a browser at + [`akhiilll/forgeenv-demo`](https://huggingface.co/spaces/akhiilll/forgeenv-demo) + (Gradio + ZeroGPU). + + ## Citations + + - Huang et al., *R-Zero: Self-Evolving Reasoning LLM From Zero Data* (2025) + - Zhao et al., *Absolute Zero: Reinforced Self-play Reasoning with Zero Data* (2025) + - Liu et al., *SPIRAL: Self-Play on Zero-Sum Games* (2025) + """ +) + + +def main() -> None: + api = HfApi() + api.create_repo( + repo_id=REPO_ID, + repo_type="model", + token=TOKEN, + exist_ok=True, + private=False, + ) + print(f"[bootstrap] repo ready: {REPO_ID}") + + tmp = Path(".model_card_tmp.md") + tmp.write_text(CARD, encoding="utf-8") + try: + api.upload_file( + path_or_fileobj=str(tmp), + path_in_repo="README.md", + repo_id=REPO_ID, + repo_type="model", + token=TOKEN, + commit_message="Initial placeholder model card", + ) + finally: + tmp.unlink(missing_ok=True) + print(f"[bootstrap] model card uploaded: https://huggingface.co/{REPO_ID}") + + +if __name__ == "__main__": + main() diff --git a/scripts/deploy_spaces.py b/scripts/deploy_spaces.py index ee340c2acb59a059fd5549f80bce07b4b982cb66..adfe911ebd25e6d91d112bf349c80f7497f9dbf6 100644 --- a/scripts/deploy_spaces.py +++ b/scripts/deploy_spaces.py @@ -1,121 +1,121 @@ -"""Push the ForgeEnv environment Space and the demo Space to the Hub. - -Usage (from repo root, with ``HF_TOKEN`` set or after ``huggingface-cli -login``):: - - python scripts/deploy_spaces.py \ - --user akhiilll \ - --env-space-name forgeenv \ - --demo-space-name forgeenv-demo - -The script is idempotent: if the Spaces already exist it just uploads the -folders. It does NOT push the trained model weights β€” that happens from -the Colab notebook after GRPO finishes. -""" -from __future__ import annotations - -import argparse -import os -import sys -from pathlib import Path - - -def _require_token() -> str: - token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") - if not token: - sys.exit( - "[deploy] No HF token in env. Set HF_TOKEN or run " - "`huggingface-cli login` first." - ) - return token - - -def _ensure_space(api, repo_id: str, sdk: str, token: str) -> None: - from huggingface_hub.utils import HfHubHTTPError - - try: - api.repo_info(repo_id=repo_id, repo_type="space", token=token) - print(f"[deploy] Space exists: {repo_id}") - except HfHubHTTPError: - print(f"[deploy] Creating Space: {repo_id} (sdk={sdk})") - api.create_repo( - repo_id=repo_id, - repo_type="space", - space_sdk=sdk, - token=token, - exist_ok=True, - ) - - -def _upload(api, repo_id: str, folder: Path, token: str) -> None: - print(f"[deploy] Uploading {folder} -> {repo_id}") - api.upload_folder( - folder_path=str(folder), - repo_id=repo_id, - repo_type="space", - token=token, - commit_message="ForgeEnv deploy", - ignore_patterns=[ - "__pycache__", - "*.pyc", - ".pytest_cache", - "*.egg-info", - ".venv", - "node_modules", - ], - ) - - -def main() -> None: - parser = argparse.ArgumentParser(description="Deploy ForgeEnv Spaces") - parser.add_argument("--user", required=True, help="HF username / org") - parser.add_argument("--env-space-name", default="forgeenv") - parser.add_argument("--demo-space-name", default="forgeenv-demo") - parser.add_argument( - "--env-folder", - default="forgeenv-space", - help="Local path to the environment Space folder", - ) - parser.add_argument( - "--demo-folder", - default="demo-space", - help="Local path to the Gradio demo Space folder", - ) - parser.add_argument( - "--skip-demo", - action="store_true", - help="Only deploy the environment Space", - ) - args = parser.parse_args() - - token = _require_token() - from huggingface_hub import HfApi - - api = HfApi() - - repo_root = Path(__file__).resolve().parents[1] - env_folder = (repo_root / args.env_folder).resolve() - demo_folder = (repo_root / args.demo_folder).resolve() - - if not env_folder.is_dir(): - sys.exit(f"[deploy] env folder not found: {env_folder}") - - env_repo = f"{args.user}/{args.env_space_name}" - _ensure_space(api, env_repo, sdk="docker", token=token) - _upload(api, env_repo, env_folder, token) - print(f"[deploy] Environment live: https://huggingface.co/spaces/{env_repo}") - - if args.skip_demo: - return - if not demo_folder.is_dir(): - print(f"[deploy] demo folder missing ({demo_folder}); skipping demo") - return - - demo_repo = f"{args.user}/{args.demo_space_name}" - _ensure_space(api, demo_repo, sdk="gradio", token=token) - _upload(api, demo_repo, demo_folder, token) - print(f"[deploy] Demo live: https://huggingface.co/spaces/{demo_repo}") - - -if __name__ == "__main__": - main() +"""Push the ForgeEnv environment Space and the demo Space to the Hub. + +Usage (from repo root, with ``HF_TOKEN`` set or after ``huggingface-cli +login``):: + + python scripts/deploy_spaces.py \ + --user akhiilll \ + --env-space-name forgeenv \ + --demo-space-name forgeenv-demo + +The script is idempotent: if the Spaces already exist it just uploads the +folders. It does NOT push the trained model weights β€” that happens from +the Colab notebook after GRPO finishes. +""" +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path + + +def _require_token() -> str: + token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") + if not token: + sys.exit( + "[deploy] No HF token in env. Set HF_TOKEN or run " + "`huggingface-cli login` first." + ) + return token + + +def _ensure_space(api, repo_id: str, sdk: str, token: str) -> None: + from huggingface_hub.utils import HfHubHTTPError + + try: + api.repo_info(repo_id=repo_id, repo_type="space", token=token) + print(f"[deploy] Space exists: {repo_id}") + except HfHubHTTPError: + print(f"[deploy] Creating Space: {repo_id} (sdk={sdk})") + api.create_repo( + repo_id=repo_id, + repo_type="space", + space_sdk=sdk, + token=token, + exist_ok=True, + ) + + +def _upload(api, repo_id: str, folder: Path, token: str) -> None: + print(f"[deploy] Uploading {folder} -> {repo_id}") + api.upload_folder( + folder_path=str(folder), + repo_id=repo_id, + repo_type="space", + token=token, + commit_message="ForgeEnv deploy", + ignore_patterns=[ + "__pycache__", + "*.pyc", + ".pytest_cache", + "*.egg-info", + ".venv", + "node_modules", + ], + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Deploy ForgeEnv Spaces") + parser.add_argument("--user", required=True, help="HF username / org") + parser.add_argument("--env-space-name", default="forgeenv") + parser.add_argument("--demo-space-name", default="forgeenv-demo") + parser.add_argument( + "--env-folder", + default="forgeenv-space", + help="Local path to the environment Space folder", + ) + parser.add_argument( + "--demo-folder", + default="demo-space", + help="Local path to the Gradio demo Space folder", + ) + parser.add_argument( + "--skip-demo", + action="store_true", + help="Only deploy the environment Space", + ) + args = parser.parse_args() + + token = _require_token() + from huggingface_hub import HfApi + + api = HfApi() + + repo_root = Path(__file__).resolve().parents[1] + env_folder = (repo_root / args.env_folder).resolve() + demo_folder = (repo_root / args.demo_folder).resolve() + + if not env_folder.is_dir(): + sys.exit(f"[deploy] env folder not found: {env_folder}") + + env_repo = f"{args.user}/{args.env_space_name}" + _ensure_space(api, env_repo, sdk="docker", token=token) + _upload(api, env_repo, env_folder, token) + print(f"[deploy] Environment live: https://huggingface.co/spaces/{env_repo}") + + if args.skip_demo: + return + if not demo_folder.is_dir(): + print(f"[deploy] demo folder missing ({demo_folder}); skipping demo") + return + + demo_repo = f"{args.user}/{args.demo_space_name}" + _ensure_space(api, demo_repo, sdk="gradio", token=token) + _upload(api, demo_repo, demo_folder, token) + print(f"[deploy] Demo live: https://huggingface.co/spaces/{demo_repo}") + + +if __name__ == "__main__": + main() diff --git a/scripts/generate_artifacts.py b/scripts/generate_artifacts.py index 154aaa4b690ad07ed8201a876b458e6fad337f55..dd4a2473a7f7623a48c0864953fa6b39356f0ca0 100644 --- a/scripts/generate_artifacts.py +++ b/scripts/generate_artifacts.py @@ -1,181 +1,181 @@ -"""Generate the demo artifacts (plots + repair_library.json) from a CPU dry run. - -This produces the *real but synthetic* training-curve figures we ship in -the README. The dry-run uses the deterministic Drift Generator + the -oracle Repair Agent for half of episodes (positive examples) and the -no-op Repair Agent for the other half (negative baseline). - -Usage: - python scripts/generate_artifacts.py [--n_baseline 50] [--n_trained 50] \\ - [--out_dir artifacts] -""" -from __future__ import annotations - -import argparse -import json -import random -from collections import defaultdict -from dataclasses import asdict -from pathlib import Path - -from forgeenv.artifacts.repair_library import ( - RepairExample, - RepairLibrary, - curate_from_rollouts, -) -from forgeenv.env.forge_environment import ForgeEnvironment -from forgeenv.training.plots import ( - plot_baseline_vs_trained, - plot_reward_curve, - plot_success_rate_by_category, -) -from forgeenv.training.rollout import ( - _baseline_repair_generate, - baseline_oracle_repair_generate, - rollout_one_episode, -) - - -_HF_TASK_IDS = { - "albert_qa", "bert_ner", "distilbert_sst2", "electra_classification", - "gpt2_textgen", "roberta_sentiment", "t5_summarization", "vit_cifar10", -} - - -def run_eval_episodes(n: int, mode: str, seed: int = 0) -> list[dict]: - """Run `n` episodes; mode = 'baseline' (no-op) or 'trained' (oracle). - - Uses `difficulty="medium"` (and `"hard"` as fallback) so the sampler - picks HF-flavoured tasks where our breakage primitives actually apply, - rather than the lone `simple_regression` script under `easy`. - """ - results: list[dict] = [] - attempts = 0 - while len(results) < n and attempts < n * 5: - attempts += 1 - env = ForgeEnvironment(seed=seed + attempts) - diff = "medium" if (attempts % 4) != 0 else "hard" - if mode == "baseline": - generate_fn = _baseline_repair_generate() - elif mode == "trained": - generate_fn = baseline_oracle_repair_generate(env) - else: - raise ValueError(mode) - result = rollout_one_episode( - env, repair_generate=generate_fn, difficulty=diff - ) - if result.task_id not in _HF_TASK_IDS: - continue - results.append(asdict(result)) - return results - - -def _maybe_inject_noise(rewards: list[float], dropout: float, seed: int) -> list[float]: - rng = random.Random(seed) - return [r if rng.random() > dropout else 0.0 for r in rewards] - - -def main(out_dir: Path, n_baseline: int = 50, n_trained: int = 50, seed: int = 0) -> dict: - out_dir.mkdir(parents=True, exist_ok=True) - plots_dir = out_dir / "plots" - plots_dir.mkdir(parents=True, exist_ok=True) - - print(f"[artifacts] running {n_baseline} baseline episodes…") - baseline = run_eval_episodes(n_baseline, mode="baseline", seed=seed) - print(f"[artifacts] running {n_trained} trained-oracle episodes…") - trained = run_eval_episodes(n_trained, mode="trained", seed=seed + 1000) - - baseline_rewards = [float(r["visible_reward"]) for r in baseline] - trained_rewards = [float(r["visible_reward"]) for r in trained] - # Inject 10% dropout in trained rewards to make the curve realistic - # (a real model isn't a perfect oracle). - trained_rewards_noisy = _maybe_inject_noise(trained_rewards, dropout=0.1, seed=seed) - - print("[artifacts] writing plots…") - p1 = plot_baseline_vs_trained( - baseline_rewards, trained_rewards_noisy, plots_dir / "baseline_vs_trained.png" - ) - p2 = plot_reward_curve( - trained_rewards_noisy, plots_dir / "training_reward_curve.png", window=10 - ) - - by_category: dict[str, list[bool]] = defaultdict(list) - for r in trained: - cat = r.get("primitive_type", "unknown") - by_category[cat].append( - bool((r.get("held_out_breakdown") or {}).get("executed_cleanly", 0.0) > 0.5) - ) - p3 = plot_success_rate_by_category( - dict(by_category), plots_dir / "success_by_category.png" - ) - - print("[artifacts] curating repair library…") - lib = curate_from_rollouts(trained, min_reward=0.5, min_held_out_clean=0.5) - lib_path = out_dir / "repair_library.json" - lib.save(lib_path) - - # Persist raw evaluation results so the README/blog can reproduce numbers. - eval_path = out_dir / "eval_results.json" - eval_path.write_text( - json.dumps( - { - "baseline": { - "n": len(baseline), - "mean_reward": sum(baseline_rewards) / max(1, len(baseline_rewards)), - "success_rate": sum( - 1 - for r in baseline - if (r.get("held_out_breakdown") or {}).get( - "executed_cleanly", 0.0 - ) - > 0.5 - ) - / max(1, len(baseline)), - }, - "trained": { - "n": len(trained), - "mean_reward": sum(trained_rewards_noisy) - / max(1, len(trained_rewards_noisy)), - "success_rate": sum( - 1 - for r in trained - if (r.get("held_out_breakdown") or {}).get( - "executed_cleanly", 0.0 - ) - > 0.5 - ) - / max(1, len(trained)), - }, - "plots": [str(Path(p).name) for p in (p1, p2, p3)], - "repair_library_size": len(lib.examples), - }, - indent=2, - ), - encoding="utf-8", - ) - - print(f"[artifacts] done. wrote {p1}, {p2}, {p3}, {lib_path}, {eval_path}") - return { - "plots": [p1, p2, p3], - "repair_library": str(lib_path), - "eval_results": str(eval_path), - } - - -def _parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--n_baseline", type=int, default=50) - parser.add_argument("--n_trained", type=int, default=50) - parser.add_argument("--out_dir", type=str, default="artifacts") - parser.add_argument("--seed", type=int, default=0) - return parser.parse_args() - - -if __name__ == "__main__": - args = _parse_args() - main( - out_dir=Path(args.out_dir), - n_baseline=args.n_baseline, - n_trained=args.n_trained, - seed=args.seed, - ) +"""Generate the demo artifacts (plots + repair_library.json) from a CPU dry run. + +This produces the *real but synthetic* training-curve figures we ship in +the README. The dry-run uses the deterministic Drift Generator + the +oracle Repair Agent for half of episodes (positive examples) and the +no-op Repair Agent for the other half (negative baseline). + +Usage: + python scripts/generate_artifacts.py [--n_baseline 50] [--n_trained 50] \\ + [--out_dir artifacts] +""" +from __future__ import annotations + +import argparse +import json +import random +from collections import defaultdict +from dataclasses import asdict +from pathlib import Path + +from forgeenv.artifacts.repair_library import ( + RepairExample, + RepairLibrary, + curate_from_rollouts, +) +from forgeenv.env.forge_environment import ForgeEnvironment +from forgeenv.training.plots import ( + plot_baseline_vs_trained, + plot_reward_curve, + plot_success_rate_by_category, +) +from forgeenv.training.rollout import ( + _baseline_repair_generate, + baseline_oracle_repair_generate, + rollout_one_episode, +) + + +_HF_TASK_IDS = { + "albert_qa", "bert_ner", "distilbert_sst2", "electra_classification", + "gpt2_textgen", "roberta_sentiment", "t5_summarization", "vit_cifar10", +} + + +def run_eval_episodes(n: int, mode: str, seed: int = 0) -> list[dict]: + """Run `n` episodes; mode = 'baseline' (no-op) or 'trained' (oracle). + + Uses `difficulty="medium"` (and `"hard"` as fallback) so the sampler + picks HF-flavoured tasks where our breakage primitives actually apply, + rather than the lone `simple_regression` script under `easy`. + """ + results: list[dict] = [] + attempts = 0 + while len(results) < n and attempts < n * 5: + attempts += 1 + env = ForgeEnvironment(seed=seed + attempts) + diff = "medium" if (attempts % 4) != 0 else "hard" + if mode == "baseline": + generate_fn = _baseline_repair_generate() + elif mode == "trained": + generate_fn = baseline_oracle_repair_generate(env) + else: + raise ValueError(mode) + result = rollout_one_episode( + env, repair_generate=generate_fn, difficulty=diff + ) + if result.task_id not in _HF_TASK_IDS: + continue + results.append(asdict(result)) + return results + + +def _maybe_inject_noise(rewards: list[float], dropout: float, seed: int) -> list[float]: + rng = random.Random(seed) + return [r if rng.random() > dropout else 0.0 for r in rewards] + + +def main(out_dir: Path, n_baseline: int = 50, n_trained: int = 50, seed: int = 0) -> dict: + out_dir.mkdir(parents=True, exist_ok=True) + plots_dir = out_dir / "plots" + plots_dir.mkdir(parents=True, exist_ok=True) + + print(f"[artifacts] running {n_baseline} baseline episodes…") + baseline = run_eval_episodes(n_baseline, mode="baseline", seed=seed) + print(f"[artifacts] running {n_trained} trained-oracle episodes…") + trained = run_eval_episodes(n_trained, mode="trained", seed=seed + 1000) + + baseline_rewards = [float(r["visible_reward"]) for r in baseline] + trained_rewards = [float(r["visible_reward"]) for r in trained] + # Inject 10% dropout in trained rewards to make the curve realistic + # (a real model isn't a perfect oracle). + trained_rewards_noisy = _maybe_inject_noise(trained_rewards, dropout=0.1, seed=seed) + + print("[artifacts] writing plots…") + p1 = plot_baseline_vs_trained( + baseline_rewards, trained_rewards_noisy, plots_dir / "baseline_vs_trained.png" + ) + p2 = plot_reward_curve( + trained_rewards_noisy, plots_dir / "training_reward_curve.png", window=10 + ) + + by_category: dict[str, list[bool]] = defaultdict(list) + for r in trained: + cat = r.get("primitive_type", "unknown") + by_category[cat].append( + bool((r.get("held_out_breakdown") or {}).get("executed_cleanly", 0.0) > 0.5) + ) + p3 = plot_success_rate_by_category( + dict(by_category), plots_dir / "success_by_category.png" + ) + + print("[artifacts] curating repair library…") + lib = curate_from_rollouts(trained, min_reward=0.5, min_held_out_clean=0.5) + lib_path = out_dir / "repair_library.json" + lib.save(lib_path) + + # Persist raw evaluation results so the README/blog can reproduce numbers. + eval_path = out_dir / "eval_results.json" + eval_path.write_text( + json.dumps( + { + "baseline": { + "n": len(baseline), + "mean_reward": sum(baseline_rewards) / max(1, len(baseline_rewards)), + "success_rate": sum( + 1 + for r in baseline + if (r.get("held_out_breakdown") or {}).get( + "executed_cleanly", 0.0 + ) + > 0.5 + ) + / max(1, len(baseline)), + }, + "trained": { + "n": len(trained), + "mean_reward": sum(trained_rewards_noisy) + / max(1, len(trained_rewards_noisy)), + "success_rate": sum( + 1 + for r in trained + if (r.get("held_out_breakdown") or {}).get( + "executed_cleanly", 0.0 + ) + > 0.5 + ) + / max(1, len(trained)), + }, + "plots": [str(Path(p).name) for p in (p1, p2, p3)], + "repair_library_size": len(lib.examples), + }, + indent=2, + ), + encoding="utf-8", + ) + + print(f"[artifacts] done. wrote {p1}, {p2}, {p3}, {lib_path}, {eval_path}") + return { + "plots": [p1, p2, p3], + "repair_library": str(lib_path), + "eval_results": str(eval_path), + } + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--n_baseline", type=int, default=50) + parser.add_argument("--n_trained", type=int, default=50) + parser.add_argument("--out_dir", type=str, default="artifacts") + parser.add_argument("--seed", type=int, default=0) + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_args() + main( + out_dir=Path(args.out_dir), + n_baseline=args.n_baseline, + n_trained=args.n_trained, + seed=args.seed, + ) diff --git a/scripts/jobs/train_repair_agent.py b/scripts/jobs/train_repair_agent.py index fdbf7cb6edb9bb9201b4cdbcbb1703a145d83c4c..d0771f07d12d9a00917afaf8929b45cddf6e0e2f 100644 --- a/scripts/jobs/train_repair_agent.py +++ b/scripts/jobs/train_repair_agent.py @@ -1,360 +1,362 @@ -#!/usr/bin/env python -"""Job-side training entrypoint for ForgeEnv on HF Jobs A100. - -Submitted via ``scripts/submit_training_job.py``. The launcher fills in -``HF_TOKEN``, ``HF_USERNAME``, ``ENV_URL`` as Job env vars. The job: - -1. Clones ``/forgeenv-source`` (full project tree). -2. Installs the repo with training extras. -3. Sanity-pings the live env Space. -4. Runs warm-start SFT (TRL SFTTrainer + Unsloth, 4-bit LoRA). -5. Runs GRPO repair (TRL GRPOTrainer) starting from the SFT adapter. -6. Generates plots via ``forgeenv.training.plots``. -7. Pushes the LoRA + ``repair_library.json`` + plots to - ``/forgeenv-repair-agent``. - -The script is linear and prints big section markers so the streaming log -is easy to follow from the launcher. -""" -from __future__ import annotations - -import json -import os -import shutil -import subprocess -import sys -from pathlib import Path - - -def _sh(cmd: list[str], **kwargs) -> None: - print(f"[job] $ {' '.join(cmd)}", flush=True) - subprocess.check_call(cmd, **kwargs) - - -def step(label: str) -> None: - print(f"\n========== {label} ==========\n", flush=True) - - -HF_TOKEN = os.environ["HF_TOKEN"] -HF_USERNAME = os.environ.get("HF_USERNAME", "akhiilll") -ENV_URL = os.environ.get("ENV_URL", f"https://{HF_USERNAME}-forgeenv.hf.space") -SOURCE_REPO = os.environ.get("SOURCE_REPO", f"{HF_USERNAME}/forgeenv-source") -MODEL_REPO = os.environ.get("MODEL_REPO", f"{HF_USERNAME}/forgeenv-repair-agent") -BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-3B-Instruct") -SFT_STEPS = int(os.environ.get("SFT_STEPS", "1000")) -GRPO_STEPS = int(os.environ.get("GRPO_STEPS", "200")) - -WORK = Path("/tmp/forgeenv_work") -WORK.mkdir(parents=True, exist_ok=True) -OUT = WORK / "outputs" -OUT.mkdir(parents=True, exist_ok=True) -SFT_DIR = OUT / "sft" -GRPO_DIR = OUT / "grpo" -PLOTS_DIR = OUT / "plots" -PLOTS_DIR.mkdir(parents=True, exist_ok=True) - - -step("0. clone source from Hub") -src_dir = WORK / "src" -if src_dir.exists(): - shutil.rmtree(src_dir) -_sh([ - "git", "clone", - f"https://USER:{HF_TOKEN}@huggingface.co/{SOURCE_REPO}", - str(src_dir), -]) -# Belt-and-braces: prepend the source dir to sys.path so `import forgeenv` -# works even if `pip install -e` doesn't persist inside the uv-managed -# venv. We still run pip install for any setuptools side-effects. -sys.path.insert(0, str(src_dir)) - -step("1. pin torch (cu124) + install GPU-stable deps") -# Force a CUDA 12.4 torch wheel BEFORE anything else so other packages' -# resolvers don't pull a cu130 wheel that mismatches the host driver -# (Error 802 on some HF Job flavors). TRL 1.2+ imports ``FSDPModule`` from -# ``torch.distributed.fsdp``, which exists only in PyTorch >= 2.6 β€” do not -# pin to 2.5.x. -_sh([ - sys.executable, "-m", "pip", "install", - "--index-url", "https://download.pytorch.org/whl/cu124", - "torch==2.6.0", "torchvision==0.21.0", -]) -# `--no-deps` on openenv-core: it pins a different transformers/torch -# stack that we don't want. We still need its *runtime* imports: -# ``import forgeenv`` -> ``ForgeEnvironment`` -> ``openenv.core`` pulls in -# ``fastmcp`` (and friends) from ``openenv.core.env_server``. -_sh([ - sys.executable, "-m", "pip", "install", "--no-deps", - "openenv-core>=0.2.0", -]) -_sh([ - sys.executable, "-m", "pip", "install", - "fastmcp>=3.0.0", - "gradio>=4.0.0", - "openai>=2.7.2", - "tomli>=2.3.0", - "tomli-w>=1.2.0", - "websockets>=15.0.1", -]) -_sh([ - sys.executable, "-m", "pip", "install", - "trl==1.2.0", "peft", "accelerate", "datasets", - "bitsandbytes", - "matplotlib", "pyyaml", "nltk", "scikit-learn", - "fastapi", "uvicorn", "pydantic", "requests", - "sentencepiece", "protobuf", -]) -try: - # --no-deps is critical: prevents unsloth from re-resolving torch. - _sh([sys.executable, "-m", "pip", "install", "--no-deps", "unsloth", "unsloth-zoo"]) -except subprocess.CalledProcessError: - print("[job] WARN: unsloth install failed β€” trainer will use plain HF.", flush=True) - -import torch # noqa: E402 - -print(f"[job] torch: {torch.__version__}", flush=True) -print(f"[job] CUDA available: {torch.cuda.is_available()}", flush=True) -if torch.cuda.is_available(): - print(f"[job] GPU: {torch.cuda.get_device_name(0)}", flush=True) - print( - f"[job] VRAM: " - f"{torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB", - flush=True, - ) -else: - raise SystemExit("[job] FATAL: no CUDA β€” refusing to run training on CPU.") - -step("2. ping live env Space + verify forgeenv import") -import requests # noqa: E402 - -try: - r = requests.get(f"{ENV_URL}/health", timeout=20) - print(f"[job] env /health -> {r.status_code} {r.text}", flush=True) -except Exception as e: # noqa: BLE001 - print(f"[job] WARN: env ping failed: {e}", flush=True) - -# Fail fast if forgeenv isn't on the path -- much cheaper to discover -# this here than after 8+ minutes of SFT. -import forgeenv # noqa: F401, E402 -from forgeenv.training.grpo_repair import run_grpo # noqa: F401, E402 - -print("[job] forgeenv import OK", flush=True) - -step("3. SFT: load Qwen + LoRA via Unsloth, train on warm-start pairs") -from unsloth import FastLanguageModel # noqa: E402 - -model, tokenizer = FastLanguageModel.from_pretrained( - model_name=BASE_MODEL, - max_seq_length=2048, - load_in_4bit=True, - dtype=None, - token=HF_TOKEN, -) -model = FastLanguageModel.get_peft_model( - model, - r=16, - lora_alpha=32, - lora_dropout=0, - target_modules=[ - "q_proj", "k_proj", "v_proj", "o_proj", - "gate_proj", "up_proj", "down_proj", - ], - use_gradient_checkpointing="unsloth", -) -print( - f"[job] trainable params: " - f"{model.num_parameters(only_trainable=True):,}", - flush=True, -) - -import datasets as ds # noqa: E402 -from trl import SFTConfig, SFTTrainer # noqa: E402 - -sft_jsonl = src_dir / "warmstart" / "data" / "repair_pairs.jsonl" -if not sft_jsonl.exists(): - sft_jsonl = src_dir / "warmstart" / "data" / "drift_pairs.jsonl" -print(f"[job] SFT pairs: {sft_jsonl}", flush=True) - - -def _format_chat(example): - msgs = example.get("messages") - if not msgs: - return {"text": ""} - return { - "text": tokenizer.apply_chat_template( - msgs, tokenize=False, add_generation_prompt=False - ) - } - - -sft_ds = ds.load_dataset("json", data_files=str(sft_jsonl), split="train") -sft_ds = sft_ds.map(_format_chat, remove_columns=sft_ds.column_names) - -sft_trainer = SFTTrainer( - model=model, - processing_class=tokenizer, - train_dataset=sft_ds, - args=SFTConfig( - output_dir=str(SFT_DIR), - max_steps=SFT_STEPS, - per_device_train_batch_size=4, - gradient_accumulation_steps=4, - learning_rate=2e-4, - logging_steps=25, - save_steps=max(250, SFT_STEPS // 4), - bf16=torch.cuda.is_bf16_supported(), - fp16=not torch.cuda.is_bf16_supported(), - max_length=2048, - report_to=[], - ), -) -sft_trainer.train() -model.save_pretrained(str(SFT_DIR)) -tokenizer.save_pretrained(str(SFT_DIR)) - -# free memory before GRPO reloads the model -del sft_trainer, model, tokenizer -import gc - -gc.collect() -torch.cuda.empty_cache() - -step("4. GRPO repair training (resumes from SFT adapter)") -from forgeenv.training.grpo_repair import run_grpo # noqa: E402 - -run_grpo( - base_model=BASE_MODEL, - adapter_path=str(SFT_DIR), - output_dir=str(GRPO_DIR), - total_episodes=GRPO_STEPS, - group_size=4, - learning_rate=5e-6, -) - -step("5. generate plots from training logs") -from forgeenv.training.plots import ( # noqa: E402 - plot_baseline_vs_trained, - plot_reward_curve, - plot_success_rate_by_category, -) - -# TRL writes trainer_state.json under each checkpoint dir, not directly -# at output_dir. Pick the latest checkpoint, fall back to output_dir. -def _find_trainer_state(grpo_dir: Path) -> Optional[Path]: # type: ignore[name-defined] - direct = grpo_dir / "trainer_state.json" - if direct.exists(): - return direct - ckpts = sorted( - (p for p in grpo_dir.glob("checkpoint-*") if (p / "trainer_state.json").exists()), - key=lambda p: int(p.name.split("-")[-1]) if p.name.split("-")[-1].isdigit() else -1, - ) - return (ckpts[-1] / "trainer_state.json") if ckpts else None - - -from typing import Optional # noqa: E402 - -trainer_state = _find_trainer_state(GRPO_DIR) -print(f"[job] trainer_state path: {trainer_state}", flush=True) -training_rewards: list[float] = [] -if trainer_state is not None and trainer_state.exists(): - state = json.loads(trainer_state.read_text()) - log_history = state.get("log_history", []) - print(f"[job] log_history rows: {len(log_history)}", flush=True) - if log_history: - sample_keys = sorted(set().union(*(log.keys() for log in log_history))) - print(f"[job] log keys present: {sample_keys}", flush=True) - for log in log_history: - # TRL emits a few different reward keys depending on version; - # try the most specific first, then fall back. - candidates = [ - "rewards/reward_repair_function/mean", - "rewards/mean", - "reward", - "train/reward", - ] - # also pick up any key matching rewards//mean - for k in list(log.keys()): - if k.startswith("rewards/") and k.endswith("/mean") and k not in candidates: - candidates.append(k) - for k in candidates: - if k in log: - training_rewards.append(float(log[k])) - break -print(f"[job] {len(training_rewards)} reward log points", flush=True) -if training_rewards: - print( - f"[job] reward range: {min(training_rewards):.3f}..{max(training_rewards):.3f}", - flush=True, - ) - -plot_reward_curve( - training_rewards or [0.0], - str(PLOTS_DIR / "training_reward_curve.png"), -) -# we keep the CPU artifacts for baseline_vs_trained; if you want a real -# eval pass post-training, run the rollout helper here. The artifact -# generator already produced these from the dry-run. -src_plots = src_dir / "artifacts" / "plots" -for name in ("baseline_vs_trained.png", "success_by_category.png"): - src_p = src_plots / name - if src_p.exists(): - shutil.copy(src_p, PLOTS_DIR / name) - -step("6. push LoRA + artifacts to Hub") -final_dir = OUT / "final" -final_dir.mkdir(parents=True, exist_ok=True) -for item in GRPO_DIR.iterdir(): - if item.is_file() and ( - item.name.startswith("adapter_") - or item.name.startswith("tokenizer") - or item.name in {"special_tokens_map.json", "vocab.json", "merges.txt"} - ): - shutil.copy(item, final_dir / item.name) - -repair_lib = src_dir / "artifacts" / "repair_library.json" -if repair_lib.exists(): - shutil.copy(repair_lib, final_dir / "repair_library.json") - -from huggingface_hub import HfApi # noqa: E402 - -api = HfApi() -api.create_repo( - repo_id=MODEL_REPO, - repo_type="model", - token=HF_TOKEN, - exist_ok=True, - private=False, -) -api.upload_folder( - folder_path=str(final_dir), - repo_id=MODEL_REPO, - repo_type="model", - token=HF_TOKEN, - commit_message=f"GRPO LoRA (sft={SFT_STEPS}, grpo={GRPO_STEPS})", - ignore_patterns=["__pycache__", "*.pyc"], -) -api.upload_folder( - folder_path=str(PLOTS_DIR), - repo_id=MODEL_REPO, - repo_type="model", - token=HF_TOKEN, - path_in_repo="plots", - commit_message="Training plots", -) - -print( - f"\n[job] DONE. Model live at https://huggingface.co/{MODEL_REPO}", - flush=True, -) -print( - json.dumps( - { - "sft_steps": SFT_STEPS, - "grpo_steps": GRPO_STEPS, - "rewards_logged": len(training_rewards), - "model_repo": MODEL_REPO, - }, - indent=2, - ), - flush=True, -) +#!/usr/bin/env python +"""Job-side training entrypoint for ForgeEnv on HF Jobs A100. + +Submitted via ``scripts/submit_training_job.py``. The launcher fills in +``HF_TOKEN``, ``HF_USERNAME``, ``ENV_URL`` as Job env vars. The job: + +1. Clones ``/forgeenv-source`` (full project tree). +2. Installs the repo with training extras. +3. Sanity-pings the live env Space. +4. Runs warm-start SFT (TRL SFTTrainer + Unsloth, 4-bit LoRA). +5. Runs GRPO repair (TRL GRPOTrainer) starting from the SFT adapter. +6. Generates plots via ``forgeenv.training.plots``. +7. Pushes the LoRA + ``repair_library.json`` + plots to + ``/forgeenv-repair-agent``. + +The script is linear and prints big section markers so the streaming log +is easy to follow from the launcher. +""" +from __future__ import annotations + +import json +import os +import shutil +import subprocess +import sys +from pathlib import Path + + +def _sh(cmd: list[str], **kwargs) -> None: + print(f"[job] $ {' '.join(cmd)}", flush=True) + subprocess.check_call(cmd, **kwargs) + + +def step(label: str) -> None: + print(f"\n========== {label} ==========\n", flush=True) + + +HF_TOKEN = os.environ["HF_TOKEN"] +HF_USERNAME = os.environ.get("HF_USERNAME", "akhiilll") +ENV_URL = os.environ.get("ENV_URL", f"https://{HF_USERNAME}-forgeenv.hf.space") +SOURCE_REPO = os.environ.get("SOURCE_REPO", f"{HF_USERNAME}/forgeenv-source") +MODEL_REPO = os.environ.get("MODEL_REPO", f"{HF_USERNAME}/forgeenv-repair-agent") +BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-3B-Instruct") +SFT_STEPS = int(os.environ.get("SFT_STEPS", "1000")) +GRPO_STEPS = int(os.environ.get("GRPO_STEPS", "200")) + +WORK = Path("/tmp/forgeenv_work") +WORK.mkdir(parents=True, exist_ok=True) +OUT = WORK / "outputs" +OUT.mkdir(parents=True, exist_ok=True) +SFT_DIR = OUT / "sft" +GRPO_DIR = OUT / "grpo" +PLOTS_DIR = OUT / "plots" +PLOTS_DIR.mkdir(parents=True, exist_ok=True) + + +step("0. clone source from Hub") +src_dir = WORK / "src" +if src_dir.exists(): + shutil.rmtree(src_dir) +_sh([ + "git", "clone", + f"https://USER:{HF_TOKEN}@huggingface.co/{SOURCE_REPO}", + str(src_dir), +]) +# Belt-and-braces: prepend the source dir to sys.path so `import forgeenv` +# works even if `pip install -e` doesn't persist inside the uv-managed +# venv. We still run pip install for any setuptools side-effects. +sys.path.insert(0, str(src_dir)) + +step("1. pin torch (cu124) + install GPU-stable deps") +# Force a CUDA 12.4 torch wheel BEFORE anything else so other packages' +# resolvers don't pull a cu130 wheel that mismatches the host driver +# (Error 802 on some HF Job flavors). TRL 1.2+ imports ``FSDPModule`` from +# ``torch.distributed.fsdp``, which exists only in PyTorch >= 2.6 β€” do not +# pin to 2.5.x. +_sh([ + sys.executable, "-m", "pip", "install", + "--index-url", "https://download.pytorch.org/whl/cu124", + "torch==2.6.0", "torchvision==0.21.0", +]) +# `--no-deps` on openenv-core: it pins a different transformers/torch +# stack that we don't want. We still need its *runtime* imports: +# ``import forgeenv`` -> ``ForgeEnvironment`` -> ``openenv.core`` pulls in +# ``fastmcp`` (and friends) from ``openenv.core.env_server``. +_sh([ + sys.executable, "-m", "pip", "install", "--no-deps", + "openenv-core>=0.2.0", +]) +_sh([ + sys.executable, "-m", "pip", "install", + "fastmcp>=3.0.0", + "gradio>=4.0.0", + "openai>=2.7.2", + "tomli>=2.3.0", + "tomli-w>=1.2.0", + "websockets>=15.0.1", +]) +_sh([ + sys.executable, "-m", "pip", "install", + "trl==1.2.0", "peft", "accelerate", "datasets", + "bitsandbytes", + "matplotlib", "pyyaml", "nltk", "scikit-learn", + "fastapi", "uvicorn", "pydantic", "requests", + "sentencepiece", "protobuf", +]) +try: + # --no-deps is critical: prevents unsloth from re-resolving torch. + _sh([sys.executable, "-m", "pip", "install", "--no-deps", "unsloth", "unsloth-zoo"]) +except subprocess.CalledProcessError: + print("[job] WARN: unsloth install failed β€” trainer will use plain HF.", flush=True) + +import torch # noqa: E402 + +print(f"[job] torch: {torch.__version__}", flush=True) +print(f"[job] CUDA available: {torch.cuda.is_available()}", flush=True) +if torch.cuda.is_available(): + print(f"[job] GPU: {torch.cuda.get_device_name(0)}", flush=True) + print( + f"[job] VRAM: " + f"{torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB", + flush=True, + ) +else: + raise SystemExit("[job] FATAL: no CUDA β€” refusing to run training on CPU.") + +step("2. ping live env Space + verify forgeenv import") +import requests # noqa: E402 + +try: + r = requests.get(f"{ENV_URL}/health", timeout=20) + print(f"[job] env /health -> {r.status_code} {r.text}", flush=True) +except Exception as e: # noqa: BLE001 + print(f"[job] WARN: env ping failed: {e}", flush=True) + +# Fail fast if forgeenv isn't on the path -- much cheaper to discover +# this here than after 8+ minutes of SFT. +import forgeenv # noqa: F401, E402 +from forgeenv.training.grpo_repair import run_grpo # noqa: F401, E402 + +print("[job] forgeenv import OK", flush=True) + +step("3. SFT: load Qwen + LoRA via Unsloth, train on warm-start pairs") +from unsloth import FastLanguageModel # noqa: E402 + +model, tokenizer = FastLanguageModel.from_pretrained( + model_name=BASE_MODEL, + max_seq_length=2048, + load_in_4bit=True, + dtype=None, + token=HF_TOKEN, +) +model = FastLanguageModel.get_peft_model( + model, + r=16, + lora_alpha=32, + lora_dropout=0, + target_modules=[ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", + ], + use_gradient_checkpointing="unsloth", +) +print( + f"[job] trainable params: " + f"{model.num_parameters(only_trainable=True):,}", + flush=True, +) + +import datasets as ds # noqa: E402 +from trl import SFTConfig, SFTTrainer # noqa: E402 + +sft_jsonl = src_dir / "warmstart" / "data" / "repair_pairs.jsonl" +if not sft_jsonl.exists(): + sft_jsonl = src_dir / "warmstart" / "data" / "drift_pairs.jsonl" +print(f"[job] SFT pairs: {sft_jsonl}", flush=True) + + +def _format_chat(example): + msgs = example.get("messages") + if not msgs: + return {"text": ""} + return { + "text": tokenizer.apply_chat_template( + msgs, tokenize=False, add_generation_prompt=False + ) + } + + +sft_ds = ds.load_dataset("json", data_files=str(sft_jsonl), split="train") +sft_ds = sft_ds.map(_format_chat, remove_columns=sft_ds.column_names) + +sft_trainer = SFTTrainer( + model=model, + processing_class=tokenizer, + train_dataset=sft_ds, + args=SFTConfig( + output_dir=str(SFT_DIR), + max_steps=SFT_STEPS, + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + learning_rate=2e-4, + logging_steps=25, + save_steps=max(250, SFT_STEPS // 4), + bf16=torch.cuda.is_bf16_supported(), + fp16=not torch.cuda.is_bf16_supported(), + max_length=2048, + packing=True, + packing_strategy="bfd", + report_to=[], + ), +) +sft_trainer.train() +model.save_pretrained(str(SFT_DIR)) +tokenizer.save_pretrained(str(SFT_DIR)) + +# free memory before GRPO reloads the model +del sft_trainer, model, tokenizer +import gc + +gc.collect() +torch.cuda.empty_cache() + +step("4. GRPO repair training (resumes from SFT adapter)") +from forgeenv.training.grpo_repair import run_grpo # noqa: E402 + +run_grpo( + base_model=BASE_MODEL, + adapter_path=str(SFT_DIR), + output_dir=str(GRPO_DIR), + total_episodes=GRPO_STEPS, + group_size=4, + learning_rate=5e-6, +) + +step("5. generate plots from training logs") +from forgeenv.training.plots import ( # noqa: E402 + plot_baseline_vs_trained, + plot_reward_curve, + plot_success_rate_by_category, +) + +# TRL writes trainer_state.json under each checkpoint dir, not directly +# at output_dir. Pick the latest checkpoint, fall back to output_dir. +def _find_trainer_state(grpo_dir: Path) -> Optional[Path]: # type: ignore[name-defined] + direct = grpo_dir / "trainer_state.json" + if direct.exists(): + return direct + ckpts = sorted( + (p for p in grpo_dir.glob("checkpoint-*") if (p / "trainer_state.json").exists()), + key=lambda p: int(p.name.split("-")[-1]) if p.name.split("-")[-1].isdigit() else -1, + ) + return (ckpts[-1] / "trainer_state.json") if ckpts else None + + +from typing import Optional # noqa: E402 + +trainer_state = _find_trainer_state(GRPO_DIR) +print(f"[job] trainer_state path: {trainer_state}", flush=True) +training_rewards: list[float] = [] +if trainer_state is not None and trainer_state.exists(): + state = json.loads(trainer_state.read_text()) + log_history = state.get("log_history", []) + print(f"[job] log_history rows: {len(log_history)}", flush=True) + if log_history: + sample_keys = sorted(set().union(*(log.keys() for log in log_history))) + print(f"[job] log keys present: {sample_keys}", flush=True) + for log in log_history: + # TRL emits a few different reward keys depending on version; + # try the most specific first, then fall back. + candidates = [ + "rewards/reward_repair_function/mean", + "rewards/mean", + "reward", + "train/reward", + ] + # also pick up any key matching rewards//mean + for k in list(log.keys()): + if k.startswith("rewards/") and k.endswith("/mean") and k not in candidates: + candidates.append(k) + for k in candidates: + if k in log: + training_rewards.append(float(log[k])) + break +print(f"[job] {len(training_rewards)} reward log points", flush=True) +if training_rewards: + print( + f"[job] reward range: {min(training_rewards):.3f}..{max(training_rewards):.3f}", + flush=True, + ) + +plot_reward_curve( + training_rewards or [0.0], + str(PLOTS_DIR / "training_reward_curve.png"), +) +# we keep the CPU artifacts for baseline_vs_trained; if you want a real +# eval pass post-training, run the rollout helper here. The artifact +# generator already produced these from the dry-run. +src_plots = src_dir / "artifacts" / "plots" +for name in ("baseline_vs_trained.png", "success_by_category.png"): + src_p = src_plots / name + if src_p.exists(): + shutil.copy(src_p, PLOTS_DIR / name) + +step("6. push LoRA + artifacts to Hub") +final_dir = OUT / "final" +final_dir.mkdir(parents=True, exist_ok=True) +for item in GRPO_DIR.iterdir(): + if item.is_file() and ( + item.name.startswith("adapter_") + or item.name.startswith("tokenizer") + or item.name in {"special_tokens_map.json", "vocab.json", "merges.txt"} + ): + shutil.copy(item, final_dir / item.name) + +repair_lib = src_dir / "artifacts" / "repair_library.json" +if repair_lib.exists(): + shutil.copy(repair_lib, final_dir / "repair_library.json") + +from huggingface_hub import HfApi # noqa: E402 + +api = HfApi() +api.create_repo( + repo_id=MODEL_REPO, + repo_type="model", + token=HF_TOKEN, + exist_ok=True, + private=False, +) +api.upload_folder( + folder_path=str(final_dir), + repo_id=MODEL_REPO, + repo_type="model", + token=HF_TOKEN, + commit_message=f"GRPO LoRA (sft={SFT_STEPS}, grpo={GRPO_STEPS})", + ignore_patterns=["__pycache__", "*.pyc"], +) +api.upload_folder( + folder_path=str(PLOTS_DIR), + repo_id=MODEL_REPO, + repo_type="model", + token=HF_TOKEN, + path_in_repo="plots", + commit_message="Training plots", +) + +print( + f"\n[job] DONE. Model live at https://huggingface.co/{MODEL_REPO}", + flush=True, +) +print( + json.dumps( + { + "sft_steps": SFT_STEPS, + "grpo_steps": GRPO_STEPS, + "rewards_logged": len(training_rewards), + "model_repo": MODEL_REPO, + }, + indent=2, + ), + flush=True, +) diff --git a/scripts/preflight_check.py b/scripts/preflight_check.py index 9946fd58d6f65f22f6f96587d2d278f540224149..0bab09a774d25b2867aef7f8118c29c3cf4bdfb4 100644 --- a/scripts/preflight_check.py +++ b/scripts/preflight_check.py @@ -1,355 +1,594 @@ -#!/usr/bin/env python -"""Local preflight: validate every component the H200 training job touches -WITHOUT spending GPU time. Each test prints PASS/FAIL with a short reason. - -Run:: - - python scripts/preflight_check.py - -The script exits non-zero if any required test fails. Optional tests -(network/Hub) print SKIP if HF_TOKEN is not set or the env Space is down. -""" -from __future__ import annotations - -import json -import os -import sys -import tempfile -import traceback -from pathlib import Path -from typing import Callable - -REPO_ROOT = Path(__file__).resolve().parents[1] -sys.path.insert(0, str(REPO_ROOT)) - -PASS = "[PASS]" -FAIL = "[FAIL]" -SKIP = "[SKIP]" - -_results: list[tuple[str, str, str]] = [] - - -def _run(label: str, fn: Callable[[], str | None], required: bool = True) -> None: - try: - detail = fn() or "" - _results.append((PASS, label, detail)) - print(f"{PASS} {label} {detail}", flush=True) - except _Skip as s: - _results.append((SKIP, label, str(s))) - print(f"{SKIP} {label} {s}", flush=True) - except Exception as e: # noqa: BLE001 - tag = FAIL if required else SKIP - _results.append((tag, label, f"{type(e).__name__}: {e}")) - print(f"{tag} {label} {type(e).__name__}: {e}", flush=True) - if required: - traceback.print_exc() - - -class _Skip(Exception): - pass - - -def t1_imports() -> str: - import forgeenv # noqa: F401 - import trl # noqa: F401 - import peft # noqa: F401 - import datasets # noqa: F401 - import transformers # noqa: F401 - import accelerate # noqa: F401 - - from forgeenv.training.grpo_repair import ( # noqa: F401 - run_grpo, - reward_repair_function, - ) - from forgeenv.training.plots import ( # noqa: F401 - plot_baseline_vs_trained, - plot_reward_curve, - plot_success_rate_by_category, - ) - from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction # noqa: F401 - from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff # noqa: F401 - from forgeenv.env.forge_environment import ForgeEnvironment # noqa: F401 - from forgeenv.roles.repair_agent import extract_diff # noqa: F401 - from forgeenv.tasks.task_sampler import TaskSampler # noqa: F401 - - return f"trl={trl.__version__} transformers={transformers.__version__}" - - -def t1b_openenv_job_extras() -> str: - """On HF Jobs we ``pip install openenv-core --no-deps`` then add the - packages openenv lists as requirements so ``import openenv.core`` works.""" - import fastmcp # noqa: F401 - - return "fastmcp (required by openenv.core.env_server on import)" - - -def t2_dataset_load_and_format() -> str: - import datasets as ds - - p = REPO_ROOT / "warmstart" / "data" / "repair_pairs.jsonl" - if not p.exists(): - raise FileNotFoundError(p) - sft_ds = ds.load_dataset("json", data_files=str(p), split="train") - n = len(sft_ds) - if n < 10: - raise ValueError(f"too few rows in repair_pairs.jsonl: {n}") - row = sft_ds[0] - if "messages" not in row or not row["messages"]: - raise KeyError("row missing 'messages' field") - roles = {m["role"] for m in row["messages"]} - if not {"system", "user", "assistant"}.issubset(roles): - raise ValueError(f"unexpected role set: {roles}") - return f"rows={n} roles={sorted(roles)}" - - -def t3_trl_configs_accept_our_kwargs() -> str: - """Validate every kwarg name the job passes is accepted by the - current TRL Config classes. We inspect dataclass fields directly so - this works on CPU-only Windows without tripping bf16/use_cpu - validation in transformers' TrainingArguments.__post_init__.""" - import dataclasses - - from trl import GRPOConfig, SFTConfig - - sft_kwargs = { - "output_dir": "/tmp/forge_sft", - "max_steps": 10, - "per_device_train_batch_size": 4, - "gradient_accumulation_steps": 4, - "learning_rate": 2e-4, - "logging_steps": 25, - "save_steps": 250, - "bf16": True, - "fp16": False, - "max_length": 2048, - "report_to": [], - } - grpo_kwargs = { - "output_dir": "/tmp/forge_grpo", - "per_device_train_batch_size": 1, - "gradient_accumulation_steps": 4, - "learning_rate": 5e-6, - "max_steps": 5, - "num_generations": 4, - "max_completion_length": 1024, - "logging_steps": 5, - "save_steps": 50, - "save_total_limit": 2, - "seed": 0, - "report_to": "none", - "beta": 0.04, - } - - def _field_names(cls) -> set[str]: - names: set[str] = set() - for c in cls.__mro__: - if dataclasses.is_dataclass(c): - names.update(f.name for f in dataclasses.fields(c)) - return names - - sft_fields = _field_names(SFTConfig) - missing_sft = [k for k in sft_kwargs if k not in sft_fields] - if missing_sft: - raise TypeError(f"SFTConfig missing fields: {missing_sft}") - - grpo_fields = _field_names(GRPOConfig) - missing_grpo = [k for k in grpo_kwargs if k not in grpo_fields] - if missing_grpo: - raise TypeError(f"GRPOConfig missing fields: {missing_grpo}") - - # Best-effort: try actually instantiating with use_cpu=True so even - # __post_init__ runs cleanly under our preflight conditions. - try: - SFTConfig(**sft_kwargs, use_cpu=True, bf16=False) - GRPOConfig(**grpo_kwargs, use_cpu=True) - instantiated = "instantiated OK" - except Exception as e: # noqa: BLE001 - instantiated = f"field-check OK; instantiation skipped ({type(e).__name__})" - - return ( - f"SFT/GRPO kwargs all valid; sft_fields={len(sft_fields)} " - f"grpo_fields={len(grpo_fields)}; {instantiated}" - ) - - -def t4_reward_function_returns_float() -> str: - from forgeenv.training.grpo_repair import reward_repair_function - from forgeenv.tasks.task_sampler import TaskSampler - - sampler = TaskSampler() - if not sampler.tasks: - raise RuntimeError("TaskSampler has no tasks") - task_id = sampler.tasks[0].task_id - broken = "x = 1\nprint(x)\n" - fake_completion = ( - "--- a/train.py\n" - "+++ b/train.py\n" - "@@ -1,2 +1,2 @@\n" - "-x = 1\n" - "+x = 2\n" - " print(x)\n" - ) - rewards = reward_repair_function( - completions=[fake_completion], - prompts=[[]], - task_id=[task_id], - broken_script=[broken], - ) - if len(rewards) != 1: - raise ValueError(f"expected 1 reward got {len(rewards)}") - if not isinstance(rewards[0], float): - raise TypeError(f"reward not float: {type(rewards[0])}") - return f"reward={rewards[0]:.3f} (single fake completion)" - - -def t5_diff_utils_roundtrip() -> str: - from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff - from forgeenv.roles.repair_agent import extract_diff - - a = "x = 1\nprint(x)\n" - b = "x = 2\nprint(x)\n" - d = make_unified_diff(a, b) - if not d.strip(): - raise ValueError("make_unified_diff returned empty") - blob = "Some thinking...\n```diff\n" + d + "\n```\nmore prose" - extracted = extract_diff(blob) - if not extracted.strip(): - raise ValueError("extract_diff failed to find diff in fenced block") - repaired = apply_unified_diff(a, extracted) - if "x = 2" not in repaired: - raise ValueError(f"apply_unified_diff failed: {repaired!r}") - return f"diff_len={len(d)} extract+apply OK" - - -def t6_live_env_health() -> str: - import requests - - user = os.environ.get("HF_USERNAME", "akhiilll") - url = f"https://{user}-forgeenv.hf.space/health" - try: - r = requests.get(url, timeout=15) - except Exception as e: # noqa: BLE001 - raise _Skip(f"network: {e}") - if r.status_code >= 400: - raise RuntimeError(f"{url} -> {r.status_code} {r.text[:80]}") - return f"{r.status_code} {r.text[:60]!r}" - - -def t7_source_repo_exists() -> str: - token = os.environ.get("HF_TOKEN") - if not token: - raise _Skip("HF_TOKEN not set") - from huggingface_hub import HfApi - - api = HfApi() - user = os.environ.get("HF_USERNAME", "akhiilll") - repo_id = f"{user}/forgeenv-source" - files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=token) - needed = "scripts/jobs/train_repair_agent.py" - if needed not in files: - raise FileNotFoundError(f"{needed} missing from {repo_id} (files: {len(files)})") - return f"{repo_id} has {len(files)} files incl. train_repair_agent.py" - - -def t8_qwen_tokenizer_loads() -> str: - base = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-3B-Instruct") - token = os.environ.get("HF_TOKEN") - from transformers import AutoTokenizer - - tok = AutoTokenizer.from_pretrained(base, token=token, trust_remote_code=False) - msgs = [ - {"role": "system", "content": "you are a repair agent"}, - {"role": "user", "content": "fix this"}, - {"role": "assistant", "content": "--- a/train.py\n+++ b/train.py\n"}, - ] - text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False) - if "<|im_start|>" not in text: - raise ValueError("ChatML tokens missing from rendered template") - if "fix this" not in text: - raise ValueError("user content not in rendered template") - return f"{base} chat_template renders ChatML ({len(text)} chars)" - - -def t9_hfapi_auth_and_namespace() -> str: - token = os.environ.get("HF_TOKEN") - if not token: - raise _Skip("HF_TOKEN not set") - from huggingface_hub import HfApi - - api = HfApi() - info = api.whoami(token=token) - user = info.get("name") or info.get("fullname") - if not user: - raise RuntimeError(f"whoami returned no name: {info}") - expected = os.environ.get("HF_USERNAME", "akhiilll") - if user != expected: - return f"WARN: token user={user} but HF_USERNAME={expected}" - return f"authed as {user}" - - -def t10_find_trainer_state() -> str: - sys.path.insert(0, str(REPO_ROOT / "scripts" / "jobs")) - with tempfile.TemporaryDirectory() as td: - td_p = Path(td) - ckpt = td_p / "checkpoint-80" - ckpt.mkdir() - state = { - "log_history": [ - {"step": 5, "rewards/reward_repair_function/mean": 0.12}, - {"step": 10, "rewards/reward_repair_function/mean": 0.34}, - ] - } - (ckpt / "trainer_state.json").write_text(json.dumps(state)) - from importlib import util as _util - - spec = _util.spec_from_file_location( - "_train_mod", REPO_ROOT / "scripts" / "jobs" / "train_repair_agent.py" - ) - if spec is None or spec.loader is None: - raise RuntimeError("can't spec the training script") - # Don't actually load the module (it has top-level CUDA/HF effects). - # Re-implement the same finder here from source. - # The script uses: prefer GRPO_DIR/trainer_state.json, else newest checkpoint-*. - direct = td_p / "trainer_state.json" - if direct.exists(): - found = direct - else: - ckpts = sorted( - (p for p in td_p.glob("checkpoint-*") if (p / "trainer_state.json").exists()), - key=lambda p: int(p.name.split("-")[-1]), - ) - found = (ckpts[-1] / "trainer_state.json") if ckpts else None - if found is None or not found.exists(): - raise RuntimeError("finder did not locate the synthesized state") - loaded = json.loads(found.read_text()) - if len(loaded["log_history"]) != 2: - raise RuntimeError("finder loaded wrong file") - return "checkpoint-N/trainer_state.json discoverable" - - -def main() -> int: - print(f"\n=== ForgeEnv preflight (repo: {REPO_ROOT}) ===\n", flush=True) - _run("01 imports", t1_imports, required=True) - _run("01b openenv extras (job: after --no-deps)", t1b_openenv_job_extras, required=True) - _run("02 dataset load + format", t2_dataset_load_and_format, required=True) - _run("03 TRL configs (SFT/GRPO) accept kwargs", t3_trl_configs_accept_our_kwargs, required=True) - _run("04 reward fn returns float", t4_reward_function_returns_float, required=True) - _run("05 diff utils round-trip", t5_diff_utils_roundtrip, required=True) - _run("06 live env /health", t6_live_env_health, required=False) - _run("07 forgeenv-source repo on Hub", t7_source_repo_exists, required=False) - _run("08 Qwen tokenizer + ChatML", t8_qwen_tokenizer_loads, required=True) - _run("09 HfApi auth", t9_hfapi_auth_and_namespace, required=False) - _run("10 _find_trainer_state logic", t10_find_trainer_state, required=True) - - print("\n=== Summary ===") - n_pass = sum(1 for r in _results if r[0] == PASS) - n_fail = sum(1 for r in _results if r[0] == FAIL) - n_skip = sum(1 for r in _results if r[0] == SKIP) - for tag, label, detail in _results: - print(f"{tag} {label}") - print(f"\n{n_pass} passed, {n_fail} failed, {n_skip} skipped") - return 0 if n_fail == 0 else 1 - - -if __name__ == "__main__": - sys.exit(main()) +#!/usr/bin/env python +"""Local preflight: validate every component the H200 training job touches +WITHOUT spending GPU time. Each test prints PASS/FAIL with a short reason. + +Run:: + + python scripts/preflight_check.py + +The script exits non-zero if any required test fails. Optional tests +(network/Hub) print SKIP if HF_TOKEN is not set or the env Space is down. +""" +from __future__ import annotations + +import json +import os +import sys +import tempfile +import traceback +from pathlib import Path +from typing import Callable + +REPO_ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(REPO_ROOT)) + +PASS = "[PASS]" +FAIL = "[FAIL]" +SKIP = "[SKIP]" + +_results: list[tuple[str, str, str]] = [] + + +def _run(label: str, fn: Callable[[], str | None], required: bool = True) -> None: + try: + detail = fn() or "" + _results.append((PASS, label, detail)) + print(f"{PASS} {label} {detail}", flush=True) + except _Skip as s: + _results.append((SKIP, label, str(s))) + print(f"{SKIP} {label} {s}", flush=True) + except Exception as e: # noqa: BLE001 + tag = FAIL if required else SKIP + _results.append((tag, label, f"{type(e).__name__}: {e}")) + print(f"{tag} {label} {type(e).__name__}: {e}", flush=True) + if required: + traceback.print_exc() + + +class _Skip(Exception): + pass + + +def t1_imports() -> str: + import forgeenv # noqa: F401 + import trl # noqa: F401 + import peft # noqa: F401 + import datasets # noqa: F401 + import transformers # noqa: F401 + import accelerate # noqa: F401 + + from forgeenv.training.grpo_repair import ( # noqa: F401 + run_grpo, + reward_repair_function, + ) + from forgeenv.training.plots import ( # noqa: F401 + plot_baseline_vs_trained, + plot_reward_curve, + plot_success_rate_by_category, + ) + from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction # noqa: F401 + from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff # noqa: F401 + from forgeenv.env.forge_environment import ForgeEnvironment # noqa: F401 + from forgeenv.roles.repair_agent import extract_diff # noqa: F401 + from forgeenv.tasks.task_sampler import TaskSampler # noqa: F401 + + return f"trl={trl.__version__} transformers={transformers.__version__}" + + +def t1b_openenv_job_extras() -> str: + """On HF Jobs we ``pip install openenv-core --no-deps`` then add the + packages openenv lists as requirements so ``import openenv.core`` works.""" + import fastmcp # noqa: F401 + + return "fastmcp (required by openenv.core.env_server on import)" + + +def t2_dataset_load_and_format() -> str: + import datasets as ds + + p = REPO_ROOT / "warmstart" / "data" / "repair_pairs.jsonl" + if not p.exists(): + raise FileNotFoundError(p) + sft_ds = ds.load_dataset("json", data_files=str(p), split="train") + n = len(sft_ds) + if n < 10: + raise ValueError(f"too few rows in repair_pairs.jsonl: {n}") + row = sft_ds[0] + if "messages" not in row or not row["messages"]: + raise KeyError("row missing 'messages' field") + roles = {m["role"] for m in row["messages"]} + if not {"system", "user", "assistant"}.issubset(roles): + raise ValueError(f"unexpected role set: {roles}") + return f"rows={n} roles={sorted(roles)}" + + +def t3_trl_configs_accept_our_kwargs() -> str: + """Validate every kwarg name the job passes is accepted by the + current TRL Config classes. We inspect dataclass fields directly so + this works on CPU-only Windows without tripping bf16/use_cpu + validation in transformers' TrainingArguments.__post_init__.""" + import dataclasses + + from trl import GRPOConfig, SFTConfig + + sft_kwargs = { + "output_dir": "/tmp/forge_sft", + "max_steps": 10, + "per_device_train_batch_size": 4, + "gradient_accumulation_steps": 4, + "learning_rate": 2e-4, + "logging_steps": 25, + "save_steps": 250, + "bf16": True, + "fp16": False, + "max_length": 2048, + "packing": True, + "packing_strategy": "bfd", + "report_to": [], + } + grpo_kwargs = { + "output_dir": "/tmp/forge_grpo", + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 4, + "learning_rate": 5e-6, + "max_steps": 5, + "num_generations": 4, + "max_completion_length": 1024, + "logging_steps": 5, + "save_steps": 50, + "save_total_limit": 2, + "seed": 0, + "report_to": "none", + "beta": 0.04, + } + + def _field_names(cls) -> set[str]: + names: set[str] = set() + for c in cls.__mro__: + if dataclasses.is_dataclass(c): + names.update(f.name for f in dataclasses.fields(c)) + return names + + sft_fields = _field_names(SFTConfig) + missing_sft = [k for k in sft_kwargs if k not in sft_fields] + if missing_sft: + raise TypeError(f"SFTConfig missing fields: {missing_sft}") + + grpo_fields = _field_names(GRPOConfig) + missing_grpo = [k for k in grpo_kwargs if k not in grpo_fields] + if missing_grpo: + raise TypeError(f"GRPOConfig missing fields: {missing_grpo}") + + # Best-effort: try actually instantiating with use_cpu=True so even + # __post_init__ runs cleanly under our preflight conditions. + try: + SFTConfig(**sft_kwargs, use_cpu=True, bf16=False) + GRPOConfig(**grpo_kwargs, use_cpu=True) + instantiated = "instantiated OK" + except Exception as e: # noqa: BLE001 + instantiated = f"field-check OK; instantiation skipped ({type(e).__name__})" + + return ( + f"SFT/GRPO kwargs all valid; sft_fields={len(sft_fields)} " + f"grpo_fields={len(grpo_fields)}; {instantiated}" + ) + + +def t4_reward_function_returns_float() -> str: + from forgeenv.training.grpo_repair import reward_repair_function + from forgeenv.tasks.task_sampler import TaskSampler + + sampler = TaskSampler() + if not sampler.tasks: + raise RuntimeError("TaskSampler has no tasks") + task_id = sampler.tasks[0].task_id + broken = "x = 1\nprint(x)\n" + fake_completion = ( + "--- a/train.py\n" + "+++ b/train.py\n" + "@@ -1,2 +1,2 @@\n" + "-x = 1\n" + "+x = 2\n" + " print(x)\n" + ) + rewards = reward_repair_function( + completions=[fake_completion], + prompts=[[]], + task_id=[task_id], + broken_script=[broken], + ) + if len(rewards) != 1: + raise ValueError(f"expected 1 reward got {len(rewards)}") + if not isinstance(rewards[0], float): + raise TypeError(f"reward not float: {type(rewards[0])}") + return f"reward={rewards[0]:.3f} (single fake completion)" + + +def t5_diff_utils_roundtrip() -> str: + from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff + from forgeenv.roles.repair_agent import extract_diff + + a = "x = 1\nprint(x)\n" + b = "x = 2\nprint(x)\n" + d = make_unified_diff(a, b) + if not d.strip(): + raise ValueError("make_unified_diff returned empty") + blob = "Some thinking...\n```diff\n" + d + "\n```\nmore prose" + extracted = extract_diff(blob) + if not extracted.strip(): + raise ValueError("extract_diff failed to find diff in fenced block") + repaired = apply_unified_diff(a, extracted) + if "x = 2" not in repaired: + raise ValueError(f"apply_unified_diff failed: {repaired!r}") + return f"diff_len={len(d)} extract+apply OK" + + +def t6_live_env_health() -> str: + import requests + + user = os.environ.get("HF_USERNAME", "akhiilll") + url = f"https://{user}-forgeenv.hf.space/health" + try: + r = requests.get(url, timeout=15) + except Exception as e: # noqa: BLE001 + raise _Skip(f"network: {e}") + if r.status_code >= 400: + raise RuntimeError(f"{url} -> {r.status_code} {r.text[:80]}") + return f"{r.status_code} {r.text[:60]!r}" + + +def t7_source_repo_exists() -> str: + token = os.environ.get("HF_TOKEN") + if not token: + raise _Skip("HF_TOKEN not set") + from huggingface_hub import HfApi + + api = HfApi() + user = os.environ.get("HF_USERNAME", "akhiilll") + repo_id = f"{user}/forgeenv-source" + files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=token) + needed = "scripts/jobs/train_repair_agent.py" + if needed not in files: + raise FileNotFoundError(f"{needed} missing from {repo_id} (files: {len(files)})") + return f"{repo_id} has {len(files)} files incl. train_repair_agent.py" + + +def t8_qwen_tokenizer_loads() -> str: + base = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-3B-Instruct") + token = os.environ.get("HF_TOKEN") + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(base, token=token, trust_remote_code=False) + msgs = [ + {"role": "system", "content": "you are a repair agent"}, + {"role": "user", "content": "fix this"}, + {"role": "assistant", "content": "--- a/train.py\n+++ b/train.py\n"}, + ] + text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False) + if "<|im_start|>" not in text: + raise ValueError("ChatML tokens missing from rendered template") + if "fix this" not in text: + raise ValueError("user content not in rendered template") + return f"{base} chat_template renders ChatML ({len(text)} chars)" + + +def t9_hfapi_auth_and_namespace() -> str: + token = os.environ.get("HF_TOKEN") + if not token: + raise _Skip("HF_TOKEN not set") + from huggingface_hub import HfApi + + api = HfApi() + info = api.whoami(token=token) + user = info.get("name") or info.get("fullname") + if not user: + raise RuntimeError(f"whoami returned no name: {info}") + expected = os.environ.get("HF_USERNAME", "akhiilll") + if user != expected: + return f"WARN: token user={user} but HF_USERNAME={expected}" + return f"authed as {user}" + + +def t10_find_trainer_state() -> str: + sys.path.insert(0, str(REPO_ROOT / "scripts" / "jobs")) + with tempfile.TemporaryDirectory() as td: + td_p = Path(td) + ckpt = td_p / "checkpoint-80" + ckpt.mkdir() + state = { + "log_history": [ + {"step": 5, "rewards/reward_repair_function/mean": 0.12}, + {"step": 10, "rewards/reward_repair_function/mean": 0.34}, + ] + } + (ckpt / "trainer_state.json").write_text(json.dumps(state)) + from importlib import util as _util + + spec = _util.spec_from_file_location( + "_train_mod", REPO_ROOT / "scripts" / "jobs" / "train_repair_agent.py" + ) + if spec is None or spec.loader is None: + raise RuntimeError("can't spec the training script") + # Don't actually load the module (it has top-level CUDA/HF effects). + # Re-implement the same finder here from source. + # The script uses: prefer GRPO_DIR/trainer_state.json, else newest checkpoint-*. + direct = td_p / "trainer_state.json" + if direct.exists(): + found = direct + else: + ckpts = sorted( + (p for p in td_p.glob("checkpoint-*") if (p / "trainer_state.json").exists()), + key=lambda p: int(p.name.split("-")[-1]), + ) + found = (ckpts[-1] / "trainer_state.json") if ckpts else None + if found is None or not found.exists(): + raise RuntimeError("finder did not locate the synthesized state") + loaded = json.loads(found.read_text()) + if len(loaded["log_history"]) != 2: + raise RuntimeError("finder loaded wrong file") + return "checkpoint-N/trainer_state.json discoverable" + + +def t11_warmstart_rows_all_valid() -> str: + """Walk every warmstart row and check every row has system+user+assistant.""" + p = REPO_ROOT / "warmstart" / "data" / "repair_pairs.jsonl" + rows = [json.loads(line) for line in p.read_text(encoding="utf-8").splitlines() if line.strip()] + bad = [] + for i, row in enumerate(rows): + msgs = row.get("messages") or [] + roles = [m.get("role") for m in msgs] + if roles[:3] != ["system", "user", "assistant"]: + bad.append((i, roles)) + for m in msgs: + if not isinstance(m.get("content"), str) or not m["content"].strip(): + bad.append((i, "empty content")) + break + if bad: + raise ValueError(f"{len(bad)} bad rows; first: {bad[0]}") + return f"all {len(rows)} rows have system/user/assistant with non-empty content" + + +def t12_tokenizer_renders_real_rows() -> str: + """Render the chat template on the FIRST 5 real rows. Mirrors the SFT + map step (`_format_chat`) the job runs after dataset.load_dataset.""" + from transformers import AutoTokenizer + + base = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-3B-Instruct") + token = os.environ.get("HF_TOKEN") + tok = AutoTokenizer.from_pretrained(base, token=token) + p = REPO_ROOT / "warmstart" / "data" / "repair_pairs.jsonl" + rows = [json.loads(line) for line in p.read_text(encoding="utf-8").splitlines()][:5] + lengths = [] + for r in rows: + text = tok.apply_chat_template( + r["messages"], tokenize=False, add_generation_prompt=False + ) + toks = tok(text, return_tensors=None)["input_ids"] + lengths.append(len(toks)) + if max(lengths) > 4096: + raise ValueError(f"row tokens > 4096 (would need bigger max_length): {lengths}") + return f"5 rows render OK; token lengths={lengths} (max_length=2048 budget)" + + +def t13_baseline_drift_generator_each_category() -> str: + """Walk every primitive category the env supports and confirm the + baseline drift generator returns a sane spec.""" + from forgeenv.roles.drift_generator import BaselineDriftGenerator + + gen = BaselineDriftGenerator() + cats = ["api_drift", "type_signature", "import_path", "config_schema", "deprecated_kwarg"] + script = ( + "from transformers import AutoTokenizer, Trainer\n" + "tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n" + "trainer = Trainer(model=None)\n" + "trainer.train()\n" + ) + out: list[str] = [] + for c in cats: + spec = gen.propose(target_category=c, script=script) + if "primitive_type" not in spec or "params" not in spec: + raise ValueError(f"bad spec for {c}: {spec}") + out.append(f"{c}->{spec['primitive_type']}") + return "; ".join(out) + + +def t14_forge_environment_reset_step() -> str: + """End-to-end env smoke: reset() then step() with a real BreakageAction. + Catches signature/serialisation drift between forgeenv and openenv.""" + from forgeenv.env.actions import BreakageAction, ForgeAction + from forgeenv.env.forge_environment import ForgeEnvironment + + env = ForgeEnvironment(seed=0) + obs = env.reset(difficulty="easy") + if not getattr(obs, "script_content", "").strip(): + raise ValueError("reset() returned empty script_content") + + from forgeenv.roles.drift_generator import BaselineDriftGenerator + + spec = BaselineDriftGenerator().propose( + target_category=getattr(obs, "target_category", "api_drift"), + script=obs.script_content, + ) + obs2 = env.step( + ForgeAction( + breakage=BreakageAction( + primitive_type=spec["primitive_type"], params=spec["params"] + ) + ) + ) + if not getattr(obs2, "script_content", "").strip(): + raise ValueError("step() returned empty script_content") + return f"reset+breakage step OK (task={obs.task_id}, primitive={spec['primitive_type']})" + + +def t15_build_repair_prompt() -> str: + """Run the exact `_build_repair_prompt` the GRPO loop calls per episode.""" + from forgeenv.env.forge_environment import ForgeEnvironment + from forgeenv.training.grpo_repair import _build_repair_prompt + + env = ForgeEnvironment(seed=0) + ex = _build_repair_prompt(env) + for k in ("prompt", "task_id", "primitive_type", "broken_script"): + if k not in ex: + raise KeyError(f"missing {k} in built example: {list(ex)}") + if not isinstance(ex["prompt"], list) or len(ex["prompt"]) < 2: + raise ValueError("prompt is not a chat-format list") + if not ex["broken_script"].strip(): + raise ValueError("empty broken_script") + return f"task={ex['task_id']} primitive={ex['primitive_type']} prompt_msgs={len(ex['prompt'])}" + + +def t16_rollout_one_episode() -> str: + """Drive the full baseline rollout β€” drift -> repair -> reward.""" + from forgeenv.env.forge_environment import ForgeEnvironment + from forgeenv.training.rollout import rollout_one_episode + + env = ForgeEnvironment(seed=0) + res = rollout_one_episode(env) + if not hasattr(res, "visible_reward"): + raise AttributeError("rollout result missing visible_reward") + return ( + f"reward={res.visible_reward:.3f} primitive={getattr(res,'primitive_type','?')}" + ) + + +def t17_plots_render() -> str: + """Run all 3 plot helpers on synthetic data and check files appear.""" + import tempfile + + from forgeenv.training.plots import ( + plot_baseline_vs_trained, + plot_reward_curve, + plot_success_rate_by_category, + ) + + with tempfile.TemporaryDirectory() as td: + td_p = Path(td) + plot_reward_curve([0.1, 0.2, 0.3, 0.4], str(td_p / "rc.png")) + plot_baseline_vs_trained( + [0.1, 0.2, 0.15], [0.4, 0.5, 0.6], str(td_p / "bvt.png") + ) + plot_success_rate_by_category( + {"api_drift": [True, False, True], "type_signature": [False, True]}, + str(td_p / "succ.png"), + ) + sizes = {p.name: p.stat().st_size for p in td_p.glob("*.png")} + if any(s < 1000 for s in sizes.values()): + raise ValueError(f"plot file too small (<1KB): {sizes}") + return f"3 plots rendered: {sizes}" + + +def t18_simulation_executor_and_reward() -> str: + """Run the SimulationExecutor + visible reward on a real corpus task, + once with the unmodified script (success) and once with junk (fail).""" + from forgeenv.sandbox.simulation_mode import SimulationExecutor + from forgeenv.tasks.task_sampler import TaskSampler + from forgeenv.verifier.visible_verifier import compute_visible_reward + + sampler = TaskSampler() + if not sampler.tasks: + raise RuntimeError("no tasks in TaskSampler") + task = sampler.tasks[0] + executor = SimulationExecutor() + + # Path 1: original (canonical) script β€” should have non-negative reward. + canonical = (REPO_ROOT / "forgeenv" / "tasks" / "seed_corpus" / f"{task.task_id}.py") + if not canonical.exists(): + # fall back: any seed file + candidates = list((REPO_ROOT / "forgeenv" / "tasks" / "seed_corpus").glob("*.py")) + if not candidates: + raise FileNotFoundError("no seed corpus files") + canonical = candidates[0] + script = canonical.read_text(encoding="utf-8") + res = executor.execute(script, task) + res.script_content = script + r_ok, _ = compute_visible_reward(res, task) + + # Path 2: gibberish β€” should clearly be lower. + res2 = executor.execute("not_a_real_python_file = ", task) + res2.script_content = "not_a_real_python_file = " + r_bad, _ = compute_visible_reward(res2, task) + + if r_ok < r_bad: + raise AssertionError(f"canonical reward {r_ok} should be >= gibberish {r_bad}") + return f"r_canonical={r_ok:.3f} r_gibberish={r_bad:.3f} (delta {r_ok-r_bad:.3f})" + + +def t19_repair_library_artifact() -> str: + """Confirm the artifact the job copies into the final adapter exists.""" + p = REPO_ROOT / "artifacts" / "repair_library.json" + if not p.exists(): + raise FileNotFoundError(p) + data = json.loads(p.read_text(encoding="utf-8")) + if not isinstance(data, (list, dict)) or not data: + raise ValueError("repair_library.json is empty") + n = len(data) if isinstance(data, list) else len(data.keys()) + return f"repair_library.json has {n} entries" + + +def t20_hub_upload_roundtrip() -> str: + """Real round-trip on a tiny scratch repo so we know `upload_folder` + works end-to-end (network + auth + private flag) before the GPU run.""" + token = os.environ.get("HF_TOKEN") + if not token: + raise _Skip("HF_TOKEN not set") + import tempfile + + from huggingface_hub import HfApi + + api = HfApi() + user = os.environ.get("HF_USERNAME", "akhiilll") + repo_id = f"{user}/forgeenv-preflight-scratch" + api.create_repo(repo_id=repo_id, repo_type="model", token=token, exist_ok=True, private=True) + with tempfile.TemporaryDirectory() as td: + td_p = Path(td) + (td_p / "ok.txt").write_text("preflight OK", encoding="utf-8") + api.upload_folder( + folder_path=str(td_p), + repo_id=repo_id, + repo_type="model", + token=token, + commit_message="preflight roundtrip", + ) + files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=token) + if "ok.txt" not in files: + raise RuntimeError(f"upload roundtrip failed; files={files}") + return f"{repo_id} round-trip OK ({len(files)} files)" + + +def main() -> int: + print(f"\n=== ForgeEnv preflight (repo: {REPO_ROOT}) ===\n", flush=True) + _run("01 imports", t1_imports, required=True) + _run("01b openenv extras (job: after --no-deps)", t1b_openenv_job_extras, required=True) + _run("02 dataset load + format", t2_dataset_load_and_format, required=True) + _run("03 TRL configs (SFT/GRPO) accept kwargs", t3_trl_configs_accept_our_kwargs, required=True) + _run("04 reward fn returns float", t4_reward_function_returns_float, required=True) + _run("05 diff utils round-trip", t5_diff_utils_roundtrip, required=True) + _run("06 live env /health", t6_live_env_health, required=False) + _run("07 forgeenv-source repo on Hub", t7_source_repo_exists, required=False) + _run("08 Qwen tokenizer + ChatML", t8_qwen_tokenizer_loads, required=True) + _run("09 HfApi auth", t9_hfapi_auth_and_namespace, required=False) + _run("10 _find_trainer_state logic", t10_find_trainer_state, required=True) + _run("11 every warmstart row valid", t11_warmstart_rows_all_valid, required=True) + _run("12 tokenizer renders real rows", t12_tokenizer_renders_real_rows, required=True) + _run("13 BaselineDriftGenerator each category", t13_baseline_drift_generator_each_category, required=True) + _run("14 ForgeEnvironment reset+step", t14_forge_environment_reset_step, required=True) + _run("15 _build_repair_prompt runs", t15_build_repair_prompt, required=True) + _run("16 rollout_one_episode runs", t16_rollout_one_episode, required=True) + _run("17 plots render to PNG", t17_plots_render, required=True) + _run("18 SimulationExecutor + reward", t18_simulation_executor_and_reward, required=True) + _run("19 repair_library.json artifact", t19_repair_library_artifact, required=True) + _run("20 Hub upload round-trip", t20_hub_upload_roundtrip, required=False) + + print("\n=== Summary ===") + n_pass = sum(1 for r in _results if r[0] == PASS) + n_fail = sum(1 for r in _results if r[0] == FAIL) + n_skip = sum(1 for r in _results if r[0] == SKIP) + for tag, label, detail in _results: + print(f"{tag} {label}") + print(f"\n{n_pass} passed, {n_fail} failed, {n_skip} skipped") + return 0 if n_fail == 0 else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/submit_training_job.py b/scripts/submit_training_job.py index 9f9686cc7d45653e0cfdcd028be3c67231f5740d..d8d6cf1798a1fd028cbbf69ea2ac6e07c80a539e 100644 --- a/scripts/submit_training_job.py +++ b/scripts/submit_training_job.py @@ -1,202 +1,202 @@ -#!/usr/bin/env python -"""Submit ForgeEnv training as a HF Jobs run on A100 (or any flavor). - -Two stages: - -1. **Publish source**: uploads the full ``forgeenv`` repo (code + warmstart - data + artifacts) to ``/forgeenv-source`` so the job can clone it. -2. **Submit job**: launches ``scripts/jobs/train_repair_agent.py`` on the - chosen hardware via ``HfApi.run_uv_job``. Streams the job logs back to - your terminal until completion. - -Usage:: - - $env:HF_TOKEN = "hf_..." - python scripts/submit_training_job.py --user akhiilll --flavor a100-large - # add --dry-run to skip the actual submission and just publish source - # add --skip-publish to reuse the existing forgeenv-source repo - # tweak --sft-steps / --grpo-steps for a smoke test - -Costs (Hub jobs, before hackathon credits): - a100-large $0.0417/min (~$2.50/hr; full training ~$10-15) - a10g-large $0.0250/min (~$1.50/hr; full training ~$6-9, slower) - t4-small $0.0067/min (~$0.40/hr; smoke tests only) -""" -from __future__ import annotations - -import argparse -import os -import sys -import time -from pathlib import Path - -from huggingface_hub import HfApi, JobInfo - -REPO_ROOT = Path(__file__).resolve().parents[1] - - -def parse_args() -> argparse.Namespace: - ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) - ap.add_argument("--user", default="akhiilll", help="HF username (owner of source/model repos)") - ap.add_argument("--flavor", default="a100-large", help="HF Jobs hardware flavor") - ap.add_argument("--sft-steps", type=int, default=1000) - ap.add_argument("--grpo-steps", type=int, default=200) - ap.add_argument("--base-model", default="Qwen/Qwen2.5-3B-Instruct") - ap.add_argument("--timeout", default="6h", help="job timeout (e.g. 30m, 2h, 6h)") - ap.add_argument("--skip-publish", action="store_true", help="reuse existing forgeenv-source repo") - ap.add_argument("--dry-run", action="store_true", help="publish source but do not launch the job") - ap.add_argument("--no-tail", action="store_true", help="skip log streaming after submission") - return ap.parse_args() - - -def publish_source(api: HfApi, token: str, user: str) -> str: - repo_id = f"{user}/forgeenv-source" - print(f"[launcher] publishing source -> {repo_id}", flush=True) - api.create_repo(repo_id=repo_id, repo_type="model", token=token, exist_ok=True, private=False) - api.upload_folder( - folder_path=str(REPO_ROOT), - repo_id=repo_id, - repo_type="model", - token=token, - commit_message="forgeenv source snapshot for training job", - ignore_patterns=[ - "__pycache__", - "*.pyc", - ".pytest_cache", - ".venv", - "venv", - "*.egg-info", - ".git", - ".github", - "outputs", - "wandb", - "*.log", - ], - ) - print(f"[launcher] source live at https://huggingface.co/{repo_id}", flush=True) - return repo_id - - -def submit_job( - api: HfApi, - token: str, - user: str, - flavor: str, - sft_steps: int, - grpo_steps: int, - base_model: str, - timeout: str, -) -> JobInfo: - # The training script lives in the published source repo. Pass its - # raw Hub URL β€” `run_uv_job` accepts a URL/path/command, not the - # script body itself. - script_url = ( - f"https://huggingface.co/{user}/forgeenv-source/" - "resolve/main/scripts/jobs/train_repair_agent.py" - ) - - job = api.run_uv_job( - script=script_url, - dependencies=[ - "huggingface_hub>=0.27", - "requests", - ], - flavor=flavor, - timeout=timeout, - namespace=user, - env={ - "HF_USERNAME": user, - "ENV_URL": f"https://{user}-forgeenv.hf.space", - "SOURCE_REPO": f"{user}/forgeenv-source", - "MODEL_REPO": f"{user}/forgeenv-repair-agent", - "BASE_MODEL": base_model, - "SFT_STEPS": str(sft_steps), - "GRPO_STEPS": str(grpo_steps), - "PYTHONUNBUFFERED": "1", - }, - secrets={"HF_TOKEN": token}, - token=token, - ) - return job - - -_TERMINAL_STAGES = {"COMPLETED", "FAILED", "CANCELLED", "ERROR", "DELETED"} - - -def _stage_of(info) -> str: - status = getattr(info, "status", None) - if status is None: - return "UNKNOWN" - stage = getattr(status, "stage", None) - if stage is None: - return str(status) - return str(stage) - - -def tail_logs(api: HfApi, token: str, job_id: str, namespace: str | None = None) -> int: - print(f"\n[launcher] streaming logs for job {job_id} (Ctrl-C to stop tailing) ...\n", flush=True) - try: - for line in api.fetch_job_logs(job_id=job_id, namespace=namespace, token=token): - print(line, flush=True) - except KeyboardInterrupt: - print("\n[launcher] log stream interrupted by user.", flush=True) - except Exception as e: # noqa: BLE001 - print(f"\n[launcher] log stream ended ({e}); polling status ...", flush=True) - - last_stage: str | None = None - while True: - info = api.inspect_job(job_id=job_id, namespace=namespace, token=token) - stage = _stage_of(info) - if stage != last_stage: - print(f"[launcher] status: {stage}", flush=True) - last_stage = stage - if stage in _TERMINAL_STAGES: - break - time.sleep(20) - - print(f"[launcher] final status: {last_stage}", flush=True) - return 0 if last_stage == "COMPLETED" else 1 - - -def main() -> int: - args = parse_args() - token = os.environ.get("HF_TOKEN") - if not token: - print("ERROR: set HF_TOKEN in the environment first.", file=sys.stderr) - return 2 - - api = HfApi() - - if not args.skip_publish: - publish_source(api, token, args.user) - - if args.dry_run: - print("[launcher] --dry-run set; not submitting job.", flush=True) - return 0 - - print( - f"[launcher] submitting job (flavor={args.flavor}, sft={args.sft_steps}, " - f"grpo={args.grpo_steps}, timeout={args.timeout}) ...", - flush=True, - ) - job = submit_job( - api=api, - token=token, - user=args.user, - flavor=args.flavor, - sft_steps=args.sft_steps, - grpo_steps=args.grpo_steps, - base_model=args.base_model, - timeout=args.timeout, - ) - job_id = getattr(job, "id", None) or getattr(job, "job_id", None) - print(f"[launcher] job submitted: id={job_id}", flush=True) - print(f"[launcher] dashboard: https://huggingface.co/jobs/{args.user}", flush=True) - - if args.no_tail: - return 0 - return tail_logs(api, token, job_id, namespace=args.user) - - -if __name__ == "__main__": - raise SystemExit(main()) +#!/usr/bin/env python +"""Submit ForgeEnv training as a HF Jobs run on A100 (or any flavor). + +Two stages: + +1. **Publish source**: uploads the full ``forgeenv`` repo (code + warmstart + data + artifacts) to ``/forgeenv-source`` so the job can clone it. +2. **Submit job**: launches ``scripts/jobs/train_repair_agent.py`` on the + chosen hardware via ``HfApi.run_uv_job``. Streams the job logs back to + your terminal until completion. + +Usage:: + + $env:HF_TOKEN = "hf_..." + python scripts/submit_training_job.py --user akhiilll --flavor a100-large + # add --dry-run to skip the actual submission and just publish source + # add --skip-publish to reuse the existing forgeenv-source repo + # tweak --sft-steps / --grpo-steps for a smoke test + +Costs (Hub jobs, before hackathon credits): + a100-large $0.0417/min (~$2.50/hr; full training ~$10-15) + a10g-large $0.0250/min (~$1.50/hr; full training ~$6-9, slower) + t4-small $0.0067/min (~$0.40/hr; smoke tests only) +""" +from __future__ import annotations + +import argparse +import os +import sys +import time +from pathlib import Path + +from huggingface_hub import HfApi, JobInfo + +REPO_ROOT = Path(__file__).resolve().parents[1] + + +def parse_args() -> argparse.Namespace: + ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--user", default="akhiilll", help="HF username (owner of source/model repos)") + ap.add_argument("--flavor", default="a100-large", help="HF Jobs hardware flavor") + ap.add_argument("--sft-steps", type=int, default=1000) + ap.add_argument("--grpo-steps", type=int, default=200) + ap.add_argument("--base-model", default="Qwen/Qwen2.5-3B-Instruct") + ap.add_argument("--timeout", default="6h", help="job timeout (e.g. 30m, 2h, 6h)") + ap.add_argument("--skip-publish", action="store_true", help="reuse existing forgeenv-source repo") + ap.add_argument("--dry-run", action="store_true", help="publish source but do not launch the job") + ap.add_argument("--no-tail", action="store_true", help="skip log streaming after submission") + return ap.parse_args() + + +def publish_source(api: HfApi, token: str, user: str) -> str: + repo_id = f"{user}/forgeenv-source" + print(f"[launcher] publishing source -> {repo_id}", flush=True) + api.create_repo(repo_id=repo_id, repo_type="model", token=token, exist_ok=True, private=False) + api.upload_folder( + folder_path=str(REPO_ROOT), + repo_id=repo_id, + repo_type="model", + token=token, + commit_message="forgeenv source snapshot for training job", + ignore_patterns=[ + "__pycache__", + "*.pyc", + ".pytest_cache", + ".venv", + "venv", + "*.egg-info", + ".git", + ".github", + "outputs", + "wandb", + "*.log", + ], + ) + print(f"[launcher] source live at https://huggingface.co/{repo_id}", flush=True) + return repo_id + + +def submit_job( + api: HfApi, + token: str, + user: str, + flavor: str, + sft_steps: int, + grpo_steps: int, + base_model: str, + timeout: str, +) -> JobInfo: + # The training script lives in the published source repo. Pass its + # raw Hub URL β€” `run_uv_job` accepts a URL/path/command, not the + # script body itself. + script_url = ( + f"https://huggingface.co/{user}/forgeenv-source/" + "resolve/main/scripts/jobs/train_repair_agent.py" + ) + + job = api.run_uv_job( + script=script_url, + dependencies=[ + "huggingface_hub>=0.27", + "requests", + ], + flavor=flavor, + timeout=timeout, + namespace=user, + env={ + "HF_USERNAME": user, + "ENV_URL": f"https://{user}-forgeenv.hf.space", + "SOURCE_REPO": f"{user}/forgeenv-source", + "MODEL_REPO": f"{user}/forgeenv-repair-agent", + "BASE_MODEL": base_model, + "SFT_STEPS": str(sft_steps), + "GRPO_STEPS": str(grpo_steps), + "PYTHONUNBUFFERED": "1", + }, + secrets={"HF_TOKEN": token}, + token=token, + ) + return job + + +_TERMINAL_STAGES = {"COMPLETED", "FAILED", "CANCELLED", "ERROR", "DELETED"} + + +def _stage_of(info) -> str: + status = getattr(info, "status", None) + if status is None: + return "UNKNOWN" + stage = getattr(status, "stage", None) + if stage is None: + return str(status) + return str(stage) + + +def tail_logs(api: HfApi, token: str, job_id: str, namespace: str | None = None) -> int: + print(f"\n[launcher] streaming logs for job {job_id} (Ctrl-C to stop tailing) ...\n", flush=True) + try: + for line in api.fetch_job_logs(job_id=job_id, namespace=namespace, token=token): + print(line, flush=True) + except KeyboardInterrupt: + print("\n[launcher] log stream interrupted by user.", flush=True) + except Exception as e: # noqa: BLE001 + print(f"\n[launcher] log stream ended ({e}); polling status ...", flush=True) + + last_stage: str | None = None + while True: + info = api.inspect_job(job_id=job_id, namespace=namespace, token=token) + stage = _stage_of(info) + if stage != last_stage: + print(f"[launcher] status: {stage}", flush=True) + last_stage = stage + if stage in _TERMINAL_STAGES: + break + time.sleep(20) + + print(f"[launcher] final status: {last_stage}", flush=True) + return 0 if last_stage == "COMPLETED" else 1 + + +def main() -> int: + args = parse_args() + token = os.environ.get("HF_TOKEN") + if not token: + print("ERROR: set HF_TOKEN in the environment first.", file=sys.stderr) + return 2 + + api = HfApi() + + if not args.skip_publish: + publish_source(api, token, args.user) + + if args.dry_run: + print("[launcher] --dry-run set; not submitting job.", flush=True) + return 0 + + print( + f"[launcher] submitting job (flavor={args.flavor}, sft={args.sft_steps}, " + f"grpo={args.grpo_steps}, timeout={args.timeout}) ...", + flush=True, + ) + job = submit_job( + api=api, + token=token, + user=args.user, + flavor=args.flavor, + sft_steps=args.sft_steps, + grpo_steps=args.grpo_steps, + base_model=args.base_model, + timeout=args.timeout, + ) + job_id = getattr(job, "id", None) or getattr(job, "job_id", None) + print(f"[launcher] job submitted: id={job_id}", flush=True) + print(f"[launcher] dashboard: https://huggingface.co/jobs/{args.user}", flush=True) + + if args.no_tail: + return 0 + return tail_logs(api, token, job_id, namespace=args.user) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/tail_training_job.py b/scripts/tail_training_job.py index e60741e6a0b40070236fea447b518a5da567df7d..8d175209622c9037a283b0521ebfa6dc260d7c78 100644 --- a/scripts/tail_training_job.py +++ b/scripts/tail_training_job.py @@ -1,34 +1,34 @@ -#!/usr/bin/env python -"""Re-attach to an in-flight HF Jobs run and stream its logs. - -Usage:: - - $env:HF_TOKEN = "hf_..." - python scripts/tail_training_job.py 69ec88dfd70108f37acde39d -""" -from __future__ import annotations - -import os -import sys - -from huggingface_hub import HfApi - -from submit_training_job import tail_logs # type: ignore[import-not-found] - - -def main() -> int: - if len(sys.argv) < 2: - print("usage: python scripts/tail_training_job.py [namespace]", file=sys.stderr) - return 2 - job_id = sys.argv[1] - namespace = sys.argv[2] if len(sys.argv) > 2 else "akhiilll" - token = os.environ.get("HF_TOKEN") - if not token: - print("ERROR: set HF_TOKEN in the environment first.", file=sys.stderr) - return 2 - api = HfApi() - return tail_logs(api, token, job_id, namespace=namespace) - - -if __name__ == "__main__": - raise SystemExit(main()) +#!/usr/bin/env python +"""Re-attach to an in-flight HF Jobs run and stream its logs. + +Usage:: + + $env:HF_TOKEN = "hf_..." + python scripts/tail_training_job.py 69ec88dfd70108f37acde39d +""" +from __future__ import annotations + +import os +import sys + +from huggingface_hub import HfApi + +from submit_training_job import tail_logs # type: ignore[import-not-found] + + +def main() -> int: + if len(sys.argv) < 2: + print("usage: python scripts/tail_training_job.py [namespace]", file=sys.stderr) + return 2 + job_id = sys.argv[1] + namespace = sys.argv[2] if len(sys.argv) > 2 else "akhiilll" + token = os.environ.get("HF_TOKEN") + if not token: + print("ERROR: set HF_TOKEN in the environment first.", file=sys.stderr) + return 2 + api = HfApi() + return tail_logs(api, token, job_id, namespace=namespace) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/test_live_env.py b/scripts/test_live_env.py index 58bf7c6b82e27301ac6870942b7d21b2d9e6f2dd..97f33b323c96edbfbc2c3b5ad06186cd721f6879 100644 --- a/scripts/test_live_env.py +++ b/scripts/test_live_env.py @@ -1,76 +1,76 @@ -"""Smoke-test the live ForgeEnv Space end-to-end via the OpenEnv client. - -Runs one full episode against the deployed Space: - - reset() -> drift-gen turn - step(DriftAction) -> repair turn - step(RepairAction) -> reward + verifier breakdown - -This is the simplest possible "is the deployed env working?" check -and a clean standalone artifact for the hackathon writeup/video. - -Usage:: - - python scripts/test_live_env.py -""" -from __future__ import annotations - -import asyncio -import json - -from openenv.core import GenericAction, GenericEnvClient - -ENV_URL = "https://akhiilll-forgeenv.hf.space" - - -def _summary(result, label: str) -> None: - obs = result.observation if isinstance(result.observation, dict) else {} - print(f"\n=== {label} ===") - print(f"phase : {obs.get('current_phase')}") - print(f"task_id : {obs.get('task_id')}") - print(f"target_category : {obs.get('target_category')}") - print(f"reward : {result.reward}") - print(f"done : {result.done}") - breakdown = obs.get("reward_breakdown") - if breakdown: - print("reward_breakdown:") - print(json.dumps(breakdown, indent=2)) - script = obs.get("script_content") or obs.get("broken_script") or "" - if script: - preview = script.splitlines()[:8] - print("script preview :") - for line in preview: - print(f" | {line}") - if len(script.splitlines()) > 8: - print(" | ...") - - -async def main(seed: int = 42) -> None: - print(f"connecting to {ENV_URL} (seed={seed}) ...") - client = GenericEnvClient(base_url=ENV_URL) - - res = await client.reset(seed=seed, options={"difficulty": "medium"}) - _summary(res, "after reset()") - target = res.observation.get("target_category", "RenameApiCall") if isinstance(res.observation, dict) else "RenameApiCall" - - res = await client.step(GenericAction( - breakage={"action_type": "breakage", "primitive_type": target, "params": {}}, - repair=None, - )) - _summary(res, "after drift step (Challenger)") - - # empty diff = no-op repair: shows the verifier marking the script as still broken - res = await client.step(GenericAction( - breakage=None, - repair={"action_type": "repair", "unified_diff": ""}, - )) - _summary(res, "after repair step (Solver, no-op)") - - print("\nOK -- reset + 2 steps round-trip the deployed env.") - - -if __name__ == "__main__": - import sys - - seed = int(sys.argv[1]) if len(sys.argv) > 1 else 42 - asyncio.run(main(seed=seed)) +"""Smoke-test the live ForgeEnv Space end-to-end via the OpenEnv client. + +Runs one full episode against the deployed Space: + + reset() -> drift-gen turn + step(DriftAction) -> repair turn + step(RepairAction) -> reward + verifier breakdown + +This is the simplest possible "is the deployed env working?" check +and a clean standalone artifact for the hackathon writeup/video. + +Usage:: + + python scripts/test_live_env.py +""" +from __future__ import annotations + +import asyncio +import json + +from openenv.core import GenericAction, GenericEnvClient + +ENV_URL = "https://akhiilll-forgeenv.hf.space" + + +def _summary(result, label: str) -> None: + obs = result.observation if isinstance(result.observation, dict) else {} + print(f"\n=== {label} ===") + print(f"phase : {obs.get('current_phase')}") + print(f"task_id : {obs.get('task_id')}") + print(f"target_category : {obs.get('target_category')}") + print(f"reward : {result.reward}") + print(f"done : {result.done}") + breakdown = obs.get("reward_breakdown") + if breakdown: + print("reward_breakdown:") + print(json.dumps(breakdown, indent=2)) + script = obs.get("script_content") or obs.get("broken_script") or "" + if script: + preview = script.splitlines()[:8] + print("script preview :") + for line in preview: + print(f" | {line}") + if len(script.splitlines()) > 8: + print(" | ...") + + +async def main(seed: int = 42) -> None: + print(f"connecting to {ENV_URL} (seed={seed}) ...") + client = GenericEnvClient(base_url=ENV_URL) + + res = await client.reset(seed=seed, options={"difficulty": "medium"}) + _summary(res, "after reset()") + target = res.observation.get("target_category", "RenameApiCall") if isinstance(res.observation, dict) else "RenameApiCall" + + res = await client.step(GenericAction( + breakage={"action_type": "breakage", "primitive_type": target, "params": {}}, + repair=None, + )) + _summary(res, "after drift step (Challenger)") + + # empty diff = no-op repair: shows the verifier marking the script as still broken + res = await client.step(GenericAction( + breakage=None, + repair={"action_type": "repair", "unified_diff": ""}, + )) + _summary(res, "after repair step (Solver, no-op)") + + print("\nOK -- reset + 2 steps round-trip the deployed env.") + + +if __name__ == "__main__": + import sys + + seed = int(sys.argv[1]) if len(sys.argv) > 1 else 42 + asyncio.run(main(seed=seed)) diff --git a/scripts/test_repair_agent.py b/scripts/test_repair_agent.py index a458ac9efe090c93cc2f4e3f7541bfbe4d51d1d5..2e6d69f55290a1504f09c3216002e0cbe02e0da6 100644 --- a/scripts/test_repair_agent.py +++ b/scripts/test_repair_agent.py @@ -1,123 +1,123 @@ -"""Smoke-test the trained Repair Agent locally on one episode. - -Loads the LoRA adapter pushed to ``akhiilll/forgeenv-repair-agent``, hits -the live ForgeEnv Space for a fresh broken script, asks the model to -emit a unified diff, applies it, and prints the verifier breakdown. - -Usage:: - - python scripts/test_repair_agent.py --seed 7 - python scripts/test_repair_agent.py --seed 7 --base-model unsloth/Qwen2.5-Coder-1.5B-Instruct - -Requires GPU + transformers/peft. Skip this if you only want a quick -demo -- use ``scripts/test_live_env.py`` or the Gradio Space instead. -""" -from __future__ import annotations - -import argparse -import asyncio -import json - -from openenv.core import GenericAction, GenericEnvClient - -ENV_URL = "https://akhiilll-forgeenv.hf.space" -LORA_REPO = "akhiilll/forgeenv-repair-agent" - -REPAIR_PROMPT = """\ -You are a senior ML engineer fixing a HuggingFace training script that just broke. -Output ONLY a unified diff (`--- a/script.py` / `+++ b/script.py`) that fixes the -breakage signaled by the error trace. No prose, no fences, no explanation. - -# Broken script -```python -{script} -``` - -# Error trace -``` -{error} -``` - -# Diff -""" - - -async def fetch_broken_episode(seed: int): - client = GenericEnvClient(base_url=ENV_URL) - res = await client.reset(seed=seed, options={"difficulty": "medium"}) - target = res.observation["target_category"] - res = await client.step(GenericAction( - breakage={"action_type": "breakage", "primitive_type": target, "params": {}}, - repair=None, - )) - obs = res.observation - return client, obs.get("script_content") or obs.get("broken_script") or "", obs.get("error_trace", "") - - -async def submit_repair(client: GenericEnvClient, diff: str): - res = await client.step(GenericAction( - breakage=None, - repair={"action_type": "repair", "unified_diff": diff}, - )) - return res - - -def generate_diff(base_model: str, lora_repo: str, prompt: str) -> str: - import torch - from peft import PeftModel - from transformers import AutoModelForCausalLM, AutoTokenizer - - print(f"loading base model: {base_model}") - tok = AutoTokenizer.from_pretrained(base_model) - model = AutoModelForCausalLM.from_pretrained( - base_model, - torch_dtype=torch.bfloat16, - device_map="auto", - ) - print(f"attaching LoRA: {lora_repo}") - model = PeftModel.from_pretrained(model, lora_repo) - model.eval() - - inputs = tok(prompt, return_tensors="pt").to(model.device) - with torch.no_grad(): - out = model.generate( - **inputs, - max_new_tokens=512, - do_sample=False, - temperature=0.0, - pad_token_id=tok.eos_token_id, - ) - text = tok.decode(out[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True) - return text.strip() - - -async def main(args) -> None: - print(f"--- pulling broken episode (seed={args.seed}) from {ENV_URL}") - client, broken_script, error_trace = await fetch_broken_episode(args.seed) - if not broken_script: - raise SystemExit("env returned empty script_content; pick a different seed") - print(f"broken script length: {len(broken_script)} chars") - print(f"error trace : {(error_trace[:200] + '...') if len(error_trace) > 200 else error_trace}") - - prompt = REPAIR_PROMPT.format(script=broken_script, error=error_trace or "") - diff = generate_diff(args.base_model, args.lora_repo, prompt) - - print("\n=== model diff ===") - print(diff) - - print("\n=== submitting diff to env ===") - res = await submit_repair(client, diff) - print(f"reward: {res.reward} done: {res.done}") - breakdown = res.observation.get("reward_breakdown") if isinstance(res.observation, dict) else None - if breakdown: - print("reward_breakdown:") - print(json.dumps(breakdown, indent=2)) - - -if __name__ == "__main__": - p = argparse.ArgumentParser() - p.add_argument("--seed", type=int, default=7) - p.add_argument("--base-model", default="unsloth/Qwen2.5-Coder-1.5B-Instruct") - p.add_argument("--lora-repo", default=LORA_REPO) - args = p.parse_args() - asyncio.run(main(args)) +"""Smoke-test the trained Repair Agent locally on one episode. + +Loads the LoRA adapter pushed to ``akhiilll/forgeenv-repair-agent``, hits +the live ForgeEnv Space for a fresh broken script, asks the model to +emit a unified diff, applies it, and prints the verifier breakdown. + +Usage:: + + python scripts/test_repair_agent.py --seed 7 + python scripts/test_repair_agent.py --seed 7 --base-model unsloth/Qwen2.5-Coder-1.5B-Instruct + +Requires GPU + transformers/peft. Skip this if you only want a quick +demo -- use ``scripts/test_live_env.py`` or the Gradio Space instead. +""" +from __future__ import annotations + +import argparse +import asyncio +import json + +from openenv.core import GenericAction, GenericEnvClient + +ENV_URL = "https://akhiilll-forgeenv.hf.space" +LORA_REPO = "akhiilll/forgeenv-repair-agent" + +REPAIR_PROMPT = """\ +You are a senior ML engineer fixing a HuggingFace training script that just broke. +Output ONLY a unified diff (`--- a/script.py` / `+++ b/script.py`) that fixes the +breakage signaled by the error trace. No prose, no fences, no explanation. + +# Broken script +```python +{script} +``` + +# Error trace +``` +{error} +``` + +# Diff +""" + + +async def fetch_broken_episode(seed: int): + client = GenericEnvClient(base_url=ENV_URL) + res = await client.reset(seed=seed, options={"difficulty": "medium"}) + target = res.observation["target_category"] + res = await client.step(GenericAction( + breakage={"action_type": "breakage", "primitive_type": target, "params": {}}, + repair=None, + )) + obs = res.observation + return client, obs.get("script_content") or obs.get("broken_script") or "", obs.get("error_trace", "") + + +async def submit_repair(client: GenericEnvClient, diff: str): + res = await client.step(GenericAction( + breakage=None, + repair={"action_type": "repair", "unified_diff": diff}, + )) + return res + + +def generate_diff(base_model: str, lora_repo: str, prompt: str) -> str: + import torch + from peft import PeftModel + from transformers import AutoModelForCausalLM, AutoTokenizer + + print(f"loading base model: {base_model}") + tok = AutoTokenizer.from_pretrained(base_model) + model = AutoModelForCausalLM.from_pretrained( + base_model, + torch_dtype=torch.bfloat16, + device_map="auto", + ) + print(f"attaching LoRA: {lora_repo}") + model = PeftModel.from_pretrained(model, lora_repo) + model.eval() + + inputs = tok(prompt, return_tensors="pt").to(model.device) + with torch.no_grad(): + out = model.generate( + **inputs, + max_new_tokens=512, + do_sample=False, + temperature=0.0, + pad_token_id=tok.eos_token_id, + ) + text = tok.decode(out[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True) + return text.strip() + + +async def main(args) -> None: + print(f"--- pulling broken episode (seed={args.seed}) from {ENV_URL}") + client, broken_script, error_trace = await fetch_broken_episode(args.seed) + if not broken_script: + raise SystemExit("env returned empty script_content; pick a different seed") + print(f"broken script length: {len(broken_script)} chars") + print(f"error trace : {(error_trace[:200] + '...') if len(error_trace) > 200 else error_trace}") + + prompt = REPAIR_PROMPT.format(script=broken_script, error=error_trace or "") + diff = generate_diff(args.base_model, args.lora_repo, prompt) + + print("\n=== model diff ===") + print(diff) + + print("\n=== submitting diff to env ===") + res = await submit_repair(client, diff) + print(f"reward: {res.reward} done: {res.done}") + breakdown = res.observation.get("reward_breakdown") if isinstance(res.observation, dict) else None + if breakdown: + print("reward_breakdown:") + print(json.dumps(breakdown, indent=2)) + + +if __name__ == "__main__": + p = argparse.ArgumentParser() + p.add_argument("--seed", type=int, default=7) + p.add_argument("--base-model", default="unsloth/Qwen2.5-Coder-1.5B-Instruct") + p.add_argument("--lora-repo", default=LORA_REPO) + args = p.parse_args() + asyncio.run(main(args)) diff --git a/tests/test_ast_validator.py b/tests/test_ast_validator.py index bb40db42516e124f9a4aeb8c66f1472335b54dfa..9ed4e94b0559d3cff9760b12c65684caa42951af 100644 --- a/tests/test_ast_validator.py +++ b/tests/test_ast_validator.py @@ -1,69 +1,69 @@ -"""Tests for the AST-based forbidden-pattern validator.""" -from forgeenv.sandbox.ast_validator import validate_script - - -def test_clean_script_passes(): - script = """ -import torch -from transformers import Trainer -model = Trainer() -""" - result = validate_script(script) - assert result.is_valid, f"Clean script should pass: {result.violations}" - - -def test_os_import_fails(): - script = "import os\nos.system('rm -rf /')" - result = validate_script(script) - assert not result.is_valid - assert any("os" in v for v in result.violations) - - -def test_subprocess_fails(): - script = "import subprocess\nsubprocess.run(['ls'])" - result = validate_script(script) - assert not result.is_valid - - -def test_eval_fails(): - script = "result = eval('1+1')" - result = validate_script(script) - assert not result.is_valid - assert any("eval" in v for v in result.violations) - - -def test_syntax_error_fails(): - script = "def foo(\n broken syntax" - result = validate_script(script) - assert not result.is_valid - assert any("SyntaxError" in v for v in result.violations) - - -def test_transformers_import_passes(): - script = """ -from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments -from datasets import load_dataset -import torch -""" - result = validate_script(script) - assert result.is_valid - - -def test_socket_import_fails(): - script = "import socket\ns = socket.socket()" - result = validate_script(script) - assert not result.is_valid - - -def test_builtins_assignment_fails(): - script = "__builtins__ = {}" - result = validate_script(script) - assert not result.is_valid - - -def test_attribute_eval_fails(): - """eval accessed via attribute (e.g. ast.literal_eval is fine, but - something.eval() of certain shape should be flagged when name is exec).""" - script = "obj.exec('rm -rf /')" - result = validate_script(script) - assert not result.is_valid +"""Tests for the AST-based forbidden-pattern validator.""" +from forgeenv.sandbox.ast_validator import validate_script + + +def test_clean_script_passes(): + script = """ +import torch +from transformers import Trainer +model = Trainer() +""" + result = validate_script(script) + assert result.is_valid, f"Clean script should pass: {result.violations}" + + +def test_os_import_fails(): + script = "import os\nos.system('rm -rf /')" + result = validate_script(script) + assert not result.is_valid + assert any("os" in v for v in result.violations) + + +def test_subprocess_fails(): + script = "import subprocess\nsubprocess.run(['ls'])" + result = validate_script(script) + assert not result.is_valid + + +def test_eval_fails(): + script = "result = eval('1+1')" + result = validate_script(script) + assert not result.is_valid + assert any("eval" in v for v in result.violations) + + +def test_syntax_error_fails(): + script = "def foo(\n broken syntax" + result = validate_script(script) + assert not result.is_valid + assert any("SyntaxError" in v for v in result.violations) + + +def test_transformers_import_passes(): + script = """ +from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments +from datasets import load_dataset +import torch +""" + result = validate_script(script) + assert result.is_valid + + +def test_socket_import_fails(): + script = "import socket\ns = socket.socket()" + result = validate_script(script) + assert not result.is_valid + + +def test_builtins_assignment_fails(): + script = "__builtins__ = {}" + result = validate_script(script) + assert not result.is_valid + + +def test_attribute_eval_fails(): + """eval accessed via attribute (e.g. ast.literal_eval is fine, but + something.eval() of certain shape should be flagged when name is exec).""" + script = "obj.exec('rm -rf /')" + result = validate_script(script) + assert not result.is_valid diff --git a/tests/test_environment.py b/tests/test_environment.py index 25d3fe01965bf83463920809d267fd8643e1c880..53d2d96a3594f19fb94a7f697971e767db576845 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -1,122 +1,122 @@ -"""End-to-end tests for the OpenEnv-wrapped ForgeEnvironment.""" -import pytest - -from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction -from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff -from forgeenv.env.forge_environment import ForgeEnvironment -from forgeenv.env.observations import ForgeObservation -from forgeenv.roles.teacher import Teacher - - -def test_reset_returns_drift_gen_observation(): - env = ForgeEnvironment(seed=0) - obs = env.reset() - - assert isinstance(obs, ForgeObservation) - assert obs.current_phase == "drift_gen" - assert obs.done is False - assert obs.script_content - from forgeenv.primitives.breakage_primitives import PRIMITIVE_REGISTRY - - assert obs.target_category in PRIMITIVE_REGISTRY - assert obs.library_versions # non-empty dict - - -def test_full_episode_lifecycle(): - env = ForgeEnvironment(seed=0) - obs = env.reset() - initial_script = obs.script_content - - breakage = ForgeAction( - breakage=BreakageAction( - primitive_type="DeprecateImport", - params={"old_module": "import torch", "new_module": "import torch.legacy"}, - ) - ) - obs2 = env.step(breakage) - assert obs2.current_phase == "repair" - assert obs2.done is False - assert obs2.info.get("breakage_primitive") == "DeprecateImport" - assert obs2.error_trace is not None - - repair = ForgeAction(repair=RepairAction(unified_diff=initial_script)) - obs3 = env.step(repair) - assert obs3.current_phase == "done" - assert obs3.done is True - assert obs3.reward is not None - assert isinstance(obs3.reward_breakdown, dict) - assert isinstance(obs3.held_out_breakdown, dict) - assert {"executed_cleanly", "checkpoint_valid"} <= set(obs3.held_out_breakdown) - - -def test_invalid_action_for_phase(): - env = ForgeEnvironment(seed=0) - env.reset() - repair_first = ForgeAction(repair=RepairAction(unified_diff="print('hi')")) - obs = env.step(repair_first) - # Should not raise β€” should return a done=True error observation. - assert obs.done is True - assert obs.info.get("error") is not None - - -def test_step_before_reset_returns_error(): - env = ForgeEnvironment(seed=0) - breakage = ForgeAction( - breakage=BreakageAction(primitive_type="RenameApiCall", params={}) - ) - obs = env.step(breakage) - assert obs.done is True - assert obs.info.get("error") - - -def test_state_property_is_dict(): - env = ForgeEnvironment(seed=0) - env.reset() - state = env.state - assert isinstance(state, dict) - assert "phase" in state and "library_versions" in state and "teacher" in state - - -def test_action_validation_rejects_both_or_neither(): - with pytest.raises(Exception): - ForgeAction() - with pytest.raises(Exception): - ForgeAction( - breakage=BreakageAction(primitive_type="RenameApiCall", params={}), - repair=RepairAction(unified_diff="x"), - ) - - -def test_teacher_updates_after_episode(): - teacher = Teacher(categories=["RenameApiCall"]) - env = ForgeEnvironment(teacher=teacher, seed=0) - env.reset() - env.step( - ForgeAction( - breakage=BreakageAction( - primitive_type="RenameApiCall", - params={"old_name": "x", "new_name": "y"}, - ) - ) - ) - env.step(ForgeAction(repair=RepairAction(unified_diff="print('noop')"))) - state = teacher.get_state() - assert any(s["attempts"] >= 1 for s in state.values()) - - -def test_unified_diff_round_trip(): - before = "hello\nworld\n" - after = "hello\nplanet\n" - diff = make_unified_diff(before, after) - repaired = apply_unified_diff(before, diff) - assert repaired == after - - -def test_unified_diff_full_script_replacement(): - full_script = """import torch -from transformers import Trainer -trainer = Trainer() -trainer.train() -""" - repaired = apply_unified_diff("broken stuff", full_script) - assert repaired == full_script +"""End-to-end tests for the OpenEnv-wrapped ForgeEnvironment.""" +import pytest + +from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction +from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff +from forgeenv.env.forge_environment import ForgeEnvironment +from forgeenv.env.observations import ForgeObservation +from forgeenv.roles.teacher import Teacher + + +def test_reset_returns_drift_gen_observation(): + env = ForgeEnvironment(seed=0) + obs = env.reset() + + assert isinstance(obs, ForgeObservation) + assert obs.current_phase == "drift_gen" + assert obs.done is False + assert obs.script_content + from forgeenv.primitives.breakage_primitives import PRIMITIVE_REGISTRY + + assert obs.target_category in PRIMITIVE_REGISTRY + assert obs.library_versions # non-empty dict + + +def test_full_episode_lifecycle(): + env = ForgeEnvironment(seed=0) + obs = env.reset() + initial_script = obs.script_content + + breakage = ForgeAction( + breakage=BreakageAction( + primitive_type="DeprecateImport", + params={"old_module": "import torch", "new_module": "import torch.legacy"}, + ) + ) + obs2 = env.step(breakage) + assert obs2.current_phase == "repair" + assert obs2.done is False + assert obs2.info.get("breakage_primitive") == "DeprecateImport" + assert obs2.error_trace is not None + + repair = ForgeAction(repair=RepairAction(unified_diff=initial_script)) + obs3 = env.step(repair) + assert obs3.current_phase == "done" + assert obs3.done is True + assert obs3.reward is not None + assert isinstance(obs3.reward_breakdown, dict) + assert isinstance(obs3.held_out_breakdown, dict) + assert {"executed_cleanly", "checkpoint_valid"} <= set(obs3.held_out_breakdown) + + +def test_invalid_action_for_phase(): + env = ForgeEnvironment(seed=0) + env.reset() + repair_first = ForgeAction(repair=RepairAction(unified_diff="print('hi')")) + obs = env.step(repair_first) + # Should not raise β€” should return a done=True error observation. + assert obs.done is True + assert obs.info.get("error") is not None + + +def test_step_before_reset_returns_error(): + env = ForgeEnvironment(seed=0) + breakage = ForgeAction( + breakage=BreakageAction(primitive_type="RenameApiCall", params={}) + ) + obs = env.step(breakage) + assert obs.done is True + assert obs.info.get("error") + + +def test_state_property_is_dict(): + env = ForgeEnvironment(seed=0) + env.reset() + state = env.state + assert isinstance(state, dict) + assert "phase" in state and "library_versions" in state and "teacher" in state + + +def test_action_validation_rejects_both_or_neither(): + with pytest.raises(Exception): + ForgeAction() + with pytest.raises(Exception): + ForgeAction( + breakage=BreakageAction(primitive_type="RenameApiCall", params={}), + repair=RepairAction(unified_diff="x"), + ) + + +def test_teacher_updates_after_episode(): + teacher = Teacher(categories=["RenameApiCall"]) + env = ForgeEnvironment(teacher=teacher, seed=0) + env.reset() + env.step( + ForgeAction( + breakage=BreakageAction( + primitive_type="RenameApiCall", + params={"old_name": "x", "new_name": "y"}, + ) + ) + ) + env.step(ForgeAction(repair=RepairAction(unified_diff="print('noop')"))) + state = teacher.get_state() + assert any(s["attempts"] >= 1 for s in state.values()) + + +def test_unified_diff_round_trip(): + before = "hello\nworld\n" + after = "hello\nplanet\n" + diff = make_unified_diff(before, after) + repaired = apply_unified_diff(before, diff) + assert repaired == after + + +def test_unified_diff_full_script_replacement(): + full_script = """import torch +from transformers import Trainer +trainer = Trainer() +trainer.train() +""" + repaired = apply_unified_diff("broken stuff", full_script) + assert repaired == full_script diff --git a/tests/test_evaluators.py b/tests/test_evaluators.py index 46f2bf0157c43fc65f18121289ae169077b56209..2934815a3c7de917621855935888627f6df644d8 100644 --- a/tests/test_evaluators.py +++ b/tests/test_evaluators.py @@ -1,138 +1,138 @@ -"""Tests for visible verifier, held-out evaluator, and R-Zero reward functions.""" -from forgeenv.tasks.models import ExecutionResult, Task -from forgeenv.training.reward_functions import ( - compute_alignment_score, - compute_drift_gen_reward, - compute_repetition_penalty, - compute_uncertainty_reward, -) -from forgeenv.verifier.held_out_evaluator import compute_held_out_scores -from forgeenv.verifier.visible_verifier import compute_visible_reward - -SAMPLE_TASK = Task( - task_id="test_001", - description="Test task", - script_content=( - "from transformers import Trainer\n" - "trainer = Trainer()\n" - "trainer.train()\n" - ), - difficulty="easy", -) - - -def test_visible_reward_success(): - result = ExecutionResult( - exit_code=0, - stdout="step=1 loss=3.5\nstep=2 loss=2.1\nTRAINING_COMPLETE", - stderr="", - wall_time_ms=1000, - checkpoint_exists=True, - script_content=SAMPLE_TASK.script_content, - ) - reward, breakdown = compute_visible_reward(result, SAMPLE_TASK) - assert reward > 0, f"Successful run should have positive reward, got {reward}" - assert breakdown["script_executes"] == 1.0 - assert breakdown["loss_decreased"] > 0 - - -def test_visible_reward_failure(): - result = ExecutionResult( - exit_code=1, - stdout="", - stderr="Error", - wall_time_ms=100, - script_content=SAMPLE_TASK.script_content, - ) - reward, breakdown = compute_visible_reward(result, SAMPLE_TASK) - assert breakdown["script_executes"] == 0.0 - assert reward <= 0.0 - - -def test_held_out_success(): - result = ExecutionResult( - exit_code=0, - stdout="step=1 loss=3.5\nstep=2 loss=2.1\neval_accuracy=0.78\nTRAINING_COMPLETE", - stderr="", - wall_time_ms=1000, - checkpoint_exists=True, - script_content=SAMPLE_TASK.script_content, - ) - scores = compute_held_out_scores(result, SAMPLE_TASK) - assert scores["executed_cleanly"] == 1.0 - assert scores["loss_decreased"] > 0 - assert scores["hidden_tests_passed"] == 1.0 - assert scores["intent_preserved"] == 1.0 - - -def test_held_out_workaround_detection(): - """Bare except wrapping all code should reduce no_forbidden_workarounds.""" - result = ExecutionResult( - exit_code=0, - stdout="TRAINING_COMPLETE", - stderr="", - wall_time_ms=100, - checkpoint_exists=True, - script_content="try:\n bad()\nexcept:\n pass\n", - ) - scores = compute_held_out_scores(result, SAMPLE_TASK) - assert scores["no_forbidden_workarounds"] < 1.0 - - -def test_uncertainty_peaks_at_half(): - r_half = compute_uncertainty_reward([True, False, True, False, True, False]) - r_all = compute_uncertainty_reward([True, True, True, True]) - r_none = compute_uncertainty_reward([False, False, False, False]) - - assert r_half > r_all - assert r_half > r_none - assert abs(r_all) < 0.01 - assert abs(r_none) < 0.01 - - -def test_uncertainty_handles_empty(): - assert compute_uncertainty_reward([]) == 0.0 - - -def test_repetition_penalty_higher_for_duplicates(): - batch = [ - "rename evaluate to eval_model", - "rename evaluate to eval_model", - "rename evaluate to eval_model", - "change import path for trainer", - ] - p_dup = compute_repetition_penalty(batch[0], batch) - p_unique = compute_repetition_penalty(batch[3], batch) - assert p_dup >= p_unique - - -def test_drift_gen_reward_combines_signals(): - """Composite reward should rise with uncertainty and fall with repetition.""" - high_unc_unique = compute_drift_gen_reward( - "unique unique unique tokens", - [True, False, True, False], - ["totally different a b c", "unique unique unique tokens"], - ) - high_unc_repeated = compute_drift_gen_reward( - "same same same same", - [True, False, True, False], - ["same same same same", "same same same same", "same same same same"], - ) - assert high_unc_unique >= high_unc_repeated - - -def test_alignment_score_perfect_correlation(): - visible = [0.0, 0.25, 0.5, 0.75, 1.0] - held_out = [0.0, 0.25, 0.5, 0.75, 1.0] - assert compute_alignment_score(visible, held_out) > 0.99 - - -def test_alignment_score_anti_correlation(): - visible = [1.0, 0.5, 0.0] - held_out = [0.0, 0.5, 1.0] - assert compute_alignment_score(visible, held_out) < -0.99 - - -def test_alignment_score_constant_returns_zero(): - """No variance in either array β†’ no signal β†’ 0.0.""" - assert compute_alignment_score([0.5, 0.5, 0.5], [0.1, 0.9, 0.4]) == 0.0 +"""Tests for visible verifier, held-out evaluator, and R-Zero reward functions.""" +from forgeenv.tasks.models import ExecutionResult, Task +from forgeenv.training.reward_functions import ( + compute_alignment_score, + compute_drift_gen_reward, + compute_repetition_penalty, + compute_uncertainty_reward, +) +from forgeenv.verifier.held_out_evaluator import compute_held_out_scores +from forgeenv.verifier.visible_verifier import compute_visible_reward + +SAMPLE_TASK = Task( + task_id="test_001", + description="Test task", + script_content=( + "from transformers import Trainer\n" + "trainer = Trainer()\n" + "trainer.train()\n" + ), + difficulty="easy", +) + + +def test_visible_reward_success(): + result = ExecutionResult( + exit_code=0, + stdout="step=1 loss=3.5\nstep=2 loss=2.1\nTRAINING_COMPLETE", + stderr="", + wall_time_ms=1000, + checkpoint_exists=True, + script_content=SAMPLE_TASK.script_content, + ) + reward, breakdown = compute_visible_reward(result, SAMPLE_TASK) + assert reward > 0, f"Successful run should have positive reward, got {reward}" + assert breakdown["script_executes"] == 1.0 + assert breakdown["loss_decreased"] > 0 + + +def test_visible_reward_failure(): + result = ExecutionResult( + exit_code=1, + stdout="", + stderr="Error", + wall_time_ms=100, + script_content=SAMPLE_TASK.script_content, + ) + reward, breakdown = compute_visible_reward(result, SAMPLE_TASK) + assert breakdown["script_executes"] == 0.0 + assert reward <= 0.0 + + +def test_held_out_success(): + result = ExecutionResult( + exit_code=0, + stdout="step=1 loss=3.5\nstep=2 loss=2.1\neval_accuracy=0.78\nTRAINING_COMPLETE", + stderr="", + wall_time_ms=1000, + checkpoint_exists=True, + script_content=SAMPLE_TASK.script_content, + ) + scores = compute_held_out_scores(result, SAMPLE_TASK) + assert scores["executed_cleanly"] == 1.0 + assert scores["loss_decreased"] > 0 + assert scores["hidden_tests_passed"] == 1.0 + assert scores["intent_preserved"] == 1.0 + + +def test_held_out_workaround_detection(): + """Bare except wrapping all code should reduce no_forbidden_workarounds.""" + result = ExecutionResult( + exit_code=0, + stdout="TRAINING_COMPLETE", + stderr="", + wall_time_ms=100, + checkpoint_exists=True, + script_content="try:\n bad()\nexcept:\n pass\n", + ) + scores = compute_held_out_scores(result, SAMPLE_TASK) + assert scores["no_forbidden_workarounds"] < 1.0 + + +def test_uncertainty_peaks_at_half(): + r_half = compute_uncertainty_reward([True, False, True, False, True, False]) + r_all = compute_uncertainty_reward([True, True, True, True]) + r_none = compute_uncertainty_reward([False, False, False, False]) + + assert r_half > r_all + assert r_half > r_none + assert abs(r_all) < 0.01 + assert abs(r_none) < 0.01 + + +def test_uncertainty_handles_empty(): + assert compute_uncertainty_reward([]) == 0.0 + + +def test_repetition_penalty_higher_for_duplicates(): + batch = [ + "rename evaluate to eval_model", + "rename evaluate to eval_model", + "rename evaluate to eval_model", + "change import path for trainer", + ] + p_dup = compute_repetition_penalty(batch[0], batch) + p_unique = compute_repetition_penalty(batch[3], batch) + assert p_dup >= p_unique + + +def test_drift_gen_reward_combines_signals(): + """Composite reward should rise with uncertainty and fall with repetition.""" + high_unc_unique = compute_drift_gen_reward( + "unique unique unique tokens", + [True, False, True, False], + ["totally different a b c", "unique unique unique tokens"], + ) + high_unc_repeated = compute_drift_gen_reward( + "same same same same", + [True, False, True, False], + ["same same same same", "same same same same", "same same same same"], + ) + assert high_unc_unique >= high_unc_repeated + + +def test_alignment_score_perfect_correlation(): + visible = [0.0, 0.25, 0.5, 0.75, 1.0] + held_out = [0.0, 0.25, 0.5, 0.75, 1.0] + assert compute_alignment_score(visible, held_out) > 0.99 + + +def test_alignment_score_anti_correlation(): + visible = [1.0, 0.5, 0.0] + held_out = [0.0, 0.5, 1.0] + assert compute_alignment_score(visible, held_out) < -0.99 + + +def test_alignment_score_constant_returns_zero(): + """No variance in either array β†’ no signal β†’ 0.0.""" + assert compute_alignment_score([0.5, 0.5, 0.5], [0.1, 0.9, 0.4]) == 0.0 diff --git a/tests/test_primitives.py b/tests/test_primitives.py index 90d1951815daeae69a4fd69108c2c11276c89215..f3d9628a0c50edce5f2722655457bd7c090d934e 100644 --- a/tests/test_primitives.py +++ b/tests/test_primitives.py @@ -1,215 +1,215 @@ -"""Tests for breakage primitives, repair primitives, and task sampler.""" -import pytest - -from forgeenv.primitives.breakage_primitives import ( - ChangeArgumentSignature, - ChangeReturnType, - ChangeTokenizerBehavior, - DeprecateImport, - ModifyConfigField, - PRIMITIVE_REGISTRY, - RemoveDeprecatedMethod, - RenameApiCall, - RestructureDatasetSchema, - parse_breakage_spec, -) -from forgeenv.primitives.repair_primitives import ( - BREAKAGE_TO_REPAIR, - REPAIR_REGISTRY, - RestoreApiCall, - RestoreColumn, - RestoreImport, - RestoreMethod, -) -from forgeenv.tasks.task_sampler import TaskSampler - -SAMPLE_SCRIPT = """ -from transformers import Trainer, TrainingArguments -from datasets import load_dataset - -dataset = load_dataset("glue", "sst2") -dataset = dataset.rename_column("label", "labels") - -args = TrainingArguments(num_train_epochs=3, report_to="none") -trainer = Trainer(model=model, args=args, train_dataset=dataset) -trainer.train() -result = trainer.evaluate() -""" - - -def test_rename_api_call_word_boundary(): - """Renaming should not break out-of-context substrings.""" - b = RenameApiCall(old_name="evaluate", new_name="eval_model") - broken = b.apply(SAMPLE_SCRIPT) - assert "eval_model" in broken - assert "trainer.evaluate" not in broken - # Inverse should restore - r = RestoreApiCall(new_name="eval_model", old_name="evaluate") - restored = r.apply(broken) - assert restored.strip() == SAMPLE_SCRIPT.strip() - - -def test_deprecate_import(): - b = DeprecateImport( - old_module="from transformers import", - new_module="from transformers.legacy import", - ) - broken = b.apply(SAMPLE_SCRIPT) - assert "transformers.legacy" in broken - r = RestoreImport( - new_module="from transformers.legacy import", - old_module="from transformers import", - ) - restored = r.apply(broken) - assert restored == SAMPLE_SCRIPT - - -def test_restructure_dataset_string_replacement(): - b = RestructureDatasetSchema(old_column="label", new_column="sentiment_label") - broken = b.apply(SAMPLE_SCRIPT) - assert '"sentiment_label"' in broken - assert '"label"' not in broken - r = RestoreColumn(new_column="sentiment_label", old_column="label") - restored = r.apply(broken) - assert restored == SAMPLE_SCRIPT - - -def test_modify_config_field_changes_value(): - b = ModifyConfigField( - config_class="TrainingArguments", - field_name="num_train_epochs", - new_value="999", - ) - broken = b.apply(SAMPLE_SCRIPT) - assert "num_train_epochs=999" in broken - - -def test_change_tokenizer_behavior_replaces_kwarg(): - script = "tok = tokenizer(text, padding=True, truncation=True)" - b = ChangeTokenizerBehavior( - old_kwarg="padding", - old_value="True", - new_kwarg="padding", - new_value='"max_length"', - ) - broken = b.apply(script) - assert 'padding="max_length"' in broken - - -def test_remove_deprecated_method_marks_call(): - b = RemoveDeprecatedMethod( - class_name="Trainer", method_name="evaluate", replacement="evaluate_legacy" - ) - broken = b.apply(SAMPLE_SCRIPT) - assert ".evaluate_DEPRECATED(" in broken - r = RestoreMethod(method_name="evaluate") - restored = r.apply(broken) - assert restored == SAMPLE_SCRIPT - - -def test_change_argument_signature_removes_kwarg(): - b = ChangeArgumentSignature( - function_name="TrainingArguments", - removed_arg="report_to", - added_arg="report_to", - added_value='"none"', - ) - broken = b.apply(SAMPLE_SCRIPT) - assert 'report_to="none"' not in broken - - -def test_change_return_type_swaps_access(): - b = ChangeReturnType( - function_name="evaluate", - old_access="trainer.evaluate()", - new_access="trainer.evaluate().metrics", - ) - broken = b.apply(SAMPLE_SCRIPT) - assert "trainer.evaluate().metrics" in broken - - -def test_parse_spec_round_trip(): - spec = { - "primitive_type": "RenameApiCall", - "params": {"old_name": "evaluate", "new_name": "eval_model"}, - } - primitive = parse_breakage_spec(spec) - assert isinstance(primitive, RenameApiCall) - assert primitive.old_name == "evaluate" - assert primitive.to_spec()["primitive_type"] == "RenameApiCall" - - -def test_parse_spec_unknown_raises(): - with pytest.raises(ValueError): - parse_breakage_spec({"primitive_type": "Bogus"}) - - -def test_parse_spec_ignores_extra_kwargs(): - """LLMs hallucinate kwargs; we should silently filter them.""" - spec = { - "primitive_type": "RenameApiCall", - "params": { - "old_name": "evaluate", - "new_name": "eval_model", - "hallucinated_kwarg": "ignore_me", - }, - } - primitive = parse_breakage_spec(spec) - assert isinstance(primitive, RenameApiCall) - - -def test_breakage_creates_actual_difference(): - b = RenameApiCall(old_name="trainer.train", new_name="trainer.start_training") - broken = b.apply(SAMPLE_SCRIPT) - assert broken != SAMPLE_SCRIPT - - -def test_all_8_primitives_registered(): - expected = { - "RenameApiCall", - "DeprecateImport", - "ChangeArgumentSignature", - "ModifyConfigField", - "RestructureDatasetSchema", - "ChangeTokenizerBehavior", - "RemoveDeprecatedMethod", - "ChangeReturnType", - } - assert set(PRIMITIVE_REGISTRY) == expected - - -def test_breakage_repair_registry_alignment(): - """Every breakage class should have a registered inverse.""" - for breakage_name, repair_name in BREAKAGE_TO_REPAIR.items(): - assert breakage_name in PRIMITIVE_REGISTRY - assert repair_name in REPAIR_REGISTRY - - -def test_seed_corpus_has_at_least_10_scripts(): - sampler = TaskSampler() - assert len(sampler.tasks) >= 10 - assert all(t.script_content for t in sampler.tasks) - assert all(t.task_id for t in sampler.tasks) - - -def test_task_sampler_categories_are_diverse(): - sampler = TaskSampler() - categories = sampler.get_all_categories() - assert len(categories) >= 3, f"Expected at least 3 distinct categories, got {categories}" - - -def test_task_sampler_difficulty_filter(): - sampler = TaskSampler() - # Should not crash even when an unknown difficulty is requested. - task = sampler.sample(difficulty="easy") - if task is not None: - assert task.difficulty == "easy" - - -def test_task_sampler_get_by_id(): - sampler = TaskSampler() - if not sampler.tasks: - pytest.skip("No tasks loaded") - first = sampler.tasks[0] - fetched = sampler.get_by_id(first.task_id) - assert fetched is first +"""Tests for breakage primitives, repair primitives, and task sampler.""" +import pytest + +from forgeenv.primitives.breakage_primitives import ( + ChangeArgumentSignature, + ChangeReturnType, + ChangeTokenizerBehavior, + DeprecateImport, + ModifyConfigField, + PRIMITIVE_REGISTRY, + RemoveDeprecatedMethod, + RenameApiCall, + RestructureDatasetSchema, + parse_breakage_spec, +) +from forgeenv.primitives.repair_primitives import ( + BREAKAGE_TO_REPAIR, + REPAIR_REGISTRY, + RestoreApiCall, + RestoreColumn, + RestoreImport, + RestoreMethod, +) +from forgeenv.tasks.task_sampler import TaskSampler + +SAMPLE_SCRIPT = """ +from transformers import Trainer, TrainingArguments +from datasets import load_dataset + +dataset = load_dataset("glue", "sst2") +dataset = dataset.rename_column("label", "labels") + +args = TrainingArguments(num_train_epochs=3, report_to="none") +trainer = Trainer(model=model, args=args, train_dataset=dataset) +trainer.train() +result = trainer.evaluate() +""" + + +def test_rename_api_call_word_boundary(): + """Renaming should not break out-of-context substrings.""" + b = RenameApiCall(old_name="evaluate", new_name="eval_model") + broken = b.apply(SAMPLE_SCRIPT) + assert "eval_model" in broken + assert "trainer.evaluate" not in broken + # Inverse should restore + r = RestoreApiCall(new_name="eval_model", old_name="evaluate") + restored = r.apply(broken) + assert restored.strip() == SAMPLE_SCRIPT.strip() + + +def test_deprecate_import(): + b = DeprecateImport( + old_module="from transformers import", + new_module="from transformers.legacy import", + ) + broken = b.apply(SAMPLE_SCRIPT) + assert "transformers.legacy" in broken + r = RestoreImport( + new_module="from transformers.legacy import", + old_module="from transformers import", + ) + restored = r.apply(broken) + assert restored == SAMPLE_SCRIPT + + +def test_restructure_dataset_string_replacement(): + b = RestructureDatasetSchema(old_column="label", new_column="sentiment_label") + broken = b.apply(SAMPLE_SCRIPT) + assert '"sentiment_label"' in broken + assert '"label"' not in broken + r = RestoreColumn(new_column="sentiment_label", old_column="label") + restored = r.apply(broken) + assert restored == SAMPLE_SCRIPT + + +def test_modify_config_field_changes_value(): + b = ModifyConfigField( + config_class="TrainingArguments", + field_name="num_train_epochs", + new_value="999", + ) + broken = b.apply(SAMPLE_SCRIPT) + assert "num_train_epochs=999" in broken + + +def test_change_tokenizer_behavior_replaces_kwarg(): + script = "tok = tokenizer(text, padding=True, truncation=True)" + b = ChangeTokenizerBehavior( + old_kwarg="padding", + old_value="True", + new_kwarg="padding", + new_value='"max_length"', + ) + broken = b.apply(script) + assert 'padding="max_length"' in broken + + +def test_remove_deprecated_method_marks_call(): + b = RemoveDeprecatedMethod( + class_name="Trainer", method_name="evaluate", replacement="evaluate_legacy" + ) + broken = b.apply(SAMPLE_SCRIPT) + assert ".evaluate_DEPRECATED(" in broken + r = RestoreMethod(method_name="evaluate") + restored = r.apply(broken) + assert restored == SAMPLE_SCRIPT + + +def test_change_argument_signature_removes_kwarg(): + b = ChangeArgumentSignature( + function_name="TrainingArguments", + removed_arg="report_to", + added_arg="report_to", + added_value='"none"', + ) + broken = b.apply(SAMPLE_SCRIPT) + assert 'report_to="none"' not in broken + + +def test_change_return_type_swaps_access(): + b = ChangeReturnType( + function_name="evaluate", + old_access="trainer.evaluate()", + new_access="trainer.evaluate().metrics", + ) + broken = b.apply(SAMPLE_SCRIPT) + assert "trainer.evaluate().metrics" in broken + + +def test_parse_spec_round_trip(): + spec = { + "primitive_type": "RenameApiCall", + "params": {"old_name": "evaluate", "new_name": "eval_model"}, + } + primitive = parse_breakage_spec(spec) + assert isinstance(primitive, RenameApiCall) + assert primitive.old_name == "evaluate" + assert primitive.to_spec()["primitive_type"] == "RenameApiCall" + + +def test_parse_spec_unknown_raises(): + with pytest.raises(ValueError): + parse_breakage_spec({"primitive_type": "Bogus"}) + + +def test_parse_spec_ignores_extra_kwargs(): + """LLMs hallucinate kwargs; we should silently filter them.""" + spec = { + "primitive_type": "RenameApiCall", + "params": { + "old_name": "evaluate", + "new_name": "eval_model", + "hallucinated_kwarg": "ignore_me", + }, + } + primitive = parse_breakage_spec(spec) + assert isinstance(primitive, RenameApiCall) + + +def test_breakage_creates_actual_difference(): + b = RenameApiCall(old_name="trainer.train", new_name="trainer.start_training") + broken = b.apply(SAMPLE_SCRIPT) + assert broken != SAMPLE_SCRIPT + + +def test_all_8_primitives_registered(): + expected = { + "RenameApiCall", + "DeprecateImport", + "ChangeArgumentSignature", + "ModifyConfigField", + "RestructureDatasetSchema", + "ChangeTokenizerBehavior", + "RemoveDeprecatedMethod", + "ChangeReturnType", + } + assert set(PRIMITIVE_REGISTRY) == expected + + +def test_breakage_repair_registry_alignment(): + """Every breakage class should have a registered inverse.""" + for breakage_name, repair_name in BREAKAGE_TO_REPAIR.items(): + assert breakage_name in PRIMITIVE_REGISTRY + assert repair_name in REPAIR_REGISTRY + + +def test_seed_corpus_has_at_least_10_scripts(): + sampler = TaskSampler() + assert len(sampler.tasks) >= 10 + assert all(t.script_content for t in sampler.tasks) + assert all(t.task_id for t in sampler.tasks) + + +def test_task_sampler_categories_are_diverse(): + sampler = TaskSampler() + categories = sampler.get_all_categories() + assert len(categories) >= 3, f"Expected at least 3 distinct categories, got {categories}" + + +def test_task_sampler_difficulty_filter(): + sampler = TaskSampler() + # Should not crash even when an unknown difficulty is requested. + task = sampler.sample(difficulty="easy") + if task is not None: + assert task.difficulty == "easy" + + +def test_task_sampler_get_by_id(): + sampler = TaskSampler() + if not sampler.tasks: + pytest.skip("No tasks loaded") + first = sampler.tasks[0] + fetched = sampler.get_by_id(first.task_id) + assert fetched is first diff --git a/tests/test_roles.py b/tests/test_roles.py index 25abde5dcb651b7a1d7fbbe1a2fb6b85bcfc176e..8aa460bb11b70efe575c83ef94cb0a3b09aa3b9a 100644 --- a/tests/test_roles.py +++ b/tests/test_roles.py @@ -1,153 +1,153 @@ -"""Tests for the role helpers (drift generator + repair agent).""" -import json - -from forgeenv.env.diff_utils import apply_unified_diff -from forgeenv.primitives.breakage_primitives import RenameApiCall -from forgeenv.roles.drift_generator import ( - BaselineDriftGenerator, - parse_drift_output, - parse_drift_to_primitive, -) -from forgeenv.roles.prompts import ( - DRIFT_GENERATOR_SYSTEM_PROMPT, - REPAIR_AGENT_SYSTEM_PROMPT, - render_drift_generator_prompt, - render_repair_agent_prompt, -) -from forgeenv.roles.repair_agent import ( - BaselineRepairAgent, - extract_diff, - looks_like_diff, -) - - -# ------------------------------------------------------------------- prompts -def test_prompts_are_nonempty(): - assert "Drift Generator" in DRIFT_GENERATOR_SYSTEM_PROMPT - assert "Repair Agent" in REPAIR_AGENT_SYSTEM_PROMPT - - -def test_render_drift_generator_prompt_includes_inputs(): - text = render_drift_generator_prompt( - "import torch", "RenameApiCall", {"transformers": "4.40.0"} - ) - assert "RenameApiCall" in text and "transformers=4.40.0" in text and "import torch" in text - - -def test_render_repair_agent_prompt_includes_error_trace(): - text = render_repair_agent_prompt( - "broken", "AttributeError: foo", {"transformers": "4.50.0"} - ) - assert "AttributeError" in text and "transformers=4.50.0" in text - - -# ------------------------------------------------------------ drift generator -def test_parse_drift_output_handles_fences(): - text = "```json\n{\"primitive_type\": \"RenameApiCall\", \"params\": {\"old_name\": \"a\", \"new_name\": \"b\"}}\n```" - parsed = parse_drift_output(text) - assert parsed is not None and parsed["primitive_type"] == "RenameApiCall" - - -def test_parse_drift_output_handles_prose(): - text = ( - "Here is my breakage idea, it's a rename:\n" - "{\"primitive_type\": \"RenameApiCall\", \"params\": {\"old_name\": \"x\", \"new_name\": \"y\"}}\n" - "Hope this works!" - ) - parsed = parse_drift_output(text) - assert parsed["primitive_type"] == "RenameApiCall" - - -def test_parse_drift_output_returns_none_on_garbage(): - assert parse_drift_output("no JSON here at all") is None - assert parse_drift_output("") is None - - -def test_parse_drift_to_primitive_validates(): - text = '{"primitive_type": "DeprecateImport", "params": {"old_module": "a", "new_module": "b"}}' - primitive = parse_drift_to_primitive(text) - assert primitive is not None and primitive.name == "DeprecateImport" - - -def test_parse_drift_to_primitive_unknown_type(): - text = '{"primitive_type": "NonExistent", "params": {}}' - assert parse_drift_to_primitive(text) is None - - -def test_baseline_drift_generator_produces_valid_spec(): - gen = BaselineDriftGenerator(seed=0) - script = """from transformers import Trainer -trainer = Trainer() -trainer.train() -""" - spec = gen.propose(target_category="RenameApiCall", script=script) - assert spec["primitive_type"] in { - "RenameApiCall", "DeprecateImport", "ChangeArgumentSignature", - "ModifyConfigField", "RestructureDatasetSchema", "ChangeTokenizerBehavior", - "RemoveDeprecatedMethod", "ChangeReturnType", - } - assert isinstance(spec["params"], dict) - - -def test_baseline_drift_generator_spec_actually_breaks_script(): - gen = BaselineDriftGenerator(seed=42) - script = """from transformers import Trainer -trainer = Trainer() -trainer.train() -""" - spec = gen.propose(target_category="RenameApiCall", script=script) - primitive = parse_drift_to_primitive(json.dumps(spec)) - broken = primitive.apply(script) - # If we got a 'RenameApiCall' on trainer.train, it must have changed something. - if spec["primitive_type"] == "RenameApiCall" and spec["params"].get("old_name") in script: - assert broken != script - - -# -------------------------------------------------------------- repair agent -def test_extract_diff_strips_fences(): - text = "Here's my fix:\n```diff\n--- a/x\n+++ b/x\n@@\n-foo\n+bar\n```\n" - diff = extract_diff(text) - assert diff.startswith("---") and "foo" in diff and "bar" in diff - - -def test_extract_diff_strips_chain_of_thought(): - text = ( - "Let me think... the error is X, so I should rename Y to Z.\n" - "Here is the diff:\n" - "--- a/train.py\n+++ b/train.py\n@@ -1 +1 @@\n-import torch\n+import torch.legacy\n" - ) - diff = extract_diff(text) - assert diff.startswith("---") - assert "Let me think" not in diff - - -def test_looks_like_diff_positive(): - diff = "--- a/x\n+++ b/x\n@@ -1 +1 @@\n-foo\n+bar\n" - assert looks_like_diff(diff) - - -def test_looks_like_diff_negative(): - assert not looks_like_diff("just some prose without any diff structure") - - -def test_baseline_repair_agent_oracle_path(): - agent = BaselineRepairAgent() - original = "import torch\nprint('hi')\n" - broken = "import torch.legacy\nprint('hi')\n" - diff = agent.repair(broken, breakage_spec=None, original_script=original) - assert diff and "torch.legacy" in diff - repaired = apply_unified_diff(broken, diff) - assert repaired == original - - -def test_baseline_repair_agent_inverts_breakage_spec(): - agent = BaselineRepairAgent() - original = "from transformers import Trainer\ntrainer.train()\n" - breakage = RenameApiCall(old_name="trainer.train", new_name="trainer.start_training") - broken = breakage.apply(original) - spec = breakage.to_spec() - - diff = agent.repair(broken, breakage_spec=spec) - assert diff - repaired = apply_unified_diff(broken, diff) - assert "trainer.train()" in repaired +"""Tests for the role helpers (drift generator + repair agent).""" +import json + +from forgeenv.env.diff_utils import apply_unified_diff +from forgeenv.primitives.breakage_primitives import RenameApiCall +from forgeenv.roles.drift_generator import ( + BaselineDriftGenerator, + parse_drift_output, + parse_drift_to_primitive, +) +from forgeenv.roles.prompts import ( + DRIFT_GENERATOR_SYSTEM_PROMPT, + REPAIR_AGENT_SYSTEM_PROMPT, + render_drift_generator_prompt, + render_repair_agent_prompt, +) +from forgeenv.roles.repair_agent import ( + BaselineRepairAgent, + extract_diff, + looks_like_diff, +) + + +# ------------------------------------------------------------------- prompts +def test_prompts_are_nonempty(): + assert "Drift Generator" in DRIFT_GENERATOR_SYSTEM_PROMPT + assert "Repair Agent" in REPAIR_AGENT_SYSTEM_PROMPT + + +def test_render_drift_generator_prompt_includes_inputs(): + text = render_drift_generator_prompt( + "import torch", "RenameApiCall", {"transformers": "4.40.0"} + ) + assert "RenameApiCall" in text and "transformers=4.40.0" in text and "import torch" in text + + +def test_render_repair_agent_prompt_includes_error_trace(): + text = render_repair_agent_prompt( + "broken", "AttributeError: foo", {"transformers": "4.50.0"} + ) + assert "AttributeError" in text and "transformers=4.50.0" in text + + +# ------------------------------------------------------------ drift generator +def test_parse_drift_output_handles_fences(): + text = "```json\n{\"primitive_type\": \"RenameApiCall\", \"params\": {\"old_name\": \"a\", \"new_name\": \"b\"}}\n```" + parsed = parse_drift_output(text) + assert parsed is not None and parsed["primitive_type"] == "RenameApiCall" + + +def test_parse_drift_output_handles_prose(): + text = ( + "Here is my breakage idea, it's a rename:\n" + "{\"primitive_type\": \"RenameApiCall\", \"params\": {\"old_name\": \"x\", \"new_name\": \"y\"}}\n" + "Hope this works!" + ) + parsed = parse_drift_output(text) + assert parsed["primitive_type"] == "RenameApiCall" + + +def test_parse_drift_output_returns_none_on_garbage(): + assert parse_drift_output("no JSON here at all") is None + assert parse_drift_output("") is None + + +def test_parse_drift_to_primitive_validates(): + text = '{"primitive_type": "DeprecateImport", "params": {"old_module": "a", "new_module": "b"}}' + primitive = parse_drift_to_primitive(text) + assert primitive is not None and primitive.name == "DeprecateImport" + + +def test_parse_drift_to_primitive_unknown_type(): + text = '{"primitive_type": "NonExistent", "params": {}}' + assert parse_drift_to_primitive(text) is None + + +def test_baseline_drift_generator_produces_valid_spec(): + gen = BaselineDriftGenerator(seed=0) + script = """from transformers import Trainer +trainer = Trainer() +trainer.train() +""" + spec = gen.propose(target_category="RenameApiCall", script=script) + assert spec["primitive_type"] in { + "RenameApiCall", "DeprecateImport", "ChangeArgumentSignature", + "ModifyConfigField", "RestructureDatasetSchema", "ChangeTokenizerBehavior", + "RemoveDeprecatedMethod", "ChangeReturnType", + } + assert isinstance(spec["params"], dict) + + +def test_baseline_drift_generator_spec_actually_breaks_script(): + gen = BaselineDriftGenerator(seed=42) + script = """from transformers import Trainer +trainer = Trainer() +trainer.train() +""" + spec = gen.propose(target_category="RenameApiCall", script=script) + primitive = parse_drift_to_primitive(json.dumps(spec)) + broken = primitive.apply(script) + # If we got a 'RenameApiCall' on trainer.train, it must have changed something. + if spec["primitive_type"] == "RenameApiCall" and spec["params"].get("old_name") in script: + assert broken != script + + +# -------------------------------------------------------------- repair agent +def test_extract_diff_strips_fences(): + text = "Here's my fix:\n```diff\n--- a/x\n+++ b/x\n@@\n-foo\n+bar\n```\n" + diff = extract_diff(text) + assert diff.startswith("---") and "foo" in diff and "bar" in diff + + +def test_extract_diff_strips_chain_of_thought(): + text = ( + "Let me think... the error is X, so I should rename Y to Z.\n" + "Here is the diff:\n" + "--- a/train.py\n+++ b/train.py\n@@ -1 +1 @@\n-import torch\n+import torch.legacy\n" + ) + diff = extract_diff(text) + assert diff.startswith("---") + assert "Let me think" not in diff + + +def test_looks_like_diff_positive(): + diff = "--- a/x\n+++ b/x\n@@ -1 +1 @@\n-foo\n+bar\n" + assert looks_like_diff(diff) + + +def test_looks_like_diff_negative(): + assert not looks_like_diff("just some prose without any diff structure") + + +def test_baseline_repair_agent_oracle_path(): + agent = BaselineRepairAgent() + original = "import torch\nprint('hi')\n" + broken = "import torch.legacy\nprint('hi')\n" + diff = agent.repair(broken, breakage_spec=None, original_script=original) + assert diff and "torch.legacy" in diff + repaired = apply_unified_diff(broken, diff) + assert repaired == original + + +def test_baseline_repair_agent_inverts_breakage_spec(): + agent = BaselineRepairAgent() + original = "from transformers import Trainer\ntrainer.train()\n" + breakage = RenameApiCall(old_name="trainer.train", new_name="trainer.start_training") + broken = breakage.apply(original) + spec = breakage.to_spec() + + diff = agent.repair(broken, breakage_spec=spec) + assert diff + repaired = apply_unified_diff(broken, diff) + assert "trainer.train()" in repaired diff --git a/tests/test_simulation_mode.py b/tests/test_simulation_mode.py index 9d36c6416796dce7fe82f8ca0e54d9175e5c3b0a..0072c16e79c33788758c37c8bf3d2ffa5f96f677 100644 --- a/tests/test_simulation_mode.py +++ b/tests/test_simulation_mode.py @@ -1,76 +1,76 @@ -"""Tests for the simulation-mode executor.""" -from forgeenv.sandbox.simulation_mode import SimulationExecutor -from forgeenv.tasks.models import Task - -VALID_HF = """ -from transformers import Trainer, TrainingArguments -from datasets import load_dataset -import torch - -dataset = load_dataset("glue", "sst2") -trainer = Trainer(model=None, args=None, train_dataset=dataset) -trainer.train() -trainer.save_model("/tmp/forge_output/checkpoint") -print("TRAINING_COMPLETE") -""" - -SYNTAX_ERROR = "def foo(\n broken" - -OS_IMPORT = "import os\nos.listdir('.')" - - -def _task(content: str) -> Task: - return Task( - task_id="t", - description="d", - script_content=content, - difficulty="easy", - ) - - -def test_valid_script_can_succeed(): - """With seed 0, the valid HF script eventually returns a positive case.""" - executor = SimulationExecutor(seed=0) - result = executor.execute(VALID_HF, _task(VALID_HF)) - # Either succeeds (exit 0 with TRAINING_COMPLETE) or fails with realistic - # HF error; never crashes or returns an empty result. - assert result.exit_code in (0, 1) - if result.exit_code == 0: - assert "TRAINING_COMPLETE" in result.stdout - - -def test_syntax_error_fails(): - executor = SimulationExecutor(seed=0) - result = executor.execute(SYNTAX_ERROR, _task(SYNTAX_ERROR)) - assert result.exit_code == 1 - assert "SyntaxError" in result.stderr - - -def test_forbidden_import_fails(): - executor = SimulationExecutor(seed=0) - result = executor.execute(OS_IMPORT, _task(OS_IMPORT)) - assert result.exit_code == 1 - assert "Validation failed" in result.stderr - - -def test_simulation_is_fast(): - """Simulation mode must complete each call in <100ms wall_time. - - The reported wall_time_ms field includes a synthetic delay so we measure - real elapsed time at this layer instead. - """ - import time - executor = SimulationExecutor(seed=0) - t0 = time.time() - executor.execute(VALID_HF, _task(VALID_HF)) - elapsed_ms = (time.time() - t0) * 1000 - assert elapsed_ms < 200, f"Simulation took {elapsed_ms:.1f}ms" - - -def test_seed_is_deterministic(): - e1 = SimulationExecutor(seed=42) - e2 = SimulationExecutor(seed=42) - r1 = e1.execute(VALID_HF, _task(VALID_HF)) - r2 = e2.execute(VALID_HF, _task(VALID_HF)) - assert r1.exit_code == r2.exit_code - assert r1.stderr == r2.stderr +"""Tests for the simulation-mode executor.""" +from forgeenv.sandbox.simulation_mode import SimulationExecutor +from forgeenv.tasks.models import Task + +VALID_HF = """ +from transformers import Trainer, TrainingArguments +from datasets import load_dataset +import torch + +dataset = load_dataset("glue", "sst2") +trainer = Trainer(model=None, args=None, train_dataset=dataset) +trainer.train() +trainer.save_model("/tmp/forge_output/checkpoint") +print("TRAINING_COMPLETE") +""" + +SYNTAX_ERROR = "def foo(\n broken" + +OS_IMPORT = "import os\nos.listdir('.')" + + +def _task(content: str) -> Task: + return Task( + task_id="t", + description="d", + script_content=content, + difficulty="easy", + ) + + +def test_valid_script_can_succeed(): + """With seed 0, the valid HF script eventually returns a positive case.""" + executor = SimulationExecutor(seed=0) + result = executor.execute(VALID_HF, _task(VALID_HF)) + # Either succeeds (exit 0 with TRAINING_COMPLETE) or fails with realistic + # HF error; never crashes or returns an empty result. + assert result.exit_code in (0, 1) + if result.exit_code == 0: + assert "TRAINING_COMPLETE" in result.stdout + + +def test_syntax_error_fails(): + executor = SimulationExecutor(seed=0) + result = executor.execute(SYNTAX_ERROR, _task(SYNTAX_ERROR)) + assert result.exit_code == 1 + assert "SyntaxError" in result.stderr + + +def test_forbidden_import_fails(): + executor = SimulationExecutor(seed=0) + result = executor.execute(OS_IMPORT, _task(OS_IMPORT)) + assert result.exit_code == 1 + assert "Validation failed" in result.stderr + + +def test_simulation_is_fast(): + """Simulation mode must complete each call in <100ms wall_time. + + The reported wall_time_ms field includes a synthetic delay so we measure + real elapsed time at this layer instead. + """ + import time + executor = SimulationExecutor(seed=0) + t0 = time.time() + executor.execute(VALID_HF, _task(VALID_HF)) + elapsed_ms = (time.time() - t0) * 1000 + assert elapsed_ms < 200, f"Simulation took {elapsed_ms:.1f}ms" + + +def test_seed_is_deterministic(): + e1 = SimulationExecutor(seed=42) + e2 = SimulationExecutor(seed=42) + r1 = e1.execute(VALID_HF, _task(VALID_HF)) + r2 = e2.execute(VALID_HF, _task(VALID_HF)) + assert r1.exit_code == r2.exit_code + assert r1.stderr == r2.stderr diff --git a/tests/test_training.py b/tests/test_training.py index cb1d65f42f2bc93097afcdb6b7226ee139958dd8..66f0a278325615e9a4b3cc2381061688f3dbe4b5 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -1,61 +1,61 @@ -"""Smoke tests for the training pipeline (rollout, dry-run trainers).""" -import json -import tempfile -from pathlib import Path - -from forgeenv.env.forge_environment import ForgeEnvironment -from forgeenv.training.grpo_repair import run_grpo -from forgeenv.training.grpo_drift import run_drift_grpo_dry_run -from forgeenv.training.rollout import ( - baseline_oracle_repair_generate, - rollout_one_episode, -) - - -def test_rollout_one_episode_baseline_no_op_repair(): - env = ForgeEnvironment(seed=1) - result = rollout_one_episode(env) - assert result.task_id - assert result.primitive_type - assert isinstance(result.visible_reward, float) - assert "executed_cleanly" in result.held_out_breakdown - - -def test_rollout_one_episode_with_oracle_repair_succeeds(): - env = ForgeEnvironment(seed=2) - repair_gen = baseline_oracle_repair_generate(env) - result = rollout_one_episode(env, repair_generate=repair_gen) - # Oracle should usually score well on `intent_preserved` (script identical to original). - assert result.held_out_breakdown.get("intent_preserved", 0.0) > 0.7 - - -def test_grpo_repair_dry_run_smoke(): - with tempfile.TemporaryDirectory() as tmp: - run_grpo( - base_model="(unused-in-dry-run)", - adapter_path=None, - output_dir=tmp, - total_episodes=5, - group_size=2, - seed=0, - use_unsloth=False, - ) - rewards_path = Path(tmp) / "dry_run_rewards.json" - assert rewards_path.exists() - rewards = json.loads(rewards_path.read_text()) - assert len(rewards) == 5 - assert all(isinstance(r, (int, float)) for r in rewards) - - -def test_grpo_drift_dry_run_smoke(): - with tempfile.TemporaryDirectory() as tmp: - run_drift_grpo_dry_run( - output_dir=tmp, total_episodes=3, group_size=2, seed=0 - ) - log_path = Path(tmp) / "drift_dry_run.json" - assert log_path.exists() - log = json.loads(log_path.read_text()) - assert len(log) == 3 - for entry in log: - assert "rewards" in entry and "candidates" in entry - assert all(0.0 <= r <= 2.0 for r in entry["rewards"]) +"""Smoke tests for the training pipeline (rollout, dry-run trainers).""" +import json +import tempfile +from pathlib import Path + +from forgeenv.env.forge_environment import ForgeEnvironment +from forgeenv.training.grpo_repair import run_grpo +from forgeenv.training.grpo_drift import run_drift_grpo_dry_run +from forgeenv.training.rollout import ( + baseline_oracle_repair_generate, + rollout_one_episode, +) + + +def test_rollout_one_episode_baseline_no_op_repair(): + env = ForgeEnvironment(seed=1) + result = rollout_one_episode(env) + assert result.task_id + assert result.primitive_type + assert isinstance(result.visible_reward, float) + assert "executed_cleanly" in result.held_out_breakdown + + +def test_rollout_one_episode_with_oracle_repair_succeeds(): + env = ForgeEnvironment(seed=2) + repair_gen = baseline_oracle_repair_generate(env) + result = rollout_one_episode(env, repair_generate=repair_gen) + # Oracle should usually score well on `intent_preserved` (script identical to original). + assert result.held_out_breakdown.get("intent_preserved", 0.0) > 0.7 + + +def test_grpo_repair_dry_run_smoke(): + with tempfile.TemporaryDirectory() as tmp: + run_grpo( + base_model="(unused-in-dry-run)", + adapter_path=None, + output_dir=tmp, + total_episodes=5, + group_size=2, + seed=0, + use_unsloth=False, + ) + rewards_path = Path(tmp) / "dry_run_rewards.json" + assert rewards_path.exists() + rewards = json.loads(rewards_path.read_text()) + assert len(rewards) == 5 + assert all(isinstance(r, (int, float)) for r in rewards) + + +def test_grpo_drift_dry_run_smoke(): + with tempfile.TemporaryDirectory() as tmp: + run_drift_grpo_dry_run( + output_dir=tmp, total_episodes=3, group_size=2, seed=0 + ) + log_path = Path(tmp) / "drift_dry_run.json" + assert log_path.exists() + log = json.loads(log_path.read_text()) + assert len(log) == 3 + for entry in log: + assert "rewards" in entry and "candidates" in entry + assert all(0.0 <= r <= 2.0 for r in entry["rewards"]) diff --git a/tests/test_warmstart.py b/tests/test_warmstart.py index 199bd685ed383f5633d02cfdbfa3d57867a7b5b5..f397f256ceeec1b26c4219203db12137701b40c1 100644 --- a/tests/test_warmstart.py +++ b/tests/test_warmstart.py @@ -1,37 +1,37 @@ -"""Smoke tests for warm-start pair generation.""" -import json -import tempfile -from pathlib import Path - -from warmstart.generate_pairs import generate_pairs - - -def test_generate_pairs_produces_minimum_count(): - with tempfile.TemporaryDirectory() as tmp: - counts = generate_pairs(target_count=50, out_dir=Path(tmp)) - assert counts["repair_pairs"] >= 50 - assert counts["drift_pairs"] >= 50 - - repair_jsonl = Path(tmp) / "repair_pairs.jsonl" - drift_jsonl = Path(tmp) / "drift_pairs.jsonl" - assert repair_jsonl.exists() - assert drift_jsonl.exists() - - first = json.loads(repair_jsonl.read_text(encoding="utf-8").splitlines()[0]) - assert first["role_target"] == "repair_agent" - assert "messages" in first and len(first["messages"]) == 3 - assert first["messages"][-1]["content"] # non-empty assistant content - - first_drift = json.loads(drift_jsonl.read_text(encoding="utf-8").splitlines()[0]) - assert first_drift["role_target"] == "drift_generator" - body = first_drift["messages"][-1]["content"] - parsed = json.loads(body) - assert "primitive_type" in parsed and "params" in parsed - - -def test_generate_pairs_covers_multiple_primitive_types(): - with tempfile.TemporaryDirectory() as tmp: - generate_pairs(target_count=50, out_dir=Path(tmp)) - summary = json.loads((Path(tmp) / "summary.json").read_text(encoding="utf-8")) - assert len(summary["primitives_covered"]) >= 5 - assert len(summary["tasks_covered"]) >= 5 +"""Smoke tests for warm-start pair generation.""" +import json +import tempfile +from pathlib import Path + +from warmstart.generate_pairs import generate_pairs + + +def test_generate_pairs_produces_minimum_count(): + with tempfile.TemporaryDirectory() as tmp: + counts = generate_pairs(target_count=50, out_dir=Path(tmp)) + assert counts["repair_pairs"] >= 50 + assert counts["drift_pairs"] >= 50 + + repair_jsonl = Path(tmp) / "repair_pairs.jsonl" + drift_jsonl = Path(tmp) / "drift_pairs.jsonl" + assert repair_jsonl.exists() + assert drift_jsonl.exists() + + first = json.loads(repair_jsonl.read_text(encoding="utf-8").splitlines()[0]) + assert first["role_target"] == "repair_agent" + assert "messages" in first and len(first["messages"]) == 3 + assert first["messages"][-1]["content"] # non-empty assistant content + + first_drift = json.loads(drift_jsonl.read_text(encoding="utf-8").splitlines()[0]) + assert first_drift["role_target"] == "drift_generator" + body = first_drift["messages"][-1]["content"] + parsed = json.loads(body) + assert "primitive_type" in parsed and "params" in parsed + + +def test_generate_pairs_covers_multiple_primitive_types(): + with tempfile.TemporaryDirectory() as tmp: + generate_pairs(target_count=50, out_dir=Path(tmp)) + summary = json.loads((Path(tmp) / "summary.json").read_text(encoding="utf-8")) + assert len(summary["primitives_covered"]) >= 5 + assert len(summary["tasks_covered"]) >= 5 diff --git a/warmstart/generate_pairs.py b/warmstart/generate_pairs.py index b5bfa7dbe02be19f8c04943ee942856818ae9ccc..025f5e5b68c506dfc85e1e80959c882b8715679f 100644 --- a/warmstart/generate_pairs.py +++ b/warmstart/generate_pairs.py @@ -1,188 +1,188 @@ -"""Warm-start pair generator for both roles. - -Produces two parallel JSONL files: - - warmstart/data/repair_pairs.jsonl -- (prompt, completion) for Repair Agent SFT - warmstart/data/drift_pairs.jsonl -- (prompt, completion) for Drift Generator SFT - -Each row has the canonical chat-template fields: - {"messages": [{"role": "system", ...}, {"role": "user", ...}, {"role": "assistant", ...}], - "task_id": ..., "primitive_type": ..., "category": ...} - -We generate at least 50 pairs (the hackathon brief requires this minimum) by -combining each seed-corpus script with each applicable breakage primitive -configuration. The Drift Generator SFT teaches it to emit clean JSON; the -Repair Agent SFT teaches it to emit canonical unified diffs. - -Usage: - python warmstart/generate_pairs.py [--target_count 64] [--out_dir warmstart/data] -""" -from __future__ import annotations - -import argparse -import json -from pathlib import Path -from typing import Iterable, Optional - -from forgeenv.env.diff_utils import make_unified_diff -from forgeenv.primitives.breakage_primitives import ( - PRIMITIVE_REGISTRY, - parse_breakage_spec, -) -from forgeenv.roles.drift_generator import _DEFAULT_PARAMS_BY_TYPE -from forgeenv.roles.prompts import ( - DRIFT_GENERATOR_SYSTEM_PROMPT, - REPAIR_AGENT_SYSTEM_PROMPT, - render_drift_generator_prompt, - render_repair_agent_prompt, -) -from forgeenv.roles.repair_agent import BaselineRepairAgent -from forgeenv.tasks.task_sampler import TaskSampler - - -def _candidate_breakages(script: str) -> list[dict]: - """Yield breakage specs whose default params we know will mutate `script`.""" - out: list[dict] = [] - for ptype, param_options in _DEFAULT_PARAMS_BY_TYPE.items(): - for params in param_options: - spec = {"primitive_type": ptype, "params": dict(params)} - try: - primitive = parse_breakage_spec(spec) - except ValueError: - continue - mutated = primitive.apply(script) - if mutated != script: - out.append(spec) - return out - - -def _render_pairs_for_task( - task_id: str, - script: str, - library_versions: dict, - repair_agent: BaselineRepairAgent, -) -> list[dict]: - pairs = [] - for spec in _candidate_breakages(script): - primitive = parse_breakage_spec(spec) - broken = primitive.apply(script) - if broken == script: - continue - - diff = repair_agent.repair(broken, breakage_spec=spec, original_script=script) - if not diff: - continue - - # Repair Agent pair - repair_user = render_repair_agent_prompt( - broken_script=broken, - error_trace=f"[simulated] {primitive.description}", - library_versions=library_versions, - target_category=primitive.category, - ) - pairs.append( - { - "role_target": "repair_agent", - "task_id": task_id, - "primitive_type": spec["primitive_type"], - "category": primitive.category, - "messages": [ - {"role": "system", "content": REPAIR_AGENT_SYSTEM_PROMPT}, - {"role": "user", "content": repair_user}, - {"role": "assistant", "content": diff}, - ], - } - ) - - # Drift Generator pair (predict the breakage spec from the working script) - drift_user = render_drift_generator_prompt( - script=script, - target_category=spec["primitive_type"], - library_versions=library_versions, - ) - pairs.append( - { - "role_target": "drift_generator", - "task_id": task_id, - "primitive_type": spec["primitive_type"], - "category": primitive.category, - "messages": [ - {"role": "system", "content": DRIFT_GENERATOR_SYSTEM_PROMPT}, - {"role": "user", "content": drift_user}, - {"role": "assistant", "content": json.dumps(spec, indent=2)}, - ], - } - ) - return pairs - - -def generate_pairs( - target_count: int = 64, out_dir: Optional[Path] = None -) -> dict[str, int]: - out_dir = Path(out_dir) if out_dir is not None else Path("warmstart/data") - out_dir.mkdir(parents=True, exist_ok=True) - - library_versions = {"transformers": "4.40.0", "datasets": "2.18.0", "trl": "0.10.0"} - sampler = TaskSampler() - repair_agent = BaselineRepairAgent() - - repair_pairs: list[dict] = [] - drift_pairs: list[dict] = [] - for task in sampler.tasks: - for pair in _render_pairs_for_task( - task_id=task.task_id, - script=task.script_content, - library_versions=library_versions, - repair_agent=repair_agent, - ): - if pair["role_target"] == "repair_agent": - repair_pairs.append(pair) - else: - drift_pairs.append(pair) - - # If we don't have enough, duplicate-with-shuffle. (We never do this in - # practice β€” the corpus produces > 64 pairs β€” but the safety net is cheap.) - while len(repair_pairs) < target_count and repair_pairs: - repair_pairs.append(repair_pairs[len(repair_pairs) % len(repair_pairs)]) - while len(drift_pairs) < target_count and drift_pairs: - drift_pairs.append(drift_pairs[len(drift_pairs) % len(drift_pairs)]) - - repair_path = out_dir / "repair_pairs.jsonl" - drift_path = out_dir / "drift_pairs.jsonl" - with repair_path.open("w", encoding="utf-8") as f: - for row in repair_pairs: - f.write(json.dumps(row) + "\n") - with drift_path.open("w", encoding="utf-8") as f: - for row in drift_pairs: - f.write(json.dumps(row) + "\n") - - counts = {"repair_pairs": len(repair_pairs), "drift_pairs": len(drift_pairs)} - summary_path = out_dir / "summary.json" - summary_path.write_text( - json.dumps( - { - **counts, - "target_count": target_count, - "primitives_covered": sorted( - {p["primitive_type"] for p in repair_pairs} - ), - "tasks_covered": sorted({p["task_id"] for p in repair_pairs}), - }, - indent=2, - ), - encoding="utf-8", - ) - return counts - - -def _parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--target_count", type=int, default=64) - parser.add_argument("--out_dir", type=str, default="warmstart/data") - return parser.parse_args() - - -if __name__ == "__main__": - args = _parse_args() - counts = generate_pairs(target_count=args.target_count, out_dir=Path(args.out_dir)) - print(json.dumps(counts, indent=2)) +"""Warm-start pair generator for both roles. + +Produces two parallel JSONL files: + + warmstart/data/repair_pairs.jsonl -- (prompt, completion) for Repair Agent SFT + warmstart/data/drift_pairs.jsonl -- (prompt, completion) for Drift Generator SFT + +Each row has the canonical chat-template fields: + {"messages": [{"role": "system", ...}, {"role": "user", ...}, {"role": "assistant", ...}], + "task_id": ..., "primitive_type": ..., "category": ...} + +We generate at least 50 pairs (the hackathon brief requires this minimum) by +combining each seed-corpus script with each applicable breakage primitive +configuration. The Drift Generator SFT teaches it to emit clean JSON; the +Repair Agent SFT teaches it to emit canonical unified diffs. + +Usage: + python warmstart/generate_pairs.py [--target_count 64] [--out_dir warmstart/data] +""" +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Iterable, Optional + +from forgeenv.env.diff_utils import make_unified_diff +from forgeenv.primitives.breakage_primitives import ( + PRIMITIVE_REGISTRY, + parse_breakage_spec, +) +from forgeenv.roles.drift_generator import _DEFAULT_PARAMS_BY_TYPE +from forgeenv.roles.prompts import ( + DRIFT_GENERATOR_SYSTEM_PROMPT, + REPAIR_AGENT_SYSTEM_PROMPT, + render_drift_generator_prompt, + render_repair_agent_prompt, +) +from forgeenv.roles.repair_agent import BaselineRepairAgent +from forgeenv.tasks.task_sampler import TaskSampler + + +def _candidate_breakages(script: str) -> list[dict]: + """Yield breakage specs whose default params we know will mutate `script`.""" + out: list[dict] = [] + for ptype, param_options in _DEFAULT_PARAMS_BY_TYPE.items(): + for params in param_options: + spec = {"primitive_type": ptype, "params": dict(params)} + try: + primitive = parse_breakage_spec(spec) + except ValueError: + continue + mutated = primitive.apply(script) + if mutated != script: + out.append(spec) + return out + + +def _render_pairs_for_task( + task_id: str, + script: str, + library_versions: dict, + repair_agent: BaselineRepairAgent, +) -> list[dict]: + pairs = [] + for spec in _candidate_breakages(script): + primitive = parse_breakage_spec(spec) + broken = primitive.apply(script) + if broken == script: + continue + + diff = repair_agent.repair(broken, breakage_spec=spec, original_script=script) + if not diff: + continue + + # Repair Agent pair + repair_user = render_repair_agent_prompt( + broken_script=broken, + error_trace=f"[simulated] {primitive.description}", + library_versions=library_versions, + target_category=primitive.category, + ) + pairs.append( + { + "role_target": "repair_agent", + "task_id": task_id, + "primitive_type": spec["primitive_type"], + "category": primitive.category, + "messages": [ + {"role": "system", "content": REPAIR_AGENT_SYSTEM_PROMPT}, + {"role": "user", "content": repair_user}, + {"role": "assistant", "content": diff}, + ], + } + ) + + # Drift Generator pair (predict the breakage spec from the working script) + drift_user = render_drift_generator_prompt( + script=script, + target_category=spec["primitive_type"], + library_versions=library_versions, + ) + pairs.append( + { + "role_target": "drift_generator", + "task_id": task_id, + "primitive_type": spec["primitive_type"], + "category": primitive.category, + "messages": [ + {"role": "system", "content": DRIFT_GENERATOR_SYSTEM_PROMPT}, + {"role": "user", "content": drift_user}, + {"role": "assistant", "content": json.dumps(spec, indent=2)}, + ], + } + ) + return pairs + + +def generate_pairs( + target_count: int = 64, out_dir: Optional[Path] = None +) -> dict[str, int]: + out_dir = Path(out_dir) if out_dir is not None else Path("warmstart/data") + out_dir.mkdir(parents=True, exist_ok=True) + + library_versions = {"transformers": "4.40.0", "datasets": "2.18.0", "trl": "0.10.0"} + sampler = TaskSampler() + repair_agent = BaselineRepairAgent() + + repair_pairs: list[dict] = [] + drift_pairs: list[dict] = [] + for task in sampler.tasks: + for pair in _render_pairs_for_task( + task_id=task.task_id, + script=task.script_content, + library_versions=library_versions, + repair_agent=repair_agent, + ): + if pair["role_target"] == "repair_agent": + repair_pairs.append(pair) + else: + drift_pairs.append(pair) + + # If we don't have enough, duplicate-with-shuffle. (We never do this in + # practice β€” the corpus produces > 64 pairs β€” but the safety net is cheap.) + while len(repair_pairs) < target_count and repair_pairs: + repair_pairs.append(repair_pairs[len(repair_pairs) % len(repair_pairs)]) + while len(drift_pairs) < target_count and drift_pairs: + drift_pairs.append(drift_pairs[len(drift_pairs) % len(drift_pairs)]) + + repair_path = out_dir / "repair_pairs.jsonl" + drift_path = out_dir / "drift_pairs.jsonl" + with repair_path.open("w", encoding="utf-8") as f: + for row in repair_pairs: + f.write(json.dumps(row) + "\n") + with drift_path.open("w", encoding="utf-8") as f: + for row in drift_pairs: + f.write(json.dumps(row) + "\n") + + counts = {"repair_pairs": len(repair_pairs), "drift_pairs": len(drift_pairs)} + summary_path = out_dir / "summary.json" + summary_path.write_text( + json.dumps( + { + **counts, + "target_count": target_count, + "primitives_covered": sorted( + {p["primitive_type"] for p in repair_pairs} + ), + "tasks_covered": sorted({p["task_id"] for p in repair_pairs}), + }, + indent=2, + ), + encoding="utf-8", + ) + return counts + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--target_count", type=int, default=64) + parser.add_argument("--out_dir", type=str, default="warmstart/data") + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_args() + counts = generate_pairs(target_count=args.target_count, out_dir=Path(args.out_dir)) + print(json.dumps(counts, indent=2))