ARBS / arbitor /attention /kv_ledger.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""KV Ledger — append-only ring buffer of motif IDs (int32), max 256K entries.
Per D-57: Append-only ring buffer of motif IDs (int32), max 256K entries.
When full, oldest entries are overwritten. Stored as flat tensor on GPU.
Per D-59: The ledger stores only what the model outputs (motif IDs),
not input prompts. Prompts go through VQ -> GNN -> Motif pipeline first.
KV is consumed by the ContextAttentionScheduler. Its output is injected into
MoEGraph, which then conditions the router and output heads through the shared
processed relational state.
"""
import torch
import torch.nn as nn
from ..config import KV_LEDGER_SIZE, SLIDING_WINDOW_SIZE
from .ring_buffer import GPURingBuffer
class KVLedger(nn.Module):
def __init__(self, max_size=KV_LEDGER_SIZE):
super().__init__()
self.ring = GPURingBuffer(max_size=max_size, dtype=torch.int32, dim=1)
def append(self, motif_id: int):
self.ring.append(torch.tensor(motif_id, dtype=torch.int32, device=self.ring.buffer.device))
def get_sliding_window(self, n=SLIDING_WINDOW_SIZE):
return self.ring.get_last_n(n)
def get_range(self, start, end):
n = end - start
if n <= 0 or start >= self.ring.size:
return torch.zeros(0, dtype=torch.int32, device=self.ring.buffer.device)
if start + n <= self.ring.max_size:
return self.ring.buffer[start:start + n].squeeze(-1)
first = self.ring.buffer[start:].squeeze(-1)
second = self.ring.buffer[:n - (self.ring.max_size - start)].squeeze(-1)
return torch.cat([first, second])
def get_sparse(self, stride=8):
size = self.ring.size
if size == 0:
return torch.zeros(0, dtype=torch.int32, device=self.ring.buffer.device)
all_vals = self.ring.get_all()
indices = torch.arange(0, size, stride, device=self.ring.buffer.device, dtype=torch.long)
indices = indices[indices < len(all_vals)]
return all_vals[indices]
@property
def size(self):
return self.ring.size
def __len__(self):
return self.ring.size
def reset(self):
self.ring.reset()