| """ |
| Vision xLSTM adapter built on the upstream NX-AI vision-lstm repository. |
| |
| This module keeps the existing ViL-DLM contract: |
| - `VisionXLSTM.forward_features(pixel_values)` returns patch tokens `[B, N, D]` |
| - `VisionProjector` maps those visual tokens into the LM embedding space |
| """ |
|
|
| from __future__ import annotations |
|
|
| import sys |
| from pathlib import Path |
| import os |
| import ssl |
|
|
| import certifi |
| import torch |
| import torch.nn as nn |
|
|
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| VISION_LSTM_ROOT = REPO_ROOT / "external" / "vision-lstm" |
|
|
| if str(VISION_LSTM_ROOT) not in sys.path: |
| sys.path.insert(0, str(VISION_LSTM_ROOT)) |
|
|
| from vision_lstm import VisionLSTM, VisionLSTM2 |
|
|
|
|
| VISION_BACKBONES = { |
| "vil-small": { |
| "ctor": VisionLSTM, |
| "preprocess": "v1", |
| "url": "https://ml.jku.at/research/vision_lstm/download/vil_small16_e400_in1k.th", |
| "kwargs": { |
| "dim": 384, |
| "depth": 24, |
| "legacy_norm": True, |
| "mode": None, |
| "pooling": None, |
| "output_shape": None, |
| }, |
| }, |
| "vil2-small": { |
| "ctor": VisionLSTM2, |
| "preprocess": "v2", |
| "url": "https://ml.jku.at/research/vision_lstm/download/vil2_small16_e400_in1k.th", |
| "kwargs": { |
| "dim": 384, |
| "depth": 12, |
| "legacy_norm": True, |
| "mode": "features", |
| "pooling": None, |
| "output_shape": None, |
| "conv_kind": "2d", |
| "conv_kernel_size": 3, |
| "norm_bias": True, |
| "proj_bias": True, |
| }, |
| }, |
| } |
|
|
|
|
| def _preprocess_v1_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
| state_dict = {key.replace(".xlstm.", ".layer."): value for key, value in state_dict.items()} |
| state_dict = {key.replace("xlstm.", ""): value for key, value in state_dict.items()} |
| state_dict = {key.replace(".xlstm_norm.", ".norm."): value for key, value in state_dict.items()} |
| state_dict["legacy_norm.weight"] = state_dict.pop("post_blocks_norm.weight") |
| state_dict["norm.weight"] = state_dict.pop("head.0.weight") |
| state_dict["norm.bias"] = state_dict.pop("head.0.bias") |
| state_dict["head.weight"] = state_dict.pop("head.1.weight") |
| state_dict["head.bias"] = state_dict.pop("head.1.bias") |
| return state_dict |
|
|
|
|
| def _preprocess_v2_state_dict( |
| state_dict: dict[str, torch.Tensor], |
| *, |
| depth: int, |
| legacy_norm: bool, |
| ) -> dict[str, torch.Tensor]: |
| state_dict = {key.replace(".xlstm.", ".layer."): value for key, value in state_dict.items()} |
| state_dict = {key.replace("xlstm.", ""): value for key, value in state_dict.items()} |
| state_dict = {key.replace(".xlstm_norm.", ".norm."): value for key, value in state_dict.items()} |
| state_dict = {key.replace(".conv1d.", ".conv."): value for key, value in state_dict.items()} |
| for index in range(depth * 2): |
| if index % 2 == 0: |
| state_dict = { |
| key.replace(f"blocks.{index}.", f"blocks.{index // 2}.rowwise_from_top_left."): value |
| for key, value in state_dict.items() |
| } |
| else: |
| state_dict = { |
| key.replace(f"blocks.{index}.", f"blocks.{index // 2}.rowwise_from_bot_right."): value |
| for key, value in state_dict.items() |
| } |
| state_dict["norm.weight"] = state_dict.pop("post_blocks_norm.weight") |
| state_dict["norm.bias"] = state_dict.pop("post_blocks_norm.bias") |
| if legacy_norm: |
| state_dict["legacy_norm.weight"] = state_dict.pop("head.0.weight") |
| state_dict["legacy_norm.bias"] = state_dict.pop("head.0.bias") |
| state_dict["head.weight"] = state_dict.pop("head.1.weight") |
| state_dict["head.bias"] = state_dict.pop("head.1.bias") |
| return state_dict |
|
|
|
|
| def _load_pretrained_backbone(model: nn.Module, name: str, spec: dict) -> None: |
| os.environ.setdefault("SSL_CERT_FILE", certifi.where()) |
| ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=certifi.where()) |
| payload = torch.hub.load_state_dict_from_url(spec["url"], map_location="cpu") |
| state_dict = payload["state_dict"] |
| if spec["preprocess"] == "v1": |
| state_dict = _preprocess_v1_state_dict(state_dict) |
| elif spec["preprocess"] == "v2": |
| state_dict = _preprocess_v2_state_dict( |
| state_dict, |
| depth=spec["kwargs"]["depth"], |
| legacy_norm=spec["kwargs"]["legacy_norm"], |
| ) |
| else: |
| raise ValueError(f"Unsupported checkpoint preprocessing mode: {spec['preprocess']}") |
| if getattr(model, "head", None) is None: |
| state_dict.pop("head.weight", None) |
| state_dict.pop("head.bias", None) |
| model.load_state_dict(state_dict) |
|
|
|
|
| class VisionXLSTM(nn.Module): |
| """ |
| Thin adapter over upstream VisionLSTM / VisionLSTM2 models. |
| |
| The default backbone is `vil2-small`, which matches the requested 384-dim |
| patch features while using the newer ViL v2 implementation. |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| backbone_name = getattr(config, "vision_backbone", "vil2-small") |
| pretrained = getattr(config, "pretrained", True) |
| img_size = getattr(config, "img_size", 224) |
| patch_size = getattr(config, "patch_size", 16) |
| in_channels = getattr(config, "in_channels", 3) |
|
|
| if backbone_name not in VISION_BACKBONES: |
| supported = ", ".join(sorted(VISION_BACKBONES)) |
| raise ValueError(f"Unsupported vision backbone '{backbone_name}'. Supported backbones: {supported}") |
|
|
| spec = VISION_BACKBONES[backbone_name] |
| ctor_kwargs = dict(spec["kwargs"]) |
| ctor_kwargs["input_shape"] = (in_channels, img_size, img_size) |
| ctor_kwargs["patch_size"] = patch_size |
|
|
| self.config = config |
| self.backbone_name = backbone_name |
| self.model = spec["ctor"](**ctor_kwargs) |
| self.dim = ctor_kwargs["dim"] |
| self.num_patches = self.model.patch_embed.num_patches |
|
|
| if pretrained: |
| _load_pretrained_backbone(self.model, backbone_name, spec) |
|
|
| def forward_features(self, pixel_values: torch.Tensor) -> torch.Tensor: |
| return self.model(pixel_values) |
|
|
| def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: |
| return self.forward_features(pixel_values) |
|
|
|
|
| class VisionProjector(nn.Module): |
| """ |
| MLP projector: maps ViL features -> LM embedding space. |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| hidden_dim = config.lm_dim * config.hidden_mult |
|
|
| layers = [nn.Linear(config.vil_dim, hidden_dim), nn.GELU()] |
| if config.dropout > 0: |
| layers.append(nn.Dropout(config.dropout)) |
|
|
| for _ in range(config.num_layers - 1): |
| layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.GELU()]) |
| if config.dropout > 0: |
| layers.append(nn.Dropout(config.dropout)) |
|
|
| layers.append(nn.Linear(hidden_dim, config.lm_dim)) |
| self.mlp = nn.Sequential(*layers) |
|
|
| def forward(self, vision_features: torch.Tensor) -> torch.Tensor: |
| return self.mlp(vision_features) |
|
|