ataeff commited on
Commit
e7df4fb
·
verified ·
1 Parent(s): 882f913

Wulf's entropy-driven inference script

Browse files
Files changed (1) hide show
  1. entropy_resonance.py +883 -0
entropy_resonance.py ADDED
@@ -0,0 +1,883 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ entropy_resonance.py — Entropy-Driven Adaptive Resonance for Gemma-3 270M-IT
4
+
5
+ The model doesn't decide WHEN to think. The entropy of its own logits does.
6
+ LoRA teaches it HOW to think. Entropy tells it WHEN.
7
+
8
+ Usage:
9
+ # Interactive mode
10
+ python entropy_resonance.py --adapter-path ./gemma3-resonate/best
11
+
12
+ # Single prompt
13
+ python entropy_resonance.py --adapter-path ./gemma3-resonate/best \
14
+ --prompt "Why does emergence happen?"
15
+
16
+ # Base model without LoRA (entropy still works, resonance content will be weaker)
17
+ python entropy_resonance.py --no-lora --prompt "What is consciousness?"
18
+
19
+ # With custom thresholds
20
+ python entropy_resonance.py --adapter-path ./gemma3-resonate/best \
21
+ --h-high 0.38 --h-low 0.12
22
+
23
+ # Verbose mode with entropy curve visualization
24
+ python entropy_resonance.py --adapter-path ./gemma3-resonate/best \
25
+ --prompt "Is free will real?" --verbose --show-curve
26
+
27
+ # Calibrate thresholds first (recommended for new model/adapter)
28
+ python calibrate_entropy.py --adapter-path ./gemma3-resonate/best
29
+
30
+ Author: Wulf (Opus + Oleg)
31
+ Date: 2026-03-28
32
+ """
33
+
34
+ from __future__ import annotations
35
+
36
+ import os
37
+ import sys
38
+ import math
39
+ import time
40
+ import argparse
41
+ import logging
42
+ from dataclasses import dataclass, field
43
+ from typing import Optional
44
+
45
+ import torch
46
+ import torch.nn.functional as F
47
+
48
+ from transformers import AutoModelForCausalLM, AutoTokenizer
49
+
50
+ # ============================================================================
51
+ # Constants
52
+ # ============================================================================
53
+
54
+ MODEL_ID = "unsloth/gemma-3-270m-it"
55
+
56
+ # Gemma-3 chat template
57
+ START_OF_TURN = "<start_of_turn>"
58
+ END_OF_TURN = "<end_of_turn>"
59
+
60
+ # Resonance markers — plain text, not special tokens
61
+ RESONATE_OPEN = "/resonate/"
62
+ RESONATE_CLOSE = "/resonated/"
63
+
64
+ # Gemma-3 vocab size
65
+ VOCAB_SIZE = 262_144
66
+ H_MAX = math.log2(VOCAB_SIZE) # 18.0 bits — theoretical maximum entropy
67
+
68
+ # ============================================================================
69
+ # Logging
70
+ # ============================================================================
71
+
72
+ logging.basicConfig(
73
+ level=logging.INFO,
74
+ format="%(asctime)s [%(levelname)s] %(message)s",
75
+ datefmt="%H:%M:%S",
76
+ )
77
+ log = logging.getLogger("entropy_resonance")
78
+
79
+
80
+ # ============================================================================
81
+ # Entropy Computation
82
+ # ============================================================================
83
+
84
+ def compute_entropy(logits: torch.Tensor, temperature: float = 1.0) -> float:
85
+ """Compute Shannon entropy from raw logits in bits.
86
+
87
+ CRITICAL: We compute entropy from RAW logits (temperature=1.0 internal),
88
+ not from temperature-scaled logits. This gives us the model's TRUE
89
+ uncertainty, independent of our sampling temperature choice.
90
+
91
+ Args:
92
+ logits: shape (vocab_size,) — raw logits from model's last layer
93
+ temperature: ignored for entropy computation (documented for clarity)
94
+
95
+ Returns:
96
+ H in bits (log base 2). Range: [0, log2(vocab_size)] = [0, 18.0]
97
+ """
98
+ # Softmax on raw logits (no temperature scaling for entropy measurement)
99
+ probs = F.softmax(logits.float(), dim=-1)
100
+
101
+ # Clamp to avoid log(0)
102
+ probs = probs.clamp(min=1e-10)
103
+
104
+ # Shannon entropy in bits
105
+ H = -(probs * probs.log2()).sum().item()
106
+
107
+ return H
108
+
109
+
110
+ def normalized_entropy(H: float) -> float:
111
+ """Normalize entropy to [0, 1] range based on vocab size.
112
+
113
+ H_norm = H / H_max = H / log2(262144)
114
+
115
+ Returns:
116
+ 0.0 = perfect certainty (one-hot distribution)
117
+ 1.0 = uniform distribution (maximum uncertainty)
118
+ """
119
+ return H / H_MAX
120
+
121
+
122
+ # ============================================================================
123
+ # Entropy Curve Visualization (Terminal)
124
+ # ============================================================================
125
+
126
+ class EntropyCurve:
127
+ """Collects entropy values during generation and renders ASCII visualization."""
128
+
129
+ def __init__(self, width: int = 70, height: int = 20):
130
+ self.width = width
131
+ self.height = height
132
+ self.values: list[float] = [] # raw H in bits
133
+ self.normalized: list[float] = [] # H_norm [0, 1]
134
+ self.tokens: list[str] = [] # generated token strings
135
+ self.events: list[tuple[int, str]] = [] # (step, event_type)
136
+
137
+ def add(self, H: float, token_str: str):
138
+ self.values.append(H)
139
+ self.normalized.append(normalized_entropy(H))
140
+ self.tokens.append(token_str)
141
+
142
+ def mark_event(self, event_type: str):
143
+ """Mark an event at the current step (e.g., 'enter_resonance', 'exit_resonance')."""
144
+ self.events.append((len(self.values) - 1, event_type))
145
+
146
+ def render(self, h_high: float, h_low: float) -> str:
147
+ """Render ASCII entropy curve with threshold lines and events.
148
+
149
+ Args:
150
+ h_high: normalized high threshold (enter resonance)
151
+ h_low: normalized low threshold (exit resonance)
152
+
153
+ Returns:
154
+ Multi-line string with the visualization
155
+ """
156
+ if not self.normalized:
157
+ return "(no data)"
158
+
159
+ n = len(self.normalized)
160
+
161
+ # If more data points than width, subsample
162
+ if n > self.width:
163
+ step = n / self.width
164
+ indices = [int(i * step) for i in range(self.width)]
165
+ data = [self.normalized[i] for i in indices]
166
+ else:
167
+ data = list(self.normalized)
168
+ indices = list(range(n))
169
+
170
+ # Scale to height
171
+ max_val = max(max(data), h_high + 0.05, 0.5)
172
+ min_val = 0.0
173
+
174
+ lines = []
175
+ lines.append(f" Entropy Curve ({n} tokens, H_max={H_MAX:.1f} bits)")
176
+ lines.append(f" H_high={h_high:.3f} (enter resonance) H_low={h_low:.3f} (exit resonance)")
177
+ lines.append("")
178
+
179
+ # Build grid
180
+ grid = [[' ' for _ in range(len(data))] for _ in range(self.height)]
181
+
182
+ # Plot data points
183
+ for col, val in enumerate(data):
184
+ row = int((1.0 - (val - min_val) / (max_val - min_val)) * (self.height - 1))
185
+ row = max(0, min(self.height - 1, row))
186
+ grid[row][col] = '#'
187
+
188
+ # Plot threshold lines
189
+ h_high_row = int((1.0 - (h_high - min_val) / (max_val - min_val)) * (self.height - 1))
190
+ h_low_row = int((1.0 - (h_low - min_val) / (max_val - min_val)) * (self.height - 1))
191
+ h_high_row = max(0, min(self.height - 1, h_high_row))
192
+ h_low_row = max(0, min(self.height - 1, h_low_row))
193
+
194
+ for col in range(len(data)):
195
+ if grid[h_high_row][col] == ' ':
196
+ grid[h_high_row][col] = '-'
197
+ if grid[h_low_row][col] == ' ':
198
+ grid[h_low_row][col] = '.'
199
+
200
+ # Mark events
201
+ event_map = {}
202
+ for step, etype in self.events:
203
+ if n > self.width:
204
+ # Find closest column
205
+ col = min(range(len(indices)), key=lambda c: abs(indices[c] - step))
206
+ else:
207
+ col = step
208
+ if 0 <= col < len(data):
209
+ event_map[col] = etype
210
+
211
+ # Render
212
+ for row_idx, row in enumerate(grid):
213
+ # Y-axis label
214
+ val = max_val - row_idx * (max_val - min_val) / (self.height - 1)
215
+ label = f"{val:.2f}"
216
+
217
+ row_str = ''.join(row)
218
+
219
+ # Annotate threshold rows
220
+ suffix = ""
221
+ if row_idx == h_high_row:
222
+ suffix = " <-- H_high (enter)"
223
+ elif row_idx == h_low_row:
224
+ suffix = " <-- H_low (exit)"
225
+
226
+ lines.append(f" {label:>5} |{row_str}|{suffix}")
227
+
228
+ # X-axis
229
+ lines.append(f" {''.join(['+' if col in event_map else '-' for col in range(len(data))])}")
230
+
231
+ # Event legend
232
+ event_line = " "
233
+ for col in range(len(data)):
234
+ if col in event_map:
235
+ if event_map[col] == 'enter_resonance':
236
+ event_line += 'E'
237
+ elif event_map[col] == 'exit_resonance':
238
+ event_line += 'X'
239
+ else:
240
+ event_line += '?'
241
+ else:
242
+ event_line += ' '
243
+ lines.append(event_line)
244
+ lines.append(f" E=enter resonance, X=exit resonance")
245
+
246
+ # Stats
247
+ avg_h = sum(self.normalized) / len(self.normalized)
248
+ max_h = max(self.normalized)
249
+ min_h = min(self.normalized)
250
+ std_h = (sum((v - avg_h)**2 for v in self.normalized) / len(self.normalized)) ** 0.5
251
+
252
+ lines.append("")
253
+ lines.append(f" Stats: mean={avg_h:.4f} max={max_h:.4f} min={min_h:.4f} std={std_h:.4f}")
254
+ lines.append(f" Raw H: mean={sum(self.values)/len(self.values):.2f} bits max={max(self.values):.2f} bits")
255
+
256
+ # Resonance segments
257
+ in_res = False
258
+ segments = []
259
+ seg_start = 0
260
+ for step, etype in self.events:
261
+ if etype == 'enter_resonance' and not in_res:
262
+ in_res = True
263
+ seg_start = step
264
+ elif etype == 'exit_resonance' and in_res:
265
+ in_res = False
266
+ segments.append((seg_start, step))
267
+
268
+ if segments:
269
+ lines.append(f" Resonance segments: {len(segments)}")
270
+ for i, (s, e) in enumerate(segments):
271
+ seg_h = self.normalized[s:e+1]
272
+ seg_avg = sum(seg_h) / len(seg_h) if seg_h else 0
273
+ lines.append(f" [{i+1}] tokens {s}-{e} ({e-s} tokens, avg H_norm={seg_avg:.4f})")
274
+
275
+ return '\n'.join(lines)
276
+
277
+
278
+ # ============================================================================
279
+ # Resonance State Machine
280
+ # ============================================================================
281
+
282
+ @dataclass
283
+ class ResonanceState:
284
+ """Tracks the resonance state during generation."""
285
+ in_resonance: bool = False
286
+
287
+ # Hysteresis counters — prevent rapid enter/exit flickering
288
+ consecutive_high: int = 0 # consecutive tokens above H_high
289
+ consecutive_low: int = 0 # consecutive tokens below H_low
290
+
291
+ # Thresholds (normalized, 0-1)
292
+ h_high: float = 0.35 # enter resonance above this
293
+ h_low: float = 0.12 # exit resonance below this
294
+
295
+ # Hysteresis requirements
296
+ enter_count: int = 3 # N consecutive high-entropy tokens to enter
297
+ exit_count: int = 5 # M consecutive low-entropy tokens to exit
298
+
299
+ # Safeguards
300
+ max_resonance_tokens: int = 500 # force exit after this many resonance tokens
301
+ resonance_token_count: int = 0 # current count
302
+
303
+ # Entropy modulation (Delta Voice integration)
304
+ beta: float = 0.3 # entropy coupling constant for θ = ε + γ + αδ + βH
305
+
306
+ # Sampling parameters (modulated by entropy)
307
+ base_temperature: float = 0.7
308
+ base_top_p: float = 0.9
309
+ base_top_k: int = 40
310
+
311
+ # Diagnostic
312
+ total_tokens: int = 0
313
+ resonance_entries: int = 0
314
+ forced_exits: int = 0
315
+
316
+ def update(self, h_norm: float) -> Optional[str]:
317
+ """Process a new entropy value and return event or None.
318
+
319
+ Returns:
320
+ 'enter_resonance' — inject /resonate/ marker
321
+ 'exit_resonance' — inject /resonated/ marker
322
+ 'force_exit' — max tokens exceeded, force exit
323
+ None — no state change
324
+ """
325
+ self.total_tokens += 1
326
+
327
+ if self.in_resonance:
328
+ self.resonance_token_count += 1
329
+
330
+ # Check for forced exit
331
+ if self.resonance_token_count >= self.max_resonance_tokens:
332
+ self.in_resonance = False
333
+ self.resonance_token_count = 0
334
+ self.consecutive_high = 0
335
+ self.consecutive_low = 0
336
+ self.forced_exits += 1
337
+ return 'force_exit'
338
+
339
+ # Check for natural exit
340
+ if h_norm < self.h_low:
341
+ self.consecutive_low += 1
342
+ self.consecutive_high = 0
343
+ else:
344
+ self.consecutive_low = 0
345
+
346
+ if self.consecutive_low >= self.exit_count:
347
+ self.in_resonance = False
348
+ self.resonance_token_count = 0
349
+ self.consecutive_low = 0
350
+ return 'exit_resonance'
351
+
352
+ else:
353
+ # Check for entry
354
+ if h_norm > self.h_high:
355
+ self.consecutive_high += 1
356
+ self.consecutive_low = 0
357
+ else:
358
+ self.consecutive_high = 0
359
+
360
+ if self.consecutive_high >= self.enter_count:
361
+ self.in_resonance = True
362
+ self.resonance_token_count = 0
363
+ self.consecutive_high = 0
364
+ self.resonance_entries += 1
365
+ return 'enter_resonance'
366
+
367
+ return None
368
+
369
+ def get_sampling_params(self, h_norm: float) -> dict:
370
+ """Get entropy-modulated sampling parameters.
371
+
372
+ Inside /resonate/: more exploratory (higher temp, wider sampling)
373
+ Outside /resonate/: more crystallized (base params)
374
+
375
+ The modulation is ANALOG — scales with entropy level.
376
+ This is the βH term in θ = ε + γ + αδ + βH
377
+ """
378
+ if self.in_resonance:
379
+ # Entropy modulates exploration depth
380
+ temp = self.base_temperature * (1.0 + self.beta * h_norm)
381
+ top_p = min(0.98, self.base_top_p + self.beta * h_norm * 0.15)
382
+ top_k = int(self.base_top_k * (1.0 + self.beta * h_norm))
383
+ return {
384
+ 'temperature': temp,
385
+ 'top_p': top_p,
386
+ 'top_k': top_k,
387
+ }
388
+ else:
389
+ return {
390
+ 'temperature': self.base_temperature,
391
+ 'top_p': self.base_top_p,
392
+ 'top_k': self.base_top_k,
393
+ }
394
+
395
+ def summary(self) -> str:
396
+ """Return diagnostic summary."""
397
+ return (
398
+ f"Resonance: {self.resonance_entries} entries, "
399
+ f"{self.forced_exits} forced exits, "
400
+ f"{self.total_tokens} total tokens"
401
+ )
402
+
403
+
404
+ # ============================================================================
405
+ # The Main Beast: Entropy-Driven Generation
406
+ # ============================================================================
407
+
408
+ def entropy_generate(
409
+ model,
410
+ tokenizer,
411
+ prompt: str,
412
+ state: ResonanceState,
413
+ max_new_tokens: int = 768,
414
+ verbose: bool = False,
415
+ show_curve: bool = False,
416
+ repetition_penalty: float = 1.3,
417
+ ) -> tuple[str, EntropyCurve]:
418
+ """Generate text with entropy-driven adaptive resonance.
419
+
420
+ This is NOT model.generate(). We run the generation loop manually,
421
+ token by token, computing entropy at each step and making resonance
422
+ decisions in real time.
423
+
424
+ Args:
425
+ model: Gemma-3 270M-IT (with or without LoRA adapter)
426
+ tokenizer: Gemma tokenizer
427
+ prompt: user's question/input
428
+ state: ResonanceState with thresholds and parameters
429
+ max_new_tokens: maximum tokens to generate
430
+ verbose: print entropy at each step
431
+ show_curve: collect data for visualization
432
+ repetition_penalty: penalize repeated tokens
433
+
434
+ Returns:
435
+ (generated_text, entropy_curve)
436
+ """
437
+ device = next(model.parameters()).device
438
+ model.eval()
439
+
440
+ # Format prompt in Gemma chat template
441
+ input_text = f"{START_OF_TURN}user\n{prompt}{END_OF_TURN}\n{START_OF_TURN}model\n"
442
+ input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
443
+
444
+ # Initialize
445
+ curve = EntropyCurve()
446
+ generated_ids = []
447
+ generated_text = ""
448
+
449
+ # Track generated token IDs for repetition penalty
450
+ all_ids = input_ids[0].tolist()
451
+
452
+ # EOS token
453
+ eos_id = tokenizer.eos_token_id
454
+ # Gemma end_of_turn token
455
+ eot_text = END_OF_TURN
456
+ eot_ids = tokenizer.encode(eot_text, add_special_tokens=False)
457
+
458
+ # Reset state for this generation
459
+ state.in_resonance = False
460
+ state.consecutive_high = 0
461
+ state.consecutive_low = 0
462
+ state.resonance_token_count = 0
463
+ state.total_tokens = 0
464
+ state.resonance_entries = 0
465
+ state.forced_exits = 0
466
+
467
+ # Prefill: get initial logits from full context
468
+ with torch.no_grad():
469
+ outputs = model(input_ids)
470
+ next_logits = outputs.logits[0, -1, :] # (vocab_size,)
471
+
472
+ for step in range(max_new_tokens):
473
+ # ── 1. Compute entropy from RAW logits ──
474
+ H = compute_entropy(next_logits)
475
+ h_norm = normalized_entropy(H)
476
+
477
+ # ── 2. Check resonance state ──
478
+ event = state.update(h_norm)
479
+
480
+ if event == 'enter_resonance':
481
+ # Inject /resonate/ marker into the generation
482
+ marker_text = f"\n{RESONATE_OPEN}\n"
483
+ marker_ids = tokenizer.encode(marker_text, add_special_tokens=False)
484
+ generated_ids.extend(marker_ids)
485
+ all_ids.extend(marker_ids)
486
+ generated_text += marker_text
487
+
488
+ if verbose:
489
+ log.info(f" [ENTER RESONANCE] H_norm={h_norm:.4f} at token {step}")
490
+
491
+ if show_curve:
492
+ curve.mark_event('enter_resonance')
493
+
494
+ # Re-run model with the injected marker to update context
495
+ full_ids = torch.tensor([all_ids], device=device)
496
+ with torch.no_grad():
497
+ outputs = model(full_ids)
498
+ next_logits = outputs.logits[0, -1, :]
499
+ continue # Re-evaluate entropy after marker injection
500
+
501
+ elif event in ('exit_resonance', 'force_exit'):
502
+ # Inject /resonated/ marker
503
+ marker_text = f"\n{RESONATE_CLOSE}\n"
504
+ marker_ids = tokenizer.encode(marker_text, add_special_tokens=False)
505
+ generated_ids.extend(marker_ids)
506
+ all_ids.extend(marker_ids)
507
+ generated_text += marker_text
508
+
509
+ if verbose:
510
+ if event == 'force_exit':
511
+ log.warning(f" [FORCED EXIT] Max resonance tokens exceeded at step {step}")
512
+ else:
513
+ log.info(f" [EXIT RESONANCE] H_norm={h_norm:.4f} at token {step}")
514
+
515
+ if show_curve:
516
+ curve.mark_event('exit_resonance')
517
+
518
+ # Re-run model with marker
519
+ full_ids = torch.tensor([all_ids], device=device)
520
+ with torch.no_grad():
521
+ outputs = model(full_ids)
522
+ next_logits = outputs.logits[0, -1, :]
523
+ continue
524
+
525
+ # ── 3. Get entropy-modulated sampling parameters ──
526
+ params = state.get_sampling_params(h_norm)
527
+
528
+ # ── 4. Apply repetition penalty ──
529
+ logits = next_logits.clone()
530
+ if repetition_penalty != 1.0 and generated_ids:
531
+ for prev_id in set(generated_ids[-50:]): # look back 50 tokens
532
+ if logits[prev_id] > 0:
533
+ logits[prev_id] /= repetition_penalty
534
+ else:
535
+ logits[prev_id] *= repetition_penalty
536
+
537
+ # ── 5. Apply temperature ──
538
+ temp = params['temperature']
539
+ if temp > 0:
540
+ logits = logits / temp
541
+ else:
542
+ # temperature=0 → greedy
543
+ pass
544
+
545
+ # ── 6. Apply top-k filtering ──
546
+ top_k = params['top_k']
547
+ if top_k > 0:
548
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][-1]
549
+ logits[indices_to_remove] = float('-inf')
550
+
551
+ # ── 7. Apply top-p (nucleus) filtering ──
552
+ top_p = params['top_p']
553
+ if top_p < 1.0:
554
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
555
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
556
+ # Remove tokens with cumulative prob above top_p
557
+ sorted_indices_to_remove = cumulative_probs > top_p
558
+ # Keep the first token above threshold
559
+ sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
560
+ sorted_indices_to_remove[0] = False
561
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
562
+ logits[indices_to_remove] = float('-inf')
563
+
564
+ # ── 8. Sample ──
565
+ probs = F.softmax(logits, dim=-1)
566
+ if temp > 0:
567
+ next_token = torch.multinomial(probs, num_samples=1).item()
568
+ else:
569
+ next_token = torch.argmax(logits).item()
570
+
571
+ # ── 9. Check for EOS ──
572
+ if next_token == eos_id:
573
+ break
574
+
575
+ # Check for end_of_turn sequence
576
+ generated_ids.append(next_token)
577
+ all_ids.append(next_token)
578
+
579
+ token_str = tokenizer.decode([next_token])
580
+ generated_text += token_str
581
+
582
+ # Check if we just generated end_of_turn
583
+ if generated_text.rstrip().endswith(eot_text):
584
+ generated_text = generated_text.rstrip()[:-len(eot_text)].rstrip()
585
+ break
586
+
587
+ # ── 10. Record for visualization ──
588
+ if show_curve:
589
+ curve.add(H, token_str)
590
+
591
+ if verbose and step % 10 == 0:
592
+ mode = "RESONANCE" if state.in_resonance else "crystal"
593
+ log.info(
594
+ f" step={step:3d} H={H:.2f}bits H_norm={h_norm:.4f} "
595
+ f"mode={mode} temp={params['temperature']:.3f} "
596
+ f"token={repr(token_str)}"
597
+ )
598
+
599
+ # ── 11. Forward pass for next token ──
600
+ next_input = torch.tensor([[next_token]], device=device)
601
+ full_ids = torch.tensor([all_ids], device=device)
602
+ with torch.no_grad():
603
+ # Use full context for each step (no KV cache for simplicity;
604
+ # for production, implement KV cache management)
605
+ outputs = model(full_ids)
606
+ next_logits = outputs.logits[0, -1, :]
607
+
608
+ return generated_text, curve
609
+
610
+
611
+ # ============================================================================
612
+ # Pretty Printing
613
+ # ============================================================================
614
+
615
+ def print_result(prompt: str, generated: str, curve: EntropyCurve,
616
+ state: ResonanceState, show_curve: bool = False,
617
+ h_high: float = 0.35, h_low: float = 0.12):
618
+ """Print the generation result with formatting."""
619
+
620
+ print(f"\n{'='*70}")
621
+ print(f" PROMPT: {prompt}")
622
+ print(f"{'='*70}")
623
+
624
+ # Parse /resonate/ sections
625
+ if RESONATE_OPEN in generated and RESONATE_CLOSE in generated:
626
+ parts = generated.split(RESONATE_OPEN, 1)
627
+ pre_resonate = parts[0].strip()
628
+ rest = parts[1]
629
+
630
+ if RESONATE_CLOSE in rest:
631
+ reasoning_and_rest = rest.split(RESONATE_CLOSE, 1)
632
+ reasoning = reasoning_and_rest[0].strip()
633
+ answer = reasoning_and_rest[1].strip()
634
+ else:
635
+ reasoning = rest.strip()
636
+ answer = "[resonance did not crystallize — forced exit or max tokens]"
637
+
638
+ if pre_resonate:
639
+ print(f"\n {pre_resonate}")
640
+
641
+ print(f"\n --- {RESONATE_OPEN} ---")
642
+ # Print reasoning with indent
643
+ for line in reasoning.split('\n'):
644
+ print(f" | {line}")
645
+
646
+ print(f"\n --- {RESONATE_CLOSE} ---")
647
+ print(f"\n {answer}")
648
+ else:
649
+ # No resonance triggered — direct answer
650
+ print(f"\n [direct answer — entropy stayed low, no resonance needed]")
651
+ print(f"\n {generated}")
652
+
653
+ print(f"\n{'─'*70}")
654
+ print(f" {state.summary()}")
655
+
656
+ if show_curve and curve.values:
657
+ print(f"\n{curve.render(h_high, h_low)}")
658
+
659
+ print(f"{'='*70}\n")
660
+
661
+
662
+ # ============================================================================
663
+ # Model Loading
664
+ # ============================================================================
665
+
666
+ def load_model(model_id: str = MODEL_ID, adapter_path: str = None,
667
+ device: str = None) -> tuple:
668
+ """Load Gemma-3 270M-IT with optional LoRA adapter.
669
+
670
+ Args:
671
+ model_id: base model identifier
672
+ adapter_path: path to LoRA adapter (None for base model)
673
+ device: 'cuda', 'cpu', or 'mps' (auto-detected if None)
674
+
675
+ Returns:
676
+ (model, tokenizer, device_str)
677
+ """
678
+ if device is None:
679
+ if torch.cuda.is_available():
680
+ device = 'cuda'
681
+ elif torch.backends.mps.is_available():
682
+ device = 'mps'
683
+ else:
684
+ device = 'cpu'
685
+
686
+ log.info(f"Loading tokenizer from {model_id}...")
687
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
688
+
689
+ log.info(f"Loading model from {model_id} onto {device}...")
690
+
691
+ dtype = torch.bfloat16 if device == 'cuda' else torch.float32
692
+
693
+ model = AutoModelForCausalLM.from_pretrained(
694
+ model_id,
695
+ torch_dtype=dtype,
696
+ device_map=device if device == 'cuda' else None,
697
+ attn_implementation="sdpa" if device == 'cuda' else "eager",
698
+ trust_remote_code=True,
699
+ )
700
+
701
+ if device != 'cuda':
702
+ model = model.to(device)
703
+
704
+ total_params = sum(p.numel() for p in model.parameters())
705
+ log.info(f"Base model: {total_params/1e6:.1f}M params, dtype={dtype}")
706
+
707
+ # Load LoRA adapter if provided
708
+ if adapter_path:
709
+ if not os.path.isdir(adapter_path):
710
+ log.error(f"Adapter path does not exist: {adapter_path}")
711
+ log.error("Run training first: python train_gemma_resonate.py")
712
+ sys.exit(1)
713
+
714
+ from peft import PeftModel
715
+ log.info(f"Loading LoRA adapter from {adapter_path}...")
716
+ model = PeftModel.from_pretrained(model, adapter_path)
717
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
718
+ log.info(f"Adapter loaded: {trainable/1e6:.1f}M trainable params")
719
+
720
+ model.eval()
721
+ return model, tokenizer, device
722
+
723
+
724
+ # ============================================================================
725
+ # Interactive Mode
726
+ # ============================================================================
727
+
728
+ def interactive_mode(model, tokenizer, state: ResonanceState,
729
+ verbose: bool = False, show_curve: bool = False):
730
+ """Interactive REPL for entropy-driven resonance."""
731
+
732
+ print(f"\n{'='*70}")
733
+ print(f" ENTROPY-DRIVEN ADAPTIVE RESONANCE")
734
+ print(f" Gemma-3 270M-IT + Entropy Monitoring")
735
+ print(f"{'─'*70}")
736
+ print(f" H_high = {state.h_high:.3f} (enter resonance)")
737
+ print(f" H_low = {state.h_low:.3f} (exit resonance)")
738
+ print(f" Beta = {state.beta:.2f} (entropy coupling)")
739
+ print(f" Max resonance tokens = {state.max_resonance_tokens}")
740
+ print(f"{'─'*70}")
741
+ print(f" Commands: /quit /verbose /curve /thresholds H_HIGH H_LOW")
742
+ print(f"{'='*70}\n")
743
+
744
+ while True:
745
+ try:
746
+ prompt = input(">>> ").strip()
747
+ except (EOFError, KeyboardInterrupt):
748
+ print("\nExiting.")
749
+ break
750
+
751
+ if not prompt:
752
+ continue
753
+
754
+ if prompt == '/quit':
755
+ break
756
+ elif prompt == '/verbose':
757
+ verbose = not verbose
758
+ print(f" Verbose: {'ON' if verbose else 'OFF'}")
759
+ continue
760
+ elif prompt == '/curve':
761
+ show_curve = not show_curve
762
+ print(f" Curve: {'ON' if show_curve else 'OFF'}")
763
+ continue
764
+ elif prompt.startswith('/thresholds'):
765
+ parts = prompt.split()
766
+ if len(parts) == 3:
767
+ try:
768
+ state.h_high = float(parts[1])
769
+ state.h_low = float(parts[2])
770
+ print(f" Thresholds updated: H_high={state.h_high:.3f}, H_low={state.h_low:.3f}")
771
+ except ValueError:
772
+ print(f" Usage: /thresholds 0.35 0.12")
773
+ else:
774
+ print(f" Current: H_high={state.h_high:.3f}, H_low={state.h_low:.3f}")
775
+ continue
776
+
777
+ # Generate with entropy monitoring
778
+ t0 = time.time()
779
+
780
+ generated, curve = entropy_generate(
781
+ model, tokenizer, prompt, state,
782
+ verbose=verbose,
783
+ show_curve=show_curve,
784
+ )
785
+
786
+ elapsed = time.time() - t0
787
+
788
+ print_result(prompt, generated, curve, state,
789
+ show_curve=show_curve,
790
+ h_high=state.h_high, h_low=state.h_low)
791
+
792
+ tokens_generated = len(curve.values) if curve.values else 0
793
+ tps = tokens_generated / elapsed if elapsed > 0 else 0
794
+ print(f" [{elapsed:.1f}s, ~{tokens_generated} tokens, {tps:.1f} tok/s]\n")
795
+
796
+
797
+ # ============================================================================
798
+ # Main
799
+ # ============================================================================
800
+
801
+ def main():
802
+ parser = argparse.ArgumentParser(
803
+ description="Entropy-Driven Adaptive Resonance — inference for Gemma-3 270M-IT"
804
+ )
805
+
806
+ # Model
807
+ parser.add_argument("--model", default=MODEL_ID, help="Base model ID")
808
+ parser.add_argument("--adapter-path", default=None, help="LoRA adapter path")
809
+ parser.add_argument("--no-lora", action="store_true", help="Skip LoRA loading")
810
+ parser.add_argument("--device", default=None, help="Device: cuda/cpu/mps (auto)")
811
+
812
+ # Generation
813
+ parser.add_argument("--prompt", default=None, help="Single prompt (non-interactive)")
814
+ parser.add_argument("--max-tokens", type=int, default=768, help="Max tokens to generate")
815
+
816
+ # Entropy thresholds
817
+ parser.add_argument("--h-high", type=float, default=0.35,
818
+ help="Normalized entropy threshold to enter resonance (0-1)")
819
+ parser.add_argument("--h-low", type=float, default=0.12,
820
+ help="Normalized entropy threshold to exit resonance (0-1)")
821
+ parser.add_argument("--beta", type=float, default=0.3,
822
+ help="Entropy coupling constant (Delta Voice integration)")
823
+
824
+ # Hysteresis
825
+ parser.add_argument("--enter-count", type=int, default=3,
826
+ help="Consecutive high-entropy tokens to enter resonance")
827
+ parser.add_argument("--exit-count", type=int, default=5,
828
+ help="Consecutive low-entropy tokens to exit resonance")
829
+ parser.add_argument("--max-resonance", type=int, default=500,
830
+ help="Max tokens in a single resonance section")
831
+
832
+ # Sampling
833
+ parser.add_argument("--temperature", type=float, default=0.7, help="Base temperature")
834
+ parser.add_argument("--top-p", type=float, default=0.9, help="Base top-p")
835
+ parser.add_argument("--top-k", type=int, default=40, help="Base top-k")
836
+ parser.add_argument("--repetition-penalty", type=float, default=1.3,
837
+ help="Repetition penalty")
838
+
839
+ # Display
840
+ parser.add_argument("--verbose", action="store_true", help="Show entropy per step")
841
+ parser.add_argument("--show-curve", action="store_true",
842
+ help="Show ASCII entropy curve after generation")
843
+
844
+ args = parser.parse_args()
845
+
846
+ # Load model
847
+ adapter = None if args.no_lora else args.adapter_path
848
+ model, tokenizer, device = load_model(args.model, adapter, args.device)
849
+
850
+ # Build resonance state
851
+ state = ResonanceState(
852
+ h_high=args.h_high,
853
+ h_low=args.h_low,
854
+ enter_count=args.enter_count,
855
+ exit_count=args.exit_count,
856
+ max_resonance_tokens=args.max_resonance,
857
+ beta=args.beta,
858
+ base_temperature=args.temperature,
859
+ base_top_p=args.top_p,
860
+ base_top_k=args.top_k,
861
+ )
862
+
863
+ if args.prompt:
864
+ # Single prompt mode
865
+ generated, curve = entropy_generate(
866
+ model, tokenizer, args.prompt, state,
867
+ max_new_tokens=args.max_tokens,
868
+ verbose=args.verbose,
869
+ show_curve=args.show_curve,
870
+ repetition_penalty=args.repetition_penalty,
871
+ )
872
+ print_result(args.prompt, generated, curve, state,
873
+ show_curve=args.show_curve,
874
+ h_high=state.h_high, h_low=state.h_low)
875
+ else:
876
+ # Interactive mode
877
+ interactive_mode(model, tokenizer, state,
878
+ verbose=args.verbose,
879
+ show_curve=args.show_curve)
880
+
881
+
882
+ if __name__ == "__main__":
883
+ main()