File size: 887 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
"""KQ Cache — small ring buffer of last 8K motif IDs for O(1) peek.

Per D-64: Small ring buffer holding last 8K motif IDs. No compression - just raw IDs.
O(1) peek for fast motif lookup without MemGram query.

Per D-65: Updated after each ByteHead output append to ledger.
"""
import torch
import torch.nn as nn
from ..config import KQ_CACHE_SIZE
from .ring_buffer import GPURingBuffer


class KQCache(nn.Module):
    def __init__(self, max_size=KQ_CACHE_SIZE):
        super().__init__()
        self.ring = GPURingBuffer(max_size=max_size, dtype=torch.int32, dim=1)

    def append(self, motif_id: int):
        self.ring.append(torch.tensor(motif_id, dtype=torch.int32, device=self.ring.buffer.device))

    def peek(self, n=1):
        return self.ring.get_last_n(n)

    @property
    def size(self):
        return self.ring.size

    def reset(self):
        self.ring.reset()