til-26-ae-agent / train_in_space.py
E-Rong's picture
Add train_in_space.py for running training inside the Space
7b2f944 verified
#!/usr/bin/env python3
"""
TIL-26-AE Bomberman Agent Training - Runs inside the Space
Uses local til_environment (already in repo) + pushes checkpoints to Hub model repo.
"""
import os
import sys
import numpy as np
import gymnasium as gym
from gymnasium.spaces import Box, Discrete
import torch
# In the Space, til-26-ae is at the repo root; in sandbox it's elsewhere.
# Try multiple paths.
for path in [
"/home/user/app/til-26-ae", # HF Space typical path
"/app/til-26-ae", # sandbox path
os.path.join(os.path.dirname(__file__), "..", "til-26-ae"), # relative
"til-26-ae", # current dir
]:
if os.path.isdir(path):
sys.path.insert(0, path)
print(f"Using til_environment from: {path}")
break
from til_environment.bomberman_env import Bomberman
from til_environment.config import default_config
from pettingzoo.utils.conversions import aec_to_parallel
from sb3_contrib import MaskablePPO
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
from stable_baselines3.common.monitor import Monitor
from huggingface_hub import HfApi
# ---------------------------------------------------------------------------
# Environment wrappers
# ---------------------------------------------------------------------------
class BombermanSingleAgentEnv(gym.Env):
def __init__(self, cfg=None, seed=None, opponent_policy="random"):
super().__init__()
self.cfg = cfg or default_config()
self.cfg.env.render_mode = None
raw = Bomberman(self.cfg)
self._parallel_env = aec_to_parallel(raw)
self.agent_id = "agent_0"
self._episode_seed = seed
self._episode_count = 0
self.action_space = Discrete(6)
self._last_action_mask = None
self._obs_size = None
self._last_obs_dict = None
self._compute_obs_space()
def _compute_obs_space(self):
cfg = self.cfg
viewcone_l = int(cfg.dynamics.vision.behind) + int(cfg.dynamics.vision.ahead) + 1
viewcone_w = int(cfg.dynamics.vision.left) + int(cfg.dynamics.vision.right) + 1
agent_viewcone_size = viewcone_l * viewcone_w * 25
base_r = int(cfg.entities.base.vision_radius)
base_side = 2 * base_r + 1
base_viewcone_size = base_side * base_side * 25
scalar_size = 11
self._obs_size = agent_viewcone_size + base_viewcone_size + scalar_size
self.observation_space = Box(low=-np.inf, high=np.inf, shape=(self._obs_size,), dtype=np.float32)
def reset(self, seed=None, options=None):
self._episode_seed = self._episode_count if seed is None else seed
self._episode_count += 1
obs_dict, info_dict = self._parallel_env.reset(seed=self._episode_seed, options=options)
self._last_obs_dict = obs_dict
self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
return self._flatten_obs(obs_dict[self.agent_id]), {}
def step(self, action):
actions = {self.agent_id: action}
for aid, obs in self._last_obs_dict.items():
if aid != self.agent_id:
valid = np.where(obs["action_mask"] == 1)[0]
actions[aid] = int(np.random.choice(valid)) if len(valid) > 0 else 0
obs_dict, rewards, terminations, truncations, infos = self._parallel_env.step(actions)
self._last_obs_dict = obs_dict
if self.agent_id not in obs_dict:
return np.zeros(self._obs_size, dtype=np.float32), 0.0, True, False, {}
self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
obs = self._flatten_obs(obs_dict[self.agent_id])
r = float(rewards.get(self.agent_id, 0.0))
done = terminations.get(self.agent_id, False) or truncations.get(self.agent_id, False)
return obs, r, done, False, infos.get(self.agent_id, {})
def action_masks(self):
return self._last_action_mask
def _flatten_obs(self, od):
return np.concatenate([
od["agent_viewcone"].flatten(), od["base_viewcone"].flatten(),
np.array([od["direction"]], dtype=np.float32),
od["location"].flatten().astype(np.float32),
od["base_location"].flatten().astype(np.float32),
od["health"].flatten().astype(np.float32),
np.array([od["frozen_ticks"]], dtype=np.float32),
od["base_health"].flatten().astype(np.float32),
od["team_resources"].flatten().astype(np.float32),
np.array([od["team_bombs"]], dtype=np.float32),
np.array([od["step"]], dtype=np.float32),
], dtype=np.float32)
def close(self):
self._parallel_env.close()
class RewardShapingWrapper(gym.Wrapper):
def __init__(self, env, adaptive_k=1.2, base_explore_weight=0.5):
super().__init__(env)
self.adaptive_k = adaptive_k
self.base_explore_weight = base_explore_weight
self._visit_counts = None
self._grid_size = 16
self._avg_enemy_deaths = 0.0
self._episode_count = 0
self._explore_weight = base_explore_weight
def reset(self, **kwargs):
self._visit_counts = np.zeros((self._grid_size, self._grid_size), dtype=np.int32)
return self.env.reset(**kwargs)
def step(self, action):
obs, reward, done, truncated, info = self.env.step(action)
pos = info.get("location", None)
bonus = 0.0
if pos is not None:
x, y = int(pos[0]), int(pos[1])
if 0 <= x < self._grid_size and 0 <= y < self._grid_size:
bonus = 1.0 / (1.0 + self._visit_counts[x, y])
self._visit_counts[x, y] += 1
if done:
self._episode_count += 1
alpha = 1.0 - np.tanh(self.adaptive_k * self._avg_enemy_deaths)
self._explore_weight = self.base_explore_weight * max(0.1, alpha)
return obs, reward + self._explore_weight * bonus, done, truncated, info
def action_masks(self):
return self.env.action_masks()
class RuleBasedOpponent:
def __init__(self, difficulty="simple"):
self.difficulty = difficulty
def act(self, od):
valid = np.where(od["action_mask"] == 1)[0]
if len(valid) == 0:
return 4
if self.difficulty == "static":
return 4
if self.difficulty == "simple":
vc = od["agent_viewcone"]
if (np.any(vc[..., 10] > 0) or np.any(vc[..., 12] > 0)) and 5 in valid:
return 5
mv = [a for a in valid if a < 4]
return int(np.random.choice(mv)) if mv else 4
return 4
class CurriculumEnv(gym.Env):
STAGES = ["static", "simple", "smart", "mixed"]
WIN_RATE = 0.55
EPS_PER_STAGE = 500
def __init__(self, cfg=None, seed=None):
super().__init__()
self.cfg = cfg or default_config()
self.cfg.env.render_mode = None
self._parallel_env = aec_to_parallel(Bomberman(self.cfg))
self.agent_id = "agent_0"
self._episode_count = 0
self.action_space = Discrete(6)
self._last_action_mask = None
self._obs_size = None
self._last_obs_dict = None
self._compute_obs_space()
self.stage_idx = 0
self.stage_eps = 0
self.stage_wins = 0
self.stage_rewards = []
self.opponents = {}
self._init_opponents()
def _compute_obs_space(self):
cfg = self.cfg
vl = int(cfg.dynamics.vision.behind) + int(cfg.dynamics.vision.ahead) + 1
vw = int(cfg.dynamics.vision.left) + int(cfg.dynamics.vision.right) + 1
av = vl * vw * 25
br = int(cfg.entities.base.vision_radius)
bs = 2 * br + 1
bv = bs * bs * 25
self._obs_size = av + bv + 11
self.observation_space = Box(low=-np.inf, high=np.inf, shape=(self._obs_size,), dtype=np.float32)
def _init_opponents(self):
for i in range(1, self.cfg.env.num_teams):
self.opponents[f"agent_{i}"] = RuleBasedOpponent(difficulty="static")
def _update_difficulty(self):
stage = self.STAGES[self.stage_idx]
for oid, opp in self.opponents.items():
opp.difficulty = "smart" if (stage == "mixed" and int(oid.split("_")[1]) % 2 == 0) else stage
def _check_advance(self):
if self.stage_idx >= len(self.STAGES) - 1:
return False
if len(self.stage_rewards) >= self.EPS_PER_STAGE:
wr = self.stage_wins / max(1, len(self.stage_rewards))
if wr >= self.WIN_RATE:
print(f"Stage {self.STAGES[self.stage_idx]} done (wr={wr:.1%}). Advancing.")
self.stage_idx += 1
self.stage_eps = self.stage_wins = 0
self.stage_rewards = []
self._update_difficulty()
return True
return False
def reset(self, seed=None, options=None):
self._episode_count += 1
obs_dict, info_dict = self._parallel_env.reset(seed=self._episode_count, options=options)
self._last_obs_dict = obs_dict
self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
return self._flatten(obs_dict[self.agent_id]), {}
def step(self, action):
actions = {self.agent_id: action}
for aid, obs in self._last_obs_dict.items():
if aid != self.agent_id:
opp = self.opponents.get(aid)
actions[aid] = opp.act(obs) if opp else 4
obs_dict, rewards, terminations, truncations, infos = self._parallel_env.step(actions)
self._last_obs_dict = obs_dict
if self.agent_id not in obs_dict:
self.stage_eps += 1
return np.zeros(self._obs_size, dtype=np.float32), 0.0, True, False, {}
self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
obs = self._flatten(obs_dict[self.agent_id])
r = float(rewards.get(self.agent_id, 0.0))
done = terminations.get(self.agent_id, False) or truncations.get(self.agent_id, False)
if done:
self.stage_eps += 1
self.stage_rewards.append(r)
if r > 10.0:
self.stage_wins += 1
self._check_advance()
return obs, r, done, False, {"stage": self.stage_idx, "stage_name": self.STAGES[self.stage_idx]}
def action_masks(self):
return self._last_action_mask
def _flatten(self, od):
return np.concatenate([
od["agent_viewcone"].flatten(), od["base_viewcone"].flatten(),
np.array([od["direction"]], dtype=np.float32),
od["location"].flatten().astype(np.float32),
od["base_location"].flatten().astype(np.float32),
od["health"].flatten().astype(np.float32),
np.array([od["frozen_ticks"]], dtype=np.float32),
od["base_health"].flatten().astype(np.float32),
od["team_resources"].flatten().astype(np.float32),
np.array([od["team_bombs"]], dtype=np.float32),
np.array([od["step"]], dtype=np.float32),
], dtype=np.float32)
def close(self):
self._parallel_env.close()
# ---------------------------------------------------------------------------
# Training
# ---------------------------------------------------------------------------
HUB_REPO = os.environ.get("HUB_MODEL_ID", "E-Rong/til-26-ae-agent")
def hub_push(path_in_local, path_in_repo, repo_id=HUB_REPO):
"""Push a file to the Hub model repo."""
try:
api = HfApi()
api.upload_file(path_or_fileobj=path_in_local, path_in_repo=path_in_repo,
repo_id=repo_id, repo_type="model")
print(f" -> pushed {path_in_repo}")
except Exception as e:
print(f" -> push failed: {e}")
class HubCheckpointCallback(BaseCallback):
"""Pushes .zip checkpoints to the Hub every N steps."""
def __init__(self, save_freq=50000, repo_id=HUB_REPO, verbose=0):
super().__init__(verbose)
self.save_freq = save_freq
self.repo_id = repo_id
def _on_step(self) -> bool:
if self.num_timesteps % self.save_freq == 0:
path = f"/tmp/checkpoint_{self.num_timesteps}.zip"
self.model.save(path)
hub_push(path, f"checkpoint_{self.num_timesteps}.zip", self.repo_id)
return True
def train_phase(phase, total_timesteps, model=None):
cfg = default_config()
cfg.env.render_mode = None
if phase == 1:
print("=== Phase 1: vs Random ===")
base = BombermanSingleAgentEnv(cfg=cfg)
env = ActionMasker(Monitor(base), lambda e: e.action_masks())
elif phase == 2:
print("=== Phase 2: + Exploration Shaping ===")
base = BombermanSingleAgentEnv(cfg=cfg)
base = RewardShapingWrapper(base)
env = ActionMasker(Monitor(base), lambda e: e.action_masks())
elif phase == 3:
print("=== Phase 3: Curriculum Self-Play ===")
cfg.env.num_teams = 3
base = CurriculumEnv(cfg=cfg)
env = ActionMasker(Monitor(base), lambda e: e.action_masks())
else:
raise ValueError(phase)
if model is None:
print("Creating MaskablePPO...")
model = MaskablePPO(
"MlpPolicy", env,
learning_rate=3e-4, n_steps=2048, batch_size=64, n_epochs=10,
gamma=0.99, gae_lambda=0.95, clip_range=0.2,
ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5,
verbose=1,
device="cuda" if torch.cuda.is_available() else "cpu",
)
else:
model.set_env(env)
ckpt_cb = CheckpointCallback(save_freq=50000, save_path="./ckpts", name_prefix=f"p{phase}")
hub_cb = HubCheckpointCallback(save_freq=50000, repo_id=HUB_REPO)
model.learn(total_timesteps=total_timesteps, callback=[ckpt_cb, hub_cb], progress_bar=False)
final = f"phase{phase}_final.zip"
model.save(final)
hub_push(final, final, HUB_REPO)
env.close()
print(f"Phase {phase} complete.")
return model
def main():
ts = os.environ.get("TOTAL_TIMESTEPS", "500000:500000:1000000")
phase_ts = [int(x.replace("_", "")) for x in ts.split(":")]
print(f"Phase timesteps: {phase_ts}")
model = None
for i, t in enumerate(phase_ts[:3], 1):
model = train_phase(i, t, model)
print("\n=== All phases complete ===")
if __name__ == "__main__":
main()