ataeff commited on
Commit
f8b98f7
Β·
verified Β·
1 Parent(s): e7df4fb

Entropy threshold calibration

Browse files
Files changed (1) hide show
  1. calibrate_entropy.py +578 -0
calibrate_entropy.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ calibrate_entropy.py β€” Calibrate entropy thresholds for Adaptive Resonance
4
+
5
+ Runs the model on diverse prompts WITHOUT resonance, recording entropy
6
+ at every generation step. Then computes optimal H_high and H_low thresholds.
7
+
8
+ The calibration is PER-MODEL. Different LoRA adapters will have different
9
+ entropy profiles. ALWAYS recalibrate after training a new adapter.
10
+
11
+ Usage:
12
+ # Calibrate with LoRA adapter
13
+ python calibrate_entropy.py --adapter-path ./gemma3-resonate/best
14
+
15
+ # Calibrate base model (no adapter)
16
+ python calibrate_entropy.py --no-lora
17
+
18
+ # Custom prompts file
19
+ python calibrate_entropy.py --adapter-path ./gemma3-resonate/best \
20
+ --prompts calibration_prompts.txt
21
+
22
+ # Save calibration result
23
+ python calibrate_entropy.py --adapter-path ./gemma3-resonate/best \
24
+ --save calibration.json
25
+
26
+ Author: Wulf (Opus + Oleg)
27
+ Date: 2026-03-28
28
+ """
29
+
30
+ from __future__ import annotations
31
+
32
+ import os
33
+ import sys
34
+ import json
35
+ import math
36
+ import time
37
+ import argparse
38
+ import logging
39
+ from typing import Optional
40
+
41
+ import torch
42
+ import torch.nn.functional as F
43
+
44
+ from transformers import AutoModelForCausalLM, AutoTokenizer
45
+
46
+ # ============================================================================
47
+ # Constants
48
+ # ============================================================================
49
+
50
+ MODEL_ID = "unsloth/gemma-3-270m-it"
51
+ VOCAB_SIZE = 262_144
52
+ H_MAX = math.log2(VOCAB_SIZE) # 18.0 bits
53
+
54
+ START_OF_TURN = "<start_of_turn>"
55
+ END_OF_TURN = "<end_of_turn>"
56
+
57
+ # ============================================================================
58
+ # Logging
59
+ # ============================================================================
60
+
61
+ logging.basicConfig(
62
+ level=logging.INFO,
63
+ format="%(asctime)s [%(levelname)s] %(message)s",
64
+ datefmt="%H:%M:%S",
65
+ )
66
+ log = logging.getLogger("calibrate")
67
+
68
+ # ============================================================================
69
+ # Calibration Prompts β€” diverse, multilingual, varying difficulty
70
+ # ============================================================================
71
+
72
+ DEFAULT_PROMPTS = [
73
+ # Easy factual (should NOT trigger resonance)
74
+ "What is 2 + 2?",
75
+ "What color is the sky?",
76
+ "Who wrote Romeo and Juliet?",
77
+ "What is the capital of France?",
78
+ "How many days are in a week?",
79
+
80
+ # Medium difficulty (may or may not trigger)
81
+ "Explain what a neural network is in simple terms.",
82
+ "What causes inflation?",
83
+ "Why do birds migrate?",
84
+ "How does encryption work?",
85
+ "What is the difference between RNA and DNA?",
86
+
87
+ # Hard reasoning (SHOULD trigger resonance)
88
+ "Why do small language models sometimes outperform larger ones?",
89
+ "Is consciousness computable?",
90
+ "What is the relationship between compression and intelligence?",
91
+ "Can a system understand something it was never explicitly taught?",
92
+ "Why does emergence happen at specific scale thresholds?",
93
+
94
+ # Philosophy (SHOULD trigger)
95
+ "Is free will an illusion?",
96
+ "What is the meaning of life?",
97
+ "If all your memories were replaced, would you still be you?",
98
+ "Does objective morality exist?",
99
+ "What is the nature of time?",
100
+
101
+ # Code (mixed β€” simple bugs shouldn't, architecture should)
102
+ "What does `print(1 + 1)` output in Python?",
103
+ "Why would a recursive function without a base case crash?",
104
+ "How would you design a distributed consensus algorithm?",
105
+ "Explain why attention mechanisms are O(n^2).",
106
+
107
+ # Russian (SHOULD trigger on hard ones)
108
+ "Бколько Π±ΡƒΠ΄Π΅Ρ‚ Π΄Π²Π° плюс Π΄Π²Π°?",
109
+ "ΠŸΠΎΡ‡Π΅ΠΌΡƒ Π½Π΅Π±ΠΎ Π³ΠΎΠ»ΡƒΠ±ΠΎΠ΅?",
110
+ "Π§Ρ‚ΠΎ Ρ‚Π°ΠΊΠΎΠ΅ ΡΠΌΠ΅Ρ€Π΄ΠΆΠ΅Π½Ρ‚Π½ΠΎΡΡ‚ΡŒ Π² Π½Π΅ΠΉΡ€ΠΎΠ½Π½Ρ‹Ρ… сСтях?",
111
+ "Π‘Π²ΠΎΠ±ΠΎΠ΄Π° Π²ΠΎΠ»ΠΈ β€” это иллюзия?",
112
+ "ΠŸΠΎΡ‡Π΅ΠΌΡƒ малСнькиС языковыС ΠΌΠΎΠ΄Π΅Π»ΠΈ ΠΈΠ½ΠΎΠ³Π΄Π° Π»ΡƒΡ‡ΡˆΠ΅ Π±ΠΎΠ»ΡŒΡˆΠΈΡ…?",
113
+
114
+ # French
115
+ "Quelle est la capitale de la France?",
116
+ "Pourquoi les petits modeles de langage sont-ils importants?",
117
+ "Quel est le sens de la vie?",
118
+
119
+ # German
120
+ "Was ist der Sinn des Lebens?",
121
+ "Was bedeutet Emergenz im Kontext neuronaler Netzwerke?",
122
+
123
+ # Ambiguous / creative (high entropy expected)
124
+ "Write a haiku about debugging.",
125
+ "If neural networks could dream, what would they dream about?",
126
+ "Tell me something nobody has ever said before.",
127
+ "What would happen if entropy decreased instead of increased?",
128
+
129
+ # Meta (interesting entropy behavior expected)
130
+ "Explain your reasoning process.",
131
+ "How confident are you in your answers?",
132
+ "What don't you know?",
133
+
134
+ # Math
135
+ "What is the sum of the first 100 positive integers?",
136
+ "Prove that the square root of 2 is irrational.",
137
+ "What is the derivative of x^x?",
138
+
139
+ # Simple instructions (should NOT trigger)
140
+ "List three colors.",
141
+ "Say hello in five languages.",
142
+ "Count to ten.",
143
+ ]
144
+
145
+
146
+ # ============================================================================
147
+ # Entropy Collection
148
+ # ============================================================================
149
+
150
+ def collect_entropy_profile(
151
+ model,
152
+ tokenizer,
153
+ prompt: str,
154
+ max_tokens: int = 100,
155
+ temperature: float = 0.7,
156
+ device: str = 'cuda',
157
+ ) -> dict:
158
+ """Generate from a prompt and collect entropy at every step.
159
+
160
+ We generate normally (no resonance intervention) and just observe
161
+ the entropy curve. This gives us the model's natural entropy profile.
162
+
163
+ Returns dict with:
164
+ 'prompt': str
165
+ 'entropies': list of (H_bits, H_norm) tuples
166
+ 'tokens': list of generated token strings
167
+ 'mean_h': float
168
+ 'max_h': float
169
+ 'min_h': float
170
+ 'std_h': float
171
+ 'first_5_mean': float (mean of first 5 tokens β€” initial uncertainty)
172
+ """
173
+ model.eval()
174
+
175
+ input_text = f"{START_OF_TURN}user\n{prompt}{END_OF_TURN}\n{START_OF_TURN}model\n"
176
+ input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
177
+
178
+ all_ids = input_ids[0].tolist()
179
+ entropies = []
180
+ tokens = []
181
+
182
+ eos_id = tokenizer.eos_token_id
183
+ eot_text = END_OF_TURN
184
+
185
+ generated_text = ""
186
+
187
+ with torch.no_grad():
188
+ outputs = model(input_ids)
189
+ next_logits = outputs.logits[0, -1, :]
190
+
191
+ for step in range(max_tokens):
192
+ # Compute entropy from raw logits
193
+ probs = F.softmax(next_logits.float(), dim=-1).clamp(min=1e-10)
194
+ H = -(probs * probs.log2()).sum().item()
195
+ h_norm = H / H_MAX
196
+
197
+ entropies.append((H, h_norm))
198
+
199
+ # Sample token (normal generation, no resonance intervention)
200
+ logits = next_logits / temperature
201
+ probs_sampling = F.softmax(logits, dim=-1)
202
+ next_token = torch.multinomial(probs_sampling, num_samples=1).item()
203
+
204
+ if next_token == eos_id:
205
+ break
206
+
207
+ all_ids.append(next_token)
208
+ token_str = tokenizer.decode([next_token])
209
+ tokens.append(token_str)
210
+ generated_text += token_str
211
+
212
+ if generated_text.rstrip().endswith(eot_text):
213
+ break
214
+
215
+ # Next step
216
+ full_ids = torch.tensor([all_ids], device=device)
217
+ with torch.no_grad():
218
+ outputs = model(full_ids)
219
+ next_logits = outputs.logits[0, -1, :]
220
+
221
+ # Compute stats
222
+ if not entropies:
223
+ return {
224
+ 'prompt': prompt,
225
+ 'entropies': [],
226
+ 'tokens': [],
227
+ 'mean_h': 0, 'max_h': 0, 'min_h': 0, 'std_h': 0,
228
+ 'first_5_mean': 0,
229
+ }
230
+
231
+ h_norms = [h_norm for _, h_norm in entropies]
232
+ mean_h = sum(h_norms) / len(h_norms)
233
+ max_h = max(h_norms)
234
+ min_h = min(h_norms)
235
+ std_h = (sum((v - mean_h)**2 for v in h_norms) / len(h_norms)) ** 0.5
236
+ first_5 = h_norms[:5]
237
+ first_5_mean = sum(first_5) / len(first_5) if first_5 else 0
238
+
239
+ return {
240
+ 'prompt': prompt,
241
+ 'entropies': entropies,
242
+ 'tokens': tokens,
243
+ 'mean_h': mean_h,
244
+ 'max_h': max_h,
245
+ 'min_h': min_h,
246
+ 'std_h': std_h,
247
+ 'first_5_mean': first_5_mean,
248
+ 'generated': generated_text[:200],
249
+ }
250
+
251
+
252
+ # ============================================================================
253
+ # Threshold Computation
254
+ # ============================================================================
255
+
256
+ def compute_thresholds(profiles: list[dict], target_resonance_rate: float = 0.45) -> dict:
257
+ """Compute optimal H_high and H_low from collected entropy profiles.
258
+
259
+ Algorithm:
260
+ 1. Collect max-entropy and min-entropy per prompt
261
+ 2. H_high = percentile of max-entropies where ~target_resonance_rate
262
+ of prompts would trigger resonance
263
+ 3. H_low = mean of per-prompt min entropies + small margin
264
+
265
+ The target_resonance_rate controls how aggressive resonance is:
266
+ - 0.3 = conservative (resonance on ~30% of prompts, only hard ones)
267
+ - 0.5 = balanced (resonance on ~50% of prompts)
268
+ - 0.7 = aggressive (resonance on ~70% of prompts, even medium questions)
269
+
270
+ Returns dict with calibration results.
271
+ """
272
+ if not profiles:
273
+ return {'h_high': 0.35, 'h_low': 0.12, 'error': 'no profiles'}
274
+
275
+ # Collect per-prompt statistics
276
+ max_entropies = [p['max_h'] for p in profiles if p['entropies']]
277
+ min_entropies = [p['min_h'] for p in profiles if p['entropies']]
278
+ mean_entropies = [p['mean_h'] for p in profiles if p['entropies']]
279
+ std_entropies = [p['std_h'] for p in profiles if p['entropies']]
280
+ first_5_means = [p['first_5_mean'] for p in profiles if p['entropies']]
281
+
282
+ if not max_entropies:
283
+ return {'h_high': 0.35, 'h_low': 0.12, 'error': 'no valid profiles'}
284
+
285
+ # Sort for percentile computation
286
+ max_entropies_sorted = sorted(max_entropies)
287
+ min_entropies_sorted = sorted(min_entropies)
288
+
289
+ # H_high: we want resonance to trigger on (target_resonance_rate)% of prompts
290
+ # That means H_high should be at the (1 - target_resonance_rate) percentile
291
+ # of per-prompt max entropies
292
+ h_high_idx = int(len(max_entropies_sorted) * (1 - target_resonance_rate))
293
+ h_high_idx = max(0, min(len(max_entropies_sorted) - 1, h_high_idx))
294
+ h_high = max_entropies_sorted[h_high_idx]
295
+
296
+ # H_low: mean of per-prompt minimums + 0.5*std for safety margin
297
+ mean_of_mins = sum(min_entropies) / len(min_entropies)
298
+ std_of_mins = (sum((v - mean_of_mins)**2 for v in min_entropies) / len(min_entropies)) ** 0.5
299
+ h_low = mean_of_mins + 0.5 * std_of_mins
300
+
301
+ # Sanity checks
302
+ if h_low >= h_high:
303
+ log.warning(f"h_low ({h_low:.4f}) >= h_high ({h_high:.4f}). Adjusting.")
304
+ # Force minimum gap
305
+ midpoint = (h_low + h_high) / 2
306
+ h_high = midpoint + 0.05
307
+ h_low = midpoint - 0.05
308
+
309
+ if h_high < 0.10:
310
+ log.warning(f"h_high ({h_high:.4f}) is suspiciously low. Setting to 0.20.")
311
+ h_high = 0.20
312
+
313
+ if h_low < 0.02:
314
+ h_low = 0.02
315
+
316
+ # Compute what the actual resonance rate would be
317
+ would_trigger = sum(1 for m in max_entropies if m > h_high)
318
+ actual_rate = would_trigger / len(max_entropies)
319
+
320
+ # Compute global statistics
321
+ all_h = []
322
+ for p in profiles:
323
+ all_h.extend([h_norm for _, h_norm in p['entropies']])
324
+
325
+ global_mean = sum(all_h) / len(all_h) if all_h else 0
326
+ global_std = (sum((v - global_mean)**2 for v in all_h) / len(all_h)) ** 0.5 if all_h else 0
327
+ global_max = max(all_h) if all_h else 0
328
+ global_min = min(all_h) if all_h else 0
329
+
330
+ result = {
331
+ 'h_high': round(h_high, 4),
332
+ 'h_low': round(h_low, 4),
333
+ 'target_resonance_rate': target_resonance_rate,
334
+ 'actual_resonance_rate': round(actual_rate, 3),
335
+ 'num_prompts': len(profiles),
336
+ 'num_valid': len(max_entropies),
337
+ 'global_entropy_stats': {
338
+ 'mean': round(global_mean, 4),
339
+ 'std': round(global_std, 4),
340
+ 'max': round(global_max, 4),
341
+ 'min': round(global_min, 4),
342
+ },
343
+ 'per_prompt_max_entropy': {
344
+ 'mean': round(sum(max_entropies) / len(max_entropies), 4),
345
+ 'std': round((sum((v - sum(max_entropies)/len(max_entropies))**2 for v in max_entropies) / len(max_entropies)) ** 0.5, 4),
346
+ 'min': round(min(max_entropies), 4),
347
+ 'max': round(max(max_entropies), 4),
348
+ },
349
+ 'per_prompt_min_entropy': {
350
+ 'mean': round(mean_of_mins, 4),
351
+ 'std': round(std_of_mins, 4),
352
+ },
353
+ 'recommended_enter_count': 3,
354
+ 'recommended_exit_count': 5,
355
+ }
356
+
357
+ return result
358
+
359
+
360
+ # ============================================================================
361
+ # Report
362
+ # ============================================================================
363
+
364
+ def print_report(result: dict, profiles: list[dict]):
365
+ """Print a detailed calibration report."""
366
+
367
+ print(f"\n{'='*70}")
368
+ print(f" ENTROPY CALIBRATION REPORT")
369
+ print(f"{'='*70}")
370
+
371
+ print(f"\n Calibrated on {result['num_prompts']} prompts ({result['num_valid']} valid)")
372
+ print(f"\n RECOMMENDED THRESHOLDS:")
373
+ print(f" H_high = {result['h_high']:.4f} (enter resonance above this)")
374
+ print(f" H_low = {result['h_low']:.4f} (exit resonance below this)")
375
+ print(f"\n Expected resonance rate: {result['actual_resonance_rate']:.0%} of prompts")
376
+ print(f" Target was: {result['target_resonance_rate']:.0%}")
377
+
378
+ gs = result['global_entropy_stats']
379
+ print(f"\n Global entropy (H_norm):")
380
+ print(f" mean={gs['mean']:.4f} std={gs['std']:.4f} min={gs['min']:.4f} max={gs['max']:.4f}")
381
+
382
+ pm = result['per_prompt_max_entropy']
383
+ print(f"\n Per-prompt max entropy:")
384
+ print(f" mean={pm['mean']:.4f} std={pm['std']:.4f} range=[{pm['min']:.4f}, {pm['max']:.4f}]")
385
+
386
+ # Per-prompt breakdown
387
+ print(f"\n{'─'*70}")
388
+ print(f" PER-PROMPT ANALYSIS")
389
+ print(f"{'─'*70}")
390
+ print(f" {'Prompt':<50} {'MaxH':>7} {'MeanH':>7} {'Trigger':>8}")
391
+ print(f" {'─'*50} {'─'*7} {'─'*7} {'─'*8}")
392
+
393
+ for p in sorted(profiles, key=lambda x: -x['max_h']):
394
+ if not p['entropies']:
395
+ continue
396
+ prompt_short = p['prompt'][:48]
397
+ trigger = "YES" if p['max_h'] > result['h_high'] else "no"
398
+ trigger_mark = ">>>" if trigger == "YES" else " "
399
+ print(f" {trigger_mark}{prompt_short:<47} {p['max_h']:>7.4f} {p['mean_h']:>7.4f} {trigger:>8}")
400
+
401
+ # Histogram of max entropies
402
+ print(f"\n{'─'*70}")
403
+ print(f" MAX ENTROPY DISTRIBUTION")
404
+ print(f"{'─'*70}")
405
+
406
+ max_hs = sorted([p['max_h'] for p in profiles if p['entropies']])
407
+ if max_hs:
408
+ n_bins = 15
409
+ bin_min = 0.0
410
+ bin_max = max(max_hs) * 1.1
411
+ bin_width = (bin_max - bin_min) / n_bins
412
+
413
+ bins = [0] * n_bins
414
+ for v in max_hs:
415
+ idx = min(int((v - bin_min) / bin_width), n_bins - 1)
416
+ bins[idx] += 1
417
+
418
+ max_count = max(bins) if bins else 1
419
+ bar_width = 40
420
+
421
+ for i, count in enumerate(bins):
422
+ lo = bin_min + i * bin_width
423
+ hi = lo + bin_width
424
+ bar_len = int(count / max_count * bar_width) if max_count > 0 else 0
425
+ bar = '#' * bar_len
426
+
427
+ # Mark threshold
428
+ marker = ""
429
+ if lo <= result['h_high'] < hi:
430
+ marker = " <-- H_high"
431
+
432
+ print(f" {lo:.3f}-{hi:.3f} |{bar:<{bar_width}}| {count:>3}{marker}")
433
+
434
+ # Usage instructions
435
+ print(f"\n{'─'*70}")
436
+ print(f" USAGE")
437
+ print(f"{'─'*70}")
438
+ print(f" python entropy_resonance.py \\")
439
+ print(f" --adapter-path ./gemma3-resonate/best \\")
440
+ print(f" --h-high {result['h_high']:.4f} \\")
441
+ print(f" --h-low {result['h_low']:.4f}")
442
+ print(f"\n{'='*70}\n")
443
+
444
+
445
+ # ============================================================================
446
+ # Main
447
+ # ============================================================================
448
+
449
+ def main():
450
+ parser = argparse.ArgumentParser(
451
+ description="Calibrate entropy thresholds for Adaptive Resonance"
452
+ )
453
+
454
+ parser.add_argument("--model", default=MODEL_ID, help="Base model ID")
455
+ parser.add_argument("--adapter-path", default=None, help="LoRA adapter path")
456
+ parser.add_argument("--no-lora", action="store_true", help="Skip LoRA loading")
457
+ parser.add_argument("--device", default=None, help="Device: cuda/cpu/mps")
458
+
459
+ parser.add_argument("--prompts", default=None,
460
+ help="Text file with prompts, one per line")
461
+ parser.add_argument("--max-tokens", type=int, default=100,
462
+ help="Max tokens per generation during calibration")
463
+ parser.add_argument("--target-rate", type=float, default=0.45,
464
+ help="Target resonance trigger rate (0-1)")
465
+ parser.add_argument("--temperature", type=float, default=0.7,
466
+ help="Sampling temperature during calibration")
467
+
468
+ parser.add_argument("--save", default=None,
469
+ help="Save calibration result to JSON file")
470
+
471
+ args = parser.parse_args()
472
+
473
+ # Device
474
+ if args.device is None:
475
+ if torch.cuda.is_available():
476
+ device = 'cuda'
477
+ elif torch.backends.mps.is_available():
478
+ device = 'mps'
479
+ else:
480
+ device = 'cpu'
481
+ else:
482
+ device = args.device
483
+
484
+ # Load model
485
+ log.info(f"Loading tokenizer from {args.model}...")
486
+ tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
487
+
488
+ dtype = torch.bfloat16 if device == 'cuda' else torch.float32
489
+ log.info(f"Loading model from {args.model} onto {device}...")
490
+
491
+ model = AutoModelForCausalLM.from_pretrained(
492
+ args.model,
493
+ torch_dtype=dtype,
494
+ device_map=device if device == 'cuda' else None,
495
+ attn_implementation="sdpa" if device == 'cuda' else "eager",
496
+ trust_remote_code=True,
497
+ )
498
+
499
+ if device != 'cuda':
500
+ model = model.to(device)
501
+
502
+ if args.adapter_path and not args.no_lora:
503
+ from peft import PeftModel
504
+ log.info(f"Loading adapter from {args.adapter_path}...")
505
+ model = PeftModel.from_pretrained(model, args.adapter_path)
506
+
507
+ model.eval()
508
+
509
+ # Load prompts
510
+ if args.prompts:
511
+ with open(args.prompts, 'r', encoding='utf-8') as f:
512
+ prompts = [line.strip() for line in f if line.strip()]
513
+ log.info(f"Loaded {len(prompts)} prompts from {args.prompts}")
514
+ else:
515
+ prompts = DEFAULT_PROMPTS
516
+ log.info(f"Using {len(prompts)} default calibration prompts")
517
+
518
+ # Collect entropy profiles
519
+ log.info(f"Collecting entropy profiles ({args.max_tokens} tokens/prompt)...")
520
+ profiles = []
521
+ t0 = time.time()
522
+
523
+ for i, prompt in enumerate(prompts):
524
+ log.info(f" [{i+1}/{len(prompts)}] {prompt[:60]}...")
525
+ profile = collect_entropy_profile(
526
+ model, tokenizer, prompt,
527
+ max_tokens=args.max_tokens,
528
+ temperature=args.temperature,
529
+ device=device,
530
+ )
531
+ profiles.append(profile)
532
+
533
+ if profile['entropies']:
534
+ log.info(f" H_norm: mean={profile['mean_h']:.4f} max={profile['max_h']:.4f} "
535
+ f"min={profile['min_h']:.4f} ({len(profile['entropies'])} tokens)")
536
+
537
+ elapsed = time.time() - t0
538
+ log.info(f"Collection complete in {elapsed:.1f}s")
539
+
540
+ # Compute thresholds
541
+ result = compute_thresholds(profiles, target_resonance_rate=args.target_rate)
542
+
543
+ # Print report
544
+ print_report(result, profiles)
545
+
546
+ # Save if requested
547
+ if args.save:
548
+ # Don't save the full entropy traces (too large) β€” just the result
549
+ save_data = {
550
+ 'calibration': result,
551
+ 'per_prompt_summary': [
552
+ {
553
+ 'prompt': p['prompt'],
554
+ 'mean_h': round(p['mean_h'], 4),
555
+ 'max_h': round(p['max_h'], 4),
556
+ 'min_h': round(p['min_h'], 4),
557
+ 'std_h': round(p['std_h'], 4),
558
+ 'first_5_mean': round(p['first_5_mean'], 4),
559
+ 'n_tokens': len(p['entropies']),
560
+ 'would_trigger': p['max_h'] > result['h_high'],
561
+ }
562
+ for p in profiles if p['entropies']
563
+ ],
564
+ 'model': args.model,
565
+ 'adapter': args.adapter_path,
566
+ 'target_rate': args.target_rate,
567
+ 'max_tokens': args.max_tokens,
568
+ 'temperature': args.temperature,
569
+ }
570
+ with open(args.save, 'w', encoding='utf-8') as f:
571
+ json.dump(save_data, f, indent=2, ensure_ascii=False)
572
+ log.info(f"Calibration saved to {args.save}")
573
+
574
+ log.info("Done. Use the recommended thresholds with entropy_resonance.py.")
575
+
576
+
577
+ if __name__ == "__main__":
578
+ main()