Bio-EnvRL / models.py
Ev3Dev's picture
Upload folder using huggingface_hub
df98fca verified
"""
Data models for the Bio-Experiment Planning RL Environment.
Defines the POMDP action and observation contracts for a scientific agent
that constructs biological experiment pipelines step-by-step.
"""
from __future__ import annotations
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from openenv.core.env_server.types import Action, Observation
# ── Action vocabulary ───────────────────────────────────────────────────────
class ActionType(str, Enum):
COLLECT_SAMPLE = "collect_sample"
SELECT_COHORT = "select_cohort"
PREPARE_LIBRARY = "prepare_library"
CULTURE_CELLS = "culture_cells"
PERTURB_GENE = "perturb_gene"
PERTURB_COMPOUND = "perturb_compound"
SEQUENCE_CELLS = "sequence_cells"
RUN_QC = "run_qc"
FILTER_DATA = "filter_data"
NORMALIZE_DATA = "normalize_data"
INTEGRATE_BATCHES = "integrate_batches"
CLUSTER_CELLS = "cluster_cells"
DIFFERENTIAL_EXPRESSION = "differential_expression"
TRAJECTORY_ANALYSIS = "trajectory_analysis"
PATHWAY_ENRICHMENT = "pathway_enrichment"
REGULATORY_NETWORK_INFERENCE = "regulatory_network_inference"
MARKER_SELECTION = "marker_selection"
VALIDATE_MARKER = "validate_marker"
DESIGN_FOLLOWUP = "design_followup_experiment"
REQUEST_SUBAGENT_REVIEW = "request_subagent_review"
SYNTHESIZE_CONCLUSION = "synthesize_conclusion"
WET_LAB_ACTIONS = frozenset({
ActionType.COLLECT_SAMPLE,
ActionType.SELECT_COHORT,
ActionType.PREPARE_LIBRARY,
ActionType.CULTURE_CELLS,
ActionType.PERTURB_GENE,
ActionType.PERTURB_COMPOUND,
ActionType.SEQUENCE_CELLS,
ActionType.VALIDATE_MARKER,
})
COMPUTATIONAL_ACTIONS = frozenset({
ActionType.RUN_QC,
ActionType.FILTER_DATA,
ActionType.NORMALIZE_DATA,
ActionType.INTEGRATE_BATCHES,
ActionType.CLUSTER_CELLS,
ActionType.DIFFERENTIAL_EXPRESSION,
ActionType.TRAJECTORY_ANALYSIS,
ActionType.PATHWAY_ENRICHMENT,
ActionType.REGULATORY_NETWORK_INFERENCE,
ActionType.MARKER_SELECTION,
})
META_ACTIONS = frozenset({
ActionType.DESIGN_FOLLOWUP,
ActionType.REQUEST_SUBAGENT_REVIEW,
ActionType.SYNTHESIZE_CONCLUSION,
})
class SubagentType(str, Enum):
WET_LAB_PLANNER = "wet_lab_planner"
COMPUTATIONAL_ANALYST = "computational_analyst"
OMICS_QC_AGENT = "omics_qc_agent"
CAUSAL_REASONING_AGENT = "causal_reasoning_agent"
BUDGET_SCHEDULER = "budget_scheduler"
BIOLOGICAL_RULE_CHECKER = "biological_rule_checker"
TOOL_EXECUTOR = "tool_executor"
RETROSPECTIVE_CRITIC = "retrospective_critic"
REPORT_SYNTHESIZER = "report_synthesizer"
# ── Action schema ───────────────────────────────────────────────────────────
class ExperimentAction(Action):
"""Structured, compositional action for one experiment / analysis step.
Hybrid representation: discrete *action_type* plus typed arguments,
optional sub-agent / tool invocation, and calibration fields.
"""
action_type: ActionType = Field(
..., description="Discrete experiment or analysis step type"
)
input_targets: List[str] = Field(
default_factory=list,
description="References to prior outputs, samples, or artifacts",
)
method: Optional[str] = Field(
None, description="Specific method or tool (e.g. 'Seurat', 'CellRanger')"
)
parameters: Dict[str, Any] = Field(
default_factory=dict, description="Method-specific parameters"
)
expected_output_type: Optional[str] = Field(
None, description="What the agent expects this step to produce"
)
justification: Optional[str] = Field(
None, description="Scientific rationale for this step"
)
invoked_subagent: Optional[SubagentType] = Field(
None, description="Sub-agent to delegate to, if any"
)
tool_call_spec: Optional[Dict[str, Any]] = Field(
None, description="Structured tool invocation specification"
)
confidence: float = Field(
0.5, ge=0.0, le=1.0, description="Agent confidence in this step"
)
# ── Intermediate outputs ────────────────────────────────────────────────────
class OutputType(str, Enum):
QC_METRICS = "qc_metrics"
COUNT_MATRIX_SUMMARY = "count_matrix_summary"
EMBEDDING_SUMMARY = "embedding_summary"
CLUSTER_RESULT = "cluster_result"
DE_RESULT = "de_result"
PATHWAY_RESULT = "pathway_result"
TRAJECTORY_RESULT = "trajectory_result"
VALIDATION_RESULT = "validation_result"
NETWORK_RESULT = "network_result"
SAMPLE_COLLECTION_RESULT = "sample_collection_result"
LIBRARY_PREP_RESULT = "library_prep_result"
SEQUENCING_RESULT = "sequencing_result"
PERTURBATION_RESULT = "perturbation_result"
CULTURE_RESULT = "culture_result"
COHORT_RESULT = "cohort_result"
FOLLOWUP_DESIGN = "followup_design"
MARKER_RESULT = "marker_result"
FAILURE_REPORT = "failure_report"
SUBAGENT_REPORT = "subagent_report"
CONCLUSION = "conclusion"
class IntermediateOutput(BaseModel):
"""A single simulated output from one pipeline step."""
output_type: OutputType
step_index: int
success: bool = True
quality_score: float = Field(1.0, ge=0.0, le=1.0)
summary: str = ""
data: Dict[str, Any] = Field(default_factory=dict)
uncertainty: float = Field(0.0, ge=0.0, le=1.0)
warnings: List[str] = Field(default_factory=list)
artifacts_available: List[str] = Field(default_factory=list)
# ── Observable state components ─────────────────────────────────────────────
class ResourceUsage(BaseModel):
budget_used: float = 0.0
budget_remaining: float = 100_000.0
time_used_days: float = 0.0
time_remaining_days: float = 180.0
samples_consumed: int = 0
compute_hours_used: float = 0.0
class PipelineStepRecord(BaseModel):
step_index: int
action_type: ActionType
method: Optional[str] = None
parameters: Dict[str, Any] = Field(default_factory=dict)
output_summary: str = ""
output_type: OutputType
success: bool = True
quality_score: float = 1.0
resource_cost: float = 0.0
time_cost_days: float = 0.0
class PaperReference(BaseModel):
"""Metadata for a literature source used to ground a task."""
title: str
citation: Optional[str] = None
doi: Optional[str] = None
pmid: Optional[str] = None
url: Optional[str] = None
class ExpectedFinding(BaseModel):
"""A paper-backed result that the agent should try to recover."""
finding: str
category: str = "claim"
keywords: List[str] = Field(default_factory=list)
class TaskSpec(BaseModel):
"""Specification of the biological problem to solve."""
problem_statement: str = "Unspecified biological problem"
modality: str = "scRNA-seq"
organism: str = "human"
tissue: str = "blood"
conditions: List[str] = Field(default_factory=list)
available_assays: List[str] = Field(default_factory=lambda: [
"10x_chromium", "smart-seq2", "bulk_rna_seq",
"atac-seq", "cite-seq", "spatial_transcriptomics",
])
available_tools: List[str] = Field(default_factory=lambda: [
"CellRanger", "Seurat", "Scanpy", "DESeq2", "GSEA",
"Monocle", "scVelo", "CellChat", "SCENIC",
])
budget_limit: float = 100_000.0
time_limit_days: float = 180.0
prior_observations: List[str] = Field(default_factory=list)
success_criteria: List[str] = Field(default_factory=list)
dataset_metadata: Dict[str, Any] = Field(default_factory=dict)
paper_references: List[PaperReference] = Field(default_factory=list)
expected_findings: List[ExpectedFinding] = Field(default_factory=list)
class ConclusionClaim(BaseModel):
claim: str
evidence_steps: List[int] = Field(default_factory=list)
confidence: float = Field(0.5, ge=0.0, le=1.0)
claim_type: str = "correlational"
supporting_data: Dict[str, Any] = Field(default_factory=dict)
# ── Observation schema ──────────────────────────────────────────────────────
class ExperimentObservation(Observation):
"""Full observable state returned to the agent at each timestep.
Deliberately excludes hidden latent biological truth, hidden failure
conditions, and ground-truth mechanisms.
"""
task: TaskSpec = Field(default_factory=TaskSpec)
step_index: int = 0
pipeline_history: List[PipelineStepRecord] = Field(default_factory=list)
available_assays: List[str] = Field(default_factory=list)
available_tools: List[str] = Field(default_factory=list)
resource_usage: ResourceUsage = Field(default_factory=ResourceUsage)
latest_output: Optional[IntermediateOutput] = None
all_outputs: List[IntermediateOutput] = Field(default_factory=list)
discovered_markers: List[str] = Field(default_factory=list)
candidate_mechanisms: List[str] = Field(default_factory=list)
uncertainty_summary: Dict[str, float] = Field(default_factory=dict)
subagent_outputs: List[Dict[str, Any]] = Field(default_factory=list)
conclusions: List[ConclusionClaim] = Field(default_factory=list)
rule_violations: List[str] = Field(default_factory=list)
step_reward_breakdown: Dict[str, float] = Field(default_factory=dict)