g / calibrate_entropy.py
ataeff's picture
Entropy threshold calibration
f8b98f7 verified
#!/usr/bin/env python3
"""
calibrate_entropy.py — Calibrate entropy thresholds for Adaptive Resonance
Runs the model on diverse prompts WITHOUT resonance, recording entropy
at every generation step. Then computes optimal H_high and H_low thresholds.
The calibration is PER-MODEL. Different LoRA adapters will have different
entropy profiles. ALWAYS recalibrate after training a new adapter.
Usage:
# Calibrate with LoRA adapter
python calibrate_entropy.py --adapter-path ./gemma3-resonate/best
# Calibrate base model (no adapter)
python calibrate_entropy.py --no-lora
# Custom prompts file
python calibrate_entropy.py --adapter-path ./gemma3-resonate/best \
--prompts calibration_prompts.txt
# Save calibration result
python calibrate_entropy.py --adapter-path ./gemma3-resonate/best \
--save calibration.json
Author: Wulf (Opus + Oleg)
Date: 2026-03-28
"""
from __future__ import annotations
import os
import sys
import json
import math
import time
import argparse
import logging
from typing import Optional
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
# ============================================================================
# Constants
# ============================================================================
MODEL_ID = "unsloth/gemma-3-270m-it"
VOCAB_SIZE = 262_144
H_MAX = math.log2(VOCAB_SIZE) # 18.0 bits
START_OF_TURN = "<start_of_turn>"
END_OF_TURN = "<end_of_turn>"
# ============================================================================
# Logging
# ============================================================================
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%H:%M:%S",
)
log = logging.getLogger("calibrate")
# ============================================================================
# Calibration Prompts — diverse, multilingual, varying difficulty
# ============================================================================
DEFAULT_PROMPTS = [
# Easy factual (should NOT trigger resonance)
"What is 2 + 2?",
"What color is the sky?",
"Who wrote Romeo and Juliet?",
"What is the capital of France?",
"How many days are in a week?",
# Medium difficulty (may or may not trigger)
"Explain what a neural network is in simple terms.",
"What causes inflation?",
"Why do birds migrate?",
"How does encryption work?",
"What is the difference between RNA and DNA?",
# Hard reasoning (SHOULD trigger resonance)
"Why do small language models sometimes outperform larger ones?",
"Is consciousness computable?",
"What is the relationship between compression and intelligence?",
"Can a system understand something it was never explicitly taught?",
"Why does emergence happen at specific scale thresholds?",
# Philosophy (SHOULD trigger)
"Is free will an illusion?",
"What is the meaning of life?",
"If all your memories were replaced, would you still be you?",
"Does objective morality exist?",
"What is the nature of time?",
# Code (mixed — simple bugs shouldn't, architecture should)
"What does `print(1 + 1)` output in Python?",
"Why would a recursive function without a base case crash?",
"How would you design a distributed consensus algorithm?",
"Explain why attention mechanisms are O(n^2).",
# Russian (SHOULD trigger on hard ones)
"Сколько будет два плюс два?",
"Почему небо голубое?",
"Что такое эмерджентность в нейронных сетях?",
"Свобода воли — это иллюзия?",
"Почему маленькие языковые модели иногда лучше больших?",
# French
"Quelle est la capitale de la France?",
"Pourquoi les petits modeles de langage sont-ils importants?",
"Quel est le sens de la vie?",
# German
"Was ist der Sinn des Lebens?",
"Was bedeutet Emergenz im Kontext neuronaler Netzwerke?",
# Ambiguous / creative (high entropy expected)
"Write a haiku about debugging.",
"If neural networks could dream, what would they dream about?",
"Tell me something nobody has ever said before.",
"What would happen if entropy decreased instead of increased?",
# Meta (interesting entropy behavior expected)
"Explain your reasoning process.",
"How confident are you in your answers?",
"What don't you know?",
# Math
"What is the sum of the first 100 positive integers?",
"Prove that the square root of 2 is irrational.",
"What is the derivative of x^x?",
# Simple instructions (should NOT trigger)
"List three colors.",
"Say hello in five languages.",
"Count to ten.",
]
# ============================================================================
# Entropy Collection
# ============================================================================
def collect_entropy_profile(
model,
tokenizer,
prompt: str,
max_tokens: int = 100,
temperature: float = 0.7,
device: str = 'cuda',
) -> dict:
"""Generate from a prompt and collect entropy at every step.
We generate normally (no resonance intervention) and just observe
the entropy curve. This gives us the model's natural entropy profile.
Returns dict with:
'prompt': str
'entropies': list of (H_bits, H_norm) tuples
'tokens': list of generated token strings
'mean_h': float
'max_h': float
'min_h': float
'std_h': float
'first_5_mean': float (mean of first 5 tokens — initial uncertainty)
"""
model.eval()
input_text = f"{START_OF_TURN}user\n{prompt}{END_OF_TURN}\n{START_OF_TURN}model\n"
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
all_ids = input_ids[0].tolist()
entropies = []
tokens = []
eos_id = tokenizer.eos_token_id
eot_text = END_OF_TURN
generated_text = ""
with torch.no_grad():
outputs = model(input_ids)
next_logits = outputs.logits[0, -1, :]
for step in range(max_tokens):
# Compute entropy from raw logits
probs = F.softmax(next_logits.float(), dim=-1).clamp(min=1e-10)
H = -(probs * probs.log2()).sum().item()
h_norm = H / H_MAX
entropies.append((H, h_norm))
# Sample token (normal generation, no resonance intervention)
logits = next_logits / temperature
probs_sampling = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs_sampling, num_samples=1).item()
if next_token == eos_id:
break
all_ids.append(next_token)
token_str = tokenizer.decode([next_token])
tokens.append(token_str)
generated_text += token_str
if generated_text.rstrip().endswith(eot_text):
break
# Next step
full_ids = torch.tensor([all_ids], device=device)
with torch.no_grad():
outputs = model(full_ids)
next_logits = outputs.logits[0, -1, :]
# Compute stats
if not entropies:
return {
'prompt': prompt,
'entropies': [],
'tokens': [],
'mean_h': 0, 'max_h': 0, 'min_h': 0, 'std_h': 0,
'first_5_mean': 0,
}
h_norms = [h_norm for _, h_norm in entropies]
mean_h = sum(h_norms) / len(h_norms)
max_h = max(h_norms)
min_h = min(h_norms)
std_h = (sum((v - mean_h)**2 for v in h_norms) / len(h_norms)) ** 0.5
first_5 = h_norms[:5]
first_5_mean = sum(first_5) / len(first_5) if first_5 else 0
return {
'prompt': prompt,
'entropies': entropies,
'tokens': tokens,
'mean_h': mean_h,
'max_h': max_h,
'min_h': min_h,
'std_h': std_h,
'first_5_mean': first_5_mean,
'generated': generated_text[:200],
}
# ============================================================================
# Threshold Computation
# ============================================================================
def compute_thresholds(profiles: list[dict], target_resonance_rate: float = 0.45) -> dict:
"""Compute optimal H_high and H_low from collected entropy profiles.
Algorithm:
1. Collect max-entropy and min-entropy per prompt
2. H_high = percentile of max-entropies where ~target_resonance_rate
of prompts would trigger resonance
3. H_low = mean of per-prompt min entropies + small margin
The target_resonance_rate controls how aggressive resonance is:
- 0.3 = conservative (resonance on ~30% of prompts, only hard ones)
- 0.5 = balanced (resonance on ~50% of prompts)
- 0.7 = aggressive (resonance on ~70% of prompts, even medium questions)
Returns dict with calibration results.
"""
if not profiles:
return {'h_high': 0.35, 'h_low': 0.12, 'error': 'no profiles'}
# Collect per-prompt statistics
max_entropies = [p['max_h'] for p in profiles if p['entropies']]
min_entropies = [p['min_h'] for p in profiles if p['entropies']]
mean_entropies = [p['mean_h'] for p in profiles if p['entropies']]
std_entropies = [p['std_h'] for p in profiles if p['entropies']]
first_5_means = [p['first_5_mean'] for p in profiles if p['entropies']]
if not max_entropies:
return {'h_high': 0.35, 'h_low': 0.12, 'error': 'no valid profiles'}
# Sort for percentile computation
max_entropies_sorted = sorted(max_entropies)
min_entropies_sorted = sorted(min_entropies)
# H_high: we want resonance to trigger on (target_resonance_rate)% of prompts
# That means H_high should be at the (1 - target_resonance_rate) percentile
# of per-prompt max entropies
h_high_idx = int(len(max_entropies_sorted) * (1 - target_resonance_rate))
h_high_idx = max(0, min(len(max_entropies_sorted) - 1, h_high_idx))
h_high = max_entropies_sorted[h_high_idx]
# H_low: mean of per-prompt minimums + 0.5*std for safety margin
mean_of_mins = sum(min_entropies) / len(min_entropies)
std_of_mins = (sum((v - mean_of_mins)**2 for v in min_entropies) / len(min_entropies)) ** 0.5
h_low = mean_of_mins + 0.5 * std_of_mins
# Sanity checks
if h_low >= h_high:
log.warning(f"h_low ({h_low:.4f}) >= h_high ({h_high:.4f}). Adjusting.")
# Force minimum gap
midpoint = (h_low + h_high) / 2
h_high = midpoint + 0.05
h_low = midpoint - 0.05
if h_high < 0.10:
log.warning(f"h_high ({h_high:.4f}) is suspiciously low. Setting to 0.20.")
h_high = 0.20
if h_low < 0.02:
h_low = 0.02
# Compute what the actual resonance rate would be
would_trigger = sum(1 for m in max_entropies if m > h_high)
actual_rate = would_trigger / len(max_entropies)
# Compute global statistics
all_h = []
for p in profiles:
all_h.extend([h_norm for _, h_norm in p['entropies']])
global_mean = sum(all_h) / len(all_h) if all_h else 0
global_std = (sum((v - global_mean)**2 for v in all_h) / len(all_h)) ** 0.5 if all_h else 0
global_max = max(all_h) if all_h else 0
global_min = min(all_h) if all_h else 0
result = {
'h_high': round(h_high, 4),
'h_low': round(h_low, 4),
'target_resonance_rate': target_resonance_rate,
'actual_resonance_rate': round(actual_rate, 3),
'num_prompts': len(profiles),
'num_valid': len(max_entropies),
'global_entropy_stats': {
'mean': round(global_mean, 4),
'std': round(global_std, 4),
'max': round(global_max, 4),
'min': round(global_min, 4),
},
'per_prompt_max_entropy': {
'mean': round(sum(max_entropies) / len(max_entropies), 4),
'std': round((sum((v - sum(max_entropies)/len(max_entropies))**2 for v in max_entropies) / len(max_entropies)) ** 0.5, 4),
'min': round(min(max_entropies), 4),
'max': round(max(max_entropies), 4),
},
'per_prompt_min_entropy': {
'mean': round(mean_of_mins, 4),
'std': round(std_of_mins, 4),
},
'recommended_enter_count': 3,
'recommended_exit_count': 5,
}
return result
# ============================================================================
# Report
# ============================================================================
def print_report(result: dict, profiles: list[dict]):
"""Print a detailed calibration report."""
print(f"\n{'='*70}")
print(f" ENTROPY CALIBRATION REPORT")
print(f"{'='*70}")
print(f"\n Calibrated on {result['num_prompts']} prompts ({result['num_valid']} valid)")
print(f"\n RECOMMENDED THRESHOLDS:")
print(f" H_high = {result['h_high']:.4f} (enter resonance above this)")
print(f" H_low = {result['h_low']:.4f} (exit resonance below this)")
print(f"\n Expected resonance rate: {result['actual_resonance_rate']:.0%} of prompts")
print(f" Target was: {result['target_resonance_rate']:.0%}")
gs = result['global_entropy_stats']
print(f"\n Global entropy (H_norm):")
print(f" mean={gs['mean']:.4f} std={gs['std']:.4f} min={gs['min']:.4f} max={gs['max']:.4f}")
pm = result['per_prompt_max_entropy']
print(f"\n Per-prompt max entropy:")
print(f" mean={pm['mean']:.4f} std={pm['std']:.4f} range=[{pm['min']:.4f}, {pm['max']:.4f}]")
# Per-prompt breakdown
print(f"\n{'─'*70}")
print(f" PER-PROMPT ANALYSIS")
print(f"{'─'*70}")
print(f" {'Prompt':<50} {'MaxH':>7} {'MeanH':>7} {'Trigger':>8}")
print(f" {'─'*50} {'─'*7} {'─'*7} {'─'*8}")
for p in sorted(profiles, key=lambda x: -x['max_h']):
if not p['entropies']:
continue
prompt_short = p['prompt'][:48]
trigger = "YES" if p['max_h'] > result['h_high'] else "no"
trigger_mark = ">>>" if trigger == "YES" else " "
print(f" {trigger_mark}{prompt_short:<47} {p['max_h']:>7.4f} {p['mean_h']:>7.4f} {trigger:>8}")
# Histogram of max entropies
print(f"\n{'─'*70}")
print(f" MAX ENTROPY DISTRIBUTION")
print(f"{'─'*70}")
max_hs = sorted([p['max_h'] for p in profiles if p['entropies']])
if max_hs:
n_bins = 15
bin_min = 0.0
bin_max = max(max_hs) * 1.1
bin_width = (bin_max - bin_min) / n_bins
bins = [0] * n_bins
for v in max_hs:
idx = min(int((v - bin_min) / bin_width), n_bins - 1)
bins[idx] += 1
max_count = max(bins) if bins else 1
bar_width = 40
for i, count in enumerate(bins):
lo = bin_min + i * bin_width
hi = lo + bin_width
bar_len = int(count / max_count * bar_width) if max_count > 0 else 0
bar = '#' * bar_len
# Mark threshold
marker = ""
if lo <= result['h_high'] < hi:
marker = " <-- H_high"
print(f" {lo:.3f}-{hi:.3f} |{bar:<{bar_width}}| {count:>3}{marker}")
# Usage instructions
print(f"\n{'─'*70}")
print(f" USAGE")
print(f"{'─'*70}")
print(f" python entropy_resonance.py \\")
print(f" --adapter-path ./gemma3-resonate/best \\")
print(f" --h-high {result['h_high']:.4f} \\")
print(f" --h-low {result['h_low']:.4f}")
print(f"\n{'='*70}\n")
# ============================================================================
# Main
# ============================================================================
def main():
parser = argparse.ArgumentParser(
description="Calibrate entropy thresholds for Adaptive Resonance"
)
parser.add_argument("--model", default=MODEL_ID, help="Base model ID")
parser.add_argument("--adapter-path", default=None, help="LoRA adapter path")
parser.add_argument("--no-lora", action="store_true", help="Skip LoRA loading")
parser.add_argument("--device", default=None, help="Device: cuda/cpu/mps")
parser.add_argument("--prompts", default=None,
help="Text file with prompts, one per line")
parser.add_argument("--max-tokens", type=int, default=100,
help="Max tokens per generation during calibration")
parser.add_argument("--target-rate", type=float, default=0.45,
help="Target resonance trigger rate (0-1)")
parser.add_argument("--temperature", type=float, default=0.7,
help="Sampling temperature during calibration")
parser.add_argument("--save", default=None,
help="Save calibration result to JSON file")
args = parser.parse_args()
# Device
if args.device is None:
if torch.cuda.is_available():
device = 'cuda'
elif torch.backends.mps.is_available():
device = 'mps'
else:
device = 'cpu'
else:
device = args.device
# Load model
log.info(f"Loading tokenizer from {args.model}...")
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
dtype = torch.bfloat16 if device == 'cuda' else torch.float32
log.info(f"Loading model from {args.model} onto {device}...")
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype=dtype,
device_map=device if device == 'cuda' else None,
attn_implementation="sdpa" if device == 'cuda' else "eager",
trust_remote_code=True,
)
if device != 'cuda':
model = model.to(device)
if args.adapter_path and not args.no_lora:
from peft import PeftModel
log.info(f"Loading adapter from {args.adapter_path}...")
model = PeftModel.from_pretrained(model, args.adapter_path)
model.eval()
# Load prompts
if args.prompts:
with open(args.prompts, 'r', encoding='utf-8') as f:
prompts = [line.strip() for line in f if line.strip()]
log.info(f"Loaded {len(prompts)} prompts from {args.prompts}")
else:
prompts = DEFAULT_PROMPTS
log.info(f"Using {len(prompts)} default calibration prompts")
# Collect entropy profiles
log.info(f"Collecting entropy profiles ({args.max_tokens} tokens/prompt)...")
profiles = []
t0 = time.time()
for i, prompt in enumerate(prompts):
log.info(f" [{i+1}/{len(prompts)}] {prompt[:60]}...")
profile = collect_entropy_profile(
model, tokenizer, prompt,
max_tokens=args.max_tokens,
temperature=args.temperature,
device=device,
)
profiles.append(profile)
if profile['entropies']:
log.info(f" H_norm: mean={profile['mean_h']:.4f} max={profile['max_h']:.4f} "
f"min={profile['min_h']:.4f} ({len(profile['entropies'])} tokens)")
elapsed = time.time() - t0
log.info(f"Collection complete in {elapsed:.1f}s")
# Compute thresholds
result = compute_thresholds(profiles, target_resonance_rate=args.target_rate)
# Print report
print_report(result, profiles)
# Save if requested
if args.save:
# Don't save the full entropy traces (too large) — just the result
save_data = {
'calibration': result,
'per_prompt_summary': [
{
'prompt': p['prompt'],
'mean_h': round(p['mean_h'], 4),
'max_h': round(p['max_h'], 4),
'min_h': round(p['min_h'], 4),
'std_h': round(p['std_h'], 4),
'first_5_mean': round(p['first_5_mean'], 4),
'n_tokens': len(p['entropies']),
'would_trigger': p['max_h'] > result['h_high'],
}
for p in profiles if p['entropies']
],
'model': args.model,
'adapter': args.adapter_path,
'target_rate': args.target_rate,
'max_tokens': args.max_tokens,
'temperature': args.temperature,
}
with open(args.save, 'w', encoding='utf-8') as f:
json.dump(save_data, f, indent=2, ensure_ascii=False)
log.info(f"Calibration saved to {args.save}")
log.info("Done. Use the recommended thresholds with entropy_resonance.py.")
if __name__ == "__main__":
main()