| import abc |
| import utils |
| import torch |
| import numpy as np |
| from torch.utils.data import Dataset |
| from typing import Optional, Callable |
|
|
|
|
| class TrajectoryDataset(Dataset, abc.ABC): |
| """ |
| A dataset containing trajectories. |
| TrajectoryDataset[i] returns: (observations, actions, mask) |
| observations: Tensor[T, ...], T frames of observations |
| actions: Tensor[T, ...], T frames of actions |
| mask: Tensor[T]: 0: invalid; 1: valid |
| """ |
|
|
| @abc.abstractmethod |
| def get_seq_length(self, idx): |
| """ |
| Returns the length of the idx-th trajectory. |
| """ |
| raise NotImplementedError |
|
|
|
|
| class TrajectorySlicerDataset(TrajectoryDataset): |
| def __init__( |
| self, |
| dataset: TrajectoryDataset, |
| window: int, |
| action_window: int, |
| vqbet_get_future_action_chunk: bool = True, |
| future_conditional: bool = False, |
| min_future_sep: int = 0, |
| future_seq_len: Optional[int] = None, |
| only_sample_tail: bool = False, |
| transform: Optional[Callable] = None, |
| use_libero_goal: bool = False, |
| ): |
| if future_conditional: |
| assert future_seq_len is not None, "must specify a future_seq_len" |
| self.dataset = dataset |
| self.window = window |
| self.action_window = action_window |
| self.vqbet_get_future_action_chunk = vqbet_get_future_action_chunk |
| self.future_conditional = future_conditional |
| self.min_future_sep = min_future_sep |
| self.future_seq_len = future_seq_len |
| self.only_sample_tail = only_sample_tail |
| self.transform = transform |
| self.slices = [] |
| self.use_libero_goal = use_libero_goal |
| min_seq_length = np.inf |
| if vqbet_get_future_action_chunk: |
| min_window_required = window + action_window |
| else: |
| min_window_required = max(window, action_window) |
| for i in range(len(self.dataset)): |
| T = self.dataset.get_seq_length(i) |
| min_seq_length = min(T, min_seq_length) |
| if T - min_window_required < 0: |
| print( |
| f"Ignored short sequence #{i}: len={T}, window={min_window_required}" |
| ) |
| else: |
| self.slices += [ |
| (i, 0, end + 1) for end in range(window - 1) |
| ] |
| self.slices += [ |
| (i, start, start + window) |
| for start in range(T - min_window_required) |
| ] |
|
|
| if min_seq_length < min_window_required: |
| print( |
| f"Ignored short sequences. To include all, set window <= {min_seq_length}." |
| ) |
|
|
| def get_seq_length(self, idx: int) -> int: |
| if self.future_conditional: |
| return self.future_seq_len + self.window |
| else: |
| return self.window |
|
|
| def __len__(self): |
| return len(self.slices) |
|
|
| def __getitem__(self, idx): |
| i, start, end = self.slices[idx] |
| if end - start < self.window: |
| obs, act, *others = self.dataset[i] |
| obs = utils.inference.repeat_start_to_length( |
| obs[start:end], self.window, dim=0 |
| ) |
| act = utils.inference.repeat_start_to_length( |
| act[start : end - 1 + self.action_window], |
| self.window + self.action_window - 1, |
| dim=0, |
| ) |
| values = [obs, act] |
| else: |
| values = [ |
| self.dataset[i][0][start:end], |
| self.dataset[i][1][start : end - 1 + self.action_window], |
| ] |
|
|
| if self.use_libero_goal: |
| goals = self.dataset[i][2][start:end] |
| if end - start < self.window: |
| goals = utils.inference.repeat_start_to_length( |
| goals, self.window, dim=0 |
| ) |
| values.append(goals) |
|
|
| |
| if self.transform is not None: |
| values = self.transform(values) |
| if len(values) == 2: |
| values.append(torch.ones([1, 1, 1])) |
| return tuple(values) |
|
|