|
|
| import torch |
| import torch.nn as nn |
| import coremltools as ct |
| import numpy as np |
| import typer |
| from pathlib import Path |
| from typing import Tuple, List, Optional |
| import json |
| import shutil |
|
|
| |
| import coremltools as ct |
| import numpy as np |
| import argparse |
| from nemo.collections.asr.models import EncDecRNNTBPEModel |
|
|
| app = typer.Typer() |
|
|
| class LoopbackEncoderWrapper(nn.Module): |
| """ |
| Wraps the entire Parakeet Encoder (PreEncode + Conformer) for CoreML Loopback Streaming. |
| |
| Inputs: |
| - audio_signal: [B, D, T] (Mel spectrogram chunk) |
| - audio_length: [B] |
| - pre_cache: [B, D, pre_cache_size] (Previous audio context) |
| - cache_last_channel: [layers, B, cache_size, hidden] |
| - cache_last_time: [layers, B, hidden, time_cache] |
| - cache_last_channel_len: [B] |
| |
| Outputs: |
| - encoded_output: [B, D_out, T_out] |
| - encoded_length: [B] |
| - new_pre_cache: [B, D, pre_cache_size] |
| - new_cache_last_channel |
| - new_cache_last_time |
| - new_cache_last_channel_len |
| """ |
| def __init__(self, encoder, pre_cache_size=16): |
| super().__init__() |
| self.encoder = encoder |
| self.pre_cache_size = pre_cache_size |
| |
| def forward( |
| self, |
| audio_signal: torch.Tensor, |
| audio_length: torch.Tensor, |
| pre_cache: torch.Tensor, |
| cache_last_channel: torch.Tensor, |
| cache_last_time: torch.Tensor, |
| cache_last_channel_len: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| |
| |
| |
| |
| full_input = torch.cat([pre_cache, audio_signal], dim=2) |
| full_length = audio_length + self.pre_cache_size |
| |
| |
| |
| new_pre_cache = full_input[:, :, -self.pre_cache_size:] |
| |
| |
| |
| current_cache = [cache_last_channel, cache_last_time, cache_last_channel_len] |
| |
| encoded, encoded_len, new_cache_channel, new_cache_time, new_cache_len = self.encoder.cache_aware_stream_step( |
| processed_signal=full_input, |
| processed_signal_length=full_length, |
| cache_last_channel=cache_last_channel, |
| cache_last_time=cache_last_time, |
| cache_last_channel_len=cache_last_channel_len |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| encoded_len_32 = encoded_len.to(dtype=torch.int32) |
| new_channel_len_32 = new_cache_len.to(dtype=torch.int32) |
| |
| return encoded, encoded_len_32, new_pre_cache, new_cache_channel, new_cache_time, new_channel_len_32 |
|
|
| def _coreml_convert( |
| traced_model, |
| inputs, |
| outputs, |
| compute_units=ct.ComputeUnit.CPU_ONLY |
| ): |
| return ct.convert( |
| traced_model, |
| inputs=inputs, |
| outputs=outputs, |
| compute_units=compute_units, |
| minimum_deployment_target=ct.target.macOS14, |
| ) |
|
|
| def main(): |
| model_id: str = "nvidia/parakeet_realtime_eou_120m-v1" |
| output_dir: str = "temp_swift_models/StreamingLoopback" |
| output_path = Path(output_dir) |
| output_path.mkdir(parents=True, exist_ok=True) |
| |
| print(f"Loading model: {model_id}...") |
| asr_model = EncDecRNNTBPEModel.from_pretrained(model_name=model_id) |
| asr_model.eval() |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--chunk-frames", type=int, default=17, help="Number of frames in the input chunk (e.g. 17 for 160ms, 129 for 1.28s)") |
| args = parser.parse_args() |
|
|
| encoder = asr_model.encoder |
| |
| |
| |
| |
| chunk_size_in = args.chunk_frames |
| mel_dim = 128 |
| hidden_dim = encoder.d_model |
| num_layers = len(encoder.layers) |
| |
| |
| cache_channel_size = 70 |
| cache_time_size = 8 |
| pre_cache_size = 16 |
| |
| print(f"Config: Chunk={chunk_size_in}, Mel={mel_dim}, Hidden={hidden_dim}, Layers={num_layers}") |
| print(f"Cache: Channel={cache_channel_size}, Time={cache_time_size}, Pre={pre_cache_size}") |
|
|
| |
| wrapper = LoopbackEncoderWrapper(encoder, pre_cache_size=pre_cache_size) |
| wrapper.eval() |
| |
| |
| batch_size = 1 |
| test_mel = torch.randn(batch_size, mel_dim, chunk_size_in) |
| test_mel_len = torch.tensor([chunk_size_in], dtype=torch.int32) |
| test_pre_cache = torch.zeros(batch_size, mel_dim, pre_cache_size) |
| |
| |
| test_cache_channel = torch.zeros(num_layers, batch_size, cache_channel_size, hidden_dim) |
| test_cache_time = torch.zeros(num_layers, batch_size, hidden_dim, cache_time_size) |
| test_cache_len = torch.zeros(batch_size, dtype=torch.int32) |
| |
| print("Tracing model...") |
| traced_model = torch.jit.trace( |
| wrapper, |
| (test_mel, test_mel_len, test_pre_cache, test_cache_channel, test_cache_time, test_cache_len), |
| strict=False |
| ) |
| |
| |
| print("Converting to CoreML...") |
| |
| inputs = [ |
| ct.TensorType(name="audio_signal", shape=(1, 128, chunk_size_in), dtype=np.float32), |
| ct.TensorType(name="audio_length", shape=(1,), dtype=np.int32), |
| ct.TensorType(name="pre_cache", shape=(1, 128, pre_cache_size), dtype=np.float32), |
| ct.TensorType(name="cache_last_channel", shape=(num_layers, 1, cache_channel_size, hidden_dim), dtype=np.float32), |
| ct.TensorType(name="cache_last_time", shape=(num_layers, 1, hidden_dim, cache_time_size), dtype=np.float32), |
| ct.TensorType(name="cache_last_channel_len", shape=(1,), dtype=np.int32), |
| ] |
| |
| outputs = [ |
| ct.TensorType(name="encoded_output", dtype=np.float32), |
| ct.TensorType(name="encoded_length", dtype=np.int32), |
| ct.TensorType(name="new_pre_cache", dtype=np.float32), |
| ct.TensorType(name="new_cache_last_channel", dtype=np.float32), |
| ct.TensorType(name="new_cache_last_time", dtype=np.float32), |
| ct.TensorType(name="new_cache_last_channel_len", dtype=np.int32), |
| ] |
| |
| mlmodel = _coreml_convert(traced_model, inputs, outputs) |
| |
| save_path = output_path / "streaming_encoder.mlpackage" |
| mlmodel.save(str(save_path)) |
| print(f"Saved: {save_path}") |
| |
| |
| |
| |
|
|
| if __name__ == "__main__": |
| main() |
|
|