"""Tests for the latent-state simulator modules.""" from models import ActionType, DrugTargetAction, OutputType from server.simulator.latent_state import ( CreditState, DataQualityState, FullLatentState, TargetProfile, ValidationProgress, ) from server.simulator.noise import NoiseModel from server.simulator.output_generator import OutputGenerator from server.simulator.transition import TransitionEngine def _make_state() -> FullLatentState: return FullLatentState( target=TargetProfile( expression_level="high_specific", tissue_specificity=0.8, disease_overexpression=3.0, druggability_score=0.7, binding_pocket_quality="good", has_known_ligands=True, allosteric_site_available=True, selectivity_ratio=8.0, off_target_count=2, off_target_genes=["OFF1", "OFF2"], toxicity_profile="mild", toxicity_tissues=[], clinical_precedent="positive", clinical_stage_reached="phase2", in_vitro_ic50_nM=10.0, in_vivo_efficacy="strong", crispr_essentiality=-1.0, true_viability_score=0.8, correct_decision="go", key_evidence_dimensions=["expression", "druggability"], ), data_quality=DataQualityState(noise_level=0.1), progress=ValidationProgress(), credits=CreditState(credits_total=50), ) class TestNoiseModel: def test_deterministic_with_seed(self): n1 = NoiseModel(seed=42) n2 = NoiseModel(seed=42) assert n1.sample_qc_metric(0.5, 0.1) == n2.sample_qc_metric(0.5, 0.1) def test_false_positives(self): n = NoiseModel(seed=0) fps = n.generate_false_positives(1000, 0.01) assert all(g.startswith("FP_GENE_") for g in fps) def test_quality_degradation_bounded(self): n = NoiseModel(seed=0) for _ in range(100): q = n.quality_degradation(0.9, [0.8, 0.7]) assert 0.0 <= q <= 1.0 class TestOutputGenerator: def test_query_expression_returns_expression_result(self): gen = OutputGenerator(NoiseModel(seed=1)) out = gen.generate( DrugTargetAction( action_type=ActionType.QUERY_EXPRESSION, parameters={"database": "GTEx"}, ), _make_state(), 1, ) assert out.output_type == OutputType.EXPRESSION_RESULT assert out.data["expression_level"] == "high_specific" def test_druggability_screen_returns_druggability_result(self): gen = OutputGenerator(NoiseModel(seed=42)) out = gen.generate( DrugTargetAction(action_type=ActionType.DRUGGABILITY_SCREEN), _make_state(), 2, ) assert out.output_type == OutputType.DRUGGABILITY_RESULT assert "druggability_score" in out.data def test_binding_site_allosteric_flag(self): gen = OutputGenerator(NoiseModel(seed=7)) out = gen.generate( DrugTargetAction( action_type=ActionType.BINDING_SITE_ANALYSIS, parameters={"include_allosteric": True}, ), _make_state(), 3, ) assert out.output_type == OutputType.BINDING_SITE_RESULT assert out.data["allosteric_site_detected"] is True class TestTransitionEngine: def test_progress_flag_set_after_action(self): engine = TransitionEngine(NoiseModel(seed=0)) result = engine.step( _make_state(), DrugTargetAction(action_type=ActionType.QUERY_EXPRESSION), ) assert result.next_state.progress.expression_queried is True def test_credit_deduction(self): engine = TransitionEngine(NoiseModel(seed=0)) state = _make_state() result = engine.step( state, DrugTargetAction(action_type=ActionType.IN_VIVO_MODEL), ) assert result.next_state.credits.credits_used == 8 def test_hard_violation_blocks(self): engine = TransitionEngine(NoiseModel(seed=0)) result = engine.step( _make_state(), DrugTargetAction(action_type=ActionType.QUERY_EXPRESSION), hard_violations=["test_block"], ) assert result.output.success is False assert result.output.output_type == OutputType.FAILURE_REPORT def test_submit_validation_report_ends_episode(self): engine = TransitionEngine(NoiseModel(seed=0)) result = engine.step( _make_state(), DrugTargetAction( action_type=ActionType.SUBMIT_VALIDATION_REPORT, final_decision="go", confidence=0.8, ), ) assert result.done is True def test_credit_exhaustion_ends_episode(self): engine = TransitionEngine(NoiseModel(seed=0)) state = _make_state() state.credits.credits_used = state.credits.credits_total - 1 result = engine.step( state, DrugTargetAction(action_type=ActionType.QUERY_EXPRESSION), ) assert result.done is True