E-Rong commited on
Commit
4105e6d
·
verified ·
1 Parent(s): ef8b845

Add inference-ready AE manager for loading trained MaskablePPO

Browse files
Files changed (1) hide show
  1. ae_manager.py +78 -0
ae_manager.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AE Manager - loads trained MaskablePPO and returns actions for Bomberman."""
2
+
3
+ import os
4
+ import sys
5
+ import numpy as np
6
+ from sb3_contrib import MaskablePPO
7
+
8
+ # Try to find til_environment (for default_config/obs shape if needed)
9
+ for p in [
10
+ os.path.join(os.path.dirname(__file__), "..", "til-26-ae"),
11
+ "/app/til-26-ae-repo/til-26-ae",
12
+ "til-26-ae",
13
+ ]:
14
+ if os.path.isdir(p) and os.path.isfile(os.path.join(p, "til_environment", "bomberman_env.py")):
15
+ sys.path.insert(0, p)
16
+ break
17
+
18
+
19
+ class AEManager:
20
+ """Loads a trained MaskablePPO model and serves inference for Bomberman."""
21
+
22
+ def __init__(self):
23
+ self.model = None
24
+ self._obs_size = None
25
+ # Try loading from several locations
26
+ candidates = [
27
+ os.environ.get("MODEL_PATH", ""),
28
+ os.path.join(os.path.dirname(__file__), "..", "phase1_final.zip"),
29
+ os.path.join(os.path.dirname(__file__), "..", "phase3_final.zip"),
30
+ "/app/data/phase3_final.zip",
31
+ "/app/data/phase2_final.zip",
32
+ "/app/data/phase1_final.zip",
33
+ ]
34
+ for path in candidates:
35
+ if path and os.path.isfile(path):
36
+ try:
37
+ self.model = MaskablePPO.load(path)
38
+ print(f"[AE Manager] Loaded model from {path}")
39
+ break
40
+ except Exception as e:
41
+ print(f"[AE Manager] Failed to load {path}: {e}")
42
+ if self.model is None:
43
+ print("[AE Manager] No trained model found -- will return random valid actions.")
44
+
45
+ @staticmethod
46
+ def _flatten_obs(obs_dict):
47
+ """Flatten observation dict into the vector used during training."""
48
+ return np.concatenate([
49
+ np.array(obs_dict["agent_viewcone"]).flatten(),
50
+ np.array(obs_dict["base_viewcone"]).flatten(),
51
+ np.array([obs_dict["direction"]], dtype=np.float32),
52
+ np.array(obs_dict["location"]).flatten().astype(np.float32),
53
+ np.array(obs_dict["base_location"]).flatten().astype(np.float32),
54
+ np.array(obs_dict["health"]).flatten().astype(np.float32),
55
+ np.array([obs_dict["frozen_ticks"]], dtype=np.float32),
56
+ np.array(obs_dict["base_health"]).flatten().astype(np.float32),
57
+ np.array(obs_dict["team_resources"]).flatten().astype(np.float32),
58
+ np.array([obs_dict["team_bombs"]], dtype=np.float32),
59
+ np.array([obs_dict["step"]], dtype=np.float32),
60
+ ], dtype=np.float32)
61
+
62
+ def ae(self, observation: dict) -> int:
63
+ """Get action from observation dict."""
64
+ if self.model is None:
65
+ # Fallback: random valid action
66
+ mask = np.array(observation.get("action_mask", [1]*6), dtype=bool)
67
+ valid = np.where(mask)[0]
68
+ return int(np.random.choice(valid)) if len(valid) > 0 else 4
69
+
70
+ obs_vec = self._flatten_obs(observation)
71
+ action_mask = np.array(observation.get("action_mask", [1]*6), dtype=bool)
72
+
73
+ action, _ = self.model.predict(
74
+ obs_vec,
75
+ action_masks=action_mask,
76
+ deterministic=True,
77
+ )
78
+ return int(action)