til-26-ae-agent / ae_manager.py
E-Rong's picture
Add inference-ready AE manager for loading trained MaskablePPO
4105e6d verified
raw
history blame
3.18 kB
"""AE Manager - loads trained MaskablePPO and returns actions for Bomberman."""
import os
import sys
import numpy as np
from sb3_contrib import MaskablePPO
# Try to find til_environment (for default_config/obs shape if needed)
for p in [
os.path.join(os.path.dirname(__file__), "..", "til-26-ae"),
"/app/til-26-ae-repo/til-26-ae",
"til-26-ae",
]:
if os.path.isdir(p) and os.path.isfile(os.path.join(p, "til_environment", "bomberman_env.py")):
sys.path.insert(0, p)
break
class AEManager:
"""Loads a trained MaskablePPO model and serves inference for Bomberman."""
def __init__(self):
self.model = None
self._obs_size = None
# Try loading from several locations
candidates = [
os.environ.get("MODEL_PATH", ""),
os.path.join(os.path.dirname(__file__), "..", "phase1_final.zip"),
os.path.join(os.path.dirname(__file__), "..", "phase3_final.zip"),
"/app/data/phase3_final.zip",
"/app/data/phase2_final.zip",
"/app/data/phase1_final.zip",
]
for path in candidates:
if path and os.path.isfile(path):
try:
self.model = MaskablePPO.load(path)
print(f"[AE Manager] Loaded model from {path}")
break
except Exception as e:
print(f"[AE Manager] Failed to load {path}: {e}")
if self.model is None:
print("[AE Manager] No trained model found -- will return random valid actions.")
@staticmethod
def _flatten_obs(obs_dict):
"""Flatten observation dict into the vector used during training."""
return np.concatenate([
np.array(obs_dict["agent_viewcone"]).flatten(),
np.array(obs_dict["base_viewcone"]).flatten(),
np.array([obs_dict["direction"]], dtype=np.float32),
np.array(obs_dict["location"]).flatten().astype(np.float32),
np.array(obs_dict["base_location"]).flatten().astype(np.float32),
np.array(obs_dict["health"]).flatten().astype(np.float32),
np.array([obs_dict["frozen_ticks"]], dtype=np.float32),
np.array(obs_dict["base_health"]).flatten().astype(np.float32),
np.array(obs_dict["team_resources"]).flatten().astype(np.float32),
np.array([obs_dict["team_bombs"]], dtype=np.float32),
np.array([obs_dict["step"]], dtype=np.float32),
], dtype=np.float32)
def ae(self, observation: dict) -> int:
"""Get action from observation dict."""
if self.model is None:
# Fallback: random valid action
mask = np.array(observation.get("action_mask", [1]*6), dtype=bool)
valid = np.where(mask)[0]
return int(np.random.choice(valid)) if len(valid) > 0 else 4
obs_vec = self._flatten_obs(observation)
action_mask = np.array(observation.get("action_mask", [1]*6), dtype=bool)
action, _ = self.model.predict(
obs_vec,
action_masks=action_mask,
deterministic=True,
)
return int(action)