E-Rong commited on
Commit
2f3c7cd
·
verified ·
1 Parent(s): 2a7b40e

Add Phase 2 HF Job training script

Browse files
Files changed (1) hide show
  1. phase2_job.py +243 -0
phase2_job.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Phase 2 training job - runs in HF Jobs, resumes from Hub checkpoint."""
3
+ import os, sys, subprocess, numpy as np, torch, gymnasium
4
+ from gymnasium.spaces import Box, Discrete
5
+
6
+ # Install TIL environment from source
7
+ TIL_REPO = "e-rong/til-26-ae"
8
+ TIL_PATH = "/app/til-26-ae-repo/til-26-ae"
9
+ if not os.path.exists(TIL_PATH):
10
+ subprocess.run(["git", "clone", f"https://huggingface.co/spaces/{TIL_REPO}", "/app/til-26-ae-repo"], check=True)
11
+ subprocess.run(["pip", "install", "-e", "."], cwd=TIL_PATH, check=True)
12
+ sys.path.insert(0, TIL_PATH)
13
+
14
+ from til_environment.bomberman_env import Bomberman
15
+ from til_environment.config import default_config
16
+ from pettingzoo.utils.conversions import aec_to_parallel
17
+ from sb3_contrib import MaskablePPO
18
+ from sb3_contrib.common.wrappers import ActionMasker
19
+ from stable_baselines3.common.callbacks import CheckpointCallback
20
+ from stable_baselines3.common.monitor import Monitor
21
+ from huggingface_hub import HfApi, hf_hub_download
22
+
23
+ HUB_REPO = "E-Rong/til-26-ae-agent"
24
+ DATA_DIR = "/app/data"
25
+ os.makedirs(DATA_DIR, exist_ok=True)
26
+
27
+ def hub_push(local_path, repo_path):
28
+ try:
29
+ HfApi().upload_file(path_or_fileobj=local_path, path_in_repo=repo_path,
30
+ repo_id=HUB_REPO, repo_type="model")
31
+ print(f" -> pushed {repo_path}")
32
+ except Exception as e:
33
+ print(f" -> push failed: {e}")
34
+
35
+ class BombermanSingleAgentEnv(gymnasium.Env):
36
+ def __init__(self, cfg=None):
37
+ super().__init__()
38
+ self.cfg = cfg or default_config()
39
+ self.cfg.env.render_mode = None
40
+ raw = Bomberman(self.cfg)
41
+ self._parallel_env = aec_to_parallel(raw)
42
+ self.agent_id = "agent_0"
43
+ self._episode_count = 0
44
+ self.action_space = Discrete(6)
45
+ self._last_action_mask = None
46
+ self._obs_size = None
47
+ self._last_obs_dict = None
48
+ self._compute_obs_space()
49
+ def _compute_obs_space(self):
50
+ cfg = self.cfg
51
+ vl = int(cfg.dynamics.vision.behind) + int(cfg.dynamics.vision.ahead) + 1
52
+ vw = int(cfg.dynamics.vision.left) + int(cfg.dynamics.vision.right) + 1
53
+ av = vl * vw * 25
54
+ br = int(cfg.entities.base.vision_radius)
55
+ bs = 2 * br + 1
56
+ bv = bs * bs * 25
57
+ self._obs_size = av + bv + 11
58
+ self.observation_space = Box(low=-np.inf, high=np.inf, shape=(self._obs_size,), dtype=np.float32)
59
+ def reset(self, seed=None, options=None):
60
+ self._episode_count += 1
61
+ obs_dict, info_dict = self._parallel_env.reset(seed=self._episode_count, options=options)
62
+ self._last_obs_dict = obs_dict
63
+ self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
64
+ return self._flatten(obs_dict[self.agent_id]), {}
65
+ def step(self, action):
66
+ actions = {self.agent_id: action}
67
+ for aid, obs in self._last_obs_dict.items():
68
+ if aid != self.agent_id:
69
+ valid = np.where(obs["action_mask"] == 1)[0]
70
+ actions[aid] = int(np.random.choice(valid)) if len(valid) > 0 else 0
71
+ obs_dict, rewards, terminations, truncations, infos = self._parallel_env.step(actions)
72
+ self._last_obs_dict = obs_dict
73
+ if self.agent_id not in obs_dict:
74
+ return np.zeros(self._obs_size, dtype=np.float32), 0.0, True, False, {}
75
+ self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
76
+ obs = self._flatten(obs_dict[self.agent_id])
77
+ r = float(rewards.get(self.agent_id, 0.0))
78
+ done = terminations.get(self.agent_id, False) or truncations.get(self.agent_id, False)
79
+ return obs, r, done, False, infos.get(self.agent_id, {})
80
+ def action_masks(self):
81
+ return self._last_action_mask
82
+ def _flatten(self, od):
83
+ return np.concatenate([
84
+ od["agent_viewcone"].flatten(), od["base_viewcone"].flatten(),
85
+ np.array([od["direction"]], dtype=np.float32),
86
+ od["location"].flatten().astype(np.float32),
87
+ od["base_location"].flatten().astype(np.float32),
88
+ od["health"].flatten().astype(np.float32),
89
+ np.array([od["frozen_ticks"]], dtype=np.float32),
90
+ od["base_health"].flatten().astype(np.float32),
91
+ od["team_resources"].flatten().astype(np.float32),
92
+ np.array([od["team_bombs"]], dtype=np.float32),
93
+ np.array([od["step"]], dtype=np.float32),
94
+ ], dtype=np.float32)
95
+ def close(self):
96
+ self._parallel_env.close()
97
+
98
+ class RewardShapingWrapper(gymnasium.Wrapper):
99
+ """Visit-count exploration with adaptive annealing."""
100
+ def __init__(self, env, adaptive_k=1.2, base_explore_weight=0.5):
101
+ super().__init__(env)
102
+ self.adaptive_k = adaptive_k
103
+ self.base_explore_weight = base_explore_weight
104
+ self._visit_counts = None
105
+ self._grid_size = 16
106
+ self._avg_enemy_deaths = 0.0
107
+ self._explore_weight = base_explore_weight
108
+ def reset(self, **kwargs):
109
+ self._visit_counts = np.zeros((self._grid_size, self._grid_size), dtype=np.int32)
110
+ return self.env.reset(**kwargs)
111
+ def step(self, action):
112
+ obs, reward, done, truncated, info = self.env.step(action)
113
+ pos = info.get("location", None)
114
+ bonus = 0.0
115
+ if pos is not None:
116
+ x, y = int(pos[0]), int(pos[1])
117
+ if 0 <= x < self._grid_size and 0 <= y < self._grid_size:
118
+ visits = self._visit_counts[x, y]
119
+ bonus = 1.0 / (1.0 + visits)
120
+ self._visit_counts[x, y] += 1
121
+ if done:
122
+ alpha = 1.0 - np.tanh(self.adaptive_k * self._avg_enemy_deaths)
123
+ self._explore_weight = self.base_explore_weight * max(0.1, alpha)
124
+ if reward > 20.0:
125
+ self._avg_enemy_deaths = 0.95 * self._avg_enemy_deaths + 0.05 * 1.0
126
+ shaped = reward + self._explore_weight * bonus
127
+ info["raw_reward"] = reward
128
+ info["explore_bonus"] = bonus
129
+ return obs, shaped, done, truncated, info
130
+ def action_masks(self):
131
+ return self.env.action_masks()
132
+
133
+ class HubCheckpointCallback(CheckpointCallback):
134
+ """Saves locally + pushes to Hub."""
135
+ def _on_step(self) -> bool:
136
+ if self.num_timesteps % self.save_freq == 0:
137
+ path = os.path.join(self.save_path, f"phase2_ckpt_{self.num_timesteps}.zip")
138
+ self.model.save(path)
139
+ hub_push(path, f"phase2_ckpt_{self.num_timesteps}.zip")
140
+ return True
141
+
142
+
143
+ def main():
144
+ print("=" * 60)
145
+ print("PHASE 2: Adaptive Exploration Annealing")
146
+ print("=" * 60)
147
+
148
+ # Download latest checkpoint
149
+ latest = None
150
+ for ckpt in ["phase2_ckpt_600352.zip", "phase2_ckpt_550352.zip", "phase1_final.zip"]:
151
+ try:
152
+ latest = hf_hub_download(repo_id=HUB_REPO, filename=ckpt, repo_type="model", local_dir=DATA_DIR)
153
+ print(f"Downloaded checkpoint: {ckpt}")
154
+ break
155
+ except Exception:
156
+ print(f" {ckpt} not found, trying next...")
157
+ if latest is None:
158
+ raise RuntimeError("No checkpoint found on Hub!")
159
+
160
+ # Environment
161
+ cfg = default_config()
162
+ cfg.env.render_mode = None
163
+ base = BombermanSingleAgentEnv(cfg=cfg)
164
+ env = ActionMasker(RewardShapingWrapper(base), lambda e: e.action_masks())
165
+ env = Monitor(env)
166
+
167
+ # Load model
168
+ print(f"Loading model from {latest}...")
169
+ model = MaskablePPO.load(latest, env=env)
170
+ start_ts = model.num_timesteps
171
+ remaining = 1000000 - start_ts
172
+ print(f"Current: {start_ts}, remaining: {remaining}, target: 1,000,352")
173
+
174
+ # Train
175
+ cb = HubCheckpointCallback(save_freq=50000, save_path=DATA_DIR, name_prefix="phase2")
176
+ model.learn(total_timesteps=remaining, callback=cb, progress_bar=False, reset_num_timesteps=False)
177
+
178
+ # Save final
179
+ final = os.path.join(DATA_DIR, "phase2_final.zip")
180
+ model.save(final)
181
+ hub_push(final, "phase2_final.zip")
182
+ env.close()
183
+
184
+ print("\n=== Phase 2 COMPLETE ===")
185
+ print(f"Final timestep: {model.num_timesteps}")
186
+
187
+ # Evaluation
188
+ print("\n=== EVALUATION (100 eps vs Random) ===")
189
+ raw = Bomberman(default_config())
190
+ env = aec_to_parallel(raw)
191
+ wins = 0; total_r = 0; lens = []; bombs = 0
192
+ for ep in range(100):
193
+ obs, _ = env.reset(seed=ep+50000)
194
+ ep_r = 0; steps = 0; done = False; ep_bombs = 0
195
+ while not done:
196
+ if "agent_0" not in obs: break
197
+ ao = obs["agent_0"]
198
+ mask = np.array(ao.get("action_mask", [1]*6), dtype=bool)
199
+ vec = np.concatenate([
200
+ np.array(ao["agent_viewcone"], np.float32).flatten(),
201
+ np.array(ao["base_viewcone"], np.float32).flatten(),
202
+ np.array([ao["direction"]], np.float32),
203
+ np.array(ao["location"], np.float32).flatten(),
204
+ np.array(ao["base_location"], np.float32).flatten(),
205
+ np.array(ao["health"], np.float32).flatten(),
206
+ np.array([ao["frozen_ticks"]], np.float32),
207
+ np.array(ao["base_health"], np.float32).flatten(),
208
+ np.array(ao["team_resources"], np.float32).flatten(),
209
+ np.array([ao["team_bombs"]], np.float32),
210
+ np.array([ao["step"]], np.float32),
211
+ ], dtype=np.float32)
212
+ action, _ = model.predict(vec, action_masks=mask, deterministic=True)
213
+ if int(action) == 5: ep_bombs += 1
214
+ acts = {"agent_0": int(action)}
215
+ for aid, o in obs.items():
216
+ if aid != "agent_0":
217
+ v = np.where(np.array(o["action_mask"]) == 1)[0]
218
+ acts[aid] = int(np.random.choice(v)) if len(v) > 0 else 4
219
+ obs, rewards, terminations, truncations, _ = env.step(acts)
220
+ ep_r += rewards.get("agent_0", 0)
221
+ steps += 1
222
+ done = terminations.get("agent_0", False) or truncations.get("agent_0", False) or "agent_0" not in obs
223
+ total_r += ep_r; lens.append(steps); bombs += ep_bombs
224
+ if ep_r > 10: wins += 1
225
+ env.close()
226
+
227
+ results = (
228
+ f"=== Phase 2 Evaluation ===\n"
229
+ f"Episodes: 100\n"
230
+ f"Win Rate: {wins/100:.1%}\n"
231
+ f"Avg Reward: {total_r/100:.1f}\n"
232
+ f"Avg Length: {sum(lens)/len(lens):.1f}\n"
233
+ f"Avg Bombs: {bombs/100:.1f}\n"
234
+ )
235
+ print(results)
236
+ with open("/app/phase2_eval.txt", "w") as f:
237
+ f.write(results)
238
+ hub_push("/app/phase2_eval.txt", "phase2_eval_results.txt")
239
+ print("\n✅ ALL DONE!")
240
+
241
+
242
+ if __name__ == "__main__":
243
+ main()