franch commited on
Commit
df27dfb
·
verified ·
1 Parent(s): 39d3c0f

Add source code and examples

Browse files
.dockerignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .git
2
+ .github
3
+ .venv
4
+ __pycache__
5
+ *.pyc
6
+ *.zarr
7
+ logs/
8
+ data/
9
+ notebooks/
10
+ importance_sampler/
11
+ tests/
12
+ *.egg-info
13
+ .ruff_cache
14
+ .pytest_cache
15
+ .claude
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/sample_data.nc filter=lfs diff=lfs merge=lfs -text
.github/workflows/ci.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: CI
2
+
3
+ on:
4
+ push:
5
+ pull_request:
6
+
7
+ jobs:
8
+ lint-test-build:
9
+ runs-on: ubuntu-latest
10
+ steps:
11
+ - name: Checkout
12
+ uses: actions/checkout@v4
13
+
14
+ - name: Install uv
15
+ uses: astral-sh/setup-uv@v4
16
+ with:
17
+ python-version: "3.13"
18
+
19
+ - name: Sync dependencies
20
+ run: uv sync --all-groups --extra serve
21
+
22
+ - name: Lint
23
+ run: uv run ruff check .
24
+
25
+ - name: Tests
26
+ run: uv run pytest -q
27
+
28
+ - name: Build package
29
+ run: uv build
.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.nc
2
+ !examples/*.nc
3
+ *.zarr/
4
+ __pycache__/
5
+ .ipynb_checkpoints/
6
+ checkpoints/
7
+ *.ckpt
8
+ *.mp4
9
+ dist/
10
+ build/
11
+ *.egg-info/
12
+ .pytest_cache/
13
+ .ruff_cache/
14
+ logs/
15
+ data/
16
+ .env
17
+ .venv/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v5.0.0
4
+ hooks:
5
+ - id: trailing-whitespace
6
+ - id: end-of-file-fixer
7
+ - id: check-yaml
8
+ - id: check-added-large-files
9
+ args: ['--maxkb=12000']
10
+ - id: check-merge-conflict
11
+
12
+ - repo: https://github.com/astral-sh/ruff-pre-commit
13
+ rev: v0.9.10
14
+ hooks:
15
+ - id: ruff
16
+ args: [--fix]
17
+ - id: ruff-format
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.13-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install uv for fast dependency management
6
+ COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
7
+
8
+ # Copy dependency files first for layer caching
9
+ COPY pyproject.toml uv.lock ./
10
+
11
+ # Install dependencies (serve extras, no dev)
12
+ RUN uv sync --extra serve --no-dev --no-install-project
13
+
14
+ # Copy package source
15
+ COPY convgru_ensemble/ ./convgru_ensemble/
16
+
17
+ # Install the project itself
18
+ RUN uv sync --extra serve --no-dev
19
+
20
+ # Model checkpoint is mounted at runtime or downloaded from HF Hub
21
+ ENV MODEL_CHECKPOINT=/app/model.ckpt
22
+ ENV DEVICE=cpu
23
+
24
+ EXPOSE 8000
25
+
26
+ CMD ["uv", "run", "uvicorn", "convgru_ensemble.serve:app", "--host", "0.0.0.0", "--port", "8000"]
LICENSE ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 2-Clause License
2
+
3
+ Copyright (c) 2026, Data Science for Industry and Physics @ FBK
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
MODEL_CARD.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd-2-clause
3
+ language: en
4
+ tags:
5
+ - weather
6
+ - nowcasting
7
+ - radar
8
+ - precipitation
9
+ - ensemble-forecasting
10
+ - convgru
11
+ - earth-observation
12
+ library_name: pytorch
13
+ pipeline_tag: image-to-image
14
+ ---
15
+
16
+ # IRENE — Italian Radar Ensemble Nowcasting Experiment
17
+
18
+ **IRENE** is a ConvGRU encoder-decoder model for short-term precipitation forecasting (nowcasting) from radar data. The model generates probabilistic ensemble forecasts, producing multiple plausible future scenarios from a single input sequence.
19
+
20
+ ## Model Description
21
+
22
+ - **Architecture**: ConvGRU encoder-decoder with PixelShuffle/PixelUnshuffle for spatial scaling
23
+ - **Input**: Sequence of past radar rain rate fields (T, H, W) in mm/h
24
+ - **Output**: Ensemble of future rain rate forecasts (E, T, H, W) in mm/h
25
+ - **Temporal resolution**: 5 minutes per timestep
26
+ - **Training loss**: Continuous Ranked Probability Score (CRPS) with temporal consistency regularization
27
+
28
+ The model encodes past radar observations into multi-scale hidden states using stacked ConvGRU blocks with PixelUnshuffle downsampling. The decoder generates forecasts by unrolling with different random noise inputs, producing diverse ensemble members that capture forecast uncertainty.
29
+
30
+ ## Intended Uses
31
+
32
+ - Short-term precipitation forecasting (0-60 min ahead) from radar reflectivity data
33
+ - Probabilistic nowcasting with uncertainty quantification via ensemble spread
34
+ - Research on deep learning for weather prediction
35
+ - Fine-tuning on regional radar datasets
36
+
37
+ ## How to Use
38
+
39
+ ```python
40
+ from convgru_ensemble import RadarLightningModel
41
+
42
+ # Load from HuggingFace Hub
43
+ model = RadarLightningModel.from_pretrained("it4lia/irene")
44
+
45
+ # Run inference on past radar data (rain rate in mm/h)
46
+ import numpy as np
47
+ past = np.random.rand(6, 256, 256).astype(np.float32) # 6 past timesteps
48
+ forecasts = model.predict(past, forecast_steps=12, ensemble_size=10)
49
+ # forecasts.shape = (10, 12, 256, 256) — 10 ensemble members, 12 future steps
50
+ ```
51
+
52
+ ## Training Data
53
+
54
+ Trained on the Italian DPC (Dipartimento della Protezione Civile) radar mosaic surface rain intensity (SRI) dataset, covering the Italian territory at ~1 km resolution with 5-minute temporal resolution.
55
+
56
+ ## Training Procedure
57
+
58
+ - **Optimizer**: Adam (lr=1e-4)
59
+ - **Loss**: CRPS with temporal consistency penalty (lambda=0.01)
60
+ - **Batch size**: 16
61
+ - **Ensemble size during training**: 2 members
62
+ - **Input window**: 6 past timesteps (30 min)
63
+ - **Forecast horizon**: 12 future timesteps (60 min)
64
+ - **Data augmentation**: Random rotations and flips
65
+ - **NaN handling**: Masked loss for missing radar data
66
+
67
+ ## Limitations
68
+
69
+ - Trained on Italian radar data; performance may degrade on other domains without fine-tuning
70
+ - 5-minute temporal resolution only
71
+ - Best suited for convective and stratiform precipitation; extreme events may be underrepresented
72
+ - Ensemble spread is generated via noisy decoder inputs, not a full Bayesian approach
73
+
74
+ ## Acknowledgements
75
+
76
+ This model was developed as part of the **Italian AI-Factory** (IT4LIA), an EU-funded initiative supporting the adoption of AI across SMEs, academia, and public/private sectors. The AI-Factory provides free HPC compute, consultancy, and AI-ready datasets. This work showcases capabilities in the **Earth (weather and climate) vertical domain**.
77
+
78
+ Developed at **Fondazione Bruno Kessler (FBK)**, Trento, Italy.
79
+
80
+ ## License
81
+
82
+ BSD 2-Clause License
Makefile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: install lint test serve docker-build docker-run
2
+
3
+ install:
4
+ uv sync --all-groups --extra serve
5
+
6
+ lint:
7
+ uv run ruff check .
8
+
9
+ format:
10
+ uv run ruff format .
11
+
12
+ test:
13
+ uv run pytest -q
14
+
15
+ serve:
16
+ uv run uvicorn convgru_ensemble.serve:app --host 0.0.0.0 --port 8000
17
+
18
+ docker-build:
19
+ docker build -t convgru-ensemble .
20
+
21
+ docker-run:
22
+ docker run -p 8000:8000 -v ./checkpoints:/app/checkpoints convgru-ensemble
README.md CHANGED
@@ -1,82 +1,250 @@
1
- ---
2
- license: bsd-2-clause
3
- language: en
4
- tags:
5
- - weather
6
- - nowcasting
7
- - radar
8
- - precipitation
9
- - ensemble-forecasting
10
- - convgru
11
- - earth-observation
12
- library_name: pytorch
13
- pipeline_tag: image-to-image
14
- ---
15
 
16
- # IRENE — Italian Radar Ensemble Nowcasting Experiment
17
 
18
- **IRENE** is a ConvGRU encoder-decoder model for short-term precipitation forecasting (nowcasting) from radar data. The model generates probabilistic ensemble forecasts, producing multiple plausible future scenarios from a single input sequence.
 
 
 
19
 
20
- ## Model Description
21
 
22
- - **Architecture**: ConvGRU encoder-decoder with PixelShuffle/PixelUnshuffle for spatial scaling
23
- - **Input**: Sequence of past radar rain rate fields (T, H, W) in mm/h
24
- - **Output**: Ensemble of future rain rate forecasts (E, T, H, W) in mm/h
25
- - **Temporal resolution**: 5 minutes per timestep
26
- - **Training loss**: Continuous Ranked Probability Score (CRPS) with temporal consistency regularization
27
 
28
- The model encodes past radar observations into multi-scale hidden states using stacked ConvGRU blocks with PixelUnshuffle downsampling. The decoder generates forecasts by unrolling with different random noise inputs, producing diverse ensemble members that capture forecast uncertainty.
29
 
30
- ## Intended Uses
31
 
32
- - Short-term precipitation forecasting (0-60 min ahead) from radar reflectivity data
33
- - Probabilistic nowcasting with uncertainty quantification via ensemble spread
34
- - Research on deep learning for weather prediction
35
- - Fine-tuning on regional radar datasets
36
 
37
- ## How to Use
 
 
 
 
 
38
 
39
  ```python
40
  from convgru_ensemble import RadarLightningModel
41
 
42
- # Load from HuggingFace Hub
43
  model = RadarLightningModel.from_pretrained("it4lia/irene")
44
 
45
- # Run inference on past radar data (rain rate in mm/h)
46
  import numpy as np
47
- past = np.random.rand(6, 256, 256).astype(np.float32) # 6 past timesteps
48
  forecasts = model.predict(past, forecast_steps=12, ensemble_size=10)
49
- # forecasts.shape = (10, 12, 256, 256) — 10 ensemble members, 12 future steps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  ```
51
 
52
- ## Training Data
53
 
54
- Trained on the Italian DPC (Dipartimento della Protezione Civile) radar mosaic surface rain intensity (SRI) dataset, covering the Italian territory at ~1 km resolution with 5-minute temporal resolution.
55
 
56
- ## Training Procedure
57
 
58
- - **Optimizer**: Adam (lr=1e-4)
59
- - **Loss**: CRPS with temporal consistency penalty (lambda=0.01)
60
- - **Batch size**: 16
61
- - **Ensemble size during training**: 2 members
62
- - **Input window**: 6 past timesteps (30 min)
63
- - **Forecast horizon**: 12 future timesteps (60 min)
64
- - **Data augmentation**: Random rotations and flips
65
- - **NaN handling**: Masked loss for missing radar data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- ## Limitations
68
 
69
- - Trained on Italian radar data; performance may degrade on other domains without fine-tuning
70
- - 5-minute temporal resolution only
71
- - Best suited for convective and stratiform precipitation; extreme events may be underrepresented
72
- - Ensemble spread is generated via noisy decoder inputs, not a full Bayesian approach
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  ## Acknowledgements
75
 
76
- This model was developed as part of the **Italian AI-Factory** (IT4LIA), an EU-funded initiative supporting the adoption of AI across SMEs, academia, and public/private sectors. The AI-Factory provides free HPC compute, consultancy, and AI-ready datasets. This work showcases capabilities in the **Earth (weather and climate) vertical domain**.
 
 
 
 
 
 
 
 
 
 
77
 
78
- Developed at **Fondazione Bruno Kessler (FBK)**, Trento, Italy.
79
 
80
  ## License
81
 
82
- BSD 2-Clause License
 
1
+ <div align="center">
2
+
3
+ # ConvGRU-Ensemble
4
+
5
+ **Ensemble precipitation nowcasting using Convolutional GRU networks**
 
 
 
 
 
 
 
 
 
6
 
7
+ *Pretrained model for Italy:* ***IRENE*****I**talian **R**adar **E**nsemble **N**owcasting **E**xperiment
8
 
9
+ [![CI](https://github.com/DSIP-FBK/ConvGRU-Ensemble/actions/workflows/ci.yml/badge.svg)](https://github.com/DSIP-FBK/ConvGRU-Ensemble/actions)
10
+ [![License: BSD-2](https://img.shields.io/badge/license-BSD--2-blue.svg)](LICENSE)
11
+ [![Python 3.13+](https://img.shields.io/badge/python-3.13%2B-blue.svg)](https://python.org)
12
+ [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97-Model-yellow)](https://huggingface.co/it4lia/irene)
13
 
14
+ <br>
15
 
16
+ <a href="https://www.fbk.eu"><img src="https://webvalley.fbk.eu/static/img/logos/fbk-logo-blue.png" height="55" alt="Fondazione Bruno Kessler"></a>
17
+ &nbsp;&nbsp;&nbsp;&nbsp;
18
+ <a href="https://it4lia-aifactory.eu"><img src="https://it4lia-aifactory.eu/wp-content/uploads/2025/05/logo-IT4LIA-AI-factory.svg" height="55" alt="IT4LIA AI-Factory"></a>
19
+ &nbsp;&nbsp;&nbsp;&nbsp;
20
+ <a href="https://www.italiameteo.eu"><img src="https://it4lia-aifactory.eu/wp-content/uploads/2025/08/logo-italiameteo.svg" height="55" alt="ItaliaMeteo"></a>
21
 
22
+ <br><br>
23
 
24
+ The model encodes past radar frames into multi-scale hidden states and decodes them into an **ensemble of probabilistic forecasts** by running the decoder multiple times with different noise inputs, trained with **CRPS loss**.
25
 
26
+ </div>
 
 
 
27
 
28
+ ---
29
+
30
+ ## Quick Start
31
+
32
+ <details open>
33
+ <summary><b>Load from HuggingFace Hub</b></summary>
34
 
35
  ```python
36
  from convgru_ensemble import RadarLightningModel
37
 
 
38
  model = RadarLightningModel.from_pretrained("it4lia/irene")
39
 
 
40
  import numpy as np
41
+ past = np.load("past_radar.npy") # rain rate in mm/h, shape (T_past, H, W)
42
  forecasts = model.predict(past, forecast_steps=12, ensemble_size=10)
43
+ # forecasts.shape = (10, 12, H, W) — 10 members, 12 future steps, mm/h
44
+ ```
45
+
46
+ </details>
47
+
48
+ <details>
49
+ <summary><b>CLI Inference</b></summary>
50
+
51
+ ```bash
52
+ convgru-ensemble predict \
53
+ --input examples/sample_data.nc \
54
+ --hub-repo it4lia/irene \
55
+ --forecast-steps 12 \
56
+ --ensemble-size 10 \
57
+ --output predictions.nc
58
+ ```
59
+
60
+ </details>
61
+
62
+ <details>
63
+ <summary><b>Serve via API</b></summary>
64
+
65
+ ```bash
66
+ # With Docker
67
+ docker compose up
68
+
69
+ # Or directly
70
+ pip install convgru-ensemble[serve]
71
+ convgru-ensemble serve --hub-repo it4lia/irene --port 8000
72
+ ```
73
+
74
+ ```bash
75
+ curl -X POST http://localhost:8000/predict \
76
+ -F "file=@input.nc" -o predictions.nc
77
+ ```
78
+
79
+ | Endpoint | Method | Description |
80
+ |---|---|---|
81
+ | `/health` | GET | Health check |
82
+ | `/model/info` | GET | Model metadata and hyperparameters |
83
+ | `/predict` | POST | Upload NetCDF, get ensemble forecast as NetCDF |
84
+
85
+ </details>
86
+
87
+ <details>
88
+ <summary><b>Fine-tune on your data</b></summary>
89
+
90
+ ```bash
91
+ pip install convgru-ensemble
92
+ # See "Training" section below
93
  ```
94
 
95
+ </details>
96
 
97
+ ## Setup
98
 
99
+ Requires Python >= 3.13. Uses [uv](https://github.com/astral-sh/uv) for dependency management.
100
 
101
+ ```bash
102
+ uv sync # core dependencies
103
+ uv sync --extra serve # + FastAPI serving
104
+ ```
105
+
106
+ ## Data Preparation
107
+
108
+ The training pipeline expects a Zarr dataset with a rain rate variable `RR` indexed by `(time, x, y)`.
109
+
110
+ <details>
111
+ <summary><b>1. Filter valid datacubes</b></summary>
112
+
113
+ Scan the Zarr and find all space-time datacubes with fewer than `n_nan` NaN values:
114
+
115
+ ```bash
116
+ cd importance_sampler
117
+ uv run python filter_nan.py path/to/dataset.zarr \
118
+ --start_date 2021-01-01 --end_date 2025-12-11 \
119
+ --Dt 24 --w 256 --h 256 \
120
+ --step_T 3 --step_X 16 --step_Y 16 \
121
+ --n_nan 10000 --n_workers 8
122
+ ```
123
 
124
+ </details>
125
 
126
+ <details>
127
+ <summary><b>2. Importance sampling</b></summary>
128
+
129
+ Sample valid datacubes with higher probability for rainier events:
130
+
131
+ ```bash
132
+ uv run python sample_valid_datacubes.py path/to/dataset.zarr valid_datacubes_*.csv \
133
+ --q_min 1e-4 --m 0.1 --n_workers 8
134
+ ```
135
+
136
+ A pre-sampled CSV is provided in [`importance_sampler/output/`](importance_sampler/output/).
137
+
138
+ </details>
139
+
140
+ ## Training
141
+
142
+ Training is configured via [Fiddle](https://github.com/google/fiddle). Run with defaults:
143
+
144
+ ```bash
145
+ uv run python -m convgru_ensemble.train
146
+ ```
147
+
148
+ Override parameters from the command line:
149
+
150
+ ```bash
151
+ uv run python -m convgru_ensemble.train \
152
+ --config config:experiment \
153
+ --config set:model.num_blocks=5 \
154
+ --config set:model.forecast_steps=12 \
155
+ --config set:model.loss_class=crps \
156
+ --config set:model.ensemble_size=2 \
157
+ --config set:datamodule.batch_size=16 \
158
+ --config set:trainer.max_epochs=100
159
+ ```
160
+
161
+ Monitor with TensorBoard: `uv run tensorboard --logdir logs/`
162
+
163
+ | Parameter | Description | Default |
164
+ |---|---|---|
165
+ | `model.num_blocks` | Encoder/decoder depth | `5` |
166
+ | `model.forecast_steps` | Future steps to predict | `12` |
167
+ | `model.ensemble_size` | Ensemble members during training | `2` |
168
+ | `model.loss_class` | Loss function (`mse`, `mae`, `crps`, `afcrps`) | `crps` |
169
+ | `model.masked_loss` | Mask NaN regions in loss | `True` |
170
+ | `datamodule.steps` | Total timesteps per sample (past + future) | `18` |
171
+ | `datamodule.batch_size` | Batch size | `16` |
172
+
173
+ ## Architecture
174
+
175
+ ```
176
+ Input (B, T_past, 1, H, W)
177
+ |
178
+ v
179
+ +--------------------------+
180
+ | Encoder | ConvGRU + PixelUnshuffle (x num_blocks)
181
+ | Spatial dims halve at | Channels: 1 -> 4 -> 16 -> 64 -> 256 -> 1024
182
+ | each block |
183
+ +----------+---------------+
184
+ | hidden states
185
+ v
186
+ +--------------------------+
187
+ | Decoder | ConvGRU + PixelShuffle (x num_blocks)
188
+ | Noise input (x M runs) | Each run produces one ensemble member
189
+ | for ensemble generation |
190
+ +----------+---------------+
191
+ |
192
+ v
193
+ Output (B, T_future, M, H, W)
194
+ ```
195
+
196
+ ## Docker
197
+
198
+ ```bash
199
+ docker build -t convgru-ensemble .
200
+
201
+ # Run with local checkpoint
202
+ docker run -p 8000:8000 -v ./checkpoints:/app/checkpoints \
203
+ -e MODEL_CHECKPOINT=/app/checkpoints/model.ckpt convgru-ensemble
204
+
205
+ # Run with HuggingFace Hub
206
+ docker run -p 8000:8000 -e HF_REPO_ID=it4lia/irene convgru-ensemble
207
+ ```
208
+
209
+ ## Project Structure
210
+
211
+ ```
212
+ ConvGRU-Ensemble/
213
+ +-- convgru_ensemble/ # Python package
214
+ | +-- model.py # ConvGRU encoder-decoder architecture
215
+ | +-- losses.py # CRPS, afCRPS, masked loss wrappers
216
+ | +-- lightning_model.py # PyTorch Lightning training module
217
+ | +-- datamodule.py # Dataset and data loading
218
+ | +-- train.py # Training entry point (Fiddle config)
219
+ | +-- utils.py # Rain rate <-> reflectivity conversions
220
+ | +-- hub.py # HuggingFace Hub upload/download
221
+ | +-- cli.py # CLI for inference and serving
222
+ | +-- serve.py # FastAPI inference server
223
+ +-- examples/ # Sample data for testing
224
+ +-- importance_sampler/ # Data preparation scripts
225
+ +-- notebooks/ # Example notebooks
226
+ +-- scripts/ # Utility scripts (e.g., upload to Hub)
227
+ +-- tests/ # Test suite
228
+ +-- Dockerfile # Container for serving API
229
+ +-- MODEL_CARD.md # HuggingFace model card template
230
+ ```
231
 
232
  ## Acknowledgements
233
 
234
+ <div align="center">
235
+
236
+ Developed at **Fondazione Bruno Kessler (FBK)**, Trento, Italy, as part of the **Italian AI-Factory (IT4LIA)**, an EU-funded initiative supporting AI adoption across SMEs, academia, and public/private sectors. This work showcases capabilities in the **Earth (weather and climate) vertical domain**.
237
+
238
+ <br>
239
+
240
+ <a href="https://www.fbk.eu"><img src="https://webvalley.fbk.eu/static/img/logos/fbk-logo-blue.png" height="45" alt="Fondazione Bruno Kessler"></a>
241
+ &nbsp;&nbsp;&nbsp;&nbsp;
242
+ <a href="https://it4lia-aifactory.eu"><img src="https://it4lia-aifactory.eu/wp-content/uploads/2025/05/logo-IT4LIA-AI-factory.svg" height="45" alt="IT4LIA AI-Factory"></a>
243
+ &nbsp;&nbsp;&nbsp;&nbsp;
244
+ <a href="https://www.italiameteo.eu"><img src="https://it4lia-aifactory.eu/wp-content/uploads/2025/08/logo-italiameteo.svg" height="45" alt="ItaliaMeteo"></a>
245
 
246
+ </div>
247
 
248
  ## License
249
 
250
+ BSD 2-Clause — see [LICENSE](LICENSE).
convgru_ensemble/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """ConvGRU-Ensemble: Ensemble precipitation nowcasting using Convolutional GRU networks."""
2
+
3
+ __version__ = "0.1.0"
4
+
5
+ from convgru_ensemble.lightning_model import RadarLightningModel
6
+ from convgru_ensemble.model import EncoderDecoder
7
+
8
+ __all__ = ["EncoderDecoder", "RadarLightningModel"]
convgru_ensemble/__main__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Allow running the package as ``python -m convgru_ensemble``."""
2
+
3
+ from convgru_ensemble.cli import main
4
+
5
+ main()
convgru_ensemble/cli.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Command-line interface for ConvGRU-Ensemble inference and serving."""
2
+
3
+ import time
4
+
5
+ import fire
6
+ import numpy as np
7
+ import xarray as xr
8
+
9
+
10
+ def _load_model(checkpoint: str | None = None, hub_repo: str | None = None, device: str = "cpu"):
11
+ """Load model from local checkpoint or HuggingFace Hub."""
12
+ from .lightning_model import RadarLightningModel
13
+
14
+ if hub_repo is not None:
15
+ print(f"Loading model from HuggingFace Hub: {hub_repo}")
16
+ return RadarLightningModel.from_pretrained(hub_repo, device=device)
17
+ elif checkpoint is not None:
18
+ print(f"Loading model from checkpoint: {checkpoint}")
19
+ return RadarLightningModel.from_checkpoint(checkpoint, device=device)
20
+ else:
21
+ raise ValueError("Either --checkpoint or --hub-repo must be provided.")
22
+
23
+
24
+ def predict(
25
+ input: str,
26
+ checkpoint: str | None = None,
27
+ hub_repo: str | None = None,
28
+ variable: str = "RR",
29
+ forecast_steps: int = 12,
30
+ ensemble_size: int = 10,
31
+ device: str = "cpu",
32
+ output: str = "predictions.nc",
33
+ ):
34
+ """
35
+ Run inference on a NetCDF input file and save predictions as NetCDF.
36
+
37
+ Args:
38
+ input: Path to input NetCDF file with rain rate data (T, H, W) or (T, Y, X).
39
+ checkpoint: Path to local .ckpt checkpoint file.
40
+ hub_repo: HuggingFace Hub repo ID (e.g., 'it4lia/irene'). Alternative to --checkpoint.
41
+ variable: Name of the rain rate variable in the NetCDF file.
42
+ forecast_steps: Number of future timesteps to forecast.
43
+ ensemble_size: Number of ensemble members to generate.
44
+ device: Device for inference ('cpu' or 'cuda').
45
+ output: Path for the output NetCDF file.
46
+ """
47
+ model = _load_model(checkpoint, hub_repo, device)
48
+
49
+ # Load input data
50
+ print(f"Loading input: {input}")
51
+ ds = xr.open_dataset(input)
52
+ if variable not in ds:
53
+ available = list(ds.data_vars)
54
+ raise ValueError(f"Variable '{variable}' not found. Available: {available}")
55
+
56
+ data = ds[variable].values # (T, H, W) or similar
57
+ if data.ndim != 3:
58
+ raise ValueError(f"Expected 3D data (T, H, W), got shape {data.shape}")
59
+
60
+ print(f"Input shape: {data.shape}")
61
+ past = data.astype(np.float32)
62
+
63
+ # Run inference
64
+ t0 = time.perf_counter()
65
+ preds = model.predict(past, forecast_steps=forecast_steps, ensemble_size=ensemble_size)
66
+ elapsed = time.perf_counter() - t0
67
+ print(f"Output shape: {preds.shape} (ensemble, time, H, W)")
68
+ print(f"Elapsed: {elapsed:.2f}s")
69
+
70
+ # Build output dataset
71
+ ds_out = xr.Dataset(
72
+ {
73
+ "precipitation_forecast": xr.DataArray(
74
+ data=preds,
75
+ dims=["ensemble_member", "forecast_step", "y", "x"],
76
+ attrs={"units": "mm/h", "long_name": "Ensemble precipitation forecast"},
77
+ ),
78
+ },
79
+ attrs={
80
+ "model": "ConvGRU-Ensemble",
81
+ "forecast_steps": forecast_steps,
82
+ "ensemble_size": ensemble_size,
83
+ "source_file": str(input),
84
+ },
85
+ )
86
+
87
+ ds_out.to_netcdf(output)
88
+ print(f"Predictions saved to: {output}")
89
+
90
+
91
+ def serve(
92
+ checkpoint: str | None = None,
93
+ hub_repo: str | None = None,
94
+ host: str = "0.0.0.0",
95
+ port: int = 8000,
96
+ device: str = "cpu",
97
+ ):
98
+ """
99
+ Start the FastAPI inference server.
100
+
101
+ Args:
102
+ checkpoint: Path to local .ckpt checkpoint file.
103
+ hub_repo: HuggingFace Hub repo ID (e.g., 'it4lia/irene'). Alternative to --checkpoint.
104
+ host: Host to bind to.
105
+ port: Port to listen on.
106
+ device: Device for inference ('cpu' or 'cuda').
107
+ """
108
+ import os
109
+
110
+ if checkpoint is not None:
111
+ os.environ["MODEL_CHECKPOINT"] = checkpoint
112
+ if hub_repo is not None:
113
+ os.environ["HF_REPO_ID"] = hub_repo
114
+ os.environ.setdefault("DEVICE", device)
115
+
116
+ import uvicorn
117
+
118
+ uvicorn.run("convgru_ensemble.serve:app", host=host, port=port)
119
+
120
+
121
+ def main():
122
+ fire.Fire({"predict": predict, "serve": serve})
123
+
124
+
125
+ if __name__ == "__main__":
126
+ main()
convgru_ensemble/datamodule.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ import xarray as xr
8
+ from torch.utils.data import DataLoader, Dataset
9
+
10
+ from .utils import rainrate_to_normalized
11
+
12
+
13
+ class SampledRadarDataset(Dataset):
14
+ """
15
+ PyTorch dataset that loads radar datacubes from a Zarr store using
16
+ pre-sampled spatial-temporal coordinates from a CSV file.
17
+
18
+ Each sample is a spatio-temporal datacube of shape ``(T, 1, H, W)``
19
+ converted from rain rate to normalized reflectivity.
20
+
21
+ Parameters
22
+ ----------
23
+ zarr_path : str
24
+ Path to the Zarr dataset containing the ``'RR'`` rain rate variable.
25
+ csv_path : str
26
+ Path to the CSV file with columns ``(t, x, y)`` specifying the
27
+ top-left corner of each datacube.
28
+ steps : int
29
+ Number of timesteps to extract per sample.
30
+ return_mask : bool, optional
31
+ If ``True``, also return a spatial NaN mask. Default is ``False``.
32
+ deterministic : bool, optional
33
+ If ``True``, use a fixed random seed (42) for reproducibility.
34
+ Default is ``False``.
35
+ augment : bool, optional
36
+ If ``True``, apply random spatial augmentations (rotation, flips).
37
+ Default is ``False``.
38
+ indices : sequence of int or None, optional
39
+ Subset of row indices to use from the CSV. If ``None``, use all rows.
40
+ Default is ``None``.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ zarr_path: str,
46
+ csv_path: str,
47
+ steps: int,
48
+ return_mask: bool = False,
49
+ deterministic: bool = False,
50
+ augment: bool = False,
51
+ indices=None,
52
+ ):
53
+ """
54
+ Initialize SampledRadarDataset.
55
+
56
+ Parameters
57
+ ----------
58
+ zarr_path : str
59
+ Path to the Zarr dataset containing the ``'RR'`` rain rate
60
+ variable.
61
+ csv_path : str
62
+ Path to the CSV file with columns ``(t, x, y)``.
63
+ steps : int
64
+ Number of timesteps to extract per sample.
65
+ return_mask : bool, optional
66
+ If ``True``, also return a spatial NaN mask. Default is ``False``.
67
+ deterministic : bool, optional
68
+ If ``True``, use a fixed random seed (42). Default is ``False``.
69
+ augment : bool, optional
70
+ If ``True``, apply random spatial augmentations. Default is
71
+ ``False``.
72
+ indices : sequence of int or None, optional
73
+ Subset of row indices from the CSV. Default is ``None``.
74
+ """
75
+ self.coords = pd.read_csv(csv_path).sort_values("t")
76
+ if indices is not None:
77
+ self.coords = self.coords.iloc[list(indices)].reset_index(drop=True)
78
+ self.zg = xr.open_zarr(zarr_path)
79
+ self.RR = self.zg["RR"]
80
+ self.rng = np.random.default_rng(seed=42) if deterministic else np.random.default_rng(int(time.time()))
81
+ self.return_mask = return_mask
82
+ self.augment = augment
83
+
84
+ if augment:
85
+ print("Data augmentation is enabled.")
86
+
87
+ # default valid grid size and time step
88
+ self.w = 256
89
+ self.h = 256
90
+ self.dt = 24
91
+ self.steps = steps
92
+
93
+ # raise warning if steps > dt
94
+ if self.steps > self.dt:
95
+ print(f"Warning: requested steps ({self.steps}) > sampled time window ({self.dt})")
96
+
97
+ def __len__(self):
98
+ """
99
+ Return the number of samples in the dataset.
100
+
101
+ Returns
102
+ -------
103
+ length : int
104
+ Number of datacube samples.
105
+ """
106
+ return len(self.coords)
107
+
108
+ def shape(self):
109
+ """
110
+ Return the nominal shape of the full dataset.
111
+
112
+ Returns
113
+ -------
114
+ shape : tuple of int
115
+ ``(num_samples, steps, 1, width, height)``.
116
+ """
117
+ return (len(self.coords), self.steps, 1, self.w, self.h)
118
+
119
+ def _apply_augmentations(
120
+ self, *tensors, rotate_prob: float = 0.5, hflip_prob: float = 0.5, vflip_prob: float = 0.5
121
+ ):
122
+ """
123
+ Apply random spatial augmentations consistently to all input tensors.
124
+
125
+ All tensors receive the same random transformation so that spatial
126
+ alignment is preserved (e.g. between data and mask).
127
+
128
+ Parameters
129
+ ----------
130
+ *tensors : torch.Tensor
131
+ One or more tensors of shape ``(T, C, H, W)``.
132
+ rotate_prob : float, optional
133
+ Probability of applying a random 90-degree rotation. Default is
134
+ ``0.5``.
135
+ hflip_prob : float, optional
136
+ Probability of applying a horizontal flip. Default is ``0.5``.
137
+ vflip_prob : float, optional
138
+ Probability of applying a vertical flip. Default is ``0.5``.
139
+
140
+ Returns
141
+ -------
142
+ augmented : torch.Tensor or tuple of torch.Tensor
143
+ Single tensor if one input was given, otherwise a tuple of
144
+ augmented tensors.
145
+ """
146
+ # Random 90-degree rotation (0, 90, 180, or 270 degrees)
147
+ if self.rng.random() < rotate_prob:
148
+ k = self.rng.integers(1, 4) # 1=90, 2=180, 3=270 degrees
149
+ tensors = [torch.rot90(t, k, dims=[-2, -1]) for t in tensors]
150
+
151
+ # Random horizontal flip
152
+ if self.rng.random() < hflip_prob:
153
+ tensors = [torch.flip(t, dims=[-1]) for t in tensors]
154
+
155
+ # Random vertical flip
156
+ if self.rng.random() < vflip_prob:
157
+ tensors = [torch.flip(t, dims=[-2]) for t in tensors]
158
+
159
+ tensors = [t.contiguous() for t in tensors]
160
+ return tensors[0] if len(tensors) == 1 else tuple(tensors)
161
+
162
+ def __getitem__(self, idx: int):
163
+ """
164
+ Load and return a single datacube sample.
165
+
166
+ Parameters
167
+ ----------
168
+ idx : int
169
+ Index of the sample in the dataset.
170
+
171
+ Returns
172
+ -------
173
+ sample : dict of str to torch.Tensor
174
+ Dictionary with key ``'data'`` containing a tensor of shape
175
+ ``(T, 1, H, W)``. If ``return_mask`` is ``True``, also contains
176
+ ``'mask'`` of shape ``(1, 1, H, W)``.
177
+ """
178
+ t0, x0, y0 = self.coords.iloc[idx]
179
+
180
+ x_slice = slice(x0, x0 + self.w)
181
+ y_slice = slice(y0, y0 + self.h)
182
+
183
+ if self.steps < self.dt:
184
+ # radom sampling within available time window
185
+ t_start = self.rng.integers(t0, t0 + self.dt - self.steps + 1)
186
+ else:
187
+ t_start = t0
188
+ t_slice = slice(t_start, t_start + self.steps)
189
+
190
+ data = rainrate_to_normalized(self.RR[t_slice, x_slice, y_slice])
191
+
192
+ # create a mask for all nan values over time dimension
193
+ # shape: (1, H, W) - NOT repeated over time, broadcasting handles it
194
+ if self.return_mask:
195
+ mask = (~(np.isnan(data).any(axis=0, keepdims=True))).astype(np.float32)
196
+
197
+ # replace nan values with -1
198
+ data = np.nan_to_num(data, nan=-1.0)
199
+
200
+ # convert to tensors
201
+ data = torch.from_numpy(data[:, np.newaxis, :, :])
202
+ if self.return_mask:
203
+ mask = torch.from_numpy(mask.values[:, np.newaxis, :, :])
204
+
205
+ # apply augmentations (training only)
206
+ if self.augment:
207
+ if self.return_mask:
208
+ data, mask = self._apply_augmentations(data, mask)
209
+ else:
210
+ data = self._apply_augmentations(data)
211
+
212
+ if self.return_mask:
213
+ return {"data": data, "mask": mask}
214
+ else:
215
+ return {"data": data}
216
+
217
+
218
+ class RadarDataModule(pl.LightningDataModule):
219
+ """
220
+ PyTorch Lightning data module for radar datacube datasets.
221
+
222
+ Handles train/val/test splitting and DataLoader creation from a single
223
+ Zarr store and CSV coordinate file.
224
+
225
+ Parameters
226
+ ----------
227
+ zarr_path : str
228
+ Path to the Zarr dataset.
229
+ csv_path : str
230
+ Path to the CSV file with datacube coordinates.
231
+ steps : int
232
+ Number of timesteps per sample.
233
+ train_ratio : float, optional
234
+ Fraction of data used for training. Default is ``0.7``.
235
+ val_ratio : float, optional
236
+ Fraction of data used for validation. Default is ``0.15``.
237
+ return_mask : bool, optional
238
+ Whether to return NaN masks. Default is ``False``.
239
+ deterministic : bool, optional
240
+ Whether to use fixed random seeds. Default is ``False``.
241
+ augment : bool, optional
242
+ Whether to apply data augmentation (training set only). Default is
243
+ ``True``.
244
+ **dataloader_kwargs
245
+ Additional keyword arguments forwarded to ``DataLoader`` (e.g.
246
+ ``batch_size``, ``num_workers``, ``pin_memory``).
247
+ """
248
+
249
+ def __init__(
250
+ self,
251
+ zarr_path,
252
+ csv_path,
253
+ steps,
254
+ train_ratio=0.7,
255
+ val_ratio=0.15,
256
+ return_mask=False,
257
+ deterministic=False,
258
+ augment=True,
259
+ **dataloader_kwargs,
260
+ ):
261
+ """
262
+ Initialize RadarDataModule.
263
+
264
+ Parameters
265
+ ----------
266
+ zarr_path : str
267
+ Path to the Zarr dataset.
268
+ csv_path : str
269
+ Path to the CSV file with datacube coordinates.
270
+ steps : int
271
+ Number of timesteps per sample.
272
+ train_ratio : float, optional
273
+ Fraction of data for training. Default is ``0.7``.
274
+ val_ratio : float, optional
275
+ Fraction of data for validation. Default is ``0.15``.
276
+ return_mask : bool, optional
277
+ Whether to return NaN masks. Default is ``False``.
278
+ deterministic : bool, optional
279
+ Whether to use fixed random seeds. Default is ``False``.
280
+ augment : bool, optional
281
+ Whether to apply data augmentation. Default is ``True``.
282
+ **dataloader_kwargs
283
+ Forwarded to ``DataLoader``.
284
+ """
285
+ super().__init__()
286
+ self.zarr_path = zarr_path
287
+ self.csv_path = csv_path
288
+ self.steps = steps
289
+ self.train_ratio = train_ratio
290
+ self.val_ratio = val_ratio
291
+ self.dataloader_kwargs = dataloader_kwargs
292
+ self.return_mask = return_mask
293
+ self.deterministic = deterministic
294
+ self.augment = augment
295
+
296
+ def setup(self, stage=None):
297
+ """
298
+ Create train, validation, and test datasets from the CSV coordinates.
299
+
300
+ Splits are chronological: the first ``train_ratio`` fraction is used
301
+ for training, the next ``val_ratio`` for validation, and the rest for
302
+ testing. Augmentation is only applied to the training set.
303
+
304
+ Parameters
305
+ ----------
306
+ stage : str or None, optional
307
+ Lightning stage (``'fit'``, ``'test'``, etc.). Ignored; all
308
+ datasets are always created. Default is ``None``.
309
+ """
310
+ # Load CSV to get total length for splitting
311
+ coords = pd.read_csv(self.csv_path).sort_values("t")
312
+ n = len(coords)
313
+
314
+ # Compute split indices
315
+ train_end = int(n * self.train_ratio)
316
+ val_end = int(n * (self.train_ratio + self.val_ratio))
317
+
318
+ # Create separate datasets (augmentation only for training)
319
+ self.train_dataset = SampledRadarDataset(
320
+ self.zarr_path,
321
+ self.csv_path,
322
+ self.steps,
323
+ self.return_mask,
324
+ self.deterministic,
325
+ augment=self.augment,
326
+ indices=range(0, train_end),
327
+ )
328
+ self.val_dataset = SampledRadarDataset(
329
+ self.zarr_path,
330
+ self.csv_path,
331
+ self.steps,
332
+ self.return_mask,
333
+ self.deterministic,
334
+ augment=False,
335
+ indices=range(train_end, val_end),
336
+ )
337
+ self.test_dataset = SampledRadarDataset(
338
+ self.zarr_path,
339
+ self.csv_path,
340
+ self.steps,
341
+ self.return_mask,
342
+ self.deterministic,
343
+ augment=False,
344
+ indices=range(val_end, n),
345
+ )
346
+
347
+ def train_dataloader(self):
348
+ """
349
+ Return the training DataLoader.
350
+
351
+ Returns
352
+ -------
353
+ loader : DataLoader
354
+ DataLoader over the training dataset with shuffling enabled.
355
+ """
356
+ return DataLoader(self.train_dataset, shuffle=True, **self.dataloader_kwargs)
357
+
358
+ def val_dataloader(self):
359
+ """
360
+ Return the validation DataLoader.
361
+
362
+ Returns
363
+ -------
364
+ loader : DataLoader
365
+ DataLoader over the validation dataset without shuffling.
366
+ """
367
+ return DataLoader(self.val_dataset, shuffle=False, **self.dataloader_kwargs)
368
+
369
+ def test_dataloader(self):
370
+ """
371
+ Return the test DataLoader.
372
+
373
+ Returns
374
+ -------
375
+ loader : DataLoader
376
+ DataLoader over the test dataset without shuffling.
377
+ """
378
+ return DataLoader(self.test_dataset, shuffle=False, **self.dataloader_kwargs)
convgru_ensemble/hub.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace Hub integration for uploading and downloading ConvGRU-Ensemble models."""
2
+
3
+ import json
4
+ import shutil
5
+ import tempfile
6
+ from pathlib import Path
7
+
8
+ from huggingface_hub import HfApi, hf_hub_download
9
+
10
+
11
+ def push_to_hub(
12
+ checkpoint_path: str,
13
+ repo_id: str,
14
+ model_card_path: str | None = None,
15
+ private: bool = False,
16
+ ) -> str:
17
+ """
18
+ Upload a trained model checkpoint to HuggingFace Hub.
19
+
20
+ Parameters
21
+ ----------
22
+ checkpoint_path : str
23
+ Path to the ``.ckpt`` checkpoint file.
24
+ repo_id : str
25
+ HuggingFace Hub repository ID (e.g., ``'it4lia/irene'``).
26
+ model_card_path : str or None, optional
27
+ Path to a model card markdown file. If provided, it is uploaded
28
+ as ``README.md``. Default is ``None``.
29
+ private : bool, optional
30
+ Whether to create a private repository. Default is ``False``.
31
+
32
+ Returns
33
+ -------
34
+ url : str
35
+ URL of the uploaded model on HuggingFace Hub.
36
+ """
37
+ import torch
38
+
39
+ api = HfApi()
40
+ api.create_repo(repo_id=repo_id, exist_ok=True, private=private)
41
+
42
+ with tempfile.TemporaryDirectory() as tmp_dir:
43
+ tmp_path = Path(tmp_dir)
44
+
45
+ # Copy checkpoint
46
+ shutil.copy2(checkpoint_path, tmp_path / "model.ckpt")
47
+
48
+ # Extract and save model config from checkpoint
49
+ ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
50
+ if "hyper_parameters" in ckpt:
51
+ hparams = ckpt["hyper_parameters"]
52
+ # Convert non-serializable values to strings
53
+ config = {}
54
+ for k, v in hparams.items():
55
+ try:
56
+ json.dumps(v)
57
+ config[k] = v
58
+ except (TypeError, ValueError):
59
+ config[k] = str(v)
60
+ with open(tmp_path / "config.json", "w") as f:
61
+ json.dump(config, f, indent=2)
62
+
63
+ # Copy model card as README.md
64
+ if model_card_path is not None:
65
+ shutil.copy2(model_card_path, tmp_path / "README.md")
66
+
67
+ url = api.upload_folder(
68
+ folder_path=str(tmp_path),
69
+ repo_id=repo_id,
70
+ commit_message="Upload ConvGRU-Ensemble model",
71
+ )
72
+
73
+ return url
74
+
75
+
76
+ def from_pretrained(
77
+ repo_id: str,
78
+ filename: str = "model.ckpt",
79
+ device: str = "cpu",
80
+ ) -> "RadarLightningModel": # noqa: F821
81
+ """
82
+ Download and load a pretrained model from HuggingFace Hub.
83
+
84
+ Parameters
85
+ ----------
86
+ repo_id : str
87
+ HuggingFace Hub repository ID (e.g., ``'it4lia/irene'``).
88
+ filename : str, optional
89
+ Name of the checkpoint file in the repository. Default is
90
+ ``'model.ckpt'``.
91
+ device : str, optional
92
+ Device to map the model weights to. Default is ``'cpu'``.
93
+
94
+ Returns
95
+ -------
96
+ model : RadarLightningModel
97
+ Model with loaded pretrained weights.
98
+ """
99
+ from .lightning_model import RadarLightningModel
100
+
101
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
102
+ return RadarLightningModel.from_checkpoint(ckpt_path, device=device)
convgru_ensemble/lightning_model.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import numpy as np
4
+ import pytorch_lightning as pl
5
+ import torch
6
+ import torchvision
7
+
8
+ from .losses import build_loss
9
+ from .model import EncoderDecoder
10
+ from .utils import normalized_to_rainrate, rainrate_to_normalized
11
+
12
+
13
+ def apply_radar_colormap(tensor: torch.Tensor) -> torch.Tensor:
14
+ """
15
+ Convert grayscale radar values to RGB using the STEPS-BE colorscale.
16
+
17
+ Maps normalized values in [0, 1] (representing 0-60 dBZ) to a 14-color
18
+ discrete colormap. Pixels below 10 dBZ are rendered as white.
19
+
20
+ Parameters
21
+ ----------
22
+ tensor : torch.Tensor
23
+ Grayscale tensor with values in [0, 1], of shape ``(N, 1, H, W)``.
24
+
25
+ Returns
26
+ -------
27
+ rgb : torch.Tensor
28
+ RGB tensor of shape ``(N, 3, H, W)`` with values in [0, 1].
29
+ """
30
+ # STEPS-BE colors (RGB values normalized to 0-1)
31
+ colors = (
32
+ torch.tensor(
33
+ [
34
+ [0, 255, 255], # cyan
35
+ [0, 191, 255], # deepskyblue
36
+ [30, 144, 255], # dodgerblue
37
+ [0, 0, 255], # blue
38
+ [127, 255, 0], # chartreuse
39
+ [50, 205, 50], # limegreen
40
+ [0, 128, 0], # green
41
+ [0, 100, 0], # darkgreen
42
+ [255, 255, 0], # yellow
43
+ [255, 215, 0], # gold
44
+ [255, 165, 0], # orange
45
+ [255, 0, 0], # red
46
+ [255, 0, 255], # magenta
47
+ [139, 0, 139], # darkmagenta
48
+ ],
49
+ dtype=torch.float32,
50
+ device=tensor.device,
51
+ )
52
+ / 255.0
53
+ )
54
+
55
+ # dBZ levels: 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60 (11 levels, 10 intervals)
56
+ # But we have 14 colors, so extend to cover 10-80 dBZ range with 5 dBZ steps
57
+ # Normalized thresholds (0-1 maps to 0-60 dBZ)
58
+ # We'll use 14 intervals from 10 dBZ onwards
59
+ num_colors = len(colors)
60
+ min_dbz_norm = 10 / 60 # ~0.167, below this is background
61
+ max_dbz_norm = 1.0
62
+ thresholds = torch.linspace(min_dbz_norm, max_dbz_norm, num_colors + 1, device=tensor.device)
63
+
64
+ # Output tensor (N, 3, H, W) - initialize with white for values below 10 dBZ
65
+ N, _, H, W = tensor.shape
66
+ output = torch.ones(N, 3, H, W, dtype=torch.float32, device=tensor.device)
67
+
68
+ # Apply colormap: find which bin each pixel falls into
69
+ for i in range(num_colors - 1):
70
+ mask = (tensor[:, 0] >= thresholds[i]) & (tensor[:, 0] < thresholds[i + 1])
71
+ for c in range(3):
72
+ output[:, c][mask] = colors[i, c]
73
+
74
+ # Last color handles all values >= second-to-last threshold (inclusive of max)
75
+ mask = tensor[:, 0] >= thresholds[num_colors - 1]
76
+ for c in range(3):
77
+ output[:, c][mask] = colors[-1, c]
78
+
79
+ return output
80
+
81
+
82
+ class RadarLightningModel(pl.LightningModule):
83
+ """
84
+ PyTorch Lightning module for radar precipitation nowcasting.
85
+
86
+ Wraps an :class:`EncoderDecoder` model and handles training, validation,
87
+ and test steps including loss computation, ensemble generation, and
88
+ TensorBoard image logging.
89
+
90
+ Parameters
91
+ ----------
92
+ input_channels : int
93
+ Number of input channels per grid point.
94
+ num_blocks : int
95
+ Number of encoder/decoder blocks in the model.
96
+ ensemble_size : int, optional
97
+ Number of ensemble members to generate. Default is ``1``.
98
+ noisy_decoder : bool, optional
99
+ Whether to use random noise as decoder input. Default is ``False``.
100
+ forecast_steps : int or None, optional
101
+ Number of future timesteps to forecast. Default is ``None``.
102
+ loss_class : type, str, or None, optional
103
+ Loss function class or its string name (see ``PIXEL_LOSSES``).
104
+ Default is ``None`` (MSELoss).
105
+ loss_params : dict or None, optional
106
+ Keyword arguments for the loss constructor. Default is ``None``.
107
+ masked_loss : bool, optional
108
+ Whether to wrap the loss with :class:`MaskedLoss`. Default is
109
+ ``False``.
110
+ optimizer_class : type or None, optional
111
+ Optimizer class. Default is ``None`` (Adam).
112
+ optimizer_params : dict or None, optional
113
+ Keyword arguments for the optimizer. Default is ``None``.
114
+ lr_scheduler_class : type or None, optional
115
+ Learning rate scheduler class. Default is ``None``.
116
+ lr_scheduler_params : dict or None, optional
117
+ Keyword arguments for the LR scheduler. Default is ``None``.
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ input_channels: int,
123
+ num_blocks: int,
124
+ ensemble_size: int = 1,
125
+ noisy_decoder: bool = False,
126
+ forecast_steps: type | int | None = None,
127
+ loss_class: type | str | None = None,
128
+ loss_params: dict[str, Any] | None = None,
129
+ masked_loss: bool = False,
130
+ optimizer_class: type | None = None,
131
+ optimizer_params: dict[str, Any] | None = None,
132
+ lr_scheduler_class: type | None = None,
133
+ lr_scheduler_params: dict[str, Any] | None = None,
134
+ ) -> None:
135
+ """
136
+ Initialize RadarLightningModel.
137
+
138
+ Parameters
139
+ ----------
140
+ input_channels : int
141
+ Number of input channels per grid point.
142
+ num_blocks : int
143
+ Number of encoder/decoder blocks.
144
+ ensemble_size : int, optional
145
+ Number of ensemble members. Default is ``1``.
146
+ noisy_decoder : bool, optional
147
+ Use random noise as decoder input. Default is ``False``.
148
+ forecast_steps : int or None, optional
149
+ Number of future timesteps to forecast. Default is ``None``.
150
+ loss_class : type, str, or None, optional
151
+ Loss function class or name. Default is ``None``.
152
+ loss_params : dict or None, optional
153
+ Loss constructor kwargs. Default is ``None``.
154
+ masked_loss : bool, optional
155
+ Wrap loss with masking. Default is ``False``.
156
+ optimizer_class : type or None, optional
157
+ Optimizer class. Default is ``None``.
158
+ optimizer_params : dict or None, optional
159
+ Optimizer kwargs. Default is ``None``.
160
+ lr_scheduler_class : type or None, optional
161
+ LR scheduler class. Default is ``None``.
162
+ lr_scheduler_params : dict or None, optional
163
+ LR scheduler kwargs. Default is ``None``.
164
+ """
165
+ super().__init__()
166
+ self.save_hyperparameters()
167
+
168
+ # Initialize model
169
+ self.model = EncoderDecoder(self.hparams.input_channels, self.hparams.num_blocks)
170
+
171
+ self.criterion = build_loss(
172
+ loss_class=self.hparams.loss_class,
173
+ loss_params=self.hparams.loss_params,
174
+ masked_loss=self.hparams.masked_loss,
175
+ )
176
+ self.log_images_iterations = [50, 100, 200, 500, 750, 1000, 2000, 5000]
177
+
178
+ if self.hparams.ensemble_size > 1:
179
+ print(f"Using ensemble mode: {self.hparams.ensemble_size} independent ensemble members will be generated.")
180
+
181
+ def forward(self, x: torch.Tensor, forecast_steps: int, ensemble_size: int | None = None) -> torch.Tensor:
182
+ """
183
+ Run the encoder-decoder forward pass.
184
+
185
+ Parameters
186
+ ----------
187
+ x : torch.Tensor
188
+ Input tensor of shape ``(B, T, C, H, W)``.
189
+ forecast_steps : int
190
+ Number of future timesteps to forecast.
191
+ ensemble_size : int or None, optional
192
+ Number of ensemble members. If ``None``, uses the value from
193
+ ``hparams``. Default is ``None``.
194
+
195
+ Returns
196
+ -------
197
+ preds : torch.Tensor
198
+ Predictions of shape ``(B, forecast_steps, C, H, W)`` or
199
+ ``(B, forecast_steps, ensemble_size, H, W)`` for ensembles.
200
+ """
201
+ ensemble_size = self.hparams.ensemble_size if ensemble_size is None else ensemble_size
202
+ return self.model(
203
+ x, steps=forecast_steps, noisy_decoder=self.hparams.noisy_decoder, ensemble_size=ensemble_size
204
+ )
205
+
206
+ def shared_step(
207
+ self, batch: dict[str, torch.Tensor], split: str = "train", ensemble_size: int | None = None
208
+ ) -> torch.Tensor:
209
+ """
210
+ Shared forward step used during training, validation, and testing.
211
+
212
+ Splits the input into past and future, runs the model, computes the
213
+ loss, and logs metrics and optional images.
214
+
215
+ Parameters
216
+ ----------
217
+ batch : dict of str to torch.Tensor
218
+ Batch dictionary with key ``'data'`` of shape
219
+ ``(B, T_total, C, H, W)`` and optionally ``'mask'``.
220
+ split : str, optional
221
+ One of ``'train'``, ``'val'``, or ``'test'``. Controls logging
222
+ behavior. Default is ``'train'``.
223
+ ensemble_size : int or None, optional
224
+ Override for the number of ensemble members. Default is ``None``.
225
+
226
+ Returns
227
+ -------
228
+ loss : torch.Tensor
229
+ Scalar loss value.
230
+ """
231
+ data = batch["data"]
232
+ past = data[:, : -self.hparams.forecast_steps]
233
+ future = data[:, -self.hparams.forecast_steps :]
234
+
235
+ preds = self(past, forecast_steps=self.hparams.forecast_steps, ensemble_size=ensemble_size).clamp(
236
+ min=-1, max=1
237
+ ) # Ensure predictions are within [-1, 1]
238
+
239
+ if self.hparams.masked_loss:
240
+ mask = batch["mask"][:, -self.hparams.forecast_steps :]
241
+ loss = self.criterion(preds, future, mask)
242
+ else:
243
+ loss = self.criterion(preds, future)
244
+
245
+ # Handle tuple return from composite losses
246
+ if isinstance(loss, tuple):
247
+ loss, log_dict = loss
248
+ # log_dict already contains split-prefixed keys like 'val/pixel_loss'
249
+ self.log_dict(
250
+ log_dict, prog_bar=False, logger=True, on_step=(split == "train"), on_epoch=True, sync_dist=True
251
+ )
252
+
253
+ self.log(f"{split}_loss", loss, prog_bar=True, on_epoch=True, on_step=(split == "train"), sync_dist=True)
254
+
255
+ # Log ensemble diversity for ensemble training
256
+ if self.hparams.ensemble_size > 1:
257
+ ensemble_std = preds.std(dim=2).mean() # std across ensemble members
258
+ self.log(f"{split}_ensemble_std", ensemble_std, on_epoch=True, sync_dist=True)
259
+
260
+ if split == "train" and (
261
+ self.global_step in self.log_images_iterations or self.global_step % self.log_images_iterations[-1] == 0
262
+ ):
263
+ self.log_images(past, future, preds, split=split)
264
+ return loss
265
+
266
+ def log_images(self, past: torch.Tensor, future: torch.Tensor, preds: torch.Tensor, split: str = "val") -> None:
267
+ """
268
+ Log radar image grids to TensorBoard.
269
+
270
+ Visualizes the first sample in the batch, showing past frames, ground
271
+ truth future, ensemble average, and individual ensemble members using
272
+ the STEPS-BE radar colormap.
273
+
274
+ Parameters
275
+ ----------
276
+ past : torch.Tensor
277
+ Past input frames of shape ``(B, T_past, C, H, W)``.
278
+ future : torch.Tensor
279
+ Ground truth future frames of shape ``(B, T_future, C, H, W)``.
280
+ preds : torch.Tensor
281
+ Predicted frames of shape ``(B, T_future, C_or_E, H, W)``.
282
+ split : str, optional
283
+ Split name used as TensorBoard tag prefix. Default is ``'val'``.
284
+ """
285
+ # Log first sample in the batch
286
+ sample_idx = 0
287
+
288
+ # Log past separately
289
+ past_sample = past[sample_idx]
290
+ if self.hparams.ensemble_size > 1:
291
+ past_sample = past_sample.mean(dim=1, keepdim=True)
292
+ past_norm = (past_sample + 1) / 2
293
+ past_rgb = apply_radar_colormap(past_norm)
294
+ past_grid = torchvision.utils.make_grid(past_rgb, nrow=past_sample.shape[0])
295
+ self.logger.experiment.add_image(f"{split}/past", past_grid, self.global_step)
296
+
297
+ # Create combined preds grid: future (ground truth) as first row, then avg + ensemble members
298
+ future_sample = future[sample_idx] # (T, C, H, W)
299
+ preds_sample = preds[sample_idx] # (T, E, H, W) or (T, C, H, W)
300
+
301
+ if self.hparams.ensemble_size > 1:
302
+ # Layout: rows = [future, avg, member0, member1, ...], cols = timesteps
303
+ preds_avg = preds_sample.mean(dim=1, keepdim=True) # (T, E, H, W) -> (T, 1, H, W)
304
+ num_members_to_log = min(3, preds_sample.shape[1])
305
+
306
+ # Collect all rows: future first, then average, then individual members
307
+ rows = [future_sample] # (T, 1, H, W)
308
+ rows.append(preds_avg) # (T, 1, H, W)
309
+ for i in range(num_members_to_log):
310
+ rows.append(preds_sample[:, i : i + 1, :, :]) # (T, 1, H, W)
311
+
312
+ # Stack all rows: (num_rows * T, 1, H, W)
313
+ all_frames = torch.cat(rows, dim=0) # ((2 + num_members) * T, 1, H, W)
314
+ all_frames_norm = (all_frames + 1) / 2
315
+ all_frames_rgb = apply_radar_colormap(all_frames_norm)
316
+ grid = torchvision.utils.make_grid(all_frames_rgb, nrow=future_sample.shape[0])
317
+ self.logger.experiment.add_image(f"{split}/preds", grid, self.global_step)
318
+ else:
319
+ # For non-ensemble: show future and preds in two rows
320
+ rows = [future_sample, preds_sample] # Each is (T, C, H, W)
321
+ all_frames = torch.cat(rows, dim=0) # (2 * T, C, H, W)
322
+ all_frames_norm = (all_frames + 1) / 2
323
+ all_frames_rgb = apply_radar_colormap(all_frames_norm)
324
+ grid = torchvision.utils.make_grid(all_frames_rgb, nrow=future_sample.shape[0])
325
+ self.logger.experiment.add_image(f"{split}/preds", grid, self.global_step)
326
+
327
+ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
328
+ """
329
+ Execute a single training step.
330
+
331
+ Parameters
332
+ ----------
333
+ batch : dict of str to torch.Tensor
334
+ Training batch.
335
+ batch_idx : int
336
+ Index of the batch.
337
+
338
+ Returns
339
+ -------
340
+ loss : torch.Tensor
341
+ Training loss.
342
+ """
343
+ loss = self.shared_step(batch, split="train")
344
+ return loss
345
+
346
+ def validation_step(
347
+ self,
348
+ batch: dict[str, torch.Tensor],
349
+ batch_idx: int,
350
+ ) -> torch.Tensor:
351
+ """
352
+ Execute a single validation step.
353
+
354
+ Uses 10 ensemble members for evaluation.
355
+
356
+ Parameters
357
+ ----------
358
+ batch : dict of str to torch.Tensor
359
+ Validation batch.
360
+ batch_idx : int
361
+ Index of the batch.
362
+
363
+ Returns
364
+ -------
365
+ loss : torch.Tensor
366
+ Validation loss.
367
+ """
368
+ loss = self.shared_step(batch, split="val", ensemble_size=10)
369
+ return loss
370
+
371
+ def test_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
372
+ """
373
+ Execute a single test step.
374
+
375
+ Uses 10 ensemble members for evaluation.
376
+
377
+ Parameters
378
+ ----------
379
+ batch : dict of str to torch.Tensor
380
+ Test batch.
381
+ batch_idx : int
382
+ Index of the batch.
383
+
384
+ Returns
385
+ -------
386
+ loss : torch.Tensor
387
+ Test loss.
388
+ """
389
+ loss = self.shared_step(batch, split="test", ensemble_size=10)
390
+ return loss
391
+
392
+ def configure_optimizers(self) -> dict[str, Any]:
393
+ """
394
+ Configure the optimizer and optional learning rate scheduler.
395
+
396
+ Falls back to Adam with default parameters if no optimizer is
397
+ specified. If a scheduler is provided, it monitors ``val_loss``.
398
+
399
+ Returns
400
+ -------
401
+ config : dict
402
+ Dictionary with ``'optimizer'`` and optionally ``'lr_scheduler'``
403
+ keys, as expected by PyTorch Lightning.
404
+ """
405
+ if self.hparams.optimizer_class is not None:
406
+ optimizer = (
407
+ self.hparams.optimizer_class(self.parameters(), **self.hparams.optimizer_params)
408
+ if self.hparams.optimizer_params is not None
409
+ else self.hparams.optimizer_class(self.parameters())
410
+ )
411
+ print(
412
+ f"Using optimizer: {self.hparams.optimizer_class.__name__} with params {self.hparams.optimizer_params}"
413
+ )
414
+ else:
415
+ optimizer = torch.optim.Adam(self.parameters())
416
+ print("Using default Adam optimizer with default parameters.")
417
+
418
+ if self.hparams.lr_scheduler_class is not None:
419
+ lr_scheduler = (
420
+ self.hparams.lr_scheduler_class(optimizer, **self.hparams.lr_scheduler_params)
421
+ if self.hparams.lr_scheduler_params is not None
422
+ else self.hparams.lr_scheduler_class(optimizer)
423
+ )
424
+ print(
425
+ f"Using LR scheduler: {self.hparams.lr_scheduler_class.__name__} with params {self.hparams.lr_scheduler_params}"
426
+ )
427
+ return {"optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "monitor": "val_loss"}}
428
+ else:
429
+ return {"optimizer": optimizer}
430
+
431
+ @classmethod
432
+ def from_checkpoint(cls, checkpoint_path: str, device: str = "cpu") -> "RadarLightningModel":
433
+ """
434
+ Load a model from a checkpoint file.
435
+
436
+ Parameters
437
+ ----------
438
+ checkpoint_path : str
439
+ Path to the ``.ckpt`` checkpoint file.
440
+ device : str, optional
441
+ Device to map the checkpoint weights to. Default is ``'cpu'``.
442
+
443
+ Returns
444
+ -------
445
+ model : RadarLightningModel
446
+ Model with loaded weights.
447
+ """
448
+ return cls.load_from_checkpoint(
449
+ checkpoint_path,
450
+ map_location=torch.device(device),
451
+ strict=True,
452
+ weights_only=False,
453
+ )
454
+
455
+ @classmethod
456
+ def from_pretrained(cls, repo_id: str, filename: str = "model.ckpt", device: str = "cpu") -> "RadarLightningModel":
457
+ """
458
+ Load a pretrained model from HuggingFace Hub.
459
+
460
+ Parameters
461
+ ----------
462
+ repo_id : str
463
+ HuggingFace Hub repository ID (e.g., ``'it4lia/irene'``).
464
+ filename : str, optional
465
+ Name of the checkpoint file in the repository. Default is
466
+ ``'model.ckpt'``.
467
+ device : str, optional
468
+ Device to map the model weights to. Default is ``'cpu'``.
469
+
470
+ Returns
471
+ -------
472
+ model : RadarLightningModel
473
+ Model with loaded pretrained weights.
474
+ """
475
+ from .hub import from_pretrained
476
+
477
+ return from_pretrained(repo_id, filename, device)
478
+
479
+ def predict(self, past: torch.Tensor, forecast_steps: int = 1, ensemble_size: int | None = 1) -> torch.Tensor:
480
+ """
481
+ Generate precipitation forecasts from past radar observations.
482
+
483
+ Handles padding, NaN removal, unit conversion, and reshaping
484
+ automatically. Input should be raw rain rate values.
485
+
486
+ Parameters
487
+ ----------
488
+ past : torch.Tensor
489
+ Past radar frames as rain rate in mm/h, of shape ``(T, H, W)``.
490
+ forecast_steps : int, optional
491
+ Number of future timesteps to forecast. Default is ``1``.
492
+ ensemble_size : int, optional
493
+ Number of ensemble members to generate. If ``None``, uses the
494
+ value from ``hparams``. Default is ``1``.
495
+
496
+ Returns
497
+ -------
498
+ preds : np.ndarray
499
+ Forecasted rain rate in mm/h, of shape
500
+ ``(ensemble_size, forecast_steps, H, W)``.
501
+
502
+ Raises
503
+ ------
504
+ ValueError
505
+ If ``past`` does not have exactly 3 dimensions.
506
+ """
507
+ if len(past.shape) != 3:
508
+ raise ValueError("Input must be of shape (T, H, W)")
509
+
510
+ T, H, W = past.shape
511
+ ensemble_size = self.hparams.ensemble_size if ensemble_size is None else ensemble_size
512
+
513
+ # Each block the model decrease the resolution by a factor of 2
514
+ # The input must be divisible by 2^(num_blocks-1)
515
+ divisor = 2 ** (self.hparams.num_blocks)
516
+ padH = (divisor - (H % divisor)) % divisor
517
+ padW = (divisor - (W % divisor)) % divisor
518
+ padded_past = past
519
+ if padH != 0 or padW != 0:
520
+ padded_past = np.pad(past, ((0, 0), (0, padH), (0, padW)), mode="constant", constant_values=0)
521
+
522
+ # Remove Nan
523
+ past_clean = np.nan_to_num(padded_past)
524
+
525
+ # Reshape the input to (B, T, C, H, W)
526
+ past_clean = past_clean[np.newaxis, :, np.newaxis, ...]
527
+
528
+ # Rainrate to normalized reflectivity
529
+ norm_past = rainrate_to_normalized(past_clean)
530
+
531
+ # Numpy to torch tensor
532
+ x = torch.from_numpy(norm_past)
533
+
534
+ # Move to device
535
+ x = x.to(self.device)
536
+
537
+ # Forward pass
538
+ self.eval()
539
+ with torch.no_grad():
540
+ preds = self.model(x, forecast_steps, self.hparams.noisy_decoder, ensemble_size)
541
+
542
+ # Move to CPU
543
+ preds = preds.cpu()
544
+
545
+ # Tensor to numpy array
546
+ preds = preds.numpy()
547
+
548
+ # Rescale back to rain rate
549
+ preds = normalized_to_rainrate(preds)
550
+
551
+ # Remove the batch (T, E, H, W)
552
+ preds = preds.squeeze(0)
553
+
554
+ # Swap the Time and Ensemble dimensions (E, T, H, W)
555
+ preds = np.swapaxes(preds, 0, 1)
556
+
557
+ # Remove the padding
558
+ preds = preds[..., :H, :W]
559
+
560
+ return preds
convgru_ensemble/losses.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class LossWithReduction(nn.Module):
8
+ """
9
+ Base class for losses with reduction options.
10
+
11
+ Parameters
12
+ ----------
13
+ reduction : str, optional
14
+ Reduction mode to apply to the loss. Must be one of ``'mean'``,
15
+ ``'sum'``, or ``'none'``. Default is ``'mean'``.
16
+ """
17
+
18
+ def __init__(self, reduction: str = "mean"):
19
+ """
20
+ Initialize LossWithReduction.
21
+
22
+ Parameters
23
+ ----------
24
+ reduction : str, optional
25
+ Reduction mode to apply to the loss. Must be one of ``'mean'``,
26
+ ``'sum'``, or ``'none'``. Default is ``'mean'``.
27
+ """
28
+ super().__init__()
29
+ assert reduction in ["mean", "sum", "none"], "reduction must be 'mean', 'sum', or 'none'"
30
+ self.reduction = reduction
31
+
32
+ def apply_reduction(self, loss: torch.Tensor) -> torch.Tensor:
33
+ """
34
+ Apply the specified reduction to the loss tensor.
35
+
36
+ Parameters
37
+ ----------
38
+ loss : torch.Tensor
39
+ Loss tensor to reduce.
40
+
41
+ Returns
42
+ -------
43
+ reduced_loss : torch.Tensor
44
+ Reduced loss tensor.
45
+ """
46
+ if self.reduction == "mean":
47
+ return loss.mean()
48
+ elif self.reduction == "sum":
49
+ return loss.sum()
50
+ else: # 'none'
51
+ return loss
52
+
53
+
54
+ class MaskedLoss(LossWithReduction):
55
+ """
56
+ Wrapper to apply a mask to a given loss function.
57
+
58
+ Masks out invalid pixels before computing the loss, ensuring that only
59
+ valid regions contribute to the final value.
60
+
61
+ Parameters
62
+ ----------
63
+ elementwise_loss : nn.Module
64
+ Base loss function to be masked. Must accept ``(preds, target)`` and
65
+ return element-wise (unreduced) loss. Should be instantiated with
66
+ ``reduction='none'``.
67
+ reduction : str, optional
68
+ Reduction mode applied after masking. Must be one of ``'mean'``,
69
+ ``'sum'``, or ``'none'``. Default is ``'mean'``.
70
+ """
71
+
72
+ def __init__(self, elementwise_loss: nn.Module, reduction: str = "mean"):
73
+ """
74
+ Initialize MaskedLoss.
75
+
76
+ Parameters
77
+ ----------
78
+ elementwise_loss : nn.Module
79
+ Base loss function to be masked. Must accept ``(preds, target)``
80
+ and return element-wise (unreduced) loss. Should be instantiated
81
+ with ``reduction='none'``.
82
+ reduction : str, optional
83
+ Reduction mode applied after masking. Must be one of ``'mean'``,
84
+ ``'sum'``, or ``'none'``. Default is ``'mean'``.
85
+ """
86
+ super().__init__(reduction=reduction)
87
+ self.elementwise_loss = elementwise_loss
88
+
89
+ def forward(self, preds: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
90
+ """
91
+ Compute masked loss.
92
+
93
+ Parameters
94
+ ----------
95
+ preds : torch.Tensor
96
+ Predictions of shape (B, T, C, *D).
97
+ target : torch.Tensor
98
+ Target of shape (B, T, C, *D).
99
+ mask : torch.Tensor
100
+ Mask of shape (B, T, C, *D)
101
+ or (B, T, 1, *D)
102
+ or (B, 1, 1, *D), with 1 for valid and 0 for invalid pixels.
103
+ Broadcasted to match preds/target shape if needed.
104
+
105
+ Returns
106
+ -------
107
+ loss : torch.Tensor
108
+ Scalar loss value.
109
+ """
110
+ # assert preds.shape == target.shape, f"preds and target must have the same shape, got {preds.shape} and {target.shape}"
111
+ # assert mask.shape == preds.shape, f"mask must have the same shape as preds, got {mask.shape} and {preds.shape}"
112
+
113
+ # Compute element-wise loss
114
+ elementwise_loss = self.elementwise_loss(preds, target) # shape (B, T, C, *D)
115
+
116
+ # Apply mask (broadcast if needed)
117
+ masked_loss = elementwise_loss * mask # shape (B, T, C, *D)
118
+
119
+ # Average over valid pixels
120
+ # Account for broadcasting: mask.sum() × broadcast_factor
121
+ broadcast_factor = elementwise_loss.numel() // mask.numel()
122
+ valid_pixels = mask.sum() * broadcast_factor
123
+ if valid_pixels > 0:
124
+ if self.reduction == "mean":
125
+ return masked_loss.sum() / valid_pixels
126
+ elif self.reduction == "sum":
127
+ return masked_loss.sum()
128
+ else: # 'none'
129
+ return masked_loss
130
+ else:
131
+ return torch.tensor(0.0, device=preds.device)
132
+
133
+
134
+ class CRPS(LossWithReduction):
135
+ r"""
136
+ Continuous Ranked Probability Score (CRPS) loss with optional temporal
137
+ consistency regularization.
138
+
139
+ CRPS = E[|X - y|] - 0.5 * E[|X - X'|], where X, X' are independent
140
+ samples from the forecast distribution and y is the observation.
141
+
142
+ Parameters
143
+ ----------
144
+ temporal_lambda : float, optional
145
+ Weight for the temporal consistency penalty. If ``0.0`` (default),
146
+ the penalty is disabled. When enabled, adds a penalty for large
147
+ differences between consecutive timesteps within each ensemble
148
+ member, preventing pulsing artifacts.
149
+ reduction : str, optional
150
+ Reduction mode. Must be one of ``'mean'``, ``'sum'``, or ``'none'``.
151
+ ``'mean'`` averages over batch and all non-ensemble dimensions.
152
+ Default is ``'mean'``.
153
+
154
+ Expected shapes
155
+ ---------------
156
+ preds : (B, T, M, \*D)
157
+ Ensemble predictions with time T on dim=1, ensemble size M on dim=2.
158
+ target : (B, T, C, \*D)
159
+ Deterministic target / analysis with channel C on dim=2 (should be 1).
160
+ """
161
+
162
+ def __init__(self, temporal_lambda: float = 0.0, reduction: str = "mean"):
163
+ """
164
+ Initialize CRPS loss.
165
+
166
+ Parameters
167
+ ----------
168
+ temporal_lambda : float, optional
169
+ Weight for the temporal consistency penalty. Default is ``0.0``
170
+ (disabled).
171
+ reduction : str, optional
172
+ Reduction mode. Must be one of ``'mean'``, ``'sum'``, or
173
+ ``'none'``. Default is ``'mean'``.
174
+ """
175
+ super().__init__(reduction=reduction)
176
+ self.temporal_lambda = temporal_lambda
177
+
178
+ def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
179
+ """
180
+ Compute CRPS loss.
181
+
182
+ CRPS = E[|X - y|] - 0.5 * E[|X - X'|], where X, X' are independent
183
+ samples from the forecast distribution and y is the observation.
184
+
185
+ Parameters
186
+ ----------
187
+ preds : torch.Tensor
188
+ Ensemble forecasts of shape ``(B, T, M, *D)``, where B is batch
189
+ size, T is the number of timesteps, M is ensemble size, and
190
+ ``*D`` are spatial dimensions.
191
+ target : torch.Tensor
192
+ Verifying observation / analysis of shape ``(B, T, C, *D)``,
193
+ where C should be 1. Broadcasts against ``preds`` on dim=2.
194
+
195
+ Returns
196
+ -------
197
+ loss : torch.Tensor
198
+ Scalar if ``reduction='mean'`` or ``'sum'``, otherwise a tensor
199
+ of shape ``(B, T, 1, *D)`` (or ``(B, 1, 1, *D)`` when
200
+ ``temporal_lambda > 0``).
201
+ """
202
+ # preds: (B, T, M, *D)
203
+ # target: (B, T, C, *D) where C should be 1
204
+ # target broadcasts against preds: (B, T, 1, *D) vs (B, T, M, *D)
205
+
206
+ # First term: E[|X - y|]
207
+ # Compute absolute difference between each ensemble member and target
208
+ diff_to_target = torch.abs(preds - target) # (B, T, M, *D)
209
+ term1 = diff_to_target.mean(dim=2) # Average over ensemble: (B, T, *D)
210
+
211
+ # Second term: 0.5 * E[|X - X'|]
212
+ # Compute pairwise differences between ensemble members
213
+ # preds: (B, T, M, *D)
214
+ # Expand for pairwise differences
215
+ # preds_i: (B, T, M, 1, *D)
216
+ # preds_j: (B, T, 1, M, *D)
217
+ preds_i = preds.unsqueeze(3)
218
+ preds_j = preds.unsqueeze(2)
219
+
220
+ # Pairwise absolute differences: (B, T, M, M, *D)
221
+ pairwise_diff = torch.abs(preds_i - preds_j)
222
+
223
+ # Average over both ensemble dimensions
224
+ # Sum over M*M pairs and divide by M*M
225
+ term2 = 0.5 * pairwise_diff.mean(dim=(2, 3)) # (B, T, *D)
226
+
227
+ # CRPS
228
+ crps = term1 - term2 # (B, T, *D)
229
+
230
+ # Temporal consistency penalty
231
+ if self.temporal_lambda > 0:
232
+ # average over time dimension
233
+ crps = crps.mean(dim=1) # (B, *D)
234
+
235
+ # preds: (B, T, M, *D)
236
+ # Compute differences between consecutive timesteps per ensemble member
237
+ temporal_diff = preds[:, 1:, :, ...] - preds[:, :-1, :, ...] # (B, T-1, M, *D)
238
+ temporal_penalty = torch.abs(temporal_diff).mean(
239
+ dim=(1, 2)
240
+ ) # average over time and ensemble dimensions (B, *D)
241
+ # Add penalty to CRPS (before reduction, averaged over time)
242
+ crps = crps + self.temporal_lambda * temporal_penalty
243
+ crps = crps[:, None, None, ...] # add time and channel dims back for consistency (B, 1, 1, *D)
244
+ else:
245
+ # Keep singleton channel dim for MaskedLoss compatibility: (B, T, 1, *D)
246
+ crps = crps.unsqueeze(2)
247
+
248
+ return self.apply_reduction(crps)
249
+
250
+
251
+ class afCRPS(LossWithReduction):
252
+ r"""
253
+ Almost fair CRPS (afCRPS) loss as in eq. (4) of Lang et al. (2024).
254
+
255
+ Interpolates between the standard (energy-score style) CRPS and the
256
+ fair CRPS via the fairness parameter ``alpha``.
257
+
258
+ Parameters
259
+ ----------
260
+ alpha : float, optional
261
+ Fairness parameter in ``(0, 1]``. ``alpha=1`` recovers the fair
262
+ CRPS. Lang et al. (2024) recommend ``alpha=0.95``. Default is
263
+ ``0.95``.
264
+ temporal_lambda : float, optional
265
+ Weight for the temporal consistency penalty. If ``0.0`` (default),
266
+ the penalty is disabled. When enabled, adds a penalty for large
267
+ differences between consecutive timesteps within each ensemble
268
+ member.
269
+ reduction : str, optional
270
+ Reduction mode. Must be one of ``'mean'``, ``'sum'``, or ``'none'``.
271
+ ``'mean'`` averages over batch and all non-ensemble dimensions.
272
+ Default is ``'mean'``.
273
+
274
+ Expected shapes
275
+ ---------------
276
+ preds : (B, T, M, \*D)
277
+ Ensemble predictions with time T on dim=1, ensemble size M on dim=2.
278
+ target : (B, T, C, \*D)
279
+ Deterministic target / analysis with channel C on dim=2 (should be 1).
280
+ """
281
+
282
+ def __init__(self, alpha: float = 0.95, temporal_lambda: float = 0.0, reduction: str = "mean"):
283
+ """
284
+ Initialize afCRPS loss.
285
+
286
+ Parameters
287
+ ----------
288
+ alpha : float, optional
289
+ Fairness parameter in ``(0, 1]``. ``alpha=1`` recovers the fair
290
+ CRPS. Default is ``0.95``.
291
+ temporal_lambda : float, optional
292
+ Weight for the temporal consistency penalty. Default is ``0.0``
293
+ (disabled).
294
+ reduction : str, optional
295
+ Reduction mode. Must be one of ``'mean'``, ``'sum'``, or
296
+ ``'none'``. Default is ``'mean'``.
297
+
298
+ Raises
299
+ ------
300
+ ValueError
301
+ If ``alpha`` is not in ``(0, 1]``.
302
+ """
303
+ super().__init__(reduction=reduction)
304
+ if not (0.0 < alpha <= 1.0):
305
+ raise ValueError("alpha must be in (0, 1].")
306
+ self.alpha = alpha
307
+ self.temporal_lambda = temporal_lambda
308
+
309
+ def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
310
+ """
311
+ Compute afCRPS over an ensemble.
312
+
313
+ Parameters
314
+ ----------
315
+ preds : torch.Tensor
316
+ (B, T, M, *D) ensemble forecasts.
317
+ target : torch.Tensor
318
+ (B, T, C, *D) verifying observation / analysis (C=1).
319
+
320
+ Returns
321
+ -------
322
+ loss : torch.Tensor
323
+ Scalar if reduction=='mean', else per-sample tensor.
324
+ """
325
+ if preds.dim() < 3:
326
+ raise ValueError("preds must have at least 3 dimensions.")
327
+ if target.shape[0] != preds.shape[0]:
328
+ raise ValueError("batch dimension of preds and target must match.")
329
+ if target.shape[1] != preds.shape[1]:
330
+ raise ValueError("time dimension of preds and target must match.")
331
+
332
+ # preds: (B, T, M, *D), target: (B, T, 1, *D)
333
+ M = preds.shape[2]
334
+ if M < 2:
335
+ raise ValueError("Ensemble size M must be >= 2 for afCRPS.")
336
+
337
+ eps = (1.0 - self.alpha) / float(M)
338
+
339
+ # |x_j - y| : (B, T, M, *D) — target broadcasts via C=1
340
+ abs_x_minus_y = (preds - target).abs()
341
+
342
+ # Pairwise terms over ensemble dim (dim=2), j != k
343
+ # x_j: (B, T, M, 1, *D)
344
+ # x_k: (B, T, 1, M, *D)
345
+ x_j = preds.unsqueeze(3)
346
+ x_k = preds.unsqueeze(2)
347
+
348
+ # |x_j - y|, |x_k - y| broadcast to (B, T, M, M, *D)
349
+ abs_xj_minus_y = abs_x_minus_y.unsqueeze(3)
350
+ abs_xk_minus_y = abs_x_minus_y.unsqueeze(2)
351
+
352
+ # |x_j - x_k|: (B, T, M, M, *D)
353
+ abs_xj_minus_xk = (x_j - x_k).abs()
354
+
355
+ # Per (j,k) term: |x_j - y| + |x_k - y| - (1 - eps)|x_j - x_k|
356
+ term = abs_xj_minus_y + abs_xk_minus_y - (1.0 - eps) * abs_xj_minus_xk
357
+
358
+ # Exclude j == k (diagonal) since eq. (4) sums over k != j
359
+ idx = torch.arange(M, device=preds.device)
360
+ mask = idx[:, None] != idx[None, :] # (M, M)
361
+ term = term * mask.view(1, 1, M, M, *([1] * (term.dim() - 4)))
362
+
363
+ # Sum over j and k dims → (B, T, *D)
364
+ summed = term.sum(dim=(2, 3))
365
+
366
+ # Normalization factor 1 / [2 M (M - 1)]
367
+ afcrps = summed / (2.0 * M * (M - 1)) # (B, T, *D)
368
+
369
+ # Temporal consistency penalty
370
+ if self.temporal_lambda > 0:
371
+ # average over time dimension
372
+ afcrps = afcrps.mean(dim=1) # (B, *D)
373
+
374
+ # preds: (B, T, M, *D)
375
+ # Compute differences between consecutive timesteps per ensemble member
376
+ temporal_diff = preds[:, 1:, :, ...] - preds[:, :-1, :, ...] # (B, T-1, M, *D)
377
+ temporal_penalty = torch.abs(temporal_diff).mean(
378
+ dim=(1, 2)
379
+ ) # average over time and ensemble dimensions (B, *D)
380
+ # Add penalty to afCRPS (before reduction, averaged over time)
381
+ afcrps = afcrps + self.temporal_lambda * temporal_penalty
382
+ afcrps = afcrps[:, None, None, ...] # add time and channel dims back for consistency (B, 1, 1, *D)
383
+ else:
384
+ # Keep singleton channel dim for MaskedLoss compatibility: (B, T, 1, *D)
385
+ afcrps = afcrps.unsqueeze(2)
386
+
387
+ return self.apply_reduction(afcrps)
388
+
389
+
390
+ PIXEL_LOSSES = {"mse": nn.MSELoss, "mae": nn.L1Loss, "crps": CRPS, "afcrps": afCRPS}
391
+
392
+
393
+ def build_loss(
394
+ loss_class: type | str,
395
+ loss_params: dict[str, Any] | None = None,
396
+ masked_loss: bool = False,
397
+ ) -> nn.Module:
398
+ """
399
+ Build a loss function, optionally wrapped with masking.
400
+
401
+ Resolves a loss class by name (from ``PIXEL_LOSSES``) or accepts a class
402
+ directly, instantiates it with the given parameters, and optionally wraps
403
+ it in a :class:`MaskedLoss`.
404
+
405
+ Parameters
406
+ ----------
407
+ loss_class : type or str
408
+ Loss class or its string name. Accepted string names are the keys of
409
+ ``PIXEL_LOSSES``: ``'mse'``, ``'mae'``, ``'crps'``, ``'afcrps'``.
410
+ If ``None``, defaults to ``nn.MSELoss``.
411
+ loss_params : dict of str to any, optional
412
+ Keyword arguments forwarded to the loss class constructor. If
413
+ ``masked_loss`` is ``True``, the ``'reduction'`` key (if present)
414
+ is extracted and passed to :class:`MaskedLoss` instead. Default is
415
+ ``None``.
416
+ masked_loss : bool, optional
417
+ If ``True``, the loss is wrapped in :class:`MaskedLoss` and the
418
+ inner loss is instantiated with ``reduction='none'``. Default is
419
+ ``False``.
420
+
421
+ Returns
422
+ -------
423
+ criterion : nn.Module
424
+ Instantiated loss module, optionally wrapped in :class:`MaskedLoss`.
425
+
426
+ Raises
427
+ ------
428
+ ValueError
429
+ If ``loss_class`` is a string not found in ``PIXEL_LOSSES``.
430
+ """
431
+ if isinstance(loss_class, str):
432
+ if loss_class.lower() not in PIXEL_LOSSES:
433
+ raise ValueError(f"Unknown loss class '{loss_class}'. Available: {list(PIXEL_LOSSES.keys())}")
434
+ loss_class = PIXEL_LOSSES[loss_class.lower()]
435
+ elif loss_class is None:
436
+ loss_class = nn.MSELoss # default
437
+ print("No loss_class provided, using default MSELoss.")
438
+
439
+ params = loss_params.copy() if loss_params is not None else None
440
+
441
+ # if the loss is masked, the reduction is handled in MaskedLoss
442
+ if masked_loss and params is not None:
443
+ # pop 'reduction' from loss_params and pass to MaskedLoss
444
+ reduction = params.pop("reduction", "mean")
445
+ criterion = MaskedLoss(loss_class(reduction="none", **params), reduction=reduction)
446
+ print(f"Using masked loss: {loss_class.__name__} with params {params} and reduction {reduction}")
447
+ elif masked_loss:
448
+ criterion = MaskedLoss(loss_class(reduction="none"), reduction="mean")
449
+ print(f"Using masked loss: {loss_class.__name__} with default params and reduction 'mean'")
450
+ else:
451
+ if params is not None:
452
+ criterion = loss_class(**params)
453
+ print(f"Using custom loss: {loss_class.__name__} with params {params}")
454
+ else:
455
+ criterion = loss_class()
456
+ print(f"Using loss: {loss_class.__name__} with default params")
457
+
458
+ return criterion
convgru_ensemble/model.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ResidualConvBlock(nn.Module):
7
+ """
8
+ Residual convolutional block with two convolutions and a skip connection.
9
+
10
+ Applies two 2D convolutions with a ReLU activation in between. If the
11
+ input and output channel counts differ, a 1x1 projection is used for the
12
+ residual path.
13
+
14
+ Parameters
15
+ ----------
16
+ in_channels : int
17
+ Number of input channels.
18
+ out_channels : int
19
+ Number of output channels.
20
+ kernel_size : int, optional
21
+ Kernel size for both convolutions. Default is ``3``.
22
+ padding : int, optional
23
+ Padding for both convolutions. Default is ``1``.
24
+ """
25
+
26
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1):
27
+ """
28
+ Initialize ResidualConvBlock.
29
+
30
+ Parameters
31
+ ----------
32
+ in_channels : int
33
+ Number of input channels.
34
+ out_channels : int
35
+ Number of output channels.
36
+ kernel_size : int, optional
37
+ Kernel size for both convolutions. Default is ``3``.
38
+ padding : int, optional
39
+ Padding for both convolutions. Default is ``1``.
40
+ """
41
+ super().__init__()
42
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
43
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding)
44
+ if in_channels != out_channels:
45
+ self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=1)
46
+ else:
47
+ self.proj = None
48
+
49
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
50
+ """
51
+ Forward pass through the residual convolutional block.
52
+
53
+ Parameters
54
+ ----------
55
+ x : torch.Tensor
56
+ Input tensor of shape ``(B, C_in, H, W)``.
57
+
58
+ Returns
59
+ -------
60
+ out : torch.Tensor
61
+ Output tensor of shape ``(B, C_out, H, W)``.
62
+ """
63
+ residual = x
64
+ out = F.relu(self.conv1(x))
65
+ out = self.conv2(out)
66
+ if self.proj is not None:
67
+ residual = self.proj(residual)
68
+ out += residual
69
+ out = F.relu(out)
70
+ return out
71
+
72
+
73
+ class ConvGRUCell(nn.Module):
74
+ """
75
+ Convolutional GRU cell operating on 2D spatial grids.
76
+
77
+ Implements a single-step GRU update where all linear projections are
78
+ replaced by 2D convolutions, preserving spatial structure.
79
+
80
+ Parameters
81
+ ----------
82
+ input_size : int
83
+ Number of channels in the input tensor.
84
+ hidden_size : int
85
+ Number of channels in the hidden state.
86
+ kernel_size : int, optional
87
+ Kernel size for the convolutional gates. Default is ``3``.
88
+ conv_layer : nn.Module, optional
89
+ Convolutional layer class to use. Default is ``nn.Conv2d``.
90
+ """
91
+
92
+ def __init__(self, input_size: int, hidden_size: int, kernel_size: int = 3, conv_layer: nn.Module = nn.Conv2d):
93
+ """
94
+ Initialize ConvGRUCell.
95
+
96
+ Parameters
97
+ ----------
98
+ input_size : int
99
+ Number of channels in the input tensor.
100
+ hidden_size : int
101
+ Number of channels in the hidden state.
102
+ kernel_size : int, optional
103
+ Kernel size for the convolutional gates. Default is ``3``.
104
+ conv_layer : nn.Module, optional
105
+ Convolutional layer class to use. Default is ``nn.Conv2d``.
106
+ """
107
+ super().__init__()
108
+ padding = kernel_size // 2
109
+ self.input_size = input_size
110
+ self.hidden_size = hidden_size
111
+ # update and reset gates are combined for optimization
112
+ self.combined_gates = conv_layer(input_size + hidden_size, 2 * hidden_size, kernel_size, padding=padding)
113
+ self.out_gate = conv_layer(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
114
+
115
+ def forward(self, inpt: torch.Tensor | None = None, h_s: torch.Tensor | None = None) -> torch.Tensor:
116
+ """
117
+ Forward the ConvGRU cell for a single timestep.
118
+
119
+ If either input is ``None``, it is initialized to zeros based on the
120
+ shape of the other. If both are ``None``, a ``ValueError`` is raised.
121
+
122
+ Parameters
123
+ ----------
124
+ inpt : torch.Tensor or None, optional
125
+ Input tensor of shape ``(B, input_size, H, W)``. Default is
126
+ ``None``.
127
+ h_s : torch.Tensor or None, optional
128
+ Hidden state tensor of shape ``(B, hidden_size, H, W)``. Default
129
+ is ``None``.
130
+
131
+ Returns
132
+ -------
133
+ new_state : torch.Tensor
134
+ Updated hidden state of shape ``(B, hidden_size, H, W)``.
135
+
136
+ Raises
137
+ ------
138
+ ValueError
139
+ If both ``inpt`` and ``h_s`` are ``None``.
140
+ """
141
+ if h_s is None and inpt is None:
142
+ raise ValueError("Both input and state can't be None")
143
+ elif h_s is None:
144
+ h_s = torch.zeros(
145
+ inpt.size(0), self.hidden_size, inpt.size(2), inpt.size(3), dtype=inpt.dtype, device=inpt.device
146
+ )
147
+ elif inpt is None:
148
+ inpt = torch.zeros(
149
+ h_s.size(0), self.input_size, h_s.size(2), h_s.size(3), dtype=h_s.dtype, device=h_s.device
150
+ )
151
+
152
+ gamma, beta = torch.chunk(self.combined_gates(torch.cat([inpt, h_s], dim=1)), 2, dim=1)
153
+ update = torch.sigmoid(gamma)
154
+ reset = torch.sigmoid(beta)
155
+
156
+ out_inputs = torch.tanh(self.out_gate(torch.cat([inpt, h_s * reset], dim=1)))
157
+ new_state = h_s * (1 - update) + out_inputs * update
158
+
159
+ return new_state
160
+
161
+
162
+ class ConvGRU(nn.Module):
163
+ """
164
+ Convolutional GRU that unrolls a :class:`ConvGRUCell` over a sequence.
165
+
166
+ Parameters
167
+ ----------
168
+ input_size : int
169
+ Number of channels in the input tensor.
170
+ hidden_size : int
171
+ Number of channels in the hidden state.
172
+ kernel_size : int, optional
173
+ Kernel size for the convolutional gates. Default is ``3``.
174
+ conv_layer : nn.Module, optional
175
+ Convolutional layer class to use. Default is ``nn.Conv2d``.
176
+ """
177
+
178
+ def __init__(self, input_size: int, hidden_size: int, kernel_size: int = 3, conv_layer: nn.Module = nn.Conv2d):
179
+ """
180
+ Initialize ConvGRU.
181
+
182
+ Parameters
183
+ ----------
184
+ input_size : int
185
+ Number of channels in the input tensor.
186
+ hidden_size : int
187
+ Number of channels in the hidden state.
188
+ kernel_size : int, optional
189
+ Kernel size for the convolutional gates. Default is ``3``.
190
+ conv_layer : nn.Module, optional
191
+ Convolutional layer class to use. Default is ``nn.Conv2d``.
192
+ """
193
+ super().__init__()
194
+ self.cell = ConvGRUCell(input_size, hidden_size, kernel_size, conv_layer)
195
+
196
+ def forward(self, x: torch.Tensor | None = None, h: torch.Tensor | None = None) -> torch.Tensor:
197
+ """
198
+ Unroll the ConvGRU cell over the sequence (time) dimension.
199
+
200
+ .. code-block:: text
201
+
202
+ x[:, 0] x[:, 1]
203
+ | |
204
+ v v
205
+ *------* *------*
206
+ h --> | Cell | --> h_0 --> | Cell | --> h_1 ...
207
+ *------* *------*
208
+
209
+ If either input is ``None``, it is initialized to zeros based on the
210
+ shape of the other. If both are ``None``, a ``ValueError`` is raised.
211
+
212
+ Parameters
213
+ ----------
214
+ x : torch.Tensor or None, optional
215
+ Input tensor of shape ``(B, T, input_size, H, W)``. Default is
216
+ ``None``.
217
+ h : torch.Tensor or None, optional
218
+ Initial hidden state of shape ``(B, hidden_size, H, W)``. Default
219
+ is ``None``.
220
+
221
+ Returns
222
+ -------
223
+ hidden_states : torch.Tensor
224
+ Stacked hidden states of shape ``(B, T, hidden_size, H, W)``,
225
+ i.e. ``[h_0, h_1, h_2, ...]``.
226
+ """
227
+ h_s = []
228
+ for i in range(x.size(1)):
229
+ h = self.cell(x[:, i], h)
230
+ h_s.append(h)
231
+ return torch.stack(h_s, dim=1)
232
+
233
+
234
+ class EncoderBlock(nn.Module):
235
+ """
236
+ ConvGRU-based encoder block with spatial downsampling.
237
+
238
+ Applies a :class:`ConvGRU` followed by ``nn.PixelUnshuffle(2)`` to
239
+ halve spatial dimensions and quadruple channels.
240
+
241
+ Parameters
242
+ ----------
243
+ input_size : int
244
+ Number of input channels.
245
+ kernel_size : int, optional
246
+ Kernel size for the ConvGRU. Default is ``3``.
247
+ conv_layer : nn.Module, optional
248
+ Convolutional layer class to use. Default is ``nn.Conv2d``.
249
+ """
250
+
251
+ def __init__(self, input_size: int, kernel_size: int = 3, conv_layer: nn.Module = nn.Conv2d):
252
+ """
253
+ Initialize EncoderBlock.
254
+
255
+ Parameters
256
+ ----------
257
+ input_size : int
258
+ Number of input channels.
259
+ kernel_size : int, optional
260
+ Kernel size for the ConvGRU. Default is ``3``.
261
+ conv_layer : nn.Module, optional
262
+ Convolutional layer class to use. Default is ``nn.Conv2d``.
263
+ """
264
+ super().__init__()
265
+ self.convgru = ConvGRU(input_size, input_size, kernel_size, conv_layer)
266
+ self.down = nn.PixelUnshuffle(2)
267
+
268
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
269
+ """
270
+ Forward the encoder block.
271
+
272
+ Parameters
273
+ ----------
274
+ x : torch.Tensor
275
+ Input tensor of shape ``(B, T, C, H, W)``.
276
+
277
+ Returns
278
+ -------
279
+ out : torch.Tensor
280
+ Downsampled tensor of shape ``(B, T, C*4, H/2, W/2)``.
281
+ """
282
+ x = self.convgru(x)
283
+ x = self.down(x)
284
+ return x
285
+
286
+
287
+ class Encoder(nn.Module):
288
+ r"""
289
+ ConvGRU-based encoder that stacks multiple :class:`EncoderBlock` layers.
290
+
291
+ After each block the spatial resolution is halved via pixel-unshuffle.
292
+
293
+ .. code-block:: text
294
+
295
+ /// Encoder Block 1 \\\ /// Encoder Block 2 \\\
296
+ /--------------------------------------------\ /---------------------------------------\
297
+ | | |
298
+ * *---------* *-----------------* * *---------* *-----------------* *
299
+ X -> | ConvGRU | ---> | Pixel Unshuffle | ---> | ConvGRU | ---> | Pixel Unshuffle | ---> ...
300
+ | *---------* | *-----------------* | *---------* | *-----------------* |
301
+ v v v v v
302
+ (b,t,c,h,w) (b,t,c,h,w) (b,t,c*4,h/2,w/2) (b,t,c*4,h/2,w/2) (b,t,c*16,h/4,w/4)
303
+
304
+ Parameters
305
+ ----------
306
+ input_channels : int, optional
307
+ Number of input channels. Default is ``1``.
308
+ num_blocks : int, optional
309
+ Number of encoder blocks to stack. Default is ``4``.
310
+ **kwargs
311
+ Additional keyword arguments forwarded to each :class:`EncoderBlock`.
312
+ """
313
+
314
+ def __init__(self, input_channels: int = 1, num_blocks: int = 4, **kwargs):
315
+ """
316
+ Initialize Encoder.
317
+
318
+ Parameters
319
+ ----------
320
+ input_channels : int, optional
321
+ Number of input channels. Default is ``1``.
322
+ num_blocks : int, optional
323
+ Number of encoder blocks to stack. Default is ``4``.
324
+ **kwargs
325
+ Additional keyword arguments forwarded to each
326
+ :class:`EncoderBlock`.
327
+ """
328
+ super().__init__()
329
+ self.channel_sizes = [input_channels * 4**i for i in range(num_blocks)] # [1, 4, 16, 64]
330
+ self.blocks = nn.ModuleList([EncoderBlock(self.channel_sizes[i], **kwargs) for i in range(num_blocks)])
331
+
332
+ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
333
+ """
334
+ Forward the encoder through all blocks.
335
+
336
+ Parameters
337
+ ----------
338
+ x : torch.Tensor
339
+ Input tensor of shape ``(B, T, C, H, W)``.
340
+
341
+ Returns
342
+ -------
343
+ hidden_states : list of torch.Tensor
344
+ Hidden state tensors from each block, with progressively reduced
345
+ spatial dimensions:
346
+ ``[(B, T, C*4, H/2, W/2), (B, T, C*16, H/4, W/4), ...]``.
347
+ """
348
+ hidden_states = []
349
+ for block in self.blocks:
350
+ x = block(x)
351
+ hidden_states.append(x)
352
+ return hidden_states
353
+
354
+
355
+ class DecoderBlock(nn.Module):
356
+ """
357
+ ConvGRU-based decoder block with spatial upsampling.
358
+
359
+ Applies a :class:`ConvGRU` followed by ``nn.PixelShuffle(2)`` to double
360
+ spatial dimensions and quarter channels.
361
+
362
+ Parameters
363
+ ----------
364
+ input_size : int
365
+ Number of input channels.
366
+ hidden_size : int
367
+ Number of hidden channels for the ConvGRU.
368
+ kernel_size : int, optional
369
+ Kernel size for the ConvGRU. Default is ``3``.
370
+ conv_layer : nn.Module, optional
371
+ Convolutional layer class to use. Default is ``nn.Conv2d``.
372
+ """
373
+
374
+ def __init__(self, input_size: int, hidden_size: int, kernel_size: int = 3, conv_layer: nn.Module = nn.Conv2d):
375
+ """
376
+ Initialize DecoderBlock.
377
+
378
+ Parameters
379
+ ----------
380
+ input_size : int
381
+ Number of input channels.
382
+ hidden_size : int
383
+ Number of hidden channels for the ConvGRU.
384
+ kernel_size : int, optional
385
+ Kernel size for the ConvGRU. Default is ``3``.
386
+ conv_layer : nn.Module, optional
387
+ Convolutional layer class to use. Default is ``nn.Conv2d``.
388
+ """
389
+ super().__init__()
390
+ self.convgru = ConvGRU(input_size, hidden_size, kernel_size, conv_layer)
391
+ self.up = nn.PixelShuffle(2)
392
+
393
+ def forward(self, x: torch.Tensor, hidden_state: torch.Tensor) -> torch.Tensor:
394
+ """
395
+ Forward the decoder block.
396
+
397
+ Parameters
398
+ ----------
399
+ x : torch.Tensor
400
+ Input tensor of shape ``(B, T, C, H, W)``.
401
+ hidden_state : torch.Tensor
402
+ Hidden state from the corresponding encoder block, of shape
403
+ ``(B, hidden_size, H, W)``.
404
+
405
+ Returns
406
+ -------
407
+ out : torch.Tensor
408
+ Upsampled tensor of shape ``(B, T, hidden_size // 4, H*2, W*2)``.
409
+ """
410
+ x = self.convgru(x, hidden_state)
411
+ x = self.up(x)
412
+ return x
413
+
414
+
415
+ class Decoder(nn.Module):
416
+ r"""
417
+ ConvGRU-based decoder that stacks multiple :class:`DecoderBlock` layers.
418
+
419
+ After each block the spatial resolution is doubled via pixel-shuffle.
420
+ Hidden sizes are computed from the desired output channels.
421
+
422
+ Parameters
423
+ ----------
424
+ output_channels : int, optional
425
+ Number of output channels. Default is ``1``.
426
+ num_blocks : int, optional
427
+ Number of decoder blocks to stack. Default is ``4``.
428
+ **kwargs
429
+ Additional keyword arguments forwarded to each :class:`DecoderBlock`.
430
+ """
431
+
432
+ def __init__(self, output_channels: int = 1, num_blocks: int = 4, **kwargs):
433
+ """
434
+ Initialize Decoder.
435
+
436
+ Parameters
437
+ ----------
438
+ output_channels : int, optional
439
+ Number of output channels. Default is ``1``.
440
+ num_blocks : int, optional
441
+ Number of decoder blocks to stack. Default is ``4``.
442
+ **kwargs
443
+ Additional keyword arguments forwarded to each
444
+ :class:`DecoderBlock`.
445
+ """
446
+ super().__init__()
447
+ self.channel_sizes = [output_channels * 4 ** (i + 1) for i in reversed(range(num_blocks))] # [256, 64, 16, 4]
448
+ self.blocks = nn.ModuleList(
449
+ [DecoderBlock(self.channel_sizes[i], self.channel_sizes[i], **kwargs) for i in range(num_blocks)]
450
+ )
451
+
452
+ def forward(self, x: torch.Tensor, hidden_states: list[torch.Tensor]) -> torch.Tensor:
453
+ """
454
+ Forward the decoder through all blocks.
455
+
456
+ Parameters
457
+ ----------
458
+ x : torch.Tensor
459
+ Input tensor of shape ``(B, T, C, H, W)``.
460
+ hidden_states : list of torch.Tensor
461
+ Hidden states from the encoder (in reverse order), one per block.
462
+
463
+ Returns
464
+ -------
465
+ out : torch.Tensor
466
+ Output tensor of shape
467
+ ``(B, T, output_channels, H * 2^num_blocks, W * 2^num_blocks)``.
468
+ """
469
+ for block, hidden_state in zip(self.blocks, hidden_states, strict=True):
470
+ x = block(x, hidden_state)
471
+ return x
472
+
473
+
474
+ class EncoderDecoder(nn.Module):
475
+ """
476
+ Full encoder-decoder model for spatio-temporal forecasting.
477
+
478
+ Encodes an input sequence into multi-scale hidden states and decodes
479
+ them into a forecast sequence, optionally generating multiple ensemble
480
+ members via noisy decoder inputs.
481
+
482
+ Parameters
483
+ ----------
484
+ channels : int, optional
485
+ Number of input/output channels. Default is ``1``.
486
+ num_blocks : int, optional
487
+ Number of encoder and decoder blocks. Default is ``4``.
488
+ **kwargs
489
+ Additional keyword arguments forwarded to :class:`Encoder` and
490
+ :class:`Decoder`.
491
+ """
492
+
493
+ def __init__(self, channels: int = 1, num_blocks: int = 4, **kwargs):
494
+ """
495
+ Initialize EncoderDecoder.
496
+
497
+ Parameters
498
+ ----------
499
+ channels : int, optional
500
+ Number of input/output channels. Default is ``1``.
501
+ num_blocks : int, optional
502
+ Number of encoder and decoder blocks. Default is ``4``.
503
+ **kwargs
504
+ Additional keyword arguments forwarded to :class:`Encoder` and
505
+ :class:`Decoder`.
506
+ """
507
+ super().__init__()
508
+ self.encoder = Encoder(channels, num_blocks, **kwargs)
509
+ self.decoder = Decoder(channels, num_blocks, **kwargs)
510
+
511
+ def forward(self, x: torch.Tensor, steps: int, noisy_decoder: bool = False, ensemble_size: int = 1) -> torch.Tensor:
512
+ """
513
+ Forward the encoder-decoder model.
514
+
515
+ Parameters
516
+ ----------
517
+ x : torch.Tensor
518
+ Input tensor of shape ``(B, T, C, H, W)``.
519
+ steps : int
520
+ Number of future timesteps to forecast.
521
+ noisy_decoder : bool, optional
522
+ If ``True``, feed random noise (instead of zeros) as input to the
523
+ decoder. Default is ``False``.
524
+ ensemble_size : int, optional
525
+ Number of ensemble members to generate. When ``> 1``, the decoder
526
+ is always run with noisy inputs. Default is ``1``.
527
+
528
+ Returns
529
+ -------
530
+ preds : torch.Tensor
531
+ Forecast tensor. Shape is ``(B, steps, C, H, W)`` when
532
+ ``ensemble_size == 1``, or
533
+ ``(B, steps, ensemble_size * C, H, W)`` when ``ensemble_size > 1``
534
+ (for C=1, this is ``(B, steps, ensemble_size, H, W)``).
535
+ """
536
+
537
+ # encode the input tensor into a sequence of hidden states
538
+ encoded = self.encoder(x)
539
+
540
+ # create a tensor with the same shape as the last hidden state of the encoder to use as a input for the decoder
541
+ x_dec_shape = list(encoded[-1].shape)
542
+
543
+ # set the desired number of timestep for the output
544
+ x_dec_shape[1] = steps
545
+
546
+ # collect all the last hidden states of the encoder blocks in reverse order
547
+ last_hidden_per_block = [e[:, -1] for e in reversed(encoded)]
548
+
549
+ if ensemble_size > 1:
550
+ # Generate M ensemble members by running decoder M times with different noise
551
+ preds = []
552
+ for _ in range(ensemble_size):
553
+ # the input will be random noise for each ensemble member
554
+ x_dec = torch.randn(x_dec_shape, dtype=encoded[-1].dtype, device=encoded[-1].device)
555
+ # decode (unroll) the input hidden states into a forecast sequence of N timesteps
556
+ decoded = self.decoder(x_dec, last_hidden_per_block)
557
+ preds.append(decoded)
558
+ # stack along channel/ensemble dimension: (B, T, M, H, W)
559
+ return torch.cat(preds, dim=2)
560
+ else:
561
+ # the input will be of random values if noisy_decoder is True, otherwise with zeros
562
+ x_dec_func = torch.randn if noisy_decoder else torch.zeros
563
+
564
+ # create the input tensor for the decoder
565
+ x_dec = x_dec_func(x_dec_shape, dtype=encoded[-1].dtype, device=encoded[-1].device)
566
+
567
+ # decode (unroll) the input hidden states into a forecast sequence of N timesteps
568
+ decoded = self.decoder(x_dec, last_hidden_per_block)
569
+ return decoded
convgru_ensemble/py.typed ADDED
File without changes
convgru_ensemble/serve.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI inference server for ConvGRU-Ensemble nowcasting model."""
2
+
3
+ import io
4
+ import os
5
+ import time
6
+ from contextlib import asynccontextmanager
7
+
8
+ import numpy as np
9
+ import xarray as xr
10
+ from fastapi import FastAPI, File, Query, UploadFile
11
+ from fastapi.responses import Response
12
+
13
+ _model = None
14
+
15
+
16
+ def _load_model():
17
+ from .lightning_model import RadarLightningModel
18
+
19
+ device = os.environ.get("DEVICE", "cpu")
20
+ checkpoint = os.environ.get("MODEL_CHECKPOINT")
21
+ hub_repo = os.environ.get("HF_REPO_ID")
22
+
23
+ if hub_repo:
24
+ return RadarLightningModel.from_pretrained(hub_repo, device=device)
25
+ elif checkpoint:
26
+ return RadarLightningModel.from_checkpoint(checkpoint, device=device)
27
+ else:
28
+ raise RuntimeError("Set MODEL_CHECKPOINT or HF_REPO_ID environment variable.")
29
+
30
+
31
+ @asynccontextmanager
32
+ async def lifespan(app: FastAPI):
33
+ global _model
34
+ _model = _load_model()
35
+ yield
36
+ _model = None
37
+
38
+
39
+ app = FastAPI(
40
+ title="ConvGRU-Ensemble Nowcasting API",
41
+ version="0.1.0",
42
+ description="Ensemble precipitation nowcasting from radar data",
43
+ lifespan=lifespan,
44
+ )
45
+
46
+
47
+ @app.get("/health")
48
+ async def health():
49
+ """Health check endpoint."""
50
+ return {"status": "ok", "model_loaded": _model is not None}
51
+
52
+
53
+ @app.get("/model/info")
54
+ async def model_info():
55
+ """Return model metadata."""
56
+ if _model is None:
57
+ return {"error": "Model not loaded"}
58
+ hp = _model.hparams
59
+ return {
60
+ "architecture": "ConvGRU-Ensemble EncoderDecoder",
61
+ "input_channels": hp.input_channels,
62
+ "num_blocks": hp.num_blocks,
63
+ "forecast_steps": hp.forecast_steps,
64
+ "ensemble_size": hp.ensemble_size,
65
+ "noisy_decoder": hp.noisy_decoder,
66
+ "loss_class": str(hp.loss_class),
67
+ "device": str(_model.device),
68
+ }
69
+
70
+
71
+ @app.post("/predict")
72
+ async def predict(
73
+ file: UploadFile = File(..., description="NetCDF file with rain rate data (T, H, W)"), # noqa: B008
74
+ variable: str = Query("RR", description="Name of the rain rate variable"), # noqa: B008
75
+ forecast_steps: int = Query(12, ge=1, le=48, description="Number of future timesteps"), # noqa: B008
76
+ ensemble_size: int = Query(10, ge=1, le=50, description="Number of ensemble members"), # noqa: B008
77
+ ):
78
+ """
79
+ Run ensemble nowcasting inference on uploaded NetCDF data.
80
+
81
+ Accepts a NetCDF file containing past radar rain rate observations and
82
+ returns NetCDF predictions with ensemble forecasts.
83
+ """
84
+ t0 = time.perf_counter()
85
+
86
+ # Read uploaded NetCDF
87
+ content = await file.read()
88
+ ds = xr.open_dataset(io.BytesIO(content))
89
+
90
+ if variable not in ds:
91
+ available = list(ds.data_vars)
92
+ return Response(
93
+ content=f"Variable '{variable}' not found. Available: {available}",
94
+ status_code=422,
95
+ )
96
+
97
+ data = ds[variable].values
98
+ if data.ndim != 3:
99
+ return Response(
100
+ content=f"Expected 3D data (T, H, W), got shape {data.shape}",
101
+ status_code=422,
102
+ )
103
+
104
+ past = data.astype(np.float32)
105
+
106
+ # Run inference
107
+ preds = _model.predict(past, forecast_steps=forecast_steps, ensemble_size=ensemble_size)
108
+
109
+ elapsed = time.perf_counter() - t0
110
+
111
+ # Build output NetCDF
112
+ ds_out = xr.Dataset(
113
+ {
114
+ "precipitation_forecast": xr.DataArray(
115
+ data=preds,
116
+ dims=["ensemble_member", "forecast_step", "y", "x"],
117
+ attrs={"units": "mm/h", "long_name": "Ensemble precipitation forecast"},
118
+ ),
119
+ },
120
+ attrs={
121
+ "model": "ConvGRU-Ensemble",
122
+ "forecast_steps": forecast_steps,
123
+ "ensemble_size": ensemble_size,
124
+ "elapsed_seconds": f"{elapsed:.3f}",
125
+ },
126
+ )
127
+
128
+ buf = io.BytesIO()
129
+ ds_out.to_netcdf(buf)
130
+ buf.seek(0)
131
+
132
+ return Response(
133
+ content=buf.getvalue(),
134
+ media_type="application/x-netcdf",
135
+ headers={
136
+ "Content-Disposition": "attachment; filename=predictions.nc",
137
+ "X-Elapsed-Seconds": f"{elapsed:.3f}",
138
+ },
139
+ )
convgru_ensemble/train.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py
2
+ import os
3
+ import sys
4
+ from datetime import datetime
5
+
6
+ import fiddle as fdl
7
+ import torch
8
+ import yaml
9
+ from absl import app, flags
10
+ from fiddle import absl_flags, printing
11
+ from pytorch_lightning import Trainer, seed_everything
12
+ from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
13
+ from pytorch_lightning.loggers import TensorBoardLogger
14
+
15
+ from .datamodule import RadarDataModule
16
+ from .lightning_model import RadarLightningModel
17
+ from .losses import PIXEL_LOSSES
18
+
19
+ seed_everything(42, workers=True)
20
+
21
+ FLAGS = flags.FLAGS
22
+ flags.DEFINE_bool("print_config", False, "Print configuration and exit.")
23
+ flags.DEFINE_string("export_yaml", None, "Export configuration to YAML file and exit.")
24
+
25
+
26
+ def experiment() -> fdl.Config:
27
+ """
28
+ Define the default experiment configuration.
29
+
30
+ Returns a Fiddle config that can be overridden from the command line
31
+ with ``--config config:experiment --config set:path.to.value=X``.
32
+
33
+ Returns
34
+ -------
35
+ cfg : fdl.Config
36
+ Nested Fiddle configuration containing datamodule, model, trainer,
37
+ callbacks, and logger settings.
38
+ """
39
+ cfg = fdl.Config(dict)
40
+
41
+ # resume from checkpoint
42
+ cfg.checkpoint_path = None
43
+
44
+ # enable mixed precision for float32 matmuls if available
45
+ cfg.float32_matmul_precision = None
46
+
47
+ # compile model with torch.compile if desired
48
+ cfg.compile_model = False
49
+
50
+ # DataModule
51
+ cfg.datamodule = fdl.Config(
52
+ RadarDataModule,
53
+ zarr_path="./data/italian-radar-dpc-sri.zarr",
54
+ csv_path="./importance_sampler/output/sampled_datacubes_2021-01-01-2025-12-11_24x256x256_3x16x16_10000.csv",
55
+ steps=18,
56
+ train_ratio=0.90,
57
+ val_ratio=0.05,
58
+ return_mask=True,
59
+ deterministic=False,
60
+ augment=True,
61
+ # DataLoader params
62
+ batch_size=16,
63
+ num_workers=8,
64
+ pin_memory=True,
65
+ multiprocessing_context="fork",
66
+ )
67
+
68
+ # Lightning Model
69
+ cfg.model = fdl.Config(
70
+ RadarLightningModel,
71
+ input_channels=1,
72
+ forecast_steps=12,
73
+ num_blocks=5,
74
+ ensemble_size=2,
75
+ noisy_decoder=False,
76
+ loss_class="crps",
77
+ loss_params={"temporal_lambda": 0.01},
78
+ masked_loss=True,
79
+ optimizer_class=torch.optim.Adam,
80
+ optimizer_params={"lr": 1e-4, "fused": True},
81
+ lr_scheduler_class=torch.optim.lr_scheduler.ReduceLROnPlateau,
82
+ lr_scheduler_params={"mode": "min", "factor": 0.5, "patience": 10},
83
+ )
84
+
85
+ # Trainer
86
+ cfg.trainer = fdl.Config(
87
+ Trainer,
88
+ accelerator="auto",
89
+ # gradient_clip_val=1.0,
90
+ max_epochs=1,
91
+ )
92
+
93
+ # Callbacks
94
+ cfg.callbacks = fdl.Config(dict)
95
+ cfg.callbacks.checkpoint_val = fdl.Config(
96
+ ModelCheckpoint,
97
+ monitor="val_loss",
98
+ save_top_k=1,
99
+ mode="min",
100
+ dirpath=None,
101
+ filename=None, # Set dynamically: best-val-{ckpt_name}
102
+ save_last=False,
103
+ )
104
+ cfg.callbacks.checkpoint_train = fdl.Config(
105
+ ModelCheckpoint,
106
+ monitor="train_loss_epoch",
107
+ save_top_k=1,
108
+ mode="min",
109
+ dirpath=None,
110
+ filename=None, # Set dynamically: best-train-{ckpt_name}
111
+ save_last=False,
112
+ )
113
+ cfg.callbacks.early_stopping = fdl.Config(
114
+ EarlyStopping,
115
+ monitor="val_loss",
116
+ patience=100,
117
+ mode="min",
118
+ )
119
+ cfg.callbacks.lr_monitor = fdl.Config(
120
+ LearningRateMonitor,
121
+ logging_interval="step",
122
+ log_momentum=False,
123
+ log_weight_decay=False,
124
+ )
125
+
126
+ # Loggers
127
+ cfg.loggers = fdl.Config(dict)
128
+ cfg.loggers.tensorboard = fdl.Config(
129
+ TensorBoardLogger,
130
+ save_dir="logs",
131
+ name=None, # Set dynamically in train()
132
+ version=None, # Set dynamically in train()
133
+ )
134
+
135
+ return cfg
136
+
137
+
138
+ _CONFIG = absl_flags.DEFINE_fiddle_config(
139
+ "config",
140
+ default_module=sys.modules[__name__],
141
+ help_string="Experiment configuration.",
142
+ )
143
+
144
+
145
+ def train(cfg: fdl.Config) -> None:
146
+ """
147
+ Run training with the given Fiddle configuration.
148
+
149
+ Builds all components (model, datamodule, trainer, callbacks, loggers),
150
+ sets up dynamic naming for checkpoints and TensorBoard logs, saves the
151
+ config as YAML, and runs ``trainer.fit`` followed by ``trainer.test``.
152
+
153
+ Parameters
154
+ ----------
155
+ cfg : fdl.Config
156
+ Fiddle configuration as returned by :func:`experiment`.
157
+ """
158
+ # enable tensor cores for float32 matmuls if available
159
+ if cfg.float32_matmul_precision is not None:
160
+ torch.set_float32_matmul_precision(cfg.float32_matmul_precision)
161
+
162
+ # Compute dynamic values for naming
163
+ future_steps = cfg.model.forecast_steps
164
+ past_steps = cfg.datamodule.steps - future_steps
165
+
166
+ if cfg.model.loss_class is None:
167
+ loss_name = "MSELoss"
168
+ elif isinstance(cfg.model.loss_class, type):
169
+ loss_name = cfg.model.loss_class.__name__
170
+ else:
171
+ loss_name = (
172
+ PIXEL_LOSSES[cfg.model.loss_class.lower()].__name__
173
+ if cfg.model.loss_class.lower() in PIXEL_LOSSES
174
+ else str(cfg.model.loss_class)
175
+ )
176
+ lr = (
177
+ cfg.model.optimizer_params["lr"]
178
+ if cfg.model.optimizer_params is not None and "lr" in cfg.model.optimizer_params
179
+ else "default"
180
+ )
181
+
182
+ noise_str: str = "_noise" if cfg.model.noisy_decoder else ""
183
+ ckpt_base_name: str = f"{past_steps}past-{future_steps}fut{noise_str}_bs{cfg.datamodule.batch_size}_lr{lr}"
184
+
185
+ # Set dynamic logger name and version first (checkpoint folder depends on it)
186
+ if cfg.loggers.tensorboard.name is None:
187
+ cfg.loggers.tensorboard.name = f"{loss_name}_{past_steps}past-{future_steps}fut{noise_str}"
188
+
189
+ jobid = os.getenv("SLURM_JOB_ID", None)
190
+ tb_version = f"_{cfg.loggers.tensorboard.version}" if cfg.loggers.tensorboard.version is not None else ""
191
+
192
+ if jobid is not None:
193
+ cfg.loggers.tensorboard.version = f"job{jobid}_{ckpt_base_name}{tb_version}"
194
+ else:
195
+ cfg.loggers.tensorboard.version = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{ckpt_base_name}{tb_version}"
196
+
197
+ # Set checkpoint paths inside tensorboard experiment folder
198
+ tb_log_dir = f"{cfg.loggers.tensorboard.save_dir}/{cfg.loggers.tensorboard.name}/{cfg.loggers.tensorboard.version}"
199
+ ckpt_dir = f"{tb_log_dir}/checkpoints"
200
+
201
+ # Val checkpoint
202
+ if cfg.callbacks.checkpoint_val.dirpath is None:
203
+ cfg.callbacks.checkpoint_val.dirpath = ckpt_dir
204
+ if cfg.callbacks.checkpoint_val.filename is None:
205
+ cfg.callbacks.checkpoint_val.filename = "best-val-" + ckpt_base_name + "_ep{epoch:03d}_loss{val_loss:.4f}"
206
+
207
+ # Train checkpoint
208
+ if cfg.callbacks.checkpoint_train.dirpath is None:
209
+ cfg.callbacks.checkpoint_train.dirpath = ckpt_dir
210
+ if cfg.callbacks.checkpoint_train.filename is None:
211
+ cfg.callbacks.checkpoint_train.filename = (
212
+ "best-train-" + ckpt_base_name + "_ep{epoch:03d}_loss{train_loss_epoch:.4f}"
213
+ )
214
+
215
+ # Build all callbacks and loggers dynamically
216
+ callbacks_dict = fdl.build(cfg.callbacks)
217
+ loggers_dict = fdl.build(cfg.loggers)
218
+ callbacks = list(callbacks_dict.values())
219
+ loggers = list(loggers_dict.values())
220
+
221
+ # Add loggers and callbacks to trainer config
222
+ cfg.trainer.logger = loggers
223
+ cfg.trainer.callbacks = callbacks
224
+
225
+ print(printing.as_str_flattened(cfg))
226
+
227
+ # Save config to tensorboard folder
228
+ os.makedirs(tb_log_dir, exist_ok=True)
229
+ config_path = f"{tb_log_dir}/config.yaml"
230
+ with open(config_path, "w") as f:
231
+ yaml.dump(config_to_dict(cfg), f, default_flow_style=False, sort_keys=False)
232
+ print(f"Config saved to {config_path}")
233
+
234
+ # Build all components
235
+ built = fdl.build(cfg)
236
+ datamodule: RadarDataModule = built["datamodule"]
237
+
238
+ if cfg.checkpoint_path is not None:
239
+ print(f"Resuming training from checkpoint: {cfg.checkpoint_path}")
240
+ model = RadarLightningModel.load_from_checkpoint(cfg.checkpoint_path, strict=True, weights_only=False)
241
+ else:
242
+ model = built["model"]
243
+ trainer: Trainer = built["trainer"]
244
+
245
+ datamodule.setup()
246
+ print(
247
+ f"Train: {len(datamodule.train_dataset)}, Val: {len(datamodule.val_dataset)}, Test: {len(datamodule.test_dataset)}"
248
+ )
249
+
250
+ if cfg.compile_model:
251
+ print("Compiling model with torch.compile()...")
252
+ model = torch.compile(model, dynamic=True)
253
+
254
+ trainer.fit(model, datamodule=datamodule)
255
+ trainer.test(model, datamodule=datamodule)
256
+ print(f"Best val: {callbacks_dict['checkpoint_val'].best_model_path}")
257
+ print(f"Best train: {callbacks_dict['checkpoint_train'].best_model_path}")
258
+
259
+
260
+ def config_to_dict(cfg: fdl.Config) -> dict:
261
+ """
262
+ Recursively convert a Fiddle config to a nested dictionary.
263
+
264
+ Parameters
265
+ ----------
266
+ cfg : fdl.Config
267
+ Fiddle configuration object.
268
+
269
+ Returns
270
+ -------
271
+ result : dict
272
+ Plain dictionary suitable for YAML serialization.
273
+ """
274
+ result = {}
275
+ for key, value in fdl.ordered_arguments(cfg).items():
276
+ result[key] = config_to_dict(value) if isinstance(value, fdl.Config) else value
277
+ return result
278
+
279
+
280
+ def main(argv: list[str]) -> None:
281
+ """
282
+ Entry point for the training script.
283
+
284
+ Handles ``--print_config`` and ``--export_yaml`` flags, then delegates
285
+ to :func:`train`.
286
+
287
+ Parameters
288
+ ----------
289
+ argv : list of str
290
+ Command-line arguments (unused, consumed by ``absl``).
291
+ """
292
+ del argv
293
+ cfg = _CONFIG.value
294
+ if FLAGS.print_config:
295
+ print(printing.as_str_flattened(cfg))
296
+ return
297
+ if FLAGS.export_yaml:
298
+ cfg_dict = config_to_dict(cfg)
299
+ with open(FLAGS.export_yaml, "w") as f:
300
+ yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False)
301
+ print(f"Config exported to {FLAGS.export_yaml}")
302
+ return
303
+ train(cfg)
304
+
305
+
306
+ # Example command to run training with custom configuration overrides.
307
+ # uv run python train.py \
308
+ # --config config:experiment \
309
+ # --config set:callbacks.checkpoint.save_top_k=3 \
310
+ # --config set:model.num_blocks=5 \
311
+ # --config set:model.forecast_steps=12 \
312
+ # --config set:datamodule.steps=18 \
313
+ # --config set:datamodule.num_workers=32 \
314
+ # --config set:datamodule.batch_size=32
315
+ if __name__ == "__main__":
316
+ app.run(main)
convgru_ensemble/utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def rainrate_to_reflectivity(rainrate: np.ndarray) -> np.ndarray:
5
+ """
6
+ Convert rain rate to reflectivity using the Marshall-Palmer relationship.
7
+
8
+ Applies Z = 200 * R^1.6 and converts to dBZ. Values below ~0.037 mm/h
9
+ are clipped to 0 dBZ; values above 60 dBZ are clipped to 60.
10
+
11
+ Parameters
12
+ ----------
13
+ rainrate : np.ndarray
14
+ Rain rate in mm/h. Can be any shape.
15
+
16
+ Returns
17
+ -------
18
+ reflectivity : np.ndarray
19
+ Reflectivity in dBZ, clipped to [0, 60]. Same shape as input.
20
+ """
21
+ epsilon = 1e-16
22
+ # We return 0 for any rain lighter than ~0.037mm/h
23
+ return (10 * np.log10(200 * rainrate**1.6 + epsilon)).clip(0, 60)
24
+
25
+
26
+ def normalize_reflectivity(reflectivity: np.ndarray) -> np.ndarray:
27
+ """
28
+ Normalize reflectivity from [0, 60] dBZ to [-1, 1].
29
+
30
+ Parameters
31
+ ----------
32
+ reflectivity : np.ndarray
33
+ Reflectivity in dBZ, expected in [0, 60]. Can be any shape.
34
+
35
+ Returns
36
+ -------
37
+ normalized : np.ndarray
38
+ Normalized reflectivity in [-1, 1]. Same shape as input.
39
+ """
40
+ return (reflectivity / 30.0) - 1.0
41
+
42
+
43
+ def denormalize_reflectivity(normalized: np.ndarray) -> np.ndarray:
44
+ """
45
+ Denormalize from [-1, 1] back to [0, 60] dBZ reflectivity.
46
+
47
+ Parameters
48
+ ----------
49
+ normalized : np.ndarray
50
+ Normalized reflectivity in [-1, 1]. Can be any shape.
51
+
52
+ Returns
53
+ -------
54
+ reflectivity : np.ndarray
55
+ Reflectivity in dBZ, in [0, 60]. Same shape as input.
56
+ """
57
+ return (normalized + 1.0) * 30.0
58
+
59
+
60
+ def reflectivity_to_rainrate(reflectivity: np.ndarray) -> np.ndarray:
61
+ """
62
+ Convert reflectivity back to rain rate using the inverse Marshall-Palmer
63
+ relationship.
64
+
65
+ Applies R = (Z_linear / 200)^(1/1.6) where Z_linear = 10^(dBZ/10).
66
+
67
+ Parameters
68
+ ----------
69
+ reflectivity : np.ndarray
70
+ Reflectivity in dBZ. Can be any shape.
71
+
72
+ Returns
73
+ -------
74
+ rainrate : np.ndarray
75
+ Rain rate in mm/h. Same shape as input.
76
+ """
77
+ # Z = 200 * R^1.6
78
+ # R = (Z / 200)^(1/1.6)
79
+ z_linear = 10 ** (reflectivity / 10.0)
80
+ return (z_linear / 200.0) ** (1.0 / 1.6)
81
+
82
+
83
+ def rainrate_to_normalized(rainrate: np.ndarray) -> np.ndarray:
84
+ """
85
+ Convert rain rate directly to normalized reflectivity.
86
+
87
+ Composes :func:`rainrate_to_reflectivity` and
88
+ :func:`normalize_reflectivity`.
89
+
90
+ Parameters
91
+ ----------
92
+ rainrate : np.ndarray
93
+ Rain rate in mm/h. Can be any shape.
94
+
95
+ Returns
96
+ -------
97
+ normalized : np.ndarray
98
+ Normalized reflectivity in [-1, 1]. Same shape as input.
99
+ """
100
+ reflectivity = rainrate_to_reflectivity(rainrate)
101
+ return normalize_reflectivity(reflectivity)
102
+
103
+
104
+ def normalized_to_rainrate(normalized: np.ndarray) -> np.ndarray:
105
+ """
106
+ Convert normalized reflectivity back to rain rate.
107
+
108
+ Composes :func:`denormalize_reflectivity` and
109
+ :func:`reflectivity_to_rainrate`.
110
+
111
+ Parameters
112
+ ----------
113
+ normalized : np.ndarray
114
+ Normalized reflectivity in [-1, 1]. Can be any shape.
115
+
116
+ Returns
117
+ -------
118
+ rainrate : np.ndarray
119
+ Rain rate in mm/h. Same shape as input.
120
+ """
121
+ reflectivity = denormalize_reflectivity(normalized)
122
+ return reflectivity_to_rainrate(reflectivity)
docker-compose.yml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ api:
3
+ build: .
4
+ ports:
5
+ - "8000:8000"
6
+ volumes:
7
+ - ./checkpoints:/app/checkpoints
8
+ environment:
9
+ - MODEL_CHECKPOINT=/app/checkpoints/model.ckpt
10
+ - DEVICE=cpu
11
+ # Alternative: download from HuggingFace Hub
12
+ # - HF_REPO_ID=it4lia/irene
examples/sample_data.nc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a94b2f0c576f8e80cac470b1994bf8faed3eae11fb6fee6feba21e1a52dc0e7
3
+ size 11232217
importance_sampler/filter_nan.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import time
5
+ from functools import partial
6
+ from multiprocessing import Pool
7
+ from queue import Queue
8
+ from threading import Thread
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import xarray as xr
13
+ from tqdm import tqdm
14
+
15
+ START = time.time()
16
+
17
+
18
+ # === Functions ===
19
+ def dim_nan_count(mask, dim, delta, dim_len):
20
+ """
21
+ Count NaN values in a rolling window along a single dimension.
22
+
23
+ Uses a cumulative-sum trick to efficiently compute the number of NaN
24
+ values within each window of size ``delta`` along the specified axis.
25
+
26
+ Parameters
27
+ ----------
28
+ mask : np.ndarray
29
+ 3-D binary array where 1 indicates NaN and 0 indicates valid.
30
+ dim : int
31
+ Axis along which to compute the rolling NaN count (0, 1, or 2).
32
+ delta : int
33
+ Window size (number of pixels) along ``dim``.
34
+ dim_len : int
35
+ Length of ``dim`` in the input array.
36
+
37
+ Returns
38
+ -------
39
+ nan_counts : np.ndarray
40
+ Integer array with the NaN count for each window position.
41
+ """
42
+ cumsum = np.cumsum(mask, axis=dim, dtype=np.int32)
43
+
44
+ # Pad with zeros at the start along 'dim'
45
+ pad_width = [(1, 0) if i == dim else (0, 0) for i in range(3)]
46
+ padded_cumsum = np.pad(cumsum, pad_width=pad_width, mode="constant", constant_values=0)
47
+
48
+ # Rolling window: padded[start+delta:start+delta+dim_len] - padded[start:start+dim_len-delta]
49
+ slices_start = [slice(dim_len - delta) if i == dim else slice(None) for i in range(3)]
50
+ slices_end = [slice(delta, dim_len) if i == dim else slice(None) for i in range(3)]
51
+
52
+ # number of nans in each delta
53
+ return padded_cumsum[tuple(slices_end)] - padded_cumsum[tuple(slices_start)]
54
+
55
+
56
+ def dc_nan_count(chunk, deltas, dim_lenghts):
57
+ """
58
+ Count the number of NaN values in each 3-D datacube within a chunk.
59
+
60
+ Applies :func:`dim_nan_count` sequentially along time, X, and Y to
61
+ compute the total NaN count for every possible datacube position.
62
+
63
+ Parameters
64
+ ----------
65
+ chunk : np.ndarray
66
+ 3-D array of shape ``(T, X, Y)`` — a time chunk of the full Zarr
67
+ dataset.
68
+ deltas : tuple of int
69
+ Datacube dimensions ``(Dt, w, h)`` along time, X, and Y.
70
+ dim_lenghts : tuple of int
71
+ Shape of ``chunk``, i.e. ``(T, X, Y)``.
72
+
73
+ Returns
74
+ -------
75
+ nans_cube_chunk : np.ndarray
76
+ Integer array where ``nans_cube_chunk[it, ix, iy]`` is the number of
77
+ NaN values in the datacube ``chunk[it:it+Dt, ix:ix+w, iy:iy+h]``.
78
+ """
79
+ # Compute NaN mask and cumsum along time axis
80
+ nan_mask = np.isnan(chunk).astype(np.int16)
81
+
82
+ # Number of NaN along time
83
+ nans_t = dim_nan_count(nan_mask, dim=0, delta=deltas[0], dim_len=dim_lenghts[0])
84
+
85
+ # Number of NaN along X x T
86
+ nans_xt = dim_nan_count(nans_t, dim=1, delta=deltas[1], dim_len=dim_lenghts[1])
87
+
88
+ # Number of NaN in the datacube (Y x X x T)
89
+ nans_cube_chunk = dim_nan_count(nans_xt, dim=2, delta=deltas[2], dim_len=dim_lenghts[2])
90
+
91
+ return nans_cube_chunk
92
+
93
+
94
+ def process_chunk(time_range, t_start_idx, data, N_nan, deltas, steps, valid_starts_gap, dc_nan_count):
95
+ """
96
+ Process a single time chunk and return valid datacube indices.
97
+
98
+ Loads the chunk from the Zarr array, counts NaN values per datacube,
99
+ filters by the maximum allowed NaN threshold and time-continuity
100
+ constraints, and returns the valid ``(t, x, y)`` indices.
101
+
102
+ Parameters
103
+ ----------
104
+ time_range : array-like of int
105
+ Two-element sequence ``(start_t, end_t)`` defining the chunk
106
+ boundaries (relative to ``t_start_idx``).
107
+ t_start_idx : int
108
+ Absolute time index corresponding to the dataset start date.
109
+ data : xr.DataArray
110
+ Zarr-backed data array with the ``'RR'`` variable.
111
+ N_nan : int
112
+ Maximum number of NaN values allowed per datacube.
113
+ deltas : tuple of int
114
+ Datacube dimensions ``(Dt, w, h)``.
115
+ steps : tuple of int
116
+ Stride ``(step_T, step_X, step_Y)`` for subsampling valid indices.
117
+ valid_starts_gap : np.ndarray
118
+ Array of valid time-start indices that have no temporal gaps.
119
+ dc_nan_count : callable
120
+ Function to count NaN values per datacube (see :func:`dc_nan_count`).
121
+
122
+ Returns
123
+ -------
124
+ idx_t : np.ndarray
125
+ Absolute time indices of valid datacubes.
126
+ idx_x : np.ndarray
127
+ X indices of valid datacubes.
128
+ idx_y : np.ndarray
129
+ Y indices of valid datacubes.
130
+ """
131
+ try:
132
+ # start_t: start index of the chunk
133
+ start_t, end_t = time_range
134
+
135
+ # Chunk from Zarr (T, X, Y)
136
+ chunk = data[start_t + t_start_idx : end_t + t_start_idx, :, :]
137
+ dim_lenghts = chunk.shape # shape: (T, X, Y)
138
+
139
+ # Compute the number of NaNs in each datacube in chunk
140
+ nans_cube_chunk = dc_nan_count(chunk, deltas, dim_lenghts)
141
+ del chunk
142
+
143
+ # Apply the mask
144
+ valid_mask = nans_cube_chunk <= N_nan
145
+ del nans_cube_chunk
146
+
147
+ # This indices are relative to the chunk
148
+ idx_t_rel, idx_x, idx_y = np.where(valid_mask)
149
+ del valid_mask
150
+
151
+ # Cast to int32
152
+ idx_t_rel = idx_t_rel.astype(np.int32)
153
+ idx_x = idx_x.astype(np.int32)
154
+ idx_y = idx_y.astype(np.int32)
155
+
156
+ # Convert relative time indices
157
+ idx_t = idx_t_rel + start_t
158
+
159
+ # Keep only time indices in valid_starts_gap
160
+ time_mask = np.isin(idx_t, valid_starts_gap)
161
+ idx_t = idx_t[time_mask] + t_start_idx # also convert to absolute index
162
+ idx_x = idx_x[time_mask]
163
+ idx_y = idx_y[time_mask]
164
+
165
+ # Filter datacube indices according to steps
166
+ stride_mask = (idx_t % steps[0] == 0) & (idx_x % steps[1] == 0) & (idx_y % steps[2] == 0)
167
+ idx_x = idx_x[stride_mask]
168
+ idx_y = idx_y[stride_mask]
169
+ idx_t = idx_t[stride_mask]
170
+
171
+ return idx_t, idx_x, idx_y
172
+
173
+ except Exception as e:
174
+ print(f"Error processing chunk starting at t={start_t}: {e}", file=sys.stderr)
175
+ sys.exit(1)
176
+
177
+
178
+ def file_writer(output_queue, filename, batch_size=1000):
179
+ """
180
+ Dedicated writer thread that flushes results to a CSV file in batches.
181
+
182
+ Reads ``(idx_t, idx_x, idx_y)`` tuples from the queue and writes them
183
+ as rows to the output file. Stops when a ``None`` sentinel is received.
184
+
185
+ Parameters
186
+ ----------
187
+ output_queue : queue.Queue
188
+ Thread-safe queue providing ``(idx_t, idx_x, idx_y)`` tuples.
189
+ filename : str
190
+ Path to the output CSV file.
191
+ batch_size : int, optional
192
+ Number of rows to buffer before flushing to disk. Default is
193
+ ``1000``.
194
+ """
195
+ with open(filename, "w") as f:
196
+ f.write("t,x,y\n")
197
+ batch = []
198
+
199
+ while True:
200
+ item = output_queue.get()
201
+
202
+ if item is None: # Sentinel value to stop
203
+ # Write remaining batch
204
+ for t, x, y in batch:
205
+ f.write(f"{t},{x},{y}\n")
206
+ break
207
+
208
+ batch.extend(zip(*item, strict=True))
209
+
210
+ if len(batch) >= batch_size:
211
+ for t, x, y in batch:
212
+ f.write(f"{t},{x},{y}\n")
213
+ f.flush()
214
+ batch = []
215
+
216
+ print(f"Results saved to {filename}")
217
+
218
+
219
+ # === Parse Arguments ===
220
+ parser = argparse.ArgumentParser(description="Process valid datacubes from Zarr dataset")
221
+ parser.add_argument("zarr_path", help="Path to the Zarr dataset")
222
+ parser.add_argument("--start_date", default=None, type=str, help="Start date (YYYY-MM-DD)")
223
+ parser.add_argument("--end_date", default=None, type=str, help="End date (YYYY-MM-DD)")
224
+ parser.add_argument("--Dt", type=int, default=24, help="Time depth")
225
+ parser.add_argument("--w", type=int, default=256, help="Spatial width")
226
+ parser.add_argument("--h", type=int, default=256, help="Spatial height")
227
+ parser.add_argument("--step_T", type=int, default=3, help="Time step")
228
+ parser.add_argument("--step_X", type=int, default=16, help="X step")
229
+ parser.add_argument("--step_Y", type=int, default=16, help="Y step")
230
+ parser.add_argument("--n_workers", type=int, default=8, help="Number of parallel workers")
231
+ parser.add_argument("--n_nan", type=int, default=10000, help="Maximum NaNs per datacube")
232
+ args = parser.parse_args()
233
+
234
+
235
+ # === PARAMETERS ===
236
+ Dt = args.Dt # time depth
237
+ w = args.w # x width
238
+ h = args.h # y height
239
+ step_T = args.step_T
240
+ step_X = args.step_X
241
+ step_Y = args.step_Y
242
+ N_nan = args.n_nan # maximum number of nans in each datacube
243
+
244
+ n_workers = args.n_workers
245
+ time_chunk_size = 3 * Dt
246
+
247
+
248
+ # === Dataset Loading ===
249
+ print(f"Opening Zarr dataset: {args.zarr_path}")
250
+ try:
251
+ # zg = zarr.open(args.zarr_path, mode='r')
252
+ zg = xr.open_zarr(args.zarr_path, decode_times=True)
253
+ RR_full = zg["RR"]
254
+ time_array_full = pd.to_datetime(zg["time"][:])
255
+
256
+ print(f"Full dataset shape: T={RR_full.shape[0]}, X={RR_full.shape[1]}, Y={RR_full.shape[2]}")
257
+ print(f"Full dataset time range: {time_array_full[0]} to {time_array_full[-1]}")
258
+ except Exception as e:
259
+ print(f"Error loading Zarr dataset: {e}")
260
+ sys.exit(1)
261
+
262
+ # Filter the dates
263
+ start_date = pd.to_datetime(args.start_date) if args.start_date else time_array_full[0]
264
+ end_date = pd.to_datetime(args.end_date) if args.end_date else time_array_full[-1]
265
+
266
+ # Find indices corresponding to date range
267
+ mask = (time_array_full >= start_date) & (time_array_full <= end_date)
268
+ valid_indices = np.where(mask)[0]
269
+
270
+ if len(valid_indices) == 0:
271
+ print(f"No data found between {start_date} and {end_date}")
272
+ sys.exit(1)
273
+
274
+ t_start_idx = valid_indices[0]
275
+ t_end_idx = valid_indices[-1] + 1
276
+
277
+ # Slice the data
278
+ size_T = t_end_idx - t_start_idx
279
+ size_X = RR_full.shape[1]
280
+ size_Y = RR_full.shape[2]
281
+ time_array = time_array_full[t_start_idx:t_end_idx]
282
+
283
+ print(f"Filtered dataset shape: T={size_T}, X={size_X}, Y={size_Y}")
284
+ print(f"Filtered dataset time range: {time_array[0]} to {time_array[-1]}")
285
+
286
+ # Calculate maximum valid indices
287
+ max_x = size_X - w + 1
288
+ max_y = size_Y - h + 1
289
+ max_t = size_T - Dt + 1
290
+
291
+ # === Time Continuity ===
292
+ print("Checking time continuity...")
293
+ try:
294
+ expected_step = pd.Timedelta("00:05:00")
295
+ time_diffs = time_array[1:] - time_array[:-1]
296
+ gaps = (time_diffs != expected_step).astype(int)
297
+
298
+ # Check continuity for windows of size Dt
299
+ window_sum = np.convolve(gaps, np.ones(Dt - 1, dtype=int), mode="valid")
300
+
301
+ # Find valid starting times: continuous windows at T_step intervals
302
+ valid_starts_gap = np.where(window_sum == 0)[0]
303
+ print(f"Found {len(valid_starts_gap)} valid time starts without gaps")
304
+
305
+
306
+ except Exception as e:
307
+ print(f"Error in time continuity check: {e}", file=sys.stderr)
308
+ sys.exit(1)
309
+
310
+
311
+ # === Chunked NaN Processing ===
312
+ # Memory per chunk (x 4 because float32)
313
+ estimated_chunk_memory_gb = (time_chunk_size * size_X * size_Y * 4) / (1024**3)
314
+ print(f"Estimated memory per chunk: {estimated_chunk_memory_gb:.2f} GB")
315
+ print(f"Estimated total memory: {(estimated_chunk_memory_gb * n_workers):.2f} GB")
316
+
317
+ # Process time in chunks with overlap Dt
318
+ t_starts = np.arange(0, max_t, time_chunk_size)
319
+ t_ends = np.minimum(t_starts + time_chunk_size + Dt - 1, size_T)
320
+ t_pairs = np.stack((t_starts, t_ends), axis=1)
321
+
322
+ # Create partial function with fixed parameters
323
+ process_chunk_partial = partial(
324
+ process_chunk,
325
+ t_start_idx=t_start_idx,
326
+ data=RR_full,
327
+ N_nan=N_nan,
328
+ deltas=(Dt, w, h),
329
+ steps=(step_T, step_X, step_Y),
330
+ valid_starts_gap=valid_starts_gap,
331
+ dc_nan_count=dc_nan_count,
332
+ )
333
+
334
+ # Chek if file exists
335
+ output_file = f"valid_datacubes_{args.start_date}-{args.end_date}_{Dt}x{w}x{h}_{step_T}x{step_X}x{step_Y}_{N_nan}.csv"
336
+ if os.path.exists(output_file):
337
+ response = input(f"File {output_file} already exists. Overwrite? (y/n): ")
338
+ if response.lower() != "y":
339
+ print("Exiting without overwriting.")
340
+ sys.exit(0)
341
+ else:
342
+ print(f"Overwriting {output_file}...")
343
+
344
+ # Start writer thread
345
+ output_queue = Queue(maxsize=100)
346
+ writer_thread = Thread(target=file_writer, args=(output_queue, output_file, 1000))
347
+ writer_thread.daemon = False
348
+ writer_thread.start()
349
+
350
+ # Process chunks in parallel
351
+ with Pool(n_workers) as pool:
352
+ for _i, hits in enumerate(
353
+ tqdm(pool.imap(process_chunk_partial, t_pairs, chunksize=1), total=len(t_starts), desc="Processing time chunks")
354
+ ):
355
+ output_queue.put(hits)
356
+
357
+ # Signal writer thread to stop
358
+ output_queue.put(None)
359
+ writer_thread.join()
360
+
361
+ print(f"Done in {time.time() - START}s.")
362
+ sys.exit(0)
importance_sampler/output/sampled_datacubes_2021-01-01-2025-12-11_24x256x256_3x16x16_10000.csv ADDED
The diff for this file is too large to render. See raw diff
 
importance_sampler/output/sampled_datacubes_2021-01-01-2025-12-11_24x256x256_3x16x16_10000_metadata.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "csv": "valid_datacubes_2021-01-01-2025-12-11_24x256x256_3x16x16_10000.csv",
3
+ "zarr": "/leonardo_scratch/fast/AI4W_forecast_2/italian-radar-dpc-sri.zarr",
4
+ "file": "sampled_datacubes_2021-01-01-2025-12-11_24x256x256_3x16x16_10000.csv",
5
+ "start_date": "2021-01-01",
6
+ "end_date": "2025-12-11",
7
+ "Dt": 24,
8
+ "w": 256,
9
+ "h": 256,
10
+ "step_T": 3,
11
+ "step_X": 16,
12
+ "step_Y": 16,
13
+ "N_nan": 10000,
14
+ "N_rand": 1,
15
+ "n_workers": 112,
16
+ "qmin": 0.0001,
17
+ "m": 0.1,
18
+ "s": 1,
19
+ "seed": null,
20
+ "timestamp": "2025-12-13 21:52:41"
21
+ }
importance_sampler/sample_valid_datacubes.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import sys
5
+ import time
6
+ from functools import partial
7
+ from multiprocessing import Pool
8
+ from queue import Queue
9
+ from threading import Thread
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import xarray as xr
14
+ from tqdm import tqdm
15
+
16
+ START = time.time()
17
+ SEED = None # for reproducibility
18
+
19
+ # === Parse Arguments ===
20
+ parser = argparse.ArgumentParser(description="Importance sampler of the valid datacubes (after the nan filtering)")
21
+ parser.add_argument("zarr_path", help="Path to the Zarr dataset")
22
+ parser.add_argument(
23
+ "csv_path", help="Path to the CSV with the valid datacube coordinates (created by the nan filtering)"
24
+ )
25
+ parser.add_argument("--q_min", type=float, default=1e-4, help="Minimum selection probability (default 1e-4)")
26
+ parser.add_argument("--s", type=float, default=1, help="Denominator in the exponential")
27
+ parser.add_argument("--m", type=float, default=0.1, help="Factor weighting the mean rescaled rain rate (dafault 0.1)")
28
+ parser.add_argument("--n_workers", type=int, default=8, help="Number of parallel workers (default 8)")
29
+ parser.add_argument("--n_rand", type=int, default=1, help="Number of random sampling of each datacube (dafaut 1)")
30
+ args = parser.parse_args()
31
+
32
+ # === PARAMETERS ===
33
+ s = args.s
34
+ qmin = args.q_min
35
+ m = args.m
36
+
37
+ n_workers = args.n_workers # number of parallel workers
38
+ N_rand = args.n_rand # number of random numbers per region
39
+ chunksize = 16000 # = 500 CSV lines per workers
40
+
41
+ # Parameters from CSV filename
42
+ name_arr = args.csv_path.split("_")
43
+ dates = name_arr[2]
44
+ start_date = "-".join(dates.split("-")[0:3])
45
+ end_date = "-".join(dates.split("-")[3:])
46
+ Dt, w, h = name_arr[3].split("x")
47
+ step_T, step_X, step_Y = name_arr[4].split("x")
48
+ N_nan = name_arr[5][:-4]
49
+
50
+ # Casting
51
+ Dt, w, h = int(Dt), int(w), int(h)
52
+ step_T, step_X, step_Y = int(step_T), int(step_X), int(step_Y)
53
+ N_nan = int(N_nan)
54
+
55
+
56
+ # === FUNCTIONS ===
57
+ def acceptance_probability(data):
58
+ """
59
+ Calculate the acceptance probability for importance sampling.
60
+
61
+ The probability is ``min(1, q_min + m * mean(data))``, where ``q_min``
62
+ and ``m`` are module-level parameters.
63
+
64
+ Parameters
65
+ ----------
66
+ data : np.ndarray
67
+ Rescaled rain rate data for a single datacube.
68
+
69
+ Returns
70
+ -------
71
+ q : float
72
+ Acceptance probability in ``[q_min, 1]``.
73
+ """
74
+ return min(1.0, qmin + m * np.nanmean(data))
75
+
76
+
77
+ def process_datacube(coord, RR, N_rand, seed, acceptance_probability):
78
+ """
79
+ Process a single space-time region for importance sampling.
80
+
81
+ Loads the datacube, rescales rain rate, computes an acceptance
82
+ probability, and performs ``N_rand`` random acceptance trials.
83
+
84
+ Parameters
85
+ ----------
86
+ coord : array-like of int
87
+ Three-element sequence ``(it, ix, iy)`` specifying the datacube
88
+ origin.
89
+ RR : xr.DataArray
90
+ Rain rate data array from the Zarr dataset.
91
+ N_rand : int
92
+ Number of random acceptance trials per datacube.
93
+ seed : int or None
94
+ Random seed for reproducibility. If ``None``, non-deterministic.
95
+ acceptance_probability : callable
96
+ Function that takes a data array and returns a probability in
97
+ ``[0, 1]``.
98
+
99
+ Returns
100
+ -------
101
+ hits : list of tuple of int
102
+ List of accepted ``(it, ix, iy)`` tuples (may contain duplicates
103
+ if accepted multiple times).
104
+ """
105
+
106
+ try:
107
+ it, ix, iy = coord
108
+ time_slice = slice(it, it + Dt)
109
+ x_slice = slice(ix, ix + w)
110
+ y_slice = slice(iy, iy + h)
111
+
112
+ # Load data from Zarr
113
+ data = RR[time_slice, x_slice, y_slice]
114
+ data = 1 - np.exp(-data / s)
115
+
116
+ # Calculate acceptance probability
117
+ q = acceptance_probability(data)
118
+
119
+ # Generate random numbers with seed for reproducibility
120
+ rng = np.random.default_rng(seed)
121
+ random_numbers = rng.random(N_rand)
122
+ accepted_count = np.sum(random_numbers <= q)
123
+
124
+ # Return accepted hits
125
+ hits = [(it, ix, iy)] * accepted_count
126
+ return hits
127
+ except Exception as e:
128
+ print(f"Error processing region ({it}, {ix}, {iy}): {e}", file=sys.stderr)
129
+ return []
130
+
131
+
132
+ def file_writer(output_queue, filename, batch_size=1000):
133
+ """
134
+ Dedicated writer thread that flushes results to a CSV file in batches.
135
+
136
+ Reads lists of ``(t, x, y)`` tuples from the queue and writes them as
137
+ CSV rows. Stops when a ``None`` sentinel is received.
138
+
139
+ Parameters
140
+ ----------
141
+ output_queue : queue.Queue
142
+ Thread-safe queue providing lists of ``(t, x, y)`` tuples.
143
+ filename : str
144
+ Path to the output CSV file.
145
+ batch_size : int, optional
146
+ Number of rows to buffer before flushing to disk. Default is
147
+ ``1000``.
148
+ """
149
+ with open(filename, "w") as f:
150
+ f.write("t,x,y\n")
151
+ batch = []
152
+
153
+ while True:
154
+ item = output_queue.get()
155
+
156
+ if item is None: # Sentinel value to stop
157
+ # Write remaining batch
158
+ for t, x, y in batch:
159
+ f.write(f"{t},{x},{y}\n")
160
+ break
161
+
162
+ batch.extend(item)
163
+
164
+ if len(batch) >= batch_size:
165
+ for t, x, y in batch:
166
+ f.write(f"{t},{x},{y}\n")
167
+ f.flush()
168
+ batch = []
169
+
170
+ print(f"Results saved to {filename}")
171
+
172
+
173
+ # === Dataset Loading ===
174
+ print(f"Opening Zarr dataset: {args.zarr_path}")
175
+ try:
176
+ zg = xr.open_zarr(args.zarr_path, mode="r")
177
+ RR = zg["RR"]
178
+ except Exception as e:
179
+ print(f"Error loading Zarr dataset: {e}")
180
+ sys.exit(1)
181
+
182
+ # Chek if file exists
183
+ output_file = f"sampled_datacubes_{start_date}-{end_date}_{Dt}x{w}x{h}_{step_T}x{step_X}x{step_Y}_{N_nan}.csv"
184
+ if os.path.exists(output_file):
185
+ response = input(f"File {output_file} already exists. Overwrite? (y/n): ")
186
+ if response.lower() != "y":
187
+ print("Exiting without overwriting.")
188
+ sys.exit(0)
189
+ else:
190
+ print(f"Overwriting {output_file}...")
191
+
192
+ # Start writer thread
193
+ output_queue = Queue(maxsize=100)
194
+ writer_thread = Thread(target=file_writer, args=(output_queue, output_file, 1000))
195
+ writer_thread.daemon = False
196
+ writer_thread.start()
197
+
198
+ # save metadata
199
+ metadata = {
200
+ "csv": args.csv_path,
201
+ "zarr": args.zarr_path,
202
+ "file": output_file,
203
+ "start_date": start_date,
204
+ "end_date": end_date,
205
+ "Dt": Dt,
206
+ "w": w,
207
+ "h": h,
208
+ "step_T": step_T,
209
+ "step_X": step_X,
210
+ "step_Y": step_Y,
211
+ "N_nan": N_nan,
212
+ "N_rand": N_rand,
213
+ "n_workers": n_workers,
214
+ "qmin": qmin,
215
+ "m": m,
216
+ "s": s,
217
+ "seed": SEED,
218
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
219
+ }
220
+ metadata_filename = output_file.replace(".csv", "_metadata.json")
221
+ with open(metadata_filename, "w") as f:
222
+ json.dump(metadata, f, indent=2)
223
+ print(f"Saved run metadata to {metadata_filename}")
224
+
225
+
226
+ # === IMPORTANCE SAMPLING ===
227
+ # Create partial function with fixed parameters
228
+ process_datacube_partial = partial(
229
+ process_datacube, RR=RR, N_rand=N_rand, seed=SEED, acceptance_probability=acceptance_probability
230
+ )
231
+
232
+ pool_chunksize = max(1, chunksize // n_workers)
233
+
234
+ with Pool(n_workers) as pool:
235
+ pbar = tqdm(desc="Processing CSV chunks")
236
+
237
+ # Loading the CSV by chunks
238
+ for chunk in pd.read_csv(
239
+ args.csv_path,
240
+ usecols=["t", "x", "y"],
241
+ dtype={"t": "int32", "x": "int32", "y": "int32"},
242
+ engine="c",
243
+ chunksize=chunksize,
244
+ ):
245
+ for hits in pool.imap(process_datacube_partial, chunk.values, chunksize=pool_chunksize):
246
+ if hits:
247
+ output_queue.put(hits)
248
+ pbar.update(1)
249
+
250
+ pbar.close()
251
+
252
+ # Signal writer thread to stop
253
+ output_queue.put(None)
254
+ writer_thread.join()
255
+
256
+ print(f"Done in {time.time() - START}s.")
257
+ sys.exit(0)
notebooks/test_pretrained_model.ipynb ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "a5d812be",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Test pretrained model\n",
9
+ "This notebook tests the pretrained model on a single datacube taken from the radar dataset (https://arcodatahub.com/datasets/datasets/italian-radar-dpc-sri.zarr)."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "id": "19c9a668",
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": "import matplotlib.animation as animation\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pysteps.visualization.precipfields as pysteps_plot\nimport xarray as xr\nfrom IPython.display import HTML\n\nfrom convgru_ensemble import RadarLightningModel"
19
+ },
20
+ {
21
+ "cell_type": "markdown",
22
+ "id": "3bd5ff98",
23
+ "metadata": {},
24
+ "source": "# Load radar data\nWe first load a sample of the italian radar dataset provided in the `examples/` folder."
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "id": "008fb6ae",
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": "radar = xr.open_dataarray('../examples/sample_data.nc')\nradar"
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "id": "92f738bc",
37
+ "metadata": {},
38
+ "source": [
39
+ "This contains 18 sequences of radar images on the whole Italy, from the 28th to the 29th of October 2024. This is one of the most intense precipitation on Italy during 2024."
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "id": "fcb5529a",
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "# Create figure\n",
50
+ "fig, ax = plt.subplots(figsize=(4, 4.5))\n",
51
+ "\n",
52
+ "def update(frame):\n",
53
+ " ax.clear()\n",
54
+ " data = radar.isel(time=frame)\n",
55
+ " pysteps_plot.plot_precip_field(data.values, ax=ax, colorbar=False)\n",
56
+ " ax.set_title(f'Precipitation - {data.time.values}')\n",
57
+ " return ax,\n",
58
+ "\n",
59
+ "# Create animation\n",
60
+ "ani = animation.FuncAnimation(fig, update, frames=len(radar.time),\n",
61
+ " interval=500, blit=False, repeat=True)\n",
62
+ "\n",
63
+ "# Display in notebook\n",
64
+ "display(HTML(ani.to_jshtml()))\n",
65
+ "plt.close()"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "markdown",
70
+ "id": "e806b20a",
71
+ "metadata": {},
72
+ "source": [
73
+ "# Initialize the model \n",
74
+ "Initialize the model and load the weights from the checkpoint. You can change the number of future steps (forecast steps) and the ensemble size (ensemble_size). The other hyperparameters are fixed."
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": null,
80
+ "id": "c8ce3ff4",
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": "# Set model's parameters\nforecast_steps = 12\nensemble_size = 10\n\n# Initialize the model and load the checkpoint\n# Option A: Load from local checkpoint\nmodel = RadarLightningModel.from_checkpoint(checkpoint_path=\"../checkpoints/ConvGRU-CRPS_6past_12fut.ckpt\")\n\n# Option B: Load from HuggingFace Hub\n# model = RadarLightningModel.from_pretrained(\"it4lia/irene\")"
84
+ },
85
+ {
86
+ "cell_type": "markdown",
87
+ "id": "6f6fc1f9",
88
+ "metadata": {},
89
+ "source": [
90
+ "# Run the inference\n",
91
+ "We can run the inference and plot the forecast"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": null,
97
+ "id": "d6d340b2",
98
+ "metadata": {},
99
+ "outputs": [],
100
+ "source": "# Past and future steps\npast_steps = 7\nforecast_steps = 12\npast, future = radar[:past_steps], radar[past_steps:past_steps+forecast_steps]\n\n# Predict the future rainrate\npred = model.predict(past, forecast_steps, ensemble_size=2)"
101
+ },
102
+ {
103
+ "cell_type": "markdown",
104
+ "id": "f0705829",
105
+ "metadata": {},
106
+ "source": [
107
+ "### Plot the forecast"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "id": "0b0646f9",
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "# Ensemble mean\n",
118
+ "ensemble_mean = np.nanmean(pred, axis=0)\n",
119
+ "\n",
120
+ "# Initialize plots\n",
121
+ "fig, axs = plt.subplots(1, 4, figsize=(16,4.5))\n",
122
+ "row_labels = ['Ground Truth', 'Ensemble Mean', 'Member 1', 'Member 2']\n",
123
+ "data_sources = [future, ensemble_mean, pred[0], pred[1]]\n",
124
+ "\n",
125
+ "# Plot initial frame\n",
126
+ "for i, (ax, label, data) in enumerate(zip(axs, row_labels, data_sources)):\n",
127
+ " plot_precip_field(data[0], ax=ax, units='mm/h', colorscale='pysteps')\n",
128
+ " ax.set_title(label, fontsize=14)\n",
129
+ "\n",
130
+ "plt.tight_layout()\n",
131
+ "\n",
132
+ "# Animation function\n",
133
+ "def update(frame):\n",
134
+ " for i, (ax, data) in enumerate(zip(axs, data_sources)):\n",
135
+ " ax.clear()\n",
136
+ " plot_precip_field(data[frame], ax=ax, units='mm/h', colorscale='pysteps', colorbar=False)\n",
137
+ " ax.set_title(f'{row_labels[i]} - Step {frame}', fontsize=14)\n",
138
+ " return axs\n",
139
+ "\n",
140
+ "# Create animation\n",
141
+ "anim = FuncAnimation(fig, update, frames=forecast_steps, interval=500, blit=False)\n",
142
+ "\n",
143
+ "# Display\n",
144
+ "HTML(anim.to_jshtml())\n",
145
+ "\n",
146
+ "display(HTML(anim.to_jshtml()))\n",
147
+ "plt.close()"
148
+ ]
149
+ }
150
+ ],
151
+ "metadata": {
152
+ "kernelspec": {
153
+ "display_name": ".venv",
154
+ "language": "python",
155
+ "name": "python3"
156
+ },
157
+ "language_info": {
158
+ "codemirror_mode": {
159
+ "name": "ipython",
160
+ "version": 3
161
+ },
162
+ "file_extension": ".py",
163
+ "mimetype": "text/x-python",
164
+ "name": "python",
165
+ "nbconvert_exporter": "python",
166
+ "pygments_lexer": "ipython3",
167
+ "version": "3.13.12"
168
+ }
169
+ },
170
+ "nbformat": 4,
171
+ "nbformat_minor": 5
172
+ }
pyproject.toml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=69", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "ConvGRU-Ensemble"
7
+ version = "0.1.0"
8
+ description = "Ensemble precipitation nowcasting using Convolutional GRU networks"
9
+ readme = "README.md"
10
+ requires-python = ">=3.13"
11
+ dependencies = [
12
+ "absl-py>=2.3.1",
13
+ "einops>=0.8.1",
14
+ "etils>=1.13.0",
15
+ "fiddle>=0.3.0",
16
+ "fire>=0.7.1",
17
+ "importlib-resources>=6.5.2",
18
+ "jupyterlab>=4.5.5",
19
+ "matplotlib>=3.10.8",
20
+ "numpy>=2.3.5",
21
+ "pandas>=2.3.3",
22
+ "pysteps>=1.19.0",
23
+ "pytorch-lightning>=2.6.0",
24
+ "pyyaml>=6.0.3",
25
+ "tqdm>=4.67.1",
26
+ "tensorboard>=2.20.0",
27
+ "torch>=2.9.1",
28
+ "torchvision>=0.24.1",
29
+ "xarray>=2025.12.0",
30
+ "zarr>=3.1.5",
31
+ "huggingface_hub>=0.27.0",
32
+ ]
33
+
34
+ [project.optional-dependencies]
35
+ serve = [
36
+ "fastapi>=0.115.0",
37
+ "uvicorn>=0.34.0",
38
+ "python-multipart>=0.0.18",
39
+ ]
40
+
41
+ [project.scripts]
42
+ convgru-ensemble = "convgru_ensemble.cli:main"
43
+
44
+ [dependency-groups]
45
+ dev = [
46
+ "pytest>=8.3.5",
47
+ "pre-commit>=4.0.0",
48
+ "ruff>=0.9.10",
49
+ ]
50
+
51
+ [tool.setuptools]
52
+ packages = ["convgru_ensemble"]
53
+
54
+ [tool.ruff]
55
+ target-version = "py313"
56
+ line-length = 120
57
+ extend-exclude = ["notebooks/", "importance_sampler/"]
58
+
59
+ [tool.ruff.lint]
60
+ select = [
61
+ "E", # pycodestyle errors
62
+ "W", # pycodestyle warnings
63
+ "F", # pyflakes
64
+ "I", # isort
65
+ "UP", # pyupgrade
66
+ "B", # flake8-bugbear
67
+ "SIM", # flake8-simplify
68
+ "NPY", # numpy-specific rules
69
+ ]
70
+ ignore = [
71
+ "E501", # line too long (handled by formatter)
72
+ "SIM108", # ternary operator (readability preference)
73
+ ]
74
+
75
+ [tool.ruff.lint.isort]
76
+ known-first-party = ["convgru_ensemble"]
scripts/upload_to_hub.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Upload a trained ConvGRU-Ensemble model to HuggingFace Hub."""
3
+
4
+ import argparse
5
+ from pathlib import Path
6
+
7
+
8
+ def main():
9
+ parser = argparse.ArgumentParser(description="Upload model to HuggingFace Hub")
10
+ parser.add_argument("--checkpoint", required=True, help="Path to .ckpt checkpoint file")
11
+ parser.add_argument("--repo-id", required=True, help="HuggingFace repo ID (e.g., it4lia/irene)")
12
+ parser.add_argument("--model-card", default=None, help="Path to model card markdown file")
13
+ parser.add_argument("--private", action="store_true", help="Create a private repository")
14
+ args = parser.parse_args()
15
+
16
+ # Default model card
17
+ model_card = args.model_card
18
+ if model_card is None:
19
+ default_card = Path(__file__).parent.parent / "MODEL_CARD.md"
20
+ if default_card.exists():
21
+ model_card = str(default_card)
22
+
23
+ from convgru_ensemble.hub import push_to_hub
24
+
25
+ url = push_to_hub(
26
+ checkpoint_path=args.checkpoint,
27
+ repo_id=args.repo_id,
28
+ model_card_path=model_card,
29
+ private=args.private,
30
+ )
31
+ print(f"Model uploaded: {url}")
32
+
33
+
34
+ if __name__ == "__main__":
35
+ main()
tests/conftest.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pytest
3
+
4
+ from convgru_ensemble.lightning_model import RadarLightningModel
5
+
6
+
7
+ @pytest.fixture
8
+ def model_small():
9
+ """Small model with num_blocks=2 for fast testing."""
10
+ return RadarLightningModel(
11
+ input_channels=1,
12
+ num_blocks=2,
13
+ forecast_steps=2,
14
+ ensemble_size=2,
15
+ noisy_decoder=False,
16
+ )
17
+
18
+
19
+ @pytest.fixture
20
+ def sample_rain_rate():
21
+ """Synthetic rain rate data of shape (T, H, W)."""
22
+ rng = np.random.default_rng(42)
23
+ return rng.random((4, 16, 16), dtype=np.float32) * 10.0 # 0-10 mm/h
tests/test_hub.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest.mock import MagicMock, patch
2
+
3
+ from convgru_ensemble.hub import from_pretrained
4
+
5
+
6
+ @patch("convgru_ensemble.hub.hf_hub_download")
7
+ @patch("convgru_ensemble.hub.RadarLightningModel", create=True)
8
+ def test_from_pretrained_calls_hf_hub_download(mock_model_cls, mock_download):
9
+ mock_download.return_value = "/tmp/cached/model.ckpt"
10
+
11
+ # Patch the import inside the function
12
+ with patch("convgru_ensemble.lightning_model.RadarLightningModel") as mock_cls:
13
+ mock_cls.from_checkpoint.return_value = MagicMock()
14
+ from_pretrained("it4lia/irene", filename="model.ckpt", device="cpu")
15
+
16
+ mock_download.assert_called_once_with(repo_id="it4lia/irene", filename="model.ckpt")
tests/test_inference.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import xarray as xr
3
+
4
+ from convgru_ensemble.lightning_model import RadarLightningModel
5
+
6
+
7
+ def test_inference_on_sample_data():
8
+ """End-to-end inference on the sample NetCDF with a freshly initialized model."""
9
+ ds = xr.open_dataset("examples/sample_data.nc")
10
+ rain = ds["RR"].values # (54, 1400, 1200)
11
+
12
+ # Use 6 past frames, full spatial extent
13
+ past = rain[:6].astype(np.float32)
14
+ _, H, W = past.shape
15
+
16
+ model = RadarLightningModel(
17
+ input_channels=1,
18
+ num_blocks=3,
19
+ forecast_steps=4,
20
+ ensemble_size=1,
21
+ noisy_decoder=False,
22
+ )
23
+
24
+ preds = model.predict(past, forecast_steps=4, ensemble_size=1)
25
+
26
+ assert preds.shape == (1, 4, H, W)
27
+ assert np.isfinite(preds).all()
28
+ assert preds.dtype == np.float64 or preds.dtype == np.float32
tests/test_lightning_model.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from convgru_ensemble.lightning_model import RadarLightningModel
5
+
6
+
7
+ def test_predict_handles_unpadded_inputs():
8
+ model = RadarLightningModel(
9
+ input_channels=1,
10
+ num_blocks=1,
11
+ forecast_steps=2,
12
+ ensemble_size=1,
13
+ noisy_decoder=False,
14
+ )
15
+ past = np.zeros((4, 8, 8), dtype=np.float32)
16
+
17
+ preds = model.predict(past, forecast_steps=2, ensemble_size=1)
18
+
19
+ assert preds.shape == (1, 2, 8, 8)
20
+ assert np.isfinite(preds).all()
21
+
22
+
23
+ def test_from_checkpoint_delegates_to_lightning_loader(monkeypatch):
24
+ captured = {}
25
+
26
+ def fake_loader(cls, checkpoint_path, map_location=None, strict=None, weights_only=None):
27
+ captured["checkpoint_path"] = checkpoint_path
28
+ captured["map_location"] = map_location
29
+ captured["strict"] = strict
30
+ captured["weights_only"] = weights_only
31
+ return "loaded-model"
32
+
33
+ monkeypatch.setattr(RadarLightningModel, "load_from_checkpoint", classmethod(fake_loader))
34
+
35
+ loaded = RadarLightningModel.from_checkpoint("/tmp/model.ckpt", device="cpu")
36
+
37
+ assert loaded == "loaded-model"
38
+ assert captured["checkpoint_path"] == "/tmp/model.ckpt"
39
+ assert isinstance(captured["map_location"], torch.device)
40
+ assert captured["map_location"].type == "cpu"
41
+ assert captured["strict"] is True
tests/test_losses.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from convgru_ensemble.losses import CRPS, MaskedLoss, build_loss
4
+
5
+
6
+ def test_crps_reduces_to_scalar():
7
+ loss_fn = CRPS(reduction="mean")
8
+ preds = torch.randn(2, 4, 5, 8, 8) # (B, T, M, H, W)
9
+ target = torch.randn(2, 4, 1, 8, 8) # (B, T, 1, H, W)
10
+ loss = loss_fn(preds, target)
11
+ assert loss.dim() == 0 # scalar
12
+ assert loss.item() > 0 or loss.item() == 0
13
+
14
+
15
+ def test_masked_loss_ignores_masked_pixels():
16
+ base_loss = torch.nn.MSELoss(reduction="none")
17
+ loss_fn = MaskedLoss(base_loss, reduction="mean")
18
+ preds = torch.ones(1, 2, 1, 4, 4)
19
+ target = torch.zeros(1, 2, 1, 4, 4)
20
+ # Mask out everything — loss should be 0
21
+ mask = torch.zeros(1, 1, 1, 4, 4)
22
+ loss = loss_fn(preds, target, mask)
23
+ assert loss.item() == 0.0
24
+
25
+
26
+ def test_build_loss_by_name():
27
+ criterion = build_loss("crps", loss_params=None, masked_loss=False)
28
+ assert isinstance(criterion, CRPS)
29
+
30
+
31
+ def test_build_loss_masked():
32
+ criterion = build_loss("mse", loss_params=None, masked_loss=True)
33
+ assert isinstance(criterion, MaskedLoss)
tests/test_model.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from convgru_ensemble.model import EncoderDecoder
4
+
5
+
6
+ def test_forward_single_output_shape():
7
+ model = EncoderDecoder(channels=1, num_blocks=2)
8
+ x = torch.randn(2, 4, 1, 16, 16)
9
+ out = model(x, steps=3, noisy_decoder=False, ensemble_size=1)
10
+ assert out.shape == (2, 3, 1, 16, 16)
11
+
12
+
13
+ def test_forward_ensemble_output_shape():
14
+ model = EncoderDecoder(channels=1, num_blocks=2)
15
+ x = torch.randn(2, 4, 1, 16, 16)
16
+ out = model(x, steps=3, noisy_decoder=False, ensemble_size=5)
17
+ assert out.shape == (2, 3, 5, 16, 16)
18
+
19
+
20
+ def test_forward_different_num_blocks():
21
+ model = EncoderDecoder(channels=1, num_blocks=3)
22
+ x = torch.randn(1, 4, 1, 32, 32)
23
+ out = model(x, steps=2, ensemble_size=1)
24
+ assert out.shape == (1, 2, 1, 32, 32)
25
+
26
+
27
+ def test_noisy_decoder_produces_different_outputs():
28
+ model = EncoderDecoder(channels=1, num_blocks=2)
29
+ model.eval()
30
+ x = torch.randn(1, 4, 1, 16, 16)
31
+ out1 = model(x, steps=2, noisy_decoder=True, ensemble_size=1)
32
+ out2 = model(x, steps=2, noisy_decoder=True, ensemble_size=1)
33
+ # Noisy decoder should produce different outputs (with very high probability)
34
+ assert not torch.allclose(out1, out2)
tests/test_serve.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from unittest.mock import MagicMock, patch
3
+
4
+ import numpy as np
5
+ import pytest
6
+ import xarray as xr
7
+
8
+ from convgru_ensemble.lightning_model import RadarLightningModel
9
+
10
+
11
+ @pytest.fixture
12
+ def mock_model():
13
+ model = MagicMock(spec=RadarLightningModel)
14
+ model.hparams = MagicMock()
15
+ model.hparams.input_channels = 1
16
+ model.hparams.num_blocks = 2
17
+ model.hparams.forecast_steps = 12
18
+ model.hparams.ensemble_size = 2
19
+ model.hparams.noisy_decoder = False
20
+ model.hparams.loss_class = "crps"
21
+ model.device = "cpu"
22
+ model.predict.return_value = np.zeros((10, 12, 8, 8), dtype=np.float32)
23
+ return model
24
+
25
+
26
+ @pytest.fixture
27
+ def client(mock_model):
28
+ from fastapi.testclient import TestClient
29
+
30
+ with patch("convgru_ensemble.serve._load_model", return_value=mock_model):
31
+ from convgru_ensemble.serve import app
32
+
33
+ with TestClient(app) as c:
34
+ yield c
35
+
36
+
37
+ def test_health(client):
38
+ resp = client.get("/health")
39
+ assert resp.status_code == 200
40
+ data = resp.json()
41
+ assert data["status"] == "ok"
42
+ assert data["model_loaded"] is True
43
+
44
+
45
+ def test_model_info(client):
46
+ resp = client.get("/model/info")
47
+ assert resp.status_code == 200
48
+ data = resp.json()
49
+ assert data["architecture"] == "ConvGRU-Ensemble EncoderDecoder"
50
+ assert data["num_blocks"] == 2
51
+
52
+
53
+ def test_predict_returns_netcdf(client):
54
+ # Create a small NetCDF file in memory
55
+ ds = xr.Dataset({"RR": xr.DataArray(np.zeros((4, 8, 8), dtype=np.float32), dims=["time", "y", "x"])})
56
+ buf = io.BytesIO()
57
+ ds.to_netcdf(buf)
58
+ buf.seek(0)
59
+
60
+ resp = client.post(
61
+ "/predict?forecast_steps=12&ensemble_size=10",
62
+ files={"file": ("input.nc", buf, "application/x-netcdf")},
63
+ )
64
+ assert resp.status_code == 200
65
+ assert resp.headers["content-type"] == "application/x-netcdf"
66
+ assert "X-Elapsed-Seconds" in resp.headers
tests/test_utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from convgru_ensemble.utils import normalized_to_rainrate, rainrate_to_normalized
4
+
5
+
6
+ def test_roundtrip_conversion():
7
+ rain = np.array([0.0, 1.0, 5.0, 20.0, 50.0], dtype=np.float32)
8
+ normalized = rainrate_to_normalized(rain)
9
+ recovered = normalized_to_rainrate(normalized)
10
+ np.testing.assert_allclose(recovered, rain, rtol=0.01, atol=0.05)
11
+
12
+
13
+ def test_zero_rain_maps_to_low_normalized():
14
+ rain = np.array([0.0], dtype=np.float32)
15
+ normalized = rainrate_to_normalized(rain)
16
+ # Zero rain should map to approximately -1 (0 dBZ → normalized -1)
17
+ assert normalized[0] < -0.9
uv.lock ADDED
The diff for this file is too large to render. See raw diff