| import datetime |
| import logging |
| import time |
|
|
| import numpy as np |
| import torch |
| import torch.distributed as dist |
| import torch.nn.functional as F |
| from einops import rearrange |
|
|
| from models.criterions import get_sim |
| from utils.basic_utils import MetricLogger |
| from utils.distributed import get_rank, get_world_size |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def extract_text_feats(texts, max_txt_l, tokenizer, model, device): |
| num_text = len(texts) |
| text_bs = 256 |
| text_feats = [] |
| text_atts = [] |
| text_ids = [] |
|
|
| for i in range(0, num_text, text_bs): |
| text = texts[i : min(num_text, i + text_bs)] |
| text_input = tokenizer( |
| text, |
| padding="max_length", |
| truncation=True, |
| max_length=max_txt_l, |
| return_tensors="pt", |
| ).to(device) |
|
|
| text_feat = model.encode_text(text_input)[0] |
| text_feats.append(text_feat) |
| text_atts.append(text_input.attention_mask) |
| text_ids.append(text_input.input_ids) |
|
|
| text_feats = torch.cat(text_feats, dim=0) |
| text_atts = torch.cat(text_atts, dim=0) |
| text_ids = torch.cat(text_ids, dim=0) |
| return text_feats, text_atts, text_ids |
|
|
|
|
| def extract_vision_feats(data_loader, model, device, config): |
| image_feats_all = [] |
| pooled_image_feats_all = [] |
| metric_logger = MetricLogger(delimiter=" ") |
| header = "extracting image feats" |
| iterator = metric_logger.log_every(data_loader, 100, header) |
| for image, img_id in iterator: |
| image = image.to(device, non_blocking=True) |
| image_feat, pooled_image_feat = model.encode_vision(image, test=True) |
| if config.evaluation.eval_frame_ensemble == "concat": |
| if len(image_feat.shape) == 4: |
| image_feat = rearrange(image_feat, "b t l c -> b (t l) c").contiguous() |
| image_feat = image_feat.unsqueeze(1) |
| else: |
| assert config.video_input.num_frames == 1, "only support single-frame" |
| assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"] |
| if config.evaluation.eval_offload: |
| image_feats_all.append(image_feat.cpu()) |
| pooled_image_feats_all.append(pooled_image_feat.cpu()) |
| else: |
| image_feats_all.append(image_feat) |
| pooled_image_feats_all.append(pooled_image_feat) |
|
|
| image_feats_all = torch.cat(image_feats_all, dim=0) |
|
|
| pooled_image_feats_all = torch.cat(pooled_image_feats_all, dim=0) |
| return image_feats_all, pooled_image_feats_all |
|
|
|
|
| @torch.no_grad() |
| def evaluation_wrapper(model, data_loader, tokenizer, device, config, prefix=""): |
| if dist.get_rank() == 0: |
| |
| with torch.cuda.amp.autocast(enabled=config.fp16, dtype=torch.float): |
| |
| if config.model.model_cls == "VindLU_VideoCLIP" or config.model.model_cls == "ViCLIP": |
|
|
| i2t_x, t2i_x, i2t_emb, t2i_emb = evaluation_video_clip( |
| model, data_loader, tokenizer, device, config |
| ) |
| else: |
| i2t_x, t2i_x, i2t_emb, t2i_emb = evaluation( |
| model, data_loader, tokenizer, device, config |
| ) |
| score_pairs = [ |
| (prefix + "/", i2t_x, t2i_x), |
| (prefix + "_emb/", i2t_emb, t2i_emb), |
| ] |
| res = dict() |
| for name, i2t, t2i in score_pairs: |
| if i2t is not None: |
| txt2img_ids = data_loader.dataset.txt2img |
| img2txt_ids = data_loader.dataset.img2txt |
| res[name] = itm_eval(i2t, t2i, txt2img_ids, img2txt_ids) |
| |
| else: |
| res = dict() |
|
|
| res_list = [res] |
| dist.broadcast_object_list(res_list, src=0) |
|
|
| res = res_list[0] |
| |
| return res |
|
|
|
|
| @torch.no_grad() |
| def evaluation(model, data_loader, tokenizer, device, config): |
| model.eval() |
|
|
| metric_logger = MetricLogger(delimiter=" ") |
| header = "Evaluation:" |
| dtype = torch.half if config.fp16 else torch.float |
| media_type = data_loader.dataset.media_type |
| logger.info(f"Start evaluation for media_type={media_type}") |
|
|
| logger.info("Computing dual encoder features...") |
| start_time = time.time() |
|
|
| |
| texts = data_loader.dataset.text |
| max_txt_l = config.inputs.max_txt_l |
| if not isinstance(max_txt_l, int): |
| max_txt_l = max_txt_l[media_type] |
| text_feats, text_atts, text_ids = extract_text_feats( |
| texts, max_txt_l, tokenizer, model, device |
| ) |
|
|
| image_feats, pooled_image_feats = extract_vision_feats( |
| data_loader, model, device, config |
| ) |
| logger.info("Finished feature extraction") |
| logger.info("Computing ITC scores [dot-product]") |
| _pooled_image_feats = ( |
| pooled_image_feats.to(device, non_blocking=True) |
| if config.evaluation.eval_offload |
| else pooled_image_feats |
| ) |
| i2t_scores, t2i_scores = get_sim( |
| model.vision_proj(_pooled_image_feats), model.text_proj(text_feats[:, 0]), |
| agg_method=config.model.get("agg_method", "mean"), |
| ) |
| logger.info("Computing ITC scores [dot-product], done!") |
|
|
| num_images = len(data_loader.dataset.image) |
| i2t_scores_x = torch.full((num_images, len(texts)), -100.0).to( |
| device, torch.float, non_blocking=True |
| ) |
|
|
| |
| logger.info("Rerank dual-encoder results with cross-encoder...") |
| num_tasks = get_world_size() |
| rank = get_rank() |
| |
| |
| step = num_images // num_tasks + 1 |
| start = rank * step |
| end = min(num_images, start + step) |
|
|
| text_encoder = model.get_text_encoder() |
| iterator = metric_logger.log_every(i2t_scores[start:end], 100, header) |
| logger.info(f"i2t_scores.shape {i2t_scores[start:end].shape}") |
|
|
| |
| n_clip_per_video = ( |
| image_feats.shape[1] if not config.deep_fusion else image_feats[0].shape[1] |
| ) |
|
|
| logger.info( |
| f"n_clip_per_video={n_clip_per_video}, with eval_frame_ensemble={config.evaluation.eval_frame_ensemble}" |
| ) |
| for i, sims in enumerate(iterator): |
| k = min(len(sims), config.evaluation.k_test) |
| topk_sim, topk_idx = sims.topk(k=k, dim=0) |
|
|
| clip_scores = [] |
| for clip_idx in range(n_clip_per_video): |
| if config.deep_fusion: |
| encoder_output = [ |
| feat[start + i, clip_idx].to(device, non_blocking=True) |
| for feat in image_feats |
| ] |
|
|
| else: |
| encoder_output = ( |
| image_feats[start + i, clip_idx].to(device, non_blocking=True) |
| if config.evaluation.eval_offload |
| else image_feats[start + i, clip_idx] |
| ) |
|
|
| """ original |
| encoder_output = encoder_output.repeat(k, 1, 1) # (k=128, #frm*Li, d) |
| encoder_att = torch.ones( |
| encoder_output.size()[:-1], dtype=torch.long |
| ).to(device, non_blocking=True) |
| output = text_encoder( |
| encoder_embeds=text_feats[topk_idx], |
| attention_mask=text_atts[topk_idx], |
| encoder_hidden_states=encoder_output, |
| encoder_attention_mask=encoder_att, |
| return_dict=True, |
| mode="fusion" |
| ) |
| |
| itm_embeds = output.last_hidden_state[:, 0] |
| """ |
|
|
| |
| bs = 128 |
| |
| itm_embeds = [] |
|
|
| if not config.deep_fusion: |
| encoder_output = [encoder_output] |
| encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output] |
| encoder_att = [ |
| torch.ones(feat.size()[:-1], dtype=torch.long).to(device, non_blocking=True) |
| for feat in encoder_output |
| ] |
| |
| for j in range(0, len(topk_idx), bs): |
| cur_bs = min(bs, len(topk_idx) - j) |
| encoder_output = [feat[:cur_bs] for feat in encoder_output] |
| encoder_att = [att[:cur_bs] for att in encoder_att] |
|
|
| batch_encoder_output = encoder_output if config.deep_fusion else encoder_output[0] |
| batch_encoder_att = encoder_att if config.deep_fusion else encoder_att[0] |
| |
| if "VindLU_BLIP" in config.model.get("model_cls", ""): |
| output = model.vtm_embed( |
| text_ids=text_ids[topk_idx[j:j+bs]], |
| text_atts=text_atts[topk_idx[j:j+bs]], |
| vision_embeds=batch_encoder_output, |
| vision_atts=batch_encoder_att, |
| ) |
| else: |
| output = text_encoder( |
| encoder_embeds=text_feats[topk_idx[j:j+bs]], |
| attention_mask=text_atts[topk_idx[j:j+bs]], |
| encoder_hidden_states=batch_encoder_output, |
| encoder_attention_mask=batch_encoder_att, |
| return_dict=True, |
| mode="fusion", |
| ).last_hidden_state[:, 0] |
|
|
| itm_embeds.append(output) |
|
|
| itm_embeds = torch.cat(itm_embeds, dim=0) |
| |
| """ Original |
| if config.deep_fusion: |
| encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output] |
| encoder_att = [ |
| torch.ones(feat.size()[:-1], dtype=torch.long).to( |
| device, non_blocking=True |
| ) |
| for feat in encoder_output |
| ] |
| else: |
| encoder_output = encoder_output.repeat(bs, 1, 1) |
| encoder_att = torch.ones( |
| encoder_output.size()[:-1], dtype=torch.long |
| ).to(device, non_blocking=True) |
| |
| if config.deep_fusion: |
| if len(topk_idx) % bs != 0: |
| left = len(topk_idx) % bs |
| left_encoder_output = [feat.repeat(left, 1, 1) for feat in encoder_output] |
| left_encoder_att = [ |
| torch.ones(feat.size()[:-1], dtype=torch.long).to( |
| device, non_blocking=True |
| ) |
| for feat in left_encoder_output |
| ] |
| encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output] |
| encoder_att = [ |
| torch.ones(feat.size()[:-1], dtype=torch.long).to( |
| device, non_blocking=True |
| ) |
| for feat in encoder_output |
| ] |
| else: |
| if len(topk_idx) % bs != 0: |
| left = len(topk_idx) % bs |
| left_encoder_output = encoder_output.repeat(left, 1, 1) # (k=128, #frm*Li, d) |
| left_encoder_att = torch.ones(left_encoder_output.size()[:-1], dtype=torch.long).to( |
| device, non_blocking=True |
| ) |
| encoder_output = encoder_output.repeat(bs, 1, 1) # (k=128, #frm*Li, d) |
| encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to( |
| device, non_blocking=True |
| ) |
| |
| for j in range(0, len(topk_idx), bs): |
| if j + bs > len(topk_idx): |
| output = text_encoder( |
| encoder_embeds=text_feats[topk_idx[j:]], |
| attention_mask=text_atts[topk_idx[j:]], |
| encoder_hidden_states=left_encoder_output, |
| encoder_attention_mask=left_encoder_att, |
| return_dict=True, |
| mode="fusion", |
| ) |
| else: |
| output = text_encoder( |
| encoder_embeds=text_feats[topk_idx[j : j + bs]], |
| attention_mask=text_atts[topk_idx[j : j + bs]], |
| encoder_hidden_states=encoder_output, |
| encoder_attention_mask=encoder_att, |
| return_dict=True, |
| mode="fusion", |
| ) |
| batch_itm_embeds = output.last_hidden_state[:, 0] |
| itm_embeds.append(batch_itm_embeds) |
| itm_embeds = torch.cat(itm_embeds, dim=0) |
| # end new |
| """ |
|
|
| score = model.itm_head(itm_embeds)[:, 1] |
| clip_scores.append(score) |
|
|
| if len(clip_scores) == 1: |
| score = clip_scores[0] |
| else: |
| assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"] |
| clip_scores = torch.stack(clip_scores) |
| if config.evaluation.eval_frame_ensemble == "mean": |
| score = clip_scores.mean(0) |
| elif config.evaluation.eval_frame_ensemble == "max": |
| score = clip_scores.max(0)[0] |
| elif config.evaluation.eval_frame_ensemble == "lse": |
| score = torch.logsumexp(clip_scores, dim=0) |
| else: |
| raise ValueError( |
| "config.evaluation.eval_frame_ensemble must in [mean, max, lse] when #clip > 1." |
| ) |
|
|
| i2t_scores_x[start + i, topk_idx] = score.to(i2t_scores_x.dtype) |
|
|
| |
| num_text = len(data_loader.dataset.text) |
| t2i_scores_x = torch.full((num_text, len(data_loader.dataset.image)), -100.0).to( |
| device, torch.float, non_blocking=True |
| ) |
|
|
| step = num_text // num_tasks + 1 |
| start = rank * step |
| end = min(num_text, start + step) |
|
|
| iterator = metric_logger.log_every(t2i_scores[start:end], 100, header) |
| logger.info(f"t2i_scores.shape {t2i_scores[start:end].shape}") |
| |
| n_clip_per_video = ( |
| image_feats.shape[1] if not config.deep_fusion else image_feats[0].shape[1] |
| ) |
| k = config.evaluation.k_test |
| logger.info(f"Top-{k} matching") |
| for i, sims in enumerate(iterator): |
| k = min(len(sims), config.evaluation.k_test) |
| topk_sim, topk_idx = sims.topk(k=k, dim=0) |
|
|
| clip_scores = [] |
| for clip_idx in range(n_clip_per_video): |
|
|
| """old |
| encoder_output = image_feats[topk_idx, clip_idx].to(device, non_blocking=True) \ |
| if config.evaluation.eval_offload else image_feats[topk_idx, clip_idx] |
| encoder_att = torch.ones( |
| encoder_output.size()[:-1], dtype=torch.long |
| ).to(device, non_blocking=True) |
| output = text_encoder( |
| encoder_embeds=text_feats[start+i].repeat(k, 1, 1), |
| attention_mask=text_atts[start+i].repeat(k, 1), |
| encoder_hidden_states=encoder_output, |
| encoder_attention_mask=encoder_att, |
| return_dict=True, |
| mode="fusion" |
| ) |
| |
| itm_embeds = output.last_hidden_state[:, 0] |
| """ |
|
|
| |
| bs = 128 |
| |
| itm_embeds = [] |
| for j in range(0, len(topk_idx), bs): |
|
|
| fake_image_feats = [image_feats] if not config.deep_fusion else image_feats |
|
|
| encoder_output = [ |
| feat[topk_idx[j : j + bs], clip_idx].to(device, non_blocking=True) |
| if config.evaluation.eval_offload |
| else feat[topk_idx[j : j + bs], clip_idx] |
| for feat in fake_image_feats |
| ] |
| encoder_att = [ |
| torch.ones(feat.size()[:-1], dtype=torch.long).to( |
| device, non_blocking=True |
| ) |
| for feat in encoder_output |
| ] |
| cur_bs = min(bs, len(topk_idx) - j) |
|
|
| batch_encoder_output = encoder_output if config.deep_fusion else encoder_output[0] |
| batch_encoder_att = encoder_att if config.deep_fusion else encoder_att[0] |
|
|
| if "VindLU_BLIP" in config.model.get("model_cls", ""): |
| output = model.vtm_embed( |
| text_ids=text_ids[start + i].repeat(cur_bs, 1), |
| text_atts=text_atts[start + i].repeat(cur_bs, 1), |
| vision_embeds=batch_encoder_output, |
| vision_atts=batch_encoder_att, |
| ) |
| else: |
| output = text_encoder( |
| encoder_embeds=text_feats[start + i].repeat(cur_bs, 1, 1), |
| attention_mask=text_atts[start + i].repeat(cur_bs, 1), |
| encoder_hidden_states=batch_encoder_output, |
| encoder_attention_mask=batch_encoder_att, |
| return_dict=True, |
| mode="fusion", |
| ).last_hidden_state[:, 0] |
|
|
| itm_embeds.append(output) |
|
|
| """ old |
| if config.deep_fusion: |
| encoder_output = [ |
| feat[topk_idx[j : j + bs], clip_idx].to(device, non_blocking=True) |
| for feat in image_feats |
| ] |
| encoder_att = [ |
| torch.ones(feat.size()[:-1], dtype=torch.long).to( |
| device, non_blocking=True |
| ) |
| for feat in encoder_output |
| ] |
| else: |
| encoder_output = ( |
| image_feats[topk_idx[j : j + bs], clip_idx].to( |
| device, non_blocking=True |
| ) |
| if config.evaluation.eval_offload |
| else image_feats[topk_idx[j : j + bs], clip_idx] |
| ) |
| encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to( |
| device, non_blocking=True |
| ) |
| |
| cur_bs = ( |
| encoder_output.shape[0] |
| if not config.deep_fusion |
| else encoder_output[0].shape[0] |
| ) |
| output = text_encoder( |
| encoder_embeds=text_feats[start + i].repeat(cur_bs, 1, 1), |
| attention_mask=text_atts[start + i].repeat(cur_bs, 1), |
| encoder_hidden_states=encoder_output, |
| encoder_attention_mask=encoder_att, |
| return_dict=True, |
| mode="fusion", |
| ) |
| |
| batch_itm_embeds = output.last_hidden_state[:, 0] |
| itm_embeds.append(batch_itm_embeds) |
| """ |
|
|
| itm_embeds = torch.cat(itm_embeds, dim=0) |
| |
|
|
| score = model.itm_head(itm_embeds)[:, 1] |
| clip_scores.append(score) |
|
|
| if len(clip_scores) == 1: |
| score = clip_scores[0] |
| else: |
| assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"] |
| clip_scores = torch.stack(clip_scores) |
| if config.evaluation.eval_frame_ensemble == "mean": |
| score = clip_scores.mean(0) |
| elif config.evaluation.eval_frame_ensemble == "max": |
| score = clip_scores.max(0)[0] |
| elif config.evaluation.eval_frame_ensemble == "lse": |
| score = torch.logsumexp(clip_scores, dim=0) |
| else: |
| raise ValueError( |
| "config.evaluation.eval_frame_ensemble must in [mean, max, lse] when #clip > 1." |
| ) |
|
|
| t2i_scores_x[start + i, topk_idx] = score.to(t2i_scores_x.dtype) |
|
|
| if config.distributed: |
| |
| dist.barrier() |
| dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM) |
| dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM) |
|
|
| total_time = time.time() - start_time |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
| logger.info(f"Evaluation time {total_time_str}") |
|
|
| return ( |
| i2t_scores_x.cpu().float().numpy(), |
| t2i_scores_x.cpu().float().numpy(), |
| i2t_scores.cpu().float().numpy(), |
| i2t_scores.T.cpu().float().numpy(), |
| ) |
|
|
|
|
| @torch.no_grad() |
| def evaluation_video_clip(model, data_loader, tokenizer, device, config): |
| model.eval() |
|
|
| metric_logger = MetricLogger(delimiter=" ") |
| header = "Evaluation:" |
| |
| dtype = torch.float32 |
| media_type = data_loader.dataset.media_type |
| logger.info(f"Start evaluation for media_type={media_type}") |
|
|
| logger.info("Computing dual encoder features...") |
|
|
| |
| texts = data_loader.dataset.text |
| num_text = len(texts) |
| text_bs = 256 |
| text_feats = [] |
| for i in range(0, num_text, text_bs): |
| text = texts[i : min(num_text, i + text_bs)] |
| text_feat = model.encode_text(text) |
| text_feats.append(text_feat.cpu()) |
| text_feats = torch.cat(text_feats, dim=0) |
| logger.info("Finished computing text features") |
|
|
| if hasattr(data_loader.dataset, "num_prompts"): |
| np = data_loader.dataset.num_prompts |
| logger.info("Using {} prompts".format(np)) |
| nt = len(data_loader.dataset.text) // np |
| text_feats = text_feats.view(nt, np, -1) |
|
|
| image_feats = [] |
| metric_logger = MetricLogger(delimiter=" ") |
| header = "extracting image feats" |
| iterator = metric_logger.log_every(data_loader, 100, header) |
| for image, _ in iterator: |
| image = image.to(device, non_blocking=True) |
| image_feat = model.encode_vision(image, test=True) |
| image_feats.append(image_feat.cpu()) |
| image_feats = torch.cat(image_feats, dim=0) |
| logger.info("Finished feature extraction") |
| logger.info("Computing ITC scores [dot-product]") |
| i2t_scores, t2i_scores = get_sim(image_feats, text_feats) |
| del image_feats, text_feats |
| logger.info("Computing ITC scores [dot-product], done!") |
|
|
| i2t_scores_dsl = i2t_scores * i2t_scores.softmax(dim=0) |
| i2t_scores_dsl_T = i2t_scores.T * i2t_scores.T.softmax(dim=0) |
|
|
| return ( |
| i2t_scores.cpu().float().numpy(), |
| i2t_scores.T.cpu().float().numpy(), |
| i2t_scores_dsl.cpu().float().numpy(), |
| i2t_scores_dsl_T.cpu().float().numpy(), |
| ) |
|
|
|
|
| @torch.no_grad() |
| def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt): |
| |
| ranks = np.zeros(scores_i2t.shape[0]) |
| for index, score in enumerate(scores_i2t): |
| inds = np.argsort(score)[::-1] |
| |
| gt_txt_ids = img2txt[index] |
| if isinstance(gt_txt_ids, int): |
| ranks[index] = np.where(inds == gt_txt_ids)[0][0] |
| else: |
| rank = 1e20 |
| for i in gt_txt_ids: |
| tmp = np.where(inds == i)[0][0] |
| if tmp < rank: |
| rank = tmp |
| ranks[index] = rank |
|
|
| |
| tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) |
| tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) |
| tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) |
|
|
| |
| ranks = np.zeros(scores_t2i.shape[0]) |
|
|
| for index, score in enumerate(scores_t2i): |
| inds = np.argsort(score)[::-1] |
| gt_img_ids = txt2img[index] |
| if isinstance(gt_img_ids, int): |
| ranks[index] = np.where(inds == gt_img_ids)[0][0] |
| else: |
| |
| rank = 1e20 |
| for i in gt_img_ids: |
| tmp = np.where(inds == i)[0][0] |
| if tmp < rank: |
| rank = tmp |
| ranks[index] = rank |
|
|
| |
| ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) |
| ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) |
| ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) |
|
|
| tr_mean = (tr1 + tr5 + tr10) / 3 |
| ir_mean = (ir1 + ir5 + ir10) / 3 |
| r_mean = (tr_mean + ir_mean) / 2 |
|
|
| eval_result = { |
| "txt_r1": tr1, |
| "txt_r5": tr5, |
| "txt_r10": tr10, |
| "txt_r_mean": tr_mean, |
| "img_r1": ir1, |
| "img_r5": ir5, |
| "img_r10": ir10, |
| "img_r_mean": ir_mean, |
| "r_mean": r_mean, |
| } |
| eval_result = {k: round(v, 2) for k, v in eval_result.items()} |
| return eval_result |
|
|