akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
"""Teacher (curriculum controller).
Deterministic — NOT an LLM. Maintains an EMA success rate per breakage
category and routes the next episode toward the category where the
Repair Agent is closest to a 50% success rate (R-Zero's difficulty band).
"""
from __future__ import annotations
import random
from dataclasses import dataclass, field
@dataclass
class Teacher:
categories: list[str]
alpha: float = 0.9
success_counts: dict[str, int] = field(default_factory=dict)
attempt_counts: dict[str, int] = field(default_factory=dict)
ema_success: dict[str, float] = field(default_factory=dict)
def __post_init__(self) -> None:
for category in self.categories:
self.success_counts.setdefault(category, 0)
self.attempt_counts.setdefault(category, 0)
self.ema_success.setdefault(category, 0.5)
def update(self, category: str, success: bool) -> None:
if category not in self.ema_success:
self.categories.append(category)
self.ema_success[category] = 0.5
self.success_counts[category] = 0
self.attempt_counts[category] = 0
self.attempt_counts[category] += 1
self.success_counts[category] += int(success)
rate = self.success_counts[category] / max(1, self.attempt_counts[category])
self.ema_success[category] = (
self.alpha * self.ema_success[category] + (1 - self.alpha) * rate
)
def select_next_category(self) -> str:
in_zone = {
c: abs(s - 0.5) for c, s in self.ema_success.items() if 0.3 <= s <= 0.7
}
if in_zone:
weights = [1.0 / (v + 0.01) for v in in_zone.values()]
return random.choices(list(in_zone.keys()), weights=weights, k=1)[0]
return min(self.ema_success, key=lambda c: abs(self.ema_success[c] - 0.5))
def get_state(self) -> dict:
return {
c: {
"ema_success": round(self.ema_success[c], 4),
"attempts": self.attempt_counts[c],
"successes": self.success_counts[c],
}
for c in self.categories
}