Spaces:
Running
Running
| from __future__ import annotations | |
| import json | |
| import os | |
| import sqlite3 | |
| import time | |
| from pathlib import Path | |
| from threading import Lock | |
| from typing import Any | |
| from uuid import uuid4 | |
| def _now() -> float: | |
| return time.time() | |
| def _as_json(payload: dict[str, Any]) -> str: | |
| return json.dumps(payload, separators=(",", ":"), ensure_ascii=True) | |
| def _from_json(payload: str) -> dict[str, Any]: | |
| data = json.loads(payload) | |
| return data if isinstance(data, dict) else {} | |
| def _resolve_data_dir(repo_root: Path) -> Path: | |
| configured = os.getenv("OPENENV_DATA_DIR") or os.getenv("STORAGE_DATA_DIR") | |
| if configured: | |
| return Path(configured).expanduser().resolve() | |
| if Path("/data").exists(): | |
| return Path("/data/openenv_rl").resolve() | |
| return (repo_root / "outputs" / "persist").resolve() | |
| def _default_fallback_data_dirs(repo_root: Path) -> list[Path]: | |
| return [ | |
| (repo_root / "outputs" / "persist").resolve(), | |
| Path("/tmp/openenv_rl").resolve(), | |
| ] | |
| def _storage_enabled() -> bool: | |
| raw = str(os.getenv("STORAGE_ENABLED", "true")).strip().lower() | |
| return raw not in {"0", "false", "no", "off"} | |
| class PersistenceStore: | |
| def __init__(self, repo_root: Path) -> None: | |
| self.repo_root = repo_root.resolve() | |
| self.enabled = _storage_enabled() | |
| self.data_dir = _resolve_data_dir(self.repo_root) | |
| self.db_path = self.data_dir / "openenv_state.sqlite3" | |
| self.training_runs_dir = self.data_dir / "training_runs" | |
| self._lock = Lock() | |
| if not self.enabled: | |
| return | |
| self._initialize_storage_dirs() | |
| def _initialize_storage_dirs(self) -> None: | |
| candidates: list[Path] = [self.data_dir] | |
| for fallback in _default_fallback_data_dirs(self.repo_root): | |
| if fallback not in candidates: | |
| candidates.append(fallback) | |
| last_error: Exception | None = None | |
| for candidate in candidates: | |
| try: | |
| candidate.mkdir(parents=True, exist_ok=True) | |
| self.data_dir = candidate | |
| self.db_path = self.data_dir / "openenv_state.sqlite3" | |
| self.training_runs_dir = self.data_dir / "training_runs" | |
| self.training_runs_dir.mkdir(parents=True, exist_ok=True) | |
| self._init_schema() | |
| return | |
| except (OSError, sqlite3.Error) as exc: | |
| last_error = exc | |
| self.enabled = False | |
| # Keep service startup alive in restricted runtimes (e.g. HF Spaces without writable /data). | |
| print( | |
| f"[persistence] disabled: no writable storage directory. " | |
| f"requested={candidates[0]} last_error={last_error!r}" | |
| ) | |
| def _connect(self) -> sqlite3.Connection: | |
| conn = sqlite3.connect(self.db_path, timeout=30) | |
| conn.row_factory = sqlite3.Row | |
| return conn | |
| def _init_schema(self) -> None: | |
| with self._connect() as conn: | |
| conn.executescript( | |
| """ | |
| CREATE TABLE IF NOT EXISTS training_jobs ( | |
| job_id TEXT PRIMARY KEY, | |
| created_at REAL NOT NULL, | |
| updated_at REAL NOT NULL, | |
| payload_json TEXT NOT NULL | |
| ); | |
| CREATE TABLE IF NOT EXISTS simulation_runs ( | |
| run_id TEXT PRIMARY KEY, | |
| created_at REAL NOT NULL, | |
| updated_at REAL NOT NULL, | |
| task_id TEXT, | |
| agent_mode TEXT, | |
| status TEXT, | |
| payload_json TEXT NOT NULL | |
| ); | |
| CREATE TABLE IF NOT EXISTS comparison_runs ( | |
| comparison_id TEXT PRIMARY KEY, | |
| created_at REAL NOT NULL, | |
| updated_at REAL NOT NULL, | |
| task_id TEXT, | |
| payload_json TEXT NOT NULL | |
| ); | |
| """ | |
| ) | |
| conn.commit() | |
| # Training jobs --------------------------------------------------------- | |
| def upsert_training_job(self, snapshot: dict[str, Any]) -> None: | |
| if not self.enabled: | |
| return | |
| job_id = str(snapshot.get("job_id") or "") | |
| if not job_id: | |
| return | |
| created_at = float(snapshot.get("created_at") or _now()) | |
| updated_at = float(snapshot.get("updated_at") or _now()) | |
| with self._lock, self._connect() as conn: | |
| conn.execute( | |
| """ | |
| INSERT INTO training_jobs (job_id, created_at, updated_at, payload_json) | |
| VALUES (?, ?, ?, ?) | |
| ON CONFLICT(job_id) DO UPDATE SET | |
| updated_at = excluded.updated_at, | |
| payload_json = excluded.payload_json | |
| """, | |
| (job_id, created_at, updated_at, _as_json(snapshot)), | |
| ) | |
| conn.commit() | |
| def list_training_jobs(self, limit: int = 500) -> list[dict[str, Any]]: | |
| if not self.enabled: | |
| return [] | |
| rows: list[dict[str, Any]] = [] | |
| with self._lock, self._connect() as conn: | |
| cur = conn.execute( | |
| """ | |
| SELECT payload_json FROM training_jobs | |
| ORDER BY updated_at DESC | |
| LIMIT ? | |
| """, | |
| (max(1, int(limit)),), | |
| ) | |
| for row in cur.fetchall(): | |
| rows.append(_from_json(str(row["payload_json"]))) | |
| return rows | |
| def clear_training_jobs(self) -> int: | |
| if not self.enabled: | |
| return 0 | |
| with self._lock, self._connect() as conn: | |
| cur = conn.execute("DELETE FROM training_jobs") | |
| conn.commit() | |
| return int(cur.rowcount or 0) | |
| def delete_training_job(self, job_id: str) -> int: | |
| if not self.enabled: | |
| return 0 | |
| with self._lock, self._connect() as conn: | |
| cur = conn.execute("DELETE FROM training_jobs WHERE job_id = ?", (str(job_id),)) | |
| conn.commit() | |
| return int(cur.rowcount or 0) | |
| # Simulation runs ------------------------------------------------------- | |
| def upsert_simulation_run( | |
| self, | |
| *, | |
| run_id: str, | |
| task_id: str, | |
| agent_mode: str, | |
| status: str, | |
| payload: dict[str, Any], | |
| ) -> None: | |
| if not self.enabled: | |
| return | |
| now = _now() | |
| created_at = float(payload.get("created_at") or now) | |
| payload = dict(payload) | |
| payload["run_id"] = run_id | |
| payload["created_at"] = created_at | |
| payload["updated_at"] = now | |
| payload["task_id"] = task_id | |
| payload["agent_mode"] = agent_mode | |
| payload["status"] = status | |
| with self._lock, self._connect() as conn: | |
| conn.execute( | |
| """ | |
| INSERT INTO simulation_runs (run_id, created_at, updated_at, task_id, agent_mode, status, payload_json) | |
| VALUES (?, ?, ?, ?, ?, ?, ?) | |
| ON CONFLICT(run_id) DO UPDATE SET | |
| updated_at = excluded.updated_at, | |
| task_id = excluded.task_id, | |
| agent_mode = excluded.agent_mode, | |
| status = excluded.status, | |
| payload_json = excluded.payload_json | |
| """, | |
| ( | |
| run_id, | |
| created_at, | |
| now, | |
| task_id, | |
| agent_mode, | |
| status, | |
| _as_json(payload), | |
| ), | |
| ) | |
| conn.commit() | |
| def list_simulation_runs(self, limit: int = 50) -> list[dict[str, Any]]: | |
| if not self.enabled: | |
| return [] | |
| out: list[dict[str, Any]] = [] | |
| with self._lock, self._connect() as conn: | |
| cur = conn.execute( | |
| """ | |
| SELECT payload_json FROM simulation_runs | |
| ORDER BY updated_at DESC | |
| LIMIT ? | |
| """, | |
| (max(1, int(limit)),), | |
| ) | |
| for row in cur.fetchall(): | |
| data = _from_json(str(row["payload_json"])) | |
| if isinstance(data.get("trace"), list): | |
| data["trace_len"] = len(data["trace"]) | |
| data["has_trace"] = bool(data["trace"]) | |
| data.pop("trace", None) | |
| out.append(data) | |
| return out | |
| def get_simulation_run(self, run_id: str) -> dict[str, Any] | None: | |
| if not self.enabled: | |
| return None | |
| with self._lock, self._connect() as conn: | |
| cur = conn.execute( | |
| "SELECT payload_json FROM simulation_runs WHERE run_id = ?", | |
| (run_id,), | |
| ) | |
| row = cur.fetchone() | |
| if row is None: | |
| return None | |
| return _from_json(str(row["payload_json"])) | |
| def clear_simulation_runs(self) -> int: | |
| if not self.enabled: | |
| return 0 | |
| with self._lock, self._connect() as conn: | |
| cur = conn.execute("DELETE FROM simulation_runs") | |
| conn.commit() | |
| return int(cur.rowcount or 0) | |
| # Comparison runs ------------------------------------------------------- | |
| def create_comparison_run(self, payload: dict[str, Any]) -> str | None: | |
| if not self.enabled: | |
| return None | |
| comparison_id = str(payload.get("comparison_id") or uuid4()) | |
| now = _now() | |
| body = dict(payload) | |
| body["comparison_id"] = comparison_id | |
| body["created_at"] = float(body.get("created_at") or now) | |
| body["updated_at"] = now | |
| task_id = str(body.get("task_id") or "") | |
| with self._lock, self._connect() as conn: | |
| conn.execute( | |
| """ | |
| INSERT INTO comparison_runs (comparison_id, created_at, updated_at, task_id, payload_json) | |
| VALUES (?, ?, ?, ?, ?) | |
| ON CONFLICT(comparison_id) DO UPDATE SET | |
| updated_at = excluded.updated_at, | |
| task_id = excluded.task_id, | |
| payload_json = excluded.payload_json | |
| """, | |
| ( | |
| comparison_id, | |
| float(body["created_at"]), | |
| now, | |
| task_id, | |
| _as_json(body), | |
| ), | |
| ) | |
| conn.commit() | |
| return comparison_id | |
| def list_comparison_runs(self, limit: int = 50) -> list[dict[str, Any]]: | |
| if not self.enabled: | |
| return [] | |
| out: list[dict[str, Any]] = [] | |
| with self._lock, self._connect() as conn: | |
| cur = conn.execute( | |
| """ | |
| SELECT payload_json FROM comparison_runs | |
| ORDER BY updated_at DESC | |
| LIMIT ? | |
| """, | |
| (max(1, int(limit)),), | |
| ) | |
| for row in cur.fetchall(): | |
| out.append(_from_json(str(row["payload_json"]))) | |
| return out | |
| def get_comparison_run(self, comparison_id: str) -> dict[str, Any] | None: | |
| if not self.enabled: | |
| return None | |
| with self._lock, self._connect() as conn: | |
| cur = conn.execute( | |
| "SELECT payload_json FROM comparison_runs WHERE comparison_id = ?", | |
| (comparison_id,), | |
| ) | |
| row = cur.fetchone() | |
| if row is None: | |
| return None | |
| return _from_json(str(row["payload_json"])) | |
| def clear_comparison_runs(self) -> int: | |
| if not self.enabled: | |
| return 0 | |
| with self._lock, self._connect() as conn: | |
| cur = conn.execute("DELETE FROM comparison_runs") | |
| conn.commit() | |
| return int(cur.rowcount or 0) | |