"""Judge-facing baseline inference script for MolForge.""" from __future__ import annotations import json import os from typing import Any, Optional, cast from openai import OpenAI from inference_common import ( COMPACT_SYSTEM_PROMPT, SYSTEM_PROMPT, build_model_payload, extract_json, ) try: from molforge.models import MolForgeAction, MolForgeObservation from molforge.server.molforge_environment import MolForgeEnvironment except ImportError: from models import MolForgeAction, MolForgeObservation from server.molforge_environment import MolForgeEnvironment API_BASE_URL = os.getenv("API_BASE_URL") MODEL_NAME = os.getenv("MODEL_NAME") API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN") MAX_TURNS = 10 MODEL_TIMEOUT_S = float(os.getenv("MODEL_TIMEOUT_S", "35")) MODEL_LONG_TIMEOUT_S = float(os.getenv("MODEL_LONG_TIMEOUT_S", "45")) MODEL_RETRY_TIMEOUT_S = float(os.getenv("MODEL_RETRY_TIMEOUT_S", "15")) MODEL_MAX_TOKENS = int(os.getenv("MODEL_MAX_TOKENS", "220")) MIN_REPORTED_SCORE = 1e-6 MAX_REPORTED_SCORE = 1.0 - 1e-6 def main() -> None: env = MolForgeEnvironment() if not API_BASE_URL or not MODEL_NAME or not API_KEY: raise RuntimeError( "API_BASE_URL, MODEL_NAME, and API_KEY or HF_TOKEN are required. " "No heuristic fallback is available." ) client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) scores = [] raw_final_scores = [] submission_scores = [] progress_scores = [] model_action_count = 0 for episode_index in range(3): observation = env.reset() task_name = observation.scenario_id episode_error = "" print( f"[START] task={task_name} difficulty={observation.difficulty} episode={episode_index + 1}", flush=True, ) for _ in range(MAX_TURNS): if observation.done: break try: action = choose_action(client, observation) model_action_count += 1 observation = env.step(action) except Exception as exc: episode_error = f"{exc.__class__.__name__}:{exc}" print( f"[STEP] task={task_name} step={observation.step_index + 1} " f"reward=0.000000 action=model_error status=failed", flush=True, ) break print( f"[STEP] task={task_name} step={observation.step_index} " f"reward={observation.reward:.6f} action={action.action_type} " f"actor={action.acting_role} status={observation.governance.status}", flush=True, ) if observation.done: break grader_scores = observation.metadata.get("terminal_grader_scores", {}) raw_final_score = float(grader_scores.get("final_score", grader_scores.get("submission_score", 0.0))) final_score = reportable_score(raw_final_score) submission_score = float(grader_scores.get("submission_score", 0.0)) progress_score = float(grader_scores.get("progress_score", 0.0)) scores.append(final_score) raw_final_scores.append(raw_final_score) submission_scores.append(submission_score) progress_scores.append(progress_score) end_line = ( f"[END] task={task_name} score={final_score:.6f} raw_score={raw_final_score:.6f} " f"submission_score={submission_score:.6f} progress_score={progress_score:.6f} " f"steps={observation.step_index}" ) if episode_error: end_line += f" error={json.dumps(episode_error)}" print(end_line, flush=True) if observation.report_card: print(observation.report_card, flush=True) average = sum(scores) / len(scores) average_progress = sum(progress_scores) / len(progress_scores) summary = { "scores": scores, "raw_final_scores": raw_final_scores, "average_final_score": round(reportable_score(average), 6), "submission_scores": submission_scores, "average_submission_score": round(sum(submission_scores) / len(submission_scores), 4), "progress_scores": progress_scores, "average_progress_score": round(average_progress, 4), "model_action_count": model_action_count, "model_name": MODEL_NAME, "api_base_url": API_BASE_URL, "fallback_enabled": False, } print("[SUMMARY] " + json.dumps(summary, separators=(",", ":")), flush=True) def reportable_score(score: float) -> float: """Validator-facing scores must be strictly between 0 and 1.""" if score <= 0.0: return MIN_REPORTED_SCORE if score >= 1.0: return MAX_REPORTED_SCORE return score def choose_action(client: OpenAI, observation: MolForgeObservation) -> MolForgeAction: """Use the model and fail loudly when it cannot produce a valid action.""" action, error = ask_model(client, observation) if action is None: raise RuntimeError(f"Model action failed: {error}") return action def ask_model(client: OpenAI, observation: MolForgeObservation) -> tuple[Optional[MolForgeAction], str]: """Request a structured team action from the model and parse it safely.""" errors = [] try: full_payload = build_model_payload(observation, compact=False) timeout_s = model_timeout_for_step(observation) data = request_action_json( client=client, system_prompt=SYSTEM_PROMPT, user_payload=full_payload, timeout_s=timeout_s, ) return MolForgeAction(**data), "" except Exception as exc: errors.append(f"full_prompt:{exc.__class__.__name__}:{exc}") try: compact_payload = build_model_payload(observation, compact=True) data = request_action_json( client=client, system_prompt=COMPACT_SYSTEM_PROMPT, user_payload=compact_payload, timeout_s=MODEL_RETRY_TIMEOUT_S, ) return MolForgeAction(**data), "" except Exception as retry_exc: errors.append(f"compact_prompt:{retry_exc.__class__.__name__}:{retry_exc}") return None, " | ".join(errors) def request_action_json( *, client: OpenAI, system_prompt: str, user_payload: dict[str, Any], timeout_s: float, ) -> dict[str, Any]: """Call the remote model with a bounded timeout and parse a JSON action.""" configured_client = client.with_options(timeout=timeout_s) completion = configured_client.chat.completions.create( model=MODEL_NAME, temperature=0.0, max_tokens=MODEL_MAX_TOKENS, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": json.dumps(user_payload, indent=2)}, ], ) message_content = completion.choices[0].message.content if isinstance(message_content, list): text = "".join(part.get("text", "") for part in cast(list[dict[str, Any]], message_content)) else: text = message_content or "" return extract_json(text) def model_timeout_for_step(observation: MolForgeObservation) -> float: """Allow more time for high-value late-stage decisions without making every step unbounded.""" if observation.difficulty == "hard": return MODEL_LONG_TIMEOUT_S if observation.step_index >= observation.max_steps - 2: return MODEL_LONG_TIMEOUT_S return MODEL_TIMEOUT_S if __name__ == "__main__": main()