# train_single_gpu.py from __future__ import annotations import os, time, random, argparse, math from pathlib import Path import numpy as np import torch import torch.nn.functional as F from torch.optim import AdamW from torch.optim.lr_scheduler import LambdaLR # (removed) from transformers import get_cosine_schedule_with_warmup import matplotlib.pyplot as plt from modules.tokenization_clip import SimpleTokenizer as ClipTokenizer from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from modules.modeling import CLIP4Clip from util import get_logger from dataloaders.data_dataloaders import DATALOADER_DICT from metrics import compute_metrics, tensor_text_to_video_metrics, tensor_video_to_text_sim # ----------------------- # 1) Arguments (정리본) # ----------------------- def get_args(description='CLIP4Clip on Retrieval Task (Single GPU Minimal)'): p = argparse.ArgumentParser(description=description) # 핵심 동작 플래그 p.add_argument("--do_train", action="store_true") p.add_argument("--do_eval", action="store_true") # 데이터/출력 경로 p.add_argument('--train_csv', type=str, default='data/.train.csv') p.add_argument('--val_csv', type=str, default='data/.val.csv') p.add_argument('--data_path', type=str, default='data/caption.pickle') p.add_argument('--features_path', type=str, default='data/videos_feature.pickle') p.add_argument("--output_dir", type=str, required=True) p.add_argument("--cache_dir", type=str, default="") # 하이퍼파라미터 p.add_argument('--epochs', type=int, default=20) p.add_argument('--lr', type=float, default=1e-4) p.add_argument('--batch_size', type=int, default=256) p.add_argument('--batch_size_val', type=int, default=3500) p.add_argument('--warmup_proportion', type=float, default=0.1) p.add_argument('--gradient_accumulation_steps', type=int, default=1) p.add_argument('--lr_decay', type=float, default=0.9) # (미사용 가능) p.add_argument('--seed', type=int, default=42) # 모델/작동 옵션 p.add_argument("--task_type", default="retrieval", type=str) p.add_argument("--datatype", default="msrvtt", type=str) p.add_argument("--cross_model", default="cross-base", type=str) p.add_argument("--init_model", default=None, type=str) # 초기 가중치 로드 p.add_argument("--resume_model", default=None, type=str) # 옵티마이저 상태 포함 재개 # CLIP 관련/헤더 등 기존 옵션 최대한 유지 p.add_argument('--max_words', type=int, default=20) p.add_argument('--max_frames', type=int, default=100) p.add_argument('--feature_framerate', type=int, default=1) p.add_argument('--margin', type=float, default=0.1) p.add_argument('--hard_negative_rate', type=float, default=0.5) p.add_argument('--negative_weighting', type=int, default=1) p.add_argument('--n_pair', type=int, default=1) p.add_argument('--num_thread_reader', type=int, default=1) p.add_argument('--text_num_hidden_layers', type=int, default=12) p.add_argument('--visual_num_hidden_layers', type=int, default=12) p.add_argument('--cross_num_hidden_layers', type=int, default=4) p.add_argument('--loose_type', action='store_true') p.add_argument('--expand_msrvtt_sentences', action='store_true') p.add_argument('--train_frame_order', type=int, default=0, choices=[0,1,2]) p.add_argument('--eval_frame_order', type=int, default=0, choices=[0,1,2]) p.add_argument('--freeze_layer_num', type=int, default=0) p.add_argument('--slice_framepos', type=int, default=0, choices=[0,1,2]) p.add_argument('--linear_patch', type=str, default="2d", choices=["2d","3d"]) p.add_argument('--sim_header', type=str, default="meanP", choices=["meanP","seqLSTM","seqTransf","tightTransf"]) p.add_argument("--pretrained_clip_name", default="ViT-B/32", type=str) # 확장 플래그 (그대로 유지) p.add_argument("--use_rff", action='store_true') p.add_argument("--rff_dim", type=int, default=3000) p.add_argument("--use_clip4hashing", action="store_true") p.add_argument("--hash_bit", type=int, default=2048) # 품질/성능 옵션 p.add_argument("--num_workers", type=int, default=4) p.add_argument("--pin_memory", action="store_true") p.add_argument("--no_amp", action="store_true", help="AMP 끄기") args = p.parse_args() if args.sim_header == "tightTransf": args.loose_type = False if not args.do_train and not args.do_eval: raise ValueError("`--do_train` 또는 `--do_eval` 중 하나는 반드시 필요합니다.") args.batch_size = int(args.batch_size / args.gradient_accumulation_steps) return args # ----------------------- # 2) Seed/Logger/Device # ----------------------- def setup_env(args): os.makedirs(args.output_dir, exist_ok=True) logger = get_logger(os.path.join(args.output_dir, "log.txt")) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) torch.backends.cudnn.benchmark = True # 속도 ↑ (완전 재현 필요하면 False) # matmul precision (Ampere+) try: torch.set_float32_matmul_precision("high") except Exception: pass device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"device={device}, cuda_available={torch.cuda.is_available()}") for k in sorted(args.__dict__): logger.info(f"{k}: {getattr(args, k)}") return logger, device # ----------------------- # 3) Model # ----------------------- def init_model(args, device): state = torch.load(args.init_model, map_location='cpu') if args.init_model else None cache_dir = args.cache_dir or os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed') model = CLIP4Clip.from_pretrained(args.cross_model, cache_dir=cache_dir, state_dict=state, task_config=args) model.to(device) # 선택적 얼리기 assert -1 <= args.freeze_layer_num <= 12 if hasattr(model, "clip") and args.freeze_layer_num > -1: for name, p in model.clip.named_parameters(): if name.startswith(("ln_final","text_projection","logit_scale","visual.ln_post","visual.proj")): continue elif ("visual.transformer.resblocks." in name) or ("transformer.resblocks." in name): layer_num = int(name.split(".resblocks.")[1].split(".")[0]) if layer_num >= args.freeze_layer_num: continue if args.linear_patch == "3d" and "conv2." in name: continue p.requires_grad = False return model # ----------------------- # 4) Optimizer & Scheduler (PyTorch-only warmup+cosine) # ----------------------- def prep_optimizer(args, model, num_training_steps): if hasattr(model, 'module'): model = model.module no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] param_optimizer = list(model.named_parameters()) decay_params = [p for n,p in param_optimizer if not any(nd in n for nd in no_decay) and p.requires_grad] nodecay_params = [p for n,p in param_optimizer if any(nd in n for nd in no_decay) and p.requires_grad] optimizer = AdamW([ {'params': decay_params, 'weight_decay': 0.2, 'lr': args.lr}, {'params': nodecay_params, 'weight_decay': 0.0, 'lr': args.lr}, ], lr=args.lr) warmup_steps = int(num_training_steps * args.warmup_proportion) def lr_lambda(current_step: int): if current_step < warmup_steps: return float(current_step) / max(1, warmup_steps) # 선형 워밍업 progress = float(current_step - warmup_steps) / max(1, num_training_steps - warmup_steps) return 0.5 * (1.0 + math.cos(math.pi * progress)) # 코사인 감쇠 scheduler = LambdaLR(optimizer, lr_lambda) return optimizer, scheduler # ----------------------- # 5) Train/Eval # ----------------------- def train_epoch(epoch, args, model, train_loader, device, optimizer, scheduler, scaler, logger): model.train() total_loss = 0.0 log_step = 100 start = time.time() for step, batch in enumerate(train_loader): batch = tuple(t.to(device, non_blocking=True) for t in batch) input_ids, input_mask, segment_ids, video, video_mask = batch with torch.cuda.amp.autocast(enabled=not args.no_amp): loss = model(input_ids, segment_ids, input_mask, video, video_mask) if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps scaler.scale(loss).backward() total_loss += float(loss) if (step + 1) % args.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) scheduler.step() # optim.step() 다음 호출 # logit_scale 안정화 if hasattr(model, 'clip'): torch.clamp_(model.clip.logit_scale.data, max=np.log(100)) elif hasattr(model, 'module') and hasattr(model.module, 'clip'): torch.clamp_(model.module.clip.logit_scale.data, max=np.log(100)) if (step + 1) % log_step == 0: logger.info(f"[train] epoch {epoch+1} step {step+1}/{len(train_loader)} " f"loss={float(loss):.4f} time/step={(time.time()-start)/log_step:.4f}") start = time.time() return total_loss / len(train_loader) def _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_seq_out, batch_vis_out): sim_matrix = [] for idx1, b1 in enumerate(batch_list_t): input_mask, segment_ids = b1 sequence_output = batch_seq_out[idx1] each_row = [] for idx2, b2 in enumerate(batch_list_v): video_mask = b2[0] visual_output = batch_vis_out[idx2] logits, *_ = model.get_similarity_logits(sequence_output, visual_output, input_mask, video_mask, loose_type=model.loose_type) each_row.append(logits.cpu().detach().numpy()) sim_matrix.append(np.concatenate(each_row, axis=-1)) return sim_matrix @torch.no_grad() def eval_epoch(args, model, test_loader, device, logger): # 캐시 파일명 구성 suffix = "" if getattr(args, "use_clip4hashing", False): suffix += "_hash" if args.use_rff: suffix += "_rff" if args.init_model: suffix += "_trained" if "train" in args.val_csv and "10k" in args.val_csv: cache_name = f"{args.datatype}_train_test_10k_cache{suffix}.pt" else: cache_name = f"{args.datatype}_eval_cache{suffix}.pt" cache_path = os.path.join(args.output_dir, cache_name) model.eval() if os.path.exists(cache_path): logger.info(f"[Eval] load cached features: {cache_path}") cache = torch.load(cache_path, map_location=device) batch_seq_out = cache['batch_sequence_output_list'] batch_vis_out = cache['batch_visual_output_list'] batch_list_t = cache['batch_list_t'] batch_list_v = cache['batch_list_v'] else: logger.info("[Eval] caching features...") batch_list_t, batch_list_v = [], [] batch_seq_out, batch_vis_out = [], [] for bid, batch in enumerate(test_loader): batch = tuple(t.to(device, non_blocking=True) for t in batch) input_ids, input_mask, segment_ids, video, video_mask = batch with torch.cuda.amp.autocast(enabled=not args.no_amp): seq_out, vis_out = model.get_sequence_visual_output(input_ids, segment_ids, input_mask, video, video_mask) batch_seq_out.append(seq_out) batch_vis_out.append(vis_out) batch_list_t.append((input_mask, segment_ids)) batch_list_v.append((video_mask,)) if (bid+1) % 20 == 0: logger.info(f"[Eval] cached batch {bid+1}/{len(test_loader)}") torch.save({ 'batch_sequence_output_list': batch_seq_out, 'batch_visual_output_list': batch_vis_out, 'batch_list_t': batch_list_t, 'batch_list_v': batch_list_v, }, cache_path) logger.info(f"[Eval] saved cache to {cache_path}") sim_matrix = _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_seq_out, batch_vis_out) sim_matrix = np.concatenate(sim_matrix, axis=0) logger.info(f"[Eval] sim_matrix shape: {sim_matrix.shape}") # 히트맵(옵션) try: plt.figure(figsize=(8,6)) plt.imshow(sim_matrix[:100, :100], aspect='auto') plt.title('Similarity Matrix (first 100x100)') plt.xlabel('Video Index'); plt.ylabel('Text Index') out_path = os.path.join(args.output_dir, 'sim_matrix_heatmap.png') plt.tight_layout(); plt.savefig(out_path); plt.close() logger.info(f"[Eval] heatmap saved: {out_path}") except Exception as e: logger.info(f"[Eval] heatmap skipped: {e}") tv = compute_metrics(sim_matrix) vt = compute_metrics(sim_matrix.T) logger.info(f"Text-to-Video: R@1 {tv['R1']:.1f} | R@5 {tv['R5']:.1f} | R@10 {tv['R10']:.1f} | MR {tv['MR']:.1f} | MeanR {tv['MeanR']:.1f}") logger.info(f"Video-to-Text: R@1 {vt['R1']:.1f} | R@5 {vt['R5']:.1f} | R@10 {vt['R10']:.1f} | MR {vt['MR']:.1f} | MeanR {vt['MeanR']:.1f}") return tv['R1'] # ----------------------- # 6) Main # ----------------------- def main(): args = get_args() logger, device = setup_env(args) assert args.task_type == "retrieval" tokenizer = ClipTokenizer() model = init_model(args, device) # 데이터 로더 (기존 팩토리 그대로 사용) assert args.datatype in DATALOADER_DICT test_loader, test_len = None, 0 if DATALOADER_DICT[args.datatype]["test"] is not None: test_loader, test_len = DATALOADER_DICT[args.datatype]["test"](args, tokenizer) if DATALOADER_DICT[args.datatype]["val"] is not None: val_loader, val_len = DATALOADER_DICT[args.datatype]["val"](args, tokenizer, subset="val") else: val_loader, val_len = test_loader, test_len if test_loader is None: # 테스트 없으면 밸리데이션으로 대체 test_loader, test_len = val_loader, val_len if args.do_train: train_loader, train_len, train_sampler = DATALOADER_DICT[args.datatype]["train"](args, tokenizer) # 안전한 pin_memory: CUDA 있을 때만 사용 if hasattr(train_loader, "pin_memory") and args.pin_memory and not torch.cuda.is_available(): try: train_loader.pin_memory = False except Exception: pass steps_per_epoch = len(train_loader) num_train_steps = (steps_per_epoch * args.epochs) // max(1, args.gradient_accumulation_steps) optimizer, scheduler = prep_optimizer(args, model, num_train_steps) scaler = torch.cuda.amp.GradScaler(enabled=not args.no_amp) logger.info(f"[Train] examples={train_len} batch_size={args.batch_size} steps/epoch={steps_per_epoch} total_steps={num_train_steps}") best_r1 = -1.0 if args.resume_model: ckpt = torch.load(args.resume_model, map_location='cpu') optimizer.load_state_dict(ckpt['optimizer_state_dict']) logger.info(f"[Train] resumed optimizer from {args.resume_model}") for epoch in range(args.epochs): loss = train_epoch(epoch, args, model, train_loader, device, optimizer, scheduler, scaler, logger) logger.info(f"[Train] epoch {epoch+1}/{args.epochs} loss={loss:.4f}") # 빠른 검증: test 셋을 그대로 사용(원 코드와 동일한 흐름) r1 = eval_epoch(args, model, test_loader, device, logger) if r1 > best_r1: best_r1 = r1 model_path = os.path.join(args.output_dir, f"pytorch_model.bin.best") torch.save((model.module if hasattr(model,'module') else model).state_dict(), model_path) opt_path = os.path.join(args.output_dir, f"pytorch_opt.bin.best") torch.save({'epoch': epoch, 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss}, opt_path) logger.info(f"[Train] new best R1={best_r1:.2f} saved: {model_path}") if args.do_eval: eval_epoch(args, model, test_loader, device, logger) if __name__ == "__main__": main()