| """Components — core neural network modules for the ARB system.""" |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
| from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm, GROUP_SIZES, _COMPONENT_CONTEXT, _HAS_TRITON |
| try: |
| from .kernel.ternary_scale import _TritonTernaryEmbedFn |
| except ImportError: |
| _TritonTernaryEmbedFn = None |
| from .converters.convert_to_ternary8 import pack_ternary, unpack_ternary |
| from dataclasses import dataclass, field, fields |
| from math import ceil as _ceil, log2 as _log2 |
| from transformers import AutoModel, AutoFeatureExtractor |
| from .config import VOCAB, EMBEDDING_DIM, HIDDEN_DIM, AUDIO_VOCAB, AUDIO_SR, AUDIO_FRAME_RATE, SPECIAL_VOCAB, CODEBOOK_DIM, CODEBOOK_SIZE, FFN_HIDDEN, CTX, THRESHOLD, KG_EMA_ALPHA, KG_REQUANT_EVERY, KG_TERNARY_THRESHOLD, KGVQ_CODEBOOK_SIZE, KGVQ_CODEBOOK_DIM, KGVQ_DECAY, KGVQ_COMMITMENT_WEIGHT, KGVQ_DEAD_CODE_THRESHOLD, K_MAX_COMPOSITES, MG_N_EXPERTS, MG_CORE_RANK, MG_SHARED_INTER, MG_ACT_ITERS, MG_WORKSPACE_DIM, BYTEHEAD_ACT_MAX_ITERS, BYTEHEAD_ACT_HALT_CONSECUTIVE |
|
|
| _ceil_div = lambda a, b: _ceil(a / b) if b > 0 else 0 |
|
|
| from .sequencers import ByteEmbedding |
|
|
|
|
| @dataclass |
| class LossWeights: |
| lm: float = 1.0 |
| vq_commitment: float = 1.0 |
| moe_aux: float = 1.0 |
| graph_l1: float = 0.001 |
| graph_ponder: float = 1.0 |
| moe_ponder: float = 1.0 |
| moegraph_ponder: float = 1.0 |
| memgram_decay_reg: float = 0.01 |
| composite_vq: float = 1.0 |
|
|
|
|
| @dataclass |
| class LossComponents: |
| lm: torch.Tensor = None |
| vq_commitment: torch.Tensor = None |
| moe_aux: torch.Tensor = None |
| graph_l1: torch.Tensor = None |
| graph_ponder: torch.Tensor = None |
| moe_ponder: torch.Tensor = None |
| moegraph_ponder: torch.Tensor = None |
| memgram_decay_reg: torch.Tensor = None |
| composite_vq: torch.Tensor = None |
| weights: LossWeights = field(default_factory=LossWeights) |
|
|
| @property |
| def total(self) -> torch.Tensor: |
| w = self.weights |
| loss = None |
|
|
| def add_component(current, weight, component): |
| if component is None: |
| return current |
| weighted = weight * component |
| return weighted if current is None else current + weighted |
|
|
| loss = add_component(loss, w.lm, self.lm) |
| loss = add_component(loss, w.vq_commitment, self.vq_commitment) |
| loss = add_component(loss, w.moe_aux, self.moe_aux) |
| loss = add_component(loss, w.graph_l1, self.graph_l1) |
| loss = add_component(loss, w.graph_ponder, self.graph_ponder) |
| loss = add_component(loss, w.moe_ponder, self.moe_ponder) |
| loss = add_component(loss, w.moegraph_ponder, self.moegraph_ponder) |
| loss = add_component(loss, w.memgram_decay_reg, self.memgram_decay_reg) |
| loss = add_component(loss, w.composite_vq, self.composite_vq) |
| if loss is None: |
| raise ValueError("LossComponents.total requested with no active loss tensors") |
| return loss |
|
|
| @property |
| def active_fields(self) -> list[tuple[str, torch.Tensor, float]]: |
| result = [] |
| for field in fields(self): |
| name = field.name |
| if name == 'weights': |
| continue |
| tensor = getattr(self, name) |
| if tensor is not None: |
| weight = getattr(self.weights, name) |
| result.append((name, tensor, weight)) |
| return result |
|
|
| def log(self, writer, step, prefix="loss"): |
| writer.add_scalar(f"{prefix}/total", self.total.item(), step) |
| if self.lm is not None: |
| writer.add_scalar(f"{prefix}/lm", self.lm.item(), step) |
| if self.vq_commitment is not None: |
| writer.add_scalar(f"{prefix}/vq_commitment", self.vq_commitment.item(), step) |
| if self.moe_aux is not None: |
| writer.add_scalar(f"{prefix}/moe_aux", self.moe_aux.item(), step) |
| if self.graph_l1 is not None: |
| writer.add_scalar(f"{prefix}/graph_l1", self.graph_l1.item(), step) |
| if self.graph_ponder is not None: |
| writer.add_scalar(f"{prefix}/graph_ponder", self.graph_ponder.item(), step) |
| if self.moe_ponder is not None: |
| writer.add_scalar(f"{prefix}/moe_ponder", self.moe_ponder.item(), step) |
| if self.moegraph_ponder is not None: |
| writer.add_scalar(f"{prefix}/moegraph_ponder", self.moegraph_ponder.item(), step) |
| if self.memgram_decay_reg is not None: |
| writer.add_scalar(f"{prefix}/memgram_decay_reg", self.memgram_decay_reg.item(), step) |
| if self.composite_vq is not None: |
| writer.add_scalar(f"{prefix}/composite_vq", self.composite_vq.item(), step) |
|
|
| def backward(self, retain_graph=False): |
| self.total.backward(retain_graph=retain_graph) |
|
|
|
|
| class StickyZoneSTE(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, w, threshold): |
| ctx.save_for_backward(w, torch.tensor(threshold)) |
| return w.sign() * (w.abs() > threshold).to(w.dtype) |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| w, threshold_tensor = ctx.saved_tensors |
| threshold = threshold_tensor.item() |
| ratio = torch.clamp(w.abs() / threshold, 0.0, 1.0) |
| return grad_output * ratio, None |
|
|
|
|
| class TernaryEmbeddingTable(nn.Module): |
| def __init__(self, num_embeddings, embedding_dim, tscale_type=TScaleType.T32, |
| init_std=0.02, threshold=0.05, normalize=False): |
| super().__init__() |
| self.num_embeddings = num_embeddings |
| self.embedding_dim = embedding_dim |
| self.tscale_type = tscale_type |
| init_threshold = min(float(threshold), 0.5 * float(init_std)) if init_std > 0 else threshold |
| self.threshold = init_threshold |
| self.normalize = normalize |
| self.group_size = GROUP_SIZES.get(tscale_type, GROUP_SIZES[TScaleType.T64]) |
| self.sparse_threshold = 65_536 |
|
|
| if num_embeddings >= self.sparse_threshold: |
| n_trits = num_embeddings * embedding_dim |
| n_packed = _ceil_div(n_trits, 5) |
| packed_T = torch.randint(0, 243, (n_packed,), dtype=torch.uint8) |
| T_pad = n_packed * 5 - n_trits |
| gpr = _ceil_div(embedding_dim, self.group_size) |
| init_exp = int(round(_log2(max(init_std, 1e-8)))) |
| self.register_buffer("T_packed", packed_T) |
| self.register_buffer("_T_shape", torch.tensor([num_embeddings, embedding_dim], dtype=torch.long)) |
| self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long)) |
| self.register_buffer( |
| "E", |
| torch.full((num_embeddings * gpr,), init_exp, dtype=torch.int8), |
| ) |
| self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8)) |
| self.register_buffer("T_accum", torch.zeros(num_embeddings, embedding_dim, dtype=torch.int8)) |
| self._ema_alpha: float = 0.1 |
| self._loss_temp_scale: float = 1.0 |
| return |
|
|
| w_init = torch.randn(num_embeddings, embedding_dim) * init_std |
| T_init = w_init.sign() * (w_init.abs() > init_threshold).to(w_init.dtype) |
| packed_T, _, T_pad = pack_ternary(T_init) |
| self.register_buffer("T_packed", packed_T) |
| self.register_buffer("_T_shape", torch.tensor([num_embeddings, embedding_dim], dtype=torch.long)) |
| self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long)) |
|
|
| gpr = _ceil_div(embedding_dim, self.group_size) |
| total_in = gpr * self.group_size |
| padded = torch.zeros(num_embeddings, total_in) |
| padded[:, :embedding_dim] = w_init.abs() |
| grouped = padded.view(num_embeddings, gpr, self.group_size) |
| E_vals = torch.where(grouped.mean(dim=2) > 0, grouped.mean(dim=2), torch.ones(num_embeddings, gpr)) |
| self.register_buffer("E", E_vals.flatten().log2().clamp(-128, 127).to(torch.int8)) |
| self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8)) |
| self.register_buffer("T_accum", torch.zeros(num_embeddings, embedding_dim, dtype=torch.int8)) |
| self._ema_alpha: float = 0.1 |
| self._loss_temp_scale: float = 1.0 |
|
|
| def _get_T(self): |
| return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item())) |
|
|
| def _get_T_rows(self, indices): |
| indices = indices.reshape(-1).to(device=self.T_packed.device, dtype=torch.long) |
| dim = self.embedding_dim |
| cols = torch.arange(dim, device=indices.device, dtype=torch.long) |
| lin = indices[:, None] * dim + cols[None, :] |
| pack_idx = lin // 5 |
| trit_pos = lin - pack_idx * 5 |
| packed = self.T_packed[pack_idx].to(torch.long) |
| divisors = torch.tensor([1, 3, 9, 27, 81], device=indices.device, dtype=torch.long) |
| code = (packed // divisors[trit_pos]) % 3 |
| return (code.to(torch.int8) - 1) |
|
|
| def _expand_E_rows(self, indices): |
| indices = indices.reshape(-1).to(device=self.E.device, dtype=torch.long) |
| gpr = _ceil_div(self.embedding_dim, self.group_size) |
| E_rows = self.E.view(self.num_embeddings, gpr)[indices] |
| E_exp = E_rows.repeat_interleave(self.group_size, dim=1) |
| return E_exp[:, :self.embedding_dim] |
|
|
| @torch.no_grad() |
| def _set_T_rows(self, row_indices, rows): |
| row_indices = row_indices.reshape(-1).to(device=self.T_packed.device, dtype=torch.long) |
| rows = rows.to(device=self.T_packed.device, dtype=torch.int8).reshape(row_indices.numel(), self.embedding_dim) |
| divisors = [1, 3, 9, 27, 81] |
| for row_pos, row_idx in enumerate(row_indices.tolist()): |
| row = rows[row_pos] |
| for col in range(self.embedding_dim): |
| lin = row_idx * self.embedding_dim + col |
| pack_idx = lin // 5 |
| trit_pos = lin - pack_idx * 5 |
| divisor = divisors[trit_pos] |
| old = int(self.T_packed[pack_idx].item()) |
| old_code = (old // divisor) % 3 |
| new_code = int(row[col].item()) + 1 |
| if old_code != new_code: |
| self.T_packed[pack_idx] = old - old_code * divisor + new_code * divisor |
|
|
| def _expand_E(self): |
| out_dim, in_dim = tuple(self._T_shape.tolist()) |
| gpr = _ceil_div(in_dim, self.group_size) |
| E_2d = self.E.view(out_dim, gpr) |
| E_exp = E_2d.repeat_interleave(self.group_size, dim=1) |
| return E_exp[:, :in_dim] |
|
|
| def _ensure_E_accum(self): |
| if not hasattr(self, "E_accum"): |
| self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8)) |
| elif self.E_accum.shape != self.E.shape or self.E_accum.device != self.E.device: |
| self.E_accum = torch.zeros_like(self.E, dtype=torch.int8) |
| return self.E_accum |
|
|
| def forward(self, indices): |
| use_sparse = self.num_embeddings >= self.sparse_threshold |
| if use_sparse: |
| idx_flat = indices.reshape(-1).to(device=self.T_packed.device, dtype=torch.long) |
| T_rows = self._get_T_rows(idx_flat) |
| E_exp = self._expand_E_rows(idx_flat) |
| w_eff = torch.exp2(E_exp.float()) * T_rows.float() |
| w_eff_grad = w_eff.detach().requires_grad_(torch.is_grad_enabled()) |
| if torch.is_grad_enabled(): |
| comp_name, _ = _COMPONENT_CONTEXT.get() |
| def capture_sparse_grad(grad): |
| suffix = f"_{comp_name}" if comp_name is not None else "" |
| setattr(self, f"_hook_sparse_indices{suffix}", idx_flat.detach()) |
| setattr(self, f"_hook_sparse_grad_sign{suffix}", grad.reshape(-1, self.embedding_dim).sign().to(torch.int8).detach()) |
| setattr(self, f"_hook_sparse_T{suffix}", T_rows.detach()) |
| w_eff_grad.register_hook(capture_sparse_grad) |
| out = w_eff_grad.reshape(*indices.shape, self.embedding_dim) |
| return F.normalize(out, dim=-1) if self.normalize else out |
| if indices.is_cuda and _HAS_TRITON and _TritonTernaryEmbedFn is not None: |
| dummy = torch.zeros(1, device=indices.device, requires_grad=True) |
| out = _TritonTernaryEmbedFn.apply(indices, dummy, self) |
| else: |
| T = self._get_T() |
| w_eff = torch.exp2(self._expand_E().float()) * T.float() |
| w_eff_grad = w_eff.detach().requires_grad_(True) |
| self._hook_T = T |
| def capture_w_grad(grad_w): |
| self._hook_grad_T_sign = grad_w.sign().to(torch.int8) |
| w_eff_grad.register_hook(capture_w_grad) |
| out = F.embedding(indices, w_eff_grad) |
| return F.normalize(out, dim=-1) if self.normalize else out |
|
|
| def ternary_step(self, accum_threshold=3): |
| if hasattr(self, "_hook_sparse_indices") and hasattr(self, "_hook_sparse_grad_sign"): |
| return self._sparse_ternary_step(accum_threshold=accum_threshold) |
| if hasattr(self, "_hook_grad_T_sign"): |
| if hasattr(self, "_accumulate_corr_from_grad_sign"): |
| self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign) |
| del self._hook_grad_T_sign |
|
|
| def update_E(self, loss_signal=None): |
| if hasattr(self, "_hook_sparse_indices") and hasattr(self, "_hook_sparse_grad_sign"): |
| return self._sparse_update_E(loss_signal=loss_signal) |
|
|
| @torch.no_grad() |
| def _sparse_ternary_step(self, accum_threshold=3): |
| indices = self._hook_sparse_indices.to(device=self.T_accum.device, dtype=torch.long) |
| grad_sign = self._hook_sparse_grad_sign.to(device=self.T_accum.device, dtype=torch.int16) |
| if indices.numel() == 0: |
| return |
| unique, inverse = torch.unique(indices, return_inverse=True) |
| grad_sum = torch.zeros(unique.numel(), self.embedding_dim, device=self.T_accum.device, dtype=torch.int16) |
| grad_sum.index_add_(0, inverse, grad_sign) |
| grad_step = grad_sum.sign().to(torch.int16) * int(getattr(self, "_t_accum_step", 1)) |
| current = self.T_accum[unique].to(torch.int16) |
| updated = torch.clamp(current - grad_step, -128, 127).to(torch.int8) |
|
|
| pgt = getattr(self, "per_group_threshold", None) |
| if pgt is not None: |
| gpr = _ceil_div(self.embedding_dim, self.group_size) |
| threshold = pgt.view(self.num_embeddings, gpr)[unique] |
| threshold = threshold.unsqueeze(-1).expand(unique.numel(), gpr, self.group_size) |
| threshold = threshold.reshape(unique.numel(), gpr * self.group_size)[:, :self.embedding_dim] |
| threshold = threshold.to(updated.device) |
| flip_up = updated > threshold |
| flip_down = updated < -threshold |
| else: |
| flip_up = updated > accum_threshold |
| flip_down = updated < -accum_threshold |
| self._had_flip = bool((flip_up | flip_down).any().item()) |
| if self._had_flip: |
| rows = self._get_T_rows(unique).to(updated.device) |
| rows = torch.where(flip_up, torch.ones_like(rows), torch.where(flip_down, -torch.ones_like(rows), rows)) |
| self._set_T_rows(unique, rows) |
| updated = torch.where(flip_up | flip_down, torch.zeros_like(updated), updated) |
| self.T_accum[unique] = updated |
| del self._hook_sparse_indices |
| del self._hook_sparse_grad_sign |
| if hasattr(self, "_hook_sparse_T"): |
| del self._hook_sparse_T |
|
|
| @torch.no_grad() |
| def _sparse_update_E(self, loss_signal=None): |
| indices = self._hook_sparse_indices.to(device=self.E.device, dtype=torch.long) |
| grad_sign = self._hook_sparse_grad_sign.to(device=self.E.device, dtype=torch.int16) |
| T_rows = self._hook_sparse_T if hasattr(self, "_hook_sparse_T") else self._get_T_rows(indices) |
| T_rows = T_rows.to(device=self.E.device, dtype=torch.int16) |
| if indices.numel() == 0: |
| return |
| unique, inverse = torch.unique(indices, return_inverse=True) |
| gpr = _ceil_div(self.embedding_dim, self.group_size) |
| total_in = gpr * self.group_size |
| signed = grad_sign * T_rows |
| grouped = F.pad(signed, (0, total_in - self.embedding_dim)).view(indices.numel(), gpr, self.group_size) |
| score = grouped.sum(dim=2) |
| delta = torch.where( |
| score > 0, |
| torch.full_like(score, -1, dtype=torch.int16), |
| torch.where(score < 0, torch.ones_like(score, dtype=torch.int16), torch.zeros_like(score, dtype=torch.int16)), |
| ) |
| delta_sum = torch.zeros(unique.numel(), gpr, device=self.E.device, dtype=torch.int16) |
| delta_sum.index_add_(0, inverse, delta) |
| delta_sign = delta_sum.sign() |
| e_idx = unique[:, None] * gpr + torch.arange(gpr, device=self.E.device, dtype=torch.long)[None, :] |
| accum = torch.clamp(self.E_accum[e_idx].to(torch.int16) + delta_sign, -128, 127) |
| threshold = int(getattr(self, "_e_accum_threshold", 4)) |
| step = torch.where( |
| accum >= threshold, |
| torch.ones_like(accum, dtype=torch.int16), |
| torch.where(accum <= -threshold, torch.full_like(accum, -1, dtype=torch.int16), torch.zeros_like(accum, dtype=torch.int16)), |
| ) |
| self.E[e_idx] = torch.clamp(self.E[e_idx].to(torch.int16) + step, -128, 127).to(torch.int8) |
| self.E_accum[e_idx] = (accum - step * threshold).to(torch.int8) |
|
|
|
|
|
|
| class TernaryVQCodebook(nn.Module): |
| def __init__(self, codebook_size, codebook_dim, commitment_weight=1.0, |
| tscale_type=TScaleType.T32, exact_lookup_max=16384, |
| candidate_count=256): |
| super().__init__() |
| self.codebook_size = codebook_size |
| self.codebook_dim = codebook_dim |
| self.commitment_weight = commitment_weight |
| self.exact_lookup_max = exact_lookup_max |
| self.candidate_count = candidate_count |
| self.threshold_ema_dead_code = 2 |
| self.table = TernaryEmbeddingTable(codebook_size, codebook_dim, tscale_type=tscale_type, normalize=True) |
| self.register_buffer("cluster_size", torch.zeros(codebook_size, dtype=torch.int16)) |
|
|
| @property |
| def embed(self): |
| idx = torch.arange(self.codebook_size, device=self.table.T_packed.device) |
| return self.table(idx) |
|
|
| def _candidate_ids(self, flat): |
| c = min(self.candidate_count, self.codebook_size) |
| take = min(flat.shape[1], 16) |
| primes = torch.tensor( |
| [1009, 9176, 6361, 5333, 4447, 3469, 2531, 1613, |
| 811, 421, 211, 109, 59, 31, 17, 7], |
| device=flat.device, dtype=torch.float32, |
| )[:take] |
| signed = torch.sign(flat[:, :take].float()) |
| base = torch.abs(torch.round((signed * primes).sum(dim=1) * 104729)).to(torch.long) |
| offsets = torch.arange(c, device=flat.device, dtype=torch.long) |
| stride = 2_654_435_761 |
| return (base[:, None] + offsets[None, :] * stride) % self.codebook_size |
|
|
| def _lookup(self, flat): |
| if self.codebook_size <= self.exact_lookup_max: |
| x_norm = F.normalize(flat.float(), dim=-1) |
| codebook = self.embed.to(device=flat.device) |
| sim = x_norm @ codebook.T |
| indices = sim.argmax(dim=-1) |
| quantized = codebook[indices] |
| return quantized, indices |
|
|
| candidate_ids = self._candidate_ids(flat) |
| x_norm = F.normalize(flat.float(), dim=-1) |
| n, c, d = flat.shape[0], candidate_ids.shape[1], flat.shape[1] |
| chunk = 64 |
| quantized = torch.empty_like(flat) |
| indices = torch.empty(n, dtype=torch.long, device=flat.device) |
| for start in range(0, n, chunk): |
| end = min(start + chunk, n) |
| chunk_ids = candidate_ids[start:end] |
| chunk_vecs = self.table(chunk_ids).float() |
| chunk_norm = F.normalize(chunk_vecs, dim=-1) |
| chunk_sim = (chunk_norm * x_norm[start:end].unsqueeze(1)).sum(dim=-1) |
| chunk_best = chunk_sim.argmax(dim=-1) |
| indices[start:end] = candidate_ids[start:end].gather(1, chunk_best.unsqueeze(1)).squeeze(1) |
| quantized[start:end] = chunk_vecs[torch.arange(end - start, device=flat.device), chunk_best] |
| return quantized, indices |
|
|
| def forward(self, x): |
| orig_shape = x.shape |
| flat = x.reshape(-1, self.codebook_dim) |
| quantized, indices = self._lookup(flat) |
| commitment = self.commitment_weight * ( |
| F.mse_loss(flat.float(), quantized.detach().float()) |
| + 0.25 * F.mse_loss(quantized.float(), flat.detach().float()) |
| ) |
| quantized = flat + (quantized - flat).detach() |
| with torch.no_grad(): |
| unique, counts = torch.unique(indices, return_counts=True) |
| current = self.cluster_size[unique].to(torch.int32) |
| updated = torch.clamp(current + counts.to(device=current.device, dtype=torch.int32), 0, 32767).to(torch.int16) |
| self.cluster_size[unique] = updated |
| return quantized.reshape(orig_shape), indices.reshape(orig_shape[:-1]), commitment |
|
|
|
|
| class GNNLoRAAdapter(nn.Module): |
| def __init__(self, dim, rank=32, max_hops=4): |
| super().__init__() |
| self.max_hops = max_hops |
| self.down = TernaryScaleTensor(dim, rank, tscale_type=TScaleType.T32) |
| self.up = TernaryScaleTensor(rank, dim, tscale_type=TScaleType.T32) |
| self.scale = TernaryEmbeddingTable(max_hops, rank, tscale_type=TScaleType.T32) |
|
|
| def forward(self, x, hop_t): |
| t_idx = min(hop_t, self.max_hops - 1) |
| s = self.scale(torch.tensor(t_idx, device=x.device)) |
| return self.up(self.down(x) * s) |
|
|
|
|
| class HaltingUnit(nn.Module): |
| def __init__(self, dim, tscale_type=TScaleType.T32): |
| super().__init__() |
| self.proj = TernaryScaleTensor(dim, 1, tscale_type=tscale_type) |
| self.norm = TernaryRMSNorm(dim, tscale_type=tscale_type) |
|
|
| def forward(self, x): |
| return torch.sigmoid(self.proj(self.norm(x))) |
|
|
|
|
| class _NgramHashMapping: |
| """N-gram hash mapping with CPU offloading (Spider Engram style). |
| |
| Hashes token sequences to fixed-size embedding indices. Hash computation |
| runs on CPU via numpy, O(1) per token via precomputed multipliers. |
| """ |
|
|
| def __init__(self, max_ngram_size, num_heads, table_size_base, layer_seed=0): |
| self.max_ngram_size = max_ngram_size |
| self.num_heads = num_heads |
| self.num_ngram_orders = max_ngram_size - 1 |
|
|
| import numpy as np |
| PRIME_1 = 10007 |
| g = torch.Generator() |
| g.manual_seed(int(layer_seed + PRIME_1 * int(layer_seed))) |
| r = torch.randint(0, 1 << 30, (max_ngram_size,), generator=g, dtype=torch.int64) |
| self.multipliers = r.numpy() * 2 + 1 |
|
|
| seen_primes = set() |
| self.prime_table_sizes = [] |
| for _ in range(self.num_ngram_orders): |
| head_sizes = [] |
| ps = table_size_base - 1 |
| for _ in range(num_heads): |
| p = self._next_prime(ps, seen_primes) |
| seen_primes.add(p) |
| head_sizes.append(p) |
| ps = p |
| self.prime_table_sizes.append(head_sizes) |
|
|
| self.all_head_sizes = [s for sub in self.prime_table_sizes for s in sub] |
| offsets = [0] |
| for s in self.all_head_sizes[:-1]: |
| offsets.append(offsets[-1] + s) |
| self.offsets_arr = offsets |
| self.total_slots = sum(self.all_head_sizes) |
|
|
| @staticmethod |
| def _next_prime(n, seen): |
| while n in seen or not _is_prime(n): |
| n -= 1 |
| return n |
|
|
| def compute_hashes(self, token_ids): |
| import numpy as np |
| x = token_ids.cpu().numpy().astype(np.int64) |
| B, T = x.shape |
|
|
| shifts = [x] |
| for k in range(1, self.max_ngram_size): |
| shifts.append(np.pad(x, ((0, 0), (k, 0)), constant_values=0)[:, :T]) |
|
|
| all_hashes = [] |
| for order_idx in range(self.num_ngram_orders): |
| n = order_idx + 2 |
| mix = shifts[0] * self.multipliers[0] |
| for k in range(1, n): |
| mix = np.bitwise_xor(mix, shifts[k].astype(np.int64) * self.multipliers[k]) |
| for j, ms in enumerate(self.prime_table_sizes[order_idx]): |
| all_hashes.append((mix % ms).astype(np.int64, copy=False)) |
|
|
| result = np.stack(all_hashes, axis=2) |
| return torch.from_numpy(result).to(device=token_ids.device) |
|
|
|
|
| def _is_prime(n): |
| if n < 2: |
| return False |
| import math |
| for i in range(2, int(math.sqrt(n)) + 1): |
| if n % i == 0: |
| return False |
| return True |
|
|
|
|
| class MemGram(nn.Module): |
| """Engram-style associative memory with O(1) hashed lookup (CPU offloaded). |
| |
| Features: |
| - O(1) hash -> index -> embedding lookup (no search, no decay for retrieval) |
| - CPU-offloaded hash computation (numpy) |
| - Single offset-stacked embedding table (not per-head tables) |
| - Gated retrieval: sigmoid(Q*K/sqrt(d)) gates the memory read |
| - Depthwise conv1d processes retrieved memory (Engram-style) |
| - No strength/decay buffers (decay is handled by GraphMoE usage frequency) |
| - MemGram lookups do NOT affect KG decaying (separate mechanisms) |
| """ |
|
|
| def __init__(self, struct_primes=[64901, 64919, 64921, 64927, 64937, 64951, 64969, 64997, |
| 65003, 65011, 65027, 65029, 65033, 65053, 65063, 65071], |
| conv_primes=[8009, 8011, 8017, 8039], |
| embed_dim=64, hidden_dim=HIDDEN_DIM, key_dim=32, |
| max_ngram_size=3, num_hash_heads=4, layer_seed=0): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.key_dim = key_dim |
| self.hidden_dim = hidden_dim |
| self.n_struct_heads = len(struct_primes) |
| self.n_conv_heads = len(conv_primes) |
|
|
| self.struct_hash = _NgramHashMapping( |
| max_ngram_size=max_ngram_size, num_heads=num_hash_heads, |
| table_size_base=struct_primes[0], layer_seed=layer_seed, |
| ) |
| self.conv_hash = _NgramHashMapping( |
| max_ngram_size=max_ngram_size, num_heads=num_hash_heads, |
| table_size_base=conv_primes[0], layer_seed=layer_seed + 1000, |
| ) |
|
|
| total_heads = self.struct_hash.num_ngram_orders * num_hash_heads |
| self.total_mem_dim = total_heads * embed_dim |
|
|
| total_slots = self.struct_hash.total_slots + self.conv_hash.total_slots |
| self.mem_embed = nn.Embedding(total_slots, embed_dim) |
|
|
| self.k_proj = nn.Linear(self.total_mem_dim, key_dim, bias=False) |
| self.q_proj = nn.Linear(hidden_dim, key_dim, bias=False) |
| self.v_proj = nn.Linear(self.total_mem_dim, hidden_dim, bias=False) |
|
|
| with torch.no_grad(): |
| self.v_proj.weight.zero_() |
|
|
| self.conv_norm = nn.RMSNorm(hidden_dim) |
| self.conv = nn.Conv1d( |
| hidden_dim, hidden_dim, |
| kernel_size=4, padding=9, dilation=3, groups=hidden_dim, |
| ) |
| with torch.no_grad(): |
| self.conv.weight.zero_() |
| if self.conv.bias is not None: |
| self.conv.bias.zero_() |
|
|
| def _retrieve(self, token_ids, hash_mapping): |
| hash_ids = hash_mapping.compute_hashes(token_ids) |
| B, T, H = hash_ids.shape |
| flat_ids = hash_ids.reshape(B * T, H) |
| offsets = torch.tensor(hash_mapping.offsets_arr, device=flat_ids.device, dtype=torch.long) |
| emb = self.mem_embed(flat_ids + offsets) |
| return emb.reshape(B, T, H * self.embed_dim) |
|
|
| def forward(self, vq_indices, hidden_state): |
| B, T, D = hidden_state.shape |
|
|
| struct_mem = self._retrieve(vq_indices[:, 1:], self.struct_hash) |
| conv_mem = self._retrieve(vq_indices[:, 1:], self.conv_hash) |
| mem = struct_mem + conv_mem |
|
|
| idx_end = mem.shape[1] |
| q_proj = self.q_proj(hidden_state[:, :idx_end]) |
| k = self.k_proj(mem) |
| v = self.v_proj(mem) |
| gate = torch.sigmoid((q_proj * k).sum(dim=-1, keepdim=True) / (self.key_dim ** 0.5)) |
| v_gated = gate * v |
|
|
| v_normed = self.conv_norm(v_gated) |
| v_t = v_normed.transpose(1, 2) |
| conv_out = self.conv(v_t) |
| conv_out = conv_out[:, :, :v_t.shape[-1]].transpose(1, 2) |
| output = hidden_state[:, :idx_end] + F.silu(conv_out) + v_gated |
|
|
| if idx_end < T: |
| output = F.pad(output, (0, 0, 0, T - idx_end)) |
| return output |
|
|
| def retrieve_cb(self, vq_indices): |
| B, T = vq_indices.shape |
| struct_mem = self._retrieve(vq_indices[:, 1:], self.struct_hash) |
| conv_mem = self._retrieve(vq_indices[:, 1:], self.conv_hash) |
| mem = struct_mem + conv_mem |
| idx_end = mem.shape[1] |
| pad = torch.zeros(B, T - idx_end, mem.shape[2], device=mem.device) |
| mem = torch.cat([mem, pad], dim=1) |
| q = mem.mean(dim=-1, keepdim=True) |
| gate = torch.sigmoid(q) |
| return gate * mem |
|
|
|
|
| _BOUNDARY_TOKEN_MAP = { |
| SPECIAL_VOCAB['BOS']: 0, |
| SPECIAL_VOCAB['SYSTEM']: 1, |
| SPECIAL_VOCAB['USER']: 2, |
| SPECIAL_VOCAB['ASSISTANT']: 3, |
| } |
|
|
|
|
| class LTIInjection(nn.Module): |
| """LTI state injection: h = A*h + B*e + trans_out. |
| |
| Spectral radius < 1 guaranteed by construction via ZOH discretization. |
| Prevents divergence in recurrent/ACT loops at high dimensions. |
| """ |
| def __init__(self, dim: int): |
| super().__init__() |
| self.log_A = nn.Parameter(torch.zeros(dim)) |
| self.log_dt = nn.Parameter(torch.zeros(1)) |
| self.B = nn.Parameter(torch.ones(dim) * 0.1) |
| for p in (self.log_A, self.log_dt, self.B): |
| p.requires_grad_(False) |
|
|
| def get_A(self): |
| return torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20))) |
|
|
| def forward(self, h, e, trans_out): |
| return self.get_A() * h + self.B * e + trans_out |
|
|
|
|
| class ByteHead(nn.Module): |
| """Deep 3-layer MLP byte prediction head with ACT loop. |
| |
| Architecture: 8192 → 16384 → 8192 → 16384 → 288 |
| ACT: up to 3 iterations, halts when argmax stable for 2 consecutive steps. |
| """ |
| def __init__(self, tscale_type=TScaleType.T32, |
| act_max_iters=BYTEHEAD_ACT_MAX_ITERS, |
| act_halt_consecutive=BYTEHEAD_ACT_HALT_CONSECUTIVE): |
| super().__init__() |
| H = HIDDEN_DIM |
| W = HIDDEN_DIM * 2 |
| self.act_max_iters = act_max_iters |
| self.act_halt_consecutive = act_halt_consecutive |
| self._last_ponder = 0.0 |
|
|
| self.norm = TernaryRMSNorm(H, tscale_type=tscale_type) |
| self.up = TernaryScaleTensor(H, W, tscale_type=tscale_type) |
| self.up_norm = TernaryRMSNorm(W, tscale_type=tscale_type) |
| self.hidden = TernaryScaleTensor(W, H, tscale_type=tscale_type) |
| self.hidden_norm = TernaryRMSNorm(H, tscale_type=tscale_type) |
| self.out = TernaryScaleTensor(H, W, tscale_type=tscale_type) |
| self.out_norm = TernaryRMSNorm(W, tscale_type=tscale_type) |
| self.head = TernaryScaleTensor(W, VOCAB, tscale_type=tscale_type) |
|
|
| if act_max_iters > 1: |
| self.act_residual = TernaryScaleTensor(VOCAB, H, tscale_type=tscale_type) |
| self.lti = LTIInjection(H) |
| else: |
| self.act_residual = None |
| self.lti = None |
|
|
| def forward(self, x): |
| if self.act_max_iters <= 1 or self.act_residual is None: |
| hn = F.silu(self.up(self.norm(x))) |
| hn = F.silu(self.hidden(self.up_norm(hn))) |
| hn = F.silu(self.out(self.hidden_norm(hn))) |
| return self.head(self.out_norm(hn)) |
|
|
| h = x |
| x_initial = x |
| prev_argmax = None |
| stable_count = 0 |
| total_iters = 0 |
|
|
| for i in range(self.act_max_iters): |
| hn = F.silu(self.up(self.norm(h))) |
| hn = F.silu(self.hidden(self.up_norm(hn))) |
| hn = F.silu(self.out(self.hidden_norm(hn))) |
| logits = self.head(self.out_norm(hn)) |
|
|
| curr_argmax = logits.argmax(dim=-1) |
| if prev_argmax is not None and (curr_argmax == prev_argmax).all(): |
| stable_count += 1 |
| else: |
| stable_count = 0 |
|
|
| total_iters = i + 1 |
| if stable_count >= self.act_halt_consecutive: |
| break |
|
|
| prev_argmax = curr_argmax |
| trans_out = self.act_residual(logits) |
| h = self.lti(h, x_initial, trans_out) |
|
|
| self._last_ponder = total_iters / max(self.act_max_iters, 1) |
| return logits |
|
|
|
|
| class OutputRouter(nn.Module): |
| """Routes HIDDEN_DIM relational tokens to ByteHead, VideoHead, or TalkerHead. |
| |
| 3-layer MLP when depth=3, 2-layer when depth=2, single projection when depth=1. |
| Argmax at inference, soft weighted routing at training. |
| """ |
| def __init__(self, tscale_type=TScaleType.T32, depth=3): |
| super().__init__() |
| if depth >= 3: |
| self.hidden1 = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM, tscale_type=tscale_type) |
| self.hidden1_norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type) |
| self.hidden2 = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM // 4, tscale_type=tscale_type) |
| self.gate = TernaryScaleTensor(HIDDEN_DIM // 4, 4, tscale_type=tscale_type) |
| elif depth == 2: |
| self.hidden1 = None |
| self.hidden1_norm = None |
| self.hidden2 = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM // 4, tscale_type=tscale_type) |
| self.gate = TernaryScaleTensor(HIDDEN_DIM // 4, 4, tscale_type=tscale_type) |
| else: |
| self.hidden1 = None |
| self.hidden1_norm = None |
| self.hidden2 = None |
| self.gate = TernaryScaleTensor(HIDDEN_DIM, 4, tscale_type=tscale_type) |
| |
|
|
| def forward(self, x, training=False): |
| h = x |
| if self.hidden1 is not None: |
| h = F.silu(self.hidden1_norm(self.hidden1(h))) |
| if self.hidden2 is not None: |
| h = self.hidden2(h) |
| logits = self.gate(h) |
| logits = torch.nan_to_num(logits, nan=0.0, posinf=30.0, neginf=-30.0).clamp(-30.0, 30.0) |
| if training: |
| weights = F.softmax(logits, dim=-1) |
| return weights, logits |
| return logits.argmax(dim=-1) |
|
|
|
|
| class KGVQCodebook(TernaryVQCodebook): |
| """Compatibility wrapper for the KG/composite VQ. |
| |
| The old implementation kept float32 `embed` and `embed_avg` buffers. The |
| production path now uses the same packed ternary/int8 backing table as the |
| shared VQ so default 5M-code KG construction cannot allocate hidden float |
| codebook state. |
| """ |
| def __init__(self, codebook_size=KGVQ_CODEBOOK_SIZE, codebook_dim=KGVQ_CODEBOOK_DIM, |
| decay=KGVQ_DECAY, commitment_weight=KGVQ_COMMITMENT_WEIGHT, |
| threshold_ema_dead_code=KGVQ_DEAD_CODE_THRESHOLD): |
| super().__init__( |
| codebook_size=codebook_size, |
| codebook_dim=codebook_dim, |
| commitment_weight=commitment_weight, |
| ) |
| self.decay = decay |
| self.threshold_ema_dead_code = threshold_ema_dead_code |
|
|
| @property |
| def embed(self): |
| if self.codebook_size > self.exact_lookup_max: |
| raise RuntimeError( |
| "Full KG VQ materialization is disabled for large ternary codebooks; " |
| "query rows through `table(indices)` instead." |
| ) |
| return super().embed |
|
|
| def _ema_update(self, x_flat, indices): |
| unique, counts = torch.unique(indices, return_counts=True) |
| current = self.cluster_size[unique].to(torch.int32) |
| updated = torch.clamp( |
| current + counts.to(device=current.device, dtype=torch.int32), |
| 0, |
| 32767, |
| ).to(torch.int16) |
| self.cluster_size[unique] = updated |
|
|
| def _dead_code_reset(self, x_flat): |
| return None |
|
|
|
|
| class CompositeProposalHead(nn.Module): |
| """Multi-proposal head from pooled GNN output (Phase 17). |
| |
| Projects GNN pool output (graph_pool_out [B, D]) to K_MAX composite motif |
| proposals, quantizes via KGVQ, and applies ACT-style halting. |
| """ |
| def __init__(self, dim=HIDDEN_DIM, codebook_dim=KGVQ_CODEBOOK_DIM, |
| k_max=K_MAX_COMPOSITES, codebook_size=KGVQ_CODEBOOK_SIZE, |
| tscale_type=TScaleType.T32): |
| super().__init__() |
| self.dim = dim |
| self.k_max = k_max |
| self.codebook_dim = codebook_dim |
| self.proj = TernaryScaleTensor(dim, k_max * codebook_dim, tscale_type=tscale_type) |
| self.kgvq = TernaryVQCodebook(codebook_size=codebook_size, codebook_dim=codebook_dim, |
| tscale_type=tscale_type) |
| self.halt_gate = TernaryScaleTensor(dim, k_max, tscale_type=tscale_type) |
| self.diversity_weight = 0.1 |
|
|
| def forward(self, pool_out): |
| B = pool_out.shape[0] |
| projections = self.proj(pool_out).view(B, self.k_max, self.codebook_dim) |
| quantized, composite_ids, vq_loss = self.kgvq(projections) |
|
|
| halt_logits = self.halt_gate(pool_out).clamp(-12.0, 12.0) |
| halt = torch.sigmoid(halt_logits) |
| composite_ids = composite_ids.masked_fill(halt < 0.5, -1) |
|
|
| normed = F.normalize(projections, dim=-1) |
| sim_matrix = normed @ normed.transpose(-1, -2) |
| triu = torch.triu(sim_matrix, diagonal=1) |
| n_pairs = self.k_max * (self.k_max - 1) / 2 |
| diversity_loss = triu.sum(dim=(-1, -2)).mean() / max(n_pairs, 1) |
| diversity_loss = diversity_loss * self.diversity_weight |
|
|
| return composite_ids, vq_loss + diversity_loss, halt |
|
|
|
|
| class MoEGraph(nn.Module): |
| """Fused graph traversal + centroid-based MoE routing + ACT halting. |
| |
| Each ACT iteration: traverse KG → aggregate neighbor emb → centroid route → |
| run expert → halt check. All operations at MG_WORKSPACE_DIM (1024). |
| |
| Replaces: TernaryGraph + GraphMoEGate + GraphACTCell + SharedProjectionMoE + MoEACTCell. |
| """ |
| def __init__(self, cb_dim=MG_WORKSPACE_DIM, trigram_dim=HIDDEN_DIM, |
| codebook_dim=CODEBOOK_DIM, |
| num_experts=MG_N_EXPERTS, core_rank=MG_CORE_RANK, |
| shared_inter=MG_SHARED_INTER, max_iters=MG_ACT_ITERS, |
| halt_threshold=0.99, tscale_type=TScaleType.T32, |
| codebook_size=CODEBOOK_SIZE, |
| active_graph_max_nodes=4096, |
| top_k=1): |
| super().__init__() |
| self.cb_dim = cb_dim |
| self.trigram_dim = trigram_dim |
| self.codebook_dim = codebook_dim |
| self.num_experts = num_experts |
| self.core_rank = core_rank |
| self.shared_inter = shared_inter |
| self.max_iters = max_iters |
| self.halt_threshold = halt_threshold |
| self.codebook_size = codebook_size |
| self.active_graph_max_nodes = active_graph_max_nodes |
| self.top_k = top_k |
|
|
| self.down_proj = TernaryScaleTensor(trigram_dim, cb_dim, tscale_type=tscale_type) |
| self.down_norm = TernaryRMSNorm(trigram_dim, tscale_type=tscale_type) |
| self.up_proj = TernaryScaleTensor(cb_dim, trigram_dim, tscale_type=tscale_type) |
| self.up_norm = TernaryRMSNorm(cb_dim, tscale_type=tscale_type) |
| self.attn_down_proj = TernaryScaleTensor(trigram_dim, cb_dim, tscale_type=tscale_type) |
| self.codebook_up = TernaryScaleTensor(codebook_dim, cb_dim, tscale_type=tscale_type) |
|
|
| self.use_active_edge_store = self.codebook_size > self.active_graph_max_nodes |
| self.active_edge_capacity = max(int(self.active_graph_max_nodes) * 16, 65_536) |
| if self.use_active_edge_store: |
| self.register_buffer("edge_index", torch.zeros(2, 0, dtype=torch.int32)) |
| self.register_buffer("edge_attr", torch.zeros(0, dtype=torch.int8)) |
| self.register_buffer("edge_score", torch.zeros(0, dtype=torch.int8)) |
| self.register_buffer("active_edge_src", torch.full((self.active_edge_capacity,), -1, dtype=torch.int32)) |
| self.register_buffer("active_edge_dst", torch.full((self.active_edge_capacity,), -1, dtype=torch.int32)) |
| self.register_buffer("active_edge_attr", torch.zeros(self.active_edge_capacity, dtype=torch.int8)) |
| self.register_buffer("active_edge_score", torch.zeros(self.active_edge_capacity, dtype=torch.int8)) |
| self.register_buffer("active_edge_ptr", torch.zeros((), dtype=torch.long)) |
| else: |
| num_edges = self.codebook_size * 10 |
| src = torch.arange(self.codebook_size, dtype=torch.int32).repeat_interleave(10) |
| dst = torch.randint(0, self.codebook_size, (num_edges,), dtype=torch.int32) |
| self.register_buffer("edge_index", torch.stack([src, dst], dim=0)) |
| edge_init = torch.randint(-1, 2, (num_edges,), dtype=torch.int8) |
| self.register_buffer("edge_attr", edge_init) |
| self.register_buffer("edge_score", torch.zeros(num_edges, dtype=torch.int8)) |
| self.register_buffer("_steps_since_requant", torch.tensor(0, dtype=torch.long)) |
| self.requant_every = KG_REQUANT_EVERY |
| self.kg_ternary_threshold = KG_TERNARY_THRESHOLD |
| self.kg_ema_alpha = KG_EMA_ALPHA |
|
|
| self.centroids = TernaryEmbeddingTable(num_experts, cb_dim, tscale_type=tscale_type, normalize=True) |
|
|
| self.shared_up_norm = TernaryRMSNorm(cb_dim, tscale_type=tscale_type) |
| self.shared_up = TernaryScaleTensor(cb_dim, shared_inter, tscale_type=tscale_type) |
| self.shared_down_norm = TernaryRMSNorm(shared_inter, tscale_type=tscale_type) |
| self.shared_down = TernaryScaleTensor(shared_inter, cb_dim, tscale_type=tscale_type) |
|
|
| self.W_gate = nn.ModuleList([ |
| TernaryScaleTensor(cb_dim, core_rank, tscale_type=tscale_type) |
| for _ in range(num_experts) |
| ]) |
| self.W_gate_norms = nn.ModuleList([ |
| TernaryRMSNorm(cb_dim, tscale_type=tscale_type) |
| for _ in range(num_experts) |
| ]) |
| self.W_transform = nn.ModuleList([ |
| TernaryScaleTensor(core_rank, shared_inter, tscale_type=tscale_type) |
| for _ in range(num_experts) |
| ]) |
| self.W_transform_norms = nn.ModuleList([ |
| TernaryRMSNorm(core_rank, tscale_type=tscale_type) |
| for _ in range(num_experts) |
| ]) |
|
|
| self.hop_lora = GNNLoRAAdapter(dim=cb_dim, rank=32, max_hops=max_iters) |
| self.halting = HaltingUnit(dim=cb_dim, tscale_type=tscale_type) |
| self.lti = LTIInjection(cb_dim) |
|
|
| self._codebook_embed = None |
| self._codebook_table = None |
|
|
| def _codebook_tensor(self, device): |
| if self._codebook_table is not None: |
| idx = torch.arange(self.codebook_size, device=device) |
| codebook = self._codebook_table(idx) |
| if codebook.shape[-1] != self.cb_dim: |
| codebook = self.codebook_up(codebook) |
| return codebook |
| if self._codebook_embed is not None: |
| codebook = self._codebook_embed.to(device=device).squeeze(0) |
| if codebook.shape[-1] != self.cb_dim: |
| codebook = self.codebook_up(codebook) |
| return codebook |
| return torch.zeros(self.codebook_size, self.cb_dim, device=device) |
|
|
| def _active_codebook_features(self, vq_indices): |
| if self._codebook_table is not None: |
| safe_idx = vq_indices.clamp(min=0, max=self.codebook_size - 1) |
| active_code = self._codebook_table(safe_idx) |
| elif self._codebook_embed is not None: |
| codebook = self._codebook_embed.to(device=vq_indices.device).squeeze(0) |
| safe_idx = vq_indices.clamp(min=0, max=codebook.shape[0] - 1) |
| active_code = codebook[safe_idx] |
| else: |
| return torch.zeros(*vq_indices.shape, self.cb_dim, device=vq_indices.device) |
| if active_code.shape[-1] != self.cb_dim: |
| active_code = self.codebook_up(active_code) |
| return active_code |
|
|
| def _neighbor_aggregate(self, node_features, threshold): |
| N, D = node_features.shape |
| aggregated = torch.zeros(self.codebook_size, D, device=node_features.device, dtype=node_features.dtype) |
| edge_ternary = StickyZoneSTE.apply(self.edge_attr, threshold) |
| src_features = node_features[self.edge_index[0]] |
| messages = edge_ternary.unsqueeze(1).to(node_features.dtype) * src_features |
| dst_idx = self.edge_index[1].unsqueeze(1).expand(-1, D) |
| aggregated.scatter_add_(0, dst_idx, messages) |
| return aggregated |
|
|
| def _run_expert_batch(self, x, expert_idx): |
| B, T, D = x.shape |
| N = B * T |
| x_flat = rearrange(x, 'b t d -> (b t) d') |
| exp_flat = rearrange(expert_idx, 'b t -> (b t)') |
| shared_hidden = F.silu(self.shared_up(self.shared_up_norm(x_flat))) |
| sort_idx = exp_flat.argsort() |
| sorted_experts = exp_flat[sort_idx] |
| expert_counts = torch.bincount(sorted_experts, minlength=self.num_experts) |
| expert_boundaries = torch.cumsum(expert_counts, dim=0) |
| out_flat = torch.zeros(N, D, device=x.device, dtype=x.dtype) |
| for e in range(self.num_experts): |
| start = expert_boundaries[e] - expert_counts[e] |
| end = expert_boundaries[e] |
| if start == end: |
| continue |
| tok_idx = sort_idx[start:end] |
| inp = x_flat[tok_idx] |
| sh = shared_hidden[tok_idx] |
| gate = self.W_gate[e](self.W_gate_norms[e](inp)) |
| core = self.W_transform[e](self.W_transform_norms[e](gate)) |
| expert_out = self.shared_down(self.shared_down_norm(core * sh)) |
| out_flat[tok_idx] = expert_out |
| return rearrange(out_flat, '(b t) d -> b t d', b=B, t=T) |
|
|
| def _run_expert(self, x, expert_idx): |
| return self._run_expert_batch(x, expert_idx) |
|
|
| def _active_node_add(self, vq_output, vq_indices): |
| return vq_output + self._active_codebook_features(vq_indices) |
|
|
| def forward(self, trigram_input, vq_indices, attention_output=None, |
| memgram_cb_output=None, threshold=0.05): |
| B, T, D = trigram_input.shape |
| device = trigram_input.device |
|
|
| x = self.down_proj(self.down_norm(trigram_input)) |
|
|
| attn_cb = None |
| if attention_output is not None: |
| attn_cb = self.attn_down_proj(self.down_norm(attention_output)) |
|
|
| halted = torch.zeros(B, T, device=device, dtype=torch.bool) |
| cumulative_p = torch.zeros(B, T, device=device) |
| acc = torch.zeros_like(x) |
| total_ponder = torch.zeros(B, T, device=device) |
| last_x = x |
| initial_x = x |
|
|
| use_active_graph = self.codebook_size > self.active_graph_max_nodes |
| node_features = None if use_active_graph else self._codebook_tensor(device) |
|
|
| for iter_t in range(self.max_iters): |
| if use_active_graph: |
| traversal = self._active_node_add(x, vq_indices) |
| else: |
| node_aggregated = self._neighbor_aggregate(node_features, threshold) |
| traversal = x + node_aggregated[vq_indices] |
|
|
| if attn_cb is not None: |
| traversal = traversal + attn_cb |
|
|
| if iter_t in [1, 3] and memgram_cb_output is not None: |
| memgram_raw = memgram_cb_output.to(device) |
| if memgram_raw.shape[-1] != self.cb_dim: |
| memgram_raw = memgram_raw.mean(dim=-1, keepdim=True).expand(-1, -1, self.cb_dim) |
| traversal = traversal + memgram_raw |
|
|
| traversal = traversal + self.hop_lora(traversal, iter_t) |
|
|
| trav_norm = F.normalize(traversal, dim=-1, eps=1e-8) |
| centroid_ids = torch.arange(self.num_experts, device=device) |
| cent_norm = F.normalize(self.centroids(centroid_ids), dim=-1, eps=1e-8) |
| scores = trav_norm @ cent_norm.T |
| if self.top_k <= 1: |
| _, expert_idx = scores.max(dim=-1) |
| expert_out = self._run_expert(traversal, expert_idx) |
| else: |
| scores_topk, topk_idx = scores.topk(k=self.top_k, dim=-1) |
| weights = F.softmax(scores_topk / 0.1, dim=-1) |
| expert_out = 0 |
| for i in range(self.top_k): |
| wi = weights[..., i:i+1] |
| ei = topk_idx[..., i] |
| expert_out = expert_out + wi * self._run_expert(traversal, ei) |
| last_x = expert_out |
|
|
| p = self.halting(expert_out).squeeze(-1) |
| still_running = ~halted |
| remainder = (1.0 - cumulative_p).clamp(min=0) |
| weight = torch.where( |
| cumulative_p + p >= self.halt_threshold, |
| remainder, p, |
| ) |
| weight = weight * still_running.float() |
| acc = acc + weight.unsqueeze(-1) * expert_out |
| cumulative_p = cumulative_p + p * still_running.float() |
| halted = halted | (cumulative_p >= self.halt_threshold) |
| total_ponder = total_ponder + (1.0 - cumulative_p).clamp(min=0) |
|
|
| x = self.lti(x, initial_x, expert_out) |
|
|
| if halted.all(): |
| break |
|
|
| never_halted = (~halted).float().unsqueeze(-1) |
| acc = acc + never_halted * last_x |
|
|
| output = self.up_proj(self.up_norm(acc)) |
| ponder_loss = total_ponder.mean() / self.max_iters |
|
|
| return output, ponder_loss |
|
|
| @torch.no_grad() |
| def update_kg_edges(self, all_vq_indices): |
| if self.use_active_edge_store: |
| self._update_active_edges(all_vq_indices) |
| return |
|
|
| unique_ids = torch.unique(all_vq_indices.to(device=self.edge_index.device, dtype=torch.int32)) |
| src_in_batch = torch.isin(self.edge_index[0], unique_ids) |
|
|
| if src_in_batch.any(): |
| dst_seen = torch.isin(self.edge_index[1][src_in_batch], unique_ids) |
| delta = torch.where( |
| dst_seen, |
| torch.ones_like(self.edge_score[src_in_batch], dtype=torch.int16), |
| torch.full_like(self.edge_score[src_in_batch], -1, dtype=torch.int16), |
| ) |
| score = torch.clamp(self.edge_score[src_in_batch].to(torch.int16) + delta, -128, 127) |
| self.edge_score[src_in_batch] = score.to(torch.int8) |
|
|
| self._requantize_dense_edges() |
|
|
| @torch.no_grad() |
| def _update_active_edges(self, all_vq_indices): |
| ids = all_vq_indices.to(device=self.active_edge_src.device, dtype=torch.int32) |
| if ids.numel() < 2: |
| self._steps_since_requant.add_(1) |
| return |
|
|
| seq = ids.reshape(-1, ids.shape[-1]) if ids.dim() > 1 else ids.reshape(1, -1) |
| src = seq[:, :-1].reshape(-1) |
| dst = seq[:, 1:].reshape(-1) |
| valid = (src >= 0) & (dst >= 0) & (src < self.codebook_size) & (dst < self.codebook_size) & (src != dst) |
| src = src[valid] |
| dst = dst[valid] |
| if src.numel() == 0: |
| self._steps_since_requant.add_(1) |
| return |
|
|
| n_edges = min(src.numel(), self.active_edge_capacity) |
| src = src[-n_edges:] |
| dst = dst[-n_edges:] |
| ptr = int(self.active_edge_ptr.item()) |
| slots = (torch.arange(n_edges, device=src.device, dtype=torch.long) + ptr) % self.active_edge_capacity |
|
|
| self.active_edge_src[slots] = src |
| self.active_edge_dst[slots] = dst |
| score = torch.clamp(self.active_edge_score[slots].to(torch.int16) + 1, -128, 127) |
| self.active_edge_score[slots] = score.to(torch.int8) |
| self.active_edge_attr[slots] = 1 |
| self.active_edge_ptr.fill_((ptr + n_edges) % self.active_edge_capacity) |
| self._requantize_active_edges() |
|
|
| @torch.no_grad() |
| def _requantize_dense_edges(self): |
| if self._steps_since_requant.item() < self.requant_every: |
| self._steps_since_requant.add_(1) |
| return |
| self.edge_attr = self._score_to_attr(self.edge_score) |
| score = self.edge_score.to(torch.int16) |
| score = torch.where(score > 0, score - 1, torch.where(score < 0, score + 1, score)) |
| self.edge_score = score.to(torch.int8) |
| self._steps_since_requant.zero_() |
|
|
| @torch.no_grad() |
| def _requantize_active_edges(self): |
| if self._steps_since_requant.item() < self.requant_every: |
| self._steps_since_requant.add_(1) |
| return |
| active = self.active_edge_src >= 0 |
| if active.any(): |
| self.active_edge_attr[active] = self._score_to_attr(self.active_edge_score[active]) |
| score = self.active_edge_score[active].to(torch.int16) |
| score = torch.where(score > 0, score - 1, torch.where(score < 0, score + 1, score)) |
| self.active_edge_score[active] = score.to(torch.int8) |
| self._steps_since_requant.zero_() |
|
|
| def _score_to_attr(self, score): |
| threshold = max(1, int(round(float(self.kg_ternary_threshold) * 8))) |
| score_i = score.to(torch.int16) |
| return torch.where( |
| score_i >= threshold, |
| torch.ones_like(score, dtype=torch.int8), |
| torch.where( |
| score_i <= -threshold, |
| torch.full_like(score, -1, dtype=torch.int8), |
| torch.zeros_like(score, dtype=torch.int8), |
| ), |
| ) |
|
|
| @torch.no_grad() |
| def monitor_graph_health(self, threshold=0.05): |
| if self.use_active_edge_store: |
| active = self.active_edge_src >= 0 |
| if not active.any(): |
| return { |
| "sparsity": 1.0, "isolated_nodes": self.codebook_size, |
| "avg_polarity": 0.0, "dead_edges": 0, |
| "score_mean": 0.0, "score_max": 0.0, |
| "active_edges": 0, |
| } |
| edge_attr = self.active_edge_attr[active] |
| edge_score = self.active_edge_score[active] |
| nodes_with_edges = torch.unique(torch.cat([self.active_edge_src[active], self.active_edge_dst[active]])) |
| else: |
| edge_attr = self.edge_attr |
| edge_score = self.edge_score |
| nodes_with_edges = torch.unique(torch.cat([self.edge_index[0], self.edge_index[1]])) |
|
|
| ternary_edge = edge_attr.sign() |
| sparsity = (ternary_edge == 0).float().mean().item() if ternary_edge.numel() else 1.0 |
| n_isolated = max(int(self.codebook_size) - int(nodes_with_edges.numel()), 0) |
| n_pos = (ternary_edge > 0).sum().item() |
| n_neg = (ternary_edge < 0).sum().item() |
| n_nonzero = n_pos + n_neg |
| avg_polarity = (n_pos - n_neg) / max(n_nonzero, 1) |
| dead_edges = ((ternary_edge == 0) & (edge_score != 0)).sum().item() |
| score_mean = edge_score.float().mean().item() if edge_score.numel() else 0.0 |
| score_max = edge_score.float().abs().max().item() if edge_score.numel() else 0.0 |
| return { |
| "sparsity": sparsity, "isolated_nodes": n_isolated, |
| "avg_polarity": avg_polarity, "dead_edges": dead_edges, |
| "score_mean": score_mean, "score_max": score_max, |
| "active_edges": int(ternary_edge.numel()), |
| } |
|
|
| def set_adjacency(self, edge_index, edge_attr_init=None): |
| self.use_active_edge_store = False |
| device = self.edge_attr.device |
| self.edge_index = edge_index.to(device=device, dtype=torch.int32) |
| if edge_attr_init is not None: |
| edge_attr = edge_attr_init.sign() * (edge_attr_init.abs() > 0).to(edge_attr_init.dtype) |
| self.edge_attr = edge_attr.to(device=device, dtype=torch.int8) |
| else: |
| self.edge_attr = torch.randint(-1, 2, (edge_index.size(1),), |
| device=device, dtype=torch.int8) |
| self.edge_score = self.edge_attr.clone() |
|
|