| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import annotations |
| import collections |
| from math import sqrt |
|
|
| import scipy.stats |
|
|
| import torch |
| from torch import Tensor |
| from tokenizers import Tokenizer |
| from transformers import LogitsProcessor |
|
|
| from nltk.util import ngrams |
|
|
| from normalizers import normalization_strategy_lookup |
|
|
| class WatermarkBase: |
| def __init__( |
| self, |
| vocab: list[int] = None, |
| gamma: float = 0.5, |
| delta: float = 2.0, |
| seeding_scheme: str = "simple_1", |
| hash_key: int = 15485863, |
| select_green_tokens: bool = True, |
| ): |
|
|
| |
| self.vocab = vocab |
| self.vocab_size = len(vocab) |
| self.gamma = gamma |
| self.delta = delta |
| self.seeding_scheme = seeding_scheme |
| self.rng = None |
| self.hash_key = hash_key |
| self.select_green_tokens = select_green_tokens |
|
|
| def _seed_rng(self, input_ids: torch.LongTensor, seeding_scheme: str = None) -> None: |
| |
| |
| if seeding_scheme is None: |
| seeding_scheme = self.seeding_scheme |
|
|
| if seeding_scheme == "simple_1": |
| assert input_ids.shape[-1] >= 1, f"seeding_scheme={seeding_scheme} requires at least a 1 token prefix sequence to seed rng" |
| prev_token = input_ids[-1].item() |
| self.rng.manual_seed(self.hash_key * prev_token) |
| else: |
| raise NotImplementedError(f"Unexpected seeding_scheme: {seeding_scheme}") |
| return |
|
|
| def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]: |
| |
| |
| self._seed_rng(input_ids) |
|
|
| greenlist_size = int(self.vocab_size * self.gamma) |
| vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng) |
| if self.select_green_tokens: |
| greenlist_ids = vocab_permutation[:greenlist_size] |
| else: |
| greenlist_ids = vocab_permutation[(self.vocab_size - greenlist_size) :] |
| return greenlist_ids |
|
|
|
|
| class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor): |
|
|
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor: |
| |
| green_tokens_mask = torch.zeros_like(scores) |
| for b_idx in range(len(greenlist_token_ids)): |
| green_tokens_mask[b_idx][greenlist_token_ids[b_idx]] = 1 |
| final_mask = green_tokens_mask.bool() |
| return final_mask |
|
|
| def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor: |
| scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias |
| return scores |
|
|
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
|
| |
| if self.rng is None: |
| self.rng = torch.Generator(device=input_ids.device) |
|
|
| |
| |
| |
| batched_greenlist_ids = [None for _ in range(input_ids.shape[0])] |
|
|
| for b_idx in range(input_ids.shape[0]): |
| greenlist_ids = self._get_greenlist_ids(input_ids[b_idx]) |
| batched_greenlist_ids[b_idx] = greenlist_ids |
|
|
| green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=batched_greenlist_ids) |
|
|
| scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta) |
| return scores |
|
|
|
|
| class WatermarkDetector(WatermarkBase): |
| def __init__( |
| self, |
| *args, |
| device: torch.device = None, |
| tokenizer: Tokenizer = None, |
| z_threshold: float = 4.0, |
| normalizers: list[str] = ["unicode"], |
| ignore_repeated_bigrams: bool = False, |
| **kwargs, |
| ): |
| super().__init__(*args, **kwargs) |
| |
| assert device, "Must pass device" |
| assert tokenizer, "Need an instance of the generating tokenizer to perform detection" |
|
|
| self.tokenizer = tokenizer |
| self.device = device |
| self.z_threshold = z_threshold |
| self.rng = torch.Generator(device=self.device) |
|
|
| if self.seeding_scheme == "simple_1": |
| self.min_prefix_len = 1 |
| else: |
| raise NotImplementedError(f"Unexpected seeding_scheme: {self.seeding_scheme}") |
|
|
| self.normalizers = [] |
| for normalization_strategy in normalizers: |
| self.normalizers.append(normalization_strategy_lookup(normalization_strategy)) |
| |
| self.ignore_repeated_bigrams = ignore_repeated_bigrams |
| if self.ignore_repeated_bigrams: |
| assert self.seeding_scheme == "simple_1", "No repeated bigram credit variant assumes the single token seeding scheme." |
|
|
|
|
| def _compute_z_score(self, observed_count, T): |
| |
| expected_count = self.gamma |
| numer = observed_count - expected_count * T |
| denom = sqrt(T * expected_count * (1 - expected_count)) |
| z = numer / denom |
| return z |
|
|
| def _compute_p_value(self, z): |
| p_value = scipy.stats.norm.sf(z) |
| return p_value |
|
|
| def _score_sequence( |
| self, |
| input_ids: Tensor, |
| return_num_tokens_scored: bool = True, |
| return_num_green_tokens: bool = True, |
| return_green_fraction: bool = True, |
| return_green_token_mask: bool = False, |
| return_z_score: bool = True, |
| return_p_value: bool = True, |
| ): |
| if self.ignore_repeated_bigrams: |
| |
| |
| |
| |
| |
| assert return_green_token_mask == False, "Can't return the green/red mask when ignoring repeats." |
| bigram_table = {} |
| token_bigram_generator = ngrams(input_ids.cpu().tolist(), 2) |
| freq = collections.Counter(token_bigram_generator) |
| num_tokens_scored = len(freq.keys()) |
| for idx, bigram in enumerate(freq.keys()): |
| prefix = torch.tensor([bigram[0]], device=self.device) |
| greenlist_ids = self._get_greenlist_ids(prefix) |
| bigram_table[bigram] = True if bigram[1] in greenlist_ids else False |
| green_token_count = sum(bigram_table.values()) |
| else: |
| num_tokens_scored = len(input_ids) - self.min_prefix_len |
| if num_tokens_scored < 1: |
| raise ValueError((f"Must have at least {1} token to score after " |
| f"the first min_prefix_len={self.min_prefix_len} tokens required by the seeding scheme.")) |
| |
| |
| |
| |
| |
| |
| green_token_count, green_token_mask = 0, [] |
| for idx in range(self.min_prefix_len, len(input_ids)): |
| curr_token = input_ids[idx] |
| greenlist_ids = self._get_greenlist_ids(input_ids[:idx]) |
| if curr_token in greenlist_ids: |
| green_token_count += 1 |
| green_token_mask.append(True) |
| else: |
| green_token_mask.append(False) |
|
|
| score_dict = dict() |
| if return_num_tokens_scored: |
| score_dict.update(dict(num_tokens_scored=num_tokens_scored)) |
| if return_num_green_tokens: |
| score_dict.update(dict(num_green_tokens=green_token_count)) |
| if return_green_fraction: |
| score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored))) |
| if return_z_score: |
| score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored))) |
| if return_p_value: |
| z_score = score_dict.get("z_score") |
| if z_score is None: |
| z_score = self._compute_z_score(green_token_count, num_tokens_scored) |
| score_dict.update(dict(p_value=self._compute_p_value(z_score))) |
| if return_green_token_mask: |
| score_dict.update(dict(green_token_mask=green_token_mask)) |
|
|
| return score_dict |
|
|
| def detect( |
| self, |
| text: str = None, |
| tokenized_text: list[int] = None, |
| return_prediction: bool = True, |
| return_scores: bool = True, |
| z_threshold: float = None, |
| **kwargs, |
| ) -> dict: |
|
|
| assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string" |
| if return_prediction: |
| kwargs["return_p_value"] = True |
|
|
| |
| for normalizer in self.normalizers: |
| text = normalizer(text) |
| if len(self.normalizers) > 0: |
| print(f"Text after normalization:\n\n{text}\n") |
|
|
| if tokenized_text is None: |
| assert self.tokenizer is not None, ( |
| "Watermark detection on raw string ", |
| "requires an instance of the tokenizer ", |
| "that was used at generation time.", |
| ) |
| tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device) |
| if tokenized_text[0] == self.tokenizer.bos_token_id: |
| tokenized_text = tokenized_text[1:] |
| else: |
| |
| if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id): |
| tokenized_text = tokenized_text[1:] |
|
|
| |
| output_dict = {} |
| score_dict = self._score_sequence(tokenized_text, **kwargs) |
| if return_scores: |
| output_dict.update(score_dict) |
| |
| if return_prediction: |
| z_threshold = z_threshold if z_threshold else self.z_threshold |
| assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test" |
| output_dict["prediction"] = score_dict["z_score"] > z_threshold |
| if output_dict["prediction"]: |
| output_dict["confidence"] = 1 - score_dict["p_value"] |
|
|
| return output_dict |
|
|
|
|