import httpx import json import os from typing import Any, Dict, Optional from dataclasses import dataclass @dataclass class NL2SQLAction: query: str @dataclass class NL2SQLObservation: question: str schema_context: str task_name: str last_query: str last_result: list last_error: Optional[str] result_columns: list step: int max_steps: int done: bool reward: float score: float @dataclass class StepResult: observation: NL2SQLObservation reward: float done: bool class NL2SQLEnv: def __init__(self, base_url: str = "http://localhost:8000"): self.base_url = base_url.rstrip("/") self.client = httpx.AsyncClient(base_url=self.base_url, timeout=120.0) async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.client.aclose() async def reset(self) -> StepResult: task_name = os.getenv("NL2SQL_DEFAULT_TASK", "simple-filter") # Send task_name both ways — some openenv-core versions read from body, # some from the action wrapper. Belt-and-suspenders. payload = {"task_name": task_name} resp = await self.client.post("/reset", json=payload) resp.raise_for_status() return self._parse_result(resp.json()) async def step(self, action: NL2SQLAction) -> StepResult: # CRITICAL FIX: The server's action_cls=NL2SQLAction expects the payload # wrapped in {"action": {"query": ...}} per OpenEnv protocol. # Sending {"query": ...} at the top level bypasses action parsing → 0 reward. payload = {"action": {"query": action.query}} resp = await self.client.post("/step", json=payload) resp.raise_for_status() return self._parse_result(resp.json()) def _parse_result(self, payload: Dict[str, Any]) -> StepResult: obs_data = payload.get("observation", payload) # Extract reward — check top-level payload first (OpenEnv puts it there), # then fall back to nested observation dict. raw_reward = payload.get("reward") if raw_reward is None: raw_reward = obs_data.get("reward") safe_reward = float(raw_reward) if raw_reward is not None else 0.0 safe_score = float(obs_data.get("score") or 0.0) safe_done = bool(payload.get("done") or obs_data.get("done") or False) obs = NL2SQLObservation( question=obs_data.get("question", ""), schema_context=obs_data.get("schema_context", ""), task_name=obs_data.get("task_name", ""), last_query=obs_data.get("last_query", ""), last_result=obs_data.get("last_result", []), last_error=obs_data.get("last_error"), result_columns=obs_data.get("result_columns", []), step=obs_data.get("step", 0), max_steps=obs_data.get("max_steps", 5), done=safe_done, reward=safe_reward, score=safe_score, ) return StepResult( observation=obs, reward=safe_reward, done=safe_done, )