| """Local test for GRPO training with SQLEnvTRL. |
| |
| Usage: |
| docker build -f Dockerfile.test -t sqlenv-test . |
| docker run sqlenv-test |
| docker run sqlenv-test python scripts/test_training_local.py \ |
| --config configs/colab_l4.json |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| from pathlib import Path |
|
|
| os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "") |
|
|
| root = Path(__file__).parent.parent |
| sys.path.insert(0, str(root)) |
|
|
|
|
| def load_config(path: str) -> dict: |
| with open(path) as f: |
| return json.load(f) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--config", |
| default="configs/test_cpu.json", |
| help="Training config JSON", |
| ) |
| args = parser.parse_args() |
|
|
| cfg = load_config(args.config) |
| print(f"Config: {args.config}") |
| print(json.dumps(cfg, indent=2)) |
|
|
| import transformers |
| import trl |
| from datasets import Dataset |
| from trl import GRPOConfig, GRPOTrainer |
|
|
| from sql_env.training.trl_adapter import ( |
| SQLEnvTRL, |
| sql_env_reward_func, |
| ) |
|
|
| print(f"\nTRL: {trl.__version__}, Transformers: {transformers.__version__}") |
|
|
| |
| SQLEnvTRL._configure( |
| questions_path=cfg["questions_path"], |
| db_dir=cfg["db_dir"], |
| step_budget=cfg["step_budget"], |
| ) |
| env = SQLEnvTRL() |
| obs = env.reset() |
| print("\n--- Environment smoke test ---") |
| print(f"Reset: {obs}") |
| r = env.describe(table_name="employee") |
| print(f"Describe: {r[:80]}") |
| r = env.query(sql="SELECT COUNT(*) FROM employee") |
| print(f"Query: {r}") |
| r = env.answer(value="10") |
| print(f"Answer: {r}") |
| print(f"Total reward: {env.reward:.4f}") |
|
|
| |
| enable_thinking = cfg.get("enable_thinking", False) |
| system_prompt_base = ( |
| "You answer questions about a SQL database. " |
| "Use ONLY the provided tools.\n\n" |
| "Strategy:\n" |
| "1. Call describe(table_name=...) to see columns\n" |
| "2. Call query(sql=...) to run SELECT queries\n" |
| "3. Call answer(value=...) to submit your answer" |
| ) |
| system_prompt = ( |
| system_prompt_base if enable_thinking else "/no_think\n" + system_prompt_base |
| ) |
| questions = [ |
| "How many employees are there?", |
| "What are the names of all shops?", |
| "Find the total number of concerts.", |
| "List all singer names.", |
| ] |
| prompt_msgs = [ |
| [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": q}, |
| ] |
| for q in questions |
| ] |
| size = cfg.get("dataset_size", len(prompt_msgs)) |
| repeated = (prompt_msgs * ((size // len(prompt_msgs)) + 1))[:size] |
| repeated_q = (questions * ((size // len(questions)) + 1))[:size] |
| dataset = Dataset.from_dict({"prompt": repeated, "question_text": repeated_q}) |
|
|
| |
| print("\n--- Building trainer ---") |
| grpo_kwargs = { |
| "output_dir": cfg["output_dir"], |
| "per_device_train_batch_size": cfg["per_device_train_batch_size"], |
| "num_generations": cfg["num_generations"], |
| "num_train_epochs": cfg["num_train_epochs"], |
| "max_completion_length": cfg["max_completion_length"], |
| "logging_steps": cfg["logging_steps"], |
| "log_completions": True, |
| "num_completions_to_print": cfg.get("num_completions_to_print", 2), |
| "remove_unused_columns": False, |
| } |
| if cfg.get("max_steps"): |
| grpo_kwargs["max_steps"] = cfg["max_steps"] |
| grpo_kwargs["chat_template_kwargs"] = { |
| "enable_thinking": enable_thinking, |
| } |
|
|
| precision = cfg.get("precision", "fp32") |
| if precision == "bf16": |
| grpo_kwargs.update(bf16=True, fp16=False) |
| elif precision == "fp16": |
| grpo_kwargs.update(bf16=False, fp16=True) |
| else: |
| grpo_kwargs.update(bf16=False, fp16=False) |
|
|
| trainer = GRPOTrainer( |
| model=cfg["model_name"], |
| reward_funcs=sql_env_reward_func, |
| train_dataset=dataset, |
| environment_factory=SQLEnvTRL, |
| args=GRPOConfig(**grpo_kwargs), |
| ) |
|
|
| |
| print(f"\n--- Training ({cfg.get('max_steps', 'all')} steps) ---") |
| trainer.train() |
|
|
| |
| print("\n--- Results ---") |
| for entry in trainer.state.log_history: |
| step = entry.get("step") |
| loss = entry.get("loss") |
| if loss is None: |
| continue |
| reward = entry.get("reward", 0) |
| reward_std = entry.get("reward_std", 0) |
| tools_freq = entry.get("tools/call_frequency", 0) |
| clipped = entry.get("completions/clipped_ratio", 0) |
| mean_len = entry.get("completions/mean_length", 0) |
| print( |
| f"Step {step:>3}: " |
| f"loss={loss:.4f} " |
| f"reward={reward:.4f} +/-{reward_std:.4f} " |
| f"tools={tools_freq:.2f} " |
| f"clipped={clipped:.0%} " |
| f"len={mean_len:.0f}" |
| ) |
|
|
| losses = [e["loss"] for e in trainer.state.log_history if "loss" in e] |
| rewards = [e.get("reward", 0) for e in trainer.state.log_history if "loss" in e] |
|
|
| print(f"\nLoss: {losses}") |
| print(f"Reward: {rewards}") |
|
|
| if losses and any(v != 0.0 for v in losses): |
| print("\nSUCCESS: Non-zero training loss") |
| else: |
| print("\nFAILED: All losses zero") |
|
|
| if rewards and any(v != 0.0 for v in rewards): |
| print("SUCCESS: Non-zero rewards") |
| else: |
| print("FAILED: All rewards zero") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|