QThink-Qwen3-1.7B-GSM8k
QThink: Parallel Latent Reasoning via Per-Step Distillation of Multiple Rollouts
This model replaces explicit chain-of-thought (<think>...</think>) with 6 latent forward passes through a learned projection head, achieving 83.2% on GSM8k.
How QThink Works
Instead of generating thousands of reasoning tokens, QThink:
- Processes the prompt through the base model
- Runs K=6 latent steps: each step applies a ProjectionHead (
Linear(2048,2048) → GELU → Linear(2048,2048) → LayerNorm(2048)) to the hidden state, then feeds it back through the model viainputs_embeds+past_key_values - Generates the answer directly from the last latent step's hidden state — no
<think>block needed
Training: Per-Step Distillation from Multiple Rollouts
- Generate G=16 chain-of-thought rollouts per training problem using the base model
- For each rollout, extract hidden states at K=6 evenly-spaced positions within the
<think>block - Average these hidden states across all rollouts (uniform — including incorrect ones) at each step
- Train each latent step to match the corresponding teacher state via L1 loss, jointly with cross-entropy on the answer
The key innovations:
- Per-step distillation: Every latent step gets direct supervision, not just the final one
- Uniform multi-rollout teachers: Averaging over ALL rollouts (correct + incorrect) outperforms using only correct rollouts
- ans256 training: Training with longer answer targets (+3pp improvement)
Results
| Model | GSM8k Accuracy |
|---|---|
| QThink uniform per-step (ours) | 83.2% |
| QThink RW per-step (ours) | 82.7% |
| QThink RW final-step (CODI) | 80.4% |
| SFT | 80.7% |
| Base Qwen3-1.7B | 77.3% |
Token Efficiency
QThink generates only answer tokens (no <think>...</think> block):
| Model | Accuracy | Mean Tokens |
|---|---|---|
| QThink (ours) | 83.2% | 1,025 (answer only) |
| SFT | 80.7% | 1,279 (think + answer) |
| Base Qwen3-1.7B | 77.3% | 1,399 (think + answer) |
QThink achieves better accuracy with 20% fewer total tokens than SFT.
Cross-Benchmark Results
QThink uniform per-step is the best model on all 3 benchmarks:
| Benchmark | QThink (ours) | SFT | Base | CODI |
|---|---|---|---|---|
| GSM8k | 83.2% | 80.7% | 77.3% | 80.4% |
| MATH-500 | 43.6% | 38.2% | 33.6% | 28.0% |
| Tooluse | 48.5% | 45.6% | 47.1% | 42.6% |
Model Details
| Parameter | Value |
|---|---|
| Base model | Qwen/Qwen3-1.7B |
| Fine-tuning | LoRA (rank=32, alpha=16) |
| Distillation mode | Uniform (all rollouts) |
| Per-step distillation | Yes (K=6 steps) |
| Distillation weight (γ) | 2.0 |
| Learning rate | 0.0002 |
| Epochs | 3 |
| Batch size × grad accum × GPUs | 1 × 16 × 8 = 128 effective |
| Max answer length | 256 tokens |
| Max prompt length | 512 tokens |
| Rollouts per problem | 16 |
| Dataset | GSM8k — 7,473 train problems, 1,319 test problems |
Architecture
The checkpoint contains:
model.safetensors: Full Qwen3-1.7B weights with merged LoRA adaptersprojection_head.pt: ProjectionHead weights (PyTorch state dict)mlp.0: Linear(2048 → 2048) + biasmlp.2: Linear(2048 → 2048) + biasmlp.3: LayerNorm(2048)
Usage
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
class ProjectionHead(nn.Module):
def __init__(self, hidden_size=2048):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, hidden_size),
nn.LayerNorm(hidden_size),
)
def forward(self, x):
return self.mlp(x)
# Load model and projection head
repo_id = "LakshyAAAgrawal/QThink-Qwen3-1.7B-GSM8k"
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(repo_id)
proj = ProjectionHead(2048).to(model.device).to(torch.bfloat16)
proj.load_state_dict(torch.load(
hf_hub_download(repo_id, "projection_head.pt"), map_location=model.device
))
proj.eval()
model.eval()
# Prepare input
question = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"
messages = [{"role": "user", "content": question}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
# Step 1: Process prompt
out = model(**inputs, output_hidden_states=True, use_cache=True)
past_kv = out.past_key_values
latent = out.hidden_states[-1][:, -1, :]
# Step 2: K=6 latent reasoning steps
mask = inputs["attention_mask"].clone()
for k in range(6):
latent = proj(latent)
mask = torch.cat([mask, mask.new_ones(1, 1)], dim=1)
out = model(inputs_embeds=latent.unsqueeze(1), attention_mask=mask,
past_key_values=past_kv, output_hidden_states=True, use_cache=True)
past_kv = out.past_key_values
latent = out.hidden_states[-1][:, -1, :]
# Step 3: Greedy decode answer
next_token = out.logits[:, -1, :].argmax(dim=-1)
tokens = [next_token]
eos_id = tokenizer.eos_token_id
for _ in range(2047):
if next_token.item() == eos_id:
break
mask = torch.cat([mask, mask.new_ones(1, 1)], dim=1)
out = model(input_ids=next_token.unsqueeze(0), attention_mask=mask,
past_key_values=past_kv, use_cache=True)
past_kv = out.past_key_values
next_token = out.logits[:, -1, :].argmax(dim=-1)
tokens.append(next_token)
print(tokenizer.decode(torch.cat(tokens), skip_special_tokens=True))
Citation
@misc{qthink2025,
title={QThink: Parallel Latent Reasoning via Per-Step Distillation of Multiple Rollouts},
author={Lakshya Agrawal},
year={2025},
url={https://huggingface.co/LakshyAAAgrawal/QThink-Qwen3-1.7B-GSM8k}
}
- Downloads last month
- 6
Model tree for LakshyAAAgrawal/QThink-Qwen3-1.7B-GSM8k
Dataset used to train LakshyAAAgrawal/QThink-Qwen3-1.7B-GSM8k
Evaluation results
- Accuracy on GSM8ktest set self-reported83.2%