File size: 2,022 Bytes
1a5fa14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from __future__ import annotations

from typing import Any

from torch import Tensor
from transformers import PretrainedConfig, PreTrainedModel

from sentence_transformers.base.modules.transformer import Transformer


class SiglipVisionTransformer(Transformer):
    """Drop-in :class:`Transformer` subclass that exposes only the SigLIP vision tower."""

    def __init__(self, model_name_or_path: str, **kwargs: Any) -> None:
        super().__init__(model_name_or_path, **kwargs)
        # Drop the unused SigLIP text tokenizer (~17 MB) from the saved layout.
        if hasattr(self.processor, "image_processor"):
            self.processor = self.processor.image_processor

    def _load_model(
        self,
        model_name_or_path: str,
        transformer_task: str,
        config: PretrainedConfig,
        backend: str,
        is_peft_model: bool,
        **model_kwargs: Any,
    ) -> PreTrainedModel:
        full_model = super()._load_model(
            model_name_or_path, transformer_task, config, backend, is_peft_model, **model_kwargs
        )
        # getattr keeps the vision tower on fresh init; on reload it's already SiglipVisionModel.
        return getattr(full_model, "vision_model", full_model)

    def forward(self, features: dict[str, Tensor], **kwargs: Any) -> dict[str, Tensor]:
        features = super().forward(features, **kwargs)
        # Drop the first patch token to match training-time pooling.
        features["token_embeddings"] = features["token_embeddings"][:, 1:]
        return features


class WhisperEncoderTransformer(Transformer):
    """Drop-in :class:`Transformer` subclass that decodes audio file paths/URLs into waveforms."""

    def preprocess(
        self, inputs: list[Any], prompt: str | None = None, **kwargs: Any
    ) -> dict[str, Tensor]:
        from transformers.audio_utils import load_audio

        loaded = [load_audio(item) if isinstance(item, str) else item for item in inputs]
        return super().preprocess(loaded, prompt=prompt, **kwargs)