molforge / server /shared.py
Adhitya122's picture
Prepare MolForge OpenEnv Docker Space submission
bf9e424 verified
"""Shared imports, constants, and utility mixins for MolForge."""
from __future__ import annotations
import hashlib
from copy import deepcopy
from typing import Any, Dict, List, Mapping, Optional
try:
from ..models import (
AgentMessage,
AssayReading,
MolForgeAction,
)
from ..scenarios import (
DEFAULT_TOOL_COSTS,
EDITABLE_SLOTS,
FRAGMENT_LIBRARY,
SLOT_ORDER,
SCENARIOS,
ScenarioConfig,
compute_objective_score,
enumerate_candidate_edits,
evaluate_constraint_margins,
evaluate_constraints,
evaluate_molecule,
format_molecule,
get_scenario,
literature_hints,
molecule_to_smiles,
oracle_backend_status,
)
except ImportError:
from models import (
AgentMessage,
AssayReading,
MolForgeAction,
)
from scenarios import (
DEFAULT_TOOL_COSTS,
EDITABLE_SLOTS,
FRAGMENT_LIBRARY,
SLOT_ORDER,
SCENARIOS,
ScenarioConfig,
compute_objective_score,
enumerate_candidate_edits,
evaluate_constraint_margins,
evaluate_constraints,
evaluate_molecule,
format_molecule,
get_scenario,
literature_hints,
molecule_to_smiles,
oracle_backend_status,
)
ROLE_PERMISSIONS: Dict[str, List[str]] = {
"lead_chemist": ["edit", "submit", "restart", "defer"],
"toxicologist": [],
"assay_planner": ["run_assay"],
"process_chemist": [],
}
ROLE_MESSAGE_TYPES: Dict[str, List[str]] = {
"lead_chemist": ["proposal", "revision_request", "submission_recommendation"],
"toxicologist": ["approval", "objection", "risk_flag", "assay_request", "rejection"],
"assay_planner": ["proposal", "approval", "rejection", "assay_request", "submission_recommendation"],
"process_chemist": ["approval", "objection", "risk_flag", "assay_request"],
}
class MolForgeSharedMixin:
"""Utility methods shared across the environment mixins."""
def _merge_assays(self, readings: List[AssayReading]) -> None:
keyed = {
(reading.tool_name, reading.property_name, reading.molecule_signature): reading
for reading in self._known_assays
}
for reading in readings:
keyed[(reading.tool_name, reading.property_name, reading.molecule_signature)] = reading
self._known_assays = list(keyed.values())
def _current_property_estimate(
self,
property_name: str,
molecule_signature: Optional[str] = None,
) -> Optional[float]:
signature = molecule_signature or self._molecule_signature()
for reading in reversed(self._known_assays):
if reading.molecule_signature == signature and reading.property_name == property_name:
return reading.estimate
return None
def _estimate_information_gain(self, tool_name: str) -> float:
current_signature = self._molecule_signature()
prior_runs = self._assay_runs.get(f"{current_signature}::{tool_name}", 0)
base = {
"evaluate_properties": 0.7,
"dock_target": 0.62,
"assay_toxicity": 0.78 if self._scenario.difficulty != "easy" else 0.52,
"estimate_synthesizability": 0.66 if "synth_min" in self._scenario.hard_constraints else 0.42,
"evaluate_novelty": 0.38,
"search_literature": 0.32,
"run_md_simulation": 0.84,
}.get(tool_name, 0.25)
decay = 0.4**prior_runs
return round(base * decay, 4)
def _simulate_action_properties(self, action: MolForgeAction) -> Dict[str, float]:
if action.action_type == "edit" and action.slot:
molecule = dict(self._molecule)
if action.edit_type == "remove":
defaults = {
"warhead": "nitrile",
"hinge": "pyridine",
"solvent_tail": "morpholine",
"back_pocket": "methoxy",
}
molecule[action.slot] = defaults[action.slot]
elif action.fragment:
molecule[action.slot] = action.fragment
return self._evaluate_for_molecule(molecule, self._trap_penalty_active)
if action.action_type == "restart":
return self._evaluate_for_molecule(dict(self._scenario.restart_scaffold), False)
return self._true_properties()
def _record_message(self, message: AgentMessage) -> None:
if not message.message_id:
message.message_id = f"msg_{self._state.step_count:03d}_{len(self._message_log):03d}"
self._message_log.append(deepcopy(message))
self._state.message_count += 1
self._role_metrics[message.sender]["messages_sent"] += 1
if message.message_type in {"objection", "risk_flag", "rejection"}:
self._state.objection_count += 1
def _sync_state_metadata(self) -> None:
self._state.metadata = {
"state_label": self._state.state_label,
"state_path": list(self._state_path),
"trace": deepcopy(self._history),
"message_log": [message.model_dump() for message in self._message_log],
"oracle_log": deepcopy(self._oracle_log),
"role_metrics": deepcopy(self._role_metrics),
"terminal_grader_scores": self._grade_all() if self._state.submitted else {},
}
if self._debug_state_enabled:
self._state.metadata["debug_hidden_properties"] = self._true_properties()
def _true_properties(self) -> Dict[str, float]:
return self._evaluate_for_molecule(self._molecule, self._trap_penalty_active)
def _evaluate_for_molecule(
self,
molecule: Mapping[str, str],
trap_penalty_active: bool,
) -> Dict[str, float]:
return evaluate_molecule(
molecule,
self._scenario.__class__(
**{**self._scenario.__dict__, "trap_penalty": trap_penalty_active}
),
target_shift_active=self._target_shift_active(),
)
def _target_shift_active(self) -> bool:
return bool(
self._scenario.target_shift_step
and self._state.step_count >= self._scenario.target_shift_step
)
def _molecule_signature(self) -> str:
return format_molecule(self._molecule)
def _append_state_label(self, label: str) -> None:
if not self._state_path or self._state_path[-1] != label:
self._state_path.append(label)
def _safety_alerts(self) -> List[str]:
alerts = []
if self._molecule["solvent_tail"] == "dimethylamino":
alerts.append("Dimethylamino tail is a recurring liability for cardiac safety.")
if self._molecule["back_pocket"] == "trifluoromethyl":
alerts.append("Trifluoromethyl group may overshoot lipophilic safety windows.")
if self._molecule["hinge"] == "fluorophenyl" and self._molecule["back_pocket"] == "chloro":
alerts.append("Hydrophobic hinge/back-pocket combination looks safety-negative.")
return alerts
def _route_warnings(self) -> List[str]:
warnings = []
if self._molecule["hinge"] == "quinazoline":
warnings.append("Quinazoline hinge increases route complexity.")
if self._molecule["warhead"] == "vinyl_sulfonamide":
warnings.append("Vinyl sulfonamide warhead is reactive and harder to handle.")
if self._molecule["back_pocket"] == "trifluoromethyl":
warnings.append("CF3 substitution raises cost and scale-up complexity.")
return warnings
@staticmethod
def _empty_role_metrics() -> Dict[str, Dict[str, int]]:
return {
role: {"messages_sent": 0, "correct_messages": 0, "incorrect_messages": 0}
for role in ["lead_chemist", "toxicologist", "assay_planner", "process_chemist"]
}
@staticmethod
def _open_unit_interval(value: float, epsilon: float = 1e-4) -> float:
return round(min(max(value, epsilon), 1.0 - epsilon), 4)
@staticmethod
def _assay_estimate(
signature: str,
tool_name: str,
property_name: str,
runs: int,
true_value: float,
) -> float:
digest = hashlib.sha256(
f"{signature}|{tool_name}|{property_name}|{runs}".encode("utf-8")
).hexdigest()
centered = (int(digest[:8], 16) / 0xFFFFFFFF) - 0.5
noise = centered * (0.16 / runs)
return round(min(max(true_value + noise, 0.0), 1.0), 4)