"""AE Manager - loads trained MaskablePPO and returns actions for Bomberman.""" import os import sys import numpy as np from sb3_contrib import MaskablePPO # Try to find til_environment (for default_config/obs shape if needed) for p in [ os.path.join(os.path.dirname(__file__), "..", "til-26-ae"), "/app/til-26-ae-repo/til-26-ae", "til-26-ae", ]: if os.path.isdir(p) and os.path.isfile(os.path.join(p, "til_environment", "bomberman_env.py")): sys.path.insert(0, p) break class AEManager: """Loads a trained MaskablePPO model and serves inference for Bomberman.""" def __init__(self): self.model = None self._obs_size = None # Try loading from several locations candidates = [ os.environ.get("MODEL_PATH", ""), os.path.join(os.path.dirname(__file__), "..", "phase1_final.zip"), os.path.join(os.path.dirname(__file__), "..", "phase3_final.zip"), "/app/data/phase3_final.zip", "/app/data/phase2_final.zip", "/app/data/phase1_final.zip", ] for path in candidates: if path and os.path.isfile(path): try: self.model = MaskablePPO.load(path) print(f"[AE Manager] Loaded model from {path}") break except Exception as e: print(f"[AE Manager] Failed to load {path}: {e}") if self.model is None: print("[AE Manager] No trained model found -- will return random valid actions.") @staticmethod def _flatten_obs(obs_dict): """Flatten observation dict into the vector used during training.""" return np.concatenate([ np.array(obs_dict["agent_viewcone"]).flatten(), np.array(obs_dict["base_viewcone"]).flatten(), np.array([obs_dict["direction"]], dtype=np.float32), np.array(obs_dict["location"]).flatten().astype(np.float32), np.array(obs_dict["base_location"]).flatten().astype(np.float32), np.array(obs_dict["health"]).flatten().astype(np.float32), np.array([obs_dict["frozen_ticks"]], dtype=np.float32), np.array(obs_dict["base_health"]).flatten().astype(np.float32), np.array(obs_dict["team_resources"]).flatten().astype(np.float32), np.array([obs_dict["team_bombs"]], dtype=np.float32), np.array([obs_dict["step"]], dtype=np.float32), ], dtype=np.float32) def ae(self, observation: dict) -> int: """Get action from observation dict.""" if self.model is None: # Fallback: random valid action mask = np.array(observation.get("action_mask", [1]*6), dtype=bool) valid = np.where(mask)[0] return int(np.random.choice(valid)) if len(valid) > 0 else 4 obs_vec = self._flatten_obs(observation) action_mask = np.array(observation.get("action_mask", [1]*6), dtype=bool) action, _ = self.model.predict( obs_vec, action_masks=action_mask, deterministic=True, ) return int(action)