K446 commited on
Commit
bcce6af
·
1 Parent(s): 69bab30

GRPO training with CUDA + results in UI

Browse files
Files changed (5) hide show
  1. Dockerfile +13 -3
  2. _fix_notebook.py +75 -0
  3. app.py +39 -1
  4. requirements-training.txt +1 -2
  5. run_training.py +7 -41
Dockerfile CHANGED
@@ -2,12 +2,19 @@
2
  # Serves both the UI dashboard AND GRPO training.
3
  # Set env OPENGRID_MODE=training for training mode.
4
 
5
- FROM python:3.10-slim
6
 
7
  LABEL org.opencontainers.image.title="OpenGrid"
8
  LABEL org.opencontainers.image.description="Renewable energy grid load-balancing environment"
9
  LABEL openenv="true"
10
 
 
 
 
 
 
 
 
11
  RUN useradd -m -u 1000 user
12
  USER user
13
  ENV PATH="/home/user/.local/bin:$PATH"
@@ -19,9 +26,12 @@ WORKDIR /app
19
  COPY --chown=user requirements.txt .
20
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
21
 
 
 
 
22
  # Install training deps (only re-runs if training reqs change)
23
  COPY --chown=user requirements-training.txt .
24
- RUN pip install --no-cache-dir --upgrade -r requirements-training.txt
25
 
26
  # --- Application code (selective COPY for lean images) ---
27
  # Core Python modules
@@ -47,7 +57,7 @@ RUN chmod +x entrypoint.sh
47
  # server = FastAPI UI, training = GRPO pipeline
48
  EXPOSE 7860
49
 
50
- HEALTHCHECK --interval=30s --timeout=5s --start-period=15s \
51
  CMD python -c "import httpx; httpx.get('http://localhost:7860/health').raise_for_status()" || exit 1
52
 
53
  CMD ["./entrypoint.sh"]
 
2
  # Serves both the UI dashboard AND GRPO training.
3
  # Set env OPENGRID_MODE=training for training mode.
4
 
5
+ FROM nvidia/cuda:12.1.1-runtime-ubuntu22.04
6
 
7
  LABEL org.opencontainers.image.title="OpenGrid"
8
  LABEL org.opencontainers.image.description="Renewable energy grid load-balancing environment"
9
  LABEL openenv="true"
10
 
11
+ # Install Python 3.10
12
+ RUN apt-get update && apt-get install -y --no-install-recommends \
13
+ python3.10 python3-pip python3.10-venv && \
14
+ ln -sf /usr/bin/python3.10 /usr/bin/python && \
15
+ ln -sf /usr/bin/pip3 /usr/bin/pip && \
16
+ rm -rf /var/lib/apt/lists/*
17
+
18
  RUN useradd -m -u 1000 user
19
  USER user
20
  ENV PATH="/home/user/.local/bin:$PATH"
 
26
  COPY --chown=user requirements.txt .
27
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
28
 
29
+ # Install PyTorch with CUDA support (must come before training deps)
30
+ RUN pip install --no-cache-dir torch==2.5.1 --extra-index-url https://download.pytorch.org/whl/cu121
31
+
32
  # Install training deps (only re-runs if training reqs change)
33
  COPY --chown=user requirements-training.txt .
34
+ RUN pip install --no-cache-dir --upgrade --no-deps -r requirements-training.txt
35
 
36
  # --- Application code (selective COPY for lean images) ---
37
  # Core Python modules
 
57
  # server = FastAPI UI, training = GRPO pipeline
58
  EXPOSE 7860
59
 
60
+ HEALTHCHECK --interval=60s --timeout=10s --start-period=600s \
61
  CMD python -c "import httpx; httpx.get('http://localhost:7860/health').raise_for_status()" || exit 1
62
 
63
  CMD ["./entrypoint.sh"]
_fix_notebook.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Update the notebook: fix rewards, hyperparams, remove emojis, show plots inline."""
2
+ import json
3
+
4
+ nb = json.load(open('training/opengrid_grpo_colab.ipynb', encoding='utf-8'))
5
+
6
+ # Remove emojis from all cells
7
+ for cell in nb['cells']:
8
+ for i, line in enumerate(cell.get('source', [])):
9
+ for emoji in ['🔋','⚡','🚀','📊','✅','⚠️']:
10
+ line = line.replace(emoji, '')
11
+ cell['source'][i] = line
12
+
13
+ # Fix Cell 8: use compute_grpo_reward_env
14
+ for cell in nb['cells']:
15
+ src = ''.join(cell.get('source', []))
16
+ if 'compute_grpo_reward,' in src and 'def reward_fn' in src:
17
+ cell['source'] = [
18
+ 'import json as _json\n',
19
+ 'from training.train_grpo import compute_grpo_reward_env, extract_action\n',
20
+ '\n',
21
+ 'def reward_fn(completions, obs_context=None, **kwargs):\n',
22
+ ' """GRPO reward function with env-grounded physics rewards."""\n',
23
+ ' texts = []\n',
24
+ ' for c in completions:\n',
25
+ ' if isinstance(c, list):\n',
26
+ ' text = c[-1]["content"] if c else ""\n',
27
+ ' else:\n',
28
+ ' text = str(c)\n',
29
+ ' texts.append(text)\n',
30
+ '\n',
31
+ ' if obs_context is None:\n',
32
+ ' batch_obs = [None] * len(texts)\n',
33
+ ' else:\n',
34
+ ' batch_obs = [\n',
35
+ ' _json.loads(ctx) if isinstance(ctx, str) else ctx\n',
36
+ ' for ctx in obs_context\n',
37
+ ' ]\n',
38
+ ' return compute_grpo_reward_env(texts, batch_obs, task_config, horizon=3)\n',
39
+ '\n',
40
+ '# Sanity test\n',
41
+ 'test_rewards = reward_fn([\n',
42
+ ' \'{"bus_adjustments": [{"bus_id": 1, "delta": 5.0}], "topology_actions": []}\',\n',
43
+ ' "invalid json here",\n',
44
+ '])\n',
45
+ 'print(f"Test rewards: {test_rewards}")\n',
46
+ 'assert len(test_rewards) == 2\n',
47
+ 'print("[OK] reward_fn works")\n',
48
+ ]
49
+ break
50
+
51
+ # Fix Cell 9: update hyperparameters
52
+ for cell in nb['cells']:
53
+ src = ''.join(cell.get('source', []))
54
+ if 'GRPOConfig(' in src and 'num_generations' in src:
55
+ new_src = src.replace('num_train_epochs=1', 'num_train_epochs=3')
56
+ new_src = new_src.replace('gradient_accumulation_steps=4', 'gradient_accumulation_steps=8')
57
+ new_src = new_src.replace('learning_rate=5e-6', 'learning_rate=1e-5')
58
+ new_src = new_src.replace('num_generations=4', 'num_generations=8')
59
+ cell['source'] = new_src.splitlines(True)
60
+ break
61
+
62
+ # Fix download cell: replace google.colab with inline display
63
+ for cell in nb['cells']:
64
+ src = ''.join(cell.get('source', []))
65
+ if 'google.colab' in src:
66
+ cell['source'] = [
67
+ '# Display plots inline\n',
68
+ 'from IPython.display import Image, display\n',
69
+ 'display(Image("training/outputs/before_after.png"))\n',
70
+ 'display(Image("training/outputs/training_loss.png"))\n',
71
+ ]
72
+ break
73
+
74
+ json.dump(nb, open('training/opengrid_grpo_colab.ipynb', 'w', encoding='utf-8'), indent=1)
75
+ print("Notebook updated successfully")
app.py CHANGED
@@ -413,4 +413,42 @@ def visualize(session_id: str):
413
  hist = list(history.get(session_id, []))
414
 
415
  img_str = generate_dashboard(hist, obs)
416
- return {"image_base64": img_str}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
  hist = list(history.get(session_id, []))
414
 
415
  img_str = generate_dashboard(hist, obs)
416
+ return {"image_base64": img_str}
417
+
418
+
419
+ # ===========================================================================
420
+ # Training Results
421
+ # ===========================================================================
422
+
423
+ @app.get("/training-results")
424
+ def training_results():
425
+ """Return GRPO training results if available."""
426
+ summary_path = pathlib.Path("training/outputs/summary.json")
427
+ if not summary_path.exists():
428
+ return {"available": False}
429
+ with open(summary_path) as f:
430
+ data = json.load(f)
431
+ # Check if it was an error
432
+ if "error" in data:
433
+ return {"available": True, "error": data["error"]}
434
+ # Add plot URLs
435
+ data["available"] = True
436
+ data["plots"] = {}
437
+ for name in ["before_after", "training_loss"]:
438
+ p = pathlib.Path(f"training/outputs/{name}.png")
439
+ if p.exists():
440
+ data["plots"][name] = f"/training-plots/{name}"
441
+ return data
442
+
443
+
444
+ @app.get("/training-plots/{name}")
445
+ def training_plot(name: str):
446
+ """Serve a training plot image."""
447
+ from fastapi.responses import FileResponse
448
+ allowed = {"before_after", "training_loss"}
449
+ if name not in allowed:
450
+ raise HTTPException(404, "Plot not found")
451
+ p = pathlib.Path(f"training/outputs/{name}.png")
452
+ if not p.exists():
453
+ raise HTTPException(404, "Plot not generated yet")
454
+ return FileResponse(str(p), media_type="image/png")
requirements-training.txt CHANGED
@@ -1,5 +1,4 @@
1
- # Training dependencies
2
- torch
3
  transformers>=4.51.3
4
  trl>=0.12.0,<1.0
5
  peft>=0.13.0
 
1
+ # Training dependencies (torch installed separately in Dockerfile with CUDA)
 
2
  transformers>=4.51.3
3
  trl>=0.12.0,<1.0
4
  peft>=0.13.0
run_training.py CHANGED
@@ -330,44 +330,6 @@ def run_grpo_training():
330
  return summary
331
 
332
 
333
- # ── Results Server ────────────────────────────────────────────────
334
- def serve_results():
335
- """Serve training results on port 7860."""
336
- from fastapi import FastAPI
337
- from fastapi.responses import FileResponse, JSONResponse
338
- import uvicorn
339
-
340
- app = FastAPI(title="OpenGrid Training Results")
341
-
342
- @app.get("/")
343
- def root():
344
- summary_path = Path("training/outputs/summary.json")
345
- if summary_path.exists():
346
- with open(summary_path) as f:
347
- return json.load(f)
348
- return {"status": "Training in progress or no results yet"}
349
-
350
- @app.get("/plots/before_after")
351
- def before_after():
352
- p = Path("training/outputs/before_after.png")
353
- if p.exists():
354
- return FileResponse(str(p), media_type="image/png")
355
- return JSONResponse({"error": "not ready"}, status_code=404)
356
-
357
- @app.get("/plots/loss")
358
- def loss():
359
- p = Path("training/outputs/training_loss.png")
360
- if p.exists():
361
- return FileResponse(str(p), media_type="image/png")
362
- return JSONResponse({"error": "not ready"}, status_code=404)
363
-
364
- @app.get("/health")
365
- def health():
366
- return {"status": "ok"}
367
-
368
- uvicorn.run(app, host="0.0.0.0", port=7860)
369
-
370
-
371
  # ── Main ──────────────────────────────────────────────────────────
372
  if __name__ == "__main__":
373
  try:
@@ -375,10 +337,14 @@ if __name__ == "__main__":
375
  except Exception as e:
376
  print(f"\nERROR during training: {e}")
377
  traceback.print_exc()
378
- # Save error so the results server can report it
379
  os.makedirs("training/outputs", exist_ok=True)
380
  with open("training/outputs/summary.json", "w") as f:
381
  json.dump({"error": str(e)}, f)
382
 
383
- print("\nStarting results server on port 7860...")
384
- serve_results()
 
 
 
 
 
330
  return summary
331
 
332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  # ── Main ──────────────────────────────────────────────────────────
334
  if __name__ == "__main__":
335
  try:
 
337
  except Exception as e:
338
  print(f"\nERROR during training: {e}")
339
  traceback.print_exc()
340
+ # Save error so the UI can report it
341
  os.makedirs("training/outputs", exist_ok=True)
342
  with open("training/outputs/summary.json", "w") as f:
343
  json.dump({"error": str(e)}, f)
344
 
345
+ # Start the full UI server (not a mini results server)
346
+ # This serves the control room + training results on port 7860
347
+ print("\nTraining done. Starting full UI server on port 7860...")
348
+ import uvicorn
349
+ from app import app
350
+ uvicorn.run(app, host="0.0.0.0", port=7860)