| |
| from __future__ import annotations |
| import os, time, random, argparse, math |
| from pathlib import Path |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import LambdaLR |
| |
|
|
| import matplotlib.pyplot as plt |
|
|
| from modules.tokenization_clip import SimpleTokenizer as ClipTokenizer |
| from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE |
| from modules.modeling import CLIP4Clip |
| from util import get_logger |
| from dataloaders.data_dataloaders import DATALOADER_DICT |
| from metrics import compute_metrics, tensor_text_to_video_metrics, tensor_video_to_text_sim |
|
|
| |
| |
| |
| def get_args(description='CLIP4Clip on Retrieval Task (Single GPU Minimal)'): |
| p = argparse.ArgumentParser(description=description) |
| |
| p.add_argument("--do_train", action="store_true") |
| p.add_argument("--do_eval", action="store_true") |
|
|
| |
| p.add_argument('--train_csv', type=str, default='data/.train.csv') |
| p.add_argument('--val_csv', type=str, default='data/.val.csv') |
| p.add_argument('--data_path', type=str, default='data/caption.pickle') |
| p.add_argument('--features_path', type=str, default='data/videos_feature.pickle') |
| p.add_argument("--output_dir", type=str, required=True) |
| p.add_argument("--cache_dir", type=str, default="") |
|
|
| |
| p.add_argument('--epochs', type=int, default=20) |
| p.add_argument('--lr', type=float, default=1e-4) |
| p.add_argument('--batch_size', type=int, default=256) |
| p.add_argument('--batch_size_val', type=int, default=3500) |
| p.add_argument('--warmup_proportion', type=float, default=0.1) |
| p.add_argument('--gradient_accumulation_steps', type=int, default=1) |
| p.add_argument('--lr_decay', type=float, default=0.9) |
| p.add_argument('--seed', type=int, default=42) |
|
|
| |
| p.add_argument("--task_type", default="retrieval", type=str) |
| p.add_argument("--datatype", default="msrvtt", type=str) |
| p.add_argument("--cross_model", default="cross-base", type=str) |
| p.add_argument("--init_model", default=None, type=str) |
| p.add_argument("--resume_model", default=None, type=str) |
|
|
| |
| p.add_argument('--max_words', type=int, default=20) |
| p.add_argument('--max_frames', type=int, default=100) |
| p.add_argument('--feature_framerate', type=int, default=1) |
| p.add_argument('--margin', type=float, default=0.1) |
| p.add_argument('--hard_negative_rate', type=float, default=0.5) |
| p.add_argument('--negative_weighting', type=int, default=1) |
| p.add_argument('--n_pair', type=int, default=1) |
| p.add_argument('--num_thread_reader', type=int, default=1) |
|
|
| p.add_argument('--text_num_hidden_layers', type=int, default=12) |
| p.add_argument('--visual_num_hidden_layers', type=int, default=12) |
| p.add_argument('--cross_num_hidden_layers', type=int, default=4) |
|
|
| p.add_argument('--loose_type', action='store_true') |
| p.add_argument('--expand_msrvtt_sentences', action='store_true') |
| p.add_argument('--train_frame_order', type=int, default=0, choices=[0,1,2]) |
| p.add_argument('--eval_frame_order', type=int, default=0, choices=[0,1,2]) |
| p.add_argument('--freeze_layer_num', type=int, default=0) |
| p.add_argument('--slice_framepos', type=int, default=0, choices=[0,1,2]) |
| p.add_argument('--linear_patch', type=str, default="2d", choices=["2d","3d"]) |
| p.add_argument('--sim_header', type=str, default="meanP", |
| choices=["meanP","seqLSTM","seqTransf","tightTransf"]) |
| p.add_argument("--pretrained_clip_name", default="ViT-B/32", type=str) |
|
|
| |
| p.add_argument("--use_rff", action='store_true') |
| p.add_argument("--rff_dim", type=int, default=3000) |
| p.add_argument("--use_clip4hashing", action="store_true") |
| p.add_argument("--hash_bit", type=int, default=2048) |
|
|
| |
| p.add_argument("--num_workers", type=int, default=4) |
| p.add_argument("--pin_memory", action="store_true") |
| p.add_argument("--no_amp", action="store_true", help="AMP ๋๊ธฐ") |
|
|
| args = p.parse_args() |
| if args.sim_header == "tightTransf": |
| args.loose_type = False |
| if not args.do_train and not args.do_eval: |
| raise ValueError("`--do_train` ๋๋ `--do_eval` ์ค ํ๋๋ ๋ฐ๋์ ํ์ํฉ๋๋ค.") |
| args.batch_size = int(args.batch_size / args.gradient_accumulation_steps) |
| return args |
|
|
| |
| |
| |
| def setup_env(args): |
| os.makedirs(args.output_dir, exist_ok=True) |
| logger = get_logger(os.path.join(args.output_dir, "log.txt")) |
|
|
| random.seed(args.seed) |
| np.random.seed(args.seed) |
| torch.manual_seed(args.seed) |
| torch.cuda.manual_seed_all(args.seed) |
| torch.backends.cudnn.benchmark = True |
|
|
| |
| try: |
| torch.set_float32_matmul_precision("high") |
| except Exception: |
| pass |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| logger.info(f"device={device}, cuda_available={torch.cuda.is_available()}") |
| for k in sorted(args.__dict__): |
| logger.info(f"{k}: {getattr(args, k)}") |
| return logger, device |
|
|
| |
| |
| |
| def init_model(args, device): |
| state = torch.load(args.init_model, map_location='cpu') if args.init_model else None |
| cache_dir = args.cache_dir or os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed') |
| model = CLIP4Clip.from_pretrained(args.cross_model, cache_dir=cache_dir, state_dict=state, task_config=args) |
| model.to(device) |
|
|
| |
| assert -1 <= args.freeze_layer_num <= 12 |
| if hasattr(model, "clip") and args.freeze_layer_num > -1: |
| for name, p in model.clip.named_parameters(): |
| if name.startswith(("ln_final","text_projection","logit_scale","visual.ln_post","visual.proj")): |
| continue |
| elif ("visual.transformer.resblocks." in name) or ("transformer.resblocks." in name): |
| layer_num = int(name.split(".resblocks.")[1].split(".")[0]) |
| if layer_num >= args.freeze_layer_num: |
| continue |
| if args.linear_patch == "3d" and "conv2." in name: |
| continue |
| p.requires_grad = False |
| return model |
|
|
| |
| |
| |
| def prep_optimizer(args, model, num_training_steps): |
| if hasattr(model, 'module'): |
| model = model.module |
| no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] |
| param_optimizer = list(model.named_parameters()) |
| decay_params = [p for n,p in param_optimizer if not any(nd in n for nd in no_decay) and p.requires_grad] |
| nodecay_params = [p for n,p in param_optimizer if any(nd in n for nd in no_decay) and p.requires_grad] |
|
|
| optimizer = AdamW([ |
| {'params': decay_params, 'weight_decay': 0.2, 'lr': args.lr}, |
| {'params': nodecay_params, 'weight_decay': 0.0, 'lr': args.lr}, |
| ], lr=args.lr) |
|
|
| warmup_steps = int(num_training_steps * args.warmup_proportion) |
|
|
| def lr_lambda(current_step: int): |
| if current_step < warmup_steps: |
| return float(current_step) / max(1, warmup_steps) |
| progress = float(current_step - warmup_steps) / max(1, num_training_steps - warmup_steps) |
| return 0.5 * (1.0 + math.cos(math.pi * progress)) |
|
|
| scheduler = LambdaLR(optimizer, lr_lambda) |
| return optimizer, scheduler |
|
|
| |
| |
| |
| def train_epoch(epoch, args, model, train_loader, device, optimizer, scheduler, scaler, logger): |
| model.train() |
| total_loss = 0.0 |
| log_step = 100 |
| start = time.time() |
|
|
| for step, batch in enumerate(train_loader): |
| batch = tuple(t.to(device, non_blocking=True) for t in batch) |
| input_ids, input_mask, segment_ids, video, video_mask = batch |
|
|
| with torch.cuda.amp.autocast(enabled=not args.no_amp): |
| loss = model(input_ids, segment_ids, input_mask, video, video_mask) |
| if args.gradient_accumulation_steps > 1: |
| loss = loss / args.gradient_accumulation_steps |
|
|
| scaler.scale(loss).backward() |
| total_loss += float(loss) |
|
|
| if (step + 1) % args.gradient_accumulation_steps == 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad(set_to_none=True) |
| scheduler.step() |
|
|
| |
| if hasattr(model, 'clip'): |
| torch.clamp_(model.clip.logit_scale.data, max=np.log(100)) |
| elif hasattr(model, 'module') and hasattr(model.module, 'clip'): |
| torch.clamp_(model.module.clip.logit_scale.data, max=np.log(100)) |
|
|
| if (step + 1) % log_step == 0: |
| logger.info(f"[train] epoch {epoch+1} step {step+1}/{len(train_loader)} " |
| f"loss={float(loss):.4f} time/step={(time.time()-start)/log_step:.4f}") |
| start = time.time() |
|
|
| return total_loss / len(train_loader) |
|
|
| def _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_seq_out, batch_vis_out): |
| sim_matrix = [] |
| for idx1, b1 in enumerate(batch_list_t): |
| input_mask, segment_ids = b1 |
| sequence_output = batch_seq_out[idx1] |
| each_row = [] |
| for idx2, b2 in enumerate(batch_list_v): |
| video_mask = b2[0] |
| visual_output = batch_vis_out[idx2] |
| logits, *_ = model.get_similarity_logits(sequence_output, visual_output, input_mask, video_mask, |
| loose_type=model.loose_type) |
| each_row.append(logits.cpu().detach().numpy()) |
| sim_matrix.append(np.concatenate(each_row, axis=-1)) |
| return sim_matrix |
|
|
| @torch.no_grad() |
| def eval_epoch(args, model, test_loader, device, logger): |
| |
| suffix = "" |
| if getattr(args, "use_clip4hashing", False): suffix += "_hash" |
| if args.use_rff: suffix += "_rff" |
| if args.init_model: suffix += "_trained" |
|
|
| if "train" in args.val_csv and "10k" in args.val_csv: |
| cache_name = f"{args.datatype}_train_test_10k_cache{suffix}.pt" |
| else: |
| cache_name = f"{args.datatype}_eval_cache{suffix}.pt" |
|
|
| cache_path = os.path.join(args.output_dir, cache_name) |
|
|
| model.eval() |
| if os.path.exists(cache_path): |
| logger.info(f"[Eval] load cached features: {cache_path}") |
| cache = torch.load(cache_path, map_location=device) |
| batch_seq_out = cache['batch_sequence_output_list'] |
| batch_vis_out = cache['batch_visual_output_list'] |
| batch_list_t = cache['batch_list_t'] |
| batch_list_v = cache['batch_list_v'] |
| else: |
| logger.info("[Eval] caching features...") |
| batch_list_t, batch_list_v = [], [] |
| batch_seq_out, batch_vis_out = [], [] |
| for bid, batch in enumerate(test_loader): |
| batch = tuple(t.to(device, non_blocking=True) for t in batch) |
| input_ids, input_mask, segment_ids, video, video_mask = batch |
| with torch.cuda.amp.autocast(enabled=not args.no_amp): |
| seq_out, vis_out = model.get_sequence_visual_output(input_ids, segment_ids, input_mask, video, video_mask) |
| batch_seq_out.append(seq_out) |
| batch_vis_out.append(vis_out) |
| batch_list_t.append((input_mask, segment_ids)) |
| batch_list_v.append((video_mask,)) |
| if (bid+1) % 20 == 0: |
| logger.info(f"[Eval] cached batch {bid+1}/{len(test_loader)}") |
|
|
| torch.save({ |
| 'batch_sequence_output_list': batch_seq_out, |
| 'batch_visual_output_list': batch_vis_out, |
| 'batch_list_t': batch_list_t, |
| 'batch_list_v': batch_list_v, |
| }, cache_path) |
| logger.info(f"[Eval] saved cache to {cache_path}") |
|
|
| sim_matrix = _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_seq_out, batch_vis_out) |
| sim_matrix = np.concatenate(sim_matrix, axis=0) |
| logger.info(f"[Eval] sim_matrix shape: {sim_matrix.shape}") |
|
|
| |
| try: |
| plt.figure(figsize=(8,6)) |
| plt.imshow(sim_matrix[:100, :100], aspect='auto') |
| plt.title('Similarity Matrix (first 100x100)') |
| plt.xlabel('Video Index'); plt.ylabel('Text Index') |
| out_path = os.path.join(args.output_dir, 'sim_matrix_heatmap.png') |
| plt.tight_layout(); plt.savefig(out_path); plt.close() |
| logger.info(f"[Eval] heatmap saved: {out_path}") |
| except Exception as e: |
| logger.info(f"[Eval] heatmap skipped: {e}") |
|
|
| tv = compute_metrics(sim_matrix) |
| vt = compute_metrics(sim_matrix.T) |
| logger.info(f"Text-to-Video: R@1 {tv['R1']:.1f} | R@5 {tv['R5']:.1f} | R@10 {tv['R10']:.1f} | MR {tv['MR']:.1f} | MeanR {tv['MeanR']:.1f}") |
| logger.info(f"Video-to-Text: R@1 {vt['R1']:.1f} | R@5 {vt['R5']:.1f} | R@10 {vt['R10']:.1f} | MR {vt['MR']:.1f} | MeanR {vt['MeanR']:.1f}") |
| return tv['R1'] |
|
|
| |
| |
| |
| def main(): |
| args = get_args() |
| logger, device = setup_env(args) |
| assert args.task_type == "retrieval" |
|
|
| tokenizer = ClipTokenizer() |
| model = init_model(args, device) |
|
|
| |
| assert args.datatype in DATALOADER_DICT |
| test_loader, test_len = None, 0 |
| if DATALOADER_DICT[args.datatype]["test"] is not None: |
| test_loader, test_len = DATALOADER_DICT[args.datatype]["test"](args, tokenizer) |
| if DATALOADER_DICT[args.datatype]["val"] is not None: |
| val_loader, val_len = DATALOADER_DICT[args.datatype]["val"](args, tokenizer, subset="val") |
| else: |
| val_loader, val_len = test_loader, test_len |
| if test_loader is None: |
| test_loader, test_len = val_loader, val_len |
|
|
| if args.do_train: |
| train_loader, train_len, train_sampler = DATALOADER_DICT[args.datatype]["train"](args, tokenizer) |
| |
| if hasattr(train_loader, "pin_memory") and args.pin_memory and not torch.cuda.is_available(): |
| try: |
| train_loader.pin_memory = False |
| except Exception: |
| pass |
|
|
| steps_per_epoch = len(train_loader) |
| num_train_steps = (steps_per_epoch * args.epochs) // max(1, args.gradient_accumulation_steps) |
| optimizer, scheduler = prep_optimizer(args, model, num_train_steps) |
| scaler = torch.cuda.amp.GradScaler(enabled=not args.no_amp) |
|
|
| logger.info(f"[Train] examples={train_len} batch_size={args.batch_size} steps/epoch={steps_per_epoch} total_steps={num_train_steps}") |
| best_r1 = -1.0 |
|
|
| if args.resume_model: |
| ckpt = torch.load(args.resume_model, map_location='cpu') |
| optimizer.load_state_dict(ckpt['optimizer_state_dict']) |
| logger.info(f"[Train] resumed optimizer from {args.resume_model}") |
|
|
| for epoch in range(args.epochs): |
| loss = train_epoch(epoch, args, model, train_loader, device, optimizer, scheduler, scaler, logger) |
| logger.info(f"[Train] epoch {epoch+1}/{args.epochs} loss={loss:.4f}") |
|
|
| |
| r1 = eval_epoch(args, model, test_loader, device, logger) |
| if r1 > best_r1: |
| best_r1 = r1 |
| model_path = os.path.join(args.output_dir, f"pytorch_model.bin.best") |
| torch.save((model.module if hasattr(model,'module') else model).state_dict(), model_path) |
| opt_path = os.path.join(args.output_dir, f"pytorch_opt.bin.best") |
| torch.save({'epoch': epoch, 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss}, opt_path) |
| logger.info(f"[Train] new best R1={best_r1:.2f} saved: {model_path}") |
|
|
| if args.do_eval: |
| eval_epoch(args, model, test_loader, device, logger) |
|
|
| if __name__ == "__main__": |
| main() |
|
|