#!/usr/bin/env python3 """Post-training inference validation for adapter or merged model artifacts.""" from __future__ import annotations import argparse import json from pathlib import Path import re from typing import Any import sys ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from app.env.env_core import PolyGuardEnv def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Validate inference from saved adapter/merged artifacts.") parser.add_argument("--merged-model", default="checkpoints/merged") parser.add_argument("--adapter-dir", default="checkpoints/sft_adapter") parser.add_argument("--base-model", default="") parser.add_argument("--prompts", default="data/processed/training_corpus_grpo_prompts.jsonl") parser.add_argument("--samples", type=int, default=3) parser.add_argument("--output", default="outputs/reports/postsave_inference.json") return parser.parse_args() def _load_prompt_rows(path: Path, limit: int) -> list[dict[str, Any]]: if not path.exists(): return [] rows: list[dict[str, Any]] = [] with path.open("r", encoding="utf-8") as handle: for line in handle: line = line.strip() if not line: continue try: payload = json.loads(line) except json.JSONDecodeError: continue if isinstance(payload, dict): rows.append(payload) if len(rows) >= limit: break return rows def _prompt_to_text(row: dict[str, Any]) -> str: prompt = row.get("prompt", {}) if isinstance(row.get("prompt"), dict) else {} candidates = prompt.get("candidates", prompt.get("candidate_set", [])) candidate_ids = [ str(item.get("candidate_id")) for item in candidates if isinstance(item, dict) and item.get("candidate_id") ] text = { "instruction": "Choose one candidate_id and justify briefly.", "patient_id": prompt.get("patient_id", prompt.get("patient_summary", {}).get("patient_id", "unknown")), "candidate_ids": candidate_ids, "format": "candidate_id=; rationale=", } return json.dumps(text, ensure_ascii=True) def _discover_base_model(adapter_dir: Path) -> str: cfg = adapter_dir / "adapter_config.json" if not cfg.exists(): return "" try: payload = json.loads(cfg.read_text(encoding="utf-8")) except json.JSONDecodeError: return "" value = payload.get("base_model_name_or_path") return str(value) if isinstance(value, str) else "" def _load_model( merged_model: Path, adapter_dir: Path, base_model_arg: str, ): import torch from transformers import AutoModelForCausalLM, AutoTokenizer if merged_model.exists() and (merged_model / "config.json").exists(): tokenizer = AutoTokenizer.from_pretrained(str(merged_model)) model = AutoModelForCausalLM.from_pretrained( str(merged_model), torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, low_cpu_mem_usage=True, ) source = "merged" return model, tokenizer, source if not adapter_dir.exists(): raise FileNotFoundError(f"adapter_dir_not_found:{adapter_dir}") from peft import PeftModel base_model = base_model_arg.strip() or _discover_base_model(adapter_dir) if not base_model: raise RuntimeError("missing_base_model_for_adapter") tokenizer = AutoTokenizer.from_pretrained(base_model) base = AutoModelForCausalLM.from_pretrained( base_model, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, low_cpu_mem_usage=True, ) model = PeftModel.from_pretrained(base, str(adapter_dir)) source = "adapter" return model, tokenizer, source def _fallback_completion(row: dict[str, Any]) -> tuple[str, str | None]: prompt = row.get("prompt", {}) if isinstance(row.get("prompt"), dict) else {} candidates = prompt.get("candidates", prompt.get("candidate_set", [])) candidate_ids = [ str(item.get("candidate_id")) for item in candidates if isinstance(item, dict) and item.get("candidate_id") ] candidate_id = candidate_ids[0] if candidate_ids else None completion = ( f"candidate_id={candidate_id}; rationale=fallback_policy_artifact" if candidate_id else "candidate_id=cand_01; rationale=fallback_policy_artifact" ) return completion, candidate_id def _extract_candidate_id(text: str) -> str | None: match = re.search(r"cand_\d+", text.lower()) if not match: return None return match.group(0) def main() -> None: args = parse_args() root = Path(__file__).resolve().parents[1] merged_model = (root / args.merged_model).resolve() adapter_dir = (root / args.adapter_dir).resolve() prompts_path = (root / args.prompts).resolve() rows = _load_prompt_rows(prompts_path, limit=max(1, args.samples)) if not rows: raise SystemExit(f"no_prompts_loaded:{prompts_path}") fallback_policy_file = (root / "checkpoints" / "sft_policy_fallback.json").resolve() model = None tokenizer = None model_source = "fallback_policy" model_load_error = "" try: model, tokenizer, model_source = _load_model( merged_model=merged_model, adapter_dir=adapter_dir, base_model_arg=args.base_model, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token except Exception as exc: # noqa: BLE001 model_load_error = str(exc) if not fallback_policy_file.exists(): raise import torch device = "cuda" if torch.cuda.is_available() else "cpu" if model is not None: model = model.to(device) model.eval() env = PolyGuardEnv() results: list[dict[str, Any]] = [] for idx, row in enumerate(rows): env.reset(seed=17_000 + idx, difficulty="medium") prompt_text = _prompt_to_text(row) if model is not None and tokenizer is not None: encoded = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=512) encoded = {key: value.to(device) for key, value in encoded.items()} with torch.no_grad(): generated = model.generate( **encoded, max_new_tokens=80, do_sample=False, temperature=0.0, eos_token_id=tokenizer.eos_token_id, ) decoded = tokenizer.decode(generated[0], skip_special_tokens=True) completion = decoded[len(prompt_text) :].strip() if decoded.startswith(prompt_text) else decoded candidate_id = _extract_candidate_id(completion) else: completion, candidate_id = _fallback_completion(row) all_actions = env.get_candidate_actions() legal_actions = env.get_legal_actions() by_id_all = {str(item.get("candidate_id", "")).lower(): item for item in all_actions} by_id_legal = {str(item.get("candidate_id", "")).lower(): item for item in legal_actions} action = by_id_legal.get(str(candidate_id or "").lower()) if action is None: action = by_id_all.get(str(candidate_id or "").lower()) if action is None and legal_actions: action = legal_actions[0] if action is None: results.append( { "idx": idx, "prompt": prompt_text, "completion": completion, "candidate_id": candidate_id, "selected_candidate": None, "env_reward": 0.0, "valid": False, "reason": "no_action_available", } ) continue _, reward, done, info = env.step(action) results.append( { "idx": idx, "prompt": prompt_text, "completion": completion, "candidate_id": candidate_id, "selected_candidate": action.get("candidate_id"), "env_reward": float(reward), "done": bool(done), "valid": bool(info.get("safety_report", {}).get("legal", False)), "termination_reason": info.get("termination_reason"), } ) valid_rate = sum(1.0 for row in results if row.get("valid")) / len(results) avg_reward = sum(float(row.get("env_reward", 0.0)) for row in results) / len(results) payload = { "status": "ok", "model_source": model_source, "model_load_error": model_load_error, "samples": len(results), "valid_rate": round(valid_rate, 6), "avg_env_reward": round(avg_reward, 6), "results": results, } output_path = root / args.output output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text(json.dumps(payload, ensure_ascii=True, indent=2), encoding="utf-8") print("postsave_inference_ok") if __name__ == "__main__": main()