"""Local PEFT/LoRA inference runner for MolForge. Use this to test an SFT adapter against the environment before RL. It loads the base model named in the adapter config, attaches the LoRA weights, and requires the model to emit a valid MolForgeAction JSON object. There is no heuristic fallback or schema repair. """ from __future__ import annotations import json import os from pathlib import Path from typing import Any, Dict, Optional, Tuple import torch from peft import PeftConfig, PeftModel from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Qwen3_5ForConditionalGeneration from inference_common import ( COMPACT_SYSTEM_PROMPT, SYSTEM_PROMPT, build_model_payload, extract_json, ) try: from molforge.models import MolForgeAction, MolForgeObservation from molforge.server.molforge_environment import MolForgeEnvironment except ImportError: from models import MolForgeAction, MolForgeObservation from server.molforge_environment import MolForgeEnvironment ADAPTER_PATH = Path(os.getenv("LORA_ADAPTER_PATH", "qwen3_5_2b_lora_adapters")) LOCAL_NUM_EPISODES = int(os.getenv("LOCAL_NUM_EPISODES", "3")) LOCAL_MAX_TURNS = int(os.getenv("LOCAL_MAX_TURNS", "10")) LORA_MAX_NEW_TOKENS = int(os.getenv("LORA_MAX_NEW_TOKENS", "768")) LORA_RETRY_MAX_NEW_TOKENS = int(os.getenv("LORA_RETRY_MAX_NEW_TOKENS", "512")) LORA_DEVICE = os.getenv("LORA_DEVICE", "auto") def main() -> None: adapter_path = ADAPTER_PATH.expanduser().resolve() tokenizer, model, base_model_name, device = load_adapter_model(adapter_path) env = MolForgeEnvironment() scores = [] submission_scores = [] progress_scores = [] print(f"Using LoRA adapter: {adapter_path}", flush=True) print(f"Base model: {base_model_name}", flush=True) print(f"Device: {device}", flush=True) for episode_index in range(LOCAL_NUM_EPISODES): observation = env.reset() print(f"\n=== Episode {episode_index + 1}: {observation.scenario_id} ===", flush=True) for _ in range(LOCAL_MAX_TURNS): if observation.done: break action, source = choose_lora_action(tokenizer, model, observation, device) observation = env.step(action) print( f"step={observation.step_index:02d} action={action.action_type} actor={action.acting_role} " f"source={source} reward={observation.reward:+.3f} budget={observation.remaining_budget} " f"governance={observation.governance.status}", flush=True, ) print(f" {observation.last_transition_summary}", flush=True) if observation.done: break grader_scores = observation.metadata.get("terminal_grader_scores", {}) final_score = float(grader_scores.get("final_score", grader_scores.get("submission_score", 0.0))) submission_score = float(grader_scores.get("submission_score", 0.0)) progress_score = float(grader_scores.get("progress_score", 0.0)) scores.append(final_score) submission_scores.append(submission_score) progress_scores.append(progress_score) print(f"final_score={final_score:.3f}", flush=True) print(f"submission_score={submission_score:.3f}", flush=True) print(f"progress_score={progress_score:.3f}", flush=True) if observation.report_card: print(observation.report_card, flush=True) average = sum(scores) / len(scores) average_progress = sum(progress_scores) / len(progress_scores) print("\n=== LoRA Local Summary ===", flush=True) print( json.dumps( { "adapter": str(adapter_path), "base_model": base_model_name, "scores": scores, "average_final_score": round(average, 4), "submission_scores": submission_scores, "average_submission_score": round(sum(submission_scores) / len(submission_scores), 4), "progress_scores": progress_scores, "average_progress_score": round(average_progress, 4), }, indent=2, ), flush=True, ) def load_adapter_model(adapter_path: Path): config = PeftConfig.from_pretrained(adapter_path) base_model_name = config.base_model_name_or_path device = resolve_device() dtype = torch.float16 if device in {"cuda", "mps"} else torch.float32 tokenizer = AutoTokenizer.from_pretrained( adapter_path, trust_remote_code=True, use_fast=True, ) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token base_config = AutoConfig.from_pretrained(base_model_name, trust_remote_code=True) model_class = ( Qwen3_5ForConditionalGeneration if "Qwen3_5ForConditionalGeneration" in (base_config.architectures or []) else AutoModelForCausalLM ) base_model = model_class.from_pretrained( base_model_name, dtype=dtype, trust_remote_code=True, low_cpu_mem_usage=True, ) model = PeftModel.from_pretrained(base_model, adapter_path) model.to(device) model.eval() return tokenizer, model, base_model_name, device def resolve_device() -> str: if LORA_DEVICE != "auto": return LORA_DEVICE if torch.cuda.is_available(): return "cuda" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return "mps" return "cpu" def choose_lora_action( tokenizer, model, observation: MolForgeObservation, device: str, ) -> Tuple[MolForgeAction, str]: action, error = ask_lora_model( tokenizer, model, observation, device, compact=False, max_new_tokens=LORA_MAX_NEW_TOKENS, ) if action is not None: return action, "lora_model" retry_action, retry_error = ask_lora_model( tokenizer, model, observation, device, compact=True, max_new_tokens=LORA_RETRY_MAX_NEW_TOKENS, ) if retry_action is not None: return retry_action, "lora_model_compact_retry" raise RuntimeError(f"LoRA model action failed: full_prompt:{error} | compact_prompt:{retry_error}") def ask_lora_model( tokenizer, model, observation: MolForgeObservation, device: str, *, compact: bool, max_new_tokens: int, ) -> Tuple[Optional[MolForgeAction], str]: response_text = "" try: payload = build_model_payload(observation, compact=compact) system_prompt = COMPACT_SYSTEM_PROMPT if compact else SYSTEM_PROMPT response_text = generate_response( tokenizer, model, device, system_prompt=system_prompt, user_payload=payload, max_new_tokens=max_new_tokens, ) data = extract_json(response_text) return MolForgeAction(**data), "" except Exception as exc: snippet = response_text[:1200].replace("\n", "\\n") return None, f"{exc.__class__.__name__}:{exc}; raw={snippet}" def generate_response( tokenizer, model, device: str, *, system_prompt: str, user_payload: Dict[str, Any], max_new_tokens: int, ) -> str: messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": json.dumps(user_payload, separators=(",", ":"))}, ] prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, ) inputs = tokenizer(prompt, return_tensors="pt").to(device) with torch.inference_mode(): generated = model.generate( **inputs, do_sample=False, temperature=None, top_p=None, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) new_tokens = generated[0, inputs["input_ids"].shape[-1] :] return tokenizer.decode(new_tokens, skip_special_tokens=True).strip() if __name__ == "__main__": main()