Prompt-Injection-RL-environment / scripts /curate_llmail_dataset.py
Jatin997's picture
Add benchmark core
3b452a0
from __future__ import annotations
import argparse
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from env.llmail_curator import (
DIFFICULTIES,
curate_rows,
load_llmail_contexts,
load_rows_from_hf,
load_rows_from_jsonl,
write_curated_scenarios,
)
DEFAULT_OUTPUT_DIR = ROOT / "env" / "data" / "scenarios" / "curated"
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Curate LLMail submissions into local easy/medium/hard scenario files for the IDPI mail benchmark."
)
parser.add_argument("--phase", action="append", choices=["Phase1", "Phase2"], help="Repeat to choose one or both Hugging Face splits. Defaults to both.")
parser.add_argument("--input-jsonl-phase1", type=Path, help="Optional local Phase1 JSONL file.")
parser.add_argument("--input-jsonl-phase2", type=Path, help="Optional local Phase2 JSONL file.")
parser.add_argument("--scenarios-json", type=Path, help="Optional local copy of LLMail scenarios.json.")
parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR)
parser.add_argument("--prefix", default="")
parser.add_argument("--max-per-difficulty", type=int, default=50)
parser.add_argument("--sample-limit", type=int, default=None, help="Optional cap per split for quick dry runs.")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--overwrite", action="store_true")
return parser.parse_args()
def rows_for_phase(phase: str, args: argparse.Namespace):
if phase == "Phase1" and args.input_jsonl_phase1:
return load_rows_from_jsonl(args.input_jsonl_phase1, sample_limit=args.sample_limit)
if phase == "Phase2" and args.input_jsonl_phase2:
return load_rows_from_jsonl(args.input_jsonl_phase2, sample_limit=args.sample_limit)
return load_rows_from_hf(phase, sample_limit=args.sample_limit)
def main() -> None:
args = parse_args()
phases = args.phase or ["Phase1", "Phase2"]
contexts = load_llmail_contexts(local_path=args.scenarios_json)
merged = {difficulty: [] for difficulty in DIFFICULTIES}
for phase in phases:
curated = curate_rows(
rows_for_phase(phase, args),
phase=phase,
contexts=contexts,
max_per_difficulty=args.max_per_difficulty,
seed=args.seed,
)
for difficulty in DIFFICULTIES:
merged[difficulty].extend(curated[difficulty])
written = write_curated_scenarios(
merged,
output_dir=args.output_dir,
prefix=args.prefix,
overwrite=args.overwrite,
)
for path in written:
with path.open("r", encoding="utf-8-sig") as handle:
count = sum(1 for _ in handle)
print(f"wrote {count:>3} scenarios -> {path}")
if __name__ == "__main__":
main()