| |
| ''' |
| This script performs zero-shot evaluation on ImageNet-1K. (with single-GPU) |
| ''' |
|
|
| import os |
| import argparse |
| from pathlib import Path |
| import json |
| from tqdm import tqdm |
|
|
| import torch |
|
|
| from clip.model import convert_weights, CLIP |
| from clip import tokenize |
| from clip.utils import image_transform |
| from eval.data import get_zeroshot_dataset, _preprocess_text |
| from eval.cvinw_zeroshot_templates import ( |
| openai_templates, |
| flower_templates, |
| food_templates, |
| aircraft_templates, |
| eurosat_templates, |
| country211_templates, |
| ) |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--vision-model", |
| choices=["ViT-B-16", "ViT-L-14", "RN50"], |
| default="ViT-B-16", |
| help="Name of the vision backbone to use.", |
| ) |
| parser.add_argument( |
| "--text-model", |
| choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese", "RBT3-chinese"], |
| default="RoBERTa-wwm-ext-base-chinese", |
| help="Name of the text backbone to use.", |
| ) |
| parser.add_argument( |
| "--precision", |
| choices=["amp", "fp16", "fp32"], |
| default="amp", |
| help="Floating point precition." |
| ) |
| parser.add_argument( |
| "--label-file", |
| type=str, |
| help="file for labels", |
| ) |
| parser.add_argument( |
| "--datapath", |
| type=str, |
| required=True, |
| help="Path to the test set for conducting zero shot evaluation.", |
| ) |
| parser.add_argument( |
| "--dataset", |
| type=str, |
| default="imagenet", |
| help="Specified dataset.", |
| ) |
| parser.add_argument( |
| "--index", |
| type=str, |
| default="", |
| help="Specify image paths.", |
| ) |
| parser.add_argument( |
| "--save-dir", |
| type=str, |
| default="", |
| help="Specified dataset.", |
| ) |
| |
| |
| |
| |
| |
| |
| parser.add_argument( |
| "--img-batch-size", type=int, default=64, help="Image batch size." |
| ) |
| parser.add_argument( |
| "--context-length", |
| type=int, |
| default=52, |
| help="The maximum length of input text (include [CLS] & [SEP] tokens)." |
| ) |
| parser.add_argument( |
| "--resume", |
| default=None, |
| type=str, |
| help="path to latest checkpoint (default: none)", |
| ) |
| parser.add_argument( |
| "--num-workers", type=int, default=4, help="Number of workers for ImageNet dataloader." |
| ) |
| args = parser.parse_args() |
|
|
| return args |
|
|
| |
| |
| def convert_models_to_fp32(model): |
| for p in model.parameters(): |
| p.data = p.data.float() |
| if p.grad: |
| p.grad.data = p.grad.data.float() |
|
|
|
|
| def zero_shot_classifier(model, classnames, templates, args): |
| with torch.no_grad(): |
| zeroshot_weights = [] |
| for classname in tqdm(classnames): |
| texts = [_preprocess_text(template(classname)) for template in templates] |
| texts = tokenize(texts, context_length=args.context_length).to(args.gpu) |
| class_embeddings = model(None, texts) |
| class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) |
| class_embedding = class_embeddings.mean(dim=0) |
| class_embedding /= class_embedding.norm() |
| zeroshot_weights.append(class_embedding) |
| zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.gpu) |
| return zeroshot_weights |
|
|
|
|
| def accuracy(output, target, topk=(1,)): |
| pred = output.topk(max(topk), 1, True, True)[1].t() |
| correct = pred.eq(target.view(1, -1).expand_as(pred)) |
| return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] |
|
|
|
|
| def run(model, classifier, dataloader, args): |
| total_logits = [] |
| total_targets = [] |
| with torch.no_grad(): |
| top1, top5, n = 0.0, 0.0, 0.0 |
| for images, target in tqdm(dataloader): |
| images = images.to(args.gpu) |
| target = target.to(args.gpu) |
| total_targets.append(target) |
|
|
| |
| image_features = model(images, None) |
| image_features /= image_features.norm(dim=-1, keepdim=True) |
| logits = (100.0 * image_features @ classifier).softmax(dim=-1) |
| total_logits.append(logits) |
|
|
| |
| acc1, acc5 = accuracy(logits, target, topk=(1, 1)) |
| top1 += acc1 |
| n += images.size(0) |
|
|
| outputs = torch.cat(total_logits, dim=0) |
| targets = torch.cat(total_targets, dim=0) |
|
|
| if getattr(args, "index", ""): |
| print("Use index to rearrange the logits...") |
| with open(args.index, "r", encoding="utf-8") as f: |
| index = json.load(f) |
| print(index) |
| outputs = outputs[index] |
| targets = targets[index] |
| print(targets) |
|
|
| top1 = top1 / n |
|
|
| return top1, outputs |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
|
|
| |
| print("Params:") |
| for name in sorted(vars(args)): |
| val = getattr(args, name) |
| print(f" {name}: {val}") |
|
|
| args.gpu = 0 |
| torch.cuda.set_device(args.gpu) |
|
|
| |
| vision_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{args.vision_model.replace('/', '-')}.json" |
| print('Loading vision model config from', vision_model_config_file) |
| assert os.path.exists(vision_model_config_file) |
| |
| text_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{args.text_model.replace('/', '-')}.json" |
| print('Loading text model config from', text_model_config_file) |
| assert os.path.exists(text_model_config_file) |
| |
| with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft: |
| model_info = json.load(fv) |
| if isinstance(model_info['vision_layers'], str): |
| model_info['vision_layers'] = eval(model_info['vision_layers']) |
| for k, v in json.load(ft).items(): |
| model_info[k] = v |
|
|
| model = CLIP(**model_info) |
| convert_weights(model) |
|
|
| |
| if args.precision == "amp" or args.precision == "fp32": |
| convert_models_to_fp32(model) |
| model.cuda(args.gpu) |
| if args.precision == "fp16": |
| convert_weights(model) |
|
|
| |
| print("Preparing zeroshot dataset.") |
| data = {} |
| print(f"{model_info['image_resolution']}") |
| data[args.dataset] = get_zeroshot_dataset( |
| args, image_transform(model_info["image_resolution"]) |
| ) |
|
|
| |
| print("Begin to load model checkpoint from {}.".format(args.resume)) |
| assert os.path.exists(args.resume), "The checkpoint file {} not exists!".format(args.resume) |
| |
| loc = "cuda:{}".format(args.gpu) |
| checkpoint = torch.load(args.resume, map_location='cpu') |
| start_epoch = checkpoint["epoch"] |
| sd = checkpoint["state_dict"] |
| if next(iter(sd.items()))[0].startswith('module'): |
| sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k} |
| model.load_state_dict(sd, strict=False) |
| print( |
| f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']} @ {checkpoint['step']} steps)" |
| ) |
|
|
| |
| print('Building zero-shot classifier') |
|
|
| model.eval() |
|
|
| f = open(args.label_file, "r", encoding="utf8") |
| classnames = [line.strip() for line in f.readlines()] |
|
|
| template_dict = { |
| "fgvc-aircraft-2013b-variants102": aircraft_templates, |
| "food-101": food_templates, |
| "oxford-flower-102": flower_templates, |
| "eurosat_clip": eurosat_templates, |
| "resisc45_clip": eurosat_templates, |
| "country211": country211_templates, |
| "openai": openai_templates, |
| } |
| if args.dataset in template_dict.keys(): |
| templates = template_dict[args.dataset] |
| else: |
| templates = template_dict['openai'] |
|
|
| |
| print('Using classifier') |
| classifier = zero_shot_classifier(model, classnames, templates, args) |
| results = {} |
| top1, logits = run(model, classifier, data[args.dataset].dataloader, args) |
|
|
|
|
| results["zeroshot-top1"] = top1 |
|
|
| print('Result:') |
| print(", ".join(["{}: {}".format(k, v) for k, v in results.items()])) |
| print('Finished.') |
|
|