Spaces:
Running
Running
Siddharaj Shirke commited on
Commit ·
c7e793a
1
Parent(s): df97e68
feat: enable persistent model checkpoint discovery and upload API for HF storage
Browse files- .env.example +5 -0
- app/main.py +102 -8
- requirements.txt +1 -0
.env.example
CHANGED
|
@@ -44,3 +44,8 @@ LLM_CALL_DELAY=12.0
|
|
| 44 |
# For Hugging Face persistent storage, set OPENENV_DATA_DIR=/data/openenv_rl
|
| 45 |
STORAGE_ENABLED=true
|
| 46 |
OPENENV_DATA_DIR=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
# For Hugging Face persistent storage, set OPENENV_DATA_DIR=/data/openenv_rl
|
| 45 |
STORAGE_ENABLED=true
|
| 46 |
OPENENV_DATA_DIR=
|
| 47 |
+
|
| 48 |
+
# Optional CSV list of extra model directories to scan for RL checkpoints.
|
| 49 |
+
# Example:
|
| 50 |
+
# OPENENV_MODEL_SEARCH_DIRS=/data/openenv_rl/results/best_model/phase1,/data/openenv_rl/results/best_model/phase2
|
| 51 |
+
OPENENV_MODEL_SEARCH_DIRS=
|
app/main.py
CHANGED
|
@@ -41,7 +41,7 @@ import time
|
|
| 41 |
from typing import Any, Literal
|
| 42 |
from uuid import uuid4
|
| 43 |
|
| 44 |
-
from fastapi import APIRouter, Body, FastAPI, HTTPException, Query, status
|
| 45 |
from fastapi.middleware.cors import CORSMiddleware
|
| 46 |
from fastapi.routing import APIRoute
|
| 47 |
from fastapi.responses import FileResponse, RedirectResponse, StreamingResponse
|
|
@@ -427,12 +427,55 @@ def _log_line_text(value: Any) -> str:
|
|
| 427 |
|
| 428 |
|
| 429 |
def _phase_model_dirs() -> list[Path]:
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
]
|
| 435 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
|
| 437 |
def _discover_phase12_zip_models() -> list[Path]:
|
| 438 |
discovered: list[Path] = []
|
|
@@ -446,6 +489,14 @@ def _discover_phase12_zip_models() -> list[Path]:
|
|
| 446 |
return unique
|
| 447 |
|
| 448 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
def _phase_from_model_path(path: Path) -> int:
|
| 450 |
parent = path.parent.name.lower()
|
| 451 |
if parent == "phase1":
|
|
@@ -1751,7 +1802,7 @@ def api_workflow_components() -> WorkflowComponentsResponse:
|
|
| 1751 |
repo_root = REPO_ROOT
|
| 1752 |
baseline_f = repo_root / "baseline_openai.py"
|
| 1753 |
inference_f = repo_root / "inference.py"
|
| 1754 |
-
phase2_model =
|
| 1755 |
components = [
|
| 1756 |
WorkflowComponentStatus(
|
| 1757 |
component="baseline_openai.py",
|
|
@@ -1770,8 +1821,12 @@ def api_workflow_components() -> WorkflowComponentsResponse:
|
|
| 1770 |
WorkflowComponentStatus(
|
| 1771 |
component="phase2_final.zip",
|
| 1772 |
description="Trained Phase 2 PPO checkpoint used for local RL evaluation/execution.",
|
| 1773 |
-
available=phase2_model
|
| 1774 |
-
command=
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1775 |
),
|
| 1776 |
WorkflowComponentStatus(
|
| 1777 |
component="openenv-api",
|
|
@@ -1939,6 +1994,45 @@ def api_rl_models() -> RLModelsResponse:
|
|
| 1939 |
return RLModelsResponse(models=models)
|
| 1940 |
|
| 1941 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1942 |
@api.get(
|
| 1943 |
"/rl/models",
|
| 1944 |
response_model=list[ModelInfo],
|
|
|
|
| 41 |
from typing import Any, Literal
|
| 42 |
from uuid import uuid4
|
| 43 |
|
| 44 |
+
from fastapi import APIRouter, Body, FastAPI, File, HTTPException, Query, UploadFile, status
|
| 45 |
from fastapi.middleware.cors import CORSMiddleware
|
| 46 |
from fastapi.routing import APIRoute
|
| 47 |
from fastapi.responses import FileResponse, RedirectResponse, StreamingResponse
|
|
|
|
| 427 |
|
| 428 |
|
| 429 |
def _phase_model_dirs() -> list[Path]:
|
| 430 |
+
"""
|
| 431 |
+
Discover model directories from multiple roots.
|
| 432 |
+
|
| 433 |
+
Priority:
|
| 434 |
+
1) Explicit OPENENV_MODEL_SEARCH_DIRS (CSV of absolute/relative paths)
|
| 435 |
+
2) Persistent storage root OPENENV_DATA_DIR (HF bucket mount recommended)
|
| 436 |
+
3) Repo-local results/best_model
|
| 437 |
+
"""
|
| 438 |
+
configured_dirs = (os.getenv("OPENENV_MODEL_SEARCH_DIRS") or "").strip()
|
| 439 |
+
configured: list[Path] = []
|
| 440 |
+
if configured_dirs:
|
| 441 |
+
for raw in configured_dirs.split(","):
|
| 442 |
+
s = raw.strip()
|
| 443 |
+
if not s:
|
| 444 |
+
continue
|
| 445 |
+
p = Path(s)
|
| 446 |
+
if not p.is_absolute():
|
| 447 |
+
p = (REPO_ROOT / p).resolve()
|
| 448 |
+
configured.append(p)
|
| 449 |
+
|
| 450 |
+
data_root_raw = (os.getenv("OPENENV_DATA_DIR") or "/data/openenv_rl").strip()
|
| 451 |
+
data_root = Path(data_root_raw)
|
| 452 |
+
if not data_root.is_absolute():
|
| 453 |
+
data_root = (REPO_ROOT / data_root).resolve()
|
| 454 |
+
|
| 455 |
+
repo_base = REPO_ROOT / "results" / "best_model"
|
| 456 |
+
data_base = data_root / "results" / "best_model"
|
| 457 |
+
|
| 458 |
+
candidates = [
|
| 459 |
+
*configured,
|
| 460 |
+
data_base / "phase1",
|
| 461 |
+
data_base / "phase2",
|
| 462 |
+
data_root / "best_model" / "phase1",
|
| 463 |
+
data_root / "best_model" / "phase2",
|
| 464 |
+
repo_base / "phase1",
|
| 465 |
+
repo_base / "phase2",
|
| 466 |
]
|
| 467 |
|
| 468 |
+
# Preserve order, remove duplicates.
|
| 469 |
+
deduped: list[Path] = []
|
| 470 |
+
seen: set[str] = set()
|
| 471 |
+
for p in candidates:
|
| 472 |
+
key = str(p.resolve()) if p.exists() else str(p)
|
| 473 |
+
if key in seen:
|
| 474 |
+
continue
|
| 475 |
+
seen.add(key)
|
| 476 |
+
deduped.append(p)
|
| 477 |
+
return deduped
|
| 478 |
+
|
| 479 |
|
| 480 |
def _discover_phase12_zip_models() -> list[Path]:
|
| 481 |
discovered: list[Path] = []
|
|
|
|
| 489 |
return unique
|
| 490 |
|
| 491 |
|
| 492 |
+
def _model_storage_base_dir() -> Path:
|
| 493 |
+
data_root_raw = (os.getenv("OPENENV_DATA_DIR") or "/data/openenv_rl").strip()
|
| 494 |
+
data_root = Path(data_root_raw)
|
| 495 |
+
if not data_root.is_absolute():
|
| 496 |
+
data_root = (REPO_ROOT / data_root).resolve()
|
| 497 |
+
return data_root / "results" / "best_model"
|
| 498 |
+
|
| 499 |
+
|
| 500 |
def _phase_from_model_path(path: Path) -> int:
|
| 501 |
parent = path.parent.name.lower()
|
| 502 |
if parent == "phase1":
|
|
|
|
| 1802 |
repo_root = REPO_ROOT
|
| 1803 |
baseline_f = repo_root / "baseline_openai.py"
|
| 1804 |
inference_f = repo_root / "inference.py"
|
| 1805 |
+
phase2_model = next((p for p in _discover_phase12_zip_models() if _phase_from_model_path(p) == 2), None)
|
| 1806 |
components = [
|
| 1807 |
WorkflowComponentStatus(
|
| 1808 |
component="baseline_openai.py",
|
|
|
|
| 1821 |
WorkflowComponentStatus(
|
| 1822 |
component="phase2_final.zip",
|
| 1823 |
description="Trained Phase 2 PPO checkpoint used for local RL evaluation/execution.",
|
| 1824 |
+
available=phase2_model is not None,
|
| 1825 |
+
command=(
|
| 1826 |
+
f".\\.venv\\3.11\\Scripts\\python.exe -m rl.evaluate --model {phase2_model} --episodes 3 --model-type maskable"
|
| 1827 |
+
if phase2_model is not None
|
| 1828 |
+
else r".\.venv\3.11\Scripts\python.exe -m rl.evaluate --model results/best_model/phase2_final.zip --episodes 3 --model-type maskable"
|
| 1829 |
+
),
|
| 1830 |
),
|
| 1831 |
WorkflowComponentStatus(
|
| 1832 |
component="openenv-api",
|
|
|
|
| 1994 |
return RLModelsResponse(models=models)
|
| 1995 |
|
| 1996 |
|
| 1997 |
+
@api.post("/rl_models/upload", summary="Upload RL checkpoint zip to persistent storage")
|
| 1998 |
+
async def api_rl_model_upload(
|
| 1999 |
+
phase: int = Query(..., ge=1, le=2, description="Model phase bucket (1 or 2)"),
|
| 2000 |
+
file: UploadFile = File(..., description="Checkpoint zip file"),
|
| 2001 |
+
) -> dict[str, Any]:
|
| 2002 |
+
name = (file.filename or "").strip()
|
| 2003 |
+
if not name:
|
| 2004 |
+
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail="Missing filename.")
|
| 2005 |
+
if not name.lower().endswith(".zip"):
|
| 2006 |
+
raise HTTPException(
|
| 2007 |
+
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
| 2008 |
+
detail="Only .zip checkpoint files are accepted.",
|
| 2009 |
+
)
|
| 2010 |
+
|
| 2011 |
+
safe_name = Path(name).name
|
| 2012 |
+
base_dir = _model_storage_base_dir()
|
| 2013 |
+
target_dir = base_dir / f"phase{phase}"
|
| 2014 |
+
target_dir.mkdir(parents=True, exist_ok=True)
|
| 2015 |
+
target_path = target_dir / safe_name
|
| 2016 |
+
|
| 2017 |
+
total = 0
|
| 2018 |
+
with target_path.open("wb") as out:
|
| 2019 |
+
while True:
|
| 2020 |
+
chunk = await file.read(1024 * 1024)
|
| 2021 |
+
if not chunk:
|
| 2022 |
+
break
|
| 2023 |
+
out.write(chunk)
|
| 2024 |
+
total += len(chunk)
|
| 2025 |
+
await file.close()
|
| 2026 |
+
|
| 2027 |
+
return {
|
| 2028 |
+
"saved": True,
|
| 2029 |
+
"phase": phase,
|
| 2030 |
+
"filename": safe_name,
|
| 2031 |
+
"size_bytes": total,
|
| 2032 |
+
"path": str(target_path),
|
| 2033 |
+
}
|
| 2034 |
+
|
| 2035 |
+
|
| 2036 |
@api.get(
|
| 2037 |
"/rl/models",
|
| 2038 |
response_model=list[ModelInfo],
|
requirements.txt
CHANGED
|
@@ -12,3 +12,4 @@ anyio>=4.0,<5.0
|
|
| 12 |
PyYAML>=6.0,<7.0
|
| 13 |
sse-starlette>=2.1,<3.0
|
| 14 |
numpy>=1.26,<3.0
|
|
|
|
|
|
| 12 |
PyYAML>=6.0,<7.0
|
| 13 |
sse-starlette>=2.1,<3.0
|
| 14 |
numpy>=1.26,<3.0
|
| 15 |
+
python-multipart>=0.0.9,<1.0
|