Spaces:
Running
Running
| """MolForge environment implementation.""" | |
| from __future__ import annotations | |
| import os | |
| import random | |
| from dataclasses import replace | |
| from typing import Any, Dict, List | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from .actions import MolForgeActionMixin | |
| from .governance import MolForgeGovernanceMixin | |
| from .shared import ( | |
| FRAGMENT_LIBRARY, | |
| SCENARIOS, | |
| SLOT_ORDER, | |
| compute_objective_score, | |
| get_scenario, | |
| ) | |
| from .shared import MolForgeSharedMixin | |
| from .views import MolForgeViewMixin | |
| try: | |
| from ..models import GovernanceStatus, MolForgeAction, MolForgeObservation, MolForgeState, RewardComponent | |
| except ImportError: | |
| from models import GovernanceStatus, MolForgeAction, MolForgeObservation, MolForgeState, RewardComponent | |
| class MolForgeEnvironment( | |
| MolForgeActionMixin, | |
| MolForgeGovernanceMixin, | |
| MolForgeViewMixin, | |
| MolForgeSharedMixin, | |
| Environment, | |
| ): | |
| """Deterministic medicinal-chemistry design environment for OpenEnv.""" | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self): | |
| self._debug_state_enabled = os.getenv("MOLFORGE_DEBUG_STATE", "").lower() in {"1", "true", "yes"} | |
| self._training_randomization_enabled = os.getenv("MOLFORGE_TRAINING_RANDOMIZATION", "").lower() in { | |
| "1", | |
| "true", | |
| "yes", | |
| } | |
| self._reward_mode = os.getenv("MOLFORGE_REWARD_MODE", "assay_gated").lower() | |
| self._rng = random.Random(os.getenv("MOLFORGE_RANDOM_SEED", "molforge")) | |
| self._reset_index = -1 | |
| self._state = MolForgeState(episode_id=str(uuid4()), step_count=0) | |
| self._scenario = SCENARIOS[0] | |
| self._molecule: Dict[str, str] = {} | |
| self._assay_runs: Dict[str, int] = {} | |
| self._known_assays: List = [] | |
| self._message_log: List = [] | |
| self._history: List[Dict[str, Any]] = [] | |
| self._oracle_log: List[Dict[str, Any]] = [] | |
| self._visited_states: set[str] = set() | |
| self._last_summary = "" | |
| self._report_card = "" | |
| self._reward_total = 0.0 | |
| self._restart_used = False | |
| self._trap_penalty_active = False | |
| self._role_metrics = self._empty_role_metrics() | |
| self._state_path: List[str] = ["[start]"] | |
| self._last_governance = GovernanceStatus( | |
| status="ready", | |
| explanation="Awaiting the first coordinated decision.", | |
| required_roles=[], | |
| approvals=[], | |
| objections=[], | |
| vetoes=[], | |
| executable=True, | |
| ) | |
| self.reset() | |
| self._reset_index = -1 | |
| def reset(self) -> MolForgeObservation: | |
| """Start a new scenario in a deterministic rotation.""" | |
| self._reset_index += 1 | |
| self._scenario = self._select_reset_scenario() | |
| self._molecule = dict(self._scenario.starting_scaffold) | |
| self._assay_runs = {} | |
| self._known_assays = [] | |
| self._message_log = [] | |
| self._history = [] | |
| self._oracle_log = [] | |
| self._visited_states = {self._molecule_signature()} | |
| self._last_summary = "Episode initialized with a fresh multi-agent review board." | |
| self._report_card = "" | |
| self._reward_total = 0.0 | |
| self._restart_used = False | |
| self._trap_penalty_active = self._scenario.trap_penalty | |
| self._role_metrics = self._empty_role_metrics() | |
| self._state_path = ["[start]"] | |
| self._last_governance = GovernanceStatus( | |
| status="ready", | |
| explanation="Lead Chemist should propose the first coordinated action.", | |
| required_roles=list(self._scenario.required_review_roles), | |
| approvals=[], | |
| objections=[], | |
| vetoes=[], | |
| executable=True, | |
| ) | |
| self._state = MolForgeState( | |
| episode_id=str(uuid4()), | |
| step_count=0, | |
| scenario_id=self._scenario.scenario_id, | |
| difficulty=self._scenario.difficulty, | |
| state_label="[start]", | |
| state_path=list(self._state_path), | |
| coordination_mode=self._scenario.coordination_mode, # type: ignore[arg-type] | |
| enabled_roles=list(self._scenario.enabled_roles), | |
| target_name=self._scenario.target_name, | |
| current_molecule=self._molecule_signature(), | |
| remaining_budget=self._scenario.oracle_budget, | |
| budget_used=0, | |
| max_budget=self._scenario.oracle_budget, | |
| visited_states=1, | |
| known_assay_count=0, | |
| invalid_action_count=0, | |
| objection_count=0, | |
| oracle_call_count=0, | |
| message_count=0, | |
| decision_count=0, | |
| submitted=False, | |
| reward_total=0.0, | |
| metadata={}, | |
| ) | |
| self._sync_state_metadata() | |
| return self._build_observation(reward=0.0, done=False, reward_components=[]) | |
| def _select_reset_scenario(self): | |
| """Select a deterministic judge scenario or a randomized training variant.""" | |
| scenario = get_scenario(self._reset_index) | |
| if not self._training_randomization_enabled: | |
| return scenario | |
| scenario = self._rng.choice(SCENARIOS) | |
| budget_scale = self._rng.uniform(0.85, 1.15) | |
| max_steps_delta = self._rng.choice([-1, 0, 0, 1]) | |
| starting_scaffold = dict(scenario.starting_scaffold) | |
| if self._rng.random() < 0.35: | |
| slot = self._rng.choice(SLOT_ORDER) | |
| choices = [ | |
| fragment | |
| for fragment in FRAGMENT_LIBRARY[slot] | |
| if fragment != starting_scaffold[slot] | |
| ] | |
| starting_scaffold[slot] = self._rng.choice(choices) | |
| return replace( | |
| scenario, | |
| oracle_budget=max(1, int(round(scenario.oracle_budget * budget_scale))), | |
| max_steps=max(4, scenario.max_steps + max_steps_delta), | |
| starting_scaffold=starting_scaffold, | |
| ) | |
| def step(self, action: MolForgeAction) -> MolForgeObservation: # type: ignore[override] | |
| """Execute one coordinated environment action.""" | |
| reward_components: List[RewardComponent] = [] | |
| done = False | |
| error_code = "" | |
| self._state.step_count += 1 | |
| self._state.decision_count += 1 | |
| previous_properties = self._true_properties() | |
| previous_score = compute_objective_score(previous_properties, self._scenario) | |
| validation_error = self._validate_action(action) | |
| if validation_error: | |
| error_code, message = validation_error | |
| self._state.invalid_action_count += 1 | |
| self._last_governance = GovernanceStatus( | |
| status="needs_revision", | |
| explanation=message, | |
| required_roles=list(self._scenario.required_review_roles), | |
| approvals=[], | |
| objections=[], | |
| vetoes=[], | |
| executable=False, | |
| ) | |
| reward_components.append( | |
| RewardComponent( | |
| name="invalid_action", | |
| value=-1.0, | |
| explanation=message, | |
| ) | |
| ) | |
| reward = -1.0 | |
| self._last_summary = message | |
| self._append_state_label("[invalid]") | |
| else: | |
| governance, governance_components, policy_veto = self._assess_governance( | |
| action, previous_properties | |
| ) | |
| self._last_governance = governance | |
| reward_components.extend(governance_components) | |
| reward = sum(component.value for component in governance_components) | |
| if policy_veto: | |
| self._last_summary = governance.explanation | |
| self._append_state_label("[policy_veto]") | |
| else: | |
| self._last_governance.status = "executed" | |
| action_reward, done = self._execute_action( | |
| action, reward_components, previous_properties, previous_score | |
| ) | |
| reward += action_reward | |
| if not done: | |
| reward += self._evaluate_reasoning_consistency( | |
| action, | |
| previous_properties, | |
| self._true_properties(), | |
| reward_components, | |
| ) | |
| if done and self._state.submitted: | |
| self._append_state_label("[submitted]") | |
| elif not done: | |
| self._append_state_label(f"[decision_{self._state.step_count:02d}]") | |
| if not done and self._state.step_count >= self._scenario.max_steps: | |
| done = True | |
| reward_components.append( | |
| RewardComponent( | |
| name="step_limit", | |
| value=-0.3, | |
| explanation="Episode ended because the maximum decision horizon was reached.", | |
| ) | |
| ) | |
| reward -= 0.3 | |
| self._report_card = self._build_report_card(submitted=False) | |
| self._last_summary = "Max-step termination triggered." | |
| self._append_state_label("[terminated:max_steps]") | |
| if not done and self._state.remaining_budget <= 0: | |
| done = True | |
| reward_components.append( | |
| RewardComponent( | |
| name="budget_exhausted", | |
| value=-0.5, | |
| explanation="Episode terminated because the oracle budget reached zero.", | |
| ) | |
| ) | |
| reward -= 0.5 | |
| self._report_card = self._build_report_card(submitted=False) | |
| self._last_summary = "Budget exhausted before a valid submission." | |
| self._append_state_label("[terminated:budget]") | |
| if done and not self._report_card: | |
| self._report_card = self._build_report_card(submitted=self._state.submitted) | |
| if done and not self._state.submitted and self._reward_mode == "curriculum": | |
| reward += self._curriculum_terminal_progress_reward(reward_components) | |
| reward = round(reward, 4) | |
| self._reward_total = round(self._reward_total + reward, 4) | |
| self._state.reward_total = self._reward_total | |
| self._state.current_molecule = self._molecule_signature() | |
| self._state.state_label = self._state_path[-1] | |
| self._state.state_path = list(self._state_path) | |
| self._state.visited_states = len(self._visited_states) | |
| self._state.known_assay_count = len(self._known_assays) | |
| self._state.last_error_code = error_code | |
| self._history.append( | |
| { | |
| "step": self._state.step_count, | |
| "action": action.model_dump(exclude_none=True), | |
| "reward": reward, | |
| "done": done, | |
| "molecule": self._molecule_signature(), | |
| "state_label": self._state.state_label, | |
| "summary": self._last_summary, | |
| "governance": self._last_governance.model_dump(), | |
| } | |
| ) | |
| if done: | |
| self._report_card = self._build_report_card(submitted=self._state.submitted) | |
| self._sync_state_metadata() | |
| return self._build_observation( | |
| reward=reward, | |
| done=done, | |
| reward_components=reward_components, | |
| ) | |
| def _curriculum_terminal_progress_reward(self, reward_components: List[RewardComponent]) -> float: | |
| """Give bounded partial credit for near-miss episodes during RL warmup. | |
| This intentionally does not change the public submission grader. It only | |
| makes the training reward less sparse when a model builds evidence or a | |
| chemically plausible candidate but fails to formally submit. | |
| """ | |
| grader_scores = self._grade_all() | |
| progress = ( | |
| 0.25 * grader_scores["candidate_score"] | |
| + 0.25 * grader_scores["constraint_margin_score"] | |
| + 0.25 * grader_scores["evidence_score"] | |
| + 0.15 * grader_scores["coordination_score"] | |
| + 0.10 * grader_scores["budget_score"] | |
| ) | |
| progress = min(0.75, max(0.0, progress)) | |
| reward_components.append( | |
| RewardComponent( | |
| name="curriculum_terminal_progress", | |
| value=round(progress, 4), | |
| explanation=( | |
| "Bounded warmup reward for non-submitted episodes based on candidate quality, " | |
| "constraint margin, evidence coverage, coordination, and budget discipline. " | |
| "Official submission_score remains 0.0 without a submit action." | |
| ), | |
| ) | |
| ) | |
| missed_nomination_penalty = 0.0 | |
| if ( | |
| grader_scores["evidence_score"] >= 0.99 | |
| and grader_scores["constraint_margin_score"] >= 0.9 | |
| and grader_scores["candidate_score"] >= self._scenario.baseline_to_beat | |
| ): | |
| missed_nomination_penalty = -0.25 | |
| reward_components.append( | |
| RewardComponent( | |
| name="curriculum_missed_nomination", | |
| value=missed_nomination_penalty, | |
| explanation=( | |
| "The candidate had a strong evidence package near the decision deadline, " | |
| "but the team failed to make a formal submit decision." | |
| ), | |
| ) | |
| ) | |
| return progress + missed_nomination_penalty | |
| def state(self) -> MolForgeState: | |
| """Return the current environment state.""" | |
| return self._state | |