"""Local test for GRPO training with SQLEnvTRL. Usage: docker build -f Dockerfile.test -t sqlenv-test . docker run sqlenv-test docker run sqlenv-test python scripts/test_training_local.py \ --config configs/colab_l4.json """ from __future__ import annotations import argparse import json import os import sys from pathlib import Path os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "") root = Path(__file__).parent.parent sys.path.insert(0, str(root)) def load_config(path: str) -> dict: with open(path) as f: return json.load(f) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( "--config", default="configs/test_cpu.json", help="Training config JSON", ) args = parser.parse_args() cfg = load_config(args.config) print(f"Config: {args.config}") print(json.dumps(cfg, indent=2)) import transformers import trl from datasets import Dataset from trl import GRPOConfig, GRPOTrainer from sql_env.training.trl_adapter import ( SQLEnvTRL, sql_env_reward_func, ) print(f"\nTRL: {trl.__version__}, Transformers: {transformers.__version__}") # 1. Configure environment SQLEnvTRL._configure( questions_path=cfg["questions_path"], db_dir=cfg["db_dir"], step_budget=cfg["step_budget"], ) env = SQLEnvTRL() obs = env.reset() print("\n--- Environment smoke test ---") print(f"Reset: {obs}") r = env.describe(table_name="employee") print(f"Describe: {r[:80]}") r = env.query(sql="SELECT COUNT(*) FROM employee") print(f"Query: {r}") r = env.answer(value="10") print(f"Answer: {r}") print(f"Total reward: {env.reward:.4f}") # 2. Dataset enable_thinking = cfg.get("enable_thinking", False) system_prompt_base = ( "You answer questions about a SQL database. " "Use ONLY the provided tools.\n\n" "Strategy:\n" "1. Call describe(table_name=...) to see columns\n" "2. Call query(sql=...) to run SELECT queries\n" "3. Call answer(value=...) to submit your answer" ) system_prompt = ( system_prompt_base if enable_thinking else "/no_think\n" + system_prompt_base ) questions = [ "How many employees are there?", "What are the names of all shops?", "Find the total number of concerts.", "List all singer names.", ] prompt_msgs = [ [ {"role": "system", "content": system_prompt}, {"role": "user", "content": q}, ] for q in questions ] size = cfg.get("dataset_size", len(prompt_msgs)) repeated = (prompt_msgs * ((size // len(prompt_msgs)) + 1))[:size] repeated_q = (questions * ((size // len(questions)) + 1))[:size] dataset = Dataset.from_dict({"prompt": repeated, "question_text": repeated_q}) # 3. Trainer config print("\n--- Building trainer ---") grpo_kwargs = { "output_dir": cfg["output_dir"], "per_device_train_batch_size": cfg["per_device_train_batch_size"], "num_generations": cfg["num_generations"], "num_train_epochs": cfg["num_train_epochs"], "max_completion_length": cfg["max_completion_length"], "logging_steps": cfg["logging_steps"], "log_completions": True, "num_completions_to_print": cfg.get("num_completions_to_print", 2), "remove_unused_columns": False, } if cfg.get("max_steps"): grpo_kwargs["max_steps"] = cfg["max_steps"] grpo_kwargs["chat_template_kwargs"] = { "enable_thinking": enable_thinking, } precision = cfg.get("precision", "fp32") if precision == "bf16": grpo_kwargs.update(bf16=True, fp16=False) elif precision == "fp16": grpo_kwargs.update(bf16=False, fp16=True) else: grpo_kwargs.update(bf16=False, fp16=False) trainer = GRPOTrainer( model=cfg["model_name"], reward_funcs=sql_env_reward_func, train_dataset=dataset, environment_factory=SQLEnvTRL, args=GRPOConfig(**grpo_kwargs), ) # 4. Train print(f"\n--- Training ({cfg.get('max_steps', 'all')} steps) ---") trainer.train() # 5. Results print("\n--- Results ---") for entry in trainer.state.log_history: step = entry.get("step") loss = entry.get("loss") if loss is None: continue reward = entry.get("reward", 0) reward_std = entry.get("reward_std", 0) tools_freq = entry.get("tools/call_frequency", 0) clipped = entry.get("completions/clipped_ratio", 0) mean_len = entry.get("completions/mean_length", 0) print( f"Step {step:>3}: " f"loss={loss:.4f} " f"reward={reward:.4f} +/-{reward_std:.4f} " f"tools={tools_freq:.2f} " f"clipped={clipped:.0%} " f"len={mean_len:.0f}" ) losses = [e["loss"] for e in trainer.state.log_history if "loss" in e] rewards = [e.get("reward", 0) for e in trainer.state.log_history if "loss" in e] print(f"\nLoss: {losses}") print(f"Reward: {rewards}") if losses and any(v != 0.0 for v in losses): print("\nSUCCESS: Non-zero training loss") else: print("\nFAILED: All losses zero") if rewards and any(v != 0.0 for v in rewards): print("SUCCESS: Non-zero rewards") else: print("FAILED: All rewards zero") if __name__ == "__main__": main()