QThink-Qwen3-1.7B-Tooluse

QThink: Parallel Latent Reasoning via Per-Step Distillation of Multiple Rollouts

This model replaces explicit chain-of-thought (<think>...</think>) with 6 latent forward passes through a learned projection head, achieving 48.5% EM on Tooluse.

How QThink Works

Instead of generating thousands of reasoning tokens, QThink:

  1. Processes the prompt through the base model
  2. Runs K=6 latent steps: each step applies a ProjectionHead (Linear(2048,2048) → GELU → Linear(2048,2048) → LayerNorm(2048)) to the hidden state, then feeds it back through the model via inputs_embeds + past_key_values
  3. Generates the answer directly from the last latent step's hidden state — no <think> block needed

Training: Per-Step Distillation from Multiple Rollouts

  1. Generate G=16 chain-of-thought rollouts per training problem using the base model
  2. For each rollout, extract hidden states at K=6 evenly-spaced positions within the <think> block
  3. Average these hidden states across all rollouts (uniform — including incorrect ones) at each step
  4. Train each latent step to match the corresponding teacher state via L1 loss, jointly with cross-entropy on the answer

The key innovations:

  • Per-step distillation: Every latent step gets direct supervision, not just the final one
  • Uniform multi-rollout teachers: Averaging over ALL rollouts (correct + incorrect) outperforms using only correct rollouts
  • ans256 training: Training with longer answer targets (+3pp improvement)

Results

Model Exact Match Action Accuracy
QThink uniform per-step (ours) 48.5% 67.6%
Base Qwen3-1.7B 47.1% 75.0%
SFT 45.6% 73.5%
QThink RW final-step (CODI) 42.6% 67.6%

Cross-Benchmark Results

QThink uniform per-step is the best model on all 3 benchmarks:

Benchmark QThink (ours) SFT Base CODI
GSM8k 83.2% 80.7% 77.3% 80.4%
MATH-500 43.6% 38.2% 33.6% 28.0%
Tooluse 48.5% 45.6% 47.1% 42.6%

Model Details

Parameter Value
Base model Qwen/Qwen3-1.7B
Fine-tuning LoRA (rank=32, alpha=16)
Distillation mode Uniform (all rollouts)
Per-step distillation Yes (K=6 steps)
Distillation weight (γ) 2.0
Learning rate 0.0002
Epochs 3
Batch size × grad accum × GPUs 1 × 16 × 8 = 128 effective
Max answer length 256 tokens
Max prompt length 1024 tokens
Rollouts per problem 16
Dataset Tooluse (from SDPO) — 4,046 train problems, 68 test problems

Architecture

The checkpoint contains:

  • model.safetensors: Full Qwen3-1.7B weights with merged LoRA adapters
  • projection_head.pt: ProjectionHead weights (PyTorch state dict)
    • mlp.0: Linear(2048 → 2048) + bias
    • mlp.2: Linear(2048 → 2048) + bias
    • mlp.3: LayerNorm(2048)

Usage

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download

class ProjectionHead(nn.Module):
    def __init__(self, hidden_size=2048):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
        )
    def forward(self, x):
        return self.mlp(x)

# Load model and projection head
repo_id = "LakshyAAAgrawal/QThink-Qwen3-1.7B-Tooluse"
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(repo_id)
proj = ProjectionHead(2048).to(model.device).to(torch.bfloat16)
proj.load_state_dict(torch.load(
    hf_hub_download(repo_id, "projection_head.pt"), map_location=model.device
))
proj.eval()
model.eval()

# Prepare input
question = "What's the weather like in San Francisco?"
messages = [{"role": "user", "content": question}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
inputs = tokenizer(text, return_tensors="pt").to(model.device)

with torch.no_grad():
    # Step 1: Process prompt
    out = model(**inputs, output_hidden_states=True, use_cache=True)
    past_kv = out.past_key_values
    latent = out.hidden_states[-1][:, -1, :]

    # Step 2: K=6 latent reasoning steps
    mask = inputs["attention_mask"].clone()
    for k in range(6):
        latent = proj(latent)
        mask = torch.cat([mask, mask.new_ones(1, 1)], dim=1)
        out = model(inputs_embeds=latent.unsqueeze(1), attention_mask=mask,
                    past_key_values=past_kv, output_hidden_states=True, use_cache=True)
        past_kv = out.past_key_values
        latent = out.hidden_states[-1][:, -1, :]

    # Step 3: Greedy decode answer
    next_token = out.logits[:, -1, :].argmax(dim=-1)
    tokens = [next_token]
    eos_id = tokenizer.eos_token_id
    for _ in range(2047):
        if next_token.item() == eos_id:
            break
        mask = torch.cat([mask, mask.new_ones(1, 1)], dim=1)
        out = model(input_ids=next_token.unsqueeze(0), attention_mask=mask,
                    past_key_values=past_kv, use_cache=True)
        past_kv = out.past_key_values
        next_token = out.logits[:, -1, :].argmax(dim=-1)
        tokens.append(next_token)

print(tokenizer.decode(torch.cat(tokens), skip_special_tokens=True))

Citation

@misc{qthink2025,
  title={QThink: Parallel Latent Reasoning via Per-Step Distillation of Multiple Rollouts},
  author={Lakshya Agrawal},
  year={2025},
  url={https://huggingface.co/LakshyAAAgrawal/QThink-Qwen3-1.7B-Tooluse}
}
Downloads last month
3
Safetensors
Model size
2B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for LakshyAAAgrawal/QThink-Qwen3-1.7B-Tooluse

Finetuned
Qwen/Qwen3-1.7B
Adapter
(446)
this model

Evaluation results