File size: 12,143 Bytes
541bc33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70d10ec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
"""
BitNet b1.58 2B4T β€” CPU-Only Inference Explorer
================================================
A Gradio demo showcasing Microsoft's first open-source native 1-bit LLM.
All inference runs on CPU β€” no GPU required.

Paper: https://arxiv.org/abs/2504.12285
Model: https://huggingface.co/microsoft/bitnet-b1.58-2B-4T
"""

import os
import time
import threading
import psutil
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

# ─── Configuration ───────────────────────────────────────────────────────────
MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
MAX_CONTEXT = 4096

# ─── Load Model (CPU-only) ──────────────────────────────────────────────────
print(f"Loading {MODEL_ID} on CPU...")
t0 = time.time()

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="cpu",
    low_cpu_mem_usage=True,
)
model.eval()

load_time = time.time() - t0
proc = psutil.Process(os.getpid())
model_mem_gb = proc.memory_info().rss / 1024**3

print(f"βœ“ Model loaded in {load_time:.1f}s | RSS: {model_mem_gb:.2f} GB")

# ─── System Info ─────────────────────────────────────────────────────────────
cpu_count = psutil.cpu_count(logical=True)
total_ram = psutil.virtual_memory().total / 1024**3

SYSTEM_INFO = f"""### System
| Metric | Value |
|---|---|
| CPU cores | {cpu_count} |
| Total RAM | {total_ram:.1f} GB |
| Model RSS | {model_mem_gb:.2f} GB |
| Load time | {load_time:.1f}s |
| Weights | 1.58-bit ternary ({{-1, 0, +1}}) |
| Activations | 8-bit integer |
| Context | {MAX_CONTEXT} tokens |
"""

# ─── Paper benchmark table (from Table 1 of the paper) ──────────────────────
PAPER_TABLE = """### Published Benchmarks (from the paper)

| Benchmark | LLaMA 3.2 1B | Gemma-3 1B | Qwen2.5 1.5B | SmolLM2 1.7B | **BitNet 2B** |
|---|---|---|---|---|---|
| **Memory** | 2 GB | 1.4 GB | 2.6 GB | 3.2 GB | **0.4 GB** |
| **CPU Latency** | 48ms | 41ms | 65ms | 67ms | **29ms** |
| **Energy/token** | 0.258J | 0.186J | 0.347J | 0.425J | **0.028J** |
| ARC-Challenge | 37.8 | 38.4 | 46.7 | 43.5 | **49.9** |
| WinoGrande | 59.5 | 58.5 | 62.8 | 69.0 | **71.9** |
| GSM8K | 38.2 | 31.2 | 56.8 | 45.1 | **58.4** |
| MMLU | 45.6 | 39.9 | **60.3** | 49.2 | 53.2 |
| HumanEval+ | 31.1 | 37.2 | **50.6** | 28.0 | 38.4 |
| **Average** | 44.9 | 43.7 | **55.2** | 48.7 | 54.2 |

*BitNet uses 5-13Γ— less memory and 6-9Γ— less energy than comparable models.*

> ⚠️ **Note:** This demo uses the `transformers` library, which does **not** include
> the specialized `bitnet.cpp` kernels. For the CPU latency numbers shown above,
> use [bitnet.cpp](https://github.com/microsoft/BitNet) with the GGUF weights.
"""

# ─── Architecture explainer ──────────────────────────────────────────────────
ARCHITECTURE_MD = """### How BitNet b1.58 Works

```
Standard Transformer          β†’  BitNet b1.58
─────────────────────          ─────────────────
FP16/BF16 weights (16 bits)   β†’  Ternary weights: {-1, 0, +1} (1.58 bits)
FP16 activations               β†’  INT8 activations (absmax per-token)
nn.Linear                      β†’  BitLinear (absmean quantization)
SwiGLU activation               β†’  Squared ReLU (ReLUΒ²)
LayerNorm                       β†’  SubLN normalization
Standard MatMul                 β†’  Additions only (no multiplications!)
```

**Key Insight:** Since weights are only -1, 0, or +1, matrix multiplication
becomes pure addition/subtraction. This is why CPUs can run BitNet models
so efficiently β€” you don't need floating-point multiply hardware at all.

**Training:** The model was trained **from scratch** with this quantization,
not post-training quantized. This is crucial β€” native 1-bit training preserves
quality far better than quantizing a pre-trained FP16 model down to 1-bit.

**3-Stage Training Pipeline:**
1. **Pre-training** on 4T tokens (text, code, synthetic math)
2. **SFT** on instruction-following datasets
3. **DPO** for alignment with human preferences
"""

# ─── Generation functions ────────────────────────────────────────────────────

def chat_respond(message, history, system_prompt, max_new_tokens, temperature, top_p):
    """Streaming chat with live token/sec stats."""
    messages = [{"role": "system", "content": system_prompt}]
    for item in history:
        messages.append(item)
    messages.append({"role": "user", "content": message})

    inputs = tokenizer.apply_chat_template(
        messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
    )

    # Truncate to max context
    if inputs.shape[1] > MAX_CONTEXT - max_new_tokens:
        inputs = inputs[:, -(MAX_CONTEXT - max_new_tokens):]

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    gen_kwargs = dict(
        input_ids=inputs,
        attention_mask=torch.ones_like(inputs),
        pad_token_id=tokenizer.eos_token_id,
        streamer=streamer,
        max_new_tokens=int(max_new_tokens),
        do_sample=temperature > 0,
        use_cache=True,
    )
    if temperature > 0:
        gen_kwargs["temperature"] = float(temperature)
        gen_kwargs["top_p"] = float(top_p)

    t0 = time.perf_counter()
    thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
    thread.start()

    response = ""
    tok_count = 0
    for chunk in streamer:
        response += chunk
        tok_count += 1
        elapsed = time.perf_counter() - t0
        tps = tok_count / elapsed if elapsed > 0 else 0
        stats = f"\n\n---\n*⏱ {tok_count} tokens · {tps:.1f} tok/s · {elapsed:.1f}s*"
        yield response + stats

    thread.join()


def single_benchmark(prompt, max_new_tokens):
    """Run a single non-streaming generation with detailed stats."""
    messages = [
        {"role": "system", "content": "You are a helpful AI assistant."},
        {"role": "user", "content": prompt},
    ]
    inputs = tokenizer.apply_chat_template(
        messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
    )
    input_len = inputs.shape[1]

    mem_before = proc.memory_info().rss / 1024**3

    t0 = time.perf_counter()
    with torch.no_grad():
        output = model.generate(
            inputs,
            attention_mask=torch.ones_like(inputs),
            pad_token_id=tokenizer.eos_token_id,
            max_new_tokens=int(max_new_tokens),
            do_sample=False,
            use_cache=True,
        )
    elapsed = time.perf_counter() - t0

    mem_after = proc.memory_info().rss / 1024**3
    n_generated = output.shape[-1] - input_len
    tps = n_generated / elapsed if elapsed > 0 else 0

    response = tokenizer.decode(output[0][input_len:], skip_special_tokens=True)

    stats_md = f"""### Benchmark Results

| Metric | Value |
|---|---|
| Input tokens | {input_len} |
| Output tokens | {n_generated} |
| Total time | {elapsed:.2f}s |
| **Tokens/sec** | **{tps:.2f}** |
| Avg ms/token | {(elapsed/n_generated*1000):.1f}ms |
| Memory before | {mem_before:.2f} GB |
| Memory after | {mem_after:.2f} GB |
| Memory delta | {(mem_after - mem_before)*1000:.1f} MB |
"""
    return response, stats_md


# ─── Build Gradio UI ─────────────────────────────────────────────────────────

HEADER = """# 🧬 BitNet b1.58 2B4T β€” CPU-Only Inference Explorer

**The first open-source native 1-bit LLM** by Microsoft Research.
All weights are ternary {-1, 0, +1} β€” no floating-point multiplications needed.

| | |
|---|---|
| πŸ“„ [Paper](https://arxiv.org/abs/2504.12285) | πŸ€— [Model](https://huggingface.co/microsoft/bitnet-b1.58-2B-4T) |
| πŸ’» [bitnet.cpp](https://github.com/microsoft/BitNet) (38K+ ⭐) | πŸ“Š 2B params Β· 4T training tokens Β· 1.1 GB weights |
"""

with gr.Blocks(
    title="BitNet b1.58 2B4T β€” CPU Inference Explorer",
) as demo:

    gr.Markdown(HEADER)

    with gr.Tabs():
        # ── Tab 1: Chat ──────────────────────────────────────────────────
        with gr.Tab("πŸ’¬ Chat", id="chat"):
            chat = gr.ChatInterface(
                fn=chat_respond,
                description="Chat with BitNet b1.58 on CPU. Token/sec stats shown after each response.",
                additional_inputs=[
                    gr.Textbox(
                        value="You are a helpful, concise AI assistant.",
                        label="System Prompt",
                    ),
                    gr.Slider(1, 2048, value=256, step=1, label="Max New Tokens"),
                    gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature (0 = greedy)"),
                    gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
                ],
                examples=[
                    ["Explain what a 1-bit LLM is in 3 sentences."],
                    ["Write a Python function to find the nth Fibonacci number."],
                    ["What are the pros and cons of running AI on CPUs vs GPUs?"],
                    ["Solve: If 3x + 7 = 22, what is x?"],
                ],
                cache_examples=False,
            )

        # ── Tab 2: Benchmark ─────────────────────────────────────────────
        with gr.Tab("πŸ“Š Benchmark", id="bench"):
            gr.Markdown("### Run a single-shot benchmark (greedy decoding)")
            with gr.Row():
                with gr.Column(scale=2):
                    bench_prompt = gr.Textbox(
                        value="Write a detailed explanation of how transformer neural networks work, covering attention mechanisms, positional encoding, and the training process.",
                        label="Prompt",
                        lines=3,
                    )
                    bench_tokens = gr.Slider(16, 512, value=128, step=16, label="Max New Tokens")
                    bench_btn = gr.Button("πŸš€ Run Benchmark", variant="primary")
                with gr.Column(scale=1):
                    bench_stats = gr.Markdown("*Click 'Run Benchmark' to start*")

            bench_output = gr.Textbox(label="Generated Text", lines=10, interactive=False)
            bench_btn.click(
                fn=single_benchmark,
                inputs=[bench_prompt, bench_tokens],
                outputs=[bench_output, bench_stats],
            )

        # ── Tab 3: Paper Results ─────────────────────────────────────────
        with gr.Tab("πŸ“ˆ Paper Results", id="paper"):
            gr.Markdown(PAPER_TABLE)

        # ── Tab 4: Architecture ──────────────────────────────────────────
        with gr.Tab("πŸ—οΈ Architecture", id="arch"):
            gr.Markdown(ARCHITECTURE_MD)

        # ── Tab 5: System Info ───────────────────────────────────────────
        with gr.Tab("βš™οΈ System", id="sys"):
            gr.Markdown(SYSTEM_INFO)


if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft())