Spaces:
Sleeping
Sleeping
| """Task generator β produces (ValidationTaskSpec, FullLatentState) pairs | |
| for drug-target-validation episodes. | |
| Supports two modes: | |
| 1. Select from the curated ``SCENARIO_LIBRARY``. | |
| 2. Add procedurally-generated scenarios on top. | |
| """ | |
| from __future__ import annotations | |
| from typing import List, Optional, Tuple | |
| import numpy as np | |
| from models import ActionType, ValidationTaskSpec | |
| from server.simulator.latent_state import ( | |
| CreditState, | |
| DataQualityState, | |
| FullLatentState, | |
| TargetProfile, | |
| ValidationProgress, | |
| ) | |
| from .scenarios import SCENARIO_LIBRARY, Scenario | |
| from .procedural_generator import generate_procedural_scenarios | |
| class TaskGenerator: | |
| """Generates task + latent-state pairs for environment episodes.""" | |
| def __init__( | |
| self, | |
| scenarios: Optional[List[Scenario]] = None, | |
| domain_randomise: bool = True, | |
| ): | |
| if scenarios is not None: | |
| self.scenarios = scenarios | |
| else: | |
| self.scenarios = list(SCENARIO_LIBRARY) + generate_procedural_scenarios( | |
| n=20, seed=42, | |
| ) | |
| self.domain_randomise = domain_randomise | |
| def generate( | |
| self, | |
| *, | |
| seed: Optional[int] = None, | |
| scenario_name: Optional[str] = None, | |
| ) -> Tuple[ValidationTaskSpec, FullLatentState]: | |
| rng = np.random.default_rng(seed) | |
| if scenario_name: | |
| scenario = self._find_scenario(scenario_name) | |
| else: | |
| idx = int(rng.integers(0, len(self.scenarios))) | |
| scenario = self.scenarios[idx] | |
| task = scenario.task.model_copy(deep=True) | |
| target = scenario.target.model_copy(deep=True) | |
| data_quality = scenario.data_quality.model_copy(deep=True) | |
| if self.domain_randomise: | |
| self._randomise(rng, task, target, data_quality) | |
| if not task.available_actions: | |
| task.available_actions = [a.value for a in ActionType] | |
| latent = FullLatentState( | |
| target=target, | |
| data_quality=data_quality, | |
| progress=ValidationProgress(), | |
| credits=CreditState(credits_total=task.credits_limit), | |
| rng_seed=seed or 0, | |
| ) | |
| return task, latent | |
| def list_scenarios(self) -> List[str]: | |
| return [s.name for s in self.scenarios] | |
| # ββ internals βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _find_scenario(self, name: str) -> Scenario: | |
| for s in self.scenarios: | |
| if s.name == name: | |
| return s | |
| available = ", ".join(self.list_scenarios()) | |
| raise ValueError(f"Unknown scenario '{name}'. Available: {available}") | |
| def _randomise( | |
| rng: np.random.Generator, | |
| task: ValidationTaskSpec, | |
| target: TargetProfile, | |
| data_quality: DataQualityState, | |
| ) -> None: | |
| """Light domain randomisation that nudges noise / numerics without | |
| flipping ``correct_decision`` or ``key_evidence_dimensions``.""" | |
| # Credit budget jitter | |
| task.credits_limit = int( | |
| max(15, round(task.credits_limit * float(rng.uniform(0.9, 1.1)))) | |
| ) | |
| # Data-quality jitter | |
| data_quality.noise_level = float(np.clip( | |
| data_quality.noise_level + rng.normal(0, 0.02), 0.02, 0.4 | |
| )) | |
| data_quality.false_positive_rate = float(np.clip( | |
| data_quality.false_positive_rate + rng.normal(0, 0.01), 0.0, 0.3 | |
| )) | |
| data_quality.false_negative_rate = float(np.clip( | |
| data_quality.false_negative_rate + rng.normal(0, 0.01), 0.0, 0.3 | |
| )) | |
| data_quality.database_coverage = float(np.clip( | |
| data_quality.database_coverage + rng.normal(0, 0.03), 0.5, 1.0 | |
| )) | |
| # Target profile numerics β keep categorical fields fixed. | |
| target.tissue_specificity = float(np.clip( | |
| target.tissue_specificity * float(rng.uniform(0.9, 1.1)), 0.0, 1.0 | |
| )) | |
| target.disease_overexpression = float(max( | |
| 0.1, target.disease_overexpression * float(rng.uniform(0.85, 1.15)) | |
| )) | |
| target.druggability_score = float(np.clip( | |
| target.druggability_score * float(rng.uniform(0.9, 1.1)), 0.0, 1.0 | |
| )) | |
| target.selectivity_ratio = float(max( | |
| 0.0, target.selectivity_ratio * float(rng.uniform(0.85, 1.15)) | |
| )) | |
| target.in_vitro_ic50_nM = float(max( | |
| 0.5, target.in_vitro_ic50_nM * float(rng.uniform(0.7, 1.3)) | |
| )) | |