| """GPURingBuffer — generic GPU ring buffer utility. |
| |
| O(1) append via circular pointer, chronological get_last_n with wrap handling. |
| All storage via register_buffer for device movement and state_dict serialization. |
| """ |
| import torch |
| import torch.nn as nn |
|
|
|
|
| class GPURingBuffer(nn.Module): |
| def __init__(self, max_size: int, dtype: torch.dtype = torch.int32, dim: int = 1): |
| super().__init__() |
| self.max_size = max_size |
| self.ptr = 0 |
| self.size = 0 |
| buffer_shape = (max_size, dim if dim > 1 else 1) |
| self.register_buffer("buffer", torch.zeros(buffer_shape, dtype=dtype)) |
|
|
| def append(self, x): |
| if not isinstance(x, torch.Tensor): |
| x = torch.tensor(x, dtype=self.buffer.dtype, device=self.buffer.device) |
| if self.buffer.dim() == 2 and x.dim() == 0: |
| x = x.view(1) |
| self.buffer[self.ptr] = x |
| self.ptr = (self.ptr + 1) % self.max_size |
| self.size = min(self.size + 1, self.max_size) |
|
|
| def get_last_n(self, n: int): |
| n = min(n, self.size) |
| if n == 0: |
| return torch.zeros(0, *self.buffer.shape[1:], dtype=self.buffer.dtype, device=self.buffer.device) |
| start = (self.ptr - n) % self.max_size |
| if start + n <= self.max_size: |
| result = self.buffer[start:start + n] |
| else: |
| first = self.buffer[start:] |
| second = self.buffer[:n - (self.max_size - start)] |
| result = torch.cat([first, second], dim=0) |
| if result.dim() > 1 and result.shape[1] == 1: |
| result = result.squeeze(-1) |
| return result |
|
|
| def get_all(self): |
| return self.get_last_n(self.size) |
|
|
| def reset(self): |
| self.buffer.zero_() |
| self.ptr = 0 |
| self.size = 0 |
|
|