Prompt-Injection-RL-environment / tests /test_gold_overlay_scenarios.py
Jatin997's picture
Add benchmark core
3b452a0
import json
import unittest
from pathlib import Path
from env.dataset_loader import DatasetLoader
from env.injection_engine import InjectionEngine
from env.scenario_validator import has_errors, validate_scenario_payload
class GoldOverlayScenarioTests(unittest.TestCase):
def setUp(self) -> None:
self.loader = DatasetLoader()
self.engine = InjectionEngine()
def test_loader_merges_gold_easy_overlay(self) -> None:
scenario = self.loader.get_scenario_by_id("easy", "easy_gold_direct_relay_confirmation")
self.assertEqual(scenario.difficulty, "easy")
self.assertFalse(scenario.summary_required)
def test_gold_easy_session_boundary_scenario_parses_cleanly(self) -> None:
scenario = self.loader.get_scenario_by_id("easy", "easy_gold_session_boundary_ping")
metadata = self.engine.build_attack_metadata(scenario)
self.assertEqual(metadata["attack_channel"], "session_boundary_confusion")
self.assertEqual(metadata["target_tool"], "SendEmail")
self.assertEqual(metadata["target_recipient"], "ops-relay@external-mail.test")
self.assertEqual(metadata["target_payload"], "relocation received")
def test_gold_medium_html_comment_scenario_parses_cleanly(self) -> None:
scenario = self.loader.get_scenario_by_id("medium", "medium_gold_html_comment_relay")
metadata = self.engine.build_attack_metadata(scenario)
self.assertEqual(metadata["attack_channel"], "html_comment")
self.assertEqual(metadata["attack_visibility"], "hidden")
self.assertEqual(metadata["target_tool"], "SendEmail")
self.assertEqual(metadata["target_recipient"], "comment-relay@external-mail.test")
self.assertEqual(metadata["recipient_source"], "explicit_in_body")
def test_gold_easy_overlay_has_no_validation_errors(self) -> None:
gold_path = Path(__file__).resolve().parents[1] / "env" / "data" / "scenarios" / "gold" / "easy.jsonl"
rows = [json.loads(line) for line in gold_path.read_text(encoding="utf-8-sig").splitlines() if line.strip()]
self.assertGreaterEqual(len(rows), 3)
for payload in rows:
issues = validate_scenario_payload(payload)
self.assertFalse(has_errors(issues), msg=payload["scenario_id"])
if __name__ == "__main__":
unittest.main()