File size: 8,681 Bytes
e21cde3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from collections import defaultdict
from typing import Any, Dict, List

import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from PIL import Image
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoProcessor, WhisperFeatureExtractor, WhisperModel

from .data import PairItem


class MultiModalSentenceEmbedder(nn.Module):
    def __init__(
        self,
        text_encoder_name: str,
        image_encoder_name: str,
        audio_encoder_name: str,
        embedding_dim: int,
        max_text_length: int,
    ) -> None:
        super().__init__()
        self.text_model = SentenceTransformer(text_encoder_name)
        self.text_model.max_seq_length = max_text_length

        self.image_model = AutoModel.from_pretrained(image_encoder_name, trust_remote_code=True)
        self.image_processor = AutoProcessor.from_pretrained(image_encoder_name, trust_remote_code=True)

        whisper = WhisperModel.from_pretrained(audio_encoder_name)
        self.audio_model = whisper.encoder
        self.audio_processor = WhisperFeatureExtractor.from_pretrained(audio_encoder_name)

        text_dim = self.text_model.get_sentence_embedding_dimension()
        image_dim = self._get_vision_dim(self.image_model)
        audio_dim = whisper.config.d_model

        self.text_proj = nn.Linear(text_dim, embedding_dim) if text_dim != embedding_dim else nn.Identity()
        self.image_proj = nn.Linear(image_dim, embedding_dim) if image_dim != embedding_dim else nn.Identity()
        self.audio_proj = nn.Linear(audio_dim, embedding_dim) if audio_dim != embedding_dim else nn.Identity()

    @staticmethod
    def _get_vision_dim(model: nn.Module) -> int:
        if hasattr(model, "vision_model") and hasattr(model.config, "vision_config"):
            return int(model.config.vision_config.hidden_size)
        if hasattr(model.config, "hidden_size"):
            return int(model.config.hidden_size)
        raise ValueError("Could not infer image hidden size")

    def _encode_text(self, texts: List[Any]) -> torch.Tensor:
        device = next(self.parameters()).device
        normalized: List[torch.Tensor | None] = [None] * len(texts)

        dict_positions = [idx for idx, item in enumerate(texts) if isinstance(item, dict)]
        if dict_positions:
            pad_values = {
                "input_ids": 0,
                "attention_mask": 0,
                "token_type_ids": 0,
            }
            dict_items = [texts[idx] for idx in dict_positions]
            features = {
                key: pad_sequence(
                    [item[key].detach().cpu() for item in dict_items],
                    batch_first=True,
                    padding_value=pad_values.get(key, 0),
                ).to(device)
                for key in dict_items[0].keys()
            }
            out = self.text_model(features)
            emb = F.normalize(self.text_proj(out["sentence_embedding"]), p=2, dim=-1)
            for loc, row in zip(dict_positions, emb):
                normalized[loc] = row

        raw_positions = [idx for idx, item in enumerate(texts) if not isinstance(item, dict)]
        if raw_positions:
            raw_texts = [texts[idx] for idx in raw_positions]
            features = self.text_model.tokenize(raw_texts)
            features = {
                k: (v.to(device) if hasattr(v, "to") else v)
                for k, v in features.items()
            }
            out = self.text_model(features)
            emb = F.normalize(self.text_proj(out["sentence_embedding"]), p=2, dim=-1)
            for loc, row in zip(raw_positions, emb):
                normalized[loc] = row

        return torch.stack([row for row in normalized if row is not None], dim=0)

    def _encode_image_paths(self, paths: List[str]) -> torch.Tensor:
        images = [Image.open(path).convert("RGB") for path in paths]
        proc = self.image_processor(images=images, return_tensors="pt")
        device = next(self.parameters()).device
        proc = {k: v.to(device) for k, v in proc.items()}
        return self._encode_image_pixel_values(proc["pixel_values"])

    def _encode_image_pixel_values(self, pixel_values: torch.Tensor) -> torch.Tensor:
        device = next(self.parameters()).device
        proc = {"pixel_values": pixel_values.to(device)}
        if hasattr(self.image_model, "vision_model"):
            out = self.image_model.vision_model(**proc, output_hidden_states=False)
            hidden = out.last_hidden_state
        else:
            out = self.image_model(**proc, output_hidden_states=False)
            hidden = out.last_hidden_state
        pooled = hidden[:, 1:].mean(dim=1) if hidden.shape[1] > 1 else hidden.mean(dim=1)
        emb = self.image_proj(pooled)
        return F.normalize(emb, p=2, dim=-1)

    def _encode_audio_paths(self, paths: List[str]) -> torch.Tensor:
        waves = [librosa.load(path, sr=16000, mono=True)[0] for path in paths]
        proc = self.audio_processor(waves, sampling_rate=16000, return_tensors="pt")
        return self._encode_audio_features(proc["input_features"])

    def _encode_audio_features(self, input_features: torch.Tensor) -> torch.Tensor:
        device = next(self.parameters()).device
        input_features = input_features.to(device)
        input_features = input_features.to(self.audio_model.conv1.weight.dtype)
        out = self.audio_model(input_features=input_features, output_hidden_states=False)
        pooled = out.last_hidden_state.mean(dim=1)
        emb = self.audio_proj(pooled)
        return F.normalize(emb, p=2, dim=-1)

    @staticmethod
    def _stack_tensor_values(values: List[Any]) -> torch.Tensor:
        tensors = []
        for value in values:
            if not torch.is_tensor(value):
                raise TypeError("Expected tensor payload in cached item")
            tensor = value.detach().cpu()
            if tensor.dim() > 0 and tensor.shape[0] == 1:
                tensor = tensor.squeeze(0)
            tensors.append(tensor)
        return torch.stack(tensors, dim=0)

    def encode_items(self, items: List[PairItem]) -> torch.Tensor:
        grouped = defaultdict(list)
        for idx, item in enumerate(items):
            grouped[item.modality].append((idx, item.value))

        device = next(self.parameters()).device
        out = [None] * len(items)

        if grouped["text"]:
            idxs, vals = zip(*grouped["text"])
            embs = self._encode_text(list(vals))
            for loc, emb in zip(idxs, embs):
                out[loc] = emb

        if grouped["image"]:
            idxs, vals = zip(*grouped["image"])
            tensor_pairs = [(idx, val) for idx, val in zip(idxs, vals) if torch.is_tensor(val)]
            path_pairs = [(idx, val) for idx, val in zip(idxs, vals) if not torch.is_tensor(val)]
            if path_pairs:
                p_idxs, p_vals = zip(*path_pairs)
                embs = self._encode_image_paths(list(p_vals))
                for loc, emb in zip(p_idxs, embs):
                    out[loc] = emb
            if tensor_pairs:
                t_idxs, t_vals = zip(*tensor_pairs)
                embs = self._encode_image_pixel_values(self._stack_tensor_values(list(t_vals)))
                for loc, emb in zip(t_idxs, embs):
                    out[loc] = emb

        if grouped["audio"]:
            idxs, vals = zip(*grouped["audio"])
            tensor_pairs = [(idx, val) for idx, val in zip(idxs, vals) if torch.is_tensor(val)]
            path_pairs = [(idx, val) for idx, val in zip(idxs, vals) if not torch.is_tensor(val)]
            if path_pairs:
                p_idxs, p_vals = zip(*path_pairs)
                embs = self._encode_audio_paths(list(p_vals))
                for loc, emb in zip(p_idxs, embs):
                    out[loc] = emb
            if tensor_pairs:
                t_idxs, t_vals = zip(*tensor_pairs)
                embs = self._encode_audio_features(self._stack_tensor_values(list(t_vals)))
                for loc, emb in zip(t_idxs, embs):
                    out[loc] = emb

        stacked = torch.stack(out, dim=0).to(device=device, dtype=torch.float32)
        return F.normalize(stacked, p=2, dim=-1)


def multiple_negatives_ranking_loss(anchor: torch.Tensor, positive: torch.Tensor, scale: float = 20.0) -> torch.Tensor:
    scores = torch.matmul(anchor, positive.T) * scale
    labels = torch.arange(scores.shape[0], device=scores.device)
    loss_a = torch.nn.functional.cross_entropy(scores, labels)
    loss_b = torch.nn.functional.cross_entropy(scores.T, labels)
    return (loss_a + loss_b) * 0.5