molforge / server /actions.py
Adhitya122's picture
Prepare MolForge OpenEnv Docker Space submission
bf9e424 verified
"""Action execution mixin for MolForge."""
from __future__ import annotations
from typing import Dict, List, Mapping
from .shared import (
DEFAULT_TOOL_COSTS,
compute_objective_score,
evaluate_constraint_margins,
evaluate_constraints,
literature_hints,
)
try:
from ..models import AssayReading, MolForgeAction, RewardComponent
except ImportError:
from models import AssayReading, MolForgeAction, RewardComponent
class MolForgeActionMixin:
"""Methods that mutate environment state through actions."""
def _execute_action(
self,
action: MolForgeAction,
reward_components: List[RewardComponent],
previous_properties: Mapping[str, float],
previous_score: float,
) -> tuple[float, bool]:
reward = 0.0
done = False
if action.action_type == "edit":
reward += self._apply_edit(action, reward_components, previous_score)
elif action.action_type == "run_assay":
reward += self._run_assay(action, reward_components)
elif action.action_type == "submit":
reward, done = self._submit(reward_components)
elif action.action_type == "restart":
reward += self._restart(reward_components)
elif action.action_type == "defer":
reward -= 0.05
reward_components.append(
RewardComponent(
name="defer",
value=-0.05,
explanation="Deferring preserves state but lightly penalizes lost project time.",
)
)
self._last_summary = "The team deferred action to gather its thoughts."
return reward, done
def _apply_edit(
self,
action: MolForgeAction,
reward_components: List[RewardComponent],
previous_score: float,
) -> float:
previous_signature = self._molecule_signature()
previous_fragment = self._molecule[action.slot] # type: ignore[index]
safe_defaults = {
"warhead": "nitrile",
"hinge": "pyridine",
"solvent_tail": "morpholine",
"back_pocket": "methoxy",
}
if action.edit_type == "remove":
self._molecule[action.slot] = safe_defaults[action.slot] # type: ignore[index]
else:
self._molecule[action.slot] = action.fragment # type: ignore[index]
new_signature = self._molecule_signature()
new_properties = self._true_properties()
new_score = compute_objective_score(new_properties, self._scenario)
delta = round(new_score - previous_score, 4)
if self._reward_mode == "dense":
reward = delta * 2.0
explanation = (
f"Updated {action.slot} from {previous_fragment} to {self._molecule[action.slot]}, "
f"changing the internal objective score by {delta:+.3f}."
)
else:
reward = 0.04 if delta > 0 else (-0.04 if delta < 0 else 0.0)
explanation = (
f"Updated {action.slot} from {previous_fragment} to {self._molecule[action.slot]}. "
"Edit feedback is intentionally coarse; assays and terminal graders provide the main signal."
)
reward_components.append(
RewardComponent(
name="edit_delta",
value=round(reward, 4),
explanation=explanation,
)
)
if new_signature in self._visited_states:
reward -= 0.35
reward_components.append(
RewardComponent(
name="loop_penalty",
value=-0.35,
explanation="This edit revisited a previously explored molecular state.",
)
)
else:
reward += 0.06
self._visited_states.add(new_signature)
reward -= 0.12
reward_components.append(
RewardComponent(
name="turn_cost",
value=-0.12,
explanation="Every chemistry edit consumes simulated project time.",
)
)
self._last_summary = (
f"Lead Chemist edited {action.slot}; molecule changed from "
f"{previous_signature} to {new_signature}."
)
return reward
def _run_assay(
self,
action: MolForgeAction,
reward_components: List[RewardComponent],
) -> float:
tool_name = action.tool_name or ""
cost = DEFAULT_TOOL_COSTS[tool_name]
self._state.remaining_budget -= cost
self._state.budget_used += cost
self._state.oracle_call_count += 1
key = f"{self._molecule_signature()}::{tool_name}"
runs = self._assay_runs.get(key, 0) + 1
self._assay_runs[key] = runs
reward = 0.02
if runs == 1:
reward += 0.10
explanation = "First assay on this molecule/tool pair increased observability."
else:
reward -= 0.08
explanation = "Repeated assay spent budget on the same molecule/tool pair."
readings = self._build_assay_readings(tool_name, runs)
self._merge_assays(readings)
if tool_name == "search_literature":
reward += 0.04
if self._reward_mode == "curriculum" and runs == 1:
required_props = {"potency", "toxicity"}
if "synth_min" in self._scenario.hard_constraints:
required_props.add("synth")
covered_props = {
reading.property_name
for reading in readings
if reading.property_name in required_props
}
if covered_props:
bonus = 0.08 * len(covered_props)
reward += bonus
reward_components.append(
RewardComponent(
name="curriculum_evidence_gate",
value=round(bonus, 4),
explanation=(
"Curriculum reward for collecting first-pass evidence "
f"for: {', '.join(sorted(covered_props))}."
),
)
)
reward_components.append(
RewardComponent(
name="assay_information_gain",
value=round(reward, 4),
explanation=explanation,
)
)
reward_components.append(
RewardComponent(
name="budget_spend",
value=round(-cost / max(self._scenario.oracle_budget, 1), 4),
explanation=f"Spent {cost} assay budget on {tool_name}.",
)
)
reward -= cost / max(self._scenario.oracle_budget, 1)
self._oracle_log.append(
{
"step": self._state.step_count,
"tool_name": tool_name,
"runs": runs,
"molecule": self._molecule_signature(),
"cost": cost,
"results": [reading.model_dump() for reading in readings],
}
)
self._last_summary = (
f"Assay Planner executed {tool_name}; {len(readings)} structured assay result(s) are now visible."
)
return reward
def _submit(self, reward_components: List[RewardComponent]) -> tuple[float, bool]:
properties = self._true_properties()
final_score = compute_objective_score(properties, self._scenario)
constraint_results = evaluate_constraints(properties, self._scenario)
constraint_margins = evaluate_constraint_margins(properties, self._scenario)
margin_score = sum(constraint_margins.values()) / max(len(constraint_margins), 1)
violation_penalty = round((1.0 - margin_score) * 2.0, 4)
hard_constraints_met = all(result[0] for result in constraint_results.values())
budget_efficiency = self._state.remaining_budget / max(self._scenario.oracle_budget, 1)
beats_baseline = final_score >= self._scenario.baseline_to_beat
current_signature = self._molecule_signature()
evidence_requirements = ["potency", "toxicity"]
if "synth_min" in self._scenario.hard_constraints:
evidence_requirements.append("synth")
missing_evidence = [
prop for prop in evidence_requirements if self._current_property_estimate(prop, current_signature) is None
]
evidence_met = not missing_evidence
post_shift_evidence_met = True
if self._scenario.target_shift_step and self._target_shift_active():
post_shift_evidence_met = any(
entry["step"] >= self._scenario.target_shift_step
and entry["molecule"] == current_signature
and any(result["property_name"] == "potency" for result in entry["results"])
for entry in self._oracle_log
)
valid_submission = hard_constraints_met and beats_baseline and evidence_met and post_shift_evidence_met
reward = final_score * 2.0 if valid_submission else final_score * 0.25
if valid_submission:
reward += 3.5
elif not hard_constraints_met:
reward -= violation_penalty
if not beats_baseline:
reward -= 0.6
if not evidence_met:
reward -= 1.2
if not post_shift_evidence_met:
reward -= 0.8
if valid_submission:
reward += max(0.0, budget_efficiency) * 0.7
if self._reward_mode == "curriculum" and evidence_met and post_shift_evidence_met:
submit_bonus = 0.35
if hard_constraints_met:
submit_bonus += 0.15
reward += submit_bonus
self._state.submitted = True
self._report_card = self._build_report_card(submitted=True)
self._last_summary = (
f"The team submitted a candidate that "
f"{'passed' if hard_constraints_met else 'failed'} hard constraints."
)
reward_components.extend(
[
RewardComponent(
name="submission_quality",
value=round((final_score * 2.0 if valid_submission else final_score * 0.25), 4),
explanation=(
"Full scientific quality reward because the submission met constraints, baseline, and evidence gates."
if valid_submission
else "Only a small quality trace is awarded because the submit action missed a gate."
),
),
RewardComponent(
name="hard_constraints",
value=(
3.5
if valid_submission
else (-violation_penalty if not hard_constraints_met else 0.0)
),
explanation=(
"Large sparse bonus for beating baseline with required current evidence."
if valid_submission
else "Submission missed constraints, baseline, or evidence requirements; constraint penalty scales with violation severity."
),
),
RewardComponent(
name="constraint_margin",
value=round(margin_score, 4),
explanation=(
"Proportional hard-constraint score: worse potency, toxicity, or synthesis violations produce lower values."
),
),
RewardComponent(
name="baseline_gate",
value=0.0 if beats_baseline else -0.6,
explanation=(
"Submitted molecule beat the scenario baseline."
if beats_baseline
else "Submitted molecule did not beat the scenario baseline."
),
),
RewardComponent(
name="submission_evidence",
value=0.0 if evidence_met else -1.2,
explanation=(
"Current-molecule potency/toxicity/synthesis evidence was available."
if evidence_met
else f"Submission lacked current evidence for: {', '.join(missing_evidence)}."
),
),
RewardComponent(
name="post_shift_evidence",
value=0.0 if post_shift_evidence_met else -0.8,
explanation=(
"Post-shift potency evidence was available for the submitted molecule."
if post_shift_evidence_met
else "Hard scenario submission lacked post-shift potency evidence for the current molecule."
),
),
RewardComponent(
name="budget_efficiency",
value=round(max(0.0, budget_efficiency) * 0.7, 4) if valid_submission else 0.0,
explanation=(
"Unused budget is rewarded to discourage wasteful oracle usage."
if valid_submission
else "Budget efficiency is not awarded to a gated or premature submission."
),
),
]
)
if self._reward_mode == "curriculum" and evidence_met and post_shift_evidence_met:
reward_components.append(
RewardComponent(
name="curriculum_evidence_supported_submit",
value=round(submit_bonus, 4),
explanation=(
"Curriculum reward for making a formal submit decision after the required "
"current evidence package was available."
),
)
)
return reward, True
def _restart(self, reward_components: List[RewardComponent]) -> float:
self._molecule = dict(self._scenario.restart_scaffold)
self._trap_penalty_active = False
self._known_assays = []
self._assay_runs = {}
self._restart_used = True
self._visited_states.add(self._molecule_signature())
self._state.remaining_budget -= 350
self._state.budget_used += 350
reward_components.append(
RewardComponent(
name="restart_penalty",
value=-0.4,
explanation="Restarting discards sunk work but switches to a clean scaffold family.",
)
)
self._last_summary = (
"The team abandoned the original scaffold series and restarted from a cleaner alternative."
)
return -0.4
def _build_assay_readings(self, tool_name: str, runs: int) -> List[AssayReading]:
properties = self._true_properties()
signature = self._molecule_signature()
if tool_name == "evaluate_properties":
property_names = ["potency", "novelty"]
elif tool_name == "dock_target":
property_names = ["potency"]
elif tool_name == "assay_toxicity":
property_names = ["toxicity"]
elif tool_name == "estimate_synthesizability":
property_names = ["synth"]
elif tool_name == "evaluate_novelty":
property_names = ["novelty"]
elif tool_name == "search_literature":
hint_score = min(0.95, 0.45 + 0.08 * runs)
return [
AssayReading(
tool_name=tool_name,
property_name="literature_signal",
estimate=round(hint_score, 4),
confidence_low=max(0.0, round(hint_score - 0.08, 4)),
confidence_high=min(1.0, round(hint_score + 0.08, 4)),
runs=runs,
molecule_signature=signature,
summary=literature_hints(self._molecule)[0],
)
]
else:
property_names = ["potency", "toxicity", "synth"]
readings = []
for property_name in property_names:
true_value = properties[property_name]
estimate = self._assay_estimate(signature, tool_name, property_name, runs, true_value)
width = max(0.03, 0.18 / runs)
readings.append(
AssayReading(
tool_name=tool_name,
property_name=property_name,
estimate=estimate,
confidence_low=max(0.0, round(estimate - width, 4)),
confidence_high=min(1.0, round(estimate + width, 4)),
runs=runs,
molecule_signature=signature,
summary=f"{tool_name} estimated {property_name} with run count {runs}.",
)
)
return readings