sql_env / training /notebook_pipeline.py
hjerpe's picture
Upload folder using huggingface_hub
9e64e71 verified
"""Notebook-oriented helpers for GRPO training orchestration."""
from __future__ import annotations
import random
from typing import Any
def _precision_kwargs(precision: str) -> dict[str, bool]:
"""Map precision string to TRL config kwargs."""
if precision == "fp16":
return {"fp16": True, "bf16": False}
if precision == "bf16":
return {"fp16": False, "bf16": True}
if precision == "fp32":
return {"fp16": False, "bf16": False}
# "auto" — let TRL/transformers decide
return {}
def sample_random_baseline(
prompts: list[str],
*,
step_budget: int,
seed: int,
) -> list[dict[str, Any]]:
"""Generate simple random-action transcripts for baseline comparison."""
rng = random.Random(seed)
action_types = ["DESCRIBE", "SAMPLE", "QUERY", "ANSWER"]
transcripts: list[dict[str, Any]] = []
for prompt in prompts:
step_count = max(1, min(step_budget, 5))
lines = []
for _ in range(step_count):
action = rng.choice(action_types)
argument = "table_1" if action != "QUERY" else "SELECT 1"
lines.append(f"{action}: {argument}")
transcripts.append(
{
"prompt": prompt,
"completion": "\n".join(lines),
"content": "\n".join(lines),
"metadata": {"policy": "random", "step_count": step_count},
}
)
return transcripts
def build_trainer(
*,
model: Any,
tokenizer: Any,
prompts: list[str],
config: Any,
trl_grpo_config_cls: type,
grpo_trainer_cls: type,
reward_funcs: list[Any],
environment_factory: type | None = None,
callbacks: list[Any] | None = None,
) -> Any:
"""Build a GRPO trainer instance using notebook config objects."""
extra_kwargs: dict[str, Any] = {}
if getattr(config, "gradient_checkpointing", False):
extra_kwargs["gradient_checkpointing"] = True
trainer_config = trl_grpo_config_cls(
output_dir=config.output_dir,
learning_rate=config.learning_rate,
per_device_train_batch_size=config.per_device_train_batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
num_train_epochs=config.num_train_epochs,
logging_steps=config.logging_steps,
max_completion_length=config.max_new_tokens,
num_generations=config.num_generations,
generation_batch_size=config.num_generations,
beta=getattr(config, "beta", 0.04),
**_precision_kwargs(getattr(config, "precision", "auto")),
**extra_kwargs,
remove_unused_columns=False,
log_completions=True,
num_completions_to_print=1,
chat_template_kwargs={
"enable_thinking": getattr(config, "enable_thinking", False),
},
)
trainer_kwargs: dict[str, Any] = {
"model": model,
"processing_class": tokenizer,
"args": trainer_config,
"train_dataset": prompts,
"reward_funcs": reward_funcs,
}
if environment_factory is not None:
configure = getattr(environment_factory, "configure", None)
if not callable(configure):
configure = getattr(environment_factory, "_configure", None)
if callable(configure):
configure(
questions_path=config.questions_path,
db_dir=config.db_dir,
step_budget=config.step_budget,
)
trainer_kwargs["environment_factory"] = environment_factory
if callbacks is not None:
trainer_kwargs["callbacks"] = callbacks
return grpo_trainer_cls(
**trainer_kwargs,
)
def run_training_with_metrics(trainer: Any) -> tuple[Any, list[int], list[float]]:
"""Run trainer.train() and extract plotting-friendly step/reward vectors."""
train_output = trainer.train()
log_history: list[dict[str, Any]] = []
if hasattr(trainer, "state") and hasattr(trainer.state, "log_history"):
maybe_history = trainer.state.log_history
if isinstance(maybe_history, list):
log_history = maybe_history
steps: list[int] = []
rewards: list[float] = []
for item in log_history:
if not isinstance(item, dict):
continue
if "step" not in item or "reward" not in item:
continue
steps.append(int(item["step"]))
rewards.append(float(item["reward"]))
return train_output, steps, rewards
def format_oom_guidance(error: Exception) -> str:
"""Return actionable guidance when training hits OOM."""
return (
f"Training failed with OOM: {error}. "
"Try reducing per_device_train_batch_size or num_generations."
)