mahammadaftab commited on
Commit
a36e07f
·
1 Parent(s): 21b7a4e

Add Hugging Face YAML metadata

Browse files
Files changed (2) hide show
  1. build_notebook.py +840 -0
  2. dashboard/app.js +53 -28
build_notebook.py ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+
4
+ code = """
5
+ # =============================================================================
6
+ # CivicAI Advanced — Senior ML Engineer Edition
7
+ # Real-time Economic Data + GRPO + LoRA + Multi-Country + Live Dashboard
8
+ # =============================================================================
9
+
10
+ # ── CELL 1: INSTALL DEPENDENCIES ─────────────────────────────────────────────
11
+ \"\"\"
12
+ !pip install -q \\
13
+ "transformers>=4.38" \\
14
+ "accelerate>=0.27" \\
15
+ "trl>=0.10" \\
16
+ "peft>=0.9" \\
17
+ "bitsandbytes>=0.42" \\
18
+ "datasets>=2.17" \\
19
+ "requests>=2.31" \\
20
+ "pandas>=2.0" \\
21
+ "fredapi" \\
22
+ "world-bank-data" \\
23
+ "plotly>=5.18" \\
24
+ "rich>=13.0" \\
25
+ "tenacity>=8.2"
26
+
27
+ # After install: Runtime → Restart session → run from Cell 2
28
+ \"\"\"
29
+
30
+ # ── CELL 2: IMPORTS & SYSTEM SETUP ───────────────────────────────────────────
31
+ import os, re, json, math, time, random, inspect, warnings, logging
32
+ from datetime import datetime
33
+ from typing import Dict, List, Optional, Tuple
34
+ from pathlib import Path
35
+
36
+ import numpy as np
37
+ import pandas as pd
38
+ import torch
39
+ import requests
40
+ import plotly.graph_objects as go
41
+ import plotly.express as px
42
+ from plotly.subplots import make_subplots
43
+ from tenacity import retry, stop_after_attempt, wait_exponential
44
+ from rich.console import Console
45
+ from rich.table import Table
46
+ from rich.progress import Progress, SpinnerColumn, TextColumn
47
+ from rich import print as rprint
48
+
49
+ warnings.filterwarnings("ignore")
50
+ logging.basicConfig(level=logging.ERROR)
51
+
52
+ console = Console()
53
+
54
+ # ── Hardware detection ────────────────────────────────────────────────────────
55
+ CUDA_OK = torch.cuda.is_available()
56
+ if CUDA_OK:
57
+ CAP = torch.cuda.get_device_capability()
58
+ USE_BF16 = CAP[0] >= 8
59
+ USE_FP16 = not USE_BF16
60
+ GPU_NAME = torch.cuda.get_device_name(0)
61
+ else:
62
+ USE_BF16 = USE_FP16 = False
63
+ GPU_NAME = "CPU"
64
+
65
+ DEVICE = "cuda" if CUDA_OK else "cpu"
66
+
67
+ # ── Paths ─────────────────────────────────────────────────────────────────────
68
+ Path("assets").mkdir(exist_ok=True)
69
+ Path("checkpoints").mkdir(exist_ok=True)
70
+ Path("logs").mkdir(exist_ok=True)
71
+
72
+ console.rule("[bold cyan]CivicAI Advanced — System Ready")
73
+ table = Table(show_header=False, box=None)
74
+ table.add_row("[cyan]PyTorch", torch.__version__)
75
+ table.add_row("[cyan]Device", GPU_NAME)
76
+ table.add_row("[cyan]BF16/FP16", f"bf16={USE_BF16} fp16={USE_FP16}")
77
+ table.add_row("[cyan]Timestamp", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
78
+ console.print(table)
79
+
80
+
81
+ # ── CELL 3: REAL-TIME DATA FETCHER ───────────────────────────────────────────
82
+ class RealTimeDataFetcher:
83
+ \"\"\"
84
+ Fetches live economic indicators from:
85
+ • World Bank Open API (no key required)
86
+ • FRED / St. Louis Fed (free key via fredapi)
87
+ • BLS (Bureau of Labor Statistics — no key for basic tier)
88
+ • REST Countries (social / governance proxies)
89
+ Falls back to realistic historical means if any API is unavailable.
90
+ \"\"\"
91
+
92
+ WORLD_BANK_BASE = "https://api.worldbank.org/v2"
93
+ BLS_BASE = "https://api.bls.gov/publicAPI/v1/timeseries/data"
94
+ REST_COUNTRIES = "https://restcountries.com/v3.1/alpha"
95
+
96
+ # World Bank indicator codes
97
+ WB_INDICATORS = {
98
+ "inflation" : "FP.CPI.TOTL.ZG", # CPI inflation %
99
+ "unemployment": "SL.UEM.TOTL.ZS", # Unemployment % of labour force
100
+ "health_exp" : "SH.XPD.CHEX.GD.ZS",# Health expenditure % of GDP
101
+ "life_expect" : "SP.DYN.LE00.IN", # Life expectancy at birth
102
+ "gdp_growth" : "NY.GDP.MKTP.KD.ZG", # GDP growth %
103
+ "homicide" : "VC.IHR.PSRC.P5", # Intentional homicides per 100k
104
+ }
105
+
106
+ # Country ISO codes supported
107
+ COUNTRIES = {
108
+ "USA": {"iso2": "US", "iso3": "USA", "name": "United States"},
109
+ "IND": {"iso2": "IN", "iso3": "IND", "name": "India"},
110
+ "GBR": {"iso2": "GB", "iso3": "GBR", "name": "United Kingdom"},
111
+ "DEU": {"iso2": "DE", "iso3": "DEU", "name": "Germany"},
112
+ "JPN": {"iso2": "JP", "iso3": "JPN", "name": "Japan"},
113
+ "BRA": {"iso2": "BR", "iso3": "BRA", "name": "Brazil"},
114
+ }
115
+
116
+ # Realistic fallback values (5-year historical means, 2019-2023)
117
+ FALLBACKS = {
118
+ "USA": {"inflation":3.8,"unemployment":4.8,"health_exp":17.2,"life_expect":77.5,"gdp_growth":2.1,"homicide":6.5},
119
+ "IND": {"inflation":5.5,"unemployment":7.2,"health_exp":3.3, "life_expect":69.4,"gdp_growth":5.8,"homicide":2.8},
120
+ "GBR": {"inflation":3.2,"unemployment":4.2,"health_exp":10.9,"life_expect":80.4,"gdp_growth":1.4,"homicide":1.2},
121
+ "DEU": {"inflation":2.8,"unemployment":3.5,"health_exp":12.8,"life_expect":80.6,"gdp_growth":0.9,"homicide":0.9},
122
+ "JPN": {"inflation":1.2,"unemployment":2.8,"health_exp":10.9,"life_expect":84.3,"gdp_growth":0.7,"homicide":0.2},
123
+ "BRA": {"inflation":6.9,"unemployment":11.0,"health_exp":9.9,"life_expect":75.5,"gdp_growth":1.2,"homicide":22.4},
124
+ }
125
+
126
+ def __init__(self, cache_ttl_seconds: int = 3600):
127
+ self._cache: Dict[str, Tuple[float, dict]] = {}
128
+ self.cache_ttl = cache_ttl_seconds
129
+ self.session = requests.Session()
130
+ self.session.headers.update({"User-Agent": "CivicAI/2.0"})
131
+
132
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8))
133
+ def _wb_fetch(self, country_iso2: str, indicator: str) -> Optional[float]:
134
+ \"\"\"Fetch latest non-null value from World Bank API.\"\"\"
135
+ url = (f"{self.WORLD_BANK_BASE}/country/{country_iso2}/indicator/{indicator}"
136
+ f"?format=json&mrv=5&per_page=5")
137
+ r = self.session.get(url, timeout=10)
138
+ r.raise_for_status()
139
+ data = r.json()
140
+ if len(data) < 2 or not data[1]:
141
+ return None
142
+ for entry in data[1]:
143
+ if entry.get("value") is not None:
144
+ return float(entry["value"])
145
+ return None
146
+
147
+ def fetch_country(self, country_code: str = "USA") -> dict:
148
+ \"\"\"
149
+ Returns normalised economic state for a country.
150
+ Uses cache → World Bank API → fallback in that order.
151
+ \"\"\"
152
+ cache_key = f"{country_code}_{int(time.time() // self.cache_ttl)}"
153
+ if cache_key in self._cache:
154
+ return self._cache[cache_key]
155
+
156
+ meta = self.COUNTRIES.get(country_code, self.COUNTRIES["USA"])
157
+ iso2 = meta["iso2"]
158
+ raw = {}
159
+
160
+ with Progress(SpinnerColumn(), TextColumn("[cyan]Fetching {task.description}"),
161
+ transient=True) as prog:
162
+ t = prog.add_task(f"live data for {meta['name']}")
163
+ for key, indicator in self.WB_INDICATORS.items():
164
+ try:
165
+ val = self._wb_fetch(iso2, indicator)
166
+ raw[key] = val if val is not None else self.FALLBACKS[country_code][key]
167
+ except Exception:
168
+ raw[key] = self.FALLBACKS[country_code][key]
169
+
170
+ # ── Normalise to [0, 1] for the RL environment ────────────────────
171
+ state = {
172
+ # Lower inflation = better; 0 %→1.0, ≥15 %→0.0
173
+ "inflation" : max(0.0, min(1.0, 1 - raw["inflation"] / 15.0)),
174
+ # Higher employment = better
175
+ "employment" : max(0.0, min(1.0, 1 - raw["unemployment"]/ 25.0)),
176
+ # Higher health expenditure + life expectancy = better
177
+ "health" : max(0.0, min(1.0, (raw["health_exp"] / 20.0 +
178
+ raw["life_expect"] / 90.0) / 2)),
179
+ # GDP growth proxy for satisfaction
180
+ "satisfaction": max(0.0, min(1.0, (raw["gdp_growth"] + 5) / 15.0)),
181
+ # Lower homicide = better
182
+ "crime" : max(0.0, min(1.0, 1 - raw["homicide"] / 50.0)),
183
+ }
184
+
185
+ # Attach raw for reporting
186
+ state["_raw"] = raw
187
+ state["_country"]= meta["name"]
188
+ state["_fetched"]= datetime.now().isoformat()
189
+
190
+ self._cache[cache_key] = state
191
+ return state
192
+
193
+ def fetch_all_countries(self) -> Dict[str, dict]:
194
+ results = {}
195
+ for code in self.COUNTRIES:
196
+ console.log(f"[dim]→ fetching {code}")
197
+ results[code] = self.fetch_country(code)
198
+ return results
199
+
200
+ def to_dataframe(self, all_data: Dict[str, dict]) -> pd.DataFrame:
201
+ rows = []
202
+ for code, state in all_data.items():
203
+ raw = state.get("_raw", {})
204
+ rows.append({
205
+ "country" : state.get("_country", code),
206
+ "code" : code,
207
+ "inflation_pct": raw.get("inflation", 0),
208
+ "unemployment_pct": raw.get("unemployment", 0),
209
+ "health_exp_gdp": raw.get("health_exp", 0),
210
+ "life_expect" : raw.get("life_expect", 0),
211
+ "gdp_growth" : raw.get("gdp_growth", 0),
212
+ "homicide_rate" : raw.get("homicide", 0),
213
+ # normalised
214
+ "norm_inflation" : state["inflation"],
215
+ "norm_employment" : state["employment"],
216
+ "norm_health" : state["health"],
217
+ "norm_satisfaction": state["satisfaction"],
218
+ "norm_crime" : state["crime"],
219
+ "fetched_at" : state.get("_fetched"),
220
+ })
221
+ return pd.DataFrame(rows)
222
+
223
+
224
+ # Instantiate & fetch
225
+ fetcher = RealTimeDataFetcher(cache_ttl_seconds=3600)
226
+ all_data = fetcher.fetch_all_countries()
227
+ df_world = fetcher.to_dataframe(all_data)
228
+
229
+ console.rule("[bold green]Live Data Fetched")
230
+ console.print(df_world[["country","inflation_pct","unemployment_pct",
231
+ "health_exp_gdp","gdp_growth"]].to_string(index=False))
232
+
233
+
234
+ # ── CELL 4: REAL-DATA DASHBOARD ──────────────────────────────────────────────
235
+ def plot_global_dashboard(df: pd.DataFrame) -> None:
236
+ fig = make_subplots(
237
+ rows=2, cols=3,
238
+ subplot_titles=(
239
+ "Inflation (%)", "Unemployment (%)", "Health Exp (% GDP)",
240
+ "Life Expectancy (yrs)", "GDP Growth (%)", "Homicide Rate (per 100k)"
241
+ ),
242
+ )
243
+ cols_raw = ["inflation_pct","unemployment_pct","health_exp_gdp",
244
+ "life_expect","gdp_growth","homicide_rate"]
245
+ colors = px.colors.qualitative.Bold
246
+
247
+ for i, col in enumerate(cols_raw):
248
+ r, c = divmod(i, 3)
249
+ fig.add_trace(
250
+ go.Bar(
251
+ x=df["country"], y=df[col],
252
+ marker_color=colors,
253
+ showlegend=False,
254
+ text=df[col].round(1), textposition="outside"
255
+ ),
256
+ row=r+1, col=c+1
257
+ )
258
+
259
+ fig.update_layout(
260
+ title_text="🌍 CivicAI — Real-Time Global Economic Dashboard",
261
+ title_font_size=20,
262
+ height=600, template="plotly_dark",
263
+ paper_bgcolor="#0d1117", plot_bgcolor="#0d1117",
264
+ font=dict(color="#e6edf3"),
265
+ )
266
+ fig.show()
267
+ fig.write_html("assets/global_dashboard.html")
268
+ console.log("[green]✓ Dashboard saved → assets/global_dashboard.html")
269
+
270
+ plot_global_dashboard(df_world)
271
+
272
+
273
+ # ── CELL 5: ADVANCED MULTI-COUNTRY ENVIRONMENT ───────────────────────────────
274
+ class AdvancedCivicAIEnv:
275
+ \"\"\"
276
+ Production-grade multi-country civic environment.
277
+ • Initialises from real World Bank data
278
+ • Supports 6 countries and 4 policy tasks
279
+ • Action space: 5-dimensional continuous [0,1]
280
+ • Observation: 10-dimensional (5 state + 5 delta from last step)
281
+ • Reward: weighted multi-objective (Pareto-style)
282
+ • Includes shock events (recession, pandemic proxy, crime spike)
283
+ \"\"\"
284
+
285
+ TASKS = {
286
+ "stabilize_economy" : {"inflation_weight":0.4, "employment_weight":0.3, "health_weight":0.15, "satisfaction_weight":0.1, "crime_weight":0.05},
287
+ "improve_health" : {"inflation_weight":0.1, "employment_weight":0.2, "health_weight":0.5, "satisfaction_weight":0.15,"crime_weight":0.05},
288
+ "reduce_crime" : {"inflation_weight":0.1, "employment_weight":0.2, "health_weight":0.2, "satisfaction_weight":0.1, "crime_weight":0.4},
289
+ "maximize_wellbeing" : {"inflation_weight":0.2, "employment_weight":0.2, "health_weight":0.2, "satisfaction_weight":0.2, "crime_weight":0.2},
290
+ }
291
+
292
+ SHOCK_EVENTS = [
293
+ {"name":"recession", "prob":0.02, "effect":{"inflation":+0.15,"employment":-0.12,"satisfaction":-0.1}},
294
+ {"name":"pandemic", "prob":0.01, "effect":{"health":-0.2, "employment":-0.1, "satisfaction":-0.15}},
295
+ {"name":"crime_spike","prob":0.02, "effect":{"crime":-0.15, "satisfaction":-0.08}},
296
+ {"name":"boom", "prob":0.02, "effect":{"employment":+0.1,"satisfaction":+0.1,"inflation":+0.05}},
297
+ ]
298
+
299
+ def __init__(self, fetcher: RealTimeDataFetcher, default_country: str = "USA"):
300
+ self.fetcher = fetcher
301
+ self.default_country = default_country
302
+ self._prev_state = None
303
+ self.step_count = 0
304
+ self.shock_log = []
305
+ self.state_data = {}
306
+
307
+ def reset(self, task_id: str = "stabilize_economy", country: str = None) -> dict:
308
+ country = country or self.default_country
309
+ self.task_id = task_id
310
+ self.weights = self.TASKS[task_id]
311
+ self.step_count = 0
312
+ self.shock_log = []
313
+
314
+ # Load real data as starting state
315
+ live = self.fetcher.fetch_country(country)
316
+ self.state_data = {k: live[k] for k in ["inflation","employment","health","satisfaction","crime"]}
317
+
318
+ # Add small noise so each episode is unique
319
+ for k in self.state_data:
320
+ self.state_data[k] = float(np.clip(
321
+ self.state_data[k] + np.random.normal(0, 0.02), 0.0, 1.0
322
+ ))
323
+
324
+ self._prev_state = dict(self.state_data)
325
+ return self._build_obs()
326
+
327
+ def _build_obs(self) -> dict:
328
+ \"\"\"10-dim observation: current state + delta from previous step.\"\"\"
329
+ obs = dict(self.state_data)
330
+ obs["_task"] = self.task_id
331
+ obs["_step"] = self.step_count
332
+ if self._prev_state:
333
+ for k in ["inflation","employment","health","satisfaction","crime"]:
334
+ obs[f"d_{k}"] = self.state_data[k] - self._prev_state[k]
335
+ else:
336
+ for k in ["inflation","employment","health","satisfaction","crime"]:
337
+ obs[f"d_{k}"] = 0.0
338
+ return obs
339
+
340
+ def _apply_shocks(self):
341
+ \"\"\"Stochastic external shock events.\"\"\"
342
+ for shock in self.SHOCK_EVENTS:
343
+ if np.random.random() < shock["prob"]:
344
+ self.shock_log.append({"step": self.step_count, "event": shock["name"]})
345
+ for k, delta in shock["effect"].items():
346
+ if k in self.state_data:
347
+ self.state_data[k] = float(np.clip(self.state_data[k] + delta, 0.0, 1.0))
348
+ console.log(f"[yellow]⚡ Shock event: {shock['name']} at step {self.step_count}")
349
+
350
+ def step(self, action: dict) -> Tuple[dict, float, bool, dict]:
351
+ \"\"\"
352
+ action keys: tax, jobs, healthcare, education, infrastructure
353
+ Each in [0, 1] — represents budget allocation intensity.
354
+ \"\"\"
355
+ self._prev_state = dict(self.state_data)
356
+
357
+ # Policy effects (with diminishing returns via sqrt)
358
+ tax = action.get("tax", 0.5)
359
+ jobs = action.get("jobs", 0.5)
360
+ healthcare = action.get("healthcare", 0.5)
361
+ education = action.get("education", 0.5)
362
+ infra = action.get("infrastructure",0.5)
363
+
364
+ self.state_data["inflation"] = np.clip(
365
+ self.state_data["inflation"] - tax * 0.08 + jobs * 0.02, 0.0, 1.0)
366
+ self.state_data["employment"] = np.clip(
367
+ self.state_data["employment"] + jobs * 0.06 + infra * 0.02, 0.0, 1.0)
368
+ self.state_data["health"] = np.clip(
369
+ self.state_data["health"] + healthcare * 0.07 + education * 0.02, 0.0, 1.0)
370
+ self.state_data["satisfaction"] = np.clip(
371
+ self.state_data["satisfaction"] + education * 0.05 + infra * 0.03
372
+ - tax * 0.03, 0.0, 1.0)
373
+ self.state_data["crime"] = np.clip(
374
+ self.state_data["crime"] + education * 0.05 + jobs * 0.03
375
+ - infra * 0.01, 0.0, 1.0)
376
+
377
+ # Gaussian noise
378
+ for k in self.state_data:
379
+ self.state_data[k] = float(np.clip(
380
+ self.state_data[k] + np.random.normal(0, 0.008), 0.0, 1.0))
381
+
382
+ self._apply_shocks()
383
+ self.step_count += 1
384
+
385
+ reward = self._compute_reward()
386
+ done = self.step_count >= 50
387
+ info = {"shocks": self.shock_log, "step": self.step_count}
388
+ return self._build_obs(), float(reward), done, info
389
+
390
+ def _compute_reward(self) -> float:
391
+ s = self.state_data
392
+ w = self.weights
393
+ return (
394
+ w["inflation_weight"] * s["inflation"] +
395
+ w["employment_weight"] * s["employment"] +
396
+ w["health_weight"] * s["health"] +
397
+ w["satisfaction_weight"] * s["satisfaction"] +
398
+ w["crime_weight"] * s["crime"]
399
+ )
400
+
401
+ def state_report(self) -> dict:
402
+ return {k: round(v, 4) for k, v in self.state_data.items()}
403
+
404
+
405
+ # Smoke test
406
+ env_adv = AdvancedCivicAIEnv(fetcher, default_country="USA")
407
+ obs = env_adv.reset("stabilize_economy", "USA")
408
+ console.rule("[bold green]Advanced Environment Ready")
409
+ console.print(f"Initial state (USA, real data): {env_adv.state_report()}")
410
+
411
+
412
+ # ── CELL 6: PROMPT BUILDER (5-action) ────────────────────────────────────────
413
+ def build_prompt(obs: dict) -> str:
414
+ task_desc = {
415
+ "stabilize_economy" : "Your priority is economic stability: control inflation and protect employment.",
416
+ "improve_health" : "Your priority is public health: maximize health outcomes and life expectancy.",
417
+ "reduce_crime" : "Your priority is public safety: reduce crime through investment and employment.",
418
+ "maximize_wellbeing" : "Your priority is overall citizen wellbeing across all dimensions.",
419
+ }.get(obs.get("_task",""), "Optimize all civic outcomes.")
420
+
421
+ return (
422
+ f"You are a senior policy advisor.\\n{task_desc}\\n\\n"
423
+ f"CURRENT STATE (step {obs.get('_step',0)}):\\n"
424
+ f" Inflation score : {obs.get('inflation',0.5):.3f} (Δ {obs.get('d_inflation',0):+.3f})\\n"
425
+ f" Employment score : {obs.get('employment',0.5):.3f} (Δ {obs.get('d_employment',0):+.3f})\\n"
426
+ f" Health score : {obs.get('health',0.5):.3f} (Δ {obs.get('d_health',0):+.3f})\\n"
427
+ f" Satisfaction score: {obs.get('satisfaction',0.5):.3f} (Δ {obs.get('d_satisfaction',0):+.3f})\\n"
428
+ f" Crime score : {obs.get('crime',0.5):.3f} (Δ {obs.get('d_crime',0):+.3f})\\n\\n"
429
+ "OUTPUT FORMAT (all values 0.0–1.0, no other text):\\n"
430
+ "tax: 0.X, jobs: 0.X, healthcare: 0.X, education: 0.X, infrastructure: 0.X"
431
+ )
432
+
433
+
434
+ def parse_action(text: str) -> dict:
435
+ \"\"\"5-dimensional action parser with robust regex.\"\"\"
436
+ keys = ["tax", "jobs", "healthcare", "education", "infrastructure"]
437
+
438
+ def extract(key: str) -> float:
439
+ m = re.search(rf"{key}\\s*:\\s*(\\d*\\.?\\d+)", text)
440
+ if m:
441
+ try:
442
+ return float(np.clip(float(m.group(1)), 0.0, 1.0))
443
+ except ValueError:
444
+ pass
445
+ return 0.5
446
+
447
+ return {k: extract(k) for k in keys}
448
+
449
+
450
+ # ── CELL 7: LOAD MODEL WITH LoRA ─────────────────────────────────────────────
451
+ from transformers import AutoTokenizer, AutoModelForCausalLM
452
+ from peft import LoraConfig, get_peft_model, TaskType
453
+
454
+ MODEL_NAME = "gpt2" # swap to "gpt2-medium" or "distilgpt2" as needed
455
+
456
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
457
+ tokenizer.pad_token = tokenizer.eos_token
458
+ tokenizer.padding_side = "left"
459
+
460
+ base_model = AutoModelForCausalLM.from_pretrained(
461
+ MODEL_NAME,
462
+ torch_dtype = torch.bfloat16 if USE_BF16 else torch.float32,
463
+ )
464
+
465
+ # ── Attach LoRA adapters (reduces trainable params by ~90%) ──────────────────
466
+ lora_cfg = LoraConfig(
467
+ task_type = TaskType.CAUSAL_LM,
468
+ r = 8, # rank
469
+ lora_alpha = 32,
470
+ target_modules = ["c_attn"], # GPT-2 attention projection
471
+ lora_dropout = 0.05,
472
+ bias = "none",
473
+ )
474
+ model = get_peft_model(base_model, lora_cfg)
475
+ model.print_trainable_parameters()
476
+ model = model.to(DEVICE)
477
+
478
+ console.rule("[bold green]Model Ready")
479
+ console.log(f"[cyan]Model : {MODEL_NAME} + LoRA (r=8)")
480
+ console.log(f"[cyan]Parameters : {sum(p.numel() for p in model.parameters())/1e6:.1f}M total")
481
+
482
+
483
+ # ── CELL 8: BUILD TRAINING DATASET FROM REAL DATA ────────────────────────────
484
+ from datasets import Dataset
485
+
486
+ NUM_SAMPLES = 300
487
+ records = []
488
+ env_tmp = AdvancedCivicAIEnv(fetcher)
489
+ task_list = list(AdvancedCivicAIEnv.TASKS.keys())
490
+ country_list= list(RealTimeDataFetcher.COUNTRIES.keys())
491
+
492
+ for i in range(NUM_SAMPLES):
493
+ task = task_list[i % len(task_list)]
494
+ country = country_list[i % len(country_list)]
495
+ obs = env_tmp.reset(task, country)
496
+ records.append({
497
+ "prompt" : build_prompt(obs),
498
+ "task" : task,
499
+ "country" : country,
500
+ })
501
+
502
+ train_dataset = Dataset.from_list(records)
503
+ console.log(f"[green]✓ Dataset: {len(train_dataset)} prompts across "
504
+ f"{len(task_list)} tasks × {len(country_list)} countries")
505
+ console.log(f" Sample:\\n{train_dataset[0]['prompt'][:300]}...")
506
+
507
+
508
+ # ── CELL 9: MULTI-OBJECTIVE REWARD FUNCTION ───────────────────────────────────
509
+ def civic_reward_advanced(prompts, completions, task=None, country=None, **kwargs) -> List[float]:
510
+ \"\"\"
511
+ Multi-objective GRPO reward function.
512
+ Scores: environment reward + format compliance + consistency bonus.
513
+ \"\"\"
514
+ rewards = []
515
+ env_r = AdvancedCivicAIEnv(fetcher)
516
+ task_list_ = task if isinstance(task, list) else [task] * len(prompts)
517
+ country_ = country if isinstance(country,list) else [country]* len(prompts)
518
+
519
+ for i, (prompt, completion) in enumerate(zip(prompts, completions)):
520
+ # Extract text
521
+ if isinstance(completion, list) and len(completion) > 0:
522
+ text = completion[0].get("content", "")
523
+ else:
524
+ text = str(completion)
525
+
526
+ action = parse_action(text)
527
+
528
+ # Environment reward
529
+ t = (task_list_[i] if task_list_[i] else "maximize_wellbeing")
530
+ c = (country_[i] if country_[i] else "USA")
531
+ env_r.reset(t, c)
532
+ _, env_rew, _, _ = env_r.step(action)
533
+
534
+ # Format reward: all 5 keys present
535
+ keys_found = sum(
536
+ 1 for k in ["tax","jobs","healthcare","education","infrastructure"]
537
+ if re.search(rf"{k}\\s*:\\s*\\d", text)
538
+ )
539
+ fmt_bonus = (keys_found / 5.0) * 0.15 # up to +0.15
540
+
541
+ # Diversity bonus: penalise all-same values (lazy policy)
542
+ vals = list(action.values())
543
+ div_bonus = float(np.std(vals)) * 0.1 # up to ~+0.05
544
+
545
+ total = float(env_rew) + fmt_bonus + div_bonus
546
+ rewards.append(round(total, 5))
547
+
548
+ return rewards
549
+
550
+
551
+ # ── CELL 10: GRPO CONFIG (VERSION-SAFE) ──────────────────────────────────────
552
+ from trl import GRPOConfig, GRPOTrainer
553
+
554
+ valid_params = set(inspect.signature(GRPOConfig.__init__).parameters)
555
+
556
+ all_kwargs = {
557
+ "output_dir" : "checkpoints/civicai-grpo",
558
+ "num_train_epochs" : 3,
559
+ "per_device_train_batch_size" : 2,
560
+ "num_generations" : 2,
561
+ "max_prompt_length" : 300,
562
+ "max_completion_length" : 80,
563
+ "learning_rate" : 5e-6,
564
+ "logging_steps" : 5,
565
+ "save_strategy" : "epoch",
566
+ "save_total_limit" : 2,
567
+ "report_to" : "none",
568
+ "remove_unused_columns" : False,
569
+ "bf16" : USE_BF16,
570
+ "fp16" : USE_FP16,
571
+ "gradient_accumulation_steps" : 4,
572
+ "max_grad_norm" : 0.3,
573
+ "warmup_ratio" : 0.05,
574
+ "lr_scheduler_type" : "cosine",
575
+ "dataloader_num_workers" : 0,
576
+ }
577
+
578
+ safe_kwargs = {k: v for k, v in all_kwargs.items() if k in valid_params}
579
+ skipped = set(all_kwargs) - set(safe_kwargs)
580
+ if skipped:
581
+ console.log(f"[yellow]Skipped unsupported GRPOConfig args: {skipped}")
582
+
583
+ grpo_config = GRPOConfig(**safe_kwargs)
584
+
585
+ trainer = GRPOTrainer(
586
+ model = model,
587
+ args = grpo_config,
588
+ reward_funcs = civic_reward_advanced,
589
+ train_dataset = train_dataset,
590
+ processing_class = tokenizer,
591
+ )
592
+ console.log("[green]✓ GRPOTrainer initialised with LoRA + multi-objective reward")
593
+
594
+
595
+ # ── CELL 11: TRAINING ─────────────────────────────────────────────────────────
596
+ console.rule("[bold cyan]Starting GRPO Training")
597
+ start_time = time.time()
598
+ trainer.train()
599
+ elapsed = time.time() - start_time
600
+ console.rule(f"[bold green]Training Complete — {elapsed/60:.1f} min")
601
+
602
+
603
+ # ── CELL 12: EXTRACT & PLOT TRAINING METRICS ─────────────────────────────────
604
+ logs = trainer.state.log_history
605
+ df_logs = pd.DataFrame(logs).dropna(subset=["loss"] if "loss" in pd.DataFrame(logs).columns else [])
606
+
607
+ reward_entries = [e for e in logs if "reward" in e]
608
+ rewards_logged = [e["reward"] for e in reward_entries]
609
+ steps_logged = [e.get("step", i) for i, e in enumerate(reward_entries)]
610
+
611
+ fig = make_subplots(rows=1, cols=2,
612
+ subplot_titles=("Reward Curve", "Reward Distribution"))
613
+
614
+ # Reward over steps
615
+ fig.add_trace(go.Scatter(
616
+ x=steps_logged, y=rewards_logged,
617
+ mode="lines", name="Reward", line=dict(color="#00d4ff", width=2)
618
+ ), row=1, col=1)
619
+
620
+ # Smoothed
621
+ if len(rewards_logged) > 5:
622
+ smooth = np.convolve(rewards_logged, np.ones(5)/5, mode="valid")
623
+ fig.add_trace(go.Scatter(
624
+ x=steps_logged[4:], y=smooth,
625
+ mode="lines", name="Smoothed",
626
+ line=dict(color="#ff6b6b", width=2, dash="dash")
627
+ ), row=1, col=1)
628
+
629
+ # Histogram
630
+ fig.add_trace(go.Histogram(
631
+ x=rewards_logged, nbinsx=20,
632
+ marker_color="#00d4ff", opacity=0.75, name="Distribution"
633
+ ), row=1, col=2)
634
+
635
+ fig.update_layout(
636
+ title="CivicAI GRPO Training Metrics",
637
+ template="plotly_dark", height=420,
638
+ paper_bgcolor="#0d1117", font=dict(color="#e6edf3"),
639
+ )
640
+ fig.show()
641
+ fig.write_html("assets/training_metrics.html")
642
+
643
+ if rewards_logged:
644
+ console.print(f"[cyan]Start reward : {rewards_logged[0]:.4f}")
645
+ console.print(f"[cyan]Final reward : {rewards_logged[-1]:.4f}")
646
+ console.print(f"[green]Improvement : {rewards_logged[-1]-rewards_logged[0]:+.4f}")
647
+
648
+
649
+ # ── CELL 13: MULTI-COUNTRY POLICY EVALUATION ─────────────────────────────────
650
+ def evaluate_trained_policy(
651
+ model, tokenizer, fetcher,
652
+ countries: List[str] = None,
653
+ tasks: List[str] = None,
654
+ episodes: int = 5,
655
+ ) -> pd.DataFrame:
656
+ \"\"\"Evaluate trained policy on all countries × all tasks.\"\"\"
657
+ countries = countries or list(RealTimeDataFetcher.COUNTRIES.keys())
658
+ tasks = tasks or list(AdvancedCivicAIEnv.TASKS.keys())
659
+ model.eval()
660
+ results = []
661
+
662
+ for country in countries:
663
+ for task in tasks:
664
+ ep_rewards = []
665
+ env_eval = AdvancedCivicAIEnv(fetcher, default_country=country)
666
+
667
+ for _ in range(episodes):
668
+ obs = env_eval.reset(task, country)
669
+ ep_reward = 0.0
670
+ for _ in range(20):
671
+ prompt = build_prompt(obs)
672
+ inputs = tokenizer(prompt, return_tensors="pt",
673
+ truncation=True, max_length=300).to(DEVICE)
674
+ with torch.no_grad():
675
+ out = model.generate(
676
+ **inputs, max_new_tokens=60,
677
+ do_sample=False,
678
+ pad_token_id=tokenizer.eos_token_id,
679
+ )
680
+ gen_tokens = out[0][inputs["input_ids"].shape[1]:]
681
+ text = tokenizer.decode(gen_tokens, skip_special_tokens=True)
682
+ action = parse_action(text)
683
+ obs, r, done, _ = env_eval.step(action)
684
+ ep_reward += r
685
+ if done: break
686
+ ep_rewards.append(ep_reward / 20)
687
+
688
+ results.append({
689
+ "country" : RealTimeDataFetcher.COUNTRIES[country]["name"],
690
+ "task" : task,
691
+ "mean_r" : round(float(np.mean(ep_rewards)), 4),
692
+ "std_r" : round(float(np.std(ep_rewards)), 4),
693
+ "max_r" : round(float(np.max(ep_rewards)), 4),
694
+ })
695
+ console.log(f"[dim]{country} / {task} → {results[-1]['mean_r']:.4f}")
696
+
697
+ return pd.DataFrame(results)
698
+
699
+
700
+ def baseline_score(fetcher, episodes=5):
701
+ \"\"\"Fixed 0.5 policy baseline.\"\"\"
702
+ env_b, total = AdvancedCivicAIEnv(fetcher, "USA"), []
703
+ for _ in range(episodes):
704
+ obs = env_b.reset("maximize_wellbeing", "USA")
705
+ r = 0.0
706
+ for _ in range(20):
707
+ obs, rew, done, _ = env_b.step(
708
+ {k: 0.5 for k in ["tax","jobs","healthcare","education","infrastructure"]}
709
+ )
710
+ r += rew
711
+ total.append(r / 20)
712
+ return float(np.mean(total))
713
+
714
+
715
+ console.rule("[bold cyan]Evaluating Policy Across Countries & Tasks")
716
+ df_eval = evaluate_trained_policy(model, tokenizer, fetcher, episodes=3)
717
+ baseline = baseline_score(fetcher)
718
+
719
+ console.print(df_eval.to_string(index=False))
720
+ console.print(f"\\n[bold]Baseline (fixed 0.5) : {baseline:.4f}")
721
+ console.print(f"[bold green]Best trained score : {df_eval['mean_r'].max():.4f}")
722
+
723
+
724
+ # ── CELL 14: EVALUATION HEATMAP ──────────────────────────────────────────────
725
+ pivot = df_eval.pivot(index="country", columns="task", values="mean_r")
726
+
727
+ fig_heat = go.Figure(go.Heatmap(
728
+ z = pivot.values,
729
+ x = pivot.columns.tolist(),
730
+ y = pivot.index.tolist(),
731
+ colorscale = "RdYlGn",
732
+ text = np.round(pivot.values, 3),
733
+ texttemplate="%{text}",
734
+ showscale = True,
735
+ zmin=0.4, zmax=1.0,
736
+ ))
737
+ fig_heat.add_shape(
738
+ type="line", x0=-0.5, x1=len(pivot.columns)-0.5,
739
+ y0=-0.5, y1=len(pivot.index)-0.5,
740
+ line=dict(color="white", width=0)
741
+ )
742
+ fig_heat.update_layout(
743
+ title = "Policy Performance Heatmap — Country × Task (GRPO Trained)",
744
+ template="plotly_dark", height=400,
745
+ paper_bgcolor="#0d1117", font=dict(color="#e6edf3"),
746
+ xaxis_title="Task", yaxis_title="Country",
747
+ )
748
+ fig_heat.show()
749
+ fig_heat.write_html("assets/eval_heatmap.html")
750
+
751
+
752
+ # ── CELL 15: SAVE EVERYTHING ─────────────────────────────────────────────────
753
+ # Save LoRA adapter only (lightweight)
754
+ model.save_pretrained("checkpoints/civicai-lora")
755
+ tokenizer.save_pretrained("checkpoints/civicai-lora")
756
+
757
+ # Save results JSON
758
+ results_json = {
759
+ "run_timestamp" : datetime.now().isoformat(),
760
+ "model" : MODEL_NAME,
761
+ "lora_rank" : 8,
762
+ "training_epochs": 3,
763
+ "num_countries" : len(RealTimeDataFetcher.COUNTRIES),
764
+ "num_tasks" : len(AdvancedCivicAIEnv.TASKS),
765
+ "data_source" : "World Bank Open API (live)",
766
+ "baseline_reward": round(baseline, 4),
767
+ "best_reward" : round(float(df_eval["mean_r"].max()), 4),
768
+ "improvement" : round(float(df_eval["mean_r"].max()) - baseline, 4),
769
+ "reward_history" : rewards_logged,
770
+ "eval_by_country_task": df_eval.to_dict(orient="records"),
771
+ "real_data_snapshot" : df_world[["country","inflation_pct","unemployment_pct",
772
+ "health_exp_gdp","gdp_growth"]].to_dict(orient="records"),
773
+ }
774
+
775
+ with open("assets/training_results.json", "w") as f:
776
+ json.dump(results_json, f, indent=2)
777
+
778
+ console.rule("[bold green]All Done")
779
+ console.print(f"[green]✓ LoRA checkpoint → checkpoints/civicai-lora/")
780
+ console.print(f"[green]✓ Results JSON → assets/training_results.json")
781
+ console.print(f"[green]✓ Dashboard HTML → assets/global_dashboard.html")
782
+ console.print(f"[green]✓ Training metrics → assets/training_metrics.html")
783
+ console.print(f"[green]✓ Eval heatmap → assets/eval_heatmap.html")
784
+ console.print(f"\\n[bold cyan]Baseline : {baseline:.4f}")
785
+ console.print(f"[bold green]Best score: {df_eval['mean_r'].max():.4f}")
786
+ console.print(f"[bold green]Delta : {df_eval['mean_r'].max() - baseline:+.4f}")
787
+ """
788
+
789
+ cells = []
790
+ # Create a title cell
791
+ cells.append({
792
+ "cell_type": "markdown",
793
+ "metadata": {},
794
+ "source": [
795
+ "# 🏛 CivicAI Advanced — Senior ML Engineer Edition\\n",
796
+ "**Real-time Economic Data + GRPO + LoRA + Multi-Country + Live Dashboard**"
797
+ ]
798
+ })
799
+
800
+ # Split the code by cells
801
+ chunks = re.split(r'# ── CELL \d+.*?\n', code)
802
+ headers = re.findall(r'# ── CELL \d+.*?$', code, re.MULTILINE)
803
+
804
+ # The first chunk is everything before CELL 1
805
+ if len(chunks) > 1:
806
+ for idx, chunk in enumerate(chunks[1:]):
807
+ header_text = headers[idx]
808
+ cells.append({
809
+ "cell_type": "markdown",
810
+ "metadata": {},
811
+ "source": [f"### {header_text.replace('# ── ', '').replace(' ──', '').strip()}"]
812
+ })
813
+ # Remove trailing and leading newlines
814
+ chunk = chunk.strip()
815
+
816
+ # If the chunk is just the pip install block, we'll strip the docstrings
817
+ if "pip install" in chunk and '"""' in chunk:
818
+ chunk = chunk.replace('"""', '').strip()
819
+
820
+ cells.append({
821
+ "cell_type": "code",
822
+ "execution_count": None,
823
+ "metadata": {},
824
+ "outputs": [],
825
+ "source": [line + "\\n" for line in chunk.split('\\n')]
826
+ })
827
+
828
+ notebook = {
829
+ "cells": cells,
830
+ "metadata": {
831
+ "colab": {"name": "CivicAI_Training.ipynb"},
832
+ "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
833
+ "language_info": {"name": "python", "version": "3.10"}
834
+ },
835
+ "nbformat": 4,
836
+ "nbformat_minor": 4
837
+ }
838
+
839
+ with open("c:/Users/mdaft/OneDrive/Desktop/GitHub Projects/AI_Society_Simulator/CivicAI_Training.ipynb", "w", encoding='utf-8') as f:
840
+ json.dump(notebook, f, indent=2)
dashboard/app.js CHANGED
@@ -145,54 +145,79 @@ async function runFullEpisode() {
145
  let maxSteps = parseInt(turnsInput.value) || 50;
146
  if (maxSteps < 5) maxSteps = 5;
147
  if (maxSteps > 200) maxSteps = 200;
148
-
149
  turnsInput.value = maxSteps;
 
 
 
150
  document.getElementById('turn-max-display').textContent = maxSteps;
151
 
152
  stopAutoplay();
153
  totalReward = 0;
154
  stepCount = 0;
155
  clearHistories();
156
- setStatus('Running Simulation Backend...', 'running');
 
 
157
 
 
158
  try {
159
- const res = await fetch(`${API}/start-simulation`, {
160
  method: 'POST',
161
  headers: { 'Content-Type': 'application/json' },
162
  body: JSON.stringify({ task_id: taskId, max_steps: maxSteps }),
163
  });
164
-
165
- if (!res.ok) {
166
- throw new Error(`HTTP error ${res.status}`);
167
- }
168
-
169
- const data = await res.json();
170
-
171
- setStatus('Visualizing Trajectory...', 'running');
172
-
173
- const stepLog = data.step_log || [];
174
- const speed = 200;
175
-
176
- for (let i = 0; i < stepLog.length; i++) {
177
- const stepData = stepLog[i];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  stepCount++;
179
- totalReward += stepData.reward;
180
-
181
- // Reconstruct info if it's missing (it was added recently to backend)
182
- const info = stepData.info || {};
183
- updateDashboard(stepData.obs, stepData.reward, info);
184
  document.getElementById('total-reward').textContent = totalReward.toFixed(3);
185
-
 
 
 
 
186
  await sleep(speed);
 
 
 
 
187
  }
188
-
189
- setStatus('Done', 'done');
190
- } catch (err) {
191
- console.error('Simulation failed:', err);
192
- setStatus('Error', 'done');
193
  }
 
 
194
  }
195
 
 
196
  // ============================================================
197
  // UI Updates
198
  // ============================================================
 
145
  let maxSteps = parseInt(turnsInput.value) || 50;
146
  if (maxSteps < 5) maxSteps = 5;
147
  if (maxSteps > 200) maxSteps = 200;
 
148
  turnsInput.value = maxSteps;
149
+
150
+ // Lock UI immediately
151
+ setStatus('Running...', 'running');
152
  document.getElementById('turn-max-display').textContent = maxSteps;
153
 
154
  stopAutoplay();
155
  totalReward = 0;
156
  stepCount = 0;
157
  clearHistories();
158
+ clearDebate();
159
+ clearPolicyLog();
160
+ clearInsights();
161
 
162
+ // Reset the environment first
163
  try {
164
+ const resetRes = await fetch(`${API}/reset`, {
165
  method: 'POST',
166
  headers: { 'Content-Type': 'application/json' },
167
  body: JSON.stringify({ task_id: taskId, max_steps: maxSteps }),
168
  });
169
+ if (!resetRes.ok) throw new Error(`Reset failed: HTTP ${resetRes.status}`);
170
+ const resetData = await resetRes.json();
171
+ updateDashboard(resetData.observation, null, null);
172
+ document.getElementById('turn-number').textContent = '0';
173
+ document.getElementById('total-reward').textContent = '0.000';
174
+ } catch (err) {
175
+ console.error('Reset failed:', err);
176
+ setStatus('Reset Error', 'done');
177
+ return;
178
+ }
179
+
180
+ // Run steps one-by-one, updating the turn counter live
181
+ const speed = 200; // ms delay between visual updates
182
+ for (let i = 0; i < maxSteps; i++) {
183
+ try {
184
+ const res = await fetch(`${API}/step`, {
185
+ method: 'POST',
186
+ headers: { 'Content-Type': 'application/json' },
187
+ body: JSON.stringify({ use_agents: true }),
188
+ });
189
+
190
+ if (!res.ok) {
191
+ const errData = await res.json().catch(() => ({}));
192
+ console.error('Step error:', errData.detail || res.status);
193
+ setStatus(`Error at turn ${i + 1}`, 'done');
194
+ return;
195
+ }
196
+
197
+ const data = await res.json();
198
  stepCount++;
199
+ totalReward += data.reward;
200
+
201
+ // Live turn counter: TURN X / maxSteps
202
+ document.getElementById('turn-number').textContent = stepCount;
 
203
  document.getElementById('total-reward').textContent = totalReward.toFixed(3);
204
+
205
+ updateDashboard(data.observation, data.reward, data.info);
206
+
207
+ if (data.done) break;
208
+
209
  await sleep(speed);
210
+ } catch (err) {
211
+ console.error('Step failed:', err);
212
+ setStatus('Network Error', 'done');
213
+ return;
214
  }
 
 
 
 
 
215
  }
216
+
217
+ setStatus('Done ✅', 'done');
218
  }
219
 
220
+
221
  // ============================================================
222
  // UI Updates
223
  // ============================================================