Initial upload
Browse files- .gitignore +163 -6
- .vscode/launch.json +29 -0
- LICENSE +202 -0
- NOTICE.txt +16 -0
- README.md +90 -0
- configs/train_split1_full_MSFF_DTCL.yaml +43 -0
- configs/train_split2_full_MSFF_DTCL.yaml +43 -0
- configs/train_split3_full_MSFF_DTCL.yaml +43 -0
- eval.py +585 -0
- evaluation/classificationMAP.py +26 -0
- evaluation/detectionMAP.py +516 -0
- evaluation/eval.py +129 -0
- evaluation/utils.py +57 -0
- feeders/__init__.py +1 -0
- feeders/feeder.py +313 -0
- feeders/tools.py +234 -0
- graph/__init__.py +1 -0
- graph/kinetics.py +76 -0
- graph/ntu_rgb_d.py +69 -0
- graph/tools.py +113 -0
- huggingface.py +9 -0
- human_model/Put SMPLH model here.txt +0 -0
- model/__init__.py +7 -0
- model/agcn.py +278 -0
- model/losses.py +63 -0
- prepare/configs/action_label_split1.json +6 -0
- prepare/configs/action_label_split2.json +6 -0
- prepare/configs/action_label_split3.json +6 -0
- prepare/create_dataset.py +370 -0
- prepare/dutils.py +310 -0
- prepare/generate_dataset.sh +5 -0
- prepare/preprocess.py +94 -0
- prepare/rotation.py +91 -0
- prepare/split_dataset.py +143 -0
- prepare/viz.py +447 -0
- pyproject.toml +13 -0
- requirements.txt +28 -0
- train.py +830 -0
- train_full.py +788 -0
- train_full_SSL.py +784 -0
- train_full_SSL_Unet.py +813 -0
- utils/__init__.py +0 -0
- utils/logger.py +135 -0
.gitignore
CHANGED
|
@@ -1,6 +1,163 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dataset
|
| 2 |
+
data
|
| 3 |
+
dataset
|
| 4 |
+
work_dir
|
| 5 |
+
|
| 6 |
+
*.pkl
|
| 7 |
+
*.mp4
|
| 8 |
+
|
| 9 |
+
*.sh.o*
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Byte-compiled / optimized / DLL files
|
| 13 |
+
__pycache__/
|
| 14 |
+
*.py[cod]
|
| 15 |
+
*$py.class
|
| 16 |
+
|
| 17 |
+
# C extensions
|
| 18 |
+
*.so
|
| 19 |
+
|
| 20 |
+
# Distribution / packaging
|
| 21 |
+
.Python
|
| 22 |
+
build/
|
| 23 |
+
develop-eggs/
|
| 24 |
+
dist/
|
| 25 |
+
downloads/
|
| 26 |
+
eggs/
|
| 27 |
+
.eggs/
|
| 28 |
+
lib/
|
| 29 |
+
lib64/
|
| 30 |
+
parts/
|
| 31 |
+
sdist/
|
| 32 |
+
var/
|
| 33 |
+
wheels/
|
| 34 |
+
share/python-wheels/
|
| 35 |
+
*.egg-info/
|
| 36 |
+
.installed.cfg
|
| 37 |
+
*.egg
|
| 38 |
+
MANIFEST
|
| 39 |
+
|
| 40 |
+
# PyInstaller
|
| 41 |
+
# Usually these files are written by a python script from a template
|
| 42 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 43 |
+
*.manifest
|
| 44 |
+
*.spec
|
| 45 |
+
|
| 46 |
+
# Installer logs
|
| 47 |
+
pip-log.txt
|
| 48 |
+
pip-delete-this-directory.txt
|
| 49 |
+
|
| 50 |
+
# Unit test / coverage reports
|
| 51 |
+
htmlcov/
|
| 52 |
+
.tox/
|
| 53 |
+
.nox/
|
| 54 |
+
.coverage
|
| 55 |
+
.coverage.*
|
| 56 |
+
.cache
|
| 57 |
+
nosetests.xml
|
| 58 |
+
coverage.xml
|
| 59 |
+
*.cover
|
| 60 |
+
*.py,cover
|
| 61 |
+
.hypothesis/
|
| 62 |
+
.pytest_cache/
|
| 63 |
+
cover/
|
| 64 |
+
|
| 65 |
+
# Translations
|
| 66 |
+
*.mo
|
| 67 |
+
*.pot
|
| 68 |
+
|
| 69 |
+
# Django stuff:
|
| 70 |
+
*.log
|
| 71 |
+
local_settings.py
|
| 72 |
+
db.sqlite3
|
| 73 |
+
db.sqlite3-journal
|
| 74 |
+
|
| 75 |
+
# Flask stuff:
|
| 76 |
+
instance/
|
| 77 |
+
.webassets-cache
|
| 78 |
+
|
| 79 |
+
# Scrapy stuff:
|
| 80 |
+
.scrapy
|
| 81 |
+
|
| 82 |
+
# Sphinx documentation
|
| 83 |
+
docs/_build/
|
| 84 |
+
|
| 85 |
+
# PyBuilder
|
| 86 |
+
.pybuilder/
|
| 87 |
+
target/
|
| 88 |
+
|
| 89 |
+
# Jupyter Notebook
|
| 90 |
+
.ipynb_checkpoints
|
| 91 |
+
|
| 92 |
+
# IPython
|
| 93 |
+
profile_default/
|
| 94 |
+
ipython_config.py
|
| 95 |
+
|
| 96 |
+
# pyenv
|
| 97 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 98 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 99 |
+
# .python-version
|
| 100 |
+
|
| 101 |
+
# pipenv
|
| 102 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 103 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 104 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 105 |
+
# install all needed dependencies.
|
| 106 |
+
#Pipfile.lock
|
| 107 |
+
|
| 108 |
+
# poetry
|
| 109 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 110 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 111 |
+
# commonly ignored for libraries.
|
| 112 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 113 |
+
#poetry.lock
|
| 114 |
+
|
| 115 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 116 |
+
__pypackages__/
|
| 117 |
+
|
| 118 |
+
# Celery stuff
|
| 119 |
+
celerybeat-schedule
|
| 120 |
+
celerybeat.pid
|
| 121 |
+
|
| 122 |
+
# SageMath parsed files
|
| 123 |
+
*.sage.py
|
| 124 |
+
|
| 125 |
+
# Environments
|
| 126 |
+
.env
|
| 127 |
+
.venv
|
| 128 |
+
env/
|
| 129 |
+
venv/
|
| 130 |
+
ENV/
|
| 131 |
+
env.bak/
|
| 132 |
+
venv.bak/
|
| 133 |
+
|
| 134 |
+
# Spyder project settings
|
| 135 |
+
.spyderproject
|
| 136 |
+
.spyproject
|
| 137 |
+
|
| 138 |
+
# Rope project settings
|
| 139 |
+
.ropeproject
|
| 140 |
+
|
| 141 |
+
# mkdocs documentation
|
| 142 |
+
/site
|
| 143 |
+
|
| 144 |
+
# mypy
|
| 145 |
+
.mypy_cache/
|
| 146 |
+
.dmypy.json
|
| 147 |
+
dmypy.json
|
| 148 |
+
|
| 149 |
+
# Pyre type checker
|
| 150 |
+
.pyre/
|
| 151 |
+
|
| 152 |
+
# pytype static type analyzer
|
| 153 |
+
.pytype/
|
| 154 |
+
|
| 155 |
+
# Cython debug symbols
|
| 156 |
+
cython_debug/
|
| 157 |
+
|
| 158 |
+
# PyCharm
|
| 159 |
+
# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
|
| 160 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 161 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 162 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 163 |
+
#.idea/
|
.vscode/launch.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": "0.2.0",
|
| 3 |
+
"configurations": [
|
| 4 |
+
{
|
| 5 |
+
"name": "Python: Debug Current File (gvhmr)",
|
| 6 |
+
"type": "debugpy",
|
| 7 |
+
"request": "launch",
|
| 8 |
+
"program": "${file}",
|
| 9 |
+
"console": "integratedTerminal",
|
| 10 |
+
"cwd": "/root/autodl-tmp/workshop2/",
|
| 11 |
+
"args": [
|
| 12 |
+
"--config",
|
| 13 |
+
// "config/eval.yaml"
|
| 14 |
+
"configss/train_split3_full_MSFF_DTCL.yaml",
|
| 15 |
+
// "config/pretrain2.yaml"
|
| 16 |
+
"--work-dir",
|
| 17 |
+
"./work_dir/3_agcn_cl_all/",
|
| 18 |
+
"-model_saved_name",
|
| 19 |
+
"./work_dir/3_agcn_cl_all/",
|
| 20 |
+
"--weights",
|
| 21 |
+
"/root/autodl-tmp/RVTCLR/work_dir/ntu60_cs/skeletonclr_joint/U/cl/pretext_babel/epoch300_model.pt"
|
| 22 |
+
]
|
| 23 |
+
}
|
| 24 |
+
]
|
| 25 |
+
}
|
| 26 |
+
// CUBLAS_WORKSPACE_CONFIG=:4096:8 python train_full_SSL.py --config configss/train_split3_full_MSFF_DTCL.yaml \
|
| 27 |
+
// --work-dir ./work_dir/3_agcn_cl_all/ \
|
| 28 |
+
// -model_saved_name ./work_dir/3_agcn_cl_all/ \
|
| 29 |
+
// --weights /root/autodl-tmp/RVTCLR/work_dir/ntu60_cs/skeletonclr_joint/U/both/pretext_babel_dense_U/epoch300_model.pt
|
LICENSE
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Apache License
|
| 3 |
+
Version 2.0, January 2004
|
| 4 |
+
http://www.apache.org/licenses/
|
| 5 |
+
|
| 6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 7 |
+
|
| 8 |
+
1. Definitions.
|
| 9 |
+
|
| 10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 12 |
+
|
| 13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 14 |
+
the copyright owner that is granting the License.
|
| 15 |
+
|
| 16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 17 |
+
other entities that control, are controlled by, or are under common
|
| 18 |
+
control with that entity. For the purposes of this definition,
|
| 19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 20 |
+
direction or management of such entity, whether by contract or
|
| 21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 23 |
+
|
| 24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 25 |
+
exercising permissions granted by this License.
|
| 26 |
+
|
| 27 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 28 |
+
including but not limited to software source code, documentation
|
| 29 |
+
source, and configuration files.
|
| 30 |
+
|
| 31 |
+
"Object" form shall mean any form resulting from mechanical
|
| 32 |
+
transformation or translation of a Source form, including but
|
| 33 |
+
not limited to compiled object code, generated documentation,
|
| 34 |
+
and conversions to other media types.
|
| 35 |
+
|
| 36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 37 |
+
Object form, made available under the License, as indicated by a
|
| 38 |
+
copyright notice that is included in or attached to the work
|
| 39 |
+
(an example is provided in the Appendix below).
|
| 40 |
+
|
| 41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 42 |
+
form, that is based on (or derived from) the Work and for which the
|
| 43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 45 |
+
of this License, Derivative Works shall not include works that remain
|
| 46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 47 |
+
the Work and Derivative Works thereof.
|
| 48 |
+
|
| 49 |
+
"Contribution" shall mean any work of authorship, including
|
| 50 |
+
the original version of the Work and any modifications or additions
|
| 51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 55 |
+
means any form of electronic, verbal, or written communication sent
|
| 56 |
+
to the Licensor or its representatives, including but not limited to
|
| 57 |
+
communication on electronic mailing lists, source code control systems,
|
| 58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 60 |
+
excluding communication that is conspicuously marked or otherwise
|
| 61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 62 |
+
|
| 63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 65 |
+
subsequently incorporated within the Work.
|
| 66 |
+
|
| 67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 72 |
+
Work and such Derivative Works in Source or Object form.
|
| 73 |
+
|
| 74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 77 |
+
(except as stated in this section) patent license to make, have made,
|
| 78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 79 |
+
where such license applies only to those patent claims licensable
|
| 80 |
+
by such Contributor that are necessarily infringed by their
|
| 81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 83 |
+
institute patent litigation against any entity (including a
|
| 84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 85 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 86 |
+
or contributory patent infringement, then any patent licenses
|
| 87 |
+
granted to You under this License for that Work shall terminate
|
| 88 |
+
as of the date such litigation is filed.
|
| 89 |
+
|
| 90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 91 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 92 |
+
modifications, and in Source or Object form, provided that You
|
| 93 |
+
meet the following conditions:
|
| 94 |
+
|
| 95 |
+
(a) You must give any other recipients of the Work or
|
| 96 |
+
Derivative Works a copy of this License; and
|
| 97 |
+
|
| 98 |
+
(b) You must cause any modified files to carry prominent notices
|
| 99 |
+
stating that You changed the files; and
|
| 100 |
+
|
| 101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 102 |
+
that You distribute, all copyright, patent, trademark, and
|
| 103 |
+
attribution notices from the Source form of the Work,
|
| 104 |
+
excluding those notices that do not pertain to any part of
|
| 105 |
+
the Derivative Works; and
|
| 106 |
+
|
| 107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 108 |
+
distribution, then any Derivative Works that You distribute must
|
| 109 |
+
include a readable copy of the attribution notices contained
|
| 110 |
+
within such NOTICE file, excluding those notices that do not
|
| 111 |
+
pertain to any part of the Derivative Works, in at least one
|
| 112 |
+
of the following places: within a NOTICE text file distributed
|
| 113 |
+
as part of the Derivative Works; within the Source form or
|
| 114 |
+
documentation, if provided along with the Derivative Works; or,
|
| 115 |
+
within a display generated by the Derivative Works, if and
|
| 116 |
+
wherever such third-party notices normally appear. The contents
|
| 117 |
+
of the NOTICE file are for informational purposes only and
|
| 118 |
+
do not modify the License. You may add Your own attribution
|
| 119 |
+
notices within Derivative Works that You distribute, alongside
|
| 120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 121 |
+
that such additional attribution notices cannot be construed
|
| 122 |
+
as modifying the License.
|
| 123 |
+
|
| 124 |
+
You may add Your own copyright statement to Your modifications and
|
| 125 |
+
may provide additional or different license terms and conditions
|
| 126 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 127 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 128 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 129 |
+
the conditions stated in this License.
|
| 130 |
+
|
| 131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 133 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 134 |
+
this License, without any additional terms or conditions.
|
| 135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 136 |
+
the terms of any separate license agreement you may have executed
|
| 137 |
+
with Licensor regarding such Contributions.
|
| 138 |
+
|
| 139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 141 |
+
except as required for reasonable and customary use in describing the
|
| 142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 143 |
+
|
| 144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 145 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 148 |
+
implied, including, without limitation, any warranties or conditions
|
| 149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 151 |
+
appropriateness of using or redistributing the Work and assume any
|
| 152 |
+
risks associated with Your exercise of permissions under this License.
|
| 153 |
+
|
| 154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 155 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 156 |
+
unless required by applicable law (such as deliberate and grossly
|
| 157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 158 |
+
liable to You for damages, including any direct, indirect, special,
|
| 159 |
+
incidental, or consequential damages of any character arising as a
|
| 160 |
+
result of this License or out of the use or inability to use the
|
| 161 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 162 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 163 |
+
other commercial damages or losses), even if such Contributor
|
| 164 |
+
has been advised of the possibility of such damages.
|
| 165 |
+
|
| 166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 169 |
+
or other liability obligations and/or rights consistent with this
|
| 170 |
+
License. However, in accepting such obligations, You may act only
|
| 171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 172 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 173 |
+
defend, and hold each Contributor harmless for any liability
|
| 174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 175 |
+
of your accepting any such warranty or additional liability.
|
| 176 |
+
|
| 177 |
+
END OF TERMS AND CONDITIONS
|
| 178 |
+
|
| 179 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 180 |
+
|
| 181 |
+
To apply the Apache License to your work, attach the following
|
| 182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 183 |
+
replaced with your own identifying information. (Don't include
|
| 184 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 185 |
+
comment syntax for the file format. We also recommend that a
|
| 186 |
+
file or class name and description of purpose be included on the
|
| 187 |
+
same "printed page" as the copyright notice for easier
|
| 188 |
+
identification within third-party archives.
|
| 189 |
+
|
| 190 |
+
Copyright [yyyy] [name of copyright owner]
|
| 191 |
+
|
| 192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 193 |
+
you may not use this file except in compliance with the License.
|
| 194 |
+
You may obtain a copy of the License at
|
| 195 |
+
|
| 196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 197 |
+
|
| 198 |
+
Unless required by applicable law or agreed to in writing, software
|
| 199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 201 |
+
See the License for the specific language governing permissions and
|
| 202 |
+
limitations under the License.
|
NOTICE.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### Open Source Software Notice ###
|
| 2 |
+
|
| 3 |
+
This product from LINE Corporation contains the open source software or third-party software listed below, to which the terms in the LICENSE file in this repository do not apply.
|
| 4 |
+
Please refer to the licences of the respective software repositories for the terms and conditions of their use.
|
| 5 |
+
|
| 6 |
+
BABEL (Non-commercial License)
|
| 7 |
+
https://github.com/abhinanda-punnakkal/BABEL
|
| 8 |
+
|
| 9 |
+
2s-AGCN (CC BY-NC 4.0 License)
|
| 10 |
+
https://github.com/lshiwjx/2s-AGCN
|
| 11 |
+
|
| 12 |
+
FAC-Net (MIT License)
|
| 13 |
+
https://github.com/LeonHLJ/FAC-Net
|
| 14 |
+
|
| 15 |
+
pytorch-classification(MIT License)
|
| 16 |
+
https://github.com/bearpaw/pytorch-classification
|
README.md
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Skeleton-Temporal-Action-Localization
|
| 2 |
+
|
| 3 |
+
Code for the paper "Frame-Level Label Refinement for Skeleton-Based Weakly-Supervised Action Recognition" (AAAI 2023).
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
Architecture of Network
|
| 8 |
+
|
| 9 |
+

|
| 10 |
+
|
| 11 |
+
## Requirements
|
| 12 |
+
```bash
|
| 13 |
+
conda create -n stal python=3.7
|
| 14 |
+
conda activate stal
|
| 15 |
+
conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 -c pytorch
|
| 16 |
+
pip install -r requirements.txt
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
## Data Preparation
|
| 20 |
+
Due to the distribution policy of AMASS dataset, we are not allowed to distribute the data directly. We provide a series of script that could reproduce our motion segmentation dataset from BABEL dataset.
|
| 21 |
+
|
| 22 |
+
Download [AMASS Dataset](https://amass.is.tue.mpg.de/) and [BABEL Dataset](https://babel.is.tue.mpg.de/). Unzip and locate them in the `dataset` folder.
|
| 23 |
+
|
| 24 |
+
Prepare the SMPLH Model following [this](https://github.com/vchoutas/smplx/blob/main/tools/README.md#smpl-h-version-used-in-amass) and put the merged model `SMPLH_male.pkl` into the `human_model` folder.
|
| 25 |
+
|
| 26 |
+
The whole directory should be look like this:
|
| 27 |
+
```
|
| 28 |
+
Skeleton-Temporal-Action-Localization
|
| 29 |
+
│ README.md
|
| 30 |
+
│ train.py
|
| 31 |
+
| ...
|
| 32 |
+
|
|
| 33 |
+
└───config
|
| 34 |
+
└───prepare
|
| 35 |
+
└───...
|
| 36 |
+
│
|
| 37 |
+
└───human_model
|
| 38 |
+
│ └───SMPLH_male.pkl
|
| 39 |
+
│
|
| 40 |
+
└───dataset
|
| 41 |
+
└───amass
|
| 42 |
+
| └───ACCAD
|
| 43 |
+
| └───BMLmovi
|
| 44 |
+
| └───...
|
| 45 |
+
│
|
| 46 |
+
└───babel_v1.0_release
|
| 47 |
+
└───train.json
|
| 48 |
+
└───val.json
|
| 49 |
+
└───...
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
And also clone the BABEL offical code into the `dataset` folder.
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
git clone https://github.com/abhinanda-punnakkal/BABEL.git dataset/BABEL
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
Finally, the motion segmentation dataset can be generate by:
|
| 59 |
+
```bash
|
| 60 |
+
bash prepare/generate_dataset.sh
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## Training and Evaluation
|
| 64 |
+
Train and evaluate the model with subset-1 of BABEL, run following commands:
|
| 65 |
+
```bash
|
| 66 |
+
python train.py --config config/train_split1.yaml
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
## Acknowledgement
|
| 70 |
+
Our codes are based on [BABEL](https://github.com/abhinanda-punnakkal/BABEL), [2s-AGCN](https://github.com/lshiwjx/2s-AGCN) and [FAC-Net](https://github.com/LeonHLJ/FAC-Net).
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
## Citation
|
| 74 |
+
|
| 75 |
+
```
|
| 76 |
+
@InProceedings{yu2023frame,
|
| 77 |
+
title={Frame-Level Label Refinement for Skeleton-Based Weakly-Supervised Action Recognition},
|
| 78 |
+
author={Yu, Qing and Fujiwara, Kent},
|
| 79 |
+
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
|
| 80 |
+
volume={37},
|
| 81 |
+
number={3},
|
| 82 |
+
pages={3322--3330},
|
| 83 |
+
year={2023}
|
| 84 |
+
}
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
## License
|
| 88 |
+
[Apache License 2.0](LICENSE)
|
| 89 |
+
|
| 90 |
+
Additionally, this repository contains third-party software. Refer [NOTICE.txt](NOTICE.txt) for more details and follow the terms and conditions of their use.
|
configs/train_split1_full_MSFF_DTCL.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
work_dir: ./work_dir/1_w_MSFF_full/
|
| 2 |
+
model_saved_name: ./work_dir/1_w_MSFF_full/
|
| 3 |
+
|
| 4 |
+
# feeder
|
| 5 |
+
feeder: feeders.feeder.Feeder
|
| 6 |
+
train_feeder_args:
|
| 7 |
+
data_path: ./dataset/train_split1.pkl
|
| 8 |
+
debug: False
|
| 9 |
+
random_choose: False
|
| 10 |
+
random_shift: False
|
| 11 |
+
random_move: False
|
| 12 |
+
window_size: -1
|
| 13 |
+
nb_class: 4
|
| 14 |
+
|
| 15 |
+
test_feeder_args:
|
| 16 |
+
data_path: ./dataset/val_split1.pkl
|
| 17 |
+
nb_class: 4
|
| 18 |
+
|
| 19 |
+
# model
|
| 20 |
+
model: model.agcn_Unet.Model
|
| 21 |
+
model_args:
|
| 22 |
+
num_class: 4
|
| 23 |
+
num_person: 1
|
| 24 |
+
num_point: 25 # checked from 25
|
| 25 |
+
graph: graph.ntu_rgb_d.Graph
|
| 26 |
+
graph_args:
|
| 27 |
+
labeling_mode: 'spatial'
|
| 28 |
+
|
| 29 |
+
#optim
|
| 30 |
+
weight_decay: 0.001
|
| 31 |
+
base_lr: 0.0002
|
| 32 |
+
step: [60,80]
|
| 33 |
+
|
| 34 |
+
# training
|
| 35 |
+
device: [0]
|
| 36 |
+
optimizer: 'Adam'
|
| 37 |
+
loss: 'CE'
|
| 38 |
+
batch_size: 8
|
| 39 |
+
test_batch_size: 1
|
| 40 |
+
num_epoch: 101 #101
|
| 41 |
+
nesterov: True
|
| 42 |
+
lambda_mil: 1.0
|
| 43 |
+
weights: /home/newDisk/epoch300_model.pt
|
configs/train_split2_full_MSFF_DTCL.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
work_dir: ./work_dir/2_w_MSFF_full/
|
| 2 |
+
model_saved_name: ./work_dir/2_w_MSFF_full/
|
| 3 |
+
|
| 4 |
+
# feeder
|
| 5 |
+
feeder: feeders.feeder.Feeder
|
| 6 |
+
train_feeder_args:
|
| 7 |
+
data_path: ./dataset/train_split2.pkl
|
| 8 |
+
debug: False
|
| 9 |
+
random_choose: False
|
| 10 |
+
random_shift: False
|
| 11 |
+
random_move: False
|
| 12 |
+
window_size: -1
|
| 13 |
+
nb_class: 4
|
| 14 |
+
|
| 15 |
+
test_feeder_args:
|
| 16 |
+
data_path: ./dataset/val_split2.pkl
|
| 17 |
+
nb_class: 4
|
| 18 |
+
|
| 19 |
+
# model
|
| 20 |
+
model: model.agcn_Unet.Model
|
| 21 |
+
model_args:
|
| 22 |
+
num_class: 4
|
| 23 |
+
num_person: 1
|
| 24 |
+
num_point: 25 # checked from 25
|
| 25 |
+
graph: graph.ntu_rgb_d.Graph
|
| 26 |
+
graph_args:
|
| 27 |
+
labeling_mode: 'spatial'
|
| 28 |
+
|
| 29 |
+
#optim
|
| 30 |
+
weight_decay: 0.0001
|
| 31 |
+
base_lr: 0.001
|
| 32 |
+
step: [60,80]
|
| 33 |
+
|
| 34 |
+
# training
|
| 35 |
+
device: [0]
|
| 36 |
+
optimizer: 'Adam'
|
| 37 |
+
loss: 'CE'
|
| 38 |
+
batch_size: 8
|
| 39 |
+
test_batch_size: 1
|
| 40 |
+
num_epoch: 101 #101
|
| 41 |
+
nesterov: True
|
| 42 |
+
lambda_mil: 1.0
|
| 43 |
+
weights: /home/newDisk/epoch300_model.pt
|
configs/train_split3_full_MSFF_DTCL.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
work_dir: ./work_dir/3_w_MSFF_full/
|
| 2 |
+
model_saved_name: ./work_dir/3_w_MSFF_full/
|
| 3 |
+
|
| 4 |
+
# feeder
|
| 5 |
+
feeder: feeders.feeder.Feeder
|
| 6 |
+
train_feeder_args:
|
| 7 |
+
data_path: ./dataset/train_split3.pkl
|
| 8 |
+
debug: False
|
| 9 |
+
random_choose: False
|
| 10 |
+
random_shift: False
|
| 11 |
+
random_move: False
|
| 12 |
+
window_size: -1
|
| 13 |
+
nb_class: 4
|
| 14 |
+
|
| 15 |
+
test_feeder_args:
|
| 16 |
+
data_path: ./dataset/val_split3.pkl
|
| 17 |
+
nb_class: 4
|
| 18 |
+
|
| 19 |
+
# model
|
| 20 |
+
model: model.agcn_Unet.Model
|
| 21 |
+
model_args:
|
| 22 |
+
num_class: 4
|
| 23 |
+
num_person: 1
|
| 24 |
+
num_point: 25 # checked from 25
|
| 25 |
+
graph: graph.ntu_rgb_d.Graph
|
| 26 |
+
graph_args:
|
| 27 |
+
labeling_mode: 'spatial'
|
| 28 |
+
|
| 29 |
+
#optim
|
| 30 |
+
weight_decay: 0.0001
|
| 31 |
+
base_lr: 0.001
|
| 32 |
+
step: [60,80]
|
| 33 |
+
|
| 34 |
+
# training
|
| 35 |
+
device: [0]
|
| 36 |
+
optimizer: 'Adam'
|
| 37 |
+
loss: 'CE'
|
| 38 |
+
batch_size: 8
|
| 39 |
+
test_batch_size: 1
|
| 40 |
+
num_epoch: 101 #101
|
| 41 |
+
nesterov: True
|
| 42 |
+
lambda_mil: 1.0
|
| 43 |
+
weights: /home/newDisk/epoch300_model.pt
|
eval.py
ADDED
|
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
from __future__ import print_function
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import inspect
|
| 7 |
+
import os
|
| 8 |
+
import pdb
|
| 9 |
+
import pickle
|
| 10 |
+
import random
|
| 11 |
+
import re
|
| 12 |
+
import shutil
|
| 13 |
+
import time
|
| 14 |
+
from collections import *
|
| 15 |
+
|
| 16 |
+
import ipdb
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
# torch
|
| 20 |
+
import torch
|
| 21 |
+
import torch.backends.cudnn as cudnn
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
import torch.optim as optim
|
| 25 |
+
import yaml
|
| 26 |
+
from einops import rearrange, reduce, repeat
|
| 27 |
+
from evaluation.classificationMAP import getClassificationMAP as cmAP
|
| 28 |
+
from evaluation.detectionMAP import getSingleStreamDetectionMAP as dsmAP
|
| 29 |
+
from feeders.tools import collate_with_padding_multi_joint
|
| 30 |
+
from model.losses import cross_entropy_loss, mvl_loss
|
| 31 |
+
from sklearn.metrics import f1_score
|
| 32 |
+
|
| 33 |
+
# Custom
|
| 34 |
+
from tensorboardX import SummaryWriter
|
| 35 |
+
from torch.autograd import Variable
|
| 36 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 37 |
+
from tqdm import tqdm
|
| 38 |
+
from utils.logger import Logger
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def init_seed(seed):
|
| 42 |
+
torch.cuda.manual_seed_all(seed)
|
| 43 |
+
torch.manual_seed(seed)
|
| 44 |
+
np.random.seed(seed)
|
| 45 |
+
random.seed(seed)
|
| 46 |
+
torch.backends.cudnn.deterministic = True
|
| 47 |
+
torch.backends.cudnn.benchmark = False
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_parser():
|
| 51 |
+
# parameter priority: command line > config > default
|
| 52 |
+
parser = argparse.ArgumentParser(
|
| 53 |
+
description="Spatial Temporal Graph Convolution Network"
|
| 54 |
+
)
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--work-dir",
|
| 57 |
+
default="./work_dir/temp",
|
| 58 |
+
help="the work folder for storing results",
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
parser.add_argument("-model_saved_name", default="")
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--config",
|
| 64 |
+
default="./config/nturgbd-cross-view/test_bone.yaml",
|
| 65 |
+
help="path to the configuration file",
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# processor
|
| 69 |
+
parser.add_argument("--phase", default="test", help="must be train or test")
|
| 70 |
+
|
| 71 |
+
# visulize and debug
|
| 72 |
+
parser.add_argument("--seed", type=int, default=5, help="random seed for pytorch")
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"--log-interval",
|
| 75 |
+
type=int,
|
| 76 |
+
default=100,
|
| 77 |
+
help="the interval for printing messages (#iteration)",
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--save-interval",
|
| 81 |
+
type=int,
|
| 82 |
+
default=2,
|
| 83 |
+
help="the interval for storing models (#iteration)",
|
| 84 |
+
)
|
| 85 |
+
parser.add_argument(
|
| 86 |
+
"--eval-interval",
|
| 87 |
+
type=int,
|
| 88 |
+
default=5,
|
| 89 |
+
help="the interval for evaluating models (#iteration)",
|
| 90 |
+
)
|
| 91 |
+
parser.add_argument(
|
| 92 |
+
"--print-log", type=str2bool, default=True, help="print logging or not"
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--show-topk",
|
| 96 |
+
type=int,
|
| 97 |
+
default=[1, 5],
|
| 98 |
+
nargs="+",
|
| 99 |
+
help="which Top K accuracy will be shown",
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# feeder
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--feeder", default="feeder.feeder", help="data loader will be used"
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--num-worker",
|
| 108 |
+
type=int,
|
| 109 |
+
default=32,
|
| 110 |
+
help="the number of worker for data loader",
|
| 111 |
+
)
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--train-feeder-args",
|
| 114 |
+
default=dict(),
|
| 115 |
+
help="the arguments of data loader for training",
|
| 116 |
+
)
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--test-feeder-args",
|
| 119 |
+
default=dict(),
|
| 120 |
+
help="the arguments of data loader for test",
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# model
|
| 124 |
+
parser.add_argument("--model", default=None, help="the model will be used")
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--model-args", type=dict, default=dict(), help="the arguments of model"
|
| 127 |
+
)
|
| 128 |
+
parser.add_argument(
|
| 129 |
+
"--weights", default=None, help="the weights for network initialization"
|
| 130 |
+
)
|
| 131 |
+
parser.add_argument(
|
| 132 |
+
"--ignore-weights",
|
| 133 |
+
type=str,
|
| 134 |
+
default=[],
|
| 135 |
+
nargs="+",
|
| 136 |
+
help="the name of weights which will be ignored in the initialization",
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# optim
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--base-lr", type=float, default=0.01, help="initial learning rate"
|
| 142 |
+
)
|
| 143 |
+
parser.add_argument(
|
| 144 |
+
"--step",
|
| 145 |
+
type=int,
|
| 146 |
+
default=[60,80],
|
| 147 |
+
nargs="+",
|
| 148 |
+
help="the epoch where optimizer reduce the learning rate",
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# training
|
| 152 |
+
parser.add_argument(
|
| 153 |
+
"--device",
|
| 154 |
+
type=int,
|
| 155 |
+
default=0,
|
| 156 |
+
nargs="+",
|
| 157 |
+
help="the indexes of GPUs for training or testing",
|
| 158 |
+
)
|
| 159 |
+
parser.add_argument("--optimizer", default="SGD", help="type of optimizer")
|
| 160 |
+
parser.add_argument(
|
| 161 |
+
"--nesterov", type=str2bool, default=False, help="use nesterov or not"
|
| 162 |
+
)
|
| 163 |
+
parser.add_argument(
|
| 164 |
+
"--batch-size", type=int, default=256, help="training batch size"
|
| 165 |
+
)
|
| 166 |
+
parser.add_argument(
|
| 167 |
+
"--test-batch-size", type=int, default=256, help="test batch size"
|
| 168 |
+
)
|
| 169 |
+
parser.add_argument(
|
| 170 |
+
"--start-epoch", type=int, default=0, help="start training from which epoch"
|
| 171 |
+
)
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--num-epoch", type=int, default=80, help="stop training in which epoch"
|
| 174 |
+
)
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--weight-decay", type=float, default=0.0005, help="weight decay for optimizer"
|
| 177 |
+
)
|
| 178 |
+
# loss
|
| 179 |
+
parser.add_argument("--loss", type=str, default="CE", help="loss type(CE or focal)")
|
| 180 |
+
parser.add_argument(
|
| 181 |
+
"--label_count_path",
|
| 182 |
+
default=None,
|
| 183 |
+
type=str,
|
| 184 |
+
help="Path to label counts (used in loss weighting)",
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"---beta",
|
| 188 |
+
type=float,
|
| 189 |
+
default=0.9999,
|
| 190 |
+
help="Hyperparameter for Class balanced loss",
|
| 191 |
+
)
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--gamma", type=float, default=2.0, help="Hyperparameter for Focal loss"
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
parser.add_argument("--only_train_part", default=False)
|
| 197 |
+
parser.add_argument("--only_train_epoch", default=0)
|
| 198 |
+
parser.add_argument("--warm_up_epoch", default=0)
|
| 199 |
+
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"--lambda-mil", default=1.0, help="balancing hyper-parameter of mil branch"
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
parser.add_argument(
|
| 205 |
+
"--class-threshold",
|
| 206 |
+
type=float,
|
| 207 |
+
default=0.1,
|
| 208 |
+
help="class threshold for rejection",
|
| 209 |
+
)
|
| 210 |
+
parser.add_argument(
|
| 211 |
+
"--start-threshold",
|
| 212 |
+
type=float,
|
| 213 |
+
default=0.03,
|
| 214 |
+
help="start threshold for action localization",
|
| 215 |
+
)
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--end-threshold",
|
| 218 |
+
type=float,
|
| 219 |
+
default=0.055,
|
| 220 |
+
help="end threshold for action localization",
|
| 221 |
+
)
|
| 222 |
+
parser.add_argument(
|
| 223 |
+
"--threshold-interval",
|
| 224 |
+
type=float,
|
| 225 |
+
default=0.005,
|
| 226 |
+
help="threshold interval for action localization",
|
| 227 |
+
)
|
| 228 |
+
return parser
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class Processor:
|
| 232 |
+
"""
|
| 233 |
+
Processor for Skeleton-based Action Recgnition
|
| 234 |
+
"""
|
| 235 |
+
|
| 236 |
+
def __init__(self, arg):
|
| 237 |
+
self.arg = arg
|
| 238 |
+
self.save_arg()
|
| 239 |
+
if arg.phase == "train":
|
| 240 |
+
if not arg.train_feeder_args["debug"]:
|
| 241 |
+
if os.path.isdir(arg.model_saved_name):
|
| 242 |
+
print("log_dir: ", arg.model_saved_name, "already exist")
|
| 243 |
+
# answer = input('delete it? y/n:')
|
| 244 |
+
answer = "y"
|
| 245 |
+
if answer == "y":
|
| 246 |
+
print("Deleting dir...")
|
| 247 |
+
shutil.rmtree(arg.model_saved_name)
|
| 248 |
+
print("Dir removed: ", arg.model_saved_name)
|
| 249 |
+
# input('Refresh the website of tensorboard by pressing any keys')
|
| 250 |
+
else:
|
| 251 |
+
print("Dir not removed: ", arg.model_saved_name)
|
| 252 |
+
self.train_writer = SummaryWriter(
|
| 253 |
+
os.path.join(arg.model_saved_name, "train"), "train"
|
| 254 |
+
)
|
| 255 |
+
self.val_writer = SummaryWriter(
|
| 256 |
+
os.path.join(arg.model_saved_name, "val"), "val"
|
| 257 |
+
)
|
| 258 |
+
else:
|
| 259 |
+
self.train_writer = self.val_writer = SummaryWriter(
|
| 260 |
+
os.path.join(arg.model_saved_name, "test"), "test"
|
| 261 |
+
)
|
| 262 |
+
self.global_step = 0
|
| 263 |
+
self.load_model()
|
| 264 |
+
self.load_optimizer()
|
| 265 |
+
self.load_data()
|
| 266 |
+
self.lr = self.arg.base_lr
|
| 267 |
+
self.best_acc = 0
|
| 268 |
+
self.best_per_class_acc = 0
|
| 269 |
+
self.loss_nce = torch.nn.BCELoss()
|
| 270 |
+
|
| 271 |
+
self.my_logger = Logger(
|
| 272 |
+
os.path.join(arg.model_saved_name, "log.txt"), title="SWTAL"
|
| 273 |
+
)
|
| 274 |
+
self.my_logger.set_names(["Step", "cmap"] + [f"map_0.{i}" for i in range(1, 6)])
|
| 275 |
+
|
| 276 |
+
def load_data(self):
|
| 277 |
+
Feeder = import_class(self.arg.feeder)
|
| 278 |
+
self.data_loader = dict()
|
| 279 |
+
if self.arg.phase == "train":
|
| 280 |
+
self.data_loader["train"] = torch.utils.data.DataLoader(
|
| 281 |
+
dataset=Feeder(**self.arg.train_feeder_args),
|
| 282 |
+
batch_size=self.arg.batch_size,
|
| 283 |
+
shuffle=True,
|
| 284 |
+
num_workers=self.arg.num_worker,
|
| 285 |
+
drop_last=True,
|
| 286 |
+
collate_fn=collate_with_padding_multi_joint,
|
| 287 |
+
)
|
| 288 |
+
self.data_loader["test"] = torch.utils.data.DataLoader(
|
| 289 |
+
dataset=Feeder(**self.arg.test_feeder_args),
|
| 290 |
+
batch_size=self.arg.test_batch_size,
|
| 291 |
+
shuffle=False,
|
| 292 |
+
num_workers=self.arg.num_worker,
|
| 293 |
+
drop_last=False,
|
| 294 |
+
collate_fn=collate_with_padding_multi_joint,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
def load_model(self):
|
| 298 |
+
output_device = (
|
| 299 |
+
self.arg.device[0] if type(self.arg.device) is list else self.arg.device
|
| 300 |
+
)
|
| 301 |
+
self.output_device = output_device
|
| 302 |
+
Model = import_class(self.arg.model)
|
| 303 |
+
shutil.copy2(inspect.getfile(Model), self.arg.work_dir)
|
| 304 |
+
# print(Model)
|
| 305 |
+
self.model = Model(**self.arg.model_args).cuda(output_device)
|
| 306 |
+
# print(self.model)
|
| 307 |
+
self.loss_type = arg.loss
|
| 308 |
+
|
| 309 |
+
if self.arg.weights:
|
| 310 |
+
# self.global_step = int(arg.weights[:-3].split("-")[-1])
|
| 311 |
+
self.print_log("Load weights from {}.".format(self.arg.weights))
|
| 312 |
+
if ".pkl" in self.arg.weights:
|
| 313 |
+
with open(self.arg.weights, "r") as f:
|
| 314 |
+
weights = pickle.load(f)
|
| 315 |
+
else:
|
| 316 |
+
weights = torch.load(self.arg.weights)
|
| 317 |
+
|
| 318 |
+
weights = OrderedDict(
|
| 319 |
+
[
|
| 320 |
+
[k.split("module.")[-1], v.cuda(output_device)]
|
| 321 |
+
for k, v in weights.items()
|
| 322 |
+
]
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
keys = list(weights.keys())
|
| 326 |
+
for w in self.arg.ignore_weights:
|
| 327 |
+
for key in keys:
|
| 328 |
+
if w in key:
|
| 329 |
+
if weights.pop(key, None) is not None:
|
| 330 |
+
self.print_log(
|
| 331 |
+
"Sucessfully Remove Weights: {}.".format(key)
|
| 332 |
+
)
|
| 333 |
+
else:
|
| 334 |
+
self.print_log("Can Not Remove Weights: {}.".format(key))
|
| 335 |
+
|
| 336 |
+
try:
|
| 337 |
+
self.model.load_state_dict(weights)
|
| 338 |
+
except:
|
| 339 |
+
state = self.model.state_dict()
|
| 340 |
+
diff = list(set(state.keys()).difference(set(weights.keys())))
|
| 341 |
+
print("Can not find these weights:")
|
| 342 |
+
for d in diff:
|
| 343 |
+
print(" " + d)
|
| 344 |
+
state.update(weights)
|
| 345 |
+
self.model.load_state_dict(state)
|
| 346 |
+
|
| 347 |
+
if type(self.arg.device) is list:
|
| 348 |
+
if len(self.arg.device) > 1:
|
| 349 |
+
self.model = nn.DataParallel(
|
| 350 |
+
self.model, device_ids=self.arg.device, output_device=output_device
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
def load_optimizer(self):
|
| 354 |
+
if self.arg.optimizer == "SGD":
|
| 355 |
+
self.optimizer = optim.SGD(
|
| 356 |
+
self.model.parameters(),
|
| 357 |
+
lr=self.arg.base_lr,
|
| 358 |
+
momentum=0.9,
|
| 359 |
+
nesterov=self.arg.nesterov,
|
| 360 |
+
weight_decay=self.arg.weight_decay,
|
| 361 |
+
)
|
| 362 |
+
elif self.arg.optimizer == "Adam":
|
| 363 |
+
self.optimizer = optim.Adam(
|
| 364 |
+
self.model.parameters(),
|
| 365 |
+
lr=self.arg.base_lr,
|
| 366 |
+
weight_decay=self.arg.weight_decay,
|
| 367 |
+
)
|
| 368 |
+
else:
|
| 369 |
+
raise ValueError()
|
| 370 |
+
|
| 371 |
+
def save_arg(self):
|
| 372 |
+
# save arg
|
| 373 |
+
arg_dict = vars(self.arg)
|
| 374 |
+
if not os.path.exists(self.arg.work_dir):
|
| 375 |
+
os.makedirs(self.arg.work_dir)
|
| 376 |
+
with open("{}/config.yaml".format(self.arg.work_dir), "w") as f:
|
| 377 |
+
yaml.dump(arg_dict, f)
|
| 378 |
+
|
| 379 |
+
def adjust_learning_rate(self, epoch):
|
| 380 |
+
if self.arg.optimizer == "SGD" or self.arg.optimizer == "Adam":
|
| 381 |
+
if epoch < self.arg.warm_up_epoch:
|
| 382 |
+
lr = self.arg.base_lr * (epoch + 1) / self.arg.warm_up_epoch
|
| 383 |
+
else:
|
| 384 |
+
lr = self.arg.base_lr * (
|
| 385 |
+
0.1 ** np.sum(epoch >= np.array(self.arg.step))
|
| 386 |
+
)
|
| 387 |
+
for param_group in self.optimizer.param_groups:
|
| 388 |
+
param_group["lr"] = lr
|
| 389 |
+
|
| 390 |
+
return lr
|
| 391 |
+
else:
|
| 392 |
+
raise ValueError()
|
| 393 |
+
|
| 394 |
+
def print_time(self):
|
| 395 |
+
localtime = time.asctime(time.localtime(time.time()))
|
| 396 |
+
self.print_log("Local current time : " + localtime)
|
| 397 |
+
|
| 398 |
+
def print_log(self, str, print_time=True):
|
| 399 |
+
if print_time:
|
| 400 |
+
localtime = time.asctime(time.localtime(time.time()))
|
| 401 |
+
str = "[ " + localtime + " ] " + str
|
| 402 |
+
print(str)
|
| 403 |
+
if self.arg.print_log:
|
| 404 |
+
with open("{}/print_log.txt".format(self.arg.work_dir), "a") as f:
|
| 405 |
+
print(str, file=f)
|
| 406 |
+
|
| 407 |
+
def record_time(self):
|
| 408 |
+
self.cur_time = time.time()
|
| 409 |
+
return self.cur_time
|
| 410 |
+
|
| 411 |
+
def split_time(self):
|
| 412 |
+
split_time = time.time() - self.cur_time
|
| 413 |
+
self.record_time()
|
| 414 |
+
return split_time
|
| 415 |
+
|
| 416 |
+
@torch.no_grad()
|
| 417 |
+
def eval(
|
| 418 |
+
self,
|
| 419 |
+
epoch,
|
| 420 |
+
wb_dict,
|
| 421 |
+
loader_name=["test"],
|
| 422 |
+
):
|
| 423 |
+
self.model.eval()
|
| 424 |
+
self.print_log("Eval epoch: {}".format(epoch + 1))
|
| 425 |
+
|
| 426 |
+
vid_preds = []
|
| 427 |
+
frm_preds = []
|
| 428 |
+
vid_lens = []
|
| 429 |
+
labels = []
|
| 430 |
+
|
| 431 |
+
for ln in loader_name:
|
| 432 |
+
loss_value = []
|
| 433 |
+
step = 0
|
| 434 |
+
process = tqdm(self.data_loader[ln])
|
| 435 |
+
|
| 436 |
+
for batch_idx, (data, label, target, mask, index, soft_label) in enumerate(
|
| 437 |
+
process
|
| 438 |
+
):
|
| 439 |
+
data = data.float().cuda(self.output_device)
|
| 440 |
+
label = label.cuda(self.output_device)
|
| 441 |
+
mask = mask.cuda(self.output_device)
|
| 442 |
+
|
| 443 |
+
ab_labels = torch.cat([label, torch.ones(label.size(0), 1).cuda()], -1)
|
| 444 |
+
|
| 445 |
+
# forward
|
| 446 |
+
mil_pred, frm_scrs, mil_pred_2, frm_scrs_2 = self.model(data,mask)
|
| 447 |
+
|
| 448 |
+
cls_mil_loss = self.loss_nce(
|
| 449 |
+
mil_pred, ab_labels.float()
|
| 450 |
+
) + self.loss_nce(mil_pred_2, ab_labels.float())
|
| 451 |
+
|
| 452 |
+
loss_co = mvl_loss(frm_scrs, frm_scrs_2, rate=0.2, weight=0.5)
|
| 453 |
+
|
| 454 |
+
loss = cls_mil_loss * self.arg.lambda_mil + loss_co
|
| 455 |
+
|
| 456 |
+
loss_value.append(loss.data.item())
|
| 457 |
+
|
| 458 |
+
for i in range(data.size(0)):
|
| 459 |
+
frm_scr = frm_scrs[i]
|
| 460 |
+
vid_pred = mil_pred[i]
|
| 461 |
+
|
| 462 |
+
label_ = label[i].cpu().numpy()
|
| 463 |
+
mask_ = mask[i].cpu().numpy()
|
| 464 |
+
vid_len = mask_.sum()
|
| 465 |
+
|
| 466 |
+
frm_pred = F.softmax(frm_scr, -1).cpu().numpy()[:vid_len]
|
| 467 |
+
vid_pred = vid_pred.cpu().numpy()
|
| 468 |
+
|
| 469 |
+
vid_preds.append(vid_pred)
|
| 470 |
+
frm_preds.append(frm_pred)
|
| 471 |
+
vid_lens.append(vid_len)
|
| 472 |
+
labels.append(label_)
|
| 473 |
+
|
| 474 |
+
step += 1
|
| 475 |
+
|
| 476 |
+
vid_preds = np.array(vid_preds)
|
| 477 |
+
frm_preds = np.array(frm_preds)
|
| 478 |
+
vid_lens = np.array(vid_lens)
|
| 479 |
+
labels = np.array(labels)
|
| 480 |
+
|
| 481 |
+
cmap = cmAP(vid_preds, labels)
|
| 482 |
+
|
| 483 |
+
score = cmap
|
| 484 |
+
loss = np.mean(loss_value)
|
| 485 |
+
|
| 486 |
+
dmap, iou = dsmAP(
|
| 487 |
+
vid_preds,
|
| 488 |
+
frm_preds,
|
| 489 |
+
vid_lens,
|
| 490 |
+
self.arg.test_feeder_args["data_path"],
|
| 491 |
+
self.arg,
|
| 492 |
+
multi=True,
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
print("Classification map %f" % cmap)
|
| 496 |
+
|
| 497 |
+
for item in list(zip(iou, dmap)):
|
| 498 |
+
print("Detection map @ %f = %f" % (item[0], item[1]))
|
| 499 |
+
|
| 500 |
+
self.my_logger.append([epoch + 1, cmap] + dmap)
|
| 501 |
+
|
| 502 |
+
wb_dict["val loss"] = loss
|
| 503 |
+
wb_dict["val acc"] = score
|
| 504 |
+
|
| 505 |
+
if score > self.best_acc:
|
| 506 |
+
self.best_acc = score
|
| 507 |
+
|
| 508 |
+
print("Acc score: ", score, " model: ", self.arg.model_saved_name)
|
| 509 |
+
if self.arg.phase == "train":
|
| 510 |
+
self.val_writer.add_scalar("loss", loss, self.global_step)
|
| 511 |
+
self.val_writer.add_scalar("acc", score, self.global_step)
|
| 512 |
+
|
| 513 |
+
self.print_log(
|
| 514 |
+
"\tMean {} loss of {} batches: {}.".format(
|
| 515 |
+
ln, len(self.data_loader[ln]), np.mean(loss_value)
|
| 516 |
+
)
|
| 517 |
+
)
|
| 518 |
+
self.print_log("\tAcc score: {:.3f}%".format(score))
|
| 519 |
+
|
| 520 |
+
return wb_dict
|
| 521 |
+
|
| 522 |
+
def start(self):
|
| 523 |
+
wb_dict = {}
|
| 524 |
+
|
| 525 |
+
if self.arg.phase == "test":
|
| 526 |
+
if not self.arg.test_feeder_args["debug"]:
|
| 527 |
+
wf = self.arg.model_saved_name + "_wrong.txt"
|
| 528 |
+
rf = self.arg.model_saved_name + "_right.txt"
|
| 529 |
+
else:
|
| 530 |
+
wf = rf = None
|
| 531 |
+
if self.arg.weights is None:
|
| 532 |
+
raise ValueError("Please appoint --weights.")
|
| 533 |
+
self.arg.print_log = False
|
| 534 |
+
self.print_log("Model: {}.".format(self.arg.model))
|
| 535 |
+
self.print_log("Weights: {}.".format(self.arg.weights))
|
| 536 |
+
|
| 537 |
+
wb_dict = self.eval(
|
| 538 |
+
epoch=0,
|
| 539 |
+
wb_dict=wb_dict,
|
| 540 |
+
loader_name=["test"],
|
| 541 |
+
# wrong_file=wf,
|
| 542 |
+
# result_file=rf,
|
| 543 |
+
)
|
| 544 |
+
print("Inference metrics: ", wb_dict)
|
| 545 |
+
self.print_log("Done.\n")
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def str2bool(v):
|
| 549 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 550 |
+
return True
|
| 551 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 552 |
+
return False
|
| 553 |
+
else:
|
| 554 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
def import_class(name):
|
| 558 |
+
components = name.split(".")
|
| 559 |
+
mod = __import__(components[0])
|
| 560 |
+
for comp in components[1:]:
|
| 561 |
+
mod = getattr(mod, comp)
|
| 562 |
+
return mod
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
if __name__ == "__main__":
|
| 566 |
+
parser = get_parser()
|
| 567 |
+
|
| 568 |
+
# load arg form config file
|
| 569 |
+
p = parser.parse_args()
|
| 570 |
+
if p.config is not None:
|
| 571 |
+
with open(p.config, "r") as f:
|
| 572 |
+
default_arg = yaml.safe_load(f)
|
| 573 |
+
key = vars(p).keys()
|
| 574 |
+
for k in default_arg.keys():
|
| 575 |
+
if k not in key:
|
| 576 |
+
print("WRONG ARG: {}".format(k))
|
| 577 |
+
assert k in key
|
| 578 |
+
parser.set_defaults(**default_arg)
|
| 579 |
+
|
| 580 |
+
arg = parser.parse_args()
|
| 581 |
+
print("BABEL Action Recognition")
|
| 582 |
+
print("Config: ", arg)
|
| 583 |
+
init_seed(arg.seed)
|
| 584 |
+
processor = Processor(arg)
|
| 585 |
+
processor.start()
|
evaluation/classificationMAP.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def getAP(conf, labels):
|
| 5 |
+
assert len(conf) == len(labels)
|
| 6 |
+
sortind = np.argsort(-conf)
|
| 7 |
+
tp = labels[sortind] == 1
|
| 8 |
+
fp = labels[sortind] != 1
|
| 9 |
+
npos = np.sum(labels)
|
| 10 |
+
|
| 11 |
+
fp = np.cumsum(fp).astype("float32")
|
| 12 |
+
tp = np.cumsum(tp).astype("float32")
|
| 13 |
+
rec = tp / npos
|
| 14 |
+
prec = tp / (fp + tp)
|
| 15 |
+
tmp = (labels[sortind] == 1).astype("float32")
|
| 16 |
+
|
| 17 |
+
return np.sum(tmp * prec) / npos if npos > 0 else 1
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def getClassificationMAP(confidence, labels):
|
| 21 |
+
""" confidence and labels are of dimension n_samples x n_label """
|
| 22 |
+
|
| 23 |
+
AP = []
|
| 24 |
+
for i in range(np.shape(labels)[1]):
|
| 25 |
+
AP.append(getAP(confidence[:, i], labels[:, i]))
|
| 26 |
+
return 100 * sum(AP) / len(AP)
|
evaluation/detectionMAP.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
from collections import Counter
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def str2ind(categoryname, classlist):
|
| 8 |
+
return [i for i in range(len(classlist)) if categoryname == classlist[i]][0]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def encode_mask_to_rle(mask):
|
| 12 |
+
"""
|
| 13 |
+
mask: numpy array binary mask
|
| 14 |
+
1 - mask
|
| 15 |
+
0 - background
|
| 16 |
+
Returns encoded run length
|
| 17 |
+
"""
|
| 18 |
+
pixels = mask.flatten()
|
| 19 |
+
pixels = np.concatenate([[0], pixels, [0]])
|
| 20 |
+
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
|
| 21 |
+
runs[1::2] -= runs[::2]
|
| 22 |
+
return runs
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def filter_segments(segment_predict, videonames, ambilist, factor):
|
| 26 |
+
ind = np.zeros(np.shape(segment_predict)[0])
|
| 27 |
+
for i in range(np.shape(segment_predict)[0]):
|
| 28 |
+
vn = videonames[int(segment_predict[i, 0])]
|
| 29 |
+
for a in ambilist:
|
| 30 |
+
if a[0] == vn:
|
| 31 |
+
gt = range(
|
| 32 |
+
int(round(float(a[2]) * factor)), int(round(float(a[3]) * factor))
|
| 33 |
+
)
|
| 34 |
+
pd = range(int(segment_predict[i][1]), int(segment_predict[i][2]))
|
| 35 |
+
IoU = float(len(set(gt).intersection(set(pd)))) / float(
|
| 36 |
+
len(set(gt).union(set(pd)))
|
| 37 |
+
)
|
| 38 |
+
if IoU > 0:
|
| 39 |
+
ind[i] = 1
|
| 40 |
+
s = [
|
| 41 |
+
segment_predict[i, :]
|
| 42 |
+
for i in range(np.shape(segment_predict)[0])
|
| 43 |
+
if ind[i] == 0
|
| 44 |
+
]
|
| 45 |
+
return np.array(s)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def getActLoc(
|
| 49 |
+
vid_preds, frm_preds, vid_lens, act_thresh_cas, annotation_path, args, multi=False
|
| 50 |
+
):
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
with open(annotation_path) as f:
|
| 54 |
+
data = pickle.load(f)
|
| 55 |
+
except:
|
| 56 |
+
# for pickle file from python2
|
| 57 |
+
with open(annotation_path, "rb") as f:
|
| 58 |
+
data = pickle.load(f, encoding="latin1")
|
| 59 |
+
|
| 60 |
+
if multi:
|
| 61 |
+
gtsegments = []
|
| 62 |
+
gtlabels = []
|
| 63 |
+
for idx in range(len(data["L"])):
|
| 64 |
+
gt = data["L"][idx]
|
| 65 |
+
gt_ = set(gt)
|
| 66 |
+
gt_.discard(args.model_args["num_class"])
|
| 67 |
+
gts = []
|
| 68 |
+
gtl = []
|
| 69 |
+
for c in list(gt_):
|
| 70 |
+
gt_encoded = encode_mask_to_rle(gt == c)
|
| 71 |
+
gts.extend(
|
| 72 |
+
[
|
| 73 |
+
[x - 1, x + y - 2]
|
| 74 |
+
for x, y in zip(gt_encoded[::2], gt_encoded[1::2])
|
| 75 |
+
]
|
| 76 |
+
)
|
| 77 |
+
gtl.extend([c for item in gt_encoded[::2]])
|
| 78 |
+
gtsegments.append(gts)
|
| 79 |
+
gtlabels.append(gtl)
|
| 80 |
+
else:
|
| 81 |
+
gtsegments = []
|
| 82 |
+
gtlabels = []
|
| 83 |
+
for idx in range(len(data["L"])):
|
| 84 |
+
gt = data["L"][idx]
|
| 85 |
+
gt_encoded = encode_mask_to_rle(gt)
|
| 86 |
+
gtsegments.append(
|
| 87 |
+
[[x - 1, x + y - 2] for x, y in zip(gt_encoded[::2], gt_encoded[1::2])]
|
| 88 |
+
)
|
| 89 |
+
gtlabels.append([data["Y"][idx] for item in gt_encoded[::2]])
|
| 90 |
+
|
| 91 |
+
videoname = np.array(data["sid"])
|
| 92 |
+
|
| 93 |
+
# keep ground truth and predictions for instances with temporal annotations
|
| 94 |
+
gtl, vn, vp, fp, vl = [], [], [], [], []
|
| 95 |
+
for i, s in enumerate(gtsegments):
|
| 96 |
+
if len(s):
|
| 97 |
+
gtl.append(gtlabels[i])
|
| 98 |
+
vn.append(videoname[i])
|
| 99 |
+
vp.append(vid_preds[i])
|
| 100 |
+
fp.append(frm_preds[i])
|
| 101 |
+
vl.append(vid_lens[i])
|
| 102 |
+
else:
|
| 103 |
+
print(i)
|
| 104 |
+
gtlabels = gtl
|
| 105 |
+
videoname = vn
|
| 106 |
+
|
| 107 |
+
# which categories have temporal labels ?
|
| 108 |
+
templabelidx = sorted(list(set([l for gtl in gtlabels for l in gtl])))
|
| 109 |
+
|
| 110 |
+
dataset_segment_predict = []
|
| 111 |
+
class_threshold = args.class_threshold
|
| 112 |
+
for c in range(frm_preds[0].shape[1]):
|
| 113 |
+
c_temp = []
|
| 114 |
+
# Get list of all predictions for class c
|
| 115 |
+
for i in range(len(fp)):
|
| 116 |
+
vid_cls_score = vp[i][c]
|
| 117 |
+
vid_cas = fp[i][:, c]
|
| 118 |
+
vid_cls_proposal = []
|
| 119 |
+
# if vid_cls_score < class_threshold:
|
| 120 |
+
# continue
|
| 121 |
+
for t in range(len(act_thresh_cas)):
|
| 122 |
+
thres = act_thresh_cas[t]
|
| 123 |
+
vid_pred = np.concatenate(
|
| 124 |
+
[np.zeros(1), (vid_cas > thres).astype("float32"), np.zeros(1)],
|
| 125 |
+
axis=0,
|
| 126 |
+
)
|
| 127 |
+
vid_pred_diff = [
|
| 128 |
+
vid_pred[idt] - vid_pred[idt - 1] for idt in range(1, len(vid_pred))
|
| 129 |
+
]
|
| 130 |
+
s = [idk for idk, item in enumerate(vid_pred_diff) if item == 1]
|
| 131 |
+
e = [idk for idk, item in enumerate(vid_pred_diff) if item == -1]
|
| 132 |
+
for j in range(len(s)):
|
| 133 |
+
len_proposal = e[j] - s[j]
|
| 134 |
+
if len_proposal >= 3:
|
| 135 |
+
inner_score = np.mean(vid_cas[s[j] : e[j] + 1])
|
| 136 |
+
outer_s = max(0, int(s[j] - 0.25 * len_proposal))
|
| 137 |
+
outer_e = min(
|
| 138 |
+
int(vid_cas.shape[0] - 1),
|
| 139 |
+
int(e[j] + 0.25 * len_proposal + 1),
|
| 140 |
+
)
|
| 141 |
+
outer_temp_list = list(range(outer_s, int(s[j]))) + list(
|
| 142 |
+
range(int(e[j] + 1), outer_e)
|
| 143 |
+
)
|
| 144 |
+
if len(outer_temp_list) == 0:
|
| 145 |
+
outer_score = 0
|
| 146 |
+
else:
|
| 147 |
+
outer_score = np.mean(vid_cas[outer_temp_list])
|
| 148 |
+
c_score = inner_score - 0.6 * outer_score
|
| 149 |
+
vid_cls_proposal.append([i, s[j], e[j] + 1, c_score])
|
| 150 |
+
pick_idx = NonMaximumSuppression(np.array(vid_cls_proposal), 0.2)
|
| 151 |
+
nms_vid_cls_proposal = [vid_cls_proposal[k] for k in pick_idx]
|
| 152 |
+
c_temp += nms_vid_cls_proposal
|
| 153 |
+
if len(c_temp) > 0:
|
| 154 |
+
c_temp = np.array(c_temp)
|
| 155 |
+
dataset_segment_predict.append(c_temp)
|
| 156 |
+
"""
|
| 157 |
+
for i, pred in enumerate(dataset_segment_predict):
|
| 158 |
+
print (f"#{i} class {c} has {len(pred)} predictions")
|
| 159 |
+
"""
|
| 160 |
+
return dataset_segment_predict
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def IntergrateSegs(rgb_segs, flow_segs, th, args):
|
| 164 |
+
NUM_CLASS = args.class_num
|
| 165 |
+
NUM_VID = 212
|
| 166 |
+
segs = []
|
| 167 |
+
for i in range(NUM_CLASS):
|
| 168 |
+
class_seg = []
|
| 169 |
+
rgb_seg = rgb_segs[i]
|
| 170 |
+
flow_seg = flow_segs[i]
|
| 171 |
+
rgb_seg_ind = np.array(rgb_seg)[:, 0]
|
| 172 |
+
flow_seg_ind = np.array(flow_seg)[:, 0]
|
| 173 |
+
for j in range(NUM_VID):
|
| 174 |
+
rgb_find = np.where(rgb_seg_ind == j)
|
| 175 |
+
flow_find = np.where(flow_seg_ind == j)
|
| 176 |
+
if len(rgb_find[0]) == 0 and len(flow_find[0]) == 0:
|
| 177 |
+
continue
|
| 178 |
+
elif len(rgb_find[0]) != 0 and len(flow_find[0]) != 0:
|
| 179 |
+
rgb_vid_seg = rgb_seg[rgb_find[0]]
|
| 180 |
+
flow_vid_seg = flow_seg[flow_find[0]]
|
| 181 |
+
fuse_seg = np.concatenate([rgb_vid_seg, flow_vid_seg], axis=0)
|
| 182 |
+
pick_idx = NonMaximumSuppression(fuse_seg, th)
|
| 183 |
+
fuse_segs = fuse_seg[pick_idx]
|
| 184 |
+
class_seg.append(fuse_segs)
|
| 185 |
+
elif len(rgb_find[0]) != 0 and len(flow_find[0]) == 0:
|
| 186 |
+
vid_seg = rgb_seg[rgb_find[0]]
|
| 187 |
+
class_seg.append(vid_seg)
|
| 188 |
+
elif len(rgb_find[0]) == 0 and len(flow_find[0]) != 0:
|
| 189 |
+
vid_seg = flow_seg[flow_find[0]]
|
| 190 |
+
class_seg.append(vid_seg)
|
| 191 |
+
class_seg = np.concatenate(class_seg, axis=0)
|
| 192 |
+
segs.append(class_seg)
|
| 193 |
+
return segs
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def NonMaximumSuppression(segs, overlapThresh):
|
| 197 |
+
# if there are no boxes, return an empty list
|
| 198 |
+
if len(segs) == 0:
|
| 199 |
+
return []
|
| 200 |
+
# if the bounding boxes integers, convert them to floats --
|
| 201 |
+
# this is important since we'll be doing a bunch of divisions
|
| 202 |
+
if segs.dtype.kind == "i":
|
| 203 |
+
segs = segs.astype("float")
|
| 204 |
+
|
| 205 |
+
# initialize the list of picked indexes
|
| 206 |
+
pick = []
|
| 207 |
+
|
| 208 |
+
# grab the coordinates of the segments
|
| 209 |
+
s = segs[:, 1]
|
| 210 |
+
e = segs[:, 2]
|
| 211 |
+
scores = segs[:, 3]
|
| 212 |
+
# compute the area of the bounding boxes and sort the bounding
|
| 213 |
+
# boxes by the score of the bounding box
|
| 214 |
+
area = e - s + 1
|
| 215 |
+
idxs = np.argsort(scores)
|
| 216 |
+
|
| 217 |
+
# keep looping while some indexes still remain in the indexes
|
| 218 |
+
# list
|
| 219 |
+
while len(idxs) > 0:
|
| 220 |
+
# grab the last index in the indexes list and add the
|
| 221 |
+
# index value to the list of picked indexes
|
| 222 |
+
last = len(idxs) - 1
|
| 223 |
+
i = idxs[last]
|
| 224 |
+
pick.append(i)
|
| 225 |
+
|
| 226 |
+
# find the largest coordinates for the start of
|
| 227 |
+
# the segments and the smallest coordinates
|
| 228 |
+
# for the end of the segments
|
| 229 |
+
maxs = np.maximum(s[i], s[idxs[:last]])
|
| 230 |
+
mine = np.minimum(e[i], e[idxs[:last]])
|
| 231 |
+
|
| 232 |
+
# compute the length of the overlapping area
|
| 233 |
+
l = np.maximum(0, mine - maxs + 1)
|
| 234 |
+
# compute the ratio of overlap
|
| 235 |
+
overlap = l / area[idxs[:last]]
|
| 236 |
+
|
| 237 |
+
# delete segments beyond the threshold
|
| 238 |
+
idxs = np.delete(
|
| 239 |
+
idxs, np.concatenate(([last], np.where(overlap > overlapThresh)[0]))
|
| 240 |
+
)
|
| 241 |
+
return pick
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def getLocMAP(seg_preds, th, annotation_path, args, multi=False, factor=1.0):
|
| 245 |
+
try:
|
| 246 |
+
with open(annotation_path) as f:
|
| 247 |
+
data = pickle.load(f)
|
| 248 |
+
except:
|
| 249 |
+
# for pickle file from python2
|
| 250 |
+
with open(annotation_path, "rb") as f:
|
| 251 |
+
data = pickle.load(f, encoding="latin1")
|
| 252 |
+
|
| 253 |
+
if multi:
|
| 254 |
+
gtsegments = []
|
| 255 |
+
gtlabels = []
|
| 256 |
+
for idx in range(len(data["L"])):
|
| 257 |
+
gt = data["L"][idx]
|
| 258 |
+
gt_ = set(gt)
|
| 259 |
+
# gt_.discard(args.model_args["num_classes"])
|
| 260 |
+
gt_.discard(4)
|
| 261 |
+
gts = []
|
| 262 |
+
gtl = []
|
| 263 |
+
for c in list(gt_):
|
| 264 |
+
gt_encoded = encode_mask_to_rle(gt == c)
|
| 265 |
+
gts.extend(
|
| 266 |
+
[
|
| 267 |
+
[x - 1, x + y - 2]
|
| 268 |
+
for x, y in zip(gt_encoded[::2], gt_encoded[1::2])
|
| 269 |
+
]
|
| 270 |
+
)
|
| 271 |
+
gtl.extend([c for item in gt_encoded[::2]])
|
| 272 |
+
gtsegments.append(gts)
|
| 273 |
+
gtlabels.append(gtl)
|
| 274 |
+
# else:
|
| 275 |
+
# gtsegments = []
|
| 276 |
+
# gtlabels = []
|
| 277 |
+
# for idx in range(len(data["L"])):
|
| 278 |
+
# gt = data["L"][idx]
|
| 279 |
+
# gt_encoded = encode_mask_to_rle(gt)
|
| 280 |
+
# gtsegments.append(
|
| 281 |
+
# [[x - 1, x + y - 2] for x, y in zip(gt_encoded[::2], gt_encoded[1::2])]
|
| 282 |
+
# )
|
| 283 |
+
# gtlabels.append([data["Y"][idx] for item in gt_encoded[::2]])
|
| 284 |
+
|
| 285 |
+
# videoname = np.array(data["sid"])
|
| 286 |
+
# """
|
| 287 |
+
# cnt = Counter(data['Y'])
|
| 288 |
+
# d = cnt.most_common()
|
| 289 |
+
# print (d)
|
| 290 |
+
# """
|
| 291 |
+
# # which categories have temporal labels ?
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# templabelidx = sorted(list(set([l for gtl in gtlabels for l in gtl])))
|
| 295 |
+
templabelidx = [0,1,2,3]
|
| 296 |
+
ap = []
|
| 297 |
+
for c in templabelidx:
|
| 298 |
+
segment_predict = seg_preds[c]
|
| 299 |
+
# Sort the list of predictions for class c based on score
|
| 300 |
+
if len(segment_predict) == 0:
|
| 301 |
+
ap.append(0.0)
|
| 302 |
+
continue
|
| 303 |
+
segment_predict = segment_predict[np.argsort(-segment_predict[:, 3])]
|
| 304 |
+
|
| 305 |
+
# Create gt list
|
| 306 |
+
segment_gt = [
|
| 307 |
+
[i, gtsegments[i][j][0], gtsegments[i][j][1]]
|
| 308 |
+
for i in range(len(gtsegments))
|
| 309 |
+
for j in range(len(gtsegments[i]))
|
| 310 |
+
if gtlabels[i][j] == c
|
| 311 |
+
]
|
| 312 |
+
gtpos = len(segment_gt)
|
| 313 |
+
|
| 314 |
+
# Compare predictions and gt
|
| 315 |
+
tp, fp = [], []
|
| 316 |
+
for i in range(len(segment_predict)):
|
| 317 |
+
matched = False
|
| 318 |
+
best_iou = 0
|
| 319 |
+
for j in range(len(segment_gt)):
|
| 320 |
+
if segment_predict[i][0] == segment_gt[j][0]:
|
| 321 |
+
gt = range(
|
| 322 |
+
int(round(segment_gt[j][1] * factor)),
|
| 323 |
+
int(round(segment_gt[j][2] * factor)),
|
| 324 |
+
)
|
| 325 |
+
p = range(int(segment_predict[i][1]), int(segment_predict[i][2]))
|
| 326 |
+
# IoU = float(len(set(gt).intersection(set(p)))) / float(
|
| 327 |
+
# len(set(gt).union(set(p)))
|
| 328 |
+
# )
|
| 329 |
+
union_set = set(gt).union(set(p))
|
| 330 |
+
if len(union_set) == 0:
|
| 331 |
+
IoU = 0.0 # or handle the case as needed
|
| 332 |
+
else:
|
| 333 |
+
IoU = float(len(set(gt).intersection(set(p)))) / float(len(union_set))
|
| 334 |
+
if IoU >= th:
|
| 335 |
+
matched = True
|
| 336 |
+
if IoU > best_iou:
|
| 337 |
+
best_iou = IoU
|
| 338 |
+
best_j = j
|
| 339 |
+
if matched:
|
| 340 |
+
del segment_gt[best_j]
|
| 341 |
+
tp.append(float(matched))
|
| 342 |
+
fp.append(1.0 - float(matched))
|
| 343 |
+
tp_c = np.cumsum(tp)
|
| 344 |
+
fp_c = np.cumsum(fp)
|
| 345 |
+
# print (c, tp, fp)
|
| 346 |
+
if sum(tp) == 0:
|
| 347 |
+
prc = 0.0
|
| 348 |
+
else:
|
| 349 |
+
cur_prec = tp_c / (fp_c + tp_c)
|
| 350 |
+
cur_rec = tp_c / gtpos
|
| 351 |
+
prc = _ap_from_pr(cur_prec, cur_rec)
|
| 352 |
+
ap.append(prc)
|
| 353 |
+
|
| 354 |
+
print(f" ".join([f"{item*100:.2f}" for item in ap]))
|
| 355 |
+
if ap:
|
| 356 |
+
return 100 * np.mean(ap)
|
| 357 |
+
else:
|
| 358 |
+
return 0
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
# Inspired by Pascal VOC evaluation tool.
|
| 362 |
+
def _ap_from_pr(prec, rec):
|
| 363 |
+
mprec = np.hstack([[0], prec, [0]])
|
| 364 |
+
mrec = np.hstack([[0], rec, [1]])
|
| 365 |
+
|
| 366 |
+
for i in range(len(mprec) - 1)[::-1]:
|
| 367 |
+
mprec[i] = max(mprec[i], mprec[i + 1])
|
| 368 |
+
|
| 369 |
+
idx = np.where(mrec[1::] != mrec[0:-1])[0] + 1
|
| 370 |
+
ap = np.sum((mrec[idx] - mrec[idx - 1]) * mprec[idx])
|
| 371 |
+
|
| 372 |
+
return ap
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def compute_iou(dur1, dur2):
|
| 376 |
+
# find the each edge of intersect rectangle
|
| 377 |
+
left_line = max(dur1[0], dur2[0])
|
| 378 |
+
right_line = min(dur1[1], dur2[1])
|
| 379 |
+
|
| 380 |
+
# judge if there is an intersect
|
| 381 |
+
if left_line >= right_line:
|
| 382 |
+
return 0
|
| 383 |
+
else:
|
| 384 |
+
intersect = right_line - left_line
|
| 385 |
+
union = max(dur1[1], dur2[1]) - min(dur1[0], dur2[0])
|
| 386 |
+
return intersect / union
|
| 387 |
+
|
| 388 |
+
def getActLoc1(
|
| 389 |
+
frm_preds,act_thresh_cas = np.arange(0.03, 0.055, 0.005)
|
| 390 |
+
):
|
| 391 |
+
fp = []
|
| 392 |
+
for i, s in enumerate(frm_preds):
|
| 393 |
+
fp.append(frm_preds[i])
|
| 394 |
+
|
| 395 |
+
dataset_segment_predict = []
|
| 396 |
+
for c in range(frm_preds[0].shape[1]):
|
| 397 |
+
c_temp = []
|
| 398 |
+
# Get list of all predictions for class c
|
| 399 |
+
for i in range(len(fp)):
|
| 400 |
+
vid_cas = fp[i][:, c]
|
| 401 |
+
vid_cls_proposal = []
|
| 402 |
+
|
| 403 |
+
for t in range(len(act_thresh_cas)):
|
| 404 |
+
thres = act_thresh_cas[t]
|
| 405 |
+
vid_pred = np.concatenate(
|
| 406 |
+
[np.zeros(1), (vid_cas > thres).astype("float32"), np.zeros(1)],
|
| 407 |
+
axis=0,
|
| 408 |
+
)
|
| 409 |
+
vid_pred_diff = [
|
| 410 |
+
vid_pred[idt] - vid_pred[idt - 1] for idt in range(1, len(vid_pred))
|
| 411 |
+
]
|
| 412 |
+
s = [idk for idk, item in enumerate(vid_pred_diff) if item == 1]
|
| 413 |
+
e = [idk for idk, item in enumerate(vid_pred_diff) if item == -1]
|
| 414 |
+
for j in range(len(s)):
|
| 415 |
+
len_proposal = e[j] - s[j]
|
| 416 |
+
if len_proposal >= 3:
|
| 417 |
+
inner_score = np.mean(vid_cas[s[j] : e[j] + 1])
|
| 418 |
+
outer_s = max(0, int(s[j] - 0.25 * len_proposal))
|
| 419 |
+
outer_e = min(
|
| 420 |
+
int(vid_cas.shape[0] - 1),
|
| 421 |
+
int(e[j] + 0.25 * len_proposal + 1),
|
| 422 |
+
)
|
| 423 |
+
outer_temp_list = list(range(outer_s, int(s[j]))) + list(
|
| 424 |
+
range(int(e[j] + 1), outer_e)
|
| 425 |
+
)
|
| 426 |
+
if len(outer_temp_list) == 0:
|
| 427 |
+
outer_score = 0
|
| 428 |
+
else:
|
| 429 |
+
outer_score = np.mean(vid_cas[outer_temp_list])
|
| 430 |
+
c_score = inner_score - 0.6 * outer_score
|
| 431 |
+
vid_cls_proposal.append([i, s[j], e[j] + 1, c_score])
|
| 432 |
+
pick_idx = NonMaximumSuppression(np.array(vid_cls_proposal), 0.2)
|
| 433 |
+
nms_vid_cls_proposal = [vid_cls_proposal[k] for k in pick_idx]
|
| 434 |
+
c_temp += nms_vid_cls_proposal
|
| 435 |
+
if len(c_temp) > 0:
|
| 436 |
+
c_temp = np.array(c_temp)
|
| 437 |
+
dataset_segment_predict.append(c_temp)
|
| 438 |
+
"""
|
| 439 |
+
for i, pred in enumerate(dataset_segment_predict):
|
| 440 |
+
print (f"#{i} class {c} has {len(pred)} predictions")
|
| 441 |
+
"""
|
| 442 |
+
return dataset_segment_predict
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def getSingleStreamDetectionMAP(
|
| 446 |
+
vid_preds, frm_preds, vid_lens, annotation_path, args, multi=False, factor=1.0
|
| 447 |
+
):
|
| 448 |
+
iou_list = [0.1, 0.2, 0.3, 0.4, 0.5]
|
| 449 |
+
dmap_list = []
|
| 450 |
+
|
| 451 |
+
seg = getActLoc1(
|
| 452 |
+
frm_preds,
|
| 453 |
+
np.arange(args.start_threshold, args.end_threshold, args.threshold_interval),
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
for iou in iou_list:
|
| 457 |
+
print("Testing for IoU %f" % iou)
|
| 458 |
+
dmap_list.append(
|
| 459 |
+
getLocMAP(seg, iou, annotation_path, args, multi=multi, factor=factor)
|
| 460 |
+
)
|
| 461 |
+
return dmap_list, iou_list
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def getTwoStreamDetectionMAP(
|
| 465 |
+
rgb_vid_preds,
|
| 466 |
+
flow_vid_preds,
|
| 467 |
+
rgb_frm_preds,
|
| 468 |
+
flow_frm_preds,
|
| 469 |
+
vid_lens,
|
| 470 |
+
annotation_path,
|
| 471 |
+
args,
|
| 472 |
+
):
|
| 473 |
+
iou_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
| 474 |
+
dmap_list = []
|
| 475 |
+
rgb_seg = getActLoc(
|
| 476 |
+
rgb_vid_preds,
|
| 477 |
+
rgb_frm_preds * 0.1,
|
| 478 |
+
vid_lens,
|
| 479 |
+
np.arange(args.start_threshold, args.end_threshold, args.threshold_interval)
|
| 480 |
+
* 0.1,
|
| 481 |
+
annotation_path,
|
| 482 |
+
args,
|
| 483 |
+
)
|
| 484 |
+
flow_seg = getActLoc(
|
| 485 |
+
flow_vid_preds,
|
| 486 |
+
flow_frm_preds,
|
| 487 |
+
vid_lens,
|
| 488 |
+
np.arange(args.start_threshold, args.end_threshold, args.threshold_interval),
|
| 489 |
+
annotation_path,
|
| 490 |
+
args,
|
| 491 |
+
)
|
| 492 |
+
seg = IntergrateSegs(rgb_seg, flow_seg, 0.9, args)
|
| 493 |
+
for iou in iou_list:
|
| 494 |
+
print("Testing for IoU %f" % iou)
|
| 495 |
+
dmap_list.append(getLocMAP(seg, iou, annotation_path, args))
|
| 496 |
+
|
| 497 |
+
return dmap_list, iou_list
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def getSingleStreamDetectionMAP_gcn(
|
| 501 |
+
seg, annotation_path, args, multi=False, factor=1.0
|
| 502 |
+
):
|
| 503 |
+
'''
|
| 504 |
+
seg is a list of 4+1 ndarrays
|
| 505 |
+
each ndarray is of shape (# pred, 4), 4 expands as [videoindex, s[j], e[j] + 1, c_score]
|
| 506 |
+
'''
|
| 507 |
+
iou_list = [0.3, 0.5]
|
| 508 |
+
iou_list = [0.1,0.2,0.3, 0.4,0.5]
|
| 509 |
+
dmap_list = []
|
| 510 |
+
|
| 511 |
+
for iou in iou_list:
|
| 512 |
+
print("Testing for IoU %f" % iou)
|
| 513 |
+
dmap_list.append(
|
| 514 |
+
getLocMAP(seg, iou, annotation_path, args, multi=multi, factor=factor)
|
| 515 |
+
)
|
| 516 |
+
return dmap_list, iou_list
|
evaluation/eval.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.autograd import Variable
|
| 5 |
+
|
| 6 |
+
from .classificationMAP import getClassificationMAP as cmAP
|
| 7 |
+
from .detectionMAP import getSingleStreamDetectionMAP as dsmAP
|
| 8 |
+
from .detectionMAP import getTwoStreamDetectionMAP as dtmAP
|
| 9 |
+
from .utils import write_results_to_eval_file, write_results_to_file
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def ss_eval(epoch, dataloader, args, logger, model, device):
|
| 13 |
+
vid_preds = []
|
| 14 |
+
frm_preds = []
|
| 15 |
+
vid_lens = []
|
| 16 |
+
labels = []
|
| 17 |
+
|
| 18 |
+
for num, sample in enumerate(dataloader):
|
| 19 |
+
if (num + 1) % 100 == 0:
|
| 20 |
+
print("Testing test data point %d of %d" % (num + 1, len(dataloader)))
|
| 21 |
+
|
| 22 |
+
features = sample["data"].numpy()
|
| 23 |
+
label = sample["labels"].numpy()
|
| 24 |
+
vid_len = sample["vid_len"].numpy()
|
| 25 |
+
|
| 26 |
+
features = torch.from_numpy(features).float().to(device)
|
| 27 |
+
|
| 28 |
+
with torch.no_grad():
|
| 29 |
+
_, vid_pred, _, frm_scr = model(Variable(features))
|
| 30 |
+
frm_pred = F.softmax(frm_scr, -1)
|
| 31 |
+
vid_pred = np.squeeze(vid_pred.cpu().data.numpy(), axis=0)
|
| 32 |
+
frm_pred = np.squeeze(frm_pred.cpu().data.numpy(), axis=0)
|
| 33 |
+
label = np.squeeze(label, axis=0)
|
| 34 |
+
|
| 35 |
+
vid_preds.append(vid_pred)
|
| 36 |
+
frm_preds.append(frm_pred)
|
| 37 |
+
vid_lens.append(vid_len)
|
| 38 |
+
labels.append(label)
|
| 39 |
+
|
| 40 |
+
vid_preds = np.array(vid_preds)
|
| 41 |
+
frm_preds = np.array(frm_preds)
|
| 42 |
+
vid_lens = np.array(vid_lens)
|
| 43 |
+
labels = np.array(labels)
|
| 44 |
+
|
| 45 |
+
cmap = cmAP(vid_preds, labels)
|
| 46 |
+
dmap, iou = dsmAP(
|
| 47 |
+
vid_preds, frm_preds, vid_lens, dataloader.dataset.path_to_annotations, args
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
print("Classification map %f" % cmap)
|
| 51 |
+
for item in list(zip(iou, dmap)):
|
| 52 |
+
print("Detection map @ %f = %f" % (item[0], item[1]))
|
| 53 |
+
|
| 54 |
+
logger.log_value("Test Classification mAP", cmap, epoch)
|
| 55 |
+
for item in list(zip(dmap, iou)):
|
| 56 |
+
logger.log_value("Test Detection1 mAP @ IoU = " + str(item[1]), item[0], epoch)
|
| 57 |
+
|
| 58 |
+
write_results_to_file(args, dmap, cmap, epoch)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def ts_eval(dataloader, args, logger, rgb_model, flow_model, device):
|
| 62 |
+
rgb_vid_preds = []
|
| 63 |
+
rgb_frame_preds = []
|
| 64 |
+
flow_vid_preds = []
|
| 65 |
+
flow_frame_preds = []
|
| 66 |
+
vid_lens = []
|
| 67 |
+
labels = []
|
| 68 |
+
|
| 69 |
+
for num, sample in enumerate(dataloader):
|
| 70 |
+
if (num + 1) % 100 == 0:
|
| 71 |
+
print("Testing test data point %d of %d" % (num + 1, len(dataloader)))
|
| 72 |
+
|
| 73 |
+
rgb_features = sample["rgb_data"].numpy()
|
| 74 |
+
flow_features = sample["flow_data"].numpy()
|
| 75 |
+
label = sample["labels"].numpy()
|
| 76 |
+
vid_len = sample["vid_len"].numpy()
|
| 77 |
+
|
| 78 |
+
rgb_features_inp = torch.from_numpy(rgb_features).float().to(device)
|
| 79 |
+
flow_features_inp = torch.from_numpy(flow_features).float().to(device)
|
| 80 |
+
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
_, rgb_video_pred, _, rgb_frame_scr = rgb_model(Variable(rgb_features_inp))
|
| 83 |
+
_, flow_video_pred, _, flow_frame_scr = flow_model(
|
| 84 |
+
Variable(flow_features_inp)
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
rgb_frame_pred = F.softmax(rgb_frame_scr, -1)
|
| 88 |
+
flow_frame_pred = F.softmax(flow_frame_scr, -1)
|
| 89 |
+
|
| 90 |
+
rgb_frame_pred = np.squeeze(rgb_frame_pred.cpu().data.numpy(), axis=0)
|
| 91 |
+
flow_frame_pred = np.squeeze(flow_frame_pred.cpu().data.numpy(), axis=0)
|
| 92 |
+
rgb_video_pred = np.squeeze(rgb_video_pred.cpu().data.numpy(), axis=0)
|
| 93 |
+
flow_video_pred = np.squeeze(flow_video_pred.cpu().data.numpy(), axis=0)
|
| 94 |
+
label = np.squeeze(label, axis=0)
|
| 95 |
+
|
| 96 |
+
rgb_vid_preds.append(rgb_video_pred)
|
| 97 |
+
rgb_frame_preds.append(rgb_frame_pred)
|
| 98 |
+
flow_vid_preds.append(flow_video_pred)
|
| 99 |
+
flow_frame_preds.append(flow_frame_pred)
|
| 100 |
+
vid_lens.append(vid_len)
|
| 101 |
+
labels.append(label)
|
| 102 |
+
|
| 103 |
+
rgb_vid_preds = np.array(rgb_vid_preds)
|
| 104 |
+
rgb_frame_preds = np.array(rgb_frame_preds)
|
| 105 |
+
flow_vid_preds = np.array(flow_vid_preds)
|
| 106 |
+
flow_frame_preds = np.array(flow_frame_preds)
|
| 107 |
+
vid_lens = np.array(vid_lens)
|
| 108 |
+
labels = np.array(labels)
|
| 109 |
+
|
| 110 |
+
dmap, iou = dtmAP(
|
| 111 |
+
rgb_vid_preds,
|
| 112 |
+
flow_vid_preds,
|
| 113 |
+
rgb_frame_preds,
|
| 114 |
+
flow_frame_preds,
|
| 115 |
+
vid_lens,
|
| 116 |
+
dataloader.dataset.path_to_annotations,
|
| 117 |
+
args,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
sum = 0
|
| 121 |
+
count = 0
|
| 122 |
+
for item in list(zip(iou, dmap)):
|
| 123 |
+
print("Detection map @ %f = %f" % (item[0], item[1]))
|
| 124 |
+
if count < 7:
|
| 125 |
+
sum = sum + item[1]
|
| 126 |
+
count += 1
|
| 127 |
+
|
| 128 |
+
print("average map = %f" % (sum / count))
|
| 129 |
+
write_results_to_eval_file(args, dmap, args.rgb_load_epoch, args.flow_load_epoch)
|
evaluation/utils.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def str2ind(categoryname, classlist):
|
| 7 |
+
return [
|
| 8 |
+
i for i in range(len(classlist)) if categoryname == classlist[i].decode("utf-8")
|
| 9 |
+
][0]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def strlist2indlist(strlist, classlist):
|
| 13 |
+
return [str2ind(s, classlist) for s in strlist]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def strlist2multihot(strlist, classlist):
|
| 17 |
+
return np.sum(np.eye(len(classlist))[strlist2indlist(strlist, classlist)], axis=0)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def idx2multihot(id_list, num_class):
|
| 21 |
+
return np.sum(np.eye(num_class)[id_list], axis=0)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def write_results_to_eval_file(args, dmap, itr1, itr2):
|
| 25 |
+
file_folder = "./ckpt/" + args.dataset_name + "/eval/"
|
| 26 |
+
file_name = args.dataset_name + "-results.log"
|
| 27 |
+
fid = open(file_folder + file_name, "a+")
|
| 28 |
+
string_to_write = str(itr1)
|
| 29 |
+
string_to_write += " " + str(itr2)
|
| 30 |
+
for item in dmap:
|
| 31 |
+
string_to_write += " " + "%.2f" % item
|
| 32 |
+
fid.write(string_to_write + "\n")
|
| 33 |
+
fid.close()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def write_results_to_file(args, dmap, cmap, itr):
|
| 37 |
+
file_folder = "./ckpt/" + args.dataset_name + "/" + str(args.model_id) + "/"
|
| 38 |
+
file_name = args.dataset_name + "-results.log"
|
| 39 |
+
fid = open(file_folder + file_name, "a+")
|
| 40 |
+
string_to_write = str(itr)
|
| 41 |
+
for item in dmap:
|
| 42 |
+
string_to_write += " " + "%.2f" % item
|
| 43 |
+
string_to_write += " " + "%.2f" % cmap
|
| 44 |
+
fid.write(string_to_write + "\n")
|
| 45 |
+
fid.close()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def write_settings_to_file(args):
|
| 49 |
+
file_folder = "./ckpt/" + args.dataset_name + "/" + str(args.model_id) + "/"
|
| 50 |
+
file_name = args.dataset_name + "-results.log"
|
| 51 |
+
fid = open(file_folder + file_name, "a+")
|
| 52 |
+
string_to_write = "#" * 80 + "\n"
|
| 53 |
+
for arg in vars(args):
|
| 54 |
+
string_to_write += str(arg) + ": " + str(getattr(args, arg)) + "\n"
|
| 55 |
+
string_to_write += "*" * 80 + "\n"
|
| 56 |
+
fid.write(string_to_write)
|
| 57 |
+
fid.close()
|
feeders/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from . import feeder
|
feeders/feeder.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# -*- coding: utf-8 -*-
|
| 4 |
+
#
|
| 5 |
+
# Adapted from https://github.com/lshiwjx/2s-AGCN for BABEL (https://babel.is.tue.mpg.de/)
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import math
|
| 9 |
+
import os
|
| 10 |
+
import os.path as osp
|
| 11 |
+
import pdb
|
| 12 |
+
import pickle
|
| 13 |
+
import random
|
| 14 |
+
import shutil
|
| 15 |
+
import subprocess
|
| 16 |
+
import sys
|
| 17 |
+
import uuid
|
| 18 |
+
|
| 19 |
+
import matplotlib.pyplot as plt
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
from feeders import tools
|
| 23 |
+
from torch.utils.data import Dataset
|
| 24 |
+
|
| 25 |
+
sys.path.extend(["../"])
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Feeder(Dataset):
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
data_path,
|
| 32 |
+
random_choose=False,
|
| 33 |
+
random_shift=False,
|
| 34 |
+
random_move=False,
|
| 35 |
+
window_size=-1,
|
| 36 |
+
debug=False,
|
| 37 |
+
use_mmap=True,
|
| 38 |
+
frame_pad=False,
|
| 39 |
+
nb_class=3,
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
:param data_path:
|
| 44 |
+
:param label_path:
|
| 45 |
+
:param random_choose: If true, randomly choose a portion of the input sequence
|
| 46 |
+
:param random_shift: If true, randomly pad zeros at the begining or end of sequence
|
| 47 |
+
:param random_move:
|
| 48 |
+
:param window_size: The length of the output sequence
|
| 49 |
+
:param normalization: If true, normalize input sequence
|
| 50 |
+
:param debug: If true, only use the first 100 samples
|
| 51 |
+
:param use_mmap: If true, use mmap mode to load data, which can save the running memory
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
self.debug = debug
|
| 55 |
+
self.data_path = data_path
|
| 56 |
+
self.random_choose = random_choose
|
| 57 |
+
self.random_shift = random_shift
|
| 58 |
+
self.random_move = random_move
|
| 59 |
+
self.window_size = window_size
|
| 60 |
+
self.use_mmap = use_mmap
|
| 61 |
+
self.nb_class = nb_class
|
| 62 |
+
self.frame_pad = frame_pad
|
| 63 |
+
self.load_data()
|
| 64 |
+
self.count = 0
|
| 65 |
+
for i in range(len(self.data["X"])):
|
| 66 |
+
assert self.data["L"][i].shape[0] == self.data["X"][i].shape[1]
|
| 67 |
+
|
| 68 |
+
self.prediction = [
|
| 69 |
+
np.zeros((item.shape[0], 10, self.nb_class + 1), dtype=np.float32)
|
| 70 |
+
for item in self.data["L"]
|
| 71 |
+
]
|
| 72 |
+
self.soft_labels = [
|
| 73 |
+
np.zeros((item.shape[0], self.nb_class + 1), dtype=np.float32)
|
| 74 |
+
for item in self.data["L"]
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
def load_data(self):
|
| 78 |
+
# data: N, C, T, V, M
|
| 79 |
+
# load data
|
| 80 |
+
try:
|
| 81 |
+
with open(self.data_path) as f:
|
| 82 |
+
self.data = pickle.load(f)
|
| 83 |
+
except:
|
| 84 |
+
# for pickle file from python2
|
| 85 |
+
with open(self.data_path, "rb") as f:
|
| 86 |
+
self.data = pickle.load(f, encoding="latin1")
|
| 87 |
+
|
| 88 |
+
def label_update(self, results, indexs):
|
| 89 |
+
self.count += 1
|
| 90 |
+
|
| 91 |
+
# While updating the noisy label y_i by the probability s, we used the average output probability of the network of the past 10 epochs as s.
|
| 92 |
+
idx = (self.count - 1) % 10
|
| 93 |
+
|
| 94 |
+
for ind, res in zip(indexs, results):
|
| 95 |
+
self.prediction[ind][:, idx, :] = res
|
| 96 |
+
|
| 97 |
+
for i in range(len(self.prediction)):
|
| 98 |
+
self.soft_labels[i] = self.prediction[i].mean(axis=1)
|
| 99 |
+
|
| 100 |
+
def __len__(self):
|
| 101 |
+
return len(self.data["X"])
|
| 102 |
+
|
| 103 |
+
def __iter__(self):
|
| 104 |
+
return self
|
| 105 |
+
|
| 106 |
+
def __getitem__(self, index):
|
| 107 |
+
'''
|
| 108 |
+
data_numpy: read joints from PKL and no padding here as frame_pad is false
|
| 109 |
+
label: video level label
|
| 110 |
+
gt: action label for each frame
|
| 111 |
+
|
| 112 |
+
mask?
|
| 113 |
+
index?
|
| 114 |
+
frame_label: soft label
|
| 115 |
+
'''
|
| 116 |
+
data_numpy = self.data["X"][index]
|
| 117 |
+
data_numpy = np.array(data_numpy)
|
| 118 |
+
|
| 119 |
+
label = self.data["Y"][index]
|
| 120 |
+
label_np = np.zeros(self.nb_class)
|
| 121 |
+
for item in label:
|
| 122 |
+
label_np[item] = 1
|
| 123 |
+
label = np.array(label_np)
|
| 124 |
+
|
| 125 |
+
gt = self.data["L"][index]
|
| 126 |
+
gt = np.array(gt)
|
| 127 |
+
|
| 128 |
+
if self.random_shift:
|
| 129 |
+
data_numpy = tools.random_shift(data_numpy)
|
| 130 |
+
if self.random_choose:
|
| 131 |
+
data_numpy = tools.random_choose(data_numpy, self.window_size)
|
| 132 |
+
elif self.window_size > 0:
|
| 133 |
+
data_numpy = tools.auto_pading(data_numpy, self.window_size)
|
| 134 |
+
if self.random_move:
|
| 135 |
+
data_numpy = tools.random_move(data_numpy)
|
| 136 |
+
|
| 137 |
+
if self.frame_pad:
|
| 138 |
+
C, T, V, M = data_numpy.shape
|
| 139 |
+
if T % 15 != 0:
|
| 140 |
+
new_T = T + 15 - T % 15
|
| 141 |
+
|
| 142 |
+
data_numpy_paded = np.zeros((C, new_T, V, M))
|
| 143 |
+
data_numpy_paded[:, :T, :, :] = data_numpy
|
| 144 |
+
|
| 145 |
+
data_numpy = data_numpy_paded
|
| 146 |
+
|
| 147 |
+
mask = np.ones_like(gt)
|
| 148 |
+
|
| 149 |
+
frame_label = self.soft_labels[index]
|
| 150 |
+
|
| 151 |
+
return data_numpy, label, gt, mask, index, frame_label
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def import_class(name):
|
| 155 |
+
components = name.split(".")
|
| 156 |
+
mod = __import__(components[0])
|
| 157 |
+
for comp in components[1:]:
|
| 158 |
+
mod = getattr(mod, comp)
|
| 159 |
+
return mod
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def test(
|
| 163 |
+
dataset,
|
| 164 |
+
preds=None,
|
| 165 |
+
th=None,
|
| 166 |
+
idx=None,
|
| 167 |
+
graph="graph.ntu_rgb_d.Graph",
|
| 168 |
+
is_3d=True,
|
| 169 |
+
folder_p="viz",
|
| 170 |
+
label_json="prepare/configs/action_label_split1.json",
|
| 171 |
+
):
|
| 172 |
+
"""
|
| 173 |
+
vis the samples using matplotlib
|
| 174 |
+
:param data_path:
|
| 175 |
+
:param vid: the id of sample
|
| 176 |
+
:param graph:
|
| 177 |
+
:param is_3d: when vis NTU, set it True
|
| 178 |
+
:return:
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
with open(label_json) as infile:
|
| 182 |
+
jc = json.load(infile)
|
| 183 |
+
|
| 184 |
+
idx2act = {v: k for k, v in jc.items()}
|
| 185 |
+
|
| 186 |
+
idx2act[len(idx2act)] = "other"
|
| 187 |
+
|
| 188 |
+
if osp.exists(osp.join(folder_p, "frames")):
|
| 189 |
+
shutil.rmtree(osp.join(folder_p, "frames"))
|
| 190 |
+
os.makedirs(osp.join(folder_p, "frames"))
|
| 191 |
+
|
| 192 |
+
data, label, gt, _ = dataset[idx]
|
| 193 |
+
data = data.reshape((1,) + data.shape)
|
| 194 |
+
|
| 195 |
+
# for batch_idx, (data, label) in enumerate(loader):
|
| 196 |
+
N, C, T, V, M = data.shape
|
| 197 |
+
|
| 198 |
+
plt.ion()
|
| 199 |
+
fig = plt.figure()
|
| 200 |
+
if is_3d:
|
| 201 |
+
from mpl_toolkits.mplot3d import Axes3D
|
| 202 |
+
|
| 203 |
+
ax = fig.add_subplot(111, projection="3d")
|
| 204 |
+
else:
|
| 205 |
+
ax = fig.add_subplot(111)
|
| 206 |
+
|
| 207 |
+
if graph is None:
|
| 208 |
+
p_type = ["b.", "g.", "r.", "c.", "m.", "y.", "k.", "k.", "k.", "k."]
|
| 209 |
+
pose = [ax.plot(np.zeros(V), np.zeros(V), p_type[m])[0] for m in range(M)]
|
| 210 |
+
ax.axis([-1, 1, -1, 1])
|
| 211 |
+
for t in range(T):
|
| 212 |
+
for m in range(M):
|
| 213 |
+
pose[m].set_xdata(data[0, 0, t, :, m])
|
| 214 |
+
pose[m].set_ydata(data[0, 1, t, :, m])
|
| 215 |
+
fig.canvas.draw()
|
| 216 |
+
plt.pause(0.001)
|
| 217 |
+
else:
|
| 218 |
+
p_type = ["b-", "g-", "r-", "c-", "m-", "y-", "k-", "k-", "k-", "k-"]
|
| 219 |
+
import sys
|
| 220 |
+
from os import path
|
| 221 |
+
|
| 222 |
+
sys.path.append(
|
| 223 |
+
path.dirname(path.dirname(path.dirname(path.abspath(__file__))))
|
| 224 |
+
)
|
| 225 |
+
G = import_class(graph)()
|
| 226 |
+
edge = G.inward
|
| 227 |
+
pose = []
|
| 228 |
+
for m in range(M):
|
| 229 |
+
a = []
|
| 230 |
+
for i in range(len(edge)):
|
| 231 |
+
if is_3d:
|
| 232 |
+
a.append(ax.plot(np.zeros(3), np.zeros(3), p_type[m])[0])
|
| 233 |
+
else:
|
| 234 |
+
a.append(ax.plot(np.zeros(2), np.zeros(2), p_type[m])[0])
|
| 235 |
+
pose.append(a)
|
| 236 |
+
ax.axis([-1, 1, -1, 1])
|
| 237 |
+
if is_3d:
|
| 238 |
+
ax.set_zlim3d(-1, 1)
|
| 239 |
+
for t in range(T):
|
| 240 |
+
for m in range(M):
|
| 241 |
+
for i, (v1, v2) in enumerate(edge):
|
| 242 |
+
x1 = data[0, :2, t, v1, m]
|
| 243 |
+
x2 = data[0, :2, t, v2, m]
|
| 244 |
+
if (x1.sum() != 0 and x2.sum() != 0) or v1 == 1 or v2 == 1:
|
| 245 |
+
pose[m][i].set_xdata(data[0, 0, t, [v1, v2], m])
|
| 246 |
+
pose[m][i].set_ydata(data[0, 1, t, [v1, v2], m])
|
| 247 |
+
if is_3d:
|
| 248 |
+
pose[m][i].set_3d_properties(data[0, 2, t, [v1, v2], m])
|
| 249 |
+
|
| 250 |
+
if gt[t]:
|
| 251 |
+
text = ax.text2D(
|
| 252 |
+
0.1, 0.9, idx2act[int(label)], size=20, transform=ax.transAxes
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
if preds is not None:
|
| 256 |
+
pred_idx = preds[t].argmax()
|
| 257 |
+
text_pred = ax.text2D(
|
| 258 |
+
0.6,
|
| 259 |
+
0.9,
|
| 260 |
+
idx2act[int(pred_idx)] + f": {preds[t, pred_idx]:.2f}",
|
| 261 |
+
size=20,
|
| 262 |
+
transform=ax.transAxes,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
fig.canvas.draw()
|
| 266 |
+
plt.savefig(osp.join(folder_p, "frames", str(t) + ".jpg"), dpi=300)
|
| 267 |
+
if gt[t]:
|
| 268 |
+
text.remove()
|
| 269 |
+
if preds is not None:
|
| 270 |
+
text_pred.remove()
|
| 271 |
+
|
| 272 |
+
write_vid_from_imgs(folder_p, idx)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def write_vid_from_imgs(folder_p, fname, fps=30):
|
| 276 |
+
"""Collate frames into a video sequence.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
folder_p (str): Frame images are in the path: folder_p/frames/<int>.jpg
|
| 280 |
+
fps (float): Output frame rate.
|
| 281 |
+
|
| 282 |
+
Returns:
|
| 283 |
+
Output video is stored in the path: folder_p/video.mp4
|
| 284 |
+
"""
|
| 285 |
+
vid_p = osp.join(folder_p, f"{fname}.mp4")
|
| 286 |
+
cmd = [
|
| 287 |
+
"ffmpeg",
|
| 288 |
+
"-r",
|
| 289 |
+
str(int(fps)),
|
| 290 |
+
"-i",
|
| 291 |
+
osp.join(folder_p, "frames", "%d.jpg"),
|
| 292 |
+
"-y",
|
| 293 |
+
vid_p,
|
| 294 |
+
]
|
| 295 |
+
FNULL = open(os.devnull, "w")
|
| 296 |
+
retcode = subprocess.call(cmd, stdout=FNULL, stderr=subprocess.STDOUT)
|
| 297 |
+
if not 0 == retcode:
|
| 298 |
+
print(
|
| 299 |
+
"*******ValueError(Error {0} executing command: {1}*********".format(
|
| 300 |
+
retcode, " ".join(cmd)
|
| 301 |
+
)
|
| 302 |
+
)
|
| 303 |
+
shutil.rmtree(osp.join(folder_p, "frames"))
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
if __name__ == "__main__":
|
| 307 |
+
import os
|
| 308 |
+
|
| 309 |
+
os.environ["DISPLAY"] = "localhost:10.0"
|
| 310 |
+
data_path = "dataset/processed_data/train_split1.pkl"
|
| 311 |
+
graph = "graph.ntu_rgb_d.Graph"
|
| 312 |
+
dataset = Feeder(data_path)
|
| 313 |
+
test(dataset, idx=0, graph=graph, is_3d=True)
|
feeders/tools.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def downsample(data_numpy, step, random_sample=True):
|
| 9 |
+
# input: C,T,V,M
|
| 10 |
+
begin = np.random.randint(step) if random_sample else 0
|
| 11 |
+
return data_numpy[:, begin::step, :, :]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def temporal_slice(data_numpy, step):
|
| 15 |
+
# input: C,T,V,M
|
| 16 |
+
C, T, V, M = data_numpy.shape
|
| 17 |
+
return (
|
| 18 |
+
data_numpy.reshape(C, T / step, step, V, M)
|
| 19 |
+
.transpose((0, 1, 3, 2, 4))
|
| 20 |
+
.reshape(C, T / step, V, step * M)
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def mean_subtractor(data_numpy, mean):
|
| 25 |
+
# input: C,T,V,M
|
| 26 |
+
# naive version
|
| 27 |
+
if mean == 0:
|
| 28 |
+
return
|
| 29 |
+
C, T, V, M = data_numpy.shape
|
| 30 |
+
valid_frame = (data_numpy != 0).sum(axis=3).sum(axis=2).sum(axis=0) > 0
|
| 31 |
+
begin = valid_frame.argmax()
|
| 32 |
+
end = len(valid_frame) - valid_frame[::-1].argmax()
|
| 33 |
+
data_numpy[:, :end, :, :] = data_numpy[:, :end, :, :] - mean
|
| 34 |
+
return data_numpy
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def auto_pading(data_numpy, size, random_pad=False):
|
| 38 |
+
C, T, V, M = data_numpy.shape
|
| 39 |
+
if T < size:
|
| 40 |
+
begin = random.randint(0, size - T) if random_pad else 0
|
| 41 |
+
data_numpy_paded = np.zeros((C, size, V, M))
|
| 42 |
+
data_numpy_paded[:, begin : begin + T, :, :] = data_numpy
|
| 43 |
+
return data_numpy_paded
|
| 44 |
+
else:
|
| 45 |
+
return data_numpy
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def random_choose(data_numpy, size, auto_pad=True):
|
| 49 |
+
# input: C,T,V,M 随机选择其中一段,不是很合理。因为有0
|
| 50 |
+
C, T, V, M = data_numpy.shape
|
| 51 |
+
if T == size:
|
| 52 |
+
return data_numpy
|
| 53 |
+
elif T < size:
|
| 54 |
+
if auto_pad:
|
| 55 |
+
return auto_pading(data_numpy, size, random_pad=True)
|
| 56 |
+
else:
|
| 57 |
+
return data_numpy
|
| 58 |
+
else:
|
| 59 |
+
begin = random.randint(0, T - size)
|
| 60 |
+
return data_numpy[:, begin : begin + size, :, :]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def random_move(
|
| 64 |
+
data_numpy,
|
| 65 |
+
angle_candidate=[-10.0, -5.0, 0.0, 5.0, 10.0],
|
| 66 |
+
scale_candidate=[0.9, 1.0, 1.1],
|
| 67 |
+
transform_candidate=[-0.2, -0.1, 0.0, 0.1, 0.2],
|
| 68 |
+
move_time_candidate=[1],
|
| 69 |
+
):
|
| 70 |
+
# input: C,T,V,M
|
| 71 |
+
C, T, V, M = data_numpy.shape
|
| 72 |
+
move_time = random.choice(move_time_candidate)
|
| 73 |
+
node = np.arange(0, T, T * 1.0 / move_time).round().astype(int)
|
| 74 |
+
node = np.append(node, T)
|
| 75 |
+
num_node = len(node)
|
| 76 |
+
|
| 77 |
+
A = np.random.choice(angle_candidate, num_node)
|
| 78 |
+
S = np.random.choice(scale_candidate, num_node)
|
| 79 |
+
T_x = np.random.choice(transform_candidate, num_node)
|
| 80 |
+
T_y = np.random.choice(transform_candidate, num_node)
|
| 81 |
+
|
| 82 |
+
a = np.zeros(T)
|
| 83 |
+
s = np.zeros(T)
|
| 84 |
+
t_x = np.zeros(T)
|
| 85 |
+
t_y = np.zeros(T)
|
| 86 |
+
|
| 87 |
+
# linspace
|
| 88 |
+
for i in range(num_node - 1):
|
| 89 |
+
a[node[i] : node[i + 1]] = (
|
| 90 |
+
np.linspace(A[i], A[i + 1], node[i + 1] - node[i]) * np.pi / 180
|
| 91 |
+
)
|
| 92 |
+
s[node[i] : node[i + 1]] = np.linspace(S[i], S[i + 1], node[i + 1] - node[i])
|
| 93 |
+
t_x[node[i] : node[i + 1]] = np.linspace(
|
| 94 |
+
T_x[i], T_x[i + 1], node[i + 1] - node[i]
|
| 95 |
+
)
|
| 96 |
+
t_y[node[i] : node[i + 1]] = np.linspace(
|
| 97 |
+
T_y[i], T_y[i + 1], node[i + 1] - node[i]
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
theta = np.array(
|
| 101 |
+
[[np.cos(a) * s, -np.sin(a) * s], [np.sin(a) * s, np.cos(a) * s]]
|
| 102 |
+
) # xuanzhuan juzhen
|
| 103 |
+
|
| 104 |
+
# perform transformation
|
| 105 |
+
for i_frame in range(T):
|
| 106 |
+
xy = data_numpy[0:2, i_frame, :, :]
|
| 107 |
+
new_xy = np.dot(theta[:, :, i_frame], xy.reshape(2, -1))
|
| 108 |
+
new_xy[0] += t_x[i_frame]
|
| 109 |
+
new_xy[1] += t_y[i_frame] # pingyi bianhuan
|
| 110 |
+
data_numpy[0:2, i_frame, :, :] = new_xy.reshape(2, V, M)
|
| 111 |
+
|
| 112 |
+
return data_numpy
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def random_shift(data_numpy):
|
| 116 |
+
# input: C,T,V,M 偏移其中一段
|
| 117 |
+
C, T, V, M = data_numpy.shape
|
| 118 |
+
data_shift = np.zeros(data_numpy.shape)
|
| 119 |
+
valid_frame = (data_numpy != 0).sum(axis=3).sum(axis=2).sum(axis=0) > 0
|
| 120 |
+
begin = valid_frame.argmax()
|
| 121 |
+
end = len(valid_frame) - valid_frame[::-1].argmax()
|
| 122 |
+
|
| 123 |
+
size = end - begin
|
| 124 |
+
bias = random.randint(0, T - size)
|
| 125 |
+
data_shift[:, bias : bias + size, :, :] = data_numpy[:, begin:end, :, :]
|
| 126 |
+
|
| 127 |
+
return data_shift
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def openpose_match(data_numpy):
|
| 131 |
+
C, T, V, M = data_numpy.shape
|
| 132 |
+
assert C == 3
|
| 133 |
+
score = data_numpy[2, :, :, :].sum(axis=1)
|
| 134 |
+
# the rank of body confidence in each frame (shape: T-1, M)
|
| 135 |
+
rank = (-score[0 : T - 1]).argsort(axis=1).reshape(T - 1, M)
|
| 136 |
+
|
| 137 |
+
# data of frame 1
|
| 138 |
+
xy1 = data_numpy[0:2, 0 : T - 1, :, :].reshape(2, T - 1, V, M, 1)
|
| 139 |
+
# data of frame 2
|
| 140 |
+
xy2 = data_numpy[0:2, 1:T, :, :].reshape(2, T - 1, V, 1, M)
|
| 141 |
+
# square of distance between frame 1&2 (shape: T-1, M, M)
|
| 142 |
+
distance = ((xy2 - xy1) ** 2).sum(axis=2).sum(axis=0)
|
| 143 |
+
|
| 144 |
+
# match pose
|
| 145 |
+
forward_map = np.zeros((T, M), dtype=int) - 1
|
| 146 |
+
forward_map[0] = range(M)
|
| 147 |
+
for m in range(M):
|
| 148 |
+
choose = rank == m
|
| 149 |
+
forward = distance[choose].argmin(axis=1)
|
| 150 |
+
for t in range(T - 1):
|
| 151 |
+
distance[t, :, forward[t]] = np.inf
|
| 152 |
+
forward_map[1:][choose] = forward
|
| 153 |
+
assert np.all(forward_map >= 0)
|
| 154 |
+
|
| 155 |
+
# string data
|
| 156 |
+
for t in range(T - 1):
|
| 157 |
+
forward_map[t + 1] = forward_map[t + 1][forward_map[t]]
|
| 158 |
+
|
| 159 |
+
# generate data
|
| 160 |
+
new_data_numpy = np.zeros(data_numpy.shape)
|
| 161 |
+
for t in range(T):
|
| 162 |
+
new_data_numpy[:, t, :, :] = data_numpy[:, t, :, forward_map[t]].transpose(
|
| 163 |
+
1, 2, 0
|
| 164 |
+
)
|
| 165 |
+
data_numpy = new_data_numpy
|
| 166 |
+
|
| 167 |
+
# score sort
|
| 168 |
+
trace_score = data_numpy[2, :, :, :].sum(axis=1).sum(axis=0)
|
| 169 |
+
rank = (-trace_score).argsort()
|
| 170 |
+
data_numpy = data_numpy[:, :, :, rank]
|
| 171 |
+
|
| 172 |
+
return data_numpy
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def pad(tensor, padding_value=0):
|
| 176 |
+
return pad_sequence(tensor, batch_first=True, padding_value=padding_value)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def collate_with_padding(batch):
|
| 180 |
+
data = [torch.tensor(item[0].transpose(1, 0, 2, 3)) for item in batch]
|
| 181 |
+
target = [torch.tensor(item[1]) for item in batch]
|
| 182 |
+
gt = [torch.tensor(item[2]) for item in batch]
|
| 183 |
+
mask = [torch.tensor(item[3]) for item in batch]
|
| 184 |
+
|
| 185 |
+
data = pad(data).transpose(1, 2)
|
| 186 |
+
target = torch.tensor(target)
|
| 187 |
+
gt = pad(gt)
|
| 188 |
+
mask = pad(mask)
|
| 189 |
+
return [data, target, gt, mask]
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def collate_with_padding_multi(batch):
|
| 193 |
+
data = [torch.tensor(item[0].transpose(1, 0, 2, 3)) for item in batch]
|
| 194 |
+
target = [torch.tensor(item[1]) for item in batch]
|
| 195 |
+
gt = [torch.tensor(item[2]) for item in batch]
|
| 196 |
+
mask = [torch.tensor(item[3]) for item in batch]
|
| 197 |
+
|
| 198 |
+
data = pad(data).transpose(1, 2)
|
| 199 |
+
target = torch.stack(target)
|
| 200 |
+
gt = pad(gt)
|
| 201 |
+
mask = pad(mask)
|
| 202 |
+
return [data, target, gt, mask]
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def collate_with_padding_multi_velo(batch):
|
| 206 |
+
data = [torch.tensor(item[0].transpose(1, 0, 2, 3)) for item in batch]
|
| 207 |
+
velo = [torch.tensor(item[1].transpose(1, 0, 2, 3)) for item in batch]
|
| 208 |
+
target = [torch.tensor(item[2]) for item in batch]
|
| 209 |
+
gt = [torch.tensor(item[3]) for item in batch]
|
| 210 |
+
mask = [torch.tensor(item[4]) for item in batch]
|
| 211 |
+
|
| 212 |
+
data = pad(data).transpose(1, 2)
|
| 213 |
+
velo = pad(velo).transpose(1, 2)
|
| 214 |
+
target = torch.stack(target)
|
| 215 |
+
gt = pad(gt)
|
| 216 |
+
mask = pad(mask)
|
| 217 |
+
return [data, velo, target, gt, mask]
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def collate_with_padding_multi_joint(batch):
|
| 221 |
+
data = [torch.tensor(item[0].transpose(1, 0, 2, 3)) for item in batch] # shape?
|
| 222 |
+
target = [torch.tensor(item[1]) for item in batch] # video level label
|
| 223 |
+
gt = [torch.tensor(item[2]) for item in batch] # frame level label
|
| 224 |
+
mask = [torch.tensor(item[3]) for item in batch]
|
| 225 |
+
index = [torch.tensor(item[4]) for item in batch]
|
| 226 |
+
soft_label = [torch.tensor(item[5]) for item in batch]
|
| 227 |
+
|
| 228 |
+
data = pad(data).transpose(1, 2) # pad joints seq with 0, rather than the last frame
|
| 229 |
+
target = torch.stack(target)
|
| 230 |
+
gt = pad(gt,padding_value=4) # pad frame level label with 0, so 0 hhere have to stands for 'background', o for 4 action it will be 0,1,2,3,4,
|
| 231 |
+
mask = pad(mask)
|
| 232 |
+
index = torch.tensor(index)
|
| 233 |
+
soft_label = pad(soft_label, padding_value=-100)
|
| 234 |
+
return [data, target, gt, mask, index, soft_label]
|
graph/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from . import kinetics, ntu_rgb_d, tools,ntu_rgb_d_infogcn, nturgbd_blockgcn
|
graph/kinetics.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
import networkx as nx
|
| 4 |
+
import numpy as np
|
| 5 |
+
from graph import tools
|
| 6 |
+
|
| 7 |
+
sys.path.extend(["../"])
|
| 8 |
+
|
| 9 |
+
# Joint index:
|
| 10 |
+
# {0, "Nose"}
|
| 11 |
+
# {1, "Neck"},
|
| 12 |
+
# {2, "RShoulder"},
|
| 13 |
+
# {3, "RElbow"},
|
| 14 |
+
# {4, "RWrist"},
|
| 15 |
+
# {5, "LShoulder"},
|
| 16 |
+
# {6, "LElbow"},
|
| 17 |
+
# {7, "LWrist"},
|
| 18 |
+
# {8, "RHip"},
|
| 19 |
+
# {9, "RKnee"},
|
| 20 |
+
# {10, "RAnkle"},
|
| 21 |
+
# {11, "LHip"},
|
| 22 |
+
# {12, "LKnee"},
|
| 23 |
+
# {13, "LAnkle"},
|
| 24 |
+
# {14, "REye"},
|
| 25 |
+
# {15, "LEye"},
|
| 26 |
+
# {16, "REar"},
|
| 27 |
+
# {17, "LEar"},
|
| 28 |
+
|
| 29 |
+
# Edge format: (origin, neighbor)
|
| 30 |
+
num_node = 18
|
| 31 |
+
self_link = [(i, i) for i in range(num_node)]
|
| 32 |
+
inward = [
|
| 33 |
+
(4, 3),
|
| 34 |
+
(3, 2),
|
| 35 |
+
(7, 6),
|
| 36 |
+
(6, 5),
|
| 37 |
+
(13, 12),
|
| 38 |
+
(12, 11),
|
| 39 |
+
(10, 9),
|
| 40 |
+
(9, 8),
|
| 41 |
+
(11, 5),
|
| 42 |
+
(8, 2),
|
| 43 |
+
(5, 1),
|
| 44 |
+
(2, 1),
|
| 45 |
+
(0, 1),
|
| 46 |
+
(15, 0),
|
| 47 |
+
(14, 0),
|
| 48 |
+
(17, 15),
|
| 49 |
+
(16, 14),
|
| 50 |
+
]
|
| 51 |
+
outward = [(j, i) for (i, j) in inward]
|
| 52 |
+
neighbor = inward + outward
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Graph:
|
| 56 |
+
def __init__(self, labeling_mode="spatial"):
|
| 57 |
+
self.A = self.get_adjacency_matrix(labeling_mode)
|
| 58 |
+
self.num_node = num_node
|
| 59 |
+
self.self_link = self_link
|
| 60 |
+
self.inward = inward
|
| 61 |
+
self.outward = outward
|
| 62 |
+
self.neighbor = neighbor
|
| 63 |
+
|
| 64 |
+
def get_adjacency_matrix(self, labeling_mode=None):
|
| 65 |
+
if labeling_mode is None:
|
| 66 |
+
return self.A
|
| 67 |
+
if labeling_mode == "spatial":
|
| 68 |
+
A = tools.get_spatial_graph(num_node, self_link, inward, outward)
|
| 69 |
+
else:
|
| 70 |
+
raise ValueError()
|
| 71 |
+
return A
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
A = Graph("spatial").get_adjacency_matrix()
|
| 76 |
+
print("")
|
graph/ntu_rgb_d.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
from graph import tools
|
| 4 |
+
|
| 5 |
+
sys.path.extend(["../"])
|
| 6 |
+
|
| 7 |
+
num_node = 25
|
| 8 |
+
self_link = [(i, i) for i in range(num_node)]
|
| 9 |
+
inward_ori_index = [
|
| 10 |
+
(1, 2),
|
| 11 |
+
(2, 21),
|
| 12 |
+
(3, 21),
|
| 13 |
+
(4, 3),
|
| 14 |
+
(5, 21),
|
| 15 |
+
(6, 5),
|
| 16 |
+
(7, 6),
|
| 17 |
+
(8, 7),
|
| 18 |
+
(9, 21),
|
| 19 |
+
(10, 9),
|
| 20 |
+
(11, 10),
|
| 21 |
+
(12, 11),
|
| 22 |
+
(13, 1),
|
| 23 |
+
(14, 13),
|
| 24 |
+
(15, 14),
|
| 25 |
+
(16, 15),
|
| 26 |
+
(17, 1),
|
| 27 |
+
(18, 17),
|
| 28 |
+
(19, 18),
|
| 29 |
+
(20, 19),
|
| 30 |
+
(22, 23),
|
| 31 |
+
(23, 8),
|
| 32 |
+
(24, 25),
|
| 33 |
+
(25, 12),
|
| 34 |
+
]
|
| 35 |
+
inward = [(i - 1, j - 1) for (i, j) in inward_ori_index]
|
| 36 |
+
outward = [(j, i) for (i, j) in inward]
|
| 37 |
+
neighbor = inward + outward
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class Graph:
|
| 41 |
+
def __init__(self, labeling_mode="spatial"):
|
| 42 |
+
self.A = self.get_adjacency_matrix(labeling_mode)
|
| 43 |
+
self.num_node = num_node
|
| 44 |
+
self.self_link = self_link
|
| 45 |
+
self.inward = inward
|
| 46 |
+
self.outward = outward
|
| 47 |
+
self.neighbor = neighbor
|
| 48 |
+
|
| 49 |
+
def get_adjacency_matrix(self, labeling_mode=None):
|
| 50 |
+
if labeling_mode is None:
|
| 51 |
+
return self.A
|
| 52 |
+
if labeling_mode == "spatial":
|
| 53 |
+
A = tools.get_spatial_graph(num_node, self_link, inward, outward)
|
| 54 |
+
else:
|
| 55 |
+
raise ValueError()
|
| 56 |
+
return A
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
import os
|
| 61 |
+
|
| 62 |
+
import matplotlib.pyplot as plt
|
| 63 |
+
|
| 64 |
+
# os.environ['DISPLAY'] = 'localhost:11.0'
|
| 65 |
+
A = Graph("spatial").get_adjacency_matrix()
|
| 66 |
+
for i in A:
|
| 67 |
+
plt.imshow(i, cmap="gray")
|
| 68 |
+
plt.show()
|
| 69 |
+
print(A)
|
graph/tools.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def edge2mat(link, num_node):
|
| 5 |
+
A = np.zeros((num_node, num_node))
|
| 6 |
+
for i, j in link:
|
| 7 |
+
A[j, i] = 1
|
| 8 |
+
return A
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def normalize_digraph(A): # 除以每列的和
|
| 12 |
+
Dl = np.sum(A, 0)
|
| 13 |
+
h, w = A.shape
|
| 14 |
+
Dn = np.zeros((w, w))
|
| 15 |
+
for i in range(w):
|
| 16 |
+
if Dl[i] > 0:
|
| 17 |
+
Dn[i, i] = Dl[i] ** (-1)
|
| 18 |
+
AD = np.dot(A, Dn)
|
| 19 |
+
return AD
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_spatial_graph(num_node, self_link, inward, outward):
|
| 23 |
+
I = edge2mat(self_link, num_node)
|
| 24 |
+
In = normalize_digraph(edge2mat(inward, num_node))
|
| 25 |
+
Out = normalize_digraph(edge2mat(outward, num_node))
|
| 26 |
+
A = np.stack((I, In, Out))
|
| 27 |
+
return A
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
|
| 32 |
+
def get_sgp_mat(num_in, num_out, link):
|
| 33 |
+
A = np.zeros((num_in, num_out))
|
| 34 |
+
for i, j in link:
|
| 35 |
+
A[i, j] = 1
|
| 36 |
+
A_norm = A / np.sum(A, axis=0, keepdims=True)
|
| 37 |
+
return A_norm
|
| 38 |
+
|
| 39 |
+
def edge2mat(link, num_node):
|
| 40 |
+
A = np.zeros((num_node, num_node))
|
| 41 |
+
for i, j in link:
|
| 42 |
+
A[j, i] = 1
|
| 43 |
+
return A
|
| 44 |
+
|
| 45 |
+
def get_k_scale_graph(scale, A):
|
| 46 |
+
if scale == 1:
|
| 47 |
+
return A
|
| 48 |
+
An = np.zeros_like(A)
|
| 49 |
+
A_power = np.eye(A.shape[0])
|
| 50 |
+
for k in range(scale):
|
| 51 |
+
A_power = A_power @ A
|
| 52 |
+
An += A_power
|
| 53 |
+
An[An > 0] = 1
|
| 54 |
+
return An
|
| 55 |
+
|
| 56 |
+
def normalize_digraph(A):
|
| 57 |
+
Dl = np.sum(A, 0)
|
| 58 |
+
h, w = A.shape
|
| 59 |
+
Dn = np.zeros((w, w))
|
| 60 |
+
for i in range(w):
|
| 61 |
+
if Dl[i] > 0:
|
| 62 |
+
Dn[i, i] = Dl[i] ** (-1)
|
| 63 |
+
AD = np.dot(A, Dn)
|
| 64 |
+
return AD
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_spatial_graph(num_node, self_link, inward, outward):
|
| 68 |
+
I = edge2mat(self_link, num_node)
|
| 69 |
+
In = normalize_digraph(edge2mat(inward, num_node))
|
| 70 |
+
Out = normalize_digraph(edge2mat(outward, num_node))
|
| 71 |
+
A = np.stack((I, In, Out))
|
| 72 |
+
return A
|
| 73 |
+
|
| 74 |
+
def normalize_adjacency_matrix(A):
|
| 75 |
+
node_degrees = A.sum(-1)
|
| 76 |
+
degs_inv_sqrt = np.power(node_degrees, -0.5)
|
| 77 |
+
norm_degs_matrix = np.eye(len(node_degrees)) * degs_inv_sqrt
|
| 78 |
+
return (norm_degs_matrix @ A @ norm_degs_matrix).astype(np.float32)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def k_adjacency(A, k, with_self=False, self_factor=1):
|
| 82 |
+
assert isinstance(A, np.ndarray)
|
| 83 |
+
I = np.eye(len(A), dtype=A.dtype)
|
| 84 |
+
if k == 0:
|
| 85 |
+
return I
|
| 86 |
+
Ak = np.minimum(np.linalg.matrix_power(A + I, k), 1) \
|
| 87 |
+
- np.minimum(np.linalg.matrix_power(A + I, k - 1), 1)
|
| 88 |
+
if with_self:
|
| 89 |
+
Ak += (self_factor * I)
|
| 90 |
+
return Ak
|
| 91 |
+
|
| 92 |
+
def get_multiscale_spatial_graph(num_node, self_link, inward, outward):
|
| 93 |
+
I = edge2mat(self_link, num_node)
|
| 94 |
+
A1 = edge2mat(inward, num_node)
|
| 95 |
+
A2 = edge2mat(outward, num_node)
|
| 96 |
+
A3 = k_adjacency(A1, 2)
|
| 97 |
+
A4 = k_adjacency(A2, 2)
|
| 98 |
+
A1 = normalize_digraph(A1)
|
| 99 |
+
A2 = normalize_digraph(A2)
|
| 100 |
+
A3 = normalize_digraph(A3)
|
| 101 |
+
A4 = normalize_digraph(A4)
|
| 102 |
+
A = np.stack((I, A1, A2, A3, A4))
|
| 103 |
+
return A
|
| 104 |
+
|
| 105 |
+
def get_adjacency_matrix(edges, num_nodes):
|
| 106 |
+
A = np.zeros((num_nodes, num_nodes), dtype=np.float32)
|
| 107 |
+
for edge in edges:
|
| 108 |
+
A[edge] = 1.
|
| 109 |
+
return A
|
| 110 |
+
|
| 111 |
+
def get_uniform_graph(num_node, self_link, neighbor):
|
| 112 |
+
A = normalize_digraph(edge2mat(neighbor + self_link, num_node))
|
| 113 |
+
return A
|
huggingface.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import upload_folder
|
| 2 |
+
|
| 3 |
+
upload_folder(
|
| 4 |
+
folder_path="/root/autodl-tmp/workshop2",
|
| 5 |
+
repo_id="qiushuocheng/workshop",
|
| 6 |
+
repo_type="model", # or "dataset", "space"
|
| 7 |
+
path_in_repo="", # optional subdirectory in repo
|
| 8 |
+
commit_message="Initial upload"
|
| 9 |
+
)
|
human_model/Put SMPLH model here.txt
ADDED
|
File without changes
|
model/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import agcn
|
| 2 |
+
from . import agcn_Unet
|
| 3 |
+
from . import agcn_concat
|
| 4 |
+
from . import agcn_MSFF
|
| 5 |
+
from . import blockgcn_MSFF
|
| 6 |
+
|
| 7 |
+
from . import blockgcn
|
model/agcn.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright 2023 LINE Corporation
|
| 3 |
+
LINE Corporation licenses this file to you under the Apache License,
|
| 4 |
+
version 2.0 (the "License"); you may not use this file except in compliance
|
| 5 |
+
with the License. You may obtain a copy of the License at:
|
| 6 |
+
https://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
Unless required by applicable law or agreed to in writing, software
|
| 8 |
+
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
| 9 |
+
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
| 10 |
+
License for the specific language governing permissions and limitations
|
| 11 |
+
under the License.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import math
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from einops import rearrange, reduce, repeat
|
| 21 |
+
from torch.autograd import Variable
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def import_class(name):
|
| 25 |
+
components = name.split(".")
|
| 26 |
+
mod = __import__(components[0])
|
| 27 |
+
for comp in components[1:]:
|
| 28 |
+
mod = getattr(mod, comp)
|
| 29 |
+
return mod
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def conv_branch_init(conv, branches):
|
| 33 |
+
weight = conv.weight
|
| 34 |
+
n = weight.size(0)
|
| 35 |
+
k1 = weight.size(1)
|
| 36 |
+
k2 = weight.size(2)
|
| 37 |
+
nn.init.normal_(weight, 0, math.sqrt(2.0 / (n * k1 * k2 * branches)))
|
| 38 |
+
nn.init.constant_(conv.bias, 0)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def conv_init(conv):
|
| 42 |
+
nn.init.kaiming_normal_(conv.weight, mode="fan_out")
|
| 43 |
+
nn.init.constant_(conv.bias, 0)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def bn_init(bn, scale):
|
| 47 |
+
nn.init.constant_(bn.weight, scale)
|
| 48 |
+
nn.init.constant_(bn.bias, 0)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class unit_tcn(nn.Module):
|
| 52 |
+
def __init__(self, in_channels, out_channels, kernel_size=9, stride=1):
|
| 53 |
+
super(unit_tcn, self).__init__()
|
| 54 |
+
pad = int((kernel_size - 1) / 2)
|
| 55 |
+
self.conv = nn.Conv2d(
|
| 56 |
+
in_channels,
|
| 57 |
+
out_channels,
|
| 58 |
+
kernel_size=(kernel_size, 1),
|
| 59 |
+
padding=(pad, 0),
|
| 60 |
+
stride=(stride, 1),
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 64 |
+
self.relu = nn.ReLU()
|
| 65 |
+
conv_init(self.conv)
|
| 66 |
+
bn_init(self.bn, 1)
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
x = self.bn(self.conv(x))
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class unit_gcn(nn.Module):
|
| 74 |
+
def __init__(self, in_channels, out_channels, A, coff_embedding=4, num_subset=3):
|
| 75 |
+
super(unit_gcn, self).__init__()
|
| 76 |
+
inter_channels = out_channels // coff_embedding
|
| 77 |
+
self.inter_c = inter_channels
|
| 78 |
+
self.PA = nn.Parameter(torch.from_numpy(A.astype(np.float32)))
|
| 79 |
+
nn.init.constant_(self.PA, 1e-6)
|
| 80 |
+
self.A = Variable(torch.from_numpy(A.astype(np.float32)), requires_grad=False)
|
| 81 |
+
self.num_subset = num_subset
|
| 82 |
+
|
| 83 |
+
self.conv_a = nn.ModuleList()
|
| 84 |
+
self.conv_b = nn.ModuleList()
|
| 85 |
+
self.conv_d = nn.ModuleList()
|
| 86 |
+
for i in range(self.num_subset):
|
| 87 |
+
self.conv_a.append(nn.Conv2d(in_channels, inter_channels, 1))
|
| 88 |
+
self.conv_b.append(nn.Conv2d(in_channels, inter_channels, 1))
|
| 89 |
+
self.conv_d.append(nn.Conv2d(in_channels, out_channels, 1))
|
| 90 |
+
|
| 91 |
+
if in_channels != out_channels:
|
| 92 |
+
self.down = nn.Sequential(
|
| 93 |
+
nn.Conv2d(in_channels, out_channels, 1), nn.BatchNorm2d(out_channels)
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
self.down = lambda x: x
|
| 97 |
+
|
| 98 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 99 |
+
self.soft = nn.Softmax(-2)
|
| 100 |
+
self.relu = nn.ReLU()
|
| 101 |
+
|
| 102 |
+
for m in self.modules():
|
| 103 |
+
if isinstance(m, nn.Conv2d):
|
| 104 |
+
conv_init(m)
|
| 105 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 106 |
+
bn_init(m, 1)
|
| 107 |
+
bn_init(self.bn, 1e-6)
|
| 108 |
+
for i in range(self.num_subset):
|
| 109 |
+
conv_branch_init(self.conv_d[i], self.num_subset)
|
| 110 |
+
|
| 111 |
+
def forward(self, x):
|
| 112 |
+
N, C, T, V = x.size()
|
| 113 |
+
A = self.A
|
| 114 |
+
if -1 != x.get_device():
|
| 115 |
+
A = A.cuda(x.get_device())
|
| 116 |
+
A = A + self.PA
|
| 117 |
+
|
| 118 |
+
y = None
|
| 119 |
+
for i in range(self.num_subset):
|
| 120 |
+
A1 = (
|
| 121 |
+
self.conv_a[i](x)
|
| 122 |
+
.permute(0, 3, 1, 2)
|
| 123 |
+
.contiguous()
|
| 124 |
+
.view(N, V, self.inter_c * T)
|
| 125 |
+
)
|
| 126 |
+
A2 = self.conv_b[i](x).view(N, self.inter_c * T, V)
|
| 127 |
+
A1 = self.soft(torch.matmul(A1, A2) / A1.size(-1)) # N V V
|
| 128 |
+
A1 = A1 + A[i]
|
| 129 |
+
A2 = x.view(N, C * T, V)
|
| 130 |
+
z = self.conv_d[i](torch.matmul(A2, A1).view(N, C, T, V))
|
| 131 |
+
y = z + y if y is not None else z
|
| 132 |
+
|
| 133 |
+
y = self.bn(y)
|
| 134 |
+
y += self.down(x)
|
| 135 |
+
return self.relu(y)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class TCN_GCN_unit(nn.Module):
|
| 139 |
+
def __init__(self, in_channels, out_channels, A, stride=1, residual=True):
|
| 140 |
+
super(TCN_GCN_unit, self).__init__()
|
| 141 |
+
self.gcn1 = unit_gcn(in_channels, out_channels, A)
|
| 142 |
+
self.tcn1 = unit_tcn(out_channels, out_channels, stride=stride)
|
| 143 |
+
self.relu = nn.ReLU()
|
| 144 |
+
if not residual:
|
| 145 |
+
self.residual = lambda x: 0
|
| 146 |
+
|
| 147 |
+
elif (in_channels == out_channels) and (stride == 1):
|
| 148 |
+
self.residual = lambda x: x
|
| 149 |
+
|
| 150 |
+
else:
|
| 151 |
+
self.residual = unit_tcn(
|
| 152 |
+
in_channels, out_channels, kernel_size=1, stride=stride
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def forward(self, x):
|
| 156 |
+
x = self.tcn1(self.gcn1(x)) + self.residual(x)
|
| 157 |
+
return self.relu(x)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class Classifier(nn.Module):
|
| 161 |
+
def __init__(self, num_class=60, scale_factor=5.0, temperature=[1.0, 2.0, 5.0]):
|
| 162 |
+
super(Classifier, self).__init__()
|
| 163 |
+
|
| 164 |
+
# action features
|
| 165 |
+
self.ac_center = nn.Parameter(torch.zeros(num_class + 1, 256))
|
| 166 |
+
nn.init.xavier_uniform_(self.ac_center)
|
| 167 |
+
# foreground feature
|
| 168 |
+
|
| 169 |
+
self.temperature = temperature
|
| 170 |
+
self.scale_factor = scale_factor
|
| 171 |
+
|
| 172 |
+
def forward(self, x):
|
| 173 |
+
|
| 174 |
+
N = x.size(0)
|
| 175 |
+
|
| 176 |
+
x_emb = reduce(x, "(n m) c t v -> n t c", "mean", n=N)
|
| 177 |
+
|
| 178 |
+
norms_emb = F.normalize(x_emb, dim=2)
|
| 179 |
+
norms_ac = F.normalize(self.ac_center)
|
| 180 |
+
|
| 181 |
+
# generate foeground and action scores
|
| 182 |
+
frm_scrs = (
|
| 183 |
+
torch.einsum("ntd,cd->ntc", [norms_emb, norms_ac]) * self.scale_factor
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# attention
|
| 187 |
+
class_wise_atts = [F.softmax(frm_scrs * t, 1) for t in self.temperature]
|
| 188 |
+
|
| 189 |
+
# multiple instance learning branch
|
| 190 |
+
# temporal score aggregation
|
| 191 |
+
mid_vid_scrs = [
|
| 192 |
+
torch.einsum("ntc,ntc->nc", [frm_scrs, att]) for att in class_wise_atts
|
| 193 |
+
]
|
| 194 |
+
mil_vid_scr = (
|
| 195 |
+
torch.stack(mid_vid_scrs, -1).mean(-1) * 2.0
|
| 196 |
+
) # frm_scrs have been multiplied by the scale factor
|
| 197 |
+
mil_vid_pred = F.sigmoid(mil_vid_scr)
|
| 198 |
+
|
| 199 |
+
return mil_vid_pred, frm_scrs
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class Model(nn.Module):
|
| 203 |
+
def __init__(
|
| 204 |
+
self,
|
| 205 |
+
num_class=60,
|
| 206 |
+
num_point=25,
|
| 207 |
+
num_person=1,
|
| 208 |
+
graph=None,
|
| 209 |
+
graph_args=dict(),
|
| 210 |
+
in_channels=2,
|
| 211 |
+
scale_factor=5.0,
|
| 212 |
+
temperature=[1.0, 2.0, 5.0],
|
| 213 |
+
):
|
| 214 |
+
super(Model, self).__init__()
|
| 215 |
+
|
| 216 |
+
if graph is None:
|
| 217 |
+
raise ValueError()
|
| 218 |
+
else:
|
| 219 |
+
Graph = import_class(graph)
|
| 220 |
+
self.graph = Graph(**graph_args)
|
| 221 |
+
|
| 222 |
+
A = self.graph.A
|
| 223 |
+
self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point)
|
| 224 |
+
|
| 225 |
+
self.l1 = TCN_GCN_unit(3, 64, A, residual=False) # save (B,64,25,T)
|
| 226 |
+
self.l2 = TCN_GCN_unit(64, 64, A, stride=2)
|
| 227 |
+
self.l3 = TCN_GCN_unit(64, 64, A)
|
| 228 |
+
self.l4 = TCN_GCN_unit(64, 64, A) # save (B,64,25,T/2)
|
| 229 |
+
self.l5 = TCN_GCN_unit(64, 128, A, stride=2)
|
| 230 |
+
self.l6 = TCN_GCN_unit(128, 128, A)
|
| 231 |
+
self.l7 = TCN_GCN_unit(128, 128, A) # save (B,128,25,T/4)
|
| 232 |
+
self.l8 = TCN_GCN_unit(128, 256, A, stride=2)
|
| 233 |
+
self.l9 = TCN_GCN_unit(256, 256, A)
|
| 234 |
+
self.l10 = TCN_GCN_unit(256, 256, A) # save (B,256,25,T/8)
|
| 235 |
+
|
| 236 |
+
bn_init(self.data_bn, 1)
|
| 237 |
+
|
| 238 |
+
self.classifier_1 = Classifier(num_class, scale_factor, temperature)
|
| 239 |
+
|
| 240 |
+
self.classifier_2 = Classifier(num_class, scale_factor, temperature)
|
| 241 |
+
|
| 242 |
+
def forward(self, x,mask):
|
| 243 |
+
N, C, T, V, M = x.size()
|
| 244 |
+
|
| 245 |
+
x = rearrange(x, "n c t v m -> n (m v c) t")
|
| 246 |
+
# x = self.data_bn(x)
|
| 247 |
+
x = rearrange(x, "n (m v c) t -> (n m) c t v", m=M, v=V, c=C)
|
| 248 |
+
|
| 249 |
+
x = self.l1(x)
|
| 250 |
+
x = self.l2(x)
|
| 251 |
+
x = self.l3(x)
|
| 252 |
+
x = self.l4(x)
|
| 253 |
+
x = self.l5(x)
|
| 254 |
+
x = self.l6(x)
|
| 255 |
+
x = self.l7(x)
|
| 256 |
+
x = self.l8(x)
|
| 257 |
+
x = self.l9(x)
|
| 258 |
+
x = self.l10(x)
|
| 259 |
+
|
| 260 |
+
mil_vid_pred_1, frm_scrs_1 = self.classifier_1(x)
|
| 261 |
+
|
| 262 |
+
mil_vid_pred_2, frm_scrs_2 = self.classifier_2(x.detach())
|
| 263 |
+
|
| 264 |
+
# print (frm_scrs_1.size(), T)
|
| 265 |
+
|
| 266 |
+
frm_scrs_1 = rearrange(frm_scrs_1, "n t c -> n c t")
|
| 267 |
+
frm_scrs_1 = F.interpolate(
|
| 268 |
+
frm_scrs_1, size=(T), mode="linear", align_corners=True
|
| 269 |
+
)
|
| 270 |
+
frm_scrs_1 = rearrange(frm_scrs_1, "n c t -> n t c")
|
| 271 |
+
|
| 272 |
+
frm_scrs_2 = rearrange(frm_scrs_2, "n t c -> n c t")
|
| 273 |
+
frm_scrs_2 = F.interpolate(
|
| 274 |
+
frm_scrs_2, size=(T), mode="linear", align_corners=True
|
| 275 |
+
)
|
| 276 |
+
frm_scrs_2 = rearrange(frm_scrs_2, "n c t -> n t c")
|
| 277 |
+
|
| 278 |
+
return mil_vid_pred_1, frm_scrs_1, mil_vid_pred_2, frm_scrs_2
|
model/losses.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright 2023 LINE Corporation
|
| 3 |
+
LINE Corporation licenses this file to you under the Apache License,
|
| 4 |
+
version 2.0 (the "License"); you may not use this file except in compliance
|
| 5 |
+
with the License. You may obtain a copy of the License at:
|
| 6 |
+
https://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
Unless required by applicable law or agreed to in writing, software
|
| 8 |
+
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
| 9 |
+
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
| 10 |
+
License for the specific language governing permissions and limitations
|
| 11 |
+
under the License.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from platform import mac_ver
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from einops import rearrange, reduce, repeat
|
| 21 |
+
from torch.autograd import Variable
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def kl_loss_compute(pred, soft_targets, reduce=True):
|
| 25 |
+
|
| 26 |
+
kl = F.kl_div(
|
| 27 |
+
F.log_softmax(pred, dim=1), F.softmax(soft_targets, dim=1), reduce=False
|
| 28 |
+
)
|
| 29 |
+
if reduce:
|
| 30 |
+
return torch.mean(torch.sum(kl, dim=1))
|
| 31 |
+
else:
|
| 32 |
+
return torch.sum(kl, 1)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def mvl_loss(y_1, y_2, rate=0.2, weight=0.1):
|
| 36 |
+
y_1 = rearrange(y_1, "n t c -> (n t) c")
|
| 37 |
+
y_2 = rearrange(y_2, "n t c -> (n t) c")
|
| 38 |
+
|
| 39 |
+
loss_pick = weight * kl_loss_compute(
|
| 40 |
+
y_1, y_2, reduce=False
|
| 41 |
+
) + weight * kl_loss_compute(y_2, y_1, reduce=False)
|
| 42 |
+
|
| 43 |
+
loss_pick = loss_pick.cpu().detach()
|
| 44 |
+
|
| 45 |
+
ind_sorted = torch.argsort(loss_pick.data)
|
| 46 |
+
loss_sorted = loss_pick[ind_sorted]
|
| 47 |
+
|
| 48 |
+
num_remember = int(rate * len(loss_sorted))
|
| 49 |
+
|
| 50 |
+
ind_update = ind_sorted[:num_remember]
|
| 51 |
+
|
| 52 |
+
loss = torch.mean(loss_pick[ind_update])
|
| 53 |
+
|
| 54 |
+
return loss
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def cross_entropy_loss(outputs, soft_targets):
|
| 58 |
+
mask = (soft_targets != -100).sum(1) > 0
|
| 59 |
+
outputs = outputs[mask]
|
| 60 |
+
soft_targets = soft_targets[mask]
|
| 61 |
+
loss = -torch.mean(torch.sum(F.log_softmax(outputs, dim=1) * soft_targets, dim=1))
|
| 62 |
+
return loss
|
| 63 |
+
|
prepare/configs/action_label_split1.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"walk": 0,
|
| 3 |
+
"stand": 1,
|
| 4 |
+
"turn": 2,
|
| 5 |
+
"jump": 3
|
| 6 |
+
}
|
prepare/configs/action_label_split2.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"sit": 0,
|
| 3 |
+
"run": 1,
|
| 4 |
+
"stand up": 2,
|
| 5 |
+
"kick": 3
|
| 6 |
+
}
|
prepare/configs/action_label_split3.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"jog": 0,
|
| 3 |
+
"wave": 1,
|
| 4 |
+
"dance": 2,
|
| 5 |
+
"gesture": 3
|
| 6 |
+
}
|
prepare/create_dataset.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/abhinanda-punnakkal/BABEL/ to frame-wise motion segmentation
|
| 2 |
+
#! /usr/bin/env python
|
| 3 |
+
# -*- coding: utf-8 -*-
|
| 4 |
+
# vim:fenc=utf-8
|
| 5 |
+
#
|
| 6 |
+
# Copyright © 2021 achandrasekaran <arjun.chandrasekaran@tuebingen.mpg.de>
|
| 7 |
+
#
|
| 8 |
+
# Distributed under terms of the MIT license.
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import csv
|
| 12 |
+
import json
|
| 13 |
+
import os
|
| 14 |
+
import pdb
|
| 15 |
+
import pickle
|
| 16 |
+
import sys
|
| 17 |
+
from collections import *
|
| 18 |
+
from itertools import *
|
| 19 |
+
from os.path import basename as ospb
|
| 20 |
+
from os.path import dirname as ospd
|
| 21 |
+
from os.path import join as ospj
|
| 22 |
+
|
| 23 |
+
import dutils
|
| 24 |
+
import ipdb
|
| 25 |
+
import numpy as np
|
| 26 |
+
import pandas as pd
|
| 27 |
+
|
| 28 |
+
# Custom
|
| 29 |
+
import preprocess
|
| 30 |
+
import torch
|
| 31 |
+
import viz
|
| 32 |
+
from pandas.core.common import flatten
|
| 33 |
+
from tqdm import tqdm
|
| 34 |
+
|
| 35 |
+
"""
|
| 36 |
+
Script to load BABEL segments with NTU skeleton format and pre-process.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def ntu_style_preprocessing(b_dset_path):
|
| 41 |
+
""""""
|
| 42 |
+
print("Load BABEL v1.0 dataset subset", b_dset_path)
|
| 43 |
+
b_dset = dutils.read_pkl(b_dset_path)
|
| 44 |
+
|
| 45 |
+
X_new = []
|
| 46 |
+
Y_new = []
|
| 47 |
+
id_new = []
|
| 48 |
+
for idx in range(len(b_dset["X"])):
|
| 49 |
+
# Get unnormalized 5-sec. samples
|
| 50 |
+
X = np.array(b_dset["X"][idx])
|
| 51 |
+
print("X (old) = ", np.shape(X)) # T, V, C
|
| 52 |
+
|
| 53 |
+
X = X[np.newaxis, :, :, :]
|
| 54 |
+
|
| 55 |
+
# Prep. data for normalization
|
| 56 |
+
X = X.transpose(0, 3, 1, 2) # N, C, T, V
|
| 57 |
+
X = X[:, :, :, :, np.newaxis] # N, C, T, V, M
|
| 58 |
+
print("Shape of prepped X: ", X.shape)
|
| 59 |
+
|
| 60 |
+
# Normalize (pre-process) in NTU RGBD-style
|
| 61 |
+
ntu_sk_spine_bone = np.array([0, 1])
|
| 62 |
+
ntu_sk_shoulder_bone = np.array([8, 4])
|
| 63 |
+
X, l_m_sk = preprocess.pre_normalization(
|
| 64 |
+
X, zaxis=ntu_sk_spine_bone, xaxis=ntu_sk_shoulder_bone
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
if len(l_m_sk) == 0:
|
| 68 |
+
X_new.append(X[0])
|
| 69 |
+
Y_new.append(b_dset["Y"][idx])
|
| 70 |
+
id_new.append(b_dset["sid"][idx])
|
| 71 |
+
else:
|
| 72 |
+
print("Skipped")
|
| 73 |
+
|
| 74 |
+
# Dataset w/ processed seg. chunks. (Skip samples w/ missing skeletons)
|
| 75 |
+
b_AR_dset = {"sid": id_new, "X": X_new, "Y": Y_new}
|
| 76 |
+
|
| 77 |
+
fp = b_dset_path.replace("samples", "ntu_sk_ntu-style_preprocessed")
|
| 78 |
+
# fp = '../data/babel_v1.0/babel_v1.0_ntu_sk_ntu-style_preprocessed.pkl'
|
| 79 |
+
# dutils.write_pkl(b_AR_dset, fp)
|
| 80 |
+
with open(fp, "wb") as of:
|
| 81 |
+
pickle.dump(b_AR_dset, of, protocol=4)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_act_idx(y, act2idx, n_classes):
|
| 85 |
+
""""""
|
| 86 |
+
if y in act2idx:
|
| 87 |
+
return act2idx[y]
|
| 88 |
+
else:
|
| 89 |
+
return n_classes
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def store_splits_subsets(
|
| 93 |
+
n_classes, spl, plus_extra=True, w_folder="../data_created/babel_v1.0/"
|
| 94 |
+
):
|
| 95 |
+
""""""
|
| 96 |
+
# Get splits
|
| 97 |
+
splits = dutils.read_json("../data_created/amass_splits.json")
|
| 98 |
+
sid2split = {
|
| 99 |
+
int(ospb(u).replace(".mp4", "")): spl for spl in splits for u in splits[spl]
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
# In labels, act. cat. --> idx
|
| 103 |
+
act2idx_150 = dutils.read_json("../data_created/action_label_2_idx.json")
|
| 104 |
+
act2idx = {k: act2idx_150[k] for k in act2idx_150 if act2idx_150[k] < n_classes}
|
| 105 |
+
print("{0} actions in label set: {1}".format(len(act2idx), act2idx))
|
| 106 |
+
|
| 107 |
+
if plus_extra:
|
| 108 |
+
fp = w_folder + "babel_v1.0_" + spl + "_extra_ntu_sk_ntu-style_preprocessed.pkl"
|
| 109 |
+
else:
|
| 110 |
+
fp = w_folder + "babel_v1.0_" + spl + "_ntu_sk_ntu-style_preprocessed.pkl"
|
| 111 |
+
|
| 112 |
+
# Get full dataset
|
| 113 |
+
b_AR_dset = dutils.read_pkl(fp)
|
| 114 |
+
|
| 115 |
+
# Store idxs of samples to include in learning
|
| 116 |
+
split_idxs = defaultdict(list)
|
| 117 |
+
for i, y1 in enumerate(b_AR_dset["Y1"]):
|
| 118 |
+
|
| 119 |
+
# Check if action category in list of classes
|
| 120 |
+
if y1 not in act2idx:
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
sid = b_AR_dset["sid"][i]
|
| 124 |
+
split_idxs[sid2split[sid]].append(i) # Include idx in dataset
|
| 125 |
+
|
| 126 |
+
# Save features that'll be loaded by dataloader
|
| 127 |
+
ar_idxs = np.array(split_idxs[spl])
|
| 128 |
+
X = b_AR_dset["X"][ar_idxs]
|
| 129 |
+
if plus_extra:
|
| 130 |
+
fn = w_folder + f"{spl}_extra_ntu_sk_{n_classes}.npy"
|
| 131 |
+
else:
|
| 132 |
+
fn = w_folder + f"{spl}_ntu_sk_{n_classes}.npy"
|
| 133 |
+
np.save(fn, X)
|
| 134 |
+
|
| 135 |
+
# labels
|
| 136 |
+
labels = {k: np.array(b_AR_dset[k])[ar_idxs] for k in b_AR_dset if k != "X"}
|
| 137 |
+
|
| 138 |
+
# Create, save label data structure that'll be loaded by dataloader
|
| 139 |
+
label_idxs = defaultdict(list)
|
| 140 |
+
for i, y1 in enumerate(labels["Y1"]):
|
| 141 |
+
# y1
|
| 142 |
+
label_idxs["Y1"].append(act2idx[y1])
|
| 143 |
+
# yk
|
| 144 |
+
yk = [get_act_idx(y, act2idx, n_classes) for y in labels["Yk"][i]]
|
| 145 |
+
label_idxs["Yk"].append(yk)
|
| 146 |
+
# yov
|
| 147 |
+
yov_o = labels["Yov"][i]
|
| 148 |
+
yov = {get_act_idx(y, act2idx, n_classes): yov_o[y] for y in yov_o}
|
| 149 |
+
label_idxs["Yov"].append(yov)
|
| 150 |
+
#
|
| 151 |
+
label_idxs["seg_id"].append(labels["seg_id"][i])
|
| 152 |
+
label_idxs["sid"].append(labels["sid"][i])
|
| 153 |
+
label_idxs["chunk_n"].append(labels["chunk_n"][i])
|
| 154 |
+
label_idxs["anntr_id"].append(labels["anntr_id"][i])
|
| 155 |
+
|
| 156 |
+
if plus_extra:
|
| 157 |
+
wr_f = w_folder + f"{spl}_extra_label_{n_classes}.pkl"
|
| 158 |
+
else:
|
| 159 |
+
wr_f = w_folder + f"{spl}_label_{n_classes}.pkl"
|
| 160 |
+
dutils.write_pkl(
|
| 161 |
+
(
|
| 162 |
+
label_idxs["seg_id"],
|
| 163 |
+
(
|
| 164 |
+
label_idxs["Y1"],
|
| 165 |
+
label_idxs["sid"],
|
| 166 |
+
label_idxs["chunk_n"],
|
| 167 |
+
label_idxs["anntr_id"],
|
| 168 |
+
),
|
| 169 |
+
),
|
| 170 |
+
wr_f,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class Babel_AR:
|
| 175 |
+
"""Object containing data, methods for Action Recognition.
|
| 176 |
+
|
| 177 |
+
Task
|
| 178 |
+
-----
|
| 179 |
+
Given: x (Segment from Babel)
|
| 180 |
+
Predict: \hat{p}(x) (Distribution over action categories)
|
| 181 |
+
|
| 182 |
+
GT
|
| 183 |
+
---
|
| 184 |
+
How to compute GT for a given segment?
|
| 185 |
+
- yk: All action categories that are labeled for the entirety of segment
|
| 186 |
+
- y1: One of yk
|
| 187 |
+
- yov: Any y that belongs to part of a segment is considered to be GT.
|
| 188 |
+
Fraction of segment covered by an action: {'walk': 1.0, 'wave': 0.5}
|
| 189 |
+
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self, dataset, dense=True):
|
| 193 |
+
"""Dataset with (samples, different GTs)"""
|
| 194 |
+
# Load dataset
|
| 195 |
+
self.babel = dataset
|
| 196 |
+
self.dense = dense
|
| 197 |
+
self.jpos_p = "dataset/amass"
|
| 198 |
+
|
| 199 |
+
# Get frame-rate for each seq. in AMASS
|
| 200 |
+
f_p = "dataset/BABEL/action_recognition/data/featp_2_fps.json"
|
| 201 |
+
self.ft_p_2_fps = dutils.read_json(f_p)
|
| 202 |
+
|
| 203 |
+
# Dataset w/ keys = {'X', 'Y1', 'Yk', 'Yov', 'seg_id', 'sid',
|
| 204 |
+
# 'seg_dur'}
|
| 205 |
+
self.d = defaultdict(list)
|
| 206 |
+
for ann in tqdm(self.babel):
|
| 207 |
+
self._update_dataset(ann)
|
| 208 |
+
|
| 209 |
+
def _subsample_to_30fps(self, orig_ft, orig_fps):
|
| 210 |
+
"""Get features at 30fps frame-rate
|
| 211 |
+
Args:
|
| 212 |
+
orig_ft <array> (T, 25*3): Feats. @ `orig_fps` frame-rate
|
| 213 |
+
orig_fps <float>: Frame-rate in original (ft) seq.
|
| 214 |
+
Return:
|
| 215 |
+
ft <array> (T', 25*3): Feats. @ 30fps
|
| 216 |
+
"""
|
| 217 |
+
T, n_j, _ = orig_ft.shape
|
| 218 |
+
out_fps = 30.0
|
| 219 |
+
# Matching the sub-sampling used for rendering
|
| 220 |
+
if int(orig_fps) % int(out_fps):
|
| 221 |
+
sel_fr = np.floor(orig_fps / out_fps * np.arange(int(out_fps))).astype(int)
|
| 222 |
+
n_duration = int(T / int(orig_fps))
|
| 223 |
+
t_idxs = []
|
| 224 |
+
for i in range(n_duration):
|
| 225 |
+
t_idxs += list(i * int(orig_fps) + sel_fr)
|
| 226 |
+
if int(T % int(orig_fps)):
|
| 227 |
+
last_sec_frame_idx = n_duration * int(orig_fps)
|
| 228 |
+
t_idxs += [
|
| 229 |
+
x + last_sec_frame_idx for x in sel_fr if x + last_sec_frame_idx < T
|
| 230 |
+
]
|
| 231 |
+
else:
|
| 232 |
+
t_idxs = np.arange(0, T, orig_fps / out_fps, dtype=int)
|
| 233 |
+
|
| 234 |
+
ft = orig_ft[t_idxs, :, :]
|
| 235 |
+
return ft
|
| 236 |
+
|
| 237 |
+
def _viz_x(self, ft, fn="test_sample"):
|
| 238 |
+
"""Wraper to Viz. the given sample (w/ NTU RGBD skeleton)"""
|
| 239 |
+
viz.viz_seq(seq=ft, folder_p=f"test_viz/{fn}", sk_type="nturgbd", debug=True)
|
| 240 |
+
return None
|
| 241 |
+
|
| 242 |
+
def _load_seq_feats(self, ft_p, sk_type):
|
| 243 |
+
"""Given path to joint position features, return them in 30fps"""
|
| 244 |
+
# Identify appropriate feature directory path on disk
|
| 245 |
+
if "smpl_wo_hands" == sk_type: # SMPL w/o hands (T, 22*3)
|
| 246 |
+
jpos_p = ospj(self.jpos_p, "joint_pos")
|
| 247 |
+
if "nturgbd" == sk_type: # NTU (T, 219)
|
| 248 |
+
jpos_p = ospj(self.jpos_p, "babel_joint_pos")
|
| 249 |
+
|
| 250 |
+
# Get the correct dataset folder name
|
| 251 |
+
ddir_n = ospb(ospd(ospd(ft_p)))
|
| 252 |
+
ddir_map = {"BioMotionLab_NTroje": "BMLrub", "DFaust_67": "DFaust"}
|
| 253 |
+
ddir_n = ddir_map[ddir_n] if ddir_n in ddir_map else ddir_n
|
| 254 |
+
# Get the subject folder name
|
| 255 |
+
sub_fol_n = ospb(ospd(ft_p))
|
| 256 |
+
|
| 257 |
+
# Sanity check
|
| 258 |
+
fft_p = ospj(jpos_p, ddir_n, sub_fol_n, ospb(ft_p))
|
| 259 |
+
assert os.path.exists(fft_p)
|
| 260 |
+
|
| 261 |
+
# Load seq. fts.
|
| 262 |
+
ft = np.load(fft_p)["joint_pos"]
|
| 263 |
+
T, ft_sz = ft.shape
|
| 264 |
+
|
| 265 |
+
# Get NTU skeleton joints
|
| 266 |
+
ntu_js = dutils.smpl_to_nturgbd(model_type="smplh", out_format="nturgbd")
|
| 267 |
+
ft = ft.reshape(T, -1, 3)
|
| 268 |
+
ft = ft[:, ntu_js, :]
|
| 269 |
+
|
| 270 |
+
# Sub-sample to 30fps
|
| 271 |
+
orig_fps = self.ft_p_2_fps[ft_p]
|
| 272 |
+
ft = self._subsample_to_30fps(ft, orig_fps)
|
| 273 |
+
# print(f'Feat. shape = {ft.shape}, fps = {orig_fps}')
|
| 274 |
+
# if orig_fps != 30.0:
|
| 275 |
+
# self._viz_x(ft)
|
| 276 |
+
return ft
|
| 277 |
+
|
| 278 |
+
def _get_per_f_labels(self, ann, ann_type, seq_dur):
|
| 279 |
+
""" """
|
| 280 |
+
# Per-frame labels: {0: ['walk'], 1: ['walk', 'wave'], ... T: ['stand']}
|
| 281 |
+
yf = defaultdict(list)
|
| 282 |
+
T = int(30.0 * seq_dur)
|
| 283 |
+
for n_f in range(T):
|
| 284 |
+
cur_t = float(n_f / 30.0)
|
| 285 |
+
for seg in ann["labels"]:
|
| 286 |
+
|
| 287 |
+
if seg["act_cat"] is None:
|
| 288 |
+
continue
|
| 289 |
+
|
| 290 |
+
if "seq_ann" == ann_type:
|
| 291 |
+
seg["start_t"] = 0.0
|
| 292 |
+
seg["end_t"] = seq_dur
|
| 293 |
+
|
| 294 |
+
if cur_t >= float(seg["start_t"]) and cur_t < float(seg["end_t"]):
|
| 295 |
+
yf[n_f] += seg["act_cat"]
|
| 296 |
+
return yf
|
| 297 |
+
|
| 298 |
+
def _compute_dur_samples(self, id, ann, ann_type, seq_ft, seq_dur, dur=5.0):
|
| 299 |
+
"""Return each motion and its frame-wise GT action
|
| 300 |
+
|
| 301 |
+
Return:
|
| 302 |
+
[ { 'seg_id': motion id,
|
| 303 |
+
'x': motion feats,
|
| 304 |
+
'yall': labels of each motion,
|
| 305 |
+
{ ... }, ...
|
| 306 |
+
]
|
| 307 |
+
"""
|
| 308 |
+
yf = self._get_per_f_labels(ann, ann_type, seq_dur)
|
| 309 |
+
|
| 310 |
+
seq_ft = seq_ft[:len(yf)]
|
| 311 |
+
assert seq_ft.shape[0] == len(yf)
|
| 312 |
+
|
| 313 |
+
seq_samples = []
|
| 314 |
+
seq_samples.append(
|
| 315 |
+
{"seg_id": id, "x": seq_ft, "y": yf,}
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
return seq_samples
|
| 319 |
+
|
| 320 |
+
def _sample_at_seg_chunk_level(self, ann, seq_samples):
|
| 321 |
+
# Samples at segment-chunk-level
|
| 322 |
+
for i, sample in enumerate(seq_samples):
|
| 323 |
+
|
| 324 |
+
self.d["sid"].append(ann["babel_sid"]) # Seq. info
|
| 325 |
+
self.d["X"].append(sample["x"]) # motion feats.
|
| 326 |
+
self.d["Y"].append(sample["y"]) # labels of each motion.
|
| 327 |
+
return
|
| 328 |
+
|
| 329 |
+
def _update_dataset(self, ann):
|
| 330 |
+
"""Return one sample (one segment) = (X, Y1, Yall)"""
|
| 331 |
+
|
| 332 |
+
# Get feats. for seq.
|
| 333 |
+
seq_ft = self._load_seq_feats(ann["feat_p"], "nturgbd")
|
| 334 |
+
|
| 335 |
+
# To keep track of type of annotation for loading 'extra'
|
| 336 |
+
# Compute all GT labels for this seq.
|
| 337 |
+
seq_samples = None
|
| 338 |
+
if self.dense:
|
| 339 |
+
if ann["frame_ann"] is not None:
|
| 340 |
+
ann_ar = ann["frame_ann"]
|
| 341 |
+
seq_samples = self._compute_dur_samples(
|
| 342 |
+
ann["babel_sid"], ann_ar, "frame_ann", seq_ft, ann["dur"]
|
| 343 |
+
)
|
| 344 |
+
self._sample_at_seg_chunk_level(ann, seq_samples)
|
| 345 |
+
else:
|
| 346 |
+
print("not supported data")
|
| 347 |
+
|
| 348 |
+
else:
|
| 349 |
+
raise NotImplementedError
|
| 350 |
+
|
| 351 |
+
return
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
# Create dataset
|
| 355 |
+
# --------------------------
|
| 356 |
+
d_folder = "dataset/babel_v1.0_release/"
|
| 357 |
+
w_folder = "dataset/babel_v1.0_sequence/"
|
| 358 |
+
os.makedirs(w_folder, exist_ok=True)
|
| 359 |
+
for spl in ["train", "val"]:
|
| 360 |
+
# Load Dense BABEL
|
| 361 |
+
data = dutils.read_json(ospj(d_folder, f"{spl}.json"))
|
| 362 |
+
dataset = [data[sid] for sid in data]
|
| 363 |
+
dense_babel = Babel_AR(dataset, dense=True)
|
| 364 |
+
# Store Dense BABEL
|
| 365 |
+
d_filename = w_folder + "babel_v1.0_" + spl + "_samples.pkl"
|
| 366 |
+
dutils.write_pkl(dense_babel.d, d_filename)
|
| 367 |
+
|
| 368 |
+
# Pre-process, Store data in dataset
|
| 369 |
+
print("NTU-style preprocessing")
|
| 370 |
+
babel_dataset_AR = ntu_style_preprocessing(d_filename)
|
prepare/dutils.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! /usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# vim:fenc=utf-8
|
| 4 |
+
#
|
| 5 |
+
# Copyright © 2021 achandrasekaran <arjun.chandrasekaran@tuebingen.mpg.de>
|
| 6 |
+
#
|
| 7 |
+
# Distributed under terms of the MIT license.
|
| 8 |
+
|
| 9 |
+
import csv
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
import os.path as osp
|
| 13 |
+
import pdb
|
| 14 |
+
import pickle
|
| 15 |
+
import sys
|
| 16 |
+
from collections import Counter
|
| 17 |
+
from os.path import basename as ospb
|
| 18 |
+
from os.path import dirname as ospd
|
| 19 |
+
from os.path import join as ospj
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import viz
|
| 24 |
+
from smplx import SMPLH
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def read_json(json_filename):
|
| 29 |
+
"""Return contents of JSON file"""
|
| 30 |
+
jc = None
|
| 31 |
+
with open(json_filename) as infile:
|
| 32 |
+
jc = json.load(infile)
|
| 33 |
+
return jc
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def read_pkl(pkl_filename):
|
| 37 |
+
"""Return contents of pikcle file"""
|
| 38 |
+
pklc = None
|
| 39 |
+
with open(pkl_filename, "rb") as infile:
|
| 40 |
+
pklc = pickle.load(infile)
|
| 41 |
+
return pklc
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def write_json(contents, filename):
|
| 45 |
+
with open(filename, "w") as outfile:
|
| 46 |
+
json.dump(contents, outfile, indent=2)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def write_pkl(contents, filename):
|
| 50 |
+
with open(filename, "wb") as outfile:
|
| 51 |
+
pickle.dump(contents, outfile)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def smpl_to_nturgbd(model_type="smplh", out_format="nturgbd"):
|
| 55 |
+
"""Borrowed from https://gitlab.tuebingen.mpg.de/apunnakkal/2s_agcn/-/blob/master/data_gen/smpl_data_utils.py
|
| 56 |
+
NTU mapping
|
| 57 |
+
-----------
|
| 58 |
+
0 --> ?
|
| 59 |
+
1-base of the spine
|
| 60 |
+
2-middle of the spine
|
| 61 |
+
3-neck
|
| 62 |
+
4-head
|
| 63 |
+
5-left shoulder
|
| 64 |
+
6-left elbow
|
| 65 |
+
7-left wrist
|
| 66 |
+
8-left hand
|
| 67 |
+
9-right shoulder
|
| 68 |
+
10-right elbow
|
| 69 |
+
11-right wrist
|
| 70 |
+
12-right hand
|
| 71 |
+
13-left hip
|
| 72 |
+
14-left knee
|
| 73 |
+
15-left ankle
|
| 74 |
+
16-left foot
|
| 75 |
+
17-right hip
|
| 76 |
+
18-right knee
|
| 77 |
+
19-right ankle
|
| 78 |
+
20-right foot
|
| 79 |
+
21-spine
|
| 80 |
+
22-tip of the left hand
|
| 81 |
+
23-left thumb
|
| 82 |
+
24-tip of the right hand
|
| 83 |
+
25-right thumb
|
| 84 |
+
|
| 85 |
+
:param model_type:
|
| 86 |
+
:param out_format:
|
| 87 |
+
:return:
|
| 88 |
+
"""
|
| 89 |
+
if model_type == "smplh" and out_format == "nturgbd":
|
| 90 |
+
"22 and 37 are approximation for hand (base of index finger)"
|
| 91 |
+
return np.array(
|
| 92 |
+
[
|
| 93 |
+
0,
|
| 94 |
+
3,
|
| 95 |
+
12,
|
| 96 |
+
15,
|
| 97 |
+
16,
|
| 98 |
+
18,
|
| 99 |
+
20,
|
| 100 |
+
22, # left hand
|
| 101 |
+
17,
|
| 102 |
+
19,
|
| 103 |
+
21,
|
| 104 |
+
37, # right hand
|
| 105 |
+
1,
|
| 106 |
+
4,
|
| 107 |
+
7,
|
| 108 |
+
10, # left leg
|
| 109 |
+
2,
|
| 110 |
+
5,
|
| 111 |
+
8,
|
| 112 |
+
11, # right hand
|
| 113 |
+
9,
|
| 114 |
+
63,
|
| 115 |
+
64,
|
| 116 |
+
68,
|
| 117 |
+
69,
|
| 118 |
+
],
|
| 119 |
+
dtype=np.int32,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class dotdict(dict):
|
| 124 |
+
"""dot.notation access to dictionary attributes"""
|
| 125 |
+
|
| 126 |
+
__getattr__ = dict.get
|
| 127 |
+
__setattr__ = dict.__setitem__
|
| 128 |
+
__delattr__ = dict.__delitem__
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def store_counts(label_fp):
|
| 132 |
+
"""Compute # samples per class, from stored labels
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
label_fp <str>: Path to label file
|
| 136 |
+
|
| 137 |
+
Writes (to same path as label file):
|
| 138 |
+
out_fp <dict>: # samples per class = {<idx>: <count>, ...}
|
| 139 |
+
"""
|
| 140 |
+
Y_tup = read_pkl(label_fp)
|
| 141 |
+
Y_idxs = Y_tup[1][0]
|
| 142 |
+
print("# Samples in set = ", len(Y_idxs))
|
| 143 |
+
|
| 144 |
+
label_count = Counter(Y_idxs)
|
| 145 |
+
print("File ", label_fp, "len", len(label_count))
|
| 146 |
+
|
| 147 |
+
out_fp = label_fp.replace(".pkl", "_count.pkl")
|
| 148 |
+
write_pkl(label_count, out_fp)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def load_babel_dataset(d_folder="dataset/babel_v1.0_release"):
|
| 152 |
+
"""Load the BABEL dataset"""
|
| 153 |
+
# Data folder
|
| 154 |
+
l_babel_dense_files = ["train", "val", "test"]
|
| 155 |
+
l_babel_extra_files = ["extra_train", "extra_val"]
|
| 156 |
+
|
| 157 |
+
# BABEL Dataset
|
| 158 |
+
babel = {}
|
| 159 |
+
for fn in l_babel_dense_files + l_babel_extra_files:
|
| 160 |
+
babel[fn] = json.load(open(ospj(d_folder, fn + ".json")))
|
| 161 |
+
|
| 162 |
+
return babel
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def store_seq_fps(amass_p):
|
| 166 |
+
"""Get fps for each seq. in BABEL
|
| 167 |
+
Arguments:
|
| 168 |
+
---------
|
| 169 |
+
amass_p <str>: Path where you download AMASS to.
|
| 170 |
+
Save:
|
| 171 |
+
-----
|
| 172 |
+
featp_2_fps.json <dict>: Key: feat path <str>, value: orig. fps
|
| 173 |
+
in AMASS <float>. E.g.,: {'KIT/KIT/4/RightTurn01_poses.npz': 100.0, ...}
|
| 174 |
+
"""
|
| 175 |
+
# Get BABEL dataset
|
| 176 |
+
babel = load_babel_dataset()
|
| 177 |
+
|
| 178 |
+
# Loop over each BABEL seq, store frame-rate
|
| 179 |
+
ft_p_2_fps = {}
|
| 180 |
+
for fn in babel:
|
| 181 |
+
for sid in tqdm(babel[fn]):
|
| 182 |
+
ann = babel[fn][sid]
|
| 183 |
+
# print (ann)
|
| 184 |
+
if ann["feat_p"] not in ft_p_2_fps:
|
| 185 |
+
ddir_n = ann["feat_p"]
|
| 186 |
+
ddir_n_ = ddir_n.split("/")
|
| 187 |
+
ddir_n_ = ddir_n_[1:]
|
| 188 |
+
|
| 189 |
+
ddir_map = {"BMLrub": "BioMotionLab_NTroje", "DFaust": "DFaust_67"}
|
| 190 |
+
ddir_n_[0] = (
|
| 191 |
+
ddir_map[ddir_n_[0]] if ddir_n_[0] in ddir_map else ddir_n_[0]
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
p = ospj(amass_p, "/".join(ddir_n_))
|
| 195 |
+
|
| 196 |
+
fps = np.load(p)["mocap_framerate"]
|
| 197 |
+
ft_p_2_fps[ann["feat_p"]] = float(fps)
|
| 198 |
+
dest_fp = "dataset/featp_2_fps.json"
|
| 199 |
+
write_json(ft_p_2_fps, dest_fp)
|
| 200 |
+
return None
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def store_ntu_jpos(smplh_model_p, dest_jpos_p, amass_p):
|
| 204 |
+
"""Store joint positions of kfor NTU-RGBD skeleton"""
|
| 205 |
+
# Model to forward-pass through, to store joint positions
|
| 206 |
+
smplh = SMPLH(
|
| 207 |
+
smplh_model_p,
|
| 208 |
+
create_transl=False,
|
| 209 |
+
ext="pkl",
|
| 210 |
+
gender="male",
|
| 211 |
+
use_pca=False,
|
| 212 |
+
batch_size=1,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# Load paths to all BABEL features
|
| 216 |
+
featp_2_fps = read_json("dataset/featp_2_fps.json")
|
| 217 |
+
|
| 218 |
+
# Loop over all BABEL data, verify that joint positions are stored on disk
|
| 219 |
+
l_m_ft_p = []
|
| 220 |
+
for ft_p in featp_2_fps:
|
| 221 |
+
|
| 222 |
+
# Get the correct dataset folder name
|
| 223 |
+
ddir_n = ospb(ospd(ospd(ft_p)))
|
| 224 |
+
ddir_map = {"BioMotionLab_NTroje": "BMLrub", "DFaust_67": "DFaust"}
|
| 225 |
+
ddir_n = ddir_map[ddir_n] if ddir_n in ddir_map else ddir_n
|
| 226 |
+
# Get the subject folder name
|
| 227 |
+
sub_fol_n = ospb(ospd(ft_p))
|
| 228 |
+
|
| 229 |
+
# Sanity check
|
| 230 |
+
fft_p = ospj(dest_jpos_p, ddir_n, sub_fol_n, ospb(ft_p))
|
| 231 |
+
if not os.path.exists(fft_p):
|
| 232 |
+
l_m_ft_p.append((ft_p, fft_p))
|
| 233 |
+
print("Total # missing NTU RGBD skeleton features = ", len(l_m_ft_p))
|
| 234 |
+
|
| 235 |
+
# Loop over missing joint positions and store them on disk
|
| 236 |
+
for i, (ft_p, ntu_jpos_p) in enumerate(tqdm(l_m_ft_p)):
|
| 237 |
+
ft_p_ = ft_p.split("/")
|
| 238 |
+
ft_p_ = ft_p_[1:]
|
| 239 |
+
|
| 240 |
+
ft_p = ospj(amass_p, "/".join(ft_p_))
|
| 241 |
+
|
| 242 |
+
jrot_smplh = np.load(ft_p)["poses"]
|
| 243 |
+
# Break joints down into body parts
|
| 244 |
+
smpl_body_jrot = jrot_smplh[:, 3:66]
|
| 245 |
+
left_hand_jrot = jrot_smplh[:, 66:111]
|
| 246 |
+
right_hand_jrot = jrot_smplh[:, 111:]
|
| 247 |
+
root_orient = jrot_smplh[:, 0:3].reshape(-1, 3)
|
| 248 |
+
|
| 249 |
+
# Forward through model to get a superset of required joints
|
| 250 |
+
T = jrot_smplh.shape[0]
|
| 251 |
+
ntu_jpos = np.zeros((T, 219))
|
| 252 |
+
for t in range(T):
|
| 253 |
+
res = smplh(
|
| 254 |
+
body_pose=torch.Tensor(smpl_body_jrot[t : t + 1, :]),
|
| 255 |
+
global_orient=torch.Tensor(root_orient[t : t + 1, :]),
|
| 256 |
+
left_hand_pose=torch.Tensor(left_hand_jrot[t : t + 1, :]),
|
| 257 |
+
right_hand_pose=torch.Tensor(right_hand_jrot[t : t + 1, :]),
|
| 258 |
+
# transl=torch.Tensor(transl)
|
| 259 |
+
)
|
| 260 |
+
jpos = res.joints.detach().cpu().numpy()[:, :, :].reshape(-1)
|
| 261 |
+
ntu_jpos[t, :] = jpos
|
| 262 |
+
|
| 263 |
+
# Save to disk
|
| 264 |
+
if not os.path.exists(ospd(ntu_jpos_p)):
|
| 265 |
+
os.makedirs(ospd(ntu_jpos_p))
|
| 266 |
+
np.savez(ntu_jpos_p, joint_pos=ntu_jpos, allow_pickle=True)
|
| 267 |
+
|
| 268 |
+
return
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def viz_ntu_jpos(jpos_p, l_ft_p):
|
| 272 |
+
"""Visualize sequences of NTU-skeleton joint positions"""
|
| 273 |
+
# Indices that are in the NTU RGBD skeleton
|
| 274 |
+
smpl2nturgbd = smpl_to_nturgbd()
|
| 275 |
+
# Iterate over each
|
| 276 |
+
for ft_p in l_ft_p:
|
| 277 |
+
x = np.load(ospj(jpos_p, ft_p))["joint_pos"]
|
| 278 |
+
T, ft_sz = x.shape
|
| 279 |
+
x = x.reshape(T, ft_sz // 3, 3)
|
| 280 |
+
# print('Data shape = {0}'.format(x.shape))
|
| 281 |
+
x = x[:, smpl2nturgbd, :]
|
| 282 |
+
# print('Data shape = {0}'.format(x.shape))
|
| 283 |
+
# x = x[:,:,:, 0].transpose(1, 2, 0) # (3, 150, 22, 1) --> (150, 22, 3)
|
| 284 |
+
print("Data shape = {0}".format(x.shape))
|
| 285 |
+
viz.viz_seq(
|
| 286 |
+
seq=x, folder_p="test_viz/test_ntu_w_axis", sk_type="nturgbd", debug=True
|
| 287 |
+
)
|
| 288 |
+
print("-" * 50)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def main():
|
| 292 |
+
"""Store preliminary stuff"""
|
| 293 |
+
amass_p = "dataset/amass/"
|
| 294 |
+
|
| 295 |
+
# Save feature paths --> fps (released in babel/action_recognition/data/)
|
| 296 |
+
store_seq_fps(amass_p)
|
| 297 |
+
|
| 298 |
+
# Save joint positions in NTU-RGBD skeleton format
|
| 299 |
+
smplh_model_p = "./human_model/SMPLH_male.pkl"
|
| 300 |
+
# model is generated by https://github.com/vchoutas/smplx
|
| 301 |
+
jpos_p = "./dataset/amass/babel_joint_pos"
|
| 302 |
+
store_ntu_jpos(smplh_model_p, jpos_p, amass_p)
|
| 303 |
+
|
| 304 |
+
# Viz. saved seqs.
|
| 305 |
+
l_ft_p = ["ACCAD/Male2MartialArtsStances_c3d/D7 - walk to bow_poses.npz"]
|
| 306 |
+
viz_ntu_jpos(jpos_p, l_ft_p)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
if __name__ == "__main__":
|
| 310 |
+
main()
|
prepare/generate_dataset.sh
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python prepare/dutils.py
|
| 2 |
+
python prepare/create_dataset.py
|
| 3 |
+
python prepare/split_dataset.py --split 1
|
| 4 |
+
# python prepare/split_dataset.py --split 2
|
| 5 |
+
# python prepare/split_dataset.py --split 3
|
prepare/preprocess.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
from rotation import *
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def pre_normalization(data, zaxis=[0, 1], xaxis=[8, 4]):
|
| 8 |
+
N, C, T, V, M = data.shape
|
| 9 |
+
s = np.transpose(data, [0, 4, 2, 3, 1]) # N, C, T, V, M to N, M, T, V, C
|
| 10 |
+
l_m_sk = [] # List idxs of missing skeletons
|
| 11 |
+
|
| 12 |
+
print("pad the null frames with the previous frames")
|
| 13 |
+
for i_s, skeleton in enumerate(tqdm(s)): # pad
|
| 14 |
+
if skeleton.sum() == 0:
|
| 15 |
+
print(i_s, " has no skeleton")
|
| 16 |
+
l_m_sk.append(i_s)
|
| 17 |
+
for i_p, person in enumerate(skeleton):
|
| 18 |
+
if person.sum() == 0:
|
| 19 |
+
continue
|
| 20 |
+
if person[0].sum() == 0:
|
| 21 |
+
index = person.sum(-1).sum(-1) != 0
|
| 22 |
+
tmp = person[index].copy()
|
| 23 |
+
person *= 0
|
| 24 |
+
person[: len(tmp)] = tmp
|
| 25 |
+
for i_f, frame in enumerate(person):
|
| 26 |
+
if frame.sum() == 0:
|
| 27 |
+
if person[i_f:].sum() == 0:
|
| 28 |
+
rest = len(person) - i_f
|
| 29 |
+
num = int(np.ceil(rest / i_f))
|
| 30 |
+
pad = np.concatenate([person[0:i_f] for _ in range(num)], 0)[
|
| 31 |
+
:rest
|
| 32 |
+
]
|
| 33 |
+
s[i_s, i_p, i_f:] = pad
|
| 34 |
+
break
|
| 35 |
+
|
| 36 |
+
print("sub the center joint #1 (spine joint in ntu and neck joint in kinetics)")
|
| 37 |
+
for i_s, skeleton in enumerate(tqdm(s)):
|
| 38 |
+
if skeleton.sum() == 0:
|
| 39 |
+
continue
|
| 40 |
+
main_body_center = skeleton[0][:, 1:2, :].copy()
|
| 41 |
+
for i_p, person in enumerate(skeleton):
|
| 42 |
+
if person.sum() == 0:
|
| 43 |
+
continue
|
| 44 |
+
mask = (person.sum(-1) != 0).reshape(T, V, 1)
|
| 45 |
+
s[i_s, i_p] = (s[i_s, i_p] - main_body_center) * mask
|
| 46 |
+
|
| 47 |
+
print(
|
| 48 |
+
"parallel the bone between hip(jpt 0) and spine(jpt 1) of the first person to the z axis"
|
| 49 |
+
)
|
| 50 |
+
for i_s, skeleton in enumerate(tqdm(s)):
|
| 51 |
+
if skeleton.sum() == 0:
|
| 52 |
+
continue
|
| 53 |
+
joint_bottom = skeleton[0, 0, zaxis[0]]
|
| 54 |
+
joint_top = skeleton[0, 0, zaxis[1]]
|
| 55 |
+
axis = np.cross(joint_top - joint_bottom, [0, 0, 1])
|
| 56 |
+
angle = angle_between(joint_top - joint_bottom, [0, 0, 1])
|
| 57 |
+
matrix_z = rotation_matrix(axis, angle)
|
| 58 |
+
for i_p, person in enumerate(skeleton):
|
| 59 |
+
if person.sum() == 0:
|
| 60 |
+
continue
|
| 61 |
+
for i_f, frame in enumerate(person):
|
| 62 |
+
if frame.sum() == 0:
|
| 63 |
+
continue
|
| 64 |
+
for i_j, joint in enumerate(frame):
|
| 65 |
+
s[i_s, i_p, i_f, i_j] = np.dot(matrix_z, joint)
|
| 66 |
+
|
| 67 |
+
print(
|
| 68 |
+
"parallel the bone between right shoulder(jpt 8) and left shoulder(jpt 4) of the first person to the x axis"
|
| 69 |
+
)
|
| 70 |
+
for i_s, skeleton in enumerate(tqdm(s)):
|
| 71 |
+
if skeleton.sum() == 0:
|
| 72 |
+
continue
|
| 73 |
+
joint_rshoulder = skeleton[0, 0, xaxis[0]]
|
| 74 |
+
joint_lshoulder = skeleton[0, 0, xaxis[1]]
|
| 75 |
+
axis = np.cross(joint_rshoulder - joint_lshoulder, [1, 0, 0])
|
| 76 |
+
angle = angle_between(joint_rshoulder - joint_lshoulder, [1, 0, 0])
|
| 77 |
+
matrix_x = rotation_matrix(axis, angle)
|
| 78 |
+
for i_p, person in enumerate(skeleton):
|
| 79 |
+
if person.sum() == 0:
|
| 80 |
+
continue
|
| 81 |
+
for i_f, frame in enumerate(person):
|
| 82 |
+
if frame.sum() == 0:
|
| 83 |
+
continue
|
| 84 |
+
for i_j, joint in enumerate(frame):
|
| 85 |
+
s[i_s, i_p, i_f, i_j] = np.dot(matrix_x, joint)
|
| 86 |
+
|
| 87 |
+
data = np.transpose(s, [0, 4, 2, 3, 1])
|
| 88 |
+
return data, l_m_sk
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
data = np.load("../data/ntu/xview/val_data.npy")
|
| 93 |
+
pre_normalization(data)
|
| 94 |
+
np.save("../data/ntu/xview/data_val_pre.npy", data)
|
prepare/rotation.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! /usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# vim:fenc=utf-8
|
| 4 |
+
#
|
| 5 |
+
# Copyright © 2021 achandrasekaran <arjun.chandrasekaran@tuebingen.mpg.de>
|
| 6 |
+
#
|
| 7 |
+
# Distributed under terms of the MIT license.
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def rotation_matrix(axis, theta):
|
| 15 |
+
"""
|
| 16 |
+
Return the rotation matrix associated with counterclockwise rotation about
|
| 17 |
+
the given axis by theta radians.
|
| 18 |
+
"""
|
| 19 |
+
if np.abs(axis).sum() < 1e-6 or np.abs(theta) < 1e-6:
|
| 20 |
+
return np.eye(3)
|
| 21 |
+
axis = np.asarray(axis)
|
| 22 |
+
axis = axis / math.sqrt(np.dot(axis, axis))
|
| 23 |
+
a = math.cos(theta / 2.0)
|
| 24 |
+
b, c, d = -axis * math.sin(theta / 2.0)
|
| 25 |
+
aa, bb, cc, dd = a * a, b * b, c * c, d * d
|
| 26 |
+
bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
|
| 27 |
+
return np.array(
|
| 28 |
+
[
|
| 29 |
+
[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
|
| 30 |
+
[2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
|
| 31 |
+
[2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc],
|
| 32 |
+
]
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def unit_vector(vector):
|
| 37 |
+
""" Returns the unit vector of the vector. """
|
| 38 |
+
return vector / np.linalg.norm(vector)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def angle_between(v1, v2):
|
| 42 |
+
"""Returns the angle in radians between vectors 'v1' and 'v2'::
|
| 43 |
+
|
| 44 |
+
>>> angle_between((1, 0, 0), (0, 1, 0))
|
| 45 |
+
1.5707963267948966
|
| 46 |
+
>>> angle_between((1, 0, 0), (1, 0, 0))
|
| 47 |
+
0.0
|
| 48 |
+
>>> angle_between((1, 0, 0), (-1, 0, 0))
|
| 49 |
+
3.141592653589793
|
| 50 |
+
"""
|
| 51 |
+
if np.abs(v1).sum() < 1e-6 or np.abs(v2).sum() < 1e-6:
|
| 52 |
+
return 0
|
| 53 |
+
v1_u = unit_vector(v1)
|
| 54 |
+
v2_u = unit_vector(v2)
|
| 55 |
+
return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def x_rotation(vector, theta):
|
| 59 |
+
"""Rotates 3-D vector around x-axis"""
|
| 60 |
+
R = np.array(
|
| 61 |
+
[
|
| 62 |
+
[1, 0, 0],
|
| 63 |
+
[0, np.cos(theta), -np.sin(theta)],
|
| 64 |
+
[0, np.sin(theta), np.cos(theta)],
|
| 65 |
+
]
|
| 66 |
+
)
|
| 67 |
+
return np.dot(R, vector)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def y_rotation(vector, theta):
|
| 71 |
+
"""Rotates 3-D vector around y-axis"""
|
| 72 |
+
R = np.array(
|
| 73 |
+
[
|
| 74 |
+
[np.cos(theta), 0, np.sin(theta)],
|
| 75 |
+
[0, 1, 0],
|
| 76 |
+
[-np.sin(theta), 0, np.cos(theta)],
|
| 77 |
+
]
|
| 78 |
+
)
|
| 79 |
+
return np.dot(R, vector)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def z_rotation(vector, theta):
|
| 83 |
+
"""Rotates 3-D vector around z-axis"""
|
| 84 |
+
R = np.array(
|
| 85 |
+
[
|
| 86 |
+
[np.cos(theta), -np.sin(theta), 0],
|
| 87 |
+
[np.sin(theta), np.cos(theta), 0],
|
| 88 |
+
[0, 0, 1],
|
| 89 |
+
]
|
| 90 |
+
)
|
| 91 |
+
return np.dot(R, vector)
|
prepare/split_dataset.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright 2023 LINE Corporation
|
| 3 |
+
LINE Corporation licenses this file to you under the Apache License,
|
| 4 |
+
version 2.0 (the "License"); you may not use this file except in compliance
|
| 5 |
+
with the License. You may obtain a copy of the License at:
|
| 6 |
+
https://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
Unless required by applicable law or agreed to in writing, software
|
| 8 |
+
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
| 9 |
+
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
| 10 |
+
License for the specific language governing permissions and limitations
|
| 11 |
+
under the License.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import dutils
|
| 16 |
+
import pandas as pd
|
| 17 |
+
from collections import Counter
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import os
|
| 20 |
+
from pandas.core.common import flatten
|
| 21 |
+
import argparse
|
| 22 |
+
|
| 23 |
+
MAX_LEN = 1000
|
| 24 |
+
N_CLASS = 4
|
| 25 |
+
|
| 26 |
+
parser = argparse.ArgumentParser(
|
| 27 |
+
description="Spatial Temporal Graph Convolution Network"
|
| 28 |
+
)
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"--data-root",
|
| 31 |
+
default="dataset/babel_v1.0_sequence/",
|
| 32 |
+
help="the root path of the dataset",
|
| 33 |
+
type=str
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--split",
|
| 37 |
+
default=1,
|
| 38 |
+
help="the split of the dataset",
|
| 39 |
+
type=int
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--output-folder",
|
| 43 |
+
default="dataset/processed_data",
|
| 44 |
+
help="the output folder of the generated data",
|
| 45 |
+
type=str
|
| 46 |
+
)
|
| 47 |
+
args = parser.parse_args()
|
| 48 |
+
|
| 49 |
+
os.makedirs(args.output_folder, exist_ok=True)
|
| 50 |
+
|
| 51 |
+
def main(data_root):
|
| 52 |
+
train_data = dutils.read_pkl(os.path.join(data_root, "babel_v1.0_train_ntu_sk_ntu-style_preprocessed.pkl"))
|
| 53 |
+
test_data = dutils.read_pkl(os.path.join(data_root, "babel_v1.0_val_ntu_sk_ntu-style_preprocessed.pkl"))
|
| 54 |
+
|
| 55 |
+
act2idx = dutils.read_json(f"./prepare/configs/action_label_split{args.split}.json")
|
| 56 |
+
|
| 57 |
+
label_train_data(data_root, train_data, act2idx)
|
| 58 |
+
label_val_data(data_root, test_data, act2idx)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def label_train_data(data_root, train_data, act2idx):
|
| 62 |
+
sid = []
|
| 63 |
+
x = []
|
| 64 |
+
y = []
|
| 65 |
+
loc = []
|
| 66 |
+
|
| 67 |
+
for i, seq_labels in enumerate(tqdm(train_data["Y"])):
|
| 68 |
+
if len(seq_labels) > MAX_LEN:
|
| 69 |
+
continue
|
| 70 |
+
|
| 71 |
+
y_ = []
|
| 72 |
+
loc_ = []
|
| 73 |
+
flag = False
|
| 74 |
+
|
| 75 |
+
for frame, labels in seq_labels.items():
|
| 76 |
+
label_set = set(labels) & set(act2idx.keys())
|
| 77 |
+
label_list = list(label_set)
|
| 78 |
+
if len(label_list) > 0:
|
| 79 |
+
flag = True
|
| 80 |
+
loc_.append(act2idx[label_list[0]])
|
| 81 |
+
y_.append(act2idx[label_list[0]])
|
| 82 |
+
else:
|
| 83 |
+
loc_.append(N_CLASS)
|
| 84 |
+
|
| 85 |
+
max_t = len(loc_)
|
| 86 |
+
loc_ = np.array(loc_)
|
| 87 |
+
y_ = list(set(y_))
|
| 88 |
+
|
| 89 |
+
if flag:
|
| 90 |
+
|
| 91 |
+
# print (train_data["X"][i].shape, len(loc_))
|
| 92 |
+
loc.append(loc_)
|
| 93 |
+
sid.append(train_data["sid"][i])
|
| 94 |
+
x.append(train_data["X"][i][:,:max_t,...])
|
| 95 |
+
y.append(y_)
|
| 96 |
+
|
| 97 |
+
data = {"sid": sid, "X": x, "Y": y, "L":loc}
|
| 98 |
+
|
| 99 |
+
dutils.write_pkl(data, os.path.join(args.output_folder, f"train_split{args.split}.pkl"))
|
| 100 |
+
print (f"#Train sequence: {len(x)}")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def label_val_data(data_root, test_data, act2idx):
|
| 104 |
+
sid = []
|
| 105 |
+
x = []
|
| 106 |
+
y = []
|
| 107 |
+
loc = []
|
| 108 |
+
for i, seq_labels in enumerate(tqdm(test_data["Y"])):
|
| 109 |
+
if len(seq_labels) > MAX_LEN:
|
| 110 |
+
continue
|
| 111 |
+
|
| 112 |
+
y_ = []
|
| 113 |
+
loc_ = []
|
| 114 |
+
flag = False
|
| 115 |
+
|
| 116 |
+
for frame, labels in seq_labels.items():
|
| 117 |
+
label_set = set(labels) & set(act2idx.keys())
|
| 118 |
+
label_list = list(label_set)
|
| 119 |
+
if len(label_list) > 0:
|
| 120 |
+
flag = True
|
| 121 |
+
loc_.append(act2idx[label_list[0]])
|
| 122 |
+
y_.append(act2idx[label_list[0]])
|
| 123 |
+
else:
|
| 124 |
+
loc_.append(N_CLASS)
|
| 125 |
+
|
| 126 |
+
max_t = len(loc_)
|
| 127 |
+
loc_ = np.array(loc_)
|
| 128 |
+
y_ = list(set(y_))
|
| 129 |
+
|
| 130 |
+
if flag:
|
| 131 |
+
loc.append(loc_)
|
| 132 |
+
sid.append(test_data["sid"][i])
|
| 133 |
+
x.append(test_data["X"][i][:,:max_t,...])
|
| 134 |
+
y.append(y_)
|
| 135 |
+
|
| 136 |
+
data = {"sid": sid, "X": x, "Y": y, "L":loc}
|
| 137 |
+
|
| 138 |
+
dutils.write_pkl(data, os.path.join(args.output_folder, f"val_split{args.split}.pkl"))
|
| 139 |
+
print (f"#Test sequence: {len(x)}")
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
if __name__ == "__main__":
|
| 143 |
+
main(args.data_root)
|
prepare/viz.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! /usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# vim:fenc=utf-8
|
| 4 |
+
#
|
| 5 |
+
# Copyright © 2020 achandrasekaran <arjun.chandrasekaran@tuebingen.mpg.de>
|
| 6 |
+
#
|
| 7 |
+
# Distributed under terms of the MIT license.
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
import os
|
| 11 |
+
import os.path as osp
|
| 12 |
+
import pdb
|
| 13 |
+
import random
|
| 14 |
+
import shutil
|
| 15 |
+
import subprocess
|
| 16 |
+
import sys
|
| 17 |
+
import uuid
|
| 18 |
+
|
| 19 |
+
import cv2
|
| 20 |
+
import dutils
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
from matplotlib import pyplot as plt
|
| 24 |
+
from mpl_toolkits.mplot3d import Axes3D
|
| 25 |
+
from torch.nn.functional import interpolate as intrp
|
| 26 |
+
|
| 27 |
+
"""
|
| 28 |
+
Visualize input and output motion sequences and labels
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_smpl_skeleton():
|
| 33 |
+
"""Skeleton ordering so that you traverse joints in this order:
|
| 34 |
+
Left lower, Left upper, Spine, Neck, Head, Right lower, Right upper.
|
| 35 |
+
"""
|
| 36 |
+
return np.array(
|
| 37 |
+
[
|
| 38 |
+
# Left lower
|
| 39 |
+
[0, 1],
|
| 40 |
+
[1, 4],
|
| 41 |
+
[4, 7],
|
| 42 |
+
[7, 10],
|
| 43 |
+
# Left upper
|
| 44 |
+
[9, 13],
|
| 45 |
+
[13, 16],
|
| 46 |
+
[16, 18],
|
| 47 |
+
[18, 20],
|
| 48 |
+
# [20, 22],
|
| 49 |
+
# Spinal column
|
| 50 |
+
[0, 3],
|
| 51 |
+
[3, 6],
|
| 52 |
+
[6, 9],
|
| 53 |
+
[9, 12],
|
| 54 |
+
[12, 15],
|
| 55 |
+
# Right lower
|
| 56 |
+
[0, 2],
|
| 57 |
+
[2, 5],
|
| 58 |
+
[5, 8],
|
| 59 |
+
[8, 11],
|
| 60 |
+
# Right upper
|
| 61 |
+
[9, 14],
|
| 62 |
+
[14, 17],
|
| 63 |
+
[17, 19],
|
| 64 |
+
[19, 21],
|
| 65 |
+
# [21, 23],
|
| 66 |
+
]
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_nturgbd_joint_names():
|
| 71 |
+
"""From paper:
|
| 72 |
+
1-base of the spine 2-middle of the spine 3-neck 4-head 5-left shoulder 6-left elbow 7-left wrist 8- left hand 9-right shoulder 10-right elbow 11-right wrist 12- right hand 13-left hip 14-left knee 15-left ankle 16-left foot 17- right hip 18-right knee 19-right ankle 20-right foot 21-spine 22- tip of the left hand 23-left thumb 24-tip of the right hand 25- right thumb
|
| 73 |
+
"""
|
| 74 |
+
# Joint names by AC, based on SMPL names
|
| 75 |
+
joint_names_map = {
|
| 76 |
+
0: "Pelvis",
|
| 77 |
+
12: "L_Hip",
|
| 78 |
+
13: "L_Knee",
|
| 79 |
+
14: "L_Ankle",
|
| 80 |
+
15: "L_Foot",
|
| 81 |
+
16: "R_Hip",
|
| 82 |
+
17: "R_Knee",
|
| 83 |
+
18: "R_Ankle",
|
| 84 |
+
19: "R_Foot",
|
| 85 |
+
1: "Spine1",
|
| 86 |
+
# 'Spine2',
|
| 87 |
+
20: "Spine3",
|
| 88 |
+
2: "Neck",
|
| 89 |
+
3: "Head",
|
| 90 |
+
# 'L_Collar',
|
| 91 |
+
4: "L_Shoulder",
|
| 92 |
+
5: "L_Elbow",
|
| 93 |
+
6: "L_Wrist",
|
| 94 |
+
7: "L_Hand",
|
| 95 |
+
21: "L_HandTip", # Not in SMPL
|
| 96 |
+
22: "L_Thumb", # Not in SMPL
|
| 97 |
+
# 'R_Collar',
|
| 98 |
+
8: "R_Shoulder",
|
| 99 |
+
9: "R_Elbow",
|
| 100 |
+
10: "R_Wrist",
|
| 101 |
+
11: "R_Hand",
|
| 102 |
+
23: "R_HandTip", # Not in SMPL
|
| 103 |
+
24: "R_Thumb", # Not in SMPL
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
return [joint_names_map[idx] for idx in range(len(joint_names_map))]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def get_smpl_joint_names():
|
| 110 |
+
# Joint names from SMPL Wiki
|
| 111 |
+
joint_names_map = {
|
| 112 |
+
0: "Pelvis",
|
| 113 |
+
1: "L_Hip",
|
| 114 |
+
4: "L_Knee",
|
| 115 |
+
7: "L_Ankle",
|
| 116 |
+
10: "L_Foot",
|
| 117 |
+
2: "R_Hip",
|
| 118 |
+
5: "R_Knee",
|
| 119 |
+
8: "R_Ankle",
|
| 120 |
+
11: "R_Foot",
|
| 121 |
+
3: "Spine1",
|
| 122 |
+
6: "Spine2",
|
| 123 |
+
9: "Spine3",
|
| 124 |
+
12: "Neck",
|
| 125 |
+
15: "Head",
|
| 126 |
+
13: "L_Collar",
|
| 127 |
+
16: "L_Shoulder",
|
| 128 |
+
18: "L_Elbow",
|
| 129 |
+
20: "L_Wrist",
|
| 130 |
+
22: "L_Hand",
|
| 131 |
+
14: "R_Collar",
|
| 132 |
+
17: "R_Shoulder",
|
| 133 |
+
19: "R_Elbow",
|
| 134 |
+
21: "R_Wrist",
|
| 135 |
+
23: "R_Hand",
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
# Return all joints except indices 22 (L_Hand), 23 (R_Hand)
|
| 139 |
+
return [joint_names_map[idx] for idx in range(len(joint_names_map) - 2)]
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def get_nturgbd_skeleton():
|
| 143 |
+
"""Skeleton ordering such that you traverse joints in this order:
|
| 144 |
+
Left lower, Left upper, Spine, Neck, Head, Right lower, Right upper.
|
| 145 |
+
"""
|
| 146 |
+
return np.array(
|
| 147 |
+
[
|
| 148 |
+
# Left lower
|
| 149 |
+
[0, 12],
|
| 150 |
+
[12, 13],
|
| 151 |
+
[13, 14],
|
| 152 |
+
[14, 15],
|
| 153 |
+
# Left upper
|
| 154 |
+
[4, 20],
|
| 155 |
+
[4, 5],
|
| 156 |
+
[5, 6],
|
| 157 |
+
[6, 7],
|
| 158 |
+
[7, 21],
|
| 159 |
+
[7, 22], # --> L Thumb
|
| 160 |
+
# Spinal column
|
| 161 |
+
[0, 1],
|
| 162 |
+
[1, 20],
|
| 163 |
+
[20, 2],
|
| 164 |
+
[2, 3],
|
| 165 |
+
# Right lower
|
| 166 |
+
[0, 16],
|
| 167 |
+
[16, 17],
|
| 168 |
+
[17, 18],
|
| 169 |
+
[18, 19],
|
| 170 |
+
# Right upper
|
| 171 |
+
[20, 8],
|
| 172 |
+
[8, 9],
|
| 173 |
+
[9, 10],
|
| 174 |
+
[10, 11],
|
| 175 |
+
[11, 24],
|
| 176 |
+
# [24, 11] --> R Thumb
|
| 177 |
+
[21, 22],
|
| 178 |
+
[23, 24],
|
| 179 |
+
]
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def get_joint_colors(joint_names):
|
| 184 |
+
"""Return joints based on a color spectrum. Also, joints on
|
| 185 |
+
L and R should have distinctly different colors.
|
| 186 |
+
"""
|
| 187 |
+
# Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
|
| 188 |
+
cmap = plt.get_cmap("rainbow")
|
| 189 |
+
colors = [cmap(i) for i in np.linspace(0, 1, len(joint_names))]
|
| 190 |
+
colors = [np.array((c[2], c[1], c[0])) for c in colors]
|
| 191 |
+
return colors
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def calc_angle_from_x(sk):
|
| 195 |
+
"""Given skeleton, calc. angle from x-axis"""
|
| 196 |
+
# Hip bone
|
| 197 |
+
id_l_hip = get_smpl_joint_names().index("L_Hip")
|
| 198 |
+
id_r_hip = get_smpl_joint_names().index("R_Hip")
|
| 199 |
+
pl, pr = sk[id_l_hip], sk[id_r_hip]
|
| 200 |
+
bone = np.array(pr - pl)
|
| 201 |
+
unit_v = bone / np.linalg.norm(bone)
|
| 202 |
+
# Angle with x-axis
|
| 203 |
+
pdb.set_trace()
|
| 204 |
+
x_ax = np.array([1, 0, 0])
|
| 205 |
+
x_angle = math.degrees(np.arccos(np.dot(x_ax, unit_v)))
|
| 206 |
+
|
| 207 |
+
"""
|
| 208 |
+
l_hip_z = seq[0, joint_names.index('L_Hip'), 2]
|
| 209 |
+
r_hip_z = seq[0, joint_names.index('R_Hip'), 2]
|
| 210 |
+
az = 0 if (l_hip_z > zroot and zroot > r_hip_z) else 180
|
| 211 |
+
"""
|
| 212 |
+
if bone[1] > 0:
|
| 213 |
+
x_angle = -x_angle
|
| 214 |
+
|
| 215 |
+
return x_angle
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def calc_angle_from_y(sk):
|
| 219 |
+
"""Given skeleton, calc. angle from x-axis"""
|
| 220 |
+
# Hip bone
|
| 221 |
+
id_l_hip = get_smpl_joint_names().index("L_Hip")
|
| 222 |
+
id_r_hip = get_smpl_joint_names().index("R_Hip")
|
| 223 |
+
pl, pr = sk[id_l_hip], sk[id_r_hip]
|
| 224 |
+
bone = np.array(pl - pr)
|
| 225 |
+
unit_v = bone / np.linalg.norm(bone)
|
| 226 |
+
print(unit_v)
|
| 227 |
+
# Angle with x-axis
|
| 228 |
+
pdb.set_trace()
|
| 229 |
+
y_ax = np.array([0, 1, 0])
|
| 230 |
+
y_angle = math.degrees(np.arccos(np.dot(y_ax, unit_v)))
|
| 231 |
+
|
| 232 |
+
"""
|
| 233 |
+
l_hip_z = seq[0, joint_names.index('L_Hip'), 2]
|
| 234 |
+
r_hip_z = seq[0, joint_names.index('R_Hip'), 2]
|
| 235 |
+
az = 0 if (l_hip_z > zroot and zroot > r_hip_z) else 180
|
| 236 |
+
"""
|
| 237 |
+
# if bone[1] > 0:
|
| 238 |
+
# y_angle = - y_angle
|
| 239 |
+
seq_y_proj = bone * np.cos(np.deg2rad(y_angle))
|
| 240 |
+
print("Bone projected onto y-axis: ", seq_y_proj)
|
| 241 |
+
|
| 242 |
+
return y_angle
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def viz_skeleton(
|
| 246 |
+
seq,
|
| 247 |
+
folder_p,
|
| 248 |
+
sk_type="smpl",
|
| 249 |
+
radius=1,
|
| 250 |
+
lcolor="#ff0000",
|
| 251 |
+
rcolor="#0000ff",
|
| 252 |
+
action="",
|
| 253 |
+
debug=False,
|
| 254 |
+
):
|
| 255 |
+
"""Visualize skeletons for given sequence and store as images.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
seq (np.array): Array (frames) of joint positions.
|
| 259 |
+
Size depends on sk_type (see below).
|
| 260 |
+
if sk_type is 'smpl' then assume:
|
| 261 |
+
1. first 3 dims = translation.
|
| 262 |
+
2. Size = (# frames, 69)
|
| 263 |
+
elif sk_type is 'nturgbd', then assume:
|
| 264 |
+
1. no translation.
|
| 265 |
+
2. Size = (# frames, 25, 3)
|
| 266 |
+
folder_p (str): Path to root folder containing visualized frames.
|
| 267 |
+
Frames are dumped to the path: folder_p/frames/*.jpg
|
| 268 |
+
radius (float): Space around the subject?
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
Stores skeleton sequence as jpg frames.
|
| 272 |
+
"""
|
| 273 |
+
joint_names = (
|
| 274 |
+
get_nturgbd_joint_names() if "nturgbd" == sk_type else get_smpl_joint_names()
|
| 275 |
+
)
|
| 276 |
+
n_j = n_j = len(joint_names)
|
| 277 |
+
|
| 278 |
+
az = 90
|
| 279 |
+
if "smpl" == sk_type:
|
| 280 |
+
# SMPL kinematic chain, joint list.
|
| 281 |
+
# NOTE that hands are skipped.
|
| 282 |
+
kin_chain = get_smpl_skeleton()
|
| 283 |
+
# Reshape flat pose features into (frames, joints, (x,y,z)) (skip trans)
|
| 284 |
+
seq = seq[:, 3:].reshape(-1, n_j, 3).cpu().detach().numpy()
|
| 285 |
+
|
| 286 |
+
elif "nturgbd" == sk_type:
|
| 287 |
+
kin_chain = get_nturgbd_skeleton()
|
| 288 |
+
az = 0
|
| 289 |
+
|
| 290 |
+
# Get color-spectrum for skeleton
|
| 291 |
+
colors = get_joint_colors(joint_names)
|
| 292 |
+
labels = [(joint_names[jidx[0]], joint_names[jidx[1]]) for jidx in kin_chain]
|
| 293 |
+
|
| 294 |
+
# xroot, yroot, zroot = 0.0, 0.0, 0.0
|
| 295 |
+
xroot, yroot, zroot = seq[0, 0, 0], seq[0, 0, 1], seq[0, 0, 2]
|
| 296 |
+
# seq = seq - seq[0, :, :]
|
| 297 |
+
|
| 298 |
+
# Change viewing angle so that first frame is in frontal pose
|
| 299 |
+
# az = calc_angle_from_x(seq[0]-np.array([xroot, yroot, zroot]))
|
| 300 |
+
# az = calc_angle_from_y(seq[0]-np.array([xroot, yroot, zroot]))
|
| 301 |
+
|
| 302 |
+
# Viz. skeleton for each frame
|
| 303 |
+
for t in range(seq.shape[0]):
|
| 304 |
+
|
| 305 |
+
# Fig. settings
|
| 306 |
+
fig = plt.figure(figsize=(7, 6)) if debug else plt.figure(figsize=(5, 5))
|
| 307 |
+
ax = fig.add_subplot(111, projection="3d")
|
| 308 |
+
|
| 309 |
+
for i, (j1, j2) in enumerate(kin_chain):
|
| 310 |
+
# Store bones
|
| 311 |
+
x = np.array([seq[t, j1, 0], seq[t, j2, 0]])
|
| 312 |
+
y = np.array([seq[t, j1, 1], seq[t, j2, 1]])
|
| 313 |
+
z = np.array([seq[t, j1, 2], seq[t, j2, 2]])
|
| 314 |
+
# Plot bones in skeleton
|
| 315 |
+
ax.plot(x, y, z, c=colors[i], marker="o", linewidth=2, label=labels[i])
|
| 316 |
+
|
| 317 |
+
# More figure settings
|
| 318 |
+
ax.set_title(action)
|
| 319 |
+
ax.set_xlabel("X")
|
| 320 |
+
ax.set_ylabel("Y")
|
| 321 |
+
ax.set_zlabel("Z")
|
| 322 |
+
# xroot, yroot, zroot = seq[t, 0, 0], seq[t, 0, 1], seq[t, 0, 2]
|
| 323 |
+
|
| 324 |
+
# pdb.set_trace()
|
| 325 |
+
ax.set_xlim3d(-radius + xroot, radius + xroot)
|
| 326 |
+
ax.set_ylim3d([-radius + yroot, radius + yroot])
|
| 327 |
+
ax.set_zlim3d([-radius + zroot, radius + zroot])
|
| 328 |
+
|
| 329 |
+
if True == debug:
|
| 330 |
+
ax.axis("on")
|
| 331 |
+
ax.grid(b=True)
|
| 332 |
+
else:
|
| 333 |
+
ax.axis("off")
|
| 334 |
+
ax.grid(b=None)
|
| 335 |
+
# Turn off tick labels
|
| 336 |
+
ax.set_yticklabels([])
|
| 337 |
+
ax.set_xticklabels([])
|
| 338 |
+
ax.set_zticklabels([])
|
| 339 |
+
|
| 340 |
+
cv2.waitKey(0)
|
| 341 |
+
|
| 342 |
+
# ax.view_init(-75, 90)
|
| 343 |
+
# ax.view_init(elev=20, azim=90+az)
|
| 344 |
+
ax.view_init(elev=20, azim=az)
|
| 345 |
+
|
| 346 |
+
if True == debug:
|
| 347 |
+
ax.legend(bbox_to_anchor=(1.1, 1), loc="upper right")
|
| 348 |
+
pass
|
| 349 |
+
|
| 350 |
+
fig.savefig(osp.join(folder_p, "frames", "{0}.jpg".format(t)))
|
| 351 |
+
plt.close(fig)
|
| 352 |
+
|
| 353 |
+
# break
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def write_vid_from_imgs(folder_p, fps):
|
| 357 |
+
"""Collate frames into a video sequence.
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
folder_p (str): Frame images are in the path: folder_p/frames/<int>.jpg
|
| 361 |
+
fps (float): Output frame rate.
|
| 362 |
+
|
| 363 |
+
Returns:
|
| 364 |
+
Output video is stored in the path: folder_p/video.mp4
|
| 365 |
+
"""
|
| 366 |
+
vid_p = osp.join(folder_p, "video.mp4")
|
| 367 |
+
cmd = [
|
| 368 |
+
"ffmpeg",
|
| 369 |
+
"-r",
|
| 370 |
+
str(int(fps)),
|
| 371 |
+
"-i",
|
| 372 |
+
osp.join(folder_p, "frames", "%d.jpg"),
|
| 373 |
+
"-y",
|
| 374 |
+
vid_p,
|
| 375 |
+
]
|
| 376 |
+
FNULL = open(os.devnull, "w")
|
| 377 |
+
retcode = subprocess.call(cmd, stdout=FNULL, stderr=subprocess.STDOUT)
|
| 378 |
+
if not 0 == retcode:
|
| 379 |
+
print(
|
| 380 |
+
"*******ValueError(Error {0} executing command: {1}*********".format(
|
| 381 |
+
retcode, " ".join(cmd)
|
| 382 |
+
)
|
| 383 |
+
)
|
| 384 |
+
shutil.rmtree(osp.join(folder_p, "frames"))
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def viz_seq(seq, folder_p, sk_type, orig_fps=30.0, debug=False):
|
| 388 |
+
"""1. Dumps sequence of skeleton images for the given sequence of joints.
|
| 389 |
+
2. Collates the sequence of images into an mp4 video.
|
| 390 |
+
|
| 391 |
+
Args:
|
| 392 |
+
seq (np.array): Array of joint positions.
|
| 393 |
+
folder_p (str): Path to root folder that will contain frames folder.
|
| 394 |
+
sk_type (str): {'smpl', 'nturgbd'}
|
| 395 |
+
|
| 396 |
+
Return:
|
| 397 |
+
None. Path of mp4 video: folder_p/video.mp4
|
| 398 |
+
"""
|
| 399 |
+
# Delete folder if exists
|
| 400 |
+
if osp.exists(folder_p):
|
| 401 |
+
print("Deleting existing folder ", folder_p)
|
| 402 |
+
shutil.rmtree(folder_p)
|
| 403 |
+
|
| 404 |
+
# Create folder for frames
|
| 405 |
+
os.makedirs(osp.join(folder_p, "frames"))
|
| 406 |
+
|
| 407 |
+
# Dump frames into folder. Args: (data, radius, frames path)
|
| 408 |
+
viz_skeleton(seq, folder_p=folder_p, sk_type=sk_type, radius=1.2, debug=debug)
|
| 409 |
+
write_vid_from_imgs(folder_p, orig_fps)
|
| 410 |
+
|
| 411 |
+
return None
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def viz_rand_seq(X, Y, dtype, epoch, wb, urls=None, k=3, pred_labels=None):
|
| 415 |
+
"""
|
| 416 |
+
Args:
|
| 417 |
+
X (np.array): Array (frames) of SMPL joint positions.
|
| 418 |
+
Y (np.array): Multiple labels for each frame in x \in X.
|
| 419 |
+
dtype (str): {'input', 'pred'}
|
| 420 |
+
k (int): # samples to viz.
|
| 421 |
+
urls (tuple): Tuple of URLs of the rendered videos from original mocap.
|
| 422 |
+
wb (dict): Wandb log dict.
|
| 423 |
+
Returns:
|
| 424 |
+
viz_ds (dict): Data structure containing all viz. info so far.
|
| 425 |
+
"""
|
| 426 |
+
import wandb
|
| 427 |
+
|
| 428 |
+
# `idx2al`: idx --> action label string
|
| 429 |
+
al2idx = dutils.read_json("data/action_label_to_idx.json")
|
| 430 |
+
idx2al = {al2idx[k]: k for k in al2idx}
|
| 431 |
+
|
| 432 |
+
# Sample k random seqs. to viz.
|
| 433 |
+
for s_idx in random.sample(list(range(X.shape[0])), k):
|
| 434 |
+
# Visualize a single seq. in path `folder_p`
|
| 435 |
+
folder_p = osp.join("viz", str(uuid.uuid4()))
|
| 436 |
+
viz_seq(seq=X[s_idx], folder_p=folder_p)
|
| 437 |
+
title = "{0} seq. {1}: ".format(dtype, s_idx)
|
| 438 |
+
acts_str = ", ".join([idx2al[l] for l in torch.unique(Y[s_idx])])
|
| 439 |
+
wb[title + urls[s_idx]] = wandb.Video(
|
| 440 |
+
osp.join(folder_p, "video.mp4"), caption="Actions: " + acts_str
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
if "pred" == dtype or "preds" == dtype:
|
| 444 |
+
raise NotImplementedError
|
| 445 |
+
|
| 446 |
+
print("Done viz. {0} seqs.".format(k))
|
| 447 |
+
return wb
|
pyproject.toml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.pysen]
|
| 2 |
+
version = "0.10"
|
| 3 |
+
|
| 4 |
+
[tool.pysen.lint]
|
| 5 |
+
enable_black = true
|
| 6 |
+
enable_flake8 = true
|
| 7 |
+
enable_isort = true
|
| 8 |
+
enable_mypy = true
|
| 9 |
+
mypy_preset = "strict"
|
| 10 |
+
line_length = 88
|
| 11 |
+
py_version = "py37"
|
| 12 |
+
[[tool.pysen.lint.mypy_targets]]
|
| 13 |
+
paths = ["."]
|
requirements.txt
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
backcall==0.2.0
|
| 2 |
+
certifi==2020.12.5
|
| 3 |
+
decorator==4.4.2
|
| 4 |
+
ipdb==0.13.4
|
| 5 |
+
jedi==0.18.0
|
| 6 |
+
joblib==1.0.0
|
| 7 |
+
networkx==2.5
|
| 8 |
+
parso==0.8.1
|
| 9 |
+
pexpect==4.8.0
|
| 10 |
+
pickleshare==0.7.5
|
| 11 |
+
prompt-toolkit==3.0.10
|
| 12 |
+
protobuf==3.14.0
|
| 13 |
+
ptyprocess==0.7.0
|
| 14 |
+
Pygments==2.7.4
|
| 15 |
+
six==1.15.0
|
| 16 |
+
tensorboardX==2.1
|
| 17 |
+
threadpoolctl==2.1.0
|
| 18 |
+
tqdm==4.56.0
|
| 19 |
+
traitlets==5.0.5
|
| 20 |
+
typing-extensions==3.7.4.3
|
| 21 |
+
wcwidth==0.2.5
|
| 22 |
+
smplx==0.1.28
|
| 23 |
+
opencv-python
|
| 24 |
+
einops
|
| 25 |
+
matplotlib
|
| 26 |
+
scikit-learn
|
| 27 |
+
pandas
|
| 28 |
+
chumpy
|
train.py
ADDED
|
@@ -0,0 +1,830 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright 2023 LINE Corporation
|
| 3 |
+
LINE Corporation licenses this file to you under the Apache License,
|
| 4 |
+
version 2.0 (the "License"); you may not use this file except in compliance
|
| 5 |
+
with the License. You may obtain a copy of the License at:
|
| 6 |
+
https://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
Unless required by applicable law or agreed to in writing, software
|
| 8 |
+
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
| 9 |
+
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
| 10 |
+
License for the specific language governing permissions and limitations
|
| 11 |
+
under the License.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import print_function
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import inspect
|
| 18 |
+
import os
|
| 19 |
+
import pdb
|
| 20 |
+
import pickle
|
| 21 |
+
import random
|
| 22 |
+
import re
|
| 23 |
+
import shutil
|
| 24 |
+
import time
|
| 25 |
+
from collections import *
|
| 26 |
+
|
| 27 |
+
import ipdb
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
# torch
|
| 31 |
+
import torch
|
| 32 |
+
import torch.backends.cudnn as cudnn
|
| 33 |
+
import torch.nn as nn
|
| 34 |
+
import torch.nn.functional as F
|
| 35 |
+
import torch.optim as optim
|
| 36 |
+
import yaml
|
| 37 |
+
from einops import rearrange, reduce, repeat
|
| 38 |
+
from evaluation.classificationMAP import getClassificationMAP as cmAP
|
| 39 |
+
from evaluation.detectionMAP import getSingleStreamDetectionMAP as dsmAP
|
| 40 |
+
from feeders.tools import collate_with_padding_multi_joint
|
| 41 |
+
from model.losses import cross_entropy_loss, mvl_loss
|
| 42 |
+
from sklearn.metrics import f1_score
|
| 43 |
+
|
| 44 |
+
# Custom
|
| 45 |
+
from tensorboardX import SummaryWriter
|
| 46 |
+
from torch.autograd import Variable
|
| 47 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 48 |
+
from tqdm import tqdm
|
| 49 |
+
from utils.logger import Logger
|
| 50 |
+
|
| 51 |
+
def remove_prefix_from_state_dict(state_dict, prefix):
|
| 52 |
+
new_state_dict = {}
|
| 53 |
+
for k, v in state_dict.items():
|
| 54 |
+
if k.startswith(prefix):
|
| 55 |
+
print(k)
|
| 56 |
+
new_k = k[len(prefix):] # strip the prefix
|
| 57 |
+
print(new_k)
|
| 58 |
+
else:
|
| 59 |
+
print(k)
|
| 60 |
+
new_k = k
|
| 61 |
+
new_state_dict[new_k] = v
|
| 62 |
+
return new_state_dict
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def init_seed(seed):
|
| 66 |
+
torch.cuda.manual_seed_all(seed)
|
| 67 |
+
torch.manual_seed(seed)
|
| 68 |
+
np.random.seed(seed)
|
| 69 |
+
random.seed(seed)
|
| 70 |
+
torch.backends.cudnn.deterministic = True
|
| 71 |
+
torch.backends.cudnn.benchmark = False
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_parser():
|
| 75 |
+
# parameter priority: command line > config > default
|
| 76 |
+
parser = argparse.ArgumentParser(
|
| 77 |
+
description="Spatial Temporal Graph Convolution Network"
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--work-dir",
|
| 81 |
+
default="./work_dir/temp",
|
| 82 |
+
help="the work folder for storing results",
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
parser.add_argument("-model_saved_name", default="")
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--config",
|
| 88 |
+
default="./config/nturgbd-cross-view/test_bone.yaml",
|
| 89 |
+
help="path to the configuration file",
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# processor
|
| 93 |
+
parser.add_argument("--phase", default="train", help="must be train or test")
|
| 94 |
+
|
| 95 |
+
# visulize and debug
|
| 96 |
+
parser.add_argument("--seed", type=int, default=5, help="random seed for pytorch")
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"--log-interval",
|
| 99 |
+
type=int,
|
| 100 |
+
default=100,
|
| 101 |
+
help="the interval for printing messages (#iteration)",
|
| 102 |
+
)
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--save-interval",
|
| 105 |
+
type=int,
|
| 106 |
+
default=2,
|
| 107 |
+
help="the interval for storing models (#iteration)",
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--eval-interval",
|
| 111 |
+
type=int,
|
| 112 |
+
default=5,
|
| 113 |
+
help="the interval for evaluating models (#iteration)",
|
| 114 |
+
)
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--print-log", type=str2bool, default=True, help="print logging or not"
|
| 117 |
+
)
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--show-topk",
|
| 120 |
+
type=int,
|
| 121 |
+
default=[1, 5],
|
| 122 |
+
nargs="+",
|
| 123 |
+
help="which Top K accuracy will be shown",
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# feeder
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--feeder", default="feeder.feeder", help="data loader will be used"
|
| 129 |
+
)
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--num-worker",
|
| 132 |
+
type=int,
|
| 133 |
+
default=32,
|
| 134 |
+
help="the number of worker for data loader",
|
| 135 |
+
)
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--train-feeder-args",
|
| 138 |
+
default=dict(),
|
| 139 |
+
help="the arguments of data loader for training",
|
| 140 |
+
)
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--test-feeder-args",
|
| 143 |
+
default=dict(),
|
| 144 |
+
help="the arguments of data loader for test",
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# model
|
| 148 |
+
parser.add_argument("--model", default=None, help="the model will be used")
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--model-args", type=dict, default=dict(), help="the arguments of model"
|
| 151 |
+
)
|
| 152 |
+
parser.add_argument(
|
| 153 |
+
"--weights", default=None, help="the weights for network initialization"
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--ignore-weights",
|
| 157 |
+
type=str,
|
| 158 |
+
default=[],
|
| 159 |
+
nargs="+",
|
| 160 |
+
help="the name of weights which will be ignored in the initialization",
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# optim
|
| 164 |
+
parser.add_argument(
|
| 165 |
+
"--base-lr", type=float, default=0.01, help="initial learning rate"
|
| 166 |
+
)
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--step",
|
| 169 |
+
type=int,
|
| 170 |
+
default=[60,80],
|
| 171 |
+
nargs="+",
|
| 172 |
+
help="the epoch where optimizer reduce the learning rate",
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# training
|
| 176 |
+
parser.add_argument(
|
| 177 |
+
"--device",
|
| 178 |
+
type=int,
|
| 179 |
+
default=0,
|
| 180 |
+
nargs="+",
|
| 181 |
+
help="the indexes of GPUs for training or testing",
|
| 182 |
+
)
|
| 183 |
+
parser.add_argument("--optimizer", default="SGD", help="type of optimizer")
|
| 184 |
+
parser.add_argument(
|
| 185 |
+
"--nesterov", type=str2bool, default=False, help="use nesterov or not"
|
| 186 |
+
)
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"--batch-size", type=int, default=256, help="training batch size"
|
| 189 |
+
)
|
| 190 |
+
parser.add_argument(
|
| 191 |
+
"--test-batch-size", type=int, default=256, help="test batch size"
|
| 192 |
+
)
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"--start-epoch", type=int, default=0, help="start training from which epoch"
|
| 195 |
+
)
|
| 196 |
+
parser.add_argument(
|
| 197 |
+
"--num-epoch", type=int, default=80, help="stop training in which epoch"
|
| 198 |
+
)
|
| 199 |
+
parser.add_argument(
|
| 200 |
+
"--weight-decay", type=float, default=0.0005, help="weight decay for optimizer"
|
| 201 |
+
)
|
| 202 |
+
# loss
|
| 203 |
+
parser.add_argument("--loss", type=str, default="CE", help="loss type(CE or focal)")
|
| 204 |
+
parser.add_argument(
|
| 205 |
+
"--label_count_path",
|
| 206 |
+
default=None,
|
| 207 |
+
type=str,
|
| 208 |
+
help="Path to label counts (used in loss weighting)",
|
| 209 |
+
)
|
| 210 |
+
parser.add_argument(
|
| 211 |
+
"---beta",
|
| 212 |
+
type=float,
|
| 213 |
+
default=0.9999,
|
| 214 |
+
help="Hyperparameter for Class balanced loss",
|
| 215 |
+
)
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--gamma", type=float, default=2.0, help="Hyperparameter for Focal loss"
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
parser.add_argument("--only_train_part", default=False)
|
| 221 |
+
parser.add_argument("--only_train_epoch", default=0)
|
| 222 |
+
parser.add_argument("--warm_up_epoch", default=0)
|
| 223 |
+
|
| 224 |
+
parser.add_argument(
|
| 225 |
+
"--lambda-mil", default=1.0, help="balancing hyper-parameter of mil branch"
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
parser.add_argument(
|
| 229 |
+
"--class-threshold",
|
| 230 |
+
type=float,
|
| 231 |
+
default=0.1,
|
| 232 |
+
help="class threshold for rejection",
|
| 233 |
+
)
|
| 234 |
+
parser.add_argument(
|
| 235 |
+
"--start-threshold",
|
| 236 |
+
type=float,
|
| 237 |
+
default=0.03,
|
| 238 |
+
help="start threshold for action localization",
|
| 239 |
+
)
|
| 240 |
+
parser.add_argument(
|
| 241 |
+
"--end-threshold",
|
| 242 |
+
type=float,
|
| 243 |
+
default=0.055,
|
| 244 |
+
help="end threshold for action localization",
|
| 245 |
+
)
|
| 246 |
+
parser.add_argument(
|
| 247 |
+
"--threshold-interval",
|
| 248 |
+
type=float,
|
| 249 |
+
default=0.005,
|
| 250 |
+
help="threshold interval for action localization",
|
| 251 |
+
)
|
| 252 |
+
return parser
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class Processor:
|
| 256 |
+
"""
|
| 257 |
+
Processor for Skeleton-based Action Recgnition
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
def __init__(self, arg):
|
| 261 |
+
self.arg = arg
|
| 262 |
+
self.save_arg()
|
| 263 |
+
if arg.phase == "train":
|
| 264 |
+
if not arg.train_feeder_args["debug"]:
|
| 265 |
+
if os.path.isdir(arg.model_saved_name):
|
| 266 |
+
print("log_dir: ", arg.model_saved_name, "already exist")
|
| 267 |
+
# answer = input('delete it? y/n:')
|
| 268 |
+
answer = "y"
|
| 269 |
+
if answer == "y":
|
| 270 |
+
print("Deleting dir...")
|
| 271 |
+
shutil.rmtree(arg.model_saved_name)
|
| 272 |
+
print("Dir removed: ", arg.model_saved_name)
|
| 273 |
+
# input('Refresh the website of tensorboard by pressing any keys')
|
| 274 |
+
else:
|
| 275 |
+
print("Dir not removed: ", arg.model_saved_name)
|
| 276 |
+
self.train_writer = SummaryWriter(
|
| 277 |
+
os.path.join(arg.model_saved_name, "train"), "train"
|
| 278 |
+
)
|
| 279 |
+
self.val_writer = SummaryWriter(
|
| 280 |
+
os.path.join(arg.model_saved_name, "val"), "val"
|
| 281 |
+
)
|
| 282 |
+
else:
|
| 283 |
+
self.train_writer = self.val_writer = SummaryWriter(
|
| 284 |
+
os.path.join(arg.model_saved_name, "test"), "test"
|
| 285 |
+
)
|
| 286 |
+
self.global_step = 0
|
| 287 |
+
self.load_model()
|
| 288 |
+
self.load_optimizer()
|
| 289 |
+
self.load_data()
|
| 290 |
+
self.lr = self.arg.base_lr
|
| 291 |
+
self.best_acc = 0
|
| 292 |
+
self.best_per_class_acc = 0
|
| 293 |
+
self.loss_nce = torch.nn.BCELoss()
|
| 294 |
+
|
| 295 |
+
self.my_logger = Logger(
|
| 296 |
+
os.path.join(arg.model_saved_name, "log.txt"), title="SWTAL"
|
| 297 |
+
)
|
| 298 |
+
self.my_logger.set_names(["Step", "cmap"] + [f"map_0.{i}" for i in range(1, 6)])
|
| 299 |
+
|
| 300 |
+
def load_data(self):
|
| 301 |
+
Feeder = import_class(self.arg.feeder)
|
| 302 |
+
self.data_loader = dict()
|
| 303 |
+
if self.arg.phase == "train":
|
| 304 |
+
self.data_loader["train"] = torch.utils.data.DataLoader(
|
| 305 |
+
dataset=Feeder(**self.arg.train_feeder_args),
|
| 306 |
+
batch_size=self.arg.batch_size,
|
| 307 |
+
shuffle=True,
|
| 308 |
+
num_workers=self.arg.num_worker,
|
| 309 |
+
drop_last=True,
|
| 310 |
+
collate_fn=collate_with_padding_multi_joint,
|
| 311 |
+
)
|
| 312 |
+
self.data_loader["test"] = torch.utils.data.DataLoader(
|
| 313 |
+
dataset=Feeder(**self.arg.test_feeder_args),
|
| 314 |
+
batch_size=self.arg.test_batch_size,
|
| 315 |
+
shuffle=False,
|
| 316 |
+
num_workers=self.arg.num_worker,
|
| 317 |
+
drop_last=False,
|
| 318 |
+
collate_fn=collate_with_padding_multi_joint,
|
| 319 |
+
)
|
| 320 |
+
def load_model(self):
|
| 321 |
+
output_device = (
|
| 322 |
+
self.arg.device[0] if type(self.arg.device) is list else self.arg.device
|
| 323 |
+
)
|
| 324 |
+
self.output_device = output_device
|
| 325 |
+
Model = import_class(self.arg.model)
|
| 326 |
+
shutil.copy2(inspect.getfile(Model), self.arg.work_dir)
|
| 327 |
+
# print(Model)
|
| 328 |
+
self.model = Model(**self.arg.model_args).cuda(output_device)
|
| 329 |
+
# print(self.model)
|
| 330 |
+
self.loss_type = arg.loss
|
| 331 |
+
|
| 332 |
+
if self.arg.weights:
|
| 333 |
+
# if False:
|
| 334 |
+
# self.global_step = int(arg.weights[:-3].split("-")[-1])
|
| 335 |
+
self.print_log("Load weights from {}.".format(self.arg.weights))
|
| 336 |
+
if ".pkl" in self.arg.weights:
|
| 337 |
+
with open(self.arg.weights, "r") as f:
|
| 338 |
+
weights = pickle.load(f)
|
| 339 |
+
else:
|
| 340 |
+
weights = torch.load(self.arg.weights)
|
| 341 |
+
|
| 342 |
+
weights = OrderedDict(
|
| 343 |
+
[
|
| 344 |
+
[k.split("module.")[-1], v.cuda(output_device)]
|
| 345 |
+
for k, v in weights.items()
|
| 346 |
+
]
|
| 347 |
+
)
|
| 348 |
+
weights = remove_prefix_from_state_dict(weights, 'encoder_q.agcn.')
|
| 349 |
+
keys = list(weights.keys())
|
| 350 |
+
self.arg.ignore_weights = ['encoder_q','encoder_q.relation','data_bn','fc','encoder_k','queue','queue_ptr','value_transform']
|
| 351 |
+
for w in self.arg.ignore_weights:
|
| 352 |
+
for key in keys:
|
| 353 |
+
if w in key:
|
| 354 |
+
if weights.pop(key, None) is not None:
|
| 355 |
+
self.print_log(
|
| 356 |
+
"Sucessfully Remove Weights: {}.".format(key)
|
| 357 |
+
)
|
| 358 |
+
else:
|
| 359 |
+
self.print_log("Can Not Remove Weights: {}.".format(key))
|
| 360 |
+
|
| 361 |
+
try:
|
| 362 |
+
self.model.load_state_dict(weights)
|
| 363 |
+
except:
|
| 364 |
+
state = self.model.state_dict()
|
| 365 |
+
diff = list(set(state.keys()).difference(set(weights.keys())))
|
| 366 |
+
print("Can not find these weights:")
|
| 367 |
+
for d in diff:
|
| 368 |
+
print(" " + d)
|
| 369 |
+
state.update(weights)
|
| 370 |
+
self.model.load_state_dict(state)
|
| 371 |
+
|
| 372 |
+
if type(self.arg.device) is list:
|
| 373 |
+
if len(self.arg.device) > 1:
|
| 374 |
+
self.model = nn.DataParallel(
|
| 375 |
+
self.model, device_ids=self.arg.device, output_device=output_device
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
# # def load_model(self):
|
| 379 |
+
# output_device = (
|
| 380 |
+
# self.arg.device[0] if type(self.arg.device) is list else self.arg.device
|
| 381 |
+
# )
|
| 382 |
+
# self.output_device = output_device
|
| 383 |
+
# Model = import_class(self.arg.model)
|
| 384 |
+
# shutil.copy2(inspect.getfile(Model), self.arg.work_dir)
|
| 385 |
+
# # print(Model)
|
| 386 |
+
# self.model = Model(**self.arg.model_args).cuda(output_device)
|
| 387 |
+
# # print(self.model)
|
| 388 |
+
# self.loss_type = arg.loss
|
| 389 |
+
|
| 390 |
+
# if self.arg.weights:
|
| 391 |
+
# # self.global_step = int(arg.weights[:-3].split("-")[-1])
|
| 392 |
+
# self.print_log("Load weights from {}.".format(self.arg.weights))
|
| 393 |
+
# if ".pkl" in self.arg.weights:
|
| 394 |
+
# with open(self.arg.weights, "r") as f:
|
| 395 |
+
# weights = pickle.load(f)
|
| 396 |
+
# else:
|
| 397 |
+
# weights = torch.load(self.arg.weights)
|
| 398 |
+
|
| 399 |
+
# weights = OrderedDict(
|
| 400 |
+
# [
|
| 401 |
+
# [k.split("module.")[-1], v.cuda(output_device)]
|
| 402 |
+
# for k, v in weights.items()
|
| 403 |
+
# ]
|
| 404 |
+
# )
|
| 405 |
+
|
| 406 |
+
# keys = list(weights.keys())
|
| 407 |
+
# for w in self.arg.ignore_weights:
|
| 408 |
+
# for key in keys:
|
| 409 |
+
# if w in key:
|
| 410 |
+
# if weights.pop(key, None) is not None:
|
| 411 |
+
# self.print_log(
|
| 412 |
+
# "Sucessfully Remove Weights: {}.".format(key)
|
| 413 |
+
# )
|
| 414 |
+
# else:
|
| 415 |
+
# self.print_log("Can Not Remove Weights: {}.".format(key))
|
| 416 |
+
|
| 417 |
+
# try:
|
| 418 |
+
# self.model.load_state_dict(weights)
|
| 419 |
+
# except:
|
| 420 |
+
# state = self.model.state_dict()
|
| 421 |
+
# diff = list(set(state.keys()).difference(set(weights.keys())))
|
| 422 |
+
# print("Can not find these weights:")
|
| 423 |
+
# for d in diff:
|
| 424 |
+
# print(" " + d)
|
| 425 |
+
# state.update(weights)
|
| 426 |
+
# self.model.load_state_dict(state)
|
| 427 |
+
|
| 428 |
+
# if type(self.arg.device) is list:
|
| 429 |
+
# if len(self.arg.device) > 1:
|
| 430 |
+
# self.model = nn.DataParallel(
|
| 431 |
+
# self.model, device_ids=self.arg.device, output_device=output_device
|
| 432 |
+
# )
|
| 433 |
+
|
| 434 |
+
def load_optimizer(self):
|
| 435 |
+
if self.arg.optimizer == "SGD":
|
| 436 |
+
self.optimizer = optim.SGD(
|
| 437 |
+
self.model.parameters(),
|
| 438 |
+
lr=self.arg.base_lr,
|
| 439 |
+
momentum=0.9,
|
| 440 |
+
nesterov=self.arg.nesterov,
|
| 441 |
+
weight_decay=self.arg.weight_decay,
|
| 442 |
+
)
|
| 443 |
+
elif self.arg.optimizer == "Adam":
|
| 444 |
+
self.optimizer = optim.Adam(
|
| 445 |
+
self.model.parameters(),
|
| 446 |
+
lr=self.arg.base_lr,
|
| 447 |
+
weight_decay=self.arg.weight_decay,
|
| 448 |
+
)
|
| 449 |
+
else:
|
| 450 |
+
raise ValueError()
|
| 451 |
+
|
| 452 |
+
def save_arg(self):
|
| 453 |
+
# save arg
|
| 454 |
+
arg_dict = vars(self.arg)
|
| 455 |
+
if not os.path.exists(self.arg.work_dir):
|
| 456 |
+
os.makedirs(self.arg.work_dir)
|
| 457 |
+
with open("{}/config.yaml".format(self.arg.work_dir), "w") as f:
|
| 458 |
+
yaml.dump(arg_dict, f)
|
| 459 |
+
|
| 460 |
+
def adjust_learning_rate(self, epoch):
|
| 461 |
+
if self.arg.optimizer == "SGD" or self.arg.optimizer == "Adam":
|
| 462 |
+
if epoch < self.arg.warm_up_epoch:
|
| 463 |
+
lr = self.arg.base_lr * (epoch + 1) / self.arg.warm_up_epoch
|
| 464 |
+
else:
|
| 465 |
+
lr = self.arg.base_lr * (
|
| 466 |
+
0.1 ** np.sum(epoch >= np.array(self.arg.step))
|
| 467 |
+
)
|
| 468 |
+
for param_group in self.optimizer.param_groups:
|
| 469 |
+
param_group["lr"] = lr
|
| 470 |
+
|
| 471 |
+
return lr
|
| 472 |
+
else:
|
| 473 |
+
raise ValueError()
|
| 474 |
+
|
| 475 |
+
def print_time(self):
|
| 476 |
+
localtime = time.asctime(time.localtime(time.time()))
|
| 477 |
+
self.print_log("Local current time : " + localtime)
|
| 478 |
+
|
| 479 |
+
def print_log(self, str, print_time=True):
|
| 480 |
+
if print_time:
|
| 481 |
+
localtime = time.asctime(time.localtime(time.time()))
|
| 482 |
+
str = "[ " + localtime + " ] " + str
|
| 483 |
+
print(str)
|
| 484 |
+
if self.arg.print_log:
|
| 485 |
+
with open("{}/print_log.txt".format(self.arg.work_dir), "a") as f:
|
| 486 |
+
print(str, file=f)
|
| 487 |
+
|
| 488 |
+
def record_time(self):
|
| 489 |
+
self.cur_time = time.time()
|
| 490 |
+
return self.cur_time
|
| 491 |
+
|
| 492 |
+
def split_time(self):
|
| 493 |
+
split_time = time.time() - self.cur_time
|
| 494 |
+
self.record_time()
|
| 495 |
+
return split_time
|
| 496 |
+
|
| 497 |
+
def train(self, epoch, wb_dict, save_model=False):
|
| 498 |
+
self.model.train()
|
| 499 |
+
self.print_log("Training epoch: {}".format(epoch + 1))
|
| 500 |
+
loader = self.data_loader["train"]
|
| 501 |
+
self.adjust_learning_rate(epoch)
|
| 502 |
+
|
| 503 |
+
loss_value, batch_acc = [], []
|
| 504 |
+
self.train_writer.add_scalar("epoch", epoch, self.global_step)
|
| 505 |
+
self.record_time()
|
| 506 |
+
timer = dict(dataloader=0.001, model=0.001, statistics=0.001)
|
| 507 |
+
process = tqdm(loader)
|
| 508 |
+
if self.arg.only_train_part:
|
| 509 |
+
if epoch > self.arg.only_train_epoch:
|
| 510 |
+
print("only train part, require grad")
|
| 511 |
+
for key, value in self.model.named_parameters():
|
| 512 |
+
if "PA" in key:
|
| 513 |
+
value.requires_grad = True
|
| 514 |
+
else:
|
| 515 |
+
print("only train part, do not require grad")
|
| 516 |
+
for key, value in self.model.named_parameters():
|
| 517 |
+
if "PA" in key:
|
| 518 |
+
value.requires_grad = False
|
| 519 |
+
|
| 520 |
+
vid_preds = []
|
| 521 |
+
frm_preds = []
|
| 522 |
+
vid_lens = []
|
| 523 |
+
labels = []
|
| 524 |
+
|
| 525 |
+
results = []
|
| 526 |
+
indexs = []
|
| 527 |
+
|
| 528 |
+
for batch_idx, (data, label, target, mask, index, soft_label) in enumerate(
|
| 529 |
+
process
|
| 530 |
+
):
|
| 531 |
+
|
| 532 |
+
self.global_step += 1
|
| 533 |
+
# get data
|
| 534 |
+
data = data.float().cuda(self.output_device)
|
| 535 |
+
label = label.cuda(self.output_device)
|
| 536 |
+
mask = mask.cuda(self.output_device)
|
| 537 |
+
soft_label = soft_label.cuda(self.output_device)
|
| 538 |
+
timer["dataloader"] += self.split_time()
|
| 539 |
+
|
| 540 |
+
indexs.extend(index.cpu().numpy().tolist())
|
| 541 |
+
|
| 542 |
+
ab_labels = torch.cat([label, torch.ones(label.size(0), 1).cuda()], -1)
|
| 543 |
+
|
| 544 |
+
# forward
|
| 545 |
+
mil_pred, frm_scrs, mil_pred_2, frm_scrs_2 = self.model(data,mask)
|
| 546 |
+
|
| 547 |
+
cls_mil_loss = self.loss_nce(mil_pred, ab_labels.float()) + self.loss_nce(
|
| 548 |
+
mil_pred_2, ab_labels.float()
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
if epoch > 10:
|
| 552 |
+
|
| 553 |
+
frm_scrs_re = rearrange(frm_scrs, "n t c -> (n t) c")
|
| 554 |
+
frm_scrs_2_re = rearrange(frm_scrs_2, "n t c -> (n t) c")
|
| 555 |
+
soft_label = rearrange(soft_label, "n t c -> (n t) c")
|
| 556 |
+
|
| 557 |
+
loss = cls_mil_loss * 0.1 + mvl_loss(
|
| 558 |
+
frm_scrs, frm_scrs_2, rate=0.2, weight=0.5
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
loss += cross_entropy_loss(
|
| 562 |
+
frm_scrs_re, soft_label
|
| 563 |
+
) + cross_entropy_loss(frm_scrs_2_re, soft_label)
|
| 564 |
+
|
| 565 |
+
else:
|
| 566 |
+
loss = cls_mil_loss * self.arg.lambda_mil + mvl_loss(
|
| 567 |
+
frm_scrs, frm_scrs_2, rate=0.2, weight=0.5
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
for i in range(data.size(0)):
|
| 571 |
+
frm_scr = frm_scrs[i]
|
| 572 |
+
|
| 573 |
+
label_ = label[i].cpu().numpy()
|
| 574 |
+
mask_ = mask[i].cpu().numpy()
|
| 575 |
+
vid_len = mask_.sum()
|
| 576 |
+
|
| 577 |
+
frm_pred = F.softmax(frm_scr, -1).detach().cpu().numpy()[:vid_len]
|
| 578 |
+
vid_pred = mil_pred[i].detach().cpu().numpy()
|
| 579 |
+
|
| 580 |
+
results.append(frm_pred)
|
| 581 |
+
|
| 582 |
+
vid_preds.append(vid_pred)
|
| 583 |
+
frm_preds.append(frm_pred)
|
| 584 |
+
vid_lens.append(vid_len)
|
| 585 |
+
labels.append(label_)
|
| 586 |
+
|
| 587 |
+
# backward
|
| 588 |
+
self.optimizer.zero_grad()
|
| 589 |
+
loss.backward()
|
| 590 |
+
self.optimizer.step()
|
| 591 |
+
|
| 592 |
+
loss_value.append(loss.data.item())
|
| 593 |
+
timer["model"] += self.split_time()
|
| 594 |
+
|
| 595 |
+
vid_preds = np.array(vid_preds)
|
| 596 |
+
frm_preds = np.array(frm_preds)
|
| 597 |
+
vid_lens = np.array(vid_lens)
|
| 598 |
+
labels = np.array(labels)
|
| 599 |
+
|
| 600 |
+
loader.dataset.label_update(results, indexs)
|
| 601 |
+
|
| 602 |
+
cmap = cmAP(vid_preds, labels)
|
| 603 |
+
|
| 604 |
+
self.train_writer.add_scalar("acc", cmap, self.global_step)
|
| 605 |
+
self.train_writer.add_scalar("loss", np.mean(loss_value), self.global_step)
|
| 606 |
+
|
| 607 |
+
# statistics
|
| 608 |
+
self.lr = self.optimizer.param_groups[0]["lr"]
|
| 609 |
+
self.train_writer.add_scalar("lr", self.lr, self.global_step)
|
| 610 |
+
timer["statistics"] += self.split_time()
|
| 611 |
+
|
| 612 |
+
# statistics of time consumption and loss
|
| 613 |
+
self.print_log("\tMean training loss: {:.4f}.".format(np.mean(loss_value)))
|
| 614 |
+
self.print_log("\tAcc score: {:.3f}%".format(cmap))
|
| 615 |
+
|
| 616 |
+
# Log
|
| 617 |
+
wb_dict["train loss"] = np.mean(loss_value)
|
| 618 |
+
wb_dict["train acc"] = cmap
|
| 619 |
+
|
| 620 |
+
if save_model:
|
| 621 |
+
state_dict = self.model.state_dict()
|
| 622 |
+
weights = OrderedDict(
|
| 623 |
+
[[k.split("module.")[-1], v.cpu()] for k, v in state_dict.items()]
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
torch.save(
|
| 627 |
+
weights,
|
| 628 |
+
self.arg.model_saved_name + str(epoch) + ".pt",
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
return wb_dict
|
| 632 |
+
|
| 633 |
+
@torch.no_grad()
|
| 634 |
+
def eval(
|
| 635 |
+
self,
|
| 636 |
+
epoch,
|
| 637 |
+
wb_dict,
|
| 638 |
+
loader_name=["test"],
|
| 639 |
+
):
|
| 640 |
+
self.model.eval()
|
| 641 |
+
self.print_log("Eval epoch: {}".format(epoch + 1))
|
| 642 |
+
|
| 643 |
+
vid_preds = []
|
| 644 |
+
frm_preds = []
|
| 645 |
+
vid_lens = []
|
| 646 |
+
labels = []
|
| 647 |
+
|
| 648 |
+
for ln in loader_name:
|
| 649 |
+
loss_value = []
|
| 650 |
+
step = 0
|
| 651 |
+
process = tqdm(self.data_loader[ln])
|
| 652 |
+
|
| 653 |
+
for batch_idx, (data, label, target, mask, index, soft_label) in enumerate(
|
| 654 |
+
process
|
| 655 |
+
):
|
| 656 |
+
data = data.float().cuda(self.output_device)
|
| 657 |
+
label = label.cuda(self.output_device)
|
| 658 |
+
mask = mask.cuda(self.output_device)
|
| 659 |
+
|
| 660 |
+
ab_labels = torch.cat([label, torch.ones(label.size(0), 1).cuda()], -1)
|
| 661 |
+
|
| 662 |
+
# forward
|
| 663 |
+
mil_pred, frm_scrs, mil_pred_2, frm_scrs_2 = self.model(data,mask)
|
| 664 |
+
|
| 665 |
+
cls_mil_loss = self.loss_nce(
|
| 666 |
+
mil_pred, ab_labels.float()
|
| 667 |
+
) + self.loss_nce(mil_pred_2, ab_labels.float())
|
| 668 |
+
|
| 669 |
+
loss_co = mvl_loss(frm_scrs, frm_scrs_2, rate=0.2, weight=0.5)
|
| 670 |
+
|
| 671 |
+
loss = cls_mil_loss * self.arg.lambda_mil + loss_co
|
| 672 |
+
|
| 673 |
+
loss_value.append(loss.data.item())
|
| 674 |
+
|
| 675 |
+
for i in range(data.size(0)):
|
| 676 |
+
frm_scr = frm_scrs[i]
|
| 677 |
+
vid_pred = mil_pred[i]
|
| 678 |
+
|
| 679 |
+
label_ = label[i].cpu().numpy()
|
| 680 |
+
mask_ = mask[i].cpu().numpy()
|
| 681 |
+
vid_len = mask_.sum()
|
| 682 |
+
|
| 683 |
+
frm_pred = F.softmax(frm_scr, -1).cpu().numpy()[:vid_len]
|
| 684 |
+
vid_pred = vid_pred.cpu().numpy()
|
| 685 |
+
|
| 686 |
+
vid_preds.append(vid_pred)
|
| 687 |
+
frm_preds.append(frm_pred)
|
| 688 |
+
vid_lens.append(vid_len)
|
| 689 |
+
labels.append(label_)
|
| 690 |
+
|
| 691 |
+
step += 1
|
| 692 |
+
|
| 693 |
+
vid_preds = np.array(vid_preds)
|
| 694 |
+
frm_preds = np.array(frm_preds)
|
| 695 |
+
vid_lens = np.array(vid_lens)
|
| 696 |
+
labels = np.array(labels)
|
| 697 |
+
|
| 698 |
+
cmap = cmAP(vid_preds, labels)
|
| 699 |
+
|
| 700 |
+
score = cmap
|
| 701 |
+
loss = np.mean(loss_value)
|
| 702 |
+
|
| 703 |
+
dmap, iou = dsmAP(
|
| 704 |
+
vid_preds,
|
| 705 |
+
frm_preds,
|
| 706 |
+
vid_lens,
|
| 707 |
+
self.arg.test_feeder_args["data_path"],
|
| 708 |
+
self.arg,
|
| 709 |
+
multi=True,
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
print("Classification map %f" % cmap)
|
| 713 |
+
for item in list(zip(iou, dmap)):
|
| 714 |
+
print("Detection map @ %f = %f" % (item[0], item[1]))
|
| 715 |
+
|
| 716 |
+
self.my_logger.append([epoch + 1, cmap] + dmap)
|
| 717 |
+
|
| 718 |
+
wb_dict["val loss"] = loss
|
| 719 |
+
wb_dict["val acc"] = score
|
| 720 |
+
|
| 721 |
+
if score > self.best_acc:
|
| 722 |
+
self.best_acc = score
|
| 723 |
+
|
| 724 |
+
print("Acc score: ", score, " model: ", self.arg.model_saved_name)
|
| 725 |
+
if self.arg.phase == "train":
|
| 726 |
+
self.val_writer.add_scalar("loss", loss, self.global_step)
|
| 727 |
+
self.val_writer.add_scalar("acc", score, self.global_step)
|
| 728 |
+
|
| 729 |
+
self.print_log(
|
| 730 |
+
"\tMean {} loss of {} batches: {}.".format(
|
| 731 |
+
ln, len(self.data_loader[ln]), np.mean(loss_value)
|
| 732 |
+
)
|
| 733 |
+
)
|
| 734 |
+
self.print_log("\tAcc score: {:.3f}%".format(score))
|
| 735 |
+
|
| 736 |
+
return wb_dict
|
| 737 |
+
|
| 738 |
+
def start(self):
|
| 739 |
+
wb_dict = {}
|
| 740 |
+
if self.arg.phase == "train":
|
| 741 |
+
self.print_log("Parameters:\n{}\n".format(str(vars(self.arg))))
|
| 742 |
+
self.global_step = (
|
| 743 |
+
self.arg.start_epoch
|
| 744 |
+
* len(self.data_loader["train"])
|
| 745 |
+
/ self.arg.batch_size
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
for epoch in range(self.arg.start_epoch, self.arg.num_epoch):
|
| 749 |
+
|
| 750 |
+
save_model = ((epoch + 1) % self.arg.save_interval == 0) or (
|
| 751 |
+
epoch + 1 == self.arg.num_epoch
|
| 752 |
+
)
|
| 753 |
+
wb_dict = {"lr": self.lr}
|
| 754 |
+
|
| 755 |
+
# Train
|
| 756 |
+
wb_dict = self.train(epoch, wb_dict, save_model=save_model)
|
| 757 |
+
|
| 758 |
+
# Eval. on val set
|
| 759 |
+
wb_dict = self.eval(epoch, wb_dict, loader_name=["test"])
|
| 760 |
+
# Log stats. for this epoch
|
| 761 |
+
print("Epoch: {0}\nMetrics: {1}".format(epoch, wb_dict))
|
| 762 |
+
|
| 763 |
+
print(
|
| 764 |
+
"best accuracy: ",
|
| 765 |
+
self.best_acc,
|
| 766 |
+
" model_name: ",
|
| 767 |
+
self.arg.model_saved_name,
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
elif self.arg.phase == "test":
|
| 771 |
+
if not self.arg.test_feeder_args["debug"]:
|
| 772 |
+
wf = self.arg.model_saved_name + "_wrong.txt"
|
| 773 |
+
rf = self.arg.model_saved_name + "_right.txt"
|
| 774 |
+
else:
|
| 775 |
+
wf = rf = None
|
| 776 |
+
if self.arg.weights is None:
|
| 777 |
+
raise ValueError("Please appoint --weights.")
|
| 778 |
+
self.arg.print_log = False
|
| 779 |
+
self.print_log("Model: {}.".format(self.arg.model))
|
| 780 |
+
self.print_log("Weights: {}.".format(self.arg.weights))
|
| 781 |
+
|
| 782 |
+
wb_dict = self.eval(
|
| 783 |
+
epoch=0,
|
| 784 |
+
wb_dict=wb_dict,
|
| 785 |
+
loader_name=["test"],
|
| 786 |
+
wrong_file=wf,
|
| 787 |
+
result_file=rf,
|
| 788 |
+
)
|
| 789 |
+
print("Inference metrics: ", wb_dict)
|
| 790 |
+
self.print_log("Done.\n")
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
def str2bool(v):
|
| 794 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 795 |
+
return True
|
| 796 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 797 |
+
return False
|
| 798 |
+
else:
|
| 799 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
def import_class(name):
|
| 803 |
+
components = name.split(".")
|
| 804 |
+
mod = __import__(components[0])
|
| 805 |
+
for comp in components[1:]:
|
| 806 |
+
mod = getattr(mod, comp)
|
| 807 |
+
return mod
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
if __name__ == "__main__":
|
| 811 |
+
parser = get_parser()
|
| 812 |
+
|
| 813 |
+
# load arg form config file
|
| 814 |
+
p = parser.parse_args()
|
| 815 |
+
if p.config is not None:
|
| 816 |
+
with open(p.config, "r") as f:
|
| 817 |
+
default_arg = yaml.safe_load(f)
|
| 818 |
+
key = vars(p).keys()
|
| 819 |
+
for k in default_arg.keys():
|
| 820 |
+
if k not in key:
|
| 821 |
+
print("WRONG ARG: {}".format(k))
|
| 822 |
+
assert k in key
|
| 823 |
+
parser.set_defaults(**default_arg)
|
| 824 |
+
|
| 825 |
+
arg = parser.parse_args()
|
| 826 |
+
print("BABEL Action Recognition")
|
| 827 |
+
print("Config: ", arg)
|
| 828 |
+
init_seed(arg.seed)
|
| 829 |
+
processor = Processor(arg)
|
| 830 |
+
processor.start()
|
train_full.py
ADDED
|
@@ -0,0 +1,788 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import print_function
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import inspect
|
| 6 |
+
import os
|
| 7 |
+
import pdb
|
| 8 |
+
import pickle
|
| 9 |
+
import random
|
| 10 |
+
import re
|
| 11 |
+
import shutil
|
| 12 |
+
import time
|
| 13 |
+
from collections import *
|
| 14 |
+
|
| 15 |
+
import ipdb
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
# torch
|
| 19 |
+
import torch
|
| 20 |
+
import torch.backends.cudnn as cudnn
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
import torch.optim as optim
|
| 24 |
+
import yaml
|
| 25 |
+
from einops import rearrange, reduce, repeat
|
| 26 |
+
from evaluation.classificationMAP import getClassificationMAP as cmAP
|
| 27 |
+
from evaluation.detectionMAP import getSingleStreamDetectionMAP as dsmAP
|
| 28 |
+
from feeders.tools import collate_with_padding_multi_joint
|
| 29 |
+
from model.losses import cross_entropy_loss, mvl_loss
|
| 30 |
+
from sklearn.metrics import f1_score
|
| 31 |
+
|
| 32 |
+
# Custom
|
| 33 |
+
from tensorboardX import SummaryWriter
|
| 34 |
+
from torch.autograd import Variable
|
| 35 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 36 |
+
from tqdm import tqdm
|
| 37 |
+
from utils.logger import Logger
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# seed = 0
|
| 41 |
+
# random.seed(seed)
|
| 42 |
+
# np.random.seed(seed)
|
| 43 |
+
# torch.manual_seed(seed)
|
| 44 |
+
# torch.cuda.manual_seed_all(seed)
|
| 45 |
+
# torch.use_deterministic_algorithms(True)
|
| 46 |
+
# torch.backends.cudnn.deterministic = True
|
| 47 |
+
# torch.backends.cudnn.benchmark = False
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def init_seed(seed):
|
| 51 |
+
torch.cuda.manual_seed_all(seed)
|
| 52 |
+
torch.manual_seed(seed)
|
| 53 |
+
np.random.seed(seed)
|
| 54 |
+
random.seed(seed)
|
| 55 |
+
torch.backends.cudnn.deterministic = True
|
| 56 |
+
torch.backends.cudnn.benchmark = False
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_parser():
|
| 60 |
+
# parameter priority: command line > config > default
|
| 61 |
+
parser = argparse.ArgumentParser(
|
| 62 |
+
description="Spatial Temporal Graph Convolution Network"
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--work-dir",
|
| 66 |
+
default="./work_dir/temp",
|
| 67 |
+
help="the work folder for storing results",
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
parser.add_argument("-model_saved_name", default="")
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--config",
|
| 73 |
+
default="./config/nturgbd-cross-view/test_bone.yaml",
|
| 74 |
+
help="path to the configuration file",
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# processor
|
| 78 |
+
parser.add_argument("--phase", default="train", help="must be train or test")
|
| 79 |
+
|
| 80 |
+
# visulize and debug
|
| 81 |
+
parser.add_argument("--seed", type=int, default=5, help="random seed for pytorch")
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--log-interval",
|
| 84 |
+
type=int,
|
| 85 |
+
default=100,
|
| 86 |
+
help="the interval for printing messages (#iteration)",
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--save-interval",
|
| 90 |
+
type=int,
|
| 91 |
+
default=2,
|
| 92 |
+
help="the interval for storing models (#iteration)",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--eval-interval",
|
| 96 |
+
type=int,
|
| 97 |
+
default=5,
|
| 98 |
+
help="the interval for evaluating models (#iteration)",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--print-log", type=str2bool, default=True, help="print logging or not"
|
| 102 |
+
)
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--show-topk",
|
| 105 |
+
type=int,
|
| 106 |
+
default=[1, 5],
|
| 107 |
+
nargs="+",
|
| 108 |
+
help="which Top K accuracy will be shown",
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# feeder
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--feeder", default="feeder.feeder", help="data loader will be used"
|
| 114 |
+
)
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--num-worker",
|
| 117 |
+
type=int,
|
| 118 |
+
default=32,
|
| 119 |
+
help="the number of worker for data loader",
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--train-feeder-args",
|
| 123 |
+
default=dict(),
|
| 124 |
+
help="the arguments of data loader for training",
|
| 125 |
+
)
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--test-feeder-args",
|
| 128 |
+
default=dict(),
|
| 129 |
+
help="the arguments of data loader for test",
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# model
|
| 133 |
+
parser.add_argument("--model", default=None, help="the model will be used")
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--model-args", type=dict, default=dict(), help="the arguments of model"
|
| 136 |
+
)
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
"--weights", default=None, help="the weights for network initialization"
|
| 139 |
+
)
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--ignore-weights",
|
| 142 |
+
type=str,
|
| 143 |
+
default=[],
|
| 144 |
+
nargs="+",
|
| 145 |
+
help="the name of weights which will be ignored in the initialization",
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# optim
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--base-lr", type=float, default=0.01, help="initial learning rate"
|
| 151 |
+
)
|
| 152 |
+
parser.add_argument(
|
| 153 |
+
"--step",
|
| 154 |
+
type=int,
|
| 155 |
+
default=[60,80],
|
| 156 |
+
nargs="+",
|
| 157 |
+
help="the epoch where optimizer reduce the learning rate",
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# training
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--device",
|
| 163 |
+
type=int,
|
| 164 |
+
default=0,
|
| 165 |
+
nargs="+",
|
| 166 |
+
help="the indexes of GPUs for training or testing",
|
| 167 |
+
)
|
| 168 |
+
parser.add_argument("--optimizer", default="SGD", help="type of optimizer")
|
| 169 |
+
parser.add_argument(
|
| 170 |
+
"--nesterov", type=str2bool, default=False, help="use nesterov or not"
|
| 171 |
+
)
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--batch-size", type=int, default=256, help="training batch size"
|
| 174 |
+
)
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--test-batch-size", type=int, default=256, help="test batch size"
|
| 177 |
+
)
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"--start-epoch", type=int, default=0, help="start training from which epoch"
|
| 180 |
+
)
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--num-epoch", type=int, default=80, help="stop training in which epoch"
|
| 183 |
+
)
|
| 184 |
+
parser.add_argument(
|
| 185 |
+
"--weight-decay", type=float, default=0.0005, help="weight decay for optimizer"
|
| 186 |
+
)
|
| 187 |
+
# loss
|
| 188 |
+
parser.add_argument("--loss", type=str, default="CE", help="loss type(CE or focal)")
|
| 189 |
+
parser.add_argument(
|
| 190 |
+
"--label_count_path",
|
| 191 |
+
default=None,
|
| 192 |
+
type=str,
|
| 193 |
+
help="Path to label counts (used in loss weighting)",
|
| 194 |
+
)
|
| 195 |
+
parser.add_argument(
|
| 196 |
+
"---beta",
|
| 197 |
+
type=float,
|
| 198 |
+
default=0.9999,
|
| 199 |
+
help="Hyperparameter for Class balanced loss",
|
| 200 |
+
)
|
| 201 |
+
parser.add_argument(
|
| 202 |
+
"--gamma", type=float, default=2.0, help="Hyperparameter for Focal loss"
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
parser.add_argument("--only_train_part", default=False)
|
| 206 |
+
parser.add_argument("--only_train_epoch", default=0)
|
| 207 |
+
parser.add_argument("--warm_up_epoch", default=10)
|
| 208 |
+
|
| 209 |
+
parser.add_argument(
|
| 210 |
+
"--lambda-mil", default=1.0, help="balancing hyper-parameter of mil branch"
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
parser.add_argument(
|
| 214 |
+
"--class-threshold",
|
| 215 |
+
type=float,
|
| 216 |
+
default=0.1,
|
| 217 |
+
help="class threshold for rejection",
|
| 218 |
+
)
|
| 219 |
+
parser.add_argument(
|
| 220 |
+
"--start-threshold",
|
| 221 |
+
type=float,
|
| 222 |
+
default=0.03,
|
| 223 |
+
help="start threshold for action localization",
|
| 224 |
+
)
|
| 225 |
+
parser.add_argument(
|
| 226 |
+
"--end-threshold",
|
| 227 |
+
type=float,
|
| 228 |
+
default=0.055,
|
| 229 |
+
help="end threshold for action localization",
|
| 230 |
+
)
|
| 231 |
+
parser.add_argument(
|
| 232 |
+
"--threshold-interval",
|
| 233 |
+
type=float,
|
| 234 |
+
default=0.005,
|
| 235 |
+
help="threshold interval for action localization",
|
| 236 |
+
)
|
| 237 |
+
return parser
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class Processor:
|
| 241 |
+
"""
|
| 242 |
+
Processor for Skeleton-based Action Recgnition
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
def __init__(self, arg):
|
| 246 |
+
self.arg = arg
|
| 247 |
+
self.save_arg()
|
| 248 |
+
if arg.phase == "train":
|
| 249 |
+
if not arg.train_feeder_args["debug"]:
|
| 250 |
+
if os.path.isdir(arg.model_saved_name):
|
| 251 |
+
print("log_dir: ", arg.model_saved_name, "already exist")
|
| 252 |
+
# answer = input('delete it? y/n:')
|
| 253 |
+
answer = "y"
|
| 254 |
+
if answer == "y":
|
| 255 |
+
print("Deleting dir...")
|
| 256 |
+
shutil.rmtree(arg.model_saved_name)
|
| 257 |
+
print("Dir removed: ", arg.model_saved_name)
|
| 258 |
+
# input('Refresh the website of tensorboard by pressing any keys')
|
| 259 |
+
else:
|
| 260 |
+
print("Dir not removed: ", arg.model_saved_name)
|
| 261 |
+
self.train_writer = SummaryWriter(
|
| 262 |
+
os.path.join(arg.model_saved_name, "train"), "train"
|
| 263 |
+
)
|
| 264 |
+
self.val_writer = SummaryWriter(
|
| 265 |
+
os.path.join(arg.model_saved_name, "val"), "val"
|
| 266 |
+
)
|
| 267 |
+
else:
|
| 268 |
+
self.train_writer = self.val_writer = SummaryWriter(
|
| 269 |
+
os.path.join(arg.model_saved_name, "test"), "test"
|
| 270 |
+
)
|
| 271 |
+
self.global_step = 0
|
| 272 |
+
self.load_model()
|
| 273 |
+
self.load_optimizer()
|
| 274 |
+
self.load_data()
|
| 275 |
+
self.lr = self.arg.base_lr
|
| 276 |
+
self.best_acc = 0
|
| 277 |
+
self.best_per_class_acc = 0
|
| 278 |
+
self.loss_nce = torch.nn.BCELoss()
|
| 279 |
+
|
| 280 |
+
self.my_logger = Logger(
|
| 281 |
+
os.path.join(arg.model_saved_name, "log.txt"), title="SWTAL"
|
| 282 |
+
)
|
| 283 |
+
self.my_logger.set_names(["Step", "cmap"] + [f"map_0.{i}" for i in range(1, 6)]+['avg'])
|
| 284 |
+
|
| 285 |
+
def load_data(self):
|
| 286 |
+
Feeder = import_class(self.arg.feeder)
|
| 287 |
+
self.data_loader = dict()
|
| 288 |
+
if self.arg.phase == "train":
|
| 289 |
+
self.data_loader["train"] = torch.utils.data.DataLoader(
|
| 290 |
+
dataset=Feeder(**self.arg.train_feeder_args),
|
| 291 |
+
batch_size=self.arg.batch_size,
|
| 292 |
+
shuffle=True,
|
| 293 |
+
num_workers=self.arg.num_worker,
|
| 294 |
+
drop_last=True,
|
| 295 |
+
collate_fn=collate_with_padding_multi_joint,
|
| 296 |
+
)
|
| 297 |
+
self.data_loader["test"] = torch.utils.data.DataLoader(
|
| 298 |
+
dataset=Feeder(**self.arg.test_feeder_args),
|
| 299 |
+
batch_size=self.arg.test_batch_size,
|
| 300 |
+
shuffle=False,
|
| 301 |
+
num_workers=self.arg.num_worker,
|
| 302 |
+
drop_last=False,
|
| 303 |
+
collate_fn=collate_with_padding_multi_joint,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
def load_model(self):
|
| 307 |
+
output_device = (
|
| 308 |
+
self.arg.device[0] if type(self.arg.device) is list else self.arg.device
|
| 309 |
+
)
|
| 310 |
+
self.output_device = output_device
|
| 311 |
+
Model = import_class(self.arg.model)
|
| 312 |
+
shutil.copy2(inspect.getfile(Model), self.arg.work_dir)
|
| 313 |
+
# print(Model)
|
| 314 |
+
self.model = Model(**self.arg.model_args).cuda(output_device)
|
| 315 |
+
# print(self.model)
|
| 316 |
+
self.loss_type = arg.loss
|
| 317 |
+
|
| 318 |
+
if self.arg.weights:
|
| 319 |
+
# self.global_step = int(arg.weights[:-3].split("-")[-1])
|
| 320 |
+
self.print_log("Load weights from {}.".format(self.arg.weights))
|
| 321 |
+
if ".pkl" in self.arg.weights:
|
| 322 |
+
with open(self.arg.weights, "r") as f:
|
| 323 |
+
weights = pickle.load(f)
|
| 324 |
+
else:
|
| 325 |
+
weights = torch.load(self.arg.weights)
|
| 326 |
+
|
| 327 |
+
weights = OrderedDict(
|
| 328 |
+
[
|
| 329 |
+
[k.split("module.")[-1], v.cuda(output_device)]
|
| 330 |
+
for k, v in weights.items()
|
| 331 |
+
]
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
keys = list(weights.keys())
|
| 335 |
+
for w in self.arg.ignore_weights:
|
| 336 |
+
for key in keys:
|
| 337 |
+
if w in key:
|
| 338 |
+
if weights.pop(key, None) is not None:
|
| 339 |
+
self.print_log(
|
| 340 |
+
"Sucessfully Remove Weights: {}.".format(key)
|
| 341 |
+
)
|
| 342 |
+
else:
|
| 343 |
+
self.print_log("Can Not Remove Weights: {}.".format(key))
|
| 344 |
+
|
| 345 |
+
try:
|
| 346 |
+
self.model.load_state_dict(weights)
|
| 347 |
+
except:
|
| 348 |
+
state = self.model.state_dict()
|
| 349 |
+
diff = list(set(state.keys()).difference(set(weights.keys())))
|
| 350 |
+
print("Can not find these weights:")
|
| 351 |
+
for d in diff:
|
| 352 |
+
print(" " + d)
|
| 353 |
+
state.update(weights)
|
| 354 |
+
self.model.load_state_dict(state)
|
| 355 |
+
|
| 356 |
+
if type(self.arg.device) is list:
|
| 357 |
+
if len(self.arg.device) > 1:
|
| 358 |
+
self.model = nn.DataParallel(
|
| 359 |
+
self.model, device_ids=self.arg.device, output_device=output_device
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
def load_optimizer(self):
|
| 363 |
+
if self.arg.optimizer == "SGD":
|
| 364 |
+
self.optimizer = optim.SGD(
|
| 365 |
+
self.model.parameters(),
|
| 366 |
+
lr=self.arg.base_lr,
|
| 367 |
+
momentum=0.9,
|
| 368 |
+
nesterov=self.arg.nesterov,
|
| 369 |
+
weight_decay=self.arg.weight_decay,
|
| 370 |
+
)
|
| 371 |
+
elif self.arg.optimizer == "Adam":
|
| 372 |
+
self.optimizer = optim.Adam(
|
| 373 |
+
self.model.parameters(),
|
| 374 |
+
lr=self.arg.base_lr,
|
| 375 |
+
weight_decay=self.arg.weight_decay,
|
| 376 |
+
)
|
| 377 |
+
else:
|
| 378 |
+
raise ValueError()
|
| 379 |
+
|
| 380 |
+
def save_arg(self):
|
| 381 |
+
# save arg
|
| 382 |
+
arg_dict = vars(self.arg)
|
| 383 |
+
if not os.path.exists(self.arg.work_dir):
|
| 384 |
+
os.makedirs(self.arg.work_dir)
|
| 385 |
+
with open("{}/config.yaml".format(self.arg.work_dir), "w") as f:
|
| 386 |
+
yaml.dump(arg_dict, f)
|
| 387 |
+
|
| 388 |
+
def adjust_learning_rate(self, epoch):
|
| 389 |
+
if self.arg.optimizer == "SGD" or self.arg.optimizer == "Adam":
|
| 390 |
+
if epoch < self.arg.warm_up_epoch:
|
| 391 |
+
lr = self.arg.base_lr * (epoch + 1) / self.arg.warm_up_epoch
|
| 392 |
+
else:
|
| 393 |
+
lr = self.arg.base_lr * (
|
| 394 |
+
0.1 ** np.sum(epoch >= np.array(self.arg.step))
|
| 395 |
+
)
|
| 396 |
+
for param_group in self.optimizer.param_groups:
|
| 397 |
+
param_group["lr"] = lr
|
| 398 |
+
|
| 399 |
+
return lr
|
| 400 |
+
else:
|
| 401 |
+
raise ValueError()
|
| 402 |
+
|
| 403 |
+
def print_time(self):
|
| 404 |
+
localtime = time.asctime(time.localtime(time.time()))
|
| 405 |
+
self.print_log("Local current time : " + localtime)
|
| 406 |
+
|
| 407 |
+
def print_log(self, str, print_time=True):
|
| 408 |
+
if print_time:
|
| 409 |
+
localtime = time.asctime(time.localtime(time.time()))
|
| 410 |
+
str = "[ " + localtime + " ] " + str
|
| 411 |
+
print(str)
|
| 412 |
+
if self.arg.print_log:
|
| 413 |
+
with open("{}/print_log.txt".format(self.arg.work_dir), "a") as f:
|
| 414 |
+
print(str, file=f)
|
| 415 |
+
|
| 416 |
+
def record_time(self):
|
| 417 |
+
self.cur_time = time.time()
|
| 418 |
+
return self.cur_time
|
| 419 |
+
|
| 420 |
+
def split_time(self):
|
| 421 |
+
split_time = time.time() - self.cur_time
|
| 422 |
+
self.record_time()
|
| 423 |
+
return split_time
|
| 424 |
+
|
| 425 |
+
def train(self, epoch, wb_dict, save_model=False):
|
| 426 |
+
self.model.train()
|
| 427 |
+
self.print_log("Training epoch: {}".format(epoch + 1))
|
| 428 |
+
loader = self.data_loader["train"]
|
| 429 |
+
self.adjust_learning_rate(epoch)
|
| 430 |
+
|
| 431 |
+
loss_value, batch_acc = [], []
|
| 432 |
+
self.train_writer.add_scalar("epoch", epoch, self.global_step)
|
| 433 |
+
self.record_time()
|
| 434 |
+
timer = dict(dataloader=0.001, model=0.001, statistics=0.001)
|
| 435 |
+
process = tqdm(loader)
|
| 436 |
+
if self.arg.only_train_part:
|
| 437 |
+
if epoch > self.arg.only_train_epoch:
|
| 438 |
+
print("only train part, require grad")
|
| 439 |
+
for key, value in self.model.named_parameters():
|
| 440 |
+
if "PA" in key:
|
| 441 |
+
value.requires_grad = True
|
| 442 |
+
else:
|
| 443 |
+
print("only train part, do not require grad")
|
| 444 |
+
for key, value in self.model.named_parameters():
|
| 445 |
+
if "PA" in key:
|
| 446 |
+
value.requires_grad = False
|
| 447 |
+
|
| 448 |
+
vid_preds = []
|
| 449 |
+
frm_preds = []
|
| 450 |
+
vid_lens = []
|
| 451 |
+
labels = []
|
| 452 |
+
|
| 453 |
+
results = []
|
| 454 |
+
indexs = []
|
| 455 |
+
|
| 456 |
+
'''
|
| 457 |
+
Switch to FULL supervision
|
| 458 |
+
Dataloader->Feeder ->collate_with_padding_multi_joint
|
| 459 |
+
'''
|
| 460 |
+
|
| 461 |
+
for batch_idx, (data, label, target, mask, index, soft_label) in enumerate(
|
| 462 |
+
process
|
| 463 |
+
):
|
| 464 |
+
|
| 465 |
+
self.global_step += 1
|
| 466 |
+
# get data
|
| 467 |
+
data = data.float().cuda(self.output_device)
|
| 468 |
+
label = label.cuda(self.output_device)
|
| 469 |
+
target = target.cuda(self.output_device)
|
| 470 |
+
mask = mask.cuda(self.output_device)
|
| 471 |
+
soft_label = soft_label.cuda(self.output_device)
|
| 472 |
+
timer["dataloader"] += self.split_time()
|
| 473 |
+
|
| 474 |
+
''' into one hot'''
|
| 475 |
+
ground_truth_flat = target.view(-1)
|
| 476 |
+
one_hot_ground_truth = F.one_hot(ground_truth_flat, num_classes=5)
|
| 477 |
+
''' into one hot'''
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
indexs.extend(index.cpu().numpy().tolist())
|
| 481 |
+
|
| 482 |
+
ab_labels = torch.cat([label, torch.ones(label.size(0), 1).cuda()], -1)
|
| 483 |
+
|
| 484 |
+
# forward
|
| 485 |
+
mil_pred, frm_scrs, mil_pred_2, frm_scrs_2 = self.model(data,mask)
|
| 486 |
+
|
| 487 |
+
cls_mil_loss = self.loss_nce(mil_pred, ab_labels.float()) + self.loss_nce(
|
| 488 |
+
mil_pred_2, ab_labels.float()
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
if epoch > -1:
|
| 492 |
+
|
| 493 |
+
frm_scrs_re = rearrange(frm_scrs, "n t c -> (n t) c")
|
| 494 |
+
frm_scrs_2_re = rearrange(frm_scrs_2, "n t c -> (n t) c")
|
| 495 |
+
# soft_label = rearrange(soft_label, "n t c -> (n t) c")
|
| 496 |
+
|
| 497 |
+
loss = cls_mil_loss * 0.1 + mvl_loss(
|
| 498 |
+
frm_scrs, frm_scrs_2, rate=0.2, weight=0.5
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
loss += cross_entropy_loss(
|
| 502 |
+
frm_scrs_re, one_hot_ground_truth
|
| 503 |
+
) + cross_entropy_loss(frm_scrs_2_re, one_hot_ground_truth)
|
| 504 |
+
|
| 505 |
+
# else:
|
| 506 |
+
# loss = cls_mil_loss * self.arg.lambda_mil + mvl_loss(
|
| 507 |
+
# frm_scrs, frm_scrs_2, rate=0.2, weight=0.5
|
| 508 |
+
# )
|
| 509 |
+
|
| 510 |
+
for i in range(data.size(0)):
|
| 511 |
+
frm_scr = frm_scrs[i]
|
| 512 |
+
|
| 513 |
+
label_ = label[i].cpu().numpy()
|
| 514 |
+
mask_ = mask[i].cpu().numpy()
|
| 515 |
+
vid_len = mask_.sum()
|
| 516 |
+
|
| 517 |
+
frm_pred = F.softmax(frm_scr, -1).detach().cpu().numpy()[:vid_len]
|
| 518 |
+
vid_pred = mil_pred[i].detach().cpu().numpy()
|
| 519 |
+
|
| 520 |
+
results.append(frm_pred)
|
| 521 |
+
|
| 522 |
+
vid_preds.append(vid_pred)
|
| 523 |
+
frm_preds.append(frm_pred)
|
| 524 |
+
vid_lens.append(vid_len)
|
| 525 |
+
labels.append(label_)
|
| 526 |
+
|
| 527 |
+
# backward
|
| 528 |
+
self.optimizer.zero_grad()
|
| 529 |
+
loss.backward()
|
| 530 |
+
self.optimizer.step()
|
| 531 |
+
|
| 532 |
+
loss_value.append(loss.data.item())
|
| 533 |
+
timer["model"] += self.split_time()
|
| 534 |
+
|
| 535 |
+
vid_preds = np.array(vid_preds)
|
| 536 |
+
frm_preds = np.array(frm_preds)
|
| 537 |
+
vid_lens = np.array(vid_lens)
|
| 538 |
+
labels = np.array(labels)
|
| 539 |
+
|
| 540 |
+
loader.dataset.label_update(results, indexs)
|
| 541 |
+
|
| 542 |
+
cmap = cmAP(vid_preds, labels)
|
| 543 |
+
|
| 544 |
+
self.train_writer.add_scalar("acc", cmap, self.global_step)
|
| 545 |
+
self.train_writer.add_scalar("loss", np.mean(loss_value), self.global_step)
|
| 546 |
+
|
| 547 |
+
# statistics
|
| 548 |
+
self.lr = self.optimizer.param_groups[0]["lr"]
|
| 549 |
+
self.train_writer.add_scalar("lr", self.lr, self.global_step)
|
| 550 |
+
timer["statistics"] += self.split_time()
|
| 551 |
+
|
| 552 |
+
# statistics of time consumption and loss
|
| 553 |
+
self.print_log("\tMean training loss: {:.4f}.".format(np.mean(loss_value)))
|
| 554 |
+
self.print_log("\tAcc score: {:.3f}%".format(cmap))
|
| 555 |
+
|
| 556 |
+
# Log
|
| 557 |
+
wb_dict["train loss"] = np.mean(loss_value)
|
| 558 |
+
wb_dict["train acc"] = cmap
|
| 559 |
+
|
| 560 |
+
if save_model:
|
| 561 |
+
state_dict = self.model.state_dict()
|
| 562 |
+
weights = OrderedDict(
|
| 563 |
+
[[k.split("module.")[-1], v.cpu()] for k, v in state_dict.items()]
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
torch.save(
|
| 567 |
+
weights,
|
| 568 |
+
self.arg.model_saved_name + str(epoch) + ".pt",
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
return wb_dict
|
| 572 |
+
|
| 573 |
+
@torch.no_grad()
|
| 574 |
+
def eval(
|
| 575 |
+
self,
|
| 576 |
+
epoch,
|
| 577 |
+
wb_dict,
|
| 578 |
+
loader_name=["test"],
|
| 579 |
+
):
|
| 580 |
+
self.model.eval()
|
| 581 |
+
self.print_log("Eval epoch: {}".format(epoch + 1))
|
| 582 |
+
|
| 583 |
+
vid_preds = []
|
| 584 |
+
frm_preds = []
|
| 585 |
+
vid_lens = []
|
| 586 |
+
labels = []
|
| 587 |
+
|
| 588 |
+
for ln in loader_name:
|
| 589 |
+
loss_value = []
|
| 590 |
+
step = 0
|
| 591 |
+
process = tqdm(self.data_loader[ln])
|
| 592 |
+
|
| 593 |
+
for batch_idx, (data, label, target, mask, index, soft_label) in enumerate(
|
| 594 |
+
process
|
| 595 |
+
):
|
| 596 |
+
data = data.float().cuda(self.output_device)
|
| 597 |
+
label = label.cuda(self.output_device)
|
| 598 |
+
mask = mask.cuda(self.output_device)
|
| 599 |
+
|
| 600 |
+
ab_labels = torch.cat([label, torch.ones(label.size(0), 1).cuda()], -1)
|
| 601 |
+
|
| 602 |
+
# forward
|
| 603 |
+
mil_pred, frm_scrs, mil_pred_2, frm_scrs_2 = self.model(data,mask)
|
| 604 |
+
|
| 605 |
+
'''Loc LOSS'''
|
| 606 |
+
target = target.cuda(self.output_device)
|
| 607 |
+
''' into one hot'''
|
| 608 |
+
ground_truth_flat = target.view(-1)
|
| 609 |
+
one_hot_ground_truth = F.one_hot(ground_truth_flat, num_classes=5)
|
| 610 |
+
''' into one hot'''
|
| 611 |
+
frm_scrs_re = rearrange(frm_scrs, "n t c -> (n t) c")
|
| 612 |
+
frm_scrs_2_re = rearrange(frm_scrs_2, "n t c -> (n t) c")
|
| 613 |
+
'''Loc LOSS'''
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
cls_mil_loss = self.loss_nce(
|
| 617 |
+
mil_pred, ab_labels.float()
|
| 618 |
+
) + self.loss_nce(mil_pred_2, ab_labels.float())
|
| 619 |
+
|
| 620 |
+
loss_co = mvl_loss(frm_scrs, frm_scrs_2, rate=0.2, weight=0.5)
|
| 621 |
+
|
| 622 |
+
loss = cls_mil_loss * self.arg.lambda_mil + loss_co
|
| 623 |
+
|
| 624 |
+
'''Loc LOSS'''
|
| 625 |
+
loss += cross_entropy_loss(
|
| 626 |
+
frm_scrs_re, one_hot_ground_truth
|
| 627 |
+
) + cross_entropy_loss(frm_scrs_2_re, one_hot_ground_truth)
|
| 628 |
+
'''Loc LOSS'''
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
loss_value.append(loss.data.item())
|
| 632 |
+
|
| 633 |
+
for i in range(data.size(0)):
|
| 634 |
+
frm_scr = frm_scrs[i]
|
| 635 |
+
vid_pred = mil_pred[i]
|
| 636 |
+
|
| 637 |
+
label_ = label[i].cpu().numpy()
|
| 638 |
+
mask_ = mask[i].cpu().numpy()
|
| 639 |
+
vid_len = mask_.sum()
|
| 640 |
+
|
| 641 |
+
frm_pred = F.softmax(frm_scr, -1).cpu().numpy()[:vid_len]
|
| 642 |
+
vid_pred = vid_pred.cpu().numpy()
|
| 643 |
+
|
| 644 |
+
vid_preds.append(vid_pred)
|
| 645 |
+
frm_preds.append(frm_pred)
|
| 646 |
+
vid_lens.append(vid_len)
|
| 647 |
+
labels.append(label_)
|
| 648 |
+
|
| 649 |
+
step += 1
|
| 650 |
+
|
| 651 |
+
vid_preds = np.array(vid_preds)
|
| 652 |
+
frm_preds = np.array(frm_preds)
|
| 653 |
+
vid_lens = np.array(vid_lens)
|
| 654 |
+
labels = np.array(labels)
|
| 655 |
+
|
| 656 |
+
cmap = cmAP(vid_preds, labels)
|
| 657 |
+
|
| 658 |
+
score = cmap
|
| 659 |
+
loss = np.mean(loss_value)
|
| 660 |
+
|
| 661 |
+
dmap, iou = dsmAP(
|
| 662 |
+
vid_preds,
|
| 663 |
+
frm_preds,
|
| 664 |
+
vid_lens,
|
| 665 |
+
self.arg.test_feeder_args["data_path"],
|
| 666 |
+
self.arg,
|
| 667 |
+
multi=True,
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
print("Classification map %f" % cmap)
|
| 671 |
+
for item in list(zip(iou, dmap)):
|
| 672 |
+
print("Detection map @ %f = %f" % (item[0], item[1]))
|
| 673 |
+
|
| 674 |
+
self.my_logger.append([epoch + 1, cmap] + dmap+ [np.mean(dmap)])
|
| 675 |
+
|
| 676 |
+
wb_dict["val loss"] = loss
|
| 677 |
+
wb_dict["val acc"] = score
|
| 678 |
+
|
| 679 |
+
if score > self.best_acc:
|
| 680 |
+
self.best_acc = score
|
| 681 |
+
|
| 682 |
+
print("Acc score: ", score, " model: ", self.arg.model_saved_name)
|
| 683 |
+
if self.arg.phase == "train":
|
| 684 |
+
self.val_writer.add_scalar("loss", loss, self.global_step)
|
| 685 |
+
self.val_writer.add_scalar("acc", score, self.global_step)
|
| 686 |
+
|
| 687 |
+
self.print_log(
|
| 688 |
+
"\tMean {} loss of {} batches: {}.".format(
|
| 689 |
+
ln, len(self.data_loader[ln]), np.mean(loss_value)
|
| 690 |
+
)
|
| 691 |
+
)
|
| 692 |
+
self.print_log("\tAcc score: {:.3f}%".format(score))
|
| 693 |
+
|
| 694 |
+
return wb_dict
|
| 695 |
+
|
| 696 |
+
def start(self):
|
| 697 |
+
wb_dict = {}
|
| 698 |
+
if self.arg.phase == "train":
|
| 699 |
+
self.print_log("Parameters:\n{}\n".format(str(vars(self.arg))))
|
| 700 |
+
self.global_step = (
|
| 701 |
+
self.arg.start_epoch
|
| 702 |
+
* len(self.data_loader["train"])
|
| 703 |
+
/ self.arg.batch_size
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
for epoch in range(self.arg.start_epoch, self.arg.num_epoch):
|
| 707 |
+
|
| 708 |
+
save_model = ((epoch + 1) % self.arg.save_interval == 0) or (
|
| 709 |
+
epoch + 1 == self.arg.num_epoch
|
| 710 |
+
)
|
| 711 |
+
wb_dict = {"lr": self.lr}
|
| 712 |
+
|
| 713 |
+
# Train
|
| 714 |
+
wb_dict = self.train(epoch, wb_dict, save_model=save_model)
|
| 715 |
+
if epoch%10==0:
|
| 716 |
+
# Eval. on val set
|
| 717 |
+
wb_dict = self.eval(epoch, wb_dict, loader_name=["test"])
|
| 718 |
+
# Log stats. for this epoch
|
| 719 |
+
print("Epoch: {0}\nMetrics: {1}".format(epoch, wb_dict))
|
| 720 |
+
|
| 721 |
+
print(
|
| 722 |
+
"best accuracy: ",
|
| 723 |
+
self.best_acc,
|
| 724 |
+
" model_name: ",
|
| 725 |
+
self.arg.model_saved_name,
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
elif self.arg.phase == "test":
|
| 729 |
+
if not self.arg.test_feeder_args["debug"]:
|
| 730 |
+
wf = self.arg.model_saved_name + "_wrong.txt"
|
| 731 |
+
rf = self.arg.model_saved_name + "_right.txt"
|
| 732 |
+
else:
|
| 733 |
+
wf = rf = None
|
| 734 |
+
if self.arg.weights is None:
|
| 735 |
+
raise ValueError("Please appoint --weights.")
|
| 736 |
+
self.arg.print_log = False
|
| 737 |
+
self.print_log("Model: {}.".format(self.arg.model))
|
| 738 |
+
self.print_log("Weights: {}.".format(self.arg.weights))
|
| 739 |
+
|
| 740 |
+
wb_dict = self.eval(
|
| 741 |
+
epoch=0,
|
| 742 |
+
wb_dict=wb_dict,
|
| 743 |
+
loader_name=["test"],
|
| 744 |
+
wrong_file=wf,
|
| 745 |
+
result_file=rf,
|
| 746 |
+
)
|
| 747 |
+
print("Inference metrics: ", wb_dict)
|
| 748 |
+
self.print_log("Done.\n")
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
def str2bool(v):
|
| 752 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 753 |
+
return True
|
| 754 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 755 |
+
return False
|
| 756 |
+
else:
|
| 757 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
def import_class(name):
|
| 761 |
+
components = name.split(".")
|
| 762 |
+
mod = __import__(components[0])
|
| 763 |
+
for comp in components[1:]:
|
| 764 |
+
mod = getattr(mod, comp)
|
| 765 |
+
return mod
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
if __name__ == "__main__":
|
| 769 |
+
parser = get_parser()
|
| 770 |
+
|
| 771 |
+
# load arg form config file
|
| 772 |
+
p = parser.parse_args()
|
| 773 |
+
if p.config is not None:
|
| 774 |
+
with open(p.config, "r") as f:
|
| 775 |
+
default_arg = yaml.safe_load(f)
|
| 776 |
+
key = vars(p).keys()
|
| 777 |
+
for k in default_arg.keys():
|
| 778 |
+
if k not in key:
|
| 779 |
+
print("WRONG ARG: {}".format(k))
|
| 780 |
+
assert k in key
|
| 781 |
+
parser.set_defaults(**default_arg)
|
| 782 |
+
|
| 783 |
+
arg = parser.parse_args()
|
| 784 |
+
print("BABEL Action Recognition")
|
| 785 |
+
print("Config: ", arg)
|
| 786 |
+
init_seed(arg.seed)
|
| 787 |
+
processor = Processor(arg)
|
| 788 |
+
processor.start()
|
train_full_SSL.py
ADDED
|
@@ -0,0 +1,784 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright 2023 LINE Corporation
|
| 3 |
+
LINE Corporation licenses this file to you under the Apache License,
|
| 4 |
+
version 2.0 (the "License"); you may not use this file except in compliance
|
| 5 |
+
with the License. You may obtain a copy of the License at:
|
| 6 |
+
https://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
Unless required by applicable law or agreed to in writing, software
|
| 8 |
+
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
| 9 |
+
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
| 10 |
+
License for the specific language governing permissions and limitations
|
| 11 |
+
under the License.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import print_function
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import inspect
|
| 18 |
+
import os
|
| 19 |
+
import pdb
|
| 20 |
+
import pickle
|
| 21 |
+
import random
|
| 22 |
+
import re
|
| 23 |
+
import shutil
|
| 24 |
+
import time
|
| 25 |
+
from collections import *
|
| 26 |
+
|
| 27 |
+
import ipdb
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
# torch
|
| 31 |
+
import torch
|
| 32 |
+
import torch.backends.cudnn as cudnn
|
| 33 |
+
import torch.nn as nn
|
| 34 |
+
import torch.nn.functional as F
|
| 35 |
+
import torch.optim as optim
|
| 36 |
+
import yaml
|
| 37 |
+
from einops import rearrange, reduce, repeat
|
| 38 |
+
from evaluation.classificationMAP import getClassificationMAP as cmAP
|
| 39 |
+
from evaluation.detectionMAP import getSingleStreamDetectionMAP as dsmAP
|
| 40 |
+
from feeders.tools import collate_with_padding_multi_joint
|
| 41 |
+
from model.losses import cross_entropy_loss, mvl_loss
|
| 42 |
+
from sklearn.metrics import f1_score
|
| 43 |
+
|
| 44 |
+
# Custom
|
| 45 |
+
from tensorboardX import SummaryWriter
|
| 46 |
+
from torch.autograd import Variable
|
| 47 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 48 |
+
from tqdm import tqdm
|
| 49 |
+
from utils.logger import Logger
|
| 50 |
+
|
| 51 |
+
def remove_prefix_from_state_dict(state_dict, prefix):
|
| 52 |
+
new_state_dict = {}
|
| 53 |
+
for k, v in state_dict.items():
|
| 54 |
+
if k.startswith(prefix):
|
| 55 |
+
new_k = k[len(prefix):] # strip the prefix
|
| 56 |
+
else:
|
| 57 |
+
new_k = k
|
| 58 |
+
new_state_dict[new_k] = v
|
| 59 |
+
return new_state_dict
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def init_seed(seed):
|
| 63 |
+
torch.cuda.manual_seed_all(seed)
|
| 64 |
+
torch.manual_seed(seed)
|
| 65 |
+
np.random.seed(seed)
|
| 66 |
+
random.seed(seed)
|
| 67 |
+
torch.backends.cudnn.deterministic = True
|
| 68 |
+
torch.backends.cudnn.benchmark = False
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get_parser():
|
| 72 |
+
# parameter priority: command line > config > default
|
| 73 |
+
parser = argparse.ArgumentParser(
|
| 74 |
+
description="Spatial Temporal Graph Convolution Network"
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--work-dir",
|
| 78 |
+
default="./work_dir/temp",
|
| 79 |
+
help="the work folder for storing results",
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
parser.add_argument("-model_saved_name", default="")
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--config",
|
| 85 |
+
default="./config/nturgbd-cross-view/test_bone.yaml",
|
| 86 |
+
help="path to the configuration file",
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# processor
|
| 90 |
+
parser.add_argument("--phase", default="train", help="must be train or test")
|
| 91 |
+
|
| 92 |
+
# visulize and debug
|
| 93 |
+
parser.add_argument("--seed", type=int, default=5, help="random seed for pytorch")
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--log-interval",
|
| 96 |
+
type=int,
|
| 97 |
+
default=100,
|
| 98 |
+
help="the interval for printing messages (#iteration)",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--save-interval",
|
| 102 |
+
type=int,
|
| 103 |
+
default=2,
|
| 104 |
+
help="the interval for storing models (#iteration)",
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--eval-interval",
|
| 108 |
+
type=int,
|
| 109 |
+
default=5,
|
| 110 |
+
help="the interval for evaluating models (#iteration)",
|
| 111 |
+
)
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--print-log", type=str2bool, default=True, help="print logging or not"
|
| 114 |
+
)
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--show-topk",
|
| 117 |
+
type=int,
|
| 118 |
+
default=[1, 5],
|
| 119 |
+
nargs="+",
|
| 120 |
+
help="which Top K accuracy will be shown",
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# feeder
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--feeder", default="feeder.feeder", help="data loader will be used"
|
| 126 |
+
)
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--num-worker",
|
| 129 |
+
type=int,
|
| 130 |
+
default=32,
|
| 131 |
+
help="the number of worker for data loader",
|
| 132 |
+
)
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--train-feeder-args",
|
| 135 |
+
default=dict(),
|
| 136 |
+
help="the arguments of data loader for training",
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--test-feeder-args",
|
| 140 |
+
default=dict(),
|
| 141 |
+
help="the arguments of data loader for test",
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# model
|
| 145 |
+
parser.add_argument("--model", default=None, help="the model will be used")
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--model-args", type=dict, default=dict(), help="the arguments of model"
|
| 148 |
+
)
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--weights", default=None, help="the weights for network initialization"
|
| 151 |
+
)
|
| 152 |
+
parser.add_argument(
|
| 153 |
+
"--ignore-weights",
|
| 154 |
+
type=str,
|
| 155 |
+
default=[],
|
| 156 |
+
nargs="+",
|
| 157 |
+
help="the name of weights which will be ignored in the initialization",
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# optim
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--base-lr", type=float, default=0.01, help="initial learning rate"
|
| 163 |
+
)
|
| 164 |
+
parser.add_argument(
|
| 165 |
+
"--step",
|
| 166 |
+
type=int,
|
| 167 |
+
default=[200],
|
| 168 |
+
nargs="+",
|
| 169 |
+
help="the epoch where optimizer reduce the learning rate",
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# training
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"--device",
|
| 175 |
+
type=int,
|
| 176 |
+
default=0,
|
| 177 |
+
nargs="+",
|
| 178 |
+
help="the indexes of GPUs for training or testing",
|
| 179 |
+
)
|
| 180 |
+
parser.add_argument("--optimizer", default="SGD", help="type of optimizer")
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--nesterov", type=str2bool, default=False, help="use nesterov or not"
|
| 183 |
+
)
|
| 184 |
+
parser.add_argument(
|
| 185 |
+
"--batch-size", type=int, default=256, help="training batch size"
|
| 186 |
+
)
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"--test-batch-size", type=int, default=256, help="test batch size"
|
| 189 |
+
)
|
| 190 |
+
parser.add_argument(
|
| 191 |
+
"--start-epoch", type=int, default=0, help="start training from which epoch"
|
| 192 |
+
)
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"--num-epoch", type=int, default=80, help="stop training in which epoch"
|
| 195 |
+
)
|
| 196 |
+
parser.add_argument(
|
| 197 |
+
"--weight-decay", type=float, default=0.0005, help="weight decay for optimizer"
|
| 198 |
+
)
|
| 199 |
+
# loss
|
| 200 |
+
parser.add_argument("--loss", type=str, default="CE", help="loss type(CE or focal)")
|
| 201 |
+
parser.add_argument(
|
| 202 |
+
"--label_count_path",
|
| 203 |
+
default=None,
|
| 204 |
+
type=str,
|
| 205 |
+
help="Path to label counts (used in loss weighting)",
|
| 206 |
+
)
|
| 207 |
+
parser.add_argument(
|
| 208 |
+
"---beta",
|
| 209 |
+
type=float,
|
| 210 |
+
default=0.9999,
|
| 211 |
+
help="Hyperparameter for Class balanced loss",
|
| 212 |
+
)
|
| 213 |
+
parser.add_argument(
|
| 214 |
+
"--gamma", type=float, default=2.0, help="Hyperparameter for Focal loss"
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
parser.add_argument("--only_train_part", default=False)
|
| 218 |
+
parser.add_argument("--only_train_epoch", default=0)
|
| 219 |
+
parser.add_argument("--warm_up_epoch", default=10)
|
| 220 |
+
|
| 221 |
+
parser.add_argument(
|
| 222 |
+
"--lambda-mil", default=1.0, help="balancing hyper-parameter of mil branch"
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
parser.add_argument(
|
| 226 |
+
"--class-threshold",
|
| 227 |
+
type=float,
|
| 228 |
+
default=0.1,
|
| 229 |
+
help="class threshold for rejection",
|
| 230 |
+
)
|
| 231 |
+
parser.add_argument(
|
| 232 |
+
"--start-threshold",
|
| 233 |
+
type=float,
|
| 234 |
+
default=0.03,
|
| 235 |
+
help="start threshold for action localization",
|
| 236 |
+
)
|
| 237 |
+
parser.add_argument(
|
| 238 |
+
"--end-threshold",
|
| 239 |
+
type=float,
|
| 240 |
+
default=0.055,
|
| 241 |
+
help="end threshold for action localization",
|
| 242 |
+
)
|
| 243 |
+
parser.add_argument(
|
| 244 |
+
"--threshold-interval",
|
| 245 |
+
type=float,
|
| 246 |
+
default=0.005,
|
| 247 |
+
help="threshold interval for action localization",
|
| 248 |
+
)
|
| 249 |
+
return parser
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class Processor:
|
| 253 |
+
"""
|
| 254 |
+
Processor for Skeleton-based Action Recgnition
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
def __init__(self, arg):
|
| 258 |
+
self.arg = arg
|
| 259 |
+
self.save_arg()
|
| 260 |
+
if arg.phase == "train":
|
| 261 |
+
if not arg.train_feeder_args["debug"]:
|
| 262 |
+
if os.path.isdir(arg.model_saved_name):
|
| 263 |
+
print("log_dir: ", arg.model_saved_name, "already exist")
|
| 264 |
+
# answer = input('delete it? y/n:')
|
| 265 |
+
answer = "y"
|
| 266 |
+
if answer == "y":
|
| 267 |
+
print("Deleting dir...")
|
| 268 |
+
shutil.rmtree(arg.model_saved_name)
|
| 269 |
+
print("Dir removed: ", arg.model_saved_name)
|
| 270 |
+
# input('Refresh the website of tensorboard by pressing any keys')
|
| 271 |
+
else:
|
| 272 |
+
print("Dir not removed: ", arg.model_saved_name)
|
| 273 |
+
self.train_writer = SummaryWriter(
|
| 274 |
+
os.path.join(arg.model_saved_name, "train"), "train"
|
| 275 |
+
)
|
| 276 |
+
self.val_writer = SummaryWriter(
|
| 277 |
+
os.path.join(arg.model_saved_name, "val"), "val"
|
| 278 |
+
)
|
| 279 |
+
else:
|
| 280 |
+
self.train_writer = self.val_writer = SummaryWriter(
|
| 281 |
+
os.path.join(arg.model_saved_name, "test"), "test"
|
| 282 |
+
)
|
| 283 |
+
self.global_step = 0
|
| 284 |
+
self.load_model()
|
| 285 |
+
self.load_optimizer()
|
| 286 |
+
self.load_data()
|
| 287 |
+
self.lr = self.arg.base_lr
|
| 288 |
+
self.best_acc = 0
|
| 289 |
+
self.best_per_class_acc = 0
|
| 290 |
+
self.loss_nce = torch.nn.BCELoss()
|
| 291 |
+
|
| 292 |
+
self.my_logger = Logger(
|
| 293 |
+
os.path.join(arg.model_saved_name, "log.txt"), title="SWTAL"
|
| 294 |
+
)
|
| 295 |
+
self.my_logger.set_names(["Step", "cmap"] + [f"map_0.{i}" for i in range(1, 6)]+["avg"])
|
| 296 |
+
|
| 297 |
+
def load_data(self):
|
| 298 |
+
Feeder = import_class(self.arg.feeder)
|
| 299 |
+
self.data_loader = dict()
|
| 300 |
+
if self.arg.phase == "train":
|
| 301 |
+
self.data_loader["train"] = torch.utils.data.DataLoader(
|
| 302 |
+
dataset=Feeder(**self.arg.train_feeder_args),
|
| 303 |
+
batch_size=self.arg.batch_size,
|
| 304 |
+
shuffle=True,
|
| 305 |
+
num_workers=self.arg.num_worker,
|
| 306 |
+
drop_last=True,
|
| 307 |
+
collate_fn=collate_with_padding_multi_joint,
|
| 308 |
+
)
|
| 309 |
+
self.data_loader["test"] = torch.utils.data.DataLoader(
|
| 310 |
+
dataset=Feeder(**self.arg.test_feeder_args),
|
| 311 |
+
batch_size=self.arg.test_batch_size,
|
| 312 |
+
shuffle=False,
|
| 313 |
+
num_workers=self.arg.num_worker,
|
| 314 |
+
drop_last=False,
|
| 315 |
+
collate_fn=collate_with_padding_multi_joint,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
def load_model(self):
|
| 319 |
+
output_device = (
|
| 320 |
+
self.arg.device[0] if type(self.arg.device) is list else self.arg.device
|
| 321 |
+
)
|
| 322 |
+
self.output_device = output_device
|
| 323 |
+
Model = import_class(self.arg.model)
|
| 324 |
+
shutil.copy2(inspect.getfile(Model), self.arg.work_dir)
|
| 325 |
+
# print(Model)
|
| 326 |
+
self.model = Model(**self.arg.model_args).cuda(output_device)
|
| 327 |
+
# print(self.model)
|
| 328 |
+
self.loss_type = arg.loss
|
| 329 |
+
|
| 330 |
+
if self.arg.weights:
|
| 331 |
+
self.print_log("Load weights from {}.".format(self.arg.weights))
|
| 332 |
+
if ".pkl" in self.arg.weights:
|
| 333 |
+
with open(self.arg.weights, "r") as f:
|
| 334 |
+
weights = pickle.load(f)
|
| 335 |
+
else:
|
| 336 |
+
weights = torch.load(self.arg.weights)
|
| 337 |
+
|
| 338 |
+
weights = OrderedDict(
|
| 339 |
+
[
|
| 340 |
+
[k.split("module.")[-1], v.cuda(output_device)]
|
| 341 |
+
for k, v in weights.items()
|
| 342 |
+
]
|
| 343 |
+
)
|
| 344 |
+
weights = remove_prefix_from_state_dict(weights, 'encoder_q.agcn.')
|
| 345 |
+
keys = list(weights.keys())
|
| 346 |
+
|
| 347 |
+
self.arg.ignore_weights = ['data_bn','fc','encoder_q','encoder_k','queue','queue_ptr','value_transform']
|
| 348 |
+
for w in self.arg.ignore_weights:
|
| 349 |
+
for key in keys:
|
| 350 |
+
if w in key:
|
| 351 |
+
if weights.pop(key, None) is not None:
|
| 352 |
+
continue
|
| 353 |
+
# self.print_log(
|
| 354 |
+
# "Sucessfully Remove Weights: {}.".format(key)
|
| 355 |
+
# )
|
| 356 |
+
# else:
|
| 357 |
+
# self.print_log("Can Not Remove Weights: {}.".format(key))
|
| 358 |
+
|
| 359 |
+
try:
|
| 360 |
+
self.model.load_state_dict(weights)
|
| 361 |
+
except:
|
| 362 |
+
state = self.model.state_dict()
|
| 363 |
+
diff = list(set(state.keys()).difference(set(weights.keys())))
|
| 364 |
+
print("Can not find these weights:")
|
| 365 |
+
for d in diff:
|
| 366 |
+
print(" " + d)
|
| 367 |
+
state.update(weights)
|
| 368 |
+
self.model.load_state_dict(state)
|
| 369 |
+
|
| 370 |
+
if type(self.arg.device) is list:
|
| 371 |
+
if len(self.arg.device) > 1:
|
| 372 |
+
self.model = nn.DataParallel(
|
| 373 |
+
self.model, device_ids=self.arg.device, output_device=output_device
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
def load_optimizer(self):
|
| 377 |
+
if self.arg.optimizer == "SGD":
|
| 378 |
+
self.optimizer = optim.SGD(
|
| 379 |
+
self.model.parameters(),
|
| 380 |
+
lr=self.arg.base_lr,
|
| 381 |
+
momentum=0.9,
|
| 382 |
+
nesterov=self.arg.nesterov,
|
| 383 |
+
weight_decay=self.arg.weight_decay,
|
| 384 |
+
)
|
| 385 |
+
elif self.arg.optimizer == "Adam":
|
| 386 |
+
self.optimizer = optim.Adam(
|
| 387 |
+
self.model.parameters(),
|
| 388 |
+
lr=self.arg.base_lr,
|
| 389 |
+
weight_decay=self.arg.weight_decay,
|
| 390 |
+
)
|
| 391 |
+
else:
|
| 392 |
+
raise ValueError()
|
| 393 |
+
|
| 394 |
+
def save_arg(self):
|
| 395 |
+
# save arg
|
| 396 |
+
arg_dict = vars(self.arg)
|
| 397 |
+
if not os.path.exists(self.arg.work_dir):
|
| 398 |
+
os.makedirs(self.arg.work_dir)
|
| 399 |
+
with open("{}/config.yaml".format(self.arg.work_dir), "w") as f:
|
| 400 |
+
yaml.dump(arg_dict, f)
|
| 401 |
+
|
| 402 |
+
def adjust_learning_rate(self, epoch):
|
| 403 |
+
if self.arg.optimizer == "SGD" or self.arg.optimizer == "Adam":
|
| 404 |
+
if epoch < self.arg.warm_up_epoch:
|
| 405 |
+
lr = self.arg.base_lr * (epoch + 1) / self.arg.warm_up_epoch
|
| 406 |
+
else:
|
| 407 |
+
lr = self.arg.base_lr * (
|
| 408 |
+
0.1 ** np.sum(epoch >= np.array(self.arg.step))
|
| 409 |
+
)
|
| 410 |
+
for param_group in self.optimizer.param_groups:
|
| 411 |
+
param_group["lr"] = lr
|
| 412 |
+
|
| 413 |
+
return lr
|
| 414 |
+
else:
|
| 415 |
+
raise ValueError()
|
| 416 |
+
|
| 417 |
+
def print_time(self):
|
| 418 |
+
localtime = time.asctime(time.localtime(time.time()))
|
| 419 |
+
self.print_log("Local current time : " + localtime)
|
| 420 |
+
|
| 421 |
+
def print_log(self, str, print_time=True):
|
| 422 |
+
if print_time:
|
| 423 |
+
localtime = time.asctime(time.localtime(time.time()))
|
| 424 |
+
str = "[ " + localtime + " ] " + str
|
| 425 |
+
print(str)
|
| 426 |
+
if self.arg.print_log:
|
| 427 |
+
with open("{}/print_log.txt".format(self.arg.work_dir), "a") as f:
|
| 428 |
+
print(str, file=f)
|
| 429 |
+
|
| 430 |
+
def record_time(self):
|
| 431 |
+
self.cur_time = time.time()
|
| 432 |
+
return self.cur_time
|
| 433 |
+
|
| 434 |
+
def split_time(self):
|
| 435 |
+
split_time = time.time() - self.cur_time
|
| 436 |
+
self.record_time()
|
| 437 |
+
return split_time
|
| 438 |
+
|
| 439 |
+
def train(self, epoch, wb_dict, save_model=False):
|
| 440 |
+
self.model.train()
|
| 441 |
+
self.print_log("Training epoch: {}".format(epoch + 1))
|
| 442 |
+
loader = self.data_loader["train"]
|
| 443 |
+
self.adjust_learning_rate(epoch)
|
| 444 |
+
|
| 445 |
+
loss_value, batch_acc = [], []
|
| 446 |
+
self.train_writer.add_scalar("epoch", epoch, self.global_step)
|
| 447 |
+
self.record_time()
|
| 448 |
+
timer = dict(dataloader=0.001, model=0.001, statistics=0.001)
|
| 449 |
+
process = tqdm(loader)
|
| 450 |
+
if self.arg.only_train_part:
|
| 451 |
+
if epoch > self.arg.only_train_epoch:
|
| 452 |
+
print("only train part, require grad")
|
| 453 |
+
for key, value in self.model.named_parameters():
|
| 454 |
+
if "PA" in key:
|
| 455 |
+
value.requires_grad = True
|
| 456 |
+
else:
|
| 457 |
+
print("only train part, do not require grad")
|
| 458 |
+
for key, value in self.model.named_parameters():
|
| 459 |
+
if "PA" in key:
|
| 460 |
+
value.requires_grad = False
|
| 461 |
+
|
| 462 |
+
vid_preds = []
|
| 463 |
+
frm_preds = []
|
| 464 |
+
vid_lens = []
|
| 465 |
+
labels = []
|
| 466 |
+
|
| 467 |
+
results = []
|
| 468 |
+
indexs = []
|
| 469 |
+
|
| 470 |
+
'''
|
| 471 |
+
Switch to FULL supervision
|
| 472 |
+
Dataloader->Feeder -> collate_with_padding_multi_joint
|
| 473 |
+
'''
|
| 474 |
+
|
| 475 |
+
for batch_idx, (data, label, target, mask, index, soft_label) in enumerate(
|
| 476 |
+
process
|
| 477 |
+
):
|
| 478 |
+
|
| 479 |
+
self.global_step += 1
|
| 480 |
+
# get data
|
| 481 |
+
data = data.float().cuda(self.output_device)
|
| 482 |
+
label = label.cuda(self.output_device)
|
| 483 |
+
target = target.cuda(self.output_device)
|
| 484 |
+
mask = mask.cuda(self.output_device)
|
| 485 |
+
soft_label = soft_label.cuda(self.output_device)
|
| 486 |
+
timer["dataloader"] += self.split_time()
|
| 487 |
+
|
| 488 |
+
''' into one hot'''
|
| 489 |
+
ground_truth_flat = target.view(-1)
|
| 490 |
+
one_hot_ground_truth = F.one_hot(ground_truth_flat, num_classes=5)
|
| 491 |
+
''' into one hot'''
|
| 492 |
+
|
| 493 |
+
indexs.extend(index.cpu().numpy().tolist())
|
| 494 |
+
|
| 495 |
+
ab_labels = torch.cat([label, torch.ones(label.size(0), 1).cuda()], -1)
|
| 496 |
+
|
| 497 |
+
# forward
|
| 498 |
+
frm_scrs = self.model(data)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
if epoch > -1:
|
| 502 |
+
|
| 503 |
+
frm_scrs_re = rearrange(frm_scrs, "n t c -> (n t) c")
|
| 504 |
+
# frm_scrs_2_re = rearrange(frm_scrs_2, "n t c -> (n t) c")
|
| 505 |
+
# soft_label = rearrange(soft_label, "n t c -> (n t) c")
|
| 506 |
+
|
| 507 |
+
# loss = cls_mil_loss * 0.1 + mvl_loss(
|
| 508 |
+
# frm_scrs, frm_scrs_2, rate=0.2, weight=0.5
|
| 509 |
+
# )
|
| 510 |
+
|
| 511 |
+
loss = cross_entropy_loss(
|
| 512 |
+
frm_scrs_re, one_hot_ground_truth
|
| 513 |
+
) #+ cross_entropy_loss(frm_scrs_2_re, one_hot_ground_truth)
|
| 514 |
+
|
| 515 |
+
for i in range(data.size(0)):
|
| 516 |
+
frm_scr = frm_scrs[i]
|
| 517 |
+
|
| 518 |
+
label_ = label[i].cpu().numpy()
|
| 519 |
+
mask_ = mask[i].cpu().numpy()
|
| 520 |
+
vid_len = mask_.sum()
|
| 521 |
+
|
| 522 |
+
frm_pred = F.softmax(frm_scr, -1).detach().cpu().numpy()[:vid_len]
|
| 523 |
+
# vid_pred = mil_pred[i].detach().cpu().numpy()
|
| 524 |
+
|
| 525 |
+
vid_pred = 0
|
| 526 |
+
results.append(frm_pred)
|
| 527 |
+
|
| 528 |
+
vid_preds.append(vid_pred)
|
| 529 |
+
frm_preds.append(frm_pred)
|
| 530 |
+
vid_lens.append(vid_len)
|
| 531 |
+
labels.append(label_)
|
| 532 |
+
|
| 533 |
+
# backward
|
| 534 |
+
self.optimizer.zero_grad()
|
| 535 |
+
loss.backward()
|
| 536 |
+
self.optimizer.step()
|
| 537 |
+
|
| 538 |
+
loss_value.append(loss.data.item())
|
| 539 |
+
timer["model"] += self.split_time()
|
| 540 |
+
|
| 541 |
+
vid_preds = np.array(vid_preds)
|
| 542 |
+
frm_preds = np.array(frm_preds)
|
| 543 |
+
vid_lens = np.array(vid_lens)
|
| 544 |
+
labels = np.array(labels)
|
| 545 |
+
|
| 546 |
+
loader.dataset.label_update(results, indexs)
|
| 547 |
+
|
| 548 |
+
# cmap = cmAP(vid_preds, labels)
|
| 549 |
+
cmap = 0
|
| 550 |
+
|
| 551 |
+
self.train_writer.add_scalar("acc", cmap, self.global_step)
|
| 552 |
+
self.train_writer.add_scalar("loss", np.mean(loss_value), self.global_step)
|
| 553 |
+
|
| 554 |
+
# statistics
|
| 555 |
+
self.lr = self.optimizer.param_groups[0]["lr"]
|
| 556 |
+
self.train_writer.add_scalar("lr", self.lr, self.global_step)
|
| 557 |
+
timer["statistics"] += self.split_time()
|
| 558 |
+
|
| 559 |
+
# statistics of time consumption and loss
|
| 560 |
+
self.print_log("\tMean training loss: {:.4f}.".format(np.mean(loss_value)))
|
| 561 |
+
self.print_log("\tAcc score: {:.3f}%".format(cmap))
|
| 562 |
+
|
| 563 |
+
# Log
|
| 564 |
+
wb_dict["train loss"] = np.mean(loss_value)
|
| 565 |
+
wb_dict["train acc"] = cmap
|
| 566 |
+
|
| 567 |
+
if save_model:
|
| 568 |
+
state_dict = self.model.state_dict()
|
| 569 |
+
weights = OrderedDict(
|
| 570 |
+
[[k.split("module.")[-1], v.cpu()] for k, v in state_dict.items()]
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
torch.save(
|
| 574 |
+
weights,
|
| 575 |
+
self.arg.model_saved_name + str(epoch) + ".pt",
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
return wb_dict
|
| 579 |
+
|
| 580 |
+
@torch.no_grad()
|
| 581 |
+
def eval(
|
| 582 |
+
self,
|
| 583 |
+
epoch,
|
| 584 |
+
wb_dict,
|
| 585 |
+
loader_name=["test"],
|
| 586 |
+
):
|
| 587 |
+
self.model.eval()
|
| 588 |
+
self.print_log("Eval epoch: {}".format(epoch + 1))
|
| 589 |
+
|
| 590 |
+
vid_preds = []
|
| 591 |
+
frm_preds = []
|
| 592 |
+
vid_lens = []
|
| 593 |
+
labels = []
|
| 594 |
+
|
| 595 |
+
for ln in loader_name:
|
| 596 |
+
loss_value = []
|
| 597 |
+
step = 0
|
| 598 |
+
process = tqdm(self.data_loader[ln])
|
| 599 |
+
|
| 600 |
+
for batch_idx, (data, label, target, mask, index, soft_label) in enumerate(
|
| 601 |
+
process
|
| 602 |
+
):
|
| 603 |
+
data = data.float().cuda(self.output_device)
|
| 604 |
+
label = label.cuda(self.output_device)
|
| 605 |
+
mask = mask.cuda(self.output_device)
|
| 606 |
+
|
| 607 |
+
ab_labels = torch.cat([label, torch.ones(label.size(0), 1).cuda()], -1)
|
| 608 |
+
|
| 609 |
+
# forward
|
| 610 |
+
frm_scrs = self.model(data)
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
'''Loc LOSS'''
|
| 614 |
+
target = target.cuda(self.output_device)
|
| 615 |
+
''' into one hot'''
|
| 616 |
+
ground_truth_flat = target.view(-1)
|
| 617 |
+
one_hot_ground_truth = F.one_hot(ground_truth_flat, num_classes=5)
|
| 618 |
+
''' into one hot'''
|
| 619 |
+
frm_scrs_re = rearrange(frm_scrs, "n t c -> (n t) c")
|
| 620 |
+
'''Loc LOSS'''
|
| 621 |
+
'''Loc LOSS'''
|
| 622 |
+
loss = cross_entropy_loss(
|
| 623 |
+
frm_scrs_re, one_hot_ground_truth
|
| 624 |
+
)
|
| 625 |
+
'''Loc LOSS'''
|
| 626 |
+
|
| 627 |
+
loss_value.append(loss.data.item())
|
| 628 |
+
|
| 629 |
+
for i in range(data.size(0)):
|
| 630 |
+
frm_scr = frm_scrs[i]
|
| 631 |
+
|
| 632 |
+
label_ = label[i].cpu().numpy()
|
| 633 |
+
mask_ = mask[i].cpu().numpy()
|
| 634 |
+
vid_len = mask_.sum()
|
| 635 |
+
|
| 636 |
+
frm_pred = F.softmax(frm_scr, -1).cpu().numpy()[:vid_len]
|
| 637 |
+
# vid_pred = vid_pred.cpu().numpy()
|
| 638 |
+
|
| 639 |
+
vid_pred = 0
|
| 640 |
+
vid_preds.append(vid_pred)
|
| 641 |
+
frm_preds.append(frm_pred)
|
| 642 |
+
vid_lens.append(vid_len)
|
| 643 |
+
labels.append(label_)
|
| 644 |
+
|
| 645 |
+
step += 1
|
| 646 |
+
|
| 647 |
+
vid_preds = np.array(vid_preds)
|
| 648 |
+
frm_preds = np.array(frm_preds)
|
| 649 |
+
vid_lens = np.array(vid_lens)
|
| 650 |
+
labels = np.array(labels)
|
| 651 |
+
|
| 652 |
+
# cmap = cmAP(vid_preds, labels)
|
| 653 |
+
cmap = 0
|
| 654 |
+
score = cmap
|
| 655 |
+
loss = np.mean(loss_value)
|
| 656 |
+
|
| 657 |
+
dmap, iou = dsmAP(
|
| 658 |
+
vid_preds,
|
| 659 |
+
frm_preds,
|
| 660 |
+
vid_lens,
|
| 661 |
+
self.arg.test_feeder_args["data_path"],
|
| 662 |
+
self.arg,
|
| 663 |
+
multi=True,
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
print("Classification map %f" % cmap)
|
| 667 |
+
for item in list(zip(iou, dmap)):
|
| 668 |
+
print("Detection map @ %f = %f" % (item[0], item[1]))
|
| 669 |
+
|
| 670 |
+
self.my_logger.append([epoch + 1, cmap] + dmap+[np.mean(dmap)])
|
| 671 |
+
|
| 672 |
+
wb_dict["val loss"] = loss
|
| 673 |
+
wb_dict["val acc"] = score
|
| 674 |
+
|
| 675 |
+
if score > self.best_acc:
|
| 676 |
+
self.best_acc = score
|
| 677 |
+
|
| 678 |
+
print("Acc score: ", score, " model: ", self.arg.model_saved_name)
|
| 679 |
+
if self.arg.phase == "train":
|
| 680 |
+
self.val_writer.add_scalar("loss", loss, self.global_step)
|
| 681 |
+
self.val_writer.add_scalar("acc", score, self.global_step)
|
| 682 |
+
|
| 683 |
+
self.print_log(
|
| 684 |
+
"\tMean {} loss of {} batches: {}.".format(
|
| 685 |
+
ln, len(self.data_loader[ln]), np.mean(loss_value)
|
| 686 |
+
)
|
| 687 |
+
)
|
| 688 |
+
self.print_log("\tAcc score: {:.3f}%".format(score))
|
| 689 |
+
|
| 690 |
+
return wb_dict
|
| 691 |
+
|
| 692 |
+
def start(self):
|
| 693 |
+
wb_dict = {}
|
| 694 |
+
if self.arg.phase == "train":
|
| 695 |
+
self.print_log("Parameters:\n{}\n".format(str(vars(self.arg))))
|
| 696 |
+
self.global_step = (
|
| 697 |
+
self.arg.start_epoch
|
| 698 |
+
* len(self.data_loader["train"])
|
| 699 |
+
/ self.arg.batch_size
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
for epoch in range(self.arg.start_epoch, self.arg.num_epoch):
|
| 703 |
+
|
| 704 |
+
save_model = ((epoch + 1) % self.arg.save_interval == 0) or (
|
| 705 |
+
epoch + 1 == self.arg.num_epoch
|
| 706 |
+
)
|
| 707 |
+
wb_dict = {"lr": self.lr}
|
| 708 |
+
|
| 709 |
+
# Train
|
| 710 |
+
wb_dict = self.train(epoch, wb_dict, save_model=save_model)
|
| 711 |
+
if epoch%1==0:
|
| 712 |
+
# Eval. on val set
|
| 713 |
+
wb_dict = self.eval(epoch, wb_dict, loader_name=["test"])
|
| 714 |
+
# Log stats. for this epoch
|
| 715 |
+
print("Epoch: {0}\nMetrics: {1}".format(epoch, wb_dict))
|
| 716 |
+
|
| 717 |
+
print(
|
| 718 |
+
"best accuracy: ",
|
| 719 |
+
self.best_acc,
|
| 720 |
+
" model_name: ",
|
| 721 |
+
self.arg.model_saved_name,
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
elif self.arg.phase == "test":
|
| 725 |
+
if not self.arg.test_feeder_args["debug"]:
|
| 726 |
+
wf = self.arg.model_saved_name + "_wrong.txt"
|
| 727 |
+
rf = self.arg.model_saved_name + "_right.txt"
|
| 728 |
+
else:
|
| 729 |
+
wf = rf = None
|
| 730 |
+
if self.arg.weights is None:
|
| 731 |
+
raise ValueError("Please appoint --weights.")
|
| 732 |
+
self.arg.print_log = False
|
| 733 |
+
self.print_log("Model: {}.".format(self.arg.model))
|
| 734 |
+
self.print_log("Weights: {}.".format(self.arg.weights))
|
| 735 |
+
|
| 736 |
+
wb_dict = self.eval(
|
| 737 |
+
epoch=0,
|
| 738 |
+
wb_dict=wb_dict,
|
| 739 |
+
loader_name=["test"],
|
| 740 |
+
wrong_file=wf,
|
| 741 |
+
result_file=rf,
|
| 742 |
+
)
|
| 743 |
+
print("Inference metrics: ", wb_dict)
|
| 744 |
+
self.print_log("Done.\n")
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
def str2bool(v):
|
| 748 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 749 |
+
return True
|
| 750 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 751 |
+
return False
|
| 752 |
+
else:
|
| 753 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
def import_class(name):
|
| 757 |
+
components = name.split(".")
|
| 758 |
+
mod = __import__(components[0])
|
| 759 |
+
for comp in components[1:]:
|
| 760 |
+
mod = getattr(mod, comp)
|
| 761 |
+
return mod
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
if __name__ == "__main__":
|
| 765 |
+
parser = get_parser()
|
| 766 |
+
|
| 767 |
+
# load arg form config file
|
| 768 |
+
p = parser.parse_args()
|
| 769 |
+
if p.config is not None:
|
| 770 |
+
with open(p.config, "r") as f:
|
| 771 |
+
default_arg = yaml.safe_load(f)
|
| 772 |
+
key = vars(p).keys()
|
| 773 |
+
for k in default_arg.keys():
|
| 774 |
+
if k not in key:
|
| 775 |
+
print("WRONG ARG: {}".format(k))
|
| 776 |
+
assert k in key
|
| 777 |
+
parser.set_defaults(**default_arg)
|
| 778 |
+
|
| 779 |
+
arg = parser.parse_args()
|
| 780 |
+
print("BABEL Action Recognition")
|
| 781 |
+
print("Config: ", arg)
|
| 782 |
+
init_seed(arg.seed)
|
| 783 |
+
processor = Processor(arg)
|
| 784 |
+
processor.start()
|
train_full_SSL_Unet.py
ADDED
|
@@ -0,0 +1,813 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright 2023 LINE Corporation
|
| 3 |
+
LINE Corporation licenses this file to you under the Apache License,
|
| 4 |
+
version 2.0 (the "License"); you may not use this file except in compliance
|
| 5 |
+
with the License. You may obtain a copy of the License at:
|
| 6 |
+
https://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
Unless required by applicable law or agreed to in writing, software
|
| 8 |
+
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
| 9 |
+
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
| 10 |
+
License for the specific language governing permissions and limitations
|
| 11 |
+
under the License.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import print_function
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import inspect
|
| 18 |
+
import os
|
| 19 |
+
import pdb
|
| 20 |
+
import pickle
|
| 21 |
+
import random
|
| 22 |
+
import re
|
| 23 |
+
import shutil
|
| 24 |
+
import time
|
| 25 |
+
from collections import *
|
| 26 |
+
|
| 27 |
+
import ipdb
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
# torch
|
| 31 |
+
import torch
|
| 32 |
+
import torch.backends.cudnn as cudnn
|
| 33 |
+
import torch.nn as nn
|
| 34 |
+
import torch.nn.functional as F
|
| 35 |
+
import torch.optim as optim
|
| 36 |
+
import yaml
|
| 37 |
+
from einops import rearrange, reduce, repeat
|
| 38 |
+
from evaluation.classificationMAP import getClassificationMAP as cmAP
|
| 39 |
+
from evaluation.detectionMAP import getSingleStreamDetectionMAP as dsmAP
|
| 40 |
+
from feeders.tools import collate_with_padding_multi_joint
|
| 41 |
+
from model.losses import cross_entropy_loss, mvl_loss
|
| 42 |
+
from sklearn.metrics import f1_score
|
| 43 |
+
|
| 44 |
+
# Custom
|
| 45 |
+
from tensorboardX import SummaryWriter
|
| 46 |
+
from torch.autograd import Variable
|
| 47 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 48 |
+
from tqdm import tqdm
|
| 49 |
+
from utils.logger import Logger
|
| 50 |
+
|
| 51 |
+
def remove_prefix_from_state_dict(state_dict, prefix):
|
| 52 |
+
new_state_dict = {}
|
| 53 |
+
for k, v in state_dict.items():
|
| 54 |
+
if k.startswith(prefix):
|
| 55 |
+
new_k = k[len(prefix):] # strip the prefix
|
| 56 |
+
else:
|
| 57 |
+
new_k = k
|
| 58 |
+
new_state_dict[new_k] = v
|
| 59 |
+
return new_state_dict
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def init_seed(seed):
|
| 63 |
+
torch.cuda.manual_seed_all(seed)
|
| 64 |
+
torch.manual_seed(seed)
|
| 65 |
+
np.random.seed(seed)
|
| 66 |
+
random.seed(seed)
|
| 67 |
+
torch.backends.cudnn.deterministic = True
|
| 68 |
+
torch.backends.cudnn.benchmark = False
|
| 69 |
+
# torch.use_deterministic_algorithms(True)
|
| 70 |
+
|
| 71 |
+
def get_parser():
|
| 72 |
+
# parameter priority: command line > config > default
|
| 73 |
+
parser = argparse.ArgumentParser(
|
| 74 |
+
description="Spatial Temporal Graph Convolution Network"
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--work-dir",
|
| 78 |
+
default="./work_dir/temp",
|
| 79 |
+
help="the work folder for storing results",
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
parser.add_argument("-model_saved_name", default="")
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--config",
|
| 85 |
+
default="./config/nturgbd-cross-view/test_bone.yaml",
|
| 86 |
+
help="path to the configuration file",
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# processor
|
| 90 |
+
parser.add_argument("--phase", default="train", help="must be train or test")
|
| 91 |
+
|
| 92 |
+
# visulize and debug
|
| 93 |
+
parser.add_argument("--seed", type=int, default=5, help="random seed for pytorch")
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--log-interval",
|
| 96 |
+
type=int,
|
| 97 |
+
default=100,
|
| 98 |
+
help="the interval for printing messages (#iteration)",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--save-interval",
|
| 102 |
+
type=int,
|
| 103 |
+
default=2,
|
| 104 |
+
help="the interval for storing models (#iteration)",
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--eval-interval",
|
| 108 |
+
type=int,
|
| 109 |
+
default=5,
|
| 110 |
+
help="the interval for evaluating models (#iteration)",
|
| 111 |
+
)
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--print-log", type=str2bool, default=True, help="print logging or not"
|
| 114 |
+
)
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--show-topk",
|
| 117 |
+
type=int,
|
| 118 |
+
default=[1, 5],
|
| 119 |
+
nargs="+",
|
| 120 |
+
help="which Top K accuracy will be shown",
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# feeder
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--feeder", default="feeder.feeder", help="data loader will be used"
|
| 126 |
+
)
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--num-worker",
|
| 129 |
+
type=int,
|
| 130 |
+
default=32,
|
| 131 |
+
help="the number of worker for data loader",
|
| 132 |
+
)
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--train-feeder-args",
|
| 135 |
+
default=dict(),
|
| 136 |
+
help="the arguments of data loader for training",
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--test-feeder-args",
|
| 140 |
+
default=dict(),
|
| 141 |
+
help="the arguments of data loader for test",
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# model
|
| 145 |
+
parser.add_argument("--model", default=None, help="the model will be used")
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--model-args", type=dict, default=dict(), help="the arguments of model"
|
| 148 |
+
)
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--weights", default=None, help="the weights for network initialization"
|
| 151 |
+
)
|
| 152 |
+
parser.add_argument(
|
| 153 |
+
"--ignore-weights",
|
| 154 |
+
type=str,
|
| 155 |
+
default=[],
|
| 156 |
+
nargs="+",
|
| 157 |
+
help="the name of weights which will be ignored in the initialization",
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# optim
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--base-lr", type=float, default=0.01, help="initial learning rate"
|
| 163 |
+
)
|
| 164 |
+
parser.add_argument(
|
| 165 |
+
"--step",
|
| 166 |
+
type=int,
|
| 167 |
+
default=[200],
|
| 168 |
+
nargs="+",
|
| 169 |
+
help="the epoch where optimizer reduce the learning rate",
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# training
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"--device",
|
| 175 |
+
type=int,
|
| 176 |
+
default=0,
|
| 177 |
+
nargs="+",
|
| 178 |
+
help="the indexes of GPUs for training or testing",
|
| 179 |
+
)
|
| 180 |
+
parser.add_argument("--optimizer", default="SGD", help="type of optimizer")
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--nesterov", type=str2bool, default=False, help="use nesterov or not"
|
| 183 |
+
)
|
| 184 |
+
parser.add_argument(
|
| 185 |
+
"--batch-size", type=int, default=256, help="training batch size"
|
| 186 |
+
)
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"--test-batch-size", type=int, default=256, help="test batch size"
|
| 189 |
+
)
|
| 190 |
+
parser.add_argument(
|
| 191 |
+
"--start-epoch", type=int, default=0, help="start training from which epoch"
|
| 192 |
+
)
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"--num-epoch", type=int, default=80, help="stop training in which epoch"
|
| 195 |
+
)
|
| 196 |
+
parser.add_argument(
|
| 197 |
+
"--weight-decay", type=float, default=0.0005, help="weight decay for optimizer"
|
| 198 |
+
)
|
| 199 |
+
# loss
|
| 200 |
+
parser.add_argument("--loss", type=str, default="CE", help="loss type(CE or focal)")
|
| 201 |
+
parser.add_argument(
|
| 202 |
+
"--label_count_path",
|
| 203 |
+
default=None,
|
| 204 |
+
type=str,
|
| 205 |
+
help="Path to label counts (used in loss weighting)",
|
| 206 |
+
)
|
| 207 |
+
parser.add_argument(
|
| 208 |
+
"---beta",
|
| 209 |
+
type=float,
|
| 210 |
+
default=0.9999,
|
| 211 |
+
help="Hyperparameter for Class balanced loss",
|
| 212 |
+
)
|
| 213 |
+
parser.add_argument(
|
| 214 |
+
"--gamma", type=float, default=2.0, help="Hyperparameter for Focal loss"
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
parser.add_argument("--only_train_part", default=False)
|
| 218 |
+
parser.add_argument("--only_train_epoch", default=0)
|
| 219 |
+
parser.add_argument("--warm_up_epoch", default=10)
|
| 220 |
+
|
| 221 |
+
parser.add_argument(
|
| 222 |
+
"--lambda-mil", default=1.0, help="balancing hyper-parameter of mil branch"
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
parser.add_argument(
|
| 226 |
+
"--class-threshold",
|
| 227 |
+
type=float,
|
| 228 |
+
default=0.1,
|
| 229 |
+
help="class threshold for rejection",
|
| 230 |
+
)
|
| 231 |
+
parser.add_argument(
|
| 232 |
+
"--start-threshold",
|
| 233 |
+
type=float,
|
| 234 |
+
default=0.03,
|
| 235 |
+
help="start threshold for action localization",
|
| 236 |
+
)
|
| 237 |
+
parser.add_argument(
|
| 238 |
+
"--end-threshold",
|
| 239 |
+
type=float,
|
| 240 |
+
default=0.055,
|
| 241 |
+
help="end threshold for action localization",
|
| 242 |
+
)
|
| 243 |
+
parser.add_argument(
|
| 244 |
+
"--threshold-interval",
|
| 245 |
+
type=float,
|
| 246 |
+
default=0.005,
|
| 247 |
+
help="threshold interval for action localization",
|
| 248 |
+
)
|
| 249 |
+
return parser
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class Processor:
|
| 253 |
+
"""
|
| 254 |
+
Processor for Skeleton-based Action Recgnition
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
def __init__(self, arg):
|
| 258 |
+
self.arg = arg
|
| 259 |
+
self.save_arg()
|
| 260 |
+
if arg.phase == "train":
|
| 261 |
+
if not arg.train_feeder_args["debug"]:
|
| 262 |
+
if os.path.isdir(arg.model_saved_name):
|
| 263 |
+
print("log_dir: ", arg.model_saved_name, "already exist")
|
| 264 |
+
# answer = input('delete it? y/n:')
|
| 265 |
+
answer = "y"
|
| 266 |
+
if answer == "y":
|
| 267 |
+
print("Deleting dir...")
|
| 268 |
+
shutil.rmtree(arg.model_saved_name)
|
| 269 |
+
print("Dir removed: ", arg.model_saved_name)
|
| 270 |
+
# input('Refresh the website of tensorboard by pressing any keys')
|
| 271 |
+
else:
|
| 272 |
+
print("Dir not removed: ", arg.model_saved_name)
|
| 273 |
+
self.train_writer = SummaryWriter(
|
| 274 |
+
os.path.join(arg.model_saved_name, "train"), "train"
|
| 275 |
+
)
|
| 276 |
+
self.val_writer = SummaryWriter(
|
| 277 |
+
os.path.join(arg.model_saved_name, "val"), "val"
|
| 278 |
+
)
|
| 279 |
+
else:
|
| 280 |
+
self.train_writer = self.val_writer = SummaryWriter(
|
| 281 |
+
os.path.join(arg.model_saved_name, "test"), "test"
|
| 282 |
+
)
|
| 283 |
+
self.global_step = 0
|
| 284 |
+
self.load_model()
|
| 285 |
+
self.load_optimizer()
|
| 286 |
+
self.load_data()
|
| 287 |
+
self.lr = self.arg.base_lr
|
| 288 |
+
self.best_acc = 0
|
| 289 |
+
self.best_per_class_acc = 0
|
| 290 |
+
self.loss_nce = torch.nn.BCELoss()
|
| 291 |
+
|
| 292 |
+
self.my_logger = Logger(
|
| 293 |
+
os.path.join(arg.model_saved_name, "log.txt"), title="SWTAL"
|
| 294 |
+
)
|
| 295 |
+
self.my_logger.set_names(["Step", "cmap"] + [f"map_0.{i}" for i in range(1, 6)]+ ['avg'])
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def load_data(self):
|
| 301 |
+
|
| 302 |
+
seed = self.arg.seed if hasattr(self.arg, "seed") else 42
|
| 303 |
+
|
| 304 |
+
def seed_worker(worker_id):
|
| 305 |
+
worker_seed = seed + worker_id
|
| 306 |
+
np.random.seed(worker_seed)
|
| 307 |
+
random.seed(worker_seed)
|
| 308 |
+
|
| 309 |
+
g = torch.Generator()
|
| 310 |
+
g.manual_seed(seed)
|
| 311 |
+
|
| 312 |
+
Feeder = import_class(self.arg.feeder)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
self.data_loader = dict()
|
| 316 |
+
if self.arg.phase == "train":
|
| 317 |
+
self.data_loader["train"] = torch.utils.data.DataLoader(
|
| 318 |
+
dataset=Feeder(**self.arg.train_feeder_args),
|
| 319 |
+
batch_size=self.arg.batch_size,
|
| 320 |
+
shuffle=True,
|
| 321 |
+
num_workers=self.arg.num_worker,
|
| 322 |
+
drop_last=True,
|
| 323 |
+
collate_fn=collate_with_padding_multi_joint,
|
| 324 |
+
worker_init_fn=seed_worker, # ✅ 固定每个worker的seed
|
| 325 |
+
generator=g
|
| 326 |
+
)
|
| 327 |
+
self.data_loader["test"] = torch.utils.data.DataLoader(
|
| 328 |
+
dataset=Feeder(**self.arg.test_feeder_args),
|
| 329 |
+
batch_size=self.arg.test_batch_size,
|
| 330 |
+
shuffle=False,
|
| 331 |
+
num_workers=self.arg.num_worker,
|
| 332 |
+
drop_last=False,
|
| 333 |
+
collate_fn=collate_with_padding_multi_joint,
|
| 334 |
+
worker_init_fn=seed_worker, # ✅ 固定每个worker的seed
|
| 335 |
+
generator=g
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
def load_model(self):
|
| 339 |
+
output_device = (
|
| 340 |
+
self.arg.device[0] if type(self.arg.device) is list else self.arg.device
|
| 341 |
+
)
|
| 342 |
+
self.output_device = output_device
|
| 343 |
+
Model = import_class(self.arg.model)
|
| 344 |
+
shutil.copy2(inspect.getfile(Model), self.arg.work_dir)
|
| 345 |
+
# print(Model)
|
| 346 |
+
self.model = Model(**self.arg.model_args).cuda(output_device)
|
| 347 |
+
# print(self.model)
|
| 348 |
+
self.loss_type = arg.loss
|
| 349 |
+
|
| 350 |
+
if self.arg.weights:
|
| 351 |
+
# if False:
|
| 352 |
+
# self.global_step = int(arg.weights[:-3].split("-")[-1])
|
| 353 |
+
self.print_log("Load weights from {}.".format(self.arg.weights))
|
| 354 |
+
if ".pkl" in self.arg.weights:
|
| 355 |
+
with open(self.arg.weights, "r") as f:
|
| 356 |
+
weights = pickle.load(f)
|
| 357 |
+
else:
|
| 358 |
+
weights = torch.load(self.arg.weights)
|
| 359 |
+
|
| 360 |
+
weights = OrderedDict(
|
| 361 |
+
[
|
| 362 |
+
[k.split("module.")[-1], v.cuda(output_device)]
|
| 363 |
+
for k, v in weights.items()
|
| 364 |
+
]
|
| 365 |
+
)
|
| 366 |
+
weights = remove_prefix_from_state_dict(weights, 'encoder_q.agcn.')
|
| 367 |
+
keys = list(weights.keys())
|
| 368 |
+
self.arg.ignore_weights = ['data_bn','fc','encoder_q','encoder_k','queue','queue_ptr','value_transform']
|
| 369 |
+
for w in self.arg.ignore_weights:
|
| 370 |
+
for key in keys:
|
| 371 |
+
if w in key:
|
| 372 |
+
if weights.pop(key, None) is not None:
|
| 373 |
+
self.print_log(
|
| 374 |
+
"Sucessfully Remove Weights: {}.".format(key)
|
| 375 |
+
)
|
| 376 |
+
else:
|
| 377 |
+
self.print_log("Can Not Remove Weights: {}.".format(key))
|
| 378 |
+
|
| 379 |
+
try:
|
| 380 |
+
self.model.load_state_dict(weights)
|
| 381 |
+
except:
|
| 382 |
+
state = self.model.state_dict()
|
| 383 |
+
diff = list(set(state.keys()).difference(set(weights.keys())))
|
| 384 |
+
print("Can not find these weights:")
|
| 385 |
+
for d in diff:
|
| 386 |
+
print(" " + d)
|
| 387 |
+
state.update(weights)
|
| 388 |
+
self.model.load_state_dict(state)
|
| 389 |
+
|
| 390 |
+
if type(self.arg.device) is list:
|
| 391 |
+
if len(self.arg.device) > 1:
|
| 392 |
+
self.model = nn.DataParallel(
|
| 393 |
+
self.model, device_ids=self.arg.device, output_device=output_device
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
def load_optimizer(self):
|
| 397 |
+
if self.arg.optimizer == "SGD":
|
| 398 |
+
self.optimizer = optim.SGD(
|
| 399 |
+
self.model.parameters(),
|
| 400 |
+
lr=self.arg.base_lr,
|
| 401 |
+
momentum=0.9,
|
| 402 |
+
nesterov=self.arg.nesterov,
|
| 403 |
+
weight_decay=self.arg.weight_decay,
|
| 404 |
+
)
|
| 405 |
+
elif self.arg.optimizer == "Adam":
|
| 406 |
+
self.optimizer = optim.Adam(
|
| 407 |
+
self.model.parameters(),
|
| 408 |
+
lr=self.arg.base_lr,
|
| 409 |
+
weight_decay=self.arg.weight_decay,
|
| 410 |
+
)
|
| 411 |
+
else:
|
| 412 |
+
raise ValueError()
|
| 413 |
+
|
| 414 |
+
def save_arg(self):
|
| 415 |
+
# save arg
|
| 416 |
+
arg_dict = vars(self.arg)
|
| 417 |
+
if not os.path.exists(self.arg.work_dir):
|
| 418 |
+
os.makedirs(self.arg.work_dir)
|
| 419 |
+
with open("{}/config.yaml".format(self.arg.work_dir), "w") as f:
|
| 420 |
+
yaml.dump(arg_dict, f)
|
| 421 |
+
|
| 422 |
+
def adjust_learning_rate(self, epoch):
|
| 423 |
+
if self.arg.optimizer == "SGD" or self.arg.optimizer == "Adam":
|
| 424 |
+
if epoch < self.arg.warm_up_epoch:
|
| 425 |
+
lr = self.arg.base_lr * (epoch + 1) / self.arg.warm_up_epoch
|
| 426 |
+
else:
|
| 427 |
+
lr = self.arg.base_lr * (
|
| 428 |
+
0.1 ** np.sum(epoch >= np.array(self.arg.step))
|
| 429 |
+
)
|
| 430 |
+
for param_group in self.optimizer.param_groups:
|
| 431 |
+
param_group["lr"] = lr
|
| 432 |
+
|
| 433 |
+
return lr
|
| 434 |
+
else:
|
| 435 |
+
raise ValueError()
|
| 436 |
+
|
| 437 |
+
def print_time(self):
|
| 438 |
+
localtime = time.asctime(time.localtime(time.time()))
|
| 439 |
+
self.print_log("Local current time : " + localtime)
|
| 440 |
+
|
| 441 |
+
def print_log(self, str, print_time=True):
|
| 442 |
+
if print_time:
|
| 443 |
+
localtime = time.asctime(time.localtime(time.time()))
|
| 444 |
+
str = "[ " + localtime + " ] " + str
|
| 445 |
+
print(str)
|
| 446 |
+
if self.arg.print_log:
|
| 447 |
+
with open("{}/print_log.txt".format(self.arg.work_dir), "a") as f:
|
| 448 |
+
print(str, file=f)
|
| 449 |
+
|
| 450 |
+
def record_time(self):
|
| 451 |
+
self.cur_time = time.time()
|
| 452 |
+
return self.cur_time
|
| 453 |
+
|
| 454 |
+
def split_time(self):
|
| 455 |
+
split_time = time.time() - self.cur_time
|
| 456 |
+
self.record_time()
|
| 457 |
+
return split_time
|
| 458 |
+
|
| 459 |
+
def train(self, epoch, wb_dict, save_model=False):
|
| 460 |
+
self.model.train()
|
| 461 |
+
self.print_log("Training epoch: {}".format(epoch + 1))
|
| 462 |
+
loader = self.data_loader["train"]
|
| 463 |
+
self.adjust_learning_rate(epoch)
|
| 464 |
+
|
| 465 |
+
loss_value, batch_acc = [], []
|
| 466 |
+
self.train_writer.add_scalar("epoch", epoch, self.global_step)
|
| 467 |
+
self.record_time()
|
| 468 |
+
timer = dict(dataloader=0.001, model=0.001, statistics=0.001)
|
| 469 |
+
process = tqdm(loader)
|
| 470 |
+
if self.arg.only_train_part:
|
| 471 |
+
if epoch > self.arg.only_train_epoch:
|
| 472 |
+
print("only train part, require grad")
|
| 473 |
+
for key, value in self.model.named_parameters():
|
| 474 |
+
if "PA" in key:
|
| 475 |
+
value.requires_grad = True
|
| 476 |
+
else:
|
| 477 |
+
print("only train part, do not require grad")
|
| 478 |
+
for key, value in self.model.named_parameters():
|
| 479 |
+
if "PA" in key:
|
| 480 |
+
value.requires_grad = False
|
| 481 |
+
|
| 482 |
+
vid_preds = []
|
| 483 |
+
frm_preds = []
|
| 484 |
+
vid_lens = []
|
| 485 |
+
labels = []
|
| 486 |
+
|
| 487 |
+
results = []
|
| 488 |
+
indexs = []
|
| 489 |
+
|
| 490 |
+
'''
|
| 491 |
+
Switch to FULL supervision
|
| 492 |
+
Dataloader->Feeder -> collate_with_padding_multi_joint
|
| 493 |
+
'''
|
| 494 |
+
|
| 495 |
+
for batch_idx, (data, label, target, mask, index, soft_label) in enumerate(
|
| 496 |
+
process
|
| 497 |
+
):
|
| 498 |
+
|
| 499 |
+
self.global_step += 1
|
| 500 |
+
# get data
|
| 501 |
+
data = data.float().cuda(self.output_device)
|
| 502 |
+
label = label.cuda(self.output_device)
|
| 503 |
+
target = target.cuda(self.output_device)
|
| 504 |
+
mask = mask.cuda(self.output_device)
|
| 505 |
+
soft_label = soft_label.cuda(self.output_device)
|
| 506 |
+
timer["dataloader"] += self.split_time()
|
| 507 |
+
|
| 508 |
+
''' into one hot'''
|
| 509 |
+
ground_truth_flat = target.view(-1)
|
| 510 |
+
one_hot_ground_truth = F.one_hot(ground_truth_flat, num_classes=5)
|
| 511 |
+
''' into one hot'''
|
| 512 |
+
|
| 513 |
+
indexs.extend(index.cpu().numpy().tolist())
|
| 514 |
+
|
| 515 |
+
ab_labels = torch.cat([label, torch.ones(label.size(0), 1).cuda()], -1)
|
| 516 |
+
|
| 517 |
+
# forward
|
| 518 |
+
# print(data.shape)
|
| 519 |
+
mil_pred, frm_scrs, mil_pred_2, frm_scrs_2 = self.model(data,mask)
|
| 520 |
+
|
| 521 |
+
cls_mil_loss = self.loss_nce(mil_pred.float(), ab_labels.float()) + self.loss_nce(
|
| 522 |
+
mil_pred_2.float(), ab_labels.float()
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
if epoch > -1:
|
| 526 |
+
|
| 527 |
+
frm_scrs_re = rearrange(frm_scrs, "n t c -> (n t) c")
|
| 528 |
+
frm_scrs_2_re = rearrange(frm_scrs_2, "n t c -> (n t) c")
|
| 529 |
+
soft_label = rearrange(soft_label, "n t c -> (n t) c")
|
| 530 |
+
|
| 531 |
+
loss = cls_mil_loss * 0.1 + mvl_loss(
|
| 532 |
+
frm_scrs, frm_scrs_2, rate=0.2, weight=0.5
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
loss += cross_entropy_loss(
|
| 536 |
+
frm_scrs_re, one_hot_ground_truth
|
| 537 |
+
) + cross_entropy_loss(frm_scrs_2_re, one_hot_ground_truth)
|
| 538 |
+
|
| 539 |
+
for i in range(data.size(0)):
|
| 540 |
+
frm_scr = frm_scrs[i]
|
| 541 |
+
|
| 542 |
+
label_ = label[i].cpu().numpy()
|
| 543 |
+
mask_ = mask[i].cpu().numpy()
|
| 544 |
+
vid_len = mask_.sum()
|
| 545 |
+
|
| 546 |
+
frm_pred = F.softmax(frm_scr, -1).detach().cpu().numpy()[:vid_len]
|
| 547 |
+
vid_pred = mil_pred[i].detach().cpu().numpy()
|
| 548 |
+
|
| 549 |
+
results.append(frm_pred)
|
| 550 |
+
|
| 551 |
+
vid_preds.append(vid_pred)
|
| 552 |
+
frm_preds.append(frm_pred)
|
| 553 |
+
vid_lens.append(vid_len)
|
| 554 |
+
labels.append(label_)
|
| 555 |
+
|
| 556 |
+
# backward
|
| 557 |
+
self.optimizer.zero_grad()
|
| 558 |
+
loss.backward()
|
| 559 |
+
self.optimizer.step()
|
| 560 |
+
|
| 561 |
+
loss_value.append(loss.data.item())
|
| 562 |
+
timer["model"] += self.split_time()
|
| 563 |
+
|
| 564 |
+
vid_preds = np.array(vid_preds)
|
| 565 |
+
frm_preds = np.array(frm_preds)
|
| 566 |
+
vid_lens = np.array(vid_lens)
|
| 567 |
+
labels = np.array(labels)
|
| 568 |
+
|
| 569 |
+
loader.dataset.label_update(results, indexs)
|
| 570 |
+
|
| 571 |
+
cmap = cmAP(vid_preds, labels)
|
| 572 |
+
|
| 573 |
+
self.train_writer.add_scalar("acc", cmap, self.global_step)
|
| 574 |
+
self.train_writer.add_scalar("loss", np.mean(loss_value), self.global_step)
|
| 575 |
+
|
| 576 |
+
# statistics
|
| 577 |
+
self.lr = self.optimizer.param_groups[0]["lr"]
|
| 578 |
+
self.train_writer.add_scalar("lr", self.lr, self.global_step)
|
| 579 |
+
timer["statistics"] += self.split_time()
|
| 580 |
+
|
| 581 |
+
# statistics of time consumption and loss
|
| 582 |
+
self.print_log("\tMean training loss: {:.4f}.".format(np.mean(loss_value)))
|
| 583 |
+
self.print_log("\tAcc score: {:.3f}%".format(cmap))
|
| 584 |
+
|
| 585 |
+
# Log
|
| 586 |
+
wb_dict["train loss"] = np.mean(loss_value)
|
| 587 |
+
wb_dict["train acc"] = cmap
|
| 588 |
+
|
| 589 |
+
if save_model:
|
| 590 |
+
state_dict = self.model.state_dict()
|
| 591 |
+
weights = OrderedDict(
|
| 592 |
+
[[k.split("module.")[-1], v.cpu()] for k, v in state_dict.items()]
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
torch.save(
|
| 596 |
+
weights,
|
| 597 |
+
self.arg.model_saved_name + str(epoch) + ".pt",
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
return wb_dict
|
| 601 |
+
|
| 602 |
+
@torch.no_grad()
|
| 603 |
+
def eval(
|
| 604 |
+
self,
|
| 605 |
+
epoch,
|
| 606 |
+
wb_dict,
|
| 607 |
+
loader_name=["test"],
|
| 608 |
+
):
|
| 609 |
+
self.model.eval()
|
| 610 |
+
self.print_log("Eval epoch: {}".format(epoch + 1))
|
| 611 |
+
|
| 612 |
+
vid_preds = []
|
| 613 |
+
frm_preds = []
|
| 614 |
+
vid_lens = []
|
| 615 |
+
labels = []
|
| 616 |
+
|
| 617 |
+
for ln in loader_name:
|
| 618 |
+
loss_value = []
|
| 619 |
+
step = 0
|
| 620 |
+
process = tqdm(self.data_loader[ln])
|
| 621 |
+
|
| 622 |
+
for batch_idx, (data, label, target, mask, index, soft_label) in enumerate(
|
| 623 |
+
process
|
| 624 |
+
):
|
| 625 |
+
data = data.float().cuda(self.output_device)
|
| 626 |
+
label = label.cuda(self.output_device)
|
| 627 |
+
mask = mask.cuda(self.output_device)
|
| 628 |
+
|
| 629 |
+
ab_labels = torch.cat([label, torch.ones(label.size(0), 1).cuda()], -1)
|
| 630 |
+
# print(data.shape)
|
| 631 |
+
# forward
|
| 632 |
+
mil_pred, frm_scrs, mil_pred_2, frm_scrs_2 = self.model(data,mask)
|
| 633 |
+
|
| 634 |
+
# cls_mil_loss = self.loss_nce(
|
| 635 |
+
# mil_pred, ab_labels.float()
|
| 636 |
+
# ) + self.loss_nce(mil_pred_2, ab_labels.float())
|
| 637 |
+
|
| 638 |
+
# loss_co = mvl_loss(frm_scrs, frm_scrs_2, rate=0.2, weight=0.5)
|
| 639 |
+
|
| 640 |
+
# loss = cls_mil_loss * self.arg.lambda_mil + loss_co
|
| 641 |
+
'''Loc LOSS'''
|
| 642 |
+
target = target.cuda(self.output_device)
|
| 643 |
+
''' into one hot'''
|
| 644 |
+
ground_truth_flat = target.view(-1)
|
| 645 |
+
one_hot_ground_truth = F.one_hot(ground_truth_flat, num_classes=5)
|
| 646 |
+
''' into one hot'''
|
| 647 |
+
frm_scrs_re = rearrange(frm_scrs, "n t c -> (n t) c")
|
| 648 |
+
frm_scrs_2_re = rearrange(frm_scrs_2, "n t c -> (n t) c")
|
| 649 |
+
'''Loc LOSS'''
|
| 650 |
+
'''Loc LOSS'''
|
| 651 |
+
loss = cross_entropy_loss(
|
| 652 |
+
frm_scrs_re, one_hot_ground_truth
|
| 653 |
+
) + cross_entropy_loss(frm_scrs_2_re, one_hot_ground_truth)
|
| 654 |
+
'''Loc LOSS'''
|
| 655 |
+
|
| 656 |
+
loss_value.append(loss.data.item())
|
| 657 |
+
|
| 658 |
+
for i in range(data.size(0)):
|
| 659 |
+
frm_scr = frm_scrs[i]
|
| 660 |
+
vid_pred = mil_pred[i]
|
| 661 |
+
|
| 662 |
+
label_ = label[i].cpu().numpy()
|
| 663 |
+
mask_ = mask[i].cpu().numpy()
|
| 664 |
+
vid_len = mask_.sum()
|
| 665 |
+
|
| 666 |
+
frm_pred = F.softmax(frm_scr, -1).cpu().numpy()[:vid_len]
|
| 667 |
+
vid_pred = vid_pred.cpu().numpy()
|
| 668 |
+
|
| 669 |
+
vid_preds.append(vid_pred)
|
| 670 |
+
frm_preds.append(frm_pred)
|
| 671 |
+
vid_lens.append(vid_len)
|
| 672 |
+
labels.append(label_)
|
| 673 |
+
|
| 674 |
+
step += 1
|
| 675 |
+
|
| 676 |
+
vid_preds = np.array(vid_preds)
|
| 677 |
+
frm_preds = np.array(frm_preds)
|
| 678 |
+
vid_lens = np.array(vid_lens)
|
| 679 |
+
labels = np.array(labels)
|
| 680 |
+
|
| 681 |
+
cmap = cmAP(vid_preds, labels)
|
| 682 |
+
|
| 683 |
+
score = cmap
|
| 684 |
+
loss = np.mean(loss_value)
|
| 685 |
+
|
| 686 |
+
dmap, iou = dsmAP(
|
| 687 |
+
vid_preds,
|
| 688 |
+
frm_preds,
|
| 689 |
+
vid_lens,
|
| 690 |
+
self.arg.test_feeder_args["data_path"],
|
| 691 |
+
self.arg,
|
| 692 |
+
multi=True,
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
print("Classification map %f" % cmap)
|
| 696 |
+
for item in list(zip(iou, dmap)):
|
| 697 |
+
print("Detection map @ %f = %f" % (item[0], item[1]))
|
| 698 |
+
|
| 699 |
+
self.my_logger.append([epoch + 1, cmap] + dmap+ [np.mean(dmap)])
|
| 700 |
+
|
| 701 |
+
wb_dict["val loss"] = loss
|
| 702 |
+
wb_dict["val acc"] = score
|
| 703 |
+
|
| 704 |
+
if score > self.best_acc:
|
| 705 |
+
self.best_acc = score
|
| 706 |
+
|
| 707 |
+
print("Acc score: ", score, " model: ", self.arg.model_saved_name)
|
| 708 |
+
if self.arg.phase == "train":
|
| 709 |
+
self.val_writer.add_scalar("loss", loss, self.global_step)
|
| 710 |
+
self.val_writer.add_scalar("acc", score, self.global_step)
|
| 711 |
+
|
| 712 |
+
self.print_log(
|
| 713 |
+
"\tMean {} loss of {} batches: {}.".format(
|
| 714 |
+
ln, len(self.data_loader[ln]), np.mean(loss_value)
|
| 715 |
+
)
|
| 716 |
+
)
|
| 717 |
+
self.print_log("\tAcc score: {:.3f}%".format(score))
|
| 718 |
+
|
| 719 |
+
return wb_dict
|
| 720 |
+
|
| 721 |
+
def start(self):
|
| 722 |
+
wb_dict = {}
|
| 723 |
+
if self.arg.phase == "train":
|
| 724 |
+
self.print_log("Parameters:\n{}\n".format(str(vars(self.arg))))
|
| 725 |
+
self.global_step = (
|
| 726 |
+
self.arg.start_epoch
|
| 727 |
+
* len(self.data_loader["train"])
|
| 728 |
+
/ self.arg.batch_size
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
for epoch in range(self.arg.start_epoch, self.arg.num_epoch):
|
| 732 |
+
|
| 733 |
+
save_model = ((epoch + 1) % self.arg.save_interval == 0) or (
|
| 734 |
+
epoch + 1 == self.arg.num_epoch
|
| 735 |
+
)
|
| 736 |
+
wb_dict = {"lr": self.lr}
|
| 737 |
+
|
| 738 |
+
# Train
|
| 739 |
+
wb_dict = self.train(epoch, wb_dict, save_model=save_model)
|
| 740 |
+
if epoch%5==0:
|
| 741 |
+
# Eval. on val set
|
| 742 |
+
wb_dict = self.eval(epoch, wb_dict, loader_name=["test"])
|
| 743 |
+
# Log stats. for this epoch
|
| 744 |
+
print("Epoch: {0}\nMetrics: {1}".format(epoch, wb_dict))
|
| 745 |
+
|
| 746 |
+
print(
|
| 747 |
+
"best accuracy: ",
|
| 748 |
+
self.best_acc,
|
| 749 |
+
" model_name: ",
|
| 750 |
+
self.arg.model_saved_name,
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
elif self.arg.phase == "test":
|
| 754 |
+
if not self.arg.test_feeder_args["debug"]:
|
| 755 |
+
wf = self.arg.model_saved_name + "_wrong.txt"
|
| 756 |
+
rf = self.arg.model_saved_name + "_right.txt"
|
| 757 |
+
else:
|
| 758 |
+
wf = rf = None
|
| 759 |
+
if self.arg.weights is None:
|
| 760 |
+
raise ValueError("Please appoint --weights.")
|
| 761 |
+
self.arg.print_log = False
|
| 762 |
+
self.print_log("Model: {}.".format(self.arg.model))
|
| 763 |
+
self.print_log("Weights: {}.".format(self.arg.weights))
|
| 764 |
+
|
| 765 |
+
wb_dict = self.eval(
|
| 766 |
+
epoch=0,
|
| 767 |
+
wb_dict=wb_dict,
|
| 768 |
+
loader_name=["test"],
|
| 769 |
+
wrong_file=wf,
|
| 770 |
+
result_file=rf,
|
| 771 |
+
)
|
| 772 |
+
print("Inference metrics: ", wb_dict)
|
| 773 |
+
self.print_log("Done.\n")
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
def str2bool(v):
|
| 777 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 778 |
+
return True
|
| 779 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 780 |
+
return False
|
| 781 |
+
else:
|
| 782 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
def import_class(name):
|
| 786 |
+
components = name.split(".")
|
| 787 |
+
mod = __import__(components[0])
|
| 788 |
+
for comp in components[1:]:
|
| 789 |
+
mod = getattr(mod, comp)
|
| 790 |
+
return mod
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
if __name__ == "__main__":
|
| 794 |
+
parser = get_parser()
|
| 795 |
+
|
| 796 |
+
# load arg form config file
|
| 797 |
+
p = parser.parse_args()
|
| 798 |
+
if p.config is not None:
|
| 799 |
+
with open(p.config, "r") as f:
|
| 800 |
+
default_arg = yaml.safe_load(f)
|
| 801 |
+
key = vars(p).keys()
|
| 802 |
+
for k in default_arg.keys():
|
| 803 |
+
if k not in key:
|
| 804 |
+
print("WRONG ARG: {}".format(k))
|
| 805 |
+
assert k in key
|
| 806 |
+
parser.set_defaults(**default_arg)
|
| 807 |
+
|
| 808 |
+
arg = parser.parse_args()
|
| 809 |
+
print("BABEL Action Recognition")
|
| 810 |
+
print("Config: ", arg)
|
| 811 |
+
init_seed(arg.seed)
|
| 812 |
+
processor = Processor(arg)
|
| 813 |
+
processor.start()
|
utils/__init__.py
ADDED
|
File without changes
|
utils/logger.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# A simple torch style logger
|
| 2 |
+
# (C) Wei YANG 2017
|
| 3 |
+
from __future__ import absolute_import
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
__all__ = ["Logger", "LoggerMonitor", "savefig"]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def savefig(fname, dpi=None):
|
| 15 |
+
dpi = 150 if dpi == None else dpi
|
| 16 |
+
plt.savefig(fname, dpi=dpi)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def plot_overlap(logger, names=None):
|
| 20 |
+
names = logger.names if names == None else names
|
| 21 |
+
numbers = logger.numbers
|
| 22 |
+
for _, name in enumerate(names):
|
| 23 |
+
x = np.arange(len(numbers[name]))
|
| 24 |
+
plt.plot(x, np.asarray(numbers[name]))
|
| 25 |
+
return [logger.title + "(" + name + ")" for name in names]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Logger(object):
|
| 29 |
+
"""Save training process to log file with simple plot function."""
|
| 30 |
+
|
| 31 |
+
def __init__(self, fpath, title=None, resume=False):
|
| 32 |
+
self.file = None
|
| 33 |
+
self.resume = resume
|
| 34 |
+
self.title = "" if title == None else title
|
| 35 |
+
if fpath is not None:
|
| 36 |
+
if resume:
|
| 37 |
+
self.file = open(fpath, "r")
|
| 38 |
+
name = self.file.readline()
|
| 39 |
+
self.names = name.rstrip().split("\t")
|
| 40 |
+
self.numbers = {}
|
| 41 |
+
for _, name in enumerate(self.names):
|
| 42 |
+
self.numbers[name] = []
|
| 43 |
+
|
| 44 |
+
for numbers in self.file:
|
| 45 |
+
numbers = numbers.rstrip().split("\t")
|
| 46 |
+
for i in range(0, len(numbers)):
|
| 47 |
+
self.numbers[self.names[i]].append(numbers[i])
|
| 48 |
+
self.file.close()
|
| 49 |
+
self.file = open(fpath, "a")
|
| 50 |
+
else:
|
| 51 |
+
self.file = open(fpath, "w")
|
| 52 |
+
|
| 53 |
+
def set_names(self, names):
|
| 54 |
+
if self.resume:
|
| 55 |
+
pass
|
| 56 |
+
# initialize numbers as empty list
|
| 57 |
+
self.numbers = {}
|
| 58 |
+
self.names = names
|
| 59 |
+
for _, name in enumerate(self.names):
|
| 60 |
+
self.file.write(name)
|
| 61 |
+
self.file.write("\t")
|
| 62 |
+
self.numbers[name] = []
|
| 63 |
+
self.file.write("\n")
|
| 64 |
+
self.file.flush()
|
| 65 |
+
|
| 66 |
+
def append(self, numbers):
|
| 67 |
+
assert len(self.names) == len(numbers), "Numbers do not match names"
|
| 68 |
+
for index, num in enumerate(numbers):
|
| 69 |
+
self.file.write("{0:.6f}".format(num))
|
| 70 |
+
self.file.write("\t")
|
| 71 |
+
self.numbers[self.names[index]].append(num)
|
| 72 |
+
self.file.write("\n")
|
| 73 |
+
self.file.flush()
|
| 74 |
+
|
| 75 |
+
def plot(self, names=None):
|
| 76 |
+
names = self.names if names == None else names
|
| 77 |
+
numbers = self.numbers
|
| 78 |
+
for _, name in enumerate(names):
|
| 79 |
+
x = np.arange(len(numbers[name]))
|
| 80 |
+
plt.plot(x, np.asarray(numbers[name]))
|
| 81 |
+
plt.legend([self.title + "(" + name + ")" for name in names])
|
| 82 |
+
plt.grid(True)
|
| 83 |
+
|
| 84 |
+
def close(self):
|
| 85 |
+
if self.file is not None:
|
| 86 |
+
self.file.close()
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class LoggerMonitor(object):
|
| 90 |
+
"""Load and visualize multiple logs."""
|
| 91 |
+
|
| 92 |
+
def __init__(self, paths):
|
| 93 |
+
"""paths is a distionary with {name:filepath} pair"""
|
| 94 |
+
self.loggers = []
|
| 95 |
+
for title, path in paths.items():
|
| 96 |
+
logger = Logger(path, title=title, resume=True)
|
| 97 |
+
self.loggers.append(logger)
|
| 98 |
+
|
| 99 |
+
def plot(self, names=None):
|
| 100 |
+
plt.figure()
|
| 101 |
+
plt.subplot(121)
|
| 102 |
+
legend_text = []
|
| 103 |
+
for logger in self.loggers:
|
| 104 |
+
legend_text += plot_overlap(logger, names)
|
| 105 |
+
plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)
|
| 106 |
+
plt.grid(True)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
# # Example
|
| 111 |
+
# logger = Logger('test.txt')
|
| 112 |
+
# logger.set_names(['Train loss', 'Valid loss','Test loss'])
|
| 113 |
+
|
| 114 |
+
# length = 100
|
| 115 |
+
# t = np.arange(length)
|
| 116 |
+
# train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
|
| 117 |
+
# valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
|
| 118 |
+
# test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
|
| 119 |
+
|
| 120 |
+
# for i in range(0, length):
|
| 121 |
+
# logger.append([train_loss[i], valid_loss[i], test_loss[i]])
|
| 122 |
+
# logger.plot()
|
| 123 |
+
|
| 124 |
+
# Example: logger monitor
|
| 125 |
+
paths = {
|
| 126 |
+
"resadvnet20": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt",
|
| 127 |
+
"resadvnet32": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt",
|
| 128 |
+
"resadvnet44": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt",
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
field = ["Valid Acc."]
|
| 132 |
+
|
| 133 |
+
monitor = LoggerMonitor(paths)
|
| 134 |
+
monitor.plot(names=field)
|
| 135 |
+
savefig("test.eps")
|