Text Generation
Transformers
Safetensors
abstract-cot
latent-reasoning
math-reasoning
qwen3
leapeto's picture
Add files using upload-large-folder tool
a555798 verified
---
license: apache-2.0
library_name: transformers
base_model: Qwen/Qwen3-4B
pipeline_tag: text-generation
tags:
- abstract-cot
- latent-reasoning
- math-reasoning
- qwen3
datasets:
- HuggingFaceH4/MATH-500
- allenai/Dolci-Think-SFT-7B
---
# Qwen3-4B-AbstractCoT-warmup
Qwen3-4B fine-tuned with the **Abstract Chain-of-Thought (Abstract-CoT)** warm-up procedure from "[Thinking Without Words: Efficient Latent Reasoning with Abstract Chain-of-Thought](https://arxiv.org/abs/2604.22709v2)" (Ramji, Naseem, Fernandez Astudillo, IBM Research AI, 2026). The model is taught to compress its reasoning into a short sequence (~16–22 tokens) drawn from a reserved 64-symbol *abstract vocabulary* `V_abs = {<TOKEN_A>, …, <TOKEN_BL>}`, used as a discrete latent scratchpad before emitting the answer.
```
prompt ─► <beginabstract> z_1 ... z_m <endabstract> answer
└─────── z̃ ∈ V_abs^m, m ≤ 128 ───────┘
```
This is the SFT half of the paper only — no RL stage. The comparison row is the paper's "Abstract-CoT (Warm-up)" line in Table 1.
## Headline result
| | MATH-500 acc | Mean tokens |
|---|---|---|
| Paper Baseline (Qwen3-4B verbal CoT) | 83.2 | 1087 |
| **Our Baseline** (Qwen3-4B verbal CoT, this hardware) | 84.60 | 1045 |
| Paper Abstract-CoT Warm-up | 86.2 | 168 |
| **This model** (T=3 PI, N=5k, 1 epoch, LoRA, seq 8k) | **72.00** | **432** |
The accuracy gap to the paper's 86.2 is driven by reduced data scale (5k vs 600k), LoRA vs full fine-tuning, and 1 vs 3 epochs per phase. See `docs/20260511_reader.md` for a full discussion.
## Repository layout
```
final/ ← end-of-round-3 merged model (THE warm-up checkpoint)
round2/ ← end-of-round-2 merged model
round1/ ← end-of-round-1 merged model
adapters/ ← all 6 LoRA adapters (pi{1,2,3}_phase{A,B})
results/ ← per-example eval JSONL (baseline + abstract)
teacher_traces/ ← on-policy V_abs traces used as Phase B/A teachers
train_logs/ ← per-phase loss + LR curves (verifies cosine fix)
docs/ ← run reports (technical + reader-oriented)
```
## How it was trained
Three policy-iteration rounds, each with two phases:
- **Phase A — Bottleneck SFT.** Train on `[prompt; verbal-CoT; z̃; answer]` with the answer blocked from attending to the verbal CoT, forcing all CoT→answer signal through `z̃`.
- Round 1: `z̃` is random V_abs tokens.
- Rounds 2+: `z̃` is sampled on-policy from the previous round's model.
- **Phase B — Self-distillation.** Train on `[prompt; z̃; answer]` with standard causal attention, where `z̃` is now generated from the prompt alone.
Training config:
- Base: `Qwen/Qwen3-4B`, extended with V_abs (M=64) + `<beginabstract>` + `<endabstract>` (151 669 → 151 735 tokens).
- LoRA r=32, α=64 on attention + MLP projections. Embedding table + LM head trained fully (so the new abstract-vocab rows can move freely). 842.9 M / 4.86 B trainable (17.3%).
- Data: 5 000 examples from `allenai/Dolci-Think-SFT-7B`, filtered to assistant messages with `<think>` blocks ≥ 200 chars.
- max_len 8192, batch 32, lr 1e-4, cosine schedule, 5% warmup.
- 2× A100-SXM4-80GB, ~11 hours wall.
## Using the model
### Inference (vLLM, recommended)
```python
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from huggingface_hub import snapshot_download
# Download the final checkpoint
model_path = snapshot_download(
"leapeto/Qwen3-4B-AbstractCoT-warmup",
allow_patterns=["final/*"],
)
tok = AutoTokenizer.from_pretrained(f"{model_path}/final", trust_remote_code=True)
# Abstract token ids
abs_tokens = []
for i in range(64):
if i < 26:
abs_tokens.append(f"<TOKEN_{chr(ord('A')+i)}>")
else:
j = i - 26
abs_tokens.append(f"<TOKEN_{chr(ord('A')+j//26)}{chr(ord('A')+j%26)}>")
end_id = tok.convert_tokens_to_ids("<endabstract>")
abs_ids = tok.convert_tokens_to_ids(abs_tokens)
allowed = list(set(abs_ids + [end_id]))
llm = LLM(model=f"{model_path}/final", tensor_parallel_size=2,
dtype="bfloat16", trust_remote_code=True)
# Two-stage decode: (1) constrained abstract trace, (2) unconstrained answer
prompt = "What is the integral of x^2 from 0 to 1? Put your final answer in \\boxed{}."
messages = [
{"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."},
{"role": "user", "content": prompt},
]
prefix = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
prefix += "<beginabstract>"
# Stage 1: V_abs only, stop at <endabstract>
sp1 = SamplingParams(temperature=0.7, max_tokens=128,
allowed_token_ids=allowed, stop_token_ids=[end_id],
skip_special_tokens=False)
abstract = llm.generate([prefix], sp1)[0].outputs[0].text
prompt2 = prefix + abstract + "<endabstract>\n"
# Stage 2: unconstrained answer
sp2 = SamplingParams(temperature=0.0, max_tokens=2048)
answer = llm.generate([prompt2], sp2)[0].outputs[0].text
print(answer)
```
### Loading the LoRA adapters (peft)
If you want to inspect individual round outputs without downloading the merged models:
```python
from peft import PeftModel
from transformers import AutoModelForCausalLM
from huggingface_hub import snapshot_download
# You'll need the extended base model first — produce it locally via scripts/01_extend_model.sh
# OR start from one of our merged checkpoints and load a later adapter on top.
base = AutoModelForCausalLM.from_pretrained("path/to/extended/base", trust_remote_code=True)
adapter_path = snapshot_download(
"leapeto/Qwen3-4B-AbstractCoT-warmup",
allow_patterns=["adapters/pi3_phaseB/*"],
)
model = PeftModel.from_pretrained(base, f"{adapter_path}/adapters/pi3_phaseB")
```
## Files of interest
| File | What |
|---|---|
| `final/` | End-of-round-3 merged model. **This is the main artifact.** |
| `round1/`, `round2/` | Intermediate merged models for studying T=1 → T=2 → T=3 progression |
| `adapters/pi{1,2,3}_phase{A,B}/` | LoRA-only checkpoints from each phase |
| `results/baseline_math500.jsonl` | Qwen3-4B verbal-CoT eval (84.60% / 1045 tok) |
| `results/abstract_math500_T3_N5000.jsonl` | This model's eval (72.00% / 432 tok) |
| `train_logs/*.json` | Per-step loss + LR curves for each phase |
| `docs/20260511.md` | Technical report (full breakdown) |
| `docs/20260511_reader.md` | Reader-oriented report (concepts + reasoning) |
## Citation
```bibtex
@article{ramji2026thinking,
title={Thinking Without Words: Efficient Latent Reasoning with Abstract Chain-of-Thought},
author={Ramji, Keshav and Naseem, Tahira and Fernandez Astudillo, Ramón},
journal={arXiv preprint arXiv:2604.22709},
year={2026}
}
```