queryforge / client.py
Prithvigg's picture
Upload folder using huggingface_hub
3867c62 verified
"""QueryForge Environment Client."""
from typing import Any, Dict, List, Optional
import httpx
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State
try:
from .models import SQLAction, SQLObservation, TaskSpec
except ImportError:
from models import SQLAction, SQLObservation, TaskSpec
class QueryforgeEnv(EnvClient[SQLAction, SQLObservation, State]):
"""
Client for the QueryForge SQL Debugger & Optimiser environment.
Maintains a persistent WebSocket connection to the environment server.
Each client instance has its own dedicated session (isolated task state).
Example:
>>> with QueryforgeEnv(base_url="http://localhost:8000") as env:
... obs = env.reset()
... print(obs.task_title)
... print(obs.task_description)
...
... result = env.step(SQLAction(sql="SELECT name, age FROM users WHERE age > 30"))
... print(result.reward, result.observation.feedback)
Example with Docker:
>>> env = QueryforgeEnv.from_docker_image("queryforge-env:latest")
>>> try:
... obs = env.reset()
... result = env.step(SQLAction(sql="SELECT ..."))
... finally:
... env.close()
"""
def _step_payload(self, action: SQLAction) -> Dict:
return {"sql": action.sql}
def _parse_result(self, payload: Dict) -> StepResult[SQLObservation]:
obs_data = payload.get("observation", {})
observation = SQLObservation(
task_id=obs_data.get("task_id", ""),
task_level=obs_data.get("task_level", ""),
task_title=obs_data.get("task_title", ""),
task_description=obs_data.get("task_description", ""),
syntax_valid=obs_data.get("syntax_valid", False),
execution_success=obs_data.get("execution_success", False),
execution_error=obs_data.get("execution_error"),
rows_returned=obs_data.get("rows_returned", 0),
feedback=obs_data.get("feedback", ""),
hint=obs_data.get("hint", ""),
attempt=obs_data.get("attempt", 0),
best_score=obs_data.get("best_score", 0.0),
done=payload.get("done", False),
reward=payload.get("reward", 0.0),
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=payload.get("reward", 0.0),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict) -> State:
return State(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
)
# ── Task Registry helpers ─────────────────────────────────────────────────
def register_task(self, spec: TaskSpec) -> Dict[str, Any]:
"""Register a custom task on the server. Returns the server response dict."""
resp = httpx.post(
f"{self.base_url}/tasks",
json=spec.model_dump(),
timeout=10,
)
resp.raise_for_status()
return resp.json()
def list_tasks(self) -> List[Dict[str, Any]]:
"""Return all registered tasks (built-in + custom) as a list of dicts."""
resp = httpx.get(f"{self.base_url}/tasks", timeout=10)
resp.raise_for_status()
return resp.json()
def delete_task(self, task_id: str) -> Dict[str, Any]:
"""Remove a custom task by ID. Raises httpx.HTTPStatusError on 403/404."""
resp = httpx.delete(f"{self.base_url}/tasks/{task_id}", timeout=10)
resp.raise_for_status()
return resp.json()