| import logging |
| from os.path import join |
|
|
| import torch |
| import torch.backends.cudnn as cudnn |
| import torch.distributed as dist |
| import torch.nn.functional as F |
| from einops import rearrange |
|
|
| from dataset import create_dataset, create_loader |
| from models.utils import tile |
| from models.vindlu import VindLU |
| from models.vindlu_vit import VindLU_VIT |
| from tasks.shared_utils import setup_model |
| from utils.basic_utils import (MetricLogger, flat_list_of_lists, save_json, |
| setup_seed) |
| from utils.config_utils import setup_main |
| from utils.distributed import get_rank |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def get_sim_for_each_question(model, pooled_image_feat, pooled_text_feat): |
| """TODO: Docstring for get_sim_for_each_question. |
| |
| Args: |
| model (TODO): TODO |
| pooled_image_feat (torch.Tensor): Shape: [b,t, c] |
| pooled_text_feat (torch.Tensor): Shape: [b, n, c]. n is the number of answer candidates. |
| |
| Returns: TODO |
| |
| """ |
| image_proj = model.vision_proj |
| text_proj = model.text_proj |
|
|
| image_feat = F.normalize(image_proj(pooled_image_feat), dim=-1) |
| text_feat = F.normalize(text_proj(pooled_text_feat), dim=-1) |
| sim = torch.matmul(image_feat, rearrange(text_feat, "b n c -> b c n")) |
| sim = sim.mean(1) / model.temp |
| sim = F.softmax(sim, dim=1) |
| return sim |
|
|
|
|
| def main(config): |
| logger.info(f"config: \n{config}") |
| logger.info(f"train_file: {config.train_file}") |
|
|
| setup_seed(config.seed + get_rank()) |
| device = torch.device(config.device) |
| cudnn.benchmark = True |
|
|
| |
| test_dataset = create_dataset("mc_test", config) |
| test_loader = create_loader( |
| [test_dataset], |
| [None], |
| batch_size=[config.batch_size_test.video], |
| num_workers=[config.num_workers], |
| is_trains=[False], |
| collate_fns=[None], |
| )[0] |
|
|
| config.scheduler.num_training_steps = 10 |
| config.scheduler.num_warmup_steps = 10 |
| model_cls = eval(config.model.get('model_cls', 'VindLU')) |
| ( |
| model, |
| model_without_ddp, |
| optimizer, |
| scheduler, |
| scaler, |
| tokenizer, |
| start_epoch, |
| global_step, |
| ) = setup_model( |
| config, |
| model_cls=model_cls, |
| has_decoder=False, |
| pretrain=False, |
| |
| find_unused_parameters=False, |
| ) |
| model = model_without_ddp |
|
|
| logger.info("Start " + "evaluation" if config.evaluate else "training") |
| metric_logger = MetricLogger(delimiter=" ") |
| iterator = metric_logger.log_every(test_loader, 5, "Evaluation: ") |
| num_options_per_q = 5 |
| all_gt_answers = [] |
| all_pred_answers = [] |
| with torch.cuda.amp.autocast(enabled=config.fp16, dtype=torch.bfloat16), torch.no_grad(): |
| for image, text, ans, ann in iterator: |
| image = image.to(device, non_blocking=True) |
| all_gt_answers.append(ans) |
| text = flat_list_of_lists(list(zip(*text))) |
| text_input = tokenizer( |
| text, |
| padding="max_length", |
| truncation=True, |
| max_length=config.max_txt_l, |
| return_tensors="pt", |
| ).to( |
| device |
| ) |
|
|
| |
| text_feat = model.encode_text(text_input)[0] |
| |
| image_feat, pooled_image_feat = model.encode_image(image) |
| image_feat = tile(image_feat, 0, num_options_per_q) |
| image_mask = torch.ones(image_feat.size()[:-1], dtype=torch.long).to( |
| device, non_blocking=True |
| ) |
| |
| |
| output = model.get_text_encoder()( |
| encoder_embeds=text_feat, |
| attention_mask=text_input.attention_mask, |
| encoder_hidden_states=image_feat, |
| encoder_attention_mask=image_mask, |
| return_dict=True, |
| mode="fusion", |
| ) |
| itm_embeds = output.last_hidden_state[:, 0] |
|
|
| score = model.itm_head(itm_embeds)[:, 1] |
| pred_ans = score.view(-1, num_options_per_q).max(1)[1].cpu() |
| all_pred_answers.append(pred_ans) |
|
|
| all_gt_answers = torch.cat(all_gt_answers, 0) |
| all_pred_answers = torch.cat(all_pred_answers, 0) |
| acc = all_gt_answers == all_pred_answers |
| acc = float(torch.sum(acc) / len(acc)) |
| eval_res = {"test": round(100 * acc, 2)} |
| logger.info(f"\n{eval_res}") |
| save_json(eval_res, join(config.output_dir, "eval_res.json")) |
|
|
| dist.barrier() |
|
|
|
|
| def main_with_ensemble(config): |
| logger.info(f"train_file: {config.train_file}") |
|
|
| setup_seed(config.seed + get_rank()) |
| device = torch.device(config.device) |
| cudnn.benchmark = True |
|
|
| |
| test_dataset = create_dataset("mc_test", config) |
| test_loader = create_loader( |
| [test_dataset], |
| [None], |
| batch_size=[config.inputs.batch_size_test.video], |
| num_workers=[config.num_workers], |
| is_trains=[False], |
| collate_fns=[None], |
| )[0] |
|
|
| config.scheduler.num_training_steps = 10 |
| config.scheduler.num_warmup_steps = 10 |
| model_cls = eval(config.model.get('model_cls', 'VindLU')) |
| ( |
| model, |
| model_without_ddp, |
| optimizer, |
| scheduler, |
| scaler, |
| tokenizer, |
| start_epoch, |
| global_step, |
| ) = setup_model( |
| config, |
| model_cls=model_cls, |
| has_decoder=False, |
| pretrain=False, |
| |
| find_unused_parameters=False, |
| ) |
| model = model_without_ddp |
|
|
| logger.info("Start " + "evaluation" if config.evaluate else "training") |
| metric_logger = MetricLogger(delimiter=" ") |
| iterator = metric_logger.log_every(test_loader, 5, "Evaluation: ") |
| num_options_per_q = 5 |
| all_gt_answers = [] |
| all_pred_answers = [] |
| predictions = [] |
| with torch.cuda.amp.autocast(enabled=config.fp16, dtype=torch.bfloat16), torch.no_grad(): |
| for image, text, ans, ann in iterator: |
| image = image.to(device, non_blocking=True) |
| all_gt_answers.append(ans) |
| text = flat_list_of_lists(list(zip(*text))) |
| text_input = tokenizer( |
| text, |
| padding="max_length", |
| truncation=True, |
| max_length=config.max_txt_l, |
| return_tensors="pt", |
| ).to( |
| device |
| ) |
|
|
| |
| |
| text_feat, pooled_text_feat = model.encode_text(text_input) |
| |
| if config.evaluation.eval_frame_ensemble == "concat": |
| image_feats, pooled_image_feat = model.encode_vision(image, test=True) |
| if len(image_feats.shape) == 4: |
| image_feats = rearrange(image_feats, "b t l c -> b (t l) c") |
| |
| image_feats = image_feats.unsqueeze(1) |
| pooled_image_feat = pooled_image_feat.unsqueeze(1) |
| else: |
| assert config.video_input.num_frames == 1, "only support single-frame" |
| assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"] |
| image_feats, pooled_image_feat = model.encode_vision( |
| image |
| ) |
| |
| n_clip_per_video = image_feats.shape[1] |
| clip_scores = [] |
| for clip_idx in range(n_clip_per_video): |
| image_feat = image_feats[:, clip_idx] |
| pooled_image_feat = pooled_image_feat[:, clip_idx] |
| image_feat = tile(image_feat, 0, num_options_per_q) |
| image_mask = torch.ones(image_feat.size()[:-1], dtype=torch.long).to( |
| device, non_blocking=True |
| ) |
|
|
| |
| pooled_text_feat = rearrange( |
| pooled_text_feat, "(b n) c -> b n c", n=num_options_per_q |
| ) |
| sim = get_sim_for_each_question( |
| model, pooled_image_feat, pooled_text_feat |
| ) |
| sim = sim.flatten() |
|
|
| |
| output = model.get_text_encoder()( |
| encoder_embeds=text_feat, |
| attention_mask=text_input.attention_mask, |
| encoder_hidden_states=image_feat, |
| encoder_attention_mask=image_mask, |
| return_dict=True, |
| mode="fusion", |
| ) |
| itm_embeds = output.last_hidden_state[:, 0] |
|
|
| score = F.softmax(model.itm_head(itm_embeds), dim=1)[:, 1] |
| score = score * 0.7 + sim * 0.3 |
|
|
| clip_scores.append(score) |
|
|
| if len(clip_scores) == 1: |
| score = clip_scores[0] |
| else: |
| assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"] |
| clip_scores = torch.stack(clip_scores) |
| if config.evaluation.eval_frame_ensemble == "mean": |
| score = clip_scores.mean(0) |
| elif config.evaluation.eval_frame_ensemble == "max": |
| score = clip_scores.max(0)[0] |
| elif config.evaluation.eval_frame_ensemble == "lse": |
| score = torch.logsumexp(clip_scores, dim=0) |
| else: |
| raise ValueError( |
| "config.evaluation.eval_frame_ensemble must in [mean, max, lse] when #clip > 1." |
| ) |
|
|
| pred_ans = score.view(-1, num_options_per_q).max(1)[1].cpu() |
| all_pred_answers.append(pred_ans) |
|
|
| |
| ensemble_scores = score.view(-1, num_options_per_q).cpu() |
| if n_clip_per_video > 1: |
| clip_scores = clip_scores.view( |
| n_clip_per_video, -1, num_options_per_q |
| ).cpu() |
| for q_idx in range(len(ensemble_scores)): |
| _pred = dict( |
| video=ann["video"][q_idx], |
| options=[e[q_idx] for e in ann["caption"]], |
| answer=ann["answer"][q_idx].item(), |
| pred_ans_ensemble=pred_ans[q_idx].item(), |
| pred_scores_ensemble=ensemble_scores[q_idx].numpy(), |
| ) |
| |
| if n_clip_per_video > 1: |
| _pred["pred_scores_frame"] = clip_scores[:, q_idx].numpy() |
| _pred["pred_ans_frame"] = ( |
| clip_scores[:, q_idx].max(1)[1].numpy() |
| ) |
| predictions.append(_pred) |
|
|
| all_gt_answers = torch.cat(all_gt_answers, 0) |
| all_pred_answers = torch.cat(all_pred_answers, 0) |
| acc = all_gt_answers == all_pred_answers |
| acc = float(torch.sum(acc) / len(acc)) |
| eval_res = {"test": round(100 * acc, 2)} |
| logger.info(f"\n{eval_res}") |
| save_json(eval_res, join(config.output_dir, "eval_res.json")) |
| torch.save(predictions, join(config.output_dir, "prediction_scores.pth")) |
|
|
| dist.barrier() |
|
|
|
|
| if __name__ == "__main__": |
| cfg = setup_main() |
| main_with_ensemble(cfg) |
|
|