| import os |
| import torch |
| import numpy as np |
|
|
| |
| |
| |
|
|
| def compute_depth_metrics(gt, pred, mask=None, median_align=False): |
| """Computation of metrics between predicted and ground truth depths |
| """ |
|
|
| if mask is None: |
| mask = gt > 0 |
|
|
| gt = gt.squeeze(1) |
| pred = pred.squeeze(1) |
| mask = mask.squeeze(1) |
| gt = gt[mask] |
| pred = pred[mask] |
|
|
|
|
| thresh = torch.max((gt / pred), (pred / gt)) |
| a1 = (thresh < 1.25 ).float().mean() |
| a2 = (thresh < 1.25 ** 2).float().mean() |
| a3 = (thresh < 1.25 ** 3).float().mean() |
|
|
| rmse = (gt - pred) ** 2 |
| rmse = torch.sqrt(rmse).mean() |
|
|
| rmse_log = (torch.log10(gt) - torch.log10(pred)) ** 2 |
| rmse_log = torch.sqrt(rmse_log).mean() |
|
|
| abs_ = torch.mean(torch.abs(gt - pred)) |
|
|
| abs_rel = torch.mean(torch.abs(gt - pred) / gt) |
|
|
| sq_rel = torch.mean((gt - pred) ** 2 / gt) |
|
|
| log10 = torch.mean(torch.abs(torch.log10(pred/gt))) |
|
|
| return abs_, abs_rel, sq_rel, rmse, rmse_log, log10, a1, a2, a3 |
|
|
|
|
| |
| class AverageMeter(object): |
| """Computes and stores the average and current value""" |
|
|
| def __init__(self): |
| self.vals = [] |
| self.reset() |
|
|
| def reset(self): |
| self.val = 0 |
| self.avg = 0 |
| self.sum = 0 |
| self.count = 0 |
|
|
| def update(self, val, n=1): |
| self.vals.append(val) |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
|
|
| def to_dict(self): |
| return { |
| 'val': self.val, |
| 'sum': self.sum, |
| 'count': self.count, |
| 'avg': self.avg |
| } |
|
|
| def from_dict(self, meter_dict): |
| self.val = meter_dict['val'] |
| self.sum = meter_dict['sum'] |
| self.count = meter_dict['count'] |
| self.avg = meter_dict['avg'] |
|
|
|
|
| class Evaluator(object): |
|
|
| def __init__(self, median_align=False): |
|
|
| self.median_align = median_align |
| |
| self.metrics = {} |
| self.metrics["err/abs_"] = AverageMeter() |
| self.metrics["err/abs_rel"] = AverageMeter() |
| self.metrics["err/sq_rel"] = AverageMeter() |
| self.metrics["err/rms"] = AverageMeter() |
| self.metrics["err/log_rms"] = AverageMeter() |
| self.metrics["err/log10"] = AverageMeter() |
| self.metrics["acc/a1"] = AverageMeter() |
| self.metrics["acc/a2"] = AverageMeter() |
| self.metrics["acc/a3"] = AverageMeter() |
|
|
| def reset_eval_metrics(self): |
| """ |
| Resets metrics used to evaluate the model |
| """ |
| self.metrics["err/abs_"].reset() |
| self.metrics["err/abs_rel"].reset() |
| self.metrics["err/sq_rel"].reset() |
| self.metrics["err/rms"].reset() |
| self.metrics["err/log_rms"].reset() |
| self.metrics["err/log10"].reset() |
| self.metrics["acc/a1"].reset() |
| self.metrics["acc/a2"].reset() |
| self.metrics["acc/a3"].reset() |
|
|
| def compute_eval_metrics(self, gt_depth, pred_depth, mask): |
| """ |
| Computes metrics used to evaluate the model |
| """ |
| N = gt_depth.shape[0] |
|
|
| abs_, abs_rel, sq_rel, rms, rms_log, log10, a1, a2, a3 = \ |
| compute_depth_metrics(gt_depth, pred_depth, mask, self.median_align) |
|
|
| self.metrics["err/abs_"].update(abs_, N) |
| self.metrics["err/abs_rel"].update(abs_rel, N) |
| self.metrics["err/sq_rel"].update(sq_rel, N) |
| self.metrics["err/rms"].update(rms, N) |
| self.metrics["err/log_rms"].update(rms_log, N) |
| self.metrics["err/log10"].update(log10, N) |
| self.metrics["acc/a1"].update(a1, N) |
| self.metrics["acc/a2"].update(a2, N) |
| self.metrics["acc/a3"].update(a3, N) |
|
|
| def print(self, dir=None): |
| avg_metrics = [] |
| avg_metrics_print = [] |
|
|
| avg_metrics.append(self.metrics["err/abs_"].avg) |
| avg_metrics.append(self.metrics["err/abs_rel"].avg) |
| avg_metrics.append(self.metrics["err/sq_rel"].avg) |
| avg_metrics.append(self.metrics["err/rms"].avg) |
| avg_metrics.append(self.metrics["err/log_rms"].avg) |
| avg_metrics.append(self.metrics["err/log10"].avg) |
| avg_metrics.append(self.metrics["acc/a1"].avg) |
| avg_metrics.append(self.metrics["acc/a2"].avg) |
| avg_metrics.append(self.metrics["acc/a3"].avg) |
| avg_metrics_print.append(self.metrics["err/abs_rel"].avg) |
| avg_metrics_print.append(self.metrics["err/rms"].avg) |
| avg_metrics_print.append(self.metrics["acc/a1"].avg) |
|
|
| print("\n "+ ("{:>8} | " * 3).format("abs_rel", "rms", "a1")) |
| print(("& {: 8.5f} " * 3).format(*avg_metrics_print)) |
|
|
| if dir is not None: |
| file = os.path.join(dir, "result.txt") |
| with open(file, 'w') as f: |
| print("\n " + ("{:>9} | " * 9).format("abs_", "abs_rel", "sq_rel", "rms", "rms_log", |
| "log10", "a1", "a2", "a3"), file=f) |
| print(("& {: 8.5f} " * 9).format(*avg_metrics), file=f) |
|
|