File size: 7,146 Bytes
519f856
2b05eb6
 
 
 
 
519f856
 
2b05eb6
 
 
 
 
 
 
 
519f856
 
 
 
2b05eb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519f856
2b05eb6
 
 
 
 
 
519f856
 
 
 
2b05eb6
 
 
 
519f856
2b05eb6
519f856
 
2b05eb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519f856
2b05eb6
 
 
 
 
 
 
 
 
 
 
 
 
519f856
 
 
 
2b05eb6
519f856
2b05eb6
519f856
 
 
2b05eb6
 
519f856
 
2b05eb6
519f856
2b05eb6
519f856
 
2b05eb6
519f856
 
2b05eb6
 
519f856
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""
Vision xLSTM adapter built on the upstream NX-AI vision-lstm repository.

This module keeps the existing ViL-DLM contract:
- `VisionXLSTM.forward_features(pixel_values)` returns patch tokens `[B, N, D]`
- `VisionProjector` maps those visual tokens into the LM embedding space
"""

from __future__ import annotations

import sys
from pathlib import Path
import os
import ssl

import certifi
import torch
import torch.nn as nn


REPO_ROOT = Path(__file__).resolve().parents[1]
VISION_LSTM_ROOT = REPO_ROOT / "external" / "vision-lstm"

if str(VISION_LSTM_ROOT) not in sys.path:
    sys.path.insert(0, str(VISION_LSTM_ROOT))

from vision_lstm import VisionLSTM, VisionLSTM2  # noqa: E402


VISION_BACKBONES = {
    "vil-small": {
        "ctor": VisionLSTM,
        "preprocess": "v1",
        "url": "https://ml.jku.at/research/vision_lstm/download/vil_small16_e400_in1k.th",
        "kwargs": {
            "dim": 384,
            "depth": 24,
            "legacy_norm": True,
            "mode": None,
            "pooling": None,
            "output_shape": None,
        },
    },
    "vil2-small": {
        "ctor": VisionLSTM2,
        "preprocess": "v2",
        "url": "https://ml.jku.at/research/vision_lstm/download/vil2_small16_e400_in1k.th",
        "kwargs": {
            "dim": 384,
            "depth": 12,
            "legacy_norm": True,
            "mode": "features",
            "pooling": None,
            "output_shape": None,
            "conv_kind": "2d",
            "conv_kernel_size": 3,
            "norm_bias": True,
            "proj_bias": True,
        },
    },
}


def _preprocess_v1_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
    state_dict = {key.replace(".xlstm.", ".layer."): value for key, value in state_dict.items()}
    state_dict = {key.replace("xlstm.", ""): value for key, value in state_dict.items()}
    state_dict = {key.replace(".xlstm_norm.", ".norm."): value for key, value in state_dict.items()}
    state_dict["legacy_norm.weight"] = state_dict.pop("post_blocks_norm.weight")
    state_dict["norm.weight"] = state_dict.pop("head.0.weight")
    state_dict["norm.bias"] = state_dict.pop("head.0.bias")
    state_dict["head.weight"] = state_dict.pop("head.1.weight")
    state_dict["head.bias"] = state_dict.pop("head.1.bias")
    return state_dict


def _preprocess_v2_state_dict(
    state_dict: dict[str, torch.Tensor],
    *,
    depth: int,
    legacy_norm: bool,
) -> dict[str, torch.Tensor]:
    state_dict = {key.replace(".xlstm.", ".layer."): value for key, value in state_dict.items()}
    state_dict = {key.replace("xlstm.", ""): value for key, value in state_dict.items()}
    state_dict = {key.replace(".xlstm_norm.", ".norm."): value for key, value in state_dict.items()}
    state_dict = {key.replace(".conv1d.", ".conv."): value for key, value in state_dict.items()}
    for index in range(depth * 2):
        if index % 2 == 0:
            state_dict = {
                key.replace(f"blocks.{index}.", f"blocks.{index // 2}.rowwise_from_top_left."): value
                for key, value in state_dict.items()
            }
        else:
            state_dict = {
                key.replace(f"blocks.{index}.", f"blocks.{index // 2}.rowwise_from_bot_right."): value
                for key, value in state_dict.items()
            }
    state_dict["norm.weight"] = state_dict.pop("post_blocks_norm.weight")
    state_dict["norm.bias"] = state_dict.pop("post_blocks_norm.bias")
    if legacy_norm:
        state_dict["legacy_norm.weight"] = state_dict.pop("head.0.weight")
        state_dict["legacy_norm.bias"] = state_dict.pop("head.0.bias")
    state_dict["head.weight"] = state_dict.pop("head.1.weight")
    state_dict["head.bias"] = state_dict.pop("head.1.bias")
    return state_dict


def _load_pretrained_backbone(model: nn.Module, name: str, spec: dict) -> None:
    os.environ.setdefault("SSL_CERT_FILE", certifi.where())
    ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=certifi.where())
    payload = torch.hub.load_state_dict_from_url(spec["url"], map_location="cpu")
    state_dict = payload["state_dict"]
    if spec["preprocess"] == "v1":
        state_dict = _preprocess_v1_state_dict(state_dict)
    elif spec["preprocess"] == "v2":
        state_dict = _preprocess_v2_state_dict(
            state_dict,
            depth=spec["kwargs"]["depth"],
            legacy_norm=spec["kwargs"]["legacy_norm"],
        )
    else:
        raise ValueError(f"Unsupported checkpoint preprocessing mode: {spec['preprocess']}")
    if getattr(model, "head", None) is None:
        state_dict.pop("head.weight", None)
        state_dict.pop("head.bias", None)
    model.load_state_dict(state_dict)


class VisionXLSTM(nn.Module):
    """
    Thin adapter over upstream VisionLSTM / VisionLSTM2 models.

    The default backbone is `vil2-small`, which matches the requested 384-dim
    patch features while using the newer ViL v2 implementation.
    """

    def __init__(self, config):
        super().__init__()
        backbone_name = getattr(config, "vision_backbone", "vil2-small")
        pretrained = getattr(config, "pretrained", True)
        img_size = getattr(config, "img_size", 224)
        patch_size = getattr(config, "patch_size", 16)
        in_channels = getattr(config, "in_channels", 3)

        if backbone_name not in VISION_BACKBONES:
            supported = ", ".join(sorted(VISION_BACKBONES))
            raise ValueError(f"Unsupported vision backbone '{backbone_name}'. Supported backbones: {supported}")

        spec = VISION_BACKBONES[backbone_name]
        ctor_kwargs = dict(spec["kwargs"])
        ctor_kwargs["input_shape"] = (in_channels, img_size, img_size)
        ctor_kwargs["patch_size"] = patch_size

        self.config = config
        self.backbone_name = backbone_name
        self.model = spec["ctor"](**ctor_kwargs)
        self.dim = ctor_kwargs["dim"]
        self.num_patches = self.model.patch_embed.num_patches

        if pretrained:
            _load_pretrained_backbone(self.model, backbone_name, spec)

    def forward_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
        return self.model(pixel_values)

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        return self.forward_features(pixel_values)


class VisionProjector(nn.Module):
    """
    MLP projector: maps ViL features -> LM embedding space.
    """

    def __init__(self, config):
        super().__init__()
        hidden_dim = config.lm_dim * config.hidden_mult

        layers = [nn.Linear(config.vil_dim, hidden_dim), nn.GELU()]
        if config.dropout > 0:
            layers.append(nn.Dropout(config.dropout))

        for _ in range(config.num_layers - 1):
            layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.GELU()])
            if config.dropout > 0:
                layers.append(nn.Dropout(config.dropout))

        layers.append(nn.Linear(hidden_dim, config.lm_dim))
        self.mlp = nn.Sequential(*layers)

    def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
        return self.mlp(vision_features)