"""Notebook-oriented helpers for GRPO training orchestration.""" from __future__ import annotations import random from typing import Any def _precision_kwargs(precision: str) -> dict[str, bool]: """Map precision string to TRL config kwargs.""" if precision == "fp16": return {"fp16": True, "bf16": False} if precision == "bf16": return {"fp16": False, "bf16": True} if precision == "fp32": return {"fp16": False, "bf16": False} # "auto" — let TRL/transformers decide return {} def sample_random_baseline( prompts: list[str], *, step_budget: int, seed: int, ) -> list[dict[str, Any]]: """Generate simple random-action transcripts for baseline comparison.""" rng = random.Random(seed) action_types = ["DESCRIBE", "SAMPLE", "QUERY", "ANSWER"] transcripts: list[dict[str, Any]] = [] for prompt in prompts: step_count = max(1, min(step_budget, 5)) lines = [] for _ in range(step_count): action = rng.choice(action_types) argument = "table_1" if action != "QUERY" else "SELECT 1" lines.append(f"{action}: {argument}") transcripts.append( { "prompt": prompt, "completion": "\n".join(lines), "content": "\n".join(lines), "metadata": {"policy": "random", "step_count": step_count}, } ) return transcripts def build_trainer( *, model: Any, tokenizer: Any, prompts: list[str], config: Any, trl_grpo_config_cls: type, grpo_trainer_cls: type, reward_funcs: list[Any], environment_factory: type | None = None, callbacks: list[Any] | None = None, ) -> Any: """Build a GRPO trainer instance using notebook config objects.""" extra_kwargs: dict[str, Any] = {} if getattr(config, "gradient_checkpointing", False): extra_kwargs["gradient_checkpointing"] = True trainer_config = trl_grpo_config_cls( output_dir=config.output_dir, learning_rate=config.learning_rate, per_device_train_batch_size=config.per_device_train_batch_size, gradient_accumulation_steps=config.gradient_accumulation_steps, num_train_epochs=config.num_train_epochs, logging_steps=config.logging_steps, max_completion_length=config.max_new_tokens, num_generations=config.num_generations, generation_batch_size=config.num_generations, beta=getattr(config, "beta", 0.04), **_precision_kwargs(getattr(config, "precision", "auto")), **extra_kwargs, remove_unused_columns=False, log_completions=True, num_completions_to_print=1, chat_template_kwargs={ "enable_thinking": getattr(config, "enable_thinking", False), }, ) trainer_kwargs: dict[str, Any] = { "model": model, "processing_class": tokenizer, "args": trainer_config, "train_dataset": prompts, "reward_funcs": reward_funcs, } if environment_factory is not None: configure = getattr(environment_factory, "configure", None) if not callable(configure): configure = getattr(environment_factory, "_configure", None) if callable(configure): configure( questions_path=config.questions_path, db_dir=config.db_dir, step_budget=config.step_budget, ) trainer_kwargs["environment_factory"] = environment_factory if callbacks is not None: trainer_kwargs["callbacks"] = callbacks return grpo_trainer_cls( **trainer_kwargs, ) def run_training_with_metrics(trainer: Any) -> tuple[Any, list[int], list[float]]: """Run trainer.train() and extract plotting-friendly step/reward vectors.""" train_output = trainer.train() log_history: list[dict[str, Any]] = [] if hasattr(trainer, "state") and hasattr(trainer.state, "log_history"): maybe_history = trainer.state.log_history if isinstance(maybe_history, list): log_history = maybe_history steps: list[int] = [] rewards: list[float] = [] for item in log_history: if not isinstance(item, dict): continue if "step" not in item or "reward" not in item: continue steps.append(int(item["step"])) rewards.append(float(item["reward"])) return train_output, steps, rewards def format_oom_guidance(error: Exception) -> str: """Return actionable guidance when training hits OOM.""" return ( f"Training failed with OOM: {error}. " "Try reducing per_device_train_batch_size or num_generations." )