jaewooo commited on
Commit
de15dc5
·
verified ·
1 Parent(s): 71ebe1d

Initial upload

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .claude/settings.local.json +15 -0
  2. .gitattributes +20 -0
  3. .gitignore +2 -0
  4. .vscode/settings.json +5 -0
  5. CLAUDE.md +137 -0
  6. CLIP4Clip.png +3 -0
  7. LICENSE +21 -0
  8. README.md +193 -0
  9. __pycache__/metrics.cpython-312.pyc +0 -0
  10. __pycache__/metrics.cpython-37.pyc +0 -0
  11. __pycache__/metrics.cpython-39.pyc +0 -0
  12. __pycache__/simple_dataloaders.cpython-37.pyc +0 -0
  13. __pycache__/util.cpython-312.pyc +0 -0
  14. __pycache__/util.cpython-37.pyc +0 -0
  15. cache_main_task_retrieval.py +1053 -0
  16. cache_main_task_retrieval_backup.py +867 -0
  17. ckpts/cache_train_9k/log.txt +115 -0
  18. ckpts/cache_train_9k/msrvtt_train_test_10k_cache_trained.pt +3 -0
  19. ckpts/cache_train_9k/sim_matrix_heatmap.png +0 -0
  20. ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/log.txt +0 -0
  21. ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/msrvtt_eval_cache.pt +3 -0
  22. ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/msrvtt_eval_cache_proj3008_binary_trained.pt +3 -0
  23. ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/msrvtt_eval_cache_proj3008_trained.pt +3 -0
  24. ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/pytorch_model.bin.0 +3 -0
  25. ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/pytorch_opt.bin.0 +3 -0
  26. ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/sim_matrix_heatmap.png +0 -0
  27. ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/t2v_top10_proj3008_binary_trained.json +1 -0
  28. ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/t2v_top10_proj3008_binary_trained.tsv +1 -0
  29. ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/t2v_top10_proj3008_trained.json +1 -0
  30. ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/t2v_top10_proj3008_trained.tsv +1 -0
  31. ckpts/ckpt_msrvtt_retrieval_looseType/log.txt +3 -0
  32. ckpts/ckpt_msrvtt_retrieval_looseType/msrvtt_eval_cache.pt +3 -0
  33. ckpts/ckpt_msrvtt_retrieval_looseType/msrvtt_eval_cache_proj3008_binary_trained.pt +3 -0
  34. ckpts/ckpt_msrvtt_retrieval_looseType/msrvtt_eval_cache_proj3008_trained.pt +3 -0
  35. ckpts/ckpt_msrvtt_retrieval_looseType/msrvtt_eval_cache_trained.pt +3 -0
  36. ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.0 +3 -0
  37. ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.1 +3 -0
  38. ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.2 +3 -0
  39. ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.3 +3 -0
  40. ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.4 +3 -0
  41. ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.0 +3 -0
  42. ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.1 +3 -0
  43. ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.2 +3 -0
  44. ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.3 +3 -0
  45. ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.4 +3 -0
  46. ckpts/ckpt_msrvtt_retrieval_looseType/sim_matrix_heatmap.png +0 -0
  47. ckpts/ckpt_msrvtt_retrieval_looseType/t2v_top10_proj3008_binary_trained.json +1 -0
  48. ckpts/ckpt_msrvtt_retrieval_looseType/t2v_top10_proj3008_binary_trained.tsv +1 -0
  49. ckpts/ckpt_msrvtt_retrieval_looseType/t2v_top10_proj3008_trained.json +1 -0
  50. ckpts/ckpt_msrvtt_retrieval_looseType/t2v_top10_proj3008_trained.tsv +1 -0
.claude/settings.local.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "permissions": {
3
+ "allow": [
4
+ "Bash(grep:*)",
5
+ "Bash(python:*)",
6
+ "Bash(conda create:*)",
7
+ "Bash(conda activate:*)",
8
+ "Bash(pip install:*)",
9
+ "Bash(source:*)",
10
+ "Bash(pkill:*)",
11
+ "Bash(chmod:*)"
12
+ ],
13
+ "deny": []
14
+ }
15
+ }
.gitattributes CHANGED
@@ -33,3 +33,23 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ CLIP4Clip.png filter=lfs diff=lfs merge=lfs -text
37
+ ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/pytorch_model.bin.0 filter=lfs diff=lfs merge=lfs -text
38
+ ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/pytorch_opt.bin.0 filter=lfs diff=lfs merge=lfs -text
39
+ ckpts/ckpt_msrvtt_retrieval_looseType/log.txt filter=lfs diff=lfs merge=lfs -text
40
+ ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.0 filter=lfs diff=lfs merge=lfs -text
41
+ ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.1 filter=lfs diff=lfs merge=lfs -text
42
+ ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.2 filter=lfs diff=lfs merge=lfs -text
43
+ ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.3 filter=lfs diff=lfs merge=lfs -text
44
+ ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.4 filter=lfs diff=lfs merge=lfs -text
45
+ ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.0 filter=lfs diff=lfs merge=lfs -text
46
+ ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.1 filter=lfs diff=lfs merge=lfs -text
47
+ ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.2 filter=lfs diff=lfs merge=lfs -text
48
+ ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.3 filter=lfs diff=lfs merge=lfs -text
49
+ ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.4 filter=lfs diff=lfs merge=lfs -text
50
+ ckpts/ckpt_msrvtt_retrieval_looseType_0909/pytorch_model.bin.0 filter=lfs diff=lfs merge=lfs -text
51
+ ckpts/ckpt_msrvtt_retrieval_looseType_0909/pytorch_model.bin.1 filter=lfs diff=lfs merge=lfs -text
52
+ ckpts/ckpt_msrvtt_retrieval_looseType_0909/pytorch_model.bin.2 filter=lfs diff=lfs merge=lfs -text
53
+ ckpts/ckpt_msrvtt_retrieval_looseType_0909/pytorch_opt.bin.0 filter=lfs diff=lfs merge=lfs -text
54
+ ckpts/ckpt_msrvtt_retrieval_looseType_0909/pytorch_opt.bin.1 filter=lfs diff=lfs merge=lfs -text
55
+ ckpts/ckpt_msrvtt_retrieval_looseType_0909/pytorch_opt.bin.2 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .git
2
+ .idea
.vscode/settings.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "python-envs.defaultEnvManager": "ms-python.python:conda",
3
+ "python-envs.defaultPackageManager": "ms-python.python:conda",
4
+ "python-envs.pythonProjects": []
5
+ }
CLAUDE.md ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLAUDE.md
2
+
3
+ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
+
5
+ ## Project Overview
6
+
7
+ CLIP4Clip is a video-text retrieval model based on OpenAI's CLIP (ViT-B). The project investigates three similarity calculation approaches: parameter-free type, sequential type, and tight type. This repository has been extended with additional features including hashing, hypervector representations, and random Fourier features (RFF).
8
+
9
+ ## Environment Setup
10
+
11
+ The project uses conda for environment management:
12
+ ```bash
13
+ conda env create -f environment.yml
14
+ conda activate c4c
15
+ ```
16
+
17
+ Key dependencies: PyTorch 1.7.1, CUDA 11.0, opencv-python, tqdm, ftfy, regex, pandas, boto3
18
+
19
+ ## Development Commands
20
+
21
+ ### Training Commands
22
+ - **Standard training**: `./train.sh` - Trains CLIP4Clip with hashing extensions
23
+ - **Distributed training**: Uses `torch.distributed.launch` with multiple GPUs
24
+ - **Key training script**: `cache_main_task_retrieval.py` (main training entry point)
25
+ - **Original training script**: `main_task_retrieval.py` (reference implementation)
26
+
27
+ ### Testing/Evaluation Commands
28
+ - **Standard evaluation**: `./test.sh` - Evaluates with RFF features
29
+ - **Hash evaluation**: `./test_hash.sh` - Evaluates hashing-based models
30
+ - **Hypervector evaluation**: `./test_hv.sh` - Evaluates with high-dimensional RFF
31
+ - All test scripts use `cache_main_task_retrieval.py --do_eval`
32
+
33
+ ### Data Preprocessing
34
+ ```bash
35
+ python preprocess/compress_video.py --input_root [raw_video_path] --output_root [compressed_video_path]
36
+ ```
37
+
38
+ ### CLIP Model Downloads
39
+ ```bash
40
+ # ViT-B/32 (default)
41
+ wget -P ./modules https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt
42
+
43
+ # ViT-B/16 (better performance)
44
+ wget -P ./modules https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt
45
+ ```
46
+
47
+ ## Architecture Overview
48
+
49
+ ### Core Components
50
+ - **Main Models**:
51
+ - `modules/modeling.py` - CLIP4Clip main model class
52
+ - `modules/module_clip.py` - CLIP model implementation
53
+ - `modules/module_cross.py` - Cross-modal interaction modules
54
+
55
+ - **Extended Features**:
56
+ - `modules/hashnet.py` - Hash-based representations
57
+ - `modules/hypervector.py` - Hypervector/random projection methods
58
+ - `modules/binarize_ste.py` - Straight-through estimator for binarization
59
+ - `modules/minmax_hash.py` - MinMax hashing implementation
60
+
61
+ ### Data Loading
62
+ - **Dataloaders**: `dataloaders/data_dataloaders.py` provides unified interface
63
+ - **Supported datasets**: MSR-VTT, MSVD, LSMDC, ActivityNet, DiDeMo
64
+ - **Key parameters**: `--datatype`, `--features_path`, `--max_frames`, `--max_words`
65
+
66
+ ### Training Configuration
67
+ - **Similarity headers**: `meanP` (parameter-free), `seqLSTM`, `seqTransf`, `tightTransf`
68
+ - **Linear patch**: `2d` or `3d` patch projection
69
+ - **Frame sampling**: `--slice_framepos` (0=head, 1=tail, 2=uniform)
70
+ - **Model variants**: Standard CLIP4Clip, with hashing (`--use_clip4hashing`), with RFF (`--use_rff`)
71
+
72
+ ### Model Checkpoints
73
+ - Saved in `ckpts/` directory with model versions
74
+ - Format: `pytorch_model.bin.{epoch}` and `pytorch_opt.bin.{epoch}`
75
+ - Cache files for evaluation: `*_eval_cache.pt`, `*_eval_cache_trained.pt`
76
+
77
+ ## Key Parameters for Development
78
+
79
+ ### Essential Arguments
80
+ - `--datatype`: Dataset type (msrvtt, msvd, lsmdc, activity, didemo)
81
+ - `--features_path`: Path to video features/files
82
+ - `--output_dir`: Checkpoint output directory
83
+ - `--sim_header`: Similarity calculation method
84
+ - `--pretrained_clip_name`: CLIP model version (ViT-B/32 or ViT-B/16)
85
+
86
+ ### Extended Features
87
+ - `--use_clip4hashing --hash_bit 2048`: Enable hashing with specified bit size
88
+ - `--use_rff --rff_dim 3000`: Enable Random Fourier Features with dimension
89
+ - `--freeze_layer_num`: Number of CLIP layers to freeze (0-12)
90
+
91
+ ### Training Parameters
92
+ - Distributed training uses `torch.distributed.launch`
93
+ - Typical batch sizes: 128-512 for training, 16 for validation
94
+ - Learning rates: 1e-4 (base), 1e-3 (coef_lr for CLIP parameters)
95
+ - Standard training: 5-50 epochs depending on dataset
96
+
97
+ ## Additional Development Information
98
+
99
+ ### Environment and Dependencies
100
+ ```bash
101
+ # Create conda environment
102
+ conda env create -f environment.yml
103
+ conda activate c4c
104
+
105
+ # Install additional dependencies if needed
106
+ pip install ftfy regex tqdm opencv-python boto3 requests pandas
107
+ ```
108
+
109
+ ### Quick Start Commands
110
+ - **Setup environment**: `conda env create -f environment.yml && conda activate c4c`
111
+ - **Single GPU training**: `python cache_main_task_retrieval.py --do_train [args]`
112
+ - **Multi-GPU training**: `python -m torch.distributed.launch --nproc_per_node=2 cache_main_task_retrieval.py --do_train [args]`
113
+ - **Evaluation only**: `python cache_main_task_retrieval.py --do_eval [args]`
114
+
115
+ ### Debug and Development
116
+ - **Check GPU availability**: Use `torch.cuda.is_available()` in Python
117
+ - **Monitor training**: Logs are saved in output_dir with format `log.txt`
118
+ - **Resume training**: Use `--resume_model` with `--init_model` to continue from checkpoint
119
+ - **Video preprocessing**: Optional compression with `preprocess/compress_video.py` for 3fps/224px
120
+
121
+ ### Performance Tips
122
+ - **Batch sizes**: Start with 128 for training, 16 for validation; adjust based on GPU memory
123
+ - **Frame sampling**: `--slice_framepos 2` (uniform) generally works best
124
+ - **CLIP model**: ViT-B/16 provides better performance than ViT-B/32 but requires more memory
125
+ - **Frozen layers**: `--freeze_layer_num 12` freezes all CLIP layers for faster training
126
+
127
+ ### Data Structure Expectations
128
+ - **Video files**: Raw videos or compressed (3fps, 224px recommended)
129
+ - **CSV format**: Training/validation splits follow MSRVTT format
130
+ - **JSON data**: Captions and metadata in standard video-text retrieval format
131
+ - **Features path**: Directory containing video files, organized by dataset type
132
+
133
+ ### Common Issues and Solutions
134
+ - **CUDA out of memory**: Reduce batch_size or max_frames/max_words
135
+ - **Slow data loading**: Increase num_thread_reader (but not too high)
136
+ - **Poor performance**: Check pretrained CLIP weights are loaded correctly
137
+ - **Resume failures**: Ensure both pytorch_model.bin and pytorch_opt.bin exist for the epoch
CLIP4Clip.png ADDED

Git LFS Details

  • SHA256: 99ba99d88aaed58324e108334d67841dd4d36a6f339955080988e6c4f3306bf2
  • Pointer size: 131 Bytes
  • Size of remote file: 334 kB
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 ArrowLuo
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLIP4Clip: An Empirical Study of CLIP for End to End Video Clip Retrieval
2
+
3
+ (**July 28, 2021**) Add ViT-B/16 with an extra `--pretrained_clip_name`
4
+
5
+ (**Apr. 22, 2021**) First version
6
+
7
+ The implementation of paper [**CLIP4Clip: An Empirical Study of CLIP for End to End Video Clip Retrieval**](https://arxiv.org/abs/2104.08860).
8
+
9
+ CLIP4Clip is a video-text retrieval model based on [CLIP (ViT-B)](https://github.com/openai/CLIP). We investigate three similarity calculation approaches: parameter-free type, sequential type, and tight type, in this work. The model achieve SOTA results on MSR-VTT, MSVD, LSMDC, ActivityNet, and DiDeMo.
10
+
11
+ ![CLIP4Clip](CLIP4Clip.png)
12
+
13
+ ## Requirement
14
+ ```sh
15
+ # From CLIP
16
+ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
17
+ pip install ftfy regex tqdm
18
+ pip install opencv-python boto3 requests pandas
19
+ ```
20
+
21
+ ## Data Preparing
22
+
23
+ **For MSRVTT**
24
+
25
+ The official data and video links can be found in [link](http://ms-multimedia-challenge.com/2017/dataset).
26
+
27
+ For the convenience, you can also download the splits and captions by,
28
+ ```sh
29
+ wget https://github.com/ArrowLuo/CLIP4Clip/releases/download/v0.0/msrvtt_data.zip
30
+ ```
31
+
32
+ Besides, the raw videos can be found in [sharing](https://github.com/m-bain/frozen-in-time#-finetuning-benchmarks-msr-vtt) from *Frozen️ in Time*, i.e.,
33
+ ```sh
34
+ wget https://www.robots.ox.ac.uk/~maxbain/frozen-in-time/data/MSRVTT.zip
35
+ ```
36
+
37
+ **For MSVD**
38
+
39
+ Raw videos can be download from [link](https://www.cs.utexas.edu/users/ml/clamp/videoDescription/).
40
+
41
+ The splits and `raw_captions` can be found in the wonderful job [collaborative-experts](https://github.com/albanie/collaborative-experts/blob/master/misc/datasets/msvd/README.md). For the convenience, you can also download them by,
42
+ ```sh
43
+ wget https://github.com/ArrowLuo/CLIP4Clip/releases/download/v0.0/msvd_data.zip
44
+ ```
45
+
46
+ **For LSMDC**
47
+
48
+ You must obtain permission from MPII to download and use the data. The download link is [here](https://sites.google.com/site/describingmovies/download).
49
+ The 1000 test clips data is [link](http://www.google.com/url?q=http%3A%2F%2Fdatasets.d2.mpi-inf.mpg.de%2FmovieDescription%2Fprotected%2Flsmdc2016%2FLSMDC16_challenge_1000_publictect.csv&sa=D&sntz=1&usg=AFQjCNGIaGVhCeb6zNfUs2UL1zNzoEtaSg). Read our paper and the [dataloader](./dataloaders/dataloader_lsmdc_retrieval.py) for more information.
50
+
51
+ **For ActivityNet**
52
+
53
+ The official websit has made the full dataset available on Google and Baidu drives, see more information at [here](http://activity-net.org/download.html) . The splits can be found in the job [collaborative-experts](https://github.com/albanie/collaborative-experts/tree/master/misc/datasets/activity-net).
54
+
55
+ **For DiDeMo**
56
+
57
+ Raw videos can be download from [LisaAnne/LocalizingMoments](https://github.com/LisaAnne/LocalizingMoments). The splits can be found in the job [collaborative-experts](https://github.com/albanie/collaborative-experts/tree/master/misc/datasets/didemo/README.md).
58
+
59
+
60
+ ## Compress Video for Speed-up (optional)
61
+ ```sh
62
+ python preprocess/compress_video.py --input_root [raw_video_path] --output_root [compressed_video_path]
63
+ ```
64
+ This script will compress the video to *3fps* with width *224* (or height *224*). Modify the variables for your customization.
65
+
66
+ ## How to Run
67
+
68
+ >`--features_path` is the video root path
69
+ >
70
+ >`--linear_patch` can be set with `2d` or `3d`
71
+ >
72
+ > `--sim_header` can be set with `meanP`, `seqLSTM`, `seqTransf`, or `tightTransf`
73
+ >
74
+ > `--pretrained_clip_name` can be set with `ViT-B/32` or `ViT-B/16`
75
+ >
76
+ > `--resume_model` can be used to reload the saved optimizer state to continuely train the model, **Note**: need to set the corresponding chechpoint via `--init_model` simultaneously.
77
+
78
+ read our paper for more details on `--linear_patch` and `--sim_header`. Test more hyperparameters for better performance.
79
+
80
+ Download CLIP (ViT-B/32) weight,
81
+ ```sh
82
+ wget -P ./modules https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt
83
+ ```
84
+ or, download CLIP (ViT-B/16) weight,
85
+ ```sh
86
+ wget -P ./modules https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt
87
+ ```
88
+
89
+ Then, run
90
+
91
+
92
+ *The CLIP (ViT-B/32) is the default setting in the paper, replacing with the ViT-B/16 for better performance.*
93
+
94
+ ### MSRVTT
95
+
96
+ ```sh
97
+ DATA_PATH=[Your MSRVTT data and videos path]
98
+ python -m torch.distributed.launch --nproc_per_node=4 \
99
+ main_task_retrieval.py --do_train --num_thread_reader=0 \
100
+ --epochs=5 --batch_size=128 --n_display=50 \
101
+ --train_csv ${DATA_PATH}/MSRVTT_train.9k.csv \
102
+ --val_csv ${DATA_PATH}/MSRVTT_JSFUSION_test.csv \
103
+ --data_path ${DATA_PATH}/MSRVTT_data.json \
104
+ --features_path ${DATA_PATH}/MSRVTT_Videos \
105
+ --output_dir ckpts/ckpt_msrvtt_retrieval_looseType \
106
+ --lr 1e-4 --max_words 32 --max_frames 12 --batch_size_val 16 \
107
+ --datatype msrvtt --expand_msrvtt_sentences \
108
+ --feature_framerate 1 --coef_lr 1e-3 \
109
+ --freeze_layer_num 0 --slice_framepos 2 \
110
+ --loose_type --linear_patch 2d --sim_header meanP \
111
+ --pretrained_clip_name ViT-B/32
112
+ ```
113
+
114
+ ### MSVD
115
+ ```sh
116
+ DATA_PATH=[Your MSVD data and videos path]
117
+ python -m torch.distributed.launch --nproc_per_node=4 \
118
+ main_task_retrieval.py --do_train --num_thread_reader=2 \
119
+ --epochs=5 --batch_size=128 --n_display=50 \
120
+ --data_path ${DATA_PATH} \
121
+ --features_path ${DATA_PATH}/MSVD_Videos \
122
+ --output_dir ckpts/ckpt_msvd_retrieval_looseType \
123
+ --lr 1e-4 --max_words 32 --max_frames 12 --batch_size_val 16 \
124
+ --datatype msvd \
125
+ --feature_framerate 1 --coef_lr 1e-3 \
126
+ --freeze_layer_num 0 --slice_framepos 2 \
127
+ --loose_type --linear_patch 2d --sim_header meanP \
128
+ --pretrained_clip_name ViT-B/32
129
+ ```
130
+
131
+ ### LSMDC
132
+ ```sh
133
+ DATA_PATH=[Your LSMDC data and videos path]
134
+ python -m torch.distributed.launch --nproc_per_node=4 \
135
+ main_task_retrieval.py --do_train --num_thread_reader=2 \
136
+ --epochs=5 --batch_size=128 --n_display=50 \
137
+ --data_path ${DATA_PATH} \
138
+ --features_path ${DATA_PATH}/LSMDC_Videos \
139
+ --output_dir ckpts/ckpt_lsmdc_retrieval_looseType \
140
+ --lr 1e-4 --max_words 32 --max_frames 12 --batch_size_val 16 \
141
+ --datatype lsmdc --feature_framerate 1 --coef_lr 1e-3 \
142
+ --freeze_layer_num 0 --slice_framepos 2 \
143
+ --loose_type --linear_patch 2d --sim_header meanP \
144
+ --pretrained_clip_name ViT-B/32
145
+ ```
146
+
147
+ ### ActivityNet
148
+ ActivityNet is regarded as video-paragraph retrieval in our setting, thus, need more GPUs (or run with multi-node).
149
+ ```sh
150
+ DATA_PATH=[Your ActivityNet data and videos path]
151
+ python -m torch.distributed.launch --nproc_per_node=8 \
152
+ main_task_retrieval.py --do_train --num_thread_reader=2 \
153
+ --epochs=5 --batch_size=128 --n_display=50 \
154
+ --data_path ${DATA_PATH} \
155
+ --features_path ${DATA_PATH}/Activity_Videos \
156
+ --output_dir ckpts/ckpt_activity_retrieval_looseType \
157
+ --lr 1e-4 --max_words 64 --max_frames 64 --batch_size_val 16 \
158
+ --datatype activity --feature_framerate 1 --coef_lr 1e-3 \
159
+ --freeze_layer_num 0 --slice_framepos 2 \
160
+ --loose_type --linear_patch 2d --sim_header meanP \
161
+ --pretrained_clip_name ViT-B/32
162
+ ```
163
+
164
+ ### DiDeMo
165
+ DiDeMo is regarded as video-paragraph retrieval in our setting, thus, need more GPUs (or run with multi-node).
166
+ ```sh
167
+ DATA_PATH=[Your DiDeMo data and videos path]
168
+ python -m torch.distributed.launch --nproc_per_node=8 \
169
+ main_task_retrieval.py --do_train --num_thread_reader=2 \
170
+ --epochs=5 --batch_size=128 --n_display=50 \
171
+ --data_path ${DATA_PATH} \
172
+ --features_path ${DATA_PATH}/DiDeMo_Videos \
173
+ --output_dir ckpts/ckpt_didemo_retrieval_looseType \
174
+ --lr 1e-4 --max_words 64 --max_frames 64 --batch_size_val 16 \
175
+ --datatype didemo --feature_framerate 1 --coef_lr 1e-3 \
176
+ --freeze_layer_num 0 --slice_framepos 2 \
177
+ --loose_type --linear_patch 2d --sim_header meanP \
178
+ --pretrained_clip_name ViT-B/32
179
+ ```
180
+
181
+ # Citation
182
+ If you find CLIP4Clip useful in your work, you can cite the following paper:
183
+ ```bibtex
184
+ @Article{Luo2021CLIP4Clip,
185
+ author = {Huaishao Luo and Lei Ji and Ming Zhong and Yang Chen and Wen Lei and Nan Duan and Tianrui Li},
186
+ title = {{CLIP4Clip}: An Empirical Study of CLIP for End to End Video Clip Retrieval},
187
+ journal = {arXiv preprint arXiv:2104.08860},
188
+ year = {2021},
189
+ }
190
+ ```
191
+
192
+ # Acknowledgments
193
+ Our code is based on [CLIP](https://github.com/openai/CLIP) and [UniVL](https://github.com/microsoft/UniVL).
__pycache__/metrics.cpython-312.pyc ADDED
Binary file (4.53 kB). View file
 
__pycache__/metrics.cpython-37.pyc ADDED
Binary file (2.51 kB). View file
 
__pycache__/metrics.cpython-39.pyc ADDED
Binary file (2.54 kB). View file
 
__pycache__/simple_dataloaders.cpython-37.pyc ADDED
Binary file (1.44 kB). View file
 
__pycache__/util.cpython-312.pyc ADDED
Binary file (4.37 kB). View file
 
__pycache__/util.cpython-37.pyc ADDED
Binary file (2.36 kB). View file
 
cache_main_task_retrieval.py ADDED
@@ -0,0 +1,1053 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import unicode_literals
4
+ from __future__ import print_function
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ import random
10
+ import os
11
+ from metrics import compute_metrics, tensor_text_to_video_metrics, tensor_video_to_text_sim
12
+ import time
13
+ import argparse
14
+ from modules.tokenization_clip import SimpleTokenizer as ClipTokenizer
15
+ from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
16
+ from modules.modeling import CLIP4Clip
17
+ import matplotlib.pyplot as plt
18
+ from modules.optimization import BertAdam
19
+
20
+ from util import parallel_apply, get_logger
21
+ from modules.until_module import AllGather
22
+ from dataloaders.data_dataloaders import DATALOADER_DICT
23
+
24
+ # torch.distributed.init_process_group(backend="nccl")
25
+
26
+ global logger
27
+
28
+ def get_args(description='CLIP4Clip on Retrieval Task'):
29
+ parser = argparse.ArgumentParser(description=description)
30
+ parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
31
+ parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
32
+ parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
33
+
34
+ parser.add_argument('--train_csv', type=str, default='data/.train.csv', help='')
35
+ parser.add_argument('--val_csv', type=str, default='data/.val.csv', help='')
36
+ parser.add_argument('--data_path', type=str, default='data/caption.pickle', help='data pickle file path')
37
+ parser.add_argument('--features_path', type=str, default='data/videos_feature.pickle', help='feature path')
38
+
39
+ parser.add_argument('--num_thread_reader', type=int, default=1, help='')
40
+ parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate')
41
+ parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit')
42
+ parser.add_argument('--batch_size', type=int, default=256, help='batch size')
43
+ parser.add_argument('--batch_size_val', type=int, default=3500, help='batch size eval')
44
+ parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate exp epoch decay')
45
+ parser.add_argument('--n_display', type=int, default=100, help='Information display frequence')
46
+ parser.add_argument('--video_dim', type=int, default=1024, help='video feature dimension')
47
+ parser.add_argument('--seed', type=int, default=42, help='random seed')
48
+ parser.add_argument('--max_words', type=int, default=20, help='')
49
+ parser.add_argument('--max_frames', type=int, default=100, help='')
50
+ parser.add_argument('--feature_framerate', type=int, default=1, help='')
51
+ parser.add_argument('--margin', type=float, default=0.1, help='margin for loss')
52
+ parser.add_argument('--hard_negative_rate', type=float, default=0.5, help='rate of intra negative sample')
53
+ parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative')
54
+ parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader')
55
+
56
+ parser.add_argument("--output_dir", default=None, type=str, required=True,
57
+ help="The output directory where the model predictions and checkpoints will be written.")
58
+ parser.add_argument("--cross_model", default="cross-base", type=str, required=False, help="Cross module")
59
+ parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.")
60
+ parser.add_argument("--resume_model", default=None, type=str, required=False, help="Resume train model.")
61
+ parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
62
+ parser.add_argument("--warmup_proportion", default=0.1, type=float,
63
+ help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.")
64
+ parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
65
+ help="Number of updates steps to accumulate before performing a backward/update pass.")
66
+ parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.")
67
+
68
+ parser.add_argument("--cache_dir", default="", type=str,
69
+ help="Where do you want to store the pre-trained models downloaded from s3")
70
+
71
+ parser.add_argument('--fp16', action='store_true',
72
+ help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
73
+ parser.add_argument('--fp16_opt_level', type=str, default='O1',
74
+ help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
75
+ "See details at https://nvidia.github.io/apex/amp.html")
76
+
77
+ parser.add_argument("--task_type", default="retrieval", type=str, help="Point the task `retrieval` to finetune.")
78
+ parser.add_argument("--datatype", default="msrvtt", type=str, help="Point the dataset to finetune.")
79
+
80
+ parser.add_argument("--world_size", default=0, type=int, help="distribted training")
81
+ parser.add_argument("--local_rank", default=0, type=int, help="distribted training")
82
+ # alias for torch.distributed.run / launch passing --local-rank
83
+ parser.add_argument("--local-rank", dest="local_rank", default=0, type=int, help="alias for local_rank")
84
+ parser.add_argument("--rank", default=0, type=int, help="distribted training")
85
+ parser.add_argument('--coef_lr', type=float, default=1., help='coefficient for bert branch.')
86
+ parser.add_argument('--use_mil', action='store_true', help="Whether use MIL as Miech et. al. (2020).")
87
+ parser.add_argument('--sampled_use_mil', action='store_true', help="Whether MIL, has a high priority than use_mil.")
88
+
89
+ parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.")
90
+ parser.add_argument('--visual_num_hidden_layers', type=int, default=12, help="Layer NO. of visual.")
91
+ parser.add_argument('--cross_num_hidden_layers', type=int, default=4, help="Layer NO. of cross.")
92
+
93
+ parser.add_argument('--loose_type', action='store_true', help="Default using tight type for retrieval.")
94
+ parser.add_argument('--expand_msrvtt_sentences', action='store_true', help="")
95
+
96
+ parser.add_argument('--train_frame_order', type=int, default=0, choices=[0, 1, 2],
97
+ help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.")
98
+ parser.add_argument('--eval_frame_order', type=int, default=0, choices=[0, 1, 2],
99
+ help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.")
100
+
101
+ parser.add_argument('--freeze_layer_num', type=int, default=0, help="Layer NO. of CLIP need to freeze.")
102
+ parser.add_argument('--slice_framepos', type=int, default=0, choices=[0, 1, 2],
103
+ help="0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly.")
104
+ parser.add_argument('--linear_patch', type=str, default="2d", choices=["2d", "3d"],
105
+ help="linear projection of flattened patches.")
106
+ parser.add_argument('--sim_header', type=str, default="meanP",
107
+ choices=["meanP", "seqLSTM", "seqTransf", "tightTransf"],
108
+ help="choice a similarity header.")
109
+
110
+ parser.add_argument("--pretrained_clip_name", default="ViT-B/32", type=str, help="Choose a CLIP version")
111
+ parser.add_argument("--use_rff", action='store_true', help="Use RFF hypervector encoding for video embeddings")
112
+ parser.add_argument("--rff_dim", type=int, default=3000, help="Hypervector dimension for RFF encoding")
113
+ parser.add_argument("--use_clip4hashing", action="store_true", help="CLIP4Hashing 손실·해시 경로 사용 여부")
114
+ parser.add_argument("--hash_bit", type=int, default=2048, help="해시 코드 비트 수 (default 1024)")
115
+ # Projection options
116
+ parser.add_argument('--proj', type=int, default=0, help='Projection dim (0 to disable, e.g., 3008)')
117
+ parser.add_argument('--proj_act', type=str, default='tanh', choices=['tanh', 'relu', 'gelu', 'sigmoid'],
118
+ help='Activation after projection')
119
+ parser.add_argument('--binary_eval', action='store_true', help='Use binarized retrieval at eval (sign + sum)')
120
+
121
+ args = parser.parse_args()
122
+
123
+ if args.sim_header == "tightTransf":
124
+ args.loose_type = False
125
+
126
+ # Check paramenters
127
+ if args.gradient_accumulation_steps < 1:
128
+ raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
129
+ args.gradient_accumulation_steps))
130
+ if not args.do_train and not args.do_eval:
131
+ raise ValueError("At least one of `do_train` or `do_eval` must be True.")
132
+
133
+ args.batch_size = int(args.batch_size / args.gradient_accumulation_steps)
134
+
135
+ # Accept env fallback if provided by torchrun
136
+ if 'LOCAL_RANK' in os.environ:
137
+ try:
138
+ args.local_rank = int(os.environ['LOCAL_RANK'])
139
+ except Exception:
140
+ pass
141
+ if 'RANK' in os.environ:
142
+ try:
143
+ args.rank = int(os.environ['RANK'])
144
+ except Exception:
145
+ pass
146
+ if 'WORLD_SIZE' in os.environ:
147
+ try:
148
+ args.world_size = int(os.environ['WORLD_SIZE'])
149
+ except Exception:
150
+ pass
151
+
152
+ return args
153
+
154
+ def set_seed_logger(args):
155
+ global logger
156
+ # predefining random initial seeds
157
+ random.seed(args.seed)
158
+ os.environ['PYTHONHASHSEED'] = str(args.seed)
159
+ np.random.seed(args.seed)
160
+ torch.manual_seed(args.seed)
161
+ torch.cuda.manual_seed(args.seed)
162
+ torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU.
163
+ torch.backends.cudnn.benchmark = False
164
+ torch.backends.cudnn.deterministic = True
165
+
166
+ world_size = torch.distributed.get_world_size()
167
+ torch.cuda.set_device(args.local_rank)
168
+ args.world_size = world_size
169
+ rank = torch.distributed.get_rank()
170
+ args.rank = rank
171
+
172
+ if not os.path.exists(args.output_dir):
173
+ os.makedirs(args.output_dir, exist_ok=True)
174
+
175
+ logger = get_logger(os.path.join(args.output_dir, "log.txt"))
176
+
177
+ if args.local_rank == 0:
178
+ logger.info("Effective parameters:")
179
+ for key in sorted(args.__dict__):
180
+ logger.info(" <<< {}: {}".format(key, args.__dict__[key]))
181
+
182
+ # Sanity check for binary eval + bit packing compatibility
183
+ if getattr(args, 'binary_eval', False) and getattr(args, 'proj', 0) > 0:
184
+ if args.proj % 64 != 0:
185
+ raise ValueError(f"--proj must be divisible by 64 for binary eval, got {args.proj}")
186
+
187
+ return args
188
+
189
+ def init_device(args, local_rank):
190
+ global logger
191
+
192
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu", local_rank)
193
+
194
+ n_gpu = torch.cuda.device_count()
195
+ logger.info("device: {} n_gpu: {}".format(device, n_gpu))
196
+ args.n_gpu = n_gpu
197
+
198
+ if args.batch_size % args.n_gpu != 0 or args.batch_size_val % args.n_gpu != 0:
199
+ raise ValueError("Invalid batch_size/batch_size_val and n_gpu parameter: {}%{} and {}%{}, should be == 0".format(
200
+ args.batch_size, args.n_gpu, args.batch_size_val, args.n_gpu))
201
+
202
+ return device, n_gpu
203
+
204
+ def init_model(args, device, n_gpu, local_rank):
205
+
206
+ if args.init_model:
207
+ model_state_dict = torch.load(args.init_model, map_location='cpu')
208
+ else:
209
+ model_state_dict = None
210
+
211
+ # Prepare model
212
+ cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
213
+ model = CLIP4Clip.from_pretrained(args.cross_model, cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
214
+
215
+ # Attach projection head if requested (before DDP wrapping)
216
+ if getattr(args, 'proj', 0) and args.proj > 0:
217
+ print("Projection")
218
+ # Register projection layer on the model
219
+ proj = torch.nn.Linear(512, args.proj, bias=False)
220
+ torch.nn.init.normal_(proj.weight, mean=0.0, std=0.02)
221
+ # Activation
222
+ if args.proj_act == 'tanh':
223
+ act = torch.nn.Tanh()
224
+ elif args.proj_act == 'relu':
225
+ act = torch.nn.ReLU()
226
+ elif args.proj_act == 'gelu':
227
+ act = torch.nn.GELU()
228
+ elif args.proj_act == 'sigmoid':
229
+ act = torch.nn.Sigmoid()
230
+ else:
231
+ act = torch.nn.Tanh()
232
+ model.proj_head = proj
233
+ model.proj_activation = act
234
+
235
+ # If init_model contains proj_head params, load them now
236
+ if model_state_dict is not None:
237
+ try:
238
+ missing, unexpected = model.load_state_dict(model_state_dict, strict=False)
239
+ except Exception:
240
+ pass
241
+
242
+ model.to(device)
243
+
244
+ return model
245
+
246
+ def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=1.):
247
+
248
+ if hasattr(model, 'module'):
249
+ model = model.module
250
+
251
+ param_optimizer = list(model.named_parameters())
252
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
253
+
254
+ decay_param_tp = [(n, p) for n, p in param_optimizer if not any(nd in n for nd in no_decay)]
255
+ no_decay_param_tp = [(n, p) for n, p in param_optimizer if any(nd in n for nd in no_decay)]
256
+
257
+ decay_clip_param_tp = [(n, p) for n, p in decay_param_tp if "clip." in n]
258
+ decay_noclip_param_tp = [(n, p) for n, p in decay_param_tp if "clip." not in n]
259
+
260
+ no_decay_clip_param_tp = [(n, p) for n, p in no_decay_param_tp if "clip." in n]
261
+ no_decay_noclip_param_tp = [(n, p) for n, p in no_decay_param_tp if "clip." not in n]
262
+
263
+ weight_decay = 0.2
264
+ optimizer_grouped_parameters = [
265
+ {'params': [p for n, p in decay_clip_param_tp], 'weight_decay': weight_decay, 'lr': args.lr * coef_lr},
266
+ {'params': [p for n, p in decay_noclip_param_tp], 'weight_decay': weight_decay},
267
+ {'params': [p for n, p in no_decay_clip_param_tp], 'weight_decay': 0.0, 'lr': args.lr * coef_lr},
268
+ {'params': [p for n, p in no_decay_noclip_param_tp], 'weight_decay': 0.0}
269
+ ]
270
+
271
+ scheduler = None
272
+ optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, warmup=args.warmup_proportion,
273
+ schedule='warmup_cosine', b1=0.9, b2=0.98, e=1e-6,
274
+ t_total=num_train_optimization_steps, weight_decay=weight_decay,
275
+ max_grad_norm=1.0)
276
+
277
+ # 옵티마이저 만든 뒤 곧장 실행
278
+ name2param = {n: p for n, p in model.named_parameters() if p.requires_grad}
279
+ param2name = {id(p): n for n, p in name2param.items()}
280
+
281
+ for gi, g in enumerate(optimizer.param_groups):
282
+ print(f"[group {gi}] lr={g['lr']:.2e}, params={len(g['params'])}")
283
+ # 각 그룹에서 몇 개만 샘플로 찍기
284
+ for p in g["params"][:8]:
285
+ print(" ", param2name.get(id(p), "?"))
286
+
287
+ # Ensure both ranks finish building the exact same model before DDP wrap.
288
+ # Use a CPU barrier to avoid NCCL device-scoped hangs/timeouts.
289
+
290
+ # Quick debug: log param tensor count per rank
291
+ try:
292
+ num_tensors = len(list(model.parameters()))
293
+ if local_rank == 0:
294
+ print(f"[DDP-DEBUG] rank={args.rank} local_rank={local_rank} param_tensors={num_tensors}")
295
+ else:
296
+ print(f"[DDP-DEBUG] rank={args.rank} local_rank={local_rank} param_tensors={num_tensors}")
297
+ except Exception:
298
+ pass
299
+
300
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
301
+ output_device=local_rank, find_unused_parameters=False)
302
+
303
+ return optimizer, scheduler, model
304
+
305
+ def save_model(epoch, args, model, optimizer, tr_loss, type_name=""):
306
+ # Only save the model it-self
307
+ model_to_save = model.module if hasattr(model, 'module') else model
308
+ output_model_file = os.path.join(
309
+ args.output_dir, "pytorch_model.bin.{}{}".format("" if type_name=="" else type_name+".", epoch))
310
+ optimizer_state_file = os.path.join(
311
+ args.output_dir, "pytorch_opt.bin.{}{}".format("" if type_name=="" else type_name+".", epoch))
312
+ torch.save(model_to_save.state_dict(), output_model_file)
313
+ torch.save({
314
+ 'epoch': epoch,
315
+ 'optimizer_state_dict': optimizer.state_dict(),
316
+ 'loss': tr_loss,
317
+ }, optimizer_state_file)
318
+ logger.info("Model saved to %s", output_model_file)
319
+ logger.info("Optimizer saved to %s", optimizer_state_file)
320
+ return output_model_file
321
+
322
+ def load_model(epoch, args, n_gpu, device, model_file=None):
323
+ if model_file is None or len(model_file) == 0:
324
+ model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(epoch))
325
+ if os.path.exists(model_file):
326
+ model_state_dict = torch.load(model_file, map_location='cpu')
327
+ if args.local_rank == 0:
328
+ logger.info("Model loaded from %s", model_file)
329
+ # Prepare model
330
+ cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
331
+ model = CLIP4Clip.from_pretrained(args.cross_model, cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
332
+ # Attach projection head if needed and load any matching weights
333
+ if getattr(args, 'proj', 0) and args.proj > 0:
334
+ proj = torch.nn.Linear(512, args.proj, bias=False)
335
+ torch.nn.init.normal_(proj.weight, mean=0.0, std=0.02)
336
+ if args.proj_act == 'tanh':
337
+ act = torch.nn.Tanh()
338
+ elif args.proj_act == 'relu':
339
+ act = torch.nn.ReLU()
340
+ elif args.proj_act == 'gelu':
341
+ act = torch.nn.GELU()
342
+ elif args.proj_act == 'sigmoid':
343
+ act = torch.nn.Sigmoid()
344
+ else:
345
+ act = torch.nn.Tanh()
346
+ model.proj_head = proj
347
+ model.proj_activation = act
348
+ try:
349
+ model.load_state_dict(model_state_dict, strict=False)
350
+ except Exception:
351
+ pass
352
+
353
+ model.to(device)
354
+ else:
355
+ model = None
356
+
357
+ logger.info(f"모델을 로드합니다:{cache_dir}")
358
+ return model
359
+
360
+ def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, scheduler, global_step, local_rank=0):
361
+ global logger
362
+ torch.cuda.empty_cache()
363
+
364
+ net = model.module if hasattr(model, 'module') else model
365
+ net.train()
366
+
367
+ log_step = args.n_display
368
+ start_time = time.time()
369
+ total_loss = 0
370
+
371
+ for step, batch in enumerate(train_dataloader):
372
+ if n_gpu == 1:
373
+ # multi-gpu does scattering it-self
374
+ batch = tuple(t.to(device=device, non_blocking=True) for t in batch)
375
+
376
+ input_ids, input_mask, segment_ids, video, video_mask = batch
377
+
378
+ # If projection head enabled, override loss path to use proj+tanh->sum embeddings
379
+ if getattr(args, 'proj', 0) and args.proj > 0 and hasattr(net, 'proj_head'):
380
+ # Forward backbone encoders only
381
+ sequence_output, visual_output = net.get_sequence_visual_output(
382
+ input_ids, segment_ids, input_mask, video, video_mask)
383
+
384
+ # Gather across processes for global negatives
385
+ sequence_output = AllGather.apply(sequence_output, args)
386
+ visual_output = AllGather.apply(visual_output, args)
387
+ video_mask_g = AllGather.apply(video_mask.view(-1, video_mask.shape[-1]), args)
388
+
389
+ # Text: [B,1,512] -> [B,512] -> proj -> act -> L2
390
+ txt = sequence_output.squeeze(1)
391
+ txt = net.proj_activation(net.proj_head(txt))
392
+ txt = F.normalize(txt, dim=-1)
393
+
394
+ # Video: [B,T,512] -> proj -> act -> mask -> sum(T) -> L2
395
+ vid = net.proj_activation(net.proj_head(visual_output))
396
+ vm = video_mask_g.to(dtype=vid.dtype).unsqueeze(-1)
397
+ vid = (vid * vm).sum(dim=1)
398
+ vid = F.normalize(vid, dim=-1)
399
+
400
+ logit_scale = net.clip.logit_scale.exp()
401
+ sim = logit_scale * torch.matmul(txt, vid.t())
402
+
403
+ # Same symmetric CE loss as original
404
+ loss = net.loss_fct(sim) + net.loss_fct(sim.T)
405
+ loss = loss * 0.5
406
+ else:
407
+ loss = model(input_ids, segment_ids, input_mask, video, video_mask)
408
+
409
+ if n_gpu > 1:
410
+ loss = loss.mean() # mean() to average on multi-gpu.
411
+ if args.gradient_accumulation_steps > 1:
412
+ loss = loss / args.gradient_accumulation_steps
413
+
414
+ loss.backward()
415
+
416
+ total_loss += float(loss)
417
+ if (step + 1) % args.gradient_accumulation_steps == 0:
418
+
419
+ torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
420
+
421
+ if scheduler is not None:
422
+ scheduler.step() # Update learning rate schedule
423
+
424
+ optimizer.step()
425
+ optimizer.zero_grad()
426
+
427
+ # https://github.com/openai/CLIP/issues/46
428
+
429
+ torch.clamp_(net.clip.logit_scale.data, max=np.log(100))
430
+
431
+ global_step += 1
432
+ if global_step % log_step == 0 and local_rank == 0:
433
+ logger.info("Epoch: %d/%s, Step: %d/%d, Lr: %s, Loss: %f, Time/step: %f", epoch + 1,
434
+ args.epochs, step + 1,
435
+ len(train_dataloader), "-".join([str('%.9f'%itm) for itm in sorted(list(set(optimizer.get_lr())))]),
436
+ float(loss),
437
+ (time.time() - start_time) / (log_step * args.gradient_accumulation_steps))
438
+ start_time = time.time()
439
+
440
+ total_loss = total_loss / len(train_dataloader)
441
+ return total_loss, global_step
442
+
443
+ def _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_sequence_output_list, batch_visual_output_list):
444
+
445
+ sim_matrix = []
446
+ for idx1, b1 in enumerate(batch_list_t):
447
+ input_mask, segment_ids, *_tmp = b1
448
+ sequence_output = batch_sequence_output_list[idx1]
449
+ each_row = []
450
+ for idx2, b2 in enumerate(batch_list_v):
451
+ video_mask, *_tmp = b2
452
+ visual_output = batch_visual_output_list[idx2]
453
+ b1b2_logits, *_tmp = model.get_similarity_logits(sequence_output, visual_output, input_mask, video_mask,
454
+ loose_type=model.loose_type)
455
+ b1b2_logits = b1b2_logits.cpu().detach().numpy()
456
+ each_row.append(b1b2_logits)
457
+ each_row = np.concatenate(tuple(each_row), axis=-1)
458
+ sim_matrix.append(each_row)
459
+ return sim_matrix
460
+
461
+ def eval_epoch(args, model, test_dataloader, device, n_gpu, tokenizer):
462
+ import numpy as np
463
+ import os
464
+ import torch
465
+ import matplotlib.pyplot as plt
466
+
467
+ def _decode_query(tokenizer, ids_tensor):
468
+ # ids_tensor: 1D tensor (token ids)
469
+ try:
470
+ if isinstance(ids_tensor, torch.Tensor):
471
+ ids = ids_tensor.cpu().numpy().tolist()
472
+ else:
473
+ ids = ids_tensor.tolist() if hasattr(ids_tensor, 'tolist') else list(ids_tensor)
474
+
475
+ # ClipTokenizer의 특수 토큰 ID들
476
+ start_token_id = tokenizer.encoder.get('<|startoftext|>', 49406) # 기본값
477
+ end_token_id = tokenizer.encoder.get('<|endoftext|>', 49407) # 기본값
478
+
479
+ # 패딩 토큰(0)과 특수 토큰 제거
480
+ clean_ids = []
481
+ for token_id in ids:
482
+ if token_id > 0 and token_id != start_token_id and token_id != end_token_id:
483
+ clean_ids.append(token_id)
484
+
485
+ if not clean_ids:
486
+ return "<empty_query>"
487
+
488
+ # 유효하지 않은 토큰 ID 필터링 (vocab 범위 내)
489
+ vocab_size = len(tokenizer.decoder)
490
+ valid_ids = [tid for tid in clean_ids if tid < vocab_size]
491
+
492
+ if not valid_ids:
493
+ return "<invalid_tokens>"
494
+
495
+ # 디코딩 시도
496
+ try:
497
+ decoded_text = tokenizer.decode(valid_ids)
498
+ return decoded_text.strip()
499
+ except KeyError as e:
500
+ # 개별 토큰별로 디코딩 시도
501
+ decoded_tokens = []
502
+ for tid in valid_ids:
503
+ if tid in tokenizer.decoder:
504
+ decoded_tokens.append(tokenizer.decoder[tid])
505
+ else:
506
+ decoded_tokens.append(f"<unk_{tid}>")
507
+ text = ''.join(decoded_tokens)
508
+ # BPE 후처리
509
+ text = text.replace('</w>', ' ').strip()
510
+ return text if text else "<decode_partial_error>"
511
+ except Exception as e:
512
+ return f"<decode_error: {str(e)[:50]}>"
513
+
514
+ except Exception as e:
515
+ return f"<general_error: {str(e)[:50]}>"
516
+
517
+ def _get_video_ids_from_dataset(dataset, num_videos):
518
+ # 다양한 후보 속성명 시도 → 없으면 0..N-1 인덱스 문자열로 대체
519
+ for attr in ["video_list", "video_ids", "video_names", "videos", "vids", "id_list"]:
520
+ if hasattr(dataset, attr):
521
+ obj = getattr(dataset, attr)
522
+ try:
523
+ if isinstance(obj, (list, tuple)) and len(obj) == num_videos:
524
+ return list(map(str, obj))
525
+ except Exception:
526
+ pass
527
+ return [str(i) for i in range(num_videos)]
528
+
529
+ global logger
530
+
531
+ multi_sentence_ = False
532
+ cut_off_points_, sentence_num_, video_num_ = [], -1, -1
533
+
534
+ if hasattr(model, 'module'):
535
+ model = model.module.to(device)
536
+ else:
537
+ model = model.to(device)
538
+
539
+ if hasattr(model, 'module'):
540
+ model.module.eval()
541
+ else:
542
+ model.eval()
543
+
544
+ logger.info("Model %s", "training" if model.training else "eval")
545
+
546
+ # suffix for cache/result naming
547
+ suffix = "_hash" if getattr(args, "use_clip4hashing", False) else ""
548
+ suffix += "_rff" if args.use_rff else ""
549
+ suffix += f"_proj{args.proj}" if getattr(args, 'proj', 0) and args.proj > 0 else ""
550
+ suffix += "_binary" if getattr(args, 'binary_eval', False) else ""
551
+ suffix += "_trained" if args.init_model else ""
552
+
553
+ # (A) 캐시 로드/생성
554
+ if "train" in args.val_csv and "10k" in args.val_csv:
555
+ cache_name = f"{args.datatype}_train_test_10k_cache{suffix}.pt"
556
+ logger.info(f"9k 훈련 데이터 캐시 생성: {cache_name}")
557
+ else:
558
+ cache_name = f"{args.datatype}_eval_cache{suffix}.pt"
559
+ logger.info(f"평가 데이터 캐시: {cache_name}")
560
+
561
+ cache_path = os.path.join(args.output_dir, cache_name)
562
+
563
+ loaded_from_cache = False
564
+ if os.path.exists(cache_path):
565
+ logger.info(f"캐시된 피처를 로드합니다: {cache_path}")
566
+ cache = torch.load(cache_path, map_location=device)
567
+ batch_sequence_output_list = cache['batch_sequence_output_list']
568
+ batch_visual_output_list = cache['batch_visual_output_list']
569
+ batch_list_t = cache['batch_list_t']
570
+ batch_list_v = cache['batch_list_v']
571
+ text_input_ids_list = cache.get('text_input_ids_list', None)
572
+ video_ids = cache.get('video_ids', None)
573
+ loaded_from_cache = True
574
+
575
+ logger.info(f"[Cache] 텍스트 피쳐 개수={len(batch_sequence_output_list)} "
576
+ f"각 텐서 shape={batch_sequence_output_list[0].shape}")
577
+ logger.info(f"[Cache] 비디오 피쳐 개수={len(batch_visual_output_list)} "
578
+ f"각 텐서 shape={batch_visual_output_list[0].shape}")
579
+ else:
580
+ print("Caching feature..")
581
+ if hasattr(test_dataloader.dataset, 'multi_sentence_per_video') and \
582
+ test_dataloader.dataset.multi_sentence_per_video:
583
+ multi_sentence_ = True
584
+ cut_off_points_ = test_dataloader.dataset.cut_off_points
585
+ sentence_num_ = test_dataloader.dataset.sentence_num
586
+ video_num_ = test_dataloader.dataset.video_num
587
+ cut_off_points_ = [itm - 1 for itm in cut_off_points_]
588
+ logger.warning("Eval under multi-sentence-per-video. sentence num: %s, video num: %s",
589
+ sentence_num_, video_num_)
590
+
591
+ with torch.no_grad():
592
+ batch_list_t = []
593
+ batch_list_v = []
594
+ batch_sequence_output_list, batch_visual_output_list = [], []
595
+ text_input_ids_list = []
596
+ total_video_num = 0
597
+
598
+ for bid, batch in enumerate(test_dataloader):
599
+ batch = tuple(t.to(device) for t in batch)
600
+ input_ids, input_mask, segment_ids, video, video_mask = batch
601
+
602
+ if multi_sentence_:
603
+ b, *_t = video.shape
604
+ sequence_output = model.get_sequence_output(input_ids, segment_ids, input_mask)
605
+ batch_sequence_output_list.append(sequence_output)
606
+ # input_ids를 함께 보관 (run_on_single_gpu는 *_tmp로 무시하므로 안전)
607
+ batch_list_t.append((input_mask, segment_ids, input_ids.detach().cpu()))
608
+
609
+ s_, e_ = total_video_num, total_video_num + b
610
+ filter_inds = [itm - s_ for itm in cut_off_points_ if s_ <= itm < e_]
611
+ if len(filter_inds) > 0:
612
+ video, video_mask = video[filter_inds, ...], video_mask[filter_inds, ...]
613
+ visual_output = model.get_visual_output(video, video_mask)
614
+ batch_visual_output_list.append(visual_output)
615
+ batch_list_v.append((video_mask,))
616
+ total_video_num += b
617
+ else:
618
+ sequence_output, visual_output = model.get_sequence_visual_output(
619
+ input_ids, segment_ids, input_mask, video, video_mask)
620
+
621
+ batch_sequence_output_list.append(sequence_output)
622
+ batch_list_t.append((input_mask, segment_ids, input_ids.detach().cpu()))
623
+
624
+ batch_visual_output_list.append(visual_output)
625
+ batch_list_v.append((video_mask,))
626
+
627
+ print("{}/{}\r".format(bid, len(test_dataloader)), end="")
628
+
629
+ # 비디오 ID 목록 구성 (데이터셋 노��� 없으면 0..N-1)
630
+ num_videos = int(sum(bv.shape[0] for bv in batch_visual_output_list))
631
+ video_ids = _get_video_ids_from_dataset(test_dataloader.dataset, num_videos)
632
+
633
+ logger.info(f"추출된 피처를 캐시에 저장합니다: {cache_path}")
634
+ torch.save({
635
+ 'batch_sequence_output_list': batch_sequence_output_list,
636
+ 'batch_visual_output_list': batch_visual_output_list,
637
+ 'batch_list_t': batch_list_t,
638
+ 'batch_list_v': batch_list_v,
639
+ 'text_input_ids_list': text_input_ids_list,
640
+ 'video_ids': video_ids,
641
+ }, cache_path)
642
+
643
+ logger.info(f"[Cache] 텍스트 피쳐 개수={len(batch_sequence_output_list)} "
644
+ f"각 텐서 shape={batch_sequence_output_list[0].shape}")
645
+ logger.info(f"[Cache] 비디오 피쳐 개수={len(batch_visual_output_list)} "
646
+ f"각 텐서 shape={batch_visual_output_list[0].shape}")
647
+
648
+ # 캐시에 text_input_ids_list가 없으면, 한 번 더 훑어서 수집 (구버전 캐시 호환)
649
+ if loaded_from_cache and 'text_input_ids_list' not in cache:
650
+ logger.info("캐시에 text_input_ids_list가 없어 재수집합니다(호환성 경로).")
651
+ text_input_ids_list = []
652
+ with torch.no_grad():
653
+ for batch in test_dataloader:
654
+ input_ids = batch[0].detach().cpu()
655
+ text_input_ids_list.append(input_ids)
656
+ elif loaded_from_cache and text_input_ids_list is None:
657
+ # batch_list_t에서 input_ids 추출
658
+ logger.info("batch_list_t에서 input_ids를 추출합니다.")
659
+ text_input_ids_list = []
660
+ for input_mask, segment_ids, input_ids in batch_list_t:
661
+ text_input_ids_list.append(input_ids)
662
+
663
+ # video_ids가 없으면 만들어준다(구버전 캐시 호환)
664
+ if loaded_from_cache and cache.get('video_ids', None) is None:
665
+ num_videos = int(sum(bv.shape[0] for bv in batch_visual_output_list))
666
+ video_ids = _get_video_ids_from_dataset(test_dataloader.dataset, num_videos)
667
+
668
+ # (B) 유사도 행렬 계산
669
+ def _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_sequence_output_list, batch_visual_output_list):
670
+ sim_matrix = []
671
+ use_proj = getattr(args, 'proj', 0) and args.proj > 0 and hasattr(model, 'proj_head')
672
+ use_binary = getattr(args, 'binary_eval', False) and use_proj
673
+ for idx1, b1 in enumerate(batch_list_t):
674
+ input_mask, segment_ids, *_tmp = b1
675
+ sequence_output = batch_sequence_output_list[idx1]
676
+
677
+ each_row = []
678
+ if use_proj:
679
+ # Text: [B,1,512] -> [B,512] -> proj -> act
680
+ t = sequence_output.squeeze(1)
681
+ t = model.proj_activation(model.proj_head(t))
682
+ if use_binary:
683
+ t_vec = torch.where(t > 0, torch.tensor(1.0, device=t.device, dtype=t.dtype),
684
+ torch.tensor(-1.0, device=t.device, dtype=t.dtype))
685
+ else:
686
+ t_vec = F.normalize(t, dim=-1)
687
+
688
+ for idx2, b2 in enumerate(batch_list_v):
689
+ video_mask, *_tmp = b2
690
+ visual_output = batch_visual_output_list[idx2]
691
+ v = model.proj_activation(model.proj_head(visual_output)) # [B, T, P]
692
+
693
+ # Robust mask handling: ensure vm is [B, T, 1]
694
+ vm = video_mask
695
+ # Squeeze stray singleton dims (e.g., [B, T, 1] -> [B, T])
696
+ while vm.dim() > 2 and vm.size(-1) == 1:
697
+ vm = vm.squeeze(-1)
698
+ # If still higher-rank, flatten to [B, -1] safely
699
+ if vm.dim() > 2:
700
+ vm = vm.view(vm.size(0), -1)
701
+ # If [T] vector sneaks in, expand batch
702
+ if vm.dim() == 1:
703
+ vm = vm.unsqueeze(0)
704
+ # Align time length to visual features
705
+ if vm.size(1) != v.size(1):
706
+ vm = vm[:, :v.size(1)]
707
+ vm = vm.to(dtype=v.dtype).unsqueeze(-1) # [B, T, 1]
708
+
709
+ # Masked temporal sum -> [B, P]
710
+ v = (v * vm).sum(dim=1)
711
+ # Squeeze any trailing singleton dim that might remain
712
+ if v.dim() > 2 and v.size(-1) == 1:
713
+ v = v.squeeze(-1)
714
+ if use_binary:
715
+ v_vec = torch.where(v > 0, torch.tensor(1.0, device=v.device, dtype=v.dtype),
716
+ torch.tensor(-1.0, device=v.device, dtype=v.dtype))
717
+ scores = torch.matmul(t_vec, v_vec.t())
718
+ else:
719
+ v_vec = F.normalize(v, dim=-1)
720
+ scores = torch.matmul(t_vec, v_vec.t())
721
+ each_row.append(scores.cpu().detach().numpy())
722
+ else:
723
+ for idx2, b2 in enumerate(batch_list_v):
724
+ video_mask, *_tmp = b2
725
+ visual_output = batch_visual_output_list[idx2]
726
+ b1b2_logits, *_tmp = model.get_similarity_logits(
727
+ sequence_output, visual_output, input_mask, video_mask, loose_type=model.loose_type)
728
+ b1b2_logits = b1b2_logits.cpu().detach().numpy()
729
+ each_row.append(b1b2_logits)
730
+
731
+ each_row = np.concatenate(tuple(each_row), axis=-1)
732
+ sim_matrix.append(each_row)
733
+ return sim_matrix
734
+
735
+ if n_gpu > 1:
736
+ device_ids = list(range(n_gpu))
737
+ batch_list_t_splits, batch_list_v_splits = [], []
738
+ batch_t_output_splits, batch_v_output_splits = [], []
739
+ bacth_len = len(batch_list_t)
740
+ split_len = (bacth_len + n_gpu - 1) // n_gpu
741
+ for dev_id in device_ids:
742
+ s_, e_ = dev_id * split_len, (dev_id + 1) * split_len
743
+ if dev_id == 0:
744
+ batch_list_t_splits.append(batch_list_t[s_:e_]); batch_list_v_splits.append(batch_list_v)
745
+ batch_t_output_splits.append(batch_sequence_output_list[s_:e_]); batch_v_output_splits.append(batch_visual_output_list)
746
+ else:
747
+ devc = torch.device(f'cuda:{dev_id}')
748
+ devc_batch_list = [tuple(t.to(devc) for t in b) for b in batch_list_t[s_:e_]]
749
+ batch_list_t_splits.append(devc_batch_list)
750
+ devc_batch_list = [tuple(t.to(devc) for t in b) for b in batch_list_v]
751
+ batch_list_v_splits.append(devc_batch_list)
752
+ devc_batch_list = [b.to(devc) for b in batch_sequence_output_list[s_:e_]]
753
+ batch_t_output_splits.append(devc_batch_list)
754
+ devc_batch_list = [b.to(devc) for b in batch_visual_output_list]
755
+ batch_v_output_splits.append(devc_batch_list)
756
+
757
+ parameters_tuple_list = [(batch_list_t_splits[dev_id], batch_list_v_splits[dev_id],
758
+ batch_t_output_splits[dev_id], batch_v_output_splits[dev_id]) for dev_id in device_ids]
759
+ parallel_outputs = parallel_apply(_run_on_single_gpu, model, parameters_tuple_list, device_ids)
760
+ sim_matrix = []
761
+ for idx in range(len(parallel_outputs)):
762
+ sim_matrix += parallel_outputs[idx]
763
+ sim_matrix = np.concatenate(tuple(sim_matrix), axis=0)
764
+ else:
765
+ sim_matrix = _run_on_single_gpu(model, batch_list_t, batch_list_v,
766
+ batch_sequence_output_list, batch_visual_output_list)
767
+ sim_matrix = np.concatenate(tuple(sim_matrix), axis=0)
768
+
769
+ # (C) 멀티센텐스 처리 및 메트릭
770
+ if hasattr(test_dataloader.dataset, 'multi_sentence_per_video') and \
771
+ test_dataloader.dataset.multi_sentence_per_video:
772
+ multi_sentence_ = True
773
+
774
+ if multi_sentence_:
775
+ logger.info("before reshape, sim matrix size: %d x %d", sim_matrix.shape[0], sim_matrix.shape[1])
776
+ sim_matrix_flat = sim_matrix.copy() # 쿼리별 Top-K 용 2D 보관
777
+
778
+ cut_off_points2len_ = [itm + 1 for itm in cut_off_points_]
779
+ max_length = max([e_-s_ for s_, e_ in zip([0]+cut_off_points2len_[:-1], cut_off_points2len_)])
780
+ sim_matrix_new = []
781
+ for s_, e_ in zip([0] + cut_off_points2len_[:-1], cut_off_points2len_):
782
+ sim_matrix_new.append(np.concatenate((sim_matrix[s_:e_],
783
+ np.full((max_length-e_+s_, sim_matrix.shape[1]), -np.inf)), axis=0))
784
+ sim_matrix = np.stack(tuple(sim_matrix_new), axis=0)
785
+ logger.info("after reshape, sim matrix size: %d x %d x %d",
786
+ sim_matrix.shape[0], sim_matrix.shape[1], sim_matrix.shape[2])
787
+
788
+ tv_metrics = tensor_text_to_video_metrics(sim_matrix)
789
+ vt_metrics = compute_metrics(tensor_video_to_text_sim(sim_matrix))
790
+ else:
791
+ logger.info("sim matrix size: %d x %d", sim_matrix.shape[0], sim_matrix.shape[1])
792
+ # 히트맵 저장(샘플)
793
+ plt.figure(figsize=(8,6))
794
+ plt.imshow(sim_matrix[:100, :100], aspect='auto')
795
+ plt.title('Similarity Matrix Heatmap')
796
+ plt.xlabel('Video Index')
797
+ plt.ylabel('Text Index')
798
+ plt.tight_layout()
799
+ out_path = os.path.join(args.output_dir, 'sim_matrix_heatmap.png')
800
+ plt.savefig(out_path); plt.close()
801
+ logger.info(f"Saved sim_matrix heatmap to {out_path}")
802
+
803
+ sim_matrix_flat = sim_matrix # 2D 그대로
804
+ tv_metrics = compute_metrics(sim_matrix)
805
+ vt_metrics = compute_metrics(sim_matrix.T)
806
+ logger.info('\t Length-T: %d, Length-V:%d', len(sim_matrix), len(sim_matrix[0]))
807
+
808
+ logger.info("Text-to-Video:")
809
+ logger.info('\t>>> R@1: %.1f - R@5: %.1f - R@10: %.1f - Median R: %.1f - Mean R: %.1f',
810
+ tv_metrics['R1'], tv_metrics['R5'], tv_metrics['R10'], tv_metrics['MR'], tv_metrics['MeanR'])
811
+ logger.info("Video-to-Text:")
812
+ logger.info('\t>>> V2T$R@1: %.1f - V2T$R@5: %.1f - V2T$R@10: %.1f - V2T$Median R: %.1f - V2T$Mean R: %.1f',
813
+ vt_metrics['R1'], vt_metrics['R5'], vt_metrics['R10'], vt_metrics['MR'], vt_metrics['MeanR'])
814
+
815
+ # (D) 쿼리 텍스트 복원 + Top-10 덤프
816
+ # text_input_ids_list: List[Tensor[B_i, L]]
817
+ all_queries = []
818
+ logger.info(f"text_input_ids_list 개수: {len(text_input_ids_list)}")
819
+
820
+ for batch_idx, ids_batch in enumerate(text_input_ids_list):
821
+ if ids_batch is None:
822
+ logger.warning(f"배치 {batch_idx}: ids_batch가 None입니다.")
823
+ continue
824
+
825
+ try:
826
+ ids_batch = ids_batch if isinstance(ids_batch, torch.Tensor) else torch.as_tensor(ids_batch)
827
+ logger.info(f"배치 {batch_idx}: shape={ids_batch.shape}")
828
+
829
+ for row_idx, row in enumerate(ids_batch):
830
+ decoded = _decode_query(tokenizer, row)
831
+ all_queries.append(decoded)
832
+ if batch_idx == 0 and row_idx < 3: # 첫 배치의 처음 3개만 샘플로 출력
833
+ logger.info(f"샘플 디코딩 결과 [{batch_idx}-{row_idx}]: '{decoded}'")
834
+
835
+ except Exception as e:
836
+ logger.error(f"배치 {batch_idx} 처리 중 오류: {str(e)}")
837
+ # 에러가 발생해도 계속 진행
838
+ continue
839
+
840
+ logger.info(f"총 {len(all_queries)}개의 쿼리가 디코딩되었습니다.")
841
+
842
+ # video_ids 길이 보정(안전)
843
+ num_videos = sim_matrix_flat.shape[1]
844
+ if 'video_ids' in locals():
845
+ if len(video_ids) != num_videos:
846
+ logger.warning("video_ids 길이(%d)와 비디오 수(%d)가 달라 index로 대체합니다.",
847
+ len(video_ids), num_videos)
848
+ video_ids = [str(i) for i in range(num_videos)]
849
+ else:
850
+ video_ids = [str(i) for i in range(num_videos)]
851
+
852
+ # 저장 파일
853
+ topk = 10
854
+ out_tsv = os.path.join(args.output_dir, f"t2v_top10{suffix}.tsv")
855
+ out_json = os.path.join(args.output_dir, f"t2v_top10{suffix}.json")
856
+
857
+ if args.local_rank == 0:
858
+ import json
859
+
860
+ # TSV 파일 저장
861
+ with open(out_tsv, "w", encoding="utf-8") as f:
862
+ f.write("query_idx\tquery\tvideo_rank\tvideo_id\tvideo_idx\tscore\n")
863
+ for qi, q in enumerate(all_queries):
864
+ scores = sim_matrix_flat[qi]
865
+ # 효율: argpartition 후 정렬
866
+ idxs = np.argpartition(-scores, kth=min(topk, len(scores)-1))[:topk]
867
+ idxs = idxs[np.argsort(-scores[idxs])]
868
+ for rank, vidx in enumerate(idxs, 1):
869
+ f.write(f"{qi}\t{q}\t{rank}\t{video_ids[vidx]}\t{int(vidx)}\t{float(scores[vidx]):.6f}\n")
870
+
871
+ # JSON 파일 저장 (구조화된 형태)
872
+ results_dict = {}
873
+ for qi, q in enumerate(all_queries):
874
+ scores = sim_matrix_flat[qi]
875
+ idxs = np.argpartition(-scores, kth=min(topk, len(scores)-1))[:topk]
876
+ idxs = idxs[np.argsort(-scores[idxs])]
877
+
878
+ results_dict[f"query_{qi+1}"] = {
879
+ "query_text": q,
880
+ "top_videos": []
881
+ }
882
+
883
+ for rank, vidx in enumerate(idxs, 1):
884
+ results_dict[f"query_{qi+1}"]["top_videos"].append({
885
+ "rank": rank,
886
+ "video_id": video_ids[vidx],
887
+ "video_idx": int(vidx),
888
+ "score": float(scores[vidx])
889
+ })
890
+
891
+ with open(out_json, "w", encoding="utf-8") as f:
892
+ json.dump(results_dict, f, ensure_ascii=False, indent=2)
893
+
894
+ logger.info("T2V Top-10 per query 저장 완료:")
895
+ logger.info(" TSV 파일: %s", out_tsv)
896
+ logger.info(" JSON 파일: %s", out_json)
897
+ logger.info("총 %d개 쿼리에 대한 top-10 결과가 저장되었습니다.", len(all_queries))
898
+
899
+ # # 로그에 모든 쿼리의 Top-10 결과 출력
900
+ # logger.info("=== Query-wise Top-10 Results (전체 %d개 쿼리) ===", len(all_queries))
901
+ # for qi in range(len(all_queries)):
902
+ # scores = sim_matrix_flat[qi]
903
+ # idxs = np.argpartition(-scores, kth=min(topk, len(scores)-1))[:topk]
904
+ # idxs = idxs[np.argsort(-scores[idxs])]
905
+
906
+ # logger.info(f"Query {qi+1}: \"{all_queries[qi]}\"")
907
+ # for rank, vidx in enumerate(idxs, 1):
908
+ # logger.info(f" Rank {rank}: video_id={video_ids[vidx]}, video_idx={vidx}, score={scores[vidx]:.6f}")
909
+ # logger.info("---")
910
+
911
+ # logger.info("=== 모든 쿼리 결과 출력 완료 ===")
912
+
913
+ return tv_metrics['R1']
914
+
915
+
916
+ def main():
917
+ global logger
918
+ args = get_args()
919
+
920
+ if "LOCAL_RANK" in os.environ:
921
+ try:
922
+ args.local_rank = int(os.environ["LOCAL_RANK"])
923
+ except Exception:
924
+ pass
925
+ torch.cuda.set_device(args.local_rank)
926
+
927
+ from datetime import timedelta
928
+ import torch.distributed as dist
929
+ dist.init_process_group(
930
+ backend="nccl",
931
+ init_method="env://",
932
+ timeout=timedelta(minutes=30)
933
+ )
934
+
935
+ args = set_seed_logger(args)
936
+ device, n_gpu = init_device(args, args.local_rank)
937
+
938
+ tokenizer = ClipTokenizer()
939
+
940
+ assert args.task_type == "retrieval"
941
+ model = init_model(args, device, n_gpu, args.local_rank)
942
+
943
+ ## ####################################
944
+ # freeze testing
945
+ ## ####################################
946
+ assert args.freeze_layer_num <= 12 and args.freeze_layer_num >= -1
947
+ if hasattr(model, "clip") and args.freeze_layer_num > -1:
948
+ for name, param in model.clip.named_parameters():
949
+ # top layers always need to train
950
+ if name.find("ln_final.") == 0 or name.find("text_projection") == 0 or name.find("logit_scale") == 0 \
951
+ or name.find("visual.ln_post.") == 0 or name.find("visual.proj") == 0:
952
+ continue # need to train
953
+ elif name.find("visual.transformer.resblocks.") == 0 or name.find("transformer.resblocks.") == 0:
954
+ layer_num = int(name.split(".resblocks.")[1].split(".")[0])
955
+ if layer_num >= args.freeze_layer_num:
956
+ continue # need to train
957
+
958
+ if args.linear_patch == "3d" and name.find("conv2."):
959
+ continue
960
+ else:
961
+ # paramenters which < freeze_layer_num will be freezed
962
+ param.requires_grad = False
963
+
964
+ ## ####################################
965
+ # dataloader loading
966
+ ## ####################################
967
+ assert args.datatype in DATALOADER_DICT
968
+
969
+ assert DATALOADER_DICT[args.datatype]["test"] is not None \
970
+ or DATALOADER_DICT[args.datatype]["val"] is not None
971
+
972
+ test_dataloader, test_length = None, 0
973
+ if DATALOADER_DICT[args.datatype]["test"] is not None:
974
+ test_dataloader, test_length = DATALOADER_DICT[args.datatype]["test"](args, tokenizer)
975
+
976
+ if DATALOADER_DICT[args.datatype]["val"] is not None:
977
+ val_dataloader, val_length = DATALOADER_DICT[args.datatype]["val"](args, tokenizer, subset="val")
978
+ else:
979
+ val_dataloader, val_length = test_dataloader, test_length
980
+
981
+ ## report validation results if the ["test"] is None
982
+ if test_dataloader is None:
983
+ test_dataloader, test_length = val_dataloader, val_length
984
+
985
+ if args.local_rank == 0:
986
+ logger.info("***** Running test *****")
987
+ logger.info(" Num examples = %d", test_length)
988
+ logger.info(" Batch size = %d", args.batch_size_val)
989
+ logger.info(" Num steps = %d", len(test_dataloader))
990
+ logger.info("***** Running val *****")
991
+ logger.info(" Num examples = %d", val_length)
992
+
993
+ ## ####################################
994
+ # train and eval
995
+ ## ####################################
996
+ if args.do_train:
997
+ train_dataloader, train_length, train_sampler = DATALOADER_DICT[args.datatype]["train"](args, tokenizer)
998
+ num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1)
999
+ / args.gradient_accumulation_steps) * args.epochs
1000
+
1001
+ coef_lr = args.coef_lr
1002
+ optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, args.local_rank, coef_lr=coef_lr)
1003
+
1004
+ if args.local_rank == 0:
1005
+ logger.info("***** Running training *****")
1006
+ logger.info(" Num examples = %d", train_length)
1007
+ logger.info(" Batch size = %d", args.batch_size)
1008
+ logger.info(" Num steps = %d", num_train_optimization_steps * args.gradient_accumulation_steps)
1009
+
1010
+ best_score = 0.00001
1011
+ best_output_model_file = "None"
1012
+ ## ##############################################################
1013
+ # resume optimizer state besides loss to continue train
1014
+ ## ##############################################################
1015
+ resumed_epoch = 0
1016
+ if args.resume_model:
1017
+ checkpoint = torch.load(args.resume_model, map_location='cpu')
1018
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
1019
+ resumed_epoch = checkpoint['epoch']+1
1020
+ resumed_loss = checkpoint['loss']
1021
+
1022
+ global_step = 0
1023
+ for epoch in range(resumed_epoch, args.epochs):
1024
+ train_sampler.set_epoch(epoch)
1025
+ tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer,
1026
+ scheduler, global_step, local_rank=args.local_rank)
1027
+
1028
+ if args.local_rank == 0:
1029
+ logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
1030
+
1031
+ output_model_file = save_model(epoch, args, model, optimizer, tr_loss, type_name="")
1032
+
1033
+ ## Run on val dataset, this process is *TIME-consuming*.
1034
+ # logger.info("Eval on val dataset")
1035
+ # R1 = eval_epoch(args, model, val_dataloader, device, n_gpu)
1036
+
1037
+ R1 = eval_epoch(args, model, test_dataloader, device, n_gpu, tokenizer)
1038
+ if best_score <= R1:
1039
+ best_score = R1
1040
+ best_output_model_file = output_model_file
1041
+ logger.info("The best model is: {}, the R1 is: {:.4f}".format(best_output_model_file, best_score))
1042
+
1043
+ ## Uncomment if want to test on the best checkpoint
1044
+ # if args.local_rank == 0:
1045
+ # model = load_model(-1, args, n_gpu, device, model_file=best_output_model_file)
1046
+ # eval_epoch(args, model, test_dataloader, device, n_gpu)
1047
+
1048
+ elif args.do_eval:
1049
+ if args.local_rank == 0:
1050
+ eval_epoch(args, model, test_dataloader, device, n_gpu, tokenizer)
1051
+
1052
+ if __name__ == "__main__":
1053
+ main()
cache_main_task_retrieval_backup.py ADDED
@@ -0,0 +1,867 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import unicode_literals
4
+ from __future__ import print_function
5
+
6
+ import torch
7
+ import numpy as np
8
+ import random
9
+ import os
10
+ from metrics import compute_metrics, tensor_text_to_video_metrics, tensor_video_to_text_sim
11
+ import time
12
+ import argparse
13
+ from modules.tokenization_clip import SimpleTokenizer as ClipTokenizer
14
+ from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
15
+ from modules.modeling import CLIP4Clip
16
+ import matplotlib.pyplot as plt
17
+ from modules.optimization import BertAdam
18
+
19
+ from util import parallel_apply, get_logger
20
+ from dataloaders.data_dataloaders import DATALOADER_DICT
21
+
22
+ torch.distributed.init_process_group(backend="nccl")
23
+
24
+ global logger
25
+
26
+ def get_args(description='CLIP4Clip on Retrieval Task'):
27
+ parser = argparse.ArgumentParser(description=description)
28
+ parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
29
+ parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
30
+ parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
31
+
32
+ parser.add_argument('--train_csv', type=str, default='data/.train.csv', help='')
33
+ parser.add_argument('--val_csv', type=str, default='data/.val.csv', help='')
34
+ parser.add_argument('--data_path', type=str, default='data/caption.pickle', help='data pickle file path')
35
+ parser.add_argument('--features_path', type=str, default='data/videos_feature.pickle', help='feature path')
36
+
37
+ parser.add_argument('--num_thread_reader', type=int, default=1, help='')
38
+ parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate')
39
+ parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit')
40
+ parser.add_argument('--batch_size', type=int, default=256, help='batch size')
41
+ parser.add_argument('--batch_size_val', type=int, default=3500, help='batch size eval')
42
+ parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate exp epoch decay')
43
+ parser.add_argument('--n_display', type=int, default=100, help='Information display frequence')
44
+ parser.add_argument('--video_dim', type=int, default=1024, help='video feature dimension')
45
+ parser.add_argument('--seed', type=int, default=42, help='random seed')
46
+ parser.add_argument('--max_words', type=int, default=20, help='')
47
+ parser.add_argument('--max_frames', type=int, default=100, help='')
48
+ parser.add_argument('--feature_framerate', type=int, default=1, help='')
49
+ parser.add_argument('--margin', type=float, default=0.1, help='margin for loss')
50
+ parser.add_argument('--hard_negative_rate', type=float, default=0.5, help='rate of intra negative sample')
51
+ parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative')
52
+ parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader')
53
+
54
+ parser.add_argument("--output_dir", default=None, type=str, required=True,
55
+ help="The output directory where the model predictions and checkpoints will be written.")
56
+ parser.add_argument("--cross_model", default="cross-base", type=str, required=False, help="Cross module")
57
+ parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.")
58
+ parser.add_argument("--resume_model", default=None, type=str, required=False, help="Resume train model.")
59
+ parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
60
+ parser.add_argument("--warmup_proportion", default=0.1, type=float,
61
+ help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.")
62
+ parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
63
+ help="Number of updates steps to accumulate before performing a backward/update pass.")
64
+ parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.")
65
+
66
+ parser.add_argument("--cache_dir", default="", type=str,
67
+ help="Where do you want to store the pre-trained models downloaded from s3")
68
+
69
+ parser.add_argument('--fp16', action='store_true',
70
+ help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
71
+ parser.add_argument('--fp16_opt_level', type=str, default='O1',
72
+ help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
73
+ "See details at https://nvidia.github.io/apex/amp.html")
74
+
75
+ parser.add_argument("--task_type", default="retrieval", type=str, help="Point the task `retrieval` to finetune.")
76
+ parser.add_argument("--datatype", default="msrvtt", type=str, help="Point the dataset to finetune.")
77
+
78
+ parser.add_argument("--world_size", default=0, type=int, help="distribted training")
79
+ parser.add_argument("--local_rank", default=0, type=int, help="distribted training")
80
+ parser.add_argument("--rank", default=0, type=int, help="distribted training")
81
+ parser.add_argument('--coef_lr', type=float, default=1., help='coefficient for bert branch.')
82
+ parser.add_argument('--use_mil', action='store_true', help="Whether use MIL as Miech et. al. (2020).")
83
+ parser.add_argument('--sampled_use_mil', action='store_true', help="Whether MIL, has a high priority than use_mil.")
84
+
85
+ parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.")
86
+ parser.add_argument('--visual_num_hidden_layers', type=int, default=12, help="Layer NO. of visual.")
87
+ parser.add_argument('--cross_num_hidden_layers', type=int, default=4, help="Layer NO. of cross.")
88
+
89
+ parser.add_argument('--loose_type', action='store_true', help="Default using tight type for retrieval.")
90
+ parser.add_argument('--expand_msrvtt_sentences', action='store_true', help="")
91
+
92
+ parser.add_argument('--train_frame_order', type=int, default=0, choices=[0, 1, 2],
93
+ help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.")
94
+ parser.add_argument('--eval_frame_order', type=int, default=0, choices=[0, 1, 2],
95
+ help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.")
96
+
97
+ parser.add_argument('--freeze_layer_num', type=int, default=0, help="Layer NO. of CLIP need to freeze.")
98
+ parser.add_argument('--slice_framepos', type=int, default=0, choices=[0, 1, 2],
99
+ help="0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly.")
100
+ parser.add_argument('--linear_patch', type=str, default="2d", choices=["2d", "3d"],
101
+ help="linear projection of flattened patches.")
102
+ parser.add_argument('--sim_header', type=str, default="meanP",
103
+ choices=["meanP", "seqLSTM", "seqTransf", "tightTransf"],
104
+ help="choice a similarity header.")
105
+
106
+ parser.add_argument("--pretrained_clip_name", default="ViT-B/32", type=str, help="Choose a CLIP version")
107
+ parser.add_argument("--use_rff", action='store_true', help="Use RFF hypervector encoding for video embeddings")
108
+ parser.add_argument("--rff_dim", type=int, default=3000, help="Hypervector dimension for RFF encoding")
109
+ parser.add_argument("--use_clip4hashing", action="store_true", help="CLIP4Hashing 손실·해시 경로 사용 여부")
110
+ parser.add_argument("--hash_bit", type=int, default=2048, help="해시 코드 비트 수 (default 1024)")
111
+
112
+ args = parser.parse_args()
113
+
114
+ if args.sim_header == "tightTransf":
115
+ args.loose_type = False
116
+
117
+ # Check paramenters
118
+ if args.gradient_accumulation_steps < 1:
119
+ raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
120
+ args.gradient_accumulation_steps))
121
+ if not args.do_train and not args.do_eval:
122
+ raise ValueError("At least one of `do_train` or `do_eval` must be True.")
123
+
124
+ args.batch_size = int(args.batch_size / args.gradient_accumulation_steps)
125
+
126
+ return args
127
+
128
+ def set_seed_logger(args):
129
+ global logger
130
+ # predefining random initial seeds
131
+ random.seed(args.seed)
132
+ os.environ['PYTHONHASHSEED'] = str(args.seed)
133
+ np.random.seed(args.seed)
134
+ torch.manual_seed(args.seed)
135
+ torch.cuda.manual_seed(args.seed)
136
+ torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU.
137
+ torch.backends.cudnn.benchmark = False
138
+ torch.backends.cudnn.deterministic = True
139
+
140
+ world_size = torch.distributed.get_world_size()
141
+ torch.cuda.set_device(args.local_rank)
142
+ args.world_size = world_size
143
+ rank = torch.distributed.get_rank()
144
+ args.rank = rank
145
+
146
+ if not os.path.exists(args.output_dir):
147
+ os.makedirs(args.output_dir, exist_ok=True)
148
+
149
+ logger = get_logger(os.path.join(args.output_dir, "log.txt"))
150
+
151
+ if args.local_rank == 0:
152
+ logger.info("Effective parameters:")
153
+ for key in sorted(args.__dict__):
154
+ logger.info(" <<< {}: {}".format(key, args.__dict__[key]))
155
+
156
+ return args
157
+
158
+ def init_device(args, local_rank):
159
+ global logger
160
+
161
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu", local_rank)
162
+
163
+ n_gpu = torch.cuda.device_count()
164
+ logger.info("device: {} n_gpu: {}".format(device, n_gpu))
165
+ args.n_gpu = n_gpu
166
+
167
+ if args.batch_size % args.n_gpu != 0 or args.batch_size_val % args.n_gpu != 0:
168
+ raise ValueError("Invalid batch_size/batch_size_val and n_gpu parameter: {}%{} and {}%{}, should be == 0".format(
169
+ args.batch_size, args.n_gpu, args.batch_size_val, args.n_gpu))
170
+
171
+ return device, n_gpu
172
+
173
+ def init_model(args, device, n_gpu, local_rank):
174
+
175
+ if args.init_model:
176
+ model_state_dict = torch.load(args.init_model, map_location='cpu')
177
+ else:
178
+ model_state_dict = None
179
+
180
+ # Prepare model
181
+ cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
182
+ model = CLIP4Clip.from_pretrained(args.cross_model, cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
183
+
184
+ model.to(device)
185
+
186
+ return model
187
+
188
+ def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=1.):
189
+
190
+ if hasattr(model, 'module'):
191
+ model = model.module
192
+
193
+ param_optimizer = list(model.named_parameters())
194
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
195
+
196
+ decay_param_tp = [(n, p) for n, p in param_optimizer if not any(nd in n for nd in no_decay)]
197
+ no_decay_param_tp = [(n, p) for n, p in param_optimizer if any(nd in n for nd in no_decay)]
198
+
199
+ decay_clip_param_tp = [(n, p) for n, p in decay_param_tp if "clip." in n]
200
+ decay_noclip_param_tp = [(n, p) for n, p in decay_param_tp if "clip." not in n]
201
+
202
+ no_decay_clip_param_tp = [(n, p) for n, p in no_decay_param_tp if "clip." in n]
203
+ no_decay_noclip_param_tp = [(n, p) for n, p in no_decay_param_tp if "clip." not in n]
204
+
205
+ weight_decay = 0.2
206
+ optimizer_grouped_parameters = [
207
+ {'params': [p for n, p in decay_clip_param_tp], 'weight_decay': weight_decay, 'lr': args.lr * coef_lr},
208
+ {'params': [p for n, p in decay_noclip_param_tp], 'weight_decay': weight_decay},
209
+ {'params': [p for n, p in no_decay_clip_param_tp], 'weight_decay': 0.0, 'lr': args.lr * coef_lr},
210
+ {'params': [p for n, p in no_decay_noclip_param_tp], 'weight_decay': 0.0}
211
+ ]
212
+
213
+ scheduler = None
214
+ optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, warmup=args.warmup_proportion,
215
+ schedule='warmup_cosine', b1=0.9, b2=0.98, e=1e-6,
216
+ t_total=num_train_optimization_steps, weight_decay=weight_decay,
217
+ max_grad_norm=1.0)
218
+
219
+ # 옵티마이저 만든 뒤 곧장 실행
220
+ name2param = {n: p for n, p in model.named_parameters() if p.requires_grad}
221
+ param2name = {id(p): n for n, p in name2param.items()}
222
+
223
+ for gi, g in enumerate(optimizer.param_groups):
224
+ print(f"[group {gi}] lr={g['lr']:.2e}, params={len(g['params'])}")
225
+ # 각 그룹에서 몇 개만 샘플로 찍기
226
+ for p in g["params"][:8]:
227
+ print(" ", param2name.get(id(p), "?"))
228
+
229
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
230
+ output_device=local_rank, find_unused_parameters=True)
231
+
232
+ return optimizer, scheduler, model
233
+
234
+ def save_model(epoch, args, model, optimizer, tr_loss, type_name=""):
235
+ # Only save the model it-self
236
+ model_to_save = model.module if hasattr(model, 'module') else model
237
+ output_model_file = os.path.join(
238
+ args.output_dir, "pytorch_model.bin.{}{}".format("" if type_name=="" else type_name+".", epoch))
239
+ optimizer_state_file = os.path.join(
240
+ args.output_dir, "pytorch_opt.bin.{}{}".format("" if type_name=="" else type_name+".", epoch))
241
+ torch.save(model_to_save.state_dict(), output_model_file)
242
+ torch.save({
243
+ 'epoch': epoch,
244
+ 'optimizer_state_dict': optimizer.state_dict(),
245
+ 'loss': tr_loss,
246
+ }, optimizer_state_file)
247
+ logger.info("Model saved to %s", output_model_file)
248
+ logger.info("Optimizer saved to %s", optimizer_state_file)
249
+ return output_model_file
250
+
251
+ def load_model(epoch, args, n_gpu, device, model_file=None):
252
+ if model_file is None or len(model_file) == 0:
253
+ model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(epoch))
254
+ if os.path.exists(model_file):
255
+ model_state_dict = torch.load(model_file, map_location='cpu')
256
+ if args.local_rank == 0:
257
+ logger.info("Model loaded from %s", model_file)
258
+ # Prepare model
259
+ cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
260
+ model = CLIP4Clip.from_pretrained(args.cross_model, cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
261
+
262
+ model.to(device)
263
+ else:
264
+ model = None
265
+
266
+ logger.info(f"모델을 로드합니다:{cache_dir}")
267
+ return model
268
+
269
+ def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, scheduler, global_step, local_rank=0):
270
+ global logger
271
+ torch.cuda.empty_cache()
272
+ if hasattr(model, 'module'):
273
+ model.module.train()
274
+ else:
275
+ model.train()
276
+ log_step = args.n_display
277
+ start_time = time.time()
278
+ total_loss = 0
279
+
280
+ for step, batch in enumerate(train_dataloader):
281
+ if n_gpu == 1:
282
+ # multi-gpu does scattering it-self
283
+ batch = tuple(t.to(device=device, non_blocking=True) for t in batch)
284
+
285
+ input_ids, input_mask, segment_ids, video, video_mask = batch
286
+ loss = model(input_ids, segment_ids, input_mask, video, video_mask)
287
+
288
+ if n_gpu > 1:
289
+ loss = loss.mean() # mean() to average on multi-gpu.
290
+ if args.gradient_accumulation_steps > 1:
291
+ loss = loss / args.gradient_accumulation_steps
292
+
293
+ loss.backward()
294
+
295
+ total_loss += float(loss)
296
+ if (step + 1) % args.gradient_accumulation_steps == 0:
297
+
298
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
299
+
300
+ if scheduler is not None:
301
+ scheduler.step() # Update learning rate schedule
302
+
303
+ optimizer.step()
304
+ optimizer.zero_grad()
305
+
306
+ # https://github.com/openai/CLIP/issues/46
307
+ if hasattr(model, 'module'):
308
+ torch.clamp_(model.module.clip.logit_scale.data, max=np.log(100))
309
+ else:
310
+ torch.clamp_(model.clip.logit_scale.data, max=np.log(100))
311
+
312
+ global_step += 1
313
+ if global_step % log_step == 0 and local_rank == 0:
314
+ logger.info("Epoch: %d/%s, Step: %d/%d, Lr: %s, Loss: %f, Time/step: %f", epoch + 1,
315
+ args.epochs, step + 1,
316
+ len(train_dataloader), "-".join([str('%.9f'%itm) for itm in sorted(list(set(optimizer.get_lr())))]),
317
+ float(loss),
318
+ (time.time() - start_time) / (log_step * args.gradient_accumulation_steps))
319
+ start_time = time.time()
320
+
321
+ total_loss = total_loss / len(train_dataloader)
322
+ return total_loss, global_step
323
+
324
+ def _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_sequence_output_list, batch_visual_output_list):
325
+
326
+ sim_matrix = []
327
+ for idx1, b1 in enumerate(batch_list_t):
328
+ input_mask, segment_ids, *_tmp = b1
329
+ sequence_output = batch_sequence_output_list[idx1]
330
+ each_row = []
331
+ for idx2, b2 in enumerate(batch_list_v):
332
+ video_mask, *_tmp = b2
333
+ visual_output = batch_visual_output_list[idx2]
334
+ b1b2_logits, *_tmp = model.get_similarity_logits(sequence_output, visual_output, input_mask, video_mask,
335
+ loose_type=model.loose_type)
336
+ b1b2_logits = b1b2_logits.cpu().detach().numpy()
337
+ each_row.append(b1b2_logits)
338
+ each_row = np.concatenate(tuple(each_row), axis=-1)
339
+ sim_matrix.append(each_row)
340
+ return sim_matrix
341
+
342
+ def eval_epoch(args, model, test_dataloader, device, n_gpu, tokenizer):
343
+ import numpy as np
344
+ import os
345
+ import torch
346
+ import matplotlib.pyplot as plt
347
+
348
+ def _decode_query(tokenizer, ids_tensor):
349
+ # ids_tensor: 1D tensor (token ids)
350
+ try:
351
+ if isinstance(ids_tensor, torch.Tensor):
352
+ ids = ids_tensor.cpu().numpy().tolist()
353
+ else:
354
+ ids = ids_tensor.tolist() if hasattr(ids_tensor, 'tolist') else list(ids_tensor)
355
+
356
+ # ClipTokenizer의 특수 토큰 ID들
357
+ start_token_id = tokenizer.encoder.get('<|startoftext|>', 49406) # 기본값
358
+ end_token_id = tokenizer.encoder.get('<|endoftext|>', 49407) # 기본값
359
+
360
+ # 패딩 토큰(0)과 특수 토큰 제거
361
+ clean_ids = []
362
+ for token_id in ids:
363
+ if token_id > 0 and token_id != start_token_id and token_id != end_token_id:
364
+ clean_ids.append(token_id)
365
+
366
+ if not clean_ids:
367
+ return "<empty_query>"
368
+
369
+ # 유효하지 않은 토큰 ID 필터링 (vocab 범위 내)
370
+ vocab_size = len(tokenizer.decoder)
371
+ valid_ids = [tid for tid in clean_ids if tid < vocab_size]
372
+
373
+ if not valid_ids:
374
+ return "<invalid_tokens>"
375
+
376
+ # 디코딩 시도
377
+ try:
378
+ decoded_text = tokenizer.decode(valid_ids)
379
+ return decoded_text.strip()
380
+ except KeyError as e:
381
+ # 개별 토큰별로 디코딩 시도
382
+ decoded_tokens = []
383
+ for tid in valid_ids:
384
+ if tid in tokenizer.decoder:
385
+ decoded_tokens.append(tokenizer.decoder[tid])
386
+ else:
387
+ decoded_tokens.append(f"<unk_{tid}>")
388
+ text = ''.join(decoded_tokens)
389
+ # BPE 후처리
390
+ text = text.replace('</w>', ' ').strip()
391
+ return text if text else "<decode_partial_error>"
392
+ except Exception as e:
393
+ return f"<decode_error: {str(e)[:50]}>"
394
+
395
+ except Exception as e:
396
+ return f"<general_error: {str(e)[:50]}>"
397
+
398
+ def _get_video_ids_from_dataset(dataset, num_videos):
399
+ # 다양한 후보 속성명 시도 → 없으면 0..N-1 인덱스 문자열로 대체
400
+ for attr in ["video_list", "video_ids", "video_names", "videos", "vids", "id_list"]:
401
+ if hasattr(dataset, attr):
402
+ obj = getattr(dataset, attr)
403
+ try:
404
+ if isinstance(obj, (list, tuple)) and len(obj) == num_videos:
405
+ return list(map(str, obj))
406
+ except Exception:
407
+ pass
408
+ return [str(i) for i in range(num_videos)]
409
+
410
+ global logger
411
+
412
+ multi_sentence_ = False
413
+ cut_off_points_, sentence_num_, video_num_ = [], -1, -1
414
+
415
+ if hasattr(model, 'module'):
416
+ model = model.module.to(device)
417
+ else:
418
+ model = model.to(device)
419
+
420
+ if hasattr(model, 'module'):
421
+ model.module.eval()
422
+ else:
423
+ model.eval()
424
+
425
+ logger.info("Model %s", "training" if model.training else "eval")
426
+
427
+ # suffix for cache/result naming
428
+ suffix = "_hash" if getattr(args, "use_clip4hashing", False) else ""
429
+ suffix += "_rff" if args.use_rff else ""
430
+ suffix += "_trained" if args.init_model else ""
431
+
432
+ # (A) 캐시 로드/생성
433
+ if "train" in args.val_csv and "10k" in args.val_csv:
434
+ cache_name = f"{args.datatype}_train_test_10k_cache{suffix}.pt"
435
+ logger.info(f"9k 훈련 데이터 캐시 생성: {cache_name}")
436
+ else:
437
+ cache_name = f"{args.datatype}_eval_cache{suffix}.pt"
438
+ logger.info(f"평가 데이터 캐시: {cache_name}")
439
+
440
+ cache_path = os.path.join(args.output_dir, cache_name)
441
+
442
+ loaded_from_cache = False
443
+ if os.path.exists(cache_path):
444
+ logger.info(f"캐시된 피처를 로드합니다: {cache_path}")
445
+ cache = torch.load(cache_path, map_location=device)
446
+ batch_sequence_output_list = cache['batch_sequence_output_list']
447
+ batch_visual_output_list = cache['batch_visual_output_list']
448
+ batch_list_t = cache['batch_list_t']
449
+ batch_list_v = cache['batch_list_v']
450
+ text_input_ids_list = cache.get('text_input_ids_list', None)
451
+ video_ids = cache.get('video_ids', None)
452
+ loaded_from_cache = True
453
+
454
+ logger.info(f"[Cache] 텍스트 피쳐 개수={len(batch_sequence_output_list)} "
455
+ f"각 텐서 shape={batch_sequence_output_list[0].shape}")
456
+ logger.info(f"[Cache] 비디오 피쳐 개수={len(batch_visual_output_list)} "
457
+ f"각 텐서 shape={batch_visual_output_list[0].shape}")
458
+ else:
459
+ print("Caching feature..")
460
+ if hasattr(test_dataloader.dataset, 'multi_sentence_per_video') and \
461
+ test_dataloader.dataset.multi_sentence_per_video:
462
+ multi_sentence_ = True
463
+ cut_off_points_ = test_dataloader.dataset.cut_off_points
464
+ sentence_num_ = test_dataloader.dataset.sentence_num
465
+ video_num_ = test_dataloader.dataset.video_num
466
+ cut_off_points_ = [itm - 1 for itm in cut_off_points_]
467
+ logger.warning("Eval under multi-sentence-per-video. sentence num: %s, video num: %s",
468
+ sentence_num_, video_num_)
469
+
470
+ with torch.no_grad():
471
+ batch_list_t = []
472
+ batch_list_v = []
473
+ batch_sequence_output_list, batch_visual_output_list = [], []
474
+ text_input_ids_list = []
475
+ total_video_num = 0
476
+
477
+ for bid, batch in enumerate(test_dataloader):
478
+ batch = tuple(t.to(device) for t in batch)
479
+ input_ids, input_mask, segment_ids, video, video_mask = batch
480
+
481
+ if multi_sentence_:
482
+ b, *_t = video.shape
483
+ sequence_output = model.get_sequence_output(input_ids, segment_ids, input_mask)
484
+ batch_sequence_output_list.append(sequence_output)
485
+ # input_ids를 함께 보관 (run_on_single_gpu는 *_tmp로 무시하므로 안전)
486
+ batch_list_t.append((input_mask, segment_ids, input_ids.detach().cpu()))
487
+
488
+ s_, e_ = total_video_num, total_video_num + b
489
+ filter_inds = [itm - s_ for itm in cut_off_points_ if s_ <= itm < e_]
490
+ if len(filter_inds) > 0:
491
+ video, video_mask = video[filter_inds, ...], video_mask[filter_inds, ...]
492
+ visual_output = model.get_visual_output(video, video_mask)
493
+ batch_visual_output_list.append(visual_output)
494
+ batch_list_v.append((video_mask,))
495
+ total_video_num += b
496
+ else:
497
+ sequence_output, visual_output = model.get_sequence_visual_output(
498
+ input_ids, segment_ids, input_mask, video, video_mask)
499
+
500
+ batch_sequence_output_list.append(sequence_output)
501
+ batch_list_t.append((input_mask, segment_ids, input_ids.detach().cpu()))
502
+
503
+ batch_visual_output_list.append(visual_output)
504
+ batch_list_v.append((video_mask,))
505
+
506
+ print("{}/{}\r".format(bid, len(test_dataloader)), end="")
507
+
508
+ # 비디오 ID 목록 구성 (데이터셋 노출 없으면 0..N-1)
509
+ num_videos = int(sum(bv.shape[0] for bv in batch_visual_output_list))
510
+ video_ids = _get_video_ids_from_dataset(test_dataloader.dataset, num_videos)
511
+
512
+ logger.info(f"추출된 피처를 캐시에 저장합니다: {cache_path}")
513
+ torch.save({
514
+ 'batch_sequence_output_list': batch_sequence_output_list,
515
+ 'batch_visual_output_list': batch_visual_output_list,
516
+ 'batch_list_t': batch_list_t,
517
+ 'batch_list_v': batch_list_v,
518
+ 'text_input_ids_list': text_input_ids_list,
519
+ 'video_ids': video_ids,
520
+ }, cache_path)
521
+
522
+ logger.info(f"[Cache] 텍스트 피쳐 개수={len(batch_sequence_output_list)} "
523
+ f"각 텐서 shape={batch_sequence_output_list[0].shape}")
524
+ logger.info(f"[Cache] 비디오 피쳐 개수={len(batch_visual_output_list)} "
525
+ f"각 텐서 shape={batch_visual_output_list[0].shape}")
526
+
527
+ # 캐시에 text_input_ids_list가 없으면, 한 번 더 훑어서 수집 (구버전 캐시 호환)
528
+ if loaded_from_cache and 'text_input_ids_list' not in cache:
529
+ logger.info("캐시에 text_input_ids_list가 없어 재수집합니다(호환성 경로).")
530
+ text_input_ids_list = []
531
+ with torch.no_grad():
532
+ for batch in test_dataloader:
533
+ input_ids = batch[0].detach().cpu()
534
+ text_input_ids_list.append(input_ids)
535
+ elif loaded_from_cache and text_input_ids_list is None:
536
+ # batch_list_t에서 input_ids 추출
537
+ logger.info("batch_list_t에서 input_ids를 추출합니다.")
538
+ text_input_ids_list = []
539
+ for input_mask, segment_ids, input_ids in batch_list_t:
540
+ text_input_ids_list.append(input_ids)
541
+
542
+ # video_ids가 없으면 만들어준다(구버전 캐시 호환)
543
+ if loaded_from_cache and cache.get('video_ids', None) is None:
544
+ num_videos = int(sum(bv.shape[0] for bv in batch_visual_output_list))
545
+ video_ids = _get_video_ids_from_dataset(test_dataloader.dataset, num_videos)
546
+
547
+ # (B) 유사도 행렬 계산
548
+ def _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_sequence_output_list, batch_visual_output_list):
549
+ sim_matrix = []
550
+ for idx1, b1 in enumerate(batch_list_t):
551
+ input_mask, segment_ids, *_tmp = b1
552
+ sequence_output = batch_sequence_output_list[idx1]
553
+ each_row = []
554
+ for idx2, b2 in enumerate(batch_list_v):
555
+ video_mask, *_tmp = b2
556
+ visual_output = batch_visual_output_list[idx2]
557
+ b1b2_logits, *_tmp = model.get_similarity_logits(
558
+ sequence_output, visual_output, input_mask, video_mask, loose_type=model.loose_type)
559
+ b1b2_logits = b1b2_logits.cpu().detach().numpy()
560
+ each_row.append(b1b2_logits)
561
+ each_row = np.concatenate(tuple(each_row), axis=-1)
562
+ sim_matrix.append(each_row)
563
+ return sim_matrix
564
+
565
+ if n_gpu > 1:
566
+ device_ids = list(range(n_gpu))
567
+ batch_list_t_splits, batch_list_v_splits = [], []
568
+ batch_t_output_splits, batch_v_output_splits = [], []
569
+ bacth_len = len(batch_list_t)
570
+ split_len = (bacth_len + n_gpu - 1) // n_gpu
571
+ for dev_id in device_ids:
572
+ s_, e_ = dev_id * split_len, (dev_id + 1) * split_len
573
+ if dev_id == 0:
574
+ batch_list_t_splits.append(batch_list_t[s_:e_]); batch_list_v_splits.append(batch_list_v)
575
+ batch_t_output_splits.append(batch_sequence_output_list[s_:e_]); batch_v_output_splits.append(batch_visual_output_list)
576
+ else:
577
+ devc = torch.device(f'cuda:{dev_id}')
578
+ devc_batch_list = [tuple(t.to(devc) for t in b) for b in batch_list_t[s_:e_]]
579
+ batch_list_t_splits.append(devc_batch_list)
580
+ devc_batch_list = [tuple(t.to(devc) for t in b) for b in batch_list_v]
581
+ batch_list_v_splits.append(devc_batch_list)
582
+ devc_batch_list = [b.to(devc) for b in batch_sequence_output_list[s_:e_]]
583
+ batch_t_output_splits.append(devc_batch_list)
584
+ devc_batch_list = [b.to(devc) for b in batch_visual_output_list]
585
+ batch_v_output_splits.append(devc_batch_list)
586
+
587
+ parameters_tuple_list = [(batch_list_t_splits[dev_id], batch_list_v_splits[dev_id],
588
+ batch_t_output_splits[dev_id], batch_v_output_splits[dev_id]) for dev_id in device_ids]
589
+ parallel_outputs = parallel_apply(_run_on_single_gpu, model, parameters_tuple_list, device_ids)
590
+ sim_matrix = []
591
+ for idx in range(len(parallel_outputs)):
592
+ sim_matrix += parallel_outputs[idx]
593
+ sim_matrix = np.concatenate(tuple(sim_matrix), axis=0)
594
+ else:
595
+ sim_matrix = _run_on_single_gpu(model, batch_list_t, batch_list_v,
596
+ batch_sequence_output_list, batch_visual_output_list)
597
+ sim_matrix = np.concatenate(tuple(sim_matrix), axis=0)
598
+
599
+ # (C) 멀티센텐스 처리 및 메트릭
600
+ if hasattr(test_dataloader.dataset, 'multi_sentence_per_video') and \
601
+ test_dataloader.dataset.multi_sentence_per_video:
602
+ multi_sentence_ = True
603
+
604
+ if multi_sentence_:
605
+ logger.info("before reshape, sim matrix size: %d x %d", sim_matrix.shape[0], sim_matrix.shape[1])
606
+ sim_matrix_flat = sim_matrix.copy() # 쿼리별 Top-K 용 2D 보관
607
+
608
+ cut_off_points2len_ = [itm + 1 for itm in cut_off_points_]
609
+ max_length = max([e_-s_ for s_, e_ in zip([0]+cut_off_points2len_[:-1], cut_off_points2len_)])
610
+ sim_matrix_new = []
611
+ for s_, e_ in zip([0] + cut_off_points2len_[:-1], cut_off_points2len_):
612
+ sim_matrix_new.append(np.concatenate((sim_matrix[s_:e_],
613
+ np.full((max_length-e_+s_, sim_matrix.shape[1]), -np.inf)), axis=0))
614
+ sim_matrix = np.stack(tuple(sim_matrix_new), axis=0)
615
+ logger.info("after reshape, sim matrix size: %d x %d x %d",
616
+ sim_matrix.shape[0], sim_matrix.shape[1], sim_matrix.shape[2])
617
+
618
+ tv_metrics = tensor_text_to_video_metrics(sim_matrix)
619
+ vt_metrics = compute_metrics(tensor_video_to_text_sim(sim_matrix))
620
+ else:
621
+ logger.info("sim matrix size: %d x %d", sim_matrix.shape[0], sim_matrix.shape[1])
622
+ # 히트맵 저장(샘플)
623
+ plt.figure(figsize=(8,6))
624
+ plt.imshow(sim_matrix[:100, :100], aspect='auto')
625
+ plt.title('Similarity Matrix Heatmap')
626
+ plt.xlabel('Video Index')
627
+ plt.ylabel('Text Index')
628
+ plt.tight_layout()
629
+ out_path = os.path.join(args.output_dir, 'sim_matrix_heatmap.png')
630
+ plt.savefig(out_path); plt.close()
631
+ logger.info(f"Saved sim_matrix heatmap to {out_path}")
632
+
633
+ sim_matrix_flat = sim_matrix # 2D 그대로
634
+ tv_metrics = compute_metrics(sim_matrix)
635
+ vt_metrics = compute_metrics(sim_matrix.T)
636
+ logger.info('\t Length-T: %d, Length-V:%d', len(sim_matrix), len(sim_matrix[0]))
637
+
638
+ logger.info("Text-to-Video:")
639
+ logger.info('\t>>> R@1: %.1f - R@5: %.1f - R@10: %.1f - Median R: %.1f - Mean R: %.1f',
640
+ tv_metrics['R1'], tv_metrics['R5'], tv_metrics['R10'], tv_metrics['MR'], tv_metrics['MeanR'])
641
+ logger.info("Video-to-Text:")
642
+ logger.info('\t>>> V2T$R@1: %.1f - V2T$R@5: %.1f - V2T$R@10: %.1f - V2T$Median R: %.1f - V2T$Mean R: %.1f',
643
+ vt_metrics['R1'], vt_metrics['R5'], vt_metrics['R10'], vt_metrics['MR'], vt_metrics['MeanR'])
644
+
645
+ # (D) 쿼리 텍스트 복원 + Top-10 덤프
646
+ # text_input_ids_list: List[Tensor[B_i, L]]
647
+ all_queries = []
648
+ logger.info(f"text_input_ids_list 개수: {len(text_input_ids_list)}")
649
+
650
+ for batch_idx, ids_batch in enumerate(text_input_ids_list):
651
+ if ids_batch is None:
652
+ logger.warning(f"배치 {batch_idx}: ids_batch가 None입니다.")
653
+ continue
654
+
655
+ try:
656
+ ids_batch = ids_batch if isinstance(ids_batch, torch.Tensor) else torch.as_tensor(ids_batch)
657
+ logger.info(f"배치 {batch_idx}: shape={ids_batch.shape}")
658
+
659
+ for row_idx, row in enumerate(ids_batch):
660
+ decoded = _decode_query(tokenizer, row)
661
+ all_queries.append(decoded)
662
+ if batch_idx == 0 and row_idx < 3: # 첫 배치의 처음 3개만 샘플로 출력
663
+ logger.info(f"샘플 디코딩 결과 [{batch_idx}-{row_idx}]: '{decoded}'")
664
+
665
+ except Exception as e:
666
+ logger.error(f"배치 {batch_idx} 처리 중 오류: {str(e)}")
667
+ # 에러가 발생해도 계속 진행
668
+ continue
669
+
670
+ logger.info(f"총 {len(all_queries)}개의 쿼리가 디코딩되었습니다.")
671
+
672
+ # video_ids 길이 보정(안전)
673
+ num_videos = sim_matrix_flat.shape[1]
674
+ if 'video_ids' in locals():
675
+ if len(video_ids) != num_videos:
676
+ logger.warning("video_ids 길이(%d)와 비디오 수(%d)가 달라 index로 대체합니다.",
677
+ len(video_ids), num_videos)
678
+ video_ids = [str(i) for i in range(num_videos)]
679
+ else:
680
+ video_ids = [str(i) for i in range(num_videos)]
681
+
682
+ # 저장 파일
683
+ topk = 10
684
+ out_tsv = os.path.join(args.output_dir, f"t2v_top10{suffix}.tsv")
685
+ out_json = os.path.join(args.output_dir, f"t2v_top10{suffix}.json")
686
+
687
+ if args.local_rank == 0:
688
+ import json
689
+
690
+ # TSV 파일 저장
691
+ with open(out_tsv, "w", encoding="utf-8") as f:
692
+ f.write("query_idx\tquery\tvideo_rank\tvideo_id\tvideo_idx\tscore\n")
693
+ for qi, q in enumerate(all_queries):
694
+ scores = sim_matrix_flat[qi]
695
+ # 효율: argpartition 후 정렬
696
+ idxs = np.argpartition(-scores, kth=min(topk, len(scores)-1))[:topk]
697
+ idxs = idxs[np.argsort(-scores[idxs])]
698
+ for rank, vidx in enumerate(idxs, 1):
699
+ f.write(f"{qi}\t{q}\t{rank}\t{video_ids[vidx]}\t{int(vidx)}\t{float(scores[vidx]):.6f}\n")
700
+
701
+ # JSON 파일 저장 (구조화된 형태)
702
+ results_dict = {}
703
+ for qi, q in enumerate(all_queries):
704
+ scores = sim_matrix_flat[qi]
705
+ idxs = np.argpartition(-scores, kth=min(topk, len(scores)-1))[:topk]
706
+ idxs = idxs[np.argsort(-scores[idxs])]
707
+
708
+ results_dict[f"query_{qi+1}"] = {
709
+ "query_text": q,
710
+ "top_videos": []
711
+ }
712
+
713
+ for rank, vidx in enumerate(idxs, 1):
714
+ results_dict[f"query_{qi+1}"]["top_videos"].append({
715
+ "rank": rank,
716
+ "video_id": video_ids[vidx],
717
+ "video_idx": int(vidx),
718
+ "score": float(scores[vidx])
719
+ })
720
+
721
+ with open(out_json, "w", encoding="utf-8") as f:
722
+ json.dump(results_dict, f, ensure_ascii=False, indent=2)
723
+
724
+ logger.info("T2V Top-10 per query 저장 완료:")
725
+ logger.info(" TSV 파일: %s", out_tsv)
726
+ logger.info(" JSON 파일: %s", out_json)
727
+ logger.info("총 %d개 쿼리에 대한 top-10 결과가 저장되었습니다.", len(all_queries))
728
+
729
+ # 로그에 모든 쿼리의 Top-10 결과 출력
730
+ logger.info("=== Query-wise Top-10 Results (전체 %d개 쿼리) ===", len(all_queries))
731
+ for qi in range(len(all_queries)):
732
+ scores = sim_matrix_flat[qi]
733
+ idxs = np.argpartition(-scores, kth=min(topk, len(scores)-1))[:topk]
734
+ idxs = idxs[np.argsort(-scores[idxs])]
735
+
736
+ logger.info(f"Query {qi+1}: \"{all_queries[qi]}\"")
737
+ for rank, vidx in enumerate(idxs, 1):
738
+ logger.info(f" Rank {rank}: video_id={video_ids[vidx]}, video_idx={vidx}, score={scores[vidx]:.6f}")
739
+ logger.info("---")
740
+
741
+ logger.info("=== 모든 쿼리 결과 출력 완료 ===")
742
+
743
+ return tv_metrics['R1']
744
+
745
+
746
+ def main():
747
+ global logger
748
+ args = get_args()
749
+ args = set_seed_logger(args)
750
+ device, n_gpu = init_device(args, args.local_rank)
751
+
752
+ tokenizer = ClipTokenizer()
753
+
754
+ assert args.task_type == "retrieval"
755
+ model = init_model(args, device, n_gpu, args.local_rank)
756
+
757
+ ## ####################################
758
+ # freeze testing
759
+ ## ####################################
760
+ assert args.freeze_layer_num <= 12 and args.freeze_layer_num >= -1
761
+ if hasattr(model, "clip") and args.freeze_layer_num > -1:
762
+ for name, param in model.clip.named_parameters():
763
+ # top layers always need to train
764
+ if name.find("ln_final.") == 0 or name.find("text_projection") == 0 or name.find("logit_scale") == 0 \
765
+ or name.find("visual.ln_post.") == 0 or name.find("visual.proj") == 0:
766
+ continue # need to train
767
+ elif name.find("visual.transformer.resblocks.") == 0 or name.find("transformer.resblocks.") == 0:
768
+ layer_num = int(name.split(".resblocks.")[1].split(".")[0])
769
+ if layer_num >= args.freeze_layer_num:
770
+ continue # need to train
771
+
772
+ if args.linear_patch == "3d" and name.find("conv2."):
773
+ continue
774
+ else:
775
+ # paramenters which < freeze_layer_num will be freezed
776
+ param.requires_grad = False
777
+
778
+ ## ####################################
779
+ # dataloader loading
780
+ ## ####################################
781
+ assert args.datatype in DATALOADER_DICT
782
+
783
+ assert DATALOADER_DICT[args.datatype]["test"] is not None \
784
+ or DATALOADER_DICT[args.datatype]["val"] is not None
785
+
786
+ test_dataloader, test_length = None, 0
787
+ if DATALOADER_DICT[args.datatype]["test"] is not None:
788
+ test_dataloader, test_length = DATALOADER_DICT[args.datatype]["test"](args, tokenizer)
789
+
790
+ if DATALOADER_DICT[args.datatype]["val"] is not None:
791
+ val_dataloader, val_length = DATALOADER_DICT[args.datatype]["val"](args, tokenizer, subset="val")
792
+ else:
793
+ val_dataloader, val_length = test_dataloader, test_length
794
+
795
+ ## report validation results if the ["test"] is None
796
+ if test_dataloader is None:
797
+ test_dataloader, test_length = val_dataloader, val_length
798
+
799
+ if args.local_rank == 0:
800
+ logger.info("***** Running test *****")
801
+ logger.info(" Num examples = %d", test_length)
802
+ logger.info(" Batch size = %d", args.batch_size_val)
803
+ logger.info(" Num steps = %d", len(test_dataloader))
804
+ logger.info("***** Running val *****")
805
+ logger.info(" Num examples = %d", val_length)
806
+
807
+ ## ####################################
808
+ # train and eval
809
+ ## ####################################
810
+ if args.do_train:
811
+ train_dataloader, train_length, train_sampler = DATALOADER_DICT[args.datatype]["train"](args, tokenizer)
812
+ num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1)
813
+ / args.gradient_accumulation_steps) * args.epochs
814
+
815
+ coef_lr = args.coef_lr
816
+ optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, args.local_rank, coef_lr=coef_lr)
817
+
818
+ if args.local_rank == 0:
819
+ logger.info("***** Running training *****")
820
+ logger.info(" Num examples = %d", train_length)
821
+ logger.info(" Batch size = %d", args.batch_size)
822
+ logger.info(" Num steps = %d", num_train_optimization_steps * args.gradient_accumulation_steps)
823
+
824
+ best_score = 0.00001
825
+ best_output_model_file = "None"
826
+ ## ##############################################################
827
+ # resume optimizer state besides loss to continue train
828
+ ## ##############################################################
829
+ resumed_epoch = 0
830
+ if args.resume_model:
831
+ checkpoint = torch.load(args.resume_model, map_location='cpu')
832
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
833
+ resumed_epoch = checkpoint['epoch']+1
834
+ resumed_loss = checkpoint['loss']
835
+
836
+ global_step = 0
837
+ for epoch in range(resumed_epoch, args.epochs):
838
+ train_sampler.set_epoch(epoch)
839
+ tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer,
840
+ scheduler, global_step, local_rank=args.local_rank)
841
+
842
+ if args.local_rank == 0:
843
+ logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
844
+
845
+ output_model_file = save_model(epoch, args, model, optimizer, tr_loss, type_name="")
846
+
847
+ ## Run on val dataset, this process is *TIME-consuming*.
848
+ # logger.info("Eval on val dataset")
849
+ # R1 = eval_epoch(args, model, val_dataloader, device, n_gpu)
850
+
851
+ R1 = eval_epoch(args, model, test_dataloader, device, n_gpu, tokenizer)
852
+ if best_score <= R1:
853
+ best_score = R1
854
+ best_output_model_file = output_model_file
855
+ logger.info("The best model is: {}, the R1 is: {:.4f}".format(best_output_model_file, best_score))
856
+
857
+ ## Uncomment if want to test on the best checkpoint
858
+ # if args.local_rank == 0:
859
+ # model = load_model(-1, args, n_gpu, device, model_file=best_output_model_file)
860
+ # eval_epoch(args, model, test_dataloader, device, n_gpu)
861
+
862
+ elif args.do_eval:
863
+ if args.local_rank == 0:
864
+ eval_epoch(args, model, test_dataloader, device, n_gpu, tokenizer)
865
+
866
+ if __name__ == "__main__":
867
+ main()
ckpts/cache_train_9k/log.txt ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-08-12 02:24:18,172:INFO: Effective parameters:
2
+ 2025-08-12 02:24:18,172:INFO: <<< batch_size: 128
3
+ 2025-08-12 02:24:18,172:INFO: <<< batch_size_val: 16
4
+ 2025-08-12 02:24:18,172:INFO: <<< cache_dir:
5
+ 2025-08-12 02:24:18,172:INFO: <<< coef_lr: 0.001
6
+ 2025-08-12 02:24:18,172:INFO: <<< cross_model: cross-base
7
+ 2025-08-12 02:24:18,173:INFO: <<< cross_num_hidden_layers: 4
8
+ 2025-08-12 02:24:18,173:INFO: <<< data_path: /disk/gjw/msr-vtt/MSRVTT_data.json
9
+ 2025-08-12 02:24:18,173:INFO: <<< datatype: msrvtt
10
+ 2025-08-12 02:24:18,173:INFO: <<< do_eval: True
11
+ 2025-08-12 02:24:18,173:INFO: <<< do_lower_case: False
12
+ 2025-08-12 02:24:18,173:INFO: <<< do_pretrain: False
13
+ 2025-08-12 02:24:18,173:INFO: <<< do_train: False
14
+ 2025-08-12 02:24:18,173:INFO: <<< epochs: 5
15
+ 2025-08-12 02:24:18,173:INFO: <<< eval_frame_order: 0
16
+ 2025-08-12 02:24:18,173:INFO: <<< expand_msrvtt_sentences: True
17
+ 2025-08-12 02:24:18,173:INFO: <<< feature_framerate: 1
18
+ 2025-08-12 02:24:18,173:INFO: <<< features_path: /disk/gjw/msr-vtt/compressed_videos
19
+ 2025-08-12 02:24:18,173:INFO: <<< fp16: False
20
+ 2025-08-12 02:24:18,173:INFO: <<< fp16_opt_level: O1
21
+ 2025-08-12 02:24:18,173:INFO: <<< freeze_layer_num: 0
22
+ 2025-08-12 02:24:18,173:INFO: <<< gradient_accumulation_steps: 1
23
+ 2025-08-12 02:24:18,173:INFO: <<< hard_negative_rate: 0.5
24
+ 2025-08-12 02:24:18,173:INFO: <<< hash_bit: 2048
25
+ 2025-08-12 02:24:18,174:INFO: <<< init_model: ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.0
26
+ 2025-08-12 02:24:18,174:INFO: <<< linear_patch: 2d
27
+ 2025-08-12 02:24:18,174:INFO: <<< local_rank: 0
28
+ 2025-08-12 02:24:18,174:INFO: <<< loose_type: True
29
+ 2025-08-12 02:24:18,174:INFO: <<< lr: 0.0001
30
+ 2025-08-12 02:24:18,174:INFO: <<< lr_decay: 0.9
31
+ 2025-08-12 02:24:18,174:INFO: <<< margin: 0.1
32
+ 2025-08-12 02:24:18,174:INFO: <<< max_frames: 12
33
+ 2025-08-12 02:24:18,174:INFO: <<< max_words: 32
34
+ 2025-08-12 02:24:18,174:INFO: <<< n_display: 50
35
+ 2025-08-12 02:24:18,174:INFO: <<< n_gpu: 1
36
+ 2025-08-12 02:24:18,174:INFO: <<< n_pair: 1
37
+ 2025-08-12 02:24:18,174:INFO: <<< negative_weighting: 1
38
+ 2025-08-12 02:24:18,174:INFO: <<< num_thread_reader: 0
39
+ 2025-08-12 02:24:18,174:INFO: <<< output_dir: ckpts/cache_train_9k
40
+ 2025-08-12 02:24:18,174:INFO: <<< pretrained_clip_name: ViT-B/32
41
+ 2025-08-12 02:24:18,174:INFO: <<< rank: 0
42
+ 2025-08-12 02:24:18,174:INFO: <<< resume_model: None
43
+ 2025-08-12 02:24:18,174:INFO: <<< rff_dim: 3000
44
+ 2025-08-12 02:24:18,175:INFO: <<< sampled_use_mil: False
45
+ 2025-08-12 02:24:18,175:INFO: <<< seed: 42
46
+ 2025-08-12 02:24:18,175:INFO: <<< sim_header: meanP
47
+ 2025-08-12 02:24:18,175:INFO: <<< slice_framepos: 2
48
+ 2025-08-12 02:24:18,175:INFO: <<< task_type: retrieval
49
+ 2025-08-12 02:24:18,175:INFO: <<< text_num_hidden_layers: 12
50
+ 2025-08-12 02:24:18,175:INFO: <<< train_csv: /disk/gjw/msr-vtt/MSRVTT_train.9k.csv
51
+ 2025-08-12 02:24:18,175:INFO: <<< train_frame_order: 0
52
+ 2025-08-12 02:24:18,175:INFO: <<< use_clip4hashing: False
53
+ 2025-08-12 02:24:18,175:INFO: <<< use_mil: False
54
+ 2025-08-12 02:24:18,175:INFO: <<< use_rff: False
55
+ 2025-08-12 02:24:18,175:INFO: <<< val_csv: /disk/gjw/msr-vtt/MSRVTT_JSFUSION_train_test_10k.csv
56
+ 2025-08-12 02:24:18,175:INFO: <<< video_dim: 1024
57
+ 2025-08-12 02:24:18,175:INFO: <<< visual_num_hidden_layers: 12
58
+ 2025-08-12 02:24:18,175:INFO: <<< warmup_proportion: 0.1
59
+ 2025-08-12 02:24:18,175:INFO: <<< world_size: 1
60
+ 2025-08-12 02:24:18,176:INFO: device: cuda:0 n_gpu: 1
61
+ 2025-08-12 02:24:19,207:INFO: loading archive file /disk/gjw/CLIP4Clip/modules/cross-base
62
+ 2025-08-12 02:24:19,207:INFO: Model config {
63
+ "attention_probs_dropout_prob": 0.1,
64
+ "hidden_act": "gelu",
65
+ "hidden_dropout_prob": 0.1,
66
+ "hidden_size": 512,
67
+ "initializer_range": 0.02,
68
+ "intermediate_size": 2048,
69
+ "max_position_embeddings": 128,
70
+ "num_attention_heads": 8,
71
+ "num_hidden_layers": 4,
72
+ "type_vocab_size": 2,
73
+ "vocab_size": 512
74
+ }
75
+
76
+ 2025-08-12 02:24:19,208:INFO: Weight doesn't exsits. /disk/gjw/CLIP4Clip/modules/cross-base/cross_pytorch_model.bin
77
+ 2025-08-12 02:24:19,208:WARNING: Stage-One:True, Stage-Two:False
78
+ 2025-08-12 02:24:19,208:WARNING: Test retrieval by loose type.
79
+ 2025-08-12 02:24:19,208:WARNING: embed_dim: 512
80
+ 2025-08-12 02:24:19,208:WARNING: image_resolution: 224
81
+ 2025-08-12 02:24:19,208:WARNING: vision_layers: 12
82
+ 2025-08-12 02:24:19,208:WARNING: vision_width: 768
83
+ 2025-08-12 02:24:19,208:WARNING: vision_patch_size: 32
84
+ 2025-08-12 02:24:19,208:WARNING: context_length: 77
85
+ 2025-08-12 02:24:19,209:WARNING: vocab_size: 49408
86
+ 2025-08-12 02:24:19,209:WARNING: transformer_width: 512
87
+ 2025-08-12 02:24:19,209:WARNING: transformer_heads: 8
88
+ 2025-08-12 02:24:19,209:WARNING: transformer_layers: 12
89
+ 2025-08-12 02:24:19,209:WARNING: linear_patch: 2d
90
+ 2025-08-12 02:24:19,209:WARNING: cut_top_layer: 0
91
+ 2025-08-12 02:24:22,820:WARNING: sim_header: meanP
92
+ 2025-08-12 02:24:31,116:INFO: --------------------
93
+ 2025-08-12 02:24:31,117:INFO: Weights from pretrained model not used in CLIP4Clip:
94
+ clip.input_resolution
95
+ clip.context_length
96
+ clip.vocab_size
97
+ 2025-08-12 02:24:32,963:INFO: ***** Running test *****
98
+ 2025-08-12 02:24:32,963:INFO: Num examples = 10000
99
+ 2025-08-12 02:24:32,963:INFO: Batch size = 16
100
+ 2025-08-12 02:24:32,963:INFO: Num steps = 625
101
+ 2025-08-12 02:24:32,963:INFO: ***** Running val *****
102
+ 2025-08-12 02:24:32,963:INFO: Num examples = 10000
103
+ 2025-08-12 02:24:32,966:INFO: Model testing
104
+ 2025-08-12 02:24:32,966:INFO: 9k 훈련 데이터 캐시 생성: msrvtt_train_test_10k_cache_trained.pt
105
+ 2025-08-12 02:24:32,966:INFO: suffix: _trained
106
+ 2025-08-12 05:49:06,718:INFO: 추출된 피처를 캐시에 저장합니다: ckpts/cache_train_9k/msrvtt_train_test_10k_cache_trained.pt
107
+ 2025-08-12 05:49:07,252:INFO: [Cache] 텍스트 피쳐 개수=625각 텐서 shape=torch.Size([16, 1, 512])
108
+ 2025-08-12 05:49:07,253:INFO: [Cache] 비디오 피쳐 개수=625각 텐서 shape=torch.Size([16, 12, 512])
109
+ 2025-08-12 05:51:36,241:INFO: sim matrix size: 10000, 10000
110
+ 2025-08-12 05:51:36,355:INFO: Saved sim_matrix heatmap to ckpts/cache_train_9k/sim_matrix_heatmap.png
111
+ 2025-08-12 05:53:13,421:INFO: Length-T: 10000, Length-V:10000
112
+ 2025-08-12 05:53:13,421:INFO: Text-to-Video:
113
+ 2025-08-12 05:53:13,421:INFO: >>> R@1: 26.6 - R@5: 51.2 - R@10: 61.3 - Median R: 5.0 - Mean R: 127.3
114
+ 2025-08-12 05:53:13,421:INFO: Video-to-Text:
115
+ 2025-08-12 05:53:13,421:INFO: >>> V2T$R@1: 17.9 - V2T$R@5: 32.8 - V2T$R@10: 39.8 - V2T$Median R: 25.0 - V2T$Mean R: 199.8
ckpts/cache_train_9k/msrvtt_train_test_10k_cache_trained.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:840d1fbfa39d4fda94294e01b7a2d7bdf409ab04154f9c38ebabc79dec4902e0
3
+ size 273271799
ckpts/cache_train_9k/sim_matrix_heatmap.png ADDED
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/log.txt ADDED
The diff for this file is too large to render. See raw diff
 
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/msrvtt_eval_cache.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:326307d8fa7b94df2ae84ef2a983ff9cc7bc771011435d438834849340bbceff
3
+ size 27327894
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/msrvtt_eval_cache_proj3008_binary_trained.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f276f0ad05bc6827baef7bc199dbce104552d818e71b6003412360c423c5340
3
+ size 27623162
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/msrvtt_eval_cache_proj3008_trained.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa0cb7b6f93768d7df200b4da4f666c7a326ef2019697d43deef4e7690df65ec
3
+ size 27620488
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/pytorch_model.bin.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee6b844c4664b89f5fdb5575807a091b22337037716c396837826f694fef2e5e
3
+ size 359716242
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/pytorch_opt.bin.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71f4d99edafdbc5681c53834aa5fed86e53c92f87c97ef4201967cda0d04ee1b
3
+ size 494597414
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/sim_matrix_heatmap.png ADDED
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/t2v_top10_proj3008_binary_trained.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/t2v_top10_proj3008_binary_trained.tsv ADDED
@@ -0,0 +1 @@
 
 
1
+ query_idx query video_rank video_id video_idx score
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/t2v_top10_proj3008_trained.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
ckpts/ckpt_hash_msrvtt_retrieval_looseType_1e-7/t2v_top10_proj3008_trained.tsv ADDED
@@ -0,0 +1 @@
 
 
1
+ query_idx query video_rank video_id video_idx score
ckpts/ckpt_msrvtt_retrieval_looseType/log.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6ce285755f9c657d33f5215e1796baf094cda32a8728a2189864baa413e5a4e
3
+ size 22286632
ckpts/ckpt_msrvtt_retrieval_looseType/msrvtt_eval_cache.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42ff89065849459a75dcbdd4058362bf4071f47ff542b1d6a4f8db03389af43e
3
+ size 27327893
ckpts/ckpt_msrvtt_retrieval_looseType/msrvtt_eval_cache_proj3008_binary_trained.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36ee838a7701808f023b70616e893119fa5417eaff0406989d6664d9d1ccc554
3
+ size 27623162
ckpts/ckpt_msrvtt_retrieval_looseType/msrvtt_eval_cache_proj3008_trained.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42c7b8bbca50170c1e2c4227b452028e744a6ff390b1b57107399baa11189dda
3
+ size 27620488
ckpts/ckpt_msrvtt_retrieval_looseType/msrvtt_eval_cache_trained.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:450bd99bb45b5ed1855e21e0a8d153d9b61a653944076882b36f66042473686a
3
+ size 27327893
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:905df44d3f0b80981471701f4f7a2ec8430d57f9cedf38eec139a56a7b184858
3
+ size 353556437
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2ae61ac9ce4df7c3c99b825c8c8a7f20cc25337aae07e297312906075d94930
3
+ size 353556437
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20804da0ab5454d35c2ba99fe65f8e08bd1b86c4966fcf961f8e143d659b9066
3
+ size 353556437
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb87fe2412b22327ae764affbbdc056aab4a7191f541036698af0a2744af7219
3
+ size 353556437
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_model.bin.4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d7e2ba508986c7347380767233b017d4ac4c9fabc135f3457de8625dc833dca
3
+ size 353556437
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8795af40e7f2e6a9e166657b854f20aeb606843a2d8b67021155160f50c9de44
3
+ size 494600757
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89dfc0f35eb9cdabe12930b544434c30e1fcb84f8ac31004fe31a6c9c3747f3c
3
+ size 494600757
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93a48879a8b7b7a9c5772f9259de9dfdef7698b06e74811df6ff255737a257f0
3
+ size 494600757
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:adfc35ef88500e828fa539e426a46b03dc1f989b5203636e41c228f4ed1be798
3
+ size 494600757
ckpts/ckpt_msrvtt_retrieval_looseType/pytorch_opt.bin.4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:27c179a6d93b10941a88e691d2213a4c057c9a5d862ac409f26041dacd7fd85c
3
+ size 494600757
ckpts/ckpt_msrvtt_retrieval_looseType/sim_matrix_heatmap.png ADDED
ckpts/ckpt_msrvtt_retrieval_looseType/t2v_top10_proj3008_binary_trained.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
ckpts/ckpt_msrvtt_retrieval_looseType/t2v_top10_proj3008_binary_trained.tsv ADDED
@@ -0,0 +1 @@
 
 
1
+ query_idx query video_rank video_id video_idx score
ckpts/ckpt_msrvtt_retrieval_looseType/t2v_top10_proj3008_trained.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
ckpts/ckpt_msrvtt_retrieval_looseType/t2v_top10_proj3008_trained.tsv ADDED
@@ -0,0 +1 @@
 
 
1
+ query_idx query video_rank video_id video_idx score