from __future__ import annotations import json import os import random from pathlib import Path from typing import Dict, List, Optional from .scenario_types import EmailScenario class DatasetLoader: def __init__(self, data_dir: Optional[Path] = None) -> None: env_data_dir = os.getenv("IDPI_SCENARIO_DIR") self._data_dir = data_dir or (Path(env_data_dir) if env_data_dir else Path(__file__).resolve().parent / "data" / "scenarios") self._cache: Dict[str, List[EmailScenario]] = {} def load_scenarios(self, difficulty: str) -> List[EmailScenario]: normalized = difficulty.lower() if normalized in self._cache: return self._cache[normalized] scenarios: List[EmailScenario] = [] seen_ids = set() scenario_files = self._scenario_files(normalized) if not scenario_files: raise FileNotFoundError(f"Missing scenario file for difficulty '{normalized}' in {self._data_dir}") for scenario_path in scenario_files: with scenario_path.open("r", encoding="utf-8-sig") as handle: for line_number, raw_line in enumerate(handle, start=1): line = raw_line.strip() if not line: continue payload = json.loads(line) try: scenario = EmailScenario.from_dict(payload) except KeyError as exc: raise ValueError( f"Scenario file {scenario_path} is missing field {exc} on line {line_number}" ) from exc if scenario.scenario_id in seen_ids: continue seen_ids.add(scenario.scenario_id) scenarios.append(scenario) if not scenarios: raise ValueError(f"No scenarios found for difficulty '{normalized}' in {self._data_dir}") self._cache[normalized] = scenarios return scenarios def get_scenario(self, difficulty: str, seed: Optional[int] = None) -> EmailScenario: scenarios = self.load_scenarios(difficulty) chooser = random.Random(seed) if seed is not None else random return chooser.choice(scenarios) def get_scenario_by_id(self, difficulty: str, scenario_id: str) -> EmailScenario: for scenario in self.load_scenarios(difficulty): if scenario.scenario_id == scenario_id: return scenario raise KeyError(f"Unknown scenario_id '{scenario_id}' for difficulty '{difficulty}'") def _scenario_files(self, difficulty: str) -> List[Path]: files = [self._data_dir / f"{difficulty}.jsonl"] gold_path = self._data_dir / "gold" / f"{difficulty}.jsonl" if gold_path.exists(): files.append(gold_path) return [path for path in files if path.exists()]