| |
| """Fast generation evaluation for TMF921 intent-to-config models/adapters. |
| |
| Metrics: JSON parse rate, exact canonical JSON match, field-level F1, simple metadata constraints, |
| stratified by split/target_layer/slice_type/lifecycle_operation. |
| |
| The evaluator is resumable and batched. It periodically writes partial predictions so a long OOD |
| run can be stopped/restarted without losing completed examples. |
| """ |
| import argparse |
| import json |
| import os |
| from collections import defaultdict |
| from pathlib import Path |
| from typing import Any, Dict, List, Tuple |
|
|
| import torch |
| from datasets import load_dataset |
| from peft import PeftModel |
| from tqdm import tqdm |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
|
| from tmf921_train.utils import aggregate_metrics, field_f1, get_message, json_exact_match, metadata_constraint_pass, parse_json, write_json |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--model", required=True, help="Base model or merged model path/id") |
| p.add_argument("--adapter", default=None, help="Optional PEFT adapter path/id") |
| p.add_argument("--dataset", default="nraptisss/TMF921-intent-to-config-research-sota") |
| p.add_argument("--splits", nargs="+", default=["test_in_distribution", "test_template_ood", "test_use_case_ood", "test_sector_ood", "test_adversarial"]) |
| p.add_argument("--output_dir", default="eval_outputs") |
| p.add_argument("--max_samples_per_split", type=int, default=None) |
| p.add_argument("--max_new_tokens", type=int, default=1536, help="Hard cap for generation. 1536 covers audited dataset outputs with margin.") |
| p.add_argument("--gold_length_buffer", type=int, default=96, help="Dynamic cap = max gold output tokens in batch + buffer, clipped by max_new_tokens") |
| p.add_argument("--batch_size", type=int, default=4, help="Batched generation size. Use 1 if OOM; 4 is usually safe on RTX 6000 Ada for Qwen3-8B QLoRA.") |
| p.add_argument("--save_every", type=int, default=25, help="Write partial predictions every N completed examples") |
| p.add_argument("--resume", action="store_true", default=True, help="Resume from existing predictions.json if present") |
| p.add_argument("--no_resume", dest="resume", action="store_false") |
| p.add_argument("--temperature", type=float, default=0.0) |
| p.add_argument("--top_p", type=float, default=1.0) |
| p.add_argument("--load_in_4bit", action="store_true", help="Load base model in 4-bit for adapter evaluation") |
| p.add_argument("--bf16", action="store_true", default=True) |
| p.add_argument("--trust_remote_code", action="store_true", default=True) |
| return p.parse_args() |
|
|
|
|
| def make_prompt_messages(messages: List[Dict[str, str]]) -> List[Dict[str, str]]: |
| out = [] |
| for i, m in enumerate(messages): |
| if i == len(messages) - 1 and m.get("role") == "assistant": |
| break |
| out.append({"role": m.get("role"), "content": m.get("content", "")}) |
| if not out: |
| out = [m for m in messages if m.get("role") != "assistant"] |
| return out |
|
|
|
|
| def make_prompt_text(tokenizer, messages: List[Dict[str, str]]) -> str: |
| return tokenizer.apply_chat_template(make_prompt_messages(messages), tokenize=False, add_generation_prompt=True) |
|
|
|
|
| def gold_text(example: Dict[str, Any]) -> str: |
| return example.get("completion") or get_message(example["messages"], "assistant") |
|
|
|
|
| def dynamic_max_new_tokens(tokenizer, examples: List[Dict[str, Any]], args) -> int: |
| lens = [] |
| for ex in examples: |
| ids = tokenizer(gold_text(ex), add_special_tokens=False)["input_ids"] |
| lens.append(len(ids)) |
| return max(16, min(int(args.max_new_tokens), max(lens) + int(args.gold_length_buffer))) |
|
|
|
|
| def generate_batch(model, tokenizer, examples: List[Dict[str, Any]], args) -> List[str]: |
| texts = [make_prompt_text(tokenizer, ex["messages"]) for ex in examples] |
| old_padding_side = tokenizer.padding_side |
| tokenizer.padding_side = "left" |
| try: |
| inputs = tokenizer(texts, return_tensors="pt", padding=True).to(model.device) |
| finally: |
| tokenizer.padding_side = old_padding_side |
| max_new = dynamic_max_new_tokens(tokenizer, examples, args) |
| gen_kwargs = dict( |
| max_new_tokens=max_new, |
| do_sample=args.temperature > 0, |
| temperature=args.temperature if args.temperature > 0 else None, |
| top_p=args.top_p, |
| pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
| gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None} |
| with torch.inference_mode(): |
| out = model.generate(**inputs, **gen_kwargs) |
| new_tokens = out[:, inputs["input_ids"].shape[1]:] |
| return tokenizer.batch_decode(new_tokens, skip_special_tokens=True) |
|
|
|
|
| def row_metrics(example: Dict[str, Any], prediction: str) -> Dict[str, Any]: |
| gold = gold_text(example) |
| pred_obj, pred_err = parse_json(prediction) |
| gold_obj, gold_err = parse_json(gold) |
| out: Dict[str, Any] = { |
| "id": example.get("id"), |
| "target_layer": example.get("target_layer"), |
| "slice_type": example.get("slice_type"), |
| "lifecycle_operation": example.get("lifecycle_operation"), |
| "parse_json": pred_obj is not None, |
| "gold_parse_json": gold_obj is not None, |
| "exact_match": False, |
| "prediction": prediction, |
| "gold": gold, |
| "parse_error": pred_err, |
| } |
| if pred_obj is not None and gold_obj is not None: |
| out["exact_match"] = json_exact_match(pred_obj, gold_obj) |
| out.update(field_f1(pred_obj, gold_obj)) |
| out.update(metadata_constraint_pass(example, prediction, pred_obj)) |
| else: |
| out.update({"field_precision": 0.0, "field_recall": 0.0, "field_f1": 0.0, "field_tp": 0, "field_fp": 0, "field_fn": 0}) |
| out.update({"slice_sst_pass": False, "kpi_text_presence_pass": False, "adversarial_status_pass": False}) |
| return out |
|
|
|
|
| def require_cuda(): |
| print("=== CUDA CHECK ===") |
| print(f"torch={torch.__version__} torch.version.cuda={torch.version.cuda} CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}") |
| if not torch.cuda.is_available(): |
| raise RuntimeError( |
| "CUDA is not available to PyTorch. Refusing to run generation on CPU. " |
| "Run `bash scripts/install_rtx6000ada.sh`, verify `nvidia-smi`, and set CUDA_VISIBLE_DEVICES=0." |
| ) |
| print(f"cuda device_count={torch.cuda.device_count()} gpu0={torch.cuda.get_device_name(0)}") |
|
|
|
|
| def load_existing_predictions(path: Path) -> Tuple[List[Dict[str, Any]], set]: |
| if path.exists(): |
| rows = json.loads(path.read_text()) |
| done = {str(r.get("id")) for r in rows} |
| return rows, done |
| return [], set() |
|
|
|
|
| def write_split_outputs(split_dir: Path, rows: List[Dict[str, Any]]) -> Dict[str, Any]: |
| write_json(split_dir / "predictions.json", rows) |
| summary = aggregate_metrics(rows) |
| for key in ["target_layer", "slice_type", "lifecycle_operation"]: |
| groups = defaultdict(list) |
| for r in rows: |
| groups[str(r.get(key))].append(r) |
| summary[f"by_{key}"] = {g: aggregate_metrics(v) for g, v in sorted(groups.items())} |
| write_json(split_dir / "metrics.json", summary) |
| return summary |
|
|
|
|
| def main(): |
| args = parse_args() |
| require_cuda() |
| out_dir = Path(args.output_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(args.adapter or args.model, trust_remote_code=args.trust_remote_code) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| quant_config = None |
| if args.load_in_4bit: |
| quant_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| ) |
| model = AutoModelForCausalLM.from_pretrained( |
| args.model, |
| trust_remote_code=args.trust_remote_code, |
| dtype=torch.bfloat16 if args.bf16 else torch.float16, |
| device_map="auto", |
| quantization_config=quant_config, |
| ) |
| if args.adapter: |
| model = PeftModel.from_pretrained(model, args.adapter) |
| model.eval() |
|
|
| ds = load_dataset(args.dataset) |
| all_summary = {} |
| for split in args.splits: |
| split_ds = ds[split] |
| if args.max_samples_per_split: |
| split_ds = split_ds.select(range(min(args.max_samples_per_split, len(split_ds)))) |
| split_dir = out_dir / split |
| split_dir.mkdir(parents=True, exist_ok=True) |
| pred_path = split_dir / "predictions.json" |
| rows, done_ids = load_existing_predictions(pred_path) if args.resume else ([], set()) |
| todo = [ex for ex in split_ds if str(ex.get("id")) not in done_ids] |
| print(f"Evaluating {split}: total={len(split_ds)} already_done={len(done_ids)} remaining={len(todo)} batch_size={args.batch_size}") |
| pbar = tqdm(total=len(todo), desc=split) |
| completed_since_save = 0 |
| for start in range(0, len(todo), args.batch_size): |
| batch = todo[start:start + args.batch_size] |
| try: |
| preds = generate_batch(model, tokenizer, batch, args) |
| except torch.cuda.OutOfMemoryError: |
| torch.cuda.empty_cache() |
| if args.batch_size == 1 or len(batch) == 1: |
| raise |
| preds = [] |
| for ex in batch: |
| preds.extend(generate_batch(model, tokenizer, [ex], args)) |
| for ex, pred in zip(batch, preds): |
| rows.append(row_metrics(ex, pred.strip())) |
| pbar.update(len(batch)) |
| completed_since_save += len(batch) |
| if completed_since_save >= args.save_every: |
| write_split_outputs(split_dir, rows) |
| completed_since_save = 0 |
| pbar.close() |
| summary = write_split_outputs(split_dir, rows) |
| all_summary[split] = summary |
| write_json(out_dir / "all_metrics.json", all_summary) |
| print(json.dumps(all_summary, indent=2)[:4000]) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|