{ "$schema": "autocode-verification-input-v1", "feature_id": "F006", "spec_path": "specs/F006-IMPLEMENTATION_SPEC.md", "generated": "2026-03-27T12:00:00Z", "verification_mode": "mvp", "overview": { "summary": "GRPO training pipeline that trains a small LLM (default Qwen3-1.7B) to play SQLEnv using TRL's GRPOTrainer. Includes training config, system prompt, rollout function (plays episodes via HF generate), three reward callables (correctness, progress, operational), and an end-to-end training notebook that produces learning curves and before/after episode comparisons.", "goal": "Enable users to train a model that learns SQL exploration strategy through RL, producing the 'before vs after' demonstration that shows an untrained agent flailing randomly while a trained agent explores strategically." }, "interfaces": { "types": [ { "name": "GRPOConfig", "fields": [ {"name": "model_name", "type": "str", "description": "HuggingFace model identifier (default: Qwen/Qwen3-1.7B)"}, {"name": "max_new_tokens", "type": "int", "description": "Max tokens per generation (default: 256)"}, {"name": "num_train_epochs", "type": "int", "description": "Number of training epochs (default: 1)"}, {"name": "per_device_train_batch_size", "type": "int", "description": "Batch size per device (default: 2)"}, {"name": "gradient_accumulation_steps", "type": "int", "description": "Gradient accumulation steps (default: 4)"}, {"name": "learning_rate", "type": "float", "description": "Learning rate (default: 5e-6)"}, {"name": "num_generations", "type": "int", "description": "GRPO completions per prompt (default: 4)"}, {"name": "questions_path", "type": "str", "description": "Path to training questions JSON"}, {"name": "db_dir", "type": "str", "description": "Path to SQLite databases directory"}, {"name": "step_budget", "type": "int", "description": "Max steps per episode (default: 10)"}, {"name": "difficulty_filter", "type": "list[str]", "description": "Question difficulties to include (default: [easy, medium])"}, {"name": "seed", "type": "int", "description": "Random seed for reproducibility (default: 42)"}, {"name": "output_dir", "type": "str", "description": "Directory for saving trained model"}, {"name": "logging_steps", "type": "int", "description": "Log metrics every N steps (default: 10)"} ], "description": "Dataclass holding all GRPO training hyperparameters, model selection, and environment configuration." } ], "functions": [ { "name": "get_system_prompt", "location": "training/prompts.py", "signature": "get_system_prompt() -> str", "description": "Returns the system prompt instructing the model on SQLEnv action format (DESCRIBE, SAMPLE, QUERY, ANSWER) and exploration strategy." }, { "name": "format_observation", "location": "training/prompts.py", "signature": "format_observation(obs: SQLObservation) -> str", "description": "Formats an SQLObservation into a user-turn string for the model's chat history." }, { "name": "parse_model_output", "location": "training/rollout.py", "signature": "parse_model_output(text: str) -> SQLAction", "description": "Parses free-form model output into an SQLAction. Falls back to QUERY with raw text on parse failure." }, { "name": "rollout_func", "location": "training/rollout.py", "signature": "rollout_func(prompts: list[str], model: Any, tokenizer: Any, config: GRPOConfig) -> list[dict[str, Any]]", "description": "Plays full SQLEnv episodes for a batch of question prompts using HF generate. Returns completions and metadata (correctness, progress, operational signals)." }, { "name": "reward_correctness", "location": "training/rewards.py", "signature": "reward_correctness(completions: list[list[dict[str, str]]], **kwargs: Any) -> list[float]", "description": "Binary reward: 1.0 if episode ended with correct answer, 0.0 otherwise. TRL reward_func compatible." }, { "name": "reward_progress", "location": "training/rewards.py", "signature": "reward_progress(completions: list[list[dict[str, str]]], **kwargs: Any) -> list[float]", "description": "Progress reward based on cumulative closeness to gold answer. Normalized to [0, 1]. TRL reward_func compatible." }, { "name": "reward_operational", "location": "training/rewards.py", "signature": "reward_operational(completions: list[list[dict[str, str]]], **kwargs: Any) -> list[float]", "description": "Operational reward: sum of per-step L1 signals (exec_ok, new_info, repeat penalty). TRL reward_func compatible." } ], "api_endpoints": [] }, "data_flow": { "primary_flow": [ "Notebook loads GRPOConfig and model/tokenizer from HuggingFace", "Training questions loaded and filtered by difficulty (easy+medium)", "GRPOTrainer calls rollout_func for each batch of question prompts", "rollout_func creates local SQLEnvironment, plays episodes via model.generate() loop", "Each reward_func receives completions + metadata, returns list[float]", "GRPOTrainer computes GRPO loss and updates model weights", "After training: plot learning curve and run comparison episodes" ], "alternative_flows": [ { "condition": "Model generates unparseable action text", "steps": ["parse_model_output falls back to SQLAction(action_type='QUERY', argument=raw_text)", "Environment returns error observation", "Episode continues normally"] }, { "condition": "Conversation history exceeds token window", "steps": ["rollout truncates history to system prompt + last 3 observation-action pairs", "Episode continues with truncated context"] } ] }, "error_handling": { "error_types": [ {"name": "ModelLoadError", "condition": "Model not found on HuggingFace", "strategy": "Fail fast with clear message naming model_name"}, {"name": "ActionParseError", "condition": "Model output not parseable as ACTION_TYPE: argument", "strategy": "Default to QUERY with raw text, log warning"}, {"name": "OOMError", "condition": "GPU out of memory during training", "strategy": "Print guidance to reduce batch_size or num_generations"}, {"name": "QuestionLoadError", "condition": "Questions file missing or empty", "strategy": "Fail fast with path in error message"} ], "retry_strategy": null }, "dependencies": { "external": [ {"name": "trl", "version": ">=0.12.0", "usage": "GRPOTrainer for GRPO training loop"}, {"name": "transformers", "version": "<5", "usage": "Model loading, tokenizer, AutoModelForCausalLM"}, {"name": "accelerate", "version": ">=0.34.0", "usage": "Required by TRL for training orchestration"}, {"name": "torch", "version": "==2.2.2", "usage": "PyTorch backend for model training"}, {"name": "matplotlib", "version": ">=3.0.0", "usage": "Learning curve plots in notebook"} ], "internal": [ {"name": "models.SQLAction", "usage": "Action type for environment step"}, {"name": "models.SQLObservation", "usage": "Observation type from environment"}, {"name": "models.QuestionRecord", "usage": "Question data structure"}, {"name": "server.sql_environment.SQLEnvironment", "usage": "Local environment instance for rollout episodes"} ] } }