| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import warnings |
| from transformers import ( |
| HubertModel, |
| AutoProcessor, |
| AutoTokenizer, |
| AutoModel |
| ) |
| warnings.filterwarnings("ignore") |
| import torchvision.transforms as transforms |
| from PIL import Image |
| |
| |
| |
| class AudioEmbedder(nn.Module): |
| """ |
| Pre-trained HuBERT (or similar) to extract audio features from raw audio (16kHz). |
| Projects them down to a desired embedding dimension. |
| """ |
| def __init__(self, embedding_dim=512, hubert_name="facebook/hubert-base-ls960"): |
| super().__init__() |
| self.processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft") |
| self.hubert = HubertModel.from_pretrained(hubert_name) |
| self.projection = nn.Linear(self.hubert.config.hidden_size, embedding_dim) |
| |
| for param in self.hubert.parameters(): |
| param.requires_grad = True |
| for param in self.projection.parameters(): |
| param.requires_grad = True |
| |
| def forward(self, audio_input: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| audio_input: (B, T) raw audio waveform at 16kHz |
| |
| Returns: |
| audio_feats: (B, Na, D) |
| B = batch size |
| Na = number of audio tokens (T/320 for Hubert) |
| D = embedding_dim |
| """ |
| if len(audio_input.shape) == 3: |
| audio_input = audio_input.squeeze(0) |
| inputs = self.processor( |
| audio_input, |
| return_tensors="pt", |
| sampling_rate=16000, |
| padding=True, |
| return_attention_mask=True |
| ).input_values.squeeze(0) |
| device = next(self.parameters()).device |
| inputs = inputs.to(device) |
| |
| hubert_output = self.hubert(inputs).last_hidden_state |
| |
| audio_feats = self.projection(hubert_output) |
| |
| return audio_feats |
|
|
|
|
| |
| |
| |
| class TextEmbedder(nn.Module): |
| """ |
| Pre-trained BERT-like model (ModernBERT or similar) to extract text features. |
| Projects them down to a desired embedding dimension. |
| """ |
| def __init__(self, embedding_dim=512, model_name="answerdotai/ModernBERT-base"): |
| super().__init__() |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.encoder = AutoModel.from_pretrained(model_name) |
| self.projection = nn.Linear(self.encoder.config.hidden_size, embedding_dim) |
| print("Using text model: ", model_name) |
| |
| for param in self.encoder.parameters(): |
| param.requires_grad = True |
| for param in self.projection.parameters(): |
| param.requires_grad = True |
| |
| def forward(self, text_list): |
| """ |
| Args: |
| text_list: List[str], batch of text inputs |
| |
| Returns: |
| text_feats: (B, Nt, D) |
| attention_mask: (B, Nt) |
| """ |
| inputs = self.tokenizer( |
| text_list, |
| padding=True, |
| truncation=True, |
| add_special_tokens=False, |
| max_length=128, |
| return_tensors="pt" |
| ) |
| device = next(self.parameters()).device |
| for k in inputs: |
| inputs[k] = inputs[k].to(device) |
|
|
| outputs = self.encoder(**inputs) |
| hidden_states = outputs.last_hidden_state |
| text_feats = self.projection(hidden_states) |
| |
| return text_feats, inputs["attention_mask"] |
|
|
|
|
| |
| |
| |
| class ViTEmbedder(nn.Module): |
| """ |
| DINOv2 to extract patch embeddings from an image. |
| Then projects to a common dimension with a linear layer. |
| """ |
| def __init__(self, model_name='facebookresearch/dinov2', arch='dinov2_vitb14', |
| embedding_dim=512, dropout_prob=0.1): |
| super().__init__() |
| self.model = torch.hub.load(model_name, arch) |
| print("Using DINOv2 model: ", arch) |
| self.projection = nn.Linear(self.model.embed_dim, embedding_dim) |
| self.dropout = nn.Dropout(p=dropout_prob) |
|
|
| for param in self.model.parameters(): |
| param.requires_grad = True |
|
|
| def forward(self, x): |
| """ |
| Args: |
| x: (B, 3, H, W), e.g. (B,3,224,224) image batch |
| Returns: |
| visual_feats: (B, Nv, D) |
| Nv = number of visual tokens |
| D = embedding_dim |
| """ |
| if len(x.shape) == 5: |
| x = x.squeeze(0) |
| if len(x.shape) == 3: |
| x = x.unsqueeze(0) |
| patches = self.model.get_intermediate_layers(x, n=1)[0] |
| feats = self.projection(patches) |
| feats = self.dropout(feats) |
| |
| return feats |
|
|
| class Triad(nn.Module): |
| def __init__( |
| self, |
| audio_model_name="facebook/hubert-base-ls960", |
| text_model_name="distilbert/distilbert-base-uncased", |
| temperature=2.0, |
| patch_sparsity_threshold=0.3, |
| patch_sparsity_weight=0.1, |
| visual_dropout_prob=0.1 |
| ): |
| super().__init__() |
|
|
| self.audio_embedder = AudioEmbedder(embedding_dim=512, hubert_name=audio_model_name) |
| self.text_embedder = TextEmbedder(embedding_dim=512, model_name=text_model_name) |
| self.visual_embedder = ViTEmbedder(arch='dinov2_vitb14', |
| embedding_dim=512, |
| dropout_prob=visual_dropout_prob) |
|
|
| self.temperature = nn.Parameter(torch.tensor(temperature)) |
| self.patch_sparsity_threshold = patch_sparsity_threshold |
| self.patch_sparsity_weight = patch_sparsity_weight |
|
|
| def compute_similarity_matrix(self, feats1, feats2): |
| """ |
| Generic token-level dot-product similarity between feats1 and feats2. |
| feats1: (B, N1, D) |
| feats2: (B, N2, D) |
| Returns sim: (B, N1, N2) |
| """ |
| sim = torch.bmm(feats1, feats2.transpose(1, 2)) |
| return sim / self.temperature |
| |
| def forward(self, image=None, audio=None, text_list=None): |
| assert image is not None or audio is not None or text_list is not None, "At least one modality must be provided" |
| if image is not None: assert image is not str, "Frames should be a path to an image" |
| if audio is not None: |
| assert isinstance(audio, torch.Tensor) and len(audio.shape) == 2, "Audio must be a PyTorch tensor of shape (B, T)" |
| if text_list is not None: |
| assert isinstance(text_list, list) and len(text_list) == 1, "Text list must be a list of strings of length 1" |
| if image is not None: |
| device = next(self.parameters()).device |
| |
| |
| if isinstance(image, list): |
| |
| processed_images = [] |
| for img_path in image: |
| img = Image.open(img_path).convert('RGB') |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
| processed_img = transform(img).to(device) |
| processed_images.append(processed_img) |
| image = torch.stack(processed_images, dim=0) |
| |
| |
| elif isinstance(image, str): |
| img = Image.open(image).convert('RGB') |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
| image = transform(img).to(device).unsqueeze(0) |
| |
| |
| elif isinstance(image, torch.Tensor): |
| |
| if image.dim() == 3: |
| image = image.unsqueeze(0) |
| image = image.to(device) |
| |
| embeddings = {} |
| if image is not None: |
| embeddings['visual_feats'] = self.visual_embedder(image) |
| if audio is not None: |
| embeddings['audio_feats'] = self.audio_embedder(audio) |
| if text_list is not None: |
| embeddings['text_feats'], _ = self.text_embedder(text_list) |
| |
| if image is not None and text_list is not None: |
| embeddings['vis_text_sim_matrix'] = self.compute_similarity_matrix(embeddings['text_feats'], embeddings['visual_feats']) |
| if audio is not None and image is not None: |
| embeddings['vis_audio_sim_matrix'] = self.compute_similarity_matrix(embeddings['audio_feats'], embeddings['visual_feats']) |
| if text_list is not None and audio is not None: |
| embeddings['text_audio_sim_matrix'] = self.compute_similarity_matrix(embeddings['text_feats'], embeddings['audio_feats']) |
| return embeddings |
|
|
|
|