"""Unit + integration tests for the composable analyzer rubric system. Covers: 1. Each leaf rubric in isolation (DetectionRubric, MissedScamRubric, FalsePositiveRubric, CalibrationRubric, ExplanationRubric) 2. Non-terminal observations return 0.0 from every rubric 3. AnalyzerRubric composability: weighted sum matches manual calculation 4. ``last_score`` is populated on every child after each call 5. ``named_children`` / ``named_rubrics`` introspection surface 6. Weight overrides propagate correctly and invalid configs raise 7. State-dict round-trip for checkpointing 8. Integration: the rubric wired into ChakravyuhOpenEnv drives the actual ``observation.reward`` and matches the sum of its child scores """ from __future__ import annotations from types import SimpleNamespace import pytest from chakravyuh_env import ( AnalyzerRubric, CalibrationRubric, ChakravyuhAction, ChakravyuhOpenEnv, DetectionRubric, ExplanationRubric, FalsePositiveRubric, MissedScamRubric, ) from chakravyuh_env.rubrics import DEFAULT_WEIGHTS from openenv.core.rubrics import Rubric # --------------------------------------------------------------------------- # Helpers: stub action / observation # --------------------------------------------------------------------------- def _action( score: float = 0.8, signals: list[str] | None = None, explanation: str = "", ) -> ChakravyuhAction: return ChakravyuhAction( score=score, signals=signals or [], explanation=explanation, ) def _obs( *, done: bool = True, outcome: dict | None = None, ): """Duck-typed observation stub — the rubric only reads `.done` and `.outcome`, so a SimpleNamespace is enough for leaf-rubric tests.""" return SimpleNamespace(done=done, outcome=outcome) # --------------------------------------------------------------------------- # Leaf rubric unit tests # --------------------------------------------------------------------------- @pytest.mark.unit class TestDetectionRubric: def test_returns_zero_on_non_terminal(self) -> None: r = DetectionRubric() assert r(_action(), _obs(done=False, outcome=None)) == 0.0 def test_returns_zero_when_not_flagged(self) -> None: r = DetectionRubric() out = {"analyzer_flagged": False, "detected_by_turn": None, "false_positive": False} assert r(_action(), _obs(outcome=out)) == 0.0 def test_returns_one_on_early_flag(self) -> None: r = DetectionRubric() out = {"analyzer_flagged": True, "detected_by_turn": 3, "false_positive": False} assert r(_action(), _obs(outcome=out)) == 1.0 def test_returns_zero_on_late_flag(self) -> None: r = DetectionRubric() out = {"analyzer_flagged": True, "detected_by_turn": 8, "false_positive": False} assert r(_action(), _obs(outcome=out)) == 0.0 def test_returns_zero_on_false_positive(self) -> None: # Flagged early but the episode was benign → NOT a detection win. # Semantic ground truth lives on `is_benign`; the legacy # `false_positive` shortcut produced wrong answers when benign was # not flagged (audit P1-11). Tests now declare ground truth via # `is_benign` directly. r = DetectionRubric() out = { "analyzer_flagged": True, "detected_by_turn": 3, "is_benign": True, "false_positive": True, } assert r(_action(), _obs(outcome=out)) == 0.0 def test_returns_zero_on_benign_not_flagged(self) -> None: # Benign episode that didn't get flagged is not a detection (regression # test for the audit P1-11 fix — legacy code returned 1.0 here when # the action would have set `false_positive=False`). r = DetectionRubric() out = {"analyzer_flagged": False, "detected_by_turn": None, "is_benign": True} assert r(_action(), _obs(outcome=out)) == 0.0 def test_custom_cutoff(self) -> None: r = DetectionRubric(early_cutoff=4) out = {"analyzer_flagged": True, "detected_by_turn": 5, "false_positive": False} assert r(_action(), _obs(outcome=out)) == 0.0 @pytest.mark.unit class TestMissedScamRubric: def test_fires_only_when_missed_and_money_extracted(self) -> None: r = MissedScamRubric() out = {"analyzer_flagged": False, "money_extracted": True} assert r(_action(), _obs(outcome=out)) == 1.0 def test_zero_when_flagged_but_missed_money(self) -> None: r = MissedScamRubric() out = {"analyzer_flagged": True, "money_extracted": True} assert r(_action(), _obs(outcome=out)) == 0.0 def test_zero_when_no_money_extracted(self) -> None: r = MissedScamRubric() out = {"analyzer_flagged": False, "money_extracted": False} assert r(_action(), _obs(outcome=out)) == 0.0 @pytest.mark.unit class TestFalsePositiveRubric: def test_fires_only_when_flagged_benign(self) -> None: r = FalsePositiveRubric() out = {"analyzer_flagged": True, "false_positive": True} assert r(_action(), _obs(outcome=out)) == 1.0 def test_zero_when_not_flagged(self) -> None: r = FalsePositiveRubric() out = {"analyzer_flagged": False, "false_positive": True} assert r(_action(), _obs(outcome=out)) == 0.0 def test_zero_when_flagged_real_scam(self) -> None: r = FalsePositiveRubric() out = {"analyzer_flagged": True, "false_positive": False} assert r(_action(), _obs(outcome=out)) == 0.0 @pytest.mark.unit class TestCalibrationRubric: def test_perfect_score_on_scam(self) -> None: r = CalibrationRubric(scam_target=0.9) out = {"is_benign": False} assert r(_action(score=0.9), _obs(outcome=out)) == pytest.approx(1.0) def test_decreases_with_distance(self) -> None: r = CalibrationRubric(scam_target=0.9) out = {"is_benign": False} # Symmetric 0.1 distance on either side → 0.9 each. s_low = r(_action(score=0.8), _obs(outcome=out)) s_high = r(_action(score=1.0), _obs(outcome=out)) assert s_low == pytest.approx(0.9) assert s_high == pytest.approx(0.9) def test_clips_to_zero(self) -> None: r = CalibrationRubric(scam_target=0.9) out = {"is_benign": False} # |0.0 - 0.9| = 0.9 → 1 - 0.9 = 0.1 assert r(_action(score=0.0), _obs(outcome=out)) == pytest.approx(0.1) def test_benign_flips_target(self) -> None: r = CalibrationRubric(scam_target=0.9, benign_target=0.1) out = {"is_benign": True} assert r(_action(score=0.1), _obs(outcome=out)) == pytest.approx(1.0) assert r(_action(score=0.9), _obs(outcome=out)) == pytest.approx(0.2) def test_rejects_out_of_range_targets(self) -> None: with pytest.raises(ValueError): CalibrationRubric(scam_target=1.5) with pytest.raises(ValueError): CalibrationRubric(benign_target=-0.1) @pytest.mark.unit class TestExplanationRubric: def test_empty_explanation_returns_zero(self) -> None: r = ExplanationRubric() out = {} assert r(_action(explanation=""), _obs(outcome=out)) == 0.0 def test_short_baseline(self) -> None: r = ExplanationRubric() out = {} # 10-char explanation, no matching signal → baseline 0.3 only. assert r(_action(explanation="suspicious"), _obs(outcome=out)) == pytest.approx(0.3) def test_rewards_matching_signal(self) -> None: r = ExplanationRubric() out = {} score = r( _action(signals=["urgency"], explanation="clear urgency cue"), _obs(outcome=out), ) # 0.3 baseline + 0.3 signal match = 0.6 (length < 30) assert score == pytest.approx(0.6) def test_length_bonuses(self) -> None: r = ExplanationRubric() out = {} medium = "This message shows clear signs of scam behaviour via" # ~50 chars long_ = ( "This message shows clear signs of scam behaviour via urgency and " "impersonation patterns common in Indian UPI fraud cases." ) assert len(medium) >= 30 and len(medium) < 60 assert len(long_) >= 60 assert r(_action(explanation=medium), _obs(outcome=out)) == pytest.approx(0.5) assert r(_action(explanation=long_), _obs(outcome=out)) == pytest.approx(0.7) def test_max_clip(self) -> None: r = ExplanationRubric() out = {} long_with_signal = ( "High urgency info_request from impersonated bank with suspicious_link; " "typical Indian UPI fraud pattern, high confidence classification." ) score = r( _action( signals=["urgency", "info_request", "suspicious_link", "impersonation"], explanation=long_with_signal, ), _obs(outcome={}), ) assert score == 1.0 # --------------------------------------------------------------------------- # AnalyzerRubric composability tests # --------------------------------------------------------------------------- @pytest.mark.unit class TestAnalyzerRubricComposability: def test_registers_all_five_children(self) -> None: r = AnalyzerRubric() names = [n for n, _ in r.named_children()] assert names == [ "detection", "missed_scam", "false_positive", "calibration", "explanation", ] # named_rubrics is flat (our children are leaves). named = [n for n, _ in r.named_rubrics()] assert set(named) == set(names) def test_children_are_rubric_instances(self) -> None: r = AnalyzerRubric() for _, child in r.named_children(): assert isinstance(child, Rubric) def test_get_rubric_by_path(self) -> None: r = AnalyzerRubric() assert isinstance(r.get_rubric("detection"), DetectionRubric) assert isinstance(r.get_rubric("calibration"), CalibrationRubric) with pytest.raises(KeyError): r.get_rubric("no_such_child") def test_non_terminal_returns_zero(self) -> None: r = AnalyzerRubric() assert r(_action(), _obs(done=False, outcome=None)) == 0.0 def test_last_scores_populated_after_call(self) -> None: r = AnalyzerRubric() out = { "analyzer_flagged": True, "detected_by_turn": 3, "money_extracted": False, "false_positive": False, "is_benign": False, } r(_action(score=0.9, signals=["urgency"], explanation="urgent info request"), _obs(outcome=out)) last = r.last_scores() for k in ("total", "detection", "missed_scam", "false_positive", "calibration", "explanation"): assert k in last assert last[k] is not None def test_weighted_sum_matches_manual(self) -> None: r = AnalyzerRubric() out = { "analyzer_flagged": True, "detected_by_turn": 3, "money_extracted": False, "false_positive": False, "is_benign": False, } action = _action( score=0.9, signals=["urgency", "info_request"], explanation="High urgency info request with impersonation signals; typical scam.", ) total = r(action, _obs(outcome=out)) # Expected child scores: # detection = 1.0, missed = 0.0, fp = 0.0 # calibration = 1.0 (score matches 0.9 target exactly) # explanation = 0.3 base + 0.3 signal match + 0.2 (>=30) + 0.2 (>=60) = 1.0 expected = ( DEFAULT_WEIGHTS["detection"] * 1.0 + DEFAULT_WEIGHTS["missed_scam"] * 0.0 + DEFAULT_WEIGHTS["false_positive"] * 0.0 + DEFAULT_WEIGHTS["calibration"] * 1.0 + DEFAULT_WEIGHTS["explanation"] * 1.0 ) assert total == pytest.approx(expected) def test_custom_weights_override(self) -> None: new_weights = { "detection": 2.0, "missed_scam": -1.0, "false_positive": -0.5, "calibration": 0.0, "explanation": 0.0, } r = AnalyzerRubric(weights=new_weights) out = { "analyzer_flagged": True, "detected_by_turn": 3, "money_extracted": False, "false_positive": False, "is_benign": False, } total = r(_action(score=0.5, explanation=""), _obs(outcome=out)) assert total == pytest.approx(2.0) def test_missing_weight_raises(self) -> None: with pytest.raises(ValueError, match="missing keys"): AnalyzerRubric(weights={"detection": 1.0}) def test_default_weights_do_not_mutate_on_override(self) -> None: r = AnalyzerRubric(weights={**DEFAULT_WEIGHTS, "detection": 99.0}) assert r.weights["detection"] == 99.0 # Module-level DEFAULT_WEIGHTS must be unchanged. assert DEFAULT_WEIGHTS["detection"] == 1.0 def test_state_dict_round_trip(self) -> None: r = AnalyzerRubric() r.weights["detection"] = 42.0 state = r.state_dict() assert state["weights"]["detection"] == 42.0 r2 = AnalyzerRubric() r2.load_state_dict(state) assert r2.weights["detection"] == 42.0 # --------------------------------------------------------------------------- # Integration with ChakravyuhOpenEnv # --------------------------------------------------------------------------- @pytest.mark.integration class TestRubricIntegration: def test_env_uses_analyzer_rubric_by_default(self) -> None: env = ChakravyuhOpenEnv() assert isinstance(env.rubric, AnalyzerRubric) def test_rubric_reward_matches_terminal_observation_reward(self) -> None: env = ChakravyuhOpenEnv() env.reset(seed=42) obs = env.step(_action(score=0.9, signals=["urgency"], explanation="u")) if not obs.done: obs = env.step(_action(score=0.9, signals=["impersonation"], explanation="i")) # The terminal observation.reward must equal the top-level rubric # last_score (they're literally the same value by construction). assert obs.reward == pytest.approx(env.rubric.last_score) def test_reward_breakdown_includes_all_sub_rubrics(self) -> None: env = ChakravyuhOpenEnv() env.reset(seed=42) obs = env.step(_action(score=0.9)) if not obs.done: obs = env.step(_action(score=0.9)) bd = obs.reward_breakdown assert bd is not None for key in ("detection", "missed_scam", "false_positive", "calibration", "explanation"): assert key in bd def test_custom_rubric_is_used(self) -> None: class ConstantRubric(Rubric): def forward(self, action, observation): return 7.0 if getattr(observation, "done", False) else 0.0 env = ChakravyuhOpenEnv(rubric=ConstantRubric()) env.reset(seed=42) obs = env.step(_action(score=0.9)) if not obs.done: obs = env.step(_action(score=0.9)) assert obs.reward == pytest.approx(7.0) def test_rubric_reset_clears_last_score(self) -> None: env = ChakravyuhOpenEnv() env.reset(seed=42) env.step(_action(score=0.9)) env.step(_action(score=0.9)) assert env.rubric.last_score is not None # reset() must trigger a rubric reset — for stateless Rubrics # last_score stays from the previous episode, but trajectory # rubrics would have their accumulated state cleared. Verify # last_score gets re-populated on the new episode's terminal step. env.reset(seed=99) env.step(_action(score=0.9)) env.step(_action(score=0.9)) # After a full second episode, last_score is freshly set (not # a stale carry-over from the first episode). assert env.rubric.last_score is not None def test_forward_pre_hook_fires_on_every_call(self) -> None: calls: list[str] = [] def pre_hook(rubric, action, observation): calls.append("pre") def post_hook(rubric, action, observation, result): calls.append(f"post:{result:.2f}") env = ChakravyuhOpenEnv() env.rubric.register_forward_pre_hook(pre_hook) env.rubric.register_forward_hook(post_hook) env.reset(seed=42) env.step(_action(score=0.9)) env.step(_action(score=0.9)) # One rubric call per step (2 steps) — pre + post each. assert calls.count("pre") >= 1 assert any(c.startswith("post:") for c in calls) # --------------------------------------------------------------------------- # V2 anti-collapse profile + new leaf rubrics # --------------------------------------------------------------------------- @pytest.mark.unit class TestSignalAccuracyRubric: def test_zero_when_no_expected_signals(self) -> None: from chakravyuh_env import SignalAccuracyRubric r = SignalAccuracyRubric() out = {"analyzer_flagged": True, "expected_signals": ()} assert r(_action(signals=["urgency"]), _obs(outcome=out)) == 0.0 def test_full_match_returns_one(self) -> None: from chakravyuh_env import SignalAccuracyRubric r = SignalAccuracyRubric() out = {"analyzer_flagged": True, "expected_signals": ["urgency", "info_request"]} action = _action(signals=["urgency", "info_request"]) assert r(action, _obs(outcome=out)) == 1.0 def test_partial_match_proportional(self) -> None: from chakravyuh_env import SignalAccuracyRubric r = SignalAccuracyRubric() out = { "analyzer_flagged": True, "expected_signals": ["urgency", "info_request", "impersonation"], } action = _action(signals=["urgency"]) assert abs(r(action, _obs(outcome=out)) - (1 / 3)) < 1e-6 @pytest.mark.unit class TestFormatRubric: def test_zero_on_empty_explanation(self) -> None: from chakravyuh_env import FormatRubric r = FormatRubric() assert r(_action(explanation=""), _obs(outcome={"is_benign": False})) == 0.0 def test_zero_on_non_json_explanation(self) -> None: from chakravyuh_env import FormatRubric r = FormatRubric() action = _action(explanation="just plain text reasoning") assert r(action, _obs(outcome={"is_benign": False})) == 0.0 def test_one_on_well_formed_json_for_scam(self) -> None: from chakravyuh_env import FormatRubric r = FormatRubric() action = _action( score=0.9, explanation='{"score": 0.9, "explanation": "scam"}', ) out = {"is_benign": False} assert r(action, _obs(outcome=out)) == 1.0 def test_zero_on_benign_flagged_as_scam(self) -> None: """v2 anti-collapse: deny format reward when flagging benign as scam.""" from chakravyuh_env import FormatRubric r = FormatRubric() action = _action( score=0.9, # >= 0.5 default threshold → flagged explanation='{"score": 0.9, "explanation": "scam"}', ) out = {"is_benign": True} assert r(action, _obs(outcome=out)) == 0.0 @pytest.mark.unit class TestLengthRubric: def test_zero_on_empty(self) -> None: from chakravyuh_env import LengthRubric r = LengthRubric() assert r(_action(explanation=""), _obs(outcome={})) == 0.0 def test_peak_at_target(self) -> None: from chakravyuh_env import LengthRubric r = LengthRubric(target_tokens=45) # exactly 45 tokens text = " ".join(["word"] * 45) assert r(_action(explanation=text), _obs(outcome={})) == 1.0 def test_negative_above_upper_band(self) -> None: from chakravyuh_env import LengthRubric r = LengthRubric(upper_band=70) # Use a duck-typed action so we can exceed Pydantic's 300-char # explanation limit (the trainer's reward function sees raw model # output, not a validated ChakravyuhAction). text = " ".join(["w"] * 130) # 130 tokens, well above upper_band=70 action = SimpleNamespace(score=0.5, explanation=text, signals=[]) result = r(action, _obs(outcome={})) assert result == -1.0 @pytest.mark.unit class TestAnalyzerRubricV2: def test_default_weights_match_v2(self) -> None: from chakravyuh_env import AnalyzerRubricV2, V2_WEIGHTS r = AnalyzerRubricV2() assert r.weights == V2_WEIGHTS assert r.weights["false_positive"] == -0.8 assert r.weights["calibration"] == 0.5 def test_eight_children_registered(self) -> None: from chakravyuh_env import AnalyzerRubricV2 r = AnalyzerRubricV2() names = {name for name, _ in r.named_children()} assert names == { "detection", "missed_scam", "false_positive", "calibration", "explanation", "signal_accuracy", "format", "length", } def test_weighted_sum_includes_new_leaves(self) -> None: from chakravyuh_env import AnalyzerRubricV2 r = AnalyzerRubricV2() action = _action( score=0.95, signals=["urgency", "info_request"], explanation='{"score": 0.95, "explanation": "OTP urgency info_request impersonation flag"}', ) out = { "analyzer_flagged": True, "detected_by_turn": 3, "is_benign": False, "false_positive": False, "money_extracted": False, "expected_signals": ["urgency", "info_request"], } result = r(action, _obs(outcome=out)) # detection (1 * 1.0) + calibration (~0.95 * 0.5) + # explanation (1 * 0.4) + signal_accuracy (1 * 0.2) + format (1 * 0.15) # plus length contribution ~ 0 (short text). False positive = 0. assert result > 1.5 assert result < 2.5 def test_env_default_is_v2(self) -> None: from chakravyuh_env import AnalyzerRubricV2, ChakravyuhOpenEnv env = ChakravyuhOpenEnv() assert isinstance(env.rubric, AnalyzerRubricV2)