"""TRL environment adapter for SQLEnv.""" from __future__ import annotations import collections try: from sql_env.models import SQLAction except ImportError: # pragma: no cover from models import SQLAction # type: ignore[no-redef] try: from sql_env.server.sql_environment import SQLEnvironment except ImportError: # pragma: no cover from server.sql_environment import SQLEnvironment # type: ignore[no-redef] def get_tool_definitions(env_cls: type | None = None) -> list[dict]: """Extract tool definitions from an environment class via introspection. Inspects public methods (excluding reset and dunder) to build the same JSON schema that TRL generates for environment_factory. This guarantees SFT and GRPO see identical tool definitions. """ import inspect if env_cls is None: env_cls = SQLEnvTRL _SKIP = {"reset", "reward"} tools = [] for name, method in inspect.getmembers(env_cls, predicate=inspect.isfunction): if name.startswith("_") or name in _SKIP: continue sig = inspect.signature(method) doc = inspect.getdoc(method) or "" # Split docstring into description and Args/Returns sections lines = doc.split("\n") description = lines[0].strip() if lines else name # Parse Args section for parameter descriptions param_descriptions: dict[str, str] = {} return_description = "" section = "" for line in lines[1:]: stripped = line.strip() if stripped.lower().startswith("args:"): section = "args" continue if stripped.lower().startswith("returns:"): section = "returns" continue if section == "args" and ":" in stripped: param_name, param_desc = stripped.split(":", 1) param_descriptions[param_name.strip()] = param_desc.strip() if section == "returns" and stripped: return_description = stripped # Build parameters schema from signature properties = {} required = [] for param_name, param in sig.parameters.items(): if param_name == "self": continue properties[param_name] = { "type": "string", "description": param_descriptions.get( param_name, f"{param_name} parameter." ), } if param.default is inspect.Parameter.empty: required.append(param_name) tool = { "type": "function", "function": { "name": name, "description": description, "parameters": { "type": "object", "properties": properties, "required": required, }, }, } if return_description: tool["function"]["return"] = { "type": "string", "description": return_description, } tools.append(tool) # Sort by name for deterministic ordering tools.sort(key=lambda t: t["function"]["name"]) return tools class _MinimalTokenizer: """Minimal tokenizer stub used only for SQLEnvironment initialization.""" def apply_chat_template( self, messages: list[dict[str, str]], *, tokenize: bool = False, add_generation_prompt: bool = False, ) -> str: """Return an empty rendered prompt string. Parameters ---------- messages Chat message payload. tokenize Unused tokenizer flag. add_generation_prompt Unused generation-prompt flag. Returns ------- str Always an empty string. """ del messages del tokenize del add_generation_prompt return "" _POST_EPISODE_PENALTY = -0.3 # Adapter-level repeat penalty (on top of environment's -0.03 in reward.py). # Intentionally harsher: the env penalty shapes per-step reward, while this # penalty shapes the episode-level signal that GRPO sees. _REPEAT_PENALTY = -0.2 class SQLEnvTRL: """TRL-compatible adapter shell for SQLEnv.""" _questions_path: str | None = None _db_dir: str | None = None _step_budget: int = 10 @classmethod def _configure( cls, *, questions_path: str, db_dir: str, step_budget: int = 10, ) -> None: """Store class-level adapter configuration before TRL instantiation.""" if not questions_path: raise ValueError("questions_path must be a non-empty string") if not db_dir: raise ValueError("db_dir must be a non-empty string") if step_budget <= 0: raise ValueError("step_budget must be a positive integer") cls._questions_path = questions_path cls._db_dir = db_dir cls._step_budget = step_budget def __init__(self) -> None: """Initialize a configured SQLEnvironment-backed adapter instance.""" if self.__class__._questions_path is None or self.__class__._db_dir is None: raise RuntimeError( "SQLEnvTRL.configure() must be called before SQLEnvTRL()" ) tokenizer = _MinimalTokenizer() self._env = SQLEnvironment( questions_path=self.__class__._questions_path, db_dir=self.__class__._db_dir, tokenizer=tokenizer, step_budget=self.__class__._step_budget, ) self.reward = 0.0 self._done = False self._recent_calls: collections.deque[tuple[str, str]] = collections.deque( maxlen=3 ) self._repeat_count = 0 def reset(self, **kwargs: object) -> str | None: """Initialize a new episode and return the initial observation text. TRL passes dataset columns as kwargs. If ``question_text`` is present, the environment resets to the matching question (and therefore the correct database). Args: kwargs: Dataset columns from TRL, may include question_text. Returns: Short observation hint for the language model, or None. """ self.reward = 0.0 self._done = False self._recent_calls.clear() self._repeat_count = 0 question_text = kwargs.get("question_text") if question_text and isinstance(question_text, str): # Filter to the matching question so the right DB loads original = list(self._env.questions) matching = [ q for q in self._env.questions if q.question_text == question_text ] if matching: self._env.questions = matching try: self._obs = self._env.reset(seed=None) finally: self._env.questions = original else: self._obs = self._env.reset(seed=None) else: self._obs = self._env.reset(seed=None) # Return concise hint — full observation via describe/sample tables = [] for line in (self._obs.schema_info or "").split("\n"): stripped = line.strip().lstrip("- ").strip() if stripped and stripped != "Available tables:": tables.append(stripped) return ( f"Tables: {', '.join(tables)}. " "Use describe, sample, query, and answer tools." ) def _dispatch(self, action_type: str, argument: str) -> str: """Execute an action with repeat detection and reward accumulation.""" if self._done: self.reward += _POST_EPISODE_PENALTY raise ValueError("Episode is over") call_key = (action_type.lower(), argument) if call_key in self._recent_calls: self.reward += _REPEAT_PENALTY self._repeat_count += 1 self._recent_calls.append(call_key) observation = self._env.step( SQLAction(action_type=action_type, argument=argument) ) if observation.reward is not None: self.reward += observation.reward self._done = observation.done if observation.result: return observation.result if observation.error: return f"Error: {observation.error}" return "No output." def describe(self, table_name: str) -> str: """Show schema details for a database table. Args: table_name: Name of the table to describe. Returns: Schema information for the specified table. """ return self._dispatch("DESCRIBE", table_name) def sample(self, table_name: str) -> str: """Show sample rows from a database table. Args: table_name: Name of the table to sample. Returns: Sample row output for the specified table. """ return self._dispatch("SAMPLE", table_name) def query(self, sql: str) -> str: """Execute a read-only SQL query. Args: sql: SELECT SQL statement to execute. Returns: Query output text. """ return self._dispatch("QUERY", sql) def answer(self, value: str) -> str: """Submit a final answer for the active episode. Args: value: Final answer value to submit. Returns: Feedback text for the submitted answer. """ return self._dispatch("ANSWER", value) def sql_env_reward_func(environments, **kwargs): """Read accumulated reward from each environment instance. Args: environments: Completed environment instances (passed by TRL). kwargs: Additional TRL reward kwargs (ignored). Returns: Reward values aligned with input environment order. """ return [float(env.reward) for env in environments]