| { |
| "$schema": "autocode-verification-input-v1", |
| "feature_id": "F006", |
| "spec_path": "specs/F006-IMPLEMENTATION_SPEC.md", |
| "generated": "2026-03-27T12:00:00Z", |
| "verification_mode": "mvp", |
|
|
| "overview": { |
| "summary": "GRPO training pipeline that trains a small LLM (default Qwen3-1.7B) to play SQLEnv using TRL's GRPOTrainer. Includes training config, system prompt, rollout function (plays episodes via HF generate), three reward callables (correctness, progress, operational), and an end-to-end training notebook that produces learning curves and before/after episode comparisons.", |
| "goal": "Enable users to train a model that learns SQL exploration strategy through RL, producing the 'before vs after' demonstration that shows an untrained agent flailing randomly while a trained agent explores strategically." |
| }, |
|
|
| "interfaces": { |
| "types": [ |
| { |
| "name": "GRPOConfig", |
| "fields": [ |
| {"name": "model_name", "type": "str", "description": "HuggingFace model identifier (default: Qwen/Qwen3-1.7B)"}, |
| {"name": "max_new_tokens", "type": "int", "description": "Max tokens per generation (default: 256)"}, |
| {"name": "num_train_epochs", "type": "int", "description": "Number of training epochs (default: 1)"}, |
| {"name": "per_device_train_batch_size", "type": "int", "description": "Batch size per device (default: 2)"}, |
| {"name": "gradient_accumulation_steps", "type": "int", "description": "Gradient accumulation steps (default: 4)"}, |
| {"name": "learning_rate", "type": "float", "description": "Learning rate (default: 5e-6)"}, |
| {"name": "num_generations", "type": "int", "description": "GRPO completions per prompt (default: 4)"}, |
| {"name": "questions_path", "type": "str", "description": "Path to training questions JSON"}, |
| {"name": "db_dir", "type": "str", "description": "Path to SQLite databases directory"}, |
| {"name": "step_budget", "type": "int", "description": "Max steps per episode (default: 10)"}, |
| {"name": "difficulty_filter", "type": "list[str]", "description": "Question difficulties to include (default: [easy, medium])"}, |
| {"name": "seed", "type": "int", "description": "Random seed for reproducibility (default: 42)"}, |
| {"name": "output_dir", "type": "str", "description": "Directory for saving trained model"}, |
| {"name": "logging_steps", "type": "int", "description": "Log metrics every N steps (default: 10)"} |
| ], |
| "description": "Dataclass holding all GRPO training hyperparameters, model selection, and environment configuration." |
| } |
| ], |
| "functions": [ |
| { |
| "name": "get_system_prompt", |
| "location": "training/prompts.py", |
| "signature": "get_system_prompt() -> str", |
| "description": "Returns the system prompt instructing the model on SQLEnv action format (DESCRIBE, SAMPLE, QUERY, ANSWER) and exploration strategy." |
| }, |
| { |
| "name": "format_observation", |
| "location": "training/prompts.py", |
| "signature": "format_observation(obs: SQLObservation) -> str", |
| "description": "Formats an SQLObservation into a user-turn string for the model's chat history." |
| }, |
| { |
| "name": "parse_model_output", |
| "location": "training/rollout.py", |
| "signature": "parse_model_output(text: str) -> SQLAction", |
| "description": "Parses free-form model output into an SQLAction. Falls back to QUERY with raw text on parse failure." |
| }, |
| { |
| "name": "rollout_func", |
| "location": "training/rollout.py", |
| "signature": "rollout_func(prompts: list[str], model: Any, tokenizer: Any, config: GRPOConfig) -> list[dict[str, Any]]", |
| "description": "Plays full SQLEnv episodes for a batch of question prompts using HF generate. Returns completions and metadata (correctness, progress, operational signals)." |
| }, |
| { |
| "name": "reward_correctness", |
| "location": "training/rewards.py", |
| "signature": "reward_correctness(completions: list[list[dict[str, str]]], **kwargs: Any) -> list[float]", |
| "description": "Binary reward: 1.0 if episode ended with correct answer, 0.0 otherwise. TRL reward_func compatible." |
| }, |
| { |
| "name": "reward_progress", |
| "location": "training/rewards.py", |
| "signature": "reward_progress(completions: list[list[dict[str, str]]], **kwargs: Any) -> list[float]", |
| "description": "Progress reward based on cumulative closeness to gold answer. Normalized to [0, 1]. TRL reward_func compatible." |
| }, |
| { |
| "name": "reward_operational", |
| "location": "training/rewards.py", |
| "signature": "reward_operational(completions: list[list[dict[str, str]]], **kwargs: Any) -> list[float]", |
| "description": "Operational reward: sum of per-step L1 signals (exec_ok, new_info, repeat penalty). TRL reward_func compatible." |
| } |
| ], |
| "api_endpoints": [] |
| }, |
|
|
| "data_flow": { |
| "primary_flow": [ |
| "Notebook loads GRPOConfig and model/tokenizer from HuggingFace", |
| "Training questions loaded and filtered by difficulty (easy+medium)", |
| "GRPOTrainer calls rollout_func for each batch of question prompts", |
| "rollout_func creates local SQLEnvironment, plays episodes via model.generate() loop", |
| "Each reward_func receives completions + metadata, returns list[float]", |
| "GRPOTrainer computes GRPO loss and updates model weights", |
| "After training: plot learning curve and run comparison episodes" |
| ], |
| "alternative_flows": [ |
| { |
| "condition": "Model generates unparseable action text", |
| "steps": ["parse_model_output falls back to SQLAction(action_type='QUERY', argument=raw_text)", "Environment returns error observation", "Episode continues normally"] |
| }, |
| { |
| "condition": "Conversation history exceeds token window", |
| "steps": ["rollout truncates history to system prompt + last 3 observation-action pairs", "Episode continues with truncated context"] |
| } |
| ] |
| }, |
|
|
| "error_handling": { |
| "error_types": [ |
| {"name": "ModelLoadError", "condition": "Model not found on HuggingFace", "strategy": "Fail fast with clear message naming model_name"}, |
| {"name": "ActionParseError", "condition": "Model output not parseable as ACTION_TYPE: argument", "strategy": "Default to QUERY with raw text, log warning"}, |
| {"name": "OOMError", "condition": "GPU out of memory during training", "strategy": "Print guidance to reduce batch_size or num_generations"}, |
| {"name": "QuestionLoadError", "condition": "Questions file missing or empty", "strategy": "Fail fast with path in error message"} |
| ], |
| "retry_strategy": null |
| }, |
|
|
| "dependencies": { |
| "external": [ |
| {"name": "trl", "version": ">=0.12.0", "usage": "GRPOTrainer for GRPO training loop"}, |
| {"name": "transformers", "version": "<5", "usage": "Model loading, tokenizer, AutoModelForCausalLM"}, |
| {"name": "accelerate", "version": ">=0.34.0", "usage": "Required by TRL for training orchestration"}, |
| {"name": "torch", "version": "==2.2.2", "usage": "PyTorch backend for model training"}, |
| {"name": "matplotlib", "version": ">=3.0.0", "usage": "Learning curve plots in notebook"} |
| ], |
| "internal": [ |
| {"name": "models.SQLAction", "usage": "Action type for environment step"}, |
| {"name": "models.SQLObservation", "usage": "Observation type from environment"}, |
| {"name": "models.QuestionRecord", "usage": "Question data structure"}, |
| {"name": "server.sql_environment.SQLEnvironment", "usage": "Local environment instance for rollout episodes"} |
| ] |
| } |
| } |
|
|