| import torch |
| import einops |
| import pickle |
| from pathlib import Path |
| from typing import Optional |
| from datasets.core import TrajectoryDataset |
|
|
|
|
| class PushTDataset(TrajectoryDataset): |
| def __init__( |
| self, |
| data_directory, |
| subset_fraction: Optional[float] = None, |
| relative=False, |
| ): |
| self.data_directory = Path(data_directory) |
| self.relative = relative |
| self.states = torch.load(self.data_directory / "states.pth") |
| if relative: |
| self.actions = torch.load(self.data_directory / "rel_actions.pth") |
| else: |
| self.actions = torch.load(self.data_directory / "abs_actions.pth") |
| with open(self.data_directory / "seq_lengths.pkl", "rb") as f: |
| self.seq_lengths = pickle.load(f) |
|
|
| self.subset_fraction = subset_fraction |
| if self.subset_fraction: |
| assert self.subset_fraction > 0 and self.subset_fraction <= 1 |
| n = int(len(self.states) * self.subset_fraction) |
| else: |
| n = len(self.states) |
| self.states = self.states[:n] |
| self.actions = self.actions[:n] |
| self.seq_lengths = self.seq_lengths[:n] |
|
|
| for i in range(n): |
| T = self.seq_lengths[i] |
| self.actions[i, T:] = 0 |
|
|
| def get_seq_length(self, idx): |
| return self.seq_lengths[idx] |
|
|
| def get_all_actions(self): |
| result = [] |
| for i in range(len(self.seq_lengths)): |
| T = self.seq_lengths[i] |
| result.append(self.actions[i, :T, :]) |
| return torch.cat(result, dim=0) |
|
|
| def get_frames(self, idx, frames): |
| vid_dir = self.data_directory / "obses" |
| obs = torch.load(str(vid_dir / f"episode_{idx:03d}.pth")) |
| obs = obs[frames] |
| obs = einops.rearrange(obs, "T H W C -> T 1 C H W") / 255.0 |
| act = self.actions[idx, frames] |
| mask = torch.ones(len(act)).bool() |
| return obs, act, mask |
|
|
| def __getitem__(self, idx): |
| return self.get_frames(idx, range(self.get_seq_length(idx))) |
|
|
| def __len__(self): |
| return len(self.seq_lengths) |
|
|