drugenv-trainer / tests /test_simulator.py
anugrahteesdollar's picture
fix: include requirements-train.txt + tests (glob bug)
ad12dda verified
"""Tests for the latent-state simulator modules."""
from models import ActionType, DrugTargetAction, OutputType
from server.simulator.latent_state import (
CreditState,
DataQualityState,
FullLatentState,
TargetProfile,
ValidationProgress,
)
from server.simulator.noise import NoiseModel
from server.simulator.output_generator import OutputGenerator
from server.simulator.transition import TransitionEngine
def _make_state() -> FullLatentState:
return FullLatentState(
target=TargetProfile(
expression_level="high_specific",
tissue_specificity=0.8,
disease_overexpression=3.0,
druggability_score=0.7,
binding_pocket_quality="good",
has_known_ligands=True,
allosteric_site_available=True,
selectivity_ratio=8.0,
off_target_count=2,
off_target_genes=["OFF1", "OFF2"],
toxicity_profile="mild",
toxicity_tissues=[],
clinical_precedent="positive",
clinical_stage_reached="phase2",
in_vitro_ic50_nM=10.0,
in_vivo_efficacy="strong",
crispr_essentiality=-1.0,
true_viability_score=0.8,
correct_decision="go",
key_evidence_dimensions=["expression", "druggability"],
),
data_quality=DataQualityState(noise_level=0.1),
progress=ValidationProgress(),
credits=CreditState(credits_total=50),
)
class TestNoiseModel:
def test_deterministic_with_seed(self):
n1 = NoiseModel(seed=42)
n2 = NoiseModel(seed=42)
assert n1.sample_qc_metric(0.5, 0.1) == n2.sample_qc_metric(0.5, 0.1)
def test_false_positives(self):
n = NoiseModel(seed=0)
fps = n.generate_false_positives(1000, 0.01)
assert all(g.startswith("FP_GENE_") for g in fps)
def test_quality_degradation_bounded(self):
n = NoiseModel(seed=0)
for _ in range(100):
q = n.quality_degradation(0.9, [0.8, 0.7])
assert 0.0 <= q <= 1.0
class TestOutputGenerator:
def test_query_expression_returns_expression_result(self):
gen = OutputGenerator(NoiseModel(seed=1))
out = gen.generate(
DrugTargetAction(
action_type=ActionType.QUERY_EXPRESSION,
parameters={"database": "GTEx"},
),
_make_state(),
1,
)
assert out.output_type == OutputType.EXPRESSION_RESULT
assert out.data["expression_level"] == "high_specific"
def test_druggability_screen_returns_druggability_result(self):
gen = OutputGenerator(NoiseModel(seed=42))
out = gen.generate(
DrugTargetAction(action_type=ActionType.DRUGGABILITY_SCREEN),
_make_state(),
2,
)
assert out.output_type == OutputType.DRUGGABILITY_RESULT
assert "druggability_score" in out.data
def test_binding_site_allosteric_flag(self):
gen = OutputGenerator(NoiseModel(seed=7))
out = gen.generate(
DrugTargetAction(
action_type=ActionType.BINDING_SITE_ANALYSIS,
parameters={"include_allosteric": True},
),
_make_state(),
3,
)
assert out.output_type == OutputType.BINDING_SITE_RESULT
assert out.data["allosteric_site_detected"] is True
class TestTransitionEngine:
def test_progress_flag_set_after_action(self):
engine = TransitionEngine(NoiseModel(seed=0))
result = engine.step(
_make_state(),
DrugTargetAction(action_type=ActionType.QUERY_EXPRESSION),
)
assert result.next_state.progress.expression_queried is True
def test_credit_deduction(self):
engine = TransitionEngine(NoiseModel(seed=0))
state = _make_state()
result = engine.step(
state,
DrugTargetAction(action_type=ActionType.IN_VIVO_MODEL),
)
assert result.next_state.credits.credits_used == 8
def test_hard_violation_blocks(self):
engine = TransitionEngine(NoiseModel(seed=0))
result = engine.step(
_make_state(),
DrugTargetAction(action_type=ActionType.QUERY_EXPRESSION),
hard_violations=["test_block"],
)
assert result.output.success is False
assert result.output.output_type == OutputType.FAILURE_REPORT
def test_submit_validation_report_ends_episode(self):
engine = TransitionEngine(NoiseModel(seed=0))
result = engine.step(
_make_state(),
DrugTargetAction(
action_type=ActionType.SUBMIT_VALIDATION_REPORT,
final_decision="go",
confidence=0.8,
),
)
assert result.done is True
def test_credit_exhaustion_ends_episode(self):
engine = TransitionEngine(NoiseModel(seed=0))
state = _make_state()
state.credits.credits_used = state.credits.credits_total - 1
result = engine.step(
state,
DrugTargetAction(action_type=ActionType.QUERY_EXPRESSION),
)
assert result.done is True