knoxel commited on
Commit
541bc33
Β·
verified Β·
1 Parent(s): aa994b2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +285 -0
app.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BitNet b1.58 2B4T β€” CPU-Only Inference Explorer
3
+ ================================================
4
+ A Gradio demo showcasing Microsoft's first open-source native 1-bit LLM.
5
+ All inference runs on CPU β€” no GPU required.
6
+
7
+ Paper: https://arxiv.org/abs/2504.12285
8
+ Model: https://huggingface.co/microsoft/bitnet-b1.58-2B-4T
9
+ """
10
+
11
+ import os
12
+ import time
13
+ import threading
14
+ import psutil
15
+ import torch
16
+ import gradio as gr
17
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
18
+
19
+ # ─── Configuration ───────────────────────────────────────────────────────────
20
+ MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
21
+ MAX_CONTEXT = 4096
22
+
23
+ # ─── Load Model (CPU-only) ──────────────────────────────────────────────────
24
+ print(f"Loading {MODEL_ID} on CPU...")
25
+ t0 = time.time()
26
+
27
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ MODEL_ID,
30
+ torch_dtype=torch.bfloat16,
31
+ device_map="cpu",
32
+ low_cpu_mem_usage=True,
33
+ )
34
+ model.eval()
35
+
36
+ load_time = time.time() - t0
37
+ proc = psutil.Process(os.getpid())
38
+ model_mem_gb = proc.memory_info().rss / 1024**3
39
+
40
+ print(f"βœ“ Model loaded in {load_time:.1f}s | RSS: {model_mem_gb:.2f} GB")
41
+
42
+ # ─── System Info ─────────────────────────────────────────────────────────────
43
+ cpu_count = psutil.cpu_count(logical=True)
44
+ total_ram = psutil.virtual_memory().total / 1024**3
45
+
46
+ SYSTEM_INFO = f"""### System
47
+ | Metric | Value |
48
+ |---|---|
49
+ | CPU cores | {cpu_count} |
50
+ | Total RAM | {total_ram:.1f} GB |
51
+ | Model RSS | {model_mem_gb:.2f} GB |
52
+ | Load time | {load_time:.1f}s |
53
+ | Weights | 1.58-bit ternary ({{-1, 0, +1}}) |
54
+ | Activations | 8-bit integer |
55
+ | Context | {MAX_CONTEXT} tokens |
56
+ """
57
+
58
+ # ─── Paper benchmark table (from Table 1 of the paper) ──────────────────────
59
+ PAPER_TABLE = """### Published Benchmarks (from the paper)
60
+
61
+ | Benchmark | LLaMA 3.2 1B | Gemma-3 1B | Qwen2.5 1.5B | SmolLM2 1.7B | **BitNet 2B** |
62
+ |---|---|---|---|---|---|
63
+ | **Memory** | 2 GB | 1.4 GB | 2.6 GB | 3.2 GB | **0.4 GB** |
64
+ | **CPU Latency** | 48ms | 41ms | 65ms | 67ms | **29ms** |
65
+ | **Energy/token** | 0.258J | 0.186J | 0.347J | 0.425J | **0.028J** |
66
+ | ARC-Challenge | 37.8 | 38.4 | 46.7 | 43.5 | **49.9** |
67
+ | WinoGrande | 59.5 | 58.5 | 62.8 | 69.0 | **71.9** |
68
+ | GSM8K | 38.2 | 31.2 | 56.8 | 45.1 | **58.4** |
69
+ | MMLU | 45.6 | 39.9 | **60.3** | 49.2 | 53.2 |
70
+ | HumanEval+ | 31.1 | 37.2 | **50.6** | 28.0 | 38.4 |
71
+ | **Average** | 44.9 | 43.7 | **55.2** | 48.7 | 54.2 |
72
+
73
+ *BitNet uses 5-13Γ— less memory and 6-9Γ— less energy than comparable models.*
74
+
75
+ > ⚠️ **Note:** This demo uses the `transformers` library, which does **not** include
76
+ > the specialized `bitnet.cpp` kernels. For the CPU latency numbers shown above,
77
+ > use [bitnet.cpp](https://github.com/microsoft/BitNet) with the GGUF weights.
78
+ """
79
+
80
+ # ─── Architecture explainer ──────────────────────────────────────────────────
81
+ ARCHITECTURE_MD = """### How BitNet b1.58 Works
82
+
83
+ ```
84
+ Standard Transformer β†’ BitNet b1.58
85
+ ───────────────────── ─────────────────
86
+ FP16/BF16 weights (16 bits) β†’ Ternary weights: {-1, 0, +1} (1.58 bits)
87
+ FP16 activations β†’ INT8 activations (absmax per-token)
88
+ nn.Linear β†’ BitLinear (absmean quantization)
89
+ SwiGLU activation β†’ Squared ReLU (ReLUΒ²)
90
+ LayerNorm β†’ SubLN normalization
91
+ Standard MatMul β†’ Additions only (no multiplications!)
92
+ ```
93
+
94
+ **Key Insight:** Since weights are only -1, 0, or +1, matrix multiplication
95
+ becomes pure addition/subtraction. This is why CPUs can run BitNet models
96
+ so efficiently β€” you don't need floating-point multiply hardware at all.
97
+
98
+ **Training:** The model was trained **from scratch** with this quantization,
99
+ not post-training quantized. This is crucial β€” native 1-bit training preserves
100
+ quality far better than quantizing a pre-trained FP16 model down to 1-bit.
101
+
102
+ **3-Stage Training Pipeline:**
103
+ 1. **Pre-training** on 4T tokens (text, code, synthetic math)
104
+ 2. **SFT** on instruction-following datasets
105
+ 3. **DPO** for alignment with human preferences
106
+ """
107
+
108
+ # ─── Generation functions ────────────────────────────────────────────────────
109
+
110
+ def chat_respond(message, history, system_prompt, max_new_tokens, temperature, top_p):
111
+ """Streaming chat with live token/sec stats."""
112
+ messages = [{"role": "system", "content": system_prompt}]
113
+ for item in history:
114
+ messages.append(item)
115
+ messages.append({"role": "user", "content": message})
116
+
117
+ inputs = tokenizer.apply_chat_template(
118
+ messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
119
+ )
120
+
121
+ # Truncate to max context
122
+ if inputs.shape[1] > MAX_CONTEXT - max_new_tokens:
123
+ inputs = inputs[:, -(MAX_CONTEXT - max_new_tokens):]
124
+
125
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
126
+
127
+ gen_kwargs = dict(
128
+ input_ids=inputs,
129
+ attention_mask=torch.ones_like(inputs),
130
+ pad_token_id=tokenizer.eos_token_id,
131
+ streamer=streamer,
132
+ max_new_tokens=int(max_new_tokens),
133
+ do_sample=temperature > 0,
134
+ use_cache=True,
135
+ )
136
+ if temperature > 0:
137
+ gen_kwargs["temperature"] = float(temperature)
138
+ gen_kwargs["top_p"] = float(top_p)
139
+
140
+ t0 = time.perf_counter()
141
+ thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
142
+ thread.start()
143
+
144
+ response = ""
145
+ tok_count = 0
146
+ for chunk in streamer:
147
+ response += chunk
148
+ tok_count += 1
149
+ elapsed = time.perf_counter() - t0
150
+ tps = tok_count / elapsed if elapsed > 0 else 0
151
+ stats = f"\n\n---\n*⏱ {tok_count} tokens · {tps:.1f} tok/s · {elapsed:.1f}s*"
152
+ yield response + stats
153
+
154
+ thread.join()
155
+
156
+
157
+ def single_benchmark(prompt, max_new_tokens):
158
+ """Run a single non-streaming generation with detailed stats."""
159
+ messages = [
160
+ {"role": "system", "content": "You are a helpful AI assistant."},
161
+ {"role": "user", "content": prompt},
162
+ ]
163
+ inputs = tokenizer.apply_chat_template(
164
+ messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
165
+ )
166
+ input_len = inputs.shape[1]
167
+
168
+ mem_before = proc.memory_info().rss / 1024**3
169
+
170
+ t0 = time.perf_counter()
171
+ with torch.no_grad():
172
+ output = model.generate(
173
+ inputs,
174
+ attention_mask=torch.ones_like(inputs),
175
+ pad_token_id=tokenizer.eos_token_id,
176
+ max_new_tokens=int(max_new_tokens),
177
+ do_sample=False,
178
+ use_cache=True,
179
+ )
180
+ elapsed = time.perf_counter() - t0
181
+
182
+ mem_after = proc.memory_info().rss / 1024**3
183
+ n_generated = output.shape[-1] - input_len
184
+ tps = n_generated / elapsed if elapsed > 0 else 0
185
+
186
+ response = tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
187
+
188
+ stats_md = f"""### Benchmark Results
189
+
190
+ | Metric | Value |
191
+ |---|---|
192
+ | Input tokens | {input_len} |
193
+ | Output tokens | {n_generated} |
194
+ | Total time | {elapsed:.2f}s |
195
+ | **Tokens/sec** | **{tps:.2f}** |
196
+ | Avg ms/token | {(elapsed/n_generated*1000):.1f}ms |
197
+ | Memory before | {mem_before:.2f} GB |
198
+ | Memory after | {mem_after:.2f} GB |
199
+ | Memory delta | {(mem_after - mem_before)*1000:.1f} MB |
200
+ """
201
+ return response, stats_md
202
+
203
+
204
+ # ─── Build Gradio UI ─────────────────────────────────────────────────────────
205
+
206
+ HEADER = """# 🧬 BitNet b1.58 2B4T β€” CPU-Only Inference Explorer
207
+
208
+ **The first open-source native 1-bit LLM** by Microsoft Research.
209
+ All weights are ternary {-1, 0, +1} β€” no floating-point multiplications needed.
210
+
211
+ | | |
212
+ |---|---|
213
+ | πŸ“„ [Paper](https://arxiv.org/abs/2504.12285) | πŸ€— [Model](https://huggingface.co/microsoft/bitnet-b1.58-2B-4T) |
214
+ | πŸ’» [bitnet.cpp](https://github.com/microsoft/BitNet) (38K+ ⭐) | πŸ“Š 2B params Β· 4T training tokens Β· 1.1 GB weights |
215
+ """
216
+
217
+ with gr.Blocks(
218
+ title="BitNet b1.58 2B4T β€” CPU Inference Explorer",
219
+ theme=gr.themes.Soft(),
220
+ ) as demo:
221
+
222
+ gr.Markdown(HEADER)
223
+
224
+ with gr.Tabs():
225
+ # ── Tab 1: Chat ──────────────────────────────────────────────────
226
+ with gr.Tab("πŸ’¬ Chat", id="chat"):
227
+ chat = gr.ChatInterface(
228
+ fn=chat_respond,
229
+ type="messages",
230
+ description="Chat with BitNet b1.58 on CPU. Token/sec stats shown after each response.",
231
+ additional_inputs=[
232
+ gr.Textbox(
233
+ value="You are a helpful, concise AI assistant.",
234
+ label="System Prompt",
235
+ ),
236
+ gr.Slider(1, 2048, value=256, step=1, label="Max New Tokens"),
237
+ gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature (0 = greedy)"),
238
+ gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
239
+ ],
240
+ examples=[
241
+ ["Explain what a 1-bit LLM is in 3 sentences."],
242
+ ["Write a Python function to find the nth Fibonacci number."],
243
+ ["What are the pros and cons of running AI on CPUs vs GPUs?"],
244
+ ["Solve: If 3x + 7 = 22, what is x?"],
245
+ ],
246
+ cache_examples=False,
247
+ )
248
+
249
+ # ── Tab 2: Benchmark ─────────────────────────────────────────────
250
+ with gr.Tab("πŸ“Š Benchmark", id="bench"):
251
+ gr.Markdown("### Run a single-shot benchmark (greedy decoding)")
252
+ with gr.Row():
253
+ with gr.Column(scale=2):
254
+ bench_prompt = gr.Textbox(
255
+ value="Write a detailed explanation of how transformer neural networks work, covering attention mechanisms, positional encoding, and the training process.",
256
+ label="Prompt",
257
+ lines=3,
258
+ )
259
+ bench_tokens = gr.Slider(16, 512, value=128, step=16, label="Max New Tokens")
260
+ bench_btn = gr.Button("πŸš€ Run Benchmark", variant="primary")
261
+ with gr.Column(scale=1):
262
+ bench_stats = gr.Markdown("*Click 'Run Benchmark' to start*")
263
+
264
+ bench_output = gr.Textbox(label="Generated Text", lines=10, interactive=False)
265
+ bench_btn.click(
266
+ fn=single_benchmark,
267
+ inputs=[bench_prompt, bench_tokens],
268
+ outputs=[bench_output, bench_stats],
269
+ )
270
+
271
+ # ── Tab 3: Paper Results ─────────────────────────────────────────
272
+ with gr.Tab("πŸ“ˆ Paper Results", id="paper"):
273
+ gr.Markdown(PAPER_TABLE)
274
+
275
+ # ── Tab 4: Architecture ──────────────────────────────────────────
276
+ with gr.Tab("πŸ—οΈ Architecture", id="arch"):
277
+ gr.Markdown(ARCHITECTURE_MD)
278
+
279
+ # ── Tab 5: System Info ───────────────────────────────────────────
280
+ with gr.Tab("βš™οΈ System", id="sys"):
281
+ gr.Markdown(SYSTEM_INFO)
282
+
283
+
284
+ if __name__ == "__main__":
285
+ demo.launch(server_name="0.0.0.0", server_port=7860)