"""OpenGrid GRPO Training Runner for HF Spaces. Runs env-grounded GRPO training, saves model + plots, then starts a FastAPI server to serve/download results. """ import os os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import sys import json import copy import time import shutil import traceback from pathlib import Path # --- TRITON COMPILER FIX --- import subprocess try: print("Checking for gcc...") result = subprocess.run(['which', 'gcc'], capture_output=True, text=True) gcc_path = result.stdout.strip() print(f"gcc location: {gcc_path or 'NOT FOUND'}") if gcc_path: os.environ['CC'] = gcc_path os.environ['CXX'] = shutil.which('g++') or '' result2 = subprocess.run(['gcc', '--version'], capture_output=True, text=True) print(f"gcc version:\n{result2.stdout.strip()[:100]}") else: print("WARNING: gcc still not found in PATH!") except Exception as e: print(f"Error checking gcc: {e}") # ---------------------------- # ── Training ────────────────────────────────────────────────────── def run_grpo_training(): """Run GRPO training with env-grounded rewards.""" import torch import numpy as np print("=" * 60) print(" OpenGrid GRPO Training") print("=" * 60) if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") else: print("WARNING: No GPU detected — training will be very slow!") # Import project modules sys.path.insert(0, ".") from src.environment import OpenGridEnv from src.tasks import TASKS from src.models import GridAction, BusAdjustment from training.train_grpo import ( SYSTEM_PROMPT, format_observation_prompt, compute_grpo_reward_env, extract_action, rollout_multi_agent, ) # ── 1. Load model ── print("\n[1/6] Loading model with bitsandbytes 4-bit...") from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training # ── Iteration-budget config ── tweak these to trade speed vs quality ── MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" LORA_RANK = 8 # 8 → faster, less VRAM; 16 → more capacity NUM_EPOCHS = 1 # 1 epoch ≈ 50 min; 3 epochs ≈ 2.5 h NUM_EPISODES = 4 # prompt generation episodes (×15 steps ×n_agents ≈ prompts) SAVE_STEPS = 25 # checkpoint every N steps so a late crash still saves progress # ───────────────────────────────────────────────────────────────────── bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, bnb_4bit_use_double_quant=True, ) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=bnb_config, device_map="auto", ) # Critical for bnb-4bit + LoRA + gradient checkpointing: cast norms to fp32, # enable input grads, and wire up non-reentrant checkpointing. model = prepare_model_for_kbit_training( model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, ) model.config.pad_token_id = tokenizer.pad_token_id model.config.use_cache = False # silences the warning loop during training lora_config = LoraConfig( r=LORA_RANK, lora_alpha=LORA_RANK * 2, lora_dropout=0.05, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], task_type="CAUSAL_LM", ) model = get_peft_model(model, lora_config) model.enable_input_require_grads() print(f" Model: {MODEL_NAME}") print(f" Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}") # ── 2. Baseline evaluation ── print("\n[2/6] Running baseline evaluation...") import re def heuristic_generate(prompt): freq_match = re.search(r'Frequency: ([\d.]+)', prompt) freq = float(freq_match.group(1)) if freq_match else 50.0 error = 50.0 - freq delta = max(-20, min(20, error * 10)) bus_match = re.search(r'Bus (\d+) \((generator|battery|slack)\)', prompt) if bus_match: return json.dumps({"bus_adjustments": [{"bus_id": int(bus_match.group(1)), "delta": round(delta, 1)}], "topology_actions": []}) return json.dumps({"bus_adjustments": [], "topology_actions": []}) baseline_results = {} for task_id in ["task_easy", "task_medium", "karnataka_easy", "karnataka_medium", "karnataka_hard", "task_karnataka"]: if task_id not in TASKS: continue config = TASKS[task_id] rewards = [] for ep in range(3): ep_config = copy.deepcopy(config) ep_config['seed'] = 42 + ep env = OpenGridEnv(ep_config) result = rollout_multi_agent(env, heuristic_generate, ep_config) rewards.append(result['total_reward']) baseline_results[task_id] = {"avg": np.mean(rewards), "std": np.std(rewards), "rewards": rewards} print(f" [BASELINE] {task_id}: {np.mean(rewards):.2f} ± {np.std(rewards):.2f}") # ── 3. Generate training prompts ── print("\n[3/6] Generating training prompts...") TRAIN_TASK = "task_karnataka" if "task_karnataka" in TASKS else "task_easy" task_config = copy.deepcopy(TASKS[TRAIN_TASK]) base_seed = task_config.get('seed', 42) prompts = [] obs_contexts = [] rng = np.random.RandomState(base_seed) for episode in range(NUM_EPISODES): # NUM_EPISODES × 15 steps × n_agents ≈ prompts ep_config = copy.deepcopy(task_config) ep_config['seed'] = base_seed + episode env = OpenGridEnv(ep_config) zone_obs = env.reset_multi() # Adversarial: drain batteries every 5th episode if episode % 5 == 0: for b in env.bus_state: b_cfg = env._find_bus_config(b['id']) if b_cfg and b_cfg['type'] == 'battery': b['soc'] = max(1.0, b['soc'] * 0.1) for t in range(min(15, task_config['max_steps'])): for agent_id, obs in zone_obs.items(): obs_dict = json.loads(obs.model_dump_json()) prompt_text = format_observation_prompt(obs_dict, zone_name=obs.zone_name) messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt_text}, ] formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) prompts.append(formatted) obs_contexts.append(json.dumps(obs_dict)) random_actions = {} for aid in range(env.num_agents): zone_buses = task_config['zone_bus_ids'].get(aid, []) controllable = [ bid for bid in zone_buses if next((b for b in task_config['buses'] if b['id'] == bid), {}).get('type') in ['generator', 'battery'] ] adj = [] if controllable: n_adj = min(len(controllable), rng.randint(1, 3)) chosen = rng.choice(controllable, size=n_adj, replace=False) for bid in chosen: adj.append(BusAdjustment(bus_id=int(bid), delta=float(rng.uniform(-30, 30)))) random_actions[aid] = GridAction(bus_adjustments=adj) result = env.step_multi(random_actions) if result.done: break zone_obs = result.observations print(f" Generated {len(prompts)} training prompts") # ── 4. Train ── print("\n[4/6] Starting GRPO training...") from trl import GRPOTrainer, GRPOConfig from datasets import Dataset import inspect as _inspect _grpo_params = set(_inspect.signature(GRPOConfig.__init__).parameters) _bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported() _fp16 = torch.cuda.is_available() and not _bf16 def reward_fn(completions, obs_context=None, **kwargs): texts = [] for c in completions: if isinstance(c, list): text = c[-1]['content'] if c else "" else: text = str(c) texts.append(text) if obs_context is None: obs_context = [None] * len(texts) obs_dicts = [] for ctx in obs_context: if isinstance(ctx, str): try: obs_dicts.append(json.loads(ctx)) except (json.JSONDecodeError, TypeError): obs_dicts.append(None) else: obs_dicts.append(ctx) return compute_grpo_reward_env(texts, obs_dicts, task_config) # Set generation config explicitly so EOS is always respected and # generation never runs to max_completion_length every single time. from transformers import GenerationConfig model.generation_config = GenerationConfig( do_sample=True, temperature=0.7, top_p=0.9, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, max_new_tokens=64, ) # Some GRPOConfig params were renamed/moved between TRL versions; only pass # what this installed TRL accepts. _opt = {} if 'max_prompt_length' in _grpo_params: _opt['max_prompt_length'] = 512 if 'max_completion_length' in _grpo_params: _opt['max_completion_length'] = 64 if 'torch_compile' in _grpo_params: _opt['torch_compile'] = False if 'use_vllm' in _grpo_params: _opt['use_vllm'] = False grpo_config = GRPOConfig( output_dir="training/outputs/grpo_checkpoints", num_train_epochs=NUM_EPOCHS, per_device_train_batch_size=4, gradient_accumulation_steps=4, learning_rate=2e-5, # slightly higher LR for fewer steps logging_steps=1, save_steps=SAVE_STEPS, # checkpoint often so late crashes don't lose everything save_total_limit=3, # keep only 3 checkpoints to save disk num_generations=4, report_to="none", remove_unused_columns=False, bf16=_bf16, fp16=_fp16, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, optim="paged_adamw_8bit", warmup_ratio=0.05, lr_scheduler_type="cosine", dataloader_num_workers=0, **_opt, ) train_dataset = Dataset.from_dict({"prompt": prompts, "obs_context": obs_contexts}) print(f" Dataset: {len(train_dataset)} rows") print(f" Effective batch: {grpo_config.per_device_train_batch_size * grpo_config.gradient_accumulation_steps}") trainer = GRPOTrainer( model=model, args=grpo_config, train_dataset=train_dataset, reward_funcs=reward_fn, processing_class=tokenizer, ) # ── Sanity-check generation before handing off to GRPO ── # If this hangs, the model/tokenizer setup is the problem. print(" [DEBUG] Testing model generation (should complete in <30s)...") _test_inputs = tokenizer("Hello", return_tensors="pt").to(model.device) with torch.no_grad(): _out = model.generate( **_test_inputs, max_new_tokens=8, do_sample=False, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) print(f" [DEBUG] Generation OK: {tokenizer.decode(_out[0][-8:], skip_special_tokens=True)!r}") print(" [NOTE] First GRPO step includes Triton JIT — may show 0/N for up to 5 min. That is normal.") t0 = time.time() trainer.train() train_time = time.time() - t0 print(f"\n Training complete in {train_time/60:.1f} minutes") # Save adapter only (avoids OOM from merging/dequantising the full model) output_path = "training/outputs/trained_model" os.makedirs(output_path, exist_ok=True) torch.cuda.empty_cache() # free activations before saving try: model.save_pretrained(output_path) # saves LoRA adapter weights only tokenizer.save_pretrained(output_path) print(f" Adapter saved to {output_path}") except Exception as save_err: print(f" WARNING: adapter save failed ({save_err}); training metrics still captured") # ── 5. Post-training evaluation ── # Only evaluate on 3 tasks × 1 episode to stay within VRAM budget. # Full 6-task × 3-episode eval can be run offline if needed. print("\n[5/6] Evaluating trained model (fast: 3 tasks × 1 ep)...") torch.cuda.empty_cache() model.eval() def trained_generate(prompt): messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt}, ] formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(formatted, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=64, # short for speed; enough for JSON action temperature=0.3, do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) trained_results = {} EVAL_TASKS = ["task_easy", "task_karnataka", "karnataka_hard"] # representative subset for task_id in EVAL_TASKS: if task_id not in TASKS: continue try: config = TASKS[task_id] ep_config = copy.deepcopy(config) ep_config['seed'] = 42 env = OpenGridEnv(ep_config) result = rollout_multi_agent(env, trained_generate, ep_config) r = result['total_reward'] trained_results[task_id] = {"avg": round(r, 2), "std": 0.0, "rewards": [r]} print(f" [TRAINED] {task_id}: {r:.2f}") torch.cuda.empty_cache() except Exception as eval_err: print(f" [TRAINED] {task_id}: eval failed ({eval_err})") trained_results[task_id] = {"avg": None, "std": None, "rewards": []} # ── 6. Generate plots ── print("\n[6/6] Generating plots...") import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt os.makedirs("training/outputs", exist_ok=True) # Before vs After common_tasks = [t for t in baseline_results if t in trained_results] if common_tasks: fig, ax = plt.subplots(figsize=(10, 6)) x = np.arange(len(common_tasks)) width = 0.35 before = [baseline_results[t]['avg'] for t in common_tasks] after = [trained_results[t]['avg'] for t in common_tasks] ax.bar(x - width/2, before, width, label='Heuristic Baseline', color='#ff6b6b', alpha=0.8) ax.bar(x + width/2, after, width, label='GRPO Trained', color='#00d4aa', alpha=0.8) ax.set_xlabel('Task'); ax.set_ylabel('Average Episode Reward') ax.set_title('OpenGrid — GRPO Training: Before vs After', fontweight='bold') ax.set_xticks(x); ax.set_xticklabels([t.replace('task_', '').title() for t in common_tasks]) ax.legend(); ax.grid(True, alpha=0.3, axis='y') for bars in ax.containers: for bar in bars: h = bar.get_height() ax.text(bar.get_x() + bar.get_width()/2., h + (1 if h >= 0 else -3), f'{h:.1f}', ha='center', va='bottom' if h >= 0 else 'top', fontsize=10) plt.tight_layout() plt.savefig('training/outputs/before_after.png', dpi=150) plt.close() # Training loss history = trainer.state.log_history steps = [h['step'] for h in history if 'loss' in h] losses = [h['loss'] for h in history if 'loss' in h] if steps: fig, ax = plt.subplots(figsize=(10, 5)) ax.plot(steps, losses, color='#ff6b6b', linewidth=1.5, alpha=0.6, label='Loss') if len(losses) > 10: w = min(20, len(losses) // 3) smoothed = np.convolve(losses, np.ones(w)/w, mode='valid') ax.plot(steps[w-1:], smoothed, color='#ff6b6b', linewidth=2.5, label=f'Smoothed (w={w})') ax.set_xlabel('Step'); ax.set_ylabel('Loss') ax.set_title('OpenGrid GRPO — Training Loss', fontweight='bold') ax.legend(); ax.grid(True, alpha=0.3) plt.tight_layout() plt.savefig('training/outputs/training_loss.png', dpi=150) plt.close() # Save summary — includes run config so multiple runs are comparable # Also record trainer log history for the reward curve log_history = trainer.state.log_history summary = { "model": MODEL_NAME, "train_task": TRAIN_TASK, "train_time_minutes": round(train_time / 60, 1), "num_prompts": len(prompts), "num_epochs": NUM_EPOCHS, "lora_rank": LORA_RANK, "baseline": {k: {"avg": round(v["avg"], 2), "std": round(v["std"], 2)} for k, v in baseline_results.items()}, "trained": {k: {"avg": round(v["avg"], 2) if v["avg"] is not None else None, "std": round(v["std"], 2) if v["std"] is not None else None} for k, v in trained_results.items()}, "reward_start": round(float(np.mean([h['reward'] for h in log_history if 'reward' in h][:5])), 4) if log_history else None, "reward_end": round(float(np.mean([h['reward'] for h in log_history if 'reward' in h][-20:])), 4) if log_history else None, } with open("training/outputs/summary.json", "w") as f: json.dump(summary, f, indent=2) print("\n" + "=" * 60) print(" TRAINING COMPLETE") print("=" * 60) print(f" Time: {train_time/60:.1f} minutes") print(f" {'Task':<20} {'Baseline':>10} {'Trained':>10} {'Δ':>8}") print(f" {'-'*50}") for t in common_tasks: b, a = baseline_results[t]['avg'], trained_results[t]['avg'] arrow = '↑' if a > b else '↓' print(f" {t:<20} {b:>10.2f} {a:>10.2f} {arrow} {abs(a-b):.2f}") print("=" * 60) return summary # ── Main ────────────────────────────────────────────────────────── if __name__ == "__main__": try: summary = run_grpo_training() except Exception as e: print(f"\nERROR during training: {e}") traceback.print_exc() # Save error so the UI can report it os.makedirs("training/outputs", exist_ok=True) with open("training/outputs/summary.json", "w") as f: json.dump({"error": str(e)}, f) # Start the full UI server (not a mini results server) # This serves the control room + training results on port 7860 # NOTE: In training mode, entrypoint.sh starts the server in background # before training. This block is kept for standalone execution only. if os.environ.get("OPENGRID_MODE") != "training": print("\nTraining done. Starting full UI server on port 7860...") import uvicorn from app import app uvicorn.run(app, host="0.0.0.0", port=7860) else: print("\nTraining done. UI server already running in background.")