"""Data/model loading helpers for the GRPO training notebook.""" from __future__ import annotations import json from pathlib import Path from typing import Any from transformers import AutoModelForCausalLM, AutoTokenizer def filter_questions_by_difficulty( questions: list[dict[str, Any]], allowed: list[str] | None ) -> list[dict[str, Any]]: """Filter question records by case-insensitive difficulty labels.""" if not allowed: return questions allowed_set = {level.lower() for level in allowed} return [ question for question in questions if str(question.get("difficulty", "")).lower() in allowed_set ] def load_question_prompts( questions_path: str, allowed: list[str] | None = None ) -> list[dict[str, str]]: """Load question text prompts from JSON and apply difficulty filtering.""" path = Path(questions_path) if not path.exists(): raise FileNotFoundError(f"Questions file not found: {questions_path}") try: payload = json.loads(path.read_text(encoding="utf-8")) except json.JSONDecodeError as exc: raise ValueError(f"Invalid JSON in questions file: {questions_path}") from exc if not isinstance(payload, list) or not payload: raise ValueError(f"Questions file is empty or invalid: {questions_path}") filtered = filter_questions_by_difficulty(payload, allowed) if not filtered: raise ValueError( f"No questions match difficulty_filter={allowed} in {questions_path}" ) prompts = [ {"prompt": str(item["question_text"])} for item in filtered if item.get("question_text") ] if not prompts: raise ValueError(f"No usable question_text values found in {questions_path}") return prompts def validate_no_data_leak( train_path: str, eval_path: str, ) -> None: """Assert zero question overlap between train and eval sets. Raises ------ ValueError If any question text appears in both files. """ train = json.loads(Path(train_path).read_text(encoding="utf-8")) eval_ = json.loads(Path(eval_path).read_text(encoding="utf-8")) train_qs = {q["question_text"] for q in train if "question_text" in q} eval_qs = {q["question_text"] for q in eval_ if "question_text" in q} overlap = train_qs & eval_qs if overlap: examples = list(overlap)[:3] raise ValueError( f"Data leak: {len(overlap)} questions appear in both train and eval. " f"Examples: {examples}" ) def load_model_and_tokenizer(model_name: str) -> tuple[Any, Any]: """Load HuggingFace tokenizer and model with fail-fast errors.""" try: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) except Exception as exc: # pragma: no cover - covered by monkeypatched tests raise RuntimeError(f"Cannot load model '{model_name}': {exc}") from exc return model, tokenizer