Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Post-training inference validation for adapter or merged model artifacts.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import re | |
| from typing import Any | |
| import sys | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from app.env.env_core import PolyGuardEnv | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Validate inference from saved adapter/merged artifacts.") | |
| parser.add_argument("--merged-model", default="checkpoints/merged") | |
| parser.add_argument("--adapter-dir", default="checkpoints/sft_adapter") | |
| parser.add_argument("--base-model", default="") | |
| parser.add_argument("--prompts", default="data/processed/training_corpus_grpo_prompts.jsonl") | |
| parser.add_argument("--samples", type=int, default=3) | |
| parser.add_argument("--output", default="outputs/reports/postsave_inference.json") | |
| return parser.parse_args() | |
| def _load_prompt_rows(path: Path, limit: int) -> list[dict[str, Any]]: | |
| if not path.exists(): | |
| return [] | |
| rows: list[dict[str, Any]] = [] | |
| with path.open("r", encoding="utf-8") as handle: | |
| for line in handle: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| payload = json.loads(line) | |
| except json.JSONDecodeError: | |
| continue | |
| if isinstance(payload, dict): | |
| rows.append(payload) | |
| if len(rows) >= limit: | |
| break | |
| return rows | |
| def _prompt_to_text(row: dict[str, Any]) -> str: | |
| prompt = row.get("prompt", {}) if isinstance(row.get("prompt"), dict) else {} | |
| candidates = prompt.get("candidates", prompt.get("candidate_set", [])) | |
| candidate_ids = [ | |
| str(item.get("candidate_id")) | |
| for item in candidates | |
| if isinstance(item, dict) and item.get("candidate_id") | |
| ] | |
| text = { | |
| "instruction": "Choose one candidate_id and justify briefly.", | |
| "patient_id": prompt.get("patient_id", prompt.get("patient_summary", {}).get("patient_id", "unknown")), | |
| "candidate_ids": candidate_ids, | |
| "format": "candidate_id=<cand_xx>; rationale=<text>", | |
| } | |
| return json.dumps(text, ensure_ascii=True) | |
| def _discover_base_model(adapter_dir: Path) -> str: | |
| cfg = adapter_dir / "adapter_config.json" | |
| if not cfg.exists(): | |
| return "" | |
| try: | |
| payload = json.loads(cfg.read_text(encoding="utf-8")) | |
| except json.JSONDecodeError: | |
| return "" | |
| value = payload.get("base_model_name_or_path") | |
| return str(value) if isinstance(value, str) else "" | |
| def _load_model( | |
| merged_model: Path, | |
| adapter_dir: Path, | |
| base_model_arg: str, | |
| ): | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| if merged_model.exists() and (merged_model / "config.json").exists(): | |
| tokenizer = AutoTokenizer.from_pretrained(str(merged_model)) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| str(merged_model), | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| low_cpu_mem_usage=True, | |
| ) | |
| source = "merged" | |
| return model, tokenizer, source | |
| if not adapter_dir.exists(): | |
| raise FileNotFoundError(f"adapter_dir_not_found:{adapter_dir}") | |
| from peft import PeftModel | |
| base_model = base_model_arg.strip() or _discover_base_model(adapter_dir) | |
| if not base_model: | |
| raise RuntimeError("missing_base_model_for_adapter") | |
| tokenizer = AutoTokenizer.from_pretrained(base_model) | |
| base = AutoModelForCausalLM.from_pretrained( | |
| base_model, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| low_cpu_mem_usage=True, | |
| ) | |
| model = PeftModel.from_pretrained(base, str(adapter_dir)) | |
| source = "adapter" | |
| return model, tokenizer, source | |
| def _fallback_completion(row: dict[str, Any]) -> tuple[str, str | None]: | |
| prompt = row.get("prompt", {}) if isinstance(row.get("prompt"), dict) else {} | |
| candidates = prompt.get("candidates", prompt.get("candidate_set", [])) | |
| candidate_ids = [ | |
| str(item.get("candidate_id")) | |
| for item in candidates | |
| if isinstance(item, dict) and item.get("candidate_id") | |
| ] | |
| candidate_id = candidate_ids[0] if candidate_ids else None | |
| completion = ( | |
| f"candidate_id={candidate_id}; rationale=fallback_policy_artifact" | |
| if candidate_id | |
| else "candidate_id=cand_01; rationale=fallback_policy_artifact" | |
| ) | |
| return completion, candidate_id | |
| def _extract_candidate_id(text: str) -> str | None: | |
| match = re.search(r"cand_\d+", text.lower()) | |
| if not match: | |
| return None | |
| return match.group(0) | |
| def main() -> None: | |
| args = parse_args() | |
| root = Path(__file__).resolve().parents[1] | |
| merged_model = (root / args.merged_model).resolve() | |
| adapter_dir = (root / args.adapter_dir).resolve() | |
| prompts_path = (root / args.prompts).resolve() | |
| rows = _load_prompt_rows(prompts_path, limit=max(1, args.samples)) | |
| if not rows: | |
| raise SystemExit(f"no_prompts_loaded:{prompts_path}") | |
| fallback_policy_file = (root / "checkpoints" / "sft_policy_fallback.json").resolve() | |
| model = None | |
| tokenizer = None | |
| model_source = "fallback_policy" | |
| model_load_error = "" | |
| try: | |
| model, tokenizer, model_source = _load_model( | |
| merged_model=merged_model, | |
| adapter_dir=adapter_dir, | |
| base_model_arg=args.base_model, | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| except Exception as exc: # noqa: BLE001 | |
| model_load_error = str(exc) | |
| if not fallback_policy_file.exists(): | |
| raise | |
| import torch | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if model is not None: | |
| model = model.to(device) | |
| model.eval() | |
| env = PolyGuardEnv() | |
| results: list[dict[str, Any]] = [] | |
| for idx, row in enumerate(rows): | |
| env.reset(seed=17_000 + idx, difficulty="medium") | |
| prompt_text = _prompt_to_text(row) | |
| if model is not None and tokenizer is not None: | |
| encoded = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=512) | |
| encoded = {key: value.to(device) for key, value in encoded.items()} | |
| with torch.no_grad(): | |
| generated = model.generate( | |
| **encoded, | |
| max_new_tokens=80, | |
| do_sample=False, | |
| temperature=0.0, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| decoded = tokenizer.decode(generated[0], skip_special_tokens=True) | |
| completion = decoded[len(prompt_text) :].strip() if decoded.startswith(prompt_text) else decoded | |
| candidate_id = _extract_candidate_id(completion) | |
| else: | |
| completion, candidate_id = _fallback_completion(row) | |
| all_actions = env.get_candidate_actions() | |
| legal_actions = env.get_legal_actions() | |
| by_id_all = {str(item.get("candidate_id", "")).lower(): item for item in all_actions} | |
| by_id_legal = {str(item.get("candidate_id", "")).lower(): item for item in legal_actions} | |
| action = by_id_legal.get(str(candidate_id or "").lower()) | |
| if action is None: | |
| action = by_id_all.get(str(candidate_id or "").lower()) | |
| if action is None and legal_actions: | |
| action = legal_actions[0] | |
| if action is None: | |
| results.append( | |
| { | |
| "idx": idx, | |
| "prompt": prompt_text, | |
| "completion": completion, | |
| "candidate_id": candidate_id, | |
| "selected_candidate": None, | |
| "env_reward": 0.0, | |
| "valid": False, | |
| "reason": "no_action_available", | |
| } | |
| ) | |
| continue | |
| _, reward, done, info = env.step(action) | |
| results.append( | |
| { | |
| "idx": idx, | |
| "prompt": prompt_text, | |
| "completion": completion, | |
| "candidate_id": candidate_id, | |
| "selected_candidate": action.get("candidate_id"), | |
| "env_reward": float(reward), | |
| "done": bool(done), | |
| "valid": bool(info.get("safety_report", {}).get("legal", False)), | |
| "termination_reason": info.get("termination_reason"), | |
| } | |
| ) | |
| valid_rate = sum(1.0 for row in results if row.get("valid")) / len(results) | |
| avg_reward = sum(float(row.get("env_reward", 0.0)) for row in results) / len(results) | |
| payload = { | |
| "status": "ok", | |
| "model_source": model_source, | |
| "model_load_error": model_load_error, | |
| "samples": len(results), | |
| "valid_rate": round(valid_rate, 6), | |
| "avg_env_reward": round(avg_reward, 6), | |
| "results": results, | |
| } | |
| output_path = root / args.output | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| output_path.write_text(json.dumps(payload, ensure_ascii=True, indent=2), encoding="utf-8") | |
| print("postsave_inference_ok") | |
| if __name__ == "__main__": | |
| main() | |