Spaces:
Sleeping
Sleeping
| """ | |
| ECHO ULTIMATE — Phase 4: Adversarial Self-Play. | |
| After Phase 3, the model generates its own hard calibration questions targeting | |
| its weakest domains, then trains on them for an additional 500 steps. | |
| This is a research feature — all errors are caught and logged without crashing. | |
| """ | |
| import json | |
| import logging | |
| import re | |
| import torch | |
| from dataclasses import dataclass, field | |
| from typing import List, Optional | |
| from config import cfg | |
| logger = logging.getLogger(__name__) | |
| _WEAK_DOMAIN_DEFAULT = ["medical", "coding", "science"] | |
| class AdversarialQuestion: | |
| question: str | |
| domain: str | |
| difficulty: str = "adversarial" | |
| generated_by: str = "self-play" | |
| def generate_adversarial_questions( | |
| model, | |
| tokenizer, | |
| weak_domains: List[str], | |
| n_questions: int = 200, | |
| config=None, | |
| ) -> List[dict]: | |
| """ | |
| Model generates questions in domains where it is overconfident. | |
| Returns a list of task dicts compatible with TaskBank format. | |
| """ | |
| config = config or cfg | |
| questions = [] | |
| per_domain = max(1, n_questions // len(weak_domains)) | |
| for domain in weak_domains: | |
| prompt = ( | |
| f"Generate {per_domain} challenging {domain} questions where an AI might be " | |
| f"overconfident. Each should have a clear, non-obvious correct answer.\n" | |
| f"Format each as:\nQ: [question]\nA: [correct answer]\n---\n" | |
| f"Generate {per_domain} questions now:\n" | |
| ) | |
| try: | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=1000, | |
| temperature=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| generated = tokenizer.decode( | |
| outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True | |
| ) | |
| pairs = generated.split("---") | |
| for pair in pairs: | |
| q_match = re.search(r"Q:\s*(.+?)(?=A:|$)", pair, re.DOTALL) | |
| a_match = re.search(r"A:\s*(.+?)(?=Q:|---$|$)", pair, re.DOTALL) | |
| if q_match and a_match: | |
| q_text = q_match.group(1).strip().replace("\n", " ") | |
| a_text = a_match.group(1).strip().replace("\n", " ") | |
| if q_text and a_text: | |
| questions.append({ | |
| "id": f"adversarial_{domain}_{len(questions):05d}", | |
| "domain": domain, | |
| "difficulty": "adversarial", | |
| "difficulty_score": 0.10, | |
| "question": q_text, | |
| "answer": a_text, | |
| "answer_aliases": [a_text], | |
| "source_dataset": "self_play", | |
| "metadata": {"generated_by": "echo_phase4"}, | |
| }) | |
| except Exception as exc: | |
| logger.error("Phase 4 generation failed for domain %s: %s", domain, exc) | |
| logger.info("Phase 4: generated %d adversarial questions", len(questions)) | |
| return questions[:n_questions] | |
| def _get_weak_domains(reward_history) -> List[str]: | |
| """Return the 3 domains with the highest ECE (most miscalibrated).""" | |
| if reward_history is None: | |
| return _WEAK_DOMAIN_DEFAULT | |
| try: | |
| profiles = reward_history.get_domain_profiles() | |
| if not profiles: | |
| return _WEAK_DOMAIN_DEFAULT | |
| sorted_domains = sorted( | |
| [(d, p.ece) for d, p in profiles.items() if p.n_samples > 0], | |
| key=lambda x: x[1], | |
| reverse=True, | |
| ) | |
| weak = [d for d, _ in sorted_domains[:3]] | |
| return weak if weak else _WEAK_DOMAIN_DEFAULT | |
| except Exception: | |
| return _WEAK_DOMAIN_DEFAULT | |
| def run_phase_4(trainer, model, tokenizer, reward_history, config=None) -> List[dict]: | |
| """ | |
| Run adversarial self-play phase after Phase 3. | |
| Generates questions targeting weak domains, saves them, and trains 500 more steps. | |
| """ | |
| config = config or cfg | |
| logger.info("=== PHASE 4: ADVERSARIAL SELF-PLAY ===") | |
| print("\n🧪 Phase 4: Adversarial Self-Play") | |
| try: | |
| weak_domains = _get_weak_domains(reward_history) | |
| print(f" Targeting weak domains: {weak_domains}") | |
| questions = generate_adversarial_questions( | |
| model, tokenizer, weak_domains, n_questions=200, config=config | |
| ) | |
| print(f" Generated {len(questions)} adversarial questions") | |
| # Save for inspection / reuse | |
| out_path = "adversarial_questions.json" | |
| with open(out_path, "w") as f: | |
| json.dump(questions, f, indent=2) | |
| print(f" Saved to {out_path}") | |
| if not questions: | |
| logger.warning("Phase 4: no questions generated — skipping extra training") | |
| return questions | |
| # Build a small dataset from the adversarial questions and run 500 more steps | |
| try: | |
| from training.dataset import build_grpo_dataset | |
| from env.task_bank import TaskBank | |
| # Inject questions into a temporary TaskBank and rebuild dataset | |
| tmp_bank = TaskBank() | |
| tmp_bank.ensure_loaded() | |
| for q in questions: | |
| d = q["domain"] | |
| if d in tmp_bank._tasks: | |
| tmp_bank._tasks[d]["hard"].append(q) | |
| adv_dataset = build_grpo_dataset( | |
| tmp_bank, | |
| n_samples=min(500 * config.BATCH_SIZE, len(questions) * 4), | |
| phase=3, | |
| tokenizer=tokenizer, | |
| ) | |
| trainer.train_dataset = adv_dataset | |
| trainer.args.max_steps = (trainer.state.global_step or 0) + 500 | |
| print(" Training 500 steps on adversarial questions…") | |
| trainer.train(resume_from_checkpoint=False) | |
| print(" Phase 4 complete ✅") | |
| except Exception as exc: | |
| logger.error("Phase 4 extra training failed: %s", exc) | |
| return questions | |
| except Exception as exc: | |
| logger.error("Phase 4 run_phase_4 error: %s", exc) | |
| return [] | |