| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| import numpy as np |
| import torch |
| import os |
| from PIL import Image |
| from typing import Tuple |
| import yaml |
| import pickle |
| import tqdm |
| from torch.utils.data import Dataset |
| from misc import angle_difference, get_data_path, get_delta_np, normalize_data, to_local_coords |
| import torchaudio |
|
|
| class BaseDataset(Dataset): |
| def __init__( |
| self, |
| data_folder: str, |
| data_split_folder: str, |
| dataset_name: str, |
| image_size: Tuple[int, int], |
| min_dist_cat: int, |
| max_dist_cat: int, |
| len_traj_pred: int, |
| traj_stride: int, |
| context_size: int, |
| transform: object, |
| traj_names: str, |
| normalize: bool = True, |
| predefined_index: list = None, |
| goals_per_obs: int = 1, |
| ): |
| self.data_folder = data_folder |
| self.data_split_folder = data_split_folder |
| self.dataset_name = dataset_name |
| self.goals_per_obs = goals_per_obs |
|
|
|
|
| traj_names_file = os.path.join(data_split_folder, traj_names) |
| with open(traj_names_file, "r") as f: |
| file_lines = f.read() |
| self.traj_names = file_lines.split("\n") |
| if "" in self.traj_names: |
| self.traj_names.remove("") |
|
|
| self.image_size = image_size |
| self.distance_categories = list(range(min_dist_cat, max_dist_cat + 1)) |
| self.min_dist_cat = self.distance_categories[0] |
| self.max_dist_cat = self.distance_categories[-1] |
| self.len_traj_pred = len_traj_pred |
| self.traj_stride = traj_stride |
|
|
| self.context_size = context_size |
| self.normalize = normalize |
|
|
| |
| with open("config/data_config.yaml", "r") as f: |
| all_data_config = yaml.safe_load(f) |
|
|
| dataset_names = list(all_data_config.keys()) |
| dataset_names.sort() |
| |
| self.data_config = all_data_config[self.dataset_name] |
| self.transform = transform |
| self._load_index(predefined_index) |
| self.ACTION_STATS = {} |
| for key in all_data_config['action_stats']: |
| self.ACTION_STATS[key] = np.expand_dims(all_data_config['action_stats'][key], axis=0) |
| self.DISTANCE_DIFF_STATS = {} |
| for key in all_data_config['distance_diff_stats']: |
| self.DISTANCE_DIFF_STATS[key] = np.expand_dims(all_data_config['distance_diff_stats'][key], axis=0) |
|
|
| def _load_index(self, predefined_index) -> None: |
| """ |
| Generates a list of tuples of (obs_traj_name, goal_traj_name, obs_time, goal_time) for each observation in the dataset |
| """ |
| if predefined_index: |
| print(f"****** Using a predefined evaluation index... {predefined_index}******") |
| with open(predefined_index, "rb") as f: |
| self.index_to_data = pickle.load(f) |
| return |
| else: |
| print("****** Evaluating from NON PREDEFINED index... ******") |
| index_to_data_path = os.path.join( |
| self.data_split_folder, |
| f"dataset_dist_{self.min_dist_cat}_to_{self.max_dist_cat}_n{self.context_size}_len_traj_pred_{self.len_traj_pred}.pkl", |
| ) |
| |
| self.index_to_data, self.goals_index = self._build_index() |
| with open(index_to_data_path, "wb") as f: |
| pickle.dump((self.index_to_data, self.goals_index), f) |
|
|
| def _build_index(self, use_tqdm: bool = False): |
| """ |
| Build an index consisting of tuples (trajectory name, time, max goal distance) |
| """ |
| samples_index = [] |
| goals_index = [] |
|
|
| for traj_name in tqdm.tqdm(self.traj_names, disable=not use_tqdm, dynamic_ncols=True): |
| traj_data = self._get_trajectory(traj_name) |
| traj_len = len(traj_data["position"]) |
| for goal_time in range(0, traj_len): |
| goals_index.append((traj_name, goal_time)) |
|
|
| begin_time = self.context_size - 1 |
| end_time = traj_len - self.len_traj_pred |
| for curr_time in range(begin_time, end_time, self.traj_stride): |
| max_goal_distance = min(self.max_dist_cat, traj_len - curr_time - 1) |
| min_goal_distance = max(self.min_dist_cat, -curr_time) |
| samples_index.append((traj_name, curr_time, min_goal_distance, max_goal_distance)) |
|
|
| return samples_index, goals_index |
| |
| def _get_trajectory(self, trajectory_name): |
| with open(os.path.join(self.data_folder, trajectory_name, "traj_data.pkl"), "rb") as f: |
| traj_data = pickle.load(f) |
| for k,v in traj_data.items(): |
| traj_data[k] = v.astype('float') |
| return traj_data |
|
|
| def __len__(self) -> int: |
| return len(self.index_to_data) |
|
|
| def _compute_actions(self, traj_data, curr_time, goal_time): |
| start_index = curr_time |
| end_index = curr_time + self.len_traj_pred + 1 |
| yaw = traj_data["yaw"][start_index:end_index] |
| positions = traj_data["position"][start_index:end_index] |
| goal_pos = traj_data["position"][goal_time] |
| goal_yaw = traj_data["yaw"][goal_time] |
| dist_window = traj_data["distance_to_target"][start_index:end_index] |
| goal_dist = traj_data["distance_to_target"][goal_time] |
|
|
| if len(yaw.shape) == 2: |
| yaw = yaw.squeeze(1) |
|
|
| if yaw.shape != (self.len_traj_pred + 1,): |
| raise ValueError("is used?") |
|
|
| waypoints_pos = to_local_coords(positions, positions[0], yaw[0]) |
| waypoints_yaw = angle_difference(yaw[0], yaw) |
| actions = np.concatenate([waypoints_pos, waypoints_yaw.reshape(-1, 1)], axis=-1) |
| actions = actions[1:] |
| |
| goal_pos = to_local_coords(goal_pos, positions[0], yaw[0]) |
| goal_yaw = angle_difference(yaw[0], goal_yaw) |
|
|
| diffs_seq = (dist_window[0] - dist_window).reshape(-1, 1)[1:] |
| goal_diff = (dist_window[0] - goal_dist).reshape(-1, 1) |
| |
| if self.normalize: |
| actions[:, :2] /= self.data_config["metric_waypoint_spacing"] |
| goal_pos[:, :2] /= self.data_config["metric_waypoint_spacing"] |
| diffs_seq /= self.data_config["metric_waypoint_spacing"] |
| goal_diff /= self.data_config["metric_waypoint_spacing"] |
| |
| goal_pos = np.concatenate([goal_pos, goal_yaw.reshape(-1, 1)], axis=-1) |
| return actions, goal_pos, diffs_seq, goal_diff |
|
|
| class TrainingDataset(BaseDataset): |
| def __init__( |
| self, |
| data_folder: str, |
| data_split_folder: str, |
| dataset_name: str, |
| image_size: Tuple[int, int], |
| min_dist_cat: int, |
| max_dist_cat: int, |
| len_traj_pred: int, |
| traj_stride: int, |
| context_size: int, |
| transform: object, |
| traj_names: str = 'traj_names.txt', |
| normalize: bool = True, |
| predefined_index: list = None, |
| goals_per_obs: int = 1, |
| |
| |
| sample_rate: int = 16000, |
| input_sr: int = 48000, |
| evaluate: bool = False |
| ): |
| super().__init__(data_folder, data_split_folder, dataset_name, image_size, min_dist_cat, max_dist_cat, |
| len_traj_pred, traj_stride, context_size, transform, traj_names, normalize, predefined_index, goals_per_obs) |
| self.resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=input_sr, lowpass_filter_width=64) |
| self.evaluate = evaluate |
|
|
| def __getitem__(self, i: int) -> Tuple[torch.Tensor]: |
| try: |
| f_curr, curr_time, min_goal_dist, max_goal_dist = self.index_to_data[i] |
| goal_offset = np.random.randint(min_goal_dist, max_goal_dist + 1, size=(self.goals_per_obs)) |
| goal_time = (curr_time + goal_offset).astype('int') |
| rel_time = (goal_offset).astype('float')/(128.) |
|
|
| context_times = list(range(curr_time - self.context_size + 1, curr_time + 1)) |
| context = [(f_curr, t) for t in context_times] + [(f_curr, t) for t in goal_time] |
|
|
| obs_image = torch.stack([self.transform(Image.open(get_data_path(self.data_folder, f, t))) for f, t in context]) |
| obs_audio = torch.stack([torchaudio.load(get_data_path(self.data_folder, f, t, data_type="audio"))[0] for f, t in context]) |
| if self.evaluate: |
| orig_obs_audio = obs_audio |
| obs_audio = self.resampler(obs_audio) |
|
|
| |
| curr_traj_data = self._get_trajectory(f_curr) |
|
|
| |
| _, goal_pos, _, goal_diff = self._compute_actions(curr_traj_data, curr_time, goal_time) |
| goal_pos[:, :2] = normalize_data(goal_pos[:, :2], self.ACTION_STATS) |
| goal_diff = normalize_data(goal_diff, self.DISTANCE_DIFF_STATS) |
|
|
| if self.evaluate: |
| return ( |
| torch.as_tensor(obs_image, dtype=torch.float32), |
| torch.as_tensor(obs_audio, dtype=torch.float32), |
| torch.as_tensor(goal_pos, dtype=torch.float32), |
| torch.as_tensor(goal_diff, dtype=torch.float32), |
| torch.as_tensor(rel_time, dtype=torch.float32), |
| torch.as_tensor(orig_obs_audio, dtype=torch.float32), |
| ) |
| else: |
| return ( |
| torch.as_tensor(obs_image, dtype=torch.float32), |
| torch.as_tensor(obs_audio, dtype=torch.float32), |
| torch.as_tensor(goal_pos, dtype=torch.float32), |
| torch.as_tensor(goal_diff, dtype=torch.float32), |
| torch.as_tensor(rel_time, dtype=torch.float32), |
| ) |
| except Exception as e: |
| print(f"Exception in {self.dataset_name}", e) |
| raise Exception(e) |
|
|
| class EvalDataset(BaseDataset): |
| def __init__( |
| self, |
| data_folder: str, |
| data_split_folder: str, |
| dataset_name: str, |
| image_size: Tuple[int, int], |
| min_dist_cat: int, |
| max_dist_cat: int, |
| len_traj_pred: int, |
| traj_stride: int, |
| context_size: int, |
| transform: object, |
| traj_names: str, |
| normalize: bool = True, |
| predefined_index: list = None, |
| goals_per_obs: int = 1, |
| sample_rate: int = 16000, |
| input_sr: int = 48000 |
| ): |
| super().__init__(data_folder, data_split_folder, dataset_name, image_size, min_dist_cat, max_dist_cat, |
| len_traj_pred, traj_stride, context_size, transform, traj_names, normalize, predefined_index, goals_per_obs) |
| self.resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=input_sr, lowpass_filter_width=64) |
| |
| def __getitem__(self, i: int) -> Tuple[torch.Tensor]: |
| try: |
| f_curr, curr_time, _, _ = self.index_to_data[i] |
| context_times = list(range(curr_time - self.context_size + 1, curr_time + 1)) |
| pred_times = list(range(curr_time + 1, curr_time + self.len_traj_pred + 1)) |
|
|
| context = [(f_curr, t) for t in context_times] |
| pred = [(f_curr, t) for t in pred_times] |
|
|
| obs_image = torch.stack([self.transform(Image.open(get_data_path(self.data_folder, f, t))) for f, t in context]) |
| pred_image = torch.stack([self.transform(Image.open(get_data_path(self.data_folder, f, t))) for f, t in pred]) |
|
|
| orig_obs_audio = torch.stack([torchaudio.load(get_data_path(self.data_folder, f, t, data_type="audio"))[0] for f, t in context]) |
| orig_pred_audio = torch.stack([torchaudio.load(get_data_path(self.data_folder, f, t, data_type="audio"))[0] for f, t in pred]) |
|
|
| obs_audio = self.resampler(orig_obs_audio) |
| pred_audio = self.resampler(orig_pred_audio) |
|
|
| curr_traj_data = self._get_trajectory(f_curr) |
|
|
| |
| actions, _, diffs_seq, _ = self._compute_actions(curr_traj_data, curr_time, np.array([curr_time+1])) |
| actions[:, :2] = normalize_data(actions[:, :2], self.ACTION_STATS) |
| diffs_seq = normalize_data(diffs_seq, self.DISTANCE_DIFF_STATS) |
|
|
| delta = get_delta_np(actions) |
| diffs_seq = get_delta_np(diffs_seq) |
|
|
| return ( |
| torch.tensor([i], dtype=torch.float32), |
| torch.as_tensor(obs_image, dtype=torch.float32), |
| torch.as_tensor(pred_image, dtype=torch.float32), |
| torch.as_tensor(obs_audio, dtype=torch.float32), |
| torch.as_tensor(pred_audio, dtype=torch.float32), |
| torch.as_tensor(diffs_seq, dtype=torch.float32), |
| torch.as_tensor(delta, dtype=torch.float32), |
| torch.as_tensor(orig_obs_audio, dtype=torch.float32), |
| torch.as_tensor(orig_pred_audio, dtype=torch.float32), |
| ) |
| except Exception as e: |
| print(f"Exception in {self.dataset_name}", e) |
| raise Exception(e) |