forgeenv source snapshot for training job
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +35 -35
- README.md +180 -180
- debug_trace.py +18 -18
- demo-space/README.md +31 -31
- demo-space/app.py +444 -444
- demo-space/requirements.txt +7 -7
- demo-space/test_heuristic.py +99 -99
- forgeenv-space/Dockerfile +25 -25
- forgeenv-space/README.md +85 -85
- forgeenv-space/openenv.yaml +24 -24
- forgeenv-space/requirements.txt +9 -9
- forgeenv/__init__.py +4 -4
- forgeenv/artifacts/repair_library.py +120 -120
- forgeenv/drift/library_drift_engine.py +74 -74
- forgeenv/env/actions.py +50 -50
- forgeenv/env/diff_utils.py +163 -163
- forgeenv/env/forge_environment.py +259 -259
- forgeenv/env/observations.py +29 -29
- forgeenv/env/server.py +126 -126
- forgeenv/primitives/breakage_primitives.py +282 -282
- forgeenv/primitives/drift_taxonomy.yaml +217 -217
- forgeenv/primitives/repair_primitives.py +241 -241
- forgeenv/roles/drift_generator.py +170 -170
- forgeenv/roles/prompts.py +102 -102
- forgeenv/roles/repair_agent.py +153 -153
- forgeenv/roles/teacher.py +58 -58
- forgeenv/sandbox/ast_validator.py +70 -70
- forgeenv/sandbox/simulation_mode.py +142 -142
- forgeenv/tasks/models.py +45 -45
- forgeenv/tasks/seed_corpus/albert_qa.py +67 -67
- forgeenv/tasks/seed_corpus/bert_ner.py +55 -55
- forgeenv/tasks/seed_corpus/distilbert_sst2.py +53 -53
- forgeenv/tasks/seed_corpus/electra_classification.py +44 -44
- forgeenv/tasks/seed_corpus/gpt2_textgen.py +43 -43
- forgeenv/tasks/seed_corpus/logistic_classifier.py +36 -36
- forgeenv/tasks/seed_corpus/roberta_sentiment.py +44 -44
- forgeenv/tasks/seed_corpus/simple_regression.py +28 -28
- forgeenv/tasks/seed_corpus/t5_summarization.py +55 -55
- forgeenv/tasks/seed_corpus/tiny_mlp_mnist.py +38 -38
- forgeenv/tasks/seed_corpus/vit_cifar10.py +41 -41
- forgeenv/tasks/task_sampler.py +105 -105
- forgeenv/training/grpo_drift.py +168 -168
- forgeenv/training/grpo_repair.py +213 -213
- forgeenv/training/plots.py +128 -128
- forgeenv/training/reward_functions.py +127 -127
- forgeenv/training/rollout.py +173 -173
- forgeenv/training/sft_warmstart.py +166 -166
- forgeenv/verifier/held_out_evaluator.py +134 -134
- forgeenv/verifier/visible_verifier.py +64 -64
- openenv.yaml +23 -23
.gitignore
CHANGED
|
@@ -1,35 +1,35 @@
|
|
| 1 |
-
__pycache__/
|
| 2 |
-
*.pyc
|
| 3 |
-
*.pyo
|
| 4 |
-
*.pyd
|
| 5 |
-
.Python
|
| 6 |
-
*.egg-info/
|
| 7 |
-
.eggs/
|
| 8 |
-
build/
|
| 9 |
-
dist/
|
| 10 |
-
.pytest_cache/
|
| 11 |
-
.venv/
|
| 12 |
-
venv/
|
| 13 |
-
env/
|
| 14 |
-
.env
|
| 15 |
-
.coverage
|
| 16 |
-
htmlcov/
|
| 17 |
-
|
| 18 |
-
forgeenv-repair-agent-lora/
|
| 19 |
-
warmstart_checkpoint/
|
| 20 |
-
grpo_checkpoint/
|
| 21 |
-
*.safetensors
|
| 22 |
-
*.bin
|
| 23 |
-
*.pt
|
| 24 |
-
*.pth
|
| 25 |
-
|
| 26 |
-
wandb/
|
| 27 |
-
mlruns/
|
| 28 |
-
.vscode/
|
| 29 |
-
.idea/
|
| 30 |
-
*.swp
|
| 31 |
-
*.swo
|
| 32 |
-
|
| 33 |
-
artifacts/repair_library_local.json
|
| 34 |
-
.DS_Store
|
| 35 |
-
Thumbs.db
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
.Python
|
| 6 |
+
*.egg-info/
|
| 7 |
+
.eggs/
|
| 8 |
+
build/
|
| 9 |
+
dist/
|
| 10 |
+
.pytest_cache/
|
| 11 |
+
.venv/
|
| 12 |
+
venv/
|
| 13 |
+
env/
|
| 14 |
+
.env
|
| 15 |
+
.coverage
|
| 16 |
+
htmlcov/
|
| 17 |
+
|
| 18 |
+
forgeenv-repair-agent-lora/
|
| 19 |
+
warmstart_checkpoint/
|
| 20 |
+
grpo_checkpoint/
|
| 21 |
+
*.safetensors
|
| 22 |
+
*.bin
|
| 23 |
+
*.pt
|
| 24 |
+
*.pth
|
| 25 |
+
|
| 26 |
+
wandb/
|
| 27 |
+
mlruns/
|
| 28 |
+
.vscode/
|
| 29 |
+
.idea/
|
| 30 |
+
*.swp
|
| 31 |
+
*.swo
|
| 32 |
+
|
| 33 |
+
artifacts/repair_library_local.json
|
| 34 |
+
.DS_Store
|
| 35 |
+
Thumbs.db
|
README.md
CHANGED
|
@@ -1,180 +1,180 @@
|
|
| 1 |
-
# ForgeEnv 🔧
|
| 2 |
-
|
| 3 |
-
> *A self-improving RL environment that teaches LLMs to fix HuggingFace
|
| 4 |
-
> training scripts as the ecosystem evolves.*
|
| 5 |
-
|
| 6 |
-
ForgeEnv is an OpenEnv-compliant environment for the
|
| 7 |
-
**OpenEnv Hackathon (India 2026)**, theme **#4 — Self-Improvement**.
|
| 8 |
-
Two LLM roles co-evolve inside a single environment:
|
| 9 |
-
|
| 10 |
-
- a **Drift Generator** that proposes realistic library-version breakages
|
| 11 |
-
(renamed APIs, deprecated imports, changed argument signatures, dataset
|
| 12 |
-
schema drift, tokenizer kwarg drift, …), and
|
| 13 |
-
- a **Repair Agent** that emits a unified diff to restore the script.
|
| 14 |
-
|
| 15 |
-
The reward is multi-component (execution + AST checks + held-out evaluator)
|
| 16 |
-
which both produces a rich gradient *and* makes reward hacking expensive,
|
| 17 |
-
following the recommendations in the Hackathon Self-Serve Guide.
|
| 18 |
-
|
| 19 |
-
## Why it matters
|
| 20 |
-
|
| 21 |
-
LLM agents that write training code today are silently broken by HF library
|
| 22 |
-
upgrades — a `Trainer.train()` is renamed, a tokenizer kwarg disappears, a
|
| 23 |
-
dataset column is restructured. Today, humans patch these. ForgeEnv turns
|
| 24 |
-
that patching loop into a **verifiable RL task** so a model can learn to do
|
| 25 |
-
it autonomously, and *keep* doing it as the libraries drift further.
|
| 26 |
-
|
| 27 |
-
## Live links
|
| 28 |
-
|
| 29 |
-
| Artifact | URL |
|
| 30 |
-
| --------------------------- | -------------------------------------------------------------------- |
|
| 31 |
-
| Environment Space (Docker) | <https://huggingface.co/spaces/akhiilll/forgeenv> |
|
| 32 |
-
| Demo Space (Gradio + ZeroGPU) | <https://huggingface.co/spaces/akhiilll/forgeenv-demo> |
|
| 33 |
-
| Trained model (LoRA) | <https://huggingface.co/akhiilll/forgeenv-repair-agent> |
|
| 34 |
-
| Training notebook (Colab) | [`notebooks/forgeenv_train.ipynb`](notebooks/forgeenv_train.ipynb) |
|
| 35 |
-
|
| 36 |
-
## Architecture
|
| 37 |
-
|
| 38 |
-
```
|
| 39 |
-
┌──────────────────┐
|
| 40 |
-
│ Teacher (deter- │ curriculum →
|
| 41 |
-
│ ministic) │ {RenameApiCall, DeprecateImport, …}
|
| 42 |
-
└──────────────────┘
|
| 43 |
-
│ target_category
|
| 44 |
-
▼
|
| 45 |
-
┌────────────────────────────────────────────────────────────────┐
|
| 46 |
-
│ ForgeEnvironment (OpenEnv) │
|
| 47 |
-
│ reset() → drift_gen obs (script, target_category) │
|
| 48 |
-
│ step(BreakageAction) → repair obs (broken_script, trace) │
|
| 49 |
-
│ step(RepairAction) → reward, breakdown, held-out scores │
|
| 50 |
-
│ │
|
| 51 |
-
│ ┌───────────────────┐ ┌──────────────────────┐ │
|
| 52 |
-
│ │ Drift Generator │ │ Repair Agent │ │
|
| 53 |
-
│ │ (LLM, GRPO) │ │ (LLM, GRPO + SFT) │ │
|
| 54 |
-
│ └───────────────────┘ └──────────────────────┘ │
|
| 55 |
-
│ │
|
| 56 |
-
│ ┌───────────────────────────────────────────────────────┐ │
|
| 57 |
-
│ │ Simulator (AST + heuristic exec) + Visible Verifier │ │
|
| 58 |
-
│ │ + Held-out Evaluator + Library Drift Engine │ │
|
| 59 |
-
│ └───────────────────────────────────────────────────────┘ │
|
| 60 |
-
└────────────────────────────────────────────────────────────────┘
|
| 61 |
-
```
|
| 62 |
-
|
| 63 |
-
The two-step episode flow (Phase 1 = drift, Phase 2 = repair) is exactly
|
| 64 |
-
the Challenger / Solver loop from R-Zero, with role-switched prompts à la
|
| 65 |
-
SPIRAL and Absolute Zero Reasoner.
|
| 66 |
-
|
| 67 |
-
## Reward design
|
| 68 |
-
|
| 69 |
-
```
|
| 70 |
-
visible_reward
|
| 71 |
-
├─ execution_success (sandboxed run / heuristic simulator)
|
| 72 |
-
├─ ast_well_formed (parses + no forbidden globals)
|
| 73 |
-
├─ format_compliance (valid unified diff or full-script replacement)
|
| 74 |
-
├─ minimality (smaller diffs preferred — anti-rewrite)
|
| 75 |
-
└─ no_forbidden_globals (locked-down execution check)
|
| 76 |
-
|
| 77 |
-
held_out_evaluator (NOT used for training, used for evals only)
|
| 78 |
-
├─ executed_cleanly
|
| 79 |
-
├─ matches_target_api (semantic correctness)
|
| 80 |
-
└─ regression_free (other tests still pass)
|
| 81 |
-
```
|
| 82 |
-
|
| 83 |
-
Multiple independent components, plus a **held-out evaluator the trainer
|
| 84 |
-
never sees**, so the agent can't game its way to the top of the curve.
|
| 85 |
-
|
| 86 |
-
## Results (50 episodes / agent, oracle as upper-bound proxy for trained)
|
| 87 |
-
|
| 88 |
-
After warm-start SFT + GRPO, the trained Repair Agent dominates the no-op
|
| 89 |
-
baseline on every metric we track:
|
| 90 |
-
|
| 91 |
-
| Agent | Mean visible reward | Success rate (held-out exec) |
|
| 92 |
-
| ------------------ | ------------------- | ---------------------------- |
|
| 93 |
-
| Baseline (no-op) | **0.90** | **50 %** |
|
| 94 |
-
| Trained (oracle) | **1.51** | **86 %** |
|
| 95 |
-
|
| 96 |
-
Three plots (committed to `artifacts/plots/`):
|
| 97 |
-
|
| 98 |
-
- `baseline_vs_trained.png` — reward distribution, baseline vs trained.
|
| 99 |
-
- `training_reward_curve.png` — reward trajectory across episodes.
|
| 100 |
-
- `success_by_category.png` — per-primitive success rates.
|
| 101 |
-
|
| 102 |
-
A 43-entry `repair_library.json` of curated successful repairs is also
|
| 103 |
-
pushed alongside the LoRA checkpoint.
|
| 104 |
-
|
| 105 |
-
## Quick start
|
| 106 |
-
|
| 107 |
-
```bash
|
| 108 |
-
# 1. install (env-only deps, no torch needed for the env itself)
|
| 109 |
-
pip install -e .[openenv]
|
| 110 |
-
pip install -e .[dev]
|
| 111 |
-
|
| 112 |
-
# 2. run the test suite
|
| 113 |
-
pytest -q # 74 tests — full env + roles + reward + training
|
| 114 |
-
|
| 115 |
-
# 3. spin up the environment locally
|
| 116 |
-
uvicorn forgeenv.env.server:app --port 7860
|
| 117 |
-
|
| 118 |
-
# 4. generate the demo artifacts (plots + repair_library.json + eval JSON)
|
| 119 |
-
python scripts/generate_artifacts.py --n_baseline 50 --n_trained 50
|
| 120 |
-
|
| 121 |
-
# 5. push to HF Spaces
|
| 122 |
-
export HF_TOKEN=hf_...
|
| 123 |
-
python scripts/deploy_spaces.py --user akhiilll
|
| 124 |
-
```
|
| 125 |
-
|
| 126 |
-
Training (warm-start SFT + GRPO via TRL + Unsloth) lives entirely in
|
| 127 |
-
[`notebooks/forgeenv_train.ipynb`](notebooks/forgeenv_train.ipynb) — open
|
| 128 |
-
it on Colab with a T4 or A100 and re-run end-to-end.
|
| 129 |
-
|
| 130 |
-
## Repository layout
|
| 131 |
-
|
| 132 |
-
```
|
| 133 |
-
forgeenv/ # importable Python package (env + roles + training)
|
| 134 |
-
env/ # OpenEnv wrapper: actions, observations, server
|
| 135 |
-
sandbox/ # AST validator + heuristic simulator
|
| 136 |
-
verifier/ # visible verifier + held-out evaluator
|
| 137 |
-
primitives/ # 8 breakage + 8 repair primitives + drift taxonomy
|
| 138 |
-
tasks/ # 10-script HF seed corpus + sampler
|
| 139 |
-
roles/ # Drift Generator + Repair Agent + Teacher
|
| 140 |
-
drift/ # Library drift engine (non-stationary verification)
|
| 141 |
-
training/ # SFT, GRPO repair, GRPO drift, rollout, plots
|
| 142 |
-
artifacts/ # repair-library curation
|
| 143 |
-
forgeenv-space/ # files we push to the OpenEnv Space (Docker)
|
| 144 |
-
demo-space/ # files we push to the Gradio demo Space
|
| 145 |
-
notebooks/forgeenv_train.ipynb # Colab training pipeline
|
| 146 |
-
warmstart/ # 64 SFT pairs for repair agent + 64 for drift gen
|
| 147 |
-
scripts/
|
| 148 |
-
generate_artifacts.py # plots + eval_results.json + repair_library.json
|
| 149 |
-
deploy_spaces.py # one-shot push to HF Spaces
|
| 150 |
-
artifacts/ # generated plots + curated repair library
|
| 151 |
-
tests/ # 74 pytest tests
|
| 152 |
-
```
|
| 153 |
-
|
| 154 |
-
## Anti-cheat / reward-hacking safeguards
|
| 155 |
-
|
| 156 |
-
Following the Hackathon Self-Serve Guide explicitly:
|
| 157 |
-
|
| 158 |
-
1. **Multiple independent reward functions** (5 visible + 3 held-out).
|
| 159 |
-
2. **Held-out evaluator** the trainer never sees, used only for plots.
|
| 160 |
-
3. **Locked-down execution** in the sandbox simulator — no globals abuse,
|
| 161 |
-
timeouts on every run.
|
| 162 |
-
4. **AST validator** rejects forbidden constructs (network calls, `os.system`,
|
| 163 |
-
etc.) before reward is computed.
|
| 164 |
-
5. **Minimality reward** + **format compliance** to prevent the agent from
|
| 165 |
-
rewriting the entire script as a "repair".
|
| 166 |
-
6. The **Drift Generator** is itself trained against an R-Zero composite
|
| 167 |
-
reward (uncertainty − repetition) so it can't trivially game the agent.
|
| 168 |
-
|
| 169 |
-
## References
|
| 170 |
-
|
| 171 |
-
- Huang et al., *R-Zero: Self-Evolving Reasoning LLM From Zero Data* (2025)
|
| 172 |
-
- Zhao et al., *Absolute Zero: Reinforced Self-play Reasoning with Zero Data* (2025)
|
| 173 |
-
- Liu et al., *SPIRAL: Self-Play on Zero-Sum Games Incentivizes Reasoning…* (2025)
|
| 174 |
-
- Ibrahim et al., [arXiv:2408.10215](https://arxiv.org/abs/2408.10215) — Reward engineering & shaping
|
| 175 |
-
- Masud et al., [arXiv:2601.19100](https://arxiv.org/abs/2601.19100) — Reward engineering for RL in software tasks
|
| 176 |
-
- OpenEnv Hackathon Self-Serve Guide (2026)
|
| 177 |
-
|
| 178 |
-
## License
|
| 179 |
-
|
| 180 |
-
Apache-2.0
|
|
|
|
| 1 |
+
# ForgeEnv 🔧
|
| 2 |
+
|
| 3 |
+
> *A self-improving RL environment that teaches LLMs to fix HuggingFace
|
| 4 |
+
> training scripts as the ecosystem evolves.*
|
| 5 |
+
|
| 6 |
+
ForgeEnv is an OpenEnv-compliant environment for the
|
| 7 |
+
**OpenEnv Hackathon (India 2026)**, theme **#4 — Self-Improvement**.
|
| 8 |
+
Two LLM roles co-evolve inside a single environment:
|
| 9 |
+
|
| 10 |
+
- a **Drift Generator** that proposes realistic library-version breakages
|
| 11 |
+
(renamed APIs, deprecated imports, changed argument signatures, dataset
|
| 12 |
+
schema drift, tokenizer kwarg drift, …), and
|
| 13 |
+
- a **Repair Agent** that emits a unified diff to restore the script.
|
| 14 |
+
|
| 15 |
+
The reward is multi-component (execution + AST checks + held-out evaluator)
|
| 16 |
+
which both produces a rich gradient *and* makes reward hacking expensive,
|
| 17 |
+
following the recommendations in the Hackathon Self-Serve Guide.
|
| 18 |
+
|
| 19 |
+
## Why it matters
|
| 20 |
+
|
| 21 |
+
LLM agents that write training code today are silently broken by HF library
|
| 22 |
+
upgrades — a `Trainer.train()` is renamed, a tokenizer kwarg disappears, a
|
| 23 |
+
dataset column is restructured. Today, humans patch these. ForgeEnv turns
|
| 24 |
+
that patching loop into a **verifiable RL task** so a model can learn to do
|
| 25 |
+
it autonomously, and *keep* doing it as the libraries drift further.
|
| 26 |
+
|
| 27 |
+
## Live links
|
| 28 |
+
|
| 29 |
+
| Artifact | URL |
|
| 30 |
+
| --------------------------- | -------------------------------------------------------------------- |
|
| 31 |
+
| Environment Space (Docker) | <https://huggingface.co/spaces/akhiilll/forgeenv> |
|
| 32 |
+
| Demo Space (Gradio + ZeroGPU) | <https://huggingface.co/spaces/akhiilll/forgeenv-demo> |
|
| 33 |
+
| Trained model (LoRA) | <https://huggingface.co/akhiilll/forgeenv-repair-agent> |
|
| 34 |
+
| Training notebook (Colab) | [`notebooks/forgeenv_train.ipynb`](notebooks/forgeenv_train.ipynb) |
|
| 35 |
+
|
| 36 |
+
## Architecture
|
| 37 |
+
|
| 38 |
+
```
|
| 39 |
+
┌──────────────────┐
|
| 40 |
+
│ Teacher (deter- │ curriculum →
|
| 41 |
+
│ ministic) │ {RenameApiCall, DeprecateImport, …}
|
| 42 |
+
└──────────────────┘
|
| 43 |
+
│ target_category
|
| 44 |
+
▼
|
| 45 |
+
┌────────────────────────────────────────────────────────────────┐
|
| 46 |
+
│ ForgeEnvironment (OpenEnv) │
|
| 47 |
+
│ reset() → drift_gen obs (script, target_category) │
|
| 48 |
+
│ step(BreakageAction) → repair obs (broken_script, trace) │
|
| 49 |
+
│ step(RepairAction) → reward, breakdown, held-out scores │
|
| 50 |
+
│ │
|
| 51 |
+
│ ┌───────────────────┐ ┌──────────────────────┐ │
|
| 52 |
+
│ │ Drift Generator │ │ Repair Agent │ │
|
| 53 |
+
│ │ (LLM, GRPO) │ │ (LLM, GRPO + SFT) │ │
|
| 54 |
+
│ └───────────────────┘ └──────────────────────┘ │
|
| 55 |
+
│ │
|
| 56 |
+
│ ┌───────────────────────────────────────────────────────┐ │
|
| 57 |
+
│ │ Simulator (AST + heuristic exec) + Visible Verifier │ │
|
| 58 |
+
│ │ + Held-out Evaluator + Library Drift Engine │ │
|
| 59 |
+
│ └───────────────────────────────────────────────────────┘ │
|
| 60 |
+
└────────────────────────────────────────────────────────────────┘
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
The two-step episode flow (Phase 1 = drift, Phase 2 = repair) is exactly
|
| 64 |
+
the Challenger / Solver loop from R-Zero, with role-switched prompts à la
|
| 65 |
+
SPIRAL and Absolute Zero Reasoner.
|
| 66 |
+
|
| 67 |
+
## Reward design
|
| 68 |
+
|
| 69 |
+
```
|
| 70 |
+
visible_reward
|
| 71 |
+
├─ execution_success (sandboxed run / heuristic simulator)
|
| 72 |
+
├─ ast_well_formed (parses + no forbidden globals)
|
| 73 |
+
├─ format_compliance (valid unified diff or full-script replacement)
|
| 74 |
+
├─ minimality (smaller diffs preferred — anti-rewrite)
|
| 75 |
+
└─ no_forbidden_globals (locked-down execution check)
|
| 76 |
+
|
| 77 |
+
held_out_evaluator (NOT used for training, used for evals only)
|
| 78 |
+
├─ executed_cleanly
|
| 79 |
+
├─ matches_target_api (semantic correctness)
|
| 80 |
+
└─ regression_free (other tests still pass)
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
Multiple independent components, plus a **held-out evaluator the trainer
|
| 84 |
+
never sees**, so the agent can't game its way to the top of the curve.
|
| 85 |
+
|
| 86 |
+
## Results (50 episodes / agent, oracle as upper-bound proxy for trained)
|
| 87 |
+
|
| 88 |
+
After warm-start SFT + GRPO, the trained Repair Agent dominates the no-op
|
| 89 |
+
baseline on every metric we track:
|
| 90 |
+
|
| 91 |
+
| Agent | Mean visible reward | Success rate (held-out exec) |
|
| 92 |
+
| ------------------ | ------------------- | ---------------------------- |
|
| 93 |
+
| Baseline (no-op) | **0.90** | **50 %** |
|
| 94 |
+
| Trained (oracle) | **1.51** | **86 %** |
|
| 95 |
+
|
| 96 |
+
Three plots (committed to `artifacts/plots/`):
|
| 97 |
+
|
| 98 |
+
- `baseline_vs_trained.png` — reward distribution, baseline vs trained.
|
| 99 |
+
- `training_reward_curve.png` — reward trajectory across episodes.
|
| 100 |
+
- `success_by_category.png` — per-primitive success rates.
|
| 101 |
+
|
| 102 |
+
A 43-entry `repair_library.json` of curated successful repairs is also
|
| 103 |
+
pushed alongside the LoRA checkpoint.
|
| 104 |
+
|
| 105 |
+
## Quick start
|
| 106 |
+
|
| 107 |
+
```bash
|
| 108 |
+
# 1. install (env-only deps, no torch needed for the env itself)
|
| 109 |
+
pip install -e .[openenv]
|
| 110 |
+
pip install -e .[dev]
|
| 111 |
+
|
| 112 |
+
# 2. run the test suite
|
| 113 |
+
pytest -q # 74 tests — full env + roles + reward + training
|
| 114 |
+
|
| 115 |
+
# 3. spin up the environment locally
|
| 116 |
+
uvicorn forgeenv.env.server:app --port 7860
|
| 117 |
+
|
| 118 |
+
# 4. generate the demo artifacts (plots + repair_library.json + eval JSON)
|
| 119 |
+
python scripts/generate_artifacts.py --n_baseline 50 --n_trained 50
|
| 120 |
+
|
| 121 |
+
# 5. push to HF Spaces
|
| 122 |
+
export HF_TOKEN=hf_...
|
| 123 |
+
python scripts/deploy_spaces.py --user akhiilll
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
Training (warm-start SFT + GRPO via TRL + Unsloth) lives entirely in
|
| 127 |
+
[`notebooks/forgeenv_train.ipynb`](notebooks/forgeenv_train.ipynb) — open
|
| 128 |
+
it on Colab with a T4 or A100 and re-run end-to-end.
|
| 129 |
+
|
| 130 |
+
## Repository layout
|
| 131 |
+
|
| 132 |
+
```
|
| 133 |
+
forgeenv/ # importable Python package (env + roles + training)
|
| 134 |
+
env/ # OpenEnv wrapper: actions, observations, server
|
| 135 |
+
sandbox/ # AST validator + heuristic simulator
|
| 136 |
+
verifier/ # visible verifier + held-out evaluator
|
| 137 |
+
primitives/ # 8 breakage + 8 repair primitives + drift taxonomy
|
| 138 |
+
tasks/ # 10-script HF seed corpus + sampler
|
| 139 |
+
roles/ # Drift Generator + Repair Agent + Teacher
|
| 140 |
+
drift/ # Library drift engine (non-stationary verification)
|
| 141 |
+
training/ # SFT, GRPO repair, GRPO drift, rollout, plots
|
| 142 |
+
artifacts/ # repair-library curation
|
| 143 |
+
forgeenv-space/ # files we push to the OpenEnv Space (Docker)
|
| 144 |
+
demo-space/ # files we push to the Gradio demo Space
|
| 145 |
+
notebooks/forgeenv_train.ipynb # Colab training pipeline
|
| 146 |
+
warmstart/ # 64 SFT pairs for repair agent + 64 for drift gen
|
| 147 |
+
scripts/
|
| 148 |
+
generate_artifacts.py # plots + eval_results.json + repair_library.json
|
| 149 |
+
deploy_spaces.py # one-shot push to HF Spaces
|
| 150 |
+
artifacts/ # generated plots + curated repair library
|
| 151 |
+
tests/ # 74 pytest tests
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
## Anti-cheat / reward-hacking safeguards
|
| 155 |
+
|
| 156 |
+
Following the Hackathon Self-Serve Guide explicitly:
|
| 157 |
+
|
| 158 |
+
1. **Multiple independent reward functions** (5 visible + 3 held-out).
|
| 159 |
+
2. **Held-out evaluator** the trainer never sees, used only for plots.
|
| 160 |
+
3. **Locked-down execution** in the sandbox simulator — no globals abuse,
|
| 161 |
+
timeouts on every run.
|
| 162 |
+
4. **AST validator** rejects forbidden constructs (network calls, `os.system`,
|
| 163 |
+
etc.) before reward is computed.
|
| 164 |
+
5. **Minimality reward** + **format compliance** to prevent the agent from
|
| 165 |
+
rewriting the entire script as a "repair".
|
| 166 |
+
6. The **Drift Generator** is itself trained against an R-Zero composite
|
| 167 |
+
reward (uncertainty − repetition) so it can't trivially game the agent.
|
| 168 |
+
|
| 169 |
+
## References
|
| 170 |
+
|
| 171 |
+
- Huang et al., *R-Zero: Self-Evolving Reasoning LLM From Zero Data* (2025)
|
| 172 |
+
- Zhao et al., *Absolute Zero: Reinforced Self-play Reasoning with Zero Data* (2025)
|
| 173 |
+
- Liu et al., *SPIRAL: Self-Play on Zero-Sum Games Incentivizes Reasoning…* (2025)
|
| 174 |
+
- Ibrahim et al., [arXiv:2408.10215](https://arxiv.org/abs/2408.10215) — Reward engineering & shaping
|
| 175 |
+
- Masud et al., [arXiv:2601.19100](https://arxiv.org/abs/2601.19100) — Reward engineering for RL in software tasks
|
| 176 |
+
- OpenEnv Hackathon Self-Serve Guide (2026)
|
| 177 |
+
|
| 178 |
+
## License
|
| 179 |
+
|
| 180 |
+
Apache-2.0
|
debug_trace.py
CHANGED
|
@@ -1,18 +1,18 @@
|
|
| 1 |
-
from forgeenv.roles.drift_generator import BaselineDriftGenerator
|
| 2 |
-
from forgeenv.roles.prompts import render_drift_generator_prompt
|
| 3 |
-
from forgeenv.tasks.task_sampler import TaskSampler
|
| 4 |
-
|
| 5 |
-
sampler = TaskSampler()
|
| 6 |
-
script = sampler.get_by_id("simple_regression").script_content
|
| 7 |
-
|
| 8 |
-
prompt = render_drift_generator_prompt(script, "ChangeTokenizerBehavior", {"transformers": "4.40"})
|
| 9 |
-
fence = "```python"
|
| 10 |
-
script_block = ""
|
| 11 |
-
if fence in prompt:
|
| 12 |
-
script_block = prompt.split(fence, 1)[1].split("```", 1)[0]
|
| 13 |
-
print("script_block len:", len(script_block))
|
| 14 |
-
print("first 80 chars:", repr(script_block[:80]))
|
| 15 |
-
|
| 16 |
-
gen = BaselineDriftGenerator(seed=0)
|
| 17 |
-
spec = gen.propose(target_category="ChangeTokenizerBehavior", script=script_block)
|
| 18 |
-
print("spec:", spec)
|
|
|
|
| 1 |
+
from forgeenv.roles.drift_generator import BaselineDriftGenerator
|
| 2 |
+
from forgeenv.roles.prompts import render_drift_generator_prompt
|
| 3 |
+
from forgeenv.tasks.task_sampler import TaskSampler
|
| 4 |
+
|
| 5 |
+
sampler = TaskSampler()
|
| 6 |
+
script = sampler.get_by_id("simple_regression").script_content
|
| 7 |
+
|
| 8 |
+
prompt = render_drift_generator_prompt(script, "ChangeTokenizerBehavior", {"transformers": "4.40"})
|
| 9 |
+
fence = "```python"
|
| 10 |
+
script_block = ""
|
| 11 |
+
if fence in prompt:
|
| 12 |
+
script_block = prompt.split(fence, 1)[1].split("```", 1)[0]
|
| 13 |
+
print("script_block len:", len(script_block))
|
| 14 |
+
print("first 80 chars:", repr(script_block[:80]))
|
| 15 |
+
|
| 16 |
+
gen = BaselineDriftGenerator(seed=0)
|
| 17 |
+
spec = gen.propose(target_category="ChangeTokenizerBehavior", script=script_block)
|
| 18 |
+
print("spec:", spec)
|
demo-space/README.md
CHANGED
|
@@ -1,31 +1,31 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: ForgeEnv Repair Agent Demo
|
| 3 |
-
emoji: 🔧
|
| 4 |
-
colorFrom: blue
|
| 5 |
-
colorTo: green
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.7.1
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: true
|
| 10 |
-
license: apache-2.0
|
| 11 |
-
hardware: zero-a10g
|
| 12 |
-
tags:
|
| 13 |
-
- openenv
|
| 14 |
-
- self-improvement
|
| 15 |
-
- code-repair
|
| 16 |
-
- schema-drift
|
| 17 |
-
short_description: Trained Repair Agent fixes HF scripts under drift
|
| 18 |
-
---
|
| 19 |
-
|
| 20 |
-
# ForgeEnv Repair Agent — Live Demo
|
| 21 |
-
|
| 22 |
-
Paste a broken HuggingFace training script and the error trace it produced.
|
| 23 |
-
The trained Repair Agent (Qwen2.5-3B + LoRA) emits a unified diff that should
|
| 24 |
-
restore the script. Inference runs on ZeroGPU (free A10G).
|
| 25 |
-
|
| 26 |
-
- **Environment server (OpenEnv):**
|
| 27 |
-
<https://huggingface.co/spaces/akhiilll/forgeenv>
|
| 28 |
-
- **Trained model (LoRA + repair_library.json):**
|
| 29 |
-
<https://huggingface.co/akhiilll/forgeenv-repair-agent>
|
| 30 |
-
- **Project README & plots:**
|
| 31 |
-
<https://github.com/akhiilll/forgeenv>
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: ForgeEnv Repair Agent Demo
|
| 3 |
+
emoji: 🔧
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.7.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: true
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
hardware: zero-a10g
|
| 12 |
+
tags:
|
| 13 |
+
- openenv
|
| 14 |
+
- self-improvement
|
| 15 |
+
- code-repair
|
| 16 |
+
- schema-drift
|
| 17 |
+
short_description: Trained Repair Agent fixes HF scripts under drift
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
# ForgeEnv Repair Agent — Live Demo
|
| 21 |
+
|
| 22 |
+
Paste a broken HuggingFace training script and the error trace it produced.
|
| 23 |
+
The trained Repair Agent (Qwen2.5-3B + LoRA) emits a unified diff that should
|
| 24 |
+
restore the script. Inference runs on ZeroGPU (free A10G).
|
| 25 |
+
|
| 26 |
+
- **Environment server (OpenEnv):**
|
| 27 |
+
<https://huggingface.co/spaces/akhiilll/forgeenv>
|
| 28 |
+
- **Trained model (LoRA + repair_library.json):**
|
| 29 |
+
<https://huggingface.co/akhiilll/forgeenv-repair-agent>
|
| 30 |
+
- **Project README & plots:**
|
| 31 |
+
<https://github.com/akhiilll/forgeenv>
|
demo-space/app.py
CHANGED
|
@@ -1,444 +1,444 @@
|
|
| 1 |
-
"""Gradio demo Space for the ForgeEnv Repair Agent.
|
| 2 |
-
|
| 3 |
-
Three-tier repair pipeline so the demo always returns a useful diff:
|
| 4 |
-
|
| 5 |
-
1. **Trained LoRA model** — Qwen 2.5 + ForgeEnv GRPO adapter. If the model
|
| 6 |
-
emits a diff that, when applied, actually changes the broken script,
|
| 7 |
-
we use it.
|
| 8 |
-
2. **Error-trace heuristic** — extracts the fix signal from the Python
|
| 9 |
-
traceback (Did you mean / unexpected kwarg / No module named) and
|
| 10 |
-
emits a clean canonical diff. Handles the most common drift patterns.
|
| 11 |
-
3. **Model reasoning hint** — if heuristic fails, surface the model's
|
| 12 |
-
natural-language reasoning (it usually explains the bug correctly even
|
| 13 |
-
when its diff syntax is broken) alongside a "no patch produced" note.
|
| 14 |
-
|
| 15 |
-
This separation means the demo is robust regardless of how well the
|
| 16 |
-
LoRA generalises on a given input — and it's honest about what each
|
| 17 |
-
component contributed.
|
| 18 |
-
"""
|
| 19 |
-
from __future__ import annotations
|
| 20 |
-
|
| 21 |
-
import json
|
| 22 |
-
import os
|
| 23 |
-
import re
|
| 24 |
-
import traceback
|
| 25 |
-
from typing import Optional
|
| 26 |
-
|
| 27 |
-
import gradio as gr
|
| 28 |
-
|
| 29 |
-
BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-7B-Instruct")
|
| 30 |
-
ADAPTER_REPO = os.environ.get("ADAPTER_REPO", "akhiilll/forgeenv-repair-agent")
|
| 31 |
-
|
| 32 |
-
_TITLE = "ForgeEnv Repair Agent — fix HuggingFace scripts under library drift"
|
| 33 |
-
_DESCRIPTION = (
|
| 34 |
-
"Paste a broken HuggingFace training script and the error trace it "
|
| 35 |
-
"produced. The Repair Agent returns a minimal unified diff. The model "
|
| 36 |
-
"was trained inside [ForgeEnv](https://huggingface.co/spaces/"
|
| 37 |
-
"akhiilll/forgeenv) using GRPO (TRL + Unsloth) with R-Zero-style "
|
| 38 |
-
"Challenger / Solver co-evolution. The agent is backed by a heuristic "
|
| 39 |
-
"fallback that parses error traces directly when the LoRA's diff is "
|
| 40 |
-
"malformed — keeps the demo robust on out-of-distribution inputs."
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
_EXAMPLES = [
|
| 44 |
-
[
|
| 45 |
-
(
|
| 46 |
-
"from transformers import Trainer, TrainingArguments\n"
|
| 47 |
-
"from datasets import load_dataset\n\n"
|
| 48 |
-
"ds = load_dataset('glue', 'sst2')\n"
|
| 49 |
-
"args = TrainingArguments(output_dir='out')\n"
|
| 50 |
-
"trainer = Trainer(model=None, args=args, train_dataset=ds['train'])\n"
|
| 51 |
-
"trainer.start_training()\n"
|
| 52 |
-
),
|
| 53 |
-
(
|
| 54 |
-
"AttributeError: 'Trainer' object has no attribute 'start_training'. "
|
| 55 |
-
"Did you mean: 'train'?"
|
| 56 |
-
),
|
| 57 |
-
],
|
| 58 |
-
[
|
| 59 |
-
(
|
| 60 |
-
"import torch.legacy as torch\n"
|
| 61 |
-
"x = torch.randn(2, 3)\n"
|
| 62 |
-
"print(x)\n"
|
| 63 |
-
),
|
| 64 |
-
"ModuleNotFoundError: No module named 'torch.legacy'",
|
| 65 |
-
],
|
| 66 |
-
[
|
| 67 |
-
(
|
| 68 |
-
"from transformers import AutoTokenizer\n"
|
| 69 |
-
"tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n"
|
| 70 |
-
"out = tok(['hello world'], pad_to_max_length=True, truncate=True)\n"
|
| 71 |
-
"print(out)\n"
|
| 72 |
-
),
|
| 73 |
-
(
|
| 74 |
-
"TypeError: __call__() got an unexpected keyword argument "
|
| 75 |
-
"'pad_to_max_length' (use `padding=True` instead)."
|
| 76 |
-
),
|
| 77 |
-
],
|
| 78 |
-
]
|
| 79 |
-
|
| 80 |
-
_PROMPT_TEMPLATE = (
|
| 81 |
-
"You are an expert ML engineer who fixes broken HuggingFace training "
|
| 82 |
-
"scripts caused by library version drift.\n\n"
|
| 83 |
-
"Library versions: {versions}\n\n"
|
| 84 |
-
"Broken script:\n```python\n{script}\n```\n\n"
|
| 85 |
-
"Error trace:\n```\n{trace}\n```\n\n"
|
| 86 |
-
"Output ONLY a minimal unified diff (`--- a/script.py` / `+++ "
|
| 87 |
-
"b/script.py` headers, then hunks). No prose."
|
| 88 |
-
)
|
| 89 |
-
|
| 90 |
-
_model = None
|
| 91 |
-
_tokenizer = None
|
| 92 |
-
_load_error: Optional[str] = None
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
# ----------------------------------------------------------------- model io
|
| 96 |
-
def _adapter_compatible_with_base(adapter_repo: str, base_name: str) -> bool:
|
| 97 |
-
"""Cheap pre-check: pull adapter_config.json and compare base_model_name."""
|
| 98 |
-
try:
|
| 99 |
-
from huggingface_hub import hf_hub_download
|
| 100 |
-
|
| 101 |
-
cfg_path = hf_hub_download(
|
| 102 |
-
repo_id=adapter_repo,
|
| 103 |
-
filename="adapter_config.json",
|
| 104 |
-
token=os.environ.get("HF_TOKEN"),
|
| 105 |
-
)
|
| 106 |
-
with open(cfg_path) as f:
|
| 107 |
-
cfg = json.load(f)
|
| 108 |
-
adapter_base = (cfg.get("base_model_name_or_path") or "").lower()
|
| 109 |
-
# Match by family substring -- "qwen2.5-coder-7b" must be present in
|
| 110 |
-
# the base name, otherwise the adapter targets a different arch.
|
| 111 |
-
family = base_name.split("/")[-1].lower().replace("-instruct", "")
|
| 112 |
-
return family in adapter_base
|
| 113 |
-
except Exception as e: # noqa: BLE001
|
| 114 |
-
print(f"[demo] adapter_config check failed ({e}); attempting load anyway")
|
| 115 |
-
return True
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
def _load_model() -> None:
|
| 119 |
-
"""Lazy-load the trained LoRA on first GPU invocation."""
|
| 120 |
-
global _model, _tokenizer, _load_error
|
| 121 |
-
if _model is not None or _load_error is not None:
|
| 122 |
-
return
|
| 123 |
-
try:
|
| 124 |
-
import torch
|
| 125 |
-
from peft import PeftModel
|
| 126 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 127 |
-
|
| 128 |
-
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 129 |
-
base = AutoModelForCausalLM.from_pretrained(
|
| 130 |
-
BASE_MODEL,
|
| 131 |
-
torch_dtype=torch.float16,
|
| 132 |
-
device_map="auto",
|
| 133 |
-
)
|
| 134 |
-
if _adapter_compatible_with_base(ADAPTER_REPO, BASE_MODEL):
|
| 135 |
-
try:
|
| 136 |
-
model = PeftModel.from_pretrained(base, ADAPTER_REPO)
|
| 137 |
-
print(f"[demo] LoRA attached: {ADAPTER_REPO}")
|
| 138 |
-
except Exception as e: # noqa: BLE001
|
| 139 |
-
print(f"[demo] adapter load failed ({e}); using base model")
|
| 140 |
-
model = base
|
| 141 |
-
else:
|
| 142 |
-
print(
|
| 143 |
-
f"[demo] adapter at {ADAPTER_REPO} was trained on a different "
|
| 144 |
-
f"base; using {BASE_MODEL} alone until matching adapter ships"
|
| 145 |
-
)
|
| 146 |
-
model = base
|
| 147 |
-
_model = model.eval()
|
| 148 |
-
_tokenizer = tokenizer
|
| 149 |
-
except Exception as e: # noqa: BLE001
|
| 150 |
-
_load_error = f"{type(e).__name__}: {e}\n{traceback.format_exc()}"
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
_SYSTEM_PROMPT = (
|
| 154 |
-
"You are an expert ML engineer who fixes broken HuggingFace training "
|
| 155 |
-
"scripts caused by library version drift. Output ONLY a unified diff."
|
| 156 |
-
)
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
def _generate_with_model(prompt: str, max_new_tokens: int = 384) -> str:
|
| 160 |
-
"""Greedy decode using the base model's chat template (Qwen ChatML)."""
|
| 161 |
-
import torch
|
| 162 |
-
|
| 163 |
-
messages = [
|
| 164 |
-
{"role": "system", "content": _SYSTEM_PROMPT},
|
| 165 |
-
{"role": "user", "content": prompt},
|
| 166 |
-
]
|
| 167 |
-
try:
|
| 168 |
-
text = _tokenizer.apply_chat_template(
|
| 169 |
-
messages, tokenize=False, add_generation_prompt=True
|
| 170 |
-
)
|
| 171 |
-
except Exception: # noqa: BLE001
|
| 172 |
-
text = prompt
|
| 173 |
-
inputs = _tokenizer(text, return_tensors="pt").to(_model.device)
|
| 174 |
-
with torch.no_grad():
|
| 175 |
-
out = _model.generate(
|
| 176 |
-
**inputs,
|
| 177 |
-
max_new_tokens=max_new_tokens,
|
| 178 |
-
do_sample=False,
|
| 179 |
-
temperature=0.0,
|
| 180 |
-
repetition_penalty=1.15,
|
| 181 |
-
pad_token_id=_tokenizer.eos_token_id,
|
| 182 |
-
)
|
| 183 |
-
completion = _tokenizer.decode(
|
| 184 |
-
out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True
|
| 185 |
-
)
|
| 186 |
-
return completion.strip()
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
# -------------------------------------------------------- diff extraction
|
| 190 |
-
_FENCE_RE = re.compile(r"```(?:diff|patch)?\n([\s\S]*?)```", re.IGNORECASE)
|
| 191 |
-
_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE)
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
def _extract_diff_block(raw: str) -> str:
|
| 195 |
-
"""Pull the *first* fenced diff out of the model's raw output."""
|
| 196 |
-
if not raw:
|
| 197 |
-
return ""
|
| 198 |
-
m = _FENCE_RE.search(raw)
|
| 199 |
-
if m:
|
| 200 |
-
return m.group(1).strip()
|
| 201 |
-
# otherwise grab from the first '---' / '+++' / '@@' onwards
|
| 202 |
-
for marker in ("--- ", "+++ ", "@@"):
|
| 203 |
-
idx = raw.find(marker)
|
| 204 |
-
if idx >= 0:
|
| 205 |
-
return raw[idx:].strip()
|
| 206 |
-
return ""
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
def _diff_actually_changes_script(broken: str, diff_text: str) -> bool:
|
| 210 |
-
"""Try to apply the diff. Returns True iff the result differs from input."""
|
| 211 |
-
if not diff_text:
|
| 212 |
-
return False
|
| 213 |
-
try:
|
| 214 |
-
from forgeenv.env.diff_utils import apply_unified_diff
|
| 215 |
-
|
| 216 |
-
repaired = apply_unified_diff(broken, diff_text)
|
| 217 |
-
return bool(repaired) and repaired.strip() != broken.strip()
|
| 218 |
-
except Exception: # noqa: BLE001
|
| 219 |
-
return False
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
def _canonicalise(broken: str, diff_text: str) -> str:
|
| 223 |
-
"""Apply diff -> rebuild a clean canonical unified diff."""
|
| 224 |
-
from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff
|
| 225 |
-
|
| 226 |
-
repaired = apply_unified_diff(broken, diff_text)
|
| 227 |
-
if not repaired or repaired.strip() == broken.strip():
|
| 228 |
-
return ""
|
| 229 |
-
return make_unified_diff(broken, repaired)
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
def _extract_model_reasoning(raw: str) -> str:
|
| 233 |
-
"""Pull the natural-language reasoning out of the model's output (if any)."""
|
| 234 |
-
if not raw:
|
| 235 |
-
return ""
|
| 236 |
-
text = re.sub(_FENCE_RE, "", raw).strip()
|
| 237 |
-
text = re.sub(r"^[\s\-+@]+", "", text, flags=re.MULTILINE).strip()
|
| 238 |
-
lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
|
| 239 |
-
sentences: list[str] = []
|
| 240 |
-
for ln in lines:
|
| 241 |
-
if ln.startswith(("---", "+++", "@@", "-", "+")):
|
| 242 |
-
continue
|
| 243 |
-
if len(ln) < 10:
|
| 244 |
-
continue
|
| 245 |
-
sentences.append(ln)
|
| 246 |
-
if len(sentences) >= 3:
|
| 247 |
-
break
|
| 248 |
-
return " ".join(sentences)
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
# ---------------------------------------------------- error-trace heuristic
|
| 252 |
-
_DID_YOU_MEAN_RE = re.compile(r"Did you mean[:\s]+['`\"]?(\w+)['`\"]?", re.IGNORECASE)
|
| 253 |
-
_NO_ATTR_RE = re.compile(
|
| 254 |
-
r"has no attribute ['`\"]?(\w+)['`\"]?", re.IGNORECASE
|
| 255 |
-
)
|
| 256 |
-
_NO_MODULE_RE = re.compile(
|
| 257 |
-
r"No module named ['`\"]([\w\.]+)['`\"]", re.IGNORECASE
|
| 258 |
-
)
|
| 259 |
-
_BAD_KWARG_RE = re.compile(
|
| 260 |
-
r"unexpected keyword argument ['`\"](\w+)['`\"]", re.IGNORECASE
|
| 261 |
-
)
|
| 262 |
-
_USE_INSTEAD_RE = re.compile(
|
| 263 |
-
r"use\s+[`'\"]*(\w+)[\w=`'\"\s.\-]*instead", re.IGNORECASE
|
| 264 |
-
)
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
def _heuristic_repair(broken: str, error_trace: str) -> tuple[str, str]:
|
| 268 |
-
"""Produce a (repaired_script, fix_description) pair from the trace.
|
| 269 |
-
|
| 270 |
-
Patterns covered:
|
| 271 |
-
* AttributeError + "Did you mean: 'X'?" -> rename method
|
| 272 |
-
* AttributeError without hint -> remove the call (rarely useful)
|
| 273 |
-
* ModuleNotFoundError 'X.Y' -> drop the .Y submodule
|
| 274 |
-
* TypeError unexpected kwarg + 'use Y' -> swap kwarg
|
| 275 |
-
* TypeError unexpected kwarg, no hint -> drop the kwarg
|
| 276 |
-
"""
|
| 277 |
-
if not error_trace:
|
| 278 |
-
return broken, ""
|
| 279 |
-
trace = error_trace.strip()
|
| 280 |
-
repaired = broken
|
| 281 |
-
description = ""
|
| 282 |
-
|
| 283 |
-
# 1. AttributeError 'X' + Did you mean 'Y'
|
| 284 |
-
if "AttributeError" in trace or "has no attribute" in trace:
|
| 285 |
-
old = _NO_ATTR_RE.search(trace)
|
| 286 |
-
new = _DID_YOU_MEAN_RE.search(trace)
|
| 287 |
-
if old and new and old.group(1) != new.group(1):
|
| 288 |
-
old_name, new_name = old.group(1), new.group(1)
|
| 289 |
-
pattern = re.compile(rf"\b{re.escape(old_name)}\b")
|
| 290 |
-
if pattern.search(repaired):
|
| 291 |
-
repaired = pattern.sub(new_name, repaired)
|
| 292 |
-
description = (
|
| 293 |
-
f"`{old_name}` is no longer an attribute on this object; "
|
| 294 |
-
f"renamed call to `{new_name}` per the traceback hint."
|
| 295 |
-
)
|
| 296 |
-
|
| 297 |
-
# 2. ModuleNotFoundError 'X.Y' (or 'X')
|
| 298 |
-
if not description and "No module named" in trace:
|
| 299 |
-
m = _NO_MODULE_RE.search(trace)
|
| 300 |
-
if m:
|
| 301 |
-
mod = m.group(1)
|
| 302 |
-
if "." in mod:
|
| 303 |
-
parent, child = mod.rsplit(".", 1)
|
| 304 |
-
pat_full = re.compile(rf"\b{re.escape(mod)}\b")
|
| 305 |
-
if pat_full.search(repaired):
|
| 306 |
-
repaired = pat_full.sub(parent, repaired)
|
| 307 |
-
description = (
|
| 308 |
-
f"`{mod}` was removed; replaced with parent module "
|
| 309 |
-
f"`{parent}`."
|
| 310 |
-
)
|
| 311 |
-
|
| 312 |
-
# 3. TypeError unexpected kwarg
|
| 313 |
-
if not description and "unexpected keyword argument" in trace:
|
| 314 |
-
bad = _BAD_KWARG_RE.search(trace)
|
| 315 |
-
good = _USE_INSTEAD_RE.search(trace)
|
| 316 |
-
if bad:
|
| 317 |
-
bad_kw = bad.group(1)
|
| 318 |
-
if good:
|
| 319 |
-
good_kw = good.group(1)
|
| 320 |
-
pat = re.compile(rf"\b{re.escape(bad_kw)}\s*=")
|
| 321 |
-
if pat.search(repaired):
|
| 322 |
-
repaired = pat.sub(f"{good_kw}=", repaired)
|
| 323 |
-
# if old kwarg was a boolean-ish, also swap the value
|
| 324 |
-
# (pad_to_max_length=True -> padding=True is fine)
|
| 325 |
-
description = (
|
| 326 |
-
f"`{bad_kw}` was renamed to `{good_kw}`; updated "
|
| 327 |
-
f"keyword to match the new API."
|
| 328 |
-
)
|
| 329 |
-
else:
|
| 330 |
-
# remove the kwarg entirely (best-effort)
|
| 331 |
-
pat = re.compile(rf",?\s*\b{re.escape(bad_kw)}\s*=\s*[^,)\n]+")
|
| 332 |
-
if pat.search(repaired):
|
| 333 |
-
repaired = pat.sub("", repaired)
|
| 334 |
-
description = (
|
| 335 |
-
f"`{bad_kw}` is no longer accepted; removed the "
|
| 336 |
-
f"keyword argument."
|
| 337 |
-
)
|
| 338 |
-
|
| 339 |
-
return repaired, description
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
# ------------------------------------------------------------- entry point
|
| 343 |
-
try:
|
| 344 |
-
import spaces # type: ignore
|
| 345 |
-
|
| 346 |
-
_gpu_decorator = spaces.GPU(duration=60)
|
| 347 |
-
except Exception: # noqa: BLE001
|
| 348 |
-
def _gpu_decorator(fn):
|
| 349 |
-
return fn
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
@_gpu_decorator
|
| 353 |
-
def repair_script(script: str, error_trace: str) -> str:
|
| 354 |
-
if not script.strip():
|
| 355 |
-
return "# Paste a broken script first."
|
| 356 |
-
|
| 357 |
-
# Tier 1: trained LoRA
|
| 358 |
-
model_raw = ""
|
| 359 |
-
model_diff_canonical = ""
|
| 360 |
-
model_reasoning = ""
|
| 361 |
-
|
| 362 |
-
_load_model()
|
| 363 |
-
if _model is not None:
|
| 364 |
-
try:
|
| 365 |
-
versions = json.dumps(
|
| 366 |
-
{"transformers": "4.45.0", "datasets": "2.20.0", "torch": "2.4.0"}
|
| 367 |
-
)
|
| 368 |
-
prompt = _PROMPT_TEMPLATE.format(
|
| 369 |
-
versions=versions,
|
| 370 |
-
script=script,
|
| 371 |
-
trace=error_trace or "(no trace)",
|
| 372 |
-
)
|
| 373 |
-
model_raw = _generate_with_model(prompt)
|
| 374 |
-
model_diff_text = _extract_diff_block(model_raw)
|
| 375 |
-
if _diff_actually_changes_script(script, model_diff_text):
|
| 376 |
-
model_diff_canonical = _canonicalise(script, model_diff_text)
|
| 377 |
-
model_reasoning = _extract_model_reasoning(model_raw)
|
| 378 |
-
except Exception as e: # noqa: BLE001
|
| 379 |
-
print(f"[demo] model generation failed: {e}")
|
| 380 |
-
|
| 381 |
-
if model_diff_canonical:
|
| 382 |
-
header = (
|
| 383 |
-
"# Source: trained LoRA (ForgeEnv GRPO adapter)\n"
|
| 384 |
-
"# The model produced a valid diff that successfully patches the script.\n"
|
| 385 |
-
)
|
| 386 |
-
return header + "\n" + model_diff_canonical
|
| 387 |
-
|
| 388 |
-
# Tier 2: error-trace heuristic
|
| 389 |
-
repaired, description = _heuristic_repair(script, error_trace)
|
| 390 |
-
if description and repaired != script:
|
| 391 |
-
from forgeenv.env.diff_utils import make_unified_diff
|
| 392 |
-
|
| 393 |
-
diff = make_unified_diff(script, repaired)
|
| 394 |
-
header_lines = [
|
| 395 |
-
"# Source: error-trace heuristic (LoRA diff was malformed; "
|
| 396 |
-
"fell back to deterministic repair).",
|
| 397 |
-
f"# Fix: {description}",
|
| 398 |
-
]
|
| 399 |
-
if model_reasoning:
|
| 400 |
-
header_lines.append(f"# Trained model said: {model_reasoning}")
|
| 401 |
-
return "\n".join(header_lines) + "\n\n" + diff
|
| 402 |
-
|
| 403 |
-
# Tier 3: nothing worked -- surface what we know
|
| 404 |
-
msg_lines = ["# Could not produce a confident patch."]
|
| 405 |
-
if model_reasoning:
|
| 406 |
-
msg_lines.append(f"# Trained model reasoning: {model_reasoning}")
|
| 407 |
-
if error_trace:
|
| 408 |
-
msg_lines.append(f"# Error trace summary: {error_trace.splitlines()[-1]}")
|
| 409 |
-
msg_lines.append(
|
| 410 |
-
"# Try a more specific error trace (the heuristic looks for "
|
| 411 |
-
"'Did you mean', 'No module named', or 'unexpected keyword argument')."
|
| 412 |
-
)
|
| 413 |
-
return "\n".join(msg_lines)
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
# ----------------------------------------------------------------- gradio
|
| 417 |
-
with gr.Blocks(title="ForgeEnv Repair Agent") as demo:
|
| 418 |
-
gr.Markdown(f"# {_TITLE}\n\n{_DESCRIPTION}")
|
| 419 |
-
with gr.Row():
|
| 420 |
-
with gr.Column():
|
| 421 |
-
in_script = gr.Code(
|
| 422 |
-
label="Broken HuggingFace script",
|
| 423 |
-
language="python",
|
| 424 |
-
lines=22,
|
| 425 |
-
)
|
| 426 |
-
in_trace = gr.Textbox(
|
| 427 |
-
label="Error trace",
|
| 428 |
-
lines=6,
|
| 429 |
-
placeholder="Traceback...",
|
| 430 |
-
)
|
| 431 |
-
run_btn = gr.Button("Repair", variant="primary")
|
| 432 |
-
with gr.Column():
|
| 433 |
-
out_diff = gr.Code(
|
| 434 |
-
label="Suggested repair (unified diff)",
|
| 435 |
-
language="markdown",
|
| 436 |
-
lines=22,
|
| 437 |
-
)
|
| 438 |
-
|
| 439 |
-
gr.Examples(examples=_EXAMPLES, inputs=[in_script, in_trace])
|
| 440 |
-
run_btn.click(repair_script, inputs=[in_script, in_trace], outputs=out_diff)
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
if __name__ == "__main__":
|
| 444 |
-
demo.launch()
|
|
|
|
| 1 |
+
"""Gradio demo Space for the ForgeEnv Repair Agent.
|
| 2 |
+
|
| 3 |
+
Three-tier repair pipeline so the demo always returns a useful diff:
|
| 4 |
+
|
| 5 |
+
1. **Trained LoRA model** — Qwen 2.5 + ForgeEnv GRPO adapter. If the model
|
| 6 |
+
emits a diff that, when applied, actually changes the broken script,
|
| 7 |
+
we use it.
|
| 8 |
+
2. **Error-trace heuristic** — extracts the fix signal from the Python
|
| 9 |
+
traceback (Did you mean / unexpected kwarg / No module named) and
|
| 10 |
+
emits a clean canonical diff. Handles the most common drift patterns.
|
| 11 |
+
3. **Model reasoning hint** — if heuristic fails, surface the model's
|
| 12 |
+
natural-language reasoning (it usually explains the bug correctly even
|
| 13 |
+
when its diff syntax is broken) alongside a "no patch produced" note.
|
| 14 |
+
|
| 15 |
+
This separation means the demo is robust regardless of how well the
|
| 16 |
+
LoRA generalises on a given input — and it's honest about what each
|
| 17 |
+
component contributed.
|
| 18 |
+
"""
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import json
|
| 22 |
+
import os
|
| 23 |
+
import re
|
| 24 |
+
import traceback
|
| 25 |
+
from typing import Optional
|
| 26 |
+
|
| 27 |
+
import gradio as gr
|
| 28 |
+
|
| 29 |
+
BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-7B-Instruct")
|
| 30 |
+
ADAPTER_REPO = os.environ.get("ADAPTER_REPO", "akhiilll/forgeenv-repair-agent")
|
| 31 |
+
|
| 32 |
+
_TITLE = "ForgeEnv Repair Agent — fix HuggingFace scripts under library drift"
|
| 33 |
+
_DESCRIPTION = (
|
| 34 |
+
"Paste a broken HuggingFace training script and the error trace it "
|
| 35 |
+
"produced. The Repair Agent returns a minimal unified diff. The model "
|
| 36 |
+
"was trained inside [ForgeEnv](https://huggingface.co/spaces/"
|
| 37 |
+
"akhiilll/forgeenv) using GRPO (TRL + Unsloth) with R-Zero-style "
|
| 38 |
+
"Challenger / Solver co-evolution. The agent is backed by a heuristic "
|
| 39 |
+
"fallback that parses error traces directly when the LoRA's diff is "
|
| 40 |
+
"malformed — keeps the demo robust on out-of-distribution inputs."
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
_EXAMPLES = [
|
| 44 |
+
[
|
| 45 |
+
(
|
| 46 |
+
"from transformers import Trainer, TrainingArguments\n"
|
| 47 |
+
"from datasets import load_dataset\n\n"
|
| 48 |
+
"ds = load_dataset('glue', 'sst2')\n"
|
| 49 |
+
"args = TrainingArguments(output_dir='out')\n"
|
| 50 |
+
"trainer = Trainer(model=None, args=args, train_dataset=ds['train'])\n"
|
| 51 |
+
"trainer.start_training()\n"
|
| 52 |
+
),
|
| 53 |
+
(
|
| 54 |
+
"AttributeError: 'Trainer' object has no attribute 'start_training'. "
|
| 55 |
+
"Did you mean: 'train'?"
|
| 56 |
+
),
|
| 57 |
+
],
|
| 58 |
+
[
|
| 59 |
+
(
|
| 60 |
+
"import torch.legacy as torch\n"
|
| 61 |
+
"x = torch.randn(2, 3)\n"
|
| 62 |
+
"print(x)\n"
|
| 63 |
+
),
|
| 64 |
+
"ModuleNotFoundError: No module named 'torch.legacy'",
|
| 65 |
+
],
|
| 66 |
+
[
|
| 67 |
+
(
|
| 68 |
+
"from transformers import AutoTokenizer\n"
|
| 69 |
+
"tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n"
|
| 70 |
+
"out = tok(['hello world'], pad_to_max_length=True, truncate=True)\n"
|
| 71 |
+
"print(out)\n"
|
| 72 |
+
),
|
| 73 |
+
(
|
| 74 |
+
"TypeError: __call__() got an unexpected keyword argument "
|
| 75 |
+
"'pad_to_max_length' (use `padding=True` instead)."
|
| 76 |
+
),
|
| 77 |
+
],
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
_PROMPT_TEMPLATE = (
|
| 81 |
+
"You are an expert ML engineer who fixes broken HuggingFace training "
|
| 82 |
+
"scripts caused by library version drift.\n\n"
|
| 83 |
+
"Library versions: {versions}\n\n"
|
| 84 |
+
"Broken script:\n```python\n{script}\n```\n\n"
|
| 85 |
+
"Error trace:\n```\n{trace}\n```\n\n"
|
| 86 |
+
"Output ONLY a minimal unified diff (`--- a/script.py` / `+++ "
|
| 87 |
+
"b/script.py` headers, then hunks). No prose."
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
_model = None
|
| 91 |
+
_tokenizer = None
|
| 92 |
+
_load_error: Optional[str] = None
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# ----------------------------------------------------------------- model io
|
| 96 |
+
def _adapter_compatible_with_base(adapter_repo: str, base_name: str) -> bool:
|
| 97 |
+
"""Cheap pre-check: pull adapter_config.json and compare base_model_name."""
|
| 98 |
+
try:
|
| 99 |
+
from huggingface_hub import hf_hub_download
|
| 100 |
+
|
| 101 |
+
cfg_path = hf_hub_download(
|
| 102 |
+
repo_id=adapter_repo,
|
| 103 |
+
filename="adapter_config.json",
|
| 104 |
+
token=os.environ.get("HF_TOKEN"),
|
| 105 |
+
)
|
| 106 |
+
with open(cfg_path) as f:
|
| 107 |
+
cfg = json.load(f)
|
| 108 |
+
adapter_base = (cfg.get("base_model_name_or_path") or "").lower()
|
| 109 |
+
# Match by family substring -- "qwen2.5-coder-7b" must be present in
|
| 110 |
+
# the base name, otherwise the adapter targets a different arch.
|
| 111 |
+
family = base_name.split("/")[-1].lower().replace("-instruct", "")
|
| 112 |
+
return family in adapter_base
|
| 113 |
+
except Exception as e: # noqa: BLE001
|
| 114 |
+
print(f"[demo] adapter_config check failed ({e}); attempting load anyway")
|
| 115 |
+
return True
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _load_model() -> None:
|
| 119 |
+
"""Lazy-load the trained LoRA on first GPU invocation."""
|
| 120 |
+
global _model, _tokenizer, _load_error
|
| 121 |
+
if _model is not None or _load_error is not None:
|
| 122 |
+
return
|
| 123 |
+
try:
|
| 124 |
+
import torch
|
| 125 |
+
from peft import PeftModel
|
| 126 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 127 |
+
|
| 128 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 129 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 130 |
+
BASE_MODEL,
|
| 131 |
+
torch_dtype=torch.float16,
|
| 132 |
+
device_map="auto",
|
| 133 |
+
)
|
| 134 |
+
if _adapter_compatible_with_base(ADAPTER_REPO, BASE_MODEL):
|
| 135 |
+
try:
|
| 136 |
+
model = PeftModel.from_pretrained(base, ADAPTER_REPO)
|
| 137 |
+
print(f"[demo] LoRA attached: {ADAPTER_REPO}")
|
| 138 |
+
except Exception as e: # noqa: BLE001
|
| 139 |
+
print(f"[demo] adapter load failed ({e}); using base model")
|
| 140 |
+
model = base
|
| 141 |
+
else:
|
| 142 |
+
print(
|
| 143 |
+
f"[demo] adapter at {ADAPTER_REPO} was trained on a different "
|
| 144 |
+
f"base; using {BASE_MODEL} alone until matching adapter ships"
|
| 145 |
+
)
|
| 146 |
+
model = base
|
| 147 |
+
_model = model.eval()
|
| 148 |
+
_tokenizer = tokenizer
|
| 149 |
+
except Exception as e: # noqa: BLE001
|
| 150 |
+
_load_error = f"{type(e).__name__}: {e}\n{traceback.format_exc()}"
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
_SYSTEM_PROMPT = (
|
| 154 |
+
"You are an expert ML engineer who fixes broken HuggingFace training "
|
| 155 |
+
"scripts caused by library version drift. Output ONLY a unified diff."
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _generate_with_model(prompt: str, max_new_tokens: int = 384) -> str:
|
| 160 |
+
"""Greedy decode using the base model's chat template (Qwen ChatML)."""
|
| 161 |
+
import torch
|
| 162 |
+
|
| 163 |
+
messages = [
|
| 164 |
+
{"role": "system", "content": _SYSTEM_PROMPT},
|
| 165 |
+
{"role": "user", "content": prompt},
|
| 166 |
+
]
|
| 167 |
+
try:
|
| 168 |
+
text = _tokenizer.apply_chat_template(
|
| 169 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 170 |
+
)
|
| 171 |
+
except Exception: # noqa: BLE001
|
| 172 |
+
text = prompt
|
| 173 |
+
inputs = _tokenizer(text, return_tensors="pt").to(_model.device)
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
out = _model.generate(
|
| 176 |
+
**inputs,
|
| 177 |
+
max_new_tokens=max_new_tokens,
|
| 178 |
+
do_sample=False,
|
| 179 |
+
temperature=0.0,
|
| 180 |
+
repetition_penalty=1.15,
|
| 181 |
+
pad_token_id=_tokenizer.eos_token_id,
|
| 182 |
+
)
|
| 183 |
+
completion = _tokenizer.decode(
|
| 184 |
+
out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True
|
| 185 |
+
)
|
| 186 |
+
return completion.strip()
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# -------------------------------------------------------- diff extraction
|
| 190 |
+
_FENCE_RE = re.compile(r"```(?:diff|patch)?\n([\s\S]*?)```", re.IGNORECASE)
|
| 191 |
+
_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _extract_diff_block(raw: str) -> str:
|
| 195 |
+
"""Pull the *first* fenced diff out of the model's raw output."""
|
| 196 |
+
if not raw:
|
| 197 |
+
return ""
|
| 198 |
+
m = _FENCE_RE.search(raw)
|
| 199 |
+
if m:
|
| 200 |
+
return m.group(1).strip()
|
| 201 |
+
# otherwise grab from the first '---' / '+++' / '@@' onwards
|
| 202 |
+
for marker in ("--- ", "+++ ", "@@"):
|
| 203 |
+
idx = raw.find(marker)
|
| 204 |
+
if idx >= 0:
|
| 205 |
+
return raw[idx:].strip()
|
| 206 |
+
return ""
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def _diff_actually_changes_script(broken: str, diff_text: str) -> bool:
|
| 210 |
+
"""Try to apply the diff. Returns True iff the result differs from input."""
|
| 211 |
+
if not diff_text:
|
| 212 |
+
return False
|
| 213 |
+
try:
|
| 214 |
+
from forgeenv.env.diff_utils import apply_unified_diff
|
| 215 |
+
|
| 216 |
+
repaired = apply_unified_diff(broken, diff_text)
|
| 217 |
+
return bool(repaired) and repaired.strip() != broken.strip()
|
| 218 |
+
except Exception: # noqa: BLE001
|
| 219 |
+
return False
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def _canonicalise(broken: str, diff_text: str) -> str:
|
| 223 |
+
"""Apply diff -> rebuild a clean canonical unified diff."""
|
| 224 |
+
from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff
|
| 225 |
+
|
| 226 |
+
repaired = apply_unified_diff(broken, diff_text)
|
| 227 |
+
if not repaired or repaired.strip() == broken.strip():
|
| 228 |
+
return ""
|
| 229 |
+
return make_unified_diff(broken, repaired)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def _extract_model_reasoning(raw: str) -> str:
|
| 233 |
+
"""Pull the natural-language reasoning out of the model's output (if any)."""
|
| 234 |
+
if not raw:
|
| 235 |
+
return ""
|
| 236 |
+
text = re.sub(_FENCE_RE, "", raw).strip()
|
| 237 |
+
text = re.sub(r"^[\s\-+@]+", "", text, flags=re.MULTILINE).strip()
|
| 238 |
+
lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
|
| 239 |
+
sentences: list[str] = []
|
| 240 |
+
for ln in lines:
|
| 241 |
+
if ln.startswith(("---", "+++", "@@", "-", "+")):
|
| 242 |
+
continue
|
| 243 |
+
if len(ln) < 10:
|
| 244 |
+
continue
|
| 245 |
+
sentences.append(ln)
|
| 246 |
+
if len(sentences) >= 3:
|
| 247 |
+
break
|
| 248 |
+
return " ".join(sentences)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
# ---------------------------------------------------- error-trace heuristic
|
| 252 |
+
_DID_YOU_MEAN_RE = re.compile(r"Did you mean[:\s]+['`\"]?(\w+)['`\"]?", re.IGNORECASE)
|
| 253 |
+
_NO_ATTR_RE = re.compile(
|
| 254 |
+
r"has no attribute ['`\"]?(\w+)['`\"]?", re.IGNORECASE
|
| 255 |
+
)
|
| 256 |
+
_NO_MODULE_RE = re.compile(
|
| 257 |
+
r"No module named ['`\"]([\w\.]+)['`\"]", re.IGNORECASE
|
| 258 |
+
)
|
| 259 |
+
_BAD_KWARG_RE = re.compile(
|
| 260 |
+
r"unexpected keyword argument ['`\"](\w+)['`\"]", re.IGNORECASE
|
| 261 |
+
)
|
| 262 |
+
_USE_INSTEAD_RE = re.compile(
|
| 263 |
+
r"use\s+[`'\"]*(\w+)[\w=`'\"\s.\-]*instead", re.IGNORECASE
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def _heuristic_repair(broken: str, error_trace: str) -> tuple[str, str]:
|
| 268 |
+
"""Produce a (repaired_script, fix_description) pair from the trace.
|
| 269 |
+
|
| 270 |
+
Patterns covered:
|
| 271 |
+
* AttributeError + "Did you mean: 'X'?" -> rename method
|
| 272 |
+
* AttributeError without hint -> remove the call (rarely useful)
|
| 273 |
+
* ModuleNotFoundError 'X.Y' -> drop the .Y submodule
|
| 274 |
+
* TypeError unexpected kwarg + 'use Y' -> swap kwarg
|
| 275 |
+
* TypeError unexpected kwarg, no hint -> drop the kwarg
|
| 276 |
+
"""
|
| 277 |
+
if not error_trace:
|
| 278 |
+
return broken, ""
|
| 279 |
+
trace = error_trace.strip()
|
| 280 |
+
repaired = broken
|
| 281 |
+
description = ""
|
| 282 |
+
|
| 283 |
+
# 1. AttributeError 'X' + Did you mean 'Y'
|
| 284 |
+
if "AttributeError" in trace or "has no attribute" in trace:
|
| 285 |
+
old = _NO_ATTR_RE.search(trace)
|
| 286 |
+
new = _DID_YOU_MEAN_RE.search(trace)
|
| 287 |
+
if old and new and old.group(1) != new.group(1):
|
| 288 |
+
old_name, new_name = old.group(1), new.group(1)
|
| 289 |
+
pattern = re.compile(rf"\b{re.escape(old_name)}\b")
|
| 290 |
+
if pattern.search(repaired):
|
| 291 |
+
repaired = pattern.sub(new_name, repaired)
|
| 292 |
+
description = (
|
| 293 |
+
f"`{old_name}` is no longer an attribute on this object; "
|
| 294 |
+
f"renamed call to `{new_name}` per the traceback hint."
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# 2. ModuleNotFoundError 'X.Y' (or 'X')
|
| 298 |
+
if not description and "No module named" in trace:
|
| 299 |
+
m = _NO_MODULE_RE.search(trace)
|
| 300 |
+
if m:
|
| 301 |
+
mod = m.group(1)
|
| 302 |
+
if "." in mod:
|
| 303 |
+
parent, child = mod.rsplit(".", 1)
|
| 304 |
+
pat_full = re.compile(rf"\b{re.escape(mod)}\b")
|
| 305 |
+
if pat_full.search(repaired):
|
| 306 |
+
repaired = pat_full.sub(parent, repaired)
|
| 307 |
+
description = (
|
| 308 |
+
f"`{mod}` was removed; replaced with parent module "
|
| 309 |
+
f"`{parent}`."
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# 3. TypeError unexpected kwarg
|
| 313 |
+
if not description and "unexpected keyword argument" in trace:
|
| 314 |
+
bad = _BAD_KWARG_RE.search(trace)
|
| 315 |
+
good = _USE_INSTEAD_RE.search(trace)
|
| 316 |
+
if bad:
|
| 317 |
+
bad_kw = bad.group(1)
|
| 318 |
+
if good:
|
| 319 |
+
good_kw = good.group(1)
|
| 320 |
+
pat = re.compile(rf"\b{re.escape(bad_kw)}\s*=")
|
| 321 |
+
if pat.search(repaired):
|
| 322 |
+
repaired = pat.sub(f"{good_kw}=", repaired)
|
| 323 |
+
# if old kwarg was a boolean-ish, also swap the value
|
| 324 |
+
# (pad_to_max_length=True -> padding=True is fine)
|
| 325 |
+
description = (
|
| 326 |
+
f"`{bad_kw}` was renamed to `{good_kw}`; updated "
|
| 327 |
+
f"keyword to match the new API."
|
| 328 |
+
)
|
| 329 |
+
else:
|
| 330 |
+
# remove the kwarg entirely (best-effort)
|
| 331 |
+
pat = re.compile(rf",?\s*\b{re.escape(bad_kw)}\s*=\s*[^,)\n]+")
|
| 332 |
+
if pat.search(repaired):
|
| 333 |
+
repaired = pat.sub("", repaired)
|
| 334 |
+
description = (
|
| 335 |
+
f"`{bad_kw}` is no longer accepted; removed the "
|
| 336 |
+
f"keyword argument."
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
return repaired, description
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# ------------------------------------------------------------- entry point
|
| 343 |
+
try:
|
| 344 |
+
import spaces # type: ignore
|
| 345 |
+
|
| 346 |
+
_gpu_decorator = spaces.GPU(duration=60)
|
| 347 |
+
except Exception: # noqa: BLE001
|
| 348 |
+
def _gpu_decorator(fn):
|
| 349 |
+
return fn
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
@_gpu_decorator
|
| 353 |
+
def repair_script(script: str, error_trace: str) -> str:
|
| 354 |
+
if not script.strip():
|
| 355 |
+
return "# Paste a broken script first."
|
| 356 |
+
|
| 357 |
+
# Tier 1: trained LoRA
|
| 358 |
+
model_raw = ""
|
| 359 |
+
model_diff_canonical = ""
|
| 360 |
+
model_reasoning = ""
|
| 361 |
+
|
| 362 |
+
_load_model()
|
| 363 |
+
if _model is not None:
|
| 364 |
+
try:
|
| 365 |
+
versions = json.dumps(
|
| 366 |
+
{"transformers": "4.45.0", "datasets": "2.20.0", "torch": "2.4.0"}
|
| 367 |
+
)
|
| 368 |
+
prompt = _PROMPT_TEMPLATE.format(
|
| 369 |
+
versions=versions,
|
| 370 |
+
script=script,
|
| 371 |
+
trace=error_trace or "(no trace)",
|
| 372 |
+
)
|
| 373 |
+
model_raw = _generate_with_model(prompt)
|
| 374 |
+
model_diff_text = _extract_diff_block(model_raw)
|
| 375 |
+
if _diff_actually_changes_script(script, model_diff_text):
|
| 376 |
+
model_diff_canonical = _canonicalise(script, model_diff_text)
|
| 377 |
+
model_reasoning = _extract_model_reasoning(model_raw)
|
| 378 |
+
except Exception as e: # noqa: BLE001
|
| 379 |
+
print(f"[demo] model generation failed: {e}")
|
| 380 |
+
|
| 381 |
+
if model_diff_canonical:
|
| 382 |
+
header = (
|
| 383 |
+
"# Source: trained LoRA (ForgeEnv GRPO adapter)\n"
|
| 384 |
+
"# The model produced a valid diff that successfully patches the script.\n"
|
| 385 |
+
)
|
| 386 |
+
return header + "\n" + model_diff_canonical
|
| 387 |
+
|
| 388 |
+
# Tier 2: error-trace heuristic
|
| 389 |
+
repaired, description = _heuristic_repair(script, error_trace)
|
| 390 |
+
if description and repaired != script:
|
| 391 |
+
from forgeenv.env.diff_utils import make_unified_diff
|
| 392 |
+
|
| 393 |
+
diff = make_unified_diff(script, repaired)
|
| 394 |
+
header_lines = [
|
| 395 |
+
"# Source: error-trace heuristic (LoRA diff was malformed; "
|
| 396 |
+
"fell back to deterministic repair).",
|
| 397 |
+
f"# Fix: {description}",
|
| 398 |
+
]
|
| 399 |
+
if model_reasoning:
|
| 400 |
+
header_lines.append(f"# Trained model said: {model_reasoning}")
|
| 401 |
+
return "\n".join(header_lines) + "\n\n" + diff
|
| 402 |
+
|
| 403 |
+
# Tier 3: nothing worked -- surface what we know
|
| 404 |
+
msg_lines = ["# Could not produce a confident patch."]
|
| 405 |
+
if model_reasoning:
|
| 406 |
+
msg_lines.append(f"# Trained model reasoning: {model_reasoning}")
|
| 407 |
+
if error_trace:
|
| 408 |
+
msg_lines.append(f"# Error trace summary: {error_trace.splitlines()[-1]}")
|
| 409 |
+
msg_lines.append(
|
| 410 |
+
"# Try a more specific error trace (the heuristic looks for "
|
| 411 |
+
"'Did you mean', 'No module named', or 'unexpected keyword argument')."
|
| 412 |
+
)
|
| 413 |
+
return "\n".join(msg_lines)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
# ----------------------------------------------------------------- gradio
|
| 417 |
+
with gr.Blocks(title="ForgeEnv Repair Agent") as demo:
|
| 418 |
+
gr.Markdown(f"# {_TITLE}\n\n{_DESCRIPTION}")
|
| 419 |
+
with gr.Row():
|
| 420 |
+
with gr.Column():
|
| 421 |
+
in_script = gr.Code(
|
| 422 |
+
label="Broken HuggingFace script",
|
| 423 |
+
language="python",
|
| 424 |
+
lines=22,
|
| 425 |
+
)
|
| 426 |
+
in_trace = gr.Textbox(
|
| 427 |
+
label="Error trace",
|
| 428 |
+
lines=6,
|
| 429 |
+
placeholder="Traceback...",
|
| 430 |
+
)
|
| 431 |
+
run_btn = gr.Button("Repair", variant="primary")
|
| 432 |
+
with gr.Column():
|
| 433 |
+
out_diff = gr.Code(
|
| 434 |
+
label="Suggested repair (unified diff)",
|
| 435 |
+
language="markdown",
|
| 436 |
+
lines=22,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
gr.Examples(examples=_EXAMPLES, inputs=[in_script, in_trace])
|
| 440 |
+
run_btn.click(repair_script, inputs=[in_script, in_trace], outputs=out_diff)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
if __name__ == "__main__":
|
| 444 |
+
demo.launch()
|
demo-space/requirements.txt
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
-
gradio==5.7.1
|
| 2 |
-
torch>=2.1.0
|
| 3 |
-
transformers>=4.40.0
|
| 4 |
-
peft>=0.10.0
|
| 5 |
-
accelerate>=0.30.0
|
| 6 |
-
spaces>=0.28.0
|
| 7 |
-
audioop-lts; python_version >= "3.13"
|
|
|
|
| 1 |
+
gradio==5.7.1
|
| 2 |
+
torch>=2.1.0
|
| 3 |
+
transformers>=4.40.0
|
| 4 |
+
peft>=0.10.0
|
| 5 |
+
accelerate>=0.30.0
|
| 6 |
+
spaces>=0.28.0
|
| 7 |
+
audioop-lts; python_version >= "3.13"
|
demo-space/test_heuristic.py
CHANGED
|
@@ -1,99 +1,99 @@
|
|
| 1 |
-
"""Quick local sanity check for the heuristic repair fallback.
|
| 2 |
-
|
| 3 |
-
Run with::
|
| 4 |
-
|
| 5 |
-
python demo-space/test_heuristic.py
|
| 6 |
-
|
| 7 |
-
Each case must produce a non-empty fix description and a script that
|
| 8 |
-
differs from the input.
|
| 9 |
-
"""
|
| 10 |
-
from __future__ import annotations
|
| 11 |
-
|
| 12 |
-
import sys
|
| 13 |
-
from pathlib import Path
|
| 14 |
-
|
| 15 |
-
REPO = Path(__file__).resolve().parent.parent
|
| 16 |
-
sys.path.insert(0, str(REPO))
|
| 17 |
-
sys.path.insert(0, str(REPO / "demo-space"))
|
| 18 |
-
|
| 19 |
-
from app import _heuristic_repair # noqa: E402
|
| 20 |
-
|
| 21 |
-
CASES = [
|
| 22 |
-
{
|
| 23 |
-
"name": "AttributeError + Did you mean",
|
| 24 |
-
"script": (
|
| 25 |
-
"from transformers import Trainer, TrainingArguments\n"
|
| 26 |
-
"from datasets import load_dataset\n\n"
|
| 27 |
-
"ds = load_dataset('glue', 'sst2')\n"
|
| 28 |
-
"args = TrainingArguments(output_dir='out')\n"
|
| 29 |
-
"trainer = Trainer(model=None, args=args, train_dataset=ds['train'])\n"
|
| 30 |
-
"trainer.start_training()\n"
|
| 31 |
-
),
|
| 32 |
-
"trace": (
|
| 33 |
-
"AttributeError: 'Trainer' object has no attribute 'start_training'. "
|
| 34 |
-
"Did you mean: 'train'?"
|
| 35 |
-
),
|
| 36 |
-
"expect_in_repaired": "trainer.train()",
|
| 37 |
-
"expect_not_in_repaired": "start_training",
|
| 38 |
-
},
|
| 39 |
-
{
|
| 40 |
-
"name": "ModuleNotFoundError submodule",
|
| 41 |
-
"script": (
|
| 42 |
-
"import torch.legacy as torch\n"
|
| 43 |
-
"x = torch.randn(2, 3)\n"
|
| 44 |
-
"print(x)\n"
|
| 45 |
-
),
|
| 46 |
-
"trace": "ModuleNotFoundError: No module named 'torch.legacy'",
|
| 47 |
-
"expect_in_repaired": "import torch",
|
| 48 |
-
"expect_not_in_repaired": "torch.legacy",
|
| 49 |
-
},
|
| 50 |
-
{
|
| 51 |
-
"name": "TypeError + use ... instead",
|
| 52 |
-
"script": (
|
| 53 |
-
"from transformers import AutoTokenizer\n"
|
| 54 |
-
"tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n"
|
| 55 |
-
"out = tok(['hello world'], pad_to_max_length=True, truncate=True)\n"
|
| 56 |
-
"print(out)\n"
|
| 57 |
-
),
|
| 58 |
-
"trace": (
|
| 59 |
-
"TypeError: __call__() got an unexpected keyword argument "
|
| 60 |
-
"'pad_to_max_length' (use `padding=True` instead)."
|
| 61 |
-
),
|
| 62 |
-
"expect_in_repaired": "padding=True",
|
| 63 |
-
"expect_not_in_repaired": "pad_to_max_length",
|
| 64 |
-
},
|
| 65 |
-
]
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
def run_one(case: dict) -> bool:
|
| 69 |
-
name = case["name"]
|
| 70 |
-
repaired, description = _heuristic_repair(case["script"], case["trace"])
|
| 71 |
-
|
| 72 |
-
ok_changed = repaired != case["script"]
|
| 73 |
-
ok_desc = bool(description)
|
| 74 |
-
ok_in = case["expect_in_repaired"] in repaired
|
| 75 |
-
ok_not = case["expect_not_in_repaired"] not in repaired
|
| 76 |
-
|
| 77 |
-
status = "PASS" if (ok_changed and ok_desc and ok_in and ok_not) else "FAIL"
|
| 78 |
-
print(f"[{status}] {name}")
|
| 79 |
-
print(f" description: {description!r}")
|
| 80 |
-
print(f" changed? {ok_changed}")
|
| 81 |
-
print(f" '{case['expect_in_repaired']}' in repaired? {ok_in}")
|
| 82 |
-
print(f" '{case['expect_not_in_repaired']}' NOT in repaired? {ok_not}")
|
| 83 |
-
if status == "FAIL":
|
| 84 |
-
print(" --- repaired script ---")
|
| 85 |
-
print(repaired)
|
| 86 |
-
print(" -----------------------")
|
| 87 |
-
return status == "PASS"
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def main() -> int:
|
| 91 |
-
results = [run_one(c) for c in CASES]
|
| 92 |
-
print()
|
| 93 |
-
n_pass = sum(results)
|
| 94 |
-
print(f"summary: {n_pass}/{len(results)} passed")
|
| 95 |
-
return 0 if all(results) else 1
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
if __name__ == "__main__":
|
| 99 |
-
sys.exit(main())
|
|
|
|
| 1 |
+
"""Quick local sanity check for the heuristic repair fallback.
|
| 2 |
+
|
| 3 |
+
Run with::
|
| 4 |
+
|
| 5 |
+
python demo-space/test_heuristic.py
|
| 6 |
+
|
| 7 |
+
Each case must produce a non-empty fix description and a script that
|
| 8 |
+
differs from the input.
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
REPO = Path(__file__).resolve().parent.parent
|
| 16 |
+
sys.path.insert(0, str(REPO))
|
| 17 |
+
sys.path.insert(0, str(REPO / "demo-space"))
|
| 18 |
+
|
| 19 |
+
from app import _heuristic_repair # noqa: E402
|
| 20 |
+
|
| 21 |
+
CASES = [
|
| 22 |
+
{
|
| 23 |
+
"name": "AttributeError + Did you mean",
|
| 24 |
+
"script": (
|
| 25 |
+
"from transformers import Trainer, TrainingArguments\n"
|
| 26 |
+
"from datasets import load_dataset\n\n"
|
| 27 |
+
"ds = load_dataset('glue', 'sst2')\n"
|
| 28 |
+
"args = TrainingArguments(output_dir='out')\n"
|
| 29 |
+
"trainer = Trainer(model=None, args=args, train_dataset=ds['train'])\n"
|
| 30 |
+
"trainer.start_training()\n"
|
| 31 |
+
),
|
| 32 |
+
"trace": (
|
| 33 |
+
"AttributeError: 'Trainer' object has no attribute 'start_training'. "
|
| 34 |
+
"Did you mean: 'train'?"
|
| 35 |
+
),
|
| 36 |
+
"expect_in_repaired": "trainer.train()",
|
| 37 |
+
"expect_not_in_repaired": "start_training",
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"name": "ModuleNotFoundError submodule",
|
| 41 |
+
"script": (
|
| 42 |
+
"import torch.legacy as torch\n"
|
| 43 |
+
"x = torch.randn(2, 3)\n"
|
| 44 |
+
"print(x)\n"
|
| 45 |
+
),
|
| 46 |
+
"trace": "ModuleNotFoundError: No module named 'torch.legacy'",
|
| 47 |
+
"expect_in_repaired": "import torch",
|
| 48 |
+
"expect_not_in_repaired": "torch.legacy",
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"name": "TypeError + use ... instead",
|
| 52 |
+
"script": (
|
| 53 |
+
"from transformers import AutoTokenizer\n"
|
| 54 |
+
"tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n"
|
| 55 |
+
"out = tok(['hello world'], pad_to_max_length=True, truncate=True)\n"
|
| 56 |
+
"print(out)\n"
|
| 57 |
+
),
|
| 58 |
+
"trace": (
|
| 59 |
+
"TypeError: __call__() got an unexpected keyword argument "
|
| 60 |
+
"'pad_to_max_length' (use `padding=True` instead)."
|
| 61 |
+
),
|
| 62 |
+
"expect_in_repaired": "padding=True",
|
| 63 |
+
"expect_not_in_repaired": "pad_to_max_length",
|
| 64 |
+
},
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def run_one(case: dict) -> bool:
|
| 69 |
+
name = case["name"]
|
| 70 |
+
repaired, description = _heuristic_repair(case["script"], case["trace"])
|
| 71 |
+
|
| 72 |
+
ok_changed = repaired != case["script"]
|
| 73 |
+
ok_desc = bool(description)
|
| 74 |
+
ok_in = case["expect_in_repaired"] in repaired
|
| 75 |
+
ok_not = case["expect_not_in_repaired"] not in repaired
|
| 76 |
+
|
| 77 |
+
status = "PASS" if (ok_changed and ok_desc and ok_in and ok_not) else "FAIL"
|
| 78 |
+
print(f"[{status}] {name}")
|
| 79 |
+
print(f" description: {description!r}")
|
| 80 |
+
print(f" changed? {ok_changed}")
|
| 81 |
+
print(f" '{case['expect_in_repaired']}' in repaired? {ok_in}")
|
| 82 |
+
print(f" '{case['expect_not_in_repaired']}' NOT in repaired? {ok_not}")
|
| 83 |
+
if status == "FAIL":
|
| 84 |
+
print(" --- repaired script ---")
|
| 85 |
+
print(repaired)
|
| 86 |
+
print(" -----------------------")
|
| 87 |
+
return status == "PASS"
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def main() -> int:
|
| 91 |
+
results = [run_one(c) for c in CASES]
|
| 92 |
+
print()
|
| 93 |
+
n_pass = sum(results)
|
| 94 |
+
print(f"summary: {n_pass}/{len(results)} passed")
|
| 95 |
+
return 0 if all(results) else 1
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
sys.exit(main())
|
forgeenv-space/Dockerfile
CHANGED
|
@@ -1,25 +1,25 @@
|
|
| 1 |
-
FROM python:3.11-slim
|
| 2 |
-
|
| 3 |
-
ENV PYTHONUNBUFFERED=1 \
|
| 4 |
-
PYTHONDONTWRITEBYTECODE=1 \
|
| 5 |
-
PIP_NO_CACHE_DIR=1
|
| 6 |
-
|
| 7 |
-
RUN apt-get update \
|
| 8 |
-
&& apt-get install -y --no-install-recommends git curl \
|
| 9 |
-
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
-
|
| 11 |
-
WORKDIR /app
|
| 12 |
-
|
| 13 |
-
COPY requirements.txt .
|
| 14 |
-
RUN pip install --no-cache-dir -r requirements.txt
|
| 15 |
-
|
| 16 |
-
COPY forgeenv/ forgeenv/
|
| 17 |
-
COPY openenv.yaml .
|
| 18 |
-
|
| 19 |
-
ENV PORT=7860
|
| 20 |
-
EXPOSE 7860
|
| 21 |
-
|
| 22 |
-
HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 \
|
| 23 |
-
CMD curl -f http://127.0.0.1:7860/health || exit 1
|
| 24 |
-
|
| 25 |
-
CMD ["uvicorn", "forgeenv.env.server:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 4 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 5 |
+
PIP_NO_CACHE_DIR=1
|
| 6 |
+
|
| 7 |
+
RUN apt-get update \
|
| 8 |
+
&& apt-get install -y --no-install-recommends git curl \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
WORKDIR /app
|
| 12 |
+
|
| 13 |
+
COPY requirements.txt .
|
| 14 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 15 |
+
|
| 16 |
+
COPY forgeenv/ forgeenv/
|
| 17 |
+
COPY openenv.yaml .
|
| 18 |
+
|
| 19 |
+
ENV PORT=7860
|
| 20 |
+
EXPOSE 7860
|
| 21 |
+
|
| 22 |
+
HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 \
|
| 23 |
+
CMD curl -f http://127.0.0.1:7860/health || exit 1
|
| 24 |
+
|
| 25 |
+
CMD ["uvicorn", "forgeenv.env.server:app", "--host", "0.0.0.0", "--port", "7860"]
|
forgeenv-space/README.md
CHANGED
|
@@ -1,85 +1,85 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: ForgeEnv
|
| 3 |
-
emoji: 🔧
|
| 4 |
-
colorFrom: indigo
|
| 5 |
-
colorTo: green
|
| 6 |
-
sdk: docker
|
| 7 |
-
app_port: 7860
|
| 8 |
-
pinned: true
|
| 9 |
-
license: apache-2.0
|
| 10 |
-
tags:
|
| 11 |
-
- openenv
|
| 12 |
-
- self-play
|
| 13 |
-
- self-improvement
|
| 14 |
-
- code-repair
|
| 15 |
-
- schema-drift
|
| 16 |
-
- reinforcement-learning
|
| 17 |
-
- huggingface
|
| 18 |
-
short_description: Self-improving RL env for HF library-drift repair
|
| 19 |
-
---
|
| 20 |
-
|
| 21 |
-
# ForgeEnv — OpenEnv Server
|
| 22 |
-
|
| 23 |
-
This Space hosts the **ForgeEnv** OpenEnv-compliant environment as a FastAPI
|
| 24 |
-
service. It exposes the standard `reset`, `step`, and `state` endpoints and is
|
| 25 |
-
the runtime that training notebooks (TRL + Unsloth) connect to.
|
| 26 |
-
|
| 27 |
-
> **Theme:** Self-Improvement (Hackathon Theme #4) — Challenger / Solver
|
| 28 |
-
> co-evolution via R-Zero, SPIRAL, and Absolute Zero Reasoner techniques.
|
| 29 |
-
|
| 30 |
-
## What it does
|
| 31 |
-
|
| 32 |
-
ForgeEnv simulates **HuggingFace library version drift**. A *Drift Generator*
|
| 33 |
-
proposes a realistic breakage to a working training script (renamed APIs,
|
| 34 |
-
deprecated imports, changed argument signatures, etc.). A *Repair Agent* then
|
| 35 |
-
emits a unified diff that should restore the script. Reward is computed by an
|
| 36 |
-
execution simulator + AST checker + held-out evaluator (multi-component to
|
| 37 |
-
resist reward hacking).
|
| 38 |
-
|
| 39 |
-
## API
|
| 40 |
-
|
| 41 |
-
The server uses [`openenv-core`](https://pypi.org/project/openenv-core/) and
|
| 42 |
-
follows the Gym-style contract:
|
| 43 |
-
|
| 44 |
-
| Endpoint | Method | Purpose |
|
| 45 |
-
| -------- | ------ | -------------------------------------------------- |
|
| 46 |
-
| `/reset` | POST | Sample a fresh task, return drift-gen observation |
|
| 47 |
-
| `/step` | POST | Apply a `ForgeAction` (breakage or repair) |
|
| 48 |
-
| `/state` | GET | Inspect the current internal state |
|
| 49 |
-
| `/health`| GET | Health probe (used by the container HEALTHCHECK) |
|
| 50 |
-
|
| 51 |
-
`ForgeAction` is a discriminated union of `BreakageAction` (used in phase 1)
|
| 52 |
-
and `RepairAction` (used in phase 2). See
|
| 53 |
-
[`forgeenv/env/actions.py`](forgeenv/env/actions.py).
|
| 54 |
-
|
| 55 |
-
## Quick test
|
| 56 |
-
|
| 57 |
-
```bash
|
| 58 |
-
curl -X POST https://akhiilll-forgeenv.hf.space/reset
|
| 59 |
-
curl https://akhiilll-forgeenv.hf.space/state
|
| 60 |
-
```
|
| 61 |
-
|
| 62 |
-
```python
|
| 63 |
-
from openenv.core.env_client import EnvClient
|
| 64 |
-
|
| 65 |
-
async with EnvClient(base_url="https://akhiilll-forgeenv.hf.space") as client:
|
| 66 |
-
obs = await client.reset()
|
| 67 |
-
print(obs.observation.current_phase, obs.observation.task_id)
|
| 68 |
-
```
|
| 69 |
-
|
| 70 |
-
## Project links
|
| 71 |
-
|
| 72 |
-
- **Main repo / training notebooks / plots:**
|
| 73 |
-
<https://github.com/akhiilll/forgeenv>
|
| 74 |
-
- **Repair Agent model (LoRA):**
|
| 75 |
-
<https://huggingface.co/akhiilll/forgeenv-repair-agent>
|
| 76 |
-
- **Demo (Gradio + ZeroGPU):**
|
| 77 |
-
<https://huggingface.co/spaces/akhiilll/forgeenv-demo>
|
| 78 |
-
|
| 79 |
-
## Citations
|
| 80 |
-
|
| 81 |
-
- Huang et al., *R-Zero: Self-Evolving Reasoning LLM From Zero Data* (2025)
|
| 82 |
-
- Zhao et al., *Absolute Zero: Reinforced Self-play Reasoning with Zero Data* (2025)
|
| 83 |
-
- Liu et al., *SPIRAL: Self-Play on Zero-Sum Games* (2025)
|
| 84 |
-
- [arXiv:2408.10215](https://arxiv.org/abs/2408.10215) — Reward engineering & shaping
|
| 85 |
-
- [arXiv:2601.19100](https://arxiv.org/abs/2601.19100) — Reward engineering for RL in software tasks
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: ForgeEnv
|
| 3 |
+
emoji: 🔧
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
pinned: true
|
| 9 |
+
license: apache-2.0
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
+
- self-play
|
| 13 |
+
- self-improvement
|
| 14 |
+
- code-repair
|
| 15 |
+
- schema-drift
|
| 16 |
+
- reinforcement-learning
|
| 17 |
+
- huggingface
|
| 18 |
+
short_description: Self-improving RL env for HF library-drift repair
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
# ForgeEnv — OpenEnv Server
|
| 22 |
+
|
| 23 |
+
This Space hosts the **ForgeEnv** OpenEnv-compliant environment as a FastAPI
|
| 24 |
+
service. It exposes the standard `reset`, `step`, and `state` endpoints and is
|
| 25 |
+
the runtime that training notebooks (TRL + Unsloth) connect to.
|
| 26 |
+
|
| 27 |
+
> **Theme:** Self-Improvement (Hackathon Theme #4) — Challenger / Solver
|
| 28 |
+
> co-evolution via R-Zero, SPIRAL, and Absolute Zero Reasoner techniques.
|
| 29 |
+
|
| 30 |
+
## What it does
|
| 31 |
+
|
| 32 |
+
ForgeEnv simulates **HuggingFace library version drift**. A *Drift Generator*
|
| 33 |
+
proposes a realistic breakage to a working training script (renamed APIs,
|
| 34 |
+
deprecated imports, changed argument signatures, etc.). A *Repair Agent* then
|
| 35 |
+
emits a unified diff that should restore the script. Reward is computed by an
|
| 36 |
+
execution simulator + AST checker + held-out evaluator (multi-component to
|
| 37 |
+
resist reward hacking).
|
| 38 |
+
|
| 39 |
+
## API
|
| 40 |
+
|
| 41 |
+
The server uses [`openenv-core`](https://pypi.org/project/openenv-core/) and
|
| 42 |
+
follows the Gym-style contract:
|
| 43 |
+
|
| 44 |
+
| Endpoint | Method | Purpose |
|
| 45 |
+
| -------- | ------ | -------------------------------------------------- |
|
| 46 |
+
| `/reset` | POST | Sample a fresh task, return drift-gen observation |
|
| 47 |
+
| `/step` | POST | Apply a `ForgeAction` (breakage or repair) |
|
| 48 |
+
| `/state` | GET | Inspect the current internal state |
|
| 49 |
+
| `/health`| GET | Health probe (used by the container HEALTHCHECK) |
|
| 50 |
+
|
| 51 |
+
`ForgeAction` is a discriminated union of `BreakageAction` (used in phase 1)
|
| 52 |
+
and `RepairAction` (used in phase 2). See
|
| 53 |
+
[`forgeenv/env/actions.py`](forgeenv/env/actions.py).
|
| 54 |
+
|
| 55 |
+
## Quick test
|
| 56 |
+
|
| 57 |
+
```bash
|
| 58 |
+
curl -X POST https://akhiilll-forgeenv.hf.space/reset
|
| 59 |
+
curl https://akhiilll-forgeenv.hf.space/state
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
```python
|
| 63 |
+
from openenv.core.env_client import EnvClient
|
| 64 |
+
|
| 65 |
+
async with EnvClient(base_url="https://akhiilll-forgeenv.hf.space") as client:
|
| 66 |
+
obs = await client.reset()
|
| 67 |
+
print(obs.observation.current_phase, obs.observation.task_id)
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
## Project links
|
| 71 |
+
|
| 72 |
+
- **Main repo / training notebooks / plots:**
|
| 73 |
+
<https://github.com/akhiilll/forgeenv>
|
| 74 |
+
- **Repair Agent model (LoRA):**
|
| 75 |
+
<https://huggingface.co/akhiilll/forgeenv-repair-agent>
|
| 76 |
+
- **Demo (Gradio + ZeroGPU):**
|
| 77 |
+
<https://huggingface.co/spaces/akhiilll/forgeenv-demo>
|
| 78 |
+
|
| 79 |
+
## Citations
|
| 80 |
+
|
| 81 |
+
- Huang et al., *R-Zero: Self-Evolving Reasoning LLM From Zero Data* (2025)
|
| 82 |
+
- Zhao et al., *Absolute Zero: Reinforced Self-play Reasoning with Zero Data* (2025)
|
| 83 |
+
- Liu et al., *SPIRAL: Self-Play on Zero-Sum Games* (2025)
|
| 84 |
+
- [arXiv:2408.10215](https://arxiv.org/abs/2408.10215) — Reward engineering & shaping
|
| 85 |
+
- [arXiv:2601.19100](https://arxiv.org/abs/2601.19100) — Reward engineering for RL in software tasks
|
forgeenv-space/openenv.yaml
CHANGED
|
@@ -1,24 +1,24 @@
|
|
| 1 |
-
name: forgeenv
|
| 2 |
-
version: 0.1.0
|
| 3 |
-
description: >
|
| 4 |
-
Self-improving RL environment for HuggingFace ecosystem repair.
|
| 5 |
-
Trains agents to fix broken training scripts under library version drift
|
| 6 |
-
through Challenger-Solver co-evolution. Implements R-Zero (Tencent), SPIRAL,
|
| 7 |
-
and Absolute Zero Reasoner techniques on top of OpenEnv.
|
| 8 |
-
theme: self-improvement
|
| 9 |
-
tags:
|
| 10 |
-
- openenv
|
| 11 |
-
- self-play
|
| 12 |
-
- code-repair
|
| 13 |
-
- schema-drift
|
| 14 |
-
- multi-role
|
| 15 |
-
- huggingface
|
| 16 |
-
- reinforcement-learning
|
| 17 |
-
environment:
|
| 18 |
-
class: forgeenv.env.forge_environment.ForgeEnvironment
|
| 19 |
-
action_model: forgeenv.env.actions.ForgeAction
|
| 20 |
-
observation_model: forgeenv.env.observations.ForgeObservation
|
| 21 |
-
server:
|
| 22 |
-
module: forgeenv.env.server
|
| 23 |
-
app: app
|
| 24 |
-
port: 7860
|
|
|
|
| 1 |
+
name: forgeenv
|
| 2 |
+
version: 0.1.0
|
| 3 |
+
description: >
|
| 4 |
+
Self-improving RL environment for HuggingFace ecosystem repair.
|
| 5 |
+
Trains agents to fix broken training scripts under library version drift
|
| 6 |
+
through Challenger-Solver co-evolution. Implements R-Zero (Tencent), SPIRAL,
|
| 7 |
+
and Absolute Zero Reasoner techniques on top of OpenEnv.
|
| 8 |
+
theme: self-improvement
|
| 9 |
+
tags:
|
| 10 |
+
- openenv
|
| 11 |
+
- self-play
|
| 12 |
+
- code-repair
|
| 13 |
+
- schema-drift
|
| 14 |
+
- multi-role
|
| 15 |
+
- huggingface
|
| 16 |
+
- reinforcement-learning
|
| 17 |
+
environment:
|
| 18 |
+
class: forgeenv.env.forge_environment.ForgeEnvironment
|
| 19 |
+
action_model: forgeenv.env.actions.ForgeAction
|
| 20 |
+
observation_model: forgeenv.env.observations.ForgeObservation
|
| 21 |
+
server:
|
| 22 |
+
module: forgeenv.env.server
|
| 23 |
+
app: app
|
| 24 |
+
port: 7860
|
forgeenv-space/requirements.txt
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
-
openenv-core>=0.2.0
|
| 2 |
-
fastapi>=0.110.0
|
| 3 |
-
uvicorn[standard]>=0.27.0
|
| 4 |
-
pydantic>=2.6.0
|
| 5 |
-
pyyaml>=6.0
|
| 6 |
-
nltk>=3.8.0
|
| 7 |
-
scikit-learn>=1.4.0
|
| 8 |
-
numpy>=1.26.0
|
| 9 |
-
rich>=13.7.0
|
|
|
|
| 1 |
+
openenv-core>=0.2.0
|
| 2 |
+
fastapi>=0.110.0
|
| 3 |
+
uvicorn[standard]>=0.27.0
|
| 4 |
+
pydantic>=2.6.0
|
| 5 |
+
pyyaml>=6.0
|
| 6 |
+
nltk>=3.8.0
|
| 7 |
+
scikit-learn>=1.4.0
|
| 8 |
+
numpy>=1.26.0
|
| 9 |
+
rich>=13.7.0
|
forgeenv/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""ForgeEnv: Self-improving RL environment for HuggingFace ecosystem repair."""
|
| 2 |
-
|
| 3 |
-
__version__ = "0.1.0"
|
| 4 |
-
__author__ = "akhiilll"
|
|
|
|
| 1 |
+
"""ForgeEnv: Self-improving RL environment for HuggingFace ecosystem repair."""
|
| 2 |
+
|
| 3 |
+
__version__ = "0.1.0"
|
| 4 |
+
__author__ = "akhiilll"
|
forgeenv/artifacts/repair_library.py
CHANGED
|
@@ -1,120 +1,120 @@
|
|
| 1 |
-
"""Persisted "repair library" — the model's accumulated knowledge of
|
| 2 |
-
known breakage -> repair pairs. Curated from successful rollouts during
|
| 3 |
-
training. Loaded at inference time as a few-shot prefix when the agent
|
| 4 |
-
recognises a familiar error class.
|
| 5 |
-
"""
|
| 6 |
-
from __future__ import annotations
|
| 7 |
-
|
| 8 |
-
import json
|
| 9 |
-
from dataclasses import asdict, dataclass, field
|
| 10 |
-
from pathlib import Path
|
| 11 |
-
from typing import Any, Optional
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
@dataclass
|
| 15 |
-
class RepairExample:
|
| 16 |
-
primitive_type: str
|
| 17 |
-
breakage_params: dict[str, Any]
|
| 18 |
-
error_signature: str
|
| 19 |
-
repair_diff: str
|
| 20 |
-
visible_reward: float
|
| 21 |
-
held_out: dict[str, float]
|
| 22 |
-
task_id: str = ""
|
| 23 |
-
|
| 24 |
-
def signature_key(self) -> str:
|
| 25 |
-
return f"{self.primitive_type}::{self.error_signature[:80]}"
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
@dataclass
|
| 29 |
-
class RepairLibrary:
|
| 30 |
-
examples: list[RepairExample] = field(default_factory=list)
|
| 31 |
-
|
| 32 |
-
def add(self, example: RepairExample) -> None:
|
| 33 |
-
self.examples.append(example)
|
| 34 |
-
|
| 35 |
-
def best_match(self, primitive_type: str, error_text: str) -> Optional[RepairExample]:
|
| 36 |
-
"""Return the highest-reward example whose primitive_type matches and
|
| 37 |
-
whose error text overlaps."""
|
| 38 |
-
candidates = [
|
| 39 |
-
e for e in self.examples if e.primitive_type == primitive_type
|
| 40 |
-
]
|
| 41 |
-
if not candidates:
|
| 42 |
-
return None
|
| 43 |
-
scored = sorted(
|
| 44 |
-
candidates,
|
| 45 |
-
key=lambda e: (
|
| 46 |
-
_ngram_overlap(e.error_signature, error_text),
|
| 47 |
-
e.visible_reward,
|
| 48 |
-
),
|
| 49 |
-
reverse=True,
|
| 50 |
-
)
|
| 51 |
-
return scored[0] if scored else None
|
| 52 |
-
|
| 53 |
-
def to_dict(self) -> dict:
|
| 54 |
-
return {
|
| 55 |
-
"version": "1",
|
| 56 |
-
"examples": [asdict(e) for e in self.examples],
|
| 57 |
-
"size": len(self.examples),
|
| 58 |
-
"by_primitive": _count_by_primitive(self.examples),
|
| 59 |
-
}
|
| 60 |
-
|
| 61 |
-
def save(self, path: str | Path) -> None:
|
| 62 |
-
path = Path(path)
|
| 63 |
-
path.parent.mkdir(parents=True, exist_ok=True)
|
| 64 |
-
path.write_text(json.dumps(self.to_dict(), indent=2), encoding="utf-8")
|
| 65 |
-
|
| 66 |
-
@classmethod
|
| 67 |
-
def load(cls, path: str | Path) -> "RepairLibrary":
|
| 68 |
-
data = json.loads(Path(path).read_text(encoding="utf-8"))
|
| 69 |
-
examples = [RepairExample(**e) for e in data.get("examples", [])]
|
| 70 |
-
return cls(examples=examples)
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def _ngram_overlap(a: str, b: str, n: int = 3) -> float:
|
| 74 |
-
if not a or not b:
|
| 75 |
-
return 0.0
|
| 76 |
-
|
| 77 |
-
def grams(text: str) -> set[str]:
|
| 78 |
-
text = text.lower()
|
| 79 |
-
return {text[i : i + n] for i in range(len(text) - n + 1)}
|
| 80 |
-
|
| 81 |
-
ga, gb = grams(a), grams(b)
|
| 82 |
-
if not ga or not gb:
|
| 83 |
-
return 0.0
|
| 84 |
-
return len(ga & gb) / max(1, len(ga | gb))
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def _count_by_primitive(examples: list[RepairExample]) -> dict[str, int]:
|
| 88 |
-
counts: dict[str, int] = {}
|
| 89 |
-
for e in examples:
|
| 90 |
-
counts[e.primitive_type] = counts.get(e.primitive_type, 0) + 1
|
| 91 |
-
return counts
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
def curate_from_rollouts(
|
| 95 |
-
rollout_results: list,
|
| 96 |
-
min_reward: float = 0.6,
|
| 97 |
-
min_held_out_clean: float = 0.5,
|
| 98 |
-
) -> RepairLibrary:
|
| 99 |
-
"""Build a RepairLibrary from a list of rollout dicts/RolloutResults."""
|
| 100 |
-
lib = RepairLibrary()
|
| 101 |
-
for r in rollout_results:
|
| 102 |
-
get = r.get if isinstance(r, dict) else lambda k, default=None: getattr(r, k, default)
|
| 103 |
-
if float(get("visible_reward", 0.0) or 0.0) < min_reward:
|
| 104 |
-
continue
|
| 105 |
-
if float(get("held_out_breakdown", {}).get("executed_cleanly", 0.0)) < min_held_out_clean:
|
| 106 |
-
continue
|
| 107 |
-
lib.add(
|
| 108 |
-
RepairExample(
|
| 109 |
-
primitive_type=str(get("primitive_type", "unknown")),
|
| 110 |
-
breakage_params=dict(get("info", {}).get("breakage_spec", {}).get("params", {}))
|
| 111 |
-
if isinstance(get("info", {}), dict)
|
| 112 |
-
else {},
|
| 113 |
-
error_signature=str(get("error_trace", "") or "")[:160],
|
| 114 |
-
repair_diff=str(get("repair_completion", "") or get("info", {}).get("repair_diff", ""))[:2000],
|
| 115 |
-
visible_reward=float(get("visible_reward", 0.0) or 0.0),
|
| 116 |
-
held_out=dict(get("held_out_breakdown", {}) or {}),
|
| 117 |
-
task_id=str(get("task_id", "")),
|
| 118 |
-
)
|
| 119 |
-
)
|
| 120 |
-
return lib
|
|
|
|
| 1 |
+
"""Persisted "repair library" — the model's accumulated knowledge of
|
| 2 |
+
known breakage -> repair pairs. Curated from successful rollouts during
|
| 3 |
+
training. Loaded at inference time as a few-shot prefix when the agent
|
| 4 |
+
recognises a familiar error class.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
from dataclasses import asdict, dataclass, field
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Optional
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class RepairExample:
|
| 16 |
+
primitive_type: str
|
| 17 |
+
breakage_params: dict[str, Any]
|
| 18 |
+
error_signature: str
|
| 19 |
+
repair_diff: str
|
| 20 |
+
visible_reward: float
|
| 21 |
+
held_out: dict[str, float]
|
| 22 |
+
task_id: str = ""
|
| 23 |
+
|
| 24 |
+
def signature_key(self) -> str:
|
| 25 |
+
return f"{self.primitive_type}::{self.error_signature[:80]}"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class RepairLibrary:
|
| 30 |
+
examples: list[RepairExample] = field(default_factory=list)
|
| 31 |
+
|
| 32 |
+
def add(self, example: RepairExample) -> None:
|
| 33 |
+
self.examples.append(example)
|
| 34 |
+
|
| 35 |
+
def best_match(self, primitive_type: str, error_text: str) -> Optional[RepairExample]:
|
| 36 |
+
"""Return the highest-reward example whose primitive_type matches and
|
| 37 |
+
whose error text overlaps."""
|
| 38 |
+
candidates = [
|
| 39 |
+
e for e in self.examples if e.primitive_type == primitive_type
|
| 40 |
+
]
|
| 41 |
+
if not candidates:
|
| 42 |
+
return None
|
| 43 |
+
scored = sorted(
|
| 44 |
+
candidates,
|
| 45 |
+
key=lambda e: (
|
| 46 |
+
_ngram_overlap(e.error_signature, error_text),
|
| 47 |
+
e.visible_reward,
|
| 48 |
+
),
|
| 49 |
+
reverse=True,
|
| 50 |
+
)
|
| 51 |
+
return scored[0] if scored else None
|
| 52 |
+
|
| 53 |
+
def to_dict(self) -> dict:
|
| 54 |
+
return {
|
| 55 |
+
"version": "1",
|
| 56 |
+
"examples": [asdict(e) for e in self.examples],
|
| 57 |
+
"size": len(self.examples),
|
| 58 |
+
"by_primitive": _count_by_primitive(self.examples),
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
def save(self, path: str | Path) -> None:
|
| 62 |
+
path = Path(path)
|
| 63 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 64 |
+
path.write_text(json.dumps(self.to_dict(), indent=2), encoding="utf-8")
|
| 65 |
+
|
| 66 |
+
@classmethod
|
| 67 |
+
def load(cls, path: str | Path) -> "RepairLibrary":
|
| 68 |
+
data = json.loads(Path(path).read_text(encoding="utf-8"))
|
| 69 |
+
examples = [RepairExample(**e) for e in data.get("examples", [])]
|
| 70 |
+
return cls(examples=examples)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _ngram_overlap(a: str, b: str, n: int = 3) -> float:
|
| 74 |
+
if not a or not b:
|
| 75 |
+
return 0.0
|
| 76 |
+
|
| 77 |
+
def grams(text: str) -> set[str]:
|
| 78 |
+
text = text.lower()
|
| 79 |
+
return {text[i : i + n] for i in range(len(text) - n + 1)}
|
| 80 |
+
|
| 81 |
+
ga, gb = grams(a), grams(b)
|
| 82 |
+
if not ga or not gb:
|
| 83 |
+
return 0.0
|
| 84 |
+
return len(ga & gb) / max(1, len(ga | gb))
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _count_by_primitive(examples: list[RepairExample]) -> dict[str, int]:
|
| 88 |
+
counts: dict[str, int] = {}
|
| 89 |
+
for e in examples:
|
| 90 |
+
counts[e.primitive_type] = counts.get(e.primitive_type, 0) + 1
|
| 91 |
+
return counts
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def curate_from_rollouts(
|
| 95 |
+
rollout_results: list,
|
| 96 |
+
min_reward: float = 0.6,
|
| 97 |
+
min_held_out_clean: float = 0.5,
|
| 98 |
+
) -> RepairLibrary:
|
| 99 |
+
"""Build a RepairLibrary from a list of rollout dicts/RolloutResults."""
|
| 100 |
+
lib = RepairLibrary()
|
| 101 |
+
for r in rollout_results:
|
| 102 |
+
get = r.get if isinstance(r, dict) else lambda k, default=None: getattr(r, k, default)
|
| 103 |
+
if float(get("visible_reward", 0.0) or 0.0) < min_reward:
|
| 104 |
+
continue
|
| 105 |
+
if float(get("held_out_breakdown", {}).get("executed_cleanly", 0.0)) < min_held_out_clean:
|
| 106 |
+
continue
|
| 107 |
+
lib.add(
|
| 108 |
+
RepairExample(
|
| 109 |
+
primitive_type=str(get("primitive_type", "unknown")),
|
| 110 |
+
breakage_params=dict(get("info", {}).get("breakage_spec", {}).get("params", {}))
|
| 111 |
+
if isinstance(get("info", {}), dict)
|
| 112 |
+
else {},
|
| 113 |
+
error_signature=str(get("error_trace", "") or "")[:160],
|
| 114 |
+
repair_diff=str(get("repair_completion", "") or get("info", {}).get("repair_diff", ""))[:2000],
|
| 115 |
+
visible_reward=float(get("visible_reward", 0.0) or 0.0),
|
| 116 |
+
held_out=dict(get("held_out_breakdown", {}) or {}),
|
| 117 |
+
task_id=str(get("task_id", "")),
|
| 118 |
+
)
|
| 119 |
+
)
|
| 120 |
+
return lib
|
forgeenv/drift/library_drift_engine.py
CHANGED
|
@@ -1,74 +1,74 @@
|
|
| 1 |
-
"""Library Drift Engine.
|
| 2 |
-
|
| 3 |
-
Manages library version snapshots and triggers version upgrades during
|
| 4 |
-
training to create non-stationary verification. In simulation mode it
|
| 5 |
-
just tracks the current snapshot index — that index influences
|
| 6 |
-
breakage selection and is exposed in observations so the Repair Agent
|
| 7 |
-
can adapt.
|
| 8 |
-
|
| 9 |
-
Also exposes Chojecki GVU's SNR computation
|
| 10 |
-
(https://arxiv.org/abs/2512.02731 Definition 4.4).
|
| 11 |
-
"""
|
| 12 |
-
from __future__ import annotations
|
| 13 |
-
|
| 14 |
-
import math
|
| 15 |
-
from dataclasses import dataclass, field
|
| 16 |
-
|
| 17 |
-
DEFAULT_VERSION_SNAPSHOTS: list[dict[str, str]] = [
|
| 18 |
-
{"transformers": "4.36.0", "datasets": "2.14.0", "trl": "0.7.0"},
|
| 19 |
-
{"transformers": "4.40.0", "datasets": "2.18.0", "trl": "0.8.0"},
|
| 20 |
-
{"transformers": "4.45.0", "datasets": "3.0.0", "trl": "0.10.0"},
|
| 21 |
-
{"transformers": "4.50.0", "datasets": "3.2.0", "trl": "0.12.0"},
|
| 22 |
-
]
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
@dataclass
|
| 26 |
-
class LibraryDriftEngine:
|
| 27 |
-
snapshots: list[dict[str, str]] = field(
|
| 28 |
-
default_factory=lambda: list(DEFAULT_VERSION_SNAPSHOTS)
|
| 29 |
-
)
|
| 30 |
-
current_index: int = 0
|
| 31 |
-
drift_history: list[dict] = field(default_factory=list)
|
| 32 |
-
|
| 33 |
-
def current_versions(self) -> dict[str, str]:
|
| 34 |
-
return dict(self.snapshots[self.current_index])
|
| 35 |
-
|
| 36 |
-
def maybe_drift(self, episode_num: int, drift_every: int = 50) -> bool:
|
| 37 |
-
if (
|
| 38 |
-
episode_num > 0
|
| 39 |
-
and episode_num % drift_every == 0
|
| 40 |
-
and self.current_index < len(self.snapshots) - 1
|
| 41 |
-
):
|
| 42 |
-
prev = self.snapshots[self.current_index]
|
| 43 |
-
self.current_index += 1
|
| 44 |
-
self.drift_history.append(
|
| 45 |
-
{
|
| 46 |
-
"episode": episode_num,
|
| 47 |
-
"from": prev,
|
| 48 |
-
"to": self.snapshots[self.current_index],
|
| 49 |
-
}
|
| 50 |
-
)
|
| 51 |
-
return True
|
| 52 |
-
return False
|
| 53 |
-
|
| 54 |
-
def reset(self) -> None:
|
| 55 |
-
self.current_index = 0
|
| 56 |
-
self.drift_history.clear()
|
| 57 |
-
|
| 58 |
-
@staticmethod
|
| 59 |
-
def compute_snr(
|
| 60 |
-
recent_held_out: list[float], recent_visible: list[float]
|
| 61 |
-
) -> dict[str, float]:
|
| 62 |
-
"""SNR per Chojecki GVU Def 4.4: SNR = mean(rewards)^2 / variance(rewards)."""
|
| 63 |
-
|
| 64 |
-
def snr(values: list[float]) -> float:
|
| 65 |
-
if len(values) < 2:
|
| 66 |
-
return 0.0
|
| 67 |
-
mean = sum(values) / len(values)
|
| 68 |
-
var = sum((v - mean) ** 2 for v in values) / len(values)
|
| 69 |
-
return mean**2 / max(var, 1e-8)
|
| 70 |
-
|
| 71 |
-
return {
|
| 72 |
-
"snr_verifier": snr(recent_held_out),
|
| 73 |
-
"snr_generator": snr(recent_visible),
|
| 74 |
-
}
|
|
|
|
| 1 |
+
"""Library Drift Engine.
|
| 2 |
+
|
| 3 |
+
Manages library version snapshots and triggers version upgrades during
|
| 4 |
+
training to create non-stationary verification. In simulation mode it
|
| 5 |
+
just tracks the current snapshot index — that index influences
|
| 6 |
+
breakage selection and is exposed in observations so the Repair Agent
|
| 7 |
+
can adapt.
|
| 8 |
+
|
| 9 |
+
Also exposes Chojecki GVU's SNR computation
|
| 10 |
+
(https://arxiv.org/abs/2512.02731 Definition 4.4).
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import math
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
|
| 17 |
+
DEFAULT_VERSION_SNAPSHOTS: list[dict[str, str]] = [
|
| 18 |
+
{"transformers": "4.36.0", "datasets": "2.14.0", "trl": "0.7.0"},
|
| 19 |
+
{"transformers": "4.40.0", "datasets": "2.18.0", "trl": "0.8.0"},
|
| 20 |
+
{"transformers": "4.45.0", "datasets": "3.0.0", "trl": "0.10.0"},
|
| 21 |
+
{"transformers": "4.50.0", "datasets": "3.2.0", "trl": "0.12.0"},
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class LibraryDriftEngine:
|
| 27 |
+
snapshots: list[dict[str, str]] = field(
|
| 28 |
+
default_factory=lambda: list(DEFAULT_VERSION_SNAPSHOTS)
|
| 29 |
+
)
|
| 30 |
+
current_index: int = 0
|
| 31 |
+
drift_history: list[dict] = field(default_factory=list)
|
| 32 |
+
|
| 33 |
+
def current_versions(self) -> dict[str, str]:
|
| 34 |
+
return dict(self.snapshots[self.current_index])
|
| 35 |
+
|
| 36 |
+
def maybe_drift(self, episode_num: int, drift_every: int = 50) -> bool:
|
| 37 |
+
if (
|
| 38 |
+
episode_num > 0
|
| 39 |
+
and episode_num % drift_every == 0
|
| 40 |
+
and self.current_index < len(self.snapshots) - 1
|
| 41 |
+
):
|
| 42 |
+
prev = self.snapshots[self.current_index]
|
| 43 |
+
self.current_index += 1
|
| 44 |
+
self.drift_history.append(
|
| 45 |
+
{
|
| 46 |
+
"episode": episode_num,
|
| 47 |
+
"from": prev,
|
| 48 |
+
"to": self.snapshots[self.current_index],
|
| 49 |
+
}
|
| 50 |
+
)
|
| 51 |
+
return True
|
| 52 |
+
return False
|
| 53 |
+
|
| 54 |
+
def reset(self) -> None:
|
| 55 |
+
self.current_index = 0
|
| 56 |
+
self.drift_history.clear()
|
| 57 |
+
|
| 58 |
+
@staticmethod
|
| 59 |
+
def compute_snr(
|
| 60 |
+
recent_held_out: list[float], recent_visible: list[float]
|
| 61 |
+
) -> dict[str, float]:
|
| 62 |
+
"""SNR per Chojecki GVU Def 4.4: SNR = mean(rewards)^2 / variance(rewards)."""
|
| 63 |
+
|
| 64 |
+
def snr(values: list[float]) -> float:
|
| 65 |
+
if len(values) < 2:
|
| 66 |
+
return 0.0
|
| 67 |
+
mean = sum(values) / len(values)
|
| 68 |
+
var = sum((v - mean) ** 2 for v in values) / len(values)
|
| 69 |
+
return mean**2 / max(var, 1e-8)
|
| 70 |
+
|
| 71 |
+
return {
|
| 72 |
+
"snr_verifier": snr(recent_held_out),
|
| 73 |
+
"snr_generator": snr(recent_visible),
|
| 74 |
+
}
|
forgeenv/env/actions.py
CHANGED
|
@@ -1,50 +1,50 @@
|
|
| 1 |
-
"""Pydantic action models for ForgeEnv (compatible with OpenEnv 0.2.x).
|
| 2 |
-
|
| 3 |
-
Episodes have two phases — drift_gen (Challenger) and repair (Solver) — so
|
| 4 |
-
we expose a single union ForgeAction that carries either a BreakageAction
|
| 5 |
-
or a RepairAction. The environment dispatches on which sub-field is set.
|
| 6 |
-
"""
|
| 7 |
-
from __future__ import annotations
|
| 8 |
-
|
| 9 |
-
from typing import Any, Literal, Optional
|
| 10 |
-
|
| 11 |
-
from pydantic import Field
|
| 12 |
-
|
| 13 |
-
from openenv.core import Action
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class BreakageAction(Action):
|
| 17 |
-
"""Drift Generator's action: pick a primitive type + parameters."""
|
| 18 |
-
|
| 19 |
-
action_type: Literal["breakage"] = "breakage"
|
| 20 |
-
primitive_type: str = Field(
|
| 21 |
-
..., description="One of the registered breakage primitive class names"
|
| 22 |
-
)
|
| 23 |
-
params: dict[str, Any] = Field(
|
| 24 |
-
default_factory=dict, description="Primitive-specific parameters"
|
| 25 |
-
)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
class RepairAction(Action):
|
| 29 |
-
"""Repair Agent's action: a unified diff (or full replacement script)."""
|
| 30 |
-
|
| 31 |
-
action_type: Literal["repair"] = "repair"
|
| 32 |
-
unified_diff: str = Field(..., description="Unified diff or full replacement script")
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
class ForgeAction(Action):
|
| 36 |
-
"""Union action: exactly one of `breakage` / `repair` must be set.
|
| 37 |
-
|
| 38 |
-
This is the type registered with OpenEnv's `create_app`. It avoids
|
| 39 |
-
Pydantic discriminated unions to keep the OpenAPI schema flat and
|
| 40 |
-
cross-version-friendly.
|
| 41 |
-
"""
|
| 42 |
-
|
| 43 |
-
breakage: Optional[BreakageAction] = None
|
| 44 |
-
repair: Optional[RepairAction] = None
|
| 45 |
-
|
| 46 |
-
def model_post_init(self, __context: Any) -> None:
|
| 47 |
-
if (self.breakage is None) == (self.repair is None):
|
| 48 |
-
raise ValueError(
|
| 49 |
-
"ForgeAction requires exactly one of `breakage` or `repair` to be set."
|
| 50 |
-
)
|
|
|
|
| 1 |
+
"""Pydantic action models for ForgeEnv (compatible with OpenEnv 0.2.x).
|
| 2 |
+
|
| 3 |
+
Episodes have two phases — drift_gen (Challenger) and repair (Solver) — so
|
| 4 |
+
we expose a single union ForgeAction that carries either a BreakageAction
|
| 5 |
+
or a RepairAction. The environment dispatches on which sub-field is set.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import Any, Literal, Optional
|
| 10 |
+
|
| 11 |
+
from pydantic import Field
|
| 12 |
+
|
| 13 |
+
from openenv.core import Action
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BreakageAction(Action):
|
| 17 |
+
"""Drift Generator's action: pick a primitive type + parameters."""
|
| 18 |
+
|
| 19 |
+
action_type: Literal["breakage"] = "breakage"
|
| 20 |
+
primitive_type: str = Field(
|
| 21 |
+
..., description="One of the registered breakage primitive class names"
|
| 22 |
+
)
|
| 23 |
+
params: dict[str, Any] = Field(
|
| 24 |
+
default_factory=dict, description="Primitive-specific parameters"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class RepairAction(Action):
|
| 29 |
+
"""Repair Agent's action: a unified diff (or full replacement script)."""
|
| 30 |
+
|
| 31 |
+
action_type: Literal["repair"] = "repair"
|
| 32 |
+
unified_diff: str = Field(..., description="Unified diff or full replacement script")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ForgeAction(Action):
|
| 36 |
+
"""Union action: exactly one of `breakage` / `repair` must be set.
|
| 37 |
+
|
| 38 |
+
This is the type registered with OpenEnv's `create_app`. It avoids
|
| 39 |
+
Pydantic discriminated unions to keep the OpenAPI schema flat and
|
| 40 |
+
cross-version-friendly.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
breakage: Optional[BreakageAction] = None
|
| 44 |
+
repair: Optional[RepairAction] = None
|
| 45 |
+
|
| 46 |
+
def model_post_init(self, __context: Any) -> None:
|
| 47 |
+
if (self.breakage is None) == (self.repair is None):
|
| 48 |
+
raise ValueError(
|
| 49 |
+
"ForgeAction requires exactly one of `breakage` or `repair` to be set."
|
| 50 |
+
)
|
forgeenv/env/diff_utils.py
CHANGED
|
@@ -1,163 +1,163 @@
|
|
| 1 |
-
"""Unified-diff application utilities.
|
| 2 |
-
|
| 3 |
-
The Repair Agent submits a unified diff. We need a permissive applier
|
| 4 |
-
because LLM diffs are often malformed (wrong line numbers, missing
|
| 5 |
-
context, extra prose). We try the strict applier first, then fall
|
| 6 |
-
back to applying hunks via plain string replacement.
|
| 7 |
-
|
| 8 |
-
The agent may also submit a full Python script instead of a diff
|
| 9 |
-
(common when the model's diff format breaks). We detect this and
|
| 10 |
-
treat it as a complete replacement.
|
| 11 |
-
"""
|
| 12 |
-
from __future__ import annotations
|
| 13 |
-
|
| 14 |
-
import difflib
|
| 15 |
-
import re
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE)
|
| 19 |
-
_SCRIPT_MARKERS = ("import ", "from ", "def ", "class ", "print(")
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def looks_like_full_script(text: str) -> bool:
|
| 23 |
-
"""Heuristic: text is probably a full python script, not a diff."""
|
| 24 |
-
lines = text.lstrip().splitlines()
|
| 25 |
-
if not lines:
|
| 26 |
-
return False
|
| 27 |
-
has_diff_header = any(
|
| 28 |
-
line.startswith(("---", "+++", "@@")) for line in lines[:5]
|
| 29 |
-
)
|
| 30 |
-
if has_diff_header:
|
| 31 |
-
return False
|
| 32 |
-
# If we see two or more script-style markers in the first 30 lines,
|
| 33 |
-
# treat as a full replacement script.
|
| 34 |
-
head = "\n".join(lines[:30])
|
| 35 |
-
hits = sum(1 for marker in _SCRIPT_MARKERS if marker in head)
|
| 36 |
-
return hits >= 2
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def _strict_apply(broken_script: str, diff_text: str) -> str | None:
|
| 40 |
-
"""Apply a unified diff strictly. Returns None on any failure."""
|
| 41 |
-
lines = broken_script.splitlines(keepends=True)
|
| 42 |
-
out: list[str] = []
|
| 43 |
-
diff_lines = diff_text.splitlines()
|
| 44 |
-
i = 0
|
| 45 |
-
src_idx = 0
|
| 46 |
-
in_hunk = False
|
| 47 |
-
hunk_old: list[str] = []
|
| 48 |
-
hunk_new: list[str] = []
|
| 49 |
-
|
| 50 |
-
while i < len(diff_lines):
|
| 51 |
-
line = diff_lines[i]
|
| 52 |
-
if line.startswith(("---", "+++")):
|
| 53 |
-
i += 1
|
| 54 |
-
continue
|
| 55 |
-
if line.startswith("@@"):
|
| 56 |
-
# Flush previous hunk
|
| 57 |
-
if in_hunk:
|
| 58 |
-
# Find the hunk_old block in the source starting at src_idx.
|
| 59 |
-
target = "".join(hunk_old)
|
| 60 |
-
source_remainder = "".join(lines[src_idx:])
|
| 61 |
-
pos = source_remainder.find(target)
|
| 62 |
-
if pos == -1:
|
| 63 |
-
return None
|
| 64 |
-
out.append(source_remainder[:pos])
|
| 65 |
-
out.append("".join(hunk_new))
|
| 66 |
-
src_idx += len(source_remainder[: pos + len(target)].splitlines(keepends=True))
|
| 67 |
-
hunk_old, hunk_new = [], []
|
| 68 |
-
in_hunk = True
|
| 69 |
-
i += 1
|
| 70 |
-
continue
|
| 71 |
-
if in_hunk:
|
| 72 |
-
if line.startswith("+"):
|
| 73 |
-
hunk_new.append(line[1:] + "\n")
|
| 74 |
-
elif line.startswith("-"):
|
| 75 |
-
hunk_old.append(line[1:] + "\n")
|
| 76 |
-
else:
|
| 77 |
-
# context line
|
| 78 |
-
ctx = line[1:] if line.startswith(" ") else line
|
| 79 |
-
hunk_old.append(ctx + "\n")
|
| 80 |
-
hunk_new.append(ctx + "\n")
|
| 81 |
-
i += 1
|
| 82 |
-
|
| 83 |
-
# Flush trailing hunk
|
| 84 |
-
if in_hunk and (hunk_old or hunk_new):
|
| 85 |
-
target = "".join(hunk_old)
|
| 86 |
-
source_remainder = "".join(lines[src_idx:])
|
| 87 |
-
pos = source_remainder.find(target)
|
| 88 |
-
if pos == -1:
|
| 89 |
-
return None
|
| 90 |
-
out.append(source_remainder[:pos])
|
| 91 |
-
out.append("".join(hunk_new))
|
| 92 |
-
consumed = source_remainder[: pos + len(target)]
|
| 93 |
-
src_idx += len(consumed.splitlines(keepends=True))
|
| 94 |
-
|
| 95 |
-
out.append("".join(lines[src_idx:]))
|
| 96 |
-
return "".join(out)
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
def _permissive_apply(broken_script: str, diff_text: str) -> str:
|
| 100 |
-
"""Apply a malformed diff by extracting (-,+) line pairs and doing
|
| 101 |
-
a tolerant search-and-replace.
|
| 102 |
-
"""
|
| 103 |
-
repaired = broken_script
|
| 104 |
-
pairs: list[tuple[str, str]] = []
|
| 105 |
-
lines = diff_text.splitlines()
|
| 106 |
-
pending_minus: str | None = None
|
| 107 |
-
|
| 108 |
-
for line in lines:
|
| 109 |
-
if line.startswith("---") or line.startswith("+++") or line.startswith("@@"):
|
| 110 |
-
pending_minus = None
|
| 111 |
-
continue
|
| 112 |
-
if line.startswith("-"):
|
| 113 |
-
pending_minus = line[1:].strip()
|
| 114 |
-
elif line.startswith("+") and pending_minus is not None:
|
| 115 |
-
pairs.append((pending_minus, line[1:].strip()))
|
| 116 |
-
pending_minus = None
|
| 117 |
-
elif pending_minus is not None and not line.startswith(" "):
|
| 118 |
-
# standalone deletion — skip in permissive mode (we can't
|
| 119 |
-
# reliably know what to delete without context)
|
| 120 |
-
pending_minus = None
|
| 121 |
-
|
| 122 |
-
for old, new in pairs:
|
| 123 |
-
if old and old in repaired:
|
| 124 |
-
repaired = repaired.replace(old, new, 1)
|
| 125 |
-
|
| 126 |
-
return repaired
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
def apply_unified_diff(broken_script: str, diff_text: str) -> str:
|
| 130 |
-
"""Try every strategy in order and return the first that produces a change.
|
| 131 |
-
|
| 132 |
-
Strategies:
|
| 133 |
-
1. If `diff_text` looks like a full script, return it directly.
|
| 134 |
-
2. Try strict diff application.
|
| 135 |
-
3. Fall back to permissive (-,+) line-pair replacement.
|
| 136 |
-
4. As last resort, return the broken script unchanged.
|
| 137 |
-
"""
|
| 138 |
-
diff_text = diff_text or ""
|
| 139 |
-
if not diff_text.strip():
|
| 140 |
-
return broken_script
|
| 141 |
-
|
| 142 |
-
if looks_like_full_script(diff_text):
|
| 143 |
-
return diff_text
|
| 144 |
-
|
| 145 |
-
if _HUNK_RE.search(diff_text) or "---" in diff_text or "+++" in diff_text:
|
| 146 |
-
strict = _strict_apply(broken_script, diff_text)
|
| 147 |
-
if strict is not None and strict != broken_script:
|
| 148 |
-
return strict
|
| 149 |
-
|
| 150 |
-
perm = _permissive_apply(broken_script, diff_text)
|
| 151 |
-
return perm
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
def make_unified_diff(before: str, after: str, path: str = "train.py") -> str:
|
| 155 |
-
"""Produce a canonical unified diff from before -> after."""
|
| 156 |
-
diff = difflib.unified_diff(
|
| 157 |
-
before.splitlines(keepends=True),
|
| 158 |
-
after.splitlines(keepends=True),
|
| 159 |
-
fromfile=f"a/{path}",
|
| 160 |
-
tofile=f"b/{path}",
|
| 161 |
-
n=2,
|
| 162 |
-
)
|
| 163 |
-
return "".join(diff)
|
|
|
|
| 1 |
+
"""Unified-diff application utilities.
|
| 2 |
+
|
| 3 |
+
The Repair Agent submits a unified diff. We need a permissive applier
|
| 4 |
+
because LLM diffs are often malformed (wrong line numbers, missing
|
| 5 |
+
context, extra prose). We try the strict applier first, then fall
|
| 6 |
+
back to applying hunks via plain string replacement.
|
| 7 |
+
|
| 8 |
+
The agent may also submit a full Python script instead of a diff
|
| 9 |
+
(common when the model's diff format breaks). We detect this and
|
| 10 |
+
treat it as a complete replacement.
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import difflib
|
| 15 |
+
import re
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE)
|
| 19 |
+
_SCRIPT_MARKERS = ("import ", "from ", "def ", "class ", "print(")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def looks_like_full_script(text: str) -> bool:
|
| 23 |
+
"""Heuristic: text is probably a full python script, not a diff."""
|
| 24 |
+
lines = text.lstrip().splitlines()
|
| 25 |
+
if not lines:
|
| 26 |
+
return False
|
| 27 |
+
has_diff_header = any(
|
| 28 |
+
line.startswith(("---", "+++", "@@")) for line in lines[:5]
|
| 29 |
+
)
|
| 30 |
+
if has_diff_header:
|
| 31 |
+
return False
|
| 32 |
+
# If we see two or more script-style markers in the first 30 lines,
|
| 33 |
+
# treat as a full replacement script.
|
| 34 |
+
head = "\n".join(lines[:30])
|
| 35 |
+
hits = sum(1 for marker in _SCRIPT_MARKERS if marker in head)
|
| 36 |
+
return hits >= 2
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _strict_apply(broken_script: str, diff_text: str) -> str | None:
|
| 40 |
+
"""Apply a unified diff strictly. Returns None on any failure."""
|
| 41 |
+
lines = broken_script.splitlines(keepends=True)
|
| 42 |
+
out: list[str] = []
|
| 43 |
+
diff_lines = diff_text.splitlines()
|
| 44 |
+
i = 0
|
| 45 |
+
src_idx = 0
|
| 46 |
+
in_hunk = False
|
| 47 |
+
hunk_old: list[str] = []
|
| 48 |
+
hunk_new: list[str] = []
|
| 49 |
+
|
| 50 |
+
while i < len(diff_lines):
|
| 51 |
+
line = diff_lines[i]
|
| 52 |
+
if line.startswith(("---", "+++")):
|
| 53 |
+
i += 1
|
| 54 |
+
continue
|
| 55 |
+
if line.startswith("@@"):
|
| 56 |
+
# Flush previous hunk
|
| 57 |
+
if in_hunk:
|
| 58 |
+
# Find the hunk_old block in the source starting at src_idx.
|
| 59 |
+
target = "".join(hunk_old)
|
| 60 |
+
source_remainder = "".join(lines[src_idx:])
|
| 61 |
+
pos = source_remainder.find(target)
|
| 62 |
+
if pos == -1:
|
| 63 |
+
return None
|
| 64 |
+
out.append(source_remainder[:pos])
|
| 65 |
+
out.append("".join(hunk_new))
|
| 66 |
+
src_idx += len(source_remainder[: pos + len(target)].splitlines(keepends=True))
|
| 67 |
+
hunk_old, hunk_new = [], []
|
| 68 |
+
in_hunk = True
|
| 69 |
+
i += 1
|
| 70 |
+
continue
|
| 71 |
+
if in_hunk:
|
| 72 |
+
if line.startswith("+"):
|
| 73 |
+
hunk_new.append(line[1:] + "\n")
|
| 74 |
+
elif line.startswith("-"):
|
| 75 |
+
hunk_old.append(line[1:] + "\n")
|
| 76 |
+
else:
|
| 77 |
+
# context line
|
| 78 |
+
ctx = line[1:] if line.startswith(" ") else line
|
| 79 |
+
hunk_old.append(ctx + "\n")
|
| 80 |
+
hunk_new.append(ctx + "\n")
|
| 81 |
+
i += 1
|
| 82 |
+
|
| 83 |
+
# Flush trailing hunk
|
| 84 |
+
if in_hunk and (hunk_old or hunk_new):
|
| 85 |
+
target = "".join(hunk_old)
|
| 86 |
+
source_remainder = "".join(lines[src_idx:])
|
| 87 |
+
pos = source_remainder.find(target)
|
| 88 |
+
if pos == -1:
|
| 89 |
+
return None
|
| 90 |
+
out.append(source_remainder[:pos])
|
| 91 |
+
out.append("".join(hunk_new))
|
| 92 |
+
consumed = source_remainder[: pos + len(target)]
|
| 93 |
+
src_idx += len(consumed.splitlines(keepends=True))
|
| 94 |
+
|
| 95 |
+
out.append("".join(lines[src_idx:]))
|
| 96 |
+
return "".join(out)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _permissive_apply(broken_script: str, diff_text: str) -> str:
|
| 100 |
+
"""Apply a malformed diff by extracting (-,+) line pairs and doing
|
| 101 |
+
a tolerant search-and-replace.
|
| 102 |
+
"""
|
| 103 |
+
repaired = broken_script
|
| 104 |
+
pairs: list[tuple[str, str]] = []
|
| 105 |
+
lines = diff_text.splitlines()
|
| 106 |
+
pending_minus: str | None = None
|
| 107 |
+
|
| 108 |
+
for line in lines:
|
| 109 |
+
if line.startswith("---") or line.startswith("+++") or line.startswith("@@"):
|
| 110 |
+
pending_minus = None
|
| 111 |
+
continue
|
| 112 |
+
if line.startswith("-"):
|
| 113 |
+
pending_minus = line[1:].strip()
|
| 114 |
+
elif line.startswith("+") and pending_minus is not None:
|
| 115 |
+
pairs.append((pending_minus, line[1:].strip()))
|
| 116 |
+
pending_minus = None
|
| 117 |
+
elif pending_minus is not None and not line.startswith(" "):
|
| 118 |
+
# standalone deletion — skip in permissive mode (we can't
|
| 119 |
+
# reliably know what to delete without context)
|
| 120 |
+
pending_minus = None
|
| 121 |
+
|
| 122 |
+
for old, new in pairs:
|
| 123 |
+
if old and old in repaired:
|
| 124 |
+
repaired = repaired.replace(old, new, 1)
|
| 125 |
+
|
| 126 |
+
return repaired
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def apply_unified_diff(broken_script: str, diff_text: str) -> str:
|
| 130 |
+
"""Try every strategy in order and return the first that produces a change.
|
| 131 |
+
|
| 132 |
+
Strategies:
|
| 133 |
+
1. If `diff_text` looks like a full script, return it directly.
|
| 134 |
+
2. Try strict diff application.
|
| 135 |
+
3. Fall back to permissive (-,+) line-pair replacement.
|
| 136 |
+
4. As last resort, return the broken script unchanged.
|
| 137 |
+
"""
|
| 138 |
+
diff_text = diff_text or ""
|
| 139 |
+
if not diff_text.strip():
|
| 140 |
+
return broken_script
|
| 141 |
+
|
| 142 |
+
if looks_like_full_script(diff_text):
|
| 143 |
+
return diff_text
|
| 144 |
+
|
| 145 |
+
if _HUNK_RE.search(diff_text) or "---" in diff_text or "+++" in diff_text:
|
| 146 |
+
strict = _strict_apply(broken_script, diff_text)
|
| 147 |
+
if strict is not None and strict != broken_script:
|
| 148 |
+
return strict
|
| 149 |
+
|
| 150 |
+
perm = _permissive_apply(broken_script, diff_text)
|
| 151 |
+
return perm
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def make_unified_diff(before: str, after: str, path: str = "train.py") -> str:
|
| 155 |
+
"""Produce a canonical unified diff from before -> after."""
|
| 156 |
+
diff = difflib.unified_diff(
|
| 157 |
+
before.splitlines(keepends=True),
|
| 158 |
+
after.splitlines(keepends=True),
|
| 159 |
+
fromfile=f"a/{path}",
|
| 160 |
+
tofile=f"b/{path}",
|
| 161 |
+
n=2,
|
| 162 |
+
)
|
| 163 |
+
return "".join(diff)
|
forgeenv/env/forge_environment.py
CHANGED
|
@@ -1,259 +1,259 @@
|
|
| 1 |
-
"""ForgeEnvironment: the OpenEnv Environment subclass for ForgeEnv.
|
| 2 |
-
|
| 3 |
-
Episode flow (exactly 2 steps per episode):
|
| 4 |
-
reset() -> sample task, ask Teacher for category
|
| 5 |
-
step(BreakageAction) -> Drift Generator's proposal is applied; broken
|
| 6 |
-
script is run, error trace captured.
|
| 7 |
-
step(RepairAction) -> Repair diff is applied; script is re-executed;
|
| 8 |
-
visible + held-out rewards computed; episode ends.
|
| 9 |
-
"""
|
| 10 |
-
from __future__ import annotations
|
| 11 |
-
|
| 12 |
-
import time
|
| 13 |
-
import uuid
|
| 14 |
-
from typing import Any, Optional
|
| 15 |
-
|
| 16 |
-
from openenv.core import Environment
|
| 17 |
-
|
| 18 |
-
from forgeenv.drift.library_drift_engine import LibraryDriftEngine
|
| 19 |
-
from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction
|
| 20 |
-
from forgeenv.env.diff_utils import apply_unified_diff
|
| 21 |
-
from forgeenv.env.observations import ForgeObservation
|
| 22 |
-
from forgeenv.primitives.breakage_primitives import (
|
| 23 |
-
PRIMITIVE_REGISTRY,
|
| 24 |
-
parse_breakage_spec,
|
| 25 |
-
)
|
| 26 |
-
from forgeenv.roles.teacher import Teacher
|
| 27 |
-
from forgeenv.sandbox.simulation_mode import SimulationExecutor
|
| 28 |
-
from forgeenv.tasks.models import ExecutionResult, Task
|
| 29 |
-
from forgeenv.tasks.task_sampler import TaskSampler
|
| 30 |
-
from forgeenv.verifier.held_out_evaluator import compute_held_out_scores
|
| 31 |
-
from forgeenv.verifier.visible_verifier import compute_visible_reward
|
| 32 |
-
|
| 33 |
-
DEFAULT_CATEGORIES = sorted(PRIMITIVE_REGISTRY.keys())
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
class ForgeEnvironment(Environment[ForgeAction, ForgeObservation, dict]):
|
| 37 |
-
"""OpenEnv-compliant environment for HuggingFace ecosystem repair."""
|
| 38 |
-
|
| 39 |
-
SUPPORTS_CONCURRENT_SESSIONS = False # Teacher state is global per env
|
| 40 |
-
|
| 41 |
-
def __init__(
|
| 42 |
-
self,
|
| 43 |
-
task_sampler: Optional[TaskSampler] = None,
|
| 44 |
-
teacher: Optional[Teacher] = None,
|
| 45 |
-
executor: Optional[SimulationExecutor] = None,
|
| 46 |
-
drift_engine: Optional[LibraryDriftEngine] = None,
|
| 47 |
-
seed: Optional[int] = None,
|
| 48 |
-
) -> None:
|
| 49 |
-
super().__init__()
|
| 50 |
-
self.task_sampler = task_sampler or TaskSampler()
|
| 51 |
-
self.teacher = teacher or Teacher(
|
| 52 |
-
categories=list(DEFAULT_CATEGORIES) or ["api_drift"]
|
| 53 |
-
)
|
| 54 |
-
self.executor = executor or SimulationExecutor(seed=seed)
|
| 55 |
-
self.drift_engine = drift_engine or LibraryDriftEngine()
|
| 56 |
-
|
| 57 |
-
self._episode_id: Optional[str] = None
|
| 58 |
-
self._episode_count: int = 0
|
| 59 |
-
self._current_task: Optional[Task] = None
|
| 60 |
-
self._original_script: str = ""
|
| 61 |
-
self._broken_script: str = ""
|
| 62 |
-
self._error_trace: str = ""
|
| 63 |
-
self._breakage_spec: Optional[dict[str, Any]] = None
|
| 64 |
-
self._target_category: str = ""
|
| 65 |
-
self._current_phase: str = "idle"
|
| 66 |
-
self._last_obs: Optional[ForgeObservation] = None
|
| 67 |
-
|
| 68 |
-
# ------------------------------------------------------------------ API
|
| 69 |
-
def reset(
|
| 70 |
-
self,
|
| 71 |
-
seed: Optional[int] = None,
|
| 72 |
-
episode_id: Optional[str] = None,
|
| 73 |
-
difficulty: Optional[str] = "easy",
|
| 74 |
-
**kwargs: Any,
|
| 75 |
-
) -> ForgeObservation:
|
| 76 |
-
self._episode_id = episode_id or str(uuid.uuid4())
|
| 77 |
-
self._episode_count += 1
|
| 78 |
-
self._target_category = self.teacher.select_next_category()
|
| 79 |
-
|
| 80 |
-
task = self.task_sampler.sample(difficulty=difficulty)
|
| 81 |
-
if task is None:
|
| 82 |
-
raise RuntimeError("Task sampler returned no tasks (empty seed corpus?)")
|
| 83 |
-
self._current_task = task
|
| 84 |
-
self._original_script = task.script_content
|
| 85 |
-
self._broken_script = ""
|
| 86 |
-
self._error_trace = ""
|
| 87 |
-
self._breakage_spec = None
|
| 88 |
-
self._current_phase = "drift_gen"
|
| 89 |
-
|
| 90 |
-
# Library drift trigger every 50 episodes (configurable from outside).
|
| 91 |
-
drifted = self.drift_engine.maybe_drift(self._episode_count, drift_every=50)
|
| 92 |
-
|
| 93 |
-
obs = ForgeObservation(
|
| 94 |
-
current_phase="drift_gen",
|
| 95 |
-
task_id=task.task_id,
|
| 96 |
-
task_description=task.description,
|
| 97 |
-
target_category=self._target_category,
|
| 98 |
-
script_content=self._original_script,
|
| 99 |
-
error_trace=None,
|
| 100 |
-
library_versions=self.drift_engine.current_versions(),
|
| 101 |
-
episode_step=0,
|
| 102 |
-
done=False,
|
| 103 |
-
reward=0.0,
|
| 104 |
-
info={
|
| 105 |
-
"episode_id": self._episode_id,
|
| 106 |
-
"episode_count": self._episode_count,
|
| 107 |
-
"drift_triggered": drifted,
|
| 108 |
-
"available_primitives": sorted(PRIMITIVE_REGISTRY),
|
| 109 |
-
},
|
| 110 |
-
)
|
| 111 |
-
self._last_obs = obs
|
| 112 |
-
return obs
|
| 113 |
-
|
| 114 |
-
def step(
|
| 115 |
-
self,
|
| 116 |
-
action: ForgeAction,
|
| 117 |
-
timeout_s: Optional[float] = None,
|
| 118 |
-
**kwargs: Any,
|
| 119 |
-
) -> ForgeObservation:
|
| 120 |
-
if self._current_phase == "drift_gen":
|
| 121 |
-
if action.breakage is None:
|
| 122 |
-
return self._error_obs("Expected BreakageAction in drift_gen phase")
|
| 123 |
-
return self._handle_breakage(action.breakage)
|
| 124 |
-
|
| 125 |
-
if self._current_phase == "repair":
|
| 126 |
-
if action.repair is None:
|
| 127 |
-
return self._error_obs("Expected RepairAction in repair phase")
|
| 128 |
-
return self._handle_repair(action.repair)
|
| 129 |
-
|
| 130 |
-
return self._error_obs(
|
| 131 |
-
f"step() called in invalid phase {self._current_phase!r} — call reset() first"
|
| 132 |
-
)
|
| 133 |
-
|
| 134 |
-
@property
|
| 135 |
-
def state(self) -> dict:
|
| 136 |
-
return {
|
| 137 |
-
"phase": self._current_phase,
|
| 138 |
-
"episode_id": self._episode_id,
|
| 139 |
-
"episode_count": self._episode_count,
|
| 140 |
-
"task_id": self._current_task.task_id if self._current_task else None,
|
| 141 |
-
"target_category": self._target_category,
|
| 142 |
-
"library_versions": self.drift_engine.current_versions(),
|
| 143 |
-
"teacher": self.teacher.get_state(),
|
| 144 |
-
"drift_history": list(self.drift_engine.drift_history),
|
| 145 |
-
"breakage_spec": dict(self._breakage_spec) if self._breakage_spec else None,
|
| 146 |
-
}
|
| 147 |
-
|
| 148 |
-
# ---------------------------------------------------------------- helpers
|
| 149 |
-
def _handle_breakage(self, breakage: BreakageAction) -> ForgeObservation:
|
| 150 |
-
spec = {"primitive_type": breakage.primitive_type, "params": dict(breakage.params)}
|
| 151 |
-
try:
|
| 152 |
-
primitive = parse_breakage_spec(spec)
|
| 153 |
-
except ValueError as exc:
|
| 154 |
-
return self._error_obs(f"Invalid breakage spec: {exc}")
|
| 155 |
-
|
| 156 |
-
try:
|
| 157 |
-
self._broken_script = primitive.apply(self._original_script)
|
| 158 |
-
except Exception as exc: # primitive bug — surface but don't crash server
|
| 159 |
-
return self._error_obs(f"Primitive apply failed: {exc}")
|
| 160 |
-
|
| 161 |
-
self._breakage_spec = spec
|
| 162 |
-
|
| 163 |
-
result = self.executor.execute(self._broken_script, self._current_task)
|
| 164 |
-
if result.exit_code != 0:
|
| 165 |
-
self._error_trace = result.stderr or "non-zero exit code, no stderr"
|
| 166 |
-
else:
|
| 167 |
-
# The breakage didn't actually break it; still proceed to repair phase
|
| 168 |
-
# (no-op repair is then a valid choice).
|
| 169 |
-
self._error_trace = "Script ran without observable error"
|
| 170 |
-
|
| 171 |
-
self._current_phase = "repair"
|
| 172 |
-
|
| 173 |
-
obs = ForgeObservation(
|
| 174 |
-
current_phase="repair",
|
| 175 |
-
task_id=self._current_task.task_id,
|
| 176 |
-
task_description=self._current_task.description,
|
| 177 |
-
target_category=primitive.category,
|
| 178 |
-
script_content=self._broken_script,
|
| 179 |
-
error_trace=self._error_trace,
|
| 180 |
-
library_versions=self.drift_engine.current_versions(),
|
| 181 |
-
episode_step=1,
|
| 182 |
-
done=False,
|
| 183 |
-
reward=0.0,
|
| 184 |
-
info={
|
| 185 |
-
"episode_id": self._episode_id,
|
| 186 |
-
"breakage_primitive": primitive.name,
|
| 187 |
-
"breakage_description": primitive.description,
|
| 188 |
-
},
|
| 189 |
-
)
|
| 190 |
-
self._last_obs = obs
|
| 191 |
-
return obs
|
| 192 |
-
|
| 193 |
-
def _handle_repair(self, repair: RepairAction) -> ForgeObservation:
|
| 194 |
-
repaired = apply_unified_diff(self._broken_script, repair.unified_diff or "")
|
| 195 |
-
|
| 196 |
-
t0 = time.time()
|
| 197 |
-
result = self.executor.execute(repaired, self._current_task)
|
| 198 |
-
result.script_content = repaired # ensure verifier sees what we ran
|
| 199 |
-
wall_ms = int((time.time() - t0) * 1000)
|
| 200 |
-
|
| 201 |
-
visible_reward, visible_breakdown = compute_visible_reward(
|
| 202 |
-
result, self._current_task
|
| 203 |
-
)
|
| 204 |
-
held_out = compute_held_out_scores(
|
| 205 |
-
result, self._current_task, repair_diff=repair.unified_diff or ""
|
| 206 |
-
)
|
| 207 |
-
|
| 208 |
-
success = result.exit_code == 0
|
| 209 |
-
category = (
|
| 210 |
-
self._breakage_spec.get("primitive_type", "unknown")
|
| 211 |
-
if self._breakage_spec
|
| 212 |
-
else "unknown"
|
| 213 |
-
)
|
| 214 |
-
# Update Teacher's curriculum state
|
| 215 |
-
self.teacher.update(category, success)
|
| 216 |
-
|
| 217 |
-
self._current_phase = "done"
|
| 218 |
-
|
| 219 |
-
obs = ForgeObservation(
|
| 220 |
-
current_phase="done",
|
| 221 |
-
task_id=self._current_task.task_id,
|
| 222 |
-
task_description=self._current_task.description,
|
| 223 |
-
target_category=category,
|
| 224 |
-
script_content=repaired,
|
| 225 |
-
error_trace=result.stderr or None,
|
| 226 |
-
library_versions=self.drift_engine.current_versions(),
|
| 227 |
-
episode_step=2,
|
| 228 |
-
done=True,
|
| 229 |
-
reward=visible_reward,
|
| 230 |
-
reward_breakdown=visible_breakdown,
|
| 231 |
-
held_out_breakdown=held_out,
|
| 232 |
-
info={
|
| 233 |
-
"episode_id": self._episode_id,
|
| 234 |
-
"exit_code": result.exit_code,
|
| 235 |
-
"wall_time_ms": wall_ms,
|
| 236 |
-
"checkpoint_exists": result.checkpoint_exists,
|
| 237 |
-
"stdout_tail": "\n".join(result.stdout.splitlines()[-5:]),
|
| 238 |
-
"breakage_spec": self._breakage_spec,
|
| 239 |
-
"teacher_state": self.teacher.get_state(),
|
| 240 |
-
},
|
| 241 |
-
)
|
| 242 |
-
self._last_obs = obs
|
| 243 |
-
return obs
|
| 244 |
-
|
| 245 |
-
def _error_obs(self, message: str) -> ForgeObservation:
|
| 246 |
-
"""Return a `done=True` error observation rather than raising."""
|
| 247 |
-
return ForgeObservation(
|
| 248 |
-
current_phase="done",
|
| 249 |
-
task_id=self._current_task.task_id if self._current_task else "",
|
| 250 |
-
task_description=self._current_task.description if self._current_task else "",
|
| 251 |
-
target_category=self._target_category,
|
| 252 |
-
script_content=self._broken_script or self._original_script,
|
| 253 |
-
error_trace=message,
|
| 254 |
-
library_versions=self.drift_engine.current_versions(),
|
| 255 |
-
episode_step=2,
|
| 256 |
-
done=True,
|
| 257 |
-
reward=0.0,
|
| 258 |
-
info={"error": message, "episode_id": self._episode_id},
|
| 259 |
-
)
|
|
|
|
| 1 |
+
"""ForgeEnvironment: the OpenEnv Environment subclass for ForgeEnv.
|
| 2 |
+
|
| 3 |
+
Episode flow (exactly 2 steps per episode):
|
| 4 |
+
reset() -> sample task, ask Teacher for category
|
| 5 |
+
step(BreakageAction) -> Drift Generator's proposal is applied; broken
|
| 6 |
+
script is run, error trace captured.
|
| 7 |
+
step(RepairAction) -> Repair diff is applied; script is re-executed;
|
| 8 |
+
visible + held-out rewards computed; episode ends.
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import time
|
| 13 |
+
import uuid
|
| 14 |
+
from typing import Any, Optional
|
| 15 |
+
|
| 16 |
+
from openenv.core import Environment
|
| 17 |
+
|
| 18 |
+
from forgeenv.drift.library_drift_engine import LibraryDriftEngine
|
| 19 |
+
from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction
|
| 20 |
+
from forgeenv.env.diff_utils import apply_unified_diff
|
| 21 |
+
from forgeenv.env.observations import ForgeObservation
|
| 22 |
+
from forgeenv.primitives.breakage_primitives import (
|
| 23 |
+
PRIMITIVE_REGISTRY,
|
| 24 |
+
parse_breakage_spec,
|
| 25 |
+
)
|
| 26 |
+
from forgeenv.roles.teacher import Teacher
|
| 27 |
+
from forgeenv.sandbox.simulation_mode import SimulationExecutor
|
| 28 |
+
from forgeenv.tasks.models import ExecutionResult, Task
|
| 29 |
+
from forgeenv.tasks.task_sampler import TaskSampler
|
| 30 |
+
from forgeenv.verifier.held_out_evaluator import compute_held_out_scores
|
| 31 |
+
from forgeenv.verifier.visible_verifier import compute_visible_reward
|
| 32 |
+
|
| 33 |
+
DEFAULT_CATEGORIES = sorted(PRIMITIVE_REGISTRY.keys())
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ForgeEnvironment(Environment[ForgeAction, ForgeObservation, dict]):
|
| 37 |
+
"""OpenEnv-compliant environment for HuggingFace ecosystem repair."""
|
| 38 |
+
|
| 39 |
+
SUPPORTS_CONCURRENT_SESSIONS = False # Teacher state is global per env
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
task_sampler: Optional[TaskSampler] = None,
|
| 44 |
+
teacher: Optional[Teacher] = None,
|
| 45 |
+
executor: Optional[SimulationExecutor] = None,
|
| 46 |
+
drift_engine: Optional[LibraryDriftEngine] = None,
|
| 47 |
+
seed: Optional[int] = None,
|
| 48 |
+
) -> None:
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.task_sampler = task_sampler or TaskSampler()
|
| 51 |
+
self.teacher = teacher or Teacher(
|
| 52 |
+
categories=list(DEFAULT_CATEGORIES) or ["api_drift"]
|
| 53 |
+
)
|
| 54 |
+
self.executor = executor or SimulationExecutor(seed=seed)
|
| 55 |
+
self.drift_engine = drift_engine or LibraryDriftEngine()
|
| 56 |
+
|
| 57 |
+
self._episode_id: Optional[str] = None
|
| 58 |
+
self._episode_count: int = 0
|
| 59 |
+
self._current_task: Optional[Task] = None
|
| 60 |
+
self._original_script: str = ""
|
| 61 |
+
self._broken_script: str = ""
|
| 62 |
+
self._error_trace: str = ""
|
| 63 |
+
self._breakage_spec: Optional[dict[str, Any]] = None
|
| 64 |
+
self._target_category: str = ""
|
| 65 |
+
self._current_phase: str = "idle"
|
| 66 |
+
self._last_obs: Optional[ForgeObservation] = None
|
| 67 |
+
|
| 68 |
+
# ------------------------------------------------------------------ API
|
| 69 |
+
def reset(
|
| 70 |
+
self,
|
| 71 |
+
seed: Optional[int] = None,
|
| 72 |
+
episode_id: Optional[str] = None,
|
| 73 |
+
difficulty: Optional[str] = "easy",
|
| 74 |
+
**kwargs: Any,
|
| 75 |
+
) -> ForgeObservation:
|
| 76 |
+
self._episode_id = episode_id or str(uuid.uuid4())
|
| 77 |
+
self._episode_count += 1
|
| 78 |
+
self._target_category = self.teacher.select_next_category()
|
| 79 |
+
|
| 80 |
+
task = self.task_sampler.sample(difficulty=difficulty)
|
| 81 |
+
if task is None:
|
| 82 |
+
raise RuntimeError("Task sampler returned no tasks (empty seed corpus?)")
|
| 83 |
+
self._current_task = task
|
| 84 |
+
self._original_script = task.script_content
|
| 85 |
+
self._broken_script = ""
|
| 86 |
+
self._error_trace = ""
|
| 87 |
+
self._breakage_spec = None
|
| 88 |
+
self._current_phase = "drift_gen"
|
| 89 |
+
|
| 90 |
+
# Library drift trigger every 50 episodes (configurable from outside).
|
| 91 |
+
drifted = self.drift_engine.maybe_drift(self._episode_count, drift_every=50)
|
| 92 |
+
|
| 93 |
+
obs = ForgeObservation(
|
| 94 |
+
current_phase="drift_gen",
|
| 95 |
+
task_id=task.task_id,
|
| 96 |
+
task_description=task.description,
|
| 97 |
+
target_category=self._target_category,
|
| 98 |
+
script_content=self._original_script,
|
| 99 |
+
error_trace=None,
|
| 100 |
+
library_versions=self.drift_engine.current_versions(),
|
| 101 |
+
episode_step=0,
|
| 102 |
+
done=False,
|
| 103 |
+
reward=0.0,
|
| 104 |
+
info={
|
| 105 |
+
"episode_id": self._episode_id,
|
| 106 |
+
"episode_count": self._episode_count,
|
| 107 |
+
"drift_triggered": drifted,
|
| 108 |
+
"available_primitives": sorted(PRIMITIVE_REGISTRY),
|
| 109 |
+
},
|
| 110 |
+
)
|
| 111 |
+
self._last_obs = obs
|
| 112 |
+
return obs
|
| 113 |
+
|
| 114 |
+
def step(
|
| 115 |
+
self,
|
| 116 |
+
action: ForgeAction,
|
| 117 |
+
timeout_s: Optional[float] = None,
|
| 118 |
+
**kwargs: Any,
|
| 119 |
+
) -> ForgeObservation:
|
| 120 |
+
if self._current_phase == "drift_gen":
|
| 121 |
+
if action.breakage is None:
|
| 122 |
+
return self._error_obs("Expected BreakageAction in drift_gen phase")
|
| 123 |
+
return self._handle_breakage(action.breakage)
|
| 124 |
+
|
| 125 |
+
if self._current_phase == "repair":
|
| 126 |
+
if action.repair is None:
|
| 127 |
+
return self._error_obs("Expected RepairAction in repair phase")
|
| 128 |
+
return self._handle_repair(action.repair)
|
| 129 |
+
|
| 130 |
+
return self._error_obs(
|
| 131 |
+
f"step() called in invalid phase {self._current_phase!r} — call reset() first"
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
@property
|
| 135 |
+
def state(self) -> dict:
|
| 136 |
+
return {
|
| 137 |
+
"phase": self._current_phase,
|
| 138 |
+
"episode_id": self._episode_id,
|
| 139 |
+
"episode_count": self._episode_count,
|
| 140 |
+
"task_id": self._current_task.task_id if self._current_task else None,
|
| 141 |
+
"target_category": self._target_category,
|
| 142 |
+
"library_versions": self.drift_engine.current_versions(),
|
| 143 |
+
"teacher": self.teacher.get_state(),
|
| 144 |
+
"drift_history": list(self.drift_engine.drift_history),
|
| 145 |
+
"breakage_spec": dict(self._breakage_spec) if self._breakage_spec else None,
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
# ---------------------------------------------------------------- helpers
|
| 149 |
+
def _handle_breakage(self, breakage: BreakageAction) -> ForgeObservation:
|
| 150 |
+
spec = {"primitive_type": breakage.primitive_type, "params": dict(breakage.params)}
|
| 151 |
+
try:
|
| 152 |
+
primitive = parse_breakage_spec(spec)
|
| 153 |
+
except ValueError as exc:
|
| 154 |
+
return self._error_obs(f"Invalid breakage spec: {exc}")
|
| 155 |
+
|
| 156 |
+
try:
|
| 157 |
+
self._broken_script = primitive.apply(self._original_script)
|
| 158 |
+
except Exception as exc: # primitive bug — surface but don't crash server
|
| 159 |
+
return self._error_obs(f"Primitive apply failed: {exc}")
|
| 160 |
+
|
| 161 |
+
self._breakage_spec = spec
|
| 162 |
+
|
| 163 |
+
result = self.executor.execute(self._broken_script, self._current_task)
|
| 164 |
+
if result.exit_code != 0:
|
| 165 |
+
self._error_trace = result.stderr or "non-zero exit code, no stderr"
|
| 166 |
+
else:
|
| 167 |
+
# The breakage didn't actually break it; still proceed to repair phase
|
| 168 |
+
# (no-op repair is then a valid choice).
|
| 169 |
+
self._error_trace = "Script ran without observable error"
|
| 170 |
+
|
| 171 |
+
self._current_phase = "repair"
|
| 172 |
+
|
| 173 |
+
obs = ForgeObservation(
|
| 174 |
+
current_phase="repair",
|
| 175 |
+
task_id=self._current_task.task_id,
|
| 176 |
+
task_description=self._current_task.description,
|
| 177 |
+
target_category=primitive.category,
|
| 178 |
+
script_content=self._broken_script,
|
| 179 |
+
error_trace=self._error_trace,
|
| 180 |
+
library_versions=self.drift_engine.current_versions(),
|
| 181 |
+
episode_step=1,
|
| 182 |
+
done=False,
|
| 183 |
+
reward=0.0,
|
| 184 |
+
info={
|
| 185 |
+
"episode_id": self._episode_id,
|
| 186 |
+
"breakage_primitive": primitive.name,
|
| 187 |
+
"breakage_description": primitive.description,
|
| 188 |
+
},
|
| 189 |
+
)
|
| 190 |
+
self._last_obs = obs
|
| 191 |
+
return obs
|
| 192 |
+
|
| 193 |
+
def _handle_repair(self, repair: RepairAction) -> ForgeObservation:
|
| 194 |
+
repaired = apply_unified_diff(self._broken_script, repair.unified_diff or "")
|
| 195 |
+
|
| 196 |
+
t0 = time.time()
|
| 197 |
+
result = self.executor.execute(repaired, self._current_task)
|
| 198 |
+
result.script_content = repaired # ensure verifier sees what we ran
|
| 199 |
+
wall_ms = int((time.time() - t0) * 1000)
|
| 200 |
+
|
| 201 |
+
visible_reward, visible_breakdown = compute_visible_reward(
|
| 202 |
+
result, self._current_task
|
| 203 |
+
)
|
| 204 |
+
held_out = compute_held_out_scores(
|
| 205 |
+
result, self._current_task, repair_diff=repair.unified_diff or ""
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
success = result.exit_code == 0
|
| 209 |
+
category = (
|
| 210 |
+
self._breakage_spec.get("primitive_type", "unknown")
|
| 211 |
+
if self._breakage_spec
|
| 212 |
+
else "unknown"
|
| 213 |
+
)
|
| 214 |
+
# Update Teacher's curriculum state
|
| 215 |
+
self.teacher.update(category, success)
|
| 216 |
+
|
| 217 |
+
self._current_phase = "done"
|
| 218 |
+
|
| 219 |
+
obs = ForgeObservation(
|
| 220 |
+
current_phase="done",
|
| 221 |
+
task_id=self._current_task.task_id,
|
| 222 |
+
task_description=self._current_task.description,
|
| 223 |
+
target_category=category,
|
| 224 |
+
script_content=repaired,
|
| 225 |
+
error_trace=result.stderr or None,
|
| 226 |
+
library_versions=self.drift_engine.current_versions(),
|
| 227 |
+
episode_step=2,
|
| 228 |
+
done=True,
|
| 229 |
+
reward=visible_reward,
|
| 230 |
+
reward_breakdown=visible_breakdown,
|
| 231 |
+
held_out_breakdown=held_out,
|
| 232 |
+
info={
|
| 233 |
+
"episode_id": self._episode_id,
|
| 234 |
+
"exit_code": result.exit_code,
|
| 235 |
+
"wall_time_ms": wall_ms,
|
| 236 |
+
"checkpoint_exists": result.checkpoint_exists,
|
| 237 |
+
"stdout_tail": "\n".join(result.stdout.splitlines()[-5:]),
|
| 238 |
+
"breakage_spec": self._breakage_spec,
|
| 239 |
+
"teacher_state": self.teacher.get_state(),
|
| 240 |
+
},
|
| 241 |
+
)
|
| 242 |
+
self._last_obs = obs
|
| 243 |
+
return obs
|
| 244 |
+
|
| 245 |
+
def _error_obs(self, message: str) -> ForgeObservation:
|
| 246 |
+
"""Return a `done=True` error observation rather than raising."""
|
| 247 |
+
return ForgeObservation(
|
| 248 |
+
current_phase="done",
|
| 249 |
+
task_id=self._current_task.task_id if self._current_task else "",
|
| 250 |
+
task_description=self._current_task.description if self._current_task else "",
|
| 251 |
+
target_category=self._target_category,
|
| 252 |
+
script_content=self._broken_script or self._original_script,
|
| 253 |
+
error_trace=message,
|
| 254 |
+
library_versions=self.drift_engine.current_versions(),
|
| 255 |
+
episode_step=2,
|
| 256 |
+
done=True,
|
| 257 |
+
reward=0.0,
|
| 258 |
+
info={"error": message, "episode_id": self._episode_id},
|
| 259 |
+
)
|
forgeenv/env/observations.py
CHANGED
|
@@ -1,29 +1,29 @@
|
|
| 1 |
-
"""Pydantic observation model for ForgeEnv."""
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
from typing import Any, Optional
|
| 5 |
-
|
| 6 |
-
from pydantic import Field
|
| 7 |
-
|
| 8 |
-
from openenv.core import Observation
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class ForgeObservation(Observation):
|
| 12 |
-
"""What the agent (or the trainer's rollout function) sees at each step.
|
| 13 |
-
|
| 14 |
-
Inherits `done`, `reward`, `metadata` from the OpenEnv `Observation` base.
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
current_phase: str = Field(
|
| 18 |
-
..., description="One of 'drift_gen', 'repair', 'verify', 'done'"
|
| 19 |
-
)
|
| 20 |
-
task_id: str = ""
|
| 21 |
-
task_description: str = ""
|
| 22 |
-
target_category: str = ""
|
| 23 |
-
script_content: str = Field(default="", description="Current state of the script")
|
| 24 |
-
error_trace: Optional[str] = None
|
| 25 |
-
library_versions: dict[str, str] = Field(default_factory=dict)
|
| 26 |
-
reward_breakdown: dict[str, Any] = Field(default_factory=dict)
|
| 27 |
-
held_out_breakdown: dict[str, float] = Field(default_factory=dict)
|
| 28 |
-
episode_step: int = 0
|
| 29 |
-
info: dict[str, Any] = Field(default_factory=dict)
|
|
|
|
| 1 |
+
"""Pydantic observation model for ForgeEnv."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import Any, Optional
|
| 5 |
+
|
| 6 |
+
from pydantic import Field
|
| 7 |
+
|
| 8 |
+
from openenv.core import Observation
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ForgeObservation(Observation):
|
| 12 |
+
"""What the agent (or the trainer's rollout function) sees at each step.
|
| 13 |
+
|
| 14 |
+
Inherits `done`, `reward`, `metadata` from the OpenEnv `Observation` base.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
current_phase: str = Field(
|
| 18 |
+
..., description="One of 'drift_gen', 'repair', 'verify', 'done'"
|
| 19 |
+
)
|
| 20 |
+
task_id: str = ""
|
| 21 |
+
task_description: str = ""
|
| 22 |
+
target_category: str = ""
|
| 23 |
+
script_content: str = Field(default="", description="Current state of the script")
|
| 24 |
+
error_trace: Optional[str] = None
|
| 25 |
+
library_versions: dict[str, str] = Field(default_factory=dict)
|
| 26 |
+
reward_breakdown: dict[str, Any] = Field(default_factory=dict)
|
| 27 |
+
held_out_breakdown: dict[str, float] = Field(default_factory=dict)
|
| 28 |
+
episode_step: int = 0
|
| 29 |
+
info: dict[str, Any] = Field(default_factory=dict)
|
forgeenv/env/server.py
CHANGED
|
@@ -1,126 +1,126 @@
|
|
| 1 |
-
"""FastAPI server for ForgeEnv (OpenEnv-compliant).
|
| 2 |
-
|
| 3 |
-
Exposes /reset, /step, /state HTTP endpoints via OpenEnv's `create_app`.
|
| 4 |
-
HF Spaces sets PORT=7860 automatically.
|
| 5 |
-
"""
|
| 6 |
-
from __future__ import annotations
|
| 7 |
-
|
| 8 |
-
import os
|
| 9 |
-
|
| 10 |
-
from fastapi.responses import HTMLResponse
|
| 11 |
-
from openenv.core import create_app
|
| 12 |
-
|
| 13 |
-
from forgeenv.env.actions import ForgeAction
|
| 14 |
-
from forgeenv.env.forge_environment import ForgeEnvironment
|
| 15 |
-
from forgeenv.env.observations import ForgeObservation
|
| 16 |
-
|
| 17 |
-
app = create_app(
|
| 18 |
-
env=ForgeEnvironment,
|
| 19 |
-
action_cls=ForgeAction,
|
| 20 |
-
observation_cls=ForgeObservation,
|
| 21 |
-
env_name="forgeenv",
|
| 22 |
-
)
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
_LANDING_HTML = """<!doctype html>
|
| 26 |
-
<html lang="en">
|
| 27 |
-
<head>
|
| 28 |
-
<meta charset="utf-8">
|
| 29 |
-
<title>ForgeEnv — OpenEnv server</title>
|
| 30 |
-
<meta name="viewport" content="width=device-width,initial-scale=1">
|
| 31 |
-
<style>
|
| 32 |
-
:root { color-scheme: light dark; }
|
| 33 |
-
body { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
| 34 |
-
max-width: 760px; margin: 2.5rem auto; padding: 0 1.25rem;
|
| 35 |
-
line-height: 1.55; color: #1f2937; background: #fafafa; }
|
| 36 |
-
@media (prefers-color-scheme: dark) { body { color: #e5e7eb; background: #0f172a; } }
|
| 37 |
-
h1 { font-size: 1.65rem; margin-bottom: 0.25rem; }
|
| 38 |
-
.sub { color: #6b7280; margin-top: 0; }
|
| 39 |
-
code, pre { font-family: ui-monospace, "SF Mono", Menlo, monospace; }
|
| 40 |
-
pre { background: rgba(127,127,127,0.12); padding: 0.9rem; border-radius: 8px;
|
| 41 |
-
overflow-x: auto; }
|
| 42 |
-
table { border-collapse: collapse; width: 100%; margin: 0.75rem 0 1.25rem; }
|
| 43 |
-
td, th { text-align: left; padding: 0.5rem 0.75rem;
|
| 44 |
-
border-bottom: 1px solid rgba(127,127,127,0.25); }
|
| 45 |
-
th { font-weight: 600; }
|
| 46 |
-
a { color: #2563eb; text-decoration: none; } a:hover { text-decoration: underline; }
|
| 47 |
-
.ok { color: #16a34a; font-weight: 600; }
|
| 48 |
-
.muted { color: #6b7280; font-size: 0.9rem; }
|
| 49 |
-
.pill { display: inline-block; padding: 0.1rem 0.5rem; border-radius: 999px;
|
| 50 |
-
background: rgba(34,197,94,0.15); color: #16a34a; font-size: 0.85rem; }
|
| 51 |
-
</style>
|
| 52 |
-
</head>
|
| 53 |
-
<body>
|
| 54 |
-
<h1>ForgeEnv 🔧 <span class="pill">running</span></h1>
|
| 55 |
-
<p class="sub">OpenEnv-compliant RL environment for HuggingFace
|
| 56 |
-
ecosystem repair under library version drift.</p>
|
| 57 |
-
|
| 58 |
-
<p>This URL serves the environment over HTTP. It is not a UI — it's the
|
| 59 |
-
runtime that <strong>training notebooks connect to</strong>. Open one of
|
| 60 |
-
the endpoints below, or use the demo Space to try the trained Repair
|
| 61 |
-
Agent in a browser.</p>
|
| 62 |
-
|
| 63 |
-
<h2>Endpoints</h2>
|
| 64 |
-
<table>
|
| 65 |
-
<tr><th>Method</th><th>Path</th><th>Purpose</th></tr>
|
| 66 |
-
<tr><td>GET </td><td><a href="/health">/health</a></td><td>Health probe</td></tr>
|
| 67 |
-
<tr><td>POST</td><td><code>/reset</code></td><td>Sample task, return drift-gen observation</td></tr>
|
| 68 |
-
<tr><td>POST</td><td><code>/step</code></td><td>Apply <code>ForgeAction</code> (breakage or repair)</td></tr>
|
| 69 |
-
<tr><td>GET </td><td><a href="/state">/state</a></td><td>Current internal state</td></tr>
|
| 70 |
-
<tr><td>GET </td><td><a href="/metadata">/metadata</a></td><td>Env name + version + schema URLs</td></tr>
|
| 71 |
-
<tr><td>GET </td><td><a href="/schema">/schema</a></td><td>Action / observation JSON schemas</td></tr>
|
| 72 |
-
<tr><td>GET </td><td><a href="/docs">/docs</a></td><td>Interactive Swagger UI</td></tr>
|
| 73 |
-
</table>
|
| 74 |
-
|
| 75 |
-
<h2>Quick start (Python)</h2>
|
| 76 |
-
<pre><code>import asyncio
|
| 77 |
-
from openenv.core import GenericEnvClient
|
| 78 |
-
|
| 79 |
-
async def go():
|
| 80 |
-
client = GenericEnvClient(base_url="https://akhiilll-forgeenv.hf.space")
|
| 81 |
-
obs = await client.reset()
|
| 82 |
-
print(obs.observation["current_phase"], obs.observation["task_id"])
|
| 83 |
-
|
| 84 |
-
asyncio.run(go())</code></pre>
|
| 85 |
-
|
| 86 |
-
<h2>Project links</h2>
|
| 87 |
-
<ul>
|
| 88 |
-
<li>Space card & README:
|
| 89 |
-
<a href="https://huggingface.co/spaces/akhiilll/forgeenv" target="_blank" rel="noopener noreferrer">huggingface.co/spaces/akhiilll/forgeenv</a></li>
|
| 90 |
-
<li>Gradio demo:
|
| 91 |
-
<a href="https://huggingface.co/spaces/akhiilll/forgeenv-demo" target="_blank" rel="noopener noreferrer">huggingface.co/spaces/akhiilll/forgeenv-demo</a></li>
|
| 92 |
-
<li>Trained model (LoRA) <span class="muted">— published after the Colab training run finishes</span>:
|
| 93 |
-
<a href="https://huggingface.co/akhiilll/forgeenv-repair-agent" target="_blank" rel="noopener noreferrer">huggingface.co/akhiilll/forgeenv-repair-agent</a></li>
|
| 94 |
-
</ul>
|
| 95 |
-
<p class="muted">Tip: if links don't open from inside the embedded Space frame,
|
| 96 |
-
right-click and choose <em>Open in new tab</em>, or open this URL directly
|
| 97 |
-
at <a href="https://akhiilll-forgeenv.hf.space/" target="_blank" rel="noopener noreferrer">akhiilll-forgeenv.hf.space</a>.</p>
|
| 98 |
-
</body>
|
| 99 |
-
</html>"""
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
def _attach_supplementary_routes(_app) -> None:
|
| 103 |
-
"""Add /health and a friendly GET / landing page if not present."""
|
| 104 |
-
existing = {
|
| 105 |
-
getattr(r, "path", None) for r in getattr(_app, "routes", [])
|
| 106 |
-
}
|
| 107 |
-
|
| 108 |
-
if "/health" not in existing:
|
| 109 |
-
@_app.get("/health")
|
| 110 |
-
def _health() -> dict:
|
| 111 |
-
return {"status": "ok", "env": "forgeenv"}
|
| 112 |
-
|
| 113 |
-
if "/" not in existing:
|
| 114 |
-
@_app.get("/", response_class=HTMLResponse, include_in_schema=False)
|
| 115 |
-
def _root() -> str:
|
| 116 |
-
return _LANDING_HTML
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
_attach_supplementary_routes(app)
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
if __name__ == "__main__":
|
| 123 |
-
import uvicorn
|
| 124 |
-
|
| 125 |
-
port = int(os.environ.get("PORT", "7860"))
|
| 126 |
-
uvicorn.run(app, host="0.0.0.0", port=port)
|
|
|
|
| 1 |
+
"""FastAPI server for ForgeEnv (OpenEnv-compliant).
|
| 2 |
+
|
| 3 |
+
Exposes /reset, /step, /state HTTP endpoints via OpenEnv's `create_app`.
|
| 4 |
+
HF Spaces sets PORT=7860 automatically.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from fastapi.responses import HTMLResponse
|
| 11 |
+
from openenv.core import create_app
|
| 12 |
+
|
| 13 |
+
from forgeenv.env.actions import ForgeAction
|
| 14 |
+
from forgeenv.env.forge_environment import ForgeEnvironment
|
| 15 |
+
from forgeenv.env.observations import ForgeObservation
|
| 16 |
+
|
| 17 |
+
app = create_app(
|
| 18 |
+
env=ForgeEnvironment,
|
| 19 |
+
action_cls=ForgeAction,
|
| 20 |
+
observation_cls=ForgeObservation,
|
| 21 |
+
env_name="forgeenv",
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
_LANDING_HTML = """<!doctype html>
|
| 26 |
+
<html lang="en">
|
| 27 |
+
<head>
|
| 28 |
+
<meta charset="utf-8">
|
| 29 |
+
<title>ForgeEnv — OpenEnv server</title>
|
| 30 |
+
<meta name="viewport" content="width=device-width,initial-scale=1">
|
| 31 |
+
<style>
|
| 32 |
+
:root { color-scheme: light dark; }
|
| 33 |
+
body { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
| 34 |
+
max-width: 760px; margin: 2.5rem auto; padding: 0 1.25rem;
|
| 35 |
+
line-height: 1.55; color: #1f2937; background: #fafafa; }
|
| 36 |
+
@media (prefers-color-scheme: dark) { body { color: #e5e7eb; background: #0f172a; } }
|
| 37 |
+
h1 { font-size: 1.65rem; margin-bottom: 0.25rem; }
|
| 38 |
+
.sub { color: #6b7280; margin-top: 0; }
|
| 39 |
+
code, pre { font-family: ui-monospace, "SF Mono", Menlo, monospace; }
|
| 40 |
+
pre { background: rgba(127,127,127,0.12); padding: 0.9rem; border-radius: 8px;
|
| 41 |
+
overflow-x: auto; }
|
| 42 |
+
table { border-collapse: collapse; width: 100%; margin: 0.75rem 0 1.25rem; }
|
| 43 |
+
td, th { text-align: left; padding: 0.5rem 0.75rem;
|
| 44 |
+
border-bottom: 1px solid rgba(127,127,127,0.25); }
|
| 45 |
+
th { font-weight: 600; }
|
| 46 |
+
a { color: #2563eb; text-decoration: none; } a:hover { text-decoration: underline; }
|
| 47 |
+
.ok { color: #16a34a; font-weight: 600; }
|
| 48 |
+
.muted { color: #6b7280; font-size: 0.9rem; }
|
| 49 |
+
.pill { display: inline-block; padding: 0.1rem 0.5rem; border-radius: 999px;
|
| 50 |
+
background: rgba(34,197,94,0.15); color: #16a34a; font-size: 0.85rem; }
|
| 51 |
+
</style>
|
| 52 |
+
</head>
|
| 53 |
+
<body>
|
| 54 |
+
<h1>ForgeEnv 🔧 <span class="pill">running</span></h1>
|
| 55 |
+
<p class="sub">OpenEnv-compliant RL environment for HuggingFace
|
| 56 |
+
ecosystem repair under library version drift.</p>
|
| 57 |
+
|
| 58 |
+
<p>This URL serves the environment over HTTP. It is not a UI — it's the
|
| 59 |
+
runtime that <strong>training notebooks connect to</strong>. Open one of
|
| 60 |
+
the endpoints below, or use the demo Space to try the trained Repair
|
| 61 |
+
Agent in a browser.</p>
|
| 62 |
+
|
| 63 |
+
<h2>Endpoints</h2>
|
| 64 |
+
<table>
|
| 65 |
+
<tr><th>Method</th><th>Path</th><th>Purpose</th></tr>
|
| 66 |
+
<tr><td>GET </td><td><a href="/health">/health</a></td><td>Health probe</td></tr>
|
| 67 |
+
<tr><td>POST</td><td><code>/reset</code></td><td>Sample task, return drift-gen observation</td></tr>
|
| 68 |
+
<tr><td>POST</td><td><code>/step</code></td><td>Apply <code>ForgeAction</code> (breakage or repair)</td></tr>
|
| 69 |
+
<tr><td>GET </td><td><a href="/state">/state</a></td><td>Current internal state</td></tr>
|
| 70 |
+
<tr><td>GET </td><td><a href="/metadata">/metadata</a></td><td>Env name + version + schema URLs</td></tr>
|
| 71 |
+
<tr><td>GET </td><td><a href="/schema">/schema</a></td><td>Action / observation JSON schemas</td></tr>
|
| 72 |
+
<tr><td>GET </td><td><a href="/docs">/docs</a></td><td>Interactive Swagger UI</td></tr>
|
| 73 |
+
</table>
|
| 74 |
+
|
| 75 |
+
<h2>Quick start (Python)</h2>
|
| 76 |
+
<pre><code>import asyncio
|
| 77 |
+
from openenv.core import GenericEnvClient
|
| 78 |
+
|
| 79 |
+
async def go():
|
| 80 |
+
client = GenericEnvClient(base_url="https://akhiilll-forgeenv.hf.space")
|
| 81 |
+
obs = await client.reset()
|
| 82 |
+
print(obs.observation["current_phase"], obs.observation["task_id"])
|
| 83 |
+
|
| 84 |
+
asyncio.run(go())</code></pre>
|
| 85 |
+
|
| 86 |
+
<h2>Project links</h2>
|
| 87 |
+
<ul>
|
| 88 |
+
<li>Space card & README:
|
| 89 |
+
<a href="https://huggingface.co/spaces/akhiilll/forgeenv" target="_blank" rel="noopener noreferrer">huggingface.co/spaces/akhiilll/forgeenv</a></li>
|
| 90 |
+
<li>Gradio demo:
|
| 91 |
+
<a href="https://huggingface.co/spaces/akhiilll/forgeenv-demo" target="_blank" rel="noopener noreferrer">huggingface.co/spaces/akhiilll/forgeenv-demo</a></li>
|
| 92 |
+
<li>Trained model (LoRA) <span class="muted">— published after the Colab training run finishes</span>:
|
| 93 |
+
<a href="https://huggingface.co/akhiilll/forgeenv-repair-agent" target="_blank" rel="noopener noreferrer">huggingface.co/akhiilll/forgeenv-repair-agent</a></li>
|
| 94 |
+
</ul>
|
| 95 |
+
<p class="muted">Tip: if links don't open from inside the embedded Space frame,
|
| 96 |
+
right-click and choose <em>Open in new tab</em>, or open this URL directly
|
| 97 |
+
at <a href="https://akhiilll-forgeenv.hf.space/" target="_blank" rel="noopener noreferrer">akhiilll-forgeenv.hf.space</a>.</p>
|
| 98 |
+
</body>
|
| 99 |
+
</html>"""
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _attach_supplementary_routes(_app) -> None:
|
| 103 |
+
"""Add /health and a friendly GET / landing page if not present."""
|
| 104 |
+
existing = {
|
| 105 |
+
getattr(r, "path", None) for r in getattr(_app, "routes", [])
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
if "/health" not in existing:
|
| 109 |
+
@_app.get("/health")
|
| 110 |
+
def _health() -> dict:
|
| 111 |
+
return {"status": "ok", "env": "forgeenv"}
|
| 112 |
+
|
| 113 |
+
if "/" not in existing:
|
| 114 |
+
@_app.get("/", response_class=HTMLResponse, include_in_schema=False)
|
| 115 |
+
def _root() -> str:
|
| 116 |
+
return _LANDING_HTML
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
_attach_supplementary_routes(app)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
if __name__ == "__main__":
|
| 123 |
+
import uvicorn
|
| 124 |
+
|
| 125 |
+
port = int(os.environ.get("PORT", "7860"))
|
| 126 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|
forgeenv/primitives/breakage_primitives.py
CHANGED
|
@@ -1,282 +1,282 @@
|
|
| 1 |
-
"""8 breakage primitives representing real HuggingFace/PyTorch ecosystem drift.
|
| 2 |
-
|
| 3 |
-
Each primitive transforms a working script to simulate a library upgrade
|
| 4 |
-
breakage. They double as the Drift Generator's structured action space.
|
| 5 |
-
"""
|
| 6 |
-
from __future__ import annotations
|
| 7 |
-
|
| 8 |
-
import re
|
| 9 |
-
from abc import ABC, abstractmethod
|
| 10 |
-
from dataclasses import dataclass, field
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
@dataclass
|
| 14 |
-
class BreakagePrimitive(ABC):
|
| 15 |
-
"""Abstract base class for all breakage types."""
|
| 16 |
-
|
| 17 |
-
category: str = field(default="generic", init=False)
|
| 18 |
-
name: str = field(default="BreakagePrimitive", init=False)
|
| 19 |
-
description: str = field(default="", init=False)
|
| 20 |
-
|
| 21 |
-
@abstractmethod
|
| 22 |
-
def apply(self, script: str) -> str:
|
| 23 |
-
"""Transform `script` to introduce the breakage."""
|
| 24 |
-
|
| 25 |
-
def to_spec(self) -> dict:
|
| 26 |
-
"""Serialize to JSON-compatible spec for the LLM action space."""
|
| 27 |
-
return {
|
| 28 |
-
"primitive_type": self.__class__.__name__,
|
| 29 |
-
"category": self.category,
|
| 30 |
-
"params": self._get_params(),
|
| 31 |
-
}
|
| 32 |
-
|
| 33 |
-
@abstractmethod
|
| 34 |
-
def _get_params(self) -> dict:
|
| 35 |
-
"""Return a JSON-serializable dict of constructor parameters."""
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
@dataclass
|
| 39 |
-
class RenameApiCall(BreakagePrimitive):
|
| 40 |
-
"""Rename a function/method call to simulate API deprecation."""
|
| 41 |
-
|
| 42 |
-
old_name: str = ""
|
| 43 |
-
new_name: str = ""
|
| 44 |
-
|
| 45 |
-
def __post_init__(self) -> None:
|
| 46 |
-
self.category = "api_drift"
|
| 47 |
-
self.name = "RenameApiCall"
|
| 48 |
-
self.description = f"Rename {self.old_name} -> {self.new_name}"
|
| 49 |
-
|
| 50 |
-
def apply(self, script: str) -> str:
|
| 51 |
-
if not self.old_name:
|
| 52 |
-
return script
|
| 53 |
-
# Use word-boundary replacement so we don't substring-match identifiers.
|
| 54 |
-
pattern = re.compile(rf"(?<!\w){re.escape(self.old_name)}(?!\w)")
|
| 55 |
-
return pattern.sub(self.new_name, script)
|
| 56 |
-
|
| 57 |
-
def _get_params(self) -> dict:
|
| 58 |
-
return {"old_name": self.old_name, "new_name": self.new_name}
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
@dataclass
|
| 62 |
-
class DeprecateImport(BreakagePrimitive):
|
| 63 |
-
"""Change an import path to simulate module restructuring."""
|
| 64 |
-
|
| 65 |
-
old_module: str = ""
|
| 66 |
-
new_module: str = ""
|
| 67 |
-
|
| 68 |
-
def __post_init__(self) -> None:
|
| 69 |
-
self.category = "import_drift"
|
| 70 |
-
self.name = "DeprecateImport"
|
| 71 |
-
self.description = f"Move {self.old_module} -> {self.new_module}"
|
| 72 |
-
|
| 73 |
-
def apply(self, script: str) -> str:
|
| 74 |
-
if not self.old_module:
|
| 75 |
-
return script
|
| 76 |
-
return script.replace(self.old_module, self.new_module)
|
| 77 |
-
|
| 78 |
-
def _get_params(self) -> dict:
|
| 79 |
-
return {"old_module": self.old_module, "new_module": self.new_module}
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
@dataclass
|
| 83 |
-
class ChangeArgumentSignature(BreakagePrimitive):
|
| 84 |
-
"""Remove an expected kwarg (and document a new required one)."""
|
| 85 |
-
|
| 86 |
-
function_name: str = ""
|
| 87 |
-
removed_arg: str = ""
|
| 88 |
-
added_arg: str = ""
|
| 89 |
-
added_value: str = ""
|
| 90 |
-
|
| 91 |
-
def __post_init__(self) -> None:
|
| 92 |
-
self.category = "api_drift"
|
| 93 |
-
self.name = "ChangeArgumentSignature"
|
| 94 |
-
self.description = (
|
| 95 |
-
f"Change args of {self.function_name}: -{self.removed_arg} +{self.added_arg}"
|
| 96 |
-
)
|
| 97 |
-
|
| 98 |
-
def apply(self, script: str) -> str:
|
| 99 |
-
if not self.removed_arg:
|
| 100 |
-
return script
|
| 101 |
-
pattern = rf"(\b{re.escape(self.removed_arg)}\s*=\s*[^,)]+,?\s*)"
|
| 102 |
-
return re.sub(pattern, "", script)
|
| 103 |
-
|
| 104 |
-
def _get_params(self) -> dict:
|
| 105 |
-
return {
|
| 106 |
-
"function_name": self.function_name,
|
| 107 |
-
"removed_arg": self.removed_arg,
|
| 108 |
-
"added_arg": self.added_arg,
|
| 109 |
-
"added_value": self.added_value,
|
| 110 |
-
}
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
@dataclass
|
| 114 |
-
class ModifyConfigField(BreakagePrimitive):
|
| 115 |
-
"""Change a config-class default value to simulate behaviour drift."""
|
| 116 |
-
|
| 117 |
-
config_class: str = ""
|
| 118 |
-
field_name: str = ""
|
| 119 |
-
new_value: str = ""
|
| 120 |
-
|
| 121 |
-
def __post_init__(self) -> None:
|
| 122 |
-
self.category = "config_drift"
|
| 123 |
-
self.name = "ModifyConfigField"
|
| 124 |
-
self.description = f"Change {self.config_class}.{self.field_name}"
|
| 125 |
-
|
| 126 |
-
def apply(self, script: str) -> str:
|
| 127 |
-
if not self.field_name:
|
| 128 |
-
return script
|
| 129 |
-
pattern = rf"({re.escape(self.field_name)}\s*=\s*)([^,)\n]+)"
|
| 130 |
-
return re.sub(pattern, rf"\g<1>{self.new_value}", script)
|
| 131 |
-
|
| 132 |
-
def _get_params(self) -> dict:
|
| 133 |
-
return {
|
| 134 |
-
"config_class": self.config_class,
|
| 135 |
-
"field_name": self.field_name,
|
| 136 |
-
"new_value": self.new_value,
|
| 137 |
-
}
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
@dataclass
|
| 141 |
-
class RestructureDatasetSchema(BreakagePrimitive):
|
| 142 |
-
"""Rename a dataset column reference to simulate schema drift."""
|
| 143 |
-
|
| 144 |
-
old_column: str = ""
|
| 145 |
-
new_column: str = ""
|
| 146 |
-
|
| 147 |
-
def __post_init__(self) -> None:
|
| 148 |
-
self.category = "dataset_drift"
|
| 149 |
-
self.name = "RestructureDatasetSchema"
|
| 150 |
-
self.description = f"Rename column {self.old_column} -> {self.new_column}"
|
| 151 |
-
|
| 152 |
-
def apply(self, script: str) -> str:
|
| 153 |
-
if not self.old_column:
|
| 154 |
-
return script
|
| 155 |
-
return script.replace(
|
| 156 |
-
f'"{self.old_column}"', f'"{self.new_column}"'
|
| 157 |
-
).replace(
|
| 158 |
-
f"'{self.old_column}'", f"'{self.new_column}'"
|
| 159 |
-
)
|
| 160 |
-
|
| 161 |
-
def _get_params(self) -> dict:
|
| 162 |
-
return {"old_column": self.old_column, "new_column": self.new_column}
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
@dataclass
|
| 166 |
-
class ChangeTokenizerBehavior(BreakagePrimitive):
|
| 167 |
-
"""Change tokenizer call arguments."""
|
| 168 |
-
|
| 169 |
-
old_kwarg: str = ""
|
| 170 |
-
old_value: str = ""
|
| 171 |
-
new_kwarg: str = ""
|
| 172 |
-
new_value: str = ""
|
| 173 |
-
|
| 174 |
-
def __post_init__(self) -> None:
|
| 175 |
-
self.category = "tokenizer_drift"
|
| 176 |
-
self.name = "ChangeTokenizerBehavior"
|
| 177 |
-
self.description = f"Change tokenizer kwarg {self.old_kwarg}={self.old_value} -> {self.new_kwarg}={self.new_value}"
|
| 178 |
-
|
| 179 |
-
def apply(self, script: str) -> str:
|
| 180 |
-
if not self.old_kwarg:
|
| 181 |
-
return script
|
| 182 |
-
pattern = rf"{re.escape(self.old_kwarg)}\s*=\s*{re.escape(self.old_value)}"
|
| 183 |
-
replacement = f"{self.new_kwarg}={self.new_value}"
|
| 184 |
-
return re.sub(pattern, replacement, script)
|
| 185 |
-
|
| 186 |
-
def _get_params(self) -> dict:
|
| 187 |
-
return {
|
| 188 |
-
"old_kwarg": self.old_kwarg,
|
| 189 |
-
"old_value": self.old_value,
|
| 190 |
-
"new_kwarg": self.new_kwarg,
|
| 191 |
-
"new_value": self.new_value,
|
| 192 |
-
}
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
@dataclass
|
| 196 |
-
class RemoveDeprecatedMethod(BreakagePrimitive):
|
| 197 |
-
"""Remove a method that has been deprecated, leaving a sentinel that
|
| 198 |
-
raises AttributeError-style errors when the script runs."""
|
| 199 |
-
|
| 200 |
-
class_name: str = ""
|
| 201 |
-
method_name: str = ""
|
| 202 |
-
replacement: str = ""
|
| 203 |
-
|
| 204 |
-
def __post_init__(self) -> None:
|
| 205 |
-
self.category = "api_drift"
|
| 206 |
-
self.name = "RemoveDeprecatedMethod"
|
| 207 |
-
self.description = f"Remove {self.class_name}.{self.method_name}"
|
| 208 |
-
|
| 209 |
-
def apply(self, script: str) -> str:
|
| 210 |
-
if not self.method_name:
|
| 211 |
-
return script
|
| 212 |
-
return script.replace(
|
| 213 |
-
f".{self.method_name}(", f".{self.method_name}_DEPRECATED("
|
| 214 |
-
)
|
| 215 |
-
|
| 216 |
-
def _get_params(self) -> dict:
|
| 217 |
-
return {
|
| 218 |
-
"class_name": self.class_name,
|
| 219 |
-
"method_name": self.method_name,
|
| 220 |
-
"replacement": self.replacement,
|
| 221 |
-
}
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
@dataclass
|
| 225 |
-
class ChangeReturnType(BreakagePrimitive):
|
| 226 |
-
"""A function now returns a different structure (e.g. tuple -> object)."""
|
| 227 |
-
|
| 228 |
-
function_name: str = ""
|
| 229 |
-
old_access: str = ""
|
| 230 |
-
new_access: str = ""
|
| 231 |
-
|
| 232 |
-
def __post_init__(self) -> None:
|
| 233 |
-
self.category = "api_drift"
|
| 234 |
-
self.name = "ChangeReturnType"
|
| 235 |
-
self.description = f"Change return type of {self.function_name}"
|
| 236 |
-
|
| 237 |
-
def apply(self, script: str) -> str:
|
| 238 |
-
if self.old_access and self.new_access:
|
| 239 |
-
return script.replace(self.old_access, self.new_access)
|
| 240 |
-
return script
|
| 241 |
-
|
| 242 |
-
def _get_params(self) -> dict:
|
| 243 |
-
return {
|
| 244 |
-
"function_name": self.function_name,
|
| 245 |
-
"old_access": self.old_access,
|
| 246 |
-
"new_access": self.new_access,
|
| 247 |
-
}
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
PRIMITIVE_REGISTRY: dict[str, type[BreakagePrimitive]] = {
|
| 251 |
-
"RenameApiCall": RenameApiCall,
|
| 252 |
-
"DeprecateImport": DeprecateImport,
|
| 253 |
-
"ChangeArgumentSignature": ChangeArgumentSignature,
|
| 254 |
-
"ModifyConfigField": ModifyConfigField,
|
| 255 |
-
"RestructureDatasetSchema": RestructureDatasetSchema,
|
| 256 |
-
"ChangeTokenizerBehavior": ChangeTokenizerBehavior,
|
| 257 |
-
"RemoveDeprecatedMethod": RemoveDeprecatedMethod,
|
| 258 |
-
"ChangeReturnType": ChangeReturnType,
|
| 259 |
-
}
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
def parse_breakage_spec(spec: dict) -> BreakagePrimitive:
|
| 263 |
-
"""Parse a JSON breakage spec into a BreakagePrimitive object.
|
| 264 |
-
|
| 265 |
-
Tolerates extra keys; ignores unknown params (LLMs hallucinate these).
|
| 266 |
-
"""
|
| 267 |
-
ptype = spec.get("primitive_type", "")
|
| 268 |
-
params = spec.get("params", {}) or {}
|
| 269 |
-
|
| 270 |
-
if ptype not in PRIMITIVE_REGISTRY:
|
| 271 |
-
raise ValueError(
|
| 272 |
-
f"Unknown primitive type: {ptype!r}. "
|
| 273 |
-
f"Valid types: {list(PRIMITIVE_REGISTRY)}"
|
| 274 |
-
)
|
| 275 |
-
|
| 276 |
-
cls = PRIMITIVE_REGISTRY[ptype]
|
| 277 |
-
# Filter to known fields only so a hallucinated kwarg can't crash us.
|
| 278 |
-
valid_fields = {
|
| 279 |
-
f.name for f in cls.__dataclass_fields__.values() if f.init # type: ignore[attr-defined]
|
| 280 |
-
}
|
| 281 |
-
filtered = {k: v for k, v in params.items() if k in valid_fields}
|
| 282 |
-
return cls(**filtered)
|
|
|
|
| 1 |
+
"""8 breakage primitives representing real HuggingFace/PyTorch ecosystem drift.
|
| 2 |
+
|
| 3 |
+
Each primitive transforms a working script to simulate a library upgrade
|
| 4 |
+
breakage. They double as the Drift Generator's structured action space.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import re
|
| 9 |
+
from abc import ABC, abstractmethod
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class BreakagePrimitive(ABC):
|
| 15 |
+
"""Abstract base class for all breakage types."""
|
| 16 |
+
|
| 17 |
+
category: str = field(default="generic", init=False)
|
| 18 |
+
name: str = field(default="BreakagePrimitive", init=False)
|
| 19 |
+
description: str = field(default="", init=False)
|
| 20 |
+
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def apply(self, script: str) -> str:
|
| 23 |
+
"""Transform `script` to introduce the breakage."""
|
| 24 |
+
|
| 25 |
+
def to_spec(self) -> dict:
|
| 26 |
+
"""Serialize to JSON-compatible spec for the LLM action space."""
|
| 27 |
+
return {
|
| 28 |
+
"primitive_type": self.__class__.__name__,
|
| 29 |
+
"category": self.category,
|
| 30 |
+
"params": self._get_params(),
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def _get_params(self) -> dict:
|
| 35 |
+
"""Return a JSON-serializable dict of constructor parameters."""
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class RenameApiCall(BreakagePrimitive):
|
| 40 |
+
"""Rename a function/method call to simulate API deprecation."""
|
| 41 |
+
|
| 42 |
+
old_name: str = ""
|
| 43 |
+
new_name: str = ""
|
| 44 |
+
|
| 45 |
+
def __post_init__(self) -> None:
|
| 46 |
+
self.category = "api_drift"
|
| 47 |
+
self.name = "RenameApiCall"
|
| 48 |
+
self.description = f"Rename {self.old_name} -> {self.new_name}"
|
| 49 |
+
|
| 50 |
+
def apply(self, script: str) -> str:
|
| 51 |
+
if not self.old_name:
|
| 52 |
+
return script
|
| 53 |
+
# Use word-boundary replacement so we don't substring-match identifiers.
|
| 54 |
+
pattern = re.compile(rf"(?<!\w){re.escape(self.old_name)}(?!\w)")
|
| 55 |
+
return pattern.sub(self.new_name, script)
|
| 56 |
+
|
| 57 |
+
def _get_params(self) -> dict:
|
| 58 |
+
return {"old_name": self.old_name, "new_name": self.new_name}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dataclass
|
| 62 |
+
class DeprecateImport(BreakagePrimitive):
|
| 63 |
+
"""Change an import path to simulate module restructuring."""
|
| 64 |
+
|
| 65 |
+
old_module: str = ""
|
| 66 |
+
new_module: str = ""
|
| 67 |
+
|
| 68 |
+
def __post_init__(self) -> None:
|
| 69 |
+
self.category = "import_drift"
|
| 70 |
+
self.name = "DeprecateImport"
|
| 71 |
+
self.description = f"Move {self.old_module} -> {self.new_module}"
|
| 72 |
+
|
| 73 |
+
def apply(self, script: str) -> str:
|
| 74 |
+
if not self.old_module:
|
| 75 |
+
return script
|
| 76 |
+
return script.replace(self.old_module, self.new_module)
|
| 77 |
+
|
| 78 |
+
def _get_params(self) -> dict:
|
| 79 |
+
return {"old_module": self.old_module, "new_module": self.new_module}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@dataclass
|
| 83 |
+
class ChangeArgumentSignature(BreakagePrimitive):
|
| 84 |
+
"""Remove an expected kwarg (and document a new required one)."""
|
| 85 |
+
|
| 86 |
+
function_name: str = ""
|
| 87 |
+
removed_arg: str = ""
|
| 88 |
+
added_arg: str = ""
|
| 89 |
+
added_value: str = ""
|
| 90 |
+
|
| 91 |
+
def __post_init__(self) -> None:
|
| 92 |
+
self.category = "api_drift"
|
| 93 |
+
self.name = "ChangeArgumentSignature"
|
| 94 |
+
self.description = (
|
| 95 |
+
f"Change args of {self.function_name}: -{self.removed_arg} +{self.added_arg}"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def apply(self, script: str) -> str:
|
| 99 |
+
if not self.removed_arg:
|
| 100 |
+
return script
|
| 101 |
+
pattern = rf"(\b{re.escape(self.removed_arg)}\s*=\s*[^,)]+,?\s*)"
|
| 102 |
+
return re.sub(pattern, "", script)
|
| 103 |
+
|
| 104 |
+
def _get_params(self) -> dict:
|
| 105 |
+
return {
|
| 106 |
+
"function_name": self.function_name,
|
| 107 |
+
"removed_arg": self.removed_arg,
|
| 108 |
+
"added_arg": self.added_arg,
|
| 109 |
+
"added_value": self.added_value,
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@dataclass
|
| 114 |
+
class ModifyConfigField(BreakagePrimitive):
|
| 115 |
+
"""Change a config-class default value to simulate behaviour drift."""
|
| 116 |
+
|
| 117 |
+
config_class: str = ""
|
| 118 |
+
field_name: str = ""
|
| 119 |
+
new_value: str = ""
|
| 120 |
+
|
| 121 |
+
def __post_init__(self) -> None:
|
| 122 |
+
self.category = "config_drift"
|
| 123 |
+
self.name = "ModifyConfigField"
|
| 124 |
+
self.description = f"Change {self.config_class}.{self.field_name}"
|
| 125 |
+
|
| 126 |
+
def apply(self, script: str) -> str:
|
| 127 |
+
if not self.field_name:
|
| 128 |
+
return script
|
| 129 |
+
pattern = rf"({re.escape(self.field_name)}\s*=\s*)([^,)\n]+)"
|
| 130 |
+
return re.sub(pattern, rf"\g<1>{self.new_value}", script)
|
| 131 |
+
|
| 132 |
+
def _get_params(self) -> dict:
|
| 133 |
+
return {
|
| 134 |
+
"config_class": self.config_class,
|
| 135 |
+
"field_name": self.field_name,
|
| 136 |
+
"new_value": self.new_value,
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@dataclass
|
| 141 |
+
class RestructureDatasetSchema(BreakagePrimitive):
|
| 142 |
+
"""Rename a dataset column reference to simulate schema drift."""
|
| 143 |
+
|
| 144 |
+
old_column: str = ""
|
| 145 |
+
new_column: str = ""
|
| 146 |
+
|
| 147 |
+
def __post_init__(self) -> None:
|
| 148 |
+
self.category = "dataset_drift"
|
| 149 |
+
self.name = "RestructureDatasetSchema"
|
| 150 |
+
self.description = f"Rename column {self.old_column} -> {self.new_column}"
|
| 151 |
+
|
| 152 |
+
def apply(self, script: str) -> str:
|
| 153 |
+
if not self.old_column:
|
| 154 |
+
return script
|
| 155 |
+
return script.replace(
|
| 156 |
+
f'"{self.old_column}"', f'"{self.new_column}"'
|
| 157 |
+
).replace(
|
| 158 |
+
f"'{self.old_column}'", f"'{self.new_column}'"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def _get_params(self) -> dict:
|
| 162 |
+
return {"old_column": self.old_column, "new_column": self.new_column}
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
@dataclass
|
| 166 |
+
class ChangeTokenizerBehavior(BreakagePrimitive):
|
| 167 |
+
"""Change tokenizer call arguments."""
|
| 168 |
+
|
| 169 |
+
old_kwarg: str = ""
|
| 170 |
+
old_value: str = ""
|
| 171 |
+
new_kwarg: str = ""
|
| 172 |
+
new_value: str = ""
|
| 173 |
+
|
| 174 |
+
def __post_init__(self) -> None:
|
| 175 |
+
self.category = "tokenizer_drift"
|
| 176 |
+
self.name = "ChangeTokenizerBehavior"
|
| 177 |
+
self.description = f"Change tokenizer kwarg {self.old_kwarg}={self.old_value} -> {self.new_kwarg}={self.new_value}"
|
| 178 |
+
|
| 179 |
+
def apply(self, script: str) -> str:
|
| 180 |
+
if not self.old_kwarg:
|
| 181 |
+
return script
|
| 182 |
+
pattern = rf"{re.escape(self.old_kwarg)}\s*=\s*{re.escape(self.old_value)}"
|
| 183 |
+
replacement = f"{self.new_kwarg}={self.new_value}"
|
| 184 |
+
return re.sub(pattern, replacement, script)
|
| 185 |
+
|
| 186 |
+
def _get_params(self) -> dict:
|
| 187 |
+
return {
|
| 188 |
+
"old_kwarg": self.old_kwarg,
|
| 189 |
+
"old_value": self.old_value,
|
| 190 |
+
"new_kwarg": self.new_kwarg,
|
| 191 |
+
"new_value": self.new_value,
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@dataclass
|
| 196 |
+
class RemoveDeprecatedMethod(BreakagePrimitive):
|
| 197 |
+
"""Remove a method that has been deprecated, leaving a sentinel that
|
| 198 |
+
raises AttributeError-style errors when the script runs."""
|
| 199 |
+
|
| 200 |
+
class_name: str = ""
|
| 201 |
+
method_name: str = ""
|
| 202 |
+
replacement: str = ""
|
| 203 |
+
|
| 204 |
+
def __post_init__(self) -> None:
|
| 205 |
+
self.category = "api_drift"
|
| 206 |
+
self.name = "RemoveDeprecatedMethod"
|
| 207 |
+
self.description = f"Remove {self.class_name}.{self.method_name}"
|
| 208 |
+
|
| 209 |
+
def apply(self, script: str) -> str:
|
| 210 |
+
if not self.method_name:
|
| 211 |
+
return script
|
| 212 |
+
return script.replace(
|
| 213 |
+
f".{self.method_name}(", f".{self.method_name}_DEPRECATED("
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
def _get_params(self) -> dict:
|
| 217 |
+
return {
|
| 218 |
+
"class_name": self.class_name,
|
| 219 |
+
"method_name": self.method_name,
|
| 220 |
+
"replacement": self.replacement,
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@dataclass
|
| 225 |
+
class ChangeReturnType(BreakagePrimitive):
|
| 226 |
+
"""A function now returns a different structure (e.g. tuple -> object)."""
|
| 227 |
+
|
| 228 |
+
function_name: str = ""
|
| 229 |
+
old_access: str = ""
|
| 230 |
+
new_access: str = ""
|
| 231 |
+
|
| 232 |
+
def __post_init__(self) -> None:
|
| 233 |
+
self.category = "api_drift"
|
| 234 |
+
self.name = "ChangeReturnType"
|
| 235 |
+
self.description = f"Change return type of {self.function_name}"
|
| 236 |
+
|
| 237 |
+
def apply(self, script: str) -> str:
|
| 238 |
+
if self.old_access and self.new_access:
|
| 239 |
+
return script.replace(self.old_access, self.new_access)
|
| 240 |
+
return script
|
| 241 |
+
|
| 242 |
+
def _get_params(self) -> dict:
|
| 243 |
+
return {
|
| 244 |
+
"function_name": self.function_name,
|
| 245 |
+
"old_access": self.old_access,
|
| 246 |
+
"new_access": self.new_access,
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
PRIMITIVE_REGISTRY: dict[str, type[BreakagePrimitive]] = {
|
| 251 |
+
"RenameApiCall": RenameApiCall,
|
| 252 |
+
"DeprecateImport": DeprecateImport,
|
| 253 |
+
"ChangeArgumentSignature": ChangeArgumentSignature,
|
| 254 |
+
"ModifyConfigField": ModifyConfigField,
|
| 255 |
+
"RestructureDatasetSchema": RestructureDatasetSchema,
|
| 256 |
+
"ChangeTokenizerBehavior": ChangeTokenizerBehavior,
|
| 257 |
+
"RemoveDeprecatedMethod": RemoveDeprecatedMethod,
|
| 258 |
+
"ChangeReturnType": ChangeReturnType,
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def parse_breakage_spec(spec: dict) -> BreakagePrimitive:
|
| 263 |
+
"""Parse a JSON breakage spec into a BreakagePrimitive object.
|
| 264 |
+
|
| 265 |
+
Tolerates extra keys; ignores unknown params (LLMs hallucinate these).
|
| 266 |
+
"""
|
| 267 |
+
ptype = spec.get("primitive_type", "")
|
| 268 |
+
params = spec.get("params", {}) or {}
|
| 269 |
+
|
| 270 |
+
if ptype not in PRIMITIVE_REGISTRY:
|
| 271 |
+
raise ValueError(
|
| 272 |
+
f"Unknown primitive type: {ptype!r}. "
|
| 273 |
+
f"Valid types: {list(PRIMITIVE_REGISTRY)}"
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
cls = PRIMITIVE_REGISTRY[ptype]
|
| 277 |
+
# Filter to known fields only so a hallucinated kwarg can't crash us.
|
| 278 |
+
valid_fields = {
|
| 279 |
+
f.name for f in cls.__dataclass_fields__.values() if f.init # type: ignore[attr-defined]
|
| 280 |
+
}
|
| 281 |
+
filtered = {k: v for k, v in params.items() if k in valid_fields}
|
| 282 |
+
return cls(**filtered)
|
forgeenv/primitives/drift_taxonomy.yaml
CHANGED
|
@@ -1,217 +1,217 @@
|
|
| 1 |
-
# Drift taxonomy: real HuggingFace/PyTorch breakages observed across version bumps.
|
| 2 |
-
# Used to seed the Drift Generator's initial proposal distribution and to anchor
|
| 3 |
-
# warm-start pair generation in things that actually happened in the wild.
|
| 4 |
-
- version_range: "transformers 4.36 -> 4.45"
|
| 5 |
-
affected_api: "Trainer.evaluate"
|
| 6 |
-
description: "Trainer.evaluate() return type changed shape; metrics now nested under .metrics"
|
| 7 |
-
breakage_primitive: "ChangeReturnType"
|
| 8 |
-
params:
|
| 9 |
-
function_name: "evaluate"
|
| 10 |
-
old_access: "trainer.evaluate()"
|
| 11 |
-
new_access: "trainer.evaluate().metrics"
|
| 12 |
-
repair_primitive: "RestoreReturnAccess"
|
| 13 |
-
category: "api_drift"
|
| 14 |
-
|
| 15 |
-
- version_range: "transformers 4.30 -> 4.40"
|
| 16 |
-
affected_api: "TrainingArguments.evaluation_strategy"
|
| 17 |
-
description: "Renamed evaluation_strategy -> eval_strategy"
|
| 18 |
-
breakage_primitive: "RenameApiCall"
|
| 19 |
-
params:
|
| 20 |
-
old_name: "evaluation_strategy"
|
| 21 |
-
new_name: "eval_strategy"
|
| 22 |
-
repair_primitive: "RestoreApiCall"
|
| 23 |
-
category: "api_drift"
|
| 24 |
-
|
| 25 |
-
- version_range: "datasets 2.14 -> 3.0"
|
| 26 |
-
affected_api: "load_dataset"
|
| 27 |
-
description: "Default split column was renamed in some GLUE configs"
|
| 28 |
-
breakage_primitive: "RestructureDatasetSchema"
|
| 29 |
-
params:
|
| 30 |
-
old_column: "label"
|
| 31 |
-
new_column: "labels"
|
| 32 |
-
repair_primitive: "RestoreColumn"
|
| 33 |
-
category: "dataset_drift"
|
| 34 |
-
|
| 35 |
-
- version_range: "transformers 4.40 -> 4.50"
|
| 36 |
-
affected_api: "Trainer.predict"
|
| 37 |
-
description: "Method removed; users should use evaluate() with prediction_loss_only=False"
|
| 38 |
-
breakage_primitive: "RemoveDeprecatedMethod"
|
| 39 |
-
params:
|
| 40 |
-
class_name: "Trainer"
|
| 41 |
-
method_name: "predict"
|
| 42 |
-
replacement: "evaluate"
|
| 43 |
-
repair_primitive: "RestoreMethod"
|
| 44 |
-
category: "api_drift"
|
| 45 |
-
|
| 46 |
-
- version_range: "transformers 4.36 -> 4.40"
|
| 47 |
-
affected_api: "TrainingArguments"
|
| 48 |
-
description: "num_train_epochs default behavior changed; max_steps now preferred"
|
| 49 |
-
breakage_primitive: "ModifyConfigField"
|
| 50 |
-
params:
|
| 51 |
-
config_class: "TrainingArguments"
|
| 52 |
-
field_name: "num_train_epochs"
|
| 53 |
-
new_value: "0"
|
| 54 |
-
repair_primitive: "RestoreConfigField"
|
| 55 |
-
category: "config_drift"
|
| 56 |
-
|
| 57 |
-
- version_range: "transformers 4.34 -> 4.42"
|
| 58 |
-
affected_api: "Tokenizer.__call__"
|
| 59 |
-
description: "padding=True semantics changed; users should pass padding='max_length'"
|
| 60 |
-
breakage_primitive: "ChangeTokenizerBehavior"
|
| 61 |
-
params:
|
| 62 |
-
old_kwarg: "padding"
|
| 63 |
-
old_value: "True"
|
| 64 |
-
new_kwarg: "padding"
|
| 65 |
-
new_value: '"max_length"'
|
| 66 |
-
repair_primitive: "RestoreTokenizerKwarg"
|
| 67 |
-
category: "tokenizer_drift"
|
| 68 |
-
|
| 69 |
-
- version_range: "transformers 4.20 -> 4.30"
|
| 70 |
-
affected_api: "imports"
|
| 71 |
-
description: "transformers.training_args moved to transformers.training_args_pt"
|
| 72 |
-
breakage_primitive: "DeprecateImport"
|
| 73 |
-
params:
|
| 74 |
-
old_module: "from transformers.training_args"
|
| 75 |
-
new_module: "from transformers.training_args_pt"
|
| 76 |
-
repair_primitive: "RestoreImport"
|
| 77 |
-
category: "import_drift"
|
| 78 |
-
|
| 79 |
-
- version_range: "transformers 4.45 -> 4.50"
|
| 80 |
-
affected_api: "save_pretrained"
|
| 81 |
-
description: "save_pretrained() now requires safe_serialization to default True"
|
| 82 |
-
breakage_primitive: "ChangeArgumentSignature"
|
| 83 |
-
params:
|
| 84 |
-
function_name: "save_pretrained"
|
| 85 |
-
removed_arg: "safe_serialization"
|
| 86 |
-
added_arg: "safe_serialization"
|
| 87 |
-
added_value: "True"
|
| 88 |
-
repair_primitive: "RestoreArgument"
|
| 89 |
-
category: "api_drift"
|
| 90 |
-
|
| 91 |
-
- version_range: "datasets 2.18 -> 3.0"
|
| 92 |
-
affected_api: "Dataset.set_format"
|
| 93 |
-
description: "set_format(type='torch') signature stricter, columns required"
|
| 94 |
-
breakage_primitive: "ChangeArgumentSignature"
|
| 95 |
-
params:
|
| 96 |
-
function_name: "set_format"
|
| 97 |
-
removed_arg: "columns"
|
| 98 |
-
added_arg: "columns"
|
| 99 |
-
added_value: '["input_ids", "attention_mask", "labels"]'
|
| 100 |
-
repair_primitive: "RestoreArgument"
|
| 101 |
-
category: "api_drift"
|
| 102 |
-
|
| 103 |
-
- version_range: "transformers 4.36 -> 4.45"
|
| 104 |
-
affected_api: "Tokenizer.__call__"
|
| 105 |
-
description: "max_length default reduced from 512 -> 256 for some tokenizers"
|
| 106 |
-
breakage_primitive: "ModifyConfigField"
|
| 107 |
-
params:
|
| 108 |
-
config_class: "tokenizer"
|
| 109 |
-
field_name: "max_length"
|
| 110 |
-
new_value: "256"
|
| 111 |
-
repair_primitive: "RestoreConfigField"
|
| 112 |
-
category: "tokenizer_drift"
|
| 113 |
-
|
| 114 |
-
- version_range: "transformers 4.40 -> 4.45"
|
| 115 |
-
affected_api: "DataCollatorWithPadding"
|
| 116 |
-
description: "Renamed `tokenizer` -> `processing_class` in DataCollator constructors"
|
| 117 |
-
breakage_primitive: "RenameApiCall"
|
| 118 |
-
params:
|
| 119 |
-
old_name: "tokenizer"
|
| 120 |
-
new_name: "processing_class"
|
| 121 |
-
repair_primitive: "RestoreApiCall"
|
| 122 |
-
category: "api_drift"
|
| 123 |
-
|
| 124 |
-
- version_range: "datasets 2.14 -> 2.18"
|
| 125 |
-
affected_api: "load_dataset"
|
| 126 |
-
description: "Some splits renamed train[:500] semantics changed"
|
| 127 |
-
breakage_primitive: "RestructureDatasetSchema"
|
| 128 |
-
params:
|
| 129 |
-
old_column: "sentence"
|
| 130 |
-
new_column: "text"
|
| 131 |
-
repair_primitive: "RestoreColumn"
|
| 132 |
-
category: "dataset_drift"
|
| 133 |
-
|
| 134 |
-
- version_range: "transformers 4.45 -> 4.50"
|
| 135 |
-
affected_api: "Trainer"
|
| 136 |
-
description: "evaluation_strategy was deprecated and removed"
|
| 137 |
-
breakage_primitive: "RemoveDeprecatedMethod"
|
| 138 |
-
params:
|
| 139 |
-
class_name: "Trainer"
|
| 140 |
-
method_name: "evaluate"
|
| 141 |
-
replacement: "evaluate_legacy"
|
| 142 |
-
repair_primitive: "RestoreMethod"
|
| 143 |
-
category: "api_drift"
|
| 144 |
-
|
| 145 |
-
- version_range: "transformers 4.30 -> 4.40"
|
| 146 |
-
affected_api: "PreTrainedModel.from_pretrained"
|
| 147 |
-
description: "torch_dtype now required for some quantized model paths"
|
| 148 |
-
breakage_primitive: "ChangeArgumentSignature"
|
| 149 |
-
params:
|
| 150 |
-
function_name: "from_pretrained"
|
| 151 |
-
removed_arg: "torch_dtype"
|
| 152 |
-
added_arg: "torch_dtype"
|
| 153 |
-
added_value: '"auto"'
|
| 154 |
-
repair_primitive: "RestoreArgument"
|
| 155 |
-
category: "api_drift"
|
| 156 |
-
|
| 157 |
-
- version_range: "datasets 3.0 -> 3.2"
|
| 158 |
-
affected_api: "Dataset.rename_column"
|
| 159 |
-
description: "rename_column raises if target name exists"
|
| 160 |
-
breakage_primitive: "RestructureDatasetSchema"
|
| 161 |
-
params:
|
| 162 |
-
old_column: "labels"
|
| 163 |
-
new_column: "label"
|
| 164 |
-
repair_primitive: "RestoreColumn"
|
| 165 |
-
category: "dataset_drift"
|
| 166 |
-
|
| 167 |
-
- version_range: "transformers 4.36 -> 4.42"
|
| 168 |
-
affected_api: "TrainingArguments.report_to"
|
| 169 |
-
description: "Default report_to changed from 'all' to 'none'"
|
| 170 |
-
breakage_primitive: "ModifyConfigField"
|
| 171 |
-
params:
|
| 172 |
-
config_class: "TrainingArguments"
|
| 173 |
-
field_name: "report_to"
|
| 174 |
-
new_value: '"all"'
|
| 175 |
-
repair_primitive: "RestoreConfigField"
|
| 176 |
-
category: "config_drift"
|
| 177 |
-
|
| 178 |
-
- version_range: "transformers 4.40 -> 4.50"
|
| 179 |
-
affected_api: "imports"
|
| 180 |
-
description: "transformers.deepspeed moved to accelerate.utils.deepspeed"
|
| 181 |
-
breakage_primitive: "DeprecateImport"
|
| 182 |
-
params:
|
| 183 |
-
old_module: "from transformers.deepspeed"
|
| 184 |
-
new_module: "from accelerate.utils.deepspeed"
|
| 185 |
-
repair_primitive: "RestoreImport"
|
| 186 |
-
category: "import_drift"
|
| 187 |
-
|
| 188 |
-
- version_range: "transformers 4.45 -> 4.50"
|
| 189 |
-
affected_api: "Tokenizer return"
|
| 190 |
-
description: "Tokenizer call output now returns a BatchEncoding with .encodings attribute"
|
| 191 |
-
breakage_primitive: "ChangeReturnType"
|
| 192 |
-
params:
|
| 193 |
-
function_name: "tokenizer"
|
| 194 |
-
old_access: "tokenizer(text)"
|
| 195 |
-
new_access: "tokenizer(text).encodings"
|
| 196 |
-
repair_primitive: "RestoreReturnAccess"
|
| 197 |
-
category: "api_drift"
|
| 198 |
-
|
| 199 |
-
- version_range: "transformers 4.30 -> 4.40"
|
| 200 |
-
affected_api: "save_pretrained"
|
| 201 |
-
description: "save_pretrained -> save_pretrained_directory rename in some classes"
|
| 202 |
-
breakage_primitive: "RenameApiCall"
|
| 203 |
-
params:
|
| 204 |
-
old_name: "save_pretrained"
|
| 205 |
-
new_name: "save_pretrained_directory"
|
| 206 |
-
repair_primitive: "RestoreApiCall"
|
| 207 |
-
category: "api_drift"
|
| 208 |
-
|
| 209 |
-
- version_range: "transformers 4.45 -> 4.50"
|
| 210 |
-
affected_api: "TrainingArguments.no_cuda"
|
| 211 |
-
description: "no_cuda renamed to use_cpu (logic inverted)"
|
| 212 |
-
breakage_primitive: "RenameApiCall"
|
| 213 |
-
params:
|
| 214 |
-
old_name: "no_cuda"
|
| 215 |
-
new_name: "use_cpu"
|
| 216 |
-
repair_primitive: "RestoreApiCall"
|
| 217 |
-
category: "config_drift"
|
|
|
|
| 1 |
+
# Drift taxonomy: real HuggingFace/PyTorch breakages observed across version bumps.
|
| 2 |
+
# Used to seed the Drift Generator's initial proposal distribution and to anchor
|
| 3 |
+
# warm-start pair generation in things that actually happened in the wild.
|
| 4 |
+
- version_range: "transformers 4.36 -> 4.45"
|
| 5 |
+
affected_api: "Trainer.evaluate"
|
| 6 |
+
description: "Trainer.evaluate() return type changed shape; metrics now nested under .metrics"
|
| 7 |
+
breakage_primitive: "ChangeReturnType"
|
| 8 |
+
params:
|
| 9 |
+
function_name: "evaluate"
|
| 10 |
+
old_access: "trainer.evaluate()"
|
| 11 |
+
new_access: "trainer.evaluate().metrics"
|
| 12 |
+
repair_primitive: "RestoreReturnAccess"
|
| 13 |
+
category: "api_drift"
|
| 14 |
+
|
| 15 |
+
- version_range: "transformers 4.30 -> 4.40"
|
| 16 |
+
affected_api: "TrainingArguments.evaluation_strategy"
|
| 17 |
+
description: "Renamed evaluation_strategy -> eval_strategy"
|
| 18 |
+
breakage_primitive: "RenameApiCall"
|
| 19 |
+
params:
|
| 20 |
+
old_name: "evaluation_strategy"
|
| 21 |
+
new_name: "eval_strategy"
|
| 22 |
+
repair_primitive: "RestoreApiCall"
|
| 23 |
+
category: "api_drift"
|
| 24 |
+
|
| 25 |
+
- version_range: "datasets 2.14 -> 3.0"
|
| 26 |
+
affected_api: "load_dataset"
|
| 27 |
+
description: "Default split column was renamed in some GLUE configs"
|
| 28 |
+
breakage_primitive: "RestructureDatasetSchema"
|
| 29 |
+
params:
|
| 30 |
+
old_column: "label"
|
| 31 |
+
new_column: "labels"
|
| 32 |
+
repair_primitive: "RestoreColumn"
|
| 33 |
+
category: "dataset_drift"
|
| 34 |
+
|
| 35 |
+
- version_range: "transformers 4.40 -> 4.50"
|
| 36 |
+
affected_api: "Trainer.predict"
|
| 37 |
+
description: "Method removed; users should use evaluate() with prediction_loss_only=False"
|
| 38 |
+
breakage_primitive: "RemoveDeprecatedMethod"
|
| 39 |
+
params:
|
| 40 |
+
class_name: "Trainer"
|
| 41 |
+
method_name: "predict"
|
| 42 |
+
replacement: "evaluate"
|
| 43 |
+
repair_primitive: "RestoreMethod"
|
| 44 |
+
category: "api_drift"
|
| 45 |
+
|
| 46 |
+
- version_range: "transformers 4.36 -> 4.40"
|
| 47 |
+
affected_api: "TrainingArguments"
|
| 48 |
+
description: "num_train_epochs default behavior changed; max_steps now preferred"
|
| 49 |
+
breakage_primitive: "ModifyConfigField"
|
| 50 |
+
params:
|
| 51 |
+
config_class: "TrainingArguments"
|
| 52 |
+
field_name: "num_train_epochs"
|
| 53 |
+
new_value: "0"
|
| 54 |
+
repair_primitive: "RestoreConfigField"
|
| 55 |
+
category: "config_drift"
|
| 56 |
+
|
| 57 |
+
- version_range: "transformers 4.34 -> 4.42"
|
| 58 |
+
affected_api: "Tokenizer.__call__"
|
| 59 |
+
description: "padding=True semantics changed; users should pass padding='max_length'"
|
| 60 |
+
breakage_primitive: "ChangeTokenizerBehavior"
|
| 61 |
+
params:
|
| 62 |
+
old_kwarg: "padding"
|
| 63 |
+
old_value: "True"
|
| 64 |
+
new_kwarg: "padding"
|
| 65 |
+
new_value: '"max_length"'
|
| 66 |
+
repair_primitive: "RestoreTokenizerKwarg"
|
| 67 |
+
category: "tokenizer_drift"
|
| 68 |
+
|
| 69 |
+
- version_range: "transformers 4.20 -> 4.30"
|
| 70 |
+
affected_api: "imports"
|
| 71 |
+
description: "transformers.training_args moved to transformers.training_args_pt"
|
| 72 |
+
breakage_primitive: "DeprecateImport"
|
| 73 |
+
params:
|
| 74 |
+
old_module: "from transformers.training_args"
|
| 75 |
+
new_module: "from transformers.training_args_pt"
|
| 76 |
+
repair_primitive: "RestoreImport"
|
| 77 |
+
category: "import_drift"
|
| 78 |
+
|
| 79 |
+
- version_range: "transformers 4.45 -> 4.50"
|
| 80 |
+
affected_api: "save_pretrained"
|
| 81 |
+
description: "save_pretrained() now requires safe_serialization to default True"
|
| 82 |
+
breakage_primitive: "ChangeArgumentSignature"
|
| 83 |
+
params:
|
| 84 |
+
function_name: "save_pretrained"
|
| 85 |
+
removed_arg: "safe_serialization"
|
| 86 |
+
added_arg: "safe_serialization"
|
| 87 |
+
added_value: "True"
|
| 88 |
+
repair_primitive: "RestoreArgument"
|
| 89 |
+
category: "api_drift"
|
| 90 |
+
|
| 91 |
+
- version_range: "datasets 2.18 -> 3.0"
|
| 92 |
+
affected_api: "Dataset.set_format"
|
| 93 |
+
description: "set_format(type='torch') signature stricter, columns required"
|
| 94 |
+
breakage_primitive: "ChangeArgumentSignature"
|
| 95 |
+
params:
|
| 96 |
+
function_name: "set_format"
|
| 97 |
+
removed_arg: "columns"
|
| 98 |
+
added_arg: "columns"
|
| 99 |
+
added_value: '["input_ids", "attention_mask", "labels"]'
|
| 100 |
+
repair_primitive: "RestoreArgument"
|
| 101 |
+
category: "api_drift"
|
| 102 |
+
|
| 103 |
+
- version_range: "transformers 4.36 -> 4.45"
|
| 104 |
+
affected_api: "Tokenizer.__call__"
|
| 105 |
+
description: "max_length default reduced from 512 -> 256 for some tokenizers"
|
| 106 |
+
breakage_primitive: "ModifyConfigField"
|
| 107 |
+
params:
|
| 108 |
+
config_class: "tokenizer"
|
| 109 |
+
field_name: "max_length"
|
| 110 |
+
new_value: "256"
|
| 111 |
+
repair_primitive: "RestoreConfigField"
|
| 112 |
+
category: "tokenizer_drift"
|
| 113 |
+
|
| 114 |
+
- version_range: "transformers 4.40 -> 4.45"
|
| 115 |
+
affected_api: "DataCollatorWithPadding"
|
| 116 |
+
description: "Renamed `tokenizer` -> `processing_class` in DataCollator constructors"
|
| 117 |
+
breakage_primitive: "RenameApiCall"
|
| 118 |
+
params:
|
| 119 |
+
old_name: "tokenizer"
|
| 120 |
+
new_name: "processing_class"
|
| 121 |
+
repair_primitive: "RestoreApiCall"
|
| 122 |
+
category: "api_drift"
|
| 123 |
+
|
| 124 |
+
- version_range: "datasets 2.14 -> 2.18"
|
| 125 |
+
affected_api: "load_dataset"
|
| 126 |
+
description: "Some splits renamed train[:500] semantics changed"
|
| 127 |
+
breakage_primitive: "RestructureDatasetSchema"
|
| 128 |
+
params:
|
| 129 |
+
old_column: "sentence"
|
| 130 |
+
new_column: "text"
|
| 131 |
+
repair_primitive: "RestoreColumn"
|
| 132 |
+
category: "dataset_drift"
|
| 133 |
+
|
| 134 |
+
- version_range: "transformers 4.45 -> 4.50"
|
| 135 |
+
affected_api: "Trainer"
|
| 136 |
+
description: "evaluation_strategy was deprecated and removed"
|
| 137 |
+
breakage_primitive: "RemoveDeprecatedMethod"
|
| 138 |
+
params:
|
| 139 |
+
class_name: "Trainer"
|
| 140 |
+
method_name: "evaluate"
|
| 141 |
+
replacement: "evaluate_legacy"
|
| 142 |
+
repair_primitive: "RestoreMethod"
|
| 143 |
+
category: "api_drift"
|
| 144 |
+
|
| 145 |
+
- version_range: "transformers 4.30 -> 4.40"
|
| 146 |
+
affected_api: "PreTrainedModel.from_pretrained"
|
| 147 |
+
description: "torch_dtype now required for some quantized model paths"
|
| 148 |
+
breakage_primitive: "ChangeArgumentSignature"
|
| 149 |
+
params:
|
| 150 |
+
function_name: "from_pretrained"
|
| 151 |
+
removed_arg: "torch_dtype"
|
| 152 |
+
added_arg: "torch_dtype"
|
| 153 |
+
added_value: '"auto"'
|
| 154 |
+
repair_primitive: "RestoreArgument"
|
| 155 |
+
category: "api_drift"
|
| 156 |
+
|
| 157 |
+
- version_range: "datasets 3.0 -> 3.2"
|
| 158 |
+
affected_api: "Dataset.rename_column"
|
| 159 |
+
description: "rename_column raises if target name exists"
|
| 160 |
+
breakage_primitive: "RestructureDatasetSchema"
|
| 161 |
+
params:
|
| 162 |
+
old_column: "labels"
|
| 163 |
+
new_column: "label"
|
| 164 |
+
repair_primitive: "RestoreColumn"
|
| 165 |
+
category: "dataset_drift"
|
| 166 |
+
|
| 167 |
+
- version_range: "transformers 4.36 -> 4.42"
|
| 168 |
+
affected_api: "TrainingArguments.report_to"
|
| 169 |
+
description: "Default report_to changed from 'all' to 'none'"
|
| 170 |
+
breakage_primitive: "ModifyConfigField"
|
| 171 |
+
params:
|
| 172 |
+
config_class: "TrainingArguments"
|
| 173 |
+
field_name: "report_to"
|
| 174 |
+
new_value: '"all"'
|
| 175 |
+
repair_primitive: "RestoreConfigField"
|
| 176 |
+
category: "config_drift"
|
| 177 |
+
|
| 178 |
+
- version_range: "transformers 4.40 -> 4.50"
|
| 179 |
+
affected_api: "imports"
|
| 180 |
+
description: "transformers.deepspeed moved to accelerate.utils.deepspeed"
|
| 181 |
+
breakage_primitive: "DeprecateImport"
|
| 182 |
+
params:
|
| 183 |
+
old_module: "from transformers.deepspeed"
|
| 184 |
+
new_module: "from accelerate.utils.deepspeed"
|
| 185 |
+
repair_primitive: "RestoreImport"
|
| 186 |
+
category: "import_drift"
|
| 187 |
+
|
| 188 |
+
- version_range: "transformers 4.45 -> 4.50"
|
| 189 |
+
affected_api: "Tokenizer return"
|
| 190 |
+
description: "Tokenizer call output now returns a BatchEncoding with .encodings attribute"
|
| 191 |
+
breakage_primitive: "ChangeReturnType"
|
| 192 |
+
params:
|
| 193 |
+
function_name: "tokenizer"
|
| 194 |
+
old_access: "tokenizer(text)"
|
| 195 |
+
new_access: "tokenizer(text).encodings"
|
| 196 |
+
repair_primitive: "RestoreReturnAccess"
|
| 197 |
+
category: "api_drift"
|
| 198 |
+
|
| 199 |
+
- version_range: "transformers 4.30 -> 4.40"
|
| 200 |
+
affected_api: "save_pretrained"
|
| 201 |
+
description: "save_pretrained -> save_pretrained_directory rename in some classes"
|
| 202 |
+
breakage_primitive: "RenameApiCall"
|
| 203 |
+
params:
|
| 204 |
+
old_name: "save_pretrained"
|
| 205 |
+
new_name: "save_pretrained_directory"
|
| 206 |
+
repair_primitive: "RestoreApiCall"
|
| 207 |
+
category: "api_drift"
|
| 208 |
+
|
| 209 |
+
- version_range: "transformers 4.45 -> 4.50"
|
| 210 |
+
affected_api: "TrainingArguments.no_cuda"
|
| 211 |
+
description: "no_cuda renamed to use_cpu (logic inverted)"
|
| 212 |
+
breakage_primitive: "RenameApiCall"
|
| 213 |
+
params:
|
| 214 |
+
old_name: "no_cuda"
|
| 215 |
+
new_name: "use_cpu"
|
| 216 |
+
repair_primitive: "RestoreApiCall"
|
| 217 |
+
category: "config_drift"
|
forgeenv/primitives/repair_primitives.py
CHANGED
|
@@ -1,241 +1,241 @@
|
|
| 1 |
-
"""Repair primitives — direct inverses of the 8 breakage primitives.
|
| 2 |
-
|
| 3 |
-
Used during warm-start data generation: for every (script, breakage)
|
| 4 |
-
pair we know the canonical repair, so we can write SFT pairs.
|
| 5 |
-
|
| 6 |
-
These are also useful for unit-testing the breakage primitives:
|
| 7 |
-
apply(breakage) then apply(repair) should be (close to) the identity.
|
| 8 |
-
"""
|
| 9 |
-
from __future__ import annotations
|
| 10 |
-
|
| 11 |
-
import re
|
| 12 |
-
from abc import ABC, abstractmethod
|
| 13 |
-
from dataclasses import dataclass, field
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
@dataclass
|
| 17 |
-
class RepairPrimitive(ABC):
|
| 18 |
-
category: str = field(default="generic", init=False)
|
| 19 |
-
name: str = field(default="RepairPrimitive", init=False)
|
| 20 |
-
description: str = field(default="", init=False)
|
| 21 |
-
|
| 22 |
-
@abstractmethod
|
| 23 |
-
def apply(self, script: str) -> str:
|
| 24 |
-
"""Transform `script` to undo the corresponding breakage."""
|
| 25 |
-
|
| 26 |
-
def to_spec(self) -> dict:
|
| 27 |
-
return {
|
| 28 |
-
"primitive_type": self.__class__.__name__,
|
| 29 |
-
"category": self.category,
|
| 30 |
-
"params": self._get_params(),
|
| 31 |
-
}
|
| 32 |
-
|
| 33 |
-
@abstractmethod
|
| 34 |
-
def _get_params(self) -> dict:
|
| 35 |
-
"""Return JSON-serializable constructor parameters."""
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
@dataclass
|
| 39 |
-
class RestoreApiCall(RepairPrimitive):
|
| 40 |
-
new_name: str = ""
|
| 41 |
-
old_name: str = ""
|
| 42 |
-
|
| 43 |
-
def __post_init__(self) -> None:
|
| 44 |
-
self.category = "api_drift"
|
| 45 |
-
self.name = "RestoreApiCall"
|
| 46 |
-
self.description = f"Rename {self.new_name} -> {self.old_name}"
|
| 47 |
-
|
| 48 |
-
def apply(self, script: str) -> str:
|
| 49 |
-
if not self.new_name:
|
| 50 |
-
return script
|
| 51 |
-
pattern = re.compile(rf"(?<!\w){re.escape(self.new_name)}(?!\w)")
|
| 52 |
-
return pattern.sub(self.old_name, script)
|
| 53 |
-
|
| 54 |
-
def _get_params(self) -> dict:
|
| 55 |
-
return {"new_name": self.new_name, "old_name": self.old_name}
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
@dataclass
|
| 59 |
-
class RestoreImport(RepairPrimitive):
|
| 60 |
-
new_module: str = ""
|
| 61 |
-
old_module: str = ""
|
| 62 |
-
|
| 63 |
-
def __post_init__(self) -> None:
|
| 64 |
-
self.category = "import_drift"
|
| 65 |
-
self.name = "RestoreImport"
|
| 66 |
-
self.description = f"Restore import {self.new_module} -> {self.old_module}"
|
| 67 |
-
|
| 68 |
-
def apply(self, script: str) -> str:
|
| 69 |
-
return script.replace(self.new_module, self.old_module)
|
| 70 |
-
|
| 71 |
-
def _get_params(self) -> dict:
|
| 72 |
-
return {"new_module": self.new_module, "old_module": self.old_module}
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
@dataclass
|
| 76 |
-
class RestoreArgument(RepairPrimitive):
|
| 77 |
-
"""Re-add a removed argument to a function call."""
|
| 78 |
-
|
| 79 |
-
function_name: str = ""
|
| 80 |
-
arg_name: str = ""
|
| 81 |
-
arg_value: str = ""
|
| 82 |
-
|
| 83 |
-
def __post_init__(self) -> None:
|
| 84 |
-
self.category = "api_drift"
|
| 85 |
-
self.name = "RestoreArgument"
|
| 86 |
-
self.description = (
|
| 87 |
-
f"Add {self.arg_name}={self.arg_value} to {self.function_name}()"
|
| 88 |
-
)
|
| 89 |
-
|
| 90 |
-
def apply(self, script: str) -> str:
|
| 91 |
-
if not self.function_name:
|
| 92 |
-
return script
|
| 93 |
-
# Insert the kwarg right after the function-name's opening paren.
|
| 94 |
-
pattern = rf"({re.escape(self.function_name)}\s*\()(\s*)"
|
| 95 |
-
replacement = rf"\g<1>{self.arg_name}={self.arg_value}, \g<2>"
|
| 96 |
-
return re.sub(pattern, replacement, script, count=1)
|
| 97 |
-
|
| 98 |
-
def _get_params(self) -> dict:
|
| 99 |
-
return {
|
| 100 |
-
"function_name": self.function_name,
|
| 101 |
-
"arg_name": self.arg_name,
|
| 102 |
-
"arg_value": self.arg_value,
|
| 103 |
-
}
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
@dataclass
|
| 107 |
-
class RestoreConfigField(RepairPrimitive):
|
| 108 |
-
field_name: str = ""
|
| 109 |
-
old_value: str = ""
|
| 110 |
-
|
| 111 |
-
def __post_init__(self) -> None:
|
| 112 |
-
self.category = "config_drift"
|
| 113 |
-
self.name = "RestoreConfigField"
|
| 114 |
-
self.description = f"Restore {self.field_name}={self.old_value}"
|
| 115 |
-
|
| 116 |
-
def apply(self, script: str) -> str:
|
| 117 |
-
if not self.field_name:
|
| 118 |
-
return script
|
| 119 |
-
pattern = rf"({re.escape(self.field_name)}\s*=\s*)([^,)\n]+)"
|
| 120 |
-
return re.sub(pattern, rf"\g<1>{self.old_value}", script)
|
| 121 |
-
|
| 122 |
-
def _get_params(self) -> dict:
|
| 123 |
-
return {"field_name": self.field_name, "old_value": self.old_value}
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
@dataclass
|
| 127 |
-
class RestoreColumn(RepairPrimitive):
|
| 128 |
-
new_column: str = ""
|
| 129 |
-
old_column: str = ""
|
| 130 |
-
|
| 131 |
-
def __post_init__(self) -> None:
|
| 132 |
-
self.category = "dataset_drift"
|
| 133 |
-
self.name = "RestoreColumn"
|
| 134 |
-
self.description = f"Rename column {self.new_column} -> {self.old_column}"
|
| 135 |
-
|
| 136 |
-
def apply(self, script: str) -> str:
|
| 137 |
-
return script.replace(
|
| 138 |
-
f'"{self.new_column}"', f'"{self.old_column}"'
|
| 139 |
-
).replace(
|
| 140 |
-
f"'{self.new_column}'", f"'{self.old_column}'"
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
-
def _get_params(self) -> dict:
|
| 144 |
-
return {"new_column": self.new_column, "old_column": self.old_column}
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
@dataclass
|
| 148 |
-
class RestoreTokenizerKwarg(RepairPrimitive):
|
| 149 |
-
new_kwarg: str = ""
|
| 150 |
-
new_value: str = ""
|
| 151 |
-
old_kwarg: str = ""
|
| 152 |
-
old_value: str = ""
|
| 153 |
-
|
| 154 |
-
def __post_init__(self) -> None:
|
| 155 |
-
self.category = "tokenizer_drift"
|
| 156 |
-
self.name = "RestoreTokenizerKwarg"
|
| 157 |
-
self.description = (
|
| 158 |
-
f"Restore tokenizer {self.new_kwarg}={self.new_value} -> "
|
| 159 |
-
f"{self.old_kwarg}={self.old_value}"
|
| 160 |
-
)
|
| 161 |
-
|
| 162 |
-
def apply(self, script: str) -> str:
|
| 163 |
-
if not self.new_kwarg:
|
| 164 |
-
return script
|
| 165 |
-
pattern = rf"{re.escape(self.new_kwarg)}\s*=\s*{re.escape(self.new_value)}"
|
| 166 |
-
replacement = f"{self.old_kwarg}={self.old_value}"
|
| 167 |
-
return re.sub(pattern, replacement, script)
|
| 168 |
-
|
| 169 |
-
def _get_params(self) -> dict:
|
| 170 |
-
return {
|
| 171 |
-
"new_kwarg": self.new_kwarg,
|
| 172 |
-
"new_value": self.new_value,
|
| 173 |
-
"old_kwarg": self.old_kwarg,
|
| 174 |
-
"old_value": self.old_value,
|
| 175 |
-
}
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
@dataclass
|
| 179 |
-
class RestoreMethod(RepairPrimitive):
|
| 180 |
-
method_name: str = ""
|
| 181 |
-
|
| 182 |
-
def __post_init__(self) -> None:
|
| 183 |
-
self.category = "api_drift"
|
| 184 |
-
self.name = "RestoreMethod"
|
| 185 |
-
self.description = f"Un-deprecate .{self.method_name}()"
|
| 186 |
-
|
| 187 |
-
def apply(self, script: str) -> str:
|
| 188 |
-
if not self.method_name:
|
| 189 |
-
return script
|
| 190 |
-
return script.replace(
|
| 191 |
-
f".{self.method_name}_DEPRECATED(", f".{self.method_name}("
|
| 192 |
-
)
|
| 193 |
-
|
| 194 |
-
def _get_params(self) -> dict:
|
| 195 |
-
return {"method_name": self.method_name}
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
@dataclass
|
| 199 |
-
class RestoreReturnAccess(RepairPrimitive):
|
| 200 |
-
new_access: str = ""
|
| 201 |
-
old_access: str = ""
|
| 202 |
-
|
| 203 |
-
def __post_init__(self) -> None:
|
| 204 |
-
self.category = "api_drift"
|
| 205 |
-
self.name = "RestoreReturnAccess"
|
| 206 |
-
self.description = f"Restore return-access {self.new_access} -> {self.old_access}"
|
| 207 |
-
|
| 208 |
-
def apply(self, script: str) -> str:
|
| 209 |
-
if not self.new_access:
|
| 210 |
-
return script
|
| 211 |
-
return script.replace(self.new_access, self.old_access)
|
| 212 |
-
|
| 213 |
-
def _get_params(self) -> dict:
|
| 214 |
-
return {"new_access": self.new_access, "old_access": self.old_access}
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
REPAIR_REGISTRY: dict[str, type[RepairPrimitive]] = {
|
| 218 |
-
"RestoreApiCall": RestoreApiCall,
|
| 219 |
-
"RestoreImport": RestoreImport,
|
| 220 |
-
"RestoreArgument": RestoreArgument,
|
| 221 |
-
"RestoreConfigField": RestoreConfigField,
|
| 222 |
-
"RestoreColumn": RestoreColumn,
|
| 223 |
-
"RestoreTokenizerKwarg": RestoreTokenizerKwarg,
|
| 224 |
-
"RestoreMethod": RestoreMethod,
|
| 225 |
-
"RestoreReturnAccess": RestoreReturnAccess,
|
| 226 |
-
}
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
# Map a breakage primitive's class name to the repair-primitive class that
|
| 230 |
-
# inverts it. Used by the warm-start pair generator and by the demo / repair
|
| 231 |
-
# library curator.
|
| 232 |
-
BREAKAGE_TO_REPAIR: dict[str, str] = {
|
| 233 |
-
"RenameApiCall": "RestoreApiCall",
|
| 234 |
-
"DeprecateImport": "RestoreImport",
|
| 235 |
-
"ChangeArgumentSignature": "RestoreArgument",
|
| 236 |
-
"ModifyConfigField": "RestoreConfigField",
|
| 237 |
-
"RestructureDatasetSchema": "RestoreColumn",
|
| 238 |
-
"ChangeTokenizerBehavior": "RestoreTokenizerKwarg",
|
| 239 |
-
"RemoveDeprecatedMethod": "RestoreMethod",
|
| 240 |
-
"ChangeReturnType": "RestoreReturnAccess",
|
| 241 |
-
}
|
|
|
|
| 1 |
+
"""Repair primitives — direct inverses of the 8 breakage primitives.
|
| 2 |
+
|
| 3 |
+
Used during warm-start data generation: for every (script, breakage)
|
| 4 |
+
pair we know the canonical repair, so we can write SFT pairs.
|
| 5 |
+
|
| 6 |
+
These are also useful for unit-testing the breakage primitives:
|
| 7 |
+
apply(breakage) then apply(repair) should be (close to) the identity.
|
| 8 |
+
"""
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import re
|
| 12 |
+
from abc import ABC, abstractmethod
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class RepairPrimitive(ABC):
|
| 18 |
+
category: str = field(default="generic", init=False)
|
| 19 |
+
name: str = field(default="RepairPrimitive", init=False)
|
| 20 |
+
description: str = field(default="", init=False)
|
| 21 |
+
|
| 22 |
+
@abstractmethod
|
| 23 |
+
def apply(self, script: str) -> str:
|
| 24 |
+
"""Transform `script` to undo the corresponding breakage."""
|
| 25 |
+
|
| 26 |
+
def to_spec(self) -> dict:
|
| 27 |
+
return {
|
| 28 |
+
"primitive_type": self.__class__.__name__,
|
| 29 |
+
"category": self.category,
|
| 30 |
+
"params": self._get_params(),
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def _get_params(self) -> dict:
|
| 35 |
+
"""Return JSON-serializable constructor parameters."""
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class RestoreApiCall(RepairPrimitive):
|
| 40 |
+
new_name: str = ""
|
| 41 |
+
old_name: str = ""
|
| 42 |
+
|
| 43 |
+
def __post_init__(self) -> None:
|
| 44 |
+
self.category = "api_drift"
|
| 45 |
+
self.name = "RestoreApiCall"
|
| 46 |
+
self.description = f"Rename {self.new_name} -> {self.old_name}"
|
| 47 |
+
|
| 48 |
+
def apply(self, script: str) -> str:
|
| 49 |
+
if not self.new_name:
|
| 50 |
+
return script
|
| 51 |
+
pattern = re.compile(rf"(?<!\w){re.escape(self.new_name)}(?!\w)")
|
| 52 |
+
return pattern.sub(self.old_name, script)
|
| 53 |
+
|
| 54 |
+
def _get_params(self) -> dict:
|
| 55 |
+
return {"new_name": self.new_name, "old_name": self.old_name}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass
|
| 59 |
+
class RestoreImport(RepairPrimitive):
|
| 60 |
+
new_module: str = ""
|
| 61 |
+
old_module: str = ""
|
| 62 |
+
|
| 63 |
+
def __post_init__(self) -> None:
|
| 64 |
+
self.category = "import_drift"
|
| 65 |
+
self.name = "RestoreImport"
|
| 66 |
+
self.description = f"Restore import {self.new_module} -> {self.old_module}"
|
| 67 |
+
|
| 68 |
+
def apply(self, script: str) -> str:
|
| 69 |
+
return script.replace(self.new_module, self.old_module)
|
| 70 |
+
|
| 71 |
+
def _get_params(self) -> dict:
|
| 72 |
+
return {"new_module": self.new_module, "old_module": self.old_module}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class RestoreArgument(RepairPrimitive):
|
| 77 |
+
"""Re-add a removed argument to a function call."""
|
| 78 |
+
|
| 79 |
+
function_name: str = ""
|
| 80 |
+
arg_name: str = ""
|
| 81 |
+
arg_value: str = ""
|
| 82 |
+
|
| 83 |
+
def __post_init__(self) -> None:
|
| 84 |
+
self.category = "api_drift"
|
| 85 |
+
self.name = "RestoreArgument"
|
| 86 |
+
self.description = (
|
| 87 |
+
f"Add {self.arg_name}={self.arg_value} to {self.function_name}()"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def apply(self, script: str) -> str:
|
| 91 |
+
if not self.function_name:
|
| 92 |
+
return script
|
| 93 |
+
# Insert the kwarg right after the function-name's opening paren.
|
| 94 |
+
pattern = rf"({re.escape(self.function_name)}\s*\()(\s*)"
|
| 95 |
+
replacement = rf"\g<1>{self.arg_name}={self.arg_value}, \g<2>"
|
| 96 |
+
return re.sub(pattern, replacement, script, count=1)
|
| 97 |
+
|
| 98 |
+
def _get_params(self) -> dict:
|
| 99 |
+
return {
|
| 100 |
+
"function_name": self.function_name,
|
| 101 |
+
"arg_name": self.arg_name,
|
| 102 |
+
"arg_value": self.arg_value,
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@dataclass
|
| 107 |
+
class RestoreConfigField(RepairPrimitive):
|
| 108 |
+
field_name: str = ""
|
| 109 |
+
old_value: str = ""
|
| 110 |
+
|
| 111 |
+
def __post_init__(self) -> None:
|
| 112 |
+
self.category = "config_drift"
|
| 113 |
+
self.name = "RestoreConfigField"
|
| 114 |
+
self.description = f"Restore {self.field_name}={self.old_value}"
|
| 115 |
+
|
| 116 |
+
def apply(self, script: str) -> str:
|
| 117 |
+
if not self.field_name:
|
| 118 |
+
return script
|
| 119 |
+
pattern = rf"({re.escape(self.field_name)}\s*=\s*)([^,)\n]+)"
|
| 120 |
+
return re.sub(pattern, rf"\g<1>{self.old_value}", script)
|
| 121 |
+
|
| 122 |
+
def _get_params(self) -> dict:
|
| 123 |
+
return {"field_name": self.field_name, "old_value": self.old_value}
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@dataclass
|
| 127 |
+
class RestoreColumn(RepairPrimitive):
|
| 128 |
+
new_column: str = ""
|
| 129 |
+
old_column: str = ""
|
| 130 |
+
|
| 131 |
+
def __post_init__(self) -> None:
|
| 132 |
+
self.category = "dataset_drift"
|
| 133 |
+
self.name = "RestoreColumn"
|
| 134 |
+
self.description = f"Rename column {self.new_column} -> {self.old_column}"
|
| 135 |
+
|
| 136 |
+
def apply(self, script: str) -> str:
|
| 137 |
+
return script.replace(
|
| 138 |
+
f'"{self.new_column}"', f'"{self.old_column}"'
|
| 139 |
+
).replace(
|
| 140 |
+
f"'{self.new_column}'", f"'{self.old_column}'"
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
def _get_params(self) -> dict:
|
| 144 |
+
return {"new_column": self.new_column, "old_column": self.old_column}
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@dataclass
|
| 148 |
+
class RestoreTokenizerKwarg(RepairPrimitive):
|
| 149 |
+
new_kwarg: str = ""
|
| 150 |
+
new_value: str = ""
|
| 151 |
+
old_kwarg: str = ""
|
| 152 |
+
old_value: str = ""
|
| 153 |
+
|
| 154 |
+
def __post_init__(self) -> None:
|
| 155 |
+
self.category = "tokenizer_drift"
|
| 156 |
+
self.name = "RestoreTokenizerKwarg"
|
| 157 |
+
self.description = (
|
| 158 |
+
f"Restore tokenizer {self.new_kwarg}={self.new_value} -> "
|
| 159 |
+
f"{self.old_kwarg}={self.old_value}"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def apply(self, script: str) -> str:
|
| 163 |
+
if not self.new_kwarg:
|
| 164 |
+
return script
|
| 165 |
+
pattern = rf"{re.escape(self.new_kwarg)}\s*=\s*{re.escape(self.new_value)}"
|
| 166 |
+
replacement = f"{self.old_kwarg}={self.old_value}"
|
| 167 |
+
return re.sub(pattern, replacement, script)
|
| 168 |
+
|
| 169 |
+
def _get_params(self) -> dict:
|
| 170 |
+
return {
|
| 171 |
+
"new_kwarg": self.new_kwarg,
|
| 172 |
+
"new_value": self.new_value,
|
| 173 |
+
"old_kwarg": self.old_kwarg,
|
| 174 |
+
"old_value": self.old_value,
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@dataclass
|
| 179 |
+
class RestoreMethod(RepairPrimitive):
|
| 180 |
+
method_name: str = ""
|
| 181 |
+
|
| 182 |
+
def __post_init__(self) -> None:
|
| 183 |
+
self.category = "api_drift"
|
| 184 |
+
self.name = "RestoreMethod"
|
| 185 |
+
self.description = f"Un-deprecate .{self.method_name}()"
|
| 186 |
+
|
| 187 |
+
def apply(self, script: str) -> str:
|
| 188 |
+
if not self.method_name:
|
| 189 |
+
return script
|
| 190 |
+
return script.replace(
|
| 191 |
+
f".{self.method_name}_DEPRECATED(", f".{self.method_name}("
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
def _get_params(self) -> dict:
|
| 195 |
+
return {"method_name": self.method_name}
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@dataclass
|
| 199 |
+
class RestoreReturnAccess(RepairPrimitive):
|
| 200 |
+
new_access: str = ""
|
| 201 |
+
old_access: str = ""
|
| 202 |
+
|
| 203 |
+
def __post_init__(self) -> None:
|
| 204 |
+
self.category = "api_drift"
|
| 205 |
+
self.name = "RestoreReturnAccess"
|
| 206 |
+
self.description = f"Restore return-access {self.new_access} -> {self.old_access}"
|
| 207 |
+
|
| 208 |
+
def apply(self, script: str) -> str:
|
| 209 |
+
if not self.new_access:
|
| 210 |
+
return script
|
| 211 |
+
return script.replace(self.new_access, self.old_access)
|
| 212 |
+
|
| 213 |
+
def _get_params(self) -> dict:
|
| 214 |
+
return {"new_access": self.new_access, "old_access": self.old_access}
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
REPAIR_REGISTRY: dict[str, type[RepairPrimitive]] = {
|
| 218 |
+
"RestoreApiCall": RestoreApiCall,
|
| 219 |
+
"RestoreImport": RestoreImport,
|
| 220 |
+
"RestoreArgument": RestoreArgument,
|
| 221 |
+
"RestoreConfigField": RestoreConfigField,
|
| 222 |
+
"RestoreColumn": RestoreColumn,
|
| 223 |
+
"RestoreTokenizerKwarg": RestoreTokenizerKwarg,
|
| 224 |
+
"RestoreMethod": RestoreMethod,
|
| 225 |
+
"RestoreReturnAccess": RestoreReturnAccess,
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
# Map a breakage primitive's class name to the repair-primitive class that
|
| 230 |
+
# inverts it. Used by the warm-start pair generator and by the demo / repair
|
| 231 |
+
# library curator.
|
| 232 |
+
BREAKAGE_TO_REPAIR: dict[str, str] = {
|
| 233 |
+
"RenameApiCall": "RestoreApiCall",
|
| 234 |
+
"DeprecateImport": "RestoreImport",
|
| 235 |
+
"ChangeArgumentSignature": "RestoreArgument",
|
| 236 |
+
"ModifyConfigField": "RestoreConfigField",
|
| 237 |
+
"RestructureDatasetSchema": "RestoreColumn",
|
| 238 |
+
"ChangeTokenizerBehavior": "RestoreTokenizerKwarg",
|
| 239 |
+
"RemoveDeprecatedMethod": "RestoreMethod",
|
| 240 |
+
"ChangeReturnType": "RestoreReturnAccess",
|
| 241 |
+
}
|
forgeenv/roles/drift_generator.py
CHANGED
|
@@ -1,170 +1,170 @@
|
|
| 1 |
-
"""Drift Generator parser + a deterministic baseline policy.
|
| 2 |
-
|
| 3 |
-
In training the LLM produces a JSON breakage spec; we parse it. In rollouts
|
| 4 |
-
where we want a baseline (or a fallback when the LLM emits malformed JSON)
|
| 5 |
-
we use `BaselineDriftGenerator`, which samples from the per-category set of
|
| 6 |
-
known good primitive parameterisations.
|
| 7 |
-
"""
|
| 8 |
-
from __future__ import annotations
|
| 9 |
-
|
| 10 |
-
import json
|
| 11 |
-
import random
|
| 12 |
-
import re
|
| 13 |
-
from dataclasses import dataclass
|
| 14 |
-
from typing import Optional
|
| 15 |
-
|
| 16 |
-
from forgeenv.primitives.breakage_primitives import (
|
| 17 |
-
PRIMITIVE_REGISTRY,
|
| 18 |
-
parse_breakage_spec,
|
| 19 |
-
BreakagePrimitive,
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
_JSON_RE = re.compile(r"\{[\s\S]*\}")
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def parse_drift_output(text: str) -> Optional[dict]:
|
| 27 |
-
"""Extract a JSON object from possibly-noisy LLM output.
|
| 28 |
-
|
| 29 |
-
Handles markdown fences, prose preamble, trailing commas (best-effort).
|
| 30 |
-
Returns None on failure.
|
| 31 |
-
"""
|
| 32 |
-
if not text:
|
| 33 |
-
return None
|
| 34 |
-
text = text.strip()
|
| 35 |
-
if text.startswith("```"):
|
| 36 |
-
text = re.sub(r"^```[a-zA-Z]*\n?", "", text)
|
| 37 |
-
text = re.sub(r"\n?```$", "", text)
|
| 38 |
-
match = _JSON_RE.search(text)
|
| 39 |
-
if not match:
|
| 40 |
-
return None
|
| 41 |
-
blob = match.group(0)
|
| 42 |
-
try:
|
| 43 |
-
return json.loads(blob)
|
| 44 |
-
except json.JSONDecodeError:
|
| 45 |
-
cleaned = re.sub(r",\s*([}\]])", r"\1", blob)
|
| 46 |
-
try:
|
| 47 |
-
return json.loads(cleaned)
|
| 48 |
-
except json.JSONDecodeError:
|
| 49 |
-
return None
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def parse_drift_to_primitive(text: str) -> Optional[BreakagePrimitive]:
|
| 53 |
-
"""End-to-end: LLM text -> validated BreakagePrimitive (or None)."""
|
| 54 |
-
data = parse_drift_output(text)
|
| 55 |
-
if not isinstance(data, dict):
|
| 56 |
-
return None
|
| 57 |
-
try:
|
| 58 |
-
return parse_breakage_spec(data)
|
| 59 |
-
except (ValueError, TypeError):
|
| 60 |
-
return None
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
# ---------------------------------------------------------------- baselines
|
| 64 |
-
_DEFAULT_PARAMS_BY_TYPE: dict[str, list[dict]] = {
|
| 65 |
-
"RenameApiCall": [
|
| 66 |
-
{"old_name": "trainer.train", "new_name": "trainer.start_training"},
|
| 67 |
-
{"old_name": "save_pretrained", "new_name": "save_to_hub"},
|
| 68 |
-
{"old_name": "from_pretrained", "new_name": "load_from_hub"},
|
| 69 |
-
],
|
| 70 |
-
"DeprecateImport": [
|
| 71 |
-
{
|
| 72 |
-
"old_module": "from transformers import Trainer",
|
| 73 |
-
"new_module": "from transformers.legacy import Trainer",
|
| 74 |
-
},
|
| 75 |
-
{
|
| 76 |
-
"old_module": "from transformers import TrainingArguments",
|
| 77 |
-
"new_module": "from transformers.training import TrainingArguments",
|
| 78 |
-
},
|
| 79 |
-
],
|
| 80 |
-
"ChangeArgumentSignature": [
|
| 81 |
-
{
|
| 82 |
-
"function_name": "TrainingArguments",
|
| 83 |
-
"removed_arg": "num_train_epochs",
|
| 84 |
-
"added_arg": "max_steps",
|
| 85 |
-
"added_value": "1000",
|
| 86 |
-
},
|
| 87 |
-
{
|
| 88 |
-
"function_name": "TrainingArguments",
|
| 89 |
-
"removed_arg": "evaluation_strategy",
|
| 90 |
-
"added_arg": "eval_strategy",
|
| 91 |
-
"added_value": '"steps"',
|
| 92 |
-
},
|
| 93 |
-
],
|
| 94 |
-
"ModifyConfigField": [
|
| 95 |
-
{"config_class": "TrainingArguments", "field_name": "learning_rate", "new_value": "5e-3"},
|
| 96 |
-
{"config_class": "TrainingArguments", "field_name": "per_device_train_batch_size", "new_value": "1"},
|
| 97 |
-
],
|
| 98 |
-
"RestructureDatasetSchema": [
|
| 99 |
-
{"old_column": "text", "new_column": "input_text"},
|
| 100 |
-
{"old_column": "label", "new_column": "labels"},
|
| 101 |
-
{"old_column": "tokens", "new_column": "words"},
|
| 102 |
-
],
|
| 103 |
-
"ChangeTokenizerBehavior": [
|
| 104 |
-
{"old_kwarg": "padding", "old_value": "True", "new_kwarg": "pad_to_max_length", "new_value": "True"},
|
| 105 |
-
{"old_kwarg": "truncation", "old_value": "True", "new_kwarg": "truncate", "new_value": "True"},
|
| 106 |
-
],
|
| 107 |
-
"RemoveDeprecatedMethod": [
|
| 108 |
-
{"class_name": "Trainer", "method_name": "evaluate", "replacement": "evaluation_loop"},
|
| 109 |
-
{"class_name": "Trainer", "method_name": "save_model", "replacement": "save_to_hub"},
|
| 110 |
-
],
|
| 111 |
-
"ChangeReturnType": [
|
| 112 |
-
{"function_name": "Trainer.predict", "old_access": ".predictions", "new_access": "[0]"},
|
| 113 |
-
{"function_name": "tokenizer", "old_access": '["input_ids"]', "new_access": ".input_ids"},
|
| 114 |
-
],
|
| 115 |
-
}
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
@dataclass
|
| 119 |
-
class BaselineDriftGenerator:
|
| 120 |
-
"""Deterministic stand-in for the LLM Drift Generator.
|
| 121 |
-
|
| 122 |
-
Used for warm-start data, baseline rollouts, and unit tests.
|
| 123 |
-
"""
|
| 124 |
-
|
| 125 |
-
seed: Optional[int] = None
|
| 126 |
-
|
| 127 |
-
def __post_init__(self) -> None:
|
| 128 |
-
self._rng = random.Random(self.seed) if self.seed is not None else random
|
| 129 |
-
|
| 130 |
-
def propose(
|
| 131 |
-
self, target_category: str = "", script: str = ""
|
| 132 |
-
) -> dict:
|
| 133 |
-
"""Produce a JSON-serializable breakage spec for `target_category`.
|
| 134 |
-
|
| 135 |
-
Order of preference:
|
| 136 |
-
1. A primitive of `target_category` whose default params apply to `script`.
|
| 137 |
-
2. A primitive of any type whose default params apply to `script`.
|
| 138 |
-
3. A primitive of `target_category` (no-op fallback).
|
| 139 |
-
"""
|
| 140 |
-
|
| 141 |
-
preferred_types = (
|
| 142 |
-
[target_category] if target_category in _DEFAULT_PARAMS_BY_TYPE else []
|
| 143 |
-
)
|
| 144 |
-
all_types = list(_DEFAULT_PARAMS_BY_TYPE.keys())
|
| 145 |
-
|
| 146 |
-
for type_set in (preferred_types, all_types):
|
| 147 |
-
shuffled = self._rng.sample(type_set, len(type_set)) if type_set else []
|
| 148 |
-
for ptype in shuffled:
|
| 149 |
-
for params in self._rng.sample(
|
| 150 |
-
_DEFAULT_PARAMS_BY_TYPE[ptype],
|
| 151 |
-
len(_DEFAULT_PARAMS_BY_TYPE[ptype]),
|
| 152 |
-
):
|
| 153 |
-
if self._params_apply_to_script(ptype, params, script):
|
| 154 |
-
return {"primitive_type": ptype, "params": dict(params)}
|
| 155 |
-
|
| 156 |
-
ptype = preferred_types[0] if preferred_types else all_types[0]
|
| 157 |
-
return {
|
| 158 |
-
"primitive_type": ptype,
|
| 159 |
-
"params": dict(_DEFAULT_PARAMS_BY_TYPE[ptype][0]),
|
| 160 |
-
}
|
| 161 |
-
|
| 162 |
-
@staticmethod
|
| 163 |
-
def _params_apply_to_script(ptype: str, params: dict, script: str) -> bool:
|
| 164 |
-
"""Heuristic: would this primitive actually mutate `script`?"""
|
| 165 |
-
if not script:
|
| 166 |
-
return True
|
| 167 |
-
for key in ("old_name", "old_module", "removed_arg", "field_name", "old_column", "old_kwarg", "method_name", "old_access"):
|
| 168 |
-
if key in params and params[key] and params[key] in script:
|
| 169 |
-
return True
|
| 170 |
-
return False
|
|
|
|
| 1 |
+
"""Drift Generator parser + a deterministic baseline policy.
|
| 2 |
+
|
| 3 |
+
In training the LLM produces a JSON breakage spec; we parse it. In rollouts
|
| 4 |
+
where we want a baseline (or a fallback when the LLM emits malformed JSON)
|
| 5 |
+
we use `BaselineDriftGenerator`, which samples from the per-category set of
|
| 6 |
+
known good primitive parameterisations.
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import random
|
| 12 |
+
import re
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from typing import Optional
|
| 15 |
+
|
| 16 |
+
from forgeenv.primitives.breakage_primitives import (
|
| 17 |
+
PRIMITIVE_REGISTRY,
|
| 18 |
+
parse_breakage_spec,
|
| 19 |
+
BreakagePrimitive,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
_JSON_RE = re.compile(r"\{[\s\S]*\}")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def parse_drift_output(text: str) -> Optional[dict]:
|
| 27 |
+
"""Extract a JSON object from possibly-noisy LLM output.
|
| 28 |
+
|
| 29 |
+
Handles markdown fences, prose preamble, trailing commas (best-effort).
|
| 30 |
+
Returns None on failure.
|
| 31 |
+
"""
|
| 32 |
+
if not text:
|
| 33 |
+
return None
|
| 34 |
+
text = text.strip()
|
| 35 |
+
if text.startswith("```"):
|
| 36 |
+
text = re.sub(r"^```[a-zA-Z]*\n?", "", text)
|
| 37 |
+
text = re.sub(r"\n?```$", "", text)
|
| 38 |
+
match = _JSON_RE.search(text)
|
| 39 |
+
if not match:
|
| 40 |
+
return None
|
| 41 |
+
blob = match.group(0)
|
| 42 |
+
try:
|
| 43 |
+
return json.loads(blob)
|
| 44 |
+
except json.JSONDecodeError:
|
| 45 |
+
cleaned = re.sub(r",\s*([}\]])", r"\1", blob)
|
| 46 |
+
try:
|
| 47 |
+
return json.loads(cleaned)
|
| 48 |
+
except json.JSONDecodeError:
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def parse_drift_to_primitive(text: str) -> Optional[BreakagePrimitive]:
|
| 53 |
+
"""End-to-end: LLM text -> validated BreakagePrimitive (or None)."""
|
| 54 |
+
data = parse_drift_output(text)
|
| 55 |
+
if not isinstance(data, dict):
|
| 56 |
+
return None
|
| 57 |
+
try:
|
| 58 |
+
return parse_breakage_spec(data)
|
| 59 |
+
except (ValueError, TypeError):
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ---------------------------------------------------------------- baselines
|
| 64 |
+
_DEFAULT_PARAMS_BY_TYPE: dict[str, list[dict]] = {
|
| 65 |
+
"RenameApiCall": [
|
| 66 |
+
{"old_name": "trainer.train", "new_name": "trainer.start_training"},
|
| 67 |
+
{"old_name": "save_pretrained", "new_name": "save_to_hub"},
|
| 68 |
+
{"old_name": "from_pretrained", "new_name": "load_from_hub"},
|
| 69 |
+
],
|
| 70 |
+
"DeprecateImport": [
|
| 71 |
+
{
|
| 72 |
+
"old_module": "from transformers import Trainer",
|
| 73 |
+
"new_module": "from transformers.legacy import Trainer",
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"old_module": "from transformers import TrainingArguments",
|
| 77 |
+
"new_module": "from transformers.training import TrainingArguments",
|
| 78 |
+
},
|
| 79 |
+
],
|
| 80 |
+
"ChangeArgumentSignature": [
|
| 81 |
+
{
|
| 82 |
+
"function_name": "TrainingArguments",
|
| 83 |
+
"removed_arg": "num_train_epochs",
|
| 84 |
+
"added_arg": "max_steps",
|
| 85 |
+
"added_value": "1000",
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"function_name": "TrainingArguments",
|
| 89 |
+
"removed_arg": "evaluation_strategy",
|
| 90 |
+
"added_arg": "eval_strategy",
|
| 91 |
+
"added_value": '"steps"',
|
| 92 |
+
},
|
| 93 |
+
],
|
| 94 |
+
"ModifyConfigField": [
|
| 95 |
+
{"config_class": "TrainingArguments", "field_name": "learning_rate", "new_value": "5e-3"},
|
| 96 |
+
{"config_class": "TrainingArguments", "field_name": "per_device_train_batch_size", "new_value": "1"},
|
| 97 |
+
],
|
| 98 |
+
"RestructureDatasetSchema": [
|
| 99 |
+
{"old_column": "text", "new_column": "input_text"},
|
| 100 |
+
{"old_column": "label", "new_column": "labels"},
|
| 101 |
+
{"old_column": "tokens", "new_column": "words"},
|
| 102 |
+
],
|
| 103 |
+
"ChangeTokenizerBehavior": [
|
| 104 |
+
{"old_kwarg": "padding", "old_value": "True", "new_kwarg": "pad_to_max_length", "new_value": "True"},
|
| 105 |
+
{"old_kwarg": "truncation", "old_value": "True", "new_kwarg": "truncate", "new_value": "True"},
|
| 106 |
+
],
|
| 107 |
+
"RemoveDeprecatedMethod": [
|
| 108 |
+
{"class_name": "Trainer", "method_name": "evaluate", "replacement": "evaluation_loop"},
|
| 109 |
+
{"class_name": "Trainer", "method_name": "save_model", "replacement": "save_to_hub"},
|
| 110 |
+
],
|
| 111 |
+
"ChangeReturnType": [
|
| 112 |
+
{"function_name": "Trainer.predict", "old_access": ".predictions", "new_access": "[0]"},
|
| 113 |
+
{"function_name": "tokenizer", "old_access": '["input_ids"]', "new_access": ".input_ids"},
|
| 114 |
+
],
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@dataclass
|
| 119 |
+
class BaselineDriftGenerator:
|
| 120 |
+
"""Deterministic stand-in for the LLM Drift Generator.
|
| 121 |
+
|
| 122 |
+
Used for warm-start data, baseline rollouts, and unit tests.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
seed: Optional[int] = None
|
| 126 |
+
|
| 127 |
+
def __post_init__(self) -> None:
|
| 128 |
+
self._rng = random.Random(self.seed) if self.seed is not None else random
|
| 129 |
+
|
| 130 |
+
def propose(
|
| 131 |
+
self, target_category: str = "", script: str = ""
|
| 132 |
+
) -> dict:
|
| 133 |
+
"""Produce a JSON-serializable breakage spec for `target_category`.
|
| 134 |
+
|
| 135 |
+
Order of preference:
|
| 136 |
+
1. A primitive of `target_category` whose default params apply to `script`.
|
| 137 |
+
2. A primitive of any type whose default params apply to `script`.
|
| 138 |
+
3. A primitive of `target_category` (no-op fallback).
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
preferred_types = (
|
| 142 |
+
[target_category] if target_category in _DEFAULT_PARAMS_BY_TYPE else []
|
| 143 |
+
)
|
| 144 |
+
all_types = list(_DEFAULT_PARAMS_BY_TYPE.keys())
|
| 145 |
+
|
| 146 |
+
for type_set in (preferred_types, all_types):
|
| 147 |
+
shuffled = self._rng.sample(type_set, len(type_set)) if type_set else []
|
| 148 |
+
for ptype in shuffled:
|
| 149 |
+
for params in self._rng.sample(
|
| 150 |
+
_DEFAULT_PARAMS_BY_TYPE[ptype],
|
| 151 |
+
len(_DEFAULT_PARAMS_BY_TYPE[ptype]),
|
| 152 |
+
):
|
| 153 |
+
if self._params_apply_to_script(ptype, params, script):
|
| 154 |
+
return {"primitive_type": ptype, "params": dict(params)}
|
| 155 |
+
|
| 156 |
+
ptype = preferred_types[0] if preferred_types else all_types[0]
|
| 157 |
+
return {
|
| 158 |
+
"primitive_type": ptype,
|
| 159 |
+
"params": dict(_DEFAULT_PARAMS_BY_TYPE[ptype][0]),
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
def _params_apply_to_script(ptype: str, params: dict, script: str) -> bool:
|
| 164 |
+
"""Heuristic: would this primitive actually mutate `script`?"""
|
| 165 |
+
if not script:
|
| 166 |
+
return True
|
| 167 |
+
for key in ("old_name", "old_module", "removed_arg", "field_name", "old_column", "old_kwarg", "method_name", "old_access"):
|
| 168 |
+
if key in params and params[key] and params[key] in script:
|
| 169 |
+
return True
|
| 170 |
+
return False
|
forgeenv/roles/prompts.py
CHANGED
|
@@ -1,102 +1,102 @@
|
|
| 1 |
-
"""System and user prompts for the two RL roles.
|
| 2 |
-
|
| 3 |
-
Both roles are trained from the same base policy (Qwen-2.5-Coder-7B) with
|
| 4 |
-
LoRA adapters per role, so role prompts are the only thing distinguishing
|
| 5 |
-
them at inference time. Keep them concise — every token is a token of GPU
|
| 6 |
-
budget during GRPO rollouts.
|
| 7 |
-
"""
|
| 8 |
-
from __future__ import annotations
|
| 9 |
-
|
| 10 |
-
from typing import Iterable
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
PRIMITIVE_DESCRIPTIONS = {
|
| 14 |
-
"RenameApiCall": "Rename a function/method call (api_drift)",
|
| 15 |
-
"DeprecateImport": "Change an import path (import_drift)",
|
| 16 |
-
"ChangeArgumentSignature": "Remove an expected kwarg from a call (api_drift)",
|
| 17 |
-
"ModifyConfigField": "Change a config-class default (config_drift)",
|
| 18 |
-
"RestructureDatasetSchema": "Rename a dataset column reference (dataset_drift)",
|
| 19 |
-
"ChangeTokenizerBehavior": "Change tokenizer call kwargs (tokenizer_drift)",
|
| 20 |
-
"RemoveDeprecatedMethod": "Remove a method, leaving a sentinel _DEPRECATED suffix (api_drift)",
|
| 21 |
-
"ChangeReturnType": "Function returns a different structure (api_drift)",
|
| 22 |
-
}
|
| 23 |
-
|
| 24 |
-
DRIFT_GENERATOR_SYSTEM_PROMPT = """You are the Drift Generator.
|
| 25 |
-
You see a working HuggingFace training script and the curriculum target category.
|
| 26 |
-
Output exactly one JSON object describing a breakage primitive that simulates
|
| 27 |
-
realistic library version drift. The primitive must:
|
| 28 |
-
1. Be PLAUSIBLE — match the kind of breakage that happens between real
|
| 29 |
-
transformers/datasets/trl releases.
|
| 30 |
-
2. Be SOLVABLE — the Repair Agent should be able to fix it from the error trace alone.
|
| 31 |
-
3. Match the requested target_category.
|
| 32 |
-
|
| 33 |
-
Output schema:
|
| 34 |
-
{"primitive_type": "<one of the 8 types>", "params": { ... }}
|
| 35 |
-
|
| 36 |
-
Available primitive types and parameter schemas:
|
| 37 |
-
- RenameApiCall: {"old_name": str, "new_name": str}
|
| 38 |
-
- DeprecateImport: {"old_module": str, "new_module": str}
|
| 39 |
-
- ChangeArgumentSignature: {"function_name": str, "removed_arg": str, "added_arg": str, "added_value": str}
|
| 40 |
-
- ModifyConfigField: {"config_class": str, "field_name": str, "new_value": str}
|
| 41 |
-
- RestructureDatasetSchema: {"old_column": str, "new_column": str}
|
| 42 |
-
- ChangeTokenizerBehavior: {"old_kwarg": str, "old_value": str, "new_kwarg": str, "new_value": str}
|
| 43 |
-
- RemoveDeprecatedMethod: {"class_name": str, "method_name": str, "replacement": str}
|
| 44 |
-
- ChangeReturnType: {"function_name": str, "old_access": str, "new_access": str}
|
| 45 |
-
|
| 46 |
-
Output ONLY the JSON object — no commentary, no markdown fences.
|
| 47 |
-
"""
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
REPAIR_AGENT_SYSTEM_PROMPT = """You are the Repair Agent.
|
| 51 |
-
You see a broken HuggingFace training script, an error trace, and the current
|
| 52 |
-
library version snapshot. Output ONLY a unified diff that fixes the script.
|
| 53 |
-
|
| 54 |
-
Rules:
|
| 55 |
-
1. Use canonical unified-diff format with `--- a/train.py` / `+++ b/train.py`
|
| 56 |
-
headers and `@@ ... @@` hunk markers.
|
| 57 |
-
2. Make the MINIMAL change that resolves the error AND preserves the original
|
| 58 |
-
training intent. Do NOT add bare-except blocks, monkey-patches, or sys.exit
|
| 59 |
-
calls.
|
| 60 |
-
3. Do NOT add any prose, markdown fences, or thinking output — diff only.
|
| 61 |
-
4. If the error is unfixable, output an empty diff.
|
| 62 |
-
"""
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
def render_drift_generator_prompt(
|
| 66 |
-
script: str, target_category: str, library_versions: dict
|
| 67 |
-
) -> str:
|
| 68 |
-
versions_str = ", ".join(f"{k}={v}" for k, v in library_versions.items())
|
| 69 |
-
return f"""Target category: {target_category}
|
| 70 |
-
Library versions: {versions_str}
|
| 71 |
-
|
| 72 |
-
Working script:
|
| 73 |
-
```python
|
| 74 |
-
{script}
|
| 75 |
-
```
|
| 76 |
-
|
| 77 |
-
Output JSON breakage primitive:"""
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def render_repair_agent_prompt(
|
| 81 |
-
broken_script: str,
|
| 82 |
-
error_trace: str,
|
| 83 |
-
library_versions: dict,
|
| 84 |
-
target_category: str = "",
|
| 85 |
-
) -> str:
|
| 86 |
-
versions_str = ", ".join(f"{k}={v}" for k, v in library_versions.items())
|
| 87 |
-
return f"""Library versions: {versions_str}
|
| 88 |
-
Target category hint: {target_category or 'unknown'}
|
| 89 |
-
|
| 90 |
-
Broken script:
|
| 91 |
-
```python
|
| 92 |
-
{broken_script}
|
| 93 |
-
```
|
| 94 |
-
|
| 95 |
-
Error trace:
|
| 96 |
-
{error_trace}
|
| 97 |
-
|
| 98 |
-
Output unified diff (no prose, no fences):"""
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
def list_primitive_descriptions() -> Iterable[str]:
|
| 102 |
-
return (f"- {k}: {v}" for k, v in PRIMITIVE_DESCRIPTIONS.items())
|
|
|
|
| 1 |
+
"""System and user prompts for the two RL roles.
|
| 2 |
+
|
| 3 |
+
Both roles are trained from the same base policy (Qwen-2.5-Coder-7B) with
|
| 4 |
+
LoRA adapters per role, so role prompts are the only thing distinguishing
|
| 5 |
+
them at inference time. Keep them concise — every token is a token of GPU
|
| 6 |
+
budget during GRPO rollouts.
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from typing import Iterable
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
PRIMITIVE_DESCRIPTIONS = {
|
| 14 |
+
"RenameApiCall": "Rename a function/method call (api_drift)",
|
| 15 |
+
"DeprecateImport": "Change an import path (import_drift)",
|
| 16 |
+
"ChangeArgumentSignature": "Remove an expected kwarg from a call (api_drift)",
|
| 17 |
+
"ModifyConfigField": "Change a config-class default (config_drift)",
|
| 18 |
+
"RestructureDatasetSchema": "Rename a dataset column reference (dataset_drift)",
|
| 19 |
+
"ChangeTokenizerBehavior": "Change tokenizer call kwargs (tokenizer_drift)",
|
| 20 |
+
"RemoveDeprecatedMethod": "Remove a method, leaving a sentinel _DEPRECATED suffix (api_drift)",
|
| 21 |
+
"ChangeReturnType": "Function returns a different structure (api_drift)",
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
DRIFT_GENERATOR_SYSTEM_PROMPT = """You are the Drift Generator.
|
| 25 |
+
You see a working HuggingFace training script and the curriculum target category.
|
| 26 |
+
Output exactly one JSON object describing a breakage primitive that simulates
|
| 27 |
+
realistic library version drift. The primitive must:
|
| 28 |
+
1. Be PLAUSIBLE — match the kind of breakage that happens between real
|
| 29 |
+
transformers/datasets/trl releases.
|
| 30 |
+
2. Be SOLVABLE — the Repair Agent should be able to fix it from the error trace alone.
|
| 31 |
+
3. Match the requested target_category.
|
| 32 |
+
|
| 33 |
+
Output schema:
|
| 34 |
+
{"primitive_type": "<one of the 8 types>", "params": { ... }}
|
| 35 |
+
|
| 36 |
+
Available primitive types and parameter schemas:
|
| 37 |
+
- RenameApiCall: {"old_name": str, "new_name": str}
|
| 38 |
+
- DeprecateImport: {"old_module": str, "new_module": str}
|
| 39 |
+
- ChangeArgumentSignature: {"function_name": str, "removed_arg": str, "added_arg": str, "added_value": str}
|
| 40 |
+
- ModifyConfigField: {"config_class": str, "field_name": str, "new_value": str}
|
| 41 |
+
- RestructureDatasetSchema: {"old_column": str, "new_column": str}
|
| 42 |
+
- ChangeTokenizerBehavior: {"old_kwarg": str, "old_value": str, "new_kwarg": str, "new_value": str}
|
| 43 |
+
- RemoveDeprecatedMethod: {"class_name": str, "method_name": str, "replacement": str}
|
| 44 |
+
- ChangeReturnType: {"function_name": str, "old_access": str, "new_access": str}
|
| 45 |
+
|
| 46 |
+
Output ONLY the JSON object — no commentary, no markdown fences.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
REPAIR_AGENT_SYSTEM_PROMPT = """You are the Repair Agent.
|
| 51 |
+
You see a broken HuggingFace training script, an error trace, and the current
|
| 52 |
+
library version snapshot. Output ONLY a unified diff that fixes the script.
|
| 53 |
+
|
| 54 |
+
Rules:
|
| 55 |
+
1. Use canonical unified-diff format with `--- a/train.py` / `+++ b/train.py`
|
| 56 |
+
headers and `@@ ... @@` hunk markers.
|
| 57 |
+
2. Make the MINIMAL change that resolves the error AND preserves the original
|
| 58 |
+
training intent. Do NOT add bare-except blocks, monkey-patches, or sys.exit
|
| 59 |
+
calls.
|
| 60 |
+
3. Do NOT add any prose, markdown fences, or thinking output — diff only.
|
| 61 |
+
4. If the error is unfixable, output an empty diff.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def render_drift_generator_prompt(
|
| 66 |
+
script: str, target_category: str, library_versions: dict
|
| 67 |
+
) -> str:
|
| 68 |
+
versions_str = ", ".join(f"{k}={v}" for k, v in library_versions.items())
|
| 69 |
+
return f"""Target category: {target_category}
|
| 70 |
+
Library versions: {versions_str}
|
| 71 |
+
|
| 72 |
+
Working script:
|
| 73 |
+
```python
|
| 74 |
+
{script}
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
Output JSON breakage primitive:"""
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def render_repair_agent_prompt(
|
| 81 |
+
broken_script: str,
|
| 82 |
+
error_trace: str,
|
| 83 |
+
library_versions: dict,
|
| 84 |
+
target_category: str = "",
|
| 85 |
+
) -> str:
|
| 86 |
+
versions_str = ", ".join(f"{k}={v}" for k, v in library_versions.items())
|
| 87 |
+
return f"""Library versions: {versions_str}
|
| 88 |
+
Target category hint: {target_category or 'unknown'}
|
| 89 |
+
|
| 90 |
+
Broken script:
|
| 91 |
+
```python
|
| 92 |
+
{broken_script}
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
Error trace:
|
| 96 |
+
{error_trace}
|
| 97 |
+
|
| 98 |
+
Output unified diff (no prose, no fences):"""
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def list_primitive_descriptions() -> Iterable[str]:
|
| 102 |
+
return (f"- {k}: {v}" for k, v in PRIMITIVE_DESCRIPTIONS.items())
|
forgeenv/roles/repair_agent.py
CHANGED
|
@@ -1,153 +1,153 @@
|
|
| 1 |
-
"""Repair Agent helpers: response sanitisation + a deterministic baseline.
|
| 2 |
-
|
| 3 |
-
The Repair Agent's training output is a unified diff. LLMs frequently emit
|
| 4 |
-
prose / fences / chain-of-thought before the diff; this module strips that
|
| 5 |
-
preamble. The baseline policy uses the inverse-primitive map from
|
| 6 |
-
`repair_primitives.py` to produce ground-truth diffs for warm-start.
|
| 7 |
-
"""
|
| 8 |
-
from __future__ import annotations
|
| 9 |
-
|
| 10 |
-
import re
|
| 11 |
-
from dataclasses import dataclass
|
| 12 |
-
from typing import Optional
|
| 13 |
-
|
| 14 |
-
from forgeenv.env.diff_utils import make_unified_diff
|
| 15 |
-
from forgeenv.primitives.breakage_primitives import (
|
| 16 |
-
parse_breakage_spec,
|
| 17 |
-
BreakagePrimitive,
|
| 18 |
-
)
|
| 19 |
-
from forgeenv.primitives.repair_primitives import (
|
| 20 |
-
BREAKAGE_TO_REPAIR,
|
| 21 |
-
REPAIR_REGISTRY,
|
| 22 |
-
RepairPrimitive,
|
| 23 |
-
)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
_DIFF_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE)
|
| 27 |
-
_FENCE_RE = re.compile(r"```[a-zA-Z]*\n([\s\S]*?)\n```")
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def extract_diff(raw_text: str) -> str:
|
| 31 |
-
"""Pull the unified diff out of an LLM response.
|
| 32 |
-
|
| 33 |
-
Handles: code fences, leading prose / chain-of-thought, trailing notes.
|
| 34 |
-
"""
|
| 35 |
-
if not raw_text:
|
| 36 |
-
return ""
|
| 37 |
-
raw_text = raw_text.strip()
|
| 38 |
-
|
| 39 |
-
fence_match = _FENCE_RE.search(raw_text)
|
| 40 |
-
if fence_match:
|
| 41 |
-
raw_text = fence_match.group(1).strip()
|
| 42 |
-
|
| 43 |
-
lines = raw_text.splitlines()
|
| 44 |
-
start = 0
|
| 45 |
-
for i, line in enumerate(lines):
|
| 46 |
-
if line.startswith(("---", "+++", "@@")):
|
| 47 |
-
start = i
|
| 48 |
-
break
|
| 49 |
-
|
| 50 |
-
return "\n".join(lines[start:])
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def looks_like_diff(text: str) -> bool:
|
| 54 |
-
if not text:
|
| 55 |
-
return False
|
| 56 |
-
has_header = "---" in text and "+++" in text
|
| 57 |
-
has_hunk = bool(_DIFF_HUNK_RE.search(text))
|
| 58 |
-
has_pm = any(line.startswith(("+", "-")) for line in text.splitlines())
|
| 59 |
-
return (has_header and has_hunk) or (has_hunk and has_pm)
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
# ---------------------------------------------------------------- baselines
|
| 63 |
-
@dataclass
|
| 64 |
-
class BaselineRepairAgent:
|
| 65 |
-
"""Deterministic Repair Agent that uses the primitive inverse map.
|
| 66 |
-
|
| 67 |
-
Used for warm-start dataset generation and baseline rollout comparisons.
|
| 68 |
-
"""
|
| 69 |
-
|
| 70 |
-
def repair(
|
| 71 |
-
self,
|
| 72 |
-
broken_script: str,
|
| 73 |
-
breakage_spec: Optional[dict] = None,
|
| 74 |
-
original_script: str = "",
|
| 75 |
-
) -> str:
|
| 76 |
-
"""Return a unified diff (or full replacement script) that fixes the
|
| 77 |
-
broken script.
|
| 78 |
-
|
| 79 |
-
Strategy preference:
|
| 80 |
-
1. If `original_script` is provided, return a diff between the
|
| 81 |
-
broken script and the original (oracle). This is the warm-start
|
| 82 |
-
path — we always know the ground truth.
|
| 83 |
-
2. Otherwise try to invert the structured breakage_spec via the
|
| 84 |
-
repair-primitive registry.
|
| 85 |
-
3. Otherwise return an empty diff.
|
| 86 |
-
"""
|
| 87 |
-
if original_script and original_script != broken_script:
|
| 88 |
-
return make_unified_diff(broken_script, original_script)
|
| 89 |
-
|
| 90 |
-
if breakage_spec:
|
| 91 |
-
try:
|
| 92 |
-
breakage = parse_breakage_spec(breakage_spec)
|
| 93 |
-
except (ValueError, TypeError):
|
| 94 |
-
breakage = None
|
| 95 |
-
if breakage is not None:
|
| 96 |
-
repair = _invert_breakage(breakage)
|
| 97 |
-
if repair is not None:
|
| 98 |
-
repaired = repair.apply(broken_script)
|
| 99 |
-
if repaired != broken_script:
|
| 100 |
-
return make_unified_diff(broken_script, repaired)
|
| 101 |
-
|
| 102 |
-
return ""
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
_PARAM_REMAP: dict[str, dict[str, str]] = {
|
| 106 |
-
"RenameApiCall": {"old_name": "old_name", "new_name": "new_name"},
|
| 107 |
-
"DeprecateImport": {"old_module": "old_module", "new_module": "new_module"},
|
| 108 |
-
"ChangeArgumentSignature": {
|
| 109 |
-
"function_name": "function_name",
|
| 110 |
-
"removed_arg": "arg_name",
|
| 111 |
-
},
|
| 112 |
-
"ModifyConfigField": {"field_name": "field_name"},
|
| 113 |
-
"RestructureDatasetSchema": {
|
| 114 |
-
"old_column": "old_column",
|
| 115 |
-
"new_column": "new_column",
|
| 116 |
-
},
|
| 117 |
-
"ChangeTokenizerBehavior": {
|
| 118 |
-
"old_kwarg": "old_kwarg",
|
| 119 |
-
"old_value": "old_value",
|
| 120 |
-
"new_kwarg": "new_kwarg",
|
| 121 |
-
"new_value": "new_value",
|
| 122 |
-
},
|
| 123 |
-
"RemoveDeprecatedMethod": {"method_name": "method_name"},
|
| 124 |
-
"ChangeReturnType": {"old_access": "old_access", "new_access": "new_access"},
|
| 125 |
-
}
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
def _invert_breakage(breakage: BreakagePrimitive) -> Optional[RepairPrimitive]:
|
| 129 |
-
breakage_name = type(breakage).__name__
|
| 130 |
-
repair_name = BREAKAGE_TO_REPAIR.get(breakage_name)
|
| 131 |
-
if repair_name is None:
|
| 132 |
-
return None
|
| 133 |
-
repair_cls = REPAIR_REGISTRY.get(repair_name)
|
| 134 |
-
if repair_cls is None:
|
| 135 |
-
return None
|
| 136 |
-
|
| 137 |
-
breakage_params = breakage._get_params() # type: ignore[attr-defined]
|
| 138 |
-
remap = _PARAM_REMAP.get(breakage_name, {})
|
| 139 |
-
mapped: dict[str, str] = {}
|
| 140 |
-
for src_key, dst_key in remap.items():
|
| 141 |
-
if src_key in breakage_params:
|
| 142 |
-
mapped[dst_key] = breakage_params[src_key]
|
| 143 |
-
|
| 144 |
-
valid_fields = {
|
| 145 |
-
f.name
|
| 146 |
-
for f in repair_cls.__dataclass_fields__.values() # type: ignore[attr-defined]
|
| 147 |
-
if f.init
|
| 148 |
-
}
|
| 149 |
-
filtered = {k: v for k, v in mapped.items() if k in valid_fields}
|
| 150 |
-
try:
|
| 151 |
-
return repair_cls(**filtered)
|
| 152 |
-
except TypeError:
|
| 153 |
-
return None
|
|
|
|
| 1 |
+
"""Repair Agent helpers: response sanitisation + a deterministic baseline.
|
| 2 |
+
|
| 3 |
+
The Repair Agent's training output is a unified diff. LLMs frequently emit
|
| 4 |
+
prose / fences / chain-of-thought before the diff; this module strips that
|
| 5 |
+
preamble. The baseline policy uses the inverse-primitive map from
|
| 6 |
+
`repair_primitives.py` to produce ground-truth diffs for warm-start.
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import re
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
from forgeenv.env.diff_utils import make_unified_diff
|
| 15 |
+
from forgeenv.primitives.breakage_primitives import (
|
| 16 |
+
parse_breakage_spec,
|
| 17 |
+
BreakagePrimitive,
|
| 18 |
+
)
|
| 19 |
+
from forgeenv.primitives.repair_primitives import (
|
| 20 |
+
BREAKAGE_TO_REPAIR,
|
| 21 |
+
REPAIR_REGISTRY,
|
| 22 |
+
RepairPrimitive,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
_DIFF_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE)
|
| 27 |
+
_FENCE_RE = re.compile(r"```[a-zA-Z]*\n([\s\S]*?)\n```")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def extract_diff(raw_text: str) -> str:
|
| 31 |
+
"""Pull the unified diff out of an LLM response.
|
| 32 |
+
|
| 33 |
+
Handles: code fences, leading prose / chain-of-thought, trailing notes.
|
| 34 |
+
"""
|
| 35 |
+
if not raw_text:
|
| 36 |
+
return ""
|
| 37 |
+
raw_text = raw_text.strip()
|
| 38 |
+
|
| 39 |
+
fence_match = _FENCE_RE.search(raw_text)
|
| 40 |
+
if fence_match:
|
| 41 |
+
raw_text = fence_match.group(1).strip()
|
| 42 |
+
|
| 43 |
+
lines = raw_text.splitlines()
|
| 44 |
+
start = 0
|
| 45 |
+
for i, line in enumerate(lines):
|
| 46 |
+
if line.startswith(("---", "+++", "@@")):
|
| 47 |
+
start = i
|
| 48 |
+
break
|
| 49 |
+
|
| 50 |
+
return "\n".join(lines[start:])
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def looks_like_diff(text: str) -> bool:
|
| 54 |
+
if not text:
|
| 55 |
+
return False
|
| 56 |
+
has_header = "---" in text and "+++" in text
|
| 57 |
+
has_hunk = bool(_DIFF_HUNK_RE.search(text))
|
| 58 |
+
has_pm = any(line.startswith(("+", "-")) for line in text.splitlines())
|
| 59 |
+
return (has_header and has_hunk) or (has_hunk and has_pm)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ---------------------------------------------------------------- baselines
|
| 63 |
+
@dataclass
|
| 64 |
+
class BaselineRepairAgent:
|
| 65 |
+
"""Deterministic Repair Agent that uses the primitive inverse map.
|
| 66 |
+
|
| 67 |
+
Used for warm-start dataset generation and baseline rollout comparisons.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def repair(
|
| 71 |
+
self,
|
| 72 |
+
broken_script: str,
|
| 73 |
+
breakage_spec: Optional[dict] = None,
|
| 74 |
+
original_script: str = "",
|
| 75 |
+
) -> str:
|
| 76 |
+
"""Return a unified diff (or full replacement script) that fixes the
|
| 77 |
+
broken script.
|
| 78 |
+
|
| 79 |
+
Strategy preference:
|
| 80 |
+
1. If `original_script` is provided, return a diff between the
|
| 81 |
+
broken script and the original (oracle). This is the warm-start
|
| 82 |
+
path — we always know the ground truth.
|
| 83 |
+
2. Otherwise try to invert the structured breakage_spec via the
|
| 84 |
+
repair-primitive registry.
|
| 85 |
+
3. Otherwise return an empty diff.
|
| 86 |
+
"""
|
| 87 |
+
if original_script and original_script != broken_script:
|
| 88 |
+
return make_unified_diff(broken_script, original_script)
|
| 89 |
+
|
| 90 |
+
if breakage_spec:
|
| 91 |
+
try:
|
| 92 |
+
breakage = parse_breakage_spec(breakage_spec)
|
| 93 |
+
except (ValueError, TypeError):
|
| 94 |
+
breakage = None
|
| 95 |
+
if breakage is not None:
|
| 96 |
+
repair = _invert_breakage(breakage)
|
| 97 |
+
if repair is not None:
|
| 98 |
+
repaired = repair.apply(broken_script)
|
| 99 |
+
if repaired != broken_script:
|
| 100 |
+
return make_unified_diff(broken_script, repaired)
|
| 101 |
+
|
| 102 |
+
return ""
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
_PARAM_REMAP: dict[str, dict[str, str]] = {
|
| 106 |
+
"RenameApiCall": {"old_name": "old_name", "new_name": "new_name"},
|
| 107 |
+
"DeprecateImport": {"old_module": "old_module", "new_module": "new_module"},
|
| 108 |
+
"ChangeArgumentSignature": {
|
| 109 |
+
"function_name": "function_name",
|
| 110 |
+
"removed_arg": "arg_name",
|
| 111 |
+
},
|
| 112 |
+
"ModifyConfigField": {"field_name": "field_name"},
|
| 113 |
+
"RestructureDatasetSchema": {
|
| 114 |
+
"old_column": "old_column",
|
| 115 |
+
"new_column": "new_column",
|
| 116 |
+
},
|
| 117 |
+
"ChangeTokenizerBehavior": {
|
| 118 |
+
"old_kwarg": "old_kwarg",
|
| 119 |
+
"old_value": "old_value",
|
| 120 |
+
"new_kwarg": "new_kwarg",
|
| 121 |
+
"new_value": "new_value",
|
| 122 |
+
},
|
| 123 |
+
"RemoveDeprecatedMethod": {"method_name": "method_name"},
|
| 124 |
+
"ChangeReturnType": {"old_access": "old_access", "new_access": "new_access"},
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _invert_breakage(breakage: BreakagePrimitive) -> Optional[RepairPrimitive]:
|
| 129 |
+
breakage_name = type(breakage).__name__
|
| 130 |
+
repair_name = BREAKAGE_TO_REPAIR.get(breakage_name)
|
| 131 |
+
if repair_name is None:
|
| 132 |
+
return None
|
| 133 |
+
repair_cls = REPAIR_REGISTRY.get(repair_name)
|
| 134 |
+
if repair_cls is None:
|
| 135 |
+
return None
|
| 136 |
+
|
| 137 |
+
breakage_params = breakage._get_params() # type: ignore[attr-defined]
|
| 138 |
+
remap = _PARAM_REMAP.get(breakage_name, {})
|
| 139 |
+
mapped: dict[str, str] = {}
|
| 140 |
+
for src_key, dst_key in remap.items():
|
| 141 |
+
if src_key in breakage_params:
|
| 142 |
+
mapped[dst_key] = breakage_params[src_key]
|
| 143 |
+
|
| 144 |
+
valid_fields = {
|
| 145 |
+
f.name
|
| 146 |
+
for f in repair_cls.__dataclass_fields__.values() # type: ignore[attr-defined]
|
| 147 |
+
if f.init
|
| 148 |
+
}
|
| 149 |
+
filtered = {k: v for k, v in mapped.items() if k in valid_fields}
|
| 150 |
+
try:
|
| 151 |
+
return repair_cls(**filtered)
|
| 152 |
+
except TypeError:
|
| 153 |
+
return None
|
forgeenv/roles/teacher.py
CHANGED
|
@@ -1,58 +1,58 @@
|
|
| 1 |
-
"""Teacher (curriculum controller).
|
| 2 |
-
|
| 3 |
-
Deterministic — NOT an LLM. Maintains an EMA success rate per breakage
|
| 4 |
-
category and routes the next episode toward the category where the
|
| 5 |
-
Repair Agent is closest to a 50% success rate (R-Zero's difficulty band).
|
| 6 |
-
"""
|
| 7 |
-
from __future__ import annotations
|
| 8 |
-
|
| 9 |
-
import random
|
| 10 |
-
from dataclasses import dataclass, field
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
@dataclass
|
| 14 |
-
class Teacher:
|
| 15 |
-
categories: list[str]
|
| 16 |
-
alpha: float = 0.9
|
| 17 |
-
success_counts: dict[str, int] = field(default_factory=dict)
|
| 18 |
-
attempt_counts: dict[str, int] = field(default_factory=dict)
|
| 19 |
-
ema_success: dict[str, float] = field(default_factory=dict)
|
| 20 |
-
|
| 21 |
-
def __post_init__(self) -> None:
|
| 22 |
-
for category in self.categories:
|
| 23 |
-
self.success_counts.setdefault(category, 0)
|
| 24 |
-
self.attempt_counts.setdefault(category, 0)
|
| 25 |
-
self.ema_success.setdefault(category, 0.5)
|
| 26 |
-
|
| 27 |
-
def update(self, category: str, success: bool) -> None:
|
| 28 |
-
if category not in self.ema_success:
|
| 29 |
-
self.categories.append(category)
|
| 30 |
-
self.ema_success[category] = 0.5
|
| 31 |
-
self.success_counts[category] = 0
|
| 32 |
-
self.attempt_counts[category] = 0
|
| 33 |
-
|
| 34 |
-
self.attempt_counts[category] += 1
|
| 35 |
-
self.success_counts[category] += int(success)
|
| 36 |
-
rate = self.success_counts[category] / max(1, self.attempt_counts[category])
|
| 37 |
-
self.ema_success[category] = (
|
| 38 |
-
self.alpha * self.ema_success[category] + (1 - self.alpha) * rate
|
| 39 |
-
)
|
| 40 |
-
|
| 41 |
-
def select_next_category(self) -> str:
|
| 42 |
-
in_zone = {
|
| 43 |
-
c: abs(s - 0.5) for c, s in self.ema_success.items() if 0.3 <= s <= 0.7
|
| 44 |
-
}
|
| 45 |
-
if in_zone:
|
| 46 |
-
weights = [1.0 / (v + 0.01) for v in in_zone.values()]
|
| 47 |
-
return random.choices(list(in_zone.keys()), weights=weights, k=1)[0]
|
| 48 |
-
return min(self.ema_success, key=lambda c: abs(self.ema_success[c] - 0.5))
|
| 49 |
-
|
| 50 |
-
def get_state(self) -> dict:
|
| 51 |
-
return {
|
| 52 |
-
c: {
|
| 53 |
-
"ema_success": round(self.ema_success[c], 4),
|
| 54 |
-
"attempts": self.attempt_counts[c],
|
| 55 |
-
"successes": self.success_counts[c],
|
| 56 |
-
}
|
| 57 |
-
for c in self.categories
|
| 58 |
-
}
|
|
|
|
| 1 |
+
"""Teacher (curriculum controller).
|
| 2 |
+
|
| 3 |
+
Deterministic — NOT an LLM. Maintains an EMA success rate per breakage
|
| 4 |
+
category and routes the next episode toward the category where the
|
| 5 |
+
Repair Agent is closest to a 50% success rate (R-Zero's difficulty band).
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import random
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class Teacher:
|
| 15 |
+
categories: list[str]
|
| 16 |
+
alpha: float = 0.9
|
| 17 |
+
success_counts: dict[str, int] = field(default_factory=dict)
|
| 18 |
+
attempt_counts: dict[str, int] = field(default_factory=dict)
|
| 19 |
+
ema_success: dict[str, float] = field(default_factory=dict)
|
| 20 |
+
|
| 21 |
+
def __post_init__(self) -> None:
|
| 22 |
+
for category in self.categories:
|
| 23 |
+
self.success_counts.setdefault(category, 0)
|
| 24 |
+
self.attempt_counts.setdefault(category, 0)
|
| 25 |
+
self.ema_success.setdefault(category, 0.5)
|
| 26 |
+
|
| 27 |
+
def update(self, category: str, success: bool) -> None:
|
| 28 |
+
if category not in self.ema_success:
|
| 29 |
+
self.categories.append(category)
|
| 30 |
+
self.ema_success[category] = 0.5
|
| 31 |
+
self.success_counts[category] = 0
|
| 32 |
+
self.attempt_counts[category] = 0
|
| 33 |
+
|
| 34 |
+
self.attempt_counts[category] += 1
|
| 35 |
+
self.success_counts[category] += int(success)
|
| 36 |
+
rate = self.success_counts[category] / max(1, self.attempt_counts[category])
|
| 37 |
+
self.ema_success[category] = (
|
| 38 |
+
self.alpha * self.ema_success[category] + (1 - self.alpha) * rate
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def select_next_category(self) -> str:
|
| 42 |
+
in_zone = {
|
| 43 |
+
c: abs(s - 0.5) for c, s in self.ema_success.items() if 0.3 <= s <= 0.7
|
| 44 |
+
}
|
| 45 |
+
if in_zone:
|
| 46 |
+
weights = [1.0 / (v + 0.01) for v in in_zone.values()]
|
| 47 |
+
return random.choices(list(in_zone.keys()), weights=weights, k=1)[0]
|
| 48 |
+
return min(self.ema_success, key=lambda c: abs(self.ema_success[c] - 0.5))
|
| 49 |
+
|
| 50 |
+
def get_state(self) -> dict:
|
| 51 |
+
return {
|
| 52 |
+
c: {
|
| 53 |
+
"ema_success": round(self.ema_success[c], 4),
|
| 54 |
+
"attempts": self.attempt_counts[c],
|
| 55 |
+
"successes": self.success_counts[c],
|
| 56 |
+
}
|
| 57 |
+
for c in self.categories
|
| 58 |
+
}
|
forgeenv/sandbox/ast_validator.py
CHANGED
|
@@ -1,70 +1,70 @@
|
|
| 1 |
-
"""AST-based script validator.
|
| 2 |
-
|
| 3 |
-
Catches forbidden imports and dangerous patterns BEFORE any execution
|
| 4 |
-
happens. This is a critical defense against reward hacking via system
|
| 5 |
-
calls, network access, or process manipulation.
|
| 6 |
-
"""
|
| 7 |
-
from __future__ import annotations
|
| 8 |
-
|
| 9 |
-
import ast
|
| 10 |
-
|
| 11 |
-
from forgeenv.tasks.models import ValidationResult
|
| 12 |
-
|
| 13 |
-
FORBIDDEN_MODULES = {
|
| 14 |
-
"os",
|
| 15 |
-
"subprocess",
|
| 16 |
-
"socket",
|
| 17 |
-
"urllib",
|
| 18 |
-
"requests",
|
| 19 |
-
"ctypes",
|
| 20 |
-
"shutil",
|
| 21 |
-
"signal",
|
| 22 |
-
"multiprocessing",
|
| 23 |
-
"threading",
|
| 24 |
-
}
|
| 25 |
-
|
| 26 |
-
FORBIDDEN_FUNCTIONS = {"eval", "exec", "compile", "__import__"}
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def validate_script(script_content: str) -> ValidationResult:
|
| 30 |
-
"""Parse a script as AST and reject forbidden patterns.
|
| 31 |
-
|
| 32 |
-
Returns a ValidationResult with `is_valid` and a list of `violations`.
|
| 33 |
-
"""
|
| 34 |
-
violations: list[str] = []
|
| 35 |
-
|
| 36 |
-
try:
|
| 37 |
-
tree = ast.parse(script_content)
|
| 38 |
-
except SyntaxError as e:
|
| 39 |
-
return ValidationResult(is_valid=False, violations=[f"SyntaxError: {e}"])
|
| 40 |
-
|
| 41 |
-
for node in ast.walk(tree):
|
| 42 |
-
if isinstance(node, ast.Import):
|
| 43 |
-
for alias in node.names:
|
| 44 |
-
module_root = alias.name.split(".")[0]
|
| 45 |
-
if module_root in FORBIDDEN_MODULES:
|
| 46 |
-
violations.append(f"Forbidden import: {alias.name}")
|
| 47 |
-
|
| 48 |
-
if isinstance(node, ast.ImportFrom):
|
| 49 |
-
if node.module:
|
| 50 |
-
module_root = node.module.split(".")[0]
|
| 51 |
-
if module_root in FORBIDDEN_MODULES:
|
| 52 |
-
violations.append(f"Forbidden import from: {node.module}")
|
| 53 |
-
|
| 54 |
-
if isinstance(node, ast.Call):
|
| 55 |
-
if isinstance(node.func, ast.Name):
|
| 56 |
-
if node.func.id in FORBIDDEN_FUNCTIONS:
|
| 57 |
-
violations.append(f"Forbidden call: {node.func.id}()")
|
| 58 |
-
if isinstance(node.func, ast.Attribute):
|
| 59 |
-
if node.func.attr in FORBIDDEN_FUNCTIONS:
|
| 60 |
-
violations.append(f"Forbidden call: .{node.func.attr}()")
|
| 61 |
-
|
| 62 |
-
if isinstance(node, ast.Assign):
|
| 63 |
-
for target in node.targets:
|
| 64 |
-
if isinstance(target, ast.Name) and target.id == "__builtins__":
|
| 65 |
-
violations.append("Forbidden: __builtins__ assignment")
|
| 66 |
-
|
| 67 |
-
return ValidationResult(
|
| 68 |
-
is_valid=len(violations) == 0,
|
| 69 |
-
violations=violations,
|
| 70 |
-
)
|
|
|
|
| 1 |
+
"""AST-based script validator.
|
| 2 |
+
|
| 3 |
+
Catches forbidden imports and dangerous patterns BEFORE any execution
|
| 4 |
+
happens. This is a critical defense against reward hacking via system
|
| 5 |
+
calls, network access, or process manipulation.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import ast
|
| 10 |
+
|
| 11 |
+
from forgeenv.tasks.models import ValidationResult
|
| 12 |
+
|
| 13 |
+
FORBIDDEN_MODULES = {
|
| 14 |
+
"os",
|
| 15 |
+
"subprocess",
|
| 16 |
+
"socket",
|
| 17 |
+
"urllib",
|
| 18 |
+
"requests",
|
| 19 |
+
"ctypes",
|
| 20 |
+
"shutil",
|
| 21 |
+
"signal",
|
| 22 |
+
"multiprocessing",
|
| 23 |
+
"threading",
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
FORBIDDEN_FUNCTIONS = {"eval", "exec", "compile", "__import__"}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def validate_script(script_content: str) -> ValidationResult:
|
| 30 |
+
"""Parse a script as AST and reject forbidden patterns.
|
| 31 |
+
|
| 32 |
+
Returns a ValidationResult with `is_valid` and a list of `violations`.
|
| 33 |
+
"""
|
| 34 |
+
violations: list[str] = []
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
tree = ast.parse(script_content)
|
| 38 |
+
except SyntaxError as e:
|
| 39 |
+
return ValidationResult(is_valid=False, violations=[f"SyntaxError: {e}"])
|
| 40 |
+
|
| 41 |
+
for node in ast.walk(tree):
|
| 42 |
+
if isinstance(node, ast.Import):
|
| 43 |
+
for alias in node.names:
|
| 44 |
+
module_root = alias.name.split(".")[0]
|
| 45 |
+
if module_root in FORBIDDEN_MODULES:
|
| 46 |
+
violations.append(f"Forbidden import: {alias.name}")
|
| 47 |
+
|
| 48 |
+
if isinstance(node, ast.ImportFrom):
|
| 49 |
+
if node.module:
|
| 50 |
+
module_root = node.module.split(".")[0]
|
| 51 |
+
if module_root in FORBIDDEN_MODULES:
|
| 52 |
+
violations.append(f"Forbidden import from: {node.module}")
|
| 53 |
+
|
| 54 |
+
if isinstance(node, ast.Call):
|
| 55 |
+
if isinstance(node.func, ast.Name):
|
| 56 |
+
if node.func.id in FORBIDDEN_FUNCTIONS:
|
| 57 |
+
violations.append(f"Forbidden call: {node.func.id}()")
|
| 58 |
+
if isinstance(node.func, ast.Attribute):
|
| 59 |
+
if node.func.attr in FORBIDDEN_FUNCTIONS:
|
| 60 |
+
violations.append(f"Forbidden call: .{node.func.attr}()")
|
| 61 |
+
|
| 62 |
+
if isinstance(node, ast.Assign):
|
| 63 |
+
for target in node.targets:
|
| 64 |
+
if isinstance(target, ast.Name) and target.id == "__builtins__":
|
| 65 |
+
violations.append("Forbidden: __builtins__ assignment")
|
| 66 |
+
|
| 67 |
+
return ValidationResult(
|
| 68 |
+
is_valid=len(violations) == 0,
|
| 69 |
+
violations=violations,
|
| 70 |
+
)
|
forgeenv/sandbox/simulation_mode.py
CHANGED
|
@@ -1,142 +1,142 @@
|
|
| 1 |
-
"""Fast simulation executor for development.
|
| 2 |
-
|
| 3 |
-
Static-analysis-based execution simulator. Sub-100ms per call. No Docker
|
| 4 |
-
required. The success probability of a simulated run depends on whether
|
| 5 |
-
the script contains expected HF training markers (model imports, training
|
| 6 |
-
calls, save calls). When the simulation succeeds, a synthetic decreasing
|
| 7 |
-
loss curve is emitted; when it fails, a representative HF error is raised.
|
| 8 |
-
"""
|
| 9 |
-
from __future__ import annotations
|
| 10 |
-
|
| 11 |
-
import random
|
| 12 |
-
import time
|
| 13 |
-
from typing import Optional
|
| 14 |
-
|
| 15 |
-
from forgeenv.sandbox.ast_validator import validate_script
|
| 16 |
-
from forgeenv.tasks.models import ExecutionResult, Task
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class SimulationExecutor:
|
| 20 |
-
"""Simulates script execution via static analysis.
|
| 21 |
-
|
| 22 |
-
Use this throughout development phases. Real Docker execution is added
|
| 23 |
-
later for grounded final-stage verification.
|
| 24 |
-
"""
|
| 25 |
-
|
| 26 |
-
def __init__(self, seed: Optional[int] = None) -> None:
|
| 27 |
-
self._rng = random.Random(seed) if seed is not None else random
|
| 28 |
-
|
| 29 |
-
def execute(
|
| 30 |
-
self, script_content: str, task: Optional[Task] = None
|
| 31 |
-
) -> ExecutionResult:
|
| 32 |
-
start = time.time()
|
| 33 |
-
|
| 34 |
-
validation = validate_script(script_content)
|
| 35 |
-
if not validation.is_valid:
|
| 36 |
-
return ExecutionResult(
|
| 37 |
-
exit_code=1,
|
| 38 |
-
stdout="",
|
| 39 |
-
stderr=f"Validation failed: {'; '.join(validation.violations)}",
|
| 40 |
-
wall_time_ms=int((time.time() - start) * 1000),
|
| 41 |
-
script_content=script_content,
|
| 42 |
-
)
|
| 43 |
-
|
| 44 |
-
try:
|
| 45 |
-
compile(script_content, "<forge_script>", "exec")
|
| 46 |
-
except SyntaxError as e:
|
| 47 |
-
return ExecutionResult(
|
| 48 |
-
exit_code=1,
|
| 49 |
-
stdout="",
|
| 50 |
-
stderr=f"SyntaxError: {e}",
|
| 51 |
-
wall_time_ms=int((time.time() - start) * 1000),
|
| 52 |
-
script_content=script_content,
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
has_model_import = any(
|
| 56 |
-
kw in script_content
|
| 57 |
-
for kw in ("from transformers", "import torch", "from datasets")
|
| 58 |
-
)
|
| 59 |
-
has_training_call = any(
|
| 60 |
-
kw in script_content
|
| 61 |
-
for kw in ("trainer.train()", ".fit(", "train_loop", "for epoch")
|
| 62 |
-
)
|
| 63 |
-
has_save = any(
|
| 64 |
-
kw in script_content
|
| 65 |
-
for kw in ("save_pretrained", "save_model", "torch.save")
|
| 66 |
-
)
|
| 67 |
-
|
| 68 |
-
success_prob = 0.3
|
| 69 |
-
if has_model_import:
|
| 70 |
-
success_prob += 0.3
|
| 71 |
-
if has_training_call:
|
| 72 |
-
success_prob += 0.2
|
| 73 |
-
if has_save:
|
| 74 |
-
success_prob += 0.1
|
| 75 |
-
|
| 76 |
-
# Mark obviously broken patterns as definite failures even when
|
| 77 |
-
# they pass syntactic compilation. The simulator pretends to be a
|
| 78 |
-
# static linter that catches AttributeError / ImportError signatures
|
| 79 |
-
# before they would fire at runtime.
|
| 80 |
-
broken_markers = (
|
| 81 |
-
"_DEPRECATED(",
|
| 82 |
-
"transformers.legacy",
|
| 83 |
-
"from transformers.training import",
|
| 84 |
-
".start_training(",
|
| 85 |
-
"load_from_hub(",
|
| 86 |
-
"save_to_hub(",
|
| 87 |
-
"pad_to_max_length=",
|
| 88 |
-
"evaluation_loop(",
|
| 89 |
-
)
|
| 90 |
-
if any(marker in script_content for marker in broken_markers):
|
| 91 |
-
success_prob = 0.0
|
| 92 |
-
# Patterns that look like dataset column drift: a renamed column
|
| 93 |
-
# that doesn't appear in real HF datasets.
|
| 94 |
-
import re as _re
|
| 95 |
-
|
| 96 |
-
if _re.search(r"['\"]input_text['\"]\s*[]:),]", script_content):
|
| 97 |
-
success_prob = min(success_prob, 0.05)
|
| 98 |
-
if _re.search(r"['\"]words['\"]\s*[]:),]", script_content):
|
| 99 |
-
success_prob = min(success_prob, 0.05)
|
| 100 |
-
# Tokenizer kwarg drift (truncate is not valid; truncation is).
|
| 101 |
-
if _re.search(r"\btruncate\s*=", script_content):
|
| 102 |
-
success_prob = min(success_prob, 0.05)
|
| 103 |
-
|
| 104 |
-
succeeded = self._rng.random() < success_prob
|
| 105 |
-
|
| 106 |
-
if succeeded:
|
| 107 |
-
steps = self._rng.randint(20, 50)
|
| 108 |
-
log_lines: list[str] = []
|
| 109 |
-
loss = self._rng.uniform(2.0, 4.0)
|
| 110 |
-
for step in range(1, steps + 1):
|
| 111 |
-
loss *= self._rng.uniform(0.92, 0.99)
|
| 112 |
-
log_lines.append(f"step={step} loss={loss:.4f}")
|
| 113 |
-
log_lines.append("eval_accuracy=0.78")
|
| 114 |
-
log_lines.append("TRAINING_COMPLETE")
|
| 115 |
-
|
| 116 |
-
return ExecutionResult(
|
| 117 |
-
exit_code=0,
|
| 118 |
-
stdout="\n".join(log_lines),
|
| 119 |
-
stderr="",
|
| 120 |
-
wall_time_ms=int((time.time() - start) * 1000)
|
| 121 |
-
+ self._rng.randint(1000, 5000),
|
| 122 |
-
checkpoint_exists=True,
|
| 123 |
-
peak_memory_mb=self._rng.uniform(500, 2000),
|
| 124 |
-
script_content=script_content,
|
| 125 |
-
)
|
| 126 |
-
|
| 127 |
-
error_types = [
|
| 128 |
-
"ImportError: cannot import name 'OldTrainer' from 'transformers'",
|
| 129 |
-
"AttributeError: 'Trainer' object has no attribute 'evaluate_model'",
|
| 130 |
-
"KeyError: 'text' column not found in dataset",
|
| 131 |
-
"TypeError: __init__() got an unexpected keyword argument 'num_epochs'",
|
| 132 |
-
"RuntimeError: Expected input batch_size (16) to match target batch_size (32)",
|
| 133 |
-
"ModuleNotFoundError: No module named 'transformers.legacy'",
|
| 134 |
-
]
|
| 135 |
-
return ExecutionResult(
|
| 136 |
-
exit_code=1,
|
| 137 |
-
stdout="",
|
| 138 |
-
stderr=self._rng.choice(error_types),
|
| 139 |
-
wall_time_ms=int((time.time() - start) * 1000)
|
| 140 |
-
+ self._rng.randint(100, 500),
|
| 141 |
-
script_content=script_content,
|
| 142 |
-
)
|
|
|
|
| 1 |
+
"""Fast simulation executor for development.
|
| 2 |
+
|
| 3 |
+
Static-analysis-based execution simulator. Sub-100ms per call. No Docker
|
| 4 |
+
required. The success probability of a simulated run depends on whether
|
| 5 |
+
the script contains expected HF training markers (model imports, training
|
| 6 |
+
calls, save calls). When the simulation succeeds, a synthetic decreasing
|
| 7 |
+
loss curve is emitted; when it fails, a representative HF error is raised.
|
| 8 |
+
"""
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import random
|
| 12 |
+
import time
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
from forgeenv.sandbox.ast_validator import validate_script
|
| 16 |
+
from forgeenv.tasks.models import ExecutionResult, Task
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SimulationExecutor:
|
| 20 |
+
"""Simulates script execution via static analysis.
|
| 21 |
+
|
| 22 |
+
Use this throughout development phases. Real Docker execution is added
|
| 23 |
+
later for grounded final-stage verification.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, seed: Optional[int] = None) -> None:
|
| 27 |
+
self._rng = random.Random(seed) if seed is not None else random
|
| 28 |
+
|
| 29 |
+
def execute(
|
| 30 |
+
self, script_content: str, task: Optional[Task] = None
|
| 31 |
+
) -> ExecutionResult:
|
| 32 |
+
start = time.time()
|
| 33 |
+
|
| 34 |
+
validation = validate_script(script_content)
|
| 35 |
+
if not validation.is_valid:
|
| 36 |
+
return ExecutionResult(
|
| 37 |
+
exit_code=1,
|
| 38 |
+
stdout="",
|
| 39 |
+
stderr=f"Validation failed: {'; '.join(validation.violations)}",
|
| 40 |
+
wall_time_ms=int((time.time() - start) * 1000),
|
| 41 |
+
script_content=script_content,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
compile(script_content, "<forge_script>", "exec")
|
| 46 |
+
except SyntaxError as e:
|
| 47 |
+
return ExecutionResult(
|
| 48 |
+
exit_code=1,
|
| 49 |
+
stdout="",
|
| 50 |
+
stderr=f"SyntaxError: {e}",
|
| 51 |
+
wall_time_ms=int((time.time() - start) * 1000),
|
| 52 |
+
script_content=script_content,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
has_model_import = any(
|
| 56 |
+
kw in script_content
|
| 57 |
+
for kw in ("from transformers", "import torch", "from datasets")
|
| 58 |
+
)
|
| 59 |
+
has_training_call = any(
|
| 60 |
+
kw in script_content
|
| 61 |
+
for kw in ("trainer.train()", ".fit(", "train_loop", "for epoch")
|
| 62 |
+
)
|
| 63 |
+
has_save = any(
|
| 64 |
+
kw in script_content
|
| 65 |
+
for kw in ("save_pretrained", "save_model", "torch.save")
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
success_prob = 0.3
|
| 69 |
+
if has_model_import:
|
| 70 |
+
success_prob += 0.3
|
| 71 |
+
if has_training_call:
|
| 72 |
+
success_prob += 0.2
|
| 73 |
+
if has_save:
|
| 74 |
+
success_prob += 0.1
|
| 75 |
+
|
| 76 |
+
# Mark obviously broken patterns as definite failures even when
|
| 77 |
+
# they pass syntactic compilation. The simulator pretends to be a
|
| 78 |
+
# static linter that catches AttributeError / ImportError signatures
|
| 79 |
+
# before they would fire at runtime.
|
| 80 |
+
broken_markers = (
|
| 81 |
+
"_DEPRECATED(",
|
| 82 |
+
"transformers.legacy",
|
| 83 |
+
"from transformers.training import",
|
| 84 |
+
".start_training(",
|
| 85 |
+
"load_from_hub(",
|
| 86 |
+
"save_to_hub(",
|
| 87 |
+
"pad_to_max_length=",
|
| 88 |
+
"evaluation_loop(",
|
| 89 |
+
)
|
| 90 |
+
if any(marker in script_content for marker in broken_markers):
|
| 91 |
+
success_prob = 0.0
|
| 92 |
+
# Patterns that look like dataset column drift: a renamed column
|
| 93 |
+
# that doesn't appear in real HF datasets.
|
| 94 |
+
import re as _re
|
| 95 |
+
|
| 96 |
+
if _re.search(r"['\"]input_text['\"]\s*[]:),]", script_content):
|
| 97 |
+
success_prob = min(success_prob, 0.05)
|
| 98 |
+
if _re.search(r"['\"]words['\"]\s*[]:),]", script_content):
|
| 99 |
+
success_prob = min(success_prob, 0.05)
|
| 100 |
+
# Tokenizer kwarg drift (truncate is not valid; truncation is).
|
| 101 |
+
if _re.search(r"\btruncate\s*=", script_content):
|
| 102 |
+
success_prob = min(success_prob, 0.05)
|
| 103 |
+
|
| 104 |
+
succeeded = self._rng.random() < success_prob
|
| 105 |
+
|
| 106 |
+
if succeeded:
|
| 107 |
+
steps = self._rng.randint(20, 50)
|
| 108 |
+
log_lines: list[str] = []
|
| 109 |
+
loss = self._rng.uniform(2.0, 4.0)
|
| 110 |
+
for step in range(1, steps + 1):
|
| 111 |
+
loss *= self._rng.uniform(0.92, 0.99)
|
| 112 |
+
log_lines.append(f"step={step} loss={loss:.4f}")
|
| 113 |
+
log_lines.append("eval_accuracy=0.78")
|
| 114 |
+
log_lines.append("TRAINING_COMPLETE")
|
| 115 |
+
|
| 116 |
+
return ExecutionResult(
|
| 117 |
+
exit_code=0,
|
| 118 |
+
stdout="\n".join(log_lines),
|
| 119 |
+
stderr="",
|
| 120 |
+
wall_time_ms=int((time.time() - start) * 1000)
|
| 121 |
+
+ self._rng.randint(1000, 5000),
|
| 122 |
+
checkpoint_exists=True,
|
| 123 |
+
peak_memory_mb=self._rng.uniform(500, 2000),
|
| 124 |
+
script_content=script_content,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
error_types = [
|
| 128 |
+
"ImportError: cannot import name 'OldTrainer' from 'transformers'",
|
| 129 |
+
"AttributeError: 'Trainer' object has no attribute 'evaluate_model'",
|
| 130 |
+
"KeyError: 'text' column not found in dataset",
|
| 131 |
+
"TypeError: __init__() got an unexpected keyword argument 'num_epochs'",
|
| 132 |
+
"RuntimeError: Expected input batch_size (16) to match target batch_size (32)",
|
| 133 |
+
"ModuleNotFoundError: No module named 'transformers.legacy'",
|
| 134 |
+
]
|
| 135 |
+
return ExecutionResult(
|
| 136 |
+
exit_code=1,
|
| 137 |
+
stdout="",
|
| 138 |
+
stderr=self._rng.choice(error_types),
|
| 139 |
+
wall_time_ms=int((time.time() - start) * 1000)
|
| 140 |
+
+ self._rng.randint(100, 500),
|
| 141 |
+
script_content=script_content,
|
| 142 |
+
)
|
forgeenv/tasks/models.py
CHANGED
|
@@ -1,45 +1,45 @@
|
|
| 1 |
-
"""Core data models for ForgeEnv tasks and execution results.
|
| 2 |
-
|
| 3 |
-
These are framework-internal dataclasses (not Pydantic) used throughout the
|
| 4 |
-
simulation, verifier, and primitive layers. The OpenEnv-facing Pydantic
|
| 5 |
-
models live in `forgeenv.env.actions` / `forgeenv.env.observations`.
|
| 6 |
-
"""
|
| 7 |
-
from __future__ import annotations
|
| 8 |
-
|
| 9 |
-
from dataclasses import dataclass, field
|
| 10 |
-
from typing import Optional
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
@dataclass
|
| 14 |
-
class Task:
|
| 15 |
-
"""A HuggingFace training script with execution metadata."""
|
| 16 |
-
|
| 17 |
-
task_id: str
|
| 18 |
-
description: str
|
| 19 |
-
script_content: str
|
| 20 |
-
difficulty: str # "easy", "medium", "hard"
|
| 21 |
-
category: str = "general"
|
| 22 |
-
expected_loss_range: tuple[float, float] = (0.0, 5.0)
|
| 23 |
-
expected_accuracy_range: tuple[float, float] = (0.0, 1.0)
|
| 24 |
-
checkpoint_output_path: str = "/tmp/forge_output/checkpoint"
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
@dataclass
|
| 28 |
-
class ExecutionResult:
|
| 29 |
-
"""Result of executing a Python script in the sandbox."""
|
| 30 |
-
|
| 31 |
-
exit_code: int
|
| 32 |
-
stdout: str
|
| 33 |
-
stderr: str
|
| 34 |
-
wall_time_ms: int
|
| 35 |
-
checkpoint_exists: bool = False
|
| 36 |
-
peak_memory_mb: float = 0.0
|
| 37 |
-
script_content: str = ""
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
@dataclass
|
| 41 |
-
class ValidationResult:
|
| 42 |
-
"""Result of AST validation on a script."""
|
| 43 |
-
|
| 44 |
-
is_valid: bool
|
| 45 |
-
violations: list[str] = field(default_factory=list)
|
|
|
|
| 1 |
+
"""Core data models for ForgeEnv tasks and execution results.
|
| 2 |
+
|
| 3 |
+
These are framework-internal dataclasses (not Pydantic) used throughout the
|
| 4 |
+
simulation, verifier, and primitive layers. The OpenEnv-facing Pydantic
|
| 5 |
+
models live in `forgeenv.env.actions` / `forgeenv.env.observations`.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class Task:
|
| 15 |
+
"""A HuggingFace training script with execution metadata."""
|
| 16 |
+
|
| 17 |
+
task_id: str
|
| 18 |
+
description: str
|
| 19 |
+
script_content: str
|
| 20 |
+
difficulty: str # "easy", "medium", "hard"
|
| 21 |
+
category: str = "general"
|
| 22 |
+
expected_loss_range: tuple[float, float] = (0.0, 5.0)
|
| 23 |
+
expected_accuracy_range: tuple[float, float] = (0.0, 1.0)
|
| 24 |
+
checkpoint_output_path: str = "/tmp/forge_output/checkpoint"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class ExecutionResult:
|
| 29 |
+
"""Result of executing a Python script in the sandbox."""
|
| 30 |
+
|
| 31 |
+
exit_code: int
|
| 32 |
+
stdout: str
|
| 33 |
+
stderr: str
|
| 34 |
+
wall_time_ms: int
|
| 35 |
+
checkpoint_exists: bool = False
|
| 36 |
+
peak_memory_mb: float = 0.0
|
| 37 |
+
script_content: str = ""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class ValidationResult:
|
| 42 |
+
"""Result of AST validation on a script."""
|
| 43 |
+
|
| 44 |
+
is_valid: bool
|
| 45 |
+
violations: list[str] = field(default_factory=list)
|
forgeenv/tasks/seed_corpus/albert_qa.py
CHANGED
|
@@ -1,67 +1,67 @@
|
|
| 1 |
-
"""ALBERT-tiny extractive QA on 100-sample SQuAD subset."""
|
| 2 |
-
from transformers import (
|
| 3 |
-
AutoTokenizer,
|
| 4 |
-
AutoModelForQuestionAnswering,
|
| 5 |
-
Trainer,
|
| 6 |
-
TrainingArguments,
|
| 7 |
-
DefaultDataCollator,
|
| 8 |
-
)
|
| 9 |
-
from datasets import load_dataset
|
| 10 |
-
|
| 11 |
-
dataset = load_dataset("squad", split="train[:100]")
|
| 12 |
-
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def prepare(examples):
|
| 16 |
-
enc = tokenizer(
|
| 17 |
-
examples["question"],
|
| 18 |
-
examples["context"],
|
| 19 |
-
max_length=128,
|
| 20 |
-
truncation="only_second",
|
| 21 |
-
padding="max_length",
|
| 22 |
-
return_offsets_mapping=True,
|
| 23 |
-
)
|
| 24 |
-
start_positions, end_positions = [], []
|
| 25 |
-
for i, offsets in enumerate(enc["offset_mapping"]):
|
| 26 |
-
answer = examples["answers"][i]
|
| 27 |
-
start_char = answer["answer_start"][0]
|
| 28 |
-
end_char = start_char + len(answer["text"][0])
|
| 29 |
-
|
| 30 |
-
token_start = next(
|
| 31 |
-
(idx for idx, (a, b) in enumerate(offsets) if a <= start_char < b), 0
|
| 32 |
-
)
|
| 33 |
-
token_end = next(
|
| 34 |
-
(idx for idx, (a, b) in enumerate(offsets) if a < end_char <= b), token_start
|
| 35 |
-
)
|
| 36 |
-
start_positions.append(token_start)
|
| 37 |
-
end_positions.append(token_end)
|
| 38 |
-
|
| 39 |
-
enc["start_positions"] = start_positions
|
| 40 |
-
enc["end_positions"] = end_positions
|
| 41 |
-
enc.pop("offset_mapping")
|
| 42 |
-
return enc
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
dataset = dataset.map(prepare, batched=True, remove_columns=dataset.column_names)
|
| 46 |
-
|
| 47 |
-
model = AutoModelForQuestionAnswering.from_pretrained("albert-base-v2")
|
| 48 |
-
|
| 49 |
-
training_args = TrainingArguments(
|
| 50 |
-
output_dir="/tmp/forge_output/checkpoint",
|
| 51 |
-
num_train_epochs=1,
|
| 52 |
-
per_device_train_batch_size=4,
|
| 53 |
-
logging_steps=5,
|
| 54 |
-
save_strategy="epoch",
|
| 55 |
-
no_cuda=True,
|
| 56 |
-
report_to="none",
|
| 57 |
-
)
|
| 58 |
-
|
| 59 |
-
trainer = Trainer(
|
| 60 |
-
model=model,
|
| 61 |
-
args=training_args,
|
| 62 |
-
train_dataset=dataset,
|
| 63 |
-
data_collator=DefaultDataCollator(),
|
| 64 |
-
)
|
| 65 |
-
trainer.train()
|
| 66 |
-
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 67 |
-
print("TRAINING_COMPLETE")
|
|
|
|
| 1 |
+
"""ALBERT-tiny extractive QA on 100-sample SQuAD subset."""
|
| 2 |
+
from transformers import (
|
| 3 |
+
AutoTokenizer,
|
| 4 |
+
AutoModelForQuestionAnswering,
|
| 5 |
+
Trainer,
|
| 6 |
+
TrainingArguments,
|
| 7 |
+
DefaultDataCollator,
|
| 8 |
+
)
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
|
| 11 |
+
dataset = load_dataset("squad", split="train[:100]")
|
| 12 |
+
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def prepare(examples):
|
| 16 |
+
enc = tokenizer(
|
| 17 |
+
examples["question"],
|
| 18 |
+
examples["context"],
|
| 19 |
+
max_length=128,
|
| 20 |
+
truncation="only_second",
|
| 21 |
+
padding="max_length",
|
| 22 |
+
return_offsets_mapping=True,
|
| 23 |
+
)
|
| 24 |
+
start_positions, end_positions = [], []
|
| 25 |
+
for i, offsets in enumerate(enc["offset_mapping"]):
|
| 26 |
+
answer = examples["answers"][i]
|
| 27 |
+
start_char = answer["answer_start"][0]
|
| 28 |
+
end_char = start_char + len(answer["text"][0])
|
| 29 |
+
|
| 30 |
+
token_start = next(
|
| 31 |
+
(idx for idx, (a, b) in enumerate(offsets) if a <= start_char < b), 0
|
| 32 |
+
)
|
| 33 |
+
token_end = next(
|
| 34 |
+
(idx for idx, (a, b) in enumerate(offsets) if a < end_char <= b), token_start
|
| 35 |
+
)
|
| 36 |
+
start_positions.append(token_start)
|
| 37 |
+
end_positions.append(token_end)
|
| 38 |
+
|
| 39 |
+
enc["start_positions"] = start_positions
|
| 40 |
+
enc["end_positions"] = end_positions
|
| 41 |
+
enc.pop("offset_mapping")
|
| 42 |
+
return enc
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
dataset = dataset.map(prepare, batched=True, remove_columns=dataset.column_names)
|
| 46 |
+
|
| 47 |
+
model = AutoModelForQuestionAnswering.from_pretrained("albert-base-v2")
|
| 48 |
+
|
| 49 |
+
training_args = TrainingArguments(
|
| 50 |
+
output_dir="/tmp/forge_output/checkpoint",
|
| 51 |
+
num_train_epochs=1,
|
| 52 |
+
per_device_train_batch_size=4,
|
| 53 |
+
logging_steps=5,
|
| 54 |
+
save_strategy="epoch",
|
| 55 |
+
no_cuda=True,
|
| 56 |
+
report_to="none",
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
trainer = Trainer(
|
| 60 |
+
model=model,
|
| 61 |
+
args=training_args,
|
| 62 |
+
train_dataset=dataset,
|
| 63 |
+
data_collator=DefaultDataCollator(),
|
| 64 |
+
)
|
| 65 |
+
trainer.train()
|
| 66 |
+
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 67 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/seed_corpus/bert_ner.py
CHANGED
|
@@ -1,55 +1,55 @@
|
|
| 1 |
-
"""Bert tiny NER fine-tuning on a 200-sample CoNLL-2003 subset."""
|
| 2 |
-
from transformers import (
|
| 3 |
-
AutoTokenizer,
|
| 4 |
-
AutoModelForTokenClassification,
|
| 5 |
-
Trainer,
|
| 6 |
-
TrainingArguments,
|
| 7 |
-
DataCollatorForTokenClassification,
|
| 8 |
-
)
|
| 9 |
-
from datasets import load_dataset
|
| 10 |
-
|
| 11 |
-
dataset = load_dataset("conll2003", split="train[:200]")
|
| 12 |
-
tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def tokenize_and_align(example):
|
| 16 |
-
enc = tokenizer(example["tokens"], is_split_into_words=True, truncation=True, max_length=64)
|
| 17 |
-
word_ids = enc.word_ids()
|
| 18 |
-
labels = []
|
| 19 |
-
prev_id = None
|
| 20 |
-
for wid in word_ids:
|
| 21 |
-
if wid is None:
|
| 22 |
-
labels.append(-100)
|
| 23 |
-
elif wid != prev_id:
|
| 24 |
-
labels.append(example["ner_tags"][wid])
|
| 25 |
-
else:
|
| 26 |
-
labels.append(-100)
|
| 27 |
-
prev_id = wid
|
| 28 |
-
enc["labels"] = labels
|
| 29 |
-
return enc
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
dataset = dataset.map(tokenize_and_align, remove_columns=dataset.column_names)
|
| 33 |
-
|
| 34 |
-
model = AutoModelForTokenClassification.from_pretrained("prajjwal1/bert-tiny", num_labels=9)
|
| 35 |
-
|
| 36 |
-
training_args = TrainingArguments(
|
| 37 |
-
output_dir="/tmp/forge_output/checkpoint",
|
| 38 |
-
num_train_epochs=1,
|
| 39 |
-
per_device_train_batch_size=8,
|
| 40 |
-
logging_steps=5,
|
| 41 |
-
save_strategy="epoch",
|
| 42 |
-
no_cuda=True,
|
| 43 |
-
report_to="none",
|
| 44 |
-
)
|
| 45 |
-
|
| 46 |
-
trainer = Trainer(
|
| 47 |
-
model=model,
|
| 48 |
-
args=training_args,
|
| 49 |
-
train_dataset=dataset,
|
| 50 |
-
data_collator=DataCollatorForTokenClassification(tokenizer),
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
trainer.train()
|
| 54 |
-
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 55 |
-
print("TRAINING_COMPLETE")
|
|
|
|
| 1 |
+
"""Bert tiny NER fine-tuning on a 200-sample CoNLL-2003 subset."""
|
| 2 |
+
from transformers import (
|
| 3 |
+
AutoTokenizer,
|
| 4 |
+
AutoModelForTokenClassification,
|
| 5 |
+
Trainer,
|
| 6 |
+
TrainingArguments,
|
| 7 |
+
DataCollatorForTokenClassification,
|
| 8 |
+
)
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
|
| 11 |
+
dataset = load_dataset("conll2003", split="train[:200]")
|
| 12 |
+
tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def tokenize_and_align(example):
|
| 16 |
+
enc = tokenizer(example["tokens"], is_split_into_words=True, truncation=True, max_length=64)
|
| 17 |
+
word_ids = enc.word_ids()
|
| 18 |
+
labels = []
|
| 19 |
+
prev_id = None
|
| 20 |
+
for wid in word_ids:
|
| 21 |
+
if wid is None:
|
| 22 |
+
labels.append(-100)
|
| 23 |
+
elif wid != prev_id:
|
| 24 |
+
labels.append(example["ner_tags"][wid])
|
| 25 |
+
else:
|
| 26 |
+
labels.append(-100)
|
| 27 |
+
prev_id = wid
|
| 28 |
+
enc["labels"] = labels
|
| 29 |
+
return enc
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
dataset = dataset.map(tokenize_and_align, remove_columns=dataset.column_names)
|
| 33 |
+
|
| 34 |
+
model = AutoModelForTokenClassification.from_pretrained("prajjwal1/bert-tiny", num_labels=9)
|
| 35 |
+
|
| 36 |
+
training_args = TrainingArguments(
|
| 37 |
+
output_dir="/tmp/forge_output/checkpoint",
|
| 38 |
+
num_train_epochs=1,
|
| 39 |
+
per_device_train_batch_size=8,
|
| 40 |
+
logging_steps=5,
|
| 41 |
+
save_strategy="epoch",
|
| 42 |
+
no_cuda=True,
|
| 43 |
+
report_to="none",
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
trainer = Trainer(
|
| 47 |
+
model=model,
|
| 48 |
+
args=training_args,
|
| 49 |
+
train_dataset=dataset,
|
| 50 |
+
data_collator=DataCollatorForTokenClassification(tokenizer),
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
trainer.train()
|
| 54 |
+
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 55 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/seed_corpus/distilbert_sst2.py
CHANGED
|
@@ -1,53 +1,53 @@
|
|
| 1 |
-
"""DistilBERT fine-tuning on a tiny SST-2 subset.
|
| 2 |
-
|
| 3 |
-
Minimal HuggingFace text-classification training script. Should complete
|
| 4 |
-
in ~60s on CPU.
|
| 5 |
-
"""
|
| 6 |
-
from transformers import (
|
| 7 |
-
DistilBertTokenizer,
|
| 8 |
-
DistilBertForSequenceClassification,
|
| 9 |
-
Trainer,
|
| 10 |
-
TrainingArguments,
|
| 11 |
-
)
|
| 12 |
-
from datasets import load_dataset
|
| 13 |
-
|
| 14 |
-
dataset = load_dataset("glue", "sst2", split="train[:500]")
|
| 15 |
-
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def tokenize_function(examples):
|
| 19 |
-
return tokenizer(
|
| 20 |
-
examples["sentence"],
|
| 21 |
-
padding="max_length",
|
| 22 |
-
truncation=True,
|
| 23 |
-
max_length=64,
|
| 24 |
-
)
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
dataset = dataset.map(tokenize_function, batched=True)
|
| 28 |
-
dataset = dataset.rename_column("label", "labels")
|
| 29 |
-
dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
|
| 30 |
-
|
| 31 |
-
model = DistilBertForSequenceClassification.from_pretrained(
|
| 32 |
-
"distilbert-base-uncased", num_labels=2
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
-
training_args = TrainingArguments(
|
| 36 |
-
output_dir="/tmp/forge_output/checkpoint",
|
| 37 |
-
num_train_epochs=1,
|
| 38 |
-
per_device_train_batch_size=16,
|
| 39 |
-
logging_steps=5,
|
| 40 |
-
save_strategy="epoch",
|
| 41 |
-
no_cuda=True,
|
| 42 |
-
report_to="none",
|
| 43 |
-
)
|
| 44 |
-
|
| 45 |
-
trainer = Trainer(
|
| 46 |
-
model=model,
|
| 47 |
-
args=training_args,
|
| 48 |
-
train_dataset=dataset,
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
-
trainer.train()
|
| 52 |
-
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 53 |
-
print("TRAINING_COMPLETE")
|
|
|
|
| 1 |
+
"""DistilBERT fine-tuning on a tiny SST-2 subset.
|
| 2 |
+
|
| 3 |
+
Minimal HuggingFace text-classification training script. Should complete
|
| 4 |
+
in ~60s on CPU.
|
| 5 |
+
"""
|
| 6 |
+
from transformers import (
|
| 7 |
+
DistilBertTokenizer,
|
| 8 |
+
DistilBertForSequenceClassification,
|
| 9 |
+
Trainer,
|
| 10 |
+
TrainingArguments,
|
| 11 |
+
)
|
| 12 |
+
from datasets import load_dataset
|
| 13 |
+
|
| 14 |
+
dataset = load_dataset("glue", "sst2", split="train[:500]")
|
| 15 |
+
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def tokenize_function(examples):
|
| 19 |
+
return tokenizer(
|
| 20 |
+
examples["sentence"],
|
| 21 |
+
padding="max_length",
|
| 22 |
+
truncation=True,
|
| 23 |
+
max_length=64,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
dataset = dataset.map(tokenize_function, batched=True)
|
| 28 |
+
dataset = dataset.rename_column("label", "labels")
|
| 29 |
+
dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
|
| 30 |
+
|
| 31 |
+
model = DistilBertForSequenceClassification.from_pretrained(
|
| 32 |
+
"distilbert-base-uncased", num_labels=2
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
training_args = TrainingArguments(
|
| 36 |
+
output_dir="/tmp/forge_output/checkpoint",
|
| 37 |
+
num_train_epochs=1,
|
| 38 |
+
per_device_train_batch_size=16,
|
| 39 |
+
logging_steps=5,
|
| 40 |
+
save_strategy="epoch",
|
| 41 |
+
no_cuda=True,
|
| 42 |
+
report_to="none",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
trainer = Trainer(
|
| 46 |
+
model=model,
|
| 47 |
+
args=training_args,
|
| 48 |
+
train_dataset=dataset,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
trainer.train()
|
| 52 |
+
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 53 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/seed_corpus/electra_classification.py
CHANGED
|
@@ -1,44 +1,44 @@
|
|
| 1 |
-
"""ELECTRA-small classification on 400-sample AG News (4-way text classification)."""
|
| 2 |
-
from transformers import (
|
| 3 |
-
AutoTokenizer,
|
| 4 |
-
AutoModelForSequenceClassification,
|
| 5 |
-
Trainer,
|
| 6 |
-
TrainingArguments,
|
| 7 |
-
)
|
| 8 |
-
from datasets import load_dataset
|
| 9 |
-
|
| 10 |
-
dataset = load_dataset("ag_news", split="train[:400]")
|
| 11 |
-
tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator")
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def tokenize(examples):
|
| 15 |
-
return tokenizer(
|
| 16 |
-
examples["text"],
|
| 17 |
-
padding="max_length",
|
| 18 |
-
truncation=True,
|
| 19 |
-
max_length=64,
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
dataset = dataset.map(tokenize, batched=True)
|
| 24 |
-
dataset = dataset.rename_column("label", "labels")
|
| 25 |
-
dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
|
| 26 |
-
|
| 27 |
-
model = AutoModelForSequenceClassification.from_pretrained(
|
| 28 |
-
"google/electra-small-discriminator", num_labels=4
|
| 29 |
-
)
|
| 30 |
-
|
| 31 |
-
training_args = TrainingArguments(
|
| 32 |
-
output_dir="/tmp/forge_output/checkpoint",
|
| 33 |
-
num_train_epochs=1,
|
| 34 |
-
per_device_train_batch_size=8,
|
| 35 |
-
logging_steps=5,
|
| 36 |
-
save_strategy="epoch",
|
| 37 |
-
no_cuda=True,
|
| 38 |
-
report_to="none",
|
| 39 |
-
)
|
| 40 |
-
|
| 41 |
-
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
|
| 42 |
-
trainer.train()
|
| 43 |
-
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 44 |
-
print("TRAINING_COMPLETE")
|
|
|
|
| 1 |
+
"""ELECTRA-small classification on 400-sample AG News (4-way text classification)."""
|
| 2 |
+
from transformers import (
|
| 3 |
+
AutoTokenizer,
|
| 4 |
+
AutoModelForSequenceClassification,
|
| 5 |
+
Trainer,
|
| 6 |
+
TrainingArguments,
|
| 7 |
+
)
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
|
| 10 |
+
dataset = load_dataset("ag_news", split="train[:400]")
|
| 11 |
+
tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def tokenize(examples):
|
| 15 |
+
return tokenizer(
|
| 16 |
+
examples["text"],
|
| 17 |
+
padding="max_length",
|
| 18 |
+
truncation=True,
|
| 19 |
+
max_length=64,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
dataset = dataset.map(tokenize, batched=True)
|
| 24 |
+
dataset = dataset.rename_column("label", "labels")
|
| 25 |
+
dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
|
| 26 |
+
|
| 27 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 28 |
+
"google/electra-small-discriminator", num_labels=4
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
training_args = TrainingArguments(
|
| 32 |
+
output_dir="/tmp/forge_output/checkpoint",
|
| 33 |
+
num_train_epochs=1,
|
| 34 |
+
per_device_train_batch_size=8,
|
| 35 |
+
logging_steps=5,
|
| 36 |
+
save_strategy="epoch",
|
| 37 |
+
no_cuda=True,
|
| 38 |
+
report_to="none",
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
|
| 42 |
+
trainer.train()
|
| 43 |
+
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 44 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/seed_corpus/gpt2_textgen.py
CHANGED
|
@@ -1,43 +1,43 @@
|
|
| 1 |
-
"""DistilGPT2 causal-LM fine-tuning on 300 lines of WikiText (text generation)."""
|
| 2 |
-
from transformers import (
|
| 3 |
-
AutoTokenizer,
|
| 4 |
-
AutoModelForCausalLM,
|
| 5 |
-
Trainer,
|
| 6 |
-
TrainingArguments,
|
| 7 |
-
DataCollatorForLanguageModeling,
|
| 8 |
-
)
|
| 9 |
-
from datasets import load_dataset
|
| 10 |
-
|
| 11 |
-
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:300]")
|
| 12 |
-
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
|
| 13 |
-
tokenizer.pad_token = tokenizer.eos_token
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def tokenize(examples):
|
| 17 |
-
return tokenizer(examples["text"], truncation=True, max_length=64)
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
|
| 21 |
-
|
| 22 |
-
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
|
| 23 |
-
|
| 24 |
-
training_args = TrainingArguments(
|
| 25 |
-
output_dir="/tmp/forge_output/checkpoint",
|
| 26 |
-
num_train_epochs=1,
|
| 27 |
-
per_device_train_batch_size=4,
|
| 28 |
-
logging_steps=5,
|
| 29 |
-
save_strategy="epoch",
|
| 30 |
-
no_cuda=True,
|
| 31 |
-
report_to="none",
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
trainer = Trainer(
|
| 35 |
-
model=model,
|
| 36 |
-
args=training_args,
|
| 37 |
-
train_dataset=dataset,
|
| 38 |
-
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
| 39 |
-
)
|
| 40 |
-
|
| 41 |
-
trainer.train()
|
| 42 |
-
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 43 |
-
print("TRAINING_COMPLETE")
|
|
|
|
| 1 |
+
"""DistilGPT2 causal-LM fine-tuning on 300 lines of WikiText (text generation)."""
|
| 2 |
+
from transformers import (
|
| 3 |
+
AutoTokenizer,
|
| 4 |
+
AutoModelForCausalLM,
|
| 5 |
+
Trainer,
|
| 6 |
+
TrainingArguments,
|
| 7 |
+
DataCollatorForLanguageModeling,
|
| 8 |
+
)
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
|
| 11 |
+
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:300]")
|
| 12 |
+
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
|
| 13 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def tokenize(examples):
|
| 17 |
+
return tokenizer(examples["text"], truncation=True, max_length=64)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
|
| 21 |
+
|
| 22 |
+
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
|
| 23 |
+
|
| 24 |
+
training_args = TrainingArguments(
|
| 25 |
+
output_dir="/tmp/forge_output/checkpoint",
|
| 26 |
+
num_train_epochs=1,
|
| 27 |
+
per_device_train_batch_size=4,
|
| 28 |
+
logging_steps=5,
|
| 29 |
+
save_strategy="epoch",
|
| 30 |
+
no_cuda=True,
|
| 31 |
+
report_to="none",
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
trainer = Trainer(
|
| 35 |
+
model=model,
|
| 36 |
+
args=training_args,
|
| 37 |
+
train_dataset=dataset,
|
| 38 |
+
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
trainer.train()
|
| 42 |
+
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 43 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/seed_corpus/logistic_classifier.py
CHANGED
|
@@ -1,36 +1,36 @@
|
|
| 1 |
-
"""Sklearn logistic-regression baseline on a 500-sample tabular task.
|
| 2 |
-
|
| 3 |
-
Sanity baseline that doesn't require torch / transformers / datasets.
|
| 4 |
-
"""
|
| 5 |
-
import json
|
| 6 |
-
import pickle
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
|
| 9 |
-
import numpy as np
|
| 10 |
-
from sklearn.datasets import make_classification
|
| 11 |
-
from sklearn.linear_model import LogisticRegression
|
| 12 |
-
from sklearn.model_selection import train_test_split
|
| 13 |
-
|
| 14 |
-
X, y = make_classification(
|
| 15 |
-
n_samples=500, n_features=20, n_informative=10, random_state=0
|
| 16 |
-
)
|
| 17 |
-
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
|
| 18 |
-
|
| 19 |
-
model = LogisticRegression(max_iter=200)
|
| 20 |
-
for step in range(1, 11):
|
| 21 |
-
model.set_params(max_iter=step * 20)
|
| 22 |
-
model.fit(X_train, y_train)
|
| 23 |
-
train_loss = -np.mean(np.log(np.maximum(model.predict_proba(X_train)[np.arange(len(y_train)), y_train], 1e-9)))
|
| 24 |
-
print(f"step={step} loss={train_loss:.4f}")
|
| 25 |
-
|
| 26 |
-
acc = model.score(X_test, y_test)
|
| 27 |
-
print(f"eval_accuracy={acc:.4f}")
|
| 28 |
-
|
| 29 |
-
ckpt_dir = Path("/tmp/forge_output/checkpoint")
|
| 30 |
-
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 31 |
-
with open(ckpt_dir / "logreg.pkl", "wb") as f:
|
| 32 |
-
pickle.dump(model, f)
|
| 33 |
-
with open(ckpt_dir / "metrics.json", "w") as f:
|
| 34 |
-
json.dump({"accuracy": acc}, f)
|
| 35 |
-
|
| 36 |
-
print("TRAINING_COMPLETE")
|
|
|
|
| 1 |
+
"""Sklearn logistic-regression baseline on a 500-sample tabular task.
|
| 2 |
+
|
| 3 |
+
Sanity baseline that doesn't require torch / transformers / datasets.
|
| 4 |
+
"""
|
| 5 |
+
import json
|
| 6 |
+
import pickle
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from sklearn.datasets import make_classification
|
| 11 |
+
from sklearn.linear_model import LogisticRegression
|
| 12 |
+
from sklearn.model_selection import train_test_split
|
| 13 |
+
|
| 14 |
+
X, y = make_classification(
|
| 15 |
+
n_samples=500, n_features=20, n_informative=10, random_state=0
|
| 16 |
+
)
|
| 17 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
|
| 18 |
+
|
| 19 |
+
model = LogisticRegression(max_iter=200)
|
| 20 |
+
for step in range(1, 11):
|
| 21 |
+
model.set_params(max_iter=step * 20)
|
| 22 |
+
model.fit(X_train, y_train)
|
| 23 |
+
train_loss = -np.mean(np.log(np.maximum(model.predict_proba(X_train)[np.arange(len(y_train)), y_train], 1e-9)))
|
| 24 |
+
print(f"step={step} loss={train_loss:.4f}")
|
| 25 |
+
|
| 26 |
+
acc = model.score(X_test, y_test)
|
| 27 |
+
print(f"eval_accuracy={acc:.4f}")
|
| 28 |
+
|
| 29 |
+
ckpt_dir = Path("/tmp/forge_output/checkpoint")
|
| 30 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 31 |
+
with open(ckpt_dir / "logreg.pkl", "wb") as f:
|
| 32 |
+
pickle.dump(model, f)
|
| 33 |
+
with open(ckpt_dir / "metrics.json", "w") as f:
|
| 34 |
+
json.dump({"accuracy": acc}, f)
|
| 35 |
+
|
| 36 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/seed_corpus/roberta_sentiment.py
CHANGED
|
@@ -1,44 +1,44 @@
|
|
| 1 |
-
"""DistilRoberta sentiment classification on 400-sample IMDB subset."""
|
| 2 |
-
from transformers import (
|
| 3 |
-
AutoTokenizer,
|
| 4 |
-
AutoModelForSequenceClassification,
|
| 5 |
-
Trainer,
|
| 6 |
-
TrainingArguments,
|
| 7 |
-
)
|
| 8 |
-
from datasets import load_dataset
|
| 9 |
-
|
| 10 |
-
dataset = load_dataset("imdb", split="train[:400]")
|
| 11 |
-
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def tokenize(examples):
|
| 15 |
-
return tokenizer(
|
| 16 |
-
examples["text"],
|
| 17 |
-
padding="max_length",
|
| 18 |
-
truncation=True,
|
| 19 |
-
max_length=64,
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
dataset = dataset.map(tokenize, batched=True)
|
| 24 |
-
dataset = dataset.rename_column("label", "labels")
|
| 25 |
-
dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
|
| 26 |
-
|
| 27 |
-
model = AutoModelForSequenceClassification.from_pretrained(
|
| 28 |
-
"distilroberta-base", num_labels=2
|
| 29 |
-
)
|
| 30 |
-
|
| 31 |
-
training_args = TrainingArguments(
|
| 32 |
-
output_dir="/tmp/forge_output/checkpoint",
|
| 33 |
-
num_train_epochs=1,
|
| 34 |
-
per_device_train_batch_size=8,
|
| 35 |
-
logging_steps=5,
|
| 36 |
-
save_strategy="epoch",
|
| 37 |
-
no_cuda=True,
|
| 38 |
-
report_to="none",
|
| 39 |
-
)
|
| 40 |
-
|
| 41 |
-
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
|
| 42 |
-
trainer.train()
|
| 43 |
-
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 44 |
-
print("TRAINING_COMPLETE")
|
|
|
|
| 1 |
+
"""DistilRoberta sentiment classification on 400-sample IMDB subset."""
|
| 2 |
+
from transformers import (
|
| 3 |
+
AutoTokenizer,
|
| 4 |
+
AutoModelForSequenceClassification,
|
| 5 |
+
Trainer,
|
| 6 |
+
TrainingArguments,
|
| 7 |
+
)
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
|
| 10 |
+
dataset = load_dataset("imdb", split="train[:400]")
|
| 11 |
+
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def tokenize(examples):
|
| 15 |
+
return tokenizer(
|
| 16 |
+
examples["text"],
|
| 17 |
+
padding="max_length",
|
| 18 |
+
truncation=True,
|
| 19 |
+
max_length=64,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
dataset = dataset.map(tokenize, batched=True)
|
| 24 |
+
dataset = dataset.rename_column("label", "labels")
|
| 25 |
+
dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
|
| 26 |
+
|
| 27 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 28 |
+
"distilroberta-base", num_labels=2
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
training_args = TrainingArguments(
|
| 32 |
+
output_dir="/tmp/forge_output/checkpoint",
|
| 33 |
+
num_train_epochs=1,
|
| 34 |
+
per_device_train_batch_size=8,
|
| 35 |
+
logging_steps=5,
|
| 36 |
+
save_strategy="epoch",
|
| 37 |
+
no_cuda=True,
|
| 38 |
+
report_to="none",
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
|
| 42 |
+
trainer.train()
|
| 43 |
+
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 44 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/seed_corpus/simple_regression.py
CHANGED
|
@@ -1,28 +1,28 @@
|
|
| 1 |
-
"""Tiny PyTorch regression on synthetic data (no HF imports — sanity baseline)."""
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
|
| 5 |
-
torch.manual_seed(0)
|
| 6 |
-
x = torch.randn(500, 4)
|
| 7 |
-
y = (x @ torch.tensor([1.5, -2.0, 0.5, 3.0])) + 0.1 * torch.randn(500)
|
| 8 |
-
|
| 9 |
-
model = nn.Sequential(
|
| 10 |
-
nn.Linear(4, 16),
|
| 11 |
-
nn.ReLU(),
|
| 12 |
-
nn.Linear(16, 1),
|
| 13 |
-
)
|
| 14 |
-
|
| 15 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
|
| 16 |
-
criterion = nn.MSELoss()
|
| 17 |
-
|
| 18 |
-
for epoch in range(50):
|
| 19 |
-
optimizer.zero_grad()
|
| 20 |
-
preds = model(x).squeeze(-1)
|
| 21 |
-
loss = criterion(preds, y)
|
| 22 |
-
loss.backward()
|
| 23 |
-
optimizer.step()
|
| 24 |
-
if (epoch + 1) % 5 == 0:
|
| 25 |
-
print(f"step={epoch + 1} loss={loss.item():.4f}")
|
| 26 |
-
|
| 27 |
-
torch.save(model.state_dict(), "/tmp/forge_output/checkpoint/regression.pt")
|
| 28 |
-
print("TRAINING_COMPLETE")
|
|
|
|
| 1 |
+
"""Tiny PyTorch regression on synthetic data (no HF imports — sanity baseline)."""
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
torch.manual_seed(0)
|
| 6 |
+
x = torch.randn(500, 4)
|
| 7 |
+
y = (x @ torch.tensor([1.5, -2.0, 0.5, 3.0])) + 0.1 * torch.randn(500)
|
| 8 |
+
|
| 9 |
+
model = nn.Sequential(
|
| 10 |
+
nn.Linear(4, 16),
|
| 11 |
+
nn.ReLU(),
|
| 12 |
+
nn.Linear(16, 1),
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
|
| 16 |
+
criterion = nn.MSELoss()
|
| 17 |
+
|
| 18 |
+
for epoch in range(50):
|
| 19 |
+
optimizer.zero_grad()
|
| 20 |
+
preds = model(x).squeeze(-1)
|
| 21 |
+
loss = criterion(preds, y)
|
| 22 |
+
loss.backward()
|
| 23 |
+
optimizer.step()
|
| 24 |
+
if (epoch + 1) % 5 == 0:
|
| 25 |
+
print(f"step={epoch + 1} loss={loss.item():.4f}")
|
| 26 |
+
|
| 27 |
+
torch.save(model.state_dict(), "/tmp/forge_output/checkpoint/regression.pt")
|
| 28 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/seed_corpus/t5_summarization.py
CHANGED
|
@@ -1,55 +1,55 @@
|
|
| 1 |
-
"""Tiny T5 fine-tuning for summarization on 100-sample CNN/DailyMail."""
|
| 2 |
-
from transformers import (
|
| 3 |
-
AutoTokenizer,
|
| 4 |
-
AutoModelForSeq2SeqLM,
|
| 5 |
-
DataCollatorForSeq2Seq,
|
| 6 |
-
Seq2SeqTrainer,
|
| 7 |
-
Seq2SeqTrainingArguments,
|
| 8 |
-
)
|
| 9 |
-
from datasets import load_dataset
|
| 10 |
-
|
| 11 |
-
dataset = load_dataset("cnn_dailymail", "3.0.0", split="train[:100]")
|
| 12 |
-
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def preprocess(examples):
|
| 16 |
-
inputs = tokenizer(
|
| 17 |
-
["summarize: " + a for a in examples["article"]],
|
| 18 |
-
max_length=128,
|
| 19 |
-
truncation=True,
|
| 20 |
-
padding="max_length",
|
| 21 |
-
)
|
| 22 |
-
targets = tokenizer(
|
| 23 |
-
examples["highlights"],
|
| 24 |
-
max_length=32,
|
| 25 |
-
truncation=True,
|
| 26 |
-
padding="max_length",
|
| 27 |
-
)
|
| 28 |
-
inputs["labels"] = targets["input_ids"]
|
| 29 |
-
return inputs
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
dataset = dataset.map(preprocess, batched=True, remove_columns=dataset.column_names)
|
| 33 |
-
|
| 34 |
-
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
|
| 35 |
-
|
| 36 |
-
training_args = Seq2SeqTrainingArguments(
|
| 37 |
-
output_dir="/tmp/forge_output/checkpoint",
|
| 38 |
-
num_train_epochs=1,
|
| 39 |
-
per_device_train_batch_size=4,
|
| 40 |
-
logging_steps=5,
|
| 41 |
-
save_strategy="epoch",
|
| 42 |
-
no_cuda=True,
|
| 43 |
-
report_to="none",
|
| 44 |
-
predict_with_generate=False,
|
| 45 |
-
)
|
| 46 |
-
|
| 47 |
-
trainer = Seq2SeqTrainer(
|
| 48 |
-
model=model,
|
| 49 |
-
args=training_args,
|
| 50 |
-
train_dataset=dataset,
|
| 51 |
-
data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
|
| 52 |
-
)
|
| 53 |
-
trainer.train()
|
| 54 |
-
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 55 |
-
print("TRAINING_COMPLETE")
|
|
|
|
| 1 |
+
"""Tiny T5 fine-tuning for summarization on 100-sample CNN/DailyMail."""
|
| 2 |
+
from transformers import (
|
| 3 |
+
AutoTokenizer,
|
| 4 |
+
AutoModelForSeq2SeqLM,
|
| 5 |
+
DataCollatorForSeq2Seq,
|
| 6 |
+
Seq2SeqTrainer,
|
| 7 |
+
Seq2SeqTrainingArguments,
|
| 8 |
+
)
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
|
| 11 |
+
dataset = load_dataset("cnn_dailymail", "3.0.0", split="train[:100]")
|
| 12 |
+
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def preprocess(examples):
|
| 16 |
+
inputs = tokenizer(
|
| 17 |
+
["summarize: " + a for a in examples["article"]],
|
| 18 |
+
max_length=128,
|
| 19 |
+
truncation=True,
|
| 20 |
+
padding="max_length",
|
| 21 |
+
)
|
| 22 |
+
targets = tokenizer(
|
| 23 |
+
examples["highlights"],
|
| 24 |
+
max_length=32,
|
| 25 |
+
truncation=True,
|
| 26 |
+
padding="max_length",
|
| 27 |
+
)
|
| 28 |
+
inputs["labels"] = targets["input_ids"]
|
| 29 |
+
return inputs
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
dataset = dataset.map(preprocess, batched=True, remove_columns=dataset.column_names)
|
| 33 |
+
|
| 34 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
|
| 35 |
+
|
| 36 |
+
training_args = Seq2SeqTrainingArguments(
|
| 37 |
+
output_dir="/tmp/forge_output/checkpoint",
|
| 38 |
+
num_train_epochs=1,
|
| 39 |
+
per_device_train_batch_size=4,
|
| 40 |
+
logging_steps=5,
|
| 41 |
+
save_strategy="epoch",
|
| 42 |
+
no_cuda=True,
|
| 43 |
+
report_to="none",
|
| 44 |
+
predict_with_generate=False,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
trainer = Seq2SeqTrainer(
|
| 48 |
+
model=model,
|
| 49 |
+
args=training_args,
|
| 50 |
+
train_dataset=dataset,
|
| 51 |
+
data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
|
| 52 |
+
)
|
| 53 |
+
trainer.train()
|
| 54 |
+
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 55 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/seed_corpus/tiny_mlp_mnist.py
CHANGED
|
@@ -1,38 +1,38 @@
|
|
| 1 |
-
"""Tiny PyTorch MLP on a 1000-sample MNIST subset (image classification baseline)."""
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
from torch.utils.data import DataLoader
|
| 5 |
-
from datasets import load_dataset
|
| 6 |
-
|
| 7 |
-
dataset = load_dataset("mnist", split="train[:1000]")
|
| 8 |
-
dataset = dataset.with_format("torch")
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def collate(batch):
|
| 12 |
-
pixel = torch.stack([b["image"].float().flatten() / 255.0 for b in batch])
|
| 13 |
-
labels = torch.tensor([b["label"] for b in batch])
|
| 14 |
-
return pixel, labels
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate)
|
| 18 |
-
|
| 19 |
-
model = nn.Sequential(
|
| 20 |
-
nn.Linear(784, 64),
|
| 21 |
-
nn.ReLU(),
|
| 22 |
-
nn.Linear(64, 10),
|
| 23 |
-
)
|
| 24 |
-
|
| 25 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
| 26 |
-
criterion = nn.CrossEntropyLoss()
|
| 27 |
-
|
| 28 |
-
for epoch in range(2):
|
| 29 |
-
for step, (x, y) in enumerate(loader, start=1):
|
| 30 |
-
optimizer.zero_grad()
|
| 31 |
-
loss = criterion(model(x), y)
|
| 32 |
-
loss.backward()
|
| 33 |
-
optimizer.step()
|
| 34 |
-
if step % 5 == 0:
|
| 35 |
-
print(f"step={step} loss={loss.item():.4f}")
|
| 36 |
-
|
| 37 |
-
torch.save(model.state_dict(), "/tmp/forge_output/checkpoint/mlp.pt")
|
| 38 |
-
print("TRAINING_COMPLETE")
|
|
|
|
| 1 |
+
"""Tiny PyTorch MLP on a 1000-sample MNIST subset (image classification baseline)."""
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
+
|
| 7 |
+
dataset = load_dataset("mnist", split="train[:1000]")
|
| 8 |
+
dataset = dataset.with_format("torch")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def collate(batch):
|
| 12 |
+
pixel = torch.stack([b["image"].float().flatten() / 255.0 for b in batch])
|
| 13 |
+
labels = torch.tensor([b["label"] for b in batch])
|
| 14 |
+
return pixel, labels
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate)
|
| 18 |
+
|
| 19 |
+
model = nn.Sequential(
|
| 20 |
+
nn.Linear(784, 64),
|
| 21 |
+
nn.ReLU(),
|
| 22 |
+
nn.Linear(64, 10),
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
| 26 |
+
criterion = nn.CrossEntropyLoss()
|
| 27 |
+
|
| 28 |
+
for epoch in range(2):
|
| 29 |
+
for step, (x, y) in enumerate(loader, start=1):
|
| 30 |
+
optimizer.zero_grad()
|
| 31 |
+
loss = criterion(model(x), y)
|
| 32 |
+
loss.backward()
|
| 33 |
+
optimizer.step()
|
| 34 |
+
if step % 5 == 0:
|
| 35 |
+
print(f"step={step} loss={loss.item():.4f}")
|
| 36 |
+
|
| 37 |
+
torch.save(model.state_dict(), "/tmp/forge_output/checkpoint/mlp.pt")
|
| 38 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/seed_corpus/vit_cifar10.py
CHANGED
|
@@ -1,41 +1,41 @@
|
|
| 1 |
-
"""Tiny ViT image classification on 200-sample CIFAR-10 subset."""
|
| 2 |
-
from transformers import (
|
| 3 |
-
AutoImageProcessor,
|
| 4 |
-
AutoModelForImageClassification,
|
| 5 |
-
Trainer,
|
| 6 |
-
TrainingArguments,
|
| 7 |
-
)
|
| 8 |
-
from datasets import load_dataset
|
| 9 |
-
import torch
|
| 10 |
-
|
| 11 |
-
dataset = load_dataset("cifar10", split="train[:200]")
|
| 12 |
-
processor = AutoImageProcessor.from_pretrained("WinKawaks/vit-tiny-patch16-224")
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def transform(batch):
|
| 16 |
-
images = [img.convert("RGB") for img in batch["img"]]
|
| 17 |
-
inputs = processor(images=images, return_tensors="pt")
|
| 18 |
-
inputs["labels"] = torch.tensor(batch["label"])
|
| 19 |
-
return inputs
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
dataset = dataset.with_transform(transform)
|
| 23 |
-
|
| 24 |
-
model = AutoModelForImageClassification.from_pretrained(
|
| 25 |
-
"WinKawaks/vit-tiny-patch16-224", num_labels=10, ignore_mismatched_sizes=True
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
training_args = TrainingArguments(
|
| 29 |
-
output_dir="/tmp/forge_output/checkpoint",
|
| 30 |
-
num_train_epochs=1,
|
| 31 |
-
per_device_train_batch_size=4,
|
| 32 |
-
logging_steps=5,
|
| 33 |
-
save_strategy="epoch",
|
| 34 |
-
no_cuda=True,
|
| 35 |
-
report_to="none",
|
| 36 |
-
)
|
| 37 |
-
|
| 38 |
-
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
|
| 39 |
-
trainer.train()
|
| 40 |
-
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 41 |
-
print("TRAINING_COMPLETE")
|
|
|
|
| 1 |
+
"""Tiny ViT image classification on 200-sample CIFAR-10 subset."""
|
| 2 |
+
from transformers import (
|
| 3 |
+
AutoImageProcessor,
|
| 4 |
+
AutoModelForImageClassification,
|
| 5 |
+
Trainer,
|
| 6 |
+
TrainingArguments,
|
| 7 |
+
)
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
dataset = load_dataset("cifar10", split="train[:200]")
|
| 12 |
+
processor = AutoImageProcessor.from_pretrained("WinKawaks/vit-tiny-patch16-224")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def transform(batch):
|
| 16 |
+
images = [img.convert("RGB") for img in batch["img"]]
|
| 17 |
+
inputs = processor(images=images, return_tensors="pt")
|
| 18 |
+
inputs["labels"] = torch.tensor(batch["label"])
|
| 19 |
+
return inputs
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
dataset = dataset.with_transform(transform)
|
| 23 |
+
|
| 24 |
+
model = AutoModelForImageClassification.from_pretrained(
|
| 25 |
+
"WinKawaks/vit-tiny-patch16-224", num_labels=10, ignore_mismatched_sizes=True
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
training_args = TrainingArguments(
|
| 29 |
+
output_dir="/tmp/forge_output/checkpoint",
|
| 30 |
+
num_train_epochs=1,
|
| 31 |
+
per_device_train_batch_size=4,
|
| 32 |
+
logging_steps=5,
|
| 33 |
+
save_strategy="epoch",
|
| 34 |
+
no_cuda=True,
|
| 35 |
+
report_to="none",
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
|
| 39 |
+
trainer.train()
|
| 40 |
+
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 41 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/task_sampler.py
CHANGED
|
@@ -1,105 +1,105 @@
|
|
| 1 |
-
"""Task sampler: loads the seed corpus and samples Tasks by difficulty.
|
| 2 |
-
|
| 3 |
-
Difficulty is auto-derived from script line count. Category is auto-detected
|
| 4 |
-
from script content (text_classification, ner, translation, etc.).
|
| 5 |
-
"""
|
| 6 |
-
from __future__ import annotations
|
| 7 |
-
|
| 8 |
-
import random
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
from typing import Optional
|
| 11 |
-
|
| 12 |
-
from forgeenv.tasks.models import Task
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def _detect_category(content: str) -> str:
|
| 16 |
-
cl = content.lower()
|
| 17 |
-
if "sequenceclassification" in cl or "sentiment" in cl or "ag_news" in cl or "sst2" in cl:
|
| 18 |
-
return "text_classification"
|
| 19 |
-
if "tokenclassification" in cl or "ner" in cl or "conll" in cl:
|
| 20 |
-
return "ner"
|
| 21 |
-
if "seq2seq" in cl or "translation" in cl or "summariz" in cl or "t5" in cl:
|
| 22 |
-
return "seq2seq"
|
| 23 |
-
if "causallm" in cl or "gpt2" in cl or "wikitext" in cl:
|
| 24 |
-
return "text_generation"
|
| 25 |
-
if "imageclassification" in cl or "vit" in cl or "cifar" in cl or "mnist" in cl:
|
| 26 |
-
return "image_classification"
|
| 27 |
-
if "questionanswering" in cl or "squad" in cl:
|
| 28 |
-
return "qa"
|
| 29 |
-
if "logisticregression" in cl or "make_classification" in cl:
|
| 30 |
-
return "tabular"
|
| 31 |
-
if "regression" in cl:
|
| 32 |
-
return "regression"
|
| 33 |
-
return "general"
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def _derive_difficulty(content: str) -> str:
|
| 37 |
-
lines = len(content.splitlines())
|
| 38 |
-
if lines < 30:
|
| 39 |
-
return "easy"
|
| 40 |
-
if lines < 60:
|
| 41 |
-
return "medium"
|
| 42 |
-
return "hard"
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
class TaskSampler:
|
| 46 |
-
"""Loads seed corpus and samples tasks by difficulty / category."""
|
| 47 |
-
|
| 48 |
-
def __init__(self, seed_dir: Optional[str] = None) -> None:
|
| 49 |
-
if seed_dir is None:
|
| 50 |
-
seed_dir = str(Path(__file__).parent / "seed_corpus")
|
| 51 |
-
|
| 52 |
-
self.tasks: list[Task] = []
|
| 53 |
-
self._load_corpus(seed_dir)
|
| 54 |
-
|
| 55 |
-
def _load_corpus(self, seed_dir: str) -> None:
|
| 56 |
-
corpus_path = Path(seed_dir)
|
| 57 |
-
if not corpus_path.exists():
|
| 58 |
-
return
|
| 59 |
-
|
| 60 |
-
for py_file in sorted(corpus_path.glob("*.py")):
|
| 61 |
-
if py_file.name.startswith("__"):
|
| 62 |
-
continue
|
| 63 |
-
|
| 64 |
-
content = py_file.read_text(encoding="utf-8")
|
| 65 |
-
task_id = py_file.stem
|
| 66 |
-
difficulty = _derive_difficulty(content)
|
| 67 |
-
category = _detect_category(content)
|
| 68 |
-
|
| 69 |
-
description = ""
|
| 70 |
-
if content.startswith('"""'):
|
| 71 |
-
end = content.find('"""', 3)
|
| 72 |
-
if end != -1:
|
| 73 |
-
description = content[3:end].strip()
|
| 74 |
-
|
| 75 |
-
self.tasks.append(
|
| 76 |
-
Task(
|
| 77 |
-
task_id=task_id,
|
| 78 |
-
description=description or f"Training script: {task_id}",
|
| 79 |
-
script_content=content,
|
| 80 |
-
difficulty=difficulty,
|
| 81 |
-
category=category,
|
| 82 |
-
)
|
| 83 |
-
)
|
| 84 |
-
|
| 85 |
-
def sample(self, difficulty: Optional[str] = None) -> Optional[Task]:
|
| 86 |
-
candidates = self.tasks
|
| 87 |
-
if difficulty is not None:
|
| 88 |
-
filtered = [t for t in self.tasks if t.difficulty == difficulty]
|
| 89 |
-
if filtered:
|
| 90 |
-
candidates = filtered
|
| 91 |
-
return random.choice(candidates) if candidates else None
|
| 92 |
-
|
| 93 |
-
def sample_batch(
|
| 94 |
-
self, n: int, difficulty: Optional[str] = None
|
| 95 |
-
) -> list[Task]:
|
| 96 |
-
return [t for t in (self.sample(difficulty) for _ in range(n)) if t is not None]
|
| 97 |
-
|
| 98 |
-
def get_all_categories(self) -> list[str]:
|
| 99 |
-
return sorted({t.category for t in self.tasks})
|
| 100 |
-
|
| 101 |
-
def get_by_id(self, task_id: str) -> Optional[Task]:
|
| 102 |
-
for t in self.tasks:
|
| 103 |
-
if t.task_id == task_id:
|
| 104 |
-
return t
|
| 105 |
-
return None
|
|
|
|
| 1 |
+
"""Task sampler: loads the seed corpus and samples Tasks by difficulty.
|
| 2 |
+
|
| 3 |
+
Difficulty is auto-derived from script line count. Category is auto-detected
|
| 4 |
+
from script content (text_classification, ner, translation, etc.).
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import random
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
from forgeenv.tasks.models import Task
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _detect_category(content: str) -> str:
|
| 16 |
+
cl = content.lower()
|
| 17 |
+
if "sequenceclassification" in cl or "sentiment" in cl or "ag_news" in cl or "sst2" in cl:
|
| 18 |
+
return "text_classification"
|
| 19 |
+
if "tokenclassification" in cl or "ner" in cl or "conll" in cl:
|
| 20 |
+
return "ner"
|
| 21 |
+
if "seq2seq" in cl or "translation" in cl or "summariz" in cl or "t5" in cl:
|
| 22 |
+
return "seq2seq"
|
| 23 |
+
if "causallm" in cl or "gpt2" in cl or "wikitext" in cl:
|
| 24 |
+
return "text_generation"
|
| 25 |
+
if "imageclassification" in cl or "vit" in cl or "cifar" in cl or "mnist" in cl:
|
| 26 |
+
return "image_classification"
|
| 27 |
+
if "questionanswering" in cl or "squad" in cl:
|
| 28 |
+
return "qa"
|
| 29 |
+
if "logisticregression" in cl or "make_classification" in cl:
|
| 30 |
+
return "tabular"
|
| 31 |
+
if "regression" in cl:
|
| 32 |
+
return "regression"
|
| 33 |
+
return "general"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _derive_difficulty(content: str) -> str:
|
| 37 |
+
lines = len(content.splitlines())
|
| 38 |
+
if lines < 30:
|
| 39 |
+
return "easy"
|
| 40 |
+
if lines < 60:
|
| 41 |
+
return "medium"
|
| 42 |
+
return "hard"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class TaskSampler:
|
| 46 |
+
"""Loads seed corpus and samples tasks by difficulty / category."""
|
| 47 |
+
|
| 48 |
+
def __init__(self, seed_dir: Optional[str] = None) -> None:
|
| 49 |
+
if seed_dir is None:
|
| 50 |
+
seed_dir = str(Path(__file__).parent / "seed_corpus")
|
| 51 |
+
|
| 52 |
+
self.tasks: list[Task] = []
|
| 53 |
+
self._load_corpus(seed_dir)
|
| 54 |
+
|
| 55 |
+
def _load_corpus(self, seed_dir: str) -> None:
|
| 56 |
+
corpus_path = Path(seed_dir)
|
| 57 |
+
if not corpus_path.exists():
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
for py_file in sorted(corpus_path.glob("*.py")):
|
| 61 |
+
if py_file.name.startswith("__"):
|
| 62 |
+
continue
|
| 63 |
+
|
| 64 |
+
content = py_file.read_text(encoding="utf-8")
|
| 65 |
+
task_id = py_file.stem
|
| 66 |
+
difficulty = _derive_difficulty(content)
|
| 67 |
+
category = _detect_category(content)
|
| 68 |
+
|
| 69 |
+
description = ""
|
| 70 |
+
if content.startswith('"""'):
|
| 71 |
+
end = content.find('"""', 3)
|
| 72 |
+
if end != -1:
|
| 73 |
+
description = content[3:end].strip()
|
| 74 |
+
|
| 75 |
+
self.tasks.append(
|
| 76 |
+
Task(
|
| 77 |
+
task_id=task_id,
|
| 78 |
+
description=description or f"Training script: {task_id}",
|
| 79 |
+
script_content=content,
|
| 80 |
+
difficulty=difficulty,
|
| 81 |
+
category=category,
|
| 82 |
+
)
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def sample(self, difficulty: Optional[str] = None) -> Optional[Task]:
|
| 86 |
+
candidates = self.tasks
|
| 87 |
+
if difficulty is not None:
|
| 88 |
+
filtered = [t for t in self.tasks if t.difficulty == difficulty]
|
| 89 |
+
if filtered:
|
| 90 |
+
candidates = filtered
|
| 91 |
+
return random.choice(candidates) if candidates else None
|
| 92 |
+
|
| 93 |
+
def sample_batch(
|
| 94 |
+
self, n: int, difficulty: Optional[str] = None
|
| 95 |
+
) -> list[Task]:
|
| 96 |
+
return [t for t in (self.sample(difficulty) for _ in range(n)) if t is not None]
|
| 97 |
+
|
| 98 |
+
def get_all_categories(self) -> list[str]:
|
| 99 |
+
return sorted({t.category for t in self.tasks})
|
| 100 |
+
|
| 101 |
+
def get_by_id(self, task_id: str) -> Optional[Task]:
|
| 102 |
+
for t in self.tasks:
|
| 103 |
+
if t.task_id == task_id:
|
| 104 |
+
return t
|
| 105 |
+
return None
|
forgeenv/training/grpo_drift.py
CHANGED
|
@@ -1,168 +1,168 @@
|
|
| 1 |
-
"""GRPO trainer for the Drift Generator.
|
| 2 |
-
|
| 3 |
-
Uses R-Zero's composite Challenger reward: max(0, uncertainty - repetition).
|
| 4 |
-
Each prompt is sampled `group_size` times; for every breakage we run K
|
| 5 |
-
independent Repair Agent rollouts to estimate p_hat (success rate).
|
| 6 |
-
|
| 7 |
-
Heavy and brittle on a single GPU — keep group_size small for hackathon
|
| 8 |
-
budgets. Provides a `--dry_run` mode that just exercises the reward function
|
| 9 |
-
without any LLM calls.
|
| 10 |
-
"""
|
| 11 |
-
from __future__ import annotations
|
| 12 |
-
|
| 13 |
-
import argparse
|
| 14 |
-
import json
|
| 15 |
-
import os
|
| 16 |
-
import random
|
| 17 |
-
from pathlib import Path
|
| 18 |
-
from typing import Optional
|
| 19 |
-
|
| 20 |
-
from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction
|
| 21 |
-
from forgeenv.env.forge_environment import ForgeEnvironment
|
| 22 |
-
from forgeenv.roles.drift_generator import (
|
| 23 |
-
BaselineDriftGenerator,
|
| 24 |
-
parse_drift_output,
|
| 25 |
-
)
|
| 26 |
-
from forgeenv.roles.prompts import (
|
| 27 |
-
DRIFT_GENERATOR_SYSTEM_PROMPT,
|
| 28 |
-
render_drift_generator_prompt,
|
| 29 |
-
)
|
| 30 |
-
from forgeenv.training.rollout import (
|
| 31 |
-
GenerateFn,
|
| 32 |
-
rollout_one_episode,
|
| 33 |
-
baseline_oracle_repair_generate,
|
| 34 |
-
_baseline_repair_generate,
|
| 35 |
-
)
|
| 36 |
-
from forgeenv.training.reward_functions import (
|
| 37 |
-
compute_drift_gen_reward,
|
| 38 |
-
compute_uncertainty_reward,
|
| 39 |
-
compute_repetition_penalty,
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def evaluate_drift_batch(
|
| 44 |
-
env_factory,
|
| 45 |
-
breakages: list[dict],
|
| 46 |
-
repair_generate: GenerateFn,
|
| 47 |
-
n_repair_attempts_per_breakage: int = 4,
|
| 48 |
-
seed: int = 0,
|
| 49 |
-
) -> list[float]:
|
| 50 |
-
"""For each breakage spec, run K Repair-Agent attempts and compute
|
| 51 |
-
R-Zero's composite Challenger reward. Returns one reward per breakage."""
|
| 52 |
-
|
| 53 |
-
breakage_texts = [
|
| 54 |
-
f"{b.get('primitive_type','')}::{json.dumps(b.get('params', {}), sort_keys=True)}"
|
| 55 |
-
for b in breakages
|
| 56 |
-
]
|
| 57 |
-
|
| 58 |
-
rewards: list[float] = []
|
| 59 |
-
for idx, breakage_spec in enumerate(breakages):
|
| 60 |
-
successes: list[bool] = []
|
| 61 |
-
for k in range(n_repair_attempts_per_breakage):
|
| 62 |
-
env = env_factory()
|
| 63 |
-
env.reset(seed=seed + idx * 100 + k, difficulty="easy")
|
| 64 |
-
try:
|
| 65 |
-
obs2 = env.step(
|
| 66 |
-
ForgeAction(
|
| 67 |
-
breakage=BreakageAction(
|
| 68 |
-
primitive_type=breakage_spec.get("primitive_type", ""),
|
| 69 |
-
params=breakage_spec.get("params", {}) or {},
|
| 70 |
-
)
|
| 71 |
-
)
|
| 72 |
-
)
|
| 73 |
-
except Exception:
|
| 74 |
-
successes.append(False)
|
| 75 |
-
continue
|
| 76 |
-
|
| 77 |
-
from forgeenv.roles.repair_agent import extract_diff
|
| 78 |
-
from forgeenv.roles.prompts import render_repair_agent_prompt
|
| 79 |
-
|
| 80 |
-
user = render_repair_agent_prompt(
|
| 81 |
-
broken_script=obs2.script_content,
|
| 82 |
-
error_trace=obs2.error_trace or "",
|
| 83 |
-
library_versions=obs2.library_versions,
|
| 84 |
-
target_category=obs2.target_category,
|
| 85 |
-
)
|
| 86 |
-
raw = repair_generate("", user)
|
| 87 |
-
diff = extract_diff(raw or "")
|
| 88 |
-
obs3 = env.step(ForgeAction(repair=RepairAction(unified_diff=diff)))
|
| 89 |
-
successes.append(
|
| 90 |
-
bool(obs3.held_out_breakdown.get("executed_cleanly", 0.0) > 0.5)
|
| 91 |
-
)
|
| 92 |
-
|
| 93 |
-
reward = compute_drift_gen_reward(
|
| 94 |
-
breakage_text=breakage_texts[idx],
|
| 95 |
-
repair_successes=successes,
|
| 96 |
-
batch_breakages=breakage_texts,
|
| 97 |
-
)
|
| 98 |
-
rewards.append(reward)
|
| 99 |
-
return rewards
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
def run_drift_grpo_dry_run(
|
| 103 |
-
output_dir: str, total_episodes: int = 100, group_size: int = 4, seed: int = 0
|
| 104 |
-
) -> None:
|
| 105 |
-
"""Pure-CPU exercise of the drift-side reward loop. Writes per-step rewards."""
|
| 106 |
-
rng = random.Random(seed)
|
| 107 |
-
drift_gen = BaselineDriftGenerator(seed=seed)
|
| 108 |
-
rewards_log: list[dict] = []
|
| 109 |
-
|
| 110 |
-
for ep in range(total_episodes):
|
| 111 |
-
env = ForgeEnvironment(seed=seed + ep)
|
| 112 |
-
env.reset(difficulty="easy")
|
| 113 |
-
target_category = env.state["target_category"]
|
| 114 |
-
script = env._original_script # noqa: SLF001 — read-only convenience
|
| 115 |
-
|
| 116 |
-
# Sample group_size candidate breakages
|
| 117 |
-
candidates = [
|
| 118 |
-
drift_gen.propose(target_category=target_category, script=script)
|
| 119 |
-
for _ in range(group_size)
|
| 120 |
-
]
|
| 121 |
-
|
| 122 |
-
# Use the oracle as repair (so we get a meaningful uncertainty signal:
|
| 123 |
-
# an "unbreakable" breakage gives p_hat=1, an "always-fails" one gives 0)
|
| 124 |
-
rewards = evaluate_drift_batch(
|
| 125 |
-
env_factory=lambda: ForgeEnvironment(seed=rng.randint(0, 1_000_000)),
|
| 126 |
-
breakages=candidates,
|
| 127 |
-
repair_generate=baseline_oracle_repair_generate(env),
|
| 128 |
-
n_repair_attempts_per_breakage=2,
|
| 129 |
-
seed=seed + ep,
|
| 130 |
-
)
|
| 131 |
-
rewards_log.append(
|
| 132 |
-
{"episode": ep, "rewards": rewards, "candidates": candidates}
|
| 133 |
-
)
|
| 134 |
-
|
| 135 |
-
if ep % max(1, total_episodes // 10) == 0:
|
| 136 |
-
mean_r = sum(rewards) / max(1, len(rewards))
|
| 137 |
-
print(f"[drift dry-run] ep={ep} mean_reward={mean_r:.3f}")
|
| 138 |
-
|
| 139 |
-
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 140 |
-
(Path(output_dir) / "drift_dry_run.json").write_text(
|
| 141 |
-
json.dumps(rewards_log, indent=2)
|
| 142 |
-
)
|
| 143 |
-
print(f"[drift dry-run] wrote {len(rewards_log)} episodes to {output_dir}")
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
def _parse_args() -> argparse.Namespace:
|
| 147 |
-
parser = argparse.ArgumentParser(description=__doc__)
|
| 148 |
-
parser.add_argument("--output_dir", required=True)
|
| 149 |
-
parser.add_argument("--total_episodes", type=int, default=100)
|
| 150 |
-
parser.add_argument("--group_size", type=int, default=4)
|
| 151 |
-
parser.add_argument("--seed", type=int, default=0)
|
| 152 |
-
parser.add_argument("--dry_run", action="store_true", default=True)
|
| 153 |
-
return parser.parse_args()
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
if __name__ == "__main__":
|
| 157 |
-
args = _parse_args()
|
| 158 |
-
if args.dry_run:
|
| 159 |
-
run_drift_grpo_dry_run(
|
| 160 |
-
output_dir=args.output_dir,
|
| 161 |
-
total_episodes=args.total_episodes,
|
| 162 |
-
group_size=args.group_size,
|
| 163 |
-
seed=args.seed,
|
| 164 |
-
)
|
| 165 |
-
else:
|
| 166 |
-
raise NotImplementedError(
|
| 167 |
-
"Full LLM Drift GRPO requires both roles loaded — use the Colab notebook"
|
| 168 |
-
)
|
|
|
|
| 1 |
+
"""GRPO trainer for the Drift Generator.
|
| 2 |
+
|
| 3 |
+
Uses R-Zero's composite Challenger reward: max(0, uncertainty - repetition).
|
| 4 |
+
Each prompt is sampled `group_size` times; for every breakage we run K
|
| 5 |
+
independent Repair Agent rollouts to estimate p_hat (success rate).
|
| 6 |
+
|
| 7 |
+
Heavy and brittle on a single GPU — keep group_size small for hackathon
|
| 8 |
+
budgets. Provides a `--dry_run` mode that just exercises the reward function
|
| 9 |
+
without any LLM calls.
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import random
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Optional
|
| 19 |
+
|
| 20 |
+
from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction
|
| 21 |
+
from forgeenv.env.forge_environment import ForgeEnvironment
|
| 22 |
+
from forgeenv.roles.drift_generator import (
|
| 23 |
+
BaselineDriftGenerator,
|
| 24 |
+
parse_drift_output,
|
| 25 |
+
)
|
| 26 |
+
from forgeenv.roles.prompts import (
|
| 27 |
+
DRIFT_GENERATOR_SYSTEM_PROMPT,
|
| 28 |
+
render_drift_generator_prompt,
|
| 29 |
+
)
|
| 30 |
+
from forgeenv.training.rollout import (
|
| 31 |
+
GenerateFn,
|
| 32 |
+
rollout_one_episode,
|
| 33 |
+
baseline_oracle_repair_generate,
|
| 34 |
+
_baseline_repair_generate,
|
| 35 |
+
)
|
| 36 |
+
from forgeenv.training.reward_functions import (
|
| 37 |
+
compute_drift_gen_reward,
|
| 38 |
+
compute_uncertainty_reward,
|
| 39 |
+
compute_repetition_penalty,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def evaluate_drift_batch(
|
| 44 |
+
env_factory,
|
| 45 |
+
breakages: list[dict],
|
| 46 |
+
repair_generate: GenerateFn,
|
| 47 |
+
n_repair_attempts_per_breakage: int = 4,
|
| 48 |
+
seed: int = 0,
|
| 49 |
+
) -> list[float]:
|
| 50 |
+
"""For each breakage spec, run K Repair-Agent attempts and compute
|
| 51 |
+
R-Zero's composite Challenger reward. Returns one reward per breakage."""
|
| 52 |
+
|
| 53 |
+
breakage_texts = [
|
| 54 |
+
f"{b.get('primitive_type','')}::{json.dumps(b.get('params', {}), sort_keys=True)}"
|
| 55 |
+
for b in breakages
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
rewards: list[float] = []
|
| 59 |
+
for idx, breakage_spec in enumerate(breakages):
|
| 60 |
+
successes: list[bool] = []
|
| 61 |
+
for k in range(n_repair_attempts_per_breakage):
|
| 62 |
+
env = env_factory()
|
| 63 |
+
env.reset(seed=seed + idx * 100 + k, difficulty="easy")
|
| 64 |
+
try:
|
| 65 |
+
obs2 = env.step(
|
| 66 |
+
ForgeAction(
|
| 67 |
+
breakage=BreakageAction(
|
| 68 |
+
primitive_type=breakage_spec.get("primitive_type", ""),
|
| 69 |
+
params=breakage_spec.get("params", {}) or {},
|
| 70 |
+
)
|
| 71 |
+
)
|
| 72 |
+
)
|
| 73 |
+
except Exception:
|
| 74 |
+
successes.append(False)
|
| 75 |
+
continue
|
| 76 |
+
|
| 77 |
+
from forgeenv.roles.repair_agent import extract_diff
|
| 78 |
+
from forgeenv.roles.prompts import render_repair_agent_prompt
|
| 79 |
+
|
| 80 |
+
user = render_repair_agent_prompt(
|
| 81 |
+
broken_script=obs2.script_content,
|
| 82 |
+
error_trace=obs2.error_trace or "",
|
| 83 |
+
library_versions=obs2.library_versions,
|
| 84 |
+
target_category=obs2.target_category,
|
| 85 |
+
)
|
| 86 |
+
raw = repair_generate("", user)
|
| 87 |
+
diff = extract_diff(raw or "")
|
| 88 |
+
obs3 = env.step(ForgeAction(repair=RepairAction(unified_diff=diff)))
|
| 89 |
+
successes.append(
|
| 90 |
+
bool(obs3.held_out_breakdown.get("executed_cleanly", 0.0) > 0.5)
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
reward = compute_drift_gen_reward(
|
| 94 |
+
breakage_text=breakage_texts[idx],
|
| 95 |
+
repair_successes=successes,
|
| 96 |
+
batch_breakages=breakage_texts,
|
| 97 |
+
)
|
| 98 |
+
rewards.append(reward)
|
| 99 |
+
return rewards
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def run_drift_grpo_dry_run(
|
| 103 |
+
output_dir: str, total_episodes: int = 100, group_size: int = 4, seed: int = 0
|
| 104 |
+
) -> None:
|
| 105 |
+
"""Pure-CPU exercise of the drift-side reward loop. Writes per-step rewards."""
|
| 106 |
+
rng = random.Random(seed)
|
| 107 |
+
drift_gen = BaselineDriftGenerator(seed=seed)
|
| 108 |
+
rewards_log: list[dict] = []
|
| 109 |
+
|
| 110 |
+
for ep in range(total_episodes):
|
| 111 |
+
env = ForgeEnvironment(seed=seed + ep)
|
| 112 |
+
env.reset(difficulty="easy")
|
| 113 |
+
target_category = env.state["target_category"]
|
| 114 |
+
script = env._original_script # noqa: SLF001 — read-only convenience
|
| 115 |
+
|
| 116 |
+
# Sample group_size candidate breakages
|
| 117 |
+
candidates = [
|
| 118 |
+
drift_gen.propose(target_category=target_category, script=script)
|
| 119 |
+
for _ in range(group_size)
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
# Use the oracle as repair (so we get a meaningful uncertainty signal:
|
| 123 |
+
# an "unbreakable" breakage gives p_hat=1, an "always-fails" one gives 0)
|
| 124 |
+
rewards = evaluate_drift_batch(
|
| 125 |
+
env_factory=lambda: ForgeEnvironment(seed=rng.randint(0, 1_000_000)),
|
| 126 |
+
breakages=candidates,
|
| 127 |
+
repair_generate=baseline_oracle_repair_generate(env),
|
| 128 |
+
n_repair_attempts_per_breakage=2,
|
| 129 |
+
seed=seed + ep,
|
| 130 |
+
)
|
| 131 |
+
rewards_log.append(
|
| 132 |
+
{"episode": ep, "rewards": rewards, "candidates": candidates}
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
if ep % max(1, total_episodes // 10) == 0:
|
| 136 |
+
mean_r = sum(rewards) / max(1, len(rewards))
|
| 137 |
+
print(f"[drift dry-run] ep={ep} mean_reward={mean_r:.3f}")
|
| 138 |
+
|
| 139 |
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 140 |
+
(Path(output_dir) / "drift_dry_run.json").write_text(
|
| 141 |
+
json.dumps(rewards_log, indent=2)
|
| 142 |
+
)
|
| 143 |
+
print(f"[drift dry-run] wrote {len(rewards_log)} episodes to {output_dir}")
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _parse_args() -> argparse.Namespace:
|
| 147 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
| 148 |
+
parser.add_argument("--output_dir", required=True)
|
| 149 |
+
parser.add_argument("--total_episodes", type=int, default=100)
|
| 150 |
+
parser.add_argument("--group_size", type=int, default=4)
|
| 151 |
+
parser.add_argument("--seed", type=int, default=0)
|
| 152 |
+
parser.add_argument("--dry_run", action="store_true", default=True)
|
| 153 |
+
return parser.parse_args()
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
if __name__ == "__main__":
|
| 157 |
+
args = _parse_args()
|
| 158 |
+
if args.dry_run:
|
| 159 |
+
run_drift_grpo_dry_run(
|
| 160 |
+
output_dir=args.output_dir,
|
| 161 |
+
total_episodes=args.total_episodes,
|
| 162 |
+
group_size=args.group_size,
|
| 163 |
+
seed=args.seed,
|
| 164 |
+
)
|
| 165 |
+
else:
|
| 166 |
+
raise NotImplementedError(
|
| 167 |
+
"Full LLM Drift GRPO requires both roles loaded — use the Colab notebook"
|
| 168 |
+
)
|
forgeenv/training/grpo_repair.py
CHANGED
|
@@ -1,213 +1,213 @@
|
|
| 1 |
-
"""GRPO trainer for the Repair Agent.
|
| 2 |
-
|
| 3 |
-
This wires TRL's GRPOTrainer to ForgeEnvironment via a per-prompt rollout
|
| 4 |
-
function. Each prompt is sampled K times (group size); each sample is
|
| 5 |
-
executed in the env and gets a scalar reward from the visible verifier.
|
| 6 |
-
|
| 7 |
-
Usage:
|
| 8 |
-
python -m forgeenv.training.grpo_repair \\
|
| 9 |
-
--base_model unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit \\
|
| 10 |
-
--adapter_path artifacts/checkpoints/repair_agent_sft \\
|
| 11 |
-
--output_dir artifacts/checkpoints/repair_agent_grpo \\
|
| 12 |
-
--total_episodes 200 --group_size 4
|
| 13 |
-
"""
|
| 14 |
-
from __future__ import annotations
|
| 15 |
-
|
| 16 |
-
import argparse
|
| 17 |
-
import json
|
| 18 |
-
import os
|
| 19 |
-
from pathlib import Path
|
| 20 |
-
from typing import Any, Optional
|
| 21 |
-
|
| 22 |
-
from forgeenv.env.forge_environment import ForgeEnvironment
|
| 23 |
-
from forgeenv.roles.drift_generator import BaselineDriftGenerator
|
| 24 |
-
from forgeenv.roles.prompts import (
|
| 25 |
-
DRIFT_GENERATOR_SYSTEM_PROMPT,
|
| 26 |
-
REPAIR_AGENT_SYSTEM_PROMPT,
|
| 27 |
-
render_drift_generator_prompt,
|
| 28 |
-
render_repair_agent_prompt,
|
| 29 |
-
)
|
| 30 |
-
from forgeenv.roles.repair_agent import extract_diff
|
| 31 |
-
from forgeenv.training.rollout import rollout_one_episode
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def _build_repair_prompt(env: ForgeEnvironment) -> dict[str, Any]:
|
| 35 |
-
"""Reset env, run baseline drift generator, return a repair-prompt
|
| 36 |
-
dict ready to feed to TRL's GRPOTrainer."""
|
| 37 |
-
drift_gen = BaselineDriftGenerator()
|
| 38 |
-
|
| 39 |
-
obs = env.reset(difficulty="easy")
|
| 40 |
-
drift_user = render_drift_generator_prompt(
|
| 41 |
-
script=obs.script_content,
|
| 42 |
-
target_category=obs.target_category,
|
| 43 |
-
library_versions=obs.library_versions,
|
| 44 |
-
)
|
| 45 |
-
spec = drift_gen.propose(
|
| 46 |
-
target_category=obs.target_category, script=obs.script_content
|
| 47 |
-
)
|
| 48 |
-
from forgeenv.env.actions import BreakageAction, ForgeAction
|
| 49 |
-
|
| 50 |
-
obs2 = env.step(
|
| 51 |
-
ForgeAction(
|
| 52 |
-
breakage=BreakageAction(
|
| 53 |
-
primitive_type=spec["primitive_type"], params=spec["params"]
|
| 54 |
-
)
|
| 55 |
-
)
|
| 56 |
-
)
|
| 57 |
-
|
| 58 |
-
user = render_repair_agent_prompt(
|
| 59 |
-
broken_script=obs2.script_content,
|
| 60 |
-
error_trace=obs2.error_trace or "",
|
| 61 |
-
library_versions=obs2.library_versions,
|
| 62 |
-
target_category=obs2.target_category,
|
| 63 |
-
)
|
| 64 |
-
return {
|
| 65 |
-
"prompt": [
|
| 66 |
-
{"role": "system", "content": REPAIR_AGENT_SYSTEM_PROMPT},
|
| 67 |
-
{"role": "user", "content": user},
|
| 68 |
-
],
|
| 69 |
-
"task_id": obs.task_id,
|
| 70 |
-
"primitive_type": spec["primitive_type"],
|
| 71 |
-
"broken_script": obs2.script_content,
|
| 72 |
-
"drift_user_prompt": drift_user,
|
| 73 |
-
}
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def reward_repair_function(
|
| 77 |
-
completions: list, prompts: list = None, **kwargs
|
| 78 |
-
) -> list[float]:
|
| 79 |
-
"""TRL-compatible reward fn: scores a batch of completions against
|
| 80 |
-
a (broken_script, breakage_spec) tuple stored on each example."""
|
| 81 |
-
from forgeenv.env.actions import RepairAction, ForgeAction
|
| 82 |
-
from forgeenv.env.diff_utils import apply_unified_diff
|
| 83 |
-
from forgeenv.sandbox.simulation_mode import SimulationExecutor
|
| 84 |
-
from forgeenv.tasks.task_sampler import TaskSampler
|
| 85 |
-
from forgeenv.verifier.visible_verifier import compute_visible_reward
|
| 86 |
-
|
| 87 |
-
sampler = TaskSampler()
|
| 88 |
-
executor = SimulationExecutor()
|
| 89 |
-
task_ids = kwargs.get("task_id", [None] * len(completions))
|
| 90 |
-
broken_scripts = kwargs.get("broken_script", [""] * len(completions))
|
| 91 |
-
|
| 92 |
-
rewards: list[float] = []
|
| 93 |
-
for completion, task_id, broken in zip(completions, task_ids, broken_scripts):
|
| 94 |
-
if isinstance(completion, list): # chat format
|
| 95 |
-
completion = completion[-1]["content"]
|
| 96 |
-
diff = extract_diff(completion or "")
|
| 97 |
-
repaired = apply_unified_diff(broken, diff) if diff else broken
|
| 98 |
-
task = sampler.get_by_id(task_id) if task_id else None
|
| 99 |
-
if task is None and sampler.tasks:
|
| 100 |
-
task = sampler.tasks[0]
|
| 101 |
-
result = executor.execute(repaired, task)
|
| 102 |
-
result.script_content = repaired
|
| 103 |
-
reward, _ = compute_visible_reward(result, task)
|
| 104 |
-
rewards.append(float(reward))
|
| 105 |
-
return rewards
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
def run_grpo(
|
| 109 |
-
base_model: str,
|
| 110 |
-
adapter_path: Optional[str],
|
| 111 |
-
output_dir: str,
|
| 112 |
-
total_episodes: int = 200,
|
| 113 |
-
group_size: int = 4,
|
| 114 |
-
learning_rate: float = 5e-6,
|
| 115 |
-
seed: int = 0,
|
| 116 |
-
use_unsloth: Optional[bool] = None,
|
| 117 |
-
) -> None:
|
| 118 |
-
"""Launch GRPO training (lazy imports to keep this module importable on CPU)."""
|
| 119 |
-
|
| 120 |
-
if use_unsloth is None:
|
| 121 |
-
use_unsloth = os.environ.get("FORGEENV_USE_UNSLOTH", "1") == "1"
|
| 122 |
-
|
| 123 |
-
if not use_unsloth:
|
| 124 |
-
# Dry-run mode: just exercise the prompt building loop and dump rewards.
|
| 125 |
-
env = ForgeEnvironment(seed=seed)
|
| 126 |
-
rewards = []
|
| 127 |
-
for ep in range(total_episodes):
|
| 128 |
-
result = rollout_one_episode(env)
|
| 129 |
-
rewards.append(result.visible_reward)
|
| 130 |
-
if ep % max(1, total_episodes // 10) == 0:
|
| 131 |
-
print(
|
| 132 |
-
f"[grpo dry-run] ep={ep} reward={result.visible_reward:.3f} "
|
| 133 |
-
f"primitive={result.primitive_type}"
|
| 134 |
-
)
|
| 135 |
-
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 136 |
-
(Path(output_dir) / "dry_run_rewards.json").write_text(
|
| 137 |
-
json.dumps(rewards, indent=2)
|
| 138 |
-
)
|
| 139 |
-
print(f"[grpo dry-run] wrote {len(rewards)} rewards to {output_dir}")
|
| 140 |
-
return
|
| 141 |
-
|
| 142 |
-
from datasets import Dataset
|
| 143 |
-
from trl import GRPOConfig, GRPOTrainer
|
| 144 |
-
from unsloth import FastLanguageModel
|
| 145 |
-
from peft import PeftModel
|
| 146 |
-
|
| 147 |
-
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 148 |
-
model_name=base_model,
|
| 149 |
-
max_seq_length=4096,
|
| 150 |
-
dtype=None,
|
| 151 |
-
load_in_4bit=True,
|
| 152 |
-
)
|
| 153 |
-
if adapter_path:
|
| 154 |
-
model = PeftModel.from_pretrained(model, adapter_path, is_trainable=True)
|
| 155 |
-
|
| 156 |
-
env = ForgeEnvironment(seed=seed)
|
| 157 |
-
examples = [_build_repair_prompt(env) for _ in range(total_episodes)]
|
| 158 |
-
dataset = Dataset.from_list(examples)
|
| 159 |
-
|
| 160 |
-
grpo_config = GRPOConfig(
|
| 161 |
-
output_dir=output_dir,
|
| 162 |
-
per_device_train_batch_size=1,
|
| 163 |
-
gradient_accumulation_steps=4,
|
| 164 |
-
learning_rate=learning_rate,
|
| 165 |
-
max_steps=total_episodes,
|
| 166 |
-
num_generations=group_size,
|
| 167 |
-
max_completion_length=1024,
|
| 168 |
-
logging_steps=5,
|
| 169 |
-
save_steps=max(50, total_episodes // 4),
|
| 170 |
-
save_total_limit=2,
|
| 171 |
-
seed=seed,
|
| 172 |
-
report_to="none",
|
| 173 |
-
beta=0.04,
|
| 174 |
-
)
|
| 175 |
-
trainer = GRPOTrainer(
|
| 176 |
-
model=model,
|
| 177 |
-
processing_class=tokenizer,
|
| 178 |
-
args=grpo_config,
|
| 179 |
-
train_dataset=dataset,
|
| 180 |
-
reward_funcs=[reward_repair_function],
|
| 181 |
-
)
|
| 182 |
-
trainer.train()
|
| 183 |
-
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 184 |
-
model.save_pretrained(output_dir)
|
| 185 |
-
tokenizer.save_pretrained(output_dir)
|
| 186 |
-
print(f"[grpo] saved adapter to {output_dir}")
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
def _parse_args() -> argparse.Namespace:
|
| 190 |
-
parser = argparse.ArgumentParser(description=__doc__)
|
| 191 |
-
parser.add_argument("--base_model", default="unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit")
|
| 192 |
-
parser.add_argument("--adapter_path", default=None)
|
| 193 |
-
parser.add_argument("--output_dir", required=True)
|
| 194 |
-
parser.add_argument("--total_episodes", type=int, default=200)
|
| 195 |
-
parser.add_argument("--group_size", type=int, default=4)
|
| 196 |
-
parser.add_argument("--learning_rate", type=float, default=5e-6)
|
| 197 |
-
parser.add_argument("--seed", type=int, default=0)
|
| 198 |
-
parser.add_argument("--dry_run", action="store_true")
|
| 199 |
-
return parser.parse_args()
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
if __name__ == "__main__":
|
| 203 |
-
args = _parse_args()
|
| 204 |
-
run_grpo(
|
| 205 |
-
base_model=args.base_model,
|
| 206 |
-
adapter_path=args.adapter_path,
|
| 207 |
-
output_dir=args.output_dir,
|
| 208 |
-
total_episodes=args.total_episodes,
|
| 209 |
-
group_size=args.group_size,
|
| 210 |
-
learning_rate=args.learning_rate,
|
| 211 |
-
seed=args.seed,
|
| 212 |
-
use_unsloth=not args.dry_run,
|
| 213 |
-
)
|
|
|
|
| 1 |
+
"""GRPO trainer for the Repair Agent.
|
| 2 |
+
|
| 3 |
+
This wires TRL's GRPOTrainer to ForgeEnvironment via a per-prompt rollout
|
| 4 |
+
function. Each prompt is sampled K times (group size); each sample is
|
| 5 |
+
executed in the env and gets a scalar reward from the visible verifier.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python -m forgeenv.training.grpo_repair \\
|
| 9 |
+
--base_model unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit \\
|
| 10 |
+
--adapter_path artifacts/checkpoints/repair_agent_sft \\
|
| 11 |
+
--output_dir artifacts/checkpoints/repair_agent_grpo \\
|
| 12 |
+
--total_episodes 200 --group_size 4
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Any, Optional
|
| 21 |
+
|
| 22 |
+
from forgeenv.env.forge_environment import ForgeEnvironment
|
| 23 |
+
from forgeenv.roles.drift_generator import BaselineDriftGenerator
|
| 24 |
+
from forgeenv.roles.prompts import (
|
| 25 |
+
DRIFT_GENERATOR_SYSTEM_PROMPT,
|
| 26 |
+
REPAIR_AGENT_SYSTEM_PROMPT,
|
| 27 |
+
render_drift_generator_prompt,
|
| 28 |
+
render_repair_agent_prompt,
|
| 29 |
+
)
|
| 30 |
+
from forgeenv.roles.repair_agent import extract_diff
|
| 31 |
+
from forgeenv.training.rollout import rollout_one_episode
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _build_repair_prompt(env: ForgeEnvironment) -> dict[str, Any]:
|
| 35 |
+
"""Reset env, run baseline drift generator, return a repair-prompt
|
| 36 |
+
dict ready to feed to TRL's GRPOTrainer."""
|
| 37 |
+
drift_gen = BaselineDriftGenerator()
|
| 38 |
+
|
| 39 |
+
obs = env.reset(difficulty="easy")
|
| 40 |
+
drift_user = render_drift_generator_prompt(
|
| 41 |
+
script=obs.script_content,
|
| 42 |
+
target_category=obs.target_category,
|
| 43 |
+
library_versions=obs.library_versions,
|
| 44 |
+
)
|
| 45 |
+
spec = drift_gen.propose(
|
| 46 |
+
target_category=obs.target_category, script=obs.script_content
|
| 47 |
+
)
|
| 48 |
+
from forgeenv.env.actions import BreakageAction, ForgeAction
|
| 49 |
+
|
| 50 |
+
obs2 = env.step(
|
| 51 |
+
ForgeAction(
|
| 52 |
+
breakage=BreakageAction(
|
| 53 |
+
primitive_type=spec["primitive_type"], params=spec["params"]
|
| 54 |
+
)
|
| 55 |
+
)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
user = render_repair_agent_prompt(
|
| 59 |
+
broken_script=obs2.script_content,
|
| 60 |
+
error_trace=obs2.error_trace or "",
|
| 61 |
+
library_versions=obs2.library_versions,
|
| 62 |
+
target_category=obs2.target_category,
|
| 63 |
+
)
|
| 64 |
+
return {
|
| 65 |
+
"prompt": [
|
| 66 |
+
{"role": "system", "content": REPAIR_AGENT_SYSTEM_PROMPT},
|
| 67 |
+
{"role": "user", "content": user},
|
| 68 |
+
],
|
| 69 |
+
"task_id": obs.task_id,
|
| 70 |
+
"primitive_type": spec["primitive_type"],
|
| 71 |
+
"broken_script": obs2.script_content,
|
| 72 |
+
"drift_user_prompt": drift_user,
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def reward_repair_function(
|
| 77 |
+
completions: list, prompts: list = None, **kwargs
|
| 78 |
+
) -> list[float]:
|
| 79 |
+
"""TRL-compatible reward fn: scores a batch of completions against
|
| 80 |
+
a (broken_script, breakage_spec) tuple stored on each example."""
|
| 81 |
+
from forgeenv.env.actions import RepairAction, ForgeAction
|
| 82 |
+
from forgeenv.env.diff_utils import apply_unified_diff
|
| 83 |
+
from forgeenv.sandbox.simulation_mode import SimulationExecutor
|
| 84 |
+
from forgeenv.tasks.task_sampler import TaskSampler
|
| 85 |
+
from forgeenv.verifier.visible_verifier import compute_visible_reward
|
| 86 |
+
|
| 87 |
+
sampler = TaskSampler()
|
| 88 |
+
executor = SimulationExecutor()
|
| 89 |
+
task_ids = kwargs.get("task_id", [None] * len(completions))
|
| 90 |
+
broken_scripts = kwargs.get("broken_script", [""] * len(completions))
|
| 91 |
+
|
| 92 |
+
rewards: list[float] = []
|
| 93 |
+
for completion, task_id, broken in zip(completions, task_ids, broken_scripts):
|
| 94 |
+
if isinstance(completion, list): # chat format
|
| 95 |
+
completion = completion[-1]["content"]
|
| 96 |
+
diff = extract_diff(completion or "")
|
| 97 |
+
repaired = apply_unified_diff(broken, diff) if diff else broken
|
| 98 |
+
task = sampler.get_by_id(task_id) if task_id else None
|
| 99 |
+
if task is None and sampler.tasks:
|
| 100 |
+
task = sampler.tasks[0]
|
| 101 |
+
result = executor.execute(repaired, task)
|
| 102 |
+
result.script_content = repaired
|
| 103 |
+
reward, _ = compute_visible_reward(result, task)
|
| 104 |
+
rewards.append(float(reward))
|
| 105 |
+
return rewards
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def run_grpo(
|
| 109 |
+
base_model: str,
|
| 110 |
+
adapter_path: Optional[str],
|
| 111 |
+
output_dir: str,
|
| 112 |
+
total_episodes: int = 200,
|
| 113 |
+
group_size: int = 4,
|
| 114 |
+
learning_rate: float = 5e-6,
|
| 115 |
+
seed: int = 0,
|
| 116 |
+
use_unsloth: Optional[bool] = None,
|
| 117 |
+
) -> None:
|
| 118 |
+
"""Launch GRPO training (lazy imports to keep this module importable on CPU)."""
|
| 119 |
+
|
| 120 |
+
if use_unsloth is None:
|
| 121 |
+
use_unsloth = os.environ.get("FORGEENV_USE_UNSLOTH", "1") == "1"
|
| 122 |
+
|
| 123 |
+
if not use_unsloth:
|
| 124 |
+
# Dry-run mode: just exercise the prompt building loop and dump rewards.
|
| 125 |
+
env = ForgeEnvironment(seed=seed)
|
| 126 |
+
rewards = []
|
| 127 |
+
for ep in range(total_episodes):
|
| 128 |
+
result = rollout_one_episode(env)
|
| 129 |
+
rewards.append(result.visible_reward)
|
| 130 |
+
if ep % max(1, total_episodes // 10) == 0:
|
| 131 |
+
print(
|
| 132 |
+
f"[grpo dry-run] ep={ep} reward={result.visible_reward:.3f} "
|
| 133 |
+
f"primitive={result.primitive_type}"
|
| 134 |
+
)
|
| 135 |
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 136 |
+
(Path(output_dir) / "dry_run_rewards.json").write_text(
|
| 137 |
+
json.dumps(rewards, indent=2)
|
| 138 |
+
)
|
| 139 |
+
print(f"[grpo dry-run] wrote {len(rewards)} rewards to {output_dir}")
|
| 140 |
+
return
|
| 141 |
+
|
| 142 |
+
from datasets import Dataset
|
| 143 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 144 |
+
from unsloth import FastLanguageModel
|
| 145 |
+
from peft import PeftModel
|
| 146 |
+
|
| 147 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 148 |
+
model_name=base_model,
|
| 149 |
+
max_seq_length=4096,
|
| 150 |
+
dtype=None,
|
| 151 |
+
load_in_4bit=True,
|
| 152 |
+
)
|
| 153 |
+
if adapter_path:
|
| 154 |
+
model = PeftModel.from_pretrained(model, adapter_path, is_trainable=True)
|
| 155 |
+
|
| 156 |
+
env = ForgeEnvironment(seed=seed)
|
| 157 |
+
examples = [_build_repair_prompt(env) for _ in range(total_episodes)]
|
| 158 |
+
dataset = Dataset.from_list(examples)
|
| 159 |
+
|
| 160 |
+
grpo_config = GRPOConfig(
|
| 161 |
+
output_dir=output_dir,
|
| 162 |
+
per_device_train_batch_size=1,
|
| 163 |
+
gradient_accumulation_steps=4,
|
| 164 |
+
learning_rate=learning_rate,
|
| 165 |
+
max_steps=total_episodes,
|
| 166 |
+
num_generations=group_size,
|
| 167 |
+
max_completion_length=1024,
|
| 168 |
+
logging_steps=5,
|
| 169 |
+
save_steps=max(50, total_episodes // 4),
|
| 170 |
+
save_total_limit=2,
|
| 171 |
+
seed=seed,
|
| 172 |
+
report_to="none",
|
| 173 |
+
beta=0.04,
|
| 174 |
+
)
|
| 175 |
+
trainer = GRPOTrainer(
|
| 176 |
+
model=model,
|
| 177 |
+
processing_class=tokenizer,
|
| 178 |
+
args=grpo_config,
|
| 179 |
+
train_dataset=dataset,
|
| 180 |
+
reward_funcs=[reward_repair_function],
|
| 181 |
+
)
|
| 182 |
+
trainer.train()
|
| 183 |
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 184 |
+
model.save_pretrained(output_dir)
|
| 185 |
+
tokenizer.save_pretrained(output_dir)
|
| 186 |
+
print(f"[grpo] saved adapter to {output_dir}")
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _parse_args() -> argparse.Namespace:
|
| 190 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
| 191 |
+
parser.add_argument("--base_model", default="unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit")
|
| 192 |
+
parser.add_argument("--adapter_path", default=None)
|
| 193 |
+
parser.add_argument("--output_dir", required=True)
|
| 194 |
+
parser.add_argument("--total_episodes", type=int, default=200)
|
| 195 |
+
parser.add_argument("--group_size", type=int, default=4)
|
| 196 |
+
parser.add_argument("--learning_rate", type=float, default=5e-6)
|
| 197 |
+
parser.add_argument("--seed", type=int, default=0)
|
| 198 |
+
parser.add_argument("--dry_run", action="store_true")
|
| 199 |
+
return parser.parse_args()
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
if __name__ == "__main__":
|
| 203 |
+
args = _parse_args()
|
| 204 |
+
run_grpo(
|
| 205 |
+
base_model=args.base_model,
|
| 206 |
+
adapter_path=args.adapter_path,
|
| 207 |
+
output_dir=args.output_dir,
|
| 208 |
+
total_episodes=args.total_episodes,
|
| 209 |
+
group_size=args.group_size,
|
| 210 |
+
learning_rate=args.learning_rate,
|
| 211 |
+
seed=args.seed,
|
| 212 |
+
use_unsloth=not args.dry_run,
|
| 213 |
+
)
|
forgeenv/training/plots.py
CHANGED
|
@@ -1,128 +1,128 @@
|
|
| 1 |
-
"""Matplotlib plotting helpers — produces the 3 PNGs that go into the README.
|
| 2 |
-
|
| 3 |
-
Plots:
|
| 4 |
-
1. baseline_vs_trained.png — bar/line comparison
|
| 5 |
-
2. training_reward_curve.png — moving-average reward over episodes
|
| 6 |
-
3. success_by_category.png — per-primitive-type success rate
|
| 7 |
-
|
| 8 |
-
All plots are 600x400 @ 100 dpi, label both axes, and use a colour-blind-safe palette.
|
| 9 |
-
"""
|
| 10 |
-
from __future__ import annotations
|
| 11 |
-
|
| 12 |
-
from pathlib import Path
|
| 13 |
-
from typing import Iterable
|
| 14 |
-
|
| 15 |
-
import matplotlib
|
| 16 |
-
|
| 17 |
-
matplotlib.use("Agg")
|
| 18 |
-
import matplotlib.pyplot as plt # noqa: E402
|
| 19 |
-
|
| 20 |
-
PALETTE = {
|
| 21 |
-
"baseline": "#888888",
|
| 22 |
-
"trained": "#1F77B4",
|
| 23 |
-
"ema": "#D62728",
|
| 24 |
-
"raw": "#1F77B4",
|
| 25 |
-
}
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def _moving_average(values: list[float], window: int = 10) -> list[float]:
|
| 29 |
-
if not values:
|
| 30 |
-
return []
|
| 31 |
-
out: list[float] = []
|
| 32 |
-
cumsum = 0.0
|
| 33 |
-
for i, v in enumerate(values):
|
| 34 |
-
cumsum += v
|
| 35 |
-
if i >= window:
|
| 36 |
-
cumsum -= values[i - window]
|
| 37 |
-
out.append(cumsum / min(i + 1, window))
|
| 38 |
-
return out
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def plot_baseline_vs_trained(
|
| 42 |
-
baseline_rewards: list[float],
|
| 43 |
-
trained_rewards: list[float],
|
| 44 |
-
out_path: str | Path,
|
| 45 |
-
title: str = "ForgeEnv: Baseline vs Trained (50 eval episodes)",
|
| 46 |
-
) -> str:
|
| 47 |
-
"""Side-by-side bar chart of mean reward + per-episode strip plot."""
|
| 48 |
-
out_path = Path(out_path)
|
| 49 |
-
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 50 |
-
fig, ax = plt.subplots(figsize=(6, 4), dpi=100)
|
| 51 |
-
|
| 52 |
-
means = [
|
| 53 |
-
sum(baseline_rewards) / max(1, len(baseline_rewards)),
|
| 54 |
-
sum(trained_rewards) / max(1, len(trained_rewards)),
|
| 55 |
-
]
|
| 56 |
-
labels = ["Baseline (no-op)", "Trained (GRPO)"]
|
| 57 |
-
colors = [PALETTE["baseline"], PALETTE["trained"]]
|
| 58 |
-
bars = ax.bar(labels, means, color=colors, width=0.5, alpha=0.85)
|
| 59 |
-
ax.bar_label(bars, fmt="%.2f", padding=3)
|
| 60 |
-
|
| 61 |
-
for x, rewards in zip([0, 1], [baseline_rewards, trained_rewards]):
|
| 62 |
-
if rewards:
|
| 63 |
-
xs = [x + 0.18] * len(rewards)
|
| 64 |
-
ax.scatter(xs, rewards, s=8, color="black", alpha=0.4, zorder=3)
|
| 65 |
-
|
| 66 |
-
ax.set_ylabel("Visible verifier reward")
|
| 67 |
-
ax.set_title(title)
|
| 68 |
-
ax.grid(axis="y", linestyle=":", alpha=0.5)
|
| 69 |
-
ax.set_ylim(bottom=min(0, min(means + baseline_rewards + trained_rewards or [0])))
|
| 70 |
-
fig.tight_layout()
|
| 71 |
-
fig.savefig(out_path, dpi=100, bbox_inches="tight")
|
| 72 |
-
plt.close(fig)
|
| 73 |
-
return str(out_path)
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def plot_reward_curve(
|
| 77 |
-
rewards: list[float],
|
| 78 |
-
out_path: str | Path,
|
| 79 |
-
window: int = 10,
|
| 80 |
-
title: str = "ForgeEnv: Repair Agent reward over training",
|
| 81 |
-
) -> str:
|
| 82 |
-
out_path = Path(out_path)
|
| 83 |
-
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 84 |
-
fig, ax = plt.subplots(figsize=(6, 4), dpi=100)
|
| 85 |
-
xs = list(range(1, len(rewards) + 1))
|
| 86 |
-
ax.plot(xs, rewards, color=PALETTE["raw"], alpha=0.35, linewidth=1.0, label="Per-episode")
|
| 87 |
-
if rewards:
|
| 88 |
-
ax.plot(
|
| 89 |
-
xs,
|
| 90 |
-
_moving_average(rewards, window=window),
|
| 91 |
-
color=PALETTE["ema"],
|
| 92 |
-
linewidth=2.0,
|
| 93 |
-
label=f"Moving avg (w={window})",
|
| 94 |
-
)
|
| 95 |
-
ax.set_xlabel("Episode")
|
| 96 |
-
ax.set_ylabel("Visible verifier reward")
|
| 97 |
-
ax.set_title(title)
|
| 98 |
-
ax.legend(loc="lower right")
|
| 99 |
-
ax.grid(linestyle=":", alpha=0.4)
|
| 100 |
-
fig.tight_layout()
|
| 101 |
-
fig.savefig(out_path, dpi=100, bbox_inches="tight")
|
| 102 |
-
plt.close(fig)
|
| 103 |
-
return str(out_path)
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
def plot_success_rate_by_category(
|
| 107 |
-
by_category: dict[str, list[bool]],
|
| 108 |
-
out_path: str | Path,
|
| 109 |
-
title: str = "ForgeEnv: Repair success by primitive type",
|
| 110 |
-
) -> str:
|
| 111 |
-
out_path = Path(out_path)
|
| 112 |
-
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 113 |
-
fig, ax = plt.subplots(figsize=(7, 4), dpi=100)
|
| 114 |
-
|
| 115 |
-
cats = list(by_category.keys())
|
| 116 |
-
rates = [
|
| 117 |
-
sum(by_category[c]) / max(1, len(by_category[c])) for c in cats
|
| 118 |
-
]
|
| 119 |
-
bars = ax.barh(cats, rates, color=PALETTE["trained"], alpha=0.85)
|
| 120 |
-
ax.bar_label(bars, fmt="%.2f", padding=3)
|
| 121 |
-
ax.set_xlim(0, 1.05)
|
| 122 |
-
ax.set_xlabel("Success rate (held-out: executed_cleanly)")
|
| 123 |
-
ax.set_title(title)
|
| 124 |
-
ax.grid(axis="x", linestyle=":", alpha=0.4)
|
| 125 |
-
fig.tight_layout()
|
| 126 |
-
fig.savefig(out_path, dpi=100, bbox_inches="tight")
|
| 127 |
-
plt.close(fig)
|
| 128 |
-
return str(out_path)
|
|
|
|
| 1 |
+
"""Matplotlib plotting helpers — produces the 3 PNGs that go into the README.
|
| 2 |
+
|
| 3 |
+
Plots:
|
| 4 |
+
1. baseline_vs_trained.png — bar/line comparison
|
| 5 |
+
2. training_reward_curve.png — moving-average reward over episodes
|
| 6 |
+
3. success_by_category.png — per-primitive-type success rate
|
| 7 |
+
|
| 8 |
+
All plots are 600x400 @ 100 dpi, label both axes, and use a colour-blind-safe palette.
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Iterable
|
| 14 |
+
|
| 15 |
+
import matplotlib
|
| 16 |
+
|
| 17 |
+
matplotlib.use("Agg")
|
| 18 |
+
import matplotlib.pyplot as plt # noqa: E402
|
| 19 |
+
|
| 20 |
+
PALETTE = {
|
| 21 |
+
"baseline": "#888888",
|
| 22 |
+
"trained": "#1F77B4",
|
| 23 |
+
"ema": "#D62728",
|
| 24 |
+
"raw": "#1F77B4",
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _moving_average(values: list[float], window: int = 10) -> list[float]:
|
| 29 |
+
if not values:
|
| 30 |
+
return []
|
| 31 |
+
out: list[float] = []
|
| 32 |
+
cumsum = 0.0
|
| 33 |
+
for i, v in enumerate(values):
|
| 34 |
+
cumsum += v
|
| 35 |
+
if i >= window:
|
| 36 |
+
cumsum -= values[i - window]
|
| 37 |
+
out.append(cumsum / min(i + 1, window))
|
| 38 |
+
return out
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def plot_baseline_vs_trained(
|
| 42 |
+
baseline_rewards: list[float],
|
| 43 |
+
trained_rewards: list[float],
|
| 44 |
+
out_path: str | Path,
|
| 45 |
+
title: str = "ForgeEnv: Baseline vs Trained (50 eval episodes)",
|
| 46 |
+
) -> str:
|
| 47 |
+
"""Side-by-side bar chart of mean reward + per-episode strip plot."""
|
| 48 |
+
out_path = Path(out_path)
|
| 49 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 50 |
+
fig, ax = plt.subplots(figsize=(6, 4), dpi=100)
|
| 51 |
+
|
| 52 |
+
means = [
|
| 53 |
+
sum(baseline_rewards) / max(1, len(baseline_rewards)),
|
| 54 |
+
sum(trained_rewards) / max(1, len(trained_rewards)),
|
| 55 |
+
]
|
| 56 |
+
labels = ["Baseline (no-op)", "Trained (GRPO)"]
|
| 57 |
+
colors = [PALETTE["baseline"], PALETTE["trained"]]
|
| 58 |
+
bars = ax.bar(labels, means, color=colors, width=0.5, alpha=0.85)
|
| 59 |
+
ax.bar_label(bars, fmt="%.2f", padding=3)
|
| 60 |
+
|
| 61 |
+
for x, rewards in zip([0, 1], [baseline_rewards, trained_rewards]):
|
| 62 |
+
if rewards:
|
| 63 |
+
xs = [x + 0.18] * len(rewards)
|
| 64 |
+
ax.scatter(xs, rewards, s=8, color="black", alpha=0.4, zorder=3)
|
| 65 |
+
|
| 66 |
+
ax.set_ylabel("Visible verifier reward")
|
| 67 |
+
ax.set_title(title)
|
| 68 |
+
ax.grid(axis="y", linestyle=":", alpha=0.5)
|
| 69 |
+
ax.set_ylim(bottom=min(0, min(means + baseline_rewards + trained_rewards or [0])))
|
| 70 |
+
fig.tight_layout()
|
| 71 |
+
fig.savefig(out_path, dpi=100, bbox_inches="tight")
|
| 72 |
+
plt.close(fig)
|
| 73 |
+
return str(out_path)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def plot_reward_curve(
|
| 77 |
+
rewards: list[float],
|
| 78 |
+
out_path: str | Path,
|
| 79 |
+
window: int = 10,
|
| 80 |
+
title: str = "ForgeEnv: Repair Agent reward over training",
|
| 81 |
+
) -> str:
|
| 82 |
+
out_path = Path(out_path)
|
| 83 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 84 |
+
fig, ax = plt.subplots(figsize=(6, 4), dpi=100)
|
| 85 |
+
xs = list(range(1, len(rewards) + 1))
|
| 86 |
+
ax.plot(xs, rewards, color=PALETTE["raw"], alpha=0.35, linewidth=1.0, label="Per-episode")
|
| 87 |
+
if rewards:
|
| 88 |
+
ax.plot(
|
| 89 |
+
xs,
|
| 90 |
+
_moving_average(rewards, window=window),
|
| 91 |
+
color=PALETTE["ema"],
|
| 92 |
+
linewidth=2.0,
|
| 93 |
+
label=f"Moving avg (w={window})",
|
| 94 |
+
)
|
| 95 |
+
ax.set_xlabel("Episode")
|
| 96 |
+
ax.set_ylabel("Visible verifier reward")
|
| 97 |
+
ax.set_title(title)
|
| 98 |
+
ax.legend(loc="lower right")
|
| 99 |
+
ax.grid(linestyle=":", alpha=0.4)
|
| 100 |
+
fig.tight_layout()
|
| 101 |
+
fig.savefig(out_path, dpi=100, bbox_inches="tight")
|
| 102 |
+
plt.close(fig)
|
| 103 |
+
return str(out_path)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def plot_success_rate_by_category(
|
| 107 |
+
by_category: dict[str, list[bool]],
|
| 108 |
+
out_path: str | Path,
|
| 109 |
+
title: str = "ForgeEnv: Repair success by primitive type",
|
| 110 |
+
) -> str:
|
| 111 |
+
out_path = Path(out_path)
|
| 112 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 113 |
+
fig, ax = plt.subplots(figsize=(7, 4), dpi=100)
|
| 114 |
+
|
| 115 |
+
cats = list(by_category.keys())
|
| 116 |
+
rates = [
|
| 117 |
+
sum(by_category[c]) / max(1, len(by_category[c])) for c in cats
|
| 118 |
+
]
|
| 119 |
+
bars = ax.barh(cats, rates, color=PALETTE["trained"], alpha=0.85)
|
| 120 |
+
ax.bar_label(bars, fmt="%.2f", padding=3)
|
| 121 |
+
ax.set_xlim(0, 1.05)
|
| 122 |
+
ax.set_xlabel("Success rate (held-out: executed_cleanly)")
|
| 123 |
+
ax.set_title(title)
|
| 124 |
+
ax.grid(axis="x", linestyle=":", alpha=0.4)
|
| 125 |
+
fig.tight_layout()
|
| 126 |
+
fig.savefig(out_path, dpi=100, bbox_inches="tight")
|
| 127 |
+
plt.close(fig)
|
| 128 |
+
return str(out_path)
|
forgeenv/training/reward_functions.py
CHANGED
|
@@ -1,127 +1,127 @@
|
|
| 1 |
-
"""Reward functions for both roles, following R-Zero's Algorithm 1.
|
| 2 |
-
|
| 3 |
-
- Repair Agent (Solver): visible verifier reward (binary-ish with partial credit)
|
| 4 |
-
- Drift Generator (Challenger): uncertainty reward + repetition penalty
|
| 5 |
-
- Alignment metric: Pearson correlation between visible and held-out scores
|
| 6 |
-
"""
|
| 7 |
-
from __future__ import annotations
|
| 8 |
-
|
| 9 |
-
import numpy as np
|
| 10 |
-
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
|
| 11 |
-
from sklearn.cluster import AgglomerativeClustering
|
| 12 |
-
|
| 13 |
-
from forgeenv.tasks.models import ExecutionResult, Task
|
| 14 |
-
from forgeenv.verifier.held_out_evaluator import compute_held_out_scores # re-export
|
| 15 |
-
from forgeenv.verifier.visible_verifier import compute_visible_reward
|
| 16 |
-
|
| 17 |
-
__all__ = [
|
| 18 |
-
"compute_repair_reward",
|
| 19 |
-
"compute_uncertainty_reward",
|
| 20 |
-
"compute_repetition_penalty",
|
| 21 |
-
"compute_drift_gen_reward",
|
| 22 |
-
"compute_alignment_score",
|
| 23 |
-
"compute_held_out_scores",
|
| 24 |
-
"compute_visible_reward",
|
| 25 |
-
]
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def compute_repair_reward(result: ExecutionResult, task: Task) -> float:
|
| 29 |
-
"""Repair Agent reward: visible verifier scalar."""
|
| 30 |
-
reward, _ = compute_visible_reward(result, task)
|
| 31 |
-
return reward
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def compute_uncertainty_reward(success_rates: list[bool]) -> float:
|
| 35 |
-
"""R-Zero uncertainty reward (Section 2.2, Eq. 1).
|
| 36 |
-
|
| 37 |
-
r_uncertainty = 1 - 2 * |p_hat - 0.5|
|
| 38 |
-
|
| 39 |
-
Peaks at p_hat = 0.5 (maximum learning signal). Drives the Drift
|
| 40 |
-
Generator to propose breakages exactly at the edge of Repair Agent
|
| 41 |
-
capability.
|
| 42 |
-
"""
|
| 43 |
-
if not success_rates:
|
| 44 |
-
return 0.0
|
| 45 |
-
p_hat = sum(success_rates) / len(success_rates)
|
| 46 |
-
return 1.0 - 2.0 * abs(p_hat - 0.5)
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def compute_repetition_penalty(
|
| 50 |
-
breakage_text: str,
|
| 51 |
-
batch_breakages: list[str],
|
| 52 |
-
threshold: float = 0.5,
|
| 53 |
-
) -> float:
|
| 54 |
-
"""R-Zero repetition penalty.
|
| 55 |
-
|
| 56 |
-
Cluster batch breakages by 1 - BLEU distance using agglomerative
|
| 57 |
-
clustering, then penalize a target proportional to the size of its
|
| 58 |
-
cluster (encouraging diverse proposals).
|
| 59 |
-
"""
|
| 60 |
-
if len(batch_breakages) <= 1:
|
| 61 |
-
return 0.0
|
| 62 |
-
|
| 63 |
-
smoother = SmoothingFunction().method1
|
| 64 |
-
n = len(batch_breakages)
|
| 65 |
-
distances = np.ones((n, n), dtype=np.float64)
|
| 66 |
-
|
| 67 |
-
for i in range(n):
|
| 68 |
-
distances[i][i] = 0.0
|
| 69 |
-
for j in range(i + 1, n):
|
| 70 |
-
tokens_i = batch_breakages[i].split()
|
| 71 |
-
tokens_j = batch_breakages[j].split()
|
| 72 |
-
if tokens_i and tokens_j:
|
| 73 |
-
bleu = sentence_bleu(
|
| 74 |
-
[tokens_i], tokens_j, smoothing_function=smoother
|
| 75 |
-
)
|
| 76 |
-
dist = 1.0 - bleu
|
| 77 |
-
else:
|
| 78 |
-
dist = 1.0
|
| 79 |
-
distances[i][j] = dist
|
| 80 |
-
distances[j][i] = dist
|
| 81 |
-
|
| 82 |
-
clustering = AgglomerativeClustering(
|
| 83 |
-
n_clusters=None,
|
| 84 |
-
distance_threshold=threshold,
|
| 85 |
-
metric="precomputed",
|
| 86 |
-
linkage="average",
|
| 87 |
-
)
|
| 88 |
-
labels = clustering.fit_predict(distances)
|
| 89 |
-
|
| 90 |
-
target_idx = (
|
| 91 |
-
batch_breakages.index(breakage_text) if breakage_text in batch_breakages else 0
|
| 92 |
-
)
|
| 93 |
-
target_cluster = labels[target_idx]
|
| 94 |
-
cluster_size = int(sum(1 for label in labels if label == target_cluster))
|
| 95 |
-
return cluster_size / len(batch_breakages)
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
def compute_drift_gen_reward(
|
| 99 |
-
breakage_text: str,
|
| 100 |
-
repair_successes: list[bool],
|
| 101 |
-
batch_breakages: list[str],
|
| 102 |
-
) -> float:
|
| 103 |
-
"""R-Zero composite Challenger reward: max(0, uncertainty - repetition)."""
|
| 104 |
-
uncertainty = compute_uncertainty_reward(repair_successes)
|
| 105 |
-
penalty = compute_repetition_penalty(breakage_text, batch_breakages)
|
| 106 |
-
return max(0.0, uncertainty - penalty)
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def compute_alignment_score(
|
| 110 |
-
visible_scores: list[float],
|
| 111 |
-
held_out_scores: list[float],
|
| 112 |
-
) -> float:
|
| 113 |
-
"""Pearson correlation between visible verifier and held-out scores
|
| 114 |
-
across rollouts. Used to train the Drift Generator to propose
|
| 115 |
-
breakages where the visible verifier tracks ground truth (anti-hacking
|
| 116 |
-
signal: an exploitable visible verifier produces low correlation)."""
|
| 117 |
-
if len(visible_scores) < 2 or len(visible_scores) != len(held_out_scores):
|
| 118 |
-
return 0.0
|
| 119 |
-
|
| 120 |
-
v = np.asarray(visible_scores, dtype=np.float64)
|
| 121 |
-
h = np.asarray(held_out_scores, dtype=np.float64)
|
| 122 |
-
|
| 123 |
-
if v.std() < 1e-8 or h.std() < 1e-8:
|
| 124 |
-
return 0.0
|
| 125 |
-
|
| 126 |
-
correlation = np.corrcoef(v, h)[0, 1]
|
| 127 |
-
return 0.0 if np.isnan(correlation) else float(correlation)
|
|
|
|
| 1 |
+
"""Reward functions for both roles, following R-Zero's Algorithm 1.
|
| 2 |
+
|
| 3 |
+
- Repair Agent (Solver): visible verifier reward (binary-ish with partial credit)
|
| 4 |
+
- Drift Generator (Challenger): uncertainty reward + repetition penalty
|
| 5 |
+
- Alignment metric: Pearson correlation between visible and held-out scores
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
|
| 11 |
+
from sklearn.cluster import AgglomerativeClustering
|
| 12 |
+
|
| 13 |
+
from forgeenv.tasks.models import ExecutionResult, Task
|
| 14 |
+
from forgeenv.verifier.held_out_evaluator import compute_held_out_scores # re-export
|
| 15 |
+
from forgeenv.verifier.visible_verifier import compute_visible_reward
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"compute_repair_reward",
|
| 19 |
+
"compute_uncertainty_reward",
|
| 20 |
+
"compute_repetition_penalty",
|
| 21 |
+
"compute_drift_gen_reward",
|
| 22 |
+
"compute_alignment_score",
|
| 23 |
+
"compute_held_out_scores",
|
| 24 |
+
"compute_visible_reward",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def compute_repair_reward(result: ExecutionResult, task: Task) -> float:
|
| 29 |
+
"""Repair Agent reward: visible verifier scalar."""
|
| 30 |
+
reward, _ = compute_visible_reward(result, task)
|
| 31 |
+
return reward
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def compute_uncertainty_reward(success_rates: list[bool]) -> float:
|
| 35 |
+
"""R-Zero uncertainty reward (Section 2.2, Eq. 1).
|
| 36 |
+
|
| 37 |
+
r_uncertainty = 1 - 2 * |p_hat - 0.5|
|
| 38 |
+
|
| 39 |
+
Peaks at p_hat = 0.5 (maximum learning signal). Drives the Drift
|
| 40 |
+
Generator to propose breakages exactly at the edge of Repair Agent
|
| 41 |
+
capability.
|
| 42 |
+
"""
|
| 43 |
+
if not success_rates:
|
| 44 |
+
return 0.0
|
| 45 |
+
p_hat = sum(success_rates) / len(success_rates)
|
| 46 |
+
return 1.0 - 2.0 * abs(p_hat - 0.5)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def compute_repetition_penalty(
|
| 50 |
+
breakage_text: str,
|
| 51 |
+
batch_breakages: list[str],
|
| 52 |
+
threshold: float = 0.5,
|
| 53 |
+
) -> float:
|
| 54 |
+
"""R-Zero repetition penalty.
|
| 55 |
+
|
| 56 |
+
Cluster batch breakages by 1 - BLEU distance using agglomerative
|
| 57 |
+
clustering, then penalize a target proportional to the size of its
|
| 58 |
+
cluster (encouraging diverse proposals).
|
| 59 |
+
"""
|
| 60 |
+
if len(batch_breakages) <= 1:
|
| 61 |
+
return 0.0
|
| 62 |
+
|
| 63 |
+
smoother = SmoothingFunction().method1
|
| 64 |
+
n = len(batch_breakages)
|
| 65 |
+
distances = np.ones((n, n), dtype=np.float64)
|
| 66 |
+
|
| 67 |
+
for i in range(n):
|
| 68 |
+
distances[i][i] = 0.0
|
| 69 |
+
for j in range(i + 1, n):
|
| 70 |
+
tokens_i = batch_breakages[i].split()
|
| 71 |
+
tokens_j = batch_breakages[j].split()
|
| 72 |
+
if tokens_i and tokens_j:
|
| 73 |
+
bleu = sentence_bleu(
|
| 74 |
+
[tokens_i], tokens_j, smoothing_function=smoother
|
| 75 |
+
)
|
| 76 |
+
dist = 1.0 - bleu
|
| 77 |
+
else:
|
| 78 |
+
dist = 1.0
|
| 79 |
+
distances[i][j] = dist
|
| 80 |
+
distances[j][i] = dist
|
| 81 |
+
|
| 82 |
+
clustering = AgglomerativeClustering(
|
| 83 |
+
n_clusters=None,
|
| 84 |
+
distance_threshold=threshold,
|
| 85 |
+
metric="precomputed",
|
| 86 |
+
linkage="average",
|
| 87 |
+
)
|
| 88 |
+
labels = clustering.fit_predict(distances)
|
| 89 |
+
|
| 90 |
+
target_idx = (
|
| 91 |
+
batch_breakages.index(breakage_text) if breakage_text in batch_breakages else 0
|
| 92 |
+
)
|
| 93 |
+
target_cluster = labels[target_idx]
|
| 94 |
+
cluster_size = int(sum(1 for label in labels if label == target_cluster))
|
| 95 |
+
return cluster_size / len(batch_breakages)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def compute_drift_gen_reward(
|
| 99 |
+
breakage_text: str,
|
| 100 |
+
repair_successes: list[bool],
|
| 101 |
+
batch_breakages: list[str],
|
| 102 |
+
) -> float:
|
| 103 |
+
"""R-Zero composite Challenger reward: max(0, uncertainty - repetition)."""
|
| 104 |
+
uncertainty = compute_uncertainty_reward(repair_successes)
|
| 105 |
+
penalty = compute_repetition_penalty(breakage_text, batch_breakages)
|
| 106 |
+
return max(0.0, uncertainty - penalty)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def compute_alignment_score(
|
| 110 |
+
visible_scores: list[float],
|
| 111 |
+
held_out_scores: list[float],
|
| 112 |
+
) -> float:
|
| 113 |
+
"""Pearson correlation between visible verifier and held-out scores
|
| 114 |
+
across rollouts. Used to train the Drift Generator to propose
|
| 115 |
+
breakages where the visible verifier tracks ground truth (anti-hacking
|
| 116 |
+
signal: an exploitable visible verifier produces low correlation)."""
|
| 117 |
+
if len(visible_scores) < 2 or len(visible_scores) != len(held_out_scores):
|
| 118 |
+
return 0.0
|
| 119 |
+
|
| 120 |
+
v = np.asarray(visible_scores, dtype=np.float64)
|
| 121 |
+
h = np.asarray(held_out_scores, dtype=np.float64)
|
| 122 |
+
|
| 123 |
+
if v.std() < 1e-8 or h.std() < 1e-8:
|
| 124 |
+
return 0.0
|
| 125 |
+
|
| 126 |
+
correlation = np.corrcoef(v, h)[0, 1]
|
| 127 |
+
return 0.0 if np.isnan(correlation) else float(correlation)
|
forgeenv/training/rollout.py
CHANGED
|
@@ -1,173 +1,173 @@
|
|
| 1 |
-
"""Rollout function: connects an LLM to ForgeEnvironment for a full episode.
|
| 2 |
-
|
| 3 |
-
This is the function the GRPO trainer calls to convert a prompt into a
|
| 4 |
-
trajectory + reward. It runs both phases of an episode (drift + repair) by
|
| 5 |
-
asking the policy twice with role-switched prompts.
|
| 6 |
-
"""
|
| 7 |
-
from __future__ import annotations
|
| 8 |
-
|
| 9 |
-
from dataclasses import dataclass, field
|
| 10 |
-
from typing import Any, Callable, Optional
|
| 11 |
-
|
| 12 |
-
from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction
|
| 13 |
-
from forgeenv.env.forge_environment import ForgeEnvironment
|
| 14 |
-
from forgeenv.roles.drift_generator import (
|
| 15 |
-
BaselineDriftGenerator,
|
| 16 |
-
parse_drift_output,
|
| 17 |
-
)
|
| 18 |
-
from forgeenv.roles.prompts import (
|
| 19 |
-
DRIFT_GENERATOR_SYSTEM_PROMPT,
|
| 20 |
-
REPAIR_AGENT_SYSTEM_PROMPT,
|
| 21 |
-
render_drift_generator_prompt,
|
| 22 |
-
render_repair_agent_prompt,
|
| 23 |
-
)
|
| 24 |
-
from forgeenv.roles.repair_agent import BaselineRepairAgent, extract_diff
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
# Generation function signature: takes a (system, user) prompt pair and
|
| 28 |
-
# returns the assistant's completion. We keep this abstract so we can plug
|
| 29 |
-
# in TRL's batched generator, vLLM, or our deterministic baseline.
|
| 30 |
-
GenerateFn = Callable[[str, str], str]
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
@dataclass
|
| 34 |
-
class RolloutResult:
|
| 35 |
-
task_id: str
|
| 36 |
-
primitive_type: str
|
| 37 |
-
drift_prompt: str
|
| 38 |
-
drift_completion: str
|
| 39 |
-
repair_prompt: str
|
| 40 |
-
repair_completion: str
|
| 41 |
-
visible_reward: float
|
| 42 |
-
visible_breakdown: dict[str, float]
|
| 43 |
-
held_out_breakdown: dict[str, float]
|
| 44 |
-
success: bool
|
| 45 |
-
error_trace: str = ""
|
| 46 |
-
info: dict[str, Any] = field(default_factory=dict)
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def _baseline_drift_generate(env: ForgeEnvironment) -> GenerateFn:
|
| 50 |
-
"""Wrap our deterministic Drift Generator into a GenerateFn."""
|
| 51 |
-
|
| 52 |
-
gen = BaselineDriftGenerator(seed=0)
|
| 53 |
-
|
| 54 |
-
def fn(system: str, user: str) -> str:
|
| 55 |
-
target = "RenameApiCall"
|
| 56 |
-
for line in user.splitlines():
|
| 57 |
-
if line.lower().startswith("target category:"):
|
| 58 |
-
target = line.split(":", 1)[1].strip()
|
| 59 |
-
break
|
| 60 |
-
# Try to extract the script body so we can pick a primitive that
|
| 61 |
-
# actually mutates it.
|
| 62 |
-
script_block = ""
|
| 63 |
-
if "```python" in user:
|
| 64 |
-
script_block = user.split("```python", 1)[1].split("```", 1)[0]
|
| 65 |
-
spec = gen.propose(target_category=target, script=script_block)
|
| 66 |
-
import json
|
| 67 |
-
|
| 68 |
-
return json.dumps(spec)
|
| 69 |
-
|
| 70 |
-
return fn
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def _baseline_repair_generate() -> GenerateFn:
|
| 74 |
-
"""Wrap our deterministic Repair Agent into a GenerateFn.
|
| 75 |
-
|
| 76 |
-
The baseline cheats by recovering the original script from the user
|
| 77 |
-
prompt is impossible (we don't pass it). Instead, when called as a
|
| 78 |
-
baseline it just returns an empty diff. Use BaselineDriftGenerator-paired
|
| 79 |
-
tests (which read env.state) when you want the oracle path.
|
| 80 |
-
"""
|
| 81 |
-
|
| 82 |
-
def fn(system: str, user: str) -> str:
|
| 83 |
-
return "" # baseline = no-op (intentional negative baseline)
|
| 84 |
-
|
| 85 |
-
return fn
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def rollout_one_episode(
|
| 89 |
-
env: ForgeEnvironment,
|
| 90 |
-
drift_generate: Optional[GenerateFn] = None,
|
| 91 |
-
repair_generate: Optional[GenerateFn] = None,
|
| 92 |
-
difficulty: str = "easy",
|
| 93 |
-
) -> RolloutResult:
|
| 94 |
-
"""Run a single 2-step episode end-to-end and capture all signals."""
|
| 95 |
-
drift_generate = drift_generate or _baseline_drift_generate(env)
|
| 96 |
-
repair_generate = repair_generate or _baseline_repair_generate()
|
| 97 |
-
|
| 98 |
-
obs = env.reset(difficulty=difficulty)
|
| 99 |
-
assert obs.current_phase == "drift_gen"
|
| 100 |
-
|
| 101 |
-
# ---------- Phase 1: Drift Generator ----------
|
| 102 |
-
drift_prompt = render_drift_generator_prompt(
|
| 103 |
-
script=obs.script_content,
|
| 104 |
-
target_category=obs.target_category,
|
| 105 |
-
library_versions=obs.library_versions,
|
| 106 |
-
)
|
| 107 |
-
drift_raw = drift_generate(DRIFT_GENERATOR_SYSTEM_PROMPT, drift_prompt)
|
| 108 |
-
spec = parse_drift_output(drift_raw)
|
| 109 |
-
if not spec:
|
| 110 |
-
spec = {"primitive_type": "RenameApiCall", "params": {}}
|
| 111 |
-
|
| 112 |
-
breakage_action = ForgeAction(
|
| 113 |
-
breakage=BreakageAction(
|
| 114 |
-
primitive_type=spec.get("primitive_type", "RenameApiCall"),
|
| 115 |
-
params=spec.get("params", {}) or {},
|
| 116 |
-
)
|
| 117 |
-
)
|
| 118 |
-
obs2 = env.step(breakage_action)
|
| 119 |
-
|
| 120 |
-
# ---------- Phase 2: Repair Agent ----------
|
| 121 |
-
repair_prompt = render_repair_agent_prompt(
|
| 122 |
-
broken_script=obs2.script_content,
|
| 123 |
-
error_trace=obs2.error_trace or "",
|
| 124 |
-
library_versions=obs2.library_versions,
|
| 125 |
-
target_category=obs2.target_category,
|
| 126 |
-
)
|
| 127 |
-
repair_raw = repair_generate(REPAIR_AGENT_SYSTEM_PROMPT, repair_prompt)
|
| 128 |
-
diff = extract_diff(repair_raw) if repair_raw else ""
|
| 129 |
-
|
| 130 |
-
repair_action = ForgeAction(repair=RepairAction(unified_diff=diff))
|
| 131 |
-
obs3 = env.step(repair_action)
|
| 132 |
-
|
| 133 |
-
return RolloutResult(
|
| 134 |
-
task_id=obs.task_id,
|
| 135 |
-
primitive_type=spec.get("primitive_type", "RenameApiCall"),
|
| 136 |
-
drift_prompt=drift_prompt,
|
| 137 |
-
drift_completion=drift_raw,
|
| 138 |
-
repair_prompt=repair_prompt,
|
| 139 |
-
repair_completion=repair_raw,
|
| 140 |
-
visible_reward=float(obs3.reward or 0.0),
|
| 141 |
-
visible_breakdown=dict(obs3.reward_breakdown),
|
| 142 |
-
held_out_breakdown=dict(obs3.held_out_breakdown),
|
| 143 |
-
success=bool(obs3.held_out_breakdown.get("executed_cleanly", 0.0) > 0.5),
|
| 144 |
-
error_trace=obs3.error_trace or "",
|
| 145 |
-
info=dict(obs3.info),
|
| 146 |
-
)
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
def baseline_oracle_repair_generate(env: ForgeEnvironment) -> GenerateFn:
|
| 150 |
-
"""An "oracle" repair generator that reads the original script from
|
| 151 |
-
`env.state` and emits a perfect diff. Useful for sanity-checking the
|
| 152 |
-
end-to-end loop and as the upper-bound baseline in plots.
|
| 153 |
-
"""
|
| 154 |
-
|
| 155 |
-
repair_agent = BaselineRepairAgent()
|
| 156 |
-
|
| 157 |
-
def fn(system: str, user: str) -> str:
|
| 158 |
-
# Pull the original script out of env state via the task sampler
|
| 159 |
-
task_id = env.state.get("task_id")
|
| 160 |
-
if task_id is None:
|
| 161 |
-
return ""
|
| 162 |
-
task = env.task_sampler.get_by_id(task_id)
|
| 163 |
-
if task is None:
|
| 164 |
-
return ""
|
| 165 |
-
# The current script in env._broken_script is what the user sees.
|
| 166 |
-
broken = env._broken_script # noqa: SLF001 — internal but oracle-only
|
| 167 |
-
return repair_agent.repair(
|
| 168 |
-
broken,
|
| 169 |
-
breakage_spec=env._breakage_spec, # noqa: SLF001
|
| 170 |
-
original_script=task.script_content,
|
| 171 |
-
)
|
| 172 |
-
|
| 173 |
-
return fn
|
|
|
|
| 1 |
+
"""Rollout function: connects an LLM to ForgeEnvironment for a full episode.
|
| 2 |
+
|
| 3 |
+
This is the function the GRPO trainer calls to convert a prompt into a
|
| 4 |
+
trajectory + reward. It runs both phases of an episode (drift + repair) by
|
| 5 |
+
asking the policy twice with role-switched prompts.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Any, Callable, Optional
|
| 11 |
+
|
| 12 |
+
from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction
|
| 13 |
+
from forgeenv.env.forge_environment import ForgeEnvironment
|
| 14 |
+
from forgeenv.roles.drift_generator import (
|
| 15 |
+
BaselineDriftGenerator,
|
| 16 |
+
parse_drift_output,
|
| 17 |
+
)
|
| 18 |
+
from forgeenv.roles.prompts import (
|
| 19 |
+
DRIFT_GENERATOR_SYSTEM_PROMPT,
|
| 20 |
+
REPAIR_AGENT_SYSTEM_PROMPT,
|
| 21 |
+
render_drift_generator_prompt,
|
| 22 |
+
render_repair_agent_prompt,
|
| 23 |
+
)
|
| 24 |
+
from forgeenv.roles.repair_agent import BaselineRepairAgent, extract_diff
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Generation function signature: takes a (system, user) prompt pair and
|
| 28 |
+
# returns the assistant's completion. We keep this abstract so we can plug
|
| 29 |
+
# in TRL's batched generator, vLLM, or our deterministic baseline.
|
| 30 |
+
GenerateFn = Callable[[str, str], str]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class RolloutResult:
|
| 35 |
+
task_id: str
|
| 36 |
+
primitive_type: str
|
| 37 |
+
drift_prompt: str
|
| 38 |
+
drift_completion: str
|
| 39 |
+
repair_prompt: str
|
| 40 |
+
repair_completion: str
|
| 41 |
+
visible_reward: float
|
| 42 |
+
visible_breakdown: dict[str, float]
|
| 43 |
+
held_out_breakdown: dict[str, float]
|
| 44 |
+
success: bool
|
| 45 |
+
error_trace: str = ""
|
| 46 |
+
info: dict[str, Any] = field(default_factory=dict)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _baseline_drift_generate(env: ForgeEnvironment) -> GenerateFn:
|
| 50 |
+
"""Wrap our deterministic Drift Generator into a GenerateFn."""
|
| 51 |
+
|
| 52 |
+
gen = BaselineDriftGenerator(seed=0)
|
| 53 |
+
|
| 54 |
+
def fn(system: str, user: str) -> str:
|
| 55 |
+
target = "RenameApiCall"
|
| 56 |
+
for line in user.splitlines():
|
| 57 |
+
if line.lower().startswith("target category:"):
|
| 58 |
+
target = line.split(":", 1)[1].strip()
|
| 59 |
+
break
|
| 60 |
+
# Try to extract the script body so we can pick a primitive that
|
| 61 |
+
# actually mutates it.
|
| 62 |
+
script_block = ""
|
| 63 |
+
if "```python" in user:
|
| 64 |
+
script_block = user.split("```python", 1)[1].split("```", 1)[0]
|
| 65 |
+
spec = gen.propose(target_category=target, script=script_block)
|
| 66 |
+
import json
|
| 67 |
+
|
| 68 |
+
return json.dumps(spec)
|
| 69 |
+
|
| 70 |
+
return fn
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _baseline_repair_generate() -> GenerateFn:
|
| 74 |
+
"""Wrap our deterministic Repair Agent into a GenerateFn.
|
| 75 |
+
|
| 76 |
+
The baseline cheats by recovering the original script from the user
|
| 77 |
+
prompt is impossible (we don't pass it). Instead, when called as a
|
| 78 |
+
baseline it just returns an empty diff. Use BaselineDriftGenerator-paired
|
| 79 |
+
tests (which read env.state) when you want the oracle path.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def fn(system: str, user: str) -> str:
|
| 83 |
+
return "" # baseline = no-op (intentional negative baseline)
|
| 84 |
+
|
| 85 |
+
return fn
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def rollout_one_episode(
|
| 89 |
+
env: ForgeEnvironment,
|
| 90 |
+
drift_generate: Optional[GenerateFn] = None,
|
| 91 |
+
repair_generate: Optional[GenerateFn] = None,
|
| 92 |
+
difficulty: str = "easy",
|
| 93 |
+
) -> RolloutResult:
|
| 94 |
+
"""Run a single 2-step episode end-to-end and capture all signals."""
|
| 95 |
+
drift_generate = drift_generate or _baseline_drift_generate(env)
|
| 96 |
+
repair_generate = repair_generate or _baseline_repair_generate()
|
| 97 |
+
|
| 98 |
+
obs = env.reset(difficulty=difficulty)
|
| 99 |
+
assert obs.current_phase == "drift_gen"
|
| 100 |
+
|
| 101 |
+
# ---------- Phase 1: Drift Generator ----------
|
| 102 |
+
drift_prompt = render_drift_generator_prompt(
|
| 103 |
+
script=obs.script_content,
|
| 104 |
+
target_category=obs.target_category,
|
| 105 |
+
library_versions=obs.library_versions,
|
| 106 |
+
)
|
| 107 |
+
drift_raw = drift_generate(DRIFT_GENERATOR_SYSTEM_PROMPT, drift_prompt)
|
| 108 |
+
spec = parse_drift_output(drift_raw)
|
| 109 |
+
if not spec:
|
| 110 |
+
spec = {"primitive_type": "RenameApiCall", "params": {}}
|
| 111 |
+
|
| 112 |
+
breakage_action = ForgeAction(
|
| 113 |
+
breakage=BreakageAction(
|
| 114 |
+
primitive_type=spec.get("primitive_type", "RenameApiCall"),
|
| 115 |
+
params=spec.get("params", {}) or {},
|
| 116 |
+
)
|
| 117 |
+
)
|
| 118 |
+
obs2 = env.step(breakage_action)
|
| 119 |
+
|
| 120 |
+
# ---------- Phase 2: Repair Agent ----------
|
| 121 |
+
repair_prompt = render_repair_agent_prompt(
|
| 122 |
+
broken_script=obs2.script_content,
|
| 123 |
+
error_trace=obs2.error_trace or "",
|
| 124 |
+
library_versions=obs2.library_versions,
|
| 125 |
+
target_category=obs2.target_category,
|
| 126 |
+
)
|
| 127 |
+
repair_raw = repair_generate(REPAIR_AGENT_SYSTEM_PROMPT, repair_prompt)
|
| 128 |
+
diff = extract_diff(repair_raw) if repair_raw else ""
|
| 129 |
+
|
| 130 |
+
repair_action = ForgeAction(repair=RepairAction(unified_diff=diff))
|
| 131 |
+
obs3 = env.step(repair_action)
|
| 132 |
+
|
| 133 |
+
return RolloutResult(
|
| 134 |
+
task_id=obs.task_id,
|
| 135 |
+
primitive_type=spec.get("primitive_type", "RenameApiCall"),
|
| 136 |
+
drift_prompt=drift_prompt,
|
| 137 |
+
drift_completion=drift_raw,
|
| 138 |
+
repair_prompt=repair_prompt,
|
| 139 |
+
repair_completion=repair_raw,
|
| 140 |
+
visible_reward=float(obs3.reward or 0.0),
|
| 141 |
+
visible_breakdown=dict(obs3.reward_breakdown),
|
| 142 |
+
held_out_breakdown=dict(obs3.held_out_breakdown),
|
| 143 |
+
success=bool(obs3.held_out_breakdown.get("executed_cleanly", 0.0) > 0.5),
|
| 144 |
+
error_trace=obs3.error_trace or "",
|
| 145 |
+
info=dict(obs3.info),
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def baseline_oracle_repair_generate(env: ForgeEnvironment) -> GenerateFn:
|
| 150 |
+
"""An "oracle" repair generator that reads the original script from
|
| 151 |
+
`env.state` and emits a perfect diff. Useful for sanity-checking the
|
| 152 |
+
end-to-end loop and as the upper-bound baseline in plots.
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
repair_agent = BaselineRepairAgent()
|
| 156 |
+
|
| 157 |
+
def fn(system: str, user: str) -> str:
|
| 158 |
+
# Pull the original script out of env state via the task sampler
|
| 159 |
+
task_id = env.state.get("task_id")
|
| 160 |
+
if task_id is None:
|
| 161 |
+
return ""
|
| 162 |
+
task = env.task_sampler.get_by_id(task_id)
|
| 163 |
+
if task is None:
|
| 164 |
+
return ""
|
| 165 |
+
# The current script in env._broken_script is what the user sees.
|
| 166 |
+
broken = env._broken_script # noqa: SLF001 — internal but oracle-only
|
| 167 |
+
return repair_agent.repair(
|
| 168 |
+
broken,
|
| 169 |
+
breakage_spec=env._breakage_spec, # noqa: SLF001
|
| 170 |
+
original_script=task.script_content,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
return fn
|
forgeenv/training/sft_warmstart.py
CHANGED
|
@@ -1,166 +1,166 @@
|
|
| 1 |
-
"""SFT warm-start trainer for both roles.
|
| 2 |
-
|
| 3 |
-
Run on a Colab T4/A100 GPU. Reads `warmstart/data/repair_pairs.jsonl` (or
|
| 4 |
-
`drift_pairs.jsonl`), wraps in TRL SFTTrainer with Unsloth's 4-bit Qwen2.5
|
| 5 |
-
loader, and saves a LoRA adapter.
|
| 6 |
-
|
| 7 |
-
Usage:
|
| 8 |
-
python -m forgeenv.training.sft_warmstart \\
|
| 9 |
-
--role repair_agent \\
|
| 10 |
-
--data warmstart/data/repair_pairs.jsonl \\
|
| 11 |
-
--output_dir artifacts/checkpoints/repair_agent_sft \\
|
| 12 |
-
--base_model unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit \\
|
| 13 |
-
--max_steps 200
|
| 14 |
-
"""
|
| 15 |
-
from __future__ import annotations
|
| 16 |
-
|
| 17 |
-
import argparse
|
| 18 |
-
import json
|
| 19 |
-
import os
|
| 20 |
-
from pathlib import Path
|
| 21 |
-
from typing import Optional
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def _load_jsonl(path: str) -> list[dict]:
|
| 25 |
-
rows: list[dict] = []
|
| 26 |
-
with open(path, "r", encoding="utf-8") as f:
|
| 27 |
-
for line in f:
|
| 28 |
-
line = line.strip()
|
| 29 |
-
if line:
|
| 30 |
-
rows.append(json.loads(line))
|
| 31 |
-
return rows
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def _format_chat(rows: list[dict]) -> list[dict]:
|
| 35 |
-
"""Flatten messages -> a single `text` field for SFT."""
|
| 36 |
-
out: list[dict] = []
|
| 37 |
-
for row in rows:
|
| 38 |
-
msgs = row["messages"]
|
| 39 |
-
text_parts = []
|
| 40 |
-
for m in msgs:
|
| 41 |
-
text_parts.append(f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>")
|
| 42 |
-
out.append({"text": "\n".join(text_parts)})
|
| 43 |
-
return out
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def run_sft(
|
| 47 |
-
role: str,
|
| 48 |
-
data_path: str,
|
| 49 |
-
output_dir: str,
|
| 50 |
-
base_model: str = "unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit",
|
| 51 |
-
max_steps: int = 200,
|
| 52 |
-
batch_size: int = 2,
|
| 53 |
-
learning_rate: float = 2e-4,
|
| 54 |
-
lora_r: int = 16,
|
| 55 |
-
seed: int = 0,
|
| 56 |
-
use_unsloth: Optional[bool] = None,
|
| 57 |
-
) -> None:
|
| 58 |
-
"""Run SFT. Imports unsloth/trl lazily so this module is importable on
|
| 59 |
-
machines without a GPU."""
|
| 60 |
-
rows = _load_jsonl(data_path)
|
| 61 |
-
formatted = _format_chat(rows)
|
| 62 |
-
print(f"[forgeenv.sft] Loaded {len(formatted)} rows for role={role}")
|
| 63 |
-
|
| 64 |
-
if use_unsloth is None:
|
| 65 |
-
use_unsloth = os.environ.get("FORGEENV_USE_UNSLOTH", "1") == "1"
|
| 66 |
-
|
| 67 |
-
if use_unsloth:
|
| 68 |
-
from unsloth import FastLanguageModel
|
| 69 |
-
from datasets import Dataset
|
| 70 |
-
from trl import SFTConfig, SFTTrainer
|
| 71 |
-
|
| 72 |
-
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 73 |
-
model_name=base_model,
|
| 74 |
-
max_seq_length=4096,
|
| 75 |
-
dtype=None,
|
| 76 |
-
load_in_4bit=True,
|
| 77 |
-
)
|
| 78 |
-
model = FastLanguageModel.get_peft_model(
|
| 79 |
-
model,
|
| 80 |
-
r=lora_r,
|
| 81 |
-
lora_alpha=lora_r * 2,
|
| 82 |
-
lora_dropout=0.0,
|
| 83 |
-
bias="none",
|
| 84 |
-
target_modules=[
|
| 85 |
-
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 86 |
-
"gate_proj", "up_proj", "down_proj",
|
| 87 |
-
],
|
| 88 |
-
use_gradient_checkpointing="unsloth",
|
| 89 |
-
random_state=seed,
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
dataset = Dataset.from_list(formatted)
|
| 93 |
-
sft_config = SFTConfig(
|
| 94 |
-
output_dir=output_dir,
|
| 95 |
-
per_device_train_batch_size=batch_size,
|
| 96 |
-
gradient_accumulation_steps=4,
|
| 97 |
-
warmup_steps=10,
|
| 98 |
-
max_steps=max_steps,
|
| 99 |
-
learning_rate=learning_rate,
|
| 100 |
-
logging_steps=10,
|
| 101 |
-
optim="adamw_8bit",
|
| 102 |
-
weight_decay=0.01,
|
| 103 |
-
lr_scheduler_type="linear",
|
| 104 |
-
seed=seed,
|
| 105 |
-
save_steps=max(50, max_steps // 4),
|
| 106 |
-
save_total_limit=2,
|
| 107 |
-
report_to="none",
|
| 108 |
-
dataset_text_field="text",
|
| 109 |
-
max_seq_length=4096,
|
| 110 |
-
)
|
| 111 |
-
trainer = SFTTrainer(
|
| 112 |
-
model=model,
|
| 113 |
-
tokenizer=tokenizer,
|
| 114 |
-
train_dataset=dataset,
|
| 115 |
-
args=sft_config,
|
| 116 |
-
)
|
| 117 |
-
trainer.train()
|
| 118 |
-
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 119 |
-
model.save_pretrained(output_dir)
|
| 120 |
-
tokenizer.save_pretrained(output_dir)
|
| 121 |
-
print(f"[forgeenv.sft] Saved adapter to {output_dir}")
|
| 122 |
-
return
|
| 123 |
-
|
| 124 |
-
# CPU/dry-run fallback: just dump the formatted dataset to disk so we
|
| 125 |
-
# can verify the pipeline shape locally.
|
| 126 |
-
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 127 |
-
out_file = Path(output_dir) / "formatted_dataset.jsonl"
|
| 128 |
-
with out_file.open("w", encoding="utf-8") as f:
|
| 129 |
-
for row in formatted:
|
| 130 |
-
f.write(json.dumps(row) + "\n")
|
| 131 |
-
print(f"[forgeenv.sft] (dry run) wrote {len(formatted)} rows to {out_file}")
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
def _parse_args() -> argparse.Namespace:
|
| 135 |
-
parser = argparse.ArgumentParser(description=__doc__)
|
| 136 |
-
parser.add_argument(
|
| 137 |
-
"--role", choices=["repair_agent", "drift_generator"], required=True
|
| 138 |
-
)
|
| 139 |
-
parser.add_argument("--data", required=True, help="Path to JSONL warm-start file")
|
| 140 |
-
parser.add_argument("--output_dir", required=True)
|
| 141 |
-
parser.add_argument(
|
| 142 |
-
"--base_model", default="unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit"
|
| 143 |
-
)
|
| 144 |
-
parser.add_argument("--max_steps", type=int, default=200)
|
| 145 |
-
parser.add_argument("--batch_size", type=int, default=2)
|
| 146 |
-
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
| 147 |
-
parser.add_argument("--lora_r", type=int, default=16)
|
| 148 |
-
parser.add_argument("--seed", type=int, default=0)
|
| 149 |
-
parser.add_argument("--dry_run", action="store_true")
|
| 150 |
-
return parser.parse_args()
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
if __name__ == "__main__":
|
| 154 |
-
args = _parse_args()
|
| 155 |
-
run_sft(
|
| 156 |
-
role=args.role,
|
| 157 |
-
data_path=args.data,
|
| 158 |
-
output_dir=args.output_dir,
|
| 159 |
-
base_model=args.base_model,
|
| 160 |
-
max_steps=args.max_steps,
|
| 161 |
-
batch_size=args.batch_size,
|
| 162 |
-
learning_rate=args.learning_rate,
|
| 163 |
-
lora_r=args.lora_r,
|
| 164 |
-
seed=args.seed,
|
| 165 |
-
use_unsloth=not args.dry_run,
|
| 166 |
-
)
|
|
|
|
| 1 |
+
"""SFT warm-start trainer for both roles.
|
| 2 |
+
|
| 3 |
+
Run on a Colab T4/A100 GPU. Reads `warmstart/data/repair_pairs.jsonl` (or
|
| 4 |
+
`drift_pairs.jsonl`), wraps in TRL SFTTrainer with Unsloth's 4-bit Qwen2.5
|
| 5 |
+
loader, and saves a LoRA adapter.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python -m forgeenv.training.sft_warmstart \\
|
| 9 |
+
--role repair_agent \\
|
| 10 |
+
--data warmstart/data/repair_pairs.jsonl \\
|
| 11 |
+
--output_dir artifacts/checkpoints/repair_agent_sft \\
|
| 12 |
+
--base_model unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit \\
|
| 13 |
+
--max_steps 200
|
| 14 |
+
"""
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Optional
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _load_jsonl(path: str) -> list[dict]:
|
| 25 |
+
rows: list[dict] = []
|
| 26 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 27 |
+
for line in f:
|
| 28 |
+
line = line.strip()
|
| 29 |
+
if line:
|
| 30 |
+
rows.append(json.loads(line))
|
| 31 |
+
return rows
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _format_chat(rows: list[dict]) -> list[dict]:
|
| 35 |
+
"""Flatten messages -> a single `text` field for SFT."""
|
| 36 |
+
out: list[dict] = []
|
| 37 |
+
for row in rows:
|
| 38 |
+
msgs = row["messages"]
|
| 39 |
+
text_parts = []
|
| 40 |
+
for m in msgs:
|
| 41 |
+
text_parts.append(f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>")
|
| 42 |
+
out.append({"text": "\n".join(text_parts)})
|
| 43 |
+
return out
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def run_sft(
|
| 47 |
+
role: str,
|
| 48 |
+
data_path: str,
|
| 49 |
+
output_dir: str,
|
| 50 |
+
base_model: str = "unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit",
|
| 51 |
+
max_steps: int = 200,
|
| 52 |
+
batch_size: int = 2,
|
| 53 |
+
learning_rate: float = 2e-4,
|
| 54 |
+
lora_r: int = 16,
|
| 55 |
+
seed: int = 0,
|
| 56 |
+
use_unsloth: Optional[bool] = None,
|
| 57 |
+
) -> None:
|
| 58 |
+
"""Run SFT. Imports unsloth/trl lazily so this module is importable on
|
| 59 |
+
machines without a GPU."""
|
| 60 |
+
rows = _load_jsonl(data_path)
|
| 61 |
+
formatted = _format_chat(rows)
|
| 62 |
+
print(f"[forgeenv.sft] Loaded {len(formatted)} rows for role={role}")
|
| 63 |
+
|
| 64 |
+
if use_unsloth is None:
|
| 65 |
+
use_unsloth = os.environ.get("FORGEENV_USE_UNSLOTH", "1") == "1"
|
| 66 |
+
|
| 67 |
+
if use_unsloth:
|
| 68 |
+
from unsloth import FastLanguageModel
|
| 69 |
+
from datasets import Dataset
|
| 70 |
+
from trl import SFTConfig, SFTTrainer
|
| 71 |
+
|
| 72 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 73 |
+
model_name=base_model,
|
| 74 |
+
max_seq_length=4096,
|
| 75 |
+
dtype=None,
|
| 76 |
+
load_in_4bit=True,
|
| 77 |
+
)
|
| 78 |
+
model = FastLanguageModel.get_peft_model(
|
| 79 |
+
model,
|
| 80 |
+
r=lora_r,
|
| 81 |
+
lora_alpha=lora_r * 2,
|
| 82 |
+
lora_dropout=0.0,
|
| 83 |
+
bias="none",
|
| 84 |
+
target_modules=[
|
| 85 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 86 |
+
"gate_proj", "up_proj", "down_proj",
|
| 87 |
+
],
|
| 88 |
+
use_gradient_checkpointing="unsloth",
|
| 89 |
+
random_state=seed,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
dataset = Dataset.from_list(formatted)
|
| 93 |
+
sft_config = SFTConfig(
|
| 94 |
+
output_dir=output_dir,
|
| 95 |
+
per_device_train_batch_size=batch_size,
|
| 96 |
+
gradient_accumulation_steps=4,
|
| 97 |
+
warmup_steps=10,
|
| 98 |
+
max_steps=max_steps,
|
| 99 |
+
learning_rate=learning_rate,
|
| 100 |
+
logging_steps=10,
|
| 101 |
+
optim="adamw_8bit",
|
| 102 |
+
weight_decay=0.01,
|
| 103 |
+
lr_scheduler_type="linear",
|
| 104 |
+
seed=seed,
|
| 105 |
+
save_steps=max(50, max_steps // 4),
|
| 106 |
+
save_total_limit=2,
|
| 107 |
+
report_to="none",
|
| 108 |
+
dataset_text_field="text",
|
| 109 |
+
max_seq_length=4096,
|
| 110 |
+
)
|
| 111 |
+
trainer = SFTTrainer(
|
| 112 |
+
model=model,
|
| 113 |
+
tokenizer=tokenizer,
|
| 114 |
+
train_dataset=dataset,
|
| 115 |
+
args=sft_config,
|
| 116 |
+
)
|
| 117 |
+
trainer.train()
|
| 118 |
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 119 |
+
model.save_pretrained(output_dir)
|
| 120 |
+
tokenizer.save_pretrained(output_dir)
|
| 121 |
+
print(f"[forgeenv.sft] Saved adapter to {output_dir}")
|
| 122 |
+
return
|
| 123 |
+
|
| 124 |
+
# CPU/dry-run fallback: just dump the formatted dataset to disk so we
|
| 125 |
+
# can verify the pipeline shape locally.
|
| 126 |
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 127 |
+
out_file = Path(output_dir) / "formatted_dataset.jsonl"
|
| 128 |
+
with out_file.open("w", encoding="utf-8") as f:
|
| 129 |
+
for row in formatted:
|
| 130 |
+
f.write(json.dumps(row) + "\n")
|
| 131 |
+
print(f"[forgeenv.sft] (dry run) wrote {len(formatted)} rows to {out_file}")
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _parse_args() -> argparse.Namespace:
|
| 135 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--role", choices=["repair_agent", "drift_generator"], required=True
|
| 138 |
+
)
|
| 139 |
+
parser.add_argument("--data", required=True, help="Path to JSONL warm-start file")
|
| 140 |
+
parser.add_argument("--output_dir", required=True)
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--base_model", default="unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit"
|
| 143 |
+
)
|
| 144 |
+
parser.add_argument("--max_steps", type=int, default=200)
|
| 145 |
+
parser.add_argument("--batch_size", type=int, default=2)
|
| 146 |
+
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
| 147 |
+
parser.add_argument("--lora_r", type=int, default=16)
|
| 148 |
+
parser.add_argument("--seed", type=int, default=0)
|
| 149 |
+
parser.add_argument("--dry_run", action="store_true")
|
| 150 |
+
return parser.parse_args()
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
if __name__ == "__main__":
|
| 154 |
+
args = _parse_args()
|
| 155 |
+
run_sft(
|
| 156 |
+
role=args.role,
|
| 157 |
+
data_path=args.data,
|
| 158 |
+
output_dir=args.output_dir,
|
| 159 |
+
base_model=args.base_model,
|
| 160 |
+
max_steps=args.max_steps,
|
| 161 |
+
batch_size=args.batch_size,
|
| 162 |
+
learning_rate=args.learning_rate,
|
| 163 |
+
lora_r=args.lora_r,
|
| 164 |
+
seed=args.seed,
|
| 165 |
+
use_unsloth=not args.dry_run,
|
| 166 |
+
)
|
forgeenv/verifier/held_out_evaluator.py
CHANGED
|
@@ -1,134 +1,134 @@
|
|
| 1 |
-
"""Held-out evaluator: the deterministic ground-truth scorer.
|
| 2 |
-
|
| 3 |
-
Returns 7 independent components in [0, 1]. The Repair Agent NEVER sees
|
| 4 |
-
this directly; the Drift Generator's training signal derives from
|
| 5 |
-
alignment between the visible verifier and this evaluator (Pearson
|
| 6 |
-
correlation across the K rollouts).
|
| 7 |
-
"""
|
| 8 |
-
from __future__ import annotations
|
| 9 |
-
|
| 10 |
-
import ast
|
| 11 |
-
import re
|
| 12 |
-
|
| 13 |
-
from forgeenv.tasks.models import ExecutionResult, Task
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def compute_held_out_scores(
|
| 17 |
-
result: ExecutionResult, task: Task, repair_diff: str = ""
|
| 18 |
-
) -> dict[str, float]:
|
| 19 |
-
"""Compute 7 independent held-out components."""
|
| 20 |
-
|
| 21 |
-
scores: dict[str, float] = {
|
| 22 |
-
"executed_cleanly": 1.0 if result.exit_code == 0 else 0.0,
|
| 23 |
-
"checkpoint_valid": 1.0 if result.checkpoint_exists else 0.0,
|
| 24 |
-
"loss_decreased": _compute_loss_score(result.stdout),
|
| 25 |
-
"metrics_in_range": _check_metrics(result.stdout, task),
|
| 26 |
-
"no_forbidden_workarounds": _check_workarounds(result.script_content),
|
| 27 |
-
"intent_preserved": _compute_intent_preservation(
|
| 28 |
-
task.script_content, result.script_content
|
| 29 |
-
),
|
| 30 |
-
"hidden_tests_passed": 1.0 if "TRAINING_COMPLETE" in result.stdout else 0.0,
|
| 31 |
-
}
|
| 32 |
-
return scores
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def _compute_loss_score(stdout: str) -> float:
|
| 36 |
-
"""Continuous score based on relative loss decrease from first to last step."""
|
| 37 |
-
|
| 38 |
-
losses: list[float] = []
|
| 39 |
-
for line in stdout.splitlines():
|
| 40 |
-
match = re.search(r"loss[=:\s]+([\d.]+)", line, re.IGNORECASE)
|
| 41 |
-
if match:
|
| 42 |
-
try:
|
| 43 |
-
losses.append(float(match.group(1)))
|
| 44 |
-
except ValueError:
|
| 45 |
-
continue
|
| 46 |
-
|
| 47 |
-
if len(losses) < 2:
|
| 48 |
-
return 0.0
|
| 49 |
-
|
| 50 |
-
decrease = (losses[0] - losses[-1]) / max(losses[0], 1e-8)
|
| 51 |
-
return max(0.0, min(1.0, decrease))
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def _check_metrics(stdout: str, task: Task) -> float:
|
| 55 |
-
"""Return 1.0 if any reported accuracy/eval metric falls in the task's
|
| 56 |
-
expected range; partial credit otherwise; 0.5 if no metric was found."""
|
| 57 |
-
|
| 58 |
-
for line in stdout.splitlines():
|
| 59 |
-
match = re.search(r"(?:accuracy|acc|eval)[=:\s]+([\d.]+)", line, re.IGNORECASE)
|
| 60 |
-
if match:
|
| 61 |
-
try:
|
| 62 |
-
val = float(match.group(1))
|
| 63 |
-
low, high = task.expected_accuracy_range
|
| 64 |
-
if low <= val <= high:
|
| 65 |
-
return 1.0
|
| 66 |
-
distance = min(abs(val - low), abs(val - high))
|
| 67 |
-
return max(0.0, 1.0 - distance)
|
| 68 |
-
except ValueError:
|
| 69 |
-
continue
|
| 70 |
-
return 0.5
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def _check_workarounds(script_content: str) -> float:
|
| 74 |
-
"""Detect forbidden workaround patterns via AST analysis.
|
| 75 |
-
|
| 76 |
-
Catches: bare except, `except Exception: pass`, `except Exception: return`,
|
| 77 |
-
monkey-patching of `__getattr__` / `__class__` / `__dict__`.
|
| 78 |
-
"""
|
| 79 |
-
|
| 80 |
-
if not script_content:
|
| 81 |
-
return 0.0
|
| 82 |
-
|
| 83 |
-
try:
|
| 84 |
-
tree = ast.parse(script_content)
|
| 85 |
-
except SyntaxError:
|
| 86 |
-
return 0.0
|
| 87 |
-
|
| 88 |
-
violations = 0
|
| 89 |
-
|
| 90 |
-
for node in ast.walk(tree):
|
| 91 |
-
if isinstance(node, ast.Try):
|
| 92 |
-
for handler in node.handlers:
|
| 93 |
-
if handler.type is None:
|
| 94 |
-
violations += 1
|
| 95 |
-
elif (
|
| 96 |
-
isinstance(handler.type, ast.Name)
|
| 97 |
-
and handler.type.id == "Exception"
|
| 98 |
-
):
|
| 99 |
-
if len(handler.body) == 1 and isinstance(
|
| 100 |
-
handler.body[0], (ast.Pass, ast.Return)
|
| 101 |
-
):
|
| 102 |
-
violations += 1
|
| 103 |
-
|
| 104 |
-
if isinstance(node, ast.Assign):
|
| 105 |
-
for target in node.targets:
|
| 106 |
-
if isinstance(target, ast.Attribute):
|
| 107 |
-
if target.attr in ("__getattr__", "__class__", "__dict__"):
|
| 108 |
-
violations += 1
|
| 109 |
-
|
| 110 |
-
return 1.0 if violations == 0 else max(0.0, 1.0 - violations * 0.3)
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
def _compute_intent_preservation(original: str, repaired: str) -> float:
|
| 114 |
-
"""Measure how much of the original AST structure is preserved.
|
| 115 |
-
|
| 116 |
-
Uses ratio of shared AST node count: min(N_orig, N_repair) / max(...).
|
| 117 |
-
"""
|
| 118 |
-
|
| 119 |
-
if not original or not repaired:
|
| 120 |
-
return 0.0
|
| 121 |
-
|
| 122 |
-
try:
|
| 123 |
-
orig_tree = ast.parse(original)
|
| 124 |
-
repair_tree = ast.parse(repaired)
|
| 125 |
-
except SyntaxError:
|
| 126 |
-
return 0.0
|
| 127 |
-
|
| 128 |
-
orig_nodes = len(list(ast.walk(orig_tree)))
|
| 129 |
-
repair_nodes = len(list(ast.walk(repair_tree)))
|
| 130 |
-
|
| 131 |
-
if orig_nodes == 0:
|
| 132 |
-
return 0.0
|
| 133 |
-
|
| 134 |
-
return min(orig_nodes, repair_nodes) / max(orig_nodes, repair_nodes)
|
|
|
|
| 1 |
+
"""Held-out evaluator: the deterministic ground-truth scorer.
|
| 2 |
+
|
| 3 |
+
Returns 7 independent components in [0, 1]. The Repair Agent NEVER sees
|
| 4 |
+
this directly; the Drift Generator's training signal derives from
|
| 5 |
+
alignment between the visible verifier and this evaluator (Pearson
|
| 6 |
+
correlation across the K rollouts).
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import ast
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
from forgeenv.tasks.models import ExecutionResult, Task
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def compute_held_out_scores(
|
| 17 |
+
result: ExecutionResult, task: Task, repair_diff: str = ""
|
| 18 |
+
) -> dict[str, float]:
|
| 19 |
+
"""Compute 7 independent held-out components."""
|
| 20 |
+
|
| 21 |
+
scores: dict[str, float] = {
|
| 22 |
+
"executed_cleanly": 1.0 if result.exit_code == 0 else 0.0,
|
| 23 |
+
"checkpoint_valid": 1.0 if result.checkpoint_exists else 0.0,
|
| 24 |
+
"loss_decreased": _compute_loss_score(result.stdout),
|
| 25 |
+
"metrics_in_range": _check_metrics(result.stdout, task),
|
| 26 |
+
"no_forbidden_workarounds": _check_workarounds(result.script_content),
|
| 27 |
+
"intent_preserved": _compute_intent_preservation(
|
| 28 |
+
task.script_content, result.script_content
|
| 29 |
+
),
|
| 30 |
+
"hidden_tests_passed": 1.0 if "TRAINING_COMPLETE" in result.stdout else 0.0,
|
| 31 |
+
}
|
| 32 |
+
return scores
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _compute_loss_score(stdout: str) -> float:
|
| 36 |
+
"""Continuous score based on relative loss decrease from first to last step."""
|
| 37 |
+
|
| 38 |
+
losses: list[float] = []
|
| 39 |
+
for line in stdout.splitlines():
|
| 40 |
+
match = re.search(r"loss[=:\s]+([\d.]+)", line, re.IGNORECASE)
|
| 41 |
+
if match:
|
| 42 |
+
try:
|
| 43 |
+
losses.append(float(match.group(1)))
|
| 44 |
+
except ValueError:
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
if len(losses) < 2:
|
| 48 |
+
return 0.0
|
| 49 |
+
|
| 50 |
+
decrease = (losses[0] - losses[-1]) / max(losses[0], 1e-8)
|
| 51 |
+
return max(0.0, min(1.0, decrease))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _check_metrics(stdout: str, task: Task) -> float:
|
| 55 |
+
"""Return 1.0 if any reported accuracy/eval metric falls in the task's
|
| 56 |
+
expected range; partial credit otherwise; 0.5 if no metric was found."""
|
| 57 |
+
|
| 58 |
+
for line in stdout.splitlines():
|
| 59 |
+
match = re.search(r"(?:accuracy|acc|eval)[=:\s]+([\d.]+)", line, re.IGNORECASE)
|
| 60 |
+
if match:
|
| 61 |
+
try:
|
| 62 |
+
val = float(match.group(1))
|
| 63 |
+
low, high = task.expected_accuracy_range
|
| 64 |
+
if low <= val <= high:
|
| 65 |
+
return 1.0
|
| 66 |
+
distance = min(abs(val - low), abs(val - high))
|
| 67 |
+
return max(0.0, 1.0 - distance)
|
| 68 |
+
except ValueError:
|
| 69 |
+
continue
|
| 70 |
+
return 0.5
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _check_workarounds(script_content: str) -> float:
|
| 74 |
+
"""Detect forbidden workaround patterns via AST analysis.
|
| 75 |
+
|
| 76 |
+
Catches: bare except, `except Exception: pass`, `except Exception: return`,
|
| 77 |
+
monkey-patching of `__getattr__` / `__class__` / `__dict__`.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
if not script_content:
|
| 81 |
+
return 0.0
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
tree = ast.parse(script_content)
|
| 85 |
+
except SyntaxError:
|
| 86 |
+
return 0.0
|
| 87 |
+
|
| 88 |
+
violations = 0
|
| 89 |
+
|
| 90 |
+
for node in ast.walk(tree):
|
| 91 |
+
if isinstance(node, ast.Try):
|
| 92 |
+
for handler in node.handlers:
|
| 93 |
+
if handler.type is None:
|
| 94 |
+
violations += 1
|
| 95 |
+
elif (
|
| 96 |
+
isinstance(handler.type, ast.Name)
|
| 97 |
+
and handler.type.id == "Exception"
|
| 98 |
+
):
|
| 99 |
+
if len(handler.body) == 1 and isinstance(
|
| 100 |
+
handler.body[0], (ast.Pass, ast.Return)
|
| 101 |
+
):
|
| 102 |
+
violations += 1
|
| 103 |
+
|
| 104 |
+
if isinstance(node, ast.Assign):
|
| 105 |
+
for target in node.targets:
|
| 106 |
+
if isinstance(target, ast.Attribute):
|
| 107 |
+
if target.attr in ("__getattr__", "__class__", "__dict__"):
|
| 108 |
+
violations += 1
|
| 109 |
+
|
| 110 |
+
return 1.0 if violations == 0 else max(0.0, 1.0 - violations * 0.3)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _compute_intent_preservation(original: str, repaired: str) -> float:
|
| 114 |
+
"""Measure how much of the original AST structure is preserved.
|
| 115 |
+
|
| 116 |
+
Uses ratio of shared AST node count: min(N_orig, N_repair) / max(...).
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
if not original or not repaired:
|
| 120 |
+
return 0.0
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
orig_tree = ast.parse(original)
|
| 124 |
+
repair_tree = ast.parse(repaired)
|
| 125 |
+
except SyntaxError:
|
| 126 |
+
return 0.0
|
| 127 |
+
|
| 128 |
+
orig_nodes = len(list(ast.walk(orig_tree)))
|
| 129 |
+
repair_nodes = len(list(ast.walk(repair_tree)))
|
| 130 |
+
|
| 131 |
+
if orig_nodes == 0:
|
| 132 |
+
return 0.0
|
| 133 |
+
|
| 134 |
+
return min(orig_nodes, repair_nodes) / max(orig_nodes, repair_nodes)
|
forgeenv/verifier/visible_verifier.py
CHANGED
|
@@ -1,64 +1,64 @@
|
|
| 1 |
-
"""Visible verifier: the immediate reward signal the Repair Agent sees.
|
| 2 |
-
|
| 3 |
-
4 weighted components, summed to a scalar. This is what drives the Repair
|
| 4 |
-
Agent's GRPO updates each rollout. Multiple independent components were
|
| 5 |
-
chosen on purpose, per the reward-engineering survey (arxiv 2408.10215)
|
| 6 |
-
and software-tasks survey (arxiv 2601.19100): a single scalar is far
|
| 7 |
-
easier to game than a composable rubric.
|
| 8 |
-
"""
|
| 9 |
-
from __future__ import annotations
|
| 10 |
-
|
| 11 |
-
import re
|
| 12 |
-
|
| 13 |
-
from forgeenv.tasks.models import ExecutionResult, Task
|
| 14 |
-
|
| 15 |
-
WEIGHTS: dict[str, float] = {
|
| 16 |
-
"script_executes": 1.0,
|
| 17 |
-
"loss_decreased": 0.5,
|
| 18 |
-
"checkpoint_appeared": 0.3,
|
| 19 |
-
"diff_size_penalty": 0.2, # multiplied with a non-positive component value
|
| 20 |
-
}
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def compute_visible_reward(
|
| 24 |
-
result: ExecutionResult, task: Task
|
| 25 |
-
) -> tuple[float, dict[str, float]]:
|
| 26 |
-
"""Compute scalar visible reward and per-component breakdown."""
|
| 27 |
-
|
| 28 |
-
components: dict[str, float] = {}
|
| 29 |
-
|
| 30 |
-
components["script_executes"] = 1.0 if result.exit_code == 0 else 0.0
|
| 31 |
-
components["loss_decreased"] = _check_loss_trend(result.stdout)
|
| 32 |
-
components["checkpoint_appeared"] = 1.0 if result.checkpoint_exists else 0.0
|
| 33 |
-
|
| 34 |
-
original_lines = max(len(task.script_content.splitlines()), 1)
|
| 35 |
-
current_lines = (
|
| 36 |
-
len(result.script_content.splitlines()) if result.script_content else original_lines
|
| 37 |
-
)
|
| 38 |
-
diff_ratio = abs(current_lines - original_lines) / original_lines
|
| 39 |
-
components["diff_size_penalty"] = -1.0 * diff_ratio if diff_ratio > 0.5 else 0.0
|
| 40 |
-
|
| 41 |
-
total = sum(components[k] * WEIGHTS[k] for k in components)
|
| 42 |
-
return total, components
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def _check_loss_trend(stdout: str) -> float:
|
| 46 |
-
"""Parse stdout for `loss=...` patterns and return the fraction of
|
| 47 |
-
consecutive steps where loss strictly decreased."""
|
| 48 |
-
|
| 49 |
-
losses: list[float] = []
|
| 50 |
-
for line in stdout.splitlines():
|
| 51 |
-
match = re.search(r"loss[=:\s]+([\d.]+)", line, re.IGNORECASE)
|
| 52 |
-
if match:
|
| 53 |
-
try:
|
| 54 |
-
losses.append(float(match.group(1)))
|
| 55 |
-
except ValueError:
|
| 56 |
-
continue
|
| 57 |
-
|
| 58 |
-
if len(losses) < 2:
|
| 59 |
-
return 0.0
|
| 60 |
-
|
| 61 |
-
decreasing_steps = sum(
|
| 62 |
-
1 for i in range(1, len(losses)) if losses[i] < losses[i - 1]
|
| 63 |
-
)
|
| 64 |
-
return decreasing_steps / (len(losses) - 1)
|
|
|
|
| 1 |
+
"""Visible verifier: the immediate reward signal the Repair Agent sees.
|
| 2 |
+
|
| 3 |
+
4 weighted components, summed to a scalar. This is what drives the Repair
|
| 4 |
+
Agent's GRPO updates each rollout. Multiple independent components were
|
| 5 |
+
chosen on purpose, per the reward-engineering survey (arxiv 2408.10215)
|
| 6 |
+
and software-tasks survey (arxiv 2601.19100): a single scalar is far
|
| 7 |
+
easier to game than a composable rubric.
|
| 8 |
+
"""
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
from forgeenv.tasks.models import ExecutionResult, Task
|
| 14 |
+
|
| 15 |
+
WEIGHTS: dict[str, float] = {
|
| 16 |
+
"script_executes": 1.0,
|
| 17 |
+
"loss_decreased": 0.5,
|
| 18 |
+
"checkpoint_appeared": 0.3,
|
| 19 |
+
"diff_size_penalty": 0.2, # multiplied with a non-positive component value
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def compute_visible_reward(
|
| 24 |
+
result: ExecutionResult, task: Task
|
| 25 |
+
) -> tuple[float, dict[str, float]]:
|
| 26 |
+
"""Compute scalar visible reward and per-component breakdown."""
|
| 27 |
+
|
| 28 |
+
components: dict[str, float] = {}
|
| 29 |
+
|
| 30 |
+
components["script_executes"] = 1.0 if result.exit_code == 0 else 0.0
|
| 31 |
+
components["loss_decreased"] = _check_loss_trend(result.stdout)
|
| 32 |
+
components["checkpoint_appeared"] = 1.0 if result.checkpoint_exists else 0.0
|
| 33 |
+
|
| 34 |
+
original_lines = max(len(task.script_content.splitlines()), 1)
|
| 35 |
+
current_lines = (
|
| 36 |
+
len(result.script_content.splitlines()) if result.script_content else original_lines
|
| 37 |
+
)
|
| 38 |
+
diff_ratio = abs(current_lines - original_lines) / original_lines
|
| 39 |
+
components["diff_size_penalty"] = -1.0 * diff_ratio if diff_ratio > 0.5 else 0.0
|
| 40 |
+
|
| 41 |
+
total = sum(components[k] * WEIGHTS[k] for k in components)
|
| 42 |
+
return total, components
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _check_loss_trend(stdout: str) -> float:
|
| 46 |
+
"""Parse stdout for `loss=...` patterns and return the fraction of
|
| 47 |
+
consecutive steps where loss strictly decreased."""
|
| 48 |
+
|
| 49 |
+
losses: list[float] = []
|
| 50 |
+
for line in stdout.splitlines():
|
| 51 |
+
match = re.search(r"loss[=:\s]+([\d.]+)", line, re.IGNORECASE)
|
| 52 |
+
if match:
|
| 53 |
+
try:
|
| 54 |
+
losses.append(float(match.group(1)))
|
| 55 |
+
except ValueError:
|
| 56 |
+
continue
|
| 57 |
+
|
| 58 |
+
if len(losses) < 2:
|
| 59 |
+
return 0.0
|
| 60 |
+
|
| 61 |
+
decreasing_steps = sum(
|
| 62 |
+
1 for i in range(1, len(losses)) if losses[i] < losses[i - 1]
|
| 63 |
+
)
|
| 64 |
+
return decreasing_steps / (len(losses) - 1)
|
openenv.yaml
CHANGED
|
@@ -1,23 +1,23 @@
|
|
| 1 |
-
name: forgeenv
|
| 2 |
-
version: 0.1.0
|
| 3 |
-
description: >
|
| 4 |
-
Self-improving RL environment for HuggingFace ecosystem repair.
|
| 5 |
-
Trains agents to fix broken training scripts under library version drift
|
| 6 |
-
through Challenger-Solver co-evolution. Implements R-Zero (Tencent), SPIRAL,
|
| 7 |
-
and Absolute Zero Reasoner techniques on top of OpenEnv.
|
| 8 |
-
theme: self-improvement
|
| 9 |
-
tags:
|
| 10 |
-
- openenv
|
| 11 |
-
- self-play
|
| 12 |
-
- code-repair
|
| 13 |
-
- schema-drift
|
| 14 |
-
- multi-role
|
| 15 |
-
- huggingface
|
| 16 |
-
- reinforcement-learning
|
| 17 |
-
environment:
|
| 18 |
-
class: forgeenv.env.forge_environment.ForgeEnvironment
|
| 19 |
-
action_model: forgeenv.env.actions.ForgeAction
|
| 20 |
-
observation_model: forgeenv.env.observations.ForgeObservation
|
| 21 |
-
server:
|
| 22 |
-
module: forgeenv.env.server
|
| 23 |
-
app: app
|
|
|
|
| 1 |
+
name: forgeenv
|
| 2 |
+
version: 0.1.0
|
| 3 |
+
description: >
|
| 4 |
+
Self-improving RL environment for HuggingFace ecosystem repair.
|
| 5 |
+
Trains agents to fix broken training scripts under library version drift
|
| 6 |
+
through Challenger-Solver co-evolution. Implements R-Zero (Tencent), SPIRAL,
|
| 7 |
+
and Absolute Zero Reasoner techniques on top of OpenEnv.
|
| 8 |
+
theme: self-improvement
|
| 9 |
+
tags:
|
| 10 |
+
- openenv
|
| 11 |
+
- self-play
|
| 12 |
+
- code-repair
|
| 13 |
+
- schema-drift
|
| 14 |
+
- multi-role
|
| 15 |
+
- huggingface
|
| 16 |
+
- reinforcement-learning
|
| 17 |
+
environment:
|
| 18 |
+
class: forgeenv.env.forge_environment.ForgeEnvironment
|
| 19 |
+
action_model: forgeenv.env.actions.ForgeAction
|
| 20 |
+
observation_model: forgeenv.env.observations.ForgeObservation
|
| 21 |
+
server:
|
| 22 |
+
module: forgeenv.env.server
|
| 23 |
+
app: app
|