K446 commited on
Commit
7bcd08c
·
1 Parent(s): 78131a0

Add GRPO training runner for HF Spaces GPU

Browse files
Files changed (3) hide show
  1. Dockerfile.training +24 -0
  2. requirements-training.txt +20 -0
  3. run_training.py +384 -0
Dockerfile.training ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OpenGrid GRPO Training Space — Runs on A10G GPU
2
+ # After training completes, serves results on port 7860
3
+
4
+ FROM python:3.10-slim
5
+
6
+ LABEL org.opencontainers.image.title="OpenGrid GRPO Training"
7
+ LABEL org.opencontainers.image.description="GRPO training for power grid multi-agent controller"
8
+
9
+ RUN useradd -m -u 1000 user
10
+ USER user
11
+ ENV PATH="/home/user/.local/bin:$PATH"
12
+
13
+ WORKDIR /app
14
+
15
+ # Install training dependencies
16
+ COPY --chown=user requirements-training.txt .
17
+ RUN pip install --no-cache-dir --upgrade -r requirements-training.txt
18
+
19
+ # Copy application code
20
+ COPY --chown=user . /app
21
+
22
+ # Training entrypoint: runs GRPO then serves results
23
+ EXPOSE 7860
24
+ CMD ["python", "run_training.py"]
requirements-training.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core env
2
+ fastapi
3
+ uvicorn[standard]
4
+ pydantic>=2.0
5
+ numpy
6
+ networkx
7
+ matplotlib
8
+ openai
9
+ httpx
10
+ openenv-core>=0.2.0
11
+
12
+ # Training
13
+ torch
14
+ transformers
15
+ trl>=0.17.0
16
+ peft
17
+ accelerate
18
+ bitsandbytes
19
+ datasets
20
+ unsloth
run_training.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenGrid GRPO Training Runner for HF Spaces.
2
+
3
+ Runs env-grounded GRPO training, saves model + plots,
4
+ then starts a FastAPI server to serve/download results.
5
+ """
6
+ import os
7
+ import sys
8
+ import json
9
+ import copy
10
+ import time
11
+ import shutil
12
+ import traceback
13
+ from pathlib import Path
14
+
15
+ # ── Training ──────────────────────────────────────────────────────
16
+ def run_grpo_training():
17
+ """Run GRPO training with env-grounded rewards."""
18
+ import torch
19
+ import numpy as np
20
+
21
+ print("=" * 60)
22
+ print(" OpenGrid GRPO Training")
23
+ print("=" * 60)
24
+
25
+ if torch.cuda.is_available():
26
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
27
+ print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
28
+ else:
29
+ print("WARNING: No GPU detected — training will be very slow!")
30
+
31
+ # Import project modules
32
+ sys.path.insert(0, ".")
33
+ from src.environment import OpenGridEnv
34
+ from src.tasks import TASKS
35
+ from src.models import GridAction, BusAdjustment
36
+ from training.train_grpo import (
37
+ SYSTEM_PROMPT, format_observation_prompt,
38
+ compute_grpo_reward_env, extract_action,
39
+ rollout_multi_agent,
40
+ )
41
+
42
+ # ── 1. Load model ──
43
+ print("\n[1/6] Loading model with Unsloth...")
44
+ try:
45
+ from unsloth import FastLanguageModel
46
+ MODEL_NAME = "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit"
47
+ model, tokenizer = FastLanguageModel.from_pretrained(
48
+ model_name=MODEL_NAME, max_seq_length=2048, load_in_4bit=True,
49
+ )
50
+ model = FastLanguageModel.get_peft_model(
51
+ model, r=16, lora_alpha=16, lora_dropout=0,
52
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
53
+ "gate_proj", "up_proj", "down_proj"],
54
+ )
55
+ print(f" Model: {MODEL_NAME}")
56
+ print(f" Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
57
+ except ImportError:
58
+ print("WARNING: Unsloth not available, using standard loading")
59
+ from transformers import AutoTokenizer, AutoModelForCausalLM
60
+ MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
61
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
62
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
63
+
64
+ if tokenizer.pad_token is None:
65
+ tokenizer.pad_token = tokenizer.eos_token
66
+
67
+ # ── 2. Baseline evaluation ──
68
+ print("\n[2/6] Running baseline evaluation...")
69
+ import re
70
+
71
+ def heuristic_generate(prompt):
72
+ freq_match = re.search(r'Frequency: ([\d.]+)', prompt)
73
+ freq = float(freq_match.group(1)) if freq_match else 50.0
74
+ error = 50.0 - freq
75
+ delta = max(-20, min(20, error * 10))
76
+ bus_match = re.search(r'Bus (\d+) \((generator|battery|slack)\)', prompt)
77
+ if bus_match:
78
+ return json.dumps({"bus_adjustments": [{"bus_id": int(bus_match.group(1)), "delta": round(delta, 1)}], "topology_actions": []})
79
+ return json.dumps({"bus_adjustments": [], "topology_actions": []})
80
+
81
+ baseline_results = {}
82
+ for task_id in ["task_easy", "task_medium", "task_karnataka"]:
83
+ if task_id not in TASKS:
84
+ continue
85
+ config = TASKS[task_id]
86
+ rewards = []
87
+ for ep in range(3):
88
+ ep_config = copy.deepcopy(config)
89
+ ep_config['seed'] = 42 + ep
90
+ env = OpenGridEnv(ep_config)
91
+ result = rollout_multi_agent(env, heuristic_generate, ep_config)
92
+ rewards.append(result['total_reward'])
93
+ baseline_results[task_id] = {"avg": np.mean(rewards), "std": np.std(rewards), "rewards": rewards}
94
+ print(f" [BASELINE] {task_id}: {np.mean(rewards):.2f} ± {np.std(rewards):.2f}")
95
+
96
+ # ── 3. Generate training prompts ──
97
+ print("\n[3/6] Generating training prompts...")
98
+ TRAIN_TASK = "task_karnataka" if "task_karnataka" in TASKS else "task_easy"
99
+ task_config = copy.deepcopy(TASKS[TRAIN_TASK])
100
+ base_seed = task_config.get('seed', 42)
101
+
102
+ prompts = []
103
+ obs_contexts = []
104
+ rng = np.random.RandomState(base_seed)
105
+
106
+ for episode in range(30):
107
+ ep_config = copy.deepcopy(task_config)
108
+ ep_config['seed'] = base_seed + episode
109
+ env = OpenGridEnv(ep_config)
110
+ zone_obs = env.reset_multi()
111
+
112
+ # Adversarial: drain batteries every 5th episode
113
+ if episode % 5 == 0:
114
+ for b in env.bus_state:
115
+ b_cfg = env._find_bus_config(b['id'])
116
+ if b_cfg and b_cfg['type'] == 'battery':
117
+ b['soc'] = max(1.0, b['soc'] * 0.1)
118
+
119
+ for t in range(min(15, task_config['max_steps'])):
120
+ for agent_id, obs in zone_obs.items():
121
+ obs_dict = json.loads(obs.model_dump_json())
122
+ prompt_text = format_observation_prompt(obs_dict, zone_name=obs.zone_name)
123
+ messages = [
124
+ {"role": "system", "content": SYSTEM_PROMPT},
125
+ {"role": "user", "content": prompt_text},
126
+ ]
127
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
128
+ prompts.append(formatted)
129
+ obs_contexts.append(json.dumps(obs_dict))
130
+
131
+ random_actions = {}
132
+ for aid in range(env.num_agents):
133
+ zone_buses = task_config['zone_bus_ids'].get(aid, [])
134
+ controllable = [
135
+ bid for bid in zone_buses
136
+ if next((b for b in task_config['buses'] if b['id'] == bid), {}).get('type')
137
+ in ['generator', 'battery']
138
+ ]
139
+ adj = []
140
+ if controllable:
141
+ n_adj = min(len(controllable), rng.randint(1, 3))
142
+ chosen = rng.choice(controllable, size=n_adj, replace=False)
143
+ for bid in chosen:
144
+ adj.append(BusAdjustment(bus_id=int(bid), delta=float(rng.uniform(-30, 30))))
145
+ random_actions[aid] = GridAction(bus_adjustments=adj)
146
+
147
+ result = env.step_multi(random_actions)
148
+ if result.done:
149
+ break
150
+ zone_obs = result.observations
151
+
152
+ print(f" Generated {len(prompts)} training prompts")
153
+
154
+ # ── 4. Train ──
155
+ print("\n[4/6] Starting GRPO training...")
156
+ from trl import GRPOTrainer, GRPOConfig
157
+ from datasets import Dataset
158
+
159
+ _bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
160
+ _fp16 = torch.cuda.is_available() and not _bf16
161
+
162
+ def reward_fn(completions, obs_context=None, **kwargs):
163
+ texts = []
164
+ for c in completions:
165
+ if isinstance(c, list):
166
+ text = c[-1]['content'] if c else ""
167
+ else:
168
+ text = str(c)
169
+ texts.append(text)
170
+
171
+ if obs_context is None:
172
+ obs_context = [None] * len(texts)
173
+
174
+ obs_dicts = []
175
+ for ctx in obs_context:
176
+ if isinstance(ctx, str):
177
+ try:
178
+ obs_dicts.append(json.loads(ctx))
179
+ except (json.JSONDecodeError, TypeError):
180
+ obs_dicts.append(None)
181
+ else:
182
+ obs_dicts.append(ctx)
183
+
184
+ return compute_grpo_reward_env(texts, obs_dicts, task_config, horizon=3)
185
+
186
+ grpo_config = GRPOConfig(
187
+ output_dir="training/outputs/grpo_checkpoints",
188
+ num_train_epochs=3,
189
+ per_device_train_batch_size=2,
190
+ gradient_accumulation_steps=8,
191
+ learning_rate=1e-5,
192
+ logging_steps=5,
193
+ save_steps=50,
194
+ max_completion_length=256,
195
+ num_generations=8,
196
+ report_to="none",
197
+ remove_unused_columns=False,
198
+ bf16=_bf16,
199
+ fp16=_fp16,
200
+ )
201
+
202
+ train_dataset = Dataset.from_dict({"prompt": prompts, "obs_context": obs_contexts})
203
+ print(f" Dataset: {len(train_dataset)} rows")
204
+ print(f" Effective batch: {grpo_config.per_device_train_batch_size * grpo_config.gradient_accumulation_steps}")
205
+
206
+ trainer = GRPOTrainer(
207
+ model=model, args=grpo_config, train_dataset=train_dataset,
208
+ reward_funcs=reward_fn, processing_class=tokenizer,
209
+ )
210
+
211
+ t0 = time.time()
212
+ trainer.train()
213
+ train_time = time.time() - t0
214
+ print(f"\n Training complete in {train_time/60:.1f} minutes")
215
+
216
+ # Save model
217
+ output_path = "training/outputs/trained_model"
218
+ trainer.save_model(output_path)
219
+ tokenizer.save_pretrained(output_path)
220
+ print(f" Model saved to {output_path}")
221
+
222
+ # ── 5. Post-training evaluation ──
223
+ print("\n[5/6] Evaluating trained model...")
224
+ try:
225
+ FastLanguageModel.for_inference(model)
226
+ except Exception:
227
+ pass
228
+
229
+ def trained_generate(prompt):
230
+ messages = [
231
+ {"role": "system", "content": SYSTEM_PROMPT},
232
+ {"role": "user", "content": prompt},
233
+ ]
234
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
235
+ inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
236
+ with torch.no_grad():
237
+ outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.3, do_sample=True)
238
+ return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
239
+
240
+ trained_results = {}
241
+ for task_id in ["task_easy", "task_medium", "task_karnataka"]:
242
+ if task_id not in TASKS:
243
+ continue
244
+ config = TASKS[task_id]
245
+ rewards = []
246
+ for ep in range(3):
247
+ ep_config = copy.deepcopy(config)
248
+ ep_config['seed'] = 42 + ep
249
+ env = OpenGridEnv(ep_config)
250
+ result = rollout_multi_agent(env, trained_generate, ep_config)
251
+ rewards.append(result['total_reward'])
252
+ print(f" {task_id} ep{ep}: reward={result['total_reward']:.2f}")
253
+ trained_results[task_id] = {"avg": np.mean(rewards), "std": np.std(rewards), "rewards": rewards}
254
+ print(f" [TRAINED] {task_id}: {np.mean(rewards):.2f} ± {np.std(rewards):.2f}")
255
+
256
+ # ── 6. Generate plots ──
257
+ print("\n[6/6] Generating plots...")
258
+ import matplotlib
259
+ matplotlib.use('Agg')
260
+ import matplotlib.pyplot as plt
261
+
262
+ os.makedirs("training/outputs", exist_ok=True)
263
+
264
+ # Before vs After
265
+ common_tasks = [t for t in baseline_results if t in trained_results]
266
+ if common_tasks:
267
+ fig, ax = plt.subplots(figsize=(10, 6))
268
+ x = np.arange(len(common_tasks))
269
+ width = 0.35
270
+ before = [baseline_results[t]['avg'] for t in common_tasks]
271
+ after = [trained_results[t]['avg'] for t in common_tasks]
272
+ ax.bar(x - width/2, before, width, label='Heuristic Baseline', color='#ff6b6b', alpha=0.8)
273
+ ax.bar(x + width/2, after, width, label='GRPO Trained', color='#00d4aa', alpha=0.8)
274
+ ax.set_xlabel('Task'); ax.set_ylabel('Average Episode Reward')
275
+ ax.set_title('OpenGrid — GRPO Training: Before vs After', fontweight='bold')
276
+ ax.set_xticks(x); ax.set_xticklabels([t.replace('task_', '').title() for t in common_tasks])
277
+ ax.legend(); ax.grid(True, alpha=0.3, axis='y')
278
+ for bars in ax.containers:
279
+ for bar in bars:
280
+ h = bar.get_height()
281
+ ax.text(bar.get_x() + bar.get_width()/2., h + (1 if h >= 0 else -3),
282
+ f'{h:.1f}', ha='center', va='bottom' if h >= 0 else 'top', fontsize=10)
283
+ plt.tight_layout()
284
+ plt.savefig('training/outputs/before_after.png', dpi=150)
285
+ plt.close()
286
+
287
+ # Training loss
288
+ history = trainer.state.log_history
289
+ steps = [h['step'] for h in history if 'loss' in h]
290
+ losses = [h['loss'] for h in history if 'loss' in h]
291
+ if steps:
292
+ fig, ax = plt.subplots(figsize=(10, 5))
293
+ ax.plot(steps, losses, color='#ff6b6b', linewidth=1.5, alpha=0.6, label='Loss')
294
+ if len(losses) > 10:
295
+ w = min(20, len(losses) // 3)
296
+ smoothed = np.convolve(losses, np.ones(w)/w, mode='valid')
297
+ ax.plot(steps[w-1:], smoothed, color='#ff6b6b', linewidth=2.5, label=f'Smoothed (w={w})')
298
+ ax.set_xlabel('Step'); ax.set_ylabel('Loss')
299
+ ax.set_title('OpenGrid GRPO — Training Loss', fontweight='bold')
300
+ ax.legend(); ax.grid(True, alpha=0.3)
301
+ plt.tight_layout()
302
+ plt.savefig('training/outputs/training_loss.png', dpi=150)
303
+ plt.close()
304
+
305
+ # Save summary
306
+ summary = {
307
+ "model": MODEL_NAME,
308
+ "train_task": TRAIN_TASK,
309
+ "train_time_minutes": round(train_time / 60, 1),
310
+ "num_prompts": len(prompts),
311
+ "num_epochs": 3,
312
+ "baseline": {k: {"avg": round(v["avg"], 2), "std": round(v["std"], 2)} for k, v in baseline_results.items()},
313
+ "trained": {k: {"avg": round(v["avg"], 2), "std": round(v["std"], 2)} for k, v in trained_results.items()},
314
+ }
315
+ with open("training/outputs/summary.json", "w") as f:
316
+ json.dump(summary, f, indent=2)
317
+
318
+ print("\n" + "=" * 60)
319
+ print(" TRAINING COMPLETE")
320
+ print("=" * 60)
321
+ print(f" Time: {train_time/60:.1f} minutes")
322
+ print(f" {'Task':<20} {'Baseline':>10} {'Trained':>10} {'Δ':>8}")
323
+ print(f" {'-'*50}")
324
+ for t in common_tasks:
325
+ b, a = baseline_results[t]['avg'], trained_results[t]['avg']
326
+ arrow = '↑' if a > b else '↓'
327
+ print(f" {t:<20} {b:>10.2f} {a:>10.2f} {arrow} {abs(a-b):.2f}")
328
+ print("=" * 60)
329
+
330
+ return summary
331
+
332
+
333
+ # ── Results Server ────────────────────────────────────────────────
334
+ def serve_results():
335
+ """Serve training results on port 7860."""
336
+ from fastapi import FastAPI
337
+ from fastapi.responses import FileResponse, JSONResponse
338
+ import uvicorn
339
+
340
+ app = FastAPI(title="OpenGrid Training Results")
341
+
342
+ @app.get("/")
343
+ def root():
344
+ summary_path = Path("training/outputs/summary.json")
345
+ if summary_path.exists():
346
+ with open(summary_path) as f:
347
+ return json.load(f)
348
+ return {"status": "Training in progress or no results yet"}
349
+
350
+ @app.get("/plots/before_after")
351
+ def before_after():
352
+ p = Path("training/outputs/before_after.png")
353
+ if p.exists():
354
+ return FileResponse(str(p), media_type="image/png")
355
+ return JSONResponse({"error": "not ready"}, status_code=404)
356
+
357
+ @app.get("/plots/loss")
358
+ def loss():
359
+ p = Path("training/outputs/training_loss.png")
360
+ if p.exists():
361
+ return FileResponse(str(p), media_type="image/png")
362
+ return JSONResponse({"error": "not ready"}, status_code=404)
363
+
364
+ @app.get("/health")
365
+ def health():
366
+ return {"status": "ok"}
367
+
368
+ uvicorn.run(app, host="0.0.0.0", port=7860)
369
+
370
+
371
+ # ── Main ──────────────────────────────────────────────────────────
372
+ if __name__ == "__main__":
373
+ try:
374
+ summary = run_grpo_training()
375
+ except Exception as e:
376
+ print(f"\nERROR during training: {e}")
377
+ traceback.print_exc()
378
+ # Save error so the results server can report it
379
+ os.makedirs("training/outputs", exist_ok=True)
380
+ with open("training/outputs/summary.json", "w") as f:
381
+ json.dump({"error": str(e)}, f)
382
+
383
+ print("\nStarting results server on port 7860...")
384
+ serve_results()