| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Processor class for Speech Granite. |
| """ |
|
|
| from collections.abc import Sequence |
| from typing import List, Union |
|
|
| import numpy as np |
| import torch |
|
|
| from transformers.feature_extraction_utils import BatchFeature |
| from transformers.processing_utils import ProcessorMixin |
| from transformers.tokenization_utils import PreTokenizedInput, TextInput |
| from transformers.utils import logging |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| |
| |
| |
| import math |
| from typing import List, Optional |
|
|
| from transformers.feature_extraction_utils import BatchFeature, FeatureExtractionMixin |
| from transformers.utils import is_torch_available, is_torchaudio_available, logging |
|
|
| if is_torch_available(): |
| import torch |
|
|
| if is_torchaudio_available(): |
| import torchaudio |
|
|
|
|
| class GraniteSpeechFeatureExtractor(FeatureExtractionMixin): |
| model_input_names = ["input_features"] |
|
|
| def __init__( |
| self, |
| sampling_rate=16000, |
| n_fft=512, |
| win_length=400, |
| hop_length=160, |
| n_mels=80, |
| projector_window_size=15, |
| projector_downsample_rate=5, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.melspec_kwargs = { |
| "sample_rate": sampling_rate, |
| "n_fft": n_fft, |
| "win_length": win_length, |
| "hop_length": hop_length, |
| "n_mels": n_mels, |
| } |
| |
| |
| |
| |
| self.melspec = None |
| self.projector_window_size = projector_window_size |
| self.projector_downsample_rate = projector_downsample_rate |
|
|
| def _ensure_melspec_transform_is_initialized(self): |
| if self.melspec is None: |
| self.melspec = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs) |
|
|
| def __call__( |
| self, |
| x: torch.Tensor, |
| device: Optional[str] = "cpu", |
| ) -> BatchFeature: |
| |
| self._ensure_melspec_transform_is_initialized() |
| if device is not None: |
| melspec = self.melspec.to(device) |
| x = x.to(device) |
| else: |
| melspec = self.melspec |
|
|
| B, _ = x.shape |
| with torch.no_grad(): |
| mel = melspec(x.float()) |
| logmel = mel.transpose(-1, -2).clip_(min=1e-10).log10_() |
| mx = logmel.amax(dim=(-2, -1), keepdim=True) |
| logmel = torch.maximum(logmel, mx - 8.0).div_(4).add_(1) |
| if logmel.shape[1] % 2 == 1: |
| logmel = logmel[:, :-1] |
| x = logmel.reshape(B, -1, 2 * logmel.shape[-1]) |
|
|
| if x.device != "cpu": |
| return x.detach().cpu() |
| return x |
|
|
| def _get_num_audio_features(self, audio_lengths: List[int]) -> List[int]: |
| """ |
| Gets the (variable length) variable length number of features |
| (i.e., projector output) for the sequences being considered. |
| """ |
| hop_length = self.melspec_kwargs["hop_length"] |
| effective_window_size = self.projector_window_size // self.projector_downsample_rate |
|
|
| projector_lengths = [] |
| for raw_length in audio_lengths: |
| |
| mel_length = raw_length // hop_length + 1 |
| |
| encoder_length = mel_length // 2 |
| nblocks = math.ceil(encoder_length / self.projector_window_size) |
| |
| projector_length = nblocks * effective_window_size |
| projector_lengths.append(projector_length) |
|
|
| return projector_lengths |
|
|
|
|
| import transformers |
| transformers.GraniteSpeechFeatureExtractor = GraniteSpeechFeatureExtractor |
| |
| |
|
|
| class GraniteSpeechProcessor(ProcessorMixin): |
| attributes = ["feature_extractor", "tokenizer"] |
| valid_kwargs = ["audio_token"] |
|
|
| feature_extractor_class = "GraniteSpeechFeatureExtractor" |
| tokenizer_class = "AutoTokenizer" |
|
|
| def __init__( |
| self, |
| feature_extractor, |
| tokenizer, |
| audio_token="<|audio|>", |
| ): |
| self.audio_token = tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token |
| super().__init__(feature_extractor, tokenizer) |
|
|
| def __call__( |
| self, |
| text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], |
| audios: Union[torch.Tensor, List[torch.Tensor]] = None, |
| device: str = "cpu", |
| **kwargs, |
| ) -> BatchFeature: |
| speech_inputs = {} |
| text_inputs = {} |
|
|
| text = self._get_validated_text(text) |
| expected_num_audios = sum(t.count(self.audio_token) for t in text) |
|
|
| if audios is not None: |
| audios, audio_lengths = self._get_validated_audios(audios) |
| if any(text.count(self.audio_token) != 1 for text in text): |
| raise ValueError("Only one audio sample is currently supported per input") |
| if len(audio_lengths) != expected_num_audios: |
| raise ValueError("Text/Audio mismatch. The number of audios and audio tokens do not match") |
|
|
| |
| speech_inputs["input_features"] = self.feature_extractor( |
| audios, |
| device=device, |
| ) |
| num_audio_features = self.feature_extractor._get_num_audio_features(audio_lengths) |
| speech_inputs["input_features_mask"] = torch.arange(max(num_audio_features)).view(1, -1) <= torch.tensor( |
| num_audio_features |
| ).view(-1, 1) |
|
|
| |
| text = self._expand_audio_placeholders(text, num_audio_features) |
| else: |
| assert expected_num_audios == 0, "No audio is provided, expecting no audio tokens" |
|
|
| text_inputs = self.tokenizer(text, padding=True, **kwargs) |
| return BatchFeature(data={**text_inputs, **speech_inputs}) |
|
|
| def _expand_audio_placeholders(self, text: list[str], num_audio_features: List[int]): |
| """ |
| Expands audio placeholders in the formatted text to match the number of |
| features of the corresponding embeddings; we can use the resulting text |
| to conveniently mask the audio features into the text embeddings. |
| """ |
| prompt_strings = [] |
| num_replaced = 0 |
| for sample in text: |
| while self.audio_token in sample: |
| sample = sample.replace( |
| self.audio_token, |
| "<placeholder>" * num_audio_features[num_replaced], |
| 1, |
| ) |
| num_replaced += 1 |
| prompt_strings.append(sample) |
|
|
| prompt_strings = [sample.replace("<placeholder>", self.audio_token) for sample in prompt_strings] |
| return prompt_strings |
|
|
| |
| def _get_validated_text(self, text: Union[str, list]) -> List[str]: |
| if isinstance(text, str): |
| return [text] |
| elif isinstance(text, list) and isinstance(text[0], str): |
| return text |
| raise TypeError("Invalid text provided! Text should be a string or list of strings.") |
|
|
| def _get_validated_audios(self, audios): |
| |
| |
| if isinstance(audios, np.ndarray): |
| audios = torch.from_numpy(audios) |
| elif isinstance(audios, Sequence) and isinstance(audios[0], np.ndarray): |
| audios = [torch.from_numpy(arr) for arr in audios] |
|
|
| if isinstance(audios, torch.Tensor): |
| if audios.ndim == 1: |
| audios = audios.unsqueeze(0) |
| if not torch.is_floating_point(audios): |
| raise ValueError("Invalid audio provided. Audio should be a floating point between 0 and 1") |
|
|
| if audios.shape[0] > 1: |
| logger.warning("Audio samples are already collated; assuming they all have the same length") |
| lengths = [audios.shape[-1]] * audios.shape[0] |
| return audios, lengths |
|
|
| elif isinstance(audios, Sequence) and isinstance(audios[0], torch.Tensor): |
| if not torch.is_floating_point(audios[0]): |
| raise ValueError("Invalid audio provided. Audio should be a floating point between 0 and 1") |
| lengths = [audio.shape[-1] for audio in audios] |
| padding = [max(lengths) - length for length in lengths] |
| |
| audios = [audio.view(1, -1) for audio in audios] |
| padded = [torch.nn.functional.pad(audio, (0, pad)) for audio, pad in zip(audios, padding)] |
| audios = torch.cat(padded, dim=0) |
| return audios, lengths |
|
|
| raise TypeError("Invalid audio provided. Audio should be a one or more torch tensors or numpy arrays") |
|
|
|
|
| __all__ = ["GraniteSpeechProcessor"] |
|
|