Opengrid / run_training.py
K446's picture
Polish for hackathon submission: training evidence, two pipelines, UI, docs
e81353d
"""OpenGrid GRPO Training Runner for HF Spaces.
Runs env-grounded GRPO training, saves model + plots,
then starts a FastAPI server to serve/download results.
"""
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import sys
import json
import copy
import time
import shutil
import traceback
from pathlib import Path
# --- TRITON COMPILER FIX ---
import subprocess
try:
print("Checking for gcc...")
result = subprocess.run(['which', 'gcc'], capture_output=True, text=True)
gcc_path = result.stdout.strip()
print(f"gcc location: {gcc_path or 'NOT FOUND'}")
if gcc_path:
os.environ['CC'] = gcc_path
os.environ['CXX'] = shutil.which('g++') or ''
result2 = subprocess.run(['gcc', '--version'], capture_output=True, text=True)
print(f"gcc version:\n{result2.stdout.strip()[:100]}")
else:
print("WARNING: gcc still not found in PATH!")
except Exception as e:
print(f"Error checking gcc: {e}")
# ----------------------------
# ── Training ──────────────────────────────────────────────────────
def run_grpo_training():
"""Run GRPO training with env-grounded rewards."""
import torch
import numpy as np
print("=" * 60)
print(" OpenGrid GRPO Training")
print("=" * 60)
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
print("WARNING: No GPU detected — training will be very slow!")
# Import project modules
sys.path.insert(0, ".")
from src.environment import OpenGridEnv
from src.tasks import TASKS
from src.models import GridAction, BusAdjustment
from training.train_grpo import (
SYSTEM_PROMPT, format_observation_prompt,
compute_grpo_reward_env, extract_action,
rollout_multi_agent,
)
# ── 1. Load model ──
print("\n[1/6] Loading model with bitsandbytes 4-bit...")
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
# ── Iteration-budget config ── tweak these to trade speed vs quality ──
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
LORA_RANK = 8 # 8 → faster, less VRAM; 16 → more capacity
NUM_EPOCHS = 1 # 1 epoch ≈ 50 min; 3 epochs ≈ 2.5 h
NUM_EPISODES = 4 # prompt generation episodes (×15 steps ×n_agents ≈ prompts)
SAVE_STEPS = 25 # checkpoint every N steps so a late crash still saves progress
# ─────────────────────────────────────────────────────────────────────
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, quantization_config=bnb_config, device_map="auto",
)
# Critical for bnb-4bit + LoRA + gradient checkpointing: cast norms to fp32,
# enable input grads, and wire up non-reentrant checkpointing.
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
)
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False # silences the warning loop during training
lora_config = LoraConfig(
r=LORA_RANK, lora_alpha=LORA_RANK * 2, lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.enable_input_require_grads()
print(f" Model: {MODEL_NAME}")
print(f" Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
# ── 2. Baseline evaluation ──
print("\n[2/6] Running baseline evaluation...")
import re
def heuristic_generate(prompt):
freq_match = re.search(r'Frequency: ([\d.]+)', prompt)
freq = float(freq_match.group(1)) if freq_match else 50.0
error = 50.0 - freq
delta = max(-20, min(20, error * 10))
bus_match = re.search(r'Bus (\d+) \((generator|battery|slack)\)', prompt)
if bus_match:
return json.dumps({"bus_adjustments": [{"bus_id": int(bus_match.group(1)), "delta": round(delta, 1)}], "topology_actions": []})
return json.dumps({"bus_adjustments": [], "topology_actions": []})
baseline_results = {}
for task_id in ["task_easy", "task_medium", "karnataka_easy", "karnataka_medium", "karnataka_hard", "task_karnataka"]:
if task_id not in TASKS:
continue
config = TASKS[task_id]
rewards = []
for ep in range(3):
ep_config = copy.deepcopy(config)
ep_config['seed'] = 42 + ep
env = OpenGridEnv(ep_config)
result = rollout_multi_agent(env, heuristic_generate, ep_config)
rewards.append(result['total_reward'])
baseline_results[task_id] = {"avg": np.mean(rewards), "std": np.std(rewards), "rewards": rewards}
print(f" [BASELINE] {task_id}: {np.mean(rewards):.2f} ± {np.std(rewards):.2f}")
# ── 3. Generate training prompts ──
print("\n[3/6] Generating training prompts...")
TRAIN_TASK = "task_karnataka" if "task_karnataka" in TASKS else "task_easy"
task_config = copy.deepcopy(TASKS[TRAIN_TASK])
base_seed = task_config.get('seed', 42)
prompts = []
obs_contexts = []
rng = np.random.RandomState(base_seed)
for episode in range(NUM_EPISODES): # NUM_EPISODES × 15 steps × n_agents ≈ prompts
ep_config = copy.deepcopy(task_config)
ep_config['seed'] = base_seed + episode
env = OpenGridEnv(ep_config)
zone_obs = env.reset_multi()
# Adversarial: drain batteries every 5th episode
if episode % 5 == 0:
for b in env.bus_state:
b_cfg = env._find_bus_config(b['id'])
if b_cfg and b_cfg['type'] == 'battery':
b['soc'] = max(1.0, b['soc'] * 0.1)
for t in range(min(15, task_config['max_steps'])):
for agent_id, obs in zone_obs.items():
obs_dict = json.loads(obs.model_dump_json())
prompt_text = format_observation_prompt(obs_dict, zone_name=obs.zone_name)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt_text},
]
formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
prompts.append(formatted)
obs_contexts.append(json.dumps(obs_dict))
random_actions = {}
for aid in range(env.num_agents):
zone_buses = task_config['zone_bus_ids'].get(aid, [])
controllable = [
bid for bid in zone_buses
if next((b for b in task_config['buses'] if b['id'] == bid), {}).get('type')
in ['generator', 'battery']
]
adj = []
if controllable:
n_adj = min(len(controllable), rng.randint(1, 3))
chosen = rng.choice(controllable, size=n_adj, replace=False)
for bid in chosen:
adj.append(BusAdjustment(bus_id=int(bid), delta=float(rng.uniform(-30, 30))))
random_actions[aid] = GridAction(bus_adjustments=adj)
result = env.step_multi(random_actions)
if result.done:
break
zone_obs = result.observations
print(f" Generated {len(prompts)} training prompts")
# ── 4. Train ──
print("\n[4/6] Starting GRPO training...")
from trl import GRPOTrainer, GRPOConfig
from datasets import Dataset
import inspect as _inspect
_grpo_params = set(_inspect.signature(GRPOConfig.__init__).parameters)
_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
_fp16 = torch.cuda.is_available() and not _bf16
def reward_fn(completions, obs_context=None, **kwargs):
texts = []
for c in completions:
if isinstance(c, list):
text = c[-1]['content'] if c else ""
else:
text = str(c)
texts.append(text)
if obs_context is None:
obs_context = [None] * len(texts)
obs_dicts = []
for ctx in obs_context:
if isinstance(ctx, str):
try:
obs_dicts.append(json.loads(ctx))
except (json.JSONDecodeError, TypeError):
obs_dicts.append(None)
else:
obs_dicts.append(ctx)
return compute_grpo_reward_env(texts, obs_dicts, task_config)
# Set generation config explicitly so EOS is always respected and
# generation never runs to max_completion_length every single time.
from transformers import GenerationConfig
model.generation_config = GenerationConfig(
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=64,
)
# Some GRPOConfig params were renamed/moved between TRL versions; only pass
# what this installed TRL accepts.
_opt = {}
if 'max_prompt_length' in _grpo_params: _opt['max_prompt_length'] = 512
if 'max_completion_length' in _grpo_params: _opt['max_completion_length'] = 64
if 'torch_compile' in _grpo_params: _opt['torch_compile'] = False
if 'use_vllm' in _grpo_params: _opt['use_vllm'] = False
grpo_config = GRPOConfig(
output_dir="training/outputs/grpo_checkpoints",
num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-5, # slightly higher LR for fewer steps
logging_steps=1,
save_steps=SAVE_STEPS, # checkpoint often so late crashes don't lose everything
save_total_limit=3, # keep only 3 checkpoints to save disk
num_generations=4,
report_to="none",
remove_unused_columns=False,
bf16=_bf16,
fp16=_fp16,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
optim="paged_adamw_8bit",
warmup_ratio=0.05,
lr_scheduler_type="cosine",
dataloader_num_workers=0,
**_opt,
)
train_dataset = Dataset.from_dict({"prompt": prompts, "obs_context": obs_contexts})
print(f" Dataset: {len(train_dataset)} rows")
print(f" Effective batch: {grpo_config.per_device_train_batch_size * grpo_config.gradient_accumulation_steps}")
trainer = GRPOTrainer(
model=model, args=grpo_config, train_dataset=train_dataset,
reward_funcs=reward_fn, processing_class=tokenizer,
)
# ── Sanity-check generation before handing off to GRPO ──
# If this hangs, the model/tokenizer setup is the problem.
print(" [DEBUG] Testing model generation (should complete in <30s)...")
_test_inputs = tokenizer("Hello", return_tensors="pt").to(model.device)
with torch.no_grad():
_out = model.generate(
**_test_inputs,
max_new_tokens=8,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
print(f" [DEBUG] Generation OK: {tokenizer.decode(_out[0][-8:], skip_special_tokens=True)!r}")
print(" [NOTE] First GRPO step includes Triton JIT — may show 0/N for up to 5 min. That is normal.")
t0 = time.time()
trainer.train()
train_time = time.time() - t0
print(f"\n Training complete in {train_time/60:.1f} minutes")
# Save adapter only (avoids OOM from merging/dequantising the full model)
output_path = "training/outputs/trained_model"
os.makedirs(output_path, exist_ok=True)
torch.cuda.empty_cache() # free activations before saving
try:
model.save_pretrained(output_path) # saves LoRA adapter weights only
tokenizer.save_pretrained(output_path)
print(f" Adapter saved to {output_path}")
except Exception as save_err:
print(f" WARNING: adapter save failed ({save_err}); training metrics still captured")
# ── 5. Post-training evaluation ──
# Only evaluate on 3 tasks × 1 episode to stay within VRAM budget.
# Full 6-task × 3-episode eval can be run offline if needed.
print("\n[5/6] Evaluating trained model (fast: 3 tasks × 1 ep)...")
torch.cuda.empty_cache()
model.eval()
def trained_generate(prompt):
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
]
formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs, max_new_tokens=64, # short for speed; enough for JSON action
temperature=0.3, do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
trained_results = {}
EVAL_TASKS = ["task_easy", "task_karnataka", "karnataka_hard"] # representative subset
for task_id in EVAL_TASKS:
if task_id not in TASKS:
continue
try:
config = TASKS[task_id]
ep_config = copy.deepcopy(config)
ep_config['seed'] = 42
env = OpenGridEnv(ep_config)
result = rollout_multi_agent(env, trained_generate, ep_config)
r = result['total_reward']
trained_results[task_id] = {"avg": round(r, 2), "std": 0.0, "rewards": [r]}
print(f" [TRAINED] {task_id}: {r:.2f}")
torch.cuda.empty_cache()
except Exception as eval_err:
print(f" [TRAINED] {task_id}: eval failed ({eval_err})")
trained_results[task_id] = {"avg": None, "std": None, "rewards": []}
# ── 6. Generate plots ──
print("\n[6/6] Generating plots...")
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
os.makedirs("training/outputs", exist_ok=True)
# Before vs After
common_tasks = [t for t in baseline_results if t in trained_results]
if common_tasks:
fig, ax = plt.subplots(figsize=(10, 6))
x = np.arange(len(common_tasks))
width = 0.35
before = [baseline_results[t]['avg'] for t in common_tasks]
after = [trained_results[t]['avg'] for t in common_tasks]
ax.bar(x - width/2, before, width, label='Heuristic Baseline', color='#ff6b6b', alpha=0.8)
ax.bar(x + width/2, after, width, label='GRPO Trained', color='#00d4aa', alpha=0.8)
ax.set_xlabel('Task'); ax.set_ylabel('Average Episode Reward')
ax.set_title('OpenGrid — GRPO Training: Before vs After', fontweight='bold')
ax.set_xticks(x); ax.set_xticklabels([t.replace('task_', '').title() for t in common_tasks])
ax.legend(); ax.grid(True, alpha=0.3, axis='y')
for bars in ax.containers:
for bar in bars:
h = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., h + (1 if h >= 0 else -3),
f'{h:.1f}', ha='center', va='bottom' if h >= 0 else 'top', fontsize=10)
plt.tight_layout()
plt.savefig('training/outputs/before_after.png', dpi=150)
plt.close()
# Training loss
history = trainer.state.log_history
steps = [h['step'] for h in history if 'loss' in h]
losses = [h['loss'] for h in history if 'loss' in h]
if steps:
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(steps, losses, color='#ff6b6b', linewidth=1.5, alpha=0.6, label='Loss')
if len(losses) > 10:
w = min(20, len(losses) // 3)
smoothed = np.convolve(losses, np.ones(w)/w, mode='valid')
ax.plot(steps[w-1:], smoothed, color='#ff6b6b', linewidth=2.5, label=f'Smoothed (w={w})')
ax.set_xlabel('Step'); ax.set_ylabel('Loss')
ax.set_title('OpenGrid GRPO — Training Loss', fontweight='bold')
ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('training/outputs/training_loss.png', dpi=150)
plt.close()
# Save summary — includes run config so multiple runs are comparable
# Also record trainer log history for the reward curve
log_history = trainer.state.log_history
summary = {
"model": MODEL_NAME,
"train_task": TRAIN_TASK,
"train_time_minutes": round(train_time / 60, 1),
"num_prompts": len(prompts),
"num_epochs": NUM_EPOCHS,
"lora_rank": LORA_RANK,
"baseline": {k: {"avg": round(v["avg"], 2), "std": round(v["std"], 2)} for k, v in baseline_results.items()},
"trained": {k: {"avg": round(v["avg"], 2) if v["avg"] is not None else None,
"std": round(v["std"], 2) if v["std"] is not None else None}
for k, v in trained_results.items()},
"reward_start": round(float(np.mean([h['reward'] for h in log_history if 'reward' in h][:5])), 4) if log_history else None,
"reward_end": round(float(np.mean([h['reward'] for h in log_history if 'reward' in h][-20:])), 4) if log_history else None,
}
with open("training/outputs/summary.json", "w") as f:
json.dump(summary, f, indent=2)
print("\n" + "=" * 60)
print(" TRAINING COMPLETE")
print("=" * 60)
print(f" Time: {train_time/60:.1f} minutes")
print(f" {'Task':<20} {'Baseline':>10} {'Trained':>10} {'Δ':>8}")
print(f" {'-'*50}")
for t in common_tasks:
b, a = baseline_results[t]['avg'], trained_results[t]['avg']
arrow = '↑' if a > b else '↓'
print(f" {t:<20} {b:>10.2f} {a:>10.2f} {arrow} {abs(a-b):.2f}")
print("=" * 60)
return summary
# ── Main ──────────────────────────────────────────────────────────
if __name__ == "__main__":
try:
summary = run_grpo_training()
except Exception as e:
print(f"\nERROR during training: {e}")
traceback.print_exc()
# Save error so the UI can report it
os.makedirs("training/outputs", exist_ok=True)
with open("training/outputs/summary.json", "w") as f:
json.dump({"error": str(e)}, f)
# Start the full UI server (not a mini results server)
# This serves the control room + training results on port 7860
# NOTE: In training mode, entrypoint.sh starts the server in background
# before training. This block is kept for standalone execution only.
if os.environ.get("OPENGRID_MODE") != "training":
print("\nTraining done. Starting full UI server on port 7860...")
import uvicorn
from app import app
uvicorn.run(app, host="0.0.0.0", port=7860)
else:
print("\nTraining done. UI server already running in background.")