Spaces:
Running
Running
| """ | |
| BitNet b1.58 2B4T β CPU-Only Inference Explorer | |
| ================================================ | |
| A Gradio demo showcasing Microsoft's first open-source native 1-bit LLM. | |
| All inference runs on CPU β no GPU required. | |
| Paper: https://arxiv.org/abs/2504.12285 | |
| Model: https://huggingface.co/microsoft/bitnet-b1.58-2B-4T | |
| """ | |
| import os | |
| import time | |
| import threading | |
| import psutil | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| # βββ Configuration βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_ID = "microsoft/bitnet-b1.58-2B-4T" | |
| MAX_CONTEXT = 4096 | |
| # βββ Load Model (CPU-only) ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"Loading {MODEL_ID} on CPU...") | |
| t0 = time.time() | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| device_map="cpu", | |
| low_cpu_mem_usage=True, | |
| ) | |
| model.eval() | |
| load_time = time.time() - t0 | |
| proc = psutil.Process(os.getpid()) | |
| model_mem_gb = proc.memory_info().rss / 1024**3 | |
| print(f"β Model loaded in {load_time:.1f}s | RSS: {model_mem_gb:.2f} GB") | |
| # βββ System Info βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| cpu_count = psutil.cpu_count(logical=True) | |
| total_ram = psutil.virtual_memory().total / 1024**3 | |
| SYSTEM_INFO = f"""### System | |
| | Metric | Value | | |
| |---|---| | |
| | CPU cores | {cpu_count} | | |
| | Total RAM | {total_ram:.1f} GB | | |
| | Model RSS | {model_mem_gb:.2f} GB | | |
| | Load time | {load_time:.1f}s | | |
| | Weights | 1.58-bit ternary ({{-1, 0, +1}}) | | |
| | Activations | 8-bit integer | | |
| | Context | {MAX_CONTEXT} tokens | | |
| """ | |
| # βββ Paper benchmark table (from Table 1 of the paper) ββββββββββββββββββββββ | |
| PAPER_TABLE = """### Published Benchmarks (from the paper) | |
| | Benchmark | LLaMA 3.2 1B | Gemma-3 1B | Qwen2.5 1.5B | SmolLM2 1.7B | **BitNet 2B** | | |
| |---|---|---|---|---|---| | |
| | **Memory** | 2 GB | 1.4 GB | 2.6 GB | 3.2 GB | **0.4 GB** | | |
| | **CPU Latency** | 48ms | 41ms | 65ms | 67ms | **29ms** | | |
| | **Energy/token** | 0.258J | 0.186J | 0.347J | 0.425J | **0.028J** | | |
| | ARC-Challenge | 37.8 | 38.4 | 46.7 | 43.5 | **49.9** | | |
| | WinoGrande | 59.5 | 58.5 | 62.8 | 69.0 | **71.9** | | |
| | GSM8K | 38.2 | 31.2 | 56.8 | 45.1 | **58.4** | | |
| | MMLU | 45.6 | 39.9 | **60.3** | 49.2 | 53.2 | | |
| | HumanEval+ | 31.1 | 37.2 | **50.6** | 28.0 | 38.4 | | |
| | **Average** | 44.9 | 43.7 | **55.2** | 48.7 | 54.2 | | |
| *BitNet uses 5-13Γ less memory and 6-9Γ less energy than comparable models.* | |
| > β οΈ **Note:** This demo uses the `transformers` library, which does **not** include | |
| > the specialized `bitnet.cpp` kernels. For the CPU latency numbers shown above, | |
| > use [bitnet.cpp](https://github.com/microsoft/BitNet) with the GGUF weights. | |
| """ | |
| # βββ Architecture explainer ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ARCHITECTURE_MD = """### How BitNet b1.58 Works | |
| ``` | |
| Standard Transformer β BitNet b1.58 | |
| βββββββββββββββββββββ βββββββββββββββββ | |
| FP16/BF16 weights (16 bits) β Ternary weights: {-1, 0, +1} (1.58 bits) | |
| FP16 activations β INT8 activations (absmax per-token) | |
| nn.Linear β BitLinear (absmean quantization) | |
| SwiGLU activation β Squared ReLU (ReLUΒ²) | |
| LayerNorm β SubLN normalization | |
| Standard MatMul β Additions only (no multiplications!) | |
| ``` | |
| **Key Insight:** Since weights are only -1, 0, or +1, matrix multiplication | |
| becomes pure addition/subtraction. This is why CPUs can run BitNet models | |
| so efficiently β you don't need floating-point multiply hardware at all. | |
| **Training:** The model was trained **from scratch** with this quantization, | |
| not post-training quantized. This is crucial β native 1-bit training preserves | |
| quality far better than quantizing a pre-trained FP16 model down to 1-bit. | |
| **3-Stage Training Pipeline:** | |
| 1. **Pre-training** on 4T tokens (text, code, synthetic math) | |
| 2. **SFT** on instruction-following datasets | |
| 3. **DPO** for alignment with human preferences | |
| """ | |
| # βββ Generation functions ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def chat_respond(message, history, system_prompt, max_new_tokens, temperature, top_p): | |
| """Streaming chat with live token/sec stats.""" | |
| messages = [{"role": "system", "content": system_prompt}] | |
| for item in history: | |
| messages.append(item) | |
| messages.append({"role": "user", "content": message}) | |
| inputs = tokenizer.apply_chat_template( | |
| messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" | |
| ) | |
| # Truncate to max context | |
| if inputs.shape[1] > MAX_CONTEXT - max_new_tokens: | |
| inputs = inputs[:, -(MAX_CONTEXT - max_new_tokens):] | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| gen_kwargs = dict( | |
| input_ids=inputs, | |
| attention_mask=torch.ones_like(inputs), | |
| pad_token_id=tokenizer.eos_token_id, | |
| streamer=streamer, | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=temperature > 0, | |
| use_cache=True, | |
| ) | |
| if temperature > 0: | |
| gen_kwargs["temperature"] = float(temperature) | |
| gen_kwargs["top_p"] = float(top_p) | |
| t0 = time.perf_counter() | |
| thread = threading.Thread(target=model.generate, kwargs=gen_kwargs) | |
| thread.start() | |
| response = "" | |
| tok_count = 0 | |
| for chunk in streamer: | |
| response += chunk | |
| tok_count += 1 | |
| elapsed = time.perf_counter() - t0 | |
| tps = tok_count / elapsed if elapsed > 0 else 0 | |
| stats = f"\n\n---\n*β± {tok_count} tokens Β· {tps:.1f} tok/s Β· {elapsed:.1f}s*" | |
| yield response + stats | |
| thread.join() | |
| def single_benchmark(prompt, max_new_tokens): | |
| """Run a single non-streaming generation with detailed stats.""" | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful AI assistant."}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| inputs = tokenizer.apply_chat_template( | |
| messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" | |
| ) | |
| input_len = inputs.shape[1] | |
| mem_before = proc.memory_info().rss / 1024**3 | |
| t0 = time.perf_counter() | |
| with torch.no_grad(): | |
| output = model.generate( | |
| inputs, | |
| attention_mask=torch.ones_like(inputs), | |
| pad_token_id=tokenizer.eos_token_id, | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=False, | |
| use_cache=True, | |
| ) | |
| elapsed = time.perf_counter() - t0 | |
| mem_after = proc.memory_info().rss / 1024**3 | |
| n_generated = output.shape[-1] - input_len | |
| tps = n_generated / elapsed if elapsed > 0 else 0 | |
| response = tokenizer.decode(output[0][input_len:], skip_special_tokens=True) | |
| stats_md = f"""### Benchmark Results | |
| | Metric | Value | | |
| |---|---| | |
| | Input tokens | {input_len} | | |
| | Output tokens | {n_generated} | | |
| | Total time | {elapsed:.2f}s | | |
| | **Tokens/sec** | **{tps:.2f}** | | |
| | Avg ms/token | {(elapsed/n_generated*1000):.1f}ms | | |
| | Memory before | {mem_before:.2f} GB | | |
| | Memory after | {mem_after:.2f} GB | | |
| | Memory delta | {(mem_after - mem_before)*1000:.1f} MB | | |
| """ | |
| return response, stats_md | |
| # βββ Build Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HEADER = """# 𧬠BitNet b1.58 2B4T β CPU-Only Inference Explorer | |
| **The first open-source native 1-bit LLM** by Microsoft Research. | |
| All weights are ternary {-1, 0, +1} β no floating-point multiplications needed. | |
| | | | | |
| |---|---| | |
| | π [Paper](https://arxiv.org/abs/2504.12285) | π€ [Model](https://huggingface.co/microsoft/bitnet-b1.58-2B-4T) | | |
| | π» [bitnet.cpp](https://github.com/microsoft/BitNet) (38K+ β) | π 2B params Β· 4T training tokens Β· 1.1 GB weights | | |
| """ | |
| with gr.Blocks( | |
| title="BitNet b1.58 2B4T β CPU Inference Explorer", | |
| ) as demo: | |
| gr.Markdown(HEADER) | |
| with gr.Tabs(): | |
| # ββ Tab 1: Chat ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π¬ Chat", id="chat"): | |
| chat = gr.ChatInterface( | |
| fn=chat_respond, | |
| description="Chat with BitNet b1.58 on CPU. Token/sec stats shown after each response.", | |
| additional_inputs=[ | |
| gr.Textbox( | |
| value="You are a helpful, concise AI assistant.", | |
| label="System Prompt", | |
| ), | |
| gr.Slider(1, 2048, value=256, step=1, label="Max New Tokens"), | |
| gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature (0 = greedy)"), | |
| gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"), | |
| ], | |
| examples=[ | |
| ["Explain what a 1-bit LLM is in 3 sentences."], | |
| ["Write a Python function to find the nth Fibonacci number."], | |
| ["What are the pros and cons of running AI on CPUs vs GPUs?"], | |
| ["Solve: If 3x + 7 = 22, what is x?"], | |
| ], | |
| cache_examples=False, | |
| ) | |
| # ββ Tab 2: Benchmark βββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π Benchmark", id="bench"): | |
| gr.Markdown("### Run a single-shot benchmark (greedy decoding)") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| bench_prompt = gr.Textbox( | |
| value="Write a detailed explanation of how transformer neural networks work, covering attention mechanisms, positional encoding, and the training process.", | |
| label="Prompt", | |
| lines=3, | |
| ) | |
| bench_tokens = gr.Slider(16, 512, value=128, step=16, label="Max New Tokens") | |
| bench_btn = gr.Button("π Run Benchmark", variant="primary") | |
| with gr.Column(scale=1): | |
| bench_stats = gr.Markdown("*Click 'Run Benchmark' to start*") | |
| bench_output = gr.Textbox(label="Generated Text", lines=10, interactive=False) | |
| bench_btn.click( | |
| fn=single_benchmark, | |
| inputs=[bench_prompt, bench_tokens], | |
| outputs=[bench_output, bench_stats], | |
| ) | |
| # ββ Tab 3: Paper Results βββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π Paper Results", id="paper"): | |
| gr.Markdown(PAPER_TABLE) | |
| # ββ Tab 4: Architecture ββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("ποΈ Architecture", id="arch"): | |
| gr.Markdown(ARCHITECTURE_MD) | |
| # ββ Tab 5: System Info βββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("βοΈ System", id="sys"): | |
| gr.Markdown(SYSTEM_INFO) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft()) | |