Spaces:
Sleeping
Sleeping
| """ | |
| CivicAI TRL PPO Training Script β scripts/train_ppo.py | |
| ======================================================= | |
| Full training pipeline using HuggingFace TRL. | |
| LLM (GPT-2) receives society state as text β outputs JSON action. | |
| PPO optimises the LLM against the CivicAI environment reward. | |
| """ | |
| from __future__ import annotations | |
| import os, sys, json, random | |
| import numpy as np | |
| import torch | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from tqdm import tqdm | |
| from transformers import AutoTokenizer | |
| from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from civicai.environment import CivicAIEnv | |
| from civicai.models import Action, SubsidyPolicy | |
| from civicai.reward import get_named_scores, compute_reward | |
| # ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_NAME = "gpt2" # swap for "meta-llama/Llama-3.2-1B" on Colab A100 | |
| TASK_ID = "stabilize_economy" | |
| N_EPISODES = 20 # episodes to train | |
| STEPS_EP = 50 # max steps per episode | |
| BATCH_SIZE = 1 | |
| LR = 1.41e-5 | |
| SEED = 42 | |
| random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED) | |
| DARK, PANEL, GRID = "#0f172a", "#1e293b", "#334155" | |
| # ββ Prompt / Parser βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def obs_to_prompt(obs: dict) -> str: | |
| return ( | |
| f"You are a policy advisor. State: Turn={obs['turn']}, " | |
| f"GDP=${obs['gdp']:.0f}B, Inflation={obs['inflation']:.1%}, " | |
| f"Employment={obs['employment_rate']:.1%}, " | |
| f"Satisfaction={obs['public_satisfaction']:.1%}, " | |
| f"Health={obs['health_index']:.1%}, Crime={obs['crime_rate']:.1%}. " | |
| f"Output JSON: {{\"tax_rate\":0.0-1.0,\"healthcare_budget\":0.0-1.0," | |
| f"\"education_budget\":0.0-1.0,\"police_budget\":0.0-1.0," | |
| f"\"subsidy_policy\":\"none|agriculture|industry|technology\"}} Action:" | |
| ) | |
| def parse_action(text: str) -> Action: | |
| try: | |
| s, e = text.find("{"), text.rfind("}") | |
| if s != -1 and e != -1: | |
| d = json.loads(text[s:e+1]) | |
| return Action( | |
| tax_rate=max(0.0, min(1.0, float(d.get("tax_rate", 0.25)))), | |
| healthcare_budget=max(0.0, min(1.0, float(d.get("healthcare_budget", 0.20)))), | |
| education_budget=max(0.0, min(1.0, float(d.get("education_budget", 0.15)))), | |
| police_budget=max(0.0, min(1.0, float(d.get("police_budget", 0.10)))), | |
| subsidy_policy=SubsidyPolicy(d.get("subsidy_policy", "none")), | |
| ) | |
| except Exception: | |
| pass | |
| return Action( | |
| tax_rate=random.uniform(0.2, 0.4), | |
| healthcare_budget=random.uniform(0.1, 0.3), | |
| education_budget=random.uniform(0.05, 0.2), | |
| police_budget=random.uniform(0.05, 0.15), | |
| ) | |
| # ββ Random Baseline βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_random_baseline(n: int = 5) -> float: | |
| rewards = [] | |
| env = CivicAIEnv() | |
| for seed in range(n): | |
| rng = random.Random(seed) | |
| env.reset(task_id=TASK_ID, seed=seed) | |
| ep = [] | |
| for _ in range(STEPS_EP): | |
| a = Action( | |
| tax_rate=rng.uniform(0.15, 0.5), | |
| healthcare_budget=rng.uniform(0.08, 0.35), | |
| education_budget=rng.uniform(0.05, 0.25), | |
| police_budget=rng.uniform(0.03, 0.18), | |
| ) | |
| _, r, done, _ = env.step(a) | |
| ep.append(r) | |
| if done: | |
| break | |
| rewards.append(float(np.mean(ep))) | |
| return float(np.mean(rewards)) | |
| # ββ Main Training βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def train_ppo(): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"[CivicAI] TRL PPO Training | model={MODEL_NAME} device={device}") | |
| # Models | |
| config = PPOConfig( | |
| model_name=MODEL_NAME, | |
| learning_rate=LR, | |
| batch_size=BATCH_SIZE, | |
| mini_batch_size=1, | |
| gradient_accumulation_steps=1, | |
| log_with=None, | |
| ) | |
| model = AutoModelForCausalLMWithValueHead.from_pretrained(MODEL_NAME).to(device) | |
| ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(MODEL_NAME).to(device) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| ppo = PPOTrainer(config, model, ref_model, tokenizer) | |
| gen_kwargs = dict( | |
| max_new_tokens=80, do_sample=True, top_k=50, top_p=0.95, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| env = CivicAIEnv() | |
| # Baseline | |
| print("[CivicAI] Computing random baseline...") | |
| baseline_avg = run_random_baseline(5) | |
| print(f" Random baseline avg reward: {baseline_avg:.4f}") | |
| # Training | |
| episode_rewards, episode_components = [], [] | |
| print(f"[CivicAI] Training for {N_EPISODES} episodes...") | |
| for ep in range(N_EPISODES): | |
| obs = env.reset(task_id=TASK_ID, seed=ep) | |
| ep_rewards, ep_comp = [], [] | |
| for step in tqdm(range(STEPS_EP), desc=f"Ep {ep+1}/{N_EPISODES}", leave=False): | |
| prompt = obs_to_prompt(obs.model_dump()) | |
| query = tokenizer.encode(prompt, return_tensors="pt").to(device)[0] | |
| response = ppo.generate(query.unsqueeze(0), **gen_kwargs) | |
| response_ids = response[0][len(query):] | |
| text = tokenizer.decode(response_ids, skip_special_tokens=True) | |
| action = parse_action(text) | |
| obs, reward, done, info = env.step(action) | |
| # Named component scores | |
| state = env.state() | |
| robj = compute_reward(state, action) | |
| ep_comp.append(get_named_scores(robj)) | |
| reward_t = torch.tensor([reward], dtype=torch.float).to(device) | |
| ppo.step([query], [response_ids], [reward_t]) | |
| ep_rewards.append(reward) | |
| if done: | |
| break | |
| avg_r = float(np.mean(ep_rewards)) | |
| episode_rewards.append(avg_r) | |
| episode_components.append({ | |
| k: round(float(np.mean([c[k] for c in ep_comp])), 4) | |
| for k in ep_comp[0] | |
| }) | |
| print(f" Ep {ep+1:2d}: avg_reward={avg_r:.4f} " | |
| + " ".join(f"{k}={v:.3f}" for k, v in episode_components[-1].items())) | |
| # ββ Save model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| os.makedirs("assets", exist_ok=True) | |
| model.save_pretrained("assets/civicai_ppo_model") | |
| tokenizer.save_pretrained("assets/civicai_ppo_model") | |
| print("\n Model saved to assets/civicai_ppo_model/") | |
| # ββ Save JSON results βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| results = { | |
| "baseline_avg": baseline_avg, | |
| "episode_rewards": episode_rewards, | |
| "episode_components": episode_components, | |
| "final_avg": float(np.mean(episode_rewards[-5:])), | |
| "improvement": float(np.mean(episode_rewards[-5:])) - baseline_avg, | |
| } | |
| with open("assets/training_results.json", "w") as f: | |
| json.dump(results, f, indent=2) | |
| # ββ Plots βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _plot_training_curve(episode_rewards, baseline_avg) | |
| _plot_component_breakdown(episode_components) | |
| print("\n[CivicAI] Training complete.") | |
| print(f" Baseline avg: {baseline_avg:.4f}") | |
| print(f" Final 5-ep avg: {results['final_avg']:.4f}") | |
| print(f" Improvement: {results['improvement']:+.4f}") | |
| return results | |
| def _plot_training_curve(rewards: list[float], baseline: float) -> None: | |
| smooth = np.convolve(rewards, np.ones(3)/3, mode="valid") | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| fig.patch.set_facecolor(DARK); ax.set_facecolor(PANEL) | |
| ax.plot(rewards, color="#06b6d4", alpha=0.4, linewidth=1) | |
| ax.plot(range(len(smooth)), smooth, color="#06b6d4", linewidth=2.5, | |
| label=f"PPO Agent (final={rewards[-1]:.3f})") | |
| ax.axhline(baseline, color="#ef4444", linestyle="--", linewidth=1.8, | |
| label=f"Random Baseline ({baseline:.3f})") | |
| ax.fill_between(range(len(smooth)), smooth, baseline, | |
| where=[s > baseline for s in smooth], | |
| alpha=0.15, color="#06b6d4", label="Improvement over baseline") | |
| ax.set_ylim(0, 1.05) | |
| ax.set_xlabel("Episode", color="#94a3b8"); ax.set_ylabel("Avg Step Reward", color="#94a3b8") | |
| ax.set_title("CivicAI TRL PPO β Training Curve", color="#e2e8f0", fontsize=14, fontweight="bold") | |
| ax.tick_params(colors="#94a3b8") | |
| for sp in ax.spines.values(): sp.set_edgecolor(GRID) | |
| ax.grid(axis="y", color=GRID, linewidth=0.5, linestyle="--") | |
| ax.legend(facecolor=PANEL, edgecolor=GRID, labelcolor="#e2e8f0") | |
| plt.tight_layout() | |
| plt.savefig("assets/reward_curve.png", dpi=150, facecolor=DARK) | |
| plt.close() | |
| print(" Saved: assets/reward_curve.png") | |
| def _plot_component_breakdown(components: list[dict]) -> None: | |
| keys = ["economic_score", "health_score", "satisfaction_score", "crime_score"] | |
| colors = ["#f59e0b", "#10b981", "#a78bfa", "#f97316"] | |
| fig, axes = plt.subplots(1, 4, figsize=(16, 4)) | |
| fig.patch.set_facecolor(DARK) | |
| fig.suptitle("Named Reward Components Over Training", color="#e2e8f0", | |
| fontsize=13, fontweight="bold") | |
| for ax, key, col in zip(axes, keys, colors): | |
| vals = [c[key] for c in components] | |
| ax.set_facecolor(PANEL) | |
| ax.plot(vals, color=col, linewidth=2) | |
| ax.fill_between(range(len(vals)), vals, alpha=0.15, color=col) | |
| ax.set_ylim(0, 1.05) | |
| ax.set_title(key.replace("_score", "").capitalize(), color="#e2e8f0", fontsize=11) | |
| ax.tick_params(colors="#94a3b8", labelsize=8) | |
| for sp in ax.spines.values(): sp.set_edgecolor(GRID) | |
| ax.grid(color=GRID, linewidth=0.4, linestyle="--") | |
| plt.tight_layout() | |
| plt.savefig("assets/component_scores.png", dpi=150, facecolor=DARK) | |
| plt.close() | |
| print(" Saved: assets/component_scores.png") | |
| if __name__ == "__main__": | |
| train_ppo() | |