| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| import logging |
| import floret |
| import os |
| from huggingface_hub import hf_hub_download |
| from .configuration_lang import ImpressoConfig |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class LangDetectorModel(PreTrainedModel): |
| config_class = ImpressoConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
|
|
| |
| self.dummy_param = nn.Parameter(torch.zeros(1)) |
| bin_filename = self.config.config.filename |
|
|
| |
| if not os.path.exists(bin_filename): |
| |
| bin_filename = hf_hub_download(repo_id=self.config.config._name_or_path, |
| filename=bin_filename) |
|
|
| |
| self.model_floret = floret.load_model(bin_filename) |
|
|
| def forward(self, input_ids, **kwargs): |
| if isinstance(input_ids, str): |
| |
| texts = [input_ids] |
| elif isinstance(input_ids, list) and all(isinstance(t, str) for t in input_ids): |
| texts = input_ids |
| else: |
| raise ValueError(f"Unexpected input type: {type(input_ids)}") |
|
|
| predictions, probabilities = self.model_floret.predict(texts, k=1) |
| return ( |
| predictions, |
| probabilities, |
| ) |
|
|
| @property |
| def device(self): |
| return next(self.parameters()).device |
|
|
| @classmethod |
| def from_pretrained(cls, *args, **kwargs): |
| |
| |
| config = ImpressoConfig(**kwargs) |
| |
| model = cls(config) |
| return model |
|
|