E-Rong commited on
Commit
7be626a
·
verified ·
1 Parent(s): c19c488

Upload train_all_phases.py

Browse files
Files changed (1) hide show
  1. train_all_phases.py +651 -0
train_all_phases.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Full training pipeline: Phase 1 -> Phase 2 -> Phase 3
4
+ TIL-26-AE Bomberman Agent Training
5
+
6
+ Run with:
7
+ TOTAL_TIMESTEPS=500_000:500_000:1_000_000 \
8
+ HUB_MODEL_ID=E-Rong/til-26-ae-agent \
9
+ TRACKIO_PROJECT=til-26-ae \
10
+ python train_all_phases.py
11
+
12
+ References:
13
+ - Pommerman multi-agent RL: arxiv:2407.00662
14
+ - MAPPO best practices: arxiv:2103.01955
15
+ - Invalid Action Masking: arxiv:2006.14171
16
+ """
17
+
18
+ import os
19
+ import sys
20
+ import subprocess
21
+
22
+ # Bootstrap: download and set up the TIL environment if not present
23
+ repo_path = "/app/til-26-ae-repo/til-26-ae"
24
+ if not os.path.exists(repo_path):
25
+ try:
26
+ from huggingface_hub import snapshot_download
27
+ snapshot_download(
28
+ repo_id='e-rong/til-26-ae',
29
+ repo_type='space',
30
+ local_dir='/app/til-26-ae-repo',
31
+ local_dir_use_symlinks=False
32
+ )
33
+ except Exception:
34
+ subprocess.run(
35
+ ["git", "clone", "https://huggingface.co/spaces/e-rong/til-26-ae", "/app/til-26-ae-repo"],
36
+ capture_output=True, check=False
37
+ )
38
+
39
+ if os.path.exists(repo_path):
40
+ sys.path.insert(0, repo_path)
41
+ elif os.path.exists("/app/til-26-ae-repo"):
42
+ sys.path.insert(0, "/app/til-26-ae-repo")
43
+
44
+ import numpy as np
45
+ import gymnasium as gym
46
+ from gymnasium.spaces import Box, Discrete
47
+ import torch
48
+
49
+ from til_environment.bomberman_env import Bomberman
50
+ from til_environment.config import default_config
51
+ from pettingzoo.utils.conversions import aec_to_parallel
52
+ from sb3_contrib import MaskablePPO
53
+ from sb3_contrib.common.wrappers import ActionMasker
54
+ from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
55
+ from stable_baselines3.common.monitor import Monitor
56
+ import trackio
57
+
58
+
59
+ # ============================================================================
60
+ # PHASE 1: Base environment wrapper
61
+ # ============================================================================
62
+
63
+ class BombermanSingleAgentEnv(gym.Env):
64
+ """
65
+ Wraps parallel PettingZoo Bomberman into a single-agent gymnasium env.
66
+ Agent 0 is the learning agent; opponents use random valid actions.
67
+ """
68
+
69
+ def __init__(self, cfg=None, seed=None, opponent_policy="random"):
70
+ super().__init__()
71
+ self.cfg = cfg or default_config()
72
+ self.cfg.env.render_mode = None
73
+
74
+ raw = Bomberman(self.cfg)
75
+ self._parallel_env = aec_to_parallel(raw)
76
+ self.agent_id = "agent_0"
77
+ self.opponent_policy = opponent_policy
78
+ self._episode_seed = seed
79
+ self._episode_count = 0
80
+
81
+ self.action_space = Discrete(6)
82
+
83
+ self._last_action_mask = None
84
+ self._obs_size = None
85
+ self._last_obs_dict = None
86
+
87
+ self._compute_obs_space()
88
+
89
+ def _compute_obs_space(self):
90
+ cfg = self.cfg
91
+ viewcone_l = int(cfg.dynamics.vision.behind) + int(cfg.dynamics.vision.ahead) + 1
92
+ viewcone_w = int(cfg.dynamics.vision.left) + int(cfg.dynamics.vision.right) + 1
93
+ agent_viewcone_size = viewcone_l * viewcone_w * 25
94
+ base_r = int(cfg.entities.base.vision_radius)
95
+ base_side = 2 * base_r + 1
96
+ base_viewcone_size = base_side * base_side * 25
97
+ scalar_size = 11
98
+ self._obs_size = agent_viewcone_size + base_viewcone_size + scalar_size
99
+ self.observation_space = Box(
100
+ low=-np.inf, high=np.inf,
101
+ shape=(self._obs_size,), dtype=np.float32,
102
+ )
103
+
104
+ def reset(self, seed=None, options=None):
105
+ if seed is not None:
106
+ self._episode_seed = seed
107
+ else:
108
+ self._episode_seed = self._episode_count
109
+ self._episode_count += 1
110
+
111
+ obs_dict, info_dict = self._parallel_env.reset(seed=self._episode_seed, options=options)
112
+ self._store_action_mask(obs_dict[self.agent_id])
113
+ self._last_obs_dict = obs_dict
114
+ return self._flatten_obs(obs_dict[self.agent_id]), {}
115
+
116
+ def step(self, action):
117
+ actions = {}
118
+ for agent_id in self._parallel_env.agents:
119
+ if agent_id == self.agent_id:
120
+ actions[agent_id] = action
121
+ else:
122
+ mask = (
123
+ self._last_obs_dict[agent_id]["action_mask"]
124
+ if self._last_obs_dict and agent_id in self._last_obs_dict
125
+ else np.ones(6, dtype=np.int8)
126
+ )
127
+ valid = np.where(mask == 1)[0]
128
+ actions[agent_id] = int(np.random.choice(valid)) if len(valid) > 0 else 0
129
+
130
+ obs_dict, rewards, terminations, truncations, infos = self._parallel_env.step(actions)
131
+ self._last_obs_dict = obs_dict
132
+
133
+ if self.agent_id not in obs_dict:
134
+ return np.zeros(self._obs_size, dtype=np.float32), 0.0, True, False, {}
135
+
136
+ self._store_action_mask(obs_dict[self.agent_id])
137
+ obs = self._flatten_obs(obs_dict[self.agent_id])
138
+ reward = float(rewards.get(self.agent_id, 0.0))
139
+ done = terminations.get(self.agent_id, False) or truncations.get(self.agent_id, False)
140
+
141
+ return obs, reward, done, False, infos.get(self.agent_id, {})
142
+
143
+ def _store_action_mask(self, obs_dict):
144
+ if "action_mask" in obs_dict:
145
+ self._last_action_mask = obs_dict["action_mask"].copy().astype(bool)
146
+ else:
147
+ self._last_action_mask = np.ones(6, dtype=bool)
148
+
149
+ def action_masks(self):
150
+ return self._last_action_mask
151
+
152
+ def _flatten_obs(self, obs_dict):
153
+ return np.concatenate(
154
+ [
155
+ obs_dict["agent_viewcone"].flatten(),
156
+ obs_dict["base_viewcone"].flatten(),
157
+ np.array([obs_dict["direction"]], dtype=np.float32),
158
+ obs_dict["location"].flatten().astype(np.float32),
159
+ obs_dict["base_location"].flatten().astype(np.float32),
160
+ obs_dict["health"].flatten().astype(np.float32),
161
+ np.array([obs_dict["frozen_ticks"]], dtype=np.float32),
162
+ obs_dict["base_health"].flatten().astype(np.float32),
163
+ obs_dict["team_resources"].flatten().astype(np.float32),
164
+ np.array([obs_dict["team_bombs"]], dtype=np.float32),
165
+ np.array([obs_dict["step"]], dtype=np.float32),
166
+ ],
167
+ dtype=np.float32,
168
+ )
169
+
170
+ def close(self):
171
+ self._parallel_env.close()
172
+
173
+
174
+ # ============================================================================
175
+ # PHASE 2: Exploration reward shaping
176
+ # ============================================================================
177
+
178
+ class RewardShapingWrapper(gym.Wrapper):
179
+ """
180
+ Adds visit-count exploration bonus with adaptive annealing.
181
+ alpha = 1 - tanh(k * avg_enemy_deaths) gradually reduces exploration weight.
182
+ """
183
+
184
+ def __init__(self, env, adaptive_k=1.2, base_explore_weight=0.5):
185
+ super().__init__(env)
186
+ self.adaptive_k = adaptive_k
187
+ self.base_explore_weight = base_explore_weight
188
+ self._visit_counts = None
189
+ self._grid_size = 16
190
+ self._avg_enemy_deaths = 0.0
191
+ self._episode_count = 0
192
+ self._episode_enemy_deaths = 0
193
+ self._explore_weight = base_explore_weight
194
+
195
+ def reset(self, **kwargs):
196
+ self._visit_counts = np.zeros((self._grid_size, self._grid_size), dtype=np.int32)
197
+ self._episode_enemy_deaths = 0
198
+ return self.env.reset(**kwargs)
199
+
200
+ def step(self, action):
201
+ obs, reward, done, truncated, info = self.env.step(action)
202
+
203
+ pos = info.get("location", None)
204
+ visit_bonus = 0.0
205
+ if pos is not None:
206
+ x, y = int(pos[0]), int(pos[1])
207
+ if 0 <= x < self._grid_size and 0 <= y < self._grid_size:
208
+ visits = self._visit_counts[x, y]
209
+ visit_bonus = 1.0 / (1.0 + visits)
210
+ self._visit_counts[x, y] += 1
211
+
212
+ if done:
213
+ self._episode_count += 1
214
+ alpha = 1.0 - np.tanh(self.adaptive_k * self._avg_enemy_deaths)
215
+ self._explore_weight = self.base_explore_weight * max(0.1, alpha)
216
+ self._avg_enemy_deaths = 0.95 * self._avg_enemy_deaths + 0.05 * self._episode_enemy_deaths
217
+
218
+ shaped_reward = reward + self._explore_weight * visit_bonus
219
+ info["raw_reward"] = reward
220
+ info["explore_bonus"] = visit_bonus
221
+ info["explore_weight"] = self._explore_weight
222
+
223
+ return obs, shaped_reward, done, truncated, info
224
+
225
+ def action_masks(self):
226
+ return self.env.action_masks()
227
+
228
+
229
+ # ============================================================================
230
+ # PHASE 3: Rule-based opponents + curriculum
231
+ # ============================================================================
232
+
233
+ class RuleBasedOpponent:
234
+ """Rule-based Bomberman opponent with three difficulty levels."""
235
+
236
+ def __init__(self, team_id=1, difficulty="simple"):
237
+ self.team_id = team_id
238
+ self.difficulty = difficulty
239
+ self.visited = None
240
+ self.grid_size = 16
241
+
242
+ def reset(self):
243
+ self.visited = np.zeros((self.grid_size, self.grid_size), dtype=np.int32)
244
+
245
+ def act(self, obs_dict):
246
+ action_mask = obs_dict["action_mask"]
247
+ valid_actions = np.where(action_mask == 1)[0]
248
+ if len(valid_actions) == 0:
249
+ return 4 # STAY
250
+
251
+ if self.difficulty == "static":
252
+ return 4
253
+
254
+ elif self.difficulty == "simple":
255
+ viewcone = obs_dict["agent_viewcone"]
256
+ has_enemy = np.any(viewcone[..., 10] > 0)
257
+ has_enemy_base = np.any(viewcone[..., 12] > 0)
258
+
259
+ if (has_enemy or has_enemy_base) and 5 in valid_actions:
260
+ return 5
261
+
262
+ movement_actions = [a for a in valid_actions if a < 4]
263
+ if len(movement_actions) > 0:
264
+ return int(np.random.choice(movement_actions))
265
+ return 4
266
+
267
+ elif self.difficulty == "smart":
268
+ return self._smart_policy(obs_dict, valid_actions)
269
+
270
+ return 4
271
+
272
+ def _smart_policy(self, obs, valid_actions):
273
+ viewcone = obs["agent_viewcone"]
274
+ h, w, _ = viewcone.shape
275
+
276
+ collectibles = np.stack([
277
+ viewcone[..., 7], viewcone[..., 8], viewcone[..., 6],
278
+ ], axis=-1)
279
+ has_collectible = np.any(collectibles > 0, axis=-1)
280
+
281
+ cx, cy = 3, 2
282
+
283
+ best_action = 4
284
+ best_score = -1
285
+
286
+ for action in valid_actions:
287
+ if action == 4 or action == 5:
288
+ continue
289
+
290
+ if action == 0:
291
+ nx, ny = cx - 1, cy
292
+ elif action == 1:
293
+ nx, ny = cx + 1, cy
294
+ elif action == 2:
295
+ nx, ny = cx, cy - 1
296
+ elif action == 3:
297
+ nx, ny = cx, cy + 1
298
+ else:
299
+ continue
300
+
301
+ if 0 <= nx < h and 0 <= ny < w:
302
+ score = 0
303
+ if has_collectible[nx, ny]:
304
+ score += 10.0
305
+ if viewcone[nx, ny, 0] < 1:
306
+ score -= 5.0
307
+ wall_score = (
308
+ viewcone[nx, ny, 1] + viewcone[nx, ny, 2]
309
+ + viewcone[nx, ny, 3] + viewcone[nx, ny, 4]
310
+ )
311
+ score -= wall_score * 2.0
312
+
313
+ if score > best_score:
314
+ best_score = score
315
+ best_action = action
316
+
317
+ for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1), (0, 0)]:
318
+ nx, ny = cx + dx, cy + dy
319
+ if 0 <= nx < h and 0 <= ny < w:
320
+ if viewcone[nx, ny, 10] > 0 or viewcone[nx, ny, 12] > 0:
321
+ if 5 in valid_actions and np.random.random() < 0.7:
322
+ return 5
323
+ break
324
+
325
+ return int(best_action) if best_score > -1 else 4
326
+
327
+
328
+ class CurriculumEnv(gym.Env):
329
+ """Single-agent env with curriculum-based opponent difficulty."""
330
+
331
+ CURRICULUM_STAGES = ["static", "simple", "smart", "mixed"]
332
+ WIN_RATE_THRESHOLD = 0.55
333
+ EPISODES_PER_STAGE = 500
334
+
335
+ def __init__(self, cfg=None, seed=None):
336
+ super().__init__()
337
+ self.cfg = cfg or default_config()
338
+ self.cfg.env.render_mode = None
339
+
340
+ raw = Bomberman(self.cfg)
341
+ self._parallel_env = aec_to_parallel(raw)
342
+ self.agent_id = "agent_0"
343
+ self._episode_seed = seed
344
+ self._episode_count = 0
345
+
346
+ self.action_space = Discrete(6)
347
+
348
+ self._last_action_mask = None
349
+ self._obs_size = None
350
+ self._last_obs_dict = None
351
+
352
+ self._compute_obs_space()
353
+
354
+ self.stage_idx = 0
355
+ self.stage_episodes = 0
356
+ self.stage_wins = 0
357
+ self.stage_rewards = []
358
+
359
+ self.opponents = {}
360
+ self._init_opponents()
361
+
362
+ def _compute_obs_space(self):
363
+ cfg = self.cfg
364
+ viewcone_l = int(cfg.dynamics.vision.behind) + int(cfg.dynamics.vision.ahead) + 1
365
+ viewcone_w = int(cfg.dynamics.vision.left) + int(cfg.dynamics.vision.right) + 1
366
+ agent_viewcone_size = viewcone_l * viewcone_w * 25
367
+ base_r = int(cfg.entities.base.vision_radius)
368
+ base_side = 2 * base_r + 1
369
+ base_viewcone_size = base_side * base_side * 25
370
+ scalar_size = 11
371
+ self._obs_size = agent_viewcone_size + base_viewcone_size + scalar_size
372
+ self.observation_space = Box(
373
+ low=-np.inf, high=np.inf,
374
+ shape=(self._obs_size,), dtype=np.float32,
375
+ )
376
+
377
+ def _init_opponents(self):
378
+ for i in range(1, self.cfg.env.num_teams):
379
+ opp_id = f"agent_{i}"
380
+ self.opponents[opp_id] = RuleBasedOpponent(team_id=i, difficulty="static")
381
+
382
+ def _update_opponent_difficulty(self):
383
+ stage = self.CURRICULUM_STAGES[self.stage_idx]
384
+ for opp_id, opp in self.opponents.items():
385
+ if stage == "mixed":
386
+ opp.difficulty = "smart" if (int(opp_id.split("_")[1]) % 2 == 0) else "simple"
387
+ else:
388
+ opp.difficulty = stage
389
+
390
+ def _check_stage_advance(self):
391
+ if self.stage_idx >= len(self.CURRICULUM_STAGES) - 1:
392
+ return False
393
+ if len(self.stage_rewards) >= self.EPISODES_PER_STAGE:
394
+ win_rate = self.stage_wins / max(1, len(self.stage_rewards))
395
+ avg_reward = np.mean(self.stage_rewards)
396
+ if win_rate >= self.WIN_RATE_THRESHOLD or len(self.stage_rewards) >= self.EPISODES_PER_STAGE:
397
+ trackio.alert(
398
+ "Curriculum Advance",
399
+ f"Stage {self.CURRICULUM_STAGES[self.stage_idx]} complete: "
400
+ f"win_rate={win_rate:.2%}, avg_reward={avg_reward:.1f}. "
401
+ f"Advancing to {self.CURRICULUM_STAGES[self.stage_idx + 1]}",
402
+ "INFO",
403
+ )
404
+ self.stage_idx += 1
405
+ self.stage_episodes = 0
406
+ self.stage_wins = 0
407
+ self.stage_rewards = []
408
+ self._update_opponent_difficulty()
409
+ return True
410
+ return False
411
+
412
+ def reset(self, seed=None, options=None):
413
+ if seed is not None:
414
+ self._episode_seed = seed
415
+ else:
416
+ self._episode_seed = self._episode_count
417
+ self._episode_count += 1
418
+
419
+ for opp in self.opponents.values():
420
+ opp.reset()
421
+
422
+ obs_dict, info_dict = self._parallel_env.reset(
423
+ seed=self._episode_seed, options=options
424
+ )
425
+
426
+ self._store_action_mask(obs_dict[self.agent_id])
427
+ self._last_obs_dict = obs_dict
428
+ return self._flatten_obs(obs_dict[self.agent_id]), {}
429
+
430
+ def step(self, action):
431
+ actions = {}
432
+ for agent_id in self._parallel_env.agents:
433
+ if agent_id == self.agent_id:
434
+ actions[agent_id] = action
435
+ else:
436
+ opp = self.opponents.get(agent_id)
437
+ if opp is not None and agent_id in self._last_obs_dict:
438
+ actions[agent_id] = opp.act(self._last_obs_dict[agent_id])
439
+ else:
440
+ actions[agent_id] = 4
441
+
442
+ obs_dict, rewards, terminations, truncations, infos = self._parallel_env.step(actions)
443
+ self._last_obs_dict = obs_dict
444
+
445
+ if self.agent_id not in obs_dict:
446
+ self.stage_episodes += 1
447
+ return np.zeros(self._obs_size, dtype=np.float32), 0.0, True, False, {}
448
+
449
+ self._store_action_mask(obs_dict[self.agent_id])
450
+ obs = self._flatten_obs(obs_dict[self.agent_id])
451
+ reward = float(rewards.get(self.agent_id, 0.0))
452
+ done = terminations.get(self.agent_id, False) or truncations.get(self.agent_id, False)
453
+
454
+ if done:
455
+ self.stage_episodes += 1
456
+ self.stage_rewards.append(reward)
457
+ if reward > 10.0:
458
+ self.stage_wins += 1
459
+ self._check_stage_advance()
460
+
461
+ info = dict(infos.get(self.agent_id, {}))
462
+ info["curriculum_stage"] = self.stage_idx
463
+ info["curriculum_stage_name"] = self.CURRICULUM_STAGES[self.stage_idx]
464
+
465
+ return obs, reward, done, False, info
466
+
467
+ def _store_action_mask(self, obs_dict):
468
+ if "action_mask" in obs_dict:
469
+ self._last_action_mask = obs_dict["action_mask"].copy().astype(bool)
470
+ else:
471
+ self._last_action_mask = np.ones(6, dtype=bool)
472
+
473
+ def action_masks(self):
474
+ return self._last_action_mask
475
+
476
+ def _flatten_obs(self, obs_dict):
477
+ return np.concatenate(
478
+ [
479
+ obs_dict["agent_viewcone"].flatten(),
480
+ obs_dict["base_viewcone"].flatten(),
481
+ np.array([obs_dict["direction"]], dtype=np.float32),
482
+ obs_dict["location"].flatten().astype(np.float32),
483
+ obs_dict["base_location"].flatten().astype(np.float32),
484
+ obs_dict["health"].flatten().astype(np.float32),
485
+ np.array([obs_dict["frozen_ticks"]], dtype=np.float32),
486
+ obs_dict["base_health"].flatten().astype(np.float32),
487
+ obs_dict["team_resources"].flatten().astype(np.float32),
488
+ np.array([obs_dict["team_bombs"]], dtype=np.float32),
489
+ np.array([obs_dict["step"]], dtype=np.float32),
490
+ ],
491
+ dtype=np.float32,
492
+ )
493
+
494
+ def close(self):
495
+ self._parallel_env.close()
496
+
497
+
498
+ # ============================================================================
499
+ # Trackio logging callback
500
+ # ============================================================================
501
+
502
+ class TrackioLoggingCallback(BaseCallback):
503
+ def __init__(self, project, run_name, log_interval=2048, verbose=0):
504
+ super().__init__(verbose)
505
+ self.project = project
506
+ self.run_name = run_name
507
+ self.log_interval = log_interval
508
+ self._last_mean_reward = 0.0
509
+
510
+ def _on_training_start(self):
511
+ trackio.init_run(project=self.project, run_name=self.run_name)
512
+ trackio.alert("Training Started", f"{self.run_name} training began.", "INFO")
513
+
514
+ def _on_step(self):
515
+ if self.n_calls % self.log_interval == 0:
516
+ infos = self.locals.get("infos", [{}])
517
+ ep_rewards = [info.get("episode", {}).get("r", 0) for info in infos if "episode" in info]
518
+ ep_lengths = [info.get("episode", {}).get("l", 0) for info in infos if "episode" in info]
519
+ explore_bonuses = [info.get("explore_bonus", 0) for info in infos]
520
+ stages = [info.get("curriculum_stage", 0) for info in infos]
521
+
522
+ if ep_rewards:
523
+ mean_r = float(np.mean(ep_rewards))
524
+ self._last_mean_reward = mean_r
525
+ log_dict = {
526
+ "train/mean_episode_reward": mean_r,
527
+ "train/mean_episode_length": float(np.mean(ep_lengths)) if ep_lengths else 0.0,
528
+ "train/timesteps": self.num_timesteps,
529
+ }
530
+ if explore_bonuses:
531
+ log_dict["train/mean_explore_bonus"] = float(np.mean(explore_bonuses))
532
+ if stages:
533
+ log_dict["train/curriculum_stage"] = float(np.mean(stages))
534
+ trackio.log(log_dict)
535
+
536
+ if mean_r < -5.0 and self.num_timesteps > 50_000:
537
+ trackio.alert("Low Reward Warning",
538
+ f"mean_reward={mean_r:.2f} at step {self.num_timesteps} -- may be camping.", "WARN")
539
+ return True
540
+
541
+ def _on_training_end(self):
542
+ trackio.alert("Training Complete",
543
+ f"Finished at {self.num_timesteps}. Final mean reward: {self._last_mean_reward:.2f}",
544
+ "INFO")
545
+ trackio.finish()
546
+
547
+
548
+ # ============================================================================
549
+ # Main training pipeline
550
+ # ============================================================================
551
+
552
+ def train_phase(cfg, phase, total_timesteps, model=None):
553
+ trackio_project = os.environ.get("TRACKIO_PROJECT", "til-26-ae")
554
+
555
+ if phase == 1:
556
+ print("=== PHASE 1: MaskablePPO vs Random Opponents ===")
557
+ base_env = BombermanSingleAgentEnv(cfg=cfg, opponent_policy="random")
558
+ env = ActionMasker(base_env, lambda env: env.action_masks())
559
+ env = Monitor(env)
560
+ run_name = "phase1-maskable-ppo-random"
561
+
562
+ elif phase == 2:
563
+ print("=== PHASE 2: Adaptive Exploration Annealing ===")
564
+ base_env = BombermanSingleAgentEnv(cfg=cfg, opponent_policy="random")
565
+ shaped_env = RewardShapingWrapper(base_env, adaptive_k=1.2, base_explore_weight=0.5)
566
+ env = ActionMasker(shaped_env, lambda env: env.action_masks())
567
+ env = Monitor(env)
568
+ run_name = "phase2-adaptive-explore"
569
+
570
+ elif phase == 3:
571
+ print("=== PHASE 3: Curriculum + Rule-Based Self-Play ===")
572
+ cfg.env.num_teams = 3
573
+ base_env = CurriculumEnv(cfg=cfg)
574
+ env = ActionMasker(base_env, lambda env: env.action_masks())
575
+ env = Monitor(env)
576
+ run_name = "phase3-curriculum-selfplay"
577
+
578
+ else:
579
+ raise ValueError(f"Unknown phase: {phase}")
580
+
581
+ if model is None:
582
+ model = MaskablePPO(
583
+ "MlpPolicy", env,
584
+ learning_rate=3e-4, n_steps=2048, batch_size=64, n_epochs=10,
585
+ gamma=0.99, gae_lambda=0.95, clip_range=0.2,
586
+ ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5,
587
+ verbose=1, tensorboard_log="./tb_logs",
588
+ device="cuda" if torch.cuda.is_available() else "cpu",
589
+ )
590
+ else:
591
+ model.set_env(env)
592
+
593
+ checkpoint_callback = CheckpointCallback(
594
+ save_freq=50_000, save_path=f"./checkpoints/phase{phase}",
595
+ name_prefix=f"bomberman_phase{phase}",
596
+ )
597
+
598
+ trackio_callback = TrackioLoggingCallback(
599
+ trackio_project, run_name, log_interval=2048,
600
+ )
601
+
602
+ model.learn(
603
+ total_timesteps=total_timesteps,
604
+ callback=[checkpoint_callback, trackio_callback],
605
+ progress_bar=False,
606
+ )
607
+
608
+ model.save(f"bomberman_phase{phase}_final")
609
+ env.close()
610
+ print(f"Phase {phase} complete. Model saved to bomberman_phase{phase}_final.zip")
611
+ return model
612
+
613
+
614
+ def main():
615
+ cfg = default_config()
616
+ cfg.env.render_mode = None
617
+
618
+ total_ts_env = os.environ.get("TOTAL_TIMESTEPS", "500_000:500_000:1_000_000")
619
+ phase_ts = [int(x.replace("_", "")) for x in total_ts_env.split(":")]
620
+
621
+ model = None
622
+ model = train_phase(cfg, phase=1, total_timesteps=phase_ts[0], model=model)
623
+
624
+ if len(phase_ts) > 1:
625
+ model = train_phase(cfg, phase=2, total_timesteps=phase_ts[1], model=model)
626
+
627
+ if len(phase_ts) > 2:
628
+ model = train_phase(cfg, phase=3, total_timesteps=phase_ts[2], model=model)
629
+
630
+ hub_model_id = os.environ.get("HUB_MODEL_ID", "")
631
+ if hub_model_id:
632
+ from huggingface_hub import HfApi
633
+ api = HfApi()
634
+ for phase in range(1, len(phase_ts) + 1):
635
+ try:
636
+ api.upload_file(
637
+ path_or_fileobj=f"bomberman_phase{phase}_final.zip",
638
+ path_in_repo=f"bomberman_phase{phase}_final.zip",
639
+ repo_id=hub_model_id, repo_type="model",
640
+ )
641
+ print(f"Phase {phase} model pushed to {hub_model_id}")
642
+ except Exception as e:
643
+ print(f"Failed to push phase {phase}: {e}")
644
+
645
+ print("\n=== All phases complete! ===")
646
+ if hub_model_id:
647
+ print(f"Model repository: https://huggingface.co/{hub_model_id}")
648
+
649
+
650
+ if __name__ == "__main__":
651
+ main()