temporal-twins-code / src /tgn /memory.py
temporal-twins-anon's picture
Add anonymous Temporal Twins code release
a3682cf verified
raw
history blame contribute delete
390 Bytes
import torch
class Memory:
def __init__(self, num_nodes, memory_dim, device):
self.memory = torch.zeros((num_nodes, memory_dim), device=device)
def get(self, node_ids):
return self.memory[node_ids].detach()
def update(self, node_ids, values):
for idx in range(len(node_ids)):
self.memory[int(node_ids[idx].item())] = values[idx].detach()