Add source code and examples
Browse files- .dockerignore +15 -0
- .gitattributes +1 -0
- .github/workflows/ci.yml +29 -0
- .gitignore +17 -0
- .pre-commit-config.yaml +17 -0
- Dockerfile +26 -0
- LICENSE +24 -0
- MODEL_CARD.md +82 -0
- Makefile +22 -0
- README.md +220 -52
- convgru_ensemble/__init__.py +8 -0
- convgru_ensemble/__main__.py +5 -0
- convgru_ensemble/cli.py +126 -0
- convgru_ensemble/datamodule.py +378 -0
- convgru_ensemble/hub.py +102 -0
- convgru_ensemble/lightning_model.py +560 -0
- convgru_ensemble/losses.py +458 -0
- convgru_ensemble/model.py +569 -0
- convgru_ensemble/py.typed +0 -0
- convgru_ensemble/serve.py +139 -0
- convgru_ensemble/train.py +316 -0
- convgru_ensemble/utils.py +122 -0
- docker-compose.yml +12 -0
- examples/sample_data.nc +3 -0
- importance_sampler/filter_nan.py +362 -0
- importance_sampler/output/sampled_datacubes_2021-01-01-2025-12-11_24x256x256_3x16x16_10000.csv +0 -0
- importance_sampler/output/sampled_datacubes_2021-01-01-2025-12-11_24x256x256_3x16x16_10000_metadata.json +21 -0
- importance_sampler/sample_valid_datacubes.py +257 -0
- notebooks/test_pretrained_model.ipynb +172 -0
- pyproject.toml +76 -0
- scripts/upload_to_hub.py +35 -0
- tests/conftest.py +23 -0
- tests/test_hub.py +16 -0
- tests/test_inference.py +28 -0
- tests/test_lightning_model.py +41 -0
- tests/test_losses.py +33 -0
- tests/test_model.py +34 -0
- tests/test_serve.py +66 -0
- tests/test_utils.py +17 -0
- uv.lock +0 -0
.dockerignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
.github
|
| 3 |
+
.venv
|
| 4 |
+
__pycache__
|
| 5 |
+
*.pyc
|
| 6 |
+
*.zarr
|
| 7 |
+
logs/
|
| 8 |
+
data/
|
| 9 |
+
notebooks/
|
| 10 |
+
importance_sampler/
|
| 11 |
+
tests/
|
| 12 |
+
*.egg-info
|
| 13 |
+
.ruff_cache
|
| 14 |
+
.pytest_cache
|
| 15 |
+
.claude
|
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
examples/sample_data.nc filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/ci.yml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: CI
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
pull_request:
|
| 6 |
+
|
| 7 |
+
jobs:
|
| 8 |
+
lint-test-build:
|
| 9 |
+
runs-on: ubuntu-latest
|
| 10 |
+
steps:
|
| 11 |
+
- name: Checkout
|
| 12 |
+
uses: actions/checkout@v4
|
| 13 |
+
|
| 14 |
+
- name: Install uv
|
| 15 |
+
uses: astral-sh/setup-uv@v4
|
| 16 |
+
with:
|
| 17 |
+
python-version: "3.13"
|
| 18 |
+
|
| 19 |
+
- name: Sync dependencies
|
| 20 |
+
run: uv sync --all-groups --extra serve
|
| 21 |
+
|
| 22 |
+
- name: Lint
|
| 23 |
+
run: uv run ruff check .
|
| 24 |
+
|
| 25 |
+
- name: Tests
|
| 26 |
+
run: uv run pytest -q
|
| 27 |
+
|
| 28 |
+
- name: Build package
|
| 29 |
+
run: uv build
|
.gitignore
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.nc
|
| 2 |
+
!examples/*.nc
|
| 3 |
+
*.zarr/
|
| 4 |
+
__pycache__/
|
| 5 |
+
.ipynb_checkpoints/
|
| 6 |
+
checkpoints/
|
| 7 |
+
*.ckpt
|
| 8 |
+
*.mp4
|
| 9 |
+
dist/
|
| 10 |
+
build/
|
| 11 |
+
*.egg-info/
|
| 12 |
+
.pytest_cache/
|
| 13 |
+
.ruff_cache/
|
| 14 |
+
logs/
|
| 15 |
+
data/
|
| 16 |
+
.env
|
| 17 |
+
.venv/
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 3 |
+
rev: v5.0.0
|
| 4 |
+
hooks:
|
| 5 |
+
- id: trailing-whitespace
|
| 6 |
+
- id: end-of-file-fixer
|
| 7 |
+
- id: check-yaml
|
| 8 |
+
- id: check-added-large-files
|
| 9 |
+
args: ['--maxkb=12000']
|
| 10 |
+
- id: check-merge-conflict
|
| 11 |
+
|
| 12 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 13 |
+
rev: v0.9.10
|
| 14 |
+
hooks:
|
| 15 |
+
- id: ruff
|
| 16 |
+
args: [--fix]
|
| 17 |
+
- id: ruff-format
|
Dockerfile
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.13-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install uv for fast dependency management
|
| 6 |
+
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
|
| 7 |
+
|
| 8 |
+
# Copy dependency files first for layer caching
|
| 9 |
+
COPY pyproject.toml uv.lock ./
|
| 10 |
+
|
| 11 |
+
# Install dependencies (serve extras, no dev)
|
| 12 |
+
RUN uv sync --extra serve --no-dev --no-install-project
|
| 13 |
+
|
| 14 |
+
# Copy package source
|
| 15 |
+
COPY convgru_ensemble/ ./convgru_ensemble/
|
| 16 |
+
|
| 17 |
+
# Install the project itself
|
| 18 |
+
RUN uv sync --extra serve --no-dev
|
| 19 |
+
|
| 20 |
+
# Model checkpoint is mounted at runtime or downloaded from HF Hub
|
| 21 |
+
ENV MODEL_CHECKPOINT=/app/model.ckpt
|
| 22 |
+
ENV DEVICE=cpu
|
| 23 |
+
|
| 24 |
+
EXPOSE 8000
|
| 25 |
+
|
| 26 |
+
CMD ["uv", "run", "uvicorn", "convgru_ensemble.serve:app", "--host", "0.0.0.0", "--port", "8000"]
|
LICENSE
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
BSD 2-Clause License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026, Data Science for Industry and Physics @ FBK
|
| 4 |
+
|
| 5 |
+
Redistribution and use in source and binary forms, with or without
|
| 6 |
+
modification, are permitted provided that the following conditions are met:
|
| 7 |
+
|
| 8 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
list of conditions and the following disclaimer.
|
| 10 |
+
|
| 11 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
and/or other materials provided with the distribution.
|
| 14 |
+
|
| 15 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 16 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 17 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 18 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 19 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 20 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 21 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 22 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 23 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 24 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
MODEL_CARD.md
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: bsd-2-clause
|
| 3 |
+
language: en
|
| 4 |
+
tags:
|
| 5 |
+
- weather
|
| 6 |
+
- nowcasting
|
| 7 |
+
- radar
|
| 8 |
+
- precipitation
|
| 9 |
+
- ensemble-forecasting
|
| 10 |
+
- convgru
|
| 11 |
+
- earth-observation
|
| 12 |
+
library_name: pytorch
|
| 13 |
+
pipeline_tag: image-to-image
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# IRENE — Italian Radar Ensemble Nowcasting Experiment
|
| 17 |
+
|
| 18 |
+
**IRENE** is a ConvGRU encoder-decoder model for short-term precipitation forecasting (nowcasting) from radar data. The model generates probabilistic ensemble forecasts, producing multiple plausible future scenarios from a single input sequence.
|
| 19 |
+
|
| 20 |
+
## Model Description
|
| 21 |
+
|
| 22 |
+
- **Architecture**: ConvGRU encoder-decoder with PixelShuffle/PixelUnshuffle for spatial scaling
|
| 23 |
+
- **Input**: Sequence of past radar rain rate fields (T, H, W) in mm/h
|
| 24 |
+
- **Output**: Ensemble of future rain rate forecasts (E, T, H, W) in mm/h
|
| 25 |
+
- **Temporal resolution**: 5 minutes per timestep
|
| 26 |
+
- **Training loss**: Continuous Ranked Probability Score (CRPS) with temporal consistency regularization
|
| 27 |
+
|
| 28 |
+
The model encodes past radar observations into multi-scale hidden states using stacked ConvGRU blocks with PixelUnshuffle downsampling. The decoder generates forecasts by unrolling with different random noise inputs, producing diverse ensemble members that capture forecast uncertainty.
|
| 29 |
+
|
| 30 |
+
## Intended Uses
|
| 31 |
+
|
| 32 |
+
- Short-term precipitation forecasting (0-60 min ahead) from radar reflectivity data
|
| 33 |
+
- Probabilistic nowcasting with uncertainty quantification via ensemble spread
|
| 34 |
+
- Research on deep learning for weather prediction
|
| 35 |
+
- Fine-tuning on regional radar datasets
|
| 36 |
+
|
| 37 |
+
## How to Use
|
| 38 |
+
|
| 39 |
+
```python
|
| 40 |
+
from convgru_ensemble import RadarLightningModel
|
| 41 |
+
|
| 42 |
+
# Load from HuggingFace Hub
|
| 43 |
+
model = RadarLightningModel.from_pretrained("it4lia/irene")
|
| 44 |
+
|
| 45 |
+
# Run inference on past radar data (rain rate in mm/h)
|
| 46 |
+
import numpy as np
|
| 47 |
+
past = np.random.rand(6, 256, 256).astype(np.float32) # 6 past timesteps
|
| 48 |
+
forecasts = model.predict(past, forecast_steps=12, ensemble_size=10)
|
| 49 |
+
# forecasts.shape = (10, 12, 256, 256) — 10 ensemble members, 12 future steps
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
## Training Data
|
| 53 |
+
|
| 54 |
+
Trained on the Italian DPC (Dipartimento della Protezione Civile) radar mosaic surface rain intensity (SRI) dataset, covering the Italian territory at ~1 km resolution with 5-minute temporal resolution.
|
| 55 |
+
|
| 56 |
+
## Training Procedure
|
| 57 |
+
|
| 58 |
+
- **Optimizer**: Adam (lr=1e-4)
|
| 59 |
+
- **Loss**: CRPS with temporal consistency penalty (lambda=0.01)
|
| 60 |
+
- **Batch size**: 16
|
| 61 |
+
- **Ensemble size during training**: 2 members
|
| 62 |
+
- **Input window**: 6 past timesteps (30 min)
|
| 63 |
+
- **Forecast horizon**: 12 future timesteps (60 min)
|
| 64 |
+
- **Data augmentation**: Random rotations and flips
|
| 65 |
+
- **NaN handling**: Masked loss for missing radar data
|
| 66 |
+
|
| 67 |
+
## Limitations
|
| 68 |
+
|
| 69 |
+
- Trained on Italian radar data; performance may degrade on other domains without fine-tuning
|
| 70 |
+
- 5-minute temporal resolution only
|
| 71 |
+
- Best suited for convective and stratiform precipitation; extreme events may be underrepresented
|
| 72 |
+
- Ensemble spread is generated via noisy decoder inputs, not a full Bayesian approach
|
| 73 |
+
|
| 74 |
+
## Acknowledgements
|
| 75 |
+
|
| 76 |
+
This model was developed as part of the **Italian AI-Factory** (IT4LIA), an EU-funded initiative supporting the adoption of AI across SMEs, academia, and public/private sectors. The AI-Factory provides free HPC compute, consultancy, and AI-ready datasets. This work showcases capabilities in the **Earth (weather and climate) vertical domain**.
|
| 77 |
+
|
| 78 |
+
Developed at **Fondazione Bruno Kessler (FBK)**, Trento, Italy.
|
| 79 |
+
|
| 80 |
+
## License
|
| 81 |
+
|
| 82 |
+
BSD 2-Clause License
|
Makefile
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: install lint test serve docker-build docker-run
|
| 2 |
+
|
| 3 |
+
install:
|
| 4 |
+
uv sync --all-groups --extra serve
|
| 5 |
+
|
| 6 |
+
lint:
|
| 7 |
+
uv run ruff check .
|
| 8 |
+
|
| 9 |
+
format:
|
| 10 |
+
uv run ruff format .
|
| 11 |
+
|
| 12 |
+
test:
|
| 13 |
+
uv run pytest -q
|
| 14 |
+
|
| 15 |
+
serve:
|
| 16 |
+
uv run uvicorn convgru_ensemble.serve:app --host 0.0.0.0 --port 8000
|
| 17 |
+
|
| 18 |
+
docker-build:
|
| 19 |
+
docker build -t convgru-ensemble .
|
| 20 |
+
|
| 21 |
+
docker-run:
|
| 22 |
+
docker run -p 8000:8000 -v ./checkpoints:/app/checkpoints convgru-ensemble
|
README.md
CHANGED
|
@@ -1,82 +1,250 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
- nowcasting
|
| 7 |
-
- radar
|
| 8 |
-
- precipitation
|
| 9 |
-
- ensemble-forecasting
|
| 10 |
-
- convgru
|
| 11 |
-
- earth-observation
|
| 12 |
-
library_name: pytorch
|
| 13 |
-
pipeline_tag: image-to-image
|
| 14 |
-
---
|
| 15 |
|
| 16 |
-
|
| 17 |
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
|
| 28 |
-
|
| 29 |
|
| 30 |
-
|
| 31 |
|
| 32 |
-
|
| 33 |
-
- Probabilistic nowcasting with uncertainty quantification via ensemble spread
|
| 34 |
-
- Research on deep learning for weather prediction
|
| 35 |
-
- Fine-tuning on regional radar datasets
|
| 36 |
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
```python
|
| 40 |
from convgru_ensemble import RadarLightningModel
|
| 41 |
|
| 42 |
-
# Load from HuggingFace Hub
|
| 43 |
model = RadarLightningModel.from_pretrained("it4lia/irene")
|
| 44 |
|
| 45 |
-
# Run inference on past radar data (rain rate in mm/h)
|
| 46 |
import numpy as np
|
| 47 |
-
past = np.
|
| 48 |
forecasts = model.predict(past, forecast_steps=12, ensemble_size=10)
|
| 49 |
-
# forecasts.shape = (10, 12,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
```
|
| 51 |
|
| 52 |
-
|
| 53 |
|
| 54 |
-
|
| 55 |
|
| 56 |
-
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
## Acknowledgements
|
| 75 |
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
|
| 79 |
|
| 80 |
## License
|
| 81 |
|
| 82 |
-
BSD 2-Clause
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
# ConvGRU-Ensemble
|
| 4 |
+
|
| 5 |
+
**Ensemble precipitation nowcasting using Convolutional GRU networks**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
*Pretrained model for Italy:* ***IRENE*** — **I**talian **R**adar **E**nsemble **N**owcasting **E**xperiment
|
| 8 |
|
| 9 |
+
[](https://github.com/DSIP-FBK/ConvGRU-Ensemble/actions)
|
| 10 |
+
[](LICENSE)
|
| 11 |
+
[](https://python.org)
|
| 12 |
+
[](https://huggingface.co/it4lia/irene)
|
| 13 |
|
| 14 |
+
<br>
|
| 15 |
|
| 16 |
+
<a href="https://www.fbk.eu"><img src="https://webvalley.fbk.eu/static/img/logos/fbk-logo-blue.png" height="55" alt="Fondazione Bruno Kessler"></a>
|
| 17 |
+
|
| 18 |
+
<a href="https://it4lia-aifactory.eu"><img src="https://it4lia-aifactory.eu/wp-content/uploads/2025/05/logo-IT4LIA-AI-factory.svg" height="55" alt="IT4LIA AI-Factory"></a>
|
| 19 |
+
|
| 20 |
+
<a href="https://www.italiameteo.eu"><img src="https://it4lia-aifactory.eu/wp-content/uploads/2025/08/logo-italiameteo.svg" height="55" alt="ItaliaMeteo"></a>
|
| 21 |
|
| 22 |
+
<br><br>
|
| 23 |
|
| 24 |
+
The model encodes past radar frames into multi-scale hidden states and decodes them into an **ensemble of probabilistic forecasts** by running the decoder multiple times with different noise inputs, trained with **CRPS loss**.
|
| 25 |
|
| 26 |
+
</div>
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
---
|
| 29 |
+
|
| 30 |
+
## Quick Start
|
| 31 |
+
|
| 32 |
+
<details open>
|
| 33 |
+
<summary><b>Load from HuggingFace Hub</b></summary>
|
| 34 |
|
| 35 |
```python
|
| 36 |
from convgru_ensemble import RadarLightningModel
|
| 37 |
|
|
|
|
| 38 |
model = RadarLightningModel.from_pretrained("it4lia/irene")
|
| 39 |
|
|
|
|
| 40 |
import numpy as np
|
| 41 |
+
past = np.load("past_radar.npy") # rain rate in mm/h, shape (T_past, H, W)
|
| 42 |
forecasts = model.predict(past, forecast_steps=12, ensemble_size=10)
|
| 43 |
+
# forecasts.shape = (10, 12, H, W) — 10 members, 12 future steps, mm/h
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
</details>
|
| 47 |
+
|
| 48 |
+
<details>
|
| 49 |
+
<summary><b>CLI Inference</b></summary>
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
convgru-ensemble predict \
|
| 53 |
+
--input examples/sample_data.nc \
|
| 54 |
+
--hub-repo it4lia/irene \
|
| 55 |
+
--forecast-steps 12 \
|
| 56 |
+
--ensemble-size 10 \
|
| 57 |
+
--output predictions.nc
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
</details>
|
| 61 |
+
|
| 62 |
+
<details>
|
| 63 |
+
<summary><b>Serve via API</b></summary>
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
# With Docker
|
| 67 |
+
docker compose up
|
| 68 |
+
|
| 69 |
+
# Or directly
|
| 70 |
+
pip install convgru-ensemble[serve]
|
| 71 |
+
convgru-ensemble serve --hub-repo it4lia/irene --port 8000
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
```bash
|
| 75 |
+
curl -X POST http://localhost:8000/predict \
|
| 76 |
+
-F "file=@input.nc" -o predictions.nc
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
| Endpoint | Method | Description |
|
| 80 |
+
|---|---|---|
|
| 81 |
+
| `/health` | GET | Health check |
|
| 82 |
+
| `/model/info` | GET | Model metadata and hyperparameters |
|
| 83 |
+
| `/predict` | POST | Upload NetCDF, get ensemble forecast as NetCDF |
|
| 84 |
+
|
| 85 |
+
</details>
|
| 86 |
+
|
| 87 |
+
<details>
|
| 88 |
+
<summary><b>Fine-tune on your data</b></summary>
|
| 89 |
+
|
| 90 |
+
```bash
|
| 91 |
+
pip install convgru-ensemble
|
| 92 |
+
# See "Training" section below
|
| 93 |
```
|
| 94 |
|
| 95 |
+
</details>
|
| 96 |
|
| 97 |
+
## Setup
|
| 98 |
|
| 99 |
+
Requires Python >= 3.13. Uses [uv](https://github.com/astral-sh/uv) for dependency management.
|
| 100 |
|
| 101 |
+
```bash
|
| 102 |
+
uv sync # core dependencies
|
| 103 |
+
uv sync --extra serve # + FastAPI serving
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
## Data Preparation
|
| 107 |
+
|
| 108 |
+
The training pipeline expects a Zarr dataset with a rain rate variable `RR` indexed by `(time, x, y)`.
|
| 109 |
+
|
| 110 |
+
<details>
|
| 111 |
+
<summary><b>1. Filter valid datacubes</b></summary>
|
| 112 |
+
|
| 113 |
+
Scan the Zarr and find all space-time datacubes with fewer than `n_nan` NaN values:
|
| 114 |
+
|
| 115 |
+
```bash
|
| 116 |
+
cd importance_sampler
|
| 117 |
+
uv run python filter_nan.py path/to/dataset.zarr \
|
| 118 |
+
--start_date 2021-01-01 --end_date 2025-12-11 \
|
| 119 |
+
--Dt 24 --w 256 --h 256 \
|
| 120 |
+
--step_T 3 --step_X 16 --step_Y 16 \
|
| 121 |
+
--n_nan 10000 --n_workers 8
|
| 122 |
+
```
|
| 123 |
|
| 124 |
+
</details>
|
| 125 |
|
| 126 |
+
<details>
|
| 127 |
+
<summary><b>2. Importance sampling</b></summary>
|
| 128 |
+
|
| 129 |
+
Sample valid datacubes with higher probability for rainier events:
|
| 130 |
+
|
| 131 |
+
```bash
|
| 132 |
+
uv run python sample_valid_datacubes.py path/to/dataset.zarr valid_datacubes_*.csv \
|
| 133 |
+
--q_min 1e-4 --m 0.1 --n_workers 8
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
A pre-sampled CSV is provided in [`importance_sampler/output/`](importance_sampler/output/).
|
| 137 |
+
|
| 138 |
+
</details>
|
| 139 |
+
|
| 140 |
+
## Training
|
| 141 |
+
|
| 142 |
+
Training is configured via [Fiddle](https://github.com/google/fiddle). Run with defaults:
|
| 143 |
+
|
| 144 |
+
```bash
|
| 145 |
+
uv run python -m convgru_ensemble.train
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
Override parameters from the command line:
|
| 149 |
+
|
| 150 |
+
```bash
|
| 151 |
+
uv run python -m convgru_ensemble.train \
|
| 152 |
+
--config config:experiment \
|
| 153 |
+
--config set:model.num_blocks=5 \
|
| 154 |
+
--config set:model.forecast_steps=12 \
|
| 155 |
+
--config set:model.loss_class=crps \
|
| 156 |
+
--config set:model.ensemble_size=2 \
|
| 157 |
+
--config set:datamodule.batch_size=16 \
|
| 158 |
+
--config set:trainer.max_epochs=100
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
Monitor with TensorBoard: `uv run tensorboard --logdir logs/`
|
| 162 |
+
|
| 163 |
+
| Parameter | Description | Default |
|
| 164 |
+
|---|---|---|
|
| 165 |
+
| `model.num_blocks` | Encoder/decoder depth | `5` |
|
| 166 |
+
| `model.forecast_steps` | Future steps to predict | `12` |
|
| 167 |
+
| `model.ensemble_size` | Ensemble members during training | `2` |
|
| 168 |
+
| `model.loss_class` | Loss function (`mse`, `mae`, `crps`, `afcrps`) | `crps` |
|
| 169 |
+
| `model.masked_loss` | Mask NaN regions in loss | `True` |
|
| 170 |
+
| `datamodule.steps` | Total timesteps per sample (past + future) | `18` |
|
| 171 |
+
| `datamodule.batch_size` | Batch size | `16` |
|
| 172 |
+
|
| 173 |
+
## Architecture
|
| 174 |
+
|
| 175 |
+
```
|
| 176 |
+
Input (B, T_past, 1, H, W)
|
| 177 |
+
|
|
| 178 |
+
v
|
| 179 |
+
+--------------------------+
|
| 180 |
+
| Encoder | ConvGRU + PixelUnshuffle (x num_blocks)
|
| 181 |
+
| Spatial dims halve at | Channels: 1 -> 4 -> 16 -> 64 -> 256 -> 1024
|
| 182 |
+
| each block |
|
| 183 |
+
+----------+---------------+
|
| 184 |
+
| hidden states
|
| 185 |
+
v
|
| 186 |
+
+--------------------------+
|
| 187 |
+
| Decoder | ConvGRU + PixelShuffle (x num_blocks)
|
| 188 |
+
| Noise input (x M runs) | Each run produces one ensemble member
|
| 189 |
+
| for ensemble generation |
|
| 190 |
+
+----------+---------------+
|
| 191 |
+
|
|
| 192 |
+
v
|
| 193 |
+
Output (B, T_future, M, H, W)
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
## Docker
|
| 197 |
+
|
| 198 |
+
```bash
|
| 199 |
+
docker build -t convgru-ensemble .
|
| 200 |
+
|
| 201 |
+
# Run with local checkpoint
|
| 202 |
+
docker run -p 8000:8000 -v ./checkpoints:/app/checkpoints \
|
| 203 |
+
-e MODEL_CHECKPOINT=/app/checkpoints/model.ckpt convgru-ensemble
|
| 204 |
+
|
| 205 |
+
# Run with HuggingFace Hub
|
| 206 |
+
docker run -p 8000:8000 -e HF_REPO_ID=it4lia/irene convgru-ensemble
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
## Project Structure
|
| 210 |
+
|
| 211 |
+
```
|
| 212 |
+
ConvGRU-Ensemble/
|
| 213 |
+
+-- convgru_ensemble/ # Python package
|
| 214 |
+
| +-- model.py # ConvGRU encoder-decoder architecture
|
| 215 |
+
| +-- losses.py # CRPS, afCRPS, masked loss wrappers
|
| 216 |
+
| +-- lightning_model.py # PyTorch Lightning training module
|
| 217 |
+
| +-- datamodule.py # Dataset and data loading
|
| 218 |
+
| +-- train.py # Training entry point (Fiddle config)
|
| 219 |
+
| +-- utils.py # Rain rate <-> reflectivity conversions
|
| 220 |
+
| +-- hub.py # HuggingFace Hub upload/download
|
| 221 |
+
| +-- cli.py # CLI for inference and serving
|
| 222 |
+
| +-- serve.py # FastAPI inference server
|
| 223 |
+
+-- examples/ # Sample data for testing
|
| 224 |
+
+-- importance_sampler/ # Data preparation scripts
|
| 225 |
+
+-- notebooks/ # Example notebooks
|
| 226 |
+
+-- scripts/ # Utility scripts (e.g., upload to Hub)
|
| 227 |
+
+-- tests/ # Test suite
|
| 228 |
+
+-- Dockerfile # Container for serving API
|
| 229 |
+
+-- MODEL_CARD.md # HuggingFace model card template
|
| 230 |
+
```
|
| 231 |
|
| 232 |
## Acknowledgements
|
| 233 |
|
| 234 |
+
<div align="center">
|
| 235 |
+
|
| 236 |
+
Developed at **Fondazione Bruno Kessler (FBK)**, Trento, Italy, as part of the **Italian AI-Factory (IT4LIA)**, an EU-funded initiative supporting AI adoption across SMEs, academia, and public/private sectors. This work showcases capabilities in the **Earth (weather and climate) vertical domain**.
|
| 237 |
+
|
| 238 |
+
<br>
|
| 239 |
+
|
| 240 |
+
<a href="https://www.fbk.eu"><img src="https://webvalley.fbk.eu/static/img/logos/fbk-logo-blue.png" height="45" alt="Fondazione Bruno Kessler"></a>
|
| 241 |
+
|
| 242 |
+
<a href="https://it4lia-aifactory.eu"><img src="https://it4lia-aifactory.eu/wp-content/uploads/2025/05/logo-IT4LIA-AI-factory.svg" height="45" alt="IT4LIA AI-Factory"></a>
|
| 243 |
+
|
| 244 |
+
<a href="https://www.italiameteo.eu"><img src="https://it4lia-aifactory.eu/wp-content/uploads/2025/08/logo-italiameteo.svg" height="45" alt="ItaliaMeteo"></a>
|
| 245 |
|
| 246 |
+
</div>
|
| 247 |
|
| 248 |
## License
|
| 249 |
|
| 250 |
+
BSD 2-Clause — see [LICENSE](LICENSE).
|
convgru_ensemble/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ConvGRU-Ensemble: Ensemble precipitation nowcasting using Convolutional GRU networks."""
|
| 2 |
+
|
| 3 |
+
__version__ = "0.1.0"
|
| 4 |
+
|
| 5 |
+
from convgru_ensemble.lightning_model import RadarLightningModel
|
| 6 |
+
from convgru_ensemble.model import EncoderDecoder
|
| 7 |
+
|
| 8 |
+
__all__ = ["EncoderDecoder", "RadarLightningModel"]
|
convgru_ensemble/__main__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Allow running the package as ``python -m convgru_ensemble``."""
|
| 2 |
+
|
| 3 |
+
from convgru_ensemble.cli import main
|
| 4 |
+
|
| 5 |
+
main()
|
convgru_ensemble/cli.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Command-line interface for ConvGRU-Ensemble inference and serving."""
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import fire
|
| 6 |
+
import numpy as np
|
| 7 |
+
import xarray as xr
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _load_model(checkpoint: str | None = None, hub_repo: str | None = None, device: str = "cpu"):
|
| 11 |
+
"""Load model from local checkpoint or HuggingFace Hub."""
|
| 12 |
+
from .lightning_model import RadarLightningModel
|
| 13 |
+
|
| 14 |
+
if hub_repo is not None:
|
| 15 |
+
print(f"Loading model from HuggingFace Hub: {hub_repo}")
|
| 16 |
+
return RadarLightningModel.from_pretrained(hub_repo, device=device)
|
| 17 |
+
elif checkpoint is not None:
|
| 18 |
+
print(f"Loading model from checkpoint: {checkpoint}")
|
| 19 |
+
return RadarLightningModel.from_checkpoint(checkpoint, device=device)
|
| 20 |
+
else:
|
| 21 |
+
raise ValueError("Either --checkpoint or --hub-repo must be provided.")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def predict(
|
| 25 |
+
input: str,
|
| 26 |
+
checkpoint: str | None = None,
|
| 27 |
+
hub_repo: str | None = None,
|
| 28 |
+
variable: str = "RR",
|
| 29 |
+
forecast_steps: int = 12,
|
| 30 |
+
ensemble_size: int = 10,
|
| 31 |
+
device: str = "cpu",
|
| 32 |
+
output: str = "predictions.nc",
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Run inference on a NetCDF input file and save predictions as NetCDF.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
input: Path to input NetCDF file with rain rate data (T, H, W) or (T, Y, X).
|
| 39 |
+
checkpoint: Path to local .ckpt checkpoint file.
|
| 40 |
+
hub_repo: HuggingFace Hub repo ID (e.g., 'it4lia/irene'). Alternative to --checkpoint.
|
| 41 |
+
variable: Name of the rain rate variable in the NetCDF file.
|
| 42 |
+
forecast_steps: Number of future timesteps to forecast.
|
| 43 |
+
ensemble_size: Number of ensemble members to generate.
|
| 44 |
+
device: Device for inference ('cpu' or 'cuda').
|
| 45 |
+
output: Path for the output NetCDF file.
|
| 46 |
+
"""
|
| 47 |
+
model = _load_model(checkpoint, hub_repo, device)
|
| 48 |
+
|
| 49 |
+
# Load input data
|
| 50 |
+
print(f"Loading input: {input}")
|
| 51 |
+
ds = xr.open_dataset(input)
|
| 52 |
+
if variable not in ds:
|
| 53 |
+
available = list(ds.data_vars)
|
| 54 |
+
raise ValueError(f"Variable '{variable}' not found. Available: {available}")
|
| 55 |
+
|
| 56 |
+
data = ds[variable].values # (T, H, W) or similar
|
| 57 |
+
if data.ndim != 3:
|
| 58 |
+
raise ValueError(f"Expected 3D data (T, H, W), got shape {data.shape}")
|
| 59 |
+
|
| 60 |
+
print(f"Input shape: {data.shape}")
|
| 61 |
+
past = data.astype(np.float32)
|
| 62 |
+
|
| 63 |
+
# Run inference
|
| 64 |
+
t0 = time.perf_counter()
|
| 65 |
+
preds = model.predict(past, forecast_steps=forecast_steps, ensemble_size=ensemble_size)
|
| 66 |
+
elapsed = time.perf_counter() - t0
|
| 67 |
+
print(f"Output shape: {preds.shape} (ensemble, time, H, W)")
|
| 68 |
+
print(f"Elapsed: {elapsed:.2f}s")
|
| 69 |
+
|
| 70 |
+
# Build output dataset
|
| 71 |
+
ds_out = xr.Dataset(
|
| 72 |
+
{
|
| 73 |
+
"precipitation_forecast": xr.DataArray(
|
| 74 |
+
data=preds,
|
| 75 |
+
dims=["ensemble_member", "forecast_step", "y", "x"],
|
| 76 |
+
attrs={"units": "mm/h", "long_name": "Ensemble precipitation forecast"},
|
| 77 |
+
),
|
| 78 |
+
},
|
| 79 |
+
attrs={
|
| 80 |
+
"model": "ConvGRU-Ensemble",
|
| 81 |
+
"forecast_steps": forecast_steps,
|
| 82 |
+
"ensemble_size": ensemble_size,
|
| 83 |
+
"source_file": str(input),
|
| 84 |
+
},
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
ds_out.to_netcdf(output)
|
| 88 |
+
print(f"Predictions saved to: {output}")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def serve(
|
| 92 |
+
checkpoint: str | None = None,
|
| 93 |
+
hub_repo: str | None = None,
|
| 94 |
+
host: str = "0.0.0.0",
|
| 95 |
+
port: int = 8000,
|
| 96 |
+
device: str = "cpu",
|
| 97 |
+
):
|
| 98 |
+
"""
|
| 99 |
+
Start the FastAPI inference server.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
checkpoint: Path to local .ckpt checkpoint file.
|
| 103 |
+
hub_repo: HuggingFace Hub repo ID (e.g., 'it4lia/irene'). Alternative to --checkpoint.
|
| 104 |
+
host: Host to bind to.
|
| 105 |
+
port: Port to listen on.
|
| 106 |
+
device: Device for inference ('cpu' or 'cuda').
|
| 107 |
+
"""
|
| 108 |
+
import os
|
| 109 |
+
|
| 110 |
+
if checkpoint is not None:
|
| 111 |
+
os.environ["MODEL_CHECKPOINT"] = checkpoint
|
| 112 |
+
if hub_repo is not None:
|
| 113 |
+
os.environ["HF_REPO_ID"] = hub_repo
|
| 114 |
+
os.environ.setdefault("DEVICE", device)
|
| 115 |
+
|
| 116 |
+
import uvicorn
|
| 117 |
+
|
| 118 |
+
uvicorn.run("convgru_ensemble.serve:app", host=host, port=port)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def main():
|
| 122 |
+
fire.Fire({"predict": predict, "serve": serve})
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
if __name__ == "__main__":
|
| 126 |
+
main()
|
convgru_ensemble/datamodule.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import pytorch_lightning as pl
|
| 6 |
+
import torch
|
| 7 |
+
import xarray as xr
|
| 8 |
+
from torch.utils.data import DataLoader, Dataset
|
| 9 |
+
|
| 10 |
+
from .utils import rainrate_to_normalized
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SampledRadarDataset(Dataset):
|
| 14 |
+
"""
|
| 15 |
+
PyTorch dataset that loads radar datacubes from a Zarr store using
|
| 16 |
+
pre-sampled spatial-temporal coordinates from a CSV file.
|
| 17 |
+
|
| 18 |
+
Each sample is a spatio-temporal datacube of shape ``(T, 1, H, W)``
|
| 19 |
+
converted from rain rate to normalized reflectivity.
|
| 20 |
+
|
| 21 |
+
Parameters
|
| 22 |
+
----------
|
| 23 |
+
zarr_path : str
|
| 24 |
+
Path to the Zarr dataset containing the ``'RR'`` rain rate variable.
|
| 25 |
+
csv_path : str
|
| 26 |
+
Path to the CSV file with columns ``(t, x, y)`` specifying the
|
| 27 |
+
top-left corner of each datacube.
|
| 28 |
+
steps : int
|
| 29 |
+
Number of timesteps to extract per sample.
|
| 30 |
+
return_mask : bool, optional
|
| 31 |
+
If ``True``, also return a spatial NaN mask. Default is ``False``.
|
| 32 |
+
deterministic : bool, optional
|
| 33 |
+
If ``True``, use a fixed random seed (42) for reproducibility.
|
| 34 |
+
Default is ``False``.
|
| 35 |
+
augment : bool, optional
|
| 36 |
+
If ``True``, apply random spatial augmentations (rotation, flips).
|
| 37 |
+
Default is ``False``.
|
| 38 |
+
indices : sequence of int or None, optional
|
| 39 |
+
Subset of row indices to use from the CSV. If ``None``, use all rows.
|
| 40 |
+
Default is ``None``.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
zarr_path: str,
|
| 46 |
+
csv_path: str,
|
| 47 |
+
steps: int,
|
| 48 |
+
return_mask: bool = False,
|
| 49 |
+
deterministic: bool = False,
|
| 50 |
+
augment: bool = False,
|
| 51 |
+
indices=None,
|
| 52 |
+
):
|
| 53 |
+
"""
|
| 54 |
+
Initialize SampledRadarDataset.
|
| 55 |
+
|
| 56 |
+
Parameters
|
| 57 |
+
----------
|
| 58 |
+
zarr_path : str
|
| 59 |
+
Path to the Zarr dataset containing the ``'RR'`` rain rate
|
| 60 |
+
variable.
|
| 61 |
+
csv_path : str
|
| 62 |
+
Path to the CSV file with columns ``(t, x, y)``.
|
| 63 |
+
steps : int
|
| 64 |
+
Number of timesteps to extract per sample.
|
| 65 |
+
return_mask : bool, optional
|
| 66 |
+
If ``True``, also return a spatial NaN mask. Default is ``False``.
|
| 67 |
+
deterministic : bool, optional
|
| 68 |
+
If ``True``, use a fixed random seed (42). Default is ``False``.
|
| 69 |
+
augment : bool, optional
|
| 70 |
+
If ``True``, apply random spatial augmentations. Default is
|
| 71 |
+
``False``.
|
| 72 |
+
indices : sequence of int or None, optional
|
| 73 |
+
Subset of row indices from the CSV. Default is ``None``.
|
| 74 |
+
"""
|
| 75 |
+
self.coords = pd.read_csv(csv_path).sort_values("t")
|
| 76 |
+
if indices is not None:
|
| 77 |
+
self.coords = self.coords.iloc[list(indices)].reset_index(drop=True)
|
| 78 |
+
self.zg = xr.open_zarr(zarr_path)
|
| 79 |
+
self.RR = self.zg["RR"]
|
| 80 |
+
self.rng = np.random.default_rng(seed=42) if deterministic else np.random.default_rng(int(time.time()))
|
| 81 |
+
self.return_mask = return_mask
|
| 82 |
+
self.augment = augment
|
| 83 |
+
|
| 84 |
+
if augment:
|
| 85 |
+
print("Data augmentation is enabled.")
|
| 86 |
+
|
| 87 |
+
# default valid grid size and time step
|
| 88 |
+
self.w = 256
|
| 89 |
+
self.h = 256
|
| 90 |
+
self.dt = 24
|
| 91 |
+
self.steps = steps
|
| 92 |
+
|
| 93 |
+
# raise warning if steps > dt
|
| 94 |
+
if self.steps > self.dt:
|
| 95 |
+
print(f"Warning: requested steps ({self.steps}) > sampled time window ({self.dt})")
|
| 96 |
+
|
| 97 |
+
def __len__(self):
|
| 98 |
+
"""
|
| 99 |
+
Return the number of samples in the dataset.
|
| 100 |
+
|
| 101 |
+
Returns
|
| 102 |
+
-------
|
| 103 |
+
length : int
|
| 104 |
+
Number of datacube samples.
|
| 105 |
+
"""
|
| 106 |
+
return len(self.coords)
|
| 107 |
+
|
| 108 |
+
def shape(self):
|
| 109 |
+
"""
|
| 110 |
+
Return the nominal shape of the full dataset.
|
| 111 |
+
|
| 112 |
+
Returns
|
| 113 |
+
-------
|
| 114 |
+
shape : tuple of int
|
| 115 |
+
``(num_samples, steps, 1, width, height)``.
|
| 116 |
+
"""
|
| 117 |
+
return (len(self.coords), self.steps, 1, self.w, self.h)
|
| 118 |
+
|
| 119 |
+
def _apply_augmentations(
|
| 120 |
+
self, *tensors, rotate_prob: float = 0.5, hflip_prob: float = 0.5, vflip_prob: float = 0.5
|
| 121 |
+
):
|
| 122 |
+
"""
|
| 123 |
+
Apply random spatial augmentations consistently to all input tensors.
|
| 124 |
+
|
| 125 |
+
All tensors receive the same random transformation so that spatial
|
| 126 |
+
alignment is preserved (e.g. between data and mask).
|
| 127 |
+
|
| 128 |
+
Parameters
|
| 129 |
+
----------
|
| 130 |
+
*tensors : torch.Tensor
|
| 131 |
+
One or more tensors of shape ``(T, C, H, W)``.
|
| 132 |
+
rotate_prob : float, optional
|
| 133 |
+
Probability of applying a random 90-degree rotation. Default is
|
| 134 |
+
``0.5``.
|
| 135 |
+
hflip_prob : float, optional
|
| 136 |
+
Probability of applying a horizontal flip. Default is ``0.5``.
|
| 137 |
+
vflip_prob : float, optional
|
| 138 |
+
Probability of applying a vertical flip. Default is ``0.5``.
|
| 139 |
+
|
| 140 |
+
Returns
|
| 141 |
+
-------
|
| 142 |
+
augmented : torch.Tensor or tuple of torch.Tensor
|
| 143 |
+
Single tensor if one input was given, otherwise a tuple of
|
| 144 |
+
augmented tensors.
|
| 145 |
+
"""
|
| 146 |
+
# Random 90-degree rotation (0, 90, 180, or 270 degrees)
|
| 147 |
+
if self.rng.random() < rotate_prob:
|
| 148 |
+
k = self.rng.integers(1, 4) # 1=90, 2=180, 3=270 degrees
|
| 149 |
+
tensors = [torch.rot90(t, k, dims=[-2, -1]) for t in tensors]
|
| 150 |
+
|
| 151 |
+
# Random horizontal flip
|
| 152 |
+
if self.rng.random() < hflip_prob:
|
| 153 |
+
tensors = [torch.flip(t, dims=[-1]) for t in tensors]
|
| 154 |
+
|
| 155 |
+
# Random vertical flip
|
| 156 |
+
if self.rng.random() < vflip_prob:
|
| 157 |
+
tensors = [torch.flip(t, dims=[-2]) for t in tensors]
|
| 158 |
+
|
| 159 |
+
tensors = [t.contiguous() for t in tensors]
|
| 160 |
+
return tensors[0] if len(tensors) == 1 else tuple(tensors)
|
| 161 |
+
|
| 162 |
+
def __getitem__(self, idx: int):
|
| 163 |
+
"""
|
| 164 |
+
Load and return a single datacube sample.
|
| 165 |
+
|
| 166 |
+
Parameters
|
| 167 |
+
----------
|
| 168 |
+
idx : int
|
| 169 |
+
Index of the sample in the dataset.
|
| 170 |
+
|
| 171 |
+
Returns
|
| 172 |
+
-------
|
| 173 |
+
sample : dict of str to torch.Tensor
|
| 174 |
+
Dictionary with key ``'data'`` containing a tensor of shape
|
| 175 |
+
``(T, 1, H, W)``. If ``return_mask`` is ``True``, also contains
|
| 176 |
+
``'mask'`` of shape ``(1, 1, H, W)``.
|
| 177 |
+
"""
|
| 178 |
+
t0, x0, y0 = self.coords.iloc[idx]
|
| 179 |
+
|
| 180 |
+
x_slice = slice(x0, x0 + self.w)
|
| 181 |
+
y_slice = slice(y0, y0 + self.h)
|
| 182 |
+
|
| 183 |
+
if self.steps < self.dt:
|
| 184 |
+
# radom sampling within available time window
|
| 185 |
+
t_start = self.rng.integers(t0, t0 + self.dt - self.steps + 1)
|
| 186 |
+
else:
|
| 187 |
+
t_start = t0
|
| 188 |
+
t_slice = slice(t_start, t_start + self.steps)
|
| 189 |
+
|
| 190 |
+
data = rainrate_to_normalized(self.RR[t_slice, x_slice, y_slice])
|
| 191 |
+
|
| 192 |
+
# create a mask for all nan values over time dimension
|
| 193 |
+
# shape: (1, H, W) - NOT repeated over time, broadcasting handles it
|
| 194 |
+
if self.return_mask:
|
| 195 |
+
mask = (~(np.isnan(data).any(axis=0, keepdims=True))).astype(np.float32)
|
| 196 |
+
|
| 197 |
+
# replace nan values with -1
|
| 198 |
+
data = np.nan_to_num(data, nan=-1.0)
|
| 199 |
+
|
| 200 |
+
# convert to tensors
|
| 201 |
+
data = torch.from_numpy(data[:, np.newaxis, :, :])
|
| 202 |
+
if self.return_mask:
|
| 203 |
+
mask = torch.from_numpy(mask.values[:, np.newaxis, :, :])
|
| 204 |
+
|
| 205 |
+
# apply augmentations (training only)
|
| 206 |
+
if self.augment:
|
| 207 |
+
if self.return_mask:
|
| 208 |
+
data, mask = self._apply_augmentations(data, mask)
|
| 209 |
+
else:
|
| 210 |
+
data = self._apply_augmentations(data)
|
| 211 |
+
|
| 212 |
+
if self.return_mask:
|
| 213 |
+
return {"data": data, "mask": mask}
|
| 214 |
+
else:
|
| 215 |
+
return {"data": data}
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class RadarDataModule(pl.LightningDataModule):
|
| 219 |
+
"""
|
| 220 |
+
PyTorch Lightning data module for radar datacube datasets.
|
| 221 |
+
|
| 222 |
+
Handles train/val/test splitting and DataLoader creation from a single
|
| 223 |
+
Zarr store and CSV coordinate file.
|
| 224 |
+
|
| 225 |
+
Parameters
|
| 226 |
+
----------
|
| 227 |
+
zarr_path : str
|
| 228 |
+
Path to the Zarr dataset.
|
| 229 |
+
csv_path : str
|
| 230 |
+
Path to the CSV file with datacube coordinates.
|
| 231 |
+
steps : int
|
| 232 |
+
Number of timesteps per sample.
|
| 233 |
+
train_ratio : float, optional
|
| 234 |
+
Fraction of data used for training. Default is ``0.7``.
|
| 235 |
+
val_ratio : float, optional
|
| 236 |
+
Fraction of data used for validation. Default is ``0.15``.
|
| 237 |
+
return_mask : bool, optional
|
| 238 |
+
Whether to return NaN masks. Default is ``False``.
|
| 239 |
+
deterministic : bool, optional
|
| 240 |
+
Whether to use fixed random seeds. Default is ``False``.
|
| 241 |
+
augment : bool, optional
|
| 242 |
+
Whether to apply data augmentation (training set only). Default is
|
| 243 |
+
``True``.
|
| 244 |
+
**dataloader_kwargs
|
| 245 |
+
Additional keyword arguments forwarded to ``DataLoader`` (e.g.
|
| 246 |
+
``batch_size``, ``num_workers``, ``pin_memory``).
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
def __init__(
|
| 250 |
+
self,
|
| 251 |
+
zarr_path,
|
| 252 |
+
csv_path,
|
| 253 |
+
steps,
|
| 254 |
+
train_ratio=0.7,
|
| 255 |
+
val_ratio=0.15,
|
| 256 |
+
return_mask=False,
|
| 257 |
+
deterministic=False,
|
| 258 |
+
augment=True,
|
| 259 |
+
**dataloader_kwargs,
|
| 260 |
+
):
|
| 261 |
+
"""
|
| 262 |
+
Initialize RadarDataModule.
|
| 263 |
+
|
| 264 |
+
Parameters
|
| 265 |
+
----------
|
| 266 |
+
zarr_path : str
|
| 267 |
+
Path to the Zarr dataset.
|
| 268 |
+
csv_path : str
|
| 269 |
+
Path to the CSV file with datacube coordinates.
|
| 270 |
+
steps : int
|
| 271 |
+
Number of timesteps per sample.
|
| 272 |
+
train_ratio : float, optional
|
| 273 |
+
Fraction of data for training. Default is ``0.7``.
|
| 274 |
+
val_ratio : float, optional
|
| 275 |
+
Fraction of data for validation. Default is ``0.15``.
|
| 276 |
+
return_mask : bool, optional
|
| 277 |
+
Whether to return NaN masks. Default is ``False``.
|
| 278 |
+
deterministic : bool, optional
|
| 279 |
+
Whether to use fixed random seeds. Default is ``False``.
|
| 280 |
+
augment : bool, optional
|
| 281 |
+
Whether to apply data augmentation. Default is ``True``.
|
| 282 |
+
**dataloader_kwargs
|
| 283 |
+
Forwarded to ``DataLoader``.
|
| 284 |
+
"""
|
| 285 |
+
super().__init__()
|
| 286 |
+
self.zarr_path = zarr_path
|
| 287 |
+
self.csv_path = csv_path
|
| 288 |
+
self.steps = steps
|
| 289 |
+
self.train_ratio = train_ratio
|
| 290 |
+
self.val_ratio = val_ratio
|
| 291 |
+
self.dataloader_kwargs = dataloader_kwargs
|
| 292 |
+
self.return_mask = return_mask
|
| 293 |
+
self.deterministic = deterministic
|
| 294 |
+
self.augment = augment
|
| 295 |
+
|
| 296 |
+
def setup(self, stage=None):
|
| 297 |
+
"""
|
| 298 |
+
Create train, validation, and test datasets from the CSV coordinates.
|
| 299 |
+
|
| 300 |
+
Splits are chronological: the first ``train_ratio`` fraction is used
|
| 301 |
+
for training, the next ``val_ratio`` for validation, and the rest for
|
| 302 |
+
testing. Augmentation is only applied to the training set.
|
| 303 |
+
|
| 304 |
+
Parameters
|
| 305 |
+
----------
|
| 306 |
+
stage : str or None, optional
|
| 307 |
+
Lightning stage (``'fit'``, ``'test'``, etc.). Ignored; all
|
| 308 |
+
datasets are always created. Default is ``None``.
|
| 309 |
+
"""
|
| 310 |
+
# Load CSV to get total length for splitting
|
| 311 |
+
coords = pd.read_csv(self.csv_path).sort_values("t")
|
| 312 |
+
n = len(coords)
|
| 313 |
+
|
| 314 |
+
# Compute split indices
|
| 315 |
+
train_end = int(n * self.train_ratio)
|
| 316 |
+
val_end = int(n * (self.train_ratio + self.val_ratio))
|
| 317 |
+
|
| 318 |
+
# Create separate datasets (augmentation only for training)
|
| 319 |
+
self.train_dataset = SampledRadarDataset(
|
| 320 |
+
self.zarr_path,
|
| 321 |
+
self.csv_path,
|
| 322 |
+
self.steps,
|
| 323 |
+
self.return_mask,
|
| 324 |
+
self.deterministic,
|
| 325 |
+
augment=self.augment,
|
| 326 |
+
indices=range(0, train_end),
|
| 327 |
+
)
|
| 328 |
+
self.val_dataset = SampledRadarDataset(
|
| 329 |
+
self.zarr_path,
|
| 330 |
+
self.csv_path,
|
| 331 |
+
self.steps,
|
| 332 |
+
self.return_mask,
|
| 333 |
+
self.deterministic,
|
| 334 |
+
augment=False,
|
| 335 |
+
indices=range(train_end, val_end),
|
| 336 |
+
)
|
| 337 |
+
self.test_dataset = SampledRadarDataset(
|
| 338 |
+
self.zarr_path,
|
| 339 |
+
self.csv_path,
|
| 340 |
+
self.steps,
|
| 341 |
+
self.return_mask,
|
| 342 |
+
self.deterministic,
|
| 343 |
+
augment=False,
|
| 344 |
+
indices=range(val_end, n),
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
def train_dataloader(self):
|
| 348 |
+
"""
|
| 349 |
+
Return the training DataLoader.
|
| 350 |
+
|
| 351 |
+
Returns
|
| 352 |
+
-------
|
| 353 |
+
loader : DataLoader
|
| 354 |
+
DataLoader over the training dataset with shuffling enabled.
|
| 355 |
+
"""
|
| 356 |
+
return DataLoader(self.train_dataset, shuffle=True, **self.dataloader_kwargs)
|
| 357 |
+
|
| 358 |
+
def val_dataloader(self):
|
| 359 |
+
"""
|
| 360 |
+
Return the validation DataLoader.
|
| 361 |
+
|
| 362 |
+
Returns
|
| 363 |
+
-------
|
| 364 |
+
loader : DataLoader
|
| 365 |
+
DataLoader over the validation dataset without shuffling.
|
| 366 |
+
"""
|
| 367 |
+
return DataLoader(self.val_dataset, shuffle=False, **self.dataloader_kwargs)
|
| 368 |
+
|
| 369 |
+
def test_dataloader(self):
|
| 370 |
+
"""
|
| 371 |
+
Return the test DataLoader.
|
| 372 |
+
|
| 373 |
+
Returns
|
| 374 |
+
-------
|
| 375 |
+
loader : DataLoader
|
| 376 |
+
DataLoader over the test dataset without shuffling.
|
| 377 |
+
"""
|
| 378 |
+
return DataLoader(self.test_dataset, shuffle=False, **self.dataloader_kwargs)
|
convgru_ensemble/hub.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HuggingFace Hub integration for uploading and downloading ConvGRU-Ensemble models."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import shutil
|
| 5 |
+
import tempfile
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from huggingface_hub import HfApi, hf_hub_download
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def push_to_hub(
|
| 12 |
+
checkpoint_path: str,
|
| 13 |
+
repo_id: str,
|
| 14 |
+
model_card_path: str | None = None,
|
| 15 |
+
private: bool = False,
|
| 16 |
+
) -> str:
|
| 17 |
+
"""
|
| 18 |
+
Upload a trained model checkpoint to HuggingFace Hub.
|
| 19 |
+
|
| 20 |
+
Parameters
|
| 21 |
+
----------
|
| 22 |
+
checkpoint_path : str
|
| 23 |
+
Path to the ``.ckpt`` checkpoint file.
|
| 24 |
+
repo_id : str
|
| 25 |
+
HuggingFace Hub repository ID (e.g., ``'it4lia/irene'``).
|
| 26 |
+
model_card_path : str or None, optional
|
| 27 |
+
Path to a model card markdown file. If provided, it is uploaded
|
| 28 |
+
as ``README.md``. Default is ``None``.
|
| 29 |
+
private : bool, optional
|
| 30 |
+
Whether to create a private repository. Default is ``False``.
|
| 31 |
+
|
| 32 |
+
Returns
|
| 33 |
+
-------
|
| 34 |
+
url : str
|
| 35 |
+
URL of the uploaded model on HuggingFace Hub.
|
| 36 |
+
"""
|
| 37 |
+
import torch
|
| 38 |
+
|
| 39 |
+
api = HfApi()
|
| 40 |
+
api.create_repo(repo_id=repo_id, exist_ok=True, private=private)
|
| 41 |
+
|
| 42 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 43 |
+
tmp_path = Path(tmp_dir)
|
| 44 |
+
|
| 45 |
+
# Copy checkpoint
|
| 46 |
+
shutil.copy2(checkpoint_path, tmp_path / "model.ckpt")
|
| 47 |
+
|
| 48 |
+
# Extract and save model config from checkpoint
|
| 49 |
+
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| 50 |
+
if "hyper_parameters" in ckpt:
|
| 51 |
+
hparams = ckpt["hyper_parameters"]
|
| 52 |
+
# Convert non-serializable values to strings
|
| 53 |
+
config = {}
|
| 54 |
+
for k, v in hparams.items():
|
| 55 |
+
try:
|
| 56 |
+
json.dumps(v)
|
| 57 |
+
config[k] = v
|
| 58 |
+
except (TypeError, ValueError):
|
| 59 |
+
config[k] = str(v)
|
| 60 |
+
with open(tmp_path / "config.json", "w") as f:
|
| 61 |
+
json.dump(config, f, indent=2)
|
| 62 |
+
|
| 63 |
+
# Copy model card as README.md
|
| 64 |
+
if model_card_path is not None:
|
| 65 |
+
shutil.copy2(model_card_path, tmp_path / "README.md")
|
| 66 |
+
|
| 67 |
+
url = api.upload_folder(
|
| 68 |
+
folder_path=str(tmp_path),
|
| 69 |
+
repo_id=repo_id,
|
| 70 |
+
commit_message="Upload ConvGRU-Ensemble model",
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
return url
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def from_pretrained(
|
| 77 |
+
repo_id: str,
|
| 78 |
+
filename: str = "model.ckpt",
|
| 79 |
+
device: str = "cpu",
|
| 80 |
+
) -> "RadarLightningModel": # noqa: F821
|
| 81 |
+
"""
|
| 82 |
+
Download and load a pretrained model from HuggingFace Hub.
|
| 83 |
+
|
| 84 |
+
Parameters
|
| 85 |
+
----------
|
| 86 |
+
repo_id : str
|
| 87 |
+
HuggingFace Hub repository ID (e.g., ``'it4lia/irene'``).
|
| 88 |
+
filename : str, optional
|
| 89 |
+
Name of the checkpoint file in the repository. Default is
|
| 90 |
+
``'model.ckpt'``.
|
| 91 |
+
device : str, optional
|
| 92 |
+
Device to map the model weights to. Default is ``'cpu'``.
|
| 93 |
+
|
| 94 |
+
Returns
|
| 95 |
+
-------
|
| 96 |
+
model : RadarLightningModel
|
| 97 |
+
Model with loaded pretrained weights.
|
| 98 |
+
"""
|
| 99 |
+
from .lightning_model import RadarLightningModel
|
| 100 |
+
|
| 101 |
+
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 102 |
+
return RadarLightningModel.from_checkpoint(ckpt_path, device=device)
|
convgru_ensemble/lightning_model.py
ADDED
|
@@ -0,0 +1,560 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
import torch
|
| 6 |
+
import torchvision
|
| 7 |
+
|
| 8 |
+
from .losses import build_loss
|
| 9 |
+
from .model import EncoderDecoder
|
| 10 |
+
from .utils import normalized_to_rainrate, rainrate_to_normalized
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def apply_radar_colormap(tensor: torch.Tensor) -> torch.Tensor:
|
| 14 |
+
"""
|
| 15 |
+
Convert grayscale radar values to RGB using the STEPS-BE colorscale.
|
| 16 |
+
|
| 17 |
+
Maps normalized values in [0, 1] (representing 0-60 dBZ) to a 14-color
|
| 18 |
+
discrete colormap. Pixels below 10 dBZ are rendered as white.
|
| 19 |
+
|
| 20 |
+
Parameters
|
| 21 |
+
----------
|
| 22 |
+
tensor : torch.Tensor
|
| 23 |
+
Grayscale tensor with values in [0, 1], of shape ``(N, 1, H, W)``.
|
| 24 |
+
|
| 25 |
+
Returns
|
| 26 |
+
-------
|
| 27 |
+
rgb : torch.Tensor
|
| 28 |
+
RGB tensor of shape ``(N, 3, H, W)`` with values in [0, 1].
|
| 29 |
+
"""
|
| 30 |
+
# STEPS-BE colors (RGB values normalized to 0-1)
|
| 31 |
+
colors = (
|
| 32 |
+
torch.tensor(
|
| 33 |
+
[
|
| 34 |
+
[0, 255, 255], # cyan
|
| 35 |
+
[0, 191, 255], # deepskyblue
|
| 36 |
+
[30, 144, 255], # dodgerblue
|
| 37 |
+
[0, 0, 255], # blue
|
| 38 |
+
[127, 255, 0], # chartreuse
|
| 39 |
+
[50, 205, 50], # limegreen
|
| 40 |
+
[0, 128, 0], # green
|
| 41 |
+
[0, 100, 0], # darkgreen
|
| 42 |
+
[255, 255, 0], # yellow
|
| 43 |
+
[255, 215, 0], # gold
|
| 44 |
+
[255, 165, 0], # orange
|
| 45 |
+
[255, 0, 0], # red
|
| 46 |
+
[255, 0, 255], # magenta
|
| 47 |
+
[139, 0, 139], # darkmagenta
|
| 48 |
+
],
|
| 49 |
+
dtype=torch.float32,
|
| 50 |
+
device=tensor.device,
|
| 51 |
+
)
|
| 52 |
+
/ 255.0
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# dBZ levels: 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60 (11 levels, 10 intervals)
|
| 56 |
+
# But we have 14 colors, so extend to cover 10-80 dBZ range with 5 dBZ steps
|
| 57 |
+
# Normalized thresholds (0-1 maps to 0-60 dBZ)
|
| 58 |
+
# We'll use 14 intervals from 10 dBZ onwards
|
| 59 |
+
num_colors = len(colors)
|
| 60 |
+
min_dbz_norm = 10 / 60 # ~0.167, below this is background
|
| 61 |
+
max_dbz_norm = 1.0
|
| 62 |
+
thresholds = torch.linspace(min_dbz_norm, max_dbz_norm, num_colors + 1, device=tensor.device)
|
| 63 |
+
|
| 64 |
+
# Output tensor (N, 3, H, W) - initialize with white for values below 10 dBZ
|
| 65 |
+
N, _, H, W = tensor.shape
|
| 66 |
+
output = torch.ones(N, 3, H, W, dtype=torch.float32, device=tensor.device)
|
| 67 |
+
|
| 68 |
+
# Apply colormap: find which bin each pixel falls into
|
| 69 |
+
for i in range(num_colors - 1):
|
| 70 |
+
mask = (tensor[:, 0] >= thresholds[i]) & (tensor[:, 0] < thresholds[i + 1])
|
| 71 |
+
for c in range(3):
|
| 72 |
+
output[:, c][mask] = colors[i, c]
|
| 73 |
+
|
| 74 |
+
# Last color handles all values >= second-to-last threshold (inclusive of max)
|
| 75 |
+
mask = tensor[:, 0] >= thresholds[num_colors - 1]
|
| 76 |
+
for c in range(3):
|
| 77 |
+
output[:, c][mask] = colors[-1, c]
|
| 78 |
+
|
| 79 |
+
return output
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class RadarLightningModel(pl.LightningModule):
|
| 83 |
+
"""
|
| 84 |
+
PyTorch Lightning module for radar precipitation nowcasting.
|
| 85 |
+
|
| 86 |
+
Wraps an :class:`EncoderDecoder` model and handles training, validation,
|
| 87 |
+
and test steps including loss computation, ensemble generation, and
|
| 88 |
+
TensorBoard image logging.
|
| 89 |
+
|
| 90 |
+
Parameters
|
| 91 |
+
----------
|
| 92 |
+
input_channels : int
|
| 93 |
+
Number of input channels per grid point.
|
| 94 |
+
num_blocks : int
|
| 95 |
+
Number of encoder/decoder blocks in the model.
|
| 96 |
+
ensemble_size : int, optional
|
| 97 |
+
Number of ensemble members to generate. Default is ``1``.
|
| 98 |
+
noisy_decoder : bool, optional
|
| 99 |
+
Whether to use random noise as decoder input. Default is ``False``.
|
| 100 |
+
forecast_steps : int or None, optional
|
| 101 |
+
Number of future timesteps to forecast. Default is ``None``.
|
| 102 |
+
loss_class : type, str, or None, optional
|
| 103 |
+
Loss function class or its string name (see ``PIXEL_LOSSES``).
|
| 104 |
+
Default is ``None`` (MSELoss).
|
| 105 |
+
loss_params : dict or None, optional
|
| 106 |
+
Keyword arguments for the loss constructor. Default is ``None``.
|
| 107 |
+
masked_loss : bool, optional
|
| 108 |
+
Whether to wrap the loss with :class:`MaskedLoss`. Default is
|
| 109 |
+
``False``.
|
| 110 |
+
optimizer_class : type or None, optional
|
| 111 |
+
Optimizer class. Default is ``None`` (Adam).
|
| 112 |
+
optimizer_params : dict or None, optional
|
| 113 |
+
Keyword arguments for the optimizer. Default is ``None``.
|
| 114 |
+
lr_scheduler_class : type or None, optional
|
| 115 |
+
Learning rate scheduler class. Default is ``None``.
|
| 116 |
+
lr_scheduler_params : dict or None, optional
|
| 117 |
+
Keyword arguments for the LR scheduler. Default is ``None``.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
input_channels: int,
|
| 123 |
+
num_blocks: int,
|
| 124 |
+
ensemble_size: int = 1,
|
| 125 |
+
noisy_decoder: bool = False,
|
| 126 |
+
forecast_steps: type | int | None = None,
|
| 127 |
+
loss_class: type | str | None = None,
|
| 128 |
+
loss_params: dict[str, Any] | None = None,
|
| 129 |
+
masked_loss: bool = False,
|
| 130 |
+
optimizer_class: type | None = None,
|
| 131 |
+
optimizer_params: dict[str, Any] | None = None,
|
| 132 |
+
lr_scheduler_class: type | None = None,
|
| 133 |
+
lr_scheduler_params: dict[str, Any] | None = None,
|
| 134 |
+
) -> None:
|
| 135 |
+
"""
|
| 136 |
+
Initialize RadarLightningModel.
|
| 137 |
+
|
| 138 |
+
Parameters
|
| 139 |
+
----------
|
| 140 |
+
input_channels : int
|
| 141 |
+
Number of input channels per grid point.
|
| 142 |
+
num_blocks : int
|
| 143 |
+
Number of encoder/decoder blocks.
|
| 144 |
+
ensemble_size : int, optional
|
| 145 |
+
Number of ensemble members. Default is ``1``.
|
| 146 |
+
noisy_decoder : bool, optional
|
| 147 |
+
Use random noise as decoder input. Default is ``False``.
|
| 148 |
+
forecast_steps : int or None, optional
|
| 149 |
+
Number of future timesteps to forecast. Default is ``None``.
|
| 150 |
+
loss_class : type, str, or None, optional
|
| 151 |
+
Loss function class or name. Default is ``None``.
|
| 152 |
+
loss_params : dict or None, optional
|
| 153 |
+
Loss constructor kwargs. Default is ``None``.
|
| 154 |
+
masked_loss : bool, optional
|
| 155 |
+
Wrap loss with masking. Default is ``False``.
|
| 156 |
+
optimizer_class : type or None, optional
|
| 157 |
+
Optimizer class. Default is ``None``.
|
| 158 |
+
optimizer_params : dict or None, optional
|
| 159 |
+
Optimizer kwargs. Default is ``None``.
|
| 160 |
+
lr_scheduler_class : type or None, optional
|
| 161 |
+
LR scheduler class. Default is ``None``.
|
| 162 |
+
lr_scheduler_params : dict or None, optional
|
| 163 |
+
LR scheduler kwargs. Default is ``None``.
|
| 164 |
+
"""
|
| 165 |
+
super().__init__()
|
| 166 |
+
self.save_hyperparameters()
|
| 167 |
+
|
| 168 |
+
# Initialize model
|
| 169 |
+
self.model = EncoderDecoder(self.hparams.input_channels, self.hparams.num_blocks)
|
| 170 |
+
|
| 171 |
+
self.criterion = build_loss(
|
| 172 |
+
loss_class=self.hparams.loss_class,
|
| 173 |
+
loss_params=self.hparams.loss_params,
|
| 174 |
+
masked_loss=self.hparams.masked_loss,
|
| 175 |
+
)
|
| 176 |
+
self.log_images_iterations = [50, 100, 200, 500, 750, 1000, 2000, 5000]
|
| 177 |
+
|
| 178 |
+
if self.hparams.ensemble_size > 1:
|
| 179 |
+
print(f"Using ensemble mode: {self.hparams.ensemble_size} independent ensemble members will be generated.")
|
| 180 |
+
|
| 181 |
+
def forward(self, x: torch.Tensor, forecast_steps: int, ensemble_size: int | None = None) -> torch.Tensor:
|
| 182 |
+
"""
|
| 183 |
+
Run the encoder-decoder forward pass.
|
| 184 |
+
|
| 185 |
+
Parameters
|
| 186 |
+
----------
|
| 187 |
+
x : torch.Tensor
|
| 188 |
+
Input tensor of shape ``(B, T, C, H, W)``.
|
| 189 |
+
forecast_steps : int
|
| 190 |
+
Number of future timesteps to forecast.
|
| 191 |
+
ensemble_size : int or None, optional
|
| 192 |
+
Number of ensemble members. If ``None``, uses the value from
|
| 193 |
+
``hparams``. Default is ``None``.
|
| 194 |
+
|
| 195 |
+
Returns
|
| 196 |
+
-------
|
| 197 |
+
preds : torch.Tensor
|
| 198 |
+
Predictions of shape ``(B, forecast_steps, C, H, W)`` or
|
| 199 |
+
``(B, forecast_steps, ensemble_size, H, W)`` for ensembles.
|
| 200 |
+
"""
|
| 201 |
+
ensemble_size = self.hparams.ensemble_size if ensemble_size is None else ensemble_size
|
| 202 |
+
return self.model(
|
| 203 |
+
x, steps=forecast_steps, noisy_decoder=self.hparams.noisy_decoder, ensemble_size=ensemble_size
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
def shared_step(
|
| 207 |
+
self, batch: dict[str, torch.Tensor], split: str = "train", ensemble_size: int | None = None
|
| 208 |
+
) -> torch.Tensor:
|
| 209 |
+
"""
|
| 210 |
+
Shared forward step used during training, validation, and testing.
|
| 211 |
+
|
| 212 |
+
Splits the input into past and future, runs the model, computes the
|
| 213 |
+
loss, and logs metrics and optional images.
|
| 214 |
+
|
| 215 |
+
Parameters
|
| 216 |
+
----------
|
| 217 |
+
batch : dict of str to torch.Tensor
|
| 218 |
+
Batch dictionary with key ``'data'`` of shape
|
| 219 |
+
``(B, T_total, C, H, W)`` and optionally ``'mask'``.
|
| 220 |
+
split : str, optional
|
| 221 |
+
One of ``'train'``, ``'val'``, or ``'test'``. Controls logging
|
| 222 |
+
behavior. Default is ``'train'``.
|
| 223 |
+
ensemble_size : int or None, optional
|
| 224 |
+
Override for the number of ensemble members. Default is ``None``.
|
| 225 |
+
|
| 226 |
+
Returns
|
| 227 |
+
-------
|
| 228 |
+
loss : torch.Tensor
|
| 229 |
+
Scalar loss value.
|
| 230 |
+
"""
|
| 231 |
+
data = batch["data"]
|
| 232 |
+
past = data[:, : -self.hparams.forecast_steps]
|
| 233 |
+
future = data[:, -self.hparams.forecast_steps :]
|
| 234 |
+
|
| 235 |
+
preds = self(past, forecast_steps=self.hparams.forecast_steps, ensemble_size=ensemble_size).clamp(
|
| 236 |
+
min=-1, max=1
|
| 237 |
+
) # Ensure predictions are within [-1, 1]
|
| 238 |
+
|
| 239 |
+
if self.hparams.masked_loss:
|
| 240 |
+
mask = batch["mask"][:, -self.hparams.forecast_steps :]
|
| 241 |
+
loss = self.criterion(preds, future, mask)
|
| 242 |
+
else:
|
| 243 |
+
loss = self.criterion(preds, future)
|
| 244 |
+
|
| 245 |
+
# Handle tuple return from composite losses
|
| 246 |
+
if isinstance(loss, tuple):
|
| 247 |
+
loss, log_dict = loss
|
| 248 |
+
# log_dict already contains split-prefixed keys like 'val/pixel_loss'
|
| 249 |
+
self.log_dict(
|
| 250 |
+
log_dict, prog_bar=False, logger=True, on_step=(split == "train"), on_epoch=True, sync_dist=True
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
self.log(f"{split}_loss", loss, prog_bar=True, on_epoch=True, on_step=(split == "train"), sync_dist=True)
|
| 254 |
+
|
| 255 |
+
# Log ensemble diversity for ensemble training
|
| 256 |
+
if self.hparams.ensemble_size > 1:
|
| 257 |
+
ensemble_std = preds.std(dim=2).mean() # std across ensemble members
|
| 258 |
+
self.log(f"{split}_ensemble_std", ensemble_std, on_epoch=True, sync_dist=True)
|
| 259 |
+
|
| 260 |
+
if split == "train" and (
|
| 261 |
+
self.global_step in self.log_images_iterations or self.global_step % self.log_images_iterations[-1] == 0
|
| 262 |
+
):
|
| 263 |
+
self.log_images(past, future, preds, split=split)
|
| 264 |
+
return loss
|
| 265 |
+
|
| 266 |
+
def log_images(self, past: torch.Tensor, future: torch.Tensor, preds: torch.Tensor, split: str = "val") -> None:
|
| 267 |
+
"""
|
| 268 |
+
Log radar image grids to TensorBoard.
|
| 269 |
+
|
| 270 |
+
Visualizes the first sample in the batch, showing past frames, ground
|
| 271 |
+
truth future, ensemble average, and individual ensemble members using
|
| 272 |
+
the STEPS-BE radar colormap.
|
| 273 |
+
|
| 274 |
+
Parameters
|
| 275 |
+
----------
|
| 276 |
+
past : torch.Tensor
|
| 277 |
+
Past input frames of shape ``(B, T_past, C, H, W)``.
|
| 278 |
+
future : torch.Tensor
|
| 279 |
+
Ground truth future frames of shape ``(B, T_future, C, H, W)``.
|
| 280 |
+
preds : torch.Tensor
|
| 281 |
+
Predicted frames of shape ``(B, T_future, C_or_E, H, W)``.
|
| 282 |
+
split : str, optional
|
| 283 |
+
Split name used as TensorBoard tag prefix. Default is ``'val'``.
|
| 284 |
+
"""
|
| 285 |
+
# Log first sample in the batch
|
| 286 |
+
sample_idx = 0
|
| 287 |
+
|
| 288 |
+
# Log past separately
|
| 289 |
+
past_sample = past[sample_idx]
|
| 290 |
+
if self.hparams.ensemble_size > 1:
|
| 291 |
+
past_sample = past_sample.mean(dim=1, keepdim=True)
|
| 292 |
+
past_norm = (past_sample + 1) / 2
|
| 293 |
+
past_rgb = apply_radar_colormap(past_norm)
|
| 294 |
+
past_grid = torchvision.utils.make_grid(past_rgb, nrow=past_sample.shape[0])
|
| 295 |
+
self.logger.experiment.add_image(f"{split}/past", past_grid, self.global_step)
|
| 296 |
+
|
| 297 |
+
# Create combined preds grid: future (ground truth) as first row, then avg + ensemble members
|
| 298 |
+
future_sample = future[sample_idx] # (T, C, H, W)
|
| 299 |
+
preds_sample = preds[sample_idx] # (T, E, H, W) or (T, C, H, W)
|
| 300 |
+
|
| 301 |
+
if self.hparams.ensemble_size > 1:
|
| 302 |
+
# Layout: rows = [future, avg, member0, member1, ...], cols = timesteps
|
| 303 |
+
preds_avg = preds_sample.mean(dim=1, keepdim=True) # (T, E, H, W) -> (T, 1, H, W)
|
| 304 |
+
num_members_to_log = min(3, preds_sample.shape[1])
|
| 305 |
+
|
| 306 |
+
# Collect all rows: future first, then average, then individual members
|
| 307 |
+
rows = [future_sample] # (T, 1, H, W)
|
| 308 |
+
rows.append(preds_avg) # (T, 1, H, W)
|
| 309 |
+
for i in range(num_members_to_log):
|
| 310 |
+
rows.append(preds_sample[:, i : i + 1, :, :]) # (T, 1, H, W)
|
| 311 |
+
|
| 312 |
+
# Stack all rows: (num_rows * T, 1, H, W)
|
| 313 |
+
all_frames = torch.cat(rows, dim=0) # ((2 + num_members) * T, 1, H, W)
|
| 314 |
+
all_frames_norm = (all_frames + 1) / 2
|
| 315 |
+
all_frames_rgb = apply_radar_colormap(all_frames_norm)
|
| 316 |
+
grid = torchvision.utils.make_grid(all_frames_rgb, nrow=future_sample.shape[0])
|
| 317 |
+
self.logger.experiment.add_image(f"{split}/preds", grid, self.global_step)
|
| 318 |
+
else:
|
| 319 |
+
# For non-ensemble: show future and preds in two rows
|
| 320 |
+
rows = [future_sample, preds_sample] # Each is (T, C, H, W)
|
| 321 |
+
all_frames = torch.cat(rows, dim=0) # (2 * T, C, H, W)
|
| 322 |
+
all_frames_norm = (all_frames + 1) / 2
|
| 323 |
+
all_frames_rgb = apply_radar_colormap(all_frames_norm)
|
| 324 |
+
grid = torchvision.utils.make_grid(all_frames_rgb, nrow=future_sample.shape[0])
|
| 325 |
+
self.logger.experiment.add_image(f"{split}/preds", grid, self.global_step)
|
| 326 |
+
|
| 327 |
+
def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
|
| 328 |
+
"""
|
| 329 |
+
Execute a single training step.
|
| 330 |
+
|
| 331 |
+
Parameters
|
| 332 |
+
----------
|
| 333 |
+
batch : dict of str to torch.Tensor
|
| 334 |
+
Training batch.
|
| 335 |
+
batch_idx : int
|
| 336 |
+
Index of the batch.
|
| 337 |
+
|
| 338 |
+
Returns
|
| 339 |
+
-------
|
| 340 |
+
loss : torch.Tensor
|
| 341 |
+
Training loss.
|
| 342 |
+
"""
|
| 343 |
+
loss = self.shared_step(batch, split="train")
|
| 344 |
+
return loss
|
| 345 |
+
|
| 346 |
+
def validation_step(
|
| 347 |
+
self,
|
| 348 |
+
batch: dict[str, torch.Tensor],
|
| 349 |
+
batch_idx: int,
|
| 350 |
+
) -> torch.Tensor:
|
| 351 |
+
"""
|
| 352 |
+
Execute a single validation step.
|
| 353 |
+
|
| 354 |
+
Uses 10 ensemble members for evaluation.
|
| 355 |
+
|
| 356 |
+
Parameters
|
| 357 |
+
----------
|
| 358 |
+
batch : dict of str to torch.Tensor
|
| 359 |
+
Validation batch.
|
| 360 |
+
batch_idx : int
|
| 361 |
+
Index of the batch.
|
| 362 |
+
|
| 363 |
+
Returns
|
| 364 |
+
-------
|
| 365 |
+
loss : torch.Tensor
|
| 366 |
+
Validation loss.
|
| 367 |
+
"""
|
| 368 |
+
loss = self.shared_step(batch, split="val", ensemble_size=10)
|
| 369 |
+
return loss
|
| 370 |
+
|
| 371 |
+
def test_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
|
| 372 |
+
"""
|
| 373 |
+
Execute a single test step.
|
| 374 |
+
|
| 375 |
+
Uses 10 ensemble members for evaluation.
|
| 376 |
+
|
| 377 |
+
Parameters
|
| 378 |
+
----------
|
| 379 |
+
batch : dict of str to torch.Tensor
|
| 380 |
+
Test batch.
|
| 381 |
+
batch_idx : int
|
| 382 |
+
Index of the batch.
|
| 383 |
+
|
| 384 |
+
Returns
|
| 385 |
+
-------
|
| 386 |
+
loss : torch.Tensor
|
| 387 |
+
Test loss.
|
| 388 |
+
"""
|
| 389 |
+
loss = self.shared_step(batch, split="test", ensemble_size=10)
|
| 390 |
+
return loss
|
| 391 |
+
|
| 392 |
+
def configure_optimizers(self) -> dict[str, Any]:
|
| 393 |
+
"""
|
| 394 |
+
Configure the optimizer and optional learning rate scheduler.
|
| 395 |
+
|
| 396 |
+
Falls back to Adam with default parameters if no optimizer is
|
| 397 |
+
specified. If a scheduler is provided, it monitors ``val_loss``.
|
| 398 |
+
|
| 399 |
+
Returns
|
| 400 |
+
-------
|
| 401 |
+
config : dict
|
| 402 |
+
Dictionary with ``'optimizer'`` and optionally ``'lr_scheduler'``
|
| 403 |
+
keys, as expected by PyTorch Lightning.
|
| 404 |
+
"""
|
| 405 |
+
if self.hparams.optimizer_class is not None:
|
| 406 |
+
optimizer = (
|
| 407 |
+
self.hparams.optimizer_class(self.parameters(), **self.hparams.optimizer_params)
|
| 408 |
+
if self.hparams.optimizer_params is not None
|
| 409 |
+
else self.hparams.optimizer_class(self.parameters())
|
| 410 |
+
)
|
| 411 |
+
print(
|
| 412 |
+
f"Using optimizer: {self.hparams.optimizer_class.__name__} with params {self.hparams.optimizer_params}"
|
| 413 |
+
)
|
| 414 |
+
else:
|
| 415 |
+
optimizer = torch.optim.Adam(self.parameters())
|
| 416 |
+
print("Using default Adam optimizer with default parameters.")
|
| 417 |
+
|
| 418 |
+
if self.hparams.lr_scheduler_class is not None:
|
| 419 |
+
lr_scheduler = (
|
| 420 |
+
self.hparams.lr_scheduler_class(optimizer, **self.hparams.lr_scheduler_params)
|
| 421 |
+
if self.hparams.lr_scheduler_params is not None
|
| 422 |
+
else self.hparams.lr_scheduler_class(optimizer)
|
| 423 |
+
)
|
| 424 |
+
print(
|
| 425 |
+
f"Using LR scheduler: {self.hparams.lr_scheduler_class.__name__} with params {self.hparams.lr_scheduler_params}"
|
| 426 |
+
)
|
| 427 |
+
return {"optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "monitor": "val_loss"}}
|
| 428 |
+
else:
|
| 429 |
+
return {"optimizer": optimizer}
|
| 430 |
+
|
| 431 |
+
@classmethod
|
| 432 |
+
def from_checkpoint(cls, checkpoint_path: str, device: str = "cpu") -> "RadarLightningModel":
|
| 433 |
+
"""
|
| 434 |
+
Load a model from a checkpoint file.
|
| 435 |
+
|
| 436 |
+
Parameters
|
| 437 |
+
----------
|
| 438 |
+
checkpoint_path : str
|
| 439 |
+
Path to the ``.ckpt`` checkpoint file.
|
| 440 |
+
device : str, optional
|
| 441 |
+
Device to map the checkpoint weights to. Default is ``'cpu'``.
|
| 442 |
+
|
| 443 |
+
Returns
|
| 444 |
+
-------
|
| 445 |
+
model : RadarLightningModel
|
| 446 |
+
Model with loaded weights.
|
| 447 |
+
"""
|
| 448 |
+
return cls.load_from_checkpoint(
|
| 449 |
+
checkpoint_path,
|
| 450 |
+
map_location=torch.device(device),
|
| 451 |
+
strict=True,
|
| 452 |
+
weights_only=False,
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
@classmethod
|
| 456 |
+
def from_pretrained(cls, repo_id: str, filename: str = "model.ckpt", device: str = "cpu") -> "RadarLightningModel":
|
| 457 |
+
"""
|
| 458 |
+
Load a pretrained model from HuggingFace Hub.
|
| 459 |
+
|
| 460 |
+
Parameters
|
| 461 |
+
----------
|
| 462 |
+
repo_id : str
|
| 463 |
+
HuggingFace Hub repository ID (e.g., ``'it4lia/irene'``).
|
| 464 |
+
filename : str, optional
|
| 465 |
+
Name of the checkpoint file in the repository. Default is
|
| 466 |
+
``'model.ckpt'``.
|
| 467 |
+
device : str, optional
|
| 468 |
+
Device to map the model weights to. Default is ``'cpu'``.
|
| 469 |
+
|
| 470 |
+
Returns
|
| 471 |
+
-------
|
| 472 |
+
model : RadarLightningModel
|
| 473 |
+
Model with loaded pretrained weights.
|
| 474 |
+
"""
|
| 475 |
+
from .hub import from_pretrained
|
| 476 |
+
|
| 477 |
+
return from_pretrained(repo_id, filename, device)
|
| 478 |
+
|
| 479 |
+
def predict(self, past: torch.Tensor, forecast_steps: int = 1, ensemble_size: int | None = 1) -> torch.Tensor:
|
| 480 |
+
"""
|
| 481 |
+
Generate precipitation forecasts from past radar observations.
|
| 482 |
+
|
| 483 |
+
Handles padding, NaN removal, unit conversion, and reshaping
|
| 484 |
+
automatically. Input should be raw rain rate values.
|
| 485 |
+
|
| 486 |
+
Parameters
|
| 487 |
+
----------
|
| 488 |
+
past : torch.Tensor
|
| 489 |
+
Past radar frames as rain rate in mm/h, of shape ``(T, H, W)``.
|
| 490 |
+
forecast_steps : int, optional
|
| 491 |
+
Number of future timesteps to forecast. Default is ``1``.
|
| 492 |
+
ensemble_size : int, optional
|
| 493 |
+
Number of ensemble members to generate. If ``None``, uses the
|
| 494 |
+
value from ``hparams``. Default is ``1``.
|
| 495 |
+
|
| 496 |
+
Returns
|
| 497 |
+
-------
|
| 498 |
+
preds : np.ndarray
|
| 499 |
+
Forecasted rain rate in mm/h, of shape
|
| 500 |
+
``(ensemble_size, forecast_steps, H, W)``.
|
| 501 |
+
|
| 502 |
+
Raises
|
| 503 |
+
------
|
| 504 |
+
ValueError
|
| 505 |
+
If ``past`` does not have exactly 3 dimensions.
|
| 506 |
+
"""
|
| 507 |
+
if len(past.shape) != 3:
|
| 508 |
+
raise ValueError("Input must be of shape (T, H, W)")
|
| 509 |
+
|
| 510 |
+
T, H, W = past.shape
|
| 511 |
+
ensemble_size = self.hparams.ensemble_size if ensemble_size is None else ensemble_size
|
| 512 |
+
|
| 513 |
+
# Each block the model decrease the resolution by a factor of 2
|
| 514 |
+
# The input must be divisible by 2^(num_blocks-1)
|
| 515 |
+
divisor = 2 ** (self.hparams.num_blocks)
|
| 516 |
+
padH = (divisor - (H % divisor)) % divisor
|
| 517 |
+
padW = (divisor - (W % divisor)) % divisor
|
| 518 |
+
padded_past = past
|
| 519 |
+
if padH != 0 or padW != 0:
|
| 520 |
+
padded_past = np.pad(past, ((0, 0), (0, padH), (0, padW)), mode="constant", constant_values=0)
|
| 521 |
+
|
| 522 |
+
# Remove Nan
|
| 523 |
+
past_clean = np.nan_to_num(padded_past)
|
| 524 |
+
|
| 525 |
+
# Reshape the input to (B, T, C, H, W)
|
| 526 |
+
past_clean = past_clean[np.newaxis, :, np.newaxis, ...]
|
| 527 |
+
|
| 528 |
+
# Rainrate to normalized reflectivity
|
| 529 |
+
norm_past = rainrate_to_normalized(past_clean)
|
| 530 |
+
|
| 531 |
+
# Numpy to torch tensor
|
| 532 |
+
x = torch.from_numpy(norm_past)
|
| 533 |
+
|
| 534 |
+
# Move to device
|
| 535 |
+
x = x.to(self.device)
|
| 536 |
+
|
| 537 |
+
# Forward pass
|
| 538 |
+
self.eval()
|
| 539 |
+
with torch.no_grad():
|
| 540 |
+
preds = self.model(x, forecast_steps, self.hparams.noisy_decoder, ensemble_size)
|
| 541 |
+
|
| 542 |
+
# Move to CPU
|
| 543 |
+
preds = preds.cpu()
|
| 544 |
+
|
| 545 |
+
# Tensor to numpy array
|
| 546 |
+
preds = preds.numpy()
|
| 547 |
+
|
| 548 |
+
# Rescale back to rain rate
|
| 549 |
+
preds = normalized_to_rainrate(preds)
|
| 550 |
+
|
| 551 |
+
# Remove the batch (T, E, H, W)
|
| 552 |
+
preds = preds.squeeze(0)
|
| 553 |
+
|
| 554 |
+
# Swap the Time and Ensemble dimensions (E, T, H, W)
|
| 555 |
+
preds = np.swapaxes(preds, 0, 1)
|
| 556 |
+
|
| 557 |
+
# Remove the padding
|
| 558 |
+
preds = preds[..., :H, :W]
|
| 559 |
+
|
| 560 |
+
return preds
|
convgru_ensemble/losses.py
ADDED
|
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LossWithReduction(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Base class for losses with reduction options.
|
| 10 |
+
|
| 11 |
+
Parameters
|
| 12 |
+
----------
|
| 13 |
+
reduction : str, optional
|
| 14 |
+
Reduction mode to apply to the loss. Must be one of ``'mean'``,
|
| 15 |
+
``'sum'``, or ``'none'``. Default is ``'mean'``.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, reduction: str = "mean"):
|
| 19 |
+
"""
|
| 20 |
+
Initialize LossWithReduction.
|
| 21 |
+
|
| 22 |
+
Parameters
|
| 23 |
+
----------
|
| 24 |
+
reduction : str, optional
|
| 25 |
+
Reduction mode to apply to the loss. Must be one of ``'mean'``,
|
| 26 |
+
``'sum'``, or ``'none'``. Default is ``'mean'``.
|
| 27 |
+
"""
|
| 28 |
+
super().__init__()
|
| 29 |
+
assert reduction in ["mean", "sum", "none"], "reduction must be 'mean', 'sum', or 'none'"
|
| 30 |
+
self.reduction = reduction
|
| 31 |
+
|
| 32 |
+
def apply_reduction(self, loss: torch.Tensor) -> torch.Tensor:
|
| 33 |
+
"""
|
| 34 |
+
Apply the specified reduction to the loss tensor.
|
| 35 |
+
|
| 36 |
+
Parameters
|
| 37 |
+
----------
|
| 38 |
+
loss : torch.Tensor
|
| 39 |
+
Loss tensor to reduce.
|
| 40 |
+
|
| 41 |
+
Returns
|
| 42 |
+
-------
|
| 43 |
+
reduced_loss : torch.Tensor
|
| 44 |
+
Reduced loss tensor.
|
| 45 |
+
"""
|
| 46 |
+
if self.reduction == "mean":
|
| 47 |
+
return loss.mean()
|
| 48 |
+
elif self.reduction == "sum":
|
| 49 |
+
return loss.sum()
|
| 50 |
+
else: # 'none'
|
| 51 |
+
return loss
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class MaskedLoss(LossWithReduction):
|
| 55 |
+
"""
|
| 56 |
+
Wrapper to apply a mask to a given loss function.
|
| 57 |
+
|
| 58 |
+
Masks out invalid pixels before computing the loss, ensuring that only
|
| 59 |
+
valid regions contribute to the final value.
|
| 60 |
+
|
| 61 |
+
Parameters
|
| 62 |
+
----------
|
| 63 |
+
elementwise_loss : nn.Module
|
| 64 |
+
Base loss function to be masked. Must accept ``(preds, target)`` and
|
| 65 |
+
return element-wise (unreduced) loss. Should be instantiated with
|
| 66 |
+
``reduction='none'``.
|
| 67 |
+
reduction : str, optional
|
| 68 |
+
Reduction mode applied after masking. Must be one of ``'mean'``,
|
| 69 |
+
``'sum'``, or ``'none'``. Default is ``'mean'``.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(self, elementwise_loss: nn.Module, reduction: str = "mean"):
|
| 73 |
+
"""
|
| 74 |
+
Initialize MaskedLoss.
|
| 75 |
+
|
| 76 |
+
Parameters
|
| 77 |
+
----------
|
| 78 |
+
elementwise_loss : nn.Module
|
| 79 |
+
Base loss function to be masked. Must accept ``(preds, target)``
|
| 80 |
+
and return element-wise (unreduced) loss. Should be instantiated
|
| 81 |
+
with ``reduction='none'``.
|
| 82 |
+
reduction : str, optional
|
| 83 |
+
Reduction mode applied after masking. Must be one of ``'mean'``,
|
| 84 |
+
``'sum'``, or ``'none'``. Default is ``'mean'``.
|
| 85 |
+
"""
|
| 86 |
+
super().__init__(reduction=reduction)
|
| 87 |
+
self.elementwise_loss = elementwise_loss
|
| 88 |
+
|
| 89 |
+
def forward(self, preds: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
| 90 |
+
"""
|
| 91 |
+
Compute masked loss.
|
| 92 |
+
|
| 93 |
+
Parameters
|
| 94 |
+
----------
|
| 95 |
+
preds : torch.Tensor
|
| 96 |
+
Predictions of shape (B, T, C, *D).
|
| 97 |
+
target : torch.Tensor
|
| 98 |
+
Target of shape (B, T, C, *D).
|
| 99 |
+
mask : torch.Tensor
|
| 100 |
+
Mask of shape (B, T, C, *D)
|
| 101 |
+
or (B, T, 1, *D)
|
| 102 |
+
or (B, 1, 1, *D), with 1 for valid and 0 for invalid pixels.
|
| 103 |
+
Broadcasted to match preds/target shape if needed.
|
| 104 |
+
|
| 105 |
+
Returns
|
| 106 |
+
-------
|
| 107 |
+
loss : torch.Tensor
|
| 108 |
+
Scalar loss value.
|
| 109 |
+
"""
|
| 110 |
+
# assert preds.shape == target.shape, f"preds and target must have the same shape, got {preds.shape} and {target.shape}"
|
| 111 |
+
# assert mask.shape == preds.shape, f"mask must have the same shape as preds, got {mask.shape} and {preds.shape}"
|
| 112 |
+
|
| 113 |
+
# Compute element-wise loss
|
| 114 |
+
elementwise_loss = self.elementwise_loss(preds, target) # shape (B, T, C, *D)
|
| 115 |
+
|
| 116 |
+
# Apply mask (broadcast if needed)
|
| 117 |
+
masked_loss = elementwise_loss * mask # shape (B, T, C, *D)
|
| 118 |
+
|
| 119 |
+
# Average over valid pixels
|
| 120 |
+
# Account for broadcasting: mask.sum() × broadcast_factor
|
| 121 |
+
broadcast_factor = elementwise_loss.numel() // mask.numel()
|
| 122 |
+
valid_pixels = mask.sum() * broadcast_factor
|
| 123 |
+
if valid_pixels > 0:
|
| 124 |
+
if self.reduction == "mean":
|
| 125 |
+
return masked_loss.sum() / valid_pixels
|
| 126 |
+
elif self.reduction == "sum":
|
| 127 |
+
return masked_loss.sum()
|
| 128 |
+
else: # 'none'
|
| 129 |
+
return masked_loss
|
| 130 |
+
else:
|
| 131 |
+
return torch.tensor(0.0, device=preds.device)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class CRPS(LossWithReduction):
|
| 135 |
+
r"""
|
| 136 |
+
Continuous Ranked Probability Score (CRPS) loss with optional temporal
|
| 137 |
+
consistency regularization.
|
| 138 |
+
|
| 139 |
+
CRPS = E[|X - y|] - 0.5 * E[|X - X'|], where X, X' are independent
|
| 140 |
+
samples from the forecast distribution and y is the observation.
|
| 141 |
+
|
| 142 |
+
Parameters
|
| 143 |
+
----------
|
| 144 |
+
temporal_lambda : float, optional
|
| 145 |
+
Weight for the temporal consistency penalty. If ``0.0`` (default),
|
| 146 |
+
the penalty is disabled. When enabled, adds a penalty for large
|
| 147 |
+
differences between consecutive timesteps within each ensemble
|
| 148 |
+
member, preventing pulsing artifacts.
|
| 149 |
+
reduction : str, optional
|
| 150 |
+
Reduction mode. Must be one of ``'mean'``, ``'sum'``, or ``'none'``.
|
| 151 |
+
``'mean'`` averages over batch and all non-ensemble dimensions.
|
| 152 |
+
Default is ``'mean'``.
|
| 153 |
+
|
| 154 |
+
Expected shapes
|
| 155 |
+
---------------
|
| 156 |
+
preds : (B, T, M, \*D)
|
| 157 |
+
Ensemble predictions with time T on dim=1, ensemble size M on dim=2.
|
| 158 |
+
target : (B, T, C, \*D)
|
| 159 |
+
Deterministic target / analysis with channel C on dim=2 (should be 1).
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
def __init__(self, temporal_lambda: float = 0.0, reduction: str = "mean"):
|
| 163 |
+
"""
|
| 164 |
+
Initialize CRPS loss.
|
| 165 |
+
|
| 166 |
+
Parameters
|
| 167 |
+
----------
|
| 168 |
+
temporal_lambda : float, optional
|
| 169 |
+
Weight for the temporal consistency penalty. Default is ``0.0``
|
| 170 |
+
(disabled).
|
| 171 |
+
reduction : str, optional
|
| 172 |
+
Reduction mode. Must be one of ``'mean'``, ``'sum'``, or
|
| 173 |
+
``'none'``. Default is ``'mean'``.
|
| 174 |
+
"""
|
| 175 |
+
super().__init__(reduction=reduction)
|
| 176 |
+
self.temporal_lambda = temporal_lambda
|
| 177 |
+
|
| 178 |
+
def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 179 |
+
"""
|
| 180 |
+
Compute CRPS loss.
|
| 181 |
+
|
| 182 |
+
CRPS = E[|X - y|] - 0.5 * E[|X - X'|], where X, X' are independent
|
| 183 |
+
samples from the forecast distribution and y is the observation.
|
| 184 |
+
|
| 185 |
+
Parameters
|
| 186 |
+
----------
|
| 187 |
+
preds : torch.Tensor
|
| 188 |
+
Ensemble forecasts of shape ``(B, T, M, *D)``, where B is batch
|
| 189 |
+
size, T is the number of timesteps, M is ensemble size, and
|
| 190 |
+
``*D`` are spatial dimensions.
|
| 191 |
+
target : torch.Tensor
|
| 192 |
+
Verifying observation / analysis of shape ``(B, T, C, *D)``,
|
| 193 |
+
where C should be 1. Broadcasts against ``preds`` on dim=2.
|
| 194 |
+
|
| 195 |
+
Returns
|
| 196 |
+
-------
|
| 197 |
+
loss : torch.Tensor
|
| 198 |
+
Scalar if ``reduction='mean'`` or ``'sum'``, otherwise a tensor
|
| 199 |
+
of shape ``(B, T, 1, *D)`` (or ``(B, 1, 1, *D)`` when
|
| 200 |
+
``temporal_lambda > 0``).
|
| 201 |
+
"""
|
| 202 |
+
# preds: (B, T, M, *D)
|
| 203 |
+
# target: (B, T, C, *D) where C should be 1
|
| 204 |
+
# target broadcasts against preds: (B, T, 1, *D) vs (B, T, M, *D)
|
| 205 |
+
|
| 206 |
+
# First term: E[|X - y|]
|
| 207 |
+
# Compute absolute difference between each ensemble member and target
|
| 208 |
+
diff_to_target = torch.abs(preds - target) # (B, T, M, *D)
|
| 209 |
+
term1 = diff_to_target.mean(dim=2) # Average over ensemble: (B, T, *D)
|
| 210 |
+
|
| 211 |
+
# Second term: 0.5 * E[|X - X'|]
|
| 212 |
+
# Compute pairwise differences between ensemble members
|
| 213 |
+
# preds: (B, T, M, *D)
|
| 214 |
+
# Expand for pairwise differences
|
| 215 |
+
# preds_i: (B, T, M, 1, *D)
|
| 216 |
+
# preds_j: (B, T, 1, M, *D)
|
| 217 |
+
preds_i = preds.unsqueeze(3)
|
| 218 |
+
preds_j = preds.unsqueeze(2)
|
| 219 |
+
|
| 220 |
+
# Pairwise absolute differences: (B, T, M, M, *D)
|
| 221 |
+
pairwise_diff = torch.abs(preds_i - preds_j)
|
| 222 |
+
|
| 223 |
+
# Average over both ensemble dimensions
|
| 224 |
+
# Sum over M*M pairs and divide by M*M
|
| 225 |
+
term2 = 0.5 * pairwise_diff.mean(dim=(2, 3)) # (B, T, *D)
|
| 226 |
+
|
| 227 |
+
# CRPS
|
| 228 |
+
crps = term1 - term2 # (B, T, *D)
|
| 229 |
+
|
| 230 |
+
# Temporal consistency penalty
|
| 231 |
+
if self.temporal_lambda > 0:
|
| 232 |
+
# average over time dimension
|
| 233 |
+
crps = crps.mean(dim=1) # (B, *D)
|
| 234 |
+
|
| 235 |
+
# preds: (B, T, M, *D)
|
| 236 |
+
# Compute differences between consecutive timesteps per ensemble member
|
| 237 |
+
temporal_diff = preds[:, 1:, :, ...] - preds[:, :-1, :, ...] # (B, T-1, M, *D)
|
| 238 |
+
temporal_penalty = torch.abs(temporal_diff).mean(
|
| 239 |
+
dim=(1, 2)
|
| 240 |
+
) # average over time and ensemble dimensions (B, *D)
|
| 241 |
+
# Add penalty to CRPS (before reduction, averaged over time)
|
| 242 |
+
crps = crps + self.temporal_lambda * temporal_penalty
|
| 243 |
+
crps = crps[:, None, None, ...] # add time and channel dims back for consistency (B, 1, 1, *D)
|
| 244 |
+
else:
|
| 245 |
+
# Keep singleton channel dim for MaskedLoss compatibility: (B, T, 1, *D)
|
| 246 |
+
crps = crps.unsqueeze(2)
|
| 247 |
+
|
| 248 |
+
return self.apply_reduction(crps)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class afCRPS(LossWithReduction):
|
| 252 |
+
r"""
|
| 253 |
+
Almost fair CRPS (afCRPS) loss as in eq. (4) of Lang et al. (2024).
|
| 254 |
+
|
| 255 |
+
Interpolates between the standard (energy-score style) CRPS and the
|
| 256 |
+
fair CRPS via the fairness parameter ``alpha``.
|
| 257 |
+
|
| 258 |
+
Parameters
|
| 259 |
+
----------
|
| 260 |
+
alpha : float, optional
|
| 261 |
+
Fairness parameter in ``(0, 1]``. ``alpha=1`` recovers the fair
|
| 262 |
+
CRPS. Lang et al. (2024) recommend ``alpha=0.95``. Default is
|
| 263 |
+
``0.95``.
|
| 264 |
+
temporal_lambda : float, optional
|
| 265 |
+
Weight for the temporal consistency penalty. If ``0.0`` (default),
|
| 266 |
+
the penalty is disabled. When enabled, adds a penalty for large
|
| 267 |
+
differences between consecutive timesteps within each ensemble
|
| 268 |
+
member.
|
| 269 |
+
reduction : str, optional
|
| 270 |
+
Reduction mode. Must be one of ``'mean'``, ``'sum'``, or ``'none'``.
|
| 271 |
+
``'mean'`` averages over batch and all non-ensemble dimensions.
|
| 272 |
+
Default is ``'mean'``.
|
| 273 |
+
|
| 274 |
+
Expected shapes
|
| 275 |
+
---------------
|
| 276 |
+
preds : (B, T, M, \*D)
|
| 277 |
+
Ensemble predictions with time T on dim=1, ensemble size M on dim=2.
|
| 278 |
+
target : (B, T, C, \*D)
|
| 279 |
+
Deterministic target / analysis with channel C on dim=2 (should be 1).
|
| 280 |
+
"""
|
| 281 |
+
|
| 282 |
+
def __init__(self, alpha: float = 0.95, temporal_lambda: float = 0.0, reduction: str = "mean"):
|
| 283 |
+
"""
|
| 284 |
+
Initialize afCRPS loss.
|
| 285 |
+
|
| 286 |
+
Parameters
|
| 287 |
+
----------
|
| 288 |
+
alpha : float, optional
|
| 289 |
+
Fairness parameter in ``(0, 1]``. ``alpha=1`` recovers the fair
|
| 290 |
+
CRPS. Default is ``0.95``.
|
| 291 |
+
temporal_lambda : float, optional
|
| 292 |
+
Weight for the temporal consistency penalty. Default is ``0.0``
|
| 293 |
+
(disabled).
|
| 294 |
+
reduction : str, optional
|
| 295 |
+
Reduction mode. Must be one of ``'mean'``, ``'sum'``, or
|
| 296 |
+
``'none'``. Default is ``'mean'``.
|
| 297 |
+
|
| 298 |
+
Raises
|
| 299 |
+
------
|
| 300 |
+
ValueError
|
| 301 |
+
If ``alpha`` is not in ``(0, 1]``.
|
| 302 |
+
"""
|
| 303 |
+
super().__init__(reduction=reduction)
|
| 304 |
+
if not (0.0 < alpha <= 1.0):
|
| 305 |
+
raise ValueError("alpha must be in (0, 1].")
|
| 306 |
+
self.alpha = alpha
|
| 307 |
+
self.temporal_lambda = temporal_lambda
|
| 308 |
+
|
| 309 |
+
def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 310 |
+
"""
|
| 311 |
+
Compute afCRPS over an ensemble.
|
| 312 |
+
|
| 313 |
+
Parameters
|
| 314 |
+
----------
|
| 315 |
+
preds : torch.Tensor
|
| 316 |
+
(B, T, M, *D) ensemble forecasts.
|
| 317 |
+
target : torch.Tensor
|
| 318 |
+
(B, T, C, *D) verifying observation / analysis (C=1).
|
| 319 |
+
|
| 320 |
+
Returns
|
| 321 |
+
-------
|
| 322 |
+
loss : torch.Tensor
|
| 323 |
+
Scalar if reduction=='mean', else per-sample tensor.
|
| 324 |
+
"""
|
| 325 |
+
if preds.dim() < 3:
|
| 326 |
+
raise ValueError("preds must have at least 3 dimensions.")
|
| 327 |
+
if target.shape[0] != preds.shape[0]:
|
| 328 |
+
raise ValueError("batch dimension of preds and target must match.")
|
| 329 |
+
if target.shape[1] != preds.shape[1]:
|
| 330 |
+
raise ValueError("time dimension of preds and target must match.")
|
| 331 |
+
|
| 332 |
+
# preds: (B, T, M, *D), target: (B, T, 1, *D)
|
| 333 |
+
M = preds.shape[2]
|
| 334 |
+
if M < 2:
|
| 335 |
+
raise ValueError("Ensemble size M must be >= 2 for afCRPS.")
|
| 336 |
+
|
| 337 |
+
eps = (1.0 - self.alpha) / float(M)
|
| 338 |
+
|
| 339 |
+
# |x_j - y| : (B, T, M, *D) — target broadcasts via C=1
|
| 340 |
+
abs_x_minus_y = (preds - target).abs()
|
| 341 |
+
|
| 342 |
+
# Pairwise terms over ensemble dim (dim=2), j != k
|
| 343 |
+
# x_j: (B, T, M, 1, *D)
|
| 344 |
+
# x_k: (B, T, 1, M, *D)
|
| 345 |
+
x_j = preds.unsqueeze(3)
|
| 346 |
+
x_k = preds.unsqueeze(2)
|
| 347 |
+
|
| 348 |
+
# |x_j - y|, |x_k - y| broadcast to (B, T, M, M, *D)
|
| 349 |
+
abs_xj_minus_y = abs_x_minus_y.unsqueeze(3)
|
| 350 |
+
abs_xk_minus_y = abs_x_minus_y.unsqueeze(2)
|
| 351 |
+
|
| 352 |
+
# |x_j - x_k|: (B, T, M, M, *D)
|
| 353 |
+
abs_xj_minus_xk = (x_j - x_k).abs()
|
| 354 |
+
|
| 355 |
+
# Per (j,k) term: |x_j - y| + |x_k - y| - (1 - eps)|x_j - x_k|
|
| 356 |
+
term = abs_xj_minus_y + abs_xk_minus_y - (1.0 - eps) * abs_xj_minus_xk
|
| 357 |
+
|
| 358 |
+
# Exclude j == k (diagonal) since eq. (4) sums over k != j
|
| 359 |
+
idx = torch.arange(M, device=preds.device)
|
| 360 |
+
mask = idx[:, None] != idx[None, :] # (M, M)
|
| 361 |
+
term = term * mask.view(1, 1, M, M, *([1] * (term.dim() - 4)))
|
| 362 |
+
|
| 363 |
+
# Sum over j and k dims → (B, T, *D)
|
| 364 |
+
summed = term.sum(dim=(2, 3))
|
| 365 |
+
|
| 366 |
+
# Normalization factor 1 / [2 M (M - 1)]
|
| 367 |
+
afcrps = summed / (2.0 * M * (M - 1)) # (B, T, *D)
|
| 368 |
+
|
| 369 |
+
# Temporal consistency penalty
|
| 370 |
+
if self.temporal_lambda > 0:
|
| 371 |
+
# average over time dimension
|
| 372 |
+
afcrps = afcrps.mean(dim=1) # (B, *D)
|
| 373 |
+
|
| 374 |
+
# preds: (B, T, M, *D)
|
| 375 |
+
# Compute differences between consecutive timesteps per ensemble member
|
| 376 |
+
temporal_diff = preds[:, 1:, :, ...] - preds[:, :-1, :, ...] # (B, T-1, M, *D)
|
| 377 |
+
temporal_penalty = torch.abs(temporal_diff).mean(
|
| 378 |
+
dim=(1, 2)
|
| 379 |
+
) # average over time and ensemble dimensions (B, *D)
|
| 380 |
+
# Add penalty to afCRPS (before reduction, averaged over time)
|
| 381 |
+
afcrps = afcrps + self.temporal_lambda * temporal_penalty
|
| 382 |
+
afcrps = afcrps[:, None, None, ...] # add time and channel dims back for consistency (B, 1, 1, *D)
|
| 383 |
+
else:
|
| 384 |
+
# Keep singleton channel dim for MaskedLoss compatibility: (B, T, 1, *D)
|
| 385 |
+
afcrps = afcrps.unsqueeze(2)
|
| 386 |
+
|
| 387 |
+
return self.apply_reduction(afcrps)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
PIXEL_LOSSES = {"mse": nn.MSELoss, "mae": nn.L1Loss, "crps": CRPS, "afcrps": afCRPS}
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def build_loss(
|
| 394 |
+
loss_class: type | str,
|
| 395 |
+
loss_params: dict[str, Any] | None = None,
|
| 396 |
+
masked_loss: bool = False,
|
| 397 |
+
) -> nn.Module:
|
| 398 |
+
"""
|
| 399 |
+
Build a loss function, optionally wrapped with masking.
|
| 400 |
+
|
| 401 |
+
Resolves a loss class by name (from ``PIXEL_LOSSES``) or accepts a class
|
| 402 |
+
directly, instantiates it with the given parameters, and optionally wraps
|
| 403 |
+
it in a :class:`MaskedLoss`.
|
| 404 |
+
|
| 405 |
+
Parameters
|
| 406 |
+
----------
|
| 407 |
+
loss_class : type or str
|
| 408 |
+
Loss class or its string name. Accepted string names are the keys of
|
| 409 |
+
``PIXEL_LOSSES``: ``'mse'``, ``'mae'``, ``'crps'``, ``'afcrps'``.
|
| 410 |
+
If ``None``, defaults to ``nn.MSELoss``.
|
| 411 |
+
loss_params : dict of str to any, optional
|
| 412 |
+
Keyword arguments forwarded to the loss class constructor. If
|
| 413 |
+
``masked_loss`` is ``True``, the ``'reduction'`` key (if present)
|
| 414 |
+
is extracted and passed to :class:`MaskedLoss` instead. Default is
|
| 415 |
+
``None``.
|
| 416 |
+
masked_loss : bool, optional
|
| 417 |
+
If ``True``, the loss is wrapped in :class:`MaskedLoss` and the
|
| 418 |
+
inner loss is instantiated with ``reduction='none'``. Default is
|
| 419 |
+
``False``.
|
| 420 |
+
|
| 421 |
+
Returns
|
| 422 |
+
-------
|
| 423 |
+
criterion : nn.Module
|
| 424 |
+
Instantiated loss module, optionally wrapped in :class:`MaskedLoss`.
|
| 425 |
+
|
| 426 |
+
Raises
|
| 427 |
+
------
|
| 428 |
+
ValueError
|
| 429 |
+
If ``loss_class`` is a string not found in ``PIXEL_LOSSES``.
|
| 430 |
+
"""
|
| 431 |
+
if isinstance(loss_class, str):
|
| 432 |
+
if loss_class.lower() not in PIXEL_LOSSES:
|
| 433 |
+
raise ValueError(f"Unknown loss class '{loss_class}'. Available: {list(PIXEL_LOSSES.keys())}")
|
| 434 |
+
loss_class = PIXEL_LOSSES[loss_class.lower()]
|
| 435 |
+
elif loss_class is None:
|
| 436 |
+
loss_class = nn.MSELoss # default
|
| 437 |
+
print("No loss_class provided, using default MSELoss.")
|
| 438 |
+
|
| 439 |
+
params = loss_params.copy() if loss_params is not None else None
|
| 440 |
+
|
| 441 |
+
# if the loss is masked, the reduction is handled in MaskedLoss
|
| 442 |
+
if masked_loss and params is not None:
|
| 443 |
+
# pop 'reduction' from loss_params and pass to MaskedLoss
|
| 444 |
+
reduction = params.pop("reduction", "mean")
|
| 445 |
+
criterion = MaskedLoss(loss_class(reduction="none", **params), reduction=reduction)
|
| 446 |
+
print(f"Using masked loss: {loss_class.__name__} with params {params} and reduction {reduction}")
|
| 447 |
+
elif masked_loss:
|
| 448 |
+
criterion = MaskedLoss(loss_class(reduction="none"), reduction="mean")
|
| 449 |
+
print(f"Using masked loss: {loss_class.__name__} with default params and reduction 'mean'")
|
| 450 |
+
else:
|
| 451 |
+
if params is not None:
|
| 452 |
+
criterion = loss_class(**params)
|
| 453 |
+
print(f"Using custom loss: {loss_class.__name__} with params {params}")
|
| 454 |
+
else:
|
| 455 |
+
criterion = loss_class()
|
| 456 |
+
print(f"Using loss: {loss_class.__name__} with default params")
|
| 457 |
+
|
| 458 |
+
return criterion
|
convgru_ensemble/model.py
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ResidualConvBlock(nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
Residual convolutional block with two convolutions and a skip connection.
|
| 9 |
+
|
| 10 |
+
Applies two 2D convolutions with a ReLU activation in between. If the
|
| 11 |
+
input and output channel counts differ, a 1x1 projection is used for the
|
| 12 |
+
residual path.
|
| 13 |
+
|
| 14 |
+
Parameters
|
| 15 |
+
----------
|
| 16 |
+
in_channels : int
|
| 17 |
+
Number of input channels.
|
| 18 |
+
out_channels : int
|
| 19 |
+
Number of output channels.
|
| 20 |
+
kernel_size : int, optional
|
| 21 |
+
Kernel size for both convolutions. Default is ``3``.
|
| 22 |
+
padding : int, optional
|
| 23 |
+
Padding for both convolutions. Default is ``1``.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1):
|
| 27 |
+
"""
|
| 28 |
+
Initialize ResidualConvBlock.
|
| 29 |
+
|
| 30 |
+
Parameters
|
| 31 |
+
----------
|
| 32 |
+
in_channels : int
|
| 33 |
+
Number of input channels.
|
| 34 |
+
out_channels : int
|
| 35 |
+
Number of output channels.
|
| 36 |
+
kernel_size : int, optional
|
| 37 |
+
Kernel size for both convolutions. Default is ``3``.
|
| 38 |
+
padding : int, optional
|
| 39 |
+
Padding for both convolutions. Default is ``1``.
|
| 40 |
+
"""
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
|
| 43 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding)
|
| 44 |
+
if in_channels != out_channels:
|
| 45 |
+
self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
| 46 |
+
else:
|
| 47 |
+
self.proj = None
|
| 48 |
+
|
| 49 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 50 |
+
"""
|
| 51 |
+
Forward pass through the residual convolutional block.
|
| 52 |
+
|
| 53 |
+
Parameters
|
| 54 |
+
----------
|
| 55 |
+
x : torch.Tensor
|
| 56 |
+
Input tensor of shape ``(B, C_in, H, W)``.
|
| 57 |
+
|
| 58 |
+
Returns
|
| 59 |
+
-------
|
| 60 |
+
out : torch.Tensor
|
| 61 |
+
Output tensor of shape ``(B, C_out, H, W)``.
|
| 62 |
+
"""
|
| 63 |
+
residual = x
|
| 64 |
+
out = F.relu(self.conv1(x))
|
| 65 |
+
out = self.conv2(out)
|
| 66 |
+
if self.proj is not None:
|
| 67 |
+
residual = self.proj(residual)
|
| 68 |
+
out += residual
|
| 69 |
+
out = F.relu(out)
|
| 70 |
+
return out
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class ConvGRUCell(nn.Module):
|
| 74 |
+
"""
|
| 75 |
+
Convolutional GRU cell operating on 2D spatial grids.
|
| 76 |
+
|
| 77 |
+
Implements a single-step GRU update where all linear projections are
|
| 78 |
+
replaced by 2D convolutions, preserving spatial structure.
|
| 79 |
+
|
| 80 |
+
Parameters
|
| 81 |
+
----------
|
| 82 |
+
input_size : int
|
| 83 |
+
Number of channels in the input tensor.
|
| 84 |
+
hidden_size : int
|
| 85 |
+
Number of channels in the hidden state.
|
| 86 |
+
kernel_size : int, optional
|
| 87 |
+
Kernel size for the convolutional gates. Default is ``3``.
|
| 88 |
+
conv_layer : nn.Module, optional
|
| 89 |
+
Convolutional layer class to use. Default is ``nn.Conv2d``.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(self, input_size: int, hidden_size: int, kernel_size: int = 3, conv_layer: nn.Module = nn.Conv2d):
|
| 93 |
+
"""
|
| 94 |
+
Initialize ConvGRUCell.
|
| 95 |
+
|
| 96 |
+
Parameters
|
| 97 |
+
----------
|
| 98 |
+
input_size : int
|
| 99 |
+
Number of channels in the input tensor.
|
| 100 |
+
hidden_size : int
|
| 101 |
+
Number of channels in the hidden state.
|
| 102 |
+
kernel_size : int, optional
|
| 103 |
+
Kernel size for the convolutional gates. Default is ``3``.
|
| 104 |
+
conv_layer : nn.Module, optional
|
| 105 |
+
Convolutional layer class to use. Default is ``nn.Conv2d``.
|
| 106 |
+
"""
|
| 107 |
+
super().__init__()
|
| 108 |
+
padding = kernel_size // 2
|
| 109 |
+
self.input_size = input_size
|
| 110 |
+
self.hidden_size = hidden_size
|
| 111 |
+
# update and reset gates are combined for optimization
|
| 112 |
+
self.combined_gates = conv_layer(input_size + hidden_size, 2 * hidden_size, kernel_size, padding=padding)
|
| 113 |
+
self.out_gate = conv_layer(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
|
| 114 |
+
|
| 115 |
+
def forward(self, inpt: torch.Tensor | None = None, h_s: torch.Tensor | None = None) -> torch.Tensor:
|
| 116 |
+
"""
|
| 117 |
+
Forward the ConvGRU cell for a single timestep.
|
| 118 |
+
|
| 119 |
+
If either input is ``None``, it is initialized to zeros based on the
|
| 120 |
+
shape of the other. If both are ``None``, a ``ValueError`` is raised.
|
| 121 |
+
|
| 122 |
+
Parameters
|
| 123 |
+
----------
|
| 124 |
+
inpt : torch.Tensor or None, optional
|
| 125 |
+
Input tensor of shape ``(B, input_size, H, W)``. Default is
|
| 126 |
+
``None``.
|
| 127 |
+
h_s : torch.Tensor or None, optional
|
| 128 |
+
Hidden state tensor of shape ``(B, hidden_size, H, W)``. Default
|
| 129 |
+
is ``None``.
|
| 130 |
+
|
| 131 |
+
Returns
|
| 132 |
+
-------
|
| 133 |
+
new_state : torch.Tensor
|
| 134 |
+
Updated hidden state of shape ``(B, hidden_size, H, W)``.
|
| 135 |
+
|
| 136 |
+
Raises
|
| 137 |
+
------
|
| 138 |
+
ValueError
|
| 139 |
+
If both ``inpt`` and ``h_s`` are ``None``.
|
| 140 |
+
"""
|
| 141 |
+
if h_s is None and inpt is None:
|
| 142 |
+
raise ValueError("Both input and state can't be None")
|
| 143 |
+
elif h_s is None:
|
| 144 |
+
h_s = torch.zeros(
|
| 145 |
+
inpt.size(0), self.hidden_size, inpt.size(2), inpt.size(3), dtype=inpt.dtype, device=inpt.device
|
| 146 |
+
)
|
| 147 |
+
elif inpt is None:
|
| 148 |
+
inpt = torch.zeros(
|
| 149 |
+
h_s.size(0), self.input_size, h_s.size(2), h_s.size(3), dtype=h_s.dtype, device=h_s.device
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
gamma, beta = torch.chunk(self.combined_gates(torch.cat([inpt, h_s], dim=1)), 2, dim=1)
|
| 153 |
+
update = torch.sigmoid(gamma)
|
| 154 |
+
reset = torch.sigmoid(beta)
|
| 155 |
+
|
| 156 |
+
out_inputs = torch.tanh(self.out_gate(torch.cat([inpt, h_s * reset], dim=1)))
|
| 157 |
+
new_state = h_s * (1 - update) + out_inputs * update
|
| 158 |
+
|
| 159 |
+
return new_state
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class ConvGRU(nn.Module):
|
| 163 |
+
"""
|
| 164 |
+
Convolutional GRU that unrolls a :class:`ConvGRUCell` over a sequence.
|
| 165 |
+
|
| 166 |
+
Parameters
|
| 167 |
+
----------
|
| 168 |
+
input_size : int
|
| 169 |
+
Number of channels in the input tensor.
|
| 170 |
+
hidden_size : int
|
| 171 |
+
Number of channels in the hidden state.
|
| 172 |
+
kernel_size : int, optional
|
| 173 |
+
Kernel size for the convolutional gates. Default is ``3``.
|
| 174 |
+
conv_layer : nn.Module, optional
|
| 175 |
+
Convolutional layer class to use. Default is ``nn.Conv2d``.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
def __init__(self, input_size: int, hidden_size: int, kernel_size: int = 3, conv_layer: nn.Module = nn.Conv2d):
|
| 179 |
+
"""
|
| 180 |
+
Initialize ConvGRU.
|
| 181 |
+
|
| 182 |
+
Parameters
|
| 183 |
+
----------
|
| 184 |
+
input_size : int
|
| 185 |
+
Number of channels in the input tensor.
|
| 186 |
+
hidden_size : int
|
| 187 |
+
Number of channels in the hidden state.
|
| 188 |
+
kernel_size : int, optional
|
| 189 |
+
Kernel size for the convolutional gates. Default is ``3``.
|
| 190 |
+
conv_layer : nn.Module, optional
|
| 191 |
+
Convolutional layer class to use. Default is ``nn.Conv2d``.
|
| 192 |
+
"""
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.cell = ConvGRUCell(input_size, hidden_size, kernel_size, conv_layer)
|
| 195 |
+
|
| 196 |
+
def forward(self, x: torch.Tensor | None = None, h: torch.Tensor | None = None) -> torch.Tensor:
|
| 197 |
+
"""
|
| 198 |
+
Unroll the ConvGRU cell over the sequence (time) dimension.
|
| 199 |
+
|
| 200 |
+
.. code-block:: text
|
| 201 |
+
|
| 202 |
+
x[:, 0] x[:, 1]
|
| 203 |
+
| |
|
| 204 |
+
v v
|
| 205 |
+
*------* *------*
|
| 206 |
+
h --> | Cell | --> h_0 --> | Cell | --> h_1 ...
|
| 207 |
+
*------* *------*
|
| 208 |
+
|
| 209 |
+
If either input is ``None``, it is initialized to zeros based on the
|
| 210 |
+
shape of the other. If both are ``None``, a ``ValueError`` is raised.
|
| 211 |
+
|
| 212 |
+
Parameters
|
| 213 |
+
----------
|
| 214 |
+
x : torch.Tensor or None, optional
|
| 215 |
+
Input tensor of shape ``(B, T, input_size, H, W)``. Default is
|
| 216 |
+
``None``.
|
| 217 |
+
h : torch.Tensor or None, optional
|
| 218 |
+
Initial hidden state of shape ``(B, hidden_size, H, W)``. Default
|
| 219 |
+
is ``None``.
|
| 220 |
+
|
| 221 |
+
Returns
|
| 222 |
+
-------
|
| 223 |
+
hidden_states : torch.Tensor
|
| 224 |
+
Stacked hidden states of shape ``(B, T, hidden_size, H, W)``,
|
| 225 |
+
i.e. ``[h_0, h_1, h_2, ...]``.
|
| 226 |
+
"""
|
| 227 |
+
h_s = []
|
| 228 |
+
for i in range(x.size(1)):
|
| 229 |
+
h = self.cell(x[:, i], h)
|
| 230 |
+
h_s.append(h)
|
| 231 |
+
return torch.stack(h_s, dim=1)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class EncoderBlock(nn.Module):
|
| 235 |
+
"""
|
| 236 |
+
ConvGRU-based encoder block with spatial downsampling.
|
| 237 |
+
|
| 238 |
+
Applies a :class:`ConvGRU` followed by ``nn.PixelUnshuffle(2)`` to
|
| 239 |
+
halve spatial dimensions and quadruple channels.
|
| 240 |
+
|
| 241 |
+
Parameters
|
| 242 |
+
----------
|
| 243 |
+
input_size : int
|
| 244 |
+
Number of input channels.
|
| 245 |
+
kernel_size : int, optional
|
| 246 |
+
Kernel size for the ConvGRU. Default is ``3``.
|
| 247 |
+
conv_layer : nn.Module, optional
|
| 248 |
+
Convolutional layer class to use. Default is ``nn.Conv2d``.
|
| 249 |
+
"""
|
| 250 |
+
|
| 251 |
+
def __init__(self, input_size: int, kernel_size: int = 3, conv_layer: nn.Module = nn.Conv2d):
|
| 252 |
+
"""
|
| 253 |
+
Initialize EncoderBlock.
|
| 254 |
+
|
| 255 |
+
Parameters
|
| 256 |
+
----------
|
| 257 |
+
input_size : int
|
| 258 |
+
Number of input channels.
|
| 259 |
+
kernel_size : int, optional
|
| 260 |
+
Kernel size for the ConvGRU. Default is ``3``.
|
| 261 |
+
conv_layer : nn.Module, optional
|
| 262 |
+
Convolutional layer class to use. Default is ``nn.Conv2d``.
|
| 263 |
+
"""
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.convgru = ConvGRU(input_size, input_size, kernel_size, conv_layer)
|
| 266 |
+
self.down = nn.PixelUnshuffle(2)
|
| 267 |
+
|
| 268 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 269 |
+
"""
|
| 270 |
+
Forward the encoder block.
|
| 271 |
+
|
| 272 |
+
Parameters
|
| 273 |
+
----------
|
| 274 |
+
x : torch.Tensor
|
| 275 |
+
Input tensor of shape ``(B, T, C, H, W)``.
|
| 276 |
+
|
| 277 |
+
Returns
|
| 278 |
+
-------
|
| 279 |
+
out : torch.Tensor
|
| 280 |
+
Downsampled tensor of shape ``(B, T, C*4, H/2, W/2)``.
|
| 281 |
+
"""
|
| 282 |
+
x = self.convgru(x)
|
| 283 |
+
x = self.down(x)
|
| 284 |
+
return x
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class Encoder(nn.Module):
|
| 288 |
+
r"""
|
| 289 |
+
ConvGRU-based encoder that stacks multiple :class:`EncoderBlock` layers.
|
| 290 |
+
|
| 291 |
+
After each block the spatial resolution is halved via pixel-unshuffle.
|
| 292 |
+
|
| 293 |
+
.. code-block:: text
|
| 294 |
+
|
| 295 |
+
/// Encoder Block 1 \\\ /// Encoder Block 2 \\\
|
| 296 |
+
/--------------------------------------------\ /---------------------------------------\
|
| 297 |
+
| | |
|
| 298 |
+
* *---------* *-----------------* * *---------* *-----------------* *
|
| 299 |
+
X -> | ConvGRU | ---> | Pixel Unshuffle | ---> | ConvGRU | ---> | Pixel Unshuffle | ---> ...
|
| 300 |
+
| *---------* | *-----------------* | *---------* | *-----------------* |
|
| 301 |
+
v v v v v
|
| 302 |
+
(b,t,c,h,w) (b,t,c,h,w) (b,t,c*4,h/2,w/2) (b,t,c*4,h/2,w/2) (b,t,c*16,h/4,w/4)
|
| 303 |
+
|
| 304 |
+
Parameters
|
| 305 |
+
----------
|
| 306 |
+
input_channels : int, optional
|
| 307 |
+
Number of input channels. Default is ``1``.
|
| 308 |
+
num_blocks : int, optional
|
| 309 |
+
Number of encoder blocks to stack. Default is ``4``.
|
| 310 |
+
**kwargs
|
| 311 |
+
Additional keyword arguments forwarded to each :class:`EncoderBlock`.
|
| 312 |
+
"""
|
| 313 |
+
|
| 314 |
+
def __init__(self, input_channels: int = 1, num_blocks: int = 4, **kwargs):
|
| 315 |
+
"""
|
| 316 |
+
Initialize Encoder.
|
| 317 |
+
|
| 318 |
+
Parameters
|
| 319 |
+
----------
|
| 320 |
+
input_channels : int, optional
|
| 321 |
+
Number of input channels. Default is ``1``.
|
| 322 |
+
num_blocks : int, optional
|
| 323 |
+
Number of encoder blocks to stack. Default is ``4``.
|
| 324 |
+
**kwargs
|
| 325 |
+
Additional keyword arguments forwarded to each
|
| 326 |
+
:class:`EncoderBlock`.
|
| 327 |
+
"""
|
| 328 |
+
super().__init__()
|
| 329 |
+
self.channel_sizes = [input_channels * 4**i for i in range(num_blocks)] # [1, 4, 16, 64]
|
| 330 |
+
self.blocks = nn.ModuleList([EncoderBlock(self.channel_sizes[i], **kwargs) for i in range(num_blocks)])
|
| 331 |
+
|
| 332 |
+
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
| 333 |
+
"""
|
| 334 |
+
Forward the encoder through all blocks.
|
| 335 |
+
|
| 336 |
+
Parameters
|
| 337 |
+
----------
|
| 338 |
+
x : torch.Tensor
|
| 339 |
+
Input tensor of shape ``(B, T, C, H, W)``.
|
| 340 |
+
|
| 341 |
+
Returns
|
| 342 |
+
-------
|
| 343 |
+
hidden_states : list of torch.Tensor
|
| 344 |
+
Hidden state tensors from each block, with progressively reduced
|
| 345 |
+
spatial dimensions:
|
| 346 |
+
``[(B, T, C*4, H/2, W/2), (B, T, C*16, H/4, W/4), ...]``.
|
| 347 |
+
"""
|
| 348 |
+
hidden_states = []
|
| 349 |
+
for block in self.blocks:
|
| 350 |
+
x = block(x)
|
| 351 |
+
hidden_states.append(x)
|
| 352 |
+
return hidden_states
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class DecoderBlock(nn.Module):
|
| 356 |
+
"""
|
| 357 |
+
ConvGRU-based decoder block with spatial upsampling.
|
| 358 |
+
|
| 359 |
+
Applies a :class:`ConvGRU` followed by ``nn.PixelShuffle(2)`` to double
|
| 360 |
+
spatial dimensions and quarter channels.
|
| 361 |
+
|
| 362 |
+
Parameters
|
| 363 |
+
----------
|
| 364 |
+
input_size : int
|
| 365 |
+
Number of input channels.
|
| 366 |
+
hidden_size : int
|
| 367 |
+
Number of hidden channels for the ConvGRU.
|
| 368 |
+
kernel_size : int, optional
|
| 369 |
+
Kernel size for the ConvGRU. Default is ``3``.
|
| 370 |
+
conv_layer : nn.Module, optional
|
| 371 |
+
Convolutional layer class to use. Default is ``nn.Conv2d``.
|
| 372 |
+
"""
|
| 373 |
+
|
| 374 |
+
def __init__(self, input_size: int, hidden_size: int, kernel_size: int = 3, conv_layer: nn.Module = nn.Conv2d):
|
| 375 |
+
"""
|
| 376 |
+
Initialize DecoderBlock.
|
| 377 |
+
|
| 378 |
+
Parameters
|
| 379 |
+
----------
|
| 380 |
+
input_size : int
|
| 381 |
+
Number of input channels.
|
| 382 |
+
hidden_size : int
|
| 383 |
+
Number of hidden channels for the ConvGRU.
|
| 384 |
+
kernel_size : int, optional
|
| 385 |
+
Kernel size for the ConvGRU. Default is ``3``.
|
| 386 |
+
conv_layer : nn.Module, optional
|
| 387 |
+
Convolutional layer class to use. Default is ``nn.Conv2d``.
|
| 388 |
+
"""
|
| 389 |
+
super().__init__()
|
| 390 |
+
self.convgru = ConvGRU(input_size, hidden_size, kernel_size, conv_layer)
|
| 391 |
+
self.up = nn.PixelShuffle(2)
|
| 392 |
+
|
| 393 |
+
def forward(self, x: torch.Tensor, hidden_state: torch.Tensor) -> torch.Tensor:
|
| 394 |
+
"""
|
| 395 |
+
Forward the decoder block.
|
| 396 |
+
|
| 397 |
+
Parameters
|
| 398 |
+
----------
|
| 399 |
+
x : torch.Tensor
|
| 400 |
+
Input tensor of shape ``(B, T, C, H, W)``.
|
| 401 |
+
hidden_state : torch.Tensor
|
| 402 |
+
Hidden state from the corresponding encoder block, of shape
|
| 403 |
+
``(B, hidden_size, H, W)``.
|
| 404 |
+
|
| 405 |
+
Returns
|
| 406 |
+
-------
|
| 407 |
+
out : torch.Tensor
|
| 408 |
+
Upsampled tensor of shape ``(B, T, hidden_size // 4, H*2, W*2)``.
|
| 409 |
+
"""
|
| 410 |
+
x = self.convgru(x, hidden_state)
|
| 411 |
+
x = self.up(x)
|
| 412 |
+
return x
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
class Decoder(nn.Module):
|
| 416 |
+
r"""
|
| 417 |
+
ConvGRU-based decoder that stacks multiple :class:`DecoderBlock` layers.
|
| 418 |
+
|
| 419 |
+
After each block the spatial resolution is doubled via pixel-shuffle.
|
| 420 |
+
Hidden sizes are computed from the desired output channels.
|
| 421 |
+
|
| 422 |
+
Parameters
|
| 423 |
+
----------
|
| 424 |
+
output_channels : int, optional
|
| 425 |
+
Number of output channels. Default is ``1``.
|
| 426 |
+
num_blocks : int, optional
|
| 427 |
+
Number of decoder blocks to stack. Default is ``4``.
|
| 428 |
+
**kwargs
|
| 429 |
+
Additional keyword arguments forwarded to each :class:`DecoderBlock`.
|
| 430 |
+
"""
|
| 431 |
+
|
| 432 |
+
def __init__(self, output_channels: int = 1, num_blocks: int = 4, **kwargs):
|
| 433 |
+
"""
|
| 434 |
+
Initialize Decoder.
|
| 435 |
+
|
| 436 |
+
Parameters
|
| 437 |
+
----------
|
| 438 |
+
output_channels : int, optional
|
| 439 |
+
Number of output channels. Default is ``1``.
|
| 440 |
+
num_blocks : int, optional
|
| 441 |
+
Number of decoder blocks to stack. Default is ``4``.
|
| 442 |
+
**kwargs
|
| 443 |
+
Additional keyword arguments forwarded to each
|
| 444 |
+
:class:`DecoderBlock`.
|
| 445 |
+
"""
|
| 446 |
+
super().__init__()
|
| 447 |
+
self.channel_sizes = [output_channels * 4 ** (i + 1) for i in reversed(range(num_blocks))] # [256, 64, 16, 4]
|
| 448 |
+
self.blocks = nn.ModuleList(
|
| 449 |
+
[DecoderBlock(self.channel_sizes[i], self.channel_sizes[i], **kwargs) for i in range(num_blocks)]
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
def forward(self, x: torch.Tensor, hidden_states: list[torch.Tensor]) -> torch.Tensor:
|
| 453 |
+
"""
|
| 454 |
+
Forward the decoder through all blocks.
|
| 455 |
+
|
| 456 |
+
Parameters
|
| 457 |
+
----------
|
| 458 |
+
x : torch.Tensor
|
| 459 |
+
Input tensor of shape ``(B, T, C, H, W)``.
|
| 460 |
+
hidden_states : list of torch.Tensor
|
| 461 |
+
Hidden states from the encoder (in reverse order), one per block.
|
| 462 |
+
|
| 463 |
+
Returns
|
| 464 |
+
-------
|
| 465 |
+
out : torch.Tensor
|
| 466 |
+
Output tensor of shape
|
| 467 |
+
``(B, T, output_channels, H * 2^num_blocks, W * 2^num_blocks)``.
|
| 468 |
+
"""
|
| 469 |
+
for block, hidden_state in zip(self.blocks, hidden_states, strict=True):
|
| 470 |
+
x = block(x, hidden_state)
|
| 471 |
+
return x
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
class EncoderDecoder(nn.Module):
|
| 475 |
+
"""
|
| 476 |
+
Full encoder-decoder model for spatio-temporal forecasting.
|
| 477 |
+
|
| 478 |
+
Encodes an input sequence into multi-scale hidden states and decodes
|
| 479 |
+
them into a forecast sequence, optionally generating multiple ensemble
|
| 480 |
+
members via noisy decoder inputs.
|
| 481 |
+
|
| 482 |
+
Parameters
|
| 483 |
+
----------
|
| 484 |
+
channels : int, optional
|
| 485 |
+
Number of input/output channels. Default is ``1``.
|
| 486 |
+
num_blocks : int, optional
|
| 487 |
+
Number of encoder and decoder blocks. Default is ``4``.
|
| 488 |
+
**kwargs
|
| 489 |
+
Additional keyword arguments forwarded to :class:`Encoder` and
|
| 490 |
+
:class:`Decoder`.
|
| 491 |
+
"""
|
| 492 |
+
|
| 493 |
+
def __init__(self, channels: int = 1, num_blocks: int = 4, **kwargs):
|
| 494 |
+
"""
|
| 495 |
+
Initialize EncoderDecoder.
|
| 496 |
+
|
| 497 |
+
Parameters
|
| 498 |
+
----------
|
| 499 |
+
channels : int, optional
|
| 500 |
+
Number of input/output channels. Default is ``1``.
|
| 501 |
+
num_blocks : int, optional
|
| 502 |
+
Number of encoder and decoder blocks. Default is ``4``.
|
| 503 |
+
**kwargs
|
| 504 |
+
Additional keyword arguments forwarded to :class:`Encoder` and
|
| 505 |
+
:class:`Decoder`.
|
| 506 |
+
"""
|
| 507 |
+
super().__init__()
|
| 508 |
+
self.encoder = Encoder(channels, num_blocks, **kwargs)
|
| 509 |
+
self.decoder = Decoder(channels, num_blocks, **kwargs)
|
| 510 |
+
|
| 511 |
+
def forward(self, x: torch.Tensor, steps: int, noisy_decoder: bool = False, ensemble_size: int = 1) -> torch.Tensor:
|
| 512 |
+
"""
|
| 513 |
+
Forward the encoder-decoder model.
|
| 514 |
+
|
| 515 |
+
Parameters
|
| 516 |
+
----------
|
| 517 |
+
x : torch.Tensor
|
| 518 |
+
Input tensor of shape ``(B, T, C, H, W)``.
|
| 519 |
+
steps : int
|
| 520 |
+
Number of future timesteps to forecast.
|
| 521 |
+
noisy_decoder : bool, optional
|
| 522 |
+
If ``True``, feed random noise (instead of zeros) as input to the
|
| 523 |
+
decoder. Default is ``False``.
|
| 524 |
+
ensemble_size : int, optional
|
| 525 |
+
Number of ensemble members to generate. When ``> 1``, the decoder
|
| 526 |
+
is always run with noisy inputs. Default is ``1``.
|
| 527 |
+
|
| 528 |
+
Returns
|
| 529 |
+
-------
|
| 530 |
+
preds : torch.Tensor
|
| 531 |
+
Forecast tensor. Shape is ``(B, steps, C, H, W)`` when
|
| 532 |
+
``ensemble_size == 1``, or
|
| 533 |
+
``(B, steps, ensemble_size * C, H, W)`` when ``ensemble_size > 1``
|
| 534 |
+
(for C=1, this is ``(B, steps, ensemble_size, H, W)``).
|
| 535 |
+
"""
|
| 536 |
+
|
| 537 |
+
# encode the input tensor into a sequence of hidden states
|
| 538 |
+
encoded = self.encoder(x)
|
| 539 |
+
|
| 540 |
+
# create a tensor with the same shape as the last hidden state of the encoder to use as a input for the decoder
|
| 541 |
+
x_dec_shape = list(encoded[-1].shape)
|
| 542 |
+
|
| 543 |
+
# set the desired number of timestep for the output
|
| 544 |
+
x_dec_shape[1] = steps
|
| 545 |
+
|
| 546 |
+
# collect all the last hidden states of the encoder blocks in reverse order
|
| 547 |
+
last_hidden_per_block = [e[:, -1] for e in reversed(encoded)]
|
| 548 |
+
|
| 549 |
+
if ensemble_size > 1:
|
| 550 |
+
# Generate M ensemble members by running decoder M times with different noise
|
| 551 |
+
preds = []
|
| 552 |
+
for _ in range(ensemble_size):
|
| 553 |
+
# the input will be random noise for each ensemble member
|
| 554 |
+
x_dec = torch.randn(x_dec_shape, dtype=encoded[-1].dtype, device=encoded[-1].device)
|
| 555 |
+
# decode (unroll) the input hidden states into a forecast sequence of N timesteps
|
| 556 |
+
decoded = self.decoder(x_dec, last_hidden_per_block)
|
| 557 |
+
preds.append(decoded)
|
| 558 |
+
# stack along channel/ensemble dimension: (B, T, M, H, W)
|
| 559 |
+
return torch.cat(preds, dim=2)
|
| 560 |
+
else:
|
| 561 |
+
# the input will be of random values if noisy_decoder is True, otherwise with zeros
|
| 562 |
+
x_dec_func = torch.randn if noisy_decoder else torch.zeros
|
| 563 |
+
|
| 564 |
+
# create the input tensor for the decoder
|
| 565 |
+
x_dec = x_dec_func(x_dec_shape, dtype=encoded[-1].dtype, device=encoded[-1].device)
|
| 566 |
+
|
| 567 |
+
# decode (unroll) the input hidden states into a forecast sequence of N timesteps
|
| 568 |
+
decoded = self.decoder(x_dec, last_hidden_per_block)
|
| 569 |
+
return decoded
|
convgru_ensemble/py.typed
ADDED
|
File without changes
|
convgru_ensemble/serve.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI inference server for ConvGRU-Ensemble nowcasting model."""
|
| 2 |
+
|
| 3 |
+
import io
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
from contextlib import asynccontextmanager
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import xarray as xr
|
| 10 |
+
from fastapi import FastAPI, File, Query, UploadFile
|
| 11 |
+
from fastapi.responses import Response
|
| 12 |
+
|
| 13 |
+
_model = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _load_model():
|
| 17 |
+
from .lightning_model import RadarLightningModel
|
| 18 |
+
|
| 19 |
+
device = os.environ.get("DEVICE", "cpu")
|
| 20 |
+
checkpoint = os.environ.get("MODEL_CHECKPOINT")
|
| 21 |
+
hub_repo = os.environ.get("HF_REPO_ID")
|
| 22 |
+
|
| 23 |
+
if hub_repo:
|
| 24 |
+
return RadarLightningModel.from_pretrained(hub_repo, device=device)
|
| 25 |
+
elif checkpoint:
|
| 26 |
+
return RadarLightningModel.from_checkpoint(checkpoint, device=device)
|
| 27 |
+
else:
|
| 28 |
+
raise RuntimeError("Set MODEL_CHECKPOINT or HF_REPO_ID environment variable.")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@asynccontextmanager
|
| 32 |
+
async def lifespan(app: FastAPI):
|
| 33 |
+
global _model
|
| 34 |
+
_model = _load_model()
|
| 35 |
+
yield
|
| 36 |
+
_model = None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
app = FastAPI(
|
| 40 |
+
title="ConvGRU-Ensemble Nowcasting API",
|
| 41 |
+
version="0.1.0",
|
| 42 |
+
description="Ensemble precipitation nowcasting from radar data",
|
| 43 |
+
lifespan=lifespan,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@app.get("/health")
|
| 48 |
+
async def health():
|
| 49 |
+
"""Health check endpoint."""
|
| 50 |
+
return {"status": "ok", "model_loaded": _model is not None}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@app.get("/model/info")
|
| 54 |
+
async def model_info():
|
| 55 |
+
"""Return model metadata."""
|
| 56 |
+
if _model is None:
|
| 57 |
+
return {"error": "Model not loaded"}
|
| 58 |
+
hp = _model.hparams
|
| 59 |
+
return {
|
| 60 |
+
"architecture": "ConvGRU-Ensemble EncoderDecoder",
|
| 61 |
+
"input_channels": hp.input_channels,
|
| 62 |
+
"num_blocks": hp.num_blocks,
|
| 63 |
+
"forecast_steps": hp.forecast_steps,
|
| 64 |
+
"ensemble_size": hp.ensemble_size,
|
| 65 |
+
"noisy_decoder": hp.noisy_decoder,
|
| 66 |
+
"loss_class": str(hp.loss_class),
|
| 67 |
+
"device": str(_model.device),
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@app.post("/predict")
|
| 72 |
+
async def predict(
|
| 73 |
+
file: UploadFile = File(..., description="NetCDF file with rain rate data (T, H, W)"), # noqa: B008
|
| 74 |
+
variable: str = Query("RR", description="Name of the rain rate variable"), # noqa: B008
|
| 75 |
+
forecast_steps: int = Query(12, ge=1, le=48, description="Number of future timesteps"), # noqa: B008
|
| 76 |
+
ensemble_size: int = Query(10, ge=1, le=50, description="Number of ensemble members"), # noqa: B008
|
| 77 |
+
):
|
| 78 |
+
"""
|
| 79 |
+
Run ensemble nowcasting inference on uploaded NetCDF data.
|
| 80 |
+
|
| 81 |
+
Accepts a NetCDF file containing past radar rain rate observations and
|
| 82 |
+
returns NetCDF predictions with ensemble forecasts.
|
| 83 |
+
"""
|
| 84 |
+
t0 = time.perf_counter()
|
| 85 |
+
|
| 86 |
+
# Read uploaded NetCDF
|
| 87 |
+
content = await file.read()
|
| 88 |
+
ds = xr.open_dataset(io.BytesIO(content))
|
| 89 |
+
|
| 90 |
+
if variable not in ds:
|
| 91 |
+
available = list(ds.data_vars)
|
| 92 |
+
return Response(
|
| 93 |
+
content=f"Variable '{variable}' not found. Available: {available}",
|
| 94 |
+
status_code=422,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
data = ds[variable].values
|
| 98 |
+
if data.ndim != 3:
|
| 99 |
+
return Response(
|
| 100 |
+
content=f"Expected 3D data (T, H, W), got shape {data.shape}",
|
| 101 |
+
status_code=422,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
past = data.astype(np.float32)
|
| 105 |
+
|
| 106 |
+
# Run inference
|
| 107 |
+
preds = _model.predict(past, forecast_steps=forecast_steps, ensemble_size=ensemble_size)
|
| 108 |
+
|
| 109 |
+
elapsed = time.perf_counter() - t0
|
| 110 |
+
|
| 111 |
+
# Build output NetCDF
|
| 112 |
+
ds_out = xr.Dataset(
|
| 113 |
+
{
|
| 114 |
+
"precipitation_forecast": xr.DataArray(
|
| 115 |
+
data=preds,
|
| 116 |
+
dims=["ensemble_member", "forecast_step", "y", "x"],
|
| 117 |
+
attrs={"units": "mm/h", "long_name": "Ensemble precipitation forecast"},
|
| 118 |
+
),
|
| 119 |
+
},
|
| 120 |
+
attrs={
|
| 121 |
+
"model": "ConvGRU-Ensemble",
|
| 122 |
+
"forecast_steps": forecast_steps,
|
| 123 |
+
"ensemble_size": ensemble_size,
|
| 124 |
+
"elapsed_seconds": f"{elapsed:.3f}",
|
| 125 |
+
},
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
buf = io.BytesIO()
|
| 129 |
+
ds_out.to_netcdf(buf)
|
| 130 |
+
buf.seek(0)
|
| 131 |
+
|
| 132 |
+
return Response(
|
| 133 |
+
content=buf.getvalue(),
|
| 134 |
+
media_type="application/x-netcdf",
|
| 135 |
+
headers={
|
| 136 |
+
"Content-Disposition": "attachment; filename=predictions.nc",
|
| 137 |
+
"X-Elapsed-Seconds": f"{elapsed:.3f}",
|
| 138 |
+
},
|
| 139 |
+
)
|
convgru_ensemble/train.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# train.py
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
import fiddle as fdl
|
| 7 |
+
import torch
|
| 8 |
+
import yaml
|
| 9 |
+
from absl import app, flags
|
| 10 |
+
from fiddle import absl_flags, printing
|
| 11 |
+
from pytorch_lightning import Trainer, seed_everything
|
| 12 |
+
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
|
| 13 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
| 14 |
+
|
| 15 |
+
from .datamodule import RadarDataModule
|
| 16 |
+
from .lightning_model import RadarLightningModel
|
| 17 |
+
from .losses import PIXEL_LOSSES
|
| 18 |
+
|
| 19 |
+
seed_everything(42, workers=True)
|
| 20 |
+
|
| 21 |
+
FLAGS = flags.FLAGS
|
| 22 |
+
flags.DEFINE_bool("print_config", False, "Print configuration and exit.")
|
| 23 |
+
flags.DEFINE_string("export_yaml", None, "Export configuration to YAML file and exit.")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def experiment() -> fdl.Config:
|
| 27 |
+
"""
|
| 28 |
+
Define the default experiment configuration.
|
| 29 |
+
|
| 30 |
+
Returns a Fiddle config that can be overridden from the command line
|
| 31 |
+
with ``--config config:experiment --config set:path.to.value=X``.
|
| 32 |
+
|
| 33 |
+
Returns
|
| 34 |
+
-------
|
| 35 |
+
cfg : fdl.Config
|
| 36 |
+
Nested Fiddle configuration containing datamodule, model, trainer,
|
| 37 |
+
callbacks, and logger settings.
|
| 38 |
+
"""
|
| 39 |
+
cfg = fdl.Config(dict)
|
| 40 |
+
|
| 41 |
+
# resume from checkpoint
|
| 42 |
+
cfg.checkpoint_path = None
|
| 43 |
+
|
| 44 |
+
# enable mixed precision for float32 matmuls if available
|
| 45 |
+
cfg.float32_matmul_precision = None
|
| 46 |
+
|
| 47 |
+
# compile model with torch.compile if desired
|
| 48 |
+
cfg.compile_model = False
|
| 49 |
+
|
| 50 |
+
# DataModule
|
| 51 |
+
cfg.datamodule = fdl.Config(
|
| 52 |
+
RadarDataModule,
|
| 53 |
+
zarr_path="./data/italian-radar-dpc-sri.zarr",
|
| 54 |
+
csv_path="./importance_sampler/output/sampled_datacubes_2021-01-01-2025-12-11_24x256x256_3x16x16_10000.csv",
|
| 55 |
+
steps=18,
|
| 56 |
+
train_ratio=0.90,
|
| 57 |
+
val_ratio=0.05,
|
| 58 |
+
return_mask=True,
|
| 59 |
+
deterministic=False,
|
| 60 |
+
augment=True,
|
| 61 |
+
# DataLoader params
|
| 62 |
+
batch_size=16,
|
| 63 |
+
num_workers=8,
|
| 64 |
+
pin_memory=True,
|
| 65 |
+
multiprocessing_context="fork",
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# Lightning Model
|
| 69 |
+
cfg.model = fdl.Config(
|
| 70 |
+
RadarLightningModel,
|
| 71 |
+
input_channels=1,
|
| 72 |
+
forecast_steps=12,
|
| 73 |
+
num_blocks=5,
|
| 74 |
+
ensemble_size=2,
|
| 75 |
+
noisy_decoder=False,
|
| 76 |
+
loss_class="crps",
|
| 77 |
+
loss_params={"temporal_lambda": 0.01},
|
| 78 |
+
masked_loss=True,
|
| 79 |
+
optimizer_class=torch.optim.Adam,
|
| 80 |
+
optimizer_params={"lr": 1e-4, "fused": True},
|
| 81 |
+
lr_scheduler_class=torch.optim.lr_scheduler.ReduceLROnPlateau,
|
| 82 |
+
lr_scheduler_params={"mode": "min", "factor": 0.5, "patience": 10},
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Trainer
|
| 86 |
+
cfg.trainer = fdl.Config(
|
| 87 |
+
Trainer,
|
| 88 |
+
accelerator="auto",
|
| 89 |
+
# gradient_clip_val=1.0,
|
| 90 |
+
max_epochs=1,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Callbacks
|
| 94 |
+
cfg.callbacks = fdl.Config(dict)
|
| 95 |
+
cfg.callbacks.checkpoint_val = fdl.Config(
|
| 96 |
+
ModelCheckpoint,
|
| 97 |
+
monitor="val_loss",
|
| 98 |
+
save_top_k=1,
|
| 99 |
+
mode="min",
|
| 100 |
+
dirpath=None,
|
| 101 |
+
filename=None, # Set dynamically: best-val-{ckpt_name}
|
| 102 |
+
save_last=False,
|
| 103 |
+
)
|
| 104 |
+
cfg.callbacks.checkpoint_train = fdl.Config(
|
| 105 |
+
ModelCheckpoint,
|
| 106 |
+
monitor="train_loss_epoch",
|
| 107 |
+
save_top_k=1,
|
| 108 |
+
mode="min",
|
| 109 |
+
dirpath=None,
|
| 110 |
+
filename=None, # Set dynamically: best-train-{ckpt_name}
|
| 111 |
+
save_last=False,
|
| 112 |
+
)
|
| 113 |
+
cfg.callbacks.early_stopping = fdl.Config(
|
| 114 |
+
EarlyStopping,
|
| 115 |
+
monitor="val_loss",
|
| 116 |
+
patience=100,
|
| 117 |
+
mode="min",
|
| 118 |
+
)
|
| 119 |
+
cfg.callbacks.lr_monitor = fdl.Config(
|
| 120 |
+
LearningRateMonitor,
|
| 121 |
+
logging_interval="step",
|
| 122 |
+
log_momentum=False,
|
| 123 |
+
log_weight_decay=False,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Loggers
|
| 127 |
+
cfg.loggers = fdl.Config(dict)
|
| 128 |
+
cfg.loggers.tensorboard = fdl.Config(
|
| 129 |
+
TensorBoardLogger,
|
| 130 |
+
save_dir="logs",
|
| 131 |
+
name=None, # Set dynamically in train()
|
| 132 |
+
version=None, # Set dynamically in train()
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
return cfg
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
_CONFIG = absl_flags.DEFINE_fiddle_config(
|
| 139 |
+
"config",
|
| 140 |
+
default_module=sys.modules[__name__],
|
| 141 |
+
help_string="Experiment configuration.",
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def train(cfg: fdl.Config) -> None:
|
| 146 |
+
"""
|
| 147 |
+
Run training with the given Fiddle configuration.
|
| 148 |
+
|
| 149 |
+
Builds all components (model, datamodule, trainer, callbacks, loggers),
|
| 150 |
+
sets up dynamic naming for checkpoints and TensorBoard logs, saves the
|
| 151 |
+
config as YAML, and runs ``trainer.fit`` followed by ``trainer.test``.
|
| 152 |
+
|
| 153 |
+
Parameters
|
| 154 |
+
----------
|
| 155 |
+
cfg : fdl.Config
|
| 156 |
+
Fiddle configuration as returned by :func:`experiment`.
|
| 157 |
+
"""
|
| 158 |
+
# enable tensor cores for float32 matmuls if available
|
| 159 |
+
if cfg.float32_matmul_precision is not None:
|
| 160 |
+
torch.set_float32_matmul_precision(cfg.float32_matmul_precision)
|
| 161 |
+
|
| 162 |
+
# Compute dynamic values for naming
|
| 163 |
+
future_steps = cfg.model.forecast_steps
|
| 164 |
+
past_steps = cfg.datamodule.steps - future_steps
|
| 165 |
+
|
| 166 |
+
if cfg.model.loss_class is None:
|
| 167 |
+
loss_name = "MSELoss"
|
| 168 |
+
elif isinstance(cfg.model.loss_class, type):
|
| 169 |
+
loss_name = cfg.model.loss_class.__name__
|
| 170 |
+
else:
|
| 171 |
+
loss_name = (
|
| 172 |
+
PIXEL_LOSSES[cfg.model.loss_class.lower()].__name__
|
| 173 |
+
if cfg.model.loss_class.lower() in PIXEL_LOSSES
|
| 174 |
+
else str(cfg.model.loss_class)
|
| 175 |
+
)
|
| 176 |
+
lr = (
|
| 177 |
+
cfg.model.optimizer_params["lr"]
|
| 178 |
+
if cfg.model.optimizer_params is not None and "lr" in cfg.model.optimizer_params
|
| 179 |
+
else "default"
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
noise_str: str = "_noise" if cfg.model.noisy_decoder else ""
|
| 183 |
+
ckpt_base_name: str = f"{past_steps}past-{future_steps}fut{noise_str}_bs{cfg.datamodule.batch_size}_lr{lr}"
|
| 184 |
+
|
| 185 |
+
# Set dynamic logger name and version first (checkpoint folder depends on it)
|
| 186 |
+
if cfg.loggers.tensorboard.name is None:
|
| 187 |
+
cfg.loggers.tensorboard.name = f"{loss_name}_{past_steps}past-{future_steps}fut{noise_str}"
|
| 188 |
+
|
| 189 |
+
jobid = os.getenv("SLURM_JOB_ID", None)
|
| 190 |
+
tb_version = f"_{cfg.loggers.tensorboard.version}" if cfg.loggers.tensorboard.version is not None else ""
|
| 191 |
+
|
| 192 |
+
if jobid is not None:
|
| 193 |
+
cfg.loggers.tensorboard.version = f"job{jobid}_{ckpt_base_name}{tb_version}"
|
| 194 |
+
else:
|
| 195 |
+
cfg.loggers.tensorboard.version = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{ckpt_base_name}{tb_version}"
|
| 196 |
+
|
| 197 |
+
# Set checkpoint paths inside tensorboard experiment folder
|
| 198 |
+
tb_log_dir = f"{cfg.loggers.tensorboard.save_dir}/{cfg.loggers.tensorboard.name}/{cfg.loggers.tensorboard.version}"
|
| 199 |
+
ckpt_dir = f"{tb_log_dir}/checkpoints"
|
| 200 |
+
|
| 201 |
+
# Val checkpoint
|
| 202 |
+
if cfg.callbacks.checkpoint_val.dirpath is None:
|
| 203 |
+
cfg.callbacks.checkpoint_val.dirpath = ckpt_dir
|
| 204 |
+
if cfg.callbacks.checkpoint_val.filename is None:
|
| 205 |
+
cfg.callbacks.checkpoint_val.filename = "best-val-" + ckpt_base_name + "_ep{epoch:03d}_loss{val_loss:.4f}"
|
| 206 |
+
|
| 207 |
+
# Train checkpoint
|
| 208 |
+
if cfg.callbacks.checkpoint_train.dirpath is None:
|
| 209 |
+
cfg.callbacks.checkpoint_train.dirpath = ckpt_dir
|
| 210 |
+
if cfg.callbacks.checkpoint_train.filename is None:
|
| 211 |
+
cfg.callbacks.checkpoint_train.filename = (
|
| 212 |
+
"best-train-" + ckpt_base_name + "_ep{epoch:03d}_loss{train_loss_epoch:.4f}"
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# Build all callbacks and loggers dynamically
|
| 216 |
+
callbacks_dict = fdl.build(cfg.callbacks)
|
| 217 |
+
loggers_dict = fdl.build(cfg.loggers)
|
| 218 |
+
callbacks = list(callbacks_dict.values())
|
| 219 |
+
loggers = list(loggers_dict.values())
|
| 220 |
+
|
| 221 |
+
# Add loggers and callbacks to trainer config
|
| 222 |
+
cfg.trainer.logger = loggers
|
| 223 |
+
cfg.trainer.callbacks = callbacks
|
| 224 |
+
|
| 225 |
+
print(printing.as_str_flattened(cfg))
|
| 226 |
+
|
| 227 |
+
# Save config to tensorboard folder
|
| 228 |
+
os.makedirs(tb_log_dir, exist_ok=True)
|
| 229 |
+
config_path = f"{tb_log_dir}/config.yaml"
|
| 230 |
+
with open(config_path, "w") as f:
|
| 231 |
+
yaml.dump(config_to_dict(cfg), f, default_flow_style=False, sort_keys=False)
|
| 232 |
+
print(f"Config saved to {config_path}")
|
| 233 |
+
|
| 234 |
+
# Build all components
|
| 235 |
+
built = fdl.build(cfg)
|
| 236 |
+
datamodule: RadarDataModule = built["datamodule"]
|
| 237 |
+
|
| 238 |
+
if cfg.checkpoint_path is not None:
|
| 239 |
+
print(f"Resuming training from checkpoint: {cfg.checkpoint_path}")
|
| 240 |
+
model = RadarLightningModel.load_from_checkpoint(cfg.checkpoint_path, strict=True, weights_only=False)
|
| 241 |
+
else:
|
| 242 |
+
model = built["model"]
|
| 243 |
+
trainer: Trainer = built["trainer"]
|
| 244 |
+
|
| 245 |
+
datamodule.setup()
|
| 246 |
+
print(
|
| 247 |
+
f"Train: {len(datamodule.train_dataset)}, Val: {len(datamodule.val_dataset)}, Test: {len(datamodule.test_dataset)}"
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
if cfg.compile_model:
|
| 251 |
+
print("Compiling model with torch.compile()...")
|
| 252 |
+
model = torch.compile(model, dynamic=True)
|
| 253 |
+
|
| 254 |
+
trainer.fit(model, datamodule=datamodule)
|
| 255 |
+
trainer.test(model, datamodule=datamodule)
|
| 256 |
+
print(f"Best val: {callbacks_dict['checkpoint_val'].best_model_path}")
|
| 257 |
+
print(f"Best train: {callbacks_dict['checkpoint_train'].best_model_path}")
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def config_to_dict(cfg: fdl.Config) -> dict:
|
| 261 |
+
"""
|
| 262 |
+
Recursively convert a Fiddle config to a nested dictionary.
|
| 263 |
+
|
| 264 |
+
Parameters
|
| 265 |
+
----------
|
| 266 |
+
cfg : fdl.Config
|
| 267 |
+
Fiddle configuration object.
|
| 268 |
+
|
| 269 |
+
Returns
|
| 270 |
+
-------
|
| 271 |
+
result : dict
|
| 272 |
+
Plain dictionary suitable for YAML serialization.
|
| 273 |
+
"""
|
| 274 |
+
result = {}
|
| 275 |
+
for key, value in fdl.ordered_arguments(cfg).items():
|
| 276 |
+
result[key] = config_to_dict(value) if isinstance(value, fdl.Config) else value
|
| 277 |
+
return result
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def main(argv: list[str]) -> None:
|
| 281 |
+
"""
|
| 282 |
+
Entry point for the training script.
|
| 283 |
+
|
| 284 |
+
Handles ``--print_config`` and ``--export_yaml`` flags, then delegates
|
| 285 |
+
to :func:`train`.
|
| 286 |
+
|
| 287 |
+
Parameters
|
| 288 |
+
----------
|
| 289 |
+
argv : list of str
|
| 290 |
+
Command-line arguments (unused, consumed by ``absl``).
|
| 291 |
+
"""
|
| 292 |
+
del argv
|
| 293 |
+
cfg = _CONFIG.value
|
| 294 |
+
if FLAGS.print_config:
|
| 295 |
+
print(printing.as_str_flattened(cfg))
|
| 296 |
+
return
|
| 297 |
+
if FLAGS.export_yaml:
|
| 298 |
+
cfg_dict = config_to_dict(cfg)
|
| 299 |
+
with open(FLAGS.export_yaml, "w") as f:
|
| 300 |
+
yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False)
|
| 301 |
+
print(f"Config exported to {FLAGS.export_yaml}")
|
| 302 |
+
return
|
| 303 |
+
train(cfg)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# Example command to run training with custom configuration overrides.
|
| 307 |
+
# uv run python train.py \
|
| 308 |
+
# --config config:experiment \
|
| 309 |
+
# --config set:callbacks.checkpoint.save_top_k=3 \
|
| 310 |
+
# --config set:model.num_blocks=5 \
|
| 311 |
+
# --config set:model.forecast_steps=12 \
|
| 312 |
+
# --config set:datamodule.steps=18 \
|
| 313 |
+
# --config set:datamodule.num_workers=32 \
|
| 314 |
+
# --config set:datamodule.batch_size=32
|
| 315 |
+
if __name__ == "__main__":
|
| 316 |
+
app.run(main)
|
convgru_ensemble/utils.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def rainrate_to_reflectivity(rainrate: np.ndarray) -> np.ndarray:
|
| 5 |
+
"""
|
| 6 |
+
Convert rain rate to reflectivity using the Marshall-Palmer relationship.
|
| 7 |
+
|
| 8 |
+
Applies Z = 200 * R^1.6 and converts to dBZ. Values below ~0.037 mm/h
|
| 9 |
+
are clipped to 0 dBZ; values above 60 dBZ are clipped to 60.
|
| 10 |
+
|
| 11 |
+
Parameters
|
| 12 |
+
----------
|
| 13 |
+
rainrate : np.ndarray
|
| 14 |
+
Rain rate in mm/h. Can be any shape.
|
| 15 |
+
|
| 16 |
+
Returns
|
| 17 |
+
-------
|
| 18 |
+
reflectivity : np.ndarray
|
| 19 |
+
Reflectivity in dBZ, clipped to [0, 60]. Same shape as input.
|
| 20 |
+
"""
|
| 21 |
+
epsilon = 1e-16
|
| 22 |
+
# We return 0 for any rain lighter than ~0.037mm/h
|
| 23 |
+
return (10 * np.log10(200 * rainrate**1.6 + epsilon)).clip(0, 60)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def normalize_reflectivity(reflectivity: np.ndarray) -> np.ndarray:
|
| 27 |
+
"""
|
| 28 |
+
Normalize reflectivity from [0, 60] dBZ to [-1, 1].
|
| 29 |
+
|
| 30 |
+
Parameters
|
| 31 |
+
----------
|
| 32 |
+
reflectivity : np.ndarray
|
| 33 |
+
Reflectivity in dBZ, expected in [0, 60]. Can be any shape.
|
| 34 |
+
|
| 35 |
+
Returns
|
| 36 |
+
-------
|
| 37 |
+
normalized : np.ndarray
|
| 38 |
+
Normalized reflectivity in [-1, 1]. Same shape as input.
|
| 39 |
+
"""
|
| 40 |
+
return (reflectivity / 30.0) - 1.0
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def denormalize_reflectivity(normalized: np.ndarray) -> np.ndarray:
|
| 44 |
+
"""
|
| 45 |
+
Denormalize from [-1, 1] back to [0, 60] dBZ reflectivity.
|
| 46 |
+
|
| 47 |
+
Parameters
|
| 48 |
+
----------
|
| 49 |
+
normalized : np.ndarray
|
| 50 |
+
Normalized reflectivity in [-1, 1]. Can be any shape.
|
| 51 |
+
|
| 52 |
+
Returns
|
| 53 |
+
-------
|
| 54 |
+
reflectivity : np.ndarray
|
| 55 |
+
Reflectivity in dBZ, in [0, 60]. Same shape as input.
|
| 56 |
+
"""
|
| 57 |
+
return (normalized + 1.0) * 30.0
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def reflectivity_to_rainrate(reflectivity: np.ndarray) -> np.ndarray:
|
| 61 |
+
"""
|
| 62 |
+
Convert reflectivity back to rain rate using the inverse Marshall-Palmer
|
| 63 |
+
relationship.
|
| 64 |
+
|
| 65 |
+
Applies R = (Z_linear / 200)^(1/1.6) where Z_linear = 10^(dBZ/10).
|
| 66 |
+
|
| 67 |
+
Parameters
|
| 68 |
+
----------
|
| 69 |
+
reflectivity : np.ndarray
|
| 70 |
+
Reflectivity in dBZ. Can be any shape.
|
| 71 |
+
|
| 72 |
+
Returns
|
| 73 |
+
-------
|
| 74 |
+
rainrate : np.ndarray
|
| 75 |
+
Rain rate in mm/h. Same shape as input.
|
| 76 |
+
"""
|
| 77 |
+
# Z = 200 * R^1.6
|
| 78 |
+
# R = (Z / 200)^(1/1.6)
|
| 79 |
+
z_linear = 10 ** (reflectivity / 10.0)
|
| 80 |
+
return (z_linear / 200.0) ** (1.0 / 1.6)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def rainrate_to_normalized(rainrate: np.ndarray) -> np.ndarray:
|
| 84 |
+
"""
|
| 85 |
+
Convert rain rate directly to normalized reflectivity.
|
| 86 |
+
|
| 87 |
+
Composes :func:`rainrate_to_reflectivity` and
|
| 88 |
+
:func:`normalize_reflectivity`.
|
| 89 |
+
|
| 90 |
+
Parameters
|
| 91 |
+
----------
|
| 92 |
+
rainrate : np.ndarray
|
| 93 |
+
Rain rate in mm/h. Can be any shape.
|
| 94 |
+
|
| 95 |
+
Returns
|
| 96 |
+
-------
|
| 97 |
+
normalized : np.ndarray
|
| 98 |
+
Normalized reflectivity in [-1, 1]. Same shape as input.
|
| 99 |
+
"""
|
| 100 |
+
reflectivity = rainrate_to_reflectivity(rainrate)
|
| 101 |
+
return normalize_reflectivity(reflectivity)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def normalized_to_rainrate(normalized: np.ndarray) -> np.ndarray:
|
| 105 |
+
"""
|
| 106 |
+
Convert normalized reflectivity back to rain rate.
|
| 107 |
+
|
| 108 |
+
Composes :func:`denormalize_reflectivity` and
|
| 109 |
+
:func:`reflectivity_to_rainrate`.
|
| 110 |
+
|
| 111 |
+
Parameters
|
| 112 |
+
----------
|
| 113 |
+
normalized : np.ndarray
|
| 114 |
+
Normalized reflectivity in [-1, 1]. Can be any shape.
|
| 115 |
+
|
| 116 |
+
Returns
|
| 117 |
+
-------
|
| 118 |
+
rainrate : np.ndarray
|
| 119 |
+
Rain rate in mm/h. Same shape as input.
|
| 120 |
+
"""
|
| 121 |
+
reflectivity = denormalize_reflectivity(normalized)
|
| 122 |
+
return reflectivity_to_rainrate(reflectivity)
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
api:
|
| 3 |
+
build: .
|
| 4 |
+
ports:
|
| 5 |
+
- "8000:8000"
|
| 6 |
+
volumes:
|
| 7 |
+
- ./checkpoints:/app/checkpoints
|
| 8 |
+
environment:
|
| 9 |
+
- MODEL_CHECKPOINT=/app/checkpoints/model.ckpt
|
| 10 |
+
- DEVICE=cpu
|
| 11 |
+
# Alternative: download from HuggingFace Hub
|
| 12 |
+
# - HF_REPO_ID=it4lia/irene
|
examples/sample_data.nc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1a94b2f0c576f8e80cac470b1994bf8faed3eae11fb6fee6feba21e1a52dc0e7
|
| 3 |
+
size 11232217
|
importance_sampler/filter_nan.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
from functools import partial
|
| 6 |
+
from multiprocessing import Pool
|
| 7 |
+
from queue import Queue
|
| 8 |
+
from threading import Thread
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import xarray as xr
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
START = time.time()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# === Functions ===
|
| 19 |
+
def dim_nan_count(mask, dim, delta, dim_len):
|
| 20 |
+
"""
|
| 21 |
+
Count NaN values in a rolling window along a single dimension.
|
| 22 |
+
|
| 23 |
+
Uses a cumulative-sum trick to efficiently compute the number of NaN
|
| 24 |
+
values within each window of size ``delta`` along the specified axis.
|
| 25 |
+
|
| 26 |
+
Parameters
|
| 27 |
+
----------
|
| 28 |
+
mask : np.ndarray
|
| 29 |
+
3-D binary array where 1 indicates NaN and 0 indicates valid.
|
| 30 |
+
dim : int
|
| 31 |
+
Axis along which to compute the rolling NaN count (0, 1, or 2).
|
| 32 |
+
delta : int
|
| 33 |
+
Window size (number of pixels) along ``dim``.
|
| 34 |
+
dim_len : int
|
| 35 |
+
Length of ``dim`` in the input array.
|
| 36 |
+
|
| 37 |
+
Returns
|
| 38 |
+
-------
|
| 39 |
+
nan_counts : np.ndarray
|
| 40 |
+
Integer array with the NaN count for each window position.
|
| 41 |
+
"""
|
| 42 |
+
cumsum = np.cumsum(mask, axis=dim, dtype=np.int32)
|
| 43 |
+
|
| 44 |
+
# Pad with zeros at the start along 'dim'
|
| 45 |
+
pad_width = [(1, 0) if i == dim else (0, 0) for i in range(3)]
|
| 46 |
+
padded_cumsum = np.pad(cumsum, pad_width=pad_width, mode="constant", constant_values=0)
|
| 47 |
+
|
| 48 |
+
# Rolling window: padded[start+delta:start+delta+dim_len] - padded[start:start+dim_len-delta]
|
| 49 |
+
slices_start = [slice(dim_len - delta) if i == dim else slice(None) for i in range(3)]
|
| 50 |
+
slices_end = [slice(delta, dim_len) if i == dim else slice(None) for i in range(3)]
|
| 51 |
+
|
| 52 |
+
# number of nans in each delta
|
| 53 |
+
return padded_cumsum[tuple(slices_end)] - padded_cumsum[tuple(slices_start)]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def dc_nan_count(chunk, deltas, dim_lenghts):
|
| 57 |
+
"""
|
| 58 |
+
Count the number of NaN values in each 3-D datacube within a chunk.
|
| 59 |
+
|
| 60 |
+
Applies :func:`dim_nan_count` sequentially along time, X, and Y to
|
| 61 |
+
compute the total NaN count for every possible datacube position.
|
| 62 |
+
|
| 63 |
+
Parameters
|
| 64 |
+
----------
|
| 65 |
+
chunk : np.ndarray
|
| 66 |
+
3-D array of shape ``(T, X, Y)`` — a time chunk of the full Zarr
|
| 67 |
+
dataset.
|
| 68 |
+
deltas : tuple of int
|
| 69 |
+
Datacube dimensions ``(Dt, w, h)`` along time, X, and Y.
|
| 70 |
+
dim_lenghts : tuple of int
|
| 71 |
+
Shape of ``chunk``, i.e. ``(T, X, Y)``.
|
| 72 |
+
|
| 73 |
+
Returns
|
| 74 |
+
-------
|
| 75 |
+
nans_cube_chunk : np.ndarray
|
| 76 |
+
Integer array where ``nans_cube_chunk[it, ix, iy]`` is the number of
|
| 77 |
+
NaN values in the datacube ``chunk[it:it+Dt, ix:ix+w, iy:iy+h]``.
|
| 78 |
+
"""
|
| 79 |
+
# Compute NaN mask and cumsum along time axis
|
| 80 |
+
nan_mask = np.isnan(chunk).astype(np.int16)
|
| 81 |
+
|
| 82 |
+
# Number of NaN along time
|
| 83 |
+
nans_t = dim_nan_count(nan_mask, dim=0, delta=deltas[0], dim_len=dim_lenghts[0])
|
| 84 |
+
|
| 85 |
+
# Number of NaN along X x T
|
| 86 |
+
nans_xt = dim_nan_count(nans_t, dim=1, delta=deltas[1], dim_len=dim_lenghts[1])
|
| 87 |
+
|
| 88 |
+
# Number of NaN in the datacube (Y x X x T)
|
| 89 |
+
nans_cube_chunk = dim_nan_count(nans_xt, dim=2, delta=deltas[2], dim_len=dim_lenghts[2])
|
| 90 |
+
|
| 91 |
+
return nans_cube_chunk
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def process_chunk(time_range, t_start_idx, data, N_nan, deltas, steps, valid_starts_gap, dc_nan_count):
|
| 95 |
+
"""
|
| 96 |
+
Process a single time chunk and return valid datacube indices.
|
| 97 |
+
|
| 98 |
+
Loads the chunk from the Zarr array, counts NaN values per datacube,
|
| 99 |
+
filters by the maximum allowed NaN threshold and time-continuity
|
| 100 |
+
constraints, and returns the valid ``(t, x, y)`` indices.
|
| 101 |
+
|
| 102 |
+
Parameters
|
| 103 |
+
----------
|
| 104 |
+
time_range : array-like of int
|
| 105 |
+
Two-element sequence ``(start_t, end_t)`` defining the chunk
|
| 106 |
+
boundaries (relative to ``t_start_idx``).
|
| 107 |
+
t_start_idx : int
|
| 108 |
+
Absolute time index corresponding to the dataset start date.
|
| 109 |
+
data : xr.DataArray
|
| 110 |
+
Zarr-backed data array with the ``'RR'`` variable.
|
| 111 |
+
N_nan : int
|
| 112 |
+
Maximum number of NaN values allowed per datacube.
|
| 113 |
+
deltas : tuple of int
|
| 114 |
+
Datacube dimensions ``(Dt, w, h)``.
|
| 115 |
+
steps : tuple of int
|
| 116 |
+
Stride ``(step_T, step_X, step_Y)`` for subsampling valid indices.
|
| 117 |
+
valid_starts_gap : np.ndarray
|
| 118 |
+
Array of valid time-start indices that have no temporal gaps.
|
| 119 |
+
dc_nan_count : callable
|
| 120 |
+
Function to count NaN values per datacube (see :func:`dc_nan_count`).
|
| 121 |
+
|
| 122 |
+
Returns
|
| 123 |
+
-------
|
| 124 |
+
idx_t : np.ndarray
|
| 125 |
+
Absolute time indices of valid datacubes.
|
| 126 |
+
idx_x : np.ndarray
|
| 127 |
+
X indices of valid datacubes.
|
| 128 |
+
idx_y : np.ndarray
|
| 129 |
+
Y indices of valid datacubes.
|
| 130 |
+
"""
|
| 131 |
+
try:
|
| 132 |
+
# start_t: start index of the chunk
|
| 133 |
+
start_t, end_t = time_range
|
| 134 |
+
|
| 135 |
+
# Chunk from Zarr (T, X, Y)
|
| 136 |
+
chunk = data[start_t + t_start_idx : end_t + t_start_idx, :, :]
|
| 137 |
+
dim_lenghts = chunk.shape # shape: (T, X, Y)
|
| 138 |
+
|
| 139 |
+
# Compute the number of NaNs in each datacube in chunk
|
| 140 |
+
nans_cube_chunk = dc_nan_count(chunk, deltas, dim_lenghts)
|
| 141 |
+
del chunk
|
| 142 |
+
|
| 143 |
+
# Apply the mask
|
| 144 |
+
valid_mask = nans_cube_chunk <= N_nan
|
| 145 |
+
del nans_cube_chunk
|
| 146 |
+
|
| 147 |
+
# This indices are relative to the chunk
|
| 148 |
+
idx_t_rel, idx_x, idx_y = np.where(valid_mask)
|
| 149 |
+
del valid_mask
|
| 150 |
+
|
| 151 |
+
# Cast to int32
|
| 152 |
+
idx_t_rel = idx_t_rel.astype(np.int32)
|
| 153 |
+
idx_x = idx_x.astype(np.int32)
|
| 154 |
+
idx_y = idx_y.astype(np.int32)
|
| 155 |
+
|
| 156 |
+
# Convert relative time indices
|
| 157 |
+
idx_t = idx_t_rel + start_t
|
| 158 |
+
|
| 159 |
+
# Keep only time indices in valid_starts_gap
|
| 160 |
+
time_mask = np.isin(idx_t, valid_starts_gap)
|
| 161 |
+
idx_t = idx_t[time_mask] + t_start_idx # also convert to absolute index
|
| 162 |
+
idx_x = idx_x[time_mask]
|
| 163 |
+
idx_y = idx_y[time_mask]
|
| 164 |
+
|
| 165 |
+
# Filter datacube indices according to steps
|
| 166 |
+
stride_mask = (idx_t % steps[0] == 0) & (idx_x % steps[1] == 0) & (idx_y % steps[2] == 0)
|
| 167 |
+
idx_x = idx_x[stride_mask]
|
| 168 |
+
idx_y = idx_y[stride_mask]
|
| 169 |
+
idx_t = idx_t[stride_mask]
|
| 170 |
+
|
| 171 |
+
return idx_t, idx_x, idx_y
|
| 172 |
+
|
| 173 |
+
except Exception as e:
|
| 174 |
+
print(f"Error processing chunk starting at t={start_t}: {e}", file=sys.stderr)
|
| 175 |
+
sys.exit(1)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def file_writer(output_queue, filename, batch_size=1000):
|
| 179 |
+
"""
|
| 180 |
+
Dedicated writer thread that flushes results to a CSV file in batches.
|
| 181 |
+
|
| 182 |
+
Reads ``(idx_t, idx_x, idx_y)`` tuples from the queue and writes them
|
| 183 |
+
as rows to the output file. Stops when a ``None`` sentinel is received.
|
| 184 |
+
|
| 185 |
+
Parameters
|
| 186 |
+
----------
|
| 187 |
+
output_queue : queue.Queue
|
| 188 |
+
Thread-safe queue providing ``(idx_t, idx_x, idx_y)`` tuples.
|
| 189 |
+
filename : str
|
| 190 |
+
Path to the output CSV file.
|
| 191 |
+
batch_size : int, optional
|
| 192 |
+
Number of rows to buffer before flushing to disk. Default is
|
| 193 |
+
``1000``.
|
| 194 |
+
"""
|
| 195 |
+
with open(filename, "w") as f:
|
| 196 |
+
f.write("t,x,y\n")
|
| 197 |
+
batch = []
|
| 198 |
+
|
| 199 |
+
while True:
|
| 200 |
+
item = output_queue.get()
|
| 201 |
+
|
| 202 |
+
if item is None: # Sentinel value to stop
|
| 203 |
+
# Write remaining batch
|
| 204 |
+
for t, x, y in batch:
|
| 205 |
+
f.write(f"{t},{x},{y}\n")
|
| 206 |
+
break
|
| 207 |
+
|
| 208 |
+
batch.extend(zip(*item, strict=True))
|
| 209 |
+
|
| 210 |
+
if len(batch) >= batch_size:
|
| 211 |
+
for t, x, y in batch:
|
| 212 |
+
f.write(f"{t},{x},{y}\n")
|
| 213 |
+
f.flush()
|
| 214 |
+
batch = []
|
| 215 |
+
|
| 216 |
+
print(f"Results saved to {filename}")
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# === Parse Arguments ===
|
| 220 |
+
parser = argparse.ArgumentParser(description="Process valid datacubes from Zarr dataset")
|
| 221 |
+
parser.add_argument("zarr_path", help="Path to the Zarr dataset")
|
| 222 |
+
parser.add_argument("--start_date", default=None, type=str, help="Start date (YYYY-MM-DD)")
|
| 223 |
+
parser.add_argument("--end_date", default=None, type=str, help="End date (YYYY-MM-DD)")
|
| 224 |
+
parser.add_argument("--Dt", type=int, default=24, help="Time depth")
|
| 225 |
+
parser.add_argument("--w", type=int, default=256, help="Spatial width")
|
| 226 |
+
parser.add_argument("--h", type=int, default=256, help="Spatial height")
|
| 227 |
+
parser.add_argument("--step_T", type=int, default=3, help="Time step")
|
| 228 |
+
parser.add_argument("--step_X", type=int, default=16, help="X step")
|
| 229 |
+
parser.add_argument("--step_Y", type=int, default=16, help="Y step")
|
| 230 |
+
parser.add_argument("--n_workers", type=int, default=8, help="Number of parallel workers")
|
| 231 |
+
parser.add_argument("--n_nan", type=int, default=10000, help="Maximum NaNs per datacube")
|
| 232 |
+
args = parser.parse_args()
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# === PARAMETERS ===
|
| 236 |
+
Dt = args.Dt # time depth
|
| 237 |
+
w = args.w # x width
|
| 238 |
+
h = args.h # y height
|
| 239 |
+
step_T = args.step_T
|
| 240 |
+
step_X = args.step_X
|
| 241 |
+
step_Y = args.step_Y
|
| 242 |
+
N_nan = args.n_nan # maximum number of nans in each datacube
|
| 243 |
+
|
| 244 |
+
n_workers = args.n_workers
|
| 245 |
+
time_chunk_size = 3 * Dt
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# === Dataset Loading ===
|
| 249 |
+
print(f"Opening Zarr dataset: {args.zarr_path}")
|
| 250 |
+
try:
|
| 251 |
+
# zg = zarr.open(args.zarr_path, mode='r')
|
| 252 |
+
zg = xr.open_zarr(args.zarr_path, decode_times=True)
|
| 253 |
+
RR_full = zg["RR"]
|
| 254 |
+
time_array_full = pd.to_datetime(zg["time"][:])
|
| 255 |
+
|
| 256 |
+
print(f"Full dataset shape: T={RR_full.shape[0]}, X={RR_full.shape[1]}, Y={RR_full.shape[2]}")
|
| 257 |
+
print(f"Full dataset time range: {time_array_full[0]} to {time_array_full[-1]}")
|
| 258 |
+
except Exception as e:
|
| 259 |
+
print(f"Error loading Zarr dataset: {e}")
|
| 260 |
+
sys.exit(1)
|
| 261 |
+
|
| 262 |
+
# Filter the dates
|
| 263 |
+
start_date = pd.to_datetime(args.start_date) if args.start_date else time_array_full[0]
|
| 264 |
+
end_date = pd.to_datetime(args.end_date) if args.end_date else time_array_full[-1]
|
| 265 |
+
|
| 266 |
+
# Find indices corresponding to date range
|
| 267 |
+
mask = (time_array_full >= start_date) & (time_array_full <= end_date)
|
| 268 |
+
valid_indices = np.where(mask)[0]
|
| 269 |
+
|
| 270 |
+
if len(valid_indices) == 0:
|
| 271 |
+
print(f"No data found between {start_date} and {end_date}")
|
| 272 |
+
sys.exit(1)
|
| 273 |
+
|
| 274 |
+
t_start_idx = valid_indices[0]
|
| 275 |
+
t_end_idx = valid_indices[-1] + 1
|
| 276 |
+
|
| 277 |
+
# Slice the data
|
| 278 |
+
size_T = t_end_idx - t_start_idx
|
| 279 |
+
size_X = RR_full.shape[1]
|
| 280 |
+
size_Y = RR_full.shape[2]
|
| 281 |
+
time_array = time_array_full[t_start_idx:t_end_idx]
|
| 282 |
+
|
| 283 |
+
print(f"Filtered dataset shape: T={size_T}, X={size_X}, Y={size_Y}")
|
| 284 |
+
print(f"Filtered dataset time range: {time_array[0]} to {time_array[-1]}")
|
| 285 |
+
|
| 286 |
+
# Calculate maximum valid indices
|
| 287 |
+
max_x = size_X - w + 1
|
| 288 |
+
max_y = size_Y - h + 1
|
| 289 |
+
max_t = size_T - Dt + 1
|
| 290 |
+
|
| 291 |
+
# === Time Continuity ===
|
| 292 |
+
print("Checking time continuity...")
|
| 293 |
+
try:
|
| 294 |
+
expected_step = pd.Timedelta("00:05:00")
|
| 295 |
+
time_diffs = time_array[1:] - time_array[:-1]
|
| 296 |
+
gaps = (time_diffs != expected_step).astype(int)
|
| 297 |
+
|
| 298 |
+
# Check continuity for windows of size Dt
|
| 299 |
+
window_sum = np.convolve(gaps, np.ones(Dt - 1, dtype=int), mode="valid")
|
| 300 |
+
|
| 301 |
+
# Find valid starting times: continuous windows at T_step intervals
|
| 302 |
+
valid_starts_gap = np.where(window_sum == 0)[0]
|
| 303 |
+
print(f"Found {len(valid_starts_gap)} valid time starts without gaps")
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
except Exception as e:
|
| 307 |
+
print(f"Error in time continuity check: {e}", file=sys.stderr)
|
| 308 |
+
sys.exit(1)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
# === Chunked NaN Processing ===
|
| 312 |
+
# Memory per chunk (x 4 because float32)
|
| 313 |
+
estimated_chunk_memory_gb = (time_chunk_size * size_X * size_Y * 4) / (1024**3)
|
| 314 |
+
print(f"Estimated memory per chunk: {estimated_chunk_memory_gb:.2f} GB")
|
| 315 |
+
print(f"Estimated total memory: {(estimated_chunk_memory_gb * n_workers):.2f} GB")
|
| 316 |
+
|
| 317 |
+
# Process time in chunks with overlap Dt
|
| 318 |
+
t_starts = np.arange(0, max_t, time_chunk_size)
|
| 319 |
+
t_ends = np.minimum(t_starts + time_chunk_size + Dt - 1, size_T)
|
| 320 |
+
t_pairs = np.stack((t_starts, t_ends), axis=1)
|
| 321 |
+
|
| 322 |
+
# Create partial function with fixed parameters
|
| 323 |
+
process_chunk_partial = partial(
|
| 324 |
+
process_chunk,
|
| 325 |
+
t_start_idx=t_start_idx,
|
| 326 |
+
data=RR_full,
|
| 327 |
+
N_nan=N_nan,
|
| 328 |
+
deltas=(Dt, w, h),
|
| 329 |
+
steps=(step_T, step_X, step_Y),
|
| 330 |
+
valid_starts_gap=valid_starts_gap,
|
| 331 |
+
dc_nan_count=dc_nan_count,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
# Chek if file exists
|
| 335 |
+
output_file = f"valid_datacubes_{args.start_date}-{args.end_date}_{Dt}x{w}x{h}_{step_T}x{step_X}x{step_Y}_{N_nan}.csv"
|
| 336 |
+
if os.path.exists(output_file):
|
| 337 |
+
response = input(f"File {output_file} already exists. Overwrite? (y/n): ")
|
| 338 |
+
if response.lower() != "y":
|
| 339 |
+
print("Exiting without overwriting.")
|
| 340 |
+
sys.exit(0)
|
| 341 |
+
else:
|
| 342 |
+
print(f"Overwriting {output_file}...")
|
| 343 |
+
|
| 344 |
+
# Start writer thread
|
| 345 |
+
output_queue = Queue(maxsize=100)
|
| 346 |
+
writer_thread = Thread(target=file_writer, args=(output_queue, output_file, 1000))
|
| 347 |
+
writer_thread.daemon = False
|
| 348 |
+
writer_thread.start()
|
| 349 |
+
|
| 350 |
+
# Process chunks in parallel
|
| 351 |
+
with Pool(n_workers) as pool:
|
| 352 |
+
for _i, hits in enumerate(
|
| 353 |
+
tqdm(pool.imap(process_chunk_partial, t_pairs, chunksize=1), total=len(t_starts), desc="Processing time chunks")
|
| 354 |
+
):
|
| 355 |
+
output_queue.put(hits)
|
| 356 |
+
|
| 357 |
+
# Signal writer thread to stop
|
| 358 |
+
output_queue.put(None)
|
| 359 |
+
writer_thread.join()
|
| 360 |
+
|
| 361 |
+
print(f"Done in {time.time() - START}s.")
|
| 362 |
+
sys.exit(0)
|
importance_sampler/output/sampled_datacubes_2021-01-01-2025-12-11_24x256x256_3x16x16_10000.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
importance_sampler/output/sampled_datacubes_2021-01-01-2025-12-11_24x256x256_3x16x16_10000_metadata.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"csv": "valid_datacubes_2021-01-01-2025-12-11_24x256x256_3x16x16_10000.csv",
|
| 3 |
+
"zarr": "/leonardo_scratch/fast/AI4W_forecast_2/italian-radar-dpc-sri.zarr",
|
| 4 |
+
"file": "sampled_datacubes_2021-01-01-2025-12-11_24x256x256_3x16x16_10000.csv",
|
| 5 |
+
"start_date": "2021-01-01",
|
| 6 |
+
"end_date": "2025-12-11",
|
| 7 |
+
"Dt": 24,
|
| 8 |
+
"w": 256,
|
| 9 |
+
"h": 256,
|
| 10 |
+
"step_T": 3,
|
| 11 |
+
"step_X": 16,
|
| 12 |
+
"step_Y": 16,
|
| 13 |
+
"N_nan": 10000,
|
| 14 |
+
"N_rand": 1,
|
| 15 |
+
"n_workers": 112,
|
| 16 |
+
"qmin": 0.0001,
|
| 17 |
+
"m": 0.1,
|
| 18 |
+
"s": 1,
|
| 19 |
+
"seed": null,
|
| 20 |
+
"timestamp": "2025-12-13 21:52:41"
|
| 21 |
+
}
|
importance_sampler/sample_valid_datacubes.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import time
|
| 6 |
+
from functools import partial
|
| 7 |
+
from multiprocessing import Pool
|
| 8 |
+
from queue import Queue
|
| 9 |
+
from threading import Thread
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import xarray as xr
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
START = time.time()
|
| 17 |
+
SEED = None # for reproducibility
|
| 18 |
+
|
| 19 |
+
# === Parse Arguments ===
|
| 20 |
+
parser = argparse.ArgumentParser(description="Importance sampler of the valid datacubes (after the nan filtering)")
|
| 21 |
+
parser.add_argument("zarr_path", help="Path to the Zarr dataset")
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"csv_path", help="Path to the CSV with the valid datacube coordinates (created by the nan filtering)"
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument("--q_min", type=float, default=1e-4, help="Minimum selection probability (default 1e-4)")
|
| 26 |
+
parser.add_argument("--s", type=float, default=1, help="Denominator in the exponential")
|
| 27 |
+
parser.add_argument("--m", type=float, default=0.1, help="Factor weighting the mean rescaled rain rate (dafault 0.1)")
|
| 28 |
+
parser.add_argument("--n_workers", type=int, default=8, help="Number of parallel workers (default 8)")
|
| 29 |
+
parser.add_argument("--n_rand", type=int, default=1, help="Number of random sampling of each datacube (dafaut 1)")
|
| 30 |
+
args = parser.parse_args()
|
| 31 |
+
|
| 32 |
+
# === PARAMETERS ===
|
| 33 |
+
s = args.s
|
| 34 |
+
qmin = args.q_min
|
| 35 |
+
m = args.m
|
| 36 |
+
|
| 37 |
+
n_workers = args.n_workers # number of parallel workers
|
| 38 |
+
N_rand = args.n_rand # number of random numbers per region
|
| 39 |
+
chunksize = 16000 # = 500 CSV lines per workers
|
| 40 |
+
|
| 41 |
+
# Parameters from CSV filename
|
| 42 |
+
name_arr = args.csv_path.split("_")
|
| 43 |
+
dates = name_arr[2]
|
| 44 |
+
start_date = "-".join(dates.split("-")[0:3])
|
| 45 |
+
end_date = "-".join(dates.split("-")[3:])
|
| 46 |
+
Dt, w, h = name_arr[3].split("x")
|
| 47 |
+
step_T, step_X, step_Y = name_arr[4].split("x")
|
| 48 |
+
N_nan = name_arr[5][:-4]
|
| 49 |
+
|
| 50 |
+
# Casting
|
| 51 |
+
Dt, w, h = int(Dt), int(w), int(h)
|
| 52 |
+
step_T, step_X, step_Y = int(step_T), int(step_X), int(step_Y)
|
| 53 |
+
N_nan = int(N_nan)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# === FUNCTIONS ===
|
| 57 |
+
def acceptance_probability(data):
|
| 58 |
+
"""
|
| 59 |
+
Calculate the acceptance probability for importance sampling.
|
| 60 |
+
|
| 61 |
+
The probability is ``min(1, q_min + m * mean(data))``, where ``q_min``
|
| 62 |
+
and ``m`` are module-level parameters.
|
| 63 |
+
|
| 64 |
+
Parameters
|
| 65 |
+
----------
|
| 66 |
+
data : np.ndarray
|
| 67 |
+
Rescaled rain rate data for a single datacube.
|
| 68 |
+
|
| 69 |
+
Returns
|
| 70 |
+
-------
|
| 71 |
+
q : float
|
| 72 |
+
Acceptance probability in ``[q_min, 1]``.
|
| 73 |
+
"""
|
| 74 |
+
return min(1.0, qmin + m * np.nanmean(data))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def process_datacube(coord, RR, N_rand, seed, acceptance_probability):
|
| 78 |
+
"""
|
| 79 |
+
Process a single space-time region for importance sampling.
|
| 80 |
+
|
| 81 |
+
Loads the datacube, rescales rain rate, computes an acceptance
|
| 82 |
+
probability, and performs ``N_rand`` random acceptance trials.
|
| 83 |
+
|
| 84 |
+
Parameters
|
| 85 |
+
----------
|
| 86 |
+
coord : array-like of int
|
| 87 |
+
Three-element sequence ``(it, ix, iy)`` specifying the datacube
|
| 88 |
+
origin.
|
| 89 |
+
RR : xr.DataArray
|
| 90 |
+
Rain rate data array from the Zarr dataset.
|
| 91 |
+
N_rand : int
|
| 92 |
+
Number of random acceptance trials per datacube.
|
| 93 |
+
seed : int or None
|
| 94 |
+
Random seed for reproducibility. If ``None``, non-deterministic.
|
| 95 |
+
acceptance_probability : callable
|
| 96 |
+
Function that takes a data array and returns a probability in
|
| 97 |
+
``[0, 1]``.
|
| 98 |
+
|
| 99 |
+
Returns
|
| 100 |
+
-------
|
| 101 |
+
hits : list of tuple of int
|
| 102 |
+
List of accepted ``(it, ix, iy)`` tuples (may contain duplicates
|
| 103 |
+
if accepted multiple times).
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
it, ix, iy = coord
|
| 108 |
+
time_slice = slice(it, it + Dt)
|
| 109 |
+
x_slice = slice(ix, ix + w)
|
| 110 |
+
y_slice = slice(iy, iy + h)
|
| 111 |
+
|
| 112 |
+
# Load data from Zarr
|
| 113 |
+
data = RR[time_slice, x_slice, y_slice]
|
| 114 |
+
data = 1 - np.exp(-data / s)
|
| 115 |
+
|
| 116 |
+
# Calculate acceptance probability
|
| 117 |
+
q = acceptance_probability(data)
|
| 118 |
+
|
| 119 |
+
# Generate random numbers with seed for reproducibility
|
| 120 |
+
rng = np.random.default_rng(seed)
|
| 121 |
+
random_numbers = rng.random(N_rand)
|
| 122 |
+
accepted_count = np.sum(random_numbers <= q)
|
| 123 |
+
|
| 124 |
+
# Return accepted hits
|
| 125 |
+
hits = [(it, ix, iy)] * accepted_count
|
| 126 |
+
return hits
|
| 127 |
+
except Exception as e:
|
| 128 |
+
print(f"Error processing region ({it}, {ix}, {iy}): {e}", file=sys.stderr)
|
| 129 |
+
return []
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def file_writer(output_queue, filename, batch_size=1000):
|
| 133 |
+
"""
|
| 134 |
+
Dedicated writer thread that flushes results to a CSV file in batches.
|
| 135 |
+
|
| 136 |
+
Reads lists of ``(t, x, y)`` tuples from the queue and writes them as
|
| 137 |
+
CSV rows. Stops when a ``None`` sentinel is received.
|
| 138 |
+
|
| 139 |
+
Parameters
|
| 140 |
+
----------
|
| 141 |
+
output_queue : queue.Queue
|
| 142 |
+
Thread-safe queue providing lists of ``(t, x, y)`` tuples.
|
| 143 |
+
filename : str
|
| 144 |
+
Path to the output CSV file.
|
| 145 |
+
batch_size : int, optional
|
| 146 |
+
Number of rows to buffer before flushing to disk. Default is
|
| 147 |
+
``1000``.
|
| 148 |
+
"""
|
| 149 |
+
with open(filename, "w") as f:
|
| 150 |
+
f.write("t,x,y\n")
|
| 151 |
+
batch = []
|
| 152 |
+
|
| 153 |
+
while True:
|
| 154 |
+
item = output_queue.get()
|
| 155 |
+
|
| 156 |
+
if item is None: # Sentinel value to stop
|
| 157 |
+
# Write remaining batch
|
| 158 |
+
for t, x, y in batch:
|
| 159 |
+
f.write(f"{t},{x},{y}\n")
|
| 160 |
+
break
|
| 161 |
+
|
| 162 |
+
batch.extend(item)
|
| 163 |
+
|
| 164 |
+
if len(batch) >= batch_size:
|
| 165 |
+
for t, x, y in batch:
|
| 166 |
+
f.write(f"{t},{x},{y}\n")
|
| 167 |
+
f.flush()
|
| 168 |
+
batch = []
|
| 169 |
+
|
| 170 |
+
print(f"Results saved to {filename}")
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# === Dataset Loading ===
|
| 174 |
+
print(f"Opening Zarr dataset: {args.zarr_path}")
|
| 175 |
+
try:
|
| 176 |
+
zg = xr.open_zarr(args.zarr_path, mode="r")
|
| 177 |
+
RR = zg["RR"]
|
| 178 |
+
except Exception as e:
|
| 179 |
+
print(f"Error loading Zarr dataset: {e}")
|
| 180 |
+
sys.exit(1)
|
| 181 |
+
|
| 182 |
+
# Chek if file exists
|
| 183 |
+
output_file = f"sampled_datacubes_{start_date}-{end_date}_{Dt}x{w}x{h}_{step_T}x{step_X}x{step_Y}_{N_nan}.csv"
|
| 184 |
+
if os.path.exists(output_file):
|
| 185 |
+
response = input(f"File {output_file} already exists. Overwrite? (y/n): ")
|
| 186 |
+
if response.lower() != "y":
|
| 187 |
+
print("Exiting without overwriting.")
|
| 188 |
+
sys.exit(0)
|
| 189 |
+
else:
|
| 190 |
+
print(f"Overwriting {output_file}...")
|
| 191 |
+
|
| 192 |
+
# Start writer thread
|
| 193 |
+
output_queue = Queue(maxsize=100)
|
| 194 |
+
writer_thread = Thread(target=file_writer, args=(output_queue, output_file, 1000))
|
| 195 |
+
writer_thread.daemon = False
|
| 196 |
+
writer_thread.start()
|
| 197 |
+
|
| 198 |
+
# save metadata
|
| 199 |
+
metadata = {
|
| 200 |
+
"csv": args.csv_path,
|
| 201 |
+
"zarr": args.zarr_path,
|
| 202 |
+
"file": output_file,
|
| 203 |
+
"start_date": start_date,
|
| 204 |
+
"end_date": end_date,
|
| 205 |
+
"Dt": Dt,
|
| 206 |
+
"w": w,
|
| 207 |
+
"h": h,
|
| 208 |
+
"step_T": step_T,
|
| 209 |
+
"step_X": step_X,
|
| 210 |
+
"step_Y": step_Y,
|
| 211 |
+
"N_nan": N_nan,
|
| 212 |
+
"N_rand": N_rand,
|
| 213 |
+
"n_workers": n_workers,
|
| 214 |
+
"qmin": qmin,
|
| 215 |
+
"m": m,
|
| 216 |
+
"s": s,
|
| 217 |
+
"seed": SEED,
|
| 218 |
+
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
| 219 |
+
}
|
| 220 |
+
metadata_filename = output_file.replace(".csv", "_metadata.json")
|
| 221 |
+
with open(metadata_filename, "w") as f:
|
| 222 |
+
json.dump(metadata, f, indent=2)
|
| 223 |
+
print(f"Saved run metadata to {metadata_filename}")
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# === IMPORTANCE SAMPLING ===
|
| 227 |
+
# Create partial function with fixed parameters
|
| 228 |
+
process_datacube_partial = partial(
|
| 229 |
+
process_datacube, RR=RR, N_rand=N_rand, seed=SEED, acceptance_probability=acceptance_probability
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
pool_chunksize = max(1, chunksize // n_workers)
|
| 233 |
+
|
| 234 |
+
with Pool(n_workers) as pool:
|
| 235 |
+
pbar = tqdm(desc="Processing CSV chunks")
|
| 236 |
+
|
| 237 |
+
# Loading the CSV by chunks
|
| 238 |
+
for chunk in pd.read_csv(
|
| 239 |
+
args.csv_path,
|
| 240 |
+
usecols=["t", "x", "y"],
|
| 241 |
+
dtype={"t": "int32", "x": "int32", "y": "int32"},
|
| 242 |
+
engine="c",
|
| 243 |
+
chunksize=chunksize,
|
| 244 |
+
):
|
| 245 |
+
for hits in pool.imap(process_datacube_partial, chunk.values, chunksize=pool_chunksize):
|
| 246 |
+
if hits:
|
| 247 |
+
output_queue.put(hits)
|
| 248 |
+
pbar.update(1)
|
| 249 |
+
|
| 250 |
+
pbar.close()
|
| 251 |
+
|
| 252 |
+
# Signal writer thread to stop
|
| 253 |
+
output_queue.put(None)
|
| 254 |
+
writer_thread.join()
|
| 255 |
+
|
| 256 |
+
print(f"Done in {time.time() - START}s.")
|
| 257 |
+
sys.exit(0)
|
notebooks/test_pretrained_model.ipynb
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "a5d812be",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# Test pretrained model\n",
|
| 9 |
+
"This notebook tests the pretrained model on a single datacube taken from the radar dataset (https://arcodatahub.com/datasets/datasets/italian-radar-dpc-sri.zarr)."
|
| 10 |
+
]
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"cell_type": "code",
|
| 14 |
+
"execution_count": null,
|
| 15 |
+
"id": "19c9a668",
|
| 16 |
+
"metadata": {},
|
| 17 |
+
"outputs": [],
|
| 18 |
+
"source": "import matplotlib.animation as animation\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pysteps.visualization.precipfields as pysteps_plot\nimport xarray as xr\nfrom IPython.display import HTML\n\nfrom convgru_ensemble import RadarLightningModel"
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"cell_type": "markdown",
|
| 22 |
+
"id": "3bd5ff98",
|
| 23 |
+
"metadata": {},
|
| 24 |
+
"source": "# Load radar data\nWe first load a sample of the italian radar dataset provided in the `examples/` folder."
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"cell_type": "code",
|
| 28 |
+
"execution_count": null,
|
| 29 |
+
"id": "008fb6ae",
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"outputs": [],
|
| 32 |
+
"source": "radar = xr.open_dataarray('../examples/sample_data.nc')\nradar"
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "markdown",
|
| 36 |
+
"id": "92f738bc",
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"source": [
|
| 39 |
+
"This contains 18 sequences of radar images on the whole Italy, from the 28th to the 29th of October 2024. This is one of the most intense precipitation on Italy during 2024."
|
| 40 |
+
]
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"cell_type": "code",
|
| 44 |
+
"execution_count": null,
|
| 45 |
+
"id": "fcb5529a",
|
| 46 |
+
"metadata": {},
|
| 47 |
+
"outputs": [],
|
| 48 |
+
"source": [
|
| 49 |
+
"# Create figure\n",
|
| 50 |
+
"fig, ax = plt.subplots(figsize=(4, 4.5))\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"def update(frame):\n",
|
| 53 |
+
" ax.clear()\n",
|
| 54 |
+
" data = radar.isel(time=frame)\n",
|
| 55 |
+
" pysteps_plot.plot_precip_field(data.values, ax=ax, colorbar=False)\n",
|
| 56 |
+
" ax.set_title(f'Precipitation - {data.time.values}')\n",
|
| 57 |
+
" return ax,\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"# Create animation\n",
|
| 60 |
+
"ani = animation.FuncAnimation(fig, update, frames=len(radar.time),\n",
|
| 61 |
+
" interval=500, blit=False, repeat=True)\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"# Display in notebook\n",
|
| 64 |
+
"display(HTML(ani.to_jshtml()))\n",
|
| 65 |
+
"plt.close()"
|
| 66 |
+
]
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"cell_type": "markdown",
|
| 70 |
+
"id": "e806b20a",
|
| 71 |
+
"metadata": {},
|
| 72 |
+
"source": [
|
| 73 |
+
"# Initialize the model \n",
|
| 74 |
+
"Initialize the model and load the weights from the checkpoint. You can change the number of future steps (forecast steps) and the ensemble size (ensemble_size). The other hyperparameters are fixed."
|
| 75 |
+
]
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"cell_type": "code",
|
| 79 |
+
"execution_count": null,
|
| 80 |
+
"id": "c8ce3ff4",
|
| 81 |
+
"metadata": {},
|
| 82 |
+
"outputs": [],
|
| 83 |
+
"source": "# Set model's parameters\nforecast_steps = 12\nensemble_size = 10\n\n# Initialize the model and load the checkpoint\n# Option A: Load from local checkpoint\nmodel = RadarLightningModel.from_checkpoint(checkpoint_path=\"../checkpoints/ConvGRU-CRPS_6past_12fut.ckpt\")\n\n# Option B: Load from HuggingFace Hub\n# model = RadarLightningModel.from_pretrained(\"it4lia/irene\")"
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"cell_type": "markdown",
|
| 87 |
+
"id": "6f6fc1f9",
|
| 88 |
+
"metadata": {},
|
| 89 |
+
"source": [
|
| 90 |
+
"# Run the inference\n",
|
| 91 |
+
"We can run the inference and plot the forecast"
|
| 92 |
+
]
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"cell_type": "code",
|
| 96 |
+
"execution_count": null,
|
| 97 |
+
"id": "d6d340b2",
|
| 98 |
+
"metadata": {},
|
| 99 |
+
"outputs": [],
|
| 100 |
+
"source": "# Past and future steps\npast_steps = 7\nforecast_steps = 12\npast, future = radar[:past_steps], radar[past_steps:past_steps+forecast_steps]\n\n# Predict the future rainrate\npred = model.predict(past, forecast_steps, ensemble_size=2)"
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"cell_type": "markdown",
|
| 104 |
+
"id": "f0705829",
|
| 105 |
+
"metadata": {},
|
| 106 |
+
"source": [
|
| 107 |
+
"### Plot the forecast"
|
| 108 |
+
]
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"cell_type": "code",
|
| 112 |
+
"execution_count": null,
|
| 113 |
+
"id": "0b0646f9",
|
| 114 |
+
"metadata": {},
|
| 115 |
+
"outputs": [],
|
| 116 |
+
"source": [
|
| 117 |
+
"# Ensemble mean\n",
|
| 118 |
+
"ensemble_mean = np.nanmean(pred, axis=0)\n",
|
| 119 |
+
"\n",
|
| 120 |
+
"# Initialize plots\n",
|
| 121 |
+
"fig, axs = plt.subplots(1, 4, figsize=(16,4.5))\n",
|
| 122 |
+
"row_labels = ['Ground Truth', 'Ensemble Mean', 'Member 1', 'Member 2']\n",
|
| 123 |
+
"data_sources = [future, ensemble_mean, pred[0], pred[1]]\n",
|
| 124 |
+
"\n",
|
| 125 |
+
"# Plot initial frame\n",
|
| 126 |
+
"for i, (ax, label, data) in enumerate(zip(axs, row_labels, data_sources)):\n",
|
| 127 |
+
" plot_precip_field(data[0], ax=ax, units='mm/h', colorscale='pysteps')\n",
|
| 128 |
+
" ax.set_title(label, fontsize=14)\n",
|
| 129 |
+
"\n",
|
| 130 |
+
"plt.tight_layout()\n",
|
| 131 |
+
"\n",
|
| 132 |
+
"# Animation function\n",
|
| 133 |
+
"def update(frame):\n",
|
| 134 |
+
" for i, (ax, data) in enumerate(zip(axs, data_sources)):\n",
|
| 135 |
+
" ax.clear()\n",
|
| 136 |
+
" plot_precip_field(data[frame], ax=ax, units='mm/h', colorscale='pysteps', colorbar=False)\n",
|
| 137 |
+
" ax.set_title(f'{row_labels[i]} - Step {frame}', fontsize=14)\n",
|
| 138 |
+
" return axs\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"# Create animation\n",
|
| 141 |
+
"anim = FuncAnimation(fig, update, frames=forecast_steps, interval=500, blit=False)\n",
|
| 142 |
+
"\n",
|
| 143 |
+
"# Display\n",
|
| 144 |
+
"HTML(anim.to_jshtml())\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"display(HTML(anim.to_jshtml()))\n",
|
| 147 |
+
"plt.close()"
|
| 148 |
+
]
|
| 149 |
+
}
|
| 150 |
+
],
|
| 151 |
+
"metadata": {
|
| 152 |
+
"kernelspec": {
|
| 153 |
+
"display_name": ".venv",
|
| 154 |
+
"language": "python",
|
| 155 |
+
"name": "python3"
|
| 156 |
+
},
|
| 157 |
+
"language_info": {
|
| 158 |
+
"codemirror_mode": {
|
| 159 |
+
"name": "ipython",
|
| 160 |
+
"version": 3
|
| 161 |
+
},
|
| 162 |
+
"file_extension": ".py",
|
| 163 |
+
"mimetype": "text/x-python",
|
| 164 |
+
"name": "python",
|
| 165 |
+
"nbconvert_exporter": "python",
|
| 166 |
+
"pygments_lexer": "ipython3",
|
| 167 |
+
"version": "3.13.12"
|
| 168 |
+
}
|
| 169 |
+
},
|
| 170 |
+
"nbformat": 4,
|
| 171 |
+
"nbformat_minor": 5
|
| 172 |
+
}
|
pyproject.toml
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=69", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "ConvGRU-Ensemble"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Ensemble precipitation nowcasting using Convolutional GRU networks"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.13"
|
| 11 |
+
dependencies = [
|
| 12 |
+
"absl-py>=2.3.1",
|
| 13 |
+
"einops>=0.8.1",
|
| 14 |
+
"etils>=1.13.0",
|
| 15 |
+
"fiddle>=0.3.0",
|
| 16 |
+
"fire>=0.7.1",
|
| 17 |
+
"importlib-resources>=6.5.2",
|
| 18 |
+
"jupyterlab>=4.5.5",
|
| 19 |
+
"matplotlib>=3.10.8",
|
| 20 |
+
"numpy>=2.3.5",
|
| 21 |
+
"pandas>=2.3.3",
|
| 22 |
+
"pysteps>=1.19.0",
|
| 23 |
+
"pytorch-lightning>=2.6.0",
|
| 24 |
+
"pyyaml>=6.0.3",
|
| 25 |
+
"tqdm>=4.67.1",
|
| 26 |
+
"tensorboard>=2.20.0",
|
| 27 |
+
"torch>=2.9.1",
|
| 28 |
+
"torchvision>=0.24.1",
|
| 29 |
+
"xarray>=2025.12.0",
|
| 30 |
+
"zarr>=3.1.5",
|
| 31 |
+
"huggingface_hub>=0.27.0",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
[project.optional-dependencies]
|
| 35 |
+
serve = [
|
| 36 |
+
"fastapi>=0.115.0",
|
| 37 |
+
"uvicorn>=0.34.0",
|
| 38 |
+
"python-multipart>=0.0.18",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
[project.scripts]
|
| 42 |
+
convgru-ensemble = "convgru_ensemble.cli:main"
|
| 43 |
+
|
| 44 |
+
[dependency-groups]
|
| 45 |
+
dev = [
|
| 46 |
+
"pytest>=8.3.5",
|
| 47 |
+
"pre-commit>=4.0.0",
|
| 48 |
+
"ruff>=0.9.10",
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
[tool.setuptools]
|
| 52 |
+
packages = ["convgru_ensemble"]
|
| 53 |
+
|
| 54 |
+
[tool.ruff]
|
| 55 |
+
target-version = "py313"
|
| 56 |
+
line-length = 120
|
| 57 |
+
extend-exclude = ["notebooks/", "importance_sampler/"]
|
| 58 |
+
|
| 59 |
+
[tool.ruff.lint]
|
| 60 |
+
select = [
|
| 61 |
+
"E", # pycodestyle errors
|
| 62 |
+
"W", # pycodestyle warnings
|
| 63 |
+
"F", # pyflakes
|
| 64 |
+
"I", # isort
|
| 65 |
+
"UP", # pyupgrade
|
| 66 |
+
"B", # flake8-bugbear
|
| 67 |
+
"SIM", # flake8-simplify
|
| 68 |
+
"NPY", # numpy-specific rules
|
| 69 |
+
]
|
| 70 |
+
ignore = [
|
| 71 |
+
"E501", # line too long (handled by formatter)
|
| 72 |
+
"SIM108", # ternary operator (readability preference)
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
[tool.ruff.lint.isort]
|
| 76 |
+
known-first-party = ["convgru_ensemble"]
|
scripts/upload_to_hub.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""Upload a trained ConvGRU-Ensemble model to HuggingFace Hub."""
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
parser = argparse.ArgumentParser(description="Upload model to HuggingFace Hub")
|
| 10 |
+
parser.add_argument("--checkpoint", required=True, help="Path to .ckpt checkpoint file")
|
| 11 |
+
parser.add_argument("--repo-id", required=True, help="HuggingFace repo ID (e.g., it4lia/irene)")
|
| 12 |
+
parser.add_argument("--model-card", default=None, help="Path to model card markdown file")
|
| 13 |
+
parser.add_argument("--private", action="store_true", help="Create a private repository")
|
| 14 |
+
args = parser.parse_args()
|
| 15 |
+
|
| 16 |
+
# Default model card
|
| 17 |
+
model_card = args.model_card
|
| 18 |
+
if model_card is None:
|
| 19 |
+
default_card = Path(__file__).parent.parent / "MODEL_CARD.md"
|
| 20 |
+
if default_card.exists():
|
| 21 |
+
model_card = str(default_card)
|
| 22 |
+
|
| 23 |
+
from convgru_ensemble.hub import push_to_hub
|
| 24 |
+
|
| 25 |
+
url = push_to_hub(
|
| 26 |
+
checkpoint_path=args.checkpoint,
|
| 27 |
+
repo_id=args.repo_id,
|
| 28 |
+
model_card_path=model_card,
|
| 29 |
+
private=args.private,
|
| 30 |
+
)
|
| 31 |
+
print(f"Model uploaded: {url}")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
main()
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pytest
|
| 3 |
+
|
| 4 |
+
from convgru_ensemble.lightning_model import RadarLightningModel
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@pytest.fixture
|
| 8 |
+
def model_small():
|
| 9 |
+
"""Small model with num_blocks=2 for fast testing."""
|
| 10 |
+
return RadarLightningModel(
|
| 11 |
+
input_channels=1,
|
| 12 |
+
num_blocks=2,
|
| 13 |
+
forecast_steps=2,
|
| 14 |
+
ensemble_size=2,
|
| 15 |
+
noisy_decoder=False,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@pytest.fixture
|
| 20 |
+
def sample_rain_rate():
|
| 21 |
+
"""Synthetic rain rate data of shape (T, H, W)."""
|
| 22 |
+
rng = np.random.default_rng(42)
|
| 23 |
+
return rng.random((4, 16, 16), dtype=np.float32) * 10.0 # 0-10 mm/h
|
tests/test_hub.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from unittest.mock import MagicMock, patch
|
| 2 |
+
|
| 3 |
+
from convgru_ensemble.hub import from_pretrained
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@patch("convgru_ensemble.hub.hf_hub_download")
|
| 7 |
+
@patch("convgru_ensemble.hub.RadarLightningModel", create=True)
|
| 8 |
+
def test_from_pretrained_calls_hf_hub_download(mock_model_cls, mock_download):
|
| 9 |
+
mock_download.return_value = "/tmp/cached/model.ckpt"
|
| 10 |
+
|
| 11 |
+
# Patch the import inside the function
|
| 12 |
+
with patch("convgru_ensemble.lightning_model.RadarLightningModel") as mock_cls:
|
| 13 |
+
mock_cls.from_checkpoint.return_value = MagicMock()
|
| 14 |
+
from_pretrained("it4lia/irene", filename="model.ckpt", device="cpu")
|
| 15 |
+
|
| 16 |
+
mock_download.assert_called_once_with(repo_id="it4lia/irene", filename="model.ckpt")
|
tests/test_inference.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import xarray as xr
|
| 3 |
+
|
| 4 |
+
from convgru_ensemble.lightning_model import RadarLightningModel
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def test_inference_on_sample_data():
|
| 8 |
+
"""End-to-end inference on the sample NetCDF with a freshly initialized model."""
|
| 9 |
+
ds = xr.open_dataset("examples/sample_data.nc")
|
| 10 |
+
rain = ds["RR"].values # (54, 1400, 1200)
|
| 11 |
+
|
| 12 |
+
# Use 6 past frames, full spatial extent
|
| 13 |
+
past = rain[:6].astype(np.float32)
|
| 14 |
+
_, H, W = past.shape
|
| 15 |
+
|
| 16 |
+
model = RadarLightningModel(
|
| 17 |
+
input_channels=1,
|
| 18 |
+
num_blocks=3,
|
| 19 |
+
forecast_steps=4,
|
| 20 |
+
ensemble_size=1,
|
| 21 |
+
noisy_decoder=False,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
preds = model.predict(past, forecast_steps=4, ensemble_size=1)
|
| 25 |
+
|
| 26 |
+
assert preds.shape == (1, 4, H, W)
|
| 27 |
+
assert np.isfinite(preds).all()
|
| 28 |
+
assert preds.dtype == np.float64 or preds.dtype == np.float32
|
tests/test_lightning_model.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from convgru_ensemble.lightning_model import RadarLightningModel
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def test_predict_handles_unpadded_inputs():
|
| 8 |
+
model = RadarLightningModel(
|
| 9 |
+
input_channels=1,
|
| 10 |
+
num_blocks=1,
|
| 11 |
+
forecast_steps=2,
|
| 12 |
+
ensemble_size=1,
|
| 13 |
+
noisy_decoder=False,
|
| 14 |
+
)
|
| 15 |
+
past = np.zeros((4, 8, 8), dtype=np.float32)
|
| 16 |
+
|
| 17 |
+
preds = model.predict(past, forecast_steps=2, ensemble_size=1)
|
| 18 |
+
|
| 19 |
+
assert preds.shape == (1, 2, 8, 8)
|
| 20 |
+
assert np.isfinite(preds).all()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def test_from_checkpoint_delegates_to_lightning_loader(monkeypatch):
|
| 24 |
+
captured = {}
|
| 25 |
+
|
| 26 |
+
def fake_loader(cls, checkpoint_path, map_location=None, strict=None, weights_only=None):
|
| 27 |
+
captured["checkpoint_path"] = checkpoint_path
|
| 28 |
+
captured["map_location"] = map_location
|
| 29 |
+
captured["strict"] = strict
|
| 30 |
+
captured["weights_only"] = weights_only
|
| 31 |
+
return "loaded-model"
|
| 32 |
+
|
| 33 |
+
monkeypatch.setattr(RadarLightningModel, "load_from_checkpoint", classmethod(fake_loader))
|
| 34 |
+
|
| 35 |
+
loaded = RadarLightningModel.from_checkpoint("/tmp/model.ckpt", device="cpu")
|
| 36 |
+
|
| 37 |
+
assert loaded == "loaded-model"
|
| 38 |
+
assert captured["checkpoint_path"] == "/tmp/model.ckpt"
|
| 39 |
+
assert isinstance(captured["map_location"], torch.device)
|
| 40 |
+
assert captured["map_location"].type == "cpu"
|
| 41 |
+
assert captured["strict"] is True
|
tests/test_losses.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from convgru_ensemble.losses import CRPS, MaskedLoss, build_loss
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_crps_reduces_to_scalar():
|
| 7 |
+
loss_fn = CRPS(reduction="mean")
|
| 8 |
+
preds = torch.randn(2, 4, 5, 8, 8) # (B, T, M, H, W)
|
| 9 |
+
target = torch.randn(2, 4, 1, 8, 8) # (B, T, 1, H, W)
|
| 10 |
+
loss = loss_fn(preds, target)
|
| 11 |
+
assert loss.dim() == 0 # scalar
|
| 12 |
+
assert loss.item() > 0 or loss.item() == 0
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def test_masked_loss_ignores_masked_pixels():
|
| 16 |
+
base_loss = torch.nn.MSELoss(reduction="none")
|
| 17 |
+
loss_fn = MaskedLoss(base_loss, reduction="mean")
|
| 18 |
+
preds = torch.ones(1, 2, 1, 4, 4)
|
| 19 |
+
target = torch.zeros(1, 2, 1, 4, 4)
|
| 20 |
+
# Mask out everything — loss should be 0
|
| 21 |
+
mask = torch.zeros(1, 1, 1, 4, 4)
|
| 22 |
+
loss = loss_fn(preds, target, mask)
|
| 23 |
+
assert loss.item() == 0.0
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def test_build_loss_by_name():
|
| 27 |
+
criterion = build_loss("crps", loss_params=None, masked_loss=False)
|
| 28 |
+
assert isinstance(criterion, CRPS)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def test_build_loss_masked():
|
| 32 |
+
criterion = build_loss("mse", loss_params=None, masked_loss=True)
|
| 33 |
+
assert isinstance(criterion, MaskedLoss)
|
tests/test_model.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from convgru_ensemble.model import EncoderDecoder
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_forward_single_output_shape():
|
| 7 |
+
model = EncoderDecoder(channels=1, num_blocks=2)
|
| 8 |
+
x = torch.randn(2, 4, 1, 16, 16)
|
| 9 |
+
out = model(x, steps=3, noisy_decoder=False, ensemble_size=1)
|
| 10 |
+
assert out.shape == (2, 3, 1, 16, 16)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_forward_ensemble_output_shape():
|
| 14 |
+
model = EncoderDecoder(channels=1, num_blocks=2)
|
| 15 |
+
x = torch.randn(2, 4, 1, 16, 16)
|
| 16 |
+
out = model(x, steps=3, noisy_decoder=False, ensemble_size=5)
|
| 17 |
+
assert out.shape == (2, 3, 5, 16, 16)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def test_forward_different_num_blocks():
|
| 21 |
+
model = EncoderDecoder(channels=1, num_blocks=3)
|
| 22 |
+
x = torch.randn(1, 4, 1, 32, 32)
|
| 23 |
+
out = model(x, steps=2, ensemble_size=1)
|
| 24 |
+
assert out.shape == (1, 2, 1, 32, 32)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_noisy_decoder_produces_different_outputs():
|
| 28 |
+
model = EncoderDecoder(channels=1, num_blocks=2)
|
| 29 |
+
model.eval()
|
| 30 |
+
x = torch.randn(1, 4, 1, 16, 16)
|
| 31 |
+
out1 = model(x, steps=2, noisy_decoder=True, ensemble_size=1)
|
| 32 |
+
out2 = model(x, steps=2, noisy_decoder=True, ensemble_size=1)
|
| 33 |
+
# Noisy decoder should produce different outputs (with very high probability)
|
| 34 |
+
assert not torch.allclose(out1, out2)
|
tests/test_serve.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
from unittest.mock import MagicMock, patch
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pytest
|
| 6 |
+
import xarray as xr
|
| 7 |
+
|
| 8 |
+
from convgru_ensemble.lightning_model import RadarLightningModel
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@pytest.fixture
|
| 12 |
+
def mock_model():
|
| 13 |
+
model = MagicMock(spec=RadarLightningModel)
|
| 14 |
+
model.hparams = MagicMock()
|
| 15 |
+
model.hparams.input_channels = 1
|
| 16 |
+
model.hparams.num_blocks = 2
|
| 17 |
+
model.hparams.forecast_steps = 12
|
| 18 |
+
model.hparams.ensemble_size = 2
|
| 19 |
+
model.hparams.noisy_decoder = False
|
| 20 |
+
model.hparams.loss_class = "crps"
|
| 21 |
+
model.device = "cpu"
|
| 22 |
+
model.predict.return_value = np.zeros((10, 12, 8, 8), dtype=np.float32)
|
| 23 |
+
return model
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@pytest.fixture
|
| 27 |
+
def client(mock_model):
|
| 28 |
+
from fastapi.testclient import TestClient
|
| 29 |
+
|
| 30 |
+
with patch("convgru_ensemble.serve._load_model", return_value=mock_model):
|
| 31 |
+
from convgru_ensemble.serve import app
|
| 32 |
+
|
| 33 |
+
with TestClient(app) as c:
|
| 34 |
+
yield c
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_health(client):
|
| 38 |
+
resp = client.get("/health")
|
| 39 |
+
assert resp.status_code == 200
|
| 40 |
+
data = resp.json()
|
| 41 |
+
assert data["status"] == "ok"
|
| 42 |
+
assert data["model_loaded"] is True
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def test_model_info(client):
|
| 46 |
+
resp = client.get("/model/info")
|
| 47 |
+
assert resp.status_code == 200
|
| 48 |
+
data = resp.json()
|
| 49 |
+
assert data["architecture"] == "ConvGRU-Ensemble EncoderDecoder"
|
| 50 |
+
assert data["num_blocks"] == 2
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def test_predict_returns_netcdf(client):
|
| 54 |
+
# Create a small NetCDF file in memory
|
| 55 |
+
ds = xr.Dataset({"RR": xr.DataArray(np.zeros((4, 8, 8), dtype=np.float32), dims=["time", "y", "x"])})
|
| 56 |
+
buf = io.BytesIO()
|
| 57 |
+
ds.to_netcdf(buf)
|
| 58 |
+
buf.seek(0)
|
| 59 |
+
|
| 60 |
+
resp = client.post(
|
| 61 |
+
"/predict?forecast_steps=12&ensemble_size=10",
|
| 62 |
+
files={"file": ("input.nc", buf, "application/x-netcdf")},
|
| 63 |
+
)
|
| 64 |
+
assert resp.status_code == 200
|
| 65 |
+
assert resp.headers["content-type"] == "application/x-netcdf"
|
| 66 |
+
assert "X-Elapsed-Seconds" in resp.headers
|
tests/test_utils.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from convgru_ensemble.utils import normalized_to_rainrate, rainrate_to_normalized
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_roundtrip_conversion():
|
| 7 |
+
rain = np.array([0.0, 1.0, 5.0, 20.0, 50.0], dtype=np.float32)
|
| 8 |
+
normalized = rainrate_to_normalized(rain)
|
| 9 |
+
recovered = normalized_to_rainrate(normalized)
|
| 10 |
+
np.testing.assert_allclose(recovered, rain, rtol=0.01, atol=0.05)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_zero_rain_maps_to_low_normalized():
|
| 14 |
+
rain = np.array([0.0], dtype=np.float32)
|
| 15 |
+
normalized = rainrate_to_normalized(rain)
|
| 16 |
+
# Zero rain should map to approximately -1 (0 dBZ → normalized -1)
|
| 17 |
+
assert normalized[0] < -0.9
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|