"""Components — core neural network modules for the ARB system.""" import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm, GROUP_SIZES, _COMPONENT_CONTEXT, _HAS_TRITON try: from .kernel.ternary_scale import _TritonTernaryEmbedFn except ImportError: _TritonTernaryEmbedFn = None from .converters.convert_to_ternary8 import pack_ternary, unpack_ternary from dataclasses import dataclass, field, fields from math import ceil as _ceil, log2 as _log2 from transformers import AutoModel, AutoFeatureExtractor from .config import VOCAB, EMBEDDING_DIM, HIDDEN_DIM, AUDIO_VOCAB, AUDIO_SR, AUDIO_FRAME_RATE, SPECIAL_VOCAB, CODEBOOK_DIM, CODEBOOK_SIZE, FFN_HIDDEN, CTX, THRESHOLD, KG_EMA_ALPHA, KG_REQUANT_EVERY, KG_TERNARY_THRESHOLD, KGVQ_CODEBOOK_SIZE, KGVQ_CODEBOOK_DIM, KGVQ_DECAY, KGVQ_COMMITMENT_WEIGHT, KGVQ_DEAD_CODE_THRESHOLD, K_MAX_COMPOSITES, MG_N_EXPERTS, MG_CORE_RANK, MG_SHARED_INTER, MG_ACT_ITERS, MG_WORKSPACE_DIM, BYTEHEAD_ACT_MAX_ITERS, BYTEHEAD_ACT_HALT_CONSECUTIVE _ceil_div = lambda a, b: _ceil(a / b) if b > 0 else 0 from .sequencers import ByteEmbedding @dataclass class LossWeights: lm: float = 1.0 vq_commitment: float = 1.0 moe_aux: float = 1.0 graph_l1: float = 0.001 graph_ponder: float = 1.0 moe_ponder: float = 1.0 moegraph_ponder: float = 1.0 memgram_decay_reg: float = 0.01 composite_vq: float = 1.0 @dataclass class LossComponents: lm: torch.Tensor = None vq_commitment: torch.Tensor = None moe_aux: torch.Tensor = None graph_l1: torch.Tensor = None graph_ponder: torch.Tensor = None moe_ponder: torch.Tensor = None moegraph_ponder: torch.Tensor = None memgram_decay_reg: torch.Tensor = None composite_vq: torch.Tensor = None weights: LossWeights = field(default_factory=LossWeights) @property def total(self) -> torch.Tensor: w = self.weights loss = None def add_component(current, weight, component): if component is None: return current weighted = weight * component return weighted if current is None else current + weighted loss = add_component(loss, w.lm, self.lm) loss = add_component(loss, w.vq_commitment, self.vq_commitment) loss = add_component(loss, w.moe_aux, self.moe_aux) loss = add_component(loss, w.graph_l1, self.graph_l1) loss = add_component(loss, w.graph_ponder, self.graph_ponder) loss = add_component(loss, w.moe_ponder, self.moe_ponder) loss = add_component(loss, w.moegraph_ponder, self.moegraph_ponder) loss = add_component(loss, w.memgram_decay_reg, self.memgram_decay_reg) loss = add_component(loss, w.composite_vq, self.composite_vq) if loss is None: raise ValueError("LossComponents.total requested with no active loss tensors") return loss @property def active_fields(self) -> list[tuple[str, torch.Tensor, float]]: result = [] for field in fields(self): name = field.name if name == 'weights': continue tensor = getattr(self, name) if tensor is not None: weight = getattr(self.weights, name) result.append((name, tensor, weight)) return result def log(self, writer, step, prefix="loss"): writer.add_scalar(f"{prefix}/total", self.total.item(), step) if self.lm is not None: writer.add_scalar(f"{prefix}/lm", self.lm.item(), step) if self.vq_commitment is not None: writer.add_scalar(f"{prefix}/vq_commitment", self.vq_commitment.item(), step) if self.moe_aux is not None: writer.add_scalar(f"{prefix}/moe_aux", self.moe_aux.item(), step) if self.graph_l1 is not None: writer.add_scalar(f"{prefix}/graph_l1", self.graph_l1.item(), step) if self.graph_ponder is not None: writer.add_scalar(f"{prefix}/graph_ponder", self.graph_ponder.item(), step) if self.moe_ponder is not None: writer.add_scalar(f"{prefix}/moe_ponder", self.moe_ponder.item(), step) if self.moegraph_ponder is not None: writer.add_scalar(f"{prefix}/moegraph_ponder", self.moegraph_ponder.item(), step) if self.memgram_decay_reg is not None: writer.add_scalar(f"{prefix}/memgram_decay_reg", self.memgram_decay_reg.item(), step) if self.composite_vq is not None: writer.add_scalar(f"{prefix}/composite_vq", self.composite_vq.item(), step) def backward(self, retain_graph=False): self.total.backward(retain_graph=retain_graph) class StickyZoneSTE(torch.autograd.Function): @staticmethod def forward(ctx, w, threshold): ctx.save_for_backward(w, torch.tensor(threshold)) return w.sign() * (w.abs() > threshold).to(w.dtype) @staticmethod def backward(ctx, grad_output): w, threshold_tensor = ctx.saved_tensors threshold = threshold_tensor.item() ratio = torch.clamp(w.abs() / threshold, 0.0, 1.0) return grad_output * ratio, None class TernaryEmbeddingTable(nn.Module): def __init__(self, num_embeddings, embedding_dim, tscale_type=TScaleType.T32, init_std=0.02, threshold=0.05, normalize=False): super().__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.tscale_type = tscale_type init_threshold = min(float(threshold), 0.5 * float(init_std)) if init_std > 0 else threshold self.threshold = init_threshold self.normalize = normalize self.group_size = GROUP_SIZES.get(tscale_type, GROUP_SIZES[TScaleType.T64]) self.sparse_threshold = 65_536 if num_embeddings >= self.sparse_threshold: n_trits = num_embeddings * embedding_dim n_packed = _ceil_div(n_trits, 5) packed_T = torch.randint(0, 243, (n_packed,), dtype=torch.uint8) T_pad = n_packed * 5 - n_trits gpr = _ceil_div(embedding_dim, self.group_size) init_exp = int(round(_log2(max(init_std, 1e-8)))) self.register_buffer("T_packed", packed_T) self.register_buffer("_T_shape", torch.tensor([num_embeddings, embedding_dim], dtype=torch.long)) self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long)) self.register_buffer( "E", torch.full((num_embeddings * gpr,), init_exp, dtype=torch.int8), ) self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8)) self.register_buffer("T_accum", torch.zeros(num_embeddings, embedding_dim, dtype=torch.int8)) self._ema_alpha: float = 0.1 self._loss_temp_scale: float = 1.0 return w_init = torch.randn(num_embeddings, embedding_dim) * init_std T_init = w_init.sign() * (w_init.abs() > init_threshold).to(w_init.dtype) packed_T, _, T_pad = pack_ternary(T_init) self.register_buffer("T_packed", packed_T) self.register_buffer("_T_shape", torch.tensor([num_embeddings, embedding_dim], dtype=torch.long)) self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long)) gpr = _ceil_div(embedding_dim, self.group_size) total_in = gpr * self.group_size padded = torch.zeros(num_embeddings, total_in) padded[:, :embedding_dim] = w_init.abs() grouped = padded.view(num_embeddings, gpr, self.group_size) E_vals = torch.where(grouped.mean(dim=2) > 0, grouped.mean(dim=2), torch.ones(num_embeddings, gpr)) self.register_buffer("E", E_vals.flatten().log2().clamp(-128, 127).to(torch.int8)) self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8)) self.register_buffer("T_accum", torch.zeros(num_embeddings, embedding_dim, dtype=torch.int8)) self._ema_alpha: float = 0.1 self._loss_temp_scale: float = 1.0 def _get_T(self): return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item())) def _get_T_rows(self, indices): indices = indices.reshape(-1).to(device=self.T_packed.device, dtype=torch.long) dim = self.embedding_dim cols = torch.arange(dim, device=indices.device, dtype=torch.long) lin = indices[:, None] * dim + cols[None, :] pack_idx = lin // 5 trit_pos = lin - pack_idx * 5 packed = self.T_packed[pack_idx].to(torch.long) divisors = torch.tensor([1, 3, 9, 27, 81], device=indices.device, dtype=torch.long) code = (packed // divisors[trit_pos]) % 3 return (code.to(torch.int8) - 1) def _expand_E_rows(self, indices): indices = indices.reshape(-1).to(device=self.E.device, dtype=torch.long) gpr = _ceil_div(self.embedding_dim, self.group_size) E_rows = self.E.view(self.num_embeddings, gpr)[indices] E_exp = E_rows.repeat_interleave(self.group_size, dim=1) return E_exp[:, :self.embedding_dim] @torch.no_grad() def _set_T_rows(self, row_indices, rows): row_indices = row_indices.reshape(-1).to(device=self.T_packed.device, dtype=torch.long) rows = rows.to(device=self.T_packed.device, dtype=torch.int8).reshape(row_indices.numel(), self.embedding_dim) divisors = [1, 3, 9, 27, 81] for row_pos, row_idx in enumerate(row_indices.tolist()): row = rows[row_pos] for col in range(self.embedding_dim): lin = row_idx * self.embedding_dim + col pack_idx = lin // 5 trit_pos = lin - pack_idx * 5 divisor = divisors[trit_pos] old = int(self.T_packed[pack_idx].item()) old_code = (old // divisor) % 3 new_code = int(row[col].item()) + 1 if old_code != new_code: self.T_packed[pack_idx] = old - old_code * divisor + new_code * divisor def _expand_E(self): out_dim, in_dim = tuple(self._T_shape.tolist()) gpr = _ceil_div(in_dim, self.group_size) E_2d = self.E.view(out_dim, gpr) E_exp = E_2d.repeat_interleave(self.group_size, dim=1) return E_exp[:, :in_dim] def _ensure_E_accum(self): if not hasattr(self, "E_accum"): self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8)) elif self.E_accum.shape != self.E.shape or self.E_accum.device != self.E.device: self.E_accum = torch.zeros_like(self.E, dtype=torch.int8) return self.E_accum def forward(self, indices): use_sparse = self.num_embeddings >= self.sparse_threshold if use_sparse: idx_flat = indices.reshape(-1).to(device=self.T_packed.device, dtype=torch.long) T_rows = self._get_T_rows(idx_flat) E_exp = self._expand_E_rows(idx_flat) w_eff = torch.exp2(E_exp.float()) * T_rows.float() w_eff_grad = w_eff.detach().requires_grad_(torch.is_grad_enabled()) if torch.is_grad_enabled(): comp_name, _ = _COMPONENT_CONTEXT.get() def capture_sparse_grad(grad): suffix = f"_{comp_name}" if comp_name is not None else "" setattr(self, f"_hook_sparse_indices{suffix}", idx_flat.detach()) setattr(self, f"_hook_sparse_grad_sign{suffix}", grad.reshape(-1, self.embedding_dim).sign().to(torch.int8).detach()) setattr(self, f"_hook_sparse_T{suffix}", T_rows.detach()) w_eff_grad.register_hook(capture_sparse_grad) out = w_eff_grad.reshape(*indices.shape, self.embedding_dim) return F.normalize(out, dim=-1) if self.normalize else out if indices.is_cuda and _HAS_TRITON and _TritonTernaryEmbedFn is not None: dummy = torch.zeros(1, device=indices.device, requires_grad=True) out = _TritonTernaryEmbedFn.apply(indices, dummy, self) else: T = self._get_T() w_eff = torch.exp2(self._expand_E().float()) * T.float() w_eff_grad = w_eff.detach().requires_grad_(True) self._hook_T = T def capture_w_grad(grad_w): self._hook_grad_T_sign = grad_w.sign().to(torch.int8) w_eff_grad.register_hook(capture_w_grad) out = F.embedding(indices, w_eff_grad) return F.normalize(out, dim=-1) if self.normalize else out def ternary_step(self, accum_threshold=3): if hasattr(self, "_hook_sparse_indices") and hasattr(self, "_hook_sparse_grad_sign"): return self._sparse_ternary_step(accum_threshold=accum_threshold) if hasattr(self, "_hook_grad_T_sign"): if hasattr(self, "_accumulate_corr_from_grad_sign"): self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign) del self._hook_grad_T_sign def update_E(self, loss_signal=None): if hasattr(self, "_hook_sparse_indices") and hasattr(self, "_hook_sparse_grad_sign"): return self._sparse_update_E(loss_signal=loss_signal) @torch.no_grad() def _sparse_ternary_step(self, accum_threshold=3): indices = self._hook_sparse_indices.to(device=self.T_accum.device, dtype=torch.long) grad_sign = self._hook_sparse_grad_sign.to(device=self.T_accum.device, dtype=torch.int16) if indices.numel() == 0: return unique, inverse = torch.unique(indices, return_inverse=True) grad_sum = torch.zeros(unique.numel(), self.embedding_dim, device=self.T_accum.device, dtype=torch.int16) grad_sum.index_add_(0, inverse, grad_sign) grad_step = grad_sum.sign().to(torch.int16) * int(getattr(self, "_t_accum_step", 1)) current = self.T_accum[unique].to(torch.int16) updated = torch.clamp(current - grad_step, -128, 127).to(torch.int8) pgt = getattr(self, "per_group_threshold", None) if pgt is not None: gpr = _ceil_div(self.embedding_dim, self.group_size) threshold = pgt.view(self.num_embeddings, gpr)[unique] threshold = threshold.unsqueeze(-1).expand(unique.numel(), gpr, self.group_size) threshold = threshold.reshape(unique.numel(), gpr * self.group_size)[:, :self.embedding_dim] threshold = threshold.to(updated.device) flip_up = updated > threshold flip_down = updated < -threshold else: flip_up = updated > accum_threshold flip_down = updated < -accum_threshold self._had_flip = bool((flip_up | flip_down).any().item()) if self._had_flip: rows = self._get_T_rows(unique).to(updated.device) rows = torch.where(flip_up, torch.ones_like(rows), torch.where(flip_down, -torch.ones_like(rows), rows)) self._set_T_rows(unique, rows) updated = torch.where(flip_up | flip_down, torch.zeros_like(updated), updated) self.T_accum[unique] = updated del self._hook_sparse_indices del self._hook_sparse_grad_sign if hasattr(self, "_hook_sparse_T"): del self._hook_sparse_T @torch.no_grad() def _sparse_update_E(self, loss_signal=None): indices = self._hook_sparse_indices.to(device=self.E.device, dtype=torch.long) grad_sign = self._hook_sparse_grad_sign.to(device=self.E.device, dtype=torch.int16) T_rows = self._hook_sparse_T if hasattr(self, "_hook_sparse_T") else self._get_T_rows(indices) T_rows = T_rows.to(device=self.E.device, dtype=torch.int16) if indices.numel() == 0: return unique, inverse = torch.unique(indices, return_inverse=True) gpr = _ceil_div(self.embedding_dim, self.group_size) total_in = gpr * self.group_size signed = grad_sign * T_rows grouped = F.pad(signed, (0, total_in - self.embedding_dim)).view(indices.numel(), gpr, self.group_size) score = grouped.sum(dim=2) delta = torch.where( score > 0, torch.full_like(score, -1, dtype=torch.int16), torch.where(score < 0, torch.ones_like(score, dtype=torch.int16), torch.zeros_like(score, dtype=torch.int16)), ) delta_sum = torch.zeros(unique.numel(), gpr, device=self.E.device, dtype=torch.int16) delta_sum.index_add_(0, inverse, delta) delta_sign = delta_sum.sign() e_idx = unique[:, None] * gpr + torch.arange(gpr, device=self.E.device, dtype=torch.long)[None, :] accum = torch.clamp(self.E_accum[e_idx].to(torch.int16) + delta_sign, -128, 127) threshold = int(getattr(self, "_e_accum_threshold", 4)) step = torch.where( accum >= threshold, torch.ones_like(accum, dtype=torch.int16), torch.where(accum <= -threshold, torch.full_like(accum, -1, dtype=torch.int16), torch.zeros_like(accum, dtype=torch.int16)), ) self.E[e_idx] = torch.clamp(self.E[e_idx].to(torch.int16) + step, -128, 127).to(torch.int8) self.E_accum[e_idx] = (accum - step * threshold).to(torch.int8) class TernaryVQCodebook(nn.Module): def __init__(self, codebook_size, codebook_dim, commitment_weight=1.0, tscale_type=TScaleType.T32, exact_lookup_max=16384, candidate_count=256): super().__init__() self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.commitment_weight = commitment_weight self.exact_lookup_max = exact_lookup_max self.candidate_count = candidate_count self.threshold_ema_dead_code = 2 self.table = TernaryEmbeddingTable(codebook_size, codebook_dim, tscale_type=tscale_type, normalize=True) self.register_buffer("cluster_size", torch.zeros(codebook_size, dtype=torch.int16)) @property def embed(self): idx = torch.arange(self.codebook_size, device=self.table.T_packed.device) return self.table(idx) def _candidate_ids(self, flat): c = min(self.candidate_count, self.codebook_size) take = min(flat.shape[1], 16) primes = torch.tensor( [1009, 9176, 6361, 5333, 4447, 3469, 2531, 1613, 811, 421, 211, 109, 59, 31, 17, 7], device=flat.device, dtype=torch.float32, )[:take] signed = torch.sign(flat[:, :take].float()) base = torch.abs(torch.round((signed * primes).sum(dim=1) * 104729)).to(torch.long) offsets = torch.arange(c, device=flat.device, dtype=torch.long) stride = 2_654_435_761 return (base[:, None] + offsets[None, :] * stride) % self.codebook_size def _lookup(self, flat): if self.codebook_size <= self.exact_lookup_max: x_norm = F.normalize(flat.float(), dim=-1) codebook = self.embed.to(device=flat.device) sim = x_norm @ codebook.T indices = sim.argmax(dim=-1) quantized = codebook[indices] return quantized, indices candidate_ids = self._candidate_ids(flat) x_norm = F.normalize(flat.float(), dim=-1) n, c, d = flat.shape[0], candidate_ids.shape[1], flat.shape[1] chunk = 64 quantized = torch.empty_like(flat) indices = torch.empty(n, dtype=torch.long, device=flat.device) for start in range(0, n, chunk): end = min(start + chunk, n) chunk_ids = candidate_ids[start:end] chunk_vecs = self.table(chunk_ids).float() chunk_norm = F.normalize(chunk_vecs, dim=-1) chunk_sim = (chunk_norm * x_norm[start:end].unsqueeze(1)).sum(dim=-1) chunk_best = chunk_sim.argmax(dim=-1) indices[start:end] = candidate_ids[start:end].gather(1, chunk_best.unsqueeze(1)).squeeze(1) quantized[start:end] = chunk_vecs[torch.arange(end - start, device=flat.device), chunk_best] return quantized, indices def forward(self, x): orig_shape = x.shape flat = x.reshape(-1, self.codebook_dim) quantized, indices = self._lookup(flat) commitment = self.commitment_weight * ( F.mse_loss(flat.float(), quantized.detach().float()) + 0.25 * F.mse_loss(quantized.float(), flat.detach().float()) ) quantized = flat + (quantized - flat).detach() with torch.no_grad(): unique, counts = torch.unique(indices, return_counts=True) current = self.cluster_size[unique].to(torch.int32) updated = torch.clamp(current + counts.to(device=current.device, dtype=torch.int32), 0, 32767).to(torch.int16) self.cluster_size[unique] = updated return quantized.reshape(orig_shape), indices.reshape(orig_shape[:-1]), commitment class GNNLoRAAdapter(nn.Module): def __init__(self, dim, rank=32, max_hops=4): super().__init__() self.max_hops = max_hops self.down = TernaryScaleTensor(dim, rank, tscale_type=TScaleType.T32) self.up = TernaryScaleTensor(rank, dim, tscale_type=TScaleType.T32) self.scale = TernaryEmbeddingTable(max_hops, rank, tscale_type=TScaleType.T32) def forward(self, x, hop_t): t_idx = min(hop_t, self.max_hops - 1) s = self.scale(torch.tensor(t_idx, device=x.device)) return self.up(self.down(x) * s) class HaltingUnit(nn.Module): def __init__(self, dim, tscale_type=TScaleType.T32): super().__init__() self.proj = TernaryScaleTensor(dim, 1, tscale_type=tscale_type) self.norm = TernaryRMSNorm(dim, tscale_type=tscale_type) def forward(self, x): return torch.sigmoid(self.proj(self.norm(x))) class _NgramHashMapping: """N-gram hash mapping with CPU offloading (Spider Engram style). Hashes token sequences to fixed-size embedding indices. Hash computation runs on CPU via numpy, O(1) per token via precomputed multipliers. """ def __init__(self, max_ngram_size, num_heads, table_size_base, layer_seed=0): self.max_ngram_size = max_ngram_size self.num_heads = num_heads self.num_ngram_orders = max_ngram_size - 1 import numpy as np PRIME_1 = 10007 g = torch.Generator() g.manual_seed(int(layer_seed + PRIME_1 * int(layer_seed))) r = torch.randint(0, 1 << 30, (max_ngram_size,), generator=g, dtype=torch.int64) self.multipliers = r.numpy() * 2 + 1 seen_primes = set() self.prime_table_sizes = [] for _ in range(self.num_ngram_orders): head_sizes = [] ps = table_size_base - 1 for _ in range(num_heads): p = self._next_prime(ps, seen_primes) seen_primes.add(p) head_sizes.append(p) ps = p self.prime_table_sizes.append(head_sizes) self.all_head_sizes = [s for sub in self.prime_table_sizes for s in sub] offsets = [0] for s in self.all_head_sizes[:-1]: offsets.append(offsets[-1] + s) self.offsets_arr = offsets self.total_slots = sum(self.all_head_sizes) @staticmethod def _next_prime(n, seen): while n in seen or not _is_prime(n): n -= 1 return n def compute_hashes(self, token_ids): import numpy as np x = token_ids.cpu().numpy().astype(np.int64) B, T = x.shape shifts = [x] for k in range(1, self.max_ngram_size): shifts.append(np.pad(x, ((0, 0), (k, 0)), constant_values=0)[:, :T]) all_hashes = [] for order_idx in range(self.num_ngram_orders): n = order_idx + 2 mix = shifts[0] * self.multipliers[0] for k in range(1, n): mix = np.bitwise_xor(mix, shifts[k].astype(np.int64) * self.multipliers[k]) for j, ms in enumerate(self.prime_table_sizes[order_idx]): all_hashes.append((mix % ms).astype(np.int64, copy=False)) result = np.stack(all_hashes, axis=2) return torch.from_numpy(result).to(device=token_ids.device) def _is_prime(n): if n < 2: return False import math for i in range(2, int(math.sqrt(n)) + 1): if n % i == 0: return False return True class MemGram(nn.Module): """Engram-style associative memory with O(1) hashed lookup (CPU offloaded). Features: - O(1) hash -> index -> embedding lookup (no search, no decay for retrieval) - CPU-offloaded hash computation (numpy) - Single offset-stacked embedding table (not per-head tables) - Gated retrieval: sigmoid(Q*K/sqrt(d)) gates the memory read - Depthwise conv1d processes retrieved memory (Engram-style) - No strength/decay buffers (decay is handled by GraphMoE usage frequency) - MemGram lookups do NOT affect KG decaying (separate mechanisms) """ def __init__(self, struct_primes=[64901, 64919, 64921, 64927, 64937, 64951, 64969, 64997, 65003, 65011, 65027, 65029, 65033, 65053, 65063, 65071], conv_primes=[8009, 8011, 8017, 8039], embed_dim=64, hidden_dim=HIDDEN_DIM, key_dim=32, max_ngram_size=3, num_hash_heads=4, layer_seed=0): super().__init__() self.embed_dim = embed_dim self.key_dim = key_dim self.hidden_dim = hidden_dim self.n_struct_heads = len(struct_primes) self.n_conv_heads = len(conv_primes) self.struct_hash = _NgramHashMapping( max_ngram_size=max_ngram_size, num_heads=num_hash_heads, table_size_base=struct_primes[0], layer_seed=layer_seed, ) self.conv_hash = _NgramHashMapping( max_ngram_size=max_ngram_size, num_heads=num_hash_heads, table_size_base=conv_primes[0], layer_seed=layer_seed + 1000, ) total_heads = self.struct_hash.num_ngram_orders * num_hash_heads self.total_mem_dim = total_heads * embed_dim total_slots = self.struct_hash.total_slots + self.conv_hash.total_slots self.mem_embed = nn.Embedding(total_slots, embed_dim) self.k_proj = nn.Linear(self.total_mem_dim, key_dim, bias=False) self.q_proj = nn.Linear(hidden_dim, key_dim, bias=False) self.v_proj = nn.Linear(self.total_mem_dim, hidden_dim, bias=False) with torch.no_grad(): self.v_proj.weight.zero_() self.conv_norm = nn.RMSNorm(hidden_dim) self.conv = nn.Conv1d( hidden_dim, hidden_dim, kernel_size=4, padding=9, dilation=3, groups=hidden_dim, ) with torch.no_grad(): self.conv.weight.zero_() if self.conv.bias is not None: self.conv.bias.zero_() def _retrieve(self, token_ids, hash_mapping): hash_ids = hash_mapping.compute_hashes(token_ids) B, T, H = hash_ids.shape flat_ids = hash_ids.reshape(B * T, H) offsets = torch.tensor(hash_mapping.offsets_arr, device=flat_ids.device, dtype=torch.long) emb = self.mem_embed(flat_ids + offsets) return emb.reshape(B, T, H * self.embed_dim) def forward(self, vq_indices, hidden_state): B, T, D = hidden_state.shape struct_mem = self._retrieve(vq_indices[:, 1:], self.struct_hash) conv_mem = self._retrieve(vq_indices[:, 1:], self.conv_hash) mem = struct_mem + conv_mem idx_end = mem.shape[1] q_proj = self.q_proj(hidden_state[:, :idx_end]) k = self.k_proj(mem) v = self.v_proj(mem) gate = torch.sigmoid((q_proj * k).sum(dim=-1, keepdim=True) / (self.key_dim ** 0.5)) v_gated = gate * v v_normed = self.conv_norm(v_gated) v_t = v_normed.transpose(1, 2) conv_out = self.conv(v_t) conv_out = conv_out[:, :, :v_t.shape[-1]].transpose(1, 2) output = hidden_state[:, :idx_end] + F.silu(conv_out) + v_gated if idx_end < T: output = F.pad(output, (0, 0, 0, T - idx_end)) return output def retrieve_cb(self, vq_indices): B, T = vq_indices.shape struct_mem = self._retrieve(vq_indices[:, 1:], self.struct_hash) conv_mem = self._retrieve(vq_indices[:, 1:], self.conv_hash) mem = struct_mem + conv_mem idx_end = mem.shape[1] pad = torch.zeros(B, T - idx_end, mem.shape[2], device=mem.device) mem = torch.cat([mem, pad], dim=1) q = mem.mean(dim=-1, keepdim=True) gate = torch.sigmoid(q) return gate * mem _BOUNDARY_TOKEN_MAP = { SPECIAL_VOCAB['BOS']: 0, SPECIAL_VOCAB['SYSTEM']: 1, SPECIAL_VOCAB['USER']: 2, SPECIAL_VOCAB['ASSISTANT']: 3, } class LTIInjection(nn.Module): """LTI state injection: h = A*h + B*e + trans_out. Spectral radius < 1 guaranteed by construction via ZOH discretization. Prevents divergence in recurrent/ACT loops at high dimensions. """ def __init__(self, dim: int): super().__init__() self.log_A = nn.Parameter(torch.zeros(dim)) self.log_dt = nn.Parameter(torch.zeros(1)) self.B = nn.Parameter(torch.ones(dim) * 0.1) for p in (self.log_A, self.log_dt, self.B): p.requires_grad_(False) def get_A(self): return torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20))) def forward(self, h, e, trans_out): return self.get_A() * h + self.B * e + trans_out class ByteHead(nn.Module): """Deep 3-layer MLP byte prediction head with ACT loop. Architecture: 8192 → 16384 → 8192 → 16384 → 288 ACT: up to 3 iterations, halts when argmax stable for 2 consecutive steps. """ def __init__(self, tscale_type=TScaleType.T32, act_max_iters=BYTEHEAD_ACT_MAX_ITERS, act_halt_consecutive=BYTEHEAD_ACT_HALT_CONSECUTIVE): super().__init__() H = HIDDEN_DIM W = HIDDEN_DIM * 2 self.act_max_iters = act_max_iters self.act_halt_consecutive = act_halt_consecutive self._last_ponder = 0.0 self.norm = TernaryRMSNorm(H, tscale_type=tscale_type) self.up = TernaryScaleTensor(H, W, tscale_type=tscale_type) self.up_norm = TernaryRMSNorm(W, tscale_type=tscale_type) self.hidden = TernaryScaleTensor(W, H, tscale_type=tscale_type) self.hidden_norm = TernaryRMSNorm(H, tscale_type=tscale_type) self.out = TernaryScaleTensor(H, W, tscale_type=tscale_type) self.out_norm = TernaryRMSNorm(W, tscale_type=tscale_type) self.head = TernaryScaleTensor(W, VOCAB, tscale_type=tscale_type) if act_max_iters > 1: self.act_residual = TernaryScaleTensor(VOCAB, H, tscale_type=tscale_type) self.lti = LTIInjection(H) else: self.act_residual = None self.lti = None def forward(self, x): if self.act_max_iters <= 1 or self.act_residual is None: hn = F.silu(self.up(self.norm(x))) hn = F.silu(self.hidden(self.up_norm(hn))) hn = F.silu(self.out(self.hidden_norm(hn))) return self.head(self.out_norm(hn)) h = x x_initial = x prev_argmax = None stable_count = 0 total_iters = 0 for i in range(self.act_max_iters): hn = F.silu(self.up(self.norm(h))) hn = F.silu(self.hidden(self.up_norm(hn))) hn = F.silu(self.out(self.hidden_norm(hn))) logits = self.head(self.out_norm(hn)) curr_argmax = logits.argmax(dim=-1) if prev_argmax is not None and (curr_argmax == prev_argmax).all(): stable_count += 1 else: stable_count = 0 total_iters = i + 1 if stable_count >= self.act_halt_consecutive: break prev_argmax = curr_argmax trans_out = self.act_residual(logits) h = self.lti(h, x_initial, trans_out) self._last_ponder = total_iters / max(self.act_max_iters, 1) return logits class OutputRouter(nn.Module): """Routes HIDDEN_DIM relational tokens to ByteHead, VideoHead, or TalkerHead. 3-layer MLP when depth=3, 2-layer when depth=2, single projection when depth=1. Argmax at inference, soft weighted routing at training. """ def __init__(self, tscale_type=TScaleType.T32, depth=3): super().__init__() if depth >= 3: self.hidden1 = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM, tscale_type=tscale_type) self.hidden1_norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type) self.hidden2 = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM // 4, tscale_type=tscale_type) self.gate = TernaryScaleTensor(HIDDEN_DIM // 4, 4, tscale_type=tscale_type) elif depth == 2: self.hidden1 = None self.hidden1_norm = None self.hidden2 = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM // 4, tscale_type=tscale_type) self.gate = TernaryScaleTensor(HIDDEN_DIM // 4, 4, tscale_type=tscale_type) else: self.hidden1 = None self.hidden1_norm = None self.hidden2 = None self.gate = TernaryScaleTensor(HIDDEN_DIM, 4, tscale_type=tscale_type) # 0 = Null (continue), 1 = ByteHead, 2 = VideoHead, 3 = TalkerHead def forward(self, x, training=False): h = x if self.hidden1 is not None: h = F.silu(self.hidden1_norm(self.hidden1(h))) if self.hidden2 is not None: h = self.hidden2(h) logits = self.gate(h) # [B, T, 4] logits = torch.nan_to_num(logits, nan=0.0, posinf=30.0, neginf=-30.0).clamp(-30.0, 30.0) if training: weights = F.softmax(logits, dim=-1) return weights, logits return logits.argmax(dim=-1) class KGVQCodebook(TernaryVQCodebook): """Compatibility wrapper for the KG/composite VQ. The old implementation kept float32 `embed` and `embed_avg` buffers. The production path now uses the same packed ternary/int8 backing table as the shared VQ so default 5M-code KG construction cannot allocate hidden float codebook state. """ def __init__(self, codebook_size=KGVQ_CODEBOOK_SIZE, codebook_dim=KGVQ_CODEBOOK_DIM, decay=KGVQ_DECAY, commitment_weight=KGVQ_COMMITMENT_WEIGHT, threshold_ema_dead_code=KGVQ_DEAD_CODE_THRESHOLD): super().__init__( codebook_size=codebook_size, codebook_dim=codebook_dim, commitment_weight=commitment_weight, ) self.decay = decay self.threshold_ema_dead_code = threshold_ema_dead_code @property def embed(self): if self.codebook_size > self.exact_lookup_max: raise RuntimeError( "Full KG VQ materialization is disabled for large ternary codebooks; " "query rows through `table(indices)` instead." ) return super().embed def _ema_update(self, x_flat, indices): unique, counts = torch.unique(indices, return_counts=True) current = self.cluster_size[unique].to(torch.int32) updated = torch.clamp( current + counts.to(device=current.device, dtype=torch.int32), 0, 32767, ).to(torch.int16) self.cluster_size[unique] = updated def _dead_code_reset(self, x_flat): return None class CompositeProposalHead(nn.Module): """Multi-proposal head from pooled GNN output (Phase 17). Projects GNN pool output (graph_pool_out [B, D]) to K_MAX composite motif proposals, quantizes via KGVQ, and applies ACT-style halting. """ def __init__(self, dim=HIDDEN_DIM, codebook_dim=KGVQ_CODEBOOK_DIM, k_max=K_MAX_COMPOSITES, codebook_size=KGVQ_CODEBOOK_SIZE, tscale_type=TScaleType.T32): super().__init__() self.dim = dim self.k_max = k_max self.codebook_dim = codebook_dim self.proj = TernaryScaleTensor(dim, k_max * codebook_dim, tscale_type=tscale_type) self.kgvq = TernaryVQCodebook(codebook_size=codebook_size, codebook_dim=codebook_dim, tscale_type=tscale_type) self.halt_gate = TernaryScaleTensor(dim, k_max, tscale_type=tscale_type) self.diversity_weight = 0.1 def forward(self, pool_out): B = pool_out.shape[0] projections = self.proj(pool_out).view(B, self.k_max, self.codebook_dim) quantized, composite_ids, vq_loss = self.kgvq(projections) halt_logits = self.halt_gate(pool_out).clamp(-12.0, 12.0) halt = torch.sigmoid(halt_logits) # [B, K_MAX] composite_ids = composite_ids.masked_fill(halt < 0.5, -1) normed = F.normalize(projections, dim=-1) sim_matrix = normed @ normed.transpose(-1, -2) triu = torch.triu(sim_matrix, diagonal=1) n_pairs = self.k_max * (self.k_max - 1) / 2 diversity_loss = triu.sum(dim=(-1, -2)).mean() / max(n_pairs, 1) diversity_loss = diversity_loss * self.diversity_weight return composite_ids, vq_loss + diversity_loss, halt class MoEGraph(nn.Module): """Fused graph traversal + centroid-based MoE routing + ACT halting. Each ACT iteration: traverse KG → aggregate neighbor emb → centroid route → run expert → halt check. All operations at MG_WORKSPACE_DIM (1024). Replaces: TernaryGraph + GraphMoEGate + GraphACTCell + SharedProjectionMoE + MoEACTCell. """ def __init__(self, cb_dim=MG_WORKSPACE_DIM, trigram_dim=HIDDEN_DIM, codebook_dim=CODEBOOK_DIM, num_experts=MG_N_EXPERTS, core_rank=MG_CORE_RANK, shared_inter=MG_SHARED_INTER, max_iters=MG_ACT_ITERS, halt_threshold=0.99, tscale_type=TScaleType.T32, codebook_size=CODEBOOK_SIZE, active_graph_max_nodes=4096, top_k=1): super().__init__() self.cb_dim = cb_dim self.trigram_dim = trigram_dim self.codebook_dim = codebook_dim self.num_experts = num_experts self.core_rank = core_rank self.shared_inter = shared_inter self.max_iters = max_iters self.halt_threshold = halt_threshold self.codebook_size = codebook_size self.active_graph_max_nodes = active_graph_max_nodes self.top_k = top_k self.down_proj = TernaryScaleTensor(trigram_dim, cb_dim, tscale_type=tscale_type) self.down_norm = TernaryRMSNorm(trigram_dim, tscale_type=tscale_type) self.up_proj = TernaryScaleTensor(cb_dim, trigram_dim, tscale_type=tscale_type) self.up_norm = TernaryRMSNorm(cb_dim, tscale_type=tscale_type) self.attn_down_proj = TernaryScaleTensor(trigram_dim, cb_dim, tscale_type=tscale_type) self.codebook_up = TernaryScaleTensor(codebook_dim, cb_dim, tscale_type=tscale_type) self.use_active_edge_store = self.codebook_size > self.active_graph_max_nodes self.active_edge_capacity = max(int(self.active_graph_max_nodes) * 16, 65_536) if self.use_active_edge_store: self.register_buffer("edge_index", torch.zeros(2, 0, dtype=torch.int32)) self.register_buffer("edge_attr", torch.zeros(0, dtype=torch.int8)) self.register_buffer("edge_score", torch.zeros(0, dtype=torch.int8)) self.register_buffer("active_edge_src", torch.full((self.active_edge_capacity,), -1, dtype=torch.int32)) self.register_buffer("active_edge_dst", torch.full((self.active_edge_capacity,), -1, dtype=torch.int32)) self.register_buffer("active_edge_attr", torch.zeros(self.active_edge_capacity, dtype=torch.int8)) self.register_buffer("active_edge_score", torch.zeros(self.active_edge_capacity, dtype=torch.int8)) self.register_buffer("active_edge_ptr", torch.zeros((), dtype=torch.long)) else: num_edges = self.codebook_size * 10 src = torch.arange(self.codebook_size, dtype=torch.int32).repeat_interleave(10) dst = torch.randint(0, self.codebook_size, (num_edges,), dtype=torch.int32) self.register_buffer("edge_index", torch.stack([src, dst], dim=0)) edge_init = torch.randint(-1, 2, (num_edges,), dtype=torch.int8) self.register_buffer("edge_attr", edge_init) self.register_buffer("edge_score", torch.zeros(num_edges, dtype=torch.int8)) self.register_buffer("_steps_since_requant", torch.tensor(0, dtype=torch.long)) self.requant_every = KG_REQUANT_EVERY self.kg_ternary_threshold = KG_TERNARY_THRESHOLD self.kg_ema_alpha = KG_EMA_ALPHA self.centroids = TernaryEmbeddingTable(num_experts, cb_dim, tscale_type=tscale_type, normalize=True) self.shared_up_norm = TernaryRMSNorm(cb_dim, tscale_type=tscale_type) self.shared_up = TernaryScaleTensor(cb_dim, shared_inter, tscale_type=tscale_type) self.shared_down_norm = TernaryRMSNorm(shared_inter, tscale_type=tscale_type) self.shared_down = TernaryScaleTensor(shared_inter, cb_dim, tscale_type=tscale_type) self.W_gate = nn.ModuleList([ TernaryScaleTensor(cb_dim, core_rank, tscale_type=tscale_type) for _ in range(num_experts) ]) self.W_gate_norms = nn.ModuleList([ TernaryRMSNorm(cb_dim, tscale_type=tscale_type) for _ in range(num_experts) ]) self.W_transform = nn.ModuleList([ TernaryScaleTensor(core_rank, shared_inter, tscale_type=tscale_type) for _ in range(num_experts) ]) self.W_transform_norms = nn.ModuleList([ TernaryRMSNorm(core_rank, tscale_type=tscale_type) for _ in range(num_experts) ]) self.hop_lora = GNNLoRAAdapter(dim=cb_dim, rank=32, max_hops=max_iters) self.halting = HaltingUnit(dim=cb_dim, tscale_type=tscale_type) self.lti = LTIInjection(cb_dim) self._codebook_embed = None self._codebook_table = None def _codebook_tensor(self, device): if self._codebook_table is not None: idx = torch.arange(self.codebook_size, device=device) codebook = self._codebook_table(idx) if codebook.shape[-1] != self.cb_dim: codebook = self.codebook_up(codebook) return codebook if self._codebook_embed is not None: codebook = self._codebook_embed.to(device=device).squeeze(0) if codebook.shape[-1] != self.cb_dim: codebook = self.codebook_up(codebook) return codebook return torch.zeros(self.codebook_size, self.cb_dim, device=device) def _active_codebook_features(self, vq_indices): if self._codebook_table is not None: safe_idx = vq_indices.clamp(min=0, max=self.codebook_size - 1) active_code = self._codebook_table(safe_idx) elif self._codebook_embed is not None: codebook = self._codebook_embed.to(device=vq_indices.device).squeeze(0) safe_idx = vq_indices.clamp(min=0, max=codebook.shape[0] - 1) active_code = codebook[safe_idx] else: return torch.zeros(*vq_indices.shape, self.cb_dim, device=vq_indices.device) if active_code.shape[-1] != self.cb_dim: active_code = self.codebook_up(active_code) return active_code def _neighbor_aggregate(self, node_features, threshold): N, D = node_features.shape aggregated = torch.zeros(self.codebook_size, D, device=node_features.device, dtype=node_features.dtype) edge_ternary = StickyZoneSTE.apply(self.edge_attr, threshold) src_features = node_features[self.edge_index[0]] messages = edge_ternary.unsqueeze(1).to(node_features.dtype) * src_features dst_idx = self.edge_index[1].unsqueeze(1).expand(-1, D) aggregated.scatter_add_(0, dst_idx, messages) return aggregated def _run_expert_batch(self, x, expert_idx): B, T, D = x.shape N = B * T x_flat = rearrange(x, 'b t d -> (b t) d') exp_flat = rearrange(expert_idx, 'b t -> (b t)') shared_hidden = F.silu(self.shared_up(self.shared_up_norm(x_flat))) sort_idx = exp_flat.argsort() sorted_experts = exp_flat[sort_idx] expert_counts = torch.bincount(sorted_experts, minlength=self.num_experts) expert_boundaries = torch.cumsum(expert_counts, dim=0) out_flat = torch.zeros(N, D, device=x.device, dtype=x.dtype) for e in range(self.num_experts): start = expert_boundaries[e] - expert_counts[e] end = expert_boundaries[e] if start == end: continue tok_idx = sort_idx[start:end] inp = x_flat[tok_idx] sh = shared_hidden[tok_idx] gate = self.W_gate[e](self.W_gate_norms[e](inp)) core = self.W_transform[e](self.W_transform_norms[e](gate)) expert_out = self.shared_down(self.shared_down_norm(core * sh)) out_flat[tok_idx] = expert_out return rearrange(out_flat, '(b t) d -> b t d', b=B, t=T) def _run_expert(self, x, expert_idx): return self._run_expert_batch(x, expert_idx) def _active_node_add(self, vq_output, vq_indices): return vq_output + self._active_codebook_features(vq_indices) def forward(self, trigram_input, vq_indices, attention_output=None, memgram_cb_output=None, threshold=0.05): B, T, D = trigram_input.shape device = trigram_input.device x = self.down_proj(self.down_norm(trigram_input)) attn_cb = None if attention_output is not None: attn_cb = self.attn_down_proj(self.down_norm(attention_output)) halted = torch.zeros(B, T, device=device, dtype=torch.bool) cumulative_p = torch.zeros(B, T, device=device) acc = torch.zeros_like(x) total_ponder = torch.zeros(B, T, device=device) last_x = x initial_x = x use_active_graph = self.codebook_size > self.active_graph_max_nodes node_features = None if use_active_graph else self._codebook_tensor(device) for iter_t in range(self.max_iters): if use_active_graph: traversal = self._active_node_add(x, vq_indices) else: node_aggregated = self._neighbor_aggregate(node_features, threshold) traversal = x + node_aggregated[vq_indices] if attn_cb is not None: traversal = traversal + attn_cb if iter_t in [1, 3] and memgram_cb_output is not None: memgram_raw = memgram_cb_output.to(device) if memgram_raw.shape[-1] != self.cb_dim: memgram_raw = memgram_raw.mean(dim=-1, keepdim=True).expand(-1, -1, self.cb_dim) traversal = traversal + memgram_raw traversal = traversal + self.hop_lora(traversal, iter_t) trav_norm = F.normalize(traversal, dim=-1, eps=1e-8) centroid_ids = torch.arange(self.num_experts, device=device) cent_norm = F.normalize(self.centroids(centroid_ids), dim=-1, eps=1e-8) scores = trav_norm @ cent_norm.T if self.top_k <= 1: _, expert_idx = scores.max(dim=-1) expert_out = self._run_expert(traversal, expert_idx) else: scores_topk, topk_idx = scores.topk(k=self.top_k, dim=-1) weights = F.softmax(scores_topk / 0.1, dim=-1) expert_out = 0 for i in range(self.top_k): wi = weights[..., i:i+1] ei = topk_idx[..., i] expert_out = expert_out + wi * self._run_expert(traversal, ei) last_x = expert_out p = self.halting(expert_out).squeeze(-1) still_running = ~halted remainder = (1.0 - cumulative_p).clamp(min=0) weight = torch.where( cumulative_p + p >= self.halt_threshold, remainder, p, ) weight = weight * still_running.float() acc = acc + weight.unsqueeze(-1) * expert_out cumulative_p = cumulative_p + p * still_running.float() halted = halted | (cumulative_p >= self.halt_threshold) total_ponder = total_ponder + (1.0 - cumulative_p).clamp(min=0) x = self.lti(x, initial_x, expert_out) if halted.all(): break never_halted = (~halted).float().unsqueeze(-1) acc = acc + never_halted * last_x output = self.up_proj(self.up_norm(acc)) ponder_loss = total_ponder.mean() / self.max_iters return output, ponder_loss @torch.no_grad() def update_kg_edges(self, all_vq_indices): if self.use_active_edge_store: self._update_active_edges(all_vq_indices) return unique_ids = torch.unique(all_vq_indices.to(device=self.edge_index.device, dtype=torch.int32)) src_in_batch = torch.isin(self.edge_index[0], unique_ids) if src_in_batch.any(): dst_seen = torch.isin(self.edge_index[1][src_in_batch], unique_ids) delta = torch.where( dst_seen, torch.ones_like(self.edge_score[src_in_batch], dtype=torch.int16), torch.full_like(self.edge_score[src_in_batch], -1, dtype=torch.int16), ) score = torch.clamp(self.edge_score[src_in_batch].to(torch.int16) + delta, -128, 127) self.edge_score[src_in_batch] = score.to(torch.int8) self._requantize_dense_edges() @torch.no_grad() def _update_active_edges(self, all_vq_indices): ids = all_vq_indices.to(device=self.active_edge_src.device, dtype=torch.int32) if ids.numel() < 2: self._steps_since_requant.add_(1) return seq = ids.reshape(-1, ids.shape[-1]) if ids.dim() > 1 else ids.reshape(1, -1) src = seq[:, :-1].reshape(-1) dst = seq[:, 1:].reshape(-1) valid = (src >= 0) & (dst >= 0) & (src < self.codebook_size) & (dst < self.codebook_size) & (src != dst) src = src[valid] dst = dst[valid] if src.numel() == 0: self._steps_since_requant.add_(1) return n_edges = min(src.numel(), self.active_edge_capacity) src = src[-n_edges:] dst = dst[-n_edges:] ptr = int(self.active_edge_ptr.item()) slots = (torch.arange(n_edges, device=src.device, dtype=torch.long) + ptr) % self.active_edge_capacity self.active_edge_src[slots] = src self.active_edge_dst[slots] = dst score = torch.clamp(self.active_edge_score[slots].to(torch.int16) + 1, -128, 127) self.active_edge_score[slots] = score.to(torch.int8) self.active_edge_attr[slots] = 1 self.active_edge_ptr.fill_((ptr + n_edges) % self.active_edge_capacity) self._requantize_active_edges() @torch.no_grad() def _requantize_dense_edges(self): if self._steps_since_requant.item() < self.requant_every: self._steps_since_requant.add_(1) return self.edge_attr = self._score_to_attr(self.edge_score) score = self.edge_score.to(torch.int16) score = torch.where(score > 0, score - 1, torch.where(score < 0, score + 1, score)) self.edge_score = score.to(torch.int8) self._steps_since_requant.zero_() @torch.no_grad() def _requantize_active_edges(self): if self._steps_since_requant.item() < self.requant_every: self._steps_since_requant.add_(1) return active = self.active_edge_src >= 0 if active.any(): self.active_edge_attr[active] = self._score_to_attr(self.active_edge_score[active]) score = self.active_edge_score[active].to(torch.int16) score = torch.where(score > 0, score - 1, torch.where(score < 0, score + 1, score)) self.active_edge_score[active] = score.to(torch.int8) self._steps_since_requant.zero_() def _score_to_attr(self, score): threshold = max(1, int(round(float(self.kg_ternary_threshold) * 8))) score_i = score.to(torch.int16) return torch.where( score_i >= threshold, torch.ones_like(score, dtype=torch.int8), torch.where( score_i <= -threshold, torch.full_like(score, -1, dtype=torch.int8), torch.zeros_like(score, dtype=torch.int8), ), ) @torch.no_grad() def monitor_graph_health(self, threshold=0.05): if self.use_active_edge_store: active = self.active_edge_src >= 0 if not active.any(): return { "sparsity": 1.0, "isolated_nodes": self.codebook_size, "avg_polarity": 0.0, "dead_edges": 0, "score_mean": 0.0, "score_max": 0.0, "active_edges": 0, } edge_attr = self.active_edge_attr[active] edge_score = self.active_edge_score[active] nodes_with_edges = torch.unique(torch.cat([self.active_edge_src[active], self.active_edge_dst[active]])) else: edge_attr = self.edge_attr edge_score = self.edge_score nodes_with_edges = torch.unique(torch.cat([self.edge_index[0], self.edge_index[1]])) ternary_edge = edge_attr.sign() sparsity = (ternary_edge == 0).float().mean().item() if ternary_edge.numel() else 1.0 n_isolated = max(int(self.codebook_size) - int(nodes_with_edges.numel()), 0) n_pos = (ternary_edge > 0).sum().item() n_neg = (ternary_edge < 0).sum().item() n_nonzero = n_pos + n_neg avg_polarity = (n_pos - n_neg) / max(n_nonzero, 1) dead_edges = ((ternary_edge == 0) & (edge_score != 0)).sum().item() score_mean = edge_score.float().mean().item() if edge_score.numel() else 0.0 score_max = edge_score.float().abs().max().item() if edge_score.numel() else 0.0 return { "sparsity": sparsity, "isolated_nodes": n_isolated, "avg_polarity": avg_polarity, "dead_edges": dead_edges, "score_mean": score_mean, "score_max": score_max, "active_edges": int(ternary_edge.numel()), } def set_adjacency(self, edge_index, edge_attr_init=None): self.use_active_edge_store = False device = self.edge_attr.device self.edge_index = edge_index.to(device=device, dtype=torch.int32) if edge_attr_init is not None: edge_attr = edge_attr_init.sign() * (edge_attr_init.abs() > 0).to(edge_attr_init.dtype) self.edge_attr = edge_attr.to(device=device, dtype=torch.int8) else: self.edge_attr = torch.randint(-1, 2, (edge_index.size(1),), device=device, dtype=torch.int8) self.edge_score = self.edge_attr.clone()