Spaces:
Sleeping
Sleeping
| """ | |
| Offline step-chain extraction cache builder. | |
| Run this once before training to pre-extract structured step chains from all | |
| grounded training data (GSM8K + MATH). The resulting cache file is passed to | |
| run_grpo_training.py via --extraction-cache so the extractor LLM is never | |
| called for fixed training examples β only novel self-play solutions require | |
| live extraction during training. | |
| Usage | |
| ----- | |
| python scripts/precompute_extraction_cache.py \\ | |
| --gsm8k-data data/sft/gsm8k_sft.jsonl \\ | |
| --math-data data/sft/math_sft.jsonl \\ | |
| --output-cache data/extraction_cache.json \\ | |
| --extractor-model Qwen/Qwen2.5-0.5B-Instruct \\ | |
| --device cuda | |
| Cache key: md5(question + "\\n" + solution) β keying on both prevents | |
| collisions when two MATH problems share identical solution text. | |
| Entries for solutions the extractor cannot parse are stored with | |
| success=False so training never re-attempts and correctly penalises them. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| import pathlib | |
| import sys | |
| from typing import List, Tuple | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)-8s %(message)s", | |
| handlers=[logging.StreamHandler(sys.stdout)], | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def load_jsonl(path: str) -> list[dict]: | |
| records: list[dict] = [] | |
| with open(path, encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| try: | |
| records.append(json.loads(line)) | |
| except json.JSONDecodeError: | |
| pass | |
| return records | |
| def collect_qa_pairs(records: list[dict]) -> List[Tuple[str, str]]: | |
| """ | |
| Extract (question, solution) pairs from dataset records. | |
| Returns pairs where both fields are non-empty. Falls back to empty | |
| string for the question when only the solution field is present. | |
| """ | |
| pairs: List[Tuple[str, str]] = [] | |
| for rec in records: | |
| sol = ( | |
| rec.get("solution") | |
| or rec.get("output") | |
| or rec.get("response") | |
| or "" | |
| ) | |
| q = ( | |
| rec.get("question") | |
| or rec.get("problem") | |
| or rec.get("input") | |
| or "" | |
| ) | |
| if sol.strip(): | |
| pairs.append((q.strip(), sol.strip())) | |
| return pairs | |
| def main() -> None: | |
| parser = argparse.ArgumentParser( | |
| description="Pre-extract step chains for grounded training data." | |
| ) | |
| parser.add_argument( | |
| "--gsm8k-data", required=True, | |
| help="Path to GSM8K training JSONL (e.g. data/sft/gsm8k_sft.jsonl).", | |
| ) | |
| parser.add_argument( | |
| "--math-data", default=None, | |
| help="Optional path to MATH training JSONL. If provided, those solutions " | |
| "are also extracted and added to the cache.", | |
| ) | |
| parser.add_argument( | |
| "--output-cache", required=True, | |
| help="Destination JSON file for the extraction cache.", | |
| ) | |
| parser.add_argument( | |
| "--extractor-model", default="Qwen/Qwen2.5-0.5B-Instruct", | |
| help="HuggingFace model ID for the step chain extractor. Default Qwen/Qwen2.5-0.5B-Instruct.", | |
| ) | |
| parser.add_argument( | |
| "--device", default="cuda", | |
| help="Device for the extractor model (default: cuda).", | |
| ) | |
| parser.add_argument( | |
| "--batch-size", type=int, default=1, | |
| help="Reserved for future batched extraction. Currently always 1.", | |
| ) | |
| args = parser.parse_args() | |
| # ββ Load data βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("Loading GSM8K data from: %s", args.gsm8k_data) | |
| gsm8k_records = load_jsonl(args.gsm8k_data) | |
| qa_pairs = collect_qa_pairs(gsm8k_records) | |
| logger.info("GSM8K: %d (question, solution) pairs", len(qa_pairs)) | |
| if args.math_data: | |
| logger.info("Loading MATH data from: %s", args.math_data) | |
| math_records = load_jsonl(args.math_data) | |
| math_pairs = collect_qa_pairs(math_records) | |
| logger.info("MATH: %d (question, solution) pairs", len(math_pairs)) | |
| qa_pairs += math_pairs | |
| if not qa_pairs: | |
| logger.error( | |
| "No solutions found in provided files. " | |
| "Check field names (question/problem/input + solution/output/response)." | |
| ) | |
| sys.exit(1) | |
| # Deduplicate by (question, solution) content | |
| # Two different MATH problems can have identical solution text but different | |
| # questions β the question+solution key keeps them distinct in the cache. | |
| seen: set = set() | |
| unique_pairs: List[Tuple[str, str]] = [] | |
| for q, sol in qa_pairs: | |
| key = (q, sol) | |
| if key not in seen: | |
| seen.add(key) | |
| unique_pairs.append((q, sol)) | |
| logger.info( | |
| "Total: %d pairs (%d unique after dedup)", len(qa_pairs), len(unique_pairs) | |
| ) | |
| # ββ Load extractor ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| sys.path.insert(0, str(pathlib.Path(__file__).parent.parent)) | |
| from src.rl.unified_accuracy import StepChainExtractor | |
| extractor = StepChainExtractor( | |
| model_name=args.extractor_model, | |
| device=args.device, | |
| cache_path=args.output_cache, # load existing cache if present (resume) | |
| ) | |
| # ββ Build cache βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| already_cached = len(extractor._cache) | |
| if already_cached: | |
| logger.info("Resuming: %d entries already in cache", already_cached) | |
| extractor.build_cache(unique_pairs) | |
| # ββ Save ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| extractor.save_cache() | |
| logger.info( | |
| "Done. Cache contains %d entries β %s", | |
| len(extractor._cache), | |
| args.output_cache, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |