qiushuocheng commited on
Commit
a39be45
·
verified ·
1 Parent(s): 1fd8ddd

Initial upload

Browse files
.gitignore CHANGED
@@ -1,6 +1,163 @@
1
- *.zip
2
- LaSA/dataset/
3
- LaSA/result/
4
- USDRL/data/
5
- USDRL/pretrained/
6
- Simple-Skeleton-Detection/fig/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset
2
+ data
3
+ dataset
4
+ work_dir
5
+
6
+ *.pkl
7
+ *.mp4
8
+
9
+ *.sh.o*
10
+
11
+
12
+ # Byte-compiled / optimized / DLL files
13
+ __pycache__/
14
+ *.py[cod]
15
+ *$py.class
16
+
17
+ # C extensions
18
+ *.so
19
+
20
+ # Distribution / packaging
21
+ .Python
22
+ build/
23
+ develop-eggs/
24
+ dist/
25
+ downloads/
26
+ eggs/
27
+ .eggs/
28
+ lib/
29
+ lib64/
30
+ parts/
31
+ sdist/
32
+ var/
33
+ wheels/
34
+ share/python-wheels/
35
+ *.egg-info/
36
+ .installed.cfg
37
+ *.egg
38
+ MANIFEST
39
+
40
+ # PyInstaller
41
+ # Usually these files are written by a python script from a template
42
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
43
+ *.manifest
44
+ *.spec
45
+
46
+ # Installer logs
47
+ pip-log.txt
48
+ pip-delete-this-directory.txt
49
+
50
+ # Unit test / coverage reports
51
+ htmlcov/
52
+ .tox/
53
+ .nox/
54
+ .coverage
55
+ .coverage.*
56
+ .cache
57
+ nosetests.xml
58
+ coverage.xml
59
+ *.cover
60
+ *.py,cover
61
+ .hypothesis/
62
+ .pytest_cache/
63
+ cover/
64
+
65
+ # Translations
66
+ *.mo
67
+ *.pot
68
+
69
+ # Django stuff:
70
+ *.log
71
+ local_settings.py
72
+ db.sqlite3
73
+ db.sqlite3-journal
74
+
75
+ # Flask stuff:
76
+ instance/
77
+ .webassets-cache
78
+
79
+ # Scrapy stuff:
80
+ .scrapy
81
+
82
+ # Sphinx documentation
83
+ docs/_build/
84
+
85
+ # PyBuilder
86
+ .pybuilder/
87
+ target/
88
+
89
+ # Jupyter Notebook
90
+ .ipynb_checkpoints
91
+
92
+ # IPython
93
+ profile_default/
94
+ ipython_config.py
95
+
96
+ # pyenv
97
+ # For a library or package, you might want to ignore these files since the code is
98
+ # intended to run in multiple environments; otherwise, check them in:
99
+ # .python-version
100
+
101
+ # pipenv
102
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
103
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
104
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
105
+ # install all needed dependencies.
106
+ #Pipfile.lock
107
+
108
+ # poetry
109
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
110
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
111
+ # commonly ignored for libraries.
112
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
113
+ #poetry.lock
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
.vscode/launch.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.2.0",
3
+ "configurations": [
4
+ {
5
+ "name": "Python: Debug Current File (gvhmr)",
6
+ "type": "debugpy",
7
+ "request": "launch",
8
+ "program": "${file}",
9
+ "console": "integratedTerminal",
10
+ "cwd": "/root/autodl-tmp/workshop2/",
11
+ "args": [
12
+ "--config",
13
+ // "config/eval.yaml"
14
+ "configss/train_split3_full_MSFF_DTCL.yaml",
15
+ // "config/pretrain2.yaml"
16
+ "--work-dir",
17
+ "./work_dir/3_agcn_cl_all/",
18
+ "-model_saved_name",
19
+ "./work_dir/3_agcn_cl_all/",
20
+ "--weights",
21
+ "/root/autodl-tmp/RVTCLR/work_dir/ntu60_cs/skeletonclr_joint/U/cl/pretext_babel/epoch300_model.pt"
22
+ ]
23
+ }
24
+ ]
25
+ }
26
+ // CUBLAS_WORKSPACE_CONFIG=:4096:8 python train_full_SSL.py --config configss/train_split3_full_MSFF_DTCL.yaml \
27
+ // --work-dir ./work_dir/3_agcn_cl_all/ \
28
+ // -model_saved_name ./work_dir/3_agcn_cl_all/ \
29
+ // --weights /root/autodl-tmp/RVTCLR/work_dir/ntu60_cs/skeletonclr_joint/U/both/pretext_babel_dense_U/epoch300_model.pt
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
NOTICE.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Open Source Software Notice ###
2
+
3
+ This product from LINE Corporation contains the open source software or third-party software listed below, to which the terms in the LICENSE file in this repository do not apply.
4
+ Please refer to the licences of the respective software repositories for the terms and conditions of their use.
5
+
6
+ BABEL (Non-commercial License)
7
+ https://github.com/abhinanda-punnakkal/BABEL
8
+
9
+ 2s-AGCN (CC BY-NC 4.0 License)
10
+ https://github.com/lshiwjx/2s-AGCN
11
+
12
+ FAC-Net (MIT License)
13
+ https://github.com/LeonHLJ/FAC-Net
14
+
15
+ pytorch-classification(MIT License)
16
+ https://github.com/bearpaw/pytorch-classification
README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Skeleton-Temporal-Action-Localization
2
+
3
+ Code for the paper "Frame-Level Label Refinement for Skeleton-Based Weakly-Supervised Action Recognition" (AAAI 2023).
4
+
5
+ ## Overview
6
+
7
+ Architecture of Network
8
+
9
+ ![Architecture of Network](./images/framework.jpg)
10
+
11
+ ## Requirements
12
+ ```bash
13
+ conda create -n stal python=3.7
14
+ conda activate stal
15
+ conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 -c pytorch
16
+ pip install -r requirements.txt
17
+ ```
18
+
19
+ ## Data Preparation
20
+ Due to the distribution policy of AMASS dataset, we are not allowed to distribute the data directly. We provide a series of script that could reproduce our motion segmentation dataset from BABEL dataset.
21
+
22
+ Download [AMASS Dataset](https://amass.is.tue.mpg.de/) and [BABEL Dataset](https://babel.is.tue.mpg.de/). Unzip and locate them in the `dataset` folder.
23
+
24
+ Prepare the SMPLH Model following [this](https://github.com/vchoutas/smplx/blob/main/tools/README.md#smpl-h-version-used-in-amass) and put the merged model `SMPLH_male.pkl` into the `human_model` folder.
25
+
26
+ The whole directory should be look like this:
27
+ ```
28
+ Skeleton-Temporal-Action-Localization
29
+ │ README.md
30
+ │ train.py
31
+ | ...
32
+ |
33
+ └───config
34
+ └───prepare
35
+ └───...
36
+
37
+ └───human_model
38
+ │ └───SMPLH_male.pkl
39
+
40
+ └───dataset
41
+ └───amass
42
+ | └───ACCAD
43
+ | └───BMLmovi
44
+ | └───...
45
+
46
+ └───babel_v1.0_release
47
+ └───train.json
48
+ └───val.json
49
+ └───...
50
+ ```
51
+
52
+ And also clone the BABEL offical code into the `dataset` folder.
53
+
54
+ ```bash
55
+ git clone https://github.com/abhinanda-punnakkal/BABEL.git dataset/BABEL
56
+ ```
57
+
58
+ Finally, the motion segmentation dataset can be generate by:
59
+ ```bash
60
+ bash prepare/generate_dataset.sh
61
+ ```
62
+
63
+ ## Training and Evaluation
64
+ Train and evaluate the model with subset-1 of BABEL, run following commands:
65
+ ```bash
66
+ python train.py --config config/train_split1.yaml
67
+ ```
68
+
69
+ ## Acknowledgement
70
+ Our codes are based on [BABEL](https://github.com/abhinanda-punnakkal/BABEL), [2s-AGCN](https://github.com/lshiwjx/2s-AGCN) and [FAC-Net](https://github.com/LeonHLJ/FAC-Net).
71
+
72
+
73
+ ## Citation
74
+
75
+ ```
76
+ @InProceedings{yu2023frame,
77
+ title={Frame-Level Label Refinement for Skeleton-Based Weakly-Supervised Action Recognition},
78
+ author={Yu, Qing and Fujiwara, Kent},
79
+ booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
80
+ volume={37},
81
+ number={3},
82
+ pages={3322--3330},
83
+ year={2023}
84
+ }
85
+ ```
86
+
87
+ ## License
88
+ [Apache License 2.0](LICENSE)
89
+
90
+ Additionally, this repository contains third-party software. Refer [NOTICE.txt](NOTICE.txt) for more details and follow the terms and conditions of their use.
configs/train_split1_full_MSFF_DTCL.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ work_dir: ./work_dir/1_w_MSFF_full/
2
+ model_saved_name: ./work_dir/1_w_MSFF_full/
3
+
4
+ # feeder
5
+ feeder: feeders.feeder.Feeder
6
+ train_feeder_args:
7
+ data_path: ./dataset/train_split1.pkl
8
+ debug: False
9
+ random_choose: False
10
+ random_shift: False
11
+ random_move: False
12
+ window_size: -1
13
+ nb_class: 4
14
+
15
+ test_feeder_args:
16
+ data_path: ./dataset/val_split1.pkl
17
+ nb_class: 4
18
+
19
+ # model
20
+ model: model.agcn_Unet.Model
21
+ model_args:
22
+ num_class: 4
23
+ num_person: 1
24
+ num_point: 25 # checked from 25
25
+ graph: graph.ntu_rgb_d.Graph
26
+ graph_args:
27
+ labeling_mode: 'spatial'
28
+
29
+ #optim
30
+ weight_decay: 0.001
31
+ base_lr: 0.0002
32
+ step: [60,80]
33
+
34
+ # training
35
+ device: [0]
36
+ optimizer: 'Adam'
37
+ loss: 'CE'
38
+ batch_size: 8
39
+ test_batch_size: 1
40
+ num_epoch: 101 #101
41
+ nesterov: True
42
+ lambda_mil: 1.0
43
+ weights: /home/newDisk/epoch300_model.pt
configs/train_split2_full_MSFF_DTCL.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ work_dir: ./work_dir/2_w_MSFF_full/
2
+ model_saved_name: ./work_dir/2_w_MSFF_full/
3
+
4
+ # feeder
5
+ feeder: feeders.feeder.Feeder
6
+ train_feeder_args:
7
+ data_path: ./dataset/train_split2.pkl
8
+ debug: False
9
+ random_choose: False
10
+ random_shift: False
11
+ random_move: False
12
+ window_size: -1
13
+ nb_class: 4
14
+
15
+ test_feeder_args:
16
+ data_path: ./dataset/val_split2.pkl
17
+ nb_class: 4
18
+
19
+ # model
20
+ model: model.agcn_Unet.Model
21
+ model_args:
22
+ num_class: 4
23
+ num_person: 1
24
+ num_point: 25 # checked from 25
25
+ graph: graph.ntu_rgb_d.Graph
26
+ graph_args:
27
+ labeling_mode: 'spatial'
28
+
29
+ #optim
30
+ weight_decay: 0.0001
31
+ base_lr: 0.001
32
+ step: [60,80]
33
+
34
+ # training
35
+ device: [0]
36
+ optimizer: 'Adam'
37
+ loss: 'CE'
38
+ batch_size: 8
39
+ test_batch_size: 1
40
+ num_epoch: 101 #101
41
+ nesterov: True
42
+ lambda_mil: 1.0
43
+ weights: /home/newDisk/epoch300_model.pt
configs/train_split3_full_MSFF_DTCL.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ work_dir: ./work_dir/3_w_MSFF_full/
2
+ model_saved_name: ./work_dir/3_w_MSFF_full/
3
+
4
+ # feeder
5
+ feeder: feeders.feeder.Feeder
6
+ train_feeder_args:
7
+ data_path: ./dataset/train_split3.pkl
8
+ debug: False
9
+ random_choose: False
10
+ random_shift: False
11
+ random_move: False
12
+ window_size: -1
13
+ nb_class: 4
14
+
15
+ test_feeder_args:
16
+ data_path: ./dataset/val_split3.pkl
17
+ nb_class: 4
18
+
19
+ # model
20
+ model: model.agcn_Unet.Model
21
+ model_args:
22
+ num_class: 4
23
+ num_person: 1
24
+ num_point: 25 # checked from 25
25
+ graph: graph.ntu_rgb_d.Graph
26
+ graph_args:
27
+ labeling_mode: 'spatial'
28
+
29
+ #optim
30
+ weight_decay: 0.0001
31
+ base_lr: 0.001
32
+ step: [60,80]
33
+
34
+ # training
35
+ device: [0]
36
+ optimizer: 'Adam'
37
+ loss: 'CE'
38
+ batch_size: 8
39
+ test_batch_size: 1
40
+ num_epoch: 101 #101
41
+ nesterov: True
42
+ lambda_mil: 1.0
43
+ weights: /home/newDisk/epoch300_model.pt
eval.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from __future__ import print_function
4
+
5
+ import argparse
6
+ import inspect
7
+ import os
8
+ import pdb
9
+ import pickle
10
+ import random
11
+ import re
12
+ import shutil
13
+ import time
14
+ from collections import *
15
+
16
+ import ipdb
17
+ import numpy as np
18
+
19
+ # torch
20
+ import torch
21
+ import torch.backends.cudnn as cudnn
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import torch.optim as optim
25
+ import yaml
26
+ from einops import rearrange, reduce, repeat
27
+ from evaluation.classificationMAP import getClassificationMAP as cmAP
28
+ from evaluation.detectionMAP import getSingleStreamDetectionMAP as dsmAP
29
+ from feeders.tools import collate_with_padding_multi_joint
30
+ from model.losses import cross_entropy_loss, mvl_loss
31
+ from sklearn.metrics import f1_score
32
+
33
+ # Custom
34
+ from tensorboardX import SummaryWriter
35
+ from torch.autograd import Variable
36
+ from torch.optim.lr_scheduler import _LRScheduler
37
+ from tqdm import tqdm
38
+ from utils.logger import Logger
39
+
40
+
41
+ def init_seed(seed):
42
+ torch.cuda.manual_seed_all(seed)
43
+ torch.manual_seed(seed)
44
+ np.random.seed(seed)
45
+ random.seed(seed)
46
+ torch.backends.cudnn.deterministic = True
47
+ torch.backends.cudnn.benchmark = False
48
+
49
+
50
+ def get_parser():
51
+ # parameter priority: command line > config > default
52
+ parser = argparse.ArgumentParser(
53
+ description="Spatial Temporal Graph Convolution Network"
54
+ )
55
+ parser.add_argument(
56
+ "--work-dir",
57
+ default="./work_dir/temp",
58
+ help="the work folder for storing results",
59
+ )
60
+
61
+ parser.add_argument("-model_saved_name", default="")
62
+ parser.add_argument(
63
+ "--config",
64
+ default="./config/nturgbd-cross-view/test_bone.yaml",
65
+ help="path to the configuration file",
66
+ )
67
+
68
+ # processor
69
+ parser.add_argument("--phase", default="test", help="must be train or test")
70
+
71
+ # visulize and debug
72
+ parser.add_argument("--seed", type=int, default=5, help="random seed for pytorch")
73
+ parser.add_argument(
74
+ "--log-interval",
75
+ type=int,
76
+ default=100,
77
+ help="the interval for printing messages (#iteration)",
78
+ )
79
+ parser.add_argument(
80
+ "--save-interval",
81
+ type=int,
82
+ default=2,
83
+ help="the interval for storing models (#iteration)",
84
+ )
85
+ parser.add_argument(
86
+ "--eval-interval",
87
+ type=int,
88
+ default=5,
89
+ help="the interval for evaluating models (#iteration)",
90
+ )
91
+ parser.add_argument(
92
+ "--print-log", type=str2bool, default=True, help="print logging or not"
93
+ )
94
+ parser.add_argument(
95
+ "--show-topk",
96
+ type=int,
97
+ default=[1, 5],
98
+ nargs="+",
99
+ help="which Top K accuracy will be shown",
100
+ )
101
+
102
+ # feeder
103
+ parser.add_argument(
104
+ "--feeder", default="feeder.feeder", help="data loader will be used"
105
+ )
106
+ parser.add_argument(
107
+ "--num-worker",
108
+ type=int,
109
+ default=32,
110
+ help="the number of worker for data loader",
111
+ )
112
+ parser.add_argument(
113
+ "--train-feeder-args",
114
+ default=dict(),
115
+ help="the arguments of data loader for training",
116
+ )
117
+ parser.add_argument(
118
+ "--test-feeder-args",
119
+ default=dict(),
120
+ help="the arguments of data loader for test",
121
+ )
122
+
123
+ # model
124
+ parser.add_argument("--model", default=None, help="the model will be used")
125
+ parser.add_argument(
126
+ "--model-args", type=dict, default=dict(), help="the arguments of model"
127
+ )
128
+ parser.add_argument(
129
+ "--weights", default=None, help="the weights for network initialization"
130
+ )
131
+ parser.add_argument(
132
+ "--ignore-weights",
133
+ type=str,
134
+ default=[],
135
+ nargs="+",
136
+ help="the name of weights which will be ignored in the initialization",
137
+ )
138
+
139
+ # optim
140
+ parser.add_argument(
141
+ "--base-lr", type=float, default=0.01, help="initial learning rate"
142
+ )
143
+ parser.add_argument(
144
+ "--step",
145
+ type=int,
146
+ default=[60,80],
147
+ nargs="+",
148
+ help="the epoch where optimizer reduce the learning rate",
149
+ )
150
+
151
+ # training
152
+ parser.add_argument(
153
+ "--device",
154
+ type=int,
155
+ default=0,
156
+ nargs="+",
157
+ help="the indexes of GPUs for training or testing",
158
+ )
159
+ parser.add_argument("--optimizer", default="SGD", help="type of optimizer")
160
+ parser.add_argument(
161
+ "--nesterov", type=str2bool, default=False, help="use nesterov or not"
162
+ )
163
+ parser.add_argument(
164
+ "--batch-size", type=int, default=256, help="training batch size"
165
+ )
166
+ parser.add_argument(
167
+ "--test-batch-size", type=int, default=256, help="test batch size"
168
+ )
169
+ parser.add_argument(
170
+ "--start-epoch", type=int, default=0, help="start training from which epoch"
171
+ )
172
+ parser.add_argument(
173
+ "--num-epoch", type=int, default=80, help="stop training in which epoch"
174
+ )
175
+ parser.add_argument(
176
+ "--weight-decay", type=float, default=0.0005, help="weight decay for optimizer"
177
+ )
178
+ # loss
179
+ parser.add_argument("--loss", type=str, default="CE", help="loss type(CE or focal)")
180
+ parser.add_argument(
181
+ "--label_count_path",
182
+ default=None,
183
+ type=str,
184
+ help="Path to label counts (used in loss weighting)",
185
+ )
186
+ parser.add_argument(
187
+ "---beta",
188
+ type=float,
189
+ default=0.9999,
190
+ help="Hyperparameter for Class balanced loss",
191
+ )
192
+ parser.add_argument(
193
+ "--gamma", type=float, default=2.0, help="Hyperparameter for Focal loss"
194
+ )
195
+
196
+ parser.add_argument("--only_train_part", default=False)
197
+ parser.add_argument("--only_train_epoch", default=0)
198
+ parser.add_argument("--warm_up_epoch", default=0)
199
+
200
+ parser.add_argument(
201
+ "--lambda-mil", default=1.0, help="balancing hyper-parameter of mil branch"
202
+ )
203
+
204
+ parser.add_argument(
205
+ "--class-threshold",
206
+ type=float,
207
+ default=0.1,
208
+ help="class threshold for rejection",
209
+ )
210
+ parser.add_argument(
211
+ "--start-threshold",
212
+ type=float,
213
+ default=0.03,
214
+ help="start threshold for action localization",
215
+ )
216
+ parser.add_argument(
217
+ "--end-threshold",
218
+ type=float,
219
+ default=0.055,
220
+ help="end threshold for action localization",
221
+ )
222
+ parser.add_argument(
223
+ "--threshold-interval",
224
+ type=float,
225
+ default=0.005,
226
+ help="threshold interval for action localization",
227
+ )
228
+ return parser
229
+
230
+
231
+ class Processor:
232
+ """
233
+ Processor for Skeleton-based Action Recgnition
234
+ """
235
+
236
+ def __init__(self, arg):
237
+ self.arg = arg
238
+ self.save_arg()
239
+ if arg.phase == "train":
240
+ if not arg.train_feeder_args["debug"]:
241
+ if os.path.isdir(arg.model_saved_name):
242
+ print("log_dir: ", arg.model_saved_name, "already exist")
243
+ # answer = input('delete it? y/n:')
244
+ answer = "y"
245
+ if answer == "y":
246
+ print("Deleting dir...")
247
+ shutil.rmtree(arg.model_saved_name)
248
+ print("Dir removed: ", arg.model_saved_name)
249
+ # input('Refresh the website of tensorboard by pressing any keys')
250
+ else:
251
+ print("Dir not removed: ", arg.model_saved_name)
252
+ self.train_writer = SummaryWriter(
253
+ os.path.join(arg.model_saved_name, "train"), "train"
254
+ )
255
+ self.val_writer = SummaryWriter(
256
+ os.path.join(arg.model_saved_name, "val"), "val"
257
+ )
258
+ else:
259
+ self.train_writer = self.val_writer = SummaryWriter(
260
+ os.path.join(arg.model_saved_name, "test"), "test"
261
+ )
262
+ self.global_step = 0
263
+ self.load_model()
264
+ self.load_optimizer()
265
+ self.load_data()
266
+ self.lr = self.arg.base_lr
267
+ self.best_acc = 0
268
+ self.best_per_class_acc = 0
269
+ self.loss_nce = torch.nn.BCELoss()
270
+
271
+ self.my_logger = Logger(
272
+ os.path.join(arg.model_saved_name, "log.txt"), title="SWTAL"
273
+ )
274
+ self.my_logger.set_names(["Step", "cmap"] + [f"map_0.{i}" for i in range(1, 6)])
275
+
276
+ def load_data(self):
277
+ Feeder = import_class(self.arg.feeder)
278
+ self.data_loader = dict()
279
+ if self.arg.phase == "train":
280
+ self.data_loader["train"] = torch.utils.data.DataLoader(
281
+ dataset=Feeder(**self.arg.train_feeder_args),
282
+ batch_size=self.arg.batch_size,
283
+ shuffle=True,
284
+ num_workers=self.arg.num_worker,
285
+ drop_last=True,
286
+ collate_fn=collate_with_padding_multi_joint,
287
+ )
288
+ self.data_loader["test"] = torch.utils.data.DataLoader(
289
+ dataset=Feeder(**self.arg.test_feeder_args),
290
+ batch_size=self.arg.test_batch_size,
291
+ shuffle=False,
292
+ num_workers=self.arg.num_worker,
293
+ drop_last=False,
294
+ collate_fn=collate_with_padding_multi_joint,
295
+ )
296
+
297
+ def load_model(self):
298
+ output_device = (
299
+ self.arg.device[0] if type(self.arg.device) is list else self.arg.device
300
+ )
301
+ self.output_device = output_device
302
+ Model = import_class(self.arg.model)
303
+ shutil.copy2(inspect.getfile(Model), self.arg.work_dir)
304
+ # print(Model)
305
+ self.model = Model(**self.arg.model_args).cuda(output_device)
306
+ # print(self.model)
307
+ self.loss_type = arg.loss
308
+
309
+ if self.arg.weights:
310
+ # self.global_step = int(arg.weights[:-3].split("-")[-1])
311
+ self.print_log("Load weights from {}.".format(self.arg.weights))
312
+ if ".pkl" in self.arg.weights:
313
+ with open(self.arg.weights, "r") as f:
314
+ weights = pickle.load(f)
315
+ else:
316
+ weights = torch.load(self.arg.weights)
317
+
318
+ weights = OrderedDict(
319
+ [
320
+ [k.split("module.")[-1], v.cuda(output_device)]
321
+ for k, v in weights.items()
322
+ ]
323
+ )
324
+
325
+ keys = list(weights.keys())
326
+ for w in self.arg.ignore_weights:
327
+ for key in keys:
328
+ if w in key:
329
+ if weights.pop(key, None) is not None:
330
+ self.print_log(
331
+ "Sucessfully Remove Weights: {}.".format(key)
332
+ )
333
+ else:
334
+ self.print_log("Can Not Remove Weights: {}.".format(key))
335
+
336
+ try:
337
+ self.model.load_state_dict(weights)
338
+ except:
339
+ state = self.model.state_dict()
340
+ diff = list(set(state.keys()).difference(set(weights.keys())))
341
+ print("Can not find these weights:")
342
+ for d in diff:
343
+ print(" " + d)
344
+ state.update(weights)
345
+ self.model.load_state_dict(state)
346
+
347
+ if type(self.arg.device) is list:
348
+ if len(self.arg.device) > 1:
349
+ self.model = nn.DataParallel(
350
+ self.model, device_ids=self.arg.device, output_device=output_device
351
+ )
352
+
353
+ def load_optimizer(self):
354
+ if self.arg.optimizer == "SGD":
355
+ self.optimizer = optim.SGD(
356
+ self.model.parameters(),
357
+ lr=self.arg.base_lr,
358
+ momentum=0.9,
359
+ nesterov=self.arg.nesterov,
360
+ weight_decay=self.arg.weight_decay,
361
+ )
362
+ elif self.arg.optimizer == "Adam":
363
+ self.optimizer = optim.Adam(
364
+ self.model.parameters(),
365
+ lr=self.arg.base_lr,
366
+ weight_decay=self.arg.weight_decay,
367
+ )
368
+ else:
369
+ raise ValueError()
370
+
371
+ def save_arg(self):
372
+ # save arg
373
+ arg_dict = vars(self.arg)
374
+ if not os.path.exists(self.arg.work_dir):
375
+ os.makedirs(self.arg.work_dir)
376
+ with open("{}/config.yaml".format(self.arg.work_dir), "w") as f:
377
+ yaml.dump(arg_dict, f)
378
+
379
+ def adjust_learning_rate(self, epoch):
380
+ if self.arg.optimizer == "SGD" or self.arg.optimizer == "Adam":
381
+ if epoch < self.arg.warm_up_epoch:
382
+ lr = self.arg.base_lr * (epoch + 1) / self.arg.warm_up_epoch
383
+ else:
384
+ lr = self.arg.base_lr * (
385
+ 0.1 ** np.sum(epoch >= np.array(self.arg.step))
386
+ )
387
+ for param_group in self.optimizer.param_groups:
388
+ param_group["lr"] = lr
389
+
390
+ return lr
391
+ else:
392
+ raise ValueError()
393
+
394
+ def print_time(self):
395
+ localtime = time.asctime(time.localtime(time.time()))
396
+ self.print_log("Local current time : " + localtime)
397
+
398
+ def print_log(self, str, print_time=True):
399
+ if print_time:
400
+ localtime = time.asctime(time.localtime(time.time()))
401
+ str = "[ " + localtime + " ] " + str
402
+ print(str)
403
+ if self.arg.print_log:
404
+ with open("{}/print_log.txt".format(self.arg.work_dir), "a") as f:
405
+ print(str, file=f)
406
+
407
+ def record_time(self):
408
+ self.cur_time = time.time()
409
+ return self.cur_time
410
+
411
+ def split_time(self):
412
+ split_time = time.time() - self.cur_time
413
+ self.record_time()
414
+ return split_time
415
+
416
+ @torch.no_grad()
417
+ def eval(
418
+ self,
419
+ epoch,
420
+ wb_dict,
421
+ loader_name=["test"],
422
+ ):
423
+ self.model.eval()
424
+ self.print_log("Eval epoch: {}".format(epoch + 1))
425
+
426
+ vid_preds = []
427
+ frm_preds = []
428
+ vid_lens = []
429
+ labels = []
430
+
431
+ for ln in loader_name:
432
+ loss_value = []
433
+ step = 0
434
+ process = tqdm(self.data_loader[ln])
435
+
436
+ for batch_idx, (data, label, target, mask, index, soft_label) in enumerate(
437
+ process
438
+ ):
439
+ data = data.float().cuda(self.output_device)
440
+ label = label.cuda(self.output_device)
441
+ mask = mask.cuda(self.output_device)
442
+
443
+ ab_labels = torch.cat([label, torch.ones(label.size(0), 1).cuda()], -1)
444
+
445
+ # forward
446
+ mil_pred, frm_scrs, mil_pred_2, frm_scrs_2 = self.model(data,mask)
447
+
448
+ cls_mil_loss = self.loss_nce(
449
+ mil_pred, ab_labels.float()
450
+ ) + self.loss_nce(mil_pred_2, ab_labels.float())
451
+
452
+ loss_co = mvl_loss(frm_scrs, frm_scrs_2, rate=0.2, weight=0.5)
453
+
454
+ loss = cls_mil_loss * self.arg.lambda_mil + loss_co
455
+
456
+ loss_value.append(loss.data.item())
457
+
458
+ for i in range(data.size(0)):
459
+ frm_scr = frm_scrs[i]
460
+ vid_pred = mil_pred[i]
461
+
462
+ label_ = label[i].cpu().numpy()
463
+ mask_ = mask[i].cpu().numpy()
464
+ vid_len = mask_.sum()
465
+
466
+ frm_pred = F.softmax(frm_scr, -1).cpu().numpy()[:vid_len]
467
+ vid_pred = vid_pred.cpu().numpy()
468
+
469
+ vid_preds.append(vid_pred)
470
+ frm_preds.append(frm_pred)
471
+ vid_lens.append(vid_len)
472
+ labels.append(label_)
473
+
474
+ step += 1
475
+
476
+ vid_preds = np.array(vid_preds)
477
+ frm_preds = np.array(frm_preds)
478
+ vid_lens = np.array(vid_lens)
479
+ labels = np.array(labels)
480
+
481
+ cmap = cmAP(vid_preds, labels)
482
+
483
+ score = cmap
484
+ loss = np.mean(loss_value)
485
+
486
+ dmap, iou = dsmAP(
487
+ vid_preds,
488
+ frm_preds,
489
+ vid_lens,
490
+ self.arg.test_feeder_args["data_path"],
491
+ self.arg,
492
+ multi=True,
493
+ )
494
+
495
+ print("Classification map %f" % cmap)
496
+
497
+ for item in list(zip(iou, dmap)):
498
+ print("Detection map @ %f = %f" % (item[0], item[1]))
499
+
500
+ self.my_logger.append([epoch + 1, cmap] + dmap)
501
+
502
+ wb_dict["val loss"] = loss
503
+ wb_dict["val acc"] = score
504
+
505
+ if score > self.best_acc:
506
+ self.best_acc = score
507
+
508
+ print("Acc score: ", score, " model: ", self.arg.model_saved_name)
509
+ if self.arg.phase == "train":
510
+ self.val_writer.add_scalar("loss", loss, self.global_step)
511
+ self.val_writer.add_scalar("acc", score, self.global_step)
512
+
513
+ self.print_log(
514
+ "\tMean {} loss of {} batches: {}.".format(
515
+ ln, len(self.data_loader[ln]), np.mean(loss_value)
516
+ )
517
+ )
518
+ self.print_log("\tAcc score: {:.3f}%".format(score))
519
+
520
+ return wb_dict
521
+
522
+ def start(self):
523
+ wb_dict = {}
524
+
525
+ if self.arg.phase == "test":
526
+ if not self.arg.test_feeder_args["debug"]:
527
+ wf = self.arg.model_saved_name + "_wrong.txt"
528
+ rf = self.arg.model_saved_name + "_right.txt"
529
+ else:
530
+ wf = rf = None
531
+ if self.arg.weights is None:
532
+ raise ValueError("Please appoint --weights.")
533
+ self.arg.print_log = False
534
+ self.print_log("Model: {}.".format(self.arg.model))
535
+ self.print_log("Weights: {}.".format(self.arg.weights))
536
+
537
+ wb_dict = self.eval(
538
+ epoch=0,
539
+ wb_dict=wb_dict,
540
+ loader_name=["test"],
541
+ # wrong_file=wf,
542
+ # result_file=rf,
543
+ )
544
+ print("Inference metrics: ", wb_dict)
545
+ self.print_log("Done.\n")
546
+
547
+
548
+ def str2bool(v):
549
+ if v.lower() in ("yes", "true", "t", "y", "1"):
550
+ return True
551
+ elif v.lower() in ("no", "false", "f", "n", "0"):
552
+ return False
553
+ else:
554
+ raise argparse.ArgumentTypeError("Boolean value expected.")
555
+
556
+
557
+ def import_class(name):
558
+ components = name.split(".")
559
+ mod = __import__(components[0])
560
+ for comp in components[1:]:
561
+ mod = getattr(mod, comp)
562
+ return mod
563
+
564
+
565
+ if __name__ == "__main__":
566
+ parser = get_parser()
567
+
568
+ # load arg form config file
569
+ p = parser.parse_args()
570
+ if p.config is not None:
571
+ with open(p.config, "r") as f:
572
+ default_arg = yaml.safe_load(f)
573
+ key = vars(p).keys()
574
+ for k in default_arg.keys():
575
+ if k not in key:
576
+ print("WRONG ARG: {}".format(k))
577
+ assert k in key
578
+ parser.set_defaults(**default_arg)
579
+
580
+ arg = parser.parse_args()
581
+ print("BABEL Action Recognition")
582
+ print("Config: ", arg)
583
+ init_seed(arg.seed)
584
+ processor = Processor(arg)
585
+ processor.start()
evaluation/classificationMAP.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def getAP(conf, labels):
5
+ assert len(conf) == len(labels)
6
+ sortind = np.argsort(-conf)
7
+ tp = labels[sortind] == 1
8
+ fp = labels[sortind] != 1
9
+ npos = np.sum(labels)
10
+
11
+ fp = np.cumsum(fp).astype("float32")
12
+ tp = np.cumsum(tp).astype("float32")
13
+ rec = tp / npos
14
+ prec = tp / (fp + tp)
15
+ tmp = (labels[sortind] == 1).astype("float32")
16
+
17
+ return np.sum(tmp * prec) / npos if npos > 0 else 1
18
+
19
+
20
+ def getClassificationMAP(confidence, labels):
21
+ """ confidence and labels are of dimension n_samples x n_label """
22
+
23
+ AP = []
24
+ for i in range(np.shape(labels)[1]):
25
+ AP.append(getAP(confidence[:, i], labels[:, i]))
26
+ return 100 * sum(AP) / len(AP)
evaluation/detectionMAP.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from collections import Counter
3
+
4
+ import numpy as np
5
+
6
+
7
+ def str2ind(categoryname, classlist):
8
+ return [i for i in range(len(classlist)) if categoryname == classlist[i]][0]
9
+
10
+
11
+ def encode_mask_to_rle(mask):
12
+ """
13
+ mask: numpy array binary mask
14
+ 1 - mask
15
+ 0 - background
16
+ Returns encoded run length
17
+ """
18
+ pixels = mask.flatten()
19
+ pixels = np.concatenate([[0], pixels, [0]])
20
+ runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
21
+ runs[1::2] -= runs[::2]
22
+ return runs
23
+
24
+
25
+ def filter_segments(segment_predict, videonames, ambilist, factor):
26
+ ind = np.zeros(np.shape(segment_predict)[0])
27
+ for i in range(np.shape(segment_predict)[0]):
28
+ vn = videonames[int(segment_predict[i, 0])]
29
+ for a in ambilist:
30
+ if a[0] == vn:
31
+ gt = range(
32
+ int(round(float(a[2]) * factor)), int(round(float(a[3]) * factor))
33
+ )
34
+ pd = range(int(segment_predict[i][1]), int(segment_predict[i][2]))
35
+ IoU = float(len(set(gt).intersection(set(pd)))) / float(
36
+ len(set(gt).union(set(pd)))
37
+ )
38
+ if IoU > 0:
39
+ ind[i] = 1
40
+ s = [
41
+ segment_predict[i, :]
42
+ for i in range(np.shape(segment_predict)[0])
43
+ if ind[i] == 0
44
+ ]
45
+ return np.array(s)
46
+
47
+
48
+ def getActLoc(
49
+ vid_preds, frm_preds, vid_lens, act_thresh_cas, annotation_path, args, multi=False
50
+ ):
51
+
52
+ try:
53
+ with open(annotation_path) as f:
54
+ data = pickle.load(f)
55
+ except:
56
+ # for pickle file from python2
57
+ with open(annotation_path, "rb") as f:
58
+ data = pickle.load(f, encoding="latin1")
59
+
60
+ if multi:
61
+ gtsegments = []
62
+ gtlabels = []
63
+ for idx in range(len(data["L"])):
64
+ gt = data["L"][idx]
65
+ gt_ = set(gt)
66
+ gt_.discard(args.model_args["num_class"])
67
+ gts = []
68
+ gtl = []
69
+ for c in list(gt_):
70
+ gt_encoded = encode_mask_to_rle(gt == c)
71
+ gts.extend(
72
+ [
73
+ [x - 1, x + y - 2]
74
+ for x, y in zip(gt_encoded[::2], gt_encoded[1::2])
75
+ ]
76
+ )
77
+ gtl.extend([c for item in gt_encoded[::2]])
78
+ gtsegments.append(gts)
79
+ gtlabels.append(gtl)
80
+ else:
81
+ gtsegments = []
82
+ gtlabels = []
83
+ for idx in range(len(data["L"])):
84
+ gt = data["L"][idx]
85
+ gt_encoded = encode_mask_to_rle(gt)
86
+ gtsegments.append(
87
+ [[x - 1, x + y - 2] for x, y in zip(gt_encoded[::2], gt_encoded[1::2])]
88
+ )
89
+ gtlabels.append([data["Y"][idx] for item in gt_encoded[::2]])
90
+
91
+ videoname = np.array(data["sid"])
92
+
93
+ # keep ground truth and predictions for instances with temporal annotations
94
+ gtl, vn, vp, fp, vl = [], [], [], [], []
95
+ for i, s in enumerate(gtsegments):
96
+ if len(s):
97
+ gtl.append(gtlabels[i])
98
+ vn.append(videoname[i])
99
+ vp.append(vid_preds[i])
100
+ fp.append(frm_preds[i])
101
+ vl.append(vid_lens[i])
102
+ else:
103
+ print(i)
104
+ gtlabels = gtl
105
+ videoname = vn
106
+
107
+ # which categories have temporal labels ?
108
+ templabelidx = sorted(list(set([l for gtl in gtlabels for l in gtl])))
109
+
110
+ dataset_segment_predict = []
111
+ class_threshold = args.class_threshold
112
+ for c in range(frm_preds[0].shape[1]):
113
+ c_temp = []
114
+ # Get list of all predictions for class c
115
+ for i in range(len(fp)):
116
+ vid_cls_score = vp[i][c]
117
+ vid_cas = fp[i][:, c]
118
+ vid_cls_proposal = []
119
+ # if vid_cls_score < class_threshold:
120
+ # continue
121
+ for t in range(len(act_thresh_cas)):
122
+ thres = act_thresh_cas[t]
123
+ vid_pred = np.concatenate(
124
+ [np.zeros(1), (vid_cas > thres).astype("float32"), np.zeros(1)],
125
+ axis=0,
126
+ )
127
+ vid_pred_diff = [
128
+ vid_pred[idt] - vid_pred[idt - 1] for idt in range(1, len(vid_pred))
129
+ ]
130
+ s = [idk for idk, item in enumerate(vid_pred_diff) if item == 1]
131
+ e = [idk for idk, item in enumerate(vid_pred_diff) if item == -1]
132
+ for j in range(len(s)):
133
+ len_proposal = e[j] - s[j]
134
+ if len_proposal >= 3:
135
+ inner_score = np.mean(vid_cas[s[j] : e[j] + 1])
136
+ outer_s = max(0, int(s[j] - 0.25 * len_proposal))
137
+ outer_e = min(
138
+ int(vid_cas.shape[0] - 1),
139
+ int(e[j] + 0.25 * len_proposal + 1),
140
+ )
141
+ outer_temp_list = list(range(outer_s, int(s[j]))) + list(
142
+ range(int(e[j] + 1), outer_e)
143
+ )
144
+ if len(outer_temp_list) == 0:
145
+ outer_score = 0
146
+ else:
147
+ outer_score = np.mean(vid_cas[outer_temp_list])
148
+ c_score = inner_score - 0.6 * outer_score
149
+ vid_cls_proposal.append([i, s[j], e[j] + 1, c_score])
150
+ pick_idx = NonMaximumSuppression(np.array(vid_cls_proposal), 0.2)
151
+ nms_vid_cls_proposal = [vid_cls_proposal[k] for k in pick_idx]
152
+ c_temp += nms_vid_cls_proposal
153
+ if len(c_temp) > 0:
154
+ c_temp = np.array(c_temp)
155
+ dataset_segment_predict.append(c_temp)
156
+ """
157
+ for i, pred in enumerate(dataset_segment_predict):
158
+ print (f"#{i} class {c} has {len(pred)} predictions")
159
+ """
160
+ return dataset_segment_predict
161
+
162
+
163
+ def IntergrateSegs(rgb_segs, flow_segs, th, args):
164
+ NUM_CLASS = args.class_num
165
+ NUM_VID = 212
166
+ segs = []
167
+ for i in range(NUM_CLASS):
168
+ class_seg = []
169
+ rgb_seg = rgb_segs[i]
170
+ flow_seg = flow_segs[i]
171
+ rgb_seg_ind = np.array(rgb_seg)[:, 0]
172
+ flow_seg_ind = np.array(flow_seg)[:, 0]
173
+ for j in range(NUM_VID):
174
+ rgb_find = np.where(rgb_seg_ind == j)
175
+ flow_find = np.where(flow_seg_ind == j)
176
+ if len(rgb_find[0]) == 0 and len(flow_find[0]) == 0:
177
+ continue
178
+ elif len(rgb_find[0]) != 0 and len(flow_find[0]) != 0:
179
+ rgb_vid_seg = rgb_seg[rgb_find[0]]
180
+ flow_vid_seg = flow_seg[flow_find[0]]
181
+ fuse_seg = np.concatenate([rgb_vid_seg, flow_vid_seg], axis=0)
182
+ pick_idx = NonMaximumSuppression(fuse_seg, th)
183
+ fuse_segs = fuse_seg[pick_idx]
184
+ class_seg.append(fuse_segs)
185
+ elif len(rgb_find[0]) != 0 and len(flow_find[0]) == 0:
186
+ vid_seg = rgb_seg[rgb_find[0]]
187
+ class_seg.append(vid_seg)
188
+ elif len(rgb_find[0]) == 0 and len(flow_find[0]) != 0:
189
+ vid_seg = flow_seg[flow_find[0]]
190
+ class_seg.append(vid_seg)
191
+ class_seg = np.concatenate(class_seg, axis=0)
192
+ segs.append(class_seg)
193
+ return segs
194
+
195
+
196
+ def NonMaximumSuppression(segs, overlapThresh):
197
+ # if there are no boxes, return an empty list
198
+ if len(segs) == 0:
199
+ return []
200
+ # if the bounding boxes integers, convert them to floats --
201
+ # this is important since we'll be doing a bunch of divisions
202
+ if segs.dtype.kind == "i":
203
+ segs = segs.astype("float")
204
+
205
+ # initialize the list of picked indexes
206
+ pick = []
207
+
208
+ # grab the coordinates of the segments
209
+ s = segs[:, 1]
210
+ e = segs[:, 2]
211
+ scores = segs[:, 3]
212
+ # compute the area of the bounding boxes and sort the bounding
213
+ # boxes by the score of the bounding box
214
+ area = e - s + 1
215
+ idxs = np.argsort(scores)
216
+
217
+ # keep looping while some indexes still remain in the indexes
218
+ # list
219
+ while len(idxs) > 0:
220
+ # grab the last index in the indexes list and add the
221
+ # index value to the list of picked indexes
222
+ last = len(idxs) - 1
223
+ i = idxs[last]
224
+ pick.append(i)
225
+
226
+ # find the largest coordinates for the start of
227
+ # the segments and the smallest coordinates
228
+ # for the end of the segments
229
+ maxs = np.maximum(s[i], s[idxs[:last]])
230
+ mine = np.minimum(e[i], e[idxs[:last]])
231
+
232
+ # compute the length of the overlapping area
233
+ l = np.maximum(0, mine - maxs + 1)
234
+ # compute the ratio of overlap
235
+ overlap = l / area[idxs[:last]]
236
+
237
+ # delete segments beyond the threshold
238
+ idxs = np.delete(
239
+ idxs, np.concatenate(([last], np.where(overlap > overlapThresh)[0]))
240
+ )
241
+ return pick
242
+
243
+
244
+ def getLocMAP(seg_preds, th, annotation_path, args, multi=False, factor=1.0):
245
+ try:
246
+ with open(annotation_path) as f:
247
+ data = pickle.load(f)
248
+ except:
249
+ # for pickle file from python2
250
+ with open(annotation_path, "rb") as f:
251
+ data = pickle.load(f, encoding="latin1")
252
+
253
+ if multi:
254
+ gtsegments = []
255
+ gtlabels = []
256
+ for idx in range(len(data["L"])):
257
+ gt = data["L"][idx]
258
+ gt_ = set(gt)
259
+ # gt_.discard(args.model_args["num_classes"])
260
+ gt_.discard(4)
261
+ gts = []
262
+ gtl = []
263
+ for c in list(gt_):
264
+ gt_encoded = encode_mask_to_rle(gt == c)
265
+ gts.extend(
266
+ [
267
+ [x - 1, x + y - 2]
268
+ for x, y in zip(gt_encoded[::2], gt_encoded[1::2])
269
+ ]
270
+ )
271
+ gtl.extend([c for item in gt_encoded[::2]])
272
+ gtsegments.append(gts)
273
+ gtlabels.append(gtl)
274
+ # else:
275
+ # gtsegments = []
276
+ # gtlabels = []
277
+ # for idx in range(len(data["L"])):
278
+ # gt = data["L"][idx]
279
+ # gt_encoded = encode_mask_to_rle(gt)
280
+ # gtsegments.append(
281
+ # [[x - 1, x + y - 2] for x, y in zip(gt_encoded[::2], gt_encoded[1::2])]
282
+ # )
283
+ # gtlabels.append([data["Y"][idx] for item in gt_encoded[::2]])
284
+
285
+ # videoname = np.array(data["sid"])
286
+ # """
287
+ # cnt = Counter(data['Y'])
288
+ # d = cnt.most_common()
289
+ # print (d)
290
+ # """
291
+ # # which categories have temporal labels ?
292
+
293
+
294
+ # templabelidx = sorted(list(set([l for gtl in gtlabels for l in gtl])))
295
+ templabelidx = [0,1,2,3]
296
+ ap = []
297
+ for c in templabelidx:
298
+ segment_predict = seg_preds[c]
299
+ # Sort the list of predictions for class c based on score
300
+ if len(segment_predict) == 0:
301
+ ap.append(0.0)
302
+ continue
303
+ segment_predict = segment_predict[np.argsort(-segment_predict[:, 3])]
304
+
305
+ # Create gt list
306
+ segment_gt = [
307
+ [i, gtsegments[i][j][0], gtsegments[i][j][1]]
308
+ for i in range(len(gtsegments))
309
+ for j in range(len(gtsegments[i]))
310
+ if gtlabels[i][j] == c
311
+ ]
312
+ gtpos = len(segment_gt)
313
+
314
+ # Compare predictions and gt
315
+ tp, fp = [], []
316
+ for i in range(len(segment_predict)):
317
+ matched = False
318
+ best_iou = 0
319
+ for j in range(len(segment_gt)):
320
+ if segment_predict[i][0] == segment_gt[j][0]:
321
+ gt = range(
322
+ int(round(segment_gt[j][1] * factor)),
323
+ int(round(segment_gt[j][2] * factor)),
324
+ )
325
+ p = range(int(segment_predict[i][1]), int(segment_predict[i][2]))
326
+ # IoU = float(len(set(gt).intersection(set(p)))) / float(
327
+ # len(set(gt).union(set(p)))
328
+ # )
329
+ union_set = set(gt).union(set(p))
330
+ if len(union_set) == 0:
331
+ IoU = 0.0 # or handle the case as needed
332
+ else:
333
+ IoU = float(len(set(gt).intersection(set(p)))) / float(len(union_set))
334
+ if IoU >= th:
335
+ matched = True
336
+ if IoU > best_iou:
337
+ best_iou = IoU
338
+ best_j = j
339
+ if matched:
340
+ del segment_gt[best_j]
341
+ tp.append(float(matched))
342
+ fp.append(1.0 - float(matched))
343
+ tp_c = np.cumsum(tp)
344
+ fp_c = np.cumsum(fp)
345
+ # print (c, tp, fp)
346
+ if sum(tp) == 0:
347
+ prc = 0.0
348
+ else:
349
+ cur_prec = tp_c / (fp_c + tp_c)
350
+ cur_rec = tp_c / gtpos
351
+ prc = _ap_from_pr(cur_prec, cur_rec)
352
+ ap.append(prc)
353
+
354
+ print(f" ".join([f"{item*100:.2f}" for item in ap]))
355
+ if ap:
356
+ return 100 * np.mean(ap)
357
+ else:
358
+ return 0
359
+
360
+
361
+ # Inspired by Pascal VOC evaluation tool.
362
+ def _ap_from_pr(prec, rec):
363
+ mprec = np.hstack([[0], prec, [0]])
364
+ mrec = np.hstack([[0], rec, [1]])
365
+
366
+ for i in range(len(mprec) - 1)[::-1]:
367
+ mprec[i] = max(mprec[i], mprec[i + 1])
368
+
369
+ idx = np.where(mrec[1::] != mrec[0:-1])[0] + 1
370
+ ap = np.sum((mrec[idx] - mrec[idx - 1]) * mprec[idx])
371
+
372
+ return ap
373
+
374
+
375
+ def compute_iou(dur1, dur2):
376
+ # find the each edge of intersect rectangle
377
+ left_line = max(dur1[0], dur2[0])
378
+ right_line = min(dur1[1], dur2[1])
379
+
380
+ # judge if there is an intersect
381
+ if left_line >= right_line:
382
+ return 0
383
+ else:
384
+ intersect = right_line - left_line
385
+ union = max(dur1[1], dur2[1]) - min(dur1[0], dur2[0])
386
+ return intersect / union
387
+
388
+ def getActLoc1(
389
+ frm_preds,act_thresh_cas = np.arange(0.03, 0.055, 0.005)
390
+ ):
391
+ fp = []
392
+ for i, s in enumerate(frm_preds):
393
+ fp.append(frm_preds[i])
394
+
395
+ dataset_segment_predict = []
396
+ for c in range(frm_preds[0].shape[1]):
397
+ c_temp = []
398
+ # Get list of all predictions for class c
399
+ for i in range(len(fp)):
400
+ vid_cas = fp[i][:, c]
401
+ vid_cls_proposal = []
402
+
403
+ for t in range(len(act_thresh_cas)):
404
+ thres = act_thresh_cas[t]
405
+ vid_pred = np.concatenate(
406
+ [np.zeros(1), (vid_cas > thres).astype("float32"), np.zeros(1)],
407
+ axis=0,
408
+ )
409
+ vid_pred_diff = [
410
+ vid_pred[idt] - vid_pred[idt - 1] for idt in range(1, len(vid_pred))
411
+ ]
412
+ s = [idk for idk, item in enumerate(vid_pred_diff) if item == 1]
413
+ e = [idk for idk, item in enumerate(vid_pred_diff) if item == -1]
414
+ for j in range(len(s)):
415
+ len_proposal = e[j] - s[j]
416
+ if len_proposal >= 3:
417
+ inner_score = np.mean(vid_cas[s[j] : e[j] + 1])
418
+ outer_s = max(0, int(s[j] - 0.25 * len_proposal))
419
+ outer_e = min(
420
+ int(vid_cas.shape[0] - 1),
421
+ int(e[j] + 0.25 * len_proposal + 1),
422
+ )
423
+ outer_temp_list = list(range(outer_s, int(s[j]))) + list(
424
+ range(int(e[j] + 1), outer_e)
425
+ )
426
+ if len(outer_temp_list) == 0:
427
+ outer_score = 0
428
+ else:
429
+ outer_score = np.mean(vid_cas[outer_temp_list])
430
+ c_score = inner_score - 0.6 * outer_score
431
+ vid_cls_proposal.append([i, s[j], e[j] + 1, c_score])
432
+ pick_idx = NonMaximumSuppression(np.array(vid_cls_proposal), 0.2)
433
+ nms_vid_cls_proposal = [vid_cls_proposal[k] for k in pick_idx]
434
+ c_temp += nms_vid_cls_proposal
435
+ if len(c_temp) > 0:
436
+ c_temp = np.array(c_temp)
437
+ dataset_segment_predict.append(c_temp)
438
+ """
439
+ for i, pred in enumerate(dataset_segment_predict):
440
+ print (f"#{i} class {c} has {len(pred)} predictions")
441
+ """
442
+ return dataset_segment_predict
443
+
444
+
445
+ def getSingleStreamDetectionMAP(
446
+ vid_preds, frm_preds, vid_lens, annotation_path, args, multi=False, factor=1.0
447
+ ):
448
+ iou_list = [0.1, 0.2, 0.3, 0.4, 0.5]
449
+ dmap_list = []
450
+
451
+ seg = getActLoc1(
452
+ frm_preds,
453
+ np.arange(args.start_threshold, args.end_threshold, args.threshold_interval),
454
+ )
455
+
456
+ for iou in iou_list:
457
+ print("Testing for IoU %f" % iou)
458
+ dmap_list.append(
459
+ getLocMAP(seg, iou, annotation_path, args, multi=multi, factor=factor)
460
+ )
461
+ return dmap_list, iou_list
462
+
463
+
464
+ def getTwoStreamDetectionMAP(
465
+ rgb_vid_preds,
466
+ flow_vid_preds,
467
+ rgb_frm_preds,
468
+ flow_frm_preds,
469
+ vid_lens,
470
+ annotation_path,
471
+ args,
472
+ ):
473
+ iou_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
474
+ dmap_list = []
475
+ rgb_seg = getActLoc(
476
+ rgb_vid_preds,
477
+ rgb_frm_preds * 0.1,
478
+ vid_lens,
479
+ np.arange(args.start_threshold, args.end_threshold, args.threshold_interval)
480
+ * 0.1,
481
+ annotation_path,
482
+ args,
483
+ )
484
+ flow_seg = getActLoc(
485
+ flow_vid_preds,
486
+ flow_frm_preds,
487
+ vid_lens,
488
+ np.arange(args.start_threshold, args.end_threshold, args.threshold_interval),
489
+ annotation_path,
490
+ args,
491
+ )
492
+ seg = IntergrateSegs(rgb_seg, flow_seg, 0.9, args)
493
+ for iou in iou_list:
494
+ print("Testing for IoU %f" % iou)
495
+ dmap_list.append(getLocMAP(seg, iou, annotation_path, args))
496
+
497
+ return dmap_list, iou_list
498
+
499
+
500
+ def getSingleStreamDetectionMAP_gcn(
501
+ seg, annotation_path, args, multi=False, factor=1.0
502
+ ):
503
+ '''
504
+ seg is a list of 4+1 ndarrays
505
+ each ndarray is of shape (# pred, 4), 4 expands as [videoindex, s[j], e[j] + 1, c_score]
506
+ '''
507
+ iou_list = [0.3, 0.5]
508
+ iou_list = [0.1,0.2,0.3, 0.4,0.5]
509
+ dmap_list = []
510
+
511
+ for iou in iou_list:
512
+ print("Testing for IoU %f" % iou)
513
+ dmap_list.append(
514
+ getLocMAP(seg, iou, annotation_path, args, multi=multi, factor=factor)
515
+ )
516
+ return dmap_list, iou_list
evaluation/eval.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Variable
5
+
6
+ from .classificationMAP import getClassificationMAP as cmAP
7
+ from .detectionMAP import getSingleStreamDetectionMAP as dsmAP
8
+ from .detectionMAP import getTwoStreamDetectionMAP as dtmAP
9
+ from .utils import write_results_to_eval_file, write_results_to_file
10
+
11
+
12
+ def ss_eval(epoch, dataloader, args, logger, model, device):
13
+ vid_preds = []
14
+ frm_preds = []
15
+ vid_lens = []
16
+ labels = []
17
+
18
+ for num, sample in enumerate(dataloader):
19
+ if (num + 1) % 100 == 0:
20
+ print("Testing test data point %d of %d" % (num + 1, len(dataloader)))
21
+
22
+ features = sample["data"].numpy()
23
+ label = sample["labels"].numpy()
24
+ vid_len = sample["vid_len"].numpy()
25
+
26
+ features = torch.from_numpy(features).float().to(device)
27
+
28
+ with torch.no_grad():
29
+ _, vid_pred, _, frm_scr = model(Variable(features))
30
+ frm_pred = F.softmax(frm_scr, -1)
31
+ vid_pred = np.squeeze(vid_pred.cpu().data.numpy(), axis=0)
32
+ frm_pred = np.squeeze(frm_pred.cpu().data.numpy(), axis=0)
33
+ label = np.squeeze(label, axis=0)
34
+
35
+ vid_preds.append(vid_pred)
36
+ frm_preds.append(frm_pred)
37
+ vid_lens.append(vid_len)
38
+ labels.append(label)
39
+
40
+ vid_preds = np.array(vid_preds)
41
+ frm_preds = np.array(frm_preds)
42
+ vid_lens = np.array(vid_lens)
43
+ labels = np.array(labels)
44
+
45
+ cmap = cmAP(vid_preds, labels)
46
+ dmap, iou = dsmAP(
47
+ vid_preds, frm_preds, vid_lens, dataloader.dataset.path_to_annotations, args
48
+ )
49
+
50
+ print("Classification map %f" % cmap)
51
+ for item in list(zip(iou, dmap)):
52
+ print("Detection map @ %f = %f" % (item[0], item[1]))
53
+
54
+ logger.log_value("Test Classification mAP", cmap, epoch)
55
+ for item in list(zip(dmap, iou)):
56
+ logger.log_value("Test Detection1 mAP @ IoU = " + str(item[1]), item[0], epoch)
57
+
58
+ write_results_to_file(args, dmap, cmap, epoch)
59
+
60
+
61
+ def ts_eval(dataloader, args, logger, rgb_model, flow_model, device):
62
+ rgb_vid_preds = []
63
+ rgb_frame_preds = []
64
+ flow_vid_preds = []
65
+ flow_frame_preds = []
66
+ vid_lens = []
67
+ labels = []
68
+
69
+ for num, sample in enumerate(dataloader):
70
+ if (num + 1) % 100 == 0:
71
+ print("Testing test data point %d of %d" % (num + 1, len(dataloader)))
72
+
73
+ rgb_features = sample["rgb_data"].numpy()
74
+ flow_features = sample["flow_data"].numpy()
75
+ label = sample["labels"].numpy()
76
+ vid_len = sample["vid_len"].numpy()
77
+
78
+ rgb_features_inp = torch.from_numpy(rgb_features).float().to(device)
79
+ flow_features_inp = torch.from_numpy(flow_features).float().to(device)
80
+
81
+ with torch.no_grad():
82
+ _, rgb_video_pred, _, rgb_frame_scr = rgb_model(Variable(rgb_features_inp))
83
+ _, flow_video_pred, _, flow_frame_scr = flow_model(
84
+ Variable(flow_features_inp)
85
+ )
86
+
87
+ rgb_frame_pred = F.softmax(rgb_frame_scr, -1)
88
+ flow_frame_pred = F.softmax(flow_frame_scr, -1)
89
+
90
+ rgb_frame_pred = np.squeeze(rgb_frame_pred.cpu().data.numpy(), axis=0)
91
+ flow_frame_pred = np.squeeze(flow_frame_pred.cpu().data.numpy(), axis=0)
92
+ rgb_video_pred = np.squeeze(rgb_video_pred.cpu().data.numpy(), axis=0)
93
+ flow_video_pred = np.squeeze(flow_video_pred.cpu().data.numpy(), axis=0)
94
+ label = np.squeeze(label, axis=0)
95
+
96
+ rgb_vid_preds.append(rgb_video_pred)
97
+ rgb_frame_preds.append(rgb_frame_pred)
98
+ flow_vid_preds.append(flow_video_pred)
99
+ flow_frame_preds.append(flow_frame_pred)
100
+ vid_lens.append(vid_len)
101
+ labels.append(label)
102
+
103
+ rgb_vid_preds = np.array(rgb_vid_preds)
104
+ rgb_frame_preds = np.array(rgb_frame_preds)
105
+ flow_vid_preds = np.array(flow_vid_preds)
106
+ flow_frame_preds = np.array(flow_frame_preds)
107
+ vid_lens = np.array(vid_lens)
108
+ labels = np.array(labels)
109
+
110
+ dmap, iou = dtmAP(
111
+ rgb_vid_preds,
112
+ flow_vid_preds,
113
+ rgb_frame_preds,
114
+ flow_frame_preds,
115
+ vid_lens,
116
+ dataloader.dataset.path_to_annotations,
117
+ args,
118
+ )
119
+
120
+ sum = 0
121
+ count = 0
122
+ for item in list(zip(iou, dmap)):
123
+ print("Detection map @ %f = %f" % (item[0], item[1]))
124
+ if count < 7:
125
+ sum = sum + item[1]
126
+ count += 1
127
+
128
+ print("average map = %f" % (sum / count))
129
+ write_results_to_eval_file(args, dmap, args.rgb_load_epoch, args.flow_load_epoch)
evaluation/utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def str2ind(categoryname, classlist):
7
+ return [
8
+ i for i in range(len(classlist)) if categoryname == classlist[i].decode("utf-8")
9
+ ][0]
10
+
11
+
12
+ def strlist2indlist(strlist, classlist):
13
+ return [str2ind(s, classlist) for s in strlist]
14
+
15
+
16
+ def strlist2multihot(strlist, classlist):
17
+ return np.sum(np.eye(len(classlist))[strlist2indlist(strlist, classlist)], axis=0)
18
+
19
+
20
+ def idx2multihot(id_list, num_class):
21
+ return np.sum(np.eye(num_class)[id_list], axis=0)
22
+
23
+
24
+ def write_results_to_eval_file(args, dmap, itr1, itr2):
25
+ file_folder = "./ckpt/" + args.dataset_name + "/eval/"
26
+ file_name = args.dataset_name + "-results.log"
27
+ fid = open(file_folder + file_name, "a+")
28
+ string_to_write = str(itr1)
29
+ string_to_write += " " + str(itr2)
30
+ for item in dmap:
31
+ string_to_write += " " + "%.2f" % item
32
+ fid.write(string_to_write + "\n")
33
+ fid.close()
34
+
35
+
36
+ def write_results_to_file(args, dmap, cmap, itr):
37
+ file_folder = "./ckpt/" + args.dataset_name + "/" + str(args.model_id) + "/"
38
+ file_name = args.dataset_name + "-results.log"
39
+ fid = open(file_folder + file_name, "a+")
40
+ string_to_write = str(itr)
41
+ for item in dmap:
42
+ string_to_write += " " + "%.2f" % item
43
+ string_to_write += " " + "%.2f" % cmap
44
+ fid.write(string_to_write + "\n")
45
+ fid.close()
46
+
47
+
48
+ def write_settings_to_file(args):
49
+ file_folder = "./ckpt/" + args.dataset_name + "/" + str(args.model_id) + "/"
50
+ file_name = args.dataset_name + "-results.log"
51
+ fid = open(file_folder + file_name, "a+")
52
+ string_to_write = "#" * 80 + "\n"
53
+ for arg in vars(args):
54
+ string_to_write += str(arg) + ": " + str(getattr(args, arg)) + "\n"
55
+ string_to_write += "*" * 80 + "\n"
56
+ fid.write(string_to_write)
57
+ fid.close()
feeders/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import feeder
feeders/feeder.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # -*- coding: utf-8 -*-
4
+ #
5
+ # Adapted from https://github.com/lshiwjx/2s-AGCN for BABEL (https://babel.is.tue.mpg.de/)
6
+
7
+ import json
8
+ import math
9
+ import os
10
+ import os.path as osp
11
+ import pdb
12
+ import pickle
13
+ import random
14
+ import shutil
15
+ import subprocess
16
+ import sys
17
+ import uuid
18
+
19
+ import matplotlib.pyplot as plt
20
+ import numpy as np
21
+ import torch
22
+ from feeders import tools
23
+ from torch.utils.data import Dataset
24
+
25
+ sys.path.extend(["../"])
26
+
27
+
28
+ class Feeder(Dataset):
29
+ def __init__(
30
+ self,
31
+ data_path,
32
+ random_choose=False,
33
+ random_shift=False,
34
+ random_move=False,
35
+ window_size=-1,
36
+ debug=False,
37
+ use_mmap=True,
38
+ frame_pad=False,
39
+ nb_class=3,
40
+ ):
41
+ """
42
+
43
+ :param data_path:
44
+ :param label_path:
45
+ :param random_choose: If true, randomly choose a portion of the input sequence
46
+ :param random_shift: If true, randomly pad zeros at the begining or end of sequence
47
+ :param random_move:
48
+ :param window_size: The length of the output sequence
49
+ :param normalization: If true, normalize input sequence
50
+ :param debug: If true, only use the first 100 samples
51
+ :param use_mmap: If true, use mmap mode to load data, which can save the running memory
52
+ """
53
+
54
+ self.debug = debug
55
+ self.data_path = data_path
56
+ self.random_choose = random_choose
57
+ self.random_shift = random_shift
58
+ self.random_move = random_move
59
+ self.window_size = window_size
60
+ self.use_mmap = use_mmap
61
+ self.nb_class = nb_class
62
+ self.frame_pad = frame_pad
63
+ self.load_data()
64
+ self.count = 0
65
+ for i in range(len(self.data["X"])):
66
+ assert self.data["L"][i].shape[0] == self.data["X"][i].shape[1]
67
+
68
+ self.prediction = [
69
+ np.zeros((item.shape[0], 10, self.nb_class + 1), dtype=np.float32)
70
+ for item in self.data["L"]
71
+ ]
72
+ self.soft_labels = [
73
+ np.zeros((item.shape[0], self.nb_class + 1), dtype=np.float32)
74
+ for item in self.data["L"]
75
+ ]
76
+
77
+ def load_data(self):
78
+ # data: N, C, T, V, M
79
+ # load data
80
+ try:
81
+ with open(self.data_path) as f:
82
+ self.data = pickle.load(f)
83
+ except:
84
+ # for pickle file from python2
85
+ with open(self.data_path, "rb") as f:
86
+ self.data = pickle.load(f, encoding="latin1")
87
+
88
+ def label_update(self, results, indexs):
89
+ self.count += 1
90
+
91
+ # While updating the noisy label y_i by the probability s, we used the average output probability of the network of the past 10 epochs as s.
92
+ idx = (self.count - 1) % 10
93
+
94
+ for ind, res in zip(indexs, results):
95
+ self.prediction[ind][:, idx, :] = res
96
+
97
+ for i in range(len(self.prediction)):
98
+ self.soft_labels[i] = self.prediction[i].mean(axis=1)
99
+
100
+ def __len__(self):
101
+ return len(self.data["X"])
102
+
103
+ def __iter__(self):
104
+ return self
105
+
106
+ def __getitem__(self, index):
107
+ '''
108
+ data_numpy: read joints from PKL and no padding here as frame_pad is false
109
+ label: video level label
110
+ gt: action label for each frame
111
+
112
+ mask?
113
+ index?
114
+ frame_label: soft label
115
+ '''
116
+ data_numpy = self.data["X"][index]
117
+ data_numpy = np.array(data_numpy)
118
+
119
+ label = self.data["Y"][index]
120
+ label_np = np.zeros(self.nb_class)
121
+ for item in label:
122
+ label_np[item] = 1
123
+ label = np.array(label_np)
124
+
125
+ gt = self.data["L"][index]
126
+ gt = np.array(gt)
127
+
128
+ if self.random_shift:
129
+ data_numpy = tools.random_shift(data_numpy)
130
+ if self.random_choose:
131
+ data_numpy = tools.random_choose(data_numpy, self.window_size)
132
+ elif self.window_size > 0:
133
+ data_numpy = tools.auto_pading(data_numpy, self.window_size)
134
+ if self.random_move:
135
+ data_numpy = tools.random_move(data_numpy)
136
+
137
+ if self.frame_pad:
138
+ C, T, V, M = data_numpy.shape
139
+ if T % 15 != 0:
140
+ new_T = T + 15 - T % 15
141
+
142
+ data_numpy_paded = np.zeros((C, new_T, V, M))
143
+ data_numpy_paded[:, :T, :, :] = data_numpy
144
+
145
+ data_numpy = data_numpy_paded
146
+
147
+ mask = np.ones_like(gt)
148
+
149
+ frame_label = self.soft_labels[index]
150
+
151
+ return data_numpy, label, gt, mask, index, frame_label
152
+
153
+
154
+ def import_class(name):
155
+ components = name.split(".")
156
+ mod = __import__(components[0])
157
+ for comp in components[1:]:
158
+ mod = getattr(mod, comp)
159
+ return mod
160
+
161
+
162
+ def test(
163
+ dataset,
164
+ preds=None,
165
+ th=None,
166
+ idx=None,
167
+ graph="graph.ntu_rgb_d.Graph",
168
+ is_3d=True,
169
+ folder_p="viz",
170
+ label_json="prepare/configs/action_label_split1.json",
171
+ ):
172
+ """
173
+ vis the samples using matplotlib
174
+ :param data_path:
175
+ :param vid: the id of sample
176
+ :param graph:
177
+ :param is_3d: when vis NTU, set it True
178
+ :return:
179
+ """
180
+
181
+ with open(label_json) as infile:
182
+ jc = json.load(infile)
183
+
184
+ idx2act = {v: k for k, v in jc.items()}
185
+
186
+ idx2act[len(idx2act)] = "other"
187
+
188
+ if osp.exists(osp.join(folder_p, "frames")):
189
+ shutil.rmtree(osp.join(folder_p, "frames"))
190
+ os.makedirs(osp.join(folder_p, "frames"))
191
+
192
+ data, label, gt, _ = dataset[idx]
193
+ data = data.reshape((1,) + data.shape)
194
+
195
+ # for batch_idx, (data, label) in enumerate(loader):
196
+ N, C, T, V, M = data.shape
197
+
198
+ plt.ion()
199
+ fig = plt.figure()
200
+ if is_3d:
201
+ from mpl_toolkits.mplot3d import Axes3D
202
+
203
+ ax = fig.add_subplot(111, projection="3d")
204
+ else:
205
+ ax = fig.add_subplot(111)
206
+
207
+ if graph is None:
208
+ p_type = ["b.", "g.", "r.", "c.", "m.", "y.", "k.", "k.", "k.", "k."]
209
+ pose = [ax.plot(np.zeros(V), np.zeros(V), p_type[m])[0] for m in range(M)]
210
+ ax.axis([-1, 1, -1, 1])
211
+ for t in range(T):
212
+ for m in range(M):
213
+ pose[m].set_xdata(data[0, 0, t, :, m])
214
+ pose[m].set_ydata(data[0, 1, t, :, m])
215
+ fig.canvas.draw()
216
+ plt.pause(0.001)
217
+ else:
218
+ p_type = ["b-", "g-", "r-", "c-", "m-", "y-", "k-", "k-", "k-", "k-"]
219
+ import sys
220
+ from os import path
221
+
222
+ sys.path.append(
223
+ path.dirname(path.dirname(path.dirname(path.abspath(__file__))))
224
+ )
225
+ G = import_class(graph)()
226
+ edge = G.inward
227
+ pose = []
228
+ for m in range(M):
229
+ a = []
230
+ for i in range(len(edge)):
231
+ if is_3d:
232
+ a.append(ax.plot(np.zeros(3), np.zeros(3), p_type[m])[0])
233
+ else:
234
+ a.append(ax.plot(np.zeros(2), np.zeros(2), p_type[m])[0])
235
+ pose.append(a)
236
+ ax.axis([-1, 1, -1, 1])
237
+ if is_3d:
238
+ ax.set_zlim3d(-1, 1)
239
+ for t in range(T):
240
+ for m in range(M):
241
+ for i, (v1, v2) in enumerate(edge):
242
+ x1 = data[0, :2, t, v1, m]
243
+ x2 = data[0, :2, t, v2, m]
244
+ if (x1.sum() != 0 and x2.sum() != 0) or v1 == 1 or v2 == 1:
245
+ pose[m][i].set_xdata(data[0, 0, t, [v1, v2], m])
246
+ pose[m][i].set_ydata(data[0, 1, t, [v1, v2], m])
247
+ if is_3d:
248
+ pose[m][i].set_3d_properties(data[0, 2, t, [v1, v2], m])
249
+
250
+ if gt[t]:
251
+ text = ax.text2D(
252
+ 0.1, 0.9, idx2act[int(label)], size=20, transform=ax.transAxes
253
+ )
254
+
255
+ if preds is not None:
256
+ pred_idx = preds[t].argmax()
257
+ text_pred = ax.text2D(
258
+ 0.6,
259
+ 0.9,
260
+ idx2act[int(pred_idx)] + f": {preds[t, pred_idx]:.2f}",
261
+ size=20,
262
+ transform=ax.transAxes,
263
+ )
264
+
265
+ fig.canvas.draw()
266
+ plt.savefig(osp.join(folder_p, "frames", str(t) + ".jpg"), dpi=300)
267
+ if gt[t]:
268
+ text.remove()
269
+ if preds is not None:
270
+ text_pred.remove()
271
+
272
+ write_vid_from_imgs(folder_p, idx)
273
+
274
+
275
+ def write_vid_from_imgs(folder_p, fname, fps=30):
276
+ """Collate frames into a video sequence.
277
+
278
+ Args:
279
+ folder_p (str): Frame images are in the path: folder_p/frames/<int>.jpg
280
+ fps (float): Output frame rate.
281
+
282
+ Returns:
283
+ Output video is stored in the path: folder_p/video.mp4
284
+ """
285
+ vid_p = osp.join(folder_p, f"{fname}.mp4")
286
+ cmd = [
287
+ "ffmpeg",
288
+ "-r",
289
+ str(int(fps)),
290
+ "-i",
291
+ osp.join(folder_p, "frames", "%d.jpg"),
292
+ "-y",
293
+ vid_p,
294
+ ]
295
+ FNULL = open(os.devnull, "w")
296
+ retcode = subprocess.call(cmd, stdout=FNULL, stderr=subprocess.STDOUT)
297
+ if not 0 == retcode:
298
+ print(
299
+ "*******ValueError(Error {0} executing command: {1}*********".format(
300
+ retcode, " ".join(cmd)
301
+ )
302
+ )
303
+ shutil.rmtree(osp.join(folder_p, "frames"))
304
+
305
+
306
+ if __name__ == "__main__":
307
+ import os
308
+
309
+ os.environ["DISPLAY"] = "localhost:10.0"
310
+ data_path = "dataset/processed_data/train_split1.pkl"
311
+ graph = "graph.ntu_rgb_d.Graph"
312
+ dataset = Feeder(data_path)
313
+ test(dataset, idx=0, graph=graph, is_3d=True)
feeders/tools.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.nn.utils.rnn import pad_sequence
6
+
7
+
8
+ def downsample(data_numpy, step, random_sample=True):
9
+ # input: C,T,V,M
10
+ begin = np.random.randint(step) if random_sample else 0
11
+ return data_numpy[:, begin::step, :, :]
12
+
13
+
14
+ def temporal_slice(data_numpy, step):
15
+ # input: C,T,V,M
16
+ C, T, V, M = data_numpy.shape
17
+ return (
18
+ data_numpy.reshape(C, T / step, step, V, M)
19
+ .transpose((0, 1, 3, 2, 4))
20
+ .reshape(C, T / step, V, step * M)
21
+ )
22
+
23
+
24
+ def mean_subtractor(data_numpy, mean):
25
+ # input: C,T,V,M
26
+ # naive version
27
+ if mean == 0:
28
+ return
29
+ C, T, V, M = data_numpy.shape
30
+ valid_frame = (data_numpy != 0).sum(axis=3).sum(axis=2).sum(axis=0) > 0
31
+ begin = valid_frame.argmax()
32
+ end = len(valid_frame) - valid_frame[::-1].argmax()
33
+ data_numpy[:, :end, :, :] = data_numpy[:, :end, :, :] - mean
34
+ return data_numpy
35
+
36
+
37
+ def auto_pading(data_numpy, size, random_pad=False):
38
+ C, T, V, M = data_numpy.shape
39
+ if T < size:
40
+ begin = random.randint(0, size - T) if random_pad else 0
41
+ data_numpy_paded = np.zeros((C, size, V, M))
42
+ data_numpy_paded[:, begin : begin + T, :, :] = data_numpy
43
+ return data_numpy_paded
44
+ else:
45
+ return data_numpy
46
+
47
+
48
+ def random_choose(data_numpy, size, auto_pad=True):
49
+ # input: C,T,V,M 随机选择其中一段,不是很合理。因为有0
50
+ C, T, V, M = data_numpy.shape
51
+ if T == size:
52
+ return data_numpy
53
+ elif T < size:
54
+ if auto_pad:
55
+ return auto_pading(data_numpy, size, random_pad=True)
56
+ else:
57
+ return data_numpy
58
+ else:
59
+ begin = random.randint(0, T - size)
60
+ return data_numpy[:, begin : begin + size, :, :]
61
+
62
+
63
+ def random_move(
64
+ data_numpy,
65
+ angle_candidate=[-10.0, -5.0, 0.0, 5.0, 10.0],
66
+ scale_candidate=[0.9, 1.0, 1.1],
67
+ transform_candidate=[-0.2, -0.1, 0.0, 0.1, 0.2],
68
+ move_time_candidate=[1],
69
+ ):
70
+ # input: C,T,V,M
71
+ C, T, V, M = data_numpy.shape
72
+ move_time = random.choice(move_time_candidate)
73
+ node = np.arange(0, T, T * 1.0 / move_time).round().astype(int)
74
+ node = np.append(node, T)
75
+ num_node = len(node)
76
+
77
+ A = np.random.choice(angle_candidate, num_node)
78
+ S = np.random.choice(scale_candidate, num_node)
79
+ T_x = np.random.choice(transform_candidate, num_node)
80
+ T_y = np.random.choice(transform_candidate, num_node)
81
+
82
+ a = np.zeros(T)
83
+ s = np.zeros(T)
84
+ t_x = np.zeros(T)
85
+ t_y = np.zeros(T)
86
+
87
+ # linspace
88
+ for i in range(num_node - 1):
89
+ a[node[i] : node[i + 1]] = (
90
+ np.linspace(A[i], A[i + 1], node[i + 1] - node[i]) * np.pi / 180
91
+ )
92
+ s[node[i] : node[i + 1]] = np.linspace(S[i], S[i + 1], node[i + 1] - node[i])
93
+ t_x[node[i] : node[i + 1]] = np.linspace(
94
+ T_x[i], T_x[i + 1], node[i + 1] - node[i]
95
+ )
96
+ t_y[node[i] : node[i + 1]] = np.linspace(
97
+ T_y[i], T_y[i + 1], node[i + 1] - node[i]
98
+ )
99
+
100
+ theta = np.array(
101
+ [[np.cos(a) * s, -np.sin(a) * s], [np.sin(a) * s, np.cos(a) * s]]
102
+ ) # xuanzhuan juzhen
103
+
104
+ # perform transformation
105
+ for i_frame in range(T):
106
+ xy = data_numpy[0:2, i_frame, :, :]
107
+ new_xy = np.dot(theta[:, :, i_frame], xy.reshape(2, -1))
108
+ new_xy[0] += t_x[i_frame]
109
+ new_xy[1] += t_y[i_frame] # pingyi bianhuan
110
+ data_numpy[0:2, i_frame, :, :] = new_xy.reshape(2, V, M)
111
+
112
+ return data_numpy
113
+
114
+
115
+ def random_shift(data_numpy):
116
+ # input: C,T,V,M 偏移其中一段
117
+ C, T, V, M = data_numpy.shape
118
+ data_shift = np.zeros(data_numpy.shape)
119
+ valid_frame = (data_numpy != 0).sum(axis=3).sum(axis=2).sum(axis=0) > 0
120
+ begin = valid_frame.argmax()
121
+ end = len(valid_frame) - valid_frame[::-1].argmax()
122
+
123
+ size = end - begin
124
+ bias = random.randint(0, T - size)
125
+ data_shift[:, bias : bias + size, :, :] = data_numpy[:, begin:end, :, :]
126
+
127
+ return data_shift
128
+
129
+
130
+ def openpose_match(data_numpy):
131
+ C, T, V, M = data_numpy.shape
132
+ assert C == 3
133
+ score = data_numpy[2, :, :, :].sum(axis=1)
134
+ # the rank of body confidence in each frame (shape: T-1, M)
135
+ rank = (-score[0 : T - 1]).argsort(axis=1).reshape(T - 1, M)
136
+
137
+ # data of frame 1
138
+ xy1 = data_numpy[0:2, 0 : T - 1, :, :].reshape(2, T - 1, V, M, 1)
139
+ # data of frame 2
140
+ xy2 = data_numpy[0:2, 1:T, :, :].reshape(2, T - 1, V, 1, M)
141
+ # square of distance between frame 1&2 (shape: T-1, M, M)
142
+ distance = ((xy2 - xy1) ** 2).sum(axis=2).sum(axis=0)
143
+
144
+ # match pose
145
+ forward_map = np.zeros((T, M), dtype=int) - 1
146
+ forward_map[0] = range(M)
147
+ for m in range(M):
148
+ choose = rank == m
149
+ forward = distance[choose].argmin(axis=1)
150
+ for t in range(T - 1):
151
+ distance[t, :, forward[t]] = np.inf
152
+ forward_map[1:][choose] = forward
153
+ assert np.all(forward_map >= 0)
154
+
155
+ # string data
156
+ for t in range(T - 1):
157
+ forward_map[t + 1] = forward_map[t + 1][forward_map[t]]
158
+
159
+ # generate data
160
+ new_data_numpy = np.zeros(data_numpy.shape)
161
+ for t in range(T):
162
+ new_data_numpy[:, t, :, :] = data_numpy[:, t, :, forward_map[t]].transpose(
163
+ 1, 2, 0
164
+ )
165
+ data_numpy = new_data_numpy
166
+
167
+ # score sort
168
+ trace_score = data_numpy[2, :, :, :].sum(axis=1).sum(axis=0)
169
+ rank = (-trace_score).argsort()
170
+ data_numpy = data_numpy[:, :, :, rank]
171
+
172
+ return data_numpy
173
+
174
+
175
+ def pad(tensor, padding_value=0):
176
+ return pad_sequence(tensor, batch_first=True, padding_value=padding_value)
177
+
178
+
179
+ def collate_with_padding(batch):
180
+ data = [torch.tensor(item[0].transpose(1, 0, 2, 3)) for item in batch]
181
+ target = [torch.tensor(item[1]) for item in batch]
182
+ gt = [torch.tensor(item[2]) for item in batch]
183
+ mask = [torch.tensor(item[3]) for item in batch]
184
+
185
+ data = pad(data).transpose(1, 2)
186
+ target = torch.tensor(target)
187
+ gt = pad(gt)
188
+ mask = pad(mask)
189
+ return [data, target, gt, mask]
190
+
191
+
192
+ def collate_with_padding_multi(batch):
193
+ data = [torch.tensor(item[0].transpose(1, 0, 2, 3)) for item in batch]
194
+ target = [torch.tensor(item[1]) for item in batch]
195
+ gt = [torch.tensor(item[2]) for item in batch]
196
+ mask = [torch.tensor(item[3]) for item in batch]
197
+
198
+ data = pad(data).transpose(1, 2)
199
+ target = torch.stack(target)
200
+ gt = pad(gt)
201
+ mask = pad(mask)
202
+ return [data, target, gt, mask]
203
+
204
+
205
+ def collate_with_padding_multi_velo(batch):
206
+ data = [torch.tensor(item[0].transpose(1, 0, 2, 3)) for item in batch]
207
+ velo = [torch.tensor(item[1].transpose(1, 0, 2, 3)) for item in batch]
208
+ target = [torch.tensor(item[2]) for item in batch]
209
+ gt = [torch.tensor(item[3]) for item in batch]
210
+ mask = [torch.tensor(item[4]) for item in batch]
211
+
212
+ data = pad(data).transpose(1, 2)
213
+ velo = pad(velo).transpose(1, 2)
214
+ target = torch.stack(target)
215
+ gt = pad(gt)
216
+ mask = pad(mask)
217
+ return [data, velo, target, gt, mask]
218
+
219
+
220
+ def collate_with_padding_multi_joint(batch):
221
+ data = [torch.tensor(item[0].transpose(1, 0, 2, 3)) for item in batch] # shape?
222
+ target = [torch.tensor(item[1]) for item in batch] # video level label
223
+ gt = [torch.tensor(item[2]) for item in batch] # frame level label
224
+ mask = [torch.tensor(item[3]) for item in batch]
225
+ index = [torch.tensor(item[4]) for item in batch]
226
+ soft_label = [torch.tensor(item[5]) for item in batch]
227
+
228
+ data = pad(data).transpose(1, 2) # pad joints seq with 0, rather than the last frame
229
+ target = torch.stack(target)
230
+ gt = pad(gt,padding_value=4) # pad frame level label with 0, so 0 hhere have to stands for 'background', o for 4 action it will be 0,1,2,3,4,
231
+ mask = pad(mask)
232
+ index = torch.tensor(index)
233
+ soft_label = pad(soft_label, padding_value=-100)
234
+ return [data, target, gt, mask, index, soft_label]
graph/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import kinetics, ntu_rgb_d, tools,ntu_rgb_d_infogcn, nturgbd_blockgcn
graph/kinetics.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import networkx as nx
4
+ import numpy as np
5
+ from graph import tools
6
+
7
+ sys.path.extend(["../"])
8
+
9
+ # Joint index:
10
+ # {0, "Nose"}
11
+ # {1, "Neck"},
12
+ # {2, "RShoulder"},
13
+ # {3, "RElbow"},
14
+ # {4, "RWrist"},
15
+ # {5, "LShoulder"},
16
+ # {6, "LElbow"},
17
+ # {7, "LWrist"},
18
+ # {8, "RHip"},
19
+ # {9, "RKnee"},
20
+ # {10, "RAnkle"},
21
+ # {11, "LHip"},
22
+ # {12, "LKnee"},
23
+ # {13, "LAnkle"},
24
+ # {14, "REye"},
25
+ # {15, "LEye"},
26
+ # {16, "REar"},
27
+ # {17, "LEar"},
28
+
29
+ # Edge format: (origin, neighbor)
30
+ num_node = 18
31
+ self_link = [(i, i) for i in range(num_node)]
32
+ inward = [
33
+ (4, 3),
34
+ (3, 2),
35
+ (7, 6),
36
+ (6, 5),
37
+ (13, 12),
38
+ (12, 11),
39
+ (10, 9),
40
+ (9, 8),
41
+ (11, 5),
42
+ (8, 2),
43
+ (5, 1),
44
+ (2, 1),
45
+ (0, 1),
46
+ (15, 0),
47
+ (14, 0),
48
+ (17, 15),
49
+ (16, 14),
50
+ ]
51
+ outward = [(j, i) for (i, j) in inward]
52
+ neighbor = inward + outward
53
+
54
+
55
+ class Graph:
56
+ def __init__(self, labeling_mode="spatial"):
57
+ self.A = self.get_adjacency_matrix(labeling_mode)
58
+ self.num_node = num_node
59
+ self.self_link = self_link
60
+ self.inward = inward
61
+ self.outward = outward
62
+ self.neighbor = neighbor
63
+
64
+ def get_adjacency_matrix(self, labeling_mode=None):
65
+ if labeling_mode is None:
66
+ return self.A
67
+ if labeling_mode == "spatial":
68
+ A = tools.get_spatial_graph(num_node, self_link, inward, outward)
69
+ else:
70
+ raise ValueError()
71
+ return A
72
+
73
+
74
+ if __name__ == "__main__":
75
+ A = Graph("spatial").get_adjacency_matrix()
76
+ print("")
graph/ntu_rgb_d.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ from graph import tools
4
+
5
+ sys.path.extend(["../"])
6
+
7
+ num_node = 25
8
+ self_link = [(i, i) for i in range(num_node)]
9
+ inward_ori_index = [
10
+ (1, 2),
11
+ (2, 21),
12
+ (3, 21),
13
+ (4, 3),
14
+ (5, 21),
15
+ (6, 5),
16
+ (7, 6),
17
+ (8, 7),
18
+ (9, 21),
19
+ (10, 9),
20
+ (11, 10),
21
+ (12, 11),
22
+ (13, 1),
23
+ (14, 13),
24
+ (15, 14),
25
+ (16, 15),
26
+ (17, 1),
27
+ (18, 17),
28
+ (19, 18),
29
+ (20, 19),
30
+ (22, 23),
31
+ (23, 8),
32
+ (24, 25),
33
+ (25, 12),
34
+ ]
35
+ inward = [(i - 1, j - 1) for (i, j) in inward_ori_index]
36
+ outward = [(j, i) for (i, j) in inward]
37
+ neighbor = inward + outward
38
+
39
+
40
+ class Graph:
41
+ def __init__(self, labeling_mode="spatial"):
42
+ self.A = self.get_adjacency_matrix(labeling_mode)
43
+ self.num_node = num_node
44
+ self.self_link = self_link
45
+ self.inward = inward
46
+ self.outward = outward
47
+ self.neighbor = neighbor
48
+
49
+ def get_adjacency_matrix(self, labeling_mode=None):
50
+ if labeling_mode is None:
51
+ return self.A
52
+ if labeling_mode == "spatial":
53
+ A = tools.get_spatial_graph(num_node, self_link, inward, outward)
54
+ else:
55
+ raise ValueError()
56
+ return A
57
+
58
+
59
+ if __name__ == "__main__":
60
+ import os
61
+
62
+ import matplotlib.pyplot as plt
63
+
64
+ # os.environ['DISPLAY'] = 'localhost:11.0'
65
+ A = Graph("spatial").get_adjacency_matrix()
66
+ for i in A:
67
+ plt.imshow(i, cmap="gray")
68
+ plt.show()
69
+ print(A)
graph/tools.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def edge2mat(link, num_node):
5
+ A = np.zeros((num_node, num_node))
6
+ for i, j in link:
7
+ A[j, i] = 1
8
+ return A
9
+
10
+
11
+ def normalize_digraph(A): # 除以每列的和
12
+ Dl = np.sum(A, 0)
13
+ h, w = A.shape
14
+ Dn = np.zeros((w, w))
15
+ for i in range(w):
16
+ if Dl[i] > 0:
17
+ Dn[i, i] = Dl[i] ** (-1)
18
+ AD = np.dot(A, Dn)
19
+ return AD
20
+
21
+
22
+ def get_spatial_graph(num_node, self_link, inward, outward):
23
+ I = edge2mat(self_link, num_node)
24
+ In = normalize_digraph(edge2mat(inward, num_node))
25
+ Out = normalize_digraph(edge2mat(outward, num_node))
26
+ A = np.stack((I, In, Out))
27
+ return A
28
+
29
+
30
+ import numpy as np
31
+
32
+ def get_sgp_mat(num_in, num_out, link):
33
+ A = np.zeros((num_in, num_out))
34
+ for i, j in link:
35
+ A[i, j] = 1
36
+ A_norm = A / np.sum(A, axis=0, keepdims=True)
37
+ return A_norm
38
+
39
+ def edge2mat(link, num_node):
40
+ A = np.zeros((num_node, num_node))
41
+ for i, j in link:
42
+ A[j, i] = 1
43
+ return A
44
+
45
+ def get_k_scale_graph(scale, A):
46
+ if scale == 1:
47
+ return A
48
+ An = np.zeros_like(A)
49
+ A_power = np.eye(A.shape[0])
50
+ for k in range(scale):
51
+ A_power = A_power @ A
52
+ An += A_power
53
+ An[An > 0] = 1
54
+ return An
55
+
56
+ def normalize_digraph(A):
57
+ Dl = np.sum(A, 0)
58
+ h, w = A.shape
59
+ Dn = np.zeros((w, w))
60
+ for i in range(w):
61
+ if Dl[i] > 0:
62
+ Dn[i, i] = Dl[i] ** (-1)
63
+ AD = np.dot(A, Dn)
64
+ return AD
65
+
66
+
67
+ def get_spatial_graph(num_node, self_link, inward, outward):
68
+ I = edge2mat(self_link, num_node)
69
+ In = normalize_digraph(edge2mat(inward, num_node))
70
+ Out = normalize_digraph(edge2mat(outward, num_node))
71
+ A = np.stack((I, In, Out))
72
+ return A
73
+
74
+ def normalize_adjacency_matrix(A):
75
+ node_degrees = A.sum(-1)
76
+ degs_inv_sqrt = np.power(node_degrees, -0.5)
77
+ norm_degs_matrix = np.eye(len(node_degrees)) * degs_inv_sqrt
78
+ return (norm_degs_matrix @ A @ norm_degs_matrix).astype(np.float32)
79
+
80
+
81
+ def k_adjacency(A, k, with_self=False, self_factor=1):
82
+ assert isinstance(A, np.ndarray)
83
+ I = np.eye(len(A), dtype=A.dtype)
84
+ if k == 0:
85
+ return I
86
+ Ak = np.minimum(np.linalg.matrix_power(A + I, k), 1) \
87
+ - np.minimum(np.linalg.matrix_power(A + I, k - 1), 1)
88
+ if with_self:
89
+ Ak += (self_factor * I)
90
+ return Ak
91
+
92
+ def get_multiscale_spatial_graph(num_node, self_link, inward, outward):
93
+ I = edge2mat(self_link, num_node)
94
+ A1 = edge2mat(inward, num_node)
95
+ A2 = edge2mat(outward, num_node)
96
+ A3 = k_adjacency(A1, 2)
97
+ A4 = k_adjacency(A2, 2)
98
+ A1 = normalize_digraph(A1)
99
+ A2 = normalize_digraph(A2)
100
+ A3 = normalize_digraph(A3)
101
+ A4 = normalize_digraph(A4)
102
+ A = np.stack((I, A1, A2, A3, A4))
103
+ return A
104
+
105
+ def get_adjacency_matrix(edges, num_nodes):
106
+ A = np.zeros((num_nodes, num_nodes), dtype=np.float32)
107
+ for edge in edges:
108
+ A[edge] = 1.
109
+ return A
110
+
111
+ def get_uniform_graph(num_node, self_link, neighbor):
112
+ A = normalize_digraph(edge2mat(neighbor + self_link, num_node))
113
+ return A
huggingface.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import upload_folder
2
+
3
+ upload_folder(
4
+ folder_path="/root/autodl-tmp/workshop2",
5
+ repo_id="qiushuocheng/workshop",
6
+ repo_type="model", # or "dataset", "space"
7
+ path_in_repo="", # optional subdirectory in repo
8
+ commit_message="Initial upload"
9
+ )
human_model/Put SMPLH model here.txt ADDED
File without changes
model/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from . import agcn
2
+ from . import agcn_Unet
3
+ from . import agcn_concat
4
+ from . import agcn_MSFF
5
+ from . import blockgcn_MSFF
6
+
7
+ from . import blockgcn
model/agcn.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright 2023 LINE Corporation
3
+ LINE Corporation licenses this file to you under the Apache License,
4
+ version 2.0 (the "License"); you may not use this file except in compliance
5
+ with the License. You may obtain a copy of the License at:
6
+ https://www.apache.org/licenses/LICENSE-2.0
7
+ Unless required by applicable law or agreed to in writing, software
8
+ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9
+ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10
+ License for the specific language governing permissions and limitations
11
+ under the License.
12
+ """
13
+
14
+ import math
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from einops import rearrange, reduce, repeat
21
+ from torch.autograd import Variable
22
+
23
+
24
+ def import_class(name):
25
+ components = name.split(".")
26
+ mod = __import__(components[0])
27
+ for comp in components[1:]:
28
+ mod = getattr(mod, comp)
29
+ return mod
30
+
31
+
32
+ def conv_branch_init(conv, branches):
33
+ weight = conv.weight
34
+ n = weight.size(0)
35
+ k1 = weight.size(1)
36
+ k2 = weight.size(2)
37
+ nn.init.normal_(weight, 0, math.sqrt(2.0 / (n * k1 * k2 * branches)))
38
+ nn.init.constant_(conv.bias, 0)
39
+
40
+
41
+ def conv_init(conv):
42
+ nn.init.kaiming_normal_(conv.weight, mode="fan_out")
43
+ nn.init.constant_(conv.bias, 0)
44
+
45
+
46
+ def bn_init(bn, scale):
47
+ nn.init.constant_(bn.weight, scale)
48
+ nn.init.constant_(bn.bias, 0)
49
+
50
+
51
+ class unit_tcn(nn.Module):
52
+ def __init__(self, in_channels, out_channels, kernel_size=9, stride=1):
53
+ super(unit_tcn, self).__init__()
54
+ pad = int((kernel_size - 1) / 2)
55
+ self.conv = nn.Conv2d(
56
+ in_channels,
57
+ out_channels,
58
+ kernel_size=(kernel_size, 1),
59
+ padding=(pad, 0),
60
+ stride=(stride, 1),
61
+ )
62
+
63
+ self.bn = nn.BatchNorm2d(out_channels)
64
+ self.relu = nn.ReLU()
65
+ conv_init(self.conv)
66
+ bn_init(self.bn, 1)
67
+
68
+ def forward(self, x):
69
+ x = self.bn(self.conv(x))
70
+ return x
71
+
72
+
73
+ class unit_gcn(nn.Module):
74
+ def __init__(self, in_channels, out_channels, A, coff_embedding=4, num_subset=3):
75
+ super(unit_gcn, self).__init__()
76
+ inter_channels = out_channels // coff_embedding
77
+ self.inter_c = inter_channels
78
+ self.PA = nn.Parameter(torch.from_numpy(A.astype(np.float32)))
79
+ nn.init.constant_(self.PA, 1e-6)
80
+ self.A = Variable(torch.from_numpy(A.astype(np.float32)), requires_grad=False)
81
+ self.num_subset = num_subset
82
+
83
+ self.conv_a = nn.ModuleList()
84
+ self.conv_b = nn.ModuleList()
85
+ self.conv_d = nn.ModuleList()
86
+ for i in range(self.num_subset):
87
+ self.conv_a.append(nn.Conv2d(in_channels, inter_channels, 1))
88
+ self.conv_b.append(nn.Conv2d(in_channels, inter_channels, 1))
89
+ self.conv_d.append(nn.Conv2d(in_channels, out_channels, 1))
90
+
91
+ if in_channels != out_channels:
92
+ self.down = nn.Sequential(
93
+ nn.Conv2d(in_channels, out_channels, 1), nn.BatchNorm2d(out_channels)
94
+ )
95
+ else:
96
+ self.down = lambda x: x
97
+
98
+ self.bn = nn.BatchNorm2d(out_channels)
99
+ self.soft = nn.Softmax(-2)
100
+ self.relu = nn.ReLU()
101
+
102
+ for m in self.modules():
103
+ if isinstance(m, nn.Conv2d):
104
+ conv_init(m)
105
+ elif isinstance(m, nn.BatchNorm2d):
106
+ bn_init(m, 1)
107
+ bn_init(self.bn, 1e-6)
108
+ for i in range(self.num_subset):
109
+ conv_branch_init(self.conv_d[i], self.num_subset)
110
+
111
+ def forward(self, x):
112
+ N, C, T, V = x.size()
113
+ A = self.A
114
+ if -1 != x.get_device():
115
+ A = A.cuda(x.get_device())
116
+ A = A + self.PA
117
+
118
+ y = None
119
+ for i in range(self.num_subset):
120
+ A1 = (
121
+ self.conv_a[i](x)
122
+ .permute(0, 3, 1, 2)
123
+ .contiguous()
124
+ .view(N, V, self.inter_c * T)
125
+ )
126
+ A2 = self.conv_b[i](x).view(N, self.inter_c * T, V)
127
+ A1 = self.soft(torch.matmul(A1, A2) / A1.size(-1)) # N V V
128
+ A1 = A1 + A[i]
129
+ A2 = x.view(N, C * T, V)
130
+ z = self.conv_d[i](torch.matmul(A2, A1).view(N, C, T, V))
131
+ y = z + y if y is not None else z
132
+
133
+ y = self.bn(y)
134
+ y += self.down(x)
135
+ return self.relu(y)
136
+
137
+
138
+ class TCN_GCN_unit(nn.Module):
139
+ def __init__(self, in_channels, out_channels, A, stride=1, residual=True):
140
+ super(TCN_GCN_unit, self).__init__()
141
+ self.gcn1 = unit_gcn(in_channels, out_channels, A)
142
+ self.tcn1 = unit_tcn(out_channels, out_channels, stride=stride)
143
+ self.relu = nn.ReLU()
144
+ if not residual:
145
+ self.residual = lambda x: 0
146
+
147
+ elif (in_channels == out_channels) and (stride == 1):
148
+ self.residual = lambda x: x
149
+
150
+ else:
151
+ self.residual = unit_tcn(
152
+ in_channels, out_channels, kernel_size=1, stride=stride
153
+ )
154
+
155
+ def forward(self, x):
156
+ x = self.tcn1(self.gcn1(x)) + self.residual(x)
157
+ return self.relu(x)
158
+
159
+
160
+ class Classifier(nn.Module):
161
+ def __init__(self, num_class=60, scale_factor=5.0, temperature=[1.0, 2.0, 5.0]):
162
+ super(Classifier, self).__init__()
163
+
164
+ # action features
165
+ self.ac_center = nn.Parameter(torch.zeros(num_class + 1, 256))
166
+ nn.init.xavier_uniform_(self.ac_center)
167
+ # foreground feature
168
+
169
+ self.temperature = temperature
170
+ self.scale_factor = scale_factor
171
+
172
+ def forward(self, x):
173
+
174
+ N = x.size(0)
175
+
176
+ x_emb = reduce(x, "(n m) c t v -> n t c", "mean", n=N)
177
+
178
+ norms_emb = F.normalize(x_emb, dim=2)
179
+ norms_ac = F.normalize(self.ac_center)
180
+
181
+ # generate foeground and action scores
182
+ frm_scrs = (
183
+ torch.einsum("ntd,cd->ntc", [norms_emb, norms_ac]) * self.scale_factor
184
+ )
185
+
186
+ # attention
187
+ class_wise_atts = [F.softmax(frm_scrs * t, 1) for t in self.temperature]
188
+
189
+ # multiple instance learning branch
190
+ # temporal score aggregation
191
+ mid_vid_scrs = [
192
+ torch.einsum("ntc,ntc->nc", [frm_scrs, att]) for att in class_wise_atts
193
+ ]
194
+ mil_vid_scr = (
195
+ torch.stack(mid_vid_scrs, -1).mean(-1) * 2.0
196
+ ) # frm_scrs have been multiplied by the scale factor
197
+ mil_vid_pred = F.sigmoid(mil_vid_scr)
198
+
199
+ return mil_vid_pred, frm_scrs
200
+
201
+
202
+ class Model(nn.Module):
203
+ def __init__(
204
+ self,
205
+ num_class=60,
206
+ num_point=25,
207
+ num_person=1,
208
+ graph=None,
209
+ graph_args=dict(),
210
+ in_channels=2,
211
+ scale_factor=5.0,
212
+ temperature=[1.0, 2.0, 5.0],
213
+ ):
214
+ super(Model, self).__init__()
215
+
216
+ if graph is None:
217
+ raise ValueError()
218
+ else:
219
+ Graph = import_class(graph)
220
+ self.graph = Graph(**graph_args)
221
+
222
+ A = self.graph.A
223
+ self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point)
224
+
225
+ self.l1 = TCN_GCN_unit(3, 64, A, residual=False) # save (B,64,25,T)
226
+ self.l2 = TCN_GCN_unit(64, 64, A, stride=2)
227
+ self.l3 = TCN_GCN_unit(64, 64, A)
228
+ self.l4 = TCN_GCN_unit(64, 64, A) # save (B,64,25,T/2)
229
+ self.l5 = TCN_GCN_unit(64, 128, A, stride=2)
230
+ self.l6 = TCN_GCN_unit(128, 128, A)
231
+ self.l7 = TCN_GCN_unit(128, 128, A) # save (B,128,25,T/4)
232
+ self.l8 = TCN_GCN_unit(128, 256, A, stride=2)
233
+ self.l9 = TCN_GCN_unit(256, 256, A)
234
+ self.l10 = TCN_GCN_unit(256, 256, A) # save (B,256,25,T/8)
235
+
236
+ bn_init(self.data_bn, 1)
237
+
238
+ self.classifier_1 = Classifier(num_class, scale_factor, temperature)
239
+
240
+ self.classifier_2 = Classifier(num_class, scale_factor, temperature)
241
+
242
+ def forward(self, x,mask):
243
+ N, C, T, V, M = x.size()
244
+
245
+ x = rearrange(x, "n c t v m -> n (m v c) t")
246
+ # x = self.data_bn(x)
247
+ x = rearrange(x, "n (m v c) t -> (n m) c t v", m=M, v=V, c=C)
248
+
249
+ x = self.l1(x)
250
+ x = self.l2(x)
251
+ x = self.l3(x)
252
+ x = self.l4(x)
253
+ x = self.l5(x)
254
+ x = self.l6(x)
255
+ x = self.l7(x)
256
+ x = self.l8(x)
257
+ x = self.l9(x)
258
+ x = self.l10(x)
259
+
260
+ mil_vid_pred_1, frm_scrs_1 = self.classifier_1(x)
261
+
262
+ mil_vid_pred_2, frm_scrs_2 = self.classifier_2(x.detach())
263
+
264
+ # print (frm_scrs_1.size(), T)
265
+
266
+ frm_scrs_1 = rearrange(frm_scrs_1, "n t c -> n c t")
267
+ frm_scrs_1 = F.interpolate(
268
+ frm_scrs_1, size=(T), mode="linear", align_corners=True
269
+ )
270
+ frm_scrs_1 = rearrange(frm_scrs_1, "n c t -> n t c")
271
+
272
+ frm_scrs_2 = rearrange(frm_scrs_2, "n t c -> n c t")
273
+ frm_scrs_2 = F.interpolate(
274
+ frm_scrs_2, size=(T), mode="linear", align_corners=True
275
+ )
276
+ frm_scrs_2 = rearrange(frm_scrs_2, "n c t -> n t c")
277
+
278
+ return mil_vid_pred_1, frm_scrs_1, mil_vid_pred_2, frm_scrs_2
model/losses.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright 2023 LINE Corporation
3
+ LINE Corporation licenses this file to you under the Apache License,
4
+ version 2.0 (the "License"); you may not use this file except in compliance
5
+ with the License. You may obtain a copy of the License at:
6
+ https://www.apache.org/licenses/LICENSE-2.0
7
+ Unless required by applicable law or agreed to in writing, software
8
+ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9
+ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10
+ License for the specific language governing permissions and limitations
11
+ under the License.
12
+ """
13
+
14
+ from platform import mac_ver
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from einops import rearrange, reduce, repeat
21
+ from torch.autograd import Variable
22
+
23
+
24
+ def kl_loss_compute(pred, soft_targets, reduce=True):
25
+
26
+ kl = F.kl_div(
27
+ F.log_softmax(pred, dim=1), F.softmax(soft_targets, dim=1), reduce=False
28
+ )
29
+ if reduce:
30
+ return torch.mean(torch.sum(kl, dim=1))
31
+ else:
32
+ return torch.sum(kl, 1)
33
+
34
+
35
+ def mvl_loss(y_1, y_2, rate=0.2, weight=0.1):
36
+ y_1 = rearrange(y_1, "n t c -> (n t) c")
37
+ y_2 = rearrange(y_2, "n t c -> (n t) c")
38
+
39
+ loss_pick = weight * kl_loss_compute(
40
+ y_1, y_2, reduce=False
41
+ ) + weight * kl_loss_compute(y_2, y_1, reduce=False)
42
+
43
+ loss_pick = loss_pick.cpu().detach()
44
+
45
+ ind_sorted = torch.argsort(loss_pick.data)
46
+ loss_sorted = loss_pick[ind_sorted]
47
+
48
+ num_remember = int(rate * len(loss_sorted))
49
+
50
+ ind_update = ind_sorted[:num_remember]
51
+
52
+ loss = torch.mean(loss_pick[ind_update])
53
+
54
+ return loss
55
+
56
+
57
+ def cross_entropy_loss(outputs, soft_targets):
58
+ mask = (soft_targets != -100).sum(1) > 0
59
+ outputs = outputs[mask]
60
+ soft_targets = soft_targets[mask]
61
+ loss = -torch.mean(torch.sum(F.log_softmax(outputs, dim=1) * soft_targets, dim=1))
62
+ return loss
63
+
prepare/configs/action_label_split1.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "walk": 0,
3
+ "stand": 1,
4
+ "turn": 2,
5
+ "jump": 3
6
+ }
prepare/configs/action_label_split2.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "sit": 0,
3
+ "run": 1,
4
+ "stand up": 2,
5
+ "kick": 3
6
+ }
prepare/configs/action_label_split3.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "jog": 0,
3
+ "wave": 1,
4
+ "dance": 2,
5
+ "gesture": 3
6
+ }
prepare/create_dataset.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/abhinanda-punnakkal/BABEL/ to frame-wise motion segmentation
2
+ #! /usr/bin/env python
3
+ # -*- coding: utf-8 -*-
4
+ # vim:fenc=utf-8
5
+ #
6
+ # Copyright © 2021 achandrasekaran <arjun.chandrasekaran@tuebingen.mpg.de>
7
+ #
8
+ # Distributed under terms of the MIT license.
9
+
10
+
11
+ import csv
12
+ import json
13
+ import os
14
+ import pdb
15
+ import pickle
16
+ import sys
17
+ from collections import *
18
+ from itertools import *
19
+ from os.path import basename as ospb
20
+ from os.path import dirname as ospd
21
+ from os.path import join as ospj
22
+
23
+ import dutils
24
+ import ipdb
25
+ import numpy as np
26
+ import pandas as pd
27
+
28
+ # Custom
29
+ import preprocess
30
+ import torch
31
+ import viz
32
+ from pandas.core.common import flatten
33
+ from tqdm import tqdm
34
+
35
+ """
36
+ Script to load BABEL segments with NTU skeleton format and pre-process.
37
+ """
38
+
39
+
40
+ def ntu_style_preprocessing(b_dset_path):
41
+ """"""
42
+ print("Load BABEL v1.0 dataset subset", b_dset_path)
43
+ b_dset = dutils.read_pkl(b_dset_path)
44
+
45
+ X_new = []
46
+ Y_new = []
47
+ id_new = []
48
+ for idx in range(len(b_dset["X"])):
49
+ # Get unnormalized 5-sec. samples
50
+ X = np.array(b_dset["X"][idx])
51
+ print("X (old) = ", np.shape(X)) # T, V, C
52
+
53
+ X = X[np.newaxis, :, :, :]
54
+
55
+ # Prep. data for normalization
56
+ X = X.transpose(0, 3, 1, 2) # N, C, T, V
57
+ X = X[:, :, :, :, np.newaxis] # N, C, T, V, M
58
+ print("Shape of prepped X: ", X.shape)
59
+
60
+ # Normalize (pre-process) in NTU RGBD-style
61
+ ntu_sk_spine_bone = np.array([0, 1])
62
+ ntu_sk_shoulder_bone = np.array([8, 4])
63
+ X, l_m_sk = preprocess.pre_normalization(
64
+ X, zaxis=ntu_sk_spine_bone, xaxis=ntu_sk_shoulder_bone
65
+ )
66
+
67
+ if len(l_m_sk) == 0:
68
+ X_new.append(X[0])
69
+ Y_new.append(b_dset["Y"][idx])
70
+ id_new.append(b_dset["sid"][idx])
71
+ else:
72
+ print("Skipped")
73
+
74
+ # Dataset w/ processed seg. chunks. (Skip samples w/ missing skeletons)
75
+ b_AR_dset = {"sid": id_new, "X": X_new, "Y": Y_new}
76
+
77
+ fp = b_dset_path.replace("samples", "ntu_sk_ntu-style_preprocessed")
78
+ # fp = '../data/babel_v1.0/babel_v1.0_ntu_sk_ntu-style_preprocessed.pkl'
79
+ # dutils.write_pkl(b_AR_dset, fp)
80
+ with open(fp, "wb") as of:
81
+ pickle.dump(b_AR_dset, of, protocol=4)
82
+
83
+
84
+ def get_act_idx(y, act2idx, n_classes):
85
+ """"""
86
+ if y in act2idx:
87
+ return act2idx[y]
88
+ else:
89
+ return n_classes
90
+
91
+
92
+ def store_splits_subsets(
93
+ n_classes, spl, plus_extra=True, w_folder="../data_created/babel_v1.0/"
94
+ ):
95
+ """"""
96
+ # Get splits
97
+ splits = dutils.read_json("../data_created/amass_splits.json")
98
+ sid2split = {
99
+ int(ospb(u).replace(".mp4", "")): spl for spl in splits for u in splits[spl]
100
+ }
101
+
102
+ # In labels, act. cat. --> idx
103
+ act2idx_150 = dutils.read_json("../data_created/action_label_2_idx.json")
104
+ act2idx = {k: act2idx_150[k] for k in act2idx_150 if act2idx_150[k] < n_classes}
105
+ print("{0} actions in label set: {1}".format(len(act2idx), act2idx))
106
+
107
+ if plus_extra:
108
+ fp = w_folder + "babel_v1.0_" + spl + "_extra_ntu_sk_ntu-style_preprocessed.pkl"
109
+ else:
110
+ fp = w_folder + "babel_v1.0_" + spl + "_ntu_sk_ntu-style_preprocessed.pkl"
111
+
112
+ # Get full dataset
113
+ b_AR_dset = dutils.read_pkl(fp)
114
+
115
+ # Store idxs of samples to include in learning
116
+ split_idxs = defaultdict(list)
117
+ for i, y1 in enumerate(b_AR_dset["Y1"]):
118
+
119
+ # Check if action category in list of classes
120
+ if y1 not in act2idx:
121
+ continue
122
+
123
+ sid = b_AR_dset["sid"][i]
124
+ split_idxs[sid2split[sid]].append(i) # Include idx in dataset
125
+
126
+ # Save features that'll be loaded by dataloader
127
+ ar_idxs = np.array(split_idxs[spl])
128
+ X = b_AR_dset["X"][ar_idxs]
129
+ if plus_extra:
130
+ fn = w_folder + f"{spl}_extra_ntu_sk_{n_classes}.npy"
131
+ else:
132
+ fn = w_folder + f"{spl}_ntu_sk_{n_classes}.npy"
133
+ np.save(fn, X)
134
+
135
+ # labels
136
+ labels = {k: np.array(b_AR_dset[k])[ar_idxs] for k in b_AR_dset if k != "X"}
137
+
138
+ # Create, save label data structure that'll be loaded by dataloader
139
+ label_idxs = defaultdict(list)
140
+ for i, y1 in enumerate(labels["Y1"]):
141
+ # y1
142
+ label_idxs["Y1"].append(act2idx[y1])
143
+ # yk
144
+ yk = [get_act_idx(y, act2idx, n_classes) for y in labels["Yk"][i]]
145
+ label_idxs["Yk"].append(yk)
146
+ # yov
147
+ yov_o = labels["Yov"][i]
148
+ yov = {get_act_idx(y, act2idx, n_classes): yov_o[y] for y in yov_o}
149
+ label_idxs["Yov"].append(yov)
150
+ #
151
+ label_idxs["seg_id"].append(labels["seg_id"][i])
152
+ label_idxs["sid"].append(labels["sid"][i])
153
+ label_idxs["chunk_n"].append(labels["chunk_n"][i])
154
+ label_idxs["anntr_id"].append(labels["anntr_id"][i])
155
+
156
+ if plus_extra:
157
+ wr_f = w_folder + f"{spl}_extra_label_{n_classes}.pkl"
158
+ else:
159
+ wr_f = w_folder + f"{spl}_label_{n_classes}.pkl"
160
+ dutils.write_pkl(
161
+ (
162
+ label_idxs["seg_id"],
163
+ (
164
+ label_idxs["Y1"],
165
+ label_idxs["sid"],
166
+ label_idxs["chunk_n"],
167
+ label_idxs["anntr_id"],
168
+ ),
169
+ ),
170
+ wr_f,
171
+ )
172
+
173
+
174
+ class Babel_AR:
175
+ """Object containing data, methods for Action Recognition.
176
+
177
+ Task
178
+ -----
179
+ Given: x (Segment from Babel)
180
+ Predict: \hat{p}(x) (Distribution over action categories)
181
+
182
+ GT
183
+ ---
184
+ How to compute GT for a given segment?
185
+ - yk: All action categories that are labeled for the entirety of segment
186
+ - y1: One of yk
187
+ - yov: Any y that belongs to part of a segment is considered to be GT.
188
+ Fraction of segment covered by an action: {'walk': 1.0, 'wave': 0.5}
189
+
190
+ """
191
+
192
+ def __init__(self, dataset, dense=True):
193
+ """Dataset with (samples, different GTs)"""
194
+ # Load dataset
195
+ self.babel = dataset
196
+ self.dense = dense
197
+ self.jpos_p = "dataset/amass"
198
+
199
+ # Get frame-rate for each seq. in AMASS
200
+ f_p = "dataset/BABEL/action_recognition/data/featp_2_fps.json"
201
+ self.ft_p_2_fps = dutils.read_json(f_p)
202
+
203
+ # Dataset w/ keys = {'X', 'Y1', 'Yk', 'Yov', 'seg_id', 'sid',
204
+ # 'seg_dur'}
205
+ self.d = defaultdict(list)
206
+ for ann in tqdm(self.babel):
207
+ self._update_dataset(ann)
208
+
209
+ def _subsample_to_30fps(self, orig_ft, orig_fps):
210
+ """Get features at 30fps frame-rate
211
+ Args:
212
+ orig_ft <array> (T, 25*3): Feats. @ `orig_fps` frame-rate
213
+ orig_fps <float>: Frame-rate in original (ft) seq.
214
+ Return:
215
+ ft <array> (T', 25*3): Feats. @ 30fps
216
+ """
217
+ T, n_j, _ = orig_ft.shape
218
+ out_fps = 30.0
219
+ # Matching the sub-sampling used for rendering
220
+ if int(orig_fps) % int(out_fps):
221
+ sel_fr = np.floor(orig_fps / out_fps * np.arange(int(out_fps))).astype(int)
222
+ n_duration = int(T / int(orig_fps))
223
+ t_idxs = []
224
+ for i in range(n_duration):
225
+ t_idxs += list(i * int(orig_fps) + sel_fr)
226
+ if int(T % int(orig_fps)):
227
+ last_sec_frame_idx = n_duration * int(orig_fps)
228
+ t_idxs += [
229
+ x + last_sec_frame_idx for x in sel_fr if x + last_sec_frame_idx < T
230
+ ]
231
+ else:
232
+ t_idxs = np.arange(0, T, orig_fps / out_fps, dtype=int)
233
+
234
+ ft = orig_ft[t_idxs, :, :]
235
+ return ft
236
+
237
+ def _viz_x(self, ft, fn="test_sample"):
238
+ """Wraper to Viz. the given sample (w/ NTU RGBD skeleton)"""
239
+ viz.viz_seq(seq=ft, folder_p=f"test_viz/{fn}", sk_type="nturgbd", debug=True)
240
+ return None
241
+
242
+ def _load_seq_feats(self, ft_p, sk_type):
243
+ """Given path to joint position features, return them in 30fps"""
244
+ # Identify appropriate feature directory path on disk
245
+ if "smpl_wo_hands" == sk_type: # SMPL w/o hands (T, 22*3)
246
+ jpos_p = ospj(self.jpos_p, "joint_pos")
247
+ if "nturgbd" == sk_type: # NTU (T, 219)
248
+ jpos_p = ospj(self.jpos_p, "babel_joint_pos")
249
+
250
+ # Get the correct dataset folder name
251
+ ddir_n = ospb(ospd(ospd(ft_p)))
252
+ ddir_map = {"BioMotionLab_NTroje": "BMLrub", "DFaust_67": "DFaust"}
253
+ ddir_n = ddir_map[ddir_n] if ddir_n in ddir_map else ddir_n
254
+ # Get the subject folder name
255
+ sub_fol_n = ospb(ospd(ft_p))
256
+
257
+ # Sanity check
258
+ fft_p = ospj(jpos_p, ddir_n, sub_fol_n, ospb(ft_p))
259
+ assert os.path.exists(fft_p)
260
+
261
+ # Load seq. fts.
262
+ ft = np.load(fft_p)["joint_pos"]
263
+ T, ft_sz = ft.shape
264
+
265
+ # Get NTU skeleton joints
266
+ ntu_js = dutils.smpl_to_nturgbd(model_type="smplh", out_format="nturgbd")
267
+ ft = ft.reshape(T, -1, 3)
268
+ ft = ft[:, ntu_js, :]
269
+
270
+ # Sub-sample to 30fps
271
+ orig_fps = self.ft_p_2_fps[ft_p]
272
+ ft = self._subsample_to_30fps(ft, orig_fps)
273
+ # print(f'Feat. shape = {ft.shape}, fps = {orig_fps}')
274
+ # if orig_fps != 30.0:
275
+ # self._viz_x(ft)
276
+ return ft
277
+
278
+ def _get_per_f_labels(self, ann, ann_type, seq_dur):
279
+ """ """
280
+ # Per-frame labels: {0: ['walk'], 1: ['walk', 'wave'], ... T: ['stand']}
281
+ yf = defaultdict(list)
282
+ T = int(30.0 * seq_dur)
283
+ for n_f in range(T):
284
+ cur_t = float(n_f / 30.0)
285
+ for seg in ann["labels"]:
286
+
287
+ if seg["act_cat"] is None:
288
+ continue
289
+
290
+ if "seq_ann" == ann_type:
291
+ seg["start_t"] = 0.0
292
+ seg["end_t"] = seq_dur
293
+
294
+ if cur_t >= float(seg["start_t"]) and cur_t < float(seg["end_t"]):
295
+ yf[n_f] += seg["act_cat"]
296
+ return yf
297
+
298
+ def _compute_dur_samples(self, id, ann, ann_type, seq_ft, seq_dur, dur=5.0):
299
+ """Return each motion and its frame-wise GT action
300
+
301
+ Return:
302
+ [ { 'seg_id': motion id,
303
+ 'x': motion feats,
304
+ 'yall': labels of each motion,
305
+ { ... }, ...
306
+ ]
307
+ """
308
+ yf = self._get_per_f_labels(ann, ann_type, seq_dur)
309
+
310
+ seq_ft = seq_ft[:len(yf)]
311
+ assert seq_ft.shape[0] == len(yf)
312
+
313
+ seq_samples = []
314
+ seq_samples.append(
315
+ {"seg_id": id, "x": seq_ft, "y": yf,}
316
+ )
317
+
318
+ return seq_samples
319
+
320
+ def _sample_at_seg_chunk_level(self, ann, seq_samples):
321
+ # Samples at segment-chunk-level
322
+ for i, sample in enumerate(seq_samples):
323
+
324
+ self.d["sid"].append(ann["babel_sid"]) # Seq. info
325
+ self.d["X"].append(sample["x"]) # motion feats.
326
+ self.d["Y"].append(sample["y"]) # labels of each motion.
327
+ return
328
+
329
+ def _update_dataset(self, ann):
330
+ """Return one sample (one segment) = (X, Y1, Yall)"""
331
+
332
+ # Get feats. for seq.
333
+ seq_ft = self._load_seq_feats(ann["feat_p"], "nturgbd")
334
+
335
+ # To keep track of type of annotation for loading 'extra'
336
+ # Compute all GT labels for this seq.
337
+ seq_samples = None
338
+ if self.dense:
339
+ if ann["frame_ann"] is not None:
340
+ ann_ar = ann["frame_ann"]
341
+ seq_samples = self._compute_dur_samples(
342
+ ann["babel_sid"], ann_ar, "frame_ann", seq_ft, ann["dur"]
343
+ )
344
+ self._sample_at_seg_chunk_level(ann, seq_samples)
345
+ else:
346
+ print("not supported data")
347
+
348
+ else:
349
+ raise NotImplementedError
350
+
351
+ return
352
+
353
+
354
+ # Create dataset
355
+ # --------------------------
356
+ d_folder = "dataset/babel_v1.0_release/"
357
+ w_folder = "dataset/babel_v1.0_sequence/"
358
+ os.makedirs(w_folder, exist_ok=True)
359
+ for spl in ["train", "val"]:
360
+ # Load Dense BABEL
361
+ data = dutils.read_json(ospj(d_folder, f"{spl}.json"))
362
+ dataset = [data[sid] for sid in data]
363
+ dense_babel = Babel_AR(dataset, dense=True)
364
+ # Store Dense BABEL
365
+ d_filename = w_folder + "babel_v1.0_" + spl + "_samples.pkl"
366
+ dutils.write_pkl(dense_babel.d, d_filename)
367
+
368
+ # Pre-process, Store data in dataset
369
+ print("NTU-style preprocessing")
370
+ babel_dataset_AR = ntu_style_preprocessing(d_filename)
prepare/dutils.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # vim:fenc=utf-8
4
+ #
5
+ # Copyright © 2021 achandrasekaran <arjun.chandrasekaran@tuebingen.mpg.de>
6
+ #
7
+ # Distributed under terms of the MIT license.
8
+
9
+ import csv
10
+ import json
11
+ import os
12
+ import os.path as osp
13
+ import pdb
14
+ import pickle
15
+ import sys
16
+ from collections import Counter
17
+ from os.path import basename as ospb
18
+ from os.path import dirname as ospd
19
+ from os.path import join as ospj
20
+
21
+ import numpy as np
22
+ import torch
23
+ import viz
24
+ from smplx import SMPLH
25
+ from tqdm import tqdm
26
+
27
+
28
+ def read_json(json_filename):
29
+ """Return contents of JSON file"""
30
+ jc = None
31
+ with open(json_filename) as infile:
32
+ jc = json.load(infile)
33
+ return jc
34
+
35
+
36
+ def read_pkl(pkl_filename):
37
+ """Return contents of pikcle file"""
38
+ pklc = None
39
+ with open(pkl_filename, "rb") as infile:
40
+ pklc = pickle.load(infile)
41
+ return pklc
42
+
43
+
44
+ def write_json(contents, filename):
45
+ with open(filename, "w") as outfile:
46
+ json.dump(contents, outfile, indent=2)
47
+
48
+
49
+ def write_pkl(contents, filename):
50
+ with open(filename, "wb") as outfile:
51
+ pickle.dump(contents, outfile)
52
+
53
+
54
+ def smpl_to_nturgbd(model_type="smplh", out_format="nturgbd"):
55
+ """Borrowed from https://gitlab.tuebingen.mpg.de/apunnakkal/2s_agcn/-/blob/master/data_gen/smpl_data_utils.py
56
+ NTU mapping
57
+ -----------
58
+ 0 --> ?
59
+ 1-base of the spine
60
+ 2-middle of the spine
61
+ 3-neck
62
+ 4-head
63
+ 5-left shoulder
64
+ 6-left elbow
65
+ 7-left wrist
66
+ 8-left hand
67
+ 9-right shoulder
68
+ 10-right elbow
69
+ 11-right wrist
70
+ 12-right hand
71
+ 13-left hip
72
+ 14-left knee
73
+ 15-left ankle
74
+ 16-left foot
75
+ 17-right hip
76
+ 18-right knee
77
+ 19-right ankle
78
+ 20-right foot
79
+ 21-spine
80
+ 22-tip of the left hand
81
+ 23-left thumb
82
+ 24-tip of the right hand
83
+ 25-right thumb
84
+
85
+ :param model_type:
86
+ :param out_format:
87
+ :return:
88
+ """
89
+ if model_type == "smplh" and out_format == "nturgbd":
90
+ "22 and 37 are approximation for hand (base of index finger)"
91
+ return np.array(
92
+ [
93
+ 0,
94
+ 3,
95
+ 12,
96
+ 15,
97
+ 16,
98
+ 18,
99
+ 20,
100
+ 22, # left hand
101
+ 17,
102
+ 19,
103
+ 21,
104
+ 37, # right hand
105
+ 1,
106
+ 4,
107
+ 7,
108
+ 10, # left leg
109
+ 2,
110
+ 5,
111
+ 8,
112
+ 11, # right hand
113
+ 9,
114
+ 63,
115
+ 64,
116
+ 68,
117
+ 69,
118
+ ],
119
+ dtype=np.int32,
120
+ )
121
+
122
+
123
+ class dotdict(dict):
124
+ """dot.notation access to dictionary attributes"""
125
+
126
+ __getattr__ = dict.get
127
+ __setattr__ = dict.__setitem__
128
+ __delattr__ = dict.__delitem__
129
+
130
+
131
+ def store_counts(label_fp):
132
+ """Compute # samples per class, from stored labels
133
+
134
+ Args:
135
+ label_fp <str>: Path to label file
136
+
137
+ Writes (to same path as label file):
138
+ out_fp <dict>: # samples per class = {<idx>: <count>, ...}
139
+ """
140
+ Y_tup = read_pkl(label_fp)
141
+ Y_idxs = Y_tup[1][0]
142
+ print("# Samples in set = ", len(Y_idxs))
143
+
144
+ label_count = Counter(Y_idxs)
145
+ print("File ", label_fp, "len", len(label_count))
146
+
147
+ out_fp = label_fp.replace(".pkl", "_count.pkl")
148
+ write_pkl(label_count, out_fp)
149
+
150
+
151
+ def load_babel_dataset(d_folder="dataset/babel_v1.0_release"):
152
+ """Load the BABEL dataset"""
153
+ # Data folder
154
+ l_babel_dense_files = ["train", "val", "test"]
155
+ l_babel_extra_files = ["extra_train", "extra_val"]
156
+
157
+ # BABEL Dataset
158
+ babel = {}
159
+ for fn in l_babel_dense_files + l_babel_extra_files:
160
+ babel[fn] = json.load(open(ospj(d_folder, fn + ".json")))
161
+
162
+ return babel
163
+
164
+
165
+ def store_seq_fps(amass_p):
166
+ """Get fps for each seq. in BABEL
167
+ Arguments:
168
+ ---------
169
+ amass_p <str>: Path where you download AMASS to.
170
+ Save:
171
+ -----
172
+ featp_2_fps.json <dict>: Key: feat path <str>, value: orig. fps
173
+ in AMASS <float>. E.g.,: {'KIT/KIT/4/RightTurn01_poses.npz': 100.0, ...}
174
+ """
175
+ # Get BABEL dataset
176
+ babel = load_babel_dataset()
177
+
178
+ # Loop over each BABEL seq, store frame-rate
179
+ ft_p_2_fps = {}
180
+ for fn in babel:
181
+ for sid in tqdm(babel[fn]):
182
+ ann = babel[fn][sid]
183
+ # print (ann)
184
+ if ann["feat_p"] not in ft_p_2_fps:
185
+ ddir_n = ann["feat_p"]
186
+ ddir_n_ = ddir_n.split("/")
187
+ ddir_n_ = ddir_n_[1:]
188
+
189
+ ddir_map = {"BMLrub": "BioMotionLab_NTroje", "DFaust": "DFaust_67"}
190
+ ddir_n_[0] = (
191
+ ddir_map[ddir_n_[0]] if ddir_n_[0] in ddir_map else ddir_n_[0]
192
+ )
193
+
194
+ p = ospj(amass_p, "/".join(ddir_n_))
195
+
196
+ fps = np.load(p)["mocap_framerate"]
197
+ ft_p_2_fps[ann["feat_p"]] = float(fps)
198
+ dest_fp = "dataset/featp_2_fps.json"
199
+ write_json(ft_p_2_fps, dest_fp)
200
+ return None
201
+
202
+
203
+ def store_ntu_jpos(smplh_model_p, dest_jpos_p, amass_p):
204
+ """Store joint positions of kfor NTU-RGBD skeleton"""
205
+ # Model to forward-pass through, to store joint positions
206
+ smplh = SMPLH(
207
+ smplh_model_p,
208
+ create_transl=False,
209
+ ext="pkl",
210
+ gender="male",
211
+ use_pca=False,
212
+ batch_size=1,
213
+ )
214
+
215
+ # Load paths to all BABEL features
216
+ featp_2_fps = read_json("dataset/featp_2_fps.json")
217
+
218
+ # Loop over all BABEL data, verify that joint positions are stored on disk
219
+ l_m_ft_p = []
220
+ for ft_p in featp_2_fps:
221
+
222
+ # Get the correct dataset folder name
223
+ ddir_n = ospb(ospd(ospd(ft_p)))
224
+ ddir_map = {"BioMotionLab_NTroje": "BMLrub", "DFaust_67": "DFaust"}
225
+ ddir_n = ddir_map[ddir_n] if ddir_n in ddir_map else ddir_n
226
+ # Get the subject folder name
227
+ sub_fol_n = ospb(ospd(ft_p))
228
+
229
+ # Sanity check
230
+ fft_p = ospj(dest_jpos_p, ddir_n, sub_fol_n, ospb(ft_p))
231
+ if not os.path.exists(fft_p):
232
+ l_m_ft_p.append((ft_p, fft_p))
233
+ print("Total # missing NTU RGBD skeleton features = ", len(l_m_ft_p))
234
+
235
+ # Loop over missing joint positions and store them on disk
236
+ for i, (ft_p, ntu_jpos_p) in enumerate(tqdm(l_m_ft_p)):
237
+ ft_p_ = ft_p.split("/")
238
+ ft_p_ = ft_p_[1:]
239
+
240
+ ft_p = ospj(amass_p, "/".join(ft_p_))
241
+
242
+ jrot_smplh = np.load(ft_p)["poses"]
243
+ # Break joints down into body parts
244
+ smpl_body_jrot = jrot_smplh[:, 3:66]
245
+ left_hand_jrot = jrot_smplh[:, 66:111]
246
+ right_hand_jrot = jrot_smplh[:, 111:]
247
+ root_orient = jrot_smplh[:, 0:3].reshape(-1, 3)
248
+
249
+ # Forward through model to get a superset of required joints
250
+ T = jrot_smplh.shape[0]
251
+ ntu_jpos = np.zeros((T, 219))
252
+ for t in range(T):
253
+ res = smplh(
254
+ body_pose=torch.Tensor(smpl_body_jrot[t : t + 1, :]),
255
+ global_orient=torch.Tensor(root_orient[t : t + 1, :]),
256
+ left_hand_pose=torch.Tensor(left_hand_jrot[t : t + 1, :]),
257
+ right_hand_pose=torch.Tensor(right_hand_jrot[t : t + 1, :]),
258
+ # transl=torch.Tensor(transl)
259
+ )
260
+ jpos = res.joints.detach().cpu().numpy()[:, :, :].reshape(-1)
261
+ ntu_jpos[t, :] = jpos
262
+
263
+ # Save to disk
264
+ if not os.path.exists(ospd(ntu_jpos_p)):
265
+ os.makedirs(ospd(ntu_jpos_p))
266
+ np.savez(ntu_jpos_p, joint_pos=ntu_jpos, allow_pickle=True)
267
+
268
+ return
269
+
270
+
271
+ def viz_ntu_jpos(jpos_p, l_ft_p):
272
+ """Visualize sequences of NTU-skeleton joint positions"""
273
+ # Indices that are in the NTU RGBD skeleton
274
+ smpl2nturgbd = smpl_to_nturgbd()
275
+ # Iterate over each
276
+ for ft_p in l_ft_p:
277
+ x = np.load(ospj(jpos_p, ft_p))["joint_pos"]
278
+ T, ft_sz = x.shape
279
+ x = x.reshape(T, ft_sz // 3, 3)
280
+ # print('Data shape = {0}'.format(x.shape))
281
+ x = x[:, smpl2nturgbd, :]
282
+ # print('Data shape = {0}'.format(x.shape))
283
+ # x = x[:,:,:, 0].transpose(1, 2, 0) # (3, 150, 22, 1) --> (150, 22, 3)
284
+ print("Data shape = {0}".format(x.shape))
285
+ viz.viz_seq(
286
+ seq=x, folder_p="test_viz/test_ntu_w_axis", sk_type="nturgbd", debug=True
287
+ )
288
+ print("-" * 50)
289
+
290
+
291
+ def main():
292
+ """Store preliminary stuff"""
293
+ amass_p = "dataset/amass/"
294
+
295
+ # Save feature paths --> fps (released in babel/action_recognition/data/)
296
+ store_seq_fps(amass_p)
297
+
298
+ # Save joint positions in NTU-RGBD skeleton format
299
+ smplh_model_p = "./human_model/SMPLH_male.pkl"
300
+ # model is generated by https://github.com/vchoutas/smplx
301
+ jpos_p = "./dataset/amass/babel_joint_pos"
302
+ store_ntu_jpos(smplh_model_p, jpos_p, amass_p)
303
+
304
+ # Viz. saved seqs.
305
+ l_ft_p = ["ACCAD/Male2MartialArtsStances_c3d/D7 - walk to bow_poses.npz"]
306
+ viz_ntu_jpos(jpos_p, l_ft_p)
307
+
308
+
309
+ if __name__ == "__main__":
310
+ main()
prepare/generate_dataset.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ python prepare/dutils.py
2
+ python prepare/create_dataset.py
3
+ python prepare/split_dataset.py --split 1
4
+ # python prepare/split_dataset.py --split 2
5
+ # python prepare/split_dataset.py --split 3
prepare/preprocess.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ from rotation import *
4
+ from tqdm import tqdm
5
+
6
+
7
+ def pre_normalization(data, zaxis=[0, 1], xaxis=[8, 4]):
8
+ N, C, T, V, M = data.shape
9
+ s = np.transpose(data, [0, 4, 2, 3, 1]) # N, C, T, V, M to N, M, T, V, C
10
+ l_m_sk = [] # List idxs of missing skeletons
11
+
12
+ print("pad the null frames with the previous frames")
13
+ for i_s, skeleton in enumerate(tqdm(s)): # pad
14
+ if skeleton.sum() == 0:
15
+ print(i_s, " has no skeleton")
16
+ l_m_sk.append(i_s)
17
+ for i_p, person in enumerate(skeleton):
18
+ if person.sum() == 0:
19
+ continue
20
+ if person[0].sum() == 0:
21
+ index = person.sum(-1).sum(-1) != 0
22
+ tmp = person[index].copy()
23
+ person *= 0
24
+ person[: len(tmp)] = tmp
25
+ for i_f, frame in enumerate(person):
26
+ if frame.sum() == 0:
27
+ if person[i_f:].sum() == 0:
28
+ rest = len(person) - i_f
29
+ num = int(np.ceil(rest / i_f))
30
+ pad = np.concatenate([person[0:i_f] for _ in range(num)], 0)[
31
+ :rest
32
+ ]
33
+ s[i_s, i_p, i_f:] = pad
34
+ break
35
+
36
+ print("sub the center joint #1 (spine joint in ntu and neck joint in kinetics)")
37
+ for i_s, skeleton in enumerate(tqdm(s)):
38
+ if skeleton.sum() == 0:
39
+ continue
40
+ main_body_center = skeleton[0][:, 1:2, :].copy()
41
+ for i_p, person in enumerate(skeleton):
42
+ if person.sum() == 0:
43
+ continue
44
+ mask = (person.sum(-1) != 0).reshape(T, V, 1)
45
+ s[i_s, i_p] = (s[i_s, i_p] - main_body_center) * mask
46
+
47
+ print(
48
+ "parallel the bone between hip(jpt 0) and spine(jpt 1) of the first person to the z axis"
49
+ )
50
+ for i_s, skeleton in enumerate(tqdm(s)):
51
+ if skeleton.sum() == 0:
52
+ continue
53
+ joint_bottom = skeleton[0, 0, zaxis[0]]
54
+ joint_top = skeleton[0, 0, zaxis[1]]
55
+ axis = np.cross(joint_top - joint_bottom, [0, 0, 1])
56
+ angle = angle_between(joint_top - joint_bottom, [0, 0, 1])
57
+ matrix_z = rotation_matrix(axis, angle)
58
+ for i_p, person in enumerate(skeleton):
59
+ if person.sum() == 0:
60
+ continue
61
+ for i_f, frame in enumerate(person):
62
+ if frame.sum() == 0:
63
+ continue
64
+ for i_j, joint in enumerate(frame):
65
+ s[i_s, i_p, i_f, i_j] = np.dot(matrix_z, joint)
66
+
67
+ print(
68
+ "parallel the bone between right shoulder(jpt 8) and left shoulder(jpt 4) of the first person to the x axis"
69
+ )
70
+ for i_s, skeleton in enumerate(tqdm(s)):
71
+ if skeleton.sum() == 0:
72
+ continue
73
+ joint_rshoulder = skeleton[0, 0, xaxis[0]]
74
+ joint_lshoulder = skeleton[0, 0, xaxis[1]]
75
+ axis = np.cross(joint_rshoulder - joint_lshoulder, [1, 0, 0])
76
+ angle = angle_between(joint_rshoulder - joint_lshoulder, [1, 0, 0])
77
+ matrix_x = rotation_matrix(axis, angle)
78
+ for i_p, person in enumerate(skeleton):
79
+ if person.sum() == 0:
80
+ continue
81
+ for i_f, frame in enumerate(person):
82
+ if frame.sum() == 0:
83
+ continue
84
+ for i_j, joint in enumerate(frame):
85
+ s[i_s, i_p, i_f, i_j] = np.dot(matrix_x, joint)
86
+
87
+ data = np.transpose(s, [0, 4, 2, 3, 1])
88
+ return data, l_m_sk
89
+
90
+
91
+ if __name__ == "__main__":
92
+ data = np.load("../data/ntu/xview/val_data.npy")
93
+ pre_normalization(data)
94
+ np.save("../data/ntu/xview/data_val_pre.npy", data)
prepare/rotation.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # vim:fenc=utf-8
4
+ #
5
+ # Copyright © 2021 achandrasekaran <arjun.chandrasekaran@tuebingen.mpg.de>
6
+ #
7
+ # Distributed under terms of the MIT license.
8
+
9
+ import math
10
+
11
+ import numpy as np
12
+
13
+
14
+ def rotation_matrix(axis, theta):
15
+ """
16
+ Return the rotation matrix associated with counterclockwise rotation about
17
+ the given axis by theta radians.
18
+ """
19
+ if np.abs(axis).sum() < 1e-6 or np.abs(theta) < 1e-6:
20
+ return np.eye(3)
21
+ axis = np.asarray(axis)
22
+ axis = axis / math.sqrt(np.dot(axis, axis))
23
+ a = math.cos(theta / 2.0)
24
+ b, c, d = -axis * math.sin(theta / 2.0)
25
+ aa, bb, cc, dd = a * a, b * b, c * c, d * d
26
+ bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
27
+ return np.array(
28
+ [
29
+ [aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
30
+ [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
31
+ [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc],
32
+ ]
33
+ )
34
+
35
+
36
+ def unit_vector(vector):
37
+ """ Returns the unit vector of the vector. """
38
+ return vector / np.linalg.norm(vector)
39
+
40
+
41
+ def angle_between(v1, v2):
42
+ """Returns the angle in radians between vectors 'v1' and 'v2'::
43
+
44
+ >>> angle_between((1, 0, 0), (0, 1, 0))
45
+ 1.5707963267948966
46
+ >>> angle_between((1, 0, 0), (1, 0, 0))
47
+ 0.0
48
+ >>> angle_between((1, 0, 0), (-1, 0, 0))
49
+ 3.141592653589793
50
+ """
51
+ if np.abs(v1).sum() < 1e-6 or np.abs(v2).sum() < 1e-6:
52
+ return 0
53
+ v1_u = unit_vector(v1)
54
+ v2_u = unit_vector(v2)
55
+ return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))
56
+
57
+
58
+ def x_rotation(vector, theta):
59
+ """Rotates 3-D vector around x-axis"""
60
+ R = np.array(
61
+ [
62
+ [1, 0, 0],
63
+ [0, np.cos(theta), -np.sin(theta)],
64
+ [0, np.sin(theta), np.cos(theta)],
65
+ ]
66
+ )
67
+ return np.dot(R, vector)
68
+
69
+
70
+ def y_rotation(vector, theta):
71
+ """Rotates 3-D vector around y-axis"""
72
+ R = np.array(
73
+ [
74
+ [np.cos(theta), 0, np.sin(theta)],
75
+ [0, 1, 0],
76
+ [-np.sin(theta), 0, np.cos(theta)],
77
+ ]
78
+ )
79
+ return np.dot(R, vector)
80
+
81
+
82
+ def z_rotation(vector, theta):
83
+ """Rotates 3-D vector around z-axis"""
84
+ R = np.array(
85
+ [
86
+ [np.cos(theta), -np.sin(theta), 0],
87
+ [np.sin(theta), np.cos(theta), 0],
88
+ [0, 0, 1],
89
+ ]
90
+ )
91
+ return np.dot(R, vector)
prepare/split_dataset.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright 2023 LINE Corporation
3
+ LINE Corporation licenses this file to you under the Apache License,
4
+ version 2.0 (the "License"); you may not use this file except in compliance
5
+ with the License. You may obtain a copy of the License at:
6
+ https://www.apache.org/licenses/LICENSE-2.0
7
+ Unless required by applicable law or agreed to in writing, software
8
+ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9
+ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10
+ License for the specific language governing permissions and limitations
11
+ under the License.
12
+ """
13
+
14
+ import numpy as np
15
+ import dutils
16
+ import pandas as pd
17
+ from collections import Counter
18
+ from tqdm import tqdm
19
+ import os
20
+ from pandas.core.common import flatten
21
+ import argparse
22
+
23
+ MAX_LEN = 1000
24
+ N_CLASS = 4
25
+
26
+ parser = argparse.ArgumentParser(
27
+ description="Spatial Temporal Graph Convolution Network"
28
+ )
29
+ parser.add_argument(
30
+ "--data-root",
31
+ default="dataset/babel_v1.0_sequence/",
32
+ help="the root path of the dataset",
33
+ type=str
34
+ )
35
+ parser.add_argument(
36
+ "--split",
37
+ default=1,
38
+ help="the split of the dataset",
39
+ type=int
40
+ )
41
+ parser.add_argument(
42
+ "--output-folder",
43
+ default="dataset/processed_data",
44
+ help="the output folder of the generated data",
45
+ type=str
46
+ )
47
+ args = parser.parse_args()
48
+
49
+ os.makedirs(args.output_folder, exist_ok=True)
50
+
51
+ def main(data_root):
52
+ train_data = dutils.read_pkl(os.path.join(data_root, "babel_v1.0_train_ntu_sk_ntu-style_preprocessed.pkl"))
53
+ test_data = dutils.read_pkl(os.path.join(data_root, "babel_v1.0_val_ntu_sk_ntu-style_preprocessed.pkl"))
54
+
55
+ act2idx = dutils.read_json(f"./prepare/configs/action_label_split{args.split}.json")
56
+
57
+ label_train_data(data_root, train_data, act2idx)
58
+ label_val_data(data_root, test_data, act2idx)
59
+
60
+
61
+ def label_train_data(data_root, train_data, act2idx):
62
+ sid = []
63
+ x = []
64
+ y = []
65
+ loc = []
66
+
67
+ for i, seq_labels in enumerate(tqdm(train_data["Y"])):
68
+ if len(seq_labels) > MAX_LEN:
69
+ continue
70
+
71
+ y_ = []
72
+ loc_ = []
73
+ flag = False
74
+
75
+ for frame, labels in seq_labels.items():
76
+ label_set = set(labels) & set(act2idx.keys())
77
+ label_list = list(label_set)
78
+ if len(label_list) > 0:
79
+ flag = True
80
+ loc_.append(act2idx[label_list[0]])
81
+ y_.append(act2idx[label_list[0]])
82
+ else:
83
+ loc_.append(N_CLASS)
84
+
85
+ max_t = len(loc_)
86
+ loc_ = np.array(loc_)
87
+ y_ = list(set(y_))
88
+
89
+ if flag:
90
+
91
+ # print (train_data["X"][i].shape, len(loc_))
92
+ loc.append(loc_)
93
+ sid.append(train_data["sid"][i])
94
+ x.append(train_data["X"][i][:,:max_t,...])
95
+ y.append(y_)
96
+
97
+ data = {"sid": sid, "X": x, "Y": y, "L":loc}
98
+
99
+ dutils.write_pkl(data, os.path.join(args.output_folder, f"train_split{args.split}.pkl"))
100
+ print (f"#Train sequence: {len(x)}")
101
+
102
+
103
+ def label_val_data(data_root, test_data, act2idx):
104
+ sid = []
105
+ x = []
106
+ y = []
107
+ loc = []
108
+ for i, seq_labels in enumerate(tqdm(test_data["Y"])):
109
+ if len(seq_labels) > MAX_LEN:
110
+ continue
111
+
112
+ y_ = []
113
+ loc_ = []
114
+ flag = False
115
+
116
+ for frame, labels in seq_labels.items():
117
+ label_set = set(labels) & set(act2idx.keys())
118
+ label_list = list(label_set)
119
+ if len(label_list) > 0:
120
+ flag = True
121
+ loc_.append(act2idx[label_list[0]])
122
+ y_.append(act2idx[label_list[0]])
123
+ else:
124
+ loc_.append(N_CLASS)
125
+
126
+ max_t = len(loc_)
127
+ loc_ = np.array(loc_)
128
+ y_ = list(set(y_))
129
+
130
+ if flag:
131
+ loc.append(loc_)
132
+ sid.append(test_data["sid"][i])
133
+ x.append(test_data["X"][i][:,:max_t,...])
134
+ y.append(y_)
135
+
136
+ data = {"sid": sid, "X": x, "Y": y, "L":loc}
137
+
138
+ dutils.write_pkl(data, os.path.join(args.output_folder, f"val_split{args.split}.pkl"))
139
+ print (f"#Test sequence: {len(x)}")
140
+
141
+
142
+ if __name__ == "__main__":
143
+ main(args.data_root)
prepare/viz.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # vim:fenc=utf-8
4
+ #
5
+ # Copyright © 2020 achandrasekaran <arjun.chandrasekaran@tuebingen.mpg.de>
6
+ #
7
+ # Distributed under terms of the MIT license.
8
+
9
+ import math
10
+ import os
11
+ import os.path as osp
12
+ import pdb
13
+ import random
14
+ import shutil
15
+ import subprocess
16
+ import sys
17
+ import uuid
18
+
19
+ import cv2
20
+ import dutils
21
+ import numpy as np
22
+ import torch
23
+ from matplotlib import pyplot as plt
24
+ from mpl_toolkits.mplot3d import Axes3D
25
+ from torch.nn.functional import interpolate as intrp
26
+
27
+ """
28
+ Visualize input and output motion sequences and labels
29
+ """
30
+
31
+
32
+ def get_smpl_skeleton():
33
+ """Skeleton ordering so that you traverse joints in this order:
34
+ Left lower, Left upper, Spine, Neck, Head, Right lower, Right upper.
35
+ """
36
+ return np.array(
37
+ [
38
+ # Left lower
39
+ [0, 1],
40
+ [1, 4],
41
+ [4, 7],
42
+ [7, 10],
43
+ # Left upper
44
+ [9, 13],
45
+ [13, 16],
46
+ [16, 18],
47
+ [18, 20],
48
+ # [20, 22],
49
+ # Spinal column
50
+ [0, 3],
51
+ [3, 6],
52
+ [6, 9],
53
+ [9, 12],
54
+ [12, 15],
55
+ # Right lower
56
+ [0, 2],
57
+ [2, 5],
58
+ [5, 8],
59
+ [8, 11],
60
+ # Right upper
61
+ [9, 14],
62
+ [14, 17],
63
+ [17, 19],
64
+ [19, 21],
65
+ # [21, 23],
66
+ ]
67
+ )
68
+
69
+
70
+ def get_nturgbd_joint_names():
71
+ """From paper:
72
+ 1-base of the spine 2-middle of the spine 3-neck 4-head 5-left shoulder 6-left elbow 7-left wrist 8- left hand 9-right shoulder 10-right elbow 11-right wrist 12- right hand 13-left hip 14-left knee 15-left ankle 16-left foot 17- right hip 18-right knee 19-right ankle 20-right foot 21-spine 22- tip of the left hand 23-left thumb 24-tip of the right hand 25- right thumb
73
+ """
74
+ # Joint names by AC, based on SMPL names
75
+ joint_names_map = {
76
+ 0: "Pelvis",
77
+ 12: "L_Hip",
78
+ 13: "L_Knee",
79
+ 14: "L_Ankle",
80
+ 15: "L_Foot",
81
+ 16: "R_Hip",
82
+ 17: "R_Knee",
83
+ 18: "R_Ankle",
84
+ 19: "R_Foot",
85
+ 1: "Spine1",
86
+ # 'Spine2',
87
+ 20: "Spine3",
88
+ 2: "Neck",
89
+ 3: "Head",
90
+ # 'L_Collar',
91
+ 4: "L_Shoulder",
92
+ 5: "L_Elbow",
93
+ 6: "L_Wrist",
94
+ 7: "L_Hand",
95
+ 21: "L_HandTip", # Not in SMPL
96
+ 22: "L_Thumb", # Not in SMPL
97
+ # 'R_Collar',
98
+ 8: "R_Shoulder",
99
+ 9: "R_Elbow",
100
+ 10: "R_Wrist",
101
+ 11: "R_Hand",
102
+ 23: "R_HandTip", # Not in SMPL
103
+ 24: "R_Thumb", # Not in SMPL
104
+ }
105
+
106
+ return [joint_names_map[idx] for idx in range(len(joint_names_map))]
107
+
108
+
109
+ def get_smpl_joint_names():
110
+ # Joint names from SMPL Wiki
111
+ joint_names_map = {
112
+ 0: "Pelvis",
113
+ 1: "L_Hip",
114
+ 4: "L_Knee",
115
+ 7: "L_Ankle",
116
+ 10: "L_Foot",
117
+ 2: "R_Hip",
118
+ 5: "R_Knee",
119
+ 8: "R_Ankle",
120
+ 11: "R_Foot",
121
+ 3: "Spine1",
122
+ 6: "Spine2",
123
+ 9: "Spine3",
124
+ 12: "Neck",
125
+ 15: "Head",
126
+ 13: "L_Collar",
127
+ 16: "L_Shoulder",
128
+ 18: "L_Elbow",
129
+ 20: "L_Wrist",
130
+ 22: "L_Hand",
131
+ 14: "R_Collar",
132
+ 17: "R_Shoulder",
133
+ 19: "R_Elbow",
134
+ 21: "R_Wrist",
135
+ 23: "R_Hand",
136
+ }
137
+
138
+ # Return all joints except indices 22 (L_Hand), 23 (R_Hand)
139
+ return [joint_names_map[idx] for idx in range(len(joint_names_map) - 2)]
140
+
141
+
142
+ def get_nturgbd_skeleton():
143
+ """Skeleton ordering such that you traverse joints in this order:
144
+ Left lower, Left upper, Spine, Neck, Head, Right lower, Right upper.
145
+ """
146
+ return np.array(
147
+ [
148
+ # Left lower
149
+ [0, 12],
150
+ [12, 13],
151
+ [13, 14],
152
+ [14, 15],
153
+ # Left upper
154
+ [4, 20],
155
+ [4, 5],
156
+ [5, 6],
157
+ [6, 7],
158
+ [7, 21],
159
+ [7, 22], # --> L Thumb
160
+ # Spinal column
161
+ [0, 1],
162
+ [1, 20],
163
+ [20, 2],
164
+ [2, 3],
165
+ # Right lower
166
+ [0, 16],
167
+ [16, 17],
168
+ [17, 18],
169
+ [18, 19],
170
+ # Right upper
171
+ [20, 8],
172
+ [8, 9],
173
+ [9, 10],
174
+ [10, 11],
175
+ [11, 24],
176
+ # [24, 11] --> R Thumb
177
+ [21, 22],
178
+ [23, 24],
179
+ ]
180
+ )
181
+
182
+
183
+ def get_joint_colors(joint_names):
184
+ """Return joints based on a color spectrum. Also, joints on
185
+ L and R should have distinctly different colors.
186
+ """
187
+ # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
188
+ cmap = plt.get_cmap("rainbow")
189
+ colors = [cmap(i) for i in np.linspace(0, 1, len(joint_names))]
190
+ colors = [np.array((c[2], c[1], c[0])) for c in colors]
191
+ return colors
192
+
193
+
194
+ def calc_angle_from_x(sk):
195
+ """Given skeleton, calc. angle from x-axis"""
196
+ # Hip bone
197
+ id_l_hip = get_smpl_joint_names().index("L_Hip")
198
+ id_r_hip = get_smpl_joint_names().index("R_Hip")
199
+ pl, pr = sk[id_l_hip], sk[id_r_hip]
200
+ bone = np.array(pr - pl)
201
+ unit_v = bone / np.linalg.norm(bone)
202
+ # Angle with x-axis
203
+ pdb.set_trace()
204
+ x_ax = np.array([1, 0, 0])
205
+ x_angle = math.degrees(np.arccos(np.dot(x_ax, unit_v)))
206
+
207
+ """
208
+ l_hip_z = seq[0, joint_names.index('L_Hip'), 2]
209
+ r_hip_z = seq[0, joint_names.index('R_Hip'), 2]
210
+ az = 0 if (l_hip_z > zroot and zroot > r_hip_z) else 180
211
+ """
212
+ if bone[1] > 0:
213
+ x_angle = -x_angle
214
+
215
+ return x_angle
216
+
217
+
218
+ def calc_angle_from_y(sk):
219
+ """Given skeleton, calc. angle from x-axis"""
220
+ # Hip bone
221
+ id_l_hip = get_smpl_joint_names().index("L_Hip")
222
+ id_r_hip = get_smpl_joint_names().index("R_Hip")
223
+ pl, pr = sk[id_l_hip], sk[id_r_hip]
224
+ bone = np.array(pl - pr)
225
+ unit_v = bone / np.linalg.norm(bone)
226
+ print(unit_v)
227
+ # Angle with x-axis
228
+ pdb.set_trace()
229
+ y_ax = np.array([0, 1, 0])
230
+ y_angle = math.degrees(np.arccos(np.dot(y_ax, unit_v)))
231
+
232
+ """
233
+ l_hip_z = seq[0, joint_names.index('L_Hip'), 2]
234
+ r_hip_z = seq[0, joint_names.index('R_Hip'), 2]
235
+ az = 0 if (l_hip_z > zroot and zroot > r_hip_z) else 180
236
+ """
237
+ # if bone[1] > 0:
238
+ # y_angle = - y_angle
239
+ seq_y_proj = bone * np.cos(np.deg2rad(y_angle))
240
+ print("Bone projected onto y-axis: ", seq_y_proj)
241
+
242
+ return y_angle
243
+
244
+
245
+ def viz_skeleton(
246
+ seq,
247
+ folder_p,
248
+ sk_type="smpl",
249
+ radius=1,
250
+ lcolor="#ff0000",
251
+ rcolor="#0000ff",
252
+ action="",
253
+ debug=False,
254
+ ):
255
+ """Visualize skeletons for given sequence and store as images.
256
+
257
+ Args:
258
+ seq (np.array): Array (frames) of joint positions.
259
+ Size depends on sk_type (see below).
260
+ if sk_type is 'smpl' then assume:
261
+ 1. first 3 dims = translation.
262
+ 2. Size = (# frames, 69)
263
+ elif sk_type is 'nturgbd', then assume:
264
+ 1. no translation.
265
+ 2. Size = (# frames, 25, 3)
266
+ folder_p (str): Path to root folder containing visualized frames.
267
+ Frames are dumped to the path: folder_p/frames/*.jpg
268
+ radius (float): Space around the subject?
269
+
270
+ Returns:
271
+ Stores skeleton sequence as jpg frames.
272
+ """
273
+ joint_names = (
274
+ get_nturgbd_joint_names() if "nturgbd" == sk_type else get_smpl_joint_names()
275
+ )
276
+ n_j = n_j = len(joint_names)
277
+
278
+ az = 90
279
+ if "smpl" == sk_type:
280
+ # SMPL kinematic chain, joint list.
281
+ # NOTE that hands are skipped.
282
+ kin_chain = get_smpl_skeleton()
283
+ # Reshape flat pose features into (frames, joints, (x,y,z)) (skip trans)
284
+ seq = seq[:, 3:].reshape(-1, n_j, 3).cpu().detach().numpy()
285
+
286
+ elif "nturgbd" == sk_type:
287
+ kin_chain = get_nturgbd_skeleton()
288
+ az = 0
289
+
290
+ # Get color-spectrum for skeleton
291
+ colors = get_joint_colors(joint_names)
292
+ labels = [(joint_names[jidx[0]], joint_names[jidx[1]]) for jidx in kin_chain]
293
+
294
+ # xroot, yroot, zroot = 0.0, 0.0, 0.0
295
+ xroot, yroot, zroot = seq[0, 0, 0], seq[0, 0, 1], seq[0, 0, 2]
296
+ # seq = seq - seq[0, :, :]
297
+
298
+ # Change viewing angle so that first frame is in frontal pose
299
+ # az = calc_angle_from_x(seq[0]-np.array([xroot, yroot, zroot]))
300
+ # az = calc_angle_from_y(seq[0]-np.array([xroot, yroot, zroot]))
301
+
302
+ # Viz. skeleton for each frame
303
+ for t in range(seq.shape[0]):
304
+
305
+ # Fig. settings
306
+ fig = plt.figure(figsize=(7, 6)) if debug else plt.figure(figsize=(5, 5))
307
+ ax = fig.add_subplot(111, projection="3d")
308
+
309
+ for i, (j1, j2) in enumerate(kin_chain):
310
+ # Store bones
311
+ x = np.array([seq[t, j1, 0], seq[t, j2, 0]])
312
+ y = np.array([seq[t, j1, 1], seq[t, j2, 1]])
313
+ z = np.array([seq[t, j1, 2], seq[t, j2, 2]])
314
+ # Plot bones in skeleton
315
+ ax.plot(x, y, z, c=colors[i], marker="o", linewidth=2, label=labels[i])
316
+
317
+ # More figure settings
318
+ ax.set_title(action)
319
+ ax.set_xlabel("X")
320
+ ax.set_ylabel("Y")
321
+ ax.set_zlabel("Z")
322
+ # xroot, yroot, zroot = seq[t, 0, 0], seq[t, 0, 1], seq[t, 0, 2]
323
+
324
+ # pdb.set_trace()
325
+ ax.set_xlim3d(-radius + xroot, radius + xroot)
326
+ ax.set_ylim3d([-radius + yroot, radius + yroot])
327
+ ax.set_zlim3d([-radius + zroot, radius + zroot])
328
+
329
+ if True == debug:
330
+ ax.axis("on")
331
+ ax.grid(b=True)
332
+ else:
333
+ ax.axis("off")
334
+ ax.grid(b=None)
335
+ # Turn off tick labels
336
+ ax.set_yticklabels([])
337
+ ax.set_xticklabels([])
338
+ ax.set_zticklabels([])
339
+
340
+ cv2.waitKey(0)
341
+
342
+ # ax.view_init(-75, 90)
343
+ # ax.view_init(elev=20, azim=90+az)
344
+ ax.view_init(elev=20, azim=az)
345
+
346
+ if True == debug:
347
+ ax.legend(bbox_to_anchor=(1.1, 1), loc="upper right")
348
+ pass
349
+
350
+ fig.savefig(osp.join(folder_p, "frames", "{0}.jpg".format(t)))
351
+ plt.close(fig)
352
+
353
+ # break
354
+
355
+
356
+ def write_vid_from_imgs(folder_p, fps):
357
+ """Collate frames into a video sequence.
358
+
359
+ Args:
360
+ folder_p (str): Frame images are in the path: folder_p/frames/<int>.jpg
361
+ fps (float): Output frame rate.
362
+
363
+ Returns:
364
+ Output video is stored in the path: folder_p/video.mp4
365
+ """
366
+ vid_p = osp.join(folder_p, "video.mp4")
367
+ cmd = [
368
+ "ffmpeg",
369
+ "-r",
370
+ str(int(fps)),
371
+ "-i",
372
+ osp.join(folder_p, "frames", "%d.jpg"),
373
+ "-y",
374
+ vid_p,
375
+ ]
376
+ FNULL = open(os.devnull, "w")
377
+ retcode = subprocess.call(cmd, stdout=FNULL, stderr=subprocess.STDOUT)
378
+ if not 0 == retcode:
379
+ print(
380
+ "*******ValueError(Error {0} executing command: {1}*********".format(
381
+ retcode, " ".join(cmd)
382
+ )
383
+ )
384
+ shutil.rmtree(osp.join(folder_p, "frames"))
385
+
386
+
387
+ def viz_seq(seq, folder_p, sk_type, orig_fps=30.0, debug=False):
388
+ """1. Dumps sequence of skeleton images for the given sequence of joints.
389
+ 2. Collates the sequence of images into an mp4 video.
390
+
391
+ Args:
392
+ seq (np.array): Array of joint positions.
393
+ folder_p (str): Path to root folder that will contain frames folder.
394
+ sk_type (str): {'smpl', 'nturgbd'}
395
+
396
+ Return:
397
+ None. Path of mp4 video: folder_p/video.mp4
398
+ """
399
+ # Delete folder if exists
400
+ if osp.exists(folder_p):
401
+ print("Deleting existing folder ", folder_p)
402
+ shutil.rmtree(folder_p)
403
+
404
+ # Create folder for frames
405
+ os.makedirs(osp.join(folder_p, "frames"))
406
+
407
+ # Dump frames into folder. Args: (data, radius, frames path)
408
+ viz_skeleton(seq, folder_p=folder_p, sk_type=sk_type, radius=1.2, debug=debug)
409
+ write_vid_from_imgs(folder_p, orig_fps)
410
+
411
+ return None
412
+
413
+
414
+ def viz_rand_seq(X, Y, dtype, epoch, wb, urls=None, k=3, pred_labels=None):
415
+ """
416
+ Args:
417
+ X (np.array): Array (frames) of SMPL joint positions.
418
+ Y (np.array): Multiple labels for each frame in x \in X.
419
+ dtype (str): {'input', 'pred'}
420
+ k (int): # samples to viz.
421
+ urls (tuple): Tuple of URLs of the rendered videos from original mocap.
422
+ wb (dict): Wandb log dict.
423
+ Returns:
424
+ viz_ds (dict): Data structure containing all viz. info so far.
425
+ """
426
+ import wandb
427
+
428
+ # `idx2al`: idx --> action label string
429
+ al2idx = dutils.read_json("data/action_label_to_idx.json")
430
+ idx2al = {al2idx[k]: k for k in al2idx}
431
+
432
+ # Sample k random seqs. to viz.
433
+ for s_idx in random.sample(list(range(X.shape[0])), k):
434
+ # Visualize a single seq. in path `folder_p`
435
+ folder_p = osp.join("viz", str(uuid.uuid4()))
436
+ viz_seq(seq=X[s_idx], folder_p=folder_p)
437
+ title = "{0} seq. {1}: ".format(dtype, s_idx)
438
+ acts_str = ", ".join([idx2al[l] for l in torch.unique(Y[s_idx])])
439
+ wb[title + urls[s_idx]] = wandb.Video(
440
+ osp.join(folder_p, "video.mp4"), caption="Actions: " + acts_str
441
+ )
442
+
443
+ if "pred" == dtype or "preds" == dtype:
444
+ raise NotImplementedError
445
+
446
+ print("Done viz. {0} seqs.".format(k))
447
+ return wb
pyproject.toml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.pysen]
2
+ version = "0.10"
3
+
4
+ [tool.pysen.lint]
5
+ enable_black = true
6
+ enable_flake8 = true
7
+ enable_isort = true
8
+ enable_mypy = true
9
+ mypy_preset = "strict"
10
+ line_length = 88
11
+ py_version = "py37"
12
+ [[tool.pysen.lint.mypy_targets]]
13
+ paths = ["."]
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ backcall==0.2.0
2
+ certifi==2020.12.5
3
+ decorator==4.4.2
4
+ ipdb==0.13.4
5
+ jedi==0.18.0
6
+ joblib==1.0.0
7
+ networkx==2.5
8
+ parso==0.8.1
9
+ pexpect==4.8.0
10
+ pickleshare==0.7.5
11
+ prompt-toolkit==3.0.10
12
+ protobuf==3.14.0
13
+ ptyprocess==0.7.0
14
+ Pygments==2.7.4
15
+ six==1.15.0
16
+ tensorboardX==2.1
17
+ threadpoolctl==2.1.0
18
+ tqdm==4.56.0
19
+ traitlets==5.0.5
20
+ typing-extensions==3.7.4.3
21
+ wcwidth==0.2.5
22
+ smplx==0.1.28
23
+ opencv-python
24
+ einops
25
+ matplotlib
26
+ scikit-learn
27
+ pandas
28
+ chumpy
train.py ADDED
@@ -0,0 +1,830 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright 2023 LINE Corporation
3
+ LINE Corporation licenses this file to you under the Apache License,
4
+ version 2.0 (the "License"); you may not use this file except in compliance
5
+ with the License. You may obtain a copy of the License at:
6
+ https://www.apache.org/licenses/LICENSE-2.0
7
+ Unless required by applicable law or agreed to in writing, software
8
+ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9
+ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10
+ License for the specific language governing permissions and limitations
11
+ under the License.
12
+ """
13
+
14
+ from __future__ import print_function
15
+
16
+ import argparse
17
+ import inspect
18
+ import os
19
+ import pdb
20
+ import pickle
21
+ import random
22
+ import re
23
+ import shutil
24
+ import time
25
+ from collections import *
26
+
27
+ import ipdb
28
+ import numpy as np
29
+
30
+ # torch
31
+ import torch
32
+ import torch.backends.cudnn as cudnn
33
+ import torch.nn as nn
34
+ import torch.nn.functional as F
35
+ import torch.optim as optim
36
+ import yaml
37
+ from einops import rearrange, reduce, repeat
38
+ from evaluation.classificationMAP import getClassificationMAP as cmAP
39
+ from evaluation.detectionMAP import getSingleStreamDetectionMAP as dsmAP
40
+ from feeders.tools import collate_with_padding_multi_joint
41
+ from model.losses import cross_entropy_loss, mvl_loss
42
+ from sklearn.metrics import f1_score
43
+
44
+ # Custom
45
+ from tensorboardX import SummaryWriter
46
+ from torch.autograd import Variable
47
+ from torch.optim.lr_scheduler import _LRScheduler
48
+ from tqdm import tqdm
49
+ from utils.logger import Logger
50
+
51
+ def remove_prefix_from_state_dict(state_dict, prefix):
52
+ new_state_dict = {}
53
+ for k, v in state_dict.items():
54
+ if k.startswith(prefix):
55
+ print(k)
56
+ new_k = k[len(prefix):] # strip the prefix
57
+ print(new_k)
58
+ else:
59
+ print(k)
60
+ new_k = k
61
+ new_state_dict[new_k] = v
62
+ return new_state_dict
63
+
64
+
65
+ def init_seed(seed):
66
+ torch.cuda.manual_seed_all(seed)
67
+ torch.manual_seed(seed)
68
+ np.random.seed(seed)
69
+ random.seed(seed)
70
+ torch.backends.cudnn.deterministic = True
71
+ torch.backends.cudnn.benchmark = False
72
+
73
+
74
+ def get_parser():
75
+ # parameter priority: command line > config > default
76
+ parser = argparse.ArgumentParser(
77
+ description="Spatial Temporal Graph Convolution Network"
78
+ )
79
+ parser.add_argument(
80
+ "--work-dir",
81
+ default="./work_dir/temp",
82
+ help="the work folder for storing results",
83
+ )
84
+
85
+ parser.add_argument("-model_saved_name", default="")
86
+ parser.add_argument(
87
+ "--config",
88
+ default="./config/nturgbd-cross-view/test_bone.yaml",
89
+ help="path to the configuration file",
90
+ )
91
+
92
+ # processor
93
+ parser.add_argument("--phase", default="train", help="must be train or test")
94
+
95
+ # visulize and debug
96
+ parser.add_argument("--seed", type=int, default=5, help="random seed for pytorch")
97
+ parser.add_argument(
98
+ "--log-interval",
99
+ type=int,
100
+ default=100,
101
+ help="the interval for printing messages (#iteration)",
102
+ )
103
+ parser.add_argument(
104
+ "--save-interval",
105
+ type=int,
106
+ default=2,
107
+ help="the interval for storing models (#iteration)",
108
+ )
109
+ parser.add_argument(
110
+ "--eval-interval",
111
+ type=int,
112
+ default=5,
113
+ help="the interval for evaluating models (#iteration)",
114
+ )
115
+ parser.add_argument(
116
+ "--print-log", type=str2bool, default=True, help="print logging or not"
117
+ )
118
+ parser.add_argument(
119
+ "--show-topk",
120
+ type=int,
121
+ default=[1, 5],
122
+ nargs="+",
123
+ help="which Top K accuracy will be shown",
124
+ )
125
+
126
+ # feeder
127
+ parser.add_argument(
128
+ "--feeder", default="feeder.feeder", help="data loader will be used"
129
+ )
130
+ parser.add_argument(
131
+ "--num-worker",
132
+ type=int,
133
+ default=32,
134
+ help="the number of worker for data loader",
135
+ )
136
+ parser.add_argument(
137
+ "--train-feeder-args",
138
+ default=dict(),
139
+ help="the arguments of data loader for training",
140
+ )
141
+ parser.add_argument(
142
+ "--test-feeder-args",
143
+ default=dict(),
144
+ help="the arguments of data loader for test",
145
+ )
146
+
147
+ # model
148
+ parser.add_argument("--model", default=None, help="the model will be used")
149
+ parser.add_argument(
150
+ "--model-args", type=dict, default=dict(), help="the arguments of model"
151
+ )
152
+ parser.add_argument(
153
+ "--weights", default=None, help="the weights for network initialization"
154
+ )
155
+ parser.add_argument(
156
+ "--ignore-weights",
157
+ type=str,
158
+ default=[],
159
+ nargs="+",
160
+ help="the name of weights which will be ignored in the initialization",
161
+ )
162
+
163
+ # optim
164
+ parser.add_argument(
165
+ "--base-lr", type=float, default=0.01, help="initial learning rate"
166
+ )
167
+ parser.add_argument(
168
+ "--step",
169
+ type=int,
170
+ default=[60,80],
171
+ nargs="+",
172
+ help="the epoch where optimizer reduce the learning rate",
173
+ )
174
+
175
+ # training
176
+ parser.add_argument(
177
+ "--device",
178
+ type=int,
179
+ default=0,
180
+ nargs="+",
181
+ help="the indexes of GPUs for training or testing",
182
+ )
183
+ parser.add_argument("--optimizer", default="SGD", help="type of optimizer")
184
+ parser.add_argument(
185
+ "--nesterov", type=str2bool, default=False, help="use nesterov or not"
186
+ )
187
+ parser.add_argument(
188
+ "--batch-size", type=int, default=256, help="training batch size"
189
+ )
190
+ parser.add_argument(
191
+ "--test-batch-size", type=int, default=256, help="test batch size"
192
+ )
193
+ parser.add_argument(
194
+ "--start-epoch", type=int, default=0, help="start training from which epoch"
195
+ )
196
+ parser.add_argument(
197
+ "--num-epoch", type=int, default=80, help="stop training in which epoch"
198
+ )
199
+ parser.add_argument(
200
+ "--weight-decay", type=float, default=0.0005, help="weight decay for optimizer"
201
+ )
202
+ # loss
203
+ parser.add_argument("--loss", type=str, default="CE", help="loss type(CE or focal)")
204
+ parser.add_argument(
205
+ "--label_count_path",
206
+ default=None,
207
+ type=str,
208
+ help="Path to label counts (used in loss weighting)",
209
+ )
210
+ parser.add_argument(
211
+ "---beta",
212
+ type=float,
213
+ default=0.9999,
214
+ help="Hyperparameter for Class balanced loss",
215
+ )
216
+ parser.add_argument(
217
+ "--gamma", type=float, default=2.0, help="Hyperparameter for Focal loss"
218
+ )
219
+
220
+ parser.add_argument("--only_train_part", default=False)
221
+ parser.add_argument("--only_train_epoch", default=0)
222
+ parser.add_argument("--warm_up_epoch", default=0)
223
+
224
+ parser.add_argument(
225
+ "--lambda-mil", default=1.0, help="balancing hyper-parameter of mil branch"
226
+ )
227
+
228
+ parser.add_argument(
229
+ "--class-threshold",
230
+ type=float,
231
+ default=0.1,
232
+ help="class threshold for rejection",
233
+ )
234
+ parser.add_argument(
235
+ "--start-threshold",
236
+ type=float,
237
+ default=0.03,
238
+ help="start threshold for action localization",
239
+ )
240
+ parser.add_argument(
241
+ "--end-threshold",
242
+ type=float,
243
+ default=0.055,
244
+ help="end threshold for action localization",
245
+ )
246
+ parser.add_argument(
247
+ "--threshold-interval",
248
+ type=float,
249
+ default=0.005,
250
+ help="threshold interval for action localization",
251
+ )
252
+ return parser
253
+
254
+
255
+ class Processor:
256
+ """
257
+ Processor for Skeleton-based Action Recgnition
258
+ """
259
+
260
+ def __init__(self, arg):
261
+ self.arg = arg
262
+ self.save_arg()
263
+ if arg.phase == "train":
264
+ if not arg.train_feeder_args["debug"]:
265
+ if os.path.isdir(arg.model_saved_name):
266
+ print("log_dir: ", arg.model_saved_name, "already exist")
267
+ # answer = input('delete it? y/n:')
268
+ answer = "y"
269
+ if answer == "y":
270
+ print("Deleting dir...")
271
+ shutil.rmtree(arg.model_saved_name)
272
+ print("Dir removed: ", arg.model_saved_name)
273
+ # input('Refresh the website of tensorboard by pressing any keys')
274
+ else:
275
+ print("Dir not removed: ", arg.model_saved_name)
276
+ self.train_writer = SummaryWriter(
277
+ os.path.join(arg.model_saved_name, "train"), "train"
278
+ )
279
+ self.val_writer = SummaryWriter(
280
+ os.path.join(arg.model_saved_name, "val"), "val"
281
+ )
282
+ else:
283
+ self.train_writer = self.val_writer = SummaryWriter(
284
+ os.path.join(arg.model_saved_name, "test"), "test"
285
+ )
286
+ self.global_step = 0
287
+ self.load_model()
288
+ self.load_optimizer()
289
+ self.load_data()
290
+ self.lr = self.arg.base_lr
291
+ self.best_acc = 0
292
+ self.best_per_class_acc = 0
293
+ self.loss_nce = torch.nn.BCELoss()
294
+
295
+ self.my_logger = Logger(
296
+ os.path.join(arg.model_saved_name, "log.txt"), title="SWTAL"
297
+ )
298
+ self.my_logger.set_names(["Step", "cmap"] + [f"map_0.{i}" for i in range(1, 6)])
299
+
300
+ def load_data(self):
301
+ Feeder = import_class(self.arg.feeder)
302
+ self.data_loader = dict()
303
+ if self.arg.phase == "train":
304
+ self.data_loader["train"] = torch.utils.data.DataLoader(
305
+ dataset=Feeder(**self.arg.train_feeder_args),
306
+ batch_size=self.arg.batch_size,
307
+ shuffle=True,
308
+ num_workers=self.arg.num_worker,
309
+ drop_last=True,
310
+ collate_fn=collate_with_padding_multi_joint,
311
+ )
312
+ self.data_loader["test"] = torch.utils.data.DataLoader(
313
+ dataset=Feeder(**self.arg.test_feeder_args),
314
+ batch_size=self.arg.test_batch_size,
315
+ shuffle=False,
316
+ num_workers=self.arg.num_worker,
317
+ drop_last=False,
318
+ collate_fn=collate_with_padding_multi_joint,
319
+ )
320
+ def load_model(self):
321
+ output_device = (
322
+ self.arg.device[0] if type(self.arg.device) is list else self.arg.device
323
+ )
324
+ self.output_device = output_device
325
+ Model = import_class(self.arg.model)
326
+ shutil.copy2(inspect.getfile(Model), self.arg.work_dir)
327
+ # print(Model)
328
+ self.model = Model(**self.arg.model_args).cuda(output_device)
329
+ # print(self.model)
330
+ self.loss_type = arg.loss
331
+
332
+ if self.arg.weights:
333
+ # if False:
334
+ # self.global_step = int(arg.weights[:-3].split("-")[-1])
335
+ self.print_log("Load weights from {}.".format(self.arg.weights))
336
+ if ".pkl" in self.arg.weights:
337
+ with open(self.arg.weights, "r") as f:
338
+ weights = pickle.load(f)
339
+ else:
340
+ weights = torch.load(self.arg.weights)
341
+
342
+ weights = OrderedDict(
343
+ [
344
+ [k.split("module.")[-1], v.cuda(output_device)]
345
+ for k, v in weights.items()
346
+ ]
347
+ )
348
+ weights = remove_prefix_from_state_dict(weights, 'encoder_q.agcn.')
349
+ keys = list(weights.keys())
350
+ self.arg.ignore_weights = ['encoder_q','encoder_q.relation','data_bn','fc','encoder_k','queue','queue_ptr','value_transform']
351
+ for w in self.arg.ignore_weights:
352
+ for key in keys:
353
+ if w in key:
354
+ if weights.pop(key, None) is not None:
355
+ self.print_log(
356
+ "Sucessfully Remove Weights: {}.".format(key)
357
+ )
358
+ else:
359
+ self.print_log("Can Not Remove Weights: {}.".format(key))
360
+
361
+ try:
362
+ self.model.load_state_dict(weights)
363
+ except:
364
+ state = self.model.state_dict()
365
+ diff = list(set(state.keys()).difference(set(weights.keys())))
366
+ print("Can not find these weights:")
367
+ for d in diff:
368
+ print(" " + d)
369
+ state.update(weights)
370
+ self.model.load_state_dict(state)
371
+
372
+ if type(self.arg.device) is list:
373
+ if len(self.arg.device) > 1:
374
+ self.model = nn.DataParallel(
375
+ self.model, device_ids=self.arg.device, output_device=output_device
376
+ )
377
+
378
+ # # def load_model(self):
379
+ # output_device = (
380
+ # self.arg.device[0] if type(self.arg.device) is list else self.arg.device
381
+ # )
382
+ # self.output_device = output_device
383
+ # Model = import_class(self.arg.model)
384
+ # shutil.copy2(inspect.getfile(Model), self.arg.work_dir)
385
+ # # print(Model)
386
+ # self.model = Model(**self.arg.model_args).cuda(output_device)
387
+ # # print(self.model)
388
+ # self.loss_type = arg.loss
389
+
390
+ # if self.arg.weights:
391
+ # # self.global_step = int(arg.weights[:-3].split("-")[-1])
392
+ # self.print_log("Load weights from {}.".format(self.arg.weights))
393
+ # if ".pkl" in self.arg.weights:
394
+ # with open(self.arg.weights, "r") as f:
395
+ # weights = pickle.load(f)
396
+ # else:
397
+ # weights = torch.load(self.arg.weights)
398
+
399
+ # weights = OrderedDict(
400
+ # [
401
+ # [k.split("module.")[-1], v.cuda(output_device)]
402
+ # for k, v in weights.items()
403
+ # ]
404
+ # )
405
+
406
+ # keys = list(weights.keys())
407
+ # for w in self.arg.ignore_weights:
408
+ # for key in keys:
409
+ # if w in key:
410
+ # if weights.pop(key, None) is not None:
411
+ # self.print_log(
412
+ # "Sucessfully Remove Weights: {}.".format(key)
413
+ # )
414
+ # else:
415
+ # self.print_log("Can Not Remove Weights: {}.".format(key))
416
+
417
+ # try:
418
+ # self.model.load_state_dict(weights)
419
+ # except:
420
+ # state = self.model.state_dict()
421
+ # diff = list(set(state.keys()).difference(set(weights.keys())))
422
+ # print("Can not find these weights:")
423
+ # for d in diff:
424
+ # print(" " + d)
425
+ # state.update(weights)
426
+ # self.model.load_state_dict(state)
427
+
428
+ # if type(self.arg.device) is list:
429
+ # if len(self.arg.device) > 1:
430
+ # self.model = nn.DataParallel(
431
+ # self.model, device_ids=self.arg.device, output_device=output_device
432
+ # )
433
+
434
+ def load_optimizer(self):
435
+ if self.arg.optimizer == "SGD":
436
+ self.optimizer = optim.SGD(
437
+ self.model.parameters(),
438
+ lr=self.arg.base_lr,
439
+ momentum=0.9,
440
+ nesterov=self.arg.nesterov,
441
+ weight_decay=self.arg.weight_decay,
442
+ )
443
+ elif self.arg.optimizer == "Adam":
444
+ self.optimizer = optim.Adam(
445
+ self.model.parameters(),
446
+ lr=self.arg.base_lr,
447
+ weight_decay=self.arg.weight_decay,
448
+ )
449
+ else:
450
+ raise ValueError()
451
+
452
+ def save_arg(self):
453
+ # save arg
454
+ arg_dict = vars(self.arg)
455
+ if not os.path.exists(self.arg.work_dir):
456
+ os.makedirs(self.arg.work_dir)
457
+ with open("{}/config.yaml".format(self.arg.work_dir), "w") as f:
458
+ yaml.dump(arg_dict, f)
459
+
460
+ def adjust_learning_rate(self, epoch):
461
+ if self.arg.optimizer == "SGD" or self.arg.optimizer == "Adam":
462
+ if epoch < self.arg.warm_up_epoch:
463
+ lr = self.arg.base_lr * (epoch + 1) / self.arg.warm_up_epoch
464
+ else:
465
+ lr = self.arg.base_lr * (
466
+ 0.1 ** np.sum(epoch >= np.array(self.arg.step))
467
+ )
468
+ for param_group in self.optimizer.param_groups:
469
+ param_group["lr"] = lr
470
+
471
+ return lr
472
+ else:
473
+ raise ValueError()
474
+
475
+ def print_time(self):
476
+ localtime = time.asctime(time.localtime(time.time()))
477
+ self.print_log("Local current time : " + localtime)
478
+
479
+ def print_log(self, str, print_time=True):
480
+ if print_time:
481
+ localtime = time.asctime(time.localtime(time.time()))
482
+ str = "[ " + localtime + " ] " + str
483
+ print(str)
484
+ if self.arg.print_log:
485
+ with open("{}/print_log.txt".format(self.arg.work_dir), "a") as f:
486
+ print(str, file=f)
487
+
488
+ def record_time(self):
489
+ self.cur_time = time.time()
490
+ return self.cur_time
491
+
492
+ def split_time(self):
493
+ split_time = time.time() - self.cur_time
494
+ self.record_time()
495
+ return split_time
496
+
497
+ def train(self, epoch, wb_dict, save_model=False):
498
+ self.model.train()
499
+ self.print_log("Training epoch: {}".format(epoch + 1))
500
+ loader = self.data_loader["train"]
501
+ self.adjust_learning_rate(epoch)
502
+
503
+ loss_value, batch_acc = [], []
504
+ self.train_writer.add_scalar("epoch", epoch, self.global_step)
505
+ self.record_time()
506
+ timer = dict(dataloader=0.001, model=0.001, statistics=0.001)
507
+ process = tqdm(loader)
508
+ if self.arg.only_train_part:
509
+ if epoch > self.arg.only_train_epoch:
510
+ print("only train part, require grad")
511
+ for key, value in self.model.named_parameters():
512
+ if "PA" in key:
513
+ value.requires_grad = True
514
+ else:
515
+ print("only train part, do not require grad")
516
+ for key, value in self.model.named_parameters():
517
+ if "PA" in key:
518
+ value.requires_grad = False
519
+
520
+ vid_preds = []
521
+ frm_preds = []
522
+ vid_lens = []
523
+ labels = []
524
+
525
+ results = []
526
+ indexs = []
527
+
528
+ for batch_idx, (data, label, target, mask, index, soft_label) in enumerate(
529
+ process
530
+ ):
531
+
532
+ self.global_step += 1
533
+ # get data
534
+ data = data.float().cuda(self.output_device)
535
+ label = label.cuda(self.output_device)
536
+ mask = mask.cuda(self.output_device)
537
+ soft_label = soft_label.cuda(self.output_device)
538
+ timer["dataloader"] += self.split_time()
539
+
540
+ indexs.extend(index.cpu().numpy().tolist())
541
+
542
+ ab_labels = torch.cat([label, torch.ones(label.size(0), 1).cuda()], -1)
543
+
544
+ # forward
545
+ mil_pred, frm_scrs, mil_pred_2, frm_scrs_2 = self.model(data,mask)
546
+
547
+ cls_mil_loss = self.loss_nce(mil_pred, ab_labels.float()) + self.loss_nce(
548
+ mil_pred_2, ab_labels.float()
549
+ )
550
+
551
+ if epoch > 10:
552
+
553
+ frm_scrs_re = rearrange(frm_scrs, "n t c -> (n t) c")
554
+ frm_scrs_2_re = rearrange(frm_scrs_2, "n t c -> (n t) c")
555
+ soft_label = rearrange(soft_label, "n t c -> (n t) c")
556
+
557
+ loss = cls_mil_loss * 0.1 + mvl_loss(
558
+ frm_scrs, frm_scrs_2, rate=0.2, weight=0.5
559
+ )
560
+
561
+ loss += cross_entropy_loss(
562
+ frm_scrs_re, soft_label
563
+ ) + cross_entropy_loss(frm_scrs_2_re, soft_label)
564
+
565
+ else:
566
+ loss = cls_mil_loss * self.arg.lambda_mil + mvl_loss(
567
+ frm_scrs, frm_scrs_2, rate=0.2, weight=0.5
568
+ )
569
+
570
+ for i in range(data.size(0)):
571
+ frm_scr = frm_scrs[i]
572
+
573
+ label_ = label[i].cpu().numpy()
574
+ mask_ = mask[i].cpu().numpy()
575
+ vid_len = mask_.sum()
576
+
577
+ frm_pred = F.softmax(frm_scr, -1).detach().cpu().numpy()[:vid_len]
578
+ vid_pred = mil_pred[i].detach().cpu().numpy()
579
+
580
+ results.append(frm_pred)
581
+
582
+ vid_preds.append(vid_pred)
583
+ frm_preds.append(frm_pred)
584
+ vid_lens.append(vid_len)
585
+ labels.append(label_)
586
+
587
+ # backward
588
+ self.optimizer.zero_grad()
589
+ loss.backward()
590
+ self.optimizer.step()
591
+
592
+ loss_value.append(loss.data.item())
593
+ timer["model"] += self.split_time()
594
+
595
+ vid_preds = np.array(vid_preds)
596
+ frm_preds = np.array(frm_preds)
597
+ vid_lens = np.array(vid_lens)
598
+ labels = np.array(labels)
599
+
600
+ loader.dataset.label_update(results, indexs)
601
+
602
+ cmap = cmAP(vid_preds, labels)
603
+
604
+ self.train_writer.add_scalar("acc", cmap, self.global_step)
605
+ self.train_writer.add_scalar("loss", np.mean(loss_value), self.global_step)
606
+
607
+ # statistics
608
+ self.lr = self.optimizer.param_groups[0]["lr"]
609
+ self.train_writer.add_scalar("lr", self.lr, self.global_step)
610
+ timer["statistics"] += self.split_time()
611
+
612
+ # statistics of time consumption and loss
613
+ self.print_log("\tMean training loss: {:.4f}.".format(np.mean(loss_value)))
614
+ self.print_log("\tAcc score: {:.3f}%".format(cmap))
615
+
616
+ # Log
617
+ wb_dict["train loss"] = np.mean(loss_value)
618
+ wb_dict["train acc"] = cmap
619
+
620
+ if save_model:
621
+ state_dict = self.model.state_dict()
622
+ weights = OrderedDict(
623
+ [[k.split("module.")[-1], v.cpu()] for k, v in state_dict.items()]
624
+ )
625
+
626
+ torch.save(
627
+ weights,
628
+ self.arg.model_saved_name + str(epoch) + ".pt",
629
+ )
630
+
631
+ return wb_dict
632
+
633
+ @torch.no_grad()
634
+ def eval(
635
+ self,
636
+ epoch,
637
+ wb_dict,
638
+ loader_name=["test"],
639
+ ):
640
+ self.model.eval()
641
+ self.print_log("Eval epoch: {}".format(epoch + 1))
642
+
643
+ vid_preds = []
644
+ frm_preds = []
645
+ vid_lens = []
646
+ labels = []
647
+
648
+ for ln in loader_name:
649
+ loss_value = []
650
+ step = 0
651
+ process = tqdm(self.data_loader[ln])
652
+
653
+ for batch_idx, (data, label, target, mask, index, soft_label) in enumerate(
654
+ process
655
+ ):
656
+ data = data.float().cuda(self.output_device)
657
+ label = label.cuda(self.output_device)
658
+ mask = mask.cuda(self.output_device)
659
+
660
+ ab_labels = torch.cat([label, torch.ones(label.size(0), 1).cuda()], -1)
661
+
662
+ # forward
663
+ mil_pred, frm_scrs, mil_pred_2, frm_scrs_2 = self.model(data,mask)
664
+
665
+ cls_mil_loss = self.loss_nce(
666
+ mil_pred, ab_labels.float()
667
+ ) + self.loss_nce(mil_pred_2, ab_labels.float())
668
+
669
+ loss_co = mvl_loss(frm_scrs, frm_scrs_2, rate=0.2, weight=0.5)
670
+
671
+ loss = cls_mil_loss * self.arg.lambda_mil + loss_co
672
+
673
+ loss_value.append(loss.data.item())
674
+
675
+ for i in range(data.size(0)):
676
+ frm_scr = frm_scrs[i]
677
+ vid_pred = mil_pred[i]
678
+
679
+ label_ = label[i].cpu().numpy()
680
+ mask_ = mask[i].cpu().numpy()
681
+ vid_len = mask_.sum()
682
+
683
+ frm_pred = F.softmax(frm_scr, -1).cpu().numpy()[:vid_len]
684
+ vid_pred = vid_pred.cpu().numpy()
685
+
686
+ vid_preds.append(vid_pred)
687
+ frm_preds.append(frm_pred)
688
+ vid_lens.append(vid_len)
689
+ labels.append(label_)
690
+
691
+ step += 1
692
+
693
+ vid_preds = np.array(vid_preds)
694
+ frm_preds = np.array(frm_preds)
695
+ vid_lens = np.array(vid_lens)
696
+ labels = np.array(labels)
697
+
698
+ cmap = cmAP(vid_preds, labels)
699
+
700
+ score = cmap
701
+ loss = np.mean(loss_value)
702
+
703
+ dmap, iou = dsmAP(
704
+ vid_preds,
705
+ frm_preds,
706
+ vid_lens,
707
+ self.arg.test_feeder_args["data_path"],
708
+ self.arg,
709
+ multi=True,
710
+ )
711
+
712
+ print("Classification map %f" % cmap)
713
+ for item in list(zip(iou, dmap)):
714
+ print("Detection map @ %f = %f" % (item[0], item[1]))
715
+
716
+ self.my_logger.append([epoch + 1, cmap] + dmap)
717
+
718
+ wb_dict["val loss"] = loss
719
+ wb_dict["val acc"] = score
720
+
721
+ if score > self.best_acc:
722
+ self.best_acc = score
723
+
724
+ print("Acc score: ", score, " model: ", self.arg.model_saved_name)
725
+ if self.arg.phase == "train":
726
+ self.val_writer.add_scalar("loss", loss, self.global_step)
727
+ self.val_writer.add_scalar("acc", score, self.global_step)
728
+
729
+ self.print_log(
730
+ "\tMean {} loss of {} batches: {}.".format(
731
+ ln, len(self.data_loader[ln]), np.mean(loss_value)
732
+ )
733
+ )
734
+ self.print_log("\tAcc score: {:.3f}%".format(score))
735
+
736
+ return wb_dict
737
+
738
+ def start(self):
739
+ wb_dict = {}
740
+ if self.arg.phase == "train":
741
+ self.print_log("Parameters:\n{}\n".format(str(vars(self.arg))))
742
+ self.global_step = (
743
+ self.arg.start_epoch
744
+ * len(self.data_loader["train"])
745
+ / self.arg.batch_size
746
+ )
747
+
748
+ for epoch in range(self.arg.start_epoch, self.arg.num_epoch):
749
+
750
+ save_model = ((epoch + 1) % self.arg.save_interval == 0) or (
751
+ epoch + 1 == self.arg.num_epoch
752
+ )
753
+ wb_dict = {"lr": self.lr}
754
+
755
+ # Train
756
+ wb_dict = self.train(epoch, wb_dict, save_model=save_model)
757
+
758
+ # Eval. on val set
759
+ wb_dict = self.eval(epoch, wb_dict, loader_name=["test"])
760
+ # Log stats. for this epoch
761
+ print("Epoch: {0}\nMetrics: {1}".format(epoch, wb_dict))
762
+
763
+ print(
764
+ "best accuracy: ",
765
+ self.best_acc,
766
+ " model_name: ",
767
+ self.arg.model_saved_name,
768
+ )
769
+
770
+ elif self.arg.phase == "test":
771
+ if not self.arg.test_feeder_args["debug"]:
772
+ wf = self.arg.model_saved_name + "_wrong.txt"
773
+ rf = self.arg.model_saved_name + "_right.txt"
774
+ else:
775
+ wf = rf = None
776
+ if self.arg.weights is None:
777
+ raise ValueError("Please appoint --weights.")
778
+ self.arg.print_log = False
779
+ self.print_log("Model: {}.".format(self.arg.model))
780
+ self.print_log("Weights: {}.".format(self.arg.weights))
781
+
782
+ wb_dict = self.eval(
783
+ epoch=0,
784
+ wb_dict=wb_dict,
785
+ loader_name=["test"],
786
+ wrong_file=wf,
787
+ result_file=rf,
788
+ )
789
+ print("Inference metrics: ", wb_dict)
790
+ self.print_log("Done.\n")
791
+
792
+
793
+ def str2bool(v):
794
+ if v.lower() in ("yes", "true", "t", "y", "1"):
795
+ return True
796
+ elif v.lower() in ("no", "false", "f", "n", "0"):
797
+ return False
798
+ else:
799
+ raise argparse.ArgumentTypeError("Boolean value expected.")
800
+
801
+
802
+ def import_class(name):
803
+ components = name.split(".")
804
+ mod = __import__(components[0])
805
+ for comp in components[1:]:
806
+ mod = getattr(mod, comp)
807
+ return mod
808
+
809
+
810
+ if __name__ == "__main__":
811
+ parser = get_parser()
812
+
813
+ # load arg form config file
814
+ p = parser.parse_args()
815
+ if p.config is not None:
816
+ with open(p.config, "r") as f:
817
+ default_arg = yaml.safe_load(f)
818
+ key = vars(p).keys()
819
+ for k in default_arg.keys():
820
+ if k not in key:
821
+ print("WRONG ARG: {}".format(k))
822
+ assert k in key
823
+ parser.set_defaults(**default_arg)
824
+
825
+ arg = parser.parse_args()
826
+ print("BABEL Action Recognition")
827
+ print("Config: ", arg)
828
+ init_seed(arg.seed)
829
+ processor = Processor(arg)
830
+ processor.start()
train_full.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import print_function
3
+
4
+ import argparse
5
+ import inspect
6
+ import os
7
+ import pdb
8
+ import pickle
9
+ import random
10
+ import re
11
+ import shutil
12
+ import time
13
+ from collections import *
14
+
15
+ import ipdb
16
+ import numpy as np
17
+
18
+ # torch
19
+ import torch
20
+ import torch.backends.cudnn as cudnn
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ import torch.optim as optim
24
+ import yaml
25
+ from einops import rearrange, reduce, repeat
26
+ from evaluation.classificationMAP import getClassificationMAP as cmAP
27
+ from evaluation.detectionMAP import getSingleStreamDetectionMAP as dsmAP
28
+ from feeders.tools import collate_with_padding_multi_joint
29
+ from model.losses import cross_entropy_loss, mvl_loss
30
+ from sklearn.metrics import f1_score
31
+
32
+ # Custom
33
+ from tensorboardX import SummaryWriter
34
+ from torch.autograd import Variable
35
+ from torch.optim.lr_scheduler import _LRScheduler
36
+ from tqdm import tqdm
37
+ from utils.logger import Logger
38
+
39
+
40
+ # seed = 0
41
+ # random.seed(seed)
42
+ # np.random.seed(seed)
43
+ # torch.manual_seed(seed)
44
+ # torch.cuda.manual_seed_all(seed)
45
+ # torch.use_deterministic_algorithms(True)
46
+ # torch.backends.cudnn.deterministic = True
47
+ # torch.backends.cudnn.benchmark = False
48
+
49
+
50
+ def init_seed(seed):
51
+ torch.cuda.manual_seed_all(seed)
52
+ torch.manual_seed(seed)
53
+ np.random.seed(seed)
54
+ random.seed(seed)
55
+ torch.backends.cudnn.deterministic = True
56
+ torch.backends.cudnn.benchmark = False
57
+
58
+
59
+ def get_parser():
60
+ # parameter priority: command line > config > default
61
+ parser = argparse.ArgumentParser(
62
+ description="Spatial Temporal Graph Convolution Network"
63
+ )
64
+ parser.add_argument(
65
+ "--work-dir",
66
+ default="./work_dir/temp",
67
+ help="the work folder for storing results",
68
+ )
69
+
70
+ parser.add_argument("-model_saved_name", default="")
71
+ parser.add_argument(
72
+ "--config",
73
+ default="./config/nturgbd-cross-view/test_bone.yaml",
74
+ help="path to the configuration file",
75
+ )
76
+
77
+ # processor
78
+ parser.add_argument("--phase", default="train", help="must be train or test")
79
+
80
+ # visulize and debug
81
+ parser.add_argument("--seed", type=int, default=5, help="random seed for pytorch")
82
+ parser.add_argument(
83
+ "--log-interval",
84
+ type=int,
85
+ default=100,
86
+ help="the interval for printing messages (#iteration)",
87
+ )
88
+ parser.add_argument(
89
+ "--save-interval",
90
+ type=int,
91
+ default=2,
92
+ help="the interval for storing models (#iteration)",
93
+ )
94
+ parser.add_argument(
95
+ "--eval-interval",
96
+ type=int,
97
+ default=5,
98
+ help="the interval for evaluating models (#iteration)",
99
+ )
100
+ parser.add_argument(
101
+ "--print-log", type=str2bool, default=True, help="print logging or not"
102
+ )
103
+ parser.add_argument(
104
+ "--show-topk",
105
+ type=int,
106
+ default=[1, 5],
107
+ nargs="+",
108
+ help="which Top K accuracy will be shown",
109
+ )
110
+
111
+ # feeder
112
+ parser.add_argument(
113
+ "--feeder", default="feeder.feeder", help="data loader will be used"
114
+ )
115
+ parser.add_argument(
116
+ "--num-worker",
117
+ type=int,
118
+ default=32,
119
+ help="the number of worker for data loader",
120
+ )
121
+ parser.add_argument(
122
+ "--train-feeder-args",
123
+ default=dict(),
124
+ help="the arguments of data loader for training",
125
+ )
126
+ parser.add_argument(
127
+ "--test-feeder-args",
128
+ default=dict(),
129
+ help="the arguments of data loader for test",
130
+ )
131
+
132
+ # model
133
+ parser.add_argument("--model", default=None, help="the model will be used")
134
+ parser.add_argument(
135
+ "--model-args", type=dict, default=dict(), help="the arguments of model"
136
+ )
137
+ parser.add_argument(
138
+ "--weights", default=None, help="the weights for network initialization"
139
+ )
140
+ parser.add_argument(
141
+ "--ignore-weights",
142
+ type=str,
143
+ default=[],
144
+ nargs="+",
145
+ help="the name of weights which will be ignored in the initialization",
146
+ )
147
+
148
+ # optim
149
+ parser.add_argument(
150
+ "--base-lr", type=float, default=0.01, help="initial learning rate"
151
+ )
152
+ parser.add_argument(
153
+ "--step",
154
+ type=int,
155
+ default=[60,80],
156
+ nargs="+",
157
+ help="the epoch where optimizer reduce the learning rate",
158
+ )
159
+
160
+ # training
161
+ parser.add_argument(
162
+ "--device",
163
+ type=int,
164
+ default=0,
165
+ nargs="+",
166
+ help="the indexes of GPUs for training or testing",
167
+ )
168
+ parser.add_argument("--optimizer", default="SGD", help="type of optimizer")
169
+ parser.add_argument(
170
+ "--nesterov", type=str2bool, default=False, help="use nesterov or not"
171
+ )
172
+ parser.add_argument(
173
+ "--batch-size", type=int, default=256, help="training batch size"
174
+ )
175
+ parser.add_argument(
176
+ "--test-batch-size", type=int, default=256, help="test batch size"
177
+ )
178
+ parser.add_argument(
179
+ "--start-epoch", type=int, default=0, help="start training from which epoch"
180
+ )
181
+ parser.add_argument(
182
+ "--num-epoch", type=int, default=80, help="stop training in which epoch"
183
+ )
184
+ parser.add_argument(
185
+ "--weight-decay", type=float, default=0.0005, help="weight decay for optimizer"
186
+ )
187
+ # loss
188
+ parser.add_argument("--loss", type=str, default="CE", help="loss type(CE or focal)")
189
+ parser.add_argument(
190
+ "--label_count_path",
191
+ default=None,
192
+ type=str,
193
+ help="Path to label counts (used in loss weighting)",
194
+ )
195
+ parser.add_argument(
196
+ "---beta",
197
+ type=float,
198
+ default=0.9999,
199
+ help="Hyperparameter for Class balanced loss",
200
+ )
201
+ parser.add_argument(
202
+ "--gamma", type=float, default=2.0, help="Hyperparameter for Focal loss"
203
+ )
204
+
205
+ parser.add_argument("--only_train_part", default=False)
206
+ parser.add_argument("--only_train_epoch", default=0)
207
+ parser.add_argument("--warm_up_epoch", default=10)
208
+
209
+ parser.add_argument(
210
+ "--lambda-mil", default=1.0, help="balancing hyper-parameter of mil branch"
211
+ )
212
+
213
+ parser.add_argument(
214
+ "--class-threshold",
215
+ type=float,
216
+ default=0.1,
217
+ help="class threshold for rejection",
218
+ )
219
+ parser.add_argument(
220
+ "--start-threshold",
221
+ type=float,
222
+ default=0.03,
223
+ help="start threshold for action localization",
224
+ )
225
+ parser.add_argument(
226
+ "--end-threshold",
227
+ type=float,
228
+ default=0.055,
229
+ help="end threshold for action localization",
230
+ )
231
+ parser.add_argument(
232
+ "--threshold-interval",
233
+ type=float,
234
+ default=0.005,
235
+ help="threshold interval for action localization",
236
+ )
237
+ return parser
238
+
239
+
240
+ class Processor:
241
+ """
242
+ Processor for Skeleton-based Action Recgnition
243
+ """
244
+
245
+ def __init__(self, arg):
246
+ self.arg = arg
247
+ self.save_arg()
248
+ if arg.phase == "train":
249
+ if not arg.train_feeder_args["debug"]:
250
+ if os.path.isdir(arg.model_saved_name):
251
+ print("log_dir: ", arg.model_saved_name, "already exist")
252
+ # answer = input('delete it? y/n:')
253
+ answer = "y"
254
+ if answer == "y":
255
+ print("Deleting dir...")
256
+ shutil.rmtree(arg.model_saved_name)
257
+ print("Dir removed: ", arg.model_saved_name)
258
+ # input('Refresh the website of tensorboard by pressing any keys')
259
+ else:
260
+ print("Dir not removed: ", arg.model_saved_name)
261
+ self.train_writer = SummaryWriter(
262
+ os.path.join(arg.model_saved_name, "train"), "train"
263
+ )
264
+ self.val_writer = SummaryWriter(
265
+ os.path.join(arg.model_saved_name, "val"), "val"
266
+ )
267
+ else:
268
+ self.train_writer = self.val_writer = SummaryWriter(
269
+ os.path.join(arg.model_saved_name, "test"), "test"
270
+ )
271
+ self.global_step = 0
272
+ self.load_model()
273
+ self.load_optimizer()
274
+ self.load_data()
275
+ self.lr = self.arg.base_lr
276
+ self.best_acc = 0
277
+ self.best_per_class_acc = 0
278
+ self.loss_nce = torch.nn.BCELoss()
279
+
280
+ self.my_logger = Logger(
281
+ os.path.join(arg.model_saved_name, "log.txt"), title="SWTAL"
282
+ )
283
+ self.my_logger.set_names(["Step", "cmap"] + [f"map_0.{i}" for i in range(1, 6)]+['avg'])
284
+
285
+ def load_data(self):
286
+ Feeder = import_class(self.arg.feeder)
287
+ self.data_loader = dict()
288
+ if self.arg.phase == "train":
289
+ self.data_loader["train"] = torch.utils.data.DataLoader(
290
+ dataset=Feeder(**self.arg.train_feeder_args),
291
+ batch_size=self.arg.batch_size,
292
+ shuffle=True,
293
+ num_workers=self.arg.num_worker,
294
+ drop_last=True,
295
+ collate_fn=collate_with_padding_multi_joint,
296
+ )
297
+ self.data_loader["test"] = torch.utils.data.DataLoader(
298
+ dataset=Feeder(**self.arg.test_feeder_args),
299
+ batch_size=self.arg.test_batch_size,
300
+ shuffle=False,
301
+ num_workers=self.arg.num_worker,
302
+ drop_last=False,
303
+ collate_fn=collate_with_padding_multi_joint,
304
+ )
305
+
306
+ def load_model(self):
307
+ output_device = (
308
+ self.arg.device[0] if type(self.arg.device) is list else self.arg.device
309
+ )
310
+ self.output_device = output_device
311
+ Model = import_class(self.arg.model)
312
+ shutil.copy2(inspect.getfile(Model), self.arg.work_dir)
313
+ # print(Model)
314
+ self.model = Model(**self.arg.model_args).cuda(output_device)
315
+ # print(self.model)
316
+ self.loss_type = arg.loss
317
+
318
+ if self.arg.weights:
319
+ # self.global_step = int(arg.weights[:-3].split("-")[-1])
320
+ self.print_log("Load weights from {}.".format(self.arg.weights))
321
+ if ".pkl" in self.arg.weights:
322
+ with open(self.arg.weights, "r") as f:
323
+ weights = pickle.load(f)
324
+ else:
325
+ weights = torch.load(self.arg.weights)
326
+
327
+ weights = OrderedDict(
328
+ [
329
+ [k.split("module.")[-1], v.cuda(output_device)]
330
+ for k, v in weights.items()
331
+ ]
332
+ )
333
+
334
+ keys = list(weights.keys())
335
+ for w in self.arg.ignore_weights:
336
+ for key in keys:
337
+ if w in key:
338
+ if weights.pop(key, None) is not None:
339
+ self.print_log(
340
+ "Sucessfully Remove Weights: {}.".format(key)
341
+ )
342
+ else:
343
+ self.print_log("Can Not Remove Weights: {}.".format(key))
344
+
345
+ try:
346
+ self.model.load_state_dict(weights)
347
+ except:
348
+ state = self.model.state_dict()
349
+ diff = list(set(state.keys()).difference(set(weights.keys())))
350
+ print("Can not find these weights:")
351
+ for d in diff:
352
+ print(" " + d)
353
+ state.update(weights)
354
+ self.model.load_state_dict(state)
355
+
356
+ if type(self.arg.device) is list:
357
+ if len(self.arg.device) > 1:
358
+ self.model = nn.DataParallel(
359
+ self.model, device_ids=self.arg.device, output_device=output_device
360
+ )
361
+
362
+ def load_optimizer(self):
363
+ if self.arg.optimizer == "SGD":
364
+ self.optimizer = optim.SGD(
365
+ self.model.parameters(),
366
+ lr=self.arg.base_lr,
367
+ momentum=0.9,
368
+ nesterov=self.arg.nesterov,
369
+ weight_decay=self.arg.weight_decay,
370
+ )
371
+ elif self.arg.optimizer == "Adam":
372
+ self.optimizer = optim.Adam(
373
+ self.model.parameters(),
374
+ lr=self.arg.base_lr,
375
+ weight_decay=self.arg.weight_decay,
376
+ )
377
+ else:
378
+ raise ValueError()
379
+
380
+ def save_arg(self):
381
+ # save arg
382
+ arg_dict = vars(self.arg)
383
+ if not os.path.exists(self.arg.work_dir):
384
+ os.makedirs(self.arg.work_dir)
385
+ with open("{}/config.yaml".format(self.arg.work_dir), "w") as f:
386
+ yaml.dump(arg_dict, f)
387
+
388
+ def adjust_learning_rate(self, epoch):
389
+ if self.arg.optimizer == "SGD" or self.arg.optimizer == "Adam":
390
+ if epoch < self.arg.warm_up_epoch:
391
+ lr = self.arg.base_lr * (epoch + 1) / self.arg.warm_up_epoch
392
+ else:
393
+ lr = self.arg.base_lr * (
394
+ 0.1 ** np.sum(epoch >= np.array(self.arg.step))
395
+ )
396
+ for param_group in self.optimizer.param_groups:
397
+ param_group["lr"] = lr
398
+
399
+ return lr
400
+ else:
401
+ raise ValueError()
402
+
403
+ def print_time(self):
404
+ localtime = time.asctime(time.localtime(time.time()))
405
+ self.print_log("Local current time : " + localtime)
406
+
407
+ def print_log(self, str, print_time=True):
408
+ if print_time:
409
+ localtime = time.asctime(time.localtime(time.time()))
410
+ str = "[ " + localtime + " ] " + str
411
+ print(str)
412
+ if self.arg.print_log:
413
+ with open("{}/print_log.txt".format(self.arg.work_dir), "a") as f:
414
+ print(str, file=f)
415
+
416
+ def record_time(self):
417
+ self.cur_time = time.time()
418
+ return self.cur_time
419
+
420
+ def split_time(self):
421
+ split_time = time.time() - self.cur_time
422
+ self.record_time()
423
+ return split_time
424
+
425
+ def train(self, epoch, wb_dict, save_model=False):
426
+ self.model.train()
427
+ self.print_log("Training epoch: {}".format(epoch + 1))
428
+ loader = self.data_loader["train"]
429
+ self.adjust_learning_rate(epoch)
430
+
431
+ loss_value, batch_acc = [], []
432
+ self.train_writer.add_scalar("epoch", epoch, self.global_step)
433
+ self.record_time()
434
+ timer = dict(dataloader=0.001, model=0.001, statistics=0.001)
435
+ process = tqdm(loader)
436
+ if self.arg.only_train_part:
437
+ if epoch > self.arg.only_train_epoch:
438
+ print("only train part, require grad")
439
+ for key, value in self.model.named_parameters():
440
+ if "PA" in key:
441
+ value.requires_grad = True
442
+ else:
443
+ print("only train part, do not require grad")
444
+ for key, value in self.model.named_parameters():
445
+ if "PA" in key:
446
+ value.requires_grad = False
447
+
448
+ vid_preds = []
449
+ frm_preds = []
450
+ vid_lens = []
451
+ labels = []
452
+
453
+ results = []
454
+ indexs = []
455
+
456
+ '''
457
+ Switch to FULL supervision
458
+ Dataloader->Feeder ->collate_with_padding_multi_joint
459
+ '''
460
+
461
+ for batch_idx, (data, label, target, mask, index, soft_label) in enumerate(
462
+ process
463
+ ):
464
+
465
+ self.global_step += 1
466
+ # get data
467
+ data = data.float().cuda(self.output_device)
468
+ label = label.cuda(self.output_device)
469
+ target = target.cuda(self.output_device)
470
+ mask = mask.cuda(self.output_device)
471
+ soft_label = soft_label.cuda(self.output_device)
472
+ timer["dataloader"] += self.split_time()
473
+
474
+ ''' into one hot'''
475
+ ground_truth_flat = target.view(-1)
476
+ one_hot_ground_truth = F.one_hot(ground_truth_flat, num_classes=5)
477
+ ''' into one hot'''
478
+
479
+
480
+ indexs.extend(index.cpu().numpy().tolist())
481
+
482
+ ab_labels = torch.cat([label, torch.ones(label.size(0), 1).cuda()], -1)
483
+
484
+ # forward
485
+ mil_pred, frm_scrs, mil_pred_2, frm_scrs_2 = self.model(data,mask)
486
+
487
+ cls_mil_loss = self.loss_nce(mil_pred, ab_labels.float()) + self.loss_nce(
488
+ mil_pred_2, ab_labels.float()
489
+ )
490
+
491
+ if epoch > -1:
492
+
493
+ frm_scrs_re = rearrange(frm_scrs, "n t c -> (n t) c")
494
+ frm_scrs_2_re = rearrange(frm_scrs_2, "n t c -> (n t) c")
495
+ # soft_label = rearrange(soft_label, "n t c -> (n t) c")
496
+
497
+ loss = cls_mil_loss * 0.1 + mvl_loss(
498
+ frm_scrs, frm_scrs_2, rate=0.2, weight=0.5
499
+ )
500
+
501
+ loss += cross_entropy_loss(
502
+ frm_scrs_re, one_hot_ground_truth
503
+ ) + cross_entropy_loss(frm_scrs_2_re, one_hot_ground_truth)
504
+
505
+ # else:
506
+ # loss = cls_mil_loss * self.arg.lambda_mil + mvl_loss(
507
+ # frm_scrs, frm_scrs_2, rate=0.2, weight=0.5
508
+ # )
509
+
510
+ for i in range(data.size(0)):
511
+ frm_scr = frm_scrs[i]
512
+
513
+ label_ = label[i].cpu().numpy()
514
+ mask_ = mask[i].cpu().numpy()
515
+ vid_len = mask_.sum()
516
+
517
+ frm_pred = F.softmax(frm_scr, -1).detach().cpu().numpy()[:vid_len]
518
+ vid_pred = mil_pred[i].detach().cpu().numpy()
519
+
520
+ results.append(frm_pred)
521
+
522
+ vid_preds.append(vid_pred)
523
+ frm_preds.append(frm_pred)
524
+ vid_lens.append(vid_len)
525
+ labels.append(label_)
526
+
527
+ # backward
528
+ self.optimizer.zero_grad()
529
+ loss.backward()
530
+ self.optimizer.step()
531
+
532
+ loss_value.append(loss.data.item())
533
+ timer["model"] += self.split_time()
534
+
535
+ vid_preds = np.array(vid_preds)
536
+ frm_preds = np.array(frm_preds)
537
+ vid_lens = np.array(vid_lens)
538
+ labels = np.array(labels)
539
+
540
+ loader.dataset.label_update(results, indexs)
541
+
542
+ cmap = cmAP(vid_preds, labels)
543
+
544
+ self.train_writer.add_scalar("acc", cmap, self.global_step)
545
+ self.train_writer.add_scalar("loss", np.mean(loss_value), self.global_step)
546
+
547
+ # statistics
548
+ self.lr = self.optimizer.param_groups[0]["lr"]
549
+ self.train_writer.add_scalar("lr", self.lr, self.global_step)
550
+ timer["statistics"] += self.split_time()
551
+
552
+ # statistics of time consumption and loss
553
+ self.print_log("\tMean training loss: {:.4f}.".format(np.mean(loss_value)))
554
+ self.print_log("\tAcc score: {:.3f}%".format(cmap))
555
+
556
+ # Log
557
+ wb_dict["train loss"] = np.mean(loss_value)
558
+ wb_dict["train acc"] = cmap
559
+
560
+ if save_model:
561
+ state_dict = self.model.state_dict()
562
+ weights = OrderedDict(
563
+ [[k.split("module.")[-1], v.cpu()] for k, v in state_dict.items()]
564
+ )
565
+
566
+ torch.save(
567
+ weights,
568
+ self.arg.model_saved_name + str(epoch) + ".pt",
569
+ )
570
+
571
+ return wb_dict
572
+
573
+ @torch.no_grad()
574
+ def eval(
575
+ self,
576
+ epoch,
577
+ wb_dict,
578
+ loader_name=["test"],
579
+ ):
580
+ self.model.eval()
581
+ self.print_log("Eval epoch: {}".format(epoch + 1))
582
+
583
+ vid_preds = []
584
+ frm_preds = []
585
+ vid_lens = []
586
+ labels = []
587
+
588
+ for ln in loader_name:
589
+ loss_value = []
590
+ step = 0
591
+ process = tqdm(self.data_loader[ln])
592
+
593
+ for batch_idx, (data, label, target, mask, index, soft_label) in enumerate(
594
+ process
595
+ ):
596
+ data = data.float().cuda(self.output_device)
597
+ label = label.cuda(self.output_device)
598
+ mask = mask.cuda(self.output_device)
599
+
600
+ ab_labels = torch.cat([label, torch.ones(label.size(0), 1).cuda()], -1)
601
+
602
+ # forward
603
+ mil_pred, frm_scrs, mil_pred_2, frm_scrs_2 = self.model(data,mask)
604
+
605
+ '''Loc LOSS'''
606
+ target = target.cuda(self.output_device)
607
+ ''' into one hot'''
608
+ ground_truth_flat = target.view(-1)
609
+ one_hot_ground_truth = F.one_hot(ground_truth_flat, num_classes=5)
610
+ ''' into one hot'''
611
+ frm_scrs_re = rearrange(frm_scrs, "n t c -> (n t) c")
612
+ frm_scrs_2_re = rearrange(frm_scrs_2, "n t c -> (n t) c")
613
+ '''Loc LOSS'''
614
+
615
+
616
+ cls_mil_loss = self.loss_nce(
617
+ mil_pred, ab_labels.float()
618
+ ) + self.loss_nce(mil_pred_2, ab_labels.float())
619
+
620
+ loss_co = mvl_loss(frm_scrs, frm_scrs_2, rate=0.2, weight=0.5)
621
+
622
+ loss = cls_mil_loss * self.arg.lambda_mil + loss_co
623
+
624
+ '''Loc LOSS'''
625
+ loss += cross_entropy_loss(
626
+ frm_scrs_re, one_hot_ground_truth
627
+ ) + cross_entropy_loss(frm_scrs_2_re, one_hot_ground_truth)
628
+ '''Loc LOSS'''
629
+
630
+
631
+ loss_value.append(loss.data.item())
632
+
633
+ for i in range(data.size(0)):
634
+ frm_scr = frm_scrs[i]
635
+ vid_pred = mil_pred[i]
636
+
637
+ label_ = label[i].cpu().numpy()
638
+ mask_ = mask[i].cpu().numpy()
639
+ vid_len = mask_.sum()
640
+
641
+ frm_pred = F.softmax(frm_scr, -1).cpu().numpy()[:vid_len]
642
+ vid_pred = vid_pred.cpu().numpy()
643
+
644
+ vid_preds.append(vid_pred)
645
+ frm_preds.append(frm_pred)
646
+ vid_lens.append(vid_len)
647
+ labels.append(label_)
648
+
649
+ step += 1
650
+
651
+ vid_preds = np.array(vid_preds)
652
+ frm_preds = np.array(frm_preds)
653
+ vid_lens = np.array(vid_lens)
654
+ labels = np.array(labels)
655
+
656
+ cmap = cmAP(vid_preds, labels)
657
+
658
+ score = cmap
659
+ loss = np.mean(loss_value)
660
+
661
+ dmap, iou = dsmAP(
662
+ vid_preds,
663
+ frm_preds,
664
+ vid_lens,
665
+ self.arg.test_feeder_args["data_path"],
666
+ self.arg,
667
+ multi=True,
668
+ )
669
+
670
+ print("Classification map %f" % cmap)
671
+ for item in list(zip(iou, dmap)):
672
+ print("Detection map @ %f = %f" % (item[0], item[1]))
673
+
674
+ self.my_logger.append([epoch + 1, cmap] + dmap+ [np.mean(dmap)])
675
+
676
+ wb_dict["val loss"] = loss
677
+ wb_dict["val acc"] = score
678
+
679
+ if score > self.best_acc:
680
+ self.best_acc = score
681
+
682
+ print("Acc score: ", score, " model: ", self.arg.model_saved_name)
683
+ if self.arg.phase == "train":
684
+ self.val_writer.add_scalar("loss", loss, self.global_step)
685
+ self.val_writer.add_scalar("acc", score, self.global_step)
686
+
687
+ self.print_log(
688
+ "\tMean {} loss of {} batches: {}.".format(
689
+ ln, len(self.data_loader[ln]), np.mean(loss_value)
690
+ )
691
+ )
692
+ self.print_log("\tAcc score: {:.3f}%".format(score))
693
+
694
+ return wb_dict
695
+
696
+ def start(self):
697
+ wb_dict = {}
698
+ if self.arg.phase == "train":
699
+ self.print_log("Parameters:\n{}\n".format(str(vars(self.arg))))
700
+ self.global_step = (
701
+ self.arg.start_epoch
702
+ * len(self.data_loader["train"])
703
+ / self.arg.batch_size
704
+ )
705
+
706
+ for epoch in range(self.arg.start_epoch, self.arg.num_epoch):
707
+
708
+ save_model = ((epoch + 1) % self.arg.save_interval == 0) or (
709
+ epoch + 1 == self.arg.num_epoch
710
+ )
711
+ wb_dict = {"lr": self.lr}
712
+
713
+ # Train
714
+ wb_dict = self.train(epoch, wb_dict, save_model=save_model)
715
+ if epoch%10==0:
716
+ # Eval. on val set
717
+ wb_dict = self.eval(epoch, wb_dict, loader_name=["test"])
718
+ # Log stats. for this epoch
719
+ print("Epoch: {0}\nMetrics: {1}".format(epoch, wb_dict))
720
+
721
+ print(
722
+ "best accuracy: ",
723
+ self.best_acc,
724
+ " model_name: ",
725
+ self.arg.model_saved_name,
726
+ )
727
+
728
+ elif self.arg.phase == "test":
729
+ if not self.arg.test_feeder_args["debug"]:
730
+ wf = self.arg.model_saved_name + "_wrong.txt"
731
+ rf = self.arg.model_saved_name + "_right.txt"
732
+ else:
733
+ wf = rf = None
734
+ if self.arg.weights is None:
735
+ raise ValueError("Please appoint --weights.")
736
+ self.arg.print_log = False
737
+ self.print_log("Model: {}.".format(self.arg.model))
738
+ self.print_log("Weights: {}.".format(self.arg.weights))
739
+
740
+ wb_dict = self.eval(
741
+ epoch=0,
742
+ wb_dict=wb_dict,
743
+ loader_name=["test"],
744
+ wrong_file=wf,
745
+ result_file=rf,
746
+ )
747
+ print("Inference metrics: ", wb_dict)
748
+ self.print_log("Done.\n")
749
+
750
+
751
+ def str2bool(v):
752
+ if v.lower() in ("yes", "true", "t", "y", "1"):
753
+ return True
754
+ elif v.lower() in ("no", "false", "f", "n", "0"):
755
+ return False
756
+ else:
757
+ raise argparse.ArgumentTypeError("Boolean value expected.")
758
+
759
+
760
+ def import_class(name):
761
+ components = name.split(".")
762
+ mod = __import__(components[0])
763
+ for comp in components[1:]:
764
+ mod = getattr(mod, comp)
765
+ return mod
766
+
767
+
768
+ if __name__ == "__main__":
769
+ parser = get_parser()
770
+
771
+ # load arg form config file
772
+ p = parser.parse_args()
773
+ if p.config is not None:
774
+ with open(p.config, "r") as f:
775
+ default_arg = yaml.safe_load(f)
776
+ key = vars(p).keys()
777
+ for k in default_arg.keys():
778
+ if k not in key:
779
+ print("WRONG ARG: {}".format(k))
780
+ assert k in key
781
+ parser.set_defaults(**default_arg)
782
+
783
+ arg = parser.parse_args()
784
+ print("BABEL Action Recognition")
785
+ print("Config: ", arg)
786
+ init_seed(arg.seed)
787
+ processor = Processor(arg)
788
+ processor.start()
train_full_SSL.py ADDED
@@ -0,0 +1,784 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright 2023 LINE Corporation
3
+ LINE Corporation licenses this file to you under the Apache License,
4
+ version 2.0 (the "License"); you may not use this file except in compliance
5
+ with the License. You may obtain a copy of the License at:
6
+ https://www.apache.org/licenses/LICENSE-2.0
7
+ Unless required by applicable law or agreed to in writing, software
8
+ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9
+ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10
+ License for the specific language governing permissions and limitations
11
+ under the License.
12
+ """
13
+
14
+ from __future__ import print_function
15
+
16
+ import argparse
17
+ import inspect
18
+ import os
19
+ import pdb
20
+ import pickle
21
+ import random
22
+ import re
23
+ import shutil
24
+ import time
25
+ from collections import *
26
+
27
+ import ipdb
28
+ import numpy as np
29
+
30
+ # torch
31
+ import torch
32
+ import torch.backends.cudnn as cudnn
33
+ import torch.nn as nn
34
+ import torch.nn.functional as F
35
+ import torch.optim as optim
36
+ import yaml
37
+ from einops import rearrange, reduce, repeat
38
+ from evaluation.classificationMAP import getClassificationMAP as cmAP
39
+ from evaluation.detectionMAP import getSingleStreamDetectionMAP as dsmAP
40
+ from feeders.tools import collate_with_padding_multi_joint
41
+ from model.losses import cross_entropy_loss, mvl_loss
42
+ from sklearn.metrics import f1_score
43
+
44
+ # Custom
45
+ from tensorboardX import SummaryWriter
46
+ from torch.autograd import Variable
47
+ from torch.optim.lr_scheduler import _LRScheduler
48
+ from tqdm import tqdm
49
+ from utils.logger import Logger
50
+
51
+ def remove_prefix_from_state_dict(state_dict, prefix):
52
+ new_state_dict = {}
53
+ for k, v in state_dict.items():
54
+ if k.startswith(prefix):
55
+ new_k = k[len(prefix):] # strip the prefix
56
+ else:
57
+ new_k = k
58
+ new_state_dict[new_k] = v
59
+ return new_state_dict
60
+
61
+
62
+ def init_seed(seed):
63
+ torch.cuda.manual_seed_all(seed)
64
+ torch.manual_seed(seed)
65
+ np.random.seed(seed)
66
+ random.seed(seed)
67
+ torch.backends.cudnn.deterministic = True
68
+ torch.backends.cudnn.benchmark = False
69
+
70
+
71
+ def get_parser():
72
+ # parameter priority: command line > config > default
73
+ parser = argparse.ArgumentParser(
74
+ description="Spatial Temporal Graph Convolution Network"
75
+ )
76
+ parser.add_argument(
77
+ "--work-dir",
78
+ default="./work_dir/temp",
79
+ help="the work folder for storing results",
80
+ )
81
+
82
+ parser.add_argument("-model_saved_name", default="")
83
+ parser.add_argument(
84
+ "--config",
85
+ default="./config/nturgbd-cross-view/test_bone.yaml",
86
+ help="path to the configuration file",
87
+ )
88
+
89
+ # processor
90
+ parser.add_argument("--phase", default="train", help="must be train or test")
91
+
92
+ # visulize and debug
93
+ parser.add_argument("--seed", type=int, default=5, help="random seed for pytorch")
94
+ parser.add_argument(
95
+ "--log-interval",
96
+ type=int,
97
+ default=100,
98
+ help="the interval for printing messages (#iteration)",
99
+ )
100
+ parser.add_argument(
101
+ "--save-interval",
102
+ type=int,
103
+ default=2,
104
+ help="the interval for storing models (#iteration)",
105
+ )
106
+ parser.add_argument(
107
+ "--eval-interval",
108
+ type=int,
109
+ default=5,
110
+ help="the interval for evaluating models (#iteration)",
111
+ )
112
+ parser.add_argument(
113
+ "--print-log", type=str2bool, default=True, help="print logging or not"
114
+ )
115
+ parser.add_argument(
116
+ "--show-topk",
117
+ type=int,
118
+ default=[1, 5],
119
+ nargs="+",
120
+ help="which Top K accuracy will be shown",
121
+ )
122
+
123
+ # feeder
124
+ parser.add_argument(
125
+ "--feeder", default="feeder.feeder", help="data loader will be used"
126
+ )
127
+ parser.add_argument(
128
+ "--num-worker",
129
+ type=int,
130
+ default=32,
131
+ help="the number of worker for data loader",
132
+ )
133
+ parser.add_argument(
134
+ "--train-feeder-args",
135
+ default=dict(),
136
+ help="the arguments of data loader for training",
137
+ )
138
+ parser.add_argument(
139
+ "--test-feeder-args",
140
+ default=dict(),
141
+ help="the arguments of data loader for test",
142
+ )
143
+
144
+ # model
145
+ parser.add_argument("--model", default=None, help="the model will be used")
146
+ parser.add_argument(
147
+ "--model-args", type=dict, default=dict(), help="the arguments of model"
148
+ )
149
+ parser.add_argument(
150
+ "--weights", default=None, help="the weights for network initialization"
151
+ )
152
+ parser.add_argument(
153
+ "--ignore-weights",
154
+ type=str,
155
+ default=[],
156
+ nargs="+",
157
+ help="the name of weights which will be ignored in the initialization",
158
+ )
159
+
160
+ # optim
161
+ parser.add_argument(
162
+ "--base-lr", type=float, default=0.01, help="initial learning rate"
163
+ )
164
+ parser.add_argument(
165
+ "--step",
166
+ type=int,
167
+ default=[200],
168
+ nargs="+",
169
+ help="the epoch where optimizer reduce the learning rate",
170
+ )
171
+
172
+ # training
173
+ parser.add_argument(
174
+ "--device",
175
+ type=int,
176
+ default=0,
177
+ nargs="+",
178
+ help="the indexes of GPUs for training or testing",
179
+ )
180
+ parser.add_argument("--optimizer", default="SGD", help="type of optimizer")
181
+ parser.add_argument(
182
+ "--nesterov", type=str2bool, default=False, help="use nesterov or not"
183
+ )
184
+ parser.add_argument(
185
+ "--batch-size", type=int, default=256, help="training batch size"
186
+ )
187
+ parser.add_argument(
188
+ "--test-batch-size", type=int, default=256, help="test batch size"
189
+ )
190
+ parser.add_argument(
191
+ "--start-epoch", type=int, default=0, help="start training from which epoch"
192
+ )
193
+ parser.add_argument(
194
+ "--num-epoch", type=int, default=80, help="stop training in which epoch"
195
+ )
196
+ parser.add_argument(
197
+ "--weight-decay", type=float, default=0.0005, help="weight decay for optimizer"
198
+ )
199
+ # loss
200
+ parser.add_argument("--loss", type=str, default="CE", help="loss type(CE or focal)")
201
+ parser.add_argument(
202
+ "--label_count_path",
203
+ default=None,
204
+ type=str,
205
+ help="Path to label counts (used in loss weighting)",
206
+ )
207
+ parser.add_argument(
208
+ "---beta",
209
+ type=float,
210
+ default=0.9999,
211
+ help="Hyperparameter for Class balanced loss",
212
+ )
213
+ parser.add_argument(
214
+ "--gamma", type=float, default=2.0, help="Hyperparameter for Focal loss"
215
+ )
216
+
217
+ parser.add_argument("--only_train_part", default=False)
218
+ parser.add_argument("--only_train_epoch", default=0)
219
+ parser.add_argument("--warm_up_epoch", default=10)
220
+
221
+ parser.add_argument(
222
+ "--lambda-mil", default=1.0, help="balancing hyper-parameter of mil branch"
223
+ )
224
+
225
+ parser.add_argument(
226
+ "--class-threshold",
227
+ type=float,
228
+ default=0.1,
229
+ help="class threshold for rejection",
230
+ )
231
+ parser.add_argument(
232
+ "--start-threshold",
233
+ type=float,
234
+ default=0.03,
235
+ help="start threshold for action localization",
236
+ )
237
+ parser.add_argument(
238
+ "--end-threshold",
239
+ type=float,
240
+ default=0.055,
241
+ help="end threshold for action localization",
242
+ )
243
+ parser.add_argument(
244
+ "--threshold-interval",
245
+ type=float,
246
+ default=0.005,
247
+ help="threshold interval for action localization",
248
+ )
249
+ return parser
250
+
251
+
252
+ class Processor:
253
+ """
254
+ Processor for Skeleton-based Action Recgnition
255
+ """
256
+
257
+ def __init__(self, arg):
258
+ self.arg = arg
259
+ self.save_arg()
260
+ if arg.phase == "train":
261
+ if not arg.train_feeder_args["debug"]:
262
+ if os.path.isdir(arg.model_saved_name):
263
+ print("log_dir: ", arg.model_saved_name, "already exist")
264
+ # answer = input('delete it? y/n:')
265
+ answer = "y"
266
+ if answer == "y":
267
+ print("Deleting dir...")
268
+ shutil.rmtree(arg.model_saved_name)
269
+ print("Dir removed: ", arg.model_saved_name)
270
+ # input('Refresh the website of tensorboard by pressing any keys')
271
+ else:
272
+ print("Dir not removed: ", arg.model_saved_name)
273
+ self.train_writer = SummaryWriter(
274
+ os.path.join(arg.model_saved_name, "train"), "train"
275
+ )
276
+ self.val_writer = SummaryWriter(
277
+ os.path.join(arg.model_saved_name, "val"), "val"
278
+ )
279
+ else:
280
+ self.train_writer = self.val_writer = SummaryWriter(
281
+ os.path.join(arg.model_saved_name, "test"), "test"
282
+ )
283
+ self.global_step = 0
284
+ self.load_model()
285
+ self.load_optimizer()
286
+ self.load_data()
287
+ self.lr = self.arg.base_lr
288
+ self.best_acc = 0
289
+ self.best_per_class_acc = 0
290
+ self.loss_nce = torch.nn.BCELoss()
291
+
292
+ self.my_logger = Logger(
293
+ os.path.join(arg.model_saved_name, "log.txt"), title="SWTAL"
294
+ )
295
+ self.my_logger.set_names(["Step", "cmap"] + [f"map_0.{i}" for i in range(1, 6)]+["avg"])
296
+
297
+ def load_data(self):
298
+ Feeder = import_class(self.arg.feeder)
299
+ self.data_loader = dict()
300
+ if self.arg.phase == "train":
301
+ self.data_loader["train"] = torch.utils.data.DataLoader(
302
+ dataset=Feeder(**self.arg.train_feeder_args),
303
+ batch_size=self.arg.batch_size,
304
+ shuffle=True,
305
+ num_workers=self.arg.num_worker,
306
+ drop_last=True,
307
+ collate_fn=collate_with_padding_multi_joint,
308
+ )
309
+ self.data_loader["test"] = torch.utils.data.DataLoader(
310
+ dataset=Feeder(**self.arg.test_feeder_args),
311
+ batch_size=self.arg.test_batch_size,
312
+ shuffle=False,
313
+ num_workers=self.arg.num_worker,
314
+ drop_last=False,
315
+ collate_fn=collate_with_padding_multi_joint,
316
+ )
317
+
318
+ def load_model(self):
319
+ output_device = (
320
+ self.arg.device[0] if type(self.arg.device) is list else self.arg.device
321
+ )
322
+ self.output_device = output_device
323
+ Model = import_class(self.arg.model)
324
+ shutil.copy2(inspect.getfile(Model), self.arg.work_dir)
325
+ # print(Model)
326
+ self.model = Model(**self.arg.model_args).cuda(output_device)
327
+ # print(self.model)
328
+ self.loss_type = arg.loss
329
+
330
+ if self.arg.weights:
331
+ self.print_log("Load weights from {}.".format(self.arg.weights))
332
+ if ".pkl" in self.arg.weights:
333
+ with open(self.arg.weights, "r") as f:
334
+ weights = pickle.load(f)
335
+ else:
336
+ weights = torch.load(self.arg.weights)
337
+
338
+ weights = OrderedDict(
339
+ [
340
+ [k.split("module.")[-1], v.cuda(output_device)]
341
+ for k, v in weights.items()
342
+ ]
343
+ )
344
+ weights = remove_prefix_from_state_dict(weights, 'encoder_q.agcn.')
345
+ keys = list(weights.keys())
346
+
347
+ self.arg.ignore_weights = ['data_bn','fc','encoder_q','encoder_k','queue','queue_ptr','value_transform']
348
+ for w in self.arg.ignore_weights:
349
+ for key in keys:
350
+ if w in key:
351
+ if weights.pop(key, None) is not None:
352
+ continue
353
+ # self.print_log(
354
+ # "Sucessfully Remove Weights: {}.".format(key)
355
+ # )
356
+ # else:
357
+ # self.print_log("Can Not Remove Weights: {}.".format(key))
358
+
359
+ try:
360
+ self.model.load_state_dict(weights)
361
+ except:
362
+ state = self.model.state_dict()
363
+ diff = list(set(state.keys()).difference(set(weights.keys())))
364
+ print("Can not find these weights:")
365
+ for d in diff:
366
+ print(" " + d)
367
+ state.update(weights)
368
+ self.model.load_state_dict(state)
369
+
370
+ if type(self.arg.device) is list:
371
+ if len(self.arg.device) > 1:
372
+ self.model = nn.DataParallel(
373
+ self.model, device_ids=self.arg.device, output_device=output_device
374
+ )
375
+
376
+ def load_optimizer(self):
377
+ if self.arg.optimizer == "SGD":
378
+ self.optimizer = optim.SGD(
379
+ self.model.parameters(),
380
+ lr=self.arg.base_lr,
381
+ momentum=0.9,
382
+ nesterov=self.arg.nesterov,
383
+ weight_decay=self.arg.weight_decay,
384
+ )
385
+ elif self.arg.optimizer == "Adam":
386
+ self.optimizer = optim.Adam(
387
+ self.model.parameters(),
388
+ lr=self.arg.base_lr,
389
+ weight_decay=self.arg.weight_decay,
390
+ )
391
+ else:
392
+ raise ValueError()
393
+
394
+ def save_arg(self):
395
+ # save arg
396
+ arg_dict = vars(self.arg)
397
+ if not os.path.exists(self.arg.work_dir):
398
+ os.makedirs(self.arg.work_dir)
399
+ with open("{}/config.yaml".format(self.arg.work_dir), "w") as f:
400
+ yaml.dump(arg_dict, f)
401
+
402
+ def adjust_learning_rate(self, epoch):
403
+ if self.arg.optimizer == "SGD" or self.arg.optimizer == "Adam":
404
+ if epoch < self.arg.warm_up_epoch:
405
+ lr = self.arg.base_lr * (epoch + 1) / self.arg.warm_up_epoch
406
+ else:
407
+ lr = self.arg.base_lr * (
408
+ 0.1 ** np.sum(epoch >= np.array(self.arg.step))
409
+ )
410
+ for param_group in self.optimizer.param_groups:
411
+ param_group["lr"] = lr
412
+
413
+ return lr
414
+ else:
415
+ raise ValueError()
416
+
417
+ def print_time(self):
418
+ localtime = time.asctime(time.localtime(time.time()))
419
+ self.print_log("Local current time : " + localtime)
420
+
421
+ def print_log(self, str, print_time=True):
422
+ if print_time:
423
+ localtime = time.asctime(time.localtime(time.time()))
424
+ str = "[ " + localtime + " ] " + str
425
+ print(str)
426
+ if self.arg.print_log:
427
+ with open("{}/print_log.txt".format(self.arg.work_dir), "a") as f:
428
+ print(str, file=f)
429
+
430
+ def record_time(self):
431
+ self.cur_time = time.time()
432
+ return self.cur_time
433
+
434
+ def split_time(self):
435
+ split_time = time.time() - self.cur_time
436
+ self.record_time()
437
+ return split_time
438
+
439
+ def train(self, epoch, wb_dict, save_model=False):
440
+ self.model.train()
441
+ self.print_log("Training epoch: {}".format(epoch + 1))
442
+ loader = self.data_loader["train"]
443
+ self.adjust_learning_rate(epoch)
444
+
445
+ loss_value, batch_acc = [], []
446
+ self.train_writer.add_scalar("epoch", epoch, self.global_step)
447
+ self.record_time()
448
+ timer = dict(dataloader=0.001, model=0.001, statistics=0.001)
449
+ process = tqdm(loader)
450
+ if self.arg.only_train_part:
451
+ if epoch > self.arg.only_train_epoch:
452
+ print("only train part, require grad")
453
+ for key, value in self.model.named_parameters():
454
+ if "PA" in key:
455
+ value.requires_grad = True
456
+ else:
457
+ print("only train part, do not require grad")
458
+ for key, value in self.model.named_parameters():
459
+ if "PA" in key:
460
+ value.requires_grad = False
461
+
462
+ vid_preds = []
463
+ frm_preds = []
464
+ vid_lens = []
465
+ labels = []
466
+
467
+ results = []
468
+ indexs = []
469
+
470
+ '''
471
+ Switch to FULL supervision
472
+ Dataloader->Feeder -> collate_with_padding_multi_joint
473
+ '''
474
+
475
+ for batch_idx, (data, label, target, mask, index, soft_label) in enumerate(
476
+ process
477
+ ):
478
+
479
+ self.global_step += 1
480
+ # get data
481
+ data = data.float().cuda(self.output_device)
482
+ label = label.cuda(self.output_device)
483
+ target = target.cuda(self.output_device)
484
+ mask = mask.cuda(self.output_device)
485
+ soft_label = soft_label.cuda(self.output_device)
486
+ timer["dataloader"] += self.split_time()
487
+
488
+ ''' into one hot'''
489
+ ground_truth_flat = target.view(-1)
490
+ one_hot_ground_truth = F.one_hot(ground_truth_flat, num_classes=5)
491
+ ''' into one hot'''
492
+
493
+ indexs.extend(index.cpu().numpy().tolist())
494
+
495
+ ab_labels = torch.cat([label, torch.ones(label.size(0), 1).cuda()], -1)
496
+
497
+ # forward
498
+ frm_scrs = self.model(data)
499
+
500
+
501
+ if epoch > -1:
502
+
503
+ frm_scrs_re = rearrange(frm_scrs, "n t c -> (n t) c")
504
+ # frm_scrs_2_re = rearrange(frm_scrs_2, "n t c -> (n t) c")
505
+ # soft_label = rearrange(soft_label, "n t c -> (n t) c")
506
+
507
+ # loss = cls_mil_loss * 0.1 + mvl_loss(
508
+ # frm_scrs, frm_scrs_2, rate=0.2, weight=0.5
509
+ # )
510
+
511
+ loss = cross_entropy_loss(
512
+ frm_scrs_re, one_hot_ground_truth
513
+ ) #+ cross_entropy_loss(frm_scrs_2_re, one_hot_ground_truth)
514
+
515
+ for i in range(data.size(0)):
516
+ frm_scr = frm_scrs[i]
517
+
518
+ label_ = label[i].cpu().numpy()
519
+ mask_ = mask[i].cpu().numpy()
520
+ vid_len = mask_.sum()
521
+
522
+ frm_pred = F.softmax(frm_scr, -1).detach().cpu().numpy()[:vid_len]
523
+ # vid_pred = mil_pred[i].detach().cpu().numpy()
524
+
525
+ vid_pred = 0
526
+ results.append(frm_pred)
527
+
528
+ vid_preds.append(vid_pred)
529
+ frm_preds.append(frm_pred)
530
+ vid_lens.append(vid_len)
531
+ labels.append(label_)
532
+
533
+ # backward
534
+ self.optimizer.zero_grad()
535
+ loss.backward()
536
+ self.optimizer.step()
537
+
538
+ loss_value.append(loss.data.item())
539
+ timer["model"] += self.split_time()
540
+
541
+ vid_preds = np.array(vid_preds)
542
+ frm_preds = np.array(frm_preds)
543
+ vid_lens = np.array(vid_lens)
544
+ labels = np.array(labels)
545
+
546
+ loader.dataset.label_update(results, indexs)
547
+
548
+ # cmap = cmAP(vid_preds, labels)
549
+ cmap = 0
550
+
551
+ self.train_writer.add_scalar("acc", cmap, self.global_step)
552
+ self.train_writer.add_scalar("loss", np.mean(loss_value), self.global_step)
553
+
554
+ # statistics
555
+ self.lr = self.optimizer.param_groups[0]["lr"]
556
+ self.train_writer.add_scalar("lr", self.lr, self.global_step)
557
+ timer["statistics"] += self.split_time()
558
+
559
+ # statistics of time consumption and loss
560
+ self.print_log("\tMean training loss: {:.4f}.".format(np.mean(loss_value)))
561
+ self.print_log("\tAcc score: {:.3f}%".format(cmap))
562
+
563
+ # Log
564
+ wb_dict["train loss"] = np.mean(loss_value)
565
+ wb_dict["train acc"] = cmap
566
+
567
+ if save_model:
568
+ state_dict = self.model.state_dict()
569
+ weights = OrderedDict(
570
+ [[k.split("module.")[-1], v.cpu()] for k, v in state_dict.items()]
571
+ )
572
+
573
+ torch.save(
574
+ weights,
575
+ self.arg.model_saved_name + str(epoch) + ".pt",
576
+ )
577
+
578
+ return wb_dict
579
+
580
+ @torch.no_grad()
581
+ def eval(
582
+ self,
583
+ epoch,
584
+ wb_dict,
585
+ loader_name=["test"],
586
+ ):
587
+ self.model.eval()
588
+ self.print_log("Eval epoch: {}".format(epoch + 1))
589
+
590
+ vid_preds = []
591
+ frm_preds = []
592
+ vid_lens = []
593
+ labels = []
594
+
595
+ for ln in loader_name:
596
+ loss_value = []
597
+ step = 0
598
+ process = tqdm(self.data_loader[ln])
599
+
600
+ for batch_idx, (data, label, target, mask, index, soft_label) in enumerate(
601
+ process
602
+ ):
603
+ data = data.float().cuda(self.output_device)
604
+ label = label.cuda(self.output_device)
605
+ mask = mask.cuda(self.output_device)
606
+
607
+ ab_labels = torch.cat([label, torch.ones(label.size(0), 1).cuda()], -1)
608
+
609
+ # forward
610
+ frm_scrs = self.model(data)
611
+
612
+
613
+ '''Loc LOSS'''
614
+ target = target.cuda(self.output_device)
615
+ ''' into one hot'''
616
+ ground_truth_flat = target.view(-1)
617
+ one_hot_ground_truth = F.one_hot(ground_truth_flat, num_classes=5)
618
+ ''' into one hot'''
619
+ frm_scrs_re = rearrange(frm_scrs, "n t c -> (n t) c")
620
+ '''Loc LOSS'''
621
+ '''Loc LOSS'''
622
+ loss = cross_entropy_loss(
623
+ frm_scrs_re, one_hot_ground_truth
624
+ )
625
+ '''Loc LOSS'''
626
+
627
+ loss_value.append(loss.data.item())
628
+
629
+ for i in range(data.size(0)):
630
+ frm_scr = frm_scrs[i]
631
+
632
+ label_ = label[i].cpu().numpy()
633
+ mask_ = mask[i].cpu().numpy()
634
+ vid_len = mask_.sum()
635
+
636
+ frm_pred = F.softmax(frm_scr, -1).cpu().numpy()[:vid_len]
637
+ # vid_pred = vid_pred.cpu().numpy()
638
+
639
+ vid_pred = 0
640
+ vid_preds.append(vid_pred)
641
+ frm_preds.append(frm_pred)
642
+ vid_lens.append(vid_len)
643
+ labels.append(label_)
644
+
645
+ step += 1
646
+
647
+ vid_preds = np.array(vid_preds)
648
+ frm_preds = np.array(frm_preds)
649
+ vid_lens = np.array(vid_lens)
650
+ labels = np.array(labels)
651
+
652
+ # cmap = cmAP(vid_preds, labels)
653
+ cmap = 0
654
+ score = cmap
655
+ loss = np.mean(loss_value)
656
+
657
+ dmap, iou = dsmAP(
658
+ vid_preds,
659
+ frm_preds,
660
+ vid_lens,
661
+ self.arg.test_feeder_args["data_path"],
662
+ self.arg,
663
+ multi=True,
664
+ )
665
+
666
+ print("Classification map %f" % cmap)
667
+ for item in list(zip(iou, dmap)):
668
+ print("Detection map @ %f = %f" % (item[0], item[1]))
669
+
670
+ self.my_logger.append([epoch + 1, cmap] + dmap+[np.mean(dmap)])
671
+
672
+ wb_dict["val loss"] = loss
673
+ wb_dict["val acc"] = score
674
+
675
+ if score > self.best_acc:
676
+ self.best_acc = score
677
+
678
+ print("Acc score: ", score, " model: ", self.arg.model_saved_name)
679
+ if self.arg.phase == "train":
680
+ self.val_writer.add_scalar("loss", loss, self.global_step)
681
+ self.val_writer.add_scalar("acc", score, self.global_step)
682
+
683
+ self.print_log(
684
+ "\tMean {} loss of {} batches: {}.".format(
685
+ ln, len(self.data_loader[ln]), np.mean(loss_value)
686
+ )
687
+ )
688
+ self.print_log("\tAcc score: {:.3f}%".format(score))
689
+
690
+ return wb_dict
691
+
692
+ def start(self):
693
+ wb_dict = {}
694
+ if self.arg.phase == "train":
695
+ self.print_log("Parameters:\n{}\n".format(str(vars(self.arg))))
696
+ self.global_step = (
697
+ self.arg.start_epoch
698
+ * len(self.data_loader["train"])
699
+ / self.arg.batch_size
700
+ )
701
+
702
+ for epoch in range(self.arg.start_epoch, self.arg.num_epoch):
703
+
704
+ save_model = ((epoch + 1) % self.arg.save_interval == 0) or (
705
+ epoch + 1 == self.arg.num_epoch
706
+ )
707
+ wb_dict = {"lr": self.lr}
708
+
709
+ # Train
710
+ wb_dict = self.train(epoch, wb_dict, save_model=save_model)
711
+ if epoch%1==0:
712
+ # Eval. on val set
713
+ wb_dict = self.eval(epoch, wb_dict, loader_name=["test"])
714
+ # Log stats. for this epoch
715
+ print("Epoch: {0}\nMetrics: {1}".format(epoch, wb_dict))
716
+
717
+ print(
718
+ "best accuracy: ",
719
+ self.best_acc,
720
+ " model_name: ",
721
+ self.arg.model_saved_name,
722
+ )
723
+
724
+ elif self.arg.phase == "test":
725
+ if not self.arg.test_feeder_args["debug"]:
726
+ wf = self.arg.model_saved_name + "_wrong.txt"
727
+ rf = self.arg.model_saved_name + "_right.txt"
728
+ else:
729
+ wf = rf = None
730
+ if self.arg.weights is None:
731
+ raise ValueError("Please appoint --weights.")
732
+ self.arg.print_log = False
733
+ self.print_log("Model: {}.".format(self.arg.model))
734
+ self.print_log("Weights: {}.".format(self.arg.weights))
735
+
736
+ wb_dict = self.eval(
737
+ epoch=0,
738
+ wb_dict=wb_dict,
739
+ loader_name=["test"],
740
+ wrong_file=wf,
741
+ result_file=rf,
742
+ )
743
+ print("Inference metrics: ", wb_dict)
744
+ self.print_log("Done.\n")
745
+
746
+
747
+ def str2bool(v):
748
+ if v.lower() in ("yes", "true", "t", "y", "1"):
749
+ return True
750
+ elif v.lower() in ("no", "false", "f", "n", "0"):
751
+ return False
752
+ else:
753
+ raise argparse.ArgumentTypeError("Boolean value expected.")
754
+
755
+
756
+ def import_class(name):
757
+ components = name.split(".")
758
+ mod = __import__(components[0])
759
+ for comp in components[1:]:
760
+ mod = getattr(mod, comp)
761
+ return mod
762
+
763
+
764
+ if __name__ == "__main__":
765
+ parser = get_parser()
766
+
767
+ # load arg form config file
768
+ p = parser.parse_args()
769
+ if p.config is not None:
770
+ with open(p.config, "r") as f:
771
+ default_arg = yaml.safe_load(f)
772
+ key = vars(p).keys()
773
+ for k in default_arg.keys():
774
+ if k not in key:
775
+ print("WRONG ARG: {}".format(k))
776
+ assert k in key
777
+ parser.set_defaults(**default_arg)
778
+
779
+ arg = parser.parse_args()
780
+ print("BABEL Action Recognition")
781
+ print("Config: ", arg)
782
+ init_seed(arg.seed)
783
+ processor = Processor(arg)
784
+ processor.start()
train_full_SSL_Unet.py ADDED
@@ -0,0 +1,813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright 2023 LINE Corporation
3
+ LINE Corporation licenses this file to you under the Apache License,
4
+ version 2.0 (the "License"); you may not use this file except in compliance
5
+ with the License. You may obtain a copy of the License at:
6
+ https://www.apache.org/licenses/LICENSE-2.0
7
+ Unless required by applicable law or agreed to in writing, software
8
+ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9
+ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10
+ License for the specific language governing permissions and limitations
11
+ under the License.
12
+ """
13
+
14
+ from __future__ import print_function
15
+
16
+ import argparse
17
+ import inspect
18
+ import os
19
+ import pdb
20
+ import pickle
21
+ import random
22
+ import re
23
+ import shutil
24
+ import time
25
+ from collections import *
26
+
27
+ import ipdb
28
+ import numpy as np
29
+
30
+ # torch
31
+ import torch
32
+ import torch.backends.cudnn as cudnn
33
+ import torch.nn as nn
34
+ import torch.nn.functional as F
35
+ import torch.optim as optim
36
+ import yaml
37
+ from einops import rearrange, reduce, repeat
38
+ from evaluation.classificationMAP import getClassificationMAP as cmAP
39
+ from evaluation.detectionMAP import getSingleStreamDetectionMAP as dsmAP
40
+ from feeders.tools import collate_with_padding_multi_joint
41
+ from model.losses import cross_entropy_loss, mvl_loss
42
+ from sklearn.metrics import f1_score
43
+
44
+ # Custom
45
+ from tensorboardX import SummaryWriter
46
+ from torch.autograd import Variable
47
+ from torch.optim.lr_scheduler import _LRScheduler
48
+ from tqdm import tqdm
49
+ from utils.logger import Logger
50
+
51
+ def remove_prefix_from_state_dict(state_dict, prefix):
52
+ new_state_dict = {}
53
+ for k, v in state_dict.items():
54
+ if k.startswith(prefix):
55
+ new_k = k[len(prefix):] # strip the prefix
56
+ else:
57
+ new_k = k
58
+ new_state_dict[new_k] = v
59
+ return new_state_dict
60
+
61
+
62
+ def init_seed(seed):
63
+ torch.cuda.manual_seed_all(seed)
64
+ torch.manual_seed(seed)
65
+ np.random.seed(seed)
66
+ random.seed(seed)
67
+ torch.backends.cudnn.deterministic = True
68
+ torch.backends.cudnn.benchmark = False
69
+ # torch.use_deterministic_algorithms(True)
70
+
71
+ def get_parser():
72
+ # parameter priority: command line > config > default
73
+ parser = argparse.ArgumentParser(
74
+ description="Spatial Temporal Graph Convolution Network"
75
+ )
76
+ parser.add_argument(
77
+ "--work-dir",
78
+ default="./work_dir/temp",
79
+ help="the work folder for storing results",
80
+ )
81
+
82
+ parser.add_argument("-model_saved_name", default="")
83
+ parser.add_argument(
84
+ "--config",
85
+ default="./config/nturgbd-cross-view/test_bone.yaml",
86
+ help="path to the configuration file",
87
+ )
88
+
89
+ # processor
90
+ parser.add_argument("--phase", default="train", help="must be train or test")
91
+
92
+ # visulize and debug
93
+ parser.add_argument("--seed", type=int, default=5, help="random seed for pytorch")
94
+ parser.add_argument(
95
+ "--log-interval",
96
+ type=int,
97
+ default=100,
98
+ help="the interval for printing messages (#iteration)",
99
+ )
100
+ parser.add_argument(
101
+ "--save-interval",
102
+ type=int,
103
+ default=2,
104
+ help="the interval for storing models (#iteration)",
105
+ )
106
+ parser.add_argument(
107
+ "--eval-interval",
108
+ type=int,
109
+ default=5,
110
+ help="the interval for evaluating models (#iteration)",
111
+ )
112
+ parser.add_argument(
113
+ "--print-log", type=str2bool, default=True, help="print logging or not"
114
+ )
115
+ parser.add_argument(
116
+ "--show-topk",
117
+ type=int,
118
+ default=[1, 5],
119
+ nargs="+",
120
+ help="which Top K accuracy will be shown",
121
+ )
122
+
123
+ # feeder
124
+ parser.add_argument(
125
+ "--feeder", default="feeder.feeder", help="data loader will be used"
126
+ )
127
+ parser.add_argument(
128
+ "--num-worker",
129
+ type=int,
130
+ default=32,
131
+ help="the number of worker for data loader",
132
+ )
133
+ parser.add_argument(
134
+ "--train-feeder-args",
135
+ default=dict(),
136
+ help="the arguments of data loader for training",
137
+ )
138
+ parser.add_argument(
139
+ "--test-feeder-args",
140
+ default=dict(),
141
+ help="the arguments of data loader for test",
142
+ )
143
+
144
+ # model
145
+ parser.add_argument("--model", default=None, help="the model will be used")
146
+ parser.add_argument(
147
+ "--model-args", type=dict, default=dict(), help="the arguments of model"
148
+ )
149
+ parser.add_argument(
150
+ "--weights", default=None, help="the weights for network initialization"
151
+ )
152
+ parser.add_argument(
153
+ "--ignore-weights",
154
+ type=str,
155
+ default=[],
156
+ nargs="+",
157
+ help="the name of weights which will be ignored in the initialization",
158
+ )
159
+
160
+ # optim
161
+ parser.add_argument(
162
+ "--base-lr", type=float, default=0.01, help="initial learning rate"
163
+ )
164
+ parser.add_argument(
165
+ "--step",
166
+ type=int,
167
+ default=[200],
168
+ nargs="+",
169
+ help="the epoch where optimizer reduce the learning rate",
170
+ )
171
+
172
+ # training
173
+ parser.add_argument(
174
+ "--device",
175
+ type=int,
176
+ default=0,
177
+ nargs="+",
178
+ help="the indexes of GPUs for training or testing",
179
+ )
180
+ parser.add_argument("--optimizer", default="SGD", help="type of optimizer")
181
+ parser.add_argument(
182
+ "--nesterov", type=str2bool, default=False, help="use nesterov or not"
183
+ )
184
+ parser.add_argument(
185
+ "--batch-size", type=int, default=256, help="training batch size"
186
+ )
187
+ parser.add_argument(
188
+ "--test-batch-size", type=int, default=256, help="test batch size"
189
+ )
190
+ parser.add_argument(
191
+ "--start-epoch", type=int, default=0, help="start training from which epoch"
192
+ )
193
+ parser.add_argument(
194
+ "--num-epoch", type=int, default=80, help="stop training in which epoch"
195
+ )
196
+ parser.add_argument(
197
+ "--weight-decay", type=float, default=0.0005, help="weight decay for optimizer"
198
+ )
199
+ # loss
200
+ parser.add_argument("--loss", type=str, default="CE", help="loss type(CE or focal)")
201
+ parser.add_argument(
202
+ "--label_count_path",
203
+ default=None,
204
+ type=str,
205
+ help="Path to label counts (used in loss weighting)",
206
+ )
207
+ parser.add_argument(
208
+ "---beta",
209
+ type=float,
210
+ default=0.9999,
211
+ help="Hyperparameter for Class balanced loss",
212
+ )
213
+ parser.add_argument(
214
+ "--gamma", type=float, default=2.0, help="Hyperparameter for Focal loss"
215
+ )
216
+
217
+ parser.add_argument("--only_train_part", default=False)
218
+ parser.add_argument("--only_train_epoch", default=0)
219
+ parser.add_argument("--warm_up_epoch", default=10)
220
+
221
+ parser.add_argument(
222
+ "--lambda-mil", default=1.0, help="balancing hyper-parameter of mil branch"
223
+ )
224
+
225
+ parser.add_argument(
226
+ "--class-threshold",
227
+ type=float,
228
+ default=0.1,
229
+ help="class threshold for rejection",
230
+ )
231
+ parser.add_argument(
232
+ "--start-threshold",
233
+ type=float,
234
+ default=0.03,
235
+ help="start threshold for action localization",
236
+ )
237
+ parser.add_argument(
238
+ "--end-threshold",
239
+ type=float,
240
+ default=0.055,
241
+ help="end threshold for action localization",
242
+ )
243
+ parser.add_argument(
244
+ "--threshold-interval",
245
+ type=float,
246
+ default=0.005,
247
+ help="threshold interval for action localization",
248
+ )
249
+ return parser
250
+
251
+
252
+ class Processor:
253
+ """
254
+ Processor for Skeleton-based Action Recgnition
255
+ """
256
+
257
+ def __init__(self, arg):
258
+ self.arg = arg
259
+ self.save_arg()
260
+ if arg.phase == "train":
261
+ if not arg.train_feeder_args["debug"]:
262
+ if os.path.isdir(arg.model_saved_name):
263
+ print("log_dir: ", arg.model_saved_name, "already exist")
264
+ # answer = input('delete it? y/n:')
265
+ answer = "y"
266
+ if answer == "y":
267
+ print("Deleting dir...")
268
+ shutil.rmtree(arg.model_saved_name)
269
+ print("Dir removed: ", arg.model_saved_name)
270
+ # input('Refresh the website of tensorboard by pressing any keys')
271
+ else:
272
+ print("Dir not removed: ", arg.model_saved_name)
273
+ self.train_writer = SummaryWriter(
274
+ os.path.join(arg.model_saved_name, "train"), "train"
275
+ )
276
+ self.val_writer = SummaryWriter(
277
+ os.path.join(arg.model_saved_name, "val"), "val"
278
+ )
279
+ else:
280
+ self.train_writer = self.val_writer = SummaryWriter(
281
+ os.path.join(arg.model_saved_name, "test"), "test"
282
+ )
283
+ self.global_step = 0
284
+ self.load_model()
285
+ self.load_optimizer()
286
+ self.load_data()
287
+ self.lr = self.arg.base_lr
288
+ self.best_acc = 0
289
+ self.best_per_class_acc = 0
290
+ self.loss_nce = torch.nn.BCELoss()
291
+
292
+ self.my_logger = Logger(
293
+ os.path.join(arg.model_saved_name, "log.txt"), title="SWTAL"
294
+ )
295
+ self.my_logger.set_names(["Step", "cmap"] + [f"map_0.{i}" for i in range(1, 6)]+ ['avg'])
296
+
297
+
298
+
299
+
300
+ def load_data(self):
301
+
302
+ seed = self.arg.seed if hasattr(self.arg, "seed") else 42
303
+
304
+ def seed_worker(worker_id):
305
+ worker_seed = seed + worker_id
306
+ np.random.seed(worker_seed)
307
+ random.seed(worker_seed)
308
+
309
+ g = torch.Generator()
310
+ g.manual_seed(seed)
311
+
312
+ Feeder = import_class(self.arg.feeder)
313
+
314
+
315
+ self.data_loader = dict()
316
+ if self.arg.phase == "train":
317
+ self.data_loader["train"] = torch.utils.data.DataLoader(
318
+ dataset=Feeder(**self.arg.train_feeder_args),
319
+ batch_size=self.arg.batch_size,
320
+ shuffle=True,
321
+ num_workers=self.arg.num_worker,
322
+ drop_last=True,
323
+ collate_fn=collate_with_padding_multi_joint,
324
+ worker_init_fn=seed_worker, # ✅ 固定每个worker的seed
325
+ generator=g
326
+ )
327
+ self.data_loader["test"] = torch.utils.data.DataLoader(
328
+ dataset=Feeder(**self.arg.test_feeder_args),
329
+ batch_size=self.arg.test_batch_size,
330
+ shuffle=False,
331
+ num_workers=self.arg.num_worker,
332
+ drop_last=False,
333
+ collate_fn=collate_with_padding_multi_joint,
334
+ worker_init_fn=seed_worker, # ✅ 固定每个worker的seed
335
+ generator=g
336
+ )
337
+
338
+ def load_model(self):
339
+ output_device = (
340
+ self.arg.device[0] if type(self.arg.device) is list else self.arg.device
341
+ )
342
+ self.output_device = output_device
343
+ Model = import_class(self.arg.model)
344
+ shutil.copy2(inspect.getfile(Model), self.arg.work_dir)
345
+ # print(Model)
346
+ self.model = Model(**self.arg.model_args).cuda(output_device)
347
+ # print(self.model)
348
+ self.loss_type = arg.loss
349
+
350
+ if self.arg.weights:
351
+ # if False:
352
+ # self.global_step = int(arg.weights[:-3].split("-")[-1])
353
+ self.print_log("Load weights from {}.".format(self.arg.weights))
354
+ if ".pkl" in self.arg.weights:
355
+ with open(self.arg.weights, "r") as f:
356
+ weights = pickle.load(f)
357
+ else:
358
+ weights = torch.load(self.arg.weights)
359
+
360
+ weights = OrderedDict(
361
+ [
362
+ [k.split("module.")[-1], v.cuda(output_device)]
363
+ for k, v in weights.items()
364
+ ]
365
+ )
366
+ weights = remove_prefix_from_state_dict(weights, 'encoder_q.agcn.')
367
+ keys = list(weights.keys())
368
+ self.arg.ignore_weights = ['data_bn','fc','encoder_q','encoder_k','queue','queue_ptr','value_transform']
369
+ for w in self.arg.ignore_weights:
370
+ for key in keys:
371
+ if w in key:
372
+ if weights.pop(key, None) is not None:
373
+ self.print_log(
374
+ "Sucessfully Remove Weights: {}.".format(key)
375
+ )
376
+ else:
377
+ self.print_log("Can Not Remove Weights: {}.".format(key))
378
+
379
+ try:
380
+ self.model.load_state_dict(weights)
381
+ except:
382
+ state = self.model.state_dict()
383
+ diff = list(set(state.keys()).difference(set(weights.keys())))
384
+ print("Can not find these weights:")
385
+ for d in diff:
386
+ print(" " + d)
387
+ state.update(weights)
388
+ self.model.load_state_dict(state)
389
+
390
+ if type(self.arg.device) is list:
391
+ if len(self.arg.device) > 1:
392
+ self.model = nn.DataParallel(
393
+ self.model, device_ids=self.arg.device, output_device=output_device
394
+ )
395
+
396
+ def load_optimizer(self):
397
+ if self.arg.optimizer == "SGD":
398
+ self.optimizer = optim.SGD(
399
+ self.model.parameters(),
400
+ lr=self.arg.base_lr,
401
+ momentum=0.9,
402
+ nesterov=self.arg.nesterov,
403
+ weight_decay=self.arg.weight_decay,
404
+ )
405
+ elif self.arg.optimizer == "Adam":
406
+ self.optimizer = optim.Adam(
407
+ self.model.parameters(),
408
+ lr=self.arg.base_lr,
409
+ weight_decay=self.arg.weight_decay,
410
+ )
411
+ else:
412
+ raise ValueError()
413
+
414
+ def save_arg(self):
415
+ # save arg
416
+ arg_dict = vars(self.arg)
417
+ if not os.path.exists(self.arg.work_dir):
418
+ os.makedirs(self.arg.work_dir)
419
+ with open("{}/config.yaml".format(self.arg.work_dir), "w") as f:
420
+ yaml.dump(arg_dict, f)
421
+
422
+ def adjust_learning_rate(self, epoch):
423
+ if self.arg.optimizer == "SGD" or self.arg.optimizer == "Adam":
424
+ if epoch < self.arg.warm_up_epoch:
425
+ lr = self.arg.base_lr * (epoch + 1) / self.arg.warm_up_epoch
426
+ else:
427
+ lr = self.arg.base_lr * (
428
+ 0.1 ** np.sum(epoch >= np.array(self.arg.step))
429
+ )
430
+ for param_group in self.optimizer.param_groups:
431
+ param_group["lr"] = lr
432
+
433
+ return lr
434
+ else:
435
+ raise ValueError()
436
+
437
+ def print_time(self):
438
+ localtime = time.asctime(time.localtime(time.time()))
439
+ self.print_log("Local current time : " + localtime)
440
+
441
+ def print_log(self, str, print_time=True):
442
+ if print_time:
443
+ localtime = time.asctime(time.localtime(time.time()))
444
+ str = "[ " + localtime + " ] " + str
445
+ print(str)
446
+ if self.arg.print_log:
447
+ with open("{}/print_log.txt".format(self.arg.work_dir), "a") as f:
448
+ print(str, file=f)
449
+
450
+ def record_time(self):
451
+ self.cur_time = time.time()
452
+ return self.cur_time
453
+
454
+ def split_time(self):
455
+ split_time = time.time() - self.cur_time
456
+ self.record_time()
457
+ return split_time
458
+
459
+ def train(self, epoch, wb_dict, save_model=False):
460
+ self.model.train()
461
+ self.print_log("Training epoch: {}".format(epoch + 1))
462
+ loader = self.data_loader["train"]
463
+ self.adjust_learning_rate(epoch)
464
+
465
+ loss_value, batch_acc = [], []
466
+ self.train_writer.add_scalar("epoch", epoch, self.global_step)
467
+ self.record_time()
468
+ timer = dict(dataloader=0.001, model=0.001, statistics=0.001)
469
+ process = tqdm(loader)
470
+ if self.arg.only_train_part:
471
+ if epoch > self.arg.only_train_epoch:
472
+ print("only train part, require grad")
473
+ for key, value in self.model.named_parameters():
474
+ if "PA" in key:
475
+ value.requires_grad = True
476
+ else:
477
+ print("only train part, do not require grad")
478
+ for key, value in self.model.named_parameters():
479
+ if "PA" in key:
480
+ value.requires_grad = False
481
+
482
+ vid_preds = []
483
+ frm_preds = []
484
+ vid_lens = []
485
+ labels = []
486
+
487
+ results = []
488
+ indexs = []
489
+
490
+ '''
491
+ Switch to FULL supervision
492
+ Dataloader->Feeder -> collate_with_padding_multi_joint
493
+ '''
494
+
495
+ for batch_idx, (data, label, target, mask, index, soft_label) in enumerate(
496
+ process
497
+ ):
498
+
499
+ self.global_step += 1
500
+ # get data
501
+ data = data.float().cuda(self.output_device)
502
+ label = label.cuda(self.output_device)
503
+ target = target.cuda(self.output_device)
504
+ mask = mask.cuda(self.output_device)
505
+ soft_label = soft_label.cuda(self.output_device)
506
+ timer["dataloader"] += self.split_time()
507
+
508
+ ''' into one hot'''
509
+ ground_truth_flat = target.view(-1)
510
+ one_hot_ground_truth = F.one_hot(ground_truth_flat, num_classes=5)
511
+ ''' into one hot'''
512
+
513
+ indexs.extend(index.cpu().numpy().tolist())
514
+
515
+ ab_labels = torch.cat([label, torch.ones(label.size(0), 1).cuda()], -1)
516
+
517
+ # forward
518
+ # print(data.shape)
519
+ mil_pred, frm_scrs, mil_pred_2, frm_scrs_2 = self.model(data,mask)
520
+
521
+ cls_mil_loss = self.loss_nce(mil_pred.float(), ab_labels.float()) + self.loss_nce(
522
+ mil_pred_2.float(), ab_labels.float()
523
+ )
524
+
525
+ if epoch > -1:
526
+
527
+ frm_scrs_re = rearrange(frm_scrs, "n t c -> (n t) c")
528
+ frm_scrs_2_re = rearrange(frm_scrs_2, "n t c -> (n t) c")
529
+ soft_label = rearrange(soft_label, "n t c -> (n t) c")
530
+
531
+ loss = cls_mil_loss * 0.1 + mvl_loss(
532
+ frm_scrs, frm_scrs_2, rate=0.2, weight=0.5
533
+ )
534
+
535
+ loss += cross_entropy_loss(
536
+ frm_scrs_re, one_hot_ground_truth
537
+ ) + cross_entropy_loss(frm_scrs_2_re, one_hot_ground_truth)
538
+
539
+ for i in range(data.size(0)):
540
+ frm_scr = frm_scrs[i]
541
+
542
+ label_ = label[i].cpu().numpy()
543
+ mask_ = mask[i].cpu().numpy()
544
+ vid_len = mask_.sum()
545
+
546
+ frm_pred = F.softmax(frm_scr, -1).detach().cpu().numpy()[:vid_len]
547
+ vid_pred = mil_pred[i].detach().cpu().numpy()
548
+
549
+ results.append(frm_pred)
550
+
551
+ vid_preds.append(vid_pred)
552
+ frm_preds.append(frm_pred)
553
+ vid_lens.append(vid_len)
554
+ labels.append(label_)
555
+
556
+ # backward
557
+ self.optimizer.zero_grad()
558
+ loss.backward()
559
+ self.optimizer.step()
560
+
561
+ loss_value.append(loss.data.item())
562
+ timer["model"] += self.split_time()
563
+
564
+ vid_preds = np.array(vid_preds)
565
+ frm_preds = np.array(frm_preds)
566
+ vid_lens = np.array(vid_lens)
567
+ labels = np.array(labels)
568
+
569
+ loader.dataset.label_update(results, indexs)
570
+
571
+ cmap = cmAP(vid_preds, labels)
572
+
573
+ self.train_writer.add_scalar("acc", cmap, self.global_step)
574
+ self.train_writer.add_scalar("loss", np.mean(loss_value), self.global_step)
575
+
576
+ # statistics
577
+ self.lr = self.optimizer.param_groups[0]["lr"]
578
+ self.train_writer.add_scalar("lr", self.lr, self.global_step)
579
+ timer["statistics"] += self.split_time()
580
+
581
+ # statistics of time consumption and loss
582
+ self.print_log("\tMean training loss: {:.4f}.".format(np.mean(loss_value)))
583
+ self.print_log("\tAcc score: {:.3f}%".format(cmap))
584
+
585
+ # Log
586
+ wb_dict["train loss"] = np.mean(loss_value)
587
+ wb_dict["train acc"] = cmap
588
+
589
+ if save_model:
590
+ state_dict = self.model.state_dict()
591
+ weights = OrderedDict(
592
+ [[k.split("module.")[-1], v.cpu()] for k, v in state_dict.items()]
593
+ )
594
+
595
+ torch.save(
596
+ weights,
597
+ self.arg.model_saved_name + str(epoch) + ".pt",
598
+ )
599
+
600
+ return wb_dict
601
+
602
+ @torch.no_grad()
603
+ def eval(
604
+ self,
605
+ epoch,
606
+ wb_dict,
607
+ loader_name=["test"],
608
+ ):
609
+ self.model.eval()
610
+ self.print_log("Eval epoch: {}".format(epoch + 1))
611
+
612
+ vid_preds = []
613
+ frm_preds = []
614
+ vid_lens = []
615
+ labels = []
616
+
617
+ for ln in loader_name:
618
+ loss_value = []
619
+ step = 0
620
+ process = tqdm(self.data_loader[ln])
621
+
622
+ for batch_idx, (data, label, target, mask, index, soft_label) in enumerate(
623
+ process
624
+ ):
625
+ data = data.float().cuda(self.output_device)
626
+ label = label.cuda(self.output_device)
627
+ mask = mask.cuda(self.output_device)
628
+
629
+ ab_labels = torch.cat([label, torch.ones(label.size(0), 1).cuda()], -1)
630
+ # print(data.shape)
631
+ # forward
632
+ mil_pred, frm_scrs, mil_pred_2, frm_scrs_2 = self.model(data,mask)
633
+
634
+ # cls_mil_loss = self.loss_nce(
635
+ # mil_pred, ab_labels.float()
636
+ # ) + self.loss_nce(mil_pred_2, ab_labels.float())
637
+
638
+ # loss_co = mvl_loss(frm_scrs, frm_scrs_2, rate=0.2, weight=0.5)
639
+
640
+ # loss = cls_mil_loss * self.arg.lambda_mil + loss_co
641
+ '''Loc LOSS'''
642
+ target = target.cuda(self.output_device)
643
+ ''' into one hot'''
644
+ ground_truth_flat = target.view(-1)
645
+ one_hot_ground_truth = F.one_hot(ground_truth_flat, num_classes=5)
646
+ ''' into one hot'''
647
+ frm_scrs_re = rearrange(frm_scrs, "n t c -> (n t) c")
648
+ frm_scrs_2_re = rearrange(frm_scrs_2, "n t c -> (n t) c")
649
+ '''Loc LOSS'''
650
+ '''Loc LOSS'''
651
+ loss = cross_entropy_loss(
652
+ frm_scrs_re, one_hot_ground_truth
653
+ ) + cross_entropy_loss(frm_scrs_2_re, one_hot_ground_truth)
654
+ '''Loc LOSS'''
655
+
656
+ loss_value.append(loss.data.item())
657
+
658
+ for i in range(data.size(0)):
659
+ frm_scr = frm_scrs[i]
660
+ vid_pred = mil_pred[i]
661
+
662
+ label_ = label[i].cpu().numpy()
663
+ mask_ = mask[i].cpu().numpy()
664
+ vid_len = mask_.sum()
665
+
666
+ frm_pred = F.softmax(frm_scr, -1).cpu().numpy()[:vid_len]
667
+ vid_pred = vid_pred.cpu().numpy()
668
+
669
+ vid_preds.append(vid_pred)
670
+ frm_preds.append(frm_pred)
671
+ vid_lens.append(vid_len)
672
+ labels.append(label_)
673
+
674
+ step += 1
675
+
676
+ vid_preds = np.array(vid_preds)
677
+ frm_preds = np.array(frm_preds)
678
+ vid_lens = np.array(vid_lens)
679
+ labels = np.array(labels)
680
+
681
+ cmap = cmAP(vid_preds, labels)
682
+
683
+ score = cmap
684
+ loss = np.mean(loss_value)
685
+
686
+ dmap, iou = dsmAP(
687
+ vid_preds,
688
+ frm_preds,
689
+ vid_lens,
690
+ self.arg.test_feeder_args["data_path"],
691
+ self.arg,
692
+ multi=True,
693
+ )
694
+
695
+ print("Classification map %f" % cmap)
696
+ for item in list(zip(iou, dmap)):
697
+ print("Detection map @ %f = %f" % (item[0], item[1]))
698
+
699
+ self.my_logger.append([epoch + 1, cmap] + dmap+ [np.mean(dmap)])
700
+
701
+ wb_dict["val loss"] = loss
702
+ wb_dict["val acc"] = score
703
+
704
+ if score > self.best_acc:
705
+ self.best_acc = score
706
+
707
+ print("Acc score: ", score, " model: ", self.arg.model_saved_name)
708
+ if self.arg.phase == "train":
709
+ self.val_writer.add_scalar("loss", loss, self.global_step)
710
+ self.val_writer.add_scalar("acc", score, self.global_step)
711
+
712
+ self.print_log(
713
+ "\tMean {} loss of {} batches: {}.".format(
714
+ ln, len(self.data_loader[ln]), np.mean(loss_value)
715
+ )
716
+ )
717
+ self.print_log("\tAcc score: {:.3f}%".format(score))
718
+
719
+ return wb_dict
720
+
721
+ def start(self):
722
+ wb_dict = {}
723
+ if self.arg.phase == "train":
724
+ self.print_log("Parameters:\n{}\n".format(str(vars(self.arg))))
725
+ self.global_step = (
726
+ self.arg.start_epoch
727
+ * len(self.data_loader["train"])
728
+ / self.arg.batch_size
729
+ )
730
+
731
+ for epoch in range(self.arg.start_epoch, self.arg.num_epoch):
732
+
733
+ save_model = ((epoch + 1) % self.arg.save_interval == 0) or (
734
+ epoch + 1 == self.arg.num_epoch
735
+ )
736
+ wb_dict = {"lr": self.lr}
737
+
738
+ # Train
739
+ wb_dict = self.train(epoch, wb_dict, save_model=save_model)
740
+ if epoch%5==0:
741
+ # Eval. on val set
742
+ wb_dict = self.eval(epoch, wb_dict, loader_name=["test"])
743
+ # Log stats. for this epoch
744
+ print("Epoch: {0}\nMetrics: {1}".format(epoch, wb_dict))
745
+
746
+ print(
747
+ "best accuracy: ",
748
+ self.best_acc,
749
+ " model_name: ",
750
+ self.arg.model_saved_name,
751
+ )
752
+
753
+ elif self.arg.phase == "test":
754
+ if not self.arg.test_feeder_args["debug"]:
755
+ wf = self.arg.model_saved_name + "_wrong.txt"
756
+ rf = self.arg.model_saved_name + "_right.txt"
757
+ else:
758
+ wf = rf = None
759
+ if self.arg.weights is None:
760
+ raise ValueError("Please appoint --weights.")
761
+ self.arg.print_log = False
762
+ self.print_log("Model: {}.".format(self.arg.model))
763
+ self.print_log("Weights: {}.".format(self.arg.weights))
764
+
765
+ wb_dict = self.eval(
766
+ epoch=0,
767
+ wb_dict=wb_dict,
768
+ loader_name=["test"],
769
+ wrong_file=wf,
770
+ result_file=rf,
771
+ )
772
+ print("Inference metrics: ", wb_dict)
773
+ self.print_log("Done.\n")
774
+
775
+
776
+ def str2bool(v):
777
+ if v.lower() in ("yes", "true", "t", "y", "1"):
778
+ return True
779
+ elif v.lower() in ("no", "false", "f", "n", "0"):
780
+ return False
781
+ else:
782
+ raise argparse.ArgumentTypeError("Boolean value expected.")
783
+
784
+
785
+ def import_class(name):
786
+ components = name.split(".")
787
+ mod = __import__(components[0])
788
+ for comp in components[1:]:
789
+ mod = getattr(mod, comp)
790
+ return mod
791
+
792
+
793
+ if __name__ == "__main__":
794
+ parser = get_parser()
795
+
796
+ # load arg form config file
797
+ p = parser.parse_args()
798
+ if p.config is not None:
799
+ with open(p.config, "r") as f:
800
+ default_arg = yaml.safe_load(f)
801
+ key = vars(p).keys()
802
+ for k in default_arg.keys():
803
+ if k not in key:
804
+ print("WRONG ARG: {}".format(k))
805
+ assert k in key
806
+ parser.set_defaults(**default_arg)
807
+
808
+ arg = parser.parse_args()
809
+ print("BABEL Action Recognition")
810
+ print("Config: ", arg)
811
+ init_seed(arg.seed)
812
+ processor = Processor(arg)
813
+ processor.start()
utils/__init__.py ADDED
File without changes
utils/logger.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A simple torch style logger
2
+ # (C) Wei YANG 2017
3
+ from __future__ import absolute_import
4
+
5
+ import os
6
+ import sys
7
+
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+
11
+ __all__ = ["Logger", "LoggerMonitor", "savefig"]
12
+
13
+
14
+ def savefig(fname, dpi=None):
15
+ dpi = 150 if dpi == None else dpi
16
+ plt.savefig(fname, dpi=dpi)
17
+
18
+
19
+ def plot_overlap(logger, names=None):
20
+ names = logger.names if names == None else names
21
+ numbers = logger.numbers
22
+ for _, name in enumerate(names):
23
+ x = np.arange(len(numbers[name]))
24
+ plt.plot(x, np.asarray(numbers[name]))
25
+ return [logger.title + "(" + name + ")" for name in names]
26
+
27
+
28
+ class Logger(object):
29
+ """Save training process to log file with simple plot function."""
30
+
31
+ def __init__(self, fpath, title=None, resume=False):
32
+ self.file = None
33
+ self.resume = resume
34
+ self.title = "" if title == None else title
35
+ if fpath is not None:
36
+ if resume:
37
+ self.file = open(fpath, "r")
38
+ name = self.file.readline()
39
+ self.names = name.rstrip().split("\t")
40
+ self.numbers = {}
41
+ for _, name in enumerate(self.names):
42
+ self.numbers[name] = []
43
+
44
+ for numbers in self.file:
45
+ numbers = numbers.rstrip().split("\t")
46
+ for i in range(0, len(numbers)):
47
+ self.numbers[self.names[i]].append(numbers[i])
48
+ self.file.close()
49
+ self.file = open(fpath, "a")
50
+ else:
51
+ self.file = open(fpath, "w")
52
+
53
+ def set_names(self, names):
54
+ if self.resume:
55
+ pass
56
+ # initialize numbers as empty list
57
+ self.numbers = {}
58
+ self.names = names
59
+ for _, name in enumerate(self.names):
60
+ self.file.write(name)
61
+ self.file.write("\t")
62
+ self.numbers[name] = []
63
+ self.file.write("\n")
64
+ self.file.flush()
65
+
66
+ def append(self, numbers):
67
+ assert len(self.names) == len(numbers), "Numbers do not match names"
68
+ for index, num in enumerate(numbers):
69
+ self.file.write("{0:.6f}".format(num))
70
+ self.file.write("\t")
71
+ self.numbers[self.names[index]].append(num)
72
+ self.file.write("\n")
73
+ self.file.flush()
74
+
75
+ def plot(self, names=None):
76
+ names = self.names if names == None else names
77
+ numbers = self.numbers
78
+ for _, name in enumerate(names):
79
+ x = np.arange(len(numbers[name]))
80
+ plt.plot(x, np.asarray(numbers[name]))
81
+ plt.legend([self.title + "(" + name + ")" for name in names])
82
+ plt.grid(True)
83
+
84
+ def close(self):
85
+ if self.file is not None:
86
+ self.file.close()
87
+
88
+
89
+ class LoggerMonitor(object):
90
+ """Load and visualize multiple logs."""
91
+
92
+ def __init__(self, paths):
93
+ """paths is a distionary with {name:filepath} pair"""
94
+ self.loggers = []
95
+ for title, path in paths.items():
96
+ logger = Logger(path, title=title, resume=True)
97
+ self.loggers.append(logger)
98
+
99
+ def plot(self, names=None):
100
+ plt.figure()
101
+ plt.subplot(121)
102
+ legend_text = []
103
+ for logger in self.loggers:
104
+ legend_text += plot_overlap(logger, names)
105
+ plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)
106
+ plt.grid(True)
107
+
108
+
109
+ if __name__ == "__main__":
110
+ # # Example
111
+ # logger = Logger('test.txt')
112
+ # logger.set_names(['Train loss', 'Valid loss','Test loss'])
113
+
114
+ # length = 100
115
+ # t = np.arange(length)
116
+ # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
117
+ # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
118
+ # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
119
+
120
+ # for i in range(0, length):
121
+ # logger.append([train_loss[i], valid_loss[i], test_loss[i]])
122
+ # logger.plot()
123
+
124
+ # Example: logger monitor
125
+ paths = {
126
+ "resadvnet20": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt",
127
+ "resadvnet32": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt",
128
+ "resadvnet44": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt",
129
+ }
130
+
131
+ field = ["Valid Acc."]
132
+
133
+ monitor = LoggerMonitor(paths)
134
+ monitor.plot(names=field)
135
+ savefig("test.eps")