import torch from torch.utils.data import Dataset import random import math def augment_data(obs, neighbors, future): # obs: (4, 7) tensor # neighbors: list of (4, 7) tensors # future: (12, 2) tensor # Random Scene Rotation (0-360) theta = random.uniform(0, 2 * math.pi) cos_t = math.cos(theta) sin_t = math.sin(theta) # Random X-axis reflection flip_x = random.choice([-1.0, 1.0]) # Gaussian Coordinate Noise noise_std = 0.05 def apply_transform(feat, is_obs=True): new_feat = feat.clone() for i in range(new_feat.size(0)): x, y = new_feat[i, 0].item(), new_feat[i, 1].item() # Apply Noise x += random.gauss(0, noise_std) y += random.gauss(0, noise_std) # Apply Flip x *= flip_x # Apply Rotation nx = x * cos_t - y * sin_t ny = x * sin_t + y * cos_t new_feat[i, 0] = nx new_feat[i, 1] = ny if is_obs: # Transform dx, dy dx, dy = new_feat[i, 2].item(), new_feat[i, 3].item() dx *= flip_x ndx = dx * cos_t - dy * sin_t ndy = dx * sin_t + dy * cos_t new_feat[i, 2] = ndx new_feat[i, 3] = ndy # Recompute sin_t, cos_t based on new dx, dy to be safe speed = math.hypot(ndx, ndy) if speed > 1e-5: new_feat[i, 5] = ndy / speed new_feat[i, 6] = ndx / speed else: new_feat[i, 5] = 0.0 new_feat[i, 6] = 0.0 return new_feat new_obs = apply_transform(obs, is_obs=True) new_future = apply_transform(future, is_obs=False) new_neighbors = [] for n in neighbors: # n is (4, 7) tensor if not isinstance(n, torch.Tensor): n = torch.tensor(n, dtype=torch.float32) new_neighbors.append(apply_transform(n, is_obs=True)) return new_obs, new_neighbors, new_future class TrajectoryDataset(Dataset): def __init__(self, samples, augment=False): self.obs = [] self.neighbors = [] self.future = [] self.augment = augment for obs, neighbors, future in samples: self.obs.append(obs) self.neighbors.append(neighbors) self.future.append(future) # Convert to tensors self.obs = torch.tensor(self.obs, dtype=torch.float32) self.future = torch.tensor(self.future, dtype=torch.float32) # Neighbors remain lists of matrices, will convert in getitem or augment def __len__(self): return len(self.obs) def __getitem__(self, idx): obs = self.obs[idx].clone() future = self.future[idx].clone() neighbors = [torch.tensor(n, dtype=torch.float32) for n in self.neighbors[idx]] if self.augment: obs, neighbors, future = augment_data(obs, neighbors, future) return obs, neighbors, future