from __future__ import annotations from typing import Any from torch import Tensor from transformers import PretrainedConfig, PreTrainedModel from sentence_transformers.base.modules.transformer import Transformer class SiglipVisionTransformer(Transformer): """Drop-in :class:`Transformer` subclass that exposes only the SigLIP vision tower.""" def __init__(self, model_name_or_path: str, **kwargs: Any) -> None: super().__init__(model_name_or_path, **kwargs) # Drop the unused SigLIP text tokenizer (~17 MB) from the saved layout. if hasattr(self.processor, "image_processor"): self.processor = self.processor.image_processor def _load_model( self, model_name_or_path: str, transformer_task: str, config: PretrainedConfig, backend: str, is_peft_model: bool, **model_kwargs: Any, ) -> PreTrainedModel: full_model = super()._load_model( model_name_or_path, transformer_task, config, backend, is_peft_model, **model_kwargs ) # getattr keeps the vision tower on fresh init; on reload it's already SiglipVisionModel. return getattr(full_model, "vision_model", full_model) def forward(self, features: dict[str, Tensor], **kwargs: Any) -> dict[str, Tensor]: features = super().forward(features, **kwargs) # Drop the first patch token to match training-time pooling. features["token_embeddings"] = features["token_embeddings"][:, 1:] return features class WhisperEncoderTransformer(Transformer): """Drop-in :class:`Transformer` subclass that decodes audio file paths/URLs into waveforms.""" def preprocess( self, inputs: list[Any], prompt: str | None = None, **kwargs: Any ) -> dict[str, Tensor]: from transformers.audio_utils import load_audio loaded = [load_audio(item) if isinstance(item, str) else item for item in inputs] return super().preprocess(loaded, prompt=prompt, **kwargs)