Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +144 -0
- LICENSE +202 -0
- README.md +106 -0
- autoenc-32-0.01.pt +3 -0
- config/genforecast-radaronly-128x128-20step.yaml +1 -0
- config/genforecast-radaronly-256x256-20step.yaml +5 -0
- environment/environment.yml +4 -0
- environment/ldcast.yml +170 -0
- genforecast-radaronly-256x256-20step.pt +3 -0
- ldcast/analysis/confmatrix.py +117 -0
- ldcast/analysis/crps.py +162 -0
- ldcast/analysis/fss.py +137 -0
- ldcast/analysis/histogram.py +108 -0
- ldcast/analysis/rank.py +190 -0
- ldcast/features/.sampling.py.swp +0 -0
- ldcast/features/batch.py +375 -0
- ldcast/features/batch.py.save +378 -0
- ldcast/features/io.py +125 -0
- ldcast/features/patches.py +429 -0
- ldcast/features/patches.py.save +431 -0
- ldcast/features/sampling.py +215 -0
- ldcast/features/split.py +165 -0
- ldcast/features/transform.py +296 -0
- ldcast/features/utils.py +136 -0
- ldcast/forecast.py +264 -0
- ldcast/models/autoenc/autoenc.py +93 -0
- ldcast/models/autoenc/encoder.py +57 -0
- ldcast/models/autoenc/training.py +41 -0
- ldcast/models/benchmarks/dgmr.py +82 -0
- ldcast/models/benchmarks/pysteps.py +106 -0
- ldcast/models/benchmarks/transform.py +17 -0
- ldcast/models/blocks/afno.py +348 -0
- ldcast/models/blocks/attention.py +104 -0
- ldcast/models/blocks/resnet.py +70 -0
- ldcast/models/diffusion/diffusion.py +222 -0
- ldcast/models/diffusion/ema.py +76 -0
- ldcast/models/diffusion/plms.py +245 -0
- ldcast/models/diffusion/utils.py +246 -0
- ldcast/models/distributions.py +29 -0
- ldcast/models/genforecast/analysis.py +33 -0
- ldcast/models/genforecast/training.py +42 -0
- ldcast/models/genforecast/unet.py +489 -0
- ldcast/models/nowcast/nowcast.py +256 -0
- ldcast/models/utils.py +28 -0
- ldcast/visualization/cm.py +36 -0
- ldcast/visualization/plots.py +606 -0
- models/.keep +0 -0
- models/autoenc/autoenc-32-0.01.pt +3 -0
- models/autoenc/autoenc.pt +3 -0
- scripts/convert_data_NB_2nc.py +76 -0
.gitignore
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.vscode
|
| 2 |
+
.DS_Store
|
| 3 |
+
dask-worker-space
|
| 4 |
+
*~
|
| 5 |
+
lightning_logs
|
| 6 |
+
|
| 7 |
+
# Byte-compiled / optimized / DLL files
|
| 8 |
+
__pycache__/
|
| 9 |
+
*.py[cod]
|
| 10 |
+
*$py.class
|
| 11 |
+
|
| 12 |
+
# C extensions
|
| 13 |
+
*.so
|
| 14 |
+
|
| 15 |
+
# Distribution / packaging
|
| 16 |
+
.Python
|
| 17 |
+
build/
|
| 18 |
+
develop-eggs/
|
| 19 |
+
dist/
|
| 20 |
+
downloads/
|
| 21 |
+
eggs/
|
| 22 |
+
.eggs/
|
| 23 |
+
lib/
|
| 24 |
+
lib64/
|
| 25 |
+
parts/
|
| 26 |
+
sdist/
|
| 27 |
+
var/
|
| 28 |
+
wheels/
|
| 29 |
+
share/python-wheels/
|
| 30 |
+
*.egg-info/
|
| 31 |
+
.installed.cfg
|
| 32 |
+
*.egg
|
| 33 |
+
MANIFEST
|
| 34 |
+
|
| 35 |
+
# PyInstaller
|
| 36 |
+
# Usually these files are written by a python script from a template
|
| 37 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 38 |
+
*.manifest
|
| 39 |
+
*.spec
|
| 40 |
+
|
| 41 |
+
# Installer logs
|
| 42 |
+
pip-log.txt
|
| 43 |
+
pip-delete-this-directory.txt
|
| 44 |
+
|
| 45 |
+
# Unit test / coverage reports
|
| 46 |
+
htmlcov/
|
| 47 |
+
.tox/
|
| 48 |
+
.nox/
|
| 49 |
+
.coverage
|
| 50 |
+
.coverage.*
|
| 51 |
+
.cache
|
| 52 |
+
nosetests.xml
|
| 53 |
+
coverage.xml
|
| 54 |
+
*.cover
|
| 55 |
+
*.py,cover
|
| 56 |
+
.hypothesis/
|
| 57 |
+
.pytest_cache/
|
| 58 |
+
cover/
|
| 59 |
+
|
| 60 |
+
# Translations
|
| 61 |
+
*.mo
|
| 62 |
+
*.pot
|
| 63 |
+
|
| 64 |
+
# Django stuff:
|
| 65 |
+
*.log
|
| 66 |
+
local_settings.py
|
| 67 |
+
db.sqlite3
|
| 68 |
+
db.sqlite3-journal
|
| 69 |
+
|
| 70 |
+
# Flask stuff:
|
| 71 |
+
instance/
|
| 72 |
+
.webassets-cache
|
| 73 |
+
|
| 74 |
+
# Scrapy stuff:
|
| 75 |
+
.scrapy
|
| 76 |
+
|
| 77 |
+
# Sphinx documentation
|
| 78 |
+
docs/_build/
|
| 79 |
+
|
| 80 |
+
# PyBuilder
|
| 81 |
+
.pybuilder/
|
| 82 |
+
target/
|
| 83 |
+
|
| 84 |
+
# Jupyter Notebook
|
| 85 |
+
.ipynb_checkpoints
|
| 86 |
+
|
| 87 |
+
# IPython
|
| 88 |
+
profile_default/
|
| 89 |
+
ipython_config.py
|
| 90 |
+
|
| 91 |
+
# pyenv
|
| 92 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 93 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 94 |
+
# .python-version
|
| 95 |
+
|
| 96 |
+
# pipenv
|
| 97 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 98 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 99 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 100 |
+
# install all needed dependencies.
|
| 101 |
+
#Pipfile.lock
|
| 102 |
+
|
| 103 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 104 |
+
__pypackages__/
|
| 105 |
+
|
| 106 |
+
# Celery stuff
|
| 107 |
+
celerybeat-schedule
|
| 108 |
+
celerybeat.pid
|
| 109 |
+
|
| 110 |
+
# SageMath parsed files
|
| 111 |
+
*.sage.py
|
| 112 |
+
|
| 113 |
+
# Environments
|
| 114 |
+
.env
|
| 115 |
+
.venv
|
| 116 |
+
env/
|
| 117 |
+
venv/
|
| 118 |
+
ENV/
|
| 119 |
+
env.bak/
|
| 120 |
+
venv.bak/
|
| 121 |
+
|
| 122 |
+
# Spyder project settings
|
| 123 |
+
.spyderproject
|
| 124 |
+
.spyproject
|
| 125 |
+
|
| 126 |
+
# Rope project settings
|
| 127 |
+
.ropeproject
|
| 128 |
+
|
| 129 |
+
# mkdocs documentation
|
| 130 |
+
/site
|
| 131 |
+
|
| 132 |
+
# mypy
|
| 133 |
+
.mypy_cache/
|
| 134 |
+
.dmypy.json
|
| 135 |
+
dmypy.json
|
| 136 |
+
|
| 137 |
+
# Pyre type checker
|
| 138 |
+
.pyre/
|
| 139 |
+
|
| 140 |
+
# pytype static type analyzer
|
| 141 |
+
.pytype/
|
| 142 |
+
|
| 143 |
+
# Cython debug symbols
|
| 144 |
+
cython_debug/
|
LICENSE
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
| 202 |
+
|
README.md
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LDCast is a precipitation nowcasting model based on a latent diffusion model (LDM, used by e.g. [Stable Diffusion](https://github.com/CompVis/stable-diffusion)).
|
| 2 |
+
|
| 3 |
+
This repository contains the code for using LDCast to make predictions and the code used to generate the analysis in the LDCast paper (a preprint is available at https://arxiv.org/abs/2304.12891).
|
| 4 |
+
|
| 5 |
+
A GPU is recommended for both using and training LDCast, although you may be able to generate some samples with a CPU and enough patience.
|
| 6 |
+
|
| 7 |
+
# Installation
|
| 8 |
+
|
| 9 |
+
It is recommended you install the code in its own virtual environment (created with e.g. pyenv or conda).
|
| 10 |
+
|
| 11 |
+
Clone the repository, then, in the main directory, run
|
| 12 |
+
```bash
|
| 13 |
+
$ pip install -e .
|
| 14 |
+
```
|
| 15 |
+
This should automatically install the required packages (which might take some minutes). In the paper, we used PyTorch 11.2 but are not aware of any problems with newer versions.
|
| 16 |
+
|
| 17 |
+
If you don't want the requirements to be installed (e.g. if you installed them manually with conda), use:
|
| 18 |
+
```bash
|
| 19 |
+
$ pip install --no-dependencies -e .
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
# Using LDCast
|
| 23 |
+
|
| 24 |
+
## Pretrained models
|
| 25 |
+
|
| 26 |
+
The pretrained models are available at the Zenodo repository https://doi.org/10.5281/zenodo.7780914. Unzip the file `ldcast-models.zip`. The default is to unzip it to the `models` directory, but you can also use another location.
|
| 27 |
+
|
| 28 |
+
## Producing predictions
|
| 29 |
+
|
| 30 |
+
The easiest way to produce predictions is to use the `ldcast.forecast.Forecast` class, which will set up all models and data transformations and is callable with a past precipitation array.
|
| 31 |
+
```python
|
| 32 |
+
from ldcast import forecast
|
| 33 |
+
|
| 34 |
+
fc = forecast.Forecast(
|
| 35 |
+
ldm_weights_fn=ldm_weights_fn, autoenc_weights_fn=autoenc_weights_fn
|
| 36 |
+
)
|
| 37 |
+
R_pred = fc(R_past)
|
| 38 |
+
```
|
| 39 |
+
Here, `ldm_weights_fn` is the path to the LDM weights and `autoenc_weights_fn` is the path to the autoencoder weights. `R_past` is a NumPy array of precipitation rates with shape `(timesteps, height, width)` where `timesteps` must be 4 and `height` and `width` must be divisible by 32.
|
| 40 |
+
|
| 41 |
+
### Ensemble predictions
|
| 42 |
+
|
| 43 |
+
If want to process multiple cases at once and/or generate several ensemble members, there is the `ldcast.forecast.ForecastDistributed` class. The usage is similar to the `Forecast` class, for example:
|
| 44 |
+
```python
|
| 45 |
+
from ldcast import forecast
|
| 46 |
+
|
| 47 |
+
fc = forecast.ForecastDistributed(
|
| 48 |
+
ldm_weights_fn=ldm_weights_fn, autoenc_weights_fn=autoenc_weights_fn
|
| 49 |
+
)
|
| 50 |
+
R_pred = fc(R_past, ensemble_members=32)
|
| 51 |
+
```
|
| 52 |
+
Here, `R_past` should be of shape `(cases, timesteps, height, width)` where `cases` is the number of cases you want to process. For each case, `ensemble_members` predictions are produced (this is the last axis of `R_pred`). `ForecastDistributed` automatically distributes the workload to multiple GPUs if you have them.
|
| 53 |
+
|
| 54 |
+
## Demo
|
| 55 |
+
|
| 56 |
+
For a practical example, you can run the demo in the `scripts` directory. First download the `ldcast-demo-20210622.zip` file from the [Zenodo repository](https://doi.org/10.5281/zenodo.7780914), then unzip it in the `data` directory. Then run
|
| 57 |
+
```bash
|
| 58 |
+
$ python forecast_demo.py
|
| 59 |
+
```
|
| 60 |
+
A sample output can be found in the file `ldcast-demo-video-20210622.zip` in the data repository. See the function `forecast_demo` in `forecast_demo.py` see how the `Forecast` class works. To run an ensemble mean of 8 members using the `ForecastDistributed` class, you can use:
|
| 61 |
+
```bash
|
| 62 |
+
$ python forecast_demo.py --ensemble-members=8
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
The demo for a single ensemble member runs in a couple of minutes on our system using one V100 GPU; with a CPU around 10 minutes or more would be expected. A progress bar will show the status of the generation.
|
| 66 |
+
|
| 67 |
+
# Training
|
| 68 |
+
|
| 69 |
+
## Training data
|
| 70 |
+
|
| 71 |
+
The preprocessed training data, needed to rerun the LDCast training, can be found at the [Zenodo repository](https://doi.org/10.5281/zenodo.7780914). Unzip the `ldcast-datasets.zip` file to the `data` directory.
|
| 72 |
+
|
| 73 |
+
## Training the autoencoder
|
| 74 |
+
|
| 75 |
+
In the `scripts` directory, run
|
| 76 |
+
```bash
|
| 77 |
+
$ python train_autoenc.py --model_dir="../models/autoenc_train"
|
| 78 |
+
```
|
| 79 |
+
to run the training of the autoencoder with the default parameters. The training checkpoints will be saved in the `../models/autoenc_train` directory (feel free to change this).
|
| 80 |
+
|
| 81 |
+
It has been reported that this training may encounter a condition where the loss goes to `nan`. If this happens, try restarting from the latest checkpoint:
|
| 82 |
+
```bash
|
| 83 |
+
$ python train_autoenc.py --model_dir="../models/autoenc_train" --ckpt_path="../models/autoenc_train/<checkpoint_file>"
|
| 84 |
+
```
|
| 85 |
+
where `<checkpoint_file>` should be the latest checkpoint in the `../models/autoenc_train/` directory.
|
| 86 |
+
|
| 87 |
+
## Training the diffusion model
|
| 88 |
+
|
| 89 |
+
In the `scripts` directory, run
|
| 90 |
+
```bash
|
| 91 |
+
$ python train_genforecast.py --model_dir="../models/genforecast_train"
|
| 92 |
+
```
|
| 93 |
+
to run the training of the diffusion model with the default parameters, or
|
| 94 |
+
```bash
|
| 95 |
+
$ python train_genforecast.py --model_dir="../models/genforecast_train" --config=<path_to_config_file>
|
| 96 |
+
```
|
| 97 |
+
to run the training with different parameters. Some config files can be found in the `config` directory. The training checkpoints will be saved in the `../models/genforecast_train` directory (again, this can be changed freely).
|
| 98 |
+
|
| 99 |
+
# Evaluation
|
| 100 |
+
|
| 101 |
+
You can find scripts for evaluating models in the `scripts` directory:
|
| 102 |
+
* `eval_genforecast.py` to evaluate LDCast
|
| 103 |
+
* `eval_dgmr.py` to evaluate DGMR (requires tensorflow installation and the DGMR model from https://github.com/deepmind/deepmind-research/tree/master/nowcasting placed in the `models/dgmr` directory)
|
| 104 |
+
* `eval_pysteps.py` to evaluate PySTEPS (requires pysteps installation)
|
| 105 |
+
* `metrics.py` to produce metrics from the evaluation results produced with the functions in scripts above
|
| 106 |
+
* `plot_genforecast.py` to make plots from the results generated
|
autoenc-32-0.01.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aa5ad4b8689aadbf702376e7afe5cb437ef5057675e78a8986837e8f28b3126e
|
| 3 |
+
size 1617490
|
config/genforecast-radaronly-128x128-20step.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# this configuration is the default - no parameters to override!
|
config/genforecast-radaronly-256x256-20step.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sample_shape: [8,8]
|
| 2 |
+
batch_size: 24
|
| 3 |
+
sampler: "sampler_nowcaster256"
|
| 4 |
+
initial_weights: "../models/genforecast/genforecast-radaronly-128x128-20step.pt"
|
| 5 |
+
lr: 2.5e-5
|
environment/environment.yml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: ldcast
|
| 2 |
+
channels:
|
| 3 |
+
- defaults
|
| 4 |
+
prefix: /home/mmhk20/.conda/envs/ldcast
|
environment/ldcast.yml
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: ldcast
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- _libgcc_mutex=0.1=conda_forge
|
| 7 |
+
- _openmp_mutex=4.5=2_gnu
|
| 8 |
+
- brotli=1.1.0=hd590300_1
|
| 9 |
+
- brotli-bin=1.1.0=hd590300_1
|
| 10 |
+
- bzip2=1.0.8=hd590300_5
|
| 11 |
+
- c-ares=1.27.0=hd590300_0
|
| 12 |
+
- ca-certificates=2024.2.2=hbcca054_0
|
| 13 |
+
- certifi=2024.2.2=pyhd8ed1ab_0
|
| 14 |
+
- cycler=0.12.1=pyhd8ed1ab_0
|
| 15 |
+
- freetype=2.12.1=h267a509_2
|
| 16 |
+
- geos=3.12.1=h59595ed_0
|
| 17 |
+
- keyutils=1.6.1=h166bdaf_0
|
| 18 |
+
- krb5=1.21.2=h659d440_0
|
| 19 |
+
- lcms2=2.16=hb7c19ff_0
|
| 20 |
+
- ld_impl_linux-64=2.40=h41732ed_0
|
| 21 |
+
- lerc=4.0.0=h27087fc_0
|
| 22 |
+
- libblas=3.9.0=21_linux64_openblas
|
| 23 |
+
- libbrotlicommon=1.1.0=hd590300_1
|
| 24 |
+
- libbrotlidec=1.1.0=hd590300_1
|
| 25 |
+
- libbrotlienc=1.1.0=hd590300_1
|
| 26 |
+
- libcblas=3.9.0=21_linux64_openblas
|
| 27 |
+
- libcurl=8.6.0=hca28451_0
|
| 28 |
+
- libdeflate=1.19=hd590300_0
|
| 29 |
+
- libedit=3.1.20191231=he28a2e2_2
|
| 30 |
+
- libev=4.33=hd590300_2
|
| 31 |
+
- libexpat=2.6.2=h59595ed_0
|
| 32 |
+
- libffi=3.4.2=h7f98852_5
|
| 33 |
+
- libgcc-ng=13.2.0=h807b86a_5
|
| 34 |
+
- libgfortran-ng=13.2.0=h69a702a_5
|
| 35 |
+
- libgfortran5=13.2.0=ha4646dd_5
|
| 36 |
+
- libgomp=13.2.0=h807b86a_5
|
| 37 |
+
- libjpeg-turbo=3.0.0=hd590300_1
|
| 38 |
+
- liblapack=3.9.0=21_linux64_openblas
|
| 39 |
+
- libnghttp2=1.58.0=h47da74e_1
|
| 40 |
+
- libnsl=2.0.1=hd590300_0
|
| 41 |
+
- libopenblas=0.3.26=pthreads_h413a1c8_0
|
| 42 |
+
- libpng=1.6.43=h2797004_0
|
| 43 |
+
- libsqlite=3.45.2=h2797004_0
|
| 44 |
+
- libssh2=1.11.0=h0841786_0
|
| 45 |
+
- libstdcxx-ng=13.2.0=h7e041cc_5
|
| 46 |
+
- libtiff=4.6.0=ha9c0a0a_2
|
| 47 |
+
- libuuid=2.38.1=h0b41bf4_0
|
| 48 |
+
- libwebp-base=1.3.2=hd590300_0
|
| 49 |
+
- libxcb=1.15=h0b41bf4_0
|
| 50 |
+
- libxcrypt=4.4.36=hd590300_1
|
| 51 |
+
- libzlib=1.2.13=hd590300_5
|
| 52 |
+
- matplotlib-base=3.8.3=py312he5832f3_0
|
| 53 |
+
- munkres=1.1.4=pyh9f0ad1d_0
|
| 54 |
+
- ncurses=6.4.20240210=h59595ed_0
|
| 55 |
+
- openjpeg=2.5.2=h488ebb8_0
|
| 56 |
+
- openssl=3.2.1=hd590300_1
|
| 57 |
+
- packaging=24.0=pyhd8ed1ab_0
|
| 58 |
+
- pip=24.0=pyhd8ed1ab_0
|
| 59 |
+
- proj=9.3.1=h1d62c97_0
|
| 60 |
+
- pthread-stubs=0.4=h36c2ea0_1001
|
| 61 |
+
- pyparsing=3.1.2=pyhd8ed1ab_0
|
| 62 |
+
- pyshp=2.3.1=pyhd8ed1ab_0
|
| 63 |
+
- python=3.12.2=hab00c5b_0_cpython
|
| 64 |
+
- python-dateutil=2.9.0=pyhd8ed1ab_0
|
| 65 |
+
- python_abi=3.12=4_cp312
|
| 66 |
+
- readline=8.2=h8228510_1
|
| 67 |
+
- setuptools=69.2.0=pyhd8ed1ab_0
|
| 68 |
+
- six=1.16.0=pyh6c4a22f_0
|
| 69 |
+
- sqlite=3.45.2=h2c6b66d_0
|
| 70 |
+
- tk=8.6.13=noxft_h4845f30_101
|
| 71 |
+
- wheel=0.43.0=pyhd8ed1ab_0
|
| 72 |
+
- xorg-libxau=1.0.11=hd590300_0
|
| 73 |
+
- xorg-libxdmcp=1.1.3=h7f98852_0
|
| 74 |
+
- xz=5.2.6=h166bdaf_0
|
| 75 |
+
- zstd=1.5.5=hfc55251_0
|
| 76 |
+
- pip:
|
| 77 |
+
- aiobotocore==2.12.1
|
| 78 |
+
- aiohttp==3.9.3
|
| 79 |
+
- aioitertools==0.11.0
|
| 80 |
+
- aiosignal==1.3.1
|
| 81 |
+
- antlr4-python3-runtime==4.9.3
|
| 82 |
+
- arm-pyart==1.18.0
|
| 83 |
+
- attrs==23.2.0
|
| 84 |
+
- botocore==1.34.51
|
| 85 |
+
- cartopy==0.22.0
|
| 86 |
+
- cftime==1.6.3
|
| 87 |
+
- charset-normalizer==3.3.2
|
| 88 |
+
- click==8.1.7
|
| 89 |
+
- cloudpickle==3.0.0
|
| 90 |
+
- cmweather==0.3.2
|
| 91 |
+
- contourpy==1.2.0
|
| 92 |
+
- dask==2024.3.1
|
| 93 |
+
- deprecation==2.1.0
|
| 94 |
+
- einops==0.7.0
|
| 95 |
+
- filelock==3.13.1
|
| 96 |
+
- fire==0.6.0
|
| 97 |
+
- fonttools==4.50.0
|
| 98 |
+
- frozenlist==1.4.1
|
| 99 |
+
- fsspec==2024.3.1
|
| 100 |
+
- h5netcdf==1.3.0
|
| 101 |
+
- h5py==3.10.0
|
| 102 |
+
- idna==3.6
|
| 103 |
+
- jinja2==3.1.3
|
| 104 |
+
- jmespath==1.0.1
|
| 105 |
+
- jsmin==3.0.1
|
| 106 |
+
- jsonschema==4.22.0
|
| 107 |
+
- jsonschema-specifications==2023.12.1
|
| 108 |
+
- kiwisolver==1.4.5
|
| 109 |
+
- lat-lon-parser==1.3.0
|
| 110 |
+
- lightning==2.2.4
|
| 111 |
+
- lightning-utilities==0.11.0
|
| 112 |
+
- llvmlite==0.42.0
|
| 113 |
+
- locket==1.0.0
|
| 114 |
+
- markupsafe==2.1.5
|
| 115 |
+
- matplotlib==3.8.3
|
| 116 |
+
- mda-xdrlib==0.2.0
|
| 117 |
+
- mpmath==1.3.0
|
| 118 |
+
- multidict==6.0.5
|
| 119 |
+
- netcdf4==1.6.5
|
| 120 |
+
- networkx==3.2.1
|
| 121 |
+
- numba==0.59.1
|
| 122 |
+
- numpy==1.26.4
|
| 123 |
+
- nvidia-cublas-cu12==12.1.3.1
|
| 124 |
+
- nvidia-cuda-cupti-cu12==12.1.105
|
| 125 |
+
- nvidia-cuda-nvrtc-cu12==12.1.105
|
| 126 |
+
- nvidia-cuda-runtime-cu12==12.1.105
|
| 127 |
+
- nvidia-cudnn-cu12==8.9.2.26
|
| 128 |
+
- nvidia-cufft-cu12==11.0.2.54
|
| 129 |
+
- nvidia-curand-cu12==10.3.2.106
|
| 130 |
+
- nvidia-cusolver-cu12==11.4.5.107
|
| 131 |
+
- nvidia-cusparse-cu12==12.1.0.106
|
| 132 |
+
- nvidia-nccl-cu12==2.20.5
|
| 133 |
+
- nvidia-nvjitlink-cu12==12.4.99
|
| 134 |
+
- nvidia-nvtx-cu12==12.1.105
|
| 135 |
+
- omegaconf==2.3.0
|
| 136 |
+
- open-radar-data==0.1.0
|
| 137 |
+
- opencv-python==4.9.0.80
|
| 138 |
+
- pandas==2.2.1
|
| 139 |
+
- partd==1.4.1
|
| 140 |
+
- pillow==10.2.0
|
| 141 |
+
- platformdirs==4.2.0
|
| 142 |
+
- pooch==1.8.1
|
| 143 |
+
- pyproj==3.6.1
|
| 144 |
+
- pysteps==1.9.0
|
| 145 |
+
- pytorch-lightning==2.2.1
|
| 146 |
+
- pytz==2024.1
|
| 147 |
+
- pyyaml==6.0.1
|
| 148 |
+
- referencing==0.35.1
|
| 149 |
+
- requests==2.31.0
|
| 150 |
+
- rpds-py==0.18.1
|
| 151 |
+
- s3fs==2024.3.1
|
| 152 |
+
- scipy==1.12.0
|
| 153 |
+
- shapely==2.0.3
|
| 154 |
+
- sympy==1.12
|
| 155 |
+
- termcolor==2.4.0
|
| 156 |
+
- toolz==0.12.1
|
| 157 |
+
- torch==2.3.0
|
| 158 |
+
- torchmetrics==1.3.2
|
| 159 |
+
- torchvision==0.18.0
|
| 160 |
+
- tqdm==4.66.2
|
| 161 |
+
- typing-extensions==4.10.0
|
| 162 |
+
- tzdata==2024.1
|
| 163 |
+
- urllib3==2.0.7
|
| 164 |
+
- wradlib==2.0.3
|
| 165 |
+
- wrapt==1.16.0
|
| 166 |
+
- xarray==2024.2.0
|
| 167 |
+
- xarray-datatree==0.0.14
|
| 168 |
+
- xmltodict==0.13.0
|
| 169 |
+
- xradar==0.4.3
|
| 170 |
+
- yarl==1.9.4
|
genforecast-radaronly-256x256-20step.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5fef86b78d29fde8ba66f51ae74f0d84ddc67b711fcab034a3130ea5ac7721cf
|
| 3 |
+
size 5345469521
|
ldcast/analysis/confmatrix.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import concurrent
|
| 3 |
+
import multiprocessing
|
| 4 |
+
|
| 5 |
+
import netCDF4
|
| 6 |
+
import numpy as np
|
| 7 |
+
from scipy.integrate import trapezoid
|
| 8 |
+
|
| 9 |
+
from ..features.io import load_batch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def confusion_matrix(fc_frac, obs_frac, prob_threshold):
|
| 13 |
+
N = np.prod(fc_frac.shape)
|
| 14 |
+
fc_above = fc_frac > prob_threshold
|
| 15 |
+
obs_above = obs_frac > prob_threshold
|
| 16 |
+
tp = np.count_nonzero(fc_above & obs_above) / N
|
| 17 |
+
fp = np.count_nonzero(fc_above & ~obs_above) / N
|
| 18 |
+
fn = np.count_nonzero(~fc_above & obs_above) / N
|
| 19 |
+
tn = 1.0 - tp - fp - fn
|
| 20 |
+
return np.array(((tp, fn), (fp, tn)))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def confusion_matrix_thresholds(fc_frac, obs_frac, thresholds):
|
| 24 |
+
N_threads = multiprocessing.cpu_count()
|
| 25 |
+
with concurrent.futures.ThreadPoolExecutor(N_threads) as executor:
|
| 26 |
+
futures = [
|
| 27 |
+
executor.submit(confusion_matrix, fc_frac, obs_frac, t)
|
| 28 |
+
for t in thresholds
|
| 29 |
+
]
|
| 30 |
+
conf_matrix = [f.result() for f in futures]
|
| 31 |
+
return np.stack(conf_matrix, axis=-1)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def confusion_matrix_thresholds_leadtime(fc_frac, obs_frac, thresholds):
|
| 35 |
+
N_threads = multiprocessing.cpu_count()
|
| 36 |
+
conf_matrix = []
|
| 37 |
+
with concurrent.futures.ThreadPoolExecutor(N_threads) as executor:
|
| 38 |
+
for lt in range(fc_frac.shape[2]):
|
| 39 |
+
futures = [
|
| 40 |
+
executor.submit(confusion_matrix,
|
| 41 |
+
fc_frac[...,lt,:,:], obs_frac[...,lt,:,:], t)
|
| 42 |
+
for t in thresholds
|
| 43 |
+
]
|
| 44 |
+
conf_matrix_lt = [f.result() for f in futures]
|
| 45 |
+
conf_matrix_lt = np.stack(conf_matrix_lt, axis=-1)
|
| 46 |
+
conf_matrix.append(conf_matrix_lt)
|
| 47 |
+
|
| 48 |
+
return np.stack(conf_matrix, axis=-2)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def precision(conf_matrix):
|
| 53 |
+
((tp, fn), (fp, tn)) = conf_matrix
|
| 54 |
+
precision = tp / (tp + fp)
|
| 55 |
+
precision[np.isnan(precision)] = 1.0
|
| 56 |
+
return precision
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def recall(conf_matrix):
|
| 60 |
+
((tp, fn), (fp, tn)) = conf_matrix
|
| 61 |
+
return tp / (tp + fn)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def false_alarm_ratio(conf_matrix):
|
| 65 |
+
return 1.0 - precision(conf_matrix)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def intersection_over_union(conf_matrix):
|
| 69 |
+
((tp, fn), (fp, tn)) = conf_matrix
|
| 70 |
+
return tp / (tp+fp+fn)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def equitable_threat_score(conf_matrix):
|
| 74 |
+
((tp, fn), (fp, tn)) = conf_matrix
|
| 75 |
+
tp_rnd = (tp+fn) * (tp+fp) / (tp+fp+tn+fn)
|
| 76 |
+
return (tp-tp_rnd) / (tp+fp+fn-tp_rnd)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def peirce_skill_score(conf_matrix):
|
| 80 |
+
((tp, fn), (fp, tn)) = conf_matrix
|
| 81 |
+
return (tp*tn - fn*fp) / ((tp+fn) * (fp+tn))
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def heidke_skill_score(conf_matrix):
|
| 85 |
+
((tp, fn), (fp, tn)) = conf_matrix
|
| 86 |
+
return 2 * (tp*tn - fn*fp) / ((tp+fn)*(fn+tn) + (tp+fp)*(fp+tn))
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def roc_area_under_curve(conf_matrix):
|
| 90 |
+
((tp, fn), (fp, tn)) = conf_matrix
|
| 91 |
+
tpr = tp / (tp + fn)
|
| 92 |
+
fpr = fp / (fp + tn)
|
| 93 |
+
|
| 94 |
+
auc = trapezoid(tpr[::-1], x=fpr[::-1])
|
| 95 |
+
return auc
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def pr_area_under_curve(conf_matrix):
|
| 99 |
+
prec = precision(conf_matrix)
|
| 100 |
+
rec = recall(conf_matrix)
|
| 101 |
+
|
| 102 |
+
if (rec[-1] != 0) or (prec[-1] != 1):
|
| 103 |
+
rec = np.hstack((rec, 0.0))
|
| 104 |
+
prec = np.hstack((prec, 1.0))
|
| 105 |
+
|
| 106 |
+
auc = trapezoid(prec[::-1], x=rec[::-1])
|
| 107 |
+
return auc
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def cost_loss_value(conf_matrix, cost, loss, p_clim):
|
| 111 |
+
((tp, fn), (fp, tn)) = conf_matrix
|
| 112 |
+
|
| 113 |
+
E_c = min(cost, p_clim*loss)
|
| 114 |
+
E_p = p_clim * cost
|
| 115 |
+
E_f = (tp+fp)*cost + fn*loss
|
| 116 |
+
#print(cost, loss, p_clim, E_c, E_p, E_f[len(E_f)//2]/E_p, (E_f[len(E_f)//2] - E_c) / (E_p - E_c))
|
| 117 |
+
return (E_f - E_c) / (E_p - E_c)
|
ldcast/analysis/crps.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import concurrent.futures
|
| 2 |
+
import multiprocessing
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import netCDF4
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from ..features.io import load_batch, decode_saved_var_to_rainrate
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def crps_ensemble(observation, forecasts):
|
| 12 |
+
shape = observation.shape
|
| 13 |
+
N = np.prod(shape)
|
| 14 |
+
shape_flat = (np.prod(shape),)
|
| 15 |
+
observation = observation.reshape((N,))
|
| 16 |
+
forecasts = forecasts.reshape((N, forecasts.shape[-1]))
|
| 17 |
+
crps_all = np.zeros_like(observation)
|
| 18 |
+
N_threads = multiprocessing.cpu_count()
|
| 19 |
+
|
| 20 |
+
def crps_chunk(k):
|
| 21 |
+
i0 = int(round((k/N_threads) * N))
|
| 22 |
+
i1 = int(round(((k+1) / N_threads) * N))
|
| 23 |
+
obs = observation[i0:i1].copy()
|
| 24 |
+
fc = forecasts[i0:i1,:].copy()
|
| 25 |
+
fc.sort(axis=-1)
|
| 26 |
+
fc_below = fc < obs[...,None]
|
| 27 |
+
crps = np.zeros_like(obs)
|
| 28 |
+
|
| 29 |
+
for i in range(fc.shape[-1]):
|
| 30 |
+
below = fc_below[...,i]
|
| 31 |
+
weight = ((i+1)**2 - i**2) / fc.shape[-1]**2
|
| 32 |
+
crps[below] += weight * (obs[below]-fc[...,i][below])
|
| 33 |
+
|
| 34 |
+
for i in range(fc.shape[-1]-1,-1,-1):
|
| 35 |
+
above = ~fc_below[...,i]
|
| 36 |
+
k = fc.shape[-1]-1-i
|
| 37 |
+
weight = ((k+1)**2 - k**2) / fc.shape[-1]**2
|
| 38 |
+
crps[above] += weight * (fc[...,i][above]-obs[above])
|
| 39 |
+
|
| 40 |
+
crps_all[i0:i1] = crps
|
| 41 |
+
|
| 42 |
+
with concurrent.futures.ThreadPoolExecutor(N_threads) as executor:
|
| 43 |
+
futures = {}
|
| 44 |
+
for k in range(N_threads):
|
| 45 |
+
args = (crps_chunk, k)
|
| 46 |
+
futures[executor.submit(*args)] = k
|
| 47 |
+
concurrent.futures.wait(futures)
|
| 48 |
+
|
| 49 |
+
return crps_all.reshape(shape)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def crps_ensemble_multiscale(observation, forecasts):
|
| 53 |
+
obs = observation
|
| 54 |
+
fc = forecasts
|
| 55 |
+
|
| 56 |
+
crps_scales = {}
|
| 57 |
+
scale = 1
|
| 58 |
+
while True:
|
| 59 |
+
c = crps_ensemble(obs, fc)
|
| 60 |
+
crps_scales[scale] = c
|
| 61 |
+
scale *= 2
|
| 62 |
+
if obs.shape[-1] == 1:
|
| 63 |
+
break
|
| 64 |
+
# avg pooling
|
| 65 |
+
obs = 0.25 * (
|
| 66 |
+
obs[...,::2,::2] +
|
| 67 |
+
obs[...,1::2,::2] +
|
| 68 |
+
obs[...,::2,1::2] +
|
| 69 |
+
obs[...,1::2,1::2]
|
| 70 |
+
)
|
| 71 |
+
fc = 0.25 * (
|
| 72 |
+
fc[...,::2,::2,:] +
|
| 73 |
+
fc[...,1::2,::2,:] +
|
| 74 |
+
fc[...,::2,1::2,:] +
|
| 75 |
+
fc[...,1::2,1::2,:]
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
return crps_scales
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def gather_observation(data_dir):
|
| 82 |
+
files = sorted(os.listdir(data_dir))
|
| 83 |
+
files = [os.path.join(data_dir,fn) for fn in files]
|
| 84 |
+
|
| 85 |
+
def obs_from_file(fn):
|
| 86 |
+
with netCDF4.Dataset(fn, 'r') as ds:
|
| 87 |
+
obs = np.array(ds["future_observations"][:], copy=False)
|
| 88 |
+
obs = decode_saved_var_to_rainrate(obs)
|
| 89 |
+
p = 1
|
| 90 |
+
obs_pooled = {}
|
| 91 |
+
while True:
|
| 92 |
+
obs_pooled[p] = obs
|
| 93 |
+
if obs.shape[-1] == 1:
|
| 94 |
+
break
|
| 95 |
+
obs = 0.25 * (
|
| 96 |
+
obs[...,::2,::2] +
|
| 97 |
+
obs[...,1::2,::2] +
|
| 98 |
+
obs[...,::2,1::2] +
|
| 99 |
+
obs[...,1::2,1::2]
|
| 100 |
+
)
|
| 101 |
+
p *= 2
|
| 102 |
+
return obs_pooled
|
| 103 |
+
|
| 104 |
+
obs_pooled = {}
|
| 105 |
+
for fn in files:
|
| 106 |
+
print(fn)
|
| 107 |
+
obs_file = obs_from_file(fn)
|
| 108 |
+
for k in obs_file:
|
| 109 |
+
if k not in obs_pooled:
|
| 110 |
+
obs_pooled[k] = []
|
| 111 |
+
obs_pooled[k].append(obs_file[k])
|
| 112 |
+
|
| 113 |
+
for k in obs_pooled:
|
| 114 |
+
obs_pooled[k] = np.concatenate(obs_pooled[k], axis=0)
|
| 115 |
+
|
| 116 |
+
return obs_pooled
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def process_batch(fn, log=False, preproc_fc=None):
|
| 120 |
+
print(fn)
|
| 121 |
+
(_, y, y_pred) = load_batch(fn, log=log, preproc_fc=preproc_fc)
|
| 122 |
+
return crps_ensemble_multiscale(y, y_pred)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def save_crps_for_dataset(data_dir, result_fn, log=False, preproc_fc=None):
|
| 126 |
+
files = sorted(os.listdir(data_dir))
|
| 127 |
+
files = [os.path.join(data_dir,fn) for fn in files]
|
| 128 |
+
|
| 129 |
+
N_threads = multiprocessing.cpu_count()
|
| 130 |
+
futures = []
|
| 131 |
+
with concurrent.futures.ProcessPoolExecutor(N_threads) as executor:
|
| 132 |
+
for fn in files:
|
| 133 |
+
args = (process_batch, fn)
|
| 134 |
+
kwargs = {"log": log, "preproc_fc": preproc_fc}
|
| 135 |
+
futures.append(executor.submit(*args, **kwargs))
|
| 136 |
+
|
| 137 |
+
crps = [f.result() for f in futures]
|
| 138 |
+
scales = sorted(crps[0].keys())
|
| 139 |
+
crps = {
|
| 140 |
+
s: np.concatenate([c[s] for c in crps], axis=0)
|
| 141 |
+
for s in scales
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
with netCDF4.Dataset(result_fn, 'w') as ds:
|
| 145 |
+
ds.createDimension("dim_sample", crps[1].shape[0])
|
| 146 |
+
ds.createDimension("dim_channel", crps[1].shape[1])
|
| 147 |
+
ds.createDimension("dim_time_future", crps[1].shape[2])
|
| 148 |
+
var_params = {"zlib": True, "complevel": 1}
|
| 149 |
+
|
| 150 |
+
for s in scales:
|
| 151 |
+
ds.createDimension(f"dim_h_pool{s}x{s}", crps[s].shape[3])
|
| 152 |
+
ds.createDimension(f"dim_w_pool{s}x{s}", crps[s].shape[4])
|
| 153 |
+
var = ds.createVariable(
|
| 154 |
+
f"crps_pool{s}x{s}", np.float32,
|
| 155 |
+
(
|
| 156 |
+
"dim_sample", "dim_channel", "dim_time_future",
|
| 157 |
+
f"dim_h_pool{s}x{s}", f"dim_w_pool{s}x{s}",
|
| 158 |
+
),
|
| 159 |
+
chunksizes=(1,)+crps[s].shape[1:],
|
| 160 |
+
**var_params
|
| 161 |
+
)
|
| 162 |
+
var[:] = crps[s]
|
ldcast/analysis/fss.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import concurrent.futures
|
| 2 |
+
import multiprocessing
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import netCDF4
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from ..features.io import load_batch, decode_saved_var_to_rainrate
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def fractions_ensemble(observation, forecasts, threshold, max_scale=256):
|
| 12 |
+
obs = (observation >= threshold).astype(np.float32)
|
| 13 |
+
fc = (forecasts >= threshold).astype(np.float32).mean(axis=-1)
|
| 14 |
+
obs_frac = {}
|
| 15 |
+
fc_frac = {}
|
| 16 |
+
|
| 17 |
+
scale = 1
|
| 18 |
+
while True:
|
| 19 |
+
obs_frac[scale] = obs.copy()
|
| 20 |
+
fc_frac[scale] = fc.copy()
|
| 21 |
+
scale *= 2
|
| 22 |
+
if scale > max_scale:
|
| 23 |
+
break
|
| 24 |
+
obs = 0.25 * (
|
| 25 |
+
obs[...,::2,::2] +
|
| 26 |
+
obs[...,1::2,::2] +
|
| 27 |
+
obs[...,::2,1::2] +
|
| 28 |
+
obs[...,1::2,1::2]
|
| 29 |
+
)
|
| 30 |
+
fc = 0.25 * (
|
| 31 |
+
fc[...,::2,::2] +
|
| 32 |
+
fc[...,1::2,::2] +
|
| 33 |
+
fc[...,::2,1::2] +
|
| 34 |
+
fc[...,1::2,1::2]
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
return (obs_frac, fc_frac)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def frac_from_file(fn, threshold, preproc_fc):
|
| 41 |
+
print(fn)
|
| 42 |
+
(_, y, y_pred) = load_batch(fn, preproc_fc=preproc_fc)
|
| 43 |
+
return fractions_ensemble(y, y_pred, threshold)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def save_fractions_for_dataset(data_dir, result_fn, threshold, preproc_fc=None):
|
| 47 |
+
files = sorted(os.listdir(data_dir))
|
| 48 |
+
files = [os.path.join(data_dir,fn) for fn in files]
|
| 49 |
+
|
| 50 |
+
N_threads = multiprocessing.cpu_count()
|
| 51 |
+
with concurrent.futures.ProcessPoolExecutor(N_threads) as executor:
|
| 52 |
+
futures = []
|
| 53 |
+
for fn in files:
|
| 54 |
+
args = (frac_from_file, fn, threshold, preproc_fc)
|
| 55 |
+
futures.append(executor.submit(*args))
|
| 56 |
+
|
| 57 |
+
(obs_frac, fc_frac) = zip(*(f.result() for f in futures))
|
| 58 |
+
|
| 59 |
+
scales = list(obs_frac[0].keys())
|
| 60 |
+
obs_frac_dict = {}
|
| 61 |
+
fc_frac_dict = {}
|
| 62 |
+
for s in scales:
|
| 63 |
+
obs_frac_dict[s] = np.concatenate([f[s] for f in obs_frac], axis=0)
|
| 64 |
+
fc_frac_dict[s] = np.concatenate([f[s] for f in fc_frac], axis=0)
|
| 65 |
+
obs_frac = obs_frac_dict
|
| 66 |
+
fc_frac = fc_frac_dict
|
| 67 |
+
|
| 68 |
+
frac_vars = {}
|
| 69 |
+
k = 0
|
| 70 |
+
with netCDF4.Dataset(result_fn, 'w') as ds:
|
| 71 |
+
ds.createDimension("dim_sample", obs_frac[1].shape[0])
|
| 72 |
+
ds.createDimension("dim_channel", obs_frac[1].shape[1])
|
| 73 |
+
ds.createDimension("dim_time_future", obs_frac[1].shape[2])
|
| 74 |
+
var_params = {"zlib": True, "complevel": 1}
|
| 75 |
+
for s in scales:
|
| 76 |
+
ds.createDimension(f"dim_h_pool{s}x{s}", obs_frac[s].shape[3])
|
| 77 |
+
ds.createDimension(f"dim_w_pool{s}x{s}", obs_frac[s].shape[4])
|
| 78 |
+
obs_var = ds.createVariable(
|
| 79 |
+
f"obs_frac_scale{s}x{s}", np.float32,
|
| 80 |
+
(
|
| 81 |
+
"dim_sample", "dim_channel", "dim_time_future",
|
| 82 |
+
f"dim_h_pool{s}x{s}", f"dim_w_pool{s}x{s}",
|
| 83 |
+
),
|
| 84 |
+
chunksizes=(1,)+obs_frac[s].shape[1:],
|
| 85 |
+
**var_params
|
| 86 |
+
)
|
| 87 |
+
obs_var[:] = obs_frac[s]
|
| 88 |
+
fc_var = ds.createVariable(
|
| 89 |
+
f"fc_frac_scale{s}x{s}", np.float32,
|
| 90 |
+
(
|
| 91 |
+
"dim_sample", "dim_channel", "dim_time_future",
|
| 92 |
+
f"dim_h_pool{s}x{s}", f"dim_w_pool{s}x{s}",
|
| 93 |
+
),
|
| 94 |
+
chunksizes=(1,)+fc_frac[s].shape[1:],
|
| 95 |
+
**var_params
|
| 96 |
+
)
|
| 97 |
+
fc_var[:] = fc_frac[s]
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def load_fractions(fn):
|
| 101 |
+
obs_frac = {}
|
| 102 |
+
fc_frac = {}
|
| 103 |
+
with netCDF4.Dataset(fn, 'r') as ds:
|
| 104 |
+
var_list = ds.variables.keys()
|
| 105 |
+
scales = {int(v.split("x")[-1]) for v in var_list}
|
| 106 |
+
for s in scales:
|
| 107 |
+
obs_frac[s] = np.array(ds[f"obs_frac_scale{s}x{s}"][:], copy=False)
|
| 108 |
+
fc_frac[s] = np.array(ds[f"fc_frac_scale{s}x{s}"][:], copy=False)
|
| 109 |
+
|
| 110 |
+
return (obs_frac, fc_frac)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def fractions_skill_score(
|
| 114 |
+
obs_frac, fc_frac,
|
| 115 |
+
frac_axes=None, fss_axes=None, use_timesteps=None
|
| 116 |
+
):
|
| 117 |
+
if isinstance(obs_frac, dict):
|
| 118 |
+
return {
|
| 119 |
+
s: fractions_skill_score(
|
| 120 |
+
obs_frac[s], fc_frac[s],
|
| 121 |
+
frac_axes=frac_axes, fss_axes=fss_axes,
|
| 122 |
+
use_timesteps=use_timesteps
|
| 123 |
+
)
|
| 124 |
+
for s in sorted(obs_frac)
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
if use_timesteps != None:
|
| 128 |
+
obs_frac = obs_frac[:,:,:use_timesteps,...]
|
| 129 |
+
fc_frac = fc_frac[:,:,:use_timesteps,...]
|
| 130 |
+
fbs = ((obs_frac - fc_frac)**2).mean(axis=frac_axes)
|
| 131 |
+
fbs_ref = (obs_frac**2).mean(axis=frac_axes) + \
|
| 132 |
+
(fc_frac**2).mean(axis=frac_axes)
|
| 133 |
+
fss = 1 - fbs/fbs_ref
|
| 134 |
+
if isinstance(fss, np.ndarray):
|
| 135 |
+
fss[~np.isfinite(fss)] = 1
|
| 136 |
+
fss = fss.mean(axis=fss_axes)
|
| 137 |
+
return fss
|
ldcast/analysis/histogram.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import concurrent.futures
|
| 2 |
+
import multiprocessing
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import netCDF4
|
| 6 |
+
import numpy as np
|
| 7 |
+
from scipy.interpolate import interp1d
|
| 8 |
+
|
| 9 |
+
from ..features.io import load_batch, decode_saved_var_to_rainrate
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def histogram(observation, forecasts, bins):
|
| 13 |
+
N_bins = len(bins)-1
|
| 14 |
+
N_timesteps = observation.shape[2]
|
| 15 |
+
obs_hist = np.zeros((N_bins, N_timesteps), dtype=np.uint64)
|
| 16 |
+
fc_hist = np.zeros((N_bins, N_timesteps), dtype=np.uint64)
|
| 17 |
+
|
| 18 |
+
for t in range(observation.shape[2]):
|
| 19 |
+
obs = observation[:,:,t,...].flatten()
|
| 20 |
+
fc = forecasts[:,:,t,...].flatten()
|
| 21 |
+
obs_hist[:,t] = np.histogram(obs, bins=bins)[0]
|
| 22 |
+
fc_hist[:,t] = np.histogram(fc, bins=bins)[0]
|
| 23 |
+
|
| 24 |
+
return (obs_hist, fc_hist)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def hist_from_file(fn, bins):
|
| 28 |
+
print(fn)
|
| 29 |
+
(_, y, y_pred) = load_batch(fn, threshold=bins[0])
|
| 30 |
+
return histogram(y, y_pred, bins)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def save_histogram_for_dataset(data_dir, result_fn, bins=(0.05,120,100)):
|
| 34 |
+
files = sorted(os.listdir(data_dir))
|
| 35 |
+
files = [os.path.join(data_dir,fn) for fn in files]
|
| 36 |
+
|
| 37 |
+
bins = np.exp(np.linspace(np.log(bins[0]), np.log(bins[1]), bins[2]))
|
| 38 |
+
bins = np.hstack((0, bins))
|
| 39 |
+
|
| 40 |
+
N_threads = multiprocessing.cpu_count()
|
| 41 |
+
with concurrent.futures.ProcessPoolExecutor(N_threads) as executor:
|
| 42 |
+
futures = []
|
| 43 |
+
for fn in files:
|
| 44 |
+
args = (hist_from_file, fn, bins)
|
| 45 |
+
futures.append(executor.submit(*args))
|
| 46 |
+
|
| 47 |
+
(obs_hist, fc_hist) = zip(*(f.result() for f in futures))
|
| 48 |
+
|
| 49 |
+
obs_hist = sum(obs_hist)
|
| 50 |
+
fc_hist = sum(fc_hist)
|
| 51 |
+
|
| 52 |
+
with netCDF4.Dataset(result_fn, 'w') as ds:
|
| 53 |
+
ds.createDimension("dim_bin", obs_hist.shape[0])
|
| 54 |
+
ds.createDimension("dim_time_future", obs_hist.shape[1])
|
| 55 |
+
var_params = {"zlib": True, "complevel": 1}
|
| 56 |
+
|
| 57 |
+
obs_var = ds.createVariable(
|
| 58 |
+
f"obs_hist", np.uint64,
|
| 59 |
+
("dim_bin", "dim_time_future"),
|
| 60 |
+
**var_params
|
| 61 |
+
)
|
| 62 |
+
obs_var[:] = obs_hist
|
| 63 |
+
|
| 64 |
+
fc_var = ds.createVariable(
|
| 65 |
+
f"fc_hist", np.uint64,
|
| 66 |
+
("dim_bin", "dim_time_future"),
|
| 67 |
+
**var_params
|
| 68 |
+
)
|
| 69 |
+
fc_var[:] = fc_hist
|
| 70 |
+
|
| 71 |
+
ds.createDimension("dim_bin_edge", len(bins))
|
| 72 |
+
bin_var = ds.createVariable(
|
| 73 |
+
f"bins", np.float64,
|
| 74 |
+
("dim_bin_edge",),
|
| 75 |
+
**var_params
|
| 76 |
+
)
|
| 77 |
+
bin_var[:] = bins
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def load_histogram(fn):
|
| 81 |
+
with netCDF4.Dataset(fn, 'r') as ds:
|
| 82 |
+
obs_hist = np.array(ds["obs_hist"][:], copy=False)
|
| 83 |
+
fc_hist = np.array(ds["fc_hist"][:], copy=False)
|
| 84 |
+
bins = np.array(ds["bins"][:], copy=False)
|
| 85 |
+
|
| 86 |
+
return (obs_hist, fc_hist, bins)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class ProbabilityMatch:
|
| 90 |
+
def __init__(self, obs_hist, fc_hist, bins):
|
| 91 |
+
obs_c = obs_hist.cumsum()
|
| 92 |
+
obs_c = obs_c / obs_c[-1]
|
| 93 |
+
fc_c = fc_hist.cumsum()
|
| 94 |
+
fc_c = fc_c / fc_c[-1]
|
| 95 |
+
|
| 96 |
+
self.obs_cdf = interp1d(np.hstack((0,obs_c)), bins, fill_value='extrapolate')
|
| 97 |
+
self.fc_cdf = interp1d(bins, np.hstack((0,fc_c)), fill_value='extrapolate')
|
| 98 |
+
|
| 99 |
+
def __call__(self, x):
|
| 100 |
+
return self.obs_cdf(self.fc_cdf(x))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def probability_match_timesteps(obs_hist, fc_hist, bins):
|
| 104 |
+
num_timesteps = obs_hist.shape[1]
|
| 105 |
+
return [
|
| 106 |
+
ProbabilityMatch(obs_hist[:,t], fc_hist[:,t], bins)
|
| 107 |
+
for t in range(num_timesteps)
|
| 108 |
+
]
|
ldcast/analysis/rank.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import concurrent
|
| 3 |
+
import multiprocessing
|
| 4 |
+
|
| 5 |
+
import netCDF4
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from ..features.io import load_batch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def ranks_ensemble(
|
| 12 |
+
observation, forecasts,
|
| 13 |
+
noise_scale=1e-6, rng=None
|
| 14 |
+
):
|
| 15 |
+
shape = observation.shape
|
| 16 |
+
N = np.prod(shape)
|
| 17 |
+
shape_flat = (np.prod(shape),)
|
| 18 |
+
observation = observation.reshape((N,))
|
| 19 |
+
forecasts = forecasts.reshape((N, forecasts.shape[-1]))
|
| 20 |
+
N_threads = multiprocessing.cpu_count()
|
| 21 |
+
|
| 22 |
+
max_rank = forecasts.shape[-1]
|
| 23 |
+
bins = np.arange(-0.5, max_rank+0.6)
|
| 24 |
+
ranks_all = np.zeros_like(observation, dtype=np.uint32)
|
| 25 |
+
|
| 26 |
+
if rng is None:
|
| 27 |
+
rng = np.random
|
| 28 |
+
|
| 29 |
+
def rank_dist_chunk(k):
|
| 30 |
+
i0 = int(round((k/N_threads) * N))
|
| 31 |
+
i1 = int(round(((k+1) / N_threads) * N))
|
| 32 |
+
obs = observation[i0:i1].astype(np.float64, copy=True)
|
| 33 |
+
fc = forecasts[i0:i1,:].astype(np.float64, copy=True)
|
| 34 |
+
|
| 35 |
+
# add a tiny amount of noise to forecast to randomize ties
|
| 36 |
+
# (important to add to both obs and fc!)
|
| 37 |
+
obs += (rng.rand(*obs.shape) - 0.5) * noise_scale
|
| 38 |
+
fc += (rng.rand(*fc.shape) - 0.5) * noise_scale
|
| 39 |
+
|
| 40 |
+
ranks = np.count_nonzero(obs[...,None] >= fc, axis=-1)
|
| 41 |
+
ranks_all[i0:i1] = ranks
|
| 42 |
+
|
| 43 |
+
with concurrent.futures.ThreadPoolExecutor(N_threads) as executor:
|
| 44 |
+
futures = {}
|
| 45 |
+
for k in range(N_threads):
|
| 46 |
+
args = (rank_dist_chunk, k)
|
| 47 |
+
futures[executor.submit(*args)] = k
|
| 48 |
+
concurrent.futures.wait(futures)
|
| 49 |
+
|
| 50 |
+
return ranks_all.reshape(shape)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def ranks_multiscale(observation, forecasts):
|
| 54 |
+
obs = observation
|
| 55 |
+
fc = forecasts
|
| 56 |
+
|
| 57 |
+
rank_scales = {}
|
| 58 |
+
scale = 1
|
| 59 |
+
while True:
|
| 60 |
+
r = ranks_ensemble(obs, fc)
|
| 61 |
+
rank_scales[scale] = r
|
| 62 |
+
scale *= 2
|
| 63 |
+
if obs.shape[-1] == 1:
|
| 64 |
+
break
|
| 65 |
+
# avg pooling
|
| 66 |
+
obs = 0.25 * (
|
| 67 |
+
obs[...,::2,::2] +
|
| 68 |
+
obs[...,1::2,::2] +
|
| 69 |
+
obs[...,::2,1::2] +
|
| 70 |
+
obs[...,1::2,1::2]
|
| 71 |
+
)
|
| 72 |
+
fc = 0.25 * (
|
| 73 |
+
fc[...,::2,::2,:] +
|
| 74 |
+
fc[...,1::2,::2,:] +
|
| 75 |
+
fc[...,::2,1::2,:] +
|
| 76 |
+
fc[...,1::2,1::2,:]
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
return rank_scales
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def rank_distribution(ranks, num_forecasts=32):
|
| 83 |
+
N = np.prod(ranks.shape)
|
| 84 |
+
bins = np.arange(-0.5, num_forecasts+0.6)
|
| 85 |
+
N_threads = multiprocessing.cpu_count()
|
| 86 |
+
ranks = ranks.ravel()
|
| 87 |
+
|
| 88 |
+
hist = [None] * N_threads
|
| 89 |
+
def hist_chunk(k):
|
| 90 |
+
i0 = int(round((k/N_threads) * N))
|
| 91 |
+
i1 = int(round(((k+1) / N_threads) * N))
|
| 92 |
+
(h, _) = np.histogram(ranks[i0:i1], bins=bins)
|
| 93 |
+
hist[k] = h
|
| 94 |
+
|
| 95 |
+
with concurrent.futures.ThreadPoolExecutor(N_threads) as executor:
|
| 96 |
+
futures = {}
|
| 97 |
+
for k in range(N_threads):
|
| 98 |
+
args = (hist_chunk, k)
|
| 99 |
+
futures[executor.submit(*args)] = k
|
| 100 |
+
concurrent.futures.wait(futures)
|
| 101 |
+
|
| 102 |
+
hist = sum(hist)
|
| 103 |
+
return hist / hist.sum()
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def rank_KS(rank_dist, num_forecasts=32):
|
| 107 |
+
h = rank_dist
|
| 108 |
+
h = h / h.sum()
|
| 109 |
+
ch = np.cumsum(h)
|
| 110 |
+
cb = np.linspace(0, 1, len(ch))
|
| 111 |
+
return abs(ch-cb).max()
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def rank_DKL(rank_dist, num_forecasts=32):
|
| 115 |
+
h = rank_dist
|
| 116 |
+
q = h / h.sum()
|
| 117 |
+
p = 1/len(h)
|
| 118 |
+
return p*np.log(p/q).sum()
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def rank_metric_by_leadtime(ranks, metric=None, num_forecasts=32):
|
| 122 |
+
if metric is None:
|
| 123 |
+
metric = rank_DKL
|
| 124 |
+
|
| 125 |
+
metric_by_leadtime = []
|
| 126 |
+
for t in range(ranks.shape[2]):
|
| 127 |
+
ranks_time = ranks[:,:,t,...]
|
| 128 |
+
h = rank_distribution(ranks_time)
|
| 129 |
+
m = metric(h, num_forecasts=num_forecasts)
|
| 130 |
+
metric_by_leadtime.append(m)
|
| 131 |
+
return np.array(metric_by_leadtime)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def rank_metric_by_bin(ranks, values, bins, metric=None, num_forecasts=32):
|
| 135 |
+
if metric is None:
|
| 136 |
+
metric = rank_DKL
|
| 137 |
+
|
| 138 |
+
metric_by_bin = []
|
| 139 |
+
for (b0,b1) in zip(bins[:-1],bins[1:]):
|
| 140 |
+
ranks_bin = ranks[(b0 <= values) & (values < b1)]
|
| 141 |
+
h = rank_distribution(ranks_bin)
|
| 142 |
+
m = metric(h, num_forecasts=num_forecasts)
|
| 143 |
+
metric_by_bin.append(m)
|
| 144 |
+
return np.array(metric_by_bin)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def process_batch(fn, preproc_fc=None):
|
| 148 |
+
print(fn)
|
| 149 |
+
(_, y, y_pred) = load_batch(fn, preproc_fc=preproc_fc)
|
| 150 |
+
return ranks_multiscale(y, y_pred)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def save_ranks_for_dataset(data_dir, result_fn, preproc_fc=None):
|
| 154 |
+
files = sorted(os.listdir(data_dir))
|
| 155 |
+
files = [os.path.join(data_dir,fn) for fn in files]
|
| 156 |
+
|
| 157 |
+
N_threads = multiprocessing.cpu_count()
|
| 158 |
+
futures = []
|
| 159 |
+
with concurrent.futures.ProcessPoolExecutor(N_threads) as executor:
|
| 160 |
+
for fn in files:
|
| 161 |
+
args = (process_batch, fn)
|
| 162 |
+
kwargs = {"preproc_fc": preproc_fc}
|
| 163 |
+
futures.append(executor.submit(*args, **kwargs))
|
| 164 |
+
|
| 165 |
+
ranks = [f.result() for f in futures]
|
| 166 |
+
scales = sorted(ranks[0].keys())
|
| 167 |
+
ranks = {
|
| 168 |
+
s: np.concatenate([r[s] for r in ranks], axis=0)
|
| 169 |
+
for s in scales
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
with netCDF4.Dataset(result_fn, 'w') as ds:
|
| 173 |
+
ds.createDimension("dim_sample", ranks[1].shape[0])
|
| 174 |
+
ds.createDimension("dim_channel", ranks[1].shape[1])
|
| 175 |
+
ds.createDimension("dim_time_future", ranks[1].shape[2])
|
| 176 |
+
var_params = {"zlib": True, "complevel": 1}
|
| 177 |
+
|
| 178 |
+
for s in scales:
|
| 179 |
+
ds.createDimension(f"dim_h_pool{s}x{s}", ranks[s].shape[3])
|
| 180 |
+
ds.createDimension(f"dim_w_pool{s}x{s}", ranks[s].shape[4])
|
| 181 |
+
var = ds.createVariable(
|
| 182 |
+
f"ranks_pool{s}x{s}", np.float32,
|
| 183 |
+
(
|
| 184 |
+
"dim_sample", "dim_channel", "dim_time_future",
|
| 185 |
+
f"dim_h_pool{s}x{s}", f"dim_w_pool{s}x{s}",
|
| 186 |
+
),
|
| 187 |
+
chunksizes=(1,)+ranks[s].shape[1:],
|
| 188 |
+
**var_params
|
| 189 |
+
)
|
| 190 |
+
var[:] = ranks[s]
|
ldcast/features/.sampling.py.swp
ADDED
|
File without changes
|
ldcast/features/batch.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime, timedelta
|
| 2 |
+
import os
|
| 3 |
+
import pickle
|
| 4 |
+
|
| 5 |
+
from numba import njit, prange, types
|
| 6 |
+
from numba.typed import Dict
|
| 7 |
+
import numpy as np
|
| 8 |
+
from torch.utils.data import Dataset, IterableDataset
|
| 9 |
+
|
| 10 |
+
from .patches import unpack_patches
|
| 11 |
+
from .sampling import EqualFrequencySampler
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BatchGenerator:
|
| 15 |
+
def __init__(self,
|
| 16 |
+
variables,
|
| 17 |
+
raw,
|
| 18 |
+
predictors,
|
| 19 |
+
target,
|
| 20 |
+
primary_var,
|
| 21 |
+
time_range_sampling=(-1,2),
|
| 22 |
+
forecast_raw_vars=(),
|
| 23 |
+
sampling_bins=None,
|
| 24 |
+
sampler_file=None,
|
| 25 |
+
sample_shape=(4,4),
|
| 26 |
+
batch_size=32,
|
| 27 |
+
interval=timedelta(minutes=5),
|
| 28 |
+
random_seed=None,
|
| 29 |
+
augment=False
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.batch_size = batch_size
|
| 33 |
+
self.interval = interval
|
| 34 |
+
self.interval_secs = np.int64(self.interval.total_seconds())
|
| 35 |
+
self.variables = variables
|
| 36 |
+
self.predictors = predictors
|
| 37 |
+
self.target = target
|
| 38 |
+
self.used_variables = predictors + [target]
|
| 39 |
+
self.rng = np.random.RandomState(seed=random_seed)
|
| 40 |
+
self.augment = augment
|
| 41 |
+
|
| 42 |
+
# setup indices for retrieving source raw data
|
| 43 |
+
self.sources = set.union(
|
| 44 |
+
*(set(variables[v]["sources"]) for v in self.used_variables)
|
| 45 |
+
)
|
| 46 |
+
self.forecast_raw_vars = set(forecast_raw_vars) & self.sources
|
| 47 |
+
self.patch_index = {}
|
| 48 |
+
for raw_name_base in self.sources:
|
| 49 |
+
if raw_name_base in forecast_raw_vars:
|
| 50 |
+
raw_names = (
|
| 51 |
+
rn for rn in raw if rn.startswith(raw_name_base+"-")
|
| 52 |
+
)
|
| 53 |
+
else:
|
| 54 |
+
raw_names = (raw_name_base,)
|
| 55 |
+
for raw_name in raw_names:
|
| 56 |
+
raw_data = raw[raw_name]
|
| 57 |
+
self.setup_index(raw_name, raw_data, sample_shape)
|
| 58 |
+
|
| 59 |
+
for raw_name in self.forecast_raw_vars:
|
| 60 |
+
patch_index_var = {
|
| 61 |
+
k: v for (k,v) in self.patch_index.items()
|
| 62 |
+
if k.startswith(raw_name+"-")
|
| 63 |
+
}
|
| 64 |
+
self.patch_index[raw_name] = \
|
| 65 |
+
ForecastPatchIndexWrapper(patch_index_var)
|
| 66 |
+
|
| 67 |
+
# setup samplers
|
| 68 |
+
if (sampler_file is None) or not os.path.isfile(sampler_file):
|
| 69 |
+
print("No cached sampler found, creating a new one...")
|
| 70 |
+
primary_raw_var = variables[primary_var]["sources"][0]
|
| 71 |
+
t0 = t1 = None
|
| 72 |
+
for (var_name, var_data) in variables.items():
|
| 73 |
+
timesteps = var_data["timesteps"][[0,-1]].copy()
|
| 74 |
+
timesteps[0] -= 1
|
| 75 |
+
ts_secs = timesteps * \
|
| 76 |
+
var_data.get("timestep_secs", self.interval_secs)
|
| 77 |
+
timesteps = ts_secs // self.interval_secs
|
| 78 |
+
t0 = timesteps[0] if t0 is None else min(t0,timesteps[0])
|
| 79 |
+
t1 = timesteps[-1] if t1 is None else max(t1,timesteps[-1])
|
| 80 |
+
time_range_valid = (t0,t1+1)
|
| 81 |
+
self.sampler = EqualFrequencySampler(
|
| 82 |
+
sampling_bins, raw[primary_raw_var],
|
| 83 |
+
self.patch_index[primary_raw_var], sample_shape,
|
| 84 |
+
time_range_valid, time_range_sampling=time_range_sampling,
|
| 85 |
+
timestep_secs=self.interval_secs
|
| 86 |
+
)
|
| 87 |
+
if sampler_file is not None:
|
| 88 |
+
print(f"Caching sampler to {sampler_file}.")
|
| 89 |
+
with open(sampler_file, 'wb') as f:
|
| 90 |
+
pickle.dump(self.sampler, f)
|
| 91 |
+
else:
|
| 92 |
+
print(f"Loading cached sampler from {sampler_file}.")
|
| 93 |
+
with open(sampler_file, 'rb') as f:
|
| 94 |
+
self.sampler = pickle.load(f)
|
| 95 |
+
|
| 96 |
+
def setup_index(self, raw_name, raw_data, box_size):
|
| 97 |
+
zero_value = raw_data.get("zero_value", 0)
|
| 98 |
+
missing_value = raw_data.get("missing_value", zero_value)
|
| 99 |
+
|
| 100 |
+
self.patch_index[raw_name] = PatchIndex(
|
| 101 |
+
*unpack_patches(raw_data),
|
| 102 |
+
zero_value=zero_value,
|
| 103 |
+
missing_value=missing_value,
|
| 104 |
+
interval=self.interval,
|
| 105 |
+
box_size=box_size
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
def augmentations(self):
|
| 109 |
+
return tuple(self.rng.randint(2, size=3))
|
| 110 |
+
|
| 111 |
+
def augment_batch(self, batch, transpose, flipud, fliplr):
|
| 112 |
+
if self.augment:
|
| 113 |
+
if transpose:
|
| 114 |
+
axes = list(range(batch.ndim))
|
| 115 |
+
axes = axes[:-2] + [axes[-1], axes[-2]]
|
| 116 |
+
batch = batch.transpose(axes)
|
| 117 |
+
flips = []
|
| 118 |
+
if flipud:
|
| 119 |
+
flips.append(-2)
|
| 120 |
+
if fliplr:
|
| 121 |
+
flips.append(-1)
|
| 122 |
+
if flips:
|
| 123 |
+
batch = np.flip(batch, axis=flips)
|
| 124 |
+
return batch.copy()
|
| 125 |
+
|
| 126 |
+
def batch(self, samples=None, batch_size=None):
|
| 127 |
+
if batch_size is None:
|
| 128 |
+
batch_size = self.batch_size
|
| 129 |
+
|
| 130 |
+
if samples is None:
|
| 131 |
+
# get the sample coordinates from the sampler
|
| 132 |
+
samples = self.sampler(batch_size)
|
| 133 |
+
|
| 134 |
+
print(samples)
|
| 135 |
+
(t0,i0,j0) = samples.T
|
| 136 |
+
|
| 137 |
+
if self.augment:
|
| 138 |
+
augmentations = self.augmentations()
|
| 139 |
+
|
| 140 |
+
batch = {}
|
| 141 |
+
for var_name in self.used_variables:
|
| 142 |
+
var_data = self.variables[var_name]
|
| 143 |
+
|
| 144 |
+
# different timestep from standard (e.g. forecast); round down
|
| 145 |
+
# to times where we have data available
|
| 146 |
+
ts_secs = var_data.get("timestep_secs", self.interval_secs)
|
| 147 |
+
t_shift = -(t0 % ts_secs)
|
| 148 |
+
t0_shifted = t0 + t_shift
|
| 149 |
+
t = t0_shifted[:,None] + ts_secs*var_data["timesteps"][None,:]
|
| 150 |
+
t_relative = (t - t0[:,None]) / self.interval_secs
|
| 151 |
+
|
| 152 |
+
# read raw data from index
|
| 153 |
+
raw_data = (
|
| 154 |
+
self.patch_index[raw_name](t,i0,j0)
|
| 155 |
+
for raw_name in var_data["sources"]
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# transform to model variable
|
| 159 |
+
batch_var = var_data["transform"](*raw_data)
|
| 160 |
+
|
| 161 |
+
# add channel dimension if not already present
|
| 162 |
+
add_dims = (1,) if batch_var.ndim == 4 else ()
|
| 163 |
+
batch_var = np.expand_dims(batch_var, add_dims)
|
| 164 |
+
|
| 165 |
+
# data augmentation
|
| 166 |
+
if self.augment:
|
| 167 |
+
batch_var = self.augment_batch(batch_var, *augmentations)
|
| 168 |
+
|
| 169 |
+
# bundle with time coordinates
|
| 170 |
+
batch[var_name] = (batch_var, t_relative.astype(np.float32))
|
| 171 |
+
|
| 172 |
+
pred_batch = [batch[v] for v in self.predictors]
|
| 173 |
+
target_batch = batch[self.target][0] # no time coordinates for target
|
| 174 |
+
return (pred_batch, target_batch)
|
| 175 |
+
|
| 176 |
+
def batches(self, *args, num=None, **kwargs):
|
| 177 |
+
if num is not None:
|
| 178 |
+
for i in range(num):
|
| 179 |
+
yield self.batch(*args, **kwargs)
|
| 180 |
+
else:
|
| 181 |
+
while True:
|
| 182 |
+
yield self.batch(*args, **kwargs)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class StreamBatchDataset(IterableDataset):
|
| 186 |
+
def __init__(self, batch_gen, batches_per_epoch):
|
| 187 |
+
super().__init__()
|
| 188 |
+
self.batch_gen = batch_gen
|
| 189 |
+
self.batches_per_epoch = batches_per_epoch
|
| 190 |
+
|
| 191 |
+
def __iter__(self):
|
| 192 |
+
batches = self.batch_gen.batches(num=self.batches_per_epoch)
|
| 193 |
+
yield from batches
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class DeterministicBatchDataset(Dataset):
|
| 197 |
+
def __init__(self, batch_gen, batches_per_epoch, random_seed=None):
|
| 198 |
+
super().__init__()
|
| 199 |
+
self.batch_gen = batch_gen
|
| 200 |
+
self.batches_per_epoch = batches_per_epoch
|
| 201 |
+
self.batch_gen.sampler.rng = np.random.RandomState(seed=random_seed)
|
| 202 |
+
self.samples = [
|
| 203 |
+
self.batch_gen.sampler(self.batch_gen.batch_size)
|
| 204 |
+
for i in range(self.batches_per_epoch)
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
def __len__(self):
|
| 208 |
+
return self.batches_per_epoch
|
| 209 |
+
|
| 210 |
+
def __getitem__(self, ind):
|
| 211 |
+
print(self.samples[ind])
|
| 212 |
+
return self.batch_gen.batch(samples=self.samples[ind])
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class PatchIndex:
|
| 216 |
+
IDX_ZERO = -1
|
| 217 |
+
IDX_MISSING = -2
|
| 218 |
+
|
| 219 |
+
def __init__(
|
| 220 |
+
self, patch_data, patch_coords, patch_times,
|
| 221 |
+
zero_patch_coords, zero_patch_times,
|
| 222 |
+
interval=timedelta(minutes=5),
|
| 223 |
+
box_size=(4,4), zero_value=0,
|
| 224 |
+
missing_value=0
|
| 225 |
+
):
|
| 226 |
+
self.dt = int(round(interval.total_seconds()))
|
| 227 |
+
self.box_size = box_size
|
| 228 |
+
self.zero_value = zero_value
|
| 229 |
+
self.missing_value = missing_value
|
| 230 |
+
self.patch_data = patch_data
|
| 231 |
+
self.sample_shape = (
|
| 232 |
+
self.patch_data.shape[1]*box_size[0],
|
| 233 |
+
self.patch_data.shape[2]*box_size[1]
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
self.patch_index = Dict.empty(
|
| 237 |
+
key_type=types.UniTuple(types.int64, 3),
|
| 238 |
+
value_type=types.int64
|
| 239 |
+
)
|
| 240 |
+
init_patch_index(self.patch_index, patch_coords, patch_times)
|
| 241 |
+
init_patch_index_zero(self.patch_index, zero_patch_coords,
|
| 242 |
+
zero_patch_times, PatchIndex.IDX_ZERO)
|
| 243 |
+
|
| 244 |
+
self._batch = None
|
| 245 |
+
|
| 246 |
+
def _alloc_batch(self, batch_size, num_timesteps):
|
| 247 |
+
needs_rebuild = (self._batch is None) or \
|
| 248 |
+
(self._batch.shape[0] < batch_size) or \
|
| 249 |
+
(self._batch.shape[1] < num_timesteps)
|
| 250 |
+
if needs_rebuild:
|
| 251 |
+
del self._batch
|
| 252 |
+
self._batch = np.zeros(
|
| 253 |
+
(batch_size,num_timesteps)+self.sample_shape,
|
| 254 |
+
self.patch_data.dtype
|
| 255 |
+
)
|
| 256 |
+
return self._batch
|
| 257 |
+
|
| 258 |
+
def __call__(self, t, i0_all, j0_all):
|
| 259 |
+
batch = self._alloc_batch(*t.shape)
|
| 260 |
+
|
| 261 |
+
i1_all = i0_all + self.box_size[0]
|
| 262 |
+
j1_all = j0_all + self.box_size[1]
|
| 263 |
+
bi_size = self.patch_data.shape[1]
|
| 264 |
+
bj_size = self.patch_data.shape[2]
|
| 265 |
+
|
| 266 |
+
build_batch(batch, self.patch_data, self.patch_index,
|
| 267 |
+
t, i0_all, i1_all, j0_all, j1_all,
|
| 268 |
+
bi_size, bj_size, self.zero_value,
|
| 269 |
+
self.missing_value)
|
| 270 |
+
|
| 271 |
+
return batch[:,:t.shape[1],...]
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
@njit
|
| 275 |
+
def init_patch_index(patch_index, patch_coords, patch_times):
|
| 276 |
+
for k in range(patch_coords.shape[0]):
|
| 277 |
+
t = patch_times[k]
|
| 278 |
+
i = np.int64(patch_coords[k,0])
|
| 279 |
+
j = np.int64(patch_coords[k,1])
|
| 280 |
+
patch_index[(t,i,j)] = k
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
@njit
|
| 284 |
+
def init_patch_index_zero(patch_index, zero_patch_coords,
|
| 285 |
+
zero_patch_times, idx_zero):
|
| 286 |
+
|
| 287 |
+
for k in range(zero_patch_coords.shape[0]):
|
| 288 |
+
t = zero_patch_times[k]
|
| 289 |
+
i = np.int64(zero_patch_coords[k,0])
|
| 290 |
+
j = np.int64(zero_patch_coords[k,1])
|
| 291 |
+
patch_index[(t,i,j)] = idx_zero
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# numba can't find these values from PatchIndex
|
| 295 |
+
IDX_ZERO = PatchIndex.IDX_ZERO
|
| 296 |
+
IDX_MISSING = PatchIndex.IDX_MISSING
|
| 297 |
+
@njit(parallel=True)
|
| 298 |
+
def build_batch(
|
| 299 |
+
batch, patch_data, patch_index,
|
| 300 |
+
t_all, i0_all, i1_all, j0_all, j1_all,
|
| 301 |
+
bi_size, bj_size, zero_value, missing_value
|
| 302 |
+
):
|
| 303 |
+
for k in prange(t_all.shape[0]):
|
| 304 |
+
i0 = i0_all[k]
|
| 305 |
+
i1 = i1_all[k]
|
| 306 |
+
j0 = j0_all[k]
|
| 307 |
+
j1 = j1_all[k]
|
| 308 |
+
|
| 309 |
+
for (bt,t) in enumerate(t_all[k,:]):
|
| 310 |
+
for i in range(i0, i1):
|
| 311 |
+
bi0 = (i-i0) * bi_size
|
| 312 |
+
bi1 = bi0 + bi_size
|
| 313 |
+
for j in range(j0, j1):
|
| 314 |
+
ind = int(patch_index.get((t,i,j), IDX_MISSING))
|
| 315 |
+
bj0 = (j-j0) * bj_size
|
| 316 |
+
bj1 = bj0 + bj_size
|
| 317 |
+
if ind >= 0:
|
| 318 |
+
batch[k,bt,bi0:bi1,bj0:bj1] = patch_data[ind]
|
| 319 |
+
elif ind == IDX_ZERO:
|
| 320 |
+
batch[k,bt,bi0:bi1,bj0:bj1] = zero_value
|
| 321 |
+
elif ind == IDX_MISSING:
|
| 322 |
+
batch[k,bt,bi0:bi1,bj0:bj1] = missing_value
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class ForecastPatchIndexWrapper(PatchIndex):
|
| 326 |
+
def __init__(self, patch_index):
|
| 327 |
+
self.patch_index = patch_index
|
| 328 |
+
raw_names = {"-".join(v.split("-")[:-1]) for v in patch_index}
|
| 329 |
+
if len(raw_names) != 1:
|
| 330 |
+
raise ValueError(
|
| 331 |
+
"Can only wrap variables with the same base name")
|
| 332 |
+
self.raw_name = list(raw_names)[0]
|
| 333 |
+
lags_hour = [int(v.split("-")[-1]) for v in patch_index]
|
| 334 |
+
self.lags_hour = set(lags_hour)
|
| 335 |
+
forecast_interval_hour = np.diff(sorted(lags_hour))
|
| 336 |
+
if len(set(forecast_interval_hour)) != 1:
|
| 337 |
+
raise ValueError("Lags must be evenly spaced")
|
| 338 |
+
forecast_interval_hour = forecast_interval_hour[0]
|
| 339 |
+
if (24 % forecast_interval_hour):
|
| 340 |
+
raise ValueError(
|
| 341 |
+
"24 hours must be a multiple of the forecast interval")
|
| 342 |
+
self.forecast_interval_hour = forecast_interval_hour
|
| 343 |
+
self.forecast_interval = 3600 * forecast_interval_hour
|
| 344 |
+
|
| 345 |
+
# need to set these for _alloc_batch to work
|
| 346 |
+
self._batch = None
|
| 347 |
+
v = list(self.patch_index.keys())[0]
|
| 348 |
+
self.sample_shape = self.patch_index[v].sample_shape
|
| 349 |
+
self.patch_data = self.patch_index[v].patch_data
|
| 350 |
+
|
| 351 |
+
def __call__(self, t, i0, j0):
|
| 352 |
+
batch = self._alloc_batch(*t.shape)
|
| 353 |
+
|
| 354 |
+
# ensure that all data come from the same forecast
|
| 355 |
+
t0 = t[:,:1]
|
| 356 |
+
start_time_from_fc = t0 % self.forecast_interval
|
| 357 |
+
time_from_fc = start_time_from_fc + (t - t0)
|
| 358 |
+
lags_hour = (time_from_fc // self.forecast_interval) * \
|
| 359 |
+
self.forecast_interval_hour
|
| 360 |
+
|
| 361 |
+
for lag in self.lags_hour:
|
| 362 |
+
raw_name_lag = f"{self.raw_name}-{lag}"
|
| 363 |
+
batch_lag = self.patch_index[raw_name_lag](t,i0,j0)
|
| 364 |
+
lag_mask = (lags_hour == lag)
|
| 365 |
+
copy_masked_times(batch_lag, batch, lag_mask)
|
| 366 |
+
|
| 367 |
+
return batch[:,:t.shape[1],...]
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
@njit(parallel=True)
|
| 371 |
+
def copy_masked_times(from_batch, to_batch, mask):
|
| 372 |
+
for k in prange(from_batch.shape[0]):
|
| 373 |
+
for bt in range(from_batch.shape[1]):
|
| 374 |
+
if mask[k,bt]:
|
| 375 |
+
to_batch[k,bt,:,:] = from_batch[k,bt,:,:]
|
ldcast/features/batch.py.save
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime, timedelta
|
| 2 |
+
import os
|
| 3 |
+
import pickle
|
| 4 |
+
|
| 5 |
+
from numba import njit, prange, types
|
| 6 |
+
from numba.typed import Dict
|
| 7 |
+
import numpy as np
|
| 8 |
+
from torch.utils.data import Dataset, IterableDataset
|
| 9 |
+
|
| 10 |
+
from .patches import unpack_patches
|
| 11 |
+
from .sampling import EqualFrequencySampler
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BatchGenerator:
|
| 15 |
+
def __init__(self,
|
| 16 |
+
variables,
|
| 17 |
+
raw,
|
| 18 |
+
predictors,
|
| 19 |
+
target,
|
| 20 |
+
primary_var,
|
| 21 |
+
time_range_sampling=(-1,2),
|
| 22 |
+
forecast_raw_vars=(),
|
| 23 |
+
sampling_bins=None,
|
| 24 |
+
sampler_file=None,
|
| 25 |
+
sample_shape=(4,4),
|
| 26 |
+
batch_size=32,
|
| 27 |
+
interval=timedelta(minutes=5),
|
| 28 |
+
random_seed=None,
|
| 29 |
+
augment=False
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.batch_size = batch_size
|
| 33 |
+
self.interval = interval
|
| 34 |
+
self.interval_secs = np.int64(self.interval.total_seconds())
|
| 35 |
+
self.variables = variables
|
| 36 |
+
self.predictors = predictors
|
| 37 |
+
self.target = target
|
| 38 |
+
self.used_variables = predictors + [target]
|
| 39 |
+
self.rng = np.random.RandomState(seed=random_seed)
|
| 40 |
+
self.augment = augment
|
| 41 |
+
|
| 42 |
+
# setup indices for retrieving source raw data
|
| 43 |
+
self.sources = set.union(
|
| 44 |
+
*(set(variables[v]["sources"]) for v in self.used_variables)
|
| 45 |
+
)
|
| 46 |
+
self.forecast_raw_vars = set(forecast_raw_vars) & self.sources
|
| 47 |
+
self.patch_index = {}
|
| 48 |
+
for raw_name_base in self.sources:
|
| 49 |
+
if raw_name_base in forecast_raw_vars:
|
| 50 |
+
raw_names = (
|
| 51 |
+
rn for rn in raw if rn.startswith(raw_name_base+"-")
|
| 52 |
+
)
|
| 53 |
+
else:
|
| 54 |
+
raw_names = (raw_name_base,)
|
| 55 |
+
for raw_name in raw_names:
|
| 56 |
+
raw_data = raw[raw_name]
|
| 57 |
+
self.setup_index(raw_name, raw_data, sample_shape)
|
| 58 |
+
|
| 59 |
+
for raw_name in self.forecast_raw_vars:
|
| 60 |
+
patch_index_var = {
|
| 61 |
+
k: v for (k,v) in self.patch_index.items()
|
| 62 |
+
if k.startswith(raw_name+"-")
|
| 63 |
+
}
|
| 64 |
+
self.patch_index[raw_name] = \
|
| 65 |
+
ForecastPatchIndexWrapper(patch_index_var)
|
| 66 |
+
|
| 67 |
+
# setup samplers
|
| 68 |
+
if (sampler_file is None) or not os.path.isfile(sampler_file):
|
| 69 |
+
print("No cached sampler found, creating a new one...")
|
| 70 |
+
primary_raw_var = variables[primary_var]["sources"][0]
|
| 71 |
+
t0 = t1 = None
|
| 72 |
+
for (var_name, var_data) in variables.items():
|
| 73 |
+
timesteps = var_data["timesteps"][[0,-1]].copy()
|
| 74 |
+
timesteps[0] -= 1
|
| 75 |
+
ts_secs = timesteps * \
|
| 76 |
+
var_data.get("timestep_secs", self.interval_secs)
|
| 77 |
+
timesteps = ts_secs // self.interval_secs
|
| 78 |
+
t0 = timesteps[0] if t0 is None else min(t0,timesteps[0])
|
| 79 |
+
t1 = timesteps[-1] if t1 is None else max(t1,timesteps[-1])
|
| 80 |
+
time_range_valid = (t0,t1+1)
|
| 81 |
+
self.sampler = EqualFrequencySampler(
|
| 82 |
+
sampling_bins, raw[primary_raw_var],
|
| 83 |
+
self.patch_index[primary_raw_var], sample_shape,
|
| 84 |
+
time_range_valid, time_range_sampling=time_range_sampling,
|
| 85 |
+
timestep_secs=self.interval_secs
|
| 86 |
+
)
|
| 87 |
+
if sampler_file is not None:
|
| 88 |
+
print(f"Caching sampler to {sampler_file}.")
|
| 89 |
+
with open(sampler_file, 'wb') as f:
|
| 90 |
+
pickle.dump(self.sampler, f)
|
| 91 |
+
else:
|
| 92 |
+
print(f"Loading cached sampler from {sampler_file}.")
|
| 93 |
+
with open(sampler_file, 'rb') as f:
|
| 94 |
+
self.sampler = pickle.load(f)
|
| 95 |
+
|
| 96 |
+
def setup_index(self, raw_name, raw_data, box_size):
|
| 97 |
+
zero_value = raw_data.get("zero_value", 0)
|
| 98 |
+
missing_value = raw_data.get("missing_value", zero_value)
|
| 99 |
+
|
| 100 |
+
self.patch_index[raw_name] = PatchIndex(
|
| 101 |
+
*unpack_patches(raw_data),
|
| 102 |
+
zero_value=zero_value,
|
| 103 |
+
missing_value=missing_value,
|
| 104 |
+
interval=self.interval,
|
| 105 |
+
box_size=box_size
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
def augmentations(self):
|
| 109 |
+
return tuple(self.rng.randint(2, size=3))
|
| 110 |
+
|
| 111 |
+
def augment_batch(self, batch, transpose, flipud, fliplr):
|
| 112 |
+
if self.augment:
|
| 113 |
+
if transpose:
|
| 114 |
+
axes = list(range(batch.ndim))
|
| 115 |
+
axes = axes[:-2] + [axes[-1], axes[-2]]
|
| 116 |
+
batch = batch.transpose(axes)
|
| 117 |
+
flips = []
|
| 118 |
+
if flipud:
|
| 119 |
+
flips.append(-2)
|
| 120 |
+
if fliplr:
|
| 121 |
+
flips.append(-1)
|
| 122 |
+
if flips:
|
| 123 |
+
batch = np.flip(batch, axis=flips)
|
| 124 |
+
return batch.copy()
|
| 125 |
+
|
| 126 |
+
def batch(self, samples=None, batch_size=None):
|
| 127 |
+
if batch_size is None:
|
| 128 |
+
batch_size = self.batch_size
|
| 129 |
+
|
| 130 |
+
if samples is None:
|
| 131 |
+
# get the sample coordinates from the sampler
|
| 132 |
+
samples = self.sampler(batch_size)
|
| 133 |
+
|
| 134 |
+
print(samples)
|
| 135 |
+
(t0,i0,j0) = samples.T
|
| 136 |
+
|
| 137 |
+
if self.augment:
|
| 138 |
+
augmentations = self.augmentations()
|
| 139 |
+
|
| 140 |
+
batch = {}
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
for var_name in self.used_variables:
|
| 145 |
+
var_data = self.variables[var_name]
|
| 146 |
+
|
| 147 |
+
# different timestep from standard (e.g. forecast); round down
|
| 148 |
+
# to times where we have data available
|
| 149 |
+
ts_secs = var_data.get("timestep_secs", self.interval_secs)
|
| 150 |
+
t_shift = -(t0 % ts_secs)
|
| 151 |
+
t0_shifted = t0 + t_shift
|
| 152 |
+
t = t0_shifted[:,None] + ts_secs*var_data["timesteps"][None,:]
|
| 153 |
+
t_relative = (t - t0[:,None]) / self.interval_secs
|
| 154 |
+
|
| 155 |
+
# read raw data from index
|
| 156 |
+
raw_data = (
|
| 157 |
+
self.patch_index[raw_name](t,i0,j0)
|
| 158 |
+
for raw_name in var_data["sources"]
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# transform to model variable
|
| 162 |
+
batch_var = var_data["transform"](*raw_data)
|
| 163 |
+
|
| 164 |
+
# add channel dimension if not already present
|
| 165 |
+
add_dims = (1,) if batch_var.ndim == 4 else ()
|
| 166 |
+
batch_var = np.expand_dims(batch_var, add_dims)
|
| 167 |
+
|
| 168 |
+
# data augmentation
|
| 169 |
+
if self.augment:
|
| 170 |
+
batch_var = self.augment_batch(batch_var, *augmentations)
|
| 171 |
+
|
| 172 |
+
# bundle with time coordinates
|
| 173 |
+
batch[var_name] = (batch_var, t_relative.astype(np.float32))
|
| 174 |
+
|
| 175 |
+
pred_batch = [batch[v] for v in self.predictors]
|
| 176 |
+
target_batch = batch[self.target][0] # no time coordinates for target
|
| 177 |
+
return (pred_batch, target_batch)
|
| 178 |
+
|
| 179 |
+
def batches(self, *args, num=None, **kwargs):
|
| 180 |
+
if num is not None:
|
| 181 |
+
for i in range(num):
|
| 182 |
+
yield self.batch(*args, **kwargs)
|
| 183 |
+
else:
|
| 184 |
+
while True:
|
| 185 |
+
yield self.batch(*args, **kwargs)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class StreamBatchDataset(IterableDataset):
|
| 189 |
+
def __init__(self, batch_gen, batches_per_epoch):
|
| 190 |
+
super().__init__()
|
| 191 |
+
self.batch_gen = batch_gen
|
| 192 |
+
self.batches_per_epoch = batches_per_epoch
|
| 193 |
+
|
| 194 |
+
def __iter__(self):
|
| 195 |
+
batches = self.batch_gen.batches(num=self.batches_per_epoch)
|
| 196 |
+
yield from batches
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class DeterministicBatchDataset(Dataset):
|
| 200 |
+
def __init__(self, batch_gen, batches_per_epoch, random_seed=None):
|
| 201 |
+
super().__init__()
|
| 202 |
+
self.batch_gen = batch_gen
|
| 203 |
+
self.batches_per_epoch = batches_per_epoch
|
| 204 |
+
self.batch_gen.sampler.rng = np.random.RandomState(seed=random_seed)
|
| 205 |
+
self.samples = [
|
| 206 |
+
self.batch_gen.sampler(self.batch_gen.batch_size)
|
| 207 |
+
for i in range(self.batches_per_epoch)
|
| 208 |
+
]
|
| 209 |
+
|
| 210 |
+
def __len__(self):
|
| 211 |
+
return self.batches_per_epoch
|
| 212 |
+
|
| 213 |
+
def __getitem__(self, ind):
|
| 214 |
+
print(self.samples[ind])
|
| 215 |
+
return self.batch_gen.batch(samples=self.samples[ind])
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class PatchIndex:
|
| 219 |
+
IDX_ZERO = -1
|
| 220 |
+
IDX_MISSING = -2
|
| 221 |
+
|
| 222 |
+
def __init__(
|
| 223 |
+
self, patch_data, patch_coords, patch_times,
|
| 224 |
+
zero_patch_coords, zero_patch_times,
|
| 225 |
+
interval=timedelta(minutes=5),
|
| 226 |
+
box_size=(4,4), zero_value=0,
|
| 227 |
+
missing_value=0
|
| 228 |
+
):
|
| 229 |
+
self.dt = int(round(interval.total_seconds()))
|
| 230 |
+
self.box_size = box_size
|
| 231 |
+
self.zero_value = zero_value
|
| 232 |
+
self.missing_value = missing_value
|
| 233 |
+
self.patch_data = patch_data
|
| 234 |
+
self.sample_shape = (
|
| 235 |
+
self.patch_data.shape[1]*box_size[0],
|
| 236 |
+
self.patch_data.shape[2]*box_size[1]
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
self.patch_index = Dict.empty(
|
| 240 |
+
key_type=types.UniTuple(types.int64, 3),
|
| 241 |
+
value_type=types.int64
|
| 242 |
+
)
|
| 243 |
+
init_patch_index(self.patch_index, patch_coords, patch_times)
|
| 244 |
+
init_patch_index_zero(self.patch_index, zero_patch_coords,
|
| 245 |
+
zero_patch_times, PatchIndex.IDX_ZERO)
|
| 246 |
+
|
| 247 |
+
self._batch = None
|
| 248 |
+
|
| 249 |
+
def _alloc_batch(self, batch_size, num_timesteps):
|
| 250 |
+
needs_rebuild = (self._batch is None) or \
|
| 251 |
+
(self._batch.shape[0] < batch_size) or \
|
| 252 |
+
(self._batch.shape[1] < num_timesteps)
|
| 253 |
+
if needs_rebuild:
|
| 254 |
+
del self._batch
|
| 255 |
+
self._batch = np.zeros(
|
| 256 |
+
(batch_size,num_timesteps)+self.sample_shape,
|
| 257 |
+
self.patch_data.dtype
|
| 258 |
+
)
|
| 259 |
+
return self._batch
|
| 260 |
+
|
| 261 |
+
def __call__(self, t, i0_all, j0_all):
|
| 262 |
+
batch = self._alloc_batch(*t.shape)
|
| 263 |
+
|
| 264 |
+
i1_all = i0_all + self.box_size[0]
|
| 265 |
+
j1_all = j0_all + self.box_size[1]
|
| 266 |
+
bi_size = self.patch_data.shape[1]
|
| 267 |
+
bj_size = self.patch_data.shape[2]
|
| 268 |
+
|
| 269 |
+
build_batch(batch, self.patch_data, self.patch_index,
|
| 270 |
+
t, i0_all, i1_all, j0_all, j1_all,
|
| 271 |
+
bi_size, bj_size, self.zero_value,
|
| 272 |
+
self.missing_value)
|
| 273 |
+
|
| 274 |
+
return batch[:,:t.shape[1],...]
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
@njit
|
| 278 |
+
def init_patch_index(patch_index, patch_coords, patch_times):
|
| 279 |
+
for k in range(patch_coords.shape[0]):
|
| 280 |
+
t = patch_times[k]
|
| 281 |
+
i = np.int64(patch_coords[k,0])
|
| 282 |
+
j = np.int64(patch_coords[k,1])
|
| 283 |
+
patch_index[(t,i,j)] = k
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
@njit
|
| 287 |
+
def init_patch_index_zero(patch_index, zero_patch_coords,
|
| 288 |
+
zero_patch_times, idx_zero):
|
| 289 |
+
|
| 290 |
+
for k in range(zero_patch_coords.shape[0]):
|
| 291 |
+
t = zero_patch_times[k]
|
| 292 |
+
i = np.int64(zero_patch_coords[k,0])
|
| 293 |
+
j = np.int64(zero_patch_coords[k,1])
|
| 294 |
+
patch_index[(t,i,j)] = idx_zero
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# numba can't find these values from PatchIndex
|
| 298 |
+
IDX_ZERO = PatchIndex.IDX_ZERO
|
| 299 |
+
IDX_MISSING = PatchIndex.IDX_MISSING
|
| 300 |
+
@njit(parallel=True)
|
| 301 |
+
def build_batch(
|
| 302 |
+
batch, patch_data, patch_index,
|
| 303 |
+
t_all, i0_all, i1_all, j0_all, j1_all,
|
| 304 |
+
bi_size, bj_size, zero_value, missing_value
|
| 305 |
+
):
|
| 306 |
+
for k in prange(t_all.shape[0]):
|
| 307 |
+
i0 = i0_all[k]
|
| 308 |
+
i1 = i1_all[k]
|
| 309 |
+
j0 = j0_all[k]
|
| 310 |
+
j1 = j1_all[k]
|
| 311 |
+
|
| 312 |
+
for (bt,t) in enumerate(t_all[k,:]):
|
| 313 |
+
for i in range(i0, i1):
|
| 314 |
+
bi0 = (i-i0) * bi_size
|
| 315 |
+
bi1 = bi0 + bi_size
|
| 316 |
+
for j in range(j0, j1):
|
| 317 |
+
ind = int(patch_index.get((t,i,j), IDX_MISSING))
|
| 318 |
+
bj0 = (j-j0) * bj_size
|
| 319 |
+
bj1 = bj0 + bj_size
|
| 320 |
+
if ind >= 0:
|
| 321 |
+
batch[k,bt,bi0:bi1,bj0:bj1] = patch_data[ind]
|
| 322 |
+
elif ind == IDX_ZERO:
|
| 323 |
+
batch[k,bt,bi0:bi1,bj0:bj1] = zero_value
|
| 324 |
+
elif ind == IDX_MISSING:
|
| 325 |
+
batch[k,bt,bi0:bi1,bj0:bj1] = missing_value
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class ForecastPatchIndexWrapper(PatchIndex):
|
| 329 |
+
def __init__(self, patch_index):
|
| 330 |
+
self.patch_index = patch_index
|
| 331 |
+
raw_names = {"-".join(v.split("-")[:-1]) for v in patch_index}
|
| 332 |
+
if len(raw_names) != 1:
|
| 333 |
+
raise ValueError(
|
| 334 |
+
"Can only wrap variables with the same base name")
|
| 335 |
+
self.raw_name = list(raw_names)[0]
|
| 336 |
+
lags_hour = [int(v.split("-")[-1]) for v in patch_index]
|
| 337 |
+
self.lags_hour = set(lags_hour)
|
| 338 |
+
forecast_interval_hour = np.diff(sorted(lags_hour))
|
| 339 |
+
if len(set(forecast_interval_hour)) != 1:
|
| 340 |
+
raise ValueError("Lags must be evenly spaced")
|
| 341 |
+
forecast_interval_hour = forecast_interval_hour[0]
|
| 342 |
+
if (24 % forecast_interval_hour):
|
| 343 |
+
raise ValueError(
|
| 344 |
+
"24 hours must be a multiple of the forecast interval")
|
| 345 |
+
self.forecast_interval_hour = forecast_interval_hour
|
| 346 |
+
self.forecast_interval = 3600 * forecast_interval_hour
|
| 347 |
+
|
| 348 |
+
# need to set these for _alloc_batch to work
|
| 349 |
+
self._batch = None
|
| 350 |
+
v = list(self.patch_index.keys())[0]
|
| 351 |
+
self.sample_shape = self.patch_index[v].sample_shape
|
| 352 |
+
self.patch_data = self.patch_index[v].patch_data
|
| 353 |
+
|
| 354 |
+
def __call__(self, t, i0, j0):
|
| 355 |
+
batch = self._alloc_batch(*t.shape)
|
| 356 |
+
|
| 357 |
+
# ensure that all data come from the same forecast
|
| 358 |
+
t0 = t[:,:1]
|
| 359 |
+
start_time_from_fc = t0 % self.forecast_interval
|
| 360 |
+
time_from_fc = start_time_from_fc + (t - t0)
|
| 361 |
+
lags_hour = (time_from_fc // self.forecast_interval) * \
|
| 362 |
+
self.forecast_interval_hour
|
| 363 |
+
|
| 364 |
+
for lag in self.lags_hour:
|
| 365 |
+
raw_name_lag = f"{self.raw_name}-{lag}"
|
| 366 |
+
batch_lag = self.patch_index[raw_name_lag](t,i0,j0)
|
| 367 |
+
lag_mask = (lags_hour == lag)
|
| 368 |
+
copy_masked_times(batch_lag, batch, lag_mask)
|
| 369 |
+
|
| 370 |
+
return batch[:,:t.shape[1],...]
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
@njit(parallel=True)
|
| 374 |
+
def copy_masked_times(from_batch, to_batch, mask):
|
| 375 |
+
for k in prange(from_batch.shape[0]):
|
| 376 |
+
for bt in range(from_batch.shape[1]):
|
| 377 |
+
if mask[k,bt]:
|
| 378 |
+
to_batch[k,bt,:,:] = from_batch[k,bt,:,:]
|
ldcast/features/io.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import netCDF4
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def convert_var_for_saving(
|
| 8 |
+
x, fill_value=0.02, min_value=0.05, max_value=118.428,
|
| 9 |
+
mean=-0.051, std=0.528
|
| 10 |
+
):
|
| 11 |
+
y = x*std + mean
|
| 12 |
+
log_min = np.log10(min_value)
|
| 13 |
+
log_max = np.log10(max_value)
|
| 14 |
+
mask = (y >= log_min)
|
| 15 |
+
y = y[mask].clip(max=log_max)
|
| 16 |
+
y = (y-log_min) / (log_max-log_min)
|
| 17 |
+
yc = np.zeros_like(x, dtype=np.uint16)
|
| 18 |
+
yc[mask] = (y*65533).round().astype(np.uint16) + 1
|
| 19 |
+
return yc
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def decode_saved_var_to_rainrate(
|
| 23 |
+
x, fill_value=0.02, min_value=0.05, threshold=0.1, max_value=118.428,
|
| 24 |
+
mean=-0.051, std=0.528, log=False, preproc=None
|
| 25 |
+
):
|
| 26 |
+
mask = (x >= 1)
|
| 27 |
+
log_min = np.log10(min_value)
|
| 28 |
+
log_max = np.log10(max_value)
|
| 29 |
+
yc = log_min + (x[mask].astype(np.float32)-1) * \
|
| 30 |
+
((log_max-log_min) / 65533)
|
| 31 |
+
y = np.zeros_like(x, dtype=np.float32)
|
| 32 |
+
|
| 33 |
+
yc = 10**yc
|
| 34 |
+
y[mask] = yc
|
| 35 |
+
if preproc is not None:
|
| 36 |
+
y = [preproc[t](y[:,:,t,...]) for t in range(y.shape[2])]
|
| 37 |
+
y = np.stack(y, axis=2)
|
| 38 |
+
|
| 39 |
+
if log:
|
| 40 |
+
y[y < threshold] = fill_value
|
| 41 |
+
y = np.log10(y)
|
| 42 |
+
else:
|
| 43 |
+
y[y < threshold] = 0.0
|
| 44 |
+
|
| 45 |
+
return y
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def save_batch(x, y, y_pred, batch_index, fn_template, out_dir, out_fn=None):
|
| 49 |
+
while isinstance(x, list) or isinstance(x, tuple):
|
| 50 |
+
x = x[0]
|
| 51 |
+
|
| 52 |
+
x = convert_var_for_saving(np.array(x, copy=False))
|
| 53 |
+
y = convert_var_for_saving(np.array(y, copy=False))
|
| 54 |
+
y_pred = convert_var_for_saving(np.array(y_pred, copy=False))
|
| 55 |
+
|
| 56 |
+
if out_fn is None:
|
| 57 |
+
out_fn = fn_template.format(batch_index=batch_index)
|
| 58 |
+
out_fn = os.path.join(out_dir, out_fn)
|
| 59 |
+
|
| 60 |
+
with netCDF4.Dataset(out_fn, 'w') as ds:
|
| 61 |
+
dim_sample = ds.createDimension("dim_sample", y.shape[0])
|
| 62 |
+
dim_channel = ds.createDimension("dim_channel", y.shape[1])
|
| 63 |
+
dim_time_past = ds.createDimension("dim_time_past", x.shape[2])
|
| 64 |
+
dim_time_future = ds.createDimension("dim_time_future", y.shape[2])
|
| 65 |
+
dim_h = ds.createDimension("dim_h", y.shape[3])
|
| 66 |
+
dim_w = ds.createDimension("dim_w", y.shape[4])
|
| 67 |
+
dim_member = ds.createDimension("dim_member", y_pred.shape[5])
|
| 68 |
+
var_params = {"zlib": True, "complevel": 1}
|
| 69 |
+
|
| 70 |
+
var_fc = ds.createVariable(
|
| 71 |
+
"forecasts", y_pred.dtype,
|
| 72 |
+
(
|
| 73 |
+
"dim_sample", "dim_channel",
|
| 74 |
+
"dim_time_future", "dim_h", "dim_w", "dim_member"
|
| 75 |
+
),
|
| 76 |
+
**var_params
|
| 77 |
+
)
|
| 78 |
+
var_fc[:] = y_pred
|
| 79 |
+
|
| 80 |
+
var_obs_past = ds.createVariable(
|
| 81 |
+
"past_observations", x.dtype,
|
| 82 |
+
("dim_sample", "dim_channel", "dim_time_past", "dim_h", "dim_w"),
|
| 83 |
+
**var_params
|
| 84 |
+
)
|
| 85 |
+
var_obs_past[:] = x
|
| 86 |
+
|
| 87 |
+
var_obs_future = ds.createVariable(
|
| 88 |
+
"future_observations", y.dtype,
|
| 89 |
+
("dim_sample", "dim_channel", "dim_time_future", "dim_h", "dim_w"),
|
| 90 |
+
**var_params
|
| 91 |
+
)
|
| 92 |
+
var_obs_future[:] = y
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def load_batch(fn, decode=True, preproc_fc=None, **kwargs):
|
| 96 |
+
with netCDF4.Dataset(fn, 'r') as ds:
|
| 97 |
+
y_pred = np.array(ds["forecasts"][:], copy=False)
|
| 98 |
+
x = np.array(ds["past_observations"][:], copy=False)
|
| 99 |
+
y = np.array(ds["future_observations"][:], copy=False)
|
| 100 |
+
|
| 101 |
+
if decode:
|
| 102 |
+
x = decode_saved_var_to_rainrate(x, **kwargs)
|
| 103 |
+
y = decode_saved_var_to_rainrate(y, **kwargs)
|
| 104 |
+
y_pred = decode_saved_var_to_rainrate(
|
| 105 |
+
y_pred, preproc=preproc_fc, **kwargs
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
return (x, y, y_pred)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def load_all_observations(
|
| 112 |
+
ensemble_dir, decode=True, preproc_fc=None,
|
| 113 |
+
timeframe='future', **kwargs
|
| 114 |
+
):
|
| 115 |
+
files = os.listdir(ensemble_dir)
|
| 116 |
+
obs = []
|
| 117 |
+
for fn in sorted(files):
|
| 118 |
+
with netCDF4.Dataset(os.path.join(ensemble_dir, fn), 'r') as ds:
|
| 119 |
+
var = f"{timeframe}_observations"
|
| 120 |
+
x = np.array(ds[var][:], copy=False)
|
| 121 |
+
x = decode_saved_var_to_rainrate(x, **kwargs)
|
| 122 |
+
obs.append(x)
|
| 123 |
+
|
| 124 |
+
obs = np.concatenate(obs, axis=0)
|
| 125 |
+
return obs
|
ldcast/features/patches.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime, timedelta
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import dask
|
| 5 |
+
import netCDF4
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from .utils import average_pool
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def patch_locations(
|
| 12 |
+
time_range,
|
| 13 |
+
patch_box,
|
| 14 |
+
patch_shape=(32,32),
|
| 15 |
+
interval=timedelta(minutes=5),
|
| 16 |
+
epoch=(1970,1,1)
|
| 17 |
+
):
|
| 18 |
+
patches = {}
|
| 19 |
+
t = time_range[0]
|
| 20 |
+
while t < time_range[1]:
|
| 21 |
+
patches[t] = []
|
| 22 |
+
for pi in range(patch_box[0][0], patch_box[0][1]):
|
| 23 |
+
for pj in range(patch_box[1][0], patch_box[1][1]):
|
| 24 |
+
patches[t].append((pi,pj))
|
| 25 |
+
patches[t] = np.array(patches[t])
|
| 26 |
+
t += interval
|
| 27 |
+
|
| 28 |
+
return patches
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def save_patches_radar(
|
| 32 |
+
patches, archive_path, out_dir,
|
| 33 |
+
variables=("RZC", "CPCH"),
|
| 34 |
+
suffix="2020",
|
| 35 |
+
**kwargs
|
| 36 |
+
):
|
| 37 |
+
from ..datasets import mchradar
|
| 38 |
+
|
| 39 |
+
source_vars = {}
|
| 40 |
+
mchradar_reader = mchradar.MCHRadarReader(
|
| 41 |
+
archive_path=archive_path,
|
| 42 |
+
variables=variables,
|
| 43 |
+
phys_values=False
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
ezc_nonzero_count_func = lambda x: np.count_nonzero((x >= 1) & (x<251))
|
| 47 |
+
nonzero_count_func = {
|
| 48 |
+
"RZC": lambda x: np.count_nonzero(x > 1),
|
| 49 |
+
"CPCH": lambda x: np.count_nonzero(x > 1)
|
| 50 |
+
}
|
| 51 |
+
zero_value = {v: 0 for v in variables}
|
| 52 |
+
zero_value["RZC"] = 1
|
| 53 |
+
zero_value["CPCH"] = 1
|
| 54 |
+
|
| 55 |
+
save_patches_all(
|
| 56 |
+
mchradar_reader, patches, variables,
|
| 57 |
+
nonzero_count_func, zero_value, out_dir, suffix,
|
| 58 |
+
source_vars=source_vars, min_nonzeros_to_include=5,
|
| 59 |
+
**kwargs
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def save_patches_dwdradar(
|
| 64 |
+
patches, archive_path, out_dir,
|
| 65 |
+
variables=("RV",),
|
| 66 |
+
suffix="2022",
|
| 67 |
+
patch_shape=(32,32),
|
| 68 |
+
**kwargs
|
| 69 |
+
):
|
| 70 |
+
from ..datasets import dwdradar
|
| 71 |
+
|
| 72 |
+
source_vars = {}
|
| 73 |
+
dwdradar_reader = dwdradar.DWDRadarReader(
|
| 74 |
+
archive_path=archive_path,
|
| 75 |
+
variables=variables
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
patches_flt = {}
|
| 79 |
+
for t in sorted(patches):
|
| 80 |
+
if (t.hour==0) and (t.minute==0):
|
| 81 |
+
print(t)
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
data = dwdradar_reader.variable_for_time(t, "RV")
|
| 85 |
+
except FileNotFoundError:
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
patch_locs_time = []
|
| 89 |
+
for (pi,pj) in patches[t]:
|
| 90 |
+
i0 = pi * patch_shape[0]
|
| 91 |
+
i1 = i0 + patch_shape[0]
|
| 92 |
+
j0 = pj * patch_shape[1]
|
| 93 |
+
j1 = j0 + patch_shape[1]
|
| 94 |
+
patch = data[i0:i1,j0:j1]
|
| 95 |
+
if np.isfinite(patch).all():
|
| 96 |
+
patch_locs_time.append((pi,pj))
|
| 97 |
+
|
| 98 |
+
if patch_locs_time:
|
| 99 |
+
patches_flt[t] = np.array(patch_locs_time)
|
| 100 |
+
|
| 101 |
+
print(len(patches),len(patches_flt))
|
| 102 |
+
patches = patches_flt
|
| 103 |
+
|
| 104 |
+
nonzero_count_func = {"RV": np.count_nonzero}
|
| 105 |
+
zero_value = {v: 0 for v in variables}
|
| 106 |
+
|
| 107 |
+
save_patches_all(
|
| 108 |
+
dwdradar_reader, patches, variables,
|
| 109 |
+
nonzero_count_func, zero_value, out_dir, suffix,
|
| 110 |
+
source_vars=source_vars, min_nonzeros_to_include=5,
|
| 111 |
+
**kwargs
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def save_patches_ifs(
|
| 116 |
+
patches, archive_path, out_dir,
|
| 117 |
+
variables=(
|
| 118 |
+
#"rate-tp", "rate-cp", "t2m", "cape", "cin",
|
| 119 |
+
#"tclw", "tcwv", #, "rate-tpe"
|
| 120 |
+
#"u", "v"
|
| 121 |
+
"cin",
|
| 122 |
+
),
|
| 123 |
+
suffix="2020",
|
| 124 |
+
lags=(0,12)
|
| 125 |
+
):
|
| 126 |
+
from ..datasets import ifsnwp
|
| 127 |
+
from .. import projection
|
| 128 |
+
|
| 129 |
+
proj = projection.GridProjection(projection.ccs4_swiss_grid_area)
|
| 130 |
+
ifsnwp_reader = ifsnwp.IFSNWPReader(
|
| 131 |
+
proj,
|
| 132 |
+
archive_path=archive_path,
|
| 133 |
+
variables=variables,
|
| 134 |
+
lags=lags,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# we only get data for every hour, so modify patches
|
| 138 |
+
ifs_patches = {
|
| 139 |
+
dt: pset for (dt, pset) in patches.items()
|
| 140 |
+
if (dt.minute == dt.second == dt.microsecond == 0)
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
variables_with_lag = []
|
| 144 |
+
for lag in ifsnwp_reader.lags:
|
| 145 |
+
variables_with_lag.extend(f"{v}-{lag}" for v in variables)
|
| 146 |
+
|
| 147 |
+
count_positive = lambda x: np.count_nonzero(x > 0)
|
| 148 |
+
all_nonzero = lambda x: np.prod(x.shape)
|
| 149 |
+
nonzero_count_func = {
|
| 150 |
+
"rate-tp": count_positive,
|
| 151 |
+
"rate-cp": count_positive,
|
| 152 |
+
"t2m": all_nonzero,
|
| 153 |
+
"cape": count_positive,
|
| 154 |
+
"cin": count_positive,
|
| 155 |
+
"tclw": count_positive,
|
| 156 |
+
"tcwv": count_positive,
|
| 157 |
+
"u": all_nonzero,
|
| 158 |
+
"v": all_nonzero,
|
| 159 |
+
"rate-tpe": count_positive,
|
| 160 |
+
}
|
| 161 |
+
nonzero_count_func = {
|
| 162 |
+
v: nonzero_count_func[v.rsplit("-", 1)[0]]
|
| 163 |
+
for v in variables_with_lag
|
| 164 |
+
}
|
| 165 |
+
postproc = {
|
| 166 |
+
f"cin-{lag}": lambda x: np.nan_to_num(x, nan=0.0, copy=False)
|
| 167 |
+
for lag in lags
|
| 168 |
+
}
|
| 169 |
+
zero_value = {v: 0 for v in variables_with_lag}
|
| 170 |
+
avg_pool = lambda x: average_pool(x, factor=8, missing=np.nan)
|
| 171 |
+
pool = {v: avg_pool for v in variables_with_lag}
|
| 172 |
+
|
| 173 |
+
save_patches_all(ifsnwp_reader, ifs_patches, variables_with_lag,
|
| 174 |
+
nonzero_count_func, zero_value, out_dir, suffix, pool=pool,
|
| 175 |
+
postproc=postproc)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def save_patches_cosmo(patches, archive_path, out_dir, suffix="2020"):
|
| 179 |
+
from ..datasets import cosmonwp
|
| 180 |
+
|
| 181 |
+
cosmonwp_reader = cosmonwp.COSMOCCS4Reader(
|
| 182 |
+
archive_path=archive_path, cache_size=6000)
|
| 183 |
+
|
| 184 |
+
# we only get data for every hour, so modify patches
|
| 185 |
+
cosmo_patches = {}
|
| 186 |
+
for (dt,pset) in patches.items():
|
| 187 |
+
dt0 = datetime(dt.year, dt.month, dt.day, dt.hour)
|
| 188 |
+
dt1 = dt0 + timedelta(hours=1)
|
| 189 |
+
if dt0 not in cosmo_patches:
|
| 190 |
+
cosmo_patches[dt0] = set()
|
| 191 |
+
if dt1 not in cosmo_patches:
|
| 192 |
+
cosmo_patches[dt1] = set()
|
| 193 |
+
cosmo_patches[dt0].update(pset)
|
| 194 |
+
cosmo_patches[dt1].update(pset)
|
| 195 |
+
|
| 196 |
+
variables = [
|
| 197 |
+
"CAPE_MU", "CIN_MU", "SLI",
|
| 198 |
+
"HZEROCL", "LCL_ML", "MCONV", "OMEGA",
|
| 199 |
+
"T_2M", "T_SO", "SOILTYP"
|
| 200 |
+
]
|
| 201 |
+
count_positive = lambda x: np.count_nonzero(x>0)
|
| 202 |
+
all_nonzero = lambda x: np.prod(x.shape)
|
| 203 |
+
nonzero_count_func = {
|
| 204 |
+
"CAPE_MU": count_positive,
|
| 205 |
+
"CIN_MU": count_positive,
|
| 206 |
+
"SLI": all_nonzero,
|
| 207 |
+
"HZEROCL": count_positive,
|
| 208 |
+
"LCL_ML": count_positive,
|
| 209 |
+
"MCONV": all_nonzero,
|
| 210 |
+
"OMEGA": all_nonzero,
|
| 211 |
+
"T_2M": all_nonzero,
|
| 212 |
+
"T_SO": all_nonzero,
|
| 213 |
+
"SOILTYP": lambda x: np.count_nonzero(x!=5)
|
| 214 |
+
}
|
| 215 |
+
zero_value = {v: 0 for v in variables}
|
| 216 |
+
zero_value["SOILTYP"] = 5
|
| 217 |
+
|
| 218 |
+
save_patches_all(cosmonwp_reader, cosmo_patches, variables,
|
| 219 |
+
nonzero_count_func, zero_value, out_dir, suffix, pool=pool)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def save_patches_all(
|
| 223 |
+
reader, patches, variables, nonzero_count_func, zero_value,
|
| 224 |
+
out_dir, suffix, epoch=datetime(1970,1,1), postproc={}, scale=None,
|
| 225 |
+
pool={}, source_vars={}, parallel=False, min_nonzeros_to_include=1
|
| 226 |
+
):
|
| 227 |
+
|
| 228 |
+
def save_var(var_name):
|
| 229 |
+
src_name = source_vars.get(var_name, var_name)
|
| 230 |
+
|
| 231 |
+
(patch_data, patch_coords, patch_times,
|
| 232 |
+
zero_patch_coords, zero_patch_times) = get_patches(
|
| 233 |
+
reader, src_name, patches,
|
| 234 |
+
nonzero_count_func=nonzero_count_func[var_name],
|
| 235 |
+
postproc=postproc.get(var_name),
|
| 236 |
+
pool=pool.get(var_name)
|
| 237 |
+
)
|
| 238 |
+
try:
|
| 239 |
+
time = epoch + timedelta(seconds=int(patch_times[0]))
|
| 240 |
+
var_scale = reader.get_scale(time, var_name)
|
| 241 |
+
except (AttributeError, KeyError):
|
| 242 |
+
var_scale = None if (scale is None) else scale[var_name]
|
| 243 |
+
pass
|
| 244 |
+
|
| 245 |
+
var_name = var_name.replace("_", "-")
|
| 246 |
+
out_fn = f"patches_{var_name}_{suffix}.nc"
|
| 247 |
+
out_path = os.path.join(out_dir, var_name)
|
| 248 |
+
os.makedirs(out_path, exist_ok=True)
|
| 249 |
+
out_fn = os.path.join(out_path, out_fn)
|
| 250 |
+
|
| 251 |
+
save_patches(
|
| 252 |
+
patch_data, patch_coords, patch_times,
|
| 253 |
+
zero_patch_coords, zero_patch_times, out_fn,
|
| 254 |
+
zero_value=zero_value[var_name], scale=var_scale
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
if parallel:
|
| 258 |
+
save_var = dask.delayed(save_var)
|
| 259 |
+
|
| 260 |
+
jobs = [save_var(v) for v in variables]
|
| 261 |
+
if parallel:
|
| 262 |
+
dask.compute(jobs, scheduler='threads')
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def get_patches(
|
| 266 |
+
reader, variable, patches,
|
| 267 |
+
patch_shape=(32,32), nonzero_count_func=None,
|
| 268 |
+
epoch=datetime(1970,1,1), postproc=None,
|
| 269 |
+
pool=None, min_nonzeros_to_include=1
|
| 270 |
+
):
|
| 271 |
+
num_patches = sum(len(patches[t]) for t in patches)
|
| 272 |
+
patch_data = []
|
| 273 |
+
patch_coords = []
|
| 274 |
+
patch_times = []
|
| 275 |
+
zero_patch_coords = []
|
| 276 |
+
zero_patch_times = []
|
| 277 |
+
|
| 278 |
+
if hasattr(reader, "phys_values"):
|
| 279 |
+
phys_values = reader.phys_values
|
| 280 |
+
|
| 281 |
+
k = 0
|
| 282 |
+
try:
|
| 283 |
+
if hasattr(reader, "phys_values"):
|
| 284 |
+
reader.phys_values = False
|
| 285 |
+
for (t, p_coord) in patches.items():
|
| 286 |
+
try:
|
| 287 |
+
data = reader.variable_for_time(t, variable)
|
| 288 |
+
except (ValueError, FileNotFoundError, KeyError, OSError):
|
| 289 |
+
continue
|
| 290 |
+
|
| 291 |
+
if postproc is not None:
|
| 292 |
+
data = postproc(data)
|
| 293 |
+
|
| 294 |
+
time_sec = np.int64((t-epoch).total_seconds())
|
| 295 |
+
for (pi, pj) in p_coord:
|
| 296 |
+
if k % 100000 == 0:
|
| 297 |
+
print("{}: {}/{}".format(t, k, num_patches))
|
| 298 |
+
patch_box = data[
|
| 299 |
+
pi*patch_shape[0]:(pi+1)*patch_shape[0],
|
| 300 |
+
pj*patch_shape[1]:(pj+1)*patch_shape[1],
|
| 301 |
+
].copy()
|
| 302 |
+
is_nonzero = (nonzero_count_func is not None) and \
|
| 303 |
+
(nonzero_count_func(patch_box) < min_nonzeros_to_include)
|
| 304 |
+
if is_nonzero:
|
| 305 |
+
zero_patch_coords.append((pi,pj))
|
| 306 |
+
zero_patch_times.append(time_sec)
|
| 307 |
+
else:
|
| 308 |
+
if pool is not None:
|
| 309 |
+
patch_box = pool(patch_box)
|
| 310 |
+
patch_data.append(patch_box)
|
| 311 |
+
patch_coords.append((pi,pj))
|
| 312 |
+
patch_times.append(time_sec)
|
| 313 |
+
k += 1
|
| 314 |
+
|
| 315 |
+
finally:
|
| 316 |
+
if hasattr(reader, "phys_values"):
|
| 317 |
+
reader.phys_values = phys_values
|
| 318 |
+
|
| 319 |
+
if zero_patch_coords:
|
| 320 |
+
zero_patch_coords = np.stack(zero_patch_coords, axis=0).astype(np.uint16)
|
| 321 |
+
zero_patch_times = np.stack(zero_patch_times, axis=0)
|
| 322 |
+
else:
|
| 323 |
+
zero_patch_coords = np.zeros((0,2), dtype=np.uint16)
|
| 324 |
+
zero_patch_times = np.zeros((0,), dtype=np.int64)
|
| 325 |
+
patch_data = np.stack(patch_data, axis=0)
|
| 326 |
+
patch_coords = np.stack(patch_coords, axis=0).astype(np.uint16)
|
| 327 |
+
patch_times = np.stack(patch_times, axis=0)
|
| 328 |
+
|
| 329 |
+
return (patch_data, patch_coords, patch_times,
|
| 330 |
+
zero_patch_coords, zero_patch_times)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def save_patches(patch_data, patch_coords, patch_times,
|
| 334 |
+
zero_patch_coords, zero_patch_times, out_fn, zero_value=0, scale=None):
|
| 335 |
+
|
| 336 |
+
with netCDF4.Dataset(out_fn, 'w') as ds:
|
| 337 |
+
dim_patch = ds.createDimension("dim_patch", patch_data.shape[0])
|
| 338 |
+
dim_zero_patch = ds.createDimension("dim_zero_patch", zero_patch_coords.shape[0])
|
| 339 |
+
dim_coord = ds.createDimension("dim_coord", 2)
|
| 340 |
+
dim_height = ds.createDimension("dim_height", patch_data.shape[1])
|
| 341 |
+
dim_width = ds.createDimension("dim_width", patch_data.shape[2])
|
| 342 |
+
|
| 343 |
+
var_args = {"zlib": True, "complevel": 1}
|
| 344 |
+
|
| 345 |
+
chunksizes = (min(2**10, patch_data.shape[0]), patch_data.shape[1], patch_data.shape[2])
|
| 346 |
+
var_patch = ds.createVariable("patches", patch_data.dtype,
|
| 347 |
+
("dim_patch","dim_height","dim_width"), chunksizes=chunksizes, **var_args)
|
| 348 |
+
var_patch[:] = patch_data
|
| 349 |
+
|
| 350 |
+
var_patch_coord = ds.createVariable("patch_coords", patch_coords.dtype,
|
| 351 |
+
("dim_patch","dim_coord"), **var_args)
|
| 352 |
+
var_patch_coord[:] = patch_coords
|
| 353 |
+
|
| 354 |
+
var_patch_time = ds.createVariable("patch_times", patch_times.dtype,
|
| 355 |
+
("dim_patch",), **var_args)
|
| 356 |
+
var_patch_time[:] = patch_times
|
| 357 |
+
|
| 358 |
+
var_zero_patch_coord = ds.createVariable("zero_patch_coords", zero_patch_coords.dtype,
|
| 359 |
+
("dim_zero_patch","dim_coord"), **var_args)
|
| 360 |
+
var_zero_patch_coord[:] = zero_patch_coords
|
| 361 |
+
|
| 362 |
+
var_zero_patch_time = ds.createVariable("zero_patch_times", zero_patch_times.dtype,
|
| 363 |
+
("dim_zero_patch",), **var_args)
|
| 364 |
+
var_zero_patch_time[:] = zero_patch_times
|
| 365 |
+
|
| 366 |
+
ds.zero_value = zero_value
|
| 367 |
+
|
| 368 |
+
if scale is not None:
|
| 369 |
+
dim_scale = ds.createDimension("dim_scale", len(scale))
|
| 370 |
+
var_scale = ds.createVariable("scale", scale.dtype, ("dim_scale",), **var_args)
|
| 371 |
+
var_scale[:] = scale
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def load_patches(fn, in_memory=True):
|
| 375 |
+
if in_memory:
|
| 376 |
+
with open(fn, 'rb') as f:
|
| 377 |
+
ds_raw = f.read()
|
| 378 |
+
fn = None
|
| 379 |
+
else:
|
| 380 |
+
ds_raw = None
|
| 381 |
+
|
| 382 |
+
with netCDF4.Dataset(fn, 'r', memory=ds_raw) as ds:
|
| 383 |
+
patch_data = {
|
| 384 |
+
"patches": np.array(ds["patches"]),
|
| 385 |
+
"patch_coords": np.array(ds["patch_coords"]),
|
| 386 |
+
"patch_times": np.array(ds["patch_times"]),
|
| 387 |
+
"zero_patch_coords": np.array(ds["zero_patch_coords"]),
|
| 388 |
+
"zero_patch_times": np.array(ds["zero_patch_times"]),
|
| 389 |
+
"zero_value": ds.zero_value
|
| 390 |
+
}
|
| 391 |
+
if "scale" in ds.variables:
|
| 392 |
+
patch_data["scale"] = np.array(ds["scale"])
|
| 393 |
+
|
| 394 |
+
return patch_data
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def load_all_patches(patch_dir, var):
|
| 398 |
+
files = os.listdir(patch_dir)
|
| 399 |
+
jobs = []
|
| 400 |
+
for fn in files:
|
| 401 |
+
file_var = fn.split("_")[1]
|
| 402 |
+
if file_var == var:
|
| 403 |
+
fn = os.path.join(patch_dir, fn)
|
| 404 |
+
jobs.append(dask.delayed(load_patches)(fn))
|
| 405 |
+
|
| 406 |
+
file_data = dask.compute(jobs, scheduler="processes")[0]
|
| 407 |
+
patch_data = {}
|
| 408 |
+
keys = ["patches", "patch_coords", "patch_times",
|
| 409 |
+
"zero_patch_coords", "zero_patch_times"]
|
| 410 |
+
for k in keys:
|
| 411 |
+
patch_data[k] = np.concatenate(
|
| 412 |
+
[fd[k] for fd in file_data],
|
| 413 |
+
axis=0
|
| 414 |
+
)
|
| 415 |
+
patch_data["zero_value"] = file_data[0]["zero_value"]
|
| 416 |
+
if "scale" in file_data[0]:
|
| 417 |
+
patch_data["scale"] = file_data[0]["scale"]
|
| 418 |
+
|
| 419 |
+
return patch_data
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def unpack_patches(patch_data):
|
| 423 |
+
return (
|
| 424 |
+
patch_data["patches"],
|
| 425 |
+
patch_data["patch_coords"],
|
| 426 |
+
patch_data["patch_times"],
|
| 427 |
+
patch_data["zero_patch_coords"],
|
| 428 |
+
patch_data["zero_patch_times"]
|
| 429 |
+
)
|
ldcast/features/patches.py.save
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime, timedelta
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import dask
|
| 5 |
+
import netCDF4
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from .utils import average_pool
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def patch_locations(
|
| 12 |
+
time_range,
|
| 13 |
+
patch_box,
|
| 14 |
+
patch_shape=(32,32),
|
| 15 |
+
interval=timedelta(minutes=5),
|
| 16 |
+
epoch=(1970,1,1)
|
| 17 |
+
):
|
| 18 |
+
patches = {}
|
| 19 |
+
t = time_range[0]
|
| 20 |
+
while t < time_range[1]:
|
| 21 |
+
patches[t] = []
|
| 22 |
+
for pi in range(patch_box[0][0], patch_box[0][1]):
|
| 23 |
+
for pj in range(patch_box[1][0], patch_box[1][1]):
|
| 24 |
+
patches[t].append((pi,pj))
|
| 25 |
+
patches[t] = np.array(patches[t])
|
| 26 |
+
t += interval
|
| 27 |
+
|
| 28 |
+
return patches
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def save_patches_radar(
|
| 32 |
+
patches, archive_path, out_dir,
|
| 33 |
+
variables=("RZC", "CPCH"),
|
| 34 |
+
suffix="2020",
|
| 35 |
+
**kwargs
|
| 36 |
+
):
|
| 37 |
+
from ..datasets import mchradar
|
| 38 |
+
|
| 39 |
+
source_vars = {}
|
| 40 |
+
mchradar_reader = mchradar.MCHRadarReader(
|
| 41 |
+
archive_path=archive_path,
|
| 42 |
+
variables=variables,
|
| 43 |
+
phys_values=False
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
ezc_nonzero_count_func = lambda x: np.count_nonzero((x >= 1) & (x<251))
|
| 47 |
+
nonzero_count_func = {
|
| 48 |
+
"RZC": lambda x: np.count_nonzero(x > 1),
|
| 49 |
+
"CPCH": lambda x: np.count_nonzero(x > 1)
|
| 50 |
+
}
|
| 51 |
+
zero_value = {v: 0 for v in variables}
|
| 52 |
+
zero_value["RZC"] = 1
|
| 53 |
+
zero_value["CPCH"] = 1
|
| 54 |
+
|
| 55 |
+
save_patches_all(
|
| 56 |
+
mchradar_reader, patches, variables,
|
| 57 |
+
nonzero_count_func, zero_value, out_dir, suffix,
|
| 58 |
+
source_vars=source_vars, min_nonzeros_to_include=5,
|
| 59 |
+
**kwargs
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def save_patches_dwdradar(
|
| 64 |
+
patches, archive_path, out_dir,
|
| 65 |
+
variables=("RV",),
|
| 66 |
+
suffix="2022",
|
| 67 |
+
patch_shape=(32,32),
|
| 68 |
+
**kwargs
|
| 69 |
+
):
|
| 70 |
+
from ..datasets import dwdradar
|
| 71 |
+
|
| 72 |
+
source_vars = {}
|
| 73 |
+
dwdradar_reader = dwdradar.DWDRadarReader(
|
| 74 |
+
archive_path=archive_path,
|
| 75 |
+
variables=variables
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
patches_flt = {}
|
| 79 |
+
for t in sorted(patches):
|
| 80 |
+
if (t.hour==0) and (t.minute==0):
|
| 81 |
+
print(t)
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
data = dwdradar_reader.variable_for_time(t, "RV")
|
| 85 |
+
except FileNotFoundError:
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
patch_locs_time = []
|
| 89 |
+
for (pi,pj) in patches[t]:
|
| 90 |
+
i0 = pi * patch_shape[0]
|
| 91 |
+
i1 = i0 + patch_shape[0]
|
| 92 |
+
j0 = pj * patch_shape[1]
|
| 93 |
+
j1 = j0 + patch_shape[1]
|
| 94 |
+
patch = data[i0:i1,j0:j1]
|
| 95 |
+
if np.isfinite(patch).all():
|
| 96 |
+
patch_locs_time.append((pi,pj))
|
| 97 |
+
|
| 98 |
+
if patch_locs_time:
|
| 99 |
+
patches_flt[t] = np.array(patch_locs_time)
|
| 100 |
+
|
| 101 |
+
print(len(patches),len(patches_flt))
|
| 102 |
+
patches = patches_flt
|
| 103 |
+
|
| 104 |
+
nonzero_count_func = {"RV": np.count_nonzero}
|
| 105 |
+
zero_value = {v: 0 for v in variables}
|
| 106 |
+
|
| 107 |
+
save_patches_all(
|
| 108 |
+
dwdradar_reader, patches, variables,
|
| 109 |
+
nonzero_count_func, zero_value, out_dir, suffix,
|
| 110 |
+
source_vars=source_vars, min_nonzeros_to_include=5,
|
| 111 |
+
**kwargs
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def save_patches_ifs(
|
| 116 |
+
patches, archive_path, out_dir,
|
| 117 |
+
variables=(
|
| 118 |
+
#"rate-tp", "rate-cp", "t2m", "cape", "cin",
|
| 119 |
+
#"tclw", "tcwv", #, "rate-tpe"
|
| 120 |
+
#"u", "v"
|
| 121 |
+
"cin",
|
| 122 |
+
),
|
| 123 |
+
suffix="2020",
|
| 124 |
+
lags=(0,12)
|
| 125 |
+
):
|
| 126 |
+
from ..datasets import ifsnwp
|
| 127 |
+
from .. import projection
|
| 128 |
+
|
| 129 |
+
proj = projection.GridProjection(projection.ccs4_swiss_grid_area)
|
| 130 |
+
ifsnwp_reader = ifsnwp.IFSNWPReader(
|
| 131 |
+
proj,
|
| 132 |
+
archive_path=archive_path,
|
| 133 |
+
variables=variables,
|
| 134 |
+
lags=lags,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# we only get data for every hour, so modify patches
|
| 138 |
+
ifs_patches = {
|
| 139 |
+
dt: pset for (dt, pset) in patches.items()
|
| 140 |
+
if (dt.minute == dt.second == dt.microsecond == 0)
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
variables_with_lag = []
|
| 144 |
+
for lag in ifsnwp_reader.lags:
|
| 145 |
+
variables_with_lag.extend(f"{v}-{lag}" for v in variables)
|
| 146 |
+
|
| 147 |
+
count_positive = lambda x: np.count_nonzero(x > 0)
|
| 148 |
+
all_nonzero = lambda x: np.prod(x.shape)
|
| 149 |
+
nonzero_count_func = {
|
| 150 |
+
"rate-tp": count_positive,
|
| 151 |
+
"rate-cp": count_positive,
|
| 152 |
+
"t2m": all_nonzero,
|
| 153 |
+
"cape": count_positive,
|
| 154 |
+
"cin": count_positive,
|
| 155 |
+
"tclw": count_positive,
|
| 156 |
+
"tcwv": count_positive,
|
| 157 |
+
"u": all_nonzero,
|
| 158 |
+
"v": all_nonzero,
|
| 159 |
+
"rate-tpe": count_positive,
|
| 160 |
+
}
|
| 161 |
+
nonzero_count_func = {
|
| 162 |
+
v: nonzero_count_func[v.rsplit("-", 1)[0]]
|
| 163 |
+
for v in variables_with_lag
|
| 164 |
+
}
|
| 165 |
+
postproc = {
|
| 166 |
+
f"cin-{lag}": lambda x: np.nan_to_num(x, nan=0.0, copy=False)
|
| 167 |
+
for lag in lags
|
| 168 |
+
}
|
| 169 |
+
zero_value = {v: 0 for v in variables_with_lag}
|
| 170 |
+
avg_pool = lambda x: average_pool(x, factor=8, missing=np.nan)
|
| 171 |
+
pool = {v: avg_pool for v in variables_with_lag}
|
| 172 |
+
|
| 173 |
+
save_patches_all(ifsnwp_reader, ifs_patches, variables_with_lag,
|
| 174 |
+
nonzero_count_func, zero_value, out_dir, suffix, pool=pool,
|
| 175 |
+
postproc=postproc)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def save_patches_cosmo(patches, archive_path, out_dir, suffix="2020"):
|
| 179 |
+
from ..datasets import cosmonwp
|
| 180 |
+
|
| 181 |
+
cosmonwp_reader = cosmonwp.COSMOCCS4Reader(
|
| 182 |
+
archive_path=archive_path, cache_size=6000)
|
| 183 |
+
|
| 184 |
+
# we only get data for every hour, so modify patches
|
| 185 |
+
cosmo_patches = {}
|
| 186 |
+
for (dt,pset) in patches.items():
|
| 187 |
+
dt0 = datetime(dt.year, dt.month, dt.day, dt.hour)
|
| 188 |
+
dt1 = dt0 + timedelta(hours=1)
|
| 189 |
+
if dt0 not in cosmo_patches:
|
| 190 |
+
cosmo_patches[dt0] = set()
|
| 191 |
+
if dt1 not in cosmo_patches:
|
| 192 |
+
cosmo_patches[dt1] = set()
|
| 193 |
+
cosmo_patches[dt0].update(pset)
|
| 194 |
+
cosmo_patches[dt1].update(pset)
|
| 195 |
+
|
| 196 |
+
variables = [
|
| 197 |
+
"CAPE_MU", "CIN_MU", "SLI",
|
| 198 |
+
"HZEROCL", "LCL_ML", "MCONV", "OMEGA",
|
| 199 |
+
"T_2M", "T_SO", "SOILTYP"
|
| 200 |
+
]
|
| 201 |
+
count_positive = lambda x: np.count_nonzero(x>0)
|
| 202 |
+
all_nonzero = lambda x: np.prod(x.shape)
|
| 203 |
+
nonzero_count_func = {
|
| 204 |
+
"CAPE_MU": count_positive,
|
| 205 |
+
"CIN_MU": count_positive,
|
| 206 |
+
"SLI": all_nonzero,
|
| 207 |
+
"HZEROCL": count_positive,
|
| 208 |
+
"LCL_ML": count_positive,
|
| 209 |
+
"MCONV": all_nonzero,
|
| 210 |
+
"OMEGA": all_nonzero,
|
| 211 |
+
"T_2M": all_nonzero,
|
| 212 |
+
"T_SO": all_nonzero,
|
| 213 |
+
"SOILTYP": lambda x: np.count_nonzero(x!=5)
|
| 214 |
+
}
|
| 215 |
+
zero_value = {v: 0 for v in variables}
|
| 216 |
+
zero_value["SOILTYP"] = 5
|
| 217 |
+
|
| 218 |
+
save_patches_all(cosmonwp_reader, cosmo_patches, variables,
|
| 219 |
+
nonzero_count_func, zero_value, out_dir, suffix, pool=pool)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def save_patches_all(
|
| 223 |
+
reader, patches, variables, nonzero_count_func, zero_value,
|
| 224 |
+
out_dir, suffix, epoch=datetime(1970,1,1), postproc={}, scale=None,
|
| 225 |
+
pool={}, source_vars={}, parallel=False, min_nonzeros_to_include=1
|
| 226 |
+
):
|
| 227 |
+
|
| 228 |
+
def save_var(var_name):
|
| 229 |
+
src_name = source_vars.get(var_name, var_name)
|
| 230 |
+
|
| 231 |
+
(patch_data, patch_coords, patch_times,
|
| 232 |
+
zero_patch_coords, zero_patch_times) = get_patches(
|
| 233 |
+
reader, src_name, patches,
|
| 234 |
+
nonzero_count_func=nonzero_count_func[var_name],
|
| 235 |
+
postproc=postproc.get(var_name),
|
| 236 |
+
pool=pool.get(var_name)
|
| 237 |
+
)
|
| 238 |
+
try:
|
| 239 |
+
time = epoch + timedelta(seconds=int(patch_times[0]))
|
| 240 |
+
var_scale = reader.get_scale(time, var_name)
|
| 241 |
+
except (AttributeError, KeyError):
|
| 242 |
+
var_scale = None if (scale is None) else scale[var_name]
|
| 243 |
+
pass
|
| 244 |
+
|
| 245 |
+
var_name = var_name.replace("_", "-")
|
| 246 |
+
out_fn = f"patches_{var_name}_{suffix}.nc"
|
| 247 |
+
out_path = os.path.join(out_dir, var_name)
|
| 248 |
+
os.makedirs(out_path, exist_ok=True)
|
| 249 |
+
out_fn = os.path.join(out_path, out_fn)
|
| 250 |
+
|
| 251 |
+
save_patches(
|
| 252 |
+
patch_data, patch_coords, patch_times,
|
| 253 |
+
zero_patch_coords, zero_patch_times, out_fn,
|
| 254 |
+
zero_value=zero_value[var_name], scale=var_scale
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
if parallel:
|
| 258 |
+
save_var = dask.delayed(save_var)
|
| 259 |
+
|
| 260 |
+
jobs = [save_var(v) for v in variables]
|
| 261 |
+
if parallel:
|
| 262 |
+
dask.compute(jobs, scheduler='threads')
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def get_patches(
|
| 266 |
+
reader, variable, patches,
|
| 267 |
+
patch_shape=(32,32), nonzero_count_func=None,
|
| 268 |
+
epoch=datetime(1970,1,1), postproc=None,
|
| 269 |
+
pool=None, min_nonzeros_to_include=1
|
| 270 |
+
):
|
| 271 |
+
num_patches = sum(len(patches[t]) for t in patches)
|
| 272 |
+
patch_data = []
|
| 273 |
+
patch_coords = []
|
| 274 |
+
patch_times = []
|
| 275 |
+
zero_patch_coords = []
|
| 276 |
+
zero_patch_times = []
|
| 277 |
+
|
| 278 |
+
if hasattr(reader, "phys_values"):
|
| 279 |
+
phys_values = reader.phys_values
|
| 280 |
+
|
| 281 |
+
k = 0
|
| 282 |
+
try:
|
| 283 |
+
if hasattr(reader, "phys_values"):
|
| 284 |
+
reader.phys_values = False
|
| 285 |
+
for (t, p_coord) in patches.items():
|
| 286 |
+
try:
|
| 287 |
+
data = reader.variable_for_time(t, variable)
|
| 288 |
+
except (ValueError, FileNotFoundError, KeyError, OSError):
|
| 289 |
+
continue
|
| 290 |
+
|
| 291 |
+
if postproc is not None:
|
| 292 |
+
data = postproc(data)
|
| 293 |
+
|
| 294 |
+
time_sec = np.int64((t-epoch).total_seconds())
|
| 295 |
+
for (pi, pj) in p_coord:
|
| 296 |
+
if k % 100000 == 0:
|
| 297 |
+
print("{}: {}/{}".format(t, k, num_patches))
|
| 298 |
+
patch_box = data[
|
| 299 |
+
pi*patch_shape[0]:(pi+1)*patch_shape[0],
|
| 300 |
+
pj*patch_shape[1]:(pj+1)*patch_shape[1],
|
| 301 |
+
].copy()
|
| 302 |
+
is_nonzero = (nonzero_count_func is not None) and \
|
| 303 |
+
(nonzero_count_func(patch_box) < min_nonzeros_to_include)
|
| 304 |
+
if is_nonzero:
|
| 305 |
+
zero_patch_coords.append((pi,pj))
|
| 306 |
+
zero_patch_times.append(time_sec)
|
| 307 |
+
else:
|
| 308 |
+
if pool is not None:
|
| 309 |
+
patch_box = pool(patch_box)
|
| 310 |
+
patch_data.append(patch_box)
|
| 311 |
+
patch_coords.append((pi,pj))
|
| 312 |
+
patch_times.append(time_sec)
|
| 313 |
+
k += 1
|
| 314 |
+
|
| 315 |
+
finally:
|
| 316 |
+
if hasattr(reader, "phys_values"):
|
| 317 |
+
reader.phys_values = phys_values
|
| 318 |
+
|
| 319 |
+
if zero_patch_coords:
|
| 320 |
+
zero_patch_coords = np.stack(zero_patch_coords, axis=0).astype(np.uint16)
|
| 321 |
+
zero_patch_times = np.stack(zero_patch_times, axis=0)
|
| 322 |
+
else:
|
| 323 |
+
zero_patch_coords = np.zeros((0,2), dtype=np.uint16)
|
| 324 |
+
zero_patch_times = np.zeros((0,), dtype=np.int64)
|
| 325 |
+
patch_data = np.stack(patch_data, axis=0)
|
| 326 |
+
patch_coords = np.stack(patch_coords, axis=0).astype(np.uint16)
|
| 327 |
+
patch_times = np.stack(patch_times, axis=0)
|
| 328 |
+
|
| 329 |
+
return (patch_data, patch_coords, patch_times,
|
| 330 |
+
zero_patch_coords, zero_patch_times)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def save_patches(patch_data, patch_coords, patch_times,
|
| 334 |
+
zero_patch_coords, zero_patch_times, out_fn, zero_value=0, scale=None):
|
| 335 |
+
|
| 336 |
+
with netCDF4.Dataset(out_fn, 'w') as ds:
|
| 337 |
+
dim_patch = ds.createDimension("dim_patch", patch_data.shape[0])
|
| 338 |
+
dim_zero_patch = ds.createDimension("dim_zero_patch", zero_patch_coords.shape[0])
|
| 339 |
+
dim_coord = ds.createDimension("dim_coord", 2)
|
| 340 |
+
dim_height = ds.createDimension("dim_height", patch_data.shape[1])
|
| 341 |
+
dim_width = ds.createDimension("dim_width", patch_data.shape[2])
|
| 342 |
+
|
| 343 |
+
var_args = {"zlib": True, "complevel": 1}
|
| 344 |
+
|
| 345 |
+
chunksizes = (min(2**10, patch_data.shape[0]), patch_data.shape[1], patch_data.shape[2])
|
| 346 |
+
var_patch = ds.createVariable("patches", patch_data.dtype,
|
| 347 |
+
("dim_patch","dim_height","dim_width"), chunksizes=chunksizes, **var_args)
|
| 348 |
+
var_patch[:] = patch_data
|
| 349 |
+
|
| 350 |
+
var_patch_coord = ds.createVariable("patch_coords", patch_coords.dtype,
|
| 351 |
+
("dim_patch","dim_coord"), **var_args)
|
| 352 |
+
var_patch_coord[:] = patch_coords
|
| 353 |
+
|
| 354 |
+
var_patch_time = ds.createVariable("patch_times", patch_times.dtype,
|
| 355 |
+
("dim_patch",), **var_args)
|
| 356 |
+
var_patch_time[:] = patch_times
|
| 357 |
+
|
| 358 |
+
var_zero_patch_coord = ds.createVariable("zero_patch_coords", zero_patch_coords.dtype,
|
| 359 |
+
("dim_zero_patch","dim_coord"), **var_args)
|
| 360 |
+
var_zero_patch_coord[:] = zero_patch_coords
|
| 361 |
+
|
| 362 |
+
var_zero_patch_time = ds.createVariable("zero_patch_times", zero_patch_times.dtype,
|
| 363 |
+
("dim_zero_patch",), **var_args)
|
| 364 |
+
var_zero_patch_time[:] = zero_patch_times
|
| 365 |
+
|
| 366 |
+
ds.zero_value = zero_value
|
| 367 |
+
|
| 368 |
+
if scale is not None:
|
| 369 |
+
dim_scale = ds.createDimension("dim_scale", len(scale))
|
| 370 |
+
var_scale = ds.createVariable("scale", scale.dtype, ("dim_scale",), **var_args)
|
| 371 |
+
var_scale[:] = scale
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def load_patches(fn, in_memory=True):
|
| 375 |
+
if in_memory:
|
| 376 |
+
with open(fn, 'rb') as f:
|
| 377 |
+
ds_raw = f.read()
|
| 378 |
+
fn = None
|
| 379 |
+
else:
|
| 380 |
+
ds_raw = None
|
| 381 |
+
|
| 382 |
+
with netCDF4.Dataset(fn, 'r', memory=ds_raw) as ds:
|
| 383 |
+
patch_data = {
|
| 384 |
+
#"patches": np.array(ds["patches"]),
|
| 385 |
+
#"patch_coords": np.array(ds["patch_coords"]),
|
| 386 |
+
#"patch_times": np.array(ds["patch_times"]),
|
| 387 |
+
#"zero_patch_coords": np.array(ds["zero_patch_coords"]),
|
| 388 |
+
#"zero_patch_times": np.array(ds["zero_patch_times"]),
|
| 389 |
+
"zero_value": 1,
|
| 390 |
+
"pr": np.array(ds["pr"]),
|
| 391 |
+
}
|
| 392 |
+
if "scale" in ds.variables:
|
| 393 |
+
patch_data["scale"] = np.array(ds["scale"])
|
| 394 |
+
|
| 395 |
+
return patch_data
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def load_all_patches(patch_dir, var):
|
| 399 |
+
files = os.listdir(patch_dir)
|
| 400 |
+
jobs = []
|
| 401 |
+
for fn in files:
|
| 402 |
+
file_var = fn.split("_")[1]
|
| 403 |
+
if file_var == var:
|
| 404 |
+
fn = os.path.join(patch_dir, fn)
|
| 405 |
+
jobs.append(dask.delayed(load_patches)(fn))
|
| 406 |
+
|
| 407 |
+
file_data = dask.compute(jobs, scheduler="processes")[0]
|
| 408 |
+
patch_data = {}
|
| 409 |
+
#keys = ["patches", "patch_coords", "patch_times",
|
| 410 |
+
# "zero_patch_coords", "zero_patch_times"]
|
| 411 |
+
keys = ["pr"]
|
| 412 |
+
for k in keys:
|
| 413 |
+
patch_data[k] = np.concatenate(
|
| 414 |
+
[fd[k] for fd in file_data],
|
| 415 |
+
axis=0
|
| 416 |
+
)
|
| 417 |
+
patch_data["zero_value"] = file_data[0]["zero_value"]
|
| 418 |
+
if "scale" in file_data[0]:
|
| 419 |
+
patch_data["scale"] = file_data[0]["scale"]
|
| 420 |
+
|
| 421 |
+
return patch_data
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def unpack_patches(patch_data):
|
| 425 |
+
return (
|
| 426 |
+
patch_data["patches"],
|
| 427 |
+
patch_data["patch_coords"],
|
| 428 |
+
patch_data["patch_times"],
|
| 429 |
+
patch_data["zero_patch_coords"],
|
| 430 |
+
patch_data["zero_patch_times"]
|
| 431 |
+
)
|
ldcast/features/sampling.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from bisect import bisect_left
|
| 2 |
+
import multiprocessing
|
| 3 |
+
|
| 4 |
+
import dask
|
| 5 |
+
from numba import njit, prange, types
|
| 6 |
+
from numba.typed import Dict
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from .patches import unpack_patches
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class EqualFrequencySampler:
|
| 13 |
+
def __init__(
|
| 14 |
+
self, bins, patch_data, patch_index,
|
| 15 |
+
sample_shape, time_range_valid, time_range_sampling=None,
|
| 16 |
+
timestep_secs=5*60,
|
| 17 |
+
random_seed=None, preselected_samples=None
|
| 18 |
+
):
|
| 19 |
+
binned_patches = bin_classify_patches_parallel(
|
| 20 |
+
bins,
|
| 21 |
+
*unpack_patches(patch_data),
|
| 22 |
+
zero_value=patch_data.get("zero_value", 0),
|
| 23 |
+
scale=patch_data.get("scale")
|
| 24 |
+
)
|
| 25 |
+
complete_ind = indices_with_complete_sample(
|
| 26 |
+
patch_index, sample_shape, time_range_valid, timestep_secs
|
| 27 |
+
)
|
| 28 |
+
if time_range_sampling is None:
|
| 29 |
+
time_range_sampling = time_range_valid
|
| 30 |
+
self.starting_ind = [
|
| 31 |
+
starting_indices_for_centers(
|
| 32 |
+
p, complete_ind, sample_shape, time_range_sampling, timestep_secs
|
| 33 |
+
)
|
| 34 |
+
for p in binned_patches
|
| 35 |
+
]
|
| 36 |
+
self.num_bins = len(self.starting_ind)
|
| 37 |
+
self.rng = np.random.RandomState(seed=random_seed)
|
| 38 |
+
self.preselected_samples = preselected_samples
|
| 39 |
+
self.current_ind = np.array([len(ind) for ind in self.starting_ind])
|
| 40 |
+
|
| 41 |
+
def get_bin_sample(self, bin_ind):
|
| 42 |
+
patches = self.starting_ind[bin_ind]
|
| 43 |
+
sample_ind = self.current_ind[bin_ind]
|
| 44 |
+
if sample_ind >= patches.shape[0]:
|
| 45 |
+
self.rng.shuffle(patches)
|
| 46 |
+
sample_ind = self.current_ind[bin_ind] = 0
|
| 47 |
+
else:
|
| 48 |
+
self.current_ind[bin_ind] += 1
|
| 49 |
+
return patches[sample_ind,:]
|
| 50 |
+
|
| 51 |
+
def __call__(self, num):
|
| 52 |
+
# sample each bin with equal probability
|
| 53 |
+
bins = self.rng.randint(self.num_bins, size=num)
|
| 54 |
+
coords = np.stack(
|
| 55 |
+
[self.get_bin_sample(b) for b in bins],
|
| 56 |
+
axis=0
|
| 57 |
+
)
|
| 58 |
+
return coords
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def bin_classify_patches(
|
| 62 |
+
bins, patches, patch_coords, patch_times,
|
| 63 |
+
zero_patch_coords, zero_patch_times,
|
| 64 |
+
zero_value=0, metric_func=None,
|
| 65 |
+
scale=None,
|
| 66 |
+
):
|
| 67 |
+
if metric_func is None:
|
| 68 |
+
def metric_func(x):
|
| 69 |
+
xm = np.percentile(x, 99, axis=(1,2))
|
| 70 |
+
if np.issubdtype(x.dtype, np.integer):
|
| 71 |
+
xm = xm.round()
|
| 72 |
+
return xm.astype(x.dtype)
|
| 73 |
+
|
| 74 |
+
binned_patches = [[] for _ in range(len(bins)+1)]
|
| 75 |
+
|
| 76 |
+
def find_bin(value):
|
| 77 |
+
return bisect_left(bins, value)
|
| 78 |
+
|
| 79 |
+
zero_bin = find_bin(zero_value if scale is None else scale[zero_value])
|
| 80 |
+
for (t,(pi,pj)) in zip(zero_patch_times, zero_patch_coords):
|
| 81 |
+
binned_patches[zero_bin].append((t,pi,pj))
|
| 82 |
+
|
| 83 |
+
patch_metrics = metric_func(patches)
|
| 84 |
+
if scale is not None:
|
| 85 |
+
patch_metrics = scale[patch_metrics]
|
| 86 |
+
for (metric,t,(pi,pj)) in zip(patch_metrics, patch_times, patch_coords):
|
| 87 |
+
patch_bin = find_bin(metric)
|
| 88 |
+
binned_patches[patch_bin].append((t,pi,pj))
|
| 89 |
+
|
| 90 |
+
for i in range(len(binned_patches)):
|
| 91 |
+
if binned_patches[i]:
|
| 92 |
+
binned_patches[i] = np.array(binned_patches[i])
|
| 93 |
+
else:
|
| 94 |
+
binned_patches[i] = np.zeros((0,3), dtype=np.int64)
|
| 95 |
+
|
| 96 |
+
return binned_patches
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def bin_classify_patches_parallel(
|
| 100 |
+
bins, patches, patch_coords, patch_times,
|
| 101 |
+
zero_patch_coords, zero_patch_times,
|
| 102 |
+
zero_value=0, metric_func=None,
|
| 103 |
+
scale=None,
|
| 104 |
+
):
|
| 105 |
+
num_patches = patches.shape[0]
|
| 106 |
+
num_zeros = zero_patch_coords.shape[0]
|
| 107 |
+
num_procs = multiprocessing.cpu_count()
|
| 108 |
+
|
| 109 |
+
tasks = []
|
| 110 |
+
for p in range(num_procs):
|
| 111 |
+
pk0 = int(round(num_patches*p/num_procs))
|
| 112 |
+
pk1 = int(round(num_patches*(p+1)/num_procs))
|
| 113 |
+
zk0 = int(round(num_zeros*p/num_procs))
|
| 114 |
+
zk1 = int(round(num_zeros*(p+1)/num_procs))
|
| 115 |
+
|
| 116 |
+
task = dask.delayed(bin_classify_patches)(
|
| 117 |
+
bins,
|
| 118 |
+
patches[pk0:pk1,...], patch_coords[pk0:pk1,...],
|
| 119 |
+
patch_times[pk0:pk1],
|
| 120 |
+
zero_patch_coords[zk0:zk1,...], zero_patch_times[zk0:zk1],
|
| 121 |
+
zero_value=zero_value, metric_func=metric_func,
|
| 122 |
+
scale=scale
|
| 123 |
+
)
|
| 124 |
+
tasks.append(task)
|
| 125 |
+
|
| 126 |
+
chunked_bins = dask.compute(tasks, scheduler="threads")[0]
|
| 127 |
+
|
| 128 |
+
n_bins = len(chunked_bins[0])
|
| 129 |
+
binned_patches = [
|
| 130 |
+
np.concatenate([cb[i] for cb in chunked_bins], axis=0)
|
| 131 |
+
for i in range(n_bins)
|
| 132 |
+
]
|
| 133 |
+
return binned_patches
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def indices_with_complete_sample(
|
| 137 |
+
patch_index, sample_shape, time_range, timestep_secs
|
| 138 |
+
):
|
| 139 |
+
"""Check which locations will give a sample without missing data.
|
| 140 |
+
"""
|
| 141 |
+
ind = np.array(list(patch_index.patch_index.keys()))
|
| 142 |
+
t0 = ind[:,0]
|
| 143 |
+
i0 = ind[:,1]
|
| 144 |
+
j0 = ind[:,2]
|
| 145 |
+
n = ind.shape[0]
|
| 146 |
+
complete = np.ones(n, dtype=bool)
|
| 147 |
+
# we use this dict like a set - numba doesn't support typed sets
|
| 148 |
+
complete_ind = Dict.empty(
|
| 149 |
+
key_type=types.UniTuple(types.int64, 3),
|
| 150 |
+
value_type=types.uint8
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
@njit(parallel=True) # many nested loops, numba optimization needed
|
| 154 |
+
def check_complete(index, complete, complete_ind):
|
| 155 |
+
for k in prange(n):
|
| 156 |
+
for ts in range(*time_range):
|
| 157 |
+
t = t0[k] + ts*timestep_secs
|
| 158 |
+
for di in range(sample_shape[0]):
|
| 159 |
+
i = i0[k] + di
|
| 160 |
+
for dj in range(sample_shape[1]):
|
| 161 |
+
j = j0[k] + dj
|
| 162 |
+
if (t,i,j) not in index:
|
| 163 |
+
complete[k] = False
|
| 164 |
+
|
| 165 |
+
for k in range(n): # no prange: can't set dict items in parallel
|
| 166 |
+
if complete[k]:
|
| 167 |
+
complete_ind[(t0[k],i0[k],j0[k])] = np.uint8(0)
|
| 168 |
+
|
| 169 |
+
check_complete(patch_index.patch_index, complete, complete_ind)
|
| 170 |
+
|
| 171 |
+
return complete_ind
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def starting_indices_for_centers(
|
| 175 |
+
centers, complete_ind, sample_shape, time_range, timestep_secs
|
| 176 |
+
):
|
| 177 |
+
"""Determine a complete list of sample indices that
|
| 178 |
+
contain one or more of the centerpoints.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
@njit
|
| 182 |
+
def find_indices(centers, starting_ind, complete_ind):
|
| 183 |
+
for k in range(centers.shape[0]):
|
| 184 |
+
t0 = centers[k,0]
|
| 185 |
+
i0 = centers[k,1]
|
| 186 |
+
j0 = centers[k,2]
|
| 187 |
+
for ts in range(*time_range):
|
| 188 |
+
t = t0 - ts*timestep_secs # note minus signs in (t,i,j)
|
| 189 |
+
for di in range(sample_shape[0]):
|
| 190 |
+
i = i0 - di
|
| 191 |
+
for dj in range(sample_shape[1]):
|
| 192 |
+
j = j0 - dj
|
| 193 |
+
if (t,i,j) in complete_ind:
|
| 194 |
+
starting_ind[(t,i,j)] = np.uint8(0)
|
| 195 |
+
|
| 196 |
+
num_chunks = multiprocessing.cpu_count()
|
| 197 |
+
|
| 198 |
+
@dask.delayed
|
| 199 |
+
def chunk(i):
|
| 200 |
+
starting_ind = Dict.empty(
|
| 201 |
+
key_type=types.UniTuple(types.int64, 3),
|
| 202 |
+
value_type=types.uint8
|
| 203 |
+
)
|
| 204 |
+
k0 = int(round(centers.shape[0] * (i / num_chunks)))
|
| 205 |
+
k1 = int(round(centers.shape[0] * ((i+1) / num_chunks)))
|
| 206 |
+
find_indices(centers[k0:k1,...], starting_ind, complete_ind)
|
| 207 |
+
return starting_ind
|
| 208 |
+
|
| 209 |
+
jobs = [chunk(i) for i in range(num_chunks)]
|
| 210 |
+
starting_ind = dask.compute(jobs, scheduler='threads')[0]
|
| 211 |
+
starting_ind = np.concatenate(
|
| 212 |
+
[np.array(list(st_ind.keys())) for st_ind in starting_ind if st_ind],
|
| 213 |
+
axis=0
|
| 214 |
+
)
|
| 215 |
+
return starting_ind
|
ldcast/features/split.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from bisect import bisect_left
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
|
| 7 |
+
from . import batch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_chunks(
|
| 11 |
+
primary_raw, valid_frac=0.1, test_frac=0.1,
|
| 12 |
+
chunk_seconds=2*24*60*60, random_seed=None
|
| 13 |
+
):
|
| 14 |
+
t0 = min(
|
| 15 |
+
primary_raw["patch_times"][0],
|
| 16 |
+
primary_raw["zero_patch_times"][0]
|
| 17 |
+
)
|
| 18 |
+
t1 = max(
|
| 19 |
+
primary_raw["patch_times"][-1],
|
| 20 |
+
primary_raw["zero_patch_times"][-1]
|
| 21 |
+
)+1
|
| 22 |
+
|
| 23 |
+
rng = np.random.RandomState(seed=random_seed)
|
| 24 |
+
chunk_limits = np.arange(t0,t1,chunk_seconds)
|
| 25 |
+
num_chunks = len(chunk_limits)-1
|
| 26 |
+
|
| 27 |
+
chunk_ind = np.arange(num_chunks)
|
| 28 |
+
rng.shuffle(chunk_ind)
|
| 29 |
+
i_valid = int(round(num_chunks * valid_frac))
|
| 30 |
+
i_test = i_valid + int(round(num_chunks * test_frac))
|
| 31 |
+
chunk_ind = {
|
| 32 |
+
"valid": chunk_ind[:i_valid],
|
| 33 |
+
"test": chunk_ind[i_valid:i_test],
|
| 34 |
+
"train": chunk_ind[i_test:]
|
| 35 |
+
}
|
| 36 |
+
def get_chunk_limits(chunk_ind_split):
|
| 37 |
+
return sorted(
|
| 38 |
+
(chunk_limits[i], chunk_limits[i+1])
|
| 39 |
+
for i in chunk_ind_split
|
| 40 |
+
)
|
| 41 |
+
chunks = {
|
| 42 |
+
split: get_chunk_limits(chunk_ind_split)
|
| 43 |
+
for (split, chunk_ind_split) in chunk_ind.items()
|
| 44 |
+
}
|
| 45 |
+
return chunks
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def train_valid_test_split(
|
| 49 |
+
raw_data, primary_raw_var, chunks=None, **kwargs
|
| 50 |
+
):
|
| 51 |
+
if chunks is None:
|
| 52 |
+
primary = raw_data[primary_raw_var]
|
| 53 |
+
chunks = get_chunks(primary, **kwargs)
|
| 54 |
+
|
| 55 |
+
def split_chunks_from_array(x, chunks_split, times):
|
| 56 |
+
n = 0
|
| 57 |
+
chunk_ind = []
|
| 58 |
+
for (t0,t1) in chunks_split:
|
| 59 |
+
k0 = bisect_left(times, t0)
|
| 60 |
+
k1 = bisect_left(times, t1)
|
| 61 |
+
n += k1 - k0
|
| 62 |
+
chunk_ind.append((k0,k1))
|
| 63 |
+
|
| 64 |
+
shape = (n,) + x.shape[1:]
|
| 65 |
+
x_chunk = np.empty_like(x, shape=shape)
|
| 66 |
+
|
| 67 |
+
j0 = 0
|
| 68 |
+
for (k0,k1) in chunk_ind:
|
| 69 |
+
j1 = j0 + (k1-k0)
|
| 70 |
+
x_chunk[j0:j1,...] = x[k0:k1,...]
|
| 71 |
+
j0 = j1
|
| 72 |
+
|
| 73 |
+
return x_chunk
|
| 74 |
+
|
| 75 |
+
split_raw_data = {
|
| 76 |
+
split: {var: {} for var in raw_data}
|
| 77 |
+
for split in chunks
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
for (var, raw_data_var) in raw_data.items():
|
| 81 |
+
for (split, chunks_split) in chunks.items():
|
| 82 |
+
|
| 83 |
+
#split_raw_data[split][var]["patches"] = \
|
| 84 |
+
# split_chunks_from_array(
|
| 85 |
+
# raw_data_var["patches"], chunks_split,
|
| 86 |
+
# raw_data_var["patch_times"]
|
| 87 |
+
# )
|
| 88 |
+
#split_raw_data[split][var]["patch_coords"] = \
|
| 89 |
+
# split_chunks_from_array(
|
| 90 |
+
# raw_data_var["patch_coords"], chunks_split,
|
| 91 |
+
# raw_data_var["patch_times"]
|
| 92 |
+
# )
|
| 93 |
+
#split_raw_data[split][var]["patch_times"] = \
|
| 94 |
+
# split_chunks_from_array(
|
| 95 |
+
# raw_data_var["patch_times"], chunks_split,
|
| 96 |
+
# raw_data_var["patch_times"]
|
| 97 |
+
# )
|
| 98 |
+
#split_raw_data[split][var]["zero_patch_coords"] = \
|
| 99 |
+
# split_chunks_from_array(
|
| 100 |
+
# raw_data_var["zero_patch_coords"], chunks_split,
|
| 101 |
+
# raw_data_var["zero_patch_times"]
|
| 102 |
+
# )
|
| 103 |
+
#split_raw_data[split][var]["zero_patch_times"] = \
|
| 104 |
+
# split_chunks_from_array(
|
| 105 |
+
# raw_data_var["zero_patch_times"], chunks_split,
|
| 106 |
+
# raw_data_var["zero_patch_times"]
|
| 107 |
+
# )
|
| 108 |
+
|
| 109 |
+
added_keys = set(split_raw_data[split][var].keys())
|
| 110 |
+
missing_keys = set(raw_data[var].keys()) - added_keys
|
| 111 |
+
for k in missing_keys:
|
| 112 |
+
split_raw_data[split][var][k] = raw_data[var][k]
|
| 113 |
+
|
| 114 |
+
return (split_raw_data, chunks)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class DataModule(pl.LightningDataModule):
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
variables, raw, predictors, target, primary_var,
|
| 121 |
+
sampling_bins, sampler_file,
|
| 122 |
+
batch_size=8,
|
| 123 |
+
train_epoch_size=10, valid_epoch_size=2, test_epoch_size=10,
|
| 124 |
+
valid_seed=None, test_seed=None,
|
| 125 |
+
**kwargs
|
| 126 |
+
):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.batch_gen = {
|
| 129 |
+
split: batch.BatchGenerator(
|
| 130 |
+
variables, raw_var, predictors, target, primary_var,
|
| 131 |
+
sampling_bins=sampling_bins, batch_size=batch_size,
|
| 132 |
+
sampler_file=sampler_file.get(split),
|
| 133 |
+
augment=(split=="train"),
|
| 134 |
+
**kwargs
|
| 135 |
+
)
|
| 136 |
+
for (split,raw_var) in raw.items()
|
| 137 |
+
}
|
| 138 |
+
self.datasets = {}
|
| 139 |
+
if "train" in self.batch_gen:
|
| 140 |
+
self.datasets["train"] = batch.StreamBatchDataset(
|
| 141 |
+
self.batch_gen["train"], train_epoch_size
|
| 142 |
+
)
|
| 143 |
+
if "valid" in self.batch_gen:
|
| 144 |
+
self.datasets["valid"] = batch.DeterministicBatchDataset(
|
| 145 |
+
self.batch_gen["valid"], valid_epoch_size, random_seed=valid_seed
|
| 146 |
+
)
|
| 147 |
+
if "test" in self.batch_gen:
|
| 148 |
+
self.datasets["test"] = batch.DeterministicBatchDataset(
|
| 149 |
+
self.batch_gen["test"], test_epoch_size, random_seed=test_seed
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def dataloader(self, split):
|
| 153 |
+
return DataLoader(
|
| 154 |
+
self.datasets[split], batch_size=None,
|
| 155 |
+
pin_memory=True, num_workers=0
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def train_dataloader(self):
|
| 159 |
+
return self.dataloader("train")
|
| 160 |
+
|
| 161 |
+
def val_dataloader(self):
|
| 162 |
+
return self.dataloader("valid")
|
| 163 |
+
|
| 164 |
+
def test_dataloader(self):
|
| 165 |
+
return self.dataloader("test")
|
ldcast/features/transform.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import concurrent.futures
|
| 2 |
+
import multiprocessing
|
| 3 |
+
|
| 4 |
+
from numba import njit, prange
|
| 5 |
+
import numpy as np
|
| 6 |
+
from scipy.ndimage import convolve
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def quick_cast(x, y):
|
| 10 |
+
num_threads = multiprocessing.cpu_count()
|
| 11 |
+
with concurrent.futures.ThreadPoolExecutor(num_threads) as executor:
|
| 12 |
+
futures = {}
|
| 13 |
+
limits = np.linspace(0, x.shape[0], num_threads+1).round().astype(int)
|
| 14 |
+
def _cast(k0,k1):
|
| 15 |
+
y[k0:k1,...] = x[k0:k1,...]
|
| 16 |
+
for k in range(len(limits)-1):
|
| 17 |
+
args = (_cast, limits[k], limits[k+1])
|
| 18 |
+
futures[executor.submit(*args)] = k
|
| 19 |
+
concurrent.futures.wait(futures)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def cast(dtype=np.float16):
|
| 23 |
+
xc = None
|
| 24 |
+
def transform(raw):
|
| 25 |
+
nonlocal xc
|
| 26 |
+
if (xc is None) or (xc.shape != raw.shape):
|
| 27 |
+
xc = np.empty_like(raw, dtype=dtype)
|
| 28 |
+
quick_cast(raw, xc)
|
| 29 |
+
return xc
|
| 30 |
+
return transform
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@njit(parallel=True)
|
| 34 |
+
def scale_array(in_arr, out_arr, scale):
|
| 35 |
+
in_arr = in_arr.ravel()
|
| 36 |
+
out_arr = out_arr.ravel()
|
| 37 |
+
for i in prange(in_arr.shape[0]):
|
| 38 |
+
out_arr[i] = scale[in_arr[i]]
|
| 39 |
+
|
| 40 |
+
# NumPy version
|
| 41 |
+
#def scale_array(in_arr, out_arr, scale):
|
| 42 |
+
# out_arr[:] = scale[in_arr]
|
| 43 |
+
|
| 44 |
+
def normalize(mean=0.0, std=1.0, dtype=np.float32):
|
| 45 |
+
scaled = scaled_dt = None
|
| 46 |
+
|
| 47 |
+
def transform(raw):
|
| 48 |
+
nonlocal scaled, scaled_dt
|
| 49 |
+
if (scaled is None) or (scaled.shape != raw.shape):
|
| 50 |
+
scaled = np.empty_like(raw, dtype=np.float32)
|
| 51 |
+
scaled_dt = np.empty_like(raw, dtype=dtype)
|
| 52 |
+
normalize_array(raw, scaled, mean, std)
|
| 53 |
+
|
| 54 |
+
if dtype == np.float32:
|
| 55 |
+
return scaled
|
| 56 |
+
else:
|
| 57 |
+
quick_cast(scaled, scaled_dt)
|
| 58 |
+
return scaled_dt
|
| 59 |
+
|
| 60 |
+
return transform
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def normalize_threshold(mean=0.0, std=1.0, threshold=0.0, fill_value=0.0, log=False):
|
| 64 |
+
scaled = None
|
| 65 |
+
|
| 66 |
+
def transform(raw):
|
| 67 |
+
nonlocal scaled
|
| 68 |
+
if (scaled is None) or (scaled.shape != raw.shape):
|
| 69 |
+
scaled = np.empty_like(raw, dtype=np.float32)
|
| 70 |
+
normalize_threshold_array(raw, scaled, mean, std, threshold, fill_value, log=log)
|
| 71 |
+
|
| 72 |
+
return scaled
|
| 73 |
+
|
| 74 |
+
return transform
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def scale_log_norm(scale, threshold=None, missing_value=None,
|
| 78 |
+
fill_value=0, mean=0.0, std=1.0, dtype=np.float32):
|
| 79 |
+
|
| 80 |
+
log_scale = np.log10(scale, where=scale>0).astype(np.float32)
|
| 81 |
+
if threshold is not None:
|
| 82 |
+
log_scale[log_scale < np.log10(threshold)] = np.log10(fill_value)
|
| 83 |
+
if missing_value is not None:
|
| 84 |
+
log_scale[missing_value] = np.log10(fill_value)
|
| 85 |
+
log_scale[~np.isfinite(log_scale)] = np.log10(fill_value)
|
| 86 |
+
log_scale -= mean
|
| 87 |
+
log_scale /= std
|
| 88 |
+
scaled = scaled_dt = None
|
| 89 |
+
|
| 90 |
+
def transform(raw):
|
| 91 |
+
nonlocal scaled, scaled_dt
|
| 92 |
+
if (scaled is None) or (scaled.shape != raw.shape):
|
| 93 |
+
scaled = np.empty_like(raw, dtype=np.float32)
|
| 94 |
+
scaled_dt = np.empty_like(raw, dtype=dtype)
|
| 95 |
+
scale_array(raw, scaled, log_scale)
|
| 96 |
+
|
| 97 |
+
if dtype == np.float32:
|
| 98 |
+
return scaled
|
| 99 |
+
else:
|
| 100 |
+
quick_cast(scaled, scaled_dt)
|
| 101 |
+
return scaled_dt
|
| 102 |
+
|
| 103 |
+
return transform
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def combine(transforms, memory_format="channels_first", dim=3):
|
| 107 |
+
#combined = None
|
| 108 |
+
channels_axis = 1 if (memory_format == "channels_first") else -1
|
| 109 |
+
|
| 110 |
+
def transform(*raw):
|
| 111 |
+
#nonlocal combined
|
| 112 |
+
transformed = [t(r) for (t, r) in zip(transforms, raw)]
|
| 113 |
+
for i in range(len(transformed)):
|
| 114 |
+
if transformed[i].ndim == dim + 1:
|
| 115 |
+
transformed[i] = np.expand_dims(transformed[i], channels_axis)
|
| 116 |
+
|
| 117 |
+
return np.concatenate(transformed, axis=channels_axis)
|
| 118 |
+
|
| 119 |
+
return transform
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class Antialiasing:
|
| 123 |
+
def __init__(self):
|
| 124 |
+
(x,y) = np.mgrid[-2:3,-2:3]
|
| 125 |
+
self.kernel = np.exp(-0.5*(x**2+y**2)/(0.5**2))
|
| 126 |
+
self.kernel /= self.kernel.sum()
|
| 127 |
+
self.edge_factors = {}
|
| 128 |
+
self.img_smooth = {}
|
| 129 |
+
num_threads = multiprocessing.cpu_count()
|
| 130 |
+
self.executor = concurrent.futures.ThreadPoolExecutor(num_threads)
|
| 131 |
+
|
| 132 |
+
def __call__(self, img):
|
| 133 |
+
img_shape = img.shape[-2:]
|
| 134 |
+
if img_shape not in self.edge_factors:
|
| 135 |
+
s = convolve(np.ones(img_shape, dtype=np.float32),
|
| 136 |
+
self.kernel, mode="constant")
|
| 137 |
+
s = 1.0/s
|
| 138 |
+
self.edge_factors[img_shape] = s
|
| 139 |
+
else:
|
| 140 |
+
s = self.edge_factors[img_shape]
|
| 141 |
+
|
| 142 |
+
if img.shape not in self.img_smooth:
|
| 143 |
+
img_smooth = np.empty_like(img)
|
| 144 |
+
self.img_smooth[img_shape] = img_smooth
|
| 145 |
+
else:
|
| 146 |
+
img_smooth = self.img_smooth[img_shape]
|
| 147 |
+
|
| 148 |
+
def _convolve_frame(i,j):
|
| 149 |
+
convolve(img[i,j,:,:], self.kernel,
|
| 150 |
+
mode="constant", output=img_smooth[i,j,:,:])
|
| 151 |
+
img_smooth[i,j,:,:] *= s
|
| 152 |
+
|
| 153 |
+
futures = []
|
| 154 |
+
for i in range(img.shape[0]):
|
| 155 |
+
for j in range(img.shape[1]):
|
| 156 |
+
args = (_convolve_frame, i, j)
|
| 157 |
+
futures.append(self.executor.submit(*args))
|
| 158 |
+
concurrent.futures.wait(futures)
|
| 159 |
+
|
| 160 |
+
return img_smooth
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def default_rainrate_transform(scale):
|
| 164 |
+
scaling = scale_log_norm(
|
| 165 |
+
scale, threshold=0.1, fill_value=0.02,
|
| 166 |
+
mean=-0.051, std=0.528, dtype=np.float32
|
| 167 |
+
)
|
| 168 |
+
antialiasing = Antialiasing()
|
| 169 |
+
def transform(raw):
|
| 170 |
+
x = scaling(raw)
|
| 171 |
+
return antialiasing(x)
|
| 172 |
+
return transform
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def scale_norm(scale, threshold=None, missing_value=None,
|
| 176 |
+
fill_value=0, mean=0.0, std=1.0, dtype=np.float32):
|
| 177 |
+
|
| 178 |
+
scale = scale.astype(np.float32).copy()
|
| 179 |
+
scale[np.isnan(scale)] = fill_value
|
| 180 |
+
if threshold is not None:
|
| 181 |
+
scale[scale < threshold] = fill_value
|
| 182 |
+
if missing_value is not None:
|
| 183 |
+
missing_value = np.atleast_1d(missing_value)
|
| 184 |
+
for m in missing_value:
|
| 185 |
+
scale[m] = fill_value
|
| 186 |
+
scale -= mean
|
| 187 |
+
scale /= std
|
| 188 |
+
scaled = scaled_dt = None
|
| 189 |
+
|
| 190 |
+
def transform(raw):
|
| 191 |
+
nonlocal scaled, scaled_dt
|
| 192 |
+
if (scaled is None) or (scaled.shape != raw.shape):
|
| 193 |
+
scaled = np.empty_like(raw, dtype=np.float32)
|
| 194 |
+
scaled_dt = np.empty_like(raw, dtype=dtype)
|
| 195 |
+
scale_array(raw, scaled, scale)
|
| 196 |
+
|
| 197 |
+
if dtype == np.float32:
|
| 198 |
+
return scaled
|
| 199 |
+
else:
|
| 200 |
+
quick_cast(scaled, scaled_dt)
|
| 201 |
+
return scaled_dt
|
| 202 |
+
|
| 203 |
+
return transform
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
@njit(parallel=True)
|
| 207 |
+
def threshold_array(in_arr, out_arr, threshold):
|
| 208 |
+
in_arr = in_arr.ravel()
|
| 209 |
+
out_arr = out_arr.ravel()
|
| 210 |
+
for i in prange(in_arr.shape[0]):
|
| 211 |
+
out_arr[i] = np.float32(in_arr[i] >= threshold)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def one_hot(values):
|
| 215 |
+
translation = np.zeros(max(values)+1, dtype=int)
|
| 216 |
+
num_categories = len(values)
|
| 217 |
+
for (i,v) in enumerate(values):
|
| 218 |
+
translation[v] = i
|
| 219 |
+
onehot = onehot_dt = None
|
| 220 |
+
|
| 221 |
+
def transform(raw):
|
| 222 |
+
nonlocal onehot, onehot_dt
|
| 223 |
+
if (onehot is None) or (onehot.shape[:-1] != raw.shape):
|
| 224 |
+
onehot = np.empty(raw.shape+(num_categories,),
|
| 225 |
+
dtype=np.float32)
|
| 226 |
+
onehot = np.empty(raw.shape+(num_categories,),
|
| 227 |
+
dtype=np.uint8)
|
| 228 |
+
onehot_transform(raw, onehot, translation)
|
| 229 |
+
quick_cast(onehot, onehot_dt)
|
| 230 |
+
|
| 231 |
+
return onehot
|
| 232 |
+
|
| 233 |
+
return transform
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@njit(parallel=True)
|
| 237 |
+
def onehot_transform(in_arr, out_arr, translation):
|
| 238 |
+
for k in prange(in_arr.shape[0]):
|
| 239 |
+
out_arr[k,...] = 0.0
|
| 240 |
+
for t in range(in_arr.shape[1]):
|
| 241 |
+
for i in range(in_arr.shape[2]):
|
| 242 |
+
for j in range(in_arr.shape[3]):
|
| 243 |
+
ind = np.uint64(in_arr[k,t,i,j])
|
| 244 |
+
c = translation[ind]
|
| 245 |
+
out_arr[k,t,i,j,c] = 1.0
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
@njit(parallel=True)
|
| 249 |
+
def normalize_array(in_arr, out_arr, mean, std):
|
| 250 |
+
mean = np.float32(mean)
|
| 251 |
+
inv_std = np.float32(1.0/std)
|
| 252 |
+
in_arr = in_arr.ravel()
|
| 253 |
+
out_arr = out_arr.ravel()
|
| 254 |
+
for i in prange(in_arr.shape[0]):
|
| 255 |
+
out_arr[i] = (in_arr[i]-mean)*inv_std
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
@njit(parallel=True)
|
| 259 |
+
def normalize_threshold_array(
|
| 260 |
+
in_arr, out_arr,
|
| 261 |
+
mean, std,
|
| 262 |
+
threshold, fill_value, log=False
|
| 263 |
+
):
|
| 264 |
+
mean = np.float32(mean)
|
| 265 |
+
inv_std = np.float32(1.0/std)
|
| 266 |
+
threshold = np.float32(threshold)
|
| 267 |
+
fill_value = np.float32(fill_value)
|
| 268 |
+
in_arr = in_arr.ravel()
|
| 269 |
+
out_arr = out_arr.ravel()
|
| 270 |
+
for i in prange(in_arr.shape[0]):
|
| 271 |
+
x = in_arr[i]
|
| 272 |
+
if x < threshold:
|
| 273 |
+
x = fill_value
|
| 274 |
+
if log:
|
| 275 |
+
x = np.log10(x)
|
| 276 |
+
out_arr[i] = (x-mean)*inv_std
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# NumPy version
|
| 280 |
+
#def threshold_array(in_arr, out_arr, threshold):
|
| 281 |
+
# out_arr[:] = (in_arr >= threshold).astype(np.float32)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def R_threshold(scale, threshold):
|
| 285 |
+
thresholded = None
|
| 286 |
+
scale_treshold = np.nanargmax(scale > threshold)
|
| 287 |
+
|
| 288 |
+
def transform(rzc_raw):
|
| 289 |
+
nonlocal thresholded
|
| 290 |
+
if (thresholded is None) or (thresholded.shape != rzc_raw.shape):
|
| 291 |
+
thresholded = np.empty_like(rzc_raw, dtype=np.float32)
|
| 292 |
+
threshold_array(rzc_raw, thresholded, scale_treshold)
|
| 293 |
+
|
| 294 |
+
return thresholded
|
| 295 |
+
|
| 296 |
+
return transform
|
ldcast/features/utils.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from numba import njit, prange
|
| 2 |
+
import numpy as np
|
| 3 |
+
from scipy.signal import convolve
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def log_scale_with_zero(range, n=65536, dtype=np.float32):
|
| 7 |
+
scale = np.linspace(np.log10(range[0]), np.log10(range[1]), n-1)
|
| 8 |
+
scale = np.hstack((0, 10**scale)).astype(dtype)
|
| 9 |
+
return scale
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def log_quantize_with_zero(x, range, n=65536, dtype=np.uint16):
|
| 13 |
+
scale = log_scale_with_zero(range, n=n, dtype=x.dtype)
|
| 14 |
+
y = np.empty_like(x, dtype=dtype)
|
| 15 |
+
log_quant_with_zeros(x, y, np.log10(scale[1:]))
|
| 16 |
+
return (y, scale)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# optimized helper function for the above
|
| 20 |
+
@njit(parallel=True)
|
| 21 |
+
def log_quant_with_zeros(x, y, scale):
|
| 22 |
+
x = x.ravel()
|
| 23 |
+
y = y.ravel()
|
| 24 |
+
min_val = 10**scale[0]
|
| 25 |
+
|
| 26 |
+
for i in prange(x.shape[0]):
|
| 27 |
+
# map small values to 0
|
| 28 |
+
if x[i] < min_val:
|
| 29 |
+
y[i] = 0
|
| 30 |
+
continue
|
| 31 |
+
|
| 32 |
+
lx = np.log10(x[i])
|
| 33 |
+
if lx >= scale[-1]:
|
| 34 |
+
# map too big values to max of scale
|
| 35 |
+
y[i] = len(scale)
|
| 36 |
+
else:
|
| 37 |
+
# binary search for the rest
|
| 38 |
+
k0 = 0
|
| 39 |
+
k1 = len(scale)
|
| 40 |
+
while k1-k0 > 1:
|
| 41 |
+
km = k0 + (k1-k0)//2
|
| 42 |
+
if lx < scale[km]:
|
| 43 |
+
k1 = km
|
| 44 |
+
else:
|
| 45 |
+
k0 = km
|
| 46 |
+
|
| 47 |
+
if k0 == len(scale)-1:
|
| 48 |
+
q = k0
|
| 49 |
+
elif k0 == 0:
|
| 50 |
+
q = 0
|
| 51 |
+
else:
|
| 52 |
+
d0 = abs(lx-scale[k0])
|
| 53 |
+
d1 = abs(lx-scale[k1])
|
| 54 |
+
if d0 < d1:
|
| 55 |
+
q = k0
|
| 56 |
+
else:
|
| 57 |
+
q = k1
|
| 58 |
+
|
| 59 |
+
y[i] = q+1 # add 1 to leave space for zero
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@njit(parallel=True)
|
| 63 |
+
def average_pool(x, factor=2, missing=65535):
|
| 64 |
+
y = np.empty((x.shape[0]//factor, x.shape[1]//factor), dtype=x.dtype)
|
| 65 |
+
N = factor**2
|
| 66 |
+
N_thresh = N//2
|
| 67 |
+
|
| 68 |
+
for iy in prange(y.shape[0]):
|
| 69 |
+
ix0 = iy * factor
|
| 70 |
+
ix1 = ix0 + factor
|
| 71 |
+
for jy in range(y.shape[1]):
|
| 72 |
+
jx0 = jy * factor
|
| 73 |
+
jx1 = jx0 + factor
|
| 74 |
+
v = float(0.0)
|
| 75 |
+
num_valid = 0
|
| 76 |
+
|
| 77 |
+
for ix in range(ix0, ix1):
|
| 78 |
+
for jx in range(jx0, jx1):
|
| 79 |
+
if x[ix,jx] != missing:
|
| 80 |
+
v += x[ix,jx]
|
| 81 |
+
num_valid += 1
|
| 82 |
+
|
| 83 |
+
if num_valid >= N_thresh:
|
| 84 |
+
y[iy,jy] = v/num_valid
|
| 85 |
+
else:
|
| 86 |
+
y[iy,jy] = missing
|
| 87 |
+
|
| 88 |
+
return y
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@njit(parallel=True)
|
| 92 |
+
def mode_pool(x, num_values=256, factor=2):
|
| 93 |
+
y = np.empty((x.shape[0]//factor, x.shape[1]//factor), dtype=x.dtype)
|
| 94 |
+
|
| 95 |
+
for iy in prange(y.shape[0]):
|
| 96 |
+
v = np.empty(num_values, dtype=np.int64)
|
| 97 |
+
ix0 = iy * factor
|
| 98 |
+
ix1 = ix0 + factor
|
| 99 |
+
for jy in range(y.shape[1]):
|
| 100 |
+
jx0 = jy * factor
|
| 101 |
+
jx1 = jx0 + factor
|
| 102 |
+
v[:] = 0
|
| 103 |
+
|
| 104 |
+
for ix in range(ix0, ix1):
|
| 105 |
+
for jx in range(jx0, jx1):
|
| 106 |
+
v[x[ix,jx]] += 1
|
| 107 |
+
|
| 108 |
+
y[iy,jy] = v.argmax()
|
| 109 |
+
|
| 110 |
+
return y
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def fill_holes(missing=65535, rad=1):
|
| 114 |
+
def fill(x):
|
| 115 |
+
# identify mask of points to fill
|
| 116 |
+
o = np.ones((2*rad+1,2*rad+1), dtype=np.uint16)
|
| 117 |
+
n = np.prod(o.shape)
|
| 118 |
+
valid = (x != missing)
|
| 119 |
+
num_valid_neighbors = convolve(valid, o, mode='same', method='direct')
|
| 120 |
+
mask = ~valid & (num_valid_neighbors > 0)
|
| 121 |
+
|
| 122 |
+
# compute mean of valid points around each fillable point
|
| 123 |
+
fx = x.copy()
|
| 124 |
+
fx[~valid] = 0
|
| 125 |
+
mx = convolve(fx, o.astype(np.float64), mode='same', method='direct')
|
| 126 |
+
mx = mx[mask] / num_valid_neighbors[mask]
|
| 127 |
+
if np.issubdtype(x.dtype, np.integer):
|
| 128 |
+
mx = mx.round().astype(x.dtype)
|
| 129 |
+
|
| 130 |
+
# fill holes with mean
|
| 131 |
+
fx = x.copy()
|
| 132 |
+
fx[mask] = mx
|
| 133 |
+
return fx
|
| 134 |
+
|
| 135 |
+
return fill
|
| 136 |
+
|
ldcast/forecast.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import gc
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.multiprocessing as mp
|
| 7 |
+
|
| 8 |
+
from .features.transform import Antialiasing
|
| 9 |
+
from .models.autoenc import autoenc, encoder
|
| 10 |
+
from .models.genforecast import analysis, unet
|
| 11 |
+
from .models.diffusion import diffusion, plms
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Forecast:
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
*,
|
| 18 |
+
ldm_weights_fn,
|
| 19 |
+
autoenc_weights_fn,
|
| 20 |
+
gpu='auto',
|
| 21 |
+
past_timesteps=4,
|
| 22 |
+
future_timesteps=20,
|
| 23 |
+
autoenc_time_ratio=4,
|
| 24 |
+
autoenc_hidden_dim=32,
|
| 25 |
+
verbose=True,
|
| 26 |
+
R_min_value=0.1,
|
| 27 |
+
R_zero_value=0.02,
|
| 28 |
+
R_min_output=0.1,
|
| 29 |
+
R_max_output=118.428,
|
| 30 |
+
log_R_mean=-0.051,
|
| 31 |
+
log_R_std=0.528,
|
| 32 |
+
):
|
| 33 |
+
self.ldm_weights_fn = ldm_weights_fn
|
| 34 |
+
self.autoenc_weights_fn = autoenc_weights_fn
|
| 35 |
+
self.verbose = verbose
|
| 36 |
+
self.R_min_value = R_min_value
|
| 37 |
+
self.R_zero_value = R_zero_value
|
| 38 |
+
self.R_min_output = R_min_output
|
| 39 |
+
self.R_max_output = R_max_output
|
| 40 |
+
self.log_R_mean = log_R_mean
|
| 41 |
+
self.log_R_std = log_R_std
|
| 42 |
+
self.past_timesteps = past_timesteps
|
| 43 |
+
self.future_timesteps = future_timesteps
|
| 44 |
+
self.autoenc_time_ratio = autoenc_time_ratio
|
| 45 |
+
self.autoenc_hidden_dim = autoenc_hidden_dim
|
| 46 |
+
self.antialiasing = Antialiasing()
|
| 47 |
+
|
| 48 |
+
# setup LDM
|
| 49 |
+
self.ldm = self._init_model()
|
| 50 |
+
if gpu is not None:
|
| 51 |
+
if gpu == 'auto':
|
| 52 |
+
if torch.cuda.device_count() > 0:
|
| 53 |
+
self.ldm.to(device="cuda")
|
| 54 |
+
else:
|
| 55 |
+
self.ldm.to(device=f"cuda:{gpu}")
|
| 56 |
+
# setup sampler
|
| 57 |
+
self.sampler = plms.PLMSSampler(self.ldm)
|
| 58 |
+
print(self.ldm.device)
|
| 59 |
+
gc.collect()
|
| 60 |
+
|
| 61 |
+
def _init_model(self):
|
| 62 |
+
# setup autoencoder
|
| 63 |
+
enc = encoder.SimpleConvEncoder()
|
| 64 |
+
dec = encoder.SimpleConvDecoder()
|
| 65 |
+
autoencoder_obs = autoenc.AutoencoderKL(enc, dec)
|
| 66 |
+
#print(torch.load(self.autoenc_weights_fn)['state_dict'].keys())
|
| 67 |
+
# autoencoder_obs.load_state_dict(torch.load(self.autoenc_weights_fn)['state_dict'])
|
| 68 |
+
autoencoder_obs.load_state_dict(torch.load(self.autoenc_weights_fn))
|
| 69 |
+
autoencoders = [autoencoder_obs]
|
| 70 |
+
input_patches = [self.past_timesteps//self.autoenc_time_ratio]
|
| 71 |
+
input_size_ratios = [1]
|
| 72 |
+
embed_dim = [128]
|
| 73 |
+
analysis_depth = [4]
|
| 74 |
+
|
| 75 |
+
# setup forecaster
|
| 76 |
+
analysis_net = analysis.AFNONowcastNetCascade(
|
| 77 |
+
autoencoders,
|
| 78 |
+
input_patches=input_patches,
|
| 79 |
+
input_size_ratios=input_size_ratios,
|
| 80 |
+
train_autoenc=False,
|
| 81 |
+
output_patches=self.future_timesteps//self.autoenc_time_ratio,
|
| 82 |
+
cascade_depth=3,
|
| 83 |
+
embed_dim=embed_dim,
|
| 84 |
+
analysis_depth=analysis_depth
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# setup denoiser
|
| 88 |
+
denoiser = unet.UNetModel(in_channels=autoencoder_obs.hidden_width,
|
| 89 |
+
model_channels=256, out_channels=autoencoder_obs.hidden_width,
|
| 90 |
+
num_res_blocks=2, attention_resolutions=(1,2),
|
| 91 |
+
dims=3, channel_mult=(1, 2, 4), num_heads=8,
|
| 92 |
+
num_timesteps=self.future_timesteps//self.autoenc_time_ratio,
|
| 93 |
+
context_ch=analysis_net.cascade_dims
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# create LDM
|
| 97 |
+
ldm = diffusion.LatentDiffusion(denoiser, autoencoder_obs,
|
| 98 |
+
context_encoder=analysis_net)
|
| 99 |
+
# ldm.load_state_dict(torch.load(self.ldm_weights_fn)['state_dict'])
|
| 100 |
+
ldm.load_state_dict(torch.load(self.ldm_weights_fn))
|
| 101 |
+
return ldm
|
| 102 |
+
|
| 103 |
+
def __call__(
|
| 104 |
+
self,
|
| 105 |
+
R_past,
|
| 106 |
+
num_diffusion_iters=50
|
| 107 |
+
):
|
| 108 |
+
# preprocess inputs and setup correct input shape
|
| 109 |
+
x = self.transform_precip(R_past)
|
| 110 |
+
timesteps = self.input_timesteps(x)
|
| 111 |
+
future_patches = self.future_timesteps // self.autoenc_time_ratio
|
| 112 |
+
gen_shape = (self.autoenc_hidden_dim, future_patches) + \
|
| 113 |
+
(x.shape[-2]//4, x.shape[-1]//4)
|
| 114 |
+
x = [[x, timesteps]]
|
| 115 |
+
|
| 116 |
+
# run LDM sampler
|
| 117 |
+
with contextlib.redirect_stdout(None):
|
| 118 |
+
(s, intermediates) = self.sampler.sample(
|
| 119 |
+
num_diffusion_iters,
|
| 120 |
+
x[0][0].shape[0],
|
| 121 |
+
gen_shape,
|
| 122 |
+
x,
|
| 123 |
+
progbar=self.verbose
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# postprocess outputs
|
| 127 |
+
y_pred = self.ldm.autoencoder.decode(s)
|
| 128 |
+
R_pred = self.inv_transform_precip(y_pred)
|
| 129 |
+
|
| 130 |
+
return R_pred[0,...]
|
| 131 |
+
|
| 132 |
+
def transform_precip(self, R):
|
| 133 |
+
# x = R.copy()
|
| 134 |
+
x = R.clone().detach()
|
| 135 |
+
x[~(x >= self.R_min_value)] = self.R_zero_value
|
| 136 |
+
x = np.log10(x)
|
| 137 |
+
x -= self.log_R_mean
|
| 138 |
+
x /= self.log_R_std
|
| 139 |
+
x = x.reshape((1,) + x.shape)
|
| 140 |
+
x = self.antialiasing(x)
|
| 141 |
+
x = x.reshape((1,) + x.shape)
|
| 142 |
+
return torch.Tensor(x).to(device=self.ldm.device)
|
| 143 |
+
|
| 144 |
+
def inv_transform_precip(self, x):
|
| 145 |
+
x *= self.log_R_std
|
| 146 |
+
x += self.log_R_mean
|
| 147 |
+
R = torch.pow(10, x)
|
| 148 |
+
if self.R_min_output:
|
| 149 |
+
R[R < self.R_min_output] = 0.0
|
| 150 |
+
if self.R_max_output is not None:
|
| 151 |
+
R[R > self.R_max_output] = self.R_max_output
|
| 152 |
+
R = R[:,0,...]
|
| 153 |
+
return R.to(device='cpu').numpy()
|
| 154 |
+
|
| 155 |
+
def input_timesteps(self, x):
|
| 156 |
+
batch_size = x.shape[0]
|
| 157 |
+
t0 = -x.shape[2]+1
|
| 158 |
+
t1 = 1
|
| 159 |
+
timesteps = torch.arange(t0, t1,
|
| 160 |
+
dtype=x.dtype, device=self.ldm.device)
|
| 161 |
+
return timesteps.unsqueeze(0).expand(batch_size,-1)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class ForecastDistributed:
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
ldm_weights_fn,
|
| 168 |
+
autoenc_weights_fn,
|
| 169 |
+
past_timesteps=4,
|
| 170 |
+
future_timesteps=8,
|
| 171 |
+
autoenc_time_ratio=4,
|
| 172 |
+
autoenc_hidden_dim=32,
|
| 173 |
+
verbose=True,
|
| 174 |
+
R_min_value=0.1,
|
| 175 |
+
R_zero_value=0.02,
|
| 176 |
+
R_min_output=0.1,
|
| 177 |
+
R_max_output=118.428,
|
| 178 |
+
log_R_mean=-0.051,
|
| 179 |
+
log_R_std=0.528,
|
| 180 |
+
):
|
| 181 |
+
self.verbose = verbose
|
| 182 |
+
self.R_min_value = R_min_value
|
| 183 |
+
self.R_zero_value = R_zero_value
|
| 184 |
+
self.R_min_output = R_min_output
|
| 185 |
+
self.R_max_output = R_max_output
|
| 186 |
+
self.log_R_mean = log_R_mean
|
| 187 |
+
self.log_R_std = log_R_std
|
| 188 |
+
self.past_timesteps = past_timesteps
|
| 189 |
+
self.future_timesteps = future_timesteps
|
| 190 |
+
self.autoenc_time_ratio = autoenc_time_ratio
|
| 191 |
+
self.autoenc_hidden_dim = autoenc_hidden_dim
|
| 192 |
+
|
| 193 |
+
# start worker processes
|
| 194 |
+
context = mp.get_context('spawn')
|
| 195 |
+
self.input_queue = context.Queue()
|
| 196 |
+
self.output_queue = context.Queue()
|
| 197 |
+
process_kwargs = {
|
| 198 |
+
"past_timesteps": past_timesteps,
|
| 199 |
+
"future_timesteps": future_timesteps,
|
| 200 |
+
"ldm_weights_fn": ldm_weights_fn,
|
| 201 |
+
"autoenc_weights_fn": autoenc_weights_fn,
|
| 202 |
+
"autoenc_time_ratio": autoenc_time_ratio,
|
| 203 |
+
"autoenc_hidden_dim": autoenc_hidden_dim,
|
| 204 |
+
"R_min_value": R_min_value,
|
| 205 |
+
"R_zero_value": R_zero_value,
|
| 206 |
+
"R_min_output": R_min_output,
|
| 207 |
+
"R_max_output": R_max_output,
|
| 208 |
+
"log_R_mean": log_R_mean,
|
| 209 |
+
"log_R_std": log_R_std,
|
| 210 |
+
"verbose": True
|
| 211 |
+
}
|
| 212 |
+
self.num_procs = max(0, torch.cuda.device_count())
|
| 213 |
+
self.compute_procs = mp.spawn(
|
| 214 |
+
_compute_process,
|
| 215 |
+
args=(self.input_queue, self.output_queue, process_kwargs),
|
| 216 |
+
nprocs=self.num_procs,
|
| 217 |
+
join=False
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# wait for worker processes to be ready
|
| 221 |
+
for _ in range(self.num_procs):
|
| 222 |
+
self.output_queue.get()
|
| 223 |
+
|
| 224 |
+
gc.collect()
|
| 225 |
+
|
| 226 |
+
def __call__(
|
| 227 |
+
self,
|
| 228 |
+
R_past,
|
| 229 |
+
ensemble_members=1,
|
| 230 |
+
num_diffusion_iters=50
|
| 231 |
+
):
|
| 232 |
+
# send samples to compute processes
|
| 233 |
+
for (i, R_past_sample) in enumerate(R_past):
|
| 234 |
+
for j in range(ensemble_members):
|
| 235 |
+
self.input_queue.put((R_past_sample, num_diffusion_iters, i, j))
|
| 236 |
+
|
| 237 |
+
# build output array
|
| 238 |
+
pred_shape = (R_past.shape[0], self.future_timesteps) + \
|
| 239 |
+
R_past.shape[2:] + (ensemble_members,)
|
| 240 |
+
R_pred = np.empty(pred_shape, R_past.dtype)
|
| 241 |
+
|
| 242 |
+
# gather outputs from processes
|
| 243 |
+
predictions_needed = R_past.shape[0] * ensemble_members
|
| 244 |
+
for _ in range(predictions_needed):
|
| 245 |
+
(R_pred_sample, i, j) = self.output_queue.get()
|
| 246 |
+
R_pred[i,...,j] = R_pred_sample
|
| 247 |
+
|
| 248 |
+
return R_pred
|
| 249 |
+
|
| 250 |
+
def __del__(self):
|
| 251 |
+
for _ in range(self.num_procs):
|
| 252 |
+
self.input_queue.put(None)
|
| 253 |
+
self.compute_procs.join()
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def _compute_process(process_index, input_queue, output_queue, kwargs):
|
| 257 |
+
gpu = process_index if (torch.cuda.device_count() > 0) else None
|
| 258 |
+
fc = Forecast(gpu=gpu, **kwargs)
|
| 259 |
+
output_queue.put("Ready") # signal process ready to accept inputs
|
| 260 |
+
|
| 261 |
+
while (data := input_queue.get()) is not None:
|
| 262 |
+
(R_past, num_diffusion_iters, sample, member) = data
|
| 263 |
+
R_pred = fc(R_past, num_diffusion_iters=num_diffusion_iters)
|
| 264 |
+
output_queue.put((R_pred, sample, member))
|
ldcast/models/autoenc/autoenc.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import pytorch_lightning as pl
|
| 4 |
+
|
| 5 |
+
from ..distributions import kl_from_standard_normal, ensemble_nll_normal
|
| 6 |
+
from ..distributions import sample_from_standard_normal
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AutoencoderKL(pl.LightningModule):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
encoder, decoder,
|
| 13 |
+
kl_weight=0.01,
|
| 14 |
+
encoded_channels=64,
|
| 15 |
+
hidden_width=32,
|
| 16 |
+
**kwargs
|
| 17 |
+
):
|
| 18 |
+
super().__init__(**kwargs)
|
| 19 |
+
self.encoder = encoder
|
| 20 |
+
self.decoder = decoder
|
| 21 |
+
self.hidden_width = hidden_width
|
| 22 |
+
self.to_moments = nn.Conv3d(encoded_channels, 2*hidden_width,
|
| 23 |
+
kernel_size=1)
|
| 24 |
+
self.to_decoder = nn.Conv3d(hidden_width, encoded_channels,
|
| 25 |
+
kernel_size=1)
|
| 26 |
+
self.log_var = nn.Parameter(torch.zeros(size=()))
|
| 27 |
+
self.kl_weight = kl_weight
|
| 28 |
+
|
| 29 |
+
def encode(self, x):
|
| 30 |
+
h = self.encoder(x)
|
| 31 |
+
(mean, log_var) = torch.chunk(self.to_moments(h), 2, dim=1)
|
| 32 |
+
return (mean, log_var)
|
| 33 |
+
|
| 34 |
+
def decode(self, z):
|
| 35 |
+
z = self.to_decoder(z)
|
| 36 |
+
dec = self.decoder(z)
|
| 37 |
+
return dec
|
| 38 |
+
|
| 39 |
+
def forward(self, input, sample_posterior=True):
|
| 40 |
+
(mean, log_var) = self.encode(input)
|
| 41 |
+
if sample_posterior:
|
| 42 |
+
z = sample_from_standard_normal(mean, log_var)
|
| 43 |
+
else:
|
| 44 |
+
z = mean
|
| 45 |
+
dec = self.decode(z)
|
| 46 |
+
return (dec, mean, log_var)
|
| 47 |
+
|
| 48 |
+
def _loss(self, batch):
|
| 49 |
+
(x,y) = batch
|
| 50 |
+
while isinstance(x, list) or isinstance(x, tuple):
|
| 51 |
+
x = x[0][0]
|
| 52 |
+
(y_pred, mean, log_var) = self.forward(x)
|
| 53 |
+
|
| 54 |
+
rec_loss = (y-y_pred).abs().mean()
|
| 55 |
+
kl_loss = kl_from_standard_normal(mean, log_var)
|
| 56 |
+
|
| 57 |
+
total_loss = rec_loss + self.kl_weight * kl_loss
|
| 58 |
+
|
| 59 |
+
return (total_loss, rec_loss, kl_loss)
|
| 60 |
+
|
| 61 |
+
def training_step(self, batch, batch_idx):
|
| 62 |
+
loss = self._loss(batch)[0]
|
| 63 |
+
self.log("train_loss", loss)
|
| 64 |
+
return loss
|
| 65 |
+
|
| 66 |
+
@torch.no_grad()
|
| 67 |
+
def val_test_step(self, batch, batch_idx, split="val"):
|
| 68 |
+
(total_loss, rec_loss, kl_loss) = self._loss(batch)
|
| 69 |
+
log_params = {"on_step": False, "on_epoch": True, "prog_bar": True}
|
| 70 |
+
self.log(f"{split}_loss", total_loss, **log_params)
|
| 71 |
+
self.log(f"{split}_rec_loss", rec_loss.mean(), **log_params)
|
| 72 |
+
self.log(f"{split}_kl_loss", kl_loss, **log_params)
|
| 73 |
+
|
| 74 |
+
def validation_step(self, batch, batch_idx):
|
| 75 |
+
self.val_test_step(batch, batch_idx, split="val")
|
| 76 |
+
|
| 77 |
+
def test_step(self, batch, batch_idx):
|
| 78 |
+
self.val_test_step(batch, batch_idx, split="test")
|
| 79 |
+
|
| 80 |
+
def configure_optimizers(self):
|
| 81 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3,
|
| 82 |
+
betas=(0.5, 0.9), weight_decay=1e-3)
|
| 83 |
+
reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 84 |
+
optimizer, patience=3, factor=0.25, verbose=True
|
| 85 |
+
)
|
| 86 |
+
return {
|
| 87 |
+
"optimizer": optimizer,
|
| 88 |
+
"lr_scheduler": {
|
| 89 |
+
"scheduler": reduce_lr,
|
| 90 |
+
"monitor": "val_rec_loss",
|
| 91 |
+
"frequency": 1,
|
| 92 |
+
},
|
| 93 |
+
}
|
ldcast/models/autoenc/encoder.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from ..blocks.resnet import ResBlock3D
|
| 5 |
+
from ..utils import activation, normalization
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SimpleConvEncoder(nn.Sequential):
|
| 9 |
+
def __init__(self, in_dim=1, levels=2, min_ch=64):
|
| 10 |
+
sequence = []
|
| 11 |
+
channels = np.hstack([
|
| 12 |
+
in_dim,
|
| 13 |
+
(8**np.arange(1,levels+1)).clip(min=min_ch)
|
| 14 |
+
])
|
| 15 |
+
|
| 16 |
+
for i in range(levels):
|
| 17 |
+
in_channels = int(channels[i])
|
| 18 |
+
out_channels = int(channels[i+1])
|
| 19 |
+
res_kernel_size = (3,3,3) if i == 0 else (1,3,3)
|
| 20 |
+
res_block = ResBlock3D(
|
| 21 |
+
in_channels, out_channels,
|
| 22 |
+
kernel_size=res_kernel_size,
|
| 23 |
+
norm_kwargs={"num_groups": 1}
|
| 24 |
+
)
|
| 25 |
+
sequence.append(res_block)
|
| 26 |
+
downsample = nn.Conv3d(out_channels, out_channels,
|
| 27 |
+
kernel_size=(2,2,2), stride=(2,2,2))
|
| 28 |
+
sequence.append(downsample)
|
| 29 |
+
in_channels = out_channels
|
| 30 |
+
|
| 31 |
+
super().__init__(*sequence)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class SimpleConvDecoder(nn.Sequential):
|
| 35 |
+
def __init__(self, in_dim=1, levels=2, min_ch=64):
|
| 36 |
+
sequence = []
|
| 37 |
+
channels = np.hstack([
|
| 38 |
+
in_dim,
|
| 39 |
+
(8**np.arange(1,levels+1)).clip(min=min_ch)
|
| 40 |
+
])
|
| 41 |
+
|
| 42 |
+
for i in reversed(list(range(levels))):
|
| 43 |
+
in_channels = int(channels[i+1])
|
| 44 |
+
out_channels = int(channels[i])
|
| 45 |
+
upsample = nn.ConvTranspose3d(in_channels, in_channels,
|
| 46 |
+
kernel_size=(2,2,2), stride=(2,2,2))
|
| 47 |
+
sequence.append(upsample)
|
| 48 |
+
res_kernel_size = (3,3,3) if (i == 0) else (1,3,3)
|
| 49 |
+
res_block = ResBlock3D(
|
| 50 |
+
in_channels, out_channels,
|
| 51 |
+
kernel_size=res_kernel_size,
|
| 52 |
+
norm_kwargs={"num_groups": 1}
|
| 53 |
+
)
|
| 54 |
+
sequence.append(res_block)
|
| 55 |
+
in_channels = out_channels
|
| 56 |
+
|
| 57 |
+
super().__init__(*sequence)
|
ldcast/models/autoenc/training.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytorch_lightning as pl
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from . import autoenc
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def setup_autoenc_training(
|
| 8 |
+
encoder,
|
| 9 |
+
decoder,
|
| 10 |
+
model_dir,
|
| 11 |
+
):
|
| 12 |
+
autoencoder = autoenc.AutoencoderKL(encoder, decoder)
|
| 13 |
+
|
| 14 |
+
num_gpus = torch.cuda.device_count()
|
| 15 |
+
accelerator = "gpu" if (num_gpus > 0) else "cpu"
|
| 16 |
+
devices = torch.cuda.device_count() if (accelerator == "gpu") else 1
|
| 17 |
+
|
| 18 |
+
early_stopping = pl.callbacks.EarlyStopping(
|
| 19 |
+
"val_rec_loss", patience=6, verbose=True
|
| 20 |
+
)
|
| 21 |
+
print(model_dir)
|
| 22 |
+
checkpoint = pl.callbacks.ModelCheckpoint(
|
| 23 |
+
dirpath=model_dir,
|
| 24 |
+
filename="{epoch}-{val_rec_loss:.4f}",
|
| 25 |
+
#filename=ckpt,
|
| 26 |
+
monitor="val_rec_loss",
|
| 27 |
+
every_n_epochs=1,
|
| 28 |
+
save_top_k=3,
|
| 29 |
+
save_weights_only=False,
|
| 30 |
+
)
|
| 31 |
+
callbacks = [early_stopping, checkpoint]
|
| 32 |
+
|
| 33 |
+
trainer = pl.Trainer(
|
| 34 |
+
accelerator=accelerator,
|
| 35 |
+
devices=devices,
|
| 36 |
+
max_epochs=1000,
|
| 37 |
+
#strategy='ddp' if (num_gpus > 1) else None,
|
| 38 |
+
callbacks=callbacks,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
return (autoencoder, trainer)
|
ldcast/models/benchmarks/dgmr.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import tensorflow as tf
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class DGMRModel:
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
model_handle,
|
| 11 |
+
multi_gpu=True,
|
| 12 |
+
transform_to_rainrate=None,
|
| 13 |
+
transform_from_rainrate=None,
|
| 14 |
+
data_format='channels_first',
|
| 15 |
+
calibrated=False,
|
| 16 |
+
):
|
| 17 |
+
self.transform_to_rainrate = transform_to_rainrate
|
| 18 |
+
self.transform_from_rainrate = transform_from_rainrate
|
| 19 |
+
self.data_format = data_format
|
| 20 |
+
self.calibrated = calibrated
|
| 21 |
+
|
| 22 |
+
if multi_gpu and len(tf.config.list_physical_devices('GPU')) > 1:
|
| 23 |
+
# initialize multi-GPU strategy
|
| 24 |
+
strategy = tf.distribute.MirroredStrategy()
|
| 25 |
+
else: # use default strategy
|
| 26 |
+
strategy = tf.distribute.get_strategy()
|
| 27 |
+
|
| 28 |
+
with strategy.scope():
|
| 29 |
+
module = tf.saved_model.load(model_handle)
|
| 30 |
+
|
| 31 |
+
self.model = module.signatures['default']
|
| 32 |
+
input_signature = self.model.structured_input_signature[1]
|
| 33 |
+
self.noise_dim = input_signature['z'].shape[1]
|
| 34 |
+
self.past_timesteps = input_signature['labels$cond_frames'].shape[1]
|
| 35 |
+
|
| 36 |
+
def __call__(self, x):
|
| 37 |
+
while isinstance(x, list) or isinstance(x, tuple):
|
| 38 |
+
x = x[0]
|
| 39 |
+
x = np.array(x, copy=False)
|
| 40 |
+
if self.data_format == "channels_first":
|
| 41 |
+
x = x.transpose(0,2,3,4,1)
|
| 42 |
+
if self.transform_to_rainrate is not None:
|
| 43 |
+
x = self.transform_to_rainrate(x)
|
| 44 |
+
x = tf.convert_to_tensor(x)
|
| 45 |
+
|
| 46 |
+
num_samples = x.shape[0]
|
| 47 |
+
z = tf.random.normal(shape=(num_samples, self.noise_dim))
|
| 48 |
+
if self.calibrated:
|
| 49 |
+
z = z * 2.0
|
| 50 |
+
|
| 51 |
+
onehot = tf.ones(shape=(num_samples, 1))
|
| 52 |
+
inputs = {
|
| 53 |
+
"z": z,
|
| 54 |
+
"labels$onehot" : onehot,
|
| 55 |
+
"labels$cond_frames" : x
|
| 56 |
+
}
|
| 57 |
+
y = self.model(**inputs)['default']
|
| 58 |
+
y = y[:,self.past_timesteps:,...]
|
| 59 |
+
|
| 60 |
+
y = np.array(y)
|
| 61 |
+
if self.transform_from_rainrate is not None:
|
| 62 |
+
y = self.transform_from_rainrate(y)
|
| 63 |
+
if self.data_format == "channels_first":
|
| 64 |
+
y = y.transpose(0,4,1,2,3)
|
| 65 |
+
|
| 66 |
+
return y
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def create_ensemble(
|
| 70 |
+
dgmr, x,
|
| 71 |
+
ensemble_size=32,
|
| 72 |
+
model_path="../models/dgmr/256x256",
|
| 73 |
+
|
| 74 |
+
):
|
| 75 |
+
y_pred = []
|
| 76 |
+
for member in range(ensemble_size):
|
| 77 |
+
print(f"Generating member {member+1}/{ensemble_size}")
|
| 78 |
+
y_pred.append(dgmr(x))
|
| 79 |
+
gc.collect()
|
| 80 |
+
|
| 81 |
+
y_pred = np.stack(y_pred, axis=-1)
|
| 82 |
+
return y_pred
|
ldcast/models/benchmarks/pysteps.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# following https://pysteps.readthedocs.io/en/stable/auto_examples/plot_steps_nowcast.html
|
| 2 |
+
from datetime import timedelta
|
| 3 |
+
|
| 4 |
+
import dask
|
| 5 |
+
import numpy as np
|
| 6 |
+
from pysteps import nowcasts
|
| 7 |
+
from pysteps.motion.lucaskanade import dense_lucaskanade
|
| 8 |
+
from pysteps.utils import transformation
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PySTEPSModel:
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
data_format='channels_first',
|
| 15 |
+
future_timesteps=20,
|
| 16 |
+
ensemble_size=32,
|
| 17 |
+
km_per_pixel=1.0,
|
| 18 |
+
interval=timedelta(minutes=5),
|
| 19 |
+
transform_to_rainrate=None,
|
| 20 |
+
transform_from_rainrate=None,
|
| 21 |
+
):
|
| 22 |
+
self.transform_to_rainrate = transform_to_rainrate
|
| 23 |
+
self.transform_from_rainrate = transform_from_rainrate
|
| 24 |
+
self.data_format = data_format
|
| 25 |
+
self.nowcast_method = nowcasts.get_method("steps")
|
| 26 |
+
self.future_timesteps = future_timesteps
|
| 27 |
+
self.ensemble_size = ensemble_size
|
| 28 |
+
self.km_per_pixel = km_per_pixel
|
| 29 |
+
self.interval = interval
|
| 30 |
+
|
| 31 |
+
def zero_prediction(self, R, zerovalue):
|
| 32 |
+
out_shape = (self.future_timesteps,) + R.shape[1:] + \
|
| 33 |
+
(self.ensemble_size,)
|
| 34 |
+
return np.full(out_shape, zerovalue, dtype=R.dtype)
|
| 35 |
+
|
| 36 |
+
def predict_sample(self, x, threshold=-10.0, zerovalue=-15.0):
|
| 37 |
+
R = self.transform_to_rainrate(x)
|
| 38 |
+
(R, _) = transformation.dB_transform(
|
| 39 |
+
R, threshold=0.1, zerovalue=zerovalue
|
| 40 |
+
)
|
| 41 |
+
R[~np.isfinite(R)] = zerovalue
|
| 42 |
+
if (R == zerovalue).all():
|
| 43 |
+
R_f = self.zero_prediction(R, zerovalue)
|
| 44 |
+
else:
|
| 45 |
+
V = dense_lucaskanade(R)
|
| 46 |
+
try:
|
| 47 |
+
R_f = self.nowcast_method(
|
| 48 |
+
R,
|
| 49 |
+
V,
|
| 50 |
+
self.future_timesteps,
|
| 51 |
+
n_ens_members=self.ensemble_size,
|
| 52 |
+
n_cascade_levels=6,
|
| 53 |
+
precip_thr=threshold,
|
| 54 |
+
kmperpixel=self.km_per_pixel,
|
| 55 |
+
timestep=self.interval.total_seconds()/60,
|
| 56 |
+
noise_method="nonparametric",
|
| 57 |
+
vel_pert_method="bps",
|
| 58 |
+
mask_method="incremental",
|
| 59 |
+
num_workers=2
|
| 60 |
+
)
|
| 61 |
+
R_f = R_f.transpose(1,2,3,0)
|
| 62 |
+
except (ValueError, RuntimeError) as e:
|
| 63 |
+
zero_error = str(e).endswith("contains non-finite values") or \
|
| 64 |
+
str(e).startswith("zero-size array to reduction operation") or \
|
| 65 |
+
str(e).endswith("nonstationary AR(p) process")
|
| 66 |
+
if zero_error:
|
| 67 |
+
# occasional PySTEPS errors that happen with little/no precip
|
| 68 |
+
# therefore returning all zeros makes sense
|
| 69 |
+
R_f = self.zero_prediction(R, zerovalue)
|
| 70 |
+
else:
|
| 71 |
+
raise
|
| 72 |
+
|
| 73 |
+
# Back-transform to rain rates
|
| 74 |
+
R_f = transformation.dB_transform(
|
| 75 |
+
R_f, threshold=threshold, inverse=True
|
| 76 |
+
)[0]
|
| 77 |
+
|
| 78 |
+
if self.transform_from_rainrate is not None:
|
| 79 |
+
R_f = self.transform_from_rainrate(R_f)
|
| 80 |
+
|
| 81 |
+
return R_f
|
| 82 |
+
|
| 83 |
+
def __call__(self, x, parallel=True):
|
| 84 |
+
while isinstance(x, list) or isinstance(x, tuple):
|
| 85 |
+
x = x[0]
|
| 86 |
+
x = np.array(x, copy=False)
|
| 87 |
+
if self.data_format == "channels_first":
|
| 88 |
+
x = x.transpose(0,2,3,4,1)
|
| 89 |
+
|
| 90 |
+
pred = self.predict_sample
|
| 91 |
+
if parallel:
|
| 92 |
+
pred = dask.delayed(pred)
|
| 93 |
+
y = [
|
| 94 |
+
pred(x[i,:,:,:,0])
|
| 95 |
+
for i in range(x.shape[0])
|
| 96 |
+
]
|
| 97 |
+
if parallel:
|
| 98 |
+
y = dask.compute(y, scheduler="threads", num_workers=len(y))[0]
|
| 99 |
+
y = np.stack(y, axis=0)
|
| 100 |
+
|
| 101 |
+
if self.data_format == "channels_first":
|
| 102 |
+
y = np.expand_dims(y, 1)
|
| 103 |
+
else:
|
| 104 |
+
y = np.expand_dims(y, -2)
|
| 105 |
+
|
| 106 |
+
return y
|
ldcast/models/benchmarks/transform.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def transform_to_rainrate(x, mean=-0.051, std=0.528, threshold=0.1):
|
| 5 |
+
x = x*std + mean
|
| 6 |
+
R = 10**x
|
| 7 |
+
R[R < threshold] = 0
|
| 8 |
+
return R
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def transform_from_rainrate(
|
| 12 |
+
R, mean=-0.051, std=0.528,
|
| 13 |
+
threshold=0.1, fill_value=0.02
|
| 14 |
+
):
|
| 15 |
+
R = R.copy()
|
| 16 |
+
R[R < threshold] = fill_value
|
| 17 |
+
return (np.log10(R)-mean) / std
|
ldcast/models/blocks/afno.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#reference: https://github.com/NVlabs/AFNO-transformer
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
|
| 8 |
+
from .attention import TemporalAttention
|
| 9 |
+
|
| 10 |
+
class Mlp(nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
in_features, hidden_features=None, out_features=None,
|
| 14 |
+
act_layer=nn.GELU, drop=0.0
|
| 15 |
+
):
|
| 16 |
+
super().__init__()
|
| 17 |
+
out_features = out_features or in_features
|
| 18 |
+
hidden_features = hidden_features or in_features
|
| 19 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 20 |
+
self.act = act_layer()
|
| 21 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 22 |
+
self.drop = nn.Dropout(drop) if drop > 0 else nn.Identity()
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
x = self.fc1(x)
|
| 26 |
+
x = self.act(x)
|
| 27 |
+
x = self.drop(x)
|
| 28 |
+
x = self.fc2(x)
|
| 29 |
+
x = self.drop(x)
|
| 30 |
+
return x
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class AFNO2D(nn.Module):
|
| 34 |
+
def __init__(self, hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1, hidden_size_factor=1):
|
| 35 |
+
super().__init__()
|
| 36 |
+
assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}"
|
| 37 |
+
|
| 38 |
+
self.hidden_size = hidden_size
|
| 39 |
+
self.sparsity_threshold = sparsity_threshold
|
| 40 |
+
self.num_blocks = num_blocks
|
| 41 |
+
self.block_size = self.hidden_size // self.num_blocks
|
| 42 |
+
self.hard_thresholding_fraction = hard_thresholding_fraction
|
| 43 |
+
self.hidden_size_factor = hidden_size_factor
|
| 44 |
+
self.scale = 0.02
|
| 45 |
+
|
| 46 |
+
self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor))
|
| 47 |
+
self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor))
|
| 48 |
+
self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size))
|
| 49 |
+
self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size))
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
bias = x
|
| 53 |
+
|
| 54 |
+
dtype = x.dtype
|
| 55 |
+
x = x.float()
|
| 56 |
+
B, H, W, C = x.shape
|
| 57 |
+
|
| 58 |
+
x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho")
|
| 59 |
+
x = x.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size)
|
| 60 |
+
|
| 61 |
+
o1_real = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
|
| 62 |
+
o1_imag = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
|
| 63 |
+
o2_real = torch.zeros(x.shape, device=x.device)
|
| 64 |
+
o2_imag = torch.zeros(x.shape, device=x.device)
|
| 65 |
+
|
| 66 |
+
total_modes = H // 2 + 1
|
| 67 |
+
kept_modes = int(total_modes * self.hard_thresholding_fraction)
|
| 68 |
+
|
| 69 |
+
o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu(
|
| 70 |
+
torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[0]) - \
|
| 71 |
+
torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[1]) + \
|
| 72 |
+
self.b1[0]
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu(
|
| 76 |
+
torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[0]) + \
|
| 77 |
+
torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[1]) + \
|
| 78 |
+
self.b1[1]
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
o2_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = (
|
| 82 |
+
torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) - \
|
| 83 |
+
torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \
|
| 84 |
+
self.b2[0]
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
o2_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = (
|
| 88 |
+
torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) + \
|
| 89 |
+
torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \
|
| 90 |
+
self.b2[1]
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
x = torch.stack([o2_real, o2_imag], dim=-1)
|
| 94 |
+
x = F.softshrink(x, lambd=self.sparsity_threshold)
|
| 95 |
+
x = torch.view_as_complex(x)
|
| 96 |
+
x = x.reshape(B, H, W // 2 + 1, C)
|
| 97 |
+
x = torch.fft.irfft2(x, s=(H, W), dim=(1,2), norm="ortho")
|
| 98 |
+
x = x.type(dtype)
|
| 99 |
+
|
| 100 |
+
return x + bias
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Block(nn.Module):
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
dim,
|
| 107 |
+
mlp_ratio=4.,
|
| 108 |
+
drop=0.,
|
| 109 |
+
drop_path=0.,
|
| 110 |
+
act_layer=nn.GELU,
|
| 111 |
+
norm_layer=nn.LayerNorm,
|
| 112 |
+
double_skip=True,
|
| 113 |
+
num_blocks=8,
|
| 114 |
+
sparsity_threshold=0.01,
|
| 115 |
+
hard_thresholding_fraction=1.0
|
| 116 |
+
):
|
| 117 |
+
super().__init__()
|
| 118 |
+
self.norm1 = norm_layer(dim)
|
| 119 |
+
self.filter = AFNO2D(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction)
|
| 120 |
+
self.norm2 = norm_layer(dim)
|
| 121 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 122 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 123 |
+
self.double_skip = double_skip
|
| 124 |
+
|
| 125 |
+
def forward(self, x):
|
| 126 |
+
residual = x
|
| 127 |
+
x = self.norm1(x)
|
| 128 |
+
x = self.filter(x)
|
| 129 |
+
|
| 130 |
+
if self.double_skip:
|
| 131 |
+
x = x + residual
|
| 132 |
+
residual = x
|
| 133 |
+
|
| 134 |
+
x = self.norm2(x)
|
| 135 |
+
x = self.mlp(x)
|
| 136 |
+
x = x + residual
|
| 137 |
+
return x
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class AFNO3D(nn.Module):
|
| 142 |
+
def __init__(
|
| 143 |
+
self, hidden_size, num_blocks=8, sparsity_threshold=0.01,
|
| 144 |
+
hard_thresholding_fraction=1, hidden_size_factor=1
|
| 145 |
+
):
|
| 146 |
+
super().__init__()
|
| 147 |
+
assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}"
|
| 148 |
+
|
| 149 |
+
self.hidden_size = hidden_size
|
| 150 |
+
self.sparsity_threshold = sparsity_threshold
|
| 151 |
+
self.num_blocks = num_blocks
|
| 152 |
+
self.block_size = self.hidden_size // self.num_blocks
|
| 153 |
+
self.hard_thresholding_fraction = hard_thresholding_fraction
|
| 154 |
+
self.hidden_size_factor = hidden_size_factor
|
| 155 |
+
self.scale = 0.02
|
| 156 |
+
|
| 157 |
+
self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor))
|
| 158 |
+
self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor))
|
| 159 |
+
self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size))
|
| 160 |
+
self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size))
|
| 161 |
+
|
| 162 |
+
def forward(self, x):
|
| 163 |
+
bias = x
|
| 164 |
+
|
| 165 |
+
dtype = x.dtype
|
| 166 |
+
x = x.float()
|
| 167 |
+
B, D, H, W, C = x.shape
|
| 168 |
+
|
| 169 |
+
x = torch.fft.rfftn(x, dim=(1, 2, 3), norm="ortho")
|
| 170 |
+
x = x.reshape(B, D, H, W // 2 + 1, self.num_blocks, self.block_size)
|
| 171 |
+
|
| 172 |
+
o1_real = torch.zeros([B, D, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
|
| 173 |
+
o1_imag = torch.zeros([B, D, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
|
| 174 |
+
o2_real = torch.zeros(x.shape, device=x.device)
|
| 175 |
+
o2_imag = torch.zeros(x.shape, device=x.device)
|
| 176 |
+
|
| 177 |
+
total_modes = H // 2 + 1
|
| 178 |
+
kept_modes = int(total_modes * self.hard_thresholding_fraction)
|
| 179 |
+
|
| 180 |
+
o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu(
|
| 181 |
+
torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[0]) - \
|
| 182 |
+
torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[1]) + \
|
| 183 |
+
self.b1[0]
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu(
|
| 187 |
+
torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[0]) + \
|
| 188 |
+
torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[1]) + \
|
| 189 |
+
self.b1[1]
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
o2_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = (
|
| 193 |
+
torch.einsum('...bi,bio->...bo', o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) - \
|
| 194 |
+
torch.einsum('...bi,bio->...bo', o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \
|
| 195 |
+
self.b2[0]
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
o2_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = (
|
| 199 |
+
torch.einsum('...bi,bio->...bo', o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) + \
|
| 200 |
+
torch.einsum('...bi,bio->...bo', o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \
|
| 201 |
+
self.b2[1]
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
x = torch.stack([o2_real, o2_imag], dim=-1)
|
| 205 |
+
x = F.softshrink(x, lambd=self.sparsity_threshold)
|
| 206 |
+
x = torch.view_as_complex(x)
|
| 207 |
+
x = x.reshape(B, D, H, W // 2 + 1, C)
|
| 208 |
+
x = torch.fft.irfftn(x, s=(D, H, W), dim=(1,2,3), norm="ortho")
|
| 209 |
+
x = x.type(dtype)
|
| 210 |
+
|
| 211 |
+
return x + bias
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class AFNOBlock3d(nn.Module):
|
| 215 |
+
def __init__(
|
| 216 |
+
self,
|
| 217 |
+
dim,
|
| 218 |
+
mlp_ratio=4.,
|
| 219 |
+
drop=0.,
|
| 220 |
+
act_layer=nn.GELU,
|
| 221 |
+
norm_layer=nn.LayerNorm,
|
| 222 |
+
double_skip=True,
|
| 223 |
+
num_blocks=8,
|
| 224 |
+
sparsity_threshold=0.01,
|
| 225 |
+
hard_thresholding_fraction=1.0,
|
| 226 |
+
data_format="channels_last",
|
| 227 |
+
mlp_out_features=None,
|
| 228 |
+
):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.norm_layer = norm_layer
|
| 231 |
+
self.norm1 = norm_layer(dim)
|
| 232 |
+
self.filter = AFNO3D(dim, num_blocks, sparsity_threshold,
|
| 233 |
+
hard_thresholding_fraction)
|
| 234 |
+
self.norm2 = norm_layer(dim)
|
| 235 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 236 |
+
self.mlp = Mlp(
|
| 237 |
+
in_features=dim, out_features=mlp_out_features,
|
| 238 |
+
hidden_features=mlp_hidden_dim,
|
| 239 |
+
act_layer=act_layer, drop=drop
|
| 240 |
+
)
|
| 241 |
+
self.double_skip = double_skip
|
| 242 |
+
self.channels_first = (data_format == "channels_first")
|
| 243 |
+
|
| 244 |
+
def forward(self, x):
|
| 245 |
+
if self.channels_first:
|
| 246 |
+
# AFNO natively uses a channels-last data format
|
| 247 |
+
x = x.permute(0,2,3,4,1)
|
| 248 |
+
|
| 249 |
+
residual = x
|
| 250 |
+
x = self.norm1(x)
|
| 251 |
+
x = self.filter(x)
|
| 252 |
+
|
| 253 |
+
if self.double_skip:
|
| 254 |
+
x = x + residual
|
| 255 |
+
residual = x
|
| 256 |
+
|
| 257 |
+
x = self.norm2(x)
|
| 258 |
+
x = self.mlp(x)
|
| 259 |
+
x = x + residual
|
| 260 |
+
|
| 261 |
+
if self.channels_first:
|
| 262 |
+
x = x.permute(0,4,1,2,3)
|
| 263 |
+
|
| 264 |
+
return x
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class PatchEmbed3d(nn.Module):
|
| 268 |
+
def __init__(self, patch_size=(4,4,4), in_chans=1, embed_dim=256):
|
| 269 |
+
super().__init__()
|
| 270 |
+
self.patch_size = patch_size
|
| 271 |
+
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 272 |
+
|
| 273 |
+
def forward(self, x):
|
| 274 |
+
x = self.proj(x)
|
| 275 |
+
x = x.permute(0,2,3,4,1) # convert to BHWC
|
| 276 |
+
return x
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class PatchExpand3d(nn.Module):
|
| 280 |
+
def __init__(self, patch_size=(4,4,4), out_chans=1, embed_dim=256):
|
| 281 |
+
super().__init__()
|
| 282 |
+
self.patch_size = patch_size
|
| 283 |
+
self.proj = nn.Linear(embed_dim, out_chans*np.prod(patch_size))
|
| 284 |
+
|
| 285 |
+
def forward(self, x):
|
| 286 |
+
x = self.proj(x)
|
| 287 |
+
x = rearrange(
|
| 288 |
+
x,
|
| 289 |
+
"b d h w (p0 p1 p2 c_out) -> b c_out (d p0) (h p1) (w p2)",
|
| 290 |
+
p0=self.patch_size[0],
|
| 291 |
+
p1=self.patch_size[1],
|
| 292 |
+
p2=self.patch_size[2],
|
| 293 |
+
d=x.shape[1],
|
| 294 |
+
h=x.shape[2],
|
| 295 |
+
w=x.shape[3],
|
| 296 |
+
)
|
| 297 |
+
return x
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class AFNOCrossAttentionBlock3d(nn.Module):
|
| 301 |
+
""" AFNO 3D Block with channel mixing from two sources.
|
| 302 |
+
"""
|
| 303 |
+
def __init__(
|
| 304 |
+
self,
|
| 305 |
+
dim,
|
| 306 |
+
context_dim,
|
| 307 |
+
mlp_ratio=2.,
|
| 308 |
+
drop=0.,
|
| 309 |
+
act_layer=nn.GELU,
|
| 310 |
+
norm_layer=nn.Identity,
|
| 311 |
+
double_skip=True,
|
| 312 |
+
num_blocks=8,
|
| 313 |
+
sparsity_threshold=0.01,
|
| 314 |
+
hard_thresholding_fraction=1.0,
|
| 315 |
+
data_format="channels_last",
|
| 316 |
+
timesteps=None
|
| 317 |
+
):
|
| 318 |
+
super().__init__()
|
| 319 |
+
|
| 320 |
+
self.norm1 = norm_layer(dim)
|
| 321 |
+
self.norm2 = norm_layer(dim+context_dim)
|
| 322 |
+
mlp_hidden_dim = int((dim+context_dim) * mlp_ratio)
|
| 323 |
+
self.pre_proj = nn.Linear(dim+context_dim, dim+context_dim)
|
| 324 |
+
self.filter = AFNO3D(dim+context_dim, num_blocks, sparsity_threshold,
|
| 325 |
+
hard_thresholding_fraction)
|
| 326 |
+
self.mlp = Mlp(
|
| 327 |
+
in_features=dim+context_dim,
|
| 328 |
+
out_features=dim,
|
| 329 |
+
hidden_features=mlp_hidden_dim,
|
| 330 |
+
act_layer=act_layer, drop=drop
|
| 331 |
+
)
|
| 332 |
+
self.channels_first = (data_format == "channels_first")
|
| 333 |
+
|
| 334 |
+
def forward(self, x, y):
|
| 335 |
+
if self.channels_first:
|
| 336 |
+
# AFNO natively uses a channels-last order
|
| 337 |
+
x = x.permute(0,2,3,4,1)
|
| 338 |
+
y = y.permute(0,2,3,4,1)
|
| 339 |
+
|
| 340 |
+
xy = torch.concat((self.norm1(x),y), axis=-1)
|
| 341 |
+
xy = self.pre_proj(xy) + xy
|
| 342 |
+
xy = self.filter(self.norm2(xy)) + xy # AFNO filter
|
| 343 |
+
x = self.mlp(xy) + x # feed-forward
|
| 344 |
+
|
| 345 |
+
if self.channels_first:
|
| 346 |
+
x = x.permute(0,4,1,2,3)
|
| 347 |
+
|
| 348 |
+
return x
|
ldcast/models/blocks/attention.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TemporalAttention(nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self, channels, context_channels=None,
|
| 11 |
+
head_dim=32, num_heads=8
|
| 12 |
+
):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.channels = channels
|
| 15 |
+
if context_channels is None:
|
| 16 |
+
context_channels = channels
|
| 17 |
+
self.context_channels = context_channels
|
| 18 |
+
self.head_dim = head_dim
|
| 19 |
+
self.num_heads = num_heads
|
| 20 |
+
self.inner_dim = head_dim * num_heads
|
| 21 |
+
self.attn_scale = self.head_dim ** -0.5
|
| 22 |
+
if channels % num_heads:
|
| 23 |
+
raise ValueError("channels must be divisible by num_heads")
|
| 24 |
+
self.KV = nn.Linear(context_channels, self.inner_dim*2)
|
| 25 |
+
self.Q = nn.Linear(channels, self.inner_dim)
|
| 26 |
+
self.proj = nn.Linear(self.inner_dim, channels)
|
| 27 |
+
|
| 28 |
+
def forward(self, x, y=None):
|
| 29 |
+
if y is None:
|
| 30 |
+
y = x
|
| 31 |
+
|
| 32 |
+
(K,V) = self.KV(y).chunk(2, dim=-1)
|
| 33 |
+
(B, Dk, H, W, C) = K.shape
|
| 34 |
+
shape = (B, Dk, H, W, self.num_heads, self.head_dim)
|
| 35 |
+
K = K.reshape(shape)
|
| 36 |
+
V = V.reshape(shape)
|
| 37 |
+
|
| 38 |
+
Q = self.Q(x)
|
| 39 |
+
(B, Dq, H, W, C) = Q.shape
|
| 40 |
+
shape = (B, Dq, H, W, self.num_heads, self.head_dim)
|
| 41 |
+
Q = Q.reshape(shape)
|
| 42 |
+
|
| 43 |
+
K = K.permute((0,2,3,4,5,1)) # K^T
|
| 44 |
+
V = V.permute((0,2,3,4,1,5))
|
| 45 |
+
Q = Q.permute((0,2,3,4,1,5))
|
| 46 |
+
|
| 47 |
+
attn = torch.matmul(Q, K) * self.attn_scale
|
| 48 |
+
attn = F.softmax(attn, dim=-1)
|
| 49 |
+
y = torch.matmul(attn, V)
|
| 50 |
+
y = y.permute((0,4,1,2,3,5))
|
| 51 |
+
y = y.reshape((B,Dq,H,W,C))
|
| 52 |
+
y = self.proj(y)
|
| 53 |
+
return y
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class TemporalTransformer(nn.Module):
|
| 57 |
+
def __init__(self,
|
| 58 |
+
channels,
|
| 59 |
+
mlp_dim_mul=1,
|
| 60 |
+
**kwargs
|
| 61 |
+
):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.attn1 = TemporalAttention(channels, **kwargs)
|
| 64 |
+
self.attn2 = TemporalAttention(channels, **kwargs)
|
| 65 |
+
self.norm1 = nn.LayerNorm(channels)
|
| 66 |
+
self.norm2 = nn.LayerNorm(channels)
|
| 67 |
+
self.norm3 = nn.LayerNorm(channels)
|
| 68 |
+
self.mlp = MLP(channels, dim_mul=mlp_dim_mul)
|
| 69 |
+
|
| 70 |
+
def forward(self, x, y):
|
| 71 |
+
x = self.attn1(self.norm1(x)) + x # self attention
|
| 72 |
+
x = self.attn2(self.norm2(x), y) + x # cross attention
|
| 73 |
+
return self.mlp(self.norm3(x)) + x # feed-forward
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class MLP(nn.Sequential):
|
| 77 |
+
def __init__(self, dim, dim_mul=4):
|
| 78 |
+
inner_dim = dim * dim_mul
|
| 79 |
+
sequence = [
|
| 80 |
+
nn.Linear(dim, inner_dim),
|
| 81 |
+
nn.SiLU(),
|
| 82 |
+
nn.Linear(inner_dim, dim)
|
| 83 |
+
]
|
| 84 |
+
super().__init__(*sequence)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def positional_encoding(position, dims, add_dims=()):
|
| 88 |
+
div_term = torch.exp(
|
| 89 |
+
torch.arange(0, dims, 2, device=position.device) *
|
| 90 |
+
(-math.log(10000.0) / dims)
|
| 91 |
+
)
|
| 92 |
+
if position.ndim == 1:
|
| 93 |
+
arg = position[:,None] * div_term[None,:]
|
| 94 |
+
else:
|
| 95 |
+
arg = position[:,:,None] * div_term[None,None,:]
|
| 96 |
+
|
| 97 |
+
pos_enc = torch.concat(
|
| 98 |
+
[torch.sin(arg), torch.cos(arg)],
|
| 99 |
+
dim=-1
|
| 100 |
+
)
|
| 101 |
+
if add_dims:
|
| 102 |
+
for dim in add_dims:
|
| 103 |
+
pos_enc = pos_enc.unsqueeze(dim)
|
| 104 |
+
return pos_enc
|
ldcast/models/blocks/resnet.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
from torch.nn.utils.parametrizations import spectral_norm as sn
|
| 3 |
+
|
| 4 |
+
from ..utils import activation, normalization
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ResBlock3D(nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self, in_channels, out_channels, resample=None,
|
| 10 |
+
resample_factor=(1,1,1), kernel_size=(3,3,3),
|
| 11 |
+
act='swish', norm='group', norm_kwargs=None,
|
| 12 |
+
spectral_norm=False,
|
| 13 |
+
**kwargs
|
| 14 |
+
):
|
| 15 |
+
super().__init__(**kwargs)
|
| 16 |
+
if in_channels != out_channels:
|
| 17 |
+
self.proj = nn.Conv3d(in_channels, out_channels, kernel_size=1)
|
| 18 |
+
else:
|
| 19 |
+
self.proj = nn.Identity()
|
| 20 |
+
|
| 21 |
+
padding = tuple(k//2 for k in kernel_size)
|
| 22 |
+
if resample == "down":
|
| 23 |
+
self.resample = nn.AvgPool3d(resample_factor, ceil_mode=True)
|
| 24 |
+
self.conv1 = nn.Conv3d(in_channels, out_channels,
|
| 25 |
+
kernel_size=kernel_size, stride=resample_factor, padding=padding)
|
| 26 |
+
self.conv2 = nn.Conv3d(out_channels, out_channels,
|
| 27 |
+
kernel_size=kernel_size, padding=padding)
|
| 28 |
+
elif resample == "up":
|
| 29 |
+
self.resample = nn.Upsample(
|
| 30 |
+
scale_factor=resample_factor, mode='trilinear')
|
| 31 |
+
self.conv1 = nn.ConvTranspose3d(in_channels, out_channels,
|
| 32 |
+
kernel_size=kernel_size, padding=padding)
|
| 33 |
+
output_padding = tuple(
|
| 34 |
+
2*p+s-k for (p,s,k) in zip(padding,resample_factor,kernel_size)
|
| 35 |
+
)
|
| 36 |
+
self.conv2 = nn.ConvTranspose3d(out_channels, out_channels,
|
| 37 |
+
kernel_size=kernel_size, stride=resample_factor,
|
| 38 |
+
padding=padding, output_padding=output_padding)
|
| 39 |
+
else:
|
| 40 |
+
self.resample = nn.Identity()
|
| 41 |
+
self.conv1 = nn.Conv3d(in_channels, out_channels,
|
| 42 |
+
kernel_size=kernel_size, padding=padding)
|
| 43 |
+
self.conv2 = nn.Conv3d(out_channels, out_channels,
|
| 44 |
+
kernel_size=kernel_size, padding=padding)
|
| 45 |
+
|
| 46 |
+
if isinstance(act, str):
|
| 47 |
+
act = (act, act)
|
| 48 |
+
self.act1 = activation(act_type=act[0])
|
| 49 |
+
self.act2 = activation(act_type=act[1])
|
| 50 |
+
|
| 51 |
+
if norm_kwargs is None:
|
| 52 |
+
norm_kwargs = {}
|
| 53 |
+
self.norm1 = normalization(in_channels, norm_type=norm, **norm_kwargs)
|
| 54 |
+
self.norm2 = normalization(out_channels, norm_type=norm, **norm_kwargs)
|
| 55 |
+
if spectral_norm:
|
| 56 |
+
self.conv1 = sn(self.conv1)
|
| 57 |
+
self.conv2 = sn(self.conv2)
|
| 58 |
+
if not isinstance(self.proj, nn.Identity):
|
| 59 |
+
self.proj = sn(self.proj)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
x_in = self.resample(self.proj(x))
|
| 64 |
+
x = self.norm1(x)
|
| 65 |
+
x = self.act1(x)
|
| 66 |
+
x = self.conv1(x)
|
| 67 |
+
x = self.norm2(x)
|
| 68 |
+
x = self.act2(x)
|
| 69 |
+
x = self.conv2(x)
|
| 70 |
+
return x + x_in
|
ldcast/models/diffusion/diffusion.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
From https://github.com/CompVis/latent-diffusion/main/ldm/models/diffusion/ddpm.py
|
| 3 |
+
Pared down to simplify code.
|
| 4 |
+
|
| 5 |
+
The original file acknowledges:
|
| 6 |
+
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
| 7 |
+
https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
|
| 8 |
+
https://github.com/CompVis/taming-transformers
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pytorch_lightning as pl
|
| 15 |
+
from contextlib import contextmanager
|
| 16 |
+
from functools import partial
|
| 17 |
+
from torchmetrics import MeanSquaredError
|
| 18 |
+
|
| 19 |
+
from .utils import make_beta_schedule, extract_into_tensor, noise_like, timestep_embedding
|
| 20 |
+
from .ema import LitEma
|
| 21 |
+
from ..blocks.afno import PatchEmbed3d, PatchExpand3d, AFNOBlock3d
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class LatentDiffusion(pl.LightningModule):
|
| 25 |
+
def __init__(self,
|
| 26 |
+
model,
|
| 27 |
+
autoencoder,
|
| 28 |
+
context_encoder=None,
|
| 29 |
+
timesteps=1000,
|
| 30 |
+
beta_schedule="linear",
|
| 31 |
+
loss_type="l2",
|
| 32 |
+
use_ema=True,
|
| 33 |
+
lr=1e-4,
|
| 34 |
+
lr_warmup=0,
|
| 35 |
+
linear_start=1e-4,
|
| 36 |
+
linear_end=2e-2,
|
| 37 |
+
cosine_s=8e-3,
|
| 38 |
+
parameterization="eps", # all assuming fixed variance schedules
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.model = model
|
| 42 |
+
self.autoencoder = autoencoder.requires_grad_(False)
|
| 43 |
+
self.conditional = (context_encoder is not None)
|
| 44 |
+
self.context_encoder = context_encoder
|
| 45 |
+
self.lr = lr
|
| 46 |
+
self.lr_warmup = lr_warmup
|
| 47 |
+
|
| 48 |
+
self.val_loss = MeanSquaredError()
|
| 49 |
+
|
| 50 |
+
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
|
| 51 |
+
self.parameterization = parameterization
|
| 52 |
+
|
| 53 |
+
self.use_ema = use_ema
|
| 54 |
+
if self.use_ema:
|
| 55 |
+
self.model_ema = LitEma(self.model)
|
| 56 |
+
|
| 57 |
+
self.register_schedule(
|
| 58 |
+
beta_schedule=beta_schedule, timesteps=timesteps,
|
| 59 |
+
linear_start=linear_start, linear_end=linear_end,
|
| 60 |
+
cosine_s=cosine_s
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
self.loss_type = loss_type
|
| 64 |
+
|
| 65 |
+
def register_schedule(self, beta_schedule="linear", timesteps=1000,
|
| 66 |
+
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
| 67 |
+
|
| 68 |
+
betas = make_beta_schedule(
|
| 69 |
+
beta_schedule, timesteps,
|
| 70 |
+
linear_start=linear_start, linear_end=linear_end,
|
| 71 |
+
cosine_s=cosine_s
|
| 72 |
+
)
|
| 73 |
+
alphas = 1. - betas
|
| 74 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
| 75 |
+
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
| 76 |
+
|
| 77 |
+
timesteps, = betas.shape
|
| 78 |
+
self.num_timesteps = int(timesteps)
|
| 79 |
+
self.linear_start = linear_start
|
| 80 |
+
self.linear_end = linear_end
|
| 81 |
+
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
| 82 |
+
|
| 83 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
| 84 |
+
|
| 85 |
+
self.register_buffer('betas', to_torch(betas))
|
| 86 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
| 87 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
| 88 |
+
|
| 89 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 90 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
| 91 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
| 92 |
+
|
| 93 |
+
@contextmanager
|
| 94 |
+
def ema_scope(self, context=None):
|
| 95 |
+
if self.use_ema:
|
| 96 |
+
self.model_ema.store(self.model.parameters())
|
| 97 |
+
self.model_ema.copy_to(self.model)
|
| 98 |
+
if context is not None:
|
| 99 |
+
print(f"{context}: Switched to EMA weights")
|
| 100 |
+
try:
|
| 101 |
+
yield None
|
| 102 |
+
finally:
|
| 103 |
+
if self.use_ema:
|
| 104 |
+
self.model_ema.restore(self.model.parameters())
|
| 105 |
+
if context is not None:
|
| 106 |
+
print(f"{context}: Restored training weights")
|
| 107 |
+
|
| 108 |
+
def apply_model(self, x_noisy, t, cond=None, return_ids=False):
|
| 109 |
+
if self.conditional:
|
| 110 |
+
cond = self.context_encoder(cond)
|
| 111 |
+
with self.ema_scope():
|
| 112 |
+
return self.model(x_noisy, t, context=cond)
|
| 113 |
+
|
| 114 |
+
def q_sample(self, x_start, t, noise=None):
|
| 115 |
+
if noise is None:
|
| 116 |
+
noise = torch.randn_like(x_start)
|
| 117 |
+
return (
|
| 118 |
+
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
| 119 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
def get_loss(self, pred, target, mean=True):
|
| 123 |
+
if self.loss_type == 'l1':
|
| 124 |
+
loss = (target - pred).abs()
|
| 125 |
+
if mean:
|
| 126 |
+
loss = loss.mean()
|
| 127 |
+
elif self.loss_type == 'l2':
|
| 128 |
+
if mean:
|
| 129 |
+
loss = torch.nn.functional.mse_loss(target, pred)
|
| 130 |
+
else:
|
| 131 |
+
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
|
| 132 |
+
else:
|
| 133 |
+
raise NotImplementedError("unknown loss type '{loss_type}'")
|
| 134 |
+
|
| 135 |
+
return loss
|
| 136 |
+
|
| 137 |
+
def p_losses(self, x_start, t, noise=None, context=None):
|
| 138 |
+
if noise is None:
|
| 139 |
+
noise = torch.randn_like(x_start)
|
| 140 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
| 141 |
+
model_out = self.model(x_noisy, t, context=context)
|
| 142 |
+
|
| 143 |
+
if self.parameterization == "eps":
|
| 144 |
+
target = noise
|
| 145 |
+
elif self.parameterization == "x0":
|
| 146 |
+
target = x_start
|
| 147 |
+
else:
|
| 148 |
+
raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
|
| 149 |
+
|
| 150 |
+
return self.get_loss(model_out, target, mean=False).mean()
|
| 151 |
+
|
| 152 |
+
def forward(self, x, *args, **kwargs):
|
| 153 |
+
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
| 154 |
+
return self.p_losses(x, t, *args, **kwargs)
|
| 155 |
+
|
| 156 |
+
def shared_step(self, batch):
|
| 157 |
+
(x,y) = batch
|
| 158 |
+
y = self.autoencoder.encode(y)[0]
|
| 159 |
+
context = self.context_encoder(x) if self.conditional else None
|
| 160 |
+
return self(y, context=context)
|
| 161 |
+
|
| 162 |
+
def training_step(self, batch, batch_idx):
|
| 163 |
+
loss = self.shared_step(batch)
|
| 164 |
+
self.log("train_loss", loss)
|
| 165 |
+
return loss
|
| 166 |
+
|
| 167 |
+
@torch.no_grad()
|
| 168 |
+
def validation_step(self, batch, batch_idx):
|
| 169 |
+
#x, y = batch
|
| 170 |
+
#y_pred = self(x)
|
| 171 |
+
#loss2 = torch.nn.functional.mse_loss(y_pred, y)
|
| 172 |
+
loss = self.shared_step(batch)
|
| 173 |
+
with self.ema_scope():
|
| 174 |
+
loss_ema = self.shared_step(batch)
|
| 175 |
+
log_params = {"on_step": False, "on_epoch": True, "prog_bar": True}
|
| 176 |
+
self.log("val_loss", loss, **log_params)
|
| 177 |
+
self.log("val_loss_ema", loss, **log_params)
|
| 178 |
+
#self.log("mean_square_error", loss2, **log_params)
|
| 179 |
+
|
| 180 |
+
def test_step(self, batch, batch_idx):
|
| 181 |
+
return self.validation_step(batch, batch_idx)
|
| 182 |
+
|
| 183 |
+
def on_train_batch_end(self, *args, **kwargs):
|
| 184 |
+
if self.use_ema:
|
| 185 |
+
self.model_ema(self.model)
|
| 186 |
+
|
| 187 |
+
def configure_optimizers(self):
|
| 188 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr,
|
| 189 |
+
betas=(0.5, 0.9), weight_decay=1e-3)
|
| 190 |
+
reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 191 |
+
optimizer, patience=3, factor=0.25, verbose=True
|
| 192 |
+
)
|
| 193 |
+
return {
|
| 194 |
+
"optimizer": optimizer,
|
| 195 |
+
"lr_scheduler": {
|
| 196 |
+
"scheduler": reduce_lr,
|
| 197 |
+
"monitor": "val_loss_ema",
|
| 198 |
+
"frequency": 1,
|
| 199 |
+
},
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
def optimizer_step(
|
| 203 |
+
self,
|
| 204 |
+
epoch,
|
| 205 |
+
batch_idx,
|
| 206 |
+
optimizer,
|
| 207 |
+
optimizer_idx,
|
| 208 |
+
#optimizer_closure,
|
| 209 |
+
**kwargs
|
| 210 |
+
):
|
| 211 |
+
if self.trainer.global_step < self.lr_warmup:
|
| 212 |
+
lr_scale = (self.trainer.global_step+1) / self.lr_warmup
|
| 213 |
+
for pg in optimizer.param_groups:
|
| 214 |
+
pg['lr'] = lr_scale * self.lr
|
| 215 |
+
|
| 216 |
+
super().optimizer_step(
|
| 217 |
+
epoch, batch_idx, optimizer,
|
| 218 |
+
optimizer_idx,
|
| 219 |
+
#optimizer_closure,
|
| 220 |
+
**kwargs
|
| 221 |
+
)
|
| 222 |
+
|
ldcast/models/diffusion/ema.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class LitEma(nn.Module):
|
| 6 |
+
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
| 7 |
+
super().__init__()
|
| 8 |
+
if decay < 0.0 or decay > 1.0:
|
| 9 |
+
raise ValueError('Decay must be between 0 and 1')
|
| 10 |
+
|
| 11 |
+
self.m_name2s_name = {}
|
| 12 |
+
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
| 13 |
+
self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
|
| 14 |
+
else torch.tensor(-1,dtype=torch.int))
|
| 15 |
+
|
| 16 |
+
for name, p in model.named_parameters():
|
| 17 |
+
if p.requires_grad:
|
| 18 |
+
#remove as '.'-character is not allowed in buffers
|
| 19 |
+
s_name = name.replace('.','')
|
| 20 |
+
self.m_name2s_name.update({name:s_name})
|
| 21 |
+
self.register_buffer(s_name,p.clone().detach().data)
|
| 22 |
+
|
| 23 |
+
self.collected_params = []
|
| 24 |
+
|
| 25 |
+
def forward(self,model):
|
| 26 |
+
decay = self.decay
|
| 27 |
+
|
| 28 |
+
if self.num_updates >= 0:
|
| 29 |
+
self.num_updates += 1
|
| 30 |
+
decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
|
| 31 |
+
|
| 32 |
+
one_minus_decay = 1.0 - decay
|
| 33 |
+
|
| 34 |
+
with torch.no_grad():
|
| 35 |
+
m_param = dict(model.named_parameters())
|
| 36 |
+
shadow_params = dict(self.named_buffers())
|
| 37 |
+
|
| 38 |
+
for key in m_param:
|
| 39 |
+
if m_param[key].requires_grad:
|
| 40 |
+
sname = self.m_name2s_name[key]
|
| 41 |
+
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
| 42 |
+
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
| 43 |
+
else:
|
| 44 |
+
assert not key in self.m_name2s_name
|
| 45 |
+
|
| 46 |
+
def copy_to(self, model):
|
| 47 |
+
m_param = dict(model.named_parameters())
|
| 48 |
+
shadow_params = dict(self.named_buffers())
|
| 49 |
+
for key in m_param:
|
| 50 |
+
if m_param[key].requires_grad:
|
| 51 |
+
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
| 52 |
+
else:
|
| 53 |
+
assert not key in self.m_name2s_name
|
| 54 |
+
|
| 55 |
+
def store(self, parameters):
|
| 56 |
+
"""
|
| 57 |
+
Save the current parameters for restoring later.
|
| 58 |
+
Args:
|
| 59 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 60 |
+
temporarily stored.
|
| 61 |
+
"""
|
| 62 |
+
self.collected_params = [param.clone() for param in parameters]
|
| 63 |
+
|
| 64 |
+
def restore(self, parameters):
|
| 65 |
+
"""
|
| 66 |
+
Restore the parameters stored with the `store` method.
|
| 67 |
+
Useful to validate the model with EMA parameters without affecting the
|
| 68 |
+
original optimization process. Store the parameters before the
|
| 69 |
+
`copy_to` method. After validation (or model saving), use this to
|
| 70 |
+
restore the former parameters.
|
| 71 |
+
Args:
|
| 72 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 73 |
+
updated with the stored parameters.
|
| 74 |
+
"""
|
| 75 |
+
for c_param, param in zip(self.collected_params, parameters):
|
| 76 |
+
param.data.copy_(c_param.data)
|
ldcast/models/diffusion/plms.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
"""SAMPLING ONLY."""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from .utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PLMSSampler:
|
| 16 |
+
def __init__(self, model, schedule="linear", **kwargs):
|
| 17 |
+
self.model = model
|
| 18 |
+
self.ddpm_num_timesteps = model.num_timesteps
|
| 19 |
+
self.schedule = schedule
|
| 20 |
+
|
| 21 |
+
def register_buffer(self, name, attr):
|
| 22 |
+
#if type(attr) == torch.Tensor:
|
| 23 |
+
# if attr.device != torch.device("cuda"):
|
| 24 |
+
# attr = attr.to(torch.device("cuda"))
|
| 25 |
+
setattr(self, name, attr)
|
| 26 |
+
|
| 27 |
+
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
| 28 |
+
if ddim_eta != 0:
|
| 29 |
+
raise ValueError('ddim_eta must be 0 for PLMS')
|
| 30 |
+
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
| 31 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
| 32 |
+
alphas_cumprod = self.model.alphas_cumprod
|
| 33 |
+
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
| 34 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
| 35 |
+
|
| 36 |
+
self.register_buffer('betas', to_torch(self.model.betas))
|
| 37 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
| 38 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
| 39 |
+
|
| 40 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 41 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
| 42 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
| 43 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
| 44 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
| 45 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
| 46 |
+
|
| 47 |
+
# ddim sampling parameters
|
| 48 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
| 49 |
+
ddim_timesteps=self.ddim_timesteps,
|
| 50 |
+
eta=ddim_eta,verbose=verbose)
|
| 51 |
+
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
| 52 |
+
self.register_buffer('ddim_alphas', ddim_alphas)
|
| 53 |
+
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
| 54 |
+
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
| 55 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
| 56 |
+
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
| 57 |
+
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
| 58 |
+
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
| 59 |
+
|
| 60 |
+
@torch.no_grad()
|
| 61 |
+
def sample(self,
|
| 62 |
+
S,
|
| 63 |
+
batch_size,
|
| 64 |
+
shape,
|
| 65 |
+
conditioning=None,
|
| 66 |
+
callback=None,
|
| 67 |
+
normals_sequence=None,
|
| 68 |
+
img_callback=None,
|
| 69 |
+
quantize_x0=False,
|
| 70 |
+
eta=0.,
|
| 71 |
+
mask=None,
|
| 72 |
+
x0=None,
|
| 73 |
+
temperature=1.,
|
| 74 |
+
noise_dropout=0.,
|
| 75 |
+
score_corrector=None,
|
| 76 |
+
corrector_kwargs=None,
|
| 77 |
+
verbose=True,
|
| 78 |
+
x_T=None,
|
| 79 |
+
log_every_t=100,
|
| 80 |
+
unconditional_guidance_scale=1.,
|
| 81 |
+
unconditional_conditioning=None,
|
| 82 |
+
progbar=True,
|
| 83 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
| 84 |
+
**kwargs
|
| 85 |
+
):
|
| 86 |
+
"""
|
| 87 |
+
if conditioning is not None:
|
| 88 |
+
if isinstance(conditioning, dict):
|
| 89 |
+
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
| 90 |
+
if cbs != batch_size:
|
| 91 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
| 92 |
+
else:
|
| 93 |
+
if conditioning.shape[0] != batch_size:
|
| 94 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
| 98 |
+
# sampling
|
| 99 |
+
size = (batch_size,) + shape
|
| 100 |
+
print(f'Data shape for PLMS sampling is {size}')
|
| 101 |
+
|
| 102 |
+
samples, intermediates = self.plms_sampling(conditioning, size,
|
| 103 |
+
callback=callback,
|
| 104 |
+
img_callback=img_callback,
|
| 105 |
+
quantize_denoised=quantize_x0,
|
| 106 |
+
mask=mask, x0=x0,
|
| 107 |
+
ddim_use_original_steps=False,
|
| 108 |
+
noise_dropout=noise_dropout,
|
| 109 |
+
temperature=temperature,
|
| 110 |
+
score_corrector=score_corrector,
|
| 111 |
+
corrector_kwargs=corrector_kwargs,
|
| 112 |
+
x_T=x_T,
|
| 113 |
+
log_every_t=log_every_t,
|
| 114 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 115 |
+
unconditional_conditioning=unconditional_conditioning,
|
| 116 |
+
progbar=progbar
|
| 117 |
+
)
|
| 118 |
+
return samples, intermediates
|
| 119 |
+
|
| 120 |
+
@torch.no_grad()
|
| 121 |
+
def plms_sampling(self, cond, shape,
|
| 122 |
+
x_T=None, ddim_use_original_steps=False,
|
| 123 |
+
callback=None, timesteps=None, quantize_denoised=False,
|
| 124 |
+
mask=None, x0=None, img_callback=None, log_every_t=100,
|
| 125 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
| 126 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None, progbar=True):
|
| 127 |
+
device = self.model.betas.device
|
| 128 |
+
b = shape[0]
|
| 129 |
+
if x_T is None:
|
| 130 |
+
img = torch.randn(shape, device=device)
|
| 131 |
+
else:
|
| 132 |
+
img = x_T
|
| 133 |
+
|
| 134 |
+
if timesteps is None:
|
| 135 |
+
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
| 136 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
| 137 |
+
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
| 138 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
| 139 |
+
|
| 140 |
+
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
| 141 |
+
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
| 142 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
| 143 |
+
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
| 144 |
+
|
| 145 |
+
iterator = time_range
|
| 146 |
+
if progbar:
|
| 147 |
+
iterator = tqdm(iterator, desc='PLMS Sampler', total=total_steps)
|
| 148 |
+
old_eps = []
|
| 149 |
+
|
| 150 |
+
for i, step in enumerate(iterator):
|
| 151 |
+
index = total_steps - i - 1
|
| 152 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
| 153 |
+
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
| 154 |
+
|
| 155 |
+
if mask is not None:
|
| 156 |
+
assert x0 is not None
|
| 157 |
+
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
| 158 |
+
img = img_orig * mask + (1. - mask) * img
|
| 159 |
+
|
| 160 |
+
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
| 161 |
+
quantize_denoised=quantize_denoised, temperature=temperature,
|
| 162 |
+
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
| 163 |
+
corrector_kwargs=corrector_kwargs,
|
| 164 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 165 |
+
unconditional_conditioning=unconditional_conditioning,
|
| 166 |
+
old_eps=old_eps, t_next=ts_next)
|
| 167 |
+
img, pred_x0, e_t = outs
|
| 168 |
+
old_eps.append(e_t)
|
| 169 |
+
if len(old_eps) >= 4:
|
| 170 |
+
old_eps.pop(0)
|
| 171 |
+
if callback: callback(i)
|
| 172 |
+
if img_callback: img_callback(pred_x0, i)
|
| 173 |
+
|
| 174 |
+
if index % log_every_t == 0 or index == total_steps - 1:
|
| 175 |
+
intermediates['x_inter'].append(img)
|
| 176 |
+
intermediates['pred_x0'].append(pred_x0)
|
| 177 |
+
|
| 178 |
+
return img, intermediates
|
| 179 |
+
|
| 180 |
+
@torch.no_grad()
|
| 181 |
+
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
| 182 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
| 183 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
| 184 |
+
b, *_, device = *x.shape, x.device
|
| 185 |
+
|
| 186 |
+
def get_model_output(x, t):
|
| 187 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
| 188 |
+
e_t = self.model.apply_model(x, t, c)
|
| 189 |
+
else:
|
| 190 |
+
x_in = torch.cat([x] * 2)
|
| 191 |
+
t_in = torch.cat([t] * 2)
|
| 192 |
+
c_in = torch.cat([unconditional_conditioning, c])
|
| 193 |
+
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
| 194 |
+
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
| 195 |
+
|
| 196 |
+
if score_corrector is not None:
|
| 197 |
+
assert self.model.parameterization == "eps"
|
| 198 |
+
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
| 199 |
+
|
| 200 |
+
return e_t
|
| 201 |
+
|
| 202 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
| 203 |
+
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
| 204 |
+
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
| 205 |
+
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
| 206 |
+
|
| 207 |
+
def get_x_prev_and_pred_x0(e_t, index):
|
| 208 |
+
# select parameters corresponding to the currently considered timestep
|
| 209 |
+
param_shape = (b,) + (1,)*(x.ndim-1)
|
| 210 |
+
a_t = torch.full(param_shape, alphas[index], device=device)
|
| 211 |
+
a_prev = torch.full(param_shape, alphas_prev[index], device=device)
|
| 212 |
+
sigma_t = torch.full(param_shape, sigmas[index], device=device)
|
| 213 |
+
sqrt_one_minus_at = torch.full(param_shape, sqrt_one_minus_alphas[index],device=device)
|
| 214 |
+
|
| 215 |
+
# current prediction for x_0
|
| 216 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
| 217 |
+
if quantize_denoised:
|
| 218 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
| 219 |
+
# direction pointing to x_t
|
| 220 |
+
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
| 221 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
| 222 |
+
if noise_dropout > 0.:
|
| 223 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
| 224 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
| 225 |
+
return x_prev, pred_x0
|
| 226 |
+
|
| 227 |
+
e_t = get_model_output(x, t)
|
| 228 |
+
if len(old_eps) == 0:
|
| 229 |
+
# Pseudo Improved Euler (2nd order)
|
| 230 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
| 231 |
+
e_t_next = get_model_output(x_prev, t_next)
|
| 232 |
+
e_t_prime = (e_t + e_t_next) / 2
|
| 233 |
+
elif len(old_eps) == 1:
|
| 234 |
+
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
| 235 |
+
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
| 236 |
+
elif len(old_eps) == 2:
|
| 237 |
+
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
| 238 |
+
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
| 239 |
+
elif len(old_eps) >= 3:
|
| 240 |
+
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
| 241 |
+
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
| 242 |
+
|
| 243 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
| 244 |
+
|
| 245 |
+
return x_prev, pred_x0, e_t
|
ldcast/models/diffusion/utils.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# adopted from
|
| 2 |
+
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
| 3 |
+
# and
|
| 4 |
+
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
| 5 |
+
# and
|
| 6 |
+
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
| 7 |
+
#
|
| 8 |
+
# thanks!
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import math
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import numpy as np
|
| 15 |
+
from einops import repeat
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
| 19 |
+
if schedule == "linear":
|
| 20 |
+
betas = (
|
| 21 |
+
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
elif schedule == "cosine":
|
| 25 |
+
timesteps = (
|
| 26 |
+
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
| 27 |
+
)
|
| 28 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
| 29 |
+
alphas = torch.cos(alphas).pow(2)
|
| 30 |
+
alphas = alphas / alphas[0]
|
| 31 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
| 32 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
| 33 |
+
|
| 34 |
+
elif schedule == "sqrt_linear":
|
| 35 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
| 36 |
+
elif schedule == "sqrt":
|
| 37 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
| 38 |
+
else:
|
| 39 |
+
raise ValueError(f"schedule '{schedule}' unknown.")
|
| 40 |
+
return betas.numpy()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
| 44 |
+
if ddim_discr_method == 'uniform':
|
| 45 |
+
c = num_ddpm_timesteps // num_ddim_timesteps
|
| 46 |
+
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
| 47 |
+
elif ddim_discr_method == 'quad':
|
| 48 |
+
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
| 49 |
+
else:
|
| 50 |
+
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
| 51 |
+
|
| 52 |
+
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
| 53 |
+
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
| 54 |
+
steps_out = ddim_timesteps + 1
|
| 55 |
+
if verbose:
|
| 56 |
+
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
| 57 |
+
return steps_out
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
| 61 |
+
# select alphas for computing the variance schedule
|
| 62 |
+
alphas = alphacums[ddim_timesteps]
|
| 63 |
+
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
| 64 |
+
|
| 65 |
+
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
| 66 |
+
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
| 67 |
+
if verbose:
|
| 68 |
+
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
| 69 |
+
print(f'For the chosen value of eta, which is {eta}, '
|
| 70 |
+
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
| 71 |
+
return sigmas, alphas, alphas_prev
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
| 75 |
+
"""
|
| 76 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
| 77 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
| 78 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
| 79 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
| 80 |
+
produces the cumulative product of (1-beta) up to that
|
| 81 |
+
part of the diffusion process.
|
| 82 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
| 83 |
+
prevent singularities.
|
| 84 |
+
"""
|
| 85 |
+
betas = []
|
| 86 |
+
for i in range(num_diffusion_timesteps):
|
| 87 |
+
t1 = i / num_diffusion_timesteps
|
| 88 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
| 89 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
| 90 |
+
return np.array(betas)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def extract_into_tensor(a, t, x_shape):
|
| 94 |
+
b, *_ = t.shape
|
| 95 |
+
out = a.gather(-1, t)
|
| 96 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def checkpoint(func, inputs, params, flag):
|
| 100 |
+
"""
|
| 101 |
+
Evaluate a function without caching intermediate activations, allowing for
|
| 102 |
+
reduced memory at the expense of extra compute in the backward pass.
|
| 103 |
+
:param func: the function to evaluate.
|
| 104 |
+
:param inputs: the argument sequence to pass to `func`.
|
| 105 |
+
:param params: a sequence of parameters `func` depends on but does not
|
| 106 |
+
explicitly take as arguments.
|
| 107 |
+
:param flag: if False, disable gradient checkpointing.
|
| 108 |
+
"""
|
| 109 |
+
if flag:
|
| 110 |
+
args = tuple(inputs) + tuple(params)
|
| 111 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
| 112 |
+
else:
|
| 113 |
+
return func(*inputs)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class CheckpointFunction(torch.autograd.Function):
|
| 117 |
+
@staticmethod
|
| 118 |
+
def forward(ctx, run_function, length, *args):
|
| 119 |
+
ctx.run_function = run_function
|
| 120 |
+
ctx.input_tensors = list(args[:length])
|
| 121 |
+
ctx.input_params = list(args[length:])
|
| 122 |
+
|
| 123 |
+
with torch.no_grad():
|
| 124 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
| 125 |
+
return output_tensors
|
| 126 |
+
|
| 127 |
+
@staticmethod
|
| 128 |
+
def backward(ctx, *output_grads):
|
| 129 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
| 130 |
+
with torch.enable_grad():
|
| 131 |
+
# Fixes a bug where the first op in run_function modifies the
|
| 132 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
| 133 |
+
# Tensors.
|
| 134 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
| 135 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
| 136 |
+
input_grads = torch.autograd.grad(
|
| 137 |
+
output_tensors,
|
| 138 |
+
ctx.input_tensors + ctx.input_params,
|
| 139 |
+
output_grads,
|
| 140 |
+
allow_unused=True,
|
| 141 |
+
)
|
| 142 |
+
del ctx.input_tensors
|
| 143 |
+
del ctx.input_params
|
| 144 |
+
del output_tensors
|
| 145 |
+
return (None, None) + input_grads
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
| 149 |
+
"""
|
| 150 |
+
Create sinusoidal timestep embeddings.
|
| 151 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
| 152 |
+
These may be fractional.
|
| 153 |
+
:param dim: the dimension of the output.
|
| 154 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 155 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
| 156 |
+
"""
|
| 157 |
+
if not repeat_only:
|
| 158 |
+
half = dim // 2
|
| 159 |
+
freqs = torch.exp(
|
| 160 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 161 |
+
).to(device=timesteps.device)
|
| 162 |
+
args = timesteps[:, None].float() * freqs[None]
|
| 163 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 164 |
+
if dim % 2:
|
| 165 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 166 |
+
else:
|
| 167 |
+
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
| 168 |
+
return embedding
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def zero_module(module):
|
| 172 |
+
"""
|
| 173 |
+
Zero out the parameters of a module and return it.
|
| 174 |
+
"""
|
| 175 |
+
for p in module.parameters():
|
| 176 |
+
p.detach().zero_()
|
| 177 |
+
return module
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def scale_module(module, scale):
|
| 181 |
+
"""
|
| 182 |
+
Scale the parameters of a module and return it.
|
| 183 |
+
"""
|
| 184 |
+
for p in module.parameters():
|
| 185 |
+
p.detach().mul_(scale)
|
| 186 |
+
return module
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def mean_flat(tensor):
|
| 190 |
+
"""
|
| 191 |
+
Take the mean over all non-batch dimensions.
|
| 192 |
+
"""
|
| 193 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class GroupNorm32(nn.GroupNorm):
|
| 197 |
+
def forward(self, x):
|
| 198 |
+
return super().forward(x.float()).type(x.dtype)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def normalization(channels):
|
| 202 |
+
"""
|
| 203 |
+
Make a standard normalization layer.
|
| 204 |
+
:param channels: number of input channels.
|
| 205 |
+
:return: an nn.Module for normalization.
|
| 206 |
+
"""
|
| 207 |
+
return nn.Identity() #GroupNorm32(32, channels)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def noise_like(shape, device, repeat=False):
|
| 211 |
+
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
| 212 |
+
noise = lambda: torch.randn(shape, device=device)
|
| 213 |
+
return repeat_noise() if repeat else noise()
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def conv_nd(dims, *args, **kwargs):
|
| 217 |
+
"""
|
| 218 |
+
Create a 1D, 2D, or 3D convolution module.
|
| 219 |
+
"""
|
| 220 |
+
if dims == 1:
|
| 221 |
+
return nn.Conv1d(*args, **kwargs)
|
| 222 |
+
elif dims == 2:
|
| 223 |
+
return nn.Conv2d(*args, **kwargs)
|
| 224 |
+
elif dims == 3:
|
| 225 |
+
return nn.Conv3d(*args, **kwargs)
|
| 226 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def linear(*args, **kwargs):
|
| 230 |
+
"""
|
| 231 |
+
Create a linear module.
|
| 232 |
+
"""
|
| 233 |
+
return nn.Linear(*args, **kwargs)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
| 237 |
+
"""
|
| 238 |
+
Create a 1D, 2D, or 3D average pooling module.
|
| 239 |
+
"""
|
| 240 |
+
if dims == 1:
|
| 241 |
+
return nn.AvgPool1d(*args, **kwargs)
|
| 242 |
+
elif dims == 2:
|
| 243 |
+
return nn.AvgPool2d(*args, **kwargs)
|
| 244 |
+
elif dims == 3:
|
| 245 |
+
return nn.AvgPool3d(*args, **kwargs)
|
| 246 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
ldcast/models/distributions.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def kl_from_standard_normal(mean, log_var):
|
| 6 |
+
kl = 0.5 * (log_var.exp() + mean.square() - 1.0 - log_var)
|
| 7 |
+
return kl.mean()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def sample_from_standard_normal(mean, log_var, num=None):
|
| 11 |
+
std = (0.5 * log_var).exp()
|
| 12 |
+
shape = mean.shape
|
| 13 |
+
if num is not None:
|
| 14 |
+
# expand channel 1 to create several samples
|
| 15 |
+
shape = shape[:1] + (num,) + shape[1:]
|
| 16 |
+
mean = mean[:,None,...]
|
| 17 |
+
std = std[:,None,...]
|
| 18 |
+
return mean + std * torch.randn(shape, device=mean.device)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def ensemble_nll_normal(ensemble, sample, epsilon=1e-5):
|
| 22 |
+
mean = ensemble.mean(dim=1)
|
| 23 |
+
var = ensemble.var(dim=1, unbiased=True) + epsilon
|
| 24 |
+
logvar = var.log()
|
| 25 |
+
|
| 26 |
+
diff = sample[:,None,...] - mean
|
| 27 |
+
logtwopi = np.log(2*np.pi)
|
| 28 |
+
nll = (logtwopi + logvar + diff.square() / var).mean()
|
| 29 |
+
return nll
|
ldcast/models/genforecast/analysis.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
from ..nowcast.nowcast import AFNONowcastNetBase
|
| 6 |
+
from ..blocks.resnet import ResBlock3D
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AFNONowcastNetCascade(AFNONowcastNetBase):
|
| 10 |
+
def __init__(self, *args, cascade_depth=4, **kwargs):
|
| 11 |
+
super().__init__(*args, **kwargs)
|
| 12 |
+
self.cascade_depth = cascade_depth
|
| 13 |
+
self.resnet = nn.ModuleList()
|
| 14 |
+
ch = self.embed_dim_out
|
| 15 |
+
self.cascade_dims = [ch]
|
| 16 |
+
for i in range(cascade_depth-1):
|
| 17 |
+
ch_out = 2*ch
|
| 18 |
+
self.cascade_dims.append(ch_out)
|
| 19 |
+
self.resnet.append(
|
| 20 |
+
ResBlock3D(ch, ch_out, kernel_size=(1,3,3), norm=None)
|
| 21 |
+
)
|
| 22 |
+
ch = ch_out
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
x = super().forward(x)
|
| 26 |
+
img_shape = tuple(x.shape[-2:])
|
| 27 |
+
cascade = {img_shape: x}
|
| 28 |
+
for i in range(self.cascade_depth-1):
|
| 29 |
+
x = F.avg_pool3d(x, (1,2,2))
|
| 30 |
+
x = self.resnet[i](x)
|
| 31 |
+
img_shape = tuple(x.shape[-2:])
|
| 32 |
+
cascade[img_shape] = x
|
| 33 |
+
return cascade
|
ldcast/models/genforecast/training.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytorch_lightning as pl
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from ..diffusion import diffusion
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def setup_genforecast_training(
|
| 8 |
+
model,
|
| 9 |
+
autoencoder,
|
| 10 |
+
context_encoder,
|
| 11 |
+
model_dir,
|
| 12 |
+
lr=1e-4
|
| 13 |
+
):
|
| 14 |
+
ldm = diffusion.LatentDiffusion(model, autoencoder,
|
| 15 |
+
context_encoder=context_encoder, lr=lr)
|
| 16 |
+
|
| 17 |
+
num_gpus = torch.cuda.device_count()
|
| 18 |
+
accelerator = "gpu" if (num_gpus > 0) else "cpu"
|
| 19 |
+
devices = torch.cuda.device_count() if (accelerator == "gpu") else 1
|
| 20 |
+
|
| 21 |
+
early_stopping = pl.callbacks.EarlyStopping(
|
| 22 |
+
"val_loss_ema", patience=6, verbose=True, check_finite=False
|
| 23 |
+
)
|
| 24 |
+
checkpoint = pl.callbacks.ModelCheckpoint(
|
| 25 |
+
dirpath=model_dir,
|
| 26 |
+
filename="{epoch}-{val_loss_ema:.4f}",
|
| 27 |
+
monitor="val_loss_ema",
|
| 28 |
+
every_n_epochs=1,
|
| 29 |
+
save_top_k=3
|
| 30 |
+
)
|
| 31 |
+
callbacks = [early_stopping, checkpoint]
|
| 32 |
+
|
| 33 |
+
trainer = pl.Trainer(
|
| 34 |
+
accelerator=accelerator,
|
| 35 |
+
devices=devices,
|
| 36 |
+
max_epochs=300,
|
| 37 |
+
#strategy='dp' if (num_gpus > 1) else None,
|
| 38 |
+
callbacks=callbacks,
|
| 39 |
+
#precision=16
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
return (ldm, trainer)
|
ldcast/models/genforecast/unet.py
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod
|
| 2 |
+
from functools import partial
|
| 3 |
+
import math
|
| 4 |
+
from typing import Iterable
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch as th
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from ..diffusion.utils import (
|
| 12 |
+
checkpoint,
|
| 13 |
+
conv_nd,
|
| 14 |
+
linear,
|
| 15 |
+
avg_pool_nd,
|
| 16 |
+
zero_module,
|
| 17 |
+
normalization,
|
| 18 |
+
timestep_embedding,
|
| 19 |
+
)
|
| 20 |
+
from ..blocks.afno import AFNOCrossAttentionBlock3d
|
| 21 |
+
SpatialTransformer = type(None)
|
| 22 |
+
#from ldm.modules.attention import SpatialTransformer
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TimestepBlock(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
Any module where forward() takes timestep embeddings as a second argument.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
@abstractmethod
|
| 31 |
+
def forward(self, x, emb):
|
| 32 |
+
"""
|
| 33 |
+
Apply the module to `x` given `emb` timestep embeddings.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
| 38 |
+
"""
|
| 39 |
+
A sequential module that passes timestep embeddings to the children that
|
| 40 |
+
support it as an extra input.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def forward(self, x, emb, context=None):
|
| 44 |
+
for layer in self:
|
| 45 |
+
if isinstance(layer, TimestepBlock):
|
| 46 |
+
x = layer(x, emb)
|
| 47 |
+
elif isinstance(layer, AFNOCrossAttentionBlock3d):
|
| 48 |
+
img_shape = tuple(x.shape[-2:])
|
| 49 |
+
x = layer(x, context[img_shape])
|
| 50 |
+
else:
|
| 51 |
+
x = layer(x)
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Upsample(nn.Module):
|
| 56 |
+
"""
|
| 57 |
+
An upsampling layer with an optional convolution.
|
| 58 |
+
:param channels: channels in the inputs and outputs.
|
| 59 |
+
:param use_conv: a bool determining if a convolution is applied.
|
| 60 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
| 61 |
+
upsampling occurs in the inner-two dimensions.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.channels = channels
|
| 67 |
+
self.out_channels = out_channels or channels
|
| 68 |
+
self.use_conv = use_conv
|
| 69 |
+
self.dims = dims
|
| 70 |
+
if use_conv:
|
| 71 |
+
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
assert x.shape[1] == self.channels
|
| 75 |
+
if self.dims == 3:
|
| 76 |
+
x = F.interpolate(
|
| 77 |
+
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
| 78 |
+
)
|
| 79 |
+
else:
|
| 80 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
| 81 |
+
if self.use_conv:
|
| 82 |
+
x = self.conv(x)
|
| 83 |
+
return x
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class Downsample(nn.Module):
|
| 87 |
+
"""
|
| 88 |
+
A downsampling layer with an optional convolution.
|
| 89 |
+
:param channels: channels in the inputs and outputs.
|
| 90 |
+
:param use_conv: a bool determining if a convolution is applied.
|
| 91 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
| 92 |
+
downsampling occurs in the inner-two dimensions.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.channels = channels
|
| 98 |
+
self.out_channels = out_channels or channels
|
| 99 |
+
self.use_conv = use_conv
|
| 100 |
+
self.dims = dims
|
| 101 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
| 102 |
+
if use_conv:
|
| 103 |
+
self.op = conv_nd(
|
| 104 |
+
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
| 105 |
+
)
|
| 106 |
+
else:
|
| 107 |
+
assert self.channels == self.out_channels
|
| 108 |
+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
| 109 |
+
|
| 110 |
+
def forward(self, x):
|
| 111 |
+
assert x.shape[1] == self.channels
|
| 112 |
+
return self.op(x)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class ResBlock(TimestepBlock):
|
| 116 |
+
"""
|
| 117 |
+
A residual block that can optionally change the number of channels.
|
| 118 |
+
:param channels: the number of input channels.
|
| 119 |
+
:param emb_channels: the number of timestep embedding channels.
|
| 120 |
+
:param dropout: the rate of dropout.
|
| 121 |
+
:param out_channels: if specified, the number of out channels.
|
| 122 |
+
:param use_conv: if True and out_channels is specified, use a spatial
|
| 123 |
+
convolution instead of a smaller 1x1 convolution to change the
|
| 124 |
+
channels in the skip connection.
|
| 125 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
| 126 |
+
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
| 127 |
+
:param up: if True, use this block for upsampling.
|
| 128 |
+
:param down: if True, use this block for downsampling.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
def __init__(
|
| 132 |
+
self,
|
| 133 |
+
channels,
|
| 134 |
+
emb_channels,
|
| 135 |
+
dropout,
|
| 136 |
+
out_channels=None,
|
| 137 |
+
use_conv=False,
|
| 138 |
+
use_scale_shift_norm=False,
|
| 139 |
+
dims=2,
|
| 140 |
+
use_checkpoint=False,
|
| 141 |
+
up=False,
|
| 142 |
+
down=False,
|
| 143 |
+
):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.channels = channels
|
| 146 |
+
self.emb_channels = emb_channels
|
| 147 |
+
self.dropout = dropout
|
| 148 |
+
self.out_channels = out_channels or channels
|
| 149 |
+
self.use_conv = use_conv
|
| 150 |
+
self.use_checkpoint = use_checkpoint
|
| 151 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
| 152 |
+
|
| 153 |
+
self.in_layers = nn.Sequential(
|
| 154 |
+
normalization(channels),
|
| 155 |
+
nn.SiLU(),
|
| 156 |
+
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
self.updown = up or down
|
| 160 |
+
|
| 161 |
+
if up:
|
| 162 |
+
self.h_upd = Upsample(channels, False, dims)
|
| 163 |
+
self.x_upd = Upsample(channels, False, dims)
|
| 164 |
+
elif down:
|
| 165 |
+
self.h_upd = Downsample(channels, False, dims)
|
| 166 |
+
self.x_upd = Downsample(channels, False, dims)
|
| 167 |
+
else:
|
| 168 |
+
self.h_upd = self.x_upd = nn.Identity()
|
| 169 |
+
|
| 170 |
+
self.emb_layers = nn.Sequential(
|
| 171 |
+
nn.SiLU(),
|
| 172 |
+
linear(
|
| 173 |
+
emb_channels,
|
| 174 |
+
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
| 175 |
+
),
|
| 176 |
+
)
|
| 177 |
+
self.out_layers = nn.Sequential(
|
| 178 |
+
normalization(self.out_channels),
|
| 179 |
+
nn.SiLU(),
|
| 180 |
+
nn.Dropout(p=dropout),
|
| 181 |
+
zero_module(
|
| 182 |
+
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
| 183 |
+
),
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
if self.out_channels == channels:
|
| 187 |
+
self.skip_connection = nn.Identity()
|
| 188 |
+
elif use_conv:
|
| 189 |
+
self.skip_connection = conv_nd(
|
| 190 |
+
dims, channels, self.out_channels, 3, padding=1
|
| 191 |
+
)
|
| 192 |
+
else:
|
| 193 |
+
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
| 194 |
+
|
| 195 |
+
def forward(self, x, emb):
|
| 196 |
+
"""
|
| 197 |
+
Apply the block to a Tensor, conditioned on a timestep embedding.
|
| 198 |
+
:param x: an [N x C x ...] Tensor of features.
|
| 199 |
+
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
| 200 |
+
:return: an [N x C x ...] Tensor of outputs.
|
| 201 |
+
"""
|
| 202 |
+
return checkpoint(
|
| 203 |
+
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _forward(self, x, emb):
|
| 208 |
+
if self.updown:
|
| 209 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
| 210 |
+
h = in_rest(x)
|
| 211 |
+
h = self.h_upd(h)
|
| 212 |
+
x = self.x_upd(x)
|
| 213 |
+
h = in_conv(h)
|
| 214 |
+
else:
|
| 215 |
+
h = self.in_layers(x)
|
| 216 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
| 217 |
+
while len(emb_out.shape) < len(h.shape):
|
| 218 |
+
emb_out = emb_out[..., None]
|
| 219 |
+
if self.use_scale_shift_norm:
|
| 220 |
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
| 221 |
+
scale, shift = th.chunk(emb_out, 2, dim=1)
|
| 222 |
+
h = out_norm(h) * (1 + scale) + shift
|
| 223 |
+
h = out_rest(h)
|
| 224 |
+
else:
|
| 225 |
+
h = h + emb_out
|
| 226 |
+
h = self.out_layers(h)
|
| 227 |
+
return self.skip_connection(x) + h
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class UNetModel(nn.Module):
|
| 231 |
+
"""
|
| 232 |
+
The full UNet model with attention and timestep embedding.
|
| 233 |
+
:param in_channels: channels in the input Tensor.
|
| 234 |
+
:param model_channels: base channel count for the model.
|
| 235 |
+
:param out_channels: channels in the output Tensor.
|
| 236 |
+
:param num_res_blocks: number of residual blocks per downsample.
|
| 237 |
+
:param attention_resolutions: a collection of downsample rates at which
|
| 238 |
+
attention will take place. May be a set, list, or tuple.
|
| 239 |
+
For example, if this contains 4, then at 4x downsampling, attention
|
| 240 |
+
will be used.
|
| 241 |
+
:param dropout: the dropout probability.
|
| 242 |
+
:param channel_mult: channel multiplier for each level of the UNet.
|
| 243 |
+
:param conv_resample: if True, use learned convolutions for upsampling and
|
| 244 |
+
downsampling.
|
| 245 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
| 246 |
+
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
| 247 |
+
:param num_heads: the number of attention heads in each attention layer.
|
| 248 |
+
:param num_heads_channels: if specified, ignore num_heads and instead use
|
| 249 |
+
a fixed channel width per attention head.
|
| 250 |
+
:param num_heads_upsample: works with num_heads to set a different number
|
| 251 |
+
of heads for upsampling. Deprecated.
|
| 252 |
+
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
| 253 |
+
:param resblock_updown: use residual blocks for up/downsampling.
|
| 254 |
+
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
def __init__(
|
| 258 |
+
self,
|
| 259 |
+
model_channels,
|
| 260 |
+
in_channels=1,
|
| 261 |
+
out_channels=1,
|
| 262 |
+
num_res_blocks=2,
|
| 263 |
+
attention_resolutions=(1,2,4),
|
| 264 |
+
context_ch=128,
|
| 265 |
+
dropout=0,
|
| 266 |
+
channel_mult=(1, 2, 4, 4),
|
| 267 |
+
conv_resample=True,
|
| 268 |
+
dims=3,
|
| 269 |
+
use_checkpoint=False,
|
| 270 |
+
use_fp16=False,
|
| 271 |
+
num_heads=-1,
|
| 272 |
+
num_head_channels=-1,
|
| 273 |
+
num_heads_upsample=-1,
|
| 274 |
+
use_scale_shift_norm=False,
|
| 275 |
+
resblock_updown=False,
|
| 276 |
+
legacy=True,
|
| 277 |
+
num_timesteps=1
|
| 278 |
+
):
|
| 279 |
+
super().__init__()
|
| 280 |
+
|
| 281 |
+
if num_heads_upsample == -1:
|
| 282 |
+
num_heads_upsample = num_heads
|
| 283 |
+
|
| 284 |
+
if num_heads == -1:
|
| 285 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
| 286 |
+
|
| 287 |
+
if num_head_channels == -1:
|
| 288 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
| 289 |
+
|
| 290 |
+
self.in_channels = in_channels
|
| 291 |
+
self.model_channels = model_channels
|
| 292 |
+
self.out_channels = out_channels
|
| 293 |
+
self.num_res_blocks = num_res_blocks
|
| 294 |
+
self.attention_resolutions = attention_resolutions
|
| 295 |
+
self.dropout = dropout
|
| 296 |
+
self.channel_mult = channel_mult
|
| 297 |
+
self.conv_resample = conv_resample
|
| 298 |
+
self.use_checkpoint = use_checkpoint
|
| 299 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
| 300 |
+
self.num_heads = num_heads
|
| 301 |
+
self.num_head_channels = num_head_channels
|
| 302 |
+
self.num_heads_upsample = num_heads_upsample
|
| 303 |
+
timesteps = th.arange(1, num_timesteps+1)
|
| 304 |
+
|
| 305 |
+
time_embed_dim = model_channels * 4
|
| 306 |
+
self.time_embed = nn.Sequential(
|
| 307 |
+
linear(model_channels, time_embed_dim),
|
| 308 |
+
nn.SiLU(),
|
| 309 |
+
linear(time_embed_dim, time_embed_dim),
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
self.input_blocks = nn.ModuleList(
|
| 313 |
+
[
|
| 314 |
+
TimestepEmbedSequential(
|
| 315 |
+
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
| 316 |
+
)
|
| 317 |
+
]
|
| 318 |
+
)
|
| 319 |
+
self._feature_size = model_channels
|
| 320 |
+
input_block_chans = [model_channels]
|
| 321 |
+
ch = model_channels
|
| 322 |
+
ds = 1
|
| 323 |
+
for level, mult in enumerate(channel_mult):
|
| 324 |
+
for _ in range(num_res_blocks):
|
| 325 |
+
layers = [
|
| 326 |
+
ResBlock(
|
| 327 |
+
ch,
|
| 328 |
+
time_embed_dim,
|
| 329 |
+
dropout,
|
| 330 |
+
out_channels=mult * model_channels,
|
| 331 |
+
dims=dims,
|
| 332 |
+
use_checkpoint=use_checkpoint,
|
| 333 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 334 |
+
)
|
| 335 |
+
]
|
| 336 |
+
ch = mult * model_channels
|
| 337 |
+
if ds in attention_resolutions:
|
| 338 |
+
if num_head_channels == -1:
|
| 339 |
+
dim_head = ch // num_heads
|
| 340 |
+
else:
|
| 341 |
+
num_heads = ch // num_head_channels
|
| 342 |
+
dim_head = num_head_channels
|
| 343 |
+
if legacy:
|
| 344 |
+
dim_head = num_head_channels
|
| 345 |
+
layers.append(
|
| 346 |
+
AFNOCrossAttentionBlock3d(
|
| 347 |
+
ch, context_dim=context_ch[level], num_blocks=num_heads,
|
| 348 |
+
data_format="channels_first", timesteps=timesteps
|
| 349 |
+
)
|
| 350 |
+
)
|
| 351 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
| 352 |
+
self._feature_size += ch
|
| 353 |
+
input_block_chans.append(ch)
|
| 354 |
+
if level != len(channel_mult) - 1:
|
| 355 |
+
out_ch = ch
|
| 356 |
+
self.input_blocks.append(
|
| 357 |
+
TimestepEmbedSequential(
|
| 358 |
+
ResBlock(
|
| 359 |
+
ch,
|
| 360 |
+
time_embed_dim,
|
| 361 |
+
dropout,
|
| 362 |
+
out_channels=out_ch,
|
| 363 |
+
dims=dims,
|
| 364 |
+
use_checkpoint=use_checkpoint,
|
| 365 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 366 |
+
down=True,
|
| 367 |
+
)
|
| 368 |
+
if resblock_updown
|
| 369 |
+
else Downsample(
|
| 370 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
| 371 |
+
)
|
| 372 |
+
)
|
| 373 |
+
)
|
| 374 |
+
ch = out_ch
|
| 375 |
+
input_block_chans.append(ch)
|
| 376 |
+
ds *= 2
|
| 377 |
+
self._feature_size += ch
|
| 378 |
+
|
| 379 |
+
if num_head_channels == -1:
|
| 380 |
+
dim_head = ch // num_heads
|
| 381 |
+
else:
|
| 382 |
+
num_heads = ch // num_head_channels
|
| 383 |
+
dim_head = num_head_channels
|
| 384 |
+
if legacy:
|
| 385 |
+
dim_head = num_head_channels
|
| 386 |
+
self.middle_block = TimestepEmbedSequential(
|
| 387 |
+
ResBlock(
|
| 388 |
+
ch,
|
| 389 |
+
time_embed_dim,
|
| 390 |
+
dropout,
|
| 391 |
+
dims=dims,
|
| 392 |
+
use_checkpoint=use_checkpoint,
|
| 393 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 394 |
+
),
|
| 395 |
+
AFNOCrossAttentionBlock3d(
|
| 396 |
+
ch, context_dim=context_ch[-1], num_blocks=num_heads,
|
| 397 |
+
data_format="channels_first", timesteps=timesteps
|
| 398 |
+
),
|
| 399 |
+
ResBlock(
|
| 400 |
+
ch,
|
| 401 |
+
time_embed_dim,
|
| 402 |
+
dropout,
|
| 403 |
+
dims=dims,
|
| 404 |
+
use_checkpoint=use_checkpoint,
|
| 405 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 406 |
+
),
|
| 407 |
+
)
|
| 408 |
+
self._feature_size += ch
|
| 409 |
+
|
| 410 |
+
self.output_blocks = nn.ModuleList([])
|
| 411 |
+
for level, mult in list(enumerate(channel_mult))[::-1]:
|
| 412 |
+
for i in range(num_res_blocks + 1):
|
| 413 |
+
ich = input_block_chans.pop()
|
| 414 |
+
layers = [
|
| 415 |
+
ResBlock(
|
| 416 |
+
ch + ich,
|
| 417 |
+
time_embed_dim,
|
| 418 |
+
dropout,
|
| 419 |
+
out_channels=model_channels * mult,
|
| 420 |
+
dims=dims,
|
| 421 |
+
use_checkpoint=use_checkpoint,
|
| 422 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 423 |
+
)
|
| 424 |
+
]
|
| 425 |
+
ch = model_channels * mult
|
| 426 |
+
if ds in attention_resolutions:
|
| 427 |
+
if num_head_channels == -1:
|
| 428 |
+
dim_head = ch // num_heads
|
| 429 |
+
else:
|
| 430 |
+
num_heads = ch // num_head_channels
|
| 431 |
+
dim_head = num_head_channels
|
| 432 |
+
if legacy:
|
| 433 |
+
#num_heads = 1
|
| 434 |
+
dim_head = num_head_channels
|
| 435 |
+
layers.append(
|
| 436 |
+
AFNOCrossAttentionBlock3d(
|
| 437 |
+
ch, context_dim=context_ch[level], num_blocks=num_heads,
|
| 438 |
+
data_format="channels_first", timesteps=timesteps
|
| 439 |
+
)
|
| 440 |
+
)
|
| 441 |
+
if level and i == num_res_blocks:
|
| 442 |
+
out_ch = ch
|
| 443 |
+
layers.append(
|
| 444 |
+
ResBlock(
|
| 445 |
+
ch,
|
| 446 |
+
time_embed_dim,
|
| 447 |
+
dropout,
|
| 448 |
+
out_channels=out_ch,
|
| 449 |
+
dims=dims,
|
| 450 |
+
use_checkpoint=use_checkpoint,
|
| 451 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 452 |
+
up=True,
|
| 453 |
+
)
|
| 454 |
+
if resblock_updown
|
| 455 |
+
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
| 456 |
+
)
|
| 457 |
+
ds //= 2
|
| 458 |
+
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
| 459 |
+
self._feature_size += ch
|
| 460 |
+
|
| 461 |
+
self.out = nn.Sequential(
|
| 462 |
+
normalization(ch),
|
| 463 |
+
nn.SiLU(),
|
| 464 |
+
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
def forward(self, x, timesteps=None, context=None):
|
| 468 |
+
"""
|
| 469 |
+
Apply the model to an input batch.
|
| 470 |
+
:param x: an [N x C x ...] Tensor of inputs.
|
| 471 |
+
:param timesteps: a 1-D batch of timesteps.
|
| 472 |
+
:param context: conditioning plugged in via crossattn
|
| 473 |
+
:return: an [N x C x ...] Tensor of outputs.
|
| 474 |
+
"""
|
| 475 |
+
hs = []
|
| 476 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
| 477 |
+
emb = self.time_embed(t_emb)
|
| 478 |
+
|
| 479 |
+
h = x.type(self.dtype)
|
| 480 |
+
for module in self.input_blocks:
|
| 481 |
+
h = module(h, emb, context)
|
| 482 |
+
hs.append(h)
|
| 483 |
+
h = self.middle_block(h, emb, context)
|
| 484 |
+
for module in self.output_blocks:
|
| 485 |
+
h = th.cat([h, hs.pop()], dim=1)
|
| 486 |
+
h = module(h, emb, context)
|
| 487 |
+
h = h.type(x.dtype)
|
| 488 |
+
return self.out(h)
|
| 489 |
+
|
ldcast/models/nowcast/nowcast.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
|
| 8 |
+
from ..blocks.afno import AFNOBlock3d
|
| 9 |
+
from ..blocks.attention import positional_encoding, TemporalTransformer
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Nowcaster(pl.LightningModule):
|
| 13 |
+
def __init__(self, nowcast_net):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.nowcast_net = nowcast_net
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
return self.nowcast_net(x)
|
| 19 |
+
|
| 20 |
+
def _loss(self, batch):
|
| 21 |
+
(x,y) = batch
|
| 22 |
+
y_pred = self.forward(x)
|
| 23 |
+
return (y-y_pred).square().mean()
|
| 24 |
+
|
| 25 |
+
def training_step(self, batch, batch_idx):
|
| 26 |
+
loss = self._loss(batch)
|
| 27 |
+
self.log("train_loss", loss)
|
| 28 |
+
return loss
|
| 29 |
+
|
| 30 |
+
@torch.no_grad()
|
| 31 |
+
def val_test_step(self, batch, batch_idx, split="val"):
|
| 32 |
+
loss = self._loss(batch)
|
| 33 |
+
log_params = {"on_step": False, "on_epoch": True, "prog_bar": True}
|
| 34 |
+
self.log(f"{split}_loss", loss, **log_params)
|
| 35 |
+
|
| 36 |
+
def validation_step(self, batch, batch_idx):
|
| 37 |
+
self.val_test_step(batch, batch_idx, split="val")
|
| 38 |
+
|
| 39 |
+
def test_step(self, batch, batch_idx):
|
| 40 |
+
self.val_test_step(batch, batch_idx, split="test")
|
| 41 |
+
|
| 42 |
+
def configure_optimizers(self):
|
| 43 |
+
optimizer = torch.optim.AdamW(
|
| 44 |
+
self.parameters(), lr=1e-3,
|
| 45 |
+
betas=(0.5, 0.9), weight_decay=1e-3
|
| 46 |
+
)
|
| 47 |
+
reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 48 |
+
optimizer, patience=3, factor=0.25, verbose=True
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
optimizer_spec = {
|
| 52 |
+
"optimizer": optimizer,
|
| 53 |
+
"lr_scheduler": {
|
| 54 |
+
"scheduler": reduce_lr,
|
| 55 |
+
"monitor": "val_loss",
|
| 56 |
+
"frequency": 1,
|
| 57 |
+
},
|
| 58 |
+
}
|
| 59 |
+
return optimizer_spec
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class AFNONowcastNetBasic(nn.Sequential):
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
embed_dim=256,
|
| 66 |
+
depth=12,
|
| 67 |
+
patch_size=(4,4,4)
|
| 68 |
+
):
|
| 69 |
+
patch_embed = PatchEmbed3d(
|
| 70 |
+
embed_dim=embed_dim, patch_size=patch_size
|
| 71 |
+
)
|
| 72 |
+
blocks = nn.Sequential(
|
| 73 |
+
*(AFNOBlock(embed_dim) for _ in range(depth))
|
| 74 |
+
)
|
| 75 |
+
patch_expand = PatchExpand3d(
|
| 76 |
+
embed_dim=embed_dim, patch_size=patch_size
|
| 77 |
+
)
|
| 78 |
+
super().__init__(*[patch_embed, blocks, patch_expand])
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class FusionBlock3d(nn.Module):
|
| 82 |
+
def __init__(self, dim, size_ratios, dim_out=None, afno_fusion=False):
|
| 83 |
+
super().__init__()
|
| 84 |
+
|
| 85 |
+
N_sources = len(size_ratios)
|
| 86 |
+
if not isinstance(dim, collections.abc.Sequence):
|
| 87 |
+
dim = (dim,) * N_sources
|
| 88 |
+
if dim_out is None:
|
| 89 |
+
dim_out = dim[0]
|
| 90 |
+
|
| 91 |
+
self.scale = nn.ModuleList()
|
| 92 |
+
for (i,size_ratio) in enumerate(size_ratios):
|
| 93 |
+
if size_ratio == 1:
|
| 94 |
+
scale = nn.Identity()
|
| 95 |
+
else:
|
| 96 |
+
scale = []
|
| 97 |
+
while size_ratio > 1:
|
| 98 |
+
scale.append(nn.ConvTranspose3d(
|
| 99 |
+
dim[i], dim_out if size_ratio==2 else dim[i],
|
| 100 |
+
kernel_size=(1,3,3), stride=(1,2,2),
|
| 101 |
+
padding=(0,1,1), output_padding=(0,1,1)
|
| 102 |
+
))
|
| 103 |
+
size_ratio //= 2
|
| 104 |
+
scale = nn.Sequential(*scale)
|
| 105 |
+
self.scale.append(scale)
|
| 106 |
+
|
| 107 |
+
self.afno_fusion = afno_fusion
|
| 108 |
+
|
| 109 |
+
if self.afno_fusion:
|
| 110 |
+
if N_sources > 1:
|
| 111 |
+
self.fusion = nn.Sequential(
|
| 112 |
+
nn.Linear(sum(dim), sum(dim)),
|
| 113 |
+
AFNOBlock3d(dim*N_sources, mlp_ratio=2),
|
| 114 |
+
nn.Linear(sum(dim), dim_out)
|
| 115 |
+
)
|
| 116 |
+
else:
|
| 117 |
+
self.fusion = nn.Identity()
|
| 118 |
+
|
| 119 |
+
def resize_proj(self, x, i):
|
| 120 |
+
x = x.permute(0,4,1,2,3)
|
| 121 |
+
x = self.scale[i](x)
|
| 122 |
+
x = x.permute(0,2,3,4,1)
|
| 123 |
+
return x
|
| 124 |
+
|
| 125 |
+
def forward(self, x):
|
| 126 |
+
x = [self.resize_proj(xx, i) for (i, xx) in enumerate(x)]
|
| 127 |
+
if self.afno_fusion:
|
| 128 |
+
x = torch.concat(x, axis=-1)
|
| 129 |
+
x = self.fusion(x)
|
| 130 |
+
else:
|
| 131 |
+
x = sum(x)
|
| 132 |
+
return x
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class AFNONowcastNetBase(nn.Module):
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
autoencoder,
|
| 139 |
+
embed_dim=128,
|
| 140 |
+
embed_dim_out=None,
|
| 141 |
+
analysis_depth=4,
|
| 142 |
+
forecast_depth=4,
|
| 143 |
+
input_patches=(1,),
|
| 144 |
+
input_size_ratios=(1,),
|
| 145 |
+
output_patches=2,
|
| 146 |
+
train_autoenc=False,
|
| 147 |
+
afno_fusion=False
|
| 148 |
+
):
|
| 149 |
+
super().__init__()
|
| 150 |
+
|
| 151 |
+
self.train_autoenc = train_autoenc
|
| 152 |
+
if not isinstance(autoencoder, collections.abc.Sequence):
|
| 153 |
+
autoencoder = [autoencoder]
|
| 154 |
+
if not isinstance(input_patches, collections.abc.Sequence):
|
| 155 |
+
input_patches = [input_patches]
|
| 156 |
+
num_inputs = len(autoencoder)
|
| 157 |
+
if not isinstance(embed_dim, collections.abc.Sequence):
|
| 158 |
+
embed_dim = [embed_dim] * num_inputs
|
| 159 |
+
if embed_dim_out is None:
|
| 160 |
+
embed_dim_out = embed_dim[0]
|
| 161 |
+
if not isinstance(analysis_depth, collections.abc.Sequence):
|
| 162 |
+
analysis_depth = [analysis_depth] * num_inputs
|
| 163 |
+
self.embed_dim = embed_dim
|
| 164 |
+
self.embed_dim_out = embed_dim_out
|
| 165 |
+
self.output_patches = output_patches
|
| 166 |
+
|
| 167 |
+
# encoding + analysis for each input
|
| 168 |
+
self.autoencoder = nn.ModuleList()
|
| 169 |
+
self.proj = nn.ModuleList()
|
| 170 |
+
self.analysis = nn.ModuleList()
|
| 171 |
+
for i in range(num_inputs):
|
| 172 |
+
ae = autoencoder[i].requires_grad_(train_autoenc)
|
| 173 |
+
self.autoencoder.append(ae)
|
| 174 |
+
|
| 175 |
+
proj = nn.Conv3d(ae.hidden_width, embed_dim[i], kernel_size=1)
|
| 176 |
+
self.proj.append(proj)
|
| 177 |
+
|
| 178 |
+
analysis = nn.Sequential(
|
| 179 |
+
*(AFNOBlock3d(embed_dim[i]) for _ in range(analysis_depth[i]))
|
| 180 |
+
)
|
| 181 |
+
self.analysis.append(analysis)
|
| 182 |
+
|
| 183 |
+
# temporal transformer
|
| 184 |
+
self.use_temporal_transformer = \
|
| 185 |
+
any((ipp != output_patches) for ipp in input_patches)
|
| 186 |
+
if self.use_temporal_transformer:
|
| 187 |
+
self.temporal_transformer = nn.ModuleList(
|
| 188 |
+
TemporalTransformer(embed_dim[i]) for i in range(num_inputs)
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# data fusion
|
| 192 |
+
self.fusion = FusionBlock3d(embed_dim, input_size_ratios,
|
| 193 |
+
afno_fusion=afno_fusion, dim_out=embed_dim_out)
|
| 194 |
+
|
| 195 |
+
# forecast
|
| 196 |
+
self.forecast = nn.Sequential(
|
| 197 |
+
*(AFNOBlock3d(embed_dim_out) for _ in range(forecast_depth))
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
def add_pos_enc(self, x, t):
|
| 201 |
+
if t.shape[1] != x.shape[1]:
|
| 202 |
+
# this can happen if x has been compressed
|
| 203 |
+
# by the autoencoder in the time dimension
|
| 204 |
+
ds_factor = t.shape[1] // x.shape[1]
|
| 205 |
+
t = F.avg_pool1d(t.unsqueeze(1), ds_factor)[:,0,:]
|
| 206 |
+
|
| 207 |
+
pos_enc = positional_encoding(t, x.shape[-1], add_dims=(2,3))
|
| 208 |
+
return x + pos_enc
|
| 209 |
+
|
| 210 |
+
def forward(self, x):
|
| 211 |
+
(x, t_relative) = list(zip(*x))
|
| 212 |
+
|
| 213 |
+
# encoding + analysis for each input
|
| 214 |
+
def process_input(i):
|
| 215 |
+
z = self.autoencoder[i].encode(x[i])[0]
|
| 216 |
+
z = self.proj[i](z)
|
| 217 |
+
z = z.permute(0,2,3,4,1)
|
| 218 |
+
z = self.analysis[i](z)
|
| 219 |
+
if self.use_temporal_transformer:
|
| 220 |
+
# add positional encoding
|
| 221 |
+
z = self.add_pos_enc(z, t_relative[i])
|
| 222 |
+
|
| 223 |
+
# transform to output shape and coordinates
|
| 224 |
+
expand_shape = z.shape[:1] + (-1,) + z.shape[2:]
|
| 225 |
+
pos_enc_output = positional_encoding(
|
| 226 |
+
torch.arange(1,self.output_patches+1, device=z.device),
|
| 227 |
+
self.embed_dim[i], add_dims=(0,2,3)
|
| 228 |
+
)
|
| 229 |
+
pe_out = pos_enc_output.expand(*expand_shape)
|
| 230 |
+
z = self.temporal_transformer[i](pe_out, z)
|
| 231 |
+
return z
|
| 232 |
+
|
| 233 |
+
x = [process_input(i) for i in range(len(x))]
|
| 234 |
+
|
| 235 |
+
# merge inputs
|
| 236 |
+
x = self.fusion(x)
|
| 237 |
+
# produce prediction
|
| 238 |
+
x = self.forecast(x)
|
| 239 |
+
return x.permute(0,4,1,2,3) # to channels-first order
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class AFNONowcastNet(AFNONowcastNetBase):
|
| 243 |
+
def __init__(self, autoencoder, output_autoencoder=None, **kwargs):
|
| 244 |
+
super().__init__(autoencoder, **kwargs)
|
| 245 |
+
if output_autoencoder is None:
|
| 246 |
+
output_autoencoder = autoencoder[0]
|
| 247 |
+
self.output_autoencoder = output_autoencoder.requires_grad_(
|
| 248 |
+
self.train_autoenc)
|
| 249 |
+
self.out_proj = nn.Conv3d(
|
| 250 |
+
self.embed_dim_out, output_autoencoder.hidden_width, kernel_size=1
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
def forward(self, x):
|
| 254 |
+
x = super().forward(x)
|
| 255 |
+
x = self.out_proj(x)
|
| 256 |
+
return self.output_autoencoder.decode(x)
|
ldcast/models/utils.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def normalization(channels, norm_type="group", num_groups=32):
|
| 6 |
+
if norm_type == "batch":
|
| 7 |
+
return nn.BatchNorm3d(channels)
|
| 8 |
+
elif norm_type == "group":
|
| 9 |
+
return nn.GroupNorm(num_groups=num_groups, num_channels=channels)
|
| 10 |
+
elif (not norm_type) or (norm_type.tolower() == 'none'):
|
| 11 |
+
return nn.Identity()
|
| 12 |
+
else:
|
| 13 |
+
raise NotImplementedError(norm)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def activation(act_type="swish"):
|
| 17 |
+
if act_type == "swish":
|
| 18 |
+
return nn.SiLU()
|
| 19 |
+
elif act_type == "gelu":
|
| 20 |
+
return nn.GELU()
|
| 21 |
+
elif act_type == "relu":
|
| 22 |
+
return nn.ReLU()
|
| 23 |
+
elif act_type == "tanh":
|
| 24 |
+
return nn.Tanh()
|
| 25 |
+
elif not act_type:
|
| 26 |
+
return nn.Identity()
|
| 27 |
+
else:
|
| 28 |
+
raise NotImplementedError(act_type)
|
ldcast/visualization/cm.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from matplotlib.colors import LinearSegmentedColormap
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def yuv_rainbow_24(nc):
|
| 6 |
+
""" From https://github.com/ARM-DOE/pyart/blob/main/pyart/graph/_cm_colorblind.py
|
| 7 |
+
"""
|
| 8 |
+
path1 = np.linspace(0.8*np.pi, 1.8*np.pi, nc)
|
| 9 |
+
path2 = np.linspace(-0.33*np.pi, 0.33*np.pi, nc)
|
| 10 |
+
|
| 11 |
+
y = np.concatenate([np.linspace(0.3, 0.85, nc*2//5),
|
| 12 |
+
np.linspace(0.9, 0.0, nc - nc*2//5)])
|
| 13 |
+
u = 0.40*np.sin(path1)
|
| 14 |
+
v = 0.55*np.sin(path2) + 0.1
|
| 15 |
+
|
| 16 |
+
rgb_from_yuv = np.array([[1, 0, 1.13983],
|
| 17 |
+
[1, -0.39465, -0.58060],
|
| 18 |
+
[1, 2.03211, 0]])
|
| 19 |
+
cmap_dict = {'blue': [], 'green': [], 'red': []}
|
| 20 |
+
for i in range(len(y)):
|
| 21 |
+
yuv = np.array([y[i], u[i], v[i]])
|
| 22 |
+
rgb = rgb_from_yuv.dot(yuv)
|
| 23 |
+
red_tuple = (i/(len(y)-1.0), rgb[0], rgb[0])
|
| 24 |
+
green_tuple = (i/(len(y)-1.0), rgb[1], rgb[1])
|
| 25 |
+
blue_tuple = (i/(len(y)-1.0), rgb[2], rgb[2])
|
| 26 |
+
cmap_dict['blue'].append(blue_tuple)
|
| 27 |
+
cmap_dict['red'].append(red_tuple)
|
| 28 |
+
cmap_dict['green'].append(green_tuple)
|
| 29 |
+
|
| 30 |
+
return cmap_dict
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
homeyer_rainbow = LinearSegmentedColormap(
|
| 34 |
+
"homeyer_rainbow",
|
| 35 |
+
yuv_rainbow_24(15)
|
| 36 |
+
)
|
ldcast/visualization/plots.py
ADDED
|
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import concurrent
|
| 2 |
+
import multiprocessing
|
| 3 |
+
import netCDF4
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from matplotlib import gridspec, colors, pyplot as plt
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from ..analysis import confmatrix, fss, rank
|
| 11 |
+
from ..features import io
|
| 12 |
+
from .cm import homeyer_rainbow
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def reverse_transform_R(R, mean=-0.051, std=0.528):
|
| 16 |
+
return 10**(R*std + mean)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def plot_precip_image(
|
| 20 |
+
ax, R,
|
| 21 |
+
Rmin=0, Rmax=25, threshold_mmh=0.1,
|
| 22 |
+
transform_R=False,
|
| 23 |
+
grid_spacing=64
|
| 24 |
+
):
|
| 25 |
+
if isinstance(R, torch.Tensor):
|
| 26 |
+
R = R.detach().numpy()
|
| 27 |
+
if transform_R:
|
| 28 |
+
R = reverse_transform_R(R, mean=mean, std=std)
|
| 29 |
+
Rmin = reverse_transform_R(Rmin)
|
| 30 |
+
Rmax = reverse_transform_R(Rmax)
|
| 31 |
+
if threshold_mmh:
|
| 32 |
+
Rmin = max(Rmin, threshold_mmh)
|
| 33 |
+
R[R < threshold_mmh] = np.nan
|
| 34 |
+
norm = colors.LogNorm(Rmin, Rmax)
|
| 35 |
+
ax.set_yticks(np.arange(0, R.shape[0], grid_spacing))
|
| 36 |
+
ax.set_xticks(np.arange(0, R.shape[1], grid_spacing))
|
| 37 |
+
ax.grid(which='major', alpha=0.35)
|
| 38 |
+
ax.tick_params(left=False, bottom=False,
|
| 39 |
+
labelleft=False, labelbottom=False)
|
| 40 |
+
|
| 41 |
+
return ax.imshow(R, norm=norm, cmap=homeyer_rainbow)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def plot_autoencoder_reconstruction(
|
| 45 |
+
R, R_hat, samples=8, timesteps=4,
|
| 46 |
+
out_file=None
|
| 47 |
+
):
|
| 48 |
+
fig = plt.figure(figsize=(2*timesteps*2+0.5, samples*2), dpi=150)
|
| 49 |
+
|
| 50 |
+
gs = gridspec.GridSpec(
|
| 51 |
+
samples, 2*timesteps+1,
|
| 52 |
+
width_ratios=(1,)*(2*timesteps)+(0.2,),
|
| 53 |
+
wspace=0.02, hspace=0.02
|
| 54 |
+
)
|
| 55 |
+
for k in range(samples):
|
| 56 |
+
for (i,j) in enumerate(range(-timesteps,0)):
|
| 57 |
+
ax = fig.add_subplot(gs[k,i])
|
| 58 |
+
im = plot_precip_image(ax, R[k,0,j,:,:])
|
| 59 |
+
for (i,j) in enumerate(range(-timesteps,0)):
|
| 60 |
+
ax = fig.add_subplot(gs[k,i+timesteps])
|
| 61 |
+
im = plot_precip_image(ax, R_hat[k,0,j,:,:])
|
| 62 |
+
|
| 63 |
+
cax = fig.add_subplot(gs[:,-1])
|
| 64 |
+
plt.colorbar(im, cax=cax)
|
| 65 |
+
|
| 66 |
+
if out_file is not None:
|
| 67 |
+
out_file = fig.savefig(out_file, bbox_inches='tight')
|
| 68 |
+
plt.close(fig)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def plot_animation(x, y, out_dir, sample=0, fmt="{}_{:02d}.png"):
|
| 72 |
+
def plot_frame(R, label, timestep):
|
| 73 |
+
fig = plt.figure()
|
| 74 |
+
ax = fig.add_subplot()
|
| 75 |
+
im = plot_precip_image(ax, R[sample,0,timestep,:,:])
|
| 76 |
+
fn = fmt.format(label, k)
|
| 77 |
+
fn = os.path.join(out_dir, fn)
|
| 78 |
+
fig.savefig(fn, bbox_inches='tight')
|
| 79 |
+
plt.close(fig)
|
| 80 |
+
|
| 81 |
+
for k in range(x.shape[2]):
|
| 82 |
+
plot_frame(x, "x", k)
|
| 83 |
+
|
| 84 |
+
for k in range(y.shape[2]):
|
| 85 |
+
plot_frame(y, "y", k)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
model_colors = {
|
| 89 |
+
"mch-dgmr": "#E69F00",
|
| 90 |
+
"mch-pysteps": "#009E73",
|
| 91 |
+
"mch-iters=50-res=256": "#0072B2",
|
| 92 |
+
"mch-persistence": "#888888",
|
| 93 |
+
"dwd-dgmr": "#E69F00",
|
| 94 |
+
"dwd-pysteps": "#009E73",
|
| 95 |
+
"dwd-iters=50-res=256": "#0072B2",
|
| 96 |
+
"dwd-persistence": "#888888",
|
| 97 |
+
"pm-mch-dgmr": "#E69F00",
|
| 98 |
+
"pm-mch-pysteps": "#009E73",
|
| 99 |
+
"pm-mch-iters=50-res=256": "#0072B2",
|
| 100 |
+
"pm-dwd-dgmr": "#E69F00",
|
| 101 |
+
"pm-dwd-pysteps": "#009E73",
|
| 102 |
+
"pm-dwd-iters=50-res=256": "#0072B2",
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
scale_linestyles = {
|
| 106 |
+
"1x1": "-",
|
| 107 |
+
"8x8": "--",
|
| 108 |
+
"64x64": ":"
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def plot_crps(
|
| 113 |
+
log=False,
|
| 114 |
+
models=("iters=50-res=256", "dgmr", "pysteps"),
|
| 115 |
+
scales=("1x1", "8x8", "64x64"),
|
| 116 |
+
model_labels=("LDCast", "DGMR", "PySTEPS", "Persist."),
|
| 117 |
+
interval_mins=5,
|
| 118 |
+
out_fn=None,
|
| 119 |
+
ax=None,
|
| 120 |
+
add_xlabel=True,
|
| 121 |
+
add_ylabel=True,
|
| 122 |
+
add_legend=True,
|
| 123 |
+
crop_box=None
|
| 124 |
+
):
|
| 125 |
+
crps = {}
|
| 126 |
+
crps_name = "logcrps" if log else "crps"
|
| 127 |
+
for model in models:
|
| 128 |
+
crps[model] = {}
|
| 129 |
+
fn = f"../results/crps/{crps_name}-{model}.nc"
|
| 130 |
+
with netCDF4.Dataset(fn, 'r') as ds:
|
| 131 |
+
for scale in scales:
|
| 132 |
+
var = f"crps_pool{scale}"
|
| 133 |
+
crps_model_scale = np.array(ds[var][:], copy=False)
|
| 134 |
+
if crop_box is not None:
|
| 135 |
+
scale_int = int(scale.split("x")[0])
|
| 136 |
+
crps_model_scale = crps_model_scale[
|
| 137 |
+
...,
|
| 138 |
+
crop_box[0][0]//scale_int:crop_box[0][1]//scale_int,
|
| 139 |
+
crop_box[1][0]//scale_int:crop_box[1][1]//scale_int
|
| 140 |
+
]
|
| 141 |
+
crps[model][scale] = crps_model_scale.mean(axis=(0,1,3,4))
|
| 142 |
+
del crps_model_scale
|
| 143 |
+
|
| 144 |
+
if ax is None:
|
| 145 |
+
fig = plt.figure(figsize=(8,5))
|
| 146 |
+
ax = fig.add_subplot()
|
| 147 |
+
|
| 148 |
+
max_t = 0
|
| 149 |
+
for (model, label) in zip(models, model_labels):
|
| 150 |
+
for scale in scales:
|
| 151 |
+
score = crps[model][scale]
|
| 152 |
+
color = model_colors[model]
|
| 153 |
+
linestyle = scale_linestyles[scale]
|
| 154 |
+
t = np.arange(
|
| 155 |
+
interval_mins, (len(score)+0.1)*interval_mins, interval_mins
|
| 156 |
+
)
|
| 157 |
+
max_t = max(max_t, t[-1])
|
| 158 |
+
ax.plot(t, score, color=color, linestyle=linestyle,
|
| 159 |
+
label=label)
|
| 160 |
+
|
| 161 |
+
if add_legend:
|
| 162 |
+
plt.legend()
|
| 163 |
+
if add_xlabel:
|
| 164 |
+
plt.xlabel("Lead time [min]", fontsize=12)
|
| 165 |
+
if add_ylabel:
|
| 166 |
+
plt.ylabel(
|
| 167 |
+
"LogCRPS" if log else "CRPS [mm h$^\\mathrm{-1}$]",
|
| 168 |
+
fontsize=12
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
ax.set_xlim((0, max_t))
|
| 172 |
+
ylim = ax.get_ylim()
|
| 173 |
+
ylim = (0, ylim[1])
|
| 174 |
+
ax.set_ylim(ylim)
|
| 175 |
+
ax.tick_params(axis='both', which='major', labelsize=12)
|
| 176 |
+
|
| 177 |
+
if out_fn is not None:
|
| 178 |
+
fig.savefig(out_fn, bbox_inches='tight')
|
| 179 |
+
plt.close(fig)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def plot_rank_distribution(
|
| 183 |
+
models=("iters=50-res=256", "dgmr", "pysteps"),
|
| 184 |
+
scales=("1x1", "8x8", "64x64"),
|
| 185 |
+
model_labels=("LDCast", "DGMR", "PySTEPS"),
|
| 186 |
+
out_fn=None,
|
| 187 |
+
num_ensemble_members=32,
|
| 188 |
+
ax=None,
|
| 189 |
+
add_xlabel=True,
|
| 190 |
+
add_ylabel=True,
|
| 191 |
+
add_legend=True,
|
| 192 |
+
crop_box=None
|
| 193 |
+
):
|
| 194 |
+
rank_hist = {}
|
| 195 |
+
rank_KL = {}
|
| 196 |
+
for model in models:
|
| 197 |
+
fn = f"../results/ranks/ranks-{model}.nc"
|
| 198 |
+
with netCDF4.Dataset(fn, 'r') as ds:
|
| 199 |
+
for scale in scales:
|
| 200 |
+
var = f"ranks_pool{scale}"
|
| 201 |
+
ranks_model_scale = np.array(ds[var][:], copy=False)
|
| 202 |
+
if crop_box is not None:
|
| 203 |
+
scale_int = int(scale.split("x")[0])
|
| 204 |
+
ranks_model_scale = ranks_model_scale[
|
| 205 |
+
...,
|
| 206 |
+
crop_box[0][0]//scale_int:crop_box[0][1]//scale_int,
|
| 207 |
+
crop_box[1][0]//scale_int:crop_box[1][1]//scale_int
|
| 208 |
+
]
|
| 209 |
+
rank_hist[(model,scale)] = rank.rank_distribution(ranks_model_scale)
|
| 210 |
+
rank_KL[(model,scale)] = rank.rank_DKL(rank_hist[(model,scale)])
|
| 211 |
+
del ranks_model_scale
|
| 212 |
+
|
| 213 |
+
if ax is None:
|
| 214 |
+
fig = plt.figure(figsize=(8,5))
|
| 215 |
+
ax = fig.add_subplot()
|
| 216 |
+
|
| 217 |
+
for scale in scales:
|
| 218 |
+
linestyle = scale_linestyles[scale]
|
| 219 |
+
for (model, label) in zip(models, model_labels):
|
| 220 |
+
h = rank_hist[(model,scale)]
|
| 221 |
+
color = model_colors[model]
|
| 222 |
+
x = np.linspace(0, 1, num_ensemble_members+1)
|
| 223 |
+
label_with_score = f"{label}: {rank_KL[(model,scale)]:.3f}"
|
| 224 |
+
ax.plot(x, h, color=color, linestyle=linestyle,
|
| 225 |
+
label=label_with_score)
|
| 226 |
+
h_ideal = 1/(num_ensemble_members+1)
|
| 227 |
+
ax.plot([0, 1], [h_ideal, h_ideal], color=(0.4,0.4,0.4),
|
| 228 |
+
linewidth=1.0)
|
| 229 |
+
|
| 230 |
+
if add_legend:
|
| 231 |
+
ax.legend(loc='upper center')
|
| 232 |
+
if add_xlabel:
|
| 233 |
+
ax.set_xlabel("Normalized rank", fontsize=12)
|
| 234 |
+
if add_ylabel:
|
| 235 |
+
ax.set_ylabel("Occurrence", fontsize=12)
|
| 236 |
+
|
| 237 |
+
ax.set_xlim((0, 1))
|
| 238 |
+
ylim = ax.get_ylim()
|
| 239 |
+
ylim = (0, ylim[1])
|
| 240 |
+
ax.set_ylim(ylim)
|
| 241 |
+
ax.tick_params(axis='both', which='major', labelsize=12)
|
| 242 |
+
ax.set_xticks([0, 0.25, 0.5, 0.75, 1])
|
| 243 |
+
# int labels for 0 and 1 to save space
|
| 244 |
+
ax.set_xticklabels(["0", "0.25", "0.5", "0.75", "1"])
|
| 245 |
+
|
| 246 |
+
if out_fn is not None:
|
| 247 |
+
fig.savefig(out_fn, bbox_inches='tight')
|
| 248 |
+
plt.close(fig)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def plot_rank_metric(
|
| 252 |
+
models=("iters=50-res=256", "dgmr", "pysteps"),
|
| 253 |
+
scales=("1x1", "8x8", "64x64"),
|
| 254 |
+
model_labels=("LDCast", "DGMR", "PySTEPS"),
|
| 255 |
+
interval_mins=5,
|
| 256 |
+
metric_name="KL",
|
| 257 |
+
out_fn=None,
|
| 258 |
+
ax=None,
|
| 259 |
+
add_xlabel=True,
|
| 260 |
+
add_ylabel=True,
|
| 261 |
+
add_legend=True,
|
| 262 |
+
):
|
| 263 |
+
rank_metric = {}
|
| 264 |
+
for model in models:
|
| 265 |
+
rank_metric[model] = {}
|
| 266 |
+
fn = f"../results/ranks/ranks-{model}.nc"
|
| 267 |
+
with netCDF4.Dataset(fn, 'r') as ds:
|
| 268 |
+
for scale in scales:
|
| 269 |
+
var = f"ranks_pool{scale}"
|
| 270 |
+
ranks = np.array(ds[var][:], copy=False)
|
| 271 |
+
rank_metric[model][scale] = rank.rank_metric_by_leadtime(ranks)
|
| 272 |
+
del ranks
|
| 273 |
+
|
| 274 |
+
if ax is None:
|
| 275 |
+
fig = plt.figure(figsize=(8,5))
|
| 276 |
+
ax = fig.add_subplot()
|
| 277 |
+
|
| 278 |
+
max_t = 0
|
| 279 |
+
for (model, label) in zip(models, model_labels):
|
| 280 |
+
for scale in scales:
|
| 281 |
+
score = rank_metric[model][scale]
|
| 282 |
+
color = model_colors[model]
|
| 283 |
+
linestyle = scale_linestyles[scale]
|
| 284 |
+
label_with_scale = f"{label} {scale}"
|
| 285 |
+
t = np.arange(
|
| 286 |
+
interval_mins, (len(score)+0.1)*interval_mins, interval_mins
|
| 287 |
+
)
|
| 288 |
+
max_t = max(max_t, t[-1])
|
| 289 |
+
ax.plot(t, score, color=color, linestyle=linestyle,
|
| 290 |
+
label=label_with_scale)
|
| 291 |
+
|
| 292 |
+
if add_legend:
|
| 293 |
+
plt.legend()
|
| 294 |
+
if add_xlabel:
|
| 295 |
+
plt.xlabel("Lead time [min]", fontsize=12)
|
| 296 |
+
if add_ylabel:
|
| 297 |
+
plt.ylabel(f"Rank {metric_name}", fontsize=12)
|
| 298 |
+
|
| 299 |
+
ax.set_xlim((0, max_t))
|
| 300 |
+
ylim = ax.get_ylim()
|
| 301 |
+
ylim = (0, ylim[1])
|
| 302 |
+
ax.set_ylim(ylim)
|
| 303 |
+
ax.tick_params(axis='both', which='major', labelsize=12)
|
| 304 |
+
|
| 305 |
+
if out_fn is not None:
|
| 306 |
+
fig.savefig(out_fn, bbox_inches='tight')
|
| 307 |
+
plt.close(fig)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def load_fss(model, scale, use_timesteps, crop_box):
|
| 311 |
+
fn = f"../results/fractions/fractions-{model}.nc"
|
| 312 |
+
with netCDF4.Dataset(fn, 'r') as ds:
|
| 313 |
+
sn = f"{scale}x{scale}"
|
| 314 |
+
obs_frac = np.array(ds[f"obs_frac_scale{sn}"][:], copy=False)
|
| 315 |
+
fc_frac = np.array(ds[f"fc_frac_scale{sn}"][:], copy=False)
|
| 316 |
+
if crop_box is not None:
|
| 317 |
+
obs_frac = obs_frac[
|
| 318 |
+
...,
|
| 319 |
+
crop_box[0][0]//scale:crop_box[0][1]//scale,
|
| 320 |
+
crop_box[1][0]//scale:crop_box[1][1]//scale
|
| 321 |
+
]
|
| 322 |
+
fc_frac = fc_frac[
|
| 323 |
+
...,
|
| 324 |
+
crop_box[0][0]//scale:crop_box[0][1]//scale,
|
| 325 |
+
crop_box[1][0]//scale:crop_box[1][1]//scale
|
| 326 |
+
]
|
| 327 |
+
return fss.fractions_skill_score(
|
| 328 |
+
obs_frac, fc_frac, use_timesteps=use_timesteps
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def plot_fss(
|
| 333 |
+
log=False,
|
| 334 |
+
models=("iters=50-res=256", "dgmr", "pysteps"),
|
| 335 |
+
model_labels=("LDCast", "DGMR", "PySTEPS"),
|
| 336 |
+
interval_mins=5,
|
| 337 |
+
out_fn=None,
|
| 338 |
+
ax=None,
|
| 339 |
+
add_xlabel=True,
|
| 340 |
+
add_ylabel=True,
|
| 341 |
+
add_legend=True,
|
| 342 |
+
scales=None,
|
| 343 |
+
use_timesteps=18,
|
| 344 |
+
crop_box=None
|
| 345 |
+
):
|
| 346 |
+
if scales is None:
|
| 347 |
+
scales = 2**np.arange(9)
|
| 348 |
+
fss_scale = {}
|
| 349 |
+
N_threads = min(multiprocessing.cpu_count(), len(models)*len(scales))
|
| 350 |
+
with concurrent.futures.ProcessPoolExecutor(N_threads) as executor:
|
| 351 |
+
for model in models:
|
| 352 |
+
fss_scale[model] = {}
|
| 353 |
+
for scale in scales:
|
| 354 |
+
fss_scale[model][scale] = executor.submit(
|
| 355 |
+
load_fss, model, scale, use_timesteps, crop_box
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
for model in models:
|
| 359 |
+
for scale in scales:
|
| 360 |
+
fss_scale[model][scale] = fss_scale[model][scale].result()
|
| 361 |
+
|
| 362 |
+
if ax is None:
|
| 363 |
+
fig = plt.figure(figsize=(8,5))
|
| 364 |
+
ax = fig.add_subplot()
|
| 365 |
+
|
| 366 |
+
for (model, label) in zip(models, model_labels):
|
| 367 |
+
scales = sorted(fss_scale[model])
|
| 368 |
+
fss_for_model = [fss_scale[model][s] for s in scales]
|
| 369 |
+
|
| 370 |
+
model_parts = model.split("-")
|
| 371 |
+
if model.startswith("pm-"):
|
| 372 |
+
model_without_threshold = "-".join(model_parts[:1] + model_parts[2:])
|
| 373 |
+
else:
|
| 374 |
+
model_without_threshold = "-".join(model_parts[1:])
|
| 375 |
+
color = model_colors[model_without_threshold]
|
| 376 |
+
|
| 377 |
+
ax.plot(scales, fss_for_model, color=color,
|
| 378 |
+
label=label)
|
| 379 |
+
|
| 380 |
+
if add_legend:
|
| 381 |
+
plt.legend()
|
| 382 |
+
if add_xlabel:
|
| 383 |
+
plt.xlabel("Scale [km]", fontsize=12)
|
| 384 |
+
if add_ylabel:
|
| 385 |
+
plt.ylabel("FSS")
|
| 386 |
+
|
| 387 |
+
ax.set_xlim((scales[0], scales[-1]))
|
| 388 |
+
ylim = ax.get_ylim()
|
| 389 |
+
ylim = (0, ylim[1])
|
| 390 |
+
ax.set_ylim(ylim)
|
| 391 |
+
ax.tick_params(axis='both', which='major', labelsize=12)
|
| 392 |
+
|
| 393 |
+
if out_fn is not None:
|
| 394 |
+
fig.savefig(out_fn, bbox_inches='tight')
|
| 395 |
+
plt.close(fig)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def plot_csi_threshold(
|
| 399 |
+
models=("iters=50-res=256", "dgmr", "pysteps"),
|
| 400 |
+
scales=("1x1", "8x8", "64x64"),
|
| 401 |
+
prob_thresholds=tuple(np.linspace(0,1,33)),
|
| 402 |
+
model_labels=("LDCast", "DGMR", "PySTEPS"),
|
| 403 |
+
out_fn=None,
|
| 404 |
+
num_ensemble_members=32,
|
| 405 |
+
max_timestep=18,
|
| 406 |
+
ax=None,
|
| 407 |
+
add_xlabel=True,
|
| 408 |
+
add_ylabel=True,
|
| 409 |
+
add_legend=True,
|
| 410 |
+
crop_box=None
|
| 411 |
+
):
|
| 412 |
+
csi = {}
|
| 413 |
+
for model in models:
|
| 414 |
+
fn = f"../results/fractions/fractions-{model}.nc"
|
| 415 |
+
with netCDF4.Dataset(fn, 'r') as ds:
|
| 416 |
+
for scale in scales:
|
| 417 |
+
fc_var = f"fc_frac_scale{scale}"
|
| 418 |
+
fc_frac = np.array(ds[fc_var], copy=False)
|
| 419 |
+
fc_frac = fc_frac[...,:max_timestep,:,:]
|
| 420 |
+
obs_var = f"obs_frac_scale{scale}"
|
| 421 |
+
obs_frac = np.array(ds[obs_var], copy=False)
|
| 422 |
+
obs_frac = obs_frac[...,:max_timestep,:,:]
|
| 423 |
+
conf_matrix = confmatrix.confusion_matrix_thresholds(
|
| 424 |
+
fc_frac, obs_frac, prob_thresholds
|
| 425 |
+
)
|
| 426 |
+
del fc_frac, obs_frac
|
| 427 |
+
csi_scale = confmatrix.intersection_over_union(conf_matrix)
|
| 428 |
+
csi[(model,scale)] = csi_scale
|
| 429 |
+
|
| 430 |
+
if ax is None:
|
| 431 |
+
fig = plt.figure(figsize=(8,5))
|
| 432 |
+
ax = fig.add_subplot()
|
| 433 |
+
|
| 434 |
+
for scale in scales:
|
| 435 |
+
linestyle = scale_linestyles[scale]
|
| 436 |
+
for (model, label) in zip(models, model_labels):
|
| 437 |
+
c = csi[(model,scale)]
|
| 438 |
+
model_parts = model.split("-")
|
| 439 |
+
if model.startswith("pm-"):
|
| 440 |
+
model_without_threshold = "-".join(model_parts[:1] + model_parts[2:])
|
| 441 |
+
else:
|
| 442 |
+
model_without_threshold = "-".join(model_parts[1:])
|
| 443 |
+
color = model_colors[model_without_threshold]
|
| 444 |
+
ax.plot(prob_thresholds, c, color=color, linestyle=linestyle, label=label)
|
| 445 |
+
|
| 446 |
+
if add_legend:
|
| 447 |
+
ax.legend(loc='upper center')
|
| 448 |
+
if add_xlabel:
|
| 449 |
+
ax.set_xlabel("Prob. threshold", fontsize=12)
|
| 450 |
+
if add_ylabel:
|
| 451 |
+
ax.set_ylabel("CSI", fontsize=12)
|
| 452 |
+
|
| 453 |
+
ax.set_xlim((0, 1))
|
| 454 |
+
ylim = ax.get_ylim()
|
| 455 |
+
ylim = (0, ylim[1])
|
| 456 |
+
ax.set_ylim(ylim)
|
| 457 |
+
ax.tick_params(axis='both', which='major', labelsize=12)
|
| 458 |
+
ax.set_xticks([0, 0.25, 0.5, 0.75, 1])
|
| 459 |
+
# int labels for 0 and 1 to save space
|
| 460 |
+
ax.set_xticklabels(["0", "0.25", "0.5", "0.75", "1"])
|
| 461 |
+
|
| 462 |
+
if out_fn is not None:
|
| 463 |
+
fig.savefig(out_fn, bbox_inches='tight')
|
| 464 |
+
plt.close(fig)
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def plot_csi_leadtime(
|
| 468 |
+
models=("iters=50-res=256", "dgmr", "pysteps"),
|
| 469 |
+
scales=("1x1", "8x8", "64x64"),
|
| 470 |
+
prob_thresholds=tuple(np.linspace(0,1,33)),
|
| 471 |
+
model_labels=("LDCast", "DGMR", "PySTEPS"),
|
| 472 |
+
out_fn=None,
|
| 473 |
+
interval_mins=5,
|
| 474 |
+
num_ensemble_members=32,
|
| 475 |
+
ax=None,
|
| 476 |
+
add_xlabel=True,
|
| 477 |
+
add_ylabel=True,
|
| 478 |
+
add_legend=True,
|
| 479 |
+
crop_box=None
|
| 480 |
+
):
|
| 481 |
+
csi = {}
|
| 482 |
+
for model in models:
|
| 483 |
+
fn = f"../results/fractions/fractions-{model}.nc"
|
| 484 |
+
with netCDF4.Dataset(fn, 'r') as ds:
|
| 485 |
+
for scale in scales:
|
| 486 |
+
fc_var = f"fc_frac_scale{scale}"
|
| 487 |
+
fc_frac = np.array(ds[fc_var], copy=False)
|
| 488 |
+
obs_var = f"obs_frac_scale{scale}"
|
| 489 |
+
obs_frac = np.array(ds[obs_var], copy=False)
|
| 490 |
+
conf_matrix = confmatrix.confusion_matrix_thresholds_leadtime(
|
| 491 |
+
fc_frac, obs_frac, prob_thresholds
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
csi_scale = confmatrix.intersection_over_union(conf_matrix)
|
| 495 |
+
csi[(model,scale)] = np.nanmax(csi_scale, axis=1)
|
| 496 |
+
|
| 497 |
+
max_t = 0
|
| 498 |
+
for (model, label) in zip(models, model_labels):
|
| 499 |
+
for scale in scales:
|
| 500 |
+
score = csi[(model,scale)]
|
| 501 |
+
|
| 502 |
+
model_parts = model.split("-")
|
| 503 |
+
if model.startswith("pm-"):
|
| 504 |
+
model_without_threshold = "-".join(model_parts[:1] + model_parts[2:])
|
| 505 |
+
else:
|
| 506 |
+
model_without_threshold = "-".join(model_parts[1:])
|
| 507 |
+
color = model_colors[model_without_threshold]
|
| 508 |
+
linestyle = scale_linestyles[scale]
|
| 509 |
+
t = np.arange(
|
| 510 |
+
interval_mins, (len(score)+0.1)*interval_mins, interval_mins
|
| 511 |
+
)
|
| 512 |
+
max_t = max(max_t, t[-1])
|
| 513 |
+
ax.plot(t, score, color=color, linestyle=linestyle,
|
| 514 |
+
label=label)
|
| 515 |
+
|
| 516 |
+
if add_legend:
|
| 517 |
+
plt.legend()
|
| 518 |
+
if add_xlabel:
|
| 519 |
+
plt.xlabel("Lead time [min]", fontsize=12)
|
| 520 |
+
if add_ylabel:
|
| 521 |
+
plt.ylabel("CSI", fontsize=12)
|
| 522 |
+
|
| 523 |
+
ax.set_xlim((0, max_t))
|
| 524 |
+
ylim = ax.get_ylim()
|
| 525 |
+
ylim = (0, ylim[1])
|
| 526 |
+
ax.set_ylim(ylim)
|
| 527 |
+
ax.tick_params(axis='both', which='major', labelsize=12)
|
| 528 |
+
|
| 529 |
+
if out_fn is not None:
|
| 530 |
+
fig.savefig(out_fn, bbox_inches='tight')
|
| 531 |
+
plt.close(fig)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
def plot_cost_loss_value(
|
| 535 |
+
models=("iters=50-res=256", "dgmr", "pysteps"),
|
| 536 |
+
scales=("1x1", "8x8", "64x64"),
|
| 537 |
+
prob_thresholds=tuple(np.linspace(0,1,33)),
|
| 538 |
+
model_labels=("LDCast", "DGMR", "PySTEPS"),
|
| 539 |
+
out_fn=None,
|
| 540 |
+
interval_mins=5,
|
| 541 |
+
num_ensemble_members=32,
|
| 542 |
+
ax=None,
|
| 543 |
+
add_xlabel=True,
|
| 544 |
+
add_ylabel=True,
|
| 545 |
+
add_legend=True,
|
| 546 |
+
crop_box=None
|
| 547 |
+
):
|
| 548 |
+
value = {}
|
| 549 |
+
loss = 1.0
|
| 550 |
+
cost = np.linspace(0.01, 1, 100)
|
| 551 |
+
for model in models:
|
| 552 |
+
fn = f"../results/fractions/fractions-{model}.nc"
|
| 553 |
+
with netCDF4.Dataset(fn, 'r') as ds:
|
| 554 |
+
for scale in scales:
|
| 555 |
+
fc_var = f"fc_frac_scale{scale}"
|
| 556 |
+
fc_frac = np.array(ds[fc_var], copy=False)
|
| 557 |
+
obs_var = f"obs_frac_scale{scale}"
|
| 558 |
+
obs_frac = np.array(ds[obs_var], copy=False)
|
| 559 |
+
conf_matrix = confmatrix.confusion_matrix_thresholds(
|
| 560 |
+
fc_frac, obs_frac, prob_thresholds
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
p_clim = obs_frac.mean()
|
| 564 |
+
value_scale = []
|
| 565 |
+
for c in cost:
|
| 566 |
+
v = confmatrix.cost_loss_value(
|
| 567 |
+
conf_matrix, c, loss, p_clim
|
| 568 |
+
)
|
| 569 |
+
value_scale.append(v[len(v)//2])
|
| 570 |
+
value[(model,scale)] = np.array(value_scale)
|
| 571 |
+
|
| 572 |
+
max_score = 0
|
| 573 |
+
for (model, label) in zip(models, model_labels):
|
| 574 |
+
for scale in scales:
|
| 575 |
+
score = value[(model,scale)]
|
| 576 |
+
max_score = max(max_score, score[np.isfinite(score)].max())
|
| 577 |
+
|
| 578 |
+
model_parts = model.split("-")
|
| 579 |
+
if model.startswith("pm-"):
|
| 580 |
+
model_without_threshold = "-".join(model_parts[:1] + model_parts[2:])
|
| 581 |
+
else:
|
| 582 |
+
model_without_threshold = "-".join(model_parts[1:])
|
| 583 |
+
color = model_colors[model_without_threshold]
|
| 584 |
+
linestyle = scale_linestyles[scale]
|
| 585 |
+
|
| 586 |
+
ax.plot(cost, score, color=color, linestyle=linestyle,
|
| 587 |
+
label=label)
|
| 588 |
+
|
| 589 |
+
if add_legend:
|
| 590 |
+
plt.legend()
|
| 591 |
+
if add_xlabel:
|
| 592 |
+
plt.xlabel("Cost/loss ratio", fontsize=12)
|
| 593 |
+
if add_ylabel:
|
| 594 |
+
plt.ylabel("Value", fontsize=12)
|
| 595 |
+
|
| 596 |
+
ax.set_xlim((0, 1))
|
| 597 |
+
ylim = (0, max_score*1.05)
|
| 598 |
+
ax.set_ylim(ylim)
|
| 599 |
+
ax.tick_params(axis='both', which='major', labelsize=12)
|
| 600 |
+
ax.set_xticks([0, 0.25, 0.5, 0.75, 1])
|
| 601 |
+
# int labels for 0 and 1 to save space
|
| 602 |
+
ax.set_xticklabels(["0", "0.25", "0.5", "0.75", "1"])
|
| 603 |
+
|
| 604 |
+
if out_fn is not None:
|
| 605 |
+
fig.savefig(out_fn, bbox_inches='tight')
|
| 606 |
+
plt.close(fig)
|
models/.keep
ADDED
|
File without changes
|
models/autoenc/autoenc-32-0.01.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aa5ad4b8689aadbf702376e7afe5cb437ef5057675e78a8986837e8f28b3126e
|
| 3 |
+
size 1617490
|
models/autoenc/autoenc.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bf3792a94ee4cf347ca498b10b851d58a0f2ce6b3062e7b59fec5761c7edbf24
|
| 3 |
+
size 1616323
|
scripts/convert_data_NB_2nc.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import dask.array as da
|
| 4 |
+
import xarray as xr
|
| 5 |
+
|
| 6 |
+
def load_all_file(data_dir=""):
|
| 7 |
+
data_list = []
|
| 8 |
+
filtered_files = []
|
| 9 |
+
for filename in os.listdir(data_dir):
|
| 10 |
+
if filename.startswith("202306"):
|
| 11 |
+
filtered_files.append(filename)
|
| 12 |
+
# if filename.endswith("00.npy"):
|
| 13 |
+
# filtered_files.append(filename)
|
| 14 |
+
sorted_files = sorted(filtered_files)
|
| 15 |
+
for item in sorted_files:
|
| 16 |
+
sub_dir = os.path.join(data_dir)
|
| 17 |
+
pathfile = sub_dir + "/" + item
|
| 18 |
+
file = np.load(pathfile)
|
| 19 |
+
data_list.extend([file])
|
| 20 |
+
|
| 21 |
+
lon = np.arange(103.5, 109.2, 0.00892)
|
| 22 |
+
lat = np.arange(8, 13.75, 0.00899)
|
| 23 |
+
|
| 24 |
+
return data_list
|
| 25 |
+
|
| 26 |
+
def preprocess_data(data_list, out_dir=""):
|
| 27 |
+
patches = []
|
| 28 |
+
|
| 29 |
+
# Define patch size
|
| 30 |
+
patch_size = 32
|
| 31 |
+
# new_array = xr.DataArray(np.array(data_list[0]), dims=("dim_0", "dim_1"))
|
| 32 |
+
# Iterate over the array to extract patches
|
| 33 |
+
for k in range(len(data_list)):
|
| 34 |
+
for i in range(0, 640, patch_size):
|
| 35 |
+
for j in range(0, 640, patch_size):
|
| 36 |
+
patch = data_list[k][i:i+patch_size, j:j+patch_size]
|
| 37 |
+
patches.append(patch)
|
| 38 |
+
|
| 39 |
+
print(len(patches))
|
| 40 |
+
data_shape = len(patches)
|
| 41 |
+
patches_array = np.array(patches, dtype=np.uint8)
|
| 42 |
+
temp_array = np.array(np.random.rand(data_shape, 2), dtype=np.uint16)
|
| 43 |
+
temp_array2 = np.arange(256, dtype=np.float32)
|
| 44 |
+
temp_array3 = np.arange(data_shape, dtype=np.int64)
|
| 45 |
+
|
| 46 |
+
data_da = da.from_array(patches_array, chunks=(data_shape,32,32)) # Adjust chunk size as needed for your data
|
| 47 |
+
data_da2 = da.from_array(temp_array, chunks=(data_shape, 2))
|
| 48 |
+
data_da3 = da.from_array(temp_array3, chunks=(data_shape, ))
|
| 49 |
+
data_da4 = da.from_array(temp_array2, chunks=(256, ))
|
| 50 |
+
|
| 51 |
+
# Create xarray DataArray with DaskArray as its backend
|
| 52 |
+
patches = xr.DataArray(data_da, dims=("dim_patch", "dim_heigh", "dim_width"))
|
| 53 |
+
patch_coords = xr.DataArray(data_da2, dims=("dim_patch1", "dim_coord"))
|
| 54 |
+
patch_times = xr.DataArray(data_da3, dims=("dim_patch2"))
|
| 55 |
+
zero_patch_coords = xr.DataArray(data_da2, dims=("dim_zero_patch", "dim_coord"))
|
| 56 |
+
zero_patch_times = xr.DataArray(data_da3, dims=("dim_zero_patch1"))
|
| 57 |
+
scale = xr.DataArray(data_da4, dims=("dim_scale"))
|
| 58 |
+
|
| 59 |
+
ds = patches.to_dataset(name = 'patches')
|
| 60 |
+
ds['patch_coords'] = patch_coords
|
| 61 |
+
ds['patch_times'] = patch_times
|
| 62 |
+
ds['zero_patch_coords'] = zero_patch_coords
|
| 63 |
+
ds['zero_patch_times'] = zero_patch_times
|
| 64 |
+
ds['scale'] = scale
|
| 65 |
+
|
| 66 |
+
ds.attrs["zero_value"] = 1
|
| 67 |
+
out_dir = out_dir + "/" + "RZC"
|
| 68 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 69 |
+
file_name = os.path.join(out_dir, "patches_RV_202306.nc")
|
| 70 |
+
ds.to_netcdf(file_name)
|
| 71 |
+
|
| 72 |
+
return len(data_list)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
list = load_all_file(data_dir="/data/data_WF/ldcast_precipitation/test")
|
| 76 |
+
print(preprocess_data(list, out_dir="/data/data_WF/ldcast_precipitation/preprocess_data_test"))
|