ZENLLC commited on
Commit
b83ea71
·
verified ·
1 Parent(s): d55a6fe

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1838 -0
app.py ADDED
@@ -0,0 +1,1838 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import hashlib
4
+ from dataclasses import dataclass, asdict
5
+ from typing import Dict, List, Tuple, Optional, Any
6
+
7
+ import numpy as np
8
+ from PIL import Image, ImageDraw
9
+
10
+ import matplotlib.pyplot as plt
11
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
12
+
13
+ import gradio as gr
14
+
15
+ # ============================================================
16
+ # ZEN AgentLab — Agent POV + Multi-Agent Mini-Sim Arena
17
+ #
18
+ # Additions in this version:
19
+ # - Autoplay (Start/Stop) via gr.Timer (watch agents live)
20
+ # - One-click "Cinematic Run" (full episode in one click)
21
+ # - Example presets (env+seed) + seed controls
22
+ # - Autoplay is interruptible: manual buttons still work anytime
23
+ #
24
+ # Matplotlib HF-safe: uses canvas.buffer_rgba()
25
+ # ============================================================
26
+
27
+ # -----------------------------
28
+ # Global config
29
+ # -----------------------------
30
+ GRID_W, GRID_H = 21, 15
31
+ TILE = 22
32
+
33
+ VIEW_W, VIEW_H = 640, 360
34
+ RAY_W = 320
35
+ FOV_DEG = 78
36
+ MAX_DEPTH = 20
37
+
38
+ DIRS = [(1, 0), (0, 1), (-1, 0), (0, -1)]
39
+ ORI_DEG = [0, 90, 180, 270]
40
+
41
+ # Tiles
42
+ EMPTY = 0
43
+ WALL = 1
44
+ FOOD = 2
45
+ NOISE = 3
46
+ DOOR = 4
47
+ TELE = 5
48
+ KEY = 6
49
+ EXIT = 7
50
+ ARTIFACT = 8
51
+ HAZARD = 9
52
+ WOOD = 10
53
+ ORE = 11
54
+ MEDKIT = 12
55
+ SWITCH = 13
56
+ BASE = 14
57
+
58
+ TILE_NAMES = {
59
+ EMPTY: "Empty",
60
+ WALL: "Wall",
61
+ FOOD: "Food",
62
+ NOISE: "Noise",
63
+ DOOR: "Door",
64
+ TELE: "Teleporter",
65
+ KEY: "Key",
66
+ EXIT: "Exit",
67
+ ARTIFACT: "Artifact",
68
+ HAZARD: "Hazard",
69
+ WOOD: "Wood",
70
+ ORE: "Ore",
71
+ MEDKIT: "Medkit",
72
+ SWITCH: "Switch",
73
+ BASE: "Base",
74
+ }
75
+
76
+ AGENT_COLORS = {
77
+ "Predator": (255, 120, 90),
78
+ "Prey": (120, 255, 160),
79
+ "Scout": (120, 190, 255),
80
+ "Alpha": (255, 205, 120),
81
+ "Bravo": (160, 210, 255),
82
+ "Guardian": (255, 120, 220),
83
+ "BuilderA": (140, 255, 200),
84
+ "BuilderB": (160, 200, 255),
85
+ "Raider": (255, 160, 120),
86
+ }
87
+
88
+ SKY = np.array([14, 16, 26], dtype=np.uint8)
89
+ FLOOR_NEAR = np.array([24, 26, 40], dtype=np.uint8)
90
+ FLOOR_FAR = np.array([10, 11, 18], dtype=np.uint8)
91
+ WALL_BASE = np.array([210, 210, 225], dtype=np.uint8)
92
+ WALL_SIDE = np.array([150, 150, 170], dtype=np.uint8)
93
+ DOOR_COL = np.array([140, 210, 255], dtype=np.uint8)
94
+
95
+ # Small action space
96
+ ACTIONS = ["L", "F", "R", "I"] # interact
97
+
98
+ TRACE_MAX = 500
99
+ MAX_HISTORY = 1400
100
+
101
+ # -----------------------------
102
+ # Deterministic RNG
103
+ # -----------------------------
104
+ def rng_for(seed: int, step: int, stream: int = 0) -> np.random.Generator:
105
+ mix = (seed * 1_000_003) ^ (step * 9_999_937) ^ (stream * 97_531)
106
+ return np.random.default_rng(mix & 0xFFFFFFFFFFFFFFFF)
107
+
108
+ # -----------------------------
109
+ # Data structures
110
+ # -----------------------------
111
+ @dataclass
112
+ class Agent:
113
+ name: str
114
+ x: int
115
+ y: int
116
+ ori: int
117
+ hp: int = 10
118
+ energy: int = 100
119
+ team: str = "A"
120
+ brain: str = "q" # q | heuristic | random
121
+ inventory: Dict[str, int] = None
122
+
123
+ def __post_init__(self):
124
+ if self.inventory is None:
125
+ self.inventory = {}
126
+
127
+ @dataclass
128
+ class TrainConfig:
129
+ use_q: bool = True
130
+ alpha: float = 0.15
131
+ gamma: float = 0.95
132
+ epsilon: float = 0.10
133
+ epsilon_min: float = 0.02
134
+ epsilon_decay: float = 0.995
135
+
136
+ step_penalty: float = -0.01
137
+ explore_reward: float = 0.015
138
+ damage_penalty: float = -0.20
139
+ heal_reward: float = 0.10
140
+
141
+ chase_close_coeff: float = 0.03
142
+ chase_catch_reward: float = 3.0
143
+ chase_escaped_reward: float = 0.2
144
+ chase_caught_penalty: float = -3.0
145
+ food_reward: float = 0.6
146
+
147
+ artifact_pick_reward: float = 1.2
148
+ exit_win_reward: float = 3.0
149
+ guardian_tag_reward: float = 2.0
150
+ tagged_penalty: float = -2.0
151
+ switch_reward: float = 0.8
152
+ key_reward: float = 0.4
153
+
154
+ resource_pick_reward: float = 0.15
155
+ deposit_reward: float = 0.4
156
+ base_progress_win_reward: float = 3.5
157
+ raider_elim_reward: float = 2.0
158
+ builder_elim_penalty: float = -2.0
159
+
160
+ @dataclass
161
+ class GlobalMetrics:
162
+ episodes: int = 0
163
+ wins_teamA: int = 0
164
+ wins_teamB: int = 0
165
+ draws: int = 0
166
+ avg_steps: float = 0.0
167
+ rolling_winrate_A: float = 0.0
168
+ epsilon: float = 0.10
169
+ last_outcome: str = "init"
170
+ last_steps: int = 0
171
+
172
+ @dataclass
173
+ class EpisodeMetrics:
174
+ steps: int = 0
175
+ returns: Dict[str, float] = None
176
+ action_counts: Dict[str, Dict[str, int]] = None
177
+ tiles_discovered: Dict[str, int] = None
178
+
179
+ def __post_init__(self):
180
+ if self.returns is None:
181
+ self.returns = {}
182
+ if self.action_counts is None:
183
+ self.action_counts = {}
184
+ if self.tiles_discovered is None:
185
+ self.tiles_discovered = {}
186
+
187
+ @dataclass
188
+ class WorldState:
189
+ seed: int
190
+ step: int
191
+ env_key: str
192
+ grid: List[List[int]]
193
+ agents: Dict[str, Agent]
194
+
195
+ controlled: str
196
+ pov: str
197
+ overlay: bool
198
+
199
+ done: bool
200
+ outcome: str # A_win | B_win | draw | ongoing
201
+
202
+ door_opened_global: bool = False
203
+ base_progress: int = 0
204
+ base_target: int = 10
205
+
206
+ event_log: List[str] = None
207
+ trace_log: List[str] = None
208
+
209
+ cfg: TrainConfig = None
210
+ q_tables: Dict[str, Dict[str, List[float]]] = None
211
+ gmetrics: GlobalMetrics = None
212
+ emetrics: EpisodeMetrics = None
213
+
214
+ def __post_init__(self):
215
+ if self.event_log is None:
216
+ self.event_log = []
217
+ if self.trace_log is None:
218
+ self.trace_log = []
219
+ if self.cfg is None:
220
+ self.cfg = TrainConfig()
221
+ if self.q_tables is None:
222
+ self.q_tables = {}
223
+ if self.gmetrics is None:
224
+ self.gmetrics = GlobalMetrics(epsilon=self.cfg.epsilon)
225
+ if self.emetrics is None:
226
+ self.emetrics = EpisodeMetrics()
227
+
228
+ @dataclass
229
+ class Snapshot:
230
+ branch: str
231
+ step: int
232
+ env_key: str
233
+ grid: List[List[int]]
234
+ agents: Dict[str, Dict[str, Any]]
235
+ done: bool
236
+ outcome: str
237
+ door_opened_global: bool
238
+ base_progress: int
239
+ base_target: int
240
+ event_tail: List[str]
241
+ trace_tail: List[str]
242
+ emetrics: Dict[str, Any]
243
+
244
+ # -----------------------------
245
+ # Helpers
246
+ # -----------------------------
247
+ def in_bounds(x: int, y: int) -> bool:
248
+ return 0 <= x < GRID_W and 0 <= y < GRID_H
249
+
250
+ def is_blocking(tile: int, door_open: bool = False) -> bool:
251
+ if tile == WALL:
252
+ return True
253
+ if tile == DOOR and not door_open:
254
+ return True
255
+ return False
256
+
257
+ def manhattan_xy(ax: int, ay: int, bx: int, by: int) -> int:
258
+ return abs(ax - bx) + abs(ay - by)
259
+
260
+ def bresenham_los(grid: List[List[int]], x0: int, y0: int, x1: int, y1: int) -> bool:
261
+ dx = abs(x1 - x0)
262
+ dy = abs(y1 - y0)
263
+ sx = 1 if x0 < x1 else -1
264
+ sy = 1 if y0 < y1 else -1
265
+ err = dx - dy
266
+ x, y = x0, y0
267
+ while True:
268
+ if (x, y) != (x0, y0) and (x, y) != (x1, y1):
269
+ if grid[y][x] == WALL:
270
+ return False
271
+ if x == x1 and y == y1:
272
+ return True
273
+ e2 = 2 * err
274
+ if e2 > -dy:
275
+ err -= dy
276
+ x += sx
277
+ if e2 < dx:
278
+ err += dx
279
+ y += sy
280
+
281
+ def within_fov(observer: Agent, tx: int, ty: int, fov_deg: float = FOV_DEG) -> bool:
282
+ dx = tx - observer.x
283
+ dy = ty - observer.y
284
+ if dx == 0 and dy == 0:
285
+ return True
286
+ angle = math.degrees(math.atan2(dy, dx)) % 360
287
+ facing = ORI_DEG[observer.ori]
288
+ diff = (angle - facing + 540) % 360 - 180
289
+ return abs(diff) <= (fov_deg / 2)
290
+
291
+ def visible(state: WorldState, observer: Agent, target: Agent) -> bool:
292
+ if not within_fov(observer, target.x, target.y, FOV_DEG):
293
+ return False
294
+ return bresenham_los(state.grid, observer.x, observer.y, target.x, target.y)
295
+
296
+ def hash_sha256(txt: str) -> str:
297
+ return hashlib.sha256(txt.encode("utf-8")).hexdigest()
298
+
299
+ # -----------------------------
300
+ # Beliefs
301
+ # -----------------------------
302
+ def init_beliefs(agent_names: List[str]) -> Dict[str, np.ndarray]:
303
+ return {nm: (-1 * np.ones((GRID_H, GRID_W), dtype=np.int16)) for nm in agent_names}
304
+
305
+ def update_belief_for_agent(state: WorldState, belief: np.ndarray, agent: Agent) -> int:
306
+ before_unknown = int(np.sum(belief == -1))
307
+
308
+ belief[agent.y, agent.x] = state.grid[agent.y][agent.x]
309
+ base = math.radians(ORI_DEG[agent.ori])
310
+ half = math.radians(FOV_DEG / 2)
311
+ rays = 45 if agent.name.lower().startswith("scout") else 33
312
+
313
+ for i in range(rays):
314
+ t = i / (rays - 1)
315
+ ang = base + (t * 2 - 1) * half
316
+ sin_a, cos_a = math.sin(ang), math.cos(ang)
317
+ ox, oy = agent.x + 0.5, agent.y + 0.5
318
+ depth = 0.0
319
+ while depth < MAX_DEPTH:
320
+ depth += 0.2
321
+ tx = int(ox + cos_a * depth)
322
+ ty = int(oy + sin_a * depth)
323
+ if not in_bounds(tx, ty):
324
+ break
325
+ belief[ty, tx] = state.grid[ty][tx]
326
+ tile = state.grid[ty][tx]
327
+ if tile == WALL:
328
+ break
329
+ if tile == DOOR and not state.door_opened_global:
330
+ break
331
+
332
+ after_unknown = int(np.sum(belief == -1))
333
+ return max(0, before_unknown - after_unknown)
334
+
335
+ # -----------------------------
336
+ # Rendering
337
+ # -----------------------------
338
+ def raycast_view(state: WorldState, observer: Agent) -> np.ndarray:
339
+ img = np.zeros((VIEW_H, VIEW_W, 3), dtype=np.uint8)
340
+ img[:, :] = SKY
341
+
342
+ for y in range(VIEW_H // 2, VIEW_H):
343
+ t = (y - VIEW_H // 2) / (VIEW_H // 2 + 1e-6)
344
+ col = (1 - t) * FLOOR_NEAR + t * FLOOR_FAR
345
+ img[y, :] = col.astype(np.uint8)
346
+
347
+ fov = math.radians(FOV_DEG)
348
+ half_fov = fov / 2
349
+
350
+ for rx in range(RAY_W):
351
+ cam_x = (2 * rx / (RAY_W - 1)) - 1
352
+ ray_ang = math.radians(ORI_DEG[observer.ori]) + cam_x * half_fov
353
+
354
+ ox, oy = observer.x + 0.5, observer.y + 0.5
355
+ sin_a = math.sin(ray_ang)
356
+ cos_a = math.cos(ray_ang)
357
+
358
+ depth = 0.0
359
+ hit = None
360
+ side = 0
361
+
362
+ while depth < MAX_DEPTH:
363
+ depth += 0.05
364
+ tx = int(ox + cos_a * depth)
365
+ ty = int(oy + sin_a * depth)
366
+ if not in_bounds(tx, ty):
367
+ break
368
+ tile = state.grid[ty][tx]
369
+ if tile == WALL:
370
+ hit = "wall"
371
+ side = 1 if abs(cos_a) > abs(sin_a) else 0
372
+ break
373
+ if tile == DOOR and not state.door_opened_global:
374
+ hit = "door"
375
+ break
376
+
377
+ if hit is None:
378
+ continue
379
+
380
+ depth *= math.cos(ray_ang - math.radians(ORI_DEG[observer.ori]))
381
+ depth = max(depth, 0.001)
382
+
383
+ proj_h = int((VIEW_H * 0.9) / depth)
384
+ y0 = max(0, VIEW_H // 2 - proj_h // 2)
385
+ y1 = min(VIEW_H - 1, VIEW_H // 2 + proj_h // 2)
386
+
387
+ if hit == "door":
388
+ col = DOOR_COL.copy()
389
+ else:
390
+ col = WALL_BASE.copy() if side == 0 else WALL_SIDE.copy()
391
+
392
+ dim = max(0.25, 1.0 - (depth / MAX_DEPTH))
393
+ col = (col * dim).astype(np.uint8)
394
+
395
+ x0 = int(rx * (VIEW_W / RAY_W))
396
+ x1 = int((rx + 1) * (VIEW_W / RAY_W))
397
+ img[y0:y1, x0:x1] = col
398
+
399
+ for nm, other in state.agents.items():
400
+ if nm == observer.name or other.hp <= 0:
401
+ continue
402
+ if visible(state, observer, other):
403
+ dx = other.x - observer.x
404
+ dy = other.y - observer.y
405
+ ang = (math.degrees(math.atan2(dy, dx)) % 360)
406
+ facing = ORI_DEG[observer.ori]
407
+ diff = (ang - facing + 540) % 360 - 180
408
+ sx = int((diff / (FOV_DEG / 2)) * (VIEW_W / 2) + (VIEW_W / 2))
409
+ dist = math.sqrt(dx * dx + dy * dy)
410
+ h = int((VIEW_H * 0.65) / max(dist, 0.75))
411
+ w = max(10, h // 3)
412
+ y_mid = VIEW_H // 2
413
+ y0 = max(0, y_mid - h // 2)
414
+ y1 = min(VIEW_H - 1, y_mid + h // 2)
415
+ x0 = max(0, sx - w // 2)
416
+ x1 = min(VIEW_W - 1, sx + w // 2)
417
+ col = AGENT_COLORS.get(nm, (255, 200, 120))
418
+ img[y0:y1, x0:x1] = np.array(col, dtype=np.uint8)
419
+
420
+ if state.overlay:
421
+ cx, cy = VIEW_W // 2, VIEW_H // 2
422
+ img[cy - 1:cy + 2, cx - 10:cx + 10] = np.array([120, 190, 255], dtype=np.uint8)
423
+ img[cy - 10:cy + 10, cx - 1:cx + 2] = np.array([120, 190, 255], dtype=np.uint8)
424
+
425
+ return img
426
+
427
+ def render_topdown(grid: np.ndarray, agents: Dict[str, Agent], title: str, show_agents: bool = True) -> Image.Image:
428
+ w = grid.shape[1] * TILE
429
+ h = grid.shape[0] * TILE
430
+ im = Image.new("RGB", (w, h + 28), (10, 12, 18))
431
+ draw = ImageDraw.Draw(im)
432
+
433
+ for y in range(grid.shape[0]):
434
+ for x in range(grid.shape[1]):
435
+ t = int(grid[y, x])
436
+ if t == -1:
437
+ col = (18, 20, 32)
438
+ elif t == EMPTY:
439
+ col = (26, 30, 44)
440
+ elif t == WALL:
441
+ col = (190, 190, 210)
442
+ elif t == FOOD:
443
+ col = (255, 210, 120)
444
+ elif t == NOISE:
445
+ col = (255, 120, 220)
446
+ elif t == DOOR:
447
+ col = (140, 210, 255)
448
+ elif t == TELE:
449
+ col = (120, 190, 255)
450
+ elif t == KEY:
451
+ col = (255, 235, 160)
452
+ elif t == EXIT:
453
+ col = (120, 255, 220)
454
+ elif t == ARTIFACT:
455
+ col = (255, 170, 60)
456
+ elif t == HAZARD:
457
+ col = (255, 90, 90)
458
+ elif t == WOOD:
459
+ col = (170, 120, 60)
460
+ elif t == ORE:
461
+ col = (140, 140, 160)
462
+ elif t == MEDKIT:
463
+ col = (120, 255, 140)
464
+ elif t == SWITCH:
465
+ col = (200, 180, 255)
466
+ elif t == BASE:
467
+ col = (220, 220, 240)
468
+ else:
469
+ col = (80, 80, 90)
470
+
471
+ x0, y0 = x * TILE, y * TILE + 28
472
+ draw.rectangle([x0, y0, x0 + TILE - 1, y0 + TILE - 1], fill=col)
473
+
474
+ for x in range(grid.shape[1] + 1):
475
+ xx = x * TILE
476
+ draw.line([xx, 28, xx, h + 28], fill=(12, 14, 22))
477
+ for y in range(grid.shape[0] + 1):
478
+ yy = y * TILE + 28
479
+ draw.line([0, yy, w, yy], fill=(12, 14, 22))
480
+
481
+ if show_agents:
482
+ for nm, a in agents.items():
483
+ if a.hp <= 0:
484
+ continue
485
+ cx = a.x * TILE + TILE // 2
486
+ cy = a.y * TILE + 28 + TILE // 2
487
+ col = AGENT_COLORS.get(nm, (220, 220, 220))
488
+ r = TILE // 3
489
+ draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=col)
490
+ dx, dy = DIRS[a.ori]
491
+ draw.line([cx, cy, cx + dx * r, cy + dy * r], fill=(10, 10, 10), width=3)
492
+
493
+ draw.rectangle([0, 0, w, 28], fill=(14, 16, 26))
494
+ draw.text((8, 6), title, fill=(230, 230, 240))
495
+ return im
496
+
497
+ # -----------------------------
498
+ # Environments
499
+ # -----------------------------
500
+ def grid_with_border() -> List[List[int]]:
501
+ g = [[EMPTY for _ in range(GRID_W)] for _ in range(GRID_H)]
502
+ for x in range(GRID_W):
503
+ g[0][x] = WALL
504
+ g[GRID_H - 1][x] = WALL
505
+ for y in range(GRID_H):
506
+ g[y][0] = WALL
507
+ g[y][GRID_W - 1] = WALL
508
+ return g
509
+
510
+ def env_chase(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
511
+ g = grid_with_border()
512
+ for x in range(4, 17):
513
+ g[7][x] = WALL
514
+ g[7][10] = DOOR
515
+ g[3][4] = FOOD
516
+ g[11][15] = FOOD
517
+ g[4][14] = NOISE
518
+ g[12][5] = NOISE
519
+ g[2][18] = TELE
520
+ g[13][2] = TELE
521
+
522
+ agents = {
523
+ "Predator": Agent("Predator", 2, 2, 0, hp=10, energy=100, team="A", brain="q"),
524
+ "Prey": Agent("Prey", 18, 12, 2, hp=10, energy=100, team="B", brain="q"),
525
+ "Scout": Agent("Scout", 10, 3, 1, hp=10, energy=100, team="A", brain="heuristic"),
526
+ }
527
+ return g, agents
528
+
529
+ def env_vault(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
530
+ g = grid_with_border()
531
+ for x in range(3, 18):
532
+ g[5][x] = WALL
533
+ for x in range(3, 18):
534
+ g[9][x] = WALL
535
+ g[5][10] = DOOR
536
+ g[9][12] = DOOR
537
+
538
+ g[2][2] = KEY
539
+ g[12][18] = EXIT
540
+ g[12][2] = ARTIFACT
541
+ g[2][18] = TELE
542
+ g[13][2] = TELE
543
+ g[7][10] = SWITCH
544
+ g[3][15] = HAZARD
545
+ g[11][6] = MEDKIT
546
+ g[2][12] = FOOD
547
+
548
+ agents = {
549
+ "Alpha": Agent("Alpha", 2, 12, 0, hp=10, energy=100, team="A", brain="q"),
550
+ "Bravo": Agent("Bravo", 3, 12, 0, hp=10, energy=100, team="A", brain="q"),
551
+ "Guardian": Agent("Guardian", 18, 2, 2, hp=10, energy=100, team="B", brain="q"),
552
+ }
553
+ return g, agents
554
+
555
+ def env_civ(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
556
+ g = grid_with_border()
557
+ for y in range(3, 12):
558
+ g[y][9] = WALL
559
+ g[7][9] = DOOR
560
+
561
+ g[2][3] = WOOD
562
+ g[3][3] = WOOD
563
+ g[4][3] = WOOD
564
+ g[12][16] = ORE
565
+ g[11][16] = ORE
566
+ g[10][16] = ORE
567
+ g[6][4] = FOOD
568
+ g[8][15] = FOOD
569
+
570
+ g[13][10] = BASE
571
+ g[4][15] = HAZARD
572
+ g[10][4] = HAZARD
573
+ g[2][18] = TELE
574
+ g[13][2] = TELE
575
+ g[2][2] = KEY
576
+ g[12][6] = SWITCH
577
+
578
+ agents = {
579
+ "BuilderA": Agent("BuilderA", 3, 12, 0, hp=10, energy=100, team="A", brain="q"),
580
+ "BuilderB": Agent("BuilderB", 4, 12, 0, hp=10, energy=100, team="A", brain="q"),
581
+ "Raider": Agent("Raider", 18, 2, 2, hp=10, energy=100, team="B", brain="q"),
582
+ }
583
+ return g, agents
584
+
585
+ ENV_BUILDERS = {"chase": env_chase, "vault": env_vault, "civ": env_civ}
586
+
587
+ # -----------------------------
588
+ # Observation / Q-learning
589
+ # -----------------------------
590
+ def local_tile_ahead(state: WorldState, a: Agent) -> int:
591
+ dx, dy = DIRS[a.ori]
592
+ nx, ny = a.x + dx, a.y + dy
593
+ if not in_bounds(nx, ny):
594
+ return WALL
595
+ return state.grid[ny][nx]
596
+
597
+ def nearest_enemy_vec(state: WorldState, a: Agent) -> Tuple[int, int, int]:
598
+ best = None
599
+ for _, other in state.agents.items():
600
+ if other.hp <= 0:
601
+ continue
602
+ if other.team == a.team:
603
+ continue
604
+ d = manhattan_xy(a.x, a.y, other.x, other.y)
605
+ if best is None or d < best[0]:
606
+ best = (d, other.x - a.x, other.y - a.y)
607
+ if best is None:
608
+ return (99, 0, 0)
609
+ d, dx, dy = best
610
+ return (d, int(np.clip(dx, -6, 6)), int(np.clip(dy, -6, 6)))
611
+
612
+ def obs_key(state: WorldState, who: str) -> str:
613
+ a = state.agents[who]
614
+ d, dx, dy = nearest_enemy_vec(state, a)
615
+ ahead = local_tile_ahead(state, a)
616
+ keys = a.inventory.get("key", 0)
617
+ art = a.inventory.get("artifact", 0)
618
+ wood = a.inventory.get("wood", 0)
619
+ ore = a.inventory.get("ore", 0)
620
+ inv_bucket = f"k{min(keys,2)}a{min(art,1)}w{min(wood,3)}o{min(ore,3)}"
621
+ door = 1 if state.door_opened_global else 0
622
+ return f"{state.env_key}|{who}|{a.x},{a.y},{a.ori}|e{d}:{dx},{dy}|t{ahead}|hp{a.hp}|{inv_bucket}|D{door}|bp{state.base_progress}"
623
+
624
+ def q_get(q: Dict[str, List[float]], key: str) -> List[float]:
625
+ if key not in q:
626
+ q[key] = [0.0 for _ in ACTIONS]
627
+ return q[key]
628
+
629
+ def epsilon_greedy(qvals: List[float], eps: float, r: np.random.Generator) -> int:
630
+ if r.random() < eps:
631
+ return int(r.integers(0, len(qvals)))
632
+ return int(np.argmax(qvals))
633
+
634
+ def q_update(q: Dict[str, List[float]], key: str, a_idx: int, reward: float, next_key: str,
635
+ alpha: float, gamma: float) -> Tuple[float, float, float]:
636
+ qv = q_get(q, key)
637
+ nq = q_get(q, next_key)
638
+ old = qv[a_idx]
639
+ target = reward + gamma * float(np.max(nq))
640
+ new = old + alpha * (target - old)
641
+ qv[a_idx] = new
642
+ return old, target, new
643
+
644
+ # -----------------------------
645
+ # Baseline heuristics
646
+ # -----------------------------
647
+ def heuristic_action(state: WorldState, who: str) -> str:
648
+ a = state.agents[who]
649
+ r = rng_for(state.seed, state.step, stream=900 + hash(who) % 1000)
650
+
651
+ t_here = state.grid[a.y][a.x]
652
+ if t_here in (FOOD, KEY, ARTIFACT, WOOD, ORE, MEDKIT, SWITCH, BASE, EXIT):
653
+ return "I"
654
+
655
+ best = None
656
+ best_d = 999
657
+ for _, other in state.agents.items():
658
+ if other.hp <= 0 or other.team == a.team:
659
+ continue
660
+ d = manhattan_xy(a.x, a.y, other.x, other.y)
661
+ if d < best_d:
662
+ best_d = d
663
+ best = other
664
+
665
+ if best is not None and best_d <= 6 and visible(state, a, best):
666
+ dx = best.x - a.x
667
+ dy = best.y - a.y
668
+ ang = (math.degrees(math.atan2(dy, dx)) % 360)
669
+ facing = ORI_DEG[a.ori]
670
+ diff = (ang - facing + 540) % 360 - 180
671
+ if diff < -10:
672
+ return "L"
673
+ if diff > 10:
674
+ return "R"
675
+ return "F"
676
+
677
+ return r.choice(["F", "F", "L", "R", "I"])
678
+
679
+ def random_action(state: WorldState, who: str) -> str:
680
+ r = rng_for(state.seed, state.step, stream=700 + hash(who) % 1000)
681
+ return r.choice(ACTIONS)
682
+
683
+ # -----------------------------
684
+ # Movement + interaction
685
+ # -----------------------------
686
+ def turn_left(a: Agent) -> None:
687
+ a.ori = (a.ori - 1) % 4
688
+
689
+ def turn_right(a: Agent) -> None:
690
+ a.ori = (a.ori + 1) % 4
691
+
692
+ def move_forward(state: WorldState, a: Agent) -> str:
693
+ dx, dy = DIRS[a.ori]
694
+ nx, ny = a.x + dx, a.y + dy
695
+ if not in_bounds(nx, ny):
696
+ return "blocked: bounds"
697
+ tile = state.grid[ny][nx]
698
+ if is_blocking(tile, door_open=state.door_opened_global):
699
+ return "blocked: wall/door"
700
+ a.x, a.y = nx, ny
701
+
702
+ if state.grid[ny][nx] == TELE:
703
+ teles = [(x, y) for y in range(GRID_H) for x in range(GRID_W) if state.grid[y][x] == TELE]
704
+ if len(teles) >= 2:
705
+ teles_sorted = sorted(teles)
706
+ idx = teles_sorted.index((nx, ny))
707
+ dest = teles_sorted[(idx + 1) % len(teles_sorted)]
708
+ a.x, a.y = dest
709
+ state.event_log.append(f"t={state.step}: {a.name} teleported.")
710
+ return "moved: teleported"
711
+ return "moved"
712
+
713
+ def try_interact(state: WorldState, a: Agent) -> str:
714
+ t = state.grid[a.y][a.x]
715
+
716
+ if t == SWITCH:
717
+ state.door_opened_global = True
718
+ state.grid[a.y][a.x] = EMPTY
719
+ a.inventory["switch"] = a.inventory.get("switch", 0) + 1
720
+ return "switch: opened all doors"
721
+
722
+ if t == KEY:
723
+ a.inventory["key"] = a.inventory.get("key", 0) + 1
724
+ state.grid[a.y][a.x] = EMPTY
725
+ return "picked: key"
726
+
727
+ if t == ARTIFACT:
728
+ a.inventory["artifact"] = a.inventory.get("artifact", 0) + 1
729
+ state.grid[a.y][a.x] = EMPTY
730
+ return "picked: artifact"
731
+
732
+ if t == FOOD:
733
+ a.energy = min(200, a.energy + 35)
734
+ state.grid[a.y][a.x] = EMPTY
735
+ return "ate: food"
736
+
737
+ if t == WOOD:
738
+ a.inventory["wood"] = a.inventory.get("wood", 0) + 1
739
+ state.grid[a.y][a.x] = EMPTY
740
+ return "picked: wood"
741
+
742
+ if t == ORE:
743
+ a.inventory["ore"] = a.inventory.get("ore", 0) + 1
744
+ state.grid[a.y][a.x] = EMPTY
745
+ return "picked: ore"
746
+
747
+ if t == MEDKIT:
748
+ a.hp = min(10, a.hp + 3)
749
+ state.grid[a.y][a.x] = EMPTY
750
+ return "used: medkit"
751
+
752
+ if t == BASE:
753
+ w = a.inventory.get("wood", 0)
754
+ o = a.inventory.get("ore", 0)
755
+ dep = min(w, 2) + min(o, 2)
756
+ if dep > 0:
757
+ a.inventory["wood"] = max(0, w - min(w, 2))
758
+ a.inventory["ore"] = max(0, o - min(o, 2))
759
+ state.base_progress += dep
760
+ return f"deposited: +{dep} base_progress"
761
+ return "base: nothing to deposit"
762
+
763
+ if t == EXIT:
764
+ return "at_exit"
765
+
766
+ return "interact: none"
767
+
768
+ def apply_action(state: WorldState, who: str, action: str) -> str:
769
+ a = state.agents[who]
770
+ if a.hp <= 0:
771
+ return "dead"
772
+ if action == "L":
773
+ turn_left(a)
774
+ return "turned left"
775
+ if action == "R":
776
+ turn_right(a)
777
+ return "turned right"
778
+ if action == "F":
779
+ return move_forward(state, a)
780
+ if action == "I":
781
+ return try_interact(state, a)
782
+ return "noop"
783
+
784
+ # -----------------------------
785
+ # Hazards / collisions / done
786
+ # -----------------------------
787
+ def resolve_hazards(state: WorldState, a: Agent) -> Tuple[bool, str]:
788
+ if a.hp <= 0:
789
+ return (False, "")
790
+ if state.grid[a.y][a.x] == HAZARD:
791
+ a.hp -= 1
792
+ return (True, "hazard:-hp")
793
+ return (False, "")
794
+
795
+ def resolve_tags(state: WorldState) -> List[str]:
796
+ msgs = []
797
+ occupied: Dict[Tuple[int, int], List[str]] = {}
798
+ for nm, a in state.agents.items():
799
+ if a.hp <= 0:
800
+ continue
801
+ occupied.setdefault((a.x, a.y), []).append(nm)
802
+
803
+ for (x, y), names in occupied.items():
804
+ if len(names) < 2:
805
+ continue
806
+ teams = set(state.agents[n].team for n in names)
807
+ if len(teams) >= 2:
808
+ for n in names:
809
+ state.agents[n].hp -= 1
810
+ msgs.append(f"t={state.step}: collision/tag at ({x},{y}) {names} (-hp all)")
811
+ return msgs
812
+
813
+ def check_done(state: WorldState) -> None:
814
+ if state.env_key == "chase":
815
+ pred = state.agents["Predator"]
816
+ prey = state.agents["Prey"]
817
+ if pred.hp <= 0 and prey.hp <= 0:
818
+ state.done = True
819
+ state.outcome = "draw"
820
+ return
821
+ if pred.hp > 0 and prey.hp > 0 and pred.x == prey.x and pred.y == prey.y:
822
+ state.done = True
823
+ state.outcome = "A_win"
824
+ state.event_log.append(f"t={state.step}: CAUGHT (Predator wins).")
825
+ return
826
+ if state.step >= 300 and prey.hp > 0:
827
+ state.done = True
828
+ state.outcome = "B_win"
829
+ state.event_log.append(f"t={state.step}: ESCAPED (Prey survives).")
830
+ return
831
+
832
+ if state.env_key == "vault":
833
+ for nm in ["Alpha", "Bravo"]:
834
+ a = state.agents[nm]
835
+ if a.hp > 0 and a.inventory.get("artifact", 0) > 0 and state.grid[a.y][a.x] == EXIT:
836
+ state.done = True
837
+ state.outcome = "A_win"
838
+ state.event_log.append(f"t={state.step}: VAULT CLEARED (Team A wins).")
839
+ return
840
+ alive_A = any(state.agents[n].hp > 0 for n in ["Alpha", "Bravo"])
841
+ if not alive_A:
842
+ state.done = True
843
+ state.outcome = "B_win"
844
+ state.event_log.append(f"t={state.step}: TEAM A ELIMINATED (Guardian wins).")
845
+ return
846
+
847
+ if state.env_key == "civ":
848
+ if state.base_progress >= state.base_target:
849
+ state.done = True
850
+ state.outcome = "A_win"
851
+ state.event_log.append(f"t={state.step}: BASE COMPLETE (Builders win).")
852
+ return
853
+ alive_A = any(state.agents[n].hp > 0 for n in ["BuilderA", "BuilderB"])
854
+ if not alive_A:
855
+ state.done = True
856
+ state.outcome = "B_win"
857
+ state.event_log.append(f"t={state.step}: BUILDERS ELIMINATED (Raider wins).")
858
+ return
859
+ if state.step >= 350:
860
+ state.done = True
861
+ state.outcome = "draw"
862
+ state.event_log.append(f"t={state.step}: TIMEOUT (draw).")
863
+ return
864
+
865
+ # -----------------------------
866
+ # Rewards
867
+ # -----------------------------
868
+ def reward_for(prev: WorldState, now: WorldState, who: str, outcome_msg: str, took_damage: bool) -> float:
869
+ cfg = now.cfg
870
+ r = cfg.step_penalty
871
+ if outcome_msg.startswith("moved"):
872
+ r += cfg.explore_reward
873
+ if took_damage:
874
+ r += cfg.damage_penalty
875
+ if outcome_msg.startswith("used: medkit"):
876
+ r += cfg.heal_reward
877
+
878
+ if now.env_key == "chase":
879
+ pred = now.agents["Predator"]
880
+ prey = now.agents["Prey"]
881
+ if who == "Predator":
882
+ d0 = manhattan_xy(prev.agents["Predator"].x, prev.agents["Predator"].y,
883
+ prev.agents["Prey"].x, prev.agents["Prey"].y)
884
+ d1 = manhattan_xy(pred.x, pred.y, prey.x, prey.y)
885
+ r += cfg.chase_close_coeff * float(d0 - d1)
886
+ if now.done and now.outcome == "A_win":
887
+ r += cfg.chase_catch_reward
888
+ if who == "Prey":
889
+ if outcome_msg.startswith("ate: food"):
890
+ r += cfg.food_reward
891
+ if now.done and now.outcome == "B_win":
892
+ r += cfg.chase_escaped_reward
893
+ if now.done and now.outcome == "A_win":
894
+ r += cfg.chase_caught_penalty
895
+
896
+ if now.env_key == "vault":
897
+ if outcome_msg.startswith("picked: artifact"):
898
+ r += cfg.artifact_pick_reward
899
+ if outcome_msg.startswith("picked: key"):
900
+ r += cfg.key_reward
901
+ if outcome_msg.startswith("switch:"):
902
+ r += cfg.switch_reward
903
+ if now.done:
904
+ if now.outcome == "A_win" and now.agents[who].team == "A":
905
+ r += cfg.exit_win_reward
906
+ if now.outcome == "B_win" and now.agents[who].team == "B":
907
+ r += cfg.guardian_tag_reward
908
+ if now.outcome == "B_win" and now.agents[who].team == "A":
909
+ r += cfg.tagged_penalty
910
+
911
+ if now.env_key == "civ":
912
+ if outcome_msg.startswith("picked: wood") or outcome_msg.startswith("picked: ore"):
913
+ r += cfg.resource_pick_reward
914
+ if outcome_msg.startswith("deposited:"):
915
+ r += cfg.deposit_reward
916
+ if now.done:
917
+ if now.outcome == "A_win" and now.agents[who].team == "A":
918
+ r += cfg.base_progress_win_reward
919
+ if now.outcome == "B_win" and now.agents[who].team == "B":
920
+ r += cfg.raider_elim_reward
921
+ if now.outcome == "B_win" and now.agents[who].team == "A":
922
+ r += cfg.builder_elim_penalty
923
+
924
+ return float(r)
925
+
926
+ # -----------------------------
927
+ # Policy selection
928
+ # -----------------------------
929
+ def choose_action(state: WorldState, who: str, stream: int) -> Tuple[str, str, Optional[Tuple[str, int]]]:
930
+ a = state.agents[who]
931
+ cfg = state.cfg
932
+ r = rng_for(state.seed, state.step, stream=stream)
933
+
934
+ if a.brain == "random":
935
+ act = random_action(state, who)
936
+ return act, "random", None
937
+ if a.brain == "heuristic":
938
+ act = heuristic_action(state, who)
939
+ return act, "heuristic", None
940
+
941
+ if cfg.use_q:
942
+ key = obs_key(state, who)
943
+ qtab = state.q_tables.setdefault(who, {})
944
+ qv = q_get(qtab, key)
945
+ a_idx = epsilon_greedy(qv, state.gmetrics.epsilon, r)
946
+ return ACTIONS[a_idx], f"Q eps={state.gmetrics.epsilon:.3f} q={np.round(qv,3).tolist()}", (key, a_idx)
947
+
948
+ act = heuristic_action(state, who)
949
+ return act, "heuristic(fallback)", None
950
+
951
+ # -----------------------------
952
+ # Init / reset
953
+ # -----------------------------
954
+ def init_state(seed: int, env_key: str) -> WorldState:
955
+ g, agents = ENV_BUILDERS[env_key](seed)
956
+ st = WorldState(
957
+ seed=seed,
958
+ step=0,
959
+ env_key=env_key,
960
+ grid=g,
961
+ agents=agents,
962
+ controlled=list(agents.keys())[0],
963
+ pov=list(agents.keys())[0],
964
+ overlay=False,
965
+ done=False,
966
+ outcome="ongoing",
967
+ door_opened_global=False,
968
+ base_progress=0,
969
+ base_target=10,
970
+ )
971
+ st.event_log = [f"Initialized env={env_key} seed={seed}."]
972
+ return st
973
+
974
+ def reset_episode_keep_learning(state: WorldState, seed: Optional[int] = None) -> WorldState:
975
+ if seed is None:
976
+ seed = state.seed
977
+ fresh = init_state(int(seed), state.env_key)
978
+ fresh.cfg = state.cfg
979
+ fresh.q_tables = state.q_tables
980
+ fresh.gmetrics = state.gmetrics
981
+ fresh.gmetrics.epsilon = state.gmetrics.epsilon
982
+ return fresh
983
+
984
+ def wipe_all(seed: int, env_key: str) -> WorldState:
985
+ st = init_state(seed, env_key)
986
+ st.cfg = TrainConfig()
987
+ st.gmetrics = GlobalMetrics(epsilon=st.cfg.epsilon)
988
+ st.q_tables = {}
989
+ return st
990
+
991
+ # -----------------------------
992
+ # History / branching
993
+ # -----------------------------
994
+ def snapshot_of(state: WorldState, branch: str) -> Snapshot:
995
+ return Snapshot(
996
+ branch=branch,
997
+ step=state.step,
998
+ env_key=state.env_key,
999
+ grid=[row[:] for row in state.grid],
1000
+ agents={k: asdict(v) for k, v in state.agents.items()},
1001
+ done=state.done,
1002
+ outcome=state.outcome,
1003
+ door_opened_global=state.door_opened_global,
1004
+ base_progress=state.base_progress,
1005
+ base_target=state.base_target,
1006
+ event_tail=state.event_log[-25:],
1007
+ trace_tail=state.trace_log[-40:],
1008
+ emetrics=asdict(state.emetrics),
1009
+ )
1010
+
1011
+ def restore_into(state: WorldState, snap: Snapshot) -> WorldState:
1012
+ state.step = snap.step
1013
+ state.env_key = snap.env_key
1014
+ state.grid = [row[:] for row in snap.grid]
1015
+ state.agents = {k: Agent(**d) for k, d in snap.agents.items()}
1016
+ state.done = snap.done
1017
+ state.outcome = snap.outcome
1018
+ state.door_opened_global = snap.door_opened_global
1019
+ state.base_progress = snap.base_progress
1020
+ state.base_target = snap.base_target
1021
+ state.event_log.append(f"Jumped to snapshot t={snap.step} (branch={snap.branch}).")
1022
+ return state
1023
+
1024
+ # -----------------------------
1025
+ # Metrics / dashboard
1026
+ # -----------------------------
1027
+ def metrics_dashboard_image(state: WorldState) -> Image.Image:
1028
+ gm = state.gmetrics
1029
+
1030
+ fig = plt.figure(figsize=(7.0, 2.2), dpi=120)
1031
+ ax = fig.add_subplot(111)
1032
+
1033
+ x1 = max(1, gm.episodes)
1034
+ ax.plot([0, x1], [gm.rolling_winrate_A, gm.rolling_winrate_A])
1035
+ ax.set_title("Global Metrics Snapshot")
1036
+ ax.set_xlabel("Episodes")
1037
+ ax.set_ylabel("Rolling winrate Team A")
1038
+ ax.set_ylim(-0.05, 1.05)
1039
+ ax.grid(True)
1040
+
1041
+ txt = (
1042
+ f"env={state.env_key} | eps={gm.epsilon:.3f} | episodes={gm.episodes}\n"
1043
+ f"A_wins={gm.wins_teamA} B_wins={gm.wins_teamB} draws={gm.draws} | avg_steps~{gm.avg_steps:.1f}\n"
1044
+ f"last_outcome={gm.last_outcome} last_steps={gm.last_steps}"
1045
+ )
1046
+ ax.text(0.01, 0.05, txt, transform=ax.transAxes, fontsize=8, va="bottom")
1047
+
1048
+ fig.tight_layout()
1049
+ canvas = FigureCanvas(fig)
1050
+ canvas.draw()
1051
+ buf = np.asarray(canvas.buffer_rgba())
1052
+ img = Image.fromarray(buf, mode="RGBA").convert("RGB")
1053
+ plt.close(fig)
1054
+ return img
1055
+
1056
+ def action_entropy(counts: Dict[str, int]) -> float:
1057
+ total = sum(counts.values())
1058
+ if total <= 0:
1059
+ return 0.0
1060
+ p = np.array([c / total for c in counts.values()], dtype=np.float64)
1061
+ p = np.clip(p, 1e-12, 1.0)
1062
+ return float(-np.sum(p * np.log2(p)))
1063
+
1064
+ def agent_scoreboard(state: WorldState) -> str:
1065
+ rows = []
1066
+ header = ["agent", "team", "hp", "return", "steps", "entropy", "tiles_disc", "q_states", "inventory"]
1067
+ rows.append(header)
1068
+ steps = state.emetrics.steps
1069
+
1070
+ for nm, a in state.agents.items():
1071
+ ret = state.emetrics.returns.get(nm, 0.0)
1072
+ counts = state.emetrics.action_counts.get(nm, {})
1073
+ ent = action_entropy(counts)
1074
+ td = state.emetrics.tiles_discovered.get(nm, 0)
1075
+ qs = len(state.q_tables.get(nm, {}))
1076
+ inv = json.dumps(a.inventory, sort_keys=True)
1077
+ rows.append([nm, a.team, a.hp, f"{ret:.2f}", steps, f"{ent:.2f}", td, qs, inv])
1078
+
1079
+ col_w = [max(len(str(r[i])) for r in rows) for i in range(len(header))]
1080
+ lines = []
1081
+ for ridx, r in enumerate(rows):
1082
+ line = " | ".join(str(r[i]).ljust(col_w[i]) for i in range(len(header)))
1083
+ lines.append(line)
1084
+ if ridx == 0:
1085
+ lines.append("-+-".join("-" * w for w in col_w))
1086
+ return "\n".join(lines)
1087
+
1088
+ # -----------------------------
1089
+ # Tick / training
1090
+ # -----------------------------
1091
+ def clone_shallow(state: WorldState) -> WorldState:
1092
+ return WorldState(
1093
+ seed=state.seed,
1094
+ step=state.step,
1095
+ env_key=state.env_key,
1096
+ grid=[row[:] for row in state.grid],
1097
+ agents={k: Agent(**asdict(v)) for k, v in state.agents.items()},
1098
+ controlled=state.controlled,
1099
+ pov=state.pov,
1100
+ overlay=state.overlay,
1101
+ done=state.done,
1102
+ outcome=state.outcome,
1103
+ door_opened_global=state.door_opened_global,
1104
+ base_progress=state.base_progress,
1105
+ base_target=state.base_target,
1106
+ event_log=list(state.event_log),
1107
+ trace_log=list(state.trace_log),
1108
+ cfg=state.cfg,
1109
+ q_tables=state.q_tables,
1110
+ gmetrics=state.gmetrics,
1111
+ emetrics=state.emetrics,
1112
+ )
1113
+
1114
+ def update_action_counts(state: WorldState, who: str, act: str):
1115
+ state.emetrics.action_counts.setdefault(who, {})
1116
+ state.emetrics.action_counts[who][act] = state.emetrics.action_counts[who].get(act, 0) + 1
1117
+
1118
+ def tick(state: WorldState, beliefs: Dict[str, np.ndarray], manual_action: Optional[str] = None) -> None:
1119
+ if state.done:
1120
+ return
1121
+
1122
+ prev = clone_shallow(state)
1123
+ chosen: Dict[str, str] = {}
1124
+ reasons: Dict[str, str] = {}
1125
+ qinfo: Dict[str, Optional[Tuple[str, int]]] = {}
1126
+
1127
+ if manual_action is not None:
1128
+ chosen[state.controlled] = manual_action
1129
+ reasons[state.controlled] = "manual"
1130
+ qinfo[state.controlled] = None
1131
+
1132
+ order = list(state.agents.keys())
1133
+ for who in order:
1134
+ if who in chosen:
1135
+ continue
1136
+ act, reason, qi = choose_action(state, who, stream=200 + (hash(who) % 1000))
1137
+ chosen[who] = act
1138
+ reasons[who] = reason
1139
+ qinfo[who] = qi
1140
+
1141
+ outcomes: Dict[str, str] = {}
1142
+ took_damage: Dict[str, bool] = {nm: False for nm in order}
1143
+
1144
+ for who in order:
1145
+ outcomes[who] = apply_action(state, who, chosen[who])
1146
+ dmg, msg = resolve_hazards(state, state.agents[who])
1147
+ took_damage[who] = dmg
1148
+ if msg:
1149
+ state.event_log.append(f"t={state.step}: {who} {msg}")
1150
+ update_action_counts(state, who, chosen[who])
1151
+
1152
+ for m in resolve_tags(state):
1153
+ state.event_log.append(m)
1154
+
1155
+ for nm, a in state.agents.items():
1156
+ if a.hp <= 0:
1157
+ continue
1158
+ disc = update_belief_for_agent(state, beliefs[nm], a)
1159
+ state.emetrics.tiles_discovered[nm] = state.emetrics.tiles_discovered.get(nm, 0) + disc
1160
+
1161
+ check_done(state)
1162
+
1163
+ q_lines = []
1164
+ for who in order:
1165
+ state.emetrics.returns.setdefault(who, 0.0)
1166
+
1167
+ r = reward_for(prev, state, who, outcomes[who], took_damage[who])
1168
+ state.emetrics.returns[who] += r
1169
+
1170
+ if qinfo.get(who) is not None:
1171
+ key, a_idx = qinfo[who]
1172
+ next_key = obs_key(state, who)
1173
+ qtab = state.q_tables.setdefault(who, {})
1174
+ old, tgt, new = q_update(qtab, key, a_idx, r, next_key, state.cfg.alpha, state.cfg.gamma)
1175
+ q_lines.append(f"{who}: old={old:.3f} tgt={tgt:.3f} new={new:.3f} (a={ACTIONS[a_idx]})")
1176
+
1177
+ trace = f"t={state.step} env={state.env_key} done={state.done} outcome={state.outcome}"
1178
+ for who in order:
1179
+ a = state.agents[who]
1180
+ trace += f" | {who}:{chosen[who]} ({outcomes[who]}) hp={a.hp} [{reasons[who]}]"
1181
+ if q_lines:
1182
+ trace += " | Q: " + " ; ".join(q_lines)
1183
+
1184
+ state.trace_log.append(trace)
1185
+ if len(state.trace_log) > TRACE_MAX:
1186
+ state.trace_log = state.trace_log[-TRACE_MAX:]
1187
+
1188
+ state.step += 1
1189
+ state.emetrics.steps = state.step
1190
+
1191
+ def run_episode(state: WorldState, beliefs: Dict[str, np.ndarray], max_steps: int) -> Tuple[str, int]:
1192
+ while state.step < max_steps and not state.done:
1193
+ tick(state, beliefs, manual_action=None)
1194
+ return state.outcome, state.step
1195
+
1196
+ def update_global_metrics_after_episode(state: WorldState, outcome: str, steps: int):
1197
+ gm = state.gmetrics
1198
+ gm.episodes += 1
1199
+ gm.last_outcome = outcome
1200
+ gm.last_steps = steps
1201
+
1202
+ if outcome == "A_win":
1203
+ gm.wins_teamA += 1
1204
+ gm.rolling_winrate_A = 0.90 * gm.rolling_winrate_A + 0.10 * 1.0
1205
+ elif outcome == "B_win":
1206
+ gm.wins_teamB += 1
1207
+ gm.rolling_winrate_A = 0.90 * gm.rolling_winrate_A + 0.10 * 0.0
1208
+ else:
1209
+ gm.draws += 1
1210
+ gm.rolling_winrate_A = 0.90 * gm.rolling_winrate_A + 0.10 * 0.5
1211
+
1212
+ gm.avg_steps = (0.90 * gm.avg_steps + 0.10 * steps) if gm.avg_steps > 0 else float(steps)
1213
+ gm.epsilon = max(state.cfg.epsilon_min, gm.epsilon * state.cfg.epsilon_decay)
1214
+
1215
+ def train(state: WorldState, episodes: int, max_steps: int) -> WorldState:
1216
+ for ep in range(episodes):
1217
+ ep_seed = (state.seed * 1_000_003 + (state.gmetrics.episodes + ep) * 97_531) & 0xFFFFFFFF
1218
+ state = reset_episode_keep_learning(state, seed=int(ep_seed))
1219
+ beliefs = init_beliefs(list(state.agents.keys()))
1220
+ outcome, steps = run_episode(state, beliefs, max_steps=max_steps)
1221
+ update_global_metrics_after_episode(state, outcome, steps)
1222
+
1223
+ state.event_log.append(
1224
+ f"Training: +{episodes} eps | eps={state.gmetrics.epsilon:.3f} | "
1225
+ f"A={state.gmetrics.wins_teamA} B={state.gmetrics.wins_teamB} D={state.gmetrics.draws}"
1226
+ )
1227
+ state = reset_episode_keep_learning(state, seed=state.seed)
1228
+ return state
1229
+
1230
+ # -----------------------------
1231
+ # Export / Import
1232
+ # -----------------------------
1233
+ def export_run(state: WorldState, branches: Dict[str, List[Snapshot]], active_branch: str, rewind_idx: int) -> str:
1234
+ payload = {
1235
+ "seed": state.seed,
1236
+ "env_key": state.env_key,
1237
+ "controlled": state.controlled,
1238
+ "pov": state.pov,
1239
+ "overlay": state.overlay,
1240
+ "cfg": asdict(state.cfg),
1241
+ "gmetrics": asdict(state.gmetrics),
1242
+ "q_tables": state.q_tables,
1243
+ "branches": {b: [asdict(s) for s in snaps] for b, snaps in branches.items()},
1244
+ "active_branch": active_branch,
1245
+ "rewind_idx": int(rewind_idx),
1246
+ "grid": state.grid,
1247
+ "door_opened_global": state.door_opened_global,
1248
+ "base_progress": state.base_progress,
1249
+ "base_target": state.base_target,
1250
+ }
1251
+ txt = json.dumps(payload, indent=2)
1252
+ proof = hash_sha256(txt)
1253
+ return txt + "\n\n" + json.dumps({"proof_sha256": proof}, indent=2)
1254
+
1255
+ def import_run(txt: str) -> Tuple[WorldState, Dict[str, List[Snapshot]], str, int, Dict[str, np.ndarray]]:
1256
+ parts = txt.strip().split("\n\n")
1257
+ data = json.loads(parts[0])
1258
+
1259
+ st = init_state(int(data.get("seed", 1337)), data.get("env_key", "chase"))
1260
+ st.controlled = data.get("controlled", st.controlled)
1261
+ st.pov = data.get("pov", st.pov)
1262
+ st.overlay = bool(data.get("overlay", False))
1263
+ st.grid = data.get("grid", st.grid)
1264
+ st.door_opened_global = bool(data.get("door_opened_global", False))
1265
+ st.base_progress = int(data.get("base_progress", 0))
1266
+ st.base_target = int(data.get("base_target", 10))
1267
+
1268
+ st.cfg = TrainConfig(**data.get("cfg", asdict(st.cfg)))
1269
+ st.gmetrics = GlobalMetrics(**data.get("gmetrics", asdict(st.gmetrics)))
1270
+ st.q_tables = data.get("q_tables", {})
1271
+
1272
+ branches_in = data.get("branches", {})
1273
+ branches: Dict[str, List[Snapshot]] = {}
1274
+ for bname, snaps in branches_in.items():
1275
+ branches[bname] = [Snapshot(**s) for s in snaps]
1276
+
1277
+ active = data.get("active_branch", "main")
1278
+ r_idx = int(data.get("rewind_idx", 0))
1279
+
1280
+ if active in branches and branches[active]:
1281
+ st = restore_into(st, branches[active][-1])
1282
+ st.event_log.append("Imported run (restored last snapshot).")
1283
+ else:
1284
+ st.event_log.append("Imported run (no snapshots).")
1285
+
1286
+ beliefs = init_beliefs(list(st.agents.keys()))
1287
+ return st, branches, active, r_idx, beliefs
1288
+
1289
+ # -----------------------------
1290
+ # UI helpers
1291
+ # -----------------------------
1292
+ def build_views(state: WorldState, beliefs: Dict[str, np.ndarray]) -> Tuple[np.ndarray, Image.Image, Image.Image, Image.Image, Image.Image, str, str, str, str]:
1293
+ for nm, a in state.agents.items():
1294
+ if a.hp > 0:
1295
+ update_belief_for_agent(state, beliefs[nm], a)
1296
+
1297
+ pov = raycast_view(state, state.agents[state.pov])
1298
+ truth_np = np.array(state.grid, dtype=np.int16)
1299
+ truth_img = render_topdown(truth_np, state.agents, f"Truth Map — env={state.env_key} t={state.step} seed={state.seed}", True)
1300
+
1301
+ ctrl = state.controlled
1302
+ others = [k for k in state.agents.keys() if k != ctrl]
1303
+ other = others[0] if others else ctrl
1304
+ b_ctrl = render_topdown(beliefs[ctrl], state.agents, f"{ctrl} Belief", True)
1305
+ b_other = render_topdown(beliefs[other], state.agents, f"{other} Belief", True)
1306
+
1307
+ dash = metrics_dashboard_image(state)
1308
+
1309
+ status = (
1310
+ f"env={state.env_key} | seed={state.seed} | Controlled={state.controlled} | POV={state.pov} | done={state.done} outcome={state.outcome}\n"
1311
+ f"Episode steps={state.step} | base_progress={state.base_progress}/{state.base_target} | doors_open={state.door_opened_global}\n"
1312
+ f"Global: episodes={state.gmetrics.episodes} | A={state.gmetrics.wins_teamA} B={state.gmetrics.wins_teamB} D={state.gmetrics.draws} "
1313
+ f"| winrateA~{state.gmetrics.rolling_winrate_A:.2f} | eps={state.gmetrics.epsilon:.3f}"
1314
+ )
1315
+ events = "\n".join(state.event_log[-18:])
1316
+ trace = "\n".join(state.trace_log[-18:])
1317
+ scoreboard = agent_scoreboard(state)
1318
+ return pov, truth_img, b_ctrl, b_other, dash, status, events, trace, scoreboard
1319
+
1320
+ def grid_click_to_tile(evt: gr.SelectData, selected_tile: int, state: WorldState) -> WorldState:
1321
+ x_px, y_px = evt.index
1322
+ y_px -= 28
1323
+ if y_px < 0:
1324
+ return state
1325
+ gx = int(x_px // TILE)
1326
+ gy = int(y_px // TILE)
1327
+ if not in_bounds(gx, gy):
1328
+ return state
1329
+ if gx == 0 or gy == 0 or gx == GRID_W - 1 or gy == GRID_H - 1:
1330
+ return state
1331
+ state.grid[gy][gx] = selected_tile
1332
+ state.event_log.append(f"t={state.step}: Tile ({gx},{gy}) -> {TILE_NAMES.get(selected_tile)}")
1333
+ return state
1334
+
1335
+ # -----------------------------
1336
+ # Gradio app
1337
+ # -----------------------------
1338
+ TITLE = "ZEN AgentLab — Agent POV + Autoplay Multi-Agent Sims"
1339
+
1340
+ with gr.Blocks(title=TITLE) as demo:
1341
+ gr.Markdown(
1342
+ f"## {TITLE}\n"
1343
+ "**Press Start Autoplay** to watch the sim unfold live. Interject anytime with manual actions or edits.\n"
1344
+ "Use **Cinematic Run** for an instant full-episode spectacle. No background timers beyond the UI autoplay."
1345
+ )
1346
+
1347
+ st0 = init_state(1337, "chase")
1348
+ st = gr.State(st0)
1349
+ branches = gr.State({"main": [snapshot_of(st0, "main")]})
1350
+ active_branch = gr.State("main")
1351
+ rewind_idx = gr.State(0)
1352
+ beliefs = gr.State(init_beliefs(list(st0.agents.keys())))
1353
+
1354
+ autoplay_on = gr.State(False)
1355
+
1356
+ with gr.Row():
1357
+ pov_img = gr.Image(label="POV (Pseudo-3D)", type="numpy", width=VIEW_W, height=VIEW_H)
1358
+ with gr.Column():
1359
+ status = gr.Textbox(label="Status", lines=3)
1360
+ scoreboard = gr.Textbox(label="Agent Scoreboard", lines=8)
1361
+
1362
+ with gr.Row():
1363
+ truth = gr.Image(label="Truth Map (click to edit tiles)", type="pil")
1364
+ belief_a = gr.Image(label="Belief (Controlled)", type="pil")
1365
+ belief_b = gr.Image(label="Belief (Other)", type="pil")
1366
+
1367
+ with gr.Row():
1368
+ dash = gr.Image(label="Metrics Dashboard", type="pil")
1369
+
1370
+ with gr.Row():
1371
+ events = gr.Textbox(label="Event Log", lines=10)
1372
+ trace = gr.Textbox(label="Step Trace", lines=10)
1373
+
1374
+ with gr.Row():
1375
+ with gr.Column(scale=2):
1376
+ gr.Markdown("### Quick Start (Examples)")
1377
+ examples = gr.Examples(
1378
+ examples=[
1379
+ ["chase", 1337],
1380
+ ["vault", 2024],
1381
+ ["civ", 777],
1382
+ ],
1383
+ inputs=[],
1384
+ label="",
1385
+ )
1386
+ gr.Markdown("Pick an environment + seed below, then click **Apply**.")
1387
+
1388
+ with gr.Row():
1389
+ env_pick = gr.Radio(
1390
+ choices=[("Chase (Predator vs Prey)", "chase"),
1391
+ ("CoopVault (team vs guardian)", "vault"),
1392
+ ("MiniCiv (build + raid)", "civ")],
1393
+ value="chase",
1394
+ label="Environment"
1395
+ )
1396
+ seed_box = gr.Number(value=1337, precision=0, label="Seed")
1397
+
1398
+ with gr.Row():
1399
+ btn_apply_env_seed = gr.Button("Apply (Env + Seed)")
1400
+ btn_reset_ep = gr.Button("Reset Episode (keep learning)")
1401
+
1402
+ gr.Markdown("### Autoplay + Spectacle")
1403
+ with gr.Row():
1404
+ autoplay_speed = gr.Slider(0.05, 1.0, value=0.20, step=0.05, label="Autoplay step interval (seconds)")
1405
+ with gr.Row():
1406
+ btn_autoplay_start = gr.Button("▶ Start Autoplay")
1407
+ btn_autoplay_stop = gr.Button("⏸ Stop Autoplay")
1408
+ with gr.Row():
1409
+ cinematic_steps = gr.Number(value=350, precision=0, label="Cinematic max steps")
1410
+ btn_cinematic = gr.Button("🎬 Cinematic Run (Full Episode)")
1411
+
1412
+ gr.Markdown("### Manual Controls (Interject Anytime)")
1413
+ with gr.Row():
1414
+ btn_L = gr.Button("L")
1415
+ btn_F = gr.Button("F")
1416
+ btn_R = gr.Button("R")
1417
+ btn_I = gr.Button("I (Interact)")
1418
+ with gr.Row():
1419
+ btn_tick = gr.Button("Tick")
1420
+ run_steps = gr.Number(value=25, label="Run N steps", precision=0)
1421
+ btn_run = gr.Button("Run")
1422
+
1423
+ with gr.Row():
1424
+ btn_toggle_control = gr.Button("Toggle Controlled")
1425
+ btn_toggle_pov = gr.Button("Toggle POV")
1426
+ overlay = gr.Checkbox(False, label="Overlay reticle")
1427
+
1428
+ tile_pick = gr.Radio(
1429
+ choices=[(TILE_NAMES[k], k) for k in [EMPTY, WALL, FOOD, NOISE, DOOR, TELE, KEY, EXIT, ARTIFACT, HAZARD, WOOD, ORE, MEDKIT, SWITCH, BASE]],
1430
+ value=WALL,
1431
+ label="Paint tile type"
1432
+ )
1433
+
1434
+ with gr.Column(scale=3):
1435
+ gr.Markdown("### Training Controls (Tabular Q-learning)")
1436
+ use_q = gr.Checkbox(True, label="Use Q-learning (agents with brain='q')")
1437
+ alpha = gr.Slider(0.01, 0.5, value=0.15, step=0.01, label="alpha")
1438
+ gamma = gr.Slider(0.5, 0.99, value=0.95, step=0.01, label="gamma")
1439
+ eps = gr.Slider(0.0, 0.5, value=0.10, step=0.01, label="epsilon")
1440
+ eps_decay = gr.Slider(0.90, 0.999, value=0.995, step=0.001, label="epsilon decay")
1441
+ eps_min = gr.Slider(0.0, 0.2, value=0.02, step=0.01, label="epsilon min")
1442
+
1443
+ episodes = gr.Number(value=50, label="Train episodes", precision=0)
1444
+ max_steps = gr.Number(value=260, label="Max steps/episode", precision=0)
1445
+ btn_train = gr.Button("Train")
1446
+
1447
+ btn_reset_all = gr.Button("Reset ALL (wipe Q + metrics)")
1448
+
1449
+ with gr.Row():
1450
+ with gr.Column(scale=2):
1451
+ gr.Markdown("### Timeline + Branching")
1452
+ rewind = gr.Slider(0, 0, value=0, step=1, label="Rewind index (active branch)")
1453
+ btn_jump = gr.Button("Jump to index")
1454
+ new_branch_name = gr.Textbox(value="fork1", label="New branch name")
1455
+ btn_fork = gr.Button("Fork from current rewind")
1456
+
1457
+ with gr.Column(scale=2):
1458
+ branch_pick = gr.Dropdown(choices=["main"], value="main", label="Active branch")
1459
+ btn_set_branch = gr.Button("Set Active Branch")
1460
+
1461
+ with gr.Column(scale=3):
1462
+ export_box = gr.Textbox(label="Export JSON (+ proof hash)", lines=8)
1463
+ btn_export = gr.Button("Export")
1464
+ import_box = gr.Textbox(label="Import JSON", lines=8)
1465
+ btn_import = gr.Button("Import")
1466
+
1467
+ # Autoplay timer (inactive by default)
1468
+ timer = gr.Timer(value=0.20, active=False)
1469
+
1470
+ # ---------- glue ----------
1471
+ def refresh(state: WorldState, branches_d: Dict[str, List[Snapshot]], active: str, bel: Dict[str, np.ndarray], r: int):
1472
+ snaps = branches_d.get(active, [])
1473
+ r_max = max(0, len(snaps) - 1)
1474
+ r = max(0, min(int(r), r_max))
1475
+ pov, tr, ba, bb, dimg, stxt, etxt, ttxt, sb = build_views(state, bel)
1476
+ branch_choices = sorted(list(branches_d.keys()))
1477
+ return (
1478
+ pov, tr, ba, bb, dimg, stxt, sb, etxt, ttxt,
1479
+ gr.update(maximum=r_max, value=r), r,
1480
+ gr.update(choices=branch_choices, value=active),
1481
+ gr.update(choices=branch_choices, value=active),
1482
+ )
1483
+
1484
+ def push_hist(state: WorldState, branches_d: Dict[str, List[Snapshot]], active: str) -> Dict[str, List[Snapshot]]:
1485
+ branches_d.setdefault(active, [])
1486
+ branches_d[active].append(snapshot_of(state, active))
1487
+ if len(branches_d[active]) > MAX_HISTORY:
1488
+ branches_d[active].pop(0)
1489
+ return branches_d
1490
+
1491
+ def set_cfg(state: WorldState, use_q_v: bool, a: float, g: float, e: float, ed: float, emin: float) -> WorldState:
1492
+ state.cfg.use_q = bool(use_q_v)
1493
+ state.cfg.alpha = float(a)
1494
+ state.cfg.gamma = float(g)
1495
+ state.gmetrics.epsilon = float(e)
1496
+ state.cfg.epsilon_decay = float(ed)
1497
+ state.cfg.epsilon_min = float(emin)
1498
+ return state
1499
+
1500
+ def do_manual(state, branches_d, active, bel, r, act):
1501
+ tick(state, bel, manual_action=act)
1502
+ branches_d = push_hist(state, branches_d, active)
1503
+ r = len(branches_d[active]) - 1
1504
+ out = refresh(state, branches_d, active, bel, r)
1505
+ return out + (state, branches_d, active, bel, r)
1506
+
1507
+ def do_tick(state, branches_d, active, bel, r):
1508
+ tick(state, bel, manual_action=None)
1509
+ branches_d = push_hist(state, branches_d, active)
1510
+ r = len(branches_d[active]) - 1
1511
+ out = refresh(state, branches_d, active, bel, r)
1512
+ return out + (state, branches_d, active, bel, r)
1513
+
1514
+ def do_run(state, branches_d, active, bel, r, n):
1515
+ n = max(1, int(n))
1516
+ for _ in range(n):
1517
+ if state.done:
1518
+ break
1519
+ tick(state, bel, manual_action=None)
1520
+ branches_d = push_hist(state, branches_d, active)
1521
+ r = len(branches_d[active]) - 1
1522
+ out = refresh(state, branches_d, active, bel, r)
1523
+ return out + (state, branches_d, active, bel, r)
1524
+
1525
+ def toggle_control(state, branches_d, active, bel, r):
1526
+ order = list(state.agents.keys())
1527
+ i = order.index(state.controlled)
1528
+ state.controlled = order[(i + 1) % len(order)]
1529
+ state.event_log.append(f"Controlled -> {state.controlled}")
1530
+ branches_d = push_hist(state, branches_d, active)
1531
+ r = len(branches_d[active]) - 1
1532
+ out = refresh(state, branches_d, active, bel, r)
1533
+ return out + (state, branches_d, active, bel, r)
1534
+
1535
+ def toggle_pov(state, branches_d, active, bel, r):
1536
+ order = list(state.agents.keys())
1537
+ i = order.index(state.pov)
1538
+ state.pov = order[(i + 1) % len(order)]
1539
+ state.event_log.append(f"POV -> {state.pov}")
1540
+ branches_d = push_hist(state, branches_d, active)
1541
+ r = len(branches_d[active]) - 1
1542
+ out = refresh(state, branches_d, active, bel, r)
1543
+ return out + (state, branches_d, active, bel, r)
1544
+
1545
+ def set_overlay(state, branches_d, active, bel, r, ov):
1546
+ state.overlay = bool(ov)
1547
+ out = refresh(state, branches_d, active, bel, r)
1548
+ return out + (state, branches_d, active, bel, r)
1549
+
1550
+ def click_truth(tile, state, branches_d, active, bel, r, evt: gr.SelectData):
1551
+ state = grid_click_to_tile(evt, int(tile), state)
1552
+ branches_d = push_hist(state, branches_d, active)
1553
+ r = len(branches_d[active]) - 1
1554
+ out = refresh(state, branches_d, active, bel, r)
1555
+ return out + (state, branches_d, active, bel, r)
1556
+
1557
+ def jump(state, branches_d, active, bel, r, idx):
1558
+ snaps = branches_d.get(active, [])
1559
+ if not snaps:
1560
+ out = refresh(state, branches_d, active, bel, r)
1561
+ return out + (state, branches_d, active, bel, r)
1562
+ idx = max(0, min(int(idx), len(snaps) - 1))
1563
+ state = restore_into(state, snaps[idx])
1564
+ r = idx
1565
+ out = refresh(state, branches_d, active, bel, r)
1566
+ return out + (state, branches_d, active, bel, r)
1567
+
1568
+ def fork_branch(state, branches_d, active, bel, r, new_name):
1569
+ new_name = (new_name or "").strip() or "fork"
1570
+ new_name = new_name.replace(" ", "_")
1571
+ snaps = branches_d.get(active, [])
1572
+ if not snaps:
1573
+ branches_d[new_name] = [snapshot_of(state, new_name)]
1574
+ else:
1575
+ idx = max(0, min(int(r), len(snaps) - 1))
1576
+ branches_d[new_name] = [Snapshot(**asdict(s)) for s in snaps[:idx + 1]]
1577
+ state = restore_into(state, branches_d[new_name][-1])
1578
+ active = new_name
1579
+ state.event_log.append(f"Forked branch -> {new_name}")
1580
+ branches_d = push_hist(state, branches_d, active)
1581
+ r = len(branches_d[active]) - 1
1582
+ out = refresh(state, branches_d, active, bel, r)
1583
+ return out + (state, branches_d, active, bel, r)
1584
+
1585
+ def set_active_branch(state, branches_d, active, bel, r, br):
1586
+ br = br or "main"
1587
+ if br not in branches_d:
1588
+ branches_d[br] = [snapshot_of(state, br)]
1589
+ active = br
1590
+ if branches_d[active]:
1591
+ state = restore_into(state, branches_d[active][-1])
1592
+ bel = init_beliefs(list(state.agents.keys()))
1593
+ r = len(branches_d[active]) - 1
1594
+ out = refresh(state, branches_d, active, bel, r)
1595
+ return out + (state, branches_d, active, bel, r)
1596
+
1597
+ def apply_env_seed(state, branches_d, active, bel, r, env_key, seed_val):
1598
+ env_key = env_key or "chase"
1599
+ seed_val = int(seed_val) if seed_val is not None else state.seed
1600
+
1601
+ # Preserve learning across env swaps
1602
+ old_cfg = state.cfg
1603
+ old_q = state.q_tables
1604
+ old_gm = state.gmetrics
1605
+
1606
+ state = init_state(seed_val, env_key)
1607
+ state.cfg = old_cfg
1608
+ state.q_tables = old_q
1609
+ state.gmetrics = old_gm
1610
+
1611
+ bel = init_beliefs(list(state.agents.keys()))
1612
+ active = "main"
1613
+ branches_d = {"main": [snapshot_of(state, "main")]}
1614
+ r = 0
1615
+ out = refresh(state, branches_d, active, bel, r)
1616
+ return out + (state, branches_d, active, bel, r)
1617
+
1618
+ def reset_ep(state, branches_d, active, bel, r):
1619
+ state = reset_episode_keep_learning(state, seed=state.seed)
1620
+ bel = init_beliefs(list(state.agents.keys()))
1621
+ branches_d = {active: [snapshot_of(state, active)]}
1622
+ r = 0
1623
+ out = refresh(state, branches_d, active, bel, r)
1624
+ return out + (state, branches_d, active, bel, r)
1625
+
1626
+ def reset_all(state, branches_d, active, bel, r, env_key, seed_val):
1627
+ env_key = env_key or state.env_key
1628
+ seed_val = int(seed_val) if seed_val is not None else state.seed
1629
+ state = wipe_all(seed=seed_val, env_key=env_key)
1630
+ bel = init_beliefs(list(state.agents.keys()))
1631
+ active = "main"
1632
+ branches_d = {"main": [snapshot_of(state, "main")]}
1633
+ r = 0
1634
+ out = refresh(state, branches_d, active, bel, r)
1635
+ return out + (state, branches_d, active, bel, r)
1636
+
1637
+ def do_train(state, branches_d, active, bel, r,
1638
+ use_q_v, a, g, e, ed, emin,
1639
+ eps_count, max_s):
1640
+ state = set_cfg(state, use_q_v, a, g, e, ed, emin)
1641
+ state = train(state, episodes=max(1, int(eps_count)), max_steps=max(10, int(max_s)))
1642
+ bel = init_beliefs(list(state.agents.keys()))
1643
+ branches_d = {"main": [snapshot_of(state, "main")]}
1644
+ active = "main"
1645
+ r = 0
1646
+ out = refresh(state, branches_d, active, bel, r)
1647
+ return out + (state, branches_d, active, bel, r)
1648
+
1649
+ def cinematic_run(state, branches_d, active, bel, r, max_s):
1650
+ max_s = max(10, int(max_s))
1651
+ # Reset episode so the cinematic is clean
1652
+ state = reset_episode_keep_learning(state, seed=state.seed)
1653
+ bel = init_beliefs(list(state.agents.keys()))
1654
+ # Run to completion (or max steps) in one click
1655
+ while state.step < max_s and not state.done:
1656
+ tick(state, bel, manual_action=None)
1657
+
1658
+ state.event_log.append(f"Cinematic finished: outcome={state.outcome} steps={state.step}")
1659
+ branches_d = push_hist(state, branches_d, active)
1660
+ r = len(branches_d[active]) - 1
1661
+ out = refresh(state, branches_d, active, bel, r)
1662
+ return out + (state, branches_d, active, bel, r)
1663
+
1664
+ def export_fn(state, branches_d, active, r):
1665
+ return export_run(state, branches_d, active, int(r))
1666
+
1667
+ def import_fn(txt):
1668
+ state, branches_d, active, r, bel = import_run(txt)
1669
+ branches_d.setdefault(active, [])
1670
+ if not branches_d[active]:
1671
+ branches_d[active].append(snapshot_of(state, active))
1672
+ out = refresh(state, branches_d, active, bel, r)
1673
+ return out + (state, branches_d, active, bel, r)
1674
+
1675
+ # ---- Autoplay control ----
1676
+ def autoplay_start(state, branches_d, active, bel, r, interval_s):
1677
+ interval_s = float(interval_s)
1678
+ # Enable timer + autoplay flag
1679
+ return (
1680
+ gr.update(value=interval_s, active=True),
1681
+ True,
1682
+ state, branches_d, active, bel, r
1683
+ )
1684
+
1685
+ def autoplay_stop(state, branches_d, active, bel, r):
1686
+ return (
1687
+ gr.update(active=False),
1688
+ False,
1689
+ state, branches_d, active, bel, r
1690
+ )
1691
+
1692
+ def autoplay_tick(state, branches_d, active, bel, r, is_on: bool):
1693
+ # If not on, do nothing (also keep timer active state as-is)
1694
+ if not is_on:
1695
+ out = refresh(state, branches_d, active, bel, r)
1696
+ return out + (state, branches_d, active, bel, r, is_on, gr.update())
1697
+
1698
+ # Step once
1699
+ if not state.done:
1700
+ tick(state, bel, manual_action=None)
1701
+ branches_d = push_hist(state, branches_d, active)
1702
+ r = len(branches_d[active]) - 1
1703
+
1704
+ # If done, stop autoplay automatically
1705
+ if state.done:
1706
+ out = refresh(state, branches_d, active, bel, r)
1707
+ return out + (state, branches_d, active, bel, r, False, gr.update(active=False))
1708
+
1709
+ out = refresh(state, branches_d, active, bel, r)
1710
+ return out + (state, branches_d, active, bel, r, True, gr.update())
1711
+
1712
+ # ---- wiring ----
1713
+ common_outputs = [
1714
+ pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1715
+ rewind, rewind_idx, branch_pick, branch_pick,
1716
+ st, branches, active_branch, beliefs, rewind_idx
1717
+ ]
1718
+
1719
+ btn_L.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"L"),
1720
+ inputs=[st, branches, active_branch, beliefs, rewind_idx],
1721
+ outputs=common_outputs, queue=True)
1722
+
1723
+ btn_F.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"F"),
1724
+ inputs=[st, branches, active_branch, beliefs, rewind_idx],
1725
+ outputs=common_outputs, queue=True)
1726
+
1727
+ btn_R.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"R"),
1728
+ inputs=[st, branches, active_branch, beliefs, rewind_idx],
1729
+ outputs=common_outputs, queue=True)
1730
+
1731
+ btn_I.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"I"),
1732
+ inputs=[st, branches, active_branch, beliefs, rewind_idx],
1733
+ outputs=common_outputs, queue=True)
1734
+
1735
+ btn_tick.click(do_tick,
1736
+ inputs=[st, branches, active_branch, beliefs, rewind_idx],
1737
+ outputs=common_outputs, queue=True)
1738
+
1739
+ btn_run.click(do_run,
1740
+ inputs=[st, branches, active_branch, beliefs, rewind_idx, run_steps],
1741
+ outputs=common_outputs, queue=True)
1742
+
1743
+ btn_toggle_control.click(toggle_control,
1744
+ inputs=[st, branches, active_branch, beliefs, rewind_idx],
1745
+ outputs=common_outputs, queue=True)
1746
+
1747
+ btn_toggle_pov.click(toggle_pov,
1748
+ inputs=[st, branches, active_branch, beliefs, rewind_idx],
1749
+ outputs=common_outputs, queue=True)
1750
+
1751
+ overlay.change(set_overlay,
1752
+ inputs=[st, branches, active_branch, beliefs, rewind_idx, overlay],
1753
+ outputs=common_outputs, queue=True)
1754
+
1755
+ truth.select(click_truth,
1756
+ inputs=[tile_pick, st, branches, active_branch, beliefs, rewind_idx],
1757
+ outputs=common_outputs, queue=True)
1758
+
1759
+ btn_jump.click(jump,
1760
+ inputs=[st, branches, active_branch, beliefs, rewind_idx, rewind],
1761
+ outputs=common_outputs, queue=True)
1762
+
1763
+ btn_fork.click(fork_branch,
1764
+ inputs=[st, branches, active_branch, beliefs, rewind_idx, new_branch_name],
1765
+ outputs=common_outputs, queue=True)
1766
+
1767
+ btn_set_branch.click(set_active_branch,
1768
+ inputs=[st, branches, active_branch, beliefs, rewind_idx, branch_pick],
1769
+ outputs=common_outputs, queue=True)
1770
+
1771
+ btn_apply_env_seed.click(apply_env_seed,
1772
+ inputs=[st, branches, active_branch, beliefs, rewind_idx, env_pick, seed_box],
1773
+ outputs=common_outputs, queue=True)
1774
+
1775
+ btn_reset_ep.click(reset_ep,
1776
+ inputs=[st, branches, active_branch, beliefs, rewind_idx],
1777
+ outputs=common_outputs, queue=True)
1778
+
1779
+ btn_reset_all.click(reset_all,
1780
+ inputs=[st, branches, active_branch, beliefs, rewind_idx, env_pick, seed_box],
1781
+ outputs=common_outputs, queue=True)
1782
+
1783
+ btn_train.click(do_train,
1784
+ inputs=[st, branches, active_branch, beliefs, rewind_idx,
1785
+ use_q, alpha, gamma, eps, eps_decay, eps_min,
1786
+ episodes, max_steps],
1787
+ outputs=common_outputs, queue=True)
1788
+
1789
+ btn_cinematic.click(cinematic_run,
1790
+ inputs=[st, branches, active_branch, beliefs, rewind_idx, cinematic_steps],
1791
+ outputs=common_outputs, queue=True)
1792
+
1793
+ btn_export.click(export_fn, inputs=[st, branches, active_branch, rewind_idx], outputs=[export_box], queue=True)
1794
+
1795
+ btn_import.click(import_fn,
1796
+ inputs=[import_box],
1797
+ outputs=common_outputs, queue=True)
1798
+
1799
+ # Autoplay start/stop wires
1800
+ btn_autoplay_start.click(
1801
+ autoplay_start,
1802
+ inputs=[st, branches, active_branch, beliefs, rewind_idx, autoplay_speed],
1803
+ outputs=[timer, autoplay_on, st, branches, active_branch, beliefs, rewind_idx],
1804
+ queue=True
1805
+ )
1806
+
1807
+ btn_autoplay_stop.click(
1808
+ autoplay_stop,
1809
+ inputs=[st, branches, active_branch, beliefs, rewind_idx],
1810
+ outputs=[timer, autoplay_on, st, branches, active_branch, beliefs, rewind_idx],
1811
+ queue=True
1812
+ )
1813
+
1814
+ # Timer tick: step and update UI; auto-stop when done
1815
+ timer.tick(
1816
+ autoplay_tick,
1817
+ inputs=[st, branches, active_branch, beliefs, rewind_idx, autoplay_on],
1818
+ outputs=[
1819
+ pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1820
+ rewind, rewind_idx, branch_pick, branch_pick,
1821
+ st, branches, active_branch, beliefs, rewind_idx,
1822
+ autoplay_on, timer
1823
+ ],
1824
+ queue=True
1825
+ )
1826
+
1827
+ demo.load(
1828
+ refresh,
1829
+ inputs=[st, branches, active_branch, beliefs, rewind_idx],
1830
+ outputs=[
1831
+ pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1832
+ rewind, rewind_idx, branch_pick, branch_pick
1833
+ ],
1834
+ queue=True
1835
+ )
1836
+
1837
+ # Disable SSR for HF stability
1838
+ demo.queue().launch(ssr_mode=False)