| import utils |
| import torch |
| import numpy as np |
| from pathlib import Path |
| from datasets.core import TrajectoryDataset |
|
|
|
|
| class SimKitchenTrajectoryDataset(TrajectoryDataset): |
| def __init__(self, data_directory, prefetch=True, onehot_goals=False): |
| self.data_directory = Path(data_directory) |
| states = torch.from_numpy(np.load(self.data_directory / "observations_seq.npy")) |
| actions = torch.from_numpy(np.load(self.data_directory / "actions_seq.npy")) |
| goals = torch.load(self.data_directory / "onehot_goals.pth") |
| |
| self.states, self.actions, self.goals = utils.transpose_batch_timestep( |
| states, actions, goals |
| ) |
| self.Ts = np.load(self.data_directory / "existence_mask.npy").sum(axis=0).astype(int).tolist() |
| |
| self.prefetch = prefetch |
| if self.prefetch: |
| self.obses = [] |
| for i in range(len(self.Ts)): |
| self.obses.append(torch.load(self.data_directory / "obses" / f"{i:03d}.pth")) |
| self.onehot_goals = onehot_goals |
|
|
| def get_seq_length(self, idx): |
| return self.Ts[idx] |
|
|
| def get_all_actions(self): |
| result = [] |
| |
| for i in range(len(self.Ts)): |
| T = self.Ts[i] |
| result.append(self.actions[i, :T, :]) |
| return torch.cat(result, dim=0) |
|
|
| def get_frames(self, idx, frames): |
| |
| if self.prefetch: |
| obs = self.obses[idx][frames] |
| else: |
| obs = torch.load(self.data_directory / "obses" / f"{idx:03d}.pth")[frames] |
| obs = obs / 255.0 |
| act = self.actions[idx, frames] |
| mask = torch.ones((len(frames))) |
| if self.onehot_goals: |
| goal = self.goals[idx, frames] |
| return obs, act, mask, goal |
| else: |
| return obs, act, mask |
|
|
| def __getitem__(self, idx): |
| T = self.Ts[idx] |
| return self.get_frames(idx, range(T)) |
| |
| def __len__(self): |
| return len(self.Ts) |