Initial release: Mosaic weather model (era5 + hres variants)
Browse files- .gitattributes +6 -0
- .gitignore +10 -0
- README.md +218 -0
- base.py +17 -0
- config.py +25 -0
- dataset.py +33 -0
- era5_best.pt +3 -0
- figures_weather/bsa.jpg +3 -0
- figures_weather/bsa_runtime.jpg +3 -0
- figures_weather/healpix.jpg +0 -0
- figures_weather/hurricane_tracking.jpg +3 -0
- figures_weather/main.jpg +3 -0
- figures_weather/results_hres.jpg +3 -0
- figures_weather/results_spectra_pareto.jpg +3 -0
- hres_best.pt +3 -0
- inference.py +527 -0
- mosaic.py +337 -0
- ops.py +566 -0
- primitives.py +359 -0
- requirements.txt +11 -0
- utils.py +40 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
figures_weather/bsa.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
figures_weather/bsa_runtime.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
figures_weather/hurricane_tracking.jpg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
figures_weather/main.jpg filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
figures_weather/results_hres.jpg filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
figures_weather/results_spectra_pareto.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
*.egg-info/
|
| 6 |
+
dist/
|
| 7 |
+
build/
|
| 8 |
+
.env
|
| 9 |
+
*.npz
|
| 10 |
+
checkpoints/
|
README.md
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: cc-by-nc-4.0
|
| 3 |
+
library_name: pytorch
|
| 4 |
+
tags:
|
| 5 |
+
- weather
|
| 6 |
+
- weather-forecasting
|
| 7 |
+
- climate
|
| 8 |
+
- atmospheric-science
|
| 9 |
+
- sparse-attention
|
| 10 |
+
- transformer
|
| 11 |
+
- probabilistic-forecasting
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# Mosaic — Block-Sparse Attention for Weather Forecasting
|
| 15 |
+
|
| 16 |
+
**Mosaic** is a probabilistic weather forecasting model that operates on native-resolution grids via mesh-aligned block-sparse attention. At 1.5° resolution with 214M parameters, Mosaic matches or outperforms models trained on 6× finer resolution on key variables, and individual ensemble members exhibit near-perfect spectral alignment across all resolved frequencies. A 24-member, 10-day forecast takes under 12 s on a single H100 GPU.
|
| 17 |
+
|
| 18 |
+
> **(Sparse) Attention to the Details: Preserving Spectral Fidelity in ML-based Weather Forecasting Models** \
|
| 19 |
+
> Maksim Zhdanov, Ana Lucic, Max Welling, Jan-Willem van de Meent \
|
| 20 |
+
> *ICML 2026* · [arXiv:2604.16429](https://arxiv.org/abs/2604.16429) · [GitHub](https://github.com/maxxxzdn/mosaic)
|
| 21 |
+
|
| 22 |
+

|
| 23 |
+
|
| 24 |
+
## TL;DR
|
| 25 |
+
|
| 26 |
+
Mosaic addresses two distinct failure modes of spectral degradation in ML-based weather prediction:
|
| 27 |
+
|
| 28 |
+
1. **Spectral damping** caused by deterministic training against ensemble means. Mosaic addresses this with learned functional perturbations that produce ensemble members preserving realistic spectral variability.
|
| 29 |
+
2. **High-frequency aliasing** caused by compressive encoding onto a coarse latent grid. Mosaic operates at native resolution via block-sparse attention before any coarsening, eliminating the compress-first bottleneck.
|
| 30 |
+
|
| 31 |
+
The block-sparse attention captures long-range dependencies at **linear** cost by sharing keys and values across spatially adjacent queries arranged on the HEALPix mesh. Each query block jointly selects which key blocks to attend to.
|
| 32 |
+
|
| 33 |
+
## Published Variants
|
| 34 |
+
|
| 35 |
+
This repository ships **two trained variants**, distinguished primarily by the data they were tuned on. They share the same Mosaic architecture and 82-channel variable set, but differ in training data, time cadence, history length, and normalization statistics.
|
| 36 |
+
|
| 37 |
+
| Variant | Training data | Native step | Input history | k-neighbours | Suggested input zarr |
|
| 38 |
+
|---------|---------------|-------------|---------------|---------------|---------------------|
|
| 39 |
+
| `era5` | ERA5 reanalysis only | 24 h | 2 states (48 h) | 24 | WB2 ERA5 1.5° |
|
| 40 |
+
| `hres` | ERA5 pretrain + HRES finetune | 6 h | 4 states (24 h) | 20 | WB2 HRES-fc0 1.5° |
|
| 41 |
+
|
| 42 |
+
Choose `era5` when initializing from reanalysis (matches the training distribution); choose `hres` when initializing from HRES analysis or a similar operational state.
|
| 43 |
+
|
| 44 |
+
## Architecture
|
| 45 |
+
|
| 46 |
+
**Inputs.** 82 atmospheric channels at 1.5° equiangular resolution (240 lon × 121 lat = 29 040 points) plus 3 static channels and sinusoidal day/year time encodings.
|
| 47 |
+
|
| 48 |
+
**Backbone.** A U-Net of transformer blocks over the HEALPix mesh, where spatial neighbours occupy contiguous memory and queries can be grouped into hardware-aligned blocks:
|
| 49 |
+
|
| 50 |
+
| Stage | Nside | Hidden dim | Heads | Enc / Dec depth |
|
| 51 |
+
|------------|------:|----------:|------:|----------------:|
|
| 52 |
+
| Stage 1 | 64 | 768 | 12 | 4 / 2 |
|
| 53 |
+
| Stage 2 | 32 | 1024 | 16 | 4 / 2 |
|
| 54 |
+
| Bottleneck | 16 | 1280 | 20 | 2 |
|
| 55 |
+
|
| 56 |
+
Grouped-Query Attention with ratio 4 (3 KV heads per stage), 2D RoPE on (longitude, latitude), and additive noise injection in SwiGLU gates for ensemble generation. ~214M parameters total.
|
| 57 |
+
|
| 58 |
+

|
| 59 |
+
|
| 60 |
+
**Block-sparse attention.** Three branches combined by learned gates: (i) **compression** — block-to-block coarse attention captures broad synoptic patterns at $\mathcal{O}(N^2/b^2)$; (ii) **selection** — each query block top-k-selects fine-scale key blocks at $\mathcal{O}(Nnb)$; (iii) **local** — full attention inside each block at $\mathcal{O}(Nb)$. Spatially close points occupy contiguous memory on the HEALPix mesh, enabling coalesced GPU reads and hardware-aligned block computation. Implemented as a single Triton kernel; in practice up to **61.8× faster than dense attention** and **9.4× faster than NSA**.
|
| 61 |
+
|
| 62 |
+
<p align="center">
|
| 63 |
+
<img src="figures_weather/healpix.jpg" alt="HEALPix mesh" width="48%">
|
| 64 |
+
<img src="figures_weather/bsa_runtime.jpg" alt="Runtime scaling" width="48%">
|
| 65 |
+
</p>
|
| 66 |
+
|
| 67 |
+
### Variables (82 channels)
|
| 68 |
+
|
| 69 |
+
- **Surface (4):** `2m_temperature`, `10m_u_component_of_wind`, `10m_v_component_of_wind`, `mean_sea_level_pressure`
|
| 70 |
+
- **Pressure level (6 × 13 = 78):** `geopotential`, `specific_humidity`, `temperature`, `u_component_of_wind`, `v_component_of_wind`, `vertical_velocity` at levels [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000] hPa
|
| 71 |
+
- **Static (3, conditioning only — not in output):** `geopotential_at_surface`, `land_sea_mask`, `soil_type`
|
| 72 |
+
|
| 73 |
+
## Results
|
| 74 |
+
|
| 75 |
+
Headline plot is at the top of this page: individual ensemble members preserve realistic kinetic-energy spectra (left, 1.5°; centre, 0.25°), and Mosaic sits on the favourable end of the skill–speed–memory Pareto (right). All metrics computed at 240 h lead time, 720 initial conditions throughout the 2020 test year (1.5° benchmark) and a single 6 h forecast (0.25° benchmark).
|
| 76 |
+
|
| 77 |
+
On the 0.25° HRES benchmark, Mosaic competes with state-of-the-art 0.25° models despite operating at 1.5°:
|
| 78 |
+
|
| 79 |
+

|
| 80 |
+
|
| 81 |
+
And a case study of **Hurricane Ian (2022)** — Mosaic's ensemble correctly brackets the observed track 7 days ahead, with progressive narrowing of spread as lead time decreases:
|
| 82 |
+
|
| 83 |
+

|
| 84 |
+
|
| 85 |
+
See the paper for full benchmark tables, CRPS curves, and spread-to-skill ratios.
|
| 86 |
+
|
| 87 |
+
## Hardware Requirements
|
| 88 |
+
|
| 89 |
+
- **GPU:** any CUDA GPU; 16 GB is enough for a 1-member rollout, A100/H100 recommended for multi-member ensembles
|
| 90 |
+
- **Memory:** ~9 GB GPU RAM for a 1-member, 40-step (10-day) rollout in float16
|
| 91 |
+
- **Throughput:** 24-member, 10-day forecast in under 12 s on a single H100
|
| 92 |
+
- **CUDA:** 11.8+ with matching `triton` and `flash-attn` versions
|
| 93 |
+
|
| 94 |
+
## Installation
|
| 95 |
+
|
| 96 |
+
```bash
|
| 97 |
+
pip install -r requirements.txt
|
| 98 |
+
pip install flash-attn --no-build-isolation # built separately; needs nvcc
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
For reading data from Google Cloud Storage (WeatherBench2 zarr stores):
|
| 102 |
+
|
| 103 |
+
```bash
|
| 104 |
+
pip install gcsfs
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
## Quick Start
|
| 108 |
+
|
| 109 |
+
```bash
|
| 110 |
+
# ERA5 variant — 10-day forecast at 24 h resolution from ERA5 reanalysis
|
| 111 |
+
python inference.py --variant era5 \
|
| 112 |
+
--zarr gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-240x121_equiangular_with_poles_conservative.zarr \
|
| 113 |
+
--init-time "2020-01-01T00:00" \
|
| 114 |
+
--steps 10 --members 1 \
|
| 115 |
+
--output forecast_era5.npz
|
| 116 |
+
|
| 117 |
+
# HRES variant — 10-day forecast at 6 h resolution from HRES initial conditions
|
| 118 |
+
python inference.py --variant hres \
|
| 119 |
+
--zarr gs://weatherbench2/datasets/hres_t0/2016-2022-6h-240x121_equiangular_with_poles_conservative.zarr \
|
| 120 |
+
--init-time "2022-01-01T00:00" \
|
| 121 |
+
--steps 40 --members 1 \
|
| 122 |
+
--output forecast_hres.npz
|
| 123 |
+
|
| 124 |
+
# Ensemble forecast (16 members) — change --members
|
| 125 |
+
python inference.py --variant hres --zarr <...> --init-time "2020-06-15T12:00" \
|
| 126 |
+
--steps 40 --members 16 --output ensemble.npz
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
`--variant` selects the checkpoint, normalization statistics, history length, time stride, and neighbour count automatically. Pass `--checkpoint` or `--norm-stats` to override the bundled defaults.
|
| 130 |
+
|
| 131 |
+
## Output Format
|
| 132 |
+
|
| 133 |
+
The output `.npz` file contains:
|
| 134 |
+
|
| 135 |
+
| Array | Shape | Description |
|
| 136 |
+
|-------|-------|-------------|
|
| 137 |
+
| `forecasts` | `(members, steps, 240, 121, 82)` | Predicted states in physical units |
|
| 138 |
+
| `variables` | `(82,)` | Variable names |
|
| 139 |
+
| `lead_time_hours` | `(steps,)` | Lead times (era5: 24, 48, …; hres: 6, 12, …) |
|
| 140 |
+
| `init_time` | scalar | Initialization timestamp |
|
| 141 |
+
| `longitude` | `(240,)` | Longitude values (0 to 358.5°) |
|
| 142 |
+
| `latitude` | `(121,)` | Latitude values, South→North (−90 to 90°) |
|
| 143 |
+
|
| 144 |
+
### Reading the output
|
| 145 |
+
|
| 146 |
+
```python
|
| 147 |
+
import numpy as np
|
| 148 |
+
|
| 149 |
+
data = np.load("forecast_era5.npz", allow_pickle=True)
|
| 150 |
+
forecasts = data['forecasts'] # (members, steps, 240, 121, 82)
|
| 151 |
+
variables = list(data['variables']) # ['2m_temperature', ...]
|
| 152 |
+
lead_hours = data['lead_time_hours'] # e.g. [24, 48, ..., 240]
|
| 153 |
+
|
| 154 |
+
# Extract 500 hPa geopotential at 24 h lead time
|
| 155 |
+
z500_idx = variables.index('geopotential_500')
|
| 156 |
+
i_24h = int(np.where(lead_hours == 24)[0][0])
|
| 157 |
+
z500_24h = forecasts[0, i_24h, :, :, z500_idx] # (240, 121) lon × lat
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
## Input Data Format
|
| 161 |
+
|
| 162 |
+
The model accepts ERA5 or HRES data in zarr format at 1.5° resolution with:
|
| 163 |
+
|
| 164 |
+
- **Grid:** 240 lon × 121 lat equiangular with poles
|
| 165 |
+
- **Time:** 6-hourly timesteps (integer hours since an arbitrary origin, parsed from the `units` zarr attribute, or as datetime64)
|
| 166 |
+
- **Variables:** all 10 atmospheric variables listed above; per-variable layout is auto-detected from `_ARRAY_DIMENSIONS` (either `(time, latitude, longitude)` or `(time, longitude, latitude)`), and latitude is flipped if stored North→South
|
| 167 |
+
|
| 168 |
+
Compatible zarr stores from [WeatherBench2](https://weatherbench2.readthedocs.io/):
|
| 169 |
+
|
| 170 |
+
```
|
| 171 |
+
gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-240x121_equiangular_with_poles_conservative.zarr
|
| 172 |
+
gs://weatherbench2/datasets/hres_t0/2016-2022-6h-240x121_equiangular_with_poles_conservative.zarr
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
## Repository Contents
|
| 176 |
+
|
| 177 |
+
| File | Description |
|
| 178 |
+
|------|-------------|
|
| 179 |
+
| `inference.py` | Main inference script (variant-aware via `--variant {era5,hres}`) |
|
| 180 |
+
| `mosaic.py` | Mosaic U-Net transformer |
|
| 181 |
+
| `primitives.py` | Attention blocks, RoPE, HEALPix sampling, noise generator |
|
| 182 |
+
| `ops.py` | Triton block-sparse attention kernels |
|
| 183 |
+
| `utils.py` | HEALPix grid utilities |
|
| 184 |
+
| `base.py` | `WeatherModel` wrapper |
|
| 185 |
+
| `config.py` | Variable / level definitions |
|
| 186 |
+
| `dataset.py` | Metadata dataclasses |
|
| 187 |
+
| `norm_stats_era5.npz` | Normalization statistics for the `era5` variant |
|
| 188 |
+
| `norm_stats_hres.npz` | Normalization statistics for the `hres` variant |
|
| 189 |
+
| `static_vars.npz` | Static fields (orography, land–sea mask, soil type) — shared between variants |
|
| 190 |
+
| `era5_best.pt` | Trained checkpoint, `era5` variant (~1.7 GB) |
|
| 191 |
+
| `hres_best.pt` | Trained checkpoint, `hres` variant (~1.7 GB) |
|
| 192 |
+
| `figures_weather/` | Figures from the paper |
|
| 193 |
+
|
| 194 |
+
## Limitations
|
| 195 |
+
|
| 196 |
+
Mosaic operates at 1.5° (~166 km), which cannot resolve mesoscale phenomena such as tropical-cyclone inner-core structure or individual severe thunderstorms. The block-sparse attention is designed to scale linearly with sequence length, so finer grids (e.g. 0.25°, ~700k tokens) are a natural next step but are not part of this release.
|
| 197 |
+
|
| 198 |
+
## Citation
|
| 199 |
+
|
| 200 |
+
If you use Mosaic, please cite:
|
| 201 |
+
|
| 202 |
+
```bibtex
|
| 203 |
+
@inproceedings{zhdanov2026mosaic,
|
| 204 |
+
title = {(Sparse) Attention to the Details: Preserving Spectral Fidelity in ML-based Weather Forecasting Models},
|
| 205 |
+
author = {Zhdanov, Maksim and Lucic, Ana and Welling, Max and van de Meent, Jan-Willem},
|
| 206 |
+
booktitle = {Proceedings of the 43rd International Conference on Machine Learning (ICML)},
|
| 207 |
+
year = {2026},
|
| 208 |
+
url = {https://arxiv.org/abs/2604.16429}
|
| 209 |
+
}
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
## License
|
| 213 |
+
|
| 214 |
+
Released under [CC-BY-NC-4.0](https://creativecommons.org/licenses/by-nc/4.0/). Free for non-commercial research and educational use with attribution; commercial use requires a separate license. Underlying training data (ERA5, HRES) is subject to its own licensing terms set by ECMWF.
|
| 215 |
+
|
| 216 |
+
## Acknowledgements
|
| 217 |
+
|
| 218 |
+
MZ acknowledges support from Microsoft Research AI4Science. JWvdM acknowledges support from the European Union Horizon Framework Programme (Grant agreement ID: 101120237). This work used the Dutch national e-infrastructure with the support of the SURF Cooperative using grant no. EINF-16923. Computations were partially performed using the UvA/FNWI HPC Facility.
|
base.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from dataset import WeatherMetadata
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class WeatherModel(nn.Module):
|
| 7 |
+
"""Weather forecasting model wrapper."""
|
| 8 |
+
|
| 9 |
+
def __init__(self, model: nn.Module, weather_metadata: WeatherMetadata):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.model = model
|
| 12 |
+
self.model.initialize_static_vars(weather_metadata.static_data, weather_metadata.longitude, weather_metadata.latitude)
|
| 13 |
+
self.model.initialize_interpolation(weather_metadata.longitude, weather_metadata.latitude)
|
| 14 |
+
self.weather_metadata = weather_metadata
|
| 15 |
+
|
| 16 |
+
def forward(self, norm_state: torch.Tensor, day_year_time: torch.Tensor, num_noise_samples: int):
|
| 17 |
+
return self.model(norm_state, day_year_time, num_noise_samples)
|
config.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Variable names and pressure levels for the Mosaic weather forecasting model.
|
| 2 |
+
|
| 3 |
+
SL_VARS: list[str] = [
|
| 4 |
+
"2m_temperature",
|
| 5 |
+
"10m_u_component_of_wind",
|
| 6 |
+
"10m_v_component_of_wind",
|
| 7 |
+
"mean_sea_level_pressure",
|
| 8 |
+
]
|
| 9 |
+
|
| 10 |
+
PL_VARS: list[str] = [
|
| 11 |
+
"geopotential",
|
| 12 |
+
"specific_humidity",
|
| 13 |
+
"temperature",
|
| 14 |
+
"u_component_of_wind",
|
| 15 |
+
"v_component_of_wind",
|
| 16 |
+
"vertical_velocity",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
ST_VARS: list[str] = [
|
| 20 |
+
"geopotential_at_surface",
|
| 21 |
+
"land_sea_mask",
|
| 22 |
+
"soil_type",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
LEVELS: list[int] = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]
|
dataset.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Metadata dataclasses for weather forecasting inference."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass(frozen=True)
|
| 8 |
+
class NormalizationStats:
|
| 9 |
+
"""Normalization statistics for state variables."""
|
| 10 |
+
state_mean: torch.Tensor
|
| 11 |
+
state_std: torch.Tensor
|
| 12 |
+
residual_mean: torch.Tensor
|
| 13 |
+
residual_std: torch.Tensor
|
| 14 |
+
|
| 15 |
+
def to(self, device) -> 'NormalizationStats':
|
| 16 |
+
return NormalizationStats(
|
| 17 |
+
state_mean=self.state_mean.to(device),
|
| 18 |
+
state_std=self.state_std.to(device),
|
| 19 |
+
residual_mean=self.residual_mean.to(device),
|
| 20 |
+
residual_std=self.residual_std.to(device),
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass(frozen=True)
|
| 25 |
+
class WeatherMetadata:
|
| 26 |
+
"""Metadata for the weather dataset."""
|
| 27 |
+
variables: list[str]
|
| 28 |
+
static_variables: list[str]
|
| 29 |
+
longitude: torch.Tensor
|
| 30 |
+
latitude: torch.Tensor
|
| 31 |
+
static_data: torch.Tensor
|
| 32 |
+
day_year_delta: torch.Tensor
|
| 33 |
+
norm_stats: NormalizationStats
|
era5_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:187df31a0caef3e934e4eb8a11506c9fea518ff0e02ce6d1f804fc6c8a78a940
|
| 3 |
+
size 1713557607
|
figures_weather/bsa.jpg
ADDED
|
Git LFS Details
|
figures_weather/bsa_runtime.jpg
ADDED
|
Git LFS Details
|
figures_weather/healpix.jpg
ADDED
|
figures_weather/hurricane_tracking.jpg
ADDED
|
|
Git LFS Details
|
figures_weather/main.jpg
ADDED
|
Git LFS Details
|
figures_weather/results_hres.jpg
ADDED
|
Git LFS Details
|
figures_weather/results_spectra_pareto.jpg
ADDED
|
Git LFS Details
|
hres_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6e0bc244382fa1b0f09ecdbd527c601d352cc5989697149742e8983f38a25e5e
|
| 3 |
+
size 1714571315
|
inference.py
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Run autoregressive global weather forecasts with the Mosaic 1.5° model.
|
| 3 |
+
|
| 4 |
+
The model predicts 6-hourly atmospheric states autoregressively, supporting
|
| 5 |
+
both deterministic (1 member) and probabilistic (N members) forecasts.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
# ERA5 variant (24h steps), default checkpoint and norm stats inferred from --variant
|
| 9 |
+
python inference.py --variant era5 \\
|
| 10 |
+
--zarr gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-240x121_equiangular_with_poles_conservative.zarr \\
|
| 11 |
+
--init-time "2020-01-01T00:00" --steps 10 --output forecast_era5.npz
|
| 12 |
+
|
| 13 |
+
# HRES variant (6h steps)
|
| 14 |
+
python inference.py --variant hres \\
|
| 15 |
+
--zarr gs://weatherbench2/datasets/hres_t0/2016-2022-6h-240x121_equiangular_with_poles_conservative.zarr \\
|
| 16 |
+
--init-time "2022-01-01T00:00" --steps 40 --output forecast_hres.npz
|
| 17 |
+
|
| 18 |
+
Input zarr format:
|
| 19 |
+
The zarr store must contain the following variables at 1.5° resolution
|
| 20 |
+
(240 lon × 121 lat, 6-hourly timesteps):
|
| 21 |
+
- Surface: 2m_temperature, 10m_u_component_of_wind, 10m_v_component_of_wind,
|
| 22 |
+
mean_sea_level_pressure
|
| 23 |
+
- Pressure-level (at 13 levels 50..1000 hPa): geopotential, specific_humidity,
|
| 24 |
+
temperature, u_component_of_wind, v_component_of_wind, vertical_velocity
|
| 25 |
+
- Coordinates: longitude (240,), latitude (121,), time (hours since 1959-01-01)
|
| 26 |
+
|
| 27 |
+
Output npz:
|
| 28 |
+
forecasts float32 (members, steps, 240, 121, 82) – physical units
|
| 29 |
+
variables list of 82 variable names
|
| 30 |
+
lead_time_hours int32 (steps,) – multiples of step_stride*6h
|
| 31 |
+
(era5: 24, 48, ...; hres: 6, 12, ...)
|
| 32 |
+
init_time str – initialization timestamp
|
| 33 |
+
|
| 34 |
+
Hardware:
|
| 35 |
+
Requires a CUDA GPU. A 16 GB GPU is enough for 1 member; A100 80 GB recommended
|
| 36 |
+
for multi-member ensembles. float16 inference (~9 GB for 1 member, 40-step rollout).
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
import argparse
|
| 40 |
+
import math
|
| 41 |
+
import os
|
| 42 |
+
import sys
|
| 43 |
+
from dataclasses import dataclass
|
| 44 |
+
from pathlib import Path
|
| 45 |
+
|
| 46 |
+
import numpy as np
|
| 47 |
+
import pandas as pd
|
| 48 |
+
import torch
|
| 49 |
+
import zarr
|
| 50 |
+
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
# Model imports
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
from config import SL_VARS, PL_VARS, ST_VARS, LEVELS
|
| 55 |
+
from dataset import NormalizationStats, WeatherMetadata
|
| 56 |
+
from mosaic import Transformer, ModelConfig, StageConfig, BottleneckConfig
|
| 57 |
+
from base import WeatherModel
|
| 58 |
+
|
| 59 |
+
DTYPE = torch.float16
|
| 60 |
+
|
| 61 |
+
# ---------------------------------------------------------------------------
|
| 62 |
+
# Model variant presets
|
| 63 |
+
# ---------------------------------------------------------------------------
|
| 64 |
+
# The two published variants share the same Mosaic architecture (stage / bottleneck
|
| 65 |
+
# sizes) but differ in training data, time cadence, history length, neighbour
|
| 66 |
+
# count, and normalisation statistics:
|
| 67 |
+
# - `era5`: ERA5-only training, 24h steps (4 x 6h), 2 input states, k=24 neighbours
|
| 68 |
+
# - `hres`: ERA5 pretrain + HRES finetune, 6h steps, 4 input states, k=20 neighbours
|
| 69 |
+
# ---------------------------------------------------------------------------
|
| 70 |
+
|
| 71 |
+
_STAGE_CFGS_COMMON = [
|
| 72 |
+
StageConfig(
|
| 73 |
+
nside=64, dim=768, num_heads=12,
|
| 74 |
+
block_attn_size=1024, sparse_block_size=128, sparse_block_count=24,
|
| 75 |
+
encoder_depth=4, decoder_depth=2, mlp_ratio=4.0, gqa_ratio=4,
|
| 76 |
+
),
|
| 77 |
+
StageConfig(
|
| 78 |
+
nside=32, dim=1024, num_heads=16,
|
| 79 |
+
block_attn_size=1024, sparse_block_size=128, sparse_block_count=12,
|
| 80 |
+
encoder_depth=4, decoder_depth=2, mlp_ratio=4.0, gqa_ratio=4,
|
| 81 |
+
),
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
_BOTTLENECK_CFG_COMMON = BottleneckConfig(
|
| 85 |
+
nside=16, dim=1280, num_heads=20,
|
| 86 |
+
block_attn_size=1024, sparse_block_size=128, sparse_block_count=4,
|
| 87 |
+
depth=2, mlp_ratio=4.0, gqa_ratio=4,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@dataclass
|
| 92 |
+
class Preset:
|
| 93 |
+
step_stride: int # number of native 6h timesteps per model step
|
| 94 |
+
num_history_steps: int # number of input states fed to the model
|
| 95 |
+
k_neighbors: int # neighbours used in cross-attention interpolation
|
| 96 |
+
default_checkpoint: str
|
| 97 |
+
default_norm_stats: str
|
| 98 |
+
stage_cfgs: list
|
| 99 |
+
bottleneck_cfg: BottleneckConfig
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
PRESETS = {
|
| 103 |
+
"era5": Preset(
|
| 104 |
+
step_stride=4, num_history_steps=2, k_neighbors=24,
|
| 105 |
+
default_checkpoint="era5_best.pt",
|
| 106 |
+
default_norm_stats="norm_stats_era5.npz",
|
| 107 |
+
stage_cfgs=_STAGE_CFGS_COMMON,
|
| 108 |
+
bottleneck_cfg=_BOTTLENECK_CFG_COMMON,
|
| 109 |
+
),
|
| 110 |
+
"hres": Preset(
|
| 111 |
+
step_stride=1, num_history_steps=4, k_neighbors=20,
|
| 112 |
+
default_checkpoint="hres_best.pt",
|
| 113 |
+
default_norm_stats="norm_stats_hres.npz",
|
| 114 |
+
stage_cfgs=_STAGE_CFGS_COMMON,
|
| 115 |
+
bottleneck_cfg=_BOTTLENECK_CFG_COMMON,
|
| 116 |
+
),
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# ---------------------------------------------------------------------------
|
| 121 |
+
# Time utilities
|
| 122 |
+
# ---------------------------------------------------------------------------
|
| 123 |
+
|
| 124 |
+
def compute_day_year_progress(timestamp: pd.Timestamp):
|
| 125 |
+
"""Return (day_progress, year_progress) fractions for a single timestamp."""
|
| 126 |
+
day_progress = (timestamp.hour * 3600 + timestamp.minute * 60 + timestamp.second) / 86400.0
|
| 127 |
+
days_in_year = 366 if timestamp.is_leap_year else 365
|
| 128 |
+
year_progress = (timestamp.day_of_year - 1) / days_in_year
|
| 129 |
+
return float(day_progress), float(year_progress)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# ---------------------------------------------------------------------------
|
| 133 |
+
# Zarr loading
|
| 134 |
+
# ---------------------------------------------------------------------------
|
| 135 |
+
|
| 136 |
+
def _load_zarr_times(store) -> pd.DatetimeIndex:
|
| 137 |
+
"""Load and decode the time coordinate from the zarr store, honouring its units attr."""
|
| 138 |
+
time_raw = np.asarray(store['time'])
|
| 139 |
+
if not np.issubdtype(time_raw.dtype, np.integer):
|
| 140 |
+
return pd.to_datetime(time_raw)
|
| 141 |
+
# Integer encoding: parse 'units' attr e.g. "hours since 1959-01-01"
|
| 142 |
+
units = store['time'].attrs.get('units', 'hours since 1959-01-01')
|
| 143 |
+
try:
|
| 144 |
+
unit_word, _, origin = units.partition(' since ')
|
| 145 |
+
except Exception:
|
| 146 |
+
unit_word, origin = 'hours', '1959-01-01'
|
| 147 |
+
unit_map = {'hours': 'h', 'hour': 'h', 'days': 'D', 'day': 'D',
|
| 148 |
+
'minutes': 'm', 'minute': 'm', 'seconds': 's', 'second': 's'}
|
| 149 |
+
unit = unit_map.get(unit_word.strip().lower(), 'h')
|
| 150 |
+
return pd.to_datetime(time_raw, unit=unit, origin=origin.strip() or '1959-01-01')
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def load_initial_state(zarr_path: str, init_time: str, num_history_steps: int = 4, step_stride: int = 1):
|
| 154 |
+
"""
|
| 155 |
+
Load `num_history_steps` timesteps ending at `init_time` from a zarr store,
|
| 156 |
+
spaced `step_stride * 6h` apart (so step_stride=4 -> 24h spacing).
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
state: np.ndarray of shape (num_history_steps, 240, 121, 82) in physical units
|
| 160 |
+
day_year_time: tuple (day_progress, year_progress) for init_time
|
| 161 |
+
longitude: np.ndarray (240,)
|
| 162 |
+
latitude: np.ndarray (121,) South→North
|
| 163 |
+
"""
|
| 164 |
+
# Open zarr (supports local paths, gs://, s3://, etc.)
|
| 165 |
+
if zarr_path.startswith('gs://'):
|
| 166 |
+
import gcsfs
|
| 167 |
+
fs = gcsfs.GCSFileSystem(token='anon')
|
| 168 |
+
store_obj = zarr.open(fs.get_mapper(zarr_path), mode='r')
|
| 169 |
+
else:
|
| 170 |
+
store_obj = zarr.open(zarr_path, mode='r')
|
| 171 |
+
|
| 172 |
+
times = _load_zarr_times(store_obj)
|
| 173 |
+
init_ts = pd.Timestamp(init_time)
|
| 174 |
+
|
| 175 |
+
# Find the index of init_time
|
| 176 |
+
idx = times.searchsorted(init_ts)
|
| 177 |
+
if idx >= len(times) or times[idx] != init_ts:
|
| 178 |
+
raise ValueError(
|
| 179 |
+
f"init_time '{init_time}' not found in zarr store. "
|
| 180 |
+
f"Available range: {times[0]} to {times[-1]}"
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# history indices: [idx - (H-1)*S, idx - (H-2)*S, ..., idx]
|
| 184 |
+
history_indices = [idx - (num_history_steps - 1 - i) * step_stride for i in range(num_history_steps)]
|
| 185 |
+
if history_indices[0] < 0:
|
| 186 |
+
raise ValueError(
|
| 187 |
+
f"Not enough history: need {num_history_steps} steps spaced {step_stride*6}h apart "
|
| 188 |
+
f"before {init_time}, but data starts at {times[0]}"
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Load longitude/latitude in the canonical (lon, lat S→N) order the model expects.
|
| 192 |
+
longitude = np.asarray(store_obj['longitude']) # (240,) 0..358.5
|
| 193 |
+
latitude_raw = np.asarray(store_obj['latitude']) # (121,)
|
| 194 |
+
if latitude_raw[0] > latitude_raw[-1]: # N→S in store → flip
|
| 195 |
+
latitude = latitude_raw[::-1].copy()
|
| 196 |
+
flip_lat = True
|
| 197 |
+
else:
|
| 198 |
+
latitude = latitude_raw.copy()
|
| 199 |
+
flip_lat = False
|
| 200 |
+
|
| 201 |
+
n_lon, n_lat = len(longitude), len(latitude)
|
| 202 |
+
n_vars = len(SL_VARS) + len(PL_VARS) * len(LEVELS)
|
| 203 |
+
state = np.empty((num_history_steps, n_lon, n_lat, n_vars), dtype=np.float32)
|
| 204 |
+
|
| 205 |
+
def _to_lon_lat(arr: np.ndarray, dims: list) -> np.ndarray:
|
| 206 |
+
"""Normalise a (lat,lon) or (lon,lat) slice to (lon, lat S→N)."""
|
| 207 |
+
if dims[-2:] == ['latitude', 'longitude']:
|
| 208 |
+
arr = arr.T # (lat,lon) -> (lon,lat)
|
| 209 |
+
elif dims[-2:] != ['longitude', 'latitude']:
|
| 210 |
+
raise ValueError(f"unexpected dim order: {dims}")
|
| 211 |
+
if flip_lat:
|
| 212 |
+
arr = arr[:, ::-1]
|
| 213 |
+
return np.ascontiguousarray(arr)
|
| 214 |
+
|
| 215 |
+
all_levels_in_store = list(np.asarray(store_obj['level'])) if 'level' in store_obj else None
|
| 216 |
+
|
| 217 |
+
for step_i, t_idx in enumerate(history_indices):
|
| 218 |
+
ch = 0
|
| 219 |
+
for var in SL_VARS:
|
| 220 |
+
dims = list(store_obj[var].attrs.get('_ARRAY_DIMENSIONS', ['time', 'latitude', 'longitude']))
|
| 221 |
+
arr = np.asarray(store_obj[var][t_idx]) # 2D
|
| 222 |
+
state[step_i, :, :, ch] = _to_lon_lat(arr, dims)
|
| 223 |
+
ch += 1
|
| 224 |
+
|
| 225 |
+
for var in PL_VARS:
|
| 226 |
+
dims = list(store_obj[var].attrs.get('_ARRAY_DIMENSIONS', ['time', 'level', 'latitude', 'longitude']))
|
| 227 |
+
arr_full = np.asarray(store_obj[var][t_idx]) # 3D (level, ...)
|
| 228 |
+
spatial_dims = [d for d in dims if d != 'time'] # drop time (already indexed)
|
| 229 |
+
for level in LEVELS:
|
| 230 |
+
lev_idx = all_levels_in_store.index(level) if all_levels_in_store is not None else LEVELS.index(level)
|
| 231 |
+
arr = arr_full[lev_idx] # 2D
|
| 232 |
+
# spatial_dims still includes 'level' at the front; pass just the 2D part
|
| 233 |
+
state[step_i, :, :, ch] = _to_lon_lat(arr, spatial_dims[1:] if spatial_dims[0] == 'level' else spatial_dims)
|
| 234 |
+
ch += 1
|
| 235 |
+
|
| 236 |
+
day_progress, year_progress = compute_day_year_progress(init_ts)
|
| 237 |
+
return state, (day_progress, year_progress), longitude, latitude
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# ---------------------------------------------------------------------------
|
| 241 |
+
# Model building
|
| 242 |
+
# ---------------------------------------------------------------------------
|
| 243 |
+
|
| 244 |
+
def build_model(
|
| 245 |
+
checkpoint_path: str,
|
| 246 |
+
variables: list,
|
| 247 |
+
longitude: np.ndarray,
|
| 248 |
+
latitude: np.ndarray,
|
| 249 |
+
preset: Preset,
|
| 250 |
+
norm_stats_path: str = "norm_stats.npz",
|
| 251 |
+
static_vars_path: str = "static_vars.npz",
|
| 252 |
+
device: str = "cuda",
|
| 253 |
+
):
|
| 254 |
+
"""Build and return the WeatherModel with loaded checkpoint and metadata."""
|
| 255 |
+
|
| 256 |
+
# Load normalization statistics
|
| 257 |
+
_ns = np.load(norm_stats_path)
|
| 258 |
+
norm_stats = NormalizationStats(
|
| 259 |
+
state_mean=torch.from_numpy(_ns['state_mean'].astype(np.float32)),
|
| 260 |
+
state_std=torch.from_numpy(_ns['state_std'].astype(np.float32)),
|
| 261 |
+
residual_mean=torch.from_numpy(_ns['residual_mean'].astype(np.float32)) if 'residual_mean' in _ns else torch.zeros(len(variables)),
|
| 262 |
+
residual_std=torch.from_numpy(_ns['residual_std'].astype(np.float32)) if 'residual_std' in _ns else torch.ones(len(variables)),
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Load static variables
|
| 266 |
+
_sv = np.load(static_vars_path)
|
| 267 |
+
static_data = torch.from_numpy(_sv['data'].astype(np.float32)) # (lon, lat, 3)
|
| 268 |
+
lon_tensor = torch.from_numpy(longitude.astype(np.float32))
|
| 269 |
+
lat_tensor = torch.from_numpy(latitude.astype(np.float32))
|
| 270 |
+
|
| 271 |
+
day_year_delta = torch.tensor(
|
| 272 |
+
[preset.step_stride / 4.0, preset.step_stride / 365.25], dtype=torch.float32
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
metadata = WeatherMetadata(
|
| 276 |
+
variables=variables,
|
| 277 |
+
static_variables=list(ST_VARS),
|
| 278 |
+
longitude=lon_tensor,
|
| 279 |
+
latitude=lat_tensor,
|
| 280 |
+
static_data=static_data,
|
| 281 |
+
day_year_delta=day_year_delta,
|
| 282 |
+
norm_stats=norm_stats,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# Build model
|
| 286 |
+
model_config = ModelConfig(
|
| 287 |
+
dim=preset.stage_cfgs[0].dim,
|
| 288 |
+
num_heads=preset.stage_cfgs[0].num_heads,
|
| 289 |
+
variables=variables,
|
| 290 |
+
static_variables=list(ST_VARS),
|
| 291 |
+
k_neighbors=preset.k_neighbors,
|
| 292 |
+
qk_norm=False,
|
| 293 |
+
rope=True,
|
| 294 |
+
rope_theta=10000,
|
| 295 |
+
sparse_every=1,
|
| 296 |
+
qkv_compress_ratio=1,
|
| 297 |
+
num_history_steps=preset.num_history_steps,
|
| 298 |
+
noise_dim=32,
|
| 299 |
+
rmsnorm_elementwise_affine=False,
|
| 300 |
+
cg_stage_cfgs=preset.stage_cfgs,
|
| 301 |
+
bottleneck_cfg=preset.bottleneck_cfg,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
backbone = Transformer(model_config)
|
| 305 |
+
model = WeatherModel(backbone, metadata).to(device).eval()
|
| 306 |
+
|
| 307 |
+
# Load checkpoint. The model registers several deterministic buffers (RoPE
|
| 308 |
+
# tables, HEALPix neighbour indices, static_vars) that are recomputed at
|
| 309 |
+
# __init__ from the metadata/config and therefore aren't expected in the
|
| 310 |
+
# saved checkpoint — so we load non-strictly and only warn on *unexpected*
|
| 311 |
+
# keys, which would indicate a real architecture mismatch.
|
| 312 |
+
ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
|
| 313 |
+
state_dict = ckpt.get('model_state_dict', ckpt)
|
| 314 |
+
result = model.load_state_dict(state_dict, strict=False)
|
| 315 |
+
if result.unexpected_keys:
|
| 316 |
+
raise RuntimeError(
|
| 317 |
+
f"Unexpected keys in checkpoint (architecture mismatch): {result.unexpected_keys[:5]}"
|
| 318 |
+
+ (f" ... and {len(result.unexpected_keys)-5} more" if len(result.unexpected_keys) > 5 else "")
|
| 319 |
+
)
|
| 320 |
+
print(f"Loaded checkpoint from {checkpoint_path} (epoch {ckpt.get('epoch', '?')})")
|
| 321 |
+
print(f" {len(state_dict)} keys loaded, {len(result.missing_keys)} buffer keys re-computed from config")
|
| 322 |
+
|
| 323 |
+
return model, metadata
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
# ---------------------------------------------------------------------------
|
| 327 |
+
# Autoregressive rollout (direct state prediction)
|
| 328 |
+
# ---------------------------------------------------------------------------
|
| 329 |
+
|
| 330 |
+
@torch.no_grad()
|
| 331 |
+
def unroll_direct(
|
| 332 |
+
model: WeatherModel,
|
| 333 |
+
initial_unnorm_state: torch.Tensor,
|
| 334 |
+
day_year_time: torch.Tensor,
|
| 335 |
+
day_year_delta: torch.Tensor,
|
| 336 |
+
norm_stats: NormalizationStats,
|
| 337 |
+
num_unroll_steps: int,
|
| 338 |
+
num_ensemble_members: int,
|
| 339 |
+
dtype: torch.dtype = torch.float16,
|
| 340 |
+
) -> torch.Tensor:
|
| 341 |
+
"""
|
| 342 |
+
Autoregressively forecast using direct state prediction (learn_direct=True).
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
model: WeatherModel
|
| 346 |
+
initial_unnorm_state: (B, num_history_steps, lon, lat, channels) in physical units
|
| 347 |
+
day_year_time: (B, 2) day/year progress fractions at init_time
|
| 348 |
+
day_year_delta: (2,) increment per step
|
| 349 |
+
norm_stats: NormalizationStats on the target device
|
| 350 |
+
num_unroll_steps: number of 6-hourly steps to forecast
|
| 351 |
+
num_ensemble_members: number of ensemble members (noise samples on step 0)
|
| 352 |
+
dtype: computation dtype (float16 recommended)
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
trajectory: (B, members, num_history_steps + num_unroll_steps, lon, lat, channels)
|
| 356 |
+
"""
|
| 357 |
+
batch_size = initial_unnorm_state.shape[0]
|
| 358 |
+
num_history_steps = initial_unnorm_state.shape[1]
|
| 359 |
+
device = initial_unnorm_state.device
|
| 360 |
+
|
| 361 |
+
trajectory = torch.empty(
|
| 362 |
+
(batch_size, num_ensemble_members, num_unroll_steps + num_history_steps)
|
| 363 |
+
+ initial_unnorm_state.shape[2:],
|
| 364 |
+
dtype=initial_unnorm_state.dtype,
|
| 365 |
+
device=device,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# Expand initial state to ensemble dimension
|
| 369 |
+
current_unnorm_state = initial_unnorm_state.unsqueeze(1) # (B, 1, H, lon, lat, C)
|
| 370 |
+
current_day_year_time = day_year_time.unsqueeze(1) # (B, 1, 2)
|
| 371 |
+
|
| 372 |
+
trajectory[:, :, :num_history_steps] = current_unnorm_state
|
| 373 |
+
|
| 374 |
+
for t in range(num_unroll_steps):
|
| 375 |
+
# Expand ensemble only on the first step
|
| 376 |
+
num_ens_step = num_ensemble_members if t == 0 else 1
|
| 377 |
+
|
| 378 |
+
current_norm_state = (current_unnorm_state - norm_stats.state_mean) / norm_stats.state_std
|
| 379 |
+
with torch.amp.autocast('cuda', dtype=dtype):
|
| 380 |
+
norm_next_state = model(current_norm_state, current_day_year_time, num_ens_step)
|
| 381 |
+
|
| 382 |
+
next_unnorm_state = norm_next_state * norm_stats.state_std + norm_stats.state_mean
|
| 383 |
+
current_day_year_time = current_day_year_time + day_year_delta.unsqueeze(0).unsqueeze(0).expand(
|
| 384 |
+
batch_size, num_ens_step, -1
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
trajectory[:, :, t + num_history_steps] = next_unnorm_state
|
| 388 |
+
current_unnorm_state = trajectory[:, :, t + 1 : t + 1 + num_history_steps]
|
| 389 |
+
|
| 390 |
+
return trajectory
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
# ---------------------------------------------------------------------------
|
| 394 |
+
# Main
|
| 395 |
+
# ---------------------------------------------------------------------------
|
| 396 |
+
|
| 397 |
+
def main():
|
| 398 |
+
parser = argparse.ArgumentParser(description="Mosaic 1.5° Weather Forecast Inference")
|
| 399 |
+
parser.add_argument("--variant", type=str, required=True, choices=sorted(PRESETS.keys()),
|
| 400 |
+
help="Model variant: 'era5' (ERA5-only, 24h steps) or 'hres' (ERA5+HRES finetune, 6h steps)")
|
| 401 |
+
parser.add_argument("--checkpoint", type=str, default=None,
|
| 402 |
+
help="Path to model checkpoint (.pt). Default: preset's default_checkpoint")
|
| 403 |
+
parser.add_argument("--zarr", type=str, required=True,
|
| 404 |
+
help="Path or GCS URI to zarr store with ERA5/HRES data at 1.5°")
|
| 405 |
+
parser.add_argument("--init-time", type=str, required=True,
|
| 406 |
+
help="Initialization time (ISO 8601), e.g. '2020-01-01T00:00'")
|
| 407 |
+
parser.add_argument("--steps", type=int, default=10,
|
| 408 |
+
help="Number of forecast steps (each step = step_stride*6h; e.g. era5 step=24h, hres step=6h)")
|
| 409 |
+
parser.add_argument("--members", type=int, default=1,
|
| 410 |
+
help="Number of ensemble members (default: 1)")
|
| 411 |
+
parser.add_argument("--output", type=str, default="forecast.npz",
|
| 412 |
+
help="Output file path (default: forecast.npz)")
|
| 413 |
+
parser.add_argument("--norm-stats", type=str, default=None,
|
| 414 |
+
help="Path to norm_stats .npz. Default: preset's default_norm_stats")
|
| 415 |
+
parser.add_argument("--static-vars", type=str, default="static_vars.npz",
|
| 416 |
+
help="Path to static_vars.npz (default: static_vars.npz in current dir)")
|
| 417 |
+
parser.add_argument("--k-neighbors", type=int, default=None,
|
| 418 |
+
help="Override preset's k_neighbors (advanced — for ablation only)")
|
| 419 |
+
parser.add_argument("--no-compile", action="store_true",
|
| 420 |
+
help="Disable torch.compile (slower but easier to debug)")
|
| 421 |
+
parser.add_argument("--device", type=str, default="cuda",
|
| 422 |
+
help="Device (default: cuda)")
|
| 423 |
+
args = parser.parse_args()
|
| 424 |
+
|
| 425 |
+
preset = PRESETS[args.variant]
|
| 426 |
+
if args.k_neighbors is not None and args.k_neighbors != preset.k_neighbors:
|
| 427 |
+
from dataclasses import replace
|
| 428 |
+
preset = replace(preset, k_neighbors=args.k_neighbors)
|
| 429 |
+
checkpoint_path = args.checkpoint or preset.default_checkpoint
|
| 430 |
+
norm_stats_path = args.norm_stats or preset.default_norm_stats
|
| 431 |
+
print(f"Variant: {args.variant} "
|
| 432 |
+
f"(step_stride={preset.step_stride}, num_history_steps={preset.num_history_steps}, "
|
| 433 |
+
f"k_neighbors={preset.k_neighbors})")
|
| 434 |
+
|
| 435 |
+
device = args.device
|
| 436 |
+
torch.set_float32_matmul_precision('high')
|
| 437 |
+
|
| 438 |
+
# Build variable list: 4 surface + 6*13 pressure-level = 82 channels
|
| 439 |
+
variables = list(SL_VARS)
|
| 440 |
+
for var in PL_VARS:
|
| 441 |
+
for level in LEVELS:
|
| 442 |
+
variables.append(f"{var}_{level}")
|
| 443 |
+
print(f"Variables: {len(variables)} channels")
|
| 444 |
+
|
| 445 |
+
# Load initial state from zarr
|
| 446 |
+
print(f"Loading initial state from zarr: {args.zarr}")
|
| 447 |
+
print(f" Init time: {args.init_time} (history: {preset.num_history_steps} x {preset.step_stride*6}h steps)")
|
| 448 |
+
initial_state_np, (day_prog, year_prog), longitude, latitude = load_initial_state(
|
| 449 |
+
args.zarr, args.init_time,
|
| 450 |
+
num_history_steps=preset.num_history_steps,
|
| 451 |
+
step_stride=preset.step_stride,
|
| 452 |
+
)
|
| 453 |
+
print(f" State shape: {initial_state_np.shape} (steps, lon, lat, channels)")
|
| 454 |
+
|
| 455 |
+
# Build model and load checkpoint
|
| 456 |
+
print(f"\nBuilding model and loading checkpoint: {checkpoint_path}")
|
| 457 |
+
model, metadata = build_model(
|
| 458 |
+
checkpoint_path=checkpoint_path,
|
| 459 |
+
variables=variables,
|
| 460 |
+
longitude=longitude,
|
| 461 |
+
latitude=latitude,
|
| 462 |
+
preset=preset,
|
| 463 |
+
norm_stats_path=norm_stats_path,
|
| 464 |
+
static_vars_path=args.static_vars,
|
| 465 |
+
device=device,
|
| 466 |
+
)
|
| 467 |
+
num_params = sum(p.numel() for p in model.parameters()) / 1e6
|
| 468 |
+
print(f" Parameters: {num_params:.1f}M")
|
| 469 |
+
|
| 470 |
+
# Optionally compile
|
| 471 |
+
if not args.no_compile:
|
| 472 |
+
print("Compiling model with torch.compile (reduce-overhead)...")
|
| 473 |
+
unroll_fn = torch.compile(unroll_direct, mode='reduce-overhead')
|
| 474 |
+
else:
|
| 475 |
+
unroll_fn = unroll_direct
|
| 476 |
+
|
| 477 |
+
# Prepare tensors
|
| 478 |
+
initial_state = torch.from_numpy(initial_state_np).unsqueeze(0).to(device) # (1, H, lon, lat, C)
|
| 479 |
+
day_year_time = torch.tensor([[day_prog, year_prog]], dtype=torch.float32, device=device) # (1, 2)
|
| 480 |
+
norm_stats_d = metadata.norm_stats.to(device)
|
| 481 |
+
day_year_delta_d = metadata.day_year_delta.to(device)
|
| 482 |
+
|
| 483 |
+
# Run forecast
|
| 484 |
+
total_hours = args.steps * preset.step_stride * 6
|
| 485 |
+
print(f"\nRunning {args.steps}-step forecast ({total_hours}h) with {args.members} member(s)...")
|
| 486 |
+
if device == 'cuda':
|
| 487 |
+
torch.cuda.reset_peak_memory_stats()
|
| 488 |
+
|
| 489 |
+
with torch.no_grad():
|
| 490 |
+
trajectory = unroll_fn(
|
| 491 |
+
model=model,
|
| 492 |
+
initial_unnorm_state=initial_state,
|
| 493 |
+
day_year_time=day_year_time,
|
| 494 |
+
day_year_delta=day_year_delta_d,
|
| 495 |
+
norm_stats=norm_stats_d,
|
| 496 |
+
num_unroll_steps=args.steps,
|
| 497 |
+
num_ensemble_members=args.members,
|
| 498 |
+
dtype=DTYPE,
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
if device == 'cuda':
|
| 502 |
+
torch.cuda.synchronize()
|
| 503 |
+
peak_gb = torch.cuda.max_memory_allocated() / 1e9
|
| 504 |
+
print(f" Peak GPU memory: {peak_gb:.1f} GB")
|
| 505 |
+
|
| 506 |
+
# Extract forecast steps (skip history)
|
| 507 |
+
forecasts = trajectory[0, :, preset.num_history_steps:].cpu().numpy() # (members, steps, lon, lat, C)
|
| 508 |
+
print(f" Forecast shape: {forecasts.shape}")
|
| 509 |
+
|
| 510 |
+
# Save output
|
| 511 |
+
lead_time_hours = np.arange(1, args.steps + 1) * 6 * preset.step_stride
|
| 512 |
+
np.savez(
|
| 513 |
+
args.output,
|
| 514 |
+
forecasts=forecasts,
|
| 515 |
+
variables=np.array(variables),
|
| 516 |
+
lead_time_hours=lead_time_hours,
|
| 517 |
+
init_time=np.str_(args.init_time),
|
| 518 |
+
longitude=longitude,
|
| 519 |
+
latitude=latitude,
|
| 520 |
+
)
|
| 521 |
+
print(f"\nSaved forecast to: {args.output}")
|
| 522 |
+
print(f" Shape: forecasts {forecasts.shape} (members, steps, lon=240, lat=121, channels=82)")
|
| 523 |
+
print(f" Lead times: {lead_time_hours[0]}h to {lead_time_hours[-1]}h")
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
if __name__ == "__main__":
|
| 527 |
+
main()
|
mosaic.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mosaic: U-Net transformer with block-sparse attention for weather forecasting.
|
| 3 |
+
|
| 4 |
+
Architecture:
|
| 5 |
+
- Cross-attention interpolation between lon/lat and HEALPix grids
|
| 6 |
+
- Block-sparse attention (local block + compressed + top-k selection branches)
|
| 7 |
+
arranged in a U-Net encoder–bottleneck–decoder
|
| 8 |
+
- Probabilistic training with noise injection
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from einops import rearrange, repeat
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from torch.nn import RMSNorm
|
| 17 |
+
|
| 18 |
+
from utils import get_healpix_grid, rad_to_xyz
|
| 19 |
+
from primitives import (
|
| 20 |
+
MosaicBlock as _MosaicBlock,
|
| 21 |
+
CrossAttentionInterpolate,
|
| 22 |
+
NoiseGenerator,
|
| 23 |
+
HEALPixDownsample,
|
| 24 |
+
HEALPixUpsample,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class StageConfig:
|
| 30 |
+
"""Configuration for a U-Net encoder/decoder stage."""
|
| 31 |
+
nside: int
|
| 32 |
+
dim: int
|
| 33 |
+
num_heads: int
|
| 34 |
+
block_attn_size: int
|
| 35 |
+
sparse_block_size: int
|
| 36 |
+
sparse_block_count: int
|
| 37 |
+
encoder_depth: int
|
| 38 |
+
decoder_depth: int
|
| 39 |
+
mlp_ratio: float
|
| 40 |
+
gqa_ratio: int
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class BottleneckConfig:
|
| 45 |
+
"""Configuration for the U-Net bottleneck stage."""
|
| 46 |
+
nside: int
|
| 47 |
+
dim: int
|
| 48 |
+
num_heads: int
|
| 49 |
+
block_attn_size: int
|
| 50 |
+
sparse_block_size: int
|
| 51 |
+
sparse_block_count: int
|
| 52 |
+
depth: int
|
| 53 |
+
mlp_ratio: float
|
| 54 |
+
gqa_ratio: int
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@dataclass
|
| 58 |
+
class ModelConfig:
|
| 59 |
+
"""Configuration for the Mosaic model."""
|
| 60 |
+
dim: int
|
| 61 |
+
num_heads: int
|
| 62 |
+
k_neighbors: int
|
| 63 |
+
qk_norm: bool
|
| 64 |
+
rope: bool
|
| 65 |
+
rope_theta: int
|
| 66 |
+
sparse_every: int
|
| 67 |
+
variables: list[str]
|
| 68 |
+
static_variables: list[str]
|
| 69 |
+
qkv_compress_ratio: int
|
| 70 |
+
cg_stage_cfgs: list[StageConfig]
|
| 71 |
+
bottleneck_cfg: BottleneckConfig
|
| 72 |
+
num_history_steps: int = 1
|
| 73 |
+
noise_dim: int = 32
|
| 74 |
+
ortho_init: bool = False
|
| 75 |
+
rmsnorm_elementwise_affine: bool = True
|
| 76 |
+
no_compression: bool = False
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@dataclass
|
| 80 |
+
class _MergedStageConfig:
|
| 81 |
+
"""Merges ModelConfig and StageConfig for compatibility with MosaicBlock."""
|
| 82 |
+
dim: int
|
| 83 |
+
num_heads: int
|
| 84 |
+
block_attn_size: int
|
| 85 |
+
sparse_block_size: int
|
| 86 |
+
sparse_block_count: int
|
| 87 |
+
gqa_ratio: int
|
| 88 |
+
qkv_compress_ratio: int
|
| 89 |
+
rope: bool
|
| 90 |
+
rope_theta: int
|
| 91 |
+
mlp_ratio: float
|
| 92 |
+
noise_dim: int
|
| 93 |
+
rmsnorm_elementwise_affine: bool
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _merge_configs(config: ModelConfig, stage_cfg) -> _MergedStageConfig:
|
| 97 |
+
return _MergedStageConfig(
|
| 98 |
+
dim=stage_cfg.dim,
|
| 99 |
+
num_heads=stage_cfg.num_heads,
|
| 100 |
+
block_attn_size=stage_cfg.block_attn_size,
|
| 101 |
+
sparse_block_size=stage_cfg.sparse_block_size,
|
| 102 |
+
sparse_block_count=stage_cfg.sparse_block_count,
|
| 103 |
+
gqa_ratio=stage_cfg.gqa_ratio,
|
| 104 |
+
qkv_compress_ratio=config.qkv_compress_ratio,
|
| 105 |
+
rope=config.rope,
|
| 106 |
+
rope_theta=config.rope_theta,
|
| 107 |
+
mlp_ratio=stage_cfg.mlp_ratio,
|
| 108 |
+
noise_dim=config.noise_dim,
|
| 109 |
+
rmsnorm_elementwise_affine=config.rmsnorm_elementwise_affine,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _make_mosaic_block(config: ModelConfig, stage_cfg, block_attn_only: bool) -> _MosaicBlock:
|
| 114 |
+
return _MosaicBlock(_merge_configs(config, stage_cfg), block_attn_only, no_compression=config.no_compression)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class UNetStage(nn.Module):
|
| 118 |
+
def __init__(self, config, stage_cfg, depth):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.nside = stage_cfg.nside
|
| 121 |
+
self.blocks = nn.ModuleList([
|
| 122 |
+
_make_mosaic_block(
|
| 123 |
+
config=config,
|
| 124 |
+
stage_cfg=stage_cfg,
|
| 125 |
+
block_attn_only=(config.sparse_every <= 0) or not (i % config.sparse_every == 0),
|
| 126 |
+
)
|
| 127 |
+
for i in range(depth)
|
| 128 |
+
])
|
| 129 |
+
|
| 130 |
+
def forward(self, x, z=None):
|
| 131 |
+
for block in self.blocks:
|
| 132 |
+
x = block(x, z)
|
| 133 |
+
return x
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class Transformer(nn.Module):
|
| 137 |
+
"""U-Net style Transformer for weather forecasting on HEALPix grids."""
|
| 138 |
+
|
| 139 |
+
space_dim = 3
|
| 140 |
+
time_dim = 4
|
| 141 |
+
|
| 142 |
+
def __init__(self, config: ModelConfig, seed: int = 42):
|
| 143 |
+
super().__init__()
|
| 144 |
+
|
| 145 |
+
self.config = config
|
| 146 |
+
self.nside = config.cg_stage_cfgs[0].nside
|
| 147 |
+
self.noise_dim = config.noise_dim
|
| 148 |
+
|
| 149 |
+
initial_dim = config.dim
|
| 150 |
+
feature_dim = (len(config.variables) * config.num_history_steps
|
| 151 |
+
+ len(config.static_variables) + self.space_dim + self.time_dim)
|
| 152 |
+
|
| 153 |
+
if self.noise_dim > 0:
|
| 154 |
+
self.noise_generator = NoiseGenerator(self.noise_dim, seed)
|
| 155 |
+
|
| 156 |
+
self.preprocess = nn.Sequential(
|
| 157 |
+
nn.Linear(feature_dim, initial_dim, bias=False),
|
| 158 |
+
RMSNorm(initial_dim, elementwise_affine=config.rmsnorm_elementwise_affine),
|
| 159 |
+
nn.SiLU(),
|
| 160 |
+
nn.Linear(initial_dim, initial_dim, bias=False),
|
| 161 |
+
RMSNorm(initial_dim, elementwise_affine=config.rmsnorm_elementwise_affine),
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
self.interp_to_hp = CrossAttentionInterpolate(config)
|
| 165 |
+
self.interp_to_ll = CrossAttentionInterpolate(config)
|
| 166 |
+
|
| 167 |
+
self.encoder_stages = nn.ModuleList()
|
| 168 |
+
self.downsample_layers = nn.ModuleList()
|
| 169 |
+
|
| 170 |
+
all_stages = [*config.cg_stage_cfgs, config.bottleneck_cfg]
|
| 171 |
+
|
| 172 |
+
for i in range(len(config.cg_stage_cfgs)):
|
| 173 |
+
current_stage = all_stages[i]
|
| 174 |
+
next_stage = all_stages[i + 1]
|
| 175 |
+
|
| 176 |
+
self.encoder_stages.append(UNetStage(config=config, stage_cfg=current_stage, depth=current_stage.encoder_depth))
|
| 177 |
+
self.downsample_layers.append(
|
| 178 |
+
HEALPixDownsample(
|
| 179 |
+
in_dim=current_stage.dim,
|
| 180 |
+
out_dim=next_stage.dim,
|
| 181 |
+
nside_before=current_stage.nside,
|
| 182 |
+
nside_after=next_stage.nside,
|
| 183 |
+
rmsnorm_elementwise_affine=config.rmsnorm_elementwise_affine,
|
| 184 |
+
)
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
self.bottleneck = UNetStage(config=config, stage_cfg=config.bottleneck_cfg, depth=config.bottleneck_cfg.depth)
|
| 188 |
+
|
| 189 |
+
self.decoder_stages = nn.ModuleList()
|
| 190 |
+
self.upsample_layers = nn.ModuleList()
|
| 191 |
+
|
| 192 |
+
for i in reversed(range(len(config.cg_stage_cfgs))):
|
| 193 |
+
prev_stage = all_stages[i + 1]
|
| 194 |
+
current_stage = all_stages[i]
|
| 195 |
+
|
| 196 |
+
self.upsample_layers.append(
|
| 197 |
+
HEALPixUpsample(
|
| 198 |
+
in_dim=prev_stage.dim,
|
| 199 |
+
out_dim=current_stage.dim,
|
| 200 |
+
nside_before=prev_stage.nside,
|
| 201 |
+
nside_after=current_stage.nside,
|
| 202 |
+
rmsnorm_elementwise_affine=config.rmsnorm_elementwise_affine,
|
| 203 |
+
)
|
| 204 |
+
)
|
| 205 |
+
self.decoder_stages.append(UNetStage(config=config, stage_cfg=current_stage, depth=current_stage.decoder_depth))
|
| 206 |
+
|
| 207 |
+
self.norm_before_interp_ll = RMSNorm(initial_dim, elementwise_affine=config.rmsnorm_elementwise_affine)
|
| 208 |
+
|
| 209 |
+
self.postprocess = nn.Sequential(
|
| 210 |
+
RMSNorm(initial_dim, elementwise_affine=config.rmsnorm_elementwise_affine),
|
| 211 |
+
nn.Linear(initial_dim, initial_dim, bias=False),
|
| 212 |
+
nn.SiLU(),
|
| 213 |
+
nn.Linear(initial_dim, len(config.variables), bias=False),
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
self.apply(self._initialize_weights)
|
| 217 |
+
self._zero_init_residual_layers()
|
| 218 |
+
self.initialize_rope()
|
| 219 |
+
|
| 220 |
+
def _initialize_weights(self, module):
|
| 221 |
+
if module is self:
|
| 222 |
+
return
|
| 223 |
+
ortho_init = self.config.ortho_init
|
| 224 |
+
|
| 225 |
+
if isinstance(module, nn.Linear):
|
| 226 |
+
fan_in, fan_out = module.weight.size(1), module.weight.size(0)
|
| 227 |
+
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
|
| 228 |
+
if ortho_init:
|
| 229 |
+
nn.init.orthogonal_(module.weight); module.weight.data.mul_(std)
|
| 230 |
+
else:
|
| 231 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 232 |
+
if module.bias is not None: nn.init.zeros_(module.bias)
|
| 233 |
+
|
| 234 |
+
def _zero_init_residual_layers(self):
|
| 235 |
+
ortho_init = self.config.ortho_init
|
| 236 |
+
|
| 237 |
+
for stage in [*self.encoder_stages, self.bottleneck, *self.decoder_stages]:
|
| 238 |
+
for block in stage.blocks:
|
| 239 |
+
if ortho_init:
|
| 240 |
+
nn.init.orthogonal_(block.attention.to_o.weight)
|
| 241 |
+
block.attention.to_o.weight.data.mul_(0.01)
|
| 242 |
+
nn.init.orthogonal_(block.ffn.w2.weight)
|
| 243 |
+
block.ffn.w2.weight.data.mul_(0.01)
|
| 244 |
+
else:
|
| 245 |
+
nn.init.normal_(block.attention.to_o.weight, mean=0.0, std=0.01)
|
| 246 |
+
nn.init.normal_(block.ffn.w2.weight, mean=0.0, std=0.01)
|
| 247 |
+
|
| 248 |
+
if self.noise_dim > 0:
|
| 249 |
+
nn.init.normal_(block.ffn.noise_bias.weight, mean=0.0, std=0.01)
|
| 250 |
+
|
| 251 |
+
for upsample in self.upsample_layers:
|
| 252 |
+
if ortho_init:
|
| 253 |
+
nn.init.orthogonal_(upsample.proj_x.weight); upsample.proj_x.weight.data.mul_(0.01)
|
| 254 |
+
nn.init.orthogonal_(upsample.proj_pos.weight); upsample.proj_pos.weight.data.mul_(0.01)
|
| 255 |
+
else:
|
| 256 |
+
nn.init.normal_(upsample.proj_x.weight, mean=0.0, std=0.01)
|
| 257 |
+
nn.init.normal_(upsample.proj_pos.weight, mean=0.0, std=0.01)
|
| 258 |
+
|
| 259 |
+
if self.noise_dim > 0:
|
| 260 |
+
nn.init.normal_(self.noise_generator.to_noise.weight, mean=0.0, std=0.01)
|
| 261 |
+
|
| 262 |
+
def initialize_rope(self):
|
| 263 |
+
if not self.config.rope:
|
| 264 |
+
return
|
| 265 |
+
for stage in [*self.encoder_stages, self.bottleneck, *self.decoder_stages]:
|
| 266 |
+
hp_grid = get_healpix_grid(stage.nside)
|
| 267 |
+
for block in stage.blocks:
|
| 268 |
+
if block.attention.q_rope is not None:
|
| 269 |
+
block.attention.q_rope.initialize_rope(hp_grid)
|
| 270 |
+
block.attention.k_rope.initialize_rope(hp_grid)
|
| 271 |
+
|
| 272 |
+
def initialize_interpolation(self, longitude: torch.Tensor, latitude: torch.Tensor):
|
| 273 |
+
ll_grid_rad = torch.deg2rad(torch.stack(torch.meshgrid(longitude, latitude, indexing='ij'), -1).reshape(-1, 2))
|
| 274 |
+
hp_grid_rad = torch.deg2rad(get_healpix_grid(self.nside)).to(longitude.device)
|
| 275 |
+
self.interp_to_hp.initialize_interpolation_scheme(ll_grid_rad, hp_grid_rad)
|
| 276 |
+
self.interp_to_ll.initialize_interpolation_scheme(hp_grid_rad, ll_grid_rad)
|
| 277 |
+
|
| 278 |
+
@torch.no_grad()
|
| 279 |
+
def initialize_static_vars(self, static_vars: torch.Tensor, longitude: torch.Tensor, latitude: torch.Tensor):
|
| 280 |
+
ll_grid_rad = torch.deg2rad(torch.stack(torch.meshgrid(longitude, latitude, indexing='ij'), -1))
|
| 281 |
+
ll_grid_xyz = rad_to_xyz(ll_grid_rad)
|
| 282 |
+
static_vars = torch.concat([static_vars, ll_grid_xyz], dim=-1)
|
| 283 |
+
static_vars_mean = static_vars.mean(dim=(0, 1), keepdim=True)
|
| 284 |
+
static_vars_std = static_vars.std(dim=(0, 1), keepdim=True) + 1e-6
|
| 285 |
+
static_vars_norm = (static_vars - static_vars_mean) / static_vars_std
|
| 286 |
+
static_vars = rearrange(static_vars_norm, 'lon lat c -> (lon lat) 1 c').contiguous()
|
| 287 |
+
self.register_buffer('static_vars', static_vars, persistent=True)
|
| 288 |
+
|
| 289 |
+
@torch.no_grad()
|
| 290 |
+
def time_embedding(self, day_year_time: torch.Tensor):
|
| 291 |
+
day = day_year_time[:, 0:1]
|
| 292 |
+
year = day_year_time[:, 1:2]
|
| 293 |
+
day_sin = torch.sin(2 * math.pi * day)
|
| 294 |
+
day_cos = torch.cos(2 * math.pi * day)
|
| 295 |
+
year_sin = torch.sin(2 * math.pi * year)
|
| 296 |
+
year_cos = torch.cos(2 * math.pi * year)
|
| 297 |
+
return torch.cat([day_sin, day_cos, year_sin, year_cos], dim=-1)
|
| 298 |
+
|
| 299 |
+
def forward(self, x: torch.Tensor, day_year_time: torch.Tensor, num_noise_samples: int):
|
| 300 |
+
b, n, _, lon, lat, _ = x.shape
|
| 301 |
+
batch_size = b * num_noise_samples * n
|
| 302 |
+
|
| 303 |
+
if self.noise_dim > 0:
|
| 304 |
+
z = self.noise_generator(batch_size, x.device, x.dtype)
|
| 305 |
+
else:
|
| 306 |
+
z = None
|
| 307 |
+
|
| 308 |
+
x = repeat(x, 'b n t lon lat c -> (lon lat) (b s n) (t c)', s=num_noise_samples)
|
| 309 |
+
day_year_time = repeat(day_year_time, 'b n d -> (b s n) d', s=num_noise_samples)
|
| 310 |
+
|
| 311 |
+
x = torch.cat([
|
| 312 |
+
x,
|
| 313 |
+
self.static_vars.expand(-1, batch_size, -1),
|
| 314 |
+
self.time_embedding(day_year_time).unsqueeze(0).expand(x.shape[0], -1, -1)
|
| 315 |
+
], dim=-1)
|
| 316 |
+
|
| 317 |
+
x = self.preprocess(x)
|
| 318 |
+
x = self.interp_to_hp(x)
|
| 319 |
+
|
| 320 |
+
skip_connections = []
|
| 321 |
+
for encoder_stage, downsample in zip(self.encoder_stages, self.downsample_layers):
|
| 322 |
+
x = encoder_stage(x, z)
|
| 323 |
+
skip_connections.append(x)
|
| 324 |
+
x = downsample(x)
|
| 325 |
+
|
| 326 |
+
x = self.bottleneck(x, z)
|
| 327 |
+
|
| 328 |
+
for decoder_stage, upsample, skip in zip(self.decoder_stages, self.upsample_layers, reversed(skip_connections)):
|
| 329 |
+
x = upsample(x, skip)
|
| 330 |
+
x = decoder_stage(x, z)
|
| 331 |
+
|
| 332 |
+
x = self.norm_before_interp_ll(x)
|
| 333 |
+
x = self.interp_to_ll(x)
|
| 334 |
+
x = self.postprocess(x)
|
| 335 |
+
|
| 336 |
+
x = rearrange(x, '(lon lat) (b n s) c -> b (n s) lon lat c', lon=lon, lat=lat, b=b, s=num_noise_samples)
|
| 337 |
+
return x
|
ops.py
ADDED
|
@@ -0,0 +1,566 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_autotuning_configs(q_tile_sizes: list):
|
| 8 |
+
"""Generate autotuning configurations optimized for H100."""
|
| 9 |
+
warps = [4, 8]
|
| 10 |
+
stages = [2, 3]
|
| 11 |
+
|
| 12 |
+
return [
|
| 13 |
+
triton.Config({'q_tile_size': t}, num_warps=w, num_stages=s)
|
| 14 |
+
for t in q_tile_sizes
|
| 15 |
+
for w in warps
|
| 16 |
+
for s in stages
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@triton.autotune(
|
| 21 |
+
configs=get_autotuning_configs([64, 128]),
|
| 22 |
+
key=['seq_len', 'feature_dim'],
|
| 23 |
+
)
|
| 24 |
+
@triton.jit
|
| 25 |
+
def mosaic_attn_fwd_kernel(
|
| 26 |
+
q_ptr, k_ptr, v_ptr, output_ptr, lse_ptr, block_indices_ptr,
|
| 27 |
+
softmax_scale: tl.constexpr,
|
| 28 |
+
seq_len: tl.constexpr,
|
| 29 |
+
num_kv_heads: tl.constexpr,
|
| 30 |
+
num_q_heads: tl.constexpr,
|
| 31 |
+
q_heads_per_kv_head: tl.constexpr,
|
| 32 |
+
feature_dim: tl.constexpr,
|
| 33 |
+
kv_block_size: tl.constexpr,
|
| 34 |
+
num_kv_blocks_per_q_block: tl.constexpr,
|
| 35 |
+
q_tile_size: tl.constexpr,
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Sparse attention forward kernel:
|
| 39 |
+
for each query tile (i.e. block chunk), for each query head, attend to a subset of key/value blocks.
|
| 40 |
+
"""
|
| 41 |
+
LOG2_E: tl.constexpr = 1.44269504089
|
| 42 |
+
|
| 43 |
+
q_tile_id = tl.program_id(0)
|
| 44 |
+
q_head_id = tl.program_id(1)
|
| 45 |
+
batch_kv_head_id = tl.program_id(2)
|
| 46 |
+
|
| 47 |
+
batch_idx = batch_kv_head_id // num_kv_heads
|
| 48 |
+
kv_head_idx = batch_kv_head_id % num_kv_heads
|
| 49 |
+
q_head_idx = kv_head_idx * q_heads_per_kv_head + q_head_id
|
| 50 |
+
|
| 51 |
+
batch_offset = batch_idx * seq_len
|
| 52 |
+
q_tile_start = q_tile_id * q_tile_size
|
| 53 |
+
num_blocks_in_seq = seq_len // kv_block_size
|
| 54 |
+
tiles_per_block = kv_block_size // q_tile_size
|
| 55 |
+
q_block_id = q_tile_id // tiles_per_block
|
| 56 |
+
|
| 57 |
+
block_indices_offset = (
|
| 58 |
+
batch_idx * num_blocks_in_seq * num_kv_heads * num_kv_blocks_per_q_block +
|
| 59 |
+
q_block_id * num_kv_heads * num_kv_blocks_per_q_block +
|
| 60 |
+
kv_head_idx * num_kv_blocks_per_q_block
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
q_base_ptr = q_ptr + batch_offset * num_q_heads * feature_dim + q_head_idx * feature_dim
|
| 64 |
+
k_base_ptr = k_ptr + batch_offset * num_kv_heads * feature_dim + kv_head_idx * feature_dim
|
| 65 |
+
v_base_ptr = v_ptr + batch_offset * num_kv_heads * feature_dim + kv_head_idx * feature_dim
|
| 66 |
+
|
| 67 |
+
q_tile_ptr = tl.make_block_ptr(
|
| 68 |
+
base=q_base_ptr,
|
| 69 |
+
shape=(seq_len, feature_dim),
|
| 70 |
+
strides=(num_q_heads * feature_dim, 1),
|
| 71 |
+
offsets=(q_tile_start, 0),
|
| 72 |
+
block_shape=(q_tile_size, feature_dim),
|
| 73 |
+
order=(1, 0)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
output_tile_ptr = tl.make_block_ptr(
|
| 77 |
+
base=output_ptr + batch_offset * num_q_heads * feature_dim + q_head_idx * feature_dim,
|
| 78 |
+
shape=(seq_len, feature_dim),
|
| 79 |
+
strides=(num_q_heads * feature_dim, 1),
|
| 80 |
+
offsets=(q_tile_start, 0),
|
| 81 |
+
block_shape=(q_tile_size, feature_dim),
|
| 82 |
+
order=(1, 0)
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
lse_base_ptr = lse_ptr + (batch_offset + q_tile_start) * num_q_heads + tl.arange(0, q_tile_size) * num_q_heads + q_head_idx
|
| 86 |
+
|
| 87 |
+
output_accum = tl.zeros([q_tile_size, feature_dim], dtype=tl.float32)
|
| 88 |
+
max_scores = tl.full([q_tile_size], float('-inf'), dtype=tl.float32)
|
| 89 |
+
sum_exp_scores = tl.zeros([q_tile_size], dtype=tl.float32)
|
| 90 |
+
|
| 91 |
+
q_tile = tl.load(q_tile_ptr)
|
| 92 |
+
q_tile = (q_tile * softmax_scale * LOG2_E).to(tl.bfloat16)
|
| 93 |
+
|
| 94 |
+
for i in range(num_kv_blocks_per_q_block):
|
| 95 |
+
kv_block_start = kv_block_size * tl.load(block_indices_ptr + block_indices_offset + i).to(tl.int32)
|
| 96 |
+
|
| 97 |
+
k_block_ptr = tl.make_block_ptr(
|
| 98 |
+
base=k_base_ptr,
|
| 99 |
+
shape=(feature_dim, seq_len),
|
| 100 |
+
strides=(1, num_kv_heads * feature_dim),
|
| 101 |
+
offsets=(0, kv_block_start),
|
| 102 |
+
block_shape=(feature_dim, kv_block_size),
|
| 103 |
+
order=(1, 0)
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
v_block_ptr = tl.make_block_ptr(
|
| 107 |
+
base=v_base_ptr,
|
| 108 |
+
shape=(seq_len, feature_dim),
|
| 109 |
+
strides=(num_kv_heads * feature_dim, 1),
|
| 110 |
+
offsets=(kv_block_start, 0),
|
| 111 |
+
block_shape=(kv_block_size, feature_dim),
|
| 112 |
+
order=(1, 0)
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
k_block = tl.load(k_block_ptr).to(tl.bfloat16)
|
| 116 |
+
v_block = tl.load(v_block_ptr).to(tl.bfloat16)
|
| 117 |
+
|
| 118 |
+
attention_scores = tl.dot(q_tile, k_block)
|
| 119 |
+
|
| 120 |
+
new_max = tl.max(attention_scores, axis=1)
|
| 121 |
+
old_max = max_scores
|
| 122 |
+
max_scores = tl.maximum(max_scores, new_max)
|
| 123 |
+
rescale = tl.exp2(old_max - max_scores)
|
| 124 |
+
attention_probs = tl.exp2(attention_scores - max_scores[:, None])
|
| 125 |
+
sum_exp_scores = sum_exp_scores * rescale + tl.sum(attention_probs, axis=1)
|
| 126 |
+
|
| 127 |
+
output_accum = output_accum * rescale[:, None]
|
| 128 |
+
output_accum += tl.dot(attention_probs.to(tl.bfloat16), v_block)
|
| 129 |
+
|
| 130 |
+
final_output = output_accum / sum_exp_scores[:, None]
|
| 131 |
+
log_sum_exp = (max_scores + tl.log2(sum_exp_scores))
|
| 132 |
+
|
| 133 |
+
tl.store(output_tile_ptr, final_output.to(q_ptr.dtype.element_ty))
|
| 134 |
+
tl.store(lse_base_ptr, log_sum_exp.to(tl.float32))
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def mosaic_attn_fwd(
|
| 138 |
+
q: torch.Tensor,
|
| 139 |
+
k: torch.Tensor,
|
| 140 |
+
v: torch.Tensor,
|
| 141 |
+
block_indices: torch.LongTensor,
|
| 142 |
+
block_size: int,
|
| 143 |
+
softmax_scale: float,
|
| 144 |
+
):
|
| 145 |
+
batch_size, seq_len, num_kv_heads, feature_dim = k.shape
|
| 146 |
+
num_q_heads = q.shape[2]
|
| 147 |
+
num_kv_blocks_per_q_block = block_indices.shape[-1]
|
| 148 |
+
q_heads_per_kv_head = num_q_heads // num_kv_heads
|
| 149 |
+
|
| 150 |
+
output = torch.empty(batch_size, seq_len, num_q_heads, feature_dim, dtype=v.dtype, device=q.device)
|
| 151 |
+
lse = torch.empty(batch_size, seq_len, num_q_heads, dtype=torch.float32, device=q.device)
|
| 152 |
+
|
| 153 |
+
grid = lambda META: (
|
| 154 |
+
triton.cdiv(seq_len, META['q_tile_size']),
|
| 155 |
+
q_heads_per_kv_head,
|
| 156 |
+
batch_size * num_kv_heads
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
mosaic_attn_fwd_kernel[grid](
|
| 160 |
+
q_ptr = q,
|
| 161 |
+
k_ptr = k,
|
| 162 |
+
v_ptr = v,
|
| 163 |
+
output_ptr = output,
|
| 164 |
+
lse_ptr = lse,
|
| 165 |
+
block_indices_ptr = block_indices,
|
| 166 |
+
softmax_scale = softmax_scale,
|
| 167 |
+
seq_len = seq_len,
|
| 168 |
+
num_kv_heads = num_kv_heads,
|
| 169 |
+
num_q_heads = num_q_heads,
|
| 170 |
+
q_heads_per_kv_head = q_heads_per_kv_head,
|
| 171 |
+
feature_dim = feature_dim,
|
| 172 |
+
kv_block_size = block_size,
|
| 173 |
+
num_kv_blocks_per_q_block = num_kv_blocks_per_q_block,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
return output, lse
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
@triton.autotune(
|
| 180 |
+
configs=get_autotuning_configs([64, 128]),
|
| 181 |
+
key=['seq_len', 'feature_dim'],
|
| 182 |
+
)
|
| 183 |
+
@triton.jit
|
| 184 |
+
def mosaic_attn_bwd_q_kernel(
|
| 185 |
+
q_ptr, k_ptr, v_ptr, lse_ptr, delta_ptr, grad_o_ptr, grad_q_ptr, block_indices_ptr,
|
| 186 |
+
softmax_scale: tl.constexpr,
|
| 187 |
+
seq_len: tl.constexpr,
|
| 188 |
+
num_kv_heads: tl.constexpr,
|
| 189 |
+
num_q_heads: tl.constexpr,
|
| 190 |
+
q_heads_per_kv_head: tl.constexpr,
|
| 191 |
+
feature_dim: tl.constexpr,
|
| 192 |
+
kv_block_size: tl.constexpr,
|
| 193 |
+
num_kv_blocks_per_q_block: tl.constexpr,
|
| 194 |
+
q_tile_size: tl.constexpr,
|
| 195 |
+
):
|
| 196 |
+
LOG2_E: tl.constexpr = 1.44269504089
|
| 197 |
+
LN_2: tl.constexpr = 0.69314718056
|
| 198 |
+
|
| 199 |
+
q_tile_id = tl.program_id(0)
|
| 200 |
+
q_head_id = tl.program_id(1)
|
| 201 |
+
batch_kv_head_id = tl.program_id(2)
|
| 202 |
+
|
| 203 |
+
batch_idx = batch_kv_head_id // num_kv_heads
|
| 204 |
+
kv_head_idx = batch_kv_head_id % num_kv_heads
|
| 205 |
+
q_head_idx = kv_head_idx * q_heads_per_kv_head + q_head_id
|
| 206 |
+
|
| 207 |
+
batch_offset = batch_idx * seq_len
|
| 208 |
+
q_tile_start = q_tile_id * q_tile_size
|
| 209 |
+
tiles_per_block = kv_block_size // q_tile_size
|
| 210 |
+
q_block_id = q_tile_id // tiles_per_block
|
| 211 |
+
num_q_blocks = seq_len // kv_block_size
|
| 212 |
+
|
| 213 |
+
block_indices_offset = (
|
| 214 |
+
batch_idx * num_q_blocks * num_kv_heads * num_kv_blocks_per_q_block +
|
| 215 |
+
q_block_id * num_kv_heads * num_kv_blocks_per_q_block +
|
| 216 |
+
kv_head_idx * num_kv_blocks_per_q_block
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
q_offsets = (
|
| 220 |
+
tl.arange(0, q_tile_size)[:, None] * num_q_heads * feature_dim +
|
| 221 |
+
q_head_idx * feature_dim +
|
| 222 |
+
tl.arange(0, feature_dim)[None, :]
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
lse_offsets = tl.arange(0, q_tile_size) * num_q_heads + q_head_idx
|
| 226 |
+
|
| 227 |
+
q_base_ptr = q_ptr + (batch_offset + q_tile_start) * num_q_heads * feature_dim
|
| 228 |
+
grad_o_base_ptr = grad_o_ptr + (batch_offset + q_tile_start) * num_q_heads * feature_dim
|
| 229 |
+
delta_base_ptr = delta_ptr + (batch_offset + q_tile_start) * num_q_heads
|
| 230 |
+
lse_base_ptr = lse_ptr + (batch_offset + q_tile_start) * num_q_heads
|
| 231 |
+
grad_q_base_ptr = grad_q_ptr + (batch_offset + q_tile_start) * num_q_heads * feature_dim
|
| 232 |
+
|
| 233 |
+
grad_q_accum = tl.zeros([q_tile_size, feature_dim], dtype=tl.float32)
|
| 234 |
+
|
| 235 |
+
q_tile = tl.load(q_base_ptr + q_offsets)
|
| 236 |
+
q_tile = (q_tile * softmax_scale * LOG2_E).to(tl.bfloat16)
|
| 237 |
+
|
| 238 |
+
grad_o_tile = tl.load(grad_o_base_ptr + q_offsets).to(tl.bfloat16)
|
| 239 |
+
delta_vals = tl.load(delta_base_ptr + lse_offsets)
|
| 240 |
+
lse_vals = tl.load(lse_base_ptr + lse_offsets).to(tl.float32)
|
| 241 |
+
|
| 242 |
+
for i in range(num_kv_blocks_per_q_block):
|
| 243 |
+
kv_block_idx = tl.load(block_indices_ptr + block_indices_offset + i).to(tl.int32)
|
| 244 |
+
|
| 245 |
+
k_block_ptr = tl.make_block_ptr(
|
| 246 |
+
base=k_ptr + (batch_offset * num_kv_heads + kv_head_idx) * feature_dim,
|
| 247 |
+
shape=(feature_dim, seq_len),
|
| 248 |
+
strides=(1, num_kv_heads * feature_dim),
|
| 249 |
+
offsets=(0, kv_block_idx * kv_block_size),
|
| 250 |
+
block_shape=(feature_dim, kv_block_size),
|
| 251 |
+
order=(0, 1)
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
v_block_ptr = tl.make_block_ptr(
|
| 255 |
+
base=v_ptr + (batch_offset * num_kv_heads + kv_head_idx) * feature_dim,
|
| 256 |
+
shape=(feature_dim, seq_len),
|
| 257 |
+
strides=(1, num_kv_heads * feature_dim),
|
| 258 |
+
offsets=(0, kv_block_idx * kv_block_size),
|
| 259 |
+
block_shape=(feature_dim, kv_block_size),
|
| 260 |
+
order=(0, 1)
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
k_block = tl.load(k_block_ptr).to(tl.bfloat16)
|
| 264 |
+
v_block = tl.load(v_block_ptr).to(tl.bfloat16)
|
| 265 |
+
|
| 266 |
+
attention_scores = tl.dot(q_tile, k_block)
|
| 267 |
+
attention_probs = tl.exp2(attention_scores - lse_vals[:, None]) * LN_2
|
| 268 |
+
|
| 269 |
+
grad_times_v = tl.dot(grad_o_tile, v_block)
|
| 270 |
+
grad_scores = attention_probs * (grad_times_v - delta_vals[:, None])
|
| 271 |
+
grad_q_accum += tl.dot(grad_scores.to(tl.bfloat16), tl.trans(k_block.to(tl.bfloat16)))
|
| 272 |
+
|
| 273 |
+
grad_q_accum = grad_q_accum * softmax_scale * LOG2_E
|
| 274 |
+
tl.store(grad_q_base_ptr + q_offsets, grad_q_accum.to(q_ptr.dtype.element_ty))
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
@torch.compile
|
| 278 |
+
@torch.no_grad()
|
| 279 |
+
def mosaic_block_mask(
|
| 280 |
+
block_indices: torch.LongTensor,
|
| 281 |
+
):
|
| 282 |
+
batch_size, num_blocks, num_heads, _ = block_indices.shape
|
| 283 |
+
|
| 284 |
+
block_mask = torch.zeros(
|
| 285 |
+
batch_size, num_blocks, num_heads, num_blocks,
|
| 286 |
+
dtype=torch.bool, device=block_indices.device
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
batch_idx = torch.arange(batch_size, device=block_indices.device)[:, None, None, None]
|
| 290 |
+
q_block_idx = torch.arange(num_blocks, device=block_indices.device)[None, :, None, None]
|
| 291 |
+
head_idx = torch.arange(num_heads, device=block_indices.device)[None, None, :, None]
|
| 292 |
+
|
| 293 |
+
block_mask[batch_idx, q_block_idx, head_idx, block_indices] = True
|
| 294 |
+
|
| 295 |
+
block_mask_transposed = block_mask.permute(0, 2, 3, 1).contiguous()
|
| 296 |
+
|
| 297 |
+
return block_mask_transposed
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@triton.autotune(
|
| 301 |
+
configs=get_autotuning_configs([16, 32]),
|
| 302 |
+
key=['seq_len', 'feature_dim'],
|
| 303 |
+
)
|
| 304 |
+
@triton.jit
|
| 305 |
+
def mosaic_attn_bwd_kv_kernel(
|
| 306 |
+
q_ptr, k_ptr, v_ptr, lse_ptr, delta_ptr,
|
| 307 |
+
grad_o_ptr, grad_k_ptr, grad_v_ptr,
|
| 308 |
+
block_mask_ptr,
|
| 309 |
+
softmax_scale: tl.constexpr,
|
| 310 |
+
seq_len: tl.constexpr,
|
| 311 |
+
num_kv_heads: tl.constexpr,
|
| 312 |
+
num_q_heads: tl.constexpr,
|
| 313 |
+
q_heads_per_kv_head: tl.constexpr,
|
| 314 |
+
feature_dim: tl.constexpr,
|
| 315 |
+
kv_block_size: tl.constexpr,
|
| 316 |
+
q_tile_size: tl.constexpr,
|
| 317 |
+
):
|
| 318 |
+
LOG2_E: tl.constexpr = 1.44269504089
|
| 319 |
+
LN_2: tl.constexpr = 0.69314718056
|
| 320 |
+
|
| 321 |
+
kv_block_id = tl.program_id(0)
|
| 322 |
+
batch_kv_head_id = tl.program_id(1)
|
| 323 |
+
|
| 324 |
+
batch_idx = batch_kv_head_id // num_kv_heads
|
| 325 |
+
kv_head_idx = batch_kv_head_id % num_kv_heads
|
| 326 |
+
batch_offset = batch_idx * seq_len
|
| 327 |
+
|
| 328 |
+
num_blocks_in_seq = seq_len // kv_block_size
|
| 329 |
+
tiles_per_block = kv_block_size // q_tile_size
|
| 330 |
+
|
| 331 |
+
fine_mask_start = (
|
| 332 |
+
batch_idx * num_kv_heads * num_blocks_in_seq * num_blocks_in_seq +
|
| 333 |
+
kv_head_idx * num_blocks_in_seq * num_blocks_in_seq +
|
| 334 |
+
kv_block_id * num_blocks_in_seq
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
k_block_ptr = tl.make_block_ptr(
|
| 338 |
+
k_ptr + (batch_offset * num_kv_heads + kv_head_idx) * feature_dim,
|
| 339 |
+
(seq_len, feature_dim), (num_kv_heads * feature_dim, 1),
|
| 340 |
+
(kv_block_id * kv_block_size, 0), (kv_block_size, feature_dim), (1, 0)
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
v_block_ptr = tl.make_block_ptr(
|
| 344 |
+
v_ptr + (batch_offset * num_kv_heads + kv_head_idx) * feature_dim,
|
| 345 |
+
(seq_len, feature_dim), (num_kv_heads * feature_dim, 1),
|
| 346 |
+
(kv_block_id * kv_block_size, 0), (kv_block_size, feature_dim), (1, 0)
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
grad_k_ptr = tl.make_block_ptr(
|
| 350 |
+
grad_k_ptr + (batch_offset * num_kv_heads + kv_head_idx) * feature_dim,
|
| 351 |
+
(seq_len, feature_dim), (num_kv_heads * feature_dim, 1),
|
| 352 |
+
(kv_block_id * kv_block_size, 0), (kv_block_size, feature_dim), (1, 0)
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
grad_v_ptr = tl.make_block_ptr(
|
| 356 |
+
grad_v_ptr + (batch_offset * num_kv_heads + kv_head_idx) * feature_dim,
|
| 357 |
+
(seq_len, feature_dim), (num_kv_heads * feature_dim, 1),
|
| 358 |
+
(kv_block_id * kv_block_size, 0), (kv_block_size, feature_dim), (1, 0)
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
k_block = tl.load(k_block_ptr).to(tl.bfloat16)
|
| 362 |
+
v_block = tl.load(v_block_ptr).to(tl.bfloat16)
|
| 363 |
+
|
| 364 |
+
grad_k_accum = tl.zeros([kv_block_size, feature_dim], dtype=tl.float32)
|
| 365 |
+
grad_v_accum = tl.zeros([kv_block_size, feature_dim], dtype=tl.float32)
|
| 366 |
+
|
| 367 |
+
for q_block_id in range(num_blocks_in_seq):
|
| 368 |
+
is_connected = tl.load(block_mask_ptr + fine_mask_start + q_block_id)
|
| 369 |
+
|
| 370 |
+
if is_connected:
|
| 371 |
+
for tile_in_block in range(tiles_per_block):
|
| 372 |
+
tile_idx = q_block_id * tiles_per_block + tile_in_block
|
| 373 |
+
q_tile_start = tile_idx * q_tile_size
|
| 374 |
+
|
| 375 |
+
q_tile_ptr = tl.make_block_ptr(
|
| 376 |
+
base=q_ptr + (batch_offset + q_tile_start) * num_q_heads * feature_dim,
|
| 377 |
+
shape=(q_tile_size, num_q_heads, feature_dim),
|
| 378 |
+
strides=(num_q_heads * feature_dim, feature_dim, 1),
|
| 379 |
+
offsets=(0, kv_head_idx * q_heads_per_kv_head, 0),
|
| 380 |
+
block_shape=(q_tile_size, q_heads_per_kv_head, feature_dim),
|
| 381 |
+
order=(0, 1, 2),
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
grad_o_tile_ptr = tl.make_block_ptr(
|
| 385 |
+
base=grad_o_ptr + (batch_offset + q_tile_start) * num_q_heads * feature_dim,
|
| 386 |
+
shape=(q_tile_size, num_q_heads, feature_dim),
|
| 387 |
+
strides=(num_q_heads * feature_dim, feature_dim, 1),
|
| 388 |
+
offsets=(0, kv_head_idx * q_heads_per_kv_head, 0),
|
| 389 |
+
block_shape=(q_tile_size, q_heads_per_kv_head, feature_dim),
|
| 390 |
+
order=(0, 1, 2),
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
lse_tile_ptr = tl.make_block_ptr(
|
| 394 |
+
base=lse_ptr + (batch_offset + q_tile_start) * num_q_heads,
|
| 395 |
+
shape=(q_tile_size, num_q_heads),
|
| 396 |
+
strides=(num_q_heads, 1),
|
| 397 |
+
offsets=(0, kv_head_idx * q_heads_per_kv_head),
|
| 398 |
+
block_shape=(q_tile_size, q_heads_per_kv_head),
|
| 399 |
+
order=(1, 0),
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
delta_tile_ptr = tl.make_block_ptr(
|
| 403 |
+
base=delta_ptr + (batch_offset + q_tile_start) * num_q_heads,
|
| 404 |
+
shape=(q_tile_size, num_q_heads),
|
| 405 |
+
strides=(num_q_heads, 1),
|
| 406 |
+
offsets=(0, kv_head_idx * q_heads_per_kv_head),
|
| 407 |
+
block_shape=(q_tile_size, q_heads_per_kv_head),
|
| 408 |
+
order=(1, 0),
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
q_tile = tl.load(q_tile_ptr) * softmax_scale * LOG2_E
|
| 412 |
+
q_tile = tl.reshape(q_tile, (q_tile_size * q_heads_per_kv_head, feature_dim))
|
| 413 |
+
q_tile = q_tile.to(tl.bfloat16)
|
| 414 |
+
|
| 415 |
+
grad_o_block = tl.load(grad_o_tile_ptr)
|
| 416 |
+
grad_o_block = tl.reshape(grad_o_block, (q_tile_size * q_heads_per_kv_head, feature_dim))
|
| 417 |
+
grad_o_block = grad_o_block.to(tl.bfloat16)
|
| 418 |
+
|
| 419 |
+
lse_vals = tl.load(lse_tile_ptr)
|
| 420 |
+
lse_vals = tl.reshape(lse_vals, (q_tile_size * q_heads_per_kv_head,))
|
| 421 |
+
|
| 422 |
+
delta_vals = tl.load(delta_tile_ptr)
|
| 423 |
+
delta_vals = tl.reshape(delta_vals, (q_tile_size * q_heads_per_kv_head,))
|
| 424 |
+
|
| 425 |
+
attention_scores = tl.dot(k_block, tl.trans(q_tile))
|
| 426 |
+
attention_probs = tl.exp2(attention_scores - lse_vals[None, :])
|
| 427 |
+
grad_v_accum += tl.dot(attention_probs.to(tl.bfloat16), grad_o_block)
|
| 428 |
+
grad_times_v = tl.dot(v_block, tl.trans(grad_o_block))
|
| 429 |
+
grad_scores = attention_probs * (grad_times_v - delta_vals[None, :]) * LN_2
|
| 430 |
+
grad_k_accum += tl.dot(grad_scores.to(tl.bfloat16), q_tile)
|
| 431 |
+
|
| 432 |
+
tl.store(grad_k_ptr, grad_k_accum.to(grad_k_ptr.dtype.element_ty))
|
| 433 |
+
tl.store(grad_v_ptr, grad_v_accum.to(grad_v_ptr.dtype.element_ty))
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def mosaic_attn_bwd(
|
| 437 |
+
q: torch.Tensor,
|
| 438 |
+
k: torch.Tensor,
|
| 439 |
+
v: torch.Tensor,
|
| 440 |
+
output: torch.Tensor,
|
| 441 |
+
lse: torch.Tensor,
|
| 442 |
+
grad_o: torch.Tensor,
|
| 443 |
+
softmax_scale: float,
|
| 444 |
+
block_indices: torch.LongTensor,
|
| 445 |
+
block_size: int,
|
| 446 |
+
):
|
| 447 |
+
batch_size, seq_len, num_kv_heads, feature_dim = k.shape
|
| 448 |
+
num_q_heads = q.shape[2]
|
| 449 |
+
num_kv_blocks_per_q_block = block_indices.shape[-1]
|
| 450 |
+
q_heads_per_kv_head = num_q_heads // num_kv_heads
|
| 451 |
+
num_blocks_in_seq = seq_len // block_size
|
| 452 |
+
|
| 453 |
+
grad_q = torch.empty_like(q)
|
| 454 |
+
grad_k = torch.empty_like(k)
|
| 455 |
+
grad_v = torch.empty_like(v)
|
| 456 |
+
|
| 457 |
+
block_mask = mosaic_block_mask(block_indices)
|
| 458 |
+
|
| 459 |
+
delta = (output * grad_o).sum(dim=-1)
|
| 460 |
+
|
| 461 |
+
grid_dq = lambda META: (
|
| 462 |
+
triton.cdiv(seq_len, META['q_tile_size']),
|
| 463 |
+
q_heads_per_kv_head,
|
| 464 |
+
batch_size * num_kv_heads
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
mosaic_attn_bwd_q_kernel[grid_dq](
|
| 468 |
+
q_ptr=q,
|
| 469 |
+
k_ptr=k,
|
| 470 |
+
v_ptr=v,
|
| 471 |
+
lse_ptr=lse,
|
| 472 |
+
delta_ptr=delta,
|
| 473 |
+
grad_o_ptr=grad_o,
|
| 474 |
+
grad_q_ptr=grad_q,
|
| 475 |
+
block_indices_ptr=block_indices,
|
| 476 |
+
softmax_scale=softmax_scale,
|
| 477 |
+
seq_len=seq_len,
|
| 478 |
+
num_kv_heads=num_kv_heads,
|
| 479 |
+
num_q_heads=num_q_heads,
|
| 480 |
+
q_heads_per_kv_head=q_heads_per_kv_head,
|
| 481 |
+
feature_dim=feature_dim,
|
| 482 |
+
kv_block_size=block_size,
|
| 483 |
+
num_kv_blocks_per_q_block=num_kv_blocks_per_q_block,
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
grid_dkv = (num_blocks_in_seq, batch_size * num_kv_heads)
|
| 487 |
+
|
| 488 |
+
mosaic_attn_bwd_kv_kernel[grid_dkv](
|
| 489 |
+
q_ptr=q,
|
| 490 |
+
k_ptr=k,
|
| 491 |
+
v_ptr=v,
|
| 492 |
+
lse_ptr=lse,
|
| 493 |
+
delta_ptr=delta,
|
| 494 |
+
grad_o_ptr=grad_o,
|
| 495 |
+
grad_k_ptr=grad_k,
|
| 496 |
+
grad_v_ptr=grad_v,
|
| 497 |
+
block_mask_ptr=block_mask,
|
| 498 |
+
softmax_scale=softmax_scale,
|
| 499 |
+
seq_len=seq_len,
|
| 500 |
+
num_kv_heads=num_kv_heads,
|
| 501 |
+
num_q_heads=num_q_heads,
|
| 502 |
+
q_heads_per_kv_head=q_heads_per_kv_head,
|
| 503 |
+
feature_dim=feature_dim,
|
| 504 |
+
kv_block_size=block_size,
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
return grad_q, grad_k, grad_v
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
class MosaicAttnFunction(torch.autograd.Function):
|
| 511 |
+
|
| 512 |
+
@staticmethod
|
| 513 |
+
@torch.amp.custom_fwd(device_type='cuda')
|
| 514 |
+
def forward(
|
| 515 |
+
ctx: torch.autograd.function.FunctionCtx,
|
| 516 |
+
q: torch.Tensor,
|
| 517 |
+
k: torch.Tensor,
|
| 518 |
+
v: torch.Tensor,
|
| 519 |
+
block_indices: torch.Tensor,
|
| 520 |
+
block_size: int,
|
| 521 |
+
softmax_scale: float
|
| 522 |
+
):
|
| 523 |
+
q, k, v, block_indices = map(lambda x: x.contiguous(), (q, k, v, block_indices))
|
| 524 |
+
|
| 525 |
+
ctx.dtype = q.dtype
|
| 526 |
+
|
| 527 |
+
output, lse = mosaic_attn_fwd(
|
| 528 |
+
q=q, k=k, v=v,
|
| 529 |
+
block_indices=block_indices,
|
| 530 |
+
block_size=block_size,
|
| 531 |
+
softmax_scale=softmax_scale,
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
ctx.save_for_backward(q, k, v, output, lse, block_indices)
|
| 535 |
+
ctx.block_size = block_size
|
| 536 |
+
ctx.softmax_scale = softmax_scale
|
| 537 |
+
|
| 538 |
+
return output.to(q.dtype)
|
| 539 |
+
|
| 540 |
+
@staticmethod
|
| 541 |
+
@torch.amp.custom_bwd(device_type='cuda')
|
| 542 |
+
def backward(
|
| 543 |
+
ctx: torch.autograd.function.FunctionCtx,
|
| 544 |
+
grad_o: torch.Tensor
|
| 545 |
+
):
|
| 546 |
+
q, k, v, output, lse, block_indices = ctx.saved_tensors
|
| 547 |
+
grad_o = grad_o.contiguous()
|
| 548 |
+
grad_q, grad_k, grad_v = mosaic_attn_bwd(
|
| 549 |
+
q=q, k=k, v=v, output=output, lse=lse, grad_o=grad_o,
|
| 550 |
+
softmax_scale=ctx.softmax_scale,
|
| 551 |
+
block_indices=block_indices,
|
| 552 |
+
block_size=ctx.block_size,
|
| 553 |
+
)
|
| 554 |
+
return grad_q, grad_k, grad_v, None, None, None
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
def mosaic_sparse_attn(
|
| 558 |
+
q: torch.Tensor,
|
| 559 |
+
k: torch.Tensor,
|
| 560 |
+
v: torch.Tensor,
|
| 561 |
+
block_indices: torch.LongTensor,
|
| 562 |
+
block_size: int,
|
| 563 |
+
softmax_scale: float = None,
|
| 564 |
+
):
|
| 565 |
+
softmax_scale = q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale
|
| 566 |
+
return MosaicAttnFunction.apply(q, k, v, block_indices, block_size, softmax_scale)
|
primitives.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Primitive building blocks for the Mosaic transformer.
|
| 3 |
+
|
| 4 |
+
Components:
|
| 5 |
+
- Block-sparse attention with learned strategy weighting (local block, compressed,
|
| 6 |
+
and top-k selection branches combined with a learned gate)
|
| 7 |
+
- Rotary positional embeddings (RoPE) for 2D lon/lat
|
| 8 |
+
- Cross-attention interpolation between grids
|
| 9 |
+
- HEALPix spatial up/downsampling
|
| 10 |
+
- Conditional SwiGLU FFN with noise injection
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
from einops import rearrange, reduce, repeat
|
| 16 |
+
from torch.nn import RMSNorm
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from flash_attn import flash_attn_func # FlashAttention v2
|
| 20 |
+
except ImportError:
|
| 21 |
+
import flash_attn_interface as fa # FlashAttention v3
|
| 22 |
+
flash_attn_func = fa.flash_attn_func
|
| 23 |
+
|
| 24 |
+
from utils import get_healpix_grid, get_neighbors, rad_to_xyz
|
| 25 |
+
from ops import mosaic_sparse_attn
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def block_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, block_size: int):
|
| 29 |
+
batch_size = q.shape[0]
|
| 30 |
+
q, k, v = map(lambda x: rearrange(x, 'b (nb bs) h d -> (b nb) bs h d', bs=block_size), (q, k, v))
|
| 31 |
+
o_ba = flash_attn_func(q, k, v)
|
| 32 |
+
return rearrange(o_ba, '(b nb) bs h d -> b (nb bs) h d', b=batch_size)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@torch.no_grad()
|
| 36 |
+
def attn_topk(q: torch.Tensor, k: torch.Tensor, block_count: int):
|
| 37 |
+
Hq, Hk = q.shape[2], k.shape[2]
|
| 38 |
+
G = Hq // Hk
|
| 39 |
+
k = k.repeat_interleave(G, dim=2)
|
| 40 |
+
|
| 41 |
+
scores = torch.matmul(
|
| 42 |
+
rearrange(q, 'b t h d -> b h t d'),
|
| 43 |
+
rearrange(k, 'b t h d -> b h d t')
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
if Hq != Hk:
|
| 47 |
+
scores = reduce(scores, 'b (g h) t k -> b h t k', 'mean', g=G)
|
| 48 |
+
|
| 49 |
+
scores = rearrange(scores, 'b h t k -> b t h k')
|
| 50 |
+
top_indices = scores.topk(k=block_count, dim=-1, largest=True)[1]
|
| 51 |
+
return top_indices
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def mosaic_attn_func(
|
| 55 |
+
q, k, v,
|
| 56 |
+
weight_ba_cmp_slc,
|
| 57 |
+
block_attn_size, sparse_block_size, sparse_block_count,
|
| 58 |
+
block_attn_only, no_compression=False,
|
| 59 |
+
):
|
| 60 |
+
o_ba = block_attention(q, k, v, block_attn_size)
|
| 61 |
+
|
| 62 |
+
if block_attn_only:
|
| 63 |
+
return o_ba
|
| 64 |
+
|
| 65 |
+
q_cmp = reduce(q, 'b (nb bs) h d -> b nb h d', 'mean', bs=sparse_block_size)
|
| 66 |
+
k_cmp = reduce(k, 'b (nb bs) h d -> b nb h d', 'mean', bs=sparse_block_size)
|
| 67 |
+
|
| 68 |
+
if no_compression:
|
| 69 |
+
block_indices = attn_topk(q_cmp, k_cmp, sparse_block_count)
|
| 70 |
+
o_slc = mosaic_sparse_attn(q, k, v, block_indices, sparse_block_size)
|
| 71 |
+
w_ba = weight_ba_cmp_slc[0]
|
| 72 |
+
w_slc = weight_ba_cmp_slc[2]
|
| 73 |
+
w_sum = w_ba + w_slc + 1e-8
|
| 74 |
+
return o_ba * (w_ba / w_sum) + o_slc * (w_slc / w_sum)
|
| 75 |
+
|
| 76 |
+
v_cmp = reduce(v, 'b (nb bs) h d -> b nb h d', 'mean', bs=sparse_block_size)
|
| 77 |
+
o_cmp = flash_attn_func(q_cmp, k_cmp, v_cmp)
|
| 78 |
+
o_cmp = o_cmp.repeat_interleave(sparse_block_size, dim=1)
|
| 79 |
+
|
| 80 |
+
if sparse_block_count == 0:
|
| 81 |
+
w_ba = weight_ba_cmp_slc[0]
|
| 82 |
+
w_cmp = weight_ba_cmp_slc[1]
|
| 83 |
+
w_sum = w_ba + w_cmp + 1e-8
|
| 84 |
+
return o_ba * (w_ba / w_sum) + o_cmp * (w_cmp / w_sum)
|
| 85 |
+
|
| 86 |
+
block_indices = attn_topk(q_cmp, k_cmp, sparse_block_count)
|
| 87 |
+
o_slc = mosaic_sparse_attn(q, k, v, block_indices, sparse_block_size)
|
| 88 |
+
|
| 89 |
+
return o_ba * weight_ba_cmp_slc[0] + o_cmp * weight_ba_cmp_slc[1] + o_slc * weight_ba_cmp_slc[2]
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class cSwiGLU(nn.Module):
|
| 93 |
+
def __init__(self, dim: int, hidden_dim: int, noise_dim: int):
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False)
|
| 96 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
| 97 |
+
self.act_fn = nn.SiLU()
|
| 98 |
+
|
| 99 |
+
if noise_dim > 0:
|
| 100 |
+
self.noise_bias = nn.Linear(noise_dim, hidden_dim, bias=False)
|
| 101 |
+
|
| 102 |
+
def forward(self, x: torch.Tensor, z: torch.Tensor = None):
|
| 103 |
+
noise = self.noise_bias(z).unsqueeze(0) if z is not None else 0
|
| 104 |
+
x1, x3 = self.w13(x).chunk(2, dim=-1)
|
| 105 |
+
return self.w2(self.act_fn(x1 + noise) * x3)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class RoPE(nn.Module):
|
| 109 |
+
def __init__(self, dim, theta=10000):
|
| 110 |
+
super().__init__()
|
| 111 |
+
assert dim % 2 == 0
|
| 112 |
+
self.dim = dim
|
| 113 |
+
self.theta = theta
|
| 114 |
+
|
| 115 |
+
def initialize_rope(self, positions):
|
| 116 |
+
base_freqs = 1. / (self.theta ** (torch.arange(0, self.dim // 2, 2).float() / (self.dim // 2)))
|
| 117 |
+
lon_pos = torch.deg2rad(positions[:, 0:1])
|
| 118 |
+
lat_pos = torch.deg2rad(positions[:, 1:2])
|
| 119 |
+
lon_freqs = torch.matmul(lon_pos, base_freqs.unsqueeze(0))
|
| 120 |
+
lat_freqs = torch.matmul(lat_pos, base_freqs.unsqueeze(0))
|
| 121 |
+
freqs = torch.cat([lon_freqs, lat_freqs], dim=-1)
|
| 122 |
+
self.register_buffer('cos_freqs', freqs.cos().contiguous(), persistent=True)
|
| 123 |
+
self.register_buffer('sin_freqs', freqs.sin().contiguous(), persistent=True)
|
| 124 |
+
|
| 125 |
+
@staticmethod
|
| 126 |
+
def rotate_half(x):
|
| 127 |
+
x = rearrange(x, '... (d r) -> ... d r', r=2)
|
| 128 |
+
x1, x2 = x.unbind(dim=-1)
|
| 129 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 130 |
+
return rearrange(x, '... d r -> ... (d r)')
|
| 131 |
+
|
| 132 |
+
def forward(self, x):
|
| 133 |
+
cos = self.cos_freqs.unsqueeze(0).unsqueeze(2).repeat_interleave(2, dim=-1)
|
| 134 |
+
sin = self.sin_freqs.unsqueeze(0).unsqueeze(2).repeat_interleave(2, dim=-1)
|
| 135 |
+
return (x.float() * cos + self.rotate_half(x.float()) * sin).to(x.dtype)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class MosaicAttention(nn.Module):
|
| 139 |
+
def __init__(self, config, block_attn_only: bool, no_compression: bool = False):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.block_attn_only = block_attn_only
|
| 142 |
+
self.no_compression = no_compression
|
| 143 |
+
self.block_attn_size = config.block_attn_size
|
| 144 |
+
self.sparse_block_size = config.sparse_block_size
|
| 145 |
+
self.sparse_block_count = config.sparse_block_count
|
| 146 |
+
|
| 147 |
+
q_heads = config.num_heads
|
| 148 |
+
gqa_ratio = config.gqa_ratio
|
| 149 |
+
dim = config.dim
|
| 150 |
+
qkv_compress_ratio = config.qkv_compress_ratio
|
| 151 |
+
rope = config.rope
|
| 152 |
+
rope_theta = config.rope_theta
|
| 153 |
+
|
| 154 |
+
kv_heads = q_heads // gqa_ratio
|
| 155 |
+
head_dim = int(dim // q_heads // qkv_compress_ratio)
|
| 156 |
+
|
| 157 |
+
self.q_heads = q_heads
|
| 158 |
+
self.kv_heads = kv_heads
|
| 159 |
+
|
| 160 |
+
self.to_q = nn.Linear(dim, q_heads * head_dim, bias=False)
|
| 161 |
+
self.to_k = nn.Linear(dim, kv_heads * head_dim, bias=False)
|
| 162 |
+
self.to_v = nn.Linear(dim, kv_heads * head_dim, bias=False)
|
| 163 |
+
self.to_o = nn.Linear(q_heads * head_dim, dim, bias=False)
|
| 164 |
+
|
| 165 |
+
self.q_rope = RoPE(head_dim, rope_theta) if rope else None
|
| 166 |
+
self.k_rope = RoPE(head_dim, rope_theta) if rope else None
|
| 167 |
+
|
| 168 |
+
if block_attn_only:
|
| 169 |
+
self.to_strategy_combine_mlp = None
|
| 170 |
+
else:
|
| 171 |
+
self.to_strategy_combine_mlp = nn.Linear(dim, 3 * q_heads, bias=False)
|
| 172 |
+
|
| 173 |
+
def generate_strategy_weights(self, x):
|
| 174 |
+
if self.block_attn_only:
|
| 175 |
+
return [None, None, None]
|
| 176 |
+
strategy_logits = self.to_strategy_combine_mlp(x)
|
| 177 |
+
strategy_logits = rearrange(strategy_logits, 't b (h s) -> s b t h', h=self.q_heads)
|
| 178 |
+
strategy_weights = torch.softmax(strategy_logits.float(), dim=0).type_as(x)
|
| 179 |
+
strategy_weights = strategy_weights.unsqueeze(-1)
|
| 180 |
+
return strategy_weights
|
| 181 |
+
|
| 182 |
+
def forward(self, x):
|
| 183 |
+
q = self.to_q(x)
|
| 184 |
+
k = self.to_k(x)
|
| 185 |
+
v = self.to_v(x)
|
| 186 |
+
|
| 187 |
+
strategy_weights = self.generate_strategy_weights(x)
|
| 188 |
+
|
| 189 |
+
q = rearrange(q, 's b (h d) -> b s h d', h=self.q_heads)
|
| 190 |
+
k = rearrange(k, 's b (h d) -> b s h d', h=self.kv_heads)
|
| 191 |
+
v = rearrange(v, 's b (h d) -> b s h d', h=self.kv_heads)
|
| 192 |
+
|
| 193 |
+
if self.q_rope is not None:
|
| 194 |
+
q = self.q_rope(q)
|
| 195 |
+
k = self.k_rope(k)
|
| 196 |
+
|
| 197 |
+
output = mosaic_attn_func(
|
| 198 |
+
q=q, k=k, v=v,
|
| 199 |
+
weight_ba_cmp_slc=strategy_weights,
|
| 200 |
+
block_attn_size=self.block_attn_size,
|
| 201 |
+
sparse_block_size=self.sparse_block_size,
|
| 202 |
+
sparse_block_count=self.sparse_block_count,
|
| 203 |
+
block_attn_only=self.block_attn_only,
|
| 204 |
+
no_compression=self.no_compression,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
output = rearrange(output, 'b s h d -> s b (h d)')
|
| 208 |
+
output = self.to_o(output)
|
| 209 |
+
return output
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class MosaicBlock(nn.Module):
|
| 213 |
+
def __init__(self, config, block_attn_only: bool, no_compression: bool = False):
|
| 214 |
+
super().__init__()
|
| 215 |
+
dim = config.dim
|
| 216 |
+
noise_dim = config.noise_dim
|
| 217 |
+
mlp_ratio = config.mlp_ratio
|
| 218 |
+
|
| 219 |
+
self.attention = MosaicAttention(config, block_attn_only, no_compression)
|
| 220 |
+
self.norm1 = RMSNorm(dim, elementwise_affine=config.rmsnorm_elementwise_affine)
|
| 221 |
+
self.norm2 = RMSNorm(dim, elementwise_affine=config.rmsnorm_elementwise_affine)
|
| 222 |
+
self.ffn = cSwiGLU(dim, int(dim * mlp_ratio), noise_dim)
|
| 223 |
+
|
| 224 |
+
def forward(self, x: torch.Tensor, z: torch.Tensor = None):
|
| 225 |
+
x = x + self.attention(self.norm1(x))
|
| 226 |
+
x = x + self.ffn(self.norm2(x), z)
|
| 227 |
+
return x
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class CrossAttentionInterpolate(nn.Module):
|
| 231 |
+
space_dim = 3
|
| 232 |
+
|
| 233 |
+
def __init__(self, config):
|
| 234 |
+
super().__init__()
|
| 235 |
+
self.k_neighbors = config.k_neighbors
|
| 236 |
+
|
| 237 |
+
dim = config.dim
|
| 238 |
+
num_heads = config.num_heads
|
| 239 |
+
|
| 240 |
+
head_dim = dim // num_heads
|
| 241 |
+
self.num_heads = num_heads
|
| 242 |
+
self.head_dim = head_dim
|
| 243 |
+
self.scale = head_dim ** -0.5
|
| 244 |
+
|
| 245 |
+
self.kv_norm = RMSNorm(dim, elementwise_affine=config.rmsnorm_elementwise_affine)
|
| 246 |
+
self.to_q = nn.Linear(self.space_dim, dim, bias=False)
|
| 247 |
+
self.to_kv = nn.Linear(dim, 2 * dim, bias=False)
|
| 248 |
+
self.to_o = nn.Linear(dim, dim, bias=False)
|
| 249 |
+
|
| 250 |
+
@torch.no_grad()
|
| 251 |
+
def initialize_interpolation_scheme(self, pos_from_rad, pos_to_rad):
|
| 252 |
+
neighbors_np = get_neighbors(pos_from_rad.cpu().numpy(), pos_to_rad.cpu().numpy(), k=self.k_neighbors)
|
| 253 |
+
neighbors = torch.from_numpy(neighbors_np).long().to(pos_from_rad.device).contiguous()
|
| 254 |
+
|
| 255 |
+
pos_to_xyz = rad_to_xyz(pos_to_rad)
|
| 256 |
+
pos_from_xyz = rad_to_xyz(pos_from_rad)
|
| 257 |
+
|
| 258 |
+
rel_pos_xyz = (pos_to_xyz.unsqueeze(1) - pos_from_xyz[neighbors]).contiguous()
|
| 259 |
+
norm_rel_pos_xyz = torch.nn.functional.normalize(rel_pos_xyz, dim=-1).contiguous()
|
| 260 |
+
|
| 261 |
+
self.register_buffer('neighbors', neighbors, persistent=True)
|
| 262 |
+
self.register_buffer('rel_pos', norm_rel_pos_xyz, persistent=True)
|
| 263 |
+
|
| 264 |
+
def forward(self, x_from: torch.Tensor):
|
| 265 |
+
if self.neighbors is None or self.rel_pos is None:
|
| 266 |
+
raise ValueError("Interpolation scheme not initialized.")
|
| 267 |
+
|
| 268 |
+
q = self.to_q(self.rel_pos)
|
| 269 |
+
q = rearrange(q, 's k (h d) -> s k 1 h d', h=self.num_heads)
|
| 270 |
+
|
| 271 |
+
x = self.kv_norm(x_from)
|
| 272 |
+
|
| 273 |
+
kv = self.to_kv(x)
|
| 274 |
+
kv = rearrange(kv, 's b (n h d) -> n s b h d', h=self.num_heads, n=2)
|
| 275 |
+
|
| 276 |
+
k, v = kv[:, self.neighbors]
|
| 277 |
+
|
| 278 |
+
attn_scores = (q * k).sum(dim=-1, keepdim=True) * self.scale
|
| 279 |
+
attn_weights = torch.softmax(attn_scores, dim=1, dtype=torch.float32).type_as(k)
|
| 280 |
+
out = (attn_weights * v).sum(dim=1)
|
| 281 |
+
|
| 282 |
+
out = rearrange(out, 's b h d -> s b (h d)')
|
| 283 |
+
out = self.to_o(out)
|
| 284 |
+
return out
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class NoiseGenerator(nn.Module):
|
| 288 |
+
def __init__(self, noise_dim: int, seed: int):
|
| 289 |
+
super().__init__()
|
| 290 |
+
self.seed = seed
|
| 291 |
+
self.to_noise = nn.Linear(noise_dim, noise_dim, bias=False)
|
| 292 |
+
self.generator = None
|
| 293 |
+
|
| 294 |
+
def forward(self, num_samples: int, device: torch.device, dtype: torch.dtype):
|
| 295 |
+
if self.generator is None:
|
| 296 |
+
self.generator = torch.Generator(device=device)
|
| 297 |
+
self.generator.manual_seed(self.seed)
|
| 298 |
+
|
| 299 |
+
noise = torch.randn((num_samples, self.to_noise.in_features),
|
| 300 |
+
generator=self.generator, device=device, dtype=dtype)
|
| 301 |
+
noise = self.to_noise(noise)
|
| 302 |
+
return noise
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class HEALPixDownsample(nn.Module):
|
| 306 |
+
space_dim: int = 3
|
| 307 |
+
|
| 308 |
+
def __init__(self, in_dim, out_dim, nside_before, nside_after,
|
| 309 |
+
rmsnorm_elementwise_affine=True):
|
| 310 |
+
super().__init__()
|
| 311 |
+
self.factor = (nside_before // nside_after) ** 2
|
| 312 |
+
|
| 313 |
+
self.proj_x = nn.Linear(self.factor * in_dim, out_dim, bias=False)
|
| 314 |
+
self.proj_pos = nn.Linear(self.factor * self.space_dim, out_dim, bias=False)
|
| 315 |
+
self.norm = RMSNorm(out_dim, elementwise_affine=rmsnorm_elementwise_affine)
|
| 316 |
+
|
| 317 |
+
hp_grid_fine_xyz = rad_to_xyz(torch.deg2rad(get_healpix_grid(nside_before)))
|
| 318 |
+
hp_grid_coarse_xyz = rad_to_xyz(torch.deg2rad(get_healpix_grid(nside_after)))
|
| 319 |
+
|
| 320 |
+
pos = rearrange(hp_grid_fine_xyz, '(n f) d -> n f d', f=self.factor)
|
| 321 |
+
rel_pos = rearrange(pos - hp_grid_coarse_xyz[:, None], 'n f d -> n (f d)')
|
| 322 |
+
rel_pos = (rel_pos - rel_pos.mean(dim=0, keepdim=True)) / (rel_pos.std(dim=0, keepdim=True) + 1e-6)
|
| 323 |
+
|
| 324 |
+
self.register_buffer('rel_pos', rel_pos.contiguous(), persistent=True)
|
| 325 |
+
|
| 326 |
+
def forward(self, x: torch.Tensor):
|
| 327 |
+
x = rearrange(x, '(n f) b c -> n b (f c)', f=self.factor)
|
| 328 |
+
x = self.proj_x(x) + self.proj_pos(self.rel_pos).unsqueeze(1)
|
| 329 |
+
x = self.norm(x)
|
| 330 |
+
return x
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
class HEALPixUpsample(nn.Module):
|
| 334 |
+
space_dim: int = 3
|
| 335 |
+
|
| 336 |
+
def __init__(self, in_dim, out_dim, nside_before, nside_after,
|
| 337 |
+
rmsnorm_elementwise_affine=True):
|
| 338 |
+
super().__init__()
|
| 339 |
+
self.factor = (nside_after // nside_before) ** 2
|
| 340 |
+
|
| 341 |
+
self.proj_x = nn.Linear(in_dim, out_dim * self.factor, bias=False)
|
| 342 |
+
self.proj_pos = nn.Linear(self.factor * self.space_dim, out_dim * self.factor, bias=False)
|
| 343 |
+
self.norm = RMSNorm(out_dim, elementwise_affine=rmsnorm_elementwise_affine)
|
| 344 |
+
|
| 345 |
+
hp_grid_fine_xyz = rad_to_xyz(torch.deg2rad(get_healpix_grid(nside_after)))
|
| 346 |
+
hp_grid_coarse_xyz = rad_to_xyz(torch.deg2rad(get_healpix_grid(nside_before)))
|
| 347 |
+
|
| 348 |
+
children_pos_reshaped = rearrange(hp_grid_fine_xyz, '(n f) d -> n f d', f=self.factor)
|
| 349 |
+
rel_pos = rearrange(children_pos_reshaped - hp_grid_coarse_xyz[:, None], 'n f d -> n (f d)')
|
| 350 |
+
rel_pos = (rel_pos - rel_pos.mean(dim=0, keepdim=True)) / (rel_pos.std(dim=0, keepdim=True) + 1e-6)
|
| 351 |
+
|
| 352 |
+
self.register_buffer('rel_pos', rel_pos.contiguous(), persistent=True)
|
| 353 |
+
|
| 354 |
+
def forward(self, x: torch.Tensor, shortcut: torch.Tensor):
|
| 355 |
+
x = self.proj_x(x) + self.proj_pos(self.rel_pos).unsqueeze(1)
|
| 356 |
+
x = rearrange(x, 'n b (f d) -> (n f) b d', f=self.factor)
|
| 357 |
+
x = x + shortcut
|
| 358 |
+
x = self.norm(x)
|
| 359 |
+
return x
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0
|
| 2 |
+
einops>=0.6
|
| 3 |
+
healpy>=1.16
|
| 4 |
+
scikit-learn>=1.0
|
| 5 |
+
numpy>=1.24
|
| 6 |
+
zarr>=2.10
|
| 7 |
+
pandas>=1.5
|
| 8 |
+
triton>=2.0
|
| 9 |
+
flash-attn>=2.0
|
| 10 |
+
# For reading WeatherBench2 zarr stores directly from gs://
|
| 11 |
+
gcsfs>=2023.0
|
utils.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import healpy as hp
|
| 4 |
+
from sklearn.neighbors import BallTree
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def rad_to_xyz(lonlat: torch.Tensor):
|
| 8 |
+
"""Convert lon-lat (in radians) to unit sphere xyz."""
|
| 9 |
+
lon = lonlat[..., 0]
|
| 10 |
+
lat = lonlat[..., 1]
|
| 11 |
+
|
| 12 |
+
x = torch.cos(lat) * torch.cos(lon)
|
| 13 |
+
y = torch.cos(lat) * torch.sin(lon)
|
| 14 |
+
z = torch.sin(lat)
|
| 15 |
+
|
| 16 |
+
return torch.stack([x, y, z], axis=-1)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_healpix_grid(nside: int) -> torch.Tensor:
|
| 20 |
+
"""Return HEALPix grid coordinates as array of shape (npix, 2)."""
|
| 21 |
+
indices = np.arange(hp.nside2npix(nside))
|
| 22 |
+
theta, phi = hp.pix2ang(nside, indices, nest=True)
|
| 23 |
+
|
| 24 |
+
phi = np.rad2deg(phi)
|
| 25 |
+
theta = (90. - np.rad2deg(theta))
|
| 26 |
+
|
| 27 |
+
phi = torch.from_numpy(phi)
|
| 28 |
+
theta = torch.from_numpy(theta)
|
| 29 |
+
|
| 30 |
+
return torch.stack((phi, theta), axis=-1).float()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_neighbors(pos_from: np.ndarray, pos_to: np.ndarray, k: int = 8) -> tuple:
|
| 34 |
+
"""Build a BallTree and query k nearest neighbors with haversine metric."""
|
| 35 |
+
pos_from_rad = pos_from[:, ::-1]
|
| 36 |
+
pos_to_rad = pos_to[:, ::-1]
|
| 37 |
+
|
| 38 |
+
tree = BallTree(pos_from_rad, metric='haversine')
|
| 39 |
+
_, neighbors = tree.query(pos_to_rad, k=k)
|
| 40 |
+
return neighbors
|