echo-ultimate / training /adversarial.py
Vikaspandey582003's picture
Upload folder using huggingface_hub
acb327b verified
"""
ECHO ULTIMATE — Phase 4: Adversarial Self-Play.
After Phase 3, the model generates its own hard calibration questions targeting
its weakest domains, then trains on them for an additional 500 steps.
This is a research feature — all errors are caught and logged without crashing.
"""
import json
import logging
import re
import torch
from dataclasses import dataclass, field
from typing import List, Optional
from config import cfg
logger = logging.getLogger(__name__)
_WEAK_DOMAIN_DEFAULT = ["medical", "coding", "science"]
@dataclass
class AdversarialQuestion:
question: str
domain: str
difficulty: str = "adversarial"
generated_by: str = "self-play"
def generate_adversarial_questions(
model,
tokenizer,
weak_domains: List[str],
n_questions: int = 200,
config=None,
) -> List[dict]:
"""
Model generates questions in domains where it is overconfident.
Returns a list of task dicts compatible with TaskBank format.
"""
config = config or cfg
questions = []
per_domain = max(1, n_questions // len(weak_domains))
for domain in weak_domains:
prompt = (
f"Generate {per_domain} challenging {domain} questions where an AI might be "
f"overconfident. Each should have a clear, non-obvious correct answer.\n"
f"Format each as:\nQ: [question]\nA: [correct answer]\n---\n"
f"Generate {per_domain} questions now:\n"
)
try:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=1000,
temperature=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
generated = tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True
)
pairs = generated.split("---")
for pair in pairs:
q_match = re.search(r"Q:\s*(.+?)(?=A:|$)", pair, re.DOTALL)
a_match = re.search(r"A:\s*(.+?)(?=Q:|---$|$)", pair, re.DOTALL)
if q_match and a_match:
q_text = q_match.group(1).strip().replace("\n", " ")
a_text = a_match.group(1).strip().replace("\n", " ")
if q_text and a_text:
questions.append({
"id": f"adversarial_{domain}_{len(questions):05d}",
"domain": domain,
"difficulty": "adversarial",
"difficulty_score": 0.10,
"question": q_text,
"answer": a_text,
"answer_aliases": [a_text],
"source_dataset": "self_play",
"metadata": {"generated_by": "echo_phase4"},
})
except Exception as exc:
logger.error("Phase 4 generation failed for domain %s: %s", domain, exc)
logger.info("Phase 4: generated %d adversarial questions", len(questions))
return questions[:n_questions]
def _get_weak_domains(reward_history) -> List[str]:
"""Return the 3 domains with the highest ECE (most miscalibrated)."""
if reward_history is None:
return _WEAK_DOMAIN_DEFAULT
try:
profiles = reward_history.get_domain_profiles()
if not profiles:
return _WEAK_DOMAIN_DEFAULT
sorted_domains = sorted(
[(d, p.ece) for d, p in profiles.items() if p.n_samples > 0],
key=lambda x: x[1],
reverse=True,
)
weak = [d for d, _ in sorted_domains[:3]]
return weak if weak else _WEAK_DOMAIN_DEFAULT
except Exception:
return _WEAK_DOMAIN_DEFAULT
def run_phase_4(trainer, model, tokenizer, reward_history, config=None) -> List[dict]:
"""
Run adversarial self-play phase after Phase 3.
Generates questions targeting weak domains, saves them, and trains 500 more steps.
"""
config = config or cfg
logger.info("=== PHASE 4: ADVERSARIAL SELF-PLAY ===")
print("\n🧪 Phase 4: Adversarial Self-Play")
try:
weak_domains = _get_weak_domains(reward_history)
print(f" Targeting weak domains: {weak_domains}")
questions = generate_adversarial_questions(
model, tokenizer, weak_domains, n_questions=200, config=config
)
print(f" Generated {len(questions)} adversarial questions")
# Save for inspection / reuse
out_path = "adversarial_questions.json"
with open(out_path, "w") as f:
json.dump(questions, f, indent=2)
print(f" Saved to {out_path}")
if not questions:
logger.warning("Phase 4: no questions generated — skipping extra training")
return questions
# Build a small dataset from the adversarial questions and run 500 more steps
try:
from training.dataset import build_grpo_dataset
from env.task_bank import TaskBank
# Inject questions into a temporary TaskBank and rebuild dataset
tmp_bank = TaskBank()
tmp_bank.ensure_loaded()
for q in questions:
d = q["domain"]
if d in tmp_bank._tasks:
tmp_bank._tasks[d]["hard"].append(q)
adv_dataset = build_grpo_dataset(
tmp_bank,
n_samples=min(500 * config.BATCH_SIZE, len(questions) * 4),
phase=3,
tokenizer=tokenizer,
)
trainer.train_dataset = adv_dataset
trainer.args.max_steps = (trainer.state.global_step or 0) + 500
print(" Training 500 steps on adversarial questions…")
trainer.train(resume_from_checkpoint=False)
print(" Phase 4 complete ✅")
except Exception as exc:
logger.error("Phase 4 extra training failed: %s", exc)
return questions
except Exception as exc:
logger.error("Phase 4 run_phase_4 error: %s", exc)
return []