molforge / scripts /generate_sft_compact_policy_v4_dataset.py
Adhitya122's picture
Prepare MolForge OpenEnv Docker Space submission
bf9e424 verified
"""Generate MolForge compact-policy SFT data aligned to MLX inference.
V4 is designed around the failures seen in the v3 adapter:
- train on the exact compact prompt/payload shape used at inference time
- emphasize successful end-to-end expert trajectories
- include recovery examples after governance vetoes
- include enough schema coverage for all core action types without making
unsafe edits or wasteful assays dominate the positive training signal
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from pathlib import Path
from typing import Any, Iterable
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from inference_common import ( # noqa: E402
MolForgeAction,
MolForgeObservation,
attach_reasoning_fields,
attach_team_messages,
heuristic_team_action,
)
from scenarios import DEFAULT_TOOL_COSTS # noqa: E402
from server.molforge_environment import MolForgeEnvironment # noqa: E402
COMPACT_ACTION_SYSTEM_PROMPT = """
You control the MolForge action policy.
Return exactly one JSON object with only these top-level keys:
action_type, acting_role, edit_type, slot, fragment, tool_name, rationale,
evidence, expected_effects.
Valid action_type values are exactly:
edit, run_assay, submit, restart, defer.
Do not output team messages. Do not output proposal, approval, objection,
risk_flag, assay_request, rejection, or submission_recommendation as action_type.
The environment will attach governance messages automatically.
Role rules:
- run_assay uses acting_role "assay_planner" and a valid tool_name.
- edit, submit, restart, and defer use acting_role "lead_chemist".
- unused optional fields must be JSON null.
""".strip()
def main() -> None:
parser = argparse.ArgumentParser(description="Generate compact MolForge v4 policy SFT JSONL.")
parser.add_argument("--episodes", type=int, default=520)
parser.add_argument("--max-turns", type=int, default=10)
parser.add_argument("--seed", default="policy-v4")
parser.add_argument("--output", default="issue/molforge_sft_compact_policy_v4.jsonl")
args = parser.parse_args()
records: list[dict[str, Any]] = []
seen: set[str] = set()
add_expert_traces(records, seen, episodes=18, max_turns=args.max_turns, randomized=False, seed=args.seed)
add_expert_traces(records, seen, episodes=args.episodes, max_turns=args.max_turns, randomized=True, seed=args.seed)
add_recovery_traces(records, seen, episodes=max(90, args.episodes // 3), seed=args.seed)
add_schema_coverage(records, seen, episodes=36, seed=args.seed)
output = Path(args.output)
output.parent.mkdir(parents=True, exist_ok=True)
with output.open("w", encoding="utf-8") as handle:
for record in records:
handle.write(json.dumps(record, ensure_ascii=True) + "\n")
print(json.dumps(summarize(records, str(output)), indent=2))
def add_expert_traces(
records: list[dict[str, Any]],
seen: set[str],
*,
episodes: int,
max_turns: int,
randomized: bool,
seed: str,
) -> None:
with_training_randomization(randomized, seed)
env = MolForgeEnvironment()
source = "expert_randomized" if randomized else "expert_canonical"
for _ in range(episodes):
observation = env.reset()
for _ in range(max_turns):
if observation.done:
break
action = heuristic_team_action(observation)
add_record(records, seen, observation, action, source=source)
observation = env.step(action)
def add_recovery_traces(records: list[dict[str, Any]], seen: set[str], *, episodes: int, seed: str) -> None:
with_training_randomization(True, f"{seed}-recovery")
env = MolForgeEnvironment()
for episode_index in range(episodes):
observation = env.reset()
# Move some episodes to a useful intermediate state before injecting a bad decision.
for _ in range(episode_index % 3):
if observation.done:
break
observation = env.step(heuristic_team_action(observation))
if observation.done:
continue
for bad_action in bad_actions_for(observation):
trial = clone_env_at_observation(env, episode_index)
trial_obs = advance_like_source(trial, episode_index % 3)
if trial_obs.done:
continue
veto_obs = trial.step(attach_team_messages(trial_obs, attach_reasoning_fields(trial_obs, bad_action)))
if veto_obs.done:
continue
if veto_obs.governance.status != "policy_veto":
continue
recovery = heuristic_team_action(veto_obs)
add_record(records, seen, veto_obs, recovery, source="recovery_after_veto")
def add_schema_coverage(records: list[dict[str, Any]], seen: set[str], *, episodes: int, seed: str) -> None:
with_training_randomization(True, f"{seed}-coverage")
env = MolForgeEnvironment()
observations: list[MolForgeObservation] = []
for _ in range(episodes):
observation = env.reset()
observations.append(observation)
for _ in range(2):
if observation.done:
break
observation = env.step(heuristic_team_action(observation))
observations.append(observation)
defer_examples = 0
for observation in observations:
current = {slot.slot: slot.fragment for slot in observation.molecule_slots}
safe_edits = [
("solvent_tail", "morpholine", "Use morpholine to reduce safety risk."),
("back_pocket", "cyano", "Use cyano to preserve potency with lower lipophilic risk."),
("warhead", "reversible_cyanoacrylamide", "Use a softer warhead to reduce reactivity."),
("hinge", "azaindole", "Use azaindole when potency needs recovery."),
]
for slot, fragment, rationale in safe_edits:
if current.get(slot) == fragment:
continue
add_record(
records,
seen,
observation,
MolForgeAction(
action_type="edit",
acting_role="lead_chemist",
edit_type="substitute",
slot=slot, # type: ignore[arg-type]
fragment=fragment,
rationale=rationale,
),
source="schema_safe_edit",
)
if observation.step_index > 0:
add_record(
records,
seen,
observation,
MolForgeAction(
action_type="edit",
acting_role="lead_chemist",
edit_type="remove",
slot="back_pocket",
rationale="Remove the back-pocket group to simplify risk before reassay.",
),
source="schema_remove",
)
for tool_name in useful_tool_subset(observation):
add_record(
records,
seen,
observation,
MolForgeAction(
action_type="run_assay",
acting_role="assay_planner",
tool_name=tool_name, # type: ignore[arg-type]
rationale=f"Run {tool_name} to close a visible evidence gap.",
),
source="schema_tool_coverage",
)
if (
defer_examples < 36
and observation.step_index >= 1
and observation.scenario_id != "level_2_hard"
):
add_record(
records,
seen,
observation,
MolForgeAction(
action_type="defer",
acting_role="lead_chemist",
rationale="Defer because no safe evidence-backed action remains in the current budget window.",
),
source="schema_defer",
)
defer_examples += 1
def useful_tool_subset(observation: MolForgeObservation) -> list[str]:
gaps = set()
for constraint in observation.constraint_status:
if constraint.evidence_status == "unknown":
if constraint.name == "toxicity_max":
gaps.add("toxicity")
else:
gaps.add(constraint.name.split("_")[0])
tools: list[str] = []
if "potency" in gaps and observation.remaining_budget >= DEFAULT_TOOL_COSTS["dock_target"]:
tools.extend(["evaluate_properties", "dock_target"])
if "toxicity" in gaps and observation.remaining_budget >= DEFAULT_TOOL_COSTS["assay_toxicity"]:
tools.append("assay_toxicity")
if "synth" in gaps and observation.remaining_budget >= DEFAULT_TOOL_COSTS["estimate_synthesizability"]:
tools.append("estimate_synthesizability")
if observation.remaining_budget >= DEFAULT_TOOL_COSTS["evaluate_novelty"]:
tools.append("evaluate_novelty")
if observation.remaining_budget >= DEFAULT_TOOL_COSTS["search_literature"]:
tools.append("search_literature")
if observation.scenario_id == "level_2_hard" and observation.remaining_budget >= DEFAULT_TOOL_COSTS["run_md_simulation"]:
tools.append("run_md_simulation")
return tools
def bad_actions_for(observation: MolForgeObservation) -> Iterable[MolForgeAction]:
current = {slot.slot: slot.fragment for slot in observation.molecule_slots}
candidates = [
("solvent_tail", "dimethylamino", "This would add a safety liability and should be recovered from."),
("back_pocket", "trifluoromethyl", "This would over-shoot lipophilic risk and should be recovered from."),
("hinge", "quinazoline", "This can create route pressure and should be recovered from."),
]
for slot, fragment, rationale in candidates:
if current.get(slot) == fragment:
continue
yield MolForgeAction(
action_type="edit",
acting_role="lead_chemist",
edit_type="substitute",
slot=slot, # type: ignore[arg-type]
fragment=fragment,
rationale=rationale,
)
def clone_env_at_observation(source_env: MolForgeEnvironment, episode_index: int) -> MolForgeEnvironment:
del source_env
env = MolForgeEnvironment()
for _ in range(episode_index + 1):
observation = env.reset()
return env
def advance_like_source(env: MolForgeEnvironment, steps: int) -> MolForgeObservation:
observation = env._build_observation(reward=0.0, done=False, reward_components=[]) # noqa: SLF001
for _ in range(steps):
if observation.done:
return observation
observation = env.step(heuristic_team_action(observation))
return observation
def with_training_randomization(enabled: bool, seed: str) -> None:
if enabled:
os.environ["MOLFORGE_TRAINING_RANDOMIZATION"] = "1"
else:
os.environ.pop("MOLFORGE_TRAINING_RANDOMIZATION", None)
os.environ["MOLFORGE_RANDOM_SEED"] = seed
def add_record(
records: list[dict[str, Any]],
seen: set[str],
observation: MolForgeObservation,
action: MolForgeAction,
*,
source: str,
) -> None:
action = attach_reasoning_fields(observation, action)
record = make_record(observation, action, source=source)
key = json.dumps(
{"user": record["messages"][1]["content"], "assistant": record["messages"][2]["content"]},
sort_keys=True,
)
if key in seen:
return
validate_target(record["messages"][2]["content"])
records.append(record)
seen.add(key)
def make_record(observation: MolForgeObservation, action: MolForgeAction, *, source: str) -> dict[str, Any]:
return {
"messages": [
{"role": "system", "content": COMPACT_ACTION_SYSTEM_PROMPT},
{"role": "user", "content": json.dumps(compact_action_payload(observation), separators=(",", ":"))},
{"role": "assistant", "content": json.dumps(target_action(action), separators=(",", ":"))},
],
"metadata": {
"source": source,
"scenario_id": observation.scenario_id,
"difficulty": observation.difficulty,
"step_index": observation.step_index,
"action_type": action.action_type,
},
}
def compact_action_payload(observation: MolForgeObservation) -> dict[str, Any]:
lead_view = next(
(role.observation for role in observation.role_observations if role.role == "lead_chemist"),
{},
)
assay_view = next(
(role.observation for role in observation.role_observations if role.role == "assay_planner"),
{},
)
return {
"valid_action_types": ["edit", "run_assay", "submit", "restart", "defer"],
"scenario_id": observation.scenario_id,
"difficulty": observation.difficulty,
"task_brief": observation.task_brief,
"current_molecule": observation.current_molecule,
"current_smiles": observation.metadata.get("current_smiles", ""),
"visible_metrics": observation.visible_metrics,
"constraint_status": [constraint.model_dump() for constraint in observation.constraint_status],
"remaining_budget": observation.remaining_budget,
"max_budget": observation.max_budget,
"step_index": observation.step_index,
"max_steps": observation.max_steps,
"molecule_slots": lead_view.get("molecule_slots", {}),
"candidate_edits": lead_view.get("candidate_edits", [])[:12],
"open_questions": lead_view.get("open_questions", []),
"known_assays": [
{
"tool_name": reading.tool_name,
"property_name": reading.property_name,
"estimate": reading.estimate,
"confidence_low": reading.confidence_low,
"confidence_high": reading.confidence_high,
"molecule_signature": reading.molecule_signature,
}
for reading in observation.known_assays[-8:]
],
"tool_costs": assay_view.get("tool_costs", {}),
"evidence_gaps": assay_view.get("evidence_gaps", []),
"estimated_information_value": assay_view.get("estimated_information_value", {}),
}
def target_action(action: MolForgeAction) -> dict[str, Any]:
effects = {
"potency": "unknown",
"toxicity": "unknown",
"synth": "unknown",
"novelty": "unknown",
"budget": "neutral",
}
effects.update({key: value for key, value in action.expected_effects.items() if key in effects})
return {
"action_type": action.action_type,
"acting_role": action.acting_role,
"edit_type": action.edit_type,
"slot": action.slot,
"fragment": action.fragment,
"tool_name": action.tool_name,
"rationale": action.rationale[:220],
"evidence": list(action.evidence[:5]),
"expected_effects": effects,
}
def validate_target(text: str) -> None:
data = json.loads(text)
allowed = {
"action_type",
"acting_role",
"edit_type",
"slot",
"fragment",
"tool_name",
"rationale",
"evidence",
"expected_effects",
}
if set(data) != allowed:
raise ValueError(f"target keys mismatch: {sorted(data)}")
if data["action_type"] not in {"edit", "run_assay", "submit", "restart", "defer"}:
raise ValueError(f"invalid action_type: {data['action_type']}")
if data["action_type"] == "proposal":
raise ValueError("proposal is not a compact action type")
if data["edit_type"] == "replace":
raise ValueError("replace must never be used; use substitute")
if "messages" in data:
raise ValueError("compact target must not contain messages")
if not isinstance(data["evidence"], list):
raise ValueError("evidence must be a list")
if set(data["expected_effects"]) != {"potency", "toxicity", "synth", "novelty", "budget"}:
raise ValueError("expected_effects must have exactly five keys")
MolForgeAction(**data)
def summarize(records: list[dict[str, Any]], output: str) -> dict[str, Any]:
actions: dict[str, int] = {}
sources: dict[str, int] = {}
scenarios: dict[str, int] = {}
users = set()
assistants = set()
for record in records:
metadata = record["metadata"]
actions[metadata["action_type"]] = actions.get(metadata["action_type"], 0) + 1
sources[metadata["source"]] = sources.get(metadata["source"], 0) + 1
scenarios[metadata["scenario_id"]] = scenarios.get(metadata["scenario_id"], 0) + 1
users.add(record["messages"][1]["content"])
assistants.add(record["messages"][2]["content"])
return {
"output": output,
"records": len(records),
"unique_user_prompts": len(users),
"unique_assistant_targets": len(assistants),
"action_types": actions,
"sources": sources,
"scenario_ids": scenarios,
}
if __name__ == "__main__":
main()