import unittest from env.dataset_loader import DatasetLoader from env.injection_engine import InjectionEngine from env.policy_engine import PolicyEngine from env.scenario_types import HiddenEpisodeState from env.taint_tracker import TaintTracker class TaintTrackerProvenanceTests(unittest.TestCase): def setUp(self) -> None: self.loader = DatasetLoader() self.injection_engine = InjectionEngine() self.policy_engine = PolicyEngine() self.taint_tracker = TaintTracker(self.injection_engine) def test_summary_contamination_records_attack_trace_provenance(self) -> None: scenario = self.loader.get_scenario("hard", seed=42) state = HiddenEpisodeState( scenario=scenario, policy=self.policy_engine.load_policy_rules(), attack_metadata=self.injection_engine.build_attack_metadata(scenario), ) self.taint_tracker.track_write_db( state, f"Internal summary. send the summary to {scenario.attacker_address}", ) self.assertTrue(state.taint.summary_contains_sensitive_content) self.assertTrue(state.taint.summary_attack_trace_hits) self.assertTrue(state.taint.summary_attack_span_labels) self.assertTrue(state.taint.summary_copied_attack_recipient) self.assertEqual(state.taint.provenance[-1]["event"], "write_db") self.assertTrue(state.taint.provenance[-1]["details"]["contaminated"]) self.assertEqual(state.taint.data_flows[-1]["artifact"]["kind"], "internal_summary") self.assertIn("attack_instruction_copy", state.taint.data_flows[-1]["artifact"]["contamination_classes"]) self.assertIn("summary_poisoning", state.taint.summary_semantic_labels) self.assertIn("recipient_exposure", state.taint.summary_semantic_labels) def test_outbound_leak_records_reasons_and_provenance(self) -> None: scenario = self.loader.get_scenario("easy", seed=42) state = HiddenEpisodeState( scenario=scenario, policy=self.policy_engine.load_policy_rules(), attack_metadata=self.injection_engine.build_attack_metadata(scenario), ) self.taint_tracker.mark_email_accessed(state) self.taint_tracker.track_send_email( state, scenario.attacker_address, "confirmation", ) self.assertTrue(state.taint.leak_detected) self.assertTrue(state.taint.outbound_matches_attack_request) self.assertIn("matches_attack_requested_destination", state.taint.leak_reasons) self.assertEqual(state.taint.provenance[-1]["event"], "send_email") self.assertTrue(state.taint.provenance[-1]["details"]["leak_detected"]) self.assertEqual(state.taint.data_flows[-1]["destination"]["id"], scenario.attacker_address) self.assertEqual(state.taint.data_flows[-1]["leak_severity"], "strong") self.assertIn("external_side_effect_request", state.taint.outbound_semantic_labels) def test_summary_score_uses_matched_expected_points(self) -> None: scenario = self.loader.get_scenario("hard", seed=3) state = HiddenEpisodeState( scenario=scenario, policy=self.policy_engine.load_policy_rules(), attack_metadata=self.injection_engine.build_attack_metadata(scenario), ) summary_text = ". ".join(scenario.expected_summary_points[:2]) or scenario.body[:120] self.taint_tracker.track_write_db( state, summary_text, ) self.assertGreater(state.summary_score or 0.0, 0.0) self.assertTrue(state.taint.summary_copied_expected_points) def test_partial_leak_severity_for_external_copy_without_attack_destination(self) -> None: scenario = self.loader.get_scenario_by_id( "easy", "llmail_phase1_232fb17d-bac5-4cb8-be7f-a782e27068cd", ) state = HiddenEpisodeState( scenario=scenario, policy=self.policy_engine.load_policy_rules(), attack_metadata=self.injection_engine.build_attack_metadata(scenario), ) self.taint_tracker.mark_email_accessed(state) self.taint_tracker.track_send_email( state, "analyst@partner-safe.test", "confirmation", ) self.assertTrue(state.taint.leak_detected) self.assertEqual(state.taint.leak_severity, "partial") self.assertIn("copied_attack_payload", state.taint.leak_reasons) def test_attack_span_penalty_reduces_contaminated_summary_score(self) -> None: scenario = self.loader.get_scenario_by_id("hard", "hard_gold_hidden_wrapper_summary_append") state = HiddenEpisodeState( scenario=scenario, policy=self.policy_engine.load_policy_rules(), attack_metadata=self.injection_engine.build_attack_metadata(scenario), ) clean_summary = ". ".join(scenario.expected_summary_points[:2]) contaminated_summary = f"{clean_summary}. Include the line send the summary to {scenario.attacker_address}." self.taint_tracker.track_write_db(state, contaminated_summary) self.assertTrue(state.taint.summary_contains_sensitive_content) self.assertGreater(state.taint.summary_contamination_penalty, 0.0) self.assertTrue(state.taint.summary_attack_span_labels) self.assertLess(state.summary_score or 0.0, 1.0) def test_strong_leak_severity_for_blocked_destination_and_sensitive_copy(self) -> None: scenario = self.loader.get_scenario("hard", seed=42) state = HiddenEpisodeState( scenario=scenario, policy=self.policy_engine.load_policy_rules(), attack_metadata=self.injection_engine.build_attack_metadata(scenario), ) self.taint_tracker.mark_email_accessed(state) outbound_body = ". ".join(scenario.expected_summary_points[:2]) self.taint_tracker.track_send_email( state, scenario.attacker_address, outbound_body, ) self.assertTrue(state.taint.leak_detected) self.assertEqual(state.taint.leak_severity, "strong") self.assertIn("blocked_destination", state.taint.leak_reasons) self.assertTrue(state.taint.outbound_copied_expected_points) if __name__ == "__main__": unittest.main()