bonsai-webgpu / src /App.jsx
Xenova's picture
Xenova HF Staff
Upload 468 files
cbb6a01 verified
import { useState, useEffect, useRef, useCallback } from "react";
import { Streamdown } from "streamdown";
import { code } from "@streamdown/code";
import { mermaid } from "@streamdown/mermaid";
import { createMathPlugin } from "@streamdown/math";
import { cjk } from "@streamdown/cjk";
import "streamdown/styles.css";
import "katex/dist/katex.min.css";
const MODELS = [
{
id: "1.7b",
name: "Bonsai 1.7B",
params: "1.7B",
size: "290 MB",
blurb: "Pocket-class. Built for wearables and always-on agents.",
},
{
id: "4b",
name: "Bonsai 4B",
params: "4B",
size: "584 MB",
blurb: "The sweet spot. Strong reasoning at on-device latency.",
comingSoon: true,
},
{
id: "8b",
name: "Bonsai 8B",
params: "8B",
size: "1.2 GB",
blurb: "Datacenter-grade reasoning, in your browser tab.",
comingSoon: true,
},
];
const formatBytes = (bytes) => {
if (!bytes) return "0 MB";
if (bytes >= 1e9) return `${(bytes / 1e9).toFixed(2)} GB`;
return `${(bytes / 1e6).toFixed(0)} MB`;
};
const math = createMathPlugin({ singleDollarTextMath: true });
const STREAMDOWN_PLUGINS = { code, mermaid, math, cjk };
const PRISM_GLYPH_CLASS =
"h-9 w-9 overflow-hidden opacity-90 [clip-path:polygon(50%_4%,100%_100%,0%_100%)] bg-[radial-gradient(circle_at_50%_18%,rgba(255,255,255,0.3),transparent_28%),linear-gradient(180deg,var(--prism-2)_0%,var(--prism-1)_42%,var(--prism-6)_100%)] drop-shadow-[0_0_18px_rgba(255,184,77,0.18)]";
function SendIcon() {
return (
<svg viewBox="0 0 16 16" fill="none" aria-hidden="true">
<path
d="M4 12L12 4M6 4H12V10"
stroke="currentColor"
strokeWidth="1.7"
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
}
function ResetIcon() {
return (
<svg viewBox="0 0 16 16" fill="none" aria-hidden="true">
<path
d="M4.5 6.5H1.75V3.75M2.2 6.2A5.8 5.8 0 1 1 3.6 11.7"
stroke="currentColor"
strokeWidth="1.7"
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
}
function StopIcon() {
return (
<svg viewBox="0 0 16 16" fill="none" aria-hidden="true">
<rect
x="4.25"
y="4.25"
width="7.5"
height="7.5"
rx="1.8"
fill="currentColor"
/>
</svg>
);
}
export default function PrismDemo() {
const canvasRef = useRef(null);
const mouseRef = useRef({ x: -9999, y: -9999, inside: false });
const cellsRef = useRef([]);
const wavesRef = useRef([]);
const rafRef = useRef(null);
const playgroundRef = useRef(null);
const inputRef = useRef(null);
const messagesRef = useRef(null);
const autoScrollRef = useRef(true);
const optimizeMessageTimeoutRef = useRef(null);
// Playground stage machine
const [stage, setStage] = useState("select"); // 'select' | 'loading' | 'chat' | 'error'
const [selectedId, setSelectedId] = useState("1.7b");
const [loadProgress, setLoadProgress] = useState(0);
const [loadMessage, setLoadMessage] = useState("Fetching weights");
const [loadLoaded, setLoadLoaded] = useState(0);
const [loadTotal, setLoadTotal] = useState(0);
const [errorMessage, setErrorMessage] = useState(null);
const messagesHistoryRef = useRef([]);
// Chat state
const [messages, setMessages] = useState([]);
const [input, setInput] = useState("");
const [isThinking, setIsThinking] = useState(false);
const [isStreaming, setIsStreaming] = useState(false);
const [tps, setTps] = useState(null);
const workerRef = useRef(null);
const [fontsReady, setFontsReady] = useState(false);
useEffect(() => {
if (!document.fonts?.load) {
setFontsReady(true);
return;
}
Promise.all([
document.fonts.load('400 132px "Instrument Serif"'),
document.fonts.load('italic 400 132px "Instrument Serif"'),
]).finally(() => setFontsReady(true));
}, []);
const selected = MODELS.find((m) => m.id === selectedId) || MODELS[0];
const scrollToPlayground = useCallback(() => {
playgroundRef.current?.scrollIntoView({
behavior: "smooth",
block: "start",
});
}, []);
useEffect(() => {
if (!workerRef.current) {
workerRef.current = new Worker(new URL("./worker.js", import.meta.url), {
type: "module",
});
}
const worker = workerRef.current;
const onMessage = (e) => {
const d = e.data;
switch (d.status) {
case "progress_total":
setLoadProgress(d.progress);
setLoadLoaded(d.loaded);
setLoadTotal(d.total);
if (Number(d.progress) >= 100) {
clearTimeout(optimizeMessageTimeoutRef.current);
optimizeMessageTimeoutRef.current = setTimeout(() => {
setLoadMessage("Optimizing model for 1-bit execution");
}, 100);
}
break;
case "loading":
setLoadMessage(d.data);
break;
case "ready":
clearTimeout(optimizeMessageTimeoutRef.current);
setStage("chat");
setTimeout(() => inputRef.current?.focus(), 200);
break;
case "start":
setIsThinking(false);
setIsStreaming(true);
setTps(null);
setMessages((m) => [...m, { role: "assistant", content: "" }]);
break;
case "update":
if (d.tps != null) setTps(d.tps);
setMessages((m) => {
const copy = [...m];
const last = copy[copy.length - 1];
copy[copy.length - 1] = {
...last,
content: last.content + d.output,
};
return copy;
});
break;
case "complete":
setIsStreaming(false);
break;
case "error":
clearTimeout(optimizeMessageTimeoutRef.current);
setErrorMessage(d.data);
setStage("error");
setIsThinking(false);
setIsStreaming(false);
break;
}
};
worker.addEventListener("message", onMessage);
return () => {
clearTimeout(optimizeMessageTimeoutRef.current);
worker.removeEventListener("message", onMessage);
};
}, []);
const startLoading = () => {
clearTimeout(optimizeMessageTimeoutRef.current);
setStage("loading");
setLoadProgress(0);
setLoadMessage("Fetching weights");
setLoadLoaded(0);
setLoadTotal(0);
workerRef.current?.postMessage({ type: "load", data: selectedId });
};
// --- Background canvas: digits + glistening waves ---
useEffect(() => {
const canvas = canvasRef.current;
if (!canvas) return;
const ctx = canvas.getContext("2d");
const dpr = window.devicePixelRatio || 1;
const CELL_W = 16;
const CELL_H = 22;
const PRISM = [
[255, 122, 92],
[255, 184, 77],
[196, 217, 46],
[77, 208, 196],
[124, 142, 232],
[182, 123, 232],
];
const setupCells = () => {
const w = canvas.clientWidth;
const h = canvas.clientHeight;
canvas.width = w * dpr;
canvas.height = h * dpr;
ctx.scale(dpr, dpr);
const cols = Math.ceil(w / CELL_W);
const rows = Math.ceil(h / CELL_H);
const cells = [];
for (let r = 0; r < rows; r++) {
for (let c = 0; c < cols; c++) {
cells.push({
x: c * CELL_W + CELL_W / 2,
y: r * CELL_H + CELL_H / 2 + 6,
val: Math.random() > 0.5 ? "1" : "0",
nextFlip: performance.now() + Math.random() * 8000 + 3000,
phase: Math.random() * Math.PI * 2,
});
}
}
cellsRef.current = cells;
};
const handleResize = () => {
ctx.setTransform(1, 0, 0, 1, 0, 0);
setupCells();
};
const handleMove = (e) => {
const rect = canvas.getBoundingClientRect();
mouseRef.current.x = e.clientX - rect.left;
mouseRef.current.y = e.clientY - rect.top;
mouseRef.current.inside = true;
};
const handleLeave = () => {
mouseRef.current.inside = false;
mouseRef.current.x = -9999;
mouseRef.current.y = -9999;
};
setupCells();
window.addEventListener("resize", handleResize);
window.addEventListener("mousemove", handleMove);
canvas.addEventListener("mouseleave", handleLeave);
let waveColorIdx = 0;
const spawnWave = () => {
const w = canvas.clientWidth;
const h = canvas.clientHeight;
const angle = (Math.random() - 0.5) * (Math.PI / 3) + Math.PI / 2;
const nx = Math.cos(angle);
const ny = Math.sin(angle);
const dots = [0, w * nx, h * ny, w * nx + h * ny];
const minDot = Math.min(...dots);
const maxDot = Math.max(...dots);
const direction = Math.random() > 0.5 ? 1 : -1;
const startPos = direction > 0 ? minDot - 200 : maxDot + 200;
const endPos = direction > 0 ? maxDot + 200 : minDot - 200;
const speed = 220 + Math.random() * 140;
const color = PRISM[waveColorIdx % PRISM.length];
waveColorIdx++;
wavesRef.current.push({
nx,
ny,
pos: startPos,
endPos,
direction,
speed,
color,
bandWidth: 110,
startedAt: performance.now(),
flippedSet: new Set(),
});
};
let nextWaveAt = performance.now() + 2200;
let lastT = performance.now();
const render = (now) => {
const dt = Math.min(now - lastT, 64);
lastT = now;
const w = canvas.clientWidth;
const h = canvas.clientHeight;
ctx.clearRect(0, 0, w, h);
ctx.font = "11px 'JetBrains Mono', monospace";
ctx.textAlign = "center";
ctx.textBaseline = "alphabetic";
if (now > nextWaveAt) {
spawnWave();
nextWaveAt = now + 5500 + Math.random() * 3500;
}
const activeWaves = [];
for (const wv of wavesRef.current) {
wv.pos += wv.direction * wv.speed * (dt / 1000);
const stillOnscreen =
(wv.direction > 0 && wv.pos < wv.endPos) ||
(wv.direction < 0 && wv.pos > wv.endPos);
if (stillOnscreen) activeWaves.push(wv);
}
wavesRef.current = activeWaves;
const mx = mouseRef.current.x;
const my = mouseRef.current.y;
const LENS_R = 160;
const LENS_R2 = LENS_R * LENS_R;
const cells = cellsRef.current;
for (let i = 0; i < cells.length; i++) {
const cell = cells[i];
if (now > cell.nextFlip) {
cell.val = cell.val === "1" ? "0" : "1";
cell.nextFlip = now + 4000 + Math.random() * 9000;
}
const pulse = 0.5 + 0.5 * Math.sin(now * 0.0006 + cell.phase);
let opacity = 0.02 + pulse * 0.015;
let r = 235,
g = 229,
b = 216;
let waveBoost = 0;
let tintR = 0,
tintG = 0,
tintB = 0,
tintW = 0;
for (const wv of activeWaves) {
const cellPos = cell.x * wv.nx + cell.y * wv.ny;
const d = cellPos - wv.pos;
const ad = Math.abs(d);
if (ad < wv.bandWidth) {
const sigma = 30;
const glow = Math.exp(-(d * d) / (2 * sigma * sigma));
const crest = Math.max(0, 1 - ad / wv.bandWidth);
const intensity = glow * 0.85 + crest * 0.15;
waveBoost += intensity * 0.85;
tintR += wv.color[0] * intensity;
tintG += wv.color[1] * intensity;
tintB += wv.color[2] * intensity;
tintW += intensity;
if (glow > 0.85 && !wv.flippedSet.has(i) && Math.random() < 0.18) {
cell.val = cell.val === "1" ? "0" : "1";
wv.flippedSet.add(i);
}
}
}
let lensFalloff = 0;
if (mouseRef.current.inside) {
const dx = cell.x - mx;
const dy = cell.y - my;
const d2 = dx * dx + dy * dy;
if (d2 < LENS_R2) {
lensFalloff = 1 - Math.sqrt(d2) / LENS_R;
}
}
if (lensFalloff > 0) {
opacity = Math.max(opacity, 0.08 + lensFalloff * 0.4);
}
if (waveBoost > 0) {
opacity = Math.min(0.95, opacity + waveBoost);
if (tintW > 0) {
const tintMix = Math.min(1, waveBoost * 1.2);
const tr = tintR / tintW;
const tg = tintG / tintW;
const tb = tintB / tintW;
r = r * (1 - tintMix) + tr * tintMix;
g = g * (1 - tintMix) + tg * tintMix;
b = b * (1 - tintMix) + tb * tintMix;
}
}
ctx.fillStyle = `rgba(${r | 0}, ${g | 0}, ${b | 0}, ${opacity})`;
ctx.fillText(cell.val, cell.x, cell.y);
}
rafRef.current = requestAnimationFrame(render);
};
rafRef.current = requestAnimationFrame(render);
return () => {
cancelAnimationFrame(rafRef.current);
window.removeEventListener("resize", handleResize);
window.removeEventListener("mousemove", handleMove);
canvas.removeEventListener("mouseleave", handleLeave);
};
}, []);
useEffect(() => {
messagesHistoryRef.current = messages;
}, [messages]);
const sendMessage = (text) => {
const trimmed = text.trim();
if (!trimmed || isStreaming || isThinking) return;
setInput("");
const nextHistory = [
...messagesHistoryRef.current,
{ role: "user", content: trimmed },
];
setMessages(nextHistory);
setIsThinking(true);
workerRef.current?.postMessage({ type: "generate", data: nextHistory });
};
const interruptGeneration = () => {
workerRef.current?.postMessage({ type: "interrupt" });
};
const resetChat = () => {
setInput("");
setMessages([]);
messagesHistoryRef.current = [];
setIsThinking(false);
setIsStreaming(false);
setTps(null);
autoScrollRef.current = true;
workerRef.current?.postMessage({ type: "reset" });
setTimeout(() => inputRef.current?.focus(), 0);
};
const handleMessagesScroll = useCallback(() => {
const container = messagesRef.current;
if (!container) return;
const distanceFromBottom =
container.scrollHeight - container.scrollTop - container.clientHeight;
autoScrollRef.current = distanceFromBottom < 96;
}, []);
useEffect(() => {
if (!autoScrollRef.current) return;
const container = messagesRef.current;
if (!container) return;
const id = requestAnimationFrame(() => {
container.scrollTo({
top: container.scrollHeight,
behavior: isStreaming ? "auto" : "smooth",
});
});
return () => cancelAnimationFrame(id);
}, [messages, isThinking, isStreaming]);
const SUGGESTIONS = [
"Who are you?",
"Write a short poem about AI.",
"What is the capital of France?",
"Solve x^2 - 5x + 6 = 0.",
];
const handleKeyDown = (e) => {
if (e.key === "Enter" && !e.shiftKey) {
e.preventDefault();
sendMessage(input);
}
};
let tpsDisplay = "—";
if (isThinking) tpsDisplay = "···";
else if (tps !== null) tpsDisplay = tps.toFixed(1);
return (
<>
<div className="prism-root">
{/* ============= HERO ============= */}
<section className="hero">
<canvas ref={canvasRef} className="digit-canvas" />
<div className="grain" />
<div className="vignette" />
<div className={`hero-content ${fontsReady ? "is-ready" : ""}`}>
<div className="hero-badge">
14× less memory · 8× faster · 5× less energy
</div>
<h1>
1-bit LLMs,
<br />
in your <span className="accent italic-d">browser.</span>
</h1>
<p className="lede">
Run 1-bit Bonsai LLMs (1.7B, 4B, 8B) entirely locally in your
browser on WebGPU, powered by Transformers.js.
</p>
<button className="btn btn-primary" onClick={scrollToPlayground}>
Try the demo
<span className="arrow"></span>
</button>
</div>
</section>
{/* ============= PLAYGROUND ============= */}
<section className="playground" ref={playgroundRef}>
{stage === "select" && (
<div className="select-view">
<div className="select-head">
<div className="eyebrow">Choose a model</div>
<h2>
Load <span className="em">locally.</span>
</h2>
<p>
Each Bonsai model runs entirely in your browser via WebGPU.
Pick a size — smaller loads faster, larger reasons better.
</p>
</div>
<div className="model-grid">
{MODELS.map((m) => (
<div
key={m.id}
className={`model-card ${selectedId === m.id ? "selected" : ""} ${m.comingSoon ? "disabled" : ""}`}
onClick={() => !m.comingSoon && setSelectedId(m.id)}
>
{m.comingSoon && (
<div className="mc-ribbon">Coming soon</div>
)}
<div className="check">
{selectedId === m.id ? "✓" : ""}
</div>
<div className="mc-size">
{m.params}
<span className="gb">{m.size}</span>
</div>
<div className="mc-name">{m.name}</div>
<div className="mc-blurb">{m.blurb}</div>
</div>
))}
</div>
<div className="select-actions">
<button
className="btn btn-primary"
onClick={startLoading}
disabled={!selectedId}
>
Load {selected.name}
<span className="arrow"></span>
</button>
<div className="note">No data leaves your device</div>
</div>
</div>
)}
{stage === "loading" && (
<div className="load-view">
<div className="load-card">
<div
className={`${PRISM_GLYPH_CLASS} mx-auto mb-7 animate-spin [animation-duration:8s]`}
/>
<div className="label">Initializing</div>
<h3>
{selected.name.split(" ")[0]}{" "}
<span className="em">{selected.name.split(" ")[1]}</span>
</h3>
<div className="progress">
<div className="fill" style={{ width: `${loadProgress}%` }} />
</div>
<div className="load-row">
<span>
{loadTotal > 0 && loadProgress < 100
? `${formatBytes(loadLoaded)} / ${formatBytes(loadTotal)}`
: loadMessage}
</span>
<span className="pct">{Math.round(loadProgress)}%</span>
</div>
</div>
</div>
)}
{stage === "error" && (
<div className="load-view">
<div className="load-card">
<div className="label">Error</div>
<h3>
Could not <span className="em">load.</span>
</h3>
<p
style={{
color: "var(--cream-dim)",
fontSize: 14,
margin: "0 0 18px",
}}
>
{errorMessage}
</p>
<button
className="btn btn-primary"
onClick={() => {
setErrorMessage(null);
setStage("select");
}}
>
Back
</button>
</div>
</div>
)}
{stage === "chat" && (
<div className="chat-view">
<div className="pg-header-bar">
<h2 className="pg-title">
The <span className="em">playground.</span>
</h2>
<div className="pg-mini">
<span className="chip">
<span className="k">model</span>
{selected.name}
</span>
<span className="chip">
<span className="k">size</span>
{selected.size}
</span>
</div>
</div>
<div className="chat">
<div className="chat-bar">
<div className="left">
<span className="dots">
<span />
<span />
<span />
</span>
<span>BONSAI WebGPU</span>
</div>
<div className="right">
<span className="tps-value">{tpsDisplay}</span>
<span className="tps-label">Tokens / sec</span>
</div>
</div>
<div
ref={messagesRef}
className="messages"
onScroll={handleMessagesScroll}
>
{messages.length === 0 && !isThinking ? (
<div className="flex flex-1 flex-col items-center justify-center gap-5 px-0 py-6 text-center">
<div className={PRISM_GLYPH_CLASS} />
<div
className="text-[32px] leading-none tracking-[-0.02em] text-[var(--cream)]"
style={{ fontFamily: '"Instrument Serif", serif' }}
>
How can I help you?
</div>
<div
className="mt-[-6px] text-[12px] leading-[1.25] tracking-[0.06em] text-[var(--muted)] uppercase"
style={{ fontFamily: '"JetBrains Mono", monospace' }}
>
Talk to a 1-bit model
</div>
<div className="mt-1 grid w-full max-w-[540px] grid-cols-1 gap-2.5 sm:grid-cols-2">
{SUGGESTIONS.map((s) => (
<button
key={s}
className="rounded-md border border-[var(--line)] bg-[var(--bg-3)] px-4 py-[13px] text-left text-[13px] leading-[1.3] text-[var(--cream-dim)] transition-all duration-200 hover:translate-y-[-1px] hover:border-[var(--line-2)] hover:bg-[rgba(235,229,216,0.015)] hover:text-[var(--cream)]"
style={{ fontFamily: '"Geist", sans-serif' }}
onClick={() => sendMessage(s)}
>
{s}
</button>
))}
</div>
</div>
) : (
<>
{messages.map((m, i) => {
return (
<div key={i} className={`message ${m.role}`}>
<span className="role">
{m.role === "user" ? "You" : selected.name}
</span>
<div className="bubble">
{m.role === "assistant" ? (
<Streamdown
className="streamdown-content"
plugins={STREAMDOWN_PLUGINS}
controls={false}
parseIncompleteMarkdown
>
{m.content}
</Streamdown>
) : (
m.content
)}
</div>
</div>
);
})}
{isThinking && (
<div className="message assistant">
<span className="role">{selected.name}</span>
<div className="thinking">
<span>thinking</span>
<span className="bits">
<span>1</span>
<span>0</span>
<span>1</span>
<span>1</span>
<span>0</span>
</span>
</div>
</div>
)}
</>
)}
</div>
<div className="composer">
<textarea
ref={inputRef}
value={input}
onChange={(e) => setInput(e.target.value)}
onKeyDown={handleKeyDown}
placeholder="Ask the 1-bit model anything…"
rows={1}
/>
{isStreaming || isThinking ? (
<button className="send" onClick={interruptGeneration}>
Stop
<span className="send-icon">
<StopIcon />
</span>
</button>
) : !input.trim() ? (
<button className="send" onClick={resetChat}>
Reset
<span className="send-icon">
<ResetIcon />
</span>
</button>
) : (
<button
className="send"
onClick={() => sendMessage(input)}
>
Send
<span className="send-icon">
<SendIcon />
</span>
</button>
)}
</div>
</div>
</div>
)}
</section>
</div>
</>
);
}