| from __future__ import annotations |
|
|
| import inspect |
| import logging |
| import os |
| from collections.abc import Callable |
| from pathlib import Path |
| from typing import TYPE_CHECKING, Any |
|
|
| from sentence_transformers.backend import load_onnx_model, load_openvino_model |
|
|
| try: |
| from typing import Self |
| except ImportError: |
| from typing_extensions import Self |
|
|
| import torch |
| from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, PretrainedConfig, T5Config |
| from transformers.utils.import_utils import is_peft_available |
| from transformers.utils.peft_utils import find_adapter_config_file |
|
|
| from sentence_transformers.models.InputModule import InputModule |
|
|
| logger = logging.getLogger(__name__) |
|
|
| if TYPE_CHECKING and is_peft_available(): |
| from peft import PeftConfig |
| from sentence_transformers.models import Transformer |
|
|
|
|
| class C2LLMTransformer(Transformer): |
| config_file_name: str = "sentence_bert_config.json" |
| config_keys: list[str] = ["max_seq_length", "do_lower_case"] |
| save_in_root: bool = True |
|
|
| |
| def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]: |
| trans_features = {key: value for key, value in features.items() if key in self.model_forward_params} |
|
|
| outputs = self.auto_model(**trans_features, **kwargs, return_dict=True) |
| |
| sentence_embedding = outputs["sentence_embedding"] |
| features["sentence_embedding"] = sentence_embedding |
|
|
| return features |