File size: 2,166 Bytes
a15535e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Teacher (curriculum controller).

Deterministic — NOT an LLM. Maintains an EMA success rate per breakage
category and routes the next episode toward the category where the
Repair Agent is closest to a 50% success rate (R-Zero's difficulty band).
"""
from __future__ import annotations

import random
from dataclasses import dataclass, field


@dataclass
class Teacher:
    categories: list[str]
    alpha: float = 0.9
    success_counts: dict[str, int] = field(default_factory=dict)
    attempt_counts: dict[str, int] = field(default_factory=dict)
    ema_success: dict[str, float] = field(default_factory=dict)

    def __post_init__(self) -> None:
        for category in self.categories:
            self.success_counts.setdefault(category, 0)
            self.attempt_counts.setdefault(category, 0)
            self.ema_success.setdefault(category, 0.5)

    def update(self, category: str, success: bool) -> None:
        if category not in self.ema_success:
            self.categories.append(category)
            self.ema_success[category] = 0.5
            self.success_counts[category] = 0
            self.attempt_counts[category] = 0

        self.attempt_counts[category] += 1
        self.success_counts[category] += int(success)
        rate = self.success_counts[category] / max(1, self.attempt_counts[category])
        self.ema_success[category] = (
            self.alpha * self.ema_success[category] + (1 - self.alpha) * rate
        )

    def select_next_category(self) -> str:
        in_zone = {
            c: abs(s - 0.5) for c, s in self.ema_success.items() if 0.3 <= s <= 0.7
        }
        if in_zone:
            weights = [1.0 / (v + 0.01) for v in in_zone.values()]
            return random.choices(list(in_zone.keys()), weights=weights, k=1)[0]
        return min(self.ema_success, key=lambda c: abs(self.ema_success[c] - 0.5))

    def get_state(self) -> dict:
        return {
            c: {
                "ema_success": round(self.ema_success[c], 4),
                "attempts": self.attempt_counts[c],
                "successes": self.success_counts[c],
            }
            for c in self.categories
        }