drugenv-trainer / tests /test_rules.py
anugrahteesdollar's picture
fix: include requirements-train.txt + tests (glob bug)
ad12dda verified
"""Tests for the drug-target-validation rule engine."""
from models import ActionType, DrugTargetAction
from server.rules.engine import RuleEngine
from server.simulator.latent_state import (
CreditState,
FullLatentState,
TargetProfile,
ValidationProgress,
)
def _state(
*,
credits_used: int = 0,
credits_total: int = 50,
call_counts: dict | None = None,
progress: dict | None = None,
) -> FullLatentState:
return FullLatentState(
target=TargetProfile(),
progress=ValidationProgress(**(progress or {})),
credits=CreditState(credits_total=credits_total, credits_used=credits_used),
action_call_counts=dict(call_counts or {}),
)
class TestResourceConstraints:
def test_exhausted_credits_blocked(self):
engine = RuleEngine()
s = _state(credits_used=50)
violations = engine.check(
DrugTargetAction(action_type=ActionType.QUERY_EXPRESSION), s,
)
hard = engine.hard_violations(violations)
assert any("credits" in m.lower() for m in hard)
def test_insufficient_credits_blocked(self):
engine = RuleEngine()
s = _state(credits_used=49, credits_total=50)
violations = engine.check(
DrugTargetAction(action_type=ActionType.IN_VIVO_MODEL), s,
)
hard = engine.hard_violations(violations)
assert any("credits" in m.lower() for m in hard)
class TestSubmissionRules:
def test_report_without_evidence_is_blocked(self):
engine = RuleEngine()
violations = engine.check(
DrugTargetAction(
action_type=ActionType.SUBMIT_VALIDATION_REPORT,
final_decision="go",
confidence=0.8,
),
_state(),
evidence_dimensions_covered=[],
)
hard = engine.hard_violations(violations)
assert any("evidence" in m.lower() for m in hard)
def test_report_without_decision_is_blocked(self):
engine = RuleEngine()
violations = engine.check(
DrugTargetAction(
action_type=ActionType.SUBMIT_VALIDATION_REPORT,
confidence=0.8,
),
_state(),
evidence_dimensions_covered=["expression"],
)
hard = engine.hard_violations(violations)
assert any("final_decision" in m.lower() for m in hard)
def test_report_without_confidence_is_blocked(self):
engine = RuleEngine()
violations = engine.check(
DrugTargetAction(
action_type=ActionType.SUBMIT_VALIDATION_REPORT,
final_decision="no_go",
),
_state(),
evidence_dimensions_covered=["expression"],
)
hard = engine.hard_violations(violations)
assert any("confidence" in m.lower() for m in hard)
def test_low_confidence_report_is_soft_violation(self):
engine = RuleEngine()
violations = engine.check(
DrugTargetAction(
action_type=ActionType.SUBMIT_VALIDATION_REPORT,
final_decision="go",
confidence=0.10,
),
_state(),
evidence_dimensions_covered=["expression", "druggability"],
)
soft = engine.soft_violations(violations)
assert any("confidence" in m.lower() for m in soft)
class TestRedundancyAndOrdering:
def test_running_same_action_three_times_is_soft_blocked(self):
engine = RuleEngine()
violations = engine.check(
DrugTargetAction(action_type=ActionType.QUERY_EXPRESSION),
_state(call_counts={ActionType.QUERY_EXPRESSION.value: 2}),
)
soft = engine.soft_violations(violations)
assert any("redundant" in m.lower() for m in soft)
def test_in_vivo_before_in_vitro_is_soft_violation(self):
engine = RuleEngine()
violations = engine.check(
DrugTargetAction(action_type=ActionType.IN_VIVO_MODEL),
_state(),
)
soft = engine.soft_violations(violations)
assert any("in_vitro" in m.lower() for m in soft)
def test_toxicity_before_expression_is_soft_violation(self):
engine = RuleEngine()
violations = engine.check(
DrugTargetAction(action_type=ActionType.TOXICITY_PANEL),
_state(),
)
soft = engine.soft_violations(violations)
assert any("expression" in m.lower() for m in soft)
def test_in_vivo_after_in_vitro_is_clean(self):
engine = RuleEngine()
violations = engine.check(
DrugTargetAction(action_type=ActionType.IN_VIVO_MODEL),
_state(progress={"in_vitro_done": True}),
)
soft = engine.soft_violations(violations)
assert not any("in_vitro" in m.lower() for m in soft)