sql_env / training /__init__.py
hjerpe's picture
Upload folder using huggingface_hub
9e64e71 verified
"""Training utilities for GRPO-based SQLEnv experiments."""
import importlib
from .config import GRPOConfig, apply_device_overrides, find_project_root
__all__ = [
"GRPOConfig",
"apply_device_overrides",
"find_project_root",
"build_trainer",
"filter_questions_by_difficulty",
"format_observation",
"format_oom_guidance",
"get_system_prompt",
"load_model_and_tokenizer",
"load_question_prompts",
"run_training_with_metrics",
"sample_random_baseline",
"reward_correctness",
"reward_progress",
"reward_operational",
"LiveVisualizationCallback",
"SQLEnvTRL",
"sql_env_reward_func",
]
_LAZY_MAP = {
"filter_questions_by_difficulty": ".data_loading",
"load_model_and_tokenizer": ".data_loading",
"load_question_prompts": ".data_loading",
"build_trainer": ".notebook_pipeline",
"format_oom_guidance": ".notebook_pipeline",
"run_training_with_metrics": ".notebook_pipeline",
"sample_random_baseline": ".notebook_pipeline",
"format_observation": ".prompts",
"get_system_prompt": ".prompts",
"reward_correctness": ".rewards",
"reward_operational": ".rewards",
"reward_progress": ".rewards",
"LiveVisualizationCallback": ".visualization",
"SQLEnvTRL": ".trl_adapter",
"sql_env_reward_func": ".trl_adapter",
}
def __getattr__(name: str):
"""Lazy-load heavy modules on first access."""
if name not in _LAZY_MAP:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
module_name = _LAZY_MAP[name]
mod = importlib.import_module(module_name, __name__)
return getattr(mod, name)