| import argparse |
| import os |
| import sys |
| import torch |
| import shutil |
| import utils |
| import numpy as np |
| import metrics |
| import torch.nn as nn |
| import torch.backends.cudnn as cudnn |
| import vision_transformer as vits |
|
|
| from tqdm import tqdm |
| from datasets import build_dataset |
| from meters import AverageEpochMeter |
| from object_discovery import ncut, detect_box, get_feats |
| from visualization import visualize_fms, visualize_predictions_gt, visualize_img, visualize_eigvec |
|
|
| from collections import OrderedDict |
| |
|
|
| def evaluate(): |
| if args.device != 'cuda': |
| args.distributed = False |
| else: |
| utils.init_distributed_mode(args) |
|
|
| print(args) |
|
|
| device = torch.device(args.device) |
| cudnn.benchmark = True |
| |
| |
| if args.arch in vits.__dict__.keys(): |
| model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=args.num_labels) |
| embed_dim = model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)) |
| print(f'embed_dim: {embed_dim}') |
| else: |
| print(f"Unknow architecture: {args.arch}") |
| sys.exit(1) |
| |
| model.to(device) |
| model.eval() |
|
|
| linear_classifier = LinearClassifier(embed_dim, num_labels=args.num_labels) |
| linear_classifier = linear_classifier.to(device) |
|
|
| |
| checkpoint_model = torch.load(args.pretrained_weights, map_location='cpu') |
| if args.classifier_weights is None: |
| linear_classifier.load_state_dict(checkpoint_model['classifier']) |
| model.load_state_dict(checkpoint_model['model'], strict=False) |
| else: |
| model.load_state_dict(checkpoint_model) |
| checkpoint_classifier = torch.load(args.classifier_weights, map_location='cpu') |
| new_state_dict = OrderedDict() |
| new_state_dict = {k[7:]: v for k, v in checkpoint_classifier['state_dict'].items()} |
| |
| linear_classifier.load_state_dict(new_state_dict) |
| print('Load from checkpoint done.') |
| |
| dataset_val, _ = build_dataset(is_train=False, args=args) |
|
|
| if args.distributed: |
| sampler_val = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False) |
| else: |
| sampler_val = torch.utils.data.SequentialSampler(dataset_val) |
|
|
| val_loader = torch.utils.data.DataLoader( |
| dataset_val, |
| sampler=sampler_val, |
| batch_size=args.batch_size_per_gpu, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| ) |
| |
| |
| num_batches = len(val_loader) |
| scales = [args.patch_size, args.patch_size] |
| top1_cls_meter = AverageEpochMeter('Top-1 Cls') |
| top5_cls_meter = AverageEpochMeter('Top-5 Cls') |
| gt_loc_meter = AverageEpochMeter('GT-Known Loc') |
| top1_loc_meter = AverageEpochMeter('Top-1 Loc') |
| feat_out = {} |
| bbox_error = 0 |
| cls_error = 0 |
| skip = 0 |
| def hook_fn_forward_qkv(module, input, output): |
| feat_out["qkv"] = output |
| model._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv) |
| val_loader = tqdm(val_loader) |
| for i, (images, labels, gt_boxes) in enumerate(val_loader): |
| init_image_size = images[0].shape |
| images = utils.padding_img(images[0], args) |
| images = images.unsqueeze(0) |
| if images.shape[-1] > 1000 and images.shape[-2] > 1000: |
| skip = skip + 1 |
| continue |
| with torch.no_grad(): |
| images = images.to(device) |
| labels = labels.to(device) |
| |
| intermediate_output, shape= model.get_intermediate_layers(images, args.n_last_blocks) |
| output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1) |
| if args.avgpool_patchtokens: |
| output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1) |
| output = output.reshape(output.shape[0], -1) |
| output = linear_classifier(output) |
| |
| batch_size = images.size(0) |
| top1_cls, top5_cls = utils.accuracy(output, labels, topk=(1, 5)) |
| top1_cls_meter.update(top1_cls, batch_size) |
| top5_cls_meter.update(top5_cls, batch_size) |
|
|
|
|
|
|
| |
| dims = (images[0].shape[-2] // args.patch_size, images[0].shape[-1] // args.patch_size) |
| k = get_feats(feat_out, shape) |
| bboxes, mask, seed, eigvec= ncut(k, dims, scales, init_image_size=init_image_size[1:], eps=args.eps, tau=args.tau) |
| bboxes_copy = np.copy(bboxes) |
| bboxes = torch.FloatTensor(bboxes).unsqueeze(0) |
| gt_loc, top1_loc = metrics.loc_accuracy(output, labels, gt_boxes, bboxes) |
| |
| if args.visual: |
| visualize_predictions_gt(images[0], bboxes_copy , gt_boxes, i, seed, dims, scales, './output/all') |
| visualize_eigvec(eigvec, './output/all', i, dims, scales) |
|
|
| |
| if top1_loc == 0 and gt_loc == 0: |
| bbox_error = bbox_error + 1 |
| if args.visual: |
| |
| visualize_predictions_gt(images[0], bboxes_copy , gt_boxes, i, seed, dims, scales, './output/bbox_error') |
| |
| elif top1_loc == 0 and gt_loc == 1: |
| cls_error = cls_error +1 |
| if args.visual: |
| visualize_predictions_gt(images[0], bboxes_copy , gt_boxes, i, seed, dims, scales, './output/classification_error') |
|
|
| gt_loc_meter.update(gt_loc, batch_size) |
| top1_loc_meter.update(top1_loc, batch_size) |
| top1_cls = top1_cls_meter.compute() |
| top5_cls = top5_cls_meter.compute() |
| gt_loc = gt_loc_meter.compute() |
| top1_loc = top1_loc_meter.compute() |
| val_loader.set_description(f'Top1_cls: {top1_cls:.4f}, top5_cls{top5_cls:.4f}, gt_loc: {gt_loc:.4f}, top1_loc:{top1_loc:.4f}') |
| print(f'Top1_cls: {top1_cls}, top5_cls{top5_cls}, gt_loc: {gt_loc}, top1_loc:{top1_loc}') |
| print(f'Bbox error: {bbox_error}, cls error: {cls_error}') |
| print(f'Skip large image: {skip}') |
| |
| |
| class LinearClassifier(nn.Module): |
| """Linear layer to train on top of frozen features""" |
| def __init__(self, dim, num_labels=1000): |
| super(LinearClassifier, self).__init__() |
| self.num_labels = num_labels |
| self.linear = nn.Linear(dim, num_labels) |
| self.linear.weight.data.normal_(mean=0.0, std=0.01) |
| self.linear.bias.data.zero_() |
|
|
| def forward(self, x): |
| |
| x = x.view(x.size(0), -1) |
| |
| return self.linear(x) |
|
|
| if __name__=='__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--n_last_blocks', default=4, type=int, help="""Concatenate [CLS] tokens |
| for the `n` last blocks. We use `n=4` when evaluating ViT-Small and `n=1` with ViT-Base.""") |
| parser.add_argument('--avgpool_patchtokens', default=False, type=utils.bool_flag, |
| help="""Whether ot not to concatenate the global average pooled features to the [CLS] token. |
| We typically set this to False for ViT-Small and to True with ViT-Base.""") |
| parser.add_argument('--arch', default='vit_small', type=str, help='Architecture') |
| parser.add_argument('--dataset', default='cub', type=str, choices=['cub', 'imagenet'], help='Architecture') |
| parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') |
| parser.add_argument('--input_size', default=224, type=int, help='Input image size, default(224).') |
| parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") |
| parser.add_argument('--classifier_weights', default=None, type=str, help="Path to linear classifier pretrained weights to evaluate.") |
| parser.add_argument('--batch_size_per_gpu', default=256, type=int, help='Per-GPU batch-size') |
| parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up |
| distributed training; see https://pytorch.org/docs/stable/distributed.html""") |
| parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") |
| parser.add_argument('--data_path', default='/path/to/imagenet/', type=str) |
| parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') |
| parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.") |
| parser.add_argument('--num_labels', default=200, type=int, help='Number of labels for linear classifier') |
| parser.add_argument('--device', default='cuda', help='device to use for training / testing') |
| parser.add_argument('--distributed', default=False, action='store_true', help='device to use for training / testing') |
| parser.add_argument('--ori_size', default=False, action='store_true', help='Evaluate on image raw size') |
| parser.add_argument('--visual', default=False, action='store_true', help='Visualize error examples on ./test') |
| parser.add_argument('--eps', default=1e-5, type=float, help='hyperparameter for tokencut') |
| parser.add_argument('--tau', default=0.05, type=float, help='hyperparamter for tokencut') |
| parser.add_argument('--no_center_crop', default=False, action='store_true', help='Center crop input image') |
| args = parser.parse_args() |
| evaluate() |
|
|