Mimi ONNX Streaming Codec

Streaming ONNX models for Kyutai Mimi neural audio codec, exported with explicit convolutional state buffers and transformer KV cache for frame-by-frame processing.

Models

Variant Encoder Decoder Total Bitrate Codebooks Precision
streaming-8cb 194 MB 170 MB 364 MB 1.1 kbps 8 FP32
streaming-16cb 242 MB 186 MB 428 MB 2.2 kbps 16 FP32
streaming-8cb-fp16 119 MB 93 MB 212 MB 1.1 kbps 8 FP16 weights
streaming-16cb-fp16 167 MB 109 MB 276 MB 2.2 kbps 16 FP16 weights
  • Sample rate: 24 kHz mono
  • Frame rate: 12.5 Hz (one code frame per 80ms of audio)
  • Codebook size: 2048 (11 bits per codebook)

FP16 Variants

The -fp16 variants use weight-only FP16: all model weights are stored as float16 for ~40% smaller files, while graph I/O remains float32. At runtime, ONNX Runtime casts weights back to float32 for computation, so there is no quality loss compared to the FP32 models. These are drop-in replacements โ€” no code changes needed.

Converted using scripts/weight_fp16.py from the FP32 models.

Architecture

Each model (encoder and decoder) carries two kinds of explicit state as tensor I/O:

  1. Conv state buffers (11 per encoder/decoder) โ€” causal padding tails from SEANet convolutional layers. Shapes defined in state_spec.txt.
  2. KV cache (16 tensors per encoder/decoder) โ€” transformer self-attention key/value history. Shape: [1, 8, seq_len, 64], grows with each frame.

A causal attention mask is computed inside the ONNX graph, so the transformer processes all tokens in a single call with correct autoregressive behavior.

Encoder

Input: PCM audio chunk + conv states + KV cache

Output: Integer codes [1, num_codebooks, num_frames] + updated conv states + updated KV cache

Decoder

Input: Integer codes + conv states + KV cache

Output: PCM audio [1, 1, num_samples] + updated conv states + updated KV cache

Usage

Python (ONNX Runtime)

import numpy as np
import onnxruntime as ort

# Load models (FP16 variant is a drop-in replacement)
enc = ort.InferenceSession("streaming-8cb-fp16/encoder_model.onnx")
dec = ort.InferenceSession("streaming-8cb-fp16/decoder_model.onnx")

# Initialize state โ€” parse state_spec.txt for conv state shapes
# KV cache starts empty: shape [1, 8, 0, 64] for each of 16 tensors
enc_kv = [np.zeros((1, 8, 0, 64), dtype=np.float32) for _ in range(16)]

# Encode one chunk (e.g. 7680 samples = 320ms)
audio = np.random.randn(1, 1, 7680).astype(np.float32)
inputs = {"input_values": audio}
# ... add conv states and KV cache to inputs ...
outputs = enc.run(None, inputs)
codes = outputs[0]  # [1, 8, 4] โ€” 4 code frames
# outputs[1:12] = updated conv states
# outputs[12:] = updated KV cache

state_spec.txt Format

[encoder]
conv enc_0 1 6          # name, channels, temporal_size
conv enc_1_b1 64 2
...

[decoder]
conv_tr us 512 2        # conv_tr = transposed convolution
conv dec_0 512 6
...

Each conv state is a tensor of shape [1, channels, temporal_size], initialized to zeros on the first frame.

Export

Models were exported from kyutai/mimi using a custom streaming wrapper:

# Export FP32 models
python scripts/export_streaming_onnx.py \
  --num-codebooks 8 \
  --output-dir streaming-8cb

# Convert to weight-only FP16
python scripts/weight_fp16.py \
  --input-dir streaming-8cb \
  --output-dir streaming-8cb-fp16

The export scripts live in mimi-codec (see scripts/).

Quality

Streaming vs batch ONNX baseline on a 3-second test clip:

Metric Value
SNR 31.9 dB
Cosine similarity 0.9997
Max abs diff 0.16

The gap is from causal streaming (conv state carry-over), not from ONNX precision โ€” batch ONNX matches PyTorch at 128+ dB SNR. FP16 weight variants produce identical output to FP32 since computation runs in float32.

License

Same as the base model: CC-BY-4.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for BMekiker/mimi-onnx-streaming

Base model

kyutai/mimi
Quantized
(2)
this model