| from transformers import BatchEncoding, Pipeline |
| import torch |
| from typing import Any, Generator |
|
|
| class TextDiffusionPipeline(Pipeline): |
| def _sanitize_parameters( |
| self, |
| num_steps: int = 50, |
| allow_edits: bool = True, |
| use_confidence: bool = False, |
| stop_token: None = None, |
| **kwargs |
| ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: |
| |
| |
| forward_kwargs = { |
| "num_steps": num_steps, |
| "allow_edits": allow_edits, |
| "use_confidence": use_confidence, |
| "stop_token": stop_token |
| } |
| |
| preprocess_kwargs = {} |
| if "max_length" in kwargs: |
| preprocess_kwargs["max_length"] = kwargs["max_length"] |
|
|
| return preprocess_kwargs, forward_kwargs, {} |
| |
| def preprocess(self, input_text, max_length=None) -> BatchEncoding | Any: |
| if self.tokenizer is None: |
| raise ValueError("Tokenizer was not passed to the pipeline!") |
| |
| if max_length is None: |
| |
| max_length = getattr(self.model.config, "seq_length", 512) |
| |
| if input_text is None: |
| input_text = "" |
| |
| tokenized_text = self.tokenizer.encode(input_text) |
| |
| if len(tokenized_text) < max_length: |
| input_ids = torch.full((1, max_length), self.tokenizer.mask_token_id, dtype=torch.long) |
| input_ids[0, :len(tokenized_text)] = torch.tensor(tokenized_text, dtype=torch.long) |
|
|
| return BatchEncoding({ |
| "input_ids": input_ids, |
| "attention_mask": torch.ones_like(input_ids) |
| }) |
|
|
| return self.tokenizer( |
| input_text, |
| return_tensors="pt", |
| padding="max_length", |
| max_length=max_length, |
| truncation=True, |
| ) |
| |
| @torch.no_grad() |
| def diffusion_generator( |
| self, |
| input_ids: torch.Tensor, |
| num_steps: int, |
| allow_edits: bool = True, |
| use_confidence: bool = False |
| ) -> Generator[torch.Tensor, None, None]: |
| if self.tokenizer is None: |
| raise ValueError("Tokenizer was not passed to the pipeline!") |
| |
| current_state: torch.Tensor = input_ids.clone() |
| yield current_state.clone() |
| |
| |
| initial_mask = (current_state == self.tokenizer.mask_token_id) | \ |
| (current_state == self.tokenizer.pad_token_id) |
| |
| for step in range(num_steps): |
| t_current = 1 - step / num_steps |
| t_next = 1 - (step + 1) / num_steps |
| |
| |
| output = self.model(input_ids=current_state) |
| logits = output.logits |
| |
| |
| logits[:, :, self.tokenizer.mask_token_id] = torch.finfo(logits.dtype).min |
| |
| |
| probs = torch.softmax(logits, dim=-1) |
| dist = torch.distributions.Categorical(probs) |
| sampled_ids = dist.sample() |
| |
| |
| |
| |
| |
| if step < num_steps - 1: |
| unmasking_prob = (t_current - t_next) / t_current |
| else: |
| unmasking_prob = 1.0 |
| |
| remasking_mask: torch.Tensor = (current_state == self.tokenizer.mask_token_id) | \ |
| (current_state == self.tokenizer.pad_token_id) |
| |
| if use_confidence: |
| |
| sample_probs = probs.gather(-1, sampled_ids.unsqueeze(-1)).squeeze(-1) |
| |
| |
| if step < num_steps - 1: |
| num_masked = remasking_mask.sum(dim=1, keepdim=True) |
| num_to_unmask = (num_masked.float() * unmasking_prob).ceil().long() |
| else: |
| num_to_unmask = remasking_mask.sum(dim=1, keepdim=True) |
| |
| |
| |
| candidate_confidences = sample_probs.clone() |
| candidate_confidences[~remasking_mask] = -float('inf') |
| |
| unmasking_mask = torch.zeros_like(remasking_mask, dtype=torch.bool) |
| |
| max_k = num_to_unmask.max().item() |
| if max_k > 0: |
| _, top_indices = candidate_confidences.topk(k=max_k, dim=1) |
| range_tensor = torch.arange(max_k, device=current_state.device).unsqueeze(0) |
| mask_k = range_tensor < num_to_unmask |
| unmasking_mask.scatter_(1, top_indices, mask_k) |
|
|
| else: |
| |
| unmasking_mask = torch.rand_like(current_state, dtype=torch.float) < unmasking_prob |
| |
| update_mask = unmasking_mask & remasking_mask & initial_mask |
| |
| if allow_edits: |
| alpha_t = 0.1 * (1 - step / num_steps) |
| |
| edit_mask = torch.rand_like(current_state, dtype=torch.float) < alpha_t |
| |
| is_visible = (current_state != self.tokenizer.mask_token_id) & \ |
| (current_state != self.tokenizer.pad_token_id) & \ |
| (current_state != self.tokenizer.eos_token_id) |
| edit_mask = is_visible & edit_mask & initial_mask |
| |
| |
| update_mask = update_mask | edit_mask |
|
|
| |
| current_state[update_mask] = sampled_ids[update_mask] |
| |
| yield current_state.clone() |
| |
| @torch.no_grad() |
| def _forward( |
| self, |
| model_inputs: torch.Tensor, |
| num_steps: int = 50, |
| allow_edits: bool = True, |
| use_confidence: bool = False, |
| stop_token: None = None |
| ) -> dict[str, Any]: |
| if self.tokenizer is None: |
| raise ValueError("Tokenizer was not passed to the pipeline!") |
| |
| input_ids = model_inputs["input_ids"] |
| all_states = list(self.diffusion_generator(input_ids=input_ids, num_steps=num_steps, allow_edits=allow_edits, use_confidence=use_confidence)) |
| final_state = all_states[-1] |
| |
| return {"final_state": final_state, "history": all_states} |
| |
| @torch.no_grad() |
| def stream_generation( |
| self, |
| input_text: str, |
| num_steps: int = 50, |
| allow_edits: bool = True, |
| use_confidence: bool = False, |
| max_length: int | None = None, |
| stop_token: str | None = None |
| ) -> Generator[str, None, None]: |
| """ |
| Public method to stream text generation step-by-step. |
| """ |
| |
| inputs = self.preprocess(input_text, max_length) |
| input_ids = inputs["input_ids"].to(self.model.device) |
| |
| |
| for step_tensor in self.diffusion_generator(input_ids=input_ids, num_steps=num_steps, allow_edits=allow_edits, use_confidence=use_confidence): |
| |
| text = self.tokenizer.decode(step_tensor[0], skip_special_tokens=False) |
| yield text |
| |
| if stop_token is not None and stop_token in text[len(input_text):]: |
| text = input_text + text[len(input_text):].split(stop_token)[0] |
| yield text |
| |
| def postprocess(self, model_outputs) -> list[str] | Any: |
| if self.tokenizer is None: |
| raise ValueError("Tokenizer was not passed to the pipeline!") |
| |
| |
| final_ids = model_outputs["final_state"] |
| return { |
| "decoded_texts": self.tokenizer.batch_decode(final_ids, skip_special_tokens=False), |
| "history": model_outputs["history"], |
| "final_ids": final_ids |
| } |
| |
| @torch.no_grad() |
| def block_diffusion_generator( |
| self, input_ids: torch.Tensor, |
| block_size: int, |
| max_length: int, |
| num_steps: int, |
| allow_edits: bool = True, |
| use_confidence: bool = False, |
| stop_token: str | None = None |
| ) -> Generator[torch.Tensor, None, None]: |
| """ |
| Generator that yields the diffusion states block-by-block. |
| Args: |
| input_ids (torch.Tensor): Initial input IDs with context. |
| block_size (int): Number of tokens to generate in each block. |
| max_length (int): Max length of the generated text. |
| num_steps (int): Number of diffusion steps per block. |
| allow_edits (bool): Whether to allow edits to existing tokens. |
| use_confidence (bool): Whether to use confidence-based unmasking. |
| stop_token (str | None): Token at which to stop generation early. |
| Yields: |
| torch.Tensor: The current state of the full sequence after each diffusion step. |
| """ |
| assert num_steps > 0, "num_steps must be greater than 0" |
| if self.tokenizer is None: |
| raise ValueError("Tokenizer was not passed to the pipeline!") |
| |
| max_seq_length = self.model.config.seq_length if hasattr(self.model.config, "seq_length") else 512 |
| stop_token_id = self.tokenizer.convert_tokens_to_ids(stop_token) if stop_token is not None else None |
| |
| assert block_size > 0 and block_size <= max_seq_length, f"block_size must be in (0, {max_seq_length}]" |
| |
| full_sequence = input_ids.clone() |
| current_length = input_ids.shape[1] |
| while current_length < max_length: |
| remaining = max_length - current_length |
| this_block_len = min(block_size, remaining) |
| if this_block_len <= 0: break |
| |
| |
| mask_block = torch.full( |
| (1, this_block_len), |
| self.tokenizer.mask_token_id, |
| dtype=torch.long, |
| device=self.model.device |
| ) |
| |
| |
| input_ids = torch.cat([full_sequence[:, -(max_seq_length - this_block_len):], mask_block], dim=1) |
| |
| for step_tensor in self.diffusion_generator( |
| input_ids, |
| num_steps=num_steps, |
| allow_edits=allow_edits, |
| use_confidence=use_confidence |
| ): |
| current_generated_tokens = step_tensor[:, -this_block_len:] |
| yield torch.cat([full_sequence, current_generated_tokens], dim=1) |
| |
| |
| if stop_token_id is not None and stop_token_id in current_generated_tokens: |
| |
| eos_index = (current_generated_tokens == stop_token_id).nonzero(as_tuple=True)[1] |
| current_generated_tokens = current_generated_tokens[:, :eos_index[0]] |
| yield torch.cat([full_sequence, current_generated_tokens], dim=1) |
| break |
|
|
| |
| full_sequence = torch.cat([full_sequence, current_generated_tokens], dim=1) |
| current_length = full_sequence.shape[1] |
| |
| |
| @torch.no_grad() |
| def semi_autoregressive_generate( |
| self, |
| input_text: str, |
| block_size: int = 64, |
| max_length: int = 256, |
| num_steps: int = 50, |
| allow_edits: bool = True, |
| use_confidence: bool = False |
| ) -> dict[str, Any]: |
| """ |
| Semi-Autoregressive Generation: |
| Generates text in blocks using the diffusion model. |
| Each block is generated by appending MASK tokens to the current context |
| and running the diffusion process on the combined sequence. |
| Args: |
| input_text (str): The initial prompt text. |
| block_size (int): Number of tokens to generate in each block. |
| max_length (int): Max length of the generated text. |
| num_steps (int): Number of diffusion steps per block. |
| allow_edits (bool): Whether to allow edits to existing tokens. |
| use_confidence (bool): Whether to use confidence-based unmasking. |
| Returns: |
| dict[str, Any]: A dictionary containing the decoded texts, generation history, and final token IDs. |
| """ |
| if self.tokenizer is None: raise ValueError("No tokenizer") |
| |
| input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to(self.model.device) |
| all_states = list(self.block_diffusion_generator(input_ids, block_size, max_length, num_steps, allow_edits, use_confidence=use_confidence)) |
| final_state = all_states[-1] |
| return { |
| "decoded_texts": self.tokenizer.batch_decode(final_state, skip_special_tokens=False), |
| "history": all_states, |
| "final_ids": final_state |
| } |
| |
| @torch.no_grad() |
| def stream_semi_autoregressive_generate( |
| self, |
| input_text: str, |
| block_size: int = 64, |
| max_length: int = 256, |
| num_steps: int = 50, |
| allow_edits: bool = True, |
| use_confidence: bool = False, |
| stop_token: str | None = None |
| ) -> Generator[str, None, None]: |
| """ |
| Streams the generation process block-by-block. |
| Yields the full decoded text at every diffusion step of every block. |
| Args: |
| input_text (str): The initial prompt text. |
| block_size (int): Number of tokens to generate in each block. |
| max_length (int): Max length of the generated text. |
| num_steps (int): Number of diffusion steps per block. |
| allow_edits (bool): Whether to allow edits to existing tokens. |
| use_confidence (bool): Whether to use confidence-based unmasking. |
| stop_token (None): Token at which to stop generation early. |
| Yields: |
| str: The current generated text after each diffusion step. |
| """ |
| if self.tokenizer is None: raise ValueError("No tokenizer") |
| |
| input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to(self.model.device) |
| |
| for step_tensor in self.block_diffusion_generator(input_ids, block_size, max_length, num_steps, allow_edits, use_confidence=use_confidence, stop_token=stop_token): |
| |
| yield self.tokenizer.decode(step_tensor[0], skip_special_tokens=False) |