""" Vision xLSTM adapter built on the upstream NX-AI vision-lstm repository. This module keeps the existing ViL-DLM contract: - `VisionXLSTM.forward_features(pixel_values)` returns patch tokens `[B, N, D]` - `VisionProjector` maps those visual tokens into the LM embedding space """ from __future__ import annotations import sys from pathlib import Path import os import ssl import certifi import torch import torch.nn as nn REPO_ROOT = Path(__file__).resolve().parents[1] VISION_LSTM_ROOT = REPO_ROOT / "external" / "vision-lstm" if str(VISION_LSTM_ROOT) not in sys.path: sys.path.insert(0, str(VISION_LSTM_ROOT)) from vision_lstm import VisionLSTM, VisionLSTM2 # noqa: E402 VISION_BACKBONES = { "vil-small": { "ctor": VisionLSTM, "preprocess": "v1", "url": "https://ml.jku.at/research/vision_lstm/download/vil_small16_e400_in1k.th", "kwargs": { "dim": 384, "depth": 24, "legacy_norm": True, "mode": None, "pooling": None, "output_shape": None, }, }, "vil2-small": { "ctor": VisionLSTM2, "preprocess": "v2", "url": "https://ml.jku.at/research/vision_lstm/download/vil2_small16_e400_in1k.th", "kwargs": { "dim": 384, "depth": 12, "legacy_norm": True, "mode": "features", "pooling": None, "output_shape": None, "conv_kind": "2d", "conv_kernel_size": 3, "norm_bias": True, "proj_bias": True, }, }, } def _preprocess_v1_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: state_dict = {key.replace(".xlstm.", ".layer."): value for key, value in state_dict.items()} state_dict = {key.replace("xlstm.", ""): value for key, value in state_dict.items()} state_dict = {key.replace(".xlstm_norm.", ".norm."): value for key, value in state_dict.items()} state_dict["legacy_norm.weight"] = state_dict.pop("post_blocks_norm.weight") state_dict["norm.weight"] = state_dict.pop("head.0.weight") state_dict["norm.bias"] = state_dict.pop("head.0.bias") state_dict["head.weight"] = state_dict.pop("head.1.weight") state_dict["head.bias"] = state_dict.pop("head.1.bias") return state_dict def _preprocess_v2_state_dict( state_dict: dict[str, torch.Tensor], *, depth: int, legacy_norm: bool, ) -> dict[str, torch.Tensor]: state_dict = {key.replace(".xlstm.", ".layer."): value for key, value in state_dict.items()} state_dict = {key.replace("xlstm.", ""): value for key, value in state_dict.items()} state_dict = {key.replace(".xlstm_norm.", ".norm."): value for key, value in state_dict.items()} state_dict = {key.replace(".conv1d.", ".conv."): value for key, value in state_dict.items()} for index in range(depth * 2): if index % 2 == 0: state_dict = { key.replace(f"blocks.{index}.", f"blocks.{index // 2}.rowwise_from_top_left."): value for key, value in state_dict.items() } else: state_dict = { key.replace(f"blocks.{index}.", f"blocks.{index // 2}.rowwise_from_bot_right."): value for key, value in state_dict.items() } state_dict["norm.weight"] = state_dict.pop("post_blocks_norm.weight") state_dict["norm.bias"] = state_dict.pop("post_blocks_norm.bias") if legacy_norm: state_dict["legacy_norm.weight"] = state_dict.pop("head.0.weight") state_dict["legacy_norm.bias"] = state_dict.pop("head.0.bias") state_dict["head.weight"] = state_dict.pop("head.1.weight") state_dict["head.bias"] = state_dict.pop("head.1.bias") return state_dict def _load_pretrained_backbone(model: nn.Module, name: str, spec: dict) -> None: os.environ.setdefault("SSL_CERT_FILE", certifi.where()) ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=certifi.where()) payload = torch.hub.load_state_dict_from_url(spec["url"], map_location="cpu") state_dict = payload["state_dict"] if spec["preprocess"] == "v1": state_dict = _preprocess_v1_state_dict(state_dict) elif spec["preprocess"] == "v2": state_dict = _preprocess_v2_state_dict( state_dict, depth=spec["kwargs"]["depth"], legacy_norm=spec["kwargs"]["legacy_norm"], ) else: raise ValueError(f"Unsupported checkpoint preprocessing mode: {spec['preprocess']}") if getattr(model, "head", None) is None: state_dict.pop("head.weight", None) state_dict.pop("head.bias", None) model.load_state_dict(state_dict) class VisionXLSTM(nn.Module): """ Thin adapter over upstream VisionLSTM / VisionLSTM2 models. The default backbone is `vil2-small`, which matches the requested 384-dim patch features while using the newer ViL v2 implementation. """ def __init__(self, config): super().__init__() backbone_name = getattr(config, "vision_backbone", "vil2-small") pretrained = getattr(config, "pretrained", True) img_size = getattr(config, "img_size", 224) patch_size = getattr(config, "patch_size", 16) in_channels = getattr(config, "in_channels", 3) if backbone_name not in VISION_BACKBONES: supported = ", ".join(sorted(VISION_BACKBONES)) raise ValueError(f"Unsupported vision backbone '{backbone_name}'. Supported backbones: {supported}") spec = VISION_BACKBONES[backbone_name] ctor_kwargs = dict(spec["kwargs"]) ctor_kwargs["input_shape"] = (in_channels, img_size, img_size) ctor_kwargs["patch_size"] = patch_size self.config = config self.backbone_name = backbone_name self.model = spec["ctor"](**ctor_kwargs) self.dim = ctor_kwargs["dim"] self.num_patches = self.model.patch_embed.num_patches if pretrained: _load_pretrained_backbone(self.model, backbone_name, spec) def forward_features(self, pixel_values: torch.Tensor) -> torch.Tensor: return self.model(pixel_values) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return self.forward_features(pixel_values) class VisionProjector(nn.Module): """ MLP projector: maps ViL features -> LM embedding space. """ def __init__(self, config): super().__init__() hidden_dim = config.lm_dim * config.hidden_mult layers = [nn.Linear(config.vil_dim, hidden_dim), nn.GELU()] if config.dropout > 0: layers.append(nn.Dropout(config.dropout)) for _ in range(config.num_layers - 1): layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.GELU()]) if config.dropout > 0: layers.append(nn.Dropout(config.dropout)) layers.append(nn.Linear(hidden_dim, config.lm_dim)) self.mlp = nn.Sequential(*layers) def forward(self, vision_features: torch.Tensor) -> torch.Tensor: return self.mlp(vision_features)