| """Training utilities for GRPO-based SQLEnv experiments.""" |
|
|
| import importlib |
|
|
| from .config import GRPOConfig, apply_device_overrides, find_project_root |
|
|
| __all__ = [ |
| "GRPOConfig", |
| "apply_device_overrides", |
| "find_project_root", |
| "build_trainer", |
| "filter_questions_by_difficulty", |
| "format_observation", |
| "format_oom_guidance", |
| "get_system_prompt", |
| "load_model_and_tokenizer", |
| "load_question_prompts", |
| "run_training_with_metrics", |
| "sample_random_baseline", |
| "reward_correctness", |
| "reward_progress", |
| "reward_operational", |
| "LiveVisualizationCallback", |
| "SQLEnvTRL", |
| "sql_env_reward_func", |
| ] |
|
|
| _LAZY_MAP = { |
| "filter_questions_by_difficulty": ".data_loading", |
| "load_model_and_tokenizer": ".data_loading", |
| "load_question_prompts": ".data_loading", |
| "build_trainer": ".notebook_pipeline", |
| "format_oom_guidance": ".notebook_pipeline", |
| "run_training_with_metrics": ".notebook_pipeline", |
| "sample_random_baseline": ".notebook_pipeline", |
| "format_observation": ".prompts", |
| "get_system_prompt": ".prompts", |
| "reward_correctness": ".rewards", |
| "reward_operational": ".rewards", |
| "reward_progress": ".rewards", |
| "LiveVisualizationCallback": ".visualization", |
| "SQLEnvTRL": ".trl_adapter", |
| "sql_env_reward_func": ".trl_adapter", |
| } |
|
|
|
|
| def __getattr__(name: str): |
| """Lazy-load heavy modules on first access.""" |
| if name not in _LAZY_MAP: |
| raise AttributeError(f"module {__name__!r} has no attribute {name!r}") |
| module_name = _LAZY_MAP[name] |
| mod = importlib.import_module(module_name, __name__) |
| return getattr(mod, name) |
|
|