E-Rong commited on
Commit
1659dd8
·
verified ·
1 Parent(s): 4b6177e

Upload smoke_test_v2.py

Browse files
Files changed (1) hide show
  1. smoke_test_v2.py +143 -0
smoke_test_v2.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Smoke test: download TIL repo via snapshot_download, verify imports, run 100 steps, push dummy checkpoint."""
3
+ import os, sys, subprocess
4
+
5
+ print("="*60)
6
+ print("SMOKE TEST: HF Job private repo access + training basics")
7
+ print("="*60)
8
+
9
+ # 1. Test snapshot_download of private Space
10
+ print("\n[1/5] Downloading TIL repo via snapshot_download...")
11
+ from huggingface_hub import snapshot_download
12
+ snapshot_download(
13
+ repo_id="e-rong/til-26-ae",
14
+ repo_type="space",
15
+ local_dir="/app/til-26-ae-repo",
16
+ )
17
+ print(" ✓ Downloaded")
18
+ print(" Listing repo root:")
19
+ for root, dirs, files in os.walk("/app/til-26-ae-repo"):
20
+ level = root.replace("/app/til-26-ae-repo", "").count(os.sep)
21
+ indent = " " * 2 * level
22
+ print(f"{indent}{os.path.basename(root)}/")
23
+ subindent = " " * 2 * (level + 1)
24
+ for f in files[:5]:
25
+ print(f"{subindent}{f}")
26
+ if len(files) > 5:
27
+ print(f"{subindent}... ({len(files)-5} more files)")
28
+
29
+ # 2. Install TIL environment
30
+ print("\n[2/5] Installing TIL environment...")
31
+ # Find the actual package root (contains pyproject.toml)
32
+ PKG_ROOT = None
33
+ for root, dirs, files in os.walk("/app/til-26-ae-repo"):
34
+ if "pyproject.toml" in files:
35
+ PKG_ROOT = root
36
+ break
37
+ if PKG_ROOT is None:
38
+ raise RuntimeError("Could not find pyproject.toml in downloaded repo")
39
+ print(f" Package root found: {PKG_ROOT}")
40
+ subprocess.run(["pip", "install", "-e", "."], cwd=PKG_ROOT, check=True)
41
+ print(" ✓ Installed")
42
+
43
+ # 3. Verify imports
44
+ print("\n[3/5] Verifying imports...")
45
+ sys.path.insert(0, PKG_ROOT)
46
+ from til_environment.bomberman_env import Bomberman
47
+ from til_environment.config import default_config
48
+ from pettingzoo.utils.conversions import aec_to_parallel
49
+ print(" ✓ Imports OK")
50
+
51
+ # 4. Run 100 steps of dummy training
52
+ print("\n[4/5] Running 100 training steps...")
53
+ from sb3_contrib import MaskablePPO
54
+ from sb3_contrib.common.wrappers import ActionMasker
55
+ from stable_baselines3.common.monitor import Monitor
56
+ import gymnasium
57
+ from gymnasium.spaces import Box, Discrete
58
+ import numpy as np
59
+
60
+ class QuickEnv(gymnasium.Env):
61
+ def __init__(self):
62
+ super().__init__()
63
+ cfg = default_config()
64
+ cfg.env.render_mode = None
65
+ raw = Bomberman(cfg)
66
+ self._parallel_env = aec_to_parallel(raw)
67
+ self.agent_id = "agent_0"
68
+ self._episode_count = 0
69
+ self.action_space = Discrete(6)
70
+ vl = int(cfg.dynamics.vision.behind) + int(cfg.dynamics.vision.ahead) + 1
71
+ vw = int(cfg.dynamics.vision.left) + int(cfg.dynamics.vision.right) + 1
72
+ av = vl * vw * 25
73
+ br = int(cfg.entities.base.vision_radius)
74
+ bs = 2 * br + 1
75
+ bv = bs * bs * 25
76
+ self._obs_size = av + bv + 11
77
+ self.observation_space = Box(low=-np.inf, high=np.inf, shape=(self._obs_size,), dtype=np.float32)
78
+ self._last_action_mask = None
79
+ self._last_obs_dict = None
80
+ def reset(self, seed=None, options=None):
81
+ self._episode_count += 1
82
+ obs_dict, _ = self._parallel_env.reset(seed=self._episode_count, options=options)
83
+ self._last_obs_dict = obs_dict
84
+ self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
85
+ return self._flatten(obs_dict[self.agent_id]), {}
86
+ def step(self, action):
87
+ actions = {self.agent_id: action}
88
+ for aid, obs in self._last_obs_dict.items():
89
+ if aid != self.agent_id:
90
+ valid = np.where(obs["action_mask"] == 1)[0]
91
+ actions[aid] = int(np.random.choice(valid)) if len(valid) > 0 else 0
92
+ obs_dict, rewards, terminations, truncations, infos = self._parallel_env.step(actions)
93
+ self._last_obs_dict = obs_dict
94
+ if self.agent_id not in obs_dict:
95
+ return np.zeros(self._obs_size, dtype=np.float32), 0.0, True, False, {}
96
+ self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
97
+ obs = self._flatten(obs_dict[self.agent_id])
98
+ r = float(rewards.get(self.agent_id, 0.0))
99
+ done = terminations.get(self.agent_id, False) or truncations.get(self.agent_id, False)
100
+ return obs, r, done, False, infos.get(self.agent_id, {})
101
+ def action_masks(self):
102
+ return self._last_action_mask
103
+ def _flatten(self, od):
104
+ return np.concatenate([
105
+ od["agent_viewcone"].flatten(), od["base_viewcone"].flatten(),
106
+ np.array([od["direction"]], dtype=np.float32),
107
+ od["location"].flatten().astype(np.float32),
108
+ od["base_location"].flatten().astype(np.float32),
109
+ od["health"].flatten().astype(np.float32),
110
+ np.array([od["frozen_ticks"]], dtype=np.float32),
111
+ od["base_health"].flatten().astype(np.float32),
112
+ od["team_resources"].flatten().astype(np.float32),
113
+ np.array([od["team_bombs"]], dtype=np.float32),
114
+ np.array([od["step"]], dtype=np.float32),
115
+ ], dtype=np.float32)
116
+
117
+ env = ActionMasker(QuickEnv(), lambda e: e.action_masks())
118
+ env = Monitor(env)
119
+
120
+ model = MaskablePPO(
121
+ "MlpPolicy", env,
122
+ learning_rate=3e-4, n_steps=128, batch_size=32, n_epochs=2,
123
+ gamma=0.99, clip_range=0.2, ent_coef=0.01,
124
+ verbose=0, device="cuda",
125
+ )
126
+ model.learn(total_timesteps=100, progress_bar=False)
127
+ print(" ✓ 100 steps completed")
128
+
129
+ # 5. Push dummy checkpoint to Hub
130
+ print("\n[5/5] Pushing dummy checkpoint to Hub...")
131
+ from huggingface_hub import HfApi
132
+ model.save("/app/smoke_test_ckpt.zip")
133
+ HfApi().upload_file(
134
+ path_or_fileobj="/app/smoke_test_ckpt.zip",
135
+ path_in_repo="smoke_test_ckpt.zip",
136
+ repo_id="E-Rong/til-26-ae-agent",
137
+ repo_type="model",
138
+ )
139
+ print(" ✓ Pushed to Hub")
140
+
141
+ print("\n" + "="*60)
142
+ print("SMOKE TEST PASSED — Ready for full training job")
143
+ print("="*60)