#!/usr/bin/env python3 """Fast generation evaluation for TMF921 intent-to-config models/adapters. Metrics: JSON parse rate, exact canonical JSON match, field-level F1, simple metadata constraints, stratified by split/target_layer/slice_type/lifecycle_operation. The evaluator is resumable and batched. It periodically writes partial predictions so a long OOD run can be stopped/restarted without losing completed examples. """ import argparse import json import os from collections import defaultdict from pathlib import Path from typing import Any, Dict, List, Tuple import torch from datasets import load_dataset from peft import PeftModel from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from tmf921_train.utils import aggregate_metrics, field_f1, get_message, json_exact_match, metadata_constraint_pass, parse_json, write_json def parse_args(): p = argparse.ArgumentParser() p.add_argument("--model", required=True, help="Base model or merged model path/id") p.add_argument("--adapter", default=None, help="Optional PEFT adapter path/id") p.add_argument("--dataset", default="nraptisss/TMF921-intent-to-config-research-sota") p.add_argument("--splits", nargs="+", default=["test_in_distribution", "test_template_ood", "test_use_case_ood", "test_sector_ood", "test_adversarial"]) p.add_argument("--output_dir", default="eval_outputs") p.add_argument("--max_samples_per_split", type=int, default=None) p.add_argument("--max_new_tokens", type=int, default=1536, help="Hard cap for generation. 1536 covers audited dataset outputs with margin.") p.add_argument("--gold_length_buffer", type=int, default=96, help="Dynamic cap = max gold output tokens in batch + buffer, clipped by max_new_tokens") p.add_argument("--batch_size", type=int, default=4, help="Batched generation size. Use 1 if OOM; 4 is usually safe on RTX 6000 Ada for Qwen3-8B QLoRA.") p.add_argument("--save_every", type=int, default=25, help="Write partial predictions every N completed examples") p.add_argument("--resume", action="store_true", default=True, help="Resume from existing predictions.json if present") p.add_argument("--no_resume", dest="resume", action="store_false") p.add_argument("--temperature", type=float, default=0.0) p.add_argument("--top_p", type=float, default=1.0) p.add_argument("--load_in_4bit", action="store_true", help="Load base model in 4-bit for adapter evaluation") p.add_argument("--bf16", action="store_true", default=True) p.add_argument("--trust_remote_code", action="store_true", default=True) return p.parse_args() def make_prompt_messages(messages: List[Dict[str, str]]) -> List[Dict[str, str]]: out = [] for i, m in enumerate(messages): if i == len(messages) - 1 and m.get("role") == "assistant": break out.append({"role": m.get("role"), "content": m.get("content", "")}) if not out: out = [m for m in messages if m.get("role") != "assistant"] return out def make_prompt_text(tokenizer, messages: List[Dict[str, str]]) -> str: return tokenizer.apply_chat_template(make_prompt_messages(messages), tokenize=False, add_generation_prompt=True) def gold_text(example: Dict[str, Any]) -> str: return example.get("completion") or get_message(example["messages"], "assistant") def dynamic_max_new_tokens(tokenizer, examples: List[Dict[str, Any]], args) -> int: lens = [] for ex in examples: ids = tokenizer(gold_text(ex), add_special_tokens=False)["input_ids"] lens.append(len(ids)) return max(16, min(int(args.max_new_tokens), max(lens) + int(args.gold_length_buffer))) def generate_batch(model, tokenizer, examples: List[Dict[str, Any]], args) -> List[str]: texts = [make_prompt_text(tokenizer, ex["messages"]) for ex in examples] old_padding_side = tokenizer.padding_side tokenizer.padding_side = "left" try: inputs = tokenizer(texts, return_tensors="pt", padding=True).to(model.device) finally: tokenizer.padding_side = old_padding_side max_new = dynamic_max_new_tokens(tokenizer, examples, args) gen_kwargs = dict( max_new_tokens=max_new, do_sample=args.temperature > 0, temperature=args.temperature if args.temperature > 0 else None, top_p=args.top_p, pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None} with torch.inference_mode(): out = model.generate(**inputs, **gen_kwargs) new_tokens = out[:, inputs["input_ids"].shape[1]:] return tokenizer.batch_decode(new_tokens, skip_special_tokens=True) def row_metrics(example: Dict[str, Any], prediction: str) -> Dict[str, Any]: gold = gold_text(example) pred_obj, pred_err = parse_json(prediction) gold_obj, gold_err = parse_json(gold) out: Dict[str, Any] = { "id": example.get("id"), "target_layer": example.get("target_layer"), "slice_type": example.get("slice_type"), "lifecycle_operation": example.get("lifecycle_operation"), "parse_json": pred_obj is not None, "gold_parse_json": gold_obj is not None, "exact_match": False, "prediction": prediction, "gold": gold, "parse_error": pred_err, } if pred_obj is not None and gold_obj is not None: out["exact_match"] = json_exact_match(pred_obj, gold_obj) out.update(field_f1(pred_obj, gold_obj)) out.update(metadata_constraint_pass(example, prediction, pred_obj)) else: out.update({"field_precision": 0.0, "field_recall": 0.0, "field_f1": 0.0, "field_tp": 0, "field_fp": 0, "field_fn": 0}) out.update({"slice_sst_pass": False, "kpi_text_presence_pass": False, "adversarial_status_pass": False}) return out def require_cuda(): print("=== CUDA CHECK ===") print(f"torch={torch.__version__} torch.version.cuda={torch.version.cuda} CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}") if not torch.cuda.is_available(): raise RuntimeError( "CUDA is not available to PyTorch. Refusing to run generation on CPU. " "Run `bash scripts/install_rtx6000ada.sh`, verify `nvidia-smi`, and set CUDA_VISIBLE_DEVICES=0." ) print(f"cuda device_count={torch.cuda.device_count()} gpu0={torch.cuda.get_device_name(0)}") def load_existing_predictions(path: Path) -> Tuple[List[Dict[str, Any]], set]: if path.exists(): rows = json.loads(path.read_text()) done = {str(r.get("id")) for r in rows} return rows, done return [], set() def write_split_outputs(split_dir: Path, rows: List[Dict[str, Any]]) -> Dict[str, Any]: write_json(split_dir / "predictions.json", rows) summary = aggregate_metrics(rows) for key in ["target_layer", "slice_type", "lifecycle_operation"]: groups = defaultdict(list) for r in rows: groups[str(r.get(key))].append(r) summary[f"by_{key}"] = {g: aggregate_metrics(v) for g, v in sorted(groups.items())} write_json(split_dir / "metrics.json", summary) return summary def main(): args = parse_args() require_cuda() out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) tokenizer = AutoTokenizer.from_pretrained(args.adapter or args.model, trust_remote_code=args.trust_remote_code) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token quant_config = None if args.load_in_4bit: quant_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16, ) model = AutoModelForCausalLM.from_pretrained( args.model, trust_remote_code=args.trust_remote_code, dtype=torch.bfloat16 if args.bf16 else torch.float16, device_map="auto", quantization_config=quant_config, ) if args.adapter: model = PeftModel.from_pretrained(model, args.adapter) model.eval() ds = load_dataset(args.dataset) all_summary = {} for split in args.splits: split_ds = ds[split] if args.max_samples_per_split: split_ds = split_ds.select(range(min(args.max_samples_per_split, len(split_ds)))) split_dir = out_dir / split split_dir.mkdir(parents=True, exist_ok=True) pred_path = split_dir / "predictions.json" rows, done_ids = load_existing_predictions(pred_path) if args.resume else ([], set()) todo = [ex for ex in split_ds if str(ex.get("id")) not in done_ids] print(f"Evaluating {split}: total={len(split_ds)} already_done={len(done_ids)} remaining={len(todo)} batch_size={args.batch_size}") pbar = tqdm(total=len(todo), desc=split) completed_since_save = 0 for start in range(0, len(todo), args.batch_size): batch = todo[start:start + args.batch_size] try: preds = generate_batch(model, tokenizer, batch, args) except torch.cuda.OutOfMemoryError: torch.cuda.empty_cache() if args.batch_size == 1 or len(batch) == 1: raise preds = [] for ex in batch: preds.extend(generate_batch(model, tokenizer, [ex], args)) for ex, pred in zip(batch, preds): rows.append(row_metrics(ex, pred.strip())) pbar.update(len(batch)) completed_since_save += len(batch) if completed_since_save >= args.save_every: write_split_outputs(split_dir, rows) completed_since_save = 0 pbar.close() summary = write_split_outputs(split_dir, rows) all_summary[split] = summary write_json(out_dir / "all_metrics.json", all_summary) print(json.dumps(all_summary, indent=2)[:4000]) if __name__ == "__main__": main()