Initial upload
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .claude/settings.local.json +15 -0
- .gitattributes +20 -0
- .gitignore +2 -0
- .vscode/settings.json +5 -0
- CLAUDE.md +137 -0
- CLIP4Clip.png +3 -0
- LICENSE +21 -0
- README.md +193 -0
- __pycache__/metrics.cpython-312.pyc +0 -0
- __pycache__/metrics.cpython-37.pyc +0 -0
- __pycache__/metrics.cpython-39.pyc +0 -0
- __pycache__/simple_dataloaders.cpython-37.pyc +0 -0
- __pycache__/util.cpython-312.pyc +0 -0
- __pycache__/util.cpython-37.pyc +0 -0
- cache_main_task_retrieval.py +1053 -0
- cache_main_task_retrieval_backup.py +867 -0
- ckpts/cache_train_9k/log.txt +115 -0
- ckpts/cache_train_9k/msrvtt_train_test_10k_cache_trained.pt +3 -0
- ckpts/cache_train_9k/sim_matrix_heatmap.png +0 -0
- ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/log.txt +0 -0
- ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/msrvtt_eval_cache.pt +3 -0
- ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/msrvtt_eval_cache_proj3008_binary_trained.pt +3 -0
- ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/msrvtt_eval_cache_proj3008_trained.pt +3 -0
- ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/pytorch_model.bin.0 +3 -0
- ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/pytorch_opt.bin.0 +3 -0
- ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/sim_matrix_heatmap.png +0 -0
- ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/t2v_top10_proj3008_binary_trained.json +1 -0
- ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/t2v_top10_proj3008_binary_trained.tsv +1 -0
- ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/t2v_top10_proj3008_trained.json +1 -0
- ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/t2v_top10_proj3008_trained.tsv +1 -0
- ckpts/ckpt_msrvtt_retrieval_looseType/log.txt +3 -0
- ckpts/ckpt_msrvtt_retrieval_looseType/msrvtt_eval_cache.pt +3 -0
- ckpts/ckpt_msrvtt_retrieval_looseType/msrvtt_eval_cache_proj3008_binary_trained.pt +3 -0
- ckpts/ckpt_msrvtt_retrieval_looseType/msrvtt_eval_cache_proj3008_trained.pt +3 -0
- ckpts/ckpt_msrvtt_retrieval_looseType/msrvtt_eval_cache_trained.pt +3 -0
- ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.0 +3 -0
- ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.1 +3 -0
- ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.2 +3 -0
- ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.3 +3 -0
- ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.4 +3 -0
- ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.0 +3 -0
- ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.1 +3 -0
- ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.2 +3 -0
- ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.3 +3 -0
- ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.4 +3 -0
- ckpts/ckpt_msrvtt_retrieval_looseType/sim_matrix_heatmap.png +0 -0
- ckpts/ckpt_msrvtt_retrieval_looseType/t2v_top10_proj3008_binary_trained.json +1 -0
- ckpts/ckpt_msrvtt_retrieval_looseType/t2v_top10_proj3008_binary_trained.tsv +1 -0
- ckpts/ckpt_msrvtt_retrieval_looseType/t2v_top10_proj3008_trained.json +1 -0
- ckpts/ckpt_msrvtt_retrieval_looseType/t2v_top10_proj3008_trained.tsv +1 -0
.claude/settings.local.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"permissions": {
|
| 3 |
+
"allow": [
|
| 4 |
+
"Bash(grep:*)",
|
| 5 |
+
"Bash(python:*)",
|
| 6 |
+
"Bash(conda create:*)",
|
| 7 |
+
"Bash(conda activate:*)",
|
| 8 |
+
"Bash(pip install:*)",
|
| 9 |
+
"Bash(source:*)",
|
| 10 |
+
"Bash(pkill:*)",
|
| 11 |
+
"Bash(chmod:*)"
|
| 12 |
+
],
|
| 13 |
+
"deny": []
|
| 14 |
+
}
|
| 15 |
+
}
|
.gitattributes
CHANGED
|
@@ -33,3 +33,23 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
CLIP4Clip.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/pytorch_model.bin.0 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/pytorch_opt.bin.0 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
ckpts/ckpt_msrvtt_retrieval_looseType/log.txt filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.0 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.1 filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.2 filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.3 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.4 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.0 filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.1 filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.2 filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.3 filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.4 filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
ckpts/ckpt_msrvtt_retrieval_looseType_0909/pytorch_model.bin.0 filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
ckpts/ckpt_msrvtt_retrieval_looseType_0909/pytorch_model.bin.1 filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
ckpts/ckpt_msrvtt_retrieval_looseType_0909/pytorch_model.bin.2 filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
ckpts/ckpt_msrvtt_retrieval_looseType_0909/pytorch_opt.bin.0 filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
ckpts/ckpt_msrvtt_retrieval_looseType_0909/pytorch_opt.bin.1 filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
ckpts/ckpt_msrvtt_retrieval_looseType_0909/pytorch_opt.bin.2 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
.idea
|
.vscode/settings.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"python-envs.defaultEnvManager": "ms-python.python:conda",
|
| 3 |
+
"python-envs.defaultPackageManager": "ms-python.python:conda",
|
| 4 |
+
"python-envs.pythonProjects": []
|
| 5 |
+
}
|
CLAUDE.md
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CLAUDE.md
|
| 2 |
+
|
| 3 |
+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
| 4 |
+
|
| 5 |
+
## Project Overview
|
| 6 |
+
|
| 7 |
+
CLIP4Clip is a video-text retrieval model based on OpenAI's CLIP (ViT-B). The project investigates three similarity calculation approaches: parameter-free type, sequential type, and tight type. This repository has been extended with additional features including hashing, hypervector representations, and random Fourier features (RFF).
|
| 8 |
+
|
| 9 |
+
## Environment Setup
|
| 10 |
+
|
| 11 |
+
The project uses conda for environment management:
|
| 12 |
+
```bash
|
| 13 |
+
conda env create -f environment.yml
|
| 14 |
+
conda activate c4c
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
Key dependencies: PyTorch 1.7.1, CUDA 11.0, opencv-python, tqdm, ftfy, regex, pandas, boto3
|
| 18 |
+
|
| 19 |
+
## Development Commands
|
| 20 |
+
|
| 21 |
+
### Training Commands
|
| 22 |
+
- **Standard training**: `./train.sh` - Trains CLIP4Clip with hashing extensions
|
| 23 |
+
- **Distributed training**: Uses `torch.distributed.launch` with multiple GPUs
|
| 24 |
+
- **Key training script**: `cache_main_task_retrieval.py` (main training entry point)
|
| 25 |
+
- **Original training script**: `main_task_retrieval.py` (reference implementation)
|
| 26 |
+
|
| 27 |
+
### Testing/Evaluation Commands
|
| 28 |
+
- **Standard evaluation**: `./test.sh` - Evaluates with RFF features
|
| 29 |
+
- **Hash evaluation**: `./test_hash.sh` - Evaluates hashing-based models
|
| 30 |
+
- **Hypervector evaluation**: `./test_hv.sh` - Evaluates with high-dimensional RFF
|
| 31 |
+
- All test scripts use `cache_main_task_retrieval.py --do_eval`
|
| 32 |
+
|
| 33 |
+
### Data Preprocessing
|
| 34 |
+
```bash
|
| 35 |
+
python preprocess/compress_video.py --input_root [raw_video_path] --output_root [compressed_video_path]
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
### CLIP Model Downloads
|
| 39 |
+
```bash
|
| 40 |
+
# ViT-B/32 (default)
|
| 41 |
+
wget -P ./modules https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt
|
| 42 |
+
|
| 43 |
+
# ViT-B/16 (better performance)
|
| 44 |
+
wget -P ./modules https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
## Architecture Overview
|
| 48 |
+
|
| 49 |
+
### Core Components
|
| 50 |
+
- **Main Models**:
|
| 51 |
+
- `modules/modeling.py` - CLIP4Clip main model class
|
| 52 |
+
- `modules/module_clip.py` - CLIP model implementation
|
| 53 |
+
- `modules/module_cross.py` - Cross-modal interaction modules
|
| 54 |
+
|
| 55 |
+
- **Extended Features**:
|
| 56 |
+
- `modules/hashnet.py` - Hash-based representations
|
| 57 |
+
- `modules/hypervector.py` - Hypervector/random projection methods
|
| 58 |
+
- `modules/binarize_ste.py` - Straight-through estimator for binarization
|
| 59 |
+
- `modules/minmax_hash.py` - MinMax hashing implementation
|
| 60 |
+
|
| 61 |
+
### Data Loading
|
| 62 |
+
- **Dataloaders**: `dataloaders/data_dataloaders.py` provides unified interface
|
| 63 |
+
- **Supported datasets**: MSR-VTT, MSVD, LSMDC, ActivityNet, DiDeMo
|
| 64 |
+
- **Key parameters**: `--datatype`, `--features_path`, `--max_frames`, `--max_words`
|
| 65 |
+
|
| 66 |
+
### Training Configuration
|
| 67 |
+
- **Similarity headers**: `meanP` (parameter-free), `seqLSTM`, `seqTransf`, `tightTransf`
|
| 68 |
+
- **Linear patch**: `2d` or `3d` patch projection
|
| 69 |
+
- **Frame sampling**: `--slice_framepos` (0=head, 1=tail, 2=uniform)
|
| 70 |
+
- **Model variants**: Standard CLIP4Clip, with hashing (`--use_clip4hashing`), with RFF (`--use_rff`)
|
| 71 |
+
|
| 72 |
+
### Model Checkpoints
|
| 73 |
+
- Saved in `ckpts/` directory with model versions
|
| 74 |
+
- Format: `pytorch_model.bin.{epoch}` and `pytorch_opt.bin.{epoch}`
|
| 75 |
+
- Cache files for evaluation: `*_eval_cache.pt`, `*_eval_cache_trained.pt`
|
| 76 |
+
|
| 77 |
+
## Key Parameters for Development
|
| 78 |
+
|
| 79 |
+
### Essential Arguments
|
| 80 |
+
- `--datatype`: Dataset type (msrvtt, msvd, lsmdc, activity, didemo)
|
| 81 |
+
- `--features_path`: Path to video features/files
|
| 82 |
+
- `--output_dir`: Checkpoint output directory
|
| 83 |
+
- `--sim_header`: Similarity calculation method
|
| 84 |
+
- `--pretrained_clip_name`: CLIP model version (ViT-B/32 or ViT-B/16)
|
| 85 |
+
|
| 86 |
+
### Extended Features
|
| 87 |
+
- `--use_clip4hashing --hash_bit 2048`: Enable hashing with specified bit size
|
| 88 |
+
- `--use_rff --rff_dim 3000`: Enable Random Fourier Features with dimension
|
| 89 |
+
- `--freeze_layer_num`: Number of CLIP layers to freeze (0-12)
|
| 90 |
+
|
| 91 |
+
### Training Parameters
|
| 92 |
+
- Distributed training uses `torch.distributed.launch`
|
| 93 |
+
- Typical batch sizes: 128-512 for training, 16 for validation
|
| 94 |
+
- Learning rates: 1e-4 (base), 1e-3 (coef_lr for CLIP parameters)
|
| 95 |
+
- Standard training: 5-50 epochs depending on dataset
|
| 96 |
+
|
| 97 |
+
## Additional Development Information
|
| 98 |
+
|
| 99 |
+
### Environment and Dependencies
|
| 100 |
+
```bash
|
| 101 |
+
# Create conda environment
|
| 102 |
+
conda env create -f environment.yml
|
| 103 |
+
conda activate c4c
|
| 104 |
+
|
| 105 |
+
# Install additional dependencies if needed
|
| 106 |
+
pip install ftfy regex tqdm opencv-python boto3 requests pandas
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
### Quick Start Commands
|
| 110 |
+
- **Setup environment**: `conda env create -f environment.yml && conda activate c4c`
|
| 111 |
+
- **Single GPU training**: `python cache_main_task_retrieval.py --do_train [args]`
|
| 112 |
+
- **Multi-GPU training**: `python -m torch.distributed.launch --nproc_per_node=2 cache_main_task_retrieval.py --do_train [args]`
|
| 113 |
+
- **Evaluation only**: `python cache_main_task_retrieval.py --do_eval [args]`
|
| 114 |
+
|
| 115 |
+
### Debug and Development
|
| 116 |
+
- **Check GPU availability**: Use `torch.cuda.is_available()` in Python
|
| 117 |
+
- **Monitor training**: Logs are saved in output_dir with format `log.txt`
|
| 118 |
+
- **Resume training**: Use `--resume_model` with `--init_model` to continue from checkpoint
|
| 119 |
+
- **Video preprocessing**: Optional compression with `preprocess/compress_video.py` for 3fps/224px
|
| 120 |
+
|
| 121 |
+
### Performance Tips
|
| 122 |
+
- **Batch sizes**: Start with 128 for training, 16 for validation; adjust based on GPU memory
|
| 123 |
+
- **Frame sampling**: `--slice_framepos 2` (uniform) generally works best
|
| 124 |
+
- **CLIP model**: ViT-B/16 provides better performance than ViT-B/32 but requires more memory
|
| 125 |
+
- **Frozen layers**: `--freeze_layer_num 12` freezes all CLIP layers for faster training
|
| 126 |
+
|
| 127 |
+
### Data Structure Expectations
|
| 128 |
+
- **Video files**: Raw videos or compressed (3fps, 224px recommended)
|
| 129 |
+
- **CSV format**: Training/validation splits follow MSRVTT format
|
| 130 |
+
- **JSON data**: Captions and metadata in standard video-text retrieval format
|
| 131 |
+
- **Features path**: Directory containing video files, organized by dataset type
|
| 132 |
+
|
| 133 |
+
### Common Issues and Solutions
|
| 134 |
+
- **CUDA out of memory**: Reduce batch_size or max_frames/max_words
|
| 135 |
+
- **Slow data loading**: Increase num_thread_reader (but not too high)
|
| 136 |
+
- **Poor performance**: Check pretrained CLIP weights are loaded correctly
|
| 137 |
+
- **Resume failures**: Ensure both pytorch_model.bin and pytorch_opt.bin exist for the epoch
|
CLIP4Clip.png
ADDED
|
Git LFS Details
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2021 ArrowLuo
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CLIP4Clip: An Empirical Study of CLIP for End to End Video Clip Retrieval
|
| 2 |
+
|
| 3 |
+
(**July 28, 2021**) Add ViT-B/16 with an extra `--pretrained_clip_name`
|
| 4 |
+
|
| 5 |
+
(**Apr. 22, 2021**) First version
|
| 6 |
+
|
| 7 |
+
The implementation of paper [**CLIP4Clip: An Empirical Study of CLIP for End to End Video Clip Retrieval**](https://arxiv.org/abs/2104.08860).
|
| 8 |
+
|
| 9 |
+
CLIP4Clip is a video-text retrieval model based on [CLIP (ViT-B)](https://github.com/openai/CLIP). We investigate three similarity calculation approaches: parameter-free type, sequential type, and tight type, in this work. The model achieve SOTA results on MSR-VTT, MSVD, LSMDC, ActivityNet, and DiDeMo.
|
| 10 |
+
|
| 11 |
+

|
| 12 |
+
|
| 13 |
+
## Requirement
|
| 14 |
+
```sh
|
| 15 |
+
# From CLIP
|
| 16 |
+
conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
|
| 17 |
+
pip install ftfy regex tqdm
|
| 18 |
+
pip install opencv-python boto3 requests pandas
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
## Data Preparing
|
| 22 |
+
|
| 23 |
+
**For MSRVTT**
|
| 24 |
+
|
| 25 |
+
The official data and video links can be found in [link](http://ms-multimedia-challenge.com/2017/dataset).
|
| 26 |
+
|
| 27 |
+
For the convenience, you can also download the splits and captions by,
|
| 28 |
+
```sh
|
| 29 |
+
wget https://github.com/ArrowLuo/CLIP4Clip/releases/download/v0.0/msrvtt_data.zip
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
Besides, the raw videos can be found in [sharing](https://github.com/m-bain/frozen-in-time#-finetuning-benchmarks-msr-vtt) from *Frozen️ in Time*, i.e.,
|
| 33 |
+
```sh
|
| 34 |
+
wget https://www.robots.ox.ac.uk/~maxbain/frozen-in-time/data/MSRVTT.zip
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
**For MSVD**
|
| 38 |
+
|
| 39 |
+
Raw videos can be download from [link](https://www.cs.utexas.edu/users/ml/clamp/videoDescription/).
|
| 40 |
+
|
| 41 |
+
The splits and `raw_captions` can be found in the wonderful job [collaborative-experts](https://github.com/albanie/collaborative-experts/blob/master/misc/datasets/msvd/README.md). For the convenience, you can also download them by,
|
| 42 |
+
```sh
|
| 43 |
+
wget https://github.com/ArrowLuo/CLIP4Clip/releases/download/v0.0/msvd_data.zip
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
**For LSMDC**
|
| 47 |
+
|
| 48 |
+
You must obtain permission from MPII to download and use the data. The download link is [here](https://sites.google.com/site/describingmovies/download).
|
| 49 |
+
The 1000 test clips data is [link](http://www.google.com/url?q=http%3A%2F%2Fdatasets.d2.mpi-inf.mpg.de%2FmovieDescription%2Fprotected%2Flsmdc2016%2FLSMDC16_challenge_1000_publictect.csv&sa=D&sntz=1&usg=AFQjCNGIaGVhCeb6zNfUs2UL1zNzoEtaSg). Read our paper and the [dataloader](./dataloaders/dataloader_lsmdc_retrieval.py) for more information.
|
| 50 |
+
|
| 51 |
+
**For ActivityNet**
|
| 52 |
+
|
| 53 |
+
The official websit has made the full dataset available on Google and Baidu drives, see more information at [here](http://activity-net.org/download.html) . The splits can be found in the job [collaborative-experts](https://github.com/albanie/collaborative-experts/tree/master/misc/datasets/activity-net).
|
| 54 |
+
|
| 55 |
+
**For DiDeMo**
|
| 56 |
+
|
| 57 |
+
Raw videos can be download from [LisaAnne/LocalizingMoments](https://github.com/LisaAnne/LocalizingMoments). The splits can be found in the job [collaborative-experts](https://github.com/albanie/collaborative-experts/tree/master/misc/datasets/didemo/README.md).
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
## Compress Video for Speed-up (optional)
|
| 61 |
+
```sh
|
| 62 |
+
python preprocess/compress_video.py --input_root [raw_video_path] --output_root [compressed_video_path]
|
| 63 |
+
```
|
| 64 |
+
This script will compress the video to *3fps* with width *224* (or height *224*). Modify the variables for your customization.
|
| 65 |
+
|
| 66 |
+
## How to Run
|
| 67 |
+
|
| 68 |
+
>`--features_path` is the video root path
|
| 69 |
+
>
|
| 70 |
+
>`--linear_patch` can be set with `2d` or `3d`
|
| 71 |
+
>
|
| 72 |
+
> `--sim_header` can be set with `meanP`, `seqLSTM`, `seqTransf`, or `tightTransf`
|
| 73 |
+
>
|
| 74 |
+
> `--pretrained_clip_name` can be set with `ViT-B/32` or `ViT-B/16`
|
| 75 |
+
>
|
| 76 |
+
> `--resume_model` can be used to reload the saved optimizer state to continuely train the model, **Note**: need to set the corresponding chechpoint via `--init_model` simultaneously.
|
| 77 |
+
|
| 78 |
+
read our paper for more details on `--linear_patch` and `--sim_header`. Test more hyperparameters for better performance.
|
| 79 |
+
|
| 80 |
+
Download CLIP (ViT-B/32) weight,
|
| 81 |
+
```sh
|
| 82 |
+
wget -P ./modules https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt
|
| 83 |
+
```
|
| 84 |
+
or, download CLIP (ViT-B/16) weight,
|
| 85 |
+
```sh
|
| 86 |
+
wget -P ./modules https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
Then, run
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
*The CLIP (ViT-B/32) is the default setting in the paper, replacing with the ViT-B/16 for better performance.*
|
| 93 |
+
|
| 94 |
+
### MSRVTT
|
| 95 |
+
|
| 96 |
+
```sh
|
| 97 |
+
DATA_PATH=[Your MSRVTT data and videos path]
|
| 98 |
+
python -m torch.distributed.launch --nproc_per_node=4 \
|
| 99 |
+
main_task_retrieval.py --do_train --num_thread_reader=0 \
|
| 100 |
+
--epochs=5 --batch_size=128 --n_display=50 \
|
| 101 |
+
--train_csv ${DATA_PATH}/MSRVTT_train.9k.csv \
|
| 102 |
+
--val_csv ${DATA_PATH}/MSRVTT_JSFUSION_test.csv \
|
| 103 |
+
--data_path ${DATA_PATH}/MSRVTT_data.json \
|
| 104 |
+
--features_path ${DATA_PATH}/MSRVTT_Videos \
|
| 105 |
+
--output_dir ckpts/ckpt_msrvtt_retrieval_looseType \
|
| 106 |
+
--lr 1e-4 --max_words 32 --max_frames 12 --batch_size_val 16 \
|
| 107 |
+
--datatype msrvtt --expand_msrvtt_sentences \
|
| 108 |
+
--feature_framerate 1 --coef_lr 1e-3 \
|
| 109 |
+
--freeze_layer_num 0 --slice_framepos 2 \
|
| 110 |
+
--loose_type --linear_patch 2d --sim_header meanP \
|
| 111 |
+
--pretrained_clip_name ViT-B/32
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
### MSVD
|
| 115 |
+
```sh
|
| 116 |
+
DATA_PATH=[Your MSVD data and videos path]
|
| 117 |
+
python -m torch.distributed.launch --nproc_per_node=4 \
|
| 118 |
+
main_task_retrieval.py --do_train --num_thread_reader=2 \
|
| 119 |
+
--epochs=5 --batch_size=128 --n_display=50 \
|
| 120 |
+
--data_path ${DATA_PATH} \
|
| 121 |
+
--features_path ${DATA_PATH}/MSVD_Videos \
|
| 122 |
+
--output_dir ckpts/ckpt_msvd_retrieval_looseType \
|
| 123 |
+
--lr 1e-4 --max_words 32 --max_frames 12 --batch_size_val 16 \
|
| 124 |
+
--datatype msvd \
|
| 125 |
+
--feature_framerate 1 --coef_lr 1e-3 \
|
| 126 |
+
--freeze_layer_num 0 --slice_framepos 2 \
|
| 127 |
+
--loose_type --linear_patch 2d --sim_header meanP \
|
| 128 |
+
--pretrained_clip_name ViT-B/32
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
### LSMDC
|
| 132 |
+
```sh
|
| 133 |
+
DATA_PATH=[Your LSMDC data and videos path]
|
| 134 |
+
python -m torch.distributed.launch --nproc_per_node=4 \
|
| 135 |
+
main_task_retrieval.py --do_train --num_thread_reader=2 \
|
| 136 |
+
--epochs=5 --batch_size=128 --n_display=50 \
|
| 137 |
+
--data_path ${DATA_PATH} \
|
| 138 |
+
--features_path ${DATA_PATH}/LSMDC_Videos \
|
| 139 |
+
--output_dir ckpts/ckpt_lsmdc_retrieval_looseType \
|
| 140 |
+
--lr 1e-4 --max_words 32 --max_frames 12 --batch_size_val 16 \
|
| 141 |
+
--datatype lsmdc --feature_framerate 1 --coef_lr 1e-3 \
|
| 142 |
+
--freeze_layer_num 0 --slice_framepos 2 \
|
| 143 |
+
--loose_type --linear_patch 2d --sim_header meanP \
|
| 144 |
+
--pretrained_clip_name ViT-B/32
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
### ActivityNet
|
| 148 |
+
ActivityNet is regarded as video-paragraph retrieval in our setting, thus, need more GPUs (or run with multi-node).
|
| 149 |
+
```sh
|
| 150 |
+
DATA_PATH=[Your ActivityNet data and videos path]
|
| 151 |
+
python -m torch.distributed.launch --nproc_per_node=8 \
|
| 152 |
+
main_task_retrieval.py --do_train --num_thread_reader=2 \
|
| 153 |
+
--epochs=5 --batch_size=128 --n_display=50 \
|
| 154 |
+
--data_path ${DATA_PATH} \
|
| 155 |
+
--features_path ${DATA_PATH}/Activity_Videos \
|
| 156 |
+
--output_dir ckpts/ckpt_activity_retrieval_looseType \
|
| 157 |
+
--lr 1e-4 --max_words 64 --max_frames 64 --batch_size_val 16 \
|
| 158 |
+
--datatype activity --feature_framerate 1 --coef_lr 1e-3 \
|
| 159 |
+
--freeze_layer_num 0 --slice_framepos 2 \
|
| 160 |
+
--loose_type --linear_patch 2d --sim_header meanP \
|
| 161 |
+
--pretrained_clip_name ViT-B/32
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
### DiDeMo
|
| 165 |
+
DiDeMo is regarded as video-paragraph retrieval in our setting, thus, need more GPUs (or run with multi-node).
|
| 166 |
+
```sh
|
| 167 |
+
DATA_PATH=[Your DiDeMo data and videos path]
|
| 168 |
+
python -m torch.distributed.launch --nproc_per_node=8 \
|
| 169 |
+
main_task_retrieval.py --do_train --num_thread_reader=2 \
|
| 170 |
+
--epochs=5 --batch_size=128 --n_display=50 \
|
| 171 |
+
--data_path ${DATA_PATH} \
|
| 172 |
+
--features_path ${DATA_PATH}/DiDeMo_Videos \
|
| 173 |
+
--output_dir ckpts/ckpt_didemo_retrieval_looseType \
|
| 174 |
+
--lr 1e-4 --max_words 64 --max_frames 64 --batch_size_val 16 \
|
| 175 |
+
--datatype didemo --feature_framerate 1 --coef_lr 1e-3 \
|
| 176 |
+
--freeze_layer_num 0 --slice_framepos 2 \
|
| 177 |
+
--loose_type --linear_patch 2d --sim_header meanP \
|
| 178 |
+
--pretrained_clip_name ViT-B/32
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
# Citation
|
| 182 |
+
If you find CLIP4Clip useful in your work, you can cite the following paper:
|
| 183 |
+
```bibtex
|
| 184 |
+
@Article{Luo2021CLIP4Clip,
|
| 185 |
+
author = {Huaishao Luo and Lei Ji and Ming Zhong and Yang Chen and Wen Lei and Nan Duan and Tianrui Li},
|
| 186 |
+
title = {{CLIP4Clip}: An Empirical Study of CLIP for End to End Video Clip Retrieval},
|
| 187 |
+
journal = {arXiv preprint arXiv:2104.08860},
|
| 188 |
+
year = {2021},
|
| 189 |
+
}
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
# Acknowledgments
|
| 193 |
+
Our code is based on [CLIP](https://github.com/openai/CLIP) and [UniVL](https://github.com/microsoft/UniVL).
|
__pycache__/metrics.cpython-312.pyc
ADDED
|
Binary file (4.53 kB). View file
|
|
|
__pycache__/metrics.cpython-37.pyc
ADDED
|
Binary file (2.51 kB). View file
|
|
|
__pycache__/metrics.cpython-39.pyc
ADDED
|
Binary file (2.54 kB). View file
|
|
|
__pycache__/simple_dataloaders.cpython-37.pyc
ADDED
|
Binary file (1.44 kB). View file
|
|
|
__pycache__/util.cpython-312.pyc
ADDED
|
Binary file (4.37 kB). View file
|
|
|
__pycache__/util.cpython-37.pyc
ADDED
|
Binary file (2.36 kB). View file
|
|
|
cache_main_task_retrieval.py
ADDED
|
@@ -0,0 +1,1053 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import
|
| 2 |
+
from __future__ import division
|
| 3 |
+
from __future__ import unicode_literals
|
| 4 |
+
from __future__ import print_function
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import numpy as np
|
| 9 |
+
import random
|
| 10 |
+
import os
|
| 11 |
+
from metrics import compute_metrics, tensor_text_to_video_metrics, tensor_video_to_text_sim
|
| 12 |
+
import time
|
| 13 |
+
import argparse
|
| 14 |
+
from modules.tokenization_clip import SimpleTokenizer as ClipTokenizer
|
| 15 |
+
from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
| 16 |
+
from modules.modeling import CLIP4Clip
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
from modules.optimization import BertAdam
|
| 19 |
+
|
| 20 |
+
from util import parallel_apply, get_logger
|
| 21 |
+
from modules.until_module import AllGather
|
| 22 |
+
from dataloaders.data_dataloaders import DATALOADER_DICT
|
| 23 |
+
|
| 24 |
+
# torch.distributed.init_process_group(backend="nccl")
|
| 25 |
+
|
| 26 |
+
global logger
|
| 27 |
+
|
| 28 |
+
def get_args(description='CLIP4Clip on Retrieval Task'):
|
| 29 |
+
parser = argparse.ArgumentParser(description=description)
|
| 30 |
+
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
|
| 31 |
+
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
|
| 32 |
+
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
|
| 33 |
+
|
| 34 |
+
parser.add_argument('--train_csv', type=str, default='data/.train.csv', help='')
|
| 35 |
+
parser.add_argument('--val_csv', type=str, default='data/.val.csv', help='')
|
| 36 |
+
parser.add_argument('--data_path', type=str, default='data/caption.pickle', help='data pickle file path')
|
| 37 |
+
parser.add_argument('--features_path', type=str, default='data/videos_feature.pickle', help='feature path')
|
| 38 |
+
|
| 39 |
+
parser.add_argument('--num_thread_reader', type=int, default=1, help='')
|
| 40 |
+
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate')
|
| 41 |
+
parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit')
|
| 42 |
+
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
|
| 43 |
+
parser.add_argument('--batch_size_val', type=int, default=3500, help='batch size eval')
|
| 44 |
+
parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate exp epoch decay')
|
| 45 |
+
parser.add_argument('--n_display', type=int, default=100, help='Information display frequence')
|
| 46 |
+
parser.add_argument('--video_dim', type=int, default=1024, help='video feature dimension')
|
| 47 |
+
parser.add_argument('--seed', type=int, default=42, help='random seed')
|
| 48 |
+
parser.add_argument('--max_words', type=int, default=20, help='')
|
| 49 |
+
parser.add_argument('--max_frames', type=int, default=100, help='')
|
| 50 |
+
parser.add_argument('--feature_framerate', type=int, default=1, help='')
|
| 51 |
+
parser.add_argument('--margin', type=float, default=0.1, help='margin for loss')
|
| 52 |
+
parser.add_argument('--hard_negative_rate', type=float, default=0.5, help='rate of intra negative sample')
|
| 53 |
+
parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative')
|
| 54 |
+
parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader')
|
| 55 |
+
|
| 56 |
+
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
| 57 |
+
help="The output directory where the model predictions and checkpoints will be written.")
|
| 58 |
+
parser.add_argument("--cross_model", default="cross-base", type=str, required=False, help="Cross module")
|
| 59 |
+
parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.")
|
| 60 |
+
parser.add_argument("--resume_model", default=None, type=str, required=False, help="Resume train model.")
|
| 61 |
+
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
|
| 62 |
+
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
| 63 |
+
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.")
|
| 64 |
+
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
| 65 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
| 66 |
+
parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.")
|
| 67 |
+
|
| 68 |
+
parser.add_argument("--cache_dir", default="", type=str,
|
| 69 |
+
help="Where do you want to store the pre-trained models downloaded from s3")
|
| 70 |
+
|
| 71 |
+
parser.add_argument('--fp16', action='store_true',
|
| 72 |
+
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
| 73 |
+
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
| 74 |
+
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
| 75 |
+
"See details at https://nvidia.github.io/apex/amp.html")
|
| 76 |
+
|
| 77 |
+
parser.add_argument("--task_type", default="retrieval", type=str, help="Point the task `retrieval` to finetune.")
|
| 78 |
+
parser.add_argument("--datatype", default="msrvtt", type=str, help="Point the dataset to finetune.")
|
| 79 |
+
|
| 80 |
+
parser.add_argument("--world_size", default=0, type=int, help="distribted training")
|
| 81 |
+
parser.add_argument("--local_rank", default=0, type=int, help="distribted training")
|
| 82 |
+
# alias for torch.distributed.run / launch passing --local-rank
|
| 83 |
+
parser.add_argument("--local-rank", dest="local_rank", default=0, type=int, help="alias for local_rank")
|
| 84 |
+
parser.add_argument("--rank", default=0, type=int, help="distribted training")
|
| 85 |
+
parser.add_argument('--coef_lr', type=float, default=1., help='coefficient for bert branch.')
|
| 86 |
+
parser.add_argument('--use_mil', action='store_true', help="Whether use MIL as Miech et. al. (2020).")
|
| 87 |
+
parser.add_argument('--sampled_use_mil', action='store_true', help="Whether MIL, has a high priority than use_mil.")
|
| 88 |
+
|
| 89 |
+
parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.")
|
| 90 |
+
parser.add_argument('--visual_num_hidden_layers', type=int, default=12, help="Layer NO. of visual.")
|
| 91 |
+
parser.add_argument('--cross_num_hidden_layers', type=int, default=4, help="Layer NO. of cross.")
|
| 92 |
+
|
| 93 |
+
parser.add_argument('--loose_type', action='store_true', help="Default using tight type for retrieval.")
|
| 94 |
+
parser.add_argument('--expand_msrvtt_sentences', action='store_true', help="")
|
| 95 |
+
|
| 96 |
+
parser.add_argument('--train_frame_order', type=int, default=0, choices=[0, 1, 2],
|
| 97 |
+
help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.")
|
| 98 |
+
parser.add_argument('--eval_frame_order', type=int, default=0, choices=[0, 1, 2],
|
| 99 |
+
help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.")
|
| 100 |
+
|
| 101 |
+
parser.add_argument('--freeze_layer_num', type=int, default=0, help="Layer NO. of CLIP need to freeze.")
|
| 102 |
+
parser.add_argument('--slice_framepos', type=int, default=0, choices=[0, 1, 2],
|
| 103 |
+
help="0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly.")
|
| 104 |
+
parser.add_argument('--linear_patch', type=str, default="2d", choices=["2d", "3d"],
|
| 105 |
+
help="linear projection of flattened patches.")
|
| 106 |
+
parser.add_argument('--sim_header', type=str, default="meanP",
|
| 107 |
+
choices=["meanP", "seqLSTM", "seqTransf", "tightTransf"],
|
| 108 |
+
help="choice a similarity header.")
|
| 109 |
+
|
| 110 |
+
parser.add_argument("--pretrained_clip_name", default="ViT-B/32", type=str, help="Choose a CLIP version")
|
| 111 |
+
parser.add_argument("--use_rff", action='store_true', help="Use RFF hypervector encoding for video embeddings")
|
| 112 |
+
parser.add_argument("--rff_dim", type=int, default=3000, help="Hypervector dimension for RFF encoding")
|
| 113 |
+
parser.add_argument("--use_clip4hashing", action="store_true", help="CLIP4Hashing 손실·해시 경로 사용 여부")
|
| 114 |
+
parser.add_argument("--hash_bit", type=int, default=2048, help="해시 코드 비트 수 (default 1024)")
|
| 115 |
+
# Projection options
|
| 116 |
+
parser.add_argument('--proj', type=int, default=0, help='Projection dim (0 to disable, e.g., 3008)')
|
| 117 |
+
parser.add_argument('--proj_act', type=str, default='tanh', choices=['tanh', 'relu', 'gelu', 'sigmoid'],
|
| 118 |
+
help='Activation after projection')
|
| 119 |
+
parser.add_argument('--binary_eval', action='store_true', help='Use binarized retrieval at eval (sign + sum)')
|
| 120 |
+
|
| 121 |
+
args = parser.parse_args()
|
| 122 |
+
|
| 123 |
+
if args.sim_header == "tightTransf":
|
| 124 |
+
args.loose_type = False
|
| 125 |
+
|
| 126 |
+
# Check paramenters
|
| 127 |
+
if args.gradient_accumulation_steps < 1:
|
| 128 |
+
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
| 129 |
+
args.gradient_accumulation_steps))
|
| 130 |
+
if not args.do_train and not args.do_eval:
|
| 131 |
+
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
| 132 |
+
|
| 133 |
+
args.batch_size = int(args.batch_size / args.gradient_accumulation_steps)
|
| 134 |
+
|
| 135 |
+
# Accept env fallback if provided by torchrun
|
| 136 |
+
if 'LOCAL_RANK' in os.environ:
|
| 137 |
+
try:
|
| 138 |
+
args.local_rank = int(os.environ['LOCAL_RANK'])
|
| 139 |
+
except Exception:
|
| 140 |
+
pass
|
| 141 |
+
if 'RANK' in os.environ:
|
| 142 |
+
try:
|
| 143 |
+
args.rank = int(os.environ['RANK'])
|
| 144 |
+
except Exception:
|
| 145 |
+
pass
|
| 146 |
+
if 'WORLD_SIZE' in os.environ:
|
| 147 |
+
try:
|
| 148 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
| 149 |
+
except Exception:
|
| 150 |
+
pass
|
| 151 |
+
|
| 152 |
+
return args
|
| 153 |
+
|
| 154 |
+
def set_seed_logger(args):
|
| 155 |
+
global logger
|
| 156 |
+
# predefining random initial seeds
|
| 157 |
+
random.seed(args.seed)
|
| 158 |
+
os.environ['PYTHONHASHSEED'] = str(args.seed)
|
| 159 |
+
np.random.seed(args.seed)
|
| 160 |
+
torch.manual_seed(args.seed)
|
| 161 |
+
torch.cuda.manual_seed(args.seed)
|
| 162 |
+
torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU.
|
| 163 |
+
torch.backends.cudnn.benchmark = False
|
| 164 |
+
torch.backends.cudnn.deterministic = True
|
| 165 |
+
|
| 166 |
+
world_size = torch.distributed.get_world_size()
|
| 167 |
+
torch.cuda.set_device(args.local_rank)
|
| 168 |
+
args.world_size = world_size
|
| 169 |
+
rank = torch.distributed.get_rank()
|
| 170 |
+
args.rank = rank
|
| 171 |
+
|
| 172 |
+
if not os.path.exists(args.output_dir):
|
| 173 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 174 |
+
|
| 175 |
+
logger = get_logger(os.path.join(args.output_dir, "log.txt"))
|
| 176 |
+
|
| 177 |
+
if args.local_rank == 0:
|
| 178 |
+
logger.info("Effective parameters:")
|
| 179 |
+
for key in sorted(args.__dict__):
|
| 180 |
+
logger.info(" <<< {}: {}".format(key, args.__dict__[key]))
|
| 181 |
+
|
| 182 |
+
# Sanity check for binary eval + bit packing compatibility
|
| 183 |
+
if getattr(args, 'binary_eval', False) and getattr(args, 'proj', 0) > 0:
|
| 184 |
+
if args.proj % 64 != 0:
|
| 185 |
+
raise ValueError(f"--proj must be divisible by 64 for binary eval, got {args.proj}")
|
| 186 |
+
|
| 187 |
+
return args
|
| 188 |
+
|
| 189 |
+
def init_device(args, local_rank):
|
| 190 |
+
global logger
|
| 191 |
+
|
| 192 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu", local_rank)
|
| 193 |
+
|
| 194 |
+
n_gpu = torch.cuda.device_count()
|
| 195 |
+
logger.info("device: {} n_gpu: {}".format(device, n_gpu))
|
| 196 |
+
args.n_gpu = n_gpu
|
| 197 |
+
|
| 198 |
+
if args.batch_size % args.n_gpu != 0 or args.batch_size_val % args.n_gpu != 0:
|
| 199 |
+
raise ValueError("Invalid batch_size/batch_size_val and n_gpu parameter: {}%{} and {}%{}, should be == 0".format(
|
| 200 |
+
args.batch_size, args.n_gpu, args.batch_size_val, args.n_gpu))
|
| 201 |
+
|
| 202 |
+
return device, n_gpu
|
| 203 |
+
|
| 204 |
+
def init_model(args, device, n_gpu, local_rank):
|
| 205 |
+
|
| 206 |
+
if args.init_model:
|
| 207 |
+
model_state_dict = torch.load(args.init_model, map_location='cpu')
|
| 208 |
+
else:
|
| 209 |
+
model_state_dict = None
|
| 210 |
+
|
| 211 |
+
# Prepare model
|
| 212 |
+
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
|
| 213 |
+
model = CLIP4Clip.from_pretrained(args.cross_model, cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
|
| 214 |
+
|
| 215 |
+
# Attach projection head if requested (before DDP wrapping)
|
| 216 |
+
if getattr(args, 'proj', 0) and args.proj > 0:
|
| 217 |
+
print("Projection")
|
| 218 |
+
# Register projection layer on the model
|
| 219 |
+
proj = torch.nn.Linear(512, args.proj, bias=False)
|
| 220 |
+
torch.nn.init.normal_(proj.weight, mean=0.0, std=0.02)
|
| 221 |
+
# Activation
|
| 222 |
+
if args.proj_act == 'tanh':
|
| 223 |
+
act = torch.nn.Tanh()
|
| 224 |
+
elif args.proj_act == 'relu':
|
| 225 |
+
act = torch.nn.ReLU()
|
| 226 |
+
elif args.proj_act == 'gelu':
|
| 227 |
+
act = torch.nn.GELU()
|
| 228 |
+
elif args.proj_act == 'sigmoid':
|
| 229 |
+
act = torch.nn.Sigmoid()
|
| 230 |
+
else:
|
| 231 |
+
act = torch.nn.Tanh()
|
| 232 |
+
model.proj_head = proj
|
| 233 |
+
model.proj_activation = act
|
| 234 |
+
|
| 235 |
+
# If init_model contains proj_head params, load them now
|
| 236 |
+
if model_state_dict is not None:
|
| 237 |
+
try:
|
| 238 |
+
missing, unexpected = model.load_state_dict(model_state_dict, strict=False)
|
| 239 |
+
except Exception:
|
| 240 |
+
pass
|
| 241 |
+
|
| 242 |
+
model.to(device)
|
| 243 |
+
|
| 244 |
+
return model
|
| 245 |
+
|
| 246 |
+
def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=1.):
|
| 247 |
+
|
| 248 |
+
if hasattr(model, 'module'):
|
| 249 |
+
model = model.module
|
| 250 |
+
|
| 251 |
+
param_optimizer = list(model.named_parameters())
|
| 252 |
+
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
| 253 |
+
|
| 254 |
+
decay_param_tp = [(n, p) for n, p in param_optimizer if not any(nd in n for nd in no_decay)]
|
| 255 |
+
no_decay_param_tp = [(n, p) for n, p in param_optimizer if any(nd in n for nd in no_decay)]
|
| 256 |
+
|
| 257 |
+
decay_clip_param_tp = [(n, p) for n, p in decay_param_tp if "clip." in n]
|
| 258 |
+
decay_noclip_param_tp = [(n, p) for n, p in decay_param_tp if "clip." not in n]
|
| 259 |
+
|
| 260 |
+
no_decay_clip_param_tp = [(n, p) for n, p in no_decay_param_tp if "clip." in n]
|
| 261 |
+
no_decay_noclip_param_tp = [(n, p) for n, p in no_decay_param_tp if "clip." not in n]
|
| 262 |
+
|
| 263 |
+
weight_decay = 0.2
|
| 264 |
+
optimizer_grouped_parameters = [
|
| 265 |
+
{'params': [p for n, p in decay_clip_param_tp], 'weight_decay': weight_decay, 'lr': args.lr * coef_lr},
|
| 266 |
+
{'params': [p for n, p in decay_noclip_param_tp], 'weight_decay': weight_decay},
|
| 267 |
+
{'params': [p for n, p in no_decay_clip_param_tp], 'weight_decay': 0.0, 'lr': args.lr * coef_lr},
|
| 268 |
+
{'params': [p for n, p in no_decay_noclip_param_tp], 'weight_decay': 0.0}
|
| 269 |
+
]
|
| 270 |
+
|
| 271 |
+
scheduler = None
|
| 272 |
+
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, warmup=args.warmup_proportion,
|
| 273 |
+
schedule='warmup_cosine', b1=0.9, b2=0.98, e=1e-6,
|
| 274 |
+
t_total=num_train_optimization_steps, weight_decay=weight_decay,
|
| 275 |
+
max_grad_norm=1.0)
|
| 276 |
+
|
| 277 |
+
# 옵티마이저 만든 뒤 곧장 실행
|
| 278 |
+
name2param = {n: p for n, p in model.named_parameters() if p.requires_grad}
|
| 279 |
+
param2name = {id(p): n for n, p in name2param.items()}
|
| 280 |
+
|
| 281 |
+
for gi, g in enumerate(optimizer.param_groups):
|
| 282 |
+
print(f"[group {gi}] lr={g['lr']:.2e}, params={len(g['params'])}")
|
| 283 |
+
# 각 그룹에서 몇 개만 샘플로 찍기
|
| 284 |
+
for p in g["params"][:8]:
|
| 285 |
+
print(" ", param2name.get(id(p), "?"))
|
| 286 |
+
|
| 287 |
+
# Ensure both ranks finish building the exact same model before DDP wrap.
|
| 288 |
+
# Use a CPU barrier to avoid NCCL device-scoped hangs/timeouts.
|
| 289 |
+
|
| 290 |
+
# Quick debug: log param tensor count per rank
|
| 291 |
+
try:
|
| 292 |
+
num_tensors = len(list(model.parameters()))
|
| 293 |
+
if local_rank == 0:
|
| 294 |
+
print(f"[DDP-DEBUG] rank={args.rank} local_rank={local_rank} param_tensors={num_tensors}")
|
| 295 |
+
else:
|
| 296 |
+
print(f"[DDP-DEBUG] rank={args.rank} local_rank={local_rank} param_tensors={num_tensors}")
|
| 297 |
+
except Exception:
|
| 298 |
+
pass
|
| 299 |
+
|
| 300 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
|
| 301 |
+
output_device=local_rank, find_unused_parameters=False)
|
| 302 |
+
|
| 303 |
+
return optimizer, scheduler, model
|
| 304 |
+
|
| 305 |
+
def save_model(epoch, args, model, optimizer, tr_loss, type_name=""):
|
| 306 |
+
# Only save the model it-self
|
| 307 |
+
model_to_save = model.module if hasattr(model, 'module') else model
|
| 308 |
+
output_model_file = os.path.join(
|
| 309 |
+
args.output_dir, "pytorch_model.bin.{}{}".format("" if type_name=="" else type_name+".", epoch))
|
| 310 |
+
optimizer_state_file = os.path.join(
|
| 311 |
+
args.output_dir, "pytorch_opt.bin.{}{}".format("" if type_name=="" else type_name+".", epoch))
|
| 312 |
+
torch.save(model_to_save.state_dict(), output_model_file)
|
| 313 |
+
torch.save({
|
| 314 |
+
'epoch': epoch,
|
| 315 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 316 |
+
'loss': tr_loss,
|
| 317 |
+
}, optimizer_state_file)
|
| 318 |
+
logger.info("Model saved to %s", output_model_file)
|
| 319 |
+
logger.info("Optimizer saved to %s", optimizer_state_file)
|
| 320 |
+
return output_model_file
|
| 321 |
+
|
| 322 |
+
def load_model(epoch, args, n_gpu, device, model_file=None):
|
| 323 |
+
if model_file is None or len(model_file) == 0:
|
| 324 |
+
model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(epoch))
|
| 325 |
+
if os.path.exists(model_file):
|
| 326 |
+
model_state_dict = torch.load(model_file, map_location='cpu')
|
| 327 |
+
if args.local_rank == 0:
|
| 328 |
+
logger.info("Model loaded from %s", model_file)
|
| 329 |
+
# Prepare model
|
| 330 |
+
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
|
| 331 |
+
model = CLIP4Clip.from_pretrained(args.cross_model, cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
|
| 332 |
+
# Attach projection head if needed and load any matching weights
|
| 333 |
+
if getattr(args, 'proj', 0) and args.proj > 0:
|
| 334 |
+
proj = torch.nn.Linear(512, args.proj, bias=False)
|
| 335 |
+
torch.nn.init.normal_(proj.weight, mean=0.0, std=0.02)
|
| 336 |
+
if args.proj_act == 'tanh':
|
| 337 |
+
act = torch.nn.Tanh()
|
| 338 |
+
elif args.proj_act == 'relu':
|
| 339 |
+
act = torch.nn.ReLU()
|
| 340 |
+
elif args.proj_act == 'gelu':
|
| 341 |
+
act = torch.nn.GELU()
|
| 342 |
+
elif args.proj_act == 'sigmoid':
|
| 343 |
+
act = torch.nn.Sigmoid()
|
| 344 |
+
else:
|
| 345 |
+
act = torch.nn.Tanh()
|
| 346 |
+
model.proj_head = proj
|
| 347 |
+
model.proj_activation = act
|
| 348 |
+
try:
|
| 349 |
+
model.load_state_dict(model_state_dict, strict=False)
|
| 350 |
+
except Exception:
|
| 351 |
+
pass
|
| 352 |
+
|
| 353 |
+
model.to(device)
|
| 354 |
+
else:
|
| 355 |
+
model = None
|
| 356 |
+
|
| 357 |
+
logger.info(f"모델을 로드합니다:{cache_dir}")
|
| 358 |
+
return model
|
| 359 |
+
|
| 360 |
+
def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, scheduler, global_step, local_rank=0):
|
| 361 |
+
global logger
|
| 362 |
+
torch.cuda.empty_cache()
|
| 363 |
+
|
| 364 |
+
net = model.module if hasattr(model, 'module') else model
|
| 365 |
+
net.train()
|
| 366 |
+
|
| 367 |
+
log_step = args.n_display
|
| 368 |
+
start_time = time.time()
|
| 369 |
+
total_loss = 0
|
| 370 |
+
|
| 371 |
+
for step, batch in enumerate(train_dataloader):
|
| 372 |
+
if n_gpu == 1:
|
| 373 |
+
# multi-gpu does scattering it-self
|
| 374 |
+
batch = tuple(t.to(device=device, non_blocking=True) for t in batch)
|
| 375 |
+
|
| 376 |
+
input_ids, input_mask, segment_ids, video, video_mask = batch
|
| 377 |
+
|
| 378 |
+
# If projection head enabled, override loss path to use proj+tanh->sum embeddings
|
| 379 |
+
if getattr(args, 'proj', 0) and args.proj > 0 and hasattr(net, 'proj_head'):
|
| 380 |
+
# Forward backbone encoders only
|
| 381 |
+
sequence_output, visual_output = net.get_sequence_visual_output(
|
| 382 |
+
input_ids, segment_ids, input_mask, video, video_mask)
|
| 383 |
+
|
| 384 |
+
# Gather across processes for global negatives
|
| 385 |
+
sequence_output = AllGather.apply(sequence_output, args)
|
| 386 |
+
visual_output = AllGather.apply(visual_output, args)
|
| 387 |
+
video_mask_g = AllGather.apply(video_mask.view(-1, video_mask.shape[-1]), args)
|
| 388 |
+
|
| 389 |
+
# Text: [B,1,512] -> [B,512] -> proj -> act -> L2
|
| 390 |
+
txt = sequence_output.squeeze(1)
|
| 391 |
+
txt = net.proj_activation(net.proj_head(txt))
|
| 392 |
+
txt = F.normalize(txt, dim=-1)
|
| 393 |
+
|
| 394 |
+
# Video: [B,T,512] -> proj -> act -> mask -> sum(T) -> L2
|
| 395 |
+
vid = net.proj_activation(net.proj_head(visual_output))
|
| 396 |
+
vm = video_mask_g.to(dtype=vid.dtype).unsqueeze(-1)
|
| 397 |
+
vid = (vid * vm).sum(dim=1)
|
| 398 |
+
vid = F.normalize(vid, dim=-1)
|
| 399 |
+
|
| 400 |
+
logit_scale = net.clip.logit_scale.exp()
|
| 401 |
+
sim = logit_scale * torch.matmul(txt, vid.t())
|
| 402 |
+
|
| 403 |
+
# Same symmetric CE loss as original
|
| 404 |
+
loss = net.loss_fct(sim) + net.loss_fct(sim.T)
|
| 405 |
+
loss = loss * 0.5
|
| 406 |
+
else:
|
| 407 |
+
loss = model(input_ids, segment_ids, input_mask, video, video_mask)
|
| 408 |
+
|
| 409 |
+
if n_gpu > 1:
|
| 410 |
+
loss = loss.mean() # mean() to average on multi-gpu.
|
| 411 |
+
if args.gradient_accumulation_steps > 1:
|
| 412 |
+
loss = loss / args.gradient_accumulation_steps
|
| 413 |
+
|
| 414 |
+
loss.backward()
|
| 415 |
+
|
| 416 |
+
total_loss += float(loss)
|
| 417 |
+
if (step + 1) % args.gradient_accumulation_steps == 0:
|
| 418 |
+
|
| 419 |
+
torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
|
| 420 |
+
|
| 421 |
+
if scheduler is not None:
|
| 422 |
+
scheduler.step() # Update learning rate schedule
|
| 423 |
+
|
| 424 |
+
optimizer.step()
|
| 425 |
+
optimizer.zero_grad()
|
| 426 |
+
|
| 427 |
+
# https://github.com/openai/CLIP/issues/46
|
| 428 |
+
|
| 429 |
+
torch.clamp_(net.clip.logit_scale.data, max=np.log(100))
|
| 430 |
+
|
| 431 |
+
global_step += 1
|
| 432 |
+
if global_step % log_step == 0 and local_rank == 0:
|
| 433 |
+
logger.info("Epoch: %d/%s, Step: %d/%d, Lr: %s, Loss: %f, Time/step: %f", epoch + 1,
|
| 434 |
+
args.epochs, step + 1,
|
| 435 |
+
len(train_dataloader), "-".join([str('%.9f'%itm) for itm in sorted(list(set(optimizer.get_lr())))]),
|
| 436 |
+
float(loss),
|
| 437 |
+
(time.time() - start_time) / (log_step * args.gradient_accumulation_steps))
|
| 438 |
+
start_time = time.time()
|
| 439 |
+
|
| 440 |
+
total_loss = total_loss / len(train_dataloader)
|
| 441 |
+
return total_loss, global_step
|
| 442 |
+
|
| 443 |
+
def _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_sequence_output_list, batch_visual_output_list):
|
| 444 |
+
|
| 445 |
+
sim_matrix = []
|
| 446 |
+
for idx1, b1 in enumerate(batch_list_t):
|
| 447 |
+
input_mask, segment_ids, *_tmp = b1
|
| 448 |
+
sequence_output = batch_sequence_output_list[idx1]
|
| 449 |
+
each_row = []
|
| 450 |
+
for idx2, b2 in enumerate(batch_list_v):
|
| 451 |
+
video_mask, *_tmp = b2
|
| 452 |
+
visual_output = batch_visual_output_list[idx2]
|
| 453 |
+
b1b2_logits, *_tmp = model.get_similarity_logits(sequence_output, visual_output, input_mask, video_mask,
|
| 454 |
+
loose_type=model.loose_type)
|
| 455 |
+
b1b2_logits = b1b2_logits.cpu().detach().numpy()
|
| 456 |
+
each_row.append(b1b2_logits)
|
| 457 |
+
each_row = np.concatenate(tuple(each_row), axis=-1)
|
| 458 |
+
sim_matrix.append(each_row)
|
| 459 |
+
return sim_matrix
|
| 460 |
+
|
| 461 |
+
def eval_epoch(args, model, test_dataloader, device, n_gpu, tokenizer):
|
| 462 |
+
import numpy as np
|
| 463 |
+
import os
|
| 464 |
+
import torch
|
| 465 |
+
import matplotlib.pyplot as plt
|
| 466 |
+
|
| 467 |
+
def _decode_query(tokenizer, ids_tensor):
|
| 468 |
+
# ids_tensor: 1D tensor (token ids)
|
| 469 |
+
try:
|
| 470 |
+
if isinstance(ids_tensor, torch.Tensor):
|
| 471 |
+
ids = ids_tensor.cpu().numpy().tolist()
|
| 472 |
+
else:
|
| 473 |
+
ids = ids_tensor.tolist() if hasattr(ids_tensor, 'tolist') else list(ids_tensor)
|
| 474 |
+
|
| 475 |
+
# ClipTokenizer의 특수 토큰 ID들
|
| 476 |
+
start_token_id = tokenizer.encoder.get('<|startoftext|>', 49406) # 기본값
|
| 477 |
+
end_token_id = tokenizer.encoder.get('<|endoftext|>', 49407) # 기본값
|
| 478 |
+
|
| 479 |
+
# 패딩 토큰(0)과 특수 토큰 제거
|
| 480 |
+
clean_ids = []
|
| 481 |
+
for token_id in ids:
|
| 482 |
+
if token_id > 0 and token_id != start_token_id and token_id != end_token_id:
|
| 483 |
+
clean_ids.append(token_id)
|
| 484 |
+
|
| 485 |
+
if not clean_ids:
|
| 486 |
+
return "<empty_query>"
|
| 487 |
+
|
| 488 |
+
# 유효하지 않은 토큰 ID 필터링 (vocab 범위 내)
|
| 489 |
+
vocab_size = len(tokenizer.decoder)
|
| 490 |
+
valid_ids = [tid for tid in clean_ids if tid < vocab_size]
|
| 491 |
+
|
| 492 |
+
if not valid_ids:
|
| 493 |
+
return "<invalid_tokens>"
|
| 494 |
+
|
| 495 |
+
# 디코딩 시도
|
| 496 |
+
try:
|
| 497 |
+
decoded_text = tokenizer.decode(valid_ids)
|
| 498 |
+
return decoded_text.strip()
|
| 499 |
+
except KeyError as e:
|
| 500 |
+
# 개별 토큰별로 디코딩 시도
|
| 501 |
+
decoded_tokens = []
|
| 502 |
+
for tid in valid_ids:
|
| 503 |
+
if tid in tokenizer.decoder:
|
| 504 |
+
decoded_tokens.append(tokenizer.decoder[tid])
|
| 505 |
+
else:
|
| 506 |
+
decoded_tokens.append(f"<unk_{tid}>")
|
| 507 |
+
text = ''.join(decoded_tokens)
|
| 508 |
+
# BPE 후처리
|
| 509 |
+
text = text.replace('</w>', ' ').strip()
|
| 510 |
+
return text if text else "<decode_partial_error>"
|
| 511 |
+
except Exception as e:
|
| 512 |
+
return f"<decode_error: {str(e)[:50]}>"
|
| 513 |
+
|
| 514 |
+
except Exception as e:
|
| 515 |
+
return f"<general_error: {str(e)[:50]}>"
|
| 516 |
+
|
| 517 |
+
def _get_video_ids_from_dataset(dataset, num_videos):
|
| 518 |
+
# 다양한 후보 속성명 시도 → 없으면 0..N-1 인덱스 문자열로 대체
|
| 519 |
+
for attr in ["video_list", "video_ids", "video_names", "videos", "vids", "id_list"]:
|
| 520 |
+
if hasattr(dataset, attr):
|
| 521 |
+
obj = getattr(dataset, attr)
|
| 522 |
+
try:
|
| 523 |
+
if isinstance(obj, (list, tuple)) and len(obj) == num_videos:
|
| 524 |
+
return list(map(str, obj))
|
| 525 |
+
except Exception:
|
| 526 |
+
pass
|
| 527 |
+
return [str(i) for i in range(num_videos)]
|
| 528 |
+
|
| 529 |
+
global logger
|
| 530 |
+
|
| 531 |
+
multi_sentence_ = False
|
| 532 |
+
cut_off_points_, sentence_num_, video_num_ = [], -1, -1
|
| 533 |
+
|
| 534 |
+
if hasattr(model, 'module'):
|
| 535 |
+
model = model.module.to(device)
|
| 536 |
+
else:
|
| 537 |
+
model = model.to(device)
|
| 538 |
+
|
| 539 |
+
if hasattr(model, 'module'):
|
| 540 |
+
model.module.eval()
|
| 541 |
+
else:
|
| 542 |
+
model.eval()
|
| 543 |
+
|
| 544 |
+
logger.info("Model %s", "training" if model.training else "eval")
|
| 545 |
+
|
| 546 |
+
# suffix for cache/result naming
|
| 547 |
+
suffix = "_hash" if getattr(args, "use_clip4hashing", False) else ""
|
| 548 |
+
suffix += "_rff" if args.use_rff else ""
|
| 549 |
+
suffix += f"_proj{args.proj}" if getattr(args, 'proj', 0) and args.proj > 0 else ""
|
| 550 |
+
suffix += "_binary" if getattr(args, 'binary_eval', False) else ""
|
| 551 |
+
suffix += "_trained" if args.init_model else ""
|
| 552 |
+
|
| 553 |
+
# (A) 캐시 로드/생성
|
| 554 |
+
if "train" in args.val_csv and "10k" in args.val_csv:
|
| 555 |
+
cache_name = f"{args.datatype}_train_test_10k_cache{suffix}.pt"
|
| 556 |
+
logger.info(f"9k 훈련 데이터 캐시 생성: {cache_name}")
|
| 557 |
+
else:
|
| 558 |
+
cache_name = f"{args.datatype}_eval_cache{suffix}.pt"
|
| 559 |
+
logger.info(f"평가 데이터 캐시: {cache_name}")
|
| 560 |
+
|
| 561 |
+
cache_path = os.path.join(args.output_dir, cache_name)
|
| 562 |
+
|
| 563 |
+
loaded_from_cache = False
|
| 564 |
+
if os.path.exists(cache_path):
|
| 565 |
+
logger.info(f"캐시된 피처를 로드합니다: {cache_path}")
|
| 566 |
+
cache = torch.load(cache_path, map_location=device)
|
| 567 |
+
batch_sequence_output_list = cache['batch_sequence_output_list']
|
| 568 |
+
batch_visual_output_list = cache['batch_visual_output_list']
|
| 569 |
+
batch_list_t = cache['batch_list_t']
|
| 570 |
+
batch_list_v = cache['batch_list_v']
|
| 571 |
+
text_input_ids_list = cache.get('text_input_ids_list', None)
|
| 572 |
+
video_ids = cache.get('video_ids', None)
|
| 573 |
+
loaded_from_cache = True
|
| 574 |
+
|
| 575 |
+
logger.info(f"[Cache] 텍스트 피쳐 개수={len(batch_sequence_output_list)} "
|
| 576 |
+
f"각 텐서 shape={batch_sequence_output_list[0].shape}")
|
| 577 |
+
logger.info(f"[Cache] 비디오 피쳐 개수={len(batch_visual_output_list)} "
|
| 578 |
+
f"각 텐서 shape={batch_visual_output_list[0].shape}")
|
| 579 |
+
else:
|
| 580 |
+
print("Caching feature..")
|
| 581 |
+
if hasattr(test_dataloader.dataset, 'multi_sentence_per_video') and \
|
| 582 |
+
test_dataloader.dataset.multi_sentence_per_video:
|
| 583 |
+
multi_sentence_ = True
|
| 584 |
+
cut_off_points_ = test_dataloader.dataset.cut_off_points
|
| 585 |
+
sentence_num_ = test_dataloader.dataset.sentence_num
|
| 586 |
+
video_num_ = test_dataloader.dataset.video_num
|
| 587 |
+
cut_off_points_ = [itm - 1 for itm in cut_off_points_]
|
| 588 |
+
logger.warning("Eval under multi-sentence-per-video. sentence num: %s, video num: %s",
|
| 589 |
+
sentence_num_, video_num_)
|
| 590 |
+
|
| 591 |
+
with torch.no_grad():
|
| 592 |
+
batch_list_t = []
|
| 593 |
+
batch_list_v = []
|
| 594 |
+
batch_sequence_output_list, batch_visual_output_list = [], []
|
| 595 |
+
text_input_ids_list = []
|
| 596 |
+
total_video_num = 0
|
| 597 |
+
|
| 598 |
+
for bid, batch in enumerate(test_dataloader):
|
| 599 |
+
batch = tuple(t.to(device) for t in batch)
|
| 600 |
+
input_ids, input_mask, segment_ids, video, video_mask = batch
|
| 601 |
+
|
| 602 |
+
if multi_sentence_:
|
| 603 |
+
b, *_t = video.shape
|
| 604 |
+
sequence_output = model.get_sequence_output(input_ids, segment_ids, input_mask)
|
| 605 |
+
batch_sequence_output_list.append(sequence_output)
|
| 606 |
+
# input_ids를 함께 보관 (run_on_single_gpu는 *_tmp로 무시하므로 안전)
|
| 607 |
+
batch_list_t.append((input_mask, segment_ids, input_ids.detach().cpu()))
|
| 608 |
+
|
| 609 |
+
s_, e_ = total_video_num, total_video_num + b
|
| 610 |
+
filter_inds = [itm - s_ for itm in cut_off_points_ if s_ <= itm < e_]
|
| 611 |
+
if len(filter_inds) > 0:
|
| 612 |
+
video, video_mask = video[filter_inds, ...], video_mask[filter_inds, ...]
|
| 613 |
+
visual_output = model.get_visual_output(video, video_mask)
|
| 614 |
+
batch_visual_output_list.append(visual_output)
|
| 615 |
+
batch_list_v.append((video_mask,))
|
| 616 |
+
total_video_num += b
|
| 617 |
+
else:
|
| 618 |
+
sequence_output, visual_output = model.get_sequence_visual_output(
|
| 619 |
+
input_ids, segment_ids, input_mask, video, video_mask)
|
| 620 |
+
|
| 621 |
+
batch_sequence_output_list.append(sequence_output)
|
| 622 |
+
batch_list_t.append((input_mask, segment_ids, input_ids.detach().cpu()))
|
| 623 |
+
|
| 624 |
+
batch_visual_output_list.append(visual_output)
|
| 625 |
+
batch_list_v.append((video_mask,))
|
| 626 |
+
|
| 627 |
+
print("{}/{}\r".format(bid, len(test_dataloader)), end="")
|
| 628 |
+
|
| 629 |
+
# 비디오 ID 목록 구성 (데이터셋 노��� 없으면 0..N-1)
|
| 630 |
+
num_videos = int(sum(bv.shape[0] for bv in batch_visual_output_list))
|
| 631 |
+
video_ids = _get_video_ids_from_dataset(test_dataloader.dataset, num_videos)
|
| 632 |
+
|
| 633 |
+
logger.info(f"추출된 피처를 캐시에 저장합니다: {cache_path}")
|
| 634 |
+
torch.save({
|
| 635 |
+
'batch_sequence_output_list': batch_sequence_output_list,
|
| 636 |
+
'batch_visual_output_list': batch_visual_output_list,
|
| 637 |
+
'batch_list_t': batch_list_t,
|
| 638 |
+
'batch_list_v': batch_list_v,
|
| 639 |
+
'text_input_ids_list': text_input_ids_list,
|
| 640 |
+
'video_ids': video_ids,
|
| 641 |
+
}, cache_path)
|
| 642 |
+
|
| 643 |
+
logger.info(f"[Cache] 텍스트 피쳐 개수={len(batch_sequence_output_list)} "
|
| 644 |
+
f"각 텐서 shape={batch_sequence_output_list[0].shape}")
|
| 645 |
+
logger.info(f"[Cache] 비디오 피쳐 개수={len(batch_visual_output_list)} "
|
| 646 |
+
f"각 텐서 shape={batch_visual_output_list[0].shape}")
|
| 647 |
+
|
| 648 |
+
# 캐시에 text_input_ids_list가 없으면, 한 번 더 훑어서 수집 (구버전 캐시 호환)
|
| 649 |
+
if loaded_from_cache and 'text_input_ids_list' not in cache:
|
| 650 |
+
logger.info("캐시에 text_input_ids_list가 없어 재수집합니다(호환성 경로).")
|
| 651 |
+
text_input_ids_list = []
|
| 652 |
+
with torch.no_grad():
|
| 653 |
+
for batch in test_dataloader:
|
| 654 |
+
input_ids = batch[0].detach().cpu()
|
| 655 |
+
text_input_ids_list.append(input_ids)
|
| 656 |
+
elif loaded_from_cache and text_input_ids_list is None:
|
| 657 |
+
# batch_list_t에서 input_ids 추출
|
| 658 |
+
logger.info("batch_list_t에서 input_ids를 추출합니다.")
|
| 659 |
+
text_input_ids_list = []
|
| 660 |
+
for input_mask, segment_ids, input_ids in batch_list_t:
|
| 661 |
+
text_input_ids_list.append(input_ids)
|
| 662 |
+
|
| 663 |
+
# video_ids가 없으면 만들어준다(구버전 캐시 호환)
|
| 664 |
+
if loaded_from_cache and cache.get('video_ids', None) is None:
|
| 665 |
+
num_videos = int(sum(bv.shape[0] for bv in batch_visual_output_list))
|
| 666 |
+
video_ids = _get_video_ids_from_dataset(test_dataloader.dataset, num_videos)
|
| 667 |
+
|
| 668 |
+
# (B) 유사도 행렬 계산
|
| 669 |
+
def _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_sequence_output_list, batch_visual_output_list):
|
| 670 |
+
sim_matrix = []
|
| 671 |
+
use_proj = getattr(args, 'proj', 0) and args.proj > 0 and hasattr(model, 'proj_head')
|
| 672 |
+
use_binary = getattr(args, 'binary_eval', False) and use_proj
|
| 673 |
+
for idx1, b1 in enumerate(batch_list_t):
|
| 674 |
+
input_mask, segment_ids, *_tmp = b1
|
| 675 |
+
sequence_output = batch_sequence_output_list[idx1]
|
| 676 |
+
|
| 677 |
+
each_row = []
|
| 678 |
+
if use_proj:
|
| 679 |
+
# Text: [B,1,512] -> [B,512] -> proj -> act
|
| 680 |
+
t = sequence_output.squeeze(1)
|
| 681 |
+
t = model.proj_activation(model.proj_head(t))
|
| 682 |
+
if use_binary:
|
| 683 |
+
t_vec = torch.where(t > 0, torch.tensor(1.0, device=t.device, dtype=t.dtype),
|
| 684 |
+
torch.tensor(-1.0, device=t.device, dtype=t.dtype))
|
| 685 |
+
else:
|
| 686 |
+
t_vec = F.normalize(t, dim=-1)
|
| 687 |
+
|
| 688 |
+
for idx2, b2 in enumerate(batch_list_v):
|
| 689 |
+
video_mask, *_tmp = b2
|
| 690 |
+
visual_output = batch_visual_output_list[idx2]
|
| 691 |
+
v = model.proj_activation(model.proj_head(visual_output)) # [B, T, P]
|
| 692 |
+
|
| 693 |
+
# Robust mask handling: ensure vm is [B, T, 1]
|
| 694 |
+
vm = video_mask
|
| 695 |
+
# Squeeze stray singleton dims (e.g., [B, T, 1] -> [B, T])
|
| 696 |
+
while vm.dim() > 2 and vm.size(-1) == 1:
|
| 697 |
+
vm = vm.squeeze(-1)
|
| 698 |
+
# If still higher-rank, flatten to [B, -1] safely
|
| 699 |
+
if vm.dim() > 2:
|
| 700 |
+
vm = vm.view(vm.size(0), -1)
|
| 701 |
+
# If [T] vector sneaks in, expand batch
|
| 702 |
+
if vm.dim() == 1:
|
| 703 |
+
vm = vm.unsqueeze(0)
|
| 704 |
+
# Align time length to visual features
|
| 705 |
+
if vm.size(1) != v.size(1):
|
| 706 |
+
vm = vm[:, :v.size(1)]
|
| 707 |
+
vm = vm.to(dtype=v.dtype).unsqueeze(-1) # [B, T, 1]
|
| 708 |
+
|
| 709 |
+
# Masked temporal sum -> [B, P]
|
| 710 |
+
v = (v * vm).sum(dim=1)
|
| 711 |
+
# Squeeze any trailing singleton dim that might remain
|
| 712 |
+
if v.dim() > 2 and v.size(-1) == 1:
|
| 713 |
+
v = v.squeeze(-1)
|
| 714 |
+
if use_binary:
|
| 715 |
+
v_vec = torch.where(v > 0, torch.tensor(1.0, device=v.device, dtype=v.dtype),
|
| 716 |
+
torch.tensor(-1.0, device=v.device, dtype=v.dtype))
|
| 717 |
+
scores = torch.matmul(t_vec, v_vec.t())
|
| 718 |
+
else:
|
| 719 |
+
v_vec = F.normalize(v, dim=-1)
|
| 720 |
+
scores = torch.matmul(t_vec, v_vec.t())
|
| 721 |
+
each_row.append(scores.cpu().detach().numpy())
|
| 722 |
+
else:
|
| 723 |
+
for idx2, b2 in enumerate(batch_list_v):
|
| 724 |
+
video_mask, *_tmp = b2
|
| 725 |
+
visual_output = batch_visual_output_list[idx2]
|
| 726 |
+
b1b2_logits, *_tmp = model.get_similarity_logits(
|
| 727 |
+
sequence_output, visual_output, input_mask, video_mask, loose_type=model.loose_type)
|
| 728 |
+
b1b2_logits = b1b2_logits.cpu().detach().numpy()
|
| 729 |
+
each_row.append(b1b2_logits)
|
| 730 |
+
|
| 731 |
+
each_row = np.concatenate(tuple(each_row), axis=-1)
|
| 732 |
+
sim_matrix.append(each_row)
|
| 733 |
+
return sim_matrix
|
| 734 |
+
|
| 735 |
+
if n_gpu > 1:
|
| 736 |
+
device_ids = list(range(n_gpu))
|
| 737 |
+
batch_list_t_splits, batch_list_v_splits = [], []
|
| 738 |
+
batch_t_output_splits, batch_v_output_splits = [], []
|
| 739 |
+
bacth_len = len(batch_list_t)
|
| 740 |
+
split_len = (bacth_len + n_gpu - 1) // n_gpu
|
| 741 |
+
for dev_id in device_ids:
|
| 742 |
+
s_, e_ = dev_id * split_len, (dev_id + 1) * split_len
|
| 743 |
+
if dev_id == 0:
|
| 744 |
+
batch_list_t_splits.append(batch_list_t[s_:e_]); batch_list_v_splits.append(batch_list_v)
|
| 745 |
+
batch_t_output_splits.append(batch_sequence_output_list[s_:e_]); batch_v_output_splits.append(batch_visual_output_list)
|
| 746 |
+
else:
|
| 747 |
+
devc = torch.device(f'cuda:{dev_id}')
|
| 748 |
+
devc_batch_list = [tuple(t.to(devc) for t in b) for b in batch_list_t[s_:e_]]
|
| 749 |
+
batch_list_t_splits.append(devc_batch_list)
|
| 750 |
+
devc_batch_list = [tuple(t.to(devc) for t in b) for b in batch_list_v]
|
| 751 |
+
batch_list_v_splits.append(devc_batch_list)
|
| 752 |
+
devc_batch_list = [b.to(devc) for b in batch_sequence_output_list[s_:e_]]
|
| 753 |
+
batch_t_output_splits.append(devc_batch_list)
|
| 754 |
+
devc_batch_list = [b.to(devc) for b in batch_visual_output_list]
|
| 755 |
+
batch_v_output_splits.append(devc_batch_list)
|
| 756 |
+
|
| 757 |
+
parameters_tuple_list = [(batch_list_t_splits[dev_id], batch_list_v_splits[dev_id],
|
| 758 |
+
batch_t_output_splits[dev_id], batch_v_output_splits[dev_id]) for dev_id in device_ids]
|
| 759 |
+
parallel_outputs = parallel_apply(_run_on_single_gpu, model, parameters_tuple_list, device_ids)
|
| 760 |
+
sim_matrix = []
|
| 761 |
+
for idx in range(len(parallel_outputs)):
|
| 762 |
+
sim_matrix += parallel_outputs[idx]
|
| 763 |
+
sim_matrix = np.concatenate(tuple(sim_matrix), axis=0)
|
| 764 |
+
else:
|
| 765 |
+
sim_matrix = _run_on_single_gpu(model, batch_list_t, batch_list_v,
|
| 766 |
+
batch_sequence_output_list, batch_visual_output_list)
|
| 767 |
+
sim_matrix = np.concatenate(tuple(sim_matrix), axis=0)
|
| 768 |
+
|
| 769 |
+
# (C) 멀티센텐스 처리 및 메트릭
|
| 770 |
+
if hasattr(test_dataloader.dataset, 'multi_sentence_per_video') and \
|
| 771 |
+
test_dataloader.dataset.multi_sentence_per_video:
|
| 772 |
+
multi_sentence_ = True
|
| 773 |
+
|
| 774 |
+
if multi_sentence_:
|
| 775 |
+
logger.info("before reshape, sim matrix size: %d x %d", sim_matrix.shape[0], sim_matrix.shape[1])
|
| 776 |
+
sim_matrix_flat = sim_matrix.copy() # 쿼리별 Top-K 용 2D 보관
|
| 777 |
+
|
| 778 |
+
cut_off_points2len_ = [itm + 1 for itm in cut_off_points_]
|
| 779 |
+
max_length = max([e_-s_ for s_, e_ in zip([0]+cut_off_points2len_[:-1], cut_off_points2len_)])
|
| 780 |
+
sim_matrix_new = []
|
| 781 |
+
for s_, e_ in zip([0] + cut_off_points2len_[:-1], cut_off_points2len_):
|
| 782 |
+
sim_matrix_new.append(np.concatenate((sim_matrix[s_:e_],
|
| 783 |
+
np.full((max_length-e_+s_, sim_matrix.shape[1]), -np.inf)), axis=0))
|
| 784 |
+
sim_matrix = np.stack(tuple(sim_matrix_new), axis=0)
|
| 785 |
+
logger.info("after reshape, sim matrix size: %d x %d x %d",
|
| 786 |
+
sim_matrix.shape[0], sim_matrix.shape[1], sim_matrix.shape[2])
|
| 787 |
+
|
| 788 |
+
tv_metrics = tensor_text_to_video_metrics(sim_matrix)
|
| 789 |
+
vt_metrics = compute_metrics(tensor_video_to_text_sim(sim_matrix))
|
| 790 |
+
else:
|
| 791 |
+
logger.info("sim matrix size: %d x %d", sim_matrix.shape[0], sim_matrix.shape[1])
|
| 792 |
+
# 히트맵 저장(샘플)
|
| 793 |
+
plt.figure(figsize=(8,6))
|
| 794 |
+
plt.imshow(sim_matrix[:100, :100], aspect='auto')
|
| 795 |
+
plt.title('Similarity Matrix Heatmap')
|
| 796 |
+
plt.xlabel('Video Index')
|
| 797 |
+
plt.ylabel('Text Index')
|
| 798 |
+
plt.tight_layout()
|
| 799 |
+
out_path = os.path.join(args.output_dir, 'sim_matrix_heatmap.png')
|
| 800 |
+
plt.savefig(out_path); plt.close()
|
| 801 |
+
logger.info(f"Saved sim_matrix heatmap to {out_path}")
|
| 802 |
+
|
| 803 |
+
sim_matrix_flat = sim_matrix # 2D 그대로
|
| 804 |
+
tv_metrics = compute_metrics(sim_matrix)
|
| 805 |
+
vt_metrics = compute_metrics(sim_matrix.T)
|
| 806 |
+
logger.info('\t Length-T: %d, Length-V:%d', len(sim_matrix), len(sim_matrix[0]))
|
| 807 |
+
|
| 808 |
+
logger.info("Text-to-Video:")
|
| 809 |
+
logger.info('\t>>> R@1: %.1f - R@5: %.1f - R@10: %.1f - Median R: %.1f - Mean R: %.1f',
|
| 810 |
+
tv_metrics['R1'], tv_metrics['R5'], tv_metrics['R10'], tv_metrics['MR'], tv_metrics['MeanR'])
|
| 811 |
+
logger.info("Video-to-Text:")
|
| 812 |
+
logger.info('\t>>> V2T$R@1: %.1f - V2T$R@5: %.1f - V2T$R@10: %.1f - V2T$Median R: %.1f - V2T$Mean R: %.1f',
|
| 813 |
+
vt_metrics['R1'], vt_metrics['R5'], vt_metrics['R10'], vt_metrics['MR'], vt_metrics['MeanR'])
|
| 814 |
+
|
| 815 |
+
# (D) 쿼리 텍스트 복원 + Top-10 덤프
|
| 816 |
+
# text_input_ids_list: List[Tensor[B_i, L]]
|
| 817 |
+
all_queries = []
|
| 818 |
+
logger.info(f"text_input_ids_list 개수: {len(text_input_ids_list)}")
|
| 819 |
+
|
| 820 |
+
for batch_idx, ids_batch in enumerate(text_input_ids_list):
|
| 821 |
+
if ids_batch is None:
|
| 822 |
+
logger.warning(f"배치 {batch_idx}: ids_batch가 None입니다.")
|
| 823 |
+
continue
|
| 824 |
+
|
| 825 |
+
try:
|
| 826 |
+
ids_batch = ids_batch if isinstance(ids_batch, torch.Tensor) else torch.as_tensor(ids_batch)
|
| 827 |
+
logger.info(f"배치 {batch_idx}: shape={ids_batch.shape}")
|
| 828 |
+
|
| 829 |
+
for row_idx, row in enumerate(ids_batch):
|
| 830 |
+
decoded = _decode_query(tokenizer, row)
|
| 831 |
+
all_queries.append(decoded)
|
| 832 |
+
if batch_idx == 0 and row_idx < 3: # 첫 배치의 처음 3개만 샘플로 출력
|
| 833 |
+
logger.info(f"샘플 디코딩 결과 [{batch_idx}-{row_idx}]: '{decoded}'")
|
| 834 |
+
|
| 835 |
+
except Exception as e:
|
| 836 |
+
logger.error(f"배치 {batch_idx} 처리 중 오류: {str(e)}")
|
| 837 |
+
# 에러가 발생해도 계속 진행
|
| 838 |
+
continue
|
| 839 |
+
|
| 840 |
+
logger.info(f"총 {len(all_queries)}개의 쿼리가 디코딩되었습니다.")
|
| 841 |
+
|
| 842 |
+
# video_ids 길이 보정(안전)
|
| 843 |
+
num_videos = sim_matrix_flat.shape[1]
|
| 844 |
+
if 'video_ids' in locals():
|
| 845 |
+
if len(video_ids) != num_videos:
|
| 846 |
+
logger.warning("video_ids 길이(%d)와 비디오 수(%d)가 달라 index로 대체합니다.",
|
| 847 |
+
len(video_ids), num_videos)
|
| 848 |
+
video_ids = [str(i) for i in range(num_videos)]
|
| 849 |
+
else:
|
| 850 |
+
video_ids = [str(i) for i in range(num_videos)]
|
| 851 |
+
|
| 852 |
+
# 저장 파일
|
| 853 |
+
topk = 10
|
| 854 |
+
out_tsv = os.path.join(args.output_dir, f"t2v_top10{suffix}.tsv")
|
| 855 |
+
out_json = os.path.join(args.output_dir, f"t2v_top10{suffix}.json")
|
| 856 |
+
|
| 857 |
+
if args.local_rank == 0:
|
| 858 |
+
import json
|
| 859 |
+
|
| 860 |
+
# TSV 파일 저장
|
| 861 |
+
with open(out_tsv, "w", encoding="utf-8") as f:
|
| 862 |
+
f.write("query_idx\tquery\tvideo_rank\tvideo_id\tvideo_idx\tscore\n")
|
| 863 |
+
for qi, q in enumerate(all_queries):
|
| 864 |
+
scores = sim_matrix_flat[qi]
|
| 865 |
+
# 효율: argpartition 후 정렬
|
| 866 |
+
idxs = np.argpartition(-scores, kth=min(topk, len(scores)-1))[:topk]
|
| 867 |
+
idxs = idxs[np.argsort(-scores[idxs])]
|
| 868 |
+
for rank, vidx in enumerate(idxs, 1):
|
| 869 |
+
f.write(f"{qi}\t{q}\t{rank}\t{video_ids[vidx]}\t{int(vidx)}\t{float(scores[vidx]):.6f}\n")
|
| 870 |
+
|
| 871 |
+
# JSON 파일 저장 (구조화된 형태)
|
| 872 |
+
results_dict = {}
|
| 873 |
+
for qi, q in enumerate(all_queries):
|
| 874 |
+
scores = sim_matrix_flat[qi]
|
| 875 |
+
idxs = np.argpartition(-scores, kth=min(topk, len(scores)-1))[:topk]
|
| 876 |
+
idxs = idxs[np.argsort(-scores[idxs])]
|
| 877 |
+
|
| 878 |
+
results_dict[f"query_{qi+1}"] = {
|
| 879 |
+
"query_text": q,
|
| 880 |
+
"top_videos": []
|
| 881 |
+
}
|
| 882 |
+
|
| 883 |
+
for rank, vidx in enumerate(idxs, 1):
|
| 884 |
+
results_dict[f"query_{qi+1}"]["top_videos"].append({
|
| 885 |
+
"rank": rank,
|
| 886 |
+
"video_id": video_ids[vidx],
|
| 887 |
+
"video_idx": int(vidx),
|
| 888 |
+
"score": float(scores[vidx])
|
| 889 |
+
})
|
| 890 |
+
|
| 891 |
+
with open(out_json, "w", encoding="utf-8") as f:
|
| 892 |
+
json.dump(results_dict, f, ensure_ascii=False, indent=2)
|
| 893 |
+
|
| 894 |
+
logger.info("T2V Top-10 per query 저장 완료:")
|
| 895 |
+
logger.info(" TSV 파일: %s", out_tsv)
|
| 896 |
+
logger.info(" JSON 파일: %s", out_json)
|
| 897 |
+
logger.info("총 %d개 쿼리에 대한 top-10 결과가 저장되었습니다.", len(all_queries))
|
| 898 |
+
|
| 899 |
+
# # 로그에 모든 쿼리의 Top-10 결과 출력
|
| 900 |
+
# logger.info("=== Query-wise Top-10 Results (전체 %d개 쿼리) ===", len(all_queries))
|
| 901 |
+
# for qi in range(len(all_queries)):
|
| 902 |
+
# scores = sim_matrix_flat[qi]
|
| 903 |
+
# idxs = np.argpartition(-scores, kth=min(topk, len(scores)-1))[:topk]
|
| 904 |
+
# idxs = idxs[np.argsort(-scores[idxs])]
|
| 905 |
+
|
| 906 |
+
# logger.info(f"Query {qi+1}: \"{all_queries[qi]}\"")
|
| 907 |
+
# for rank, vidx in enumerate(idxs, 1):
|
| 908 |
+
# logger.info(f" Rank {rank}: video_id={video_ids[vidx]}, video_idx={vidx}, score={scores[vidx]:.6f}")
|
| 909 |
+
# logger.info("---")
|
| 910 |
+
|
| 911 |
+
# logger.info("=== 모든 쿼리 결과 출력 완료 ===")
|
| 912 |
+
|
| 913 |
+
return tv_metrics['R1']
|
| 914 |
+
|
| 915 |
+
|
| 916 |
+
def main():
|
| 917 |
+
global logger
|
| 918 |
+
args = get_args()
|
| 919 |
+
|
| 920 |
+
if "LOCAL_RANK" in os.environ:
|
| 921 |
+
try:
|
| 922 |
+
args.local_rank = int(os.environ["LOCAL_RANK"])
|
| 923 |
+
except Exception:
|
| 924 |
+
pass
|
| 925 |
+
torch.cuda.set_device(args.local_rank)
|
| 926 |
+
|
| 927 |
+
from datetime import timedelta
|
| 928 |
+
import torch.distributed as dist
|
| 929 |
+
dist.init_process_group(
|
| 930 |
+
backend="nccl",
|
| 931 |
+
init_method="env://",
|
| 932 |
+
timeout=timedelta(minutes=30)
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
args = set_seed_logger(args)
|
| 936 |
+
device, n_gpu = init_device(args, args.local_rank)
|
| 937 |
+
|
| 938 |
+
tokenizer = ClipTokenizer()
|
| 939 |
+
|
| 940 |
+
assert args.task_type == "retrieval"
|
| 941 |
+
model = init_model(args, device, n_gpu, args.local_rank)
|
| 942 |
+
|
| 943 |
+
## ####################################
|
| 944 |
+
# freeze testing
|
| 945 |
+
## ####################################
|
| 946 |
+
assert args.freeze_layer_num <= 12 and args.freeze_layer_num >= -1
|
| 947 |
+
if hasattr(model, "clip") and args.freeze_layer_num > -1:
|
| 948 |
+
for name, param in model.clip.named_parameters():
|
| 949 |
+
# top layers always need to train
|
| 950 |
+
if name.find("ln_final.") == 0 or name.find("text_projection") == 0 or name.find("logit_scale") == 0 \
|
| 951 |
+
or name.find("visual.ln_post.") == 0 or name.find("visual.proj") == 0:
|
| 952 |
+
continue # need to train
|
| 953 |
+
elif name.find("visual.transformer.resblocks.") == 0 or name.find("transformer.resblocks.") == 0:
|
| 954 |
+
layer_num = int(name.split(".resblocks.")[1].split(".")[0])
|
| 955 |
+
if layer_num >= args.freeze_layer_num:
|
| 956 |
+
continue # need to train
|
| 957 |
+
|
| 958 |
+
if args.linear_patch == "3d" and name.find("conv2."):
|
| 959 |
+
continue
|
| 960 |
+
else:
|
| 961 |
+
# paramenters which < freeze_layer_num will be freezed
|
| 962 |
+
param.requires_grad = False
|
| 963 |
+
|
| 964 |
+
## ####################################
|
| 965 |
+
# dataloader loading
|
| 966 |
+
## ####################################
|
| 967 |
+
assert args.datatype in DATALOADER_DICT
|
| 968 |
+
|
| 969 |
+
assert DATALOADER_DICT[args.datatype]["test"] is not None \
|
| 970 |
+
or DATALOADER_DICT[args.datatype]["val"] is not None
|
| 971 |
+
|
| 972 |
+
test_dataloader, test_length = None, 0
|
| 973 |
+
if DATALOADER_DICT[args.datatype]["test"] is not None:
|
| 974 |
+
test_dataloader, test_length = DATALOADER_DICT[args.datatype]["test"](args, tokenizer)
|
| 975 |
+
|
| 976 |
+
if DATALOADER_DICT[args.datatype]["val"] is not None:
|
| 977 |
+
val_dataloader, val_length = DATALOADER_DICT[args.datatype]["val"](args, tokenizer, subset="val")
|
| 978 |
+
else:
|
| 979 |
+
val_dataloader, val_length = test_dataloader, test_length
|
| 980 |
+
|
| 981 |
+
## report validation results if the ["test"] is None
|
| 982 |
+
if test_dataloader is None:
|
| 983 |
+
test_dataloader, test_length = val_dataloader, val_length
|
| 984 |
+
|
| 985 |
+
if args.local_rank == 0:
|
| 986 |
+
logger.info("***** Running test *****")
|
| 987 |
+
logger.info(" Num examples = %d", test_length)
|
| 988 |
+
logger.info(" Batch size = %d", args.batch_size_val)
|
| 989 |
+
logger.info(" Num steps = %d", len(test_dataloader))
|
| 990 |
+
logger.info("***** Running val *****")
|
| 991 |
+
logger.info(" Num examples = %d", val_length)
|
| 992 |
+
|
| 993 |
+
## ####################################
|
| 994 |
+
# train and eval
|
| 995 |
+
## ####################################
|
| 996 |
+
if args.do_train:
|
| 997 |
+
train_dataloader, train_length, train_sampler = DATALOADER_DICT[args.datatype]["train"](args, tokenizer)
|
| 998 |
+
num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1)
|
| 999 |
+
/ args.gradient_accumulation_steps) * args.epochs
|
| 1000 |
+
|
| 1001 |
+
coef_lr = args.coef_lr
|
| 1002 |
+
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, args.local_rank, coef_lr=coef_lr)
|
| 1003 |
+
|
| 1004 |
+
if args.local_rank == 0:
|
| 1005 |
+
logger.info("***** Running training *****")
|
| 1006 |
+
logger.info(" Num examples = %d", train_length)
|
| 1007 |
+
logger.info(" Batch size = %d", args.batch_size)
|
| 1008 |
+
logger.info(" Num steps = %d", num_train_optimization_steps * args.gradient_accumulation_steps)
|
| 1009 |
+
|
| 1010 |
+
best_score = 0.00001
|
| 1011 |
+
best_output_model_file = "None"
|
| 1012 |
+
## ##############################################################
|
| 1013 |
+
# resume optimizer state besides loss to continue train
|
| 1014 |
+
## ##############################################################
|
| 1015 |
+
resumed_epoch = 0
|
| 1016 |
+
if args.resume_model:
|
| 1017 |
+
checkpoint = torch.load(args.resume_model, map_location='cpu')
|
| 1018 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 1019 |
+
resumed_epoch = checkpoint['epoch']+1
|
| 1020 |
+
resumed_loss = checkpoint['loss']
|
| 1021 |
+
|
| 1022 |
+
global_step = 0
|
| 1023 |
+
for epoch in range(resumed_epoch, args.epochs):
|
| 1024 |
+
train_sampler.set_epoch(epoch)
|
| 1025 |
+
tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer,
|
| 1026 |
+
scheduler, global_step, local_rank=args.local_rank)
|
| 1027 |
+
|
| 1028 |
+
if args.local_rank == 0:
|
| 1029 |
+
logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
|
| 1030 |
+
|
| 1031 |
+
output_model_file = save_model(epoch, args, model, optimizer, tr_loss, type_name="")
|
| 1032 |
+
|
| 1033 |
+
## Run on val dataset, this process is *TIME-consuming*.
|
| 1034 |
+
# logger.info("Eval on val dataset")
|
| 1035 |
+
# R1 = eval_epoch(args, model, val_dataloader, device, n_gpu)
|
| 1036 |
+
|
| 1037 |
+
R1 = eval_epoch(args, model, test_dataloader, device, n_gpu, tokenizer)
|
| 1038 |
+
if best_score <= R1:
|
| 1039 |
+
best_score = R1
|
| 1040 |
+
best_output_model_file = output_model_file
|
| 1041 |
+
logger.info("The best model is: {}, the R1 is: {:.4f}".format(best_output_model_file, best_score))
|
| 1042 |
+
|
| 1043 |
+
## Uncomment if want to test on the best checkpoint
|
| 1044 |
+
# if args.local_rank == 0:
|
| 1045 |
+
# model = load_model(-1, args, n_gpu, device, model_file=best_output_model_file)
|
| 1046 |
+
# eval_epoch(args, model, test_dataloader, device, n_gpu)
|
| 1047 |
+
|
| 1048 |
+
elif args.do_eval:
|
| 1049 |
+
if args.local_rank == 0:
|
| 1050 |
+
eval_epoch(args, model, test_dataloader, device, n_gpu, tokenizer)
|
| 1051 |
+
|
| 1052 |
+
if __name__ == "__main__":
|
| 1053 |
+
main()
|
cache_main_task_retrieval_backup.py
ADDED
|
@@ -0,0 +1,867 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import
|
| 2 |
+
from __future__ import division
|
| 3 |
+
from __future__ import unicode_literals
|
| 4 |
+
from __future__ import print_function
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import random
|
| 9 |
+
import os
|
| 10 |
+
from metrics import compute_metrics, tensor_text_to_video_metrics, tensor_video_to_text_sim
|
| 11 |
+
import time
|
| 12 |
+
import argparse
|
| 13 |
+
from modules.tokenization_clip import SimpleTokenizer as ClipTokenizer
|
| 14 |
+
from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
| 15 |
+
from modules.modeling import CLIP4Clip
|
| 16 |
+
import matplotlib.pyplot as plt
|
| 17 |
+
from modules.optimization import BertAdam
|
| 18 |
+
|
| 19 |
+
from util import parallel_apply, get_logger
|
| 20 |
+
from dataloaders.data_dataloaders import DATALOADER_DICT
|
| 21 |
+
|
| 22 |
+
torch.distributed.init_process_group(backend="nccl")
|
| 23 |
+
|
| 24 |
+
global logger
|
| 25 |
+
|
| 26 |
+
def get_args(description='CLIP4Clip on Retrieval Task'):
|
| 27 |
+
parser = argparse.ArgumentParser(description=description)
|
| 28 |
+
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
|
| 29 |
+
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
|
| 30 |
+
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
|
| 31 |
+
|
| 32 |
+
parser.add_argument('--train_csv', type=str, default='data/.train.csv', help='')
|
| 33 |
+
parser.add_argument('--val_csv', type=str, default='data/.val.csv', help='')
|
| 34 |
+
parser.add_argument('--data_path', type=str, default='data/caption.pickle', help='data pickle file path')
|
| 35 |
+
parser.add_argument('--features_path', type=str, default='data/videos_feature.pickle', help='feature path')
|
| 36 |
+
|
| 37 |
+
parser.add_argument('--num_thread_reader', type=int, default=1, help='')
|
| 38 |
+
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate')
|
| 39 |
+
parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit')
|
| 40 |
+
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
|
| 41 |
+
parser.add_argument('--batch_size_val', type=int, default=3500, help='batch size eval')
|
| 42 |
+
parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate exp epoch decay')
|
| 43 |
+
parser.add_argument('--n_display', type=int, default=100, help='Information display frequence')
|
| 44 |
+
parser.add_argument('--video_dim', type=int, default=1024, help='video feature dimension')
|
| 45 |
+
parser.add_argument('--seed', type=int, default=42, help='random seed')
|
| 46 |
+
parser.add_argument('--max_words', type=int, default=20, help='')
|
| 47 |
+
parser.add_argument('--max_frames', type=int, default=100, help='')
|
| 48 |
+
parser.add_argument('--feature_framerate', type=int, default=1, help='')
|
| 49 |
+
parser.add_argument('--margin', type=float, default=0.1, help='margin for loss')
|
| 50 |
+
parser.add_argument('--hard_negative_rate', type=float, default=0.5, help='rate of intra negative sample')
|
| 51 |
+
parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative')
|
| 52 |
+
parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader')
|
| 53 |
+
|
| 54 |
+
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
| 55 |
+
help="The output directory where the model predictions and checkpoints will be written.")
|
| 56 |
+
parser.add_argument("--cross_model", default="cross-base", type=str, required=False, help="Cross module")
|
| 57 |
+
parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.")
|
| 58 |
+
parser.add_argument("--resume_model", default=None, type=str, required=False, help="Resume train model.")
|
| 59 |
+
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
|
| 60 |
+
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
| 61 |
+
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.")
|
| 62 |
+
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
| 63 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
| 64 |
+
parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.")
|
| 65 |
+
|
| 66 |
+
parser.add_argument("--cache_dir", default="", type=str,
|
| 67 |
+
help="Where do you want to store the pre-trained models downloaded from s3")
|
| 68 |
+
|
| 69 |
+
parser.add_argument('--fp16', action='store_true',
|
| 70 |
+
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
| 71 |
+
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
| 72 |
+
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
| 73 |
+
"See details at https://nvidia.github.io/apex/amp.html")
|
| 74 |
+
|
| 75 |
+
parser.add_argument("--task_type", default="retrieval", type=str, help="Point the task `retrieval` to finetune.")
|
| 76 |
+
parser.add_argument("--datatype", default="msrvtt", type=str, help="Point the dataset to finetune.")
|
| 77 |
+
|
| 78 |
+
parser.add_argument("--world_size", default=0, type=int, help="distribted training")
|
| 79 |
+
parser.add_argument("--local_rank", default=0, type=int, help="distribted training")
|
| 80 |
+
parser.add_argument("--rank", default=0, type=int, help="distribted training")
|
| 81 |
+
parser.add_argument('--coef_lr', type=float, default=1., help='coefficient for bert branch.')
|
| 82 |
+
parser.add_argument('--use_mil', action='store_true', help="Whether use MIL as Miech et. al. (2020).")
|
| 83 |
+
parser.add_argument('--sampled_use_mil', action='store_true', help="Whether MIL, has a high priority than use_mil.")
|
| 84 |
+
|
| 85 |
+
parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.")
|
| 86 |
+
parser.add_argument('--visual_num_hidden_layers', type=int, default=12, help="Layer NO. of visual.")
|
| 87 |
+
parser.add_argument('--cross_num_hidden_layers', type=int, default=4, help="Layer NO. of cross.")
|
| 88 |
+
|
| 89 |
+
parser.add_argument('--loose_type', action='store_true', help="Default using tight type for retrieval.")
|
| 90 |
+
parser.add_argument('--expand_msrvtt_sentences', action='store_true', help="")
|
| 91 |
+
|
| 92 |
+
parser.add_argument('--train_frame_order', type=int, default=0, choices=[0, 1, 2],
|
| 93 |
+
help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.")
|
| 94 |
+
parser.add_argument('--eval_frame_order', type=int, default=0, choices=[0, 1, 2],
|
| 95 |
+
help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.")
|
| 96 |
+
|
| 97 |
+
parser.add_argument('--freeze_layer_num', type=int, default=0, help="Layer NO. of CLIP need to freeze.")
|
| 98 |
+
parser.add_argument('--slice_framepos', type=int, default=0, choices=[0, 1, 2],
|
| 99 |
+
help="0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly.")
|
| 100 |
+
parser.add_argument('--linear_patch', type=str, default="2d", choices=["2d", "3d"],
|
| 101 |
+
help="linear projection of flattened patches.")
|
| 102 |
+
parser.add_argument('--sim_header', type=str, default="meanP",
|
| 103 |
+
choices=["meanP", "seqLSTM", "seqTransf", "tightTransf"],
|
| 104 |
+
help="choice a similarity header.")
|
| 105 |
+
|
| 106 |
+
parser.add_argument("--pretrained_clip_name", default="ViT-B/32", type=str, help="Choose a CLIP version")
|
| 107 |
+
parser.add_argument("--use_rff", action='store_true', help="Use RFF hypervector encoding for video embeddings")
|
| 108 |
+
parser.add_argument("--rff_dim", type=int, default=3000, help="Hypervector dimension for RFF encoding")
|
| 109 |
+
parser.add_argument("--use_clip4hashing", action="store_true", help="CLIP4Hashing 손실·해시 경로 사용 여부")
|
| 110 |
+
parser.add_argument("--hash_bit", type=int, default=2048, help="해시 코드 비트 수 (default 1024)")
|
| 111 |
+
|
| 112 |
+
args = parser.parse_args()
|
| 113 |
+
|
| 114 |
+
if args.sim_header == "tightTransf":
|
| 115 |
+
args.loose_type = False
|
| 116 |
+
|
| 117 |
+
# Check paramenters
|
| 118 |
+
if args.gradient_accumulation_steps < 1:
|
| 119 |
+
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
| 120 |
+
args.gradient_accumulation_steps))
|
| 121 |
+
if not args.do_train and not args.do_eval:
|
| 122 |
+
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
| 123 |
+
|
| 124 |
+
args.batch_size = int(args.batch_size / args.gradient_accumulation_steps)
|
| 125 |
+
|
| 126 |
+
return args
|
| 127 |
+
|
| 128 |
+
def set_seed_logger(args):
|
| 129 |
+
global logger
|
| 130 |
+
# predefining random initial seeds
|
| 131 |
+
random.seed(args.seed)
|
| 132 |
+
os.environ['PYTHONHASHSEED'] = str(args.seed)
|
| 133 |
+
np.random.seed(args.seed)
|
| 134 |
+
torch.manual_seed(args.seed)
|
| 135 |
+
torch.cuda.manual_seed(args.seed)
|
| 136 |
+
torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU.
|
| 137 |
+
torch.backends.cudnn.benchmark = False
|
| 138 |
+
torch.backends.cudnn.deterministic = True
|
| 139 |
+
|
| 140 |
+
world_size = torch.distributed.get_world_size()
|
| 141 |
+
torch.cuda.set_device(args.local_rank)
|
| 142 |
+
args.world_size = world_size
|
| 143 |
+
rank = torch.distributed.get_rank()
|
| 144 |
+
args.rank = rank
|
| 145 |
+
|
| 146 |
+
if not os.path.exists(args.output_dir):
|
| 147 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 148 |
+
|
| 149 |
+
logger = get_logger(os.path.join(args.output_dir, "log.txt"))
|
| 150 |
+
|
| 151 |
+
if args.local_rank == 0:
|
| 152 |
+
logger.info("Effective parameters:")
|
| 153 |
+
for key in sorted(args.__dict__):
|
| 154 |
+
logger.info(" <<< {}: {}".format(key, args.__dict__[key]))
|
| 155 |
+
|
| 156 |
+
return args
|
| 157 |
+
|
| 158 |
+
def init_device(args, local_rank):
|
| 159 |
+
global logger
|
| 160 |
+
|
| 161 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu", local_rank)
|
| 162 |
+
|
| 163 |
+
n_gpu = torch.cuda.device_count()
|
| 164 |
+
logger.info("device: {} n_gpu: {}".format(device, n_gpu))
|
| 165 |
+
args.n_gpu = n_gpu
|
| 166 |
+
|
| 167 |
+
if args.batch_size % args.n_gpu != 0 or args.batch_size_val % args.n_gpu != 0:
|
| 168 |
+
raise ValueError("Invalid batch_size/batch_size_val and n_gpu parameter: {}%{} and {}%{}, should be == 0".format(
|
| 169 |
+
args.batch_size, args.n_gpu, args.batch_size_val, args.n_gpu))
|
| 170 |
+
|
| 171 |
+
return device, n_gpu
|
| 172 |
+
|
| 173 |
+
def init_model(args, device, n_gpu, local_rank):
|
| 174 |
+
|
| 175 |
+
if args.init_model:
|
| 176 |
+
model_state_dict = torch.load(args.init_model, map_location='cpu')
|
| 177 |
+
else:
|
| 178 |
+
model_state_dict = None
|
| 179 |
+
|
| 180 |
+
# Prepare model
|
| 181 |
+
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
|
| 182 |
+
model = CLIP4Clip.from_pretrained(args.cross_model, cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
|
| 183 |
+
|
| 184 |
+
model.to(device)
|
| 185 |
+
|
| 186 |
+
return model
|
| 187 |
+
|
| 188 |
+
def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=1.):
|
| 189 |
+
|
| 190 |
+
if hasattr(model, 'module'):
|
| 191 |
+
model = model.module
|
| 192 |
+
|
| 193 |
+
param_optimizer = list(model.named_parameters())
|
| 194 |
+
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
| 195 |
+
|
| 196 |
+
decay_param_tp = [(n, p) for n, p in param_optimizer if not any(nd in n for nd in no_decay)]
|
| 197 |
+
no_decay_param_tp = [(n, p) for n, p in param_optimizer if any(nd in n for nd in no_decay)]
|
| 198 |
+
|
| 199 |
+
decay_clip_param_tp = [(n, p) for n, p in decay_param_tp if "clip." in n]
|
| 200 |
+
decay_noclip_param_tp = [(n, p) for n, p in decay_param_tp if "clip." not in n]
|
| 201 |
+
|
| 202 |
+
no_decay_clip_param_tp = [(n, p) for n, p in no_decay_param_tp if "clip." in n]
|
| 203 |
+
no_decay_noclip_param_tp = [(n, p) for n, p in no_decay_param_tp if "clip." not in n]
|
| 204 |
+
|
| 205 |
+
weight_decay = 0.2
|
| 206 |
+
optimizer_grouped_parameters = [
|
| 207 |
+
{'params': [p for n, p in decay_clip_param_tp], 'weight_decay': weight_decay, 'lr': args.lr * coef_lr},
|
| 208 |
+
{'params': [p for n, p in decay_noclip_param_tp], 'weight_decay': weight_decay},
|
| 209 |
+
{'params': [p for n, p in no_decay_clip_param_tp], 'weight_decay': 0.0, 'lr': args.lr * coef_lr},
|
| 210 |
+
{'params': [p for n, p in no_decay_noclip_param_tp], 'weight_decay': 0.0}
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
scheduler = None
|
| 214 |
+
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, warmup=args.warmup_proportion,
|
| 215 |
+
schedule='warmup_cosine', b1=0.9, b2=0.98, e=1e-6,
|
| 216 |
+
t_total=num_train_optimization_steps, weight_decay=weight_decay,
|
| 217 |
+
max_grad_norm=1.0)
|
| 218 |
+
|
| 219 |
+
# 옵티마이저 만든 뒤 곧장 실행
|
| 220 |
+
name2param = {n: p for n, p in model.named_parameters() if p.requires_grad}
|
| 221 |
+
param2name = {id(p): n for n, p in name2param.items()}
|
| 222 |
+
|
| 223 |
+
for gi, g in enumerate(optimizer.param_groups):
|
| 224 |
+
print(f"[group {gi}] lr={g['lr']:.2e}, params={len(g['params'])}")
|
| 225 |
+
# 각 그룹에서 몇 개만 샘플로 찍기
|
| 226 |
+
for p in g["params"][:8]:
|
| 227 |
+
print(" ", param2name.get(id(p), "?"))
|
| 228 |
+
|
| 229 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
|
| 230 |
+
output_device=local_rank, find_unused_parameters=True)
|
| 231 |
+
|
| 232 |
+
return optimizer, scheduler, model
|
| 233 |
+
|
| 234 |
+
def save_model(epoch, args, model, optimizer, tr_loss, type_name=""):
|
| 235 |
+
# Only save the model it-self
|
| 236 |
+
model_to_save = model.module if hasattr(model, 'module') else model
|
| 237 |
+
output_model_file = os.path.join(
|
| 238 |
+
args.output_dir, "pytorch_model.bin.{}{}".format("" if type_name=="" else type_name+".", epoch))
|
| 239 |
+
optimizer_state_file = os.path.join(
|
| 240 |
+
args.output_dir, "pytorch_opt.bin.{}{}".format("" if type_name=="" else type_name+".", epoch))
|
| 241 |
+
torch.save(model_to_save.state_dict(), output_model_file)
|
| 242 |
+
torch.save({
|
| 243 |
+
'epoch': epoch,
|
| 244 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 245 |
+
'loss': tr_loss,
|
| 246 |
+
}, optimizer_state_file)
|
| 247 |
+
logger.info("Model saved to %s", output_model_file)
|
| 248 |
+
logger.info("Optimizer saved to %s", optimizer_state_file)
|
| 249 |
+
return output_model_file
|
| 250 |
+
|
| 251 |
+
def load_model(epoch, args, n_gpu, device, model_file=None):
|
| 252 |
+
if model_file is None or len(model_file) == 0:
|
| 253 |
+
model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(epoch))
|
| 254 |
+
if os.path.exists(model_file):
|
| 255 |
+
model_state_dict = torch.load(model_file, map_location='cpu')
|
| 256 |
+
if args.local_rank == 0:
|
| 257 |
+
logger.info("Model loaded from %s", model_file)
|
| 258 |
+
# Prepare model
|
| 259 |
+
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
|
| 260 |
+
model = CLIP4Clip.from_pretrained(args.cross_model, cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
|
| 261 |
+
|
| 262 |
+
model.to(device)
|
| 263 |
+
else:
|
| 264 |
+
model = None
|
| 265 |
+
|
| 266 |
+
logger.info(f"모델을 로드합니다:{cache_dir}")
|
| 267 |
+
return model
|
| 268 |
+
|
| 269 |
+
def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, scheduler, global_step, local_rank=0):
|
| 270 |
+
global logger
|
| 271 |
+
torch.cuda.empty_cache()
|
| 272 |
+
if hasattr(model, 'module'):
|
| 273 |
+
model.module.train()
|
| 274 |
+
else:
|
| 275 |
+
model.train()
|
| 276 |
+
log_step = args.n_display
|
| 277 |
+
start_time = time.time()
|
| 278 |
+
total_loss = 0
|
| 279 |
+
|
| 280 |
+
for step, batch in enumerate(train_dataloader):
|
| 281 |
+
if n_gpu == 1:
|
| 282 |
+
# multi-gpu does scattering it-self
|
| 283 |
+
batch = tuple(t.to(device=device, non_blocking=True) for t in batch)
|
| 284 |
+
|
| 285 |
+
input_ids, input_mask, segment_ids, video, video_mask = batch
|
| 286 |
+
loss = model(input_ids, segment_ids, input_mask, video, video_mask)
|
| 287 |
+
|
| 288 |
+
if n_gpu > 1:
|
| 289 |
+
loss = loss.mean() # mean() to average on multi-gpu.
|
| 290 |
+
if args.gradient_accumulation_steps > 1:
|
| 291 |
+
loss = loss / args.gradient_accumulation_steps
|
| 292 |
+
|
| 293 |
+
loss.backward()
|
| 294 |
+
|
| 295 |
+
total_loss += float(loss)
|
| 296 |
+
if (step + 1) % args.gradient_accumulation_steps == 0:
|
| 297 |
+
|
| 298 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 299 |
+
|
| 300 |
+
if scheduler is not None:
|
| 301 |
+
scheduler.step() # Update learning rate schedule
|
| 302 |
+
|
| 303 |
+
optimizer.step()
|
| 304 |
+
optimizer.zero_grad()
|
| 305 |
+
|
| 306 |
+
# https://github.com/openai/CLIP/issues/46
|
| 307 |
+
if hasattr(model, 'module'):
|
| 308 |
+
torch.clamp_(model.module.clip.logit_scale.data, max=np.log(100))
|
| 309 |
+
else:
|
| 310 |
+
torch.clamp_(model.clip.logit_scale.data, max=np.log(100))
|
| 311 |
+
|
| 312 |
+
global_step += 1
|
| 313 |
+
if global_step % log_step == 0 and local_rank == 0:
|
| 314 |
+
logger.info("Epoch: %d/%s, Step: %d/%d, Lr: %s, Loss: %f, Time/step: %f", epoch + 1,
|
| 315 |
+
args.epochs, step + 1,
|
| 316 |
+
len(train_dataloader), "-".join([str('%.9f'%itm) for itm in sorted(list(set(optimizer.get_lr())))]),
|
| 317 |
+
float(loss),
|
| 318 |
+
(time.time() - start_time) / (log_step * args.gradient_accumulation_steps))
|
| 319 |
+
start_time = time.time()
|
| 320 |
+
|
| 321 |
+
total_loss = total_loss / len(train_dataloader)
|
| 322 |
+
return total_loss, global_step
|
| 323 |
+
|
| 324 |
+
def _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_sequence_output_list, batch_visual_output_list):
|
| 325 |
+
|
| 326 |
+
sim_matrix = []
|
| 327 |
+
for idx1, b1 in enumerate(batch_list_t):
|
| 328 |
+
input_mask, segment_ids, *_tmp = b1
|
| 329 |
+
sequence_output = batch_sequence_output_list[idx1]
|
| 330 |
+
each_row = []
|
| 331 |
+
for idx2, b2 in enumerate(batch_list_v):
|
| 332 |
+
video_mask, *_tmp = b2
|
| 333 |
+
visual_output = batch_visual_output_list[idx2]
|
| 334 |
+
b1b2_logits, *_tmp = model.get_similarity_logits(sequence_output, visual_output, input_mask, video_mask,
|
| 335 |
+
loose_type=model.loose_type)
|
| 336 |
+
b1b2_logits = b1b2_logits.cpu().detach().numpy()
|
| 337 |
+
each_row.append(b1b2_logits)
|
| 338 |
+
each_row = np.concatenate(tuple(each_row), axis=-1)
|
| 339 |
+
sim_matrix.append(each_row)
|
| 340 |
+
return sim_matrix
|
| 341 |
+
|
| 342 |
+
def eval_epoch(args, model, test_dataloader, device, n_gpu, tokenizer):
|
| 343 |
+
import numpy as np
|
| 344 |
+
import os
|
| 345 |
+
import torch
|
| 346 |
+
import matplotlib.pyplot as plt
|
| 347 |
+
|
| 348 |
+
def _decode_query(tokenizer, ids_tensor):
|
| 349 |
+
# ids_tensor: 1D tensor (token ids)
|
| 350 |
+
try:
|
| 351 |
+
if isinstance(ids_tensor, torch.Tensor):
|
| 352 |
+
ids = ids_tensor.cpu().numpy().tolist()
|
| 353 |
+
else:
|
| 354 |
+
ids = ids_tensor.tolist() if hasattr(ids_tensor, 'tolist') else list(ids_tensor)
|
| 355 |
+
|
| 356 |
+
# ClipTokenizer의 특수 토큰 ID들
|
| 357 |
+
start_token_id = tokenizer.encoder.get('<|startoftext|>', 49406) # 기본값
|
| 358 |
+
end_token_id = tokenizer.encoder.get('<|endoftext|>', 49407) # 기본값
|
| 359 |
+
|
| 360 |
+
# 패딩 토큰(0)과 특수 토큰 제거
|
| 361 |
+
clean_ids = []
|
| 362 |
+
for token_id in ids:
|
| 363 |
+
if token_id > 0 and token_id != start_token_id and token_id != end_token_id:
|
| 364 |
+
clean_ids.append(token_id)
|
| 365 |
+
|
| 366 |
+
if not clean_ids:
|
| 367 |
+
return "<empty_query>"
|
| 368 |
+
|
| 369 |
+
# 유효하지 않은 토큰 ID 필터링 (vocab 범위 내)
|
| 370 |
+
vocab_size = len(tokenizer.decoder)
|
| 371 |
+
valid_ids = [tid for tid in clean_ids if tid < vocab_size]
|
| 372 |
+
|
| 373 |
+
if not valid_ids:
|
| 374 |
+
return "<invalid_tokens>"
|
| 375 |
+
|
| 376 |
+
# 디코딩 시도
|
| 377 |
+
try:
|
| 378 |
+
decoded_text = tokenizer.decode(valid_ids)
|
| 379 |
+
return decoded_text.strip()
|
| 380 |
+
except KeyError as e:
|
| 381 |
+
# 개별 토큰별로 디코딩 시도
|
| 382 |
+
decoded_tokens = []
|
| 383 |
+
for tid in valid_ids:
|
| 384 |
+
if tid in tokenizer.decoder:
|
| 385 |
+
decoded_tokens.append(tokenizer.decoder[tid])
|
| 386 |
+
else:
|
| 387 |
+
decoded_tokens.append(f"<unk_{tid}>")
|
| 388 |
+
text = ''.join(decoded_tokens)
|
| 389 |
+
# BPE 후처리
|
| 390 |
+
text = text.replace('</w>', ' ').strip()
|
| 391 |
+
return text if text else "<decode_partial_error>"
|
| 392 |
+
except Exception as e:
|
| 393 |
+
return f"<decode_error: {str(e)[:50]}>"
|
| 394 |
+
|
| 395 |
+
except Exception as e:
|
| 396 |
+
return f"<general_error: {str(e)[:50]}>"
|
| 397 |
+
|
| 398 |
+
def _get_video_ids_from_dataset(dataset, num_videos):
|
| 399 |
+
# 다양한 후보 속성명 시도 → 없으면 0..N-1 인덱스 문자열로 대체
|
| 400 |
+
for attr in ["video_list", "video_ids", "video_names", "videos", "vids", "id_list"]:
|
| 401 |
+
if hasattr(dataset, attr):
|
| 402 |
+
obj = getattr(dataset, attr)
|
| 403 |
+
try:
|
| 404 |
+
if isinstance(obj, (list, tuple)) and len(obj) == num_videos:
|
| 405 |
+
return list(map(str, obj))
|
| 406 |
+
except Exception:
|
| 407 |
+
pass
|
| 408 |
+
return [str(i) for i in range(num_videos)]
|
| 409 |
+
|
| 410 |
+
global logger
|
| 411 |
+
|
| 412 |
+
multi_sentence_ = False
|
| 413 |
+
cut_off_points_, sentence_num_, video_num_ = [], -1, -1
|
| 414 |
+
|
| 415 |
+
if hasattr(model, 'module'):
|
| 416 |
+
model = model.module.to(device)
|
| 417 |
+
else:
|
| 418 |
+
model = model.to(device)
|
| 419 |
+
|
| 420 |
+
if hasattr(model, 'module'):
|
| 421 |
+
model.module.eval()
|
| 422 |
+
else:
|
| 423 |
+
model.eval()
|
| 424 |
+
|
| 425 |
+
logger.info("Model %s", "training" if model.training else "eval")
|
| 426 |
+
|
| 427 |
+
# suffix for cache/result naming
|
| 428 |
+
suffix = "_hash" if getattr(args, "use_clip4hashing", False) else ""
|
| 429 |
+
suffix += "_rff" if args.use_rff else ""
|
| 430 |
+
suffix += "_trained" if args.init_model else ""
|
| 431 |
+
|
| 432 |
+
# (A) 캐시 로드/생성
|
| 433 |
+
if "train" in args.val_csv and "10k" in args.val_csv:
|
| 434 |
+
cache_name = f"{args.datatype}_train_test_10k_cache{suffix}.pt"
|
| 435 |
+
logger.info(f"9k 훈련 데이터 캐시 생성: {cache_name}")
|
| 436 |
+
else:
|
| 437 |
+
cache_name = f"{args.datatype}_eval_cache{suffix}.pt"
|
| 438 |
+
logger.info(f"평가 데이터 캐시: {cache_name}")
|
| 439 |
+
|
| 440 |
+
cache_path = os.path.join(args.output_dir, cache_name)
|
| 441 |
+
|
| 442 |
+
loaded_from_cache = False
|
| 443 |
+
if os.path.exists(cache_path):
|
| 444 |
+
logger.info(f"캐시된 피처를 로드합니다: {cache_path}")
|
| 445 |
+
cache = torch.load(cache_path, map_location=device)
|
| 446 |
+
batch_sequence_output_list = cache['batch_sequence_output_list']
|
| 447 |
+
batch_visual_output_list = cache['batch_visual_output_list']
|
| 448 |
+
batch_list_t = cache['batch_list_t']
|
| 449 |
+
batch_list_v = cache['batch_list_v']
|
| 450 |
+
text_input_ids_list = cache.get('text_input_ids_list', None)
|
| 451 |
+
video_ids = cache.get('video_ids', None)
|
| 452 |
+
loaded_from_cache = True
|
| 453 |
+
|
| 454 |
+
logger.info(f"[Cache] 텍스트 피쳐 개수={len(batch_sequence_output_list)} "
|
| 455 |
+
f"각 텐서 shape={batch_sequence_output_list[0].shape}")
|
| 456 |
+
logger.info(f"[Cache] 비디오 피쳐 개수={len(batch_visual_output_list)} "
|
| 457 |
+
f"각 텐서 shape={batch_visual_output_list[0].shape}")
|
| 458 |
+
else:
|
| 459 |
+
print("Caching feature..")
|
| 460 |
+
if hasattr(test_dataloader.dataset, 'multi_sentence_per_video') and \
|
| 461 |
+
test_dataloader.dataset.multi_sentence_per_video:
|
| 462 |
+
multi_sentence_ = True
|
| 463 |
+
cut_off_points_ = test_dataloader.dataset.cut_off_points
|
| 464 |
+
sentence_num_ = test_dataloader.dataset.sentence_num
|
| 465 |
+
video_num_ = test_dataloader.dataset.video_num
|
| 466 |
+
cut_off_points_ = [itm - 1 for itm in cut_off_points_]
|
| 467 |
+
logger.warning("Eval under multi-sentence-per-video. sentence num: %s, video num: %s",
|
| 468 |
+
sentence_num_, video_num_)
|
| 469 |
+
|
| 470 |
+
with torch.no_grad():
|
| 471 |
+
batch_list_t = []
|
| 472 |
+
batch_list_v = []
|
| 473 |
+
batch_sequence_output_list, batch_visual_output_list = [], []
|
| 474 |
+
text_input_ids_list = []
|
| 475 |
+
total_video_num = 0
|
| 476 |
+
|
| 477 |
+
for bid, batch in enumerate(test_dataloader):
|
| 478 |
+
batch = tuple(t.to(device) for t in batch)
|
| 479 |
+
input_ids, input_mask, segment_ids, video, video_mask = batch
|
| 480 |
+
|
| 481 |
+
if multi_sentence_:
|
| 482 |
+
b, *_t = video.shape
|
| 483 |
+
sequence_output = model.get_sequence_output(input_ids, segment_ids, input_mask)
|
| 484 |
+
batch_sequence_output_list.append(sequence_output)
|
| 485 |
+
# input_ids를 함께 보관 (run_on_single_gpu는 *_tmp로 무시하므로 안전)
|
| 486 |
+
batch_list_t.append((input_mask, segment_ids, input_ids.detach().cpu()))
|
| 487 |
+
|
| 488 |
+
s_, e_ = total_video_num, total_video_num + b
|
| 489 |
+
filter_inds = [itm - s_ for itm in cut_off_points_ if s_ <= itm < e_]
|
| 490 |
+
if len(filter_inds) > 0:
|
| 491 |
+
video, video_mask = video[filter_inds, ...], video_mask[filter_inds, ...]
|
| 492 |
+
visual_output = model.get_visual_output(video, video_mask)
|
| 493 |
+
batch_visual_output_list.append(visual_output)
|
| 494 |
+
batch_list_v.append((video_mask,))
|
| 495 |
+
total_video_num += b
|
| 496 |
+
else:
|
| 497 |
+
sequence_output, visual_output = model.get_sequence_visual_output(
|
| 498 |
+
input_ids, segment_ids, input_mask, video, video_mask)
|
| 499 |
+
|
| 500 |
+
batch_sequence_output_list.append(sequence_output)
|
| 501 |
+
batch_list_t.append((input_mask, segment_ids, input_ids.detach().cpu()))
|
| 502 |
+
|
| 503 |
+
batch_visual_output_list.append(visual_output)
|
| 504 |
+
batch_list_v.append((video_mask,))
|
| 505 |
+
|
| 506 |
+
print("{}/{}\r".format(bid, len(test_dataloader)), end="")
|
| 507 |
+
|
| 508 |
+
# 비디오 ID 목록 구성 (데이터셋 노출 없으면 0..N-1)
|
| 509 |
+
num_videos = int(sum(bv.shape[0] for bv in batch_visual_output_list))
|
| 510 |
+
video_ids = _get_video_ids_from_dataset(test_dataloader.dataset, num_videos)
|
| 511 |
+
|
| 512 |
+
logger.info(f"추출된 피처를 캐시에 저장합니다: {cache_path}")
|
| 513 |
+
torch.save({
|
| 514 |
+
'batch_sequence_output_list': batch_sequence_output_list,
|
| 515 |
+
'batch_visual_output_list': batch_visual_output_list,
|
| 516 |
+
'batch_list_t': batch_list_t,
|
| 517 |
+
'batch_list_v': batch_list_v,
|
| 518 |
+
'text_input_ids_list': text_input_ids_list,
|
| 519 |
+
'video_ids': video_ids,
|
| 520 |
+
}, cache_path)
|
| 521 |
+
|
| 522 |
+
logger.info(f"[Cache] 텍스트 피쳐 개수={len(batch_sequence_output_list)} "
|
| 523 |
+
f"각 텐서 shape={batch_sequence_output_list[0].shape}")
|
| 524 |
+
logger.info(f"[Cache] 비디오 피쳐 개수={len(batch_visual_output_list)} "
|
| 525 |
+
f"각 텐서 shape={batch_visual_output_list[0].shape}")
|
| 526 |
+
|
| 527 |
+
# 캐시에 text_input_ids_list가 없으면, 한 번 더 훑어서 수집 (구버전 캐시 호환)
|
| 528 |
+
if loaded_from_cache and 'text_input_ids_list' not in cache:
|
| 529 |
+
logger.info("캐시에 text_input_ids_list가 없어 재수집합니다(호환성 경로).")
|
| 530 |
+
text_input_ids_list = []
|
| 531 |
+
with torch.no_grad():
|
| 532 |
+
for batch in test_dataloader:
|
| 533 |
+
input_ids = batch[0].detach().cpu()
|
| 534 |
+
text_input_ids_list.append(input_ids)
|
| 535 |
+
elif loaded_from_cache and text_input_ids_list is None:
|
| 536 |
+
# batch_list_t에서 input_ids 추출
|
| 537 |
+
logger.info("batch_list_t에서 input_ids를 추출합니다.")
|
| 538 |
+
text_input_ids_list = []
|
| 539 |
+
for input_mask, segment_ids, input_ids in batch_list_t:
|
| 540 |
+
text_input_ids_list.append(input_ids)
|
| 541 |
+
|
| 542 |
+
# video_ids가 없으면 만들어준다(구버전 캐시 호환)
|
| 543 |
+
if loaded_from_cache and cache.get('video_ids', None) is None:
|
| 544 |
+
num_videos = int(sum(bv.shape[0] for bv in batch_visual_output_list))
|
| 545 |
+
video_ids = _get_video_ids_from_dataset(test_dataloader.dataset, num_videos)
|
| 546 |
+
|
| 547 |
+
# (B) 유사도 행렬 계산
|
| 548 |
+
def _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_sequence_output_list, batch_visual_output_list):
|
| 549 |
+
sim_matrix = []
|
| 550 |
+
for idx1, b1 in enumerate(batch_list_t):
|
| 551 |
+
input_mask, segment_ids, *_tmp = b1
|
| 552 |
+
sequence_output = batch_sequence_output_list[idx1]
|
| 553 |
+
each_row = []
|
| 554 |
+
for idx2, b2 in enumerate(batch_list_v):
|
| 555 |
+
video_mask, *_tmp = b2
|
| 556 |
+
visual_output = batch_visual_output_list[idx2]
|
| 557 |
+
b1b2_logits, *_tmp = model.get_similarity_logits(
|
| 558 |
+
sequence_output, visual_output, input_mask, video_mask, loose_type=model.loose_type)
|
| 559 |
+
b1b2_logits = b1b2_logits.cpu().detach().numpy()
|
| 560 |
+
each_row.append(b1b2_logits)
|
| 561 |
+
each_row = np.concatenate(tuple(each_row), axis=-1)
|
| 562 |
+
sim_matrix.append(each_row)
|
| 563 |
+
return sim_matrix
|
| 564 |
+
|
| 565 |
+
if n_gpu > 1:
|
| 566 |
+
device_ids = list(range(n_gpu))
|
| 567 |
+
batch_list_t_splits, batch_list_v_splits = [], []
|
| 568 |
+
batch_t_output_splits, batch_v_output_splits = [], []
|
| 569 |
+
bacth_len = len(batch_list_t)
|
| 570 |
+
split_len = (bacth_len + n_gpu - 1) // n_gpu
|
| 571 |
+
for dev_id in device_ids:
|
| 572 |
+
s_, e_ = dev_id * split_len, (dev_id + 1) * split_len
|
| 573 |
+
if dev_id == 0:
|
| 574 |
+
batch_list_t_splits.append(batch_list_t[s_:e_]); batch_list_v_splits.append(batch_list_v)
|
| 575 |
+
batch_t_output_splits.append(batch_sequence_output_list[s_:e_]); batch_v_output_splits.append(batch_visual_output_list)
|
| 576 |
+
else:
|
| 577 |
+
devc = torch.device(f'cuda:{dev_id}')
|
| 578 |
+
devc_batch_list = [tuple(t.to(devc) for t in b) for b in batch_list_t[s_:e_]]
|
| 579 |
+
batch_list_t_splits.append(devc_batch_list)
|
| 580 |
+
devc_batch_list = [tuple(t.to(devc) for t in b) for b in batch_list_v]
|
| 581 |
+
batch_list_v_splits.append(devc_batch_list)
|
| 582 |
+
devc_batch_list = [b.to(devc) for b in batch_sequence_output_list[s_:e_]]
|
| 583 |
+
batch_t_output_splits.append(devc_batch_list)
|
| 584 |
+
devc_batch_list = [b.to(devc) for b in batch_visual_output_list]
|
| 585 |
+
batch_v_output_splits.append(devc_batch_list)
|
| 586 |
+
|
| 587 |
+
parameters_tuple_list = [(batch_list_t_splits[dev_id], batch_list_v_splits[dev_id],
|
| 588 |
+
batch_t_output_splits[dev_id], batch_v_output_splits[dev_id]) for dev_id in device_ids]
|
| 589 |
+
parallel_outputs = parallel_apply(_run_on_single_gpu, model, parameters_tuple_list, device_ids)
|
| 590 |
+
sim_matrix = []
|
| 591 |
+
for idx in range(len(parallel_outputs)):
|
| 592 |
+
sim_matrix += parallel_outputs[idx]
|
| 593 |
+
sim_matrix = np.concatenate(tuple(sim_matrix), axis=0)
|
| 594 |
+
else:
|
| 595 |
+
sim_matrix = _run_on_single_gpu(model, batch_list_t, batch_list_v,
|
| 596 |
+
batch_sequence_output_list, batch_visual_output_list)
|
| 597 |
+
sim_matrix = np.concatenate(tuple(sim_matrix), axis=0)
|
| 598 |
+
|
| 599 |
+
# (C) 멀티센텐스 처리 및 메트릭
|
| 600 |
+
if hasattr(test_dataloader.dataset, 'multi_sentence_per_video') and \
|
| 601 |
+
test_dataloader.dataset.multi_sentence_per_video:
|
| 602 |
+
multi_sentence_ = True
|
| 603 |
+
|
| 604 |
+
if multi_sentence_:
|
| 605 |
+
logger.info("before reshape, sim matrix size: %d x %d", sim_matrix.shape[0], sim_matrix.shape[1])
|
| 606 |
+
sim_matrix_flat = sim_matrix.copy() # 쿼리별 Top-K 용 2D 보관
|
| 607 |
+
|
| 608 |
+
cut_off_points2len_ = [itm + 1 for itm in cut_off_points_]
|
| 609 |
+
max_length = max([e_-s_ for s_, e_ in zip([0]+cut_off_points2len_[:-1], cut_off_points2len_)])
|
| 610 |
+
sim_matrix_new = []
|
| 611 |
+
for s_, e_ in zip([0] + cut_off_points2len_[:-1], cut_off_points2len_):
|
| 612 |
+
sim_matrix_new.append(np.concatenate((sim_matrix[s_:e_],
|
| 613 |
+
np.full((max_length-e_+s_, sim_matrix.shape[1]), -np.inf)), axis=0))
|
| 614 |
+
sim_matrix = np.stack(tuple(sim_matrix_new), axis=0)
|
| 615 |
+
logger.info("after reshape, sim matrix size: %d x %d x %d",
|
| 616 |
+
sim_matrix.shape[0], sim_matrix.shape[1], sim_matrix.shape[2])
|
| 617 |
+
|
| 618 |
+
tv_metrics = tensor_text_to_video_metrics(sim_matrix)
|
| 619 |
+
vt_metrics = compute_metrics(tensor_video_to_text_sim(sim_matrix))
|
| 620 |
+
else:
|
| 621 |
+
logger.info("sim matrix size: %d x %d", sim_matrix.shape[0], sim_matrix.shape[1])
|
| 622 |
+
# 히트맵 저장(샘플)
|
| 623 |
+
plt.figure(figsize=(8,6))
|
| 624 |
+
plt.imshow(sim_matrix[:100, :100], aspect='auto')
|
| 625 |
+
plt.title('Similarity Matrix Heatmap')
|
| 626 |
+
plt.xlabel('Video Index')
|
| 627 |
+
plt.ylabel('Text Index')
|
| 628 |
+
plt.tight_layout()
|
| 629 |
+
out_path = os.path.join(args.output_dir, 'sim_matrix_heatmap.png')
|
| 630 |
+
plt.savefig(out_path); plt.close()
|
| 631 |
+
logger.info(f"Saved sim_matrix heatmap to {out_path}")
|
| 632 |
+
|
| 633 |
+
sim_matrix_flat = sim_matrix # 2D 그대로
|
| 634 |
+
tv_metrics = compute_metrics(sim_matrix)
|
| 635 |
+
vt_metrics = compute_metrics(sim_matrix.T)
|
| 636 |
+
logger.info('\t Length-T: %d, Length-V:%d', len(sim_matrix), len(sim_matrix[0]))
|
| 637 |
+
|
| 638 |
+
logger.info("Text-to-Video:")
|
| 639 |
+
logger.info('\t>>> R@1: %.1f - R@5: %.1f - R@10: %.1f - Median R: %.1f - Mean R: %.1f',
|
| 640 |
+
tv_metrics['R1'], tv_metrics['R5'], tv_metrics['R10'], tv_metrics['MR'], tv_metrics['MeanR'])
|
| 641 |
+
logger.info("Video-to-Text:")
|
| 642 |
+
logger.info('\t>>> V2T$R@1: %.1f - V2T$R@5: %.1f - V2T$R@10: %.1f - V2T$Median R: %.1f - V2T$Mean R: %.1f',
|
| 643 |
+
vt_metrics['R1'], vt_metrics['R5'], vt_metrics['R10'], vt_metrics['MR'], vt_metrics['MeanR'])
|
| 644 |
+
|
| 645 |
+
# (D) 쿼리 텍스트 복원 + Top-10 덤프
|
| 646 |
+
# text_input_ids_list: List[Tensor[B_i, L]]
|
| 647 |
+
all_queries = []
|
| 648 |
+
logger.info(f"text_input_ids_list 개수: {len(text_input_ids_list)}")
|
| 649 |
+
|
| 650 |
+
for batch_idx, ids_batch in enumerate(text_input_ids_list):
|
| 651 |
+
if ids_batch is None:
|
| 652 |
+
logger.warning(f"배치 {batch_idx}: ids_batch가 None입니다.")
|
| 653 |
+
continue
|
| 654 |
+
|
| 655 |
+
try:
|
| 656 |
+
ids_batch = ids_batch if isinstance(ids_batch, torch.Tensor) else torch.as_tensor(ids_batch)
|
| 657 |
+
logger.info(f"배치 {batch_idx}: shape={ids_batch.shape}")
|
| 658 |
+
|
| 659 |
+
for row_idx, row in enumerate(ids_batch):
|
| 660 |
+
decoded = _decode_query(tokenizer, row)
|
| 661 |
+
all_queries.append(decoded)
|
| 662 |
+
if batch_idx == 0 and row_idx < 3: # 첫 배치의 처음 3개만 샘플로 출력
|
| 663 |
+
logger.info(f"샘플 디코딩 결과 [{batch_idx}-{row_idx}]: '{decoded}'")
|
| 664 |
+
|
| 665 |
+
except Exception as e:
|
| 666 |
+
logger.error(f"배치 {batch_idx} 처리 중 오류: {str(e)}")
|
| 667 |
+
# 에러가 발생해도 계속 진행
|
| 668 |
+
continue
|
| 669 |
+
|
| 670 |
+
logger.info(f"총 {len(all_queries)}개의 쿼리가 디코딩되었습니다.")
|
| 671 |
+
|
| 672 |
+
# video_ids 길이 보정(안전)
|
| 673 |
+
num_videos = sim_matrix_flat.shape[1]
|
| 674 |
+
if 'video_ids' in locals():
|
| 675 |
+
if len(video_ids) != num_videos:
|
| 676 |
+
logger.warning("video_ids 길이(%d)와 비디오 수(%d)가 달라 index로 대체합니다.",
|
| 677 |
+
len(video_ids), num_videos)
|
| 678 |
+
video_ids = [str(i) for i in range(num_videos)]
|
| 679 |
+
else:
|
| 680 |
+
video_ids = [str(i) for i in range(num_videos)]
|
| 681 |
+
|
| 682 |
+
# 저장 파일
|
| 683 |
+
topk = 10
|
| 684 |
+
out_tsv = os.path.join(args.output_dir, f"t2v_top10{suffix}.tsv")
|
| 685 |
+
out_json = os.path.join(args.output_dir, f"t2v_top10{suffix}.json")
|
| 686 |
+
|
| 687 |
+
if args.local_rank == 0:
|
| 688 |
+
import json
|
| 689 |
+
|
| 690 |
+
# TSV 파일 저장
|
| 691 |
+
with open(out_tsv, "w", encoding="utf-8") as f:
|
| 692 |
+
f.write("query_idx\tquery\tvideo_rank\tvideo_id\tvideo_idx\tscore\n")
|
| 693 |
+
for qi, q in enumerate(all_queries):
|
| 694 |
+
scores = sim_matrix_flat[qi]
|
| 695 |
+
# 효율: argpartition 후 정렬
|
| 696 |
+
idxs = np.argpartition(-scores, kth=min(topk, len(scores)-1))[:topk]
|
| 697 |
+
idxs = idxs[np.argsort(-scores[idxs])]
|
| 698 |
+
for rank, vidx in enumerate(idxs, 1):
|
| 699 |
+
f.write(f"{qi}\t{q}\t{rank}\t{video_ids[vidx]}\t{int(vidx)}\t{float(scores[vidx]):.6f}\n")
|
| 700 |
+
|
| 701 |
+
# JSON 파일 저장 (구조화된 형태)
|
| 702 |
+
results_dict = {}
|
| 703 |
+
for qi, q in enumerate(all_queries):
|
| 704 |
+
scores = sim_matrix_flat[qi]
|
| 705 |
+
idxs = np.argpartition(-scores, kth=min(topk, len(scores)-1))[:topk]
|
| 706 |
+
idxs = idxs[np.argsort(-scores[idxs])]
|
| 707 |
+
|
| 708 |
+
results_dict[f"query_{qi+1}"] = {
|
| 709 |
+
"query_text": q,
|
| 710 |
+
"top_videos": []
|
| 711 |
+
}
|
| 712 |
+
|
| 713 |
+
for rank, vidx in enumerate(idxs, 1):
|
| 714 |
+
results_dict[f"query_{qi+1}"]["top_videos"].append({
|
| 715 |
+
"rank": rank,
|
| 716 |
+
"video_id": video_ids[vidx],
|
| 717 |
+
"video_idx": int(vidx),
|
| 718 |
+
"score": float(scores[vidx])
|
| 719 |
+
})
|
| 720 |
+
|
| 721 |
+
with open(out_json, "w", encoding="utf-8") as f:
|
| 722 |
+
json.dump(results_dict, f, ensure_ascii=False, indent=2)
|
| 723 |
+
|
| 724 |
+
logger.info("T2V Top-10 per query 저장 완료:")
|
| 725 |
+
logger.info(" TSV 파일: %s", out_tsv)
|
| 726 |
+
logger.info(" JSON 파일: %s", out_json)
|
| 727 |
+
logger.info("총 %d개 쿼리에 대한 top-10 결과가 저장되었습니다.", len(all_queries))
|
| 728 |
+
|
| 729 |
+
# 로그에 모든 쿼리의 Top-10 결과 출력
|
| 730 |
+
logger.info("=== Query-wise Top-10 Results (전체 %d개 쿼리) ===", len(all_queries))
|
| 731 |
+
for qi in range(len(all_queries)):
|
| 732 |
+
scores = sim_matrix_flat[qi]
|
| 733 |
+
idxs = np.argpartition(-scores, kth=min(topk, len(scores)-1))[:topk]
|
| 734 |
+
idxs = idxs[np.argsort(-scores[idxs])]
|
| 735 |
+
|
| 736 |
+
logger.info(f"Query {qi+1}: \"{all_queries[qi]}\"")
|
| 737 |
+
for rank, vidx in enumerate(idxs, 1):
|
| 738 |
+
logger.info(f" Rank {rank}: video_id={video_ids[vidx]}, video_idx={vidx}, score={scores[vidx]:.6f}")
|
| 739 |
+
logger.info("---")
|
| 740 |
+
|
| 741 |
+
logger.info("=== 모든 쿼리 결과 출력 완료 ===")
|
| 742 |
+
|
| 743 |
+
return tv_metrics['R1']
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
def main():
|
| 747 |
+
global logger
|
| 748 |
+
args = get_args()
|
| 749 |
+
args = set_seed_logger(args)
|
| 750 |
+
device, n_gpu = init_device(args, args.local_rank)
|
| 751 |
+
|
| 752 |
+
tokenizer = ClipTokenizer()
|
| 753 |
+
|
| 754 |
+
assert args.task_type == "retrieval"
|
| 755 |
+
model = init_model(args, device, n_gpu, args.local_rank)
|
| 756 |
+
|
| 757 |
+
## ####################################
|
| 758 |
+
# freeze testing
|
| 759 |
+
## ####################################
|
| 760 |
+
assert args.freeze_layer_num <= 12 and args.freeze_layer_num >= -1
|
| 761 |
+
if hasattr(model, "clip") and args.freeze_layer_num > -1:
|
| 762 |
+
for name, param in model.clip.named_parameters():
|
| 763 |
+
# top layers always need to train
|
| 764 |
+
if name.find("ln_final.") == 0 or name.find("text_projection") == 0 or name.find("logit_scale") == 0 \
|
| 765 |
+
or name.find("visual.ln_post.") == 0 or name.find("visual.proj") == 0:
|
| 766 |
+
continue # need to train
|
| 767 |
+
elif name.find("visual.transformer.resblocks.") == 0 or name.find("transformer.resblocks.") == 0:
|
| 768 |
+
layer_num = int(name.split(".resblocks.")[1].split(".")[0])
|
| 769 |
+
if layer_num >= args.freeze_layer_num:
|
| 770 |
+
continue # need to train
|
| 771 |
+
|
| 772 |
+
if args.linear_patch == "3d" and name.find("conv2."):
|
| 773 |
+
continue
|
| 774 |
+
else:
|
| 775 |
+
# paramenters which < freeze_layer_num will be freezed
|
| 776 |
+
param.requires_grad = False
|
| 777 |
+
|
| 778 |
+
## ####################################
|
| 779 |
+
# dataloader loading
|
| 780 |
+
## ####################################
|
| 781 |
+
assert args.datatype in DATALOADER_DICT
|
| 782 |
+
|
| 783 |
+
assert DATALOADER_DICT[args.datatype]["test"] is not None \
|
| 784 |
+
or DATALOADER_DICT[args.datatype]["val"] is not None
|
| 785 |
+
|
| 786 |
+
test_dataloader, test_length = None, 0
|
| 787 |
+
if DATALOADER_DICT[args.datatype]["test"] is not None:
|
| 788 |
+
test_dataloader, test_length = DATALOADER_DICT[args.datatype]["test"](args, tokenizer)
|
| 789 |
+
|
| 790 |
+
if DATALOADER_DICT[args.datatype]["val"] is not None:
|
| 791 |
+
val_dataloader, val_length = DATALOADER_DICT[args.datatype]["val"](args, tokenizer, subset="val")
|
| 792 |
+
else:
|
| 793 |
+
val_dataloader, val_length = test_dataloader, test_length
|
| 794 |
+
|
| 795 |
+
## report validation results if the ["test"] is None
|
| 796 |
+
if test_dataloader is None:
|
| 797 |
+
test_dataloader, test_length = val_dataloader, val_length
|
| 798 |
+
|
| 799 |
+
if args.local_rank == 0:
|
| 800 |
+
logger.info("***** Running test *****")
|
| 801 |
+
logger.info(" Num examples = %d", test_length)
|
| 802 |
+
logger.info(" Batch size = %d", args.batch_size_val)
|
| 803 |
+
logger.info(" Num steps = %d", len(test_dataloader))
|
| 804 |
+
logger.info("***** Running val *****")
|
| 805 |
+
logger.info(" Num examples = %d", val_length)
|
| 806 |
+
|
| 807 |
+
## ####################################
|
| 808 |
+
# train and eval
|
| 809 |
+
## ####################################
|
| 810 |
+
if args.do_train:
|
| 811 |
+
train_dataloader, train_length, train_sampler = DATALOADER_DICT[args.datatype]["train"](args, tokenizer)
|
| 812 |
+
num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1)
|
| 813 |
+
/ args.gradient_accumulation_steps) * args.epochs
|
| 814 |
+
|
| 815 |
+
coef_lr = args.coef_lr
|
| 816 |
+
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, args.local_rank, coef_lr=coef_lr)
|
| 817 |
+
|
| 818 |
+
if args.local_rank == 0:
|
| 819 |
+
logger.info("***** Running training *****")
|
| 820 |
+
logger.info(" Num examples = %d", train_length)
|
| 821 |
+
logger.info(" Batch size = %d", args.batch_size)
|
| 822 |
+
logger.info(" Num steps = %d", num_train_optimization_steps * args.gradient_accumulation_steps)
|
| 823 |
+
|
| 824 |
+
best_score = 0.00001
|
| 825 |
+
best_output_model_file = "None"
|
| 826 |
+
## ##############################################################
|
| 827 |
+
# resume optimizer state besides loss to continue train
|
| 828 |
+
## ##############################################################
|
| 829 |
+
resumed_epoch = 0
|
| 830 |
+
if args.resume_model:
|
| 831 |
+
checkpoint = torch.load(args.resume_model, map_location='cpu')
|
| 832 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 833 |
+
resumed_epoch = checkpoint['epoch']+1
|
| 834 |
+
resumed_loss = checkpoint['loss']
|
| 835 |
+
|
| 836 |
+
global_step = 0
|
| 837 |
+
for epoch in range(resumed_epoch, args.epochs):
|
| 838 |
+
train_sampler.set_epoch(epoch)
|
| 839 |
+
tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer,
|
| 840 |
+
scheduler, global_step, local_rank=args.local_rank)
|
| 841 |
+
|
| 842 |
+
if args.local_rank == 0:
|
| 843 |
+
logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
|
| 844 |
+
|
| 845 |
+
output_model_file = save_model(epoch, args, model, optimizer, tr_loss, type_name="")
|
| 846 |
+
|
| 847 |
+
## Run on val dataset, this process is *TIME-consuming*.
|
| 848 |
+
# logger.info("Eval on val dataset")
|
| 849 |
+
# R1 = eval_epoch(args, model, val_dataloader, device, n_gpu)
|
| 850 |
+
|
| 851 |
+
R1 = eval_epoch(args, model, test_dataloader, device, n_gpu, tokenizer)
|
| 852 |
+
if best_score <= R1:
|
| 853 |
+
best_score = R1
|
| 854 |
+
best_output_model_file = output_model_file
|
| 855 |
+
logger.info("The best model is: {}, the R1 is: {:.4f}".format(best_output_model_file, best_score))
|
| 856 |
+
|
| 857 |
+
## Uncomment if want to test on the best checkpoint
|
| 858 |
+
# if args.local_rank == 0:
|
| 859 |
+
# model = load_model(-1, args, n_gpu, device, model_file=best_output_model_file)
|
| 860 |
+
# eval_epoch(args, model, test_dataloader, device, n_gpu)
|
| 861 |
+
|
| 862 |
+
elif args.do_eval:
|
| 863 |
+
if args.local_rank == 0:
|
| 864 |
+
eval_epoch(args, model, test_dataloader, device, n_gpu, tokenizer)
|
| 865 |
+
|
| 866 |
+
if __name__ == "__main__":
|
| 867 |
+
main()
|
ckpts/cache_train_9k/log.txt
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2025-08-12 02:24:18,172:INFO: Effective parameters:
|
| 2 |
+
2025-08-12 02:24:18,172:INFO: <<< batch_size: 128
|
| 3 |
+
2025-08-12 02:24:18,172:INFO: <<< batch_size_val: 16
|
| 4 |
+
2025-08-12 02:24:18,172:INFO: <<< cache_dir:
|
| 5 |
+
2025-08-12 02:24:18,172:INFO: <<< coef_lr: 0.001
|
| 6 |
+
2025-08-12 02:24:18,172:INFO: <<< cross_model: cross-base
|
| 7 |
+
2025-08-12 02:24:18,173:INFO: <<< cross_num_hidden_layers: 4
|
| 8 |
+
2025-08-12 02:24:18,173:INFO: <<< data_path: /disk/gjw/msr-vtt/MSRVTT_data.json
|
| 9 |
+
2025-08-12 02:24:18,173:INFO: <<< datatype: msrvtt
|
| 10 |
+
2025-08-12 02:24:18,173:INFO: <<< do_eval: True
|
| 11 |
+
2025-08-12 02:24:18,173:INFO: <<< do_lower_case: False
|
| 12 |
+
2025-08-12 02:24:18,173:INFO: <<< do_pretrain: False
|
| 13 |
+
2025-08-12 02:24:18,173:INFO: <<< do_train: False
|
| 14 |
+
2025-08-12 02:24:18,173:INFO: <<< epochs: 5
|
| 15 |
+
2025-08-12 02:24:18,173:INFO: <<< eval_frame_order: 0
|
| 16 |
+
2025-08-12 02:24:18,173:INFO: <<< expand_msrvtt_sentences: True
|
| 17 |
+
2025-08-12 02:24:18,173:INFO: <<< feature_framerate: 1
|
| 18 |
+
2025-08-12 02:24:18,173:INFO: <<< features_path: /disk/gjw/msr-vtt/compressed_videos
|
| 19 |
+
2025-08-12 02:24:18,173:INFO: <<< fp16: False
|
| 20 |
+
2025-08-12 02:24:18,173:INFO: <<< fp16_opt_level: O1
|
| 21 |
+
2025-08-12 02:24:18,173:INFO: <<< freeze_layer_num: 0
|
| 22 |
+
2025-08-12 02:24:18,173:INFO: <<< gradient_accumulation_steps: 1
|
| 23 |
+
2025-08-12 02:24:18,173:INFO: <<< hard_negative_rate: 0.5
|
| 24 |
+
2025-08-12 02:24:18,173:INFO: <<< hash_bit: 2048
|
| 25 |
+
2025-08-12 02:24:18,174:INFO: <<< init_model: ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.0
|
| 26 |
+
2025-08-12 02:24:18,174:INFO: <<< linear_patch: 2d
|
| 27 |
+
2025-08-12 02:24:18,174:INFO: <<< local_rank: 0
|
| 28 |
+
2025-08-12 02:24:18,174:INFO: <<< loose_type: True
|
| 29 |
+
2025-08-12 02:24:18,174:INFO: <<< lr: 0.0001
|
| 30 |
+
2025-08-12 02:24:18,174:INFO: <<< lr_decay: 0.9
|
| 31 |
+
2025-08-12 02:24:18,174:INFO: <<< margin: 0.1
|
| 32 |
+
2025-08-12 02:24:18,174:INFO: <<< max_frames: 12
|
| 33 |
+
2025-08-12 02:24:18,174:INFO: <<< max_words: 32
|
| 34 |
+
2025-08-12 02:24:18,174:INFO: <<< n_display: 50
|
| 35 |
+
2025-08-12 02:24:18,174:INFO: <<< n_gpu: 1
|
| 36 |
+
2025-08-12 02:24:18,174:INFO: <<< n_pair: 1
|
| 37 |
+
2025-08-12 02:24:18,174:INFO: <<< negative_weighting: 1
|
| 38 |
+
2025-08-12 02:24:18,174:INFO: <<< num_thread_reader: 0
|
| 39 |
+
2025-08-12 02:24:18,174:INFO: <<< output_dir: ckpts/cache_train_9k
|
| 40 |
+
2025-08-12 02:24:18,174:INFO: <<< pretrained_clip_name: ViT-B/32
|
| 41 |
+
2025-08-12 02:24:18,174:INFO: <<< rank: 0
|
| 42 |
+
2025-08-12 02:24:18,174:INFO: <<< resume_model: None
|
| 43 |
+
2025-08-12 02:24:18,174:INFO: <<< rff_dim: 3000
|
| 44 |
+
2025-08-12 02:24:18,175:INFO: <<< sampled_use_mil: False
|
| 45 |
+
2025-08-12 02:24:18,175:INFO: <<< seed: 42
|
| 46 |
+
2025-08-12 02:24:18,175:INFO: <<< sim_header: meanP
|
| 47 |
+
2025-08-12 02:24:18,175:INFO: <<< slice_framepos: 2
|
| 48 |
+
2025-08-12 02:24:18,175:INFO: <<< task_type: retrieval
|
| 49 |
+
2025-08-12 02:24:18,175:INFO: <<< text_num_hidden_layers: 12
|
| 50 |
+
2025-08-12 02:24:18,175:INFO: <<< train_csv: /disk/gjw/msr-vtt/MSRVTT_train.9k.csv
|
| 51 |
+
2025-08-12 02:24:18,175:INFO: <<< train_frame_order: 0
|
| 52 |
+
2025-08-12 02:24:18,175:INFO: <<< use_clip4hashing: False
|
| 53 |
+
2025-08-12 02:24:18,175:INFO: <<< use_mil: False
|
| 54 |
+
2025-08-12 02:24:18,175:INFO: <<< use_rff: False
|
| 55 |
+
2025-08-12 02:24:18,175:INFO: <<< val_csv: /disk/gjw/msr-vtt/MSRVTT_JSFUSION_train_test_10k.csv
|
| 56 |
+
2025-08-12 02:24:18,175:INFO: <<< video_dim: 1024
|
| 57 |
+
2025-08-12 02:24:18,175:INFO: <<< visual_num_hidden_layers: 12
|
| 58 |
+
2025-08-12 02:24:18,175:INFO: <<< warmup_proportion: 0.1
|
| 59 |
+
2025-08-12 02:24:18,175:INFO: <<< world_size: 1
|
| 60 |
+
2025-08-12 02:24:18,176:INFO: device: cuda:0 n_gpu: 1
|
| 61 |
+
2025-08-12 02:24:19,207:INFO: loading archive file /disk/gjw/CLIP4Clip/modules/cross-base
|
| 62 |
+
2025-08-12 02:24:19,207:INFO: Model config {
|
| 63 |
+
"attention_probs_dropout_prob": 0.1,
|
| 64 |
+
"hidden_act": "gelu",
|
| 65 |
+
"hidden_dropout_prob": 0.1,
|
| 66 |
+
"hidden_size": 512,
|
| 67 |
+
"initializer_range": 0.02,
|
| 68 |
+
"intermediate_size": 2048,
|
| 69 |
+
"max_position_embeddings": 128,
|
| 70 |
+
"num_attention_heads": 8,
|
| 71 |
+
"num_hidden_layers": 4,
|
| 72 |
+
"type_vocab_size": 2,
|
| 73 |
+
"vocab_size": 512
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
2025-08-12 02:24:19,208:INFO: Weight doesn't exsits. /disk/gjw/CLIP4Clip/modules/cross-base/cross_pytorch_model.bin
|
| 77 |
+
2025-08-12 02:24:19,208:WARNING: Stage-One:True, Stage-Two:False
|
| 78 |
+
2025-08-12 02:24:19,208:WARNING: Test retrieval by loose type.
|
| 79 |
+
2025-08-12 02:24:19,208:WARNING: embed_dim: 512
|
| 80 |
+
2025-08-12 02:24:19,208:WARNING: image_resolution: 224
|
| 81 |
+
2025-08-12 02:24:19,208:WARNING: vision_layers: 12
|
| 82 |
+
2025-08-12 02:24:19,208:WARNING: vision_width: 768
|
| 83 |
+
2025-08-12 02:24:19,208:WARNING: vision_patch_size: 32
|
| 84 |
+
2025-08-12 02:24:19,208:WARNING: context_length: 77
|
| 85 |
+
2025-08-12 02:24:19,209:WARNING: vocab_size: 49408
|
| 86 |
+
2025-08-12 02:24:19,209:WARNING: transformer_width: 512
|
| 87 |
+
2025-08-12 02:24:19,209:WARNING: transformer_heads: 8
|
| 88 |
+
2025-08-12 02:24:19,209:WARNING: transformer_layers: 12
|
| 89 |
+
2025-08-12 02:24:19,209:WARNING: linear_patch: 2d
|
| 90 |
+
2025-08-12 02:24:19,209:WARNING: cut_top_layer: 0
|
| 91 |
+
2025-08-12 02:24:22,820:WARNING: sim_header: meanP
|
| 92 |
+
2025-08-12 02:24:31,116:INFO: --------------------
|
| 93 |
+
2025-08-12 02:24:31,117:INFO: Weights from pretrained model not used in CLIP4Clip:
|
| 94 |
+
clip.input_resolution
|
| 95 |
+
clip.context_length
|
| 96 |
+
clip.vocab_size
|
| 97 |
+
2025-08-12 02:24:32,963:INFO: ***** Running test *****
|
| 98 |
+
2025-08-12 02:24:32,963:INFO: Num examples = 10000
|
| 99 |
+
2025-08-12 02:24:32,963:INFO: Batch size = 16
|
| 100 |
+
2025-08-12 02:24:32,963:INFO: Num steps = 625
|
| 101 |
+
2025-08-12 02:24:32,963:INFO: ***** Running val *****
|
| 102 |
+
2025-08-12 02:24:32,963:INFO: Num examples = 10000
|
| 103 |
+
2025-08-12 02:24:32,966:INFO: Model testing
|
| 104 |
+
2025-08-12 02:24:32,966:INFO: 9k 훈련 데이터 캐시 생성: msrvtt_train_test_10k_cache_trained.pt
|
| 105 |
+
2025-08-12 02:24:32,966:INFO: suffix: _trained
|
| 106 |
+
2025-08-12 05:49:06,718:INFO: 추출된 피처를 캐시에 저장합니다: ckpts/cache_train_9k/msrvtt_train_test_10k_cache_trained.pt
|
| 107 |
+
2025-08-12 05:49:07,252:INFO: [Cache] 텍스트 피쳐 개수=625각 텐서 shape=torch.Size([16, 1, 512])
|
| 108 |
+
2025-08-12 05:49:07,253:INFO: [Cache] 비디오 피쳐 개수=625각 텐서 shape=torch.Size([16, 12, 512])
|
| 109 |
+
2025-08-12 05:51:36,241:INFO: sim matrix size: 10000, 10000
|
| 110 |
+
2025-08-12 05:51:36,355:INFO: Saved sim_matrix heatmap to ckpts/cache_train_9k/sim_matrix_heatmap.png
|
| 111 |
+
2025-08-12 05:53:13,421:INFO: Length-T: 10000, Length-V:10000
|
| 112 |
+
2025-08-12 05:53:13,421:INFO: Text-to-Video:
|
| 113 |
+
2025-08-12 05:53:13,421:INFO: >>> R@1: 26.6 - R@5: 51.2 - R@10: 61.3 - Median R: 5.0 - Mean R: 127.3
|
| 114 |
+
2025-08-12 05:53:13,421:INFO: Video-to-Text:
|
| 115 |
+
2025-08-12 05:53:13,421:INFO: >>> V2T$R@1: 17.9 - V2T$R@5: 32.8 - V2T$R@10: 39.8 - V2T$Median R: 25.0 - V2T$Mean R: 199.8
|
ckpts/cache_train_9k/msrvtt_train_test_10k_cache_trained.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:840d1fbfa39d4fda94294e01b7a2d7bdf409ab04154f9c38ebabc79dec4902e0
|
| 3 |
+
size 273271799
|
ckpts/cache_train_9k/sim_matrix_heatmap.png
ADDED
|
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/log.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/msrvtt_eval_cache.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:326307d8fa7b94df2ae84ef2a983ff9cc7bc771011435d438834849340bbceff
|
| 3 |
+
size 27327894
|
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/msrvtt_eval_cache_proj3008_binary_trained.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1f276f0ad05bc6827baef7bc199dbce104552d818e71b6003412360c423c5340
|
| 3 |
+
size 27623162
|
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/msrvtt_eval_cache_proj3008_trained.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aa0cb7b6f93768d7df200b4da4f666c7a326ef2019697d43deef4e7690df65ec
|
| 3 |
+
size 27620488
|
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/pytorch_model.bin.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ee6b844c4664b89f5fdb5575807a091b22337037716c396837826f694fef2e5e
|
| 3 |
+
size 359716242
|
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/pytorch_opt.bin.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:71f4d99edafdbc5681c53834aa5fed86e53c92f87c97ef4201967cda0d04ee1b
|
| 3 |
+
size 494597414
|
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/sim_matrix_heatmap.png
ADDED
|
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/t2v_top10_proj3008_binary_trained.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{}
|
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/t2v_top10_proj3008_binary_trained.tsv
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
query_idx query video_rank video_id video_idx score
|
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/t2v_top10_proj3008_trained.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{}
|
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/t2v_top10_proj3008_trained.tsv
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
query_idx query video_rank video_id video_idx score
|
ckpts/ckpt_msrvtt_retrieval_looseType/log.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e6ce285755f9c657d33f5215e1796baf094cda32a8728a2189864baa413e5a4e
|
| 3 |
+
size 22286632
|
ckpts/ckpt_msrvtt_retrieval_looseType/msrvtt_eval_cache.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:42ff89065849459a75dcbdd4058362bf4071f47ff542b1d6a4f8db03389af43e
|
| 3 |
+
size 27327893
|
ckpts/ckpt_msrvtt_retrieval_looseType/msrvtt_eval_cache_proj3008_binary_trained.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:36ee838a7701808f023b70616e893119fa5417eaff0406989d6664d9d1ccc554
|
| 3 |
+
size 27623162
|
ckpts/ckpt_msrvtt_retrieval_looseType/msrvtt_eval_cache_proj3008_trained.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:42c7b8bbca50170c1e2c4227b452028e744a6ff390b1b57107399baa11189dda
|
| 3 |
+
size 27620488
|
ckpts/ckpt_msrvtt_retrieval_looseType/msrvtt_eval_cache_trained.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:450bd99bb45b5ed1855e21e0a8d153d9b61a653944076882b36f66042473686a
|
| 3 |
+
size 27327893
|
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:905df44d3f0b80981471701f4f7a2ec8430d57f9cedf38eec139a56a7b184858
|
| 3 |
+
size 353556437
|
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.1
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a2ae61ac9ce4df7c3c99b825c8c8a7f20cc25337aae07e297312906075d94930
|
| 3 |
+
size 353556437
|
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.2
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:20804da0ab5454d35c2ba99fe65f8e08bd1b86c4966fcf961f8e143d659b9066
|
| 3 |
+
size 353556437
|
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eb87fe2412b22327ae764affbbdc056aab4a7191f541036698af0a2744af7219
|
| 3 |
+
size 353556437
|
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3d7e2ba508986c7347380767233b017d4ac4c9fabc135f3457de8625dc833dca
|
| 3 |
+
size 353556437
|
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8795af40e7f2e6a9e166657b854f20aeb606843a2d8b67021155160f50c9de44
|
| 3 |
+
size 494600757
|
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.1
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:89dfc0f35eb9cdabe12930b544434c30e1fcb84f8ac31004fe31a6c9c3747f3c
|
| 3 |
+
size 494600757
|
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.2
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:93a48879a8b7b7a9c5772f9259de9dfdef7698b06e74811df6ff255737a257f0
|
| 3 |
+
size 494600757
|
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:adfc35ef88500e828fa539e426a46b03dc1f989b5203636e41c228f4ed1be798
|
| 3 |
+
size 494600757
|
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:27c179a6d93b10941a88e691d2213a4c057c9a5d862ac409f26041dacd7fd85c
|
| 3 |
+
size 494600757
|
ckpts/ckpt_msrvtt_retrieval_looseType/sim_matrix_heatmap.png
ADDED
|
ckpts/ckpt_msrvtt_retrieval_looseType/t2v_top10_proj3008_binary_trained.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{}
|
ckpts/ckpt_msrvtt_retrieval_looseType/t2v_top10_proj3008_binary_trained.tsv
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
query_idx query video_rank video_id video_idx score
|
ckpts/ckpt_msrvtt_retrieval_looseType/t2v_top10_proj3008_trained.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{}
|
ckpts/ckpt_msrvtt_retrieval_looseType/t2v_top10_proj3008_trained.tsv
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
query_idx query video_rank video_id video_idx score
|