DualTurn-Qwen2.5-Mimi-0.5B

Real-time dual-channel turn-taking model. Given stereo audio (user mic + agent speaker), predicts per-frame:

Output Shape Description
vad_probs [B, T, 2] P(speaking now) β€” [:,0]=user [:,1]=agent
fvad_probs [B, T, 8] P(future speech) at 240/480/960/2000 ms β€” user 0:4, agent 4:8
eot_probs [B, T, 2] P(end of turn) per channel
bot_probs [B, T, 2] P(beginning of turn) per channel
hold_probs [B, T, 2] P(within-turn hold/pause) per channel
bc_probs [B, T, 2] P(backchannel) per channel

Frame rate: 12.5 Hz (80 ms per frame). Audio is resampled internally to 24 kHz.


FP32 Inference

pip install transformers torch torchaudio safetensors
import torch, torchaudio
from transformers import AutoModel

model = AutoModel.from_pretrained(
    "anyreach-ai/dualturn-qwen2.5-mimi-0.5B",
    trust_remote_code=True,
)
model.eval()

wav, sr = torchaudio.load("conversation.wav")  # [2, T]  CH0=user  CH1=agent

with torch.no_grad():
    out = model(wav, sr=sr)

print(out.vad_probs.shape)    # [1, T, 2]
print(out.fvad_probs.shape)   # [1, T, 8]
print(out.eot_probs.shape)    # [1, T, 2]

ONNX Inference (CPU, KV Cache)

The ONNX checkpoint runs the Qwen backbone with a rolling KV cache. Pass empty past_key_values for full-context batch inference, or maintain the cache across steps for low-latency streaming (~41 ms/step at 5 s context).

pip install onnxruntime transformers torch torchaudio safetensors huggingface_hub
import numpy as np, torch, torch.nn as nn, torchaudio, os
from transformers import MimiModel
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download, snapshot_download
import onnxruntime as ort

N_LAYERS, N_HEADS, HEAD_DIM, D = 24, 2, 64, 896

# Mimi encoder
mimi = MimiModel.from_pretrained("kyutai/mimi").eval()

@torch.no_grad()
def encode(wav_1d):
    x  = wav_1d.unsqueeze(0).unsqueeze(0)
    e  = mimi.encoder(x)
    et = mimi.encoder_transformer(e.transpose(1, 2))
    if hasattr(et, "last_hidden_state"): et = et.last_hidden_state
    return mimi.downsample(et.transpose(1, 2)).squeeze(0).T.float().numpy()  # [T, 512]

# Projection + heads from safetensors
weights = load_file(hf_hub_download("anyreach-ai/dualturn-qwen2.5-mimi-0.5B", "model.safetensors"))

proj = nn.Sequential(nn.Linear(1024, D), nn.GELU(), nn.Linear(D, D)).eval()
proj[0].weight.data = weights["mimi_projection.proj.0.weight"]
proj[0].bias.data   = weights["mimi_projection.proj.0.bias"]
proj[2].weight.data = weights["mimi_projection.proj.2.weight"]
proj[2].bias.data   = weights["mimi_projection.proj.2.bias"]

@torch.no_grad()
def project(f0, f1):
    return proj(torch.cat([torch.from_numpy(f0),
                            torch.from_numpy(f1)], dim=-1)).unsqueeze(0).numpy()

def _head(key):
    h = nn.Linear(D, 1).eval()
    h.weight.data = weights[f"{key}.weight"]
    h.bias.data   = weights[f"{key}.bias"]
    return h

vad0, vad1 = _head("vad_head_ch0"), _head("vad_head_ch1")
fvad = nn.Linear(D, 8).eval()
fvad.weight.data = weights["fvad_head.weight"]
fvad.bias.data   = weights["fvad_head.bias"]

@torch.no_grad()
def heads(h_np):
    h = torch.from_numpy(h_np)
    return {
        "vad_probs":  torch.sigmoid(torch.stack([vad0(h).squeeze(-1),
                                                  vad1(h).squeeze(-1)], -1)).numpy(),
        "fvad_probs": torch.sigmoid(fvad(h)).numpy(),
    }

# ONNX needs model.onnx + model.onnx_data in the same directory β€” use snapshot_download
repo_dir = snapshot_download(
    "anyreach-ai/dualturn-qwen2.5-mimi-0.5B",
    allow_patterns=["onnx/*"],
)
sess = ort.InferenceSession(
    os.path.join(repo_dir, "onnx", "model.onnx"),
    providers=["CPUExecutionProvider"])

def step(embeds, pos, past_kv=None):
    T, T_ctx = embeds.shape[1], (past_kv[0][0].shape[2] if past_kv else 0)
    inp = {
        "input_ids":      np.ones((1, T), np.int64),
        "attention_mask": np.ones((1, T_ctx + T), np.int64),
        "position_ids":   np.arange(pos, pos + T, dtype=np.int64).reshape(1, -1),
    }
    for i in range(N_LAYERS):
        inp[f"past_key_values.{i}.key"]   = past_kv[i][0] if past_kv else np.zeros((1,N_HEADS,0,HEAD_DIM),np.float32)
        inp[f"past_key_values.{i}.value"] = past_kv[i][1] if past_kv else np.zeros((1,N_HEADS,0,HEAD_DIM),np.float32)
    outs = sess.run(None, inp)
    return outs[0], [(outs[1+i*2], outs[2+i*2]) for i in range(N_LAYERS)]

# Run
wav, sr = torchaudio.load("conversation.wav")
wav     = torchaudio.transforms.Resample(sr, 24000)(wav)
f0, f1  = encode(wav[0]), encode(wav[1])
T       = min(f0.shape[0], f1.shape[0])

# Batch (no KV cache)
hidden, _ = step(project(f0, f1), pos=0)
preds     = heads(hidden)
print(preds["vad_probs"].shape)    # (1, T, 2)
print(preds["fvad_probs"].shape)   # (1, T, 8)

# Streaming (KV cache, 240 ms steps)
past, pos = None, 0
for i in range(0, T, 3):
    emb          = project(f0[i:i+3], f1[i:i+3])
    hidden, past = step(emb, pos, past)
    chunk_preds  = heads(hidden)
    if past and past[0][0].shape[2] > 62:      # keep 5 s context
        past = [(k[:,:,-62:,:], v[:,:,-62:,:]) for k,v in past]
    pos += emb.shape[1]

Files

File Description
model.safetensors FP32 weights β€” Qwen2.5-0.5B backbone + projection + all heads
config.json Model config with auto_map for AutoModel
modeling_dualturn.py DualTurnModel + DualTurnConfig (self-contained)
onnx/model.onnx ONNX Qwen backbone β€” outputs hidden_states [B,T,896] + KV cache

Training Data

Trained and evaluated on:


DualTurn Model & Code

The following will be released soon:

  • Final trained model checkpoint β€” this is an intermediate checkpoint; the final model will be released at anyreach-ai
  • Training code β€” model architecture, training loop, and configs
  • Evaluation code β€” benchmarks and metrics used in the paper

Authors


Citation

Paper: DualTurn: Learning Turn-Taking from Dual-Channel Generative Speech Pretraining

@misc{rajaa2026dualturnlearningturntakingdualchannel,
      title={DualTurn: Learning Turn-Taking from Dual-Channel Generative Speech Pretraining},
      author={Shangeth Rajaa},
      year={2026},
      eprint={2603.08216},
      archivePrefix={arXiv},
      primaryClass={eess.AS},
      url={https://arxiv.org/abs/2603.08216},
}
Downloads last month
239
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Datasets used to train anyreach-ai/dualturn-qwen2.5-mimi-0.5B

Collection including anyreach-ai/dualturn-qwen2.5-mimi-0.5B

Paper for anyreach-ai/dualturn-qwen2.5-mimi-0.5B