| """KV Ledger — append-only ring buffer of motif IDs (int32), max 256K entries. |
| |
| Per D-57: Append-only ring buffer of motif IDs (int32), max 256K entries. |
| When full, oldest entries are overwritten. Stored as flat tensor on GPU. |
| |
| Per D-59: The ledger stores only what the model outputs (motif IDs), |
| not input prompts. Prompts go through VQ -> GNN -> Motif pipeline first. |
| |
| KV is consumed by the ContextAttentionScheduler. Its output is injected into |
| MoEGraph, which then conditions the router and output heads through the shared |
| processed relational state. |
| """ |
| import torch |
| import torch.nn as nn |
| from ..config import KV_LEDGER_SIZE, SLIDING_WINDOW_SIZE |
| from .ring_buffer import GPURingBuffer |
|
|
|
|
| class KVLedger(nn.Module): |
| def __init__(self, max_size=KV_LEDGER_SIZE): |
| super().__init__() |
| self.ring = GPURingBuffer(max_size=max_size, dtype=torch.int32, dim=1) |
|
|
| def append(self, motif_id: int): |
| self.ring.append(torch.tensor(motif_id, dtype=torch.int32, device=self.ring.buffer.device)) |
|
|
| def get_sliding_window(self, n=SLIDING_WINDOW_SIZE): |
| return self.ring.get_last_n(n) |
|
|
| def get_range(self, start, end): |
| n = end - start |
| if n <= 0 or start >= self.ring.size: |
| return torch.zeros(0, dtype=torch.int32, device=self.ring.buffer.device) |
| if start + n <= self.ring.max_size: |
| return self.ring.buffer[start:start + n].squeeze(-1) |
| first = self.ring.buffer[start:].squeeze(-1) |
| second = self.ring.buffer[:n - (self.ring.max_size - start)].squeeze(-1) |
| return torch.cat([first, second]) |
|
|
| def get_sparse(self, stride=8): |
| size = self.ring.size |
| if size == 0: |
| return torch.zeros(0, dtype=torch.int32, device=self.ring.buffer.device) |
| all_vals = self.ring.get_all() |
| indices = torch.arange(0, size, stride, device=self.ring.buffer.device, dtype=torch.long) |
| indices = indices[indices < len(all_vals)] |
| return all_vals[indices] |
|
|
| @property |
| def size(self): |
| return self.ring.size |
|
|
| def __len__(self): |
| return self.ring.size |
|
|
| def reset(self): |
| self.ring.reset() |
|
|