QThink-Qwen3-1.7B-GSM8k

QThink: Parallel Latent Reasoning via Per-Step Distillation of Multiple Rollouts

This model replaces explicit chain-of-thought (<think>...</think>) with 6 latent forward passes through a learned projection head, achieving 83.2% on GSM8k.

How QThink Works

Instead of generating thousands of reasoning tokens, QThink:

  1. Processes the prompt through the base model
  2. Runs K=6 latent steps: each step applies a ProjectionHead (Linear(2048,2048) → GELU → Linear(2048,2048) → LayerNorm(2048)) to the hidden state, then feeds it back through the model via inputs_embeds + past_key_values
  3. Generates the answer directly from the last latent step's hidden state — no <think> block needed

Training: Per-Step Distillation from Multiple Rollouts

  1. Generate G=16 chain-of-thought rollouts per training problem using the base model
  2. For each rollout, extract hidden states at K=6 evenly-spaced positions within the <think> block
  3. Average these hidden states across all rollouts (uniform — including incorrect ones) at each step
  4. Train each latent step to match the corresponding teacher state via L1 loss, jointly with cross-entropy on the answer

The key innovations:

  • Per-step distillation: Every latent step gets direct supervision, not just the final one
  • Uniform multi-rollout teachers: Averaging over ALL rollouts (correct + incorrect) outperforms using only correct rollouts
  • ans256 training: Training with longer answer targets (+3pp improvement)

Results

Model GSM8k Accuracy
QThink uniform per-step (ours) 83.2%
QThink RW per-step (ours) 82.7%
QThink RW final-step (CODI) 80.4%
SFT 80.7%
Base Qwen3-1.7B 77.3%

Token Efficiency

QThink generates only answer tokens (no <think>...</think> block):

Model Accuracy Mean Tokens
QThink (ours) 83.2% 1,025 (answer only)
SFT 80.7% 1,279 (think + answer)
Base Qwen3-1.7B 77.3% 1,399 (think + answer)

QThink achieves better accuracy with 20% fewer total tokens than SFT.

Cross-Benchmark Results

QThink uniform per-step is the best model on all 3 benchmarks:

Benchmark QThink (ours) SFT Base CODI
GSM8k 83.2% 80.7% 77.3% 80.4%
MATH-500 43.6% 38.2% 33.6% 28.0%
Tooluse 48.5% 45.6% 47.1% 42.6%

Model Details

Parameter Value
Base model Qwen/Qwen3-1.7B
Fine-tuning LoRA (rank=32, alpha=16)
Distillation mode Uniform (all rollouts)
Per-step distillation Yes (K=6 steps)
Distillation weight (γ) 2.0
Learning rate 0.0002
Epochs 3
Batch size × grad accum × GPUs 1 × 16 × 8 = 128 effective
Max answer length 256 tokens
Max prompt length 512 tokens
Rollouts per problem 16
Dataset GSM8k — 7,473 train problems, 1,319 test problems

Architecture

The checkpoint contains:

  • model.safetensors: Full Qwen3-1.7B weights with merged LoRA adapters
  • projection_head.pt: ProjectionHead weights (PyTorch state dict)
    • mlp.0: Linear(2048 → 2048) + bias
    • mlp.2: Linear(2048 → 2048) + bias
    • mlp.3: LayerNorm(2048)

Usage

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download

class ProjectionHead(nn.Module):
    def __init__(self, hidden_size=2048):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
        )
    def forward(self, x):
        return self.mlp(x)

# Load model and projection head
repo_id = "LakshyAAAgrawal/QThink-Qwen3-1.7B-GSM8k"
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(repo_id)
proj = ProjectionHead(2048).to(model.device).to(torch.bfloat16)
proj.load_state_dict(torch.load(
    hf_hub_download(repo_id, "projection_head.pt"), map_location=model.device
))
proj.eval()
model.eval()

# Prepare input
question = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"
messages = [{"role": "user", "content": question}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
inputs = tokenizer(text, return_tensors="pt").to(model.device)

with torch.no_grad():
    # Step 1: Process prompt
    out = model(**inputs, output_hidden_states=True, use_cache=True)
    past_kv = out.past_key_values
    latent = out.hidden_states[-1][:, -1, :]

    # Step 2: K=6 latent reasoning steps
    mask = inputs["attention_mask"].clone()
    for k in range(6):
        latent = proj(latent)
        mask = torch.cat([mask, mask.new_ones(1, 1)], dim=1)
        out = model(inputs_embeds=latent.unsqueeze(1), attention_mask=mask,
                    past_key_values=past_kv, output_hidden_states=True, use_cache=True)
        past_kv = out.past_key_values
        latent = out.hidden_states[-1][:, -1, :]

    # Step 3: Greedy decode answer
    next_token = out.logits[:, -1, :].argmax(dim=-1)
    tokens = [next_token]
    eos_id = tokenizer.eos_token_id
    for _ in range(2047):
        if next_token.item() == eos_id:
            break
        mask = torch.cat([mask, mask.new_ones(1, 1)], dim=1)
        out = model(input_ids=next_token.unsqueeze(0), attention_mask=mask,
                    past_key_values=past_kv, use_cache=True)
        past_kv = out.past_key_values
        next_token = out.logits[:, -1, :].argmax(dim=-1)
        tokens.append(next_token)

print(tokenizer.decode(torch.cat(tokens), skip_special_tokens=True))

Citation

@misc{qthink2025,
  title={QThink: Parallel Latent Reasoning via Per-Step Distillation of Multiple Rollouts},
  author={Lakshya Agrawal},
  year={2025},
  url={https://huggingface.co/LakshyAAAgrawal/QThink-Qwen3-1.7B-GSM8k}
}
Downloads last month
6
Safetensors
Model size
2B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for LakshyAAAgrawal/QThink-Qwen3-1.7B-GSM8k

Finetuned
Qwen/Qwen3-1.7B
Adapter
(446)
this model

Dataset used to train LakshyAAAgrawal/QThink-Qwen3-1.7B-GSM8k

Evaluation results