Diffusers
Safetensors
EvalMDE / Edit2Perceive /utils /eval_normal.py
zeyuren2002's picture
Add files using upload-large-folder tool
7f921f4 verified
# Copyright (C) 2019 Jin Han Lee
#
# This file is a part of BTS.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>
from __future__ import absolute_import, division, print_function
from utils.metric import compute_normal_metrics
import os
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
import argparse
import fnmatch
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
def hwc2chw(array):
return array.transpose(2, 0, 1)
def chw2hwc(array):
return array.transpose(1, 2, 0)
def convert_arg_line_to_args(arg_line):
for arg in arg_line.split():
if not arg.strip():
continue
yield arg
def resize_tensor(input_tensor, target_height, target_width):
"""
使用双线性插值调整深度图像大小
"""
# 多通道resize
input_tensor = torch.from_numpy(hwc2chw(input_tensor)).unsqueeze(0).float()
# 使用双线性插值调整大小
resized_tensor = F.interpolate(
input_tensor,
size=(target_height, target_width),
mode='bilinear',
align_corners=True
)
# 转换回numpy
resized = resized_tensor.squeeze().numpy()
resized = chw2hwc(resized)
# input_tensor = np.ascontiguousarray(input_tensor)
# resized = cv2.resize(input_tensor, (target_width, target_height), interpolation=cv2.INTER_LINEAR)
return resized
def load_image_rgb_or_grayscale(image_path):
"""
加载图像,支持RGB和灰度图像,统一转换为numpy数组
"""
if image_path.endswith('.npy'):
# 如果是npy文件,直接读取
img_array = np.load(image_path)
else:
# 首先尝试用PIL加载,可以更好地处理不同格式
img = Image.open(image_path)
img_array = np.array(img)
# 如果是RGBA图像,取前3通道,否则RGB或者Gray则不处理
if len(img_array.shape) == 3 and img_array.shape[2] == 4: # RGBA
img_array = img_array[:, :, :3] # 去掉Alpha通道
# 对3通道取均值返回
return img_array
def test(args):
global gt_depths, missing_ids, pred_filenames,gt_depths_mask
gt_depths = []
gt_depths_mask = []
missing_ids = set()
pred_filenames = []
if getattr(args, 'txt_file_list', None) is not None:
with open(args.txt_file_list, 'r') as f:
lines = f.readlines()
for i,line in enumerate(lines):
line = line.strip().split()[0]
if line == '':
continue
pred_filenames.append(line.replace(".png",".npy"))
else:
for root, dirnames, filenames in os.walk(args.pred_path):
for pred_filename in fnmatch.filter(filenames, '*.png') + fnmatch.filter(filenames, '*.jpg') + fnmatch.filter(filenames, '*.npy'):
if 'cmap' in pred_filename or 'gt' in pred_filename:
continue
dirname = root.replace(args.pred_path, '')
if dirname.startswith('/'):
dirname = dirname[1:]
pred_filenames.append(os.path.join(dirname, pred_filename))
num_test_samples = len(pred_filenames)
print(f'Found {num_test_samples} prediction files.')
pred_depths = []
for i in tqdm(range(num_test_samples)):
pred_depth_path = os.path.join(args.pred_path,pred_filenames[i])
pred_depth = load_image_rgb_or_grayscale(pred_depth_path)
if pred_depth is None:
print('Missing: %s ' % pred_depth_path)
missing_ids.add(i)
continue
# 预测图像是0-255的relative depth,先转换为float
pred_depth = pred_depth.astype(np.float32)
pred_depths.append(pred_depth)
# 加载GT深度图
if args.dataset == 'nyu' or args.dataset == 'scannet' or args.dataset == 'ibims' or args.dataset == 'oasis':
for t_id in range(num_test_samples):
if t_id in missing_ids:
continue
# 构建GT路径,保持与pred相同的目录结构
pred_relative_path = pred_filenames[t_id]
gt_depth_path = os.path.join(args.gt_path, pred_relative_path)
gt_depth_path = gt_depth_path.replace("_img.npy","_normal.npy")
depth = load_image_rgb_or_grayscale(gt_depth_path)
if depth is None:
print('Missing: %s ' % gt_depth_path)
missing_ids.add(t_id)
continue
gt_depths.append(depth)
elif args.dataset == 'diode':
for t_id in range(num_test_samples):
if t_id in missing_ids:
continue
# 构建GT路径,保持与pred相同的目录结构
pred_relative_path = pred_filenames[t_id]
gt_depth_path = os.path.join(args.gt_path, pred_relative_path)
gt_depth_path = gt_depth_path.replace(".npy","_normal.npy")
gt_depth_mask_path = gt_depth_path.replace("_depth.npy","_depth_mask.npy")
depth = load_image_rgb_or_grayscale(gt_depth_path)
depth_mask = load_image_rgb_or_grayscale(gt_depth_mask_path)
if depth is None:
print('Missing: %s ' % gt_depth_path)
missing_ids.add(t_id)
continue
gt_depths.append(depth)
gt_depths_mask.append(depth_mask)
else:
raise ValueError(f"Unsupported dataset: {args.dataset}")
print(f'### Computing errors for {len(gt_depths)} files with {len(missing_ids)} missing' if not gt_depths_mask else 'Computing errors with masks')
results = eval(pred_depths,args)
print('Done.')
return results
def eval(pred_depths,args):
num_samples = len(pred_depths)
pred_depths_valid = []
gt_depths_valid = []
# 收集有效的预测和GT深度
gt_idx = 0
for t_id in range(num_samples):
if t_id in missing_ids:
continue
pred_depths_valid.append(pred_depths[t_id])
gt_depths_valid.append(gt_depths[gt_idx])
gt_idx += 1
num_samples = len(pred_depths_valid)
mean_angular_error = np.zeros(num_samples, dtype=np.float32)
median_angular_error = np.zeros(num_samples, dtype=np.float32)
rmse_angular_error = np.zeros(num_samples, dtype=np.float32)
sub5_error = np.zeros(num_samples, dtype=np.float32)
sub7_5_error = np.zeros(num_samples, dtype=np.float32)
sub11_25_error = np.zeros(num_samples, dtype=np.float32)
sub22_5_error = np.zeros(num_samples, dtype=np.float32)
sub30_error = np.zeros(num_samples, dtype=np.float32)
for i in range(num_samples):
gt_depth = gt_depths_valid[i]
gt_depth[:,:,0] *= -1
gt_depth[np.isinf(gt_depth)] = 0
gt_depth[np.isnan(gt_depth)] = 0
pred_depth = pred_depths_valid[i]
pred_depth[np.isinf(pred_depth)] = 0
pred_depth[np.isnan(pred_depth)] = 0
# 1. 首先调整预测深度的大小以匹配GT
if pred_depth.shape != gt_depth.shape:
pred_depth = resize_tensor(pred_depth, gt_depth.shape[0], gt_depth.shape[1])
# if i < 5:
# H, W, _ = gt_depth.shape
# # num_points = 200
# # ys = np.random.randint(0, H, size=num_points)
# # xs = np.random.randint(0, W, size=num_points)
# # make grid to sample
# sep = 20
# grid_y, grid_x = np.mgrid[0:H:sep, 0:W:sep]
# ys, xs = grid_y.ravel(), grid_x.ravel()
# # 取出法向量 (x,y,z)
# gt_normals = gt_depth[ys, xs, :]
# pred_normals = pred_depth[ys, xs, :]
# # 归一化
# gt_normals = gt_normals / (np.linalg.norm(gt_normals, axis=1, keepdims=True) + 1e-8)
# pred_normals = pred_normals / (np.linalg.norm(pred_normals, axis=1, keepdims=True) + 1e-8)
# plt.figure(figsize=(18, 6))
# # -------- 左:GT 法线 --------
# plt.subplot(1, 3, 1)
# plt.imshow((gt_depth * 127.5 + 127.5).astype(np.uint8)) # normal map可视化到[0,255]
# plt.quiver(xs, ys, gt_normals[:, 0], -gt_normals[:, 1], color='r', scale=20, width=0.005)
# plt.title(f'GT Normals {i}')
# plt.axis('off')
# # -------- 中:Pred 法线 --------
# plt.subplot(1, 3, 2)
# plt.imshow((pred_depth * 127.5 + 127.5).astype(np.uint8))
# plt.quiver(xs, ys, pred_normals[:, 0], -pred_normals[:, 1], color='b', scale=20, width=0.005)
# plt.title(f'Pred Normals {i}')
# plt.axis('off')
# # -------- 右:GT depth + 两种箭头 --------
# plt.subplot(1, 3, 3)
# plt.imshow(gt_depth.astype(np.uint8))
# plt.quiver(xs, ys, gt_normals[:, 0], -gt_normals[:, 1], color='r', scale=20, width=0.005, label="GT")
# plt.quiver(xs, ys, pred_normals[:, 0], -pred_normals[:, 1], color='b', scale=20, width=0.005, label="Pred")
# plt.title(f'GT+Pred Normals {i}')
# plt.axis('off')
# plt.legend(loc="lower right")
# plt.tight_layout()
# plt.savefig(f'normals_compare_{i}.png', dpi=300)
# plt.close()
try:
mean_angular_error[i], median_angular_error[i], rmse_angular_error[i], sub5_error[i], sub7_5_error[i], sub11_25_error[i], sub22_5_error[i], sub30_error[i] = compute_normal_metrics(
pred_depth, gt_depth)
except Exception as e:
print(f'Error computing metrics for sample {i}: {e}')
continue
# 过滤掉无效值
valid_results = ~np.isnan(mean_angular_error) & ~np.isinf(mean_angular_error)
results = "{:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}".format(
mean_angular_error[valid_results].mean(), median_angular_error[valid_results].mean(), sub5_error[valid_results].mean(),
sub7_5_error[valid_results].mean(), sub11_25_error[valid_results].mean(), sub22_5_error[valid_results].mean(),
sub30_error[valid_results].mean())
print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format(
"mean", "median", "sub5", "sub7.5", "sub11.25", "sub22.5", "sub30")
)
print(results)
print(f'Valid results: {valid_results.sum()}/{len(valid_results)}')
return results
# return silog, log10, abs_rel, sq_rel, rms, log_rms, d1, d2, d3
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='BTS TensorFlow implementation.', fromfile_prefix_chars='@')
parser.convert_arg_line_to_args = convert_arg_line_to_args
parser.add_argument('--pred_path', type=str, help='path to the prediction results in png', required=True)
parser.add_argument('--gt_path', type=str, help='root path to the groundtruth data', required=False)
parser.add_argument('--dataset', type=str, help='dataset to test on, nyu or kitti', default='nyu')
parser.add_argument('--txt_file_list', type=str, help='text file containing list of files to evaluate', default=None)
args = parser.parse_args()
test(args)