| import abc |
| import utils |
| import torch |
| import numpy as np |
| from torch import default_generator, randperm |
| from torch.utils.data import Dataset, Subset |
| from typing import Callable, Optional, Sequence, List, Any |
| from torch.nn.utils.rnn import pad_sequence |
|
|
|
|
| |
| def _accumulate(iterable, fn=lambda x, y: x + y): |
| "Return running totals" |
| |
| |
| it = iter(iterable) |
| try: |
| total = next(it) |
| except StopIteration: |
| return |
| yield total |
| for element in it: |
| total = fn(total, element) |
| yield total |
|
|
|
|
| class TrajectoryDataset(Dataset, abc.ABC): |
| """ |
| A dataset containing trajectories. |
| TrajectoryDataset[i] returns: (observations, actions, mask) |
| observations: Tensor[T, ...], T frames of observations |
| actions: Tensor[T, ...], T frames of actions |
| mask: Tensor[T]: False: invalid; True: valid |
| """ |
|
|
| @abc.abstractmethod |
| def get_seq_length(self, idx): |
| """ |
| Returns the length of the idx-th trajectory. |
| """ |
| raise NotImplementedError |
|
|
| @abc.abstractmethod |
| def get_frames(self, idx, frames): |
| """ |
| Returns the frames from the idx-th trajectory at the specified frames. |
| Used to speed up slicing. |
| """ |
| raise NotImplementedError |
|
|
|
|
| class TrajectorySubset(TrajectoryDataset, Subset): |
| """ |
| Subset of a trajectory dataset at specified indices. |
| |
| Args: |
| dataset (TrajectoryDataset): The whole Dataset |
| indices (sequence): Indices in the whole set selected for subset |
| """ |
|
|
| def __init__(self, dataset: TrajectoryDataset, indices: Sequence[int]): |
| Subset.__init__(self, dataset, indices) |
|
|
| def get_seq_length(self, idx): |
| return self.dataset.get_seq_length(self.indices[idx]) |
|
|
| def get_all_actions(self): |
| return self.dataset.get_all_actions() |
|
|
| def get_frames(self, idx, frames): |
| return self.dataset.get_frames(self.indices[idx], frames) |
|
|
|
|
| class TrajectorySlicerDataset: |
| def __init__( |
| self, |
| dataset: TrajectoryDataset, |
| window: int, |
| future_conditional: bool = False, |
| min_future_sep: int = 0, |
| future_seq_len: Optional[int] = None, |
| only_sample_tail: bool = False, |
| transform: Optional[Callable] = None, |
| num_extra_predicted_actions: Optional[int] = None, |
| frame_step: int = 1, |
| repeat_first_frame: bool = False, |
| ): |
| """ |
| Slice a trajectory dataset into unique (but overlapping) sequences of length `window`. |
| |
| dataset: a trajectory dataset that satisfies: |
| dataset.get_seq_length(i) is implemented to return the length of sequence i |
| dataset[i] = (observations, actions, mask) |
| observations: Tensor[T, ...] |
| actions: Tensor[T, ...] |
| mask: Tensor[T] |
| False: invalid |
| True: valid |
| window: int |
| number of timesteps to include in each slice |
| future_conditional: bool = False |
| if True, observations will be augmented with future observations sampled from the same trajectory |
| min_future_sep: int = 0 |
| minimum number of timesteps between the end of the current sequence and the start of the future sequence |
| for the future conditional |
| future_seq_len: Optional[int] = None |
| the length of the future conditional sequence; |
| required if future_conditional is True |
| only_sample_tail: bool = False |
| if True, only sample future sequences from the tail of the trajectory |
| transform: function (observations, actions, mask[, goal]) -> (observations, actions, mask[, goal]) |
| """ |
| if future_conditional: |
| assert future_seq_len is not None, "must specify a future_seq_len" |
| self.dataset = dataset |
| self.window = window |
| self.future_conditional = future_conditional |
| self.min_future_sep = min_future_sep |
| self.future_seq_len = future_seq_len |
| self.only_sample_tail = only_sample_tail |
| self.transform = transform |
| self.num_extra_predicted_actions = num_extra_predicted_actions or 0 |
| self.slices = [] |
| self.frame_step = frame_step |
| min_seq_length = np.inf |
| if num_extra_predicted_actions: |
| window = window + num_extra_predicted_actions |
| for i in range(len(self.dataset)): |
| T = self.dataset.get_seq_length(i) |
| min_seq_length = min(T, min_seq_length) |
| if T - window < 0: |
| print(f"Ignored short sequence #{i}: len={T}, window={window}") |
| else: |
| if repeat_first_frame: |
| self.slices += [(i, 0, end + 1) for end in range(window - 1)] |
| window_len_with_step = (window - 1) * frame_step + 1 |
| last_start = T - window_len_with_step |
| self.slices += [ |
| (i, start, start + window_len_with_step) |
| for start in range(last_start) |
| ] |
|
|
| if min_seq_length < window: |
| print( |
| f"Ignored short sequences. To include all, set window <= {min_seq_length}." |
| ) |
|
|
| def get_seq_length(self, idx: int) -> int: |
| if self.future_conditional: |
| return self.future_seq_len + self.window |
| else: |
| return self.window |
|
|
| def get_all_actions(self) -> torch.Tensor: |
| return self.dataset.get_all_actions() |
|
|
| def __len__(self): |
| return len(self.slices) |
|
|
| def __getitem__(self, idx): |
| i, start, end = self.slices[idx] |
| T = self.dataset.get_seq_length(i) |
|
|
| if ( |
| self.num_extra_predicted_actions is not None |
| and self.num_extra_predicted_actions != 0 |
| ): |
| assert self.frame_step == 1, "NOT TESTED" |
| if self.future_conditional: |
| raise NotImplementedError( |
| "num_extra_predicted_actions with future_conditional not implemented" |
| ) |
| assert end <= T, f"end={end} > T={T}" |
| observations, actions, mask = self.dataset.get_frames(i, range(start, end)) |
| observations = observations[: self.window] |
|
|
| values = [observations, actions, mask.bool()] |
| else: |
| if self.future_conditional: |
| assert self.frame_step == 1, "NOT TESTED" |
| valid_start_range = ( |
| end + self.min_future_sep, |
| self.dataset.get_seq_length(i) - self.future_seq_len, |
| ) |
| if valid_start_range[0] < valid_start_range[1]: |
| if self.only_sample_tail: |
| future_obs_range = range(T - self.future_seq_len, T) |
| else: |
| future_start = np.random.randint(*valid_start_range) |
| future_end = future_start + self.future_seq_len |
| future_obs_range = range(future_start, future_end) |
| obs, actions, mask = self.dataset.get_frames( |
| i, list(range(start, end)) + list(future_obs_range) |
| ) |
| future_obs = obs[end - start :] |
| obs = obs[: end - start] |
| actions = actions[: end - start] |
| mask = mask[: end - start] |
| else: |
| |
| obs, actions, mask = self.dataset.get_frames(i, range(start, end)) |
| obs_dims = obs.shape[1:] |
| future_obs = torch.zeros((self.future_seq_len, *obs_dims)) |
|
|
| |
| values = [obs, actions, mask.bool(), future_obs] |
| else: |
| observations, actions, mask = self.dataset.get_frames( |
| i, range(start, end, self.frame_step) |
| ) |
| values = [observations, actions, mask.bool()] |
|
|
| if end - start < self.window + self.num_extra_predicted_actions: |
| |
| values = [ |
| utils.inference.repeat_start_to_length( |
| x, self.window + self.num_extra_predicted_actions, dim=0 |
| ) |
| for x in values |
| ] |
| values[0] = values[0][: self.window] |
|
|
| |
| if self.transform is not None: |
| values = self.transform(values) |
| return tuple(values) |
|
|
|
|
| class TrajectoryEmbeddingDataset(TrajectoryDataset): |
| def __init__( |
| self, |
| model, |
| dataset: TrajectoryDataset, |
| device="cpu", |
| embed_goal=False, |
| ): |
| self.data = utils.inference.embed_trajectory_dataset( |
| model, |
| dataset, |
| obs_only=False, |
| device=device, |
| embed_goal=embed_goal, |
| ) |
| assert len(self.data) == len(dataset) |
|
|
| self.seq_lengths = [len(x[0]) for x in self.data] |
| self.on_device_data = [] |
| n_tensors = len(self.data[0]) |
| for i in range(n_tensors): |
| self.on_device_data.append( |
| pad_sequence([x[i] for x in self.data], batch_first=True).to(device) |
| ) |
| self.data = self.on_device_data |
|
|
| def get_seq_length(self, idx): |
| return self.seq_lengths[idx] |
|
|
| def get_all_actions(self): |
| return torch.cat([x[1] for x in self.data], dim=0) |
|
|
| def get_frames(self, idx, frames): |
| return [x[idx, frames] for x in self.data] |
|
|
| def __getitem__(self, idx): |
| return self.get_frames(idx, range(self.get_seq_length(idx))) |
|
|
| def __len__(self): |
| return len(self.seq_lengths) |
|
|
|
|
| def get_train_val_sliced( |
| traj_dataset: TrajectoryDataset, |
| train_fraction: float = 0.9, |
| random_seed: int = 42, |
| window_size: int = 10, |
| future_conditional: bool = False, |
| min_future_sep: int = 0, |
| future_seq_len: Optional[int] = None, |
| only_sample_tail: bool = False, |
| transform: Optional[Callable[[Any], Any]] = None, |
| num_extra_predicted_actions: Optional[int] = None, |
| frame_step: int = 1, |
| ): |
| train, val = split_traj_datasets( |
| traj_dataset, |
| train_fraction=train_fraction, |
| random_seed=random_seed, |
| ) |
| traj_slicer_kwargs = { |
| "window": window_size, |
| "future_conditional": future_conditional, |
| "min_future_sep": min_future_sep, |
| "future_seq_len": future_seq_len, |
| "only_sample_tail": only_sample_tail, |
| "transform": transform, |
| "num_extra_predicted_actions": num_extra_predicted_actions, |
| "frame_step": frame_step, |
| } |
|
|
| train_slices = TrajectorySlicerDataset(train, **traj_slicer_kwargs) |
| val_slices = TrajectorySlicerDataset(val, **traj_slicer_kwargs) |
| return train_slices, val_slices |
|
|
|
|
| def random_split_traj( |
| dataset: TrajectoryDataset, |
| lengths: Sequence[int], |
| generator: Optional[torch.Generator] = default_generator, |
| ) -> List[TrajectorySubset]: |
| """ |
| (Modified from torch.utils.data.dataset.random_split) |
| |
| Randomly split a trajectory dataset into non-overlapping new datasets of given lengths. |
| Optionally fix the generator for reproducible results, e.g.: |
| |
| >>> random_split_traj(range(10), [3, 7], generator=torch.Generator().manual_seed(42)) |
| |
| Args: |
| dataset (TrajectoryDataset): TrajectoryDataset to be split |
| lengths (sequence): lengths of splits to be produced |
| generator (Generator): Generator used for the random permutation. |
| """ |
| |
| if sum(lengths) != len(dataset): |
| raise ValueError( |
| "Sum of input lengths does not equal the length of the input dataset!" |
| ) |
|
|
| indices = randperm(sum(lengths), generator=generator).tolist() |
| return [ |
| TrajectorySubset(dataset, indices[offset - length : offset]) |
| for offset, length in zip(_accumulate(lengths), lengths) |
| ] |
|
|
|
|
| def split_traj_datasets(dataset, train_fraction=0.95, random_seed=42): |
| dataset_length = len(dataset) |
| lengths = [ |
| int(train_fraction * dataset_length), |
| dataset_length - int(train_fraction * dataset_length), |
| ] |
| train_set, val_set = random_split_traj( |
| dataset, lengths, generator=torch.Generator().manual_seed(random_seed) |
| ) |
| return train_set, val_set |
|
|