| |
| ''' |
| This script extracts image and text features for evaluation. (with single-GPU) |
| ''' |
|
|
| import os |
| import argparse |
| import logging |
| from pathlib import Path |
| import json |
|
|
| import torch |
| from tqdm import tqdm |
|
|
| from clip.model import convert_weights, CLIP |
| from eval.data import get_eval_img_dataset, get_eval_txt_dataset |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| '--extract-image-feats', |
| action="store_true", |
| default=False, |
| help="Whether to extract image features." |
| ) |
| parser.add_argument( |
| '--extract-text-feats', |
| action="store_true", |
| default=False, |
| help="Whether to extract text features." |
| ) |
| parser.add_argument( |
| '--image-data', |
| type=str, |
| default="../Multimodal_Retrieval/lmdb/test/imgs", |
| help="If --extract-image-feats is True, specify the path of the LMDB directory storing input image base64 strings." |
| ) |
| parser.add_argument( |
| '--text-data', |
| type=str, |
| default="../Multimodal_Retrieval/test_texts.jsonl", |
| help="If --extract-text-feats is True, specify the path of input text Jsonl file." |
| ) |
| parser.add_argument( |
| '--image-feat-output-path', |
| type=str, |
| default=None, |
| help="If --extract-image-feats is True, specify the path of output image features." |
| ) |
| parser.add_argument( |
| '--text-feat-output-path', |
| type=str, |
| default=None, |
| help="If --extract-image-feats is True, specify the path of output text features." |
| ) |
| parser.add_argument( |
| "--img-batch-size", type=int, default=64, help="Image batch size." |
| ) |
| parser.add_argument( |
| "--text-batch-size", type=int, default=64, help="Text batch size." |
| ) |
| parser.add_argument( |
| "--context-length", type=int, default=64, help="The maximum length of input text (include [CLS] & [SEP] tokens)." |
| ) |
| parser.add_argument( |
| "--resume", |
| default=None, |
| type=str, |
| help="path to latest checkpoint (default: none)", |
| ) |
| parser.add_argument( |
| "--precision", |
| choices=["amp", "fp16", "fp32"], |
| default="amp", |
| help="Floating point precition." |
| ) |
| parser.add_argument( |
| "--vision-model", |
| choices=["ViT-B-16", "ViT-L-14", "RN50"], |
| default="ViT-B-16", |
| help="Name of the vision backbone to use.", |
| ) |
| parser.add_argument( |
| "--text-model", |
| choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese", "RBT3-chinese"], |
| default="RoBERTa-wwm-ext-base-chinese", |
| help="Name of the text backbone to use.", |
| ) |
| parser.add_argument( |
| "--debug", |
| default=False, |
| action="store_true", |
| help="If true, more information is logged." |
| ) |
| args = parser.parse_args() |
|
|
| return args |
|
|
| |
| |
| def convert_models_to_fp32(model): |
| for p in model.parameters(): |
| p.data = p.data.float() |
| if p.grad: |
| p.grad.data = p.grad.data.float() |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
|
|
| assert args.extract_image_feats or args.extract_text_feats, "--extract-image-feats and --extract-text-feats cannot both be False!" |
|
|
| |
| print("Params:") |
| for name in sorted(vars(args)): |
| val = getattr(args, name) |
| print(f" {name}: {val}") |
| |
| args.gpu = 0 |
| torch.cuda.set_device(args.gpu) |
|
|
| |
| vision_model_config_file = Path(__file__).parent.parent.parent / f"clip/model_configs/{args.vision_model.replace('/', '-')}.json" |
| print('Loading vision model config from', vision_model_config_file) |
| assert os.path.exists(vision_model_config_file) |
| |
| text_model_config_file = Path(__file__).parent.parent.parent / f"clip/model_configs/{args.text_model.replace('/', '-')}.json" |
| print('Loading text model config from', text_model_config_file) |
| assert os.path.exists(text_model_config_file) |
| |
| with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft: |
| model_info = json.load(fv) |
| if isinstance(model_info['vision_layers'], str): |
| model_info['vision_layers'] = eval(model_info['vision_layers']) |
| for k, v in json.load(ft).items(): |
| model_info[k] = v |
|
|
| model = CLIP(**model_info) |
| convert_weights(model) |
|
|
| |
| if args.precision == "amp" or args.precision == "fp32": |
| convert_models_to_fp32(model) |
| model.cuda(args.gpu) |
| if args.precision == "fp16": |
| convert_weights(model) |
|
|
| |
| if args.extract_image_feats: |
| print("Preparing image inference dataset.") |
| img_data = get_eval_img_dataset(args) |
| if args.extract_text_feats: |
| print("Preparing text inference dataset.") |
| text_data = get_eval_txt_dataset(args, max_txt_length=args.context_length) |
| |
| |
| print("Begin to load model checkpoint from {}.".format(args.resume)) |
| assert os.path.exists(args.resume), "The checkpoint file {} not exists!".format(args.resume) |
| |
| loc = "cuda:{}".format(args.gpu) |
| checkpoint = torch.load(args.resume, map_location='cpu') |
| start_epoch = checkpoint["epoch"] |
| sd = checkpoint["state_dict"] |
| if next(iter(sd.items()))[0].startswith('module'): |
| sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k} |
| model.load_state_dict(sd) |
| print( |
| f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']} @ {checkpoint['step']} steps)" |
| ) |
|
|
| |
| if args.extract_text_feats: |
| print('Make inference for texts...') |
| if args.text_feat_output_path is None: |
| args.text_feat_output_path = "{}.txt_feat.jsonl".format(args.text_data[:-6]) |
| write_cnt = 0 |
| with open(args.text_feat_output_path, "w") as fout: |
| model.eval() |
| dataloader = text_data.dataloader |
| with torch.no_grad(): |
| for batch in tqdm(dataloader): |
| text_ids, texts = batch |
| texts = texts.cuda(args.gpu, non_blocking=True) |
| text_features = model(None, texts) |
| text_features /= text_features.norm(dim=-1, keepdim=True) |
| for text_id, text_feature in zip(text_ids.tolist(), text_features.tolist()): |
| fout.write("{}\n".format(json.dumps({"text_id": text_id, "feature": text_feature}))) |
| write_cnt += 1 |
| print('{} text features are stored in {}'.format(write_cnt, args.text_feat_output_path)) |
|
|
| |
| if args.extract_image_feats: |
| print('Make inference for images...') |
| if args.image_feat_output_path is None: |
| |
| args.image_feat_output_path = "{}.img_feat.jsonl".format(args.text_data.replace("_texts.jsonl", "_imgs")) |
| write_cnt = 0 |
| with open(args.image_feat_output_path, "w") as fout: |
| model.eval() |
| dataloader = img_data.dataloader |
| with torch.no_grad(): |
| for batch in tqdm(dataloader): |
| image_ids, images = batch |
| images = images.cuda(args.gpu, non_blocking=True) |
| image_features = model(images, None) |
| image_features /= image_features.norm(dim=-1, keepdim=True) |
| for image_id, image_feature in zip(image_ids.tolist(), image_features.tolist()): |
| fout.write("{}\n".format(json.dumps({"image_id": image_id, "feature": image_feature}))) |
| write_cnt += 1 |
| print('{} image features are stored in {}'.format(write_cnt, args.image_feat_output_path)) |
|
|
| print("Done!") |