Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.utils.data import Dataset | |
| from .dataset import augment_data | |
| class FusionTrajectoryDataset(Dataset): | |
| def __init__(self, samples, augment=False): | |
| self.obs = [] | |
| self.neighbors = [] | |
| self.fusion = [] | |
| self.future = [] | |
| self.augment = augment | |
| for obs, neighbors, fusion_obs, future in samples: | |
| self.obs.append(obs) | |
| self.neighbors.append(neighbors) | |
| self.fusion.append(fusion_obs) | |
| self.future.append(future) | |
| self.obs = torch.tensor(self.obs, dtype=torch.float32) | |
| self.fusion = torch.tensor(self.fusion, dtype=torch.float32) | |
| self.future = torch.tensor(self.future, dtype=torch.float32) | |
| def __len__(self): | |
| return len(self.obs) | |
| def __getitem__(self, idx): | |
| obs = self.obs[idx].clone() | |
| fusion_obs = self.fusion[idx].clone() | |
| future = self.future[idx].clone() | |
| neighbors = [torch.tensor(n, dtype=torch.float32) for n in self.neighbors[idx]] | |
| if self.augment: | |
| obs, neighbors, future = augment_data(obs, neighbors, future) | |
| return obs, neighbors, fusion_obs, future | |