QThink-Qwen3-1.7B-Tooluse
QThink: Parallel Latent Reasoning via Per-Step Distillation of Multiple Rollouts
This model replaces explicit chain-of-thought (<think>...</think>) with 6 latent forward passes through a learned projection head, achieving 48.5% EM on Tooluse.
How QThink Works
Instead of generating thousands of reasoning tokens, QThink:
- Processes the prompt through the base model
- Runs K=6 latent steps: each step applies a ProjectionHead (
Linear(2048,2048) → GELU → Linear(2048,2048) → LayerNorm(2048)) to the hidden state, then feeds it back through the model viainputs_embeds+past_key_values - Generates the answer directly from the last latent step's hidden state — no
<think>block needed
Training: Per-Step Distillation from Multiple Rollouts
- Generate G=16 chain-of-thought rollouts per training problem using the base model
- For each rollout, extract hidden states at K=6 evenly-spaced positions within the
<think>block - Average these hidden states across all rollouts (uniform — including incorrect ones) at each step
- Train each latent step to match the corresponding teacher state via L1 loss, jointly with cross-entropy on the answer
The key innovations:
- Per-step distillation: Every latent step gets direct supervision, not just the final one
- Uniform multi-rollout teachers: Averaging over ALL rollouts (correct + incorrect) outperforms using only correct rollouts
- ans256 training: Training with longer answer targets (+3pp improvement)
Results
| Model | Exact Match | Action Accuracy |
|---|---|---|
| QThink uniform per-step (ours) | 48.5% | 67.6% |
| Base Qwen3-1.7B | 47.1% | 75.0% |
| SFT | 45.6% | 73.5% |
| QThink RW final-step (CODI) | 42.6% | 67.6% |
Cross-Benchmark Results
QThink uniform per-step is the best model on all 3 benchmarks:
| Benchmark | QThink (ours) | SFT | Base | CODI |
|---|---|---|---|---|
| GSM8k | 83.2% | 80.7% | 77.3% | 80.4% |
| MATH-500 | 43.6% | 38.2% | 33.6% | 28.0% |
| Tooluse | 48.5% | 45.6% | 47.1% | 42.6% |
Model Details
| Parameter | Value |
|---|---|
| Base model | Qwen/Qwen3-1.7B |
| Fine-tuning | LoRA (rank=32, alpha=16) |
| Distillation mode | Uniform (all rollouts) |
| Per-step distillation | Yes (K=6 steps) |
| Distillation weight (γ) | 2.0 |
| Learning rate | 0.0002 |
| Epochs | 3 |
| Batch size × grad accum × GPUs | 1 × 16 × 8 = 128 effective |
| Max answer length | 256 tokens |
| Max prompt length | 1024 tokens |
| Rollouts per problem | 16 |
| Dataset | Tooluse (from SDPO) — 4,046 train problems, 68 test problems |
Architecture
The checkpoint contains:
model.safetensors: Full Qwen3-1.7B weights with merged LoRA adaptersprojection_head.pt: ProjectionHead weights (PyTorch state dict)mlp.0: Linear(2048 → 2048) + biasmlp.2: Linear(2048 → 2048) + biasmlp.3: LayerNorm(2048)
Usage
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
class ProjectionHead(nn.Module):
def __init__(self, hidden_size=2048):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, hidden_size),
nn.LayerNorm(hidden_size),
)
def forward(self, x):
return self.mlp(x)
# Load model and projection head
repo_id = "LakshyAAAgrawal/QThink-Qwen3-1.7B-Tooluse"
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(repo_id)
proj = ProjectionHead(2048).to(model.device).to(torch.bfloat16)
proj.load_state_dict(torch.load(
hf_hub_download(repo_id, "projection_head.pt"), map_location=model.device
))
proj.eval()
model.eval()
# Prepare input
question = "What's the weather like in San Francisco?"
messages = [{"role": "user", "content": question}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
# Step 1: Process prompt
out = model(**inputs, output_hidden_states=True, use_cache=True)
past_kv = out.past_key_values
latent = out.hidden_states[-1][:, -1, :]
# Step 2: K=6 latent reasoning steps
mask = inputs["attention_mask"].clone()
for k in range(6):
latent = proj(latent)
mask = torch.cat([mask, mask.new_ones(1, 1)], dim=1)
out = model(inputs_embeds=latent.unsqueeze(1), attention_mask=mask,
past_key_values=past_kv, output_hidden_states=True, use_cache=True)
past_kv = out.past_key_values
latent = out.hidden_states[-1][:, -1, :]
# Step 3: Greedy decode answer
next_token = out.logits[:, -1, :].argmax(dim=-1)
tokens = [next_token]
eos_id = tokenizer.eos_token_id
for _ in range(2047):
if next_token.item() == eos_id:
break
mask = torch.cat([mask, mask.new_ones(1, 1)], dim=1)
out = model(input_ids=next_token.unsqueeze(0), attention_mask=mask,
past_key_values=past_kv, use_cache=True)
past_kv = out.past_key_values
next_token = out.logits[:, -1, :].argmax(dim=-1)
tokens.append(next_token)
print(tokenizer.decode(torch.cat(tokens), skip_special_tokens=True))
Citation
@misc{qthink2025,
title={QThink: Parallel Latent Reasoning via Per-Step Distillation of Multiple Rollouts},
author={Lakshya Agrawal},
year={2025},
url={https://huggingface.co/LakshyAAAgrawal/QThink-Qwen3-1.7B-Tooluse}
}
- Downloads last month
- 3
Model tree for LakshyAAAgrawal/QThink-Qwen3-1.7B-Tooluse
Evaluation results
- Accuracy on Toolusetest set self-reported48.5% EM