File size: 390 Bytes
a3682cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 | import torch
class Memory:
def __init__(self, num_nodes, memory_dim, device):
self.memory = torch.zeros((num_nodes, memory_dim), device=device)
def get(self, node_ids):
return self.memory[node_ids].detach()
def update(self, node_ids, values):
for idx in range(len(node_ids)):
self.memory[int(node_ids[idx].item())] = values[idx].detach()
|