""" BitNet b1.58 2B4T — CPU-Only Inference Explorer ================================================ A Gradio demo showcasing Microsoft's first open-source native 1-bit LLM. All inference runs on CPU — no GPU required. Paper: https://arxiv.org/abs/2504.12285 Model: https://huggingface.co/microsoft/bitnet-b1.58-2B-4T """ import os import time import threading import psutil import torch import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer # ─── Configuration ─────────────────────────────────────────────────────────── MODEL_ID = "microsoft/bitnet-b1.58-2B-4T" MAX_CONTEXT = 4096 # ─── Load Model (CPU-only) ────────────────────────────────────────────────── print(f"Loading {MODEL_ID} on CPU...") t0 = time.time() tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="cpu", low_cpu_mem_usage=True, ) model.eval() load_time = time.time() - t0 proc = psutil.Process(os.getpid()) model_mem_gb = proc.memory_info().rss / 1024**3 print(f"✓ Model loaded in {load_time:.1f}s | RSS: {model_mem_gb:.2f} GB") # ─── System Info ───────────────────────────────────────────────────────────── cpu_count = psutil.cpu_count(logical=True) total_ram = psutil.virtual_memory().total / 1024**3 SYSTEM_INFO = f"""### System | Metric | Value | |---|---| | CPU cores | {cpu_count} | | Total RAM | {total_ram:.1f} GB | | Model RSS | {model_mem_gb:.2f} GB | | Load time | {load_time:.1f}s | | Weights | 1.58-bit ternary ({{-1, 0, +1}}) | | Activations | 8-bit integer | | Context | {MAX_CONTEXT} tokens | """ # ─── Paper benchmark table (from Table 1 of the paper) ────────────────────── PAPER_TABLE = """### Published Benchmarks (from the paper) | Benchmark | LLaMA 3.2 1B | Gemma-3 1B | Qwen2.5 1.5B | SmolLM2 1.7B | **BitNet 2B** | |---|---|---|---|---|---| | **Memory** | 2 GB | 1.4 GB | 2.6 GB | 3.2 GB | **0.4 GB** | | **CPU Latency** | 48ms | 41ms | 65ms | 67ms | **29ms** | | **Energy/token** | 0.258J | 0.186J | 0.347J | 0.425J | **0.028J** | | ARC-Challenge | 37.8 | 38.4 | 46.7 | 43.5 | **49.9** | | WinoGrande | 59.5 | 58.5 | 62.8 | 69.0 | **71.9** | | GSM8K | 38.2 | 31.2 | 56.8 | 45.1 | **58.4** | | MMLU | 45.6 | 39.9 | **60.3** | 49.2 | 53.2 | | HumanEval+ | 31.1 | 37.2 | **50.6** | 28.0 | 38.4 | | **Average** | 44.9 | 43.7 | **55.2** | 48.7 | 54.2 | *BitNet uses 5-13× less memory and 6-9× less energy than comparable models.* > ⚠️ **Note:** This demo uses the `transformers` library, which does **not** include > the specialized `bitnet.cpp` kernels. For the CPU latency numbers shown above, > use [bitnet.cpp](https://github.com/microsoft/BitNet) with the GGUF weights. """ # ─── Architecture explainer ────────────────────────────────────────────────── ARCHITECTURE_MD = """### How BitNet b1.58 Works ``` Standard Transformer → BitNet b1.58 ───────────────────── ───────────────── FP16/BF16 weights (16 bits) → Ternary weights: {-1, 0, +1} (1.58 bits) FP16 activations → INT8 activations (absmax per-token) nn.Linear → BitLinear (absmean quantization) SwiGLU activation → Squared ReLU (ReLU²) LayerNorm → SubLN normalization Standard MatMul → Additions only (no multiplications!) ``` **Key Insight:** Since weights are only -1, 0, or +1, matrix multiplication becomes pure addition/subtraction. This is why CPUs can run BitNet models so efficiently — you don't need floating-point multiply hardware at all. **Training:** The model was trained **from scratch** with this quantization, not post-training quantized. This is crucial — native 1-bit training preserves quality far better than quantizing a pre-trained FP16 model down to 1-bit. **3-Stage Training Pipeline:** 1. **Pre-training** on 4T tokens (text, code, synthetic math) 2. **SFT** on instruction-following datasets 3. **DPO** for alignment with human preferences """ # ─── Generation functions ──────────────────────────────────────────────────── def chat_respond(message, history, system_prompt, max_new_tokens, temperature, top_p): """Streaming chat with live token/sec stats.""" messages = [{"role": "system", "content": system_prompt}] for item in history: messages.append(item) messages.append({"role": "user", "content": message}) inputs = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" ) # Truncate to max context if inputs.shape[1] > MAX_CONTEXT - max_new_tokens: inputs = inputs[:, -(MAX_CONTEXT - max_new_tokens):] streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) gen_kwargs = dict( input_ids=inputs, attention_mask=torch.ones_like(inputs), pad_token_id=tokenizer.eos_token_id, streamer=streamer, max_new_tokens=int(max_new_tokens), do_sample=temperature > 0, use_cache=True, ) if temperature > 0: gen_kwargs["temperature"] = float(temperature) gen_kwargs["top_p"] = float(top_p) t0 = time.perf_counter() thread = threading.Thread(target=model.generate, kwargs=gen_kwargs) thread.start() response = "" tok_count = 0 for chunk in streamer: response += chunk tok_count += 1 elapsed = time.perf_counter() - t0 tps = tok_count / elapsed if elapsed > 0 else 0 stats = f"\n\n---\n*⏱ {tok_count} tokens · {tps:.1f} tok/s · {elapsed:.1f}s*" yield response + stats thread.join() def single_benchmark(prompt, max_new_tokens): """Run a single non-streaming generation with detailed stats.""" messages = [ {"role": "system", "content": "You are a helpful AI assistant."}, {"role": "user", "content": prompt}, ] inputs = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" ) input_len = inputs.shape[1] mem_before = proc.memory_info().rss / 1024**3 t0 = time.perf_counter() with torch.no_grad(): output = model.generate( inputs, attention_mask=torch.ones_like(inputs), pad_token_id=tokenizer.eos_token_id, max_new_tokens=int(max_new_tokens), do_sample=False, use_cache=True, ) elapsed = time.perf_counter() - t0 mem_after = proc.memory_info().rss / 1024**3 n_generated = output.shape[-1] - input_len tps = n_generated / elapsed if elapsed > 0 else 0 response = tokenizer.decode(output[0][input_len:], skip_special_tokens=True) stats_md = f"""### Benchmark Results | Metric | Value | |---|---| | Input tokens | {input_len} | | Output tokens | {n_generated} | | Total time | {elapsed:.2f}s | | **Tokens/sec** | **{tps:.2f}** | | Avg ms/token | {(elapsed/n_generated*1000):.1f}ms | | Memory before | {mem_before:.2f} GB | | Memory after | {mem_after:.2f} GB | | Memory delta | {(mem_after - mem_before)*1000:.1f} MB | """ return response, stats_md # ─── Build Gradio UI ───────────────────────────────────────────────────────── HEADER = """# 🧬 BitNet b1.58 2B4T — CPU-Only Inference Explorer **The first open-source native 1-bit LLM** by Microsoft Research. All weights are ternary {-1, 0, +1} — no floating-point multiplications needed. | | | |---|---| | 📄 [Paper](https://arxiv.org/abs/2504.12285) | 🤗 [Model](https://huggingface.co/microsoft/bitnet-b1.58-2B-4T) | | 💻 [bitnet.cpp](https://github.com/microsoft/BitNet) (38K+ ⭐) | 📊 2B params · 4T training tokens · 1.1 GB weights | """ with gr.Blocks( title="BitNet b1.58 2B4T — CPU Inference Explorer", ) as demo: gr.Markdown(HEADER) with gr.Tabs(): # ── Tab 1: Chat ────────────────────────────────────────────────── with gr.Tab("💬 Chat", id="chat"): chat = gr.ChatInterface( fn=chat_respond, description="Chat with BitNet b1.58 on CPU. Token/sec stats shown after each response.", additional_inputs=[ gr.Textbox( value="You are a helpful, concise AI assistant.", label="System Prompt", ), gr.Slider(1, 2048, value=256, step=1, label="Max New Tokens"), gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature (0 = greedy)"), gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"), ], examples=[ ["Explain what a 1-bit LLM is in 3 sentences."], ["Write a Python function to find the nth Fibonacci number."], ["What are the pros and cons of running AI on CPUs vs GPUs?"], ["Solve: If 3x + 7 = 22, what is x?"], ], cache_examples=False, ) # ── Tab 2: Benchmark ───────────────────────────────────────────── with gr.Tab("📊 Benchmark", id="bench"): gr.Markdown("### Run a single-shot benchmark (greedy decoding)") with gr.Row(): with gr.Column(scale=2): bench_prompt = gr.Textbox( value="Write a detailed explanation of how transformer neural networks work, covering attention mechanisms, positional encoding, and the training process.", label="Prompt", lines=3, ) bench_tokens = gr.Slider(16, 512, value=128, step=16, label="Max New Tokens") bench_btn = gr.Button("🚀 Run Benchmark", variant="primary") with gr.Column(scale=1): bench_stats = gr.Markdown("*Click 'Run Benchmark' to start*") bench_output = gr.Textbox(label="Generated Text", lines=10, interactive=False) bench_btn.click( fn=single_benchmark, inputs=[bench_prompt, bench_tokens], outputs=[bench_output, bench_stats], ) # ── Tab 3: Paper Results ───────────────────────────────────────── with gr.Tab("📈 Paper Results", id="paper"): gr.Markdown(PAPER_TABLE) # ── Tab 4: Architecture ────────────────────────────────────────── with gr.Tab("🏗️ Architecture", id="arch"): gr.Markdown(ARCHITECTURE_MD) # ── Tab 5: System Info ─────────────────────────────────────────── with gr.Tab("⚙️ System", id="sys"): gr.Markdown(SYSTEM_INFO) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft())