weatherforecast1024 commited on
Commit
d2f661a
·
verified ·
1 Parent(s): 58d34e4

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +144 -0
  2. LICENSE +202 -0
  3. README.md +106 -0
  4. autoenc-32-0.01.pt +3 -0
  5. config/genforecast-radaronly-128x128-20step.yaml +1 -0
  6. config/genforecast-radaronly-256x256-20step.yaml +5 -0
  7. environment/environment.yml +4 -0
  8. environment/ldcast.yml +170 -0
  9. genforecast-radaronly-256x256-20step.pt +3 -0
  10. ldcast/analysis/confmatrix.py +117 -0
  11. ldcast/analysis/crps.py +162 -0
  12. ldcast/analysis/fss.py +137 -0
  13. ldcast/analysis/histogram.py +108 -0
  14. ldcast/analysis/rank.py +190 -0
  15. ldcast/features/.sampling.py.swp +0 -0
  16. ldcast/features/batch.py +375 -0
  17. ldcast/features/batch.py.save +378 -0
  18. ldcast/features/io.py +125 -0
  19. ldcast/features/patches.py +429 -0
  20. ldcast/features/patches.py.save +431 -0
  21. ldcast/features/sampling.py +215 -0
  22. ldcast/features/split.py +165 -0
  23. ldcast/features/transform.py +296 -0
  24. ldcast/features/utils.py +136 -0
  25. ldcast/forecast.py +264 -0
  26. ldcast/models/autoenc/autoenc.py +93 -0
  27. ldcast/models/autoenc/encoder.py +57 -0
  28. ldcast/models/autoenc/training.py +41 -0
  29. ldcast/models/benchmarks/dgmr.py +82 -0
  30. ldcast/models/benchmarks/pysteps.py +106 -0
  31. ldcast/models/benchmarks/transform.py +17 -0
  32. ldcast/models/blocks/afno.py +348 -0
  33. ldcast/models/blocks/attention.py +104 -0
  34. ldcast/models/blocks/resnet.py +70 -0
  35. ldcast/models/diffusion/diffusion.py +222 -0
  36. ldcast/models/diffusion/ema.py +76 -0
  37. ldcast/models/diffusion/plms.py +245 -0
  38. ldcast/models/diffusion/utils.py +246 -0
  39. ldcast/models/distributions.py +29 -0
  40. ldcast/models/genforecast/analysis.py +33 -0
  41. ldcast/models/genforecast/training.py +42 -0
  42. ldcast/models/genforecast/unet.py +489 -0
  43. ldcast/models/nowcast/nowcast.py +256 -0
  44. ldcast/models/utils.py +28 -0
  45. ldcast/visualization/cm.py +36 -0
  46. ldcast/visualization/plots.py +606 -0
  47. models/.keep +0 -0
  48. models/autoenc/autoenc-32-0.01.pt +3 -0
  49. models/autoenc/autoenc.pt +3 -0
  50. scripts/convert_data_NB_2nc.py +76 -0
.gitignore ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .vscode
2
+ .DS_Store
3
+ dask-worker-space
4
+ *~
5
+ lightning_logs
6
+
7
+ # Byte-compiled / optimized / DLL files
8
+ __pycache__/
9
+ *.py[cod]
10
+ *$py.class
11
+
12
+ # C extensions
13
+ *.so
14
+
15
+ # Distribution / packaging
16
+ .Python
17
+ build/
18
+ develop-eggs/
19
+ dist/
20
+ downloads/
21
+ eggs/
22
+ .eggs/
23
+ lib/
24
+ lib64/
25
+ parts/
26
+ sdist/
27
+ var/
28
+ wheels/
29
+ share/python-wheels/
30
+ *.egg-info/
31
+ .installed.cfg
32
+ *.egg
33
+ MANIFEST
34
+
35
+ # PyInstaller
36
+ # Usually these files are written by a python script from a template
37
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
38
+ *.manifest
39
+ *.spec
40
+
41
+ # Installer logs
42
+ pip-log.txt
43
+ pip-delete-this-directory.txt
44
+
45
+ # Unit test / coverage reports
46
+ htmlcov/
47
+ .tox/
48
+ .nox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *.cover
55
+ *.py,cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+ cover/
59
+
60
+ # Translations
61
+ *.mo
62
+ *.pot
63
+
64
+ # Django stuff:
65
+ *.log
66
+ local_settings.py
67
+ db.sqlite3
68
+ db.sqlite3-journal
69
+
70
+ # Flask stuff:
71
+ instance/
72
+ .webassets-cache
73
+
74
+ # Scrapy stuff:
75
+ .scrapy
76
+
77
+ # Sphinx documentation
78
+ docs/_build/
79
+
80
+ # PyBuilder
81
+ .pybuilder/
82
+ target/
83
+
84
+ # Jupyter Notebook
85
+ .ipynb_checkpoints
86
+
87
+ # IPython
88
+ profile_default/
89
+ ipython_config.py
90
+
91
+ # pyenv
92
+ # For a library or package, you might want to ignore these files since the code is
93
+ # intended to run in multiple environments; otherwise, check them in:
94
+ # .python-version
95
+
96
+ # pipenv
97
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
99
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
100
+ # install all needed dependencies.
101
+ #Pipfile.lock
102
+
103
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
104
+ __pypackages__/
105
+
106
+ # Celery stuff
107
+ celerybeat-schedule
108
+ celerybeat.pid
109
+
110
+ # SageMath parsed files
111
+ *.sage.py
112
+
113
+ # Environments
114
+ .env
115
+ .venv
116
+ env/
117
+ venv/
118
+ ENV/
119
+ env.bak/
120
+ venv.bak/
121
+
122
+ # Spyder project settings
123
+ .spyderproject
124
+ .spyproject
125
+
126
+ # Rope project settings
127
+ .ropeproject
128
+
129
+ # mkdocs documentation
130
+ /site
131
+
132
+ # mypy
133
+ .mypy_cache/
134
+ .dmypy.json
135
+ dmypy.json
136
+
137
+ # Pyre type checker
138
+ .pyre/
139
+
140
+ # pytype static type analyzer
141
+ .pytype/
142
+
143
+ # Cython debug symbols
144
+ cython_debug/
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
202
+
README.md ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LDCast is a precipitation nowcasting model based on a latent diffusion model (LDM, used by e.g. [Stable Diffusion](https://github.com/CompVis/stable-diffusion)).
2
+
3
+ This repository contains the code for using LDCast to make predictions and the code used to generate the analysis in the LDCast paper (a preprint is available at https://arxiv.org/abs/2304.12891).
4
+
5
+ A GPU is recommended for both using and training LDCast, although you may be able to generate some samples with a CPU and enough patience.
6
+
7
+ # Installation
8
+
9
+ It is recommended you install the code in its own virtual environment (created with e.g. pyenv or conda).
10
+
11
+ Clone the repository, then, in the main directory, run
12
+ ```bash
13
+ $ pip install -e .
14
+ ```
15
+ This should automatically install the required packages (which might take some minutes). In the paper, we used PyTorch 11.2 but are not aware of any problems with newer versions.
16
+
17
+ If you don't want the requirements to be installed (e.g. if you installed them manually with conda), use:
18
+ ```bash
19
+ $ pip install --no-dependencies -e .
20
+ ```
21
+
22
+ # Using LDCast
23
+
24
+ ## Pretrained models
25
+
26
+ The pretrained models are available at the Zenodo repository https://doi.org/10.5281/zenodo.7780914. Unzip the file `ldcast-models.zip`. The default is to unzip it to the `models` directory, but you can also use another location.
27
+
28
+ ## Producing predictions
29
+
30
+ The easiest way to produce predictions is to use the `ldcast.forecast.Forecast` class, which will set up all models and data transformations and is callable with a past precipitation array.
31
+ ```python
32
+ from ldcast import forecast
33
+
34
+ fc = forecast.Forecast(
35
+ ldm_weights_fn=ldm_weights_fn, autoenc_weights_fn=autoenc_weights_fn
36
+ )
37
+ R_pred = fc(R_past)
38
+ ```
39
+ Here, `ldm_weights_fn` is the path to the LDM weights and `autoenc_weights_fn` is the path to the autoencoder weights. `R_past` is a NumPy array of precipitation rates with shape `(timesteps, height, width)` where `timesteps` must be 4 and `height` and `width` must be divisible by 32.
40
+
41
+ ### Ensemble predictions
42
+
43
+ If want to process multiple cases at once and/or generate several ensemble members, there is the `ldcast.forecast.ForecastDistributed` class. The usage is similar to the `Forecast` class, for example:
44
+ ```python
45
+ from ldcast import forecast
46
+
47
+ fc = forecast.ForecastDistributed(
48
+ ldm_weights_fn=ldm_weights_fn, autoenc_weights_fn=autoenc_weights_fn
49
+ )
50
+ R_pred = fc(R_past, ensemble_members=32)
51
+ ```
52
+ Here, `R_past` should be of shape `(cases, timesteps, height, width)` where `cases` is the number of cases you want to process. For each case, `ensemble_members` predictions are produced (this is the last axis of `R_pred`). `ForecastDistributed` automatically distributes the workload to multiple GPUs if you have them.
53
+
54
+ ## Demo
55
+
56
+ For a practical example, you can run the demo in the `scripts` directory. First download the `ldcast-demo-20210622.zip` file from the [Zenodo repository](https://doi.org/10.5281/zenodo.7780914), then unzip it in the `data` directory. Then run
57
+ ```bash
58
+ $ python forecast_demo.py
59
+ ```
60
+ A sample output can be found in the file `ldcast-demo-video-20210622.zip` in the data repository. See the function `forecast_demo` in `forecast_demo.py` see how the `Forecast` class works. To run an ensemble mean of 8 members using the `ForecastDistributed` class, you can use:
61
+ ```bash
62
+ $ python forecast_demo.py --ensemble-members=8
63
+ ```
64
+
65
+ The demo for a single ensemble member runs in a couple of minutes on our system using one V100 GPU; with a CPU around 10 minutes or more would be expected. A progress bar will show the status of the generation.
66
+
67
+ # Training
68
+
69
+ ## Training data
70
+
71
+ The preprocessed training data, needed to rerun the LDCast training, can be found at the [Zenodo repository](https://doi.org/10.5281/zenodo.7780914). Unzip the `ldcast-datasets.zip` file to the `data` directory.
72
+
73
+ ## Training the autoencoder
74
+
75
+ In the `scripts` directory, run
76
+ ```bash
77
+ $ python train_autoenc.py --model_dir="../models/autoenc_train"
78
+ ```
79
+ to run the training of the autoencoder with the default parameters. The training checkpoints will be saved in the `../models/autoenc_train` directory (feel free to change this).
80
+
81
+ It has been reported that this training may encounter a condition where the loss goes to `nan`. If this happens, try restarting from the latest checkpoint:
82
+ ```bash
83
+ $ python train_autoenc.py --model_dir="../models/autoenc_train" --ckpt_path="../models/autoenc_train/<checkpoint_file>"
84
+ ```
85
+ where `<checkpoint_file>` should be the latest checkpoint in the `../models/autoenc_train/` directory.
86
+
87
+ ## Training the diffusion model
88
+
89
+ In the `scripts` directory, run
90
+ ```bash
91
+ $ python train_genforecast.py --model_dir="../models/genforecast_train"
92
+ ```
93
+ to run the training of the diffusion model with the default parameters, or
94
+ ```bash
95
+ $ python train_genforecast.py --model_dir="../models/genforecast_train" --config=<path_to_config_file>
96
+ ```
97
+ to run the training with different parameters. Some config files can be found in the `config` directory. The training checkpoints will be saved in the `../models/genforecast_train` directory (again, this can be changed freely).
98
+
99
+ # Evaluation
100
+
101
+ You can find scripts for evaluating models in the `scripts` directory:
102
+ * `eval_genforecast.py` to evaluate LDCast
103
+ * `eval_dgmr.py` to evaluate DGMR (requires tensorflow installation and the DGMR model from https://github.com/deepmind/deepmind-research/tree/master/nowcasting placed in the `models/dgmr` directory)
104
+ * `eval_pysteps.py` to evaluate PySTEPS (requires pysteps installation)
105
+ * `metrics.py` to produce metrics from the evaluation results produced with the functions in scripts above
106
+ * `plot_genforecast.py` to make plots from the results generated
autoenc-32-0.01.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa5ad4b8689aadbf702376e7afe5cb437ef5057675e78a8986837e8f28b3126e
3
+ size 1617490
config/genforecast-radaronly-128x128-20step.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ # this configuration is the default - no parameters to override!
config/genforecast-radaronly-256x256-20step.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ sample_shape: [8,8]
2
+ batch_size: 24
3
+ sampler: "sampler_nowcaster256"
4
+ initial_weights: "../models/genforecast/genforecast-radaronly-128x128-20step.pt"
5
+ lr: 2.5e-5
environment/environment.yml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: ldcast
2
+ channels:
3
+ - defaults
4
+ prefix: /home/mmhk20/.conda/envs/ldcast
environment/ldcast.yml ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: ldcast
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=conda_forge
7
+ - _openmp_mutex=4.5=2_gnu
8
+ - brotli=1.1.0=hd590300_1
9
+ - brotli-bin=1.1.0=hd590300_1
10
+ - bzip2=1.0.8=hd590300_5
11
+ - c-ares=1.27.0=hd590300_0
12
+ - ca-certificates=2024.2.2=hbcca054_0
13
+ - certifi=2024.2.2=pyhd8ed1ab_0
14
+ - cycler=0.12.1=pyhd8ed1ab_0
15
+ - freetype=2.12.1=h267a509_2
16
+ - geos=3.12.1=h59595ed_0
17
+ - keyutils=1.6.1=h166bdaf_0
18
+ - krb5=1.21.2=h659d440_0
19
+ - lcms2=2.16=hb7c19ff_0
20
+ - ld_impl_linux-64=2.40=h41732ed_0
21
+ - lerc=4.0.0=h27087fc_0
22
+ - libblas=3.9.0=21_linux64_openblas
23
+ - libbrotlicommon=1.1.0=hd590300_1
24
+ - libbrotlidec=1.1.0=hd590300_1
25
+ - libbrotlienc=1.1.0=hd590300_1
26
+ - libcblas=3.9.0=21_linux64_openblas
27
+ - libcurl=8.6.0=hca28451_0
28
+ - libdeflate=1.19=hd590300_0
29
+ - libedit=3.1.20191231=he28a2e2_2
30
+ - libev=4.33=hd590300_2
31
+ - libexpat=2.6.2=h59595ed_0
32
+ - libffi=3.4.2=h7f98852_5
33
+ - libgcc-ng=13.2.0=h807b86a_5
34
+ - libgfortran-ng=13.2.0=h69a702a_5
35
+ - libgfortran5=13.2.0=ha4646dd_5
36
+ - libgomp=13.2.0=h807b86a_5
37
+ - libjpeg-turbo=3.0.0=hd590300_1
38
+ - liblapack=3.9.0=21_linux64_openblas
39
+ - libnghttp2=1.58.0=h47da74e_1
40
+ - libnsl=2.0.1=hd590300_0
41
+ - libopenblas=0.3.26=pthreads_h413a1c8_0
42
+ - libpng=1.6.43=h2797004_0
43
+ - libsqlite=3.45.2=h2797004_0
44
+ - libssh2=1.11.0=h0841786_0
45
+ - libstdcxx-ng=13.2.0=h7e041cc_5
46
+ - libtiff=4.6.0=ha9c0a0a_2
47
+ - libuuid=2.38.1=h0b41bf4_0
48
+ - libwebp-base=1.3.2=hd590300_0
49
+ - libxcb=1.15=h0b41bf4_0
50
+ - libxcrypt=4.4.36=hd590300_1
51
+ - libzlib=1.2.13=hd590300_5
52
+ - matplotlib-base=3.8.3=py312he5832f3_0
53
+ - munkres=1.1.4=pyh9f0ad1d_0
54
+ - ncurses=6.4.20240210=h59595ed_0
55
+ - openjpeg=2.5.2=h488ebb8_0
56
+ - openssl=3.2.1=hd590300_1
57
+ - packaging=24.0=pyhd8ed1ab_0
58
+ - pip=24.0=pyhd8ed1ab_0
59
+ - proj=9.3.1=h1d62c97_0
60
+ - pthread-stubs=0.4=h36c2ea0_1001
61
+ - pyparsing=3.1.2=pyhd8ed1ab_0
62
+ - pyshp=2.3.1=pyhd8ed1ab_0
63
+ - python=3.12.2=hab00c5b_0_cpython
64
+ - python-dateutil=2.9.0=pyhd8ed1ab_0
65
+ - python_abi=3.12=4_cp312
66
+ - readline=8.2=h8228510_1
67
+ - setuptools=69.2.0=pyhd8ed1ab_0
68
+ - six=1.16.0=pyh6c4a22f_0
69
+ - sqlite=3.45.2=h2c6b66d_0
70
+ - tk=8.6.13=noxft_h4845f30_101
71
+ - wheel=0.43.0=pyhd8ed1ab_0
72
+ - xorg-libxau=1.0.11=hd590300_0
73
+ - xorg-libxdmcp=1.1.3=h7f98852_0
74
+ - xz=5.2.6=h166bdaf_0
75
+ - zstd=1.5.5=hfc55251_0
76
+ - pip:
77
+ - aiobotocore==2.12.1
78
+ - aiohttp==3.9.3
79
+ - aioitertools==0.11.0
80
+ - aiosignal==1.3.1
81
+ - antlr4-python3-runtime==4.9.3
82
+ - arm-pyart==1.18.0
83
+ - attrs==23.2.0
84
+ - botocore==1.34.51
85
+ - cartopy==0.22.0
86
+ - cftime==1.6.3
87
+ - charset-normalizer==3.3.2
88
+ - click==8.1.7
89
+ - cloudpickle==3.0.0
90
+ - cmweather==0.3.2
91
+ - contourpy==1.2.0
92
+ - dask==2024.3.1
93
+ - deprecation==2.1.0
94
+ - einops==0.7.0
95
+ - filelock==3.13.1
96
+ - fire==0.6.0
97
+ - fonttools==4.50.0
98
+ - frozenlist==1.4.1
99
+ - fsspec==2024.3.1
100
+ - h5netcdf==1.3.0
101
+ - h5py==3.10.0
102
+ - idna==3.6
103
+ - jinja2==3.1.3
104
+ - jmespath==1.0.1
105
+ - jsmin==3.0.1
106
+ - jsonschema==4.22.0
107
+ - jsonschema-specifications==2023.12.1
108
+ - kiwisolver==1.4.5
109
+ - lat-lon-parser==1.3.0
110
+ - lightning==2.2.4
111
+ - lightning-utilities==0.11.0
112
+ - llvmlite==0.42.0
113
+ - locket==1.0.0
114
+ - markupsafe==2.1.5
115
+ - matplotlib==3.8.3
116
+ - mda-xdrlib==0.2.0
117
+ - mpmath==1.3.0
118
+ - multidict==6.0.5
119
+ - netcdf4==1.6.5
120
+ - networkx==3.2.1
121
+ - numba==0.59.1
122
+ - numpy==1.26.4
123
+ - nvidia-cublas-cu12==12.1.3.1
124
+ - nvidia-cuda-cupti-cu12==12.1.105
125
+ - nvidia-cuda-nvrtc-cu12==12.1.105
126
+ - nvidia-cuda-runtime-cu12==12.1.105
127
+ - nvidia-cudnn-cu12==8.9.2.26
128
+ - nvidia-cufft-cu12==11.0.2.54
129
+ - nvidia-curand-cu12==10.3.2.106
130
+ - nvidia-cusolver-cu12==11.4.5.107
131
+ - nvidia-cusparse-cu12==12.1.0.106
132
+ - nvidia-nccl-cu12==2.20.5
133
+ - nvidia-nvjitlink-cu12==12.4.99
134
+ - nvidia-nvtx-cu12==12.1.105
135
+ - omegaconf==2.3.0
136
+ - open-radar-data==0.1.0
137
+ - opencv-python==4.9.0.80
138
+ - pandas==2.2.1
139
+ - partd==1.4.1
140
+ - pillow==10.2.0
141
+ - platformdirs==4.2.0
142
+ - pooch==1.8.1
143
+ - pyproj==3.6.1
144
+ - pysteps==1.9.0
145
+ - pytorch-lightning==2.2.1
146
+ - pytz==2024.1
147
+ - pyyaml==6.0.1
148
+ - referencing==0.35.1
149
+ - requests==2.31.0
150
+ - rpds-py==0.18.1
151
+ - s3fs==2024.3.1
152
+ - scipy==1.12.0
153
+ - shapely==2.0.3
154
+ - sympy==1.12
155
+ - termcolor==2.4.0
156
+ - toolz==0.12.1
157
+ - torch==2.3.0
158
+ - torchmetrics==1.3.2
159
+ - torchvision==0.18.0
160
+ - tqdm==4.66.2
161
+ - typing-extensions==4.10.0
162
+ - tzdata==2024.1
163
+ - urllib3==2.0.7
164
+ - wradlib==2.0.3
165
+ - wrapt==1.16.0
166
+ - xarray==2024.2.0
167
+ - xarray-datatree==0.0.14
168
+ - xmltodict==0.13.0
169
+ - xradar==0.4.3
170
+ - yarl==1.9.4
genforecast-radaronly-256x256-20step.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5fef86b78d29fde8ba66f51ae74f0d84ddc67b711fcab034a3130ea5ac7721cf
3
+ size 5345469521
ldcast/analysis/confmatrix.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import concurrent
3
+ import multiprocessing
4
+
5
+ import netCDF4
6
+ import numpy as np
7
+ from scipy.integrate import trapezoid
8
+
9
+ from ..features.io import load_batch
10
+
11
+
12
+ def confusion_matrix(fc_frac, obs_frac, prob_threshold):
13
+ N = np.prod(fc_frac.shape)
14
+ fc_above = fc_frac > prob_threshold
15
+ obs_above = obs_frac > prob_threshold
16
+ tp = np.count_nonzero(fc_above & obs_above) / N
17
+ fp = np.count_nonzero(fc_above & ~obs_above) / N
18
+ fn = np.count_nonzero(~fc_above & obs_above) / N
19
+ tn = 1.0 - tp - fp - fn
20
+ return np.array(((tp, fn), (fp, tn)))
21
+
22
+
23
+ def confusion_matrix_thresholds(fc_frac, obs_frac, thresholds):
24
+ N_threads = multiprocessing.cpu_count()
25
+ with concurrent.futures.ThreadPoolExecutor(N_threads) as executor:
26
+ futures = [
27
+ executor.submit(confusion_matrix, fc_frac, obs_frac, t)
28
+ for t in thresholds
29
+ ]
30
+ conf_matrix = [f.result() for f in futures]
31
+ return np.stack(conf_matrix, axis=-1)
32
+
33
+
34
+ def confusion_matrix_thresholds_leadtime(fc_frac, obs_frac, thresholds):
35
+ N_threads = multiprocessing.cpu_count()
36
+ conf_matrix = []
37
+ with concurrent.futures.ThreadPoolExecutor(N_threads) as executor:
38
+ for lt in range(fc_frac.shape[2]):
39
+ futures = [
40
+ executor.submit(confusion_matrix,
41
+ fc_frac[...,lt,:,:], obs_frac[...,lt,:,:], t)
42
+ for t in thresholds
43
+ ]
44
+ conf_matrix_lt = [f.result() for f in futures]
45
+ conf_matrix_lt = np.stack(conf_matrix_lt, axis=-1)
46
+ conf_matrix.append(conf_matrix_lt)
47
+
48
+ return np.stack(conf_matrix, axis=-2)
49
+
50
+
51
+
52
+ def precision(conf_matrix):
53
+ ((tp, fn), (fp, tn)) = conf_matrix
54
+ precision = tp / (tp + fp)
55
+ precision[np.isnan(precision)] = 1.0
56
+ return precision
57
+
58
+
59
+ def recall(conf_matrix):
60
+ ((tp, fn), (fp, tn)) = conf_matrix
61
+ return tp / (tp + fn)
62
+
63
+
64
+ def false_alarm_ratio(conf_matrix):
65
+ return 1.0 - precision(conf_matrix)
66
+
67
+
68
+ def intersection_over_union(conf_matrix):
69
+ ((tp, fn), (fp, tn)) = conf_matrix
70
+ return tp / (tp+fp+fn)
71
+
72
+
73
+ def equitable_threat_score(conf_matrix):
74
+ ((tp, fn), (fp, tn)) = conf_matrix
75
+ tp_rnd = (tp+fn) * (tp+fp) / (tp+fp+tn+fn)
76
+ return (tp-tp_rnd) / (tp+fp+fn-tp_rnd)
77
+
78
+
79
+ def peirce_skill_score(conf_matrix):
80
+ ((tp, fn), (fp, tn)) = conf_matrix
81
+ return (tp*tn - fn*fp) / ((tp+fn) * (fp+tn))
82
+
83
+
84
+ def heidke_skill_score(conf_matrix):
85
+ ((tp, fn), (fp, tn)) = conf_matrix
86
+ return 2 * (tp*tn - fn*fp) / ((tp+fn)*(fn+tn) + (tp+fp)*(fp+tn))
87
+
88
+
89
+ def roc_area_under_curve(conf_matrix):
90
+ ((tp, fn), (fp, tn)) = conf_matrix
91
+ tpr = tp / (tp + fn)
92
+ fpr = fp / (fp + tn)
93
+
94
+ auc = trapezoid(tpr[::-1], x=fpr[::-1])
95
+ return auc
96
+
97
+
98
+ def pr_area_under_curve(conf_matrix):
99
+ prec = precision(conf_matrix)
100
+ rec = recall(conf_matrix)
101
+
102
+ if (rec[-1] != 0) or (prec[-1] != 1):
103
+ rec = np.hstack((rec, 0.0))
104
+ prec = np.hstack((prec, 1.0))
105
+
106
+ auc = trapezoid(prec[::-1], x=rec[::-1])
107
+ return auc
108
+
109
+
110
+ def cost_loss_value(conf_matrix, cost, loss, p_clim):
111
+ ((tp, fn), (fp, tn)) = conf_matrix
112
+
113
+ E_c = min(cost, p_clim*loss)
114
+ E_p = p_clim * cost
115
+ E_f = (tp+fp)*cost + fn*loss
116
+ #print(cost, loss, p_clim, E_c, E_p, E_f[len(E_f)//2]/E_p, (E_f[len(E_f)//2] - E_c) / (E_p - E_c))
117
+ return (E_f - E_c) / (E_p - E_c)
ldcast/analysis/crps.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+ import multiprocessing
3
+ import os
4
+
5
+ import netCDF4
6
+ import numpy as np
7
+
8
+ from ..features.io import load_batch, decode_saved_var_to_rainrate
9
+
10
+
11
+ def crps_ensemble(observation, forecasts):
12
+ shape = observation.shape
13
+ N = np.prod(shape)
14
+ shape_flat = (np.prod(shape),)
15
+ observation = observation.reshape((N,))
16
+ forecasts = forecasts.reshape((N, forecasts.shape[-1]))
17
+ crps_all = np.zeros_like(observation)
18
+ N_threads = multiprocessing.cpu_count()
19
+
20
+ def crps_chunk(k):
21
+ i0 = int(round((k/N_threads) * N))
22
+ i1 = int(round(((k+1) / N_threads) * N))
23
+ obs = observation[i0:i1].copy()
24
+ fc = forecasts[i0:i1,:].copy()
25
+ fc.sort(axis=-1)
26
+ fc_below = fc < obs[...,None]
27
+ crps = np.zeros_like(obs)
28
+
29
+ for i in range(fc.shape[-1]):
30
+ below = fc_below[...,i]
31
+ weight = ((i+1)**2 - i**2) / fc.shape[-1]**2
32
+ crps[below] += weight * (obs[below]-fc[...,i][below])
33
+
34
+ for i in range(fc.shape[-1]-1,-1,-1):
35
+ above = ~fc_below[...,i]
36
+ k = fc.shape[-1]-1-i
37
+ weight = ((k+1)**2 - k**2) / fc.shape[-1]**2
38
+ crps[above] += weight * (fc[...,i][above]-obs[above])
39
+
40
+ crps_all[i0:i1] = crps
41
+
42
+ with concurrent.futures.ThreadPoolExecutor(N_threads) as executor:
43
+ futures = {}
44
+ for k in range(N_threads):
45
+ args = (crps_chunk, k)
46
+ futures[executor.submit(*args)] = k
47
+ concurrent.futures.wait(futures)
48
+
49
+ return crps_all.reshape(shape)
50
+
51
+
52
+ def crps_ensemble_multiscale(observation, forecasts):
53
+ obs = observation
54
+ fc = forecasts
55
+
56
+ crps_scales = {}
57
+ scale = 1
58
+ while True:
59
+ c = crps_ensemble(obs, fc)
60
+ crps_scales[scale] = c
61
+ scale *= 2
62
+ if obs.shape[-1] == 1:
63
+ break
64
+ # avg pooling
65
+ obs = 0.25 * (
66
+ obs[...,::2,::2] +
67
+ obs[...,1::2,::2] +
68
+ obs[...,::2,1::2] +
69
+ obs[...,1::2,1::2]
70
+ )
71
+ fc = 0.25 * (
72
+ fc[...,::2,::2,:] +
73
+ fc[...,1::2,::2,:] +
74
+ fc[...,::2,1::2,:] +
75
+ fc[...,1::2,1::2,:]
76
+ )
77
+
78
+ return crps_scales
79
+
80
+
81
+ def gather_observation(data_dir):
82
+ files = sorted(os.listdir(data_dir))
83
+ files = [os.path.join(data_dir,fn) for fn in files]
84
+
85
+ def obs_from_file(fn):
86
+ with netCDF4.Dataset(fn, 'r') as ds:
87
+ obs = np.array(ds["future_observations"][:], copy=False)
88
+ obs = decode_saved_var_to_rainrate(obs)
89
+ p = 1
90
+ obs_pooled = {}
91
+ while True:
92
+ obs_pooled[p] = obs
93
+ if obs.shape[-1] == 1:
94
+ break
95
+ obs = 0.25 * (
96
+ obs[...,::2,::2] +
97
+ obs[...,1::2,::2] +
98
+ obs[...,::2,1::2] +
99
+ obs[...,1::2,1::2]
100
+ )
101
+ p *= 2
102
+ return obs_pooled
103
+
104
+ obs_pooled = {}
105
+ for fn in files:
106
+ print(fn)
107
+ obs_file = obs_from_file(fn)
108
+ for k in obs_file:
109
+ if k not in obs_pooled:
110
+ obs_pooled[k] = []
111
+ obs_pooled[k].append(obs_file[k])
112
+
113
+ for k in obs_pooled:
114
+ obs_pooled[k] = np.concatenate(obs_pooled[k], axis=0)
115
+
116
+ return obs_pooled
117
+
118
+
119
+ def process_batch(fn, log=False, preproc_fc=None):
120
+ print(fn)
121
+ (_, y, y_pred) = load_batch(fn, log=log, preproc_fc=preproc_fc)
122
+ return crps_ensemble_multiscale(y, y_pred)
123
+
124
+
125
+ def save_crps_for_dataset(data_dir, result_fn, log=False, preproc_fc=None):
126
+ files = sorted(os.listdir(data_dir))
127
+ files = [os.path.join(data_dir,fn) for fn in files]
128
+
129
+ N_threads = multiprocessing.cpu_count()
130
+ futures = []
131
+ with concurrent.futures.ProcessPoolExecutor(N_threads) as executor:
132
+ for fn in files:
133
+ args = (process_batch, fn)
134
+ kwargs = {"log": log, "preproc_fc": preproc_fc}
135
+ futures.append(executor.submit(*args, **kwargs))
136
+
137
+ crps = [f.result() for f in futures]
138
+ scales = sorted(crps[0].keys())
139
+ crps = {
140
+ s: np.concatenate([c[s] for c in crps], axis=0)
141
+ for s in scales
142
+ }
143
+
144
+ with netCDF4.Dataset(result_fn, 'w') as ds:
145
+ ds.createDimension("dim_sample", crps[1].shape[0])
146
+ ds.createDimension("dim_channel", crps[1].shape[1])
147
+ ds.createDimension("dim_time_future", crps[1].shape[2])
148
+ var_params = {"zlib": True, "complevel": 1}
149
+
150
+ for s in scales:
151
+ ds.createDimension(f"dim_h_pool{s}x{s}", crps[s].shape[3])
152
+ ds.createDimension(f"dim_w_pool{s}x{s}", crps[s].shape[4])
153
+ var = ds.createVariable(
154
+ f"crps_pool{s}x{s}", np.float32,
155
+ (
156
+ "dim_sample", "dim_channel", "dim_time_future",
157
+ f"dim_h_pool{s}x{s}", f"dim_w_pool{s}x{s}",
158
+ ),
159
+ chunksizes=(1,)+crps[s].shape[1:],
160
+ **var_params
161
+ )
162
+ var[:] = crps[s]
ldcast/analysis/fss.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+ import multiprocessing
3
+ import os
4
+
5
+ import netCDF4
6
+ import numpy as np
7
+
8
+ from ..features.io import load_batch, decode_saved_var_to_rainrate
9
+
10
+
11
+ def fractions_ensemble(observation, forecasts, threshold, max_scale=256):
12
+ obs = (observation >= threshold).astype(np.float32)
13
+ fc = (forecasts >= threshold).astype(np.float32).mean(axis=-1)
14
+ obs_frac = {}
15
+ fc_frac = {}
16
+
17
+ scale = 1
18
+ while True:
19
+ obs_frac[scale] = obs.copy()
20
+ fc_frac[scale] = fc.copy()
21
+ scale *= 2
22
+ if scale > max_scale:
23
+ break
24
+ obs = 0.25 * (
25
+ obs[...,::2,::2] +
26
+ obs[...,1::2,::2] +
27
+ obs[...,::2,1::2] +
28
+ obs[...,1::2,1::2]
29
+ )
30
+ fc = 0.25 * (
31
+ fc[...,::2,::2] +
32
+ fc[...,1::2,::2] +
33
+ fc[...,::2,1::2] +
34
+ fc[...,1::2,1::2]
35
+ )
36
+
37
+ return (obs_frac, fc_frac)
38
+
39
+
40
+ def frac_from_file(fn, threshold, preproc_fc):
41
+ print(fn)
42
+ (_, y, y_pred) = load_batch(fn, preproc_fc=preproc_fc)
43
+ return fractions_ensemble(y, y_pred, threshold)
44
+
45
+
46
+ def save_fractions_for_dataset(data_dir, result_fn, threshold, preproc_fc=None):
47
+ files = sorted(os.listdir(data_dir))
48
+ files = [os.path.join(data_dir,fn) for fn in files]
49
+
50
+ N_threads = multiprocessing.cpu_count()
51
+ with concurrent.futures.ProcessPoolExecutor(N_threads) as executor:
52
+ futures = []
53
+ for fn in files:
54
+ args = (frac_from_file, fn, threshold, preproc_fc)
55
+ futures.append(executor.submit(*args))
56
+
57
+ (obs_frac, fc_frac) = zip(*(f.result() for f in futures))
58
+
59
+ scales = list(obs_frac[0].keys())
60
+ obs_frac_dict = {}
61
+ fc_frac_dict = {}
62
+ for s in scales:
63
+ obs_frac_dict[s] = np.concatenate([f[s] for f in obs_frac], axis=0)
64
+ fc_frac_dict[s] = np.concatenate([f[s] for f in fc_frac], axis=0)
65
+ obs_frac = obs_frac_dict
66
+ fc_frac = fc_frac_dict
67
+
68
+ frac_vars = {}
69
+ k = 0
70
+ with netCDF4.Dataset(result_fn, 'w') as ds:
71
+ ds.createDimension("dim_sample", obs_frac[1].shape[0])
72
+ ds.createDimension("dim_channel", obs_frac[1].shape[1])
73
+ ds.createDimension("dim_time_future", obs_frac[1].shape[2])
74
+ var_params = {"zlib": True, "complevel": 1}
75
+ for s in scales:
76
+ ds.createDimension(f"dim_h_pool{s}x{s}", obs_frac[s].shape[3])
77
+ ds.createDimension(f"dim_w_pool{s}x{s}", obs_frac[s].shape[4])
78
+ obs_var = ds.createVariable(
79
+ f"obs_frac_scale{s}x{s}", np.float32,
80
+ (
81
+ "dim_sample", "dim_channel", "dim_time_future",
82
+ f"dim_h_pool{s}x{s}", f"dim_w_pool{s}x{s}",
83
+ ),
84
+ chunksizes=(1,)+obs_frac[s].shape[1:],
85
+ **var_params
86
+ )
87
+ obs_var[:] = obs_frac[s]
88
+ fc_var = ds.createVariable(
89
+ f"fc_frac_scale{s}x{s}", np.float32,
90
+ (
91
+ "dim_sample", "dim_channel", "dim_time_future",
92
+ f"dim_h_pool{s}x{s}", f"dim_w_pool{s}x{s}",
93
+ ),
94
+ chunksizes=(1,)+fc_frac[s].shape[1:],
95
+ **var_params
96
+ )
97
+ fc_var[:] = fc_frac[s]
98
+
99
+
100
+ def load_fractions(fn):
101
+ obs_frac = {}
102
+ fc_frac = {}
103
+ with netCDF4.Dataset(fn, 'r') as ds:
104
+ var_list = ds.variables.keys()
105
+ scales = {int(v.split("x")[-1]) for v in var_list}
106
+ for s in scales:
107
+ obs_frac[s] = np.array(ds[f"obs_frac_scale{s}x{s}"][:], copy=False)
108
+ fc_frac[s] = np.array(ds[f"fc_frac_scale{s}x{s}"][:], copy=False)
109
+
110
+ return (obs_frac, fc_frac)
111
+
112
+
113
+ def fractions_skill_score(
114
+ obs_frac, fc_frac,
115
+ frac_axes=None, fss_axes=None, use_timesteps=None
116
+ ):
117
+ if isinstance(obs_frac, dict):
118
+ return {
119
+ s: fractions_skill_score(
120
+ obs_frac[s], fc_frac[s],
121
+ frac_axes=frac_axes, fss_axes=fss_axes,
122
+ use_timesteps=use_timesteps
123
+ )
124
+ for s in sorted(obs_frac)
125
+ }
126
+
127
+ if use_timesteps != None:
128
+ obs_frac = obs_frac[:,:,:use_timesteps,...]
129
+ fc_frac = fc_frac[:,:,:use_timesteps,...]
130
+ fbs = ((obs_frac - fc_frac)**2).mean(axis=frac_axes)
131
+ fbs_ref = (obs_frac**2).mean(axis=frac_axes) + \
132
+ (fc_frac**2).mean(axis=frac_axes)
133
+ fss = 1 - fbs/fbs_ref
134
+ if isinstance(fss, np.ndarray):
135
+ fss[~np.isfinite(fss)] = 1
136
+ fss = fss.mean(axis=fss_axes)
137
+ return fss
ldcast/analysis/histogram.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+ import multiprocessing
3
+ import os
4
+
5
+ import netCDF4
6
+ import numpy as np
7
+ from scipy.interpolate import interp1d
8
+
9
+ from ..features.io import load_batch, decode_saved_var_to_rainrate
10
+
11
+
12
+ def histogram(observation, forecasts, bins):
13
+ N_bins = len(bins)-1
14
+ N_timesteps = observation.shape[2]
15
+ obs_hist = np.zeros((N_bins, N_timesteps), dtype=np.uint64)
16
+ fc_hist = np.zeros((N_bins, N_timesteps), dtype=np.uint64)
17
+
18
+ for t in range(observation.shape[2]):
19
+ obs = observation[:,:,t,...].flatten()
20
+ fc = forecasts[:,:,t,...].flatten()
21
+ obs_hist[:,t] = np.histogram(obs, bins=bins)[0]
22
+ fc_hist[:,t] = np.histogram(fc, bins=bins)[0]
23
+
24
+ return (obs_hist, fc_hist)
25
+
26
+
27
+ def hist_from_file(fn, bins):
28
+ print(fn)
29
+ (_, y, y_pred) = load_batch(fn, threshold=bins[0])
30
+ return histogram(y, y_pred, bins)
31
+
32
+
33
+ def save_histogram_for_dataset(data_dir, result_fn, bins=(0.05,120,100)):
34
+ files = sorted(os.listdir(data_dir))
35
+ files = [os.path.join(data_dir,fn) for fn in files]
36
+
37
+ bins = np.exp(np.linspace(np.log(bins[0]), np.log(bins[1]), bins[2]))
38
+ bins = np.hstack((0, bins))
39
+
40
+ N_threads = multiprocessing.cpu_count()
41
+ with concurrent.futures.ProcessPoolExecutor(N_threads) as executor:
42
+ futures = []
43
+ for fn in files:
44
+ args = (hist_from_file, fn, bins)
45
+ futures.append(executor.submit(*args))
46
+
47
+ (obs_hist, fc_hist) = zip(*(f.result() for f in futures))
48
+
49
+ obs_hist = sum(obs_hist)
50
+ fc_hist = sum(fc_hist)
51
+
52
+ with netCDF4.Dataset(result_fn, 'w') as ds:
53
+ ds.createDimension("dim_bin", obs_hist.shape[0])
54
+ ds.createDimension("dim_time_future", obs_hist.shape[1])
55
+ var_params = {"zlib": True, "complevel": 1}
56
+
57
+ obs_var = ds.createVariable(
58
+ f"obs_hist", np.uint64,
59
+ ("dim_bin", "dim_time_future"),
60
+ **var_params
61
+ )
62
+ obs_var[:] = obs_hist
63
+
64
+ fc_var = ds.createVariable(
65
+ f"fc_hist", np.uint64,
66
+ ("dim_bin", "dim_time_future"),
67
+ **var_params
68
+ )
69
+ fc_var[:] = fc_hist
70
+
71
+ ds.createDimension("dim_bin_edge", len(bins))
72
+ bin_var = ds.createVariable(
73
+ f"bins", np.float64,
74
+ ("dim_bin_edge",),
75
+ **var_params
76
+ )
77
+ bin_var[:] = bins
78
+
79
+
80
+ def load_histogram(fn):
81
+ with netCDF4.Dataset(fn, 'r') as ds:
82
+ obs_hist = np.array(ds["obs_hist"][:], copy=False)
83
+ fc_hist = np.array(ds["fc_hist"][:], copy=False)
84
+ bins = np.array(ds["bins"][:], copy=False)
85
+
86
+ return (obs_hist, fc_hist, bins)
87
+
88
+
89
+ class ProbabilityMatch:
90
+ def __init__(self, obs_hist, fc_hist, bins):
91
+ obs_c = obs_hist.cumsum()
92
+ obs_c = obs_c / obs_c[-1]
93
+ fc_c = fc_hist.cumsum()
94
+ fc_c = fc_c / fc_c[-1]
95
+
96
+ self.obs_cdf = interp1d(np.hstack((0,obs_c)), bins, fill_value='extrapolate')
97
+ self.fc_cdf = interp1d(bins, np.hstack((0,fc_c)), fill_value='extrapolate')
98
+
99
+ def __call__(self, x):
100
+ return self.obs_cdf(self.fc_cdf(x))
101
+
102
+
103
+ def probability_match_timesteps(obs_hist, fc_hist, bins):
104
+ num_timesteps = obs_hist.shape[1]
105
+ return [
106
+ ProbabilityMatch(obs_hist[:,t], fc_hist[:,t], bins)
107
+ for t in range(num_timesteps)
108
+ ]
ldcast/analysis/rank.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import concurrent
3
+ import multiprocessing
4
+
5
+ import netCDF4
6
+ import numpy as np
7
+
8
+ from ..features.io import load_batch
9
+
10
+
11
+ def ranks_ensemble(
12
+ observation, forecasts,
13
+ noise_scale=1e-6, rng=None
14
+ ):
15
+ shape = observation.shape
16
+ N = np.prod(shape)
17
+ shape_flat = (np.prod(shape),)
18
+ observation = observation.reshape((N,))
19
+ forecasts = forecasts.reshape((N, forecasts.shape[-1]))
20
+ N_threads = multiprocessing.cpu_count()
21
+
22
+ max_rank = forecasts.shape[-1]
23
+ bins = np.arange(-0.5, max_rank+0.6)
24
+ ranks_all = np.zeros_like(observation, dtype=np.uint32)
25
+
26
+ if rng is None:
27
+ rng = np.random
28
+
29
+ def rank_dist_chunk(k):
30
+ i0 = int(round((k/N_threads) * N))
31
+ i1 = int(round(((k+1) / N_threads) * N))
32
+ obs = observation[i0:i1].astype(np.float64, copy=True)
33
+ fc = forecasts[i0:i1,:].astype(np.float64, copy=True)
34
+
35
+ # add a tiny amount of noise to forecast to randomize ties
36
+ # (important to add to both obs and fc!)
37
+ obs += (rng.rand(*obs.shape) - 0.5) * noise_scale
38
+ fc += (rng.rand(*fc.shape) - 0.5) * noise_scale
39
+
40
+ ranks = np.count_nonzero(obs[...,None] >= fc, axis=-1)
41
+ ranks_all[i0:i1] = ranks
42
+
43
+ with concurrent.futures.ThreadPoolExecutor(N_threads) as executor:
44
+ futures = {}
45
+ for k in range(N_threads):
46
+ args = (rank_dist_chunk, k)
47
+ futures[executor.submit(*args)] = k
48
+ concurrent.futures.wait(futures)
49
+
50
+ return ranks_all.reshape(shape)
51
+
52
+
53
+ def ranks_multiscale(observation, forecasts):
54
+ obs = observation
55
+ fc = forecasts
56
+
57
+ rank_scales = {}
58
+ scale = 1
59
+ while True:
60
+ r = ranks_ensemble(obs, fc)
61
+ rank_scales[scale] = r
62
+ scale *= 2
63
+ if obs.shape[-1] == 1:
64
+ break
65
+ # avg pooling
66
+ obs = 0.25 * (
67
+ obs[...,::2,::2] +
68
+ obs[...,1::2,::2] +
69
+ obs[...,::2,1::2] +
70
+ obs[...,1::2,1::2]
71
+ )
72
+ fc = 0.25 * (
73
+ fc[...,::2,::2,:] +
74
+ fc[...,1::2,::2,:] +
75
+ fc[...,::2,1::2,:] +
76
+ fc[...,1::2,1::2,:]
77
+ )
78
+
79
+ return rank_scales
80
+
81
+
82
+ def rank_distribution(ranks, num_forecasts=32):
83
+ N = np.prod(ranks.shape)
84
+ bins = np.arange(-0.5, num_forecasts+0.6)
85
+ N_threads = multiprocessing.cpu_count()
86
+ ranks = ranks.ravel()
87
+
88
+ hist = [None] * N_threads
89
+ def hist_chunk(k):
90
+ i0 = int(round((k/N_threads) * N))
91
+ i1 = int(round(((k+1) / N_threads) * N))
92
+ (h, _) = np.histogram(ranks[i0:i1], bins=bins)
93
+ hist[k] = h
94
+
95
+ with concurrent.futures.ThreadPoolExecutor(N_threads) as executor:
96
+ futures = {}
97
+ for k in range(N_threads):
98
+ args = (hist_chunk, k)
99
+ futures[executor.submit(*args)] = k
100
+ concurrent.futures.wait(futures)
101
+
102
+ hist = sum(hist)
103
+ return hist / hist.sum()
104
+
105
+
106
+ def rank_KS(rank_dist, num_forecasts=32):
107
+ h = rank_dist
108
+ h = h / h.sum()
109
+ ch = np.cumsum(h)
110
+ cb = np.linspace(0, 1, len(ch))
111
+ return abs(ch-cb).max()
112
+
113
+
114
+ def rank_DKL(rank_dist, num_forecasts=32):
115
+ h = rank_dist
116
+ q = h / h.sum()
117
+ p = 1/len(h)
118
+ return p*np.log(p/q).sum()
119
+
120
+
121
+ def rank_metric_by_leadtime(ranks, metric=None, num_forecasts=32):
122
+ if metric is None:
123
+ metric = rank_DKL
124
+
125
+ metric_by_leadtime = []
126
+ for t in range(ranks.shape[2]):
127
+ ranks_time = ranks[:,:,t,...]
128
+ h = rank_distribution(ranks_time)
129
+ m = metric(h, num_forecasts=num_forecasts)
130
+ metric_by_leadtime.append(m)
131
+ return np.array(metric_by_leadtime)
132
+
133
+
134
+ def rank_metric_by_bin(ranks, values, bins, metric=None, num_forecasts=32):
135
+ if metric is None:
136
+ metric = rank_DKL
137
+
138
+ metric_by_bin = []
139
+ for (b0,b1) in zip(bins[:-1],bins[1:]):
140
+ ranks_bin = ranks[(b0 <= values) & (values < b1)]
141
+ h = rank_distribution(ranks_bin)
142
+ m = metric(h, num_forecasts=num_forecasts)
143
+ metric_by_bin.append(m)
144
+ return np.array(metric_by_bin)
145
+
146
+
147
+ def process_batch(fn, preproc_fc=None):
148
+ print(fn)
149
+ (_, y, y_pred) = load_batch(fn, preproc_fc=preproc_fc)
150
+ return ranks_multiscale(y, y_pred)
151
+
152
+
153
+ def save_ranks_for_dataset(data_dir, result_fn, preproc_fc=None):
154
+ files = sorted(os.listdir(data_dir))
155
+ files = [os.path.join(data_dir,fn) for fn in files]
156
+
157
+ N_threads = multiprocessing.cpu_count()
158
+ futures = []
159
+ with concurrent.futures.ProcessPoolExecutor(N_threads) as executor:
160
+ for fn in files:
161
+ args = (process_batch, fn)
162
+ kwargs = {"preproc_fc": preproc_fc}
163
+ futures.append(executor.submit(*args, **kwargs))
164
+
165
+ ranks = [f.result() for f in futures]
166
+ scales = sorted(ranks[0].keys())
167
+ ranks = {
168
+ s: np.concatenate([r[s] for r in ranks], axis=0)
169
+ for s in scales
170
+ }
171
+
172
+ with netCDF4.Dataset(result_fn, 'w') as ds:
173
+ ds.createDimension("dim_sample", ranks[1].shape[0])
174
+ ds.createDimension("dim_channel", ranks[1].shape[1])
175
+ ds.createDimension("dim_time_future", ranks[1].shape[2])
176
+ var_params = {"zlib": True, "complevel": 1}
177
+
178
+ for s in scales:
179
+ ds.createDimension(f"dim_h_pool{s}x{s}", ranks[s].shape[3])
180
+ ds.createDimension(f"dim_w_pool{s}x{s}", ranks[s].shape[4])
181
+ var = ds.createVariable(
182
+ f"ranks_pool{s}x{s}", np.float32,
183
+ (
184
+ "dim_sample", "dim_channel", "dim_time_future",
185
+ f"dim_h_pool{s}x{s}", f"dim_w_pool{s}x{s}",
186
+ ),
187
+ chunksizes=(1,)+ranks[s].shape[1:],
188
+ **var_params
189
+ )
190
+ var[:] = ranks[s]
ldcast/features/.sampling.py.swp ADDED
File without changes
ldcast/features/batch.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timedelta
2
+ import os
3
+ import pickle
4
+
5
+ from numba import njit, prange, types
6
+ from numba.typed import Dict
7
+ import numpy as np
8
+ from torch.utils.data import Dataset, IterableDataset
9
+
10
+ from .patches import unpack_patches
11
+ from .sampling import EqualFrequencySampler
12
+
13
+
14
+ class BatchGenerator:
15
+ def __init__(self,
16
+ variables,
17
+ raw,
18
+ predictors,
19
+ target,
20
+ primary_var,
21
+ time_range_sampling=(-1,2),
22
+ forecast_raw_vars=(),
23
+ sampling_bins=None,
24
+ sampler_file=None,
25
+ sample_shape=(4,4),
26
+ batch_size=32,
27
+ interval=timedelta(minutes=5),
28
+ random_seed=None,
29
+ augment=False
30
+ ):
31
+ super().__init__()
32
+ self.batch_size = batch_size
33
+ self.interval = interval
34
+ self.interval_secs = np.int64(self.interval.total_seconds())
35
+ self.variables = variables
36
+ self.predictors = predictors
37
+ self.target = target
38
+ self.used_variables = predictors + [target]
39
+ self.rng = np.random.RandomState(seed=random_seed)
40
+ self.augment = augment
41
+
42
+ # setup indices for retrieving source raw data
43
+ self.sources = set.union(
44
+ *(set(variables[v]["sources"]) for v in self.used_variables)
45
+ )
46
+ self.forecast_raw_vars = set(forecast_raw_vars) & self.sources
47
+ self.patch_index = {}
48
+ for raw_name_base in self.sources:
49
+ if raw_name_base in forecast_raw_vars:
50
+ raw_names = (
51
+ rn for rn in raw if rn.startswith(raw_name_base+"-")
52
+ )
53
+ else:
54
+ raw_names = (raw_name_base,)
55
+ for raw_name in raw_names:
56
+ raw_data = raw[raw_name]
57
+ self.setup_index(raw_name, raw_data, sample_shape)
58
+
59
+ for raw_name in self.forecast_raw_vars:
60
+ patch_index_var = {
61
+ k: v for (k,v) in self.patch_index.items()
62
+ if k.startswith(raw_name+"-")
63
+ }
64
+ self.patch_index[raw_name] = \
65
+ ForecastPatchIndexWrapper(patch_index_var)
66
+
67
+ # setup samplers
68
+ if (sampler_file is None) or not os.path.isfile(sampler_file):
69
+ print("No cached sampler found, creating a new one...")
70
+ primary_raw_var = variables[primary_var]["sources"][0]
71
+ t0 = t1 = None
72
+ for (var_name, var_data) in variables.items():
73
+ timesteps = var_data["timesteps"][[0,-1]].copy()
74
+ timesteps[0] -= 1
75
+ ts_secs = timesteps * \
76
+ var_data.get("timestep_secs", self.interval_secs)
77
+ timesteps = ts_secs // self.interval_secs
78
+ t0 = timesteps[0] if t0 is None else min(t0,timesteps[0])
79
+ t1 = timesteps[-1] if t1 is None else max(t1,timesteps[-1])
80
+ time_range_valid = (t0,t1+1)
81
+ self.sampler = EqualFrequencySampler(
82
+ sampling_bins, raw[primary_raw_var],
83
+ self.patch_index[primary_raw_var], sample_shape,
84
+ time_range_valid, time_range_sampling=time_range_sampling,
85
+ timestep_secs=self.interval_secs
86
+ )
87
+ if sampler_file is not None:
88
+ print(f"Caching sampler to {sampler_file}.")
89
+ with open(sampler_file, 'wb') as f:
90
+ pickle.dump(self.sampler, f)
91
+ else:
92
+ print(f"Loading cached sampler from {sampler_file}.")
93
+ with open(sampler_file, 'rb') as f:
94
+ self.sampler = pickle.load(f)
95
+
96
+ def setup_index(self, raw_name, raw_data, box_size):
97
+ zero_value = raw_data.get("zero_value", 0)
98
+ missing_value = raw_data.get("missing_value", zero_value)
99
+
100
+ self.patch_index[raw_name] = PatchIndex(
101
+ *unpack_patches(raw_data),
102
+ zero_value=zero_value,
103
+ missing_value=missing_value,
104
+ interval=self.interval,
105
+ box_size=box_size
106
+ )
107
+
108
+ def augmentations(self):
109
+ return tuple(self.rng.randint(2, size=3))
110
+
111
+ def augment_batch(self, batch, transpose, flipud, fliplr):
112
+ if self.augment:
113
+ if transpose:
114
+ axes = list(range(batch.ndim))
115
+ axes = axes[:-2] + [axes[-1], axes[-2]]
116
+ batch = batch.transpose(axes)
117
+ flips = []
118
+ if flipud:
119
+ flips.append(-2)
120
+ if fliplr:
121
+ flips.append(-1)
122
+ if flips:
123
+ batch = np.flip(batch, axis=flips)
124
+ return batch.copy()
125
+
126
+ def batch(self, samples=None, batch_size=None):
127
+ if batch_size is None:
128
+ batch_size = self.batch_size
129
+
130
+ if samples is None:
131
+ # get the sample coordinates from the sampler
132
+ samples = self.sampler(batch_size)
133
+
134
+ print(samples)
135
+ (t0,i0,j0) = samples.T
136
+
137
+ if self.augment:
138
+ augmentations = self.augmentations()
139
+
140
+ batch = {}
141
+ for var_name in self.used_variables:
142
+ var_data = self.variables[var_name]
143
+
144
+ # different timestep from standard (e.g. forecast); round down
145
+ # to times where we have data available
146
+ ts_secs = var_data.get("timestep_secs", self.interval_secs)
147
+ t_shift = -(t0 % ts_secs)
148
+ t0_shifted = t0 + t_shift
149
+ t = t0_shifted[:,None] + ts_secs*var_data["timesteps"][None,:]
150
+ t_relative = (t - t0[:,None]) / self.interval_secs
151
+
152
+ # read raw data from index
153
+ raw_data = (
154
+ self.patch_index[raw_name](t,i0,j0)
155
+ for raw_name in var_data["sources"]
156
+ )
157
+
158
+ # transform to model variable
159
+ batch_var = var_data["transform"](*raw_data)
160
+
161
+ # add channel dimension if not already present
162
+ add_dims = (1,) if batch_var.ndim == 4 else ()
163
+ batch_var = np.expand_dims(batch_var, add_dims)
164
+
165
+ # data augmentation
166
+ if self.augment:
167
+ batch_var = self.augment_batch(batch_var, *augmentations)
168
+
169
+ # bundle with time coordinates
170
+ batch[var_name] = (batch_var, t_relative.astype(np.float32))
171
+
172
+ pred_batch = [batch[v] for v in self.predictors]
173
+ target_batch = batch[self.target][0] # no time coordinates for target
174
+ return (pred_batch, target_batch)
175
+
176
+ def batches(self, *args, num=None, **kwargs):
177
+ if num is not None:
178
+ for i in range(num):
179
+ yield self.batch(*args, **kwargs)
180
+ else:
181
+ while True:
182
+ yield self.batch(*args, **kwargs)
183
+
184
+
185
+ class StreamBatchDataset(IterableDataset):
186
+ def __init__(self, batch_gen, batches_per_epoch):
187
+ super().__init__()
188
+ self.batch_gen = batch_gen
189
+ self.batches_per_epoch = batches_per_epoch
190
+
191
+ def __iter__(self):
192
+ batches = self.batch_gen.batches(num=self.batches_per_epoch)
193
+ yield from batches
194
+
195
+
196
+ class DeterministicBatchDataset(Dataset):
197
+ def __init__(self, batch_gen, batches_per_epoch, random_seed=None):
198
+ super().__init__()
199
+ self.batch_gen = batch_gen
200
+ self.batches_per_epoch = batches_per_epoch
201
+ self.batch_gen.sampler.rng = np.random.RandomState(seed=random_seed)
202
+ self.samples = [
203
+ self.batch_gen.sampler(self.batch_gen.batch_size)
204
+ for i in range(self.batches_per_epoch)
205
+ ]
206
+
207
+ def __len__(self):
208
+ return self.batches_per_epoch
209
+
210
+ def __getitem__(self, ind):
211
+ print(self.samples[ind])
212
+ return self.batch_gen.batch(samples=self.samples[ind])
213
+
214
+
215
+ class PatchIndex:
216
+ IDX_ZERO = -1
217
+ IDX_MISSING = -2
218
+
219
+ def __init__(
220
+ self, patch_data, patch_coords, patch_times,
221
+ zero_patch_coords, zero_patch_times,
222
+ interval=timedelta(minutes=5),
223
+ box_size=(4,4), zero_value=0,
224
+ missing_value=0
225
+ ):
226
+ self.dt = int(round(interval.total_seconds()))
227
+ self.box_size = box_size
228
+ self.zero_value = zero_value
229
+ self.missing_value = missing_value
230
+ self.patch_data = patch_data
231
+ self.sample_shape = (
232
+ self.patch_data.shape[1]*box_size[0],
233
+ self.patch_data.shape[2]*box_size[1]
234
+ )
235
+
236
+ self.patch_index = Dict.empty(
237
+ key_type=types.UniTuple(types.int64, 3),
238
+ value_type=types.int64
239
+ )
240
+ init_patch_index(self.patch_index, patch_coords, patch_times)
241
+ init_patch_index_zero(self.patch_index, zero_patch_coords,
242
+ zero_patch_times, PatchIndex.IDX_ZERO)
243
+
244
+ self._batch = None
245
+
246
+ def _alloc_batch(self, batch_size, num_timesteps):
247
+ needs_rebuild = (self._batch is None) or \
248
+ (self._batch.shape[0] < batch_size) or \
249
+ (self._batch.shape[1] < num_timesteps)
250
+ if needs_rebuild:
251
+ del self._batch
252
+ self._batch = np.zeros(
253
+ (batch_size,num_timesteps)+self.sample_shape,
254
+ self.patch_data.dtype
255
+ )
256
+ return self._batch
257
+
258
+ def __call__(self, t, i0_all, j0_all):
259
+ batch = self._alloc_batch(*t.shape)
260
+
261
+ i1_all = i0_all + self.box_size[0]
262
+ j1_all = j0_all + self.box_size[1]
263
+ bi_size = self.patch_data.shape[1]
264
+ bj_size = self.patch_data.shape[2]
265
+
266
+ build_batch(batch, self.patch_data, self.patch_index,
267
+ t, i0_all, i1_all, j0_all, j1_all,
268
+ bi_size, bj_size, self.zero_value,
269
+ self.missing_value)
270
+
271
+ return batch[:,:t.shape[1],...]
272
+
273
+
274
+ @njit
275
+ def init_patch_index(patch_index, patch_coords, patch_times):
276
+ for k in range(patch_coords.shape[0]):
277
+ t = patch_times[k]
278
+ i = np.int64(patch_coords[k,0])
279
+ j = np.int64(patch_coords[k,1])
280
+ patch_index[(t,i,j)] = k
281
+
282
+
283
+ @njit
284
+ def init_patch_index_zero(patch_index, zero_patch_coords,
285
+ zero_patch_times, idx_zero):
286
+
287
+ for k in range(zero_patch_coords.shape[0]):
288
+ t = zero_patch_times[k]
289
+ i = np.int64(zero_patch_coords[k,0])
290
+ j = np.int64(zero_patch_coords[k,1])
291
+ patch_index[(t,i,j)] = idx_zero
292
+
293
+
294
+ # numba can't find these values from PatchIndex
295
+ IDX_ZERO = PatchIndex.IDX_ZERO
296
+ IDX_MISSING = PatchIndex.IDX_MISSING
297
+ @njit(parallel=True)
298
+ def build_batch(
299
+ batch, patch_data, patch_index,
300
+ t_all, i0_all, i1_all, j0_all, j1_all,
301
+ bi_size, bj_size, zero_value, missing_value
302
+ ):
303
+ for k in prange(t_all.shape[0]):
304
+ i0 = i0_all[k]
305
+ i1 = i1_all[k]
306
+ j0 = j0_all[k]
307
+ j1 = j1_all[k]
308
+
309
+ for (bt,t) in enumerate(t_all[k,:]):
310
+ for i in range(i0, i1):
311
+ bi0 = (i-i0) * bi_size
312
+ bi1 = bi0 + bi_size
313
+ for j in range(j0, j1):
314
+ ind = int(patch_index.get((t,i,j), IDX_MISSING))
315
+ bj0 = (j-j0) * bj_size
316
+ bj1 = bj0 + bj_size
317
+ if ind >= 0:
318
+ batch[k,bt,bi0:bi1,bj0:bj1] = patch_data[ind]
319
+ elif ind == IDX_ZERO:
320
+ batch[k,bt,bi0:bi1,bj0:bj1] = zero_value
321
+ elif ind == IDX_MISSING:
322
+ batch[k,bt,bi0:bi1,bj0:bj1] = missing_value
323
+
324
+
325
+ class ForecastPatchIndexWrapper(PatchIndex):
326
+ def __init__(self, patch_index):
327
+ self.patch_index = patch_index
328
+ raw_names = {"-".join(v.split("-")[:-1]) for v in patch_index}
329
+ if len(raw_names) != 1:
330
+ raise ValueError(
331
+ "Can only wrap variables with the same base name")
332
+ self.raw_name = list(raw_names)[0]
333
+ lags_hour = [int(v.split("-")[-1]) for v in patch_index]
334
+ self.lags_hour = set(lags_hour)
335
+ forecast_interval_hour = np.diff(sorted(lags_hour))
336
+ if len(set(forecast_interval_hour)) != 1:
337
+ raise ValueError("Lags must be evenly spaced")
338
+ forecast_interval_hour = forecast_interval_hour[0]
339
+ if (24 % forecast_interval_hour):
340
+ raise ValueError(
341
+ "24 hours must be a multiple of the forecast interval")
342
+ self.forecast_interval_hour = forecast_interval_hour
343
+ self.forecast_interval = 3600 * forecast_interval_hour
344
+
345
+ # need to set these for _alloc_batch to work
346
+ self._batch = None
347
+ v = list(self.patch_index.keys())[0]
348
+ self.sample_shape = self.patch_index[v].sample_shape
349
+ self.patch_data = self.patch_index[v].patch_data
350
+
351
+ def __call__(self, t, i0, j0):
352
+ batch = self._alloc_batch(*t.shape)
353
+
354
+ # ensure that all data come from the same forecast
355
+ t0 = t[:,:1]
356
+ start_time_from_fc = t0 % self.forecast_interval
357
+ time_from_fc = start_time_from_fc + (t - t0)
358
+ lags_hour = (time_from_fc // self.forecast_interval) * \
359
+ self.forecast_interval_hour
360
+
361
+ for lag in self.lags_hour:
362
+ raw_name_lag = f"{self.raw_name}-{lag}"
363
+ batch_lag = self.patch_index[raw_name_lag](t,i0,j0)
364
+ lag_mask = (lags_hour == lag)
365
+ copy_masked_times(batch_lag, batch, lag_mask)
366
+
367
+ return batch[:,:t.shape[1],...]
368
+
369
+
370
+ @njit(parallel=True)
371
+ def copy_masked_times(from_batch, to_batch, mask):
372
+ for k in prange(from_batch.shape[0]):
373
+ for bt in range(from_batch.shape[1]):
374
+ if mask[k,bt]:
375
+ to_batch[k,bt,:,:] = from_batch[k,bt,:,:]
ldcast/features/batch.py.save ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timedelta
2
+ import os
3
+ import pickle
4
+
5
+ from numba import njit, prange, types
6
+ from numba.typed import Dict
7
+ import numpy as np
8
+ from torch.utils.data import Dataset, IterableDataset
9
+
10
+ from .patches import unpack_patches
11
+ from .sampling import EqualFrequencySampler
12
+
13
+
14
+ class BatchGenerator:
15
+ def __init__(self,
16
+ variables,
17
+ raw,
18
+ predictors,
19
+ target,
20
+ primary_var,
21
+ time_range_sampling=(-1,2),
22
+ forecast_raw_vars=(),
23
+ sampling_bins=None,
24
+ sampler_file=None,
25
+ sample_shape=(4,4),
26
+ batch_size=32,
27
+ interval=timedelta(minutes=5),
28
+ random_seed=None,
29
+ augment=False
30
+ ):
31
+ super().__init__()
32
+ self.batch_size = batch_size
33
+ self.interval = interval
34
+ self.interval_secs = np.int64(self.interval.total_seconds())
35
+ self.variables = variables
36
+ self.predictors = predictors
37
+ self.target = target
38
+ self.used_variables = predictors + [target]
39
+ self.rng = np.random.RandomState(seed=random_seed)
40
+ self.augment = augment
41
+
42
+ # setup indices for retrieving source raw data
43
+ self.sources = set.union(
44
+ *(set(variables[v]["sources"]) for v in self.used_variables)
45
+ )
46
+ self.forecast_raw_vars = set(forecast_raw_vars) & self.sources
47
+ self.patch_index = {}
48
+ for raw_name_base in self.sources:
49
+ if raw_name_base in forecast_raw_vars:
50
+ raw_names = (
51
+ rn for rn in raw if rn.startswith(raw_name_base+"-")
52
+ )
53
+ else:
54
+ raw_names = (raw_name_base,)
55
+ for raw_name in raw_names:
56
+ raw_data = raw[raw_name]
57
+ self.setup_index(raw_name, raw_data, sample_shape)
58
+
59
+ for raw_name in self.forecast_raw_vars:
60
+ patch_index_var = {
61
+ k: v for (k,v) in self.patch_index.items()
62
+ if k.startswith(raw_name+"-")
63
+ }
64
+ self.patch_index[raw_name] = \
65
+ ForecastPatchIndexWrapper(patch_index_var)
66
+
67
+ # setup samplers
68
+ if (sampler_file is None) or not os.path.isfile(sampler_file):
69
+ print("No cached sampler found, creating a new one...")
70
+ primary_raw_var = variables[primary_var]["sources"][0]
71
+ t0 = t1 = None
72
+ for (var_name, var_data) in variables.items():
73
+ timesteps = var_data["timesteps"][[0,-1]].copy()
74
+ timesteps[0] -= 1
75
+ ts_secs = timesteps * \
76
+ var_data.get("timestep_secs", self.interval_secs)
77
+ timesteps = ts_secs // self.interval_secs
78
+ t0 = timesteps[0] if t0 is None else min(t0,timesteps[0])
79
+ t1 = timesteps[-1] if t1 is None else max(t1,timesteps[-1])
80
+ time_range_valid = (t0,t1+1)
81
+ self.sampler = EqualFrequencySampler(
82
+ sampling_bins, raw[primary_raw_var],
83
+ self.patch_index[primary_raw_var], sample_shape,
84
+ time_range_valid, time_range_sampling=time_range_sampling,
85
+ timestep_secs=self.interval_secs
86
+ )
87
+ if sampler_file is not None:
88
+ print(f"Caching sampler to {sampler_file}.")
89
+ with open(sampler_file, 'wb') as f:
90
+ pickle.dump(self.sampler, f)
91
+ else:
92
+ print(f"Loading cached sampler from {sampler_file}.")
93
+ with open(sampler_file, 'rb') as f:
94
+ self.sampler = pickle.load(f)
95
+
96
+ def setup_index(self, raw_name, raw_data, box_size):
97
+ zero_value = raw_data.get("zero_value", 0)
98
+ missing_value = raw_data.get("missing_value", zero_value)
99
+
100
+ self.patch_index[raw_name] = PatchIndex(
101
+ *unpack_patches(raw_data),
102
+ zero_value=zero_value,
103
+ missing_value=missing_value,
104
+ interval=self.interval,
105
+ box_size=box_size
106
+ )
107
+
108
+ def augmentations(self):
109
+ return tuple(self.rng.randint(2, size=3))
110
+
111
+ def augment_batch(self, batch, transpose, flipud, fliplr):
112
+ if self.augment:
113
+ if transpose:
114
+ axes = list(range(batch.ndim))
115
+ axes = axes[:-2] + [axes[-1], axes[-2]]
116
+ batch = batch.transpose(axes)
117
+ flips = []
118
+ if flipud:
119
+ flips.append(-2)
120
+ if fliplr:
121
+ flips.append(-1)
122
+ if flips:
123
+ batch = np.flip(batch, axis=flips)
124
+ return batch.copy()
125
+
126
+ def batch(self, samples=None, batch_size=None):
127
+ if batch_size is None:
128
+ batch_size = self.batch_size
129
+
130
+ if samples is None:
131
+ # get the sample coordinates from the sampler
132
+ samples = self.sampler(batch_size)
133
+
134
+ print(samples)
135
+ (t0,i0,j0) = samples.T
136
+
137
+ if self.augment:
138
+ augmentations = self.augmentations()
139
+
140
+ batch = {}
141
+
142
+
143
+
144
+ for var_name in self.used_variables:
145
+ var_data = self.variables[var_name]
146
+
147
+ # different timestep from standard (e.g. forecast); round down
148
+ # to times where we have data available
149
+ ts_secs = var_data.get("timestep_secs", self.interval_secs)
150
+ t_shift = -(t0 % ts_secs)
151
+ t0_shifted = t0 + t_shift
152
+ t = t0_shifted[:,None] + ts_secs*var_data["timesteps"][None,:]
153
+ t_relative = (t - t0[:,None]) / self.interval_secs
154
+
155
+ # read raw data from index
156
+ raw_data = (
157
+ self.patch_index[raw_name](t,i0,j0)
158
+ for raw_name in var_data["sources"]
159
+ )
160
+
161
+ # transform to model variable
162
+ batch_var = var_data["transform"](*raw_data)
163
+
164
+ # add channel dimension if not already present
165
+ add_dims = (1,) if batch_var.ndim == 4 else ()
166
+ batch_var = np.expand_dims(batch_var, add_dims)
167
+
168
+ # data augmentation
169
+ if self.augment:
170
+ batch_var = self.augment_batch(batch_var, *augmentations)
171
+
172
+ # bundle with time coordinates
173
+ batch[var_name] = (batch_var, t_relative.astype(np.float32))
174
+
175
+ pred_batch = [batch[v] for v in self.predictors]
176
+ target_batch = batch[self.target][0] # no time coordinates for target
177
+ return (pred_batch, target_batch)
178
+
179
+ def batches(self, *args, num=None, **kwargs):
180
+ if num is not None:
181
+ for i in range(num):
182
+ yield self.batch(*args, **kwargs)
183
+ else:
184
+ while True:
185
+ yield self.batch(*args, **kwargs)
186
+
187
+
188
+ class StreamBatchDataset(IterableDataset):
189
+ def __init__(self, batch_gen, batches_per_epoch):
190
+ super().__init__()
191
+ self.batch_gen = batch_gen
192
+ self.batches_per_epoch = batches_per_epoch
193
+
194
+ def __iter__(self):
195
+ batches = self.batch_gen.batches(num=self.batches_per_epoch)
196
+ yield from batches
197
+
198
+
199
+ class DeterministicBatchDataset(Dataset):
200
+ def __init__(self, batch_gen, batches_per_epoch, random_seed=None):
201
+ super().__init__()
202
+ self.batch_gen = batch_gen
203
+ self.batches_per_epoch = batches_per_epoch
204
+ self.batch_gen.sampler.rng = np.random.RandomState(seed=random_seed)
205
+ self.samples = [
206
+ self.batch_gen.sampler(self.batch_gen.batch_size)
207
+ for i in range(self.batches_per_epoch)
208
+ ]
209
+
210
+ def __len__(self):
211
+ return self.batches_per_epoch
212
+
213
+ def __getitem__(self, ind):
214
+ print(self.samples[ind])
215
+ return self.batch_gen.batch(samples=self.samples[ind])
216
+
217
+
218
+ class PatchIndex:
219
+ IDX_ZERO = -1
220
+ IDX_MISSING = -2
221
+
222
+ def __init__(
223
+ self, patch_data, patch_coords, patch_times,
224
+ zero_patch_coords, zero_patch_times,
225
+ interval=timedelta(minutes=5),
226
+ box_size=(4,4), zero_value=0,
227
+ missing_value=0
228
+ ):
229
+ self.dt = int(round(interval.total_seconds()))
230
+ self.box_size = box_size
231
+ self.zero_value = zero_value
232
+ self.missing_value = missing_value
233
+ self.patch_data = patch_data
234
+ self.sample_shape = (
235
+ self.patch_data.shape[1]*box_size[0],
236
+ self.patch_data.shape[2]*box_size[1]
237
+ )
238
+
239
+ self.patch_index = Dict.empty(
240
+ key_type=types.UniTuple(types.int64, 3),
241
+ value_type=types.int64
242
+ )
243
+ init_patch_index(self.patch_index, patch_coords, patch_times)
244
+ init_patch_index_zero(self.patch_index, zero_patch_coords,
245
+ zero_patch_times, PatchIndex.IDX_ZERO)
246
+
247
+ self._batch = None
248
+
249
+ def _alloc_batch(self, batch_size, num_timesteps):
250
+ needs_rebuild = (self._batch is None) or \
251
+ (self._batch.shape[0] < batch_size) or \
252
+ (self._batch.shape[1] < num_timesteps)
253
+ if needs_rebuild:
254
+ del self._batch
255
+ self._batch = np.zeros(
256
+ (batch_size,num_timesteps)+self.sample_shape,
257
+ self.patch_data.dtype
258
+ )
259
+ return self._batch
260
+
261
+ def __call__(self, t, i0_all, j0_all):
262
+ batch = self._alloc_batch(*t.shape)
263
+
264
+ i1_all = i0_all + self.box_size[0]
265
+ j1_all = j0_all + self.box_size[1]
266
+ bi_size = self.patch_data.shape[1]
267
+ bj_size = self.patch_data.shape[2]
268
+
269
+ build_batch(batch, self.patch_data, self.patch_index,
270
+ t, i0_all, i1_all, j0_all, j1_all,
271
+ bi_size, bj_size, self.zero_value,
272
+ self.missing_value)
273
+
274
+ return batch[:,:t.shape[1],...]
275
+
276
+
277
+ @njit
278
+ def init_patch_index(patch_index, patch_coords, patch_times):
279
+ for k in range(patch_coords.shape[0]):
280
+ t = patch_times[k]
281
+ i = np.int64(patch_coords[k,0])
282
+ j = np.int64(patch_coords[k,1])
283
+ patch_index[(t,i,j)] = k
284
+
285
+
286
+ @njit
287
+ def init_patch_index_zero(patch_index, zero_patch_coords,
288
+ zero_patch_times, idx_zero):
289
+
290
+ for k in range(zero_patch_coords.shape[0]):
291
+ t = zero_patch_times[k]
292
+ i = np.int64(zero_patch_coords[k,0])
293
+ j = np.int64(zero_patch_coords[k,1])
294
+ patch_index[(t,i,j)] = idx_zero
295
+
296
+
297
+ # numba can't find these values from PatchIndex
298
+ IDX_ZERO = PatchIndex.IDX_ZERO
299
+ IDX_MISSING = PatchIndex.IDX_MISSING
300
+ @njit(parallel=True)
301
+ def build_batch(
302
+ batch, patch_data, patch_index,
303
+ t_all, i0_all, i1_all, j0_all, j1_all,
304
+ bi_size, bj_size, zero_value, missing_value
305
+ ):
306
+ for k in prange(t_all.shape[0]):
307
+ i0 = i0_all[k]
308
+ i1 = i1_all[k]
309
+ j0 = j0_all[k]
310
+ j1 = j1_all[k]
311
+
312
+ for (bt,t) in enumerate(t_all[k,:]):
313
+ for i in range(i0, i1):
314
+ bi0 = (i-i0) * bi_size
315
+ bi1 = bi0 + bi_size
316
+ for j in range(j0, j1):
317
+ ind = int(patch_index.get((t,i,j), IDX_MISSING))
318
+ bj0 = (j-j0) * bj_size
319
+ bj1 = bj0 + bj_size
320
+ if ind >= 0:
321
+ batch[k,bt,bi0:bi1,bj0:bj1] = patch_data[ind]
322
+ elif ind == IDX_ZERO:
323
+ batch[k,bt,bi0:bi1,bj0:bj1] = zero_value
324
+ elif ind == IDX_MISSING:
325
+ batch[k,bt,bi0:bi1,bj0:bj1] = missing_value
326
+
327
+
328
+ class ForecastPatchIndexWrapper(PatchIndex):
329
+ def __init__(self, patch_index):
330
+ self.patch_index = patch_index
331
+ raw_names = {"-".join(v.split("-")[:-1]) for v in patch_index}
332
+ if len(raw_names) != 1:
333
+ raise ValueError(
334
+ "Can only wrap variables with the same base name")
335
+ self.raw_name = list(raw_names)[0]
336
+ lags_hour = [int(v.split("-")[-1]) for v in patch_index]
337
+ self.lags_hour = set(lags_hour)
338
+ forecast_interval_hour = np.diff(sorted(lags_hour))
339
+ if len(set(forecast_interval_hour)) != 1:
340
+ raise ValueError("Lags must be evenly spaced")
341
+ forecast_interval_hour = forecast_interval_hour[0]
342
+ if (24 % forecast_interval_hour):
343
+ raise ValueError(
344
+ "24 hours must be a multiple of the forecast interval")
345
+ self.forecast_interval_hour = forecast_interval_hour
346
+ self.forecast_interval = 3600 * forecast_interval_hour
347
+
348
+ # need to set these for _alloc_batch to work
349
+ self._batch = None
350
+ v = list(self.patch_index.keys())[0]
351
+ self.sample_shape = self.patch_index[v].sample_shape
352
+ self.patch_data = self.patch_index[v].patch_data
353
+
354
+ def __call__(self, t, i0, j0):
355
+ batch = self._alloc_batch(*t.shape)
356
+
357
+ # ensure that all data come from the same forecast
358
+ t0 = t[:,:1]
359
+ start_time_from_fc = t0 % self.forecast_interval
360
+ time_from_fc = start_time_from_fc + (t - t0)
361
+ lags_hour = (time_from_fc // self.forecast_interval) * \
362
+ self.forecast_interval_hour
363
+
364
+ for lag in self.lags_hour:
365
+ raw_name_lag = f"{self.raw_name}-{lag}"
366
+ batch_lag = self.patch_index[raw_name_lag](t,i0,j0)
367
+ lag_mask = (lags_hour == lag)
368
+ copy_masked_times(batch_lag, batch, lag_mask)
369
+
370
+ return batch[:,:t.shape[1],...]
371
+
372
+
373
+ @njit(parallel=True)
374
+ def copy_masked_times(from_batch, to_batch, mask):
375
+ for k in prange(from_batch.shape[0]):
376
+ for bt in range(from_batch.shape[1]):
377
+ if mask[k,bt]:
378
+ to_batch[k,bt,:,:] = from_batch[k,bt,:,:]
ldcast/features/io.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import netCDF4
4
+ import numpy as np
5
+
6
+
7
+ def convert_var_for_saving(
8
+ x, fill_value=0.02, min_value=0.05, max_value=118.428,
9
+ mean=-0.051, std=0.528
10
+ ):
11
+ y = x*std + mean
12
+ log_min = np.log10(min_value)
13
+ log_max = np.log10(max_value)
14
+ mask = (y >= log_min)
15
+ y = y[mask].clip(max=log_max)
16
+ y = (y-log_min) / (log_max-log_min)
17
+ yc = np.zeros_like(x, dtype=np.uint16)
18
+ yc[mask] = (y*65533).round().astype(np.uint16) + 1
19
+ return yc
20
+
21
+
22
+ def decode_saved_var_to_rainrate(
23
+ x, fill_value=0.02, min_value=0.05, threshold=0.1, max_value=118.428,
24
+ mean=-0.051, std=0.528, log=False, preproc=None
25
+ ):
26
+ mask = (x >= 1)
27
+ log_min = np.log10(min_value)
28
+ log_max = np.log10(max_value)
29
+ yc = log_min + (x[mask].astype(np.float32)-1) * \
30
+ ((log_max-log_min) / 65533)
31
+ y = np.zeros_like(x, dtype=np.float32)
32
+
33
+ yc = 10**yc
34
+ y[mask] = yc
35
+ if preproc is not None:
36
+ y = [preproc[t](y[:,:,t,...]) for t in range(y.shape[2])]
37
+ y = np.stack(y, axis=2)
38
+
39
+ if log:
40
+ y[y < threshold] = fill_value
41
+ y = np.log10(y)
42
+ else:
43
+ y[y < threshold] = 0.0
44
+
45
+ return y
46
+
47
+
48
+ def save_batch(x, y, y_pred, batch_index, fn_template, out_dir, out_fn=None):
49
+ while isinstance(x, list) or isinstance(x, tuple):
50
+ x = x[0]
51
+
52
+ x = convert_var_for_saving(np.array(x, copy=False))
53
+ y = convert_var_for_saving(np.array(y, copy=False))
54
+ y_pred = convert_var_for_saving(np.array(y_pred, copy=False))
55
+
56
+ if out_fn is None:
57
+ out_fn = fn_template.format(batch_index=batch_index)
58
+ out_fn = os.path.join(out_dir, out_fn)
59
+
60
+ with netCDF4.Dataset(out_fn, 'w') as ds:
61
+ dim_sample = ds.createDimension("dim_sample", y.shape[0])
62
+ dim_channel = ds.createDimension("dim_channel", y.shape[1])
63
+ dim_time_past = ds.createDimension("dim_time_past", x.shape[2])
64
+ dim_time_future = ds.createDimension("dim_time_future", y.shape[2])
65
+ dim_h = ds.createDimension("dim_h", y.shape[3])
66
+ dim_w = ds.createDimension("dim_w", y.shape[4])
67
+ dim_member = ds.createDimension("dim_member", y_pred.shape[5])
68
+ var_params = {"zlib": True, "complevel": 1}
69
+
70
+ var_fc = ds.createVariable(
71
+ "forecasts", y_pred.dtype,
72
+ (
73
+ "dim_sample", "dim_channel",
74
+ "dim_time_future", "dim_h", "dim_w", "dim_member"
75
+ ),
76
+ **var_params
77
+ )
78
+ var_fc[:] = y_pred
79
+
80
+ var_obs_past = ds.createVariable(
81
+ "past_observations", x.dtype,
82
+ ("dim_sample", "dim_channel", "dim_time_past", "dim_h", "dim_w"),
83
+ **var_params
84
+ )
85
+ var_obs_past[:] = x
86
+
87
+ var_obs_future = ds.createVariable(
88
+ "future_observations", y.dtype,
89
+ ("dim_sample", "dim_channel", "dim_time_future", "dim_h", "dim_w"),
90
+ **var_params
91
+ )
92
+ var_obs_future[:] = y
93
+
94
+
95
+ def load_batch(fn, decode=True, preproc_fc=None, **kwargs):
96
+ with netCDF4.Dataset(fn, 'r') as ds:
97
+ y_pred = np.array(ds["forecasts"][:], copy=False)
98
+ x = np.array(ds["past_observations"][:], copy=False)
99
+ y = np.array(ds["future_observations"][:], copy=False)
100
+
101
+ if decode:
102
+ x = decode_saved_var_to_rainrate(x, **kwargs)
103
+ y = decode_saved_var_to_rainrate(y, **kwargs)
104
+ y_pred = decode_saved_var_to_rainrate(
105
+ y_pred, preproc=preproc_fc, **kwargs
106
+ )
107
+
108
+ return (x, y, y_pred)
109
+
110
+
111
+ def load_all_observations(
112
+ ensemble_dir, decode=True, preproc_fc=None,
113
+ timeframe='future', **kwargs
114
+ ):
115
+ files = os.listdir(ensemble_dir)
116
+ obs = []
117
+ for fn in sorted(files):
118
+ with netCDF4.Dataset(os.path.join(ensemble_dir, fn), 'r') as ds:
119
+ var = f"{timeframe}_observations"
120
+ x = np.array(ds[var][:], copy=False)
121
+ x = decode_saved_var_to_rainrate(x, **kwargs)
122
+ obs.append(x)
123
+
124
+ obs = np.concatenate(obs, axis=0)
125
+ return obs
ldcast/features/patches.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timedelta
2
+ import os
3
+
4
+ import dask
5
+ import netCDF4
6
+ import numpy as np
7
+
8
+ from .utils import average_pool
9
+
10
+
11
+ def patch_locations(
12
+ time_range,
13
+ patch_box,
14
+ patch_shape=(32,32),
15
+ interval=timedelta(minutes=5),
16
+ epoch=(1970,1,1)
17
+ ):
18
+ patches = {}
19
+ t = time_range[0]
20
+ while t < time_range[1]:
21
+ patches[t] = []
22
+ for pi in range(patch_box[0][0], patch_box[0][1]):
23
+ for pj in range(patch_box[1][0], patch_box[1][1]):
24
+ patches[t].append((pi,pj))
25
+ patches[t] = np.array(patches[t])
26
+ t += interval
27
+
28
+ return patches
29
+
30
+
31
+ def save_patches_radar(
32
+ patches, archive_path, out_dir,
33
+ variables=("RZC", "CPCH"),
34
+ suffix="2020",
35
+ **kwargs
36
+ ):
37
+ from ..datasets import mchradar
38
+
39
+ source_vars = {}
40
+ mchradar_reader = mchradar.MCHRadarReader(
41
+ archive_path=archive_path,
42
+ variables=variables,
43
+ phys_values=False
44
+ )
45
+
46
+ ezc_nonzero_count_func = lambda x: np.count_nonzero((x >= 1) & (x<251))
47
+ nonzero_count_func = {
48
+ "RZC": lambda x: np.count_nonzero(x > 1),
49
+ "CPCH": lambda x: np.count_nonzero(x > 1)
50
+ }
51
+ zero_value = {v: 0 for v in variables}
52
+ zero_value["RZC"] = 1
53
+ zero_value["CPCH"] = 1
54
+
55
+ save_patches_all(
56
+ mchradar_reader, patches, variables,
57
+ nonzero_count_func, zero_value, out_dir, suffix,
58
+ source_vars=source_vars, min_nonzeros_to_include=5,
59
+ **kwargs
60
+ )
61
+
62
+
63
+ def save_patches_dwdradar(
64
+ patches, archive_path, out_dir,
65
+ variables=("RV",),
66
+ suffix="2022",
67
+ patch_shape=(32,32),
68
+ **kwargs
69
+ ):
70
+ from ..datasets import dwdradar
71
+
72
+ source_vars = {}
73
+ dwdradar_reader = dwdradar.DWDRadarReader(
74
+ archive_path=archive_path,
75
+ variables=variables
76
+ )
77
+
78
+ patches_flt = {}
79
+ for t in sorted(patches):
80
+ if (t.hour==0) and (t.minute==0):
81
+ print(t)
82
+
83
+ try:
84
+ data = dwdradar_reader.variable_for_time(t, "RV")
85
+ except FileNotFoundError:
86
+ continue
87
+
88
+ patch_locs_time = []
89
+ for (pi,pj) in patches[t]:
90
+ i0 = pi * patch_shape[0]
91
+ i1 = i0 + patch_shape[0]
92
+ j0 = pj * patch_shape[1]
93
+ j1 = j0 + patch_shape[1]
94
+ patch = data[i0:i1,j0:j1]
95
+ if np.isfinite(patch).all():
96
+ patch_locs_time.append((pi,pj))
97
+
98
+ if patch_locs_time:
99
+ patches_flt[t] = np.array(patch_locs_time)
100
+
101
+ print(len(patches),len(patches_flt))
102
+ patches = patches_flt
103
+
104
+ nonzero_count_func = {"RV": np.count_nonzero}
105
+ zero_value = {v: 0 for v in variables}
106
+
107
+ save_patches_all(
108
+ dwdradar_reader, patches, variables,
109
+ nonzero_count_func, zero_value, out_dir, suffix,
110
+ source_vars=source_vars, min_nonzeros_to_include=5,
111
+ **kwargs
112
+ )
113
+
114
+
115
+ def save_patches_ifs(
116
+ patches, archive_path, out_dir,
117
+ variables=(
118
+ #"rate-tp", "rate-cp", "t2m", "cape", "cin",
119
+ #"tclw", "tcwv", #, "rate-tpe"
120
+ #"u", "v"
121
+ "cin",
122
+ ),
123
+ suffix="2020",
124
+ lags=(0,12)
125
+ ):
126
+ from ..datasets import ifsnwp
127
+ from .. import projection
128
+
129
+ proj = projection.GridProjection(projection.ccs4_swiss_grid_area)
130
+ ifsnwp_reader = ifsnwp.IFSNWPReader(
131
+ proj,
132
+ archive_path=archive_path,
133
+ variables=variables,
134
+ lags=lags,
135
+ )
136
+
137
+ # we only get data for every hour, so modify patches
138
+ ifs_patches = {
139
+ dt: pset for (dt, pset) in patches.items()
140
+ if (dt.minute == dt.second == dt.microsecond == 0)
141
+ }
142
+
143
+ variables_with_lag = []
144
+ for lag in ifsnwp_reader.lags:
145
+ variables_with_lag.extend(f"{v}-{lag}" for v in variables)
146
+
147
+ count_positive = lambda x: np.count_nonzero(x > 0)
148
+ all_nonzero = lambda x: np.prod(x.shape)
149
+ nonzero_count_func = {
150
+ "rate-tp": count_positive,
151
+ "rate-cp": count_positive,
152
+ "t2m": all_nonzero,
153
+ "cape": count_positive,
154
+ "cin": count_positive,
155
+ "tclw": count_positive,
156
+ "tcwv": count_positive,
157
+ "u": all_nonzero,
158
+ "v": all_nonzero,
159
+ "rate-tpe": count_positive,
160
+ }
161
+ nonzero_count_func = {
162
+ v: nonzero_count_func[v.rsplit("-", 1)[0]]
163
+ for v in variables_with_lag
164
+ }
165
+ postproc = {
166
+ f"cin-{lag}": lambda x: np.nan_to_num(x, nan=0.0, copy=False)
167
+ for lag in lags
168
+ }
169
+ zero_value = {v: 0 for v in variables_with_lag}
170
+ avg_pool = lambda x: average_pool(x, factor=8, missing=np.nan)
171
+ pool = {v: avg_pool for v in variables_with_lag}
172
+
173
+ save_patches_all(ifsnwp_reader, ifs_patches, variables_with_lag,
174
+ nonzero_count_func, zero_value, out_dir, suffix, pool=pool,
175
+ postproc=postproc)
176
+
177
+
178
+ def save_patches_cosmo(patches, archive_path, out_dir, suffix="2020"):
179
+ from ..datasets import cosmonwp
180
+
181
+ cosmonwp_reader = cosmonwp.COSMOCCS4Reader(
182
+ archive_path=archive_path, cache_size=6000)
183
+
184
+ # we only get data for every hour, so modify patches
185
+ cosmo_patches = {}
186
+ for (dt,pset) in patches.items():
187
+ dt0 = datetime(dt.year, dt.month, dt.day, dt.hour)
188
+ dt1 = dt0 + timedelta(hours=1)
189
+ if dt0 not in cosmo_patches:
190
+ cosmo_patches[dt0] = set()
191
+ if dt1 not in cosmo_patches:
192
+ cosmo_patches[dt1] = set()
193
+ cosmo_patches[dt0].update(pset)
194
+ cosmo_patches[dt1].update(pset)
195
+
196
+ variables = [
197
+ "CAPE_MU", "CIN_MU", "SLI",
198
+ "HZEROCL", "LCL_ML", "MCONV", "OMEGA",
199
+ "T_2M", "T_SO", "SOILTYP"
200
+ ]
201
+ count_positive = lambda x: np.count_nonzero(x>0)
202
+ all_nonzero = lambda x: np.prod(x.shape)
203
+ nonzero_count_func = {
204
+ "CAPE_MU": count_positive,
205
+ "CIN_MU": count_positive,
206
+ "SLI": all_nonzero,
207
+ "HZEROCL": count_positive,
208
+ "LCL_ML": count_positive,
209
+ "MCONV": all_nonzero,
210
+ "OMEGA": all_nonzero,
211
+ "T_2M": all_nonzero,
212
+ "T_SO": all_nonzero,
213
+ "SOILTYP": lambda x: np.count_nonzero(x!=5)
214
+ }
215
+ zero_value = {v: 0 for v in variables}
216
+ zero_value["SOILTYP"] = 5
217
+
218
+ save_patches_all(cosmonwp_reader, cosmo_patches, variables,
219
+ nonzero_count_func, zero_value, out_dir, suffix, pool=pool)
220
+
221
+
222
+ def save_patches_all(
223
+ reader, patches, variables, nonzero_count_func, zero_value,
224
+ out_dir, suffix, epoch=datetime(1970,1,1), postproc={}, scale=None,
225
+ pool={}, source_vars={}, parallel=False, min_nonzeros_to_include=1
226
+ ):
227
+
228
+ def save_var(var_name):
229
+ src_name = source_vars.get(var_name, var_name)
230
+
231
+ (patch_data, patch_coords, patch_times,
232
+ zero_patch_coords, zero_patch_times) = get_patches(
233
+ reader, src_name, patches,
234
+ nonzero_count_func=nonzero_count_func[var_name],
235
+ postproc=postproc.get(var_name),
236
+ pool=pool.get(var_name)
237
+ )
238
+ try:
239
+ time = epoch + timedelta(seconds=int(patch_times[0]))
240
+ var_scale = reader.get_scale(time, var_name)
241
+ except (AttributeError, KeyError):
242
+ var_scale = None if (scale is None) else scale[var_name]
243
+ pass
244
+
245
+ var_name = var_name.replace("_", "-")
246
+ out_fn = f"patches_{var_name}_{suffix}.nc"
247
+ out_path = os.path.join(out_dir, var_name)
248
+ os.makedirs(out_path, exist_ok=True)
249
+ out_fn = os.path.join(out_path, out_fn)
250
+
251
+ save_patches(
252
+ patch_data, patch_coords, patch_times,
253
+ zero_patch_coords, zero_patch_times, out_fn,
254
+ zero_value=zero_value[var_name], scale=var_scale
255
+ )
256
+
257
+ if parallel:
258
+ save_var = dask.delayed(save_var)
259
+
260
+ jobs = [save_var(v) for v in variables]
261
+ if parallel:
262
+ dask.compute(jobs, scheduler='threads')
263
+
264
+
265
+ def get_patches(
266
+ reader, variable, patches,
267
+ patch_shape=(32,32), nonzero_count_func=None,
268
+ epoch=datetime(1970,1,1), postproc=None,
269
+ pool=None, min_nonzeros_to_include=1
270
+ ):
271
+ num_patches = sum(len(patches[t]) for t in patches)
272
+ patch_data = []
273
+ patch_coords = []
274
+ patch_times = []
275
+ zero_patch_coords = []
276
+ zero_patch_times = []
277
+
278
+ if hasattr(reader, "phys_values"):
279
+ phys_values = reader.phys_values
280
+
281
+ k = 0
282
+ try:
283
+ if hasattr(reader, "phys_values"):
284
+ reader.phys_values = False
285
+ for (t, p_coord) in patches.items():
286
+ try:
287
+ data = reader.variable_for_time(t, variable)
288
+ except (ValueError, FileNotFoundError, KeyError, OSError):
289
+ continue
290
+
291
+ if postproc is not None:
292
+ data = postproc(data)
293
+
294
+ time_sec = np.int64((t-epoch).total_seconds())
295
+ for (pi, pj) in p_coord:
296
+ if k % 100000 == 0:
297
+ print("{}: {}/{}".format(t, k, num_patches))
298
+ patch_box = data[
299
+ pi*patch_shape[0]:(pi+1)*patch_shape[0],
300
+ pj*patch_shape[1]:(pj+1)*patch_shape[1],
301
+ ].copy()
302
+ is_nonzero = (nonzero_count_func is not None) and \
303
+ (nonzero_count_func(patch_box) < min_nonzeros_to_include)
304
+ if is_nonzero:
305
+ zero_patch_coords.append((pi,pj))
306
+ zero_patch_times.append(time_sec)
307
+ else:
308
+ if pool is not None:
309
+ patch_box = pool(patch_box)
310
+ patch_data.append(patch_box)
311
+ patch_coords.append((pi,pj))
312
+ patch_times.append(time_sec)
313
+ k += 1
314
+
315
+ finally:
316
+ if hasattr(reader, "phys_values"):
317
+ reader.phys_values = phys_values
318
+
319
+ if zero_patch_coords:
320
+ zero_patch_coords = np.stack(zero_patch_coords, axis=0).astype(np.uint16)
321
+ zero_patch_times = np.stack(zero_patch_times, axis=0)
322
+ else:
323
+ zero_patch_coords = np.zeros((0,2), dtype=np.uint16)
324
+ zero_patch_times = np.zeros((0,), dtype=np.int64)
325
+ patch_data = np.stack(patch_data, axis=0)
326
+ patch_coords = np.stack(patch_coords, axis=0).astype(np.uint16)
327
+ patch_times = np.stack(patch_times, axis=0)
328
+
329
+ return (patch_data, patch_coords, patch_times,
330
+ zero_patch_coords, zero_patch_times)
331
+
332
+
333
+ def save_patches(patch_data, patch_coords, patch_times,
334
+ zero_patch_coords, zero_patch_times, out_fn, zero_value=0, scale=None):
335
+
336
+ with netCDF4.Dataset(out_fn, 'w') as ds:
337
+ dim_patch = ds.createDimension("dim_patch", patch_data.shape[0])
338
+ dim_zero_patch = ds.createDimension("dim_zero_patch", zero_patch_coords.shape[0])
339
+ dim_coord = ds.createDimension("dim_coord", 2)
340
+ dim_height = ds.createDimension("dim_height", patch_data.shape[1])
341
+ dim_width = ds.createDimension("dim_width", patch_data.shape[2])
342
+
343
+ var_args = {"zlib": True, "complevel": 1}
344
+
345
+ chunksizes = (min(2**10, patch_data.shape[0]), patch_data.shape[1], patch_data.shape[2])
346
+ var_patch = ds.createVariable("patches", patch_data.dtype,
347
+ ("dim_patch","dim_height","dim_width"), chunksizes=chunksizes, **var_args)
348
+ var_patch[:] = patch_data
349
+
350
+ var_patch_coord = ds.createVariable("patch_coords", patch_coords.dtype,
351
+ ("dim_patch","dim_coord"), **var_args)
352
+ var_patch_coord[:] = patch_coords
353
+
354
+ var_patch_time = ds.createVariable("patch_times", patch_times.dtype,
355
+ ("dim_patch",), **var_args)
356
+ var_patch_time[:] = patch_times
357
+
358
+ var_zero_patch_coord = ds.createVariable("zero_patch_coords", zero_patch_coords.dtype,
359
+ ("dim_zero_patch","dim_coord"), **var_args)
360
+ var_zero_patch_coord[:] = zero_patch_coords
361
+
362
+ var_zero_patch_time = ds.createVariable("zero_patch_times", zero_patch_times.dtype,
363
+ ("dim_zero_patch",), **var_args)
364
+ var_zero_patch_time[:] = zero_patch_times
365
+
366
+ ds.zero_value = zero_value
367
+
368
+ if scale is not None:
369
+ dim_scale = ds.createDimension("dim_scale", len(scale))
370
+ var_scale = ds.createVariable("scale", scale.dtype, ("dim_scale",), **var_args)
371
+ var_scale[:] = scale
372
+
373
+
374
+ def load_patches(fn, in_memory=True):
375
+ if in_memory:
376
+ with open(fn, 'rb') as f:
377
+ ds_raw = f.read()
378
+ fn = None
379
+ else:
380
+ ds_raw = None
381
+
382
+ with netCDF4.Dataset(fn, 'r', memory=ds_raw) as ds:
383
+ patch_data = {
384
+ "patches": np.array(ds["patches"]),
385
+ "patch_coords": np.array(ds["patch_coords"]),
386
+ "patch_times": np.array(ds["patch_times"]),
387
+ "zero_patch_coords": np.array(ds["zero_patch_coords"]),
388
+ "zero_patch_times": np.array(ds["zero_patch_times"]),
389
+ "zero_value": ds.zero_value
390
+ }
391
+ if "scale" in ds.variables:
392
+ patch_data["scale"] = np.array(ds["scale"])
393
+
394
+ return patch_data
395
+
396
+
397
+ def load_all_patches(patch_dir, var):
398
+ files = os.listdir(patch_dir)
399
+ jobs = []
400
+ for fn in files:
401
+ file_var = fn.split("_")[1]
402
+ if file_var == var:
403
+ fn = os.path.join(patch_dir, fn)
404
+ jobs.append(dask.delayed(load_patches)(fn))
405
+
406
+ file_data = dask.compute(jobs, scheduler="processes")[0]
407
+ patch_data = {}
408
+ keys = ["patches", "patch_coords", "patch_times",
409
+ "zero_patch_coords", "zero_patch_times"]
410
+ for k in keys:
411
+ patch_data[k] = np.concatenate(
412
+ [fd[k] for fd in file_data],
413
+ axis=0
414
+ )
415
+ patch_data["zero_value"] = file_data[0]["zero_value"]
416
+ if "scale" in file_data[0]:
417
+ patch_data["scale"] = file_data[0]["scale"]
418
+
419
+ return patch_data
420
+
421
+
422
+ def unpack_patches(patch_data):
423
+ return (
424
+ patch_data["patches"],
425
+ patch_data["patch_coords"],
426
+ patch_data["patch_times"],
427
+ patch_data["zero_patch_coords"],
428
+ patch_data["zero_patch_times"]
429
+ )
ldcast/features/patches.py.save ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timedelta
2
+ import os
3
+
4
+ import dask
5
+ import netCDF4
6
+ import numpy as np
7
+
8
+ from .utils import average_pool
9
+
10
+
11
+ def patch_locations(
12
+ time_range,
13
+ patch_box,
14
+ patch_shape=(32,32),
15
+ interval=timedelta(minutes=5),
16
+ epoch=(1970,1,1)
17
+ ):
18
+ patches = {}
19
+ t = time_range[0]
20
+ while t < time_range[1]:
21
+ patches[t] = []
22
+ for pi in range(patch_box[0][0], patch_box[0][1]):
23
+ for pj in range(patch_box[1][0], patch_box[1][1]):
24
+ patches[t].append((pi,pj))
25
+ patches[t] = np.array(patches[t])
26
+ t += interval
27
+
28
+ return patches
29
+
30
+
31
+ def save_patches_radar(
32
+ patches, archive_path, out_dir,
33
+ variables=("RZC", "CPCH"),
34
+ suffix="2020",
35
+ **kwargs
36
+ ):
37
+ from ..datasets import mchradar
38
+
39
+ source_vars = {}
40
+ mchradar_reader = mchradar.MCHRadarReader(
41
+ archive_path=archive_path,
42
+ variables=variables,
43
+ phys_values=False
44
+ )
45
+
46
+ ezc_nonzero_count_func = lambda x: np.count_nonzero((x >= 1) & (x<251))
47
+ nonzero_count_func = {
48
+ "RZC": lambda x: np.count_nonzero(x > 1),
49
+ "CPCH": lambda x: np.count_nonzero(x > 1)
50
+ }
51
+ zero_value = {v: 0 for v in variables}
52
+ zero_value["RZC"] = 1
53
+ zero_value["CPCH"] = 1
54
+
55
+ save_patches_all(
56
+ mchradar_reader, patches, variables,
57
+ nonzero_count_func, zero_value, out_dir, suffix,
58
+ source_vars=source_vars, min_nonzeros_to_include=5,
59
+ **kwargs
60
+ )
61
+
62
+
63
+ def save_patches_dwdradar(
64
+ patches, archive_path, out_dir,
65
+ variables=("RV",),
66
+ suffix="2022",
67
+ patch_shape=(32,32),
68
+ **kwargs
69
+ ):
70
+ from ..datasets import dwdradar
71
+
72
+ source_vars = {}
73
+ dwdradar_reader = dwdradar.DWDRadarReader(
74
+ archive_path=archive_path,
75
+ variables=variables
76
+ )
77
+
78
+ patches_flt = {}
79
+ for t in sorted(patches):
80
+ if (t.hour==0) and (t.minute==0):
81
+ print(t)
82
+
83
+ try:
84
+ data = dwdradar_reader.variable_for_time(t, "RV")
85
+ except FileNotFoundError:
86
+ continue
87
+
88
+ patch_locs_time = []
89
+ for (pi,pj) in patches[t]:
90
+ i0 = pi * patch_shape[0]
91
+ i1 = i0 + patch_shape[0]
92
+ j0 = pj * patch_shape[1]
93
+ j1 = j0 + patch_shape[1]
94
+ patch = data[i0:i1,j0:j1]
95
+ if np.isfinite(patch).all():
96
+ patch_locs_time.append((pi,pj))
97
+
98
+ if patch_locs_time:
99
+ patches_flt[t] = np.array(patch_locs_time)
100
+
101
+ print(len(patches),len(patches_flt))
102
+ patches = patches_flt
103
+
104
+ nonzero_count_func = {"RV": np.count_nonzero}
105
+ zero_value = {v: 0 for v in variables}
106
+
107
+ save_patches_all(
108
+ dwdradar_reader, patches, variables,
109
+ nonzero_count_func, zero_value, out_dir, suffix,
110
+ source_vars=source_vars, min_nonzeros_to_include=5,
111
+ **kwargs
112
+ )
113
+
114
+
115
+ def save_patches_ifs(
116
+ patches, archive_path, out_dir,
117
+ variables=(
118
+ #"rate-tp", "rate-cp", "t2m", "cape", "cin",
119
+ #"tclw", "tcwv", #, "rate-tpe"
120
+ #"u", "v"
121
+ "cin",
122
+ ),
123
+ suffix="2020",
124
+ lags=(0,12)
125
+ ):
126
+ from ..datasets import ifsnwp
127
+ from .. import projection
128
+
129
+ proj = projection.GridProjection(projection.ccs4_swiss_grid_area)
130
+ ifsnwp_reader = ifsnwp.IFSNWPReader(
131
+ proj,
132
+ archive_path=archive_path,
133
+ variables=variables,
134
+ lags=lags,
135
+ )
136
+
137
+ # we only get data for every hour, so modify patches
138
+ ifs_patches = {
139
+ dt: pset for (dt, pset) in patches.items()
140
+ if (dt.minute == dt.second == dt.microsecond == 0)
141
+ }
142
+
143
+ variables_with_lag = []
144
+ for lag in ifsnwp_reader.lags:
145
+ variables_with_lag.extend(f"{v}-{lag}" for v in variables)
146
+
147
+ count_positive = lambda x: np.count_nonzero(x > 0)
148
+ all_nonzero = lambda x: np.prod(x.shape)
149
+ nonzero_count_func = {
150
+ "rate-tp": count_positive,
151
+ "rate-cp": count_positive,
152
+ "t2m": all_nonzero,
153
+ "cape": count_positive,
154
+ "cin": count_positive,
155
+ "tclw": count_positive,
156
+ "tcwv": count_positive,
157
+ "u": all_nonzero,
158
+ "v": all_nonzero,
159
+ "rate-tpe": count_positive,
160
+ }
161
+ nonzero_count_func = {
162
+ v: nonzero_count_func[v.rsplit("-", 1)[0]]
163
+ for v in variables_with_lag
164
+ }
165
+ postproc = {
166
+ f"cin-{lag}": lambda x: np.nan_to_num(x, nan=0.0, copy=False)
167
+ for lag in lags
168
+ }
169
+ zero_value = {v: 0 for v in variables_with_lag}
170
+ avg_pool = lambda x: average_pool(x, factor=8, missing=np.nan)
171
+ pool = {v: avg_pool for v in variables_with_lag}
172
+
173
+ save_patches_all(ifsnwp_reader, ifs_patches, variables_with_lag,
174
+ nonzero_count_func, zero_value, out_dir, suffix, pool=pool,
175
+ postproc=postproc)
176
+
177
+
178
+ def save_patches_cosmo(patches, archive_path, out_dir, suffix="2020"):
179
+ from ..datasets import cosmonwp
180
+
181
+ cosmonwp_reader = cosmonwp.COSMOCCS4Reader(
182
+ archive_path=archive_path, cache_size=6000)
183
+
184
+ # we only get data for every hour, so modify patches
185
+ cosmo_patches = {}
186
+ for (dt,pset) in patches.items():
187
+ dt0 = datetime(dt.year, dt.month, dt.day, dt.hour)
188
+ dt1 = dt0 + timedelta(hours=1)
189
+ if dt0 not in cosmo_patches:
190
+ cosmo_patches[dt0] = set()
191
+ if dt1 not in cosmo_patches:
192
+ cosmo_patches[dt1] = set()
193
+ cosmo_patches[dt0].update(pset)
194
+ cosmo_patches[dt1].update(pset)
195
+
196
+ variables = [
197
+ "CAPE_MU", "CIN_MU", "SLI",
198
+ "HZEROCL", "LCL_ML", "MCONV", "OMEGA",
199
+ "T_2M", "T_SO", "SOILTYP"
200
+ ]
201
+ count_positive = lambda x: np.count_nonzero(x>0)
202
+ all_nonzero = lambda x: np.prod(x.shape)
203
+ nonzero_count_func = {
204
+ "CAPE_MU": count_positive,
205
+ "CIN_MU": count_positive,
206
+ "SLI": all_nonzero,
207
+ "HZEROCL": count_positive,
208
+ "LCL_ML": count_positive,
209
+ "MCONV": all_nonzero,
210
+ "OMEGA": all_nonzero,
211
+ "T_2M": all_nonzero,
212
+ "T_SO": all_nonzero,
213
+ "SOILTYP": lambda x: np.count_nonzero(x!=5)
214
+ }
215
+ zero_value = {v: 0 for v in variables}
216
+ zero_value["SOILTYP"] = 5
217
+
218
+ save_patches_all(cosmonwp_reader, cosmo_patches, variables,
219
+ nonzero_count_func, zero_value, out_dir, suffix, pool=pool)
220
+
221
+
222
+ def save_patches_all(
223
+ reader, patches, variables, nonzero_count_func, zero_value,
224
+ out_dir, suffix, epoch=datetime(1970,1,1), postproc={}, scale=None,
225
+ pool={}, source_vars={}, parallel=False, min_nonzeros_to_include=1
226
+ ):
227
+
228
+ def save_var(var_name):
229
+ src_name = source_vars.get(var_name, var_name)
230
+
231
+ (patch_data, patch_coords, patch_times,
232
+ zero_patch_coords, zero_patch_times) = get_patches(
233
+ reader, src_name, patches,
234
+ nonzero_count_func=nonzero_count_func[var_name],
235
+ postproc=postproc.get(var_name),
236
+ pool=pool.get(var_name)
237
+ )
238
+ try:
239
+ time = epoch + timedelta(seconds=int(patch_times[0]))
240
+ var_scale = reader.get_scale(time, var_name)
241
+ except (AttributeError, KeyError):
242
+ var_scale = None if (scale is None) else scale[var_name]
243
+ pass
244
+
245
+ var_name = var_name.replace("_", "-")
246
+ out_fn = f"patches_{var_name}_{suffix}.nc"
247
+ out_path = os.path.join(out_dir, var_name)
248
+ os.makedirs(out_path, exist_ok=True)
249
+ out_fn = os.path.join(out_path, out_fn)
250
+
251
+ save_patches(
252
+ patch_data, patch_coords, patch_times,
253
+ zero_patch_coords, zero_patch_times, out_fn,
254
+ zero_value=zero_value[var_name], scale=var_scale
255
+ )
256
+
257
+ if parallel:
258
+ save_var = dask.delayed(save_var)
259
+
260
+ jobs = [save_var(v) for v in variables]
261
+ if parallel:
262
+ dask.compute(jobs, scheduler='threads')
263
+
264
+
265
+ def get_patches(
266
+ reader, variable, patches,
267
+ patch_shape=(32,32), nonzero_count_func=None,
268
+ epoch=datetime(1970,1,1), postproc=None,
269
+ pool=None, min_nonzeros_to_include=1
270
+ ):
271
+ num_patches = sum(len(patches[t]) for t in patches)
272
+ patch_data = []
273
+ patch_coords = []
274
+ patch_times = []
275
+ zero_patch_coords = []
276
+ zero_patch_times = []
277
+
278
+ if hasattr(reader, "phys_values"):
279
+ phys_values = reader.phys_values
280
+
281
+ k = 0
282
+ try:
283
+ if hasattr(reader, "phys_values"):
284
+ reader.phys_values = False
285
+ for (t, p_coord) in patches.items():
286
+ try:
287
+ data = reader.variable_for_time(t, variable)
288
+ except (ValueError, FileNotFoundError, KeyError, OSError):
289
+ continue
290
+
291
+ if postproc is not None:
292
+ data = postproc(data)
293
+
294
+ time_sec = np.int64((t-epoch).total_seconds())
295
+ for (pi, pj) in p_coord:
296
+ if k % 100000 == 0:
297
+ print("{}: {}/{}".format(t, k, num_patches))
298
+ patch_box = data[
299
+ pi*patch_shape[0]:(pi+1)*patch_shape[0],
300
+ pj*patch_shape[1]:(pj+1)*patch_shape[1],
301
+ ].copy()
302
+ is_nonzero = (nonzero_count_func is not None) and \
303
+ (nonzero_count_func(patch_box) < min_nonzeros_to_include)
304
+ if is_nonzero:
305
+ zero_patch_coords.append((pi,pj))
306
+ zero_patch_times.append(time_sec)
307
+ else:
308
+ if pool is not None:
309
+ patch_box = pool(patch_box)
310
+ patch_data.append(patch_box)
311
+ patch_coords.append((pi,pj))
312
+ patch_times.append(time_sec)
313
+ k += 1
314
+
315
+ finally:
316
+ if hasattr(reader, "phys_values"):
317
+ reader.phys_values = phys_values
318
+
319
+ if zero_patch_coords:
320
+ zero_patch_coords = np.stack(zero_patch_coords, axis=0).astype(np.uint16)
321
+ zero_patch_times = np.stack(zero_patch_times, axis=0)
322
+ else:
323
+ zero_patch_coords = np.zeros((0,2), dtype=np.uint16)
324
+ zero_patch_times = np.zeros((0,), dtype=np.int64)
325
+ patch_data = np.stack(patch_data, axis=0)
326
+ patch_coords = np.stack(patch_coords, axis=0).astype(np.uint16)
327
+ patch_times = np.stack(patch_times, axis=0)
328
+
329
+ return (patch_data, patch_coords, patch_times,
330
+ zero_patch_coords, zero_patch_times)
331
+
332
+
333
+ def save_patches(patch_data, patch_coords, patch_times,
334
+ zero_patch_coords, zero_patch_times, out_fn, zero_value=0, scale=None):
335
+
336
+ with netCDF4.Dataset(out_fn, 'w') as ds:
337
+ dim_patch = ds.createDimension("dim_patch", patch_data.shape[0])
338
+ dim_zero_patch = ds.createDimension("dim_zero_patch", zero_patch_coords.shape[0])
339
+ dim_coord = ds.createDimension("dim_coord", 2)
340
+ dim_height = ds.createDimension("dim_height", patch_data.shape[1])
341
+ dim_width = ds.createDimension("dim_width", patch_data.shape[2])
342
+
343
+ var_args = {"zlib": True, "complevel": 1}
344
+
345
+ chunksizes = (min(2**10, patch_data.shape[0]), patch_data.shape[1], patch_data.shape[2])
346
+ var_patch = ds.createVariable("patches", patch_data.dtype,
347
+ ("dim_patch","dim_height","dim_width"), chunksizes=chunksizes, **var_args)
348
+ var_patch[:] = patch_data
349
+
350
+ var_patch_coord = ds.createVariable("patch_coords", patch_coords.dtype,
351
+ ("dim_patch","dim_coord"), **var_args)
352
+ var_patch_coord[:] = patch_coords
353
+
354
+ var_patch_time = ds.createVariable("patch_times", patch_times.dtype,
355
+ ("dim_patch",), **var_args)
356
+ var_patch_time[:] = patch_times
357
+
358
+ var_zero_patch_coord = ds.createVariable("zero_patch_coords", zero_patch_coords.dtype,
359
+ ("dim_zero_patch","dim_coord"), **var_args)
360
+ var_zero_patch_coord[:] = zero_patch_coords
361
+
362
+ var_zero_patch_time = ds.createVariable("zero_patch_times", zero_patch_times.dtype,
363
+ ("dim_zero_patch",), **var_args)
364
+ var_zero_patch_time[:] = zero_patch_times
365
+
366
+ ds.zero_value = zero_value
367
+
368
+ if scale is not None:
369
+ dim_scale = ds.createDimension("dim_scale", len(scale))
370
+ var_scale = ds.createVariable("scale", scale.dtype, ("dim_scale",), **var_args)
371
+ var_scale[:] = scale
372
+
373
+
374
+ def load_patches(fn, in_memory=True):
375
+ if in_memory:
376
+ with open(fn, 'rb') as f:
377
+ ds_raw = f.read()
378
+ fn = None
379
+ else:
380
+ ds_raw = None
381
+
382
+ with netCDF4.Dataset(fn, 'r', memory=ds_raw) as ds:
383
+ patch_data = {
384
+ #"patches": np.array(ds["patches"]),
385
+ #"patch_coords": np.array(ds["patch_coords"]),
386
+ #"patch_times": np.array(ds["patch_times"]),
387
+ #"zero_patch_coords": np.array(ds["zero_patch_coords"]),
388
+ #"zero_patch_times": np.array(ds["zero_patch_times"]),
389
+ "zero_value": 1,
390
+ "pr": np.array(ds["pr"]),
391
+ }
392
+ if "scale" in ds.variables:
393
+ patch_data["scale"] = np.array(ds["scale"])
394
+
395
+ return patch_data
396
+
397
+
398
+ def load_all_patches(patch_dir, var):
399
+ files = os.listdir(patch_dir)
400
+ jobs = []
401
+ for fn in files:
402
+ file_var = fn.split("_")[1]
403
+ if file_var == var:
404
+ fn = os.path.join(patch_dir, fn)
405
+ jobs.append(dask.delayed(load_patches)(fn))
406
+
407
+ file_data = dask.compute(jobs, scheduler="processes")[0]
408
+ patch_data = {}
409
+ #keys = ["patches", "patch_coords", "patch_times",
410
+ # "zero_patch_coords", "zero_patch_times"]
411
+ keys = ["pr"]
412
+ for k in keys:
413
+ patch_data[k] = np.concatenate(
414
+ [fd[k] for fd in file_data],
415
+ axis=0
416
+ )
417
+ patch_data["zero_value"] = file_data[0]["zero_value"]
418
+ if "scale" in file_data[0]:
419
+ patch_data["scale"] = file_data[0]["scale"]
420
+
421
+ return patch_data
422
+
423
+
424
+ def unpack_patches(patch_data):
425
+ return (
426
+ patch_data["patches"],
427
+ patch_data["patch_coords"],
428
+ patch_data["patch_times"],
429
+ patch_data["zero_patch_coords"],
430
+ patch_data["zero_patch_times"]
431
+ )
ldcast/features/sampling.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bisect import bisect_left
2
+ import multiprocessing
3
+
4
+ import dask
5
+ from numba import njit, prange, types
6
+ from numba.typed import Dict
7
+ import numpy as np
8
+
9
+ from .patches import unpack_patches
10
+
11
+
12
+ class EqualFrequencySampler:
13
+ def __init__(
14
+ self, bins, patch_data, patch_index,
15
+ sample_shape, time_range_valid, time_range_sampling=None,
16
+ timestep_secs=5*60,
17
+ random_seed=None, preselected_samples=None
18
+ ):
19
+ binned_patches = bin_classify_patches_parallel(
20
+ bins,
21
+ *unpack_patches(patch_data),
22
+ zero_value=patch_data.get("zero_value", 0),
23
+ scale=patch_data.get("scale")
24
+ )
25
+ complete_ind = indices_with_complete_sample(
26
+ patch_index, sample_shape, time_range_valid, timestep_secs
27
+ )
28
+ if time_range_sampling is None:
29
+ time_range_sampling = time_range_valid
30
+ self.starting_ind = [
31
+ starting_indices_for_centers(
32
+ p, complete_ind, sample_shape, time_range_sampling, timestep_secs
33
+ )
34
+ for p in binned_patches
35
+ ]
36
+ self.num_bins = len(self.starting_ind)
37
+ self.rng = np.random.RandomState(seed=random_seed)
38
+ self.preselected_samples = preselected_samples
39
+ self.current_ind = np.array([len(ind) for ind in self.starting_ind])
40
+
41
+ def get_bin_sample(self, bin_ind):
42
+ patches = self.starting_ind[bin_ind]
43
+ sample_ind = self.current_ind[bin_ind]
44
+ if sample_ind >= patches.shape[0]:
45
+ self.rng.shuffle(patches)
46
+ sample_ind = self.current_ind[bin_ind] = 0
47
+ else:
48
+ self.current_ind[bin_ind] += 1
49
+ return patches[sample_ind,:]
50
+
51
+ def __call__(self, num):
52
+ # sample each bin with equal probability
53
+ bins = self.rng.randint(self.num_bins, size=num)
54
+ coords = np.stack(
55
+ [self.get_bin_sample(b) for b in bins],
56
+ axis=0
57
+ )
58
+ return coords
59
+
60
+
61
+ def bin_classify_patches(
62
+ bins, patches, patch_coords, patch_times,
63
+ zero_patch_coords, zero_patch_times,
64
+ zero_value=0, metric_func=None,
65
+ scale=None,
66
+ ):
67
+ if metric_func is None:
68
+ def metric_func(x):
69
+ xm = np.percentile(x, 99, axis=(1,2))
70
+ if np.issubdtype(x.dtype, np.integer):
71
+ xm = xm.round()
72
+ return xm.astype(x.dtype)
73
+
74
+ binned_patches = [[] for _ in range(len(bins)+1)]
75
+
76
+ def find_bin(value):
77
+ return bisect_left(bins, value)
78
+
79
+ zero_bin = find_bin(zero_value if scale is None else scale[zero_value])
80
+ for (t,(pi,pj)) in zip(zero_patch_times, zero_patch_coords):
81
+ binned_patches[zero_bin].append((t,pi,pj))
82
+
83
+ patch_metrics = metric_func(patches)
84
+ if scale is not None:
85
+ patch_metrics = scale[patch_metrics]
86
+ for (metric,t,(pi,pj)) in zip(patch_metrics, patch_times, patch_coords):
87
+ patch_bin = find_bin(metric)
88
+ binned_patches[patch_bin].append((t,pi,pj))
89
+
90
+ for i in range(len(binned_patches)):
91
+ if binned_patches[i]:
92
+ binned_patches[i] = np.array(binned_patches[i])
93
+ else:
94
+ binned_patches[i] = np.zeros((0,3), dtype=np.int64)
95
+
96
+ return binned_patches
97
+
98
+
99
+ def bin_classify_patches_parallel(
100
+ bins, patches, patch_coords, patch_times,
101
+ zero_patch_coords, zero_patch_times,
102
+ zero_value=0, metric_func=None,
103
+ scale=None,
104
+ ):
105
+ num_patches = patches.shape[0]
106
+ num_zeros = zero_patch_coords.shape[0]
107
+ num_procs = multiprocessing.cpu_count()
108
+
109
+ tasks = []
110
+ for p in range(num_procs):
111
+ pk0 = int(round(num_patches*p/num_procs))
112
+ pk1 = int(round(num_patches*(p+1)/num_procs))
113
+ zk0 = int(round(num_zeros*p/num_procs))
114
+ zk1 = int(round(num_zeros*(p+1)/num_procs))
115
+
116
+ task = dask.delayed(bin_classify_patches)(
117
+ bins,
118
+ patches[pk0:pk1,...], patch_coords[pk0:pk1,...],
119
+ patch_times[pk0:pk1],
120
+ zero_patch_coords[zk0:zk1,...], zero_patch_times[zk0:zk1],
121
+ zero_value=zero_value, metric_func=metric_func,
122
+ scale=scale
123
+ )
124
+ tasks.append(task)
125
+
126
+ chunked_bins = dask.compute(tasks, scheduler="threads")[0]
127
+
128
+ n_bins = len(chunked_bins[0])
129
+ binned_patches = [
130
+ np.concatenate([cb[i] for cb in chunked_bins], axis=0)
131
+ for i in range(n_bins)
132
+ ]
133
+ return binned_patches
134
+
135
+
136
+ def indices_with_complete_sample(
137
+ patch_index, sample_shape, time_range, timestep_secs
138
+ ):
139
+ """Check which locations will give a sample without missing data.
140
+ """
141
+ ind = np.array(list(patch_index.patch_index.keys()))
142
+ t0 = ind[:,0]
143
+ i0 = ind[:,1]
144
+ j0 = ind[:,2]
145
+ n = ind.shape[0]
146
+ complete = np.ones(n, dtype=bool)
147
+ # we use this dict like a set - numba doesn't support typed sets
148
+ complete_ind = Dict.empty(
149
+ key_type=types.UniTuple(types.int64, 3),
150
+ value_type=types.uint8
151
+ )
152
+
153
+ @njit(parallel=True) # many nested loops, numba optimization needed
154
+ def check_complete(index, complete, complete_ind):
155
+ for k in prange(n):
156
+ for ts in range(*time_range):
157
+ t = t0[k] + ts*timestep_secs
158
+ for di in range(sample_shape[0]):
159
+ i = i0[k] + di
160
+ for dj in range(sample_shape[1]):
161
+ j = j0[k] + dj
162
+ if (t,i,j) not in index:
163
+ complete[k] = False
164
+
165
+ for k in range(n): # no prange: can't set dict items in parallel
166
+ if complete[k]:
167
+ complete_ind[(t0[k],i0[k],j0[k])] = np.uint8(0)
168
+
169
+ check_complete(patch_index.patch_index, complete, complete_ind)
170
+
171
+ return complete_ind
172
+
173
+
174
+ def starting_indices_for_centers(
175
+ centers, complete_ind, sample_shape, time_range, timestep_secs
176
+ ):
177
+ """Determine a complete list of sample indices that
178
+ contain one or more of the centerpoints.
179
+ """
180
+
181
+ @njit
182
+ def find_indices(centers, starting_ind, complete_ind):
183
+ for k in range(centers.shape[0]):
184
+ t0 = centers[k,0]
185
+ i0 = centers[k,1]
186
+ j0 = centers[k,2]
187
+ for ts in range(*time_range):
188
+ t = t0 - ts*timestep_secs # note minus signs in (t,i,j)
189
+ for di in range(sample_shape[0]):
190
+ i = i0 - di
191
+ for dj in range(sample_shape[1]):
192
+ j = j0 - dj
193
+ if (t,i,j) in complete_ind:
194
+ starting_ind[(t,i,j)] = np.uint8(0)
195
+
196
+ num_chunks = multiprocessing.cpu_count()
197
+
198
+ @dask.delayed
199
+ def chunk(i):
200
+ starting_ind = Dict.empty(
201
+ key_type=types.UniTuple(types.int64, 3),
202
+ value_type=types.uint8
203
+ )
204
+ k0 = int(round(centers.shape[0] * (i / num_chunks)))
205
+ k1 = int(round(centers.shape[0] * ((i+1) / num_chunks)))
206
+ find_indices(centers[k0:k1,...], starting_ind, complete_ind)
207
+ return starting_ind
208
+
209
+ jobs = [chunk(i) for i in range(num_chunks)]
210
+ starting_ind = dask.compute(jobs, scheduler='threads')[0]
211
+ starting_ind = np.concatenate(
212
+ [np.array(list(st_ind.keys())) for st_ind in starting_ind if st_ind],
213
+ axis=0
214
+ )
215
+ return starting_ind
ldcast/features/split.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bisect import bisect_left
2
+
3
+ import numpy as np
4
+ import pytorch_lightning as pl
5
+ from torch.utils.data import DataLoader
6
+
7
+ from . import batch
8
+
9
+
10
+ def get_chunks(
11
+ primary_raw, valid_frac=0.1, test_frac=0.1,
12
+ chunk_seconds=2*24*60*60, random_seed=None
13
+ ):
14
+ t0 = min(
15
+ primary_raw["patch_times"][0],
16
+ primary_raw["zero_patch_times"][0]
17
+ )
18
+ t1 = max(
19
+ primary_raw["patch_times"][-1],
20
+ primary_raw["zero_patch_times"][-1]
21
+ )+1
22
+
23
+ rng = np.random.RandomState(seed=random_seed)
24
+ chunk_limits = np.arange(t0,t1,chunk_seconds)
25
+ num_chunks = len(chunk_limits)-1
26
+
27
+ chunk_ind = np.arange(num_chunks)
28
+ rng.shuffle(chunk_ind)
29
+ i_valid = int(round(num_chunks * valid_frac))
30
+ i_test = i_valid + int(round(num_chunks * test_frac))
31
+ chunk_ind = {
32
+ "valid": chunk_ind[:i_valid],
33
+ "test": chunk_ind[i_valid:i_test],
34
+ "train": chunk_ind[i_test:]
35
+ }
36
+ def get_chunk_limits(chunk_ind_split):
37
+ return sorted(
38
+ (chunk_limits[i], chunk_limits[i+1])
39
+ for i in chunk_ind_split
40
+ )
41
+ chunks = {
42
+ split: get_chunk_limits(chunk_ind_split)
43
+ for (split, chunk_ind_split) in chunk_ind.items()
44
+ }
45
+ return chunks
46
+
47
+
48
+ def train_valid_test_split(
49
+ raw_data, primary_raw_var, chunks=None, **kwargs
50
+ ):
51
+ if chunks is None:
52
+ primary = raw_data[primary_raw_var]
53
+ chunks = get_chunks(primary, **kwargs)
54
+
55
+ def split_chunks_from_array(x, chunks_split, times):
56
+ n = 0
57
+ chunk_ind = []
58
+ for (t0,t1) in chunks_split:
59
+ k0 = bisect_left(times, t0)
60
+ k1 = bisect_left(times, t1)
61
+ n += k1 - k0
62
+ chunk_ind.append((k0,k1))
63
+
64
+ shape = (n,) + x.shape[1:]
65
+ x_chunk = np.empty_like(x, shape=shape)
66
+
67
+ j0 = 0
68
+ for (k0,k1) in chunk_ind:
69
+ j1 = j0 + (k1-k0)
70
+ x_chunk[j0:j1,...] = x[k0:k1,...]
71
+ j0 = j1
72
+
73
+ return x_chunk
74
+
75
+ split_raw_data = {
76
+ split: {var: {} for var in raw_data}
77
+ for split in chunks
78
+ }
79
+
80
+ for (var, raw_data_var) in raw_data.items():
81
+ for (split, chunks_split) in chunks.items():
82
+
83
+ #split_raw_data[split][var]["patches"] = \
84
+ # split_chunks_from_array(
85
+ # raw_data_var["patches"], chunks_split,
86
+ # raw_data_var["patch_times"]
87
+ # )
88
+ #split_raw_data[split][var]["patch_coords"] = \
89
+ # split_chunks_from_array(
90
+ # raw_data_var["patch_coords"], chunks_split,
91
+ # raw_data_var["patch_times"]
92
+ # )
93
+ #split_raw_data[split][var]["patch_times"] = \
94
+ # split_chunks_from_array(
95
+ # raw_data_var["patch_times"], chunks_split,
96
+ # raw_data_var["patch_times"]
97
+ # )
98
+ #split_raw_data[split][var]["zero_patch_coords"] = \
99
+ # split_chunks_from_array(
100
+ # raw_data_var["zero_patch_coords"], chunks_split,
101
+ # raw_data_var["zero_patch_times"]
102
+ # )
103
+ #split_raw_data[split][var]["zero_patch_times"] = \
104
+ # split_chunks_from_array(
105
+ # raw_data_var["zero_patch_times"], chunks_split,
106
+ # raw_data_var["zero_patch_times"]
107
+ # )
108
+
109
+ added_keys = set(split_raw_data[split][var].keys())
110
+ missing_keys = set(raw_data[var].keys()) - added_keys
111
+ for k in missing_keys:
112
+ split_raw_data[split][var][k] = raw_data[var][k]
113
+
114
+ return (split_raw_data, chunks)
115
+
116
+
117
+ class DataModule(pl.LightningDataModule):
118
+ def __init__(
119
+ self,
120
+ variables, raw, predictors, target, primary_var,
121
+ sampling_bins, sampler_file,
122
+ batch_size=8,
123
+ train_epoch_size=10, valid_epoch_size=2, test_epoch_size=10,
124
+ valid_seed=None, test_seed=None,
125
+ **kwargs
126
+ ):
127
+ super().__init__()
128
+ self.batch_gen = {
129
+ split: batch.BatchGenerator(
130
+ variables, raw_var, predictors, target, primary_var,
131
+ sampling_bins=sampling_bins, batch_size=batch_size,
132
+ sampler_file=sampler_file.get(split),
133
+ augment=(split=="train"),
134
+ **kwargs
135
+ )
136
+ for (split,raw_var) in raw.items()
137
+ }
138
+ self.datasets = {}
139
+ if "train" in self.batch_gen:
140
+ self.datasets["train"] = batch.StreamBatchDataset(
141
+ self.batch_gen["train"], train_epoch_size
142
+ )
143
+ if "valid" in self.batch_gen:
144
+ self.datasets["valid"] = batch.DeterministicBatchDataset(
145
+ self.batch_gen["valid"], valid_epoch_size, random_seed=valid_seed
146
+ )
147
+ if "test" in self.batch_gen:
148
+ self.datasets["test"] = batch.DeterministicBatchDataset(
149
+ self.batch_gen["test"], test_epoch_size, random_seed=test_seed
150
+ )
151
+
152
+ def dataloader(self, split):
153
+ return DataLoader(
154
+ self.datasets[split], batch_size=None,
155
+ pin_memory=True, num_workers=0
156
+ )
157
+
158
+ def train_dataloader(self):
159
+ return self.dataloader("train")
160
+
161
+ def val_dataloader(self):
162
+ return self.dataloader("valid")
163
+
164
+ def test_dataloader(self):
165
+ return self.dataloader("test")
ldcast/features/transform.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+ import multiprocessing
3
+
4
+ from numba import njit, prange
5
+ import numpy as np
6
+ from scipy.ndimage import convolve
7
+
8
+
9
+ def quick_cast(x, y):
10
+ num_threads = multiprocessing.cpu_count()
11
+ with concurrent.futures.ThreadPoolExecutor(num_threads) as executor:
12
+ futures = {}
13
+ limits = np.linspace(0, x.shape[0], num_threads+1).round().astype(int)
14
+ def _cast(k0,k1):
15
+ y[k0:k1,...] = x[k0:k1,...]
16
+ for k in range(len(limits)-1):
17
+ args = (_cast, limits[k], limits[k+1])
18
+ futures[executor.submit(*args)] = k
19
+ concurrent.futures.wait(futures)
20
+
21
+
22
+ def cast(dtype=np.float16):
23
+ xc = None
24
+ def transform(raw):
25
+ nonlocal xc
26
+ if (xc is None) or (xc.shape != raw.shape):
27
+ xc = np.empty_like(raw, dtype=dtype)
28
+ quick_cast(raw, xc)
29
+ return xc
30
+ return transform
31
+
32
+
33
+ @njit(parallel=True)
34
+ def scale_array(in_arr, out_arr, scale):
35
+ in_arr = in_arr.ravel()
36
+ out_arr = out_arr.ravel()
37
+ for i in prange(in_arr.shape[0]):
38
+ out_arr[i] = scale[in_arr[i]]
39
+
40
+ # NumPy version
41
+ #def scale_array(in_arr, out_arr, scale):
42
+ # out_arr[:] = scale[in_arr]
43
+
44
+ def normalize(mean=0.0, std=1.0, dtype=np.float32):
45
+ scaled = scaled_dt = None
46
+
47
+ def transform(raw):
48
+ nonlocal scaled, scaled_dt
49
+ if (scaled is None) or (scaled.shape != raw.shape):
50
+ scaled = np.empty_like(raw, dtype=np.float32)
51
+ scaled_dt = np.empty_like(raw, dtype=dtype)
52
+ normalize_array(raw, scaled, mean, std)
53
+
54
+ if dtype == np.float32:
55
+ return scaled
56
+ else:
57
+ quick_cast(scaled, scaled_dt)
58
+ return scaled_dt
59
+
60
+ return transform
61
+
62
+
63
+ def normalize_threshold(mean=0.0, std=1.0, threshold=0.0, fill_value=0.0, log=False):
64
+ scaled = None
65
+
66
+ def transform(raw):
67
+ nonlocal scaled
68
+ if (scaled is None) or (scaled.shape != raw.shape):
69
+ scaled = np.empty_like(raw, dtype=np.float32)
70
+ normalize_threshold_array(raw, scaled, mean, std, threshold, fill_value, log=log)
71
+
72
+ return scaled
73
+
74
+ return transform
75
+
76
+
77
+ def scale_log_norm(scale, threshold=None, missing_value=None,
78
+ fill_value=0, mean=0.0, std=1.0, dtype=np.float32):
79
+
80
+ log_scale = np.log10(scale, where=scale>0).astype(np.float32)
81
+ if threshold is not None:
82
+ log_scale[log_scale < np.log10(threshold)] = np.log10(fill_value)
83
+ if missing_value is not None:
84
+ log_scale[missing_value] = np.log10(fill_value)
85
+ log_scale[~np.isfinite(log_scale)] = np.log10(fill_value)
86
+ log_scale -= mean
87
+ log_scale /= std
88
+ scaled = scaled_dt = None
89
+
90
+ def transform(raw):
91
+ nonlocal scaled, scaled_dt
92
+ if (scaled is None) or (scaled.shape != raw.shape):
93
+ scaled = np.empty_like(raw, dtype=np.float32)
94
+ scaled_dt = np.empty_like(raw, dtype=dtype)
95
+ scale_array(raw, scaled, log_scale)
96
+
97
+ if dtype == np.float32:
98
+ return scaled
99
+ else:
100
+ quick_cast(scaled, scaled_dt)
101
+ return scaled_dt
102
+
103
+ return transform
104
+
105
+
106
+ def combine(transforms, memory_format="channels_first", dim=3):
107
+ #combined = None
108
+ channels_axis = 1 if (memory_format == "channels_first") else -1
109
+
110
+ def transform(*raw):
111
+ #nonlocal combined
112
+ transformed = [t(r) for (t, r) in zip(transforms, raw)]
113
+ for i in range(len(transformed)):
114
+ if transformed[i].ndim == dim + 1:
115
+ transformed[i] = np.expand_dims(transformed[i], channels_axis)
116
+
117
+ return np.concatenate(transformed, axis=channels_axis)
118
+
119
+ return transform
120
+
121
+
122
+ class Antialiasing:
123
+ def __init__(self):
124
+ (x,y) = np.mgrid[-2:3,-2:3]
125
+ self.kernel = np.exp(-0.5*(x**2+y**2)/(0.5**2))
126
+ self.kernel /= self.kernel.sum()
127
+ self.edge_factors = {}
128
+ self.img_smooth = {}
129
+ num_threads = multiprocessing.cpu_count()
130
+ self.executor = concurrent.futures.ThreadPoolExecutor(num_threads)
131
+
132
+ def __call__(self, img):
133
+ img_shape = img.shape[-2:]
134
+ if img_shape not in self.edge_factors:
135
+ s = convolve(np.ones(img_shape, dtype=np.float32),
136
+ self.kernel, mode="constant")
137
+ s = 1.0/s
138
+ self.edge_factors[img_shape] = s
139
+ else:
140
+ s = self.edge_factors[img_shape]
141
+
142
+ if img.shape not in self.img_smooth:
143
+ img_smooth = np.empty_like(img)
144
+ self.img_smooth[img_shape] = img_smooth
145
+ else:
146
+ img_smooth = self.img_smooth[img_shape]
147
+
148
+ def _convolve_frame(i,j):
149
+ convolve(img[i,j,:,:], self.kernel,
150
+ mode="constant", output=img_smooth[i,j,:,:])
151
+ img_smooth[i,j,:,:] *= s
152
+
153
+ futures = []
154
+ for i in range(img.shape[0]):
155
+ for j in range(img.shape[1]):
156
+ args = (_convolve_frame, i, j)
157
+ futures.append(self.executor.submit(*args))
158
+ concurrent.futures.wait(futures)
159
+
160
+ return img_smooth
161
+
162
+
163
+ def default_rainrate_transform(scale):
164
+ scaling = scale_log_norm(
165
+ scale, threshold=0.1, fill_value=0.02,
166
+ mean=-0.051, std=0.528, dtype=np.float32
167
+ )
168
+ antialiasing = Antialiasing()
169
+ def transform(raw):
170
+ x = scaling(raw)
171
+ return antialiasing(x)
172
+ return transform
173
+
174
+
175
+ def scale_norm(scale, threshold=None, missing_value=None,
176
+ fill_value=0, mean=0.0, std=1.0, dtype=np.float32):
177
+
178
+ scale = scale.astype(np.float32).copy()
179
+ scale[np.isnan(scale)] = fill_value
180
+ if threshold is not None:
181
+ scale[scale < threshold] = fill_value
182
+ if missing_value is not None:
183
+ missing_value = np.atleast_1d(missing_value)
184
+ for m in missing_value:
185
+ scale[m] = fill_value
186
+ scale -= mean
187
+ scale /= std
188
+ scaled = scaled_dt = None
189
+
190
+ def transform(raw):
191
+ nonlocal scaled, scaled_dt
192
+ if (scaled is None) or (scaled.shape != raw.shape):
193
+ scaled = np.empty_like(raw, dtype=np.float32)
194
+ scaled_dt = np.empty_like(raw, dtype=dtype)
195
+ scale_array(raw, scaled, scale)
196
+
197
+ if dtype == np.float32:
198
+ return scaled
199
+ else:
200
+ quick_cast(scaled, scaled_dt)
201
+ return scaled_dt
202
+
203
+ return transform
204
+
205
+
206
+ @njit(parallel=True)
207
+ def threshold_array(in_arr, out_arr, threshold):
208
+ in_arr = in_arr.ravel()
209
+ out_arr = out_arr.ravel()
210
+ for i in prange(in_arr.shape[0]):
211
+ out_arr[i] = np.float32(in_arr[i] >= threshold)
212
+
213
+
214
+ def one_hot(values):
215
+ translation = np.zeros(max(values)+1, dtype=int)
216
+ num_categories = len(values)
217
+ for (i,v) in enumerate(values):
218
+ translation[v] = i
219
+ onehot = onehot_dt = None
220
+
221
+ def transform(raw):
222
+ nonlocal onehot, onehot_dt
223
+ if (onehot is None) or (onehot.shape[:-1] != raw.shape):
224
+ onehot = np.empty(raw.shape+(num_categories,),
225
+ dtype=np.float32)
226
+ onehot = np.empty(raw.shape+(num_categories,),
227
+ dtype=np.uint8)
228
+ onehot_transform(raw, onehot, translation)
229
+ quick_cast(onehot, onehot_dt)
230
+
231
+ return onehot
232
+
233
+ return transform
234
+
235
+
236
+ @njit(parallel=True)
237
+ def onehot_transform(in_arr, out_arr, translation):
238
+ for k in prange(in_arr.shape[0]):
239
+ out_arr[k,...] = 0.0
240
+ for t in range(in_arr.shape[1]):
241
+ for i in range(in_arr.shape[2]):
242
+ for j in range(in_arr.shape[3]):
243
+ ind = np.uint64(in_arr[k,t,i,j])
244
+ c = translation[ind]
245
+ out_arr[k,t,i,j,c] = 1.0
246
+
247
+
248
+ @njit(parallel=True)
249
+ def normalize_array(in_arr, out_arr, mean, std):
250
+ mean = np.float32(mean)
251
+ inv_std = np.float32(1.0/std)
252
+ in_arr = in_arr.ravel()
253
+ out_arr = out_arr.ravel()
254
+ for i in prange(in_arr.shape[0]):
255
+ out_arr[i] = (in_arr[i]-mean)*inv_std
256
+
257
+
258
+ @njit(parallel=True)
259
+ def normalize_threshold_array(
260
+ in_arr, out_arr,
261
+ mean, std,
262
+ threshold, fill_value, log=False
263
+ ):
264
+ mean = np.float32(mean)
265
+ inv_std = np.float32(1.0/std)
266
+ threshold = np.float32(threshold)
267
+ fill_value = np.float32(fill_value)
268
+ in_arr = in_arr.ravel()
269
+ out_arr = out_arr.ravel()
270
+ for i in prange(in_arr.shape[0]):
271
+ x = in_arr[i]
272
+ if x < threshold:
273
+ x = fill_value
274
+ if log:
275
+ x = np.log10(x)
276
+ out_arr[i] = (x-mean)*inv_std
277
+
278
+
279
+ # NumPy version
280
+ #def threshold_array(in_arr, out_arr, threshold):
281
+ # out_arr[:] = (in_arr >= threshold).astype(np.float32)
282
+
283
+
284
+ def R_threshold(scale, threshold):
285
+ thresholded = None
286
+ scale_treshold = np.nanargmax(scale > threshold)
287
+
288
+ def transform(rzc_raw):
289
+ nonlocal thresholded
290
+ if (thresholded is None) or (thresholded.shape != rzc_raw.shape):
291
+ thresholded = np.empty_like(rzc_raw, dtype=np.float32)
292
+ threshold_array(rzc_raw, thresholded, scale_treshold)
293
+
294
+ return thresholded
295
+
296
+ return transform
ldcast/features/utils.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numba import njit, prange
2
+ import numpy as np
3
+ from scipy.signal import convolve
4
+
5
+
6
+ def log_scale_with_zero(range, n=65536, dtype=np.float32):
7
+ scale = np.linspace(np.log10(range[0]), np.log10(range[1]), n-1)
8
+ scale = np.hstack((0, 10**scale)).astype(dtype)
9
+ return scale
10
+
11
+
12
+ def log_quantize_with_zero(x, range, n=65536, dtype=np.uint16):
13
+ scale = log_scale_with_zero(range, n=n, dtype=x.dtype)
14
+ y = np.empty_like(x, dtype=dtype)
15
+ log_quant_with_zeros(x, y, np.log10(scale[1:]))
16
+ return (y, scale)
17
+
18
+
19
+ # optimized helper function for the above
20
+ @njit(parallel=True)
21
+ def log_quant_with_zeros(x, y, scale):
22
+ x = x.ravel()
23
+ y = y.ravel()
24
+ min_val = 10**scale[0]
25
+
26
+ for i in prange(x.shape[0]):
27
+ # map small values to 0
28
+ if x[i] < min_val:
29
+ y[i] = 0
30
+ continue
31
+
32
+ lx = np.log10(x[i])
33
+ if lx >= scale[-1]:
34
+ # map too big values to max of scale
35
+ y[i] = len(scale)
36
+ else:
37
+ # binary search for the rest
38
+ k0 = 0
39
+ k1 = len(scale)
40
+ while k1-k0 > 1:
41
+ km = k0 + (k1-k0)//2
42
+ if lx < scale[km]:
43
+ k1 = km
44
+ else:
45
+ k0 = km
46
+
47
+ if k0 == len(scale)-1:
48
+ q = k0
49
+ elif k0 == 0:
50
+ q = 0
51
+ else:
52
+ d0 = abs(lx-scale[k0])
53
+ d1 = abs(lx-scale[k1])
54
+ if d0 < d1:
55
+ q = k0
56
+ else:
57
+ q = k1
58
+
59
+ y[i] = q+1 # add 1 to leave space for zero
60
+
61
+
62
+ @njit(parallel=True)
63
+ def average_pool(x, factor=2, missing=65535):
64
+ y = np.empty((x.shape[0]//factor, x.shape[1]//factor), dtype=x.dtype)
65
+ N = factor**2
66
+ N_thresh = N//2
67
+
68
+ for iy in prange(y.shape[0]):
69
+ ix0 = iy * factor
70
+ ix1 = ix0 + factor
71
+ for jy in range(y.shape[1]):
72
+ jx0 = jy * factor
73
+ jx1 = jx0 + factor
74
+ v = float(0.0)
75
+ num_valid = 0
76
+
77
+ for ix in range(ix0, ix1):
78
+ for jx in range(jx0, jx1):
79
+ if x[ix,jx] != missing:
80
+ v += x[ix,jx]
81
+ num_valid += 1
82
+
83
+ if num_valid >= N_thresh:
84
+ y[iy,jy] = v/num_valid
85
+ else:
86
+ y[iy,jy] = missing
87
+
88
+ return y
89
+
90
+
91
+ @njit(parallel=True)
92
+ def mode_pool(x, num_values=256, factor=2):
93
+ y = np.empty((x.shape[0]//factor, x.shape[1]//factor), dtype=x.dtype)
94
+
95
+ for iy in prange(y.shape[0]):
96
+ v = np.empty(num_values, dtype=np.int64)
97
+ ix0 = iy * factor
98
+ ix1 = ix0 + factor
99
+ for jy in range(y.shape[1]):
100
+ jx0 = jy * factor
101
+ jx1 = jx0 + factor
102
+ v[:] = 0
103
+
104
+ for ix in range(ix0, ix1):
105
+ for jx in range(jx0, jx1):
106
+ v[x[ix,jx]] += 1
107
+
108
+ y[iy,jy] = v.argmax()
109
+
110
+ return y
111
+
112
+
113
+ def fill_holes(missing=65535, rad=1):
114
+ def fill(x):
115
+ # identify mask of points to fill
116
+ o = np.ones((2*rad+1,2*rad+1), dtype=np.uint16)
117
+ n = np.prod(o.shape)
118
+ valid = (x != missing)
119
+ num_valid_neighbors = convolve(valid, o, mode='same', method='direct')
120
+ mask = ~valid & (num_valid_neighbors > 0)
121
+
122
+ # compute mean of valid points around each fillable point
123
+ fx = x.copy()
124
+ fx[~valid] = 0
125
+ mx = convolve(fx, o.astype(np.float64), mode='same', method='direct')
126
+ mx = mx[mask] / num_valid_neighbors[mask]
127
+ if np.issubdtype(x.dtype, np.integer):
128
+ mx = mx.round().astype(x.dtype)
129
+
130
+ # fill holes with mean
131
+ fx = x.copy()
132
+ fx[mask] = mx
133
+ return fx
134
+
135
+ return fill
136
+
ldcast/forecast.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import gc
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.multiprocessing as mp
7
+
8
+ from .features.transform import Antialiasing
9
+ from .models.autoenc import autoenc, encoder
10
+ from .models.genforecast import analysis, unet
11
+ from .models.diffusion import diffusion, plms
12
+
13
+
14
+ class Forecast:
15
+ def __init__(
16
+ self,
17
+ *,
18
+ ldm_weights_fn,
19
+ autoenc_weights_fn,
20
+ gpu='auto',
21
+ past_timesteps=4,
22
+ future_timesteps=20,
23
+ autoenc_time_ratio=4,
24
+ autoenc_hidden_dim=32,
25
+ verbose=True,
26
+ R_min_value=0.1,
27
+ R_zero_value=0.02,
28
+ R_min_output=0.1,
29
+ R_max_output=118.428,
30
+ log_R_mean=-0.051,
31
+ log_R_std=0.528,
32
+ ):
33
+ self.ldm_weights_fn = ldm_weights_fn
34
+ self.autoenc_weights_fn = autoenc_weights_fn
35
+ self.verbose = verbose
36
+ self.R_min_value = R_min_value
37
+ self.R_zero_value = R_zero_value
38
+ self.R_min_output = R_min_output
39
+ self.R_max_output = R_max_output
40
+ self.log_R_mean = log_R_mean
41
+ self.log_R_std = log_R_std
42
+ self.past_timesteps = past_timesteps
43
+ self.future_timesteps = future_timesteps
44
+ self.autoenc_time_ratio = autoenc_time_ratio
45
+ self.autoenc_hidden_dim = autoenc_hidden_dim
46
+ self.antialiasing = Antialiasing()
47
+
48
+ # setup LDM
49
+ self.ldm = self._init_model()
50
+ if gpu is not None:
51
+ if gpu == 'auto':
52
+ if torch.cuda.device_count() > 0:
53
+ self.ldm.to(device="cuda")
54
+ else:
55
+ self.ldm.to(device=f"cuda:{gpu}")
56
+ # setup sampler
57
+ self.sampler = plms.PLMSSampler(self.ldm)
58
+ print(self.ldm.device)
59
+ gc.collect()
60
+
61
+ def _init_model(self):
62
+ # setup autoencoder
63
+ enc = encoder.SimpleConvEncoder()
64
+ dec = encoder.SimpleConvDecoder()
65
+ autoencoder_obs = autoenc.AutoencoderKL(enc, dec)
66
+ #print(torch.load(self.autoenc_weights_fn)['state_dict'].keys())
67
+ # autoencoder_obs.load_state_dict(torch.load(self.autoenc_weights_fn)['state_dict'])
68
+ autoencoder_obs.load_state_dict(torch.load(self.autoenc_weights_fn))
69
+ autoencoders = [autoencoder_obs]
70
+ input_patches = [self.past_timesteps//self.autoenc_time_ratio]
71
+ input_size_ratios = [1]
72
+ embed_dim = [128]
73
+ analysis_depth = [4]
74
+
75
+ # setup forecaster
76
+ analysis_net = analysis.AFNONowcastNetCascade(
77
+ autoencoders,
78
+ input_patches=input_patches,
79
+ input_size_ratios=input_size_ratios,
80
+ train_autoenc=False,
81
+ output_patches=self.future_timesteps//self.autoenc_time_ratio,
82
+ cascade_depth=3,
83
+ embed_dim=embed_dim,
84
+ analysis_depth=analysis_depth
85
+ )
86
+
87
+ # setup denoiser
88
+ denoiser = unet.UNetModel(in_channels=autoencoder_obs.hidden_width,
89
+ model_channels=256, out_channels=autoencoder_obs.hidden_width,
90
+ num_res_blocks=2, attention_resolutions=(1,2),
91
+ dims=3, channel_mult=(1, 2, 4), num_heads=8,
92
+ num_timesteps=self.future_timesteps//self.autoenc_time_ratio,
93
+ context_ch=analysis_net.cascade_dims
94
+ )
95
+
96
+ # create LDM
97
+ ldm = diffusion.LatentDiffusion(denoiser, autoencoder_obs,
98
+ context_encoder=analysis_net)
99
+ # ldm.load_state_dict(torch.load(self.ldm_weights_fn)['state_dict'])
100
+ ldm.load_state_dict(torch.load(self.ldm_weights_fn))
101
+ return ldm
102
+
103
+ def __call__(
104
+ self,
105
+ R_past,
106
+ num_diffusion_iters=50
107
+ ):
108
+ # preprocess inputs and setup correct input shape
109
+ x = self.transform_precip(R_past)
110
+ timesteps = self.input_timesteps(x)
111
+ future_patches = self.future_timesteps // self.autoenc_time_ratio
112
+ gen_shape = (self.autoenc_hidden_dim, future_patches) + \
113
+ (x.shape[-2]//4, x.shape[-1]//4)
114
+ x = [[x, timesteps]]
115
+
116
+ # run LDM sampler
117
+ with contextlib.redirect_stdout(None):
118
+ (s, intermediates) = self.sampler.sample(
119
+ num_diffusion_iters,
120
+ x[0][0].shape[0],
121
+ gen_shape,
122
+ x,
123
+ progbar=self.verbose
124
+ )
125
+
126
+ # postprocess outputs
127
+ y_pred = self.ldm.autoencoder.decode(s)
128
+ R_pred = self.inv_transform_precip(y_pred)
129
+
130
+ return R_pred[0,...]
131
+
132
+ def transform_precip(self, R):
133
+ # x = R.copy()
134
+ x = R.clone().detach()
135
+ x[~(x >= self.R_min_value)] = self.R_zero_value
136
+ x = np.log10(x)
137
+ x -= self.log_R_mean
138
+ x /= self.log_R_std
139
+ x = x.reshape((1,) + x.shape)
140
+ x = self.antialiasing(x)
141
+ x = x.reshape((1,) + x.shape)
142
+ return torch.Tensor(x).to(device=self.ldm.device)
143
+
144
+ def inv_transform_precip(self, x):
145
+ x *= self.log_R_std
146
+ x += self.log_R_mean
147
+ R = torch.pow(10, x)
148
+ if self.R_min_output:
149
+ R[R < self.R_min_output] = 0.0
150
+ if self.R_max_output is not None:
151
+ R[R > self.R_max_output] = self.R_max_output
152
+ R = R[:,0,...]
153
+ return R.to(device='cpu').numpy()
154
+
155
+ def input_timesteps(self, x):
156
+ batch_size = x.shape[0]
157
+ t0 = -x.shape[2]+1
158
+ t1 = 1
159
+ timesteps = torch.arange(t0, t1,
160
+ dtype=x.dtype, device=self.ldm.device)
161
+ return timesteps.unsqueeze(0).expand(batch_size,-1)
162
+
163
+
164
+ class ForecastDistributed:
165
+ def __init__(
166
+ self,
167
+ ldm_weights_fn,
168
+ autoenc_weights_fn,
169
+ past_timesteps=4,
170
+ future_timesteps=8,
171
+ autoenc_time_ratio=4,
172
+ autoenc_hidden_dim=32,
173
+ verbose=True,
174
+ R_min_value=0.1,
175
+ R_zero_value=0.02,
176
+ R_min_output=0.1,
177
+ R_max_output=118.428,
178
+ log_R_mean=-0.051,
179
+ log_R_std=0.528,
180
+ ):
181
+ self.verbose = verbose
182
+ self.R_min_value = R_min_value
183
+ self.R_zero_value = R_zero_value
184
+ self.R_min_output = R_min_output
185
+ self.R_max_output = R_max_output
186
+ self.log_R_mean = log_R_mean
187
+ self.log_R_std = log_R_std
188
+ self.past_timesteps = past_timesteps
189
+ self.future_timesteps = future_timesteps
190
+ self.autoenc_time_ratio = autoenc_time_ratio
191
+ self.autoenc_hidden_dim = autoenc_hidden_dim
192
+
193
+ # start worker processes
194
+ context = mp.get_context('spawn')
195
+ self.input_queue = context.Queue()
196
+ self.output_queue = context.Queue()
197
+ process_kwargs = {
198
+ "past_timesteps": past_timesteps,
199
+ "future_timesteps": future_timesteps,
200
+ "ldm_weights_fn": ldm_weights_fn,
201
+ "autoenc_weights_fn": autoenc_weights_fn,
202
+ "autoenc_time_ratio": autoenc_time_ratio,
203
+ "autoenc_hidden_dim": autoenc_hidden_dim,
204
+ "R_min_value": R_min_value,
205
+ "R_zero_value": R_zero_value,
206
+ "R_min_output": R_min_output,
207
+ "R_max_output": R_max_output,
208
+ "log_R_mean": log_R_mean,
209
+ "log_R_std": log_R_std,
210
+ "verbose": True
211
+ }
212
+ self.num_procs = max(0, torch.cuda.device_count())
213
+ self.compute_procs = mp.spawn(
214
+ _compute_process,
215
+ args=(self.input_queue, self.output_queue, process_kwargs),
216
+ nprocs=self.num_procs,
217
+ join=False
218
+ )
219
+
220
+ # wait for worker processes to be ready
221
+ for _ in range(self.num_procs):
222
+ self.output_queue.get()
223
+
224
+ gc.collect()
225
+
226
+ def __call__(
227
+ self,
228
+ R_past,
229
+ ensemble_members=1,
230
+ num_diffusion_iters=50
231
+ ):
232
+ # send samples to compute processes
233
+ for (i, R_past_sample) in enumerate(R_past):
234
+ for j in range(ensemble_members):
235
+ self.input_queue.put((R_past_sample, num_diffusion_iters, i, j))
236
+
237
+ # build output array
238
+ pred_shape = (R_past.shape[0], self.future_timesteps) + \
239
+ R_past.shape[2:] + (ensemble_members,)
240
+ R_pred = np.empty(pred_shape, R_past.dtype)
241
+
242
+ # gather outputs from processes
243
+ predictions_needed = R_past.shape[0] * ensemble_members
244
+ for _ in range(predictions_needed):
245
+ (R_pred_sample, i, j) = self.output_queue.get()
246
+ R_pred[i,...,j] = R_pred_sample
247
+
248
+ return R_pred
249
+
250
+ def __del__(self):
251
+ for _ in range(self.num_procs):
252
+ self.input_queue.put(None)
253
+ self.compute_procs.join()
254
+
255
+
256
+ def _compute_process(process_index, input_queue, output_queue, kwargs):
257
+ gpu = process_index if (torch.cuda.device_count() > 0) else None
258
+ fc = Forecast(gpu=gpu, **kwargs)
259
+ output_queue.put("Ready") # signal process ready to accept inputs
260
+
261
+ while (data := input_queue.get()) is not None:
262
+ (R_past, num_diffusion_iters, sample, member) = data
263
+ R_pred = fc(R_past, num_diffusion_iters=num_diffusion_iters)
264
+ output_queue.put((R_pred, sample, member))
ldcast/models/autoenc/autoenc.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import pytorch_lightning as pl
4
+
5
+ from ..distributions import kl_from_standard_normal, ensemble_nll_normal
6
+ from ..distributions import sample_from_standard_normal
7
+
8
+
9
+ class AutoencoderKL(pl.LightningModule):
10
+ def __init__(
11
+ self,
12
+ encoder, decoder,
13
+ kl_weight=0.01,
14
+ encoded_channels=64,
15
+ hidden_width=32,
16
+ **kwargs
17
+ ):
18
+ super().__init__(**kwargs)
19
+ self.encoder = encoder
20
+ self.decoder = decoder
21
+ self.hidden_width = hidden_width
22
+ self.to_moments = nn.Conv3d(encoded_channels, 2*hidden_width,
23
+ kernel_size=1)
24
+ self.to_decoder = nn.Conv3d(hidden_width, encoded_channels,
25
+ kernel_size=1)
26
+ self.log_var = nn.Parameter(torch.zeros(size=()))
27
+ self.kl_weight = kl_weight
28
+
29
+ def encode(self, x):
30
+ h = self.encoder(x)
31
+ (mean, log_var) = torch.chunk(self.to_moments(h), 2, dim=1)
32
+ return (mean, log_var)
33
+
34
+ def decode(self, z):
35
+ z = self.to_decoder(z)
36
+ dec = self.decoder(z)
37
+ return dec
38
+
39
+ def forward(self, input, sample_posterior=True):
40
+ (mean, log_var) = self.encode(input)
41
+ if sample_posterior:
42
+ z = sample_from_standard_normal(mean, log_var)
43
+ else:
44
+ z = mean
45
+ dec = self.decode(z)
46
+ return (dec, mean, log_var)
47
+
48
+ def _loss(self, batch):
49
+ (x,y) = batch
50
+ while isinstance(x, list) or isinstance(x, tuple):
51
+ x = x[0][0]
52
+ (y_pred, mean, log_var) = self.forward(x)
53
+
54
+ rec_loss = (y-y_pred).abs().mean()
55
+ kl_loss = kl_from_standard_normal(mean, log_var)
56
+
57
+ total_loss = rec_loss + self.kl_weight * kl_loss
58
+
59
+ return (total_loss, rec_loss, kl_loss)
60
+
61
+ def training_step(self, batch, batch_idx):
62
+ loss = self._loss(batch)[0]
63
+ self.log("train_loss", loss)
64
+ return loss
65
+
66
+ @torch.no_grad()
67
+ def val_test_step(self, batch, batch_idx, split="val"):
68
+ (total_loss, rec_loss, kl_loss) = self._loss(batch)
69
+ log_params = {"on_step": False, "on_epoch": True, "prog_bar": True}
70
+ self.log(f"{split}_loss", total_loss, **log_params)
71
+ self.log(f"{split}_rec_loss", rec_loss.mean(), **log_params)
72
+ self.log(f"{split}_kl_loss", kl_loss, **log_params)
73
+
74
+ def validation_step(self, batch, batch_idx):
75
+ self.val_test_step(batch, batch_idx, split="val")
76
+
77
+ def test_step(self, batch, batch_idx):
78
+ self.val_test_step(batch, batch_idx, split="test")
79
+
80
+ def configure_optimizers(self):
81
+ optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3,
82
+ betas=(0.5, 0.9), weight_decay=1e-3)
83
+ reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau(
84
+ optimizer, patience=3, factor=0.25, verbose=True
85
+ )
86
+ return {
87
+ "optimizer": optimizer,
88
+ "lr_scheduler": {
89
+ "scheduler": reduce_lr,
90
+ "monitor": "val_rec_loss",
91
+ "frequency": 1,
92
+ },
93
+ }
ldcast/models/autoenc/encoder.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch.nn as nn
3
+
4
+ from ..blocks.resnet import ResBlock3D
5
+ from ..utils import activation, normalization
6
+
7
+
8
+ class SimpleConvEncoder(nn.Sequential):
9
+ def __init__(self, in_dim=1, levels=2, min_ch=64):
10
+ sequence = []
11
+ channels = np.hstack([
12
+ in_dim,
13
+ (8**np.arange(1,levels+1)).clip(min=min_ch)
14
+ ])
15
+
16
+ for i in range(levels):
17
+ in_channels = int(channels[i])
18
+ out_channels = int(channels[i+1])
19
+ res_kernel_size = (3,3,3) if i == 0 else (1,3,3)
20
+ res_block = ResBlock3D(
21
+ in_channels, out_channels,
22
+ kernel_size=res_kernel_size,
23
+ norm_kwargs={"num_groups": 1}
24
+ )
25
+ sequence.append(res_block)
26
+ downsample = nn.Conv3d(out_channels, out_channels,
27
+ kernel_size=(2,2,2), stride=(2,2,2))
28
+ sequence.append(downsample)
29
+ in_channels = out_channels
30
+
31
+ super().__init__(*sequence)
32
+
33
+
34
+ class SimpleConvDecoder(nn.Sequential):
35
+ def __init__(self, in_dim=1, levels=2, min_ch=64):
36
+ sequence = []
37
+ channels = np.hstack([
38
+ in_dim,
39
+ (8**np.arange(1,levels+1)).clip(min=min_ch)
40
+ ])
41
+
42
+ for i in reversed(list(range(levels))):
43
+ in_channels = int(channels[i+1])
44
+ out_channels = int(channels[i])
45
+ upsample = nn.ConvTranspose3d(in_channels, in_channels,
46
+ kernel_size=(2,2,2), stride=(2,2,2))
47
+ sequence.append(upsample)
48
+ res_kernel_size = (3,3,3) if (i == 0) else (1,3,3)
49
+ res_block = ResBlock3D(
50
+ in_channels, out_channels,
51
+ kernel_size=res_kernel_size,
52
+ norm_kwargs={"num_groups": 1}
53
+ )
54
+ sequence.append(res_block)
55
+ in_channels = out_channels
56
+
57
+ super().__init__(*sequence)
ldcast/models/autoenc/training.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+
4
+ from . import autoenc
5
+
6
+
7
+ def setup_autoenc_training(
8
+ encoder,
9
+ decoder,
10
+ model_dir,
11
+ ):
12
+ autoencoder = autoenc.AutoencoderKL(encoder, decoder)
13
+
14
+ num_gpus = torch.cuda.device_count()
15
+ accelerator = "gpu" if (num_gpus > 0) else "cpu"
16
+ devices = torch.cuda.device_count() if (accelerator == "gpu") else 1
17
+
18
+ early_stopping = pl.callbacks.EarlyStopping(
19
+ "val_rec_loss", patience=6, verbose=True
20
+ )
21
+ print(model_dir)
22
+ checkpoint = pl.callbacks.ModelCheckpoint(
23
+ dirpath=model_dir,
24
+ filename="{epoch}-{val_rec_loss:.4f}",
25
+ #filename=ckpt,
26
+ monitor="val_rec_loss",
27
+ every_n_epochs=1,
28
+ save_top_k=3,
29
+ save_weights_only=False,
30
+ )
31
+ callbacks = [early_stopping, checkpoint]
32
+
33
+ trainer = pl.Trainer(
34
+ accelerator=accelerator,
35
+ devices=devices,
36
+ max_epochs=1000,
37
+ #strategy='ddp' if (num_gpus > 1) else None,
38
+ callbacks=callbacks,
39
+ )
40
+
41
+ return (autoencoder, trainer)
ldcast/models/benchmarks/dgmr.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+
3
+ import numpy as np
4
+ import tensorflow as tf
5
+
6
+
7
+ class DGMRModel:
8
+ def __init__(
9
+ self,
10
+ model_handle,
11
+ multi_gpu=True,
12
+ transform_to_rainrate=None,
13
+ transform_from_rainrate=None,
14
+ data_format='channels_first',
15
+ calibrated=False,
16
+ ):
17
+ self.transform_to_rainrate = transform_to_rainrate
18
+ self.transform_from_rainrate = transform_from_rainrate
19
+ self.data_format = data_format
20
+ self.calibrated = calibrated
21
+
22
+ if multi_gpu and len(tf.config.list_physical_devices('GPU')) > 1:
23
+ # initialize multi-GPU strategy
24
+ strategy = tf.distribute.MirroredStrategy()
25
+ else: # use default strategy
26
+ strategy = tf.distribute.get_strategy()
27
+
28
+ with strategy.scope():
29
+ module = tf.saved_model.load(model_handle)
30
+
31
+ self.model = module.signatures['default']
32
+ input_signature = self.model.structured_input_signature[1]
33
+ self.noise_dim = input_signature['z'].shape[1]
34
+ self.past_timesteps = input_signature['labels$cond_frames'].shape[1]
35
+
36
+ def __call__(self, x):
37
+ while isinstance(x, list) or isinstance(x, tuple):
38
+ x = x[0]
39
+ x = np.array(x, copy=False)
40
+ if self.data_format == "channels_first":
41
+ x = x.transpose(0,2,3,4,1)
42
+ if self.transform_to_rainrate is not None:
43
+ x = self.transform_to_rainrate(x)
44
+ x = tf.convert_to_tensor(x)
45
+
46
+ num_samples = x.shape[0]
47
+ z = tf.random.normal(shape=(num_samples, self.noise_dim))
48
+ if self.calibrated:
49
+ z = z * 2.0
50
+
51
+ onehot = tf.ones(shape=(num_samples, 1))
52
+ inputs = {
53
+ "z": z,
54
+ "labels$onehot" : onehot,
55
+ "labels$cond_frames" : x
56
+ }
57
+ y = self.model(**inputs)['default']
58
+ y = y[:,self.past_timesteps:,...]
59
+
60
+ y = np.array(y)
61
+ if self.transform_from_rainrate is not None:
62
+ y = self.transform_from_rainrate(y)
63
+ if self.data_format == "channels_first":
64
+ y = y.transpose(0,4,1,2,3)
65
+
66
+ return y
67
+
68
+
69
+ def create_ensemble(
70
+ dgmr, x,
71
+ ensemble_size=32,
72
+ model_path="../models/dgmr/256x256",
73
+
74
+ ):
75
+ y_pred = []
76
+ for member in range(ensemble_size):
77
+ print(f"Generating member {member+1}/{ensemble_size}")
78
+ y_pred.append(dgmr(x))
79
+ gc.collect()
80
+
81
+ y_pred = np.stack(y_pred, axis=-1)
82
+ return y_pred
ldcast/models/benchmarks/pysteps.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # following https://pysteps.readthedocs.io/en/stable/auto_examples/plot_steps_nowcast.html
2
+ from datetime import timedelta
3
+
4
+ import dask
5
+ import numpy as np
6
+ from pysteps import nowcasts
7
+ from pysteps.motion.lucaskanade import dense_lucaskanade
8
+ from pysteps.utils import transformation
9
+
10
+
11
+ class PySTEPSModel:
12
+ def __init__(
13
+ self,
14
+ data_format='channels_first',
15
+ future_timesteps=20,
16
+ ensemble_size=32,
17
+ km_per_pixel=1.0,
18
+ interval=timedelta(minutes=5),
19
+ transform_to_rainrate=None,
20
+ transform_from_rainrate=None,
21
+ ):
22
+ self.transform_to_rainrate = transform_to_rainrate
23
+ self.transform_from_rainrate = transform_from_rainrate
24
+ self.data_format = data_format
25
+ self.nowcast_method = nowcasts.get_method("steps")
26
+ self.future_timesteps = future_timesteps
27
+ self.ensemble_size = ensemble_size
28
+ self.km_per_pixel = km_per_pixel
29
+ self.interval = interval
30
+
31
+ def zero_prediction(self, R, zerovalue):
32
+ out_shape = (self.future_timesteps,) + R.shape[1:] + \
33
+ (self.ensemble_size,)
34
+ return np.full(out_shape, zerovalue, dtype=R.dtype)
35
+
36
+ def predict_sample(self, x, threshold=-10.0, zerovalue=-15.0):
37
+ R = self.transform_to_rainrate(x)
38
+ (R, _) = transformation.dB_transform(
39
+ R, threshold=0.1, zerovalue=zerovalue
40
+ )
41
+ R[~np.isfinite(R)] = zerovalue
42
+ if (R == zerovalue).all():
43
+ R_f = self.zero_prediction(R, zerovalue)
44
+ else:
45
+ V = dense_lucaskanade(R)
46
+ try:
47
+ R_f = self.nowcast_method(
48
+ R,
49
+ V,
50
+ self.future_timesteps,
51
+ n_ens_members=self.ensemble_size,
52
+ n_cascade_levels=6,
53
+ precip_thr=threshold,
54
+ kmperpixel=self.km_per_pixel,
55
+ timestep=self.interval.total_seconds()/60,
56
+ noise_method="nonparametric",
57
+ vel_pert_method="bps",
58
+ mask_method="incremental",
59
+ num_workers=2
60
+ )
61
+ R_f = R_f.transpose(1,2,3,0)
62
+ except (ValueError, RuntimeError) as e:
63
+ zero_error = str(e).endswith("contains non-finite values") or \
64
+ str(e).startswith("zero-size array to reduction operation") or \
65
+ str(e).endswith("nonstationary AR(p) process")
66
+ if zero_error:
67
+ # occasional PySTEPS errors that happen with little/no precip
68
+ # therefore returning all zeros makes sense
69
+ R_f = self.zero_prediction(R, zerovalue)
70
+ else:
71
+ raise
72
+
73
+ # Back-transform to rain rates
74
+ R_f = transformation.dB_transform(
75
+ R_f, threshold=threshold, inverse=True
76
+ )[0]
77
+
78
+ if self.transform_from_rainrate is not None:
79
+ R_f = self.transform_from_rainrate(R_f)
80
+
81
+ return R_f
82
+
83
+ def __call__(self, x, parallel=True):
84
+ while isinstance(x, list) or isinstance(x, tuple):
85
+ x = x[0]
86
+ x = np.array(x, copy=False)
87
+ if self.data_format == "channels_first":
88
+ x = x.transpose(0,2,3,4,1)
89
+
90
+ pred = self.predict_sample
91
+ if parallel:
92
+ pred = dask.delayed(pred)
93
+ y = [
94
+ pred(x[i,:,:,:,0])
95
+ for i in range(x.shape[0])
96
+ ]
97
+ if parallel:
98
+ y = dask.compute(y, scheduler="threads", num_workers=len(y))[0]
99
+ y = np.stack(y, axis=0)
100
+
101
+ if self.data_format == "channels_first":
102
+ y = np.expand_dims(y, 1)
103
+ else:
104
+ y = np.expand_dims(y, -2)
105
+
106
+ return y
ldcast/models/benchmarks/transform.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def transform_to_rainrate(x, mean=-0.051, std=0.528, threshold=0.1):
5
+ x = x*std + mean
6
+ R = 10**x
7
+ R[R < threshold] = 0
8
+ return R
9
+
10
+
11
+ def transform_from_rainrate(
12
+ R, mean=-0.051, std=0.528,
13
+ threshold=0.1, fill_value=0.02
14
+ ):
15
+ R = R.copy()
16
+ R[R < threshold] = fill_value
17
+ return (np.log10(R)-mean) / std
ldcast/models/blocks/afno.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #reference: https://github.com/NVlabs/AFNO-transformer
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+
8
+ from .attention import TemporalAttention
9
+
10
+ class Mlp(nn.Module):
11
+ def __init__(
12
+ self,
13
+ in_features, hidden_features=None, out_features=None,
14
+ act_layer=nn.GELU, drop=0.0
15
+ ):
16
+ super().__init__()
17
+ out_features = out_features or in_features
18
+ hidden_features = hidden_features or in_features
19
+ self.fc1 = nn.Linear(in_features, hidden_features)
20
+ self.act = act_layer()
21
+ self.fc2 = nn.Linear(hidden_features, out_features)
22
+ self.drop = nn.Dropout(drop) if drop > 0 else nn.Identity()
23
+
24
+ def forward(self, x):
25
+ x = self.fc1(x)
26
+ x = self.act(x)
27
+ x = self.drop(x)
28
+ x = self.fc2(x)
29
+ x = self.drop(x)
30
+ return x
31
+
32
+
33
+ class AFNO2D(nn.Module):
34
+ def __init__(self, hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1, hidden_size_factor=1):
35
+ super().__init__()
36
+ assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}"
37
+
38
+ self.hidden_size = hidden_size
39
+ self.sparsity_threshold = sparsity_threshold
40
+ self.num_blocks = num_blocks
41
+ self.block_size = self.hidden_size // self.num_blocks
42
+ self.hard_thresholding_fraction = hard_thresholding_fraction
43
+ self.hidden_size_factor = hidden_size_factor
44
+ self.scale = 0.02
45
+
46
+ self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor))
47
+ self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor))
48
+ self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size))
49
+ self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size))
50
+
51
+ def forward(self, x):
52
+ bias = x
53
+
54
+ dtype = x.dtype
55
+ x = x.float()
56
+ B, H, W, C = x.shape
57
+
58
+ x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho")
59
+ x = x.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size)
60
+
61
+ o1_real = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
62
+ o1_imag = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
63
+ o2_real = torch.zeros(x.shape, device=x.device)
64
+ o2_imag = torch.zeros(x.shape, device=x.device)
65
+
66
+ total_modes = H // 2 + 1
67
+ kept_modes = int(total_modes * self.hard_thresholding_fraction)
68
+
69
+ o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu(
70
+ torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[0]) - \
71
+ torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[1]) + \
72
+ self.b1[0]
73
+ )
74
+
75
+ o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu(
76
+ torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[0]) + \
77
+ torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[1]) + \
78
+ self.b1[1]
79
+ )
80
+
81
+ o2_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = (
82
+ torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) - \
83
+ torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \
84
+ self.b2[0]
85
+ )
86
+
87
+ o2_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = (
88
+ torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) + \
89
+ torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \
90
+ self.b2[1]
91
+ )
92
+
93
+ x = torch.stack([o2_real, o2_imag], dim=-1)
94
+ x = F.softshrink(x, lambd=self.sparsity_threshold)
95
+ x = torch.view_as_complex(x)
96
+ x = x.reshape(B, H, W // 2 + 1, C)
97
+ x = torch.fft.irfft2(x, s=(H, W), dim=(1,2), norm="ortho")
98
+ x = x.type(dtype)
99
+
100
+ return x + bias
101
+
102
+
103
+ class Block(nn.Module):
104
+ def __init__(
105
+ self,
106
+ dim,
107
+ mlp_ratio=4.,
108
+ drop=0.,
109
+ drop_path=0.,
110
+ act_layer=nn.GELU,
111
+ norm_layer=nn.LayerNorm,
112
+ double_skip=True,
113
+ num_blocks=8,
114
+ sparsity_threshold=0.01,
115
+ hard_thresholding_fraction=1.0
116
+ ):
117
+ super().__init__()
118
+ self.norm1 = norm_layer(dim)
119
+ self.filter = AFNO2D(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction)
120
+ self.norm2 = norm_layer(dim)
121
+ mlp_hidden_dim = int(dim * mlp_ratio)
122
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
123
+ self.double_skip = double_skip
124
+
125
+ def forward(self, x):
126
+ residual = x
127
+ x = self.norm1(x)
128
+ x = self.filter(x)
129
+
130
+ if self.double_skip:
131
+ x = x + residual
132
+ residual = x
133
+
134
+ x = self.norm2(x)
135
+ x = self.mlp(x)
136
+ x = x + residual
137
+ return x
138
+
139
+
140
+
141
+ class AFNO3D(nn.Module):
142
+ def __init__(
143
+ self, hidden_size, num_blocks=8, sparsity_threshold=0.01,
144
+ hard_thresholding_fraction=1, hidden_size_factor=1
145
+ ):
146
+ super().__init__()
147
+ assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}"
148
+
149
+ self.hidden_size = hidden_size
150
+ self.sparsity_threshold = sparsity_threshold
151
+ self.num_blocks = num_blocks
152
+ self.block_size = self.hidden_size // self.num_blocks
153
+ self.hard_thresholding_fraction = hard_thresholding_fraction
154
+ self.hidden_size_factor = hidden_size_factor
155
+ self.scale = 0.02
156
+
157
+ self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor))
158
+ self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor))
159
+ self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size))
160
+ self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size))
161
+
162
+ def forward(self, x):
163
+ bias = x
164
+
165
+ dtype = x.dtype
166
+ x = x.float()
167
+ B, D, H, W, C = x.shape
168
+
169
+ x = torch.fft.rfftn(x, dim=(1, 2, 3), norm="ortho")
170
+ x = x.reshape(B, D, H, W // 2 + 1, self.num_blocks, self.block_size)
171
+
172
+ o1_real = torch.zeros([B, D, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
173
+ o1_imag = torch.zeros([B, D, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
174
+ o2_real = torch.zeros(x.shape, device=x.device)
175
+ o2_imag = torch.zeros(x.shape, device=x.device)
176
+
177
+ total_modes = H // 2 + 1
178
+ kept_modes = int(total_modes * self.hard_thresholding_fraction)
179
+
180
+ o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu(
181
+ torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[0]) - \
182
+ torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[1]) + \
183
+ self.b1[0]
184
+ )
185
+
186
+ o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu(
187
+ torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[0]) + \
188
+ torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[1]) + \
189
+ self.b1[1]
190
+ )
191
+
192
+ o2_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = (
193
+ torch.einsum('...bi,bio->...bo', o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) - \
194
+ torch.einsum('...bi,bio->...bo', o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \
195
+ self.b2[0]
196
+ )
197
+
198
+ o2_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = (
199
+ torch.einsum('...bi,bio->...bo', o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) + \
200
+ torch.einsum('...bi,bio->...bo', o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \
201
+ self.b2[1]
202
+ )
203
+
204
+ x = torch.stack([o2_real, o2_imag], dim=-1)
205
+ x = F.softshrink(x, lambd=self.sparsity_threshold)
206
+ x = torch.view_as_complex(x)
207
+ x = x.reshape(B, D, H, W // 2 + 1, C)
208
+ x = torch.fft.irfftn(x, s=(D, H, W), dim=(1,2,3), norm="ortho")
209
+ x = x.type(dtype)
210
+
211
+ return x + bias
212
+
213
+
214
+ class AFNOBlock3d(nn.Module):
215
+ def __init__(
216
+ self,
217
+ dim,
218
+ mlp_ratio=4.,
219
+ drop=0.,
220
+ act_layer=nn.GELU,
221
+ norm_layer=nn.LayerNorm,
222
+ double_skip=True,
223
+ num_blocks=8,
224
+ sparsity_threshold=0.01,
225
+ hard_thresholding_fraction=1.0,
226
+ data_format="channels_last",
227
+ mlp_out_features=None,
228
+ ):
229
+ super().__init__()
230
+ self.norm_layer = norm_layer
231
+ self.norm1 = norm_layer(dim)
232
+ self.filter = AFNO3D(dim, num_blocks, sparsity_threshold,
233
+ hard_thresholding_fraction)
234
+ self.norm2 = norm_layer(dim)
235
+ mlp_hidden_dim = int(dim * mlp_ratio)
236
+ self.mlp = Mlp(
237
+ in_features=dim, out_features=mlp_out_features,
238
+ hidden_features=mlp_hidden_dim,
239
+ act_layer=act_layer, drop=drop
240
+ )
241
+ self.double_skip = double_skip
242
+ self.channels_first = (data_format == "channels_first")
243
+
244
+ def forward(self, x):
245
+ if self.channels_first:
246
+ # AFNO natively uses a channels-last data format
247
+ x = x.permute(0,2,3,4,1)
248
+
249
+ residual = x
250
+ x = self.norm1(x)
251
+ x = self.filter(x)
252
+
253
+ if self.double_skip:
254
+ x = x + residual
255
+ residual = x
256
+
257
+ x = self.norm2(x)
258
+ x = self.mlp(x)
259
+ x = x + residual
260
+
261
+ if self.channels_first:
262
+ x = x.permute(0,4,1,2,3)
263
+
264
+ return x
265
+
266
+
267
+ class PatchEmbed3d(nn.Module):
268
+ def __init__(self, patch_size=(4,4,4), in_chans=1, embed_dim=256):
269
+ super().__init__()
270
+ self.patch_size = patch_size
271
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
272
+
273
+ def forward(self, x):
274
+ x = self.proj(x)
275
+ x = x.permute(0,2,3,4,1) # convert to BHWC
276
+ return x
277
+
278
+
279
+ class PatchExpand3d(nn.Module):
280
+ def __init__(self, patch_size=(4,4,4), out_chans=1, embed_dim=256):
281
+ super().__init__()
282
+ self.patch_size = patch_size
283
+ self.proj = nn.Linear(embed_dim, out_chans*np.prod(patch_size))
284
+
285
+ def forward(self, x):
286
+ x = self.proj(x)
287
+ x = rearrange(
288
+ x,
289
+ "b d h w (p0 p1 p2 c_out) -> b c_out (d p0) (h p1) (w p2)",
290
+ p0=self.patch_size[0],
291
+ p1=self.patch_size[1],
292
+ p2=self.patch_size[2],
293
+ d=x.shape[1],
294
+ h=x.shape[2],
295
+ w=x.shape[3],
296
+ )
297
+ return x
298
+
299
+
300
+ class AFNOCrossAttentionBlock3d(nn.Module):
301
+ """ AFNO 3D Block with channel mixing from two sources.
302
+ """
303
+ def __init__(
304
+ self,
305
+ dim,
306
+ context_dim,
307
+ mlp_ratio=2.,
308
+ drop=0.,
309
+ act_layer=nn.GELU,
310
+ norm_layer=nn.Identity,
311
+ double_skip=True,
312
+ num_blocks=8,
313
+ sparsity_threshold=0.01,
314
+ hard_thresholding_fraction=1.0,
315
+ data_format="channels_last",
316
+ timesteps=None
317
+ ):
318
+ super().__init__()
319
+
320
+ self.norm1 = norm_layer(dim)
321
+ self.norm2 = norm_layer(dim+context_dim)
322
+ mlp_hidden_dim = int((dim+context_dim) * mlp_ratio)
323
+ self.pre_proj = nn.Linear(dim+context_dim, dim+context_dim)
324
+ self.filter = AFNO3D(dim+context_dim, num_blocks, sparsity_threshold,
325
+ hard_thresholding_fraction)
326
+ self.mlp = Mlp(
327
+ in_features=dim+context_dim,
328
+ out_features=dim,
329
+ hidden_features=mlp_hidden_dim,
330
+ act_layer=act_layer, drop=drop
331
+ )
332
+ self.channels_first = (data_format == "channels_first")
333
+
334
+ def forward(self, x, y):
335
+ if self.channels_first:
336
+ # AFNO natively uses a channels-last order
337
+ x = x.permute(0,2,3,4,1)
338
+ y = y.permute(0,2,3,4,1)
339
+
340
+ xy = torch.concat((self.norm1(x),y), axis=-1)
341
+ xy = self.pre_proj(xy) + xy
342
+ xy = self.filter(self.norm2(xy)) + xy # AFNO filter
343
+ x = self.mlp(xy) + x # feed-forward
344
+
345
+ if self.channels_first:
346
+ x = x.permute(0,4,1,2,3)
347
+
348
+ return x
ldcast/models/blocks/attention.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class TemporalAttention(nn.Module):
9
+ def __init__(
10
+ self, channels, context_channels=None,
11
+ head_dim=32, num_heads=8
12
+ ):
13
+ super().__init__()
14
+ self.channels = channels
15
+ if context_channels is None:
16
+ context_channels = channels
17
+ self.context_channels = context_channels
18
+ self.head_dim = head_dim
19
+ self.num_heads = num_heads
20
+ self.inner_dim = head_dim * num_heads
21
+ self.attn_scale = self.head_dim ** -0.5
22
+ if channels % num_heads:
23
+ raise ValueError("channels must be divisible by num_heads")
24
+ self.KV = nn.Linear(context_channels, self.inner_dim*2)
25
+ self.Q = nn.Linear(channels, self.inner_dim)
26
+ self.proj = nn.Linear(self.inner_dim, channels)
27
+
28
+ def forward(self, x, y=None):
29
+ if y is None:
30
+ y = x
31
+
32
+ (K,V) = self.KV(y).chunk(2, dim=-1)
33
+ (B, Dk, H, W, C) = K.shape
34
+ shape = (B, Dk, H, W, self.num_heads, self.head_dim)
35
+ K = K.reshape(shape)
36
+ V = V.reshape(shape)
37
+
38
+ Q = self.Q(x)
39
+ (B, Dq, H, W, C) = Q.shape
40
+ shape = (B, Dq, H, W, self.num_heads, self.head_dim)
41
+ Q = Q.reshape(shape)
42
+
43
+ K = K.permute((0,2,3,4,5,1)) # K^T
44
+ V = V.permute((0,2,3,4,1,5))
45
+ Q = Q.permute((0,2,3,4,1,5))
46
+
47
+ attn = torch.matmul(Q, K) * self.attn_scale
48
+ attn = F.softmax(attn, dim=-1)
49
+ y = torch.matmul(attn, V)
50
+ y = y.permute((0,4,1,2,3,5))
51
+ y = y.reshape((B,Dq,H,W,C))
52
+ y = self.proj(y)
53
+ return y
54
+
55
+
56
+ class TemporalTransformer(nn.Module):
57
+ def __init__(self,
58
+ channels,
59
+ mlp_dim_mul=1,
60
+ **kwargs
61
+ ):
62
+ super().__init__()
63
+ self.attn1 = TemporalAttention(channels, **kwargs)
64
+ self.attn2 = TemporalAttention(channels, **kwargs)
65
+ self.norm1 = nn.LayerNorm(channels)
66
+ self.norm2 = nn.LayerNorm(channels)
67
+ self.norm3 = nn.LayerNorm(channels)
68
+ self.mlp = MLP(channels, dim_mul=mlp_dim_mul)
69
+
70
+ def forward(self, x, y):
71
+ x = self.attn1(self.norm1(x)) + x # self attention
72
+ x = self.attn2(self.norm2(x), y) + x # cross attention
73
+ return self.mlp(self.norm3(x)) + x # feed-forward
74
+
75
+
76
+ class MLP(nn.Sequential):
77
+ def __init__(self, dim, dim_mul=4):
78
+ inner_dim = dim * dim_mul
79
+ sequence = [
80
+ nn.Linear(dim, inner_dim),
81
+ nn.SiLU(),
82
+ nn.Linear(inner_dim, dim)
83
+ ]
84
+ super().__init__(*sequence)
85
+
86
+
87
+ def positional_encoding(position, dims, add_dims=()):
88
+ div_term = torch.exp(
89
+ torch.arange(0, dims, 2, device=position.device) *
90
+ (-math.log(10000.0) / dims)
91
+ )
92
+ if position.ndim == 1:
93
+ arg = position[:,None] * div_term[None,:]
94
+ else:
95
+ arg = position[:,:,None] * div_term[None,None,:]
96
+
97
+ pos_enc = torch.concat(
98
+ [torch.sin(arg), torch.cos(arg)],
99
+ dim=-1
100
+ )
101
+ if add_dims:
102
+ for dim in add_dims:
103
+ pos_enc = pos_enc.unsqueeze(dim)
104
+ return pos_enc
ldcast/models/blocks/resnet.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from torch.nn.utils.parametrizations import spectral_norm as sn
3
+
4
+ from ..utils import activation, normalization
5
+
6
+
7
+ class ResBlock3D(nn.Module):
8
+ def __init__(
9
+ self, in_channels, out_channels, resample=None,
10
+ resample_factor=(1,1,1), kernel_size=(3,3,3),
11
+ act='swish', norm='group', norm_kwargs=None,
12
+ spectral_norm=False,
13
+ **kwargs
14
+ ):
15
+ super().__init__(**kwargs)
16
+ if in_channels != out_channels:
17
+ self.proj = nn.Conv3d(in_channels, out_channels, kernel_size=1)
18
+ else:
19
+ self.proj = nn.Identity()
20
+
21
+ padding = tuple(k//2 for k in kernel_size)
22
+ if resample == "down":
23
+ self.resample = nn.AvgPool3d(resample_factor, ceil_mode=True)
24
+ self.conv1 = nn.Conv3d(in_channels, out_channels,
25
+ kernel_size=kernel_size, stride=resample_factor, padding=padding)
26
+ self.conv2 = nn.Conv3d(out_channels, out_channels,
27
+ kernel_size=kernel_size, padding=padding)
28
+ elif resample == "up":
29
+ self.resample = nn.Upsample(
30
+ scale_factor=resample_factor, mode='trilinear')
31
+ self.conv1 = nn.ConvTranspose3d(in_channels, out_channels,
32
+ kernel_size=kernel_size, padding=padding)
33
+ output_padding = tuple(
34
+ 2*p+s-k for (p,s,k) in zip(padding,resample_factor,kernel_size)
35
+ )
36
+ self.conv2 = nn.ConvTranspose3d(out_channels, out_channels,
37
+ kernel_size=kernel_size, stride=resample_factor,
38
+ padding=padding, output_padding=output_padding)
39
+ else:
40
+ self.resample = nn.Identity()
41
+ self.conv1 = nn.Conv3d(in_channels, out_channels,
42
+ kernel_size=kernel_size, padding=padding)
43
+ self.conv2 = nn.Conv3d(out_channels, out_channels,
44
+ kernel_size=kernel_size, padding=padding)
45
+
46
+ if isinstance(act, str):
47
+ act = (act, act)
48
+ self.act1 = activation(act_type=act[0])
49
+ self.act2 = activation(act_type=act[1])
50
+
51
+ if norm_kwargs is None:
52
+ norm_kwargs = {}
53
+ self.norm1 = normalization(in_channels, norm_type=norm, **norm_kwargs)
54
+ self.norm2 = normalization(out_channels, norm_type=norm, **norm_kwargs)
55
+ if spectral_norm:
56
+ self.conv1 = sn(self.conv1)
57
+ self.conv2 = sn(self.conv2)
58
+ if not isinstance(self.proj, nn.Identity):
59
+ self.proj = sn(self.proj)
60
+
61
+
62
+ def forward(self, x):
63
+ x_in = self.resample(self.proj(x))
64
+ x = self.norm1(x)
65
+ x = self.act1(x)
66
+ x = self.conv1(x)
67
+ x = self.norm2(x)
68
+ x = self.act2(x)
69
+ x = self.conv2(x)
70
+ return x + x_in
ldcast/models/diffusion/diffusion.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ From https://github.com/CompVis/latent-diffusion/main/ldm/models/diffusion/ddpm.py
3
+ Pared down to simplify code.
4
+
5
+ The original file acknowledges:
6
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
7
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
8
+ https://github.com/CompVis/taming-transformers
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import numpy as np
14
+ import pytorch_lightning as pl
15
+ from contextlib import contextmanager
16
+ from functools import partial
17
+ from torchmetrics import MeanSquaredError
18
+
19
+ from .utils import make_beta_schedule, extract_into_tensor, noise_like, timestep_embedding
20
+ from .ema import LitEma
21
+ from ..blocks.afno import PatchEmbed3d, PatchExpand3d, AFNOBlock3d
22
+
23
+
24
+ class LatentDiffusion(pl.LightningModule):
25
+ def __init__(self,
26
+ model,
27
+ autoencoder,
28
+ context_encoder=None,
29
+ timesteps=1000,
30
+ beta_schedule="linear",
31
+ loss_type="l2",
32
+ use_ema=True,
33
+ lr=1e-4,
34
+ lr_warmup=0,
35
+ linear_start=1e-4,
36
+ linear_end=2e-2,
37
+ cosine_s=8e-3,
38
+ parameterization="eps", # all assuming fixed variance schedules
39
+ ):
40
+ super().__init__()
41
+ self.model = model
42
+ self.autoencoder = autoencoder.requires_grad_(False)
43
+ self.conditional = (context_encoder is not None)
44
+ self.context_encoder = context_encoder
45
+ self.lr = lr
46
+ self.lr_warmup = lr_warmup
47
+
48
+ self.val_loss = MeanSquaredError()
49
+
50
+ assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
51
+ self.parameterization = parameterization
52
+
53
+ self.use_ema = use_ema
54
+ if self.use_ema:
55
+ self.model_ema = LitEma(self.model)
56
+
57
+ self.register_schedule(
58
+ beta_schedule=beta_schedule, timesteps=timesteps,
59
+ linear_start=linear_start, linear_end=linear_end,
60
+ cosine_s=cosine_s
61
+ )
62
+
63
+ self.loss_type = loss_type
64
+
65
+ def register_schedule(self, beta_schedule="linear", timesteps=1000,
66
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
67
+
68
+ betas = make_beta_schedule(
69
+ beta_schedule, timesteps,
70
+ linear_start=linear_start, linear_end=linear_end,
71
+ cosine_s=cosine_s
72
+ )
73
+ alphas = 1. - betas
74
+ alphas_cumprod = np.cumprod(alphas, axis=0)
75
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
76
+
77
+ timesteps, = betas.shape
78
+ self.num_timesteps = int(timesteps)
79
+ self.linear_start = linear_start
80
+ self.linear_end = linear_end
81
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
82
+
83
+ to_torch = partial(torch.tensor, dtype=torch.float32)
84
+
85
+ self.register_buffer('betas', to_torch(betas))
86
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
87
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
88
+
89
+ # calculations for diffusion q(x_t | x_{t-1}) and others
90
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
91
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
92
+
93
+ @contextmanager
94
+ def ema_scope(self, context=None):
95
+ if self.use_ema:
96
+ self.model_ema.store(self.model.parameters())
97
+ self.model_ema.copy_to(self.model)
98
+ if context is not None:
99
+ print(f"{context}: Switched to EMA weights")
100
+ try:
101
+ yield None
102
+ finally:
103
+ if self.use_ema:
104
+ self.model_ema.restore(self.model.parameters())
105
+ if context is not None:
106
+ print(f"{context}: Restored training weights")
107
+
108
+ def apply_model(self, x_noisy, t, cond=None, return_ids=False):
109
+ if self.conditional:
110
+ cond = self.context_encoder(cond)
111
+ with self.ema_scope():
112
+ return self.model(x_noisy, t, context=cond)
113
+
114
+ def q_sample(self, x_start, t, noise=None):
115
+ if noise is None:
116
+ noise = torch.randn_like(x_start)
117
+ return (
118
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
119
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
120
+ )
121
+
122
+ def get_loss(self, pred, target, mean=True):
123
+ if self.loss_type == 'l1':
124
+ loss = (target - pred).abs()
125
+ if mean:
126
+ loss = loss.mean()
127
+ elif self.loss_type == 'l2':
128
+ if mean:
129
+ loss = torch.nn.functional.mse_loss(target, pred)
130
+ else:
131
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
132
+ else:
133
+ raise NotImplementedError("unknown loss type '{loss_type}'")
134
+
135
+ return loss
136
+
137
+ def p_losses(self, x_start, t, noise=None, context=None):
138
+ if noise is None:
139
+ noise = torch.randn_like(x_start)
140
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
141
+ model_out = self.model(x_noisy, t, context=context)
142
+
143
+ if self.parameterization == "eps":
144
+ target = noise
145
+ elif self.parameterization == "x0":
146
+ target = x_start
147
+ else:
148
+ raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
149
+
150
+ return self.get_loss(model_out, target, mean=False).mean()
151
+
152
+ def forward(self, x, *args, **kwargs):
153
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
154
+ return self.p_losses(x, t, *args, **kwargs)
155
+
156
+ def shared_step(self, batch):
157
+ (x,y) = batch
158
+ y = self.autoencoder.encode(y)[0]
159
+ context = self.context_encoder(x) if self.conditional else None
160
+ return self(y, context=context)
161
+
162
+ def training_step(self, batch, batch_idx):
163
+ loss = self.shared_step(batch)
164
+ self.log("train_loss", loss)
165
+ return loss
166
+
167
+ @torch.no_grad()
168
+ def validation_step(self, batch, batch_idx):
169
+ #x, y = batch
170
+ #y_pred = self(x)
171
+ #loss2 = torch.nn.functional.mse_loss(y_pred, y)
172
+ loss = self.shared_step(batch)
173
+ with self.ema_scope():
174
+ loss_ema = self.shared_step(batch)
175
+ log_params = {"on_step": False, "on_epoch": True, "prog_bar": True}
176
+ self.log("val_loss", loss, **log_params)
177
+ self.log("val_loss_ema", loss, **log_params)
178
+ #self.log("mean_square_error", loss2, **log_params)
179
+
180
+ def test_step(self, batch, batch_idx):
181
+ return self.validation_step(batch, batch_idx)
182
+
183
+ def on_train_batch_end(self, *args, **kwargs):
184
+ if self.use_ema:
185
+ self.model_ema(self.model)
186
+
187
+ def configure_optimizers(self):
188
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr,
189
+ betas=(0.5, 0.9), weight_decay=1e-3)
190
+ reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau(
191
+ optimizer, patience=3, factor=0.25, verbose=True
192
+ )
193
+ return {
194
+ "optimizer": optimizer,
195
+ "lr_scheduler": {
196
+ "scheduler": reduce_lr,
197
+ "monitor": "val_loss_ema",
198
+ "frequency": 1,
199
+ },
200
+ }
201
+
202
+ def optimizer_step(
203
+ self,
204
+ epoch,
205
+ batch_idx,
206
+ optimizer,
207
+ optimizer_idx,
208
+ #optimizer_closure,
209
+ **kwargs
210
+ ):
211
+ if self.trainer.global_step < self.lr_warmup:
212
+ lr_scale = (self.trainer.global_step+1) / self.lr_warmup
213
+ for pg in optimizer.param_groups:
214
+ pg['lr'] = lr_scale * self.lr
215
+
216
+ super().optimizer_step(
217
+ epoch, batch_idx, optimizer,
218
+ optimizer_idx,
219
+ #optimizer_closure,
220
+ **kwargs
221
+ )
222
+
ldcast/models/diffusion/ema.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError('Decay must be between 0 and 1')
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14
+ else torch.tensor(-1,dtype=torch.int))
15
+
16
+ for name, p in model.named_parameters():
17
+ if p.requires_grad:
18
+ #remove as '.'-character is not allowed in buffers
19
+ s_name = name.replace('.','')
20
+ self.m_name2s_name.update({name:s_name})
21
+ self.register_buffer(s_name,p.clone().detach().data)
22
+
23
+ self.collected_params = []
24
+
25
+ def forward(self,model):
26
+ decay = self.decay
27
+
28
+ if self.num_updates >= 0:
29
+ self.num_updates += 1
30
+ decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31
+
32
+ one_minus_decay = 1.0 - decay
33
+
34
+ with torch.no_grad():
35
+ m_param = dict(model.named_parameters())
36
+ shadow_params = dict(self.named_buffers())
37
+
38
+ for key in m_param:
39
+ if m_param[key].requires_grad:
40
+ sname = self.m_name2s_name[key]
41
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43
+ else:
44
+ assert not key in self.m_name2s_name
45
+
46
+ def copy_to(self, model):
47
+ m_param = dict(model.named_parameters())
48
+ shadow_params = dict(self.named_buffers())
49
+ for key in m_param:
50
+ if m_param[key].requires_grad:
51
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52
+ else:
53
+ assert not key in self.m_name2s_name
54
+
55
+ def store(self, parameters):
56
+ """
57
+ Save the current parameters for restoring later.
58
+ Args:
59
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60
+ temporarily stored.
61
+ """
62
+ self.collected_params = [param.clone() for param in parameters]
63
+
64
+ def restore(self, parameters):
65
+ """
66
+ Restore the parameters stored with the `store` method.
67
+ Useful to validate the model with EMA parameters without affecting the
68
+ original optimization process. Store the parameters before the
69
+ `copy_to` method. After validation (or model saving), use this to
70
+ restore the former parameters.
71
+ Args:
72
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73
+ updated with the stored parameters.
74
+ """
75
+ for c_param, param in zip(self.collected_params, parameters):
76
+ param.data.copy_(c_param.data)
ldcast/models/diffusion/plms.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py
3
+ """
4
+
5
+
6
+ """SAMPLING ONLY."""
7
+
8
+ import torch
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+
12
+ from .utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
13
+
14
+
15
+ class PLMSSampler:
16
+ def __init__(self, model, schedule="linear", **kwargs):
17
+ self.model = model
18
+ self.ddpm_num_timesteps = model.num_timesteps
19
+ self.schedule = schedule
20
+
21
+ def register_buffer(self, name, attr):
22
+ #if type(attr) == torch.Tensor:
23
+ # if attr.device != torch.device("cuda"):
24
+ # attr = attr.to(torch.device("cuda"))
25
+ setattr(self, name, attr)
26
+
27
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
28
+ if ddim_eta != 0:
29
+ raise ValueError('ddim_eta must be 0 for PLMS')
30
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
31
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
32
+ alphas_cumprod = self.model.alphas_cumprod
33
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
34
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
35
+
36
+ self.register_buffer('betas', to_torch(self.model.betas))
37
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
38
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
39
+
40
+ # calculations for diffusion q(x_t | x_{t-1}) and others
41
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
42
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
43
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
44
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
45
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
46
+
47
+ # ddim sampling parameters
48
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
49
+ ddim_timesteps=self.ddim_timesteps,
50
+ eta=ddim_eta,verbose=verbose)
51
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
52
+ self.register_buffer('ddim_alphas', ddim_alphas)
53
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
54
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
55
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
56
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
57
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
58
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
59
+
60
+ @torch.no_grad()
61
+ def sample(self,
62
+ S,
63
+ batch_size,
64
+ shape,
65
+ conditioning=None,
66
+ callback=None,
67
+ normals_sequence=None,
68
+ img_callback=None,
69
+ quantize_x0=False,
70
+ eta=0.,
71
+ mask=None,
72
+ x0=None,
73
+ temperature=1.,
74
+ noise_dropout=0.,
75
+ score_corrector=None,
76
+ corrector_kwargs=None,
77
+ verbose=True,
78
+ x_T=None,
79
+ log_every_t=100,
80
+ unconditional_guidance_scale=1.,
81
+ unconditional_conditioning=None,
82
+ progbar=True,
83
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
84
+ **kwargs
85
+ ):
86
+ """
87
+ if conditioning is not None:
88
+ if isinstance(conditioning, dict):
89
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
90
+ if cbs != batch_size:
91
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
92
+ else:
93
+ if conditioning.shape[0] != batch_size:
94
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
95
+ """
96
+
97
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
98
+ # sampling
99
+ size = (batch_size,) + shape
100
+ print(f'Data shape for PLMS sampling is {size}')
101
+
102
+ samples, intermediates = self.plms_sampling(conditioning, size,
103
+ callback=callback,
104
+ img_callback=img_callback,
105
+ quantize_denoised=quantize_x0,
106
+ mask=mask, x0=x0,
107
+ ddim_use_original_steps=False,
108
+ noise_dropout=noise_dropout,
109
+ temperature=temperature,
110
+ score_corrector=score_corrector,
111
+ corrector_kwargs=corrector_kwargs,
112
+ x_T=x_T,
113
+ log_every_t=log_every_t,
114
+ unconditional_guidance_scale=unconditional_guidance_scale,
115
+ unconditional_conditioning=unconditional_conditioning,
116
+ progbar=progbar
117
+ )
118
+ return samples, intermediates
119
+
120
+ @torch.no_grad()
121
+ def plms_sampling(self, cond, shape,
122
+ x_T=None, ddim_use_original_steps=False,
123
+ callback=None, timesteps=None, quantize_denoised=False,
124
+ mask=None, x0=None, img_callback=None, log_every_t=100,
125
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
126
+ unconditional_guidance_scale=1., unconditional_conditioning=None, progbar=True):
127
+ device = self.model.betas.device
128
+ b = shape[0]
129
+ if x_T is None:
130
+ img = torch.randn(shape, device=device)
131
+ else:
132
+ img = x_T
133
+
134
+ if timesteps is None:
135
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
136
+ elif timesteps is not None and not ddim_use_original_steps:
137
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
138
+ timesteps = self.ddim_timesteps[:subset_end]
139
+
140
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
141
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
142
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
143
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
144
+
145
+ iterator = time_range
146
+ if progbar:
147
+ iterator = tqdm(iterator, desc='PLMS Sampler', total=total_steps)
148
+ old_eps = []
149
+
150
+ for i, step in enumerate(iterator):
151
+ index = total_steps - i - 1
152
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
153
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
154
+
155
+ if mask is not None:
156
+ assert x0 is not None
157
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
158
+ img = img_orig * mask + (1. - mask) * img
159
+
160
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
161
+ quantize_denoised=quantize_denoised, temperature=temperature,
162
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
163
+ corrector_kwargs=corrector_kwargs,
164
+ unconditional_guidance_scale=unconditional_guidance_scale,
165
+ unconditional_conditioning=unconditional_conditioning,
166
+ old_eps=old_eps, t_next=ts_next)
167
+ img, pred_x0, e_t = outs
168
+ old_eps.append(e_t)
169
+ if len(old_eps) >= 4:
170
+ old_eps.pop(0)
171
+ if callback: callback(i)
172
+ if img_callback: img_callback(pred_x0, i)
173
+
174
+ if index % log_every_t == 0 or index == total_steps - 1:
175
+ intermediates['x_inter'].append(img)
176
+ intermediates['pred_x0'].append(pred_x0)
177
+
178
+ return img, intermediates
179
+
180
+ @torch.no_grad()
181
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
182
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
183
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
184
+ b, *_, device = *x.shape, x.device
185
+
186
+ def get_model_output(x, t):
187
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
188
+ e_t = self.model.apply_model(x, t, c)
189
+ else:
190
+ x_in = torch.cat([x] * 2)
191
+ t_in = torch.cat([t] * 2)
192
+ c_in = torch.cat([unconditional_conditioning, c])
193
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
194
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
195
+
196
+ if score_corrector is not None:
197
+ assert self.model.parameterization == "eps"
198
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
199
+
200
+ return e_t
201
+
202
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
203
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
204
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
205
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
206
+
207
+ def get_x_prev_and_pred_x0(e_t, index):
208
+ # select parameters corresponding to the currently considered timestep
209
+ param_shape = (b,) + (1,)*(x.ndim-1)
210
+ a_t = torch.full(param_shape, alphas[index], device=device)
211
+ a_prev = torch.full(param_shape, alphas_prev[index], device=device)
212
+ sigma_t = torch.full(param_shape, sigmas[index], device=device)
213
+ sqrt_one_minus_at = torch.full(param_shape, sqrt_one_minus_alphas[index],device=device)
214
+
215
+ # current prediction for x_0
216
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
217
+ if quantize_denoised:
218
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
219
+ # direction pointing to x_t
220
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
221
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
222
+ if noise_dropout > 0.:
223
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
224
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
225
+ return x_prev, pred_x0
226
+
227
+ e_t = get_model_output(x, t)
228
+ if len(old_eps) == 0:
229
+ # Pseudo Improved Euler (2nd order)
230
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
231
+ e_t_next = get_model_output(x_prev, t_next)
232
+ e_t_prime = (e_t + e_t_next) / 2
233
+ elif len(old_eps) == 1:
234
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
235
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
236
+ elif len(old_eps) == 2:
237
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
238
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
239
+ elif len(old_eps) >= 3:
240
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
241
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
242
+
243
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
244
+
245
+ return x_prev, pred_x0, e_t
ldcast/models/diffusion/utils.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+ import os
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+ from einops import repeat
16
+
17
+
18
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
19
+ if schedule == "linear":
20
+ betas = (
21
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
22
+ )
23
+
24
+ elif schedule == "cosine":
25
+ timesteps = (
26
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
27
+ )
28
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
29
+ alphas = torch.cos(alphas).pow(2)
30
+ alphas = alphas / alphas[0]
31
+ betas = 1 - alphas[1:] / alphas[:-1]
32
+ betas = np.clip(betas, a_min=0, a_max=0.999)
33
+
34
+ elif schedule == "sqrt_linear":
35
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
36
+ elif schedule == "sqrt":
37
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
38
+ else:
39
+ raise ValueError(f"schedule '{schedule}' unknown.")
40
+ return betas.numpy()
41
+
42
+
43
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
44
+ if ddim_discr_method == 'uniform':
45
+ c = num_ddpm_timesteps // num_ddim_timesteps
46
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
47
+ elif ddim_discr_method == 'quad':
48
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
49
+ else:
50
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
51
+
52
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
53
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
54
+ steps_out = ddim_timesteps + 1
55
+ if verbose:
56
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
57
+ return steps_out
58
+
59
+
60
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
61
+ # select alphas for computing the variance schedule
62
+ alphas = alphacums[ddim_timesteps]
63
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
64
+
65
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
66
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
67
+ if verbose:
68
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
69
+ print(f'For the chosen value of eta, which is {eta}, '
70
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
71
+ return sigmas, alphas, alphas_prev
72
+
73
+
74
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
75
+ """
76
+ Create a beta schedule that discretizes the given alpha_t_bar function,
77
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
78
+ :param num_diffusion_timesteps: the number of betas to produce.
79
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
80
+ produces the cumulative product of (1-beta) up to that
81
+ part of the diffusion process.
82
+ :param max_beta: the maximum beta to use; use values lower than 1 to
83
+ prevent singularities.
84
+ """
85
+ betas = []
86
+ for i in range(num_diffusion_timesteps):
87
+ t1 = i / num_diffusion_timesteps
88
+ t2 = (i + 1) / num_diffusion_timesteps
89
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
90
+ return np.array(betas)
91
+
92
+
93
+ def extract_into_tensor(a, t, x_shape):
94
+ b, *_ = t.shape
95
+ out = a.gather(-1, t)
96
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
97
+
98
+
99
+ def checkpoint(func, inputs, params, flag):
100
+ """
101
+ Evaluate a function without caching intermediate activations, allowing for
102
+ reduced memory at the expense of extra compute in the backward pass.
103
+ :param func: the function to evaluate.
104
+ :param inputs: the argument sequence to pass to `func`.
105
+ :param params: a sequence of parameters `func` depends on but does not
106
+ explicitly take as arguments.
107
+ :param flag: if False, disable gradient checkpointing.
108
+ """
109
+ if flag:
110
+ args = tuple(inputs) + tuple(params)
111
+ return CheckpointFunction.apply(func, len(inputs), *args)
112
+ else:
113
+ return func(*inputs)
114
+
115
+
116
+ class CheckpointFunction(torch.autograd.Function):
117
+ @staticmethod
118
+ def forward(ctx, run_function, length, *args):
119
+ ctx.run_function = run_function
120
+ ctx.input_tensors = list(args[:length])
121
+ ctx.input_params = list(args[length:])
122
+
123
+ with torch.no_grad():
124
+ output_tensors = ctx.run_function(*ctx.input_tensors)
125
+ return output_tensors
126
+
127
+ @staticmethod
128
+ def backward(ctx, *output_grads):
129
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
130
+ with torch.enable_grad():
131
+ # Fixes a bug where the first op in run_function modifies the
132
+ # Tensor storage in place, which is not allowed for detach()'d
133
+ # Tensors.
134
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
135
+ output_tensors = ctx.run_function(*shallow_copies)
136
+ input_grads = torch.autograd.grad(
137
+ output_tensors,
138
+ ctx.input_tensors + ctx.input_params,
139
+ output_grads,
140
+ allow_unused=True,
141
+ )
142
+ del ctx.input_tensors
143
+ del ctx.input_params
144
+ del output_tensors
145
+ return (None, None) + input_grads
146
+
147
+
148
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
149
+ """
150
+ Create sinusoidal timestep embeddings.
151
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
152
+ These may be fractional.
153
+ :param dim: the dimension of the output.
154
+ :param max_period: controls the minimum frequency of the embeddings.
155
+ :return: an [N x dim] Tensor of positional embeddings.
156
+ """
157
+ if not repeat_only:
158
+ half = dim // 2
159
+ freqs = torch.exp(
160
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
161
+ ).to(device=timesteps.device)
162
+ args = timesteps[:, None].float() * freqs[None]
163
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
164
+ if dim % 2:
165
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
166
+ else:
167
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
168
+ return embedding
169
+
170
+
171
+ def zero_module(module):
172
+ """
173
+ Zero out the parameters of a module and return it.
174
+ """
175
+ for p in module.parameters():
176
+ p.detach().zero_()
177
+ return module
178
+
179
+
180
+ def scale_module(module, scale):
181
+ """
182
+ Scale the parameters of a module and return it.
183
+ """
184
+ for p in module.parameters():
185
+ p.detach().mul_(scale)
186
+ return module
187
+
188
+
189
+ def mean_flat(tensor):
190
+ """
191
+ Take the mean over all non-batch dimensions.
192
+ """
193
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
194
+
195
+
196
+ class GroupNorm32(nn.GroupNorm):
197
+ def forward(self, x):
198
+ return super().forward(x.float()).type(x.dtype)
199
+
200
+
201
+ def normalization(channels):
202
+ """
203
+ Make a standard normalization layer.
204
+ :param channels: number of input channels.
205
+ :return: an nn.Module for normalization.
206
+ """
207
+ return nn.Identity() #GroupNorm32(32, channels)
208
+
209
+
210
+ def noise_like(shape, device, repeat=False):
211
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
212
+ noise = lambda: torch.randn(shape, device=device)
213
+ return repeat_noise() if repeat else noise()
214
+
215
+
216
+ def conv_nd(dims, *args, **kwargs):
217
+ """
218
+ Create a 1D, 2D, or 3D convolution module.
219
+ """
220
+ if dims == 1:
221
+ return nn.Conv1d(*args, **kwargs)
222
+ elif dims == 2:
223
+ return nn.Conv2d(*args, **kwargs)
224
+ elif dims == 3:
225
+ return nn.Conv3d(*args, **kwargs)
226
+ raise ValueError(f"unsupported dimensions: {dims}")
227
+
228
+
229
+ def linear(*args, **kwargs):
230
+ """
231
+ Create a linear module.
232
+ """
233
+ return nn.Linear(*args, **kwargs)
234
+
235
+
236
+ def avg_pool_nd(dims, *args, **kwargs):
237
+ """
238
+ Create a 1D, 2D, or 3D average pooling module.
239
+ """
240
+ if dims == 1:
241
+ return nn.AvgPool1d(*args, **kwargs)
242
+ elif dims == 2:
243
+ return nn.AvgPool2d(*args, **kwargs)
244
+ elif dims == 3:
245
+ return nn.AvgPool3d(*args, **kwargs)
246
+ raise ValueError(f"unsupported dimensions: {dims}")
ldcast/models/distributions.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def kl_from_standard_normal(mean, log_var):
6
+ kl = 0.5 * (log_var.exp() + mean.square() - 1.0 - log_var)
7
+ return kl.mean()
8
+
9
+
10
+ def sample_from_standard_normal(mean, log_var, num=None):
11
+ std = (0.5 * log_var).exp()
12
+ shape = mean.shape
13
+ if num is not None:
14
+ # expand channel 1 to create several samples
15
+ shape = shape[:1] + (num,) + shape[1:]
16
+ mean = mean[:,None,...]
17
+ std = std[:,None,...]
18
+ return mean + std * torch.randn(shape, device=mean.device)
19
+
20
+
21
+ def ensemble_nll_normal(ensemble, sample, epsilon=1e-5):
22
+ mean = ensemble.mean(dim=1)
23
+ var = ensemble.var(dim=1, unbiased=True) + epsilon
24
+ logvar = var.log()
25
+
26
+ diff = sample[:,None,...] - mean
27
+ logtwopi = np.log(2*np.pi)
28
+ nll = (logtwopi + logvar + diff.square() / var).mean()
29
+ return nll
ldcast/models/genforecast/analysis.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ from ..nowcast.nowcast import AFNONowcastNetBase
6
+ from ..blocks.resnet import ResBlock3D
7
+
8
+
9
+ class AFNONowcastNetCascade(AFNONowcastNetBase):
10
+ def __init__(self, *args, cascade_depth=4, **kwargs):
11
+ super().__init__(*args, **kwargs)
12
+ self.cascade_depth = cascade_depth
13
+ self.resnet = nn.ModuleList()
14
+ ch = self.embed_dim_out
15
+ self.cascade_dims = [ch]
16
+ for i in range(cascade_depth-1):
17
+ ch_out = 2*ch
18
+ self.cascade_dims.append(ch_out)
19
+ self.resnet.append(
20
+ ResBlock3D(ch, ch_out, kernel_size=(1,3,3), norm=None)
21
+ )
22
+ ch = ch_out
23
+
24
+ def forward(self, x):
25
+ x = super().forward(x)
26
+ img_shape = tuple(x.shape[-2:])
27
+ cascade = {img_shape: x}
28
+ for i in range(self.cascade_depth-1):
29
+ x = F.avg_pool3d(x, (1,2,2))
30
+ x = self.resnet[i](x)
31
+ img_shape = tuple(x.shape[-2:])
32
+ cascade[img_shape] = x
33
+ return cascade
ldcast/models/genforecast/training.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+
4
+ from ..diffusion import diffusion
5
+
6
+
7
+ def setup_genforecast_training(
8
+ model,
9
+ autoencoder,
10
+ context_encoder,
11
+ model_dir,
12
+ lr=1e-4
13
+ ):
14
+ ldm = diffusion.LatentDiffusion(model, autoencoder,
15
+ context_encoder=context_encoder, lr=lr)
16
+
17
+ num_gpus = torch.cuda.device_count()
18
+ accelerator = "gpu" if (num_gpus > 0) else "cpu"
19
+ devices = torch.cuda.device_count() if (accelerator == "gpu") else 1
20
+
21
+ early_stopping = pl.callbacks.EarlyStopping(
22
+ "val_loss_ema", patience=6, verbose=True, check_finite=False
23
+ )
24
+ checkpoint = pl.callbacks.ModelCheckpoint(
25
+ dirpath=model_dir,
26
+ filename="{epoch}-{val_loss_ema:.4f}",
27
+ monitor="val_loss_ema",
28
+ every_n_epochs=1,
29
+ save_top_k=3
30
+ )
31
+ callbacks = [early_stopping, checkpoint]
32
+
33
+ trainer = pl.Trainer(
34
+ accelerator=accelerator,
35
+ devices=devices,
36
+ max_epochs=300,
37
+ #strategy='dp' if (num_gpus > 1) else None,
38
+ callbacks=callbacks,
39
+ #precision=16
40
+ )
41
+
42
+ return (ldm, trainer)
ldcast/models/genforecast/unet.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+ import math
4
+ from typing import Iterable
5
+
6
+ import numpy as np
7
+ import torch as th
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from ..diffusion.utils import (
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ avg_pool_nd,
16
+ zero_module,
17
+ normalization,
18
+ timestep_embedding,
19
+ )
20
+ from ..blocks.afno import AFNOCrossAttentionBlock3d
21
+ SpatialTransformer = type(None)
22
+ #from ldm.modules.attention import SpatialTransformer
23
+
24
+
25
+ class TimestepBlock(nn.Module):
26
+ """
27
+ Any module where forward() takes timestep embeddings as a second argument.
28
+ """
29
+
30
+ @abstractmethod
31
+ def forward(self, x, emb):
32
+ """
33
+ Apply the module to `x` given `emb` timestep embeddings.
34
+ """
35
+
36
+
37
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
38
+ """
39
+ A sequential module that passes timestep embeddings to the children that
40
+ support it as an extra input.
41
+ """
42
+
43
+ def forward(self, x, emb, context=None):
44
+ for layer in self:
45
+ if isinstance(layer, TimestepBlock):
46
+ x = layer(x, emb)
47
+ elif isinstance(layer, AFNOCrossAttentionBlock3d):
48
+ img_shape = tuple(x.shape[-2:])
49
+ x = layer(x, context[img_shape])
50
+ else:
51
+ x = layer(x)
52
+ return x
53
+
54
+
55
+ class Upsample(nn.Module):
56
+ """
57
+ An upsampling layer with an optional convolution.
58
+ :param channels: channels in the inputs and outputs.
59
+ :param use_conv: a bool determining if a convolution is applied.
60
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
61
+ upsampling occurs in the inner-two dimensions.
62
+ """
63
+
64
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
65
+ super().__init__()
66
+ self.channels = channels
67
+ self.out_channels = out_channels or channels
68
+ self.use_conv = use_conv
69
+ self.dims = dims
70
+ if use_conv:
71
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
72
+
73
+ def forward(self, x):
74
+ assert x.shape[1] == self.channels
75
+ if self.dims == 3:
76
+ x = F.interpolate(
77
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
78
+ )
79
+ else:
80
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
81
+ if self.use_conv:
82
+ x = self.conv(x)
83
+ return x
84
+
85
+
86
+ class Downsample(nn.Module):
87
+ """
88
+ A downsampling layer with an optional convolution.
89
+ :param channels: channels in the inputs and outputs.
90
+ :param use_conv: a bool determining if a convolution is applied.
91
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
92
+ downsampling occurs in the inner-two dimensions.
93
+ """
94
+
95
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
96
+ super().__init__()
97
+ self.channels = channels
98
+ self.out_channels = out_channels or channels
99
+ self.use_conv = use_conv
100
+ self.dims = dims
101
+ stride = 2 if dims != 3 else (1, 2, 2)
102
+ if use_conv:
103
+ self.op = conv_nd(
104
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
105
+ )
106
+ else:
107
+ assert self.channels == self.out_channels
108
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
109
+
110
+ def forward(self, x):
111
+ assert x.shape[1] == self.channels
112
+ return self.op(x)
113
+
114
+
115
+ class ResBlock(TimestepBlock):
116
+ """
117
+ A residual block that can optionally change the number of channels.
118
+ :param channels: the number of input channels.
119
+ :param emb_channels: the number of timestep embedding channels.
120
+ :param dropout: the rate of dropout.
121
+ :param out_channels: if specified, the number of out channels.
122
+ :param use_conv: if True and out_channels is specified, use a spatial
123
+ convolution instead of a smaller 1x1 convolution to change the
124
+ channels in the skip connection.
125
+ :param dims: determines if the signal is 1D, 2D, or 3D.
126
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
127
+ :param up: if True, use this block for upsampling.
128
+ :param down: if True, use this block for downsampling.
129
+ """
130
+
131
+ def __init__(
132
+ self,
133
+ channels,
134
+ emb_channels,
135
+ dropout,
136
+ out_channels=None,
137
+ use_conv=False,
138
+ use_scale_shift_norm=False,
139
+ dims=2,
140
+ use_checkpoint=False,
141
+ up=False,
142
+ down=False,
143
+ ):
144
+ super().__init__()
145
+ self.channels = channels
146
+ self.emb_channels = emb_channels
147
+ self.dropout = dropout
148
+ self.out_channels = out_channels or channels
149
+ self.use_conv = use_conv
150
+ self.use_checkpoint = use_checkpoint
151
+ self.use_scale_shift_norm = use_scale_shift_norm
152
+
153
+ self.in_layers = nn.Sequential(
154
+ normalization(channels),
155
+ nn.SiLU(),
156
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
157
+ )
158
+
159
+ self.updown = up or down
160
+
161
+ if up:
162
+ self.h_upd = Upsample(channels, False, dims)
163
+ self.x_upd = Upsample(channels, False, dims)
164
+ elif down:
165
+ self.h_upd = Downsample(channels, False, dims)
166
+ self.x_upd = Downsample(channels, False, dims)
167
+ else:
168
+ self.h_upd = self.x_upd = nn.Identity()
169
+
170
+ self.emb_layers = nn.Sequential(
171
+ nn.SiLU(),
172
+ linear(
173
+ emb_channels,
174
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
175
+ ),
176
+ )
177
+ self.out_layers = nn.Sequential(
178
+ normalization(self.out_channels),
179
+ nn.SiLU(),
180
+ nn.Dropout(p=dropout),
181
+ zero_module(
182
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
183
+ ),
184
+ )
185
+
186
+ if self.out_channels == channels:
187
+ self.skip_connection = nn.Identity()
188
+ elif use_conv:
189
+ self.skip_connection = conv_nd(
190
+ dims, channels, self.out_channels, 3, padding=1
191
+ )
192
+ else:
193
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
194
+
195
+ def forward(self, x, emb):
196
+ """
197
+ Apply the block to a Tensor, conditioned on a timestep embedding.
198
+ :param x: an [N x C x ...] Tensor of features.
199
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
200
+ :return: an [N x C x ...] Tensor of outputs.
201
+ """
202
+ return checkpoint(
203
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
204
+ )
205
+
206
+
207
+ def _forward(self, x, emb):
208
+ if self.updown:
209
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
210
+ h = in_rest(x)
211
+ h = self.h_upd(h)
212
+ x = self.x_upd(x)
213
+ h = in_conv(h)
214
+ else:
215
+ h = self.in_layers(x)
216
+ emb_out = self.emb_layers(emb).type(h.dtype)
217
+ while len(emb_out.shape) < len(h.shape):
218
+ emb_out = emb_out[..., None]
219
+ if self.use_scale_shift_norm:
220
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
221
+ scale, shift = th.chunk(emb_out, 2, dim=1)
222
+ h = out_norm(h) * (1 + scale) + shift
223
+ h = out_rest(h)
224
+ else:
225
+ h = h + emb_out
226
+ h = self.out_layers(h)
227
+ return self.skip_connection(x) + h
228
+
229
+
230
+ class UNetModel(nn.Module):
231
+ """
232
+ The full UNet model with attention and timestep embedding.
233
+ :param in_channels: channels in the input Tensor.
234
+ :param model_channels: base channel count for the model.
235
+ :param out_channels: channels in the output Tensor.
236
+ :param num_res_blocks: number of residual blocks per downsample.
237
+ :param attention_resolutions: a collection of downsample rates at which
238
+ attention will take place. May be a set, list, or tuple.
239
+ For example, if this contains 4, then at 4x downsampling, attention
240
+ will be used.
241
+ :param dropout: the dropout probability.
242
+ :param channel_mult: channel multiplier for each level of the UNet.
243
+ :param conv_resample: if True, use learned convolutions for upsampling and
244
+ downsampling.
245
+ :param dims: determines if the signal is 1D, 2D, or 3D.
246
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
247
+ :param num_heads: the number of attention heads in each attention layer.
248
+ :param num_heads_channels: if specified, ignore num_heads and instead use
249
+ a fixed channel width per attention head.
250
+ :param num_heads_upsample: works with num_heads to set a different number
251
+ of heads for upsampling. Deprecated.
252
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
253
+ :param resblock_updown: use residual blocks for up/downsampling.
254
+
255
+ """
256
+
257
+ def __init__(
258
+ self,
259
+ model_channels,
260
+ in_channels=1,
261
+ out_channels=1,
262
+ num_res_blocks=2,
263
+ attention_resolutions=(1,2,4),
264
+ context_ch=128,
265
+ dropout=0,
266
+ channel_mult=(1, 2, 4, 4),
267
+ conv_resample=True,
268
+ dims=3,
269
+ use_checkpoint=False,
270
+ use_fp16=False,
271
+ num_heads=-1,
272
+ num_head_channels=-1,
273
+ num_heads_upsample=-1,
274
+ use_scale_shift_norm=False,
275
+ resblock_updown=False,
276
+ legacy=True,
277
+ num_timesteps=1
278
+ ):
279
+ super().__init__()
280
+
281
+ if num_heads_upsample == -1:
282
+ num_heads_upsample = num_heads
283
+
284
+ if num_heads == -1:
285
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
286
+
287
+ if num_head_channels == -1:
288
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
289
+
290
+ self.in_channels = in_channels
291
+ self.model_channels = model_channels
292
+ self.out_channels = out_channels
293
+ self.num_res_blocks = num_res_blocks
294
+ self.attention_resolutions = attention_resolutions
295
+ self.dropout = dropout
296
+ self.channel_mult = channel_mult
297
+ self.conv_resample = conv_resample
298
+ self.use_checkpoint = use_checkpoint
299
+ self.dtype = th.float16 if use_fp16 else th.float32
300
+ self.num_heads = num_heads
301
+ self.num_head_channels = num_head_channels
302
+ self.num_heads_upsample = num_heads_upsample
303
+ timesteps = th.arange(1, num_timesteps+1)
304
+
305
+ time_embed_dim = model_channels * 4
306
+ self.time_embed = nn.Sequential(
307
+ linear(model_channels, time_embed_dim),
308
+ nn.SiLU(),
309
+ linear(time_embed_dim, time_embed_dim),
310
+ )
311
+
312
+ self.input_blocks = nn.ModuleList(
313
+ [
314
+ TimestepEmbedSequential(
315
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
316
+ )
317
+ ]
318
+ )
319
+ self._feature_size = model_channels
320
+ input_block_chans = [model_channels]
321
+ ch = model_channels
322
+ ds = 1
323
+ for level, mult in enumerate(channel_mult):
324
+ for _ in range(num_res_blocks):
325
+ layers = [
326
+ ResBlock(
327
+ ch,
328
+ time_embed_dim,
329
+ dropout,
330
+ out_channels=mult * model_channels,
331
+ dims=dims,
332
+ use_checkpoint=use_checkpoint,
333
+ use_scale_shift_norm=use_scale_shift_norm,
334
+ )
335
+ ]
336
+ ch = mult * model_channels
337
+ if ds in attention_resolutions:
338
+ if num_head_channels == -1:
339
+ dim_head = ch // num_heads
340
+ else:
341
+ num_heads = ch // num_head_channels
342
+ dim_head = num_head_channels
343
+ if legacy:
344
+ dim_head = num_head_channels
345
+ layers.append(
346
+ AFNOCrossAttentionBlock3d(
347
+ ch, context_dim=context_ch[level], num_blocks=num_heads,
348
+ data_format="channels_first", timesteps=timesteps
349
+ )
350
+ )
351
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
352
+ self._feature_size += ch
353
+ input_block_chans.append(ch)
354
+ if level != len(channel_mult) - 1:
355
+ out_ch = ch
356
+ self.input_blocks.append(
357
+ TimestepEmbedSequential(
358
+ ResBlock(
359
+ ch,
360
+ time_embed_dim,
361
+ dropout,
362
+ out_channels=out_ch,
363
+ dims=dims,
364
+ use_checkpoint=use_checkpoint,
365
+ use_scale_shift_norm=use_scale_shift_norm,
366
+ down=True,
367
+ )
368
+ if resblock_updown
369
+ else Downsample(
370
+ ch, conv_resample, dims=dims, out_channels=out_ch
371
+ )
372
+ )
373
+ )
374
+ ch = out_ch
375
+ input_block_chans.append(ch)
376
+ ds *= 2
377
+ self._feature_size += ch
378
+
379
+ if num_head_channels == -1:
380
+ dim_head = ch // num_heads
381
+ else:
382
+ num_heads = ch // num_head_channels
383
+ dim_head = num_head_channels
384
+ if legacy:
385
+ dim_head = num_head_channels
386
+ self.middle_block = TimestepEmbedSequential(
387
+ ResBlock(
388
+ ch,
389
+ time_embed_dim,
390
+ dropout,
391
+ dims=dims,
392
+ use_checkpoint=use_checkpoint,
393
+ use_scale_shift_norm=use_scale_shift_norm,
394
+ ),
395
+ AFNOCrossAttentionBlock3d(
396
+ ch, context_dim=context_ch[-1], num_blocks=num_heads,
397
+ data_format="channels_first", timesteps=timesteps
398
+ ),
399
+ ResBlock(
400
+ ch,
401
+ time_embed_dim,
402
+ dropout,
403
+ dims=dims,
404
+ use_checkpoint=use_checkpoint,
405
+ use_scale_shift_norm=use_scale_shift_norm,
406
+ ),
407
+ )
408
+ self._feature_size += ch
409
+
410
+ self.output_blocks = nn.ModuleList([])
411
+ for level, mult in list(enumerate(channel_mult))[::-1]:
412
+ for i in range(num_res_blocks + 1):
413
+ ich = input_block_chans.pop()
414
+ layers = [
415
+ ResBlock(
416
+ ch + ich,
417
+ time_embed_dim,
418
+ dropout,
419
+ out_channels=model_channels * mult,
420
+ dims=dims,
421
+ use_checkpoint=use_checkpoint,
422
+ use_scale_shift_norm=use_scale_shift_norm,
423
+ )
424
+ ]
425
+ ch = model_channels * mult
426
+ if ds in attention_resolutions:
427
+ if num_head_channels == -1:
428
+ dim_head = ch // num_heads
429
+ else:
430
+ num_heads = ch // num_head_channels
431
+ dim_head = num_head_channels
432
+ if legacy:
433
+ #num_heads = 1
434
+ dim_head = num_head_channels
435
+ layers.append(
436
+ AFNOCrossAttentionBlock3d(
437
+ ch, context_dim=context_ch[level], num_blocks=num_heads,
438
+ data_format="channels_first", timesteps=timesteps
439
+ )
440
+ )
441
+ if level and i == num_res_blocks:
442
+ out_ch = ch
443
+ layers.append(
444
+ ResBlock(
445
+ ch,
446
+ time_embed_dim,
447
+ dropout,
448
+ out_channels=out_ch,
449
+ dims=dims,
450
+ use_checkpoint=use_checkpoint,
451
+ use_scale_shift_norm=use_scale_shift_norm,
452
+ up=True,
453
+ )
454
+ if resblock_updown
455
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
456
+ )
457
+ ds //= 2
458
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
459
+ self._feature_size += ch
460
+
461
+ self.out = nn.Sequential(
462
+ normalization(ch),
463
+ nn.SiLU(),
464
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
465
+ )
466
+
467
+ def forward(self, x, timesteps=None, context=None):
468
+ """
469
+ Apply the model to an input batch.
470
+ :param x: an [N x C x ...] Tensor of inputs.
471
+ :param timesteps: a 1-D batch of timesteps.
472
+ :param context: conditioning plugged in via crossattn
473
+ :return: an [N x C x ...] Tensor of outputs.
474
+ """
475
+ hs = []
476
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
477
+ emb = self.time_embed(t_emb)
478
+
479
+ h = x.type(self.dtype)
480
+ for module in self.input_blocks:
481
+ h = module(h, emb, context)
482
+ hs.append(h)
483
+ h = self.middle_block(h, emb, context)
484
+ for module in self.output_blocks:
485
+ h = th.cat([h, hs.pop()], dim=1)
486
+ h = module(h, emb, context)
487
+ h = h.type(x.dtype)
488
+ return self.out(h)
489
+
ldcast/models/nowcast/nowcast.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ import pytorch_lightning as pl
7
+
8
+ from ..blocks.afno import AFNOBlock3d
9
+ from ..blocks.attention import positional_encoding, TemporalTransformer
10
+
11
+
12
+ class Nowcaster(pl.LightningModule):
13
+ def __init__(self, nowcast_net):
14
+ super().__init__()
15
+ self.nowcast_net = nowcast_net
16
+
17
+ def forward(self, x):
18
+ return self.nowcast_net(x)
19
+
20
+ def _loss(self, batch):
21
+ (x,y) = batch
22
+ y_pred = self.forward(x)
23
+ return (y-y_pred).square().mean()
24
+
25
+ def training_step(self, batch, batch_idx):
26
+ loss = self._loss(batch)
27
+ self.log("train_loss", loss)
28
+ return loss
29
+
30
+ @torch.no_grad()
31
+ def val_test_step(self, batch, batch_idx, split="val"):
32
+ loss = self._loss(batch)
33
+ log_params = {"on_step": False, "on_epoch": True, "prog_bar": True}
34
+ self.log(f"{split}_loss", loss, **log_params)
35
+
36
+ def validation_step(self, batch, batch_idx):
37
+ self.val_test_step(batch, batch_idx, split="val")
38
+
39
+ def test_step(self, batch, batch_idx):
40
+ self.val_test_step(batch, batch_idx, split="test")
41
+
42
+ def configure_optimizers(self):
43
+ optimizer = torch.optim.AdamW(
44
+ self.parameters(), lr=1e-3,
45
+ betas=(0.5, 0.9), weight_decay=1e-3
46
+ )
47
+ reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau(
48
+ optimizer, patience=3, factor=0.25, verbose=True
49
+ )
50
+
51
+ optimizer_spec = {
52
+ "optimizer": optimizer,
53
+ "lr_scheduler": {
54
+ "scheduler": reduce_lr,
55
+ "monitor": "val_loss",
56
+ "frequency": 1,
57
+ },
58
+ }
59
+ return optimizer_spec
60
+
61
+
62
+ class AFNONowcastNetBasic(nn.Sequential):
63
+ def __init__(
64
+ self,
65
+ embed_dim=256,
66
+ depth=12,
67
+ patch_size=(4,4,4)
68
+ ):
69
+ patch_embed = PatchEmbed3d(
70
+ embed_dim=embed_dim, patch_size=patch_size
71
+ )
72
+ blocks = nn.Sequential(
73
+ *(AFNOBlock(embed_dim) for _ in range(depth))
74
+ )
75
+ patch_expand = PatchExpand3d(
76
+ embed_dim=embed_dim, patch_size=patch_size
77
+ )
78
+ super().__init__(*[patch_embed, blocks, patch_expand])
79
+
80
+
81
+ class FusionBlock3d(nn.Module):
82
+ def __init__(self, dim, size_ratios, dim_out=None, afno_fusion=False):
83
+ super().__init__()
84
+
85
+ N_sources = len(size_ratios)
86
+ if not isinstance(dim, collections.abc.Sequence):
87
+ dim = (dim,) * N_sources
88
+ if dim_out is None:
89
+ dim_out = dim[0]
90
+
91
+ self.scale = nn.ModuleList()
92
+ for (i,size_ratio) in enumerate(size_ratios):
93
+ if size_ratio == 1:
94
+ scale = nn.Identity()
95
+ else:
96
+ scale = []
97
+ while size_ratio > 1:
98
+ scale.append(nn.ConvTranspose3d(
99
+ dim[i], dim_out if size_ratio==2 else dim[i],
100
+ kernel_size=(1,3,3), stride=(1,2,2),
101
+ padding=(0,1,1), output_padding=(0,1,1)
102
+ ))
103
+ size_ratio //= 2
104
+ scale = nn.Sequential(*scale)
105
+ self.scale.append(scale)
106
+
107
+ self.afno_fusion = afno_fusion
108
+
109
+ if self.afno_fusion:
110
+ if N_sources > 1:
111
+ self.fusion = nn.Sequential(
112
+ nn.Linear(sum(dim), sum(dim)),
113
+ AFNOBlock3d(dim*N_sources, mlp_ratio=2),
114
+ nn.Linear(sum(dim), dim_out)
115
+ )
116
+ else:
117
+ self.fusion = nn.Identity()
118
+
119
+ def resize_proj(self, x, i):
120
+ x = x.permute(0,4,1,2,3)
121
+ x = self.scale[i](x)
122
+ x = x.permute(0,2,3,4,1)
123
+ return x
124
+
125
+ def forward(self, x):
126
+ x = [self.resize_proj(xx, i) for (i, xx) in enumerate(x)]
127
+ if self.afno_fusion:
128
+ x = torch.concat(x, axis=-1)
129
+ x = self.fusion(x)
130
+ else:
131
+ x = sum(x)
132
+ return x
133
+
134
+
135
+ class AFNONowcastNetBase(nn.Module):
136
+ def __init__(
137
+ self,
138
+ autoencoder,
139
+ embed_dim=128,
140
+ embed_dim_out=None,
141
+ analysis_depth=4,
142
+ forecast_depth=4,
143
+ input_patches=(1,),
144
+ input_size_ratios=(1,),
145
+ output_patches=2,
146
+ train_autoenc=False,
147
+ afno_fusion=False
148
+ ):
149
+ super().__init__()
150
+
151
+ self.train_autoenc = train_autoenc
152
+ if not isinstance(autoencoder, collections.abc.Sequence):
153
+ autoencoder = [autoencoder]
154
+ if not isinstance(input_patches, collections.abc.Sequence):
155
+ input_patches = [input_patches]
156
+ num_inputs = len(autoencoder)
157
+ if not isinstance(embed_dim, collections.abc.Sequence):
158
+ embed_dim = [embed_dim] * num_inputs
159
+ if embed_dim_out is None:
160
+ embed_dim_out = embed_dim[0]
161
+ if not isinstance(analysis_depth, collections.abc.Sequence):
162
+ analysis_depth = [analysis_depth] * num_inputs
163
+ self.embed_dim = embed_dim
164
+ self.embed_dim_out = embed_dim_out
165
+ self.output_patches = output_patches
166
+
167
+ # encoding + analysis for each input
168
+ self.autoencoder = nn.ModuleList()
169
+ self.proj = nn.ModuleList()
170
+ self.analysis = nn.ModuleList()
171
+ for i in range(num_inputs):
172
+ ae = autoencoder[i].requires_grad_(train_autoenc)
173
+ self.autoencoder.append(ae)
174
+
175
+ proj = nn.Conv3d(ae.hidden_width, embed_dim[i], kernel_size=1)
176
+ self.proj.append(proj)
177
+
178
+ analysis = nn.Sequential(
179
+ *(AFNOBlock3d(embed_dim[i]) for _ in range(analysis_depth[i]))
180
+ )
181
+ self.analysis.append(analysis)
182
+
183
+ # temporal transformer
184
+ self.use_temporal_transformer = \
185
+ any((ipp != output_patches) for ipp in input_patches)
186
+ if self.use_temporal_transformer:
187
+ self.temporal_transformer = nn.ModuleList(
188
+ TemporalTransformer(embed_dim[i]) for i in range(num_inputs)
189
+ )
190
+
191
+ # data fusion
192
+ self.fusion = FusionBlock3d(embed_dim, input_size_ratios,
193
+ afno_fusion=afno_fusion, dim_out=embed_dim_out)
194
+
195
+ # forecast
196
+ self.forecast = nn.Sequential(
197
+ *(AFNOBlock3d(embed_dim_out) for _ in range(forecast_depth))
198
+ )
199
+
200
+ def add_pos_enc(self, x, t):
201
+ if t.shape[1] != x.shape[1]:
202
+ # this can happen if x has been compressed
203
+ # by the autoencoder in the time dimension
204
+ ds_factor = t.shape[1] // x.shape[1]
205
+ t = F.avg_pool1d(t.unsqueeze(1), ds_factor)[:,0,:]
206
+
207
+ pos_enc = positional_encoding(t, x.shape[-1], add_dims=(2,3))
208
+ return x + pos_enc
209
+
210
+ def forward(self, x):
211
+ (x, t_relative) = list(zip(*x))
212
+
213
+ # encoding + analysis for each input
214
+ def process_input(i):
215
+ z = self.autoencoder[i].encode(x[i])[0]
216
+ z = self.proj[i](z)
217
+ z = z.permute(0,2,3,4,1)
218
+ z = self.analysis[i](z)
219
+ if self.use_temporal_transformer:
220
+ # add positional encoding
221
+ z = self.add_pos_enc(z, t_relative[i])
222
+
223
+ # transform to output shape and coordinates
224
+ expand_shape = z.shape[:1] + (-1,) + z.shape[2:]
225
+ pos_enc_output = positional_encoding(
226
+ torch.arange(1,self.output_patches+1, device=z.device),
227
+ self.embed_dim[i], add_dims=(0,2,3)
228
+ )
229
+ pe_out = pos_enc_output.expand(*expand_shape)
230
+ z = self.temporal_transformer[i](pe_out, z)
231
+ return z
232
+
233
+ x = [process_input(i) for i in range(len(x))]
234
+
235
+ # merge inputs
236
+ x = self.fusion(x)
237
+ # produce prediction
238
+ x = self.forecast(x)
239
+ return x.permute(0,4,1,2,3) # to channels-first order
240
+
241
+
242
+ class AFNONowcastNet(AFNONowcastNetBase):
243
+ def __init__(self, autoencoder, output_autoencoder=None, **kwargs):
244
+ super().__init__(autoencoder, **kwargs)
245
+ if output_autoencoder is None:
246
+ output_autoencoder = autoencoder[0]
247
+ self.output_autoencoder = output_autoencoder.requires_grad_(
248
+ self.train_autoenc)
249
+ self.out_proj = nn.Conv3d(
250
+ self.embed_dim_out, output_autoencoder.hidden_width, kernel_size=1
251
+ )
252
+
253
+ def forward(self, x):
254
+ x = super().forward(x)
255
+ x = self.out_proj(x)
256
+ return self.output_autoencoder.decode(x)
ldcast/models/utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ def normalization(channels, norm_type="group", num_groups=32):
6
+ if norm_type == "batch":
7
+ return nn.BatchNorm3d(channels)
8
+ elif norm_type == "group":
9
+ return nn.GroupNorm(num_groups=num_groups, num_channels=channels)
10
+ elif (not norm_type) or (norm_type.tolower() == 'none'):
11
+ return nn.Identity()
12
+ else:
13
+ raise NotImplementedError(norm)
14
+
15
+
16
+ def activation(act_type="swish"):
17
+ if act_type == "swish":
18
+ return nn.SiLU()
19
+ elif act_type == "gelu":
20
+ return nn.GELU()
21
+ elif act_type == "relu":
22
+ return nn.ReLU()
23
+ elif act_type == "tanh":
24
+ return nn.Tanh()
25
+ elif not act_type:
26
+ return nn.Identity()
27
+ else:
28
+ raise NotImplementedError(act_type)
ldcast/visualization/cm.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from matplotlib.colors import LinearSegmentedColormap
3
+
4
+
5
+ def yuv_rainbow_24(nc):
6
+ """ From https://github.com/ARM-DOE/pyart/blob/main/pyart/graph/_cm_colorblind.py
7
+ """
8
+ path1 = np.linspace(0.8*np.pi, 1.8*np.pi, nc)
9
+ path2 = np.linspace(-0.33*np.pi, 0.33*np.pi, nc)
10
+
11
+ y = np.concatenate([np.linspace(0.3, 0.85, nc*2//5),
12
+ np.linspace(0.9, 0.0, nc - nc*2//5)])
13
+ u = 0.40*np.sin(path1)
14
+ v = 0.55*np.sin(path2) + 0.1
15
+
16
+ rgb_from_yuv = np.array([[1, 0, 1.13983],
17
+ [1, -0.39465, -0.58060],
18
+ [1, 2.03211, 0]])
19
+ cmap_dict = {'blue': [], 'green': [], 'red': []}
20
+ for i in range(len(y)):
21
+ yuv = np.array([y[i], u[i], v[i]])
22
+ rgb = rgb_from_yuv.dot(yuv)
23
+ red_tuple = (i/(len(y)-1.0), rgb[0], rgb[0])
24
+ green_tuple = (i/(len(y)-1.0), rgb[1], rgb[1])
25
+ blue_tuple = (i/(len(y)-1.0), rgb[2], rgb[2])
26
+ cmap_dict['blue'].append(blue_tuple)
27
+ cmap_dict['red'].append(red_tuple)
28
+ cmap_dict['green'].append(green_tuple)
29
+
30
+ return cmap_dict
31
+
32
+
33
+ homeyer_rainbow = LinearSegmentedColormap(
34
+ "homeyer_rainbow",
35
+ yuv_rainbow_24(15)
36
+ )
ldcast/visualization/plots.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent
2
+ import multiprocessing
3
+ import netCDF4
4
+ import os
5
+
6
+ from matplotlib import gridspec, colors, pyplot as plt
7
+ import numpy as np
8
+ import torch
9
+
10
+ from ..analysis import confmatrix, fss, rank
11
+ from ..features import io
12
+ from .cm import homeyer_rainbow
13
+
14
+
15
+ def reverse_transform_R(R, mean=-0.051, std=0.528):
16
+ return 10**(R*std + mean)
17
+
18
+
19
+ def plot_precip_image(
20
+ ax, R,
21
+ Rmin=0, Rmax=25, threshold_mmh=0.1,
22
+ transform_R=False,
23
+ grid_spacing=64
24
+ ):
25
+ if isinstance(R, torch.Tensor):
26
+ R = R.detach().numpy()
27
+ if transform_R:
28
+ R = reverse_transform_R(R, mean=mean, std=std)
29
+ Rmin = reverse_transform_R(Rmin)
30
+ Rmax = reverse_transform_R(Rmax)
31
+ if threshold_mmh:
32
+ Rmin = max(Rmin, threshold_mmh)
33
+ R[R < threshold_mmh] = np.nan
34
+ norm = colors.LogNorm(Rmin, Rmax)
35
+ ax.set_yticks(np.arange(0, R.shape[0], grid_spacing))
36
+ ax.set_xticks(np.arange(0, R.shape[1], grid_spacing))
37
+ ax.grid(which='major', alpha=0.35)
38
+ ax.tick_params(left=False, bottom=False,
39
+ labelleft=False, labelbottom=False)
40
+
41
+ return ax.imshow(R, norm=norm, cmap=homeyer_rainbow)
42
+
43
+
44
+ def plot_autoencoder_reconstruction(
45
+ R, R_hat, samples=8, timesteps=4,
46
+ out_file=None
47
+ ):
48
+ fig = plt.figure(figsize=(2*timesteps*2+0.5, samples*2), dpi=150)
49
+
50
+ gs = gridspec.GridSpec(
51
+ samples, 2*timesteps+1,
52
+ width_ratios=(1,)*(2*timesteps)+(0.2,),
53
+ wspace=0.02, hspace=0.02
54
+ )
55
+ for k in range(samples):
56
+ for (i,j) in enumerate(range(-timesteps,0)):
57
+ ax = fig.add_subplot(gs[k,i])
58
+ im = plot_precip_image(ax, R[k,0,j,:,:])
59
+ for (i,j) in enumerate(range(-timesteps,0)):
60
+ ax = fig.add_subplot(gs[k,i+timesteps])
61
+ im = plot_precip_image(ax, R_hat[k,0,j,:,:])
62
+
63
+ cax = fig.add_subplot(gs[:,-1])
64
+ plt.colorbar(im, cax=cax)
65
+
66
+ if out_file is not None:
67
+ out_file = fig.savefig(out_file, bbox_inches='tight')
68
+ plt.close(fig)
69
+
70
+
71
+ def plot_animation(x, y, out_dir, sample=0, fmt="{}_{:02d}.png"):
72
+ def plot_frame(R, label, timestep):
73
+ fig = plt.figure()
74
+ ax = fig.add_subplot()
75
+ im = plot_precip_image(ax, R[sample,0,timestep,:,:])
76
+ fn = fmt.format(label, k)
77
+ fn = os.path.join(out_dir, fn)
78
+ fig.savefig(fn, bbox_inches='tight')
79
+ plt.close(fig)
80
+
81
+ for k in range(x.shape[2]):
82
+ plot_frame(x, "x", k)
83
+
84
+ for k in range(y.shape[2]):
85
+ plot_frame(y, "y", k)
86
+
87
+
88
+ model_colors = {
89
+ "mch-dgmr": "#E69F00",
90
+ "mch-pysteps": "#009E73",
91
+ "mch-iters=50-res=256": "#0072B2",
92
+ "mch-persistence": "#888888",
93
+ "dwd-dgmr": "#E69F00",
94
+ "dwd-pysteps": "#009E73",
95
+ "dwd-iters=50-res=256": "#0072B2",
96
+ "dwd-persistence": "#888888",
97
+ "pm-mch-dgmr": "#E69F00",
98
+ "pm-mch-pysteps": "#009E73",
99
+ "pm-mch-iters=50-res=256": "#0072B2",
100
+ "pm-dwd-dgmr": "#E69F00",
101
+ "pm-dwd-pysteps": "#009E73",
102
+ "pm-dwd-iters=50-res=256": "#0072B2",
103
+ }
104
+
105
+ scale_linestyles = {
106
+ "1x1": "-",
107
+ "8x8": "--",
108
+ "64x64": ":"
109
+ }
110
+
111
+
112
+ def plot_crps(
113
+ log=False,
114
+ models=("iters=50-res=256", "dgmr", "pysteps"),
115
+ scales=("1x1", "8x8", "64x64"),
116
+ model_labels=("LDCast", "DGMR", "PySTEPS", "Persist."),
117
+ interval_mins=5,
118
+ out_fn=None,
119
+ ax=None,
120
+ add_xlabel=True,
121
+ add_ylabel=True,
122
+ add_legend=True,
123
+ crop_box=None
124
+ ):
125
+ crps = {}
126
+ crps_name = "logcrps" if log else "crps"
127
+ for model in models:
128
+ crps[model] = {}
129
+ fn = f"../results/crps/{crps_name}-{model}.nc"
130
+ with netCDF4.Dataset(fn, 'r') as ds:
131
+ for scale in scales:
132
+ var = f"crps_pool{scale}"
133
+ crps_model_scale = np.array(ds[var][:], copy=False)
134
+ if crop_box is not None:
135
+ scale_int = int(scale.split("x")[0])
136
+ crps_model_scale = crps_model_scale[
137
+ ...,
138
+ crop_box[0][0]//scale_int:crop_box[0][1]//scale_int,
139
+ crop_box[1][0]//scale_int:crop_box[1][1]//scale_int
140
+ ]
141
+ crps[model][scale] = crps_model_scale.mean(axis=(0,1,3,4))
142
+ del crps_model_scale
143
+
144
+ if ax is None:
145
+ fig = plt.figure(figsize=(8,5))
146
+ ax = fig.add_subplot()
147
+
148
+ max_t = 0
149
+ for (model, label) in zip(models, model_labels):
150
+ for scale in scales:
151
+ score = crps[model][scale]
152
+ color = model_colors[model]
153
+ linestyle = scale_linestyles[scale]
154
+ t = np.arange(
155
+ interval_mins, (len(score)+0.1)*interval_mins, interval_mins
156
+ )
157
+ max_t = max(max_t, t[-1])
158
+ ax.plot(t, score, color=color, linestyle=linestyle,
159
+ label=label)
160
+
161
+ if add_legend:
162
+ plt.legend()
163
+ if add_xlabel:
164
+ plt.xlabel("Lead time [min]", fontsize=12)
165
+ if add_ylabel:
166
+ plt.ylabel(
167
+ "LogCRPS" if log else "CRPS [mm h$^\\mathrm{-1}$]",
168
+ fontsize=12
169
+ )
170
+
171
+ ax.set_xlim((0, max_t))
172
+ ylim = ax.get_ylim()
173
+ ylim = (0, ylim[1])
174
+ ax.set_ylim(ylim)
175
+ ax.tick_params(axis='both', which='major', labelsize=12)
176
+
177
+ if out_fn is not None:
178
+ fig.savefig(out_fn, bbox_inches='tight')
179
+ plt.close(fig)
180
+
181
+
182
+ def plot_rank_distribution(
183
+ models=("iters=50-res=256", "dgmr", "pysteps"),
184
+ scales=("1x1", "8x8", "64x64"),
185
+ model_labels=("LDCast", "DGMR", "PySTEPS"),
186
+ out_fn=None,
187
+ num_ensemble_members=32,
188
+ ax=None,
189
+ add_xlabel=True,
190
+ add_ylabel=True,
191
+ add_legend=True,
192
+ crop_box=None
193
+ ):
194
+ rank_hist = {}
195
+ rank_KL = {}
196
+ for model in models:
197
+ fn = f"../results/ranks/ranks-{model}.nc"
198
+ with netCDF4.Dataset(fn, 'r') as ds:
199
+ for scale in scales:
200
+ var = f"ranks_pool{scale}"
201
+ ranks_model_scale = np.array(ds[var][:], copy=False)
202
+ if crop_box is not None:
203
+ scale_int = int(scale.split("x")[0])
204
+ ranks_model_scale = ranks_model_scale[
205
+ ...,
206
+ crop_box[0][0]//scale_int:crop_box[0][1]//scale_int,
207
+ crop_box[1][0]//scale_int:crop_box[1][1]//scale_int
208
+ ]
209
+ rank_hist[(model,scale)] = rank.rank_distribution(ranks_model_scale)
210
+ rank_KL[(model,scale)] = rank.rank_DKL(rank_hist[(model,scale)])
211
+ del ranks_model_scale
212
+
213
+ if ax is None:
214
+ fig = plt.figure(figsize=(8,5))
215
+ ax = fig.add_subplot()
216
+
217
+ for scale in scales:
218
+ linestyle = scale_linestyles[scale]
219
+ for (model, label) in zip(models, model_labels):
220
+ h = rank_hist[(model,scale)]
221
+ color = model_colors[model]
222
+ x = np.linspace(0, 1, num_ensemble_members+1)
223
+ label_with_score = f"{label}: {rank_KL[(model,scale)]:.3f}"
224
+ ax.plot(x, h, color=color, linestyle=linestyle,
225
+ label=label_with_score)
226
+ h_ideal = 1/(num_ensemble_members+1)
227
+ ax.plot([0, 1], [h_ideal, h_ideal], color=(0.4,0.4,0.4),
228
+ linewidth=1.0)
229
+
230
+ if add_legend:
231
+ ax.legend(loc='upper center')
232
+ if add_xlabel:
233
+ ax.set_xlabel("Normalized rank", fontsize=12)
234
+ if add_ylabel:
235
+ ax.set_ylabel("Occurrence", fontsize=12)
236
+
237
+ ax.set_xlim((0, 1))
238
+ ylim = ax.get_ylim()
239
+ ylim = (0, ylim[1])
240
+ ax.set_ylim(ylim)
241
+ ax.tick_params(axis='both', which='major', labelsize=12)
242
+ ax.set_xticks([0, 0.25, 0.5, 0.75, 1])
243
+ # int labels for 0 and 1 to save space
244
+ ax.set_xticklabels(["0", "0.25", "0.5", "0.75", "1"])
245
+
246
+ if out_fn is not None:
247
+ fig.savefig(out_fn, bbox_inches='tight')
248
+ plt.close(fig)
249
+
250
+
251
+ def plot_rank_metric(
252
+ models=("iters=50-res=256", "dgmr", "pysteps"),
253
+ scales=("1x1", "8x8", "64x64"),
254
+ model_labels=("LDCast", "DGMR", "PySTEPS"),
255
+ interval_mins=5,
256
+ metric_name="KL",
257
+ out_fn=None,
258
+ ax=None,
259
+ add_xlabel=True,
260
+ add_ylabel=True,
261
+ add_legend=True,
262
+ ):
263
+ rank_metric = {}
264
+ for model in models:
265
+ rank_metric[model] = {}
266
+ fn = f"../results/ranks/ranks-{model}.nc"
267
+ with netCDF4.Dataset(fn, 'r') as ds:
268
+ for scale in scales:
269
+ var = f"ranks_pool{scale}"
270
+ ranks = np.array(ds[var][:], copy=False)
271
+ rank_metric[model][scale] = rank.rank_metric_by_leadtime(ranks)
272
+ del ranks
273
+
274
+ if ax is None:
275
+ fig = plt.figure(figsize=(8,5))
276
+ ax = fig.add_subplot()
277
+
278
+ max_t = 0
279
+ for (model, label) in zip(models, model_labels):
280
+ for scale in scales:
281
+ score = rank_metric[model][scale]
282
+ color = model_colors[model]
283
+ linestyle = scale_linestyles[scale]
284
+ label_with_scale = f"{label} {scale}"
285
+ t = np.arange(
286
+ interval_mins, (len(score)+0.1)*interval_mins, interval_mins
287
+ )
288
+ max_t = max(max_t, t[-1])
289
+ ax.plot(t, score, color=color, linestyle=linestyle,
290
+ label=label_with_scale)
291
+
292
+ if add_legend:
293
+ plt.legend()
294
+ if add_xlabel:
295
+ plt.xlabel("Lead time [min]", fontsize=12)
296
+ if add_ylabel:
297
+ plt.ylabel(f"Rank {metric_name}", fontsize=12)
298
+
299
+ ax.set_xlim((0, max_t))
300
+ ylim = ax.get_ylim()
301
+ ylim = (0, ylim[1])
302
+ ax.set_ylim(ylim)
303
+ ax.tick_params(axis='both', which='major', labelsize=12)
304
+
305
+ if out_fn is not None:
306
+ fig.savefig(out_fn, bbox_inches='tight')
307
+ plt.close(fig)
308
+
309
+
310
+ def load_fss(model, scale, use_timesteps, crop_box):
311
+ fn = f"../results/fractions/fractions-{model}.nc"
312
+ with netCDF4.Dataset(fn, 'r') as ds:
313
+ sn = f"{scale}x{scale}"
314
+ obs_frac = np.array(ds[f"obs_frac_scale{sn}"][:], copy=False)
315
+ fc_frac = np.array(ds[f"fc_frac_scale{sn}"][:], copy=False)
316
+ if crop_box is not None:
317
+ obs_frac = obs_frac[
318
+ ...,
319
+ crop_box[0][0]//scale:crop_box[0][1]//scale,
320
+ crop_box[1][0]//scale:crop_box[1][1]//scale
321
+ ]
322
+ fc_frac = fc_frac[
323
+ ...,
324
+ crop_box[0][0]//scale:crop_box[0][1]//scale,
325
+ crop_box[1][0]//scale:crop_box[1][1]//scale
326
+ ]
327
+ return fss.fractions_skill_score(
328
+ obs_frac, fc_frac, use_timesteps=use_timesteps
329
+ )
330
+
331
+
332
+ def plot_fss(
333
+ log=False,
334
+ models=("iters=50-res=256", "dgmr", "pysteps"),
335
+ model_labels=("LDCast", "DGMR", "PySTEPS"),
336
+ interval_mins=5,
337
+ out_fn=None,
338
+ ax=None,
339
+ add_xlabel=True,
340
+ add_ylabel=True,
341
+ add_legend=True,
342
+ scales=None,
343
+ use_timesteps=18,
344
+ crop_box=None
345
+ ):
346
+ if scales is None:
347
+ scales = 2**np.arange(9)
348
+ fss_scale = {}
349
+ N_threads = min(multiprocessing.cpu_count(), len(models)*len(scales))
350
+ with concurrent.futures.ProcessPoolExecutor(N_threads) as executor:
351
+ for model in models:
352
+ fss_scale[model] = {}
353
+ for scale in scales:
354
+ fss_scale[model][scale] = executor.submit(
355
+ load_fss, model, scale, use_timesteps, crop_box
356
+ )
357
+
358
+ for model in models:
359
+ for scale in scales:
360
+ fss_scale[model][scale] = fss_scale[model][scale].result()
361
+
362
+ if ax is None:
363
+ fig = plt.figure(figsize=(8,5))
364
+ ax = fig.add_subplot()
365
+
366
+ for (model, label) in zip(models, model_labels):
367
+ scales = sorted(fss_scale[model])
368
+ fss_for_model = [fss_scale[model][s] for s in scales]
369
+
370
+ model_parts = model.split("-")
371
+ if model.startswith("pm-"):
372
+ model_without_threshold = "-".join(model_parts[:1] + model_parts[2:])
373
+ else:
374
+ model_without_threshold = "-".join(model_parts[1:])
375
+ color = model_colors[model_without_threshold]
376
+
377
+ ax.plot(scales, fss_for_model, color=color,
378
+ label=label)
379
+
380
+ if add_legend:
381
+ plt.legend()
382
+ if add_xlabel:
383
+ plt.xlabel("Scale [km]", fontsize=12)
384
+ if add_ylabel:
385
+ plt.ylabel("FSS")
386
+
387
+ ax.set_xlim((scales[0], scales[-1]))
388
+ ylim = ax.get_ylim()
389
+ ylim = (0, ylim[1])
390
+ ax.set_ylim(ylim)
391
+ ax.tick_params(axis='both', which='major', labelsize=12)
392
+
393
+ if out_fn is not None:
394
+ fig.savefig(out_fn, bbox_inches='tight')
395
+ plt.close(fig)
396
+
397
+
398
+ def plot_csi_threshold(
399
+ models=("iters=50-res=256", "dgmr", "pysteps"),
400
+ scales=("1x1", "8x8", "64x64"),
401
+ prob_thresholds=tuple(np.linspace(0,1,33)),
402
+ model_labels=("LDCast", "DGMR", "PySTEPS"),
403
+ out_fn=None,
404
+ num_ensemble_members=32,
405
+ max_timestep=18,
406
+ ax=None,
407
+ add_xlabel=True,
408
+ add_ylabel=True,
409
+ add_legend=True,
410
+ crop_box=None
411
+ ):
412
+ csi = {}
413
+ for model in models:
414
+ fn = f"../results/fractions/fractions-{model}.nc"
415
+ with netCDF4.Dataset(fn, 'r') as ds:
416
+ for scale in scales:
417
+ fc_var = f"fc_frac_scale{scale}"
418
+ fc_frac = np.array(ds[fc_var], copy=False)
419
+ fc_frac = fc_frac[...,:max_timestep,:,:]
420
+ obs_var = f"obs_frac_scale{scale}"
421
+ obs_frac = np.array(ds[obs_var], copy=False)
422
+ obs_frac = obs_frac[...,:max_timestep,:,:]
423
+ conf_matrix = confmatrix.confusion_matrix_thresholds(
424
+ fc_frac, obs_frac, prob_thresholds
425
+ )
426
+ del fc_frac, obs_frac
427
+ csi_scale = confmatrix.intersection_over_union(conf_matrix)
428
+ csi[(model,scale)] = csi_scale
429
+
430
+ if ax is None:
431
+ fig = plt.figure(figsize=(8,5))
432
+ ax = fig.add_subplot()
433
+
434
+ for scale in scales:
435
+ linestyle = scale_linestyles[scale]
436
+ for (model, label) in zip(models, model_labels):
437
+ c = csi[(model,scale)]
438
+ model_parts = model.split("-")
439
+ if model.startswith("pm-"):
440
+ model_without_threshold = "-".join(model_parts[:1] + model_parts[2:])
441
+ else:
442
+ model_without_threshold = "-".join(model_parts[1:])
443
+ color = model_colors[model_without_threshold]
444
+ ax.plot(prob_thresholds, c, color=color, linestyle=linestyle, label=label)
445
+
446
+ if add_legend:
447
+ ax.legend(loc='upper center')
448
+ if add_xlabel:
449
+ ax.set_xlabel("Prob. threshold", fontsize=12)
450
+ if add_ylabel:
451
+ ax.set_ylabel("CSI", fontsize=12)
452
+
453
+ ax.set_xlim((0, 1))
454
+ ylim = ax.get_ylim()
455
+ ylim = (0, ylim[1])
456
+ ax.set_ylim(ylim)
457
+ ax.tick_params(axis='both', which='major', labelsize=12)
458
+ ax.set_xticks([0, 0.25, 0.5, 0.75, 1])
459
+ # int labels for 0 and 1 to save space
460
+ ax.set_xticklabels(["0", "0.25", "0.5", "0.75", "1"])
461
+
462
+ if out_fn is not None:
463
+ fig.savefig(out_fn, bbox_inches='tight')
464
+ plt.close(fig)
465
+
466
+
467
+ def plot_csi_leadtime(
468
+ models=("iters=50-res=256", "dgmr", "pysteps"),
469
+ scales=("1x1", "8x8", "64x64"),
470
+ prob_thresholds=tuple(np.linspace(0,1,33)),
471
+ model_labels=("LDCast", "DGMR", "PySTEPS"),
472
+ out_fn=None,
473
+ interval_mins=5,
474
+ num_ensemble_members=32,
475
+ ax=None,
476
+ add_xlabel=True,
477
+ add_ylabel=True,
478
+ add_legend=True,
479
+ crop_box=None
480
+ ):
481
+ csi = {}
482
+ for model in models:
483
+ fn = f"../results/fractions/fractions-{model}.nc"
484
+ with netCDF4.Dataset(fn, 'r') as ds:
485
+ for scale in scales:
486
+ fc_var = f"fc_frac_scale{scale}"
487
+ fc_frac = np.array(ds[fc_var], copy=False)
488
+ obs_var = f"obs_frac_scale{scale}"
489
+ obs_frac = np.array(ds[obs_var], copy=False)
490
+ conf_matrix = confmatrix.confusion_matrix_thresholds_leadtime(
491
+ fc_frac, obs_frac, prob_thresholds
492
+ )
493
+
494
+ csi_scale = confmatrix.intersection_over_union(conf_matrix)
495
+ csi[(model,scale)] = np.nanmax(csi_scale, axis=1)
496
+
497
+ max_t = 0
498
+ for (model, label) in zip(models, model_labels):
499
+ for scale in scales:
500
+ score = csi[(model,scale)]
501
+
502
+ model_parts = model.split("-")
503
+ if model.startswith("pm-"):
504
+ model_without_threshold = "-".join(model_parts[:1] + model_parts[2:])
505
+ else:
506
+ model_without_threshold = "-".join(model_parts[1:])
507
+ color = model_colors[model_without_threshold]
508
+ linestyle = scale_linestyles[scale]
509
+ t = np.arange(
510
+ interval_mins, (len(score)+0.1)*interval_mins, interval_mins
511
+ )
512
+ max_t = max(max_t, t[-1])
513
+ ax.plot(t, score, color=color, linestyle=linestyle,
514
+ label=label)
515
+
516
+ if add_legend:
517
+ plt.legend()
518
+ if add_xlabel:
519
+ plt.xlabel("Lead time [min]", fontsize=12)
520
+ if add_ylabel:
521
+ plt.ylabel("CSI", fontsize=12)
522
+
523
+ ax.set_xlim((0, max_t))
524
+ ylim = ax.get_ylim()
525
+ ylim = (0, ylim[1])
526
+ ax.set_ylim(ylim)
527
+ ax.tick_params(axis='both', which='major', labelsize=12)
528
+
529
+ if out_fn is not None:
530
+ fig.savefig(out_fn, bbox_inches='tight')
531
+ plt.close(fig)
532
+
533
+
534
+ def plot_cost_loss_value(
535
+ models=("iters=50-res=256", "dgmr", "pysteps"),
536
+ scales=("1x1", "8x8", "64x64"),
537
+ prob_thresholds=tuple(np.linspace(0,1,33)),
538
+ model_labels=("LDCast", "DGMR", "PySTEPS"),
539
+ out_fn=None,
540
+ interval_mins=5,
541
+ num_ensemble_members=32,
542
+ ax=None,
543
+ add_xlabel=True,
544
+ add_ylabel=True,
545
+ add_legend=True,
546
+ crop_box=None
547
+ ):
548
+ value = {}
549
+ loss = 1.0
550
+ cost = np.linspace(0.01, 1, 100)
551
+ for model in models:
552
+ fn = f"../results/fractions/fractions-{model}.nc"
553
+ with netCDF4.Dataset(fn, 'r') as ds:
554
+ for scale in scales:
555
+ fc_var = f"fc_frac_scale{scale}"
556
+ fc_frac = np.array(ds[fc_var], copy=False)
557
+ obs_var = f"obs_frac_scale{scale}"
558
+ obs_frac = np.array(ds[obs_var], copy=False)
559
+ conf_matrix = confmatrix.confusion_matrix_thresholds(
560
+ fc_frac, obs_frac, prob_thresholds
561
+ )
562
+
563
+ p_clim = obs_frac.mean()
564
+ value_scale = []
565
+ for c in cost:
566
+ v = confmatrix.cost_loss_value(
567
+ conf_matrix, c, loss, p_clim
568
+ )
569
+ value_scale.append(v[len(v)//2])
570
+ value[(model,scale)] = np.array(value_scale)
571
+
572
+ max_score = 0
573
+ for (model, label) in zip(models, model_labels):
574
+ for scale in scales:
575
+ score = value[(model,scale)]
576
+ max_score = max(max_score, score[np.isfinite(score)].max())
577
+
578
+ model_parts = model.split("-")
579
+ if model.startswith("pm-"):
580
+ model_without_threshold = "-".join(model_parts[:1] + model_parts[2:])
581
+ else:
582
+ model_without_threshold = "-".join(model_parts[1:])
583
+ color = model_colors[model_without_threshold]
584
+ linestyle = scale_linestyles[scale]
585
+
586
+ ax.plot(cost, score, color=color, linestyle=linestyle,
587
+ label=label)
588
+
589
+ if add_legend:
590
+ plt.legend()
591
+ if add_xlabel:
592
+ plt.xlabel("Cost/loss ratio", fontsize=12)
593
+ if add_ylabel:
594
+ plt.ylabel("Value", fontsize=12)
595
+
596
+ ax.set_xlim((0, 1))
597
+ ylim = (0, max_score*1.05)
598
+ ax.set_ylim(ylim)
599
+ ax.tick_params(axis='both', which='major', labelsize=12)
600
+ ax.set_xticks([0, 0.25, 0.5, 0.75, 1])
601
+ # int labels for 0 and 1 to save space
602
+ ax.set_xticklabels(["0", "0.25", "0.5", "0.75", "1"])
603
+
604
+ if out_fn is not None:
605
+ fig.savefig(out_fn, bbox_inches='tight')
606
+ plt.close(fig)
models/.keep ADDED
File without changes
models/autoenc/autoenc-32-0.01.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa5ad4b8689aadbf702376e7afe5cb437ef5057675e78a8986837e8f28b3126e
3
+ size 1617490
models/autoenc/autoenc.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf3792a94ee4cf347ca498b10b851d58a0f2ce6b3062e7b59fec5761c7edbf24
3
+ size 1616323
scripts/convert_data_NB_2nc.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import dask.array as da
4
+ import xarray as xr
5
+
6
+ def load_all_file(data_dir=""):
7
+ data_list = []
8
+ filtered_files = []
9
+ for filename in os.listdir(data_dir):
10
+ if filename.startswith("202306"):
11
+ filtered_files.append(filename)
12
+ # if filename.endswith("00.npy"):
13
+ # filtered_files.append(filename)
14
+ sorted_files = sorted(filtered_files)
15
+ for item in sorted_files:
16
+ sub_dir = os.path.join(data_dir)
17
+ pathfile = sub_dir + "/" + item
18
+ file = np.load(pathfile)
19
+ data_list.extend([file])
20
+
21
+ lon = np.arange(103.5, 109.2, 0.00892)
22
+ lat = np.arange(8, 13.75, 0.00899)
23
+
24
+ return data_list
25
+
26
+ def preprocess_data(data_list, out_dir=""):
27
+ patches = []
28
+
29
+ # Define patch size
30
+ patch_size = 32
31
+ # new_array = xr.DataArray(np.array(data_list[0]), dims=("dim_0", "dim_1"))
32
+ # Iterate over the array to extract patches
33
+ for k in range(len(data_list)):
34
+ for i in range(0, 640, patch_size):
35
+ for j in range(0, 640, patch_size):
36
+ patch = data_list[k][i:i+patch_size, j:j+patch_size]
37
+ patches.append(patch)
38
+
39
+ print(len(patches))
40
+ data_shape = len(patches)
41
+ patches_array = np.array(patches, dtype=np.uint8)
42
+ temp_array = np.array(np.random.rand(data_shape, 2), dtype=np.uint16)
43
+ temp_array2 = np.arange(256, dtype=np.float32)
44
+ temp_array3 = np.arange(data_shape, dtype=np.int64)
45
+
46
+ data_da = da.from_array(patches_array, chunks=(data_shape,32,32)) # Adjust chunk size as needed for your data
47
+ data_da2 = da.from_array(temp_array, chunks=(data_shape, 2))
48
+ data_da3 = da.from_array(temp_array3, chunks=(data_shape, ))
49
+ data_da4 = da.from_array(temp_array2, chunks=(256, ))
50
+
51
+ # Create xarray DataArray with DaskArray as its backend
52
+ patches = xr.DataArray(data_da, dims=("dim_patch", "dim_heigh", "dim_width"))
53
+ patch_coords = xr.DataArray(data_da2, dims=("dim_patch1", "dim_coord"))
54
+ patch_times = xr.DataArray(data_da3, dims=("dim_patch2"))
55
+ zero_patch_coords = xr.DataArray(data_da2, dims=("dim_zero_patch", "dim_coord"))
56
+ zero_patch_times = xr.DataArray(data_da3, dims=("dim_zero_patch1"))
57
+ scale = xr.DataArray(data_da4, dims=("dim_scale"))
58
+
59
+ ds = patches.to_dataset(name = 'patches')
60
+ ds['patch_coords'] = patch_coords
61
+ ds['patch_times'] = patch_times
62
+ ds['zero_patch_coords'] = zero_patch_coords
63
+ ds['zero_patch_times'] = zero_patch_times
64
+ ds['scale'] = scale
65
+
66
+ ds.attrs["zero_value"] = 1
67
+ out_dir = out_dir + "/" + "RZC"
68
+ os.makedirs(out_dir, exist_ok=True)
69
+ file_name = os.path.join(out_dir, "patches_RV_202306.nc")
70
+ ds.to_netcdf(file_name)
71
+
72
+ return len(data_list)
73
+
74
+
75
+ list = load_all_file(data_dir="/data/data_WF/ldcast_precipitation/test")
76
+ print(preprocess_data(list, out_dir="/data/data_WF/ldcast_precipitation/preprocess_data_test"))