| import copy |
| import datetime |
| import logging |
| import os |
| import time |
| from os.path import join |
| from shutil import copyfile |
|
|
| import torch |
| import torch.backends.cudnn as cudnn |
| import torch.distributed as dist |
| import wandb |
| from omegaconf import OmegaConf |
|
|
| from dataset import (create_dataset, create_loader, create_sampler, |
| vqa_collate_fn) |
| from dataset.utils import sync_save_result |
| from models.vindlu_qa import VindLU_QA |
| from models.vindlu_vit_qa import VindLU_VIT_QA |
| from models.vindlu_blip_T5 import VindLU_BLIP_T5 |
| from tasks.shared_utils import setup_model |
| from tasks.vqa_utils import eval_qa_acc |
| from utils.basic_utils import (MetricLogger, SmoothedValue, save_json, |
| setup_seed) |
| from utils.config import Config |
| from utils.config_utils import setup_main |
| from utils.distributed import get_rank, get_world_size, is_main_process |
| from utils.logger import log_dict_to_wandb, setup_wandb |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def train( |
| model, |
| train_loader, |
| optimizer, |
| tokenizer, |
| epoch, |
| global_step, |
| device, |
| scheduler, |
| scaler, |
| config, |
| ): |
| model.train() |
|
|
| metric_logger = MetricLogger(delimiter=" ") |
| metric_logger.add_meter("lr", SmoothedValue(window=1, fmt="{value:.6f}")) |
| metric_logger.add_meter("loss", SmoothedValue(window=1, fmt="{value:.4f}")) |
|
|
| header = "Train Epoch: [{}]".format(epoch) |
| log_freq = config.log_freq |
|
|
| if config.distributed: |
| train_loader.sampler.set_epoch(epoch) |
|
|
| iterator = metric_logger.log_every(train_loader, log_freq, header) |
| for i, data in enumerate(iterator): |
| image, question, answer, weights, n = data |
| image = image.to(device, non_blocking=True) |
| weights = weights.to(device, non_blocking=True) |
| question_input = tokenizer( |
| question, |
| padding="max_length", |
| truncation=True, |
| max_length=config.max_q_len, |
| return_tensors="pt", |
| ).to(device) |
| answer_input = tokenizer( |
| answer, |
| padding="max_length", |
| truncation=True, |
| max_length=config.max_a_len, |
| return_tensors="pt", |
| ).to(device) |
|
|
| with torch.cuda.amp.autocast(enabled=config.fp16, dtype=torch.bfloat16): |
| loss = model( |
| image, question_input, answer_input, train=True, k=n, weights=weights, |
| raw_question=list(question), raw_answer=list(answer), |
| ) |
|
|
| optimizer.zero_grad() |
| scaler.scale(loss).backward() |
| if config.optimizer.max_grad_norm > 0: |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm) |
| scaler.step(optimizer) |
| scaler.update() |
| scheduler.step() |
|
|
| |
| metric_logger.update(loss=loss.item()) |
| metric_logger.update(lr=optimizer.param_groups[0]["lr"]) |
|
|
| if is_main_process() and config.wandb.enable and global_step % log_freq == 0: |
| logs = metric_logger.get_global_avg_dict() |
| log_dict_to_wandb(logs, step=global_step, prefix="train/") |
|
|
| global_step += 1 |
|
|
| if config.debug and (i + 1) % 5 == 0: |
| break |
|
|
| metric_logger.synchronize_between_processes() |
| logger.info(f"Averaged train stats: {metric_logger.global_avg()}") |
| return global_step |
|
|
|
|
| @torch.no_grad() |
| def evaluation(model, data_loader, tokenizer, device, config): |
| model.eval() |
|
|
| metric_logger = MetricLogger(delimiter=" ") |
| header = "[evaluation] Generating answers:" |
| log_freq = config.log_freq // 2 |
|
|
| result = [] |
|
|
| raw_answer_list = data_loader.dataset.answer_list |
| answer_list = [a + " " + config.eos for a in raw_answer_list] |
| answer_input = tokenizer( |
| answer_list, |
| padding="max_length", |
| truncation=True, |
| max_length=config.max_a_len, |
| return_tensors="pt", |
| ).to(device) |
|
|
| logger.info("Start generating results.") |
|
|
| iterator = metric_logger.log_every(data_loader, log_freq, header) |
| for n, data in enumerate(iterator): |
| image, question, question_id = data |
| image = image.to(device, non_blocking=True) |
| question_input = tokenizer( |
| question, |
| padding="max_length", |
| truncation=True, |
| max_length=config.max_q_len, |
| return_tensors="pt", |
| ).to(device) |
|
|
| topk_ids, topk_probs = model( |
| image, question_input, answer_input, train=False, k=config.evaluation.k_test, |
| raw_question=list(question), raw_answer=raw_answer_list, |
| ) |
|
|
| for ques_id, topk_id, topk_prob in zip(question_id, topk_ids, topk_probs): |
| ques_id = int(ques_id.item()) if not isinstance(ques_id, str) else ques_id |
| _, pred = topk_prob.max(dim=0) |
| result.append({"question_id": ques_id, "answer": raw_answer_list[topk_id[pred]]}) |
|
|
| return result |
|
|
|
|
| def setup_dataloaders(config, dataset_name="vqa"): |
| logger.info(f"Creating {dataset_name} QA datasets") |
| train_dataset = create_dataset("qa_train", config) |
| test_datasets, test_dataset_names = create_dataset("qa_eval", config) |
| all_datasets = [train_dataset] + test_datasets |
|
|
| if config.distributed: |
| num_tasks = get_world_size() |
| global_rank = get_rank() |
| samplers = create_sampler( |
| all_datasets, |
| shuffles=[True] + [False] * len(test_datasets), |
| num_tasks=num_tasks, |
| global_rank=global_rank, |
| ) |
| else: |
| samplers = [None] * len(all_datasets) |
|
|
| train_bsz = config.inputs.batch_size[train_dataset.media_type] |
| test_bsz_list = [config.inputs.batch_size_test[d.media_type] for d in test_datasets] |
| loaders = create_loader( |
| all_datasets, |
| samplers, |
| batch_size=[train_bsz] + test_bsz_list, |
| num_workers=[ |
| config.num_workers, |
| ] |
| * len(all_datasets), |
| is_trains=[True] + [False] * len(test_datasets), |
| collate_fns=[vqa_collate_fn] + [None] * len(test_datasets), |
| ) |
| train_loader, test_loaders = loaders[0], loaders[1:] |
| test_name2loaders = {k: v for k, v in zip(test_dataset_names, test_loaders)} |
| return train_loader, test_name2loaders |
|
|
|
|
| def main(config): |
| if is_main_process() and config.wandb.enable: |
| run = setup_wandb(config) |
|
|
| logger.info(f"train_file: {config.train_file}") |
|
|
| setup_seed(config.seed + get_rank()) |
| device = torch.device(config.device) |
| cudnn.benchmark = True |
|
|
| train_loader, test_name2loaders = setup_dataloaders(config) |
| config.scheduler.num_training_steps = len(train_loader) * config.scheduler.epochs |
| config.scheduler.num_warmup_steps = len(train_loader) * config.scheduler.warmup_epochs |
|
|
| model_cls = eval(config.model.get('model_cls', 'VindLU_QA')) |
| ( |
| model, |
| model_without_ddp, |
| optimizer, |
| scheduler, |
| scaler, |
| tokenizer, |
| start_epoch, |
| global_step, |
| ) = setup_model( |
| config, |
| model_cls=model_cls, |
| has_decoder=True, |
| pretrain=False, |
| |
| find_unused_parameters=False, |
| ) |
| if is_main_process() and config.wandb.enable: |
| wandb.watch(model) |
|
|
| best = 0 |
| best_epoch = 0 |
| has_gt = config.dataset_name != "vqa" or config.test_types[0] == "minival" |
|
|
| logger.info("Start " + "evaluation" if config.evaluate else "training") |
| start_time = time.time() |
| for epoch in range(start_epoch, config.scheduler.epochs): |
| if not config.evaluate: |
| global_step = train( |
| model, |
| train_loader, |
| optimizer, |
| tokenizer, |
| epoch, |
| global_step, |
| device, |
| scheduler, |
| scaler, |
| config, |
| ) |
|
|
| if config.get('no_test', False) and not config.evaluate: |
| if is_main_process(): |
| save_obj = { |
| "model": model_without_ddp.state_dict(), |
| "optimizer": optimizer.state_dict(), |
| "scheduler": scheduler.state_dict(), |
| "scaler": scaler.state_dict(), |
| "config": config, |
| "epoch": epoch, |
| "global_step": global_step, |
| } |
| torch.save(save_obj, join(config.output_dir, "ckpt_best.pth")) |
| best_epoch = epoch |
| |
| else: |
| with torch.cuda.amp.autocast(enabled=config.fp16, dtype=torch.bfloat16): |
| eval_res = {} |
| pred_name2file = {} |
| for test_name, test_loader in test_name2loaders.items(): |
| if test_name not in config.test_types: |
| logger.info( |
| f"Skip eval {test_name} split. All test_types {config.test_types}" |
| ) |
| continue |
| logger.info(f"Evaluating {test_name} split...") |
| qa_result = evaluation(model, test_loader, tokenizer, device, config) |
|
|
| |
| pred_file, gathered_result = sync_save_result( |
| qa_result, config.result_dir, f"{test_name}_latest" |
| ) |
| pred_name2file[test_name] = pred_file |
| if is_main_process() and has_gt: |
| eval_res[test_name] = eval_qa_acc( |
| test_loader.dataset.anno_list, |
| gathered_result, |
| is_vqa=config.dataset_name == "vqa", |
| ) |
|
|
| if is_main_process(): |
| if len(eval_res) > 0: |
| logger.info(f"eval_res {eval_res}") |
|
|
| if config.wandb.enable: |
| for name, acc_dict in eval_res.items(): |
| log_dict_to_wandb(acc_dict, step=global_step, prefix=f"{name}/") |
|
|
| if not config.evaluate and has_gt and eval_res[config.stop_key]["overall"] > best: |
| save_obj = { |
| "model": model_without_ddp.state_dict(), |
| "optimizer": optimizer.state_dict(), |
| "scheduler": scheduler.state_dict(), |
| "scaler": scaler.state_dict(), |
| "config": config, |
| "epoch": epoch, |
| "global_step": global_step, |
| } |
| for name, pred_file in pred_name2file.items(): |
| copyfile(pred_file, pred_file.replace("latest", "best")) |
| save_json( |
| eval_res, join(config.output_dir, "eval_res_best.json"), save_pretty=True |
| ) |
| torch.save(save_obj, join(config.output_dir, "ckpt_best.pth")) |
| best = eval_res[config.stop_key]["overall"] |
| best_epoch = epoch |
| if config.evaluate: |
| save_json( |
| eval_res, join(config.output_dir, "eval_res_best.json"), save_pretty=True |
| ) |
|
|
| for name, pred_file in pred_name2file.items(): |
| copyfile(pred_file, pred_file.replace("latest", "eval")) |
| if has_gt: |
| save_path = join(config.output_dir, f"{name}_acc_eval.json") |
| save_json(eval_res[name], save_path, save_pretty=True) |
|
|
| if config.evaluate or config.debug: |
| break |
|
|
| dist.barrier() |
|
|
| total_time = time.time() - start_time |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
| logger.info(f"Training time {total_time_str}") |
| logger.info(f"best epoch {best_epoch}") |
| logger.info(f"Checkpoints and Logs saved at {config.output_dir}") |
|
|
| if is_main_process() and config.wandb.enable: |
| run.finish() |
|
|
|
|
| def eval_after_training(train_config): |
| |
| train_config.wandb.enable = False |
| train_config.evaluate = True |
| train_config.pretrained_path = join(train_config.output_dir, "ckpt_best.pth") |
| |
| if train_config.get('num_frames_test_final', False): |
| train_config.num_frames_test = train_config.num_frames_test_final |
| train_config.batch_size = train_config.batch_size_final |
| train_config.inputs.video_input.num_frames_test = train_config.num_frames_test_final |
| train_config.model.vision_encoder.num_frames = train_config.num_frames_test_final |
|
|
| eval_config = copy.deepcopy(train_config) |
| eval_config.test_types = ( |
| ["val", "test"] if eval_config.dataset_name != "vqa" else ["minival"] |
| ) |
| eval_config.output_dir = join(eval_config.output_dir, f"eval_after_training") |
| eval_config.result_dir = eval_config.output_dir |
| if is_main_process(): |
| os.makedirs(eval_config.output_dir, exist_ok=False) |
| Config.dump(eval_config, os.path.join(eval_config.output_dir, "config.json")) |
| logger.info(f"===========> START eval_after_training [{eval_config.test_types}]") |
| main(eval_config) |
|
|
|
|
| if __name__ == "__main__": |
| cfg = setup_main() |
| cfg.result_dir = cfg.output_dir |
| main(cfg) |
| if not cfg.evaluate: |
| eval_after_training(cfg) |
|
|