""" CivicAI TRL PPO Training Script — scripts/train_ppo.py ======================================================= Full training pipeline using HuggingFace TRL. LLM (GPT-2) receives society state as text → outputs JSON action. PPO optimises the LLM against the CivicAI environment reward. """ from __future__ import annotations import os, sys, json, random import numpy as np import torch import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from tqdm import tqdm from transformers import AutoTokenizer from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from civicai.environment import CivicAIEnv from civicai.models import Action, SubsidyPolicy from civicai.reward import get_named_scores, compute_reward # ── Config ──────────────────────────────────────────────────────────────────── MODEL_NAME = "gpt2" # swap for "meta-llama/Llama-3.2-1B" on Colab A100 TASK_ID = "stabilize_economy" N_EPISODES = 20 # episodes to train STEPS_EP = 50 # max steps per episode BATCH_SIZE = 1 LR = 1.41e-5 SEED = 42 random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED) DARK, PANEL, GRID = "#0f172a", "#1e293b", "#334155" # ── Prompt / Parser ─────────────────────────────────────────────────────────── def obs_to_prompt(obs: dict) -> str: return ( f"You are a policy advisor. State: Turn={obs['turn']}, " f"GDP=${obs['gdp']:.0f}B, Inflation={obs['inflation']:.1%}, " f"Employment={obs['employment_rate']:.1%}, " f"Satisfaction={obs['public_satisfaction']:.1%}, " f"Health={obs['health_index']:.1%}, Crime={obs['crime_rate']:.1%}. " f"Output JSON: {{\"tax_rate\":0.0-1.0,\"healthcare_budget\":0.0-1.0," f"\"education_budget\":0.0-1.0,\"police_budget\":0.0-1.0," f"\"subsidy_policy\":\"none|agriculture|industry|technology\"}} Action:" ) def parse_action(text: str) -> Action: try: s, e = text.find("{"), text.rfind("}") if s != -1 and e != -1: d = json.loads(text[s:e+1]) return Action( tax_rate=max(0.0, min(1.0, float(d.get("tax_rate", 0.25)))), healthcare_budget=max(0.0, min(1.0, float(d.get("healthcare_budget", 0.20)))), education_budget=max(0.0, min(1.0, float(d.get("education_budget", 0.15)))), police_budget=max(0.0, min(1.0, float(d.get("police_budget", 0.10)))), subsidy_policy=SubsidyPolicy(d.get("subsidy_policy", "none")), ) except Exception: pass return Action( tax_rate=random.uniform(0.2, 0.4), healthcare_budget=random.uniform(0.1, 0.3), education_budget=random.uniform(0.05, 0.2), police_budget=random.uniform(0.05, 0.15), ) # ── Random Baseline ─────────────────────────────────────────────────────────── def run_random_baseline(n: int = 5) -> float: rewards = [] env = CivicAIEnv() for seed in range(n): rng = random.Random(seed) env.reset(task_id=TASK_ID, seed=seed) ep = [] for _ in range(STEPS_EP): a = Action( tax_rate=rng.uniform(0.15, 0.5), healthcare_budget=rng.uniform(0.08, 0.35), education_budget=rng.uniform(0.05, 0.25), police_budget=rng.uniform(0.03, 0.18), ) _, r, done, _ = env.step(a) ep.append(r) if done: break rewards.append(float(np.mean(ep))) return float(np.mean(rewards)) # ── Main Training ───────────────────────────────────────────────────────────── def train_ppo(): device = "cuda" if torch.cuda.is_available() else "cpu" print(f"[CivicAI] TRL PPO Training | model={MODEL_NAME} device={device}") # Models config = PPOConfig( model_name=MODEL_NAME, learning_rate=LR, batch_size=BATCH_SIZE, mini_batch_size=1, gradient_accumulation_steps=1, log_with=None, ) model = AutoModelForCausalLMWithValueHead.from_pretrained(MODEL_NAME).to(device) ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(MODEL_NAME).to(device) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) tokenizer.pad_token = tokenizer.eos_token ppo = PPOTrainer(config, model, ref_model, tokenizer) gen_kwargs = dict( max_new_tokens=80, do_sample=True, top_k=50, top_p=0.95, pad_token_id=tokenizer.eos_token_id, ) env = CivicAIEnv() # Baseline print("[CivicAI] Computing random baseline...") baseline_avg = run_random_baseline(5) print(f" Random baseline avg reward: {baseline_avg:.4f}") # Training episode_rewards, episode_components = [], [] print(f"[CivicAI] Training for {N_EPISODES} episodes...") for ep in range(N_EPISODES): obs = env.reset(task_id=TASK_ID, seed=ep) ep_rewards, ep_comp = [], [] for step in tqdm(range(STEPS_EP), desc=f"Ep {ep+1}/{N_EPISODES}", leave=False): prompt = obs_to_prompt(obs.model_dump()) query = tokenizer.encode(prompt, return_tensors="pt").to(device)[0] response = ppo.generate(query.unsqueeze(0), **gen_kwargs) response_ids = response[0][len(query):] text = tokenizer.decode(response_ids, skip_special_tokens=True) action = parse_action(text) obs, reward, done, info = env.step(action) # Named component scores state = env.state() robj = compute_reward(state, action) ep_comp.append(get_named_scores(robj)) reward_t = torch.tensor([reward], dtype=torch.float).to(device) ppo.step([query], [response_ids], [reward_t]) ep_rewards.append(reward) if done: break avg_r = float(np.mean(ep_rewards)) episode_rewards.append(avg_r) episode_components.append({ k: round(float(np.mean([c[k] for c in ep_comp])), 4) for k in ep_comp[0] }) print(f" Ep {ep+1:2d}: avg_reward={avg_r:.4f} " + " ".join(f"{k}={v:.3f}" for k, v in episode_components[-1].items())) # ── Save model ──────────────────────────────────────────────────────────── os.makedirs("assets", exist_ok=True) model.save_pretrained("assets/civicai_ppo_model") tokenizer.save_pretrained("assets/civicai_ppo_model") print("\n Model saved to assets/civicai_ppo_model/") # ── Save JSON results ───────────────────────────────────────────────────── results = { "baseline_avg": baseline_avg, "episode_rewards": episode_rewards, "episode_components": episode_components, "final_avg": float(np.mean(episode_rewards[-5:])), "improvement": float(np.mean(episode_rewards[-5:])) - baseline_avg, } with open("assets/training_results.json", "w") as f: json.dump(results, f, indent=2) # ── Plots ───────────────────────────────────────────────────────────────── _plot_training_curve(episode_rewards, baseline_avg) _plot_component_breakdown(episode_components) print("\n[CivicAI] Training complete.") print(f" Baseline avg: {baseline_avg:.4f}") print(f" Final 5-ep avg: {results['final_avg']:.4f}") print(f" Improvement: {results['improvement']:+.4f}") return results def _plot_training_curve(rewards: list[float], baseline: float) -> None: smooth = np.convolve(rewards, np.ones(3)/3, mode="valid") fig, ax = plt.subplots(figsize=(10, 5)) fig.patch.set_facecolor(DARK); ax.set_facecolor(PANEL) ax.plot(rewards, color="#06b6d4", alpha=0.4, linewidth=1) ax.plot(range(len(smooth)), smooth, color="#06b6d4", linewidth=2.5, label=f"PPO Agent (final={rewards[-1]:.3f})") ax.axhline(baseline, color="#ef4444", linestyle="--", linewidth=1.8, label=f"Random Baseline ({baseline:.3f})") ax.fill_between(range(len(smooth)), smooth, baseline, where=[s > baseline for s in smooth], alpha=0.15, color="#06b6d4", label="Improvement over baseline") ax.set_ylim(0, 1.05) ax.set_xlabel("Episode", color="#94a3b8"); ax.set_ylabel("Avg Step Reward", color="#94a3b8") ax.set_title("CivicAI TRL PPO — Training Curve", color="#e2e8f0", fontsize=14, fontweight="bold") ax.tick_params(colors="#94a3b8") for sp in ax.spines.values(): sp.set_edgecolor(GRID) ax.grid(axis="y", color=GRID, linewidth=0.5, linestyle="--") ax.legend(facecolor=PANEL, edgecolor=GRID, labelcolor="#e2e8f0") plt.tight_layout() plt.savefig("assets/reward_curve.png", dpi=150, facecolor=DARK) plt.close() print(" Saved: assets/reward_curve.png") def _plot_component_breakdown(components: list[dict]) -> None: keys = ["economic_score", "health_score", "satisfaction_score", "crime_score"] colors = ["#f59e0b", "#10b981", "#a78bfa", "#f97316"] fig, axes = plt.subplots(1, 4, figsize=(16, 4)) fig.patch.set_facecolor(DARK) fig.suptitle("Named Reward Components Over Training", color="#e2e8f0", fontsize=13, fontweight="bold") for ax, key, col in zip(axes, keys, colors): vals = [c[key] for c in components] ax.set_facecolor(PANEL) ax.plot(vals, color=col, linewidth=2) ax.fill_between(range(len(vals)), vals, alpha=0.15, color=col) ax.set_ylim(0, 1.05) ax.set_title(key.replace("_score", "").capitalize(), color="#e2e8f0", fontsize=11) ax.tick_params(colors="#94a3b8", labelsize=8) for sp in ax.spines.values(): sp.set_edgecolor(GRID) ax.grid(color=GRID, linewidth=0.4, linestyle="--") plt.tight_layout() plt.savefig("assets/component_scores.png", dpi=150, facecolor=DARK) plt.close() print(" Saved: assets/component_scores.png") if __name__ == "__main__": train_ppo()