| """Teacher (curriculum controller). |
| |
| Deterministic — NOT an LLM. Maintains an EMA success rate per breakage |
| category and routes the next episode toward the category where the |
| Repair Agent is closest to a 50% success rate (R-Zero's difficulty band). |
| """ |
| from __future__ import annotations |
|
|
| import random |
| from dataclasses import dataclass, field |
|
|
|
|
| @dataclass |
| class Teacher: |
| categories: list[str] |
| alpha: float = 0.9 |
| success_counts: dict[str, int] = field(default_factory=dict) |
| attempt_counts: dict[str, int] = field(default_factory=dict) |
| ema_success: dict[str, float] = field(default_factory=dict) |
|
|
| def __post_init__(self) -> None: |
| for category in self.categories: |
| self.success_counts.setdefault(category, 0) |
| self.attempt_counts.setdefault(category, 0) |
| self.ema_success.setdefault(category, 0.5) |
|
|
| def update(self, category: str, success: bool) -> None: |
| if category not in self.ema_success: |
| self.categories.append(category) |
| self.ema_success[category] = 0.5 |
| self.success_counts[category] = 0 |
| self.attempt_counts[category] = 0 |
|
|
| self.attempt_counts[category] += 1 |
| self.success_counts[category] += int(success) |
| rate = self.success_counts[category] / max(1, self.attempt_counts[category]) |
| self.ema_success[category] = ( |
| self.alpha * self.ema_success[category] + (1 - self.alpha) * rate |
| ) |
|
|
| def select_next_category(self) -> str: |
| in_zone = { |
| c: abs(s - 0.5) for c, s in self.ema_success.items() if 0.3 <= s <= 0.7 |
| } |
| if in_zone: |
| weights = [1.0 / (v + 0.01) for v in in_zone.values()] |
| return random.choices(list(in_zone.keys()), weights=weights, k=1)[0] |
| return min(self.ema_success, key=lambda c: abs(self.ema_success[c] - 0.5)) |
|
|
| def get_state(self) -> dict: |
| return { |
| c: { |
| "ema_success": round(self.ema_success[c], 4), |
| "attempts": self.attempt_counts[c], |
| "successes": self.success_counts[c], |
| } |
| for c in self.categories |
| } |
|
|