Mimi ONNX Streaming Codec
Streaming ONNX models for Kyutai Mimi neural audio codec, exported with explicit convolutional state buffers and transformer KV cache for frame-by-frame processing.
Models
| Variant | Encoder | Decoder | Total | Bitrate | Codebooks | Precision |
|---|---|---|---|---|---|---|
streaming-8cb |
194 MB | 170 MB | 364 MB | 1.1 kbps | 8 | FP32 |
streaming-16cb |
242 MB | 186 MB | 428 MB | 2.2 kbps | 16 | FP32 |
streaming-8cb-fp16 |
119 MB | 93 MB | 212 MB | 1.1 kbps | 8 | FP16 weights |
streaming-16cb-fp16 |
167 MB | 109 MB | 276 MB | 2.2 kbps | 16 | FP16 weights |
- Sample rate: 24 kHz mono
- Frame rate: 12.5 Hz (one code frame per 80ms of audio)
- Codebook size: 2048 (11 bits per codebook)
FP16 Variants
The -fp16 variants use weight-only FP16: all model weights are stored as float16 for ~40% smaller files, while graph I/O remains float32. At runtime, ONNX Runtime casts weights back to float32 for computation, so there is no quality loss compared to the FP32 models. These are drop-in replacements โ no code changes needed.
Converted using scripts/weight_fp16.py from the FP32 models.
Architecture
Each model (encoder and decoder) carries two kinds of explicit state as tensor I/O:
- Conv state buffers (11 per encoder/decoder) โ causal padding tails from SEANet convolutional layers. Shapes defined in
state_spec.txt. - KV cache (16 tensors per encoder/decoder) โ transformer self-attention key/value history. Shape:
[1, 8, seq_len, 64], grows with each frame.
A causal attention mask is computed inside the ONNX graph, so the transformer processes all tokens in a single call with correct autoregressive behavior.
Encoder
Input: PCM audio chunk + conv states + KV cache
Output: Integer codes [1, num_codebooks, num_frames] + updated conv states + updated KV cache
Decoder
Input: Integer codes + conv states + KV cache
Output: PCM audio [1, 1, num_samples] + updated conv states + updated KV cache
Usage
Python (ONNX Runtime)
import numpy as np
import onnxruntime as ort
# Load models (FP16 variant is a drop-in replacement)
enc = ort.InferenceSession("streaming-8cb-fp16/encoder_model.onnx")
dec = ort.InferenceSession("streaming-8cb-fp16/decoder_model.onnx")
# Initialize state โ parse state_spec.txt for conv state shapes
# KV cache starts empty: shape [1, 8, 0, 64] for each of 16 tensors
enc_kv = [np.zeros((1, 8, 0, 64), dtype=np.float32) for _ in range(16)]
# Encode one chunk (e.g. 7680 samples = 320ms)
audio = np.random.randn(1, 1, 7680).astype(np.float32)
inputs = {"input_values": audio}
# ... add conv states and KV cache to inputs ...
outputs = enc.run(None, inputs)
codes = outputs[0] # [1, 8, 4] โ 4 code frames
# outputs[1:12] = updated conv states
# outputs[12:] = updated KV cache
state_spec.txt Format
[encoder]
conv enc_0 1 6 # name, channels, temporal_size
conv enc_1_b1 64 2
...
[decoder]
conv_tr us 512 2 # conv_tr = transposed convolution
conv dec_0 512 6
...
Each conv state is a tensor of shape [1, channels, temporal_size], initialized to zeros on the first frame.
Export
Models were exported from kyutai/mimi using a custom streaming wrapper:
# Export FP32 models
python scripts/export_streaming_onnx.py \
--num-codebooks 8 \
--output-dir streaming-8cb
# Convert to weight-only FP16
python scripts/weight_fp16.py \
--input-dir streaming-8cb \
--output-dir streaming-8cb-fp16
The export scripts live in mimi-codec (see scripts/).
Quality
Streaming vs batch ONNX baseline on a 3-second test clip:
| Metric | Value |
|---|---|
| SNR | 31.9 dB |
| Cosine similarity | 0.9997 |
| Max abs diff | 0.16 |
The gap is from causal streaming (conv state carry-over), not from ONNX precision โ batch ONNX matches PyTorch at 128+ dB SNR. FP16 weight variants produce identical output to FP32 since computation runs in float32.
License
Same as the base model: CC-BY-4.0
Model tree for BMekiker/mimi-onnx-streaming
Base model
kyutai/mimi