ViL-DLM-0.6B / code /vision_xlstm.py
omar-ah's picture
Update model configuration and training scripts with new vision backbone support and dependencies
2b05eb6
"""
Vision xLSTM adapter built on the upstream NX-AI vision-lstm repository.
This module keeps the existing ViL-DLM contract:
- `VisionXLSTM.forward_features(pixel_values)` returns patch tokens `[B, N, D]`
- `VisionProjector` maps those visual tokens into the LM embedding space
"""
from __future__ import annotations
import sys
from pathlib import Path
import os
import ssl
import certifi
import torch
import torch.nn as nn
REPO_ROOT = Path(__file__).resolve().parents[1]
VISION_LSTM_ROOT = REPO_ROOT / "external" / "vision-lstm"
if str(VISION_LSTM_ROOT) not in sys.path:
sys.path.insert(0, str(VISION_LSTM_ROOT))
from vision_lstm import VisionLSTM, VisionLSTM2 # noqa: E402
VISION_BACKBONES = {
"vil-small": {
"ctor": VisionLSTM,
"preprocess": "v1",
"url": "https://ml.jku.at/research/vision_lstm/download/vil_small16_e400_in1k.th",
"kwargs": {
"dim": 384,
"depth": 24,
"legacy_norm": True,
"mode": None,
"pooling": None,
"output_shape": None,
},
},
"vil2-small": {
"ctor": VisionLSTM2,
"preprocess": "v2",
"url": "https://ml.jku.at/research/vision_lstm/download/vil2_small16_e400_in1k.th",
"kwargs": {
"dim": 384,
"depth": 12,
"legacy_norm": True,
"mode": "features",
"pooling": None,
"output_shape": None,
"conv_kind": "2d",
"conv_kernel_size": 3,
"norm_bias": True,
"proj_bias": True,
},
},
}
def _preprocess_v1_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
state_dict = {key.replace(".xlstm.", ".layer."): value for key, value in state_dict.items()}
state_dict = {key.replace("xlstm.", ""): value for key, value in state_dict.items()}
state_dict = {key.replace(".xlstm_norm.", ".norm."): value for key, value in state_dict.items()}
state_dict["legacy_norm.weight"] = state_dict.pop("post_blocks_norm.weight")
state_dict["norm.weight"] = state_dict.pop("head.0.weight")
state_dict["norm.bias"] = state_dict.pop("head.0.bias")
state_dict["head.weight"] = state_dict.pop("head.1.weight")
state_dict["head.bias"] = state_dict.pop("head.1.bias")
return state_dict
def _preprocess_v2_state_dict(
state_dict: dict[str, torch.Tensor],
*,
depth: int,
legacy_norm: bool,
) -> dict[str, torch.Tensor]:
state_dict = {key.replace(".xlstm.", ".layer."): value for key, value in state_dict.items()}
state_dict = {key.replace("xlstm.", ""): value for key, value in state_dict.items()}
state_dict = {key.replace(".xlstm_norm.", ".norm."): value for key, value in state_dict.items()}
state_dict = {key.replace(".conv1d.", ".conv."): value for key, value in state_dict.items()}
for index in range(depth * 2):
if index % 2 == 0:
state_dict = {
key.replace(f"blocks.{index}.", f"blocks.{index // 2}.rowwise_from_top_left."): value
for key, value in state_dict.items()
}
else:
state_dict = {
key.replace(f"blocks.{index}.", f"blocks.{index // 2}.rowwise_from_bot_right."): value
for key, value in state_dict.items()
}
state_dict["norm.weight"] = state_dict.pop("post_blocks_norm.weight")
state_dict["norm.bias"] = state_dict.pop("post_blocks_norm.bias")
if legacy_norm:
state_dict["legacy_norm.weight"] = state_dict.pop("head.0.weight")
state_dict["legacy_norm.bias"] = state_dict.pop("head.0.bias")
state_dict["head.weight"] = state_dict.pop("head.1.weight")
state_dict["head.bias"] = state_dict.pop("head.1.bias")
return state_dict
def _load_pretrained_backbone(model: nn.Module, name: str, spec: dict) -> None:
os.environ.setdefault("SSL_CERT_FILE", certifi.where())
ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=certifi.where())
payload = torch.hub.load_state_dict_from_url(spec["url"], map_location="cpu")
state_dict = payload["state_dict"]
if spec["preprocess"] == "v1":
state_dict = _preprocess_v1_state_dict(state_dict)
elif spec["preprocess"] == "v2":
state_dict = _preprocess_v2_state_dict(
state_dict,
depth=spec["kwargs"]["depth"],
legacy_norm=spec["kwargs"]["legacy_norm"],
)
else:
raise ValueError(f"Unsupported checkpoint preprocessing mode: {spec['preprocess']}")
if getattr(model, "head", None) is None:
state_dict.pop("head.weight", None)
state_dict.pop("head.bias", None)
model.load_state_dict(state_dict)
class VisionXLSTM(nn.Module):
"""
Thin adapter over upstream VisionLSTM / VisionLSTM2 models.
The default backbone is `vil2-small`, which matches the requested 384-dim
patch features while using the newer ViL v2 implementation.
"""
def __init__(self, config):
super().__init__()
backbone_name = getattr(config, "vision_backbone", "vil2-small")
pretrained = getattr(config, "pretrained", True)
img_size = getattr(config, "img_size", 224)
patch_size = getattr(config, "patch_size", 16)
in_channels = getattr(config, "in_channels", 3)
if backbone_name not in VISION_BACKBONES:
supported = ", ".join(sorted(VISION_BACKBONES))
raise ValueError(f"Unsupported vision backbone '{backbone_name}'. Supported backbones: {supported}")
spec = VISION_BACKBONES[backbone_name]
ctor_kwargs = dict(spec["kwargs"])
ctor_kwargs["input_shape"] = (in_channels, img_size, img_size)
ctor_kwargs["patch_size"] = patch_size
self.config = config
self.backbone_name = backbone_name
self.model = spec["ctor"](**ctor_kwargs)
self.dim = ctor_kwargs["dim"]
self.num_patches = self.model.patch_embed.num_patches
if pretrained:
_load_pretrained_backbone(self.model, backbone_name, spec)
def forward_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
return self.model(pixel_values)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return self.forward_features(pixel_values)
class VisionProjector(nn.Module):
"""
MLP projector: maps ViL features -> LM embedding space.
"""
def __init__(self, config):
super().__init__()
hidden_dim = config.lm_dim * config.hidden_mult
layers = [nn.Linear(config.vil_dim, hidden_dim), nn.GELU()]
if config.dropout > 0:
layers.append(nn.Dropout(config.dropout))
for _ in range(config.num_layers - 1):
layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.GELU()])
if config.dropout > 0:
layers.append(nn.Dropout(config.dropout))
layers.append(nn.Linear(hidden_dim, config.lm_dim))
self.mlp = nn.Sequential(*layers)
def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
return self.mlp(vision_features)