"""GPURingBuffer — generic GPU ring buffer utility. O(1) append via circular pointer, chronological get_last_n with wrap handling. All storage via register_buffer for device movement and state_dict serialization. """ import torch import torch.nn as nn class GPURingBuffer(nn.Module): def __init__(self, max_size: int, dtype: torch.dtype = torch.int32, dim: int = 1): super().__init__() self.max_size = max_size self.ptr = 0 self.size = 0 buffer_shape = (max_size, dim if dim > 1 else 1) self.register_buffer("buffer", torch.zeros(buffer_shape, dtype=dtype)) def append(self, x): if not isinstance(x, torch.Tensor): x = torch.tensor(x, dtype=self.buffer.dtype, device=self.buffer.device) if self.buffer.dim() == 2 and x.dim() == 0: x = x.view(1) self.buffer[self.ptr] = x self.ptr = (self.ptr + 1) % self.max_size self.size = min(self.size + 1, self.max_size) def get_last_n(self, n: int): n = min(n, self.size) if n == 0: return torch.zeros(0, *self.buffer.shape[1:], dtype=self.buffer.dtype, device=self.buffer.device) start = (self.ptr - n) % self.max_size if start + n <= self.max_size: result = self.buffer[start:start + n] else: first = self.buffer[start:] second = self.buffer[:n - (self.max_size - start)] result = torch.cat([first, second], dim=0) if result.dim() > 1 and result.shape[1] == 1: result = result.squeeze(-1) return result def get_all(self): return self.get_last_n(self.size) def reset(self): self.buffer.zero_() self.ptr = 0 self.size = 0