Lgr54HFi commited on
Commit
ed37c7e
Β·
verified Β·
1 Parent(s): c4fa83f

Upload chimera/inference.py

Browse files
Files changed (1) hide show
  1. chimera/inference.py +298 -0
chimera/inference.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chimera 5.1 β€” Inference Systems (CPU-Optimized)
3
+ Span bank, Grammar FST, Entropy valve, Debt ledger, Braid state
4
+ - Vectorized span bank queries (batched cosine similarity)
5
+ - Fused grammar constraint computation
6
+ - Efficient entropy calculation (log_softmax path)
7
+ - torch.compile friendly (no Python-level data-dependent branching in hot path)
8
+ """
9
+
10
+ import math
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ # ─────────────────────────────────────────────────
17
+ # Span Bank β€” Vectorized semantic search
18
+ # ─────────────────────────────────────────────────
19
+ class SpanBank(nn.Module):
20
+ def __init__(self, max_entries: int = 524288, max_tokens: int = 64,
21
+ hidden_size: int = 2560, memory_mb: int = 384):
22
+ super().__init__()
23
+ self.max_entries = max_entries
24
+ self.max_tokens = max_tokens
25
+ self.hidden_size = hidden_size
26
+ proj_dim = hidden_size // 4
27
+ actual_entries = min(max_entries, int(memory_mb * 1024 * 1024 / (max_tokens * 4)))
28
+ self.register_buffer('bank_keys', torch.zeros(actual_entries, proj_dim))
29
+ self.register_buffer('bank_lengths', torch.zeros(actual_entries, dtype=torch.long))
30
+ self.register_buffer('bank_count', torch.tensor(0, dtype=torch.long))
31
+ self.semantic_proj = nn.Linear(hidden_size, proj_dim, bias=False)
32
+
33
+ def query(self, hidden_state: torch.Tensor, top_k: int = 64):
34
+ if self.bank_count == 0:
35
+ return None, None
36
+ q = F.normalize(self.semantic_proj(hidden_state), dim=-1)
37
+ count = self.bank_count.item()
38
+ keys = F.normalize(self.bank_keys[:count], dim=-1)
39
+ # Batched cosine similarity via matmul
40
+ sims = torch.matmul(q, keys.t()) # [*, count]
41
+ k = min(top_k, count)
42
+ scores, indices = sims.topk(k, dim=-1)
43
+ return scores, indices
44
+
45
+ @torch.no_grad()
46
+ def add_span(self, hidden_state: torch.Tensor, length: int):
47
+ if self.bank_count >= self.bank_keys.shape[0]:
48
+ return
49
+ idx = self.bank_count.item()
50
+ self.bank_keys[idx] = self.semantic_proj(hidden_state.detach()).squeeze(0)
51
+ self.bank_lengths[idx] = length
52
+ self.bank_count += 1
53
+
54
+
55
+ # ─────────────────────────────────────────────────
56
+ # STree Verifier β€” Compact scoring network
57
+ # ─────────────────────────────────────────────────
58
+ class STreeVerifier(nn.Module):
59
+ def __init__(self, tree_width: int = 4, tree_depth: int = 5,
60
+ hidden_size: int = 256):
61
+ super().__init__()
62
+ self.tree_width = tree_width
63
+ self.tree_depth = tree_depth
64
+ self.score_net = nn.Sequential(
65
+ nn.Linear(hidden_size, hidden_size // 4),
66
+ nn.ReLU(),
67
+ nn.Linear(hidden_size // 4, 1),
68
+ )
69
+
70
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
71
+ return torch.sigmoid(self.score_net(hidden_states)).squeeze(-1)
72
+
73
+
74
+ # ─────────────────────────────────────────────────
75
+ # Certificate Verifier β€” Vectorized field extraction
76
+ # ─────────────────────────────────────────────────
77
+ class CertificateVerifier(nn.Module):
78
+ def __init__(self, hidden_size: int):
79
+ super().__init__()
80
+ self.semantic_proj = nn.Linear(hidden_size, 64, bias=False)
81
+ self.grammar_proj = nn.Linear(hidden_size, 16, bias=False)
82
+ self.entity_proj = nn.Linear(hidden_size, 32, bias=False)
83
+ self.boundary_proj = nn.Linear(hidden_size, 1, bias=False)
84
+ self.risk_proj = nn.Linear(hidden_size, 1, bias=False)
85
+
86
+ def forward(self, hidden_states: torch.Tensor) -> dict:
87
+ return {
88
+ 'semantic': self.semantic_proj(hidden_states),
89
+ 'grammar': self.grammar_proj(hidden_states),
90
+ 'entity': self.entity_proj(hidden_states),
91
+ 'boundary': self.boundary_proj(hidden_states),
92
+ 'risk': torch.sigmoid(self.risk_proj(hidden_states)),
93
+ }
94
+
95
+
96
+ # ─────────────────────────────────────────────────
97
+ # Span Inference Engine
98
+ # ─────────────────────────────────────────────────
99
+ class SpanInferenceEngine(nn.Module):
100
+ def __init__(self, hidden_size: int, config: dict):
101
+ super().__init__()
102
+ self.enabled = config.get('enabled', True)
103
+ self.hidden_size = hidden_size
104
+ self.span_bank = SpanBank(
105
+ max_entries=config.get('bank_entries', 524288),
106
+ max_tokens=config.get('bank_max_tokens', 64),
107
+ hidden_size=hidden_size,
108
+ memory_mb=config.get('bank_memory_mb', 384),
109
+ )
110
+ self.tree_verifier = STreeVerifier(
111
+ tree_width=config.get('tree_verify', {}).get('tree_width', 4),
112
+ tree_depth=config.get('tree_verify', {}).get('tree_depth', 5),
113
+ hidden_size=hidden_size,
114
+ )
115
+ self.certificate = CertificateVerifier(hidden_size)
116
+ self.scoring_weights = nn.Parameter(
117
+ torch.tensor(config.get('scoring_weights_fast', [1.0, 0.8, 0.5, 0.7, 0.35])))
118
+ self.fallback_threshold = config.get('fallback_below_acceptance', 0.5)
119
+ self.risk_gate = nn.Linear(hidden_size + 1, hidden_size, bias=False)
120
+
121
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
122
+ if not self.enabled:
123
+ return hidden_states
124
+ cert = self.certificate(hidden_states)
125
+ risk = cert['risk']
126
+ gate_input = torch.cat([hidden_states, risk], dim=-1)
127
+ modulation = torch.sigmoid(self.risk_gate(gate_input))
128
+ return hidden_states * modulation
129
+
130
+
131
+ # ─────────────────────────────────────────────────
132
+ # Grammar FST β€” Fused constraint penalty
133
+ # ─────────────────────────────────────────────────
134
+ class GrammarFST(nn.Module):
135
+ """Grammar FST with fused constraint computation.
136
+
137
+ Optimizations:
138
+ - Single forward pass for all constraint features
139
+ - Fused entropy + margin + repetition penalty computation
140
+ - Pre-allocated feature buffer
141
+ """
142
+
143
+ def __init__(self, config: dict):
144
+ super().__init__()
145
+ self.enabled = config.get('enabled', True)
146
+ self.modes = config.get('modes', ['plain_text'])
147
+ self.hard_constraints = config.get('hard_constraints', [])
148
+ self.soft_constraints = config.get('soft_constraints', [])
149
+ n_features = len(self.hard_constraints) + len(self.soft_constraints) + 1
150
+ self.constraint_proj = nn.Linear(n_features, 1, bias=True)
151
+ nn.init.normal_(self.constraint_proj.weight, std=0.01)
152
+ nn.init.zeros_(self.constraint_proj.bias)
153
+ self._n_hard = len(self.hard_constraints)
154
+ self._n_soft = len(self.soft_constraints)
155
+ self._n_features = n_features
156
+
157
+ def forward(self, logits: torch.Tensor, state=None) -> torch.Tensor:
158
+ if not self.enabled:
159
+ return logits
160
+ B, T, V = logits.shape
161
+
162
+ # Fused feature computation
163
+ # 1. Entropy from log_softmax (numerically stable, single pass)
164
+ log_probs = F.log_softmax(logits, dim=-1)
165
+ probs = log_probs.exp()
166
+ entropy = -(probs * log_probs).sum(-1) # [B, T]
167
+
168
+ # 2. Repetition penalty via cosine of adjacent logit vectors
169
+ features = torch.zeros(B, T, self._n_features, device=logits.device,
170
+ dtype=logits.dtype)
171
+ features[..., 0] = entropy
172
+ if self._n_soft > 0 and T > 1:
173
+ # Cosine similarity with previous position (vectorized)
174
+ cos = F.cosine_similarity(logits[:, 1:], logits[:, :-1], dim=-1)
175
+ features[:, 1:, self._n_hard] = cos.clamp(min=0)
176
+
177
+ penalty = self.constraint_proj(features) # [B, T, 1]
178
+ return logits + penalty.expand_as(logits)
179
+
180
+
181
+ # ─────────────────────────────────────────────────
182
+ # Entropy Valve β€” Fast entropy routing
183
+ # ─────────────────────────────────────────────────
184
+ class EntropyValve(nn.Module):
185
+ """Entropy-based compute allocation valve.
186
+
187
+ Optimizations:
188
+ - log_softmax path for entropy (single pass, numerically stable)
189
+ - Pre-computed thresholds as constants
190
+ """
191
+
192
+ def __init__(self, config: dict):
193
+ super().__init__()
194
+ self.enabled = config.get('enabled', True)
195
+ self.threshold_bits = config.get('threshold_bits', 2.0)
196
+ self.levels = config.get('levels', {
197
+ 'low': {'loops': 1, 'min_span': 8, 'audit': 0.125},
198
+ 'medium': {'loops': 2, 'min_span': 4, 'audit': 0.5},
199
+ 'high': {'loops': 4, 'min_span': 1, 'audit': 1.0},
200
+ })
201
+ self.router = nn.Sequential(nn.Linear(6, 32), nn.ReLU(), nn.Linear(32, 3))
202
+ self._log2 = math.log(2.0)
203
+
204
+ def compute_entropy(self, logits: torch.Tensor) -> torch.Tensor:
205
+ """Entropy in bits via log_softmax (numerically stable)."""
206
+ log_probs = F.log_softmax(logits, dim=-1)
207
+ probs = log_probs.exp()
208
+ return -(probs * log_probs).sum(dim=-1) / self._log2
209
+
210
+ def get_level(self, entropy: torch.Tensor) -> str:
211
+ if not self.enabled:
212
+ return 'medium'
213
+ mean_h = entropy.mean().item()
214
+ if mean_h < self.threshold_bits * 0.5:
215
+ return 'low'
216
+ elif mean_h < self.threshold_bits:
217
+ return 'medium'
218
+ return 'high'
219
+
220
+ def get_loop_count(self, logits: torch.Tensor) -> int:
221
+ if not self.enabled:
222
+ return 2
223
+ entropy = self.compute_entropy(logits)
224
+ level = self.get_level(entropy)
225
+ return self.levels.get(level, self.levels['medium'])['loops']
226
+
227
+ def forward(self, logits: torch.Tensor):
228
+ entropy = self.compute_entropy(logits)
229
+ level = self.get_level(entropy)
230
+ return level, self.levels.get(level, self.levels['medium'])
231
+
232
+
233
+ # ─────────────────────────────────────────────────
234
+ # Debt Ledger
235
+ # ─────────────────────────────────────────────────
236
+ class DebtLedger(nn.Module):
237
+ def __init__(self, config: dict):
238
+ super().__init__()
239
+ self.enabled = config.get('enabled', True)
240
+ self.obligations = config.get('obligations', [])
241
+ self.max_outstanding = config.get('max_outstanding', 64)
242
+ self.pressure_weight = config.get('pressure_weight', 0.3)
243
+ self.active_debts = []
244
+ self.debt_bias_scale = nn.Parameter(torch.tensor(0.5))
245
+ self.debt_proj = nn.Linear(1, 1, bias=True)
246
+ nn.init.ones_(self.debt_proj.weight)
247
+ nn.init.zeros_(self.debt_proj.bias)
248
+
249
+ def add_debt(self, debt_type: str):
250
+ if len(self.active_debts) < self.max_outstanding:
251
+ self.active_debts.append(debt_type)
252
+
253
+ def resolve_debt(self, debt_type: str):
254
+ if debt_type in self.active_debts:
255
+ self.active_debts.remove(debt_type)
256
+
257
+ def get_pressure(self) -> float:
258
+ return self.pressure_weight * len(self.active_debts) / max(self.max_outstanding, 1)
259
+
260
+ def forward(self, logits: torch.Tensor) -> torch.Tensor:
261
+ if not self.enabled:
262
+ return logits
263
+ pressure = self.get_pressure()
264
+ if pressure > 0:
265
+ boost = self.debt_bias_scale * pressure
266
+ boosted = self.debt_proj(boost.unsqueeze(0).unsqueeze(0))
267
+ logits = logits + boosted * 0.01
268
+ return logits
269
+
270
+
271
+ # ─────────────────────────────────────────────────
272
+ # Braid State (runtime state container, not an nn.Module)
273
+ # ─────────────────────────────────────────────────
274
+ class BraidState:
275
+ __slots__ = ['continuous', 'fast', 'semantic_sketch', 'entity_slots',
276
+ 'grammar_stack', 'debt_ledger_slots']
277
+
278
+ def __init__(self, config: dict, device: str = 'cpu'):
279
+ D = config.get('continuous_hidden', [2560, 'float32'])[0]
280
+ self.continuous = torch.zeros(1, D, dtype=torch.float32, device=device)
281
+ self.fast = torch.zeros(1, D, dtype=torch.int8, device=device)
282
+ bits = config.get('semantic_sketch', [8192, 'uint64_x128'])[0]
283
+ self.semantic_sketch = torch.zeros(1, bits // 8, dtype=torch.uint8, device=device)
284
+ et = config.get('entity_table', {})
285
+ self.entity_slots = torch.zeros(
286
+ et.get('slots', 256), et.get('slot_bits', 512) // 8,
287
+ dtype=torch.uint8, device=device)
288
+ gs = config.get('grammar_stack', {})
289
+ self.grammar_stack = torch.zeros(
290
+ gs.get('slots', 64), gs.get('width_bits', 128) // 8,
291
+ dtype=torch.uint8, device=device)
292
+ self.debt_ledger_slots = torch.zeros(
293
+ config.get('debt_ledger_slots', 64), dtype=torch.int32, device=device)
294
+
295
+ def reset(self):
296
+ self.continuous.zero_()
297
+ self.fast.zero_()
298
+ self.semantic_sketch.zero_()