nraptisss commited on
Commit
1dcff4f
·
verified ·
1 Parent(s): 77ca79f

Upload benchmark.py

Browse files
Files changed (1) hide show
  1. benchmark.py +296 -0
benchmark.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Benchmark evaluation script for telecom intent-to-config models.
3
+ Evaluates on a test dataset and computes metrics:
4
+ - JSON validity rate
5
+ - Schema compliance (key presence)
6
+ - Semantic fidelity (embedding similarity)
7
+ - Per-target-layer breakdown
8
+
9
+ Usage on Kaggle:
10
+ python benchmark.py \
11
+ --adapter_path ./qwen2.5-7b-telecom-intent-lora \
12
+ --dataset nraptisss/TMF921-intent-to-config-augmented \
13
+ --split test \
14
+ --max_samples 100 \
15
+ --output benchmark_results.json
16
+ """
17
+
18
+ import argparse
19
+ import json
20
+ import re
21
+
22
+ import torch
23
+ from datasets import load_dataset
24
+ from transformers import AutoModelForCausalLM, AutoTokenizer
25
+ from peft import PeftModel
26
+ from sentence_transformers import SentenceTransformer
27
+ import numpy as np
28
+
29
+ # ============================================================================
30
+ # CONFIGURATION
31
+ # ============================================================================
32
+
33
+ BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
34
+ MAX_NEW_TOKENS = 1024
35
+ TEMPERATURE = 0.1
36
+ TOP_P = 0.95
37
+
38
+
39
+ def load_model(adapter_path: str, base_model: str):
40
+ """Load base model + LoRA adapters."""
41
+ print(f"Loading base model: {base_model}")
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ base_model,
44
+ torch_dtype=torch.float16,
45
+ device_map="auto",
46
+ trust_remote_code=True,
47
+ )
48
+ print(f"Loading LoRA adapters: {adapter_path}")
49
+ model = PeftModel.from_pretrained(model, adapter_path)
50
+ model.eval()
51
+
52
+ tokenizer = AutoTokenizer.from_pretrained(
53
+ base_model,
54
+ trust_remote_code=True,
55
+ padding_side="left",
56
+ )
57
+ if tokenizer.pad_token is None:
58
+ tokenizer.pad_token = tokenizer.eos_token
59
+
60
+ return model, tokenizer
61
+
62
+
63
+ def generate_config(model, tokenizer, messages: list) -> str:
64
+ """Generate config from messages list."""
65
+ prompt = tokenizer.apply_chat_template(
66
+ messages,
67
+ tokenize=False,
68
+ add_generation_prompt=True,
69
+ )
70
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
71
+
72
+ with torch.no_grad():
73
+ outputs = model.generate(
74
+ **inputs,
75
+ max_new_tokens=MAX_NEW_TOKENS,
76
+ temperature=TEMPERATURE,
77
+ top_p=TOP_P,
78
+ do_sample=True,
79
+ pad_token_id=tokenizer.pad_token_id,
80
+ eos_token_id=tokenizer.eos_token_id,
81
+ )
82
+
83
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
84
+ response = generated[len(prompt):].strip()
85
+
86
+ # Extract JSON from markdown code blocks
87
+ json_match = re.search(r"```(?:json)?\s*(.*?)\s*```", response, re.DOTALL)
88
+ if json_match:
89
+ response = json_match.group(1)
90
+
91
+ return response.strip()
92
+
93
+
94
+ def validate_json(text: str) -> tuple[bool, dict | None]:
95
+ """Try to parse as JSON."""
96
+ try:
97
+ text = text.strip()
98
+ start = text.find("{")
99
+ end = text.rfind("}")
100
+ if start != -1 and end != -1 and end > start:
101
+ text = text[start:end + 1]
102
+ parsed = json.loads(text)
103
+ return True, parsed
104
+ except json.JSONDecodeError:
105
+ return False, None
106
+
107
+
108
+ def check_schema_compliance(parsed: dict, target_layer: str) -> dict:
109
+ """Check required keys based on target layer."""
110
+ # Define expected top-level keys per target layer
111
+ schema_map = {
112
+ "tmf921": ["intent", "intentId", "name"],
113
+ "camara": ["networkSliceBooking", "sliceType"],
114
+ "intent_3gpp": ["ManagedElement", "intent"],
115
+ "etsi_zsm": ["intent", "serviceProfile"],
116
+ "a1_policy": ["policy", "policyType"],
117
+ "o1_nrm": ["ManagedElement", "GNBDUFunction"],
118
+ }
119
+
120
+ expected = schema_map.get(target_layer.lower(), [])
121
+ present = [k for k in expected if k in parsed]
122
+ missing = [k for k in expected if k not in parsed]
123
+
124
+ return {
125
+ "compliance_score": len(present) / max(len(expected), 1),
126
+ "present_keys": present,
127
+ "missing_keys": missing,
128
+ }
129
+
130
+
131
+ def main():
132
+ parser = argparse.ArgumentParser(description="Telecom Intent Benchmark")
133
+ parser.add_argument(
134
+ "--adapter_path",
135
+ type=str,
136
+ default="./qwen2.5-7b-telecom-intent-lora",
137
+ help="Path to LoRA adapters",
138
+ )
139
+ parser.add_argument(
140
+ "--base_model",
141
+ type=str,
142
+ default=BASE_MODEL,
143
+ help="Base model name",
144
+ )
145
+ parser.add_argument(
146
+ "--dataset",
147
+ type=str,
148
+ default="nraptisss/TMF921-intent-to-config-augmented",
149
+ help="Dataset to evaluate on",
150
+ )
151
+ parser.add_argument(
152
+ "--dataset_config",
153
+ type=str,
154
+ default="default",
155
+ help="Dataset config name",
156
+ )
157
+ parser.add_argument(
158
+ "--split",
159
+ type=str,
160
+ default="test",
161
+ help="Dataset split to evaluate",
162
+ )
163
+ parser.add_argument(
164
+ "--max_samples",
165
+ type=int,
166
+ default=100,
167
+ help="Max number of samples to evaluate",
168
+ )
169
+ parser.add_argument(
170
+ "--output",
171
+ type=str,
172
+ default="benchmark_results.json",
173
+ help="Output file for results",
174
+ )
175
+ args = parser.parse_args()
176
+
177
+ # Load model
178
+ model, tokenizer = load_model(args.adapter_path, args.base_model)
179
+
180
+ # Load dataset
181
+ print(f"\nLoading dataset: {args.dataset} ({args.split})")
182
+ ds = load_dataset(args.dataset, args.dataset_config, split=args.split)
183
+ if args.max_samples:
184
+ ds = ds.select(range(min(args.max_samples, len(ds))))
185
+ print(f"Evaluating on {len(ds)} samples")
186
+
187
+ # Load embedding model for semantic similarity (optional, falls back to string match)
188
+ try:
189
+ embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
190
+ use_embedding = True
191
+ print("Loaded embedding model for semantic similarity")
192
+ except Exception as e:
193
+ print(f"Embedding model not available ({e}), using string similarity only")
194
+ use_embedding = False
195
+
196
+ # Run evaluation
197
+ results = []
198
+ valid_count = 0
199
+ compliance_scores = []
200
+ layer_stats = {}
201
+
202
+ for i, sample in enumerate(ds):
203
+ messages = sample["messages"]
204
+ target_layer = sample.get("target_layer", "unknown")
205
+
206
+ # Extract reference (assistant content)
207
+ reference = ""
208
+ for m in messages:
209
+ if m.get("role") == "assistant":
210
+ reference = m.get("content", "")
211
+ break
212
+
213
+ # Reconstruct user messages for generation
214
+ gen_messages = [m for m in messages if m.get("role") != "assistant"]
215
+
216
+ # Generate
217
+ generated = generate_config(model, tokenizer, gen_messages)
218
+ is_valid, parsed = validate_json(generated)
219
+
220
+ if is_valid:
221
+ valid_count += 1
222
+ schema = check_schema_compliance(parsed, target_layer)
223
+ compliance_scores.append(schema["compliance_score"])
224
+ else:
225
+ schema = {"compliance_score": 0.0, "present_keys": [], "missing_keys": []}
226
+
227
+ # Semantic similarity (if embeddings available)
228
+ semantic_sim = None
229
+ if use_embedding and is_valid:
230
+ ref_emb = embed_model.encode(reference, convert_to_tensor=True)
231
+ gen_emb = embed_model.encode(generated, convert_to_tensor=True)
232
+ semantic_sim = float(torch.cosine_similarity(ref_emb, gen_emb, dim=0))
233
+
234
+ result = {
235
+ "id": sample.get("id", i),
236
+ "target_layer": target_layer,
237
+ "slice_type": sample.get("slice_type", "unknown"),
238
+ "intent": next((m["content"] for m in messages if m.get("role") == "user"), ""),
239
+ "generated": generated,
240
+ "reference": reference,
241
+ "json_valid": is_valid,
242
+ "schema_compliance": schema,
243
+ "semantic_similarity": semantic_sim,
244
+ }
245
+ results.append(result)
246
+
247
+ # Per-layer stats
248
+ if target_layer not in layer_stats:
249
+ layer_stats[target_layer] = {"total": 0, "valid": 0, "compliance": []}
250
+ layer_stats[target_layer]["total"] += 1
251
+ if is_valid:
252
+ layer_stats[target_layer]["valid"] += 1
253
+ layer_stats[target_layer]["compliance"].append(schema["compliance_score"])
254
+
255
+ if (i + 1) % 10 == 0:
256
+ print(f" Processed {i + 1}/{len(ds)} samples")
257
+
258
+ # Compute summary statistics
259
+ total = len(results)
260
+ summary = {
261
+ "total_samples": total,
262
+ "json_valid_rate": valid_count / total,
263
+ "avg_schema_compliance": float(np.mean(compliance_scores)) if compliance_scores else 0.0,
264
+ "semantic_similarity_avg": float(np.mean([r["semantic_similarity"] for r in results if r["semantic_similarity"] is not None])) if any(r["semantic_similarity"] is not None for r in results) else None,
265
+ "per_layer": {},
266
+ }
267
+
268
+ for layer, stats in layer_stats.items():
269
+ summary["per_layer"][layer] = {
270
+ "total": stats["total"],
271
+ "valid_rate": stats["valid"] / stats["total"],
272
+ "avg_compliance": float(np.mean(stats["compliance"])) if stats["compliance"] else 0.0,
273
+ }
274
+
275
+ # Save results
276
+ output_data = {"summary": summary, "results": results}
277
+ with open(args.output, "w") as f:
278
+ json.dump(output_data, f, indent=2)
279
+
280
+ # Print summary
281
+ print(f"\n{'=' * 60}")
282
+ print("BENCHMARK RESULTS")
283
+ print(f"{'=' * 60}")
284
+ print(f"Total samples: {summary['total_samples']}")
285
+ print(f"JSON valid rate: {summary['json_valid_rate']:.1%}")
286
+ print(f"Schema compliance: {summary['avg_schema_compliance']:.1%}")
287
+ if summary["semantic_similarity_avg"] is not None:
288
+ print(f"Semantic similarity: {summary['semantic_similarity_avg']:.3f}")
289
+ print(f"\nPer-layer breakdown:")
290
+ for layer, s in summary["per_layer"].items():
291
+ print(f" {layer:20s} valid={s['valid_rate']:.1%} compliance={s['avg_compliance']:.1%}")
292
+ print(f"\nDetailed results saved to: {args.output}")
293
+
294
+
295
+ if __name__ == "__main__":
296
+ main()