E-Rong commited on
Commit
7b2f944
·
verified ·
1 Parent(s): 0e66bf6

Add train_in_space.py for running training inside the Space

Browse files
Files changed (1) hide show
  1. train_in_space.py +366 -0
train_in_space.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ TIL-26-AE Bomberman Agent Training - Runs inside the Space
4
+ Uses local til_environment (already in repo) + pushes checkpoints to Hub model repo.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import numpy as np
10
+ import gymnasium as gym
11
+ from gymnasium.spaces import Box, Discrete
12
+ import torch
13
+
14
+ # In the Space, til-26-ae is at the repo root; in sandbox it's elsewhere.
15
+ # Try multiple paths.
16
+ for path in [
17
+ "/home/user/app/til-26-ae", # HF Space typical path
18
+ "/app/til-26-ae", # sandbox path
19
+ os.path.join(os.path.dirname(__file__), "..", "til-26-ae"), # relative
20
+ "til-26-ae", # current dir
21
+ ]:
22
+ if os.path.isdir(path):
23
+ sys.path.insert(0, path)
24
+ print(f"Using til_environment from: {path}")
25
+ break
26
+
27
+ from til_environment.bomberman_env import Bomberman
28
+ from til_environment.config import default_config
29
+ from pettingzoo.utils.conversions import aec_to_parallel
30
+ from sb3_contrib import MaskablePPO
31
+ from sb3_contrib.common.wrappers import ActionMasker
32
+ from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
33
+ from stable_baselines3.common.monitor import Monitor
34
+ from huggingface_hub import HfApi
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # Environment wrappers
38
+ # ---------------------------------------------------------------------------
39
+
40
+ class BombermanSingleAgentEnv(gym.Env):
41
+ def __init__(self, cfg=None, seed=None, opponent_policy="random"):
42
+ super().__init__()
43
+ self.cfg = cfg or default_config()
44
+ self.cfg.env.render_mode = None
45
+ raw = Bomberman(self.cfg)
46
+ self._parallel_env = aec_to_parallel(raw)
47
+ self.agent_id = "agent_0"
48
+ self._episode_seed = seed
49
+ self._episode_count = 0
50
+ self.action_space = Discrete(6)
51
+ self._last_action_mask = None
52
+ self._obs_size = None
53
+ self._last_obs_dict = None
54
+ self._compute_obs_space()
55
+
56
+ def _compute_obs_space(self):
57
+ cfg = self.cfg
58
+ viewcone_l = int(cfg.dynamics.vision.behind) + int(cfg.dynamics.vision.ahead) + 1
59
+ viewcone_w = int(cfg.dynamics.vision.left) + int(cfg.dynamics.vision.right) + 1
60
+ agent_viewcone_size = viewcone_l * viewcone_w * 25
61
+ base_r = int(cfg.entities.base.vision_radius)
62
+ base_side = 2 * base_r + 1
63
+ base_viewcone_size = base_side * base_side * 25
64
+ scalar_size = 11
65
+ self._obs_size = agent_viewcone_size + base_viewcone_size + scalar_size
66
+ self.observation_space = Box(low=-np.inf, high=np.inf, shape=(self._obs_size,), dtype=np.float32)
67
+
68
+ def reset(self, seed=None, options=None):
69
+ self._episode_seed = self._episode_count if seed is None else seed
70
+ self._episode_count += 1
71
+ obs_dict, info_dict = self._parallel_env.reset(seed=self._episode_seed, options=options)
72
+ self._last_obs_dict = obs_dict
73
+ self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
74
+ return self._flatten_obs(obs_dict[self.agent_id]), {}
75
+
76
+ def step(self, action):
77
+ actions = {self.agent_id: action}
78
+ for aid, obs in self._last_obs_dict.items():
79
+ if aid != self.agent_id:
80
+ valid = np.where(obs["action_mask"] == 1)[0]
81
+ actions[aid] = int(np.random.choice(valid)) if len(valid) > 0 else 0
82
+ obs_dict, rewards, terminations, truncations, infos = self._parallel_env.step(actions)
83
+ self._last_obs_dict = obs_dict
84
+ if self.agent_id not in obs_dict:
85
+ return np.zeros(self._obs_size, dtype=np.float32), 0.0, True, False, {}
86
+ self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
87
+ obs = self._flatten_obs(obs_dict[self.agent_id])
88
+ r = float(rewards.get(self.agent_id, 0.0))
89
+ done = terminations.get(self.agent_id, False) or truncations.get(self.agent_id, False)
90
+ return obs, r, done, False, infos.get(self.agent_id, {})
91
+
92
+ def action_masks(self):
93
+ return self._last_action_mask
94
+
95
+ def _flatten_obs(self, od):
96
+ return np.concatenate([
97
+ od["agent_viewcone"].flatten(), od["base_viewcone"].flatten(),
98
+ np.array([od["direction"]], dtype=np.float32),
99
+ od["location"].flatten().astype(np.float32),
100
+ od["base_location"].flatten().astype(np.float32),
101
+ od["health"].flatten().astype(np.float32),
102
+ np.array([od["frozen_ticks"]], dtype=np.float32),
103
+ od["base_health"].flatten().astype(np.float32),
104
+ od["team_resources"].flatten().astype(np.float32),
105
+ np.array([od["team_bombs"]], dtype=np.float32),
106
+ np.array([od["step"]], dtype=np.float32),
107
+ ], dtype=np.float32)
108
+
109
+ def close(self):
110
+ self._parallel_env.close()
111
+
112
+
113
+ class RewardShapingWrapper(gym.Wrapper):
114
+ def __init__(self, env, adaptive_k=1.2, base_explore_weight=0.5):
115
+ super().__init__(env)
116
+ self.adaptive_k = adaptive_k
117
+ self.base_explore_weight = base_explore_weight
118
+ self._visit_counts = None
119
+ self._grid_size = 16
120
+ self._avg_enemy_deaths = 0.0
121
+ self._episode_count = 0
122
+ self._explore_weight = base_explore_weight
123
+
124
+ def reset(self, **kwargs):
125
+ self._visit_counts = np.zeros((self._grid_size, self._grid_size), dtype=np.int32)
126
+ return self.env.reset(**kwargs)
127
+
128
+ def step(self, action):
129
+ obs, reward, done, truncated, info = self.env.step(action)
130
+ pos = info.get("location", None)
131
+ bonus = 0.0
132
+ if pos is not None:
133
+ x, y = int(pos[0]), int(pos[1])
134
+ if 0 <= x < self._grid_size and 0 <= y < self._grid_size:
135
+ bonus = 1.0 / (1.0 + self._visit_counts[x, y])
136
+ self._visit_counts[x, y] += 1
137
+ if done:
138
+ self._episode_count += 1
139
+ alpha = 1.0 - np.tanh(self.adaptive_k * self._avg_enemy_deaths)
140
+ self._explore_weight = self.base_explore_weight * max(0.1, alpha)
141
+ return obs, reward + self._explore_weight * bonus, done, truncated, info
142
+
143
+ def action_masks(self):
144
+ return self.env.action_masks()
145
+
146
+
147
+ class RuleBasedOpponent:
148
+ def __init__(self, difficulty="simple"):
149
+ self.difficulty = difficulty
150
+
151
+ def act(self, od):
152
+ valid = np.where(od["action_mask"] == 1)[0]
153
+ if len(valid) == 0:
154
+ return 4
155
+ if self.difficulty == "static":
156
+ return 4
157
+ if self.difficulty == "simple":
158
+ vc = od["agent_viewcone"]
159
+ if (np.any(vc[..., 10] > 0) or np.any(vc[..., 12] > 0)) and 5 in valid:
160
+ return 5
161
+ mv = [a for a in valid if a < 4]
162
+ return int(np.random.choice(mv)) if mv else 4
163
+ return 4
164
+
165
+
166
+ class CurriculumEnv(gym.Env):
167
+ STAGES = ["static", "simple", "smart", "mixed"]
168
+ WIN_RATE = 0.55
169
+ EPS_PER_STAGE = 500
170
+
171
+ def __init__(self, cfg=None, seed=None):
172
+ super().__init__()
173
+ self.cfg = cfg or default_config()
174
+ self.cfg.env.render_mode = None
175
+ self._parallel_env = aec_to_parallel(Bomberman(self.cfg))
176
+ self.agent_id = "agent_0"
177
+ self._episode_count = 0
178
+ self.action_space = Discrete(6)
179
+ self._last_action_mask = None
180
+ self._obs_size = None
181
+ self._last_obs_dict = None
182
+ self._compute_obs_space()
183
+ self.stage_idx = 0
184
+ self.stage_eps = 0
185
+ self.stage_wins = 0
186
+ self.stage_rewards = []
187
+ self.opponents = {}
188
+ self._init_opponents()
189
+
190
+ def _compute_obs_space(self):
191
+ cfg = self.cfg
192
+ vl = int(cfg.dynamics.vision.behind) + int(cfg.dynamics.vision.ahead) + 1
193
+ vw = int(cfg.dynamics.vision.left) + int(cfg.dynamics.vision.right) + 1
194
+ av = vl * vw * 25
195
+ br = int(cfg.entities.base.vision_radius)
196
+ bs = 2 * br + 1
197
+ bv = bs * bs * 25
198
+ self._obs_size = av + bv + 11
199
+ self.observation_space = Box(low=-np.inf, high=np.inf, shape=(self._obs_size,), dtype=np.float32)
200
+
201
+ def _init_opponents(self):
202
+ for i in range(1, self.cfg.env.num_teams):
203
+ self.opponents[f"agent_{i}"] = RuleBasedOpponent(difficulty="static")
204
+
205
+ def _update_difficulty(self):
206
+ stage = self.STAGES[self.stage_idx]
207
+ for oid, opp in self.opponents.items():
208
+ opp.difficulty = "smart" if (stage == "mixed" and int(oid.split("_")[1]) % 2 == 0) else stage
209
+
210
+ def _check_advance(self):
211
+ if self.stage_idx >= len(self.STAGES) - 1:
212
+ return False
213
+ if len(self.stage_rewards) >= self.EPS_PER_STAGE:
214
+ wr = self.stage_wins / max(1, len(self.stage_rewards))
215
+ if wr >= self.WIN_RATE:
216
+ print(f"Stage {self.STAGES[self.stage_idx]} done (wr={wr:.1%}). Advancing.")
217
+ self.stage_idx += 1
218
+ self.stage_eps = self.stage_wins = 0
219
+ self.stage_rewards = []
220
+ self._update_difficulty()
221
+ return True
222
+ return False
223
+
224
+ def reset(self, seed=None, options=None):
225
+ self._episode_count += 1
226
+ obs_dict, info_dict = self._parallel_env.reset(seed=self._episode_count, options=options)
227
+ self._last_obs_dict = obs_dict
228
+ self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
229
+ return self._flatten(obs_dict[self.agent_id]), {}
230
+
231
+ def step(self, action):
232
+ actions = {self.agent_id: action}
233
+ for aid, obs in self._last_obs_dict.items():
234
+ if aid != self.agent_id:
235
+ opp = self.opponents.get(aid)
236
+ actions[aid] = opp.act(obs) if opp else 4
237
+ obs_dict, rewards, terminations, truncations, infos = self._parallel_env.step(actions)
238
+ self._last_obs_dict = obs_dict
239
+ if self.agent_id not in obs_dict:
240
+ self.stage_eps += 1
241
+ return np.zeros(self._obs_size, dtype=np.float32), 0.0, True, False, {}
242
+ self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
243
+ obs = self._flatten(obs_dict[self.agent_id])
244
+ r = float(rewards.get(self.agent_id, 0.0))
245
+ done = terminations.get(self.agent_id, False) or truncations.get(self.agent_id, False)
246
+ if done:
247
+ self.stage_eps += 1
248
+ self.stage_rewards.append(r)
249
+ if r > 10.0:
250
+ self.stage_wins += 1
251
+ self._check_advance()
252
+ return obs, r, done, False, {"stage": self.stage_idx, "stage_name": self.STAGES[self.stage_idx]}
253
+
254
+ def action_masks(self):
255
+ return self._last_action_mask
256
+
257
+ def _flatten(self, od):
258
+ return np.concatenate([
259
+ od["agent_viewcone"].flatten(), od["base_viewcone"].flatten(),
260
+ np.array([od["direction"]], dtype=np.float32),
261
+ od["location"].flatten().astype(np.float32),
262
+ od["base_location"].flatten().astype(np.float32),
263
+ od["health"].flatten().astype(np.float32),
264
+ np.array([od["frozen_ticks"]], dtype=np.float32),
265
+ od["base_health"].flatten().astype(np.float32),
266
+ od["team_resources"].flatten().astype(np.float32),
267
+ np.array([od["team_bombs"]], dtype=np.float32),
268
+ np.array([od["step"]], dtype=np.float32),
269
+ ], dtype=np.float32)
270
+
271
+ def close(self):
272
+ self._parallel_env.close()
273
+
274
+
275
+ # ---------------------------------------------------------------------------
276
+ # Training
277
+ # ---------------------------------------------------------------------------
278
+
279
+ HUB_REPO = os.environ.get("HUB_MODEL_ID", "E-Rong/til-26-ae-agent")
280
+
281
+ def hub_push(path_in_local, path_in_repo, repo_id=HUB_REPO):
282
+ """Push a file to the Hub model repo."""
283
+ try:
284
+ api = HfApi()
285
+ api.upload_file(path_or_fileobj=path_in_local, path_in_repo=path_in_repo,
286
+ repo_id=repo_id, repo_type="model")
287
+ print(f" -> pushed {path_in_repo}")
288
+ except Exception as e:
289
+ print(f" -> push failed: {e}")
290
+
291
+
292
+ class HubCheckpointCallback(BaseCallback):
293
+ """Pushes .zip checkpoints to the Hub every N steps."""
294
+ def __init__(self, save_freq=50000, repo_id=HUB_REPO, verbose=0):
295
+ super().__init__(verbose)
296
+ self.save_freq = save_freq
297
+ self.repo_id = repo_id
298
+
299
+ def _on_step(self) -> bool:
300
+ if self.num_timesteps % self.save_freq == 0:
301
+ path = f"/tmp/checkpoint_{self.num_timesteps}.zip"
302
+ self.model.save(path)
303
+ hub_push(path, f"checkpoint_{self.num_timesteps}.zip", self.repo_id)
304
+ return True
305
+
306
+
307
+ def train_phase(phase, total_timesteps, model=None):
308
+ cfg = default_config()
309
+ cfg.env.render_mode = None
310
+
311
+ if phase == 1:
312
+ print("=== Phase 1: vs Random ===")
313
+ base = BombermanSingleAgentEnv(cfg=cfg)
314
+ env = ActionMasker(Monitor(base), lambda e: e.action_masks())
315
+ elif phase == 2:
316
+ print("=== Phase 2: + Exploration Shaping ===")
317
+ base = BombermanSingleAgentEnv(cfg=cfg)
318
+ base = RewardShapingWrapper(base)
319
+ env = ActionMasker(Monitor(base), lambda e: e.action_masks())
320
+ elif phase == 3:
321
+ print("=== Phase 3: Curriculum Self-Play ===")
322
+ cfg.env.num_teams = 3
323
+ base = CurriculumEnv(cfg=cfg)
324
+ env = ActionMasker(Monitor(base), lambda e: e.action_masks())
325
+ else:
326
+ raise ValueError(phase)
327
+
328
+ if model is None:
329
+ print("Creating MaskablePPO...")
330
+ model = MaskablePPO(
331
+ "MlpPolicy", env,
332
+ learning_rate=3e-4, n_steps=2048, batch_size=64, n_epochs=10,
333
+ gamma=0.99, gae_lambda=0.95, clip_range=0.2,
334
+ ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5,
335
+ verbose=1,
336
+ device="cuda" if torch.cuda.is_available() else "cpu",
337
+ )
338
+ else:
339
+ model.set_env(env)
340
+
341
+ ckpt_cb = CheckpointCallback(save_freq=50000, save_path="./ckpts", name_prefix=f"p{phase}")
342
+ hub_cb = HubCheckpointCallback(save_freq=50000, repo_id=HUB_REPO)
343
+
344
+ model.learn(total_timesteps=total_timesteps, callback=[ckpt_cb, hub_cb], progress_bar=False)
345
+ final = f"phase{phase}_final.zip"
346
+ model.save(final)
347
+ hub_push(final, final, HUB_REPO)
348
+ env.close()
349
+ print(f"Phase {phase} complete.")
350
+ return model
351
+
352
+
353
+ def main():
354
+ ts = os.environ.get("TOTAL_TIMESTEPS", "500000:500000:1000000")
355
+ phase_ts = [int(x.replace("_", "")) for x in ts.split(":")]
356
+ print(f"Phase timesteps: {phase_ts}")
357
+
358
+ model = None
359
+ for i, t in enumerate(phase_ts[:3], 1):
360
+ model = train_phase(i, t, model)
361
+
362
+ print("\n=== All phases complete ===")
363
+
364
+
365
+ if __name__ == "__main__":
366
+ main()