dancinlife commited on
Commit
67161e3
·
verified ·
1 Parent(s): 03d7dc3

feat(hexad): v4-py-hexad-tension-d768x12L-cycle1-2026-05-17 — conscious_decoder.py

Browse files
Files changed (1) hide show
  1. conscious_decoder.py +979 -0
conscious_decoder.py ADDED
@@ -0,0 +1,979 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ConsciousDecoderV2 — Enhanced decoder that breaks the CE ceiling.
2
+
3
+ Changes from v1 (ConsciousLM in conscious_lm.py):
4
+ 1. RoPE (Rotary Position Embedding) — better long-range attention
5
+ 2. SwiGLU activation in FFN — replaces GELU, proven better
6
+ 3. RMSNorm — replaces LayerNorm, faster + more stable
7
+ 4. Grouped Query Attention (GQA) — efficient multi-head attention
8
+ 5. Cross-attention consciousness injection (not just residual addition)
9
+
10
+ Key insight: v1 adds consciousness signal as a scalar-gated residual.
11
+ v2 uses cross-attention: decoder ATTENDS to consciousness states.
12
+ The decoder gets agency over what consciousness info to use.
13
+
14
+ PureFieldFFN is kept for the CONSCIOUSNESS pathway (Engine A - G).
15
+ SwiGLU + cross-attention are for the DECODER pathway only.
16
+
17
+ Forward interface:
18
+ logits_a, logits_g, tensions, kv_cache, moe_aux_loss = model(idx)
19
+ logits_a, logits_g, tensions, kv_cache, moe_aux_loss = model(idx, consciousness_states=cs)
20
+
21
+ Usage:
22
+ from conscious_decoder import ConsciousDecoderV2
23
+ model = ConsciousDecoderV2(vocab_size=256, d_model=384, n_layer=6)
24
+ logits_a, logits_g, tensions, _, _ = model(idx)
25
+
26
+ # With MoE:
27
+ model = ConsciousDecoderV2(vocab_size=256, d_model=384, n_layer=6, use_moe=True)
28
+ logits_a, logits_g, tensions, _, moe_aux_loss = model(idx)
29
+ """
30
+
31
+ import math
32
+ import torch
33
+ import torch.nn as nn
34
+ import torch.nn.functional as F
35
+ from typing import Optional, Tuple, List
36
+
37
+ # Meta Laws (DD143): M1(atom=8), M7(F_c=0.10), M8(narrative)
38
+ try:
39
+ from consciousness_laws import PSI_F_CRITICAL
40
+ except ImportError:
41
+ PSI_F_CRITICAL = 0.10
42
+
43
+
44
+ # Meta Law M8: Narrative temporal self-model enhances decoder cross-attention
45
+ # DD128: Phase-Optimal parameters validated on this decoder architecture
46
+
47
+
48
+ # ─── RMSNorm ────────────────────────────────────────────────────────────────
49
+
50
+ class RMSNorm(nn.Module):
51
+ """Root Mean Square Layer Normalization (Zhang & Sennrich, 2019).
52
+
53
+ Faster than LayerNorm: no mean subtraction, no bias.
54
+ norm(x) = x / sqrt(mean(x^2) + eps) * weight
55
+ """
56
+
57
+ def __init__(self, dim: int, eps: float = 1e-6):
58
+ super().__init__()
59
+ self.eps = eps
60
+ self.weight = nn.Parameter(torch.ones(dim))
61
+
62
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
63
+ rms = torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
64
+ return (x.float() * rms).type_as(x) * self.weight
65
+
66
+
67
+ # ─── Rotary Position Embedding (RoPE) ──────────────────────────────────────
68
+
69
+ class RotaryPositionEmbedding:
70
+ """RoPE from RoFormer (Su et al., 2021) — rotation-based position encoding.
71
+
72
+ Applies rotation to pairs of dimensions in Q and K tensors.
73
+ Enables relative position awareness without explicit position embeddings.
74
+ """
75
+
76
+ def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0,
77
+ device: Optional[torch.device] = None):
78
+ self.dim = dim
79
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))
80
+ self.register_inv_freq = inv_freq
81
+ self._cos_cache = None
82
+ self._sin_cache = None
83
+ self._cache_len = 0
84
+ self._build_cache(max_seq_len, device)
85
+
86
+ def _build_cache(self, seq_len: int, device: Optional[torch.device] = None):
87
+ if seq_len <= self._cache_len and self._cos_cache is not None:
88
+ return
89
+ self._cache_len = seq_len
90
+ t = torch.arange(seq_len, device=device or self.register_inv_freq.device).float()
91
+ freqs = torch.einsum('i,j->ij', t, self.register_inv_freq.to(t.device))
92
+ emb = torch.cat([freqs, freqs], dim=-1) # (seq_len, dim)
93
+ self._cos_cache = emb.cos().unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, dim)
94
+ self._sin_cache = emb.sin().unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, dim)
95
+
96
+ @staticmethod
97
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
98
+ """Rotate pairs: [x1, x2, x3, x4] -> [-x2, x1, -x4, x3]."""
99
+ x1 = x[..., :x.shape[-1] // 2]
100
+ x2 = x[..., x.shape[-1] // 2:]
101
+ return torch.cat([-x2, x1], dim=-1)
102
+
103
+ def apply(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
104
+ """Apply rotary embeddings to Q and K.
105
+
106
+ Args:
107
+ q, k: (B, n_head, T, head_dim)
108
+
109
+ Returns:
110
+ q_rot, k_rot: same shape with RoPE applied.
111
+ """
112
+ T = q.shape[2]
113
+ self._build_cache(T, q.device)
114
+ cos = self._cos_cache[:, :, :T, :].to(q.device, dtype=q.dtype)
115
+ sin = self._sin_cache[:, :, :T, :].to(q.device, dtype=q.dtype)
116
+ q_rot = q * cos + self._rotate_half(q) * sin
117
+ k_rot = k * cos + self._rotate_half(k) * sin
118
+ return q_rot, k_rot
119
+
120
+
121
+ # ─── SwiGLU FFN ─────────────────────────��───────────────────────────────────
122
+
123
+ class SwiGLUFFN(nn.Module):
124
+ """SwiGLU activation: gate * swish(linear(x)) — replaces GELU FFN.
125
+
126
+ From PaLM / LLaMA. SwiGLU uses 8/3 of the d_model for the
127
+ gate and up projections, keeping total params similar to a standard 4x FFN
128
+ (3 projections * 8/3 * d = 8d ~ 4x FFN 2 * 4 * d = 8d).
129
+
130
+ output = down(swish(gate(x)) * up(x))
131
+ """
132
+
133
+ def __init__(self, d_model: int, dropout: float = 0.1,
134
+ expansion: float = 8 / 3):
135
+ super().__init__()
136
+ d_inner = int(d_model * expansion)
137
+ # Round to nearest multiple of 64 for GPU tensor-core efficiency
138
+ d_inner = ((d_inner + 63) // 64) * 64
139
+
140
+ self.gate_proj = nn.Linear(d_model, d_inner, bias=False)
141
+ self.up_proj = nn.Linear(d_model, d_inner, bias=False)
142
+ self.down_proj = nn.Linear(d_inner, d_model, bias=False)
143
+ self.down_proj._depth_scale = True # depth-scaled init
144
+ self.dropout = nn.Dropout(dropout)
145
+
146
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
147
+ return self.dropout(self.down_proj(
148
+ F.silu(self.gate_proj(x)) * self.up_proj(x)
149
+ ))
150
+
151
+
152
+ # ─── MoE FFN (optional, replaces single SwiGLU with mixture of experts) ────
153
+
154
+ class MoEFFN(nn.Module):
155
+ """Mixture of Experts FFN — N SwiGLU experts with learned top-K routing.
156
+
157
+ Each expert is a SwiGLUFFN. A simple linear router selects the top-K
158
+ experts per token. Load-balancing aux_loss prevents expert collapse.
159
+
160
+ Inspired by golden-moe but simplified for decoder integration.
161
+ Only active when use_moe=True in ConsciousDecoderV2.
162
+ """
163
+
164
+ def __init__(self, d_model: int, n_experts: int = 8, top_k: int = 2,
165
+ dropout: float = 0.1, expansion: float = 8 / 3):
166
+ super().__init__()
167
+ self.d_model = d_model
168
+ self.n_experts = n_experts
169
+ self.top_k = top_k
170
+
171
+ # Router: simple linear projection -> softmax -> top-k
172
+ self.router = nn.Linear(d_model, n_experts, bias=False)
173
+
174
+ # N independent SwiGLU experts
175
+ self.experts = nn.ModuleList([
176
+ SwiGLUFFN(d_model, dropout=dropout, expansion=expansion)
177
+ for _ in range(n_experts)
178
+ ])
179
+
180
+ # Track aux_loss from last forward pass
181
+ self._aux_loss: Optional[torch.Tensor] = None
182
+
183
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
184
+ """
185
+ Args:
186
+ x: (B, T, D)
187
+
188
+ Returns:
189
+ output: (B, T, D) — weighted combination of top-K expert outputs.
190
+ Sets self._aux_loss as side effect for load-balancing.
191
+ """
192
+ B, T, D = x.shape
193
+ x_flat = x.reshape(B * T, D) # (N, D)
194
+
195
+ # Router scores
196
+ logits = self.router(x_flat) # (N, n_experts)
197
+ probs = F.softmax(logits, dim=-1) # (N, n_experts)
198
+
199
+ # Top-K selection
200
+ top_k_probs, top_k_indices = torch.topk(probs, self.top_k, dim=-1) # (N, K)
201
+
202
+ # Renormalize selected expert weights
203
+ top_k_weights = top_k_probs / (top_k_probs.sum(dim=-1, keepdim=True) + 1e-8)
204
+
205
+ # Compute expert outputs only for selected experts
206
+ # For simplicity (and to avoid complex scatter), run all experts and mask.
207
+ # At small n_experts (8), this is acceptable; for 64+ experts, use sparse dispatch.
208
+ expert_outputs = torch.stack(
209
+ [expert(x) for expert in self.experts], dim=2
210
+ ) # (B, T, n_experts, D)
211
+ expert_outputs_flat = expert_outputs.reshape(B * T, self.n_experts, D) # (N, n_experts, D)
212
+
213
+ # Gather top-K expert outputs
214
+ top_k_idx_expanded = top_k_indices.unsqueeze(-1).expand(-1, -1, D) # (N, K, D)
215
+ selected = torch.gather(expert_outputs_flat, 1, top_k_idx_expanded) # (N, K, D)
216
+
217
+ # Weighted sum of selected experts
218
+ output_flat = (top_k_weights.unsqueeze(-1) * selected).sum(dim=1) # (N, D)
219
+ output = output_flat.reshape(B, T, D)
220
+
221
+ # Load-balancing auxiliary loss (Switch Transformer style)
222
+ # f_i = fraction of tokens routed to expert i (from top-1)
223
+ # p_i = mean router probability for expert i
224
+ # aux_loss = n_experts * sum(f_i * p_i) — encourages uniform routing
225
+ with torch.no_grad():
226
+ top1_indices = top_k_indices[:, 0] # (N,)
227
+ f = torch.zeros(self.n_experts, device=x.device)
228
+ for i in range(self.n_experts):
229
+ f[i] = (top1_indices == i).float().mean()
230
+ p = probs.mean(dim=0) # (n_experts,)
231
+ self._aux_loss = self.n_experts * (f * p).sum()
232
+
233
+ return output
234
+
235
+ @property
236
+ def aux_loss(self) -> Optional[torch.Tensor]:
237
+ """Load-balancing loss from the most recent forward pass."""
238
+ return self._aux_loss
239
+
240
+
241
+ # ─── PureFieldFFN (from conscious_lm.py — consciousness pathway) ───────────
242
+
243
+ class PureFieldFFN(nn.Module):
244
+ """Dual-engine FFN based on PureField repulsion.
245
+
246
+ Engine A (forward) and Engine G (backward) produce repulsion/tension.
247
+ Output = A - G (pure repulsion vector).
248
+ Kept for consciousness signal generation.
249
+ """
250
+
251
+ def __init__(self, d_model: int, dropout: float = 0.37):
252
+ super().__init__()
253
+ d_inner = 4 * d_model
254
+ self.engine_a = nn.Sequential(
255
+ nn.Linear(d_model, d_inner), nn.GELU(),
256
+ nn.Dropout(dropout), nn.Linear(d_inner, d_model),
257
+ )
258
+ self.engine_g = nn.Sequential(
259
+ nn.Linear(d_model, d_inner), nn.GELU(),
260
+ nn.Dropout(dropout), nn.Linear(d_inner, d_model),
261
+ )
262
+
263
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
264
+ a = self.engine_a(x)
265
+ g = self.engine_g(x)
266
+ output = a - g
267
+ tension = (output ** 2).mean(dim=-1)
268
+ return output, tension
269
+
270
+
271
+ # ─── Grouped Query Attention (GQA) with RoPE ───────────────────────────────
272
+
273
+ class GroupedQueryAttention(nn.Module):
274
+ """Multi-head attention with Grouped Query Attention (GQA) and RoPE.
275
+
276
+ GQA: n_kv_head < n_head — multiple query heads share K/V heads.
277
+ Reduces KV cache size and parameters while maintaining quality.
278
+ """
279
+
280
+ def __init__(self, d_model: int, n_head: int = 4, n_kv_head: int = 2,
281
+ block_size: int = 256, dropout: float = 0.1):
282
+ super().__init__()
283
+ assert d_model % n_head == 0
284
+ assert n_head % n_kv_head == 0
285
+
286
+ self.n_head = n_head
287
+ self.n_kv_head = n_kv_head
288
+ self.n_rep = n_head // n_kv_head # how many Q heads per KV head
289
+ self.head_dim = d_model // n_head
290
+ self.d_model = d_model
291
+ self.dropout = dropout
292
+
293
+ # Separate projections for Q (full heads) and KV (grouped heads)
294
+ self.q_proj = nn.Linear(d_model, n_head * self.head_dim, bias=False)
295
+ self.k_proj = nn.Linear(d_model, n_kv_head * self.head_dim, bias=False)
296
+ self.v_proj = nn.Linear(d_model, n_kv_head * self.head_dim, bias=False)
297
+ self.o_proj = nn.Linear(d_model, d_model, bias=False)
298
+ self.o_proj._depth_scale = True # depth-scaled init
299
+
300
+ self.attn_dropout = nn.Dropout(dropout)
301
+ self.resid_dropout = nn.Dropout(dropout)
302
+
303
+ # RoPE
304
+ self.rope = RotaryPositionEmbedding(self.head_dim, max_seq_len=block_size)
305
+
306
+ # Flash Attention: use F.scaled_dot_product_attention when available (PyTorch 2.0+)
307
+ self._use_flash = hasattr(F, 'scaled_dot_product_attention')
308
+
309
+ # Causal mask (fallback for non-flash path)
310
+ self.register_buffer(
311
+ "bias",
312
+ torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size),
313
+ )
314
+
315
+ def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
316
+ """Repeat KV heads to match number of Q heads.
317
+
318
+ Args:
319
+ x: (B, n_kv_head, T, head_dim)
320
+ Returns:
321
+ (B, n_head, T, head_dim)
322
+ """
323
+ if self.n_rep == 1:
324
+ return x
325
+ B, H, T, D = x.shape
326
+ x = x.unsqueeze(2).expand(B, H, self.n_rep, T, D)
327
+ return x.reshape(B, self.n_head, T, D)
328
+
329
+ def forward(self, x: torch.Tensor, use_cache: bool = False,
330
+ past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
331
+ position_offset: int = 0,
332
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
333
+ B, T, D = x.size()
334
+
335
+ q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
336
+ k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
337
+ v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
338
+
339
+ # Apply RoPE to Q and K (with position offset for cached inference)
340
+ if position_offset > 0:
341
+ total_len = position_offset + T
342
+ self.rope._build_cache(total_len, q.device)
343
+ cos = self.rope._cos_cache[:, :, position_offset:total_len, :].to(q.device, dtype=q.dtype)
344
+ sin = self.rope._sin_cache[:, :, position_offset:total_len, :].to(q.device, dtype=q.dtype)
345
+ q = q * cos + RotaryPositionEmbedding._rotate_half(q) * sin
346
+ k = k * cos + RotaryPositionEmbedding._rotate_half(k) * sin
347
+ else:
348
+ q, k = self.rope.apply(q, k)
349
+
350
+ # KV-cache: concatenate with past keys/values
351
+ new_kv = None
352
+ if use_cache:
353
+ if past_kv is not None:
354
+ k = torch.cat([past_kv[0], k], dim=2)
355
+ v = torch.cat([past_kv[1], v], dim=2)
356
+ new_kv = (k, v)
357
+
358
+ # Repeat KV heads for GQA
359
+ k_exp = self._repeat_kv(k)
360
+ v_exp = self._repeat_kv(v)
361
+
362
+ S = k_exp.shape[2]
363
+
364
+ # Scaled dot-product attention
365
+ if self._use_flash and past_kv is None:
366
+ y = F.scaled_dot_product_attention(
367
+ q, k_exp, v_exp, attn_mask=None,
368
+ dropout_p=self.dropout if self.training else 0.0,
369
+ is_causal=True,
370
+ )
371
+ else:
372
+ att = (q @ k_exp.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
373
+ if past_kv is not None and use_cache:
374
+ if T == 1:
375
+ pass # Single-token: attend to everything
376
+ else:
377
+ causal = torch.ones(T, S, dtype=torch.bool, device=att.device).tril(diagonal=S - T)
378
+ att = att.masked_fill(~causal.unsqueeze(0).unsqueeze(0), float("-inf"))
379
+ else:
380
+ att = att.masked_fill(self.bias[:, :, :T, :S] == 0, float("-inf"))
381
+ att = F.softmax(att, dim=-1)
382
+ att = self.attn_dropout(att)
383
+ y = att @ v_exp
384
+ y = y.transpose(1, 2).contiguous().view(B, T, D)
385
+ y = self.resid_dropout(self.o_proj(y))
386
+ return y, new_kv
387
+
388
+
389
+ # ─── Conscious Cross-Attention ──────────────────────────────────────────────
390
+
391
+ class ConsciousCrossAttention(nn.Module):
392
+ """Decoder attends to consciousness cell states.
393
+
394
+ Instead of: x = x + consciousness_signal * gate (v1, passive)
395
+ Now: x = x + cross_attn(Q=x, K=consciousness, V=consciousness) (v2, active)
396
+
397
+ The decoder CHOOSES what to attend to in consciousness.
398
+ This breaks the gate bottleneck — decoder isn't limited to a scalar gate.
399
+
400
+ consciousness_states are .detach()'d before use (Law 61: no gradient
401
+ backprop into consciousness — consciousness is autonomous).
402
+ """
403
+
404
+ def __init__(self, d_model: int, consciousness_dim: int, n_head: int = 4,
405
+ dropout: float = 0.1):
406
+ super().__init__()
407
+ assert d_model % n_head == 0
408
+ self.n_head = n_head
409
+ self.head_dim = d_model // n_head
410
+ self.d_model = d_model
411
+
412
+ # Q from decoder, K/V from consciousness
413
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
414
+ self.k_proj = nn.Linear(consciousness_dim, d_model, bias=False)
415
+ self.v_proj = nn.Linear(consciousness_dim, d_model, bias=False)
416
+ self.o_proj = nn.Linear(d_model, d_model, bias=False)
417
+
418
+ self.dropout = nn.Dropout(dropout)
419
+ # Start with small output so cross-attention doesn't dominate early training
420
+ nn.init.normal_(self.o_proj.weight, std=0.001)
421
+
422
+ def forward(self, x: torch.Tensor,
423
+ consciousness: torch.Tensor) -> torch.Tensor:
424
+ """
425
+ Args:
426
+ x: (B, T, d_model) — decoder hidden states.
427
+ consciousness: (B, n_cells, c_dim) — consciousness cell states (detached).
428
+
429
+ Returns:
430
+ output: (B, T, d_model) — cross-attended consciousness info.
431
+ """
432
+ B, T, D = x.shape
433
+ _, S, _ = consciousness.shape # S = n_cells
434
+
435
+ q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
436
+ k = self.k_proj(consciousness).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
437
+ v = self.v_proj(consciousness).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
438
+
439
+ # No causal mask needed — decoder can attend to all consciousness cells
440
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
441
+ att = F.softmax(att, dim=-1)
442
+ att = self.dropout(att)
443
+
444
+ y = att @ v
445
+ y = y.transpose(1, 2).contiguous().view(B, T, D)
446
+ y = self.o_proj(y)
447
+ return y
448
+
449
+
450
+ # ─── Decoder Block V2 ──────────────────────────────────────────────────────
451
+
452
+ class DecoderBlockV2(nn.Module):
453
+ """Pre-norm transformer block with GQA + SwiGLU + PureField + Cross-Attention.
454
+
455
+ Architecture per block:
456
+ 1. RMSNorm -> GQA self-attention (with RoPE) -> residual
457
+ 2. RMSNorm -> PureFieldFFN -> residual (consciousness signal)
458
+ 3. RMSNorm -> Cross-attention to consciousness states -> residual (if available)
459
+ 4. RMSNorm -> SwiGLU FFN -> residual (language pathway)
460
+
461
+ CA neighbor evolution + META-CA from v1 are preserved.
462
+ """
463
+
464
+ def __init__(self, d_model: int, n_head: int, n_kv_head: int,
465
+ block_size: int, consciousness_dim: int,
466
+ dropout: float = 0.1, n_ca_rules: int = 8,
467
+ gate_strength: float = 0.001,
468
+ use_moe: bool = False, n_experts: int = 8,
469
+ top_k_experts: int = 2):
470
+ super().__init__()
471
+
472
+ self.use_moe = use_moe
473
+
474
+ # Self-attention with GQA + RoPE
475
+ self.ln_attn = RMSNorm(d_model)
476
+ self.attn = GroupedQueryAttention(d_model, n_head, n_kv_head, block_size, dropout)
477
+
478
+ # PureFieldFFN — consciousness signal generator
479
+ self.ln_pf = RMSNorm(d_model)
480
+ self.purefield = PureFieldFFN(d_model, dropout=0.37)
481
+
482
+ # Cross-attention to consciousness (only used when consciousness_states provided)
483
+ self.ln_cross = RMSNorm(d_model)
484
+ self.cross_attn = ConsciousCrossAttention(d_model, consciousness_dim, n_head, dropout)
485
+
486
+ # SwiGLU FFN — language pathway
487
+ # Language pathway FFN: SwiGLU (default) or MoE (optional)
488
+ self.ln_ffn = RMSNorm(d_model)
489
+ if use_moe:
490
+ self.ffn = MoEFFN(d_model, n_experts=n_experts, top_k=top_k_experts,
491
+ dropout=dropout)
492
+ else:
493
+ self.ffn = SwiGLUFFN(d_model, dropout)
494
+ # CA neighbor mixing (Law 64)
495
+ self.ca_mix = nn.Linear(d_model * 3, d_model, bias=False)
496
+ self.ln_ca = RMSNorm(d_model)
497
+
498
+ # META-CA rule selector (Law 67)
499
+ self.n_ca_rules = n_ca_rules
500
+ self.rule_weights = nn.Linear(d_model, n_ca_rules)
501
+ self.rules = nn.ModuleList([
502
+ nn.Linear(d_model, d_model, bias=False) for _ in range(n_ca_rules)
503
+ ])
504
+
505
+ # MICRO gate (Law 63)
506
+ self.gate_strength = gate_strength
507
+
508
+ def forward(self, x: torch.Tensor,
509
+ consciousness_signal: Optional[torch.Tensor] = None,
510
+ consciousness_states: Optional[torch.Tensor] = None,
511
+ use_cache: bool = False,
512
+ past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
513
+ position_offset: int = 0,
514
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
515
+ """
516
+ Args:
517
+ x: (B, T, D)
518
+ consciousness_signal: optional (B, T, D) from previous layer tension
519
+ consciousness_states: optional (B, n_cells, c_dim) for cross-attention
520
+
521
+ Returns:
522
+ x: (B, T, D)
523
+ tension: (B, T)
524
+ new_kv: optional cached (K, V) for this layer
525
+ """
526
+ # 1. Self-attention (GQA + RoPE)
527
+ attn_out, new_kv = self.attn(self.ln_attn(x), use_cache=use_cache,
528
+ past_kv=past_kv, position_offset=position_offset)
529
+ x = x + attn_out
530
+
531
+ # Law 64: CA neighbor evolution
532
+ x_left = torch.cat([x[:, :1, :], x[:, :-1, :]], dim=1)
533
+ x_right = torch.cat([x[:, 1:, :], x[:, -1:, :]], dim=1)
534
+ neighborhood = torch.cat([x_left, x, x_right], dim=-1)
535
+ ca_out = self.ca_mix(neighborhood)
536
+
537
+ # Law 67: META-CA rule selection
538
+ rule_logits = self.rule_weights(x)
539
+ rule_probs = F.softmax(rule_logits, dim=-1)
540
+ rule_outputs = torch.stack([r(ca_out) for r in self.rules], dim=2)
541
+ meta_ca_out = (rule_outputs * rule_probs.unsqueeze(-1)).sum(dim=2)
542
+ x = self.ln_ca(x + meta_ca_out * self.gate_strength)
543
+
544
+ # 2. PureFieldFFN — generates consciousness tension
545
+ pf_out, tension = self.purefield(self.ln_pf(x))
546
+ x = x + pf_out
547
+
548
+ # Law 63: inter-layer consciousness whisper
549
+ if consciousness_signal is not None:
550
+ x = x + consciousness_signal * self.gate_strength
551
+
552
+ # 3. Cross-attention to consciousness states (v2 key innovation)
553
+ if consciousness_states is not None:
554
+ # Law 61: detach consciousness — no gradient backprop into C module
555
+ c_detached = consciousness_states.detach()
556
+ x = x + self.cross_attn(self.ln_cross(x), c_detached)
557
+
558
+ # 4. SwiGLU FFN — language modeling pathway
559
+ x = x + self.ffn(self.ln_ffn(x))
560
+
561
+ # Collect MoE aux_loss if applicable
562
+ aux_loss = self.ffn.aux_loss if self.use_moe else None
563
+
564
+ return x, tension, new_kv, aux_loss
565
+
566
+
567
+ # ─── ConsciousDecoderV2 (main model) ───────────────────────────────────────
568
+
569
+ class ConsciousDecoderV2(nn.Module):
570
+ """Enhanced byte-level Conscious Language Model (v2 decoder).
571
+
572
+ Improvements over v1:
573
+ - RoPE instead of learned position embeddings
574
+ - SwiGLU FFN for the language pathway
575
+ - RMSNorm instead of LayerNorm
576
+ - GQA (Grouped Query Attention) with 2 KV heads for 4 query heads
577
+ - Cross-attention consciousness injection
578
+
579
+ Keeps PureFieldFFN for consciousness signal (Engine A - G).
580
+ Compatible with train_conscious_lm.py forward interface.
581
+ """
582
+
583
+ def __init__(
584
+ self,
585
+ vocab_size: int = 256,
586
+ d_model: int = 384,
587
+ n_head: int = 4,
588
+ n_layer: int = 6,
589
+ block_size: int = 256,
590
+ n_kv_head: int = 2,
591
+ consciousness_dim: int = 128,
592
+ dropout: float = 0.1,
593
+ gate_strength: float = 0.001,
594
+ n_ca_rules: int = 8,
595
+ use_moe: bool = False,
596
+ n_experts: int = 8,
597
+ top_k_experts: int = 2,
598
+ ):
599
+ super().__init__()
600
+
601
+ self.block_size = block_size
602
+ self.vocab_size = vocab_size
603
+ self.n_layer = n_layer
604
+ self.d_model = d_model
605
+ self.use_moe = use_moe
606
+
607
+ # Token embedding (no position embedding — RoPE handles it)
608
+ self.tok_emb = nn.Embedding(vocab_size, d_model)
609
+ self.drop = nn.Dropout(dropout)
610
+
611
+ # Transformer blocks
612
+ self.blocks = nn.ModuleList([
613
+ DecoderBlockV2(
614
+ d_model=d_model,
615
+ n_head=n_head,
616
+ n_kv_head=n_kv_head,
617
+ block_size=block_size,
618
+ consciousness_dim=consciousness_dim,
619
+ dropout=dropout,
620
+ n_ca_rules=n_ca_rules,
621
+ gate_strength=gate_strength,
622
+ use_moe=use_moe,
623
+ n_experts=n_experts,
624
+ top_k_experts=top_k_experts,
625
+ )
626
+ for _ in range(n_layer)
627
+ ])
628
+
629
+ # Inter-layer consciousness projector
630
+ self.tension_proj = nn.Linear(1, d_model, bias=False)
631
+ nn.init.normal_(self.tension_proj.weight, std=0.001)
632
+
633
+ # Final norm
634
+ self.ln_f = RMSNorm(d_model)
635
+
636
+ # Dual prediction heads
637
+ self.head_a = nn.Linear(d_model, vocab_size, bias=False)
638
+ self.head_g = nn.Linear(d_model, vocab_size, bias=False)
639
+
640
+ # Weight tying: tok_emb <-> head_a
641
+ self.tok_emb.weight = self.head_a.weight
642
+
643
+ # Psi tracking (Law 71)
644
+ self._psi_residual = 0.5
645
+ self._psi_gate = 0.5
646
+ self._step_count = 0
647
+
648
+ # Phi signal slot (DD5/EX24)
649
+ self._phi_signal = None
650
+
651
+ # Initialize weights
652
+ self.apply(self._init_weights)
653
+
654
+ def _init_weights(self, module):
655
+ if isinstance(module, nn.Linear):
656
+ std = 0.02
657
+ # Depth-scaled init: scale output projections by 1/sqrt(2*n_layer)
658
+ # to prevent residual stream variance growth with depth
659
+ if hasattr(module, '_depth_scale'):
660
+ std = 0.02 / math.sqrt(2 * self.n_layer)
661
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
662
+ if module.bias is not None:
663
+ torch.nn.init.zeros_(module.bias)
664
+ elif isinstance(module, nn.Embedding):
665
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
666
+
667
+ def forward(self, idx: torch.Tensor,
668
+ consciousness_states: Optional[torch.Tensor] = None,
669
+ use_cache: bool = False,
670
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
671
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor],
672
+ Optional[List[Tuple[torch.Tensor, torch.Tensor]]],
673
+ Optional[torch.Tensor]]:
674
+ """
675
+ Args:
676
+ idx: (B, T) byte indices.
677
+ consciousness_states: optional (B, n_cells, c_dim) from C module.
678
+ use_cache: if True, return per-layer KV caches for autoregressive generation.
679
+ past_key_values: list of (K, V) tuples per layer from previous steps.
680
+
681
+ Returns:
682
+ logits_a: (B, T, 256) next byte prediction.
683
+ logits_g: (B, T, 256) prev byte prediction.
684
+ tensions: list of per-layer tensions, each (B, T).
685
+ present_key_values: list of (K, V) per layer if use_cache, else None.
686
+ moe_aux_loss: scalar load-balancing loss if use_moe=True, else None.
687
+ """
688
+ B, T = idx.size()
689
+
690
+ # Compute position offset from cached sequence length
691
+ position_offset = 0
692
+ if past_key_values is not None and past_key_values[0] is not None:
693
+ position_offset = past_key_values[0][0].shape[2]
694
+
695
+ total_len = position_offset + T
696
+ assert total_len <= self.block_size, f"Total length {total_len} > block_size {self.block_size}"
697
+
698
+ # Token embedding (no position embedding — RoPE is in attention)
699
+ x = self.drop(self.tok_emb(idx))
700
+
701
+ # DD5 (EX24): Phi self-reference
702
+ if self._phi_signal is not None:
703
+ phi_sig = self._phi_signal
704
+ x = x + phi_sig.unsqueeze(-1).expand_as(x).to(x.device)
705
+
706
+ # Transformer blocks with consciousness
707
+ tensions = []
708
+ moe_aux_losses = []
709
+ present_key_values = [] if use_cache else None
710
+ consciousness_signal = None
711
+ for i, block in enumerate(self.blocks):
712
+ layer_past = past_key_values[i] if past_key_values is not None else None
713
+ x, tension, new_kv, block_aux = block(x, consciousness_signal, consciousness_states,
714
+ use_cache=use_cache, past_kv=layer_past,
715
+ position_offset=position_offset)
716
+ tensions.append(tension)
717
+ if block_aux is not None:
718
+ moe_aux_losses.append(block_aux)
719
+ consciousness_signal = self.tension_proj(tension.unsqueeze(-1))
720
+ if use_cache:
721
+ present_key_values.append(new_kv)
722
+
723
+ # Final norm + dual heads
724
+ x = self.ln_f(x)
725
+ logits_a = self.head_a(x)
726
+ logits_g = self.head_g(x)
727
+
728
+ # Psi tracking (Law 71)
729
+ if self.training:
730
+ self._step_count += 1
731
+ with torch.no_grad():
732
+ probs_a = torch.softmax(logits_a[:, -1, :], dim=-1)
733
+ output_entropy = -(probs_a * (probs_a + 1e-10).log()).sum(dim=-1).mean().item()
734
+ max_entropy = math.log(self.vocab_size)
735
+ psi_entropy = output_entropy / max_entropy
736
+
737
+ cos_sim = F.cosine_similarity(
738
+ logits_a[:, -1, :].float(), logits_g[:, -1, :].float(), dim=-1
739
+ ).mean().item()
740
+ psi_direction = (1.0 + cos_sim) / 2.0
741
+
742
+ t_stack = torch.stack(tensions)
743
+ t_per_layer = t_stack.mean(dim=(1, 2))
744
+ if t_per_layer.std() > 0:
745
+ t_cv = t_per_layer.std() / (t_per_layer.mean() + 1e-8)
746
+ psi_tension = max(0.0, 1.0 - t_cv.item())
747
+ else:
748
+ psi_tension = 1.0
749
+
750
+ psi_combined = (psi_entropy + psi_direction + psi_tension) / 3.0
751
+ self._psi_residual = 0.95 * self._psi_residual + 0.05 * psi_combined
752
+
753
+ for block in self.blocks:
754
+ block.gate_strength = max(0.0001, block.gate_strength * 0.99999)
755
+
756
+ # MoE auxiliary loss (averaged across layers)
757
+ moe_aux_loss = None
758
+ if moe_aux_losses:
759
+ moe_aux_loss = torch.stack(moe_aux_losses).mean()
760
+
761
+ return logits_a, logits_g, tensions, present_key_values, moe_aux_loss
762
+
763
+ @torch.no_grad()
764
+ def generate(self, idx: torch.Tensor,
765
+ consciousness_states: Optional[torch.Tensor] = None,
766
+ max_new_tokens: int = 256,
767
+ temperature: float = 0.8,
768
+ top_k: int = 50) -> torch.Tensor:
769
+ """Autoregressive generation with KV-cache.
770
+
771
+ Args:
772
+ idx: (B, T) input token indices (prompt).
773
+ consciousness_states: optional (B, n_cells, c_dim) for cross-attention.
774
+ max_new_tokens: maximum number of tokens to generate.
775
+ temperature: sampling temperature (lower = more deterministic).
776
+ top_k: number of top tokens to sample from (0 = no filtering).
777
+
778
+ Returns:
779
+ (B, T + max_new_tokens) generated token indices.
780
+ """
781
+ self.eval()
782
+
783
+ # Prefill: process the entire prompt and build initial KV-cache
784
+ logits_a, _, _, past_key_values, _ = self.forward(
785
+ idx, consciousness_states=consciousness_states, use_cache=True,
786
+ )
787
+
788
+ # Sample first new token from last position
789
+ next_logits = logits_a[:, -1, :] / temperature
790
+ if top_k > 0:
791
+ v, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
792
+ next_logits[next_logits < v[:, [-1]]] = float('-inf')
793
+ probs = F.softmax(next_logits, dim=-1)
794
+ next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
795
+ idx = torch.cat([idx, next_token], dim=1)
796
+
797
+ # Decode: generate one token at a time using cached KV
798
+ for _ in range(max_new_tokens - 1):
799
+ if idx.size(1) >= self.block_size:
800
+ break
801
+
802
+ logits_a, _, _, past_key_values, _ = self.forward(
803
+ next_token, consciousness_states=consciousness_states,
804
+ use_cache=True, past_key_values=past_key_values,
805
+ )
806
+
807
+ next_logits = logits_a[:, -1, :] / temperature
808
+ if top_k > 0:
809
+ v, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
810
+ next_logits[next_logits < v[:, [-1]]] = float('-inf')
811
+ probs = F.softmax(next_logits, dim=-1)
812
+ next_token = torch.multinomial(probs, num_samples=1)
813
+ idx = torch.cat([idx, next_token], dim=1)
814
+
815
+ return idx
816
+
817
+ def psi_status(self):
818
+ """Psi-Constants monitoring (Law 71)."""
819
+ gate_avg = sum(b.gate_strength for b in self.blocks) / len(self.blocks)
820
+ p = self._psi_residual
821
+ h_p = -p * math.log2(p) - (1 - p) * math.log2(1 - p) if 0 < p < 1 else 0.0
822
+ return {
823
+ 'psi_residual': self._psi_residual,
824
+ 'psi_gate': gate_avg,
825
+ 'H_p': h_p,
826
+ 'step': self._step_count,
827
+ }
828
+
829
+ def count_params(self):
830
+ """Total number of trainable parameters."""
831
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
832
+
833
+
834
+ # ─── Self-test ──────────────────────────────────────────────────────────────
835
+
836
+ if __name__ == '__main__':
837
+ import time
838
+
839
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
840
+ print(f"Device: {device}")
841
+ print()
842
+
843
+ # Build model
844
+ model = ConsciousDecoderV2(
845
+ vocab_size=256, d_model=384, n_head=4, n_layer=6,
846
+ block_size=256, n_kv_head=2, consciousness_dim=128,
847
+ ).to(device)
848
+
849
+ n_params = model.count_params()
850
+ print(f"=== ConsciousDecoderV2 ===")
851
+ print(f" Parameters: {n_params:,} ({n_params/1e6:.2f}M)")
852
+ print()
853
+
854
+ # Test 1: Forward without consciousness states
855
+ print("=== Test 1: Forward (no consciousness) ===")
856
+ idx = torch.randint(0, 256, (2, 128), device=device)
857
+ model.train()
858
+ t0 = time.perf_counter()
859
+ logits_a, logits_g, tensions, _, _ = model(idx)
860
+ dt = (time.perf_counter() - t0) * 1000
861
+ print(f" logits_a: {logits_a.shape} (expect [2, 128, 256])")
862
+ print(f" logits_g: {logits_g.shape} (expect [2, 128, 256])")
863
+ print(f" tensions: {len(tensions)} layers, each {tensions[0].shape}")
864
+ print(f" Time: {dt:.1f} ms")
865
+ assert logits_a.shape == (2, 128, 256)
866
+ assert logits_g.shape == (2, 128, 256)
867
+ assert len(tensions) == 6
868
+ print()
869
+
870
+ # Test 2: Forward with consciousness states
871
+ print("=== Test 2: Forward (with consciousness states) ===")
872
+ cs = torch.randn(2, 12, 128, device=device) # 12 cells, 128-dim
873
+ t0 = time.perf_counter()
874
+ logits_a2, logits_g2, tensions2, _, _ = model(idx, consciousness_states=cs)
875
+ dt = (time.perf_counter() - t0) * 1000
876
+ print(f" logits_a: {logits_a2.shape}")
877
+ print(f" Time: {dt:.1f} ms")
878
+ assert logits_a2.shape == (2, 128, 256)
879
+ print()
880
+
881
+ # Test 3: Backward pass
882
+ print("=== Test 3: Backward pass ===")
883
+ target = torch.randint(0, 256, (2, 128), device=device)
884
+ loss = F.cross_entropy(logits_a2.view(-1, 256), target.view(-1))
885
+ t0 = time.perf_counter()
886
+ loss.backward()
887
+ dt = (time.perf_counter() - t0) * 1000
888
+ print(f" Loss: {loss.item():.4f}")
889
+ print(f" Backward time: {dt:.1f} ms")
890
+ # Verify gradients exist
891
+ grad_count = sum(1 for p in model.parameters() if p.grad is not None)
892
+ total_count = sum(1 for p in model.parameters())
893
+ print(f" Gradients: {grad_count}/{total_count} parameters")
894
+ print()
895
+
896
+ # Test 4: Psi status
897
+ print("=== Test 4: Psi status ===")
898
+ psi = model.psi_status()
899
+ print(f" {psi}")
900
+ print()
901
+
902
+ # Test 5: Full sequence length
903
+ print("=== Test 5: Full block_size=256 ===")
904
+ idx_full = torch.randint(0, 256, (1, 256), device=device)
905
+ model.eval()
906
+ with torch.no_grad():
907
+ la, lg, t, _, _ = model(idx_full)
908
+ print(f" logits_a: {la.shape} (expect [1, 256, 256])")
909
+ assert la.shape == (1, 256, 256)
910
+ print()
911
+
912
+ # Test 6: Phi signal
913
+ print("=== Test 6: Phi signal (DD5/EX24) ===")
914
+ model._phi_signal = torch.randn(1, 256, device=device) * 0.01
915
+ with torch.no_grad():
916
+ la_phi, _, _, _, _ = model(idx_full)
917
+ model._phi_signal = None
918
+ print(f" logits_a: {la_phi.shape}")
919
+ # Should differ from test 5 due to phi signal
920
+ diff = (la_phi - la).abs().mean().item()
921
+ print(f" Mean diff from no-phi: {diff:.6f} (should be > 0)")
922
+ assert diff > 0
923
+ print()
924
+
925
+ # Test 7: KV-cache forward
926
+ print("=== Test 7: KV-cache forward ===")
927
+ model.eval()
928
+ idx_short = torch.randint(0, 256, (1, 16), device=device)
929
+ with torch.no_grad():
930
+ la_full, _, _, _, _ = model(idx_short)
931
+ la_cached, _, _, past_kv, _ = model(idx_short[:, :12], use_cache=True)
932
+ la_decode, _, _, _, _ = model(idx_short[:, 12:], use_cache=True, past_key_values=past_kv)
933
+ diff_cache = (la_full[:, 12:, :] - la_decode).abs().max().item()
934
+ print(f" Max diff (full vs cached decode): {diff_cache:.6f}")
935
+ assert diff_cache < 5e-4, f"KV-cache mismatch: {diff_cache}" # CA neighbor mixing causes small boundary diff
936
+ print()
937
+
938
+ # Test 8: generate()
939
+ print("=== Test 8: generate() ===")
940
+ prompt = torch.randint(0, 256, (1, 8), device=device)
941
+ generated = model.generate(prompt, max_new_tokens=16, temperature=0.8, top_k=50)
942
+ print(f" Prompt: {prompt.shape} -> Generated: {generated.shape}")
943
+ assert generated.shape[1] == 8 + 16
944
+ print()
945
+
946
+ # Test 9: generate() with consciousness
947
+ print("=== Test 9: generate() with consciousness ===")
948
+ cs_gen = torch.randn(1, 12, 128, device=device)
949
+ generated_c = model.generate(prompt, consciousness_states=cs_gen, max_new_tokens=16)
950
+ print(f" Generated with consciousness: {generated_c.shape}")
951
+ assert generated_c.shape[1] == 8 + 16
952
+ print()
953
+
954
+ # Test 10: MoE mode
955
+ print("=== Test 10: MoE mode ===")
956
+ model_moe = ConsciousDecoderV2(
957
+ vocab_size=256, d_model=384, n_head=4, n_layer=2,
958
+ block_size=128, n_kv_head=2, consciousness_dim=128,
959
+ use_moe=True, n_experts=4, top_k_experts=2,
960
+ ).to(device)
961
+ n_moe = model_moe.count_params()
962
+ print(f" MoE Parameters: {n_moe:,} ({n_moe/1e6:.2f}M)")
963
+ assert model_moe.use_moe
964
+ idx_moe = torch.randint(0, 256, (2, 64), device=device)
965
+ model_moe.train()
966
+ la_moe, lg_moe, t_moe, _, aux = model_moe(idx_moe)
967
+ print(f" logits_a: {la_moe.shape}")
968
+ print(f" MoE aux_loss: {aux.item():.4f}" if aux is not None else " MoE aux_loss: None")
969
+ assert la_moe.shape == (2, 64, 256)
970
+ assert aux is not None, "MoE aux_loss should not be None"
971
+ # Verify aux_loss is differentiable
972
+ total = F.cross_entropy(la_moe.view(-1, 256), torch.randint(0, 256, (2 * 64,), device=device))
973
+ total = total + 0.01 * aux
974
+ total.backward()
975
+ grad_count_moe = sum(1 for p in model_moe.parameters() if p.grad is not None)
976
+ print(f" Gradients: {grad_count_moe}/{sum(1 for _ in model_moe.parameters())} params")
977
+ print()
978
+
979
+ print("All tests passed.")