""" Mosaic: U-Net transformer with block-sparse attention for weather forecasting. Architecture: - Cross-attention interpolation between lon/lat and HEALPix grids - Block-sparse attention (local block + compressed + top-k selection branches) arranged in a U-Net encoder–bottleneck–decoder - Probabilistic training with noise injection """ import math import torch import torch.nn as nn from einops import rearrange, repeat from dataclasses import dataclass from torch.nn import RMSNorm from utils import get_healpix_grid, rad_to_xyz from primitives import ( MosaicBlock as _MosaicBlock, CrossAttentionInterpolate, NoiseGenerator, HEALPixDownsample, HEALPixUpsample, ) @dataclass class StageConfig: """Configuration for a U-Net encoder/decoder stage.""" nside: int dim: int num_heads: int block_attn_size: int sparse_block_size: int sparse_block_count: int encoder_depth: int decoder_depth: int mlp_ratio: float gqa_ratio: int @dataclass class BottleneckConfig: """Configuration for the U-Net bottleneck stage.""" nside: int dim: int num_heads: int block_attn_size: int sparse_block_size: int sparse_block_count: int depth: int mlp_ratio: float gqa_ratio: int @dataclass class ModelConfig: """Configuration for the Mosaic model.""" dim: int num_heads: int k_neighbors: int qk_norm: bool rope: bool rope_theta: int sparse_every: int variables: list[str] static_variables: list[str] qkv_compress_ratio: int cg_stage_cfgs: list[StageConfig] bottleneck_cfg: BottleneckConfig num_history_steps: int = 1 noise_dim: int = 32 ortho_init: bool = False rmsnorm_elementwise_affine: bool = True no_compression: bool = False @dataclass class _MergedStageConfig: """Merges ModelConfig and StageConfig for compatibility with MosaicBlock.""" dim: int num_heads: int block_attn_size: int sparse_block_size: int sparse_block_count: int gqa_ratio: int qkv_compress_ratio: int rope: bool rope_theta: int mlp_ratio: float noise_dim: int rmsnorm_elementwise_affine: bool def _merge_configs(config: ModelConfig, stage_cfg) -> _MergedStageConfig: return _MergedStageConfig( dim=stage_cfg.dim, num_heads=stage_cfg.num_heads, block_attn_size=stage_cfg.block_attn_size, sparse_block_size=stage_cfg.sparse_block_size, sparse_block_count=stage_cfg.sparse_block_count, gqa_ratio=stage_cfg.gqa_ratio, qkv_compress_ratio=config.qkv_compress_ratio, rope=config.rope, rope_theta=config.rope_theta, mlp_ratio=stage_cfg.mlp_ratio, noise_dim=config.noise_dim, rmsnorm_elementwise_affine=config.rmsnorm_elementwise_affine, ) def _make_mosaic_block(config: ModelConfig, stage_cfg, block_attn_only: bool) -> _MosaicBlock: return _MosaicBlock(_merge_configs(config, stage_cfg), block_attn_only, no_compression=config.no_compression) class UNetStage(nn.Module): def __init__(self, config, stage_cfg, depth): super().__init__() self.nside = stage_cfg.nside self.blocks = nn.ModuleList([ _make_mosaic_block( config=config, stage_cfg=stage_cfg, block_attn_only=(config.sparse_every <= 0) or not (i % config.sparse_every == 0), ) for i in range(depth) ]) def forward(self, x, z=None): for block in self.blocks: x = block(x, z) return x class Transformer(nn.Module): """U-Net style Transformer for weather forecasting on HEALPix grids.""" space_dim = 3 time_dim = 4 def __init__(self, config: ModelConfig, seed: int = 42): super().__init__() self.config = config self.nside = config.cg_stage_cfgs[0].nside self.noise_dim = config.noise_dim initial_dim = config.dim feature_dim = (len(config.variables) * config.num_history_steps + len(config.static_variables) + self.space_dim + self.time_dim) if self.noise_dim > 0: self.noise_generator = NoiseGenerator(self.noise_dim, seed) self.preprocess = nn.Sequential( nn.Linear(feature_dim, initial_dim, bias=False), RMSNorm(initial_dim, elementwise_affine=config.rmsnorm_elementwise_affine), nn.SiLU(), nn.Linear(initial_dim, initial_dim, bias=False), RMSNorm(initial_dim, elementwise_affine=config.rmsnorm_elementwise_affine), ) self.interp_to_hp = CrossAttentionInterpolate(config) self.interp_to_ll = CrossAttentionInterpolate(config) self.encoder_stages = nn.ModuleList() self.downsample_layers = nn.ModuleList() all_stages = [*config.cg_stage_cfgs, config.bottleneck_cfg] for i in range(len(config.cg_stage_cfgs)): current_stage = all_stages[i] next_stage = all_stages[i + 1] self.encoder_stages.append(UNetStage(config=config, stage_cfg=current_stage, depth=current_stage.encoder_depth)) self.downsample_layers.append( HEALPixDownsample( in_dim=current_stage.dim, out_dim=next_stage.dim, nside_before=current_stage.nside, nside_after=next_stage.nside, rmsnorm_elementwise_affine=config.rmsnorm_elementwise_affine, ) ) self.bottleneck = UNetStage(config=config, stage_cfg=config.bottleneck_cfg, depth=config.bottleneck_cfg.depth) self.decoder_stages = nn.ModuleList() self.upsample_layers = nn.ModuleList() for i in reversed(range(len(config.cg_stage_cfgs))): prev_stage = all_stages[i + 1] current_stage = all_stages[i] self.upsample_layers.append( HEALPixUpsample( in_dim=prev_stage.dim, out_dim=current_stage.dim, nside_before=prev_stage.nside, nside_after=current_stage.nside, rmsnorm_elementwise_affine=config.rmsnorm_elementwise_affine, ) ) self.decoder_stages.append(UNetStage(config=config, stage_cfg=current_stage, depth=current_stage.decoder_depth)) self.norm_before_interp_ll = RMSNorm(initial_dim, elementwise_affine=config.rmsnorm_elementwise_affine) self.postprocess = nn.Sequential( RMSNorm(initial_dim, elementwise_affine=config.rmsnorm_elementwise_affine), nn.Linear(initial_dim, initial_dim, bias=False), nn.SiLU(), nn.Linear(initial_dim, len(config.variables), bias=False), ) self.apply(self._initialize_weights) self._zero_init_residual_layers() self.initialize_rope() def _initialize_weights(self, module): if module is self: return ortho_init = self.config.ortho_init if isinstance(module, nn.Linear): fan_in, fan_out = module.weight.size(1), module.weight.size(0) std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in)) if ortho_init: nn.init.orthogonal_(module.weight); module.weight.data.mul_(std) else: nn.init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: nn.init.zeros_(module.bias) def _zero_init_residual_layers(self): ortho_init = self.config.ortho_init for stage in [*self.encoder_stages, self.bottleneck, *self.decoder_stages]: for block in stage.blocks: if ortho_init: nn.init.orthogonal_(block.attention.to_o.weight) block.attention.to_o.weight.data.mul_(0.01) nn.init.orthogonal_(block.ffn.w2.weight) block.ffn.w2.weight.data.mul_(0.01) else: nn.init.normal_(block.attention.to_o.weight, mean=0.0, std=0.01) nn.init.normal_(block.ffn.w2.weight, mean=0.0, std=0.01) if self.noise_dim > 0: nn.init.normal_(block.ffn.noise_bias.weight, mean=0.0, std=0.01) for upsample in self.upsample_layers: if ortho_init: nn.init.orthogonal_(upsample.proj_x.weight); upsample.proj_x.weight.data.mul_(0.01) nn.init.orthogonal_(upsample.proj_pos.weight); upsample.proj_pos.weight.data.mul_(0.01) else: nn.init.normal_(upsample.proj_x.weight, mean=0.0, std=0.01) nn.init.normal_(upsample.proj_pos.weight, mean=0.0, std=0.01) if self.noise_dim > 0: nn.init.normal_(self.noise_generator.to_noise.weight, mean=0.0, std=0.01) def initialize_rope(self): if not self.config.rope: return for stage in [*self.encoder_stages, self.bottleneck, *self.decoder_stages]: hp_grid = get_healpix_grid(stage.nside) for block in stage.blocks: if block.attention.q_rope is not None: block.attention.q_rope.initialize_rope(hp_grid) block.attention.k_rope.initialize_rope(hp_grid) def initialize_interpolation(self, longitude: torch.Tensor, latitude: torch.Tensor): ll_grid_rad = torch.deg2rad(torch.stack(torch.meshgrid(longitude, latitude, indexing='ij'), -1).reshape(-1, 2)) hp_grid_rad = torch.deg2rad(get_healpix_grid(self.nside)).to(longitude.device) self.interp_to_hp.initialize_interpolation_scheme(ll_grid_rad, hp_grid_rad) self.interp_to_ll.initialize_interpolation_scheme(hp_grid_rad, ll_grid_rad) @torch.no_grad() def initialize_static_vars(self, static_vars: torch.Tensor, longitude: torch.Tensor, latitude: torch.Tensor): ll_grid_rad = torch.deg2rad(torch.stack(torch.meshgrid(longitude, latitude, indexing='ij'), -1)) ll_grid_xyz = rad_to_xyz(ll_grid_rad) static_vars = torch.concat([static_vars, ll_grid_xyz], dim=-1) static_vars_mean = static_vars.mean(dim=(0, 1), keepdim=True) static_vars_std = static_vars.std(dim=(0, 1), keepdim=True) + 1e-6 static_vars_norm = (static_vars - static_vars_mean) / static_vars_std static_vars = rearrange(static_vars_norm, 'lon lat c -> (lon lat) 1 c').contiguous() self.register_buffer('static_vars', static_vars, persistent=True) @torch.no_grad() def time_embedding(self, day_year_time: torch.Tensor): day = day_year_time[:, 0:1] year = day_year_time[:, 1:2] day_sin = torch.sin(2 * math.pi * day) day_cos = torch.cos(2 * math.pi * day) year_sin = torch.sin(2 * math.pi * year) year_cos = torch.cos(2 * math.pi * year) return torch.cat([day_sin, day_cos, year_sin, year_cos], dim=-1) def forward(self, x: torch.Tensor, day_year_time: torch.Tensor, num_noise_samples: int): b, n, _, lon, lat, _ = x.shape batch_size = b * num_noise_samples * n if self.noise_dim > 0: z = self.noise_generator(batch_size, x.device, x.dtype) else: z = None x = repeat(x, 'b n t lon lat c -> (lon lat) (b s n) (t c)', s=num_noise_samples) day_year_time = repeat(day_year_time, 'b n d -> (b s n) d', s=num_noise_samples) x = torch.cat([ x, self.static_vars.expand(-1, batch_size, -1), self.time_embedding(day_year_time).unsqueeze(0).expand(x.shape[0], -1, -1) ], dim=-1) x = self.preprocess(x) x = self.interp_to_hp(x) skip_connections = [] for encoder_stage, downsample in zip(self.encoder_stages, self.downsample_layers): x = encoder_stage(x, z) skip_connections.append(x) x = downsample(x) x = self.bottleneck(x, z) for decoder_stage, upsample, skip in zip(self.decoder_stages, self.upsample_layers, reversed(skip_connections)): x = upsample(x, skip) x = decoder_stage(x, z) x = self.norm_before_interp_ll(x) x = self.interp_to_ll(x) x = self.postprocess(x) x = rearrange(x, '(lon lat) (b n s) c -> b (n s) lon lat c', lon=lon, lat=lat, b=b, s=num_noise_samples) return x