File size: 1,762 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""GPURingBuffer — generic GPU ring buffer utility.

O(1) append via circular pointer, chronological get_last_n with wrap handling.
All storage via register_buffer for device movement and state_dict serialization.
"""
import torch
import torch.nn as nn


class GPURingBuffer(nn.Module):
    def __init__(self, max_size: int, dtype: torch.dtype = torch.int32, dim: int = 1):
        super().__init__()
        self.max_size = max_size
        self.ptr = 0
        self.size = 0
        buffer_shape = (max_size, dim if dim > 1 else 1)
        self.register_buffer("buffer", torch.zeros(buffer_shape, dtype=dtype))

    def append(self, x):
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, dtype=self.buffer.dtype, device=self.buffer.device)
        if self.buffer.dim() == 2 and x.dim() == 0:
            x = x.view(1)
        self.buffer[self.ptr] = x
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def get_last_n(self, n: int):
        n = min(n, self.size)
        if n == 0:
            return torch.zeros(0, *self.buffer.shape[1:], dtype=self.buffer.dtype, device=self.buffer.device)
        start = (self.ptr - n) % self.max_size
        if start + n <= self.max_size:
            result = self.buffer[start:start + n]
        else:
            first = self.buffer[start:]
            second = self.buffer[:n - (self.max_size - start)]
            result = torch.cat([first, second], dim=0)
        if result.dim() > 1 and result.shape[1] == 1:
            result = result.squeeze(-1)
        return result

    def get_all(self):
        return self.get_last_n(self.size)

    def reset(self):
        self.buffer.zero_()
        self.ptr = 0
        self.size = 0