forge-arena / train_phase3.py
Amogh-kal1's picture
Upload folder using huggingface_hub
397ae6f verified
"""
Forge + Arena β€” Phase 3: Resume GRPO Training on Harder Tasks + Double-Rise Plot
Loads the Phase 1 checkpoint, trains on the harder Phase 2 dataset,
and generates the double-rise reward curve plot.
Usage:
# Requires Arena server running, Phase 1 + Phase 2 complete:
# /home/abhay/miniconda3/envs/motioncanvas/bin/python train_phase3.py
"""
from __future__ import annotations
import json
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any
import httpx
import numpy as np
import torch
from datasets import Dataset
from peft import LoraConfig, TaskType
from transformers import AutoTokenizer, BitsAndBytesConfig, TrainerCallback
from trl import GRPOConfig, GRPOTrainer
logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s")
log = logging.getLogger("train_phase3")
# ═══════════════════════════════════════════════════════════════════════════════
# Configuration
# ═══════════════════════════════════════════════════════════════════════════════
SERVER_URL = "http://localhost:8000"
HF_TOKEN = os.environ.get("HF_TOKEN", "")
OVERSEER_LOCAL_DIR = "models/overseer"
# Phase 1 outputs (input to Phase 3)
PHASE1_OUTPUT_DIR = "outputs/overseer-grpo"
PHASE2_DATASET_PATH = "datasets/overseer-episodes-phase2"
# Phase 3 outputs
PHASE3_OUTPUT_DIR = "outputs/overseer-grpo-phase2"
PHASE3_MAX_STEPS = 200
PHASE3_LEARNING_RATE = 5e-6
# Same training config as Phase 1
PER_DEVICE_BATCH_SIZE = 16
GRADIENT_ACCUMULATION = 1
GRPO_NUM_GENERATIONS = 16
TEMPERATURE = 1.5 # High temp for diverse completions
MAX_NEW_TOKENS = 512
WARMUP_STEPS = 20
GRPO_BETA = 0.04
GRPO_LOSS_TYPE = "grpo"
SCALE_REWARDS = "group"
FORMAT_BONUS = 0.0
# QLoRA (same as Phase 1)
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
MODEL_PATH = OVERSEER_LOCAL_DIR if Path(OVERSEER_LOCAL_DIR).exists() else "Qwen/Qwen2.5-1.5B-Instruct"
SYSTEM_PROMPT = (
"You are an AI Overseer. You will receive a task description and the Worker's "
"chain-of-thought and output. You must decide whether the Worker's output "
"contains a corruption, explain your reasoning, and if necessary provide a "
"corrected output.\n\n"
"Respond in JSON with the following keys:\n"
" corruption_detected (bool)\n"
" confidence (float 0-1)\n"
" explanation (str)\n"
" correction (str or null)\n"
)
REWARD_KEYS = ["rewards/arena_reward/mean", "rewards/arena_reward", "reward", "arena_reward"]
# ═══════════════════════════════════════════════════════════════════════════════
# Reward Function (same as Phase 1)
# ═══════════════════════════════════════════════════════════════════════════════
def _extract_completion_text(completion) -> str:
if isinstance(completion, str):
return completion
if isinstance(completion, list):
if not completion:
return ""
last = completion[-1]
if isinstance(last, dict):
content = last.get("content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
return " ".join(b.get("text", "") for b in content if isinstance(b, dict) and b.get("type") == "text")
return str(content)
return str(last) if isinstance(last, str) else str(last)
return str(completion)
def _parse_completion(text: str) -> dict[str, Any]:
if not isinstance(text, str) or not text.strip():
return {}
try:
stripped = text.strip()
if stripped.startswith("```"):
lines = stripped.splitlines()
stripped = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:])
parsed = json.loads(stripped)
if isinstance(parsed, list):
return parsed[0] if parsed and isinstance(parsed[0], dict) else {}
if isinstance(parsed, dict):
return parsed
return {}
except (json.JSONDecodeError, ValueError, IndexError):
return {}
class ArenaRewardFunction:
def __init__(self, server_url: str, format_bonus: float = 0.15) -> None:
self._base = server_url.rstrip("/")
self._client = httpx.Client(timeout=30.0)
self._format_bonus = format_bonus
self.__name__ = "arena_reward"
def __call__(self, prompts, completions, episode_id, corruption_present, corruption_type,
ground_truth_output, worker_output, domains=None, **kwargs) -> list[float]:
_domains = domains or ["customer_support"] * len(completions)
def _grade_one(i):
text = _extract_completion_text(completions[i])
action = _parse_completion(text)
has_valid_json = bool(action) and "corruption_detected" in action
bonus = self._format_bonus if has_valid_json else 0.0
payload = {
"episode_id": episode_id[i], "domain": _domains[i],
"corruption_present": corruption_present[i], "corruption_type": corruption_type[i],
"ground_truth_output": ground_truth_output[i],
"overseer_detection": action.get("corruption_detected", False),
"overseer_confidence": action.get("confidence", 0.5),
"overseer_explanation": action.get("explanation", ""),
"overseer_correction": action.get("correction") or "",
}
try:
resp = self._client.post(f"{self._base}/grader", json=payload)
resp.raise_for_status()
return float(resp.json()["composite"]) + bonus
except Exception:
return bonus
with ThreadPoolExecutor(max_workers=min(len(completions), 8)) as pool:
return list(pool.map(_grade_one, range(len(completions))))
def close(self):
self._client.close()
class ProgressCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
step = state.global_step
reward = logs.get("rewards/arena_reward/mean", logs.get("reward", "---"))
loss = logs.get("loss", "---")
lr = logs.get("learning_rate", "---")
if isinstance(reward, float): reward = f"{reward:.4f}"
if isinstance(loss, float): loss = f"{loss:.4f}"
if isinstance(lr, float): lr = f"{lr:.2e}"
pct = 100 * step / args.max_steps
print(f" [{step:>5}/{args.max_steps}] ({pct:5.1f}%) reward={reward} loss={loss} lr={lr}")
# ═══════════════════════════════════════════════════════════════════════════════
# Main
# ═══════════════════════════════════════════════════════════════════════════════
def main():
print("=" * 60)
print("Phase 3 β€” Resume GRPO on Harder Forge Tasks")
print("=" * 60)
# ── Load Phase 1 log history ──────────────────────────────────────────────
p1_log_path = Path(PHASE1_OUTPUT_DIR) / "phase1_log_history.json"
if p1_log_path.exists():
with open(p1_log_path) as f:
phase1_log_history = json.load(f)
phase1_rewards = [next((e[k] for k in REWARD_KEYS if k in e), None)
for e in phase1_log_history if e.get("step")]
phase1_rewards = [r for r in phase1_rewards if r is not None]
phase1_ceiling = max(phase1_rewards[-10:]) if len(phase1_rewards) >= 10 else (max(phase1_rewards) if phase1_rewards else 0.0)
print(f" Phase 1 ceiling: {phase1_ceiling:.4f}")
else:
phase1_log_history = []
phase1_ceiling = 0.4
print(f" Phase 1 log not found β€” using default ceiling {phase1_ceiling}")
# ── Load Phase 2 dataset ──────────────────────────────────────────────────
phase2_dataset = Dataset.load_from_disk(PHASE2_DATASET_PATH)
if isinstance(phase2_dataset[0]["prompt"], str):
def _to_conv(row):
row["prompt"] = [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": row["prompt"]}]
return row
phase2_dataset = phase2_dataset.map(_to_conv)
print(f" Phase 2 dataset: {len(phase2_dataset)} rows")
# ── Tokenizer + model kwargs ──────────────────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained(OVERSEER_LOCAL_DIR, local_files_only=True)
tokenizer.pad_token = tokenizer.eos_token
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16,
)
model_kwargs = {
"torch_dtype": torch.bfloat16,
"quantization_config": quantization_config,
"device_map": "auto",
}
peft_config = LoraConfig(
r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,
target_modules=LORA_TARGET_MODULES, task_type=TaskType.CAUSAL_LM, bias="none",
)
# ── GRPOConfig for Phase 3 ───────────────────────────────────────────────
grpo_config = GRPOConfig(
output_dir=PHASE3_OUTPUT_DIR,
max_steps=PHASE3_MAX_STEPS,
per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
gradient_accumulation_steps=GRADIENT_ACCUMULATION,
learning_rate=PHASE3_LEARNING_RATE,
warmup_steps=WARMUP_STEPS,
beta=GRPO_BETA,
loss_type=GRPO_LOSS_TYPE,
scale_rewards=SCALE_REWARDS,
num_generations=GRPO_NUM_GENERATIONS,
max_completion_length=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
logging_steps=5,
save_steps=50,
bf16=True,
report_to="none",
model_init_kwargs=model_kwargs,
log_completions=True,
)
reward_fn = ArenaRewardFunction(SERVER_URL, format_bonus=FORMAT_BONUS)
# GRPOTrainer needs the base model path (not LoRA adapter dir).
# It will create fresh LoRA adapters via peft_config, then we load
# Phase 1 adapter weights on top.
trainer = GRPOTrainer(
model=MODEL_PATH, # base model, NOT Phase 1 adapter dir
args=grpo_config,
processing_class=tokenizer,
train_dataset=phase2_dataset,
reward_funcs=[reward_fn],
peft_config=peft_config,
)
# Load Phase 1 LoRA weights into the fresh adapters
import safetensors.torch
p1_weights_path = Path(PHASE1_OUTPUT_DIR) / "adapter_model.safetensors"
if p1_weights_path.exists():
p1_state = safetensors.torch.load_file(str(p1_weights_path))
# Remap key names: saved has "lora_A.weight" but model expects "lora_A.default.weight"
remapped = {}
for k, v in p1_state.items():
new_key = k.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight")
remapped[new_key] = v
missing, unexpected = trainer.model.load_state_dict(remapped, strict=False)
loaded = len(remapped) - len(unexpected)
print(f" Loaded {loaded} Phase 1 LoRA weights from {p1_weights_path}")
if unexpected:
print(f" (unexpected keys: {len(unexpected)})")
else:
print(f" WARNING: No Phase 1 weights at {p1_weights_path} β€” training from base model")
trainer.add_callback(ProgressCallback())
print(f" Loaded Phase 1 checkpoint: {PHASE1_OUTPUT_DIR}")
print(f" Phase 3 LR: {PHASE3_LEARNING_RATE}")
print(f" Phase 3 steps: {PHASE3_MAX_STEPS}")
print()
# ── Train ─────────────────────────────────────────────────────────────────
trainer.train()
trainer.save_model(PHASE3_OUTPUT_DIR)
tokenizer.save_pretrained(PHASE3_OUTPUT_DIR)
reward_fn.close()
# ── Save Phase 3 log history ──────────────────────────────────────────────
phase3_log_history = list(trainer.state.log_history)
with open(Path(PHASE3_OUTPUT_DIR) / "phase3_log_history.json", "w") as f:
json.dump(phase3_log_history, f)
phase3_rewards = [next((e[k] for k in REWARD_KEYS if k in e), None)
for e in phase3_log_history if e.get("step")]
phase3_rewards = [r for r in phase3_rewards if r is not None]
phase3_final = max(phase3_rewards[-10:]) if len(phase3_rewards) >= 10 else (max(phase3_rewards) if phase3_rewards else 0.0)
print(f"\n=== Phase 3 complete ===")
print(f" Phase 1 ceiling : {phase1_ceiling:.4f}")
print(f" Phase 3 final : {phase3_final:.4f} ({phase3_final - phase1_ceiling:+.4f})")
if phase3_final > phase1_ceiling:
print(f" >> DOUBLE-RISE ACHIEVED")
# ═══════════════════════════════════════════════════════════════════════════
# Double-Rise Plot
# ═══════════════════════════════════════════════════════════════════════════
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
plt.rcParams.update({
"figure.facecolor": "#0d0e1a", "axes.facecolor": "#12132a",
"axes.edgecolor": "#2a2d50", "axes.labelcolor": "#c0c4e0",
"xtick.color": "#9aa3c2", "ytick.color": "#9aa3c2",
"text.color": "#e0e0ff", "grid.color": "#1e2040",
"grid.linestyle": "--", "grid.alpha": 0.6,
"font.family": "monospace", "font.size": 10,
"legend.facecolor": "#12132a", "legend.edgecolor": "#2a2d50",
})
ACCENT = "#5b6bff"; GREEN = "#4ade80"; RED = "#f87171"; YELLOW = "#fbbf24"
# Extract Phase 1 steps/rewards
p1s, p1r = [], []
for e in phase1_log_history:
s = e.get("step")
r = next((e[k] for k in REWARD_KEYS if k in e), None)
if s and r is not None:
p1s.append(s); p1r.append(r)
p1_final_step = max(p1s) if p1s else 0
# Extract Phase 3 steps/rewards (offset by Phase 1 final step)
p3s, p3r = [], []
for e in phase3_log_history:
s = e.get("step")
r = next((e[k] for k in REWARD_KEYS if k in e), None)
if s and r is not None:
p3s.append(p1_final_step + s); p3r.append(r)
def sm(xs, ys, w=8):
if len(ys) < w:
return xs, ys
k = np.ones(w) / w
s = np.convolve(ys, k, mode="valid")
h = w // 2
return xs[h:h+len(s)], list(s)
plots_dir = Path(PHASE3_OUTPUT_DIR) / "plots"
plots_dir.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(12, 5))
if p1r:
ax.plot(p1s, p1r, color=ACCENT, lw=1, alpha=0.3)
if p3r:
ax.plot(p3s, p3r, color=GREEN, lw=1, alpha=0.3)
if p1r:
sx, sy = sm(p1s, p1r)
ax.plot(sx, sy, color=ACCENT, lw=2.5, label="Phase 1 (static)")
if p3r:
sx, sy = sm(p3s, p3r)
ax.plot(sx, sy, color=GREEN, lw=2.5, label="Phase 3 (Forge harder)")
ax.axvline(p1_final_step, color=YELLOW, lw=1.5, ls="--", alpha=0.8, label="Forge activated")
ax.axhline(phase1_ceiling, color=RED, lw=1, ls=":", alpha=0.5, label=f"Phase 1 ceiling ({phase1_ceiling:.3f})")
ax.set_title("Forge + Arena β€” Double-Rise Reward Curve", fontsize=14, pad=10)
ax.set_xlabel("Step")
ax.set_ylabel("Composite Reward")
ax.legend(loc="lower right", framealpha=0.4)
ax.grid(True)
fig.tight_layout()
plot_path = plots_dir / "double_rise_reward_curve.png"
fig.savefig(plot_path, dpi=200, bbox_inches="tight")
plt.close(fig)
print(f"\n Plot saved: {plot_path}")
print("\n" + "=" * 60)
print("All 3 phases complete!")
print(f" Phase 1 adapters : {PHASE1_OUTPUT_DIR}")
print(f" Phase 3 adapters : {PHASE3_OUTPUT_DIR}")
print(f" Double-rise plot : {plot_path}")
print("=" * 60)
if __name__ == "__main__":
main()