Fix agent tracking to use possible_agents instead of agents attribute
Browse files- train_all_phases.py +16 -12
train_all_phases.py
CHANGED
|
@@ -3,12 +3,6 @@
|
|
| 3 |
Full training pipeline: Phase 1 -> Phase 2 -> Phase 3
|
| 4 |
TIL-26-AE Bomberman Agent Training
|
| 5 |
|
| 6 |
-
Run with:
|
| 7 |
-
TOTAL_TIMESTEPS=500_000:500_000:1_000_000 \
|
| 8 |
-
HUB_MODEL_ID=E-Rong/til-26-ae-agent \
|
| 9 |
-
TRACKIO_PROJECT=til-26-ae \
|
| 10 |
-
python train_all_phases.py
|
| 11 |
-
|
| 12 |
References:
|
| 13 |
- Pommerman multi-agent RL: arxiv:2407.00662
|
| 14 |
- MAPPO best practices: arxiv:2103.01955
|
|
@@ -101,6 +95,12 @@ class BombermanSingleAgentEnv(gym.Env):
|
|
| 101 |
shape=(self._obs_size,), dtype=np.float32,
|
| 102 |
)
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
def reset(self, seed=None, options=None):
|
| 105 |
if seed is not None:
|
| 106 |
self._episode_seed = seed
|
|
@@ -109,18 +109,18 @@ class BombermanSingleAgentEnv(gym.Env):
|
|
| 109 |
self._episode_count += 1
|
| 110 |
|
| 111 |
obs_dict, info_dict = self._parallel_env.reset(seed=self._episode_seed, options=options)
|
| 112 |
-
self._store_action_mask(obs_dict[self.agent_id])
|
| 113 |
self._last_obs_dict = obs_dict
|
|
|
|
| 114 |
return self._flatten_obs(obs_dict[self.agent_id]), {}
|
| 115 |
|
| 116 |
def step(self, action):
|
| 117 |
actions = {}
|
| 118 |
-
for agent_id in self.
|
| 119 |
if agent_id == self.agent_id:
|
| 120 |
actions[agent_id] = action
|
| 121 |
else:
|
| 122 |
mask = (
|
| 123 |
-
self._last_obs_dict[agent_id]
|
| 124 |
if self._last_obs_dict and agent_id in self._last_obs_dict
|
| 125 |
else np.ones(6, dtype=np.int8)
|
| 126 |
)
|
|
@@ -374,6 +374,11 @@ class CurriculumEnv(gym.Env):
|
|
| 374 |
shape=(self._obs_size,), dtype=np.float32,
|
| 375 |
)
|
| 376 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
def _init_opponents(self):
|
| 378 |
for i in range(1, self.cfg.env.num_teams):
|
| 379 |
opp_id = f"agent_{i}"
|
|
@@ -422,14 +427,13 @@ class CurriculumEnv(gym.Env):
|
|
| 422 |
obs_dict, info_dict = self._parallel_env.reset(
|
| 423 |
seed=self._episode_seed, options=options
|
| 424 |
)
|
| 425 |
-
|
| 426 |
-
self._store_action_mask(obs_dict[self.agent_id])
|
| 427 |
self._last_obs_dict = obs_dict
|
|
|
|
| 428 |
return self._flatten_obs(obs_dict[self.agent_id]), {}
|
| 429 |
|
| 430 |
def step(self, action):
|
| 431 |
actions = {}
|
| 432 |
-
for agent_id in self.
|
| 433 |
if agent_id == self.agent_id:
|
| 434 |
actions[agent_id] = action
|
| 435 |
else:
|
|
|
|
| 3 |
Full training pipeline: Phase 1 -> Phase 2 -> Phase 3
|
| 4 |
TIL-26-AE Bomberman Agent Training
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
References:
|
| 7 |
- Pommerman multi-agent RL: arxiv:2407.00662
|
| 8 |
- MAPPO best practices: arxiv:2103.01955
|
|
|
|
| 95 |
shape=(self._obs_size,), dtype=np.float32,
|
| 96 |
)
|
| 97 |
|
| 98 |
+
def _get_agents(self):
|
| 99 |
+
"""Get list of currently active agents from obs_dict."""
|
| 100 |
+
if self._last_obs_dict is not None:
|
| 101 |
+
return list(self._last_obs_dict.keys())
|
| 102 |
+
return self._parallel_env.possible_agents
|
| 103 |
+
|
| 104 |
def reset(self, seed=None, options=None):
|
| 105 |
if seed is not None:
|
| 106 |
self._episode_seed = seed
|
|
|
|
| 109 |
self._episode_count += 1
|
| 110 |
|
| 111 |
obs_dict, info_dict = self._parallel_env.reset(seed=self._episode_seed, options=options)
|
|
|
|
| 112 |
self._last_obs_dict = obs_dict
|
| 113 |
+
self._store_action_mask(obs_dict[self.agent_id])
|
| 114 |
return self._flatten_obs(obs_dict[self.agent_id]), {}
|
| 115 |
|
| 116 |
def step(self, action):
|
| 117 |
actions = {}
|
| 118 |
+
for agent_id in self._get_agents():
|
| 119 |
if agent_id == self.agent_id:
|
| 120 |
actions[agent_id] = action
|
| 121 |
else:
|
| 122 |
mask = (
|
| 123 |
+
self._last_obs_dict[agent_id].get("action_mask")
|
| 124 |
if self._last_obs_dict and agent_id in self._last_obs_dict
|
| 125 |
else np.ones(6, dtype=np.int8)
|
| 126 |
)
|
|
|
|
| 374 |
shape=(self._obs_size,), dtype=np.float32,
|
| 375 |
)
|
| 376 |
|
| 377 |
+
def _get_agents(self):
|
| 378 |
+
if self._last_obs_dict is not None:
|
| 379 |
+
return list(self._last_obs_dict.keys())
|
| 380 |
+
return self._parallel_env.possible_agents
|
| 381 |
+
|
| 382 |
def _init_opponents(self):
|
| 383 |
for i in range(1, self.cfg.env.num_teams):
|
| 384 |
opp_id = f"agent_{i}"
|
|
|
|
| 427 |
obs_dict, info_dict = self._parallel_env.reset(
|
| 428 |
seed=self._episode_seed, options=options
|
| 429 |
)
|
|
|
|
|
|
|
| 430 |
self._last_obs_dict = obs_dict
|
| 431 |
+
self._store_action_mask(obs_dict[self.agent_id])
|
| 432 |
return self._flatten_obs(obs_dict[self.agent_id]), {}
|
| 433 |
|
| 434 |
def step(self, action):
|
| 435 |
actions = {}
|
| 436 |
+
for agent_id in self._get_agents():
|
| 437 |
if agent_id == self.agent_id:
|
| 438 |
actions[agent_id] = action
|
| 439 |
else:
|