""" Primitive building blocks for the Mosaic transformer. Components: - Block-sparse attention with learned strategy weighting (local block, compressed, and top-k selection branches combined with a learned gate) - Rotary positional embeddings (RoPE) for 2D lon/lat - Cross-attention interpolation between grids - HEALPix spatial up/downsampling - Conditional SwiGLU FFN with noise injection """ import torch import torch.nn as nn from einops import rearrange, reduce, repeat from torch.nn import RMSNorm try: from flash_attn import flash_attn_func # FlashAttention v2 except ImportError: import flash_attn_interface as fa # FlashAttention v3 flash_attn_func = fa.flash_attn_func from utils import get_healpix_grid, get_neighbors, rad_to_xyz from ops import mosaic_sparse_attn def block_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, block_size: int): batch_size = q.shape[0] q, k, v = map(lambda x: rearrange(x, 'b (nb bs) h d -> (b nb) bs h d', bs=block_size), (q, k, v)) o_ba = flash_attn_func(q, k, v) return rearrange(o_ba, '(b nb) bs h d -> b (nb bs) h d', b=batch_size) @torch.no_grad() def attn_topk(q: torch.Tensor, k: torch.Tensor, block_count: int): Hq, Hk = q.shape[2], k.shape[2] G = Hq // Hk k = k.repeat_interleave(G, dim=2) scores = torch.matmul( rearrange(q, 'b t h d -> b h t d'), rearrange(k, 'b t h d -> b h d t') ) if Hq != Hk: scores = reduce(scores, 'b (g h) t k -> b h t k', 'mean', g=G) scores = rearrange(scores, 'b h t k -> b t h k') top_indices = scores.topk(k=block_count, dim=-1, largest=True)[1] return top_indices def mosaic_attn_func( q, k, v, weight_ba_cmp_slc, block_attn_size, sparse_block_size, sparse_block_count, block_attn_only, no_compression=False, ): o_ba = block_attention(q, k, v, block_attn_size) if block_attn_only: return o_ba q_cmp = reduce(q, 'b (nb bs) h d -> b nb h d', 'mean', bs=sparse_block_size) k_cmp = reduce(k, 'b (nb bs) h d -> b nb h d', 'mean', bs=sparse_block_size) if no_compression: block_indices = attn_topk(q_cmp, k_cmp, sparse_block_count) o_slc = mosaic_sparse_attn(q, k, v, block_indices, sparse_block_size) w_ba = weight_ba_cmp_slc[0] w_slc = weight_ba_cmp_slc[2] w_sum = w_ba + w_slc + 1e-8 return o_ba * (w_ba / w_sum) + o_slc * (w_slc / w_sum) v_cmp = reduce(v, 'b (nb bs) h d -> b nb h d', 'mean', bs=sparse_block_size) o_cmp = flash_attn_func(q_cmp, k_cmp, v_cmp) o_cmp = o_cmp.repeat_interleave(sparse_block_size, dim=1) if sparse_block_count == 0: w_ba = weight_ba_cmp_slc[0] w_cmp = weight_ba_cmp_slc[1] w_sum = w_ba + w_cmp + 1e-8 return o_ba * (w_ba / w_sum) + o_cmp * (w_cmp / w_sum) block_indices = attn_topk(q_cmp, k_cmp, sparse_block_count) o_slc = mosaic_sparse_attn(q, k, v, block_indices, sparse_block_size) return o_ba * weight_ba_cmp_slc[0] + o_cmp * weight_ba_cmp_slc[1] + o_slc * weight_ba_cmp_slc[2] class cSwiGLU(nn.Module): def __init__(self, dim: int, hidden_dim: int, noise_dim: int): super().__init__() self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.act_fn = nn.SiLU() if noise_dim > 0: self.noise_bias = nn.Linear(noise_dim, hidden_dim, bias=False) def forward(self, x: torch.Tensor, z: torch.Tensor = None): noise = self.noise_bias(z).unsqueeze(0) if z is not None else 0 x1, x3 = self.w13(x).chunk(2, dim=-1) return self.w2(self.act_fn(x1 + noise) * x3) class RoPE(nn.Module): def __init__(self, dim, theta=10000): super().__init__() assert dim % 2 == 0 self.dim = dim self.theta = theta def initialize_rope(self, positions): base_freqs = 1. / (self.theta ** (torch.arange(0, self.dim // 2, 2).float() / (self.dim // 2))) lon_pos = torch.deg2rad(positions[:, 0:1]) lat_pos = torch.deg2rad(positions[:, 1:2]) lon_freqs = torch.matmul(lon_pos, base_freqs.unsqueeze(0)) lat_freqs = torch.matmul(lat_pos, base_freqs.unsqueeze(0)) freqs = torch.cat([lon_freqs, lat_freqs], dim=-1) self.register_buffer('cos_freqs', freqs.cos().contiguous(), persistent=True) self.register_buffer('sin_freqs', freqs.sin().contiguous(), persistent=True) @staticmethod def rotate_half(x): x = rearrange(x, '... (d r) -> ... d r', r=2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return rearrange(x, '... d r -> ... (d r)') def forward(self, x): cos = self.cos_freqs.unsqueeze(0).unsqueeze(2).repeat_interleave(2, dim=-1) sin = self.sin_freqs.unsqueeze(0).unsqueeze(2).repeat_interleave(2, dim=-1) return (x.float() * cos + self.rotate_half(x.float()) * sin).to(x.dtype) class MosaicAttention(nn.Module): def __init__(self, config, block_attn_only: bool, no_compression: bool = False): super().__init__() self.block_attn_only = block_attn_only self.no_compression = no_compression self.block_attn_size = config.block_attn_size self.sparse_block_size = config.sparse_block_size self.sparse_block_count = config.sparse_block_count q_heads = config.num_heads gqa_ratio = config.gqa_ratio dim = config.dim qkv_compress_ratio = config.qkv_compress_ratio rope = config.rope rope_theta = config.rope_theta kv_heads = q_heads // gqa_ratio head_dim = int(dim // q_heads // qkv_compress_ratio) self.q_heads = q_heads self.kv_heads = kv_heads self.to_q = nn.Linear(dim, q_heads * head_dim, bias=False) self.to_k = nn.Linear(dim, kv_heads * head_dim, bias=False) self.to_v = nn.Linear(dim, kv_heads * head_dim, bias=False) self.to_o = nn.Linear(q_heads * head_dim, dim, bias=False) self.q_rope = RoPE(head_dim, rope_theta) if rope else None self.k_rope = RoPE(head_dim, rope_theta) if rope else None if block_attn_only: self.to_strategy_combine_mlp = None else: self.to_strategy_combine_mlp = nn.Linear(dim, 3 * q_heads, bias=False) def generate_strategy_weights(self, x): if self.block_attn_only: return [None, None, None] strategy_logits = self.to_strategy_combine_mlp(x) strategy_logits = rearrange(strategy_logits, 't b (h s) -> s b t h', h=self.q_heads) strategy_weights = torch.softmax(strategy_logits.float(), dim=0).type_as(x) strategy_weights = strategy_weights.unsqueeze(-1) return strategy_weights def forward(self, x): q = self.to_q(x) k = self.to_k(x) v = self.to_v(x) strategy_weights = self.generate_strategy_weights(x) q = rearrange(q, 's b (h d) -> b s h d', h=self.q_heads) k = rearrange(k, 's b (h d) -> b s h d', h=self.kv_heads) v = rearrange(v, 's b (h d) -> b s h d', h=self.kv_heads) if self.q_rope is not None: q = self.q_rope(q) k = self.k_rope(k) output = mosaic_attn_func( q=q, k=k, v=v, weight_ba_cmp_slc=strategy_weights, block_attn_size=self.block_attn_size, sparse_block_size=self.sparse_block_size, sparse_block_count=self.sparse_block_count, block_attn_only=self.block_attn_only, no_compression=self.no_compression, ) output = rearrange(output, 'b s h d -> s b (h d)') output = self.to_o(output) return output class MosaicBlock(nn.Module): def __init__(self, config, block_attn_only: bool, no_compression: bool = False): super().__init__() dim = config.dim noise_dim = config.noise_dim mlp_ratio = config.mlp_ratio self.attention = MosaicAttention(config, block_attn_only, no_compression) self.norm1 = RMSNorm(dim, elementwise_affine=config.rmsnorm_elementwise_affine) self.norm2 = RMSNorm(dim, elementwise_affine=config.rmsnorm_elementwise_affine) self.ffn = cSwiGLU(dim, int(dim * mlp_ratio), noise_dim) def forward(self, x: torch.Tensor, z: torch.Tensor = None): x = x + self.attention(self.norm1(x)) x = x + self.ffn(self.norm2(x), z) return x class CrossAttentionInterpolate(nn.Module): space_dim = 3 def __init__(self, config): super().__init__() self.k_neighbors = config.k_neighbors dim = config.dim num_heads = config.num_heads head_dim = dim // num_heads self.num_heads = num_heads self.head_dim = head_dim self.scale = head_dim ** -0.5 self.kv_norm = RMSNorm(dim, elementwise_affine=config.rmsnorm_elementwise_affine) self.to_q = nn.Linear(self.space_dim, dim, bias=False) self.to_kv = nn.Linear(dim, 2 * dim, bias=False) self.to_o = nn.Linear(dim, dim, bias=False) @torch.no_grad() def initialize_interpolation_scheme(self, pos_from_rad, pos_to_rad): neighbors_np = get_neighbors(pos_from_rad.cpu().numpy(), pos_to_rad.cpu().numpy(), k=self.k_neighbors) neighbors = torch.from_numpy(neighbors_np).long().to(pos_from_rad.device).contiguous() pos_to_xyz = rad_to_xyz(pos_to_rad) pos_from_xyz = rad_to_xyz(pos_from_rad) rel_pos_xyz = (pos_to_xyz.unsqueeze(1) - pos_from_xyz[neighbors]).contiguous() norm_rel_pos_xyz = torch.nn.functional.normalize(rel_pos_xyz, dim=-1).contiguous() self.register_buffer('neighbors', neighbors, persistent=True) self.register_buffer('rel_pos', norm_rel_pos_xyz, persistent=True) def forward(self, x_from: torch.Tensor): if self.neighbors is None or self.rel_pos is None: raise ValueError("Interpolation scheme not initialized.") q = self.to_q(self.rel_pos) q = rearrange(q, 's k (h d) -> s k 1 h d', h=self.num_heads) x = self.kv_norm(x_from) kv = self.to_kv(x) kv = rearrange(kv, 's b (n h d) -> n s b h d', h=self.num_heads, n=2) k, v = kv[:, self.neighbors] attn_scores = (q * k).sum(dim=-1, keepdim=True) * self.scale attn_weights = torch.softmax(attn_scores, dim=1, dtype=torch.float32).type_as(k) out = (attn_weights * v).sum(dim=1) out = rearrange(out, 's b h d -> s b (h d)') out = self.to_o(out) return out class NoiseGenerator(nn.Module): def __init__(self, noise_dim: int, seed: int): super().__init__() self.seed = seed self.to_noise = nn.Linear(noise_dim, noise_dim, bias=False) self.generator = None def forward(self, num_samples: int, device: torch.device, dtype: torch.dtype): if self.generator is None: self.generator = torch.Generator(device=device) self.generator.manual_seed(self.seed) noise = torch.randn((num_samples, self.to_noise.in_features), generator=self.generator, device=device, dtype=dtype) noise = self.to_noise(noise) return noise class HEALPixDownsample(nn.Module): space_dim: int = 3 def __init__(self, in_dim, out_dim, nside_before, nside_after, rmsnorm_elementwise_affine=True): super().__init__() self.factor = (nside_before // nside_after) ** 2 self.proj_x = nn.Linear(self.factor * in_dim, out_dim, bias=False) self.proj_pos = nn.Linear(self.factor * self.space_dim, out_dim, bias=False) self.norm = RMSNorm(out_dim, elementwise_affine=rmsnorm_elementwise_affine) hp_grid_fine_xyz = rad_to_xyz(torch.deg2rad(get_healpix_grid(nside_before))) hp_grid_coarse_xyz = rad_to_xyz(torch.deg2rad(get_healpix_grid(nside_after))) pos = rearrange(hp_grid_fine_xyz, '(n f) d -> n f d', f=self.factor) rel_pos = rearrange(pos - hp_grid_coarse_xyz[:, None], 'n f d -> n (f d)') rel_pos = (rel_pos - rel_pos.mean(dim=0, keepdim=True)) / (rel_pos.std(dim=0, keepdim=True) + 1e-6) self.register_buffer('rel_pos', rel_pos.contiguous(), persistent=True) def forward(self, x: torch.Tensor): x = rearrange(x, '(n f) b c -> n b (f c)', f=self.factor) x = self.proj_x(x) + self.proj_pos(self.rel_pos).unsqueeze(1) x = self.norm(x) return x class HEALPixUpsample(nn.Module): space_dim: int = 3 def __init__(self, in_dim, out_dim, nside_before, nside_after, rmsnorm_elementwise_affine=True): super().__init__() self.factor = (nside_after // nside_before) ** 2 self.proj_x = nn.Linear(in_dim, out_dim * self.factor, bias=False) self.proj_pos = nn.Linear(self.factor * self.space_dim, out_dim * self.factor, bias=False) self.norm = RMSNorm(out_dim, elementwise_affine=rmsnorm_elementwise_affine) hp_grid_fine_xyz = rad_to_xyz(torch.deg2rad(get_healpix_grid(nside_after))) hp_grid_coarse_xyz = rad_to_xyz(torch.deg2rad(get_healpix_grid(nside_before))) children_pos_reshaped = rearrange(hp_grid_fine_xyz, '(n f) d -> n f d', f=self.factor) rel_pos = rearrange(children_pos_reshaped - hp_grid_coarse_xyz[:, None], 'n f d -> n (f d)') rel_pos = (rel_pos - rel_pos.mean(dim=0, keepdim=True)) / (rel_pos.std(dim=0, keepdim=True) + 1e-6) self.register_buffer('rel_pos', rel_pos.contiguous(), persistent=True) def forward(self, x: torch.Tensor, shortcut: torch.Tensor): x = self.proj_x(x) + self.proj_pos(self.rel_pos).unsqueeze(1) x = rearrange(x, 'n b (f d) -> (n f) b d', f=self.factor) x = x + shortcut x = self.norm(x) return x