IntentDrive / backend /app /legacy /dataset_fusion.py
sajith-0701
Deploy FastAPI backend to HF Spaces (Docker SDK)
98075af
import torch
from torch.utils.data import Dataset
from .dataset import augment_data
class FusionTrajectoryDataset(Dataset):
def __init__(self, samples, augment=False):
self.obs = []
self.neighbors = []
self.fusion = []
self.future = []
self.augment = augment
for obs, neighbors, fusion_obs, future in samples:
self.obs.append(obs)
self.neighbors.append(neighbors)
self.fusion.append(fusion_obs)
self.future.append(future)
self.obs = torch.tensor(self.obs, dtype=torch.float32)
self.fusion = torch.tensor(self.fusion, dtype=torch.float32)
self.future = torch.tensor(self.future, dtype=torch.float32)
def __len__(self):
return len(self.obs)
def __getitem__(self, idx):
obs = self.obs[idx].clone()
fusion_obs = self.fusion[idx].clone()
future = self.future[idx].clone()
neighbors = [torch.tensor(n, dtype=torch.float32) for n in self.neighbors[idx]]
if self.augment:
obs, neighbors, future = augment_data(obs, neighbors, future)
return obs, neighbors, fusion_obs, future