"""Generate MolForge compact-policy SFT data aligned to MLX inference. V4 is designed around the failures seen in the v3 adapter: - train on the exact compact prompt/payload shape used at inference time - emphasize successful end-to-end expert trajectories - include recovery examples after governance vetoes - include enough schema coverage for all core action types without making unsafe edits or wasteful assays dominate the positive training signal """ from __future__ import annotations import argparse import json import os import sys from pathlib import Path from typing import Any, Iterable PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from inference_common import ( # noqa: E402 MolForgeAction, MolForgeObservation, attach_reasoning_fields, attach_team_messages, heuristic_team_action, ) from scenarios import DEFAULT_TOOL_COSTS # noqa: E402 from server.molforge_environment import MolForgeEnvironment # noqa: E402 COMPACT_ACTION_SYSTEM_PROMPT = """ You control the MolForge action policy. Return exactly one JSON object with only these top-level keys: action_type, acting_role, edit_type, slot, fragment, tool_name, rationale, evidence, expected_effects. Valid action_type values are exactly: edit, run_assay, submit, restart, defer. Do not output team messages. Do not output proposal, approval, objection, risk_flag, assay_request, rejection, or submission_recommendation as action_type. The environment will attach governance messages automatically. Role rules: - run_assay uses acting_role "assay_planner" and a valid tool_name. - edit, submit, restart, and defer use acting_role "lead_chemist". - unused optional fields must be JSON null. """.strip() def main() -> None: parser = argparse.ArgumentParser(description="Generate compact MolForge v4 policy SFT JSONL.") parser.add_argument("--episodes", type=int, default=520) parser.add_argument("--max-turns", type=int, default=10) parser.add_argument("--seed", default="policy-v4") parser.add_argument("--output", default="issue/molforge_sft_compact_policy_v4.jsonl") args = parser.parse_args() records: list[dict[str, Any]] = [] seen: set[str] = set() add_expert_traces(records, seen, episodes=18, max_turns=args.max_turns, randomized=False, seed=args.seed) add_expert_traces(records, seen, episodes=args.episodes, max_turns=args.max_turns, randomized=True, seed=args.seed) add_recovery_traces(records, seen, episodes=max(90, args.episodes // 3), seed=args.seed) add_schema_coverage(records, seen, episodes=36, seed=args.seed) output = Path(args.output) output.parent.mkdir(parents=True, exist_ok=True) with output.open("w", encoding="utf-8") as handle: for record in records: handle.write(json.dumps(record, ensure_ascii=True) + "\n") print(json.dumps(summarize(records, str(output)), indent=2)) def add_expert_traces( records: list[dict[str, Any]], seen: set[str], *, episodes: int, max_turns: int, randomized: bool, seed: str, ) -> None: with_training_randomization(randomized, seed) env = MolForgeEnvironment() source = "expert_randomized" if randomized else "expert_canonical" for _ in range(episodes): observation = env.reset() for _ in range(max_turns): if observation.done: break action = heuristic_team_action(observation) add_record(records, seen, observation, action, source=source) observation = env.step(action) def add_recovery_traces(records: list[dict[str, Any]], seen: set[str], *, episodes: int, seed: str) -> None: with_training_randomization(True, f"{seed}-recovery") env = MolForgeEnvironment() for episode_index in range(episodes): observation = env.reset() # Move some episodes to a useful intermediate state before injecting a bad decision. for _ in range(episode_index % 3): if observation.done: break observation = env.step(heuristic_team_action(observation)) if observation.done: continue for bad_action in bad_actions_for(observation): trial = clone_env_at_observation(env, episode_index) trial_obs = advance_like_source(trial, episode_index % 3) if trial_obs.done: continue veto_obs = trial.step(attach_team_messages(trial_obs, attach_reasoning_fields(trial_obs, bad_action))) if veto_obs.done: continue if veto_obs.governance.status != "policy_veto": continue recovery = heuristic_team_action(veto_obs) add_record(records, seen, veto_obs, recovery, source="recovery_after_veto") def add_schema_coverage(records: list[dict[str, Any]], seen: set[str], *, episodes: int, seed: str) -> None: with_training_randomization(True, f"{seed}-coverage") env = MolForgeEnvironment() observations: list[MolForgeObservation] = [] for _ in range(episodes): observation = env.reset() observations.append(observation) for _ in range(2): if observation.done: break observation = env.step(heuristic_team_action(observation)) observations.append(observation) defer_examples = 0 for observation in observations: current = {slot.slot: slot.fragment for slot in observation.molecule_slots} safe_edits = [ ("solvent_tail", "morpholine", "Use morpholine to reduce safety risk."), ("back_pocket", "cyano", "Use cyano to preserve potency with lower lipophilic risk."), ("warhead", "reversible_cyanoacrylamide", "Use a softer warhead to reduce reactivity."), ("hinge", "azaindole", "Use azaindole when potency needs recovery."), ] for slot, fragment, rationale in safe_edits: if current.get(slot) == fragment: continue add_record( records, seen, observation, MolForgeAction( action_type="edit", acting_role="lead_chemist", edit_type="substitute", slot=slot, # type: ignore[arg-type] fragment=fragment, rationale=rationale, ), source="schema_safe_edit", ) if observation.step_index > 0: add_record( records, seen, observation, MolForgeAction( action_type="edit", acting_role="lead_chemist", edit_type="remove", slot="back_pocket", rationale="Remove the back-pocket group to simplify risk before reassay.", ), source="schema_remove", ) for tool_name in useful_tool_subset(observation): add_record( records, seen, observation, MolForgeAction( action_type="run_assay", acting_role="assay_planner", tool_name=tool_name, # type: ignore[arg-type] rationale=f"Run {tool_name} to close a visible evidence gap.", ), source="schema_tool_coverage", ) if ( defer_examples < 36 and observation.step_index >= 1 and observation.scenario_id != "level_2_hard" ): add_record( records, seen, observation, MolForgeAction( action_type="defer", acting_role="lead_chemist", rationale="Defer because no safe evidence-backed action remains in the current budget window.", ), source="schema_defer", ) defer_examples += 1 def useful_tool_subset(observation: MolForgeObservation) -> list[str]: gaps = set() for constraint in observation.constraint_status: if constraint.evidence_status == "unknown": if constraint.name == "toxicity_max": gaps.add("toxicity") else: gaps.add(constraint.name.split("_")[0]) tools: list[str] = [] if "potency" in gaps and observation.remaining_budget >= DEFAULT_TOOL_COSTS["dock_target"]: tools.extend(["evaluate_properties", "dock_target"]) if "toxicity" in gaps and observation.remaining_budget >= DEFAULT_TOOL_COSTS["assay_toxicity"]: tools.append("assay_toxicity") if "synth" in gaps and observation.remaining_budget >= DEFAULT_TOOL_COSTS["estimate_synthesizability"]: tools.append("estimate_synthesizability") if observation.remaining_budget >= DEFAULT_TOOL_COSTS["evaluate_novelty"]: tools.append("evaluate_novelty") if observation.remaining_budget >= DEFAULT_TOOL_COSTS["search_literature"]: tools.append("search_literature") if observation.scenario_id == "level_2_hard" and observation.remaining_budget >= DEFAULT_TOOL_COSTS["run_md_simulation"]: tools.append("run_md_simulation") return tools def bad_actions_for(observation: MolForgeObservation) -> Iterable[MolForgeAction]: current = {slot.slot: slot.fragment for slot in observation.molecule_slots} candidates = [ ("solvent_tail", "dimethylamino", "This would add a safety liability and should be recovered from."), ("back_pocket", "trifluoromethyl", "This would over-shoot lipophilic risk and should be recovered from."), ("hinge", "quinazoline", "This can create route pressure and should be recovered from."), ] for slot, fragment, rationale in candidates: if current.get(slot) == fragment: continue yield MolForgeAction( action_type="edit", acting_role="lead_chemist", edit_type="substitute", slot=slot, # type: ignore[arg-type] fragment=fragment, rationale=rationale, ) def clone_env_at_observation(source_env: MolForgeEnvironment, episode_index: int) -> MolForgeEnvironment: del source_env env = MolForgeEnvironment() for _ in range(episode_index + 1): observation = env.reset() return env def advance_like_source(env: MolForgeEnvironment, steps: int) -> MolForgeObservation: observation = env._build_observation(reward=0.0, done=False, reward_components=[]) # noqa: SLF001 for _ in range(steps): if observation.done: return observation observation = env.step(heuristic_team_action(observation)) return observation def with_training_randomization(enabled: bool, seed: str) -> None: if enabled: os.environ["MOLFORGE_TRAINING_RANDOMIZATION"] = "1" else: os.environ.pop("MOLFORGE_TRAINING_RANDOMIZATION", None) os.environ["MOLFORGE_RANDOM_SEED"] = seed def add_record( records: list[dict[str, Any]], seen: set[str], observation: MolForgeObservation, action: MolForgeAction, *, source: str, ) -> None: action = attach_reasoning_fields(observation, action) record = make_record(observation, action, source=source) key = json.dumps( {"user": record["messages"][1]["content"], "assistant": record["messages"][2]["content"]}, sort_keys=True, ) if key in seen: return validate_target(record["messages"][2]["content"]) records.append(record) seen.add(key) def make_record(observation: MolForgeObservation, action: MolForgeAction, *, source: str) -> dict[str, Any]: return { "messages": [ {"role": "system", "content": COMPACT_ACTION_SYSTEM_PROMPT}, {"role": "user", "content": json.dumps(compact_action_payload(observation), separators=(",", ":"))}, {"role": "assistant", "content": json.dumps(target_action(action), separators=(",", ":"))}, ], "metadata": { "source": source, "scenario_id": observation.scenario_id, "difficulty": observation.difficulty, "step_index": observation.step_index, "action_type": action.action_type, }, } def compact_action_payload(observation: MolForgeObservation) -> dict[str, Any]: lead_view = next( (role.observation for role in observation.role_observations if role.role == "lead_chemist"), {}, ) assay_view = next( (role.observation for role in observation.role_observations if role.role == "assay_planner"), {}, ) return { "valid_action_types": ["edit", "run_assay", "submit", "restart", "defer"], "scenario_id": observation.scenario_id, "difficulty": observation.difficulty, "task_brief": observation.task_brief, "current_molecule": observation.current_molecule, "current_smiles": observation.metadata.get("current_smiles", ""), "visible_metrics": observation.visible_metrics, "constraint_status": [constraint.model_dump() for constraint in observation.constraint_status], "remaining_budget": observation.remaining_budget, "max_budget": observation.max_budget, "step_index": observation.step_index, "max_steps": observation.max_steps, "molecule_slots": lead_view.get("molecule_slots", {}), "candidate_edits": lead_view.get("candidate_edits", [])[:12], "open_questions": lead_view.get("open_questions", []), "known_assays": [ { "tool_name": reading.tool_name, "property_name": reading.property_name, "estimate": reading.estimate, "confidence_low": reading.confidence_low, "confidence_high": reading.confidence_high, "molecule_signature": reading.molecule_signature, } for reading in observation.known_assays[-8:] ], "tool_costs": assay_view.get("tool_costs", {}), "evidence_gaps": assay_view.get("evidence_gaps", []), "estimated_information_value": assay_view.get("estimated_information_value", {}), } def target_action(action: MolForgeAction) -> dict[str, Any]: effects = { "potency": "unknown", "toxicity": "unknown", "synth": "unknown", "novelty": "unknown", "budget": "neutral", } effects.update({key: value for key, value in action.expected_effects.items() if key in effects}) return { "action_type": action.action_type, "acting_role": action.acting_role, "edit_type": action.edit_type, "slot": action.slot, "fragment": action.fragment, "tool_name": action.tool_name, "rationale": action.rationale[:220], "evidence": list(action.evidence[:5]), "expected_effects": effects, } def validate_target(text: str) -> None: data = json.loads(text) allowed = { "action_type", "acting_role", "edit_type", "slot", "fragment", "tool_name", "rationale", "evidence", "expected_effects", } if set(data) != allowed: raise ValueError(f"target keys mismatch: {sorted(data)}") if data["action_type"] not in {"edit", "run_assay", "submit", "restart", "defer"}: raise ValueError(f"invalid action_type: {data['action_type']}") if data["action_type"] == "proposal": raise ValueError("proposal is not a compact action type") if data["edit_type"] == "replace": raise ValueError("replace must never be used; use substitute") if "messages" in data: raise ValueError("compact target must not contain messages") if not isinstance(data["evidence"], list): raise ValueError("evidence must be a list") if set(data["expected_effects"]) != {"potency", "toxicity", "synth", "novelty", "budget"}: raise ValueError("expected_effects must have exactly five keys") MolForgeAction(**data) def summarize(records: list[dict[str, Any]], output: str) -> dict[str, Any]: actions: dict[str, int] = {} sources: dict[str, int] = {} scenarios: dict[str, int] = {} users = set() assistants = set() for record in records: metadata = record["metadata"] actions[metadata["action_type"]] = actions.get(metadata["action_type"], 0) + 1 sources[metadata["source"]] = sources.get(metadata["source"], 0) + 1 scenarios[metadata["scenario_id"]] = scenarios.get(metadata["scenario_id"], 0) + 1 users.add(record["messages"][1]["content"]) assistants.add(record["messages"][2]["content"]) return { "output": output, "records": len(records), "unique_user_prompts": len(users), "unique_assistant_targets": len(assistants), "action_types": actions, "sources": sources, "scenario_ids": scenarios, } if __name__ == "__main__": main()