drugenv / server /hackathon_environment.py
anugrahteesdollar's picture
initial: drugenv FastAPI + gradio demo
77e1e28 verified
"""Drug Target Validation Environment.
Implements the OpenEnv ``Environment`` interface as a POMDP where the
agent issues one structured pharma / bioinformatics step at a time and
ultimately submits a go / no_go validation report.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
from models import (
ActionType,
DrugTargetAction,
EvidenceDossier,
IntermediateOutput,
OutputType,
ValidationObservation,
ValidationStepRecord,
ValidationTaskSpec,
)
from server.rules.engine import RuleEngine
from server.rewards.reward import RewardBreakdown, RewardComputer
from server.simulator.latent_state import FullLatentState
from server.simulator.noise import NoiseModel
from server.simulator.transition import (
ACTION_COSTS,
TransitionEngine,
compute_action_cost,
)
from server.tasks.generator import TaskGenerator
MAX_STEPS = 30
class DrugTargetEnvironment(Environment):
"""POMDP environment for drug target validation.
The agent observes ``ValidationObservation`` (partial view) while the
environment maintains a ``FullLatentState`` (hidden ``TargetProfile``
plus credit / progress state).
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(
self,
scenario_name: Optional[str] = None,
*,
domain_randomise: bool = True,
) -> None:
self._state = State(episode_id=str(uuid4()), step_count=0)
self._latent: Optional[FullLatentState] = None
self._task: Optional[ValidationTaskSpec] = None
self._scenario_name = scenario_name
self._noise = NoiseModel()
self._engine = TransitionEngine(self._noise)
self._rules = RuleEngine()
self._rewards = RewardComputer()
self._task_gen = TaskGenerator(domain_randomise=domain_randomise)
self._history: List[ValidationStepRecord] = []
self._dossier: EvidenceDossier = EvidenceDossier()
self._evidence_dimensions_covered: List[str] = []
self._action_history: List[str] = []
self._submitted_decision: Optional[str] = None
self._submitted_confidence: Optional[float] = None
self._cumulative_reward: float = 0.0
# ── Environment interface ───────────────────────────────────────────
def reset(self, seed: Optional[int] = None) -> ValidationObservation:
seed = seed if seed is not None else hash(uuid4()) % (2**31)
self._noise.reseed(seed)
self._state = State(episode_id=str(uuid4()), step_count=0)
self._task, self._latent = self._task_gen.generate(
seed=seed,
scenario_name=self._scenario_name,
)
self._latent.rng_seed = seed
self._history.clear()
self._dossier = EvidenceDossier(
credits_used=0,
)
self._evidence_dimensions_covered.clear()
self._action_history.clear()
self._submitted_decision = None
self._submitted_confidence = None
self._cumulative_reward = 0.0
return self._build_observation(reward=0.0, done=False)
def step( # type: ignore[override]
self, action: DrugTargetAction
) -> ValidationObservation:
assert self._latent is not None, "Call reset() before step()"
assert self._task is not None
self._state.step_count += 1
prev_state = self._latent.model_copy(deep=True)
prev_history = list(self._action_history)
violations = self._rules.check(
action,
self._latent,
evidence_dimensions_covered=self._evidence_dimensions_covered,
)
hard_v = self._rules.hard_violations(violations)
soft_v = self._rules.soft_violations(violations)
result = self._engine.step(
self._latent,
action,
hard_violations=hard_v,
soft_violations=soft_v,
)
self._latent = result.next_state
self._action_history.append(action.action_type.value)
step_rb = self._rewards.step_reward(
action,
prev_state,
self._latent,
result.output,
hard_v,
soft_v,
action_history=prev_history,
)
cost = compute_action_cost(action)
self._history.append(ValidationStepRecord(
step_index=self._state.step_count,
action_type=action.action_type,
parameters=action.parameters,
output_summary=result.output.summary,
output_type=result.output.output_type,
success=result.output.success,
quality_score=result.output.quality_score,
credit_cost=cost,
))
self._update_discoveries(action, result.output)
self._dossier.credits_used = self._latent.credits.credits_used
if (
action.action_type == ActionType.SUBMIT_VALIDATION_REPORT
and result.output.success
and not hard_v
):
self._submitted_decision = action.final_decision
self._submitted_confidence = action.confidence
done = result.done or self._state.step_count >= MAX_STEPS
terminal_rb = RewardBreakdown()
if done:
terminal_rb = self._rewards.terminal_reward(
self._latent,
final_decision=self._submitted_decision,
confidence=self._submitted_confidence,
action_history=list(self._action_history),
)
total_reward = step_rb.total + terminal_rb.total
self._cumulative_reward += total_reward
breakdown = step_rb.to_dict()
breakdown.update({f"term_{k}": v for k, v in terminal_rb.to_dict().items()})
return self._build_observation(
reward=total_reward,
done=done,
latest_output=result.output,
rule_violations=hard_v + soft_v,
reward_breakdown=breakdown,
metadata_extra={"reward_breakdown": breakdown},
)
@property
def state(self) -> State:
return self._state
def set_scenario(self, scenario_name: Optional[str]) -> None:
"""Set the scenario used on the next reset."""
self._scenario_name = scenario_name
# ── internal helpers ────────────────────────────────────────────────
def _build_observation(
self,
*,
reward: float,
done: bool,
latest_output: Optional[IntermediateOutput] = None,
rule_violations: Optional[List[str]] = None,
reward_breakdown: Optional[Dict[str, float]] = None,
metadata_extra: Optional[Dict[str, Any]] = None,
) -> ValidationObservation:
assert self._task is not None
assert self._latent is not None
meta: Dict[str, Any] = {
"episode_id": self._state.episode_id,
"step": self._state.step_count,
"cumulative_reward": self._cumulative_reward,
}
if metadata_extra:
meta.update(metadata_extra)
return ValidationObservation(
target_gene=self._task.target_gene,
disease_context=self._task.disease_context,
indication=self._task.indication,
credits_remaining=self._latent.credits.credits_remaining,
credits_total=self._latent.credits.credits_total,
dossier=self._dossier.model_copy(deep=True),
pipeline_history=[h.model_dump() for h in self._history],
available_actions=list(self._task.available_actions),
step_index=self._state.step_count,
done=done,
reward=reward,
step_reward_breakdown=reward_breakdown or {},
rule_violations=rule_violations or [],
latest_output=latest_output,
metadata=meta,
)
def _update_discoveries(
self,
action: DrugTargetAction,
output: IntermediateOutput,
) -> None:
"""Fold the latest output into the running ``EvidenceDossier`` and
the per-dimension coverage tracker."""
if not output.success:
return
data = dict(output.data or {})
if output.output_type in {
OutputType.EXPRESSION_RESULT,
OutputType.DE_RESULT,
OutputType.PATHWAY_RESULT,
OutputType.COEXPRESSION_RESULT,
}:
self._dossier.expression_findings[action.action_type.value] = data
self._track_dim("expression")
if output.output_type == OutputType.PATHWAY_RESULT:
self._track_dim("pathway")
if output.output_type in {
OutputType.STRUCTURE_RESULT,
OutputType.BINDING_SITE_RESULT,
OutputType.INTERACTION_RESULT,
OutputType.DRUGGABILITY_RESULT,
}:
self._dossier.protein_findings[action.action_type.value] = data
if output.output_type in {
OutputType.DRUGGABILITY_RESULT,
OutputType.BINDING_SITE_RESULT,
}:
self._track_dim("druggability")
if output.output_type == OutputType.STRUCTURE_RESULT:
self._track_dim("structure")
if output.output_type == OutputType.INTERACTION_RESULT:
self._track_dim("interactions")
if output.output_type == OutputType.CLINICAL_RESULT:
self._dossier.clinical_findings[action.action_type.value] = data
self._track_dim("clinical")
if output.output_type == OutputType.PATIENT_STRATIFICATION_RESULT:
self._dossier.clinical_findings[action.action_type.value] = data
self._track_dim("patient_stratification")
if output.output_type in {
OutputType.TOXICITY_RESULT,
OutputType.OFF_TARGET_RESULT,
}:
self._dossier.safety_findings[action.action_type.value] = data
if output.output_type == OutputType.TOXICITY_RESULT:
self._track_dim("toxicity")
if output.output_type == OutputType.OFF_TARGET_RESULT:
self._track_dim("off_target")
if output.output_type in {
OutputType.LITERATURE_RESULT,
OutputType.EVIDENCE_SYNTHESIS_RESULT,
OutputType.COMPETITOR_LANDSCAPE_RESULT,
}:
self._dossier.literature_findings[action.action_type.value] = data
self._track_dim("literature")
if output.output_type in {
OutputType.IN_VITRO_RESULT,
OutputType.IN_VIVO_RESULT,
OutputType.CRISPR_RESULT,
OutputType.BIOMARKER_RESULT,
}:
entry = {"action": action.action_type.value, **data}
self._dossier.experimental_results.append(entry)
if output.output_type == OutputType.IN_VITRO_RESULT:
self._track_dim("in_vitro")
if output.output_type == OutputType.IN_VIVO_RESULT:
self._track_dim("in_vivo")
if output.output_type == OutputType.CRISPR_RESULT:
self._track_dim("crispr")
if output.output_type == OutputType.BIOMARKER_RESULT:
self._track_dim("biomarker")
if output.output_type == OutputType.RED_FLAG_NOTE:
note = data.get("note", "(no detail)")
if note not in self._dossier.flagged_red_flags:
self._dossier.flagged_red_flags.append(str(note))
def _track_dim(self, dim: str) -> None:
if dim not in self._evidence_dimensions_covered:
self._evidence_dimensions_covered.append(dim)
__all__ = ["DrugTargetEnvironment", "MAX_STEPS"]