| import utils |
| import torch |
| import numpy as np |
| from pathlib import Path |
| from torch.utils.data import TensorDataset |
| from datasets.core import TrajectoryDataset |
|
|
|
|
| class YourTrajectoryDataset(TensorDataset, TrajectoryDataset): |
| def __init__(self, data_directory): |
| data_directory = Path(data_directory) |
|
|
| def get_seq_length(self, idx): |
| raise NotImplementedError |
|
|
| def get_frames(self, idx, frames): |
| raise NotImplementedError |
| |
|
|
| def __getitem__(self, idx): |
| T = self.get_seq_length(idx) |
| return self.get_frames(idx, range(T)) |
|
|