import torch class Memory: def __init__(self, num_nodes, memory_dim, device): self.memory = torch.zeros((num_nodes, memory_dim), device=device) def get(self, node_ids): return self.memory[node_ids].detach() def update(self, node_ids, values): for idx in range(len(node_ids)): self.memory[int(node_ids[idx].item())] = values[idx].detach()