Siddharaj Shirke commited on
Commit
c7e793a
·
1 Parent(s): df97e68

feat: enable persistent model checkpoint discovery and upload API for HF storage

Browse files
Files changed (3) hide show
  1. .env.example +5 -0
  2. app/main.py +102 -8
  3. requirements.txt +1 -0
.env.example CHANGED
@@ -44,3 +44,8 @@ LLM_CALL_DELAY=12.0
44
  # For Hugging Face persistent storage, set OPENENV_DATA_DIR=/data/openenv_rl
45
  STORAGE_ENABLED=true
46
  OPENENV_DATA_DIR=
 
 
 
 
 
 
44
  # For Hugging Face persistent storage, set OPENENV_DATA_DIR=/data/openenv_rl
45
  STORAGE_ENABLED=true
46
  OPENENV_DATA_DIR=
47
+
48
+ # Optional CSV list of extra model directories to scan for RL checkpoints.
49
+ # Example:
50
+ # OPENENV_MODEL_SEARCH_DIRS=/data/openenv_rl/results/best_model/phase1,/data/openenv_rl/results/best_model/phase2
51
+ OPENENV_MODEL_SEARCH_DIRS=
app/main.py CHANGED
@@ -41,7 +41,7 @@ import time
41
  from typing import Any, Literal
42
  from uuid import uuid4
43
 
44
- from fastapi import APIRouter, Body, FastAPI, HTTPException, Query, status
45
  from fastapi.middleware.cors import CORSMiddleware
46
  from fastapi.routing import APIRoute
47
  from fastapi.responses import FileResponse, RedirectResponse, StreamingResponse
@@ -427,12 +427,55 @@ def _log_line_text(value: Any) -> str:
427
 
428
 
429
  def _phase_model_dirs() -> list[Path]:
430
- base = REPO_ROOT / "results" / "best_model"
431
- return [
432
- base / "phase1",
433
- base / "phase2",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  ]
435
 
 
 
 
 
 
 
 
 
 
 
 
436
 
437
  def _discover_phase12_zip_models() -> list[Path]:
438
  discovered: list[Path] = []
@@ -446,6 +489,14 @@ def _discover_phase12_zip_models() -> list[Path]:
446
  return unique
447
 
448
 
 
 
 
 
 
 
 
 
449
  def _phase_from_model_path(path: Path) -> int:
450
  parent = path.parent.name.lower()
451
  if parent == "phase1":
@@ -1751,7 +1802,7 @@ def api_workflow_components() -> WorkflowComponentsResponse:
1751
  repo_root = REPO_ROOT
1752
  baseline_f = repo_root / "baseline_openai.py"
1753
  inference_f = repo_root / "inference.py"
1754
- phase2_model = repo_root / "results" / "best_model" / "phase2_final.zip"
1755
  components = [
1756
  WorkflowComponentStatus(
1757
  component="baseline_openai.py",
@@ -1770,8 +1821,12 @@ def api_workflow_components() -> WorkflowComponentsResponse:
1770
  WorkflowComponentStatus(
1771
  component="phase2_final.zip",
1772
  description="Trained Phase 2 PPO checkpoint used for local RL evaluation/execution.",
1773
- available=phase2_model.exists(),
1774
- command=r".\.venv\3.11\Scripts\python.exe -m rl.evaluate --model results/best_model/phase2_final.zip --episodes 3 --model-type maskable",
 
 
 
 
1775
  ),
1776
  WorkflowComponentStatus(
1777
  component="openenv-api",
@@ -1939,6 +1994,45 @@ def api_rl_models() -> RLModelsResponse:
1939
  return RLModelsResponse(models=models)
1940
 
1941
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1942
  @api.get(
1943
  "/rl/models",
1944
  response_model=list[ModelInfo],
 
41
  from typing import Any, Literal
42
  from uuid import uuid4
43
 
44
+ from fastapi import APIRouter, Body, FastAPI, File, HTTPException, Query, UploadFile, status
45
  from fastapi.middleware.cors import CORSMiddleware
46
  from fastapi.routing import APIRoute
47
  from fastapi.responses import FileResponse, RedirectResponse, StreamingResponse
 
427
 
428
 
429
  def _phase_model_dirs() -> list[Path]:
430
+ """
431
+ Discover model directories from multiple roots.
432
+
433
+ Priority:
434
+ 1) Explicit OPENENV_MODEL_SEARCH_DIRS (CSV of absolute/relative paths)
435
+ 2) Persistent storage root OPENENV_DATA_DIR (HF bucket mount recommended)
436
+ 3) Repo-local results/best_model
437
+ """
438
+ configured_dirs = (os.getenv("OPENENV_MODEL_SEARCH_DIRS") or "").strip()
439
+ configured: list[Path] = []
440
+ if configured_dirs:
441
+ for raw in configured_dirs.split(","):
442
+ s = raw.strip()
443
+ if not s:
444
+ continue
445
+ p = Path(s)
446
+ if not p.is_absolute():
447
+ p = (REPO_ROOT / p).resolve()
448
+ configured.append(p)
449
+
450
+ data_root_raw = (os.getenv("OPENENV_DATA_DIR") or "/data/openenv_rl").strip()
451
+ data_root = Path(data_root_raw)
452
+ if not data_root.is_absolute():
453
+ data_root = (REPO_ROOT / data_root).resolve()
454
+
455
+ repo_base = REPO_ROOT / "results" / "best_model"
456
+ data_base = data_root / "results" / "best_model"
457
+
458
+ candidates = [
459
+ *configured,
460
+ data_base / "phase1",
461
+ data_base / "phase2",
462
+ data_root / "best_model" / "phase1",
463
+ data_root / "best_model" / "phase2",
464
+ repo_base / "phase1",
465
+ repo_base / "phase2",
466
  ]
467
 
468
+ # Preserve order, remove duplicates.
469
+ deduped: list[Path] = []
470
+ seen: set[str] = set()
471
+ for p in candidates:
472
+ key = str(p.resolve()) if p.exists() else str(p)
473
+ if key in seen:
474
+ continue
475
+ seen.add(key)
476
+ deduped.append(p)
477
+ return deduped
478
+
479
 
480
  def _discover_phase12_zip_models() -> list[Path]:
481
  discovered: list[Path] = []
 
489
  return unique
490
 
491
 
492
+ def _model_storage_base_dir() -> Path:
493
+ data_root_raw = (os.getenv("OPENENV_DATA_DIR") or "/data/openenv_rl").strip()
494
+ data_root = Path(data_root_raw)
495
+ if not data_root.is_absolute():
496
+ data_root = (REPO_ROOT / data_root).resolve()
497
+ return data_root / "results" / "best_model"
498
+
499
+
500
  def _phase_from_model_path(path: Path) -> int:
501
  parent = path.parent.name.lower()
502
  if parent == "phase1":
 
1802
  repo_root = REPO_ROOT
1803
  baseline_f = repo_root / "baseline_openai.py"
1804
  inference_f = repo_root / "inference.py"
1805
+ phase2_model = next((p for p in _discover_phase12_zip_models() if _phase_from_model_path(p) == 2), None)
1806
  components = [
1807
  WorkflowComponentStatus(
1808
  component="baseline_openai.py",
 
1821
  WorkflowComponentStatus(
1822
  component="phase2_final.zip",
1823
  description="Trained Phase 2 PPO checkpoint used for local RL evaluation/execution.",
1824
+ available=phase2_model is not None,
1825
+ command=(
1826
+ f".\\.venv\\3.11\\Scripts\\python.exe -m rl.evaluate --model {phase2_model} --episodes 3 --model-type maskable"
1827
+ if phase2_model is not None
1828
+ else r".\.venv\3.11\Scripts\python.exe -m rl.evaluate --model results/best_model/phase2_final.zip --episodes 3 --model-type maskable"
1829
+ ),
1830
  ),
1831
  WorkflowComponentStatus(
1832
  component="openenv-api",
 
1994
  return RLModelsResponse(models=models)
1995
 
1996
 
1997
+ @api.post("/rl_models/upload", summary="Upload RL checkpoint zip to persistent storage")
1998
+ async def api_rl_model_upload(
1999
+ phase: int = Query(..., ge=1, le=2, description="Model phase bucket (1 or 2)"),
2000
+ file: UploadFile = File(..., description="Checkpoint zip file"),
2001
+ ) -> dict[str, Any]:
2002
+ name = (file.filename or "").strip()
2003
+ if not name:
2004
+ raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail="Missing filename.")
2005
+ if not name.lower().endswith(".zip"):
2006
+ raise HTTPException(
2007
+ status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
2008
+ detail="Only .zip checkpoint files are accepted.",
2009
+ )
2010
+
2011
+ safe_name = Path(name).name
2012
+ base_dir = _model_storage_base_dir()
2013
+ target_dir = base_dir / f"phase{phase}"
2014
+ target_dir.mkdir(parents=True, exist_ok=True)
2015
+ target_path = target_dir / safe_name
2016
+
2017
+ total = 0
2018
+ with target_path.open("wb") as out:
2019
+ while True:
2020
+ chunk = await file.read(1024 * 1024)
2021
+ if not chunk:
2022
+ break
2023
+ out.write(chunk)
2024
+ total += len(chunk)
2025
+ await file.close()
2026
+
2027
+ return {
2028
+ "saved": True,
2029
+ "phase": phase,
2030
+ "filename": safe_name,
2031
+ "size_bytes": total,
2032
+ "path": str(target_path),
2033
+ }
2034
+
2035
+
2036
  @api.get(
2037
  "/rl/models",
2038
  response_model=list[ModelInfo],
requirements.txt CHANGED
@@ -12,3 +12,4 @@ anyio>=4.0,<5.0
12
  PyYAML>=6.0,<7.0
13
  sse-starlette>=2.1,<3.0
14
  numpy>=1.26,<3.0
 
 
12
  PyYAML>=6.0,<7.0
13
  sse-starlette>=2.1,<3.0
14
  numpy>=1.26,<3.0
15
+ python-multipart>=0.0.9,<1.0