| import collections |
| import os, sys |
| ROOT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..') |
| sys.path.append(ROOT_DIR) |
| import numpy as np |
| import pandas as pd |
| import argparse |
| import pickle |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from torchvision.models import ResNet50_Weights, Swin_T_Weights, ViT_B_16_Weights, RegNet_Y_16GF_Weights |
| from torchvision import transforms as trn |
| from torch.hub import load_state_dict_from_url |
|
|
| from openood.evaluation_api import Evaluator |
|
|
| from openood.networks import ResNet50, Swin_T, ViT_B_16, RegNet_Y_16GF |
| from openood.networks.conf_branch_net import ConfBranchNet |
| from openood.networks.godin_net import GodinNet |
| from openood.networks.rot_net import RotNet |
| from openood.networks.cider_net import CIDERNet |
|
|
|
|
| def update(d, u): |
| for k, v in u.items(): |
| if isinstance(v, collections.abc.Mapping): |
| d[k] = update(d.get(k, {}), v) |
| else: |
| d[k] = v |
| return d |
|
|
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument('--arch', |
| default='resnet50', |
| choices=['resnet50', 'swin-t', 'vit-b-16', 'regnet']) |
| parser.add_argument('--tvs-version', default=1, choices=[1, 2]) |
| parser.add_argument('--ckpt-path', default=None) |
| parser.add_argument('--tvs-pretrained', action='store_true') |
| parser.add_argument('--postprocessor', default='msp') |
| parser.add_argument('--save-csv', action='store_true') |
| parser.add_argument('--save-score', action='store_true') |
| parser.add_argument('--fsood', action='store_true') |
| parser.add_argument('--batch-size', default=2000, type=int) |
| args = parser.parse_args() |
|
|
| if not args.tvs_pretrained: |
| assert args.ckpt_path is not None |
| root = '/'.join(args.ckpt_path.split('/')[:-1]) |
| else: |
| root = os.path.join( |
| ROOT_DIR, 'results', |
| f'imagenet_{args.arch}_tvsv{args.tvs_version}_base_default') |
| if not os.path.exists(root): |
| os.makedirs(root) |
|
|
| |
| |
| postprocessor_name = args.postprocessor |
| |
| if os.path.isfile( |
| os.path.join(root, 'postprocessors', f'{postprocessor_name}.pkl')): |
| with open( |
| os.path.join(root, 'postprocessors', f'{postprocessor_name}.pkl'), |
| 'rb') as f: |
| postprocessor = pickle.load(f) |
| else: |
| postprocessor = None |
|
|
| |
| |
| |
| if args.tvs_pretrained: |
| if args.arch == 'resnet50': |
| net = ResNet50() |
| weights = eval(f'ResNet50_Weights.IMAGENET1K_V{args.tvs_version}') |
| net.load_state_dict(load_state_dict_from_url(weights.url)) |
| preprocessor = weights.transforms() |
| elif args.arch == 'swin-t': |
| net = Swin_T() |
| weights = eval(f'Swin_T_Weights.IMAGENET1K_V{args.tvs_version}') |
| net.load_state_dict(load_state_dict_from_url(weights.url)) |
| preprocessor = weights.transforms() |
| elif args.arch == 'vit-b-16': |
| net = ViT_B_16() |
| weights = eval(f'ViT_B_16_Weights.IMAGENET1K_V{args.tvs_version}') |
| net.load_state_dict(load_state_dict_from_url(weights.url)) |
| preprocessor = weights.transforms() |
| elif args.arch == 'regnet': |
| net = RegNet_Y_16GF() |
| weights = eval( |
| f'RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_E2E_V{args.tvs_version}') |
| net.load_state_dict(load_state_dict_from_url(weights.url)) |
| preprocessor = weights.transforms() |
| else: |
| raise NotImplementedError |
| else: |
| if args.arch == 'resnet50': |
| if postprocessor_name == 'conf_branch': |
| net = ConfBranchNet(backbone=ResNet50(), num_classes=1000) |
| elif postprocessor_name == 'godin': |
| backbone = ResNet50() |
| net = GodinNet(backbone=backbone, |
| feature_size=backbone.feature_size, |
| num_classes=1000) |
| elif postprocessor_name == 'rotpred': |
| net = RotNet(backbone=ResNet50(), num_classes=1000) |
| elif postprocessor_name == 'cider': |
| net = CIDERNet(backbone=ResNet50(), |
| head='mlp', |
| feat_dim=128, |
| num_classes=1000) |
| else: |
| net = ResNet50() |
|
|
| ckpt = torch.load(args.ckpt_path, map_location='cpu') |
| net.load_state_dict(ckpt) |
| preprocessor = trn.Compose([ |
| trn.Resize(256), |
| trn.CenterCrop(224), |
| trn.ToTensor(), |
| trn.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]) |
| ]) |
| else: |
| raise NotImplementedError |
|
|
| net.cuda() |
| net.eval() |
| |
| evaluator = Evaluator( |
| net, |
| id_name='imagenet', |
| data_root=os.path.join(ROOT_DIR, 'data'), |
| config_root=os.path.join(ROOT_DIR, 'configs'), |
| preprocessor=preprocessor, |
| postprocessor_name=postprocessor_name, |
| postprocessor=postprocessor, |
| batch_size=args. |
| batch_size, |
| shuffle=True, |
| num_workers=8) |
|
|
| |
| if os.path.isfile(os.path.join(root, 'scores', f'{postprocessor_name}.pkl')): |
| with open(os.path.join(root, 'scores', f'{postprocessor_name}.pkl'), |
| 'rb') as f: |
| scores = pickle.load(f) |
| update(evaluator.scores, scores) |
| print('Loaded pre-computed scores from file.') |
|
|
| |
| if hasattr(evaluator.postprocessor, 'setup_flag' |
| ) or evaluator.postprocessor.hyperparam_search_done is True: |
| pp_save_root = os.path.join(root, 'postprocessors') |
| if not os.path.exists(pp_save_root): |
| os.makedirs(pp_save_root) |
|
|
| if not os.path.isfile( |
| os.path.join(pp_save_root, f'{postprocessor_name}.pkl')): |
| with open(os.path.join(pp_save_root, f'{postprocessor_name}.pkl'), |
| 'wb') as f: |
| pickle.dump(evaluator.postprocessor, f, pickle.HIGHEST_PROTOCOL) |
|
|
| |
| metrics = evaluator.eval_ood(fsood=args.fsood) |
|
|
| |
| if args.save_csv: |
| saving_root = os.path.join(root, 'ood' if not args.fsood else 'fsood') |
| if not os.path.exists(saving_root): |
| os.makedirs(saving_root) |
|
|
| if not os.path.isfile( |
| os.path.join(saving_root, f'{postprocessor_name}.csv')): |
| metrics.to_csv(os.path.join(saving_root, f'{postprocessor_name}.csv'), |
| float_format='{:.2f}'.format) |
|
|
| if args.save_score: |
| score_save_root = os.path.join(root, 'scores') |
| if not os.path.exists(score_save_root): |
| os.makedirs(score_save_root) |
| with open(os.path.join(score_save_root, f'{postprocessor_name}.pkl'), |
| 'wb') as f: |
| pickle.dump(evaluator.scores, f, pickle.HIGHEST_PROTOCOL) |
|
|