File size: 390 Bytes
a3682cf
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch


class Memory:
    def __init__(self, num_nodes, memory_dim, device):
        self.memory = torch.zeros((num_nodes, memory_dim), device=device)

    def get(self, node_ids):
        return self.memory[node_ids].detach()

    def update(self, node_ids, values):
        for idx in range(len(node_ids)):
            self.memory[int(node_ids[idx].item())] = values[idx].detach()