Spaces:
Sleeping
Sleeping
| """Drug Target Validation Environment. | |
| Implements the OpenEnv ``Environment`` interface as a POMDP where the | |
| agent issues one structured pharma / bioinformatics step at a time and | |
| ultimately submits a go / no_go validation report. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict, List, Optional | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| from models import ( | |
| ActionType, | |
| DrugTargetAction, | |
| EvidenceDossier, | |
| IntermediateOutput, | |
| OutputType, | |
| ValidationObservation, | |
| ValidationStepRecord, | |
| ValidationTaskSpec, | |
| ) | |
| from server.rules.engine import RuleEngine | |
| from server.rewards.reward import RewardBreakdown, RewardComputer | |
| from server.simulator.latent_state import FullLatentState | |
| from server.simulator.noise import NoiseModel | |
| from server.simulator.transition import ( | |
| ACTION_COSTS, | |
| TransitionEngine, | |
| compute_action_cost, | |
| ) | |
| from server.tasks.generator import TaskGenerator | |
| MAX_STEPS = 30 | |
| class DrugTargetEnvironment(Environment): | |
| """POMDP environment for drug target validation. | |
| The agent observes ``ValidationObservation`` (partial view) while the | |
| environment maintains a ``FullLatentState`` (hidden ``TargetProfile`` | |
| plus credit / progress state). | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__( | |
| self, | |
| scenario_name: Optional[str] = None, | |
| *, | |
| domain_randomise: bool = True, | |
| ) -> None: | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._latent: Optional[FullLatentState] = None | |
| self._task: Optional[ValidationTaskSpec] = None | |
| self._scenario_name = scenario_name | |
| self._noise = NoiseModel() | |
| self._engine = TransitionEngine(self._noise) | |
| self._rules = RuleEngine() | |
| self._rewards = RewardComputer() | |
| self._task_gen = TaskGenerator(domain_randomise=domain_randomise) | |
| self._history: List[ValidationStepRecord] = [] | |
| self._dossier: EvidenceDossier = EvidenceDossier() | |
| self._evidence_dimensions_covered: List[str] = [] | |
| self._action_history: List[str] = [] | |
| self._submitted_decision: Optional[str] = None | |
| self._submitted_confidence: Optional[float] = None | |
| self._cumulative_reward: float = 0.0 | |
| # ββ Environment interface βββββββββββββββββββββββββββββββββββββββββββ | |
| def reset(self, seed: Optional[int] = None) -> ValidationObservation: | |
| seed = seed if seed is not None else hash(uuid4()) % (2**31) | |
| self._noise.reseed(seed) | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._task, self._latent = self._task_gen.generate( | |
| seed=seed, | |
| scenario_name=self._scenario_name, | |
| ) | |
| self._latent.rng_seed = seed | |
| self._history.clear() | |
| self._dossier = EvidenceDossier( | |
| credits_used=0, | |
| ) | |
| self._evidence_dimensions_covered.clear() | |
| self._action_history.clear() | |
| self._submitted_decision = None | |
| self._submitted_confidence = None | |
| self._cumulative_reward = 0.0 | |
| return self._build_observation(reward=0.0, done=False) | |
| def step( # type: ignore[override] | |
| self, action: DrugTargetAction | |
| ) -> ValidationObservation: | |
| assert self._latent is not None, "Call reset() before step()" | |
| assert self._task is not None | |
| self._state.step_count += 1 | |
| prev_state = self._latent.model_copy(deep=True) | |
| prev_history = list(self._action_history) | |
| violations = self._rules.check( | |
| action, | |
| self._latent, | |
| evidence_dimensions_covered=self._evidence_dimensions_covered, | |
| ) | |
| hard_v = self._rules.hard_violations(violations) | |
| soft_v = self._rules.soft_violations(violations) | |
| result = self._engine.step( | |
| self._latent, | |
| action, | |
| hard_violations=hard_v, | |
| soft_violations=soft_v, | |
| ) | |
| self._latent = result.next_state | |
| self._action_history.append(action.action_type.value) | |
| step_rb = self._rewards.step_reward( | |
| action, | |
| prev_state, | |
| self._latent, | |
| result.output, | |
| hard_v, | |
| soft_v, | |
| action_history=prev_history, | |
| ) | |
| cost = compute_action_cost(action) | |
| self._history.append(ValidationStepRecord( | |
| step_index=self._state.step_count, | |
| action_type=action.action_type, | |
| parameters=action.parameters, | |
| output_summary=result.output.summary, | |
| output_type=result.output.output_type, | |
| success=result.output.success, | |
| quality_score=result.output.quality_score, | |
| credit_cost=cost, | |
| )) | |
| self._update_discoveries(action, result.output) | |
| self._dossier.credits_used = self._latent.credits.credits_used | |
| if ( | |
| action.action_type == ActionType.SUBMIT_VALIDATION_REPORT | |
| and result.output.success | |
| and not hard_v | |
| ): | |
| self._submitted_decision = action.final_decision | |
| self._submitted_confidence = action.confidence | |
| done = result.done or self._state.step_count >= MAX_STEPS | |
| terminal_rb = RewardBreakdown() | |
| if done: | |
| terminal_rb = self._rewards.terminal_reward( | |
| self._latent, | |
| final_decision=self._submitted_decision, | |
| confidence=self._submitted_confidence, | |
| action_history=list(self._action_history), | |
| ) | |
| total_reward = step_rb.total + terminal_rb.total | |
| self._cumulative_reward += total_reward | |
| breakdown = step_rb.to_dict() | |
| breakdown.update({f"term_{k}": v for k, v in terminal_rb.to_dict().items()}) | |
| return self._build_observation( | |
| reward=total_reward, | |
| done=done, | |
| latest_output=result.output, | |
| rule_violations=hard_v + soft_v, | |
| reward_breakdown=breakdown, | |
| metadata_extra={"reward_breakdown": breakdown}, | |
| ) | |
| def state(self) -> State: | |
| return self._state | |
| def set_scenario(self, scenario_name: Optional[str]) -> None: | |
| """Set the scenario used on the next reset.""" | |
| self._scenario_name = scenario_name | |
| # ββ internal helpers ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _build_observation( | |
| self, | |
| *, | |
| reward: float, | |
| done: bool, | |
| latest_output: Optional[IntermediateOutput] = None, | |
| rule_violations: Optional[List[str]] = None, | |
| reward_breakdown: Optional[Dict[str, float]] = None, | |
| metadata_extra: Optional[Dict[str, Any]] = None, | |
| ) -> ValidationObservation: | |
| assert self._task is not None | |
| assert self._latent is not None | |
| meta: Dict[str, Any] = { | |
| "episode_id": self._state.episode_id, | |
| "step": self._state.step_count, | |
| "cumulative_reward": self._cumulative_reward, | |
| } | |
| if metadata_extra: | |
| meta.update(metadata_extra) | |
| return ValidationObservation( | |
| target_gene=self._task.target_gene, | |
| disease_context=self._task.disease_context, | |
| indication=self._task.indication, | |
| credits_remaining=self._latent.credits.credits_remaining, | |
| credits_total=self._latent.credits.credits_total, | |
| dossier=self._dossier.model_copy(deep=True), | |
| pipeline_history=[h.model_dump() for h in self._history], | |
| available_actions=list(self._task.available_actions), | |
| step_index=self._state.step_count, | |
| done=done, | |
| reward=reward, | |
| step_reward_breakdown=reward_breakdown or {}, | |
| rule_violations=rule_violations or [], | |
| latest_output=latest_output, | |
| metadata=meta, | |
| ) | |
| def _update_discoveries( | |
| self, | |
| action: DrugTargetAction, | |
| output: IntermediateOutput, | |
| ) -> None: | |
| """Fold the latest output into the running ``EvidenceDossier`` and | |
| the per-dimension coverage tracker.""" | |
| if not output.success: | |
| return | |
| data = dict(output.data or {}) | |
| if output.output_type in { | |
| OutputType.EXPRESSION_RESULT, | |
| OutputType.DE_RESULT, | |
| OutputType.PATHWAY_RESULT, | |
| OutputType.COEXPRESSION_RESULT, | |
| }: | |
| self._dossier.expression_findings[action.action_type.value] = data | |
| self._track_dim("expression") | |
| if output.output_type == OutputType.PATHWAY_RESULT: | |
| self._track_dim("pathway") | |
| if output.output_type in { | |
| OutputType.STRUCTURE_RESULT, | |
| OutputType.BINDING_SITE_RESULT, | |
| OutputType.INTERACTION_RESULT, | |
| OutputType.DRUGGABILITY_RESULT, | |
| }: | |
| self._dossier.protein_findings[action.action_type.value] = data | |
| if output.output_type in { | |
| OutputType.DRUGGABILITY_RESULT, | |
| OutputType.BINDING_SITE_RESULT, | |
| }: | |
| self._track_dim("druggability") | |
| if output.output_type == OutputType.STRUCTURE_RESULT: | |
| self._track_dim("structure") | |
| if output.output_type == OutputType.INTERACTION_RESULT: | |
| self._track_dim("interactions") | |
| if output.output_type == OutputType.CLINICAL_RESULT: | |
| self._dossier.clinical_findings[action.action_type.value] = data | |
| self._track_dim("clinical") | |
| if output.output_type == OutputType.PATIENT_STRATIFICATION_RESULT: | |
| self._dossier.clinical_findings[action.action_type.value] = data | |
| self._track_dim("patient_stratification") | |
| if output.output_type in { | |
| OutputType.TOXICITY_RESULT, | |
| OutputType.OFF_TARGET_RESULT, | |
| }: | |
| self._dossier.safety_findings[action.action_type.value] = data | |
| if output.output_type == OutputType.TOXICITY_RESULT: | |
| self._track_dim("toxicity") | |
| if output.output_type == OutputType.OFF_TARGET_RESULT: | |
| self._track_dim("off_target") | |
| if output.output_type in { | |
| OutputType.LITERATURE_RESULT, | |
| OutputType.EVIDENCE_SYNTHESIS_RESULT, | |
| OutputType.COMPETITOR_LANDSCAPE_RESULT, | |
| }: | |
| self._dossier.literature_findings[action.action_type.value] = data | |
| self._track_dim("literature") | |
| if output.output_type in { | |
| OutputType.IN_VITRO_RESULT, | |
| OutputType.IN_VIVO_RESULT, | |
| OutputType.CRISPR_RESULT, | |
| OutputType.BIOMARKER_RESULT, | |
| }: | |
| entry = {"action": action.action_type.value, **data} | |
| self._dossier.experimental_results.append(entry) | |
| if output.output_type == OutputType.IN_VITRO_RESULT: | |
| self._track_dim("in_vitro") | |
| if output.output_type == OutputType.IN_VIVO_RESULT: | |
| self._track_dim("in_vivo") | |
| if output.output_type == OutputType.CRISPR_RESULT: | |
| self._track_dim("crispr") | |
| if output.output_type == OutputType.BIOMARKER_RESULT: | |
| self._track_dim("biomarker") | |
| if output.output_type == OutputType.RED_FLAG_NOTE: | |
| note = data.get("note", "(no detail)") | |
| if note not in self._dossier.flagged_red_flags: | |
| self._dossier.flagged_red_flags.append(str(note)) | |
| def _track_dim(self, dim: str) -> None: | |
| if dim not in self._evidence_dimensions_covered: | |
| self._evidence_dimensions_covered.append(dim) | |
| __all__ = ["DrugTargetEnvironment", "MAX_STEPS"] | |