knoxel's picture
fix: Gradio 6.x compat β€” remove type=messages, move theme to launch()
70d10ec verified
"""
BitNet b1.58 2B4T β€” CPU-Only Inference Explorer
================================================
A Gradio demo showcasing Microsoft's first open-source native 1-bit LLM.
All inference runs on CPU β€” no GPU required.
Paper: https://arxiv.org/abs/2504.12285
Model: https://huggingface.co/microsoft/bitnet-b1.58-2B-4T
"""
import os
import time
import threading
import psutil
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# ─── Configuration ───────────────────────────────────────────────────────────
MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
MAX_CONTEXT = 4096
# ─── Load Model (CPU-only) ──────────────────────────────────────────────────
print(f"Loading {MODEL_ID} on CPU...")
t0 = time.time()
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="cpu",
low_cpu_mem_usage=True,
)
model.eval()
load_time = time.time() - t0
proc = psutil.Process(os.getpid())
model_mem_gb = proc.memory_info().rss / 1024**3
print(f"βœ“ Model loaded in {load_time:.1f}s | RSS: {model_mem_gb:.2f} GB")
# ─── System Info ─────────────────────────────────────────────────────────────
cpu_count = psutil.cpu_count(logical=True)
total_ram = psutil.virtual_memory().total / 1024**3
SYSTEM_INFO = f"""### System
| Metric | Value |
|---|---|
| CPU cores | {cpu_count} |
| Total RAM | {total_ram:.1f} GB |
| Model RSS | {model_mem_gb:.2f} GB |
| Load time | {load_time:.1f}s |
| Weights | 1.58-bit ternary ({{-1, 0, +1}}) |
| Activations | 8-bit integer |
| Context | {MAX_CONTEXT} tokens |
"""
# ─── Paper benchmark table (from Table 1 of the paper) ──────────────────────
PAPER_TABLE = """### Published Benchmarks (from the paper)
| Benchmark | LLaMA 3.2 1B | Gemma-3 1B | Qwen2.5 1.5B | SmolLM2 1.7B | **BitNet 2B** |
|---|---|---|---|---|---|
| **Memory** | 2 GB | 1.4 GB | 2.6 GB | 3.2 GB | **0.4 GB** |
| **CPU Latency** | 48ms | 41ms | 65ms | 67ms | **29ms** |
| **Energy/token** | 0.258J | 0.186J | 0.347J | 0.425J | **0.028J** |
| ARC-Challenge | 37.8 | 38.4 | 46.7 | 43.5 | **49.9** |
| WinoGrande | 59.5 | 58.5 | 62.8 | 69.0 | **71.9** |
| GSM8K | 38.2 | 31.2 | 56.8 | 45.1 | **58.4** |
| MMLU | 45.6 | 39.9 | **60.3** | 49.2 | 53.2 |
| HumanEval+ | 31.1 | 37.2 | **50.6** | 28.0 | 38.4 |
| **Average** | 44.9 | 43.7 | **55.2** | 48.7 | 54.2 |
*BitNet uses 5-13Γ— less memory and 6-9Γ— less energy than comparable models.*
> ⚠️ **Note:** This demo uses the `transformers` library, which does **not** include
> the specialized `bitnet.cpp` kernels. For the CPU latency numbers shown above,
> use [bitnet.cpp](https://github.com/microsoft/BitNet) with the GGUF weights.
"""
# ─── Architecture explainer ──────────────────────────────────────────────────
ARCHITECTURE_MD = """### How BitNet b1.58 Works
```
Standard Transformer β†’ BitNet b1.58
───────────────────── ─────────────────
FP16/BF16 weights (16 bits) β†’ Ternary weights: {-1, 0, +1} (1.58 bits)
FP16 activations β†’ INT8 activations (absmax per-token)
nn.Linear β†’ BitLinear (absmean quantization)
SwiGLU activation β†’ Squared ReLU (ReLUΒ²)
LayerNorm β†’ SubLN normalization
Standard MatMul β†’ Additions only (no multiplications!)
```
**Key Insight:** Since weights are only -1, 0, or +1, matrix multiplication
becomes pure addition/subtraction. This is why CPUs can run BitNet models
so efficiently β€” you don't need floating-point multiply hardware at all.
**Training:** The model was trained **from scratch** with this quantization,
not post-training quantized. This is crucial β€” native 1-bit training preserves
quality far better than quantizing a pre-trained FP16 model down to 1-bit.
**3-Stage Training Pipeline:**
1. **Pre-training** on 4T tokens (text, code, synthetic math)
2. **SFT** on instruction-following datasets
3. **DPO** for alignment with human preferences
"""
# ─── Generation functions ────────────────────────────────────────────────────
def chat_respond(message, history, system_prompt, max_new_tokens, temperature, top_p):
"""Streaming chat with live token/sec stats."""
messages = [{"role": "system", "content": system_prompt}]
for item in history:
messages.append(item)
messages.append({"role": "user", "content": message})
inputs = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
)
# Truncate to max context
if inputs.shape[1] > MAX_CONTEXT - max_new_tokens:
inputs = inputs[:, -(MAX_CONTEXT - max_new_tokens):]
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
input_ids=inputs,
attention_mask=torch.ones_like(inputs),
pad_token_id=tokenizer.eos_token_id,
streamer=streamer,
max_new_tokens=int(max_new_tokens),
do_sample=temperature > 0,
use_cache=True,
)
if temperature > 0:
gen_kwargs["temperature"] = float(temperature)
gen_kwargs["top_p"] = float(top_p)
t0 = time.perf_counter()
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
response = ""
tok_count = 0
for chunk in streamer:
response += chunk
tok_count += 1
elapsed = time.perf_counter() - t0
tps = tok_count / elapsed if elapsed > 0 else 0
stats = f"\n\n---\n*⏱ {tok_count} tokens · {tps:.1f} tok/s · {elapsed:.1f}s*"
yield response + stats
thread.join()
def single_benchmark(prompt, max_new_tokens):
"""Run a single non-streaming generation with detailed stats."""
messages = [
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": prompt},
]
inputs = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
)
input_len = inputs.shape[1]
mem_before = proc.memory_info().rss / 1024**3
t0 = time.perf_counter()
with torch.no_grad():
output = model.generate(
inputs,
attention_mask=torch.ones_like(inputs),
pad_token_id=tokenizer.eos_token_id,
max_new_tokens=int(max_new_tokens),
do_sample=False,
use_cache=True,
)
elapsed = time.perf_counter() - t0
mem_after = proc.memory_info().rss / 1024**3
n_generated = output.shape[-1] - input_len
tps = n_generated / elapsed if elapsed > 0 else 0
response = tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
stats_md = f"""### Benchmark Results
| Metric | Value |
|---|---|
| Input tokens | {input_len} |
| Output tokens | {n_generated} |
| Total time | {elapsed:.2f}s |
| **Tokens/sec** | **{tps:.2f}** |
| Avg ms/token | {(elapsed/n_generated*1000):.1f}ms |
| Memory before | {mem_before:.2f} GB |
| Memory after | {mem_after:.2f} GB |
| Memory delta | {(mem_after - mem_before)*1000:.1f} MB |
"""
return response, stats_md
# ─── Build Gradio UI ─────────────────────────────────────────────────────────
HEADER = """# 🧬 BitNet b1.58 2B4T β€” CPU-Only Inference Explorer
**The first open-source native 1-bit LLM** by Microsoft Research.
All weights are ternary {-1, 0, +1} β€” no floating-point multiplications needed.
| | |
|---|---|
| πŸ“„ [Paper](https://arxiv.org/abs/2504.12285) | πŸ€— [Model](https://huggingface.co/microsoft/bitnet-b1.58-2B-4T) |
| πŸ’» [bitnet.cpp](https://github.com/microsoft/BitNet) (38K+ ⭐) | πŸ“Š 2B params Β· 4T training tokens Β· 1.1 GB weights |
"""
with gr.Blocks(
title="BitNet b1.58 2B4T β€” CPU Inference Explorer",
) as demo:
gr.Markdown(HEADER)
with gr.Tabs():
# ── Tab 1: Chat ──────────────────────────────────────────────────
with gr.Tab("πŸ’¬ Chat", id="chat"):
chat = gr.ChatInterface(
fn=chat_respond,
description="Chat with BitNet b1.58 on CPU. Token/sec stats shown after each response.",
additional_inputs=[
gr.Textbox(
value="You are a helpful, concise AI assistant.",
label="System Prompt",
),
gr.Slider(1, 2048, value=256, step=1, label="Max New Tokens"),
gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature (0 = greedy)"),
gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
],
examples=[
["Explain what a 1-bit LLM is in 3 sentences."],
["Write a Python function to find the nth Fibonacci number."],
["What are the pros and cons of running AI on CPUs vs GPUs?"],
["Solve: If 3x + 7 = 22, what is x?"],
],
cache_examples=False,
)
# ── Tab 2: Benchmark ─────────────────────────────────────────────
with gr.Tab("πŸ“Š Benchmark", id="bench"):
gr.Markdown("### Run a single-shot benchmark (greedy decoding)")
with gr.Row():
with gr.Column(scale=2):
bench_prompt = gr.Textbox(
value="Write a detailed explanation of how transformer neural networks work, covering attention mechanisms, positional encoding, and the training process.",
label="Prompt",
lines=3,
)
bench_tokens = gr.Slider(16, 512, value=128, step=16, label="Max New Tokens")
bench_btn = gr.Button("πŸš€ Run Benchmark", variant="primary")
with gr.Column(scale=1):
bench_stats = gr.Markdown("*Click 'Run Benchmark' to start*")
bench_output = gr.Textbox(label="Generated Text", lines=10, interactive=False)
bench_btn.click(
fn=single_benchmark,
inputs=[bench_prompt, bench_tokens],
outputs=[bench_output, bench_stats],
)
# ── Tab 3: Paper Results ─────────────────────────────────────────
with gr.Tab("πŸ“ˆ Paper Results", id="paper"):
gr.Markdown(PAPER_TABLE)
# ── Tab 4: Architecture ──────────────────────────────────────────
with gr.Tab("πŸ—οΈ Architecture", id="arch"):
gr.Markdown(ARCHITECTURE_MD)
# ── Tab 5: System Info ───────────────────────────────────────────
with gr.Tab("βš™οΈ System", id="sys"):
gr.Markdown(SYSTEM_INFO)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft())