| |
| |
| |
| |
|
|
| import json |
| import os |
| from abc import abstractmethod |
| from pathlib import Path |
|
|
| import json5 |
| import torch |
| import yaml |
|
|
|
|
| |
| class BaseDataset(torch.utils.data.Dataset): |
| r"""Base dataset for training and validating.""" |
|
|
| def __init__(self, args, cfg, is_valid=False): |
| pass |
|
|
|
|
| class BaseTestDataset(torch.utils.data.Dataset): |
| r"""Test dataset for inference.""" |
|
|
| def __init__(self, args=None, cfg=None, infer_type="from_dataset"): |
| assert infer_type in ["from_dataset", "from_file"] |
|
|
| self.args = args |
| self.cfg = cfg |
| self.infer_type = infer_type |
|
|
| @abstractmethod |
| def __getitem__(self, index): |
| pass |
|
|
| def __len__(self): |
| return len(self.metadata) |
|
|
| def get_metadata(self): |
| path = Path(self.args.source) |
| if path.suffix == ".json" or path.suffix == ".jsonc": |
| metadata = json5.load(open(self.args.source, "r")) |
| elif path.suffix == ".yaml" or path.suffix == ".yml": |
| metadata = yaml.full_load(open(self.args.source, "r")) |
| else: |
| raise ValueError(f"Unsupported file type: {path.suffix}") |
|
|
| return metadata |
|
|