File size: 41,438 Bytes
3eae4cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
from __future__ import annotations

# ── Path bootstrap ──────────────────────────────────────────────────────────
import sys
from pathlib import Path

_ROOT = Path(__file__).resolve().parent
if str(_ROOT) not in sys.path:
    sys.path.insert(0, str(_ROOT))

# ── Load .env ────────────────────────────────────────────────────────────────
from dotenv import load_dotenv
load_dotenv(dotenv_path=_ROOT / ".env", override=False)

import argparse
import json
import os
import random as _random
import re
import time
from dataclasses import asdict, dataclass, field
from datetime import datetime
from typing import Any

from app.env import GovWorkflowEnv
from app.models import (
    ActionModel,
    ActionType,
    ObservationModel,
    PriorityMode,
    ServiceType,
    StepInfoModel,
)
from app.tasks import get_task, list_tasks
from app.api_gateway import create_env_gateway, TransportMode


# ══════════════════════════════════════════════════════════════════════════════
# SECTION 1 β€” Model Registry & Per-Task Pools
# ══════════════════════════════════════════════════════════════════════════════

NVIDIA_BASE_URL = "https://integrate.api.nvidia.com/v1"

# ── Global 10-Model Sequential Pool (April 2026 β€” Verified on NVIDIA NIM) ────
#
# CHANGES FROM PREVIOUS VERSION:
#   REMOVED (invalid/unavailable IDs):
#     qwen/qwen3-next-80b-a3b-instruct     β†’ invalid model ID
#     moonshotai/kimi-k2-instruct-0905     β†’ not on NVIDIA NIM
#     deepseek-ai/deepseek-v3.2            β†’ wrong ID (use deepseek-v3)
#     google/gemma-3-27b-it               β†’ outdated (gemma-4 released)
#     mistralai/mixtral-8x22b-instruct-v0.1 β†’ replaced by newer models
#   ADDED (verified April 2026):
#     deepseek-ai/deepseek-v4-flash        β†’ FREE endpoint, 1M context
#     deepseek-ai/deepseek-r1             β†’ reasoning, 685B MoE
#     nvidia/nemotron-3-super-120b-a12b   β†’ hybrid Mamba-Transformer, 1M ctx
#     minimaxai/minimax-m2.7             β†’ FREE endpoint, 230B
#     google/gemma-4-31b-it             β†’ latest Gemma on NVIDIA NIM
#     qwen/qwen3.5-122b-a10b            β†’ latest Qwen on NVIDIA NIM

GLOBAL_MODEL_POOL: list[str] = [
    "meta/llama-3.3-70b-instruct",          # 1. Primary
    "deepseek-ai/deepseek-v4-flash",         # 2. FREE endpoint β€” 1M context
    "deepseek-ai/deepseek-r1",              # 3. Reasoning β€” 685B MoE
    "nvidia/nemotron-3-super-120b-a12b",    # 4. NVIDIA native β€” 1M ctx
    "qwen/qwen3.5-122b-a10b",              # 5. Qwen3.5 β€” tool calling
    "deepseek-ai/deepseek-v3",             # 6. DeepSeek V3 β€” hybrid mode
    "minimaxai/minimax-m2.7",             # 7. FREE endpoint β€” 230B
    "google/gemma-4-31b-it",             # 8. Dense 31B β€” agentic workflows
    "microsoft/phi-4-mini-instruct",     # 9. Reliable small β€” last resort
    "meta/llama-3.1-8b-instruct",       # 10. Fastest safety fallback
]

# ── Free endpoint pool (KEY 2 β€” NVIDIA_API_KEY_2 fallback) ───────────────────
FREE_POOL: list[str] = [
    "deepseek-ai/deepseek-v4-flash",
    "minimaxai/minimax-m2.7",
    "microsoft/phi-4-mini-instruct",
    "meta/llama-3.1-8b-instruct",
]

# ── Fixed seeds ────────────────────────────────────────────────────────────────
TASK_SEEDS: dict[str, int] = {
    "district_backlog_easy": 11,
    "mixed_urgency_medium":  22,
    "cross_department_hard": 33,
}

LLM_TEMPERATURE = 0.2
LLM_TOP_P       = 0.7
LLM_MAX_TOKENS  = 512
MAX_LLM_STEPS   = 80

LLM_CALL_DELAY  = float(os.environ.get("LLM_CALL_DELAY", "12.0"))
LLM_CALL_JITTER = 1.0

# ── Enum fields that MUST be lowercase for Pydantic StrEnum ──────────────────
_ENUM_FIELDS = {"action_type", "priority_mode", "service", "target_service"}

# ── Canonical field names (Phase 2 update β€” do NOT use legacy names) ─────────
#   CORRECT                        WRONG (legacy)
#   snap.blocked_missing_docs  ←   snap.missing_docs_cases
#   snap.total_pending         ←   snap.active_cases
#   obs.fairness_gap           ←   obs.fairness_index


# ══════════════════════════════════════════════════════════════════════════════
# SECTION 2 β€” Model Rotator
# ══════════════════════════════════════════════════════════════════════════════

class ModelRotator:
    def __init__(self, task_id: str) -> None:
        self._sequence: list[str] = GLOBAL_MODEL_POOL.copy()
        self._index = 0
        self._task_id = task_id
        self._rotation_log: list[dict[str, str]] = []

    @property
    def current(self) -> str:
        return self._sequence[self._index]

    @property
    def current_key_id(self) -> int:
        return 2 if self.current in FREE_POOL else 1

    @property
    def pool_exhausted(self) -> bool:
        return len(self._rotation_log) >= 50

    def rotate(self, reason: str = "error") -> str | None:
        old = self.current
        self._rotation_log.append({"from": old, "reason": reason})
        self._index = (self._index + 1) % len(self._sequence)
        new = self._sequence[self._index]
        print(
            f"\n  πŸ”„ Model rotated: "
            f"{old.split('/')[-1]}  β†’  {new.split('/')[-1]}  ({reason})"
        )
        return new

    def summary(self) -> list[dict]:
        return list(self._rotation_log)


# ══════════════════════════════════════════════════════════════════════════════
# SECTION 3 β€” Result Dataclasses
# ══════════════════════════════════════════════════════════════════════════════

@dataclass
class StepRecord:
    step: int
    day: int
    action_type: str
    reward: float
    invalid: bool
    total_backlog: int
    total_completed: int
    model_used: str
    notes: list[str]


@dataclass
class EpisodeResult:
    task_id: str
    agent: str
    primary_model: str
    seed: int
    score: float
    grader_name: str
    total_steps: int
    total_reward: float
    total_completed: int
    total_sla_breaches: int
    total_invalid_actions: int
    final_day: int
    terminated: bool
    truncated: bool
    grader_metrics: dict[str, float]
    step_log: list[StepRecord]
    elapsed_seconds: float
    model_rotations: list[dict]
    timestamp: str = field(default_factory=lambda: datetime.now().isoformat())

    def summary(self) -> str:
        usage: dict[str, int] = {}
        for r in self.step_log:
            usage[r.model_used] = usage.get(r.model_used, 0) + 1
        usage_str = ", ".join(
            f"{m.split('/')[-1]} ({c})" for m, c in usage.items()
        )
        return (
            f"[{self.task_id}] agent={self.agent} "
            f"score={self.score:.3f} reward={self.total_reward:.2f} "
            f"completed={self.total_completed} breaches={self.total_sla_breaches} "
            f"invalid={self.total_invalid_actions} "
            f"rotations={len(self.model_rotations)} "
            f"day={self.final_day} steps={self.total_steps} "
            f"time={self.elapsed_seconds:.1f}s\n"
            f"  Model usage: {usage_str}"
        )


# ══════════════════════════════════════════════════════════════════════════════
# SECTION 4 β€” Direct Environment Wrapper
# ══════════════════════════════════════════════════════════════════════════════

class DirectEnvClient:
    """
    FIX: grade() now calls grade_episode(task_id, episode_state) correctly.
    Previous version called grade_episode(self.env.state()) β€” wrong signature.
    get_episode_state() returns EpisodeStateModel, not ObservationModel.
    """

    def __init__(self, task_id: str, seed: int) -> None:
        self.env = GovWorkflowEnv(task_id=task_id)
        self._seed = seed
        self._task_id = task_id
        self.terminated = False
        self.truncated = False

    def reset(self) -> ObservationModel:
        obs, _ = self.env.reset(seed=self._seed)
        self.terminated = False
        self.truncated = False
        return obs

    def step(
        self, action: ActionModel
    ) -> tuple[ObservationModel, float, bool, bool, StepInfoModel]:
        obs, reward, terminated, truncated, info = self.env.step(action)
        self.terminated = terminated
        self.truncated = truncated
        return obs, reward, terminated, truncated, info

    def grade(self) -> tuple[float, str, dict[str, float]]:
        from app.graders import grade_episode
        episode_state = self.env.state()
        result = grade_episode(episode_state)
        return result.score, result.grader_name, result.metrics


# ══════════════════════════════════════════════════════════════════════════════
# SECTION 5 β€” HTTP Environment Wrapper
# ══════════════════════════════════════════════════════════════════════════════

class HttpEnvClient:
    def __init__(
        self, task_id: str, seed: int, base_url: str = "http://localhost:7860"
    ) -> None:
        try:
            import requests as _req
            self._req = _req
        except ImportError:
            raise ImportError("pip install requests  β€” required for --mode http")
        self._task_id = task_id
        self._seed = seed
        self._base_url = base_url.rstrip("/")
        self._session_id: str | None = None
        self.terminated = False
        self.truncated = False

    def _post(self, path: str, body: dict) -> dict:
        r = self._req.post(
            f"{self._base_url}{path}", json=body, timeout=30
        )
        r.raise_for_status()
        return r.json()

    def reset(self) -> ObservationModel:
        data = self._post("/reset", {"task_id": self._task_id, "seed": self._seed})
        self._session_id = data["session_id"]
        self.terminated = False
        self.truncated = False
        return ObservationModel(**data["observation"])

    def step(
        self, action: ActionModel
    ) -> tuple[ObservationModel, float, bool, bool, StepInfoModel]:
        data = self._post("/step", {
            "session_id": self._session_id,
            "action": action.model_dump(exclude_none=True),
        })
        obs  = ObservationModel(**data["observation"])
        info = StepInfoModel(**data["info"])
        self.terminated = data["terminated"]
        self.truncated  = data["truncated"]
        return obs, data["reward"], data["terminated"], data["truncated"], info

    def grade(self) -> tuple[float, str, dict[str, float]]:
        data = self._post("/grade", {"session_id": self._session_id})
        return data["score"], data["grader_name"], data["metrics"]


# ══════════════════════════════════════════════════════════════════════════════
# SECTION 6 β€” Heuristic Baseline Agent
# ══════════════════════════════════════════════════════════════════════════════

class HeuristicAgent:
    """
    Rule-based agent. Requires no API key.

    FIXED field names (Phase 2 canonical):
      snap.blocked_missing_docs  ← was snap.missing_docs_cases
      snap.total_pending         ← was snap.active_cases
    """

    def __init__(self) -> None:
        self._priority_set = False
        self._admin_action_day: int | None = None
        self._last_doc_request_day: int | None = None

    def reset(self) -> None:
        self._priority_set = False
        self._admin_action_day = None
        self._last_doc_request_day = None

    current_model = "heuristic"

    def rotation_summary(self) -> list[dict]:
        return []

    def update_reward(self, _: float) -> None:
        pass

    @staticmethod
    def _svc_key(service: str | ServiceType) -> str:
        return service.value if isinstance(service, ServiceType) else str(service)

    def act(self, obs: ObservationModel) -> ActionModel:
        snapshots = list(obs.queue_snapshots.values())

        # One admin action per simulated day; then always advance time.
        if self._admin_action_day == obs.day:
            return ActionModel(action_type=ActionType.ADVANCE_TIME)

        # 1. Set priority mode once
        if not self._priority_set:
            self._priority_set = True
            self._admin_action_day = obs.day
            return ActionModel(
                action_type=ActionType.SET_PRIORITY_MODE,
                priority_mode=PriorityMode.URGENT_FIRST,
            )

        # 2. Allocate any idle officer to the currently most loaded service.
        if obs.officer_pool.idle_officers > 0 and snapshots:
            most_loaded = max(snapshots, key=lambda s: s.total_pending)
            self._admin_action_day = obs.day
            return ActionModel(
                action_type=ActionType.ASSIGN_CAPACITY,
                capacity_assignment={most_loaded.service_type.value: 1},
            )

        days_left = obs.max_days - obs.day

        # 3. Reallocate one officer if load/officer ratio is clearly imbalanced.
        allocated = {
            self._svc_key(svc): int(off)
            for svc, off in obs.officer_pool.allocated.items()
        }
        if snapshots and len(allocated) >= 2:
            case_counts = {s.service_type.value: s.total_pending for s in snapshots}

            best_src: tuple[str, int] | None = None
            best_tgt: tuple[str, int] | None = None
            src_ratio = float("inf")
            tgt_ratio = -1.0

            for svc, officers in allocated.items():
                if officers <= 1:
                    continue
                ratio = case_counts.get(svc, 0) / max(officers, 1)
                if ratio < src_ratio:
                    src_ratio = ratio
                    best_src = (svc, officers)

            for svc, officers in allocated.items():
                ratio = case_counts.get(svc, 0) / max(officers, 1)
                if ratio > tgt_ratio:
                    tgt_ratio = ratio
                    best_tgt = (svc, officers)

            if best_src and best_tgt and best_src[0] != best_tgt[0] and tgt_ratio > src_ratio * 1.8:
                self._admin_action_day = obs.day
                return ActionModel(
                    action_type=ActionType.REALLOCATE_OFFICERS,
                    reallocation_delta={best_src[0]: -1, best_tgt[0]: 1},
                )

        # 4. Request missing docs conservatively to avoid repeatedly resetting
        # resolution days for already-requested cases.
        can_request_docs = (
            any(s.blocked_missing_docs > 0 for s in snapshots)
            and (
                self._last_doc_request_day is None
                or (obs.day - self._last_doc_request_day) >= 3
                or obs.pending_doc_resolutions == 0
            )
        )
        if can_request_docs:
            target_docs = max(
                snapshots,
                key=lambda s: (s.blocked_missing_docs, s.current_sla_risk, s.total_pending),
            )
            if target_docs.blocked_missing_docs > 0:
                self._admin_action_day = obs.day
                self._last_doc_request_day = obs.day
                return ActionModel(
                    action_type=ActionType.REQUEST_MISSING_DOCUMENTS,
                    service_target=target_docs.service_type,
                )

        # 5. Escalate in the final window when urgency is present.
        if obs.escalation_budget_remaining > 0:
            urgent_snaps = [s for s in snapshots if s.urgent_pending > 0]
            if urgent_snaps and days_left <= 5:
                target = max(urgent_snaps, key=lambda s: s.urgent_pending)
                self._admin_action_day = obs.day
                return ActionModel(
                    action_type=ActionType.ESCALATE_SERVICE,
                    escalation_target=target.service_type,
                )

        # 6. Default β€” progress simulation.
        return ActionModel(action_type=ActionType.ADVANCE_TIME)


# ══════════════════════════════════════════════════════════════════════════════
# SECTION 7 β€” System Prompt
# ══════════════════════════════════════════════════════════════════════════════

SYSTEM_PROMPT = """You are an expert government-office workflow manager AI.
Your job is to control a simulated government district office processing citizen
applications across multiple services.

SERVICES: passport, driving_license, gst_registration, income_certificate,
          caste_certificate, birth_certificate, land_registration

WORKFLOW STAGES (in order):
  submission β†’ document_verification β†’ field_verification β†’ approval β†’ issuance

YOUR GOAL: Maximise the episode score (0.0 to 1.0) by:
  - Completing as many applications as possible within SLA deadlines
  - Prioritising urgent cases (urgency level 3 > 2 > 1)
  - Keeping all services fairly served (no service left behind)
  - Using escalations sparingly β€” only when a case is about to breach SLA
  - Keeping officers productively busy (not idle)

QUEUE STATUS FIELDS EXPLAINED:
  backlog      = total_pending applications in queue
  missing_docs = blocked_missing_docs (stuck waiting for documents)
  urgent       = urgent_cases (high-urgency applications)
  breached     = breached_cases (already past SLA deadline)

AVAILABLE ACTIONS β€” return exactly ONE per turn as JSON:

1. Set queue processing order (do this FIRST on day 0 only):
   {"action_type": "set_priority_mode", "priority_mode": "urgent_first"}
   priority_mode options: urgent_first | oldest_first | balanced | backlog_clearance

2. Deploy a reserve officer to a service (day 0 only if reserves available):
   {"action_type": "assign_capacity", "service": "driving_license", "officer_delta": 1}

3. Unblock a stuck application with missing documents:
   {"action_type": "request_missing_documents", "service": "driving_license"}

4. Escalate one case to emergency priority (VERY LIMITED β€” use wisely):
   {"action_type": "escalate_service", "service": "income_certificate"}

5. Move officer between services (only when load ratio > 4x):
   {"action_type": "reallocate_officers", "service": "birth_certificate",
    "target_service": "driving_license", "officer_delta": 1}

6. Let one working day pass β€” THE ONLY ACTION THAT PROCESSES APPLICATIONS:
   {"action_type": "advance_time"}

CRITICAL RULES:
  - ALL values MUST be lowercase: driving_license NOT DRIVING_LICENSE
  - advance_time is the ONLY action that earns progress reward
  - Do NOT chain more than 2 admin actions before calling advance_time
  - Do NOT escalate before (max_days - 5) unless case already breached SLA
  - Do NOT reallocate if source service has fewer than 2 officers

OPTIMAL STRATEGY:
  Day 0:     set_priority_mode β†’ assign_capacity (if reserves > 0) β†’ advance_time
  Every day: request_missing_documents (ONE service, highest missing_docs) β†’ advance_time
  Final 5:   escalate_service (urgent/breached only) β†’ advance_time

RESPONSE FORMAT β€” return ONLY a raw JSON object, nothing else:
  CORRECT:   {"action_type": "advance_time"}
  CORRECT:   {"action_type": "request_missing_documents", "service": "driving_license"}
  WRONG:     ```json\n{"action_type": "ADVANCE_TIME"}```
"""


# ══════════════════════════════════════════════════════════════════════════════
# SECTION 8 β€” JSON Extraction with Lowercase Normaliser
# ══════════════════════════════════════════════════════════════════════════════

def _extract_json_action(raw: str) -> dict[str, Any]:
    cleaned = re.sub(r"```(?:json)?", "", raw).strip()
    parsed: dict[str, Any] | None = None

    try:
        parsed = json.loads(cleaned)
    except json.JSONDecodeError:
        pass

    if parsed is None:
        match = re.search(r"\{[^{}]*\}", cleaned, re.DOTALL)
        if match:
            try:
                parsed = json.loads(match.group())
            except json.JSONDecodeError:
                pass

    if parsed is None:
        print(f"  ⚠ JSON parse failed, falling back to advance_time. Raw: {raw[:120]!r}")
        return {"action_type": "advance_time"}

    for enum_field in _ENUM_FIELDS:
        if enum_field in parsed and isinstance(parsed[enum_field], str):
            parsed[enum_field] = parsed[enum_field].lower()

    return parsed


# ══════════════════════════════════════════════════════════════════════════════
# SECTION 9 β€” Observation β†’ User Message Builder
# ══════════════════════════════════════════════════════════════════════════════

def _build_user_message(
    obs: ObservationModel, step_num: int, cumulative_reward: float
) -> str:
    """
    FIXED field names (Phase 2 canonical):
      snap.total_pending        ← was snap.active_cases
      snap.blocked_missing_docs ← was snap.missing_docs_cases
    """
    queue_lines = []
    for snap in obs.queue_snapshots:
        officers = obs.officer_pool.allocations.get(snap.service, 0)
        queue_lines.append(
            f"  {snap.service:<22}: "
            f"backlog={snap.total_pending:>3} "
            f"officers={officers} "
            f"missing_docs={snap.blocked_missing_docs:>2} "
            f"urgent={snap.urgent_cases} "
            f"breached={snap.breached_cases} "
            f"avg_age={snap.avg_age_days:.1f}d"
        )
    return (
        f"STEP {step_num} | Day {obs.day}/{obs.max_days} "
        f"| Days remaining: {obs.max_days - obs.day}\n"
        f"Cumulative reward: {cumulative_reward:.2f}\n"
        f"Priority mode: {obs.priority_mode}\n"
        f"Reserve officers: {obs.officer_pool.reserve_officers}\n"
        f"Escalation budget remaining: {obs.escalation_budget_remaining}\n"
        f"Total pending: {obs.total_backlog} "
        f"| Completed: {obs.total_completed} "
        f"| SLA breaches: {obs.total_sla_breaches}\n"
        f"Fairness gap: {obs.fairness_gap:.3f}\n\n"
        f"QUEUE STATUS:\n" + "\n".join(queue_lines) + "\n\n"
        f"Return a single JSON action object. All values lowercase."
    )


# ══════════════════════════════════════════════════════════════════════════════
# SECTION 10 β€” LLM Agent with Model Rotation
# ══════════════════════════════════════════════════════════════════════════════

class LLMAgent:
    def __init__(
        self,
        task_id: str,
        model_override: str | None = None,
        api_key: str | None = None,
    ) -> None:
        try:
            from openai import OpenAI
            self._OpenAI = OpenAI
        except ImportError:
            raise ImportError("pip install openai  β€” required for LLM agent")

        resolved_key = api_key or os.environ.get("NVIDIA_API_KEY", "")
        self._api_key_2 = os.environ.get("NVIDIA_API_KEY_2", "")

        if not resolved_key:
            raise ValueError(
                "NVIDIA_API_KEY not set.\n"
                "  .env file : NVIDIA_API_KEY=nvapi-xxxxxxxxxxxx\n"
                "  Get free key: https://build.nvidia.com/explore/discover"
            )

        self._api_key = resolved_key
        self._task_id = task_id
        self._rotator = ModelRotator(task_id)

        if model_override:
            seq = [model_override] + [
                m for m in self._rotator._sequence if m != model_override
            ]
            self._rotator._sequence = seq

        self._client = self._OpenAI(base_url=NVIDIA_BASE_URL, api_key=self._api_key)
        self._client_2 = (
            self._OpenAI(base_url=NVIDIA_BASE_URL, api_key=self._api_key_2)
            if self._api_key_2 else None
        )
        self._history: list[dict[str, str]] = []
        self._cumulative_reward = 0.0

    @property
    def current_model(self) -> str:
        return self._rotator.current

    def reset(self) -> None:
        self._history = []
        self._cumulative_reward = 0.0
        self._rotator = ModelRotator(self._task_id)

    def update_reward(self, reward: float) -> None:
        self._cumulative_reward += reward

    def rotation_summary(self) -> list[dict]:
        return self._rotator.summary()

    def act(self, obs: ObservationModel, step_num: int) -> ActionModel:
        if self._rotator.pool_exhausted:
            print("  ⚠ Pool exhausted β€” returning advance_time")
            return ActionModel(action_type=ActionType.ADVANCE_TIME)

        user_message = _build_user_message(obs, step_num, self._cumulative_reward)
        self._history.append({"role": "user", "content": user_message})

        if len(self._history) > 20:
            self._history = self._history[-20:]

        messages = [{"role": "system", "content": SYSTEM_PROMPT}] + self._history
        raw_reply = ""

        while True:
            try:
                active_client = self._client
                if self._rotator.current_key_id == 2 and self._client_2:
                    active_client = self._client_2

                response = active_client.chat.completions.create(
                    model=self._rotator.current,
                    messages=messages,
                    temperature=LLM_TEMPERATURE,
                    top_p=LLM_TOP_P,
                    max_tokens=LLM_MAX_TOKENS,
                    timeout=30,
                )
                raw_reply = response.choices.message.content or ""
                break

            except KeyboardInterrupt:
                raise

            except Exception as exc:
                err_name = type(exc).__name__
                err_msg  = str(exc)[:120]
                print(f"  ⚠ {err_name} on {self._rotator.current.split('/')[-1]}: {err_msg}")
                self._rotator.rotate(reason=err_name)
                time.sleep(1.0)
                if self._rotator.pool_exhausted:
                    return ActionModel(action_type=ActionType.ADVANCE_TIME)

        self._history.append({"role": "assistant", "content": raw_reply})
        action_dict = _extract_json_action(raw_reply)

        try:
            return ActionModel(**action_dict)
        except Exception as exc:
            print(f"  ⚠ ActionModel parse failed ({exc}), using advance_time")
            return ActionModel(action_type=ActionType.ADVANCE_TIME)


# ══════════════════════════════════════════════════════════════════════════════
# SECTION 11 β€” Episode Runner
# ══════════════════════════════════════════════════════════════════════════════

def run_episode(
    task_id: str,
    agent_type: str,
    model_override: str | None,
    mode: TransportMode,
    server_url: str,
    api_key: str | None,
    verbose: bool,
    max_steps: int = MAX_LLM_STEPS,
    delay_override: float | None = None,
) -> EpisodeResult:
    seed  = TASK_SEEDS.get(task_id, get_task(task_id).seed)
    delay = delay_override if delay_override is not None else LLM_CALL_DELAY

    force_fastapi = os.getenv("FORCE_FASTAPI_GATEWAY", "0").strip().lower() in {
        "1",
        "true",
        "yes",
        "on",
    }
    env_api_prefix = os.getenv("OPENENV_ENV_API_PREFIX", "").strip()
    client = create_env_gateway(
        task_id=task_id,
        seed=seed,
        mode=mode,  # type: ignore[arg-type]
        base_url=server_url,
        api_prefix=env_api_prefix,
        enforce_fastapi=force_fastapi,
    )

    if agent_type == "llm":
        agent: HeuristicAgent | LLMAgent = LLMAgent(
            task_id=task_id,
            model_override=model_override,
            api_key=api_key,
        )
        primary_label = agent.current_model
    else:
        agent = HeuristicAgent()
        primary_label = "heuristic"

    agent.reset()
    obs = client.reset()

    step_log: list[StepRecord] = []
    total_reward = 0.0
    total_invalid = 0
    step_num = 0
    start = time.perf_counter()

    print(f"\n{'═'*65}")
    print(f"  Task  : {task_id}")
    if agent_type == "llm":
        k1 = "βœ… loaded" if os.environ.get("NVIDIA_API_KEY", "") else "❌ MISSING"
        k2 = "βœ… loaded" if os.environ.get("NVIDIA_API_KEY_2", "") else "⚠ not set"
        print(f"  KEY 1 : {k1}   KEY 2 : {k2}")
        pool_short = " β†’ ".join(m.split("/")[-1][:14] for m in GLOBAL_MODEL_POOL)
        print(f"  Pool  : {pool_short}")
    resolved_mode = getattr(client, "transport", mode)
    print(f"  Agent : {agent_type}  |  Mode: {resolved_mode}  |  Seed: {seed}")
    print(f"  Max steps: {max_steps}  |  Delay: {delay}s")
    print(f"{'═'*65}")

    while not (client.terminated or client.truncated) and step_num < max_steps:
        step_num += 1
        current_model = agent.current_model

        if agent_type == "llm":
            action = agent.act(obs, step_num)
        else:
            action = agent.act(obs)

        obs, reward, terminated, truncated, info = client.step(action)
        agent.update_reward(reward)

        total_reward += reward
        if info.invalid_action:
            total_invalid += 1

        step_notes: list[str] = []
        legacy_notes = getattr(info, "notes", None)
        if isinstance(legacy_notes, list):
            step_notes.extend(str(n).strip() for n in legacy_notes if str(n).strip())
        elif isinstance(legacy_notes, str) and legacy_notes.strip():
            step_notes.append(legacy_notes.strip())

        if info.action_explanation.strip():
            step_notes.append(info.action_explanation.strip())
        step_notes.extend(s.strip() for s in info.effects_resolved_this_step if s.strip())
        step_notes = list(dict.fromkeys(step_notes))

        record = StepRecord(
            step=step_num,
            day=obs.day,
            action_type=action.action_type.value,
            reward=round(reward, 4),
            invalid=info.invalid_action,
            total_backlog=obs.total_backlog,
            total_completed=obs.total_completed,
            model_used=current_model,
            notes=step_notes,
        )
        step_log.append(record)

        if verbose:
            status    = "❌" if info.invalid_action else "βœ…"
            model_tag = (
                f"[{current_model.split('/')[-1][:22]}]"
                if agent_type == "llm" else ""
            )
            print(
                f"  step={step_num:3d} day={obs.day:2d} "
                f"action={action.action_type.value:<28} "
                f"reward={reward:+.3f}  {status}  {model_tag}"
            )
            if step_notes:
                print(f"         notes: {step_notes}")

        if agent_type == "llm":
            actual_delay = delay + _random.uniform(-LLM_CALL_JITTER, LLM_CALL_JITTER)
            if not verbose:
                print(
                    f"  Step {step_num}/{max_steps} β€” sleeping {actual_delay:.1f}s "
                    f"[{current_model.split('/')[-1][:20]}]",
                    end="\r", flush=True,
                )
            time.sleep(max(1.0, actual_delay))
            if not verbose:
                print(" " * 80, end="\r", flush=True)

    score, grader_name, grader_metrics = client.grade()
    elapsed = round(time.perf_counter() - start, 2)
    rotations = agent.rotation_summary()

    print(f"\n{'-'*65}")
    print(f"  SCORE  : {score:.3f} / 1.000  (grader: {grader_name})")
    print(f"  Reward : {total_reward:.2f}  |  Steps: {step_num}")
    print(f"  Completed: {obs.total_completed}  |  SLA breaches: {obs.total_sla_breaches}")
    print(f"  Invalid actions: {total_invalid}  |  Model rotations: {len(rotations)}")
    print(f"  Time: {elapsed}s")
    print(f"  Grader metrics:")
    for metric, value in grader_metrics.items():
        bar = "β–ˆ" * int(value * 20)
        print(f"    {metric:<34} {value:.3f}  {bar}")
    if rotations:
        print(f"  Rotation log:")
        for r in rotations:
            print(f"    {r['from'].split('/')[-1]:<30} β†’ rotated ({r['reason']})")
    print(f"{'-'*65}")

    return EpisodeResult(
        task_id=task_id,
        agent=agent_type,
        primary_model=primary_label,
        seed=seed,
        score=score,
        grader_name=grader_name,
        total_steps=step_num,
        total_reward=round(total_reward, 4),
        total_completed=obs.total_completed,
        total_sla_breaches=obs.total_sla_breaches,
        total_invalid_actions=total_invalid,
        final_day=obs.day,
        terminated=client.terminated,
        truncated=client.truncated,
        grader_metrics=grader_metrics,
        step_log=step_log,
        elapsed_seconds=elapsed,
        model_rotations=rotations,
    )


# ══════════════════════════════════════════════════════════════════════════════
# SECTION 12 β€” Reporter
# ══════════════════════════════════════════════════════════════════════════════

def save_results(results: list[EpisodeResult], out_dir: Path) -> Path:
    out_dir.mkdir(parents=True, exist_ok=True)
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    out_path = out_dir / f"baseline_run_{ts}.json"
    payload = {
        "run_timestamp": datetime.now().isoformat(),
        "total_episodes": len(results),
        "average_score": round(sum(r.score for r in results) / len(results), 4),
        "model_pool": GLOBAL_MODEL_POOL,
        "free_pool": FREE_POOL,
        "episodes": [asdict(r) for r in results],
    }
    out_path.write_text(json.dumps(payload, indent=2))
    return out_path


def print_leaderboard(results: list[EpisodeResult]) -> None:
    print(f"\n{'═'*72}")
    print("  LEADERBOARD")
    print(f"{'═'*72}")
    header = (
        f"  {'TASK':<32} {'MODEL':<24} {'SCORE':>7}  "
        f"{'REWARD':>8}  {'DONE':>5}  {'ROT':>4}"
    )
    print(header)
    print(f"  {'-'*32} {'-'*24} {'-'*7}  {'-'*8}  {'-'*5}  {'-'*4}")
    for r in sorted(results, key=lambda x: -x.score):
        model_label = r.primary_model.split("/")[-1][:23]
        print(
            f"  {r.task_id:<32} {model_label:<24} {r.score:>7.3f}  "
            f"{r.total_reward:>8.2f}  {r.total_completed:>5}  "
            f"{len(r.model_rotations):>4}"
        )
    avg = sum(r.score for r in results) / len(results)
    print(f"  {'-'*32} {'-'*24} {'-'*7}  {'-'*8}  {'-'*5}  {'-'*4}")
    print(f"  {'AVERAGE':<32} {'':<24} {avg:>7.3f}")
    print(f"{'═'*72}\n")


# ══════════════════════════════════════════════════════════════════════════════
# SECTION 13 β€” CLI Entry Point
# ══════════════════════════════════════════════════════════════════════════════

def build_parser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(
        description="Gov Workflow OpenEnv β€” Multi-Model Rotating LLM Baseline",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
10-model pool (April 2026):
  llama-3.3-70b β†’ deepseek-v4-flash β†’ deepseek-r1 β†’ nemotron-3-super β†’
  qwen3.5-122b β†’ deepseek-v3 β†’ minimax-m2.7 β†’ gemma-4-31b β†’
  phi-4-mini β†’ llama-3.1-8b

Examples:
  python baseline_openai.py --agent heuristic --verbose
  python baseline_openai.py --agent llm --task district_backlog_easy --verbose
  python baseline_openai.py --agent llm --task all --save-results
  python baseline_openai.py --agent llm --model deepseek-ai/deepseek-v4-flash
  python baseline_openai.py --mode http --url http://localhost:7860 --agent llm
  python baseline_openai.py --mode auto --url http://localhost:7860 --agent llm
        """,
    )
    p.add_argument("--agent", choices=["llm", "heuristic"], default="heuristic")
    p.add_argument("--task", choices=list_tasks() + ["all"], default="all")
    p.add_argument("--model", default=None)
    p.add_argument("--mode", choices=["direct", "http", "auto"], default="auto")
    p.add_argument("--url", default="http://localhost:7860")
    p.add_argument("--max-steps", type=int, default=MAX_LLM_STEPS)
    p.add_argument("--delay", type=float, default=None)
    p.add_argument("--api-key", default=None)
    p.add_argument("--verbose", action="store_true")
    p.add_argument("--save-results", action="store_true")
    return p


def main() -> None:
    args = build_parser().parse_args()
    tasks = list_tasks() if args.task == "all" else [args.task]

    print(f"\n{'═'*65}")
    print("  Gov Workflow OpenEnv β€” Baseline Runner (April 2026)")
    print(f"  Agent : {args.agent.upper()}")
    if args.agent == "llm":
        pool_disp = " β†’ ".join(m.split("/")[-1][:12] for m in GLOBAL_MODEL_POOL)
        print(f"  Pool  : {pool_disp}")
    print(f"  Mode  : {args.mode}  |  Tasks: {', '.join(tasks)}")
    print(f"{'═'*65}")

    if args.agent == "llm":
        key = args.api_key or os.environ.get("NVIDIA_API_KEY", "")
        if not key:
            print("\n❌  NVIDIA_API_KEY not set.")
            print("    .env file  : NVIDIA_API_KEY=nvapi-xxxx")
            print("    PowerShell : $env:NVIDIA_API_KEY='nvapi-xxxx'")
            print("    Get free key: https://build.nvidia.com/explore/discover\n")
            sys.exit(1)
    else:
        key = None

    results: list[EpisodeResult] = []
    for task_id in tasks:
        result = run_episode(
            task_id=task_id,
            agent_type=args.agent,
            model_override=args.model,
            mode=args.mode,
            server_url=args.url,
            api_key=key,
            verbose=args.verbose,
            max_steps=args.max_steps,
            delay_override=args.delay,
        )
        results.append(result)

    print_leaderboard(results)

    if args.save_results:
        out = save_results(results, Path("results"))
        print(f"  Results saved β†’ {out}\n")


if __name__ == "__main__":
    main()