akhiilll commited on
Commit
fb3e132
·
verified ·
1 Parent(s): bf2b681

ForgeEnv deploy

Browse files
Files changed (44) hide show
  1. Dockerfile +25 -0
  2. README.md +85 -10
  3. forgeenv/__init__.py +4 -0
  4. forgeenv/artifacts/repair_library.py +120 -0
  5. forgeenv/drift/__init__.py +0 -0
  6. forgeenv/drift/library_drift_engine.py +74 -0
  7. forgeenv/env/__init__.py +0 -0
  8. forgeenv/env/actions.py +50 -0
  9. forgeenv/env/diff_utils.py +163 -0
  10. forgeenv/env/forge_environment.py +259 -0
  11. forgeenv/env/observations.py +29 -0
  12. forgeenv/env/server.py +46 -0
  13. forgeenv/primitives/__init__.py +0 -0
  14. forgeenv/primitives/breakage_primitives.py +282 -0
  15. forgeenv/primitives/drift_taxonomy.yaml +217 -0
  16. forgeenv/primitives/repair_primitives.py +241 -0
  17. forgeenv/roles/__init__.py +0 -0
  18. forgeenv/roles/drift_generator.py +170 -0
  19. forgeenv/roles/prompts.py +102 -0
  20. forgeenv/roles/repair_agent.py +153 -0
  21. forgeenv/roles/teacher.py +58 -0
  22. forgeenv/sandbox/__init__.py +0 -0
  23. forgeenv/sandbox/ast_validator.py +70 -0
  24. forgeenv/sandbox/simulation_mode.py +142 -0
  25. forgeenv/tasks/__init__.py +0 -0
  26. forgeenv/tasks/models.py +45 -0
  27. forgeenv/tasks/seed_corpus/__init__.py +0 -0
  28. forgeenv/tasks/seed_corpus/albert_qa.py +67 -0
  29. forgeenv/tasks/seed_corpus/bert_ner.py +55 -0
  30. forgeenv/tasks/seed_corpus/distilbert_sst2.py +53 -0
  31. forgeenv/tasks/seed_corpus/electra_classification.py +44 -0
  32. forgeenv/tasks/seed_corpus/gpt2_textgen.py +43 -0
  33. forgeenv/tasks/seed_corpus/logistic_classifier.py +36 -0
  34. forgeenv/tasks/seed_corpus/roberta_sentiment.py +44 -0
  35. forgeenv/tasks/seed_corpus/simple_regression.py +28 -0
  36. forgeenv/tasks/seed_corpus/t5_summarization.py +55 -0
  37. forgeenv/tasks/seed_corpus/tiny_mlp_mnist.py +38 -0
  38. forgeenv/tasks/seed_corpus/vit_cifar10.py +41 -0
  39. forgeenv/tasks/task_sampler.py +105 -0
  40. forgeenv/verifier/__init__.py +0 -0
  41. forgeenv/verifier/held_out_evaluator.py +134 -0
  42. forgeenv/verifier/visible_verifier.py +64 -0
  43. openenv.yaml +24 -0
  44. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ ENV PYTHONUNBUFFERED=1 \
4
+ PYTHONDONTWRITEBYTECODE=1 \
5
+ PIP_NO_CACHE_DIR=1
6
+
7
+ RUN apt-get update \
8
+ && apt-get install -y --no-install-recommends git curl \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ WORKDIR /app
12
+
13
+ COPY requirements.txt .
14
+ RUN pip install --no-cache-dir -r requirements.txt
15
+
16
+ COPY forgeenv/ forgeenv/
17
+ COPY openenv.yaml .
18
+
19
+ ENV PORT=7860
20
+ EXPOSE 7860
21
+
22
+ HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 \
23
+ CMD curl -f http://127.0.0.1:7860/health || exit 1
24
+
25
+ CMD ["uvicorn", "forgeenv.env.server:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,85 @@
1
- ---
2
- title: Forgeenv
3
- emoji: 🏆
4
- colorFrom: purple
5
- colorTo: yellow
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ForgeEnv
3
+ emoji: 🔧
4
+ colorFrom: indigo
5
+ colorTo: green
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: true
9
+ license: apache-2.0
10
+ tags:
11
+ - openenv
12
+ - self-play
13
+ - self-improvement
14
+ - code-repair
15
+ - schema-drift
16
+ - reinforcement-learning
17
+ - huggingface
18
+ short_description: Self-improving RL env for HF library-drift repair
19
+ ---
20
+
21
+ # ForgeEnv — OpenEnv Server
22
+
23
+ This Space hosts the **ForgeEnv** OpenEnv-compliant environment as a FastAPI
24
+ service. It exposes the standard `reset`, `step`, and `state` endpoints and is
25
+ the runtime that training notebooks (TRL + Unsloth) connect to.
26
+
27
+ > **Theme:** Self-Improvement (Hackathon Theme #4) — Challenger / Solver
28
+ > co-evolution via R-Zero, SPIRAL, and Absolute Zero Reasoner techniques.
29
+
30
+ ## What it does
31
+
32
+ ForgeEnv simulates **HuggingFace library version drift**. A *Drift Generator*
33
+ proposes a realistic breakage to a working training script (renamed APIs,
34
+ deprecated imports, changed argument signatures, etc.). A *Repair Agent* then
35
+ emits a unified diff that should restore the script. Reward is computed by an
36
+ execution simulator + AST checker + held-out evaluator (multi-component to
37
+ resist reward hacking).
38
+
39
+ ## API
40
+
41
+ The server uses [`openenv-core`](https://pypi.org/project/openenv-core/) and
42
+ follows the Gym-style contract:
43
+
44
+ | Endpoint | Method | Purpose |
45
+ | -------- | ------ | -------------------------------------------------- |
46
+ | `/reset` | POST | Sample a fresh task, return drift-gen observation |
47
+ | `/step` | POST | Apply a `ForgeAction` (breakage or repair) |
48
+ | `/state` | GET | Inspect the current internal state |
49
+ | `/health`| GET | Health probe (used by the container HEALTHCHECK) |
50
+
51
+ `ForgeAction` is a discriminated union of `BreakageAction` (used in phase 1)
52
+ and `RepairAction` (used in phase 2). See
53
+ [`forgeenv/env/actions.py`](forgeenv/env/actions.py).
54
+
55
+ ## Quick test
56
+
57
+ ```bash
58
+ curl -X POST https://akhiilll-forgeenv.hf.space/reset
59
+ curl https://akhiilll-forgeenv.hf.space/state
60
+ ```
61
+
62
+ ```python
63
+ from openenv.core.env_client import EnvClient
64
+
65
+ async with EnvClient(base_url="https://akhiilll-forgeenv.hf.space") as client:
66
+ obs = await client.reset()
67
+ print(obs.observation.current_phase, obs.observation.task_id)
68
+ ```
69
+
70
+ ## Project links
71
+
72
+ - **Main repo / training notebooks / plots:**
73
+ <https://github.com/akhiilll/forgeenv>
74
+ - **Repair Agent model (LoRA):**
75
+ <https://huggingface.co/akhiilll/forgeenv-repair-agent>
76
+ - **Demo (Gradio + ZeroGPU):**
77
+ <https://huggingface.co/spaces/akhiilll/forgeenv-demo>
78
+
79
+ ## Citations
80
+
81
+ - Huang et al., *R-Zero: Self-Evolving Reasoning LLM From Zero Data* (2025)
82
+ - Zhao et al., *Absolute Zero: Reinforced Self-play Reasoning with Zero Data* (2025)
83
+ - Liu et al., *SPIRAL: Self-Play on Zero-Sum Games* (2025)
84
+ - [arXiv:2408.10215](https://arxiv.org/abs/2408.10215) — Reward engineering & shaping
85
+ - [arXiv:2601.19100](https://arxiv.org/abs/2601.19100) — Reward engineering for RL in software tasks
forgeenv/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """ForgeEnv: Self-improving RL environment for HuggingFace ecosystem repair."""
2
+
3
+ __version__ = "0.1.0"
4
+ __author__ = "akhiilll"
forgeenv/artifacts/repair_library.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Persisted "repair library" — the model's accumulated knowledge of
2
+ known breakage -> repair pairs. Curated from successful rollouts during
3
+ training. Loaded at inference time as a few-shot prefix when the agent
4
+ recognises a familiar error class.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ from dataclasses import asdict, dataclass, field
10
+ from pathlib import Path
11
+ from typing import Any, Optional
12
+
13
+
14
+ @dataclass
15
+ class RepairExample:
16
+ primitive_type: str
17
+ breakage_params: dict[str, Any]
18
+ error_signature: str
19
+ repair_diff: str
20
+ visible_reward: float
21
+ held_out: dict[str, float]
22
+ task_id: str = ""
23
+
24
+ def signature_key(self) -> str:
25
+ return f"{self.primitive_type}::{self.error_signature[:80]}"
26
+
27
+
28
+ @dataclass
29
+ class RepairLibrary:
30
+ examples: list[RepairExample] = field(default_factory=list)
31
+
32
+ def add(self, example: RepairExample) -> None:
33
+ self.examples.append(example)
34
+
35
+ def best_match(self, primitive_type: str, error_text: str) -> Optional[RepairExample]:
36
+ """Return the highest-reward example whose primitive_type matches and
37
+ whose error text overlaps."""
38
+ candidates = [
39
+ e for e in self.examples if e.primitive_type == primitive_type
40
+ ]
41
+ if not candidates:
42
+ return None
43
+ scored = sorted(
44
+ candidates,
45
+ key=lambda e: (
46
+ _ngram_overlap(e.error_signature, error_text),
47
+ e.visible_reward,
48
+ ),
49
+ reverse=True,
50
+ )
51
+ return scored[0] if scored else None
52
+
53
+ def to_dict(self) -> dict:
54
+ return {
55
+ "version": "1",
56
+ "examples": [asdict(e) for e in self.examples],
57
+ "size": len(self.examples),
58
+ "by_primitive": _count_by_primitive(self.examples),
59
+ }
60
+
61
+ def save(self, path: str | Path) -> None:
62
+ path = Path(path)
63
+ path.parent.mkdir(parents=True, exist_ok=True)
64
+ path.write_text(json.dumps(self.to_dict(), indent=2), encoding="utf-8")
65
+
66
+ @classmethod
67
+ def load(cls, path: str | Path) -> "RepairLibrary":
68
+ data = json.loads(Path(path).read_text(encoding="utf-8"))
69
+ examples = [RepairExample(**e) for e in data.get("examples", [])]
70
+ return cls(examples=examples)
71
+
72
+
73
+ def _ngram_overlap(a: str, b: str, n: int = 3) -> float:
74
+ if not a or not b:
75
+ return 0.0
76
+
77
+ def grams(text: str) -> set[str]:
78
+ text = text.lower()
79
+ return {text[i : i + n] for i in range(len(text) - n + 1)}
80
+
81
+ ga, gb = grams(a), grams(b)
82
+ if not ga or not gb:
83
+ return 0.0
84
+ return len(ga & gb) / max(1, len(ga | gb))
85
+
86
+
87
+ def _count_by_primitive(examples: list[RepairExample]) -> dict[str, int]:
88
+ counts: dict[str, int] = {}
89
+ for e in examples:
90
+ counts[e.primitive_type] = counts.get(e.primitive_type, 0) + 1
91
+ return counts
92
+
93
+
94
+ def curate_from_rollouts(
95
+ rollout_results: list,
96
+ min_reward: float = 0.6,
97
+ min_held_out_clean: float = 0.5,
98
+ ) -> RepairLibrary:
99
+ """Build a RepairLibrary from a list of rollout dicts/RolloutResults."""
100
+ lib = RepairLibrary()
101
+ for r in rollout_results:
102
+ get = r.get if isinstance(r, dict) else lambda k, default=None: getattr(r, k, default)
103
+ if float(get("visible_reward", 0.0) or 0.0) < min_reward:
104
+ continue
105
+ if float(get("held_out_breakdown", {}).get("executed_cleanly", 0.0)) < min_held_out_clean:
106
+ continue
107
+ lib.add(
108
+ RepairExample(
109
+ primitive_type=str(get("primitive_type", "unknown")),
110
+ breakage_params=dict(get("info", {}).get("breakage_spec", {}).get("params", {}))
111
+ if isinstance(get("info", {}), dict)
112
+ else {},
113
+ error_signature=str(get("error_trace", "") or "")[:160],
114
+ repair_diff=str(get("repair_completion", "") or get("info", {}).get("repair_diff", ""))[:2000],
115
+ visible_reward=float(get("visible_reward", 0.0) or 0.0),
116
+ held_out=dict(get("held_out_breakdown", {}) or {}),
117
+ task_id=str(get("task_id", "")),
118
+ )
119
+ )
120
+ return lib
forgeenv/drift/__init__.py ADDED
File without changes
forgeenv/drift/library_drift_engine.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Library Drift Engine.
2
+
3
+ Manages library version snapshots and triggers version upgrades during
4
+ training to create non-stationary verification. In simulation mode it
5
+ just tracks the current snapshot index — that index influences
6
+ breakage selection and is exposed in observations so the Repair Agent
7
+ can adapt.
8
+
9
+ Also exposes Chojecki GVU's SNR computation
10
+ (https://arxiv.org/abs/2512.02731 Definition 4.4).
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import math
15
+ from dataclasses import dataclass, field
16
+
17
+ DEFAULT_VERSION_SNAPSHOTS: list[dict[str, str]] = [
18
+ {"transformers": "4.36.0", "datasets": "2.14.0", "trl": "0.7.0"},
19
+ {"transformers": "4.40.0", "datasets": "2.18.0", "trl": "0.8.0"},
20
+ {"transformers": "4.45.0", "datasets": "3.0.0", "trl": "0.10.0"},
21
+ {"transformers": "4.50.0", "datasets": "3.2.0", "trl": "0.12.0"},
22
+ ]
23
+
24
+
25
+ @dataclass
26
+ class LibraryDriftEngine:
27
+ snapshots: list[dict[str, str]] = field(
28
+ default_factory=lambda: list(DEFAULT_VERSION_SNAPSHOTS)
29
+ )
30
+ current_index: int = 0
31
+ drift_history: list[dict] = field(default_factory=list)
32
+
33
+ def current_versions(self) -> dict[str, str]:
34
+ return dict(self.snapshots[self.current_index])
35
+
36
+ def maybe_drift(self, episode_num: int, drift_every: int = 50) -> bool:
37
+ if (
38
+ episode_num > 0
39
+ and episode_num % drift_every == 0
40
+ and self.current_index < len(self.snapshots) - 1
41
+ ):
42
+ prev = self.snapshots[self.current_index]
43
+ self.current_index += 1
44
+ self.drift_history.append(
45
+ {
46
+ "episode": episode_num,
47
+ "from": prev,
48
+ "to": self.snapshots[self.current_index],
49
+ }
50
+ )
51
+ return True
52
+ return False
53
+
54
+ def reset(self) -> None:
55
+ self.current_index = 0
56
+ self.drift_history.clear()
57
+
58
+ @staticmethod
59
+ def compute_snr(
60
+ recent_held_out: list[float], recent_visible: list[float]
61
+ ) -> dict[str, float]:
62
+ """SNR per Chojecki GVU Def 4.4: SNR = mean(rewards)^2 / variance(rewards)."""
63
+
64
+ def snr(values: list[float]) -> float:
65
+ if len(values) < 2:
66
+ return 0.0
67
+ mean = sum(values) / len(values)
68
+ var = sum((v - mean) ** 2 for v in values) / len(values)
69
+ return mean**2 / max(var, 1e-8)
70
+
71
+ return {
72
+ "snr_verifier": snr(recent_held_out),
73
+ "snr_generator": snr(recent_visible),
74
+ }
forgeenv/env/__init__.py ADDED
File without changes
forgeenv/env/actions.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic action models for ForgeEnv (compatible with OpenEnv 0.2.x).
2
+
3
+ Episodes have two phases — drift_gen (Challenger) and repair (Solver) — so
4
+ we expose a single union ForgeAction that carries either a BreakageAction
5
+ or a RepairAction. The environment dispatches on which sub-field is set.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ from typing import Any, Literal, Optional
10
+
11
+ from pydantic import Field
12
+
13
+ from openenv.core import Action
14
+
15
+
16
+ class BreakageAction(Action):
17
+ """Drift Generator's action: pick a primitive type + parameters."""
18
+
19
+ action_type: Literal["breakage"] = "breakage"
20
+ primitive_type: str = Field(
21
+ ..., description="One of the registered breakage primitive class names"
22
+ )
23
+ params: dict[str, Any] = Field(
24
+ default_factory=dict, description="Primitive-specific parameters"
25
+ )
26
+
27
+
28
+ class RepairAction(Action):
29
+ """Repair Agent's action: a unified diff (or full replacement script)."""
30
+
31
+ action_type: Literal["repair"] = "repair"
32
+ unified_diff: str = Field(..., description="Unified diff or full replacement script")
33
+
34
+
35
+ class ForgeAction(Action):
36
+ """Union action: exactly one of `breakage` / `repair` must be set.
37
+
38
+ This is the type registered with OpenEnv's `create_app`. It avoids
39
+ Pydantic discriminated unions to keep the OpenAPI schema flat and
40
+ cross-version-friendly.
41
+ """
42
+
43
+ breakage: Optional[BreakageAction] = None
44
+ repair: Optional[RepairAction] = None
45
+
46
+ def model_post_init(self, __context: Any) -> None:
47
+ if (self.breakage is None) == (self.repair is None):
48
+ raise ValueError(
49
+ "ForgeAction requires exactly one of `breakage` or `repair` to be set."
50
+ )
forgeenv/env/diff_utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unified-diff application utilities.
2
+
3
+ The Repair Agent submits a unified diff. We need a permissive applier
4
+ because LLM diffs are often malformed (wrong line numbers, missing
5
+ context, extra prose). We try the strict applier first, then fall
6
+ back to applying hunks via plain string replacement.
7
+
8
+ The agent may also submit a full Python script instead of a diff
9
+ (common when the model's diff format breaks). We detect this and
10
+ treat it as a complete replacement.
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import difflib
15
+ import re
16
+
17
+
18
+ _HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE)
19
+ _SCRIPT_MARKERS = ("import ", "from ", "def ", "class ", "print(")
20
+
21
+
22
+ def looks_like_full_script(text: str) -> bool:
23
+ """Heuristic: text is probably a full python script, not a diff."""
24
+ lines = text.lstrip().splitlines()
25
+ if not lines:
26
+ return False
27
+ has_diff_header = any(
28
+ line.startswith(("---", "+++", "@@")) for line in lines[:5]
29
+ )
30
+ if has_diff_header:
31
+ return False
32
+ # If we see two or more script-style markers in the first 30 lines,
33
+ # treat as a full replacement script.
34
+ head = "\n".join(lines[:30])
35
+ hits = sum(1 for marker in _SCRIPT_MARKERS if marker in head)
36
+ return hits >= 2
37
+
38
+
39
+ def _strict_apply(broken_script: str, diff_text: str) -> str | None:
40
+ """Apply a unified diff strictly. Returns None on any failure."""
41
+ lines = broken_script.splitlines(keepends=True)
42
+ out: list[str] = []
43
+ diff_lines = diff_text.splitlines()
44
+ i = 0
45
+ src_idx = 0
46
+ in_hunk = False
47
+ hunk_old: list[str] = []
48
+ hunk_new: list[str] = []
49
+
50
+ while i < len(diff_lines):
51
+ line = diff_lines[i]
52
+ if line.startswith(("---", "+++")):
53
+ i += 1
54
+ continue
55
+ if line.startswith("@@"):
56
+ # Flush previous hunk
57
+ if in_hunk:
58
+ # Find the hunk_old block in the source starting at src_idx.
59
+ target = "".join(hunk_old)
60
+ source_remainder = "".join(lines[src_idx:])
61
+ pos = source_remainder.find(target)
62
+ if pos == -1:
63
+ return None
64
+ out.append(source_remainder[:pos])
65
+ out.append("".join(hunk_new))
66
+ src_idx += len(source_remainder[: pos + len(target)].splitlines(keepends=True))
67
+ hunk_old, hunk_new = [], []
68
+ in_hunk = True
69
+ i += 1
70
+ continue
71
+ if in_hunk:
72
+ if line.startswith("+"):
73
+ hunk_new.append(line[1:] + "\n")
74
+ elif line.startswith("-"):
75
+ hunk_old.append(line[1:] + "\n")
76
+ else:
77
+ # context line
78
+ ctx = line[1:] if line.startswith(" ") else line
79
+ hunk_old.append(ctx + "\n")
80
+ hunk_new.append(ctx + "\n")
81
+ i += 1
82
+
83
+ # Flush trailing hunk
84
+ if in_hunk and (hunk_old or hunk_new):
85
+ target = "".join(hunk_old)
86
+ source_remainder = "".join(lines[src_idx:])
87
+ pos = source_remainder.find(target)
88
+ if pos == -1:
89
+ return None
90
+ out.append(source_remainder[:pos])
91
+ out.append("".join(hunk_new))
92
+ consumed = source_remainder[: pos + len(target)]
93
+ src_idx += len(consumed.splitlines(keepends=True))
94
+
95
+ out.append("".join(lines[src_idx:]))
96
+ return "".join(out)
97
+
98
+
99
+ def _permissive_apply(broken_script: str, diff_text: str) -> str:
100
+ """Apply a malformed diff by extracting (-,+) line pairs and doing
101
+ a tolerant search-and-replace.
102
+ """
103
+ repaired = broken_script
104
+ pairs: list[tuple[str, str]] = []
105
+ lines = diff_text.splitlines()
106
+ pending_minus: str | None = None
107
+
108
+ for line in lines:
109
+ if line.startswith("---") or line.startswith("+++") or line.startswith("@@"):
110
+ pending_minus = None
111
+ continue
112
+ if line.startswith("-"):
113
+ pending_minus = line[1:].strip()
114
+ elif line.startswith("+") and pending_minus is not None:
115
+ pairs.append((pending_minus, line[1:].strip()))
116
+ pending_minus = None
117
+ elif pending_minus is not None and not line.startswith(" "):
118
+ # standalone deletion — skip in permissive mode (we can't
119
+ # reliably know what to delete without context)
120
+ pending_minus = None
121
+
122
+ for old, new in pairs:
123
+ if old and old in repaired:
124
+ repaired = repaired.replace(old, new, 1)
125
+
126
+ return repaired
127
+
128
+
129
+ def apply_unified_diff(broken_script: str, diff_text: str) -> str:
130
+ """Try every strategy in order and return the first that produces a change.
131
+
132
+ Strategies:
133
+ 1. If `diff_text` looks like a full script, return it directly.
134
+ 2. Try strict diff application.
135
+ 3. Fall back to permissive (-,+) line-pair replacement.
136
+ 4. As last resort, return the broken script unchanged.
137
+ """
138
+ diff_text = diff_text or ""
139
+ if not diff_text.strip():
140
+ return broken_script
141
+
142
+ if looks_like_full_script(diff_text):
143
+ return diff_text
144
+
145
+ if _HUNK_RE.search(diff_text) or "---" in diff_text or "+++" in diff_text:
146
+ strict = _strict_apply(broken_script, diff_text)
147
+ if strict is not None and strict != broken_script:
148
+ return strict
149
+
150
+ perm = _permissive_apply(broken_script, diff_text)
151
+ return perm
152
+
153
+
154
+ def make_unified_diff(before: str, after: str, path: str = "train.py") -> str:
155
+ """Produce a canonical unified diff from before -> after."""
156
+ diff = difflib.unified_diff(
157
+ before.splitlines(keepends=True),
158
+ after.splitlines(keepends=True),
159
+ fromfile=f"a/{path}",
160
+ tofile=f"b/{path}",
161
+ n=2,
162
+ )
163
+ return "".join(diff)
forgeenv/env/forge_environment.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ForgeEnvironment: the OpenEnv Environment subclass for ForgeEnv.
2
+
3
+ Episode flow (exactly 2 steps per episode):
4
+ reset() -> sample task, ask Teacher for category
5
+ step(BreakageAction) -> Drift Generator's proposal is applied; broken
6
+ script is run, error trace captured.
7
+ step(RepairAction) -> Repair diff is applied; script is re-executed;
8
+ visible + held-out rewards computed; episode ends.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import time
13
+ import uuid
14
+ from typing import Any, Optional
15
+
16
+ from openenv.core import Environment
17
+
18
+ from forgeenv.drift.library_drift_engine import LibraryDriftEngine
19
+ from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction
20
+ from forgeenv.env.diff_utils import apply_unified_diff
21
+ from forgeenv.env.observations import ForgeObservation
22
+ from forgeenv.primitives.breakage_primitives import (
23
+ PRIMITIVE_REGISTRY,
24
+ parse_breakage_spec,
25
+ )
26
+ from forgeenv.roles.teacher import Teacher
27
+ from forgeenv.sandbox.simulation_mode import SimulationExecutor
28
+ from forgeenv.tasks.models import ExecutionResult, Task
29
+ from forgeenv.tasks.task_sampler import TaskSampler
30
+ from forgeenv.verifier.held_out_evaluator import compute_held_out_scores
31
+ from forgeenv.verifier.visible_verifier import compute_visible_reward
32
+
33
+ DEFAULT_CATEGORIES = sorted(PRIMITIVE_REGISTRY.keys())
34
+
35
+
36
+ class ForgeEnvironment(Environment[ForgeAction, ForgeObservation, dict]):
37
+ """OpenEnv-compliant environment for HuggingFace ecosystem repair."""
38
+
39
+ SUPPORTS_CONCURRENT_SESSIONS = False # Teacher state is global per env
40
+
41
+ def __init__(
42
+ self,
43
+ task_sampler: Optional[TaskSampler] = None,
44
+ teacher: Optional[Teacher] = None,
45
+ executor: Optional[SimulationExecutor] = None,
46
+ drift_engine: Optional[LibraryDriftEngine] = None,
47
+ seed: Optional[int] = None,
48
+ ) -> None:
49
+ super().__init__()
50
+ self.task_sampler = task_sampler or TaskSampler()
51
+ self.teacher = teacher or Teacher(
52
+ categories=list(DEFAULT_CATEGORIES) or ["api_drift"]
53
+ )
54
+ self.executor = executor or SimulationExecutor(seed=seed)
55
+ self.drift_engine = drift_engine or LibraryDriftEngine()
56
+
57
+ self._episode_id: Optional[str] = None
58
+ self._episode_count: int = 0
59
+ self._current_task: Optional[Task] = None
60
+ self._original_script: str = ""
61
+ self._broken_script: str = ""
62
+ self._error_trace: str = ""
63
+ self._breakage_spec: Optional[dict[str, Any]] = None
64
+ self._target_category: str = ""
65
+ self._current_phase: str = "idle"
66
+ self._last_obs: Optional[ForgeObservation] = None
67
+
68
+ # ------------------------------------------------------------------ API
69
+ def reset(
70
+ self,
71
+ seed: Optional[int] = None,
72
+ episode_id: Optional[str] = None,
73
+ difficulty: Optional[str] = "easy",
74
+ **kwargs: Any,
75
+ ) -> ForgeObservation:
76
+ self._episode_id = episode_id or str(uuid.uuid4())
77
+ self._episode_count += 1
78
+ self._target_category = self.teacher.select_next_category()
79
+
80
+ task = self.task_sampler.sample(difficulty=difficulty)
81
+ if task is None:
82
+ raise RuntimeError("Task sampler returned no tasks (empty seed corpus?)")
83
+ self._current_task = task
84
+ self._original_script = task.script_content
85
+ self._broken_script = ""
86
+ self._error_trace = ""
87
+ self._breakage_spec = None
88
+ self._current_phase = "drift_gen"
89
+
90
+ # Library drift trigger every 50 episodes (configurable from outside).
91
+ drifted = self.drift_engine.maybe_drift(self._episode_count, drift_every=50)
92
+
93
+ obs = ForgeObservation(
94
+ current_phase="drift_gen",
95
+ task_id=task.task_id,
96
+ task_description=task.description,
97
+ target_category=self._target_category,
98
+ script_content=self._original_script,
99
+ error_trace=None,
100
+ library_versions=self.drift_engine.current_versions(),
101
+ episode_step=0,
102
+ done=False,
103
+ reward=0.0,
104
+ info={
105
+ "episode_id": self._episode_id,
106
+ "episode_count": self._episode_count,
107
+ "drift_triggered": drifted,
108
+ "available_primitives": sorted(PRIMITIVE_REGISTRY),
109
+ },
110
+ )
111
+ self._last_obs = obs
112
+ return obs
113
+
114
+ def step(
115
+ self,
116
+ action: ForgeAction,
117
+ timeout_s: Optional[float] = None,
118
+ **kwargs: Any,
119
+ ) -> ForgeObservation:
120
+ if self._current_phase == "drift_gen":
121
+ if action.breakage is None:
122
+ return self._error_obs("Expected BreakageAction in drift_gen phase")
123
+ return self._handle_breakage(action.breakage)
124
+
125
+ if self._current_phase == "repair":
126
+ if action.repair is None:
127
+ return self._error_obs("Expected RepairAction in repair phase")
128
+ return self._handle_repair(action.repair)
129
+
130
+ return self._error_obs(
131
+ f"step() called in invalid phase {self._current_phase!r} — call reset() first"
132
+ )
133
+
134
+ @property
135
+ def state(self) -> dict:
136
+ return {
137
+ "phase": self._current_phase,
138
+ "episode_id": self._episode_id,
139
+ "episode_count": self._episode_count,
140
+ "task_id": self._current_task.task_id if self._current_task else None,
141
+ "target_category": self._target_category,
142
+ "library_versions": self.drift_engine.current_versions(),
143
+ "teacher": self.teacher.get_state(),
144
+ "drift_history": list(self.drift_engine.drift_history),
145
+ "breakage_spec": dict(self._breakage_spec) if self._breakage_spec else None,
146
+ }
147
+
148
+ # ---------------------------------------------------------------- helpers
149
+ def _handle_breakage(self, breakage: BreakageAction) -> ForgeObservation:
150
+ spec = {"primitive_type": breakage.primitive_type, "params": dict(breakage.params)}
151
+ try:
152
+ primitive = parse_breakage_spec(spec)
153
+ except ValueError as exc:
154
+ return self._error_obs(f"Invalid breakage spec: {exc}")
155
+
156
+ try:
157
+ self._broken_script = primitive.apply(self._original_script)
158
+ except Exception as exc: # primitive bug — surface but don't crash server
159
+ return self._error_obs(f"Primitive apply failed: {exc}")
160
+
161
+ self._breakage_spec = spec
162
+
163
+ result = self.executor.execute(self._broken_script, self._current_task)
164
+ if result.exit_code != 0:
165
+ self._error_trace = result.stderr or "non-zero exit code, no stderr"
166
+ else:
167
+ # The breakage didn't actually break it; still proceed to repair phase
168
+ # (no-op repair is then a valid choice).
169
+ self._error_trace = "Script ran without observable error"
170
+
171
+ self._current_phase = "repair"
172
+
173
+ obs = ForgeObservation(
174
+ current_phase="repair",
175
+ task_id=self._current_task.task_id,
176
+ task_description=self._current_task.description,
177
+ target_category=primitive.category,
178
+ script_content=self._broken_script,
179
+ error_trace=self._error_trace,
180
+ library_versions=self.drift_engine.current_versions(),
181
+ episode_step=1,
182
+ done=False,
183
+ reward=0.0,
184
+ info={
185
+ "episode_id": self._episode_id,
186
+ "breakage_primitive": primitive.name,
187
+ "breakage_description": primitive.description,
188
+ },
189
+ )
190
+ self._last_obs = obs
191
+ return obs
192
+
193
+ def _handle_repair(self, repair: RepairAction) -> ForgeObservation:
194
+ repaired = apply_unified_diff(self._broken_script, repair.unified_diff or "")
195
+
196
+ t0 = time.time()
197
+ result = self.executor.execute(repaired, self._current_task)
198
+ result.script_content = repaired # ensure verifier sees what we ran
199
+ wall_ms = int((time.time() - t0) * 1000)
200
+
201
+ visible_reward, visible_breakdown = compute_visible_reward(
202
+ result, self._current_task
203
+ )
204
+ held_out = compute_held_out_scores(
205
+ result, self._current_task, repair_diff=repair.unified_diff or ""
206
+ )
207
+
208
+ success = result.exit_code == 0
209
+ category = (
210
+ self._breakage_spec.get("primitive_type", "unknown")
211
+ if self._breakage_spec
212
+ else "unknown"
213
+ )
214
+ # Update Teacher's curriculum state
215
+ self.teacher.update(category, success)
216
+
217
+ self._current_phase = "done"
218
+
219
+ obs = ForgeObservation(
220
+ current_phase="done",
221
+ task_id=self._current_task.task_id,
222
+ task_description=self._current_task.description,
223
+ target_category=category,
224
+ script_content=repaired,
225
+ error_trace=result.stderr or None,
226
+ library_versions=self.drift_engine.current_versions(),
227
+ episode_step=2,
228
+ done=True,
229
+ reward=visible_reward,
230
+ reward_breakdown=visible_breakdown,
231
+ held_out_breakdown=held_out,
232
+ info={
233
+ "episode_id": self._episode_id,
234
+ "exit_code": result.exit_code,
235
+ "wall_time_ms": wall_ms,
236
+ "checkpoint_exists": result.checkpoint_exists,
237
+ "stdout_tail": "\n".join(result.stdout.splitlines()[-5:]),
238
+ "breakage_spec": self._breakage_spec,
239
+ "teacher_state": self.teacher.get_state(),
240
+ },
241
+ )
242
+ self._last_obs = obs
243
+ return obs
244
+
245
+ def _error_obs(self, message: str) -> ForgeObservation:
246
+ """Return a `done=True` error observation rather than raising."""
247
+ return ForgeObservation(
248
+ current_phase="done",
249
+ task_id=self._current_task.task_id if self._current_task else "",
250
+ task_description=self._current_task.description if self._current_task else "",
251
+ target_category=self._target_category,
252
+ script_content=self._broken_script or self._original_script,
253
+ error_trace=message,
254
+ library_versions=self.drift_engine.current_versions(),
255
+ episode_step=2,
256
+ done=True,
257
+ reward=0.0,
258
+ info={"error": message, "episode_id": self._episode_id},
259
+ )
forgeenv/env/observations.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic observation model for ForgeEnv."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Any, Optional
5
+
6
+ from pydantic import Field
7
+
8
+ from openenv.core import Observation
9
+
10
+
11
+ class ForgeObservation(Observation):
12
+ """What the agent (or the trainer's rollout function) sees at each step.
13
+
14
+ Inherits `done`, `reward`, `metadata` from the OpenEnv `Observation` base.
15
+ """
16
+
17
+ current_phase: str = Field(
18
+ ..., description="One of 'drift_gen', 'repair', 'verify', 'done'"
19
+ )
20
+ task_id: str = ""
21
+ task_description: str = ""
22
+ target_category: str = ""
23
+ script_content: str = Field(default="", description="Current state of the script")
24
+ error_trace: Optional[str] = None
25
+ library_versions: dict[str, str] = Field(default_factory=dict)
26
+ reward_breakdown: dict[str, Any] = Field(default_factory=dict)
27
+ held_out_breakdown: dict[str, float] = Field(default_factory=dict)
28
+ episode_step: int = 0
29
+ info: dict[str, Any] = Field(default_factory=dict)
forgeenv/env/server.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI server for ForgeEnv (OpenEnv-compliant).
2
+
3
+ Exposes /reset, /step, /state HTTP endpoints via OpenEnv's `create_app`.
4
+ HF Spaces sets PORT=7860 automatically.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import os
9
+
10
+ from openenv.core import create_app
11
+
12
+ from forgeenv.env.actions import ForgeAction
13
+ from forgeenv.env.forge_environment import ForgeEnvironment
14
+ from forgeenv.env.observations import ForgeObservation
15
+
16
+ app = create_app(
17
+ env=ForgeEnvironment,
18
+ action_cls=ForgeAction,
19
+ observation_cls=ForgeObservation,
20
+ env_name="forgeenv",
21
+ )
22
+
23
+
24
+ # Lightweight health endpoint for Docker / HF Spaces probes. We attach it
25
+ # only if `create_app` didn't already register one, so we don't shadow
26
+ # whatever the OpenEnv version ships with.
27
+ def _ensure_health_route(_app) -> None:
28
+ existing = {
29
+ getattr(r, "path", None) for r in getattr(_app, "routes", [])
30
+ }
31
+ if "/health" in existing:
32
+ return
33
+
34
+ @_app.get("/health")
35
+ def _health() -> dict:
36
+ return {"status": "ok", "env": "forgeenv"}
37
+
38
+
39
+ _ensure_health_route(app)
40
+
41
+
42
+ if __name__ == "__main__":
43
+ import uvicorn
44
+
45
+ port = int(os.environ.get("PORT", "7860"))
46
+ uvicorn.run(app, host="0.0.0.0", port=port)
forgeenv/primitives/__init__.py ADDED
File without changes
forgeenv/primitives/breakage_primitives.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """8 breakage primitives representing real HuggingFace/PyTorch ecosystem drift.
2
+
3
+ Each primitive transforms a working script to simulate a library upgrade
4
+ breakage. They double as the Drift Generator's structured action space.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import re
9
+ from abc import ABC, abstractmethod
10
+ from dataclasses import dataclass, field
11
+
12
+
13
+ @dataclass
14
+ class BreakagePrimitive(ABC):
15
+ """Abstract base class for all breakage types."""
16
+
17
+ category: str = field(default="generic", init=False)
18
+ name: str = field(default="BreakagePrimitive", init=False)
19
+ description: str = field(default="", init=False)
20
+
21
+ @abstractmethod
22
+ def apply(self, script: str) -> str:
23
+ """Transform `script` to introduce the breakage."""
24
+
25
+ def to_spec(self) -> dict:
26
+ """Serialize to JSON-compatible spec for the LLM action space."""
27
+ return {
28
+ "primitive_type": self.__class__.__name__,
29
+ "category": self.category,
30
+ "params": self._get_params(),
31
+ }
32
+
33
+ @abstractmethod
34
+ def _get_params(self) -> dict:
35
+ """Return a JSON-serializable dict of constructor parameters."""
36
+
37
+
38
+ @dataclass
39
+ class RenameApiCall(BreakagePrimitive):
40
+ """Rename a function/method call to simulate API deprecation."""
41
+
42
+ old_name: str = ""
43
+ new_name: str = ""
44
+
45
+ def __post_init__(self) -> None:
46
+ self.category = "api_drift"
47
+ self.name = "RenameApiCall"
48
+ self.description = f"Rename {self.old_name} -> {self.new_name}"
49
+
50
+ def apply(self, script: str) -> str:
51
+ if not self.old_name:
52
+ return script
53
+ # Use word-boundary replacement so we don't substring-match identifiers.
54
+ pattern = re.compile(rf"(?<!\w){re.escape(self.old_name)}(?!\w)")
55
+ return pattern.sub(self.new_name, script)
56
+
57
+ def _get_params(self) -> dict:
58
+ return {"old_name": self.old_name, "new_name": self.new_name}
59
+
60
+
61
+ @dataclass
62
+ class DeprecateImport(BreakagePrimitive):
63
+ """Change an import path to simulate module restructuring."""
64
+
65
+ old_module: str = ""
66
+ new_module: str = ""
67
+
68
+ def __post_init__(self) -> None:
69
+ self.category = "import_drift"
70
+ self.name = "DeprecateImport"
71
+ self.description = f"Move {self.old_module} -> {self.new_module}"
72
+
73
+ def apply(self, script: str) -> str:
74
+ if not self.old_module:
75
+ return script
76
+ return script.replace(self.old_module, self.new_module)
77
+
78
+ def _get_params(self) -> dict:
79
+ return {"old_module": self.old_module, "new_module": self.new_module}
80
+
81
+
82
+ @dataclass
83
+ class ChangeArgumentSignature(BreakagePrimitive):
84
+ """Remove an expected kwarg (and document a new required one)."""
85
+
86
+ function_name: str = ""
87
+ removed_arg: str = ""
88
+ added_arg: str = ""
89
+ added_value: str = ""
90
+
91
+ def __post_init__(self) -> None:
92
+ self.category = "api_drift"
93
+ self.name = "ChangeArgumentSignature"
94
+ self.description = (
95
+ f"Change args of {self.function_name}: -{self.removed_arg} +{self.added_arg}"
96
+ )
97
+
98
+ def apply(self, script: str) -> str:
99
+ if not self.removed_arg:
100
+ return script
101
+ pattern = rf"(\b{re.escape(self.removed_arg)}\s*=\s*[^,)]+,?\s*)"
102
+ return re.sub(pattern, "", script)
103
+
104
+ def _get_params(self) -> dict:
105
+ return {
106
+ "function_name": self.function_name,
107
+ "removed_arg": self.removed_arg,
108
+ "added_arg": self.added_arg,
109
+ "added_value": self.added_value,
110
+ }
111
+
112
+
113
+ @dataclass
114
+ class ModifyConfigField(BreakagePrimitive):
115
+ """Change a config-class default value to simulate behaviour drift."""
116
+
117
+ config_class: str = ""
118
+ field_name: str = ""
119
+ new_value: str = ""
120
+
121
+ def __post_init__(self) -> None:
122
+ self.category = "config_drift"
123
+ self.name = "ModifyConfigField"
124
+ self.description = f"Change {self.config_class}.{self.field_name}"
125
+
126
+ def apply(self, script: str) -> str:
127
+ if not self.field_name:
128
+ return script
129
+ pattern = rf"({re.escape(self.field_name)}\s*=\s*)([^,)\n]+)"
130
+ return re.sub(pattern, rf"\g<1>{self.new_value}", script)
131
+
132
+ def _get_params(self) -> dict:
133
+ return {
134
+ "config_class": self.config_class,
135
+ "field_name": self.field_name,
136
+ "new_value": self.new_value,
137
+ }
138
+
139
+
140
+ @dataclass
141
+ class RestructureDatasetSchema(BreakagePrimitive):
142
+ """Rename a dataset column reference to simulate schema drift."""
143
+
144
+ old_column: str = ""
145
+ new_column: str = ""
146
+
147
+ def __post_init__(self) -> None:
148
+ self.category = "dataset_drift"
149
+ self.name = "RestructureDatasetSchema"
150
+ self.description = f"Rename column {self.old_column} -> {self.new_column}"
151
+
152
+ def apply(self, script: str) -> str:
153
+ if not self.old_column:
154
+ return script
155
+ return script.replace(
156
+ f'"{self.old_column}"', f'"{self.new_column}"'
157
+ ).replace(
158
+ f"'{self.old_column}'", f"'{self.new_column}'"
159
+ )
160
+
161
+ def _get_params(self) -> dict:
162
+ return {"old_column": self.old_column, "new_column": self.new_column}
163
+
164
+
165
+ @dataclass
166
+ class ChangeTokenizerBehavior(BreakagePrimitive):
167
+ """Change tokenizer call arguments."""
168
+
169
+ old_kwarg: str = ""
170
+ old_value: str = ""
171
+ new_kwarg: str = ""
172
+ new_value: str = ""
173
+
174
+ def __post_init__(self) -> None:
175
+ self.category = "tokenizer_drift"
176
+ self.name = "ChangeTokenizerBehavior"
177
+ self.description = f"Change tokenizer kwarg {self.old_kwarg}={self.old_value} -> {self.new_kwarg}={self.new_value}"
178
+
179
+ def apply(self, script: str) -> str:
180
+ if not self.old_kwarg:
181
+ return script
182
+ pattern = rf"{re.escape(self.old_kwarg)}\s*=\s*{re.escape(self.old_value)}"
183
+ replacement = f"{self.new_kwarg}={self.new_value}"
184
+ return re.sub(pattern, replacement, script)
185
+
186
+ def _get_params(self) -> dict:
187
+ return {
188
+ "old_kwarg": self.old_kwarg,
189
+ "old_value": self.old_value,
190
+ "new_kwarg": self.new_kwarg,
191
+ "new_value": self.new_value,
192
+ }
193
+
194
+
195
+ @dataclass
196
+ class RemoveDeprecatedMethod(BreakagePrimitive):
197
+ """Remove a method that has been deprecated, leaving a sentinel that
198
+ raises AttributeError-style errors when the script runs."""
199
+
200
+ class_name: str = ""
201
+ method_name: str = ""
202
+ replacement: str = ""
203
+
204
+ def __post_init__(self) -> None:
205
+ self.category = "api_drift"
206
+ self.name = "RemoveDeprecatedMethod"
207
+ self.description = f"Remove {self.class_name}.{self.method_name}"
208
+
209
+ def apply(self, script: str) -> str:
210
+ if not self.method_name:
211
+ return script
212
+ return script.replace(
213
+ f".{self.method_name}(", f".{self.method_name}_DEPRECATED("
214
+ )
215
+
216
+ def _get_params(self) -> dict:
217
+ return {
218
+ "class_name": self.class_name,
219
+ "method_name": self.method_name,
220
+ "replacement": self.replacement,
221
+ }
222
+
223
+
224
+ @dataclass
225
+ class ChangeReturnType(BreakagePrimitive):
226
+ """A function now returns a different structure (e.g. tuple -> object)."""
227
+
228
+ function_name: str = ""
229
+ old_access: str = ""
230
+ new_access: str = ""
231
+
232
+ def __post_init__(self) -> None:
233
+ self.category = "api_drift"
234
+ self.name = "ChangeReturnType"
235
+ self.description = f"Change return type of {self.function_name}"
236
+
237
+ def apply(self, script: str) -> str:
238
+ if self.old_access and self.new_access:
239
+ return script.replace(self.old_access, self.new_access)
240
+ return script
241
+
242
+ def _get_params(self) -> dict:
243
+ return {
244
+ "function_name": self.function_name,
245
+ "old_access": self.old_access,
246
+ "new_access": self.new_access,
247
+ }
248
+
249
+
250
+ PRIMITIVE_REGISTRY: dict[str, type[BreakagePrimitive]] = {
251
+ "RenameApiCall": RenameApiCall,
252
+ "DeprecateImport": DeprecateImport,
253
+ "ChangeArgumentSignature": ChangeArgumentSignature,
254
+ "ModifyConfigField": ModifyConfigField,
255
+ "RestructureDatasetSchema": RestructureDatasetSchema,
256
+ "ChangeTokenizerBehavior": ChangeTokenizerBehavior,
257
+ "RemoveDeprecatedMethod": RemoveDeprecatedMethod,
258
+ "ChangeReturnType": ChangeReturnType,
259
+ }
260
+
261
+
262
+ def parse_breakage_spec(spec: dict) -> BreakagePrimitive:
263
+ """Parse a JSON breakage spec into a BreakagePrimitive object.
264
+
265
+ Tolerates extra keys; ignores unknown params (LLMs hallucinate these).
266
+ """
267
+ ptype = spec.get("primitive_type", "")
268
+ params = spec.get("params", {}) or {}
269
+
270
+ if ptype not in PRIMITIVE_REGISTRY:
271
+ raise ValueError(
272
+ f"Unknown primitive type: {ptype!r}. "
273
+ f"Valid types: {list(PRIMITIVE_REGISTRY)}"
274
+ )
275
+
276
+ cls = PRIMITIVE_REGISTRY[ptype]
277
+ # Filter to known fields only so a hallucinated kwarg can't crash us.
278
+ valid_fields = {
279
+ f.name for f in cls.__dataclass_fields__.values() if f.init # type: ignore[attr-defined]
280
+ }
281
+ filtered = {k: v for k, v in params.items() if k in valid_fields}
282
+ return cls(**filtered)
forgeenv/primitives/drift_taxonomy.yaml ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Drift taxonomy: real HuggingFace/PyTorch breakages observed across version bumps.
2
+ # Used to seed the Drift Generator's initial proposal distribution and to anchor
3
+ # warm-start pair generation in things that actually happened in the wild.
4
+ - version_range: "transformers 4.36 -> 4.45"
5
+ affected_api: "Trainer.evaluate"
6
+ description: "Trainer.evaluate() return type changed shape; metrics now nested under .metrics"
7
+ breakage_primitive: "ChangeReturnType"
8
+ params:
9
+ function_name: "evaluate"
10
+ old_access: "trainer.evaluate()"
11
+ new_access: "trainer.evaluate().metrics"
12
+ repair_primitive: "RestoreReturnAccess"
13
+ category: "api_drift"
14
+
15
+ - version_range: "transformers 4.30 -> 4.40"
16
+ affected_api: "TrainingArguments.evaluation_strategy"
17
+ description: "Renamed evaluation_strategy -> eval_strategy"
18
+ breakage_primitive: "RenameApiCall"
19
+ params:
20
+ old_name: "evaluation_strategy"
21
+ new_name: "eval_strategy"
22
+ repair_primitive: "RestoreApiCall"
23
+ category: "api_drift"
24
+
25
+ - version_range: "datasets 2.14 -> 3.0"
26
+ affected_api: "load_dataset"
27
+ description: "Default split column was renamed in some GLUE configs"
28
+ breakage_primitive: "RestructureDatasetSchema"
29
+ params:
30
+ old_column: "label"
31
+ new_column: "labels"
32
+ repair_primitive: "RestoreColumn"
33
+ category: "dataset_drift"
34
+
35
+ - version_range: "transformers 4.40 -> 4.50"
36
+ affected_api: "Trainer.predict"
37
+ description: "Method removed; users should use evaluate() with prediction_loss_only=False"
38
+ breakage_primitive: "RemoveDeprecatedMethod"
39
+ params:
40
+ class_name: "Trainer"
41
+ method_name: "predict"
42
+ replacement: "evaluate"
43
+ repair_primitive: "RestoreMethod"
44
+ category: "api_drift"
45
+
46
+ - version_range: "transformers 4.36 -> 4.40"
47
+ affected_api: "TrainingArguments"
48
+ description: "num_train_epochs default behavior changed; max_steps now preferred"
49
+ breakage_primitive: "ModifyConfigField"
50
+ params:
51
+ config_class: "TrainingArguments"
52
+ field_name: "num_train_epochs"
53
+ new_value: "0"
54
+ repair_primitive: "RestoreConfigField"
55
+ category: "config_drift"
56
+
57
+ - version_range: "transformers 4.34 -> 4.42"
58
+ affected_api: "Tokenizer.__call__"
59
+ description: "padding=True semantics changed; users should pass padding='max_length'"
60
+ breakage_primitive: "ChangeTokenizerBehavior"
61
+ params:
62
+ old_kwarg: "padding"
63
+ old_value: "True"
64
+ new_kwarg: "padding"
65
+ new_value: '"max_length"'
66
+ repair_primitive: "RestoreTokenizerKwarg"
67
+ category: "tokenizer_drift"
68
+
69
+ - version_range: "transformers 4.20 -> 4.30"
70
+ affected_api: "imports"
71
+ description: "transformers.training_args moved to transformers.training_args_pt"
72
+ breakage_primitive: "DeprecateImport"
73
+ params:
74
+ old_module: "from transformers.training_args"
75
+ new_module: "from transformers.training_args_pt"
76
+ repair_primitive: "RestoreImport"
77
+ category: "import_drift"
78
+
79
+ - version_range: "transformers 4.45 -> 4.50"
80
+ affected_api: "save_pretrained"
81
+ description: "save_pretrained() now requires safe_serialization to default True"
82
+ breakage_primitive: "ChangeArgumentSignature"
83
+ params:
84
+ function_name: "save_pretrained"
85
+ removed_arg: "safe_serialization"
86
+ added_arg: "safe_serialization"
87
+ added_value: "True"
88
+ repair_primitive: "RestoreArgument"
89
+ category: "api_drift"
90
+
91
+ - version_range: "datasets 2.18 -> 3.0"
92
+ affected_api: "Dataset.set_format"
93
+ description: "set_format(type='torch') signature stricter, columns required"
94
+ breakage_primitive: "ChangeArgumentSignature"
95
+ params:
96
+ function_name: "set_format"
97
+ removed_arg: "columns"
98
+ added_arg: "columns"
99
+ added_value: '["input_ids", "attention_mask", "labels"]'
100
+ repair_primitive: "RestoreArgument"
101
+ category: "api_drift"
102
+
103
+ - version_range: "transformers 4.36 -> 4.45"
104
+ affected_api: "Tokenizer.__call__"
105
+ description: "max_length default reduced from 512 -> 256 for some tokenizers"
106
+ breakage_primitive: "ModifyConfigField"
107
+ params:
108
+ config_class: "tokenizer"
109
+ field_name: "max_length"
110
+ new_value: "256"
111
+ repair_primitive: "RestoreConfigField"
112
+ category: "tokenizer_drift"
113
+
114
+ - version_range: "transformers 4.40 -> 4.45"
115
+ affected_api: "DataCollatorWithPadding"
116
+ description: "Renamed `tokenizer` -> `processing_class` in DataCollator constructors"
117
+ breakage_primitive: "RenameApiCall"
118
+ params:
119
+ old_name: "tokenizer"
120
+ new_name: "processing_class"
121
+ repair_primitive: "RestoreApiCall"
122
+ category: "api_drift"
123
+
124
+ - version_range: "datasets 2.14 -> 2.18"
125
+ affected_api: "load_dataset"
126
+ description: "Some splits renamed train[:500] semantics changed"
127
+ breakage_primitive: "RestructureDatasetSchema"
128
+ params:
129
+ old_column: "sentence"
130
+ new_column: "text"
131
+ repair_primitive: "RestoreColumn"
132
+ category: "dataset_drift"
133
+
134
+ - version_range: "transformers 4.45 -> 4.50"
135
+ affected_api: "Trainer"
136
+ description: "evaluation_strategy was deprecated and removed"
137
+ breakage_primitive: "RemoveDeprecatedMethod"
138
+ params:
139
+ class_name: "Trainer"
140
+ method_name: "evaluate"
141
+ replacement: "evaluate_legacy"
142
+ repair_primitive: "RestoreMethod"
143
+ category: "api_drift"
144
+
145
+ - version_range: "transformers 4.30 -> 4.40"
146
+ affected_api: "PreTrainedModel.from_pretrained"
147
+ description: "torch_dtype now required for some quantized model paths"
148
+ breakage_primitive: "ChangeArgumentSignature"
149
+ params:
150
+ function_name: "from_pretrained"
151
+ removed_arg: "torch_dtype"
152
+ added_arg: "torch_dtype"
153
+ added_value: '"auto"'
154
+ repair_primitive: "RestoreArgument"
155
+ category: "api_drift"
156
+
157
+ - version_range: "datasets 3.0 -> 3.2"
158
+ affected_api: "Dataset.rename_column"
159
+ description: "rename_column raises if target name exists"
160
+ breakage_primitive: "RestructureDatasetSchema"
161
+ params:
162
+ old_column: "labels"
163
+ new_column: "label"
164
+ repair_primitive: "RestoreColumn"
165
+ category: "dataset_drift"
166
+
167
+ - version_range: "transformers 4.36 -> 4.42"
168
+ affected_api: "TrainingArguments.report_to"
169
+ description: "Default report_to changed from 'all' to 'none'"
170
+ breakage_primitive: "ModifyConfigField"
171
+ params:
172
+ config_class: "TrainingArguments"
173
+ field_name: "report_to"
174
+ new_value: '"all"'
175
+ repair_primitive: "RestoreConfigField"
176
+ category: "config_drift"
177
+
178
+ - version_range: "transformers 4.40 -> 4.50"
179
+ affected_api: "imports"
180
+ description: "transformers.deepspeed moved to accelerate.utils.deepspeed"
181
+ breakage_primitive: "DeprecateImport"
182
+ params:
183
+ old_module: "from transformers.deepspeed"
184
+ new_module: "from accelerate.utils.deepspeed"
185
+ repair_primitive: "RestoreImport"
186
+ category: "import_drift"
187
+
188
+ - version_range: "transformers 4.45 -> 4.50"
189
+ affected_api: "Tokenizer return"
190
+ description: "Tokenizer call output now returns a BatchEncoding with .encodings attribute"
191
+ breakage_primitive: "ChangeReturnType"
192
+ params:
193
+ function_name: "tokenizer"
194
+ old_access: "tokenizer(text)"
195
+ new_access: "tokenizer(text).encodings"
196
+ repair_primitive: "RestoreReturnAccess"
197
+ category: "api_drift"
198
+
199
+ - version_range: "transformers 4.30 -> 4.40"
200
+ affected_api: "save_pretrained"
201
+ description: "save_pretrained -> save_pretrained_directory rename in some classes"
202
+ breakage_primitive: "RenameApiCall"
203
+ params:
204
+ old_name: "save_pretrained"
205
+ new_name: "save_pretrained_directory"
206
+ repair_primitive: "RestoreApiCall"
207
+ category: "api_drift"
208
+
209
+ - version_range: "transformers 4.45 -> 4.50"
210
+ affected_api: "TrainingArguments.no_cuda"
211
+ description: "no_cuda renamed to use_cpu (logic inverted)"
212
+ breakage_primitive: "RenameApiCall"
213
+ params:
214
+ old_name: "no_cuda"
215
+ new_name: "use_cpu"
216
+ repair_primitive: "RestoreApiCall"
217
+ category: "config_drift"
forgeenv/primitives/repair_primitives.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Repair primitives — direct inverses of the 8 breakage primitives.
2
+
3
+ Used during warm-start data generation: for every (script, breakage)
4
+ pair we know the canonical repair, so we can write SFT pairs.
5
+
6
+ These are also useful for unit-testing the breakage primitives:
7
+ apply(breakage) then apply(repair) should be (close to) the identity.
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import re
12
+ from abc import ABC, abstractmethod
13
+ from dataclasses import dataclass, field
14
+
15
+
16
+ @dataclass
17
+ class RepairPrimitive(ABC):
18
+ category: str = field(default="generic", init=False)
19
+ name: str = field(default="RepairPrimitive", init=False)
20
+ description: str = field(default="", init=False)
21
+
22
+ @abstractmethod
23
+ def apply(self, script: str) -> str:
24
+ """Transform `script` to undo the corresponding breakage."""
25
+
26
+ def to_spec(self) -> dict:
27
+ return {
28
+ "primitive_type": self.__class__.__name__,
29
+ "category": self.category,
30
+ "params": self._get_params(),
31
+ }
32
+
33
+ @abstractmethod
34
+ def _get_params(self) -> dict:
35
+ """Return JSON-serializable constructor parameters."""
36
+
37
+
38
+ @dataclass
39
+ class RestoreApiCall(RepairPrimitive):
40
+ new_name: str = ""
41
+ old_name: str = ""
42
+
43
+ def __post_init__(self) -> None:
44
+ self.category = "api_drift"
45
+ self.name = "RestoreApiCall"
46
+ self.description = f"Rename {self.new_name} -> {self.old_name}"
47
+
48
+ def apply(self, script: str) -> str:
49
+ if not self.new_name:
50
+ return script
51
+ pattern = re.compile(rf"(?<!\w){re.escape(self.new_name)}(?!\w)")
52
+ return pattern.sub(self.old_name, script)
53
+
54
+ def _get_params(self) -> dict:
55
+ return {"new_name": self.new_name, "old_name": self.old_name}
56
+
57
+
58
+ @dataclass
59
+ class RestoreImport(RepairPrimitive):
60
+ new_module: str = ""
61
+ old_module: str = ""
62
+
63
+ def __post_init__(self) -> None:
64
+ self.category = "import_drift"
65
+ self.name = "RestoreImport"
66
+ self.description = f"Restore import {self.new_module} -> {self.old_module}"
67
+
68
+ def apply(self, script: str) -> str:
69
+ return script.replace(self.new_module, self.old_module)
70
+
71
+ def _get_params(self) -> dict:
72
+ return {"new_module": self.new_module, "old_module": self.old_module}
73
+
74
+
75
+ @dataclass
76
+ class RestoreArgument(RepairPrimitive):
77
+ """Re-add a removed argument to a function call."""
78
+
79
+ function_name: str = ""
80
+ arg_name: str = ""
81
+ arg_value: str = ""
82
+
83
+ def __post_init__(self) -> None:
84
+ self.category = "api_drift"
85
+ self.name = "RestoreArgument"
86
+ self.description = (
87
+ f"Add {self.arg_name}={self.arg_value} to {self.function_name}()"
88
+ )
89
+
90
+ def apply(self, script: str) -> str:
91
+ if not self.function_name:
92
+ return script
93
+ # Insert the kwarg right after the function-name's opening paren.
94
+ pattern = rf"({re.escape(self.function_name)}\s*\()(\s*)"
95
+ replacement = rf"\g<1>{self.arg_name}={self.arg_value}, \g<2>"
96
+ return re.sub(pattern, replacement, script, count=1)
97
+
98
+ def _get_params(self) -> dict:
99
+ return {
100
+ "function_name": self.function_name,
101
+ "arg_name": self.arg_name,
102
+ "arg_value": self.arg_value,
103
+ }
104
+
105
+
106
+ @dataclass
107
+ class RestoreConfigField(RepairPrimitive):
108
+ field_name: str = ""
109
+ old_value: str = ""
110
+
111
+ def __post_init__(self) -> None:
112
+ self.category = "config_drift"
113
+ self.name = "RestoreConfigField"
114
+ self.description = f"Restore {self.field_name}={self.old_value}"
115
+
116
+ def apply(self, script: str) -> str:
117
+ if not self.field_name:
118
+ return script
119
+ pattern = rf"({re.escape(self.field_name)}\s*=\s*)([^,)\n]+)"
120
+ return re.sub(pattern, rf"\g<1>{self.old_value}", script)
121
+
122
+ def _get_params(self) -> dict:
123
+ return {"field_name": self.field_name, "old_value": self.old_value}
124
+
125
+
126
+ @dataclass
127
+ class RestoreColumn(RepairPrimitive):
128
+ new_column: str = ""
129
+ old_column: str = ""
130
+
131
+ def __post_init__(self) -> None:
132
+ self.category = "dataset_drift"
133
+ self.name = "RestoreColumn"
134
+ self.description = f"Rename column {self.new_column} -> {self.old_column}"
135
+
136
+ def apply(self, script: str) -> str:
137
+ return script.replace(
138
+ f'"{self.new_column}"', f'"{self.old_column}"'
139
+ ).replace(
140
+ f"'{self.new_column}'", f"'{self.old_column}'"
141
+ )
142
+
143
+ def _get_params(self) -> dict:
144
+ return {"new_column": self.new_column, "old_column": self.old_column}
145
+
146
+
147
+ @dataclass
148
+ class RestoreTokenizerKwarg(RepairPrimitive):
149
+ new_kwarg: str = ""
150
+ new_value: str = ""
151
+ old_kwarg: str = ""
152
+ old_value: str = ""
153
+
154
+ def __post_init__(self) -> None:
155
+ self.category = "tokenizer_drift"
156
+ self.name = "RestoreTokenizerKwarg"
157
+ self.description = (
158
+ f"Restore tokenizer {self.new_kwarg}={self.new_value} -> "
159
+ f"{self.old_kwarg}={self.old_value}"
160
+ )
161
+
162
+ def apply(self, script: str) -> str:
163
+ if not self.new_kwarg:
164
+ return script
165
+ pattern = rf"{re.escape(self.new_kwarg)}\s*=\s*{re.escape(self.new_value)}"
166
+ replacement = f"{self.old_kwarg}={self.old_value}"
167
+ return re.sub(pattern, replacement, script)
168
+
169
+ def _get_params(self) -> dict:
170
+ return {
171
+ "new_kwarg": self.new_kwarg,
172
+ "new_value": self.new_value,
173
+ "old_kwarg": self.old_kwarg,
174
+ "old_value": self.old_value,
175
+ }
176
+
177
+
178
+ @dataclass
179
+ class RestoreMethod(RepairPrimitive):
180
+ method_name: str = ""
181
+
182
+ def __post_init__(self) -> None:
183
+ self.category = "api_drift"
184
+ self.name = "RestoreMethod"
185
+ self.description = f"Un-deprecate .{self.method_name}()"
186
+
187
+ def apply(self, script: str) -> str:
188
+ if not self.method_name:
189
+ return script
190
+ return script.replace(
191
+ f".{self.method_name}_DEPRECATED(", f".{self.method_name}("
192
+ )
193
+
194
+ def _get_params(self) -> dict:
195
+ return {"method_name": self.method_name}
196
+
197
+
198
+ @dataclass
199
+ class RestoreReturnAccess(RepairPrimitive):
200
+ new_access: str = ""
201
+ old_access: str = ""
202
+
203
+ def __post_init__(self) -> None:
204
+ self.category = "api_drift"
205
+ self.name = "RestoreReturnAccess"
206
+ self.description = f"Restore return-access {self.new_access} -> {self.old_access}"
207
+
208
+ def apply(self, script: str) -> str:
209
+ if not self.new_access:
210
+ return script
211
+ return script.replace(self.new_access, self.old_access)
212
+
213
+ def _get_params(self) -> dict:
214
+ return {"new_access": self.new_access, "old_access": self.old_access}
215
+
216
+
217
+ REPAIR_REGISTRY: dict[str, type[RepairPrimitive]] = {
218
+ "RestoreApiCall": RestoreApiCall,
219
+ "RestoreImport": RestoreImport,
220
+ "RestoreArgument": RestoreArgument,
221
+ "RestoreConfigField": RestoreConfigField,
222
+ "RestoreColumn": RestoreColumn,
223
+ "RestoreTokenizerKwarg": RestoreTokenizerKwarg,
224
+ "RestoreMethod": RestoreMethod,
225
+ "RestoreReturnAccess": RestoreReturnAccess,
226
+ }
227
+
228
+
229
+ # Map a breakage primitive's class name to the repair-primitive class that
230
+ # inverts it. Used by the warm-start pair generator and by the demo / repair
231
+ # library curator.
232
+ BREAKAGE_TO_REPAIR: dict[str, str] = {
233
+ "RenameApiCall": "RestoreApiCall",
234
+ "DeprecateImport": "RestoreImport",
235
+ "ChangeArgumentSignature": "RestoreArgument",
236
+ "ModifyConfigField": "RestoreConfigField",
237
+ "RestructureDatasetSchema": "RestoreColumn",
238
+ "ChangeTokenizerBehavior": "RestoreTokenizerKwarg",
239
+ "RemoveDeprecatedMethod": "RestoreMethod",
240
+ "ChangeReturnType": "RestoreReturnAccess",
241
+ }
forgeenv/roles/__init__.py ADDED
File without changes
forgeenv/roles/drift_generator.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Drift Generator parser + a deterministic baseline policy.
2
+
3
+ In training the LLM produces a JSON breakage spec; we parse it. In rollouts
4
+ where we want a baseline (or a fallback when the LLM emits malformed JSON)
5
+ we use `BaselineDriftGenerator`, which samples from the per-category set of
6
+ known good primitive parameterisations.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import random
12
+ import re
13
+ from dataclasses import dataclass
14
+ from typing import Optional
15
+
16
+ from forgeenv.primitives.breakage_primitives import (
17
+ PRIMITIVE_REGISTRY,
18
+ parse_breakage_spec,
19
+ BreakagePrimitive,
20
+ )
21
+
22
+
23
+ _JSON_RE = re.compile(r"\{[\s\S]*\}")
24
+
25
+
26
+ def parse_drift_output(text: str) -> Optional[dict]:
27
+ """Extract a JSON object from possibly-noisy LLM output.
28
+
29
+ Handles markdown fences, prose preamble, trailing commas (best-effort).
30
+ Returns None on failure.
31
+ """
32
+ if not text:
33
+ return None
34
+ text = text.strip()
35
+ if text.startswith("```"):
36
+ text = re.sub(r"^```[a-zA-Z]*\n?", "", text)
37
+ text = re.sub(r"\n?```$", "", text)
38
+ match = _JSON_RE.search(text)
39
+ if not match:
40
+ return None
41
+ blob = match.group(0)
42
+ try:
43
+ return json.loads(blob)
44
+ except json.JSONDecodeError:
45
+ cleaned = re.sub(r",\s*([}\]])", r"\1", blob)
46
+ try:
47
+ return json.loads(cleaned)
48
+ except json.JSONDecodeError:
49
+ return None
50
+
51
+
52
+ def parse_drift_to_primitive(text: str) -> Optional[BreakagePrimitive]:
53
+ """End-to-end: LLM text -> validated BreakagePrimitive (or None)."""
54
+ data = parse_drift_output(text)
55
+ if not isinstance(data, dict):
56
+ return None
57
+ try:
58
+ return parse_breakage_spec(data)
59
+ except (ValueError, TypeError):
60
+ return None
61
+
62
+
63
+ # ---------------------------------------------------------------- baselines
64
+ _DEFAULT_PARAMS_BY_TYPE: dict[str, list[dict]] = {
65
+ "RenameApiCall": [
66
+ {"old_name": "trainer.train", "new_name": "trainer.start_training"},
67
+ {"old_name": "save_pretrained", "new_name": "save_to_hub"},
68
+ {"old_name": "from_pretrained", "new_name": "load_from_hub"},
69
+ ],
70
+ "DeprecateImport": [
71
+ {
72
+ "old_module": "from transformers import Trainer",
73
+ "new_module": "from transformers.legacy import Trainer",
74
+ },
75
+ {
76
+ "old_module": "from transformers import TrainingArguments",
77
+ "new_module": "from transformers.training import TrainingArguments",
78
+ },
79
+ ],
80
+ "ChangeArgumentSignature": [
81
+ {
82
+ "function_name": "TrainingArguments",
83
+ "removed_arg": "num_train_epochs",
84
+ "added_arg": "max_steps",
85
+ "added_value": "1000",
86
+ },
87
+ {
88
+ "function_name": "TrainingArguments",
89
+ "removed_arg": "evaluation_strategy",
90
+ "added_arg": "eval_strategy",
91
+ "added_value": '"steps"',
92
+ },
93
+ ],
94
+ "ModifyConfigField": [
95
+ {"config_class": "TrainingArguments", "field_name": "learning_rate", "new_value": "5e-3"},
96
+ {"config_class": "TrainingArguments", "field_name": "per_device_train_batch_size", "new_value": "1"},
97
+ ],
98
+ "RestructureDatasetSchema": [
99
+ {"old_column": "text", "new_column": "input_text"},
100
+ {"old_column": "label", "new_column": "labels"},
101
+ {"old_column": "tokens", "new_column": "words"},
102
+ ],
103
+ "ChangeTokenizerBehavior": [
104
+ {"old_kwarg": "padding", "old_value": "True", "new_kwarg": "pad_to_max_length", "new_value": "True"},
105
+ {"old_kwarg": "truncation", "old_value": "True", "new_kwarg": "truncate", "new_value": "True"},
106
+ ],
107
+ "RemoveDeprecatedMethod": [
108
+ {"class_name": "Trainer", "method_name": "evaluate", "replacement": "evaluation_loop"},
109
+ {"class_name": "Trainer", "method_name": "save_model", "replacement": "save_to_hub"},
110
+ ],
111
+ "ChangeReturnType": [
112
+ {"function_name": "Trainer.predict", "old_access": ".predictions", "new_access": "[0]"},
113
+ {"function_name": "tokenizer", "old_access": '["input_ids"]', "new_access": ".input_ids"},
114
+ ],
115
+ }
116
+
117
+
118
+ @dataclass
119
+ class BaselineDriftGenerator:
120
+ """Deterministic stand-in for the LLM Drift Generator.
121
+
122
+ Used for warm-start data, baseline rollouts, and unit tests.
123
+ """
124
+
125
+ seed: Optional[int] = None
126
+
127
+ def __post_init__(self) -> None:
128
+ self._rng = random.Random(self.seed) if self.seed is not None else random
129
+
130
+ def propose(
131
+ self, target_category: str = "", script: str = ""
132
+ ) -> dict:
133
+ """Produce a JSON-serializable breakage spec for `target_category`.
134
+
135
+ Order of preference:
136
+ 1. A primitive of `target_category` whose default params apply to `script`.
137
+ 2. A primitive of any type whose default params apply to `script`.
138
+ 3. A primitive of `target_category` (no-op fallback).
139
+ """
140
+
141
+ preferred_types = (
142
+ [target_category] if target_category in _DEFAULT_PARAMS_BY_TYPE else []
143
+ )
144
+ all_types = list(_DEFAULT_PARAMS_BY_TYPE.keys())
145
+
146
+ for type_set in (preferred_types, all_types):
147
+ shuffled = self._rng.sample(type_set, len(type_set)) if type_set else []
148
+ for ptype in shuffled:
149
+ for params in self._rng.sample(
150
+ _DEFAULT_PARAMS_BY_TYPE[ptype],
151
+ len(_DEFAULT_PARAMS_BY_TYPE[ptype]),
152
+ ):
153
+ if self._params_apply_to_script(ptype, params, script):
154
+ return {"primitive_type": ptype, "params": dict(params)}
155
+
156
+ ptype = preferred_types[0] if preferred_types else all_types[0]
157
+ return {
158
+ "primitive_type": ptype,
159
+ "params": dict(_DEFAULT_PARAMS_BY_TYPE[ptype][0]),
160
+ }
161
+
162
+ @staticmethod
163
+ def _params_apply_to_script(ptype: str, params: dict, script: str) -> bool:
164
+ """Heuristic: would this primitive actually mutate `script`?"""
165
+ if not script:
166
+ return True
167
+ for key in ("old_name", "old_module", "removed_arg", "field_name", "old_column", "old_kwarg", "method_name", "old_access"):
168
+ if key in params and params[key] and params[key] in script:
169
+ return True
170
+ return False
forgeenv/roles/prompts.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """System and user prompts for the two RL roles.
2
+
3
+ Both roles are trained from the same base policy (Qwen-2.5-Coder-7B) with
4
+ LoRA adapters per role, so role prompts are the only thing distinguishing
5
+ them at inference time. Keep them concise — every token is a token of GPU
6
+ budget during GRPO rollouts.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ from typing import Iterable
11
+
12
+
13
+ PRIMITIVE_DESCRIPTIONS = {
14
+ "RenameApiCall": "Rename a function/method call (api_drift)",
15
+ "DeprecateImport": "Change an import path (import_drift)",
16
+ "ChangeArgumentSignature": "Remove an expected kwarg from a call (api_drift)",
17
+ "ModifyConfigField": "Change a config-class default (config_drift)",
18
+ "RestructureDatasetSchema": "Rename a dataset column reference (dataset_drift)",
19
+ "ChangeTokenizerBehavior": "Change tokenizer call kwargs (tokenizer_drift)",
20
+ "RemoveDeprecatedMethod": "Remove a method, leaving a sentinel _DEPRECATED suffix (api_drift)",
21
+ "ChangeReturnType": "Function returns a different structure (api_drift)",
22
+ }
23
+
24
+ DRIFT_GENERATOR_SYSTEM_PROMPT = """You are the Drift Generator.
25
+ You see a working HuggingFace training script and the curriculum target category.
26
+ Output exactly one JSON object describing a breakage primitive that simulates
27
+ realistic library version drift. The primitive must:
28
+ 1. Be PLAUSIBLE — match the kind of breakage that happens between real
29
+ transformers/datasets/trl releases.
30
+ 2. Be SOLVABLE — the Repair Agent should be able to fix it from the error trace alone.
31
+ 3. Match the requested target_category.
32
+
33
+ Output schema:
34
+ {"primitive_type": "<one of the 8 types>", "params": { ... }}
35
+
36
+ Available primitive types and parameter schemas:
37
+ - RenameApiCall: {"old_name": str, "new_name": str}
38
+ - DeprecateImport: {"old_module": str, "new_module": str}
39
+ - ChangeArgumentSignature: {"function_name": str, "removed_arg": str, "added_arg": str, "added_value": str}
40
+ - ModifyConfigField: {"config_class": str, "field_name": str, "new_value": str}
41
+ - RestructureDatasetSchema: {"old_column": str, "new_column": str}
42
+ - ChangeTokenizerBehavior: {"old_kwarg": str, "old_value": str, "new_kwarg": str, "new_value": str}
43
+ - RemoveDeprecatedMethod: {"class_name": str, "method_name": str, "replacement": str}
44
+ - ChangeReturnType: {"function_name": str, "old_access": str, "new_access": str}
45
+
46
+ Output ONLY the JSON object — no commentary, no markdown fences.
47
+ """
48
+
49
+
50
+ REPAIR_AGENT_SYSTEM_PROMPT = """You are the Repair Agent.
51
+ You see a broken HuggingFace training script, an error trace, and the current
52
+ library version snapshot. Output ONLY a unified diff that fixes the script.
53
+
54
+ Rules:
55
+ 1. Use canonical unified-diff format with `--- a/train.py` / `+++ b/train.py`
56
+ headers and `@@ ... @@` hunk markers.
57
+ 2. Make the MINIMAL change that resolves the error AND preserves the original
58
+ training intent. Do NOT add bare-except blocks, monkey-patches, or sys.exit
59
+ calls.
60
+ 3. Do NOT add any prose, markdown fences, or thinking output — diff only.
61
+ 4. If the error is unfixable, output an empty diff.
62
+ """
63
+
64
+
65
+ def render_drift_generator_prompt(
66
+ script: str, target_category: str, library_versions: dict
67
+ ) -> str:
68
+ versions_str = ", ".join(f"{k}={v}" for k, v in library_versions.items())
69
+ return f"""Target category: {target_category}
70
+ Library versions: {versions_str}
71
+
72
+ Working script:
73
+ ```python
74
+ {script}
75
+ ```
76
+
77
+ Output JSON breakage primitive:"""
78
+
79
+
80
+ def render_repair_agent_prompt(
81
+ broken_script: str,
82
+ error_trace: str,
83
+ library_versions: dict,
84
+ target_category: str = "",
85
+ ) -> str:
86
+ versions_str = ", ".join(f"{k}={v}" for k, v in library_versions.items())
87
+ return f"""Library versions: {versions_str}
88
+ Target category hint: {target_category or 'unknown'}
89
+
90
+ Broken script:
91
+ ```python
92
+ {broken_script}
93
+ ```
94
+
95
+ Error trace:
96
+ {error_trace}
97
+
98
+ Output unified diff (no prose, no fences):"""
99
+
100
+
101
+ def list_primitive_descriptions() -> Iterable[str]:
102
+ return (f"- {k}: {v}" for k, v in PRIMITIVE_DESCRIPTIONS.items())
forgeenv/roles/repair_agent.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Repair Agent helpers: response sanitisation + a deterministic baseline.
2
+
3
+ The Repair Agent's training output is a unified diff. LLMs frequently emit
4
+ prose / fences / chain-of-thought before the diff; this module strips that
5
+ preamble. The baseline policy uses the inverse-primitive map from
6
+ `repair_primitives.py` to produce ground-truth diffs for warm-start.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import re
11
+ from dataclasses import dataclass
12
+ from typing import Optional
13
+
14
+ from forgeenv.env.diff_utils import make_unified_diff
15
+ from forgeenv.primitives.breakage_primitives import (
16
+ parse_breakage_spec,
17
+ BreakagePrimitive,
18
+ )
19
+ from forgeenv.primitives.repair_primitives import (
20
+ BREAKAGE_TO_REPAIR,
21
+ REPAIR_REGISTRY,
22
+ RepairPrimitive,
23
+ )
24
+
25
+
26
+ _DIFF_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE)
27
+ _FENCE_RE = re.compile(r"```[a-zA-Z]*\n([\s\S]*?)\n```")
28
+
29
+
30
+ def extract_diff(raw_text: str) -> str:
31
+ """Pull the unified diff out of an LLM response.
32
+
33
+ Handles: code fences, leading prose / chain-of-thought, trailing notes.
34
+ """
35
+ if not raw_text:
36
+ return ""
37
+ raw_text = raw_text.strip()
38
+
39
+ fence_match = _FENCE_RE.search(raw_text)
40
+ if fence_match:
41
+ raw_text = fence_match.group(1).strip()
42
+
43
+ lines = raw_text.splitlines()
44
+ start = 0
45
+ for i, line in enumerate(lines):
46
+ if line.startswith(("---", "+++", "@@")):
47
+ start = i
48
+ break
49
+
50
+ return "\n".join(lines[start:])
51
+
52
+
53
+ def looks_like_diff(text: str) -> bool:
54
+ if not text:
55
+ return False
56
+ has_header = "---" in text and "+++" in text
57
+ has_hunk = bool(_DIFF_HUNK_RE.search(text))
58
+ has_pm = any(line.startswith(("+", "-")) for line in text.splitlines())
59
+ return (has_header and has_hunk) or (has_hunk and has_pm)
60
+
61
+
62
+ # ---------------------------------------------------------------- baselines
63
+ @dataclass
64
+ class BaselineRepairAgent:
65
+ """Deterministic Repair Agent that uses the primitive inverse map.
66
+
67
+ Used for warm-start dataset generation and baseline rollout comparisons.
68
+ """
69
+
70
+ def repair(
71
+ self,
72
+ broken_script: str,
73
+ breakage_spec: Optional[dict] = None,
74
+ original_script: str = "",
75
+ ) -> str:
76
+ """Return a unified diff (or full replacement script) that fixes the
77
+ broken script.
78
+
79
+ Strategy preference:
80
+ 1. If `original_script` is provided, return a diff between the
81
+ broken script and the original (oracle). This is the warm-start
82
+ path — we always know the ground truth.
83
+ 2. Otherwise try to invert the structured breakage_spec via the
84
+ repair-primitive registry.
85
+ 3. Otherwise return an empty diff.
86
+ """
87
+ if original_script and original_script != broken_script:
88
+ return make_unified_diff(broken_script, original_script)
89
+
90
+ if breakage_spec:
91
+ try:
92
+ breakage = parse_breakage_spec(breakage_spec)
93
+ except (ValueError, TypeError):
94
+ breakage = None
95
+ if breakage is not None:
96
+ repair = _invert_breakage(breakage)
97
+ if repair is not None:
98
+ repaired = repair.apply(broken_script)
99
+ if repaired != broken_script:
100
+ return make_unified_diff(broken_script, repaired)
101
+
102
+ return ""
103
+
104
+
105
+ _PARAM_REMAP: dict[str, dict[str, str]] = {
106
+ "RenameApiCall": {"old_name": "old_name", "new_name": "new_name"},
107
+ "DeprecateImport": {"old_module": "old_module", "new_module": "new_module"},
108
+ "ChangeArgumentSignature": {
109
+ "function_name": "function_name",
110
+ "removed_arg": "arg_name",
111
+ },
112
+ "ModifyConfigField": {"field_name": "field_name"},
113
+ "RestructureDatasetSchema": {
114
+ "old_column": "old_column",
115
+ "new_column": "new_column",
116
+ },
117
+ "ChangeTokenizerBehavior": {
118
+ "old_kwarg": "old_kwarg",
119
+ "old_value": "old_value",
120
+ "new_kwarg": "new_kwarg",
121
+ "new_value": "new_value",
122
+ },
123
+ "RemoveDeprecatedMethod": {"method_name": "method_name"},
124
+ "ChangeReturnType": {"old_access": "old_access", "new_access": "new_access"},
125
+ }
126
+
127
+
128
+ def _invert_breakage(breakage: BreakagePrimitive) -> Optional[RepairPrimitive]:
129
+ breakage_name = type(breakage).__name__
130
+ repair_name = BREAKAGE_TO_REPAIR.get(breakage_name)
131
+ if repair_name is None:
132
+ return None
133
+ repair_cls = REPAIR_REGISTRY.get(repair_name)
134
+ if repair_cls is None:
135
+ return None
136
+
137
+ breakage_params = breakage._get_params() # type: ignore[attr-defined]
138
+ remap = _PARAM_REMAP.get(breakage_name, {})
139
+ mapped: dict[str, str] = {}
140
+ for src_key, dst_key in remap.items():
141
+ if src_key in breakage_params:
142
+ mapped[dst_key] = breakage_params[src_key]
143
+
144
+ valid_fields = {
145
+ f.name
146
+ for f in repair_cls.__dataclass_fields__.values() # type: ignore[attr-defined]
147
+ if f.init
148
+ }
149
+ filtered = {k: v for k, v in mapped.items() if k in valid_fields}
150
+ try:
151
+ return repair_cls(**filtered)
152
+ except TypeError:
153
+ return None
forgeenv/roles/teacher.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Teacher (curriculum controller).
2
+
3
+ Deterministic — NOT an LLM. Maintains an EMA success rate per breakage
4
+ category and routes the next episode toward the category where the
5
+ Repair Agent is closest to a 50% success rate (R-Zero's difficulty band).
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import random
10
+ from dataclasses import dataclass, field
11
+
12
+
13
+ @dataclass
14
+ class Teacher:
15
+ categories: list[str]
16
+ alpha: float = 0.9
17
+ success_counts: dict[str, int] = field(default_factory=dict)
18
+ attempt_counts: dict[str, int] = field(default_factory=dict)
19
+ ema_success: dict[str, float] = field(default_factory=dict)
20
+
21
+ def __post_init__(self) -> None:
22
+ for category in self.categories:
23
+ self.success_counts.setdefault(category, 0)
24
+ self.attempt_counts.setdefault(category, 0)
25
+ self.ema_success.setdefault(category, 0.5)
26
+
27
+ def update(self, category: str, success: bool) -> None:
28
+ if category not in self.ema_success:
29
+ self.categories.append(category)
30
+ self.ema_success[category] = 0.5
31
+ self.success_counts[category] = 0
32
+ self.attempt_counts[category] = 0
33
+
34
+ self.attempt_counts[category] += 1
35
+ self.success_counts[category] += int(success)
36
+ rate = self.success_counts[category] / max(1, self.attempt_counts[category])
37
+ self.ema_success[category] = (
38
+ self.alpha * self.ema_success[category] + (1 - self.alpha) * rate
39
+ )
40
+
41
+ def select_next_category(self) -> str:
42
+ in_zone = {
43
+ c: abs(s - 0.5) for c, s in self.ema_success.items() if 0.3 <= s <= 0.7
44
+ }
45
+ if in_zone:
46
+ weights = [1.0 / (v + 0.01) for v in in_zone.values()]
47
+ return random.choices(list(in_zone.keys()), weights=weights, k=1)[0]
48
+ return min(self.ema_success, key=lambda c: abs(self.ema_success[c] - 0.5))
49
+
50
+ def get_state(self) -> dict:
51
+ return {
52
+ c: {
53
+ "ema_success": round(self.ema_success[c], 4),
54
+ "attempts": self.attempt_counts[c],
55
+ "successes": self.success_counts[c],
56
+ }
57
+ for c in self.categories
58
+ }
forgeenv/sandbox/__init__.py ADDED
File without changes
forgeenv/sandbox/ast_validator.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AST-based script validator.
2
+
3
+ Catches forbidden imports and dangerous patterns BEFORE any execution
4
+ happens. This is a critical defense against reward hacking via system
5
+ calls, network access, or process manipulation.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import ast
10
+
11
+ from forgeenv.tasks.models import ValidationResult
12
+
13
+ FORBIDDEN_MODULES = {
14
+ "os",
15
+ "subprocess",
16
+ "socket",
17
+ "urllib",
18
+ "requests",
19
+ "ctypes",
20
+ "shutil",
21
+ "signal",
22
+ "multiprocessing",
23
+ "threading",
24
+ }
25
+
26
+ FORBIDDEN_FUNCTIONS = {"eval", "exec", "compile", "__import__"}
27
+
28
+
29
+ def validate_script(script_content: str) -> ValidationResult:
30
+ """Parse a script as AST and reject forbidden patterns.
31
+
32
+ Returns a ValidationResult with `is_valid` and a list of `violations`.
33
+ """
34
+ violations: list[str] = []
35
+
36
+ try:
37
+ tree = ast.parse(script_content)
38
+ except SyntaxError as e:
39
+ return ValidationResult(is_valid=False, violations=[f"SyntaxError: {e}"])
40
+
41
+ for node in ast.walk(tree):
42
+ if isinstance(node, ast.Import):
43
+ for alias in node.names:
44
+ module_root = alias.name.split(".")[0]
45
+ if module_root in FORBIDDEN_MODULES:
46
+ violations.append(f"Forbidden import: {alias.name}")
47
+
48
+ if isinstance(node, ast.ImportFrom):
49
+ if node.module:
50
+ module_root = node.module.split(".")[0]
51
+ if module_root in FORBIDDEN_MODULES:
52
+ violations.append(f"Forbidden import from: {node.module}")
53
+
54
+ if isinstance(node, ast.Call):
55
+ if isinstance(node.func, ast.Name):
56
+ if node.func.id in FORBIDDEN_FUNCTIONS:
57
+ violations.append(f"Forbidden call: {node.func.id}()")
58
+ if isinstance(node.func, ast.Attribute):
59
+ if node.func.attr in FORBIDDEN_FUNCTIONS:
60
+ violations.append(f"Forbidden call: .{node.func.attr}()")
61
+
62
+ if isinstance(node, ast.Assign):
63
+ for target in node.targets:
64
+ if isinstance(target, ast.Name) and target.id == "__builtins__":
65
+ violations.append("Forbidden: __builtins__ assignment")
66
+
67
+ return ValidationResult(
68
+ is_valid=len(violations) == 0,
69
+ violations=violations,
70
+ )
forgeenv/sandbox/simulation_mode.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Fast simulation executor for development.
2
+
3
+ Static-analysis-based execution simulator. Sub-100ms per call. No Docker
4
+ required. The success probability of a simulated run depends on whether
5
+ the script contains expected HF training markers (model imports, training
6
+ calls, save calls). When the simulation succeeds, a synthetic decreasing
7
+ loss curve is emitted; when it fails, a representative HF error is raised.
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import random
12
+ import time
13
+ from typing import Optional
14
+
15
+ from forgeenv.sandbox.ast_validator import validate_script
16
+ from forgeenv.tasks.models import ExecutionResult, Task
17
+
18
+
19
+ class SimulationExecutor:
20
+ """Simulates script execution via static analysis.
21
+
22
+ Use this throughout development phases. Real Docker execution is added
23
+ later for grounded final-stage verification.
24
+ """
25
+
26
+ def __init__(self, seed: Optional[int] = None) -> None:
27
+ self._rng = random.Random(seed) if seed is not None else random
28
+
29
+ def execute(
30
+ self, script_content: str, task: Optional[Task] = None
31
+ ) -> ExecutionResult:
32
+ start = time.time()
33
+
34
+ validation = validate_script(script_content)
35
+ if not validation.is_valid:
36
+ return ExecutionResult(
37
+ exit_code=1,
38
+ stdout="",
39
+ stderr=f"Validation failed: {'; '.join(validation.violations)}",
40
+ wall_time_ms=int((time.time() - start) * 1000),
41
+ script_content=script_content,
42
+ )
43
+
44
+ try:
45
+ compile(script_content, "<forge_script>", "exec")
46
+ except SyntaxError as e:
47
+ return ExecutionResult(
48
+ exit_code=1,
49
+ stdout="",
50
+ stderr=f"SyntaxError: {e}",
51
+ wall_time_ms=int((time.time() - start) * 1000),
52
+ script_content=script_content,
53
+ )
54
+
55
+ has_model_import = any(
56
+ kw in script_content
57
+ for kw in ("from transformers", "import torch", "from datasets")
58
+ )
59
+ has_training_call = any(
60
+ kw in script_content
61
+ for kw in ("trainer.train()", ".fit(", "train_loop", "for epoch")
62
+ )
63
+ has_save = any(
64
+ kw in script_content
65
+ for kw in ("save_pretrained", "save_model", "torch.save")
66
+ )
67
+
68
+ success_prob = 0.3
69
+ if has_model_import:
70
+ success_prob += 0.3
71
+ if has_training_call:
72
+ success_prob += 0.2
73
+ if has_save:
74
+ success_prob += 0.1
75
+
76
+ # Mark obviously broken patterns as definite failures even when
77
+ # they pass syntactic compilation. The simulator pretends to be a
78
+ # static linter that catches AttributeError / ImportError signatures
79
+ # before they would fire at runtime.
80
+ broken_markers = (
81
+ "_DEPRECATED(",
82
+ "transformers.legacy",
83
+ "from transformers.training import",
84
+ ".start_training(",
85
+ "load_from_hub(",
86
+ "save_to_hub(",
87
+ "pad_to_max_length=",
88
+ "evaluation_loop(",
89
+ )
90
+ if any(marker in script_content for marker in broken_markers):
91
+ success_prob = 0.0
92
+ # Patterns that look like dataset column drift: a renamed column
93
+ # that doesn't appear in real HF datasets.
94
+ import re as _re
95
+
96
+ if _re.search(r"['\"]input_text['\"]\s*[]:),]", script_content):
97
+ success_prob = min(success_prob, 0.05)
98
+ if _re.search(r"['\"]words['\"]\s*[]:),]", script_content):
99
+ success_prob = min(success_prob, 0.05)
100
+ # Tokenizer kwarg drift (truncate is not valid; truncation is).
101
+ if _re.search(r"\btruncate\s*=", script_content):
102
+ success_prob = min(success_prob, 0.05)
103
+
104
+ succeeded = self._rng.random() < success_prob
105
+
106
+ if succeeded:
107
+ steps = self._rng.randint(20, 50)
108
+ log_lines: list[str] = []
109
+ loss = self._rng.uniform(2.0, 4.0)
110
+ for step in range(1, steps + 1):
111
+ loss *= self._rng.uniform(0.92, 0.99)
112
+ log_lines.append(f"step={step} loss={loss:.4f}")
113
+ log_lines.append("eval_accuracy=0.78")
114
+ log_lines.append("TRAINING_COMPLETE")
115
+
116
+ return ExecutionResult(
117
+ exit_code=0,
118
+ stdout="\n".join(log_lines),
119
+ stderr="",
120
+ wall_time_ms=int((time.time() - start) * 1000)
121
+ + self._rng.randint(1000, 5000),
122
+ checkpoint_exists=True,
123
+ peak_memory_mb=self._rng.uniform(500, 2000),
124
+ script_content=script_content,
125
+ )
126
+
127
+ error_types = [
128
+ "ImportError: cannot import name 'OldTrainer' from 'transformers'",
129
+ "AttributeError: 'Trainer' object has no attribute 'evaluate_model'",
130
+ "KeyError: 'text' column not found in dataset",
131
+ "TypeError: __init__() got an unexpected keyword argument 'num_epochs'",
132
+ "RuntimeError: Expected input batch_size (16) to match target batch_size (32)",
133
+ "ModuleNotFoundError: No module named 'transformers.legacy'",
134
+ ]
135
+ return ExecutionResult(
136
+ exit_code=1,
137
+ stdout="",
138
+ stderr=self._rng.choice(error_types),
139
+ wall_time_ms=int((time.time() - start) * 1000)
140
+ + self._rng.randint(100, 500),
141
+ script_content=script_content,
142
+ )
forgeenv/tasks/__init__.py ADDED
File without changes
forgeenv/tasks/models.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Core data models for ForgeEnv tasks and execution results.
2
+
3
+ These are framework-internal dataclasses (not Pydantic) used throughout the
4
+ simulation, verifier, and primitive layers. The OpenEnv-facing Pydantic
5
+ models live in `forgeenv.env.actions` / `forgeenv.env.observations`.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass, field
10
+ from typing import Optional
11
+
12
+
13
+ @dataclass
14
+ class Task:
15
+ """A HuggingFace training script with execution metadata."""
16
+
17
+ task_id: str
18
+ description: str
19
+ script_content: str
20
+ difficulty: str # "easy", "medium", "hard"
21
+ category: str = "general"
22
+ expected_loss_range: tuple[float, float] = (0.0, 5.0)
23
+ expected_accuracy_range: tuple[float, float] = (0.0, 1.0)
24
+ checkpoint_output_path: str = "/tmp/forge_output/checkpoint"
25
+
26
+
27
+ @dataclass
28
+ class ExecutionResult:
29
+ """Result of executing a Python script in the sandbox."""
30
+
31
+ exit_code: int
32
+ stdout: str
33
+ stderr: str
34
+ wall_time_ms: int
35
+ checkpoint_exists: bool = False
36
+ peak_memory_mb: float = 0.0
37
+ script_content: str = ""
38
+
39
+
40
+ @dataclass
41
+ class ValidationResult:
42
+ """Result of AST validation on a script."""
43
+
44
+ is_valid: bool
45
+ violations: list[str] = field(default_factory=list)
forgeenv/tasks/seed_corpus/__init__.py ADDED
File without changes
forgeenv/tasks/seed_corpus/albert_qa.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ALBERT-tiny extractive QA on 100-sample SQuAD subset."""
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ AutoModelForQuestionAnswering,
5
+ Trainer,
6
+ TrainingArguments,
7
+ DefaultDataCollator,
8
+ )
9
+ from datasets import load_dataset
10
+
11
+ dataset = load_dataset("squad", split="train[:100]")
12
+ tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
13
+
14
+
15
+ def prepare(examples):
16
+ enc = tokenizer(
17
+ examples["question"],
18
+ examples["context"],
19
+ max_length=128,
20
+ truncation="only_second",
21
+ padding="max_length",
22
+ return_offsets_mapping=True,
23
+ )
24
+ start_positions, end_positions = [], []
25
+ for i, offsets in enumerate(enc["offset_mapping"]):
26
+ answer = examples["answers"][i]
27
+ start_char = answer["answer_start"][0]
28
+ end_char = start_char + len(answer["text"][0])
29
+
30
+ token_start = next(
31
+ (idx for idx, (a, b) in enumerate(offsets) if a <= start_char < b), 0
32
+ )
33
+ token_end = next(
34
+ (idx for idx, (a, b) in enumerate(offsets) if a < end_char <= b), token_start
35
+ )
36
+ start_positions.append(token_start)
37
+ end_positions.append(token_end)
38
+
39
+ enc["start_positions"] = start_positions
40
+ enc["end_positions"] = end_positions
41
+ enc.pop("offset_mapping")
42
+ return enc
43
+
44
+
45
+ dataset = dataset.map(prepare, batched=True, remove_columns=dataset.column_names)
46
+
47
+ model = AutoModelForQuestionAnswering.from_pretrained("albert-base-v2")
48
+
49
+ training_args = TrainingArguments(
50
+ output_dir="/tmp/forge_output/checkpoint",
51
+ num_train_epochs=1,
52
+ per_device_train_batch_size=4,
53
+ logging_steps=5,
54
+ save_strategy="epoch",
55
+ no_cuda=True,
56
+ report_to="none",
57
+ )
58
+
59
+ trainer = Trainer(
60
+ model=model,
61
+ args=training_args,
62
+ train_dataset=dataset,
63
+ data_collator=DefaultDataCollator(),
64
+ )
65
+ trainer.train()
66
+ trainer.save_model("/tmp/forge_output/checkpoint")
67
+ print("TRAINING_COMPLETE")
forgeenv/tasks/seed_corpus/bert_ner.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Bert tiny NER fine-tuning on a 200-sample CoNLL-2003 subset."""
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ AutoModelForTokenClassification,
5
+ Trainer,
6
+ TrainingArguments,
7
+ DataCollatorForTokenClassification,
8
+ )
9
+ from datasets import load_dataset
10
+
11
+ dataset = load_dataset("conll2003", split="train[:200]")
12
+ tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")
13
+
14
+
15
+ def tokenize_and_align(example):
16
+ enc = tokenizer(example["tokens"], is_split_into_words=True, truncation=True, max_length=64)
17
+ word_ids = enc.word_ids()
18
+ labels = []
19
+ prev_id = None
20
+ for wid in word_ids:
21
+ if wid is None:
22
+ labels.append(-100)
23
+ elif wid != prev_id:
24
+ labels.append(example["ner_tags"][wid])
25
+ else:
26
+ labels.append(-100)
27
+ prev_id = wid
28
+ enc["labels"] = labels
29
+ return enc
30
+
31
+
32
+ dataset = dataset.map(tokenize_and_align, remove_columns=dataset.column_names)
33
+
34
+ model = AutoModelForTokenClassification.from_pretrained("prajjwal1/bert-tiny", num_labels=9)
35
+
36
+ training_args = TrainingArguments(
37
+ output_dir="/tmp/forge_output/checkpoint",
38
+ num_train_epochs=1,
39
+ per_device_train_batch_size=8,
40
+ logging_steps=5,
41
+ save_strategy="epoch",
42
+ no_cuda=True,
43
+ report_to="none",
44
+ )
45
+
46
+ trainer = Trainer(
47
+ model=model,
48
+ args=training_args,
49
+ train_dataset=dataset,
50
+ data_collator=DataCollatorForTokenClassification(tokenizer),
51
+ )
52
+
53
+ trainer.train()
54
+ trainer.save_model("/tmp/forge_output/checkpoint")
55
+ print("TRAINING_COMPLETE")
forgeenv/tasks/seed_corpus/distilbert_sst2.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DistilBERT fine-tuning on a tiny SST-2 subset.
2
+
3
+ Minimal HuggingFace text-classification training script. Should complete
4
+ in ~60s on CPU.
5
+ """
6
+ from transformers import (
7
+ DistilBertTokenizer,
8
+ DistilBertForSequenceClassification,
9
+ Trainer,
10
+ TrainingArguments,
11
+ )
12
+ from datasets import load_dataset
13
+
14
+ dataset = load_dataset("glue", "sst2", split="train[:500]")
15
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
16
+
17
+
18
+ def tokenize_function(examples):
19
+ return tokenizer(
20
+ examples["sentence"],
21
+ padding="max_length",
22
+ truncation=True,
23
+ max_length=64,
24
+ )
25
+
26
+
27
+ dataset = dataset.map(tokenize_function, batched=True)
28
+ dataset = dataset.rename_column("label", "labels")
29
+ dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
30
+
31
+ model = DistilBertForSequenceClassification.from_pretrained(
32
+ "distilbert-base-uncased", num_labels=2
33
+ )
34
+
35
+ training_args = TrainingArguments(
36
+ output_dir="/tmp/forge_output/checkpoint",
37
+ num_train_epochs=1,
38
+ per_device_train_batch_size=16,
39
+ logging_steps=5,
40
+ save_strategy="epoch",
41
+ no_cuda=True,
42
+ report_to="none",
43
+ )
44
+
45
+ trainer = Trainer(
46
+ model=model,
47
+ args=training_args,
48
+ train_dataset=dataset,
49
+ )
50
+
51
+ trainer.train()
52
+ trainer.save_model("/tmp/forge_output/checkpoint")
53
+ print("TRAINING_COMPLETE")
forgeenv/tasks/seed_corpus/electra_classification.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ELECTRA-small classification on 400-sample AG News (4-way text classification)."""
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ AutoModelForSequenceClassification,
5
+ Trainer,
6
+ TrainingArguments,
7
+ )
8
+ from datasets import load_dataset
9
+
10
+ dataset = load_dataset("ag_news", split="train[:400]")
11
+ tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator")
12
+
13
+
14
+ def tokenize(examples):
15
+ return tokenizer(
16
+ examples["text"],
17
+ padding="max_length",
18
+ truncation=True,
19
+ max_length=64,
20
+ )
21
+
22
+
23
+ dataset = dataset.map(tokenize, batched=True)
24
+ dataset = dataset.rename_column("label", "labels")
25
+ dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
26
+
27
+ model = AutoModelForSequenceClassification.from_pretrained(
28
+ "google/electra-small-discriminator", num_labels=4
29
+ )
30
+
31
+ training_args = TrainingArguments(
32
+ output_dir="/tmp/forge_output/checkpoint",
33
+ num_train_epochs=1,
34
+ per_device_train_batch_size=8,
35
+ logging_steps=5,
36
+ save_strategy="epoch",
37
+ no_cuda=True,
38
+ report_to="none",
39
+ )
40
+
41
+ trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
42
+ trainer.train()
43
+ trainer.save_model("/tmp/forge_output/checkpoint")
44
+ print("TRAINING_COMPLETE")
forgeenv/tasks/seed_corpus/gpt2_textgen.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DistilGPT2 causal-LM fine-tuning on 300 lines of WikiText (text generation)."""
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ AutoModelForCausalLM,
5
+ Trainer,
6
+ TrainingArguments,
7
+ DataCollatorForLanguageModeling,
8
+ )
9
+ from datasets import load_dataset
10
+
11
+ dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:300]")
12
+ tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
13
+ tokenizer.pad_token = tokenizer.eos_token
14
+
15
+
16
+ def tokenize(examples):
17
+ return tokenizer(examples["text"], truncation=True, max_length=64)
18
+
19
+
20
+ dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
21
+
22
+ model = AutoModelForCausalLM.from_pretrained("distilgpt2")
23
+
24
+ training_args = TrainingArguments(
25
+ output_dir="/tmp/forge_output/checkpoint",
26
+ num_train_epochs=1,
27
+ per_device_train_batch_size=4,
28
+ logging_steps=5,
29
+ save_strategy="epoch",
30
+ no_cuda=True,
31
+ report_to="none",
32
+ )
33
+
34
+ trainer = Trainer(
35
+ model=model,
36
+ args=training_args,
37
+ train_dataset=dataset,
38
+ data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
39
+ )
40
+
41
+ trainer.train()
42
+ trainer.save_model("/tmp/forge_output/checkpoint")
43
+ print("TRAINING_COMPLETE")
forgeenv/tasks/seed_corpus/logistic_classifier.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sklearn logistic-regression baseline on a 500-sample tabular task.
2
+
3
+ Sanity baseline that doesn't require torch / transformers / datasets.
4
+ """
5
+ import json
6
+ import pickle
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ from sklearn.datasets import make_classification
11
+ from sklearn.linear_model import LogisticRegression
12
+ from sklearn.model_selection import train_test_split
13
+
14
+ X, y = make_classification(
15
+ n_samples=500, n_features=20, n_informative=10, random_state=0
16
+ )
17
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
18
+
19
+ model = LogisticRegression(max_iter=200)
20
+ for step in range(1, 11):
21
+ model.set_params(max_iter=step * 20)
22
+ model.fit(X_train, y_train)
23
+ train_loss = -np.mean(np.log(np.maximum(model.predict_proba(X_train)[np.arange(len(y_train)), y_train], 1e-9)))
24
+ print(f"step={step} loss={train_loss:.4f}")
25
+
26
+ acc = model.score(X_test, y_test)
27
+ print(f"eval_accuracy={acc:.4f}")
28
+
29
+ ckpt_dir = Path("/tmp/forge_output/checkpoint")
30
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
31
+ with open(ckpt_dir / "logreg.pkl", "wb") as f:
32
+ pickle.dump(model, f)
33
+ with open(ckpt_dir / "metrics.json", "w") as f:
34
+ json.dump({"accuracy": acc}, f)
35
+
36
+ print("TRAINING_COMPLETE")
forgeenv/tasks/seed_corpus/roberta_sentiment.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DistilRoberta sentiment classification on 400-sample IMDB subset."""
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ AutoModelForSequenceClassification,
5
+ Trainer,
6
+ TrainingArguments,
7
+ )
8
+ from datasets import load_dataset
9
+
10
+ dataset = load_dataset("imdb", split="train[:400]")
11
+ tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
12
+
13
+
14
+ def tokenize(examples):
15
+ return tokenizer(
16
+ examples["text"],
17
+ padding="max_length",
18
+ truncation=True,
19
+ max_length=64,
20
+ )
21
+
22
+
23
+ dataset = dataset.map(tokenize, batched=True)
24
+ dataset = dataset.rename_column("label", "labels")
25
+ dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
26
+
27
+ model = AutoModelForSequenceClassification.from_pretrained(
28
+ "distilroberta-base", num_labels=2
29
+ )
30
+
31
+ training_args = TrainingArguments(
32
+ output_dir="/tmp/forge_output/checkpoint",
33
+ num_train_epochs=1,
34
+ per_device_train_batch_size=8,
35
+ logging_steps=5,
36
+ save_strategy="epoch",
37
+ no_cuda=True,
38
+ report_to="none",
39
+ )
40
+
41
+ trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
42
+ trainer.train()
43
+ trainer.save_model("/tmp/forge_output/checkpoint")
44
+ print("TRAINING_COMPLETE")
forgeenv/tasks/seed_corpus/simple_regression.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tiny PyTorch regression on synthetic data (no HF imports — sanity baseline)."""
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ torch.manual_seed(0)
6
+ x = torch.randn(500, 4)
7
+ y = (x @ torch.tensor([1.5, -2.0, 0.5, 3.0])) + 0.1 * torch.randn(500)
8
+
9
+ model = nn.Sequential(
10
+ nn.Linear(4, 16),
11
+ nn.ReLU(),
12
+ nn.Linear(16, 1),
13
+ )
14
+
15
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
16
+ criterion = nn.MSELoss()
17
+
18
+ for epoch in range(50):
19
+ optimizer.zero_grad()
20
+ preds = model(x).squeeze(-1)
21
+ loss = criterion(preds, y)
22
+ loss.backward()
23
+ optimizer.step()
24
+ if (epoch + 1) % 5 == 0:
25
+ print(f"step={epoch + 1} loss={loss.item():.4f}")
26
+
27
+ torch.save(model.state_dict(), "/tmp/forge_output/checkpoint/regression.pt")
28
+ print("TRAINING_COMPLETE")
forgeenv/tasks/seed_corpus/t5_summarization.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tiny T5 fine-tuning for summarization on 100-sample CNN/DailyMail."""
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ AutoModelForSeq2SeqLM,
5
+ DataCollatorForSeq2Seq,
6
+ Seq2SeqTrainer,
7
+ Seq2SeqTrainingArguments,
8
+ )
9
+ from datasets import load_dataset
10
+
11
+ dataset = load_dataset("cnn_dailymail", "3.0.0", split="train[:100]")
12
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
13
+
14
+
15
+ def preprocess(examples):
16
+ inputs = tokenizer(
17
+ ["summarize: " + a for a in examples["article"]],
18
+ max_length=128,
19
+ truncation=True,
20
+ padding="max_length",
21
+ )
22
+ targets = tokenizer(
23
+ examples["highlights"],
24
+ max_length=32,
25
+ truncation=True,
26
+ padding="max_length",
27
+ )
28
+ inputs["labels"] = targets["input_ids"]
29
+ return inputs
30
+
31
+
32
+ dataset = dataset.map(preprocess, batched=True, remove_columns=dataset.column_names)
33
+
34
+ model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
35
+
36
+ training_args = Seq2SeqTrainingArguments(
37
+ output_dir="/tmp/forge_output/checkpoint",
38
+ num_train_epochs=1,
39
+ per_device_train_batch_size=4,
40
+ logging_steps=5,
41
+ save_strategy="epoch",
42
+ no_cuda=True,
43
+ report_to="none",
44
+ predict_with_generate=False,
45
+ )
46
+
47
+ trainer = Seq2SeqTrainer(
48
+ model=model,
49
+ args=training_args,
50
+ train_dataset=dataset,
51
+ data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
52
+ )
53
+ trainer.train()
54
+ trainer.save_model("/tmp/forge_output/checkpoint")
55
+ print("TRAINING_COMPLETE")
forgeenv/tasks/seed_corpus/tiny_mlp_mnist.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tiny PyTorch MLP on a 1000-sample MNIST subset (image classification baseline)."""
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.data import DataLoader
5
+ from datasets import load_dataset
6
+
7
+ dataset = load_dataset("mnist", split="train[:1000]")
8
+ dataset = dataset.with_format("torch")
9
+
10
+
11
+ def collate(batch):
12
+ pixel = torch.stack([b["image"].float().flatten() / 255.0 for b in batch])
13
+ labels = torch.tensor([b["label"] for b in batch])
14
+ return pixel, labels
15
+
16
+
17
+ loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate)
18
+
19
+ model = nn.Sequential(
20
+ nn.Linear(784, 64),
21
+ nn.ReLU(),
22
+ nn.Linear(64, 10),
23
+ )
24
+
25
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
26
+ criterion = nn.CrossEntropyLoss()
27
+
28
+ for epoch in range(2):
29
+ for step, (x, y) in enumerate(loader, start=1):
30
+ optimizer.zero_grad()
31
+ loss = criterion(model(x), y)
32
+ loss.backward()
33
+ optimizer.step()
34
+ if step % 5 == 0:
35
+ print(f"step={step} loss={loss.item():.4f}")
36
+
37
+ torch.save(model.state_dict(), "/tmp/forge_output/checkpoint/mlp.pt")
38
+ print("TRAINING_COMPLETE")
forgeenv/tasks/seed_corpus/vit_cifar10.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tiny ViT image classification on 200-sample CIFAR-10 subset."""
2
+ from transformers import (
3
+ AutoImageProcessor,
4
+ AutoModelForImageClassification,
5
+ Trainer,
6
+ TrainingArguments,
7
+ )
8
+ from datasets import load_dataset
9
+ import torch
10
+
11
+ dataset = load_dataset("cifar10", split="train[:200]")
12
+ processor = AutoImageProcessor.from_pretrained("WinKawaks/vit-tiny-patch16-224")
13
+
14
+
15
+ def transform(batch):
16
+ images = [img.convert("RGB") for img in batch["img"]]
17
+ inputs = processor(images=images, return_tensors="pt")
18
+ inputs["labels"] = torch.tensor(batch["label"])
19
+ return inputs
20
+
21
+
22
+ dataset = dataset.with_transform(transform)
23
+
24
+ model = AutoModelForImageClassification.from_pretrained(
25
+ "WinKawaks/vit-tiny-patch16-224", num_labels=10, ignore_mismatched_sizes=True
26
+ )
27
+
28
+ training_args = TrainingArguments(
29
+ output_dir="/tmp/forge_output/checkpoint",
30
+ num_train_epochs=1,
31
+ per_device_train_batch_size=4,
32
+ logging_steps=5,
33
+ save_strategy="epoch",
34
+ no_cuda=True,
35
+ report_to="none",
36
+ )
37
+
38
+ trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
39
+ trainer.train()
40
+ trainer.save_model("/tmp/forge_output/checkpoint")
41
+ print("TRAINING_COMPLETE")
forgeenv/tasks/task_sampler.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Task sampler: loads the seed corpus and samples Tasks by difficulty.
2
+
3
+ Difficulty is auto-derived from script line count. Category is auto-detected
4
+ from script content (text_classification, ner, translation, etc.).
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import random
9
+ from pathlib import Path
10
+ from typing import Optional
11
+
12
+ from forgeenv.tasks.models import Task
13
+
14
+
15
+ def _detect_category(content: str) -> str:
16
+ cl = content.lower()
17
+ if "sequenceclassification" in cl or "sentiment" in cl or "ag_news" in cl or "sst2" in cl:
18
+ return "text_classification"
19
+ if "tokenclassification" in cl or "ner" in cl or "conll" in cl:
20
+ return "ner"
21
+ if "seq2seq" in cl or "translation" in cl or "summariz" in cl or "t5" in cl:
22
+ return "seq2seq"
23
+ if "causallm" in cl or "gpt2" in cl or "wikitext" in cl:
24
+ return "text_generation"
25
+ if "imageclassification" in cl or "vit" in cl or "cifar" in cl or "mnist" in cl:
26
+ return "image_classification"
27
+ if "questionanswering" in cl or "squad" in cl:
28
+ return "qa"
29
+ if "logisticregression" in cl or "make_classification" in cl:
30
+ return "tabular"
31
+ if "regression" in cl:
32
+ return "regression"
33
+ return "general"
34
+
35
+
36
+ def _derive_difficulty(content: str) -> str:
37
+ lines = len(content.splitlines())
38
+ if lines < 30:
39
+ return "easy"
40
+ if lines < 60:
41
+ return "medium"
42
+ return "hard"
43
+
44
+
45
+ class TaskSampler:
46
+ """Loads seed corpus and samples tasks by difficulty / category."""
47
+
48
+ def __init__(self, seed_dir: Optional[str] = None) -> None:
49
+ if seed_dir is None:
50
+ seed_dir = str(Path(__file__).parent / "seed_corpus")
51
+
52
+ self.tasks: list[Task] = []
53
+ self._load_corpus(seed_dir)
54
+
55
+ def _load_corpus(self, seed_dir: str) -> None:
56
+ corpus_path = Path(seed_dir)
57
+ if not corpus_path.exists():
58
+ return
59
+
60
+ for py_file in sorted(corpus_path.glob("*.py")):
61
+ if py_file.name.startswith("__"):
62
+ continue
63
+
64
+ content = py_file.read_text(encoding="utf-8")
65
+ task_id = py_file.stem
66
+ difficulty = _derive_difficulty(content)
67
+ category = _detect_category(content)
68
+
69
+ description = ""
70
+ if content.startswith('"""'):
71
+ end = content.find('"""', 3)
72
+ if end != -1:
73
+ description = content[3:end].strip()
74
+
75
+ self.tasks.append(
76
+ Task(
77
+ task_id=task_id,
78
+ description=description or f"Training script: {task_id}",
79
+ script_content=content,
80
+ difficulty=difficulty,
81
+ category=category,
82
+ )
83
+ )
84
+
85
+ def sample(self, difficulty: Optional[str] = None) -> Optional[Task]:
86
+ candidates = self.tasks
87
+ if difficulty is not None:
88
+ filtered = [t for t in self.tasks if t.difficulty == difficulty]
89
+ if filtered:
90
+ candidates = filtered
91
+ return random.choice(candidates) if candidates else None
92
+
93
+ def sample_batch(
94
+ self, n: int, difficulty: Optional[str] = None
95
+ ) -> list[Task]:
96
+ return [t for t in (self.sample(difficulty) for _ in range(n)) if t is not None]
97
+
98
+ def get_all_categories(self) -> list[str]:
99
+ return sorted({t.category for t in self.tasks})
100
+
101
+ def get_by_id(self, task_id: str) -> Optional[Task]:
102
+ for t in self.tasks:
103
+ if t.task_id == task_id:
104
+ return t
105
+ return None
forgeenv/verifier/__init__.py ADDED
File without changes
forgeenv/verifier/held_out_evaluator.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Held-out evaluator: the deterministic ground-truth scorer.
2
+
3
+ Returns 7 independent components in [0, 1]. The Repair Agent NEVER sees
4
+ this directly; the Drift Generator's training signal derives from
5
+ alignment between the visible verifier and this evaluator (Pearson
6
+ correlation across the K rollouts).
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import ast
11
+ import re
12
+
13
+ from forgeenv.tasks.models import ExecutionResult, Task
14
+
15
+
16
+ def compute_held_out_scores(
17
+ result: ExecutionResult, task: Task, repair_diff: str = ""
18
+ ) -> dict[str, float]:
19
+ """Compute 7 independent held-out components."""
20
+
21
+ scores: dict[str, float] = {
22
+ "executed_cleanly": 1.0 if result.exit_code == 0 else 0.0,
23
+ "checkpoint_valid": 1.0 if result.checkpoint_exists else 0.0,
24
+ "loss_decreased": _compute_loss_score(result.stdout),
25
+ "metrics_in_range": _check_metrics(result.stdout, task),
26
+ "no_forbidden_workarounds": _check_workarounds(result.script_content),
27
+ "intent_preserved": _compute_intent_preservation(
28
+ task.script_content, result.script_content
29
+ ),
30
+ "hidden_tests_passed": 1.0 if "TRAINING_COMPLETE" in result.stdout else 0.0,
31
+ }
32
+ return scores
33
+
34
+
35
+ def _compute_loss_score(stdout: str) -> float:
36
+ """Continuous score based on relative loss decrease from first to last step."""
37
+
38
+ losses: list[float] = []
39
+ for line in stdout.splitlines():
40
+ match = re.search(r"loss[=:\s]+([\d.]+)", line, re.IGNORECASE)
41
+ if match:
42
+ try:
43
+ losses.append(float(match.group(1)))
44
+ except ValueError:
45
+ continue
46
+
47
+ if len(losses) < 2:
48
+ return 0.0
49
+
50
+ decrease = (losses[0] - losses[-1]) / max(losses[0], 1e-8)
51
+ return max(0.0, min(1.0, decrease))
52
+
53
+
54
+ def _check_metrics(stdout: str, task: Task) -> float:
55
+ """Return 1.0 if any reported accuracy/eval metric falls in the task's
56
+ expected range; partial credit otherwise; 0.5 if no metric was found."""
57
+
58
+ for line in stdout.splitlines():
59
+ match = re.search(r"(?:accuracy|acc|eval)[=:\s]+([\d.]+)", line, re.IGNORECASE)
60
+ if match:
61
+ try:
62
+ val = float(match.group(1))
63
+ low, high = task.expected_accuracy_range
64
+ if low <= val <= high:
65
+ return 1.0
66
+ distance = min(abs(val - low), abs(val - high))
67
+ return max(0.0, 1.0 - distance)
68
+ except ValueError:
69
+ continue
70
+ return 0.5
71
+
72
+
73
+ def _check_workarounds(script_content: str) -> float:
74
+ """Detect forbidden workaround patterns via AST analysis.
75
+
76
+ Catches: bare except, `except Exception: pass`, `except Exception: return`,
77
+ monkey-patching of `__getattr__` / `__class__` / `__dict__`.
78
+ """
79
+
80
+ if not script_content:
81
+ return 0.0
82
+
83
+ try:
84
+ tree = ast.parse(script_content)
85
+ except SyntaxError:
86
+ return 0.0
87
+
88
+ violations = 0
89
+
90
+ for node in ast.walk(tree):
91
+ if isinstance(node, ast.Try):
92
+ for handler in node.handlers:
93
+ if handler.type is None:
94
+ violations += 1
95
+ elif (
96
+ isinstance(handler.type, ast.Name)
97
+ and handler.type.id == "Exception"
98
+ ):
99
+ if len(handler.body) == 1 and isinstance(
100
+ handler.body[0], (ast.Pass, ast.Return)
101
+ ):
102
+ violations += 1
103
+
104
+ if isinstance(node, ast.Assign):
105
+ for target in node.targets:
106
+ if isinstance(target, ast.Attribute):
107
+ if target.attr in ("__getattr__", "__class__", "__dict__"):
108
+ violations += 1
109
+
110
+ return 1.0 if violations == 0 else max(0.0, 1.0 - violations * 0.3)
111
+
112
+
113
+ def _compute_intent_preservation(original: str, repaired: str) -> float:
114
+ """Measure how much of the original AST structure is preserved.
115
+
116
+ Uses ratio of shared AST node count: min(N_orig, N_repair) / max(...).
117
+ """
118
+
119
+ if not original or not repaired:
120
+ return 0.0
121
+
122
+ try:
123
+ orig_tree = ast.parse(original)
124
+ repair_tree = ast.parse(repaired)
125
+ except SyntaxError:
126
+ return 0.0
127
+
128
+ orig_nodes = len(list(ast.walk(orig_tree)))
129
+ repair_nodes = len(list(ast.walk(repair_tree)))
130
+
131
+ if orig_nodes == 0:
132
+ return 0.0
133
+
134
+ return min(orig_nodes, repair_nodes) / max(orig_nodes, repair_nodes)
forgeenv/verifier/visible_verifier.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visible verifier: the immediate reward signal the Repair Agent sees.
2
+
3
+ 4 weighted components, summed to a scalar. This is what drives the Repair
4
+ Agent's GRPO updates each rollout. Multiple independent components were
5
+ chosen on purpose, per the reward-engineering survey (arxiv 2408.10215)
6
+ and software-tasks survey (arxiv 2601.19100): a single scalar is far
7
+ easier to game than a composable rubric.
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import re
12
+
13
+ from forgeenv.tasks.models import ExecutionResult, Task
14
+
15
+ WEIGHTS: dict[str, float] = {
16
+ "script_executes": 1.0,
17
+ "loss_decreased": 0.5,
18
+ "checkpoint_appeared": 0.3,
19
+ "diff_size_penalty": 0.2, # multiplied with a non-positive component value
20
+ }
21
+
22
+
23
+ def compute_visible_reward(
24
+ result: ExecutionResult, task: Task
25
+ ) -> tuple[float, dict[str, float]]:
26
+ """Compute scalar visible reward and per-component breakdown."""
27
+
28
+ components: dict[str, float] = {}
29
+
30
+ components["script_executes"] = 1.0 if result.exit_code == 0 else 0.0
31
+ components["loss_decreased"] = _check_loss_trend(result.stdout)
32
+ components["checkpoint_appeared"] = 1.0 if result.checkpoint_exists else 0.0
33
+
34
+ original_lines = max(len(task.script_content.splitlines()), 1)
35
+ current_lines = (
36
+ len(result.script_content.splitlines()) if result.script_content else original_lines
37
+ )
38
+ diff_ratio = abs(current_lines - original_lines) / original_lines
39
+ components["diff_size_penalty"] = -1.0 * diff_ratio if diff_ratio > 0.5 else 0.0
40
+
41
+ total = sum(components[k] * WEIGHTS[k] for k in components)
42
+ return total, components
43
+
44
+
45
+ def _check_loss_trend(stdout: str) -> float:
46
+ """Parse stdout for `loss=...` patterns and return the fraction of
47
+ consecutive steps where loss strictly decreased."""
48
+
49
+ losses: list[float] = []
50
+ for line in stdout.splitlines():
51
+ match = re.search(r"loss[=:\s]+([\d.]+)", line, re.IGNORECASE)
52
+ if match:
53
+ try:
54
+ losses.append(float(match.group(1)))
55
+ except ValueError:
56
+ continue
57
+
58
+ if len(losses) < 2:
59
+ return 0.0
60
+
61
+ decreasing_steps = sum(
62
+ 1 for i in range(1, len(losses)) if losses[i] < losses[i - 1]
63
+ )
64
+ return decreasing_steps / (len(losses) - 1)
openenv.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: forgeenv
2
+ version: 0.1.0
3
+ description: >
4
+ Self-improving RL environment for HuggingFace ecosystem repair.
5
+ Trains agents to fix broken training scripts under library version drift
6
+ through Challenger-Solver co-evolution. Implements R-Zero (Tencent), SPIRAL,
7
+ and Absolute Zero Reasoner techniques on top of OpenEnv.
8
+ theme: self-improvement
9
+ tags:
10
+ - openenv
11
+ - self-play
12
+ - code-repair
13
+ - schema-drift
14
+ - multi-role
15
+ - huggingface
16
+ - reinforcement-learning
17
+ environment:
18
+ class: forgeenv.env.forge_environment.ForgeEnvironment
19
+ action_model: forgeenv.env.actions.ForgeAction
20
+ observation_model: forgeenv.env.observations.ForgeObservation
21
+ server:
22
+ module: forgeenv.env.server
23
+ app: app
24
+ port: 7860
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ openenv-core>=0.2.0
2
+ fastapi>=0.110.0
3
+ uvicorn[standard]>=0.27.0
4
+ pydantic>=2.6.0
5
+ pyyaml>=6.0
6
+ nltk>=3.8.0
7
+ scikit-learn>=1.4.0
8
+ numpy>=1.26.0
9
+ rich>=13.7.0