# Copyright (C) 2019 Jin Han Lee
#
# This file is a part of BTS.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see
import os
import argparse
import fnmatch
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import struct
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
def convert_arg_line_to_args(arg_line):
for arg in arg_line.split():
if not arg.strip():
continue
yield arg
def scale_and_shift_align(pred, gt, valid_mask):
"""
使用最小二乘法对齐预测深度的scale和shift
pred: 预测深度 (相对深度 0-255)
gt: 真实深度 (绝对深度)
valid_mask: 有效像素掩码
"""
pred_valid = pred[valid_mask].flatten()
gt_valid = gt[valid_mask].flatten()
# 构建最小二乘法系统 Ax = b
# 其中 A = [pred_valid, ones], x = [scale, shift], b = gt_valid
A = np.vstack([pred_valid, np.ones(len(pred_valid))]).T
scale, shift = np.linalg.lstsq(A, gt_valid, rcond=None)[0]
# 应用scale和shift
pred_aligned = pred * scale + shift
return pred_aligned
def resize_depth_tensor(depth_img, target_height, target_width):
"""
使用双线性插值调整深度图像大小
"""
# 转换为tensor (1, 1, H, W)
# depth_img = (depth_img - depth_img.min()) / (depth_img.max() - depth_img.min()) * np.float32(65535)
# depth_img = depth_img.astype(np.uint16)
depth_tensor = torch.from_numpy(depth_img).unsqueeze(0).unsqueeze(0).float()
# 使用双线性插值调整大小
resized_tensor = F.interpolate(
depth_tensor,
size=(target_height, target_width),
mode='bilinear',
align_corners=True
)
# 转换回numpy
resized_depth = resized_tensor.squeeze().numpy()
return resized_depth
def read_depth(filename):
with open(filename, 'rb') as f:
tag = f.read(4)
if tag != b'PIEH':
raise ValueError("Invalid file format: expected 'PIEH' tag")
width = struct.unpack(' args.min_depth_eval, gt_depth < args.max_depth_eval)
valid_mask = np.logical_and(valid_mask, ~np.isnan(gt_depth))
valid_mask = np.logical_and(valid_mask, ~np.isinf(gt_depth))
if gt_depths_mask:
valid_mask = np.logical_and(valid_mask, gt_depths_mask[i] > 0)
if args.dataset == 'nyu':
_valid_mask = np.zeros_like(valid_mask)
_valid_mask[45:471, 41:601] = 1
valid_mask = np.logical_and(valid_mask, _valid_mask)
del _valid_mask
# 处理裁剪
if args.do_kb_crop:
height, width = gt_depth.shape
top_margin = int(height - 352)
left_margin = int((width - 1216) / 2)
pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
try:
if abs(pred_depth.shape[0]-375) < 10:
pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth[top_margin:top_margin + 352, left_margin:left_margin + 1216]
pred_depth = pred_depth_uncropped
else:
pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
pred_depth = pred_depth_uncropped
except Exception as e:
print(f"Error in do_kb_crop for sample {i}: {e}")
print(f"pred shape:{pred_depth.shape}, uncropped shape:{pred_depth_uncropped.shape}")
_valid_mask = np.zeros_like(valid_mask)
_valid_mask[top_margin:top_margin + 352, left_margin:left_margin + 1216] = valid_mask[top_margin:top_margin + 352, left_margin:left_margin + 1216]
valid_mask = _valid_mask
if args.garg_crop or args.eigen_crop:
gt_height, gt_width = gt_depth.shape
eval_mask = np.zeros(valid_mask.shape)
if args.garg_crop:
eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
elif args.eigen_crop:
if args.dataset == 'kitti':
eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
else:
eval_mask[45:471, 41:601] = 1
valid_mask = np.logical_and(valid_mask, eval_mask)
# 检查是否有足够的有效像素
if valid_mask.sum() < 100:
print(f'Warning: Sample {i} has very few valid pixels ({valid_mask.sum()})')
continue
# 2. 应用scale-and-shift对齐
# print("original gt depth min:{:.4f} max:{:.4f} mean:{:.4f}".format(gt_depth[valid_mask].min(), gt_depth[valid_mask].max(), gt_depth[valid_mask].mean()))
if getattr(args, 'using_log',None):
gt_depth_cp = np.log(gt_depth+1e-6)
elif getattr(args, 'using_sqrt_disp',None):
gt_depth_cp = 1/np.sqrt(gt_depth+1e-8)
elif getattr(args, 'using_disp',None):
gt_depth_cp = 1/(gt_depth+1e-8)
elif getattr(args, 'using_sqrt',None):
gt_depth_cp = np.sqrt(gt_depth+1e-6)
elif getattr(args, 'using_pdf',None):
gt_depth_cp = np.interp(gt_depth, pdf['bins'], pdf['y_map'])
else:
gt_depth_cp = gt_depth.copy()
pred_depth_aligned = scale_and_shift_align(pred_depth, gt_depth_cp, valid_mask)
if getattr(args, 'using_log',None):
pred_depth_aligned = np.exp(pred_depth_aligned)
elif getattr(args, 'using_sqrt_disp',None):
pred_depth_aligned = 1/(pred_depth_aligned**2)
elif getattr(args, 'using_disp',None):
pred_depth_aligned = 1/pred_depth_aligned
elif getattr(args, 'using_sqrt',None):
pred_depth_aligned = np.power(pred_depth_aligned, 2)
elif getattr(args, 'using_pdf',None):
pred_depth_aligned = np.interp(pred_depth_aligned, pdf['y_map'], pdf['bins'])
pred_depth_aligned = np.clip(pred_depth_aligned, args.min_depth_eval, args.max_depth_eval)
# 计算误差
try:
silog[i], log10[i], abs_rel[i], sq_rel[i], rms[i], log_rms[i], d1[i], d2[i], d3[i] = compute_errors(
gt_depth[valid_mask], pred_depth_aligned[valid_mask])
except Exception as e:
print(f'Error computing metrics for sample {i}: {e}')
continue
# 过滤掉无效值
valid_results = ~np.isnan(silog) & ~np.isinf(silog) & (silog != 0)
results = "{:7.5f}, {:7.5f}, {:7.5f}, {:7.5f}, {:7.5f}, {:7.5f}, {:7.5f}, {:7.5f}, {:7.5f}".format(
d1[valid_results].mean(), d2[valid_results].mean(), d3[valid_results].mean(),
abs_rel[valid_results].mean(), sq_rel[valid_results].mean(), rms[valid_results].mean(),
log_rms[valid_results].mean(), silog[valid_results].mean(), log10[valid_results].mean())
if not args.no_verbose:
print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format(
'd1', 'd2', 'd3', 'AbsRel', 'SqRel', 'RMSE', 'RMSElog', 'SILog', 'log10'))
print(results)
print(f'Valid results: {valid_results.sum()}/{len(valid_results)}')
return results
# return silog, log10, abs_rel, sq_rel, rms, log_rms, d1, d2, d3
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='BTS TensorFlow implementation.', fromfile_prefix_chars='@')
parser.convert_arg_line_to_args = convert_arg_line_to_args
parser.add_argument('--pred_path', type=str, help='path to the prediction results in png', required=True)
parser.add_argument('--gt_path', type=str, help='root path to the groundtruth data', required=False)
parser.add_argument('--dataset', type=str, help='dataset to test on, nyu or kitti', default='nyu')
parser.add_argument('--eigen_crop', help='if set, crops according to Eigen NIPS14', action='store_true')
parser.add_argument('--garg_crop', help='if set, crops according to Garg ECCV16', action='store_true')
parser.add_argument('--min_depth_eval', type=float, help='minimum depth for evaluation', default=1e-3)
parser.add_argument('--max_depth_eval', type=float, help='maximum depth for evaluation', default=80)
parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
parser.add_argument('--no_verbose', default=False, action='store_true', help='if set, do not print out per image results')
parser.add_argument('--using_log', default=False, action='store_true', help='if set, use log depth for eval')
parser.add_argument('--using_disp', default=False, action='store_true', help='if set, use disparity (1/depth) for eval')
parser.add_argument('--using_sqrt', default=False, action='store_true', help='if set, use sqrt depth for eval')
parser.add_argument('--using_pdf', default=False, action='store_true', help='if set, use pdf for eval')
args = parser.parse_args()
test(args)
# load_image_rgb_or_grayscale("/opt/liblibai-models/user-workspace2/users/syq/Depth_Post_Train/dataset/Eval/depth/ETH3D/depth/kicker_dslr_depth/kicker/ground_truth_depth/dslr_images/DSC_6493.JPG")