# Copyright (C) 2019 Jin Han Lee # # This file is a part of BTS. # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see import os import argparse import fnmatch import cv2 import numpy as np import torch import torch.nn.functional as F from PIL import Image import struct os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' def convert_arg_line_to_args(arg_line): for arg in arg_line.split(): if not arg.strip(): continue yield arg def scale_and_shift_align(pred, gt, valid_mask): """ 使用最小二乘法对齐预测深度的scale和shift pred: 预测深度 (相对深度 0-255) gt: 真实深度 (绝对深度) valid_mask: 有效像素掩码 """ pred_valid = pred[valid_mask].flatten() gt_valid = gt[valid_mask].flatten() # 构建最小二乘法系统 Ax = b # 其中 A = [pred_valid, ones], x = [scale, shift], b = gt_valid A = np.vstack([pred_valid, np.ones(len(pred_valid))]).T scale, shift = np.linalg.lstsq(A, gt_valid, rcond=None)[0] # 应用scale和shift pred_aligned = pred * scale + shift return pred_aligned def resize_depth_tensor(depth_img, target_height, target_width): """ 使用双线性插值调整深度图像大小 """ # 转换为tensor (1, 1, H, W) # depth_img = (depth_img - depth_img.min()) / (depth_img.max() - depth_img.min()) * np.float32(65535) # depth_img = depth_img.astype(np.uint16) depth_tensor = torch.from_numpy(depth_img).unsqueeze(0).unsqueeze(0).float() # 使用双线性插值调整大小 resized_tensor = F.interpolate( depth_tensor, size=(target_height, target_width), mode='bilinear', align_corners=True ) # 转换回numpy resized_depth = resized_tensor.squeeze().numpy() return resized_depth def read_depth(filename): with open(filename, 'rb') as f: tag = f.read(4) if tag != b'PIEH': raise ValueError("Invalid file format: expected 'PIEH' tag") width = struct.unpack(' args.min_depth_eval, gt_depth < args.max_depth_eval) valid_mask = np.logical_and(valid_mask, ~np.isnan(gt_depth)) valid_mask = np.logical_and(valid_mask, ~np.isinf(gt_depth)) if gt_depths_mask: valid_mask = np.logical_and(valid_mask, gt_depths_mask[i] > 0) if args.dataset == 'nyu': _valid_mask = np.zeros_like(valid_mask) _valid_mask[45:471, 41:601] = 1 valid_mask = np.logical_and(valid_mask, _valid_mask) del _valid_mask # 处理裁剪 if args.do_kb_crop: height, width = gt_depth.shape top_margin = int(height - 352) left_margin = int((width - 1216) / 2) pred_depth_uncropped = np.zeros((height, width), dtype=np.float32) try: if abs(pred_depth.shape[0]-375) < 10: pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth[top_margin:top_margin + 352, left_margin:left_margin + 1216] pred_depth = pred_depth_uncropped else: pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth pred_depth = pred_depth_uncropped except Exception as e: print(f"Error in do_kb_crop for sample {i}: {e}") print(f"pred shape:{pred_depth.shape}, uncropped shape:{pred_depth_uncropped.shape}") _valid_mask = np.zeros_like(valid_mask) _valid_mask[top_margin:top_margin + 352, left_margin:left_margin + 1216] = valid_mask[top_margin:top_margin + 352, left_margin:left_margin + 1216] valid_mask = _valid_mask if args.garg_crop or args.eigen_crop: gt_height, gt_width = gt_depth.shape eval_mask = np.zeros(valid_mask.shape) if args.garg_crop: eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1 elif args.eigen_crop: if args.dataset == 'kitti': eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1 else: eval_mask[45:471, 41:601] = 1 valid_mask = np.logical_and(valid_mask, eval_mask) # 检查是否有足够的有效像素 if valid_mask.sum() < 100: print(f'Warning: Sample {i} has very few valid pixels ({valid_mask.sum()})') continue # 2. 应用scale-and-shift对齐 # print("original gt depth min:{:.4f} max:{:.4f} mean:{:.4f}".format(gt_depth[valid_mask].min(), gt_depth[valid_mask].max(), gt_depth[valid_mask].mean())) if getattr(args, 'using_log',None): gt_depth_cp = np.log(gt_depth+1e-6) elif getattr(args, 'using_sqrt_disp',None): gt_depth_cp = 1/np.sqrt(gt_depth+1e-8) elif getattr(args, 'using_disp',None): gt_depth_cp = 1/(gt_depth+1e-8) elif getattr(args, 'using_sqrt',None): gt_depth_cp = np.sqrt(gt_depth+1e-6) elif getattr(args, 'using_pdf',None): gt_depth_cp = np.interp(gt_depth, pdf['bins'], pdf['y_map']) else: gt_depth_cp = gt_depth.copy() pred_depth_aligned = scale_and_shift_align(pred_depth, gt_depth_cp, valid_mask) if getattr(args, 'using_log',None): pred_depth_aligned = np.exp(pred_depth_aligned) elif getattr(args, 'using_sqrt_disp',None): pred_depth_aligned = 1/(pred_depth_aligned**2) elif getattr(args, 'using_disp',None): pred_depth_aligned = 1/pred_depth_aligned elif getattr(args, 'using_sqrt',None): pred_depth_aligned = np.power(pred_depth_aligned, 2) elif getattr(args, 'using_pdf',None): pred_depth_aligned = np.interp(pred_depth_aligned, pdf['y_map'], pdf['bins']) pred_depth_aligned = np.clip(pred_depth_aligned, args.min_depth_eval, args.max_depth_eval) # 计算误差 try: silog[i], log10[i], abs_rel[i], sq_rel[i], rms[i], log_rms[i], d1[i], d2[i], d3[i] = compute_errors( gt_depth[valid_mask], pred_depth_aligned[valid_mask]) except Exception as e: print(f'Error computing metrics for sample {i}: {e}') continue # 过滤掉无效值 valid_results = ~np.isnan(silog) & ~np.isinf(silog) & (silog != 0) results = "{:7.5f}, {:7.5f}, {:7.5f}, {:7.5f}, {:7.5f}, {:7.5f}, {:7.5f}, {:7.5f}, {:7.5f}".format( d1[valid_results].mean(), d2[valid_results].mean(), d3[valid_results].mean(), abs_rel[valid_results].mean(), sq_rel[valid_results].mean(), rms[valid_results].mean(), log_rms[valid_results].mean(), silog[valid_results].mean(), log10[valid_results].mean()) if not args.no_verbose: print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format( 'd1', 'd2', 'd3', 'AbsRel', 'SqRel', 'RMSE', 'RMSElog', 'SILog', 'log10')) print(results) print(f'Valid results: {valid_results.sum()}/{len(valid_results)}') return results # return silog, log10, abs_rel, sq_rel, rms, log_rms, d1, d2, d3 if __name__ == '__main__': parser = argparse.ArgumentParser(description='BTS TensorFlow implementation.', fromfile_prefix_chars='@') parser.convert_arg_line_to_args = convert_arg_line_to_args parser.add_argument('--pred_path', type=str, help='path to the prediction results in png', required=True) parser.add_argument('--gt_path', type=str, help='root path to the groundtruth data', required=False) parser.add_argument('--dataset', type=str, help='dataset to test on, nyu or kitti', default='nyu') parser.add_argument('--eigen_crop', help='if set, crops according to Eigen NIPS14', action='store_true') parser.add_argument('--garg_crop', help='if set, crops according to Garg ECCV16', action='store_true') parser.add_argument('--min_depth_eval', type=float, help='minimum depth for evaluation', default=1e-3) parser.add_argument('--max_depth_eval', type=float, help='maximum depth for evaluation', default=80) parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true') parser.add_argument('--no_verbose', default=False, action='store_true', help='if set, do not print out per image results') parser.add_argument('--using_log', default=False, action='store_true', help='if set, use log depth for eval') parser.add_argument('--using_disp', default=False, action='store_true', help='if set, use disparity (1/depth) for eval') parser.add_argument('--using_sqrt', default=False, action='store_true', help='if set, use sqrt depth for eval') parser.add_argument('--using_pdf', default=False, action='store_true', help='if set, use pdf for eval') args = parser.parse_args() test(args) # load_image_rgb_or_grayscale("/opt/liblibai-models/user-workspace2/users/syq/Depth_Post_Train/dataset/Eval/depth/ETH3D/depth/kicker_dslr_depth/kicker/ground_truth_depth/dslr_images/DSC_6493.JPG")