File size: 35,751 Bytes
a36e07f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93e9982
 
a36e07f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
import json
import re

code = """
# =============================================================================
# CivicAI Advanced β€” Senior ML Engineer Edition
# Real-time Economic Data + GRPO + LoRA + Multi-Country + Live Dashboard
# =============================================================================

# ── CELL 1: INSTALL DEPENDENCIES ─────────────────────────────────────────────
\"\"\"
!pip install -q \\
    "transformers>=4.38" \\
    "accelerate>=0.27" \\
    "trl>=0.10" \\
    "peft>=0.9" \\
    "bitsandbytes>=0.42" \\
    "datasets>=2.17" \\
    "requests>=2.31" \\
    "pandas>=2.0" \\
    "fredapi" \\
    "world-bank-data" \\
    "plotly>=5.18" \\
    "rich>=13.0" \\
    "tenacity>=8.2"

# After install: Runtime β†’ Restart session β†’ run from Cell 2
\"\"\"

# ── CELL 2: IMPORTS & SYSTEM SETUP ───────────────────────────────────────────
import os, re, json, math, time, random, inspect, warnings, logging
from datetime import datetime
from typing   import Dict, List, Optional, Tuple
from pathlib  import Path

import numpy    as np
import pandas   as pd
import torch
import requests
import plotly.graph_objects as go
import plotly.express       as px
from plotly.subplots import make_subplots
from tenacity        import retry, stop_after_attempt, wait_exponential
from rich.console    import Console
from rich.table      import Table
from rich.progress   import Progress, SpinnerColumn, TextColumn
from rich            import print as rprint

warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.ERROR)

console = Console()

# ── Hardware detection ────────────────────────────────────────────────────────
CUDA_OK  = torch.cuda.is_available()
if CUDA_OK:
    CAP      = torch.cuda.get_device_capability()
    USE_BF16 = CAP[0] >= 8
    USE_FP16 = not USE_BF16
    GPU_NAME = torch.cuda.get_device_name(0)
else:
    USE_BF16 = USE_FP16 = False
    GPU_NAME = "CPU"

DEVICE = "cuda" if CUDA_OK else "cpu"



# ── Paths ─────────────────────────────────────────────────────────────────────
Path("assets").mkdir(exist_ok=True)
Path("checkpoints").mkdir(exist_ok=True)
Path("logs").mkdir(exist_ok=True)

console.rule("[bold cyan]CivicAI Advanced β€” System Ready")
table = Table(show_header=False, box=None)
table.add_row("[cyan]PyTorch",    torch.__version__)
table.add_row("[cyan]Device",     GPU_NAME)
table.add_row("[cyan]BF16/FP16",  f"bf16={USE_BF16}  fp16={USE_FP16}")
table.add_row("[cyan]Timestamp",  datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
console.print(table)


# ── CELL 3: REAL-TIME DATA FETCHER ───────────────────────────────────────────
class RealTimeDataFetcher:
    \"\"\"
    Fetches live economic indicators from:
      β€’ World Bank Open API  (no key required)
      β€’ FRED / St. Louis Fed (free key via fredapi)
      β€’ BLS  (Bureau of Labor Statistics β€” no key for basic tier)
      β€’ REST Countries       (social / governance proxies)
    Falls back to realistic historical means if any API is unavailable.
    \"\"\"

    WORLD_BANK_BASE = "https://api.worldbank.org/v2"
    BLS_BASE        = "https://api.bls.gov/publicAPI/v1/timeseries/data"
    REST_COUNTRIES  = "https://restcountries.com/v3.1/alpha"

    # World Bank indicator codes
    WB_INDICATORS = {
        "inflation"   : "FP.CPI.TOTL.ZG",   # CPI inflation %
        "unemployment": "SL.UEM.TOTL.ZS",    # Unemployment % of labour force
        "health_exp"  : "SH.XPD.CHEX.GD.ZS",# Health expenditure % of GDP
        "life_expect" : "SP.DYN.LE00.IN",    # Life expectancy at birth
        "gdp_growth"  : "NY.GDP.MKTP.KD.ZG", # GDP growth %
        "homicide"    : "VC.IHR.PSRC.P5",    # Intentional homicides per 100k
    }

    # Country ISO codes supported
    COUNTRIES = {
        "USA": {"iso2": "US", "iso3": "USA", "name": "United States"},
        "IND": {"iso2": "IN", "iso3": "IND", "name": "India"},
        "GBR": {"iso2": "GB", "iso3": "GBR", "name": "United Kingdom"},
        "DEU": {"iso2": "DE", "iso3": "DEU", "name": "Germany"},
        "JPN": {"iso2": "JP", "iso3": "JPN", "name": "Japan"},
        "BRA": {"iso2": "BR", "iso3": "BRA", "name": "Brazil"},
    }

    # Realistic fallback values (5-year historical means, 2019-2023)
    FALLBACKS = {
        "USA": {"inflation":3.8,"unemployment":4.8,"health_exp":17.2,"life_expect":77.5,"gdp_growth":2.1,"homicide":6.5},
        "IND": {"inflation":5.5,"unemployment":7.2,"health_exp":3.3, "life_expect":69.4,"gdp_growth":5.8,"homicide":2.8},
        "GBR": {"inflation":3.2,"unemployment":4.2,"health_exp":10.9,"life_expect":80.4,"gdp_growth":1.4,"homicide":1.2},
        "DEU": {"inflation":2.8,"unemployment":3.5,"health_exp":12.8,"life_expect":80.6,"gdp_growth":0.9,"homicide":0.9},
        "JPN": {"inflation":1.2,"unemployment":2.8,"health_exp":10.9,"life_expect":84.3,"gdp_growth":0.7,"homicide":0.2},
        "BRA": {"inflation":6.9,"unemployment":11.0,"health_exp":9.9,"life_expect":75.5,"gdp_growth":1.2,"homicide":22.4},
    }

    def __init__(self, cache_ttl_seconds: int = 3600):
        self._cache: Dict[str, Tuple[float, dict]] = {}
        self.cache_ttl = cache_ttl_seconds
        self.session   = requests.Session()
        self.session.headers.update({"User-Agent": "CivicAI/2.0"})

    @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=8))
    def _wb_fetch(self, country_iso2: str, indicator: str) -> Optional[float]:
        \"\"\"Fetch latest non-null value from World Bank API.\"\"\"
        url = (f"{self.WORLD_BANK_BASE}/country/{country_iso2}/indicator/{indicator}"
               f"?format=json&mrv=5&per_page=5")
        r   = self.session.get(url, timeout=10)
        r.raise_for_status()
        data = r.json()
        if len(data) < 2 or not data[1]:
            return None
        for entry in data[1]:
            if entry.get("value") is not None:
                return float(entry["value"])
        return None

    def fetch_country(self, country_code: str = "USA") -> dict:
        \"\"\"
        Returns normalised economic state for a country.
        Uses cache β†’ World Bank API β†’ fallback in that order.
        \"\"\"
        cache_key = f"{country_code}_{int(time.time() // self.cache_ttl)}"
        if cache_key in self._cache:
            return self._cache[cache_key]

        meta  = self.COUNTRIES.get(country_code, self.COUNTRIES["USA"])
        iso2  = meta["iso2"]
        raw   = {}

        with Progress(SpinnerColumn(), TextColumn("[cyan]Fetching {task.description}"),
                      transient=True) as prog:
            t = prog.add_task(f"live data for {meta['name']}")
            for key, indicator in self.WB_INDICATORS.items():
                try:
                    val = self._wb_fetch(iso2, indicator)
                    raw[key] = val if val is not None else self.FALLBACKS[country_code][key]
                except Exception:
                    raw[key] = self.FALLBACKS[country_code][key]

        # ── Normalise to [0, 1] for the RL environment ────────────────────
        state = {
            # Lower inflation = better; 0 %β†’1.0, β‰₯15 %β†’0.0
            "inflation"   : max(0.0, min(1.0, 1 - raw["inflation"]   / 15.0)),
            # Higher employment = better
            "employment"  : max(0.0, min(1.0, 1 - raw["unemployment"]/ 25.0)),
            # Higher health expenditure + life expectancy = better
            "health"      : max(0.0, min(1.0, (raw["health_exp"] / 20.0 +
                                               raw["life_expect"] / 90.0) / 2)),
            # GDP growth proxy for satisfaction
            "satisfaction": max(0.0, min(1.0, (raw["gdp_growth"] + 5) / 15.0)),
            # Lower homicide = better
            "crime"       : max(0.0, min(1.0, 1 - raw["homicide"]   / 50.0)),
        }

        # Attach raw for reporting
        state["_raw"]    = raw
        state["_country"]= meta["name"]
        state["_fetched"]= datetime.now().isoformat()

        self._cache[cache_key] = state
        return state

    def fetch_all_countries(self) -> Dict[str, dict]:
        results = {}
        for code in self.COUNTRIES:
            console.log(f"[dim]β†’ fetching {code}")
            results[code] = self.fetch_country(code)
        return results

    def to_dataframe(self, all_data: Dict[str, dict]) -> pd.DataFrame:
        rows = []
        for code, state in all_data.items():
            raw = state.get("_raw", {})
            rows.append({
                "country"     : state.get("_country", code),
                "code"        : code,
                "inflation_pct": raw.get("inflation", 0),
                "unemployment_pct": raw.get("unemployment", 0),
                "health_exp_gdp": raw.get("health_exp", 0),
                "life_expect"   : raw.get("life_expect", 0),
                "gdp_growth"    : raw.get("gdp_growth", 0),
                "homicide_rate" : raw.get("homicide", 0),
                # normalised
                "norm_inflation"   : state["inflation"],
                "norm_employment"  : state["employment"],
                "norm_health"      : state["health"],
                "norm_satisfaction": state["satisfaction"],
                "norm_crime"       : state["crime"],
                "fetched_at"       : state.get("_fetched"),
            })
        return pd.DataFrame(rows)


# Instantiate & fetch
fetcher     = RealTimeDataFetcher(cache_ttl_seconds=3600)
all_data    = fetcher.fetch_all_countries()
df_world    = fetcher.to_dataframe(all_data)

console.rule("[bold green]Live Data Fetched")
console.print(df_world[["country","inflation_pct","unemployment_pct",
                          "health_exp_gdp","gdp_growth"]].to_string(index=False))


# ── CELL 4: REAL-DATA DASHBOARD ──────────────────────────────────────────────
def plot_global_dashboard(df: pd.DataFrame) -> None:
    fig = make_subplots(
        rows=2, cols=3,
        subplot_titles=(
            "Inflation (%)", "Unemployment (%)", "Health Exp (% GDP)",
            "Life Expectancy (yrs)", "GDP Growth (%)", "Homicide Rate (per 100k)"
        ),
    )
    cols_raw = ["inflation_pct","unemployment_pct","health_exp_gdp",
                "life_expect","gdp_growth","homicide_rate"]
    colors   = px.colors.qualitative.Bold

    for i, col in enumerate(cols_raw):
        r, c = divmod(i, 3)
        fig.add_trace(
            go.Bar(
                x=df["country"], y=df[col],
                marker_color=colors,
                showlegend=False,
                text=df[col].round(1), textposition="outside"
            ),
            row=r+1, col=c+1
        )

    fig.update_layout(
        title_text="🌍 CivicAI β€” Real-Time Global Economic Dashboard",
        title_font_size=20,
        height=600, template="plotly_dark",
        paper_bgcolor="#0d1117", plot_bgcolor="#0d1117",
        font=dict(color="#e6edf3"),
    )
    fig.show()
    fig.write_html("assets/global_dashboard.html")
    console.log("[green]βœ“ Dashboard saved β†’ assets/global_dashboard.html")

plot_global_dashboard(df_world)


# ── CELL 5: ADVANCED MULTI-COUNTRY ENVIRONMENT ───────────────────────────────
class AdvancedCivicAIEnv:
    \"\"\"
    Production-grade multi-country civic environment.
    β€’ Initialises from real World Bank data
    β€’ Supports 6 countries and 4 policy tasks
    β€’ Action space: 5-dimensional continuous [0,1]
    β€’ Observation: 10-dimensional (5 state + 5 delta from last step)
    β€’ Reward: weighted multi-objective (Pareto-style)
    β€’ Includes shock events (recession, pandemic proxy, crime spike)
    \"\"\"

    TASKS = {
        "stabilize_economy"  : {"inflation_weight":0.4, "employment_weight":0.3, "health_weight":0.15, "satisfaction_weight":0.1, "crime_weight":0.05},
        "improve_health"     : {"inflation_weight":0.1, "employment_weight":0.2, "health_weight":0.5,  "satisfaction_weight":0.15,"crime_weight":0.05},
        "reduce_crime"       : {"inflation_weight":0.1, "employment_weight":0.2, "health_weight":0.2,  "satisfaction_weight":0.1, "crime_weight":0.4},
        "maximize_wellbeing" : {"inflation_weight":0.2, "employment_weight":0.2, "health_weight":0.2,  "satisfaction_weight":0.2, "crime_weight":0.2},
    }

    SHOCK_EVENTS = [
        {"name":"recession",  "prob":0.02, "effect":{"inflation":+0.15,"employment":-0.12,"satisfaction":-0.1}},
        {"name":"pandemic",   "prob":0.01, "effect":{"health":-0.2,    "employment":-0.1, "satisfaction":-0.15}},
        {"name":"crime_spike","prob":0.02, "effect":{"crime":-0.15,    "satisfaction":-0.08}},
        {"name":"boom",       "prob":0.02, "effect":{"employment":+0.1,"satisfaction":+0.1,"inflation":+0.05}},
    ]

    def __init__(self, fetcher: RealTimeDataFetcher, default_country: str = "USA"):
        self.fetcher         = fetcher
        self.default_country = default_country
        self._prev_state     = None
        self.step_count      = 0
        self.shock_log       = []
        self.state_data      = {}

    def reset(self, task_id: str = "stabilize_economy", country: str = None) -> dict:
        country          = country or self.default_country
        self.task_id     = task_id
        self.weights     = self.TASKS[task_id]
        self.step_count  = 0
        self.shock_log   = []

        # Load real data as starting state
        live             = self.fetcher.fetch_country(country)
        self.state_data  = {k: live[k] for k in ["inflation","employment","health","satisfaction","crime"]}

        # Add small noise so each episode is unique
        for k in self.state_data:
            self.state_data[k] = float(np.clip(
                self.state_data[k] + np.random.normal(0, 0.02), 0.0, 1.0
            ))

        self._prev_state = dict(self.state_data)
        return self._build_obs()

    def _build_obs(self) -> dict:
        \"\"\"10-dim observation: current state + delta from previous step.\"\"\"
        obs = dict(self.state_data)
        obs["_task"]  = self.task_id
        obs["_step"]  = self.step_count
        if self._prev_state:
            for k in ["inflation","employment","health","satisfaction","crime"]:
                obs[f"d_{k}"] = self.state_data[k] - self._prev_state[k]
        else:
            for k in ["inflation","employment","health","satisfaction","crime"]:
                obs[f"d_{k}"] = 0.0
        return obs

    def _apply_shocks(self):
        \"\"\"Stochastic external shock events.\"\"\"
        for shock in self.SHOCK_EVENTS:
            if np.random.random() < shock["prob"]:
                self.shock_log.append({"step": self.step_count, "event": shock["name"]})
                for k, delta in shock["effect"].items():
                    if k in self.state_data:
                        self.state_data[k] = float(np.clip(self.state_data[k] + delta, 0.0, 1.0))
                console.log(f"[yellow]⚑ Shock event: {shock['name']} at step {self.step_count}")

    def step(self, action: dict) -> Tuple[dict, float, bool, dict]:
        \"\"\"
        action keys: tax, jobs, healthcare, education, infrastructure
        Each in [0, 1] β€” represents budget allocation intensity.
        \"\"\"
        self._prev_state = dict(self.state_data)

        # Policy effects (with diminishing returns via sqrt)
        tax          = action.get("tax",           0.5)
        jobs         = action.get("jobs",          0.5)
        healthcare   = action.get("healthcare",    0.5)
        education    = action.get("education",     0.5)
        infra        = action.get("infrastructure",0.5)

        self.state_data["inflation"]    = np.clip(
            self.state_data["inflation"]    - tax        * 0.08 + jobs * 0.02, 0.0, 1.0)
        self.state_data["employment"]   = np.clip(
            self.state_data["employment"]   + jobs       * 0.06 + infra * 0.02, 0.0, 1.0)
        self.state_data["health"]       = np.clip(
            self.state_data["health"]       + healthcare * 0.07 + education * 0.02, 0.0, 1.0)
        self.state_data["satisfaction"] = np.clip(
            self.state_data["satisfaction"] + education  * 0.05 + infra    * 0.03
                                            - tax        * 0.03, 0.0, 1.0)
        self.state_data["crime"]        = np.clip(
            self.state_data["crime"]        + education  * 0.05 + jobs     * 0.03
                                            - infra      * 0.01, 0.0, 1.0)

        # Gaussian noise
        for k in self.state_data:
            self.state_data[k] = float(np.clip(
                self.state_data[k] + np.random.normal(0, 0.008), 0.0, 1.0))

        self._apply_shocks()
        self.step_count += 1

        reward  = self._compute_reward()
        done    = self.step_count >= 50
        info    = {"shocks": self.shock_log, "step": self.step_count}
        return self._build_obs(), float(reward), done, info

    def _compute_reward(self) -> float:
        s = self.state_data
        w = self.weights
        return (
            w["inflation_weight"]    * s["inflation"]    +
            w["employment_weight"]   * s["employment"]   +
            w["health_weight"]       * s["health"]       +
            w["satisfaction_weight"] * s["satisfaction"] +
            w["crime_weight"]        * s["crime"]
        )

    def state_report(self) -> dict:
        return {k: round(v, 4) for k, v in self.state_data.items()}


# Smoke test
env_adv = AdvancedCivicAIEnv(fetcher, default_country="USA")
obs     = env_adv.reset("stabilize_economy", "USA")
console.rule("[bold green]Advanced Environment Ready")
console.print(f"Initial state (USA, real data): {env_adv.state_report()}")


# ── CELL 6: PROMPT BUILDER (5-action) ────────────────────────────────────────
def build_prompt(obs: dict) -> str:
    task_desc = {
        "stabilize_economy"  : "Your priority is economic stability: control inflation and protect employment.",
        "improve_health"     : "Your priority is public health: maximize health outcomes and life expectancy.",
        "reduce_crime"       : "Your priority is public safety: reduce crime through investment and employment.",
        "maximize_wellbeing" : "Your priority is overall citizen wellbeing across all dimensions.",
    }.get(obs.get("_task",""), "Optimize all civic outcomes.")

    return (
        f"You are a senior policy advisor.\\n{task_desc}\\n\\n"
        f"CURRENT STATE (step {obs.get('_step',0)}):\\n"
        f"  Inflation score   : {obs.get('inflation',0.5):.3f}  (Ξ” {obs.get('d_inflation',0):+.3f})\\n"
        f"  Employment score  : {obs.get('employment',0.5):.3f}  (Ξ” {obs.get('d_employment',0):+.3f})\\n"
        f"  Health score      : {obs.get('health',0.5):.3f}  (Ξ” {obs.get('d_health',0):+.3f})\\n"
        f"  Satisfaction score: {obs.get('satisfaction',0.5):.3f}  (Ξ” {obs.get('d_satisfaction',0):+.3f})\\n"
        f"  Crime score       : {obs.get('crime',0.5):.3f}  (Ξ” {obs.get('d_crime',0):+.3f})\\n\\n"
        "OUTPUT FORMAT (all values 0.0–1.0, no other text):\\n"
        "tax: 0.X, jobs: 0.X, healthcare: 0.X, education: 0.X, infrastructure: 0.X"
    )


def parse_action(text: str) -> dict:
    \"\"\"5-dimensional action parser with robust regex.\"\"\"
    keys = ["tax", "jobs", "healthcare", "education", "infrastructure"]

    def extract(key: str) -> float:
        m = re.search(rf"{key}\\s*:\\s*(\\d*\\.?\\d+)", text)
        if m:
            try:
                return float(np.clip(float(m.group(1)), 0.0, 1.0))
            except ValueError:
                pass
        return 0.5

    return {k: extract(k) for k in keys}


# ── CELL 7: LOAD MODEL WITH LoRA ─────────────────────────────────────────────
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft         import LoraConfig, get_peft_model, TaskType

MODEL_NAME = "gpt2"   # swap to "gpt2-medium" or "distilgpt2" as needed

tokenizer            = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token  = tokenizer.eos_token
tokenizer.padding_side = "left"

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype = torch.bfloat16 if USE_BF16 else torch.float32,
)

# ── Attach LoRA adapters (reduces trainable params by ~90%) ──────────────────
lora_cfg = LoraConfig(
    task_type    = TaskType.CAUSAL_LM,
    r            = 8,           # rank
    lora_alpha   = 32,
    target_modules = ["c_attn"],  # GPT-2 attention projection
    lora_dropout = 0.05,
    bias         = "none",
)
model = get_peft_model(base_model, lora_cfg)
model.print_trainable_parameters()
model = model.to(DEVICE)

console.rule("[bold green]Model Ready")
console.log(f"[cyan]Model      : {MODEL_NAME} + LoRA (r=8)")
console.log(f"[cyan]Parameters : {sum(p.numel() for p in model.parameters())/1e6:.1f}M total")


# ── CELL 8: BUILD TRAINING DATASET FROM REAL DATA ────────────────────────────
from datasets import Dataset

NUM_SAMPLES = 300
records     = []
env_tmp     = AdvancedCivicAIEnv(fetcher)
task_list   = list(AdvancedCivicAIEnv.TASKS.keys())
country_list= list(RealTimeDataFetcher.COUNTRIES.keys())

for i in range(NUM_SAMPLES):
    task    = task_list[i % len(task_list)]
    country = country_list[i % len(country_list)]
    obs     = env_tmp.reset(task, country)
    records.append({
        "prompt"  : build_prompt(obs),
        "task"    : task,
        "country" : country,
    })

train_dataset = Dataset.from_list(records)
console.log(f"[green]βœ“ Dataset: {len(train_dataset)} prompts across "
            f"{len(task_list)} tasks Γ— {len(country_list)} countries")
console.log(f"  Sample:\\n{train_dataset[0]['prompt'][:300]}...")


# ── CELL 9: MULTI-OBJECTIVE REWARD FUNCTION ───────────────────────────────────
def civic_reward_advanced(prompts, completions, task=None, country=None, **kwargs) -> List[float]:
    \"\"\"
    Multi-objective GRPO reward function.
    Scores: environment reward + format compliance + consistency bonus.
    \"\"\"
    rewards    = []
    env_r      = AdvancedCivicAIEnv(fetcher)
    task_list_ = task   if isinstance(task,   list) else [task]   * len(prompts)
    country_   = country if isinstance(country,list) else [country]* len(prompts)

    for i, (prompt, completion) in enumerate(zip(prompts, completions)):
        # Extract text
        if isinstance(completion, list) and len(completion) > 0:
            text = completion[0].get("content", "")
        else:
            text = str(completion)

        action = parse_action(text)

        # Environment reward
        t = (task_list_[i] if task_list_[i] else "maximize_wellbeing")
        c = (country_[i]   if country_[i]   else "USA")
        env_r.reset(t, c)
        _, env_rew, _, _ = env_r.step(action)

        # Format reward: all 5 keys present
        keys_found = sum(
            1 for k in ["tax","jobs","healthcare","education","infrastructure"]
            if re.search(rf"{k}\\s*:\\s*\\d", text)
        )
        fmt_bonus  = (keys_found / 5.0) * 0.15    # up to +0.15

        # Diversity bonus: penalise all-same values (lazy policy)
        vals = list(action.values())
        div_bonus  = float(np.std(vals)) * 0.1     # up to ~+0.05

        total = float(env_rew) + fmt_bonus + div_bonus
        rewards.append(round(total, 5))

    return rewards


# ── CELL 10: GRPO CONFIG (VERSION-SAFE) ──────────────────────────────────────
from trl import GRPOConfig, GRPOTrainer

valid_params = set(inspect.signature(GRPOConfig.__init__).parameters)

all_kwargs = {
    "output_dir"                  : "checkpoints/civicai-grpo",
    "num_train_epochs"            : 3,
    "per_device_train_batch_size" : 2,
    "num_generations"             : 2,
    "max_prompt_length"           : 300,
    "max_completion_length"       : 80,
    "learning_rate"               : 5e-6,
    "logging_steps"               : 5,
    "save_strategy"               : "epoch",
    "save_total_limit"            : 2,
    "report_to"                   : "none",
    "remove_unused_columns"       : False,
    "bf16"                        : USE_BF16,
    "fp16"                        : USE_FP16,
    "gradient_accumulation_steps" : 4,
    "max_grad_norm"               : 0.3,
    "warmup_ratio"                : 0.05,
    "lr_scheduler_type"           : "cosine",
    "dataloader_num_workers"      : 0,
}

safe_kwargs = {k: v for k, v in all_kwargs.items() if k in valid_params}
skipped     = set(all_kwargs) - set(safe_kwargs)
if skipped:
    console.log(f"[yellow]Skipped unsupported GRPOConfig args: {skipped}")

grpo_config = GRPOConfig(**safe_kwargs)

trainer = GRPOTrainer(
    model            = model,
    args             = grpo_config,
    reward_funcs     = civic_reward_advanced,
    train_dataset    = train_dataset,
    processing_class = tokenizer,
)
console.log("[green]βœ“ GRPOTrainer initialised with LoRA + multi-objective reward")


# ── CELL 11: TRAINING ─────────────────────────────────────────────────────────
console.rule("[bold cyan]Starting GRPO Training")
start_time = time.time()
trainer.train()
elapsed    = time.time() - start_time
console.rule(f"[bold green]Training Complete β€” {elapsed/60:.1f} min")


# ── CELL 12: EXTRACT & PLOT TRAINING METRICS ─────────────────────────────────
logs = trainer.state.log_history
df_logs = pd.DataFrame(logs).dropna(subset=["loss"] if "loss" in pd.DataFrame(logs).columns else [])

reward_entries = [e for e in logs if "reward" in e]
rewards_logged = [e["reward"] for e in reward_entries]
steps_logged   = [e.get("step", i) for i, e in enumerate(reward_entries)]

fig = make_subplots(rows=1, cols=2,
                    subplot_titles=("Reward Curve", "Reward Distribution"))

# Reward over steps
fig.add_trace(go.Scatter(
    x=steps_logged, y=rewards_logged,
    mode="lines", name="Reward", line=dict(color="#00d4ff", width=2)
), row=1, col=1)

# Smoothed
if len(rewards_logged) > 5:
    smooth = np.convolve(rewards_logged, np.ones(5)/5, mode="valid")
    fig.add_trace(go.Scatter(
        x=steps_logged[4:], y=smooth,
        mode="lines", name="Smoothed",
        line=dict(color="#ff6b6b", width=2, dash="dash")
    ), row=1, col=1)

# Histogram
fig.add_trace(go.Histogram(
    x=rewards_logged, nbinsx=20,
    marker_color="#00d4ff", opacity=0.75, name="Distribution"
), row=1, col=2)

fig.update_layout(
    title="CivicAI GRPO Training Metrics",
    template="plotly_dark", height=420,
    paper_bgcolor="#0d1117", font=dict(color="#e6edf3"),
)
fig.show()
fig.write_html("assets/training_metrics.html")

if rewards_logged:
    console.print(f"[cyan]Start reward : {rewards_logged[0]:.4f}")
    console.print(f"[cyan]Final reward : {rewards_logged[-1]:.4f}")
    console.print(f"[green]Improvement  : {rewards_logged[-1]-rewards_logged[0]:+.4f}")


# ── CELL 13: MULTI-COUNTRY POLICY EVALUATION ─────────────────────────────────
def evaluate_trained_policy(
    model, tokenizer, fetcher,
    countries: List[str] = None,
    tasks: List[str]     = None,
    episodes: int        = 5,
) -> pd.DataFrame:
    \"\"\"Evaluate trained policy on all countries Γ— all tasks.\"\"\"
    countries = countries or list(RealTimeDataFetcher.COUNTRIES.keys())
    tasks     = tasks     or list(AdvancedCivicAIEnv.TASKS.keys())
    model.eval()
    results   = []

    for country in countries:
        for task in tasks:
            ep_rewards = []
            env_eval   = AdvancedCivicAIEnv(fetcher, default_country=country)

            for _ in range(episodes):
                obs         = env_eval.reset(task, country)
                ep_reward   = 0.0
                for _ in range(20):
                    prompt = build_prompt(obs)
                    inputs = tokenizer(prompt, return_tensors="pt",
                                       truncation=True, max_length=300).to(DEVICE)
                    with torch.no_grad():
                        out = model.generate(
                            **inputs, max_new_tokens=60,
                            do_sample=False,
                            pad_token_id=tokenizer.eos_token_id,
                        )
                    gen_tokens = out[0][inputs["input_ids"].shape[1]:]
                    text       = tokenizer.decode(gen_tokens, skip_special_tokens=True)
                    action     = parse_action(text)
                    obs, r, done, _ = env_eval.step(action)
                    ep_reward  += r
                    if done: break
                ep_rewards.append(ep_reward / 20)

            results.append({
                "country"  : RealTimeDataFetcher.COUNTRIES[country]["name"],
                "task"     : task,
                "mean_r"   : round(float(np.mean(ep_rewards)), 4),
                "std_r"    : round(float(np.std(ep_rewards)),  4),
                "max_r"    : round(float(np.max(ep_rewards)),  4),
            })
            console.log(f"[dim]{country} / {task} β†’ {results[-1]['mean_r']:.4f}")

    return pd.DataFrame(results)


def baseline_score(fetcher, episodes=5):
    \"\"\"Fixed 0.5 policy baseline.\"\"\"
    env_b, total = AdvancedCivicAIEnv(fetcher, "USA"), []
    for _ in range(episodes):
        obs = env_b.reset("maximize_wellbeing", "USA")
        r   = 0.0
        for _ in range(20):
            obs, rew, done, _ = env_b.step(
                {k: 0.5 for k in ["tax","jobs","healthcare","education","infrastructure"]}
            )
            r += rew
        total.append(r / 20)
    return float(np.mean(total))


console.rule("[bold cyan]Evaluating Policy Across Countries & Tasks")
df_eval  = evaluate_trained_policy(model, tokenizer, fetcher, episodes=3)
baseline = baseline_score(fetcher)

console.print(df_eval.to_string(index=False))
console.print(f"\\n[bold]Baseline (fixed 0.5) : {baseline:.4f}")
console.print(f"[bold green]Best trained score   : {df_eval['mean_r'].max():.4f}")


# ── CELL 14: EVALUATION HEATMAP ──────────────────────────────────────────────
pivot = df_eval.pivot(index="country", columns="task", values="mean_r")

fig_heat = go.Figure(go.Heatmap(
    z          = pivot.values,
    x          = pivot.columns.tolist(),
    y          = pivot.index.tolist(),
    colorscale = "RdYlGn",
    text       = np.round(pivot.values, 3),
    texttemplate="%{text}",
    showscale  = True,
    zmin=0.4, zmax=1.0,
))
fig_heat.add_shape(
    type="line", x0=-0.5, x1=len(pivot.columns)-0.5,
    y0=-0.5, y1=len(pivot.index)-0.5,
    line=dict(color="white", width=0)
)
fig_heat.update_layout(
    title  = "Policy Performance Heatmap β€” Country Γ— Task (GRPO Trained)",
    template="plotly_dark", height=400,
    paper_bgcolor="#0d1117", font=dict(color="#e6edf3"),
    xaxis_title="Task", yaxis_title="Country",
)
fig_heat.show()
fig_heat.write_html("assets/eval_heatmap.html")


# ── CELL 15: SAVE EVERYTHING ─────────────────────────────────────────────────
# Save LoRA adapter only (lightweight)
model.save_pretrained("checkpoints/civicai-lora")
tokenizer.save_pretrained("checkpoints/civicai-lora")

# Save results JSON
results_json = {
    "run_timestamp"  : datetime.now().isoformat(),
    "model"          : MODEL_NAME,
    "lora_rank"      : 8,
    "training_epochs": 3,
    "num_countries"  : len(RealTimeDataFetcher.COUNTRIES),
    "num_tasks"      : len(AdvancedCivicAIEnv.TASKS),
    "data_source"    : "World Bank Open API (live)",
    "baseline_reward": round(baseline, 4),
    "best_reward"    : round(float(df_eval["mean_r"].max()), 4),
    "improvement"    : round(float(df_eval["mean_r"].max()) - baseline, 4),
    "reward_history" : rewards_logged,
    "eval_by_country_task": df_eval.to_dict(orient="records"),
    "real_data_snapshot"  : df_world[["country","inflation_pct","unemployment_pct",
                                       "health_exp_gdp","gdp_growth"]].to_dict(orient="records"),
}

with open("assets/training_results.json", "w") as f:
    json.dump(results_json, f, indent=2)

console.rule("[bold green]All Done")
console.print(f"[green]βœ“ LoRA checkpoint  β†’ checkpoints/civicai-lora/")
console.print(f"[green]βœ“ Results JSON     β†’ assets/training_results.json")
console.print(f"[green]βœ“ Dashboard HTML   β†’ assets/global_dashboard.html")
console.print(f"[green]βœ“ Training metrics β†’ assets/training_metrics.html")
console.print(f"[green]βœ“ Eval heatmap     β†’ assets/eval_heatmap.html")
console.print(f"\\n[bold cyan]Baseline  : {baseline:.4f}")
console.print(f"[bold green]Best score: {df_eval['mean_r'].max():.4f}")
console.print(f"[bold green]Delta     : {df_eval['mean_r'].max() - baseline:+.4f}")
"""

cells = []
# Create a title cell
cells.append({
    "cell_type": "markdown",
    "metadata": {},
    "source": [
        "# πŸ› CivicAI Advanced β€” Senior ML Engineer Edition\\n",
        "**Real-time Economic Data + GRPO + LoRA + Multi-Country + Live Dashboard**"
    ]
})

# Split the code by cells
chunks = re.split(r'# ── CELL \d+.*?\n', code)
headers = re.findall(r'# ── CELL \d+.*?$', code, re.MULTILINE)

# The first chunk is everything before CELL 1
if len(chunks) > 1:
    for idx, chunk in enumerate(chunks[1:]):
        header_text = headers[idx]
        cells.append({
            "cell_type": "markdown",
            "metadata": {},
            "source": [f"### {header_text.replace('# ── ', '').replace(' ──', '').strip()}"]
        })
        # Remove trailing and leading newlines
        chunk = chunk.strip()
        
        # If the chunk is just the pip install block, we'll strip the docstrings
        if "pip install" in chunk and '"""' in chunk:
            chunk = chunk.replace('"""', '').strip()
            
        cells.append({
            "cell_type": "code",
            "execution_count": None,
            "metadata": {},
            "outputs": [],
            "source": [line + "\\n" for line in chunk.split('\\n')]
        })

notebook = {
    "cells": cells,
    "metadata": {
        "colab": {"name": "CivicAI_Training.ipynb"},
        "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
        "language_info": {"name": "python", "version": "3.10"}
    },
    "nbformat": 4,
    "nbformat_minor": 4
}

with open("c:/Users/mdaft/OneDrive/Desktop/GitHub Projects/AI_Society_Simulator/CivicAI_Training.ipynb", "w", encoding='utf-8') as f:
    json.dump(notebook, f, indent=2)