molforge / mlx_lora_inference.py
Adhitya122's picture
Prepare MolForge OpenEnv Docker Space submission
bf9e424 verified
"""MLX-backed local LoRA inference runner for MolForge on Apple Silicon."""
from __future__ import annotations
import json
import os
import time
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from mlx_lm import generate, load
from mlx_lm.sample_utils import make_sampler
from inference_common import (
COMPACT_SYSTEM_PROMPT,
SYSTEM_PROMPT,
attach_team_messages,
build_model_payload,
extract_json,
)
try:
from molforge.models import MolForgeAction, MolForgeObservation
from molforge.server.molforge_environment import MolForgeEnvironment
except ImportError:
from models import MolForgeAction, MolForgeObservation
from server.molforge_environment import MolForgeEnvironment
ADAPTER_PATH = Path(os.getenv("LORA_ADAPTER_PATH", "qwen3_5_2b_lora_adapters_strict"))
BASE_MODEL_NAME = os.getenv("BASE_MODEL_NAME", "unsloth/Qwen3.5-2B")
LOCAL_NUM_EPISODES = int(os.getenv("LOCAL_NUM_EPISODES", "3"))
LOCAL_MAX_TURNS = int(os.getenv("LOCAL_MAX_TURNS", "10"))
MLX_MAX_TOKENS = int(os.getenv("MLX_MAX_TOKENS", "768"))
MLX_RETRY_MAX_TOKENS = int(os.getenv("MLX_RETRY_MAX_TOKENS", "512"))
MLX_JSON_PREFILL = os.getenv("MLX_JSON_PREFILL", "true").lower() == "true"
MLX_COMPACT_ACTION = os.getenv("MLX_COMPACT_ACTION", "false").lower() == "true"
MLX_COMPACT_REPAIR = os.getenv("MLX_COMPACT_REPAIR", "false").lower() == "true"
MLX_FORCED_ACTION_TYPES = [
item.strip()
for item in os.getenv("MLX_FORCED_ACTION_TYPES", "").split(",")
if item.strip()
]
JSON_PREFILL = '{"action_type":"'
COMPACT_ACTION_SYSTEM_PROMPT = """
You control the MolForge action policy.
Return exactly one JSON object with only these top-level keys:
action_type, acting_role, edit_type, slot, fragment, tool_name, rationale,
evidence, expected_effects.
Valid action_type values are exactly:
edit, run_assay, submit, restart, defer.
Do not output team messages. Do not output proposal, approval, objection,
risk_flag, assay_request, rejection, or submission_recommendation as action_type.
The environment will attach governance messages automatically.
Role rules:
- run_assay uses acting_role "assay_planner" and a valid tool_name.
- edit, submit, restart, and defer use acting_role "lead_chemist".
- unused optional fields must be JSON null.
""".strip()
def main() -> None:
adapter_path = ADAPTER_PATH.expanduser().resolve()
print(f"Using MLX base model: {BASE_MODEL_NAME}", flush=True)
print(f"Using LoRA adapter: {adapter_path}", flush=True)
model, tokenizer = load(BASE_MODEL_NAME, adapter_path=str(adapter_path))
sampler = make_sampler(temp=0.0)
env = MolForgeEnvironment()
scores = []
submission_scores = []
progress_scores = []
for episode_index in range(LOCAL_NUM_EPISODES):
observation = env.reset()
print(f"\n=== Episode {episode_index + 1}: {observation.scenario_id} ===", flush=True)
for _ in range(LOCAL_MAX_TURNS):
if observation.done:
break
action, source, elapsed = choose_mlx_action(model, tokenizer, sampler, observation)
if MLX_COMPACT_ACTION:
action = attach_team_messages(observation, action)
observation = env.step(action)
print(
f"step={observation.step_index:02d} action={action.action_type} actor={action.acting_role} "
f"source={source} gen_s={elapsed:.2f} reward={observation.reward:+.3f} "
f"budget={observation.remaining_budget} governance={observation.governance.status}",
flush=True,
)
print(f" {observation.last_transition_summary}", flush=True)
if observation.done:
break
grader_scores = observation.metadata.get("terminal_grader_scores", {})
final_score = float(grader_scores.get("final_score", grader_scores.get("submission_score", 0.0)))
submission_score = float(grader_scores.get("submission_score", 0.0))
progress_score = float(grader_scores.get("progress_score", 0.0))
scores.append(final_score)
submission_scores.append(submission_score)
progress_scores.append(progress_score)
print(f"final_score={final_score:.3f}", flush=True)
print(f"submission_score={submission_score:.3f}", flush=True)
print(f"progress_score={progress_score:.3f}", flush=True)
if observation.report_card:
print(observation.report_card, flush=True)
average = sum(scores) / len(scores)
average_progress = sum(progress_scores) / len(progress_scores)
print("\n=== MLX LoRA Local Summary ===", flush=True)
print(
json.dumps(
{
"adapter": str(adapter_path),
"base_model": BASE_MODEL_NAME,
"scores": scores,
"average_final_score": round(average, 4),
"submission_scores": submission_scores,
"average_submission_score": round(sum(submission_scores) / len(submission_scores), 4),
"progress_scores": progress_scores,
"average_progress_score": round(average_progress, 4),
},
indent=2,
),
flush=True,
)
def choose_mlx_action(
model,
tokenizer,
sampler,
observation: MolForgeObservation,
) -> Tuple[MolForgeAction, str, float]:
started = time.perf_counter()
action, error = ask_mlx_model(
model,
tokenizer,
sampler,
observation,
compact=False,
max_tokens=MLX_MAX_TOKENS,
forced_action_type=None,
)
if action is not None:
return action, "mlx_lora_model", time.perf_counter() - started
forced_errors = []
for forced_action_type in forced_action_types(observation):
forced_action, forced_error = ask_mlx_model(
model,
tokenizer,
sampler,
observation,
compact=True,
max_tokens=MLX_RETRY_MAX_TOKENS,
forced_action_type=forced_action_type,
)
if forced_action is not None:
return (
forced_action,
f"mlx_lora_forced_{forced_action_type}",
time.perf_counter() - started,
)
forced_errors.append(f"{forced_action_type}:{forced_error}")
retry_action, retry_error = ask_mlx_model(
model,
tokenizer,
sampler,
observation,
compact=True,
max_tokens=MLX_RETRY_MAX_TOKENS,
forced_action_type=None,
)
if retry_action is not None:
return retry_action, "mlx_lora_compact_retry", time.perf_counter() - started
raise RuntimeError(
"MLX LoRA action failed: "
f"full_prompt:{error} | forced:{' || '.join(forced_errors)} | compact_prompt:{retry_error}"
)
def ask_mlx_model(
model,
tokenizer,
sampler,
observation: MolForgeObservation,
*,
compact: bool,
max_tokens: int,
forced_action_type: Optional[str],
) -> Tuple[Optional[MolForgeAction], str]:
response_text = ""
try:
payload = (
compact_action_payload(observation)
if MLX_COMPACT_ACTION
else build_model_payload(observation, compact=compact)
)
system_prompt = (
COMPACT_ACTION_SYSTEM_PROMPT
if MLX_COMPACT_ACTION
else (COMPACT_SYSTEM_PROMPT if compact else SYSTEM_PROMPT)
)
response_text = generate_response(
model,
tokenizer,
sampler,
system_prompt=system_prompt,
user_payload=payload,
max_tokens=max_tokens,
use_json_prefill=MLX_JSON_PREFILL,
forced_action_type=forced_action_type,
)
if MLX_JSON_PREFILL:
response_text = json_prefill(forced_action_type) + response_text
data = extract_json(response_text)
repair_notes: list[str] = []
if MLX_COMPACT_ACTION and MLX_COMPACT_REPAIR:
data, repair_notes = repair_compact_action(data)
if MLX_COMPACT_ACTION and "messages" in data:
raise ValueError("compact action output must not include messages")
action = MolForgeAction(**data)
if repair_notes:
action.metadata["compact_repair_notes"] = repair_notes
return action, ""
except Exception as exc:
snippet = response_text[:1200].replace("\n", "\\n")
return None, f"{exc.__class__.__name__}:{exc}; raw={snippet}"
def generate_response(
model,
tokenizer,
sampler,
*,
system_prompt: str,
user_payload: Dict[str, Any],
max_tokens: int,
use_json_prefill: bool,
forced_action_type: Optional[str],
) -> str:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": json.dumps(user_payload, separators=(",", ":"))},
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
if use_json_prefill:
prompt += json_prefill(forced_action_type)
return generate(
model,
tokenizer,
prompt,
verbose=False,
max_tokens=max_tokens,
sampler=sampler,
).strip()
def json_prefill(forced_action_type: Optional[str]) -> str:
if forced_action_type:
return f'{{"action_type":"{forced_action_type}",'
return JSON_PREFILL
def forced_action_types(observation: MolForgeObservation) -> list[str]:
if MLX_FORCED_ACTION_TYPES:
return MLX_FORCED_ACTION_TYPES
if observation.step_index == 0:
if observation.scenario_id == "level_2_hard":
return ["restart", "edit", "run_assay", "defer"]
return ["edit", "run_assay", "defer"]
return ["run_assay", "edit", "submit", "restart", "defer"]
def compact_action_payload(observation: MolForgeObservation) -> dict[str, Any]:
lead_view = next(
(role.observation for role in observation.role_observations if role.role == "lead_chemist"),
{},
)
assay_view = next(
(role.observation for role in observation.role_observations if role.role == "assay_planner"),
{},
)
return {
"valid_action_types": ["edit", "run_assay", "submit", "restart", "defer"],
"scenario_id": observation.scenario_id,
"difficulty": observation.difficulty,
"task_brief": observation.task_brief,
"current_molecule": observation.current_molecule,
"current_smiles": observation.metadata.get("current_smiles", ""),
"visible_metrics": observation.visible_metrics,
"constraint_status": [constraint.model_dump() for constraint in observation.constraint_status],
"remaining_budget": observation.remaining_budget,
"max_budget": observation.max_budget,
"step_index": observation.step_index,
"max_steps": observation.max_steps,
"molecule_slots": lead_view.get("molecule_slots", {}),
"candidate_edits": lead_view.get("candidate_edits", [])[:12],
"open_questions": lead_view.get("open_questions", []),
"known_assays": [
{
"tool_name": reading.tool_name,
"property_name": reading.property_name,
"estimate": reading.estimate,
"confidence_low": reading.confidence_low,
"confidence_high": reading.confidence_high,
"molecule_signature": reading.molecule_signature,
}
for reading in observation.known_assays[-8:]
],
"tool_costs": assay_view.get("tool_costs", {}),
"evidence_gaps": assay_view.get("evidence_gaps", []),
"estimated_information_value": assay_view.get("estimated_information_value", {}),
}
def repair_compact_action(data: Dict[str, Any]) -> tuple[Dict[str, Any], list[str]]:
"""Bounded normalization for compact-action models.
This repairs only schema-near-misses. It does not invent an action from a
non-action wrapper and it still rejects invalid top-level action types.
"""
repaired = dict(data)
notes: list[str] = []
if "role" in repaired and "acting_role" not in repaired:
repaired["acting_role"] = repaired.pop("role")
notes.append("role->acting_role")
action_type = repaired.get("action_type")
if action_type not in {"edit", "run_assay", "submit", "restart", "defer"}:
return repaired, notes
if repaired.get("edit_type") == "replace":
repaired["edit_type"] = "substitute"
notes.append("edit_type:replace->substitute")
if isinstance(repaired.get("evidence"), str):
repaired["evidence"] = [repaired["evidence"]]
notes.append("evidence:string->list")
repaired["expected_effects"] = repair_effects(repaired.get("expected_effects"), notes)
if action_type == "run_assay":
repaired["acting_role"] = "assay_planner"
repaired["edit_type"] = None
repaired["slot"] = None
repaired["fragment"] = None
if repaired.get("tool_name") not in {
"evaluate_properties",
"dock_target",
"assay_toxicity",
"estimate_synthesizability",
"evaluate_novelty",
"search_literature",
"run_md_simulation",
}:
repaired["tool_name"] = "evaluate_properties"
notes.append("tool_name:invalid->evaluate_properties")
else:
repaired["acting_role"] = "lead_chemist"
if action_type == "edit":
if repaired.get("edit_type") not in {"add_fragment", "substitute", "remove", "undo_last_edit"}:
repaired["edit_type"] = "substitute"
notes.append("edit_type:invalid->substitute")
if repaired.get("tool_name") is not None:
repaired["tool_name"] = None
notes.append("tool_name:edit->null")
else:
for key in ("edit_type", "slot", "fragment", "tool_name"):
if repaired.get(key) is not None:
repaired[key] = None
notes.append(f"{key}:{action_type}->null")
allowed_keys = {
"action_type",
"acting_role",
"edit_type",
"slot",
"fragment",
"tool_name",
"rationale",
"evidence",
"expected_effects",
}
for key in list(repaired):
if key not in allowed_keys:
repaired.pop(key)
notes.append(f"drop_extra:{key}")
repaired.setdefault("rationale", "Choose the next compact MolForge action.")
repaired.setdefault("evidence", [])
for key in ("edit_type", "slot", "fragment", "tool_name"):
repaired.setdefault(key, None)
return repaired, notes
def repair_effects(value: Any, notes: list[str]) -> dict[str, str]:
defaults = {
"potency": "unknown",
"toxicity": "unknown",
"synth": "unknown",
"novelty": "unknown",
"budget": "neutral",
}
if not isinstance(value, dict):
notes.append("expected_effects:non_dict->defaults")
return defaults
aliases = {
"synthesizability": "synth",
"synthesis": "synth",
}
for raw_key, raw_value in value.items():
key = aliases.get(raw_key, raw_key)
if key not in defaults:
notes.append(f"expected_effects:drop_extra:{raw_key}")
continue
defaults[key] = normalize_effect_value(raw_value, notes, key)
return defaults
def normalize_effect_value(value: Any, notes: list[str], key: str) -> str:
if value in {"up", "down", "neutral", "unknown", "not_applicable"}:
return value
text = str(value).lower().strip().replace("-", "_").replace(" ", "_")
if any(token in text for token in ("increase", "improve", "higher", "upward", "+")):
notes.append(f"expected_effects:{key}:{value}->up")
return "up"
if any(token in text for token in ("decrease", "lower", "reduce", "downward", "-")):
notes.append(f"expected_effects:{key}:{value}->down")
return "down"
if any(token in text for token in ("maintain", "stable", "unchanged", "same")):
notes.append(f"expected_effects:{key}:{value}->neutral")
return "neutral"
if "not_applicable" in text or text == "na":
notes.append(f"expected_effects:{key}:{value}->not_applicable")
return "not_applicable"
notes.append(f"expected_effects:{key}:{value}->unknown")
return "unknown"
if __name__ == "__main__":
main()