MiMo-V2.5-ASR / src /mimo_audio /process_speechdata.py
MarkDaniel212's picture
Initial Docker-based ASR demo (app.py + src + requirements)
2c4c098 verified
raw
history blame
9.85 kB
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corporation.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import torch
import torch.nn.functional as F
from typing import Tuple, Union, List
class InputSegment:
def __init__(
self,
text: str = "",
audio: torch.Tensor = None,
tokenized_text: torch.Tensor = None,
speech_zeroemb_idx: Union[int, List[int]] = 1024,
text_zeroemb_idx: int = 152067,
add_sosp_eosp=True,
) -> None:
has_text = text is not None
has_tokenized_text = tokenized_text is not None
assert has_text or has_tokenized_text, "Text or tokenized text must be provided"
self.audio = audio
self.text = text
self.tokenized_text = tokenized_text
self.speech_zeroemb_idx = speech_zeroemb_idx
self.text_zeroemb_idx = text_zeroemb_idx
self.add_sosp_eosp = add_sosp_eosp
@staticmethod
def insert_between(tensor, i, value=-1):
return torch.scatter(
torch.full(
(1, tensor.shape[1] + (tensor.shape[1] - 1) * i + i),
value,
dtype=tensor.dtype,
),
1,
torch.arange(0, tensor.shape[1], dtype=torch.int64)[None] * (i + 1),
tensor,
)
def to_input_id(
self,
tokenizer,
group_size: int,
audio_channels: int = 8,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.audio is None:
if self.tokenized_text is None:
tokenized_text = tokenizer(
self.text,
return_tensors="pt",
truncation=True,
max_length=999999,
padding=False,
add_special_tokens=False,
)["input_ids"].int()
else:
tokenized_text = self.tokenized_text.unsqueeze(0)
if group_size > 1:
tokenized_text = self.insert_between(
tokenized_text, group_size - 1, value=-100
)
if isinstance(self.speech_zeroemb_idx, list):
audio_part_input_id = torch.zeros((audio_channels, tokenized_text.shape[1]), dtype=torch.int)
for i, idx in enumerate(self.speech_zeroemb_idx):
audio_part_input_id[i, :] = idx
else:
audio_part_input_id = torch.full(
(audio_channels, tokenized_text.shape[1]), self.speech_zeroemb_idx, dtype=torch.int
)
else:
sosp_token = (
tokenizer.convert_tokens_to_ids("<|sosp|>")
if self.add_sosp_eosp
else None
)
eosp_token = (
tokenizer.convert_tokens_to_ids("<|eosp|>")
if self.add_sosp_eosp
else None
)
audio_part = self.audio.reshape(-1, audio_channels).T # [audio_channels, seqlen]
assert (
audio_part.shape[1] % group_size == 0
), f"Audio shape {audio_part.shape} is not divisible by group_size {group_size}"
text_len = audio_part.shape[1] // group_size
empty_token = self.text_zeroemb_idx
if empty_token is None:
empty_token = tokenizer.eod
tokenized_text = torch.full((1, text_len), empty_token, dtype=torch.int)
tokenized_text = (
torch.cat(
[
torch.tensor([[sosp_token]], dtype=torch.int),
tokenized_text,
torch.tensor([[eosp_token]], dtype=torch.int),
],
dim=1,
)
if self.add_sosp_eosp
else tokenized_text
)
tokenized_text = self.insert_between(
tokenized_text, group_size - 1, value=-100
)
if self.add_sosp_eosp:
if isinstance(self.speech_zeroemb_idx, list):
sosp_part = torch.zeros((audio_channels, group_size), dtype=torch.int)
eosp_part = torch.zeros((audio_channels, group_size), dtype=torch.int)
for i, idx in enumerate(self.speech_zeroemb_idx):
sosp_part[i, :] = idx
eosp_part[i, :] = idx
audio_part_input_id = torch.cat([sosp_part, audio_part, eosp_part], dim=1)
else:
audio_part_input_id = torch.cat(
[
torch.full((audio_channels, group_size), self.speech_zeroemb_idx, dtype=torch.int),
audio_part,
torch.full((audio_channels, group_size), self.speech_zeroemb_idx, dtype=torch.int),
],
dim=1,
)
else:
audio_part_input_id = audio_part
input_ids = torch.cat(
[tokenized_text, audio_part_input_id], dim=0
) # [n_rvq + 1, seqlen]
return input_ids
class StreamingInputSegment:
def __init__(
self,
text: str = "",
audio: torch.Tensor = None,
tokenized_text: torch.Tensor = None,
speech_zeroemb_idx: Union[int, List[int]] = 1024,
text_zeroemb_idx: int = 152067,
text_segment_size: int = 5,
audio_segment_size: int = 5,
tokenizer=None,
group_size=None,
audio_channels=None,
) -> None:
has_text = text is not None
has_tokenized_text = tokenized_text is not None
assert has_text or has_tokenized_text, "Text or tokenized text must be provided"
self.audio = audio
self.text = text
self.tokenized_text = tokenized_text
self.speech_zeroemb_idx = speech_zeroemb_idx
self.text_zeroemb_idx = text_zeroemb_idx
self.text_segment_size = text_segment_size
self.audio_segment_size = audio_segment_size
self.tokenizer = tokenizer
self.group_size = group_size
self.audio_channels = audio_channels
def to_input_id(
self,
tokenizer,
group_size: int,
audio_channels: int = 8,
):
if self.tokenized_text is None:
tokenized_text = tokenizer(
self.text,
return_tensors="pt",
truncation=True,
max_length=999999,
padding=False,
add_special_tokens=False,
)["input_ids"].int() # [1, seqlen]
else:
tokenized_text = self.tokenized_text.unsqueeze(0)
tokenized_text = tokenized_text.squeeze(0)
text_segments = tokenized_text.split(self.text_segment_size, dim=0)
audio_segments = self.audio.split(self.audio_segment_size*group_size*audio_channels, dim=0)
tokenized_segments = []
tokenized_segments.append(
InputSegment(
text='<|sostm|>',
speech_zeroemb_idx=self.speech_zeroemb_idx,
text_zeroemb_idx=self.text_zeroemb_idx,
),
)
eot_tokens = tokenizer(
"<|eot|>",
return_tensors="pt",
truncation=True,
max_length=999999,
padding=False,
add_special_tokens=False,
)["input_ids"][0].to(text_segments[-1])
text_segments = text_segments[:-1] + (torch.cat([text_segments[-1], eot_tokens], dim=0),)
length = min(len(text_segments), len(audio_segments))
for i in range(length):
text_segment = text_segments[i]
audio_segment = audio_segments[i]
tokenized_segments.append(
InputSegment(
tokenized_text=text_segment,
speech_zeroemb_idx=self.speech_zeroemb_idx,
text_zeroemb_idx=self.text_zeroemb_idx,
),
)
tokenized_segments.append(
InputSegment(
audio=audio_segment,
add_sosp_eosp=False,
speech_zeroemb_idx=self.speech_zeroemb_idx,
text_zeroemb_idx=self.text_zeroemb_idx,
),
)
for j in range(length, len(text_segments)):
tokenized_segments.append(
InputSegment(
tokenized_text=text_segments[j],
speech_zeroemb_idx=self.speech_zeroemb_idx,
text_zeroemb_idx=self.text_zeroemb_idx,
),
)
for j in range(length, len(audio_segments)):
tokenized_segments.append(
InputSegment(
audio=audio_segments[j],
add_sosp_eosp=False,
speech_zeroemb_idx=self.speech_zeroemb_idx,
text_zeroemb_idx=self.text_zeroemb_idx,
),
)
tokenized_segments.append(
InputSegment(
text="<|eostm|>",
speech_zeroemb_idx=self.speech_zeroemb_idx,
text_zeroemb_idx=self.text_zeroemb_idx,
),
)
input_ids = [
seg.to_input_id(
self.tokenizer,
self.group_size,
self.audio_channels,
)
for seg in tokenized_segments
]
input_ids = torch.cat(input_ids, dim=1).type(torch.int64) # [n_rvq + 1, seqlen]
return input_ids