File size: 7,146 Bytes
519f856 2b05eb6 519f856 2b05eb6 519f856 2b05eb6 519f856 2b05eb6 519f856 2b05eb6 519f856 2b05eb6 519f856 2b05eb6 519f856 2b05eb6 519f856 2b05eb6 519f856 2b05eb6 519f856 2b05eb6 519f856 2b05eb6 519f856 2b05eb6 519f856 2b05eb6 519f856 2b05eb6 519f856 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | """
Vision xLSTM adapter built on the upstream NX-AI vision-lstm repository.
This module keeps the existing ViL-DLM contract:
- `VisionXLSTM.forward_features(pixel_values)` returns patch tokens `[B, N, D]`
- `VisionProjector` maps those visual tokens into the LM embedding space
"""
from __future__ import annotations
import sys
from pathlib import Path
import os
import ssl
import certifi
import torch
import torch.nn as nn
REPO_ROOT = Path(__file__).resolve().parents[1]
VISION_LSTM_ROOT = REPO_ROOT / "external" / "vision-lstm"
if str(VISION_LSTM_ROOT) not in sys.path:
sys.path.insert(0, str(VISION_LSTM_ROOT))
from vision_lstm import VisionLSTM, VisionLSTM2 # noqa: E402
VISION_BACKBONES = {
"vil-small": {
"ctor": VisionLSTM,
"preprocess": "v1",
"url": "https://ml.jku.at/research/vision_lstm/download/vil_small16_e400_in1k.th",
"kwargs": {
"dim": 384,
"depth": 24,
"legacy_norm": True,
"mode": None,
"pooling": None,
"output_shape": None,
},
},
"vil2-small": {
"ctor": VisionLSTM2,
"preprocess": "v2",
"url": "https://ml.jku.at/research/vision_lstm/download/vil2_small16_e400_in1k.th",
"kwargs": {
"dim": 384,
"depth": 12,
"legacy_norm": True,
"mode": "features",
"pooling": None,
"output_shape": None,
"conv_kind": "2d",
"conv_kernel_size": 3,
"norm_bias": True,
"proj_bias": True,
},
},
}
def _preprocess_v1_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
state_dict = {key.replace(".xlstm.", ".layer."): value for key, value in state_dict.items()}
state_dict = {key.replace("xlstm.", ""): value for key, value in state_dict.items()}
state_dict = {key.replace(".xlstm_norm.", ".norm."): value for key, value in state_dict.items()}
state_dict["legacy_norm.weight"] = state_dict.pop("post_blocks_norm.weight")
state_dict["norm.weight"] = state_dict.pop("head.0.weight")
state_dict["norm.bias"] = state_dict.pop("head.0.bias")
state_dict["head.weight"] = state_dict.pop("head.1.weight")
state_dict["head.bias"] = state_dict.pop("head.1.bias")
return state_dict
def _preprocess_v2_state_dict(
state_dict: dict[str, torch.Tensor],
*,
depth: int,
legacy_norm: bool,
) -> dict[str, torch.Tensor]:
state_dict = {key.replace(".xlstm.", ".layer."): value for key, value in state_dict.items()}
state_dict = {key.replace("xlstm.", ""): value for key, value in state_dict.items()}
state_dict = {key.replace(".xlstm_norm.", ".norm."): value for key, value in state_dict.items()}
state_dict = {key.replace(".conv1d.", ".conv."): value for key, value in state_dict.items()}
for index in range(depth * 2):
if index % 2 == 0:
state_dict = {
key.replace(f"blocks.{index}.", f"blocks.{index // 2}.rowwise_from_top_left."): value
for key, value in state_dict.items()
}
else:
state_dict = {
key.replace(f"blocks.{index}.", f"blocks.{index // 2}.rowwise_from_bot_right."): value
for key, value in state_dict.items()
}
state_dict["norm.weight"] = state_dict.pop("post_blocks_norm.weight")
state_dict["norm.bias"] = state_dict.pop("post_blocks_norm.bias")
if legacy_norm:
state_dict["legacy_norm.weight"] = state_dict.pop("head.0.weight")
state_dict["legacy_norm.bias"] = state_dict.pop("head.0.bias")
state_dict["head.weight"] = state_dict.pop("head.1.weight")
state_dict["head.bias"] = state_dict.pop("head.1.bias")
return state_dict
def _load_pretrained_backbone(model: nn.Module, name: str, spec: dict) -> None:
os.environ.setdefault("SSL_CERT_FILE", certifi.where())
ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=certifi.where())
payload = torch.hub.load_state_dict_from_url(spec["url"], map_location="cpu")
state_dict = payload["state_dict"]
if spec["preprocess"] == "v1":
state_dict = _preprocess_v1_state_dict(state_dict)
elif spec["preprocess"] == "v2":
state_dict = _preprocess_v2_state_dict(
state_dict,
depth=spec["kwargs"]["depth"],
legacy_norm=spec["kwargs"]["legacy_norm"],
)
else:
raise ValueError(f"Unsupported checkpoint preprocessing mode: {spec['preprocess']}")
if getattr(model, "head", None) is None:
state_dict.pop("head.weight", None)
state_dict.pop("head.bias", None)
model.load_state_dict(state_dict)
class VisionXLSTM(nn.Module):
"""
Thin adapter over upstream VisionLSTM / VisionLSTM2 models.
The default backbone is `vil2-small`, which matches the requested 384-dim
patch features while using the newer ViL v2 implementation.
"""
def __init__(self, config):
super().__init__()
backbone_name = getattr(config, "vision_backbone", "vil2-small")
pretrained = getattr(config, "pretrained", True)
img_size = getattr(config, "img_size", 224)
patch_size = getattr(config, "patch_size", 16)
in_channels = getattr(config, "in_channels", 3)
if backbone_name not in VISION_BACKBONES:
supported = ", ".join(sorted(VISION_BACKBONES))
raise ValueError(f"Unsupported vision backbone '{backbone_name}'. Supported backbones: {supported}")
spec = VISION_BACKBONES[backbone_name]
ctor_kwargs = dict(spec["kwargs"])
ctor_kwargs["input_shape"] = (in_channels, img_size, img_size)
ctor_kwargs["patch_size"] = patch_size
self.config = config
self.backbone_name = backbone_name
self.model = spec["ctor"](**ctor_kwargs)
self.dim = ctor_kwargs["dim"]
self.num_patches = self.model.patch_embed.num_patches
if pretrained:
_load_pretrained_backbone(self.model, backbone_name, spec)
def forward_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
return self.model(pixel_values)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return self.forward_features(pixel_values)
class VisionProjector(nn.Module):
"""
MLP projector: maps ViL features -> LM embedding space.
"""
def __init__(self, config):
super().__init__()
hidden_dim = config.lm_dim * config.hidden_mult
layers = [nn.Linear(config.vil_dim, hidden_dim), nn.GELU()]
if config.dropout > 0:
layers.append(nn.Dropout(config.dropout))
for _ in range(config.num_layers - 1):
layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.GELU()])
if config.dropout > 0:
layers.append(nn.Dropout(config.dropout))
layers.append(nn.Linear(hidden_dim, config.lm_dim))
self.mlp = nn.Sequential(*layers)
def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
return self.mlp(vision_features)
|