PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
File size: 10,130 Bytes
d9ba941
6f5475f
d9ba941
 
 
6f5475f
 
 
d9ba941
 
 
 
 
 
6f5475f
d9ba941
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f5475f
 
 
 
 
 
d9ba941
 
 
 
 
 
 
 
 
 
6f5475f
 
d9ba941
 
 
 
 
 
 
6f5475f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9ba941
6f5475f
d9ba941
 
 
 
 
 
 
6f5475f
d9ba941
6f5475f
 
d9ba941
 
 
6f5475f
d9ba941
6f5475f
d9ba941
 
 
 
 
 
 
 
 
6f5475f
d9ba941
 
 
 
 
 
 
 
 
 
 
 
91d636a
 
 
 
 
 
 
 
 
 
 
6f5475f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9ba941
 
91d636a
d9ba941
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f5475f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9ba941
6f5475f
d9ba941
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
#!/usr/bin/env python3
"""Fast generation evaluation for TMF921 intent-to-config models/adapters.

Metrics: JSON parse rate, exact canonical JSON match, field-level F1, simple metadata constraints,
stratified by split/target_layer/slice_type/lifecycle_operation.

The evaluator is resumable and batched. It periodically writes partial predictions so a long OOD
run can be stopped/restarted without losing completed examples.
"""
import argparse
import json
import os
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple

import torch
from datasets import load_dataset
from peft import PeftModel
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from tmf921_train.utils import aggregate_metrics, field_f1, get_message, json_exact_match, metadata_constraint_pass, parse_json, write_json


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--model", required=True, help="Base model or merged model path/id")
    p.add_argument("--adapter", default=None, help="Optional PEFT adapter path/id")
    p.add_argument("--dataset", default="nraptisss/TMF921-intent-to-config-research-sota")
    p.add_argument("--splits", nargs="+", default=["test_in_distribution", "test_template_ood", "test_use_case_ood", "test_sector_ood", "test_adversarial"])
    p.add_argument("--output_dir", default="eval_outputs")
    p.add_argument("--max_samples_per_split", type=int, default=None)
    p.add_argument("--max_new_tokens", type=int, default=1536, help="Hard cap for generation. 1536 covers audited dataset outputs with margin.")
    p.add_argument("--gold_length_buffer", type=int, default=96, help="Dynamic cap = max gold output tokens in batch + buffer, clipped by max_new_tokens")
    p.add_argument("--batch_size", type=int, default=4, help="Batched generation size. Use 1 if OOM; 4 is usually safe on RTX 6000 Ada for Qwen3-8B QLoRA.")
    p.add_argument("--save_every", type=int, default=25, help="Write partial predictions every N completed examples")
    p.add_argument("--resume", action="store_true", default=True, help="Resume from existing predictions.json if present")
    p.add_argument("--no_resume", dest="resume", action="store_false")
    p.add_argument("--temperature", type=float, default=0.0)
    p.add_argument("--top_p", type=float, default=1.0)
    p.add_argument("--load_in_4bit", action="store_true", help="Load base model in 4-bit for adapter evaluation")
    p.add_argument("--bf16", action="store_true", default=True)
    p.add_argument("--trust_remote_code", action="store_true", default=True)
    return p.parse_args()


def make_prompt_messages(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
    out = []
    for i, m in enumerate(messages):
        if i == len(messages) - 1 and m.get("role") == "assistant":
            break
        out.append({"role": m.get("role"), "content": m.get("content", "")})
    if not out:
        out = [m for m in messages if m.get("role") != "assistant"]
    return out


def make_prompt_text(tokenizer, messages: List[Dict[str, str]]) -> str:
    return tokenizer.apply_chat_template(make_prompt_messages(messages), tokenize=False, add_generation_prompt=True)


def gold_text(example: Dict[str, Any]) -> str:
    return example.get("completion") or get_message(example["messages"], "assistant")


def dynamic_max_new_tokens(tokenizer, examples: List[Dict[str, Any]], args) -> int:
    lens = []
    for ex in examples:
        ids = tokenizer(gold_text(ex), add_special_tokens=False)["input_ids"]
        lens.append(len(ids))
    return max(16, min(int(args.max_new_tokens), max(lens) + int(args.gold_length_buffer)))


def generate_batch(model, tokenizer, examples: List[Dict[str, Any]], args) -> List[str]:
    texts = [make_prompt_text(tokenizer, ex["messages"]) for ex in examples]
    old_padding_side = tokenizer.padding_side
    tokenizer.padding_side = "left"
    try:
        inputs = tokenizer(texts, return_tensors="pt", padding=True).to(model.device)
    finally:
        tokenizer.padding_side = old_padding_side
    max_new = dynamic_max_new_tokens(tokenizer, examples, args)
    gen_kwargs = dict(
        max_new_tokens=max_new,
        do_sample=args.temperature > 0,
        temperature=args.temperature if args.temperature > 0 else None,
        top_p=args.top_p,
        pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None}
    with torch.inference_mode():
        out = model.generate(**inputs, **gen_kwargs)
    new_tokens = out[:, inputs["input_ids"].shape[1]:]
    return tokenizer.batch_decode(new_tokens, skip_special_tokens=True)


def row_metrics(example: Dict[str, Any], prediction: str) -> Dict[str, Any]:
    gold = gold_text(example)
    pred_obj, pred_err = parse_json(prediction)
    gold_obj, gold_err = parse_json(gold)
    out: Dict[str, Any] = {
        "id": example.get("id"),
        "target_layer": example.get("target_layer"),
        "slice_type": example.get("slice_type"),
        "lifecycle_operation": example.get("lifecycle_operation"),
        "parse_json": pred_obj is not None,
        "gold_parse_json": gold_obj is not None,
        "exact_match": False,
        "prediction": prediction,
        "gold": gold,
        "parse_error": pred_err,
    }
    if pred_obj is not None and gold_obj is not None:
        out["exact_match"] = json_exact_match(pred_obj, gold_obj)
        out.update(field_f1(pred_obj, gold_obj))
        out.update(metadata_constraint_pass(example, prediction, pred_obj))
    else:
        out.update({"field_precision": 0.0, "field_recall": 0.0, "field_f1": 0.0, "field_tp": 0, "field_fp": 0, "field_fn": 0})
        out.update({"slice_sst_pass": False, "kpi_text_presence_pass": False, "adversarial_status_pass": False})
    return out


def require_cuda():
    print("=== CUDA CHECK ===")
    print(f"torch={torch.__version__} torch.version.cuda={torch.version.cuda} CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}")
    if not torch.cuda.is_available():
        raise RuntimeError(
            "CUDA is not available to PyTorch. Refusing to run generation on CPU. "
            "Run `bash scripts/install_rtx6000ada.sh`, verify `nvidia-smi`, and set CUDA_VISIBLE_DEVICES=0."
        )
    print(f"cuda device_count={torch.cuda.device_count()} gpu0={torch.cuda.get_device_name(0)}")


def load_existing_predictions(path: Path) -> Tuple[List[Dict[str, Any]], set]:
    if path.exists():
        rows = json.loads(path.read_text())
        done = {str(r.get("id")) for r in rows}
        return rows, done
    return [], set()


def write_split_outputs(split_dir: Path, rows: List[Dict[str, Any]]) -> Dict[str, Any]:
    write_json(split_dir / "predictions.json", rows)
    summary = aggregate_metrics(rows)
    for key in ["target_layer", "slice_type", "lifecycle_operation"]:
        groups = defaultdict(list)
        for r in rows:
            groups[str(r.get(key))].append(r)
        summary[f"by_{key}"] = {g: aggregate_metrics(v) for g, v in sorted(groups.items())}
    write_json(split_dir / "metrics.json", summary)
    return summary


def main():
    args = parse_args()
    require_cuda()
    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    tokenizer = AutoTokenizer.from_pretrained(args.adapter or args.model, trust_remote_code=args.trust_remote_code)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    quant_config = None
    if args.load_in_4bit:
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
        )
    model = AutoModelForCausalLM.from_pretrained(
        args.model,
        trust_remote_code=args.trust_remote_code,
        dtype=torch.bfloat16 if args.bf16 else torch.float16,
        device_map="auto",
        quantization_config=quant_config,
    )
    if args.adapter:
        model = PeftModel.from_pretrained(model, args.adapter)
    model.eval()

    ds = load_dataset(args.dataset)
    all_summary = {}
    for split in args.splits:
        split_ds = ds[split]
        if args.max_samples_per_split:
            split_ds = split_ds.select(range(min(args.max_samples_per_split, len(split_ds))))
        split_dir = out_dir / split
        split_dir.mkdir(parents=True, exist_ok=True)
        pred_path = split_dir / "predictions.json"
        rows, done_ids = load_existing_predictions(pred_path) if args.resume else ([], set())
        todo = [ex for ex in split_ds if str(ex.get("id")) not in done_ids]
        print(f"Evaluating {split}: total={len(split_ds)} already_done={len(done_ids)} remaining={len(todo)} batch_size={args.batch_size}")
        pbar = tqdm(total=len(todo), desc=split)
        completed_since_save = 0
        for start in range(0, len(todo), args.batch_size):
            batch = todo[start:start + args.batch_size]
            try:
                preds = generate_batch(model, tokenizer, batch, args)
            except torch.cuda.OutOfMemoryError:
                torch.cuda.empty_cache()
                if args.batch_size == 1 or len(batch) == 1:
                    raise
                preds = []
                for ex in batch:
                    preds.extend(generate_batch(model, tokenizer, [ex], args))
            for ex, pred in zip(batch, preds):
                rows.append(row_metrics(ex, pred.strip()))
            pbar.update(len(batch))
            completed_since_save += len(batch)
            if completed_since_save >= args.save_every:
                write_split_outputs(split_dir, rows)
                completed_since_save = 0
        pbar.close()
        summary = write_split_outputs(split_dir, rows)
        all_summary[split] = summary
        write_json(out_dir / "all_metrics.json", all_summary)
    print(json.dumps(all_summary, indent=2)[:4000])


if __name__ == "__main__":
    main()