| import math |
| from functools import lru_cache |
| from unittest.mock import patch |
|
|
| import torch |
| import torch.nn.functional as F |
| from einops import rearrange, repeat |
| from torch import Tensor, nn |
| from torch.nn.attention.flex_attention import ( |
| BlockMask, |
| _mask_mod_signature, |
| create_block_mask, |
| flex_attention, |
| ) |
| from torch.nn.utils.parametrizations import _is_orthogonal, orthogonal |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils.backbone_utils import load_backbone |
|
|
| from .configuration import LSPDetrConfig, STAConfig |
|
|
|
|
| flex_attention = torch.compile(flex_attention, dynamic=True) |
| patch( |
| "torch.nn.utils.parametrizations._is_orthogonal", |
| lambda Q, eps=None: Q.device == torch.device("meta") or _is_orthogonal(Q, eps=eps), |
| ).start() |
|
|
|
|
| class CayleySTRING(nn.Module): |
| """Implements the Cayley-STRING positional encoding. |
| |
| Based on "Learning the RoPEs: Better 2D and 3D Position Encodings with STRING" |
| (https://arxiv.org/abs/2502.02562). |
| |
| Applies RoPE followed by multiplication with a learnable orthogonal matrix P |
| parameterized by the Cayley transform. |
| |
| Args: |
| head_dim (int): The feature dimension of the input tensor. Must be even. |
| pos_dim (int): The dimensionality of the position vectors (e.g., 1 for 1D, 2 for 2D). |
| theta (float): The base value for the RoPE frequency calculation. |
| """ |
|
|
| def __init__(self, dim: int, pos_dim: int = 2, theta: float = 100.0) -> None: |
| super().__init__() |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
| self.freqs = nn.Parameter(repeat(freqs, "d -> p d", p=pos_dim).clone()) |
| self.P = orthogonal(nn.Linear(dim, dim, bias=False), orthogonal_map="cayley") |
|
|
| @torch.autocast("cuda", enabled=False) |
| def forward(self, x: Tensor, positions: Tensor) -> Tensor: |
| """Apply Cayley-STRING positional encoding. |
| |
| Args: |
| x ([b, h, n, d]): Input tensor. |
| positions ([b, n, pos_dim]): Positions tensor. |
| """ |
| px = self.P(x.float()) |
|
|
| |
| freqs = positions @ self.freqs |
| freqs_cis = rearrange( |
| torch.polar(torch.ones_like(freqs), freqs), "b n c -> b 1 n c" |
| ) |
| px_ = torch.view_as_complex(rearrange(px, "... (d two) -> ... d two", two=2)) |
| out = rearrange(torch.view_as_real(px_ * freqs_cis), "... d two -> ... (d two)") |
|
|
| return out.type_as(x) |
|
|
|
|
| class MLP(nn.Sequential): |
| """Very simple multi-layer perceptron.""" |
|
|
| def __init__( |
| self, |
| input_dim: int, |
| hidden_dim: int, |
| output_dim: int, |
| num_layers: int, |
| act_layer: type[nn.Module] = nn.GELU, |
| dropout: float = 0.0, |
| ) -> None: |
| assert num_layers > 1 |
|
|
| layers = [] |
| h = [hidden_dim] * (num_layers - 1) |
| for n, k in zip([input_dim, *h], h, strict=False): |
| layers.append(nn.Linear(n, k)) |
| layers.append(act_layer()) |
| if dropout > 0: |
| layers.append(nn.Dropout(dropout)) |
|
|
| layers.append(nn.Linear(hidden_dim, output_dim)) |
| super().__init__(*layers) |
|
|
|
|
| class FeedForward(nn.Module): |
| """FeedForward module. |
| |
| Taken from https://github.com/meta-llama/llama-models/blob/main/models/llama4/ffn.py |
| """ |
|
|
| def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256) -> None: |
| """Initialize the FeedForward module. |
| |
| Args: |
| dim (int): Input dimension. |
| hidden_dim (int): Hidden dimension of the feedforward layer. |
| multiple_of (int): Value to ensure hidden dimension is a multiple of this value. |
| """ |
| super().__init__() |
| hidden_dim = int(2 * hidden_dim / 3) |
| hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
|
|
| self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
| self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
| self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
| def generate_sta_mask( |
| q_canvas_w: int, |
| kv_canvas_hw: tuple[int, int], |
| kernel: int, |
| q_tile: int, |
| kv_tile: int, |
| ) -> _mask_mod_signature: |
| q_canvas_tile_w = q_canvas_w // q_tile |
| kv_canvas_tile_h = kv_canvas_hw[0] // kv_tile |
| kv_canvas_tile_w = kv_canvas_hw[1] // kv_tile |
|
|
| def q_tile_rescale(x: Tensor): |
| |
| scale_numerator = kv_canvas_tile_w - 1 |
| scale_denominator = q_canvas_tile_w - 1 |
| return (x * scale_numerator + scale_denominator // 2) // scale_denominator |
|
|
| def get_tile_xy( |
| idx: Tensor, tile_size: int, canvas_tile_w: int |
| ) -> tuple[Tensor, Tensor]: |
| tile_id = idx // (tile_size * tile_size) |
| tile_x = tile_id % canvas_tile_w |
| tile_y = tile_id // canvas_tile_w |
| return tile_x, tile_y |
|
|
| def sta_mask_2d(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: |
| q_x_tile, q_y_tile = get_tile_xy(q_idx, q_tile, q_canvas_tile_w) |
| kv_x_tile, kv_y_tile = get_tile_xy(kv_idx, kv_tile, kv_canvas_tile_w) |
|
|
| q_x_tile = q_tile_rescale(q_x_tile) |
| q_y_tile = q_tile_rescale(q_y_tile) |
|
|
| center_x = q_x_tile.clamp(kernel // 2, (kv_canvas_tile_w - 1) - kernel // 2) |
| center_y = q_y_tile.clamp(kernel // 2, (kv_canvas_tile_h - 1) - kernel // 2) |
|
|
| |
| x_mask = torch.abs(center_x - kv_x_tile) <= kernel // 2 |
| y_mask = torch.abs(center_y - kv_y_tile) <= kernel // 2 |
|
|
| return x_mask & y_mask |
|
|
| return sta_mask_2d |
|
|
|
|
| @lru_cache |
| def create_sta_block_mask( |
| q_len: int, |
| kv_len: int, |
| q_width: int, |
| kv_width: int, |
| kernel: int, |
| q_tile: int, |
| kv_tile: int, |
| ) -> BlockMask: |
| return create_block_mask( |
| generate_sta_mask( |
| q_width, (kv_len // kv_width, kv_width), kernel, q_tile, kv_tile |
| ), |
| B=None, |
| H=None, |
| device="cuda" if torch.cuda.is_available() else "cpu", |
| Q_LEN=q_len, |
| KV_LEN=kv_len, |
| _compile=True, |
| ) |
|
|
|
|
| @torch.autocast("cuda", enabled=False) |
| def relative_to_absolute_pos(pos: Tensor, step_x: float, step_y: float) -> Tensor: |
| pos = pos.sigmoid() |
| h, w = pos.shape[1:3] |
|
|
| anchor_x = torch.arange(w, dtype=torch.float32, device=pos.device) * step_x |
| anchor_y = torch.arange(h, dtype=torch.float32, device=pos.device) * step_y |
|
|
| absolute_x = pos[..., 0] * step_x + anchor_x |
| absolute_y = pos[..., 1] * step_y + anchor_y.unsqueeze(1) |
| return torch.stack((absolute_x, absolute_y), dim=-1) |
|
|
|
|
| class STAttention(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| src_dim: int, |
| num_heads: int, |
| kernel: int, |
| q_tile: int, |
| kv_tile: int, |
| ) -> None: |
| super().__init__() |
| self.num_heads = num_heads |
| self.kernel = kernel |
| self.q_tile = q_tile |
| self.kv_tile = kv_tile |
|
|
| self.pe = CayleySTRING(dim // num_heads) |
| self.q = nn.Linear(dim, dim, bias=False) |
| self.kv = nn.Linear(src_dim, dim * 2, bias=False) |
| self.wo = nn.Linear(dim, dim, bias=False) |
|
|
| def maybe_pad(self, x: Tensor, tile: int) -> Tensor: |
| h, w = x.shape[1:3] |
| pad_right = (tile - w % tile) % tile |
| pad_bottom = (tile - h % tile) % tile |
| return F.pad(x, (0, 0, 0, pad_right, 0, pad_bottom)) |
|
|
| def tile(self, x: Tensor, height: int, tile: int) -> tuple[Tensor, int, int]: |
| x = rearrange(x, "b head (h w) dim -> b h w (head dim)", h=height) |
| x = self.maybe_pad(x, tile) |
| h, w = x.shape[1:3] |
| x = rearrange( |
| x, |
| "b (n_h ts_h) (n_w ts_w) (h d) -> b h (n_h n_w ts_h ts_w) d", |
| ts_h=tile, |
| ts_w=tile, |
| h=self.num_heads, |
| ) |
| return x, h, w |
|
|
| def forward( |
| self, tgt: Tensor, src: Tensor, q_coords: Tensor, k_coords: Tensor |
| ) -> Tensor: |
| h, w = tgt.shape[1:3] |
|
|
| q = rearrange( |
| self.q(tgt), "b h w (head d) -> b head (h w) d", head=self.num_heads |
| ) |
| k, v = rearrange( |
| self.kv(src), |
| "b h w (two head d) -> two b head (h w) d", |
| two=2, |
| head=self.num_heads, |
| ) |
|
|
| |
| q = self.pe(q, q_coords) |
| k = self.pe(k, k_coords) |
|
|
| |
| q, q_h, q_w = self.tile(q, h, self.q_tile) |
| k, _, kv_w = self.tile(k, src.shape[1], self.kv_tile) |
| v, _, _ = self.tile(v, src.shape[1], self.kv_tile) |
|
|
| |
| block_mask = create_sta_block_mask( |
| q_len=q.shape[2], |
| kv_len=k.shape[2], |
| q_width=q_w, |
| kv_width=kv_w, |
| kernel=self.kernel, |
| q_tile=self.q_tile, |
| kv_tile=self.kv_tile, |
| ) |
| x = flex_attention(q, k, v, block_mask=block_mask) |
|
|
| |
| x = rearrange( |
| x, |
| "b h (n_h n_w ts_h ts_w) d -> b (n_h ts_h) (n_w ts_w) (h d)", |
| n_h=q_h // self.q_tile, |
| n_w=q_w // self.q_tile, |
| ts_h=self.q_tile, |
| ts_w=self.q_tile, |
| ) |
|
|
| |
| x = x[:, :h, :w, :].contiguous() |
|
|
| return self.wo(x) |
|
|
|
|
| class Layer(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| src_dim: int, |
| num_heads: int, |
| self_sta_config: STAConfig, |
| cross_sta_config: STAConfig, |
| ) -> None: |
| super().__init__() |
|
|
| self.self_attention = STAttention( |
| dim, |
| dim, |
| num_heads, |
| kernel=self_sta_config["kernel"], |
| q_tile=self_sta_config["q_tile"], |
| kv_tile=self_sta_config["kv_tile"], |
| ) |
| self.self_attention_norm = nn.LayerNorm(dim) |
|
|
| self.cross_attention = STAttention( |
| dim, |
| src_dim, |
| num_heads, |
| kernel=cross_sta_config["kernel"], |
| q_tile=cross_sta_config["q_tile"], |
| kv_tile=cross_sta_config["kv_tile"], |
| ) |
| self.cross_attention_norm = nn.LayerNorm(dim) |
|
|
| self.ffn = FeedForward(dim, dim * 4) |
| self.ffn_norm = nn.LayerNorm(dim) |
|
|
| def forward( |
| self, tgt: Tensor, src: Tensor, tgt_coords: Tensor, src_coords: Tensor |
| ) -> Tensor: |
| x = self.self_attention(tgt, tgt, tgt_coords, tgt_coords) |
| tgt = self.self_attention_norm(tgt + x) |
|
|
| x = self.cross_attention(tgt, src, tgt_coords, src_coords) |
| tgt = self.cross_attention_norm(tgt + x) |
|
|
| return self.ffn_norm(tgt + self.ffn(tgt)) |
|
|
|
|
| class LSPTransformer(nn.Module): |
| def __init__(self, config: LSPDetrConfig, feature_channels: list[int]) -> None: |
| super().__init__() |
|
|
| self.query_block_size = config.query_block_size |
| self.num_radial_distances = config.num_radial_distances |
| self.feature_levels = config.feature_levels |
| self.num_classes = config.num_classes + 1 |
|
|
| self.layers = nn.ModuleList() |
| for level in config.feature_levels: |
| layer = Layer( |
| dim=config.dim, |
| src_dim=feature_channels[level], |
| num_heads=config.num_heads, |
| self_sta_config=config.self_sta_config, |
| cross_sta_config=config.cross_sta_config[level], |
| ) |
| self.layers.append(layer) |
|
|
| |
| self.class_head = nn.Linear(config.dim, self.num_classes) |
| self.point_head = nn.ModuleList( |
| MLP(config.dim, config.dim, 2, 3) for _ in config.feature_levels |
| ) |
| self.radial_distances_head = nn.ModuleList( |
| MLP(config.dim, config.dim, config.num_radial_distances, 3) |
| for _ in config.feature_levels |
| ) |
|
|
| self.init_weights() |
|
|
| def init_weights(self) -> None: |
| prior_prob = 0.01 |
| bias_value = -math.log((1 - prior_prob) / prior_prob) |
| nn.init.constant_(self.class_head.bias, bias_value) |
|
|
| |
| for head in self.point_head: |
| nn.init.constant_(head[-1].weight, 0) |
| nn.init.constant_(head[-1].bias, 0) |
|
|
| for head in self.radial_distances_head: |
| nn.init.constant_(head[-1].weight, 0) |
| nn.init.constant_(head[-1].bias, 0) |
|
|
| def forward( |
| self, |
| tgt: Tensor, |
| ref_points: Tensor, |
| features: list[Tensor], |
| height: int, |
| width: int, |
| ) -> dict[str, Tensor | list[dict[str, Tensor]]]: |
| src = [] |
| src_coords = [] |
| for feature in features: |
| b, _, h, w = feature.shape |
| coords = torch.zeros(b, h, w, 2, dtype=torch.float32, device=feature.device) |
| coords = relative_to_absolute_pos( |
| coords, step_x=math.ceil(width / w), step_y=math.ceil(height / h) |
| ) |
| |
| src.append(rearrange(feature, "b c h w -> b h w c")) |
| src_coords.append(rearrange(coords, "b h w pos -> b (h w) pos")) |
|
|
| radial_distances = torch.full( |
| (*tgt.shape[:3], self.num_radial_distances), |
| math.log1p(self.query_block_size / 2), |
| dtype=torch.float32, |
| device=tgt.device, |
| ) |
|
|
| logits_list: list[Tensor] = [] |
| ref_points_list: list[Tensor] = [] |
| radial_distances_list: list[Tensor] = [] |
|
|
| |
| new_ref_points = ref_points.clone() |
| new_radial_distances = radial_distances.clone() |
|
|
| for i, layer in enumerate(self.layers): |
| tgt = layer( |
| tgt=tgt, |
| src=src[self.feature_levels[i]], |
| tgt_coords=relative_to_absolute_pos( |
| ref_points, self.query_block_size, self.query_block_size |
| ).flatten(1, 2), |
| src_coords=src_coords[self.feature_levels[i]], |
| ) |
|
|
| |
| delta_point = self.point_head[i](tgt) |
| delta_distances = self.radial_distances_head[i](tgt) |
| logits = self.class_head(tgt) |
|
|
| ref_points_list.append( |
| relative_to_absolute_pos( |
| new_ref_points + delta_point, |
| step_x=self.query_block_size / width, |
| step_y=self.query_block_size / height, |
| ).flatten(1, 2) |
| ) |
| logits_list.append(logits.flatten(1, 2)) |
| radial_distances_list.append( |
| torch.flatten(new_radial_distances + delta_distances, 1, 2) |
| ) |
|
|
| new_ref_points = ref_points + delta_point |
| new_radial_distances = radial_distances + delta_distances |
| ref_points = new_ref_points.detach() |
| radial_distances = new_radial_distances.detach() |
|
|
| return { |
| "logits": logits_list[-1], |
| "points": ref_points_list[-1], |
| "radial_distances": radial_distances_list[-1], |
| "absolute_points": relative_to_absolute_pos( |
| ref_points, self.query_block_size, self.query_block_size |
| ).flatten(1, 2), |
| "embeddings": tgt.flatten(1, 2), |
| "aux_outputs": [ |
| { |
| "logits": a, |
| "points": b, |
| "radial_distances": c, |
| } |
| for a, b, c in zip( |
| logits_list[:-1], |
| ref_points_list[:-1], |
| radial_distances_list[:-1], |
| strict=True, |
| ) |
| ], |
| } |
|
|
|
|
| class FeatureSampling(nn.Module): |
| def __init__(self, in_dim: int, out_dim: int) -> None: |
| super().__init__() |
| self.reduction = nn.Conv2d(in_dim, out_dim, kernel_size=1, bias=False) |
| self.norm = nn.LayerNorm(out_dim) |
|
|
| def forward(self, points: Tensor, feature: Tensor) -> Tensor: |
| x = F.grid_sample(self.reduction(feature), points * 2 - 1, align_corners=False) |
| return self.norm(rearrange(x, "b c h w -> b h w c")) |
|
|
|
|
| class LSPDetrModel(PreTrainedModel): |
| config_class = LSPDetrConfig |
|
|
| def __init__(self, config: LSPDetrConfig) -> None: |
| super().__init__(config) |
| self.query_block_size = config.query_block_size |
|
|
| self.backbone = load_backbone(config) |
| _, *feature_channels, neck = self.backbone.num_features |
|
|
| self.feature_sampling = FeatureSampling(neck, config.dim) |
| self.decode_head = LSPTransformer(config, feature_channels) |
|
|
| def forward(self, pixel_values: Tensor) -> dict[str, Tensor]: |
| b, _, h, w = pixel_values.shape |
|
|
| *features, neck = self.backbone(pixel_values).feature_maps |
|
|
| ref_points = torch.zeros( |
| b, |
| math.ceil(h / self.query_block_size), |
| math.ceil(w / self.query_block_size), |
| 2, |
| dtype=torch.float32, |
| device=neck.device, |
| ) |
| tgt = self.feature_sampling( |
| relative_to_absolute_pos( |
| ref_points, self.query_block_size, self.query_block_size |
| ), |
| neck, |
| ) |
|
|
| return self.decode_head(tgt, ref_points, features, h, w) |
|
|