Spaces:
Running on A100
Running on A100
| #!/usr/bin/env python3 | |
| # Copyright 2025 Xiaomi Corporation. | |
| # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | |
| import torch | |
| import torch.nn.functional as F | |
| from typing import Tuple, Union, List | |
| class InputSegment: | |
| def __init__( | |
| self, | |
| text: str = "", | |
| audio: torch.Tensor = None, | |
| tokenized_text: torch.Tensor = None, | |
| speech_zeroemb_idx: Union[int, List[int]] = 1024, | |
| text_zeroemb_idx: int = 152067, | |
| add_sosp_eosp=True, | |
| ) -> None: | |
| has_text = text is not None | |
| has_tokenized_text = tokenized_text is not None | |
| assert has_text or has_tokenized_text, "Text or tokenized text must be provided" | |
| self.audio = audio | |
| self.text = text | |
| self.tokenized_text = tokenized_text | |
| self.speech_zeroemb_idx = speech_zeroemb_idx | |
| self.text_zeroemb_idx = text_zeroemb_idx | |
| self.add_sosp_eosp = add_sosp_eosp | |
| def insert_between(tensor, i, value=-1): | |
| return torch.scatter( | |
| torch.full( | |
| (1, tensor.shape[1] + (tensor.shape[1] - 1) * i + i), | |
| value, | |
| dtype=tensor.dtype, | |
| ), | |
| 1, | |
| torch.arange(0, tensor.shape[1], dtype=torch.int64)[None] * (i + 1), | |
| tensor, | |
| ) | |
| def to_input_id( | |
| self, | |
| tokenizer, | |
| group_size: int, | |
| audio_channels: int = 8, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if self.audio is None: | |
| if self.tokenized_text is None: | |
| tokenized_text = tokenizer( | |
| self.text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=999999, | |
| padding=False, | |
| add_special_tokens=False, | |
| )["input_ids"].int() | |
| else: | |
| tokenized_text = self.tokenized_text.unsqueeze(0) | |
| if group_size > 1: | |
| tokenized_text = self.insert_between( | |
| tokenized_text, group_size - 1, value=-100 | |
| ) | |
| if isinstance(self.speech_zeroemb_idx, list): | |
| audio_part_input_id = torch.zeros((audio_channels, tokenized_text.shape[1]), dtype=torch.int) | |
| for i, idx in enumerate(self.speech_zeroemb_idx): | |
| audio_part_input_id[i, :] = idx | |
| else: | |
| audio_part_input_id = torch.full( | |
| (audio_channels, tokenized_text.shape[1]), self.speech_zeroemb_idx, dtype=torch.int | |
| ) | |
| else: | |
| sosp_token = ( | |
| tokenizer.convert_tokens_to_ids("<|sosp|>") | |
| if self.add_sosp_eosp | |
| else None | |
| ) | |
| eosp_token = ( | |
| tokenizer.convert_tokens_to_ids("<|eosp|>") | |
| if self.add_sosp_eosp | |
| else None | |
| ) | |
| audio_part = self.audio.reshape(-1, audio_channels).T # [audio_channels, seqlen] | |
| assert ( | |
| audio_part.shape[1] % group_size == 0 | |
| ), f"Audio shape {audio_part.shape} is not divisible by group_size {group_size}" | |
| text_len = audio_part.shape[1] // group_size | |
| empty_token = self.text_zeroemb_idx | |
| if empty_token is None: | |
| empty_token = tokenizer.eod | |
| tokenized_text = torch.full((1, text_len), empty_token, dtype=torch.int) | |
| tokenized_text = ( | |
| torch.cat( | |
| [ | |
| torch.tensor([[sosp_token]], dtype=torch.int), | |
| tokenized_text, | |
| torch.tensor([[eosp_token]], dtype=torch.int), | |
| ], | |
| dim=1, | |
| ) | |
| if self.add_sosp_eosp | |
| else tokenized_text | |
| ) | |
| tokenized_text = self.insert_between( | |
| tokenized_text, group_size - 1, value=-100 | |
| ) | |
| if self.add_sosp_eosp: | |
| if isinstance(self.speech_zeroemb_idx, list): | |
| sosp_part = torch.zeros((audio_channels, group_size), dtype=torch.int) | |
| eosp_part = torch.zeros((audio_channels, group_size), dtype=torch.int) | |
| for i, idx in enumerate(self.speech_zeroemb_idx): | |
| sosp_part[i, :] = idx | |
| eosp_part[i, :] = idx | |
| audio_part_input_id = torch.cat([sosp_part, audio_part, eosp_part], dim=1) | |
| else: | |
| audio_part_input_id = torch.cat( | |
| [ | |
| torch.full((audio_channels, group_size), self.speech_zeroemb_idx, dtype=torch.int), | |
| audio_part, | |
| torch.full((audio_channels, group_size), self.speech_zeroemb_idx, dtype=torch.int), | |
| ], | |
| dim=1, | |
| ) | |
| else: | |
| audio_part_input_id = audio_part | |
| input_ids = torch.cat( | |
| [tokenized_text, audio_part_input_id], dim=0 | |
| ) # [n_rvq + 1, seqlen] | |
| return input_ids | |
| class StreamingInputSegment: | |
| def __init__( | |
| self, | |
| text: str = "", | |
| audio: torch.Tensor = None, | |
| tokenized_text: torch.Tensor = None, | |
| speech_zeroemb_idx: Union[int, List[int]] = 1024, | |
| text_zeroemb_idx: int = 152067, | |
| text_segment_size: int = 5, | |
| audio_segment_size: int = 5, | |
| tokenizer=None, | |
| group_size=None, | |
| audio_channels=None, | |
| ) -> None: | |
| has_text = text is not None | |
| has_tokenized_text = tokenized_text is not None | |
| assert has_text or has_tokenized_text, "Text or tokenized text must be provided" | |
| self.audio = audio | |
| self.text = text | |
| self.tokenized_text = tokenized_text | |
| self.speech_zeroemb_idx = speech_zeroemb_idx | |
| self.text_zeroemb_idx = text_zeroemb_idx | |
| self.text_segment_size = text_segment_size | |
| self.audio_segment_size = audio_segment_size | |
| self.tokenizer = tokenizer | |
| self.group_size = group_size | |
| self.audio_channels = audio_channels | |
| def to_input_id( | |
| self, | |
| tokenizer, | |
| group_size: int, | |
| audio_channels: int = 8, | |
| ): | |
| if self.tokenized_text is None: | |
| tokenized_text = tokenizer( | |
| self.text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=999999, | |
| padding=False, | |
| add_special_tokens=False, | |
| )["input_ids"].int() # [1, seqlen] | |
| else: | |
| tokenized_text = self.tokenized_text.unsqueeze(0) | |
| tokenized_text = tokenized_text.squeeze(0) | |
| text_segments = tokenized_text.split(self.text_segment_size, dim=0) | |
| audio_segments = self.audio.split(self.audio_segment_size*group_size*audio_channels, dim=0) | |
| tokenized_segments = [] | |
| tokenized_segments.append( | |
| InputSegment( | |
| text='<|sostm|>', | |
| speech_zeroemb_idx=self.speech_zeroemb_idx, | |
| text_zeroemb_idx=self.text_zeroemb_idx, | |
| ), | |
| ) | |
| eot_tokens = tokenizer( | |
| "<|eot|>", | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=999999, | |
| padding=False, | |
| add_special_tokens=False, | |
| )["input_ids"][0].to(text_segments[-1]) | |
| text_segments = text_segments[:-1] + (torch.cat([text_segments[-1], eot_tokens], dim=0),) | |
| length = min(len(text_segments), len(audio_segments)) | |
| for i in range(length): | |
| text_segment = text_segments[i] | |
| audio_segment = audio_segments[i] | |
| tokenized_segments.append( | |
| InputSegment( | |
| tokenized_text=text_segment, | |
| speech_zeroemb_idx=self.speech_zeroemb_idx, | |
| text_zeroemb_idx=self.text_zeroemb_idx, | |
| ), | |
| ) | |
| tokenized_segments.append( | |
| InputSegment( | |
| audio=audio_segment, | |
| add_sosp_eosp=False, | |
| speech_zeroemb_idx=self.speech_zeroemb_idx, | |
| text_zeroemb_idx=self.text_zeroemb_idx, | |
| ), | |
| ) | |
| for j in range(length, len(text_segments)): | |
| tokenized_segments.append( | |
| InputSegment( | |
| tokenized_text=text_segments[j], | |
| speech_zeroemb_idx=self.speech_zeroemb_idx, | |
| text_zeroemb_idx=self.text_zeroemb_idx, | |
| ), | |
| ) | |
| for j in range(length, len(audio_segments)): | |
| tokenized_segments.append( | |
| InputSegment( | |
| audio=audio_segments[j], | |
| add_sosp_eosp=False, | |
| speech_zeroemb_idx=self.speech_zeroemb_idx, | |
| text_zeroemb_idx=self.text_zeroemb_idx, | |
| ), | |
| ) | |
| tokenized_segments.append( | |
| InputSegment( | |
| text="<|eostm|>", | |
| speech_zeroemb_idx=self.speech_zeroemb_idx, | |
| text_zeroemb_idx=self.text_zeroemb_idx, | |
| ), | |
| ) | |
| input_ids = [ | |
| seg.to_input_id( | |
| self.tokenizer, | |
| self.group_size, | |
| self.audio_channels, | |
| ) | |
| for seg in tokenized_segments | |
| ] | |
| input_ids = torch.cat(input_ids, dim=1).type(torch.int64) # [n_rvq + 1, seqlen] | |
| return input_ids | |