HuaminChen's picture
Upload multi-modal-embed-large final model
e21cde3 verified
from collections import defaultdict
from typing import Any, Dict, List
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from PIL import Image
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoProcessor, WhisperFeatureExtractor, WhisperModel
from .data import PairItem
class MultiModalSentenceEmbedder(nn.Module):
def __init__(
self,
text_encoder_name: str,
image_encoder_name: str,
audio_encoder_name: str,
embedding_dim: int,
max_text_length: int,
) -> None:
super().__init__()
self.text_model = SentenceTransformer(text_encoder_name)
self.text_model.max_seq_length = max_text_length
self.image_model = AutoModel.from_pretrained(image_encoder_name, trust_remote_code=True)
self.image_processor = AutoProcessor.from_pretrained(image_encoder_name, trust_remote_code=True)
whisper = WhisperModel.from_pretrained(audio_encoder_name)
self.audio_model = whisper.encoder
self.audio_processor = WhisperFeatureExtractor.from_pretrained(audio_encoder_name)
text_dim = self.text_model.get_sentence_embedding_dimension()
image_dim = self._get_vision_dim(self.image_model)
audio_dim = whisper.config.d_model
self.text_proj = nn.Linear(text_dim, embedding_dim) if text_dim != embedding_dim else nn.Identity()
self.image_proj = nn.Linear(image_dim, embedding_dim) if image_dim != embedding_dim else nn.Identity()
self.audio_proj = nn.Linear(audio_dim, embedding_dim) if audio_dim != embedding_dim else nn.Identity()
@staticmethod
def _get_vision_dim(model: nn.Module) -> int:
if hasattr(model, "vision_model") and hasattr(model.config, "vision_config"):
return int(model.config.vision_config.hidden_size)
if hasattr(model.config, "hidden_size"):
return int(model.config.hidden_size)
raise ValueError("Could not infer image hidden size")
def _encode_text(self, texts: List[Any]) -> torch.Tensor:
device = next(self.parameters()).device
normalized: List[torch.Tensor | None] = [None] * len(texts)
dict_positions = [idx for idx, item in enumerate(texts) if isinstance(item, dict)]
if dict_positions:
pad_values = {
"input_ids": 0,
"attention_mask": 0,
"token_type_ids": 0,
}
dict_items = [texts[idx] for idx in dict_positions]
features = {
key: pad_sequence(
[item[key].detach().cpu() for item in dict_items],
batch_first=True,
padding_value=pad_values.get(key, 0),
).to(device)
for key in dict_items[0].keys()
}
out = self.text_model(features)
emb = F.normalize(self.text_proj(out["sentence_embedding"]), p=2, dim=-1)
for loc, row in zip(dict_positions, emb):
normalized[loc] = row
raw_positions = [idx for idx, item in enumerate(texts) if not isinstance(item, dict)]
if raw_positions:
raw_texts = [texts[idx] for idx in raw_positions]
features = self.text_model.tokenize(raw_texts)
features = {
k: (v.to(device) if hasattr(v, "to") else v)
for k, v in features.items()
}
out = self.text_model(features)
emb = F.normalize(self.text_proj(out["sentence_embedding"]), p=2, dim=-1)
for loc, row in zip(raw_positions, emb):
normalized[loc] = row
return torch.stack([row for row in normalized if row is not None], dim=0)
def _encode_image_paths(self, paths: List[str]) -> torch.Tensor:
images = [Image.open(path).convert("RGB") for path in paths]
proc = self.image_processor(images=images, return_tensors="pt")
device = next(self.parameters()).device
proc = {k: v.to(device) for k, v in proc.items()}
return self._encode_image_pixel_values(proc["pixel_values"])
def _encode_image_pixel_values(self, pixel_values: torch.Tensor) -> torch.Tensor:
device = next(self.parameters()).device
proc = {"pixel_values": pixel_values.to(device)}
if hasattr(self.image_model, "vision_model"):
out = self.image_model.vision_model(**proc, output_hidden_states=False)
hidden = out.last_hidden_state
else:
out = self.image_model(**proc, output_hidden_states=False)
hidden = out.last_hidden_state
pooled = hidden[:, 1:].mean(dim=1) if hidden.shape[1] > 1 else hidden.mean(dim=1)
emb = self.image_proj(pooled)
return F.normalize(emb, p=2, dim=-1)
def _encode_audio_paths(self, paths: List[str]) -> torch.Tensor:
waves = [librosa.load(path, sr=16000, mono=True)[0] for path in paths]
proc = self.audio_processor(waves, sampling_rate=16000, return_tensors="pt")
return self._encode_audio_features(proc["input_features"])
def _encode_audio_features(self, input_features: torch.Tensor) -> torch.Tensor:
device = next(self.parameters()).device
input_features = input_features.to(device)
input_features = input_features.to(self.audio_model.conv1.weight.dtype)
out = self.audio_model(input_features=input_features, output_hidden_states=False)
pooled = out.last_hidden_state.mean(dim=1)
emb = self.audio_proj(pooled)
return F.normalize(emb, p=2, dim=-1)
@staticmethod
def _stack_tensor_values(values: List[Any]) -> torch.Tensor:
tensors = []
for value in values:
if not torch.is_tensor(value):
raise TypeError("Expected tensor payload in cached item")
tensor = value.detach().cpu()
if tensor.dim() > 0 and tensor.shape[0] == 1:
tensor = tensor.squeeze(0)
tensors.append(tensor)
return torch.stack(tensors, dim=0)
def encode_items(self, items: List[PairItem]) -> torch.Tensor:
grouped = defaultdict(list)
for idx, item in enumerate(items):
grouped[item.modality].append((idx, item.value))
device = next(self.parameters()).device
out = [None] * len(items)
if grouped["text"]:
idxs, vals = zip(*grouped["text"])
embs = self._encode_text(list(vals))
for loc, emb in zip(idxs, embs):
out[loc] = emb
if grouped["image"]:
idxs, vals = zip(*grouped["image"])
tensor_pairs = [(idx, val) for idx, val in zip(idxs, vals) if torch.is_tensor(val)]
path_pairs = [(idx, val) for idx, val in zip(idxs, vals) if not torch.is_tensor(val)]
if path_pairs:
p_idxs, p_vals = zip(*path_pairs)
embs = self._encode_image_paths(list(p_vals))
for loc, emb in zip(p_idxs, embs):
out[loc] = emb
if tensor_pairs:
t_idxs, t_vals = zip(*tensor_pairs)
embs = self._encode_image_pixel_values(self._stack_tensor_values(list(t_vals)))
for loc, emb in zip(t_idxs, embs):
out[loc] = emb
if grouped["audio"]:
idxs, vals = zip(*grouped["audio"])
tensor_pairs = [(idx, val) for idx, val in zip(idxs, vals) if torch.is_tensor(val)]
path_pairs = [(idx, val) for idx, val in zip(idxs, vals) if not torch.is_tensor(val)]
if path_pairs:
p_idxs, p_vals = zip(*path_pairs)
embs = self._encode_audio_paths(list(p_vals))
for loc, emb in zip(p_idxs, embs):
out[loc] = emb
if tensor_pairs:
t_idxs, t_vals = zip(*tensor_pairs)
embs = self._encode_audio_features(self._stack_tensor_values(list(t_vals)))
for loc, emb in zip(t_idxs, embs):
out[loc] = emb
stacked = torch.stack(out, dim=0).to(device=device, dtype=torch.float32)
return F.normalize(stacked, p=2, dim=-1)
def multiple_negatives_ranking_loss(anchor: torch.Tensor, positive: torch.Tensor, scale: float = 20.0) -> torch.Tensor:
scores = torch.matmul(anchor, positive.T) * scale
labels = torch.arange(scores.shape[0], device=scores.device)
loss_a = torch.nn.functional.cross_entropy(scores, labels)
loss_b = torch.nn.functional.cross_entropy(scores.T, labels)
return (loss_a + loss_b) * 0.5