Builder-Neekhil commited on
Commit
e26c99f
·
verified ·
1 Parent(s): 6ab4d9f

Upload train_efficient.py

Browse files
Files changed (1) hide show
  1. train_efficient.py +630 -0
train_efficient.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Orbit Wars — Efficient PPO Self-Play Training for Adaptive Parameter Controller.
4
+
5
+ Optimized version: loads agent module ONCE, modifies globals in-place each step.
6
+ """
7
+
8
+ import copy
9
+ import math
10
+ import os
11
+ import random
12
+ import sys
13
+ import time
14
+ from collections import defaultdict, deque
15
+ from dataclasses import dataclass
16
+ from pathlib import Path
17
+ from typing import Any
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch.distributions import Normal
23
+
24
+ # ============================================================
25
+ # Import the base agent as a module-level namespace
26
+ # ============================================================
27
+ sys.path.insert(0, '/app')
28
+ _BASE_NS = {}
29
+ exec(open('/app/submission.py').read(), _BASE_NS)
30
+ print("Base agent loaded successfully.")
31
+
32
+ # Also create a separate namespace for the opponent
33
+ _OPP_NS = {}
34
+ exec(open('/app/submission.py').read(), _OPP_NS)
35
+
36
+ from kaggle_environments import make as _make_env
37
+
38
+ # ============================================================
39
+ # Feature Extraction
40
+ # ============================================================
41
+
42
+ FEATURE_DIM = 33
43
+
44
+ def extract_features(obs):
45
+ get = obs.get if isinstance(obs, dict) else lambda k, d=None: getattr(obs, k, d)
46
+ player = int(get("player", 0) or 0)
47
+ step = int(get("step", 0) or 0)
48
+ planets = get("planets") or []
49
+ fleets = get("fleets") or []
50
+ ang_vel = float(get("angular_velocity", 0.0) or 0.0)
51
+ comet_ids = set(get("comet_planet_ids") or [])
52
+
53
+ my_p = my_s = my_pr = en_p = en_s = en_pr = ne_p = ne_s = 0
54
+ my_st = my_ro = en_st = 0
55
+ en_by = defaultdict(int)
56
+
57
+ for p in planets:
58
+ _, owner, x, y, radius, ships, prod = p
59
+ is_st = (math.hypot(x - 50, y - 50) + radius) >= 50.0
60
+ if owner == player:
61
+ my_p += 1; my_s += ships; my_pr += prod
62
+ my_st += is_st; my_ro += (not is_st)
63
+ elif owner == -1:
64
+ ne_p += 1; ne_s += ships
65
+ else:
66
+ en_p += 1; en_s += ships; en_pr += prod; en_by[owner] += ships
67
+ en_st += is_st
68
+
69
+ my_fs = sum(f[6] for f in fleets if f[1] == player)
70
+ en_fs = sum(f[6] for f in fleets if f[1] != player)
71
+ my_fc = sum(1 for f in fleets if f[1] == player)
72
+ en_fc = sum(1 for f in fleets if f[1] != player)
73
+ mt = my_s + my_fs; et = en_s + en_fs; ta = mt + et + ne_s
74
+ ne = len(en_by)
75
+ mx_e = max(en_by.values()) if en_by else 0
76
+ mn_e = min(en_by.values()) if en_by else 0
77
+ nc = sum(1 for p in planets if p[0] in comet_ids)
78
+
79
+ return np.array([
80
+ step/500, min(1, step/100), max(0, (500-step)/500), float(step > 400),
81
+ min(1, my_p/15), min(1, en_p/15), min(1, ne_p/15), min(1, my_st/10), min(1, my_ro/10),
82
+ min(1, mt/max(1, ta)), min(1, et/max(1, ta)),
83
+ math.log1p(mt)/10, math.log1p(et)/10, math.log1p(my_fs)/10, math.log1p(en_fs)/10,
84
+ min(1, my_pr/max(1, my_pr+en_pr)), my_pr/30, en_pr/30,
85
+ np.clip((mt-et)/max(1, ta), -1, 1), np.clip((my_p-en_p)/15, -1, 1), np.clip((my_pr-en_pr)/15, -1, 1),
86
+ min(1, ne/3), float(ne >= 3), min(1, mx_e/max(1, et)), min(1, mn_e/max(1, mx_e+1)), min(1, en_fc/20),
87
+ min(1, my_fc/20), my_fs/max(1, mt), en_fs/max(1, et),
88
+ abs(ang_vel)*100, min(1, nc/5), min(1, len(planets)/30), ne_s/max(1, ta),
89
+ ], dtype=np.float32)
90
+
91
+
92
+ class OpponentProfiler:
93
+ def __init__(self):
94
+ self.a = 0.1; self.agg = 0.5; self.exp = 0.5; self.trt = 0.5
95
+ self.pp = 0; self.pf = 0; self.ps = 0; self.sc = 0
96
+
97
+ def update(self, obs):
98
+ get = obs.get if isinstance(obs, dict) else lambda k, d=None: getattr(obs, k, d)
99
+ player = int(get("player", 0) or 0)
100
+ planets = get("planets") or []; fleets = get("fleets") or []
101
+ ep = sum(1 for p in planets if p[1] not in (-1, player))
102
+ ef = sum(1 for f in fleets if f[1] != player)
103
+ es = sum(p[5] for p in planets if p[1] not in (-1, player))
104
+ es += sum(f[6] for f in fleets if f[1] != player)
105
+ if self.sc > 0:
106
+ fd = max(0, ef - self.pf)
107
+ self.agg = (1-self.a)*self.agg + self.a*min(1, fd/5)
108
+ pd = ep - self.pp
109
+ self.exp = (1-self.a)*self.exp + self.a*np.clip(pd/3+0.5, 0, 1)
110
+ efs = sum(f[6] for f in fleets if f[1] != player)
111
+ t = 1 - min(1, efs/max(1, es)) if es > 0 else 0.5
112
+ self.trt = (1-self.a)*self.trt + self.a*t
113
+ self.pp = ep; self.pf = ef; self.ps = es; self.sc += 1
114
+ return np.array([self.agg, self.exp, self.trt, min(1, self.sc/100), float(self.sc > 50)], dtype=np.float32)
115
+
116
+
117
+ # ============================================================
118
+ # Parameter Controller
119
+ # ============================================================
120
+
121
+ TUNABLE_PARAMS = {
122
+ "HOSTILE_TARGET_VALUE_MULT": (2.05, 1.0, 3.0),
123
+ "ELIMINATION_BONUS": (55.0, 10.0, 100.0),
124
+ "PROACTIVE_DEFENSE_RATIO": (0.28, 0.05, 0.5),
125
+ "FINISHING_HOSTILE_VALUE_MULT": (1.3, 0.8, 2.0),
126
+ "WEAK_ENEMY_THRESHOLD": (110.0, 30.0, 200.0),
127
+ "ATTACK_COST_TURN_WEIGHT": (0.50, 0.2, 0.8),
128
+ "HOSTILE_MARGIN_BASE": (3.0, 1.0, 6.0),
129
+ "FOUR_PLAYER_TARGET_MARGIN": (2.0, 0.0, 5.0),
130
+ "FINISHING_HOSTILE_SEND_BONUS": (5.0, 1.0, 10.0),
131
+ "STATIC_HOSTILE_VALUE_MULT": (1.65, 1.0, 2.5),
132
+ "GANG_UP_VALUE_MULT": (1.4, 1.0, 2.0),
133
+ "EXPOSED_PLANET_VALUE_MULT": (2.0, 1.0, 3.0),
134
+ "REINFORCE_VALUE_MULT": (1.35, 0.8, 2.0),
135
+ "DEFENSE_SHIP_VALUE": (0.55, 0.2, 1.0),
136
+ "BEHIND_DOMINATION": (-0.20, -0.5, 0.0),
137
+ "AHEAD_DOMINATION": (0.15, 0.0, 0.4),
138
+ "LATE_REMAINING_TURNS": (70.0, 40.0, 100.0),
139
+ "REAR_SEND_RATIO_TWO_PLAYER": (0.62, 0.3, 0.9),
140
+ "COMET_VALUE_MULT": (0.65, 0.3, 1.2),
141
+ "SNIPE_VALUE_MULT": (1.12, 0.7, 1.6),
142
+ }
143
+ PARAM_NAMES = list(TUNABLE_PARAMS.keys())
144
+ NUM_PARAMS = len(PARAM_NAMES)
145
+ INPUT_DIM = FEATURE_DIM + 5 # features + profile
146
+
147
+
148
+ class ParameterController(nn.Module):
149
+ def __init__(self, input_dim=INPUT_DIM, hidden_size=128):
150
+ super().__init__()
151
+ self.shared = nn.Sequential(
152
+ nn.Linear(input_dim, hidden_size), nn.ReLU(),
153
+ nn.Linear(hidden_size, hidden_size), nn.ReLU(),
154
+ )
155
+ self.param_mean = nn.Sequential(
156
+ nn.Linear(hidden_size, hidden_size // 2), nn.ReLU(),
157
+ nn.Linear(hidden_size // 2, NUM_PARAMS),
158
+ )
159
+ self.param_log_std = nn.Parameter(torch.zeros(NUM_PARAMS))
160
+ self.value_head = nn.Sequential(
161
+ nn.Linear(hidden_size, hidden_size // 2), nn.ReLU(),
162
+ nn.Linear(hidden_size // 2, 1),
163
+ )
164
+
165
+ def forward(self, x):
166
+ h = self.shared(x)
167
+ return torch.tanh(self.param_mean(h)), self.param_log_std, self.value_head(h).squeeze(-1)
168
+
169
+
170
+ def decode_params(raw):
171
+ params = {}
172
+ for i, name in enumerate(PARAM_NAMES):
173
+ _, low, high = TUNABLE_PARAMS[name]
174
+ t = (float(raw[i]) + 1.0) / 2.0
175
+ params[name] = low + t * (high - low)
176
+ return params
177
+
178
+
179
+ def apply_params(ns, params):
180
+ """Apply parameter overrides to agent namespace (in-place, very fast)."""
181
+ for name, value in params.items():
182
+ if name in ns:
183
+ ns[name] = value
184
+
185
+
186
+ def reset_params(ns):
187
+ """Reset parameters to defaults."""
188
+ for name, (default, _, _) in TUNABLE_PARAMS.items():
189
+ if name in ns:
190
+ ns[name] = default
191
+
192
+
193
+ # ============================================================
194
+ # Potential-based reward shaping
195
+ # ============================================================
196
+
197
+ def compute_potential(obs, player):
198
+ get = obs.get if isinstance(obs, dict) else lambda k, d=None: getattr(obs, k, d)
199
+ planets = get("planets") or []; fleets = get("fleets") or []
200
+ my_p = my_s = my_pr = en_p = en_s = en_pr = 0
201
+ for p in planets:
202
+ _, owner, _, _, _, ships, prod = p
203
+ if owner == player: my_p += 1; my_s += ships; my_pr += prod
204
+ elif owner >= 0: en_p += 1; en_s += ships; en_pr += prod
205
+ for f in fleets:
206
+ _, owner, _, _, _, _, ships = f
207
+ if owner == player: my_s += ships
208
+ elif owner >= 0: en_s += ships
209
+ eps = 1e-6; lr = math.log(10.0)
210
+ pp = np.clip(math.log((my_p+eps)/(en_p+eps))/lr, -1, 1)
211
+ ps = np.clip(math.log((my_s+eps)/(en_s+eps))/lr, -1, 1)
212
+ pprod = np.clip(math.log((my_pr+eps)/(en_pr+eps))/lr, -1, 1)
213
+ return 0.4*pp + 0.3*ps + 0.3*pprod
214
+
215
+
216
+ # ============================================================
217
+ # Efficient training loop
218
+ # ============================================================
219
+
220
+ def run_episode_vs_random(learner_ns, seed, learner_slot=0):
221
+ """Run episode against Kaggle's built-in random agent (very fast)."""
222
+ from kaggle_environments.envs.orbit_wars.orbit_wars import random_agent
223
+
224
+ env = _make_env("orbit_wars", configuration={"seed": seed}, debug=False)
225
+ env.reset(num_agents=2)
226
+ learner_ns['_agent_step'] = 0
227
+
228
+ profiler = OpponentProfiler()
229
+ states = env.step([[], []])
230
+ learner_obs = states[learner_slot].observation
231
+
232
+ features = extract_features(learner_obs)
233
+ profile = profiler.update(learner_obs)
234
+ initial_obs_vec = np.concatenate([features, profile])
235
+
236
+ player = int(learner_obs.get("player", 0) if isinstance(learner_obs, dict) else learner_obs.player)
237
+ prev_potential = compute_potential(learner_obs, player)
238
+ total_shaped_reward = 0.0
239
+ step_count = 0
240
+ done = False
241
+
242
+ while not done:
243
+ try:
244
+ learner_moves = learner_ns['agent'](learner_obs)
245
+ except Exception:
246
+ learner_moves = []
247
+
248
+ opp_obs = states[1 - learner_slot].observation
249
+ try:
250
+ opponent_moves = random_agent(opp_obs)
251
+ except Exception:
252
+ opponent_moves = []
253
+
254
+ if learner_slot == 0:
255
+ actions = [learner_moves, opponent_moves]
256
+ else:
257
+ actions = [opponent_moves, learner_moves]
258
+
259
+ states = env.step(actions)
260
+ learner_state = states[learner_slot]
261
+ learner_obs = learner_state.observation
262
+ done = learner_state.status != "ACTIVE"
263
+
264
+ curr_potential = compute_potential(learner_obs, player)
265
+ step_reward = 0.99 * curr_potential - prev_potential
266
+ prev_potential = curr_potential
267
+
268
+ if done:
269
+ raw_reward = float(learner_state.reward) if learner_state.reward else 0.0
270
+ step_reward += raw_reward
271
+
272
+ total_shaped_reward += step_reward
273
+ step_count += 1
274
+ profile = profiler.update(learner_obs)
275
+
276
+ final_features = extract_features(learner_obs)
277
+ final_obs_vec = np.concatenate([final_features, profile])
278
+ final_reward = float(learner_state.reward) if learner_state.reward else 0.0
279
+
280
+ return initial_obs_vec, final_obs_vec, total_shaped_reward, final_reward, step_count
281
+
282
+
283
+ def run_episode(learner_ns, opponent_ns, seed, learner_slot=0):
284
+ """Run a full game episode. Returns (transitions, final_reward).
285
+
286
+ Each transition: (features, reward, done)
287
+ The controller makes ONE decision per episode (parameter setting for the whole game).
288
+ This is much more efficient than per-step parameter tuning.
289
+ """
290
+ env = _make_env("orbit_wars", configuration={"seed": seed}, debug=False)
291
+ env.reset(num_agents=2)
292
+
293
+ # Reset step counters in both agents
294
+ learner_ns['_agent_step'] = 0
295
+ opponent_ns['_agent_step'] = 0
296
+
297
+ profiler = OpponentProfiler()
298
+
299
+ # Collect initial observation
300
+ states = env.step([[], []])
301
+ learner_obs = states[learner_slot].observation
302
+ opp_obs = states[1 - learner_slot].observation
303
+
304
+ # Extract initial features for controller decision
305
+ features = extract_features(learner_obs)
306
+ profile = profiler.update(learner_obs)
307
+ initial_obs_vec = np.concatenate([features, profile])
308
+
309
+ prev_potential = compute_potential(learner_obs,
310
+ int(learner_obs.get("player", 0) if isinstance(learner_obs, dict) else learner_obs.player))
311
+
312
+ total_shaped_reward = 0.0
313
+ step_count = 0
314
+ done = False
315
+
316
+ # Run the full game
317
+ while not done:
318
+ # Get moves from both agents
319
+ try:
320
+ learner_moves = learner_ns['agent'](learner_obs)
321
+ except Exception:
322
+ learner_moves = []
323
+
324
+ try:
325
+ opponent_moves = opponent_ns['agent'](opp_obs)
326
+ except Exception:
327
+ opponent_moves = []
328
+
329
+ if learner_slot == 0:
330
+ actions = [learner_moves, opponent_moves]
331
+ else:
332
+ actions = [opponent_moves, learner_moves]
333
+
334
+ states = env.step(actions)
335
+ learner_state = states[learner_slot]
336
+ opp_state = states[1 - learner_slot]
337
+
338
+ learner_obs = learner_state.observation
339
+ opp_obs = opp_state.observation
340
+ done = learner_state.status != "ACTIVE"
341
+
342
+ # Shaped reward
343
+ player = int(learner_obs.get("player", 0) if isinstance(learner_obs, dict) else learner_obs.player)
344
+ curr_potential = compute_potential(learner_obs, player)
345
+ step_reward = 0.99 * curr_potential - prev_potential
346
+ prev_potential = curr_potential
347
+
348
+ if done:
349
+ raw_reward = float(learner_state.reward) if learner_state.reward else 0.0
350
+ step_reward += raw_reward
351
+
352
+ total_shaped_reward += step_reward
353
+ step_count += 1
354
+
355
+ # Update opponent profile
356
+ profile = profiler.update(learner_obs)
357
+
358
+ # Final features for the last state
359
+ final_features = extract_features(learner_obs)
360
+ final_obs_vec = np.concatenate([final_features, profile])
361
+
362
+ final_reward = float(learner_state.reward) if learner_state.reward else 0.0
363
+
364
+ return initial_obs_vec, final_obs_vec, total_shaped_reward, final_reward, step_count
365
+
366
+
367
+ def train():
368
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
369
+ print(f"Device: {device}")
370
+
371
+ # Config
372
+ total_updates = int(os.environ.get("TOTAL_UPDATES", "500"))
373
+ episodes_per_update = int(os.environ.get("EPISODES_PER_UPDATE", "4"))
374
+ eval_every = int(os.environ.get("EVAL_EVERY", "25"))
375
+ eval_games = int(os.environ.get("EVAL_GAMES", "6"))
376
+ lr = float(os.environ.get("LR", "3e-4"))
377
+ gamma = 0.99
378
+ clip_coef = 0.2
379
+ ent_coef = 0.01
380
+ vf_coef = 0.5
381
+ epochs = 4
382
+ pool_size = 3
383
+ save_dir = Path(os.environ.get("SAVE_DIR", "/app/checkpoints"))
384
+ save_dir.mkdir(parents=True, exist_ok=True)
385
+ random.seed(42); np.random.seed(42); torch.manual_seed(42)
386
+
387
+ controller = ParameterController().to(device)
388
+ optimizer = torch.optim.Adam(controller.parameters(), lr=lr)
389
+
390
+ # Opponent pool: list of parameter snapshots (dicts of param values)
391
+ opponent_pool = [None] # None = baseline (no overrides)
392
+ best_win_rate = 0.0
393
+ seed_counter = 0
394
+
395
+ # Import fast opponents
396
+ from kaggle_environments.envs.orbit_wars.orbit_wars import random_agent
397
+
398
+ # Opponent curriculum: random first, then baseline, then self-play
399
+ def get_opponent_ns(update_idx):
400
+ """Return opponent namespace and label based on training phase."""
401
+ phase_fraction = update_idx / total_updates
402
+
403
+ if phase_fraction < 0.2:
404
+ # Phase 1: Train vs random (very fast ~20s/episode)
405
+ return None, "random"
406
+ elif phase_fraction < 0.5:
407
+ # Phase 2: Train vs baseline (medium ~60s/episode)
408
+ reset_params(_OPP_NS)
409
+ return _OPP_NS, "baseline"
410
+ else:
411
+ # Phase 3: Train vs pool (self-play)
412
+ opp_params = random.choice(opponent_pool)
413
+ reset_params(_OPP_NS)
414
+ if opp_params is not None:
415
+ apply_params(_OPP_NS, opp_params)
416
+ return _OPP_NS, "pool"
417
+
418
+ print(f"\nTraining: {total_updates} updates × {episodes_per_update} episodes")
419
+ print(f"Phase 1 (0-20%): vs random | Phase 2 (20-50%): vs baseline | Phase 3 (50-100%): self-play")
420
+ print(f"Eval every {eval_every} updates, {eval_games} games\n")
421
+
422
+ for update in range(total_updates):
423
+ t0 = time.time()
424
+
425
+ # Collect episodes
426
+ obs_batch = []
427
+ reward_batch = []
428
+ wins = 0
429
+ total_steps = 0
430
+
431
+ for ep in range(episodes_per_update):
432
+ seed_counter += 1
433
+ learner_slot = (update * episodes_per_update + ep) % 2
434
+
435
+ # Pick opponent based on curriculum
436
+ opp_ns, opp_label = get_opponent_ns(update)
437
+
438
+ # Get controller output for this episode
439
+ with torch.inference_mode():
440
+ # Use a dummy observation to get initial params
441
+ # (we'll use the same params for the whole episode)
442
+ dummy_obs = np.zeros(INPUT_DIM, dtype=np.float32)
443
+ dummy_obs[0] = 0.0 # start of game
444
+ x = torch.from_numpy(dummy_obs).unsqueeze(0).to(device)
445
+ param_mean, log_std, value = controller(x)
446
+
447
+ std = torch.exp(log_std)
448
+ dist = Normal(param_mean.squeeze(0), std)
449
+ action = dist.sample()
450
+ log_prob = dist.log_prob(action).sum().item()
451
+ value_np = value.item()
452
+ action_np = action.cpu().numpy()
453
+
454
+ # Apply learned params to learner
455
+ params = decode_params(np.clip(action_np, -1, 1))
456
+ reset_params(_BASE_NS)
457
+ apply_params(_BASE_NS, params)
458
+
459
+ # Run episode
460
+ if opp_ns is None:
461
+ # Use random agent (fast)
462
+ init_obs, final_obs, shaped_reward, raw_reward, steps = run_episode_vs_random(
463
+ _BASE_NS, seed=seed_counter * 37 + 1, learner_slot=learner_slot
464
+ )
465
+ else:
466
+ init_obs, final_obs, shaped_reward, raw_reward, steps = run_episode(
467
+ _BASE_NS, opp_ns, seed=seed_counter * 37 + 1, learner_slot=learner_slot
468
+ )
469
+
470
+ obs_batch.append((init_obs, action_np, log_prob, value_np, shaped_reward))
471
+ reward_batch.append(raw_reward)
472
+ if raw_reward > 0:
473
+ wins += 1
474
+ total_steps += steps
475
+
476
+ # PPO update
477
+ if obs_batch:
478
+ obs_t = torch.tensor(np.stack([o[0] for o in obs_batch]), dtype=torch.float32, device=device)
479
+ actions_t = torch.tensor(np.stack([o[1] for o in obs_batch]), dtype=torch.float32, device=device)
480
+ old_log_probs_t = torch.tensor([o[2] for o in obs_batch], dtype=torch.float32, device=device)
481
+ old_values_t = torch.tensor([o[3] for o in obs_batch], dtype=torch.float32, device=device)
482
+ rewards_t = torch.tensor([o[4] for o in obs_batch], dtype=torch.float32, device=device)
483
+
484
+ # Returns = rewards (single step per "episode" from controller's perspective)
485
+ returns_t = rewards_t
486
+ advantages_t = returns_t - old_values_t
487
+ if advantages_t.std() > 1e-6:
488
+ advantages_t = (advantages_t - advantages_t.mean()) / (advantages_t.std() + 1e-8)
489
+
490
+ metrics = {"loss": 0, "pl": 0, "vl": 0, "ent": 0}
491
+ n_updates = 0
492
+
493
+ for _ in range(epochs):
494
+ param_mean, log_std, values = controller(obs_t)
495
+ std = torch.exp(log_std)
496
+ dist = Normal(param_mean, std)
497
+ new_log_probs = dist.log_prob(actions_t).sum(-1)
498
+ entropy = dist.entropy().sum(-1)
499
+
500
+ ratio = (new_log_probs - old_log_probs_t).exp()
501
+ s1 = -advantages_t * ratio
502
+ s2 = -advantages_t * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
503
+ pl = torch.max(s1, s2).mean()
504
+ vl = 0.5 * (returns_t - values).pow(2).mean()
505
+ el = -entropy.mean()
506
+
507
+ loss = pl + vf_coef * vl + ent_coef * el
508
+ optimizer.zero_grad()
509
+ loss.backward()
510
+ nn.utils.clip_grad_norm_(controller.parameters(), 0.5)
511
+ optimizer.step()
512
+
513
+ metrics["loss"] += loss.item()
514
+ metrics["pl"] += pl.item()
515
+ metrics["vl"] += vl.item()
516
+ metrics["ent"] += entropy.mean().item()
517
+ n_updates += 1
518
+
519
+ metrics = {k: v / max(1, n_updates) for k, v in metrics.items()}
520
+
521
+ elapsed = time.time() - t0
522
+ win_rate = wins / episodes_per_update
523
+ avg_reward = np.mean(reward_batch) if reward_batch else 0
524
+
525
+ print(f"U{update+1:4d}/{total_updates} | "
526
+ f"WR: {win_rate:.0%} | R: {avg_reward:+.2f} | "
527
+ f"L: {metrics.get('loss',0):.4f} PL: {metrics.get('pl',0):.4f} "
528
+ f"VL: {metrics.get('vl',0):.4f} Ent: {metrics.get('ent',0):.3f} | "
529
+ f"Steps: {total_steps} | {elapsed:.1f}s | vs: {opp_label}")
530
+
531
+ # Evaluation and pool management
532
+ if (update + 1) % eval_every == 0:
533
+ print(f"\n Evaluating vs baseline ({eval_games} games)...")
534
+ eval_wins = 0
535
+
536
+ # Get current best params from controller
537
+ with torch.inference_mode():
538
+ x = torch.zeros(1, INPUT_DIM, device=device)
539
+ pm, _, _ = controller(x)
540
+ eval_params = decode_params(pm.squeeze(0).cpu().numpy())
541
+
542
+ for g in range(eval_games):
543
+ slot = g % 2
544
+ reset_params(_BASE_NS); apply_params(_BASE_NS, eval_params)
545
+ reset_params(_OPP_NS) # opponent = baseline
546
+
547
+ _, _, _, raw_r, _ = run_episode(_BASE_NS, _OPP_NS, seed=10000 + g, learner_slot=slot)
548
+ if raw_r > 0:
549
+ eval_wins += 1
550
+ print(f" Game {g+1}: {'WIN' if raw_r > 0 else 'LOSS'} (slot={slot})")
551
+
552
+ wr = eval_wins / eval_games
553
+ print(f" Win rate: {wr:.0%} ({eval_wins}/{eval_games})")
554
+
555
+ # Add to pool if good
556
+ if wr >= 0.45:
557
+ if len(opponent_pool) >= pool_size:
558
+ opponent_pool.pop(0)
559
+ opponent_pool.append(copy.deepcopy(eval_params))
560
+ print(f" ✓ Added to pool (size={len(opponent_pool)})")
561
+
562
+ if wr > best_win_rate:
563
+ best_win_rate = wr
564
+ torch.save({
565
+ "controller": controller.state_dict(),
566
+ "params": eval_params,
567
+ "win_rate": wr,
568
+ "update": update + 1,
569
+ }, save_dir / "best_controller.pt")
570
+ print(f" ★ New best: {wr:.0%}")
571
+ print()
572
+
573
+ # Checkpoint
574
+ if (update + 1) % 100 == 0:
575
+ torch.save({
576
+ "controller": controller.state_dict(),
577
+ "optimizer": optimizer.state_dict(),
578
+ "update": update + 1,
579
+ }, save_dir / f"ckpt_{update+1:05d}.pt")
580
+
581
+ # Final save
582
+ torch.save({
583
+ "controller": controller.state_dict(),
584
+ "best_win_rate": best_win_rate,
585
+ }, save_dir / "final_controller.pt")
586
+
587
+ print(f"\nDone! Best win rate: {best_win_rate:.0%}")
588
+ print(f"Checkpoints: {save_dir}")
589
+
590
+ # Push to hub
591
+ try:
592
+ from huggingface_hub import HfApi
593
+ api = HfApi(token=os.environ.get("HF_TOKEN"))
594
+
595
+ # Upload best checkpoint
596
+ best_path = save_dir / "best_controller.pt"
597
+ if best_path.exists():
598
+ api.upload_file(
599
+ path_or_fileobj=str(best_path),
600
+ path_in_repo="best_controller.pt",
601
+ repo_id="Builder-Neekhil/orbit-wars-agent",
602
+ commit_message=f"Upload trained controller (WR: {best_win_rate:.0%})"
603
+ )
604
+ print(f"Uploaded best_controller.pt to HF Hub")
605
+
606
+ # Generate and upload adaptive submission
607
+ final_path = save_dir / "final_controller.pt"
608
+ if not best_path.exists():
609
+ best_path = final_path
610
+ if best_path.exists():
611
+ sys.path.insert(0, '/app')
612
+ from generate_submission import generate_submission
613
+ generate_submission(
614
+ base_agent_path="/app/submission.py",
615
+ checkpoint_path=str(best_path),
616
+ output_path="/app/submission_adaptive.py",
617
+ )
618
+ api.upload_file(
619
+ path_or_fileobj="/app/submission_adaptive.py",
620
+ path_in_repo="submission_adaptive.py",
621
+ repo_id="Builder-Neekhil/orbit-wars-agent",
622
+ commit_message=f"Upload adaptive submission (WR: {best_win_rate:.0%})"
623
+ )
624
+ print("Uploaded submission_adaptive.py to HF Hub")
625
+ except Exception as e:
626
+ print(f"Hub upload error: {e}")
627
+
628
+
629
+ if __name__ == "__main__":
630
+ train()