E-Rong commited on
Commit
7d18d2b
·
verified ·
1 Parent(s): 67d546f

Upload phase3_curriculum.py

Browse files
Files changed (1) hide show
  1. phase3_curriculum.py +323 -0
phase3_curriculum.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Phase 3: Rule-based curriculum training - 1M steps with progressive opponents."""
3
+ import os, sys, subprocess, numpy as np, torch, gymnasium
4
+ from gymnasium.spaces import Box, Discrete
5
+
6
+ # ── 1. Download TIL env via snapshot_download ──
7
+ print("[1/5] Downloading TIL repo...")
8
+ from huggingface_hub import snapshot_download
9
+ snapshot_download(repo_id="e-rong/til-26-ae", repo_type="space", local_dir="/app/til-26-ae-repo")
10
+ PKG_ROOT = None
11
+ for root, dirs, files in os.walk("/app/til-26-ae-repo"):
12
+ if "pyproject.toml" in files:
13
+ PKG_ROOT = root
14
+ break
15
+ if PKG_ROOT is None:
16
+ raise RuntimeError("pyproject.toml not found")
17
+ subprocess.run(["pip", "install", "-e", "."], cwd=PKG_ROOT, check=True)
18
+ sys.path.insert(0, PKG_ROOT)
19
+
20
+ from til_environment.bomberman_env import Bomberman
21
+ from til_environment.config import default_config
22
+ from pettingzoo.utils.conversions import aec_to_parallel
23
+ from sb3_contrib import MaskablePPO
24
+ from sb3_contrib.common.wrappers import ActionMasker
25
+ from stable_baselines3.common.callbacks import BaseCallback
26
+ from stable_baselines3.common.monitor import Monitor
27
+ from huggingface_hub import HfApi, hf_hub_download
28
+
29
+ HUB_REPO = "E-Rong/til-26-ae-agent"
30
+ DATA_DIR = "/app/data"
31
+ os.makedirs(DATA_DIR, exist_ok=True)
32
+
33
+
34
+ def hub_push(local_path, repo_path):
35
+ try:
36
+ HfApi().upload_file(path_or_fileobj=local_path, path_in_repo=repo_path,
37
+ repo_id=HUB_REPO, repo_type="model")
38
+ print(f" -> pushed {repo_path}")
39
+ except Exception as e:
40
+ print(f" -> push failed: {e}")
41
+
42
+
43
+ # ── Opponent Policies ──
44
+ def static_opponent(obs):
45
+ """Never moves, never places bombs."""
46
+ return 4 # STAY
47
+
48
+
49
+ def random_valid_opponent(obs):
50
+ """Random valid action (Phase 1 style)."""
51
+ mask = np.array(obs.get("action_mask", [1]*6), dtype=bool)
52
+ valid = np.where(mask)[0]
53
+ return int(np.random.choice(valid)) if len(valid) > 0 else 4
54
+
55
+
56
+ def simple_bomb_opponent(obs):
57
+ """Moves randomly but places bombs when enemies are visible."""
58
+ mask = np.array(obs.get("action_mask", [1]*6), dtype=bool)
59
+ # Check if enemies visible in viewcone
60
+ view = np.array(obs.get("agent_viewcone", np.zeros((7,5,25))))
61
+ if view.shape[-1] >= 11: # ENEMY_AGENT channel exists
62
+ enemy_present = np.any(view[..., 10] > 0) # ENEMY_AGENT channel
63
+ if enemy_present and mask[5]: # PLACE_BOMB is valid
64
+ return 5
65
+ valid = np.where(mask)[0]
66
+ # Prefer movement over stay
67
+ move_actions = [v for v in valid if v < 4]
68
+ if move_actions:
69
+ return int(np.random.choice(move_actions))
70
+ return int(np.random.choice(valid)) if len(valid) > 0 else 4
71
+
72
+
73
+ def evasive_opponent(obs):
74
+ """Tries to move away from bombs, random otherwise."""
75
+ mask = np.array(obs.get("action_mask", [1]*6), dtype=bool)
76
+ view = np.array(obs.get("agent_viewcone", np.zeros((7,5,25))))
77
+ # If enemy bomb visible, try to move away
78
+ if view.shape[-1] >= 20:
79
+ enemy_bombs = view[..., 18] # ENEMY_BOMB channel
80
+ if np.any(enemy_bombs > 0):
81
+ # Find safest direction - away from bomb
82
+ bomb_y, bomb_x = np.where(enemy_bombs > 0)
83
+ if len(bomb_y) > 0:
84
+ # Just pick any valid movement action
85
+ move_actions = [v for v in np.where(mask)[0] if v < 4]
86
+ if move_actions:
87
+ return int(np.random.choice(move_actions))
88
+ valid = np.where(mask)[0]
89
+ return int(np.random.choice(valid)) if len(valid) > 0 else 4
90
+
91
+
92
+ CURRICULUM_STAGES = [
93
+ ("static", static_opponent, 150000),
94
+ ("random", random_valid_opponent, 200000),
95
+ ("simple_bomb", simple_bomb_opponent, 250000),
96
+ ("evasive", evasive_opponent, 200000),
97
+ ("mixed", None, 200000), # cycles through all
98
+ ]
99
+
100
+
101
+ class CurriculumEnv(gymnasium.Env):
102
+ """Single-agent env with curriculum opponents."""
103
+ def __init__(self, opponent_fn=None, cfg=None):
104
+ super().__init__()
105
+ self.cfg = cfg or default_config()
106
+ self.cfg.env.render_mode = None
107
+ raw = Bomberman(self.cfg)
108
+ self._parallel_env = aec_to_parallel(raw)
109
+ self.agent_id = "agent_0"
110
+ self._episode_count = 0
111
+ self.action_space = Discrete(6)
112
+ self._last_action_mask = None
113
+ self._obs_size = None
114
+ self._last_obs_dict = None
115
+ self.opponent_fn = opponent_fn or random_valid_opponent
116
+ self._compute_obs_space()
117
+
118
+ def _compute_obs_space(self):
119
+ cfg = self.cfg
120
+ vl = int(cfg.dynamics.vision.behind) + int(cfg.dynamics.vision.ahead) + 1
121
+ vw = int(cfg.dynamics.vision.left) + int(cfg.dynamics.vision.right) + 1
122
+ av = vl * vw * 25
123
+ br = int(cfg.entities.base.vision_radius)
124
+ bs = 2 * br + 1
125
+ bv = bs * bs * 25
126
+ self._obs_size = av + bv + 11
127
+ self.observation_space = Box(low=-np.inf, high=np.inf, shape=(self._obs_size,), dtype=np.float32)
128
+
129
+ def reset(self, seed=None, options=None):
130
+ self._episode_count += 1
131
+ obs_dict, info_dict = self._parallel_env.reset(seed=self._episode_count, options=options)
132
+ self._last_obs_dict = obs_dict
133
+ self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
134
+ return self._flatten(obs_dict[self.agent_id]), {}
135
+
136
+ def step(self, action):
137
+ actions = {self.agent_id: action}
138
+ for aid, obs in self._last_obs_dict.items():
139
+ if aid != self.agent_id:
140
+ actions[aid] = self.opponent_fn(obs)
141
+ obs_dict, rewards, terminations, truncations, infos = self._parallel_env.step(actions)
142
+ self._last_obs_dict = obs_dict
143
+ if self.agent_id not in obs_dict:
144
+ return np.zeros(self._obs_size, dtype=np.float32), 0.0, True, False, {}
145
+ self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
146
+ obs = self._flatten(obs_dict[self.agent_id])
147
+ r = float(rewards.get(self.agent_id, 0.0))
148
+ done = terminations.get(self.agent_id, False) or truncations.get(self.agent_id, False)
149
+ return obs, r, done, False, infos.get(self.agent_id, {})
150
+
151
+ def action_masks(self):
152
+ return self._last_action_mask
153
+
154
+ def _flatten(self, od):
155
+ return np.concatenate([
156
+ od["agent_viewcone"].flatten(), od["base_viewcone"].flatten(),
157
+ np.array([od["direction"]], dtype=np.float32),
158
+ od["location"].flatten().astype(np.float32),
159
+ od["base_location"].flatten().astype(np.float32),
160
+ od["health"].flatten().astype(np.float32),
161
+ np.array([od["frozen_ticks"]], dtype=np.float32),
162
+ od["base_health"].flatten().astype(np.float32),
163
+ od["team_resources"].flatten().astype(np.float32),
164
+ np.array([od["team_bombs"]], dtype=np.float32),
165
+ np.array([od["step"]], dtype=np.float32),
166
+ ], dtype=np.float32)
167
+
168
+ def close(self):
169
+ self._parallel_env.close()
170
+
171
+
172
+ class CurriculumCallback(BaseCallback):
173
+ """Advances curriculum stage based on win rate + pushes checkpoints."""
174
+ def __init__(self, eval_freq=50000, save_freq=50000):
175
+ super().__init__()
176
+ self.eval_freq = eval_freq
177
+ self.save_freq = save_freq
178
+ self.stage_idx = 0
179
+ self.stage_steps = 0
180
+ self.wins_history = []
181
+ self.eval_episodes = 100
182
+
183
+ def _on_step(self) -> bool:
184
+ if self.num_timesteps % self.save_freq == 0:
185
+ path = os.path.join(DATA_DIR, f"phase3_ckpt_{self.num_timesteps}.zip")
186
+ self.model.save(path)
187
+ hub_push(path, f"phase3_ckpt_{self.num_timesteps}.zip")
188
+
189
+ if self.num_timesteps % self.eval_freq == 0:
190
+ self._evaluate_and_maybe_advance()
191
+ return True
192
+
193
+ def _evaluate_and_maybe_advance(self):
194
+ stage_name, opp_fn, stage_limit = CURRICULUM_STAGES[self.stage_idx]
195
+ print(f"\n--- Evaluating at stage {stage_name} (step {self.num_timesteps}) ---")
196
+
197
+ # Run eval episodes
198
+ env = CurriculumEnv(opponent_fn=opp_fn, cfg=default_config())
199
+ env = ActionMasker(env, lambda e: e.action_masks())
200
+ wins = 0; total_r = 0
201
+ for ep in range(self.eval_episodes):
202
+ obs, _ = env.reset(seed=ep + 100000 + self.num_timesteps)
203
+ ep_r = 0; done = False
204
+ while not done:
205
+ action, _ = self.model.predict(obs, action_masks=env.action_masks(), deterministic=True)
206
+ obs, r, done, _, _ = env.step(int(action))
207
+ ep_r += r
208
+ total_r += ep_r
209
+ if ep_r > 10:
210
+ wins += 1
211
+ env.close()
212
+
213
+ win_rate = wins / self.eval_episodes
214
+ avg_r = total_r / self.eval_episodes
215
+ print(f" Win rate: {win_rate:.1%}, Avg reward: {avg_r:.1f}")
216
+ self.wins_history.append((self.num_timesteps, stage_name, win_rate, avg_r))
217
+
218
+ # Save eval results
219
+ eval_file = f"/app/phase3_eval_{self.num_timesteps}.txt"
220
+ with open(eval_file, "w") as f:
221
+ f.write(f"Stage: {stage_name}\nStep: {self.num_timesteps}\nWinRate: {win_rate:.1%}\nAvgReward: {avg_r:.1f}\n")
222
+ hub_push(eval_file, f"phase3_eval_{self.num_timesteps}.txt")
223
+
224
+ # Advance curriculum if win rate > 55% and we've spent enough steps
225
+ if win_rate > 0.55 and self.stage_idx < len(CURRICULUM_STAGES) - 1:
226
+ self.stage_idx += 1
227
+ new_stage = CURRICULUM_STAGES[self.stage_idx][0]
228
+ print(f" >>> ADVANCING to stage: {new_stage} <<<")
229
+
230
+
231
+ def main():
232
+ print("=" * 60)
233
+ print("PHASE 3: Rule-Based Curriculum")
234
+ print("=" * 60)
235
+
236
+ # Download latest checkpoint (phase2_final or best available)
237
+ latest = None
238
+ for ckpt in ["phase2_final.zip", "phase2_ckpt_600352.zip", "phase1_final.zip"]:
239
+ try:
240
+ latest = hf_hub_download(repo_id=HUB_REPO, filename=ckpt, repo_type="model", local_dir=DATA_DIR)
241
+ print(f"Downloaded checkpoint: {ckpt}")
242
+ break
243
+ except Exception:
244
+ pass
245
+ if latest is None:
246
+ raise RuntimeError("No checkpoint found!")
247
+
248
+ # Start with first curriculum stage
249
+ stage_name, opp_fn, _ = CURRICULUM_STAGES[0]
250
+ print(f"Starting curriculum stage: {stage_name}")
251
+
252
+ cfg = default_config()
253
+ cfg.env.render_mode = None
254
+ base = CurriculumEnv(opponent_fn=opp_fn, cfg=cfg)
255
+ env = ActionMasker(base, lambda e: e.action_masks())
256
+ env = Monitor(env)
257
+
258
+ model = MaskablePPO.load(latest, env=env)
259
+ start_ts = model.num_timesteps
260
+ print(f"Loaded model at timestep {start_ts}")
261
+
262
+ cb = CurriculumCallback(eval_freq=50000, save_freq=50000)
263
+ model.learn(total_timesteps=1000000, callback=cb, progress_bar=False, reset_num_timesteps=False)
264
+
265
+ # Save final
266
+ final = os.path.join(DATA_DIR, "phase3_final.zip")
267
+ model.save(final)
268
+ hub_push(final, "phase3_final.zip")
269
+
270
+ # Final eval
271
+ print("\n=== FINAL EVALUATION ===")
272
+ raw = Bomberman(default_config())
273
+ env = aec_to_parallel(raw)
274
+ wins = 0; total_r = 0
275
+ for ep in range(200):
276
+ obs, _ = env.reset(seed=ep + 200000)
277
+ ep_r = 0; done = False
278
+ while not done:
279
+ if "agent_0" not in obs:
280
+ break
281
+ ao = obs["agent_0"]
282
+ mask = np.array(ao.get("action_mask", [1]*6), dtype=bool)
283
+ vec = np.concatenate([
284
+ np.array(ao["agent_viewcone"], np.float32).flatten(),
285
+ np.array(ao["base_viewcone"], np.float32).flatten(),
286
+ np.array([ao["direction"]], np.float32),
287
+ np.array(ao["location"], np.float32).flatten(),
288
+ np.array(ao["base_location"], np.float32).flatten(),
289
+ np.array(ao["health"], np.float32).flatten(),
290
+ np.array([ao["frozen_ticks"]], np.float32),
291
+ np.array(ao["base_health"], np.float32).flatten(),
292
+ np.array(ao["team_resources"], np.float32).flatten(),
293
+ np.array([ao["team_bombs"]], np.float32),
294
+ np.array([ao["step"]], np.float32),
295
+ ], dtype=np.float32)
296
+ action, _ = model.predict(vec, action_masks=mask, deterministic=True)
297
+ acts = {"agent_0": int(action)}
298
+ for aid, o in obs.items():
299
+ if aid != "agent_0":
300
+ acts[aid] = random_valid_opponent(o)
301
+ obs, rewards, terminations, truncations, _ = env.step(acts)
302
+ ep_r += rewards.get("agent_0", 0)
303
+ done = terminations.get("agent_0", False) or truncations.get("agent_0", False) or "agent_0" not in obs
304
+ total_r += ep_r
305
+ if ep_r > 10:
306
+ wins += 1
307
+ env.close()
308
+
309
+ results = (
310
+ f"=== Phase 3 Final Evaluation ===\n"
311
+ f"Episodes: 200\n"
312
+ f"Win Rate: {wins/200:.1%}\n"
313
+ f"Avg Reward: {total_r/200:.1f}\n"
314
+ )
315
+ print(results)
316
+ with open("/app/phase3_final_eval.txt", "w") as f:
317
+ f.write(results)
318
+ hub_push("/app/phase3_final_eval.txt", "phase3_final_eval.txt")
319
+ print("\n✅ PHASE 3 COMPLETE!")
320
+
321
+
322
+ if __name__ == "__main__":
323
+ main()