Spaces:
Sleeping
Sleeping
| import json | |
| import re | |
| code = """ | |
| # ============================================================================= | |
| # CivicAI Advanced β Senior ML Engineer Edition | |
| # Real-time Economic Data + GRPO + LoRA + Multi-Country + Live Dashboard | |
| # ============================================================================= | |
| # ββ CELL 1: INSTALL DEPENDENCIES βββββββββββββββββββββββββββββββββββββββββββββ | |
| \"\"\" | |
| !pip install -q \\ | |
| "transformers>=4.38" \\ | |
| "accelerate>=0.27" \\ | |
| "trl>=0.10" \\ | |
| "peft>=0.9" \\ | |
| "bitsandbytes>=0.42" \\ | |
| "datasets>=2.17" \\ | |
| "requests>=2.31" \\ | |
| "pandas>=2.0" \\ | |
| "fredapi" \\ | |
| "world-bank-data" \\ | |
| "plotly>=5.18" \\ | |
| "rich>=13.0" \\ | |
| "tenacity>=8.2" | |
| # After install: Runtime β Restart session β run from Cell 2 | |
| \"\"\" | |
| # ββ CELL 2: IMPORTS & SYSTEM SETUP βββββββββββββββββββββββββββββββββββββββββββ | |
| import os, re, json, math, time, random, inspect, warnings, logging | |
| from datetime import datetime | |
| from typing import Dict, List, Optional, Tuple | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import requests | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| from plotly.subplots import make_subplots | |
| from tenacity import retry, stop_after_attempt, wait_exponential | |
| from rich.console import Console | |
| from rich.table import Table | |
| from rich.progress import Progress, SpinnerColumn, TextColumn | |
| from rich import print as rprint | |
| warnings.filterwarnings("ignore") | |
| logging.basicConfig(level=logging.ERROR) | |
| console = Console() | |
| # ββ Hardware detection ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CUDA_OK = torch.cuda.is_available() | |
| if CUDA_OK: | |
| CAP = torch.cuda.get_device_capability() | |
| USE_BF16 = CAP[0] >= 8 | |
| USE_FP16 = not USE_BF16 | |
| GPU_NAME = torch.cuda.get_device_name(0) | |
| else: | |
| USE_BF16 = USE_FP16 = False | |
| GPU_NAME = "CPU" | |
| DEVICE = "cuda" if CUDA_OK else "cpu" | |
| # ββ Paths βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| Path("assets").mkdir(exist_ok=True) | |
| Path("checkpoints").mkdir(exist_ok=True) | |
| Path("logs").mkdir(exist_ok=True) | |
| console.rule("[bold cyan]CivicAI Advanced β System Ready") | |
| table = Table(show_header=False, box=None) | |
| table.add_row("[cyan]PyTorch", torch.__version__) | |
| table.add_row("[cyan]Device", GPU_NAME) | |
| table.add_row("[cyan]BF16/FP16", f"bf16={USE_BF16} fp16={USE_FP16}") | |
| table.add_row("[cyan]Timestamp", datetime.now().strftime("%Y-%m-%d %H:%M:%S")) | |
| console.print(table) | |
| # ββ CELL 3: REAL-TIME DATA FETCHER βββββββββββββββββββββββββββββββββββββββββββ | |
| class RealTimeDataFetcher: | |
| \"\"\" | |
| Fetches live economic indicators from: | |
| β’ World Bank Open API (no key required) | |
| β’ FRED / St. Louis Fed (free key via fredapi) | |
| β’ BLS (Bureau of Labor Statistics β no key for basic tier) | |
| β’ REST Countries (social / governance proxies) | |
| Falls back to realistic historical means if any API is unavailable. | |
| \"\"\" | |
| WORLD_BANK_BASE = "https://api.worldbank.org/v2" | |
| BLS_BASE = "https://api.bls.gov/publicAPI/v1/timeseries/data" | |
| REST_COUNTRIES = "https://restcountries.com/v3.1/alpha" | |
| # World Bank indicator codes | |
| WB_INDICATORS = { | |
| "inflation" : "FP.CPI.TOTL.ZG", # CPI inflation % | |
| "unemployment": "SL.UEM.TOTL.ZS", # Unemployment % of labour force | |
| "health_exp" : "SH.XPD.CHEX.GD.ZS",# Health expenditure % of GDP | |
| "life_expect" : "SP.DYN.LE00.IN", # Life expectancy at birth | |
| "gdp_growth" : "NY.GDP.MKTP.KD.ZG", # GDP growth % | |
| "homicide" : "VC.IHR.PSRC.P5", # Intentional homicides per 100k | |
| } | |
| # Country ISO codes supported | |
| COUNTRIES = { | |
| "USA": {"iso2": "US", "iso3": "USA", "name": "United States"}, | |
| "IND": {"iso2": "IN", "iso3": "IND", "name": "India"}, | |
| "GBR": {"iso2": "GB", "iso3": "GBR", "name": "United Kingdom"}, | |
| "DEU": {"iso2": "DE", "iso3": "DEU", "name": "Germany"}, | |
| "JPN": {"iso2": "JP", "iso3": "JPN", "name": "Japan"}, | |
| "BRA": {"iso2": "BR", "iso3": "BRA", "name": "Brazil"}, | |
| } | |
| # Realistic fallback values (5-year historical means, 2019-2023) | |
| FALLBACKS = { | |
| "USA": {"inflation":3.8,"unemployment":4.8,"health_exp":17.2,"life_expect":77.5,"gdp_growth":2.1,"homicide":6.5}, | |
| "IND": {"inflation":5.5,"unemployment":7.2,"health_exp":3.3, "life_expect":69.4,"gdp_growth":5.8,"homicide":2.8}, | |
| "GBR": {"inflation":3.2,"unemployment":4.2,"health_exp":10.9,"life_expect":80.4,"gdp_growth":1.4,"homicide":1.2}, | |
| "DEU": {"inflation":2.8,"unemployment":3.5,"health_exp":12.8,"life_expect":80.6,"gdp_growth":0.9,"homicide":0.9}, | |
| "JPN": {"inflation":1.2,"unemployment":2.8,"health_exp":10.9,"life_expect":84.3,"gdp_growth":0.7,"homicide":0.2}, | |
| "BRA": {"inflation":6.9,"unemployment":11.0,"health_exp":9.9,"life_expect":75.5,"gdp_growth":1.2,"homicide":22.4}, | |
| } | |
| def __init__(self, cache_ttl_seconds: int = 3600): | |
| self._cache: Dict[str, Tuple[float, dict]] = {} | |
| self.cache_ttl = cache_ttl_seconds | |
| self.session = requests.Session() | |
| self.session.headers.update({"User-Agent": "CivicAI/2.0"}) | |
| @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8)) | |
| def _wb_fetch(self, country_iso2: str, indicator: str) -> Optional[float]: | |
| \"\"\"Fetch latest non-null value from World Bank API.\"\"\" | |
| url = (f"{self.WORLD_BANK_BASE}/country/{country_iso2}/indicator/{indicator}" | |
| f"?format=json&mrv=5&per_page=5") | |
| r = self.session.get(url, timeout=10) | |
| r.raise_for_status() | |
| data = r.json() | |
| if len(data) < 2 or not data[1]: | |
| return None | |
| for entry in data[1]: | |
| if entry.get("value") is not None: | |
| return float(entry["value"]) | |
| return None | |
| def fetch_country(self, country_code: str = "USA") -> dict: | |
| \"\"\" | |
| Returns normalised economic state for a country. | |
| Uses cache β World Bank API β fallback in that order. | |
| \"\"\" | |
| cache_key = f"{country_code}_{int(time.time() // self.cache_ttl)}" | |
| if cache_key in self._cache: | |
| return self._cache[cache_key] | |
| meta = self.COUNTRIES.get(country_code, self.COUNTRIES["USA"]) | |
| iso2 = meta["iso2"] | |
| raw = {} | |
| with Progress(SpinnerColumn(), TextColumn("[cyan]Fetching {task.description}"), | |
| transient=True) as prog: | |
| t = prog.add_task(f"live data for {meta['name']}") | |
| for key, indicator in self.WB_INDICATORS.items(): | |
| try: | |
| val = self._wb_fetch(iso2, indicator) | |
| raw[key] = val if val is not None else self.FALLBACKS[country_code][key] | |
| except Exception: | |
| raw[key] = self.FALLBACKS[country_code][key] | |
| # ββ Normalise to [0, 1] for the RL environment ββββββββββββββββββββ | |
| state = { | |
| # Lower inflation = better; 0 %β1.0, β₯15 %β0.0 | |
| "inflation" : max(0.0, min(1.0, 1 - raw["inflation"] / 15.0)), | |
| # Higher employment = better | |
| "employment" : max(0.0, min(1.0, 1 - raw["unemployment"]/ 25.0)), | |
| # Higher health expenditure + life expectancy = better | |
| "health" : max(0.0, min(1.0, (raw["health_exp"] / 20.0 + | |
| raw["life_expect"] / 90.0) / 2)), | |
| # GDP growth proxy for satisfaction | |
| "satisfaction": max(0.0, min(1.0, (raw["gdp_growth"] + 5) / 15.0)), | |
| # Lower homicide = better | |
| "crime" : max(0.0, min(1.0, 1 - raw["homicide"] / 50.0)), | |
| } | |
| # Attach raw for reporting | |
| state["_raw"] = raw | |
| state["_country"]= meta["name"] | |
| state["_fetched"]= datetime.now().isoformat() | |
| self._cache[cache_key] = state | |
| return state | |
| def fetch_all_countries(self) -> Dict[str, dict]: | |
| results = {} | |
| for code in self.COUNTRIES: | |
| console.log(f"[dim]β fetching {code}") | |
| results[code] = self.fetch_country(code) | |
| return results | |
| def to_dataframe(self, all_data: Dict[str, dict]) -> pd.DataFrame: | |
| rows = [] | |
| for code, state in all_data.items(): | |
| raw = state.get("_raw", {}) | |
| rows.append({ | |
| "country" : state.get("_country", code), | |
| "code" : code, | |
| "inflation_pct": raw.get("inflation", 0), | |
| "unemployment_pct": raw.get("unemployment", 0), | |
| "health_exp_gdp": raw.get("health_exp", 0), | |
| "life_expect" : raw.get("life_expect", 0), | |
| "gdp_growth" : raw.get("gdp_growth", 0), | |
| "homicide_rate" : raw.get("homicide", 0), | |
| # normalised | |
| "norm_inflation" : state["inflation"], | |
| "norm_employment" : state["employment"], | |
| "norm_health" : state["health"], | |
| "norm_satisfaction": state["satisfaction"], | |
| "norm_crime" : state["crime"], | |
| "fetched_at" : state.get("_fetched"), | |
| }) | |
| return pd.DataFrame(rows) | |
| # Instantiate & fetch | |
| fetcher = RealTimeDataFetcher(cache_ttl_seconds=3600) | |
| all_data = fetcher.fetch_all_countries() | |
| df_world = fetcher.to_dataframe(all_data) | |
| console.rule("[bold green]Live Data Fetched") | |
| console.print(df_world[["country","inflation_pct","unemployment_pct", | |
| "health_exp_gdp","gdp_growth"]].to_string(index=False)) | |
| # ββ CELL 4: REAL-DATA DASHBOARD ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_global_dashboard(df: pd.DataFrame) -> None: | |
| fig = make_subplots( | |
| rows=2, cols=3, | |
| subplot_titles=( | |
| "Inflation (%)", "Unemployment (%)", "Health Exp (% GDP)", | |
| "Life Expectancy (yrs)", "GDP Growth (%)", "Homicide Rate (per 100k)" | |
| ), | |
| ) | |
| cols_raw = ["inflation_pct","unemployment_pct","health_exp_gdp", | |
| "life_expect","gdp_growth","homicide_rate"] | |
| colors = px.colors.qualitative.Bold | |
| for i, col in enumerate(cols_raw): | |
| r, c = divmod(i, 3) | |
| fig.add_trace( | |
| go.Bar( | |
| x=df["country"], y=df[col], | |
| marker_color=colors, | |
| showlegend=False, | |
| text=df[col].round(1), textposition="outside" | |
| ), | |
| row=r+1, col=c+1 | |
| ) | |
| fig.update_layout( | |
| title_text="π CivicAI β Real-Time Global Economic Dashboard", | |
| title_font_size=20, | |
| height=600, template="plotly_dark", | |
| paper_bgcolor="#0d1117", plot_bgcolor="#0d1117", | |
| font=dict(color="#e6edf3"), | |
| ) | |
| fig.show() | |
| fig.write_html("assets/global_dashboard.html") | |
| console.log("[green]β Dashboard saved β assets/global_dashboard.html") | |
| plot_global_dashboard(df_world) | |
| # ββ CELL 5: ADVANCED MULTI-COUNTRY ENVIRONMENT βββββββββββββββββββββββββββββββ | |
| class AdvancedCivicAIEnv: | |
| \"\"\" | |
| Production-grade multi-country civic environment. | |
| β’ Initialises from real World Bank data | |
| β’ Supports 6 countries and 4 policy tasks | |
| β’ Action space: 5-dimensional continuous [0,1] | |
| β’ Observation: 10-dimensional (5 state + 5 delta from last step) | |
| β’ Reward: weighted multi-objective (Pareto-style) | |
| β’ Includes shock events (recession, pandemic proxy, crime spike) | |
| \"\"\" | |
| TASKS = { | |
| "stabilize_economy" : {"inflation_weight":0.4, "employment_weight":0.3, "health_weight":0.15, "satisfaction_weight":0.1, "crime_weight":0.05}, | |
| "improve_health" : {"inflation_weight":0.1, "employment_weight":0.2, "health_weight":0.5, "satisfaction_weight":0.15,"crime_weight":0.05}, | |
| "reduce_crime" : {"inflation_weight":0.1, "employment_weight":0.2, "health_weight":0.2, "satisfaction_weight":0.1, "crime_weight":0.4}, | |
| "maximize_wellbeing" : {"inflation_weight":0.2, "employment_weight":0.2, "health_weight":0.2, "satisfaction_weight":0.2, "crime_weight":0.2}, | |
| } | |
| SHOCK_EVENTS = [ | |
| {"name":"recession", "prob":0.02, "effect":{"inflation":+0.15,"employment":-0.12,"satisfaction":-0.1}}, | |
| {"name":"pandemic", "prob":0.01, "effect":{"health":-0.2, "employment":-0.1, "satisfaction":-0.15}}, | |
| {"name":"crime_spike","prob":0.02, "effect":{"crime":-0.15, "satisfaction":-0.08}}, | |
| {"name":"boom", "prob":0.02, "effect":{"employment":+0.1,"satisfaction":+0.1,"inflation":+0.05}}, | |
| ] | |
| def __init__(self, fetcher: RealTimeDataFetcher, default_country: str = "USA"): | |
| self.fetcher = fetcher | |
| self.default_country = default_country | |
| self._prev_state = None | |
| self.step_count = 0 | |
| self.shock_log = [] | |
| self.state_data = {} | |
| def reset(self, task_id: str = "stabilize_economy", country: str = None) -> dict: | |
| country = country or self.default_country | |
| self.task_id = task_id | |
| self.weights = self.TASKS[task_id] | |
| self.step_count = 0 | |
| self.shock_log = [] | |
| # Load real data as starting state | |
| live = self.fetcher.fetch_country(country) | |
| self.state_data = {k: live[k] for k in ["inflation","employment","health","satisfaction","crime"]} | |
| # Add small noise so each episode is unique | |
| for k in self.state_data: | |
| self.state_data[k] = float(np.clip( | |
| self.state_data[k] + np.random.normal(0, 0.02), 0.0, 1.0 | |
| )) | |
| self._prev_state = dict(self.state_data) | |
| return self._build_obs() | |
| def _build_obs(self) -> dict: | |
| \"\"\"10-dim observation: current state + delta from previous step.\"\"\" | |
| obs = dict(self.state_data) | |
| obs["_task"] = self.task_id | |
| obs["_step"] = self.step_count | |
| if self._prev_state: | |
| for k in ["inflation","employment","health","satisfaction","crime"]: | |
| obs[f"d_{k}"] = self.state_data[k] - self._prev_state[k] | |
| else: | |
| for k in ["inflation","employment","health","satisfaction","crime"]: | |
| obs[f"d_{k}"] = 0.0 | |
| return obs | |
| def _apply_shocks(self): | |
| \"\"\"Stochastic external shock events.\"\"\" | |
| for shock in self.SHOCK_EVENTS: | |
| if np.random.random() < shock["prob"]: | |
| self.shock_log.append({"step": self.step_count, "event": shock["name"]}) | |
| for k, delta in shock["effect"].items(): | |
| if k in self.state_data: | |
| self.state_data[k] = float(np.clip(self.state_data[k] + delta, 0.0, 1.0)) | |
| console.log(f"[yellow]β‘ Shock event: {shock['name']} at step {self.step_count}") | |
| def step(self, action: dict) -> Tuple[dict, float, bool, dict]: | |
| \"\"\" | |
| action keys: tax, jobs, healthcare, education, infrastructure | |
| Each in [0, 1] β represents budget allocation intensity. | |
| \"\"\" | |
| self._prev_state = dict(self.state_data) | |
| # Policy effects (with diminishing returns via sqrt) | |
| tax = action.get("tax", 0.5) | |
| jobs = action.get("jobs", 0.5) | |
| healthcare = action.get("healthcare", 0.5) | |
| education = action.get("education", 0.5) | |
| infra = action.get("infrastructure",0.5) | |
| self.state_data["inflation"] = np.clip( | |
| self.state_data["inflation"] - tax * 0.08 + jobs * 0.02, 0.0, 1.0) | |
| self.state_data["employment"] = np.clip( | |
| self.state_data["employment"] + jobs * 0.06 + infra * 0.02, 0.0, 1.0) | |
| self.state_data["health"] = np.clip( | |
| self.state_data["health"] + healthcare * 0.07 + education * 0.02, 0.0, 1.0) | |
| self.state_data["satisfaction"] = np.clip( | |
| self.state_data["satisfaction"] + education * 0.05 + infra * 0.03 | |
| - tax * 0.03, 0.0, 1.0) | |
| self.state_data["crime"] = np.clip( | |
| self.state_data["crime"] + education * 0.05 + jobs * 0.03 | |
| - infra * 0.01, 0.0, 1.0) | |
| # Gaussian noise | |
| for k in self.state_data: | |
| self.state_data[k] = float(np.clip( | |
| self.state_data[k] + np.random.normal(0, 0.008), 0.0, 1.0)) | |
| self._apply_shocks() | |
| self.step_count += 1 | |
| reward = self._compute_reward() | |
| done = self.step_count >= 50 | |
| info = {"shocks": self.shock_log, "step": self.step_count} | |
| return self._build_obs(), float(reward), done, info | |
| def _compute_reward(self) -> float: | |
| s = self.state_data | |
| w = self.weights | |
| return ( | |
| w["inflation_weight"] * s["inflation"] + | |
| w["employment_weight"] * s["employment"] + | |
| w["health_weight"] * s["health"] + | |
| w["satisfaction_weight"] * s["satisfaction"] + | |
| w["crime_weight"] * s["crime"] | |
| ) | |
| def state_report(self) -> dict: | |
| return {k: round(v, 4) for k, v in self.state_data.items()} | |
| # Smoke test | |
| env_adv = AdvancedCivicAIEnv(fetcher, default_country="USA") | |
| obs = env_adv.reset("stabilize_economy", "USA") | |
| console.rule("[bold green]Advanced Environment Ready") | |
| console.print(f"Initial state (USA, real data): {env_adv.state_report()}") | |
| # ββ CELL 6: PROMPT BUILDER (5-action) ββββββββββββββββββββββββββββββββββββββββ | |
| def build_prompt(obs: dict) -> str: | |
| task_desc = { | |
| "stabilize_economy" : "Your priority is economic stability: control inflation and protect employment.", | |
| "improve_health" : "Your priority is public health: maximize health outcomes and life expectancy.", | |
| "reduce_crime" : "Your priority is public safety: reduce crime through investment and employment.", | |
| "maximize_wellbeing" : "Your priority is overall citizen wellbeing across all dimensions.", | |
| }.get(obs.get("_task",""), "Optimize all civic outcomes.") | |
| return ( | |
| f"You are a senior policy advisor.\\n{task_desc}\\n\\n" | |
| f"CURRENT STATE (step {obs.get('_step',0)}):\\n" | |
| f" Inflation score : {obs.get('inflation',0.5):.3f} (Ξ {obs.get('d_inflation',0):+.3f})\\n" | |
| f" Employment score : {obs.get('employment',0.5):.3f} (Ξ {obs.get('d_employment',0):+.3f})\\n" | |
| f" Health score : {obs.get('health',0.5):.3f} (Ξ {obs.get('d_health',0):+.3f})\\n" | |
| f" Satisfaction score: {obs.get('satisfaction',0.5):.3f} (Ξ {obs.get('d_satisfaction',0):+.3f})\\n" | |
| f" Crime score : {obs.get('crime',0.5):.3f} (Ξ {obs.get('d_crime',0):+.3f})\\n\\n" | |
| "OUTPUT FORMAT (all values 0.0β1.0, no other text):\\n" | |
| "tax: 0.X, jobs: 0.X, healthcare: 0.X, education: 0.X, infrastructure: 0.X" | |
| ) | |
| def parse_action(text: str) -> dict: | |
| \"\"\"5-dimensional action parser with robust regex.\"\"\" | |
| keys = ["tax", "jobs", "healthcare", "education", "infrastructure"] | |
| def extract(key: str) -> float: | |
| m = re.search(rf"{key}\\s*:\\s*(\\d*\\.?\\d+)", text) | |
| if m: | |
| try: | |
| return float(np.clip(float(m.group(1)), 0.0, 1.0)) | |
| except ValueError: | |
| pass | |
| return 0.5 | |
| return {k: extract(k) for k in keys} | |
| # ββ CELL 7: LOAD MODEL WITH LoRA βββββββββββββββββββββββββββββββββββββββββββββ | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import LoraConfig, get_peft_model, TaskType | |
| MODEL_NAME = "gpt2" # swap to "gpt2-medium" or "distilgpt2" as needed | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "left" | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype = torch.bfloat16 if USE_BF16 else torch.float32, | |
| ) | |
| # ββ Attach LoRA adapters (reduces trainable params by ~90%) ββββββββββββββββββ | |
| lora_cfg = LoraConfig( | |
| task_type = TaskType.CAUSAL_LM, | |
| r = 8, # rank | |
| lora_alpha = 32, | |
| target_modules = ["c_attn"], # GPT-2 attention projection | |
| lora_dropout = 0.05, | |
| bias = "none", | |
| ) | |
| model = get_peft_model(base_model, lora_cfg) | |
| model.print_trainable_parameters() | |
| model = model.to(DEVICE) | |
| console.rule("[bold green]Model Ready") | |
| console.log(f"[cyan]Model : {MODEL_NAME} + LoRA (r=8)") | |
| console.log(f"[cyan]Parameters : {sum(p.numel() for p in model.parameters())/1e6:.1f}M total") | |
| # ββ CELL 8: BUILD TRAINING DATASET FROM REAL DATA ββββββββββββββββββββββββββββ | |
| from datasets import Dataset | |
| NUM_SAMPLES = 300 | |
| records = [] | |
| env_tmp = AdvancedCivicAIEnv(fetcher) | |
| task_list = list(AdvancedCivicAIEnv.TASKS.keys()) | |
| country_list= list(RealTimeDataFetcher.COUNTRIES.keys()) | |
| for i in range(NUM_SAMPLES): | |
| task = task_list[i % len(task_list)] | |
| country = country_list[i % len(country_list)] | |
| obs = env_tmp.reset(task, country) | |
| records.append({ | |
| "prompt" : build_prompt(obs), | |
| "task" : task, | |
| "country" : country, | |
| }) | |
| train_dataset = Dataset.from_list(records) | |
| console.log(f"[green]β Dataset: {len(train_dataset)} prompts across " | |
| f"{len(task_list)} tasks Γ {len(country_list)} countries") | |
| console.log(f" Sample:\\n{train_dataset[0]['prompt'][:300]}...") | |
| # ββ CELL 9: MULTI-OBJECTIVE REWARD FUNCTION βββββββββββββββββββββββββββββββββββ | |
| def civic_reward_advanced(prompts, completions, task=None, country=None, **kwargs) -> List[float]: | |
| \"\"\" | |
| Multi-objective GRPO reward function. | |
| Scores: environment reward + format compliance + consistency bonus. | |
| \"\"\" | |
| rewards = [] | |
| env_r = AdvancedCivicAIEnv(fetcher) | |
| task_list_ = task if isinstance(task, list) else [task] * len(prompts) | |
| country_ = country if isinstance(country,list) else [country]* len(prompts) | |
| for i, (prompt, completion) in enumerate(zip(prompts, completions)): | |
| # Extract text | |
| if isinstance(completion, list) and len(completion) > 0: | |
| text = completion[0].get("content", "") | |
| else: | |
| text = str(completion) | |
| action = parse_action(text) | |
| # Environment reward | |
| t = (task_list_[i] if task_list_[i] else "maximize_wellbeing") | |
| c = (country_[i] if country_[i] else "USA") | |
| env_r.reset(t, c) | |
| _, env_rew, _, _ = env_r.step(action) | |
| # Format reward: all 5 keys present | |
| keys_found = sum( | |
| 1 for k in ["tax","jobs","healthcare","education","infrastructure"] | |
| if re.search(rf"{k}\\s*:\\s*\\d", text) | |
| ) | |
| fmt_bonus = (keys_found / 5.0) * 0.15 # up to +0.15 | |
| # Diversity bonus: penalise all-same values (lazy policy) | |
| vals = list(action.values()) | |
| div_bonus = float(np.std(vals)) * 0.1 # up to ~+0.05 | |
| total = float(env_rew) + fmt_bonus + div_bonus | |
| rewards.append(round(total, 5)) | |
| return rewards | |
| # ββ CELL 10: GRPO CONFIG (VERSION-SAFE) ββββββββββββββββββββββββββββββββββββββ | |
| from trl import GRPOConfig, GRPOTrainer | |
| valid_params = set(inspect.signature(GRPOConfig.__init__).parameters) | |
| all_kwargs = { | |
| "output_dir" : "checkpoints/civicai-grpo", | |
| "num_train_epochs" : 3, | |
| "per_device_train_batch_size" : 2, | |
| "num_generations" : 2, | |
| "max_prompt_length" : 300, | |
| "max_completion_length" : 80, | |
| "learning_rate" : 5e-6, | |
| "logging_steps" : 5, | |
| "save_strategy" : "epoch", | |
| "save_total_limit" : 2, | |
| "report_to" : "none", | |
| "remove_unused_columns" : False, | |
| "bf16" : USE_BF16, | |
| "fp16" : USE_FP16, | |
| "gradient_accumulation_steps" : 4, | |
| "max_grad_norm" : 0.3, | |
| "warmup_ratio" : 0.05, | |
| "lr_scheduler_type" : "cosine", | |
| "dataloader_num_workers" : 0, | |
| } | |
| safe_kwargs = {k: v for k, v in all_kwargs.items() if k in valid_params} | |
| skipped = set(all_kwargs) - set(safe_kwargs) | |
| if skipped: | |
| console.log(f"[yellow]Skipped unsupported GRPOConfig args: {skipped}") | |
| grpo_config = GRPOConfig(**safe_kwargs) | |
| trainer = GRPOTrainer( | |
| model = model, | |
| args = grpo_config, | |
| reward_funcs = civic_reward_advanced, | |
| train_dataset = train_dataset, | |
| processing_class = tokenizer, | |
| ) | |
| console.log("[green]β GRPOTrainer initialised with LoRA + multi-objective reward") | |
| # ββ CELL 11: TRAINING βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| console.rule("[bold cyan]Starting GRPO Training") | |
| start_time = time.time() | |
| trainer.train() | |
| elapsed = time.time() - start_time | |
| console.rule(f"[bold green]Training Complete β {elapsed/60:.1f} min") | |
| # ββ CELL 12: EXTRACT & PLOT TRAINING METRICS βββββββββββββββββββββββββββββββββ | |
| logs = trainer.state.log_history | |
| df_logs = pd.DataFrame(logs).dropna(subset=["loss"] if "loss" in pd.DataFrame(logs).columns else []) | |
| reward_entries = [e for e in logs if "reward" in e] | |
| rewards_logged = [e["reward"] for e in reward_entries] | |
| steps_logged = [e.get("step", i) for i, e in enumerate(reward_entries)] | |
| fig = make_subplots(rows=1, cols=2, | |
| subplot_titles=("Reward Curve", "Reward Distribution")) | |
| # Reward over steps | |
| fig.add_trace(go.Scatter( | |
| x=steps_logged, y=rewards_logged, | |
| mode="lines", name="Reward", line=dict(color="#00d4ff", width=2) | |
| ), row=1, col=1) | |
| # Smoothed | |
| if len(rewards_logged) > 5: | |
| smooth = np.convolve(rewards_logged, np.ones(5)/5, mode="valid") | |
| fig.add_trace(go.Scatter( | |
| x=steps_logged[4:], y=smooth, | |
| mode="lines", name="Smoothed", | |
| line=dict(color="#ff6b6b", width=2, dash="dash") | |
| ), row=1, col=1) | |
| # Histogram | |
| fig.add_trace(go.Histogram( | |
| x=rewards_logged, nbinsx=20, | |
| marker_color="#00d4ff", opacity=0.75, name="Distribution" | |
| ), row=1, col=2) | |
| fig.update_layout( | |
| title="CivicAI GRPO Training Metrics", | |
| template="plotly_dark", height=420, | |
| paper_bgcolor="#0d1117", font=dict(color="#e6edf3"), | |
| ) | |
| fig.show() | |
| fig.write_html("assets/training_metrics.html") | |
| if rewards_logged: | |
| console.print(f"[cyan]Start reward : {rewards_logged[0]:.4f}") | |
| console.print(f"[cyan]Final reward : {rewards_logged[-1]:.4f}") | |
| console.print(f"[green]Improvement : {rewards_logged[-1]-rewards_logged[0]:+.4f}") | |
| # ββ CELL 13: MULTI-COUNTRY POLICY EVALUATION βββββββββββββββββββββββββββββββββ | |
| def evaluate_trained_policy( | |
| model, tokenizer, fetcher, | |
| countries: List[str] = None, | |
| tasks: List[str] = None, | |
| episodes: int = 5, | |
| ) -> pd.DataFrame: | |
| \"\"\"Evaluate trained policy on all countries Γ all tasks.\"\"\" | |
| countries = countries or list(RealTimeDataFetcher.COUNTRIES.keys()) | |
| tasks = tasks or list(AdvancedCivicAIEnv.TASKS.keys()) | |
| model.eval() | |
| results = [] | |
| for country in countries: | |
| for task in tasks: | |
| ep_rewards = [] | |
| env_eval = AdvancedCivicAIEnv(fetcher, default_country=country) | |
| for _ in range(episodes): | |
| obs = env_eval.reset(task, country) | |
| ep_reward = 0.0 | |
| for _ in range(20): | |
| prompt = build_prompt(obs) | |
| inputs = tokenizer(prompt, return_tensors="pt", | |
| truncation=True, max_length=300).to(DEVICE) | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, max_new_tokens=60, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| gen_tokens = out[0][inputs["input_ids"].shape[1]:] | |
| text = tokenizer.decode(gen_tokens, skip_special_tokens=True) | |
| action = parse_action(text) | |
| obs, r, done, _ = env_eval.step(action) | |
| ep_reward += r | |
| if done: break | |
| ep_rewards.append(ep_reward / 20) | |
| results.append({ | |
| "country" : RealTimeDataFetcher.COUNTRIES[country]["name"], | |
| "task" : task, | |
| "mean_r" : round(float(np.mean(ep_rewards)), 4), | |
| "std_r" : round(float(np.std(ep_rewards)), 4), | |
| "max_r" : round(float(np.max(ep_rewards)), 4), | |
| }) | |
| console.log(f"[dim]{country} / {task} β {results[-1]['mean_r']:.4f}") | |
| return pd.DataFrame(results) | |
| def baseline_score(fetcher, episodes=5): | |
| \"\"\"Fixed 0.5 policy baseline.\"\"\" | |
| env_b, total = AdvancedCivicAIEnv(fetcher, "USA"), [] | |
| for _ in range(episodes): | |
| obs = env_b.reset("maximize_wellbeing", "USA") | |
| r = 0.0 | |
| for _ in range(20): | |
| obs, rew, done, _ = env_b.step( | |
| {k: 0.5 for k in ["tax","jobs","healthcare","education","infrastructure"]} | |
| ) | |
| r += rew | |
| total.append(r / 20) | |
| return float(np.mean(total)) | |
| console.rule("[bold cyan]Evaluating Policy Across Countries & Tasks") | |
| df_eval = evaluate_trained_policy(model, tokenizer, fetcher, episodes=3) | |
| baseline = baseline_score(fetcher) | |
| console.print(df_eval.to_string(index=False)) | |
| console.print(f"\\n[bold]Baseline (fixed 0.5) : {baseline:.4f}") | |
| console.print(f"[bold green]Best trained score : {df_eval['mean_r'].max():.4f}") | |
| # ββ CELL 14: EVALUATION HEATMAP ββββββββββββββββββββββββββββββββββββββββββββββ | |
| pivot = df_eval.pivot(index="country", columns="task", values="mean_r") | |
| fig_heat = go.Figure(go.Heatmap( | |
| z = pivot.values, | |
| x = pivot.columns.tolist(), | |
| y = pivot.index.tolist(), | |
| colorscale = "RdYlGn", | |
| text = np.round(pivot.values, 3), | |
| texttemplate="%{text}", | |
| showscale = True, | |
| zmin=0.4, zmax=1.0, | |
| )) | |
| fig_heat.add_shape( | |
| type="line", x0=-0.5, x1=len(pivot.columns)-0.5, | |
| y0=-0.5, y1=len(pivot.index)-0.5, | |
| line=dict(color="white", width=0) | |
| ) | |
| fig_heat.update_layout( | |
| title = "Policy Performance Heatmap β Country Γ Task (GRPO Trained)", | |
| template="plotly_dark", height=400, | |
| paper_bgcolor="#0d1117", font=dict(color="#e6edf3"), | |
| xaxis_title="Task", yaxis_title="Country", | |
| ) | |
| fig_heat.show() | |
| fig_heat.write_html("assets/eval_heatmap.html") | |
| # ββ CELL 15: SAVE EVERYTHING βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Save LoRA adapter only (lightweight) | |
| model.save_pretrained("checkpoints/civicai-lora") | |
| tokenizer.save_pretrained("checkpoints/civicai-lora") | |
| # Save results JSON | |
| results_json = { | |
| "run_timestamp" : datetime.now().isoformat(), | |
| "model" : MODEL_NAME, | |
| "lora_rank" : 8, | |
| "training_epochs": 3, | |
| "num_countries" : len(RealTimeDataFetcher.COUNTRIES), | |
| "num_tasks" : len(AdvancedCivicAIEnv.TASKS), | |
| "data_source" : "World Bank Open API (live)", | |
| "baseline_reward": round(baseline, 4), | |
| "best_reward" : round(float(df_eval["mean_r"].max()), 4), | |
| "improvement" : round(float(df_eval["mean_r"].max()) - baseline, 4), | |
| "reward_history" : rewards_logged, | |
| "eval_by_country_task": df_eval.to_dict(orient="records"), | |
| "real_data_snapshot" : df_world[["country","inflation_pct","unemployment_pct", | |
| "health_exp_gdp","gdp_growth"]].to_dict(orient="records"), | |
| } | |
| with open("assets/training_results.json", "w") as f: | |
| json.dump(results_json, f, indent=2) | |
| console.rule("[bold green]All Done") | |
| console.print(f"[green]β LoRA checkpoint β checkpoints/civicai-lora/") | |
| console.print(f"[green]β Results JSON β assets/training_results.json") | |
| console.print(f"[green]β Dashboard HTML β assets/global_dashboard.html") | |
| console.print(f"[green]β Training metrics β assets/training_metrics.html") | |
| console.print(f"[green]β Eval heatmap β assets/eval_heatmap.html") | |
| console.print(f"\\n[bold cyan]Baseline : {baseline:.4f}") | |
| console.print(f"[bold green]Best score: {df_eval['mean_r'].max():.4f}") | |
| console.print(f"[bold green]Delta : {df_eval['mean_r'].max() - baseline:+.4f}") | |
| """ | |
| cells = [] | |
| # Create a title cell | |
| cells.append({ | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# π CivicAI Advanced β Senior ML Engineer Edition\\n", | |
| "**Real-time Economic Data + GRPO + LoRA + Multi-Country + Live Dashboard**" | |
| ] | |
| }) | |
| # Split the code by cells | |
| chunks = re.split(r'# ββ CELL \d+.*?\n', code) | |
| headers = re.findall(r'# ββ CELL \d+.*?$', code, re.MULTILINE) | |
| # The first chunk is everything before CELL 1 | |
| if len(chunks) > 1: | |
| for idx, chunk in enumerate(chunks[1:]): | |
| header_text = headers[idx] | |
| cells.append({ | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [f"### {header_text.replace('# ββ ', '').replace(' ββ', '').strip()}"] | |
| }) | |
| # Remove trailing and leading newlines | |
| chunk = chunk.strip() | |
| # If the chunk is just the pip install block, we'll strip the docstrings | |
| if "pip install" in chunk and '"""' in chunk: | |
| chunk = chunk.replace('"""', '').strip() | |
| cells.append({ | |
| "cell_type": "code", | |
| "execution_count": None, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [line + "\\n" for line in chunk.split('\\n')] | |
| }) | |
| notebook = { | |
| "cells": cells, | |
| "metadata": { | |
| "colab": {"name": "CivicAI_Training.ipynb"}, | |
| "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, | |
| "language_info": {"name": "python", "version": "3.10"} | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } | |
| with open("c:/Users/mdaft/OneDrive/Desktop/GitHub Projects/AI_Society_Simulator/CivicAI_Training.ipynb", "w", encoding='utf-8') as f: | |
| json.dump(notebook, f, indent=2) | |