| import sys |
| from os import path as osp |
| import argparse |
| import warnings |
| import torch |
| import numpy as np |
| from PIL import Image |
| from detectron2.config import instantiate, LazyConfig |
|
|
| sys.path.append(osp.dirname(osp.dirname(__file__))) |
| from utils import * |
|
|
|
|
| warnings.simplefilter(action="ignore", category=FutureWarning) |
|
|
|
|
| def do_test(cfg, model, use_dark_inference=False): |
| val_loader = instantiate(cfg.dataloader.val) |
|
|
| model.train(False) |
| AUC = [] |
| min_dist = [] |
| avg_dist = [] |
| with torch.no_grad(): |
| for data in val_loader: |
| val_gaze_heatmap_pred, _ = model(data) |
| val_gaze_heatmap_pred = ( |
| val_gaze_heatmap_pred.squeeze(1).cpu().detach().numpy() |
| ) |
|
|
| |
| for b_i in range(len(val_gaze_heatmap_pred)): |
| |
| valid_gaze = data["gazes"][b_i] |
| valid_gaze = valid_gaze[valid_gaze != -1].view(-1, 2) |
| |
| multi_hot = multi_hot_targets(data["gazes"][b_i], data["imsize"][b_i]) |
| if use_dark_inference: |
| pred_x, pred_y = dark_inference(val_gaze_heatmap_pred[b_i]) |
| else: |
| pred_x, pred_y = argmax_pts(val_gaze_heatmap_pred[b_i]) |
| norm_p = [ |
| pred_x / val_gaze_heatmap_pred[b_i].shape[-2], |
| pred_y / val_gaze_heatmap_pred[b_i].shape[-1], |
| ] |
| scaled_heatmap = np.array( |
| Image.fromarray(val_gaze_heatmap_pred[b_i]).resize( |
| data["imsize"][b_i], |
| resample=Image.BILINEAR, |
| ) |
| ) |
| auc_score = auc(scaled_heatmap, multi_hot) |
| AUC.append(auc_score) |
| |
| all_distances = [] |
| for gt_gaze in valid_gaze: |
| all_distances.append(L2_dist(gt_gaze, norm_p)) |
| min_dist.append(min(all_distances)) |
| |
| mean_gt_gaze = torch.mean(valid_gaze, 0) |
| avg_distance = L2_dist(mean_gt_gaze, norm_p) |
| avg_dist.append(avg_distance) |
|
|
| print("|AUC |min dist|avg dist|") |
| print( |
| "|{:.4f}|{:.4f} |{:.4f} |".format( |
| torch.mean(torch.tensor(AUC)), |
| torch.mean(torch.tensor(min_dist)), |
| torch.mean(torch.tensor(avg_dist)), |
| ) |
| ) |
|
|
|
|
| def main(args): |
| cfg = LazyConfig.load(args.config_file) |
| model: torch.Module = instantiate(cfg.model) |
| model.load_state_dict(torch.load(args.model_weights)["model"]) |
| model.to(cfg.train.device) |
| do_test(cfg, model, use_dark_inference=args.use_dark_inference) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config_file", type=str, help="config file") |
| parser.add_argument( |
| "--model_weights", |
| type=str, |
| help="model weights", |
| ) |
| parser.add_argument("--use_dark_inference", action="store_true") |
| args = parser.parse_args() |
| main(args) |
|
|