"""Drug Target Validation Environment. Implements the OpenEnv ``Environment`` interface as a POMDP where the agent issues one structured pharma / bioinformatics step at a time and ultimately submits a go / no_go validation report. """ from __future__ import annotations from typing import Any, Dict, List, Optional from uuid import uuid4 from openenv.core.env_server.interfaces import Environment from openenv.core.env_server.types import State from models import ( ActionType, DrugTargetAction, EvidenceDossier, IntermediateOutput, OutputType, ValidationObservation, ValidationStepRecord, ValidationTaskSpec, ) from server.rules.engine import RuleEngine from server.rewards.reward import RewardBreakdown, RewardComputer from server.simulator.latent_state import FullLatentState from server.simulator.noise import NoiseModel from server.simulator.transition import ( ACTION_COSTS, TransitionEngine, compute_action_cost, ) from server.tasks.generator import TaskGenerator MAX_STEPS = 30 class DrugTargetEnvironment(Environment): """POMDP environment for drug target validation. The agent observes ``ValidationObservation`` (partial view) while the environment maintains a ``FullLatentState`` (hidden ``TargetProfile`` plus credit / progress state). """ SUPPORTS_CONCURRENT_SESSIONS: bool = True def __init__( self, scenario_name: Optional[str] = None, *, domain_randomise: bool = True, ) -> None: self._state = State(episode_id=str(uuid4()), step_count=0) self._latent: Optional[FullLatentState] = None self._task: Optional[ValidationTaskSpec] = None self._scenario_name = scenario_name self._noise = NoiseModel() self._engine = TransitionEngine(self._noise) self._rules = RuleEngine() self._rewards = RewardComputer() self._task_gen = TaskGenerator(domain_randomise=domain_randomise) self._history: List[ValidationStepRecord] = [] self._dossier: EvidenceDossier = EvidenceDossier() self._evidence_dimensions_covered: List[str] = [] self._action_history: List[str] = [] self._submitted_decision: Optional[str] = None self._submitted_confidence: Optional[float] = None self._cumulative_reward: float = 0.0 # ── Environment interface ─────────────────────────────────────────── def reset(self, seed: Optional[int] = None) -> ValidationObservation: seed = seed if seed is not None else hash(uuid4()) % (2**31) self._noise.reseed(seed) self._state = State(episode_id=str(uuid4()), step_count=0) self._task, self._latent = self._task_gen.generate( seed=seed, scenario_name=self._scenario_name, ) self._latent.rng_seed = seed self._history.clear() self._dossier = EvidenceDossier( credits_used=0, ) self._evidence_dimensions_covered.clear() self._action_history.clear() self._submitted_decision = None self._submitted_confidence = None self._cumulative_reward = 0.0 return self._build_observation(reward=0.0, done=False) def step( # type: ignore[override] self, action: DrugTargetAction ) -> ValidationObservation: assert self._latent is not None, "Call reset() before step()" assert self._task is not None self._state.step_count += 1 prev_state = self._latent.model_copy(deep=True) prev_history = list(self._action_history) violations = self._rules.check( action, self._latent, evidence_dimensions_covered=self._evidence_dimensions_covered, ) hard_v = self._rules.hard_violations(violations) soft_v = self._rules.soft_violations(violations) result = self._engine.step( self._latent, action, hard_violations=hard_v, soft_violations=soft_v, ) self._latent = result.next_state self._action_history.append(action.action_type.value) step_rb = self._rewards.step_reward( action, prev_state, self._latent, result.output, hard_v, soft_v, action_history=prev_history, ) cost = compute_action_cost(action) self._history.append(ValidationStepRecord( step_index=self._state.step_count, action_type=action.action_type, parameters=action.parameters, output_summary=result.output.summary, output_type=result.output.output_type, success=result.output.success, quality_score=result.output.quality_score, credit_cost=cost, )) self._update_discoveries(action, result.output) self._dossier.credits_used = self._latent.credits.credits_used if ( action.action_type == ActionType.SUBMIT_VALIDATION_REPORT and result.output.success and not hard_v ): self._submitted_decision = action.final_decision self._submitted_confidence = action.confidence done = result.done or self._state.step_count >= MAX_STEPS terminal_rb = RewardBreakdown() if done: terminal_rb = self._rewards.terminal_reward( self._latent, final_decision=self._submitted_decision, confidence=self._submitted_confidence, action_history=list(self._action_history), ) total_reward = step_rb.total + terminal_rb.total self._cumulative_reward += total_reward breakdown = step_rb.to_dict() breakdown.update({f"term_{k}": v for k, v in terminal_rb.to_dict().items()}) return self._build_observation( reward=total_reward, done=done, latest_output=result.output, rule_violations=hard_v + soft_v, reward_breakdown=breakdown, metadata_extra={"reward_breakdown": breakdown}, ) @property def state(self) -> State: return self._state def set_scenario(self, scenario_name: Optional[str]) -> None: """Set the scenario used on the next reset.""" self._scenario_name = scenario_name # ── internal helpers ──────────────────────────────────────────────── def _build_observation( self, *, reward: float, done: bool, latest_output: Optional[IntermediateOutput] = None, rule_violations: Optional[List[str]] = None, reward_breakdown: Optional[Dict[str, float]] = None, metadata_extra: Optional[Dict[str, Any]] = None, ) -> ValidationObservation: assert self._task is not None assert self._latent is not None meta: Dict[str, Any] = { "episode_id": self._state.episode_id, "step": self._state.step_count, "cumulative_reward": self._cumulative_reward, } if metadata_extra: meta.update(metadata_extra) return ValidationObservation( target_gene=self._task.target_gene, disease_context=self._task.disease_context, indication=self._task.indication, credits_remaining=self._latent.credits.credits_remaining, credits_total=self._latent.credits.credits_total, dossier=self._dossier.model_copy(deep=True), pipeline_history=[h.model_dump() for h in self._history], available_actions=list(self._task.available_actions), step_index=self._state.step_count, done=done, reward=reward, step_reward_breakdown=reward_breakdown or {}, rule_violations=rule_violations or [], latest_output=latest_output, metadata=meta, ) def _update_discoveries( self, action: DrugTargetAction, output: IntermediateOutput, ) -> None: """Fold the latest output into the running ``EvidenceDossier`` and the per-dimension coverage tracker.""" if not output.success: return data = dict(output.data or {}) if output.output_type in { OutputType.EXPRESSION_RESULT, OutputType.DE_RESULT, OutputType.PATHWAY_RESULT, OutputType.COEXPRESSION_RESULT, }: self._dossier.expression_findings[action.action_type.value] = data self._track_dim("expression") if output.output_type == OutputType.PATHWAY_RESULT: self._track_dim("pathway") if output.output_type in { OutputType.STRUCTURE_RESULT, OutputType.BINDING_SITE_RESULT, OutputType.INTERACTION_RESULT, OutputType.DRUGGABILITY_RESULT, }: self._dossier.protein_findings[action.action_type.value] = data if output.output_type in { OutputType.DRUGGABILITY_RESULT, OutputType.BINDING_SITE_RESULT, }: self._track_dim("druggability") if output.output_type == OutputType.STRUCTURE_RESULT: self._track_dim("structure") if output.output_type == OutputType.INTERACTION_RESULT: self._track_dim("interactions") if output.output_type == OutputType.CLINICAL_RESULT: self._dossier.clinical_findings[action.action_type.value] = data self._track_dim("clinical") if output.output_type == OutputType.PATIENT_STRATIFICATION_RESULT: self._dossier.clinical_findings[action.action_type.value] = data self._track_dim("patient_stratification") if output.output_type in { OutputType.TOXICITY_RESULT, OutputType.OFF_TARGET_RESULT, }: self._dossier.safety_findings[action.action_type.value] = data if output.output_type == OutputType.TOXICITY_RESULT: self._track_dim("toxicity") if output.output_type == OutputType.OFF_TARGET_RESULT: self._track_dim("off_target") if output.output_type in { OutputType.LITERATURE_RESULT, OutputType.EVIDENCE_SYNTHESIS_RESULT, OutputType.COMPETITOR_LANDSCAPE_RESULT, }: self._dossier.literature_findings[action.action_type.value] = data self._track_dim("literature") if output.output_type in { OutputType.IN_VITRO_RESULT, OutputType.IN_VIVO_RESULT, OutputType.CRISPR_RESULT, OutputType.BIOMARKER_RESULT, }: entry = {"action": action.action_type.value, **data} self._dossier.experimental_results.append(entry) if output.output_type == OutputType.IN_VITRO_RESULT: self._track_dim("in_vitro") if output.output_type == OutputType.IN_VIVO_RESULT: self._track_dim("in_vivo") if output.output_type == OutputType.CRISPR_RESULT: self._track_dim("crispr") if output.output_type == OutputType.BIOMARKER_RESULT: self._track_dim("biomarker") if output.output_type == OutputType.RED_FLAG_NOTE: note = data.get("note", "(no detail)") if note not in self._dossier.flagged_red_flags: self._dossier.flagged_red_flags.append(str(note)) def _track_dim(self, dim: str) -> None: if dim not in self._evidence_dimensions_covered: self._evidence_dimensions_covered.append(dim) __all__ = ["DrugTargetEnvironment", "MAX_STEPS"]