E-Rong commited on
Commit
0e66bf6
·
verified ·
1 Parent(s): dae5fb8

Fix agent tracking to use possible_agents instead of agents attribute

Browse files
Files changed (1) hide show
  1. train_all_phases.py +16 -12
train_all_phases.py CHANGED
@@ -3,12 +3,6 @@
3
  Full training pipeline: Phase 1 -> Phase 2 -> Phase 3
4
  TIL-26-AE Bomberman Agent Training
5
 
6
- Run with:
7
- TOTAL_TIMESTEPS=500_000:500_000:1_000_000 \
8
- HUB_MODEL_ID=E-Rong/til-26-ae-agent \
9
- TRACKIO_PROJECT=til-26-ae \
10
- python train_all_phases.py
11
-
12
  References:
13
  - Pommerman multi-agent RL: arxiv:2407.00662
14
  - MAPPO best practices: arxiv:2103.01955
@@ -101,6 +95,12 @@ class BombermanSingleAgentEnv(gym.Env):
101
  shape=(self._obs_size,), dtype=np.float32,
102
  )
103
 
 
 
 
 
 
 
104
  def reset(self, seed=None, options=None):
105
  if seed is not None:
106
  self._episode_seed = seed
@@ -109,18 +109,18 @@ class BombermanSingleAgentEnv(gym.Env):
109
  self._episode_count += 1
110
 
111
  obs_dict, info_dict = self._parallel_env.reset(seed=self._episode_seed, options=options)
112
- self._store_action_mask(obs_dict[self.agent_id])
113
  self._last_obs_dict = obs_dict
 
114
  return self._flatten_obs(obs_dict[self.agent_id]), {}
115
 
116
  def step(self, action):
117
  actions = {}
118
- for agent_id in self._parallel_env.agents:
119
  if agent_id == self.agent_id:
120
  actions[agent_id] = action
121
  else:
122
  mask = (
123
- self._last_obs_dict[agent_id]["action_mask"]
124
  if self._last_obs_dict and agent_id in self._last_obs_dict
125
  else np.ones(6, dtype=np.int8)
126
  )
@@ -374,6 +374,11 @@ class CurriculumEnv(gym.Env):
374
  shape=(self._obs_size,), dtype=np.float32,
375
  )
376
 
 
 
 
 
 
377
  def _init_opponents(self):
378
  for i in range(1, self.cfg.env.num_teams):
379
  opp_id = f"agent_{i}"
@@ -422,14 +427,13 @@ class CurriculumEnv(gym.Env):
422
  obs_dict, info_dict = self._parallel_env.reset(
423
  seed=self._episode_seed, options=options
424
  )
425
-
426
- self._store_action_mask(obs_dict[self.agent_id])
427
  self._last_obs_dict = obs_dict
 
428
  return self._flatten_obs(obs_dict[self.agent_id]), {}
429
 
430
  def step(self, action):
431
  actions = {}
432
- for agent_id in self._parallel_env.agents:
433
  if agent_id == self.agent_id:
434
  actions[agent_id] = action
435
  else:
 
3
  Full training pipeline: Phase 1 -> Phase 2 -> Phase 3
4
  TIL-26-AE Bomberman Agent Training
5
 
 
 
 
 
 
 
6
  References:
7
  - Pommerman multi-agent RL: arxiv:2407.00662
8
  - MAPPO best practices: arxiv:2103.01955
 
95
  shape=(self._obs_size,), dtype=np.float32,
96
  )
97
 
98
+ def _get_agents(self):
99
+ """Get list of currently active agents from obs_dict."""
100
+ if self._last_obs_dict is not None:
101
+ return list(self._last_obs_dict.keys())
102
+ return self._parallel_env.possible_agents
103
+
104
  def reset(self, seed=None, options=None):
105
  if seed is not None:
106
  self._episode_seed = seed
 
109
  self._episode_count += 1
110
 
111
  obs_dict, info_dict = self._parallel_env.reset(seed=self._episode_seed, options=options)
 
112
  self._last_obs_dict = obs_dict
113
+ self._store_action_mask(obs_dict[self.agent_id])
114
  return self._flatten_obs(obs_dict[self.agent_id]), {}
115
 
116
  def step(self, action):
117
  actions = {}
118
+ for agent_id in self._get_agents():
119
  if agent_id == self.agent_id:
120
  actions[agent_id] = action
121
  else:
122
  mask = (
123
+ self._last_obs_dict[agent_id].get("action_mask")
124
  if self._last_obs_dict and agent_id in self._last_obs_dict
125
  else np.ones(6, dtype=np.int8)
126
  )
 
374
  shape=(self._obs_size,), dtype=np.float32,
375
  )
376
 
377
+ def _get_agents(self):
378
+ if self._last_obs_dict is not None:
379
+ return list(self._last_obs_dict.keys())
380
+ return self._parallel_env.possible_agents
381
+
382
  def _init_opponents(self):
383
  for i in range(1, self.cfg.env.num_teams):
384
  opp_id = f"agent_{i}"
 
427
  obs_dict, info_dict = self._parallel_env.reset(
428
  seed=self._episode_seed, options=options
429
  )
 
 
430
  self._last_obs_dict = obs_dict
431
+ self._store_action_mask(obs_dict[self.agent_id])
432
  return self._flatten_obs(obs_dict[self.agent_id]), {}
433
 
434
  def step(self, action):
435
  actions = {}
436
+ for agent_id in self._get_agents():
437
  if agent_id == self.agent_id:
438
  actions[agent_id] = action
439
  else: