"""Action execution mixin for MolForge.""" from __future__ import annotations from typing import Dict, List, Mapping from .shared import ( DEFAULT_TOOL_COSTS, compute_objective_score, evaluate_constraint_margins, evaluate_constraints, literature_hints, ) try: from ..models import AssayReading, MolForgeAction, RewardComponent except ImportError: from models import AssayReading, MolForgeAction, RewardComponent class MolForgeActionMixin: """Methods that mutate environment state through actions.""" def _execute_action( self, action: MolForgeAction, reward_components: List[RewardComponent], previous_properties: Mapping[str, float], previous_score: float, ) -> tuple[float, bool]: reward = 0.0 done = False if action.action_type == "edit": reward += self._apply_edit(action, reward_components, previous_score) elif action.action_type == "run_assay": reward += self._run_assay(action, reward_components) elif action.action_type == "submit": reward, done = self._submit(reward_components) elif action.action_type == "restart": reward += self._restart(reward_components) elif action.action_type == "defer": reward -= 0.05 reward_components.append( RewardComponent( name="defer", value=-0.05, explanation="Deferring preserves state but lightly penalizes lost project time.", ) ) self._last_summary = "The team deferred action to gather its thoughts." return reward, done def _apply_edit( self, action: MolForgeAction, reward_components: List[RewardComponent], previous_score: float, ) -> float: previous_signature = self._molecule_signature() previous_fragment = self._molecule[action.slot] # type: ignore[index] safe_defaults = { "warhead": "nitrile", "hinge": "pyridine", "solvent_tail": "morpholine", "back_pocket": "methoxy", } if action.edit_type == "remove": self._molecule[action.slot] = safe_defaults[action.slot] # type: ignore[index] else: self._molecule[action.slot] = action.fragment # type: ignore[index] new_signature = self._molecule_signature() new_properties = self._true_properties() new_score = compute_objective_score(new_properties, self._scenario) delta = round(new_score - previous_score, 4) if self._reward_mode == "dense": reward = delta * 2.0 explanation = ( f"Updated {action.slot} from {previous_fragment} to {self._molecule[action.slot]}, " f"changing the internal objective score by {delta:+.3f}." ) else: reward = 0.04 if delta > 0 else (-0.04 if delta < 0 else 0.0) explanation = ( f"Updated {action.slot} from {previous_fragment} to {self._molecule[action.slot]}. " "Edit feedback is intentionally coarse; assays and terminal graders provide the main signal." ) reward_components.append( RewardComponent( name="edit_delta", value=round(reward, 4), explanation=explanation, ) ) if new_signature in self._visited_states: reward -= 0.35 reward_components.append( RewardComponent( name="loop_penalty", value=-0.35, explanation="This edit revisited a previously explored molecular state.", ) ) else: reward += 0.06 self._visited_states.add(new_signature) reward -= 0.12 reward_components.append( RewardComponent( name="turn_cost", value=-0.12, explanation="Every chemistry edit consumes simulated project time.", ) ) self._last_summary = ( f"Lead Chemist edited {action.slot}; molecule changed from " f"{previous_signature} to {new_signature}." ) return reward def _run_assay( self, action: MolForgeAction, reward_components: List[RewardComponent], ) -> float: tool_name = action.tool_name or "" cost = DEFAULT_TOOL_COSTS[tool_name] self._state.remaining_budget -= cost self._state.budget_used += cost self._state.oracle_call_count += 1 key = f"{self._molecule_signature()}::{tool_name}" runs = self._assay_runs.get(key, 0) + 1 self._assay_runs[key] = runs reward = 0.02 if runs == 1: reward += 0.10 explanation = "First assay on this molecule/tool pair increased observability." else: reward -= 0.08 explanation = "Repeated assay spent budget on the same molecule/tool pair." readings = self._build_assay_readings(tool_name, runs) self._merge_assays(readings) if tool_name == "search_literature": reward += 0.04 if self._reward_mode == "curriculum" and runs == 1: required_props = {"potency", "toxicity"} if "synth_min" in self._scenario.hard_constraints: required_props.add("synth") covered_props = { reading.property_name for reading in readings if reading.property_name in required_props } if covered_props: bonus = 0.08 * len(covered_props) reward += bonus reward_components.append( RewardComponent( name="curriculum_evidence_gate", value=round(bonus, 4), explanation=( "Curriculum reward for collecting first-pass evidence " f"for: {', '.join(sorted(covered_props))}." ), ) ) reward_components.append( RewardComponent( name="assay_information_gain", value=round(reward, 4), explanation=explanation, ) ) reward_components.append( RewardComponent( name="budget_spend", value=round(-cost / max(self._scenario.oracle_budget, 1), 4), explanation=f"Spent {cost} assay budget on {tool_name}.", ) ) reward -= cost / max(self._scenario.oracle_budget, 1) self._oracle_log.append( { "step": self._state.step_count, "tool_name": tool_name, "runs": runs, "molecule": self._molecule_signature(), "cost": cost, "results": [reading.model_dump() for reading in readings], } ) self._last_summary = ( f"Assay Planner executed {tool_name}; {len(readings)} structured assay result(s) are now visible." ) return reward def _submit(self, reward_components: List[RewardComponent]) -> tuple[float, bool]: properties = self._true_properties() final_score = compute_objective_score(properties, self._scenario) constraint_results = evaluate_constraints(properties, self._scenario) constraint_margins = evaluate_constraint_margins(properties, self._scenario) margin_score = sum(constraint_margins.values()) / max(len(constraint_margins), 1) violation_penalty = round((1.0 - margin_score) * 2.0, 4) hard_constraints_met = all(result[0] for result in constraint_results.values()) budget_efficiency = self._state.remaining_budget / max(self._scenario.oracle_budget, 1) beats_baseline = final_score >= self._scenario.baseline_to_beat current_signature = self._molecule_signature() evidence_requirements = ["potency", "toxicity"] if "synth_min" in self._scenario.hard_constraints: evidence_requirements.append("synth") missing_evidence = [ prop for prop in evidence_requirements if self._current_property_estimate(prop, current_signature) is None ] evidence_met = not missing_evidence post_shift_evidence_met = True if self._scenario.target_shift_step and self._target_shift_active(): post_shift_evidence_met = any( entry["step"] >= self._scenario.target_shift_step and entry["molecule"] == current_signature and any(result["property_name"] == "potency" for result in entry["results"]) for entry in self._oracle_log ) valid_submission = hard_constraints_met and beats_baseline and evidence_met and post_shift_evidence_met reward = final_score * 2.0 if valid_submission else final_score * 0.25 if valid_submission: reward += 3.5 elif not hard_constraints_met: reward -= violation_penalty if not beats_baseline: reward -= 0.6 if not evidence_met: reward -= 1.2 if not post_shift_evidence_met: reward -= 0.8 if valid_submission: reward += max(0.0, budget_efficiency) * 0.7 if self._reward_mode == "curriculum" and evidence_met and post_shift_evidence_met: submit_bonus = 0.35 if hard_constraints_met: submit_bonus += 0.15 reward += submit_bonus self._state.submitted = True self._report_card = self._build_report_card(submitted=True) self._last_summary = ( f"The team submitted a candidate that " f"{'passed' if hard_constraints_met else 'failed'} hard constraints." ) reward_components.extend( [ RewardComponent( name="submission_quality", value=round((final_score * 2.0 if valid_submission else final_score * 0.25), 4), explanation=( "Full scientific quality reward because the submission met constraints, baseline, and evidence gates." if valid_submission else "Only a small quality trace is awarded because the submit action missed a gate." ), ), RewardComponent( name="hard_constraints", value=( 3.5 if valid_submission else (-violation_penalty if not hard_constraints_met else 0.0) ), explanation=( "Large sparse bonus for beating baseline with required current evidence." if valid_submission else "Submission missed constraints, baseline, or evidence requirements; constraint penalty scales with violation severity." ), ), RewardComponent( name="constraint_margin", value=round(margin_score, 4), explanation=( "Proportional hard-constraint score: worse potency, toxicity, or synthesis violations produce lower values." ), ), RewardComponent( name="baseline_gate", value=0.0 if beats_baseline else -0.6, explanation=( "Submitted molecule beat the scenario baseline." if beats_baseline else "Submitted molecule did not beat the scenario baseline." ), ), RewardComponent( name="submission_evidence", value=0.0 if evidence_met else -1.2, explanation=( "Current-molecule potency/toxicity/synthesis evidence was available." if evidence_met else f"Submission lacked current evidence for: {', '.join(missing_evidence)}." ), ), RewardComponent( name="post_shift_evidence", value=0.0 if post_shift_evidence_met else -0.8, explanation=( "Post-shift potency evidence was available for the submitted molecule." if post_shift_evidence_met else "Hard scenario submission lacked post-shift potency evidence for the current molecule." ), ), RewardComponent( name="budget_efficiency", value=round(max(0.0, budget_efficiency) * 0.7, 4) if valid_submission else 0.0, explanation=( "Unused budget is rewarded to discourage wasteful oracle usage." if valid_submission else "Budget efficiency is not awarded to a gated or premature submission." ), ), ] ) if self._reward_mode == "curriculum" and evidence_met and post_shift_evidence_met: reward_components.append( RewardComponent( name="curriculum_evidence_supported_submit", value=round(submit_bonus, 4), explanation=( "Curriculum reward for making a formal submit decision after the required " "current evidence package was available." ), ) ) return reward, True def _restart(self, reward_components: List[RewardComponent]) -> float: self._molecule = dict(self._scenario.restart_scaffold) self._trap_penalty_active = False self._known_assays = [] self._assay_runs = {} self._restart_used = True self._visited_states.add(self._molecule_signature()) self._state.remaining_budget -= 350 self._state.budget_used += 350 reward_components.append( RewardComponent( name="restart_penalty", value=-0.4, explanation="Restarting discards sunk work but switches to a clean scaffold family.", ) ) self._last_summary = ( "The team abandoned the original scaffold series and restarted from a cleaner alternative." ) return -0.4 def _build_assay_readings(self, tool_name: str, runs: int) -> List[AssayReading]: properties = self._true_properties() signature = self._molecule_signature() if tool_name == "evaluate_properties": property_names = ["potency", "novelty"] elif tool_name == "dock_target": property_names = ["potency"] elif tool_name == "assay_toxicity": property_names = ["toxicity"] elif tool_name == "estimate_synthesizability": property_names = ["synth"] elif tool_name == "evaluate_novelty": property_names = ["novelty"] elif tool_name == "search_literature": hint_score = min(0.95, 0.45 + 0.08 * runs) return [ AssayReading( tool_name=tool_name, property_name="literature_signal", estimate=round(hint_score, 4), confidence_low=max(0.0, round(hint_score - 0.08, 4)), confidence_high=min(1.0, round(hint_score + 0.08, 4)), runs=runs, molecule_signature=signature, summary=literature_hints(self._molecule)[0], ) ] else: property_names = ["potency", "toxicity", "synth"] readings = [] for property_name in property_names: true_value = properties[property_name] estimate = self._assay_estimate(signature, tool_name, property_name, runs, true_value) width = max(0.03, 0.18 / runs) readings.append( AssayReading( tool_name=tool_name, property_name=property_name, estimate=estimate, confidence_low=max(0.0, round(estimate - width, 4)), confidence_high=min(1.0, round(estimate + width, 4)), runs=runs, molecule_signature=signature, summary=f"{tool_name} estimated {property_name} with run count {runs}.", ) ) return readings