sql_env / specs /F006-VERIFICATION_INPUT.json
hjerpe's picture
Upload folder using huggingface_hub
5dd1bb4 verified
{
"$schema": "autocode-verification-input-v1",
"feature_id": "F006",
"spec_path": "specs/F006-IMPLEMENTATION_SPEC.md",
"generated": "2026-03-27T12:00:00Z",
"verification_mode": "mvp",
"overview": {
"summary": "GRPO training pipeline that trains a small LLM (default Qwen3-1.7B) to play SQLEnv using TRL's GRPOTrainer. Includes training config, system prompt, rollout function (plays episodes via HF generate), three reward callables (correctness, progress, operational), and an end-to-end training notebook that produces learning curves and before/after episode comparisons.",
"goal": "Enable users to train a model that learns SQL exploration strategy through RL, producing the 'before vs after' demonstration that shows an untrained agent flailing randomly while a trained agent explores strategically."
},
"interfaces": {
"types": [
{
"name": "GRPOConfig",
"fields": [
{"name": "model_name", "type": "str", "description": "HuggingFace model identifier (default: Qwen/Qwen3-1.7B)"},
{"name": "max_new_tokens", "type": "int", "description": "Max tokens per generation (default: 256)"},
{"name": "num_train_epochs", "type": "int", "description": "Number of training epochs (default: 1)"},
{"name": "per_device_train_batch_size", "type": "int", "description": "Batch size per device (default: 2)"},
{"name": "gradient_accumulation_steps", "type": "int", "description": "Gradient accumulation steps (default: 4)"},
{"name": "learning_rate", "type": "float", "description": "Learning rate (default: 5e-6)"},
{"name": "num_generations", "type": "int", "description": "GRPO completions per prompt (default: 4)"},
{"name": "questions_path", "type": "str", "description": "Path to training questions JSON"},
{"name": "db_dir", "type": "str", "description": "Path to SQLite databases directory"},
{"name": "step_budget", "type": "int", "description": "Max steps per episode (default: 10)"},
{"name": "difficulty_filter", "type": "list[str]", "description": "Question difficulties to include (default: [easy, medium])"},
{"name": "seed", "type": "int", "description": "Random seed for reproducibility (default: 42)"},
{"name": "output_dir", "type": "str", "description": "Directory for saving trained model"},
{"name": "logging_steps", "type": "int", "description": "Log metrics every N steps (default: 10)"}
],
"description": "Dataclass holding all GRPO training hyperparameters, model selection, and environment configuration."
}
],
"functions": [
{
"name": "get_system_prompt",
"location": "training/prompts.py",
"signature": "get_system_prompt() -> str",
"description": "Returns the system prompt instructing the model on SQLEnv action format (DESCRIBE, SAMPLE, QUERY, ANSWER) and exploration strategy."
},
{
"name": "format_observation",
"location": "training/prompts.py",
"signature": "format_observation(obs: SQLObservation) -> str",
"description": "Formats an SQLObservation into a user-turn string for the model's chat history."
},
{
"name": "parse_model_output",
"location": "training/rollout.py",
"signature": "parse_model_output(text: str) -> SQLAction",
"description": "Parses free-form model output into an SQLAction. Falls back to QUERY with raw text on parse failure."
},
{
"name": "rollout_func",
"location": "training/rollout.py",
"signature": "rollout_func(prompts: list[str], model: Any, tokenizer: Any, config: GRPOConfig) -> list[dict[str, Any]]",
"description": "Plays full SQLEnv episodes for a batch of question prompts using HF generate. Returns completions and metadata (correctness, progress, operational signals)."
},
{
"name": "reward_correctness",
"location": "training/rewards.py",
"signature": "reward_correctness(completions: list[list[dict[str, str]]], **kwargs: Any) -> list[float]",
"description": "Binary reward: 1.0 if episode ended with correct answer, 0.0 otherwise. TRL reward_func compatible."
},
{
"name": "reward_progress",
"location": "training/rewards.py",
"signature": "reward_progress(completions: list[list[dict[str, str]]], **kwargs: Any) -> list[float]",
"description": "Progress reward based on cumulative closeness to gold answer. Normalized to [0, 1]. TRL reward_func compatible."
},
{
"name": "reward_operational",
"location": "training/rewards.py",
"signature": "reward_operational(completions: list[list[dict[str, str]]], **kwargs: Any) -> list[float]",
"description": "Operational reward: sum of per-step L1 signals (exec_ok, new_info, repeat penalty). TRL reward_func compatible."
}
],
"api_endpoints": []
},
"data_flow": {
"primary_flow": [
"Notebook loads GRPOConfig and model/tokenizer from HuggingFace",
"Training questions loaded and filtered by difficulty (easy+medium)",
"GRPOTrainer calls rollout_func for each batch of question prompts",
"rollout_func creates local SQLEnvironment, plays episodes via model.generate() loop",
"Each reward_func receives completions + metadata, returns list[float]",
"GRPOTrainer computes GRPO loss and updates model weights",
"After training: plot learning curve and run comparison episodes"
],
"alternative_flows": [
{
"condition": "Model generates unparseable action text",
"steps": ["parse_model_output falls back to SQLAction(action_type='QUERY', argument=raw_text)", "Environment returns error observation", "Episode continues normally"]
},
{
"condition": "Conversation history exceeds token window",
"steps": ["rollout truncates history to system prompt + last 3 observation-action pairs", "Episode continues with truncated context"]
}
]
},
"error_handling": {
"error_types": [
{"name": "ModelLoadError", "condition": "Model not found on HuggingFace", "strategy": "Fail fast with clear message naming model_name"},
{"name": "ActionParseError", "condition": "Model output not parseable as ACTION_TYPE: argument", "strategy": "Default to QUERY with raw text, log warning"},
{"name": "OOMError", "condition": "GPU out of memory during training", "strategy": "Print guidance to reduce batch_size or num_generations"},
{"name": "QuestionLoadError", "condition": "Questions file missing or empty", "strategy": "Fail fast with path in error message"}
],
"retry_strategy": null
},
"dependencies": {
"external": [
{"name": "trl", "version": ">=0.12.0", "usage": "GRPOTrainer for GRPO training loop"},
{"name": "transformers", "version": "<5", "usage": "Model loading, tokenizer, AutoModelForCausalLM"},
{"name": "accelerate", "version": ">=0.34.0", "usage": "Required by TRL for training orchestration"},
{"name": "torch", "version": "==2.2.2", "usage": "PyTorch backend for model training"},
{"name": "matplotlib", "version": ">=3.0.0", "usage": "Learning curve plots in notebook"}
],
"internal": [
{"name": "models.SQLAction", "usage": "Action type for environment step"},
{"name": "models.SQLObservation", "usage": "Observation type from environment"},
{"name": "models.QuestionRecord", "usage": "Question data structure"},
{"name": "server.sql_environment.SQLEnvironment", "usage": "Local environment instance for rollout episodes"}
]
}
}