DualTurn
Collection
DualTurn: Learning Turn-Taking from Dual-Channel Generative Speech Pretraining β’ 3 items β’ Updated
Real-time dual-channel turn-taking model. Given stereo audio (user mic + agent speaker), predicts per-frame:
| Output | Shape | Description |
|---|---|---|
vad_probs |
[B, T, 2] | P(speaking now) β [:,0]=user [:,1]=agent |
fvad_probs |
[B, T, 8] | P(future speech) at 240/480/960/2000 ms β user 0:4, agent 4:8 |
eot_probs |
[B, T, 2] | P(end of turn) per channel |
bot_probs |
[B, T, 2] | P(beginning of turn) per channel |
hold_probs |
[B, T, 2] | P(within-turn hold/pause) per channel |
bc_probs |
[B, T, 2] | P(backchannel) per channel |
Frame rate: 12.5 Hz (80 ms per frame). Audio is resampled internally to 24 kHz.
pip install transformers torch torchaudio safetensors
import torch, torchaudio
from transformers import AutoModel
model = AutoModel.from_pretrained(
"anyreach-ai/dualturn-qwen2.5-mimi-0.5B",
trust_remote_code=True,
)
model.eval()
wav, sr = torchaudio.load("conversation.wav") # [2, T] CH0=user CH1=agent
with torch.no_grad():
out = model(wav, sr=sr)
print(out.vad_probs.shape) # [1, T, 2]
print(out.fvad_probs.shape) # [1, T, 8]
print(out.eot_probs.shape) # [1, T, 2]
The ONNX checkpoint runs the Qwen backbone with a rolling KV cache.
Pass empty past_key_values for full-context batch inference, or maintain
the cache across steps for low-latency streaming (~41 ms/step at 5 s context).
pip install onnxruntime transformers torch torchaudio safetensors huggingface_hub
import numpy as np, torch, torch.nn as nn, torchaudio, os
from transformers import MimiModel
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download, snapshot_download
import onnxruntime as ort
N_LAYERS, N_HEADS, HEAD_DIM, D = 24, 2, 64, 896
# Mimi encoder
mimi = MimiModel.from_pretrained("kyutai/mimi").eval()
@torch.no_grad()
def encode(wav_1d):
x = wav_1d.unsqueeze(0).unsqueeze(0)
e = mimi.encoder(x)
et = mimi.encoder_transformer(e.transpose(1, 2))
if hasattr(et, "last_hidden_state"): et = et.last_hidden_state
return mimi.downsample(et.transpose(1, 2)).squeeze(0).T.float().numpy() # [T, 512]
# Projection + heads from safetensors
weights = load_file(hf_hub_download("anyreach-ai/dualturn-qwen2.5-mimi-0.5B", "model.safetensors"))
proj = nn.Sequential(nn.Linear(1024, D), nn.GELU(), nn.Linear(D, D)).eval()
proj[0].weight.data = weights["mimi_projection.proj.0.weight"]
proj[0].bias.data = weights["mimi_projection.proj.0.bias"]
proj[2].weight.data = weights["mimi_projection.proj.2.weight"]
proj[2].bias.data = weights["mimi_projection.proj.2.bias"]
@torch.no_grad()
def project(f0, f1):
return proj(torch.cat([torch.from_numpy(f0),
torch.from_numpy(f1)], dim=-1)).unsqueeze(0).numpy()
def _head(key):
h = nn.Linear(D, 1).eval()
h.weight.data = weights[f"{key}.weight"]
h.bias.data = weights[f"{key}.bias"]
return h
vad0, vad1 = _head("vad_head_ch0"), _head("vad_head_ch1")
fvad = nn.Linear(D, 8).eval()
fvad.weight.data = weights["fvad_head.weight"]
fvad.bias.data = weights["fvad_head.bias"]
@torch.no_grad()
def heads(h_np):
h = torch.from_numpy(h_np)
return {
"vad_probs": torch.sigmoid(torch.stack([vad0(h).squeeze(-1),
vad1(h).squeeze(-1)], -1)).numpy(),
"fvad_probs": torch.sigmoid(fvad(h)).numpy(),
}
# ONNX needs model.onnx + model.onnx_data in the same directory β use snapshot_download
repo_dir = snapshot_download(
"anyreach-ai/dualturn-qwen2.5-mimi-0.5B",
allow_patterns=["onnx/*"],
)
sess = ort.InferenceSession(
os.path.join(repo_dir, "onnx", "model.onnx"),
providers=["CPUExecutionProvider"])
def step(embeds, pos, past_kv=None):
T, T_ctx = embeds.shape[1], (past_kv[0][0].shape[2] if past_kv else 0)
inp = {
"input_ids": np.ones((1, T), np.int64),
"attention_mask": np.ones((1, T_ctx + T), np.int64),
"position_ids": np.arange(pos, pos + T, dtype=np.int64).reshape(1, -1),
}
for i in range(N_LAYERS):
inp[f"past_key_values.{i}.key"] = past_kv[i][0] if past_kv else np.zeros((1,N_HEADS,0,HEAD_DIM),np.float32)
inp[f"past_key_values.{i}.value"] = past_kv[i][1] if past_kv else np.zeros((1,N_HEADS,0,HEAD_DIM),np.float32)
outs = sess.run(None, inp)
return outs[0], [(outs[1+i*2], outs[2+i*2]) for i in range(N_LAYERS)]
# Run
wav, sr = torchaudio.load("conversation.wav")
wav = torchaudio.transforms.Resample(sr, 24000)(wav)
f0, f1 = encode(wav[0]), encode(wav[1])
T = min(f0.shape[0], f1.shape[0])
# Batch (no KV cache)
hidden, _ = step(project(f0, f1), pos=0)
preds = heads(hidden)
print(preds["vad_probs"].shape) # (1, T, 2)
print(preds["fvad_probs"].shape) # (1, T, 8)
# Streaming (KV cache, 240 ms steps)
past, pos = None, 0
for i in range(0, T, 3):
emb = project(f0[i:i+3], f1[i:i+3])
hidden, past = step(emb, pos, past)
chunk_preds = heads(hidden)
if past and past[0][0].shape[2] > 62: # keep 5 s context
past = [(k[:,:,-62:,:], v[:,:,-62:,:]) for k,v in past]
pos += emb.shape[1]
| File | Description |
|---|---|
model.safetensors |
FP32 weights β Qwen2.5-0.5B backbone + projection + all heads |
config.json |
Model config with auto_map for AutoModel |
modeling_dualturn.py |
DualTurnModel + DualTurnConfig (self-contained) |
onnx/model.onnx |
ONNX Qwen backbone β outputs hidden_states [B,T,896] + KV cache |
Trained and evaluated on:
The following will be released soon:
Paper: DualTurn: Learning Turn-Taking from Dual-Channel Generative Speech Pretraining
@misc{rajaa2026dualturnlearningturntakingdualchannel,
title={DualTurn: Learning Turn-Taking from Dual-Channel Generative Speech Pretraining},
author={Shangeth Rajaa},
year={2026},
eprint={2603.08216},
archivePrefix={arXiv},
primaryClass={eess.AS},
url={https://arxiv.org/abs/2603.08216},
}