|
|
| |
| |
| import io |
| import json |
| import matplotlib as mpl |
| import matplotlib.pyplot as plt |
| import mmap |
| import numpy as np |
| import soundfile |
| import torchaudio |
| import torch |
| from pydub import AudioSegment |
| |
| |
| |
| import math |
| from simuleval.data.segments import SpeechSegment, EmptySegment |
| from seamless_communication.streaming.agents.seamless_streaming_s2st import ( |
| SeamlessStreamingS2STVADAgent, |
| ) |
|
|
| from simuleval.utils.arguments import cli_argument_list |
| from simuleval import options |
|
|
|
|
| from typing import Union, List |
| from simuleval.data.segments import Segment, TextSegment |
| from simuleval.agents.pipeline import TreeAgentPipeline |
| from simuleval.agents.states import AgentStates |
| |
| |
| |
| SAMPLE_RATE = 16000 |
|
|
| |
| |
| |
| class AudioFrontEnd: |
| def __init__(self, wav_file, segment_size) -> None: |
| self.samples, self.sample_rate = soundfile.read(wav_file) |
| print(self.sample_rate, "sample rate") |
| assert self.sample_rate == SAMPLE_RATE |
| |
| self.samples = self.samples |
| self.segment_size = segment_size |
| self.step = 0 |
|
|
| def send_segment(self): |
| """ |
| This is the front-end logic in simuleval instance.py |
| """ |
|
|
| num_samples = math.ceil(self.segment_size / 1000 * self.sample_rate) |
|
|
| if self.step < len(self.samples): |
| if self.step + num_samples >= len(self.samples): |
| samples = self.samples[self.step :] |
| is_finished = True |
| else: |
| samples = self.samples[self.step : self.step + num_samples] |
| is_finished = False |
| self.samples = self.samples[self.step:] |
| self.step = min(self.step + num_samples, len(self.samples)) |
| segment = SpeechSegment( |
| content=samples, |
| sample_rate=self.sample_rate, |
| finished=is_finished, |
| ) |
| else: |
| |
| segment = EmptySegment( |
| finished=True, |
| ) |
| self.step = 0 |
| self.samples = [] |
| return segment |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| def add_segments(self, wav): |
| new_samples, _ = soundfile.read(wav) |
| self.samples = np.concatenate((self.samples, new_samples)) |
|
|
|
|
| class OutputSegments: |
| def __init__(self, segments: Union[List[Segment], Segment]): |
| if isinstance(segments, Segment): |
| segments = [segments] |
| self.segments: List[Segment] = [s for s in segments] |
|
|
| @property |
| def is_empty(self): |
| return all(segment.is_empty for segment in self.segments) |
|
|
| @property |
| def finished(self): |
| return all(segment.finished for segment in self.segments) |
|
|
|
|
| def get_audiosegment(samples, sr): |
| b = io.BytesIO() |
| soundfile.write(b, samples, samplerate=sr, format="wav") |
| b.seek(0) |
| return AudioSegment.from_file(b) |
|
|
|
|
| def reset_states(system, states): |
| if isinstance(system, TreeAgentPipeline): |
| states_iter = states.values() |
| else: |
| states_iter = states |
| for state in states_iter: |
| state.reset() |
|
|
|
|
| def get_states_root(system, states) -> AgentStates: |
| if isinstance(system, TreeAgentPipeline): |
| |
| return states[system.source_module] |
| else: |
| |
| return system.states[0] |
| |
|
|
| def build_streaming_system(model_configs, agent_class): |
| parser = options.general_parser() |
| parser.add_argument("-f", "--f", help="a dummy argument to fool ipython", default="1") |
|
|
| agent_class.add_args(parser) |
| args, _ = parser.parse_known_args(cli_argument_list(model_configs)) |
| system = agent_class.from_args(args) |
| return system |
|
|
|
|
| def run_streaming_inference(system, audio_frontend, system_states, tgt_lang): |
| |
| |
| |
| |
| |
| delays = {"s2st": [], "s2tt": []} |
| prediction_lists = {"s2st": [], "s2tt": []} |
| speech_durations = [] |
| curr_delay = 0 |
| target_sample_rate = None |
|
|
| while True: |
| input_segment = audio_frontend.send_segment() |
| input_segment.tgt_lang = tgt_lang |
| curr_delay += len(input_segment.content) / SAMPLE_RATE * 1000 |
| if input_segment.finished: |
| |
| get_states_root(system, system_states).source_finished = True |
| |
| if isinstance(input_segment, EmptySegment): |
| return None, None, None, None |
| output_segments = OutputSegments(system.pushpop(input_segment, system_states)) |
| if not output_segments.is_empty: |
| for segment in output_segments.segments: |
| |
| |
| if isinstance(segment, SpeechSegment): |
| pred_duration = 1000 * len(segment.content) / segment.sample_rate |
| speech_durations.append(pred_duration) |
| delays["s2st"].append(curr_delay) |
| prediction_lists["s2st"].append(segment.content) |
| target_sample_rate = segment.sample_rate |
| elif isinstance(segment, TextSegment): |
| delays["s2tt"].append(curr_delay) |
| prediction_lists["s2tt"].append(segment.content) |
| print(curr_delay, segment.content) |
| if output_segments.finished: |
| reset_states(system, system_states) |
| if input_segment.finished: |
| |
| |
| break |
| return delays, prediction_lists, speech_durations, target_sample_rate |
|
|
|
|
| def get_s2st_delayed_targets(delays, target_sample_rate, prediction_lists, speech_durations): |
| |
| intervals = [] |
|
|
| start = prev_end = prediction_offset = delays["s2st"][0] |
| target_samples = [0.0] * int(target_sample_rate * prediction_offset / 1000) |
|
|
| for i, delay in enumerate(delays["s2st"]): |
| start = max(prev_end, delay) |
|
|
| if start > prev_end: |
| |
| target_samples += [0.0] * int( |
| target_sample_rate * (start - prev_end) / 1000 |
| ) |
|
|
| target_samples += prediction_lists["s2st"][i] |
| duration = speech_durations[i] |
| prev_end = start + duration |
| intervals.append([start, duration]) |
| return target_samples, intervals |
|
|
|
|