maxxxzdn commited on
Commit
5f226eb
·
verified ·
1 Parent(s): 71ee02b

Initial release: Mosaic weather model (era5 + hres variants)

Browse files
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figures_weather/bsa.jpg filter=lfs diff=lfs merge=lfs -text
37
+ figures_weather/bsa_runtime.jpg filter=lfs diff=lfs merge=lfs -text
38
+ figures_weather/hurricane_tracking.jpg filter=lfs diff=lfs merge=lfs -text
39
+ figures_weather/main.jpg filter=lfs diff=lfs merge=lfs -text
40
+ figures_weather/results_hres.jpg filter=lfs diff=lfs merge=lfs -text
41
+ figures_weather/results_spectra_pareto.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ *.egg-info/
6
+ dist/
7
+ build/
8
+ .env
9
+ *.npz
10
+ checkpoints/
README.md ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ library_name: pytorch
4
+ tags:
5
+ - weather
6
+ - weather-forecasting
7
+ - climate
8
+ - atmospheric-science
9
+ - sparse-attention
10
+ - transformer
11
+ - probabilistic-forecasting
12
+ ---
13
+
14
+ # Mosaic — Block-Sparse Attention for Weather Forecasting
15
+
16
+ **Mosaic** is a probabilistic weather forecasting model that operates on native-resolution grids via mesh-aligned block-sparse attention. At 1.5° resolution with 214M parameters, Mosaic matches or outperforms models trained on 6× finer resolution on key variables, and individual ensemble members exhibit near-perfect spectral alignment across all resolved frequencies. A 24-member, 10-day forecast takes under 12 s on a single H100 GPU.
17
+
18
+ > **(Sparse) Attention to the Details: Preserving Spectral Fidelity in ML-based Weather Forecasting Models** \
19
+ > Maksim Zhdanov, Ana Lucic, Max Welling, Jan-Willem van de Meent \
20
+ > *ICML 2026* · [arXiv:2604.16429](https://arxiv.org/abs/2604.16429) · [GitHub](https://github.com/maxxxzdn/mosaic)
21
+
22
+ ![Spectral fidelity and skill–speed Pareto](figures_weather/results_spectra_pareto.jpg)
23
+
24
+ ## TL;DR
25
+
26
+ Mosaic addresses two distinct failure modes of spectral degradation in ML-based weather prediction:
27
+
28
+ 1. **Spectral damping** caused by deterministic training against ensemble means. Mosaic addresses this with learned functional perturbations that produce ensemble members preserving realistic spectral variability.
29
+ 2. **High-frequency aliasing** caused by compressive encoding onto a coarse latent grid. Mosaic operates at native resolution via block-sparse attention before any coarsening, eliminating the compress-first bottleneck.
30
+
31
+ The block-sparse attention captures long-range dependencies at **linear** cost by sharing keys and values across spatially adjacent queries arranged on the HEALPix mesh. Each query block jointly selects which key blocks to attend to.
32
+
33
+ ## Published Variants
34
+
35
+ This repository ships **two trained variants**, distinguished primarily by the data they were tuned on. They share the same Mosaic architecture and 82-channel variable set, but differ in training data, time cadence, history length, and normalization statistics.
36
+
37
+ | Variant | Training data | Native step | Input history | k-neighbours | Suggested input zarr |
38
+ |---------|---------------|-------------|---------------|---------------|---------------------|
39
+ | `era5` | ERA5 reanalysis only | 24 h | 2 states (48 h) | 24 | WB2 ERA5 1.5° |
40
+ | `hres` | ERA5 pretrain + HRES finetune | 6 h | 4 states (24 h) | 20 | WB2 HRES-fc0 1.5° |
41
+
42
+ Choose `era5` when initializing from reanalysis (matches the training distribution); choose `hres` when initializing from HRES analysis or a similar operational state.
43
+
44
+ ## Architecture
45
+
46
+ **Inputs.** 82 atmospheric channels at 1.5° equiangular resolution (240 lon × 121 lat = 29 040 points) plus 3 static channels and sinusoidal day/year time encodings.
47
+
48
+ **Backbone.** A U-Net of transformer blocks over the HEALPix mesh, where spatial neighbours occupy contiguous memory and queries can be grouped into hardware-aligned blocks:
49
+
50
+ | Stage | Nside | Hidden dim | Heads | Enc / Dec depth |
51
+ |------------|------:|----------:|------:|----------------:|
52
+ | Stage 1 | 64 | 768 | 12 | 4 / 2 |
53
+ | Stage 2 | 32 | 1024 | 16 | 4 / 2 |
54
+ | Bottleneck | 16 | 1280 | 20 | 2 |
55
+
56
+ Grouped-Query Attention with ratio 4 (3 KV heads per stage), 2D RoPE on (longitude, latitude), and additive noise injection in SwiGLU gates for ensemble generation. ~214M parameters total.
57
+
58
+ ![Block-sparse attention for weather forecasting](figures_weather/bsa.jpg)
59
+
60
+ **Block-sparse attention.** Three branches combined by learned gates: (i) **compression** — block-to-block coarse attention captures broad synoptic patterns at $\mathcal{O}(N^2/b^2)$; (ii) **selection** — each query block top-k-selects fine-scale key blocks at $\mathcal{O}(Nnb)$; (iii) **local** — full attention inside each block at $\mathcal{O}(Nb)$. Spatially close points occupy contiguous memory on the HEALPix mesh, enabling coalesced GPU reads and hardware-aligned block computation. Implemented as a single Triton kernel; in practice up to **61.8× faster than dense attention** and **9.4× faster than NSA**.
61
+
62
+ <p align="center">
63
+ <img src="figures_weather/healpix.jpg" alt="HEALPix mesh" width="48%">
64
+ <img src="figures_weather/bsa_runtime.jpg" alt="Runtime scaling" width="48%">
65
+ </p>
66
+
67
+ ### Variables (82 channels)
68
+
69
+ - **Surface (4):** `2m_temperature`, `10m_u_component_of_wind`, `10m_v_component_of_wind`, `mean_sea_level_pressure`
70
+ - **Pressure level (6 × 13 = 78):** `geopotential`, `specific_humidity`, `temperature`, `u_component_of_wind`, `v_component_of_wind`, `vertical_velocity` at levels [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000] hPa
71
+ - **Static (3, conditioning only — not in output):** `geopotential_at_surface`, `land_sea_mask`, `soil_type`
72
+
73
+ ## Results
74
+
75
+ Headline plot is at the top of this page: individual ensemble members preserve realistic kinetic-energy spectra (left, 1.5°; centre, 0.25°), and Mosaic sits on the favourable end of the skill–speed–memory Pareto (right). All metrics computed at 240 h lead time, 720 initial conditions throughout the 2020 test year (1.5° benchmark) and a single 6 h forecast (0.25° benchmark).
76
+
77
+ On the 0.25° HRES benchmark, Mosaic competes with state-of-the-art 0.25° models despite operating at 1.5°:
78
+
79
+ ![HRES benchmark results](figures_weather/results_hres.jpg)
80
+
81
+ And a case study of **Hurricane Ian (2022)** — Mosaic's ensemble correctly brackets the observed track 7 days ahead, with progressive narrowing of spread as lead time decreases:
82
+
83
+ ![Hurricane Ian ensemble tracks](figures_weather/hurricane_tracking.jpg)
84
+
85
+ See the paper for full benchmark tables, CRPS curves, and spread-to-skill ratios.
86
+
87
+ ## Hardware Requirements
88
+
89
+ - **GPU:** any CUDA GPU; 16 GB is enough for a 1-member rollout, A100/H100 recommended for multi-member ensembles
90
+ - **Memory:** ~9 GB GPU RAM for a 1-member, 40-step (10-day) rollout in float16
91
+ - **Throughput:** 24-member, 10-day forecast in under 12 s on a single H100
92
+ - **CUDA:** 11.8+ with matching `triton` and `flash-attn` versions
93
+
94
+ ## Installation
95
+
96
+ ```bash
97
+ pip install -r requirements.txt
98
+ pip install flash-attn --no-build-isolation # built separately; needs nvcc
99
+ ```
100
+
101
+ For reading data from Google Cloud Storage (WeatherBench2 zarr stores):
102
+
103
+ ```bash
104
+ pip install gcsfs
105
+ ```
106
+
107
+ ## Quick Start
108
+
109
+ ```bash
110
+ # ERA5 variant — 10-day forecast at 24 h resolution from ERA5 reanalysis
111
+ python inference.py --variant era5 \
112
+ --zarr gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-240x121_equiangular_with_poles_conservative.zarr \
113
+ --init-time "2020-01-01T00:00" \
114
+ --steps 10 --members 1 \
115
+ --output forecast_era5.npz
116
+
117
+ # HRES variant — 10-day forecast at 6 h resolution from HRES initial conditions
118
+ python inference.py --variant hres \
119
+ --zarr gs://weatherbench2/datasets/hres_t0/2016-2022-6h-240x121_equiangular_with_poles_conservative.zarr \
120
+ --init-time "2022-01-01T00:00" \
121
+ --steps 40 --members 1 \
122
+ --output forecast_hres.npz
123
+
124
+ # Ensemble forecast (16 members) — change --members
125
+ python inference.py --variant hres --zarr <...> --init-time "2020-06-15T12:00" \
126
+ --steps 40 --members 16 --output ensemble.npz
127
+ ```
128
+
129
+ `--variant` selects the checkpoint, normalization statistics, history length, time stride, and neighbour count automatically. Pass `--checkpoint` or `--norm-stats` to override the bundled defaults.
130
+
131
+ ## Output Format
132
+
133
+ The output `.npz` file contains:
134
+
135
+ | Array | Shape | Description |
136
+ |-------|-------|-------------|
137
+ | `forecasts` | `(members, steps, 240, 121, 82)` | Predicted states in physical units |
138
+ | `variables` | `(82,)` | Variable names |
139
+ | `lead_time_hours` | `(steps,)` | Lead times (era5: 24, 48, …; hres: 6, 12, …) |
140
+ | `init_time` | scalar | Initialization timestamp |
141
+ | `longitude` | `(240,)` | Longitude values (0 to 358.5°) |
142
+ | `latitude` | `(121,)` | Latitude values, South→North (−90 to 90°) |
143
+
144
+ ### Reading the output
145
+
146
+ ```python
147
+ import numpy as np
148
+
149
+ data = np.load("forecast_era5.npz", allow_pickle=True)
150
+ forecasts = data['forecasts'] # (members, steps, 240, 121, 82)
151
+ variables = list(data['variables']) # ['2m_temperature', ...]
152
+ lead_hours = data['lead_time_hours'] # e.g. [24, 48, ..., 240]
153
+
154
+ # Extract 500 hPa geopotential at 24 h lead time
155
+ z500_idx = variables.index('geopotential_500')
156
+ i_24h = int(np.where(lead_hours == 24)[0][0])
157
+ z500_24h = forecasts[0, i_24h, :, :, z500_idx] # (240, 121) lon × lat
158
+ ```
159
+
160
+ ## Input Data Format
161
+
162
+ The model accepts ERA5 or HRES data in zarr format at 1.5° resolution with:
163
+
164
+ - **Grid:** 240 lon × 121 lat equiangular with poles
165
+ - **Time:** 6-hourly timesteps (integer hours since an arbitrary origin, parsed from the `units` zarr attribute, or as datetime64)
166
+ - **Variables:** all 10 atmospheric variables listed above; per-variable layout is auto-detected from `_ARRAY_DIMENSIONS` (either `(time, latitude, longitude)` or `(time, longitude, latitude)`), and latitude is flipped if stored North→South
167
+
168
+ Compatible zarr stores from [WeatherBench2](https://weatherbench2.readthedocs.io/):
169
+
170
+ ```
171
+ gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-240x121_equiangular_with_poles_conservative.zarr
172
+ gs://weatherbench2/datasets/hres_t0/2016-2022-6h-240x121_equiangular_with_poles_conservative.zarr
173
+ ```
174
+
175
+ ## Repository Contents
176
+
177
+ | File | Description |
178
+ |------|-------------|
179
+ | `inference.py` | Main inference script (variant-aware via `--variant {era5,hres}`) |
180
+ | `mosaic.py` | Mosaic U-Net transformer |
181
+ | `primitives.py` | Attention blocks, RoPE, HEALPix sampling, noise generator |
182
+ | `ops.py` | Triton block-sparse attention kernels |
183
+ | `utils.py` | HEALPix grid utilities |
184
+ | `base.py` | `WeatherModel` wrapper |
185
+ | `config.py` | Variable / level definitions |
186
+ | `dataset.py` | Metadata dataclasses |
187
+ | `norm_stats_era5.npz` | Normalization statistics for the `era5` variant |
188
+ | `norm_stats_hres.npz` | Normalization statistics for the `hres` variant |
189
+ | `static_vars.npz` | Static fields (orography, land–sea mask, soil type) — shared between variants |
190
+ | `era5_best.pt` | Trained checkpoint, `era5` variant (~1.7 GB) |
191
+ | `hres_best.pt` | Trained checkpoint, `hres` variant (~1.7 GB) |
192
+ | `figures_weather/` | Figures from the paper |
193
+
194
+ ## Limitations
195
+
196
+ Mosaic operates at 1.5° (~166 km), which cannot resolve mesoscale phenomena such as tropical-cyclone inner-core structure or individual severe thunderstorms. The block-sparse attention is designed to scale linearly with sequence length, so finer grids (e.g. 0.25°, ~700k tokens) are a natural next step but are not part of this release.
197
+
198
+ ## Citation
199
+
200
+ If you use Mosaic, please cite:
201
+
202
+ ```bibtex
203
+ @inproceedings{zhdanov2026mosaic,
204
+ title = {(Sparse) Attention to the Details: Preserving Spectral Fidelity in ML-based Weather Forecasting Models},
205
+ author = {Zhdanov, Maksim and Lucic, Ana and Welling, Max and van de Meent, Jan-Willem},
206
+ booktitle = {Proceedings of the 43rd International Conference on Machine Learning (ICML)},
207
+ year = {2026},
208
+ url = {https://arxiv.org/abs/2604.16429}
209
+ }
210
+ ```
211
+
212
+ ## License
213
+
214
+ Released under [CC-BY-NC-4.0](https://creativecommons.org/licenses/by-nc/4.0/). Free for non-commercial research and educational use with attribution; commercial use requires a separate license. Underlying training data (ERA5, HRES) is subject to its own licensing terms set by ECMWF.
215
+
216
+ ## Acknowledgements
217
+
218
+ MZ acknowledges support from Microsoft Research AI4Science. JWvdM acknowledges support from the European Union Horizon Framework Programme (Grant agreement ID: 101120237). This work used the Dutch national e-infrastructure with the support of the SURF Cooperative using grant no. EINF-16923. Computations were partially performed using the UvA/FNWI HPC Facility.
base.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from dataset import WeatherMetadata
4
+
5
+
6
+ class WeatherModel(nn.Module):
7
+ """Weather forecasting model wrapper."""
8
+
9
+ def __init__(self, model: nn.Module, weather_metadata: WeatherMetadata):
10
+ super().__init__()
11
+ self.model = model
12
+ self.model.initialize_static_vars(weather_metadata.static_data, weather_metadata.longitude, weather_metadata.latitude)
13
+ self.model.initialize_interpolation(weather_metadata.longitude, weather_metadata.latitude)
14
+ self.weather_metadata = weather_metadata
15
+
16
+ def forward(self, norm_state: torch.Tensor, day_year_time: torch.Tensor, num_noise_samples: int):
17
+ return self.model(norm_state, day_year_time, num_noise_samples)
config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Variable names and pressure levels for the Mosaic weather forecasting model.
2
+
3
+ SL_VARS: list[str] = [
4
+ "2m_temperature",
5
+ "10m_u_component_of_wind",
6
+ "10m_v_component_of_wind",
7
+ "mean_sea_level_pressure",
8
+ ]
9
+
10
+ PL_VARS: list[str] = [
11
+ "geopotential",
12
+ "specific_humidity",
13
+ "temperature",
14
+ "u_component_of_wind",
15
+ "v_component_of_wind",
16
+ "vertical_velocity",
17
+ ]
18
+
19
+ ST_VARS: list[str] = [
20
+ "geopotential_at_surface",
21
+ "land_sea_mask",
22
+ "soil_type",
23
+ ]
24
+
25
+ LEVELS: list[int] = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]
dataset.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Metadata dataclasses for weather forecasting inference."""
2
+
3
+ import torch
4
+ from dataclasses import dataclass
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class NormalizationStats:
9
+ """Normalization statistics for state variables."""
10
+ state_mean: torch.Tensor
11
+ state_std: torch.Tensor
12
+ residual_mean: torch.Tensor
13
+ residual_std: torch.Tensor
14
+
15
+ def to(self, device) -> 'NormalizationStats':
16
+ return NormalizationStats(
17
+ state_mean=self.state_mean.to(device),
18
+ state_std=self.state_std.to(device),
19
+ residual_mean=self.residual_mean.to(device),
20
+ residual_std=self.residual_std.to(device),
21
+ )
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class WeatherMetadata:
26
+ """Metadata for the weather dataset."""
27
+ variables: list[str]
28
+ static_variables: list[str]
29
+ longitude: torch.Tensor
30
+ latitude: torch.Tensor
31
+ static_data: torch.Tensor
32
+ day_year_delta: torch.Tensor
33
+ norm_stats: NormalizationStats
era5_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:187df31a0caef3e934e4eb8a11506c9fea518ff0e02ce6d1f804fc6c8a78a940
3
+ size 1713557607
figures_weather/bsa.jpg ADDED

Git LFS Details

  • SHA256: 8da4601faed4832cc5339d78ef80f149ec841493146b91c9bdee7a623d83d0e9
  • Pointer size: 131 Bytes
  • Size of remote file: 302 kB
figures_weather/bsa_runtime.jpg ADDED

Git LFS Details

  • SHA256: 266b2199d88bf159212c3e26c232cc273258ce561a601f28d565e419d5e6299a
  • Pointer size: 131 Bytes
  • Size of remote file: 111 kB
figures_weather/healpix.jpg ADDED
figures_weather/hurricane_tracking.jpg ADDED

Git LFS Details

  • SHA256: c11d5a0477d2b5f3ad591e6077bb09176aaa6abbe7b4bb5ca62f42c06315ef98
  • Pointer size: 131 Bytes
  • Size of remote file: 268 kB
figures_weather/main.jpg ADDED

Git LFS Details

  • SHA256: 45fb47af2cef25871e5cad8f57b77a55dec907fdb3e42de5076e50b3f67c93dc
  • Pointer size: 131 Bytes
  • Size of remote file: 223 kB
figures_weather/results_hres.jpg ADDED

Git LFS Details

  • SHA256: fc651643d9605fec7545734943bc5d1f19c95caa85df384f441afee3fc22f900
  • Pointer size: 131 Bytes
  • Size of remote file: 509 kB
figures_weather/results_spectra_pareto.jpg ADDED

Git LFS Details

  • SHA256: 2aa7afae00e1a55312820883978caf1f5e8fbdf2a2340e0ea4fcce839606cb89
  • Pointer size: 131 Bytes
  • Size of remote file: 361 kB
hres_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e0bc244382fa1b0f09ecdbd527c601d352cc5989697149742e8983f38a25e5e
3
+ size 1714571315
inference.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Run autoregressive global weather forecasts with the Mosaic 1.5° model.
3
+
4
+ The model predicts 6-hourly atmospheric states autoregressively, supporting
5
+ both deterministic (1 member) and probabilistic (N members) forecasts.
6
+
7
+ Usage:
8
+ # ERA5 variant (24h steps), default checkpoint and norm stats inferred from --variant
9
+ python inference.py --variant era5 \\
10
+ --zarr gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-240x121_equiangular_with_poles_conservative.zarr \\
11
+ --init-time "2020-01-01T00:00" --steps 10 --output forecast_era5.npz
12
+
13
+ # HRES variant (6h steps)
14
+ python inference.py --variant hres \\
15
+ --zarr gs://weatherbench2/datasets/hres_t0/2016-2022-6h-240x121_equiangular_with_poles_conservative.zarr \\
16
+ --init-time "2022-01-01T00:00" --steps 40 --output forecast_hres.npz
17
+
18
+ Input zarr format:
19
+ The zarr store must contain the following variables at 1.5° resolution
20
+ (240 lon × 121 lat, 6-hourly timesteps):
21
+ - Surface: 2m_temperature, 10m_u_component_of_wind, 10m_v_component_of_wind,
22
+ mean_sea_level_pressure
23
+ - Pressure-level (at 13 levels 50..1000 hPa): geopotential, specific_humidity,
24
+ temperature, u_component_of_wind, v_component_of_wind, vertical_velocity
25
+ - Coordinates: longitude (240,), latitude (121,), time (hours since 1959-01-01)
26
+
27
+ Output npz:
28
+ forecasts float32 (members, steps, 240, 121, 82) – physical units
29
+ variables list of 82 variable names
30
+ lead_time_hours int32 (steps,) – multiples of step_stride*6h
31
+ (era5: 24, 48, ...; hres: 6, 12, ...)
32
+ init_time str – initialization timestamp
33
+
34
+ Hardware:
35
+ Requires a CUDA GPU. A 16 GB GPU is enough for 1 member; A100 80 GB recommended
36
+ for multi-member ensembles. float16 inference (~9 GB for 1 member, 40-step rollout).
37
+ """
38
+
39
+ import argparse
40
+ import math
41
+ import os
42
+ import sys
43
+ from dataclasses import dataclass
44
+ from pathlib import Path
45
+
46
+ import numpy as np
47
+ import pandas as pd
48
+ import torch
49
+ import zarr
50
+
51
+ # ---------------------------------------------------------------------------
52
+ # Model imports
53
+ # ---------------------------------------------------------------------------
54
+ from config import SL_VARS, PL_VARS, ST_VARS, LEVELS
55
+ from dataset import NormalizationStats, WeatherMetadata
56
+ from mosaic import Transformer, ModelConfig, StageConfig, BottleneckConfig
57
+ from base import WeatherModel
58
+
59
+ DTYPE = torch.float16
60
+
61
+ # ---------------------------------------------------------------------------
62
+ # Model variant presets
63
+ # ---------------------------------------------------------------------------
64
+ # The two published variants share the same Mosaic architecture (stage / bottleneck
65
+ # sizes) but differ in training data, time cadence, history length, neighbour
66
+ # count, and normalisation statistics:
67
+ # - `era5`: ERA5-only training, 24h steps (4 x 6h), 2 input states, k=24 neighbours
68
+ # - `hres`: ERA5 pretrain + HRES finetune, 6h steps, 4 input states, k=20 neighbours
69
+ # ---------------------------------------------------------------------------
70
+
71
+ _STAGE_CFGS_COMMON = [
72
+ StageConfig(
73
+ nside=64, dim=768, num_heads=12,
74
+ block_attn_size=1024, sparse_block_size=128, sparse_block_count=24,
75
+ encoder_depth=4, decoder_depth=2, mlp_ratio=4.0, gqa_ratio=4,
76
+ ),
77
+ StageConfig(
78
+ nside=32, dim=1024, num_heads=16,
79
+ block_attn_size=1024, sparse_block_size=128, sparse_block_count=12,
80
+ encoder_depth=4, decoder_depth=2, mlp_ratio=4.0, gqa_ratio=4,
81
+ ),
82
+ ]
83
+
84
+ _BOTTLENECK_CFG_COMMON = BottleneckConfig(
85
+ nside=16, dim=1280, num_heads=20,
86
+ block_attn_size=1024, sparse_block_size=128, sparse_block_count=4,
87
+ depth=2, mlp_ratio=4.0, gqa_ratio=4,
88
+ )
89
+
90
+
91
+ @dataclass
92
+ class Preset:
93
+ step_stride: int # number of native 6h timesteps per model step
94
+ num_history_steps: int # number of input states fed to the model
95
+ k_neighbors: int # neighbours used in cross-attention interpolation
96
+ default_checkpoint: str
97
+ default_norm_stats: str
98
+ stage_cfgs: list
99
+ bottleneck_cfg: BottleneckConfig
100
+
101
+
102
+ PRESETS = {
103
+ "era5": Preset(
104
+ step_stride=4, num_history_steps=2, k_neighbors=24,
105
+ default_checkpoint="era5_best.pt",
106
+ default_norm_stats="norm_stats_era5.npz",
107
+ stage_cfgs=_STAGE_CFGS_COMMON,
108
+ bottleneck_cfg=_BOTTLENECK_CFG_COMMON,
109
+ ),
110
+ "hres": Preset(
111
+ step_stride=1, num_history_steps=4, k_neighbors=20,
112
+ default_checkpoint="hres_best.pt",
113
+ default_norm_stats="norm_stats_hres.npz",
114
+ stage_cfgs=_STAGE_CFGS_COMMON,
115
+ bottleneck_cfg=_BOTTLENECK_CFG_COMMON,
116
+ ),
117
+ }
118
+
119
+
120
+ # ---------------------------------------------------------------------------
121
+ # Time utilities
122
+ # ---------------------------------------------------------------------------
123
+
124
+ def compute_day_year_progress(timestamp: pd.Timestamp):
125
+ """Return (day_progress, year_progress) fractions for a single timestamp."""
126
+ day_progress = (timestamp.hour * 3600 + timestamp.minute * 60 + timestamp.second) / 86400.0
127
+ days_in_year = 366 if timestamp.is_leap_year else 365
128
+ year_progress = (timestamp.day_of_year - 1) / days_in_year
129
+ return float(day_progress), float(year_progress)
130
+
131
+
132
+ # ---------------------------------------------------------------------------
133
+ # Zarr loading
134
+ # ---------------------------------------------------------------------------
135
+
136
+ def _load_zarr_times(store) -> pd.DatetimeIndex:
137
+ """Load and decode the time coordinate from the zarr store, honouring its units attr."""
138
+ time_raw = np.asarray(store['time'])
139
+ if not np.issubdtype(time_raw.dtype, np.integer):
140
+ return pd.to_datetime(time_raw)
141
+ # Integer encoding: parse 'units' attr e.g. "hours since 1959-01-01"
142
+ units = store['time'].attrs.get('units', 'hours since 1959-01-01')
143
+ try:
144
+ unit_word, _, origin = units.partition(' since ')
145
+ except Exception:
146
+ unit_word, origin = 'hours', '1959-01-01'
147
+ unit_map = {'hours': 'h', 'hour': 'h', 'days': 'D', 'day': 'D',
148
+ 'minutes': 'm', 'minute': 'm', 'seconds': 's', 'second': 's'}
149
+ unit = unit_map.get(unit_word.strip().lower(), 'h')
150
+ return pd.to_datetime(time_raw, unit=unit, origin=origin.strip() or '1959-01-01')
151
+
152
+
153
+ def load_initial_state(zarr_path: str, init_time: str, num_history_steps: int = 4, step_stride: int = 1):
154
+ """
155
+ Load `num_history_steps` timesteps ending at `init_time` from a zarr store,
156
+ spaced `step_stride * 6h` apart (so step_stride=4 -> 24h spacing).
157
+
158
+ Returns:
159
+ state: np.ndarray of shape (num_history_steps, 240, 121, 82) in physical units
160
+ day_year_time: tuple (day_progress, year_progress) for init_time
161
+ longitude: np.ndarray (240,)
162
+ latitude: np.ndarray (121,) South→North
163
+ """
164
+ # Open zarr (supports local paths, gs://, s3://, etc.)
165
+ if zarr_path.startswith('gs://'):
166
+ import gcsfs
167
+ fs = gcsfs.GCSFileSystem(token='anon')
168
+ store_obj = zarr.open(fs.get_mapper(zarr_path), mode='r')
169
+ else:
170
+ store_obj = zarr.open(zarr_path, mode='r')
171
+
172
+ times = _load_zarr_times(store_obj)
173
+ init_ts = pd.Timestamp(init_time)
174
+
175
+ # Find the index of init_time
176
+ idx = times.searchsorted(init_ts)
177
+ if idx >= len(times) or times[idx] != init_ts:
178
+ raise ValueError(
179
+ f"init_time '{init_time}' not found in zarr store. "
180
+ f"Available range: {times[0]} to {times[-1]}"
181
+ )
182
+
183
+ # history indices: [idx - (H-1)*S, idx - (H-2)*S, ..., idx]
184
+ history_indices = [idx - (num_history_steps - 1 - i) * step_stride for i in range(num_history_steps)]
185
+ if history_indices[0] < 0:
186
+ raise ValueError(
187
+ f"Not enough history: need {num_history_steps} steps spaced {step_stride*6}h apart "
188
+ f"before {init_time}, but data starts at {times[0]}"
189
+ )
190
+
191
+ # Load longitude/latitude in the canonical (lon, lat S→N) order the model expects.
192
+ longitude = np.asarray(store_obj['longitude']) # (240,) 0..358.5
193
+ latitude_raw = np.asarray(store_obj['latitude']) # (121,)
194
+ if latitude_raw[0] > latitude_raw[-1]: # N→S in store → flip
195
+ latitude = latitude_raw[::-1].copy()
196
+ flip_lat = True
197
+ else:
198
+ latitude = latitude_raw.copy()
199
+ flip_lat = False
200
+
201
+ n_lon, n_lat = len(longitude), len(latitude)
202
+ n_vars = len(SL_VARS) + len(PL_VARS) * len(LEVELS)
203
+ state = np.empty((num_history_steps, n_lon, n_lat, n_vars), dtype=np.float32)
204
+
205
+ def _to_lon_lat(arr: np.ndarray, dims: list) -> np.ndarray:
206
+ """Normalise a (lat,lon) or (lon,lat) slice to (lon, lat S→N)."""
207
+ if dims[-2:] == ['latitude', 'longitude']:
208
+ arr = arr.T # (lat,lon) -> (lon,lat)
209
+ elif dims[-2:] != ['longitude', 'latitude']:
210
+ raise ValueError(f"unexpected dim order: {dims}")
211
+ if flip_lat:
212
+ arr = arr[:, ::-1]
213
+ return np.ascontiguousarray(arr)
214
+
215
+ all_levels_in_store = list(np.asarray(store_obj['level'])) if 'level' in store_obj else None
216
+
217
+ for step_i, t_idx in enumerate(history_indices):
218
+ ch = 0
219
+ for var in SL_VARS:
220
+ dims = list(store_obj[var].attrs.get('_ARRAY_DIMENSIONS', ['time', 'latitude', 'longitude']))
221
+ arr = np.asarray(store_obj[var][t_idx]) # 2D
222
+ state[step_i, :, :, ch] = _to_lon_lat(arr, dims)
223
+ ch += 1
224
+
225
+ for var in PL_VARS:
226
+ dims = list(store_obj[var].attrs.get('_ARRAY_DIMENSIONS', ['time', 'level', 'latitude', 'longitude']))
227
+ arr_full = np.asarray(store_obj[var][t_idx]) # 3D (level, ...)
228
+ spatial_dims = [d for d in dims if d != 'time'] # drop time (already indexed)
229
+ for level in LEVELS:
230
+ lev_idx = all_levels_in_store.index(level) if all_levels_in_store is not None else LEVELS.index(level)
231
+ arr = arr_full[lev_idx] # 2D
232
+ # spatial_dims still includes 'level' at the front; pass just the 2D part
233
+ state[step_i, :, :, ch] = _to_lon_lat(arr, spatial_dims[1:] if spatial_dims[0] == 'level' else spatial_dims)
234
+ ch += 1
235
+
236
+ day_progress, year_progress = compute_day_year_progress(init_ts)
237
+ return state, (day_progress, year_progress), longitude, latitude
238
+
239
+
240
+ # ---------------------------------------------------------------------------
241
+ # Model building
242
+ # ---------------------------------------------------------------------------
243
+
244
+ def build_model(
245
+ checkpoint_path: str,
246
+ variables: list,
247
+ longitude: np.ndarray,
248
+ latitude: np.ndarray,
249
+ preset: Preset,
250
+ norm_stats_path: str = "norm_stats.npz",
251
+ static_vars_path: str = "static_vars.npz",
252
+ device: str = "cuda",
253
+ ):
254
+ """Build and return the WeatherModel with loaded checkpoint and metadata."""
255
+
256
+ # Load normalization statistics
257
+ _ns = np.load(norm_stats_path)
258
+ norm_stats = NormalizationStats(
259
+ state_mean=torch.from_numpy(_ns['state_mean'].astype(np.float32)),
260
+ state_std=torch.from_numpy(_ns['state_std'].astype(np.float32)),
261
+ residual_mean=torch.from_numpy(_ns['residual_mean'].astype(np.float32)) if 'residual_mean' in _ns else torch.zeros(len(variables)),
262
+ residual_std=torch.from_numpy(_ns['residual_std'].astype(np.float32)) if 'residual_std' in _ns else torch.ones(len(variables)),
263
+ )
264
+
265
+ # Load static variables
266
+ _sv = np.load(static_vars_path)
267
+ static_data = torch.from_numpy(_sv['data'].astype(np.float32)) # (lon, lat, 3)
268
+ lon_tensor = torch.from_numpy(longitude.astype(np.float32))
269
+ lat_tensor = torch.from_numpy(latitude.astype(np.float32))
270
+
271
+ day_year_delta = torch.tensor(
272
+ [preset.step_stride / 4.0, preset.step_stride / 365.25], dtype=torch.float32
273
+ )
274
+
275
+ metadata = WeatherMetadata(
276
+ variables=variables,
277
+ static_variables=list(ST_VARS),
278
+ longitude=lon_tensor,
279
+ latitude=lat_tensor,
280
+ static_data=static_data,
281
+ day_year_delta=day_year_delta,
282
+ norm_stats=norm_stats,
283
+ )
284
+
285
+ # Build model
286
+ model_config = ModelConfig(
287
+ dim=preset.stage_cfgs[0].dim,
288
+ num_heads=preset.stage_cfgs[0].num_heads,
289
+ variables=variables,
290
+ static_variables=list(ST_VARS),
291
+ k_neighbors=preset.k_neighbors,
292
+ qk_norm=False,
293
+ rope=True,
294
+ rope_theta=10000,
295
+ sparse_every=1,
296
+ qkv_compress_ratio=1,
297
+ num_history_steps=preset.num_history_steps,
298
+ noise_dim=32,
299
+ rmsnorm_elementwise_affine=False,
300
+ cg_stage_cfgs=preset.stage_cfgs,
301
+ bottleneck_cfg=preset.bottleneck_cfg,
302
+ )
303
+
304
+ backbone = Transformer(model_config)
305
+ model = WeatherModel(backbone, metadata).to(device).eval()
306
+
307
+ # Load checkpoint. The model registers several deterministic buffers (RoPE
308
+ # tables, HEALPix neighbour indices, static_vars) that are recomputed at
309
+ # __init__ from the metadata/config and therefore aren't expected in the
310
+ # saved checkpoint — so we load non-strictly and only warn on *unexpected*
311
+ # keys, which would indicate a real architecture mismatch.
312
+ ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
313
+ state_dict = ckpt.get('model_state_dict', ckpt)
314
+ result = model.load_state_dict(state_dict, strict=False)
315
+ if result.unexpected_keys:
316
+ raise RuntimeError(
317
+ f"Unexpected keys in checkpoint (architecture mismatch): {result.unexpected_keys[:5]}"
318
+ + (f" ... and {len(result.unexpected_keys)-5} more" if len(result.unexpected_keys) > 5 else "")
319
+ )
320
+ print(f"Loaded checkpoint from {checkpoint_path} (epoch {ckpt.get('epoch', '?')})")
321
+ print(f" {len(state_dict)} keys loaded, {len(result.missing_keys)} buffer keys re-computed from config")
322
+
323
+ return model, metadata
324
+
325
+
326
+ # ---------------------------------------------------------------------------
327
+ # Autoregressive rollout (direct state prediction)
328
+ # ---------------------------------------------------------------------------
329
+
330
+ @torch.no_grad()
331
+ def unroll_direct(
332
+ model: WeatherModel,
333
+ initial_unnorm_state: torch.Tensor,
334
+ day_year_time: torch.Tensor,
335
+ day_year_delta: torch.Tensor,
336
+ norm_stats: NormalizationStats,
337
+ num_unroll_steps: int,
338
+ num_ensemble_members: int,
339
+ dtype: torch.dtype = torch.float16,
340
+ ) -> torch.Tensor:
341
+ """
342
+ Autoregressively forecast using direct state prediction (learn_direct=True).
343
+
344
+ Args:
345
+ model: WeatherModel
346
+ initial_unnorm_state: (B, num_history_steps, lon, lat, channels) in physical units
347
+ day_year_time: (B, 2) day/year progress fractions at init_time
348
+ day_year_delta: (2,) increment per step
349
+ norm_stats: NormalizationStats on the target device
350
+ num_unroll_steps: number of 6-hourly steps to forecast
351
+ num_ensemble_members: number of ensemble members (noise samples on step 0)
352
+ dtype: computation dtype (float16 recommended)
353
+
354
+ Returns:
355
+ trajectory: (B, members, num_history_steps + num_unroll_steps, lon, lat, channels)
356
+ """
357
+ batch_size = initial_unnorm_state.shape[0]
358
+ num_history_steps = initial_unnorm_state.shape[1]
359
+ device = initial_unnorm_state.device
360
+
361
+ trajectory = torch.empty(
362
+ (batch_size, num_ensemble_members, num_unroll_steps + num_history_steps)
363
+ + initial_unnorm_state.shape[2:],
364
+ dtype=initial_unnorm_state.dtype,
365
+ device=device,
366
+ )
367
+
368
+ # Expand initial state to ensemble dimension
369
+ current_unnorm_state = initial_unnorm_state.unsqueeze(1) # (B, 1, H, lon, lat, C)
370
+ current_day_year_time = day_year_time.unsqueeze(1) # (B, 1, 2)
371
+
372
+ trajectory[:, :, :num_history_steps] = current_unnorm_state
373
+
374
+ for t in range(num_unroll_steps):
375
+ # Expand ensemble only on the first step
376
+ num_ens_step = num_ensemble_members if t == 0 else 1
377
+
378
+ current_norm_state = (current_unnorm_state - norm_stats.state_mean) / norm_stats.state_std
379
+ with torch.amp.autocast('cuda', dtype=dtype):
380
+ norm_next_state = model(current_norm_state, current_day_year_time, num_ens_step)
381
+
382
+ next_unnorm_state = norm_next_state * norm_stats.state_std + norm_stats.state_mean
383
+ current_day_year_time = current_day_year_time + day_year_delta.unsqueeze(0).unsqueeze(0).expand(
384
+ batch_size, num_ens_step, -1
385
+ )
386
+
387
+ trajectory[:, :, t + num_history_steps] = next_unnorm_state
388
+ current_unnorm_state = trajectory[:, :, t + 1 : t + 1 + num_history_steps]
389
+
390
+ return trajectory
391
+
392
+
393
+ # ---------------------------------------------------------------------------
394
+ # Main
395
+ # ---------------------------------------------------------------------------
396
+
397
+ def main():
398
+ parser = argparse.ArgumentParser(description="Mosaic 1.5° Weather Forecast Inference")
399
+ parser.add_argument("--variant", type=str, required=True, choices=sorted(PRESETS.keys()),
400
+ help="Model variant: 'era5' (ERA5-only, 24h steps) or 'hres' (ERA5+HRES finetune, 6h steps)")
401
+ parser.add_argument("--checkpoint", type=str, default=None,
402
+ help="Path to model checkpoint (.pt). Default: preset's default_checkpoint")
403
+ parser.add_argument("--zarr", type=str, required=True,
404
+ help="Path or GCS URI to zarr store with ERA5/HRES data at 1.5°")
405
+ parser.add_argument("--init-time", type=str, required=True,
406
+ help="Initialization time (ISO 8601), e.g. '2020-01-01T00:00'")
407
+ parser.add_argument("--steps", type=int, default=10,
408
+ help="Number of forecast steps (each step = step_stride*6h; e.g. era5 step=24h, hres step=6h)")
409
+ parser.add_argument("--members", type=int, default=1,
410
+ help="Number of ensemble members (default: 1)")
411
+ parser.add_argument("--output", type=str, default="forecast.npz",
412
+ help="Output file path (default: forecast.npz)")
413
+ parser.add_argument("--norm-stats", type=str, default=None,
414
+ help="Path to norm_stats .npz. Default: preset's default_norm_stats")
415
+ parser.add_argument("--static-vars", type=str, default="static_vars.npz",
416
+ help="Path to static_vars.npz (default: static_vars.npz in current dir)")
417
+ parser.add_argument("--k-neighbors", type=int, default=None,
418
+ help="Override preset's k_neighbors (advanced — for ablation only)")
419
+ parser.add_argument("--no-compile", action="store_true",
420
+ help="Disable torch.compile (slower but easier to debug)")
421
+ parser.add_argument("--device", type=str, default="cuda",
422
+ help="Device (default: cuda)")
423
+ args = parser.parse_args()
424
+
425
+ preset = PRESETS[args.variant]
426
+ if args.k_neighbors is not None and args.k_neighbors != preset.k_neighbors:
427
+ from dataclasses import replace
428
+ preset = replace(preset, k_neighbors=args.k_neighbors)
429
+ checkpoint_path = args.checkpoint or preset.default_checkpoint
430
+ norm_stats_path = args.norm_stats or preset.default_norm_stats
431
+ print(f"Variant: {args.variant} "
432
+ f"(step_stride={preset.step_stride}, num_history_steps={preset.num_history_steps}, "
433
+ f"k_neighbors={preset.k_neighbors})")
434
+
435
+ device = args.device
436
+ torch.set_float32_matmul_precision('high')
437
+
438
+ # Build variable list: 4 surface + 6*13 pressure-level = 82 channels
439
+ variables = list(SL_VARS)
440
+ for var in PL_VARS:
441
+ for level in LEVELS:
442
+ variables.append(f"{var}_{level}")
443
+ print(f"Variables: {len(variables)} channels")
444
+
445
+ # Load initial state from zarr
446
+ print(f"Loading initial state from zarr: {args.zarr}")
447
+ print(f" Init time: {args.init_time} (history: {preset.num_history_steps} x {preset.step_stride*6}h steps)")
448
+ initial_state_np, (day_prog, year_prog), longitude, latitude = load_initial_state(
449
+ args.zarr, args.init_time,
450
+ num_history_steps=preset.num_history_steps,
451
+ step_stride=preset.step_stride,
452
+ )
453
+ print(f" State shape: {initial_state_np.shape} (steps, lon, lat, channels)")
454
+
455
+ # Build model and load checkpoint
456
+ print(f"\nBuilding model and loading checkpoint: {checkpoint_path}")
457
+ model, metadata = build_model(
458
+ checkpoint_path=checkpoint_path,
459
+ variables=variables,
460
+ longitude=longitude,
461
+ latitude=latitude,
462
+ preset=preset,
463
+ norm_stats_path=norm_stats_path,
464
+ static_vars_path=args.static_vars,
465
+ device=device,
466
+ )
467
+ num_params = sum(p.numel() for p in model.parameters()) / 1e6
468
+ print(f" Parameters: {num_params:.1f}M")
469
+
470
+ # Optionally compile
471
+ if not args.no_compile:
472
+ print("Compiling model with torch.compile (reduce-overhead)...")
473
+ unroll_fn = torch.compile(unroll_direct, mode='reduce-overhead')
474
+ else:
475
+ unroll_fn = unroll_direct
476
+
477
+ # Prepare tensors
478
+ initial_state = torch.from_numpy(initial_state_np).unsqueeze(0).to(device) # (1, H, lon, lat, C)
479
+ day_year_time = torch.tensor([[day_prog, year_prog]], dtype=torch.float32, device=device) # (1, 2)
480
+ norm_stats_d = metadata.norm_stats.to(device)
481
+ day_year_delta_d = metadata.day_year_delta.to(device)
482
+
483
+ # Run forecast
484
+ total_hours = args.steps * preset.step_stride * 6
485
+ print(f"\nRunning {args.steps}-step forecast ({total_hours}h) with {args.members} member(s)...")
486
+ if device == 'cuda':
487
+ torch.cuda.reset_peak_memory_stats()
488
+
489
+ with torch.no_grad():
490
+ trajectory = unroll_fn(
491
+ model=model,
492
+ initial_unnorm_state=initial_state,
493
+ day_year_time=day_year_time,
494
+ day_year_delta=day_year_delta_d,
495
+ norm_stats=norm_stats_d,
496
+ num_unroll_steps=args.steps,
497
+ num_ensemble_members=args.members,
498
+ dtype=DTYPE,
499
+ )
500
+
501
+ if device == 'cuda':
502
+ torch.cuda.synchronize()
503
+ peak_gb = torch.cuda.max_memory_allocated() / 1e9
504
+ print(f" Peak GPU memory: {peak_gb:.1f} GB")
505
+
506
+ # Extract forecast steps (skip history)
507
+ forecasts = trajectory[0, :, preset.num_history_steps:].cpu().numpy() # (members, steps, lon, lat, C)
508
+ print(f" Forecast shape: {forecasts.shape}")
509
+
510
+ # Save output
511
+ lead_time_hours = np.arange(1, args.steps + 1) * 6 * preset.step_stride
512
+ np.savez(
513
+ args.output,
514
+ forecasts=forecasts,
515
+ variables=np.array(variables),
516
+ lead_time_hours=lead_time_hours,
517
+ init_time=np.str_(args.init_time),
518
+ longitude=longitude,
519
+ latitude=latitude,
520
+ )
521
+ print(f"\nSaved forecast to: {args.output}")
522
+ print(f" Shape: forecasts {forecasts.shape} (members, steps, lon=240, lat=121, channels=82)")
523
+ print(f" Lead times: {lead_time_hours[0]}h to {lead_time_hours[-1]}h")
524
+
525
+
526
+ if __name__ == "__main__":
527
+ main()
mosaic.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mosaic: U-Net transformer with block-sparse attention for weather forecasting.
3
+
4
+ Architecture:
5
+ - Cross-attention interpolation between lon/lat and HEALPix grids
6
+ - Block-sparse attention (local block + compressed + top-k selection branches)
7
+ arranged in a U-Net encoder–bottleneck–decoder
8
+ - Probabilistic training with noise injection
9
+ """
10
+
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+ from einops import rearrange, repeat
15
+ from dataclasses import dataclass
16
+ from torch.nn import RMSNorm
17
+
18
+ from utils import get_healpix_grid, rad_to_xyz
19
+ from primitives import (
20
+ MosaicBlock as _MosaicBlock,
21
+ CrossAttentionInterpolate,
22
+ NoiseGenerator,
23
+ HEALPixDownsample,
24
+ HEALPixUpsample,
25
+ )
26
+
27
+
28
+ @dataclass
29
+ class StageConfig:
30
+ """Configuration for a U-Net encoder/decoder stage."""
31
+ nside: int
32
+ dim: int
33
+ num_heads: int
34
+ block_attn_size: int
35
+ sparse_block_size: int
36
+ sparse_block_count: int
37
+ encoder_depth: int
38
+ decoder_depth: int
39
+ mlp_ratio: float
40
+ gqa_ratio: int
41
+
42
+
43
+ @dataclass
44
+ class BottleneckConfig:
45
+ """Configuration for the U-Net bottleneck stage."""
46
+ nside: int
47
+ dim: int
48
+ num_heads: int
49
+ block_attn_size: int
50
+ sparse_block_size: int
51
+ sparse_block_count: int
52
+ depth: int
53
+ mlp_ratio: float
54
+ gqa_ratio: int
55
+
56
+
57
+ @dataclass
58
+ class ModelConfig:
59
+ """Configuration for the Mosaic model."""
60
+ dim: int
61
+ num_heads: int
62
+ k_neighbors: int
63
+ qk_norm: bool
64
+ rope: bool
65
+ rope_theta: int
66
+ sparse_every: int
67
+ variables: list[str]
68
+ static_variables: list[str]
69
+ qkv_compress_ratio: int
70
+ cg_stage_cfgs: list[StageConfig]
71
+ bottleneck_cfg: BottleneckConfig
72
+ num_history_steps: int = 1
73
+ noise_dim: int = 32
74
+ ortho_init: bool = False
75
+ rmsnorm_elementwise_affine: bool = True
76
+ no_compression: bool = False
77
+
78
+
79
+ @dataclass
80
+ class _MergedStageConfig:
81
+ """Merges ModelConfig and StageConfig for compatibility with MosaicBlock."""
82
+ dim: int
83
+ num_heads: int
84
+ block_attn_size: int
85
+ sparse_block_size: int
86
+ sparse_block_count: int
87
+ gqa_ratio: int
88
+ qkv_compress_ratio: int
89
+ rope: bool
90
+ rope_theta: int
91
+ mlp_ratio: float
92
+ noise_dim: int
93
+ rmsnorm_elementwise_affine: bool
94
+
95
+
96
+ def _merge_configs(config: ModelConfig, stage_cfg) -> _MergedStageConfig:
97
+ return _MergedStageConfig(
98
+ dim=stage_cfg.dim,
99
+ num_heads=stage_cfg.num_heads,
100
+ block_attn_size=stage_cfg.block_attn_size,
101
+ sparse_block_size=stage_cfg.sparse_block_size,
102
+ sparse_block_count=stage_cfg.sparse_block_count,
103
+ gqa_ratio=stage_cfg.gqa_ratio,
104
+ qkv_compress_ratio=config.qkv_compress_ratio,
105
+ rope=config.rope,
106
+ rope_theta=config.rope_theta,
107
+ mlp_ratio=stage_cfg.mlp_ratio,
108
+ noise_dim=config.noise_dim,
109
+ rmsnorm_elementwise_affine=config.rmsnorm_elementwise_affine,
110
+ )
111
+
112
+
113
+ def _make_mosaic_block(config: ModelConfig, stage_cfg, block_attn_only: bool) -> _MosaicBlock:
114
+ return _MosaicBlock(_merge_configs(config, stage_cfg), block_attn_only, no_compression=config.no_compression)
115
+
116
+
117
+ class UNetStage(nn.Module):
118
+ def __init__(self, config, stage_cfg, depth):
119
+ super().__init__()
120
+ self.nside = stage_cfg.nside
121
+ self.blocks = nn.ModuleList([
122
+ _make_mosaic_block(
123
+ config=config,
124
+ stage_cfg=stage_cfg,
125
+ block_attn_only=(config.sparse_every <= 0) or not (i % config.sparse_every == 0),
126
+ )
127
+ for i in range(depth)
128
+ ])
129
+
130
+ def forward(self, x, z=None):
131
+ for block in self.blocks:
132
+ x = block(x, z)
133
+ return x
134
+
135
+
136
+ class Transformer(nn.Module):
137
+ """U-Net style Transformer for weather forecasting on HEALPix grids."""
138
+
139
+ space_dim = 3
140
+ time_dim = 4
141
+
142
+ def __init__(self, config: ModelConfig, seed: int = 42):
143
+ super().__init__()
144
+
145
+ self.config = config
146
+ self.nside = config.cg_stage_cfgs[0].nside
147
+ self.noise_dim = config.noise_dim
148
+
149
+ initial_dim = config.dim
150
+ feature_dim = (len(config.variables) * config.num_history_steps
151
+ + len(config.static_variables) + self.space_dim + self.time_dim)
152
+
153
+ if self.noise_dim > 0:
154
+ self.noise_generator = NoiseGenerator(self.noise_dim, seed)
155
+
156
+ self.preprocess = nn.Sequential(
157
+ nn.Linear(feature_dim, initial_dim, bias=False),
158
+ RMSNorm(initial_dim, elementwise_affine=config.rmsnorm_elementwise_affine),
159
+ nn.SiLU(),
160
+ nn.Linear(initial_dim, initial_dim, bias=False),
161
+ RMSNorm(initial_dim, elementwise_affine=config.rmsnorm_elementwise_affine),
162
+ )
163
+
164
+ self.interp_to_hp = CrossAttentionInterpolate(config)
165
+ self.interp_to_ll = CrossAttentionInterpolate(config)
166
+
167
+ self.encoder_stages = nn.ModuleList()
168
+ self.downsample_layers = nn.ModuleList()
169
+
170
+ all_stages = [*config.cg_stage_cfgs, config.bottleneck_cfg]
171
+
172
+ for i in range(len(config.cg_stage_cfgs)):
173
+ current_stage = all_stages[i]
174
+ next_stage = all_stages[i + 1]
175
+
176
+ self.encoder_stages.append(UNetStage(config=config, stage_cfg=current_stage, depth=current_stage.encoder_depth))
177
+ self.downsample_layers.append(
178
+ HEALPixDownsample(
179
+ in_dim=current_stage.dim,
180
+ out_dim=next_stage.dim,
181
+ nside_before=current_stage.nside,
182
+ nside_after=next_stage.nside,
183
+ rmsnorm_elementwise_affine=config.rmsnorm_elementwise_affine,
184
+ )
185
+ )
186
+
187
+ self.bottleneck = UNetStage(config=config, stage_cfg=config.bottleneck_cfg, depth=config.bottleneck_cfg.depth)
188
+
189
+ self.decoder_stages = nn.ModuleList()
190
+ self.upsample_layers = nn.ModuleList()
191
+
192
+ for i in reversed(range(len(config.cg_stage_cfgs))):
193
+ prev_stage = all_stages[i + 1]
194
+ current_stage = all_stages[i]
195
+
196
+ self.upsample_layers.append(
197
+ HEALPixUpsample(
198
+ in_dim=prev_stage.dim,
199
+ out_dim=current_stage.dim,
200
+ nside_before=prev_stage.nside,
201
+ nside_after=current_stage.nside,
202
+ rmsnorm_elementwise_affine=config.rmsnorm_elementwise_affine,
203
+ )
204
+ )
205
+ self.decoder_stages.append(UNetStage(config=config, stage_cfg=current_stage, depth=current_stage.decoder_depth))
206
+
207
+ self.norm_before_interp_ll = RMSNorm(initial_dim, elementwise_affine=config.rmsnorm_elementwise_affine)
208
+
209
+ self.postprocess = nn.Sequential(
210
+ RMSNorm(initial_dim, elementwise_affine=config.rmsnorm_elementwise_affine),
211
+ nn.Linear(initial_dim, initial_dim, bias=False),
212
+ nn.SiLU(),
213
+ nn.Linear(initial_dim, len(config.variables), bias=False),
214
+ )
215
+
216
+ self.apply(self._initialize_weights)
217
+ self._zero_init_residual_layers()
218
+ self.initialize_rope()
219
+
220
+ def _initialize_weights(self, module):
221
+ if module is self:
222
+ return
223
+ ortho_init = self.config.ortho_init
224
+
225
+ if isinstance(module, nn.Linear):
226
+ fan_in, fan_out = module.weight.size(1), module.weight.size(0)
227
+ std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
228
+ if ortho_init:
229
+ nn.init.orthogonal_(module.weight); module.weight.data.mul_(std)
230
+ else:
231
+ nn.init.normal_(module.weight, mean=0.0, std=std)
232
+ if module.bias is not None: nn.init.zeros_(module.bias)
233
+
234
+ def _zero_init_residual_layers(self):
235
+ ortho_init = self.config.ortho_init
236
+
237
+ for stage in [*self.encoder_stages, self.bottleneck, *self.decoder_stages]:
238
+ for block in stage.blocks:
239
+ if ortho_init:
240
+ nn.init.orthogonal_(block.attention.to_o.weight)
241
+ block.attention.to_o.weight.data.mul_(0.01)
242
+ nn.init.orthogonal_(block.ffn.w2.weight)
243
+ block.ffn.w2.weight.data.mul_(0.01)
244
+ else:
245
+ nn.init.normal_(block.attention.to_o.weight, mean=0.0, std=0.01)
246
+ nn.init.normal_(block.ffn.w2.weight, mean=0.0, std=0.01)
247
+
248
+ if self.noise_dim > 0:
249
+ nn.init.normal_(block.ffn.noise_bias.weight, mean=0.0, std=0.01)
250
+
251
+ for upsample in self.upsample_layers:
252
+ if ortho_init:
253
+ nn.init.orthogonal_(upsample.proj_x.weight); upsample.proj_x.weight.data.mul_(0.01)
254
+ nn.init.orthogonal_(upsample.proj_pos.weight); upsample.proj_pos.weight.data.mul_(0.01)
255
+ else:
256
+ nn.init.normal_(upsample.proj_x.weight, mean=0.0, std=0.01)
257
+ nn.init.normal_(upsample.proj_pos.weight, mean=0.0, std=0.01)
258
+
259
+ if self.noise_dim > 0:
260
+ nn.init.normal_(self.noise_generator.to_noise.weight, mean=0.0, std=0.01)
261
+
262
+ def initialize_rope(self):
263
+ if not self.config.rope:
264
+ return
265
+ for stage in [*self.encoder_stages, self.bottleneck, *self.decoder_stages]:
266
+ hp_grid = get_healpix_grid(stage.nside)
267
+ for block in stage.blocks:
268
+ if block.attention.q_rope is not None:
269
+ block.attention.q_rope.initialize_rope(hp_grid)
270
+ block.attention.k_rope.initialize_rope(hp_grid)
271
+
272
+ def initialize_interpolation(self, longitude: torch.Tensor, latitude: torch.Tensor):
273
+ ll_grid_rad = torch.deg2rad(torch.stack(torch.meshgrid(longitude, latitude, indexing='ij'), -1).reshape(-1, 2))
274
+ hp_grid_rad = torch.deg2rad(get_healpix_grid(self.nside)).to(longitude.device)
275
+ self.interp_to_hp.initialize_interpolation_scheme(ll_grid_rad, hp_grid_rad)
276
+ self.interp_to_ll.initialize_interpolation_scheme(hp_grid_rad, ll_grid_rad)
277
+
278
+ @torch.no_grad()
279
+ def initialize_static_vars(self, static_vars: torch.Tensor, longitude: torch.Tensor, latitude: torch.Tensor):
280
+ ll_grid_rad = torch.deg2rad(torch.stack(torch.meshgrid(longitude, latitude, indexing='ij'), -1))
281
+ ll_grid_xyz = rad_to_xyz(ll_grid_rad)
282
+ static_vars = torch.concat([static_vars, ll_grid_xyz], dim=-1)
283
+ static_vars_mean = static_vars.mean(dim=(0, 1), keepdim=True)
284
+ static_vars_std = static_vars.std(dim=(0, 1), keepdim=True) + 1e-6
285
+ static_vars_norm = (static_vars - static_vars_mean) / static_vars_std
286
+ static_vars = rearrange(static_vars_norm, 'lon lat c -> (lon lat) 1 c').contiguous()
287
+ self.register_buffer('static_vars', static_vars, persistent=True)
288
+
289
+ @torch.no_grad()
290
+ def time_embedding(self, day_year_time: torch.Tensor):
291
+ day = day_year_time[:, 0:1]
292
+ year = day_year_time[:, 1:2]
293
+ day_sin = torch.sin(2 * math.pi * day)
294
+ day_cos = torch.cos(2 * math.pi * day)
295
+ year_sin = torch.sin(2 * math.pi * year)
296
+ year_cos = torch.cos(2 * math.pi * year)
297
+ return torch.cat([day_sin, day_cos, year_sin, year_cos], dim=-1)
298
+
299
+ def forward(self, x: torch.Tensor, day_year_time: torch.Tensor, num_noise_samples: int):
300
+ b, n, _, lon, lat, _ = x.shape
301
+ batch_size = b * num_noise_samples * n
302
+
303
+ if self.noise_dim > 0:
304
+ z = self.noise_generator(batch_size, x.device, x.dtype)
305
+ else:
306
+ z = None
307
+
308
+ x = repeat(x, 'b n t lon lat c -> (lon lat) (b s n) (t c)', s=num_noise_samples)
309
+ day_year_time = repeat(day_year_time, 'b n d -> (b s n) d', s=num_noise_samples)
310
+
311
+ x = torch.cat([
312
+ x,
313
+ self.static_vars.expand(-1, batch_size, -1),
314
+ self.time_embedding(day_year_time).unsqueeze(0).expand(x.shape[0], -1, -1)
315
+ ], dim=-1)
316
+
317
+ x = self.preprocess(x)
318
+ x = self.interp_to_hp(x)
319
+
320
+ skip_connections = []
321
+ for encoder_stage, downsample in zip(self.encoder_stages, self.downsample_layers):
322
+ x = encoder_stage(x, z)
323
+ skip_connections.append(x)
324
+ x = downsample(x)
325
+
326
+ x = self.bottleneck(x, z)
327
+
328
+ for decoder_stage, upsample, skip in zip(self.decoder_stages, self.upsample_layers, reversed(skip_connections)):
329
+ x = upsample(x, skip)
330
+ x = decoder_stage(x, z)
331
+
332
+ x = self.norm_before_interp_ll(x)
333
+ x = self.interp_to_ll(x)
334
+ x = self.postprocess(x)
335
+
336
+ x = rearrange(x, '(lon lat) (b n s) c -> b (n s) lon lat c', lon=lon, lat=lat, b=b, s=num_noise_samples)
337
+ return x
ops.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+
7
+ def get_autotuning_configs(q_tile_sizes: list):
8
+ """Generate autotuning configurations optimized for H100."""
9
+ warps = [4, 8]
10
+ stages = [2, 3]
11
+
12
+ return [
13
+ triton.Config({'q_tile_size': t}, num_warps=w, num_stages=s)
14
+ for t in q_tile_sizes
15
+ for w in warps
16
+ for s in stages
17
+ ]
18
+
19
+
20
+ @triton.autotune(
21
+ configs=get_autotuning_configs([64, 128]),
22
+ key=['seq_len', 'feature_dim'],
23
+ )
24
+ @triton.jit
25
+ def mosaic_attn_fwd_kernel(
26
+ q_ptr, k_ptr, v_ptr, output_ptr, lse_ptr, block_indices_ptr,
27
+ softmax_scale: tl.constexpr,
28
+ seq_len: tl.constexpr,
29
+ num_kv_heads: tl.constexpr,
30
+ num_q_heads: tl.constexpr,
31
+ q_heads_per_kv_head: tl.constexpr,
32
+ feature_dim: tl.constexpr,
33
+ kv_block_size: tl.constexpr,
34
+ num_kv_blocks_per_q_block: tl.constexpr,
35
+ q_tile_size: tl.constexpr,
36
+ ):
37
+ """
38
+ Sparse attention forward kernel:
39
+ for each query tile (i.e. block chunk), for each query head, attend to a subset of key/value blocks.
40
+ """
41
+ LOG2_E: tl.constexpr = 1.44269504089
42
+
43
+ q_tile_id = tl.program_id(0)
44
+ q_head_id = tl.program_id(1)
45
+ batch_kv_head_id = tl.program_id(2)
46
+
47
+ batch_idx = batch_kv_head_id // num_kv_heads
48
+ kv_head_idx = batch_kv_head_id % num_kv_heads
49
+ q_head_idx = kv_head_idx * q_heads_per_kv_head + q_head_id
50
+
51
+ batch_offset = batch_idx * seq_len
52
+ q_tile_start = q_tile_id * q_tile_size
53
+ num_blocks_in_seq = seq_len // kv_block_size
54
+ tiles_per_block = kv_block_size // q_tile_size
55
+ q_block_id = q_tile_id // tiles_per_block
56
+
57
+ block_indices_offset = (
58
+ batch_idx * num_blocks_in_seq * num_kv_heads * num_kv_blocks_per_q_block +
59
+ q_block_id * num_kv_heads * num_kv_blocks_per_q_block +
60
+ kv_head_idx * num_kv_blocks_per_q_block
61
+ )
62
+
63
+ q_base_ptr = q_ptr + batch_offset * num_q_heads * feature_dim + q_head_idx * feature_dim
64
+ k_base_ptr = k_ptr + batch_offset * num_kv_heads * feature_dim + kv_head_idx * feature_dim
65
+ v_base_ptr = v_ptr + batch_offset * num_kv_heads * feature_dim + kv_head_idx * feature_dim
66
+
67
+ q_tile_ptr = tl.make_block_ptr(
68
+ base=q_base_ptr,
69
+ shape=(seq_len, feature_dim),
70
+ strides=(num_q_heads * feature_dim, 1),
71
+ offsets=(q_tile_start, 0),
72
+ block_shape=(q_tile_size, feature_dim),
73
+ order=(1, 0)
74
+ )
75
+
76
+ output_tile_ptr = tl.make_block_ptr(
77
+ base=output_ptr + batch_offset * num_q_heads * feature_dim + q_head_idx * feature_dim,
78
+ shape=(seq_len, feature_dim),
79
+ strides=(num_q_heads * feature_dim, 1),
80
+ offsets=(q_tile_start, 0),
81
+ block_shape=(q_tile_size, feature_dim),
82
+ order=(1, 0)
83
+ )
84
+
85
+ lse_base_ptr = lse_ptr + (batch_offset + q_tile_start) * num_q_heads + tl.arange(0, q_tile_size) * num_q_heads + q_head_idx
86
+
87
+ output_accum = tl.zeros([q_tile_size, feature_dim], dtype=tl.float32)
88
+ max_scores = tl.full([q_tile_size], float('-inf'), dtype=tl.float32)
89
+ sum_exp_scores = tl.zeros([q_tile_size], dtype=tl.float32)
90
+
91
+ q_tile = tl.load(q_tile_ptr)
92
+ q_tile = (q_tile * softmax_scale * LOG2_E).to(tl.bfloat16)
93
+
94
+ for i in range(num_kv_blocks_per_q_block):
95
+ kv_block_start = kv_block_size * tl.load(block_indices_ptr + block_indices_offset + i).to(tl.int32)
96
+
97
+ k_block_ptr = tl.make_block_ptr(
98
+ base=k_base_ptr,
99
+ shape=(feature_dim, seq_len),
100
+ strides=(1, num_kv_heads * feature_dim),
101
+ offsets=(0, kv_block_start),
102
+ block_shape=(feature_dim, kv_block_size),
103
+ order=(1, 0)
104
+ )
105
+
106
+ v_block_ptr = tl.make_block_ptr(
107
+ base=v_base_ptr,
108
+ shape=(seq_len, feature_dim),
109
+ strides=(num_kv_heads * feature_dim, 1),
110
+ offsets=(kv_block_start, 0),
111
+ block_shape=(kv_block_size, feature_dim),
112
+ order=(1, 0)
113
+ )
114
+
115
+ k_block = tl.load(k_block_ptr).to(tl.bfloat16)
116
+ v_block = tl.load(v_block_ptr).to(tl.bfloat16)
117
+
118
+ attention_scores = tl.dot(q_tile, k_block)
119
+
120
+ new_max = tl.max(attention_scores, axis=1)
121
+ old_max = max_scores
122
+ max_scores = tl.maximum(max_scores, new_max)
123
+ rescale = tl.exp2(old_max - max_scores)
124
+ attention_probs = tl.exp2(attention_scores - max_scores[:, None])
125
+ sum_exp_scores = sum_exp_scores * rescale + tl.sum(attention_probs, axis=1)
126
+
127
+ output_accum = output_accum * rescale[:, None]
128
+ output_accum += tl.dot(attention_probs.to(tl.bfloat16), v_block)
129
+
130
+ final_output = output_accum / sum_exp_scores[:, None]
131
+ log_sum_exp = (max_scores + tl.log2(sum_exp_scores))
132
+
133
+ tl.store(output_tile_ptr, final_output.to(q_ptr.dtype.element_ty))
134
+ tl.store(lse_base_ptr, log_sum_exp.to(tl.float32))
135
+
136
+
137
+ def mosaic_attn_fwd(
138
+ q: torch.Tensor,
139
+ k: torch.Tensor,
140
+ v: torch.Tensor,
141
+ block_indices: torch.LongTensor,
142
+ block_size: int,
143
+ softmax_scale: float,
144
+ ):
145
+ batch_size, seq_len, num_kv_heads, feature_dim = k.shape
146
+ num_q_heads = q.shape[2]
147
+ num_kv_blocks_per_q_block = block_indices.shape[-1]
148
+ q_heads_per_kv_head = num_q_heads // num_kv_heads
149
+
150
+ output = torch.empty(batch_size, seq_len, num_q_heads, feature_dim, dtype=v.dtype, device=q.device)
151
+ lse = torch.empty(batch_size, seq_len, num_q_heads, dtype=torch.float32, device=q.device)
152
+
153
+ grid = lambda META: (
154
+ triton.cdiv(seq_len, META['q_tile_size']),
155
+ q_heads_per_kv_head,
156
+ batch_size * num_kv_heads
157
+ )
158
+
159
+ mosaic_attn_fwd_kernel[grid](
160
+ q_ptr = q,
161
+ k_ptr = k,
162
+ v_ptr = v,
163
+ output_ptr = output,
164
+ lse_ptr = lse,
165
+ block_indices_ptr = block_indices,
166
+ softmax_scale = softmax_scale,
167
+ seq_len = seq_len,
168
+ num_kv_heads = num_kv_heads,
169
+ num_q_heads = num_q_heads,
170
+ q_heads_per_kv_head = q_heads_per_kv_head,
171
+ feature_dim = feature_dim,
172
+ kv_block_size = block_size,
173
+ num_kv_blocks_per_q_block = num_kv_blocks_per_q_block,
174
+ )
175
+
176
+ return output, lse
177
+
178
+
179
+ @triton.autotune(
180
+ configs=get_autotuning_configs([64, 128]),
181
+ key=['seq_len', 'feature_dim'],
182
+ )
183
+ @triton.jit
184
+ def mosaic_attn_bwd_q_kernel(
185
+ q_ptr, k_ptr, v_ptr, lse_ptr, delta_ptr, grad_o_ptr, grad_q_ptr, block_indices_ptr,
186
+ softmax_scale: tl.constexpr,
187
+ seq_len: tl.constexpr,
188
+ num_kv_heads: tl.constexpr,
189
+ num_q_heads: tl.constexpr,
190
+ q_heads_per_kv_head: tl.constexpr,
191
+ feature_dim: tl.constexpr,
192
+ kv_block_size: tl.constexpr,
193
+ num_kv_blocks_per_q_block: tl.constexpr,
194
+ q_tile_size: tl.constexpr,
195
+ ):
196
+ LOG2_E: tl.constexpr = 1.44269504089
197
+ LN_2: tl.constexpr = 0.69314718056
198
+
199
+ q_tile_id = tl.program_id(0)
200
+ q_head_id = tl.program_id(1)
201
+ batch_kv_head_id = tl.program_id(2)
202
+
203
+ batch_idx = batch_kv_head_id // num_kv_heads
204
+ kv_head_idx = batch_kv_head_id % num_kv_heads
205
+ q_head_idx = kv_head_idx * q_heads_per_kv_head + q_head_id
206
+
207
+ batch_offset = batch_idx * seq_len
208
+ q_tile_start = q_tile_id * q_tile_size
209
+ tiles_per_block = kv_block_size // q_tile_size
210
+ q_block_id = q_tile_id // tiles_per_block
211
+ num_q_blocks = seq_len // kv_block_size
212
+
213
+ block_indices_offset = (
214
+ batch_idx * num_q_blocks * num_kv_heads * num_kv_blocks_per_q_block +
215
+ q_block_id * num_kv_heads * num_kv_blocks_per_q_block +
216
+ kv_head_idx * num_kv_blocks_per_q_block
217
+ )
218
+
219
+ q_offsets = (
220
+ tl.arange(0, q_tile_size)[:, None] * num_q_heads * feature_dim +
221
+ q_head_idx * feature_dim +
222
+ tl.arange(0, feature_dim)[None, :]
223
+ )
224
+
225
+ lse_offsets = tl.arange(0, q_tile_size) * num_q_heads + q_head_idx
226
+
227
+ q_base_ptr = q_ptr + (batch_offset + q_tile_start) * num_q_heads * feature_dim
228
+ grad_o_base_ptr = grad_o_ptr + (batch_offset + q_tile_start) * num_q_heads * feature_dim
229
+ delta_base_ptr = delta_ptr + (batch_offset + q_tile_start) * num_q_heads
230
+ lse_base_ptr = lse_ptr + (batch_offset + q_tile_start) * num_q_heads
231
+ grad_q_base_ptr = grad_q_ptr + (batch_offset + q_tile_start) * num_q_heads * feature_dim
232
+
233
+ grad_q_accum = tl.zeros([q_tile_size, feature_dim], dtype=tl.float32)
234
+
235
+ q_tile = tl.load(q_base_ptr + q_offsets)
236
+ q_tile = (q_tile * softmax_scale * LOG2_E).to(tl.bfloat16)
237
+
238
+ grad_o_tile = tl.load(grad_o_base_ptr + q_offsets).to(tl.bfloat16)
239
+ delta_vals = tl.load(delta_base_ptr + lse_offsets)
240
+ lse_vals = tl.load(lse_base_ptr + lse_offsets).to(tl.float32)
241
+
242
+ for i in range(num_kv_blocks_per_q_block):
243
+ kv_block_idx = tl.load(block_indices_ptr + block_indices_offset + i).to(tl.int32)
244
+
245
+ k_block_ptr = tl.make_block_ptr(
246
+ base=k_ptr + (batch_offset * num_kv_heads + kv_head_idx) * feature_dim,
247
+ shape=(feature_dim, seq_len),
248
+ strides=(1, num_kv_heads * feature_dim),
249
+ offsets=(0, kv_block_idx * kv_block_size),
250
+ block_shape=(feature_dim, kv_block_size),
251
+ order=(0, 1)
252
+ )
253
+
254
+ v_block_ptr = tl.make_block_ptr(
255
+ base=v_ptr + (batch_offset * num_kv_heads + kv_head_idx) * feature_dim,
256
+ shape=(feature_dim, seq_len),
257
+ strides=(1, num_kv_heads * feature_dim),
258
+ offsets=(0, kv_block_idx * kv_block_size),
259
+ block_shape=(feature_dim, kv_block_size),
260
+ order=(0, 1)
261
+ )
262
+
263
+ k_block = tl.load(k_block_ptr).to(tl.bfloat16)
264
+ v_block = tl.load(v_block_ptr).to(tl.bfloat16)
265
+
266
+ attention_scores = tl.dot(q_tile, k_block)
267
+ attention_probs = tl.exp2(attention_scores - lse_vals[:, None]) * LN_2
268
+
269
+ grad_times_v = tl.dot(grad_o_tile, v_block)
270
+ grad_scores = attention_probs * (grad_times_v - delta_vals[:, None])
271
+ grad_q_accum += tl.dot(grad_scores.to(tl.bfloat16), tl.trans(k_block.to(tl.bfloat16)))
272
+
273
+ grad_q_accum = grad_q_accum * softmax_scale * LOG2_E
274
+ tl.store(grad_q_base_ptr + q_offsets, grad_q_accum.to(q_ptr.dtype.element_ty))
275
+
276
+
277
+ @torch.compile
278
+ @torch.no_grad()
279
+ def mosaic_block_mask(
280
+ block_indices: torch.LongTensor,
281
+ ):
282
+ batch_size, num_blocks, num_heads, _ = block_indices.shape
283
+
284
+ block_mask = torch.zeros(
285
+ batch_size, num_blocks, num_heads, num_blocks,
286
+ dtype=torch.bool, device=block_indices.device
287
+ )
288
+
289
+ batch_idx = torch.arange(batch_size, device=block_indices.device)[:, None, None, None]
290
+ q_block_idx = torch.arange(num_blocks, device=block_indices.device)[None, :, None, None]
291
+ head_idx = torch.arange(num_heads, device=block_indices.device)[None, None, :, None]
292
+
293
+ block_mask[batch_idx, q_block_idx, head_idx, block_indices] = True
294
+
295
+ block_mask_transposed = block_mask.permute(0, 2, 3, 1).contiguous()
296
+
297
+ return block_mask_transposed
298
+
299
+
300
+ @triton.autotune(
301
+ configs=get_autotuning_configs([16, 32]),
302
+ key=['seq_len', 'feature_dim'],
303
+ )
304
+ @triton.jit
305
+ def mosaic_attn_bwd_kv_kernel(
306
+ q_ptr, k_ptr, v_ptr, lse_ptr, delta_ptr,
307
+ grad_o_ptr, grad_k_ptr, grad_v_ptr,
308
+ block_mask_ptr,
309
+ softmax_scale: tl.constexpr,
310
+ seq_len: tl.constexpr,
311
+ num_kv_heads: tl.constexpr,
312
+ num_q_heads: tl.constexpr,
313
+ q_heads_per_kv_head: tl.constexpr,
314
+ feature_dim: tl.constexpr,
315
+ kv_block_size: tl.constexpr,
316
+ q_tile_size: tl.constexpr,
317
+ ):
318
+ LOG2_E: tl.constexpr = 1.44269504089
319
+ LN_2: tl.constexpr = 0.69314718056
320
+
321
+ kv_block_id = tl.program_id(0)
322
+ batch_kv_head_id = tl.program_id(1)
323
+
324
+ batch_idx = batch_kv_head_id // num_kv_heads
325
+ kv_head_idx = batch_kv_head_id % num_kv_heads
326
+ batch_offset = batch_idx * seq_len
327
+
328
+ num_blocks_in_seq = seq_len // kv_block_size
329
+ tiles_per_block = kv_block_size // q_tile_size
330
+
331
+ fine_mask_start = (
332
+ batch_idx * num_kv_heads * num_blocks_in_seq * num_blocks_in_seq +
333
+ kv_head_idx * num_blocks_in_seq * num_blocks_in_seq +
334
+ kv_block_id * num_blocks_in_seq
335
+ )
336
+
337
+ k_block_ptr = tl.make_block_ptr(
338
+ k_ptr + (batch_offset * num_kv_heads + kv_head_idx) * feature_dim,
339
+ (seq_len, feature_dim), (num_kv_heads * feature_dim, 1),
340
+ (kv_block_id * kv_block_size, 0), (kv_block_size, feature_dim), (1, 0)
341
+ )
342
+
343
+ v_block_ptr = tl.make_block_ptr(
344
+ v_ptr + (batch_offset * num_kv_heads + kv_head_idx) * feature_dim,
345
+ (seq_len, feature_dim), (num_kv_heads * feature_dim, 1),
346
+ (kv_block_id * kv_block_size, 0), (kv_block_size, feature_dim), (1, 0)
347
+ )
348
+
349
+ grad_k_ptr = tl.make_block_ptr(
350
+ grad_k_ptr + (batch_offset * num_kv_heads + kv_head_idx) * feature_dim,
351
+ (seq_len, feature_dim), (num_kv_heads * feature_dim, 1),
352
+ (kv_block_id * kv_block_size, 0), (kv_block_size, feature_dim), (1, 0)
353
+ )
354
+
355
+ grad_v_ptr = tl.make_block_ptr(
356
+ grad_v_ptr + (batch_offset * num_kv_heads + kv_head_idx) * feature_dim,
357
+ (seq_len, feature_dim), (num_kv_heads * feature_dim, 1),
358
+ (kv_block_id * kv_block_size, 0), (kv_block_size, feature_dim), (1, 0)
359
+ )
360
+
361
+ k_block = tl.load(k_block_ptr).to(tl.bfloat16)
362
+ v_block = tl.load(v_block_ptr).to(tl.bfloat16)
363
+
364
+ grad_k_accum = tl.zeros([kv_block_size, feature_dim], dtype=tl.float32)
365
+ grad_v_accum = tl.zeros([kv_block_size, feature_dim], dtype=tl.float32)
366
+
367
+ for q_block_id in range(num_blocks_in_seq):
368
+ is_connected = tl.load(block_mask_ptr + fine_mask_start + q_block_id)
369
+
370
+ if is_connected:
371
+ for tile_in_block in range(tiles_per_block):
372
+ tile_idx = q_block_id * tiles_per_block + tile_in_block
373
+ q_tile_start = tile_idx * q_tile_size
374
+
375
+ q_tile_ptr = tl.make_block_ptr(
376
+ base=q_ptr + (batch_offset + q_tile_start) * num_q_heads * feature_dim,
377
+ shape=(q_tile_size, num_q_heads, feature_dim),
378
+ strides=(num_q_heads * feature_dim, feature_dim, 1),
379
+ offsets=(0, kv_head_idx * q_heads_per_kv_head, 0),
380
+ block_shape=(q_tile_size, q_heads_per_kv_head, feature_dim),
381
+ order=(0, 1, 2),
382
+ )
383
+
384
+ grad_o_tile_ptr = tl.make_block_ptr(
385
+ base=grad_o_ptr + (batch_offset + q_tile_start) * num_q_heads * feature_dim,
386
+ shape=(q_tile_size, num_q_heads, feature_dim),
387
+ strides=(num_q_heads * feature_dim, feature_dim, 1),
388
+ offsets=(0, kv_head_idx * q_heads_per_kv_head, 0),
389
+ block_shape=(q_tile_size, q_heads_per_kv_head, feature_dim),
390
+ order=(0, 1, 2),
391
+ )
392
+
393
+ lse_tile_ptr = tl.make_block_ptr(
394
+ base=lse_ptr + (batch_offset + q_tile_start) * num_q_heads,
395
+ shape=(q_tile_size, num_q_heads),
396
+ strides=(num_q_heads, 1),
397
+ offsets=(0, kv_head_idx * q_heads_per_kv_head),
398
+ block_shape=(q_tile_size, q_heads_per_kv_head),
399
+ order=(1, 0),
400
+ )
401
+
402
+ delta_tile_ptr = tl.make_block_ptr(
403
+ base=delta_ptr + (batch_offset + q_tile_start) * num_q_heads,
404
+ shape=(q_tile_size, num_q_heads),
405
+ strides=(num_q_heads, 1),
406
+ offsets=(0, kv_head_idx * q_heads_per_kv_head),
407
+ block_shape=(q_tile_size, q_heads_per_kv_head),
408
+ order=(1, 0),
409
+ )
410
+
411
+ q_tile = tl.load(q_tile_ptr) * softmax_scale * LOG2_E
412
+ q_tile = tl.reshape(q_tile, (q_tile_size * q_heads_per_kv_head, feature_dim))
413
+ q_tile = q_tile.to(tl.bfloat16)
414
+
415
+ grad_o_block = tl.load(grad_o_tile_ptr)
416
+ grad_o_block = tl.reshape(grad_o_block, (q_tile_size * q_heads_per_kv_head, feature_dim))
417
+ grad_o_block = grad_o_block.to(tl.bfloat16)
418
+
419
+ lse_vals = tl.load(lse_tile_ptr)
420
+ lse_vals = tl.reshape(lse_vals, (q_tile_size * q_heads_per_kv_head,))
421
+
422
+ delta_vals = tl.load(delta_tile_ptr)
423
+ delta_vals = tl.reshape(delta_vals, (q_tile_size * q_heads_per_kv_head,))
424
+
425
+ attention_scores = tl.dot(k_block, tl.trans(q_tile))
426
+ attention_probs = tl.exp2(attention_scores - lse_vals[None, :])
427
+ grad_v_accum += tl.dot(attention_probs.to(tl.bfloat16), grad_o_block)
428
+ grad_times_v = tl.dot(v_block, tl.trans(grad_o_block))
429
+ grad_scores = attention_probs * (grad_times_v - delta_vals[None, :]) * LN_2
430
+ grad_k_accum += tl.dot(grad_scores.to(tl.bfloat16), q_tile)
431
+
432
+ tl.store(grad_k_ptr, grad_k_accum.to(grad_k_ptr.dtype.element_ty))
433
+ tl.store(grad_v_ptr, grad_v_accum.to(grad_v_ptr.dtype.element_ty))
434
+
435
+
436
+ def mosaic_attn_bwd(
437
+ q: torch.Tensor,
438
+ k: torch.Tensor,
439
+ v: torch.Tensor,
440
+ output: torch.Tensor,
441
+ lse: torch.Tensor,
442
+ grad_o: torch.Tensor,
443
+ softmax_scale: float,
444
+ block_indices: torch.LongTensor,
445
+ block_size: int,
446
+ ):
447
+ batch_size, seq_len, num_kv_heads, feature_dim = k.shape
448
+ num_q_heads = q.shape[2]
449
+ num_kv_blocks_per_q_block = block_indices.shape[-1]
450
+ q_heads_per_kv_head = num_q_heads // num_kv_heads
451
+ num_blocks_in_seq = seq_len // block_size
452
+
453
+ grad_q = torch.empty_like(q)
454
+ grad_k = torch.empty_like(k)
455
+ grad_v = torch.empty_like(v)
456
+
457
+ block_mask = mosaic_block_mask(block_indices)
458
+
459
+ delta = (output * grad_o).sum(dim=-1)
460
+
461
+ grid_dq = lambda META: (
462
+ triton.cdiv(seq_len, META['q_tile_size']),
463
+ q_heads_per_kv_head,
464
+ batch_size * num_kv_heads
465
+ )
466
+
467
+ mosaic_attn_bwd_q_kernel[grid_dq](
468
+ q_ptr=q,
469
+ k_ptr=k,
470
+ v_ptr=v,
471
+ lse_ptr=lse,
472
+ delta_ptr=delta,
473
+ grad_o_ptr=grad_o,
474
+ grad_q_ptr=grad_q,
475
+ block_indices_ptr=block_indices,
476
+ softmax_scale=softmax_scale,
477
+ seq_len=seq_len,
478
+ num_kv_heads=num_kv_heads,
479
+ num_q_heads=num_q_heads,
480
+ q_heads_per_kv_head=q_heads_per_kv_head,
481
+ feature_dim=feature_dim,
482
+ kv_block_size=block_size,
483
+ num_kv_blocks_per_q_block=num_kv_blocks_per_q_block,
484
+ )
485
+
486
+ grid_dkv = (num_blocks_in_seq, batch_size * num_kv_heads)
487
+
488
+ mosaic_attn_bwd_kv_kernel[grid_dkv](
489
+ q_ptr=q,
490
+ k_ptr=k,
491
+ v_ptr=v,
492
+ lse_ptr=lse,
493
+ delta_ptr=delta,
494
+ grad_o_ptr=grad_o,
495
+ grad_k_ptr=grad_k,
496
+ grad_v_ptr=grad_v,
497
+ block_mask_ptr=block_mask,
498
+ softmax_scale=softmax_scale,
499
+ seq_len=seq_len,
500
+ num_kv_heads=num_kv_heads,
501
+ num_q_heads=num_q_heads,
502
+ q_heads_per_kv_head=q_heads_per_kv_head,
503
+ feature_dim=feature_dim,
504
+ kv_block_size=block_size,
505
+ )
506
+
507
+ return grad_q, grad_k, grad_v
508
+
509
+
510
+ class MosaicAttnFunction(torch.autograd.Function):
511
+
512
+ @staticmethod
513
+ @torch.amp.custom_fwd(device_type='cuda')
514
+ def forward(
515
+ ctx: torch.autograd.function.FunctionCtx,
516
+ q: torch.Tensor,
517
+ k: torch.Tensor,
518
+ v: torch.Tensor,
519
+ block_indices: torch.Tensor,
520
+ block_size: int,
521
+ softmax_scale: float
522
+ ):
523
+ q, k, v, block_indices = map(lambda x: x.contiguous(), (q, k, v, block_indices))
524
+
525
+ ctx.dtype = q.dtype
526
+
527
+ output, lse = mosaic_attn_fwd(
528
+ q=q, k=k, v=v,
529
+ block_indices=block_indices,
530
+ block_size=block_size,
531
+ softmax_scale=softmax_scale,
532
+ )
533
+
534
+ ctx.save_for_backward(q, k, v, output, lse, block_indices)
535
+ ctx.block_size = block_size
536
+ ctx.softmax_scale = softmax_scale
537
+
538
+ return output.to(q.dtype)
539
+
540
+ @staticmethod
541
+ @torch.amp.custom_bwd(device_type='cuda')
542
+ def backward(
543
+ ctx: torch.autograd.function.FunctionCtx,
544
+ grad_o: torch.Tensor
545
+ ):
546
+ q, k, v, output, lse, block_indices = ctx.saved_tensors
547
+ grad_o = grad_o.contiguous()
548
+ grad_q, grad_k, grad_v = mosaic_attn_bwd(
549
+ q=q, k=k, v=v, output=output, lse=lse, grad_o=grad_o,
550
+ softmax_scale=ctx.softmax_scale,
551
+ block_indices=block_indices,
552
+ block_size=ctx.block_size,
553
+ )
554
+ return grad_q, grad_k, grad_v, None, None, None
555
+
556
+
557
+ def mosaic_sparse_attn(
558
+ q: torch.Tensor,
559
+ k: torch.Tensor,
560
+ v: torch.Tensor,
561
+ block_indices: torch.LongTensor,
562
+ block_size: int,
563
+ softmax_scale: float = None,
564
+ ):
565
+ softmax_scale = q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale
566
+ return MosaicAttnFunction.apply(q, k, v, block_indices, block_size, softmax_scale)
primitives.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Primitive building blocks for the Mosaic transformer.
3
+
4
+ Components:
5
+ - Block-sparse attention with learned strategy weighting (local block, compressed,
6
+ and top-k selection branches combined with a learned gate)
7
+ - Rotary positional embeddings (RoPE) for 2D lon/lat
8
+ - Cross-attention interpolation between grids
9
+ - HEALPix spatial up/downsampling
10
+ - Conditional SwiGLU FFN with noise injection
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from einops import rearrange, reduce, repeat
16
+ from torch.nn import RMSNorm
17
+
18
+ try:
19
+ from flash_attn import flash_attn_func # FlashAttention v2
20
+ except ImportError:
21
+ import flash_attn_interface as fa # FlashAttention v3
22
+ flash_attn_func = fa.flash_attn_func
23
+
24
+ from utils import get_healpix_grid, get_neighbors, rad_to_xyz
25
+ from ops import mosaic_sparse_attn
26
+
27
+
28
+ def block_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, block_size: int):
29
+ batch_size = q.shape[0]
30
+ q, k, v = map(lambda x: rearrange(x, 'b (nb bs) h d -> (b nb) bs h d', bs=block_size), (q, k, v))
31
+ o_ba = flash_attn_func(q, k, v)
32
+ return rearrange(o_ba, '(b nb) bs h d -> b (nb bs) h d', b=batch_size)
33
+
34
+
35
+ @torch.no_grad()
36
+ def attn_topk(q: torch.Tensor, k: torch.Tensor, block_count: int):
37
+ Hq, Hk = q.shape[2], k.shape[2]
38
+ G = Hq // Hk
39
+ k = k.repeat_interleave(G, dim=2)
40
+
41
+ scores = torch.matmul(
42
+ rearrange(q, 'b t h d -> b h t d'),
43
+ rearrange(k, 'b t h d -> b h d t')
44
+ )
45
+
46
+ if Hq != Hk:
47
+ scores = reduce(scores, 'b (g h) t k -> b h t k', 'mean', g=G)
48
+
49
+ scores = rearrange(scores, 'b h t k -> b t h k')
50
+ top_indices = scores.topk(k=block_count, dim=-1, largest=True)[1]
51
+ return top_indices
52
+
53
+
54
+ def mosaic_attn_func(
55
+ q, k, v,
56
+ weight_ba_cmp_slc,
57
+ block_attn_size, sparse_block_size, sparse_block_count,
58
+ block_attn_only, no_compression=False,
59
+ ):
60
+ o_ba = block_attention(q, k, v, block_attn_size)
61
+
62
+ if block_attn_only:
63
+ return o_ba
64
+
65
+ q_cmp = reduce(q, 'b (nb bs) h d -> b nb h d', 'mean', bs=sparse_block_size)
66
+ k_cmp = reduce(k, 'b (nb bs) h d -> b nb h d', 'mean', bs=sparse_block_size)
67
+
68
+ if no_compression:
69
+ block_indices = attn_topk(q_cmp, k_cmp, sparse_block_count)
70
+ o_slc = mosaic_sparse_attn(q, k, v, block_indices, sparse_block_size)
71
+ w_ba = weight_ba_cmp_slc[0]
72
+ w_slc = weight_ba_cmp_slc[2]
73
+ w_sum = w_ba + w_slc + 1e-8
74
+ return o_ba * (w_ba / w_sum) + o_slc * (w_slc / w_sum)
75
+
76
+ v_cmp = reduce(v, 'b (nb bs) h d -> b nb h d', 'mean', bs=sparse_block_size)
77
+ o_cmp = flash_attn_func(q_cmp, k_cmp, v_cmp)
78
+ o_cmp = o_cmp.repeat_interleave(sparse_block_size, dim=1)
79
+
80
+ if sparse_block_count == 0:
81
+ w_ba = weight_ba_cmp_slc[0]
82
+ w_cmp = weight_ba_cmp_slc[1]
83
+ w_sum = w_ba + w_cmp + 1e-8
84
+ return o_ba * (w_ba / w_sum) + o_cmp * (w_cmp / w_sum)
85
+
86
+ block_indices = attn_topk(q_cmp, k_cmp, sparse_block_count)
87
+ o_slc = mosaic_sparse_attn(q, k, v, block_indices, sparse_block_size)
88
+
89
+ return o_ba * weight_ba_cmp_slc[0] + o_cmp * weight_ba_cmp_slc[1] + o_slc * weight_ba_cmp_slc[2]
90
+
91
+
92
+ class cSwiGLU(nn.Module):
93
+ def __init__(self, dim: int, hidden_dim: int, noise_dim: int):
94
+ super().__init__()
95
+ self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False)
96
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
97
+ self.act_fn = nn.SiLU()
98
+
99
+ if noise_dim > 0:
100
+ self.noise_bias = nn.Linear(noise_dim, hidden_dim, bias=False)
101
+
102
+ def forward(self, x: torch.Tensor, z: torch.Tensor = None):
103
+ noise = self.noise_bias(z).unsqueeze(0) if z is not None else 0
104
+ x1, x3 = self.w13(x).chunk(2, dim=-1)
105
+ return self.w2(self.act_fn(x1 + noise) * x3)
106
+
107
+
108
+ class RoPE(nn.Module):
109
+ def __init__(self, dim, theta=10000):
110
+ super().__init__()
111
+ assert dim % 2 == 0
112
+ self.dim = dim
113
+ self.theta = theta
114
+
115
+ def initialize_rope(self, positions):
116
+ base_freqs = 1. / (self.theta ** (torch.arange(0, self.dim // 2, 2).float() / (self.dim // 2)))
117
+ lon_pos = torch.deg2rad(positions[:, 0:1])
118
+ lat_pos = torch.deg2rad(positions[:, 1:2])
119
+ lon_freqs = torch.matmul(lon_pos, base_freqs.unsqueeze(0))
120
+ lat_freqs = torch.matmul(lat_pos, base_freqs.unsqueeze(0))
121
+ freqs = torch.cat([lon_freqs, lat_freqs], dim=-1)
122
+ self.register_buffer('cos_freqs', freqs.cos().contiguous(), persistent=True)
123
+ self.register_buffer('sin_freqs', freqs.sin().contiguous(), persistent=True)
124
+
125
+ @staticmethod
126
+ def rotate_half(x):
127
+ x = rearrange(x, '... (d r) -> ... d r', r=2)
128
+ x1, x2 = x.unbind(dim=-1)
129
+ x = torch.stack((-x2, x1), dim=-1)
130
+ return rearrange(x, '... d r -> ... (d r)')
131
+
132
+ def forward(self, x):
133
+ cos = self.cos_freqs.unsqueeze(0).unsqueeze(2).repeat_interleave(2, dim=-1)
134
+ sin = self.sin_freqs.unsqueeze(0).unsqueeze(2).repeat_interleave(2, dim=-1)
135
+ return (x.float() * cos + self.rotate_half(x.float()) * sin).to(x.dtype)
136
+
137
+
138
+ class MosaicAttention(nn.Module):
139
+ def __init__(self, config, block_attn_only: bool, no_compression: bool = False):
140
+ super().__init__()
141
+ self.block_attn_only = block_attn_only
142
+ self.no_compression = no_compression
143
+ self.block_attn_size = config.block_attn_size
144
+ self.sparse_block_size = config.sparse_block_size
145
+ self.sparse_block_count = config.sparse_block_count
146
+
147
+ q_heads = config.num_heads
148
+ gqa_ratio = config.gqa_ratio
149
+ dim = config.dim
150
+ qkv_compress_ratio = config.qkv_compress_ratio
151
+ rope = config.rope
152
+ rope_theta = config.rope_theta
153
+
154
+ kv_heads = q_heads // gqa_ratio
155
+ head_dim = int(dim // q_heads // qkv_compress_ratio)
156
+
157
+ self.q_heads = q_heads
158
+ self.kv_heads = kv_heads
159
+
160
+ self.to_q = nn.Linear(dim, q_heads * head_dim, bias=False)
161
+ self.to_k = nn.Linear(dim, kv_heads * head_dim, bias=False)
162
+ self.to_v = nn.Linear(dim, kv_heads * head_dim, bias=False)
163
+ self.to_o = nn.Linear(q_heads * head_dim, dim, bias=False)
164
+
165
+ self.q_rope = RoPE(head_dim, rope_theta) if rope else None
166
+ self.k_rope = RoPE(head_dim, rope_theta) if rope else None
167
+
168
+ if block_attn_only:
169
+ self.to_strategy_combine_mlp = None
170
+ else:
171
+ self.to_strategy_combine_mlp = nn.Linear(dim, 3 * q_heads, bias=False)
172
+
173
+ def generate_strategy_weights(self, x):
174
+ if self.block_attn_only:
175
+ return [None, None, None]
176
+ strategy_logits = self.to_strategy_combine_mlp(x)
177
+ strategy_logits = rearrange(strategy_logits, 't b (h s) -> s b t h', h=self.q_heads)
178
+ strategy_weights = torch.softmax(strategy_logits.float(), dim=0).type_as(x)
179
+ strategy_weights = strategy_weights.unsqueeze(-1)
180
+ return strategy_weights
181
+
182
+ def forward(self, x):
183
+ q = self.to_q(x)
184
+ k = self.to_k(x)
185
+ v = self.to_v(x)
186
+
187
+ strategy_weights = self.generate_strategy_weights(x)
188
+
189
+ q = rearrange(q, 's b (h d) -> b s h d', h=self.q_heads)
190
+ k = rearrange(k, 's b (h d) -> b s h d', h=self.kv_heads)
191
+ v = rearrange(v, 's b (h d) -> b s h d', h=self.kv_heads)
192
+
193
+ if self.q_rope is not None:
194
+ q = self.q_rope(q)
195
+ k = self.k_rope(k)
196
+
197
+ output = mosaic_attn_func(
198
+ q=q, k=k, v=v,
199
+ weight_ba_cmp_slc=strategy_weights,
200
+ block_attn_size=self.block_attn_size,
201
+ sparse_block_size=self.sparse_block_size,
202
+ sparse_block_count=self.sparse_block_count,
203
+ block_attn_only=self.block_attn_only,
204
+ no_compression=self.no_compression,
205
+ )
206
+
207
+ output = rearrange(output, 'b s h d -> s b (h d)')
208
+ output = self.to_o(output)
209
+ return output
210
+
211
+
212
+ class MosaicBlock(nn.Module):
213
+ def __init__(self, config, block_attn_only: bool, no_compression: bool = False):
214
+ super().__init__()
215
+ dim = config.dim
216
+ noise_dim = config.noise_dim
217
+ mlp_ratio = config.mlp_ratio
218
+
219
+ self.attention = MosaicAttention(config, block_attn_only, no_compression)
220
+ self.norm1 = RMSNorm(dim, elementwise_affine=config.rmsnorm_elementwise_affine)
221
+ self.norm2 = RMSNorm(dim, elementwise_affine=config.rmsnorm_elementwise_affine)
222
+ self.ffn = cSwiGLU(dim, int(dim * mlp_ratio), noise_dim)
223
+
224
+ def forward(self, x: torch.Tensor, z: torch.Tensor = None):
225
+ x = x + self.attention(self.norm1(x))
226
+ x = x + self.ffn(self.norm2(x), z)
227
+ return x
228
+
229
+
230
+ class CrossAttentionInterpolate(nn.Module):
231
+ space_dim = 3
232
+
233
+ def __init__(self, config):
234
+ super().__init__()
235
+ self.k_neighbors = config.k_neighbors
236
+
237
+ dim = config.dim
238
+ num_heads = config.num_heads
239
+
240
+ head_dim = dim // num_heads
241
+ self.num_heads = num_heads
242
+ self.head_dim = head_dim
243
+ self.scale = head_dim ** -0.5
244
+
245
+ self.kv_norm = RMSNorm(dim, elementwise_affine=config.rmsnorm_elementwise_affine)
246
+ self.to_q = nn.Linear(self.space_dim, dim, bias=False)
247
+ self.to_kv = nn.Linear(dim, 2 * dim, bias=False)
248
+ self.to_o = nn.Linear(dim, dim, bias=False)
249
+
250
+ @torch.no_grad()
251
+ def initialize_interpolation_scheme(self, pos_from_rad, pos_to_rad):
252
+ neighbors_np = get_neighbors(pos_from_rad.cpu().numpy(), pos_to_rad.cpu().numpy(), k=self.k_neighbors)
253
+ neighbors = torch.from_numpy(neighbors_np).long().to(pos_from_rad.device).contiguous()
254
+
255
+ pos_to_xyz = rad_to_xyz(pos_to_rad)
256
+ pos_from_xyz = rad_to_xyz(pos_from_rad)
257
+
258
+ rel_pos_xyz = (pos_to_xyz.unsqueeze(1) - pos_from_xyz[neighbors]).contiguous()
259
+ norm_rel_pos_xyz = torch.nn.functional.normalize(rel_pos_xyz, dim=-1).contiguous()
260
+
261
+ self.register_buffer('neighbors', neighbors, persistent=True)
262
+ self.register_buffer('rel_pos', norm_rel_pos_xyz, persistent=True)
263
+
264
+ def forward(self, x_from: torch.Tensor):
265
+ if self.neighbors is None or self.rel_pos is None:
266
+ raise ValueError("Interpolation scheme not initialized.")
267
+
268
+ q = self.to_q(self.rel_pos)
269
+ q = rearrange(q, 's k (h d) -> s k 1 h d', h=self.num_heads)
270
+
271
+ x = self.kv_norm(x_from)
272
+
273
+ kv = self.to_kv(x)
274
+ kv = rearrange(kv, 's b (n h d) -> n s b h d', h=self.num_heads, n=2)
275
+
276
+ k, v = kv[:, self.neighbors]
277
+
278
+ attn_scores = (q * k).sum(dim=-1, keepdim=True) * self.scale
279
+ attn_weights = torch.softmax(attn_scores, dim=1, dtype=torch.float32).type_as(k)
280
+ out = (attn_weights * v).sum(dim=1)
281
+
282
+ out = rearrange(out, 's b h d -> s b (h d)')
283
+ out = self.to_o(out)
284
+ return out
285
+
286
+
287
+ class NoiseGenerator(nn.Module):
288
+ def __init__(self, noise_dim: int, seed: int):
289
+ super().__init__()
290
+ self.seed = seed
291
+ self.to_noise = nn.Linear(noise_dim, noise_dim, bias=False)
292
+ self.generator = None
293
+
294
+ def forward(self, num_samples: int, device: torch.device, dtype: torch.dtype):
295
+ if self.generator is None:
296
+ self.generator = torch.Generator(device=device)
297
+ self.generator.manual_seed(self.seed)
298
+
299
+ noise = torch.randn((num_samples, self.to_noise.in_features),
300
+ generator=self.generator, device=device, dtype=dtype)
301
+ noise = self.to_noise(noise)
302
+ return noise
303
+
304
+
305
+ class HEALPixDownsample(nn.Module):
306
+ space_dim: int = 3
307
+
308
+ def __init__(self, in_dim, out_dim, nside_before, nside_after,
309
+ rmsnorm_elementwise_affine=True):
310
+ super().__init__()
311
+ self.factor = (nside_before // nside_after) ** 2
312
+
313
+ self.proj_x = nn.Linear(self.factor * in_dim, out_dim, bias=False)
314
+ self.proj_pos = nn.Linear(self.factor * self.space_dim, out_dim, bias=False)
315
+ self.norm = RMSNorm(out_dim, elementwise_affine=rmsnorm_elementwise_affine)
316
+
317
+ hp_grid_fine_xyz = rad_to_xyz(torch.deg2rad(get_healpix_grid(nside_before)))
318
+ hp_grid_coarse_xyz = rad_to_xyz(torch.deg2rad(get_healpix_grid(nside_after)))
319
+
320
+ pos = rearrange(hp_grid_fine_xyz, '(n f) d -> n f d', f=self.factor)
321
+ rel_pos = rearrange(pos - hp_grid_coarse_xyz[:, None], 'n f d -> n (f d)')
322
+ rel_pos = (rel_pos - rel_pos.mean(dim=0, keepdim=True)) / (rel_pos.std(dim=0, keepdim=True) + 1e-6)
323
+
324
+ self.register_buffer('rel_pos', rel_pos.contiguous(), persistent=True)
325
+
326
+ def forward(self, x: torch.Tensor):
327
+ x = rearrange(x, '(n f) b c -> n b (f c)', f=self.factor)
328
+ x = self.proj_x(x) + self.proj_pos(self.rel_pos).unsqueeze(1)
329
+ x = self.norm(x)
330
+ return x
331
+
332
+
333
+ class HEALPixUpsample(nn.Module):
334
+ space_dim: int = 3
335
+
336
+ def __init__(self, in_dim, out_dim, nside_before, nside_after,
337
+ rmsnorm_elementwise_affine=True):
338
+ super().__init__()
339
+ self.factor = (nside_after // nside_before) ** 2
340
+
341
+ self.proj_x = nn.Linear(in_dim, out_dim * self.factor, bias=False)
342
+ self.proj_pos = nn.Linear(self.factor * self.space_dim, out_dim * self.factor, bias=False)
343
+ self.norm = RMSNorm(out_dim, elementwise_affine=rmsnorm_elementwise_affine)
344
+
345
+ hp_grid_fine_xyz = rad_to_xyz(torch.deg2rad(get_healpix_grid(nside_after)))
346
+ hp_grid_coarse_xyz = rad_to_xyz(torch.deg2rad(get_healpix_grid(nside_before)))
347
+
348
+ children_pos_reshaped = rearrange(hp_grid_fine_xyz, '(n f) d -> n f d', f=self.factor)
349
+ rel_pos = rearrange(children_pos_reshaped - hp_grid_coarse_xyz[:, None], 'n f d -> n (f d)')
350
+ rel_pos = (rel_pos - rel_pos.mean(dim=0, keepdim=True)) / (rel_pos.std(dim=0, keepdim=True) + 1e-6)
351
+
352
+ self.register_buffer('rel_pos', rel_pos.contiguous(), persistent=True)
353
+
354
+ def forward(self, x: torch.Tensor, shortcut: torch.Tensor):
355
+ x = self.proj_x(x) + self.proj_pos(self.rel_pos).unsqueeze(1)
356
+ x = rearrange(x, 'n b (f d) -> (n f) b d', f=self.factor)
357
+ x = x + shortcut
358
+ x = self.norm(x)
359
+ return x
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0
2
+ einops>=0.6
3
+ healpy>=1.16
4
+ scikit-learn>=1.0
5
+ numpy>=1.24
6
+ zarr>=2.10
7
+ pandas>=1.5
8
+ triton>=2.0
9
+ flash-attn>=2.0
10
+ # For reading WeatherBench2 zarr stores directly from gs://
11
+ gcsfs>=2023.0
utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import healpy as hp
4
+ from sklearn.neighbors import BallTree
5
+
6
+
7
+ def rad_to_xyz(lonlat: torch.Tensor):
8
+ """Convert lon-lat (in radians) to unit sphere xyz."""
9
+ lon = lonlat[..., 0]
10
+ lat = lonlat[..., 1]
11
+
12
+ x = torch.cos(lat) * torch.cos(lon)
13
+ y = torch.cos(lat) * torch.sin(lon)
14
+ z = torch.sin(lat)
15
+
16
+ return torch.stack([x, y, z], axis=-1)
17
+
18
+
19
+ def get_healpix_grid(nside: int) -> torch.Tensor:
20
+ """Return HEALPix grid coordinates as array of shape (npix, 2)."""
21
+ indices = np.arange(hp.nside2npix(nside))
22
+ theta, phi = hp.pix2ang(nside, indices, nest=True)
23
+
24
+ phi = np.rad2deg(phi)
25
+ theta = (90. - np.rad2deg(theta))
26
+
27
+ phi = torch.from_numpy(phi)
28
+ theta = torch.from_numpy(theta)
29
+
30
+ return torch.stack((phi, theta), axis=-1).float()
31
+
32
+
33
+ def get_neighbors(pos_from: np.ndarray, pos_to: np.ndarray, k: int = 8) -> tuple:
34
+ """Build a BallTree and query k nearest neighbors with haversine metric."""
35
+ pos_from_rad = pos_from[:, ::-1]
36
+ pos_to_rad = pos_to[:, ::-1]
37
+
38
+ tree = BallTree(pos_from_rad, metric='haversine')
39
+ _, neighbors = tree.query(pos_to_rad, k=k)
40
+ return neighbors