E-Rong commited on
Commit
06087ac
·
verified ·
1 Parent(s): cc82514

Upload phase2_resume.py

Browse files
Files changed (1) hide show
  1. phase2_resume.py +268 -0
phase2_resume.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Phase 2 training job - RESUME from checkpoint. Runs in HF Jobs."""
3
+ import os, sys, subprocess, numpy as np, torch, gymnasium
4
+ from gymnasium.spaces import Box, Discrete
5
+
6
+ # ── 1. Download TIL env via snapshot_download (works with HF_TOKEN) ──
7
+ print("[1/4] Downloading TIL repo via snapshot_download...")
8
+ from huggingface_hub import snapshot_download
9
+ snapshot_download(
10
+ repo_id="e-rong/til-26-ae",
11
+ repo_type="space",
12
+ local_dir="/app/til-26-ae-repo",
13
+ )
14
+ # Find package root (contains pyproject.toml)
15
+ PKG_ROOT = None
16
+ for root, dirs, files in os.walk("/app/til-26-ae-repo"):
17
+ if "pyproject.toml" in files:
18
+ PKG_ROOT = root
19
+ break
20
+ if PKG_ROOT is None:
21
+ raise RuntimeError("pyproject.toml not found in downloaded repo")
22
+ print(f" Package root: {PKG_ROOT}")
23
+ subprocess.run(["pip", "install", "-e", "."], cwd=PKG_ROOT, check=True)
24
+ sys.path.insert(0, PKG_ROOT)
25
+
26
+ from til_environment.bomberman_env import Bomberman
27
+ from til_environment.config import default_config
28
+ from pettingzoo.utils.conversions import aec_to_parallel
29
+ from sb3_contrib import MaskablePPO
30
+ from sb3_contrib.common.wrappers import ActionMasker
31
+ from stable_baselines3.common.callbacks import CheckpointCallback
32
+ from stable_baselines3.common.monitor import Monitor
33
+ from huggingface_hub import HfApi, hf_hub_download
34
+
35
+ HUB_REPO = "E-Rong/til-26-ae-agent"
36
+ DATA_DIR = "/app/data"
37
+ os.makedirs(DATA_DIR, exist_ok=True)
38
+
39
+
40
+ def hub_push(local_path, repo_path):
41
+ try:
42
+ HfApi().upload_file(path_or_fileobj=local_path, path_in_repo=repo_path,
43
+ repo_id=HUB_REPO, repo_type="model")
44
+ print(f" -> pushed {repo_path}")
45
+ except Exception as e:
46
+ print(f" -> push failed: {e}")
47
+
48
+
49
+ class BombermanSingleAgentEnv(gymnasium.Env):
50
+ def __init__(self, cfg=None):
51
+ super().__init__()
52
+ self.cfg = cfg or default_config()
53
+ self.cfg.env.render_mode = None
54
+ raw = Bomberman(self.cfg)
55
+ self._parallel_env = aec_to_parallel(raw)
56
+ self.agent_id = "agent_0"
57
+ self._episode_count = 0
58
+ self.action_space = Discrete(6)
59
+ self._last_action_mask = None
60
+ self._obs_size = None
61
+ self._last_obs_dict = None
62
+ self._compute_obs_space()
63
+
64
+ def _compute_obs_space(self):
65
+ cfg = self.cfg
66
+ vl = int(cfg.dynamics.vision.behind) + int(cfg.dynamics.vision.ahead) + 1
67
+ vw = int(cfg.dynamics.vision.left) + int(cfg.dynamics.vision.right) + 1
68
+ av = vl * vw * 25
69
+ br = int(cfg.entities.base.vision_radius)
70
+ bs = 2 * br + 1
71
+ bv = bs * bs * 25
72
+ self._obs_size = av + bv + 11
73
+ self.observation_space = Box(low=-np.inf, high=np.inf, shape=(self._obs_size,), dtype=np.float32)
74
+
75
+ def reset(self, seed=None, options=None):
76
+ self._episode_count += 1
77
+ obs_dict, info_dict = self._parallel_env.reset(seed=self._episode_count, options=options)
78
+ self._last_obs_dict = obs_dict
79
+ self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
80
+ return self._flatten(obs_dict[self.agent_id]), {}
81
+
82
+ def step(self, action):
83
+ actions = {self.agent_id: action}
84
+ for aid, obs in self._last_obs_dict.items():
85
+ if aid != self.agent_id:
86
+ valid = np.where(obs["action_mask"] == 1)[0]
87
+ actions[aid] = int(np.random.choice(valid)) if len(valid) > 0 else 0
88
+ obs_dict, rewards, terminations, truncations, infos = self._parallel_env.step(actions)
89
+ self._last_obs_dict = obs_dict
90
+ if self.agent_id not in obs_dict:
91
+ return np.zeros(self._obs_size, dtype=np.float32), 0.0, True, False, {}
92
+ self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
93
+ obs = self._flatten(obs_dict[self.agent_id])
94
+ r = float(rewards.get(self.agent_id, 0.0))
95
+ done = terminations.get(self.agent_id, False) or truncations.get(self.agent_id, False)
96
+ return obs, r, done, False, infos.get(self.agent_id, {})
97
+
98
+ def action_masks(self):
99
+ return self._last_action_mask
100
+
101
+ def _flatten(self, od):
102
+ return np.concatenate([
103
+ od["agent_viewcone"].flatten(), od["base_viewcone"].flatten(),
104
+ np.array([od["direction"]], dtype=np.float32),
105
+ od["location"].flatten().astype(np.float32),
106
+ od["base_location"].flatten().astype(np.float32),
107
+ od["health"].flatten().astype(np.float32),
108
+ np.array([od["frozen_ticks"]], dtype=np.float32),
109
+ od["base_health"].flatten().astype(np.float32),
110
+ od["team_resources"].flatten().astype(np.float32),
111
+ np.array([od["team_bombs"]], dtype=np.float32),
112
+ np.array([od["step"]], dtype=np.float32),
113
+ ], dtype=np.float32)
114
+
115
+ def close(self):
116
+ self._parallel_env.close()
117
+
118
+
119
+ class RewardShapingWrapper(gymnasium.Wrapper):
120
+ """Visit-count exploration with adaptive annealing."""
121
+ def __init__(self, env, adaptive_k=1.2, base_explore_weight=0.5):
122
+ super().__init__(env)
123
+ self.adaptive_k = adaptive_k
124
+ self.base_explore_weight = base_explore_weight
125
+ self._visit_counts = None
126
+ self._grid_size = 16
127
+ self._avg_enemy_deaths = 0.0
128
+ self._explore_weight = base_explore_weight
129
+
130
+ def reset(self, **kwargs):
131
+ self._visit_counts = np.zeros((self._grid_size, self._grid_size), dtype=np.int32)
132
+ return self.env.reset(**kwargs)
133
+
134
+ def step(self, action):
135
+ obs, reward, done, truncated, info = self.env.step(action)
136
+ pos = info.get("location", None)
137
+ bonus = 0.0
138
+ if pos is not None:
139
+ x, y = int(pos[0]), int(pos[1])
140
+ if 0 <= x < self._grid_size and 0 <= y < self._grid_size:
141
+ visits = self._visit_counts[x, y]
142
+ bonus = 1.0 / (1.0 + visits)
143
+ self._visit_counts[x, y] += 1
144
+ if done:
145
+ alpha = 1.0 - np.tanh(self.adaptive_k * self._avg_enemy_deaths)
146
+ self._explore_weight = self.base_explore_weight * max(0.1, alpha)
147
+ if reward > 20.0:
148
+ self._avg_enemy_deaths = 0.95 * self._avg_enemy_deaths + 0.05 * 1.0
149
+ shaped = reward + self._explore_weight * bonus
150
+ info["raw_reward"] = reward
151
+ info["explore_bonus"] = bonus
152
+ return obs, shaped, done, truncated, info
153
+
154
+ def action_masks(self):
155
+ return self.env.action_masks()
156
+
157
+
158
+ class HubCheckpointCallback(CheckpointCallback):
159
+ """Saves locally + pushes to Hub."""
160
+ def _on_step(self) -> bool:
161
+ if self.num_timesteps % self.save_freq == 0:
162
+ path = os.path.join(self.save_path, f"phase2_ckpt_{self.num_timesteps}.zip")
163
+ self.model.save(path)
164
+ hub_push(path, f"phase2_ckpt_{self.num_timesteps}.zip")
165
+ return True
166
+
167
+
168
+ def main():
169
+ print("=" * 60)
170
+ print("PHASE 2: Adaptive Exploration Annealing — RESUME")
171
+ print("=" * 60)
172
+
173
+ # Download latest checkpoint from Hub
174
+ latest = None
175
+ for ckpt in ["phase2_ckpt_600352.zip", "phase2_ckpt_550352.zip", "phase1_final.zip"]:
176
+ try:
177
+ latest = hf_hub_download(repo_id=HUB_REPO, filename=ckpt, repo_type="model", local_dir=DATA_DIR)
178
+ print(f"Downloaded checkpoint: {ckpt}")
179
+ break
180
+ except Exception as e:
181
+ print(f" {ckpt} not found: {e}")
182
+ if latest is None:
183
+ raise RuntimeError("No checkpoint found on Hub!")
184
+
185
+ # Environment
186
+ cfg = default_config()
187
+ cfg.env.render_mode = None
188
+ base = BombermanSingleAgentEnv(cfg=cfg)
189
+ env = ActionMasker(RewardShapingWrapper(base), lambda e: e.action_masks())
190
+ env = Monitor(env)
191
+
192
+ # Load model
193
+ print(f"Loading model from {latest}...")
194
+ model = MaskablePPO.load(latest, env=env)
195
+ start_ts = model.num_timesteps
196
+ remaining = 1000000 - start_ts
197
+ print(f"Current: {start_ts}, remaining: {remaining}, target: 1,000,352")
198
+
199
+ # Train
200
+ cb = HubCheckpointCallback(save_freq=50000, save_path=DATA_DIR, name_prefix="phase2")
201
+ model.learn(total_timesteps=remaining, callback=cb, progress_bar=False, reset_num_timesteps=False)
202
+
203
+ # Save final
204
+ final = os.path.join(DATA_DIR, "phase2_final.zip")
205
+ model.save(final)
206
+ hub_push(final, "phase2_final.zip")
207
+ env.close()
208
+
209
+ print("\n=== Phase 2 COMPLETE ===")
210
+ print(f"Final timestep: {model.num_timesteps}")
211
+
212
+ # Evaluation
213
+ print("\n=== EVALUATION (100 eps vs Random) ===")
214
+ raw = Bomberman(default_config())
215
+ env = aec_to_parallel(raw)
216
+ wins = 0; total_r = 0; lens = []; bombs = 0
217
+ for ep in range(100):
218
+ obs, _ = env.reset(seed=ep+50000)
219
+ ep_r = 0; steps = 0; done = False; ep_bombs = 0
220
+ while not done:
221
+ if "agent_0" not in obs: break
222
+ ao = obs["agent_0"]
223
+ mask = np.array(ao.get("action_mask", [1]*6), dtype=bool)
224
+ vec = np.concatenate([
225
+ np.array(ao["agent_viewcone"], np.float32).flatten(),
226
+ np.array(ao["base_viewcone"], np.float32).flatten(),
227
+ np.array([ao["direction"]], np.float32),
228
+ np.array(ao["location"], np.float32).flatten(),
229
+ np.array(ao["base_location"], np.float32).flatten(),
230
+ np.array(ao["health"], np.float32).flatten(),
231
+ np.array([ao["frozen_ticks"]], np.float32),
232
+ np.array(ao["base_health"], np.float32).flatten(),
233
+ np.array(ao["team_resources"], np.float32).flatten(),
234
+ np.array([ao["team_bombs"]], np.float32),
235
+ np.array([ao["step"]], np.float32),
236
+ ], dtype=np.float32)
237
+ action, _ = model.predict(vec, action_masks=mask, deterministic=True)
238
+ if int(action) == 5: ep_bombs += 1
239
+ acts = {"agent_0": int(action)}
240
+ for aid, o in obs.items():
241
+ if aid != "agent_0":
242
+ v = np.where(np.array(o["action_mask"]) == 1)[0]
243
+ acts[aid] = int(np.random.choice(v)) if len(v) > 0 else 4
244
+ obs, rewards, terminations, truncations, _ = env.step(acts)
245
+ ep_r += rewards.get("agent_0", 0)
246
+ steps += 1
247
+ done = terminations.get("agent_0", False) or truncations.get("agent_0", False) or "agent_0" not in obs
248
+ total_r += ep_r; lens.append(steps); bombs += ep_bombs
249
+ if ep_r > 10: wins += 1
250
+ env.close()
251
+
252
+ results = (
253
+ f"=== Phase 2 Evaluation ===\n"
254
+ f"Episodes: 100\n"
255
+ f"Win Rate: {wins/100:.1%}\n"
256
+ f"Avg Reward: {total_r/100:.1f}\n"
257
+ f"Avg Length: {sum(lens)/len(lens):.1f}\n"
258
+ f"Avg Bombs: {bombs/100:.1f}\n"
259
+ )
260
+ print(results)
261
+ with open("/app/phase2_eval.txt", "w") as f:
262
+ f.write(results)
263
+ hub_push("/app/phase2_eval.txt", "phase2_eval_results.txt")
264
+ print("\n✅ ALL DONE!")
265
+
266
+
267
+ if __name__ == "__main__":
268
+ main()