| from transformers import AutoTokenizer |
| from transformers.modeling_outputs import ModelOutput |
| from typing import List, Dict, Optional, Union, Tuple |
| from dataclasses import dataclass |
| import torch |
|
|
| from gigacheck.model.mistral_ai_detector import MistralAIDetectorForSequenceClassification |
| from gigacheck.model.src.interval_detector.span_utils import span_cxw_to_xx |
|
|
| from .configuration_gigacheck import GigaCheckConfig |
|
|
|
|
| @dataclass |
| class GigaCheckOutput(ModelOutput): |
| """ |
| Output type for GigaCheck model. |
| |
| Args: |
| pred_label_ids (torch.Tensor): [Batch] Indices of the predicted classes (Human/AI/Mixed). |
| classification_head_probs (torch.Tensor): [Batch, Num_Classes] Softmax probabilities. |
| """ |
| pred_label_ids: Optional[torch.Tensor] = None |
| classification_head_probs: Optional[torch.Tensor] = None |
|
|
|
|
| class GigaCheckForSequenceClassification(MistralAIDetectorForSequenceClassification): |
| config_class = GigaCheckConfig |
|
|
| def __init__(self, config: GigaCheckConfig): |
| super().__init__( |
| config, |
| with_detr = False, |
| detr_config = None, |
| ce_weights = None, |
| freeze_backbone = False, |
| id2label = config.id2label, |
| ) |
| self.trained_classification_head = True |
| self._max_len = self.config.max_length |
| self.tokenizer = None |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): |
| """Loads a pretrained GigaCheck model from a local path or the Hugging Face Hub. |
| |
| Args: |
| pretrained_model_name_or_path (str): The name or path of the pretrained model. |
| model_args: Additional positional arguments passed to parent class. |
| kwargs: Additional keyword arguments passed to parent class. |
| |
| Returns: |
| GigaCheckForSequenceClassification: The initialized model with loaded weights and initialized tokenizer. |
| """ |
| |
| model = super().from_pretrained( |
| pretrained_model_name_or_path, |
| *model_args, |
| **kwargs, |
| ) |
|
|
| if model.config.to_dict().get("trained_classification_head", True) is False: |
| |
| model.trained_classification_head = False |
|
|
| model.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) |
|
|
| |
| model.config.pad_token_id = model.tokenizer.pad_token_id \ |
| if model.tokenizer.pad_token_id is not None else model.tokenizer.unk_token_id |
| if model.tokenizer.pad_token_id is None: |
| model.tokenizer.pad_token_id = model.tokenizer.unk_token_id |
|
|
| model.config.bos_token_id = model.tokenizer.bos_token_id |
| model.config.eos_token_id = model.tokenizer.eos_token_id |
| model.config.unk_token_id = model.tokenizer.unk_token_id |
|
|
| return model |
|
|
| def _get_inputs(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: |
| """ |
| Tokenizes a batch of texts handling specific truncation logic to preserve exact text length mapping. |
| """ |
| assert self._max_len is not None and self.tokenizer is not None, "Model must be initialized" |
|
|
| |
| raw_encodings = self.tokenizer(texts, add_special_tokens=False) |
|
|
| batch_features = [] |
| text_lens = [] |
|
|
| content_max_len = self._max_len - 2 |
| bos_id = self.tokenizer.bos_token_id |
| eos_id = self.tokenizer.eos_token_id |
|
|
| for i, tokens in enumerate(raw_encodings.input_ids): |
| if len(tokens) > content_max_len: |
| tokens = tokens[:content_max_len] |
| |
| cur_text = self.tokenizer.decode(tokens, skip_special_tokens=True) |
| text_len = len(cur_text) |
| else: |
| |
| text_len = len(texts[i]) |
|
|
| |
| final_tokens = [bos_id] + tokens + [eos_id] |
|
|
| |
| batch_features.append({"input_ids": final_tokens}) |
| text_lens.append(text_len) |
|
|
| |
| padded_output = self.tokenizer.pad( |
| batch_features, |
| padding=True, |
| return_tensors="pt" |
| ) |
|
|
| input_ids = padded_output["input_ids"].to(self.device) |
| attention_mask = padded_output["attention_mask"].to(self.device) |
|
|
| return input_ids, attention_mask, text_lens |
|
|
| def forward( |
| self, |
| text: Union[str, List[str]], |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, GigaCheckOutput]: |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if isinstance(text, str): |
| text = [text] |
|
|
| input_ids, attention_mask, text_lens = self._get_inputs(text) |
|
|
| output = super().forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| return_dict=True, |
| return_detr_output=self.config.with_detr, |
| ) |
|
|
| |
| logits = output.logits |
|
|
| |
| probs = logits.to(torch.float32).softmax(dim=-1) |
| pred_label_ids = torch.argmax(probs, dim=-1) |
| classification_head_probs = probs |
|
|
| if not return_dict: |
| return (pred_label_ids, classification_head_probs) |
|
|
| return GigaCheckOutput( |
| pred_label_ids=pred_label_ids, |
| classification_head_probs=classification_head_probs, |
| ) |
|
|
|
|
| def to_absolute(pred_spans: torch.Tensor, text_len: int) -> torch.Tensor: |
| spans = span_cxw_to_xx(pred_spans) * text_len |
| return torch.clamp(spans, 0, text_len) |
|
|