File size: 1,762 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | """GPURingBuffer — generic GPU ring buffer utility.
O(1) append via circular pointer, chronological get_last_n with wrap handling.
All storage via register_buffer for device movement and state_dict serialization.
"""
import torch
import torch.nn as nn
class GPURingBuffer(nn.Module):
def __init__(self, max_size: int, dtype: torch.dtype = torch.int32, dim: int = 1):
super().__init__()
self.max_size = max_size
self.ptr = 0
self.size = 0
buffer_shape = (max_size, dim if dim > 1 else 1)
self.register_buffer("buffer", torch.zeros(buffer_shape, dtype=dtype))
def append(self, x):
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, dtype=self.buffer.dtype, device=self.buffer.device)
if self.buffer.dim() == 2 and x.dim() == 0:
x = x.view(1)
self.buffer[self.ptr] = x
self.ptr = (self.ptr + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)
def get_last_n(self, n: int):
n = min(n, self.size)
if n == 0:
return torch.zeros(0, *self.buffer.shape[1:], dtype=self.buffer.dtype, device=self.buffer.device)
start = (self.ptr - n) % self.max_size
if start + n <= self.max_size:
result = self.buffer[start:start + n]
else:
first = self.buffer[start:]
second = self.buffer[:n - (self.max_size - start)]
result = torch.cat([first, second], dim=0)
if result.dim() > 1 and result.shape[1] == 1:
result = result.squeeze(-1)
return result
def get_all(self):
return self.get_last_n(self.size)
def reset(self):
self.buffer.zero_()
self.ptr = 0
self.size = 0
|