PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
tmf921-intent-training / scripts /evaluate_model.py
nraptisss's picture
Speed up and resume OOD evaluation with batched dynamic generation
6f5475f verified
#!/usr/bin/env python3
"""Fast generation evaluation for TMF921 intent-to-config models/adapters.
Metrics: JSON parse rate, exact canonical JSON match, field-level F1, simple metadata constraints,
stratified by split/target_layer/slice_type/lifecycle_operation.
The evaluator is resumable and batched. It periodically writes partial predictions so a long OOD
run can be stopped/restarted without losing completed examples.
"""
import argparse
import json
import os
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
import torch
from datasets import load_dataset
from peft import PeftModel
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from tmf921_train.utils import aggregate_metrics, field_f1, get_message, json_exact_match, metadata_constraint_pass, parse_json, write_json
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--model", required=True, help="Base model or merged model path/id")
p.add_argument("--adapter", default=None, help="Optional PEFT adapter path/id")
p.add_argument("--dataset", default="nraptisss/TMF921-intent-to-config-research-sota")
p.add_argument("--splits", nargs="+", default=["test_in_distribution", "test_template_ood", "test_use_case_ood", "test_sector_ood", "test_adversarial"])
p.add_argument("--output_dir", default="eval_outputs")
p.add_argument("--max_samples_per_split", type=int, default=None)
p.add_argument("--max_new_tokens", type=int, default=1536, help="Hard cap for generation. 1536 covers audited dataset outputs with margin.")
p.add_argument("--gold_length_buffer", type=int, default=96, help="Dynamic cap = max gold output tokens in batch + buffer, clipped by max_new_tokens")
p.add_argument("--batch_size", type=int, default=4, help="Batched generation size. Use 1 if OOM; 4 is usually safe on RTX 6000 Ada for Qwen3-8B QLoRA.")
p.add_argument("--save_every", type=int, default=25, help="Write partial predictions every N completed examples")
p.add_argument("--resume", action="store_true", default=True, help="Resume from existing predictions.json if present")
p.add_argument("--no_resume", dest="resume", action="store_false")
p.add_argument("--temperature", type=float, default=0.0)
p.add_argument("--top_p", type=float, default=1.0)
p.add_argument("--load_in_4bit", action="store_true", help="Load base model in 4-bit for adapter evaluation")
p.add_argument("--bf16", action="store_true", default=True)
p.add_argument("--trust_remote_code", action="store_true", default=True)
return p.parse_args()
def make_prompt_messages(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
out = []
for i, m in enumerate(messages):
if i == len(messages) - 1 and m.get("role") == "assistant":
break
out.append({"role": m.get("role"), "content": m.get("content", "")})
if not out:
out = [m for m in messages if m.get("role") != "assistant"]
return out
def make_prompt_text(tokenizer, messages: List[Dict[str, str]]) -> str:
return tokenizer.apply_chat_template(make_prompt_messages(messages), tokenize=False, add_generation_prompt=True)
def gold_text(example: Dict[str, Any]) -> str:
return example.get("completion") or get_message(example["messages"], "assistant")
def dynamic_max_new_tokens(tokenizer, examples: List[Dict[str, Any]], args) -> int:
lens = []
for ex in examples:
ids = tokenizer(gold_text(ex), add_special_tokens=False)["input_ids"]
lens.append(len(ids))
return max(16, min(int(args.max_new_tokens), max(lens) + int(args.gold_length_buffer)))
def generate_batch(model, tokenizer, examples: List[Dict[str, Any]], args) -> List[str]:
texts = [make_prompt_text(tokenizer, ex["messages"]) for ex in examples]
old_padding_side = tokenizer.padding_side
tokenizer.padding_side = "left"
try:
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(model.device)
finally:
tokenizer.padding_side = old_padding_side
max_new = dynamic_max_new_tokens(tokenizer, examples, args)
gen_kwargs = dict(
max_new_tokens=max_new,
do_sample=args.temperature > 0,
temperature=args.temperature if args.temperature > 0 else None,
top_p=args.top_p,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None}
with torch.inference_mode():
out = model.generate(**inputs, **gen_kwargs)
new_tokens = out[:, inputs["input_ids"].shape[1]:]
return tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
def row_metrics(example: Dict[str, Any], prediction: str) -> Dict[str, Any]:
gold = gold_text(example)
pred_obj, pred_err = parse_json(prediction)
gold_obj, gold_err = parse_json(gold)
out: Dict[str, Any] = {
"id": example.get("id"),
"target_layer": example.get("target_layer"),
"slice_type": example.get("slice_type"),
"lifecycle_operation": example.get("lifecycle_operation"),
"parse_json": pred_obj is not None,
"gold_parse_json": gold_obj is not None,
"exact_match": False,
"prediction": prediction,
"gold": gold,
"parse_error": pred_err,
}
if pred_obj is not None and gold_obj is not None:
out["exact_match"] = json_exact_match(pred_obj, gold_obj)
out.update(field_f1(pred_obj, gold_obj))
out.update(metadata_constraint_pass(example, prediction, pred_obj))
else:
out.update({"field_precision": 0.0, "field_recall": 0.0, "field_f1": 0.0, "field_tp": 0, "field_fp": 0, "field_fn": 0})
out.update({"slice_sst_pass": False, "kpi_text_presence_pass": False, "adversarial_status_pass": False})
return out
def require_cuda():
print("=== CUDA CHECK ===")
print(f"torch={torch.__version__} torch.version.cuda={torch.version.cuda} CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}")
if not torch.cuda.is_available():
raise RuntimeError(
"CUDA is not available to PyTorch. Refusing to run generation on CPU. "
"Run `bash scripts/install_rtx6000ada.sh`, verify `nvidia-smi`, and set CUDA_VISIBLE_DEVICES=0."
)
print(f"cuda device_count={torch.cuda.device_count()} gpu0={torch.cuda.get_device_name(0)}")
def load_existing_predictions(path: Path) -> Tuple[List[Dict[str, Any]], set]:
if path.exists():
rows = json.loads(path.read_text())
done = {str(r.get("id")) for r in rows}
return rows, done
return [], set()
def write_split_outputs(split_dir: Path, rows: List[Dict[str, Any]]) -> Dict[str, Any]:
write_json(split_dir / "predictions.json", rows)
summary = aggregate_metrics(rows)
for key in ["target_layer", "slice_type", "lifecycle_operation"]:
groups = defaultdict(list)
for r in rows:
groups[str(r.get(key))].append(r)
summary[f"by_{key}"] = {g: aggregate_metrics(v) for g, v in sorted(groups.items())}
write_json(split_dir / "metrics.json", summary)
return summary
def main():
args = parse_args()
require_cuda()
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
tokenizer = AutoTokenizer.from_pretrained(args.adapter or args.model, trust_remote_code=args.trust_remote_code)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
quant_config = None
if args.load_in_4bit:
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
args.model,
trust_remote_code=args.trust_remote_code,
dtype=torch.bfloat16 if args.bf16 else torch.float16,
device_map="auto",
quantization_config=quant_config,
)
if args.adapter:
model = PeftModel.from_pretrained(model, args.adapter)
model.eval()
ds = load_dataset(args.dataset)
all_summary = {}
for split in args.splits:
split_ds = ds[split]
if args.max_samples_per_split:
split_ds = split_ds.select(range(min(args.max_samples_per_split, len(split_ds))))
split_dir = out_dir / split
split_dir.mkdir(parents=True, exist_ok=True)
pred_path = split_dir / "predictions.json"
rows, done_ids = load_existing_predictions(pred_path) if args.resume else ([], set())
todo = [ex for ex in split_ds if str(ex.get("id")) not in done_ids]
print(f"Evaluating {split}: total={len(split_ds)} already_done={len(done_ids)} remaining={len(todo)} batch_size={args.batch_size}")
pbar = tqdm(total=len(todo), desc=split)
completed_since_save = 0
for start in range(0, len(todo), args.batch_size):
batch = todo[start:start + args.batch_size]
try:
preds = generate_batch(model, tokenizer, batch, args)
except torch.cuda.OutOfMemoryError:
torch.cuda.empty_cache()
if args.batch_size == 1 or len(batch) == 1:
raise
preds = []
for ex in batch:
preds.extend(generate_batch(model, tokenizer, [ex], args))
for ex, pred in zip(batch, preds):
rows.append(row_metrics(ex, pred.strip()))
pbar.update(len(batch))
completed_since_save += len(batch)
if completed_since_save >= args.save_every:
write_split_outputs(split_dir, rows)
completed_since_save = 0
pbar.close()
summary = write_split_outputs(split_dir, rows)
all_summary[split] = summary
write_json(out_dir / "all_metrics.json", all_summary)
print(json.dumps(all_summary, indent=2)[:4000])
if __name__ == "__main__":
main()