""" nl2sql-bench/server/tasks/base.py ================================== Abstract base for all NL2SQL tasks and the global task registry. Each task holds a list of (question, ground_truth_sql) pairs. The environment picks one pair per episode via a deterministic round-robin so that the same task always cycles through the same question sequence — this keeps grader results reproducible across runs. """ from __future__ import annotations import sqlite3 from abc import ABC, abstractmethod from typing import Dict, List, NamedTuple, Tuple, Type class TaskExample(NamedTuple): question: str sql: str # Human-readable description of what makes this question that difficulty notes: str = "" class BaseTask(ABC): """Abstract base class for all tasks.""" name: str = "" difficulty: str = "" # easy | medium | hard examples: List[TaskExample] = [] def __init__(self) -> None: if not self.examples: raise ValueError(f"Task {self.name!r} has no examples defined.") self._cursor = 0 # round-robin index def next_example(self) -> TaskExample: """Return the next question in round-robin order.""" example = self.examples[self._cursor % len(self.examples)] self._cursor += 1 return example @classmethod def schema_context(cls) -> str: """Return a compact schema description for the agent system prompt.""" return _SCHEMA_CONTEXT @abstractmethod def description(self) -> str: """One-sentence description for openenv.yaml.""" # ── Global schema context string (injected into every observation) ───────── _SCHEMA_CONTEXT = """\ Database: ecommerce (SQLite, read-only) TABLES ------ categories(id INTEGER PK, name TEXT) products(id INTEGER PK, name TEXT, category_id INTEGER FK→categories.id, price REAL, stock_quantity INTEGER) customers(id INTEGER PK, name TEXT, email TEXT, country TEXT, tier TEXT ∈ {bronze|silver|gold}, created_at TEXT ISO-8601) orders(id INTEGER PK, customer_id INTEGER FK→customers.id, status TEXT ∈ {pending|processing|shipped|delivered|cancelled}, created_at TEXT ISO-8601, total_amount REAL) order_items(id INTEGER PK, order_id INTEGER FK→orders.id, product_id INTEGER FK→products.id, quantity INTEGER, unit_price REAL) reviews(id INTEGER PK, product_id INTEGER FK→products.id, customer_id INTEGER FK→customers.id, rating INTEGER 1-5, created_at TEXT ISO-8601) NOTES ----- - Date comparisons: use created_at >= '2024-01-01' (text ISO sort works) - SQLite window functions (RANK, DENSE_RANK, ROW_NUMBER, LAG, LEAD) are available - strftime('%Y-%m', created_at) returns 'YYYY-MM' month strings - All monetary values are in USD """ # ── Task registry ────────────────────────────────────────────────────────── _REGISTRY: Dict[str, Type[BaseTask]] = {} def register(cls: Type[BaseTask]) -> Type[BaseTask]: """Class decorator to register a task.""" _REGISTRY[cls.name] = cls return cls def get_task(name: str) -> BaseTask: if name not in _REGISTRY: raise KeyError(f"Unknown task {name!r}. Available: {list(_REGISTRY)}") return _REGISTRY[name]() def all_task_names() -> List[str]: return list(_REGISTRY.keys())