File size: 6,338 Bytes
acb327b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
ECHO ULTIMATE — Phase 4: Adversarial Self-Play.

After Phase 3, the model generates its own hard calibration questions targeting
its weakest domains, then trains on them for an additional 500 steps.
This is a research feature — all errors are caught and logged without crashing.
"""

import json
import logging
import re
import torch
from dataclasses import dataclass, field
from typing import List, Optional

from config import cfg

logger = logging.getLogger(__name__)

_WEAK_DOMAIN_DEFAULT = ["medical", "coding", "science"]


@dataclass
class AdversarialQuestion:
    question: str
    domain: str
    difficulty: str = "adversarial"
    generated_by: str = "self-play"


def generate_adversarial_questions(
    model,
    tokenizer,
    weak_domains: List[str],
    n_questions: int = 200,
    config=None,
) -> List[dict]:
    """
    Model generates questions in domains where it is overconfident.
    Returns a list of task dicts compatible with TaskBank format.
    """
    config = config or cfg
    questions = []
    per_domain = max(1, n_questions // len(weak_domains))

    for domain in weak_domains:
        prompt = (
            f"Generate {per_domain} challenging {domain} questions where an AI might be "
            f"overconfident. Each should have a clear, non-obvious correct answer.\n"
            f"Format each as:\nQ: [question]\nA: [correct answer]\n---\n"
            f"Generate {per_domain} questions now:\n"
        )
        try:
            inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=1000,
                    temperature=0.9,
                    do_sample=True,
                    pad_token_id=tokenizer.eos_token_id,
                )
            generated = tokenizer.decode(
                outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True
            )

            pairs = generated.split("---")
            for pair in pairs:
                q_match = re.search(r"Q:\s*(.+?)(?=A:|$)", pair, re.DOTALL)
                a_match = re.search(r"A:\s*(.+?)(?=Q:|---$|$)", pair, re.DOTALL)
                if q_match and a_match:
                    q_text = q_match.group(1).strip().replace("\n", " ")
                    a_text = a_match.group(1).strip().replace("\n", " ")
                    if q_text and a_text:
                        questions.append({
                            "id":               f"adversarial_{domain}_{len(questions):05d}",
                            "domain":           domain,
                            "difficulty":       "adversarial",
                            "difficulty_score": 0.10,
                            "question":         q_text,
                            "answer":           a_text,
                            "answer_aliases":   [a_text],
                            "source_dataset":   "self_play",
                            "metadata":         {"generated_by": "echo_phase4"},
                        })
        except Exception as exc:
            logger.error("Phase 4 generation failed for domain %s: %s", domain, exc)

    logger.info("Phase 4: generated %d adversarial questions", len(questions))
    return questions[:n_questions]


def _get_weak_domains(reward_history) -> List[str]:
    """Return the 3 domains with the highest ECE (most miscalibrated)."""
    if reward_history is None:
        return _WEAK_DOMAIN_DEFAULT

    try:
        profiles = reward_history.get_domain_profiles()
        if not profiles:
            return _WEAK_DOMAIN_DEFAULT
        sorted_domains = sorted(
            [(d, p.ece) for d, p in profiles.items() if p.n_samples > 0],
            key=lambda x: x[1],
            reverse=True,
        )
        weak = [d for d, _ in sorted_domains[:3]]
        return weak if weak else _WEAK_DOMAIN_DEFAULT
    except Exception:
        return _WEAK_DOMAIN_DEFAULT


def run_phase_4(trainer, model, tokenizer, reward_history, config=None) -> List[dict]:
    """
    Run adversarial self-play phase after Phase 3.
    Generates questions targeting weak domains, saves them, and trains 500 more steps.
    """
    config = config or cfg
    logger.info("=== PHASE 4: ADVERSARIAL SELF-PLAY ===")
    print("\n🧪  Phase 4: Adversarial Self-Play")

    try:
        weak_domains = _get_weak_domains(reward_history)
        print(f"    Targeting weak domains: {weak_domains}")

        questions = generate_adversarial_questions(
            model, tokenizer, weak_domains, n_questions=200, config=config
        )
        print(f"    Generated {len(questions)} adversarial questions")

        # Save for inspection / reuse
        out_path = "adversarial_questions.json"
        with open(out_path, "w") as f:
            json.dump(questions, f, indent=2)
        print(f"    Saved to {out_path}")

        if not questions:
            logger.warning("Phase 4: no questions generated — skipping extra training")
            return questions

        # Build a small dataset from the adversarial questions and run 500 more steps
        try:
            from training.dataset import build_grpo_dataset
            from env.task_bank import TaskBank

            # Inject questions into a temporary TaskBank and rebuild dataset
            tmp_bank = TaskBank()
            tmp_bank.ensure_loaded()
            for q in questions:
                d = q["domain"]
                if d in tmp_bank._tasks:
                    tmp_bank._tasks[d]["hard"].append(q)

            adv_dataset = build_grpo_dataset(
                tmp_bank,
                n_samples=min(500 * config.BATCH_SIZE, len(questions) * 4),
                phase=3,
                tokenizer=tokenizer,
            )
            trainer.train_dataset = adv_dataset
            trainer.args.max_steps = (trainer.state.global_step or 0) + 500
            print("    Training 500 steps on adversarial questions…")
            trainer.train(resume_from_checkpoint=False)
            print("    Phase 4 complete ✅")
        except Exception as exc:
            logger.error("Phase 4 extra training failed: %s", exc)

        return questions

    except Exception as exc:
        logger.error("Phase 4 run_phase_4 error: %s", exc)
        return []