Diffusers
Safetensors
File size: 2,857 Bytes
40a3ea8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import random

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms

from . import aug_basic

import logging
logger = logging.getLogger('root')


def get_transform(dataset_name='hypersim', mode='test'):
    assert mode in ['test']
    logger.info('Defining %s transform for %s dataset' % (mode, dataset_name))
    tf_list = [
        aug_basic.ToTensor(),
    ]
    tf_list += [
        aug_basic.Normalize(mean=[0.5],std=[0.5]),
        aug_basic.ToDict(),
    ]
    logger.info('Defining %s transform for %s dataset ... DONE' % (mode, dataset_name))
    return transforms.Compose(tf_list)


class NormalDataset(Dataset):
    def __init__(self, base_data_dir, dataset_split_path, dataset_name='nyuv2', split='test', mode='test',  epoch=0):
        self.split = split
        self.mode = mode
        self.base_data_dir = base_data_dir 
        assert mode in ['test']

        # data split
        split_path = os.path.join(dataset_split_path, dataset_name, 'split', split+'.txt') # dataset_split_path: eval/dataset_normal/
        assert os.path.exists(split_path)
        with open(split_path, 'r') as f:
            self.filenames = [i.strip() for i in f.readlines()]
        self.split_path = split_path

        # get_sample function
        if dataset_name == 'nyuv2':
            from evaluation.dataset_normal.nyuv2 import get_sample
        elif dataset_name == 'scannet':
            from evaluation.dataset_normal.scannet import get_sample
        elif dataset_name == 'ibims':
            from evaluation.dataset_normal.ibims import get_sample
        elif dataset_name == 'sintel':
            from evaluation.dataset_normal.sintel import get_sample
        elif dataset_name == 'vkitti':
            from evaluation.dataset_normal.vkitti import get_sample
        elif dataset_name == 'oasis':
            from evaluation.dataset_normal.oasis import get_sample
        self.get_sample = get_sample

        # data preprocessing/augmentation
        self.transform = get_transform(dataset_name=dataset_name, mode=mode)

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, index):
        info = {}

        sample = self.transform(self.get_sample(
            base_data_dir = self.base_data_dir,
            sample_path=self.filenames[index], 
            info=info)
        )

        return sample            
    
class TestLoader(object):
    def __init__(self, base_data_dir, dataset_split_path, dataset_name_test, test_split):
        self.test_samples = NormalDataset(base_data_dir, dataset_split_path, dataset_name=dataset_name_test, 
                                          split=test_split, mode='test', epoch=None)
        self.data = DataLoader(self.test_samples, 1, shuffle=False, num_workers=1, pin_memory=True)