ARBS / arbitor /components.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Components — core neural network modules for the ARB system."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm, GROUP_SIZES, _COMPONENT_CONTEXT, _HAS_TRITON
try:
from .kernel.ternary_scale import _TritonTernaryEmbedFn
except ImportError:
_TritonTernaryEmbedFn = None
from .converters.convert_to_ternary8 import pack_ternary, unpack_ternary
from dataclasses import dataclass, field, fields
from math import ceil as _ceil, log2 as _log2
from transformers import AutoModel, AutoFeatureExtractor
from .config import VOCAB, EMBEDDING_DIM, HIDDEN_DIM, AUDIO_VOCAB, AUDIO_SR, AUDIO_FRAME_RATE, SPECIAL_VOCAB, CODEBOOK_DIM, CODEBOOK_SIZE, FFN_HIDDEN, CTX, THRESHOLD, KG_EMA_ALPHA, KG_REQUANT_EVERY, KG_TERNARY_THRESHOLD, KGVQ_CODEBOOK_SIZE, KGVQ_CODEBOOK_DIM, KGVQ_DECAY, KGVQ_COMMITMENT_WEIGHT, KGVQ_DEAD_CODE_THRESHOLD, K_MAX_COMPOSITES, MG_N_EXPERTS, MG_CORE_RANK, MG_SHARED_INTER, MG_ACT_ITERS, MG_WORKSPACE_DIM, BYTEHEAD_ACT_MAX_ITERS, BYTEHEAD_ACT_HALT_CONSECUTIVE
_ceil_div = lambda a, b: _ceil(a / b) if b > 0 else 0
from .sequencers import ByteEmbedding
@dataclass
class LossWeights:
lm: float = 1.0
vq_commitment: float = 1.0
moe_aux: float = 1.0
graph_l1: float = 0.001
graph_ponder: float = 1.0
moe_ponder: float = 1.0
moegraph_ponder: float = 1.0
memgram_decay_reg: float = 0.01
composite_vq: float = 1.0
@dataclass
class LossComponents:
lm: torch.Tensor = None
vq_commitment: torch.Tensor = None
moe_aux: torch.Tensor = None
graph_l1: torch.Tensor = None
graph_ponder: torch.Tensor = None
moe_ponder: torch.Tensor = None
moegraph_ponder: torch.Tensor = None
memgram_decay_reg: torch.Tensor = None
composite_vq: torch.Tensor = None
weights: LossWeights = field(default_factory=LossWeights)
@property
def total(self) -> torch.Tensor:
w = self.weights
loss = None
def add_component(current, weight, component):
if component is None:
return current
weighted = weight * component
return weighted if current is None else current + weighted
loss = add_component(loss, w.lm, self.lm)
loss = add_component(loss, w.vq_commitment, self.vq_commitment)
loss = add_component(loss, w.moe_aux, self.moe_aux)
loss = add_component(loss, w.graph_l1, self.graph_l1)
loss = add_component(loss, w.graph_ponder, self.graph_ponder)
loss = add_component(loss, w.moe_ponder, self.moe_ponder)
loss = add_component(loss, w.moegraph_ponder, self.moegraph_ponder)
loss = add_component(loss, w.memgram_decay_reg, self.memgram_decay_reg)
loss = add_component(loss, w.composite_vq, self.composite_vq)
if loss is None:
raise ValueError("LossComponents.total requested with no active loss tensors")
return loss
@property
def active_fields(self) -> list[tuple[str, torch.Tensor, float]]:
result = []
for field in fields(self):
name = field.name
if name == 'weights':
continue
tensor = getattr(self, name)
if tensor is not None:
weight = getattr(self.weights, name)
result.append((name, tensor, weight))
return result
def log(self, writer, step, prefix="loss"):
writer.add_scalar(f"{prefix}/total", self.total.item(), step)
if self.lm is not None:
writer.add_scalar(f"{prefix}/lm", self.lm.item(), step)
if self.vq_commitment is not None:
writer.add_scalar(f"{prefix}/vq_commitment", self.vq_commitment.item(), step)
if self.moe_aux is not None:
writer.add_scalar(f"{prefix}/moe_aux", self.moe_aux.item(), step)
if self.graph_l1 is not None:
writer.add_scalar(f"{prefix}/graph_l1", self.graph_l1.item(), step)
if self.graph_ponder is not None:
writer.add_scalar(f"{prefix}/graph_ponder", self.graph_ponder.item(), step)
if self.moe_ponder is not None:
writer.add_scalar(f"{prefix}/moe_ponder", self.moe_ponder.item(), step)
if self.moegraph_ponder is not None:
writer.add_scalar(f"{prefix}/moegraph_ponder", self.moegraph_ponder.item(), step)
if self.memgram_decay_reg is not None:
writer.add_scalar(f"{prefix}/memgram_decay_reg", self.memgram_decay_reg.item(), step)
if self.composite_vq is not None:
writer.add_scalar(f"{prefix}/composite_vq", self.composite_vq.item(), step)
def backward(self, retain_graph=False):
self.total.backward(retain_graph=retain_graph)
class StickyZoneSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, w, threshold):
ctx.save_for_backward(w, torch.tensor(threshold))
return w.sign() * (w.abs() > threshold).to(w.dtype)
@staticmethod
def backward(ctx, grad_output):
w, threshold_tensor = ctx.saved_tensors
threshold = threshold_tensor.item()
ratio = torch.clamp(w.abs() / threshold, 0.0, 1.0)
return grad_output * ratio, None
class TernaryEmbeddingTable(nn.Module):
def __init__(self, num_embeddings, embedding_dim, tscale_type=TScaleType.T32,
init_std=0.02, threshold=0.05, normalize=False):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.tscale_type = tscale_type
init_threshold = min(float(threshold), 0.5 * float(init_std)) if init_std > 0 else threshold
self.threshold = init_threshold
self.normalize = normalize
self.group_size = GROUP_SIZES.get(tscale_type, GROUP_SIZES[TScaleType.T64])
self.sparse_threshold = 65_536
if num_embeddings >= self.sparse_threshold:
n_trits = num_embeddings * embedding_dim
n_packed = _ceil_div(n_trits, 5)
packed_T = torch.randint(0, 243, (n_packed,), dtype=torch.uint8)
T_pad = n_packed * 5 - n_trits
gpr = _ceil_div(embedding_dim, self.group_size)
init_exp = int(round(_log2(max(init_std, 1e-8))))
self.register_buffer("T_packed", packed_T)
self.register_buffer("_T_shape", torch.tensor([num_embeddings, embedding_dim], dtype=torch.long))
self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long))
self.register_buffer(
"E",
torch.full((num_embeddings * gpr,), init_exp, dtype=torch.int8),
)
self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8))
self.register_buffer("T_accum", torch.zeros(num_embeddings, embedding_dim, dtype=torch.int8))
self._ema_alpha: float = 0.1
self._loss_temp_scale: float = 1.0
return
w_init = torch.randn(num_embeddings, embedding_dim) * init_std
T_init = w_init.sign() * (w_init.abs() > init_threshold).to(w_init.dtype)
packed_T, _, T_pad = pack_ternary(T_init)
self.register_buffer("T_packed", packed_T)
self.register_buffer("_T_shape", torch.tensor([num_embeddings, embedding_dim], dtype=torch.long))
self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long))
gpr = _ceil_div(embedding_dim, self.group_size)
total_in = gpr * self.group_size
padded = torch.zeros(num_embeddings, total_in)
padded[:, :embedding_dim] = w_init.abs()
grouped = padded.view(num_embeddings, gpr, self.group_size)
E_vals = torch.where(grouped.mean(dim=2) > 0, grouped.mean(dim=2), torch.ones(num_embeddings, gpr))
self.register_buffer("E", E_vals.flatten().log2().clamp(-128, 127).to(torch.int8))
self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8))
self.register_buffer("T_accum", torch.zeros(num_embeddings, embedding_dim, dtype=torch.int8))
self._ema_alpha: float = 0.1
self._loss_temp_scale: float = 1.0
def _get_T(self):
return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item()))
def _get_T_rows(self, indices):
indices = indices.reshape(-1).to(device=self.T_packed.device, dtype=torch.long)
dim = self.embedding_dim
cols = torch.arange(dim, device=indices.device, dtype=torch.long)
lin = indices[:, None] * dim + cols[None, :]
pack_idx = lin // 5
trit_pos = lin - pack_idx * 5
packed = self.T_packed[pack_idx].to(torch.long)
divisors = torch.tensor([1, 3, 9, 27, 81], device=indices.device, dtype=torch.long)
code = (packed // divisors[trit_pos]) % 3
return (code.to(torch.int8) - 1)
def _expand_E_rows(self, indices):
indices = indices.reshape(-1).to(device=self.E.device, dtype=torch.long)
gpr = _ceil_div(self.embedding_dim, self.group_size)
E_rows = self.E.view(self.num_embeddings, gpr)[indices]
E_exp = E_rows.repeat_interleave(self.group_size, dim=1)
return E_exp[:, :self.embedding_dim]
@torch.no_grad()
def _set_T_rows(self, row_indices, rows):
row_indices = row_indices.reshape(-1).to(device=self.T_packed.device, dtype=torch.long)
rows = rows.to(device=self.T_packed.device, dtype=torch.int8).reshape(row_indices.numel(), self.embedding_dim)
divisors = [1, 3, 9, 27, 81]
for row_pos, row_idx in enumerate(row_indices.tolist()):
row = rows[row_pos]
for col in range(self.embedding_dim):
lin = row_idx * self.embedding_dim + col
pack_idx = lin // 5
trit_pos = lin - pack_idx * 5
divisor = divisors[trit_pos]
old = int(self.T_packed[pack_idx].item())
old_code = (old // divisor) % 3
new_code = int(row[col].item()) + 1
if old_code != new_code:
self.T_packed[pack_idx] = old - old_code * divisor + new_code * divisor
def _expand_E(self):
out_dim, in_dim = tuple(self._T_shape.tolist())
gpr = _ceil_div(in_dim, self.group_size)
E_2d = self.E.view(out_dim, gpr)
E_exp = E_2d.repeat_interleave(self.group_size, dim=1)
return E_exp[:, :in_dim]
def _ensure_E_accum(self):
if not hasattr(self, "E_accum"):
self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8))
elif self.E_accum.shape != self.E.shape or self.E_accum.device != self.E.device:
self.E_accum = torch.zeros_like(self.E, dtype=torch.int8)
return self.E_accum
def forward(self, indices):
use_sparse = self.num_embeddings >= self.sparse_threshold
if use_sparse:
idx_flat = indices.reshape(-1).to(device=self.T_packed.device, dtype=torch.long)
T_rows = self._get_T_rows(idx_flat)
E_exp = self._expand_E_rows(idx_flat)
w_eff = torch.exp2(E_exp.float()) * T_rows.float()
w_eff_grad = w_eff.detach().requires_grad_(torch.is_grad_enabled())
if torch.is_grad_enabled():
comp_name, _ = _COMPONENT_CONTEXT.get()
def capture_sparse_grad(grad):
suffix = f"_{comp_name}" if comp_name is not None else ""
setattr(self, f"_hook_sparse_indices{suffix}", idx_flat.detach())
setattr(self, f"_hook_sparse_grad_sign{suffix}", grad.reshape(-1, self.embedding_dim).sign().to(torch.int8).detach())
setattr(self, f"_hook_sparse_T{suffix}", T_rows.detach())
w_eff_grad.register_hook(capture_sparse_grad)
out = w_eff_grad.reshape(*indices.shape, self.embedding_dim)
return F.normalize(out, dim=-1) if self.normalize else out
if indices.is_cuda and _HAS_TRITON and _TritonTernaryEmbedFn is not None:
dummy = torch.zeros(1, device=indices.device, requires_grad=True)
out = _TritonTernaryEmbedFn.apply(indices, dummy, self)
else:
T = self._get_T()
w_eff = torch.exp2(self._expand_E().float()) * T.float()
w_eff_grad = w_eff.detach().requires_grad_(True)
self._hook_T = T
def capture_w_grad(grad_w):
self._hook_grad_T_sign = grad_w.sign().to(torch.int8)
w_eff_grad.register_hook(capture_w_grad)
out = F.embedding(indices, w_eff_grad)
return F.normalize(out, dim=-1) if self.normalize else out
def ternary_step(self, accum_threshold=3):
if hasattr(self, "_hook_sparse_indices") and hasattr(self, "_hook_sparse_grad_sign"):
return self._sparse_ternary_step(accum_threshold=accum_threshold)
if hasattr(self, "_hook_grad_T_sign"):
if hasattr(self, "_accumulate_corr_from_grad_sign"):
self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign)
del self._hook_grad_T_sign
def update_E(self, loss_signal=None):
if hasattr(self, "_hook_sparse_indices") and hasattr(self, "_hook_sparse_grad_sign"):
return self._sparse_update_E(loss_signal=loss_signal)
@torch.no_grad()
def _sparse_ternary_step(self, accum_threshold=3):
indices = self._hook_sparse_indices.to(device=self.T_accum.device, dtype=torch.long)
grad_sign = self._hook_sparse_grad_sign.to(device=self.T_accum.device, dtype=torch.int16)
if indices.numel() == 0:
return
unique, inverse = torch.unique(indices, return_inverse=True)
grad_sum = torch.zeros(unique.numel(), self.embedding_dim, device=self.T_accum.device, dtype=torch.int16)
grad_sum.index_add_(0, inverse, grad_sign)
grad_step = grad_sum.sign().to(torch.int16) * int(getattr(self, "_t_accum_step", 1))
current = self.T_accum[unique].to(torch.int16)
updated = torch.clamp(current - grad_step, -128, 127).to(torch.int8)
pgt = getattr(self, "per_group_threshold", None)
if pgt is not None:
gpr = _ceil_div(self.embedding_dim, self.group_size)
threshold = pgt.view(self.num_embeddings, gpr)[unique]
threshold = threshold.unsqueeze(-1).expand(unique.numel(), gpr, self.group_size)
threshold = threshold.reshape(unique.numel(), gpr * self.group_size)[:, :self.embedding_dim]
threshold = threshold.to(updated.device)
flip_up = updated > threshold
flip_down = updated < -threshold
else:
flip_up = updated > accum_threshold
flip_down = updated < -accum_threshold
self._had_flip = bool((flip_up | flip_down).any().item())
if self._had_flip:
rows = self._get_T_rows(unique).to(updated.device)
rows = torch.where(flip_up, torch.ones_like(rows), torch.where(flip_down, -torch.ones_like(rows), rows))
self._set_T_rows(unique, rows)
updated = torch.where(flip_up | flip_down, torch.zeros_like(updated), updated)
self.T_accum[unique] = updated
del self._hook_sparse_indices
del self._hook_sparse_grad_sign
if hasattr(self, "_hook_sparse_T"):
del self._hook_sparse_T
@torch.no_grad()
def _sparse_update_E(self, loss_signal=None):
indices = self._hook_sparse_indices.to(device=self.E.device, dtype=torch.long)
grad_sign = self._hook_sparse_grad_sign.to(device=self.E.device, dtype=torch.int16)
T_rows = self._hook_sparse_T if hasattr(self, "_hook_sparse_T") else self._get_T_rows(indices)
T_rows = T_rows.to(device=self.E.device, dtype=torch.int16)
if indices.numel() == 0:
return
unique, inverse = torch.unique(indices, return_inverse=True)
gpr = _ceil_div(self.embedding_dim, self.group_size)
total_in = gpr * self.group_size
signed = grad_sign * T_rows
grouped = F.pad(signed, (0, total_in - self.embedding_dim)).view(indices.numel(), gpr, self.group_size)
score = grouped.sum(dim=2)
delta = torch.where(
score > 0,
torch.full_like(score, -1, dtype=torch.int16),
torch.where(score < 0, torch.ones_like(score, dtype=torch.int16), torch.zeros_like(score, dtype=torch.int16)),
)
delta_sum = torch.zeros(unique.numel(), gpr, device=self.E.device, dtype=torch.int16)
delta_sum.index_add_(0, inverse, delta)
delta_sign = delta_sum.sign()
e_idx = unique[:, None] * gpr + torch.arange(gpr, device=self.E.device, dtype=torch.long)[None, :]
accum = torch.clamp(self.E_accum[e_idx].to(torch.int16) + delta_sign, -128, 127)
threshold = int(getattr(self, "_e_accum_threshold", 4))
step = torch.where(
accum >= threshold,
torch.ones_like(accum, dtype=torch.int16),
torch.where(accum <= -threshold, torch.full_like(accum, -1, dtype=torch.int16), torch.zeros_like(accum, dtype=torch.int16)),
)
self.E[e_idx] = torch.clamp(self.E[e_idx].to(torch.int16) + step, -128, 127).to(torch.int8)
self.E_accum[e_idx] = (accum - step * threshold).to(torch.int8)
class TernaryVQCodebook(nn.Module):
def __init__(self, codebook_size, codebook_dim, commitment_weight=1.0,
tscale_type=TScaleType.T32, exact_lookup_max=16384,
candidate_count=256):
super().__init__()
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.commitment_weight = commitment_weight
self.exact_lookup_max = exact_lookup_max
self.candidate_count = candidate_count
self.threshold_ema_dead_code = 2
self.table = TernaryEmbeddingTable(codebook_size, codebook_dim, tscale_type=tscale_type, normalize=True)
self.register_buffer("cluster_size", torch.zeros(codebook_size, dtype=torch.int16))
@property
def embed(self):
idx = torch.arange(self.codebook_size, device=self.table.T_packed.device)
return self.table(idx)
def _candidate_ids(self, flat):
c = min(self.candidate_count, self.codebook_size)
take = min(flat.shape[1], 16)
primes = torch.tensor(
[1009, 9176, 6361, 5333, 4447, 3469, 2531, 1613,
811, 421, 211, 109, 59, 31, 17, 7],
device=flat.device, dtype=torch.float32,
)[:take]
signed = torch.sign(flat[:, :take].float())
base = torch.abs(torch.round((signed * primes).sum(dim=1) * 104729)).to(torch.long)
offsets = torch.arange(c, device=flat.device, dtype=torch.long)
stride = 2_654_435_761
return (base[:, None] + offsets[None, :] * stride) % self.codebook_size
def _lookup(self, flat):
if self.codebook_size <= self.exact_lookup_max:
x_norm = F.normalize(flat.float(), dim=-1)
codebook = self.embed.to(device=flat.device)
sim = x_norm @ codebook.T
indices = sim.argmax(dim=-1)
quantized = codebook[indices]
return quantized, indices
candidate_ids = self._candidate_ids(flat)
x_norm = F.normalize(flat.float(), dim=-1)
n, c, d = flat.shape[0], candidate_ids.shape[1], flat.shape[1]
chunk = 64
quantized = torch.empty_like(flat)
indices = torch.empty(n, dtype=torch.long, device=flat.device)
for start in range(0, n, chunk):
end = min(start + chunk, n)
chunk_ids = candidate_ids[start:end]
chunk_vecs = self.table(chunk_ids).float()
chunk_norm = F.normalize(chunk_vecs, dim=-1)
chunk_sim = (chunk_norm * x_norm[start:end].unsqueeze(1)).sum(dim=-1)
chunk_best = chunk_sim.argmax(dim=-1)
indices[start:end] = candidate_ids[start:end].gather(1, chunk_best.unsqueeze(1)).squeeze(1)
quantized[start:end] = chunk_vecs[torch.arange(end - start, device=flat.device), chunk_best]
return quantized, indices
def forward(self, x):
orig_shape = x.shape
flat = x.reshape(-1, self.codebook_dim)
quantized, indices = self._lookup(flat)
commitment = self.commitment_weight * (
F.mse_loss(flat.float(), quantized.detach().float())
+ 0.25 * F.mse_loss(quantized.float(), flat.detach().float())
)
quantized = flat + (quantized - flat).detach()
with torch.no_grad():
unique, counts = torch.unique(indices, return_counts=True)
current = self.cluster_size[unique].to(torch.int32)
updated = torch.clamp(current + counts.to(device=current.device, dtype=torch.int32), 0, 32767).to(torch.int16)
self.cluster_size[unique] = updated
return quantized.reshape(orig_shape), indices.reshape(orig_shape[:-1]), commitment
class GNNLoRAAdapter(nn.Module):
def __init__(self, dim, rank=32, max_hops=4):
super().__init__()
self.max_hops = max_hops
self.down = TernaryScaleTensor(dim, rank, tscale_type=TScaleType.T32)
self.up = TernaryScaleTensor(rank, dim, tscale_type=TScaleType.T32)
self.scale = TernaryEmbeddingTable(max_hops, rank, tscale_type=TScaleType.T32)
def forward(self, x, hop_t):
t_idx = min(hop_t, self.max_hops - 1)
s = self.scale(torch.tensor(t_idx, device=x.device))
return self.up(self.down(x) * s)
class HaltingUnit(nn.Module):
def __init__(self, dim, tscale_type=TScaleType.T32):
super().__init__()
self.proj = TernaryScaleTensor(dim, 1, tscale_type=tscale_type)
self.norm = TernaryRMSNorm(dim, tscale_type=tscale_type)
def forward(self, x):
return torch.sigmoid(self.proj(self.norm(x)))
class _NgramHashMapping:
"""N-gram hash mapping with CPU offloading (Spider Engram style).
Hashes token sequences to fixed-size embedding indices. Hash computation
runs on CPU via numpy, O(1) per token via precomputed multipliers.
"""
def __init__(self, max_ngram_size, num_heads, table_size_base, layer_seed=0):
self.max_ngram_size = max_ngram_size
self.num_heads = num_heads
self.num_ngram_orders = max_ngram_size - 1
import numpy as np
PRIME_1 = 10007
g = torch.Generator()
g.manual_seed(int(layer_seed + PRIME_1 * int(layer_seed)))
r = torch.randint(0, 1 << 30, (max_ngram_size,), generator=g, dtype=torch.int64)
self.multipliers = r.numpy() * 2 + 1
seen_primes = set()
self.prime_table_sizes = []
for _ in range(self.num_ngram_orders):
head_sizes = []
ps = table_size_base - 1
for _ in range(num_heads):
p = self._next_prime(ps, seen_primes)
seen_primes.add(p)
head_sizes.append(p)
ps = p
self.prime_table_sizes.append(head_sizes)
self.all_head_sizes = [s for sub in self.prime_table_sizes for s in sub]
offsets = [0]
for s in self.all_head_sizes[:-1]:
offsets.append(offsets[-1] + s)
self.offsets_arr = offsets
self.total_slots = sum(self.all_head_sizes)
@staticmethod
def _next_prime(n, seen):
while n in seen or not _is_prime(n):
n -= 1
return n
def compute_hashes(self, token_ids):
import numpy as np
x = token_ids.cpu().numpy().astype(np.int64)
B, T = x.shape
shifts = [x]
for k in range(1, self.max_ngram_size):
shifts.append(np.pad(x, ((0, 0), (k, 0)), constant_values=0)[:, :T])
all_hashes = []
for order_idx in range(self.num_ngram_orders):
n = order_idx + 2
mix = shifts[0] * self.multipliers[0]
for k in range(1, n):
mix = np.bitwise_xor(mix, shifts[k].astype(np.int64) * self.multipliers[k])
for j, ms in enumerate(self.prime_table_sizes[order_idx]):
all_hashes.append((mix % ms).astype(np.int64, copy=False))
result = np.stack(all_hashes, axis=2)
return torch.from_numpy(result).to(device=token_ids.device)
def _is_prime(n):
if n < 2:
return False
import math
for i in range(2, int(math.sqrt(n)) + 1):
if n % i == 0:
return False
return True
class MemGram(nn.Module):
"""Engram-style associative memory with O(1) hashed lookup (CPU offloaded).
Features:
- O(1) hash -> index -> embedding lookup (no search, no decay for retrieval)
- CPU-offloaded hash computation (numpy)
- Single offset-stacked embedding table (not per-head tables)
- Gated retrieval: sigmoid(Q*K/sqrt(d)) gates the memory read
- Depthwise conv1d processes retrieved memory (Engram-style)
- No strength/decay buffers (decay is handled by GraphMoE usage frequency)
- MemGram lookups do NOT affect KG decaying (separate mechanisms)
"""
def __init__(self, struct_primes=[64901, 64919, 64921, 64927, 64937, 64951, 64969, 64997,
65003, 65011, 65027, 65029, 65033, 65053, 65063, 65071],
conv_primes=[8009, 8011, 8017, 8039],
embed_dim=64, hidden_dim=HIDDEN_DIM, key_dim=32,
max_ngram_size=3, num_hash_heads=4, layer_seed=0):
super().__init__()
self.embed_dim = embed_dim
self.key_dim = key_dim
self.hidden_dim = hidden_dim
self.n_struct_heads = len(struct_primes)
self.n_conv_heads = len(conv_primes)
self.struct_hash = _NgramHashMapping(
max_ngram_size=max_ngram_size, num_heads=num_hash_heads,
table_size_base=struct_primes[0], layer_seed=layer_seed,
)
self.conv_hash = _NgramHashMapping(
max_ngram_size=max_ngram_size, num_heads=num_hash_heads,
table_size_base=conv_primes[0], layer_seed=layer_seed + 1000,
)
total_heads = self.struct_hash.num_ngram_orders * num_hash_heads
self.total_mem_dim = total_heads * embed_dim
total_slots = self.struct_hash.total_slots + self.conv_hash.total_slots
self.mem_embed = nn.Embedding(total_slots, embed_dim)
self.k_proj = nn.Linear(self.total_mem_dim, key_dim, bias=False)
self.q_proj = nn.Linear(hidden_dim, key_dim, bias=False)
self.v_proj = nn.Linear(self.total_mem_dim, hidden_dim, bias=False)
with torch.no_grad():
self.v_proj.weight.zero_()
self.conv_norm = nn.RMSNorm(hidden_dim)
self.conv = nn.Conv1d(
hidden_dim, hidden_dim,
kernel_size=4, padding=9, dilation=3, groups=hidden_dim,
)
with torch.no_grad():
self.conv.weight.zero_()
if self.conv.bias is not None:
self.conv.bias.zero_()
def _retrieve(self, token_ids, hash_mapping):
hash_ids = hash_mapping.compute_hashes(token_ids)
B, T, H = hash_ids.shape
flat_ids = hash_ids.reshape(B * T, H)
offsets = torch.tensor(hash_mapping.offsets_arr, device=flat_ids.device, dtype=torch.long)
emb = self.mem_embed(flat_ids + offsets)
return emb.reshape(B, T, H * self.embed_dim)
def forward(self, vq_indices, hidden_state):
B, T, D = hidden_state.shape
struct_mem = self._retrieve(vq_indices[:, 1:], self.struct_hash)
conv_mem = self._retrieve(vq_indices[:, 1:], self.conv_hash)
mem = struct_mem + conv_mem
idx_end = mem.shape[1]
q_proj = self.q_proj(hidden_state[:, :idx_end])
k = self.k_proj(mem)
v = self.v_proj(mem)
gate = torch.sigmoid((q_proj * k).sum(dim=-1, keepdim=True) / (self.key_dim ** 0.5))
v_gated = gate * v
v_normed = self.conv_norm(v_gated)
v_t = v_normed.transpose(1, 2)
conv_out = self.conv(v_t)
conv_out = conv_out[:, :, :v_t.shape[-1]].transpose(1, 2)
output = hidden_state[:, :idx_end] + F.silu(conv_out) + v_gated
if idx_end < T:
output = F.pad(output, (0, 0, 0, T - idx_end))
return output
def retrieve_cb(self, vq_indices):
B, T = vq_indices.shape
struct_mem = self._retrieve(vq_indices[:, 1:], self.struct_hash)
conv_mem = self._retrieve(vq_indices[:, 1:], self.conv_hash)
mem = struct_mem + conv_mem
idx_end = mem.shape[1]
pad = torch.zeros(B, T - idx_end, mem.shape[2], device=mem.device)
mem = torch.cat([mem, pad], dim=1)
q = mem.mean(dim=-1, keepdim=True)
gate = torch.sigmoid(q)
return gate * mem
_BOUNDARY_TOKEN_MAP = {
SPECIAL_VOCAB['BOS']: 0,
SPECIAL_VOCAB['SYSTEM']: 1,
SPECIAL_VOCAB['USER']: 2,
SPECIAL_VOCAB['ASSISTANT']: 3,
}
class LTIInjection(nn.Module):
"""LTI state injection: h = A*h + B*e + trans_out.
Spectral radius < 1 guaranteed by construction via ZOH discretization.
Prevents divergence in recurrent/ACT loops at high dimensions.
"""
def __init__(self, dim: int):
super().__init__()
self.log_A = nn.Parameter(torch.zeros(dim))
self.log_dt = nn.Parameter(torch.zeros(1))
self.B = nn.Parameter(torch.ones(dim) * 0.1)
for p in (self.log_A, self.log_dt, self.B):
p.requires_grad_(False)
def get_A(self):
return torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20)))
def forward(self, h, e, trans_out):
return self.get_A() * h + self.B * e + trans_out
class ByteHead(nn.Module):
"""Deep 3-layer MLP byte prediction head with ACT loop.
Architecture: 8192 → 16384 → 8192 → 16384 → 288
ACT: up to 3 iterations, halts when argmax stable for 2 consecutive steps.
"""
def __init__(self, tscale_type=TScaleType.T32,
act_max_iters=BYTEHEAD_ACT_MAX_ITERS,
act_halt_consecutive=BYTEHEAD_ACT_HALT_CONSECUTIVE):
super().__init__()
H = HIDDEN_DIM
W = HIDDEN_DIM * 2
self.act_max_iters = act_max_iters
self.act_halt_consecutive = act_halt_consecutive
self._last_ponder = 0.0
self.norm = TernaryRMSNorm(H, tscale_type=tscale_type)
self.up = TernaryScaleTensor(H, W, tscale_type=tscale_type)
self.up_norm = TernaryRMSNorm(W, tscale_type=tscale_type)
self.hidden = TernaryScaleTensor(W, H, tscale_type=tscale_type)
self.hidden_norm = TernaryRMSNorm(H, tscale_type=tscale_type)
self.out = TernaryScaleTensor(H, W, tscale_type=tscale_type)
self.out_norm = TernaryRMSNorm(W, tscale_type=tscale_type)
self.head = TernaryScaleTensor(W, VOCAB, tscale_type=tscale_type)
if act_max_iters > 1:
self.act_residual = TernaryScaleTensor(VOCAB, H, tscale_type=tscale_type)
self.lti = LTIInjection(H)
else:
self.act_residual = None
self.lti = None
def forward(self, x):
if self.act_max_iters <= 1 or self.act_residual is None:
hn = F.silu(self.up(self.norm(x)))
hn = F.silu(self.hidden(self.up_norm(hn)))
hn = F.silu(self.out(self.hidden_norm(hn)))
return self.head(self.out_norm(hn))
h = x
x_initial = x
prev_argmax = None
stable_count = 0
total_iters = 0
for i in range(self.act_max_iters):
hn = F.silu(self.up(self.norm(h)))
hn = F.silu(self.hidden(self.up_norm(hn)))
hn = F.silu(self.out(self.hidden_norm(hn)))
logits = self.head(self.out_norm(hn))
curr_argmax = logits.argmax(dim=-1)
if prev_argmax is not None and (curr_argmax == prev_argmax).all():
stable_count += 1
else:
stable_count = 0
total_iters = i + 1
if stable_count >= self.act_halt_consecutive:
break
prev_argmax = curr_argmax
trans_out = self.act_residual(logits)
h = self.lti(h, x_initial, trans_out)
self._last_ponder = total_iters / max(self.act_max_iters, 1)
return logits
class OutputRouter(nn.Module):
"""Routes HIDDEN_DIM relational tokens to ByteHead, VideoHead, or TalkerHead.
3-layer MLP when depth=3, 2-layer when depth=2, single projection when depth=1.
Argmax at inference, soft weighted routing at training.
"""
def __init__(self, tscale_type=TScaleType.T32, depth=3):
super().__init__()
if depth >= 3:
self.hidden1 = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM, tscale_type=tscale_type)
self.hidden1_norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
self.hidden2 = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM // 4, tscale_type=tscale_type)
self.gate = TernaryScaleTensor(HIDDEN_DIM // 4, 4, tscale_type=tscale_type)
elif depth == 2:
self.hidden1 = None
self.hidden1_norm = None
self.hidden2 = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM // 4, tscale_type=tscale_type)
self.gate = TernaryScaleTensor(HIDDEN_DIM // 4, 4, tscale_type=tscale_type)
else:
self.hidden1 = None
self.hidden1_norm = None
self.hidden2 = None
self.gate = TernaryScaleTensor(HIDDEN_DIM, 4, tscale_type=tscale_type)
# 0 = Null (continue), 1 = ByteHead, 2 = VideoHead, 3 = TalkerHead
def forward(self, x, training=False):
h = x
if self.hidden1 is not None:
h = F.silu(self.hidden1_norm(self.hidden1(h)))
if self.hidden2 is not None:
h = self.hidden2(h)
logits = self.gate(h) # [B, T, 4]
logits = torch.nan_to_num(logits, nan=0.0, posinf=30.0, neginf=-30.0).clamp(-30.0, 30.0)
if training:
weights = F.softmax(logits, dim=-1)
return weights, logits
return logits.argmax(dim=-1)
class KGVQCodebook(TernaryVQCodebook):
"""Compatibility wrapper for the KG/composite VQ.
The old implementation kept float32 `embed` and `embed_avg` buffers. The
production path now uses the same packed ternary/int8 backing table as the
shared VQ so default 5M-code KG construction cannot allocate hidden float
codebook state.
"""
def __init__(self, codebook_size=KGVQ_CODEBOOK_SIZE, codebook_dim=KGVQ_CODEBOOK_DIM,
decay=KGVQ_DECAY, commitment_weight=KGVQ_COMMITMENT_WEIGHT,
threshold_ema_dead_code=KGVQ_DEAD_CODE_THRESHOLD):
super().__init__(
codebook_size=codebook_size,
codebook_dim=codebook_dim,
commitment_weight=commitment_weight,
)
self.decay = decay
self.threshold_ema_dead_code = threshold_ema_dead_code
@property
def embed(self):
if self.codebook_size > self.exact_lookup_max:
raise RuntimeError(
"Full KG VQ materialization is disabled for large ternary codebooks; "
"query rows through `table(indices)` instead."
)
return super().embed
def _ema_update(self, x_flat, indices):
unique, counts = torch.unique(indices, return_counts=True)
current = self.cluster_size[unique].to(torch.int32)
updated = torch.clamp(
current + counts.to(device=current.device, dtype=torch.int32),
0,
32767,
).to(torch.int16)
self.cluster_size[unique] = updated
def _dead_code_reset(self, x_flat):
return None
class CompositeProposalHead(nn.Module):
"""Multi-proposal head from pooled GNN output (Phase 17).
Projects GNN pool output (graph_pool_out [B, D]) to K_MAX composite motif
proposals, quantizes via KGVQ, and applies ACT-style halting.
"""
def __init__(self, dim=HIDDEN_DIM, codebook_dim=KGVQ_CODEBOOK_DIM,
k_max=K_MAX_COMPOSITES, codebook_size=KGVQ_CODEBOOK_SIZE,
tscale_type=TScaleType.T32):
super().__init__()
self.dim = dim
self.k_max = k_max
self.codebook_dim = codebook_dim
self.proj = TernaryScaleTensor(dim, k_max * codebook_dim, tscale_type=tscale_type)
self.kgvq = TernaryVQCodebook(codebook_size=codebook_size, codebook_dim=codebook_dim,
tscale_type=tscale_type)
self.halt_gate = TernaryScaleTensor(dim, k_max, tscale_type=tscale_type)
self.diversity_weight = 0.1
def forward(self, pool_out):
B = pool_out.shape[0]
projections = self.proj(pool_out).view(B, self.k_max, self.codebook_dim)
quantized, composite_ids, vq_loss = self.kgvq(projections)
halt_logits = self.halt_gate(pool_out).clamp(-12.0, 12.0)
halt = torch.sigmoid(halt_logits) # [B, K_MAX]
composite_ids = composite_ids.masked_fill(halt < 0.5, -1)
normed = F.normalize(projections, dim=-1)
sim_matrix = normed @ normed.transpose(-1, -2)
triu = torch.triu(sim_matrix, diagonal=1)
n_pairs = self.k_max * (self.k_max - 1) / 2
diversity_loss = triu.sum(dim=(-1, -2)).mean() / max(n_pairs, 1)
diversity_loss = diversity_loss * self.diversity_weight
return composite_ids, vq_loss + diversity_loss, halt
class MoEGraph(nn.Module):
"""Fused graph traversal + centroid-based MoE routing + ACT halting.
Each ACT iteration: traverse KG → aggregate neighbor emb → centroid route →
run expert → halt check. All operations at MG_WORKSPACE_DIM (1024).
Replaces: TernaryGraph + GraphMoEGate + GraphACTCell + SharedProjectionMoE + MoEACTCell.
"""
def __init__(self, cb_dim=MG_WORKSPACE_DIM, trigram_dim=HIDDEN_DIM,
codebook_dim=CODEBOOK_DIM,
num_experts=MG_N_EXPERTS, core_rank=MG_CORE_RANK,
shared_inter=MG_SHARED_INTER, max_iters=MG_ACT_ITERS,
halt_threshold=0.99, tscale_type=TScaleType.T32,
codebook_size=CODEBOOK_SIZE,
active_graph_max_nodes=4096,
top_k=1):
super().__init__()
self.cb_dim = cb_dim
self.trigram_dim = trigram_dim
self.codebook_dim = codebook_dim
self.num_experts = num_experts
self.core_rank = core_rank
self.shared_inter = shared_inter
self.max_iters = max_iters
self.halt_threshold = halt_threshold
self.codebook_size = codebook_size
self.active_graph_max_nodes = active_graph_max_nodes
self.top_k = top_k
self.down_proj = TernaryScaleTensor(trigram_dim, cb_dim, tscale_type=tscale_type)
self.down_norm = TernaryRMSNorm(trigram_dim, tscale_type=tscale_type)
self.up_proj = TernaryScaleTensor(cb_dim, trigram_dim, tscale_type=tscale_type)
self.up_norm = TernaryRMSNorm(cb_dim, tscale_type=tscale_type)
self.attn_down_proj = TernaryScaleTensor(trigram_dim, cb_dim, tscale_type=tscale_type)
self.codebook_up = TernaryScaleTensor(codebook_dim, cb_dim, tscale_type=tscale_type)
self.use_active_edge_store = self.codebook_size > self.active_graph_max_nodes
self.active_edge_capacity = max(int(self.active_graph_max_nodes) * 16, 65_536)
if self.use_active_edge_store:
self.register_buffer("edge_index", torch.zeros(2, 0, dtype=torch.int32))
self.register_buffer("edge_attr", torch.zeros(0, dtype=torch.int8))
self.register_buffer("edge_score", torch.zeros(0, dtype=torch.int8))
self.register_buffer("active_edge_src", torch.full((self.active_edge_capacity,), -1, dtype=torch.int32))
self.register_buffer("active_edge_dst", torch.full((self.active_edge_capacity,), -1, dtype=torch.int32))
self.register_buffer("active_edge_attr", torch.zeros(self.active_edge_capacity, dtype=torch.int8))
self.register_buffer("active_edge_score", torch.zeros(self.active_edge_capacity, dtype=torch.int8))
self.register_buffer("active_edge_ptr", torch.zeros((), dtype=torch.long))
else:
num_edges = self.codebook_size * 10
src = torch.arange(self.codebook_size, dtype=torch.int32).repeat_interleave(10)
dst = torch.randint(0, self.codebook_size, (num_edges,), dtype=torch.int32)
self.register_buffer("edge_index", torch.stack([src, dst], dim=0))
edge_init = torch.randint(-1, 2, (num_edges,), dtype=torch.int8)
self.register_buffer("edge_attr", edge_init)
self.register_buffer("edge_score", torch.zeros(num_edges, dtype=torch.int8))
self.register_buffer("_steps_since_requant", torch.tensor(0, dtype=torch.long))
self.requant_every = KG_REQUANT_EVERY
self.kg_ternary_threshold = KG_TERNARY_THRESHOLD
self.kg_ema_alpha = KG_EMA_ALPHA
self.centroids = TernaryEmbeddingTable(num_experts, cb_dim, tscale_type=tscale_type, normalize=True)
self.shared_up_norm = TernaryRMSNorm(cb_dim, tscale_type=tscale_type)
self.shared_up = TernaryScaleTensor(cb_dim, shared_inter, tscale_type=tscale_type)
self.shared_down_norm = TernaryRMSNorm(shared_inter, tscale_type=tscale_type)
self.shared_down = TernaryScaleTensor(shared_inter, cb_dim, tscale_type=tscale_type)
self.W_gate = nn.ModuleList([
TernaryScaleTensor(cb_dim, core_rank, tscale_type=tscale_type)
for _ in range(num_experts)
])
self.W_gate_norms = nn.ModuleList([
TernaryRMSNorm(cb_dim, tscale_type=tscale_type)
for _ in range(num_experts)
])
self.W_transform = nn.ModuleList([
TernaryScaleTensor(core_rank, shared_inter, tscale_type=tscale_type)
for _ in range(num_experts)
])
self.W_transform_norms = nn.ModuleList([
TernaryRMSNorm(core_rank, tscale_type=tscale_type)
for _ in range(num_experts)
])
self.hop_lora = GNNLoRAAdapter(dim=cb_dim, rank=32, max_hops=max_iters)
self.halting = HaltingUnit(dim=cb_dim, tscale_type=tscale_type)
self.lti = LTIInjection(cb_dim)
self._codebook_embed = None
self._codebook_table = None
def _codebook_tensor(self, device):
if self._codebook_table is not None:
idx = torch.arange(self.codebook_size, device=device)
codebook = self._codebook_table(idx)
if codebook.shape[-1] != self.cb_dim:
codebook = self.codebook_up(codebook)
return codebook
if self._codebook_embed is not None:
codebook = self._codebook_embed.to(device=device).squeeze(0)
if codebook.shape[-1] != self.cb_dim:
codebook = self.codebook_up(codebook)
return codebook
return torch.zeros(self.codebook_size, self.cb_dim, device=device)
def _active_codebook_features(self, vq_indices):
if self._codebook_table is not None:
safe_idx = vq_indices.clamp(min=0, max=self.codebook_size - 1)
active_code = self._codebook_table(safe_idx)
elif self._codebook_embed is not None:
codebook = self._codebook_embed.to(device=vq_indices.device).squeeze(0)
safe_idx = vq_indices.clamp(min=0, max=codebook.shape[0] - 1)
active_code = codebook[safe_idx]
else:
return torch.zeros(*vq_indices.shape, self.cb_dim, device=vq_indices.device)
if active_code.shape[-1] != self.cb_dim:
active_code = self.codebook_up(active_code)
return active_code
def _neighbor_aggregate(self, node_features, threshold):
N, D = node_features.shape
aggregated = torch.zeros(self.codebook_size, D, device=node_features.device, dtype=node_features.dtype)
edge_ternary = StickyZoneSTE.apply(self.edge_attr, threshold)
src_features = node_features[self.edge_index[0]]
messages = edge_ternary.unsqueeze(1).to(node_features.dtype) * src_features
dst_idx = self.edge_index[1].unsqueeze(1).expand(-1, D)
aggregated.scatter_add_(0, dst_idx, messages)
return aggregated
def _run_expert_batch(self, x, expert_idx):
B, T, D = x.shape
N = B * T
x_flat = rearrange(x, 'b t d -> (b t) d')
exp_flat = rearrange(expert_idx, 'b t -> (b t)')
shared_hidden = F.silu(self.shared_up(self.shared_up_norm(x_flat)))
sort_idx = exp_flat.argsort()
sorted_experts = exp_flat[sort_idx]
expert_counts = torch.bincount(sorted_experts, minlength=self.num_experts)
expert_boundaries = torch.cumsum(expert_counts, dim=0)
out_flat = torch.zeros(N, D, device=x.device, dtype=x.dtype)
for e in range(self.num_experts):
start = expert_boundaries[e] - expert_counts[e]
end = expert_boundaries[e]
if start == end:
continue
tok_idx = sort_idx[start:end]
inp = x_flat[tok_idx]
sh = shared_hidden[tok_idx]
gate = self.W_gate[e](self.W_gate_norms[e](inp))
core = self.W_transform[e](self.W_transform_norms[e](gate))
expert_out = self.shared_down(self.shared_down_norm(core * sh))
out_flat[tok_idx] = expert_out
return rearrange(out_flat, '(b t) d -> b t d', b=B, t=T)
def _run_expert(self, x, expert_idx):
return self._run_expert_batch(x, expert_idx)
def _active_node_add(self, vq_output, vq_indices):
return vq_output + self._active_codebook_features(vq_indices)
def forward(self, trigram_input, vq_indices, attention_output=None,
memgram_cb_output=None, threshold=0.05):
B, T, D = trigram_input.shape
device = trigram_input.device
x = self.down_proj(self.down_norm(trigram_input))
attn_cb = None
if attention_output is not None:
attn_cb = self.attn_down_proj(self.down_norm(attention_output))
halted = torch.zeros(B, T, device=device, dtype=torch.bool)
cumulative_p = torch.zeros(B, T, device=device)
acc = torch.zeros_like(x)
total_ponder = torch.zeros(B, T, device=device)
last_x = x
initial_x = x
use_active_graph = self.codebook_size > self.active_graph_max_nodes
node_features = None if use_active_graph else self._codebook_tensor(device)
for iter_t in range(self.max_iters):
if use_active_graph:
traversal = self._active_node_add(x, vq_indices)
else:
node_aggregated = self._neighbor_aggregate(node_features, threshold)
traversal = x + node_aggregated[vq_indices]
if attn_cb is not None:
traversal = traversal + attn_cb
if iter_t in [1, 3] and memgram_cb_output is not None:
memgram_raw = memgram_cb_output.to(device)
if memgram_raw.shape[-1] != self.cb_dim:
memgram_raw = memgram_raw.mean(dim=-1, keepdim=True).expand(-1, -1, self.cb_dim)
traversal = traversal + memgram_raw
traversal = traversal + self.hop_lora(traversal, iter_t)
trav_norm = F.normalize(traversal, dim=-1, eps=1e-8)
centroid_ids = torch.arange(self.num_experts, device=device)
cent_norm = F.normalize(self.centroids(centroid_ids), dim=-1, eps=1e-8)
scores = trav_norm @ cent_norm.T
if self.top_k <= 1:
_, expert_idx = scores.max(dim=-1)
expert_out = self._run_expert(traversal, expert_idx)
else:
scores_topk, topk_idx = scores.topk(k=self.top_k, dim=-1)
weights = F.softmax(scores_topk / 0.1, dim=-1)
expert_out = 0
for i in range(self.top_k):
wi = weights[..., i:i+1]
ei = topk_idx[..., i]
expert_out = expert_out + wi * self._run_expert(traversal, ei)
last_x = expert_out
p = self.halting(expert_out).squeeze(-1)
still_running = ~halted
remainder = (1.0 - cumulative_p).clamp(min=0)
weight = torch.where(
cumulative_p + p >= self.halt_threshold,
remainder, p,
)
weight = weight * still_running.float()
acc = acc + weight.unsqueeze(-1) * expert_out
cumulative_p = cumulative_p + p * still_running.float()
halted = halted | (cumulative_p >= self.halt_threshold)
total_ponder = total_ponder + (1.0 - cumulative_p).clamp(min=0)
x = self.lti(x, initial_x, expert_out)
if halted.all():
break
never_halted = (~halted).float().unsqueeze(-1)
acc = acc + never_halted * last_x
output = self.up_proj(self.up_norm(acc))
ponder_loss = total_ponder.mean() / self.max_iters
return output, ponder_loss
@torch.no_grad()
def update_kg_edges(self, all_vq_indices):
if self.use_active_edge_store:
self._update_active_edges(all_vq_indices)
return
unique_ids = torch.unique(all_vq_indices.to(device=self.edge_index.device, dtype=torch.int32))
src_in_batch = torch.isin(self.edge_index[0], unique_ids)
if src_in_batch.any():
dst_seen = torch.isin(self.edge_index[1][src_in_batch], unique_ids)
delta = torch.where(
dst_seen,
torch.ones_like(self.edge_score[src_in_batch], dtype=torch.int16),
torch.full_like(self.edge_score[src_in_batch], -1, dtype=torch.int16),
)
score = torch.clamp(self.edge_score[src_in_batch].to(torch.int16) + delta, -128, 127)
self.edge_score[src_in_batch] = score.to(torch.int8)
self._requantize_dense_edges()
@torch.no_grad()
def _update_active_edges(self, all_vq_indices):
ids = all_vq_indices.to(device=self.active_edge_src.device, dtype=torch.int32)
if ids.numel() < 2:
self._steps_since_requant.add_(1)
return
seq = ids.reshape(-1, ids.shape[-1]) if ids.dim() > 1 else ids.reshape(1, -1)
src = seq[:, :-1].reshape(-1)
dst = seq[:, 1:].reshape(-1)
valid = (src >= 0) & (dst >= 0) & (src < self.codebook_size) & (dst < self.codebook_size) & (src != dst)
src = src[valid]
dst = dst[valid]
if src.numel() == 0:
self._steps_since_requant.add_(1)
return
n_edges = min(src.numel(), self.active_edge_capacity)
src = src[-n_edges:]
dst = dst[-n_edges:]
ptr = int(self.active_edge_ptr.item())
slots = (torch.arange(n_edges, device=src.device, dtype=torch.long) + ptr) % self.active_edge_capacity
self.active_edge_src[slots] = src
self.active_edge_dst[slots] = dst
score = torch.clamp(self.active_edge_score[slots].to(torch.int16) + 1, -128, 127)
self.active_edge_score[slots] = score.to(torch.int8)
self.active_edge_attr[slots] = 1
self.active_edge_ptr.fill_((ptr + n_edges) % self.active_edge_capacity)
self._requantize_active_edges()
@torch.no_grad()
def _requantize_dense_edges(self):
if self._steps_since_requant.item() < self.requant_every:
self._steps_since_requant.add_(1)
return
self.edge_attr = self._score_to_attr(self.edge_score)
score = self.edge_score.to(torch.int16)
score = torch.where(score > 0, score - 1, torch.where(score < 0, score + 1, score))
self.edge_score = score.to(torch.int8)
self._steps_since_requant.zero_()
@torch.no_grad()
def _requantize_active_edges(self):
if self._steps_since_requant.item() < self.requant_every:
self._steps_since_requant.add_(1)
return
active = self.active_edge_src >= 0
if active.any():
self.active_edge_attr[active] = self._score_to_attr(self.active_edge_score[active])
score = self.active_edge_score[active].to(torch.int16)
score = torch.where(score > 0, score - 1, torch.where(score < 0, score + 1, score))
self.active_edge_score[active] = score.to(torch.int8)
self._steps_since_requant.zero_()
def _score_to_attr(self, score):
threshold = max(1, int(round(float(self.kg_ternary_threshold) * 8)))
score_i = score.to(torch.int16)
return torch.where(
score_i >= threshold,
torch.ones_like(score, dtype=torch.int8),
torch.where(
score_i <= -threshold,
torch.full_like(score, -1, dtype=torch.int8),
torch.zeros_like(score, dtype=torch.int8),
),
)
@torch.no_grad()
def monitor_graph_health(self, threshold=0.05):
if self.use_active_edge_store:
active = self.active_edge_src >= 0
if not active.any():
return {
"sparsity": 1.0, "isolated_nodes": self.codebook_size,
"avg_polarity": 0.0, "dead_edges": 0,
"score_mean": 0.0, "score_max": 0.0,
"active_edges": 0,
}
edge_attr = self.active_edge_attr[active]
edge_score = self.active_edge_score[active]
nodes_with_edges = torch.unique(torch.cat([self.active_edge_src[active], self.active_edge_dst[active]]))
else:
edge_attr = self.edge_attr
edge_score = self.edge_score
nodes_with_edges = torch.unique(torch.cat([self.edge_index[0], self.edge_index[1]]))
ternary_edge = edge_attr.sign()
sparsity = (ternary_edge == 0).float().mean().item() if ternary_edge.numel() else 1.0
n_isolated = max(int(self.codebook_size) - int(nodes_with_edges.numel()), 0)
n_pos = (ternary_edge > 0).sum().item()
n_neg = (ternary_edge < 0).sum().item()
n_nonzero = n_pos + n_neg
avg_polarity = (n_pos - n_neg) / max(n_nonzero, 1)
dead_edges = ((ternary_edge == 0) & (edge_score != 0)).sum().item()
score_mean = edge_score.float().mean().item() if edge_score.numel() else 0.0
score_max = edge_score.float().abs().max().item() if edge_score.numel() else 0.0
return {
"sparsity": sparsity, "isolated_nodes": n_isolated,
"avg_polarity": avg_polarity, "dead_edges": dead_edges,
"score_mean": score_mean, "score_max": score_max,
"active_edges": int(ternary_edge.numel()),
}
def set_adjacency(self, edge_index, edge_attr_init=None):
self.use_active_edge_store = False
device = self.edge_attr.device
self.edge_index = edge_index.to(device=device, dtype=torch.int32)
if edge_attr_init is not None:
edge_attr = edge_attr_init.sign() * (edge_attr_init.abs() > 0).to(edge_attr_init.dtype)
self.edge_attr = edge_attr.to(device=device, dtype=torch.int8)
else:
self.edge_attr = torch.randint(-1, 2, (edge_index.size(1),),
device=device, dtype=torch.int8)
self.edge_score = self.edge_attr.clone()