ForgeEnv deploy
Browse files- Dockerfile +25 -0
- README.md +85 -10
- forgeenv/__init__.py +4 -0
- forgeenv/artifacts/repair_library.py +120 -0
- forgeenv/drift/__init__.py +0 -0
- forgeenv/drift/library_drift_engine.py +74 -0
- forgeenv/env/__init__.py +0 -0
- forgeenv/env/actions.py +50 -0
- forgeenv/env/diff_utils.py +163 -0
- forgeenv/env/forge_environment.py +259 -0
- forgeenv/env/observations.py +29 -0
- forgeenv/env/server.py +46 -0
- forgeenv/primitives/__init__.py +0 -0
- forgeenv/primitives/breakage_primitives.py +282 -0
- forgeenv/primitives/drift_taxonomy.yaml +217 -0
- forgeenv/primitives/repair_primitives.py +241 -0
- forgeenv/roles/__init__.py +0 -0
- forgeenv/roles/drift_generator.py +170 -0
- forgeenv/roles/prompts.py +102 -0
- forgeenv/roles/repair_agent.py +153 -0
- forgeenv/roles/teacher.py +58 -0
- forgeenv/sandbox/__init__.py +0 -0
- forgeenv/sandbox/ast_validator.py +70 -0
- forgeenv/sandbox/simulation_mode.py +142 -0
- forgeenv/tasks/__init__.py +0 -0
- forgeenv/tasks/models.py +45 -0
- forgeenv/tasks/seed_corpus/__init__.py +0 -0
- forgeenv/tasks/seed_corpus/albert_qa.py +67 -0
- forgeenv/tasks/seed_corpus/bert_ner.py +55 -0
- forgeenv/tasks/seed_corpus/distilbert_sst2.py +53 -0
- forgeenv/tasks/seed_corpus/electra_classification.py +44 -0
- forgeenv/tasks/seed_corpus/gpt2_textgen.py +43 -0
- forgeenv/tasks/seed_corpus/logistic_classifier.py +36 -0
- forgeenv/tasks/seed_corpus/roberta_sentiment.py +44 -0
- forgeenv/tasks/seed_corpus/simple_regression.py +28 -0
- forgeenv/tasks/seed_corpus/t5_summarization.py +55 -0
- forgeenv/tasks/seed_corpus/tiny_mlp_mnist.py +38 -0
- forgeenv/tasks/seed_corpus/vit_cifar10.py +41 -0
- forgeenv/tasks/task_sampler.py +105 -0
- forgeenv/verifier/__init__.py +0 -0
- forgeenv/verifier/held_out_evaluator.py +134 -0
- forgeenv/verifier/visible_verifier.py +64 -0
- openenv.yaml +24 -0
- requirements.txt +9 -0
Dockerfile
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 4 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 5 |
+
PIP_NO_CACHE_DIR=1
|
| 6 |
+
|
| 7 |
+
RUN apt-get update \
|
| 8 |
+
&& apt-get install -y --no-install-recommends git curl \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
WORKDIR /app
|
| 12 |
+
|
| 13 |
+
COPY requirements.txt .
|
| 14 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 15 |
+
|
| 16 |
+
COPY forgeenv/ forgeenv/
|
| 17 |
+
COPY openenv.yaml .
|
| 18 |
+
|
| 19 |
+
ENV PORT=7860
|
| 20 |
+
EXPOSE 7860
|
| 21 |
+
|
| 22 |
+
HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 \
|
| 23 |
+
CMD curl -f http://127.0.0.1:7860/health || exit 1
|
| 24 |
+
|
| 25 |
+
CMD ["uvicorn", "forgeenv.env.server:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,10 +1,85 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk: docker
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: ForgeEnv
|
| 3 |
+
emoji: 🔧
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
pinned: true
|
| 9 |
+
license: apache-2.0
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
+
- self-play
|
| 13 |
+
- self-improvement
|
| 14 |
+
- code-repair
|
| 15 |
+
- schema-drift
|
| 16 |
+
- reinforcement-learning
|
| 17 |
+
- huggingface
|
| 18 |
+
short_description: Self-improving RL env for HF library-drift repair
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
# ForgeEnv — OpenEnv Server
|
| 22 |
+
|
| 23 |
+
This Space hosts the **ForgeEnv** OpenEnv-compliant environment as a FastAPI
|
| 24 |
+
service. It exposes the standard `reset`, `step`, and `state` endpoints and is
|
| 25 |
+
the runtime that training notebooks (TRL + Unsloth) connect to.
|
| 26 |
+
|
| 27 |
+
> **Theme:** Self-Improvement (Hackathon Theme #4) — Challenger / Solver
|
| 28 |
+
> co-evolution via R-Zero, SPIRAL, and Absolute Zero Reasoner techniques.
|
| 29 |
+
|
| 30 |
+
## What it does
|
| 31 |
+
|
| 32 |
+
ForgeEnv simulates **HuggingFace library version drift**. A *Drift Generator*
|
| 33 |
+
proposes a realistic breakage to a working training script (renamed APIs,
|
| 34 |
+
deprecated imports, changed argument signatures, etc.). A *Repair Agent* then
|
| 35 |
+
emits a unified diff that should restore the script. Reward is computed by an
|
| 36 |
+
execution simulator + AST checker + held-out evaluator (multi-component to
|
| 37 |
+
resist reward hacking).
|
| 38 |
+
|
| 39 |
+
## API
|
| 40 |
+
|
| 41 |
+
The server uses [`openenv-core`](https://pypi.org/project/openenv-core/) and
|
| 42 |
+
follows the Gym-style contract:
|
| 43 |
+
|
| 44 |
+
| Endpoint | Method | Purpose |
|
| 45 |
+
| -------- | ------ | -------------------------------------------------- |
|
| 46 |
+
| `/reset` | POST | Sample a fresh task, return drift-gen observation |
|
| 47 |
+
| `/step` | POST | Apply a `ForgeAction` (breakage or repair) |
|
| 48 |
+
| `/state` | GET | Inspect the current internal state |
|
| 49 |
+
| `/health`| GET | Health probe (used by the container HEALTHCHECK) |
|
| 50 |
+
|
| 51 |
+
`ForgeAction` is a discriminated union of `BreakageAction` (used in phase 1)
|
| 52 |
+
and `RepairAction` (used in phase 2). See
|
| 53 |
+
[`forgeenv/env/actions.py`](forgeenv/env/actions.py).
|
| 54 |
+
|
| 55 |
+
## Quick test
|
| 56 |
+
|
| 57 |
+
```bash
|
| 58 |
+
curl -X POST https://akhiilll-forgeenv.hf.space/reset
|
| 59 |
+
curl https://akhiilll-forgeenv.hf.space/state
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
```python
|
| 63 |
+
from openenv.core.env_client import EnvClient
|
| 64 |
+
|
| 65 |
+
async with EnvClient(base_url="https://akhiilll-forgeenv.hf.space") as client:
|
| 66 |
+
obs = await client.reset()
|
| 67 |
+
print(obs.observation.current_phase, obs.observation.task_id)
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
## Project links
|
| 71 |
+
|
| 72 |
+
- **Main repo / training notebooks / plots:**
|
| 73 |
+
<https://github.com/akhiilll/forgeenv>
|
| 74 |
+
- **Repair Agent model (LoRA):**
|
| 75 |
+
<https://huggingface.co/akhiilll/forgeenv-repair-agent>
|
| 76 |
+
- **Demo (Gradio + ZeroGPU):**
|
| 77 |
+
<https://huggingface.co/spaces/akhiilll/forgeenv-demo>
|
| 78 |
+
|
| 79 |
+
## Citations
|
| 80 |
+
|
| 81 |
+
- Huang et al., *R-Zero: Self-Evolving Reasoning LLM From Zero Data* (2025)
|
| 82 |
+
- Zhao et al., *Absolute Zero: Reinforced Self-play Reasoning with Zero Data* (2025)
|
| 83 |
+
- Liu et al., *SPIRAL: Self-Play on Zero-Sum Games* (2025)
|
| 84 |
+
- [arXiv:2408.10215](https://arxiv.org/abs/2408.10215) — Reward engineering & shaping
|
| 85 |
+
- [arXiv:2601.19100](https://arxiv.org/abs/2601.19100) — Reward engineering for RL in software tasks
|
forgeenv/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ForgeEnv: Self-improving RL environment for HuggingFace ecosystem repair."""
|
| 2 |
+
|
| 3 |
+
__version__ = "0.1.0"
|
| 4 |
+
__author__ = "akhiilll"
|
forgeenv/artifacts/repair_library.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Persisted "repair library" — the model's accumulated knowledge of
|
| 2 |
+
known breakage -> repair pairs. Curated from successful rollouts during
|
| 3 |
+
training. Loaded at inference time as a few-shot prefix when the agent
|
| 4 |
+
recognises a familiar error class.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
from dataclasses import asdict, dataclass, field
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Optional
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class RepairExample:
|
| 16 |
+
primitive_type: str
|
| 17 |
+
breakage_params: dict[str, Any]
|
| 18 |
+
error_signature: str
|
| 19 |
+
repair_diff: str
|
| 20 |
+
visible_reward: float
|
| 21 |
+
held_out: dict[str, float]
|
| 22 |
+
task_id: str = ""
|
| 23 |
+
|
| 24 |
+
def signature_key(self) -> str:
|
| 25 |
+
return f"{self.primitive_type}::{self.error_signature[:80]}"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class RepairLibrary:
|
| 30 |
+
examples: list[RepairExample] = field(default_factory=list)
|
| 31 |
+
|
| 32 |
+
def add(self, example: RepairExample) -> None:
|
| 33 |
+
self.examples.append(example)
|
| 34 |
+
|
| 35 |
+
def best_match(self, primitive_type: str, error_text: str) -> Optional[RepairExample]:
|
| 36 |
+
"""Return the highest-reward example whose primitive_type matches and
|
| 37 |
+
whose error text overlaps."""
|
| 38 |
+
candidates = [
|
| 39 |
+
e for e in self.examples if e.primitive_type == primitive_type
|
| 40 |
+
]
|
| 41 |
+
if not candidates:
|
| 42 |
+
return None
|
| 43 |
+
scored = sorted(
|
| 44 |
+
candidates,
|
| 45 |
+
key=lambda e: (
|
| 46 |
+
_ngram_overlap(e.error_signature, error_text),
|
| 47 |
+
e.visible_reward,
|
| 48 |
+
),
|
| 49 |
+
reverse=True,
|
| 50 |
+
)
|
| 51 |
+
return scored[0] if scored else None
|
| 52 |
+
|
| 53 |
+
def to_dict(self) -> dict:
|
| 54 |
+
return {
|
| 55 |
+
"version": "1",
|
| 56 |
+
"examples": [asdict(e) for e in self.examples],
|
| 57 |
+
"size": len(self.examples),
|
| 58 |
+
"by_primitive": _count_by_primitive(self.examples),
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
def save(self, path: str | Path) -> None:
|
| 62 |
+
path = Path(path)
|
| 63 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 64 |
+
path.write_text(json.dumps(self.to_dict(), indent=2), encoding="utf-8")
|
| 65 |
+
|
| 66 |
+
@classmethod
|
| 67 |
+
def load(cls, path: str | Path) -> "RepairLibrary":
|
| 68 |
+
data = json.loads(Path(path).read_text(encoding="utf-8"))
|
| 69 |
+
examples = [RepairExample(**e) for e in data.get("examples", [])]
|
| 70 |
+
return cls(examples=examples)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _ngram_overlap(a: str, b: str, n: int = 3) -> float:
|
| 74 |
+
if not a or not b:
|
| 75 |
+
return 0.0
|
| 76 |
+
|
| 77 |
+
def grams(text: str) -> set[str]:
|
| 78 |
+
text = text.lower()
|
| 79 |
+
return {text[i : i + n] for i in range(len(text) - n + 1)}
|
| 80 |
+
|
| 81 |
+
ga, gb = grams(a), grams(b)
|
| 82 |
+
if not ga or not gb:
|
| 83 |
+
return 0.0
|
| 84 |
+
return len(ga & gb) / max(1, len(ga | gb))
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _count_by_primitive(examples: list[RepairExample]) -> dict[str, int]:
|
| 88 |
+
counts: dict[str, int] = {}
|
| 89 |
+
for e in examples:
|
| 90 |
+
counts[e.primitive_type] = counts.get(e.primitive_type, 0) + 1
|
| 91 |
+
return counts
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def curate_from_rollouts(
|
| 95 |
+
rollout_results: list,
|
| 96 |
+
min_reward: float = 0.6,
|
| 97 |
+
min_held_out_clean: float = 0.5,
|
| 98 |
+
) -> RepairLibrary:
|
| 99 |
+
"""Build a RepairLibrary from a list of rollout dicts/RolloutResults."""
|
| 100 |
+
lib = RepairLibrary()
|
| 101 |
+
for r in rollout_results:
|
| 102 |
+
get = r.get if isinstance(r, dict) else lambda k, default=None: getattr(r, k, default)
|
| 103 |
+
if float(get("visible_reward", 0.0) or 0.0) < min_reward:
|
| 104 |
+
continue
|
| 105 |
+
if float(get("held_out_breakdown", {}).get("executed_cleanly", 0.0)) < min_held_out_clean:
|
| 106 |
+
continue
|
| 107 |
+
lib.add(
|
| 108 |
+
RepairExample(
|
| 109 |
+
primitive_type=str(get("primitive_type", "unknown")),
|
| 110 |
+
breakage_params=dict(get("info", {}).get("breakage_spec", {}).get("params", {}))
|
| 111 |
+
if isinstance(get("info", {}), dict)
|
| 112 |
+
else {},
|
| 113 |
+
error_signature=str(get("error_trace", "") or "")[:160],
|
| 114 |
+
repair_diff=str(get("repair_completion", "") or get("info", {}).get("repair_diff", ""))[:2000],
|
| 115 |
+
visible_reward=float(get("visible_reward", 0.0) or 0.0),
|
| 116 |
+
held_out=dict(get("held_out_breakdown", {}) or {}),
|
| 117 |
+
task_id=str(get("task_id", "")),
|
| 118 |
+
)
|
| 119 |
+
)
|
| 120 |
+
return lib
|
forgeenv/drift/__init__.py
ADDED
|
File without changes
|
forgeenv/drift/library_drift_engine.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Library Drift Engine.
|
| 2 |
+
|
| 3 |
+
Manages library version snapshots and triggers version upgrades during
|
| 4 |
+
training to create non-stationary verification. In simulation mode it
|
| 5 |
+
just tracks the current snapshot index — that index influences
|
| 6 |
+
breakage selection and is exposed in observations so the Repair Agent
|
| 7 |
+
can adapt.
|
| 8 |
+
|
| 9 |
+
Also exposes Chojecki GVU's SNR computation
|
| 10 |
+
(https://arxiv.org/abs/2512.02731 Definition 4.4).
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import math
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
|
| 17 |
+
DEFAULT_VERSION_SNAPSHOTS: list[dict[str, str]] = [
|
| 18 |
+
{"transformers": "4.36.0", "datasets": "2.14.0", "trl": "0.7.0"},
|
| 19 |
+
{"transformers": "4.40.0", "datasets": "2.18.0", "trl": "0.8.0"},
|
| 20 |
+
{"transformers": "4.45.0", "datasets": "3.0.0", "trl": "0.10.0"},
|
| 21 |
+
{"transformers": "4.50.0", "datasets": "3.2.0", "trl": "0.12.0"},
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class LibraryDriftEngine:
|
| 27 |
+
snapshots: list[dict[str, str]] = field(
|
| 28 |
+
default_factory=lambda: list(DEFAULT_VERSION_SNAPSHOTS)
|
| 29 |
+
)
|
| 30 |
+
current_index: int = 0
|
| 31 |
+
drift_history: list[dict] = field(default_factory=list)
|
| 32 |
+
|
| 33 |
+
def current_versions(self) -> dict[str, str]:
|
| 34 |
+
return dict(self.snapshots[self.current_index])
|
| 35 |
+
|
| 36 |
+
def maybe_drift(self, episode_num: int, drift_every: int = 50) -> bool:
|
| 37 |
+
if (
|
| 38 |
+
episode_num > 0
|
| 39 |
+
and episode_num % drift_every == 0
|
| 40 |
+
and self.current_index < len(self.snapshots) - 1
|
| 41 |
+
):
|
| 42 |
+
prev = self.snapshots[self.current_index]
|
| 43 |
+
self.current_index += 1
|
| 44 |
+
self.drift_history.append(
|
| 45 |
+
{
|
| 46 |
+
"episode": episode_num,
|
| 47 |
+
"from": prev,
|
| 48 |
+
"to": self.snapshots[self.current_index],
|
| 49 |
+
}
|
| 50 |
+
)
|
| 51 |
+
return True
|
| 52 |
+
return False
|
| 53 |
+
|
| 54 |
+
def reset(self) -> None:
|
| 55 |
+
self.current_index = 0
|
| 56 |
+
self.drift_history.clear()
|
| 57 |
+
|
| 58 |
+
@staticmethod
|
| 59 |
+
def compute_snr(
|
| 60 |
+
recent_held_out: list[float], recent_visible: list[float]
|
| 61 |
+
) -> dict[str, float]:
|
| 62 |
+
"""SNR per Chojecki GVU Def 4.4: SNR = mean(rewards)^2 / variance(rewards)."""
|
| 63 |
+
|
| 64 |
+
def snr(values: list[float]) -> float:
|
| 65 |
+
if len(values) < 2:
|
| 66 |
+
return 0.0
|
| 67 |
+
mean = sum(values) / len(values)
|
| 68 |
+
var = sum((v - mean) ** 2 for v in values) / len(values)
|
| 69 |
+
return mean**2 / max(var, 1e-8)
|
| 70 |
+
|
| 71 |
+
return {
|
| 72 |
+
"snr_verifier": snr(recent_held_out),
|
| 73 |
+
"snr_generator": snr(recent_visible),
|
| 74 |
+
}
|
forgeenv/env/__init__.py
ADDED
|
File without changes
|
forgeenv/env/actions.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic action models for ForgeEnv (compatible with OpenEnv 0.2.x).
|
| 2 |
+
|
| 3 |
+
Episodes have two phases — drift_gen (Challenger) and repair (Solver) — so
|
| 4 |
+
we expose a single union ForgeAction that carries either a BreakageAction
|
| 5 |
+
or a RepairAction. The environment dispatches on which sub-field is set.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import Any, Literal, Optional
|
| 10 |
+
|
| 11 |
+
from pydantic import Field
|
| 12 |
+
|
| 13 |
+
from openenv.core import Action
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BreakageAction(Action):
|
| 17 |
+
"""Drift Generator's action: pick a primitive type + parameters."""
|
| 18 |
+
|
| 19 |
+
action_type: Literal["breakage"] = "breakage"
|
| 20 |
+
primitive_type: str = Field(
|
| 21 |
+
..., description="One of the registered breakage primitive class names"
|
| 22 |
+
)
|
| 23 |
+
params: dict[str, Any] = Field(
|
| 24 |
+
default_factory=dict, description="Primitive-specific parameters"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class RepairAction(Action):
|
| 29 |
+
"""Repair Agent's action: a unified diff (or full replacement script)."""
|
| 30 |
+
|
| 31 |
+
action_type: Literal["repair"] = "repair"
|
| 32 |
+
unified_diff: str = Field(..., description="Unified diff or full replacement script")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ForgeAction(Action):
|
| 36 |
+
"""Union action: exactly one of `breakage` / `repair` must be set.
|
| 37 |
+
|
| 38 |
+
This is the type registered with OpenEnv's `create_app`. It avoids
|
| 39 |
+
Pydantic discriminated unions to keep the OpenAPI schema flat and
|
| 40 |
+
cross-version-friendly.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
breakage: Optional[BreakageAction] = None
|
| 44 |
+
repair: Optional[RepairAction] = None
|
| 45 |
+
|
| 46 |
+
def model_post_init(self, __context: Any) -> None:
|
| 47 |
+
if (self.breakage is None) == (self.repair is None):
|
| 48 |
+
raise ValueError(
|
| 49 |
+
"ForgeAction requires exactly one of `breakage` or `repair` to be set."
|
| 50 |
+
)
|
forgeenv/env/diff_utils.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unified-diff application utilities.
|
| 2 |
+
|
| 3 |
+
The Repair Agent submits a unified diff. We need a permissive applier
|
| 4 |
+
because LLM diffs are often malformed (wrong line numbers, missing
|
| 5 |
+
context, extra prose). We try the strict applier first, then fall
|
| 6 |
+
back to applying hunks via plain string replacement.
|
| 7 |
+
|
| 8 |
+
The agent may also submit a full Python script instead of a diff
|
| 9 |
+
(common when the model's diff format breaks). We detect this and
|
| 10 |
+
treat it as a complete replacement.
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import difflib
|
| 15 |
+
import re
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE)
|
| 19 |
+
_SCRIPT_MARKERS = ("import ", "from ", "def ", "class ", "print(")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def looks_like_full_script(text: str) -> bool:
|
| 23 |
+
"""Heuristic: text is probably a full python script, not a diff."""
|
| 24 |
+
lines = text.lstrip().splitlines()
|
| 25 |
+
if not lines:
|
| 26 |
+
return False
|
| 27 |
+
has_diff_header = any(
|
| 28 |
+
line.startswith(("---", "+++", "@@")) for line in lines[:5]
|
| 29 |
+
)
|
| 30 |
+
if has_diff_header:
|
| 31 |
+
return False
|
| 32 |
+
# If we see two or more script-style markers in the first 30 lines,
|
| 33 |
+
# treat as a full replacement script.
|
| 34 |
+
head = "\n".join(lines[:30])
|
| 35 |
+
hits = sum(1 for marker in _SCRIPT_MARKERS if marker in head)
|
| 36 |
+
return hits >= 2
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _strict_apply(broken_script: str, diff_text: str) -> str | None:
|
| 40 |
+
"""Apply a unified diff strictly. Returns None on any failure."""
|
| 41 |
+
lines = broken_script.splitlines(keepends=True)
|
| 42 |
+
out: list[str] = []
|
| 43 |
+
diff_lines = diff_text.splitlines()
|
| 44 |
+
i = 0
|
| 45 |
+
src_idx = 0
|
| 46 |
+
in_hunk = False
|
| 47 |
+
hunk_old: list[str] = []
|
| 48 |
+
hunk_new: list[str] = []
|
| 49 |
+
|
| 50 |
+
while i < len(diff_lines):
|
| 51 |
+
line = diff_lines[i]
|
| 52 |
+
if line.startswith(("---", "+++")):
|
| 53 |
+
i += 1
|
| 54 |
+
continue
|
| 55 |
+
if line.startswith("@@"):
|
| 56 |
+
# Flush previous hunk
|
| 57 |
+
if in_hunk:
|
| 58 |
+
# Find the hunk_old block in the source starting at src_idx.
|
| 59 |
+
target = "".join(hunk_old)
|
| 60 |
+
source_remainder = "".join(lines[src_idx:])
|
| 61 |
+
pos = source_remainder.find(target)
|
| 62 |
+
if pos == -1:
|
| 63 |
+
return None
|
| 64 |
+
out.append(source_remainder[:pos])
|
| 65 |
+
out.append("".join(hunk_new))
|
| 66 |
+
src_idx += len(source_remainder[: pos + len(target)].splitlines(keepends=True))
|
| 67 |
+
hunk_old, hunk_new = [], []
|
| 68 |
+
in_hunk = True
|
| 69 |
+
i += 1
|
| 70 |
+
continue
|
| 71 |
+
if in_hunk:
|
| 72 |
+
if line.startswith("+"):
|
| 73 |
+
hunk_new.append(line[1:] + "\n")
|
| 74 |
+
elif line.startswith("-"):
|
| 75 |
+
hunk_old.append(line[1:] + "\n")
|
| 76 |
+
else:
|
| 77 |
+
# context line
|
| 78 |
+
ctx = line[1:] if line.startswith(" ") else line
|
| 79 |
+
hunk_old.append(ctx + "\n")
|
| 80 |
+
hunk_new.append(ctx + "\n")
|
| 81 |
+
i += 1
|
| 82 |
+
|
| 83 |
+
# Flush trailing hunk
|
| 84 |
+
if in_hunk and (hunk_old or hunk_new):
|
| 85 |
+
target = "".join(hunk_old)
|
| 86 |
+
source_remainder = "".join(lines[src_idx:])
|
| 87 |
+
pos = source_remainder.find(target)
|
| 88 |
+
if pos == -1:
|
| 89 |
+
return None
|
| 90 |
+
out.append(source_remainder[:pos])
|
| 91 |
+
out.append("".join(hunk_new))
|
| 92 |
+
consumed = source_remainder[: pos + len(target)]
|
| 93 |
+
src_idx += len(consumed.splitlines(keepends=True))
|
| 94 |
+
|
| 95 |
+
out.append("".join(lines[src_idx:]))
|
| 96 |
+
return "".join(out)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _permissive_apply(broken_script: str, diff_text: str) -> str:
|
| 100 |
+
"""Apply a malformed diff by extracting (-,+) line pairs and doing
|
| 101 |
+
a tolerant search-and-replace.
|
| 102 |
+
"""
|
| 103 |
+
repaired = broken_script
|
| 104 |
+
pairs: list[tuple[str, str]] = []
|
| 105 |
+
lines = diff_text.splitlines()
|
| 106 |
+
pending_minus: str | None = None
|
| 107 |
+
|
| 108 |
+
for line in lines:
|
| 109 |
+
if line.startswith("---") or line.startswith("+++") or line.startswith("@@"):
|
| 110 |
+
pending_minus = None
|
| 111 |
+
continue
|
| 112 |
+
if line.startswith("-"):
|
| 113 |
+
pending_minus = line[1:].strip()
|
| 114 |
+
elif line.startswith("+") and pending_minus is not None:
|
| 115 |
+
pairs.append((pending_minus, line[1:].strip()))
|
| 116 |
+
pending_minus = None
|
| 117 |
+
elif pending_minus is not None and not line.startswith(" "):
|
| 118 |
+
# standalone deletion — skip in permissive mode (we can't
|
| 119 |
+
# reliably know what to delete without context)
|
| 120 |
+
pending_minus = None
|
| 121 |
+
|
| 122 |
+
for old, new in pairs:
|
| 123 |
+
if old and old in repaired:
|
| 124 |
+
repaired = repaired.replace(old, new, 1)
|
| 125 |
+
|
| 126 |
+
return repaired
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def apply_unified_diff(broken_script: str, diff_text: str) -> str:
|
| 130 |
+
"""Try every strategy in order and return the first that produces a change.
|
| 131 |
+
|
| 132 |
+
Strategies:
|
| 133 |
+
1. If `diff_text` looks like a full script, return it directly.
|
| 134 |
+
2. Try strict diff application.
|
| 135 |
+
3. Fall back to permissive (-,+) line-pair replacement.
|
| 136 |
+
4. As last resort, return the broken script unchanged.
|
| 137 |
+
"""
|
| 138 |
+
diff_text = diff_text or ""
|
| 139 |
+
if not diff_text.strip():
|
| 140 |
+
return broken_script
|
| 141 |
+
|
| 142 |
+
if looks_like_full_script(diff_text):
|
| 143 |
+
return diff_text
|
| 144 |
+
|
| 145 |
+
if _HUNK_RE.search(diff_text) or "---" in diff_text or "+++" in diff_text:
|
| 146 |
+
strict = _strict_apply(broken_script, diff_text)
|
| 147 |
+
if strict is not None and strict != broken_script:
|
| 148 |
+
return strict
|
| 149 |
+
|
| 150 |
+
perm = _permissive_apply(broken_script, diff_text)
|
| 151 |
+
return perm
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def make_unified_diff(before: str, after: str, path: str = "train.py") -> str:
|
| 155 |
+
"""Produce a canonical unified diff from before -> after."""
|
| 156 |
+
diff = difflib.unified_diff(
|
| 157 |
+
before.splitlines(keepends=True),
|
| 158 |
+
after.splitlines(keepends=True),
|
| 159 |
+
fromfile=f"a/{path}",
|
| 160 |
+
tofile=f"b/{path}",
|
| 161 |
+
n=2,
|
| 162 |
+
)
|
| 163 |
+
return "".join(diff)
|
forgeenv/env/forge_environment.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ForgeEnvironment: the OpenEnv Environment subclass for ForgeEnv.
|
| 2 |
+
|
| 3 |
+
Episode flow (exactly 2 steps per episode):
|
| 4 |
+
reset() -> sample task, ask Teacher for category
|
| 5 |
+
step(BreakageAction) -> Drift Generator's proposal is applied; broken
|
| 6 |
+
script is run, error trace captured.
|
| 7 |
+
step(RepairAction) -> Repair diff is applied; script is re-executed;
|
| 8 |
+
visible + held-out rewards computed; episode ends.
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import time
|
| 13 |
+
import uuid
|
| 14 |
+
from typing import Any, Optional
|
| 15 |
+
|
| 16 |
+
from openenv.core import Environment
|
| 17 |
+
|
| 18 |
+
from forgeenv.drift.library_drift_engine import LibraryDriftEngine
|
| 19 |
+
from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction
|
| 20 |
+
from forgeenv.env.diff_utils import apply_unified_diff
|
| 21 |
+
from forgeenv.env.observations import ForgeObservation
|
| 22 |
+
from forgeenv.primitives.breakage_primitives import (
|
| 23 |
+
PRIMITIVE_REGISTRY,
|
| 24 |
+
parse_breakage_spec,
|
| 25 |
+
)
|
| 26 |
+
from forgeenv.roles.teacher import Teacher
|
| 27 |
+
from forgeenv.sandbox.simulation_mode import SimulationExecutor
|
| 28 |
+
from forgeenv.tasks.models import ExecutionResult, Task
|
| 29 |
+
from forgeenv.tasks.task_sampler import TaskSampler
|
| 30 |
+
from forgeenv.verifier.held_out_evaluator import compute_held_out_scores
|
| 31 |
+
from forgeenv.verifier.visible_verifier import compute_visible_reward
|
| 32 |
+
|
| 33 |
+
DEFAULT_CATEGORIES = sorted(PRIMITIVE_REGISTRY.keys())
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ForgeEnvironment(Environment[ForgeAction, ForgeObservation, dict]):
|
| 37 |
+
"""OpenEnv-compliant environment for HuggingFace ecosystem repair."""
|
| 38 |
+
|
| 39 |
+
SUPPORTS_CONCURRENT_SESSIONS = False # Teacher state is global per env
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
task_sampler: Optional[TaskSampler] = None,
|
| 44 |
+
teacher: Optional[Teacher] = None,
|
| 45 |
+
executor: Optional[SimulationExecutor] = None,
|
| 46 |
+
drift_engine: Optional[LibraryDriftEngine] = None,
|
| 47 |
+
seed: Optional[int] = None,
|
| 48 |
+
) -> None:
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.task_sampler = task_sampler or TaskSampler()
|
| 51 |
+
self.teacher = teacher or Teacher(
|
| 52 |
+
categories=list(DEFAULT_CATEGORIES) or ["api_drift"]
|
| 53 |
+
)
|
| 54 |
+
self.executor = executor or SimulationExecutor(seed=seed)
|
| 55 |
+
self.drift_engine = drift_engine or LibraryDriftEngine()
|
| 56 |
+
|
| 57 |
+
self._episode_id: Optional[str] = None
|
| 58 |
+
self._episode_count: int = 0
|
| 59 |
+
self._current_task: Optional[Task] = None
|
| 60 |
+
self._original_script: str = ""
|
| 61 |
+
self._broken_script: str = ""
|
| 62 |
+
self._error_trace: str = ""
|
| 63 |
+
self._breakage_spec: Optional[dict[str, Any]] = None
|
| 64 |
+
self._target_category: str = ""
|
| 65 |
+
self._current_phase: str = "idle"
|
| 66 |
+
self._last_obs: Optional[ForgeObservation] = None
|
| 67 |
+
|
| 68 |
+
# ------------------------------------------------------------------ API
|
| 69 |
+
def reset(
|
| 70 |
+
self,
|
| 71 |
+
seed: Optional[int] = None,
|
| 72 |
+
episode_id: Optional[str] = None,
|
| 73 |
+
difficulty: Optional[str] = "easy",
|
| 74 |
+
**kwargs: Any,
|
| 75 |
+
) -> ForgeObservation:
|
| 76 |
+
self._episode_id = episode_id or str(uuid.uuid4())
|
| 77 |
+
self._episode_count += 1
|
| 78 |
+
self._target_category = self.teacher.select_next_category()
|
| 79 |
+
|
| 80 |
+
task = self.task_sampler.sample(difficulty=difficulty)
|
| 81 |
+
if task is None:
|
| 82 |
+
raise RuntimeError("Task sampler returned no tasks (empty seed corpus?)")
|
| 83 |
+
self._current_task = task
|
| 84 |
+
self._original_script = task.script_content
|
| 85 |
+
self._broken_script = ""
|
| 86 |
+
self._error_trace = ""
|
| 87 |
+
self._breakage_spec = None
|
| 88 |
+
self._current_phase = "drift_gen"
|
| 89 |
+
|
| 90 |
+
# Library drift trigger every 50 episodes (configurable from outside).
|
| 91 |
+
drifted = self.drift_engine.maybe_drift(self._episode_count, drift_every=50)
|
| 92 |
+
|
| 93 |
+
obs = ForgeObservation(
|
| 94 |
+
current_phase="drift_gen",
|
| 95 |
+
task_id=task.task_id,
|
| 96 |
+
task_description=task.description,
|
| 97 |
+
target_category=self._target_category,
|
| 98 |
+
script_content=self._original_script,
|
| 99 |
+
error_trace=None,
|
| 100 |
+
library_versions=self.drift_engine.current_versions(),
|
| 101 |
+
episode_step=0,
|
| 102 |
+
done=False,
|
| 103 |
+
reward=0.0,
|
| 104 |
+
info={
|
| 105 |
+
"episode_id": self._episode_id,
|
| 106 |
+
"episode_count": self._episode_count,
|
| 107 |
+
"drift_triggered": drifted,
|
| 108 |
+
"available_primitives": sorted(PRIMITIVE_REGISTRY),
|
| 109 |
+
},
|
| 110 |
+
)
|
| 111 |
+
self._last_obs = obs
|
| 112 |
+
return obs
|
| 113 |
+
|
| 114 |
+
def step(
|
| 115 |
+
self,
|
| 116 |
+
action: ForgeAction,
|
| 117 |
+
timeout_s: Optional[float] = None,
|
| 118 |
+
**kwargs: Any,
|
| 119 |
+
) -> ForgeObservation:
|
| 120 |
+
if self._current_phase == "drift_gen":
|
| 121 |
+
if action.breakage is None:
|
| 122 |
+
return self._error_obs("Expected BreakageAction in drift_gen phase")
|
| 123 |
+
return self._handle_breakage(action.breakage)
|
| 124 |
+
|
| 125 |
+
if self._current_phase == "repair":
|
| 126 |
+
if action.repair is None:
|
| 127 |
+
return self._error_obs("Expected RepairAction in repair phase")
|
| 128 |
+
return self._handle_repair(action.repair)
|
| 129 |
+
|
| 130 |
+
return self._error_obs(
|
| 131 |
+
f"step() called in invalid phase {self._current_phase!r} — call reset() first"
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
@property
|
| 135 |
+
def state(self) -> dict:
|
| 136 |
+
return {
|
| 137 |
+
"phase": self._current_phase,
|
| 138 |
+
"episode_id": self._episode_id,
|
| 139 |
+
"episode_count": self._episode_count,
|
| 140 |
+
"task_id": self._current_task.task_id if self._current_task else None,
|
| 141 |
+
"target_category": self._target_category,
|
| 142 |
+
"library_versions": self.drift_engine.current_versions(),
|
| 143 |
+
"teacher": self.teacher.get_state(),
|
| 144 |
+
"drift_history": list(self.drift_engine.drift_history),
|
| 145 |
+
"breakage_spec": dict(self._breakage_spec) if self._breakage_spec else None,
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
# ---------------------------------------------------------------- helpers
|
| 149 |
+
def _handle_breakage(self, breakage: BreakageAction) -> ForgeObservation:
|
| 150 |
+
spec = {"primitive_type": breakage.primitive_type, "params": dict(breakage.params)}
|
| 151 |
+
try:
|
| 152 |
+
primitive = parse_breakage_spec(spec)
|
| 153 |
+
except ValueError as exc:
|
| 154 |
+
return self._error_obs(f"Invalid breakage spec: {exc}")
|
| 155 |
+
|
| 156 |
+
try:
|
| 157 |
+
self._broken_script = primitive.apply(self._original_script)
|
| 158 |
+
except Exception as exc: # primitive bug — surface but don't crash server
|
| 159 |
+
return self._error_obs(f"Primitive apply failed: {exc}")
|
| 160 |
+
|
| 161 |
+
self._breakage_spec = spec
|
| 162 |
+
|
| 163 |
+
result = self.executor.execute(self._broken_script, self._current_task)
|
| 164 |
+
if result.exit_code != 0:
|
| 165 |
+
self._error_trace = result.stderr or "non-zero exit code, no stderr"
|
| 166 |
+
else:
|
| 167 |
+
# The breakage didn't actually break it; still proceed to repair phase
|
| 168 |
+
# (no-op repair is then a valid choice).
|
| 169 |
+
self._error_trace = "Script ran without observable error"
|
| 170 |
+
|
| 171 |
+
self._current_phase = "repair"
|
| 172 |
+
|
| 173 |
+
obs = ForgeObservation(
|
| 174 |
+
current_phase="repair",
|
| 175 |
+
task_id=self._current_task.task_id,
|
| 176 |
+
task_description=self._current_task.description,
|
| 177 |
+
target_category=primitive.category,
|
| 178 |
+
script_content=self._broken_script,
|
| 179 |
+
error_trace=self._error_trace,
|
| 180 |
+
library_versions=self.drift_engine.current_versions(),
|
| 181 |
+
episode_step=1,
|
| 182 |
+
done=False,
|
| 183 |
+
reward=0.0,
|
| 184 |
+
info={
|
| 185 |
+
"episode_id": self._episode_id,
|
| 186 |
+
"breakage_primitive": primitive.name,
|
| 187 |
+
"breakage_description": primitive.description,
|
| 188 |
+
},
|
| 189 |
+
)
|
| 190 |
+
self._last_obs = obs
|
| 191 |
+
return obs
|
| 192 |
+
|
| 193 |
+
def _handle_repair(self, repair: RepairAction) -> ForgeObservation:
|
| 194 |
+
repaired = apply_unified_diff(self._broken_script, repair.unified_diff or "")
|
| 195 |
+
|
| 196 |
+
t0 = time.time()
|
| 197 |
+
result = self.executor.execute(repaired, self._current_task)
|
| 198 |
+
result.script_content = repaired # ensure verifier sees what we ran
|
| 199 |
+
wall_ms = int((time.time() - t0) * 1000)
|
| 200 |
+
|
| 201 |
+
visible_reward, visible_breakdown = compute_visible_reward(
|
| 202 |
+
result, self._current_task
|
| 203 |
+
)
|
| 204 |
+
held_out = compute_held_out_scores(
|
| 205 |
+
result, self._current_task, repair_diff=repair.unified_diff or ""
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
success = result.exit_code == 0
|
| 209 |
+
category = (
|
| 210 |
+
self._breakage_spec.get("primitive_type", "unknown")
|
| 211 |
+
if self._breakage_spec
|
| 212 |
+
else "unknown"
|
| 213 |
+
)
|
| 214 |
+
# Update Teacher's curriculum state
|
| 215 |
+
self.teacher.update(category, success)
|
| 216 |
+
|
| 217 |
+
self._current_phase = "done"
|
| 218 |
+
|
| 219 |
+
obs = ForgeObservation(
|
| 220 |
+
current_phase="done",
|
| 221 |
+
task_id=self._current_task.task_id,
|
| 222 |
+
task_description=self._current_task.description,
|
| 223 |
+
target_category=category,
|
| 224 |
+
script_content=repaired,
|
| 225 |
+
error_trace=result.stderr or None,
|
| 226 |
+
library_versions=self.drift_engine.current_versions(),
|
| 227 |
+
episode_step=2,
|
| 228 |
+
done=True,
|
| 229 |
+
reward=visible_reward,
|
| 230 |
+
reward_breakdown=visible_breakdown,
|
| 231 |
+
held_out_breakdown=held_out,
|
| 232 |
+
info={
|
| 233 |
+
"episode_id": self._episode_id,
|
| 234 |
+
"exit_code": result.exit_code,
|
| 235 |
+
"wall_time_ms": wall_ms,
|
| 236 |
+
"checkpoint_exists": result.checkpoint_exists,
|
| 237 |
+
"stdout_tail": "\n".join(result.stdout.splitlines()[-5:]),
|
| 238 |
+
"breakage_spec": self._breakage_spec,
|
| 239 |
+
"teacher_state": self.teacher.get_state(),
|
| 240 |
+
},
|
| 241 |
+
)
|
| 242 |
+
self._last_obs = obs
|
| 243 |
+
return obs
|
| 244 |
+
|
| 245 |
+
def _error_obs(self, message: str) -> ForgeObservation:
|
| 246 |
+
"""Return a `done=True` error observation rather than raising."""
|
| 247 |
+
return ForgeObservation(
|
| 248 |
+
current_phase="done",
|
| 249 |
+
task_id=self._current_task.task_id if self._current_task else "",
|
| 250 |
+
task_description=self._current_task.description if self._current_task else "",
|
| 251 |
+
target_category=self._target_category,
|
| 252 |
+
script_content=self._broken_script or self._original_script,
|
| 253 |
+
error_trace=message,
|
| 254 |
+
library_versions=self.drift_engine.current_versions(),
|
| 255 |
+
episode_step=2,
|
| 256 |
+
done=True,
|
| 257 |
+
reward=0.0,
|
| 258 |
+
info={"error": message, "episode_id": self._episode_id},
|
| 259 |
+
)
|
forgeenv/env/observations.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic observation model for ForgeEnv."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import Any, Optional
|
| 5 |
+
|
| 6 |
+
from pydantic import Field
|
| 7 |
+
|
| 8 |
+
from openenv.core import Observation
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ForgeObservation(Observation):
|
| 12 |
+
"""What the agent (or the trainer's rollout function) sees at each step.
|
| 13 |
+
|
| 14 |
+
Inherits `done`, `reward`, `metadata` from the OpenEnv `Observation` base.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
current_phase: str = Field(
|
| 18 |
+
..., description="One of 'drift_gen', 'repair', 'verify', 'done'"
|
| 19 |
+
)
|
| 20 |
+
task_id: str = ""
|
| 21 |
+
task_description: str = ""
|
| 22 |
+
target_category: str = ""
|
| 23 |
+
script_content: str = Field(default="", description="Current state of the script")
|
| 24 |
+
error_trace: Optional[str] = None
|
| 25 |
+
library_versions: dict[str, str] = Field(default_factory=dict)
|
| 26 |
+
reward_breakdown: dict[str, Any] = Field(default_factory=dict)
|
| 27 |
+
held_out_breakdown: dict[str, float] = Field(default_factory=dict)
|
| 28 |
+
episode_step: int = 0
|
| 29 |
+
info: dict[str, Any] = Field(default_factory=dict)
|
forgeenv/env/server.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI server for ForgeEnv (OpenEnv-compliant).
|
| 2 |
+
|
| 3 |
+
Exposes /reset, /step, /state HTTP endpoints via OpenEnv's `create_app`.
|
| 4 |
+
HF Spaces sets PORT=7860 automatically.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from openenv.core import create_app
|
| 11 |
+
|
| 12 |
+
from forgeenv.env.actions import ForgeAction
|
| 13 |
+
from forgeenv.env.forge_environment import ForgeEnvironment
|
| 14 |
+
from forgeenv.env.observations import ForgeObservation
|
| 15 |
+
|
| 16 |
+
app = create_app(
|
| 17 |
+
env=ForgeEnvironment,
|
| 18 |
+
action_cls=ForgeAction,
|
| 19 |
+
observation_cls=ForgeObservation,
|
| 20 |
+
env_name="forgeenv",
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Lightweight health endpoint for Docker / HF Spaces probes. We attach it
|
| 25 |
+
# only if `create_app` didn't already register one, so we don't shadow
|
| 26 |
+
# whatever the OpenEnv version ships with.
|
| 27 |
+
def _ensure_health_route(_app) -> None:
|
| 28 |
+
existing = {
|
| 29 |
+
getattr(r, "path", None) for r in getattr(_app, "routes", [])
|
| 30 |
+
}
|
| 31 |
+
if "/health" in existing:
|
| 32 |
+
return
|
| 33 |
+
|
| 34 |
+
@_app.get("/health")
|
| 35 |
+
def _health() -> dict:
|
| 36 |
+
return {"status": "ok", "env": "forgeenv"}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
_ensure_health_route(app)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
import uvicorn
|
| 44 |
+
|
| 45 |
+
port = int(os.environ.get("PORT", "7860"))
|
| 46 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|
forgeenv/primitives/__init__.py
ADDED
|
File without changes
|
forgeenv/primitives/breakage_primitives.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""8 breakage primitives representing real HuggingFace/PyTorch ecosystem drift.
|
| 2 |
+
|
| 3 |
+
Each primitive transforms a working script to simulate a library upgrade
|
| 4 |
+
breakage. They double as the Drift Generator's structured action space.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import re
|
| 9 |
+
from abc import ABC, abstractmethod
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class BreakagePrimitive(ABC):
|
| 15 |
+
"""Abstract base class for all breakage types."""
|
| 16 |
+
|
| 17 |
+
category: str = field(default="generic", init=False)
|
| 18 |
+
name: str = field(default="BreakagePrimitive", init=False)
|
| 19 |
+
description: str = field(default="", init=False)
|
| 20 |
+
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def apply(self, script: str) -> str:
|
| 23 |
+
"""Transform `script` to introduce the breakage."""
|
| 24 |
+
|
| 25 |
+
def to_spec(self) -> dict:
|
| 26 |
+
"""Serialize to JSON-compatible spec for the LLM action space."""
|
| 27 |
+
return {
|
| 28 |
+
"primitive_type": self.__class__.__name__,
|
| 29 |
+
"category": self.category,
|
| 30 |
+
"params": self._get_params(),
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def _get_params(self) -> dict:
|
| 35 |
+
"""Return a JSON-serializable dict of constructor parameters."""
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class RenameApiCall(BreakagePrimitive):
|
| 40 |
+
"""Rename a function/method call to simulate API deprecation."""
|
| 41 |
+
|
| 42 |
+
old_name: str = ""
|
| 43 |
+
new_name: str = ""
|
| 44 |
+
|
| 45 |
+
def __post_init__(self) -> None:
|
| 46 |
+
self.category = "api_drift"
|
| 47 |
+
self.name = "RenameApiCall"
|
| 48 |
+
self.description = f"Rename {self.old_name} -> {self.new_name}"
|
| 49 |
+
|
| 50 |
+
def apply(self, script: str) -> str:
|
| 51 |
+
if not self.old_name:
|
| 52 |
+
return script
|
| 53 |
+
# Use word-boundary replacement so we don't substring-match identifiers.
|
| 54 |
+
pattern = re.compile(rf"(?<!\w){re.escape(self.old_name)}(?!\w)")
|
| 55 |
+
return pattern.sub(self.new_name, script)
|
| 56 |
+
|
| 57 |
+
def _get_params(self) -> dict:
|
| 58 |
+
return {"old_name": self.old_name, "new_name": self.new_name}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dataclass
|
| 62 |
+
class DeprecateImport(BreakagePrimitive):
|
| 63 |
+
"""Change an import path to simulate module restructuring."""
|
| 64 |
+
|
| 65 |
+
old_module: str = ""
|
| 66 |
+
new_module: str = ""
|
| 67 |
+
|
| 68 |
+
def __post_init__(self) -> None:
|
| 69 |
+
self.category = "import_drift"
|
| 70 |
+
self.name = "DeprecateImport"
|
| 71 |
+
self.description = f"Move {self.old_module} -> {self.new_module}"
|
| 72 |
+
|
| 73 |
+
def apply(self, script: str) -> str:
|
| 74 |
+
if not self.old_module:
|
| 75 |
+
return script
|
| 76 |
+
return script.replace(self.old_module, self.new_module)
|
| 77 |
+
|
| 78 |
+
def _get_params(self) -> dict:
|
| 79 |
+
return {"old_module": self.old_module, "new_module": self.new_module}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@dataclass
|
| 83 |
+
class ChangeArgumentSignature(BreakagePrimitive):
|
| 84 |
+
"""Remove an expected kwarg (and document a new required one)."""
|
| 85 |
+
|
| 86 |
+
function_name: str = ""
|
| 87 |
+
removed_arg: str = ""
|
| 88 |
+
added_arg: str = ""
|
| 89 |
+
added_value: str = ""
|
| 90 |
+
|
| 91 |
+
def __post_init__(self) -> None:
|
| 92 |
+
self.category = "api_drift"
|
| 93 |
+
self.name = "ChangeArgumentSignature"
|
| 94 |
+
self.description = (
|
| 95 |
+
f"Change args of {self.function_name}: -{self.removed_arg} +{self.added_arg}"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def apply(self, script: str) -> str:
|
| 99 |
+
if not self.removed_arg:
|
| 100 |
+
return script
|
| 101 |
+
pattern = rf"(\b{re.escape(self.removed_arg)}\s*=\s*[^,)]+,?\s*)"
|
| 102 |
+
return re.sub(pattern, "", script)
|
| 103 |
+
|
| 104 |
+
def _get_params(self) -> dict:
|
| 105 |
+
return {
|
| 106 |
+
"function_name": self.function_name,
|
| 107 |
+
"removed_arg": self.removed_arg,
|
| 108 |
+
"added_arg": self.added_arg,
|
| 109 |
+
"added_value": self.added_value,
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@dataclass
|
| 114 |
+
class ModifyConfigField(BreakagePrimitive):
|
| 115 |
+
"""Change a config-class default value to simulate behaviour drift."""
|
| 116 |
+
|
| 117 |
+
config_class: str = ""
|
| 118 |
+
field_name: str = ""
|
| 119 |
+
new_value: str = ""
|
| 120 |
+
|
| 121 |
+
def __post_init__(self) -> None:
|
| 122 |
+
self.category = "config_drift"
|
| 123 |
+
self.name = "ModifyConfigField"
|
| 124 |
+
self.description = f"Change {self.config_class}.{self.field_name}"
|
| 125 |
+
|
| 126 |
+
def apply(self, script: str) -> str:
|
| 127 |
+
if not self.field_name:
|
| 128 |
+
return script
|
| 129 |
+
pattern = rf"({re.escape(self.field_name)}\s*=\s*)([^,)\n]+)"
|
| 130 |
+
return re.sub(pattern, rf"\g<1>{self.new_value}", script)
|
| 131 |
+
|
| 132 |
+
def _get_params(self) -> dict:
|
| 133 |
+
return {
|
| 134 |
+
"config_class": self.config_class,
|
| 135 |
+
"field_name": self.field_name,
|
| 136 |
+
"new_value": self.new_value,
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@dataclass
|
| 141 |
+
class RestructureDatasetSchema(BreakagePrimitive):
|
| 142 |
+
"""Rename a dataset column reference to simulate schema drift."""
|
| 143 |
+
|
| 144 |
+
old_column: str = ""
|
| 145 |
+
new_column: str = ""
|
| 146 |
+
|
| 147 |
+
def __post_init__(self) -> None:
|
| 148 |
+
self.category = "dataset_drift"
|
| 149 |
+
self.name = "RestructureDatasetSchema"
|
| 150 |
+
self.description = f"Rename column {self.old_column} -> {self.new_column}"
|
| 151 |
+
|
| 152 |
+
def apply(self, script: str) -> str:
|
| 153 |
+
if not self.old_column:
|
| 154 |
+
return script
|
| 155 |
+
return script.replace(
|
| 156 |
+
f'"{self.old_column}"', f'"{self.new_column}"'
|
| 157 |
+
).replace(
|
| 158 |
+
f"'{self.old_column}'", f"'{self.new_column}'"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def _get_params(self) -> dict:
|
| 162 |
+
return {"old_column": self.old_column, "new_column": self.new_column}
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
@dataclass
|
| 166 |
+
class ChangeTokenizerBehavior(BreakagePrimitive):
|
| 167 |
+
"""Change tokenizer call arguments."""
|
| 168 |
+
|
| 169 |
+
old_kwarg: str = ""
|
| 170 |
+
old_value: str = ""
|
| 171 |
+
new_kwarg: str = ""
|
| 172 |
+
new_value: str = ""
|
| 173 |
+
|
| 174 |
+
def __post_init__(self) -> None:
|
| 175 |
+
self.category = "tokenizer_drift"
|
| 176 |
+
self.name = "ChangeTokenizerBehavior"
|
| 177 |
+
self.description = f"Change tokenizer kwarg {self.old_kwarg}={self.old_value} -> {self.new_kwarg}={self.new_value}"
|
| 178 |
+
|
| 179 |
+
def apply(self, script: str) -> str:
|
| 180 |
+
if not self.old_kwarg:
|
| 181 |
+
return script
|
| 182 |
+
pattern = rf"{re.escape(self.old_kwarg)}\s*=\s*{re.escape(self.old_value)}"
|
| 183 |
+
replacement = f"{self.new_kwarg}={self.new_value}"
|
| 184 |
+
return re.sub(pattern, replacement, script)
|
| 185 |
+
|
| 186 |
+
def _get_params(self) -> dict:
|
| 187 |
+
return {
|
| 188 |
+
"old_kwarg": self.old_kwarg,
|
| 189 |
+
"old_value": self.old_value,
|
| 190 |
+
"new_kwarg": self.new_kwarg,
|
| 191 |
+
"new_value": self.new_value,
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@dataclass
|
| 196 |
+
class RemoveDeprecatedMethod(BreakagePrimitive):
|
| 197 |
+
"""Remove a method that has been deprecated, leaving a sentinel that
|
| 198 |
+
raises AttributeError-style errors when the script runs."""
|
| 199 |
+
|
| 200 |
+
class_name: str = ""
|
| 201 |
+
method_name: str = ""
|
| 202 |
+
replacement: str = ""
|
| 203 |
+
|
| 204 |
+
def __post_init__(self) -> None:
|
| 205 |
+
self.category = "api_drift"
|
| 206 |
+
self.name = "RemoveDeprecatedMethod"
|
| 207 |
+
self.description = f"Remove {self.class_name}.{self.method_name}"
|
| 208 |
+
|
| 209 |
+
def apply(self, script: str) -> str:
|
| 210 |
+
if not self.method_name:
|
| 211 |
+
return script
|
| 212 |
+
return script.replace(
|
| 213 |
+
f".{self.method_name}(", f".{self.method_name}_DEPRECATED("
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
def _get_params(self) -> dict:
|
| 217 |
+
return {
|
| 218 |
+
"class_name": self.class_name,
|
| 219 |
+
"method_name": self.method_name,
|
| 220 |
+
"replacement": self.replacement,
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@dataclass
|
| 225 |
+
class ChangeReturnType(BreakagePrimitive):
|
| 226 |
+
"""A function now returns a different structure (e.g. tuple -> object)."""
|
| 227 |
+
|
| 228 |
+
function_name: str = ""
|
| 229 |
+
old_access: str = ""
|
| 230 |
+
new_access: str = ""
|
| 231 |
+
|
| 232 |
+
def __post_init__(self) -> None:
|
| 233 |
+
self.category = "api_drift"
|
| 234 |
+
self.name = "ChangeReturnType"
|
| 235 |
+
self.description = f"Change return type of {self.function_name}"
|
| 236 |
+
|
| 237 |
+
def apply(self, script: str) -> str:
|
| 238 |
+
if self.old_access and self.new_access:
|
| 239 |
+
return script.replace(self.old_access, self.new_access)
|
| 240 |
+
return script
|
| 241 |
+
|
| 242 |
+
def _get_params(self) -> dict:
|
| 243 |
+
return {
|
| 244 |
+
"function_name": self.function_name,
|
| 245 |
+
"old_access": self.old_access,
|
| 246 |
+
"new_access": self.new_access,
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
PRIMITIVE_REGISTRY: dict[str, type[BreakagePrimitive]] = {
|
| 251 |
+
"RenameApiCall": RenameApiCall,
|
| 252 |
+
"DeprecateImport": DeprecateImport,
|
| 253 |
+
"ChangeArgumentSignature": ChangeArgumentSignature,
|
| 254 |
+
"ModifyConfigField": ModifyConfigField,
|
| 255 |
+
"RestructureDatasetSchema": RestructureDatasetSchema,
|
| 256 |
+
"ChangeTokenizerBehavior": ChangeTokenizerBehavior,
|
| 257 |
+
"RemoveDeprecatedMethod": RemoveDeprecatedMethod,
|
| 258 |
+
"ChangeReturnType": ChangeReturnType,
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def parse_breakage_spec(spec: dict) -> BreakagePrimitive:
|
| 263 |
+
"""Parse a JSON breakage spec into a BreakagePrimitive object.
|
| 264 |
+
|
| 265 |
+
Tolerates extra keys; ignores unknown params (LLMs hallucinate these).
|
| 266 |
+
"""
|
| 267 |
+
ptype = spec.get("primitive_type", "")
|
| 268 |
+
params = spec.get("params", {}) or {}
|
| 269 |
+
|
| 270 |
+
if ptype not in PRIMITIVE_REGISTRY:
|
| 271 |
+
raise ValueError(
|
| 272 |
+
f"Unknown primitive type: {ptype!r}. "
|
| 273 |
+
f"Valid types: {list(PRIMITIVE_REGISTRY)}"
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
cls = PRIMITIVE_REGISTRY[ptype]
|
| 277 |
+
# Filter to known fields only so a hallucinated kwarg can't crash us.
|
| 278 |
+
valid_fields = {
|
| 279 |
+
f.name for f in cls.__dataclass_fields__.values() if f.init # type: ignore[attr-defined]
|
| 280 |
+
}
|
| 281 |
+
filtered = {k: v for k, v in params.items() if k in valid_fields}
|
| 282 |
+
return cls(**filtered)
|
forgeenv/primitives/drift_taxonomy.yaml
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Drift taxonomy: real HuggingFace/PyTorch breakages observed across version bumps.
|
| 2 |
+
# Used to seed the Drift Generator's initial proposal distribution and to anchor
|
| 3 |
+
# warm-start pair generation in things that actually happened in the wild.
|
| 4 |
+
- version_range: "transformers 4.36 -> 4.45"
|
| 5 |
+
affected_api: "Trainer.evaluate"
|
| 6 |
+
description: "Trainer.evaluate() return type changed shape; metrics now nested under .metrics"
|
| 7 |
+
breakage_primitive: "ChangeReturnType"
|
| 8 |
+
params:
|
| 9 |
+
function_name: "evaluate"
|
| 10 |
+
old_access: "trainer.evaluate()"
|
| 11 |
+
new_access: "trainer.evaluate().metrics"
|
| 12 |
+
repair_primitive: "RestoreReturnAccess"
|
| 13 |
+
category: "api_drift"
|
| 14 |
+
|
| 15 |
+
- version_range: "transformers 4.30 -> 4.40"
|
| 16 |
+
affected_api: "TrainingArguments.evaluation_strategy"
|
| 17 |
+
description: "Renamed evaluation_strategy -> eval_strategy"
|
| 18 |
+
breakage_primitive: "RenameApiCall"
|
| 19 |
+
params:
|
| 20 |
+
old_name: "evaluation_strategy"
|
| 21 |
+
new_name: "eval_strategy"
|
| 22 |
+
repair_primitive: "RestoreApiCall"
|
| 23 |
+
category: "api_drift"
|
| 24 |
+
|
| 25 |
+
- version_range: "datasets 2.14 -> 3.0"
|
| 26 |
+
affected_api: "load_dataset"
|
| 27 |
+
description: "Default split column was renamed in some GLUE configs"
|
| 28 |
+
breakage_primitive: "RestructureDatasetSchema"
|
| 29 |
+
params:
|
| 30 |
+
old_column: "label"
|
| 31 |
+
new_column: "labels"
|
| 32 |
+
repair_primitive: "RestoreColumn"
|
| 33 |
+
category: "dataset_drift"
|
| 34 |
+
|
| 35 |
+
- version_range: "transformers 4.40 -> 4.50"
|
| 36 |
+
affected_api: "Trainer.predict"
|
| 37 |
+
description: "Method removed; users should use evaluate() with prediction_loss_only=False"
|
| 38 |
+
breakage_primitive: "RemoveDeprecatedMethod"
|
| 39 |
+
params:
|
| 40 |
+
class_name: "Trainer"
|
| 41 |
+
method_name: "predict"
|
| 42 |
+
replacement: "evaluate"
|
| 43 |
+
repair_primitive: "RestoreMethod"
|
| 44 |
+
category: "api_drift"
|
| 45 |
+
|
| 46 |
+
- version_range: "transformers 4.36 -> 4.40"
|
| 47 |
+
affected_api: "TrainingArguments"
|
| 48 |
+
description: "num_train_epochs default behavior changed; max_steps now preferred"
|
| 49 |
+
breakage_primitive: "ModifyConfigField"
|
| 50 |
+
params:
|
| 51 |
+
config_class: "TrainingArguments"
|
| 52 |
+
field_name: "num_train_epochs"
|
| 53 |
+
new_value: "0"
|
| 54 |
+
repair_primitive: "RestoreConfigField"
|
| 55 |
+
category: "config_drift"
|
| 56 |
+
|
| 57 |
+
- version_range: "transformers 4.34 -> 4.42"
|
| 58 |
+
affected_api: "Tokenizer.__call__"
|
| 59 |
+
description: "padding=True semantics changed; users should pass padding='max_length'"
|
| 60 |
+
breakage_primitive: "ChangeTokenizerBehavior"
|
| 61 |
+
params:
|
| 62 |
+
old_kwarg: "padding"
|
| 63 |
+
old_value: "True"
|
| 64 |
+
new_kwarg: "padding"
|
| 65 |
+
new_value: '"max_length"'
|
| 66 |
+
repair_primitive: "RestoreTokenizerKwarg"
|
| 67 |
+
category: "tokenizer_drift"
|
| 68 |
+
|
| 69 |
+
- version_range: "transformers 4.20 -> 4.30"
|
| 70 |
+
affected_api: "imports"
|
| 71 |
+
description: "transformers.training_args moved to transformers.training_args_pt"
|
| 72 |
+
breakage_primitive: "DeprecateImport"
|
| 73 |
+
params:
|
| 74 |
+
old_module: "from transformers.training_args"
|
| 75 |
+
new_module: "from transformers.training_args_pt"
|
| 76 |
+
repair_primitive: "RestoreImport"
|
| 77 |
+
category: "import_drift"
|
| 78 |
+
|
| 79 |
+
- version_range: "transformers 4.45 -> 4.50"
|
| 80 |
+
affected_api: "save_pretrained"
|
| 81 |
+
description: "save_pretrained() now requires safe_serialization to default True"
|
| 82 |
+
breakage_primitive: "ChangeArgumentSignature"
|
| 83 |
+
params:
|
| 84 |
+
function_name: "save_pretrained"
|
| 85 |
+
removed_arg: "safe_serialization"
|
| 86 |
+
added_arg: "safe_serialization"
|
| 87 |
+
added_value: "True"
|
| 88 |
+
repair_primitive: "RestoreArgument"
|
| 89 |
+
category: "api_drift"
|
| 90 |
+
|
| 91 |
+
- version_range: "datasets 2.18 -> 3.0"
|
| 92 |
+
affected_api: "Dataset.set_format"
|
| 93 |
+
description: "set_format(type='torch') signature stricter, columns required"
|
| 94 |
+
breakage_primitive: "ChangeArgumentSignature"
|
| 95 |
+
params:
|
| 96 |
+
function_name: "set_format"
|
| 97 |
+
removed_arg: "columns"
|
| 98 |
+
added_arg: "columns"
|
| 99 |
+
added_value: '["input_ids", "attention_mask", "labels"]'
|
| 100 |
+
repair_primitive: "RestoreArgument"
|
| 101 |
+
category: "api_drift"
|
| 102 |
+
|
| 103 |
+
- version_range: "transformers 4.36 -> 4.45"
|
| 104 |
+
affected_api: "Tokenizer.__call__"
|
| 105 |
+
description: "max_length default reduced from 512 -> 256 for some tokenizers"
|
| 106 |
+
breakage_primitive: "ModifyConfigField"
|
| 107 |
+
params:
|
| 108 |
+
config_class: "tokenizer"
|
| 109 |
+
field_name: "max_length"
|
| 110 |
+
new_value: "256"
|
| 111 |
+
repair_primitive: "RestoreConfigField"
|
| 112 |
+
category: "tokenizer_drift"
|
| 113 |
+
|
| 114 |
+
- version_range: "transformers 4.40 -> 4.45"
|
| 115 |
+
affected_api: "DataCollatorWithPadding"
|
| 116 |
+
description: "Renamed `tokenizer` -> `processing_class` in DataCollator constructors"
|
| 117 |
+
breakage_primitive: "RenameApiCall"
|
| 118 |
+
params:
|
| 119 |
+
old_name: "tokenizer"
|
| 120 |
+
new_name: "processing_class"
|
| 121 |
+
repair_primitive: "RestoreApiCall"
|
| 122 |
+
category: "api_drift"
|
| 123 |
+
|
| 124 |
+
- version_range: "datasets 2.14 -> 2.18"
|
| 125 |
+
affected_api: "load_dataset"
|
| 126 |
+
description: "Some splits renamed train[:500] semantics changed"
|
| 127 |
+
breakage_primitive: "RestructureDatasetSchema"
|
| 128 |
+
params:
|
| 129 |
+
old_column: "sentence"
|
| 130 |
+
new_column: "text"
|
| 131 |
+
repair_primitive: "RestoreColumn"
|
| 132 |
+
category: "dataset_drift"
|
| 133 |
+
|
| 134 |
+
- version_range: "transformers 4.45 -> 4.50"
|
| 135 |
+
affected_api: "Trainer"
|
| 136 |
+
description: "evaluation_strategy was deprecated and removed"
|
| 137 |
+
breakage_primitive: "RemoveDeprecatedMethod"
|
| 138 |
+
params:
|
| 139 |
+
class_name: "Trainer"
|
| 140 |
+
method_name: "evaluate"
|
| 141 |
+
replacement: "evaluate_legacy"
|
| 142 |
+
repair_primitive: "RestoreMethod"
|
| 143 |
+
category: "api_drift"
|
| 144 |
+
|
| 145 |
+
- version_range: "transformers 4.30 -> 4.40"
|
| 146 |
+
affected_api: "PreTrainedModel.from_pretrained"
|
| 147 |
+
description: "torch_dtype now required for some quantized model paths"
|
| 148 |
+
breakage_primitive: "ChangeArgumentSignature"
|
| 149 |
+
params:
|
| 150 |
+
function_name: "from_pretrained"
|
| 151 |
+
removed_arg: "torch_dtype"
|
| 152 |
+
added_arg: "torch_dtype"
|
| 153 |
+
added_value: '"auto"'
|
| 154 |
+
repair_primitive: "RestoreArgument"
|
| 155 |
+
category: "api_drift"
|
| 156 |
+
|
| 157 |
+
- version_range: "datasets 3.0 -> 3.2"
|
| 158 |
+
affected_api: "Dataset.rename_column"
|
| 159 |
+
description: "rename_column raises if target name exists"
|
| 160 |
+
breakage_primitive: "RestructureDatasetSchema"
|
| 161 |
+
params:
|
| 162 |
+
old_column: "labels"
|
| 163 |
+
new_column: "label"
|
| 164 |
+
repair_primitive: "RestoreColumn"
|
| 165 |
+
category: "dataset_drift"
|
| 166 |
+
|
| 167 |
+
- version_range: "transformers 4.36 -> 4.42"
|
| 168 |
+
affected_api: "TrainingArguments.report_to"
|
| 169 |
+
description: "Default report_to changed from 'all' to 'none'"
|
| 170 |
+
breakage_primitive: "ModifyConfigField"
|
| 171 |
+
params:
|
| 172 |
+
config_class: "TrainingArguments"
|
| 173 |
+
field_name: "report_to"
|
| 174 |
+
new_value: '"all"'
|
| 175 |
+
repair_primitive: "RestoreConfigField"
|
| 176 |
+
category: "config_drift"
|
| 177 |
+
|
| 178 |
+
- version_range: "transformers 4.40 -> 4.50"
|
| 179 |
+
affected_api: "imports"
|
| 180 |
+
description: "transformers.deepspeed moved to accelerate.utils.deepspeed"
|
| 181 |
+
breakage_primitive: "DeprecateImport"
|
| 182 |
+
params:
|
| 183 |
+
old_module: "from transformers.deepspeed"
|
| 184 |
+
new_module: "from accelerate.utils.deepspeed"
|
| 185 |
+
repair_primitive: "RestoreImport"
|
| 186 |
+
category: "import_drift"
|
| 187 |
+
|
| 188 |
+
- version_range: "transformers 4.45 -> 4.50"
|
| 189 |
+
affected_api: "Tokenizer return"
|
| 190 |
+
description: "Tokenizer call output now returns a BatchEncoding with .encodings attribute"
|
| 191 |
+
breakage_primitive: "ChangeReturnType"
|
| 192 |
+
params:
|
| 193 |
+
function_name: "tokenizer"
|
| 194 |
+
old_access: "tokenizer(text)"
|
| 195 |
+
new_access: "tokenizer(text).encodings"
|
| 196 |
+
repair_primitive: "RestoreReturnAccess"
|
| 197 |
+
category: "api_drift"
|
| 198 |
+
|
| 199 |
+
- version_range: "transformers 4.30 -> 4.40"
|
| 200 |
+
affected_api: "save_pretrained"
|
| 201 |
+
description: "save_pretrained -> save_pretrained_directory rename in some classes"
|
| 202 |
+
breakage_primitive: "RenameApiCall"
|
| 203 |
+
params:
|
| 204 |
+
old_name: "save_pretrained"
|
| 205 |
+
new_name: "save_pretrained_directory"
|
| 206 |
+
repair_primitive: "RestoreApiCall"
|
| 207 |
+
category: "api_drift"
|
| 208 |
+
|
| 209 |
+
- version_range: "transformers 4.45 -> 4.50"
|
| 210 |
+
affected_api: "TrainingArguments.no_cuda"
|
| 211 |
+
description: "no_cuda renamed to use_cpu (logic inverted)"
|
| 212 |
+
breakage_primitive: "RenameApiCall"
|
| 213 |
+
params:
|
| 214 |
+
old_name: "no_cuda"
|
| 215 |
+
new_name: "use_cpu"
|
| 216 |
+
repair_primitive: "RestoreApiCall"
|
| 217 |
+
category: "config_drift"
|
forgeenv/primitives/repair_primitives.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Repair primitives — direct inverses of the 8 breakage primitives.
|
| 2 |
+
|
| 3 |
+
Used during warm-start data generation: for every (script, breakage)
|
| 4 |
+
pair we know the canonical repair, so we can write SFT pairs.
|
| 5 |
+
|
| 6 |
+
These are also useful for unit-testing the breakage primitives:
|
| 7 |
+
apply(breakage) then apply(repair) should be (close to) the identity.
|
| 8 |
+
"""
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import re
|
| 12 |
+
from abc import ABC, abstractmethod
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class RepairPrimitive(ABC):
|
| 18 |
+
category: str = field(default="generic", init=False)
|
| 19 |
+
name: str = field(default="RepairPrimitive", init=False)
|
| 20 |
+
description: str = field(default="", init=False)
|
| 21 |
+
|
| 22 |
+
@abstractmethod
|
| 23 |
+
def apply(self, script: str) -> str:
|
| 24 |
+
"""Transform `script` to undo the corresponding breakage."""
|
| 25 |
+
|
| 26 |
+
def to_spec(self) -> dict:
|
| 27 |
+
return {
|
| 28 |
+
"primitive_type": self.__class__.__name__,
|
| 29 |
+
"category": self.category,
|
| 30 |
+
"params": self._get_params(),
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def _get_params(self) -> dict:
|
| 35 |
+
"""Return JSON-serializable constructor parameters."""
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class RestoreApiCall(RepairPrimitive):
|
| 40 |
+
new_name: str = ""
|
| 41 |
+
old_name: str = ""
|
| 42 |
+
|
| 43 |
+
def __post_init__(self) -> None:
|
| 44 |
+
self.category = "api_drift"
|
| 45 |
+
self.name = "RestoreApiCall"
|
| 46 |
+
self.description = f"Rename {self.new_name} -> {self.old_name}"
|
| 47 |
+
|
| 48 |
+
def apply(self, script: str) -> str:
|
| 49 |
+
if not self.new_name:
|
| 50 |
+
return script
|
| 51 |
+
pattern = re.compile(rf"(?<!\w){re.escape(self.new_name)}(?!\w)")
|
| 52 |
+
return pattern.sub(self.old_name, script)
|
| 53 |
+
|
| 54 |
+
def _get_params(self) -> dict:
|
| 55 |
+
return {"new_name": self.new_name, "old_name": self.old_name}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass
|
| 59 |
+
class RestoreImport(RepairPrimitive):
|
| 60 |
+
new_module: str = ""
|
| 61 |
+
old_module: str = ""
|
| 62 |
+
|
| 63 |
+
def __post_init__(self) -> None:
|
| 64 |
+
self.category = "import_drift"
|
| 65 |
+
self.name = "RestoreImport"
|
| 66 |
+
self.description = f"Restore import {self.new_module} -> {self.old_module}"
|
| 67 |
+
|
| 68 |
+
def apply(self, script: str) -> str:
|
| 69 |
+
return script.replace(self.new_module, self.old_module)
|
| 70 |
+
|
| 71 |
+
def _get_params(self) -> dict:
|
| 72 |
+
return {"new_module": self.new_module, "old_module": self.old_module}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class RestoreArgument(RepairPrimitive):
|
| 77 |
+
"""Re-add a removed argument to a function call."""
|
| 78 |
+
|
| 79 |
+
function_name: str = ""
|
| 80 |
+
arg_name: str = ""
|
| 81 |
+
arg_value: str = ""
|
| 82 |
+
|
| 83 |
+
def __post_init__(self) -> None:
|
| 84 |
+
self.category = "api_drift"
|
| 85 |
+
self.name = "RestoreArgument"
|
| 86 |
+
self.description = (
|
| 87 |
+
f"Add {self.arg_name}={self.arg_value} to {self.function_name}()"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def apply(self, script: str) -> str:
|
| 91 |
+
if not self.function_name:
|
| 92 |
+
return script
|
| 93 |
+
# Insert the kwarg right after the function-name's opening paren.
|
| 94 |
+
pattern = rf"({re.escape(self.function_name)}\s*\()(\s*)"
|
| 95 |
+
replacement = rf"\g<1>{self.arg_name}={self.arg_value}, \g<2>"
|
| 96 |
+
return re.sub(pattern, replacement, script, count=1)
|
| 97 |
+
|
| 98 |
+
def _get_params(self) -> dict:
|
| 99 |
+
return {
|
| 100 |
+
"function_name": self.function_name,
|
| 101 |
+
"arg_name": self.arg_name,
|
| 102 |
+
"arg_value": self.arg_value,
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@dataclass
|
| 107 |
+
class RestoreConfigField(RepairPrimitive):
|
| 108 |
+
field_name: str = ""
|
| 109 |
+
old_value: str = ""
|
| 110 |
+
|
| 111 |
+
def __post_init__(self) -> None:
|
| 112 |
+
self.category = "config_drift"
|
| 113 |
+
self.name = "RestoreConfigField"
|
| 114 |
+
self.description = f"Restore {self.field_name}={self.old_value}"
|
| 115 |
+
|
| 116 |
+
def apply(self, script: str) -> str:
|
| 117 |
+
if not self.field_name:
|
| 118 |
+
return script
|
| 119 |
+
pattern = rf"({re.escape(self.field_name)}\s*=\s*)([^,)\n]+)"
|
| 120 |
+
return re.sub(pattern, rf"\g<1>{self.old_value}", script)
|
| 121 |
+
|
| 122 |
+
def _get_params(self) -> dict:
|
| 123 |
+
return {"field_name": self.field_name, "old_value": self.old_value}
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@dataclass
|
| 127 |
+
class RestoreColumn(RepairPrimitive):
|
| 128 |
+
new_column: str = ""
|
| 129 |
+
old_column: str = ""
|
| 130 |
+
|
| 131 |
+
def __post_init__(self) -> None:
|
| 132 |
+
self.category = "dataset_drift"
|
| 133 |
+
self.name = "RestoreColumn"
|
| 134 |
+
self.description = f"Rename column {self.new_column} -> {self.old_column}"
|
| 135 |
+
|
| 136 |
+
def apply(self, script: str) -> str:
|
| 137 |
+
return script.replace(
|
| 138 |
+
f'"{self.new_column}"', f'"{self.old_column}"'
|
| 139 |
+
).replace(
|
| 140 |
+
f"'{self.new_column}'", f"'{self.old_column}'"
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
def _get_params(self) -> dict:
|
| 144 |
+
return {"new_column": self.new_column, "old_column": self.old_column}
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@dataclass
|
| 148 |
+
class RestoreTokenizerKwarg(RepairPrimitive):
|
| 149 |
+
new_kwarg: str = ""
|
| 150 |
+
new_value: str = ""
|
| 151 |
+
old_kwarg: str = ""
|
| 152 |
+
old_value: str = ""
|
| 153 |
+
|
| 154 |
+
def __post_init__(self) -> None:
|
| 155 |
+
self.category = "tokenizer_drift"
|
| 156 |
+
self.name = "RestoreTokenizerKwarg"
|
| 157 |
+
self.description = (
|
| 158 |
+
f"Restore tokenizer {self.new_kwarg}={self.new_value} -> "
|
| 159 |
+
f"{self.old_kwarg}={self.old_value}"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def apply(self, script: str) -> str:
|
| 163 |
+
if not self.new_kwarg:
|
| 164 |
+
return script
|
| 165 |
+
pattern = rf"{re.escape(self.new_kwarg)}\s*=\s*{re.escape(self.new_value)}"
|
| 166 |
+
replacement = f"{self.old_kwarg}={self.old_value}"
|
| 167 |
+
return re.sub(pattern, replacement, script)
|
| 168 |
+
|
| 169 |
+
def _get_params(self) -> dict:
|
| 170 |
+
return {
|
| 171 |
+
"new_kwarg": self.new_kwarg,
|
| 172 |
+
"new_value": self.new_value,
|
| 173 |
+
"old_kwarg": self.old_kwarg,
|
| 174 |
+
"old_value": self.old_value,
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@dataclass
|
| 179 |
+
class RestoreMethod(RepairPrimitive):
|
| 180 |
+
method_name: str = ""
|
| 181 |
+
|
| 182 |
+
def __post_init__(self) -> None:
|
| 183 |
+
self.category = "api_drift"
|
| 184 |
+
self.name = "RestoreMethod"
|
| 185 |
+
self.description = f"Un-deprecate .{self.method_name}()"
|
| 186 |
+
|
| 187 |
+
def apply(self, script: str) -> str:
|
| 188 |
+
if not self.method_name:
|
| 189 |
+
return script
|
| 190 |
+
return script.replace(
|
| 191 |
+
f".{self.method_name}_DEPRECATED(", f".{self.method_name}("
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
def _get_params(self) -> dict:
|
| 195 |
+
return {"method_name": self.method_name}
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@dataclass
|
| 199 |
+
class RestoreReturnAccess(RepairPrimitive):
|
| 200 |
+
new_access: str = ""
|
| 201 |
+
old_access: str = ""
|
| 202 |
+
|
| 203 |
+
def __post_init__(self) -> None:
|
| 204 |
+
self.category = "api_drift"
|
| 205 |
+
self.name = "RestoreReturnAccess"
|
| 206 |
+
self.description = f"Restore return-access {self.new_access} -> {self.old_access}"
|
| 207 |
+
|
| 208 |
+
def apply(self, script: str) -> str:
|
| 209 |
+
if not self.new_access:
|
| 210 |
+
return script
|
| 211 |
+
return script.replace(self.new_access, self.old_access)
|
| 212 |
+
|
| 213 |
+
def _get_params(self) -> dict:
|
| 214 |
+
return {"new_access": self.new_access, "old_access": self.old_access}
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
REPAIR_REGISTRY: dict[str, type[RepairPrimitive]] = {
|
| 218 |
+
"RestoreApiCall": RestoreApiCall,
|
| 219 |
+
"RestoreImport": RestoreImport,
|
| 220 |
+
"RestoreArgument": RestoreArgument,
|
| 221 |
+
"RestoreConfigField": RestoreConfigField,
|
| 222 |
+
"RestoreColumn": RestoreColumn,
|
| 223 |
+
"RestoreTokenizerKwarg": RestoreTokenizerKwarg,
|
| 224 |
+
"RestoreMethod": RestoreMethod,
|
| 225 |
+
"RestoreReturnAccess": RestoreReturnAccess,
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
# Map a breakage primitive's class name to the repair-primitive class that
|
| 230 |
+
# inverts it. Used by the warm-start pair generator and by the demo / repair
|
| 231 |
+
# library curator.
|
| 232 |
+
BREAKAGE_TO_REPAIR: dict[str, str] = {
|
| 233 |
+
"RenameApiCall": "RestoreApiCall",
|
| 234 |
+
"DeprecateImport": "RestoreImport",
|
| 235 |
+
"ChangeArgumentSignature": "RestoreArgument",
|
| 236 |
+
"ModifyConfigField": "RestoreConfigField",
|
| 237 |
+
"RestructureDatasetSchema": "RestoreColumn",
|
| 238 |
+
"ChangeTokenizerBehavior": "RestoreTokenizerKwarg",
|
| 239 |
+
"RemoveDeprecatedMethod": "RestoreMethod",
|
| 240 |
+
"ChangeReturnType": "RestoreReturnAccess",
|
| 241 |
+
}
|
forgeenv/roles/__init__.py
ADDED
|
File without changes
|
forgeenv/roles/drift_generator.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Drift Generator parser + a deterministic baseline policy.
|
| 2 |
+
|
| 3 |
+
In training the LLM produces a JSON breakage spec; we parse it. In rollouts
|
| 4 |
+
where we want a baseline (or a fallback when the LLM emits malformed JSON)
|
| 5 |
+
we use `BaselineDriftGenerator`, which samples from the per-category set of
|
| 6 |
+
known good primitive parameterisations.
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import random
|
| 12 |
+
import re
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from typing import Optional
|
| 15 |
+
|
| 16 |
+
from forgeenv.primitives.breakage_primitives import (
|
| 17 |
+
PRIMITIVE_REGISTRY,
|
| 18 |
+
parse_breakage_spec,
|
| 19 |
+
BreakagePrimitive,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
_JSON_RE = re.compile(r"\{[\s\S]*\}")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def parse_drift_output(text: str) -> Optional[dict]:
|
| 27 |
+
"""Extract a JSON object from possibly-noisy LLM output.
|
| 28 |
+
|
| 29 |
+
Handles markdown fences, prose preamble, trailing commas (best-effort).
|
| 30 |
+
Returns None on failure.
|
| 31 |
+
"""
|
| 32 |
+
if not text:
|
| 33 |
+
return None
|
| 34 |
+
text = text.strip()
|
| 35 |
+
if text.startswith("```"):
|
| 36 |
+
text = re.sub(r"^```[a-zA-Z]*\n?", "", text)
|
| 37 |
+
text = re.sub(r"\n?```$", "", text)
|
| 38 |
+
match = _JSON_RE.search(text)
|
| 39 |
+
if not match:
|
| 40 |
+
return None
|
| 41 |
+
blob = match.group(0)
|
| 42 |
+
try:
|
| 43 |
+
return json.loads(blob)
|
| 44 |
+
except json.JSONDecodeError:
|
| 45 |
+
cleaned = re.sub(r",\s*([}\]])", r"\1", blob)
|
| 46 |
+
try:
|
| 47 |
+
return json.loads(cleaned)
|
| 48 |
+
except json.JSONDecodeError:
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def parse_drift_to_primitive(text: str) -> Optional[BreakagePrimitive]:
|
| 53 |
+
"""End-to-end: LLM text -> validated BreakagePrimitive (or None)."""
|
| 54 |
+
data = parse_drift_output(text)
|
| 55 |
+
if not isinstance(data, dict):
|
| 56 |
+
return None
|
| 57 |
+
try:
|
| 58 |
+
return parse_breakage_spec(data)
|
| 59 |
+
except (ValueError, TypeError):
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ---------------------------------------------------------------- baselines
|
| 64 |
+
_DEFAULT_PARAMS_BY_TYPE: dict[str, list[dict]] = {
|
| 65 |
+
"RenameApiCall": [
|
| 66 |
+
{"old_name": "trainer.train", "new_name": "trainer.start_training"},
|
| 67 |
+
{"old_name": "save_pretrained", "new_name": "save_to_hub"},
|
| 68 |
+
{"old_name": "from_pretrained", "new_name": "load_from_hub"},
|
| 69 |
+
],
|
| 70 |
+
"DeprecateImport": [
|
| 71 |
+
{
|
| 72 |
+
"old_module": "from transformers import Trainer",
|
| 73 |
+
"new_module": "from transformers.legacy import Trainer",
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"old_module": "from transformers import TrainingArguments",
|
| 77 |
+
"new_module": "from transformers.training import TrainingArguments",
|
| 78 |
+
},
|
| 79 |
+
],
|
| 80 |
+
"ChangeArgumentSignature": [
|
| 81 |
+
{
|
| 82 |
+
"function_name": "TrainingArguments",
|
| 83 |
+
"removed_arg": "num_train_epochs",
|
| 84 |
+
"added_arg": "max_steps",
|
| 85 |
+
"added_value": "1000",
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"function_name": "TrainingArguments",
|
| 89 |
+
"removed_arg": "evaluation_strategy",
|
| 90 |
+
"added_arg": "eval_strategy",
|
| 91 |
+
"added_value": '"steps"',
|
| 92 |
+
},
|
| 93 |
+
],
|
| 94 |
+
"ModifyConfigField": [
|
| 95 |
+
{"config_class": "TrainingArguments", "field_name": "learning_rate", "new_value": "5e-3"},
|
| 96 |
+
{"config_class": "TrainingArguments", "field_name": "per_device_train_batch_size", "new_value": "1"},
|
| 97 |
+
],
|
| 98 |
+
"RestructureDatasetSchema": [
|
| 99 |
+
{"old_column": "text", "new_column": "input_text"},
|
| 100 |
+
{"old_column": "label", "new_column": "labels"},
|
| 101 |
+
{"old_column": "tokens", "new_column": "words"},
|
| 102 |
+
],
|
| 103 |
+
"ChangeTokenizerBehavior": [
|
| 104 |
+
{"old_kwarg": "padding", "old_value": "True", "new_kwarg": "pad_to_max_length", "new_value": "True"},
|
| 105 |
+
{"old_kwarg": "truncation", "old_value": "True", "new_kwarg": "truncate", "new_value": "True"},
|
| 106 |
+
],
|
| 107 |
+
"RemoveDeprecatedMethod": [
|
| 108 |
+
{"class_name": "Trainer", "method_name": "evaluate", "replacement": "evaluation_loop"},
|
| 109 |
+
{"class_name": "Trainer", "method_name": "save_model", "replacement": "save_to_hub"},
|
| 110 |
+
],
|
| 111 |
+
"ChangeReturnType": [
|
| 112 |
+
{"function_name": "Trainer.predict", "old_access": ".predictions", "new_access": "[0]"},
|
| 113 |
+
{"function_name": "tokenizer", "old_access": '["input_ids"]', "new_access": ".input_ids"},
|
| 114 |
+
],
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@dataclass
|
| 119 |
+
class BaselineDriftGenerator:
|
| 120 |
+
"""Deterministic stand-in for the LLM Drift Generator.
|
| 121 |
+
|
| 122 |
+
Used for warm-start data, baseline rollouts, and unit tests.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
seed: Optional[int] = None
|
| 126 |
+
|
| 127 |
+
def __post_init__(self) -> None:
|
| 128 |
+
self._rng = random.Random(self.seed) if self.seed is not None else random
|
| 129 |
+
|
| 130 |
+
def propose(
|
| 131 |
+
self, target_category: str = "", script: str = ""
|
| 132 |
+
) -> dict:
|
| 133 |
+
"""Produce a JSON-serializable breakage spec for `target_category`.
|
| 134 |
+
|
| 135 |
+
Order of preference:
|
| 136 |
+
1. A primitive of `target_category` whose default params apply to `script`.
|
| 137 |
+
2. A primitive of any type whose default params apply to `script`.
|
| 138 |
+
3. A primitive of `target_category` (no-op fallback).
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
preferred_types = (
|
| 142 |
+
[target_category] if target_category in _DEFAULT_PARAMS_BY_TYPE else []
|
| 143 |
+
)
|
| 144 |
+
all_types = list(_DEFAULT_PARAMS_BY_TYPE.keys())
|
| 145 |
+
|
| 146 |
+
for type_set in (preferred_types, all_types):
|
| 147 |
+
shuffled = self._rng.sample(type_set, len(type_set)) if type_set else []
|
| 148 |
+
for ptype in shuffled:
|
| 149 |
+
for params in self._rng.sample(
|
| 150 |
+
_DEFAULT_PARAMS_BY_TYPE[ptype],
|
| 151 |
+
len(_DEFAULT_PARAMS_BY_TYPE[ptype]),
|
| 152 |
+
):
|
| 153 |
+
if self._params_apply_to_script(ptype, params, script):
|
| 154 |
+
return {"primitive_type": ptype, "params": dict(params)}
|
| 155 |
+
|
| 156 |
+
ptype = preferred_types[0] if preferred_types else all_types[0]
|
| 157 |
+
return {
|
| 158 |
+
"primitive_type": ptype,
|
| 159 |
+
"params": dict(_DEFAULT_PARAMS_BY_TYPE[ptype][0]),
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
def _params_apply_to_script(ptype: str, params: dict, script: str) -> bool:
|
| 164 |
+
"""Heuristic: would this primitive actually mutate `script`?"""
|
| 165 |
+
if not script:
|
| 166 |
+
return True
|
| 167 |
+
for key in ("old_name", "old_module", "removed_arg", "field_name", "old_column", "old_kwarg", "method_name", "old_access"):
|
| 168 |
+
if key in params and params[key] and params[key] in script:
|
| 169 |
+
return True
|
| 170 |
+
return False
|
forgeenv/roles/prompts.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""System and user prompts for the two RL roles.
|
| 2 |
+
|
| 3 |
+
Both roles are trained from the same base policy (Qwen-2.5-Coder-7B) with
|
| 4 |
+
LoRA adapters per role, so role prompts are the only thing distinguishing
|
| 5 |
+
them at inference time. Keep them concise — every token is a token of GPU
|
| 6 |
+
budget during GRPO rollouts.
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from typing import Iterable
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
PRIMITIVE_DESCRIPTIONS = {
|
| 14 |
+
"RenameApiCall": "Rename a function/method call (api_drift)",
|
| 15 |
+
"DeprecateImport": "Change an import path (import_drift)",
|
| 16 |
+
"ChangeArgumentSignature": "Remove an expected kwarg from a call (api_drift)",
|
| 17 |
+
"ModifyConfigField": "Change a config-class default (config_drift)",
|
| 18 |
+
"RestructureDatasetSchema": "Rename a dataset column reference (dataset_drift)",
|
| 19 |
+
"ChangeTokenizerBehavior": "Change tokenizer call kwargs (tokenizer_drift)",
|
| 20 |
+
"RemoveDeprecatedMethod": "Remove a method, leaving a sentinel _DEPRECATED suffix (api_drift)",
|
| 21 |
+
"ChangeReturnType": "Function returns a different structure (api_drift)",
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
DRIFT_GENERATOR_SYSTEM_PROMPT = """You are the Drift Generator.
|
| 25 |
+
You see a working HuggingFace training script and the curriculum target category.
|
| 26 |
+
Output exactly one JSON object describing a breakage primitive that simulates
|
| 27 |
+
realistic library version drift. The primitive must:
|
| 28 |
+
1. Be PLAUSIBLE — match the kind of breakage that happens between real
|
| 29 |
+
transformers/datasets/trl releases.
|
| 30 |
+
2. Be SOLVABLE — the Repair Agent should be able to fix it from the error trace alone.
|
| 31 |
+
3. Match the requested target_category.
|
| 32 |
+
|
| 33 |
+
Output schema:
|
| 34 |
+
{"primitive_type": "<one of the 8 types>", "params": { ... }}
|
| 35 |
+
|
| 36 |
+
Available primitive types and parameter schemas:
|
| 37 |
+
- RenameApiCall: {"old_name": str, "new_name": str}
|
| 38 |
+
- DeprecateImport: {"old_module": str, "new_module": str}
|
| 39 |
+
- ChangeArgumentSignature: {"function_name": str, "removed_arg": str, "added_arg": str, "added_value": str}
|
| 40 |
+
- ModifyConfigField: {"config_class": str, "field_name": str, "new_value": str}
|
| 41 |
+
- RestructureDatasetSchema: {"old_column": str, "new_column": str}
|
| 42 |
+
- ChangeTokenizerBehavior: {"old_kwarg": str, "old_value": str, "new_kwarg": str, "new_value": str}
|
| 43 |
+
- RemoveDeprecatedMethod: {"class_name": str, "method_name": str, "replacement": str}
|
| 44 |
+
- ChangeReturnType: {"function_name": str, "old_access": str, "new_access": str}
|
| 45 |
+
|
| 46 |
+
Output ONLY the JSON object — no commentary, no markdown fences.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
REPAIR_AGENT_SYSTEM_PROMPT = """You are the Repair Agent.
|
| 51 |
+
You see a broken HuggingFace training script, an error trace, and the current
|
| 52 |
+
library version snapshot. Output ONLY a unified diff that fixes the script.
|
| 53 |
+
|
| 54 |
+
Rules:
|
| 55 |
+
1. Use canonical unified-diff format with `--- a/train.py` / `+++ b/train.py`
|
| 56 |
+
headers and `@@ ... @@` hunk markers.
|
| 57 |
+
2. Make the MINIMAL change that resolves the error AND preserves the original
|
| 58 |
+
training intent. Do NOT add bare-except blocks, monkey-patches, or sys.exit
|
| 59 |
+
calls.
|
| 60 |
+
3. Do NOT add any prose, markdown fences, or thinking output — diff only.
|
| 61 |
+
4. If the error is unfixable, output an empty diff.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def render_drift_generator_prompt(
|
| 66 |
+
script: str, target_category: str, library_versions: dict
|
| 67 |
+
) -> str:
|
| 68 |
+
versions_str = ", ".join(f"{k}={v}" for k, v in library_versions.items())
|
| 69 |
+
return f"""Target category: {target_category}
|
| 70 |
+
Library versions: {versions_str}
|
| 71 |
+
|
| 72 |
+
Working script:
|
| 73 |
+
```python
|
| 74 |
+
{script}
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
Output JSON breakage primitive:"""
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def render_repair_agent_prompt(
|
| 81 |
+
broken_script: str,
|
| 82 |
+
error_trace: str,
|
| 83 |
+
library_versions: dict,
|
| 84 |
+
target_category: str = "",
|
| 85 |
+
) -> str:
|
| 86 |
+
versions_str = ", ".join(f"{k}={v}" for k, v in library_versions.items())
|
| 87 |
+
return f"""Library versions: {versions_str}
|
| 88 |
+
Target category hint: {target_category or 'unknown'}
|
| 89 |
+
|
| 90 |
+
Broken script:
|
| 91 |
+
```python
|
| 92 |
+
{broken_script}
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
Error trace:
|
| 96 |
+
{error_trace}
|
| 97 |
+
|
| 98 |
+
Output unified diff (no prose, no fences):"""
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def list_primitive_descriptions() -> Iterable[str]:
|
| 102 |
+
return (f"- {k}: {v}" for k, v in PRIMITIVE_DESCRIPTIONS.items())
|
forgeenv/roles/repair_agent.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Repair Agent helpers: response sanitisation + a deterministic baseline.
|
| 2 |
+
|
| 3 |
+
The Repair Agent's training output is a unified diff. LLMs frequently emit
|
| 4 |
+
prose / fences / chain-of-thought before the diff; this module strips that
|
| 5 |
+
preamble. The baseline policy uses the inverse-primitive map from
|
| 6 |
+
`repair_primitives.py` to produce ground-truth diffs for warm-start.
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import re
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
from forgeenv.env.diff_utils import make_unified_diff
|
| 15 |
+
from forgeenv.primitives.breakage_primitives import (
|
| 16 |
+
parse_breakage_spec,
|
| 17 |
+
BreakagePrimitive,
|
| 18 |
+
)
|
| 19 |
+
from forgeenv.primitives.repair_primitives import (
|
| 20 |
+
BREAKAGE_TO_REPAIR,
|
| 21 |
+
REPAIR_REGISTRY,
|
| 22 |
+
RepairPrimitive,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
_DIFF_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE)
|
| 27 |
+
_FENCE_RE = re.compile(r"```[a-zA-Z]*\n([\s\S]*?)\n```")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def extract_diff(raw_text: str) -> str:
|
| 31 |
+
"""Pull the unified diff out of an LLM response.
|
| 32 |
+
|
| 33 |
+
Handles: code fences, leading prose / chain-of-thought, trailing notes.
|
| 34 |
+
"""
|
| 35 |
+
if not raw_text:
|
| 36 |
+
return ""
|
| 37 |
+
raw_text = raw_text.strip()
|
| 38 |
+
|
| 39 |
+
fence_match = _FENCE_RE.search(raw_text)
|
| 40 |
+
if fence_match:
|
| 41 |
+
raw_text = fence_match.group(1).strip()
|
| 42 |
+
|
| 43 |
+
lines = raw_text.splitlines()
|
| 44 |
+
start = 0
|
| 45 |
+
for i, line in enumerate(lines):
|
| 46 |
+
if line.startswith(("---", "+++", "@@")):
|
| 47 |
+
start = i
|
| 48 |
+
break
|
| 49 |
+
|
| 50 |
+
return "\n".join(lines[start:])
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def looks_like_diff(text: str) -> bool:
|
| 54 |
+
if not text:
|
| 55 |
+
return False
|
| 56 |
+
has_header = "---" in text and "+++" in text
|
| 57 |
+
has_hunk = bool(_DIFF_HUNK_RE.search(text))
|
| 58 |
+
has_pm = any(line.startswith(("+", "-")) for line in text.splitlines())
|
| 59 |
+
return (has_header and has_hunk) or (has_hunk and has_pm)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ---------------------------------------------------------------- baselines
|
| 63 |
+
@dataclass
|
| 64 |
+
class BaselineRepairAgent:
|
| 65 |
+
"""Deterministic Repair Agent that uses the primitive inverse map.
|
| 66 |
+
|
| 67 |
+
Used for warm-start dataset generation and baseline rollout comparisons.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def repair(
|
| 71 |
+
self,
|
| 72 |
+
broken_script: str,
|
| 73 |
+
breakage_spec: Optional[dict] = None,
|
| 74 |
+
original_script: str = "",
|
| 75 |
+
) -> str:
|
| 76 |
+
"""Return a unified diff (or full replacement script) that fixes the
|
| 77 |
+
broken script.
|
| 78 |
+
|
| 79 |
+
Strategy preference:
|
| 80 |
+
1. If `original_script` is provided, return a diff between the
|
| 81 |
+
broken script and the original (oracle). This is the warm-start
|
| 82 |
+
path — we always know the ground truth.
|
| 83 |
+
2. Otherwise try to invert the structured breakage_spec via the
|
| 84 |
+
repair-primitive registry.
|
| 85 |
+
3. Otherwise return an empty diff.
|
| 86 |
+
"""
|
| 87 |
+
if original_script and original_script != broken_script:
|
| 88 |
+
return make_unified_diff(broken_script, original_script)
|
| 89 |
+
|
| 90 |
+
if breakage_spec:
|
| 91 |
+
try:
|
| 92 |
+
breakage = parse_breakage_spec(breakage_spec)
|
| 93 |
+
except (ValueError, TypeError):
|
| 94 |
+
breakage = None
|
| 95 |
+
if breakage is not None:
|
| 96 |
+
repair = _invert_breakage(breakage)
|
| 97 |
+
if repair is not None:
|
| 98 |
+
repaired = repair.apply(broken_script)
|
| 99 |
+
if repaired != broken_script:
|
| 100 |
+
return make_unified_diff(broken_script, repaired)
|
| 101 |
+
|
| 102 |
+
return ""
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
_PARAM_REMAP: dict[str, dict[str, str]] = {
|
| 106 |
+
"RenameApiCall": {"old_name": "old_name", "new_name": "new_name"},
|
| 107 |
+
"DeprecateImport": {"old_module": "old_module", "new_module": "new_module"},
|
| 108 |
+
"ChangeArgumentSignature": {
|
| 109 |
+
"function_name": "function_name",
|
| 110 |
+
"removed_arg": "arg_name",
|
| 111 |
+
},
|
| 112 |
+
"ModifyConfigField": {"field_name": "field_name"},
|
| 113 |
+
"RestructureDatasetSchema": {
|
| 114 |
+
"old_column": "old_column",
|
| 115 |
+
"new_column": "new_column",
|
| 116 |
+
},
|
| 117 |
+
"ChangeTokenizerBehavior": {
|
| 118 |
+
"old_kwarg": "old_kwarg",
|
| 119 |
+
"old_value": "old_value",
|
| 120 |
+
"new_kwarg": "new_kwarg",
|
| 121 |
+
"new_value": "new_value",
|
| 122 |
+
},
|
| 123 |
+
"RemoveDeprecatedMethod": {"method_name": "method_name"},
|
| 124 |
+
"ChangeReturnType": {"old_access": "old_access", "new_access": "new_access"},
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _invert_breakage(breakage: BreakagePrimitive) -> Optional[RepairPrimitive]:
|
| 129 |
+
breakage_name = type(breakage).__name__
|
| 130 |
+
repair_name = BREAKAGE_TO_REPAIR.get(breakage_name)
|
| 131 |
+
if repair_name is None:
|
| 132 |
+
return None
|
| 133 |
+
repair_cls = REPAIR_REGISTRY.get(repair_name)
|
| 134 |
+
if repair_cls is None:
|
| 135 |
+
return None
|
| 136 |
+
|
| 137 |
+
breakage_params = breakage._get_params() # type: ignore[attr-defined]
|
| 138 |
+
remap = _PARAM_REMAP.get(breakage_name, {})
|
| 139 |
+
mapped: dict[str, str] = {}
|
| 140 |
+
for src_key, dst_key in remap.items():
|
| 141 |
+
if src_key in breakage_params:
|
| 142 |
+
mapped[dst_key] = breakage_params[src_key]
|
| 143 |
+
|
| 144 |
+
valid_fields = {
|
| 145 |
+
f.name
|
| 146 |
+
for f in repair_cls.__dataclass_fields__.values() # type: ignore[attr-defined]
|
| 147 |
+
if f.init
|
| 148 |
+
}
|
| 149 |
+
filtered = {k: v for k, v in mapped.items() if k in valid_fields}
|
| 150 |
+
try:
|
| 151 |
+
return repair_cls(**filtered)
|
| 152 |
+
except TypeError:
|
| 153 |
+
return None
|
forgeenv/roles/teacher.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Teacher (curriculum controller).
|
| 2 |
+
|
| 3 |
+
Deterministic — NOT an LLM. Maintains an EMA success rate per breakage
|
| 4 |
+
category and routes the next episode toward the category where the
|
| 5 |
+
Repair Agent is closest to a 50% success rate (R-Zero's difficulty band).
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import random
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class Teacher:
|
| 15 |
+
categories: list[str]
|
| 16 |
+
alpha: float = 0.9
|
| 17 |
+
success_counts: dict[str, int] = field(default_factory=dict)
|
| 18 |
+
attempt_counts: dict[str, int] = field(default_factory=dict)
|
| 19 |
+
ema_success: dict[str, float] = field(default_factory=dict)
|
| 20 |
+
|
| 21 |
+
def __post_init__(self) -> None:
|
| 22 |
+
for category in self.categories:
|
| 23 |
+
self.success_counts.setdefault(category, 0)
|
| 24 |
+
self.attempt_counts.setdefault(category, 0)
|
| 25 |
+
self.ema_success.setdefault(category, 0.5)
|
| 26 |
+
|
| 27 |
+
def update(self, category: str, success: bool) -> None:
|
| 28 |
+
if category not in self.ema_success:
|
| 29 |
+
self.categories.append(category)
|
| 30 |
+
self.ema_success[category] = 0.5
|
| 31 |
+
self.success_counts[category] = 0
|
| 32 |
+
self.attempt_counts[category] = 0
|
| 33 |
+
|
| 34 |
+
self.attempt_counts[category] += 1
|
| 35 |
+
self.success_counts[category] += int(success)
|
| 36 |
+
rate = self.success_counts[category] / max(1, self.attempt_counts[category])
|
| 37 |
+
self.ema_success[category] = (
|
| 38 |
+
self.alpha * self.ema_success[category] + (1 - self.alpha) * rate
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def select_next_category(self) -> str:
|
| 42 |
+
in_zone = {
|
| 43 |
+
c: abs(s - 0.5) for c, s in self.ema_success.items() if 0.3 <= s <= 0.7
|
| 44 |
+
}
|
| 45 |
+
if in_zone:
|
| 46 |
+
weights = [1.0 / (v + 0.01) for v in in_zone.values()]
|
| 47 |
+
return random.choices(list(in_zone.keys()), weights=weights, k=1)[0]
|
| 48 |
+
return min(self.ema_success, key=lambda c: abs(self.ema_success[c] - 0.5))
|
| 49 |
+
|
| 50 |
+
def get_state(self) -> dict:
|
| 51 |
+
return {
|
| 52 |
+
c: {
|
| 53 |
+
"ema_success": round(self.ema_success[c], 4),
|
| 54 |
+
"attempts": self.attempt_counts[c],
|
| 55 |
+
"successes": self.success_counts[c],
|
| 56 |
+
}
|
| 57 |
+
for c in self.categories
|
| 58 |
+
}
|
forgeenv/sandbox/__init__.py
ADDED
|
File without changes
|
forgeenv/sandbox/ast_validator.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""AST-based script validator.
|
| 2 |
+
|
| 3 |
+
Catches forbidden imports and dangerous patterns BEFORE any execution
|
| 4 |
+
happens. This is a critical defense against reward hacking via system
|
| 5 |
+
calls, network access, or process manipulation.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import ast
|
| 10 |
+
|
| 11 |
+
from forgeenv.tasks.models import ValidationResult
|
| 12 |
+
|
| 13 |
+
FORBIDDEN_MODULES = {
|
| 14 |
+
"os",
|
| 15 |
+
"subprocess",
|
| 16 |
+
"socket",
|
| 17 |
+
"urllib",
|
| 18 |
+
"requests",
|
| 19 |
+
"ctypes",
|
| 20 |
+
"shutil",
|
| 21 |
+
"signal",
|
| 22 |
+
"multiprocessing",
|
| 23 |
+
"threading",
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
FORBIDDEN_FUNCTIONS = {"eval", "exec", "compile", "__import__"}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def validate_script(script_content: str) -> ValidationResult:
|
| 30 |
+
"""Parse a script as AST and reject forbidden patterns.
|
| 31 |
+
|
| 32 |
+
Returns a ValidationResult with `is_valid` and a list of `violations`.
|
| 33 |
+
"""
|
| 34 |
+
violations: list[str] = []
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
tree = ast.parse(script_content)
|
| 38 |
+
except SyntaxError as e:
|
| 39 |
+
return ValidationResult(is_valid=False, violations=[f"SyntaxError: {e}"])
|
| 40 |
+
|
| 41 |
+
for node in ast.walk(tree):
|
| 42 |
+
if isinstance(node, ast.Import):
|
| 43 |
+
for alias in node.names:
|
| 44 |
+
module_root = alias.name.split(".")[0]
|
| 45 |
+
if module_root in FORBIDDEN_MODULES:
|
| 46 |
+
violations.append(f"Forbidden import: {alias.name}")
|
| 47 |
+
|
| 48 |
+
if isinstance(node, ast.ImportFrom):
|
| 49 |
+
if node.module:
|
| 50 |
+
module_root = node.module.split(".")[0]
|
| 51 |
+
if module_root in FORBIDDEN_MODULES:
|
| 52 |
+
violations.append(f"Forbidden import from: {node.module}")
|
| 53 |
+
|
| 54 |
+
if isinstance(node, ast.Call):
|
| 55 |
+
if isinstance(node.func, ast.Name):
|
| 56 |
+
if node.func.id in FORBIDDEN_FUNCTIONS:
|
| 57 |
+
violations.append(f"Forbidden call: {node.func.id}()")
|
| 58 |
+
if isinstance(node.func, ast.Attribute):
|
| 59 |
+
if node.func.attr in FORBIDDEN_FUNCTIONS:
|
| 60 |
+
violations.append(f"Forbidden call: .{node.func.attr}()")
|
| 61 |
+
|
| 62 |
+
if isinstance(node, ast.Assign):
|
| 63 |
+
for target in node.targets:
|
| 64 |
+
if isinstance(target, ast.Name) and target.id == "__builtins__":
|
| 65 |
+
violations.append("Forbidden: __builtins__ assignment")
|
| 66 |
+
|
| 67 |
+
return ValidationResult(
|
| 68 |
+
is_valid=len(violations) == 0,
|
| 69 |
+
violations=violations,
|
| 70 |
+
)
|
forgeenv/sandbox/simulation_mode.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Fast simulation executor for development.
|
| 2 |
+
|
| 3 |
+
Static-analysis-based execution simulator. Sub-100ms per call. No Docker
|
| 4 |
+
required. The success probability of a simulated run depends on whether
|
| 5 |
+
the script contains expected HF training markers (model imports, training
|
| 6 |
+
calls, save calls). When the simulation succeeds, a synthetic decreasing
|
| 7 |
+
loss curve is emitted; when it fails, a representative HF error is raised.
|
| 8 |
+
"""
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import random
|
| 12 |
+
import time
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
from forgeenv.sandbox.ast_validator import validate_script
|
| 16 |
+
from forgeenv.tasks.models import ExecutionResult, Task
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SimulationExecutor:
|
| 20 |
+
"""Simulates script execution via static analysis.
|
| 21 |
+
|
| 22 |
+
Use this throughout development phases. Real Docker execution is added
|
| 23 |
+
later for grounded final-stage verification.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, seed: Optional[int] = None) -> None:
|
| 27 |
+
self._rng = random.Random(seed) if seed is not None else random
|
| 28 |
+
|
| 29 |
+
def execute(
|
| 30 |
+
self, script_content: str, task: Optional[Task] = None
|
| 31 |
+
) -> ExecutionResult:
|
| 32 |
+
start = time.time()
|
| 33 |
+
|
| 34 |
+
validation = validate_script(script_content)
|
| 35 |
+
if not validation.is_valid:
|
| 36 |
+
return ExecutionResult(
|
| 37 |
+
exit_code=1,
|
| 38 |
+
stdout="",
|
| 39 |
+
stderr=f"Validation failed: {'; '.join(validation.violations)}",
|
| 40 |
+
wall_time_ms=int((time.time() - start) * 1000),
|
| 41 |
+
script_content=script_content,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
compile(script_content, "<forge_script>", "exec")
|
| 46 |
+
except SyntaxError as e:
|
| 47 |
+
return ExecutionResult(
|
| 48 |
+
exit_code=1,
|
| 49 |
+
stdout="",
|
| 50 |
+
stderr=f"SyntaxError: {e}",
|
| 51 |
+
wall_time_ms=int((time.time() - start) * 1000),
|
| 52 |
+
script_content=script_content,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
has_model_import = any(
|
| 56 |
+
kw in script_content
|
| 57 |
+
for kw in ("from transformers", "import torch", "from datasets")
|
| 58 |
+
)
|
| 59 |
+
has_training_call = any(
|
| 60 |
+
kw in script_content
|
| 61 |
+
for kw in ("trainer.train()", ".fit(", "train_loop", "for epoch")
|
| 62 |
+
)
|
| 63 |
+
has_save = any(
|
| 64 |
+
kw in script_content
|
| 65 |
+
for kw in ("save_pretrained", "save_model", "torch.save")
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
success_prob = 0.3
|
| 69 |
+
if has_model_import:
|
| 70 |
+
success_prob += 0.3
|
| 71 |
+
if has_training_call:
|
| 72 |
+
success_prob += 0.2
|
| 73 |
+
if has_save:
|
| 74 |
+
success_prob += 0.1
|
| 75 |
+
|
| 76 |
+
# Mark obviously broken patterns as definite failures even when
|
| 77 |
+
# they pass syntactic compilation. The simulator pretends to be a
|
| 78 |
+
# static linter that catches AttributeError / ImportError signatures
|
| 79 |
+
# before they would fire at runtime.
|
| 80 |
+
broken_markers = (
|
| 81 |
+
"_DEPRECATED(",
|
| 82 |
+
"transformers.legacy",
|
| 83 |
+
"from transformers.training import",
|
| 84 |
+
".start_training(",
|
| 85 |
+
"load_from_hub(",
|
| 86 |
+
"save_to_hub(",
|
| 87 |
+
"pad_to_max_length=",
|
| 88 |
+
"evaluation_loop(",
|
| 89 |
+
)
|
| 90 |
+
if any(marker in script_content for marker in broken_markers):
|
| 91 |
+
success_prob = 0.0
|
| 92 |
+
# Patterns that look like dataset column drift: a renamed column
|
| 93 |
+
# that doesn't appear in real HF datasets.
|
| 94 |
+
import re as _re
|
| 95 |
+
|
| 96 |
+
if _re.search(r"['\"]input_text['\"]\s*[]:),]", script_content):
|
| 97 |
+
success_prob = min(success_prob, 0.05)
|
| 98 |
+
if _re.search(r"['\"]words['\"]\s*[]:),]", script_content):
|
| 99 |
+
success_prob = min(success_prob, 0.05)
|
| 100 |
+
# Tokenizer kwarg drift (truncate is not valid; truncation is).
|
| 101 |
+
if _re.search(r"\btruncate\s*=", script_content):
|
| 102 |
+
success_prob = min(success_prob, 0.05)
|
| 103 |
+
|
| 104 |
+
succeeded = self._rng.random() < success_prob
|
| 105 |
+
|
| 106 |
+
if succeeded:
|
| 107 |
+
steps = self._rng.randint(20, 50)
|
| 108 |
+
log_lines: list[str] = []
|
| 109 |
+
loss = self._rng.uniform(2.0, 4.0)
|
| 110 |
+
for step in range(1, steps + 1):
|
| 111 |
+
loss *= self._rng.uniform(0.92, 0.99)
|
| 112 |
+
log_lines.append(f"step={step} loss={loss:.4f}")
|
| 113 |
+
log_lines.append("eval_accuracy=0.78")
|
| 114 |
+
log_lines.append("TRAINING_COMPLETE")
|
| 115 |
+
|
| 116 |
+
return ExecutionResult(
|
| 117 |
+
exit_code=0,
|
| 118 |
+
stdout="\n".join(log_lines),
|
| 119 |
+
stderr="",
|
| 120 |
+
wall_time_ms=int((time.time() - start) * 1000)
|
| 121 |
+
+ self._rng.randint(1000, 5000),
|
| 122 |
+
checkpoint_exists=True,
|
| 123 |
+
peak_memory_mb=self._rng.uniform(500, 2000),
|
| 124 |
+
script_content=script_content,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
error_types = [
|
| 128 |
+
"ImportError: cannot import name 'OldTrainer' from 'transformers'",
|
| 129 |
+
"AttributeError: 'Trainer' object has no attribute 'evaluate_model'",
|
| 130 |
+
"KeyError: 'text' column not found in dataset",
|
| 131 |
+
"TypeError: __init__() got an unexpected keyword argument 'num_epochs'",
|
| 132 |
+
"RuntimeError: Expected input batch_size (16) to match target batch_size (32)",
|
| 133 |
+
"ModuleNotFoundError: No module named 'transformers.legacy'",
|
| 134 |
+
]
|
| 135 |
+
return ExecutionResult(
|
| 136 |
+
exit_code=1,
|
| 137 |
+
stdout="",
|
| 138 |
+
stderr=self._rng.choice(error_types),
|
| 139 |
+
wall_time_ms=int((time.time() - start) * 1000)
|
| 140 |
+
+ self._rng.randint(100, 500),
|
| 141 |
+
script_content=script_content,
|
| 142 |
+
)
|
forgeenv/tasks/__init__.py
ADDED
|
File without changes
|
forgeenv/tasks/models.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Core data models for ForgeEnv tasks and execution results.
|
| 2 |
+
|
| 3 |
+
These are framework-internal dataclasses (not Pydantic) used throughout the
|
| 4 |
+
simulation, verifier, and primitive layers. The OpenEnv-facing Pydantic
|
| 5 |
+
models live in `forgeenv.env.actions` / `forgeenv.env.observations`.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class Task:
|
| 15 |
+
"""A HuggingFace training script with execution metadata."""
|
| 16 |
+
|
| 17 |
+
task_id: str
|
| 18 |
+
description: str
|
| 19 |
+
script_content: str
|
| 20 |
+
difficulty: str # "easy", "medium", "hard"
|
| 21 |
+
category: str = "general"
|
| 22 |
+
expected_loss_range: tuple[float, float] = (0.0, 5.0)
|
| 23 |
+
expected_accuracy_range: tuple[float, float] = (0.0, 1.0)
|
| 24 |
+
checkpoint_output_path: str = "/tmp/forge_output/checkpoint"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class ExecutionResult:
|
| 29 |
+
"""Result of executing a Python script in the sandbox."""
|
| 30 |
+
|
| 31 |
+
exit_code: int
|
| 32 |
+
stdout: str
|
| 33 |
+
stderr: str
|
| 34 |
+
wall_time_ms: int
|
| 35 |
+
checkpoint_exists: bool = False
|
| 36 |
+
peak_memory_mb: float = 0.0
|
| 37 |
+
script_content: str = ""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class ValidationResult:
|
| 42 |
+
"""Result of AST validation on a script."""
|
| 43 |
+
|
| 44 |
+
is_valid: bool
|
| 45 |
+
violations: list[str] = field(default_factory=list)
|
forgeenv/tasks/seed_corpus/__init__.py
ADDED
|
File without changes
|
forgeenv/tasks/seed_corpus/albert_qa.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ALBERT-tiny extractive QA on 100-sample SQuAD subset."""
|
| 2 |
+
from transformers import (
|
| 3 |
+
AutoTokenizer,
|
| 4 |
+
AutoModelForQuestionAnswering,
|
| 5 |
+
Trainer,
|
| 6 |
+
TrainingArguments,
|
| 7 |
+
DefaultDataCollator,
|
| 8 |
+
)
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
|
| 11 |
+
dataset = load_dataset("squad", split="train[:100]")
|
| 12 |
+
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def prepare(examples):
|
| 16 |
+
enc = tokenizer(
|
| 17 |
+
examples["question"],
|
| 18 |
+
examples["context"],
|
| 19 |
+
max_length=128,
|
| 20 |
+
truncation="only_second",
|
| 21 |
+
padding="max_length",
|
| 22 |
+
return_offsets_mapping=True,
|
| 23 |
+
)
|
| 24 |
+
start_positions, end_positions = [], []
|
| 25 |
+
for i, offsets in enumerate(enc["offset_mapping"]):
|
| 26 |
+
answer = examples["answers"][i]
|
| 27 |
+
start_char = answer["answer_start"][0]
|
| 28 |
+
end_char = start_char + len(answer["text"][0])
|
| 29 |
+
|
| 30 |
+
token_start = next(
|
| 31 |
+
(idx for idx, (a, b) in enumerate(offsets) if a <= start_char < b), 0
|
| 32 |
+
)
|
| 33 |
+
token_end = next(
|
| 34 |
+
(idx for idx, (a, b) in enumerate(offsets) if a < end_char <= b), token_start
|
| 35 |
+
)
|
| 36 |
+
start_positions.append(token_start)
|
| 37 |
+
end_positions.append(token_end)
|
| 38 |
+
|
| 39 |
+
enc["start_positions"] = start_positions
|
| 40 |
+
enc["end_positions"] = end_positions
|
| 41 |
+
enc.pop("offset_mapping")
|
| 42 |
+
return enc
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
dataset = dataset.map(prepare, batched=True, remove_columns=dataset.column_names)
|
| 46 |
+
|
| 47 |
+
model = AutoModelForQuestionAnswering.from_pretrained("albert-base-v2")
|
| 48 |
+
|
| 49 |
+
training_args = TrainingArguments(
|
| 50 |
+
output_dir="/tmp/forge_output/checkpoint",
|
| 51 |
+
num_train_epochs=1,
|
| 52 |
+
per_device_train_batch_size=4,
|
| 53 |
+
logging_steps=5,
|
| 54 |
+
save_strategy="epoch",
|
| 55 |
+
no_cuda=True,
|
| 56 |
+
report_to="none",
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
trainer = Trainer(
|
| 60 |
+
model=model,
|
| 61 |
+
args=training_args,
|
| 62 |
+
train_dataset=dataset,
|
| 63 |
+
data_collator=DefaultDataCollator(),
|
| 64 |
+
)
|
| 65 |
+
trainer.train()
|
| 66 |
+
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 67 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/seed_corpus/bert_ner.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Bert tiny NER fine-tuning on a 200-sample CoNLL-2003 subset."""
|
| 2 |
+
from transformers import (
|
| 3 |
+
AutoTokenizer,
|
| 4 |
+
AutoModelForTokenClassification,
|
| 5 |
+
Trainer,
|
| 6 |
+
TrainingArguments,
|
| 7 |
+
DataCollatorForTokenClassification,
|
| 8 |
+
)
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
|
| 11 |
+
dataset = load_dataset("conll2003", split="train[:200]")
|
| 12 |
+
tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def tokenize_and_align(example):
|
| 16 |
+
enc = tokenizer(example["tokens"], is_split_into_words=True, truncation=True, max_length=64)
|
| 17 |
+
word_ids = enc.word_ids()
|
| 18 |
+
labels = []
|
| 19 |
+
prev_id = None
|
| 20 |
+
for wid in word_ids:
|
| 21 |
+
if wid is None:
|
| 22 |
+
labels.append(-100)
|
| 23 |
+
elif wid != prev_id:
|
| 24 |
+
labels.append(example["ner_tags"][wid])
|
| 25 |
+
else:
|
| 26 |
+
labels.append(-100)
|
| 27 |
+
prev_id = wid
|
| 28 |
+
enc["labels"] = labels
|
| 29 |
+
return enc
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
dataset = dataset.map(tokenize_and_align, remove_columns=dataset.column_names)
|
| 33 |
+
|
| 34 |
+
model = AutoModelForTokenClassification.from_pretrained("prajjwal1/bert-tiny", num_labels=9)
|
| 35 |
+
|
| 36 |
+
training_args = TrainingArguments(
|
| 37 |
+
output_dir="/tmp/forge_output/checkpoint",
|
| 38 |
+
num_train_epochs=1,
|
| 39 |
+
per_device_train_batch_size=8,
|
| 40 |
+
logging_steps=5,
|
| 41 |
+
save_strategy="epoch",
|
| 42 |
+
no_cuda=True,
|
| 43 |
+
report_to="none",
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
trainer = Trainer(
|
| 47 |
+
model=model,
|
| 48 |
+
args=training_args,
|
| 49 |
+
train_dataset=dataset,
|
| 50 |
+
data_collator=DataCollatorForTokenClassification(tokenizer),
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
trainer.train()
|
| 54 |
+
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 55 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/seed_corpus/distilbert_sst2.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DistilBERT fine-tuning on a tiny SST-2 subset.
|
| 2 |
+
|
| 3 |
+
Minimal HuggingFace text-classification training script. Should complete
|
| 4 |
+
in ~60s on CPU.
|
| 5 |
+
"""
|
| 6 |
+
from transformers import (
|
| 7 |
+
DistilBertTokenizer,
|
| 8 |
+
DistilBertForSequenceClassification,
|
| 9 |
+
Trainer,
|
| 10 |
+
TrainingArguments,
|
| 11 |
+
)
|
| 12 |
+
from datasets import load_dataset
|
| 13 |
+
|
| 14 |
+
dataset = load_dataset("glue", "sst2", split="train[:500]")
|
| 15 |
+
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def tokenize_function(examples):
|
| 19 |
+
return tokenizer(
|
| 20 |
+
examples["sentence"],
|
| 21 |
+
padding="max_length",
|
| 22 |
+
truncation=True,
|
| 23 |
+
max_length=64,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
dataset = dataset.map(tokenize_function, batched=True)
|
| 28 |
+
dataset = dataset.rename_column("label", "labels")
|
| 29 |
+
dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
|
| 30 |
+
|
| 31 |
+
model = DistilBertForSequenceClassification.from_pretrained(
|
| 32 |
+
"distilbert-base-uncased", num_labels=2
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
training_args = TrainingArguments(
|
| 36 |
+
output_dir="/tmp/forge_output/checkpoint",
|
| 37 |
+
num_train_epochs=1,
|
| 38 |
+
per_device_train_batch_size=16,
|
| 39 |
+
logging_steps=5,
|
| 40 |
+
save_strategy="epoch",
|
| 41 |
+
no_cuda=True,
|
| 42 |
+
report_to="none",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
trainer = Trainer(
|
| 46 |
+
model=model,
|
| 47 |
+
args=training_args,
|
| 48 |
+
train_dataset=dataset,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
trainer.train()
|
| 52 |
+
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 53 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/seed_corpus/electra_classification.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ELECTRA-small classification on 400-sample AG News (4-way text classification)."""
|
| 2 |
+
from transformers import (
|
| 3 |
+
AutoTokenizer,
|
| 4 |
+
AutoModelForSequenceClassification,
|
| 5 |
+
Trainer,
|
| 6 |
+
TrainingArguments,
|
| 7 |
+
)
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
|
| 10 |
+
dataset = load_dataset("ag_news", split="train[:400]")
|
| 11 |
+
tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def tokenize(examples):
|
| 15 |
+
return tokenizer(
|
| 16 |
+
examples["text"],
|
| 17 |
+
padding="max_length",
|
| 18 |
+
truncation=True,
|
| 19 |
+
max_length=64,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
dataset = dataset.map(tokenize, batched=True)
|
| 24 |
+
dataset = dataset.rename_column("label", "labels")
|
| 25 |
+
dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
|
| 26 |
+
|
| 27 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 28 |
+
"google/electra-small-discriminator", num_labels=4
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
training_args = TrainingArguments(
|
| 32 |
+
output_dir="/tmp/forge_output/checkpoint",
|
| 33 |
+
num_train_epochs=1,
|
| 34 |
+
per_device_train_batch_size=8,
|
| 35 |
+
logging_steps=5,
|
| 36 |
+
save_strategy="epoch",
|
| 37 |
+
no_cuda=True,
|
| 38 |
+
report_to="none",
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
|
| 42 |
+
trainer.train()
|
| 43 |
+
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 44 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/seed_corpus/gpt2_textgen.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DistilGPT2 causal-LM fine-tuning on 300 lines of WikiText (text generation)."""
|
| 2 |
+
from transformers import (
|
| 3 |
+
AutoTokenizer,
|
| 4 |
+
AutoModelForCausalLM,
|
| 5 |
+
Trainer,
|
| 6 |
+
TrainingArguments,
|
| 7 |
+
DataCollatorForLanguageModeling,
|
| 8 |
+
)
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
|
| 11 |
+
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:300]")
|
| 12 |
+
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
|
| 13 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def tokenize(examples):
|
| 17 |
+
return tokenizer(examples["text"], truncation=True, max_length=64)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
|
| 21 |
+
|
| 22 |
+
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
|
| 23 |
+
|
| 24 |
+
training_args = TrainingArguments(
|
| 25 |
+
output_dir="/tmp/forge_output/checkpoint",
|
| 26 |
+
num_train_epochs=1,
|
| 27 |
+
per_device_train_batch_size=4,
|
| 28 |
+
logging_steps=5,
|
| 29 |
+
save_strategy="epoch",
|
| 30 |
+
no_cuda=True,
|
| 31 |
+
report_to="none",
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
trainer = Trainer(
|
| 35 |
+
model=model,
|
| 36 |
+
args=training_args,
|
| 37 |
+
train_dataset=dataset,
|
| 38 |
+
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
trainer.train()
|
| 42 |
+
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 43 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/seed_corpus/logistic_classifier.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Sklearn logistic-regression baseline on a 500-sample tabular task.
|
| 2 |
+
|
| 3 |
+
Sanity baseline that doesn't require torch / transformers / datasets.
|
| 4 |
+
"""
|
| 5 |
+
import json
|
| 6 |
+
import pickle
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from sklearn.datasets import make_classification
|
| 11 |
+
from sklearn.linear_model import LogisticRegression
|
| 12 |
+
from sklearn.model_selection import train_test_split
|
| 13 |
+
|
| 14 |
+
X, y = make_classification(
|
| 15 |
+
n_samples=500, n_features=20, n_informative=10, random_state=0
|
| 16 |
+
)
|
| 17 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
|
| 18 |
+
|
| 19 |
+
model = LogisticRegression(max_iter=200)
|
| 20 |
+
for step in range(1, 11):
|
| 21 |
+
model.set_params(max_iter=step * 20)
|
| 22 |
+
model.fit(X_train, y_train)
|
| 23 |
+
train_loss = -np.mean(np.log(np.maximum(model.predict_proba(X_train)[np.arange(len(y_train)), y_train], 1e-9)))
|
| 24 |
+
print(f"step={step} loss={train_loss:.4f}")
|
| 25 |
+
|
| 26 |
+
acc = model.score(X_test, y_test)
|
| 27 |
+
print(f"eval_accuracy={acc:.4f}")
|
| 28 |
+
|
| 29 |
+
ckpt_dir = Path("/tmp/forge_output/checkpoint")
|
| 30 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 31 |
+
with open(ckpt_dir / "logreg.pkl", "wb") as f:
|
| 32 |
+
pickle.dump(model, f)
|
| 33 |
+
with open(ckpt_dir / "metrics.json", "w") as f:
|
| 34 |
+
json.dump({"accuracy": acc}, f)
|
| 35 |
+
|
| 36 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/seed_corpus/roberta_sentiment.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DistilRoberta sentiment classification on 400-sample IMDB subset."""
|
| 2 |
+
from transformers import (
|
| 3 |
+
AutoTokenizer,
|
| 4 |
+
AutoModelForSequenceClassification,
|
| 5 |
+
Trainer,
|
| 6 |
+
TrainingArguments,
|
| 7 |
+
)
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
|
| 10 |
+
dataset = load_dataset("imdb", split="train[:400]")
|
| 11 |
+
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def tokenize(examples):
|
| 15 |
+
return tokenizer(
|
| 16 |
+
examples["text"],
|
| 17 |
+
padding="max_length",
|
| 18 |
+
truncation=True,
|
| 19 |
+
max_length=64,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
dataset = dataset.map(tokenize, batched=True)
|
| 24 |
+
dataset = dataset.rename_column("label", "labels")
|
| 25 |
+
dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
|
| 26 |
+
|
| 27 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 28 |
+
"distilroberta-base", num_labels=2
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
training_args = TrainingArguments(
|
| 32 |
+
output_dir="/tmp/forge_output/checkpoint",
|
| 33 |
+
num_train_epochs=1,
|
| 34 |
+
per_device_train_batch_size=8,
|
| 35 |
+
logging_steps=5,
|
| 36 |
+
save_strategy="epoch",
|
| 37 |
+
no_cuda=True,
|
| 38 |
+
report_to="none",
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
|
| 42 |
+
trainer.train()
|
| 43 |
+
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 44 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/seed_corpus/simple_regression.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tiny PyTorch regression on synthetic data (no HF imports — sanity baseline)."""
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
torch.manual_seed(0)
|
| 6 |
+
x = torch.randn(500, 4)
|
| 7 |
+
y = (x @ torch.tensor([1.5, -2.0, 0.5, 3.0])) + 0.1 * torch.randn(500)
|
| 8 |
+
|
| 9 |
+
model = nn.Sequential(
|
| 10 |
+
nn.Linear(4, 16),
|
| 11 |
+
nn.ReLU(),
|
| 12 |
+
nn.Linear(16, 1),
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
|
| 16 |
+
criterion = nn.MSELoss()
|
| 17 |
+
|
| 18 |
+
for epoch in range(50):
|
| 19 |
+
optimizer.zero_grad()
|
| 20 |
+
preds = model(x).squeeze(-1)
|
| 21 |
+
loss = criterion(preds, y)
|
| 22 |
+
loss.backward()
|
| 23 |
+
optimizer.step()
|
| 24 |
+
if (epoch + 1) % 5 == 0:
|
| 25 |
+
print(f"step={epoch + 1} loss={loss.item():.4f}")
|
| 26 |
+
|
| 27 |
+
torch.save(model.state_dict(), "/tmp/forge_output/checkpoint/regression.pt")
|
| 28 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/seed_corpus/t5_summarization.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tiny T5 fine-tuning for summarization on 100-sample CNN/DailyMail."""
|
| 2 |
+
from transformers import (
|
| 3 |
+
AutoTokenizer,
|
| 4 |
+
AutoModelForSeq2SeqLM,
|
| 5 |
+
DataCollatorForSeq2Seq,
|
| 6 |
+
Seq2SeqTrainer,
|
| 7 |
+
Seq2SeqTrainingArguments,
|
| 8 |
+
)
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
|
| 11 |
+
dataset = load_dataset("cnn_dailymail", "3.0.0", split="train[:100]")
|
| 12 |
+
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def preprocess(examples):
|
| 16 |
+
inputs = tokenizer(
|
| 17 |
+
["summarize: " + a for a in examples["article"]],
|
| 18 |
+
max_length=128,
|
| 19 |
+
truncation=True,
|
| 20 |
+
padding="max_length",
|
| 21 |
+
)
|
| 22 |
+
targets = tokenizer(
|
| 23 |
+
examples["highlights"],
|
| 24 |
+
max_length=32,
|
| 25 |
+
truncation=True,
|
| 26 |
+
padding="max_length",
|
| 27 |
+
)
|
| 28 |
+
inputs["labels"] = targets["input_ids"]
|
| 29 |
+
return inputs
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
dataset = dataset.map(preprocess, batched=True, remove_columns=dataset.column_names)
|
| 33 |
+
|
| 34 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
|
| 35 |
+
|
| 36 |
+
training_args = Seq2SeqTrainingArguments(
|
| 37 |
+
output_dir="/tmp/forge_output/checkpoint",
|
| 38 |
+
num_train_epochs=1,
|
| 39 |
+
per_device_train_batch_size=4,
|
| 40 |
+
logging_steps=5,
|
| 41 |
+
save_strategy="epoch",
|
| 42 |
+
no_cuda=True,
|
| 43 |
+
report_to="none",
|
| 44 |
+
predict_with_generate=False,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
trainer = Seq2SeqTrainer(
|
| 48 |
+
model=model,
|
| 49 |
+
args=training_args,
|
| 50 |
+
train_dataset=dataset,
|
| 51 |
+
data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
|
| 52 |
+
)
|
| 53 |
+
trainer.train()
|
| 54 |
+
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 55 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/seed_corpus/tiny_mlp_mnist.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tiny PyTorch MLP on a 1000-sample MNIST subset (image classification baseline)."""
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
+
|
| 7 |
+
dataset = load_dataset("mnist", split="train[:1000]")
|
| 8 |
+
dataset = dataset.with_format("torch")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def collate(batch):
|
| 12 |
+
pixel = torch.stack([b["image"].float().flatten() / 255.0 for b in batch])
|
| 13 |
+
labels = torch.tensor([b["label"] for b in batch])
|
| 14 |
+
return pixel, labels
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate)
|
| 18 |
+
|
| 19 |
+
model = nn.Sequential(
|
| 20 |
+
nn.Linear(784, 64),
|
| 21 |
+
nn.ReLU(),
|
| 22 |
+
nn.Linear(64, 10),
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
| 26 |
+
criterion = nn.CrossEntropyLoss()
|
| 27 |
+
|
| 28 |
+
for epoch in range(2):
|
| 29 |
+
for step, (x, y) in enumerate(loader, start=1):
|
| 30 |
+
optimizer.zero_grad()
|
| 31 |
+
loss = criterion(model(x), y)
|
| 32 |
+
loss.backward()
|
| 33 |
+
optimizer.step()
|
| 34 |
+
if step % 5 == 0:
|
| 35 |
+
print(f"step={step} loss={loss.item():.4f}")
|
| 36 |
+
|
| 37 |
+
torch.save(model.state_dict(), "/tmp/forge_output/checkpoint/mlp.pt")
|
| 38 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/seed_corpus/vit_cifar10.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tiny ViT image classification on 200-sample CIFAR-10 subset."""
|
| 2 |
+
from transformers import (
|
| 3 |
+
AutoImageProcessor,
|
| 4 |
+
AutoModelForImageClassification,
|
| 5 |
+
Trainer,
|
| 6 |
+
TrainingArguments,
|
| 7 |
+
)
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
dataset = load_dataset("cifar10", split="train[:200]")
|
| 12 |
+
processor = AutoImageProcessor.from_pretrained("WinKawaks/vit-tiny-patch16-224")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def transform(batch):
|
| 16 |
+
images = [img.convert("RGB") for img in batch["img"]]
|
| 17 |
+
inputs = processor(images=images, return_tensors="pt")
|
| 18 |
+
inputs["labels"] = torch.tensor(batch["label"])
|
| 19 |
+
return inputs
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
dataset = dataset.with_transform(transform)
|
| 23 |
+
|
| 24 |
+
model = AutoModelForImageClassification.from_pretrained(
|
| 25 |
+
"WinKawaks/vit-tiny-patch16-224", num_labels=10, ignore_mismatched_sizes=True
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
training_args = TrainingArguments(
|
| 29 |
+
output_dir="/tmp/forge_output/checkpoint",
|
| 30 |
+
num_train_epochs=1,
|
| 31 |
+
per_device_train_batch_size=4,
|
| 32 |
+
logging_steps=5,
|
| 33 |
+
save_strategy="epoch",
|
| 34 |
+
no_cuda=True,
|
| 35 |
+
report_to="none",
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
|
| 39 |
+
trainer.train()
|
| 40 |
+
trainer.save_model("/tmp/forge_output/checkpoint")
|
| 41 |
+
print("TRAINING_COMPLETE")
|
forgeenv/tasks/task_sampler.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Task sampler: loads the seed corpus and samples Tasks by difficulty.
|
| 2 |
+
|
| 3 |
+
Difficulty is auto-derived from script line count. Category is auto-detected
|
| 4 |
+
from script content (text_classification, ner, translation, etc.).
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import random
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
from forgeenv.tasks.models import Task
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _detect_category(content: str) -> str:
|
| 16 |
+
cl = content.lower()
|
| 17 |
+
if "sequenceclassification" in cl or "sentiment" in cl or "ag_news" in cl or "sst2" in cl:
|
| 18 |
+
return "text_classification"
|
| 19 |
+
if "tokenclassification" in cl or "ner" in cl or "conll" in cl:
|
| 20 |
+
return "ner"
|
| 21 |
+
if "seq2seq" in cl or "translation" in cl or "summariz" in cl or "t5" in cl:
|
| 22 |
+
return "seq2seq"
|
| 23 |
+
if "causallm" in cl or "gpt2" in cl or "wikitext" in cl:
|
| 24 |
+
return "text_generation"
|
| 25 |
+
if "imageclassification" in cl or "vit" in cl or "cifar" in cl or "mnist" in cl:
|
| 26 |
+
return "image_classification"
|
| 27 |
+
if "questionanswering" in cl or "squad" in cl:
|
| 28 |
+
return "qa"
|
| 29 |
+
if "logisticregression" in cl or "make_classification" in cl:
|
| 30 |
+
return "tabular"
|
| 31 |
+
if "regression" in cl:
|
| 32 |
+
return "regression"
|
| 33 |
+
return "general"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _derive_difficulty(content: str) -> str:
|
| 37 |
+
lines = len(content.splitlines())
|
| 38 |
+
if lines < 30:
|
| 39 |
+
return "easy"
|
| 40 |
+
if lines < 60:
|
| 41 |
+
return "medium"
|
| 42 |
+
return "hard"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class TaskSampler:
|
| 46 |
+
"""Loads seed corpus and samples tasks by difficulty / category."""
|
| 47 |
+
|
| 48 |
+
def __init__(self, seed_dir: Optional[str] = None) -> None:
|
| 49 |
+
if seed_dir is None:
|
| 50 |
+
seed_dir = str(Path(__file__).parent / "seed_corpus")
|
| 51 |
+
|
| 52 |
+
self.tasks: list[Task] = []
|
| 53 |
+
self._load_corpus(seed_dir)
|
| 54 |
+
|
| 55 |
+
def _load_corpus(self, seed_dir: str) -> None:
|
| 56 |
+
corpus_path = Path(seed_dir)
|
| 57 |
+
if not corpus_path.exists():
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
for py_file in sorted(corpus_path.glob("*.py")):
|
| 61 |
+
if py_file.name.startswith("__"):
|
| 62 |
+
continue
|
| 63 |
+
|
| 64 |
+
content = py_file.read_text(encoding="utf-8")
|
| 65 |
+
task_id = py_file.stem
|
| 66 |
+
difficulty = _derive_difficulty(content)
|
| 67 |
+
category = _detect_category(content)
|
| 68 |
+
|
| 69 |
+
description = ""
|
| 70 |
+
if content.startswith('"""'):
|
| 71 |
+
end = content.find('"""', 3)
|
| 72 |
+
if end != -1:
|
| 73 |
+
description = content[3:end].strip()
|
| 74 |
+
|
| 75 |
+
self.tasks.append(
|
| 76 |
+
Task(
|
| 77 |
+
task_id=task_id,
|
| 78 |
+
description=description or f"Training script: {task_id}",
|
| 79 |
+
script_content=content,
|
| 80 |
+
difficulty=difficulty,
|
| 81 |
+
category=category,
|
| 82 |
+
)
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def sample(self, difficulty: Optional[str] = None) -> Optional[Task]:
|
| 86 |
+
candidates = self.tasks
|
| 87 |
+
if difficulty is not None:
|
| 88 |
+
filtered = [t for t in self.tasks if t.difficulty == difficulty]
|
| 89 |
+
if filtered:
|
| 90 |
+
candidates = filtered
|
| 91 |
+
return random.choice(candidates) if candidates else None
|
| 92 |
+
|
| 93 |
+
def sample_batch(
|
| 94 |
+
self, n: int, difficulty: Optional[str] = None
|
| 95 |
+
) -> list[Task]:
|
| 96 |
+
return [t for t in (self.sample(difficulty) for _ in range(n)) if t is not None]
|
| 97 |
+
|
| 98 |
+
def get_all_categories(self) -> list[str]:
|
| 99 |
+
return sorted({t.category for t in self.tasks})
|
| 100 |
+
|
| 101 |
+
def get_by_id(self, task_id: str) -> Optional[Task]:
|
| 102 |
+
for t in self.tasks:
|
| 103 |
+
if t.task_id == task_id:
|
| 104 |
+
return t
|
| 105 |
+
return None
|
forgeenv/verifier/__init__.py
ADDED
|
File without changes
|
forgeenv/verifier/held_out_evaluator.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Held-out evaluator: the deterministic ground-truth scorer.
|
| 2 |
+
|
| 3 |
+
Returns 7 independent components in [0, 1]. The Repair Agent NEVER sees
|
| 4 |
+
this directly; the Drift Generator's training signal derives from
|
| 5 |
+
alignment between the visible verifier and this evaluator (Pearson
|
| 6 |
+
correlation across the K rollouts).
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import ast
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
from forgeenv.tasks.models import ExecutionResult, Task
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def compute_held_out_scores(
|
| 17 |
+
result: ExecutionResult, task: Task, repair_diff: str = ""
|
| 18 |
+
) -> dict[str, float]:
|
| 19 |
+
"""Compute 7 independent held-out components."""
|
| 20 |
+
|
| 21 |
+
scores: dict[str, float] = {
|
| 22 |
+
"executed_cleanly": 1.0 if result.exit_code == 0 else 0.0,
|
| 23 |
+
"checkpoint_valid": 1.0 if result.checkpoint_exists else 0.0,
|
| 24 |
+
"loss_decreased": _compute_loss_score(result.stdout),
|
| 25 |
+
"metrics_in_range": _check_metrics(result.stdout, task),
|
| 26 |
+
"no_forbidden_workarounds": _check_workarounds(result.script_content),
|
| 27 |
+
"intent_preserved": _compute_intent_preservation(
|
| 28 |
+
task.script_content, result.script_content
|
| 29 |
+
),
|
| 30 |
+
"hidden_tests_passed": 1.0 if "TRAINING_COMPLETE" in result.stdout else 0.0,
|
| 31 |
+
}
|
| 32 |
+
return scores
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _compute_loss_score(stdout: str) -> float:
|
| 36 |
+
"""Continuous score based on relative loss decrease from first to last step."""
|
| 37 |
+
|
| 38 |
+
losses: list[float] = []
|
| 39 |
+
for line in stdout.splitlines():
|
| 40 |
+
match = re.search(r"loss[=:\s]+([\d.]+)", line, re.IGNORECASE)
|
| 41 |
+
if match:
|
| 42 |
+
try:
|
| 43 |
+
losses.append(float(match.group(1)))
|
| 44 |
+
except ValueError:
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
if len(losses) < 2:
|
| 48 |
+
return 0.0
|
| 49 |
+
|
| 50 |
+
decrease = (losses[0] - losses[-1]) / max(losses[0], 1e-8)
|
| 51 |
+
return max(0.0, min(1.0, decrease))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _check_metrics(stdout: str, task: Task) -> float:
|
| 55 |
+
"""Return 1.0 if any reported accuracy/eval metric falls in the task's
|
| 56 |
+
expected range; partial credit otherwise; 0.5 if no metric was found."""
|
| 57 |
+
|
| 58 |
+
for line in stdout.splitlines():
|
| 59 |
+
match = re.search(r"(?:accuracy|acc|eval)[=:\s]+([\d.]+)", line, re.IGNORECASE)
|
| 60 |
+
if match:
|
| 61 |
+
try:
|
| 62 |
+
val = float(match.group(1))
|
| 63 |
+
low, high = task.expected_accuracy_range
|
| 64 |
+
if low <= val <= high:
|
| 65 |
+
return 1.0
|
| 66 |
+
distance = min(abs(val - low), abs(val - high))
|
| 67 |
+
return max(0.0, 1.0 - distance)
|
| 68 |
+
except ValueError:
|
| 69 |
+
continue
|
| 70 |
+
return 0.5
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _check_workarounds(script_content: str) -> float:
|
| 74 |
+
"""Detect forbidden workaround patterns via AST analysis.
|
| 75 |
+
|
| 76 |
+
Catches: bare except, `except Exception: pass`, `except Exception: return`,
|
| 77 |
+
monkey-patching of `__getattr__` / `__class__` / `__dict__`.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
if not script_content:
|
| 81 |
+
return 0.0
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
tree = ast.parse(script_content)
|
| 85 |
+
except SyntaxError:
|
| 86 |
+
return 0.0
|
| 87 |
+
|
| 88 |
+
violations = 0
|
| 89 |
+
|
| 90 |
+
for node in ast.walk(tree):
|
| 91 |
+
if isinstance(node, ast.Try):
|
| 92 |
+
for handler in node.handlers:
|
| 93 |
+
if handler.type is None:
|
| 94 |
+
violations += 1
|
| 95 |
+
elif (
|
| 96 |
+
isinstance(handler.type, ast.Name)
|
| 97 |
+
and handler.type.id == "Exception"
|
| 98 |
+
):
|
| 99 |
+
if len(handler.body) == 1 and isinstance(
|
| 100 |
+
handler.body[0], (ast.Pass, ast.Return)
|
| 101 |
+
):
|
| 102 |
+
violations += 1
|
| 103 |
+
|
| 104 |
+
if isinstance(node, ast.Assign):
|
| 105 |
+
for target in node.targets:
|
| 106 |
+
if isinstance(target, ast.Attribute):
|
| 107 |
+
if target.attr in ("__getattr__", "__class__", "__dict__"):
|
| 108 |
+
violations += 1
|
| 109 |
+
|
| 110 |
+
return 1.0 if violations == 0 else max(0.0, 1.0 - violations * 0.3)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _compute_intent_preservation(original: str, repaired: str) -> float:
|
| 114 |
+
"""Measure how much of the original AST structure is preserved.
|
| 115 |
+
|
| 116 |
+
Uses ratio of shared AST node count: min(N_orig, N_repair) / max(...).
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
if not original or not repaired:
|
| 120 |
+
return 0.0
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
orig_tree = ast.parse(original)
|
| 124 |
+
repair_tree = ast.parse(repaired)
|
| 125 |
+
except SyntaxError:
|
| 126 |
+
return 0.0
|
| 127 |
+
|
| 128 |
+
orig_nodes = len(list(ast.walk(orig_tree)))
|
| 129 |
+
repair_nodes = len(list(ast.walk(repair_tree)))
|
| 130 |
+
|
| 131 |
+
if orig_nodes == 0:
|
| 132 |
+
return 0.0
|
| 133 |
+
|
| 134 |
+
return min(orig_nodes, repair_nodes) / max(orig_nodes, repair_nodes)
|
forgeenv/verifier/visible_verifier.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Visible verifier: the immediate reward signal the Repair Agent sees.
|
| 2 |
+
|
| 3 |
+
4 weighted components, summed to a scalar. This is what drives the Repair
|
| 4 |
+
Agent's GRPO updates each rollout. Multiple independent components were
|
| 5 |
+
chosen on purpose, per the reward-engineering survey (arxiv 2408.10215)
|
| 6 |
+
and software-tasks survey (arxiv 2601.19100): a single scalar is far
|
| 7 |
+
easier to game than a composable rubric.
|
| 8 |
+
"""
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
from forgeenv.tasks.models import ExecutionResult, Task
|
| 14 |
+
|
| 15 |
+
WEIGHTS: dict[str, float] = {
|
| 16 |
+
"script_executes": 1.0,
|
| 17 |
+
"loss_decreased": 0.5,
|
| 18 |
+
"checkpoint_appeared": 0.3,
|
| 19 |
+
"diff_size_penalty": 0.2, # multiplied with a non-positive component value
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def compute_visible_reward(
|
| 24 |
+
result: ExecutionResult, task: Task
|
| 25 |
+
) -> tuple[float, dict[str, float]]:
|
| 26 |
+
"""Compute scalar visible reward and per-component breakdown."""
|
| 27 |
+
|
| 28 |
+
components: dict[str, float] = {}
|
| 29 |
+
|
| 30 |
+
components["script_executes"] = 1.0 if result.exit_code == 0 else 0.0
|
| 31 |
+
components["loss_decreased"] = _check_loss_trend(result.stdout)
|
| 32 |
+
components["checkpoint_appeared"] = 1.0 if result.checkpoint_exists else 0.0
|
| 33 |
+
|
| 34 |
+
original_lines = max(len(task.script_content.splitlines()), 1)
|
| 35 |
+
current_lines = (
|
| 36 |
+
len(result.script_content.splitlines()) if result.script_content else original_lines
|
| 37 |
+
)
|
| 38 |
+
diff_ratio = abs(current_lines - original_lines) / original_lines
|
| 39 |
+
components["diff_size_penalty"] = -1.0 * diff_ratio if diff_ratio > 0.5 else 0.0
|
| 40 |
+
|
| 41 |
+
total = sum(components[k] * WEIGHTS[k] for k in components)
|
| 42 |
+
return total, components
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _check_loss_trend(stdout: str) -> float:
|
| 46 |
+
"""Parse stdout for `loss=...` patterns and return the fraction of
|
| 47 |
+
consecutive steps where loss strictly decreased."""
|
| 48 |
+
|
| 49 |
+
losses: list[float] = []
|
| 50 |
+
for line in stdout.splitlines():
|
| 51 |
+
match = re.search(r"loss[=:\s]+([\d.]+)", line, re.IGNORECASE)
|
| 52 |
+
if match:
|
| 53 |
+
try:
|
| 54 |
+
losses.append(float(match.group(1)))
|
| 55 |
+
except ValueError:
|
| 56 |
+
continue
|
| 57 |
+
|
| 58 |
+
if len(losses) < 2:
|
| 59 |
+
return 0.0
|
| 60 |
+
|
| 61 |
+
decreasing_steps = sum(
|
| 62 |
+
1 for i in range(1, len(losses)) if losses[i] < losses[i - 1]
|
| 63 |
+
)
|
| 64 |
+
return decreasing_steps / (len(losses) - 1)
|
openenv.yaml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: forgeenv
|
| 2 |
+
version: 0.1.0
|
| 3 |
+
description: >
|
| 4 |
+
Self-improving RL environment for HuggingFace ecosystem repair.
|
| 5 |
+
Trains agents to fix broken training scripts under library version drift
|
| 6 |
+
through Challenger-Solver co-evolution. Implements R-Zero (Tencent), SPIRAL,
|
| 7 |
+
and Absolute Zero Reasoner techniques on top of OpenEnv.
|
| 8 |
+
theme: self-improvement
|
| 9 |
+
tags:
|
| 10 |
+
- openenv
|
| 11 |
+
- self-play
|
| 12 |
+
- code-repair
|
| 13 |
+
- schema-drift
|
| 14 |
+
- multi-role
|
| 15 |
+
- huggingface
|
| 16 |
+
- reinforcement-learning
|
| 17 |
+
environment:
|
| 18 |
+
class: forgeenv.env.forge_environment.ForgeEnvironment
|
| 19 |
+
action_model: forgeenv.env.actions.ForgeAction
|
| 20 |
+
observation_model: forgeenv.env.observations.ForgeObservation
|
| 21 |
+
server:
|
| 22 |
+
module: forgeenv.env.server
|
| 23 |
+
app: app
|
| 24 |
+
port: 7860
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core>=0.2.0
|
| 2 |
+
fastapi>=0.110.0
|
| 3 |
+
uvicorn[standard]>=0.27.0
|
| 4 |
+
pydantic>=2.6.0
|
| 5 |
+
pyyaml>=6.0
|
| 6 |
+
nltk>=3.8.0
|
| 7 |
+
scikit-learn>=1.4.0
|
| 8 |
+
numpy>=1.26.0
|
| 9 |
+
rich>=13.7.0
|