test / main.py
jaewooo's picture
Initial upload
de15dc5 verified
# train_single_gpu.py
from __future__ import annotations
import os, time, random, argparse, math
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
# (removed) from transformers import get_cosine_schedule_with_warmup
import matplotlib.pyplot as plt
from modules.tokenization_clip import SimpleTokenizer as ClipTokenizer
from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from modules.modeling import CLIP4Clip
from util import get_logger
from dataloaders.data_dataloaders import DATALOADER_DICT
from metrics import compute_metrics, tensor_text_to_video_metrics, tensor_video_to_text_sim
# -----------------------
# 1) Arguments (์ •๋ฆฌ๋ณธ)
# -----------------------
def get_args(description='CLIP4Clip on Retrieval Task (Single GPU Minimal)'):
p = argparse.ArgumentParser(description=description)
# ํ•ต์‹ฌ ๋™์ž‘ ํ”Œ๋ž˜๊ทธ
p.add_argument("--do_train", action="store_true")
p.add_argument("--do_eval", action="store_true")
# ๋ฐ์ดํ„ฐ/์ถœ๋ ฅ ๊ฒฝ๋กœ
p.add_argument('--train_csv', type=str, default='data/.train.csv')
p.add_argument('--val_csv', type=str, default='data/.val.csv')
p.add_argument('--data_path', type=str, default='data/caption.pickle')
p.add_argument('--features_path', type=str, default='data/videos_feature.pickle')
p.add_argument("--output_dir", type=str, required=True)
p.add_argument("--cache_dir", type=str, default="")
# ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ
p.add_argument('--epochs', type=int, default=20)
p.add_argument('--lr', type=float, default=1e-4)
p.add_argument('--batch_size', type=int, default=256)
p.add_argument('--batch_size_val', type=int, default=3500)
p.add_argument('--warmup_proportion', type=float, default=0.1)
p.add_argument('--gradient_accumulation_steps', type=int, default=1)
p.add_argument('--lr_decay', type=float, default=0.9) # (๋ฏธ์‚ฌ์šฉ ๊ฐ€๋Šฅ)
p.add_argument('--seed', type=int, default=42)
# ๋ชจ๋ธ/์ž‘๋™ ์˜ต์…˜
p.add_argument("--task_type", default="retrieval", type=str)
p.add_argument("--datatype", default="msrvtt", type=str)
p.add_argument("--cross_model", default="cross-base", type=str)
p.add_argument("--init_model", default=None, type=str) # ์ดˆ๊ธฐ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ
p.add_argument("--resume_model", default=None, type=str) # ์˜ตํ‹ฐ๋งˆ์ด์ € ์ƒํƒœ ํฌํ•จ ์žฌ๊ฐœ
# CLIP ๊ด€๋ จ/ํ—ค๋” ๋“ฑ ๊ธฐ์กด ์˜ต์…˜ ์ตœ๋Œ€ํ•œ ์œ ์ง€
p.add_argument('--max_words', type=int, default=20)
p.add_argument('--max_frames', type=int, default=100)
p.add_argument('--feature_framerate', type=int, default=1)
p.add_argument('--margin', type=float, default=0.1)
p.add_argument('--hard_negative_rate', type=float, default=0.5)
p.add_argument('--negative_weighting', type=int, default=1)
p.add_argument('--n_pair', type=int, default=1)
p.add_argument('--num_thread_reader', type=int, default=1)
p.add_argument('--text_num_hidden_layers', type=int, default=12)
p.add_argument('--visual_num_hidden_layers', type=int, default=12)
p.add_argument('--cross_num_hidden_layers', type=int, default=4)
p.add_argument('--loose_type', action='store_true')
p.add_argument('--expand_msrvtt_sentences', action='store_true')
p.add_argument('--train_frame_order', type=int, default=0, choices=[0,1,2])
p.add_argument('--eval_frame_order', type=int, default=0, choices=[0,1,2])
p.add_argument('--freeze_layer_num', type=int, default=0)
p.add_argument('--slice_framepos', type=int, default=0, choices=[0,1,2])
p.add_argument('--linear_patch', type=str, default="2d", choices=["2d","3d"])
p.add_argument('--sim_header', type=str, default="meanP",
choices=["meanP","seqLSTM","seqTransf","tightTransf"])
p.add_argument("--pretrained_clip_name", default="ViT-B/32", type=str)
# ํ™•์žฅ ํ”Œ๋ž˜๊ทธ (๊ทธ๋Œ€๋กœ ์œ ์ง€)
p.add_argument("--use_rff", action='store_true')
p.add_argument("--rff_dim", type=int, default=3000)
p.add_argument("--use_clip4hashing", action="store_true")
p.add_argument("--hash_bit", type=int, default=2048)
# ํ’ˆ์งˆ/์„ฑ๋Šฅ ์˜ต์…˜
p.add_argument("--num_workers", type=int, default=4)
p.add_argument("--pin_memory", action="store_true")
p.add_argument("--no_amp", action="store_true", help="AMP ๋„๊ธฐ")
args = p.parse_args()
if args.sim_header == "tightTransf":
args.loose_type = False
if not args.do_train and not args.do_eval:
raise ValueError("`--do_train` ๋˜๋Š” `--do_eval` ์ค‘ ํ•˜๋‚˜๋Š” ๋ฐ˜๋“œ์‹œ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.")
args.batch_size = int(args.batch_size / args.gradient_accumulation_steps)
return args
# -----------------------
# 2) Seed/Logger/Device
# -----------------------
def setup_env(args):
os.makedirs(args.output_dir, exist_ok=True)
logger = get_logger(os.path.join(args.output_dir, "log.txt"))
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = True # ์†๋„ โ†‘ (์™„์ „ ์žฌํ˜„ ํ•„์š”ํ•˜๋ฉด False)
# matmul precision (Ampere+)
try:
torch.set_float32_matmul_precision("high")
except Exception:
pass
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"device={device}, cuda_available={torch.cuda.is_available()}")
for k in sorted(args.__dict__):
logger.info(f"{k}: {getattr(args, k)}")
return logger, device
# -----------------------
# 3) Model
# -----------------------
def init_model(args, device):
state = torch.load(args.init_model, map_location='cpu') if args.init_model else None
cache_dir = args.cache_dir or os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
model = CLIP4Clip.from_pretrained(args.cross_model, cache_dir=cache_dir, state_dict=state, task_config=args)
model.to(device)
# ์„ ํƒ์  ์–ผ๋ฆฌ๊ธฐ
assert -1 <= args.freeze_layer_num <= 12
if hasattr(model, "clip") and args.freeze_layer_num > -1:
for name, p in model.clip.named_parameters():
if name.startswith(("ln_final","text_projection","logit_scale","visual.ln_post","visual.proj")):
continue
elif ("visual.transformer.resblocks." in name) or ("transformer.resblocks." in name):
layer_num = int(name.split(".resblocks.")[1].split(".")[0])
if layer_num >= args.freeze_layer_num:
continue
if args.linear_patch == "3d" and "conv2." in name:
continue
p.requires_grad = False
return model
# -----------------------
# 4) Optimizer & Scheduler (PyTorch-only warmup+cosine)
# -----------------------
def prep_optimizer(args, model, num_training_steps):
if hasattr(model, 'module'):
model = model.module
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
param_optimizer = list(model.named_parameters())
decay_params = [p for n,p in param_optimizer if not any(nd in n for nd in no_decay) and p.requires_grad]
nodecay_params = [p for n,p in param_optimizer if any(nd in n for nd in no_decay) and p.requires_grad]
optimizer = AdamW([
{'params': decay_params, 'weight_decay': 0.2, 'lr': args.lr},
{'params': nodecay_params, 'weight_decay': 0.0, 'lr': args.lr},
], lr=args.lr)
warmup_steps = int(num_training_steps * args.warmup_proportion)
def lr_lambda(current_step: int):
if current_step < warmup_steps:
return float(current_step) / max(1, warmup_steps) # ์„ ํ˜• ์›Œ๋ฐ์—…
progress = float(current_step - warmup_steps) / max(1, num_training_steps - warmup_steps)
return 0.5 * (1.0 + math.cos(math.pi * progress)) # ์ฝ”์‚ฌ์ธ ๊ฐ์‡ 
scheduler = LambdaLR(optimizer, lr_lambda)
return optimizer, scheduler
# -----------------------
# 5) Train/Eval
# -----------------------
def train_epoch(epoch, args, model, train_loader, device, optimizer, scheduler, scaler, logger):
model.train()
total_loss = 0.0
log_step = 100
start = time.time()
for step, batch in enumerate(train_loader):
batch = tuple(t.to(device, non_blocking=True) for t in batch)
input_ids, input_mask, segment_ids, video, video_mask = batch
with torch.cuda.amp.autocast(enabled=not args.no_amp):
loss = model(input_ids, segment_ids, input_mask, video, video_mask)
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
scaler.scale(loss).backward()
total_loss += float(loss)
if (step + 1) % args.gradient_accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
scheduler.step() # optim.step() ๋‹ค์Œ ํ˜ธ์ถœ
# logit_scale ์•ˆ์ •ํ™”
if hasattr(model, 'clip'):
torch.clamp_(model.clip.logit_scale.data, max=np.log(100))
elif hasattr(model, 'module') and hasattr(model.module, 'clip'):
torch.clamp_(model.module.clip.logit_scale.data, max=np.log(100))
if (step + 1) % log_step == 0:
logger.info(f"[train] epoch {epoch+1} step {step+1}/{len(train_loader)} "
f"loss={float(loss):.4f} time/step={(time.time()-start)/log_step:.4f}")
start = time.time()
return total_loss / len(train_loader)
def _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_seq_out, batch_vis_out):
sim_matrix = []
for idx1, b1 in enumerate(batch_list_t):
input_mask, segment_ids = b1
sequence_output = batch_seq_out[idx1]
each_row = []
for idx2, b2 in enumerate(batch_list_v):
video_mask = b2[0]
visual_output = batch_vis_out[idx2]
logits, *_ = model.get_similarity_logits(sequence_output, visual_output, input_mask, video_mask,
loose_type=model.loose_type)
each_row.append(logits.cpu().detach().numpy())
sim_matrix.append(np.concatenate(each_row, axis=-1))
return sim_matrix
@torch.no_grad()
def eval_epoch(args, model, test_loader, device, logger):
# ์บ์‹œ ํŒŒ์ผ๋ช… ๊ตฌ์„ฑ
suffix = ""
if getattr(args, "use_clip4hashing", False): suffix += "_hash"
if args.use_rff: suffix += "_rff"
if args.init_model: suffix += "_trained"
if "train" in args.val_csv and "10k" in args.val_csv:
cache_name = f"{args.datatype}_train_test_10k_cache{suffix}.pt"
else:
cache_name = f"{args.datatype}_eval_cache{suffix}.pt"
cache_path = os.path.join(args.output_dir, cache_name)
model.eval()
if os.path.exists(cache_path):
logger.info(f"[Eval] load cached features: {cache_path}")
cache = torch.load(cache_path, map_location=device)
batch_seq_out = cache['batch_sequence_output_list']
batch_vis_out = cache['batch_visual_output_list']
batch_list_t = cache['batch_list_t']
batch_list_v = cache['batch_list_v']
else:
logger.info("[Eval] caching features...")
batch_list_t, batch_list_v = [], []
batch_seq_out, batch_vis_out = [], []
for bid, batch in enumerate(test_loader):
batch = tuple(t.to(device, non_blocking=True) for t in batch)
input_ids, input_mask, segment_ids, video, video_mask = batch
with torch.cuda.amp.autocast(enabled=not args.no_amp):
seq_out, vis_out = model.get_sequence_visual_output(input_ids, segment_ids, input_mask, video, video_mask)
batch_seq_out.append(seq_out)
batch_vis_out.append(vis_out)
batch_list_t.append((input_mask, segment_ids))
batch_list_v.append((video_mask,))
if (bid+1) % 20 == 0:
logger.info(f"[Eval] cached batch {bid+1}/{len(test_loader)}")
torch.save({
'batch_sequence_output_list': batch_seq_out,
'batch_visual_output_list': batch_vis_out,
'batch_list_t': batch_list_t,
'batch_list_v': batch_list_v,
}, cache_path)
logger.info(f"[Eval] saved cache to {cache_path}")
sim_matrix = _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_seq_out, batch_vis_out)
sim_matrix = np.concatenate(sim_matrix, axis=0)
logger.info(f"[Eval] sim_matrix shape: {sim_matrix.shape}")
# ํžˆํŠธ๋งต(์˜ต์…˜)
try:
plt.figure(figsize=(8,6))
plt.imshow(sim_matrix[:100, :100], aspect='auto')
plt.title('Similarity Matrix (first 100x100)')
plt.xlabel('Video Index'); plt.ylabel('Text Index')
out_path = os.path.join(args.output_dir, 'sim_matrix_heatmap.png')
plt.tight_layout(); plt.savefig(out_path); plt.close()
logger.info(f"[Eval] heatmap saved: {out_path}")
except Exception as e:
logger.info(f"[Eval] heatmap skipped: {e}")
tv = compute_metrics(sim_matrix)
vt = compute_metrics(sim_matrix.T)
logger.info(f"Text-to-Video: R@1 {tv['R1']:.1f} | R@5 {tv['R5']:.1f} | R@10 {tv['R10']:.1f} | MR {tv['MR']:.1f} | MeanR {tv['MeanR']:.1f}")
logger.info(f"Video-to-Text: R@1 {vt['R1']:.1f} | R@5 {vt['R5']:.1f} | R@10 {vt['R10']:.1f} | MR {vt['MR']:.1f} | MeanR {vt['MeanR']:.1f}")
return tv['R1']
# -----------------------
# 6) Main
# -----------------------
def main():
args = get_args()
logger, device = setup_env(args)
assert args.task_type == "retrieval"
tokenizer = ClipTokenizer()
model = init_model(args, device)
# ๋ฐ์ดํ„ฐ ๋กœ๋” (๊ธฐ์กด ํŒฉํ† ๋ฆฌ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ)
assert args.datatype in DATALOADER_DICT
test_loader, test_len = None, 0
if DATALOADER_DICT[args.datatype]["test"] is not None:
test_loader, test_len = DATALOADER_DICT[args.datatype]["test"](args, tokenizer)
if DATALOADER_DICT[args.datatype]["val"] is not None:
val_loader, val_len = DATALOADER_DICT[args.datatype]["val"](args, tokenizer, subset="val")
else:
val_loader, val_len = test_loader, test_len
if test_loader is None: # ํ…Œ์ŠคํŠธ ์—†์œผ๋ฉด ๋ฐธ๋ฆฌ๋ฐ์ด์…˜์œผ๋กœ ๋Œ€์ฒด
test_loader, test_len = val_loader, val_len
if args.do_train:
train_loader, train_len, train_sampler = DATALOADER_DICT[args.datatype]["train"](args, tokenizer)
# ์•ˆ์ „ํ•œ pin_memory: CUDA ์žˆ์„ ๋•Œ๋งŒ ์‚ฌ์šฉ
if hasattr(train_loader, "pin_memory") and args.pin_memory and not torch.cuda.is_available():
try:
train_loader.pin_memory = False
except Exception:
pass
steps_per_epoch = len(train_loader)
num_train_steps = (steps_per_epoch * args.epochs) // max(1, args.gradient_accumulation_steps)
optimizer, scheduler = prep_optimizer(args, model, num_train_steps)
scaler = torch.cuda.amp.GradScaler(enabled=not args.no_amp)
logger.info(f"[Train] examples={train_len} batch_size={args.batch_size} steps/epoch={steps_per_epoch} total_steps={num_train_steps}")
best_r1 = -1.0
if args.resume_model:
ckpt = torch.load(args.resume_model, map_location='cpu')
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
logger.info(f"[Train] resumed optimizer from {args.resume_model}")
for epoch in range(args.epochs):
loss = train_epoch(epoch, args, model, train_loader, device, optimizer, scheduler, scaler, logger)
logger.info(f"[Train] epoch {epoch+1}/{args.epochs} loss={loss:.4f}")
# ๋น ๋ฅธ ๊ฒ€์ฆ: test ์…‹์„ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ(์› ์ฝ”๋“œ์™€ ๋™์ผํ•œ ํ๋ฆ„)
r1 = eval_epoch(args, model, test_loader, device, logger)
if r1 > best_r1:
best_r1 = r1
model_path = os.path.join(args.output_dir, f"pytorch_model.bin.best")
torch.save((model.module if hasattr(model,'module') else model).state_dict(), model_path)
opt_path = os.path.join(args.output_dir, f"pytorch_opt.bin.best")
torch.save({'epoch': epoch, 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss}, opt_path)
logger.info(f"[Train] new best R1={best_r1:.2f} saved: {model_path}")
if args.do_eval:
eval_epoch(args, model, test_loader, device, logger)
if __name__ == "__main__":
main()