| """MegaLoc: One Retrieval to Place Them All |
| |
| This module implements the MegaLoc model for visual place recognition. |
| The model combines a Vision Transformer backbone with an optimal transport-based |
| feature aggregation module. |
| |
| Paper: https://arxiv.org/abs/2502.17237 |
| License: MIT |
| """ |
|
|
| import math |
| from typing import Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision.transforms.functional as tfm |
|
|
|
|
| |
| |
| def log_otp_solver(log_a, log_b, M, num_iters: int = 20, reg: float = 1.0) -> torch.Tensor: |
| r"""Sinkhorn matrix scaling algorithm for Differentiable Optimal Transport problem. |
| This function solves the optimization problem and returns the OT matrix for the given parameters. |
| Args: |
| log_a : torch.Tensor |
| Source weights |
| log_b : torch.Tensor |
| Target weights |
| M : torch.Tensor |
| metric cost matrix |
| num_iters : int, default=100 |
| The number of iterations. |
| reg : float, default=1.0 |
| regularization value |
| """ |
| M = M / reg |
|
|
| u, v = torch.zeros_like(log_a), torch.zeros_like(log_b) |
|
|
| for _ in range(num_iters): |
| u = log_a - torch.logsumexp(M + v.unsqueeze(1), dim=2).squeeze() |
| v = log_b - torch.logsumexp(M + u.unsqueeze(2), dim=1).squeeze() |
|
|
| return M + u.unsqueeze(2) + v.unsqueeze(1) |
|
|
|
|
| |
| |
| def get_matching_probs(S, dustbin_score=1.0, num_iters=3, reg=1.0): |
| """sinkhorn""" |
| batch_size, m, n = S.size() |
| |
| S_aug = torch.empty(batch_size, m + 1, n, dtype=S.dtype, device=S.device) |
| S_aug[:, :m, :n] = S |
| S_aug[:, m, :] = dustbin_score |
|
|
| |
| norm = -torch.tensor(math.log(n + m), device=S.device) |
| log_a, log_b = norm.expand(m + 1).contiguous(), norm.expand(n).contiguous() |
| log_a[-1] = log_a[-1] + math.log(n - m) |
| log_a, log_b = log_a.expand(batch_size, -1), log_b.expand(batch_size, -1) |
| log_P = log_otp_solver(log_a, log_b, S_aug, num_iters=num_iters, reg=reg) |
| return log_P - norm |
|
|
|
|
| class FeatureAggregator(nn.Module): |
| """Optimal transport-based aggregation of local features into global descriptor. |
| |
| This module aggregates local patch features into a compact global representation |
| using differentiable optimal transport. |
| |
| Args: |
| num_channels: Number of input feature channels (from backbone) |
| num_clusters: Number of cluster centers |
| cluster_dim: Dimensionality of cluster descriptors |
| token_dim: Dimensionality of global scene token |
| mlp_dim: Hidden dimension for MLPs |
| dropout: Dropout probability (0 to disable) |
| """ |
|
|
| def __init__( |
| self, |
| num_channels=1536, |
| num_clusters=64, |
| cluster_dim=128, |
| token_dim=256, |
| mlp_dim=512, |
| dropout=0.3, |
| ) -> None: |
| super().__init__() |
|
|
| self.num_channels = num_channels |
| self.num_clusters = num_clusters |
| self.cluster_dim = cluster_dim |
| self.token_dim = token_dim |
| self.mlp_dim = mlp_dim |
|
|
| if dropout > 0: |
| dropout = nn.Dropout(dropout) |
| else: |
| dropout = nn.Identity() |
|
|
| |
| self.token_features = nn.Sequential( |
| nn.Linear(self.num_channels, self.mlp_dim), nn.ReLU(), nn.Linear(self.mlp_dim, self.token_dim) |
| ) |
| |
| self.cluster_features = nn.Sequential( |
| nn.Conv2d(self.num_channels, self.mlp_dim, 1), |
| dropout, |
| nn.ReLU(), |
| nn.Conv2d(self.mlp_dim, self.cluster_dim, 1), |
| ) |
| |
| self.score = nn.Sequential( |
| nn.Conv2d(self.num_channels, self.mlp_dim, 1), |
| dropout, |
| nn.ReLU(), |
| nn.Conv2d(self.mlp_dim, self.num_clusters, 1), |
| ) |
| |
| self.dust_bin = nn.Parameter(torch.tensor(1.0)) |
|
|
| def forward(self, x): |
| """ |
| Args: |
| x: Tuple of (features, token) |
| features: [B, C, H, W] spatial feature map |
| token: [B, C] global CLS token |
| |
| Returns: |
| Global descriptor [B, num_clusters * cluster_dim + token_dim] |
| """ |
| x, t = x |
|
|
| f = self.cluster_features(x).flatten(2) |
| p = self.score(x).flatten(2) |
| t = self.token_features(t) |
|
|
| p = get_matching_probs(p, self.dust_bin, 3) |
| p = torch.exp(p) |
| p = p[:, :-1, :] |
|
|
| p = p.unsqueeze(1).repeat(1, self.cluster_dim, 1, 1) |
| f = f.unsqueeze(2).repeat(1, 1, self.num_clusters, 1) |
|
|
| f = torch.cat( |
| [ |
| F.normalize(t, p=2, dim=-1), |
| F.normalize((f * p).sum(dim=-1), p=2, dim=1).flatten(1), |
| ], |
| dim=-1, |
| ) |
|
|
| return F.normalize(f, p=2, dim=-1) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class PatchEmbedding(nn.Module): |
| """Convert image patches to embeddings using a convolutional layer.""" |
|
|
| def __init__(self, image_size: int = 518, patch_size: int = 14, in_channels: int = 3, embed_dim: int = 768): |
| super().__init__() |
| self.image_size = image_size |
| self.patch_size = patch_size |
| self.num_patches = (image_size // patch_size) ** 2 |
| self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.proj(x) |
| x = x.flatten(2) |
| x = x.transpose(1, 2) |
| return x |
|
|
|
|
| class LayerScale(nn.Module): |
| """Learnable per-channel scaling as used in CaiT and DINOv2.""" |
|
|
| def __init__(self, dim: int, init_value: float = 1e-5): |
| super().__init__() |
| self.gamma = nn.Parameter(init_value * torch.ones(dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x * self.gamma |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
| """Multi-head self-attention module.""" |
|
|
| def __init__( |
| self, dim: int, num_heads: int = 12, qkv_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0 |
| ): |
| super().__init__() |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
| self.scale = self.head_dim**-0.5 |
|
|
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| self.attn_drop = nn.Dropout(attn_drop) |
| self.proj = nn.Linear(dim, dim) |
| self.proj_drop = nn.Dropout(proj_drop) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, N, C = x.shape |
|
|
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) |
| qkv = qkv.permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
|
|
| return x |
|
|
|
|
| class MLP(nn.Module): |
| """MLP module with GELU activation.""" |
|
|
| def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None, drop: float = 0.0): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
|
|
| self.fc1 = nn.Linear(in_features, hidden_features) |
| self.act = nn.GELU() |
| self.fc2 = nn.Linear(hidden_features, out_features) |
| self.drop = nn.Dropout(drop) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.drop(x) |
| x = self.fc2(x) |
| x = self.drop(x) |
| return x |
|
|
|
|
| class TransformerBlock(nn.Module): |
| """Vision Transformer block with LayerScale.""" |
|
|
| def __init__( |
| self, |
| dim: int, |
| num_heads: int, |
| mlp_ratio: float = 4.0, |
| qkv_bias: bool = True, |
| drop: float = 0.0, |
| attn_drop: float = 0.0, |
| init_values: float = 1e-5, |
| ): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(dim, eps=1e-6) |
| self.attn = MultiHeadAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) |
| self.ls1 = LayerScale(dim, init_value=init_values) |
|
|
| self.norm2 = nn.LayerNorm(dim, eps=1e-6) |
| self.mlp = MLP(in_features=dim, hidden_features=int(dim * mlp_ratio), drop=drop) |
| self.ls2 = LayerScale(dim, init_value=init_values) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = x + self.ls1(self.attn(self.norm1(x))) |
| x = x + self.ls2(self.mlp(self.norm2(x))) |
| return x |
|
|
|
|
| class DINOv2(nn.Module): |
| """DINOv2 Vision Transformer backbone for feature extraction. |
| |
| This implements a ViT-B/14 architecture compatible with DINOv2 weights. |
| """ |
|
|
| def __init__( |
| self, |
| image_size: int = 518, |
| patch_size: int = 14, |
| in_channels: int = 3, |
| embed_dim: int = 768, |
| depth: int = 12, |
| num_heads: int = 12, |
| mlp_ratio: float = 4.0, |
| qkv_bias: bool = True, |
| ): |
| super().__init__() |
| self.patch_size = patch_size |
| self.embed_dim = embed_dim |
| self.num_channels = embed_dim |
|
|
| self.patch_embed = PatchEmbedding( |
| image_size=image_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim |
| ) |
|
|
| self.interpolate_offset = 0.1 |
| self.interpolate_antialias = False |
|
|
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| num_patches = (image_size // patch_size) ** 2 |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) |
|
|
| self.blocks = nn.ModuleList( |
| [ |
| TransformerBlock(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias) |
| for _ in range(depth) |
| ] |
| ) |
|
|
| self.norm = nn.LayerNorm(embed_dim, eps=1e-6) |
|
|
| def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor: |
| """Interpolate positional encoding for different input sizes.""" |
| previous_dtype = x.dtype |
| npatch = x.shape[1] - 1 |
| N = self.pos_embed.shape[1] - 1 |
|
|
| if npatch == N and w == h: |
| return self.pos_embed |
|
|
| pos_embed = self.pos_embed.float() |
| class_pos_embed = pos_embed[:, 0] |
| patch_pos_embed = pos_embed[:, 1:] |
|
|
| dim = x.shape[-1] |
| w0 = w // self.patch_size |
| h0 = h // self.patch_size |
| M = int(math.sqrt(N)) |
|
|
| sx = float(w0 + self.interpolate_offset) / M |
| sy = float(h0 + self.interpolate_offset) / M |
|
|
| patch_pos_embed = F.interpolate( |
| patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), |
| scale_factor=(sx, sy), |
| mode="bicubic", |
| antialias=self.interpolate_antialias, |
| ) |
|
|
| assert (w0, h0) == patch_pos_embed.shape[-2:] |
| patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) |
|
|
| return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) |
|
|
| def forward(self, images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Extract features from images. |
| |
| Args: |
| images: Input images [B, 3, H, W] where H, W are multiples of 14 |
| |
| Returns: |
| Tuple of (patch_features [B, 768, H//14, W//14], cls_token [B, 768]) |
| """ |
| B, _, H, W = images.shape |
|
|
| x = self.patch_embed(images) |
| cls_tokens = self.cls_token.expand(B, -1, -1) |
| x = torch.cat((cls_tokens, x), dim=1) |
| x = x + self.interpolate_pos_encoding(x, H, W) |
|
|
| for block in self.blocks: |
| x = block(x) |
|
|
| x = self.norm(x) |
|
|
| cls_token = x[:, 0] |
| patch_tokens = x[:, 1:] |
| patch_features = patch_tokens.reshape(B, H // self.patch_size, W // self.patch_size, self.embed_dim).permute( |
| 0, 3, 1, 2 |
| ) |
|
|
| return patch_features, cls_token |
|
|
|
|
| |
| |
| |
|
|
|
|
| class L2Norm(nn.Module): |
| def __init__(self, dim=1): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, x): |
| return F.normalize(x, p=2.0, dim=self.dim) |
|
|
|
|
| class Aggregator(nn.Module): |
| def __init__(self, feat_dim, agg_config, salad_out_dim): |
| super().__init__() |
| self.agg = FeatureAggregator(**agg_config) |
| self.linear = nn.Linear(salad_out_dim, feat_dim) |
|
|
| def forward(self, x): |
| x = self.agg(x) |
| return self.linear(x) |
|
|
|
|
| class MegaLoc(nn.Module): |
| """MegaLoc: Unified visual place recognition model. |
| |
| Combines a DINOv2 Vision Transformer backbone with optimal transport-based |
| feature aggregation to produce compact, discriminative image descriptors |
| for place recognition and image retrieval tasks. |
| |
| Args: |
| feat_dim: Output descriptor dimensionality (default: 8448) |
| num_clusters: Number of cluster centers for aggregation (default: 64) |
| cluster_dim: Dimensionality of cluster descriptors (default: 256) |
| token_dim: Dimensionality of global scene token (default: 256) |
| mlp_dim: Hidden dimension for MLPs (default: 512) |
| |
| Example: |
| >>> model = torch.hub.load("gmberton/MegaLoc", "get_trained_model") |
| >>> model.eval() |
| >>> descriptor = model(image) # [B, 8448] |
| """ |
|
|
| def __init__( |
| self, |
| feat_dim: int = 8448, |
| num_clusters: int = 64, |
| cluster_dim: int = 256, |
| token_dim: int = 256, |
| mlp_dim: int = 512, |
| ): |
| super().__init__() |
|
|
| self.backbone = DINOv2() |
| self.salad_out_dim = num_clusters * cluster_dim + token_dim |
| self.aggregator = Aggregator( |
| feat_dim=feat_dim, |
| agg_config={ |
| "num_channels": self.backbone.num_channels, |
| "num_clusters": num_clusters, |
| "cluster_dim": cluster_dim, |
| "token_dim": token_dim, |
| "mlp_dim": mlp_dim, |
| }, |
| salad_out_dim=self.salad_out_dim, |
| ) |
| self.feat_dim = feat_dim |
| self.l2norm = L2Norm() |
|
|
| def forward(self, images: torch.Tensor) -> torch.Tensor: |
| """Extract global descriptor from images. |
| |
| Args: |
| images: Input images [B, 3, H, W] |
| |
| Returns: |
| L2-normalized descriptors [B, feat_dim] |
| """ |
| b, c, h, w = images.shape |
| if h % 14 != 0 or w % 14 != 0: |
| h = round(h / 14) * 14 |
| w = round(w / 14) * 14 |
| images = tfm.resize(images, [h, w], antialias=True) |
| features = self.aggregator(self.backbone(images)) |
| features = self.l2norm(features) |
| return features |
|
|