GRPO training with CUDA + results in UI
Browse files- Dockerfile +13 -3
- _fix_notebook.py +75 -0
- app.py +39 -1
- requirements-training.txt +1 -2
- run_training.py +7 -41
Dockerfile
CHANGED
|
@@ -2,12 +2,19 @@
|
|
| 2 |
# Serves both the UI dashboard AND GRPO training.
|
| 3 |
# Set env OPENGRID_MODE=training for training mode.
|
| 4 |
|
| 5 |
-
FROM
|
| 6 |
|
| 7 |
LABEL org.opencontainers.image.title="OpenGrid"
|
| 8 |
LABEL org.opencontainers.image.description="Renewable energy grid load-balancing environment"
|
| 9 |
LABEL openenv="true"
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
RUN useradd -m -u 1000 user
|
| 12 |
USER user
|
| 13 |
ENV PATH="/home/user/.local/bin:$PATH"
|
|
@@ -19,9 +26,12 @@ WORKDIR /app
|
|
| 19 |
COPY --chown=user requirements.txt .
|
| 20 |
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
# Install training deps (only re-runs if training reqs change)
|
| 23 |
COPY --chown=user requirements-training.txt .
|
| 24 |
-
RUN pip install --no-cache-dir --upgrade -r requirements-training.txt
|
| 25 |
|
| 26 |
# --- Application code (selective COPY for lean images) ---
|
| 27 |
# Core Python modules
|
|
@@ -47,7 +57,7 @@ RUN chmod +x entrypoint.sh
|
|
| 47 |
# server = FastAPI UI, training = GRPO pipeline
|
| 48 |
EXPOSE 7860
|
| 49 |
|
| 50 |
-
HEALTHCHECK --interval=
|
| 51 |
CMD python -c "import httpx; httpx.get('http://localhost:7860/health').raise_for_status()" || exit 1
|
| 52 |
|
| 53 |
CMD ["./entrypoint.sh"]
|
|
|
|
| 2 |
# Serves both the UI dashboard AND GRPO training.
|
| 3 |
# Set env OPENGRID_MODE=training for training mode.
|
| 4 |
|
| 5 |
+
FROM nvidia/cuda:12.1.1-runtime-ubuntu22.04
|
| 6 |
|
| 7 |
LABEL org.opencontainers.image.title="OpenGrid"
|
| 8 |
LABEL org.opencontainers.image.description="Renewable energy grid load-balancing environment"
|
| 9 |
LABEL openenv="true"
|
| 10 |
|
| 11 |
+
# Install Python 3.10
|
| 12 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 13 |
+
python3.10 python3-pip python3.10-venv && \
|
| 14 |
+
ln -sf /usr/bin/python3.10 /usr/bin/python && \
|
| 15 |
+
ln -sf /usr/bin/pip3 /usr/bin/pip && \
|
| 16 |
+
rm -rf /var/lib/apt/lists/*
|
| 17 |
+
|
| 18 |
RUN useradd -m -u 1000 user
|
| 19 |
USER user
|
| 20 |
ENV PATH="/home/user/.local/bin:$PATH"
|
|
|
|
| 26 |
COPY --chown=user requirements.txt .
|
| 27 |
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 28 |
|
| 29 |
+
# Install PyTorch with CUDA support (must come before training deps)
|
| 30 |
+
RUN pip install --no-cache-dir torch==2.5.1 --extra-index-url https://download.pytorch.org/whl/cu121
|
| 31 |
+
|
| 32 |
# Install training deps (only re-runs if training reqs change)
|
| 33 |
COPY --chown=user requirements-training.txt .
|
| 34 |
+
RUN pip install --no-cache-dir --upgrade --no-deps -r requirements-training.txt
|
| 35 |
|
| 36 |
# --- Application code (selective COPY for lean images) ---
|
| 37 |
# Core Python modules
|
|
|
|
| 57 |
# server = FastAPI UI, training = GRPO pipeline
|
| 58 |
EXPOSE 7860
|
| 59 |
|
| 60 |
+
HEALTHCHECK --interval=60s --timeout=10s --start-period=600s \
|
| 61 |
CMD python -c "import httpx; httpx.get('http://localhost:7860/health').raise_for_status()" || exit 1
|
| 62 |
|
| 63 |
CMD ["./entrypoint.sh"]
|
_fix_notebook.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Update the notebook: fix rewards, hyperparams, remove emojis, show plots inline."""
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
nb = json.load(open('training/opengrid_grpo_colab.ipynb', encoding='utf-8'))
|
| 5 |
+
|
| 6 |
+
# Remove emojis from all cells
|
| 7 |
+
for cell in nb['cells']:
|
| 8 |
+
for i, line in enumerate(cell.get('source', [])):
|
| 9 |
+
for emoji in ['🔋','⚡','🚀','📊','✅','⚠️']:
|
| 10 |
+
line = line.replace(emoji, '')
|
| 11 |
+
cell['source'][i] = line
|
| 12 |
+
|
| 13 |
+
# Fix Cell 8: use compute_grpo_reward_env
|
| 14 |
+
for cell in nb['cells']:
|
| 15 |
+
src = ''.join(cell.get('source', []))
|
| 16 |
+
if 'compute_grpo_reward,' in src and 'def reward_fn' in src:
|
| 17 |
+
cell['source'] = [
|
| 18 |
+
'import json as _json\n',
|
| 19 |
+
'from training.train_grpo import compute_grpo_reward_env, extract_action\n',
|
| 20 |
+
'\n',
|
| 21 |
+
'def reward_fn(completions, obs_context=None, **kwargs):\n',
|
| 22 |
+
' """GRPO reward function with env-grounded physics rewards."""\n',
|
| 23 |
+
' texts = []\n',
|
| 24 |
+
' for c in completions:\n',
|
| 25 |
+
' if isinstance(c, list):\n',
|
| 26 |
+
' text = c[-1]["content"] if c else ""\n',
|
| 27 |
+
' else:\n',
|
| 28 |
+
' text = str(c)\n',
|
| 29 |
+
' texts.append(text)\n',
|
| 30 |
+
'\n',
|
| 31 |
+
' if obs_context is None:\n',
|
| 32 |
+
' batch_obs = [None] * len(texts)\n',
|
| 33 |
+
' else:\n',
|
| 34 |
+
' batch_obs = [\n',
|
| 35 |
+
' _json.loads(ctx) if isinstance(ctx, str) else ctx\n',
|
| 36 |
+
' for ctx in obs_context\n',
|
| 37 |
+
' ]\n',
|
| 38 |
+
' return compute_grpo_reward_env(texts, batch_obs, task_config, horizon=3)\n',
|
| 39 |
+
'\n',
|
| 40 |
+
'# Sanity test\n',
|
| 41 |
+
'test_rewards = reward_fn([\n',
|
| 42 |
+
' \'{"bus_adjustments": [{"bus_id": 1, "delta": 5.0}], "topology_actions": []}\',\n',
|
| 43 |
+
' "invalid json here",\n',
|
| 44 |
+
'])\n',
|
| 45 |
+
'print(f"Test rewards: {test_rewards}")\n',
|
| 46 |
+
'assert len(test_rewards) == 2\n',
|
| 47 |
+
'print("[OK] reward_fn works")\n',
|
| 48 |
+
]
|
| 49 |
+
break
|
| 50 |
+
|
| 51 |
+
# Fix Cell 9: update hyperparameters
|
| 52 |
+
for cell in nb['cells']:
|
| 53 |
+
src = ''.join(cell.get('source', []))
|
| 54 |
+
if 'GRPOConfig(' in src and 'num_generations' in src:
|
| 55 |
+
new_src = src.replace('num_train_epochs=1', 'num_train_epochs=3')
|
| 56 |
+
new_src = new_src.replace('gradient_accumulation_steps=4', 'gradient_accumulation_steps=8')
|
| 57 |
+
new_src = new_src.replace('learning_rate=5e-6', 'learning_rate=1e-5')
|
| 58 |
+
new_src = new_src.replace('num_generations=4', 'num_generations=8')
|
| 59 |
+
cell['source'] = new_src.splitlines(True)
|
| 60 |
+
break
|
| 61 |
+
|
| 62 |
+
# Fix download cell: replace google.colab with inline display
|
| 63 |
+
for cell in nb['cells']:
|
| 64 |
+
src = ''.join(cell.get('source', []))
|
| 65 |
+
if 'google.colab' in src:
|
| 66 |
+
cell['source'] = [
|
| 67 |
+
'# Display plots inline\n',
|
| 68 |
+
'from IPython.display import Image, display\n',
|
| 69 |
+
'display(Image("training/outputs/before_after.png"))\n',
|
| 70 |
+
'display(Image("training/outputs/training_loss.png"))\n',
|
| 71 |
+
]
|
| 72 |
+
break
|
| 73 |
+
|
| 74 |
+
json.dump(nb, open('training/opengrid_grpo_colab.ipynb', 'w', encoding='utf-8'), indent=1)
|
| 75 |
+
print("Notebook updated successfully")
|
app.py
CHANGED
|
@@ -413,4 +413,42 @@ def visualize(session_id: str):
|
|
| 413 |
hist = list(history.get(session_id, []))
|
| 414 |
|
| 415 |
img_str = generate_dashboard(hist, obs)
|
| 416 |
-
return {"image_base64": img_str}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
hist = list(history.get(session_id, []))
|
| 414 |
|
| 415 |
img_str = generate_dashboard(hist, obs)
|
| 416 |
+
return {"image_base64": img_str}
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
# ===========================================================================
|
| 420 |
+
# Training Results
|
| 421 |
+
# ===========================================================================
|
| 422 |
+
|
| 423 |
+
@app.get("/training-results")
|
| 424 |
+
def training_results():
|
| 425 |
+
"""Return GRPO training results if available."""
|
| 426 |
+
summary_path = pathlib.Path("training/outputs/summary.json")
|
| 427 |
+
if not summary_path.exists():
|
| 428 |
+
return {"available": False}
|
| 429 |
+
with open(summary_path) as f:
|
| 430 |
+
data = json.load(f)
|
| 431 |
+
# Check if it was an error
|
| 432 |
+
if "error" in data:
|
| 433 |
+
return {"available": True, "error": data["error"]}
|
| 434 |
+
# Add plot URLs
|
| 435 |
+
data["available"] = True
|
| 436 |
+
data["plots"] = {}
|
| 437 |
+
for name in ["before_after", "training_loss"]:
|
| 438 |
+
p = pathlib.Path(f"training/outputs/{name}.png")
|
| 439 |
+
if p.exists():
|
| 440 |
+
data["plots"][name] = f"/training-plots/{name}"
|
| 441 |
+
return data
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
@app.get("/training-plots/{name}")
|
| 445 |
+
def training_plot(name: str):
|
| 446 |
+
"""Serve a training plot image."""
|
| 447 |
+
from fastapi.responses import FileResponse
|
| 448 |
+
allowed = {"before_after", "training_loss"}
|
| 449 |
+
if name not in allowed:
|
| 450 |
+
raise HTTPException(404, "Plot not found")
|
| 451 |
+
p = pathlib.Path(f"training/outputs/{name}.png")
|
| 452 |
+
if not p.exists():
|
| 453 |
+
raise HTTPException(404, "Plot not generated yet")
|
| 454 |
+
return FileResponse(str(p), media_type="image/png")
|
requirements-training.txt
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
-
# Training dependencies
|
| 2 |
-
torch
|
| 3 |
transformers>=4.51.3
|
| 4 |
trl>=0.12.0,<1.0
|
| 5 |
peft>=0.13.0
|
|
|
|
| 1 |
+
# Training dependencies (torch installed separately in Dockerfile with CUDA)
|
|
|
|
| 2 |
transformers>=4.51.3
|
| 3 |
trl>=0.12.0,<1.0
|
| 4 |
peft>=0.13.0
|
run_training.py
CHANGED
|
@@ -330,44 +330,6 @@ def run_grpo_training():
|
|
| 330 |
return summary
|
| 331 |
|
| 332 |
|
| 333 |
-
# ── Results Server ────────────────────────────────────────────────
|
| 334 |
-
def serve_results():
|
| 335 |
-
"""Serve training results on port 7860."""
|
| 336 |
-
from fastapi import FastAPI
|
| 337 |
-
from fastapi.responses import FileResponse, JSONResponse
|
| 338 |
-
import uvicorn
|
| 339 |
-
|
| 340 |
-
app = FastAPI(title="OpenGrid Training Results")
|
| 341 |
-
|
| 342 |
-
@app.get("/")
|
| 343 |
-
def root():
|
| 344 |
-
summary_path = Path("training/outputs/summary.json")
|
| 345 |
-
if summary_path.exists():
|
| 346 |
-
with open(summary_path) as f:
|
| 347 |
-
return json.load(f)
|
| 348 |
-
return {"status": "Training in progress or no results yet"}
|
| 349 |
-
|
| 350 |
-
@app.get("/plots/before_after")
|
| 351 |
-
def before_after():
|
| 352 |
-
p = Path("training/outputs/before_after.png")
|
| 353 |
-
if p.exists():
|
| 354 |
-
return FileResponse(str(p), media_type="image/png")
|
| 355 |
-
return JSONResponse({"error": "not ready"}, status_code=404)
|
| 356 |
-
|
| 357 |
-
@app.get("/plots/loss")
|
| 358 |
-
def loss():
|
| 359 |
-
p = Path("training/outputs/training_loss.png")
|
| 360 |
-
if p.exists():
|
| 361 |
-
return FileResponse(str(p), media_type="image/png")
|
| 362 |
-
return JSONResponse({"error": "not ready"}, status_code=404)
|
| 363 |
-
|
| 364 |
-
@app.get("/health")
|
| 365 |
-
def health():
|
| 366 |
-
return {"status": "ok"}
|
| 367 |
-
|
| 368 |
-
uvicorn.run(app, host="0.0.0.0", port=7860)
|
| 369 |
-
|
| 370 |
-
|
| 371 |
# ── Main ──────────────────────────────────────────────────────────
|
| 372 |
if __name__ == "__main__":
|
| 373 |
try:
|
|
@@ -375,10 +337,14 @@ if __name__ == "__main__":
|
|
| 375 |
except Exception as e:
|
| 376 |
print(f"\nERROR during training: {e}")
|
| 377 |
traceback.print_exc()
|
| 378 |
-
# Save error so the
|
| 379 |
os.makedirs("training/outputs", exist_ok=True)
|
| 380 |
with open("training/outputs/summary.json", "w") as f:
|
| 381 |
json.dump({"error": str(e)}, f)
|
| 382 |
|
| 383 |
-
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
return summary
|
| 331 |
|
| 332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
# ── Main ──────────────────────────────────────────────────────────
|
| 334 |
if __name__ == "__main__":
|
| 335 |
try:
|
|
|
|
| 337 |
except Exception as e:
|
| 338 |
print(f"\nERROR during training: {e}")
|
| 339 |
traceback.print_exc()
|
| 340 |
+
# Save error so the UI can report it
|
| 341 |
os.makedirs("training/outputs", exist_ok=True)
|
| 342 |
with open("training/outputs/summary.json", "w") as f:
|
| 343 |
json.dump({"error": str(e)}, f)
|
| 344 |
|
| 345 |
+
# Start the full UI server (not a mini results server)
|
| 346 |
+
# This serves the control room + training results on port 7860
|
| 347 |
+
print("\nTraining done. Starting full UI server on port 7860...")
|
| 348 |
+
import uvicorn
|
| 349 |
+
from app import app
|
| 350 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|