"""MLX-backed local LoRA inference runner for MolForge on Apple Silicon.""" from __future__ import annotations import json import os import time from pathlib import Path from typing import Any, Dict, Optional, Tuple from mlx_lm import generate, load from mlx_lm.sample_utils import make_sampler from inference_common import ( COMPACT_SYSTEM_PROMPT, SYSTEM_PROMPT, attach_team_messages, build_model_payload, extract_json, ) try: from molforge.models import MolForgeAction, MolForgeObservation from molforge.server.molforge_environment import MolForgeEnvironment except ImportError: from models import MolForgeAction, MolForgeObservation from server.molforge_environment import MolForgeEnvironment ADAPTER_PATH = Path(os.getenv("LORA_ADAPTER_PATH", "qwen3_5_2b_lora_adapters_strict")) BASE_MODEL_NAME = os.getenv("BASE_MODEL_NAME", "unsloth/Qwen3.5-2B") LOCAL_NUM_EPISODES = int(os.getenv("LOCAL_NUM_EPISODES", "3")) LOCAL_MAX_TURNS = int(os.getenv("LOCAL_MAX_TURNS", "10")) MLX_MAX_TOKENS = int(os.getenv("MLX_MAX_TOKENS", "768")) MLX_RETRY_MAX_TOKENS = int(os.getenv("MLX_RETRY_MAX_TOKENS", "512")) MLX_JSON_PREFILL = os.getenv("MLX_JSON_PREFILL", "true").lower() == "true" MLX_COMPACT_ACTION = os.getenv("MLX_COMPACT_ACTION", "false").lower() == "true" MLX_COMPACT_REPAIR = os.getenv("MLX_COMPACT_REPAIR", "false").lower() == "true" MLX_FORCED_ACTION_TYPES = [ item.strip() for item in os.getenv("MLX_FORCED_ACTION_TYPES", "").split(",") if item.strip() ] JSON_PREFILL = '{"action_type":"' COMPACT_ACTION_SYSTEM_PROMPT = """ You control the MolForge action policy. Return exactly one JSON object with only these top-level keys: action_type, acting_role, edit_type, slot, fragment, tool_name, rationale, evidence, expected_effects. Valid action_type values are exactly: edit, run_assay, submit, restart, defer. Do not output team messages. Do not output proposal, approval, objection, risk_flag, assay_request, rejection, or submission_recommendation as action_type. The environment will attach governance messages automatically. Role rules: - run_assay uses acting_role "assay_planner" and a valid tool_name. - edit, submit, restart, and defer use acting_role "lead_chemist". - unused optional fields must be JSON null. """.strip() def main() -> None: adapter_path = ADAPTER_PATH.expanduser().resolve() print(f"Using MLX base model: {BASE_MODEL_NAME}", flush=True) print(f"Using LoRA adapter: {adapter_path}", flush=True) model, tokenizer = load(BASE_MODEL_NAME, adapter_path=str(adapter_path)) sampler = make_sampler(temp=0.0) env = MolForgeEnvironment() scores = [] submission_scores = [] progress_scores = [] for episode_index in range(LOCAL_NUM_EPISODES): observation = env.reset() print(f"\n=== Episode {episode_index + 1}: {observation.scenario_id} ===", flush=True) for _ in range(LOCAL_MAX_TURNS): if observation.done: break action, source, elapsed = choose_mlx_action(model, tokenizer, sampler, observation) if MLX_COMPACT_ACTION: action = attach_team_messages(observation, action) observation = env.step(action) print( f"step={observation.step_index:02d} action={action.action_type} actor={action.acting_role} " f"source={source} gen_s={elapsed:.2f} reward={observation.reward:+.3f} " f"budget={observation.remaining_budget} governance={observation.governance.status}", flush=True, ) print(f" {observation.last_transition_summary}", flush=True) if observation.done: break grader_scores = observation.metadata.get("terminal_grader_scores", {}) final_score = float(grader_scores.get("final_score", grader_scores.get("submission_score", 0.0))) submission_score = float(grader_scores.get("submission_score", 0.0)) progress_score = float(grader_scores.get("progress_score", 0.0)) scores.append(final_score) submission_scores.append(submission_score) progress_scores.append(progress_score) print(f"final_score={final_score:.3f}", flush=True) print(f"submission_score={submission_score:.3f}", flush=True) print(f"progress_score={progress_score:.3f}", flush=True) if observation.report_card: print(observation.report_card, flush=True) average = sum(scores) / len(scores) average_progress = sum(progress_scores) / len(progress_scores) print("\n=== MLX LoRA Local Summary ===", flush=True) print( json.dumps( { "adapter": str(adapter_path), "base_model": BASE_MODEL_NAME, "scores": scores, "average_final_score": round(average, 4), "submission_scores": submission_scores, "average_submission_score": round(sum(submission_scores) / len(submission_scores), 4), "progress_scores": progress_scores, "average_progress_score": round(average_progress, 4), }, indent=2, ), flush=True, ) def choose_mlx_action( model, tokenizer, sampler, observation: MolForgeObservation, ) -> Tuple[MolForgeAction, str, float]: started = time.perf_counter() action, error = ask_mlx_model( model, tokenizer, sampler, observation, compact=False, max_tokens=MLX_MAX_TOKENS, forced_action_type=None, ) if action is not None: return action, "mlx_lora_model", time.perf_counter() - started forced_errors = [] for forced_action_type in forced_action_types(observation): forced_action, forced_error = ask_mlx_model( model, tokenizer, sampler, observation, compact=True, max_tokens=MLX_RETRY_MAX_TOKENS, forced_action_type=forced_action_type, ) if forced_action is not None: return ( forced_action, f"mlx_lora_forced_{forced_action_type}", time.perf_counter() - started, ) forced_errors.append(f"{forced_action_type}:{forced_error}") retry_action, retry_error = ask_mlx_model( model, tokenizer, sampler, observation, compact=True, max_tokens=MLX_RETRY_MAX_TOKENS, forced_action_type=None, ) if retry_action is not None: return retry_action, "mlx_lora_compact_retry", time.perf_counter() - started raise RuntimeError( "MLX LoRA action failed: " f"full_prompt:{error} | forced:{' || '.join(forced_errors)} | compact_prompt:{retry_error}" ) def ask_mlx_model( model, tokenizer, sampler, observation: MolForgeObservation, *, compact: bool, max_tokens: int, forced_action_type: Optional[str], ) -> Tuple[Optional[MolForgeAction], str]: response_text = "" try: payload = ( compact_action_payload(observation) if MLX_COMPACT_ACTION else build_model_payload(observation, compact=compact) ) system_prompt = ( COMPACT_ACTION_SYSTEM_PROMPT if MLX_COMPACT_ACTION else (COMPACT_SYSTEM_PROMPT if compact else SYSTEM_PROMPT) ) response_text = generate_response( model, tokenizer, sampler, system_prompt=system_prompt, user_payload=payload, max_tokens=max_tokens, use_json_prefill=MLX_JSON_PREFILL, forced_action_type=forced_action_type, ) if MLX_JSON_PREFILL: response_text = json_prefill(forced_action_type) + response_text data = extract_json(response_text) repair_notes: list[str] = [] if MLX_COMPACT_ACTION and MLX_COMPACT_REPAIR: data, repair_notes = repair_compact_action(data) if MLX_COMPACT_ACTION and "messages" in data: raise ValueError("compact action output must not include messages") action = MolForgeAction(**data) if repair_notes: action.metadata["compact_repair_notes"] = repair_notes return action, "" except Exception as exc: snippet = response_text[:1200].replace("\n", "\\n") return None, f"{exc.__class__.__name__}:{exc}; raw={snippet}" def generate_response( model, tokenizer, sampler, *, system_prompt: str, user_payload: Dict[str, Any], max_tokens: int, use_json_prefill: bool, forced_action_type: Optional[str], ) -> str: messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": json.dumps(user_payload, separators=(",", ":"))}, ] prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, ) if use_json_prefill: prompt += json_prefill(forced_action_type) return generate( model, tokenizer, prompt, verbose=False, max_tokens=max_tokens, sampler=sampler, ).strip() def json_prefill(forced_action_type: Optional[str]) -> str: if forced_action_type: return f'{{"action_type":"{forced_action_type}",' return JSON_PREFILL def forced_action_types(observation: MolForgeObservation) -> list[str]: if MLX_FORCED_ACTION_TYPES: return MLX_FORCED_ACTION_TYPES if observation.step_index == 0: if observation.scenario_id == "level_2_hard": return ["restart", "edit", "run_assay", "defer"] return ["edit", "run_assay", "defer"] return ["run_assay", "edit", "submit", "restart", "defer"] def compact_action_payload(observation: MolForgeObservation) -> dict[str, Any]: lead_view = next( (role.observation for role in observation.role_observations if role.role == "lead_chemist"), {}, ) assay_view = next( (role.observation for role in observation.role_observations if role.role == "assay_planner"), {}, ) return { "valid_action_types": ["edit", "run_assay", "submit", "restart", "defer"], "scenario_id": observation.scenario_id, "difficulty": observation.difficulty, "task_brief": observation.task_brief, "current_molecule": observation.current_molecule, "current_smiles": observation.metadata.get("current_smiles", ""), "visible_metrics": observation.visible_metrics, "constraint_status": [constraint.model_dump() for constraint in observation.constraint_status], "remaining_budget": observation.remaining_budget, "max_budget": observation.max_budget, "step_index": observation.step_index, "max_steps": observation.max_steps, "molecule_slots": lead_view.get("molecule_slots", {}), "candidate_edits": lead_view.get("candidate_edits", [])[:12], "open_questions": lead_view.get("open_questions", []), "known_assays": [ { "tool_name": reading.tool_name, "property_name": reading.property_name, "estimate": reading.estimate, "confidence_low": reading.confidence_low, "confidence_high": reading.confidence_high, "molecule_signature": reading.molecule_signature, } for reading in observation.known_assays[-8:] ], "tool_costs": assay_view.get("tool_costs", {}), "evidence_gaps": assay_view.get("evidence_gaps", []), "estimated_information_value": assay_view.get("estimated_information_value", {}), } def repair_compact_action(data: Dict[str, Any]) -> tuple[Dict[str, Any], list[str]]: """Bounded normalization for compact-action models. This repairs only schema-near-misses. It does not invent an action from a non-action wrapper and it still rejects invalid top-level action types. """ repaired = dict(data) notes: list[str] = [] if "role" in repaired and "acting_role" not in repaired: repaired["acting_role"] = repaired.pop("role") notes.append("role->acting_role") action_type = repaired.get("action_type") if action_type not in {"edit", "run_assay", "submit", "restart", "defer"}: return repaired, notes if repaired.get("edit_type") == "replace": repaired["edit_type"] = "substitute" notes.append("edit_type:replace->substitute") if isinstance(repaired.get("evidence"), str): repaired["evidence"] = [repaired["evidence"]] notes.append("evidence:string->list") repaired["expected_effects"] = repair_effects(repaired.get("expected_effects"), notes) if action_type == "run_assay": repaired["acting_role"] = "assay_planner" repaired["edit_type"] = None repaired["slot"] = None repaired["fragment"] = None if repaired.get("tool_name") not in { "evaluate_properties", "dock_target", "assay_toxicity", "estimate_synthesizability", "evaluate_novelty", "search_literature", "run_md_simulation", }: repaired["tool_name"] = "evaluate_properties" notes.append("tool_name:invalid->evaluate_properties") else: repaired["acting_role"] = "lead_chemist" if action_type == "edit": if repaired.get("edit_type") not in {"add_fragment", "substitute", "remove", "undo_last_edit"}: repaired["edit_type"] = "substitute" notes.append("edit_type:invalid->substitute") if repaired.get("tool_name") is not None: repaired["tool_name"] = None notes.append("tool_name:edit->null") else: for key in ("edit_type", "slot", "fragment", "tool_name"): if repaired.get(key) is not None: repaired[key] = None notes.append(f"{key}:{action_type}->null") allowed_keys = { "action_type", "acting_role", "edit_type", "slot", "fragment", "tool_name", "rationale", "evidence", "expected_effects", } for key in list(repaired): if key not in allowed_keys: repaired.pop(key) notes.append(f"drop_extra:{key}") repaired.setdefault("rationale", "Choose the next compact MolForge action.") repaired.setdefault("evidence", []) for key in ("edit_type", "slot", "fragment", "tool_name"): repaired.setdefault(key, None) return repaired, notes def repair_effects(value: Any, notes: list[str]) -> dict[str, str]: defaults = { "potency": "unknown", "toxicity": "unknown", "synth": "unknown", "novelty": "unknown", "budget": "neutral", } if not isinstance(value, dict): notes.append("expected_effects:non_dict->defaults") return defaults aliases = { "synthesizability": "synth", "synthesis": "synth", } for raw_key, raw_value in value.items(): key = aliases.get(raw_key, raw_key) if key not in defaults: notes.append(f"expected_effects:drop_extra:{raw_key}") continue defaults[key] = normalize_effect_value(raw_value, notes, key) return defaults def normalize_effect_value(value: Any, notes: list[str], key: str) -> str: if value in {"up", "down", "neutral", "unknown", "not_applicable"}: return value text = str(value).lower().strip().replace("-", "_").replace(" ", "_") if any(token in text for token in ("increase", "improve", "higher", "upward", "+")): notes.append(f"expected_effects:{key}:{value}->up") return "up" if any(token in text for token in ("decrease", "lower", "reduce", "downward", "-")): notes.append(f"expected_effects:{key}:{value}->down") return "down" if any(token in text for token in ("maintain", "stable", "unchanged", "same")): notes.append(f"expected_effects:{key}:{value}->neutral") return "neutral" if "not_applicable" in text or text == "na": notes.append(f"expected_effects:{key}:{value}->not_applicable") return "not_applicable" notes.append(f"expected_effects:{key}:{value}->unknown") return "unknown" if __name__ == "__main__": main()