| from transformers import AutoTokenizer |
| from transformers.modeling_outputs import ModelOutput |
| from typing import List, Dict, Optional, Union, Tuple |
| from dataclasses import dataclass |
| import torch |
|
|
| from gigacheck.model.mistral_ai_detector import MistralAIDetectorForSequenceClassification |
| from gigacheck.model.src.interval_detector.span_utils import span_cxw_to_xx |
|
|
| from .configuration_gigacheck import GigaCheckConfig |
|
|
|
|
| @dataclass |
| class GigaCheckOutput(ModelOutput): |
| """ |
| Output type for GigaCheck model. |
| |
| Args: |
| pred_label_ids (torch.Tensor): [Batch] Indices of the predicted classes (Human/AI/Mixed). |
| classification_head_probs (torch.Tensor): [Batch, Num_Classes] Softmax probabilities. |
| ai_intervals (List[torch.Tensor]): List of length Batch. Each element is a tensor of shape [Num_Intervals, 3] |
| containing (start, end, score) for detected AI-generated spans. |
| """ |
| pred_label_ids: Optional[torch.Tensor] = None |
| classification_head_probs: Optional[torch.Tensor] = None |
| ai_intervals: Optional[List[torch.Tensor]] = None |
|
|
|
|
| class GigaCheckForDetection(MistralAIDetectorForSequenceClassification): |
| config_class = GigaCheckConfig |
|
|
| def __init__(self, config: GigaCheckConfig): |
| super().__init__( |
| config, |
| with_detr = config.with_detr, |
| detr_config = config.detr_config, |
| ce_weights = None, |
| freeze_backbone = False, |
| id2label = config.id2label, |
| ) |
| self.trained_classification_head = True |
| self._max_len = self.config.max_length |
| self.tokenizer = None |
| self.conf_interval_thresh = config.conf_interval_thresh |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): |
| """Loads a pretrained GigaCheck model from a local path or the Hugging Face Hub. |
| |
| Args: |
| pretrained_model_name_or_path (str): The name or path of the pretrained model. |
| model_args: Additional positional arguments passed to parent class. |
| kwargs: Additional keyword arguments passed to parent class. |
| |
| Returns: |
| GigaCheckForSequenceClassification: The initialized model with loaded weights and initialized tokenizer. |
| """ |
| |
| model = super().from_pretrained( |
| pretrained_model_name_or_path, |
| *model_args, |
| **kwargs, |
| ) |
|
|
| if model.config.with_detr: |
| extractor_dtype = getattr(torch, model.config.detr_config["extractor_dtype"]) |
| print(f"Using dtype={extractor_dtype} for {type(model.model)}") |
| if extractor_dtype == torch.bfloat16: |
| model.model.to(torch.bfloat16) |
| model.classification_head.to(torch.bfloat16) |
|
|
| if model.config.to_dict().get("trained_classification_head", True) is False: |
| |
| model.trained_classification_head = False |
|
|
| model.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) |
|
|
| |
| model.config.pad_token_id = model.tokenizer.pad_token_id \ |
| if model.tokenizer.pad_token_id is not None else model.tokenizer.unk_token_id |
| if model.tokenizer.pad_token_id is None: |
| model.tokenizer.pad_token_id = model.tokenizer.unk_token_id |
|
|
| model.config.bos_token_id = model.tokenizer.bos_token_id |
| model.config.eos_token_id = model.tokenizer.eos_token_id |
| model.config.unk_token_id = model.tokenizer.unk_token_id |
|
|
| return model |
|
|
| def _get_inputs(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: |
| """ |
| Tokenizes a batch of texts handling specific truncation logic to preserve exact text length mapping. |
| """ |
| assert self._max_len is not None and self.tokenizer is not None, "Model must be initialized" |
|
|
| |
| raw_encodings = self.tokenizer(texts, add_special_tokens=False) |
|
|
| batch_features = [] |
| text_lens = [] |
|
|
| content_max_len = self._max_len - 2 |
| bos_id = self.tokenizer.bos_token_id |
| eos_id = self.tokenizer.eos_token_id |
|
|
| for i, tokens in enumerate(raw_encodings.input_ids): |
| if len(tokens) > content_max_len: |
| tokens = tokens[:content_max_len] |
| |
| cur_text = self.tokenizer.decode(tokens, skip_special_tokens=True) |
| text_len = len(cur_text) |
| else: |
| |
| text_len = len(texts[i]) |
|
|
| |
| final_tokens = [bos_id] + tokens + [eos_id] |
|
|
| |
| batch_features.append({"input_ids": final_tokens}) |
| text_lens.append(text_len) |
|
|
| |
| padded_output = self.tokenizer.pad( |
| batch_features, |
| padding=True, |
| return_tensors="pt" |
| ) |
|
|
| input_ids = padded_output["input_ids"].to(self.device) |
| attention_mask = padded_output["attention_mask"].to(self.device) |
|
|
| return input_ids, attention_mask, text_lens |
|
|
| @staticmethod |
| def _get_ai_intervals(detr_out: Dict[str, torch.Tensor], text_lens: List[int], conf_interval_thresh: float) -> List[torch.Tensor]: |
| """ |
| Converts DETR outputs to absolute text intervals. |
| """ |
| pred_spans = detr_out["pred_spans"] |
| src_logits = detr_out["pred_logits"] |
| assert len(text_lens) == pred_spans.shape[0] |
|
|
| |
| pred_probs = torch.softmax(src_logits, dim=-1)[:, :, 0:1] |
|
|
| final_preds_batch = [] |
|
|
| for i, length in enumerate(text_lens): |
| |
| |
| spans_abs = to_absolute(pred_spans[i], length) |
|
|
| |
| scores = pred_probs[i] |
| preds_i = torch.cat([spans_abs, scores], dim=1) |
|
|
| |
| mask = preds_i[:, 2] > conf_interval_thresh |
| filtered_preds = preds_i[mask] |
|
|
| final_preds_batch.append(filtered_preds) |
|
|
| return final_preds_batch |
|
|
| def forward( |
| self, |
| text: Union[str, List[str]], |
| return_dict: Optional[bool] = None, |
| conf_interval_thresh: float = None, |
| ) -> Union[Tuple, GigaCheckOutput]: |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| conf_interval_thresh = conf_interval_thresh if conf_interval_thresh is not None else self.config.conf_interval_thresh |
|
|
| if isinstance(text, str): |
| text = [text] |
|
|
| input_ids, attention_mask, text_lens = self._get_inputs(text) |
|
|
| output = super().forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| return_dict=True, |
| return_detr_output=self.config.with_detr, |
| ) |
|
|
| pred_label_ids = None |
| classification_head_probs = None |
| ai_intervals = None |
|
|
| |
| if not self.config.with_detr: |
| logits = output.logits |
| elif self.trained_classification_head: |
| logits, _ = output.logits |
| else: |
| logits = None |
|
|
| if logits is not None: |
| |
| probs = logits.to(torch.float32).softmax(dim=-1) |
| pred_label_ids = torch.argmax(probs, dim=-1) |
| classification_head_probs = probs |
|
|
| |
| if self.config.with_detr: |
| _, detr_out = output.logits |
| ai_intervals = self._get_ai_intervals(detr_out, text_lens, conf_interval_thresh) |
|
|
| if not return_dict: |
| return (pred_label_ids, classification_head_probs, ai_intervals) |
|
|
| return GigaCheckOutput( |
| pred_label_ids=pred_label_ids, |
| classification_head_probs=classification_head_probs, |
| ai_intervals=ai_intervals, |
| ) |
|
|
|
|
| def to_absolute(pred_spans: torch.Tensor, text_len: int) -> torch.Tensor: |
| spans = span_cxw_to_xx(pred_spans) * text_len |
| return torch.clamp(spans, 0, text_len) |
|
|