| """Teacher (curriculum controller).
|
|
|
| Deterministic — NOT an LLM. Maintains an EMA success rate per breakage
|
| category and routes the next episode toward the category where the
|
| Repair Agent is closest to a 50% success rate (R-Zero's difficulty band).
|
| """
|
| from __future__ import annotations
|
|
|
| import random
|
| from dataclasses import dataclass, field
|
|
|
|
|
| @dataclass
|
| class Teacher:
|
| categories: list[str]
|
| alpha: float = 0.9
|
| success_counts: dict[str, int] = field(default_factory=dict)
|
| attempt_counts: dict[str, int] = field(default_factory=dict)
|
| ema_success: dict[str, float] = field(default_factory=dict)
|
|
|
| def __post_init__(self) -> None:
|
| for category in self.categories:
|
| self.success_counts.setdefault(category, 0)
|
| self.attempt_counts.setdefault(category, 0)
|
| self.ema_success.setdefault(category, 0.5)
|
|
|
| def update(self, category: str, success: bool) -> None:
|
| if category not in self.ema_success:
|
| self.categories.append(category)
|
| self.ema_success[category] = 0.5
|
| self.success_counts[category] = 0
|
| self.attempt_counts[category] = 0
|
|
|
| self.attempt_counts[category] += 1
|
| self.success_counts[category] += int(success)
|
| rate = self.success_counts[category] / max(1, self.attempt_counts[category])
|
| self.ema_success[category] = (
|
| self.alpha * self.ema_success[category] + (1 - self.alpha) * rate
|
| )
|
|
|
| def select_next_category(self) -> str:
|
| in_zone = {
|
| c: abs(s - 0.5) for c, s in self.ema_success.items() if 0.3 <= s <= 0.7
|
| }
|
| if in_zone:
|
| weights = [1.0 / (v + 0.01) for v in in_zone.values()]
|
| return random.choices(list(in_zone.keys()), weights=weights, k=1)[0]
|
| return min(self.ema_success, key=lambda c: abs(self.ema_success[c] - 0.5))
|
|
|
| def get_state(self) -> dict:
|
| return {
|
| c: {
|
| "ema_success": round(self.ema_success[c], 4),
|
| "attempts": self.attempt_counts[c],
|
| "successes": self.success_counts[c],
|
| }
|
| for c in self.categories
|
| }
|
|
|