Diffusers
Safetensors
File size: 11,858 Bytes
7f921f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
# Copyright (C) 2019 Jin Han Lee
#
# This file is a part of BTS.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>

from __future__ import absolute_import, division, print_function
from utils.metric import compute_normal_metrics
import os
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
import argparse
import fnmatch
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'

def hwc2chw(array):
    return array.transpose(2, 0, 1)
def chw2hwc(array):
    return array.transpose(1, 2, 0)

def convert_arg_line_to_args(arg_line):
    for arg in arg_line.split():
        if not arg.strip():
            continue
        yield arg



def resize_tensor(input_tensor, target_height, target_width):
    """
    使用双线性插值调整深度图像大小
    """
    # 多通道resize
    input_tensor = torch.from_numpy(hwc2chw(input_tensor)).unsqueeze(0).float()
    # 使用双线性插值调整大小
    resized_tensor = F.interpolate(
        input_tensor, 
        size=(target_height, target_width), 
        mode='bilinear', 
        align_corners=True
    )
    # 转换回numpy
    resized = resized_tensor.squeeze().numpy()
    resized = chw2hwc(resized)
    # input_tensor = np.ascontiguousarray(input_tensor)
    # resized = cv2.resize(input_tensor, (target_width, target_height), interpolation=cv2.INTER_LINEAR)
    return resized

def load_image_rgb_or_grayscale(image_path):
    """
    加载图像,支持RGB和灰度图像,统一转换为numpy数组
    """
    if image_path.endswith('.npy'):
        # 如果是npy文件,直接读取
        img_array = np.load(image_path)
    else:
    # 首先尝试用PIL加载,可以更好地处理不同格式
        img = Image.open(image_path)
        img_array = np.array(img)
        
        # 如果是RGBA图像,取前3通道,否则RGB或者Gray则不处理
        if len(img_array.shape) == 3 and img_array.shape[2] == 4:  # RGBA
            img_array = img_array[:, :, :3]  # 去掉Alpha通道
        # 对3通道取均值返回
    return img_array


def test(args):
    global gt_depths, missing_ids, pred_filenames,gt_depths_mask
    gt_depths = []
    gt_depths_mask = []
    missing_ids = set()
    pred_filenames = []
    if getattr(args, 'txt_file_list', None) is not None:
        with open(args.txt_file_list, 'r') as f:
            lines = f.readlines()
        for i,line in enumerate(lines):
            line = line.strip().split()[0]
            if line == '':
                continue
            pred_filenames.append(line.replace(".png",".npy"))
    else:
        for root, dirnames, filenames in os.walk(args.pred_path):
            for pred_filename in fnmatch.filter(filenames, '*.png') + fnmatch.filter(filenames, '*.jpg') + fnmatch.filter(filenames, '*.npy'):
                if 'cmap' in pred_filename or 'gt' in pred_filename:
                    continue
                dirname = root.replace(args.pred_path, '')
                if dirname.startswith('/'):
                    dirname = dirname[1:]
                pred_filenames.append(os.path.join(dirname, pred_filename))

    num_test_samples = len(pred_filenames)
    print(f'Found {num_test_samples} prediction files.')
    pred_depths = []

    for i in tqdm(range(num_test_samples)):
        pred_depth_path = os.path.join(args.pred_path,pred_filenames[i])
        pred_depth = load_image_rgb_or_grayscale(pred_depth_path)

        
        if pred_depth is None:
            print('Missing: %s ' % pred_depth_path)
            missing_ids.add(i)
            continue

        # 预测图像是0-255的relative depth,先转换为float
        pred_depth = pred_depth.astype(np.float32)
        
        pred_depths.append(pred_depth)


    # 加载GT深度图
    if args.dataset == 'nyu' or args.dataset == 'scannet' or args.dataset == 'ibims' or args.dataset == 'oasis':
        for t_id in range(num_test_samples):
            if t_id in missing_ids:
                continue
                
            # 构建GT路径,保持与pred相同的目录结构
            pred_relative_path = pred_filenames[t_id]
            gt_depth_path = os.path.join(args.gt_path, pred_relative_path)
            gt_depth_path = gt_depth_path.replace("_img.npy","_normal.npy")
            depth = load_image_rgb_or_grayscale(gt_depth_path)
            if depth is None:
                print('Missing: %s ' % gt_depth_path)
                missing_ids.add(t_id)
                continue
            gt_depths.append(depth)
    elif args.dataset == 'diode':
        for t_id in range(num_test_samples):
            if t_id in missing_ids:
                continue
                
            # 构建GT路径,保持与pred相同的目录结构
            pred_relative_path = pred_filenames[t_id]
            gt_depth_path = os.path.join(args.gt_path, pred_relative_path)
            gt_depth_path = gt_depth_path.replace(".npy","_normal.npy")
            gt_depth_mask_path = gt_depth_path.replace("_depth.npy","_depth_mask.npy")
            depth = load_image_rgb_or_grayscale(gt_depth_path)
            depth_mask = load_image_rgb_or_grayscale(gt_depth_mask_path)
            if depth is None:
                print('Missing: %s ' % gt_depth_path)
                missing_ids.add(t_id)
                continue
            gt_depths.append(depth)
            gt_depths_mask.append(depth_mask)
    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}")
    print(f'### Computing errors for {len(gt_depths)} files with {len(missing_ids)} missing' if not gt_depths_mask else 'Computing errors with masks')

    results = eval(pred_depths,args)

    print('Done.')
    return results

def eval(pred_depths,args):
    num_samples = len(pred_depths)
    pred_depths_valid = []
    gt_depths_valid = []

    # 收集有效的预测和GT深度
    gt_idx = 0
    for t_id in range(num_samples):
        if t_id in missing_ids:
            continue

        pred_depths_valid.append(pred_depths[t_id])
        gt_depths_valid.append(gt_depths[gt_idx])
        gt_idx += 1

    num_samples = len(pred_depths_valid)

    mean_angular_error = np.zeros(num_samples, dtype=np.float32)
    median_angular_error = np.zeros(num_samples, dtype=np.float32)
    rmse_angular_error = np.zeros(num_samples, dtype=np.float32)
    sub5_error = np.zeros(num_samples, dtype=np.float32)
    sub7_5_error = np.zeros(num_samples, dtype=np.float32)
    sub11_25_error = np.zeros(num_samples, dtype=np.float32)
    sub22_5_error = np.zeros(num_samples, dtype=np.float32)
    sub30_error = np.zeros(num_samples, dtype=np.float32)
    
    for i in range(num_samples):
        gt_depth = gt_depths_valid[i]
        gt_depth[:,:,0] *= -1
        gt_depth[np.isinf(gt_depth)] = 0
        gt_depth[np.isnan(gt_depth)] = 0
        pred_depth = pred_depths_valid[i]
        pred_depth[np.isinf(pred_depth)] = 0
        pred_depth[np.isnan(pred_depth)] = 0

        # 1. 首先调整预测深度的大小以匹配GT
        if pred_depth.shape != gt_depth.shape:
            pred_depth = resize_tensor(pred_depth, gt_depth.shape[0], gt_depth.shape[1])
        # if i < 5:
        #     H, W, _ = gt_depth.shape
        #     # num_points = 200
        #     # ys = np.random.randint(0, H, size=num_points)
        #     # xs = np.random.randint(0, W, size=num_points)
        #     # make grid to sample
        #     sep = 20
        #     grid_y, grid_x = np.mgrid[0:H:sep, 0:W:sep]
        #     ys, xs = grid_y.ravel(), grid_x.ravel()

        #     # 取出法向量 (x,y,z)
        #     gt_normals = gt_depth[ys, xs, :]
        #     pred_normals = pred_depth[ys, xs, :]

        #     # 归一化
        #     gt_normals = gt_normals / (np.linalg.norm(gt_normals, axis=1, keepdims=True) + 1e-8)
        #     pred_normals = pred_normals / (np.linalg.norm(pred_normals, axis=1, keepdims=True) + 1e-8)

        #     plt.figure(figsize=(18, 6))

        #     # -------- 左:GT 法线 --------
        #     plt.subplot(1, 3, 1)
        #     plt.imshow((gt_depth * 127.5 + 127.5).astype(np.uint8))  # normal map可视化到[0,255]
        #     plt.quiver(xs, ys, gt_normals[:, 0], -gt_normals[:, 1], color='r', scale=20, width=0.005)
        #     plt.title(f'GT Normals {i}')
        #     plt.axis('off')

        #     # -------- 中:Pred 法线 --------
        #     plt.subplot(1, 3, 2)
        #     plt.imshow((pred_depth * 127.5 + 127.5).astype(np.uint8))
        #     plt.quiver(xs, ys, pred_normals[:, 0], -pred_normals[:, 1], color='b', scale=20, width=0.005)
        #     plt.title(f'Pred Normals {i}')
        #     plt.axis('off')

        #     # -------- 右:GT depth + 两种箭头 --------
        #     plt.subplot(1, 3, 3)
        #     plt.imshow(gt_depth.astype(np.uint8))
        #     plt.quiver(xs, ys, gt_normals[:, 0], -gt_normals[:, 1], color='r', scale=20, width=0.005, label="GT")
        #     plt.quiver(xs, ys, pred_normals[:, 0], -pred_normals[:, 1], color='b', scale=20, width=0.005, label="Pred")
        #     plt.title(f'GT+Pred Normals {i}')
        #     plt.axis('off')
        #     plt.legend(loc="lower right")

        #     plt.tight_layout()
        #     plt.savefig(f'normals_compare_{i}.png', dpi=300)
        #     plt.close()


        try:
            mean_angular_error[i], median_angular_error[i], rmse_angular_error[i], sub5_error[i], sub7_5_error[i], sub11_25_error[i], sub22_5_error[i], sub30_error[i] = compute_normal_metrics(
                pred_depth, gt_depth)
        except Exception as e:
            print(f'Error computing metrics for sample {i}: {e}')
            continue

    # 过滤掉无效值
    valid_results = ~np.isnan(mean_angular_error) & ~np.isinf(mean_angular_error) 
    results = "{:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}".format(
        mean_angular_error[valid_results].mean(), median_angular_error[valid_results].mean(), sub5_error[valid_results].mean(),
        sub7_5_error[valid_results].mean(), sub11_25_error[valid_results].mean(), sub22_5_error[valid_results].mean(), 
        sub30_error[valid_results].mean())
    print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format(
        "mean", "median", "sub5", "sub7.5", "sub11.25", "sub22.5", "sub30")
    )
    print(results)
    
    print(f'Valid results: {valid_results.sum()}/{len(valid_results)}')
    return results
    # return silog, log10, abs_rel, sq_rel, rms, log_rms, d1, d2, d3



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='BTS TensorFlow implementation.', fromfile_prefix_chars='@')
    parser.convert_arg_line_to_args = convert_arg_line_to_args

    parser.add_argument('--pred_path',           type=str,   help='path to the prediction results in png', required=True)
    parser.add_argument('--gt_path',             type=str,   help='root path to the groundtruth data', required=False)
    parser.add_argument('--dataset',             type=str,   help='dataset to test on, nyu or kitti', default='nyu')
    parser.add_argument('--txt_file_list',      type=str,   help='text file containing list of files to evaluate', default=None)
    args = parser.parse_args()
    test(args)