drugenv-trainer / tests /test_models.py
anugrahteesdollar's picture
fix: include requirements-train.txt + tests (glob bug)
ad12dda verified
"""Tests for the drug-target-validation POMDP schema models."""
import pytest
from models import (
ActionType,
CreditUsage,
DrugTargetAction,
EvidenceDossier,
IntermediateOutput,
OutputType,
ValidationObservation,
ValidationStepRecord,
ValidationTaskSpec,
)
def test_drug_target_action_roundtrip():
a = DrugTargetAction(
action_type=ActionType.QUERY_EXPRESSION,
parameters={"database": "GTEx"},
reasoning="Establish tissue baseline",
)
d = a.model_dump()
assert d["action_type"] == "query_expression"
assert d["parameters"]["database"] == "GTEx"
reconstructed = DrugTargetAction(**d)
assert reconstructed.action_type == ActionType.QUERY_EXPRESSION
def test_drug_target_action_terminal_payload():
a = DrugTargetAction(
action_type=ActionType.SUBMIT_VALIDATION_REPORT,
reasoning="All evidence gathered",
final_decision="go",
confidence=0.82,
)
assert a.final_decision == "go"
assert a.confidence == pytest.approx(0.82)
def test_validation_observation_defaults():
obs = ValidationObservation(done=False, reward=0.0)
assert obs.step_index == 0
assert obs.pipeline_history == []
assert obs.credits_remaining == 50
assert obs.credits_total == 50
assert isinstance(obs.dossier, EvidenceDossier)
def test_intermediate_output_quality_bounds():
with pytest.raises(Exception):
IntermediateOutput(
output_type=OutputType.EXPRESSION_RESULT,
step_index=1,
quality_score=1.5,
)
def test_validation_task_spec_defaults():
t = ValidationTaskSpec()
assert t.credits_limit == 50
assert t.expected_findings == []
assert "submit_validation_report" in t.available_actions
def test_validation_step_record_roundtrip():
rec = ValidationStepRecord(
step_index=2,
action_type=ActionType.DRUGGABILITY_SCREEN,
parameters={"include_allosteric": True},
output_summary="Druggability score=0.81",
output_type=OutputType.DRUGGABILITY_RESULT,
success=True,
quality_score=0.9,
credit_cost=3,
)
d = rec.model_dump()
assert d["action_type"] == "druggability_screen"
assert d["credit_cost"] == 3
def test_credit_usage_defaults():
cu = CreditUsage()
assert cu.credits_total == 50
assert cu.credits_remaining == 50
assert cu.credits_used == 0